python - 我无法在 gridsearch 中添加优化器参数

标签 python keras scikit-learn deep-learning neural-network

from keras.wrappers.scikit_learn import KerasClassifier
from sklearn.model_selection import GridSearchCV
def build_classifier():
  classifier = Sequential()
  classifier.add(Dense(units = 6 , init='uniform' , activation= 'relu'))
  classifier.add(Dense(units = 6 , init='uniform' , activation= 'relu'))
  classifier.add(Dense(units = 1 , init='uniform' , activation= 'sigmoid'))
  classifier.compile(optimizer='adam' , loss = 'binary_crossentropy' , 
  metrics=['accuracy'])
  return classifier
KC = KerasClassifier(build_fn=build_classifier)
parameters = {'batch_size' : [25,32],
          'epochs' : [100,500],
          'optimizer':['adam','rmsprop']}
grid_search = GridSearchCV(estimator=KC , 
param_grid=parameters,scoring='accuracy',cv=10)
grid_search.fit(X_train,y_train)

我想用不同的优化器测试模型。但我似乎无法在网格搜索中添加优化器。每当我运行该程序时,它都会显示有关拟合训练集的错误。

ValueError:优化器不是合法参数

最佳答案

keras for scikit-learn 的文档说:

sk_params takes both model parameters and fitting parameters. Legal model parameters are the arguments of build_fn. Note that like all other estimators in scikit-learn, build_fn should provide default values for its arguments, so that you could create the estimator without passing any values to sk_params.

GridSearchCV 将在 KerasClassifier 上调用 get_params() 以根据您的代码获取可以传递给它的有效参数列表:

KC = KerasClassifier(build_fn=build_classifier)

将为空(因为您没有在 build_classifier 中指定任何参数)。

将其更改为:

# Used a parameter to specify the optimizer
def build_classifier(optimizer = 'adam'):
  ...
  classifier.compile(optimizer=optimizer , loss = 'binary_crossentropy' , 
  metrics=['accuracy'])
  ...
  return classifier

之后它应该可以工作了。

关于python - 我无法在 gridsearch 中添加优化器参数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53806892/

相关文章:

python - 如何选择 tf.keras 模型隐藏层输出神经元的值?

python - 如何使用 python 控制台修复 import keras 错误

python-3.x - 为什么 sklearn 标准化数据的方差不等于 1?

python - 随机森林回归 - 如何分析其性能? - python ,sklearn

python - Pytesser 不准确

python - Pyglet 和 FFmpeg

python - 在 Matplotlib 中,如何将 3D 图作为插图包含在内?

python - 关键字参数性能(python)

python - 使用 Keras 在 Python 中加载 .mat 数据

python - 拆分前的数据扩充