python - 与 CNN 交叉验证

标签 python machine-learning keras conv-neural-network cross-validation

我想知道我的代码是否正在做我想做的事情;为您提供一些背景知识,我正在实现 CNN 进行图像分类。我正在尝试使用交叉验证来比较我的不同神经网络架构

这里是代码:


def create_model():
    model = Sequential()
    model.add(Conv2D(24,kernel_size=3,padding='same',activation='relu',
            input_shape=(96,96,1)))
    model.add(MaxPool2D())
    model.add(Conv2D(48,kernel_size=3,padding='same',activation='relu'))
    model.add(MaxPool2D())
    model.add(Conv2D(64,kernel_size=3,padding='same',activation='relu'))
    model.add(MaxPool2D())
    model.add(Flatten())
    model.add(Dense(128, activation='relu'))
    model.add(Dense(256, activation='relu'))
    model.add(Dense(12, activation='softmax'))
    model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])
    return model
model = KerasClassifier(build_fn=create_model, epochs=5, batch_size=20, verbose=1) 
# 3-Fold Crossvalidation
kfold = KFold(n_splits=3, shuffle=True, random_state=2019) 
results = cross_val_score(model, train_X, train_Y_one_hot, cv=kfold)

model.fit(train_X, train_Y_one_hot,validation_data=(valid_X, valid_label),class_weight=class_weights)
y_pred = model.predict(test_X)

test_eval = model.evaluate(test_X, y_pred, verbose=0)

我在互联网上找到了交叉验证的部分。但我在理解它时遇到了一些问题。

我的问题:1=> 我可以使用交叉验证来提高准确性吗?例如,我运行了 10 次我的神经网络,我的模型得到了出现最佳准确度的权重

2 => 如果我理解得很好,在上面的代码中,结果运行我的 CNN 3 次并向我展示准确性。但是当我使用 model.fit 时,模型仅运行一次;我说得对吗?

感谢您的帮助

最佳答案

  1. 并非如此,交叉验证更多的是一种防止过度拟合/不被来自严重分割的数据集的异常结果所混淆的方法 -> 获得对模型性能的相关估计。如果您想调整模型的超参数,最好使用 sklearn.model_selection.GridSearchCV/sklearn.model_selection.RandomSearchCV

  2. 对每次训练/测试进行cross_val_score时 sklearn 进行拟合,然后预测/评估,因此对于模型的每个新实例, 您有 1 个适合,然后有 1 个预测/评估; 否则,您的交叉验证无效,因为它取决于对先前数据集的拟合(可能还取决于测试数据!)

关于python - 与 CNN 交叉验证,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55923475/

相关文章:

python - 如何使用 python 从 G(n,p) 图创建邻接矩阵?

r - 在 R 中使用 rpart() 时实际使用的字符 (0)

python - 保持示例索引与 tf.keras.predict 和 tf.data.Dataset 的对应关系

tensorflow - 索引错误: index 5 is out of bounds for axis 1 with size 5

Python math.sqrt 损失精度

javascript - 使用 JavaScript/Python/Bash 管理 API 调用

php - 将 php-ai/php-ml 与 php 一起使用

python - 将多行转换为单行并存储在 python 中的变量中

python - 无法像 Javascript 一样使用 RobotFramework 执行 Python .py 文件

python - 如何将 Keras 模型加载到内存中并在需要时使用它?