از طریق منوی جستجو مطلب مورد نظر خود در وبلاگ را به سرعت پیدا کنید
روش ذخیره و بارگذاری مدل های XGBoost
سرفصلهای مطلب
مدل ها اغلب برای به کارگیری در تولید و ارائه پیش بینی های معنادار برای ورودی های جدید آموزش داده نمی شوند. برای انتقال آنها به خارج از محیط آموزشی خود – باید یک مدل آموزش دیده را ذخیره کنید و آن را در مدل دیگری بارگذاری کنید.
XGBoost یک کتابخانه عالی، منعطف و فوق العاده سریع با عملکرد فوق العاده است و پرچمدار آن
XGBRegressor
وXGBClassifier
معجزه کند
بیایید یک پسرفتگر ساده تربیت کنیم روی مجموعه داده اسباب بازی:
import xgboost as xgb
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
X, y = datasets.load_diabetes(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y)
scaler = MinMaxScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
xbg_reg = xgb.XGBRegressor().fit(X_train_scaled, y_train)
حالا بیایید نگاهی به روش ذخیره و بارگذاری مدل ها بیندازیم:
model.save_model() و model.load_model()
به طور رسمی توصیه می شود از آن استفاده کنید save_model()
و load_model()
توابع ذخیره و بارگذاری مدل ها
توجه داشته باشید: dump_model()
استفاده میشه برای dump تنظیمات برای تفسیر توانایی و تجسم، نه برای ذخیره یک حالت آموزش دیده.
هر دو روش نامیده می شوند روی آ Booster
نمونه، مثال:
xbg_reg.save_model("model.json")
xbg_reg.save_model("model.txt")
xbg_reg = xgb.Booster()
xbg_reg.load_model("model.json")
preds = xbg_reg.predict(xgb.DMatrix(X_test_scaled))
print(preds(:10))
xbg_reg.load_model("model.txt")
preds = xbg_reg.predict(xgb.DMatrix(X_test_scaled))
print(preds(:10))
می توانید به طور متناوب مشخص کنید که کدام تقویت کننده پر شده است، در این صورت می توانید آرایه های NumPy را به جای آرایه های پیچیده شده به عنوان تغذیه کنید. DMatrix()
ماتریس ها:
xbg_reg.save_model("model.json")
xbg_reg.save_model("model.txt")
xbg_reg = xgb.XGBRegressor()
xbg_reg.load_model("model.json")
preds = xbg_reg.predict(X_test_scaled)
print(preds(:10))
xbg_reg.load_model("model.txt")
preds = xbg_reg.predict(X_test_scaled)
print(preds(:10))
توجه داشته باشید: این رویکرد سازگاری را تضمین می کند. استفاده از کتابخانه های خارجی مانند joblib
و pickle
اگر نسخه جدید XGBoost سعی کند یک پیکربندی قدیمی را بارگیری کند که قبل از نسخه جدید سریال شده است، ممکن است منجر به مشکلات سازگاری شود. میتوانید با پین کردن نسخهها یا برگشتن و سریالسازی مجدد این موضوع را تنظیم کنید، اما با استفاده از API رسمی میتوانید بهکلی از آن اجتناب کنید.
Joblib
Joblib یک کتابخانه سریال سازی است، با یک API بسیار ساده که به شما امکان می دهد مدل ها را در قالب های مختلف ذخیره کنید:
import joblib
joblib.dump(xbg_reg, "xgb_reg.sav")
xgb_reg = joblib.load("xgb_reg.sav")
preds = xgb_reg.predict(X_test_scaled)
print(preds(:10))
ترشی
Pickle یک کتابخانه سریال سازی دیگر است که به شما امکان می دهد مدل ها را به راحتی سریال سازی کنید، اما با فایل ها کمی بیشتر به صورت دستی کار می کند:
import pickle
pickle.dump(xbg_reg, open("xgb_reg.sav", "wb"))
xgb_reg = pickle.load(open("xgb_reg.sav", "rb"))
preds = xgb_reg.predict(X_test_scaled)
print(preds(:10))
(برچسبها به ترجمه)# python
منتشر شده در 1403-01-05 03:54:03