وبلاگ رسانگار
با ما حرفه ای باشید

سرور مجازی NVMe

روش ذخیره و بارگذاری مدل های XGBoost

0 11
زمان لازم برای مطالعه: 2 دقیقه


مدل ها اغلب برای به کارگیری در تولید و ارائه پیش بینی های معنادار برای ورودی های جدید آموزش داده نمی شوند. برای انتقال آنها به خارج از محیط آموزشی خود – باید یک مدل آموزش دیده را ذخیره کنید و آن را در مدل دیگری بارگذاری کنید.

XGBoost یک کتابخانه عالی، منعطف و فوق العاده سریع با عملکرد فوق العاده است و پرچمدار آن XGBRegressor و XGBClassifier معجزه کند

بیایید یک پسرفتگر ساده تربیت کنیم روی مجموعه داده اسباب بازی:

import xgboost as xgb

from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler

X, y = datasets.load_diabetes(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y)

scaler = MinMaxScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)


xbg_reg = xgb.XGBRegressor().fit(X_train_scaled, y_train)

حالا بیایید نگاهی به روش ذخیره و بارگذاری مدل ها بیندازیم:

model.save_model() و model.load_model()

به طور رسمی توصیه می شود از آن استفاده کنید save_model() و load_model() توابع ذخیره و بارگذاری مدل ها

توجه داشته باشید: dump_model() استفاده میشه برای dump تنظیمات برای تفسیر توانایی و تجسم، نه برای ذخیره یک حالت آموزش دیده.

هر دو روش نامیده می شوند روی آ Booster نمونه، مثال:


xbg_reg.save_model("model.json")

xbg_reg.save_model("model.txt")


xbg_reg = xgb.Booster()

xbg_reg.load_model("model.json")
preds = xbg_reg.predict(xgb.DMatrix(X_test_scaled))
print(preds(:10)) 

xbg_reg.load_model("model.txt")
preds = xbg_reg.predict(xgb.DMatrix(X_test_scaled))
print(preds(:10)) 

می توانید به طور متناوب مشخص کنید که کدام تقویت کننده پر شده است، در این صورت می توانید آرایه های NumPy را به جای آرایه های پیچیده شده به عنوان تغذیه کنید. DMatrix() ماتریس ها:

xbg_reg.save_model("model.json")
xbg_reg.save_model("model.txt")

xbg_reg = xgb.XGBRegressor()

xbg_reg.load_model("model.json")
preds = xbg_reg.predict(X_test_scaled)
print(preds(:10)) 

xbg_reg.load_model("model.txt")
preds = xbg_reg.predict(X_test_scaled)
print(preds(:10)) 

توجه داشته باشید: این رویکرد سازگاری را تضمین می کند. استفاده از کتابخانه های خارجی مانند joblib و pickle اگر نسخه جدید XGBoost سعی کند یک پیکربندی قدیمی را بارگیری کند که قبل از نسخه جدید سریال شده است، ممکن است منجر به مشکلات سازگاری شود. می‌توانید با پین کردن نسخه‌ها یا برگشتن و سریال‌سازی مجدد این موضوع را تنظیم کنید، اما با استفاده از API رسمی می‌توانید به‌کلی از آن اجتناب کنید.

Joblib

Joblib یک کتابخانه سریال سازی است، با یک API بسیار ساده که به شما امکان می دهد مدل ها را در قالب های مختلف ذخیره کنید:

import joblib

joblib.dump(xbg_reg, "xgb_reg.sav")
xgb_reg = joblib.load("xgb_reg.sav")

preds = xgb_reg.predict(X_test_scaled)
print(preds(:10)) 

ترشی

Pickle یک کتابخانه سریال سازی دیگر است که به شما امکان می دهد مدل ها را به راحتی سریال سازی کنید، اما با فایل ها کمی بیشتر به صورت دستی کار می کند:

import pickle

pickle.dump(xbg_reg, open("xgb_reg.sav", "wb"))
xgb_reg = pickle.load(open("xgb_reg.sav", "rb"))

preds = xgb_reg.predict(X_test_scaled)
print(preds(:10)) 

(برچسب‌ها به ترجمه)# python



منتشر شده در 1403-01-05 03:54:03

امتیاز شما به این مطلب
دیدگاه شما در خصوص مطلب چیست ؟

آدرس ایمیل شما منتشر نخواهد شد.

لطفا دیدگاه خود را با احترام به دیدگاه های دیگران و با توجه به محتوای مطلب درج کنید