我的模型有两个输出,我想监控一个以保存我的模型。 下面是我的代码的一部分。 TensorFlow的版本是2.0
model = MobileNetBaseModel()()
model.compile(optimizer=tf.keras.optimizers.Adam(),
metrics={"pitch_yaw_roll": "mae"},
loss={"pitch_yaw_roll": compute_mse_loss, # or "mse"
"total_logits": compute_cross_entropy_loss(num_classes=num_classes)},
loss_weights= {"pitch_yaw_roll":mse_weight, "total_logits":cross_entropy_weight})
file_path = os.path.join(checkpoint_path, "model.{epoch:2d}-{val_loss:.2f}.h5")
tf.keras.callbacks.ModelCheckpoint(filepath=file_path,
monitor="val_loss",
verbose=1,
save_freq=save_freq,
save_best_only=True)
ModelCheckpoint
回调中默认的monitor='val_loss'
,如何选择我需要的?我想监控 {"pitch_yaw_roll": "mae"}
。
最佳答案
如果您希望 ModelCheckpoint
根据另一个指标值保存,请在 .compile(metrics={...}, ...)
指标字典。
因此,例如,如果您只想保存最好的 "pitch_yaw_roll"
epoch 结果(最好是最小值),您应该使用
tf.keras.callbacks.ModelCheckpoint(filepath=file_path,
monitor="val_pitch_yaw_roll",
verbose=1,
mode="min",
save_freq=save_freq,
save_best_only=True)
如果您选择 "pitch_yaw_roll"
而不是 "val_pitch_yaw_roll"
它将根据训练损失而不是根据验证损失进行保存
关于python - 模型有多个输出时的ModelCheckpoint监控值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60646997/