python - 当准确度已经是 1.0 时停止 Keras 训练

标签 python machine-learning keras neural-network

当准确率已经达到 1.0 时,如何停止 Keras 训练?我尝试过监控损失值,但没有尝试在准确度已经为1时停止训练。

我尝试了下面的代码,但没有成功:

stopping_criterions =[
    EarlyStopping(monitor='loss', min_delta=0, patience = 1000),
    EarlyStopping(monitor='acc', base_line=1.0, patience =0)

]

model.summary()
model.compile(Adam(), loss='binary_crossentropy', metrics=['accuracy']) 
model.fit(scaled_train_samples, train_labels, batch_size=1000, epochs=1000000, callbacks=[stopping_criterions], shuffle = True, verbose=2)

更新:

即使准确度仍然不是 1.0,训练也会在第一个 epoch 立即停止。

enter image description here

请帮忙。

最佳答案

更新:在 keras 2.4.3(2020 年 12 月)中测试

我不知道为什么 EarlyStopping 在这种情况下不起作用。相反,我定义了一个自定义回调,当 acc(或 val_acc)达到指定基线时停止训练:

from keras.callbacks import Callback

class TerminateOnBaseline(Callback):
    """Callback that terminates training when either acc or val_acc reaches a specified baseline
    """
    def __init__(self, monitor='accuracy', baseline=0.9):
        super(TerminateOnBaseline, self).__init__()
        self.monitor = monitor
        self.baseline = baseline

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        acc = logs.get(self.monitor)
        if acc is not None:
            if acc >= self.baseline:
                print('Epoch %d: Reached baseline, terminating training' % (epoch))
                self.model.stop_training = True

你可以像这样使用它:

callbacks = [TerminateOnBaseline(monitor='accuracy', baseline=0.8)]
callbacks = [TerminateOnBaseline(monitor='val_accuracy', baseline=0.95)]
<小时/>

注意:此解决方案不起作用。

如果您想在训练(或验证)准确度达到 100% 时停止训练,请使用 EarlyStopping 回调并将 baseline 参数设置为 1.0 和 耐心归零:

EarlyStopping(monitor='acc', baseline=1.0, patience=0)  # use 'val_acc' instead to monitor validation accuarcy

关于python - 当准确度已经是 1.0 时停止 Keras 训练,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53500047/

相关文章:

tensorflow - 需要使用 if 语句的自定义损失函数

python - 在模块上调用 dir 函数

python - 如何在PyQt5中捕获libpng错误错误的自适应滤波器

python - 优化具有多个客户端的简单服务器

python-3.x - 重新训练机器学习的 Inception V3 模型

python - keras/scikit-学习 : using fit_generator() with cross validation

python - 如何读取使用 "from <module> import *"的代码?

machine-learning - 如何在多层感知器中使用sigmoid函数?

machine-learning - 为什么我们要在分类问题中最大化 AUC?

python - GradientTape 根据是否由 tf.function 修饰的损失函数给出不同的梯度