python - 保存/加载带有常量的 keras 模型

标签 python serialization keras constants

我有一个 Keras 模型,我想在预测中添加一个常量。经过一番谷歌搜索后,我最终得到了以下代码,它正是我想要的:

import numpy as np
from keras.layers import Input, Add
from keras.backend import variable
from keras.models import Model, load_model

inputs = Input(shape=(1,))
add_in = Input(tensor=variable([[5]]), name='add')
output = Add()([inputs, add_in])

model = Model([inputs, add_in], output)
model.compile(loss='mse', optimizer='adam')

X = np.array([1,2,3,4,5,6,7,8,9,10])
model.predict(X)

但是,如果我保存并加载此模型,Keras 似乎会失去对常量的跟踪:

p = 'k_model.hdf5'
model.save(p)
del model
model2 = load_model(p)
model2.predict(X)

返回结果:

Error when checking model : the list of Numpy arrays that you are passing to your model is not the size the model expected. Expected to see 2 array(s), but instead got the following list of 1 arrays:

保存/加载 Keras 模型时如何包含常量?

最佳答案

由于正如您提到的,它始终是一个常量,因此为它定义单独的输入层是没有意义的;特别考虑到它不是您模型的输入。我建议您使用Lambda层代替:

import numpy as np
from keras.layers import Input, Lambda
from keras.models import Model, load_model

def add_five(a):
    return a + 5

inputs = Input(shape=(1,))
output = Lambda(add_five)(inputs)

model = Model(inputs, output)
model.compile(loss='mse', optimizer='adam')

X = np.array([1,2,3,4,5,6,7,8,9,10])
model.predict(X)

输出:

array([[ 6.],
       [ 7.],
       [ 8.],
       [ 9.],
       [10.],
       [11.],
       [12.],
       [13.],
       [14.],
       [15.]], dtype=float32)

保存并重新加载模型时不会有任何问题,因为 add_ Five 函数已存储在模型文件中。

更新:您可以将其扩展到每个输入样本包含多个元素的情况。例如,如果输入形状为 (2,),并且您希望每个样本的第一个元素添加 5,第二个元素添加 10,则可以轻松修改 add_ Five code> 函数并像这样重新定义它:

def add_constants(a):
    return a + [5, 10]  

# ... the same as above (just change the function name and input shape)

X = np.array([1,2,3,4,5,6,7,8,9,10]).reshape(5, 2)
model.predict(X)

输出:

# X
array([[ 1,  2],
       [ 3,  4],
       [ 5,  6],
       [ 7,  8],
       [ 9, 10]])

# predictions
array([[ 6., 12.],
       [ 8., 14.],
       [10., 16.],
       [12., 18.],
       [14., 20.]], dtype=float32)

关于python - 保存/加载带有常量的 keras 模型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52413371/

相关文章:

python - 在rtree中,如何指定 float 相等性测试的阈值?

python - 来自几张图片的动画 Sprite

asp.net-mvc - 将接口(interface)对象序列化为 JSON

python - 使用 scipy.linalg.solve_triangular 求解 xA=b

python - 模拟线程上的死锁

java - 通过代理反序列化 AMF

java - 如何将单个 GSON 自定义序列化器应用于所有子类?

python - 如何让 Keras 模型在不同的拟合调用中提前停止

deep-learning - Keras 输入数组(x)不能有不同数量的样本吗?

python - Keras中如何解释清楚units参数的含义?