python - 如何同时监控loss和val_loss以避免神经网络对训练集或测试集过度拟合?

标签 python tensorflow scikit-learn neural-network

我一直在参加这个黑客马拉松并使用 keras 回调和神经网络,我是否可以知道是否有一种方法不仅可以监控损失或 val_loss,还可以监控它们两者以避免过度拟合测试集或训练集? 例如:我可以为监视器字段添加一个函数,而不仅仅是一个字段名称吗?

如果我想监控 val_loss 以选择最低值,但我还想要第二个标准来选择 val_loss 和 loss 之间的最小差异。

最佳答案

我有一个与此非常相似的问题的答案,here .

基本上,不可能使用 keras 回调来监控多个指标。但是,您可以定义一个自定义回调(请参阅 documentation 了解更多信息),它可以访问每个时期的日志并执行一些操作。

假设您想监控 lossval_loss,您可以执行以下操作:

import tensorflow as tf
from tensorflow import keras

class CombineCallback(tf.keras.callbacks.Callback):

    def __init__(self, **kargs):
        super(CombineCallback, self).__init__(**kargs)

    def on_epoch_end(self, epoch, logs={}):
        logs['combine_metric'] = logs['val_loss'] + logs['loss']

旁注:我认为最重要的是监控验证损失。当然,火车损失会持续下降,因此观察并没有多大意义。如果你真的想监控它们,我建议你添加一个乘法因子并给予验证损失更多的权重。在这种情况下:

class CombineCallback(tf.keras.callbacks.Callback):

    def __init__(self, **kargs):
        super(CombineCallback, self).__init__(**kargs)

    def on_epoch_end(self, epoch, logs={}):
        factor = 0.8
        logs['combine_metric'] = factor * logs['val_loss'] + (1-factor) * logs['loss']

然后,如果您只想在训练期间监控这个新指标,您可以这样使用它:

model.fit(
    ...
    callbacks=[CombineCallback()],
)

相反,如果您还想使用新指标停止训练,则应该将新回调与提前停止回调结合起来:

combined_cb = CombineCallback()
early_stopping_cb = keras.callbacks.EarlyStopping(monitor="combine_metric")
model.fit(
    ...
    callbacks=[combined_cb, early_stopping_cb],
)

确保在回调列表中提前停止回调之前获取CombinedCallback

而且,还可以汲取更多灵感here .

关于python - 如何同时监控loss和val_loss以避免神经网络对训练集或测试集过度拟合?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/73353241/

相关文章:

python - 如何在(GridSearchCV)拟合模型后打印估计系数? (SGDRegressor)

python - 这个 bool 值我在做什么错

python - 如何使用 Spidermon 监控特定的蜘蛛?

python - tkinter:如何让用户在 asksaveasfilename 对话框中选择文件类型?

python - 如何在 Azure 中使用 Python SDK 将 LogAnalyticsWorkSpace 扩展添加到 VM

python - 运行 tensorflow-gpu 的段错误

python - TensorFlow:实现网络,其中所有一层的特征图未连接到下一层的所有特征图

tensorflow - 如何使用 TensorFlow 后端计算 Keras 中的 KL 散度?

python - SVC scikit 预测奇数和偶数

python - 使用 scikit learn 进行稀疏编码后应用最大/平均池化进行分类