python - 如何根据损失值告诉 Keras 停止训练?

标签 python machine-learning neural-network conv-neural-network keras

目前我使用以下代码:

callbacks = [
    EarlyStopping(monitor='val_loss', patience=2, verbose=0),
    ModelCheckpoint(kfold_weights_path, monitor='val_loss', save_best_only=True, verbose=0),
]
model.fit(X_train.astype('float32'), Y_train, batch_size=batch_size, nb_epoch=nb_epoch,
      shuffle=True, verbose=1, validation_data=(X_valid, Y_valid),
      callbacks=callbacks)

当损失在 2 个 epoch 内没有改善时,它会告诉 Keras 停止训练。但是我想在损失变得小于某个恒定的“THR”后停止训练:

if val_loss < THR:
    break

我在文档中看到有可能进行自己的回调: http://keras.io/callbacks/ 但没有找到如何停止训练过程。我需要一个建议。

最佳答案

我找到了答案。我查看了 Keras 的源代码并找到了 EarlyStopping 的代码。我做了自己的回调,基于它:

class EarlyStoppingByLossVal(Callback):
    def __init__(self, monitor='val_loss', value=0.00001, verbose=0):
        super(Callback, self).__init__()
        self.monitor = monitor
        self.value = value
        self.verbose = verbose

    def on_epoch_end(self, epoch, logs={}):
        current = logs.get(self.monitor)
        if current is None:
            warnings.warn("Early stopping requires %s available!" % self.monitor, RuntimeWarning)

        if current < self.value:
            if self.verbose > 0:
                print("Epoch %05d: early stopping THR" % epoch)
            self.model.stop_training = True

及用法:

callbacks = [
    EarlyStoppingByLossVal(monitor='val_loss', value=0.00001, verbose=1),
    # EarlyStopping(monitor='val_loss', patience=2, verbose=0),
    ModelCheckpoint(kfold_weights_path, monitor='val_loss', save_best_only=True, verbose=0),
]
model.fit(X_train.astype('float32'), Y_train, batch_size=batch_size, nb_epoch=nb_epoch,
      shuffle=True, verbose=1, validation_data=(X_valid, Y_valid),
      callbacks=callbacks)

关于python - 如何根据损失值告诉 Keras 停止训练?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/37293642/

相关文章:

python - 使用 PuLP 进行线性优化,变量附加条件

python - Python不检测单个字母,但是可以检测两个字母

r - caret::train:为 mlpWeightDecay(RSNNS 包)指定更多非调整参数

python - 如何使用tensorflow函数tf.contrib.legacy_seq2seq.sequence_loss_by_example的 'weights'参数?

machine-learning - 如何查看pytorch模型的参数?

python - 如何在 python 中抓取受密码保护的站点?

python - Pandas:如何将数据帧列中的 'timestamp' 值从对象/字符串转换为时间戳?

python - Pybrain交叉验证方法

python - 神经网络中的 Softmax 函数 (Python)

neural-network - 使用 Q-Learning 和函数逼近求解 GridWorld