在Keras中我们可以简单的添加回调,如下所示:
self.model.fit(X_train,y_train,callbacks=[Custom_callback])
回调在doc中定义,但我找不到任何使用它们的例子。谁能告诉我如何将自定义回调添加到 TensorFlow 中?
最佳答案
这是我喜欢使用而不是 tensorboard 的损失记录回调示例。请注意,它需要后处理,我实际上更喜欢这样,因此我可以计算 cross-validation 的平均验证损失。 :
class lossesLogger(tf.keras.callbacks.Callback):
def __init__(self, fileName):
self.fileName = fileName
self.json_log = open(
self.fileName +'.json',
mode='w+',
buffering=1
)
def on_epoch_end(self, epoch, logs=None):
self.json_log.write(
json.dumps(
'epoch {}: '.format(epoch) +
str(logs)
) +
'\n'
)
def on_train_end(self, logs=None):
self.json_log.close()
要使用它,请将它添加到您的回调列表中,就像在这个示例中一样,我将它放在三个列表的末尾:
callbacks = [ #allows analyzing output of replicates
EarlyStopping(
patience=epochsWithoutValidationLossDecrease,
verbose=1
),
ModelCheckpoint(
os.path.join(
os.getcwd(),
'Latest_saved_model_{}.h5'.format(uniqueID)
),
verbose=1,
save_best_only=True,
save_weights_only=False
),
lossesLogger(
os.path.join(
os.getcwd(),
('val_log_per_epoch_' + str(uniqueID))
)
)
]
关于tensorflow - Tensorflow 中的回调,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53324641/