python - 如何处理 Keras 自定义层中发生的代码错误?

标签 python keras

我想在 Keras 中创建一个自定义层。 在这个例子中,我使用一个变量来乘以张量,但我得到的错误是

in /keras/engine/training_arrays.py, line 304, in predict_loop outs[i][batch_start:batch_end] = batch_out ValueError: could not broadcast input array from shape (36) into shape (2).

实际上我已经检查过这个文件,但我什么也没得到。我的自定义图层有问题吗?

#the definition of mylayer.


 from keras import backend as K
 import keras
 from keras.engine.topology import Layer

class mylayer(Layer):
def __init__(self, output_dim, **kwargs):
    self.output_dim = output_dim
    super(mylayer, self).__init__(**kwargs)

def build(self, input_shape):
    self.kernel = self.add_weight(name = 'kernel',
                                  shape=(1,),dtype='float32',trainable=True,initializer='uniform')
    super(mylayer, self).build(input_shape)

def call(self, inputs, **kwargs):
    return self.kernel * inputs[0]
def compute_output_shape(self, input_shape):
    return (input_shape[0], input_shape[1])


#the test of mylayer.

from mylayer import mylayer
from tensorflow import keras as K
import numpy as np
from keras.layers import Input, Dense, Flatten
from keras.models import Model

x_train = np.random.random((2, 3, 4, 3))
y_train = np.random.random((2, 36))
print(x_train)

x = Input(shape=(3, 4, 3))
y = Flatten()(x)
output = mylayer((36, ))(y)

model = Model(inputs=x, outputs=output)

model.summary()

 model.compile(optimizer='Adam',loss='categorical_crossentropy',metrics=['accuracy'])
model.fit(x_train, y_train, epochs=2)

hist = model.predict(x_train,batch_size=2)

print(hist)

print(model.get_layer(index=1).get_weights())


#So is there some wrong in my custom error?

特别是,当我训练这个网络时,这是可以的,但是当我尝试使用“prdict”时,它是错误的。

最佳答案

您的 self.kernel * input[0] 的形状是 (36,),但您的期望是 (?,36) 。更改它:

def call(self, inputs, **kwargs):
    return self.kernel * inputs

如果要输出mylayer的权重,需要设置index=2

关于python - 如何处理 Keras 自定义层中发生的代码错误?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53994158/

相关文章:

keras - 训练精度很好,但测试精度很差

python - Keras 从 CSV 加载图像

python - 如何在 Pandas 中用空列表[]填充数据框Nan值?

python - 两个分类交叉熵之间的凸组合

python - 如何使用flask_sqlalchemy反射(reflect)现有表

python - 可滚动框架无法使用 tkinter 正确调整大小

machine-learning - 将用户反馈纳入 ML 模型

python - 如何在笔记本中绘制 keras 激活函数

python - 根据 graphviz_layout 中的权重设置边长

python - 在python中随机选择特定范围内具有特定倍数的数字