python - Keras Lambda CTC 无法加载模型

标签 python keras keras-2

您好,我有一个基于此 https://github.com/igormq/asr-study/tree/keras-2 的模型这能够保存好但无法加载(完整模式或 json/weights),因为损失没有正确定义。

inputs = Input(name='inputs', shape=(None, num_features))
...
o = TimeDistributed(Dense(num_hiddens))(inputs)

# Output layer
outputs = TimeDistributed(Dense(num_classes))(o)

# Define placeholders
labels = Input(name='labels', shape=(None,), dtype='int32', sparse=True)
inputs_length = Input(name='inputs_length', shape=(None,), dtype='int32')

# Define a decoder
dec = Lambda(ctc_utils.decode, output_shape=ctc_utils.decode_output_shape,
             arguments={'is_greedy': True}, name='decoder')
y_pred = dec([output, inputs_length])

loss = ctc_utils.ctc_loss(output, labels, input_length)


model = Model(input=[inputs, labels, inputs_length], output=y_pred)
model.add_loss(loss)

opt = Adam(lr=args.lr, clipnorm=args.clipnorm)

 # Compile with dummy loss
 model.compile(optimizer=opt, loss=None, metrics=[metrics.ler])

这将编译并运行(注意它使用了 add_loss 函数,该函数没有很好的文档记录)。甚至可以说服它通过一些工作来保存 - 正如这篇文章提示的那样(https://github.com/fchollet/keras/issues/5179)您可以通过强制图表完成来保存它。我通过制作一个虚拟 lambda 损失函数来引入不完全属于图表的输入来做到这一点,现在这似乎有效。

#this captures all the dangling nodes so will now save
fake_dummy_loss = Lambda(fake_ctc_loss,output_shape(1,),name=ctc)([y_pred,labels,inputs_length])

def fake_ctc_loss(args):
return tf.Variable(tf.zeros([1]),name="fakeloss")

我们可以像这样将其添加到模型中:

model = Model(input=[inputs, labels, inputs_length], output=[y_pred, fake_dummy_loss])

现在尝试加载时的损失表示它不能,因为它缺少损失函数(我猜这是因为尽管使用了 add_loss 但它被设置为 None。

在此感谢任何帮助

最佳答案

我在我的一个项目中遇到了类似的问题,其中 add_loss 用于手动向我的模型添加自定义损失函数。你可以在这里看到我的模型:Keras Loss Function with Additional Dynamic Parameter如您所见,使用 load_model 加载模型失败,提示缺少损失函数。

无论如何,我的解决方案是保存和加载模型的权重而不是整个模型。 Model 类有一个 save_weights 方法,在这里讨论:https://keras.io/models/about-keras-models/同样,还有一个 load_weights 方法。使用这些方法,您应该能够很好地保存和加载模型。缺点是您必须预先定义模型,然后加载权重。在我的项目中,这不是问题,只涉及一个小的重构。

希望对您有所帮助。

关于python - Keras Lambda CTC 无法加载模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45702765/

相关文章:

将单词列表与句子列表进行比较并打印匹配行的 Pythonic 方法

python - Statsmodel 多元 OLS 错误 "matrices are not aligned"

python - tf.train.AdamOptimizer 和在 keras.compile 中使用 adam 有什么区别?

tensorflow - 使用 tensorflow_datasets API 访问已下载的数据集

python - 没有提供数据。需要每个键的数据

python-3.x - 在 Keras 中使用单热编码创建模型

Python:请求 session 登录 Cookie

machine-learning - keras(cnn + nn)在4个类别中仅预测一个类别

Keras TimeDistributed Conv1D 错误

python - 在上下文管理器 (with) 和异常处理程序中分配给具有 `as` 的成员