python - 如何保存 GridSearchCV 对象?

标签 python scikit-learn keras save grid-search

最近,我一直致力于应用网格搜索交叉验证 (sklearn GridSearchCV) 在具有 Tensorflow 后端的 Keras 中进行超参数调整。一旦我的模型被调整
我试图保存 GridSearchCV 对象供以后使用但没有成功。

超参数调整如下:

x_train, x_val, y_train, y_val = train_test_split(NN_input, NN_target, train_size = 0.85, random_state = 4)

history = History() 
kfold = 10


regressor = KerasRegressor(build_fn = create_keras_model, epochs = 100, batch_size=1000, verbose=1)

neurons = np.arange(10,101,10) 
hidden_layers = [1,2]
optimizer = ['adam','sgd']
activation = ['relu'] 
dropout = [0.1] 

parameters = dict(neurons = neurons,
                  hidden_layers = hidden_layers,
                  optimizer = optimizer,
                  activation = activation,
                  dropout = dropout)

gs = GridSearchCV(estimator = regressor,
                  param_grid = parameters,
                  scoring='mean_squared_error',
                  n_jobs = 1,
                  cv = kfold,
                  verbose = 3,
                  return_train_score=True))

grid_result = gs.fit(NN_input,
                    NN_target,
                    callbacks=[history],
                    verbose=1,
                    validation_data=(x_val, y_val))

备注:create_keras_model 函数初始化和编译 Keras Sequential 模型。

执行交叉验证后,我尝试使用以下代码保存网格搜索对象 (gs):
from sklearn.externals import joblib

joblib.dump(gs, 'GS_obj.pkl')

我得到的错误如下:
TypeError: can't pickle _thread.RLock objects

你能告诉我这个错误的原因是什么吗?

谢谢!

P.S.:joblib.dump 方法适用于保存使用的 GridSearchCV 对象
用于训练来自 sklearn 的 MLPRegressors。

最佳答案

import joblib 直接
代替from sklearn.externals import joblib使用以下方法保存对象或结果:joblib.dump(gs, 'model_file_name.pkl')并使用以下方法加载您的结果:joblib.load("model_file_name.pkl")这是一个简单的工作示例:


import joblib

#save your model or results
joblib.dump(gs, 'model_file_name.pkl')

#load your model for further usage
joblib.load("model_file_name.pkl")

关于python - 如何保存 GridSearchCV 对象?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51424312/

相关文章:

python - 为什么 pyproj.Proj 前向投影似乎没有考虑经纬度原点?

python - 检查轮廓的任何点是否与一条线相交的有效方法是什么?

python-3.x - 值错误: Unknown label type

machine-learning - Keras 中的自定义损失函数用于惩罚漏报

python - 如何将 1D numpy 数组从 keras 层输出更改为图像(3D numpy 数组)

python - 在遍历for循环的python中使用ftell()的意外文件指针位置

python - 从 python 列表中删除 '

python - 属性错误: vocabulary not found

python - KneighborsClassifier 给出与 linalg.norm 和 scipy.spatial.distance.euclidean 不同的欧几里德值

python-3.x - 我应该如何使用 mode.predict_generator 来评估混淆矩阵中的模型性能?