tensorflow - Tensorflow 中的回调

标签 tensorflow machine-learning keras callback

在Keras中我们可以简单的添加回调,如下所示:

self.model.fit(X_train,y_train,callbacks=[Custom_callback])

回调在doc中定义,但我找不到任何使用它们的例子。谁能告诉我如何将自定义回调添加到 TensorFlow 中?

最佳答案

这是我喜欢使用而不是 tensorboard 的损失记录回调示例。请注意,它需要后处理,我实际上更喜欢这样,因此我可以计算 cross-validation 的平均验证损失。 :

class lossesLogger(tf.keras.callbacks.Callback):
    def __init__(self, fileName):
        self.fileName = fileName
        self.json_log = open(
            self.fileName +'.json',
            mode='w+',
            buffering=1
        )     
    def on_epoch_end(self, epoch, logs=None):
        self.json_log.write(
            json.dumps(
                'epoch {}: '.format(epoch) +
                str(logs)
            ) +
            '\n'
        )        
    def on_train_end(self, logs=None):
        self.json_log.close()

要使用它,请将它添加到您的回调列表中,就像在这个示例中一样,我将它放在三个列表的末尾:

callbacks = [ #allows analyzing output of replicates
        EarlyStopping(
            patience=epochsWithoutValidationLossDecrease,
            verbose=1
        ),
        ModelCheckpoint(
            os.path.join(
                os.getcwd(),
                'Latest_saved_model_{}.h5'.format(uniqueID)
            ), 
            verbose=1,
            save_best_only=True,
            save_weights_only=False
        ),
        lossesLogger(
            os.path.join(
                os.getcwd(),
                ('val_log_per_epoch_' + str(uniqueID))
            )
        )
    ]

关于tensorflow - Tensorflow 中的回调,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53324641/

相关文章:

tensorflow - tf.keras : Evaluating model. 使用 tf.data.Dataset 作为输入时更新中断

keras - 如果我们需要更改 input_shape,为什么我们需要 include_top=False?

python - 从 tflite 模型文件中提取标签

python - 在 Tensorflow 中提取 Adam 更新率

R-confusionMatrix()-sort.list(y) : 'x' must be atomic for 'sort.list' 中的错误

python - 如何在 tensorflow 中将数值分类数据转换为稀疏张量?

python - 展平层的输入必须是张量

machine-learning - 如何使用tensorflow的merge和switch功能?

python - TensorFlow - tf.VariableScope 和 tf.variable_scope 之间的区别

python - Keras 中的自定义损失函数,以屏蔽数组作为输入