python - 异常训练Resnet50 : "The shape of the input to "Flatten"is not fully defined"

标签 python tensorflow keras deep-learning resnet

我想使用 keras.applications.resnet50 通过以下设置来训练 Resnet 以解决两类问题:

from keras.layers import Dropout, Flatten, Dense
from keras.applications.resnet50 import ResNet50
from keras.models import Model

resNet = ResNet50(include_top=False, weights=None)
y = resNet.output
y = Flatten()(y)
y = Dense(2, activation='softmax')(y)
model = Model(inputs=resNet.input, outputs=y)
opt = keras.optimizers.Adam()
model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])
epochs = 15
model.fit(train_tensors, train_targets, 
          validation_data=(valid_tensors, valid_targets),
          epochs=epochs, batch_size=10, callbacks=[checkpointer], verbose=1)

运行代码会抛出错误

Exception: The shape of the input to "Flatten" is not fully defined 

所以输出层的输入张量肯定有问题,在我的例子中是一个单热编码向量,即大小为 2 的一维数组。我做错了什么?

最佳答案

你得到

Exception: The shape of the input to "Flatten" is not fully defined

因为您尚未在 resnet 网络中设置输入形状。尝试:

resNet = ResNet50(include_top=False, weights=None, input_shape=(224, 224, 3)) 

此外,由于您在输出层中使用带有 sigmoid 激活的 binary_crossentropy,因此您应该仅使用 1 个神经元而不是 2 个,如下所示:

y = Dense(1, activation='sigmoid')(y)

关于python - 异常训练Resnet50 : "The shape of the input to "Flatten"is not fully defined",我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50693322/

相关文章:

python - django模板日期过滤格式字符串问题

python - 向量化形式以获取大于引用的元素数

python - pyodbc 连接到 sqlite 数据库

r - 错误 : Installation of TensorFlow not found

tensorflow - 尝试在 Keras 中构建一个集成。出现图表断开连接错误

machine-learning - 如何绘制 keras 实验的学习曲线?

python - 如何将python中的元素按n个元素分组

python - 将新单元添加到 Keras 模型层并更改其权重

windows - Tensorflow 需要多少 Visual Studio(最低要求)才能在 Windows 中运行?

python - 如何使用 PyTorch 为堆叠式 LSTM 模型执行 return_sequences?