python - 在 keras 中使用额外输入加载自定义损失

标签 python tensorflow keras

我有一个自定义损失函数,它将模型的输入作为参数之一。如果我在我训练的同一个 session 中加载,我可以使用 this 加载它没有问题。技术。


def custom_loss(inputs):
    def loss(y_true, y_pred):
        return ...
    return loss

inputs = keras.layers.Input(shape=(...))
y = keras.layers.Activation('tanh')(inputs)

model = keras.models.Model(inputs=inputs, outputs=y)

model.compile(loss=custom_loss(inputs), optimizer='Adam')
model.fit(...)
model.save('mymodel.h5')
load_model('mymodel.h5', custom_objects={'custom_loss': custom_loss(inputs})

但是,当我尝试在稍后的 session 中加载模型时遇到问题,因为这次我无法访问原始输入张量。如果我创建一个新的输入占位符,那么模型需要两组不同的输入,我就会出错。

inputs = keras.layers.Input(shape=(...))
load_model('mymodel.h5', custom_objects={'custom_loss': custom_loss(inputs)})

有什么好的方法可以解决这个问题吗?最终的问题是输入尚未反序列化,因此无法将它们传递给自定义对象。

我不想只保存权重并创建具有相同权重的新模型,因为我失去了优化器状态。

最佳答案

另一种方法是计算 Keras 层内的损失,并传递一个虚拟损失函数,该函数仅返回模型的输出作为编译方法中的损失。 还有其他方法可以做到这一点。但这是我更喜欢的。

import tensorflow as tf
print('Tensorflow', tf.__version__)

def custom_loss(tensor):
    y_true, y_pred, inputs = tensor[0], tensor[1], tensor[1]
    loss = ...
    return tf.constant([0], dtype=tf.float32)

def dummy_loss(y_true, y_pred):
    return y_pred

def get_model(training=False):
    inputs = tf.keras.layers.Input(shape=(10,))
    y = tf.keras.layers.Activation('tanh')(inputs)
    if training:
        targets = tf.keras.layers.Input(shape=(10,)) 
        loss_layer = tf.keras.layers.Lambda(custom_loss)([targets, y, inputs])
        model = tf.keras.models.Model(inputs=[inputs, targets], outputs=loss_layer)
    else:
        model = tf.keras.models.Model(inputs=inputs, outputs=y)
    return model


model = get_model(training=True)
model.compile(optimizer='sgd', loss=dummy_loss)
model.save('model.h5')

new_model = tf.keras.models.load_model('model.h5', custom_objects={'dummy_loss':dummy_loss})

关于python - 在 keras 中使用额外输入加载自定义损失,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57897080/

相关文章:

python - 如何使用 python-gitlab 插件在 python 中搜索标签?

python - keras 分类和二元交叉熵

python - Keras 合并层错误

python - 将字符串列表转换为 Python 数据框 - pyspark python Sparksql

python - PyMC:利用 Adaptive Metropolis MCMC 中的稀疏模型结构

machine-learning - TensorFlow 中的步骤和纪元有什么区别?

python - 使用 LSTM ptb 模型 tensorflow 示例预测下一个词

python - 将使用 make_csv_dataset 创建的 TensorFlow 数据集拆分为 3 部分(X1_Train、X2_Train 和 Y_Train)以用于多输入模型

python - 当图像的尺寸为(2048 * 1536)时,我应该在顺序模型的第一层中采用什么输入形状

python - 以编程方式从 QMenubar 中删除顶级 QMenu