python - 模型有多个输出时的ModelCheckpoint监控值

标签 python tensorflow keras tensorflow2.0

我的模型有两个输出,我想监控一个以保存我的模型。 下面是我的代码的一部分。 TensorFlow的版本是2.0

model = MobileNetBaseModel()()
model.compile(optimizer=tf.keras.optimizers.Adam(),
              metrics={"pitch_yaw_roll": "mae"},
              loss={"pitch_yaw_roll": compute_mse_loss, # or "mse"
                    "total_logits": compute_cross_entropy_loss(num_classes=num_classes)},
              loss_weights= {"pitch_yaw_roll":mse_weight, "total_logits":cross_entropy_weight})
file_path = os.path.join(checkpoint_path, "model.{epoch:2d}-{val_loss:.2f}.h5")
tf.keras.callbacks.ModelCheckpoint(filepath=file_path,
                                   monitor="val_loss",
                                   verbose=1,
                                   save_freq=save_freq,
                                   save_best_only=True)

ModelCheckpoint回调中默认的monitor='val_loss',如何选择我需要的?我想监控 {"pitch_yaw_roll": "mae"}

最佳答案

如果您希望 ModelCheckpoint 根据另一个指标值保存,请在 .compile(metrics={...}, ...) 指标字典。

因此,例如,如果您只想保存最好的 "pitch_yaw_roll" epoch 结果(最好是最小值),您应该使用

tf.keras.callbacks.ModelCheckpoint(filepath=file_path,
                                   monitor="val_pitch_yaw_roll",
                                   verbose=1,
                                   mode="min",
                                   save_freq=save_freq,
                                   save_best_only=True)

如果您选择 "pitch_yaw_roll" 而不是 "val_pitch_yaw_roll" 它将根据训练损失而不是根据验证损失进行保存

关于python - 模型有多个输出时的ModelCheckpoint监控值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60646997/

相关文章:

python - 使用 Python 的 Azure 服务主体和 storage.blob

python - 连接 DataFrame,但仅保留一列

c++ - 如何创建具有 CUDA 支持的最新 Tensorflow 版本的调试版本?

python - Windows 上的 Theano 与 Anaconda : how to setup BLAS?

tensorflow - keras 设置为 TF 格式时以 TH 格式加载权重

python - xinetd 服务调用 python 脚本(无法正确执行)

python - 在从传感器读取数据而不打印从传感器读取的数据时,While 循环无法立即工作?

android - Tensorflow-Lite - 基准测试工具 - 不同的结果

neural-network - Tensorflow Inception 多 GPU 训练损失未求和?

r - RStudio 中的 install.keras() 因 http 连接错误而失败