python - 为什么我在简单的 Keras 功能 API 中收到输入错误?

标签 python tensorflow machine-learning keras deep-learning

我希望我们能够一起解决我的问题,我会尝试总结它并粘贴一些代码。 在我们的研究中,我们使用 TensorFlow/Keras 作为 DNN,对图像进行分类(卷积网络)。这是一个非常简单的顺序模型,但现在我们正在尝试添加更多输入,因此我开始将网络更改为功能性API(老实说,这是我第一次使用它)。我重新创建了原始的卷积网络,一切正常。 因此,我将必要的附加数据生成到简单的文本文件中(每个图像一个),这些将是我们模型中的第二个输入。这是出了问题,因为我遇到了错误,现在我无法找到解决方案。

为了重现错误,我创建了一个新的简单 python 文件,没有卷积部分,只是为了尝试模型的文本文件部分。输入是 5000 个文本文件,仅包含一个 float 。预处理后,我们将有 4000 个用于训练,另外 1000 个用于测试,两者都存储在 numpy 数组中。训练和测试数组通过 sklearn.model_selection.train_test_split 进行分割。

在原始的多 channel 网络中,我尝试将卷积部分与文本文件连接起来,之后出现了一些密集层。但这里没有连接,只是存储在文本文件中的数据的火车。

以下是输入和标签数组的形状:
X_train.shape
(4000,)
y_train.shape
(4000, 5)

非常简单的网络:

inputA = Input(shape=(X_train.shape[0],))
x = Flatten()(inputA)
x = Dense(256, activation='relu')(x)
x = Dense(256, activation='relu')(x)
x = Dense(64, activation='relu')(x)
x = Dense(5, activation='softmax')(x)
x = Model(inputs=inputA, outputs = x)

编译:

x.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])

模型拟合:

x.fit(X_train, y_train, 
          validation_data=(X_test, y_test),
          batch_size=batch, 
          verbose=1, 
          epochs=epoch_number)

以及收到的错误消息:

ValueError: Error when checking input: expected input_8 to have shape (4000,) but got array with shape (1,)

问题是,我做错了什么?以前的代码在顺序模型中运行得很好,但这里不行。有人能帮我解开这个谜团吗?

最美好的祝愿, 塔马斯

最佳答案

这是因为这一行:

inputA = 输入(shape=(X_train.shape[0],))

Keras 期望的是特征数量,而不是样本数量。在您的情况下,input_shape=(1,)

关于python - 为什么我在简单的 Keras 功能 API 中收到输入错误?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59142319/

相关文章:

python - Flask SqlAlchemy 多对多关系在按关系名称访问时只返回一个结果

python - django 中 MIDDLEWARE_CLASSES 的顺序

tensorflow - pipenv - 由于 'Could not find a version that matches keras-nightly~=2.5.0.dev' 错误,未生成 Pipfile.lock

python - 在cross_val_score中,参数cv的使用有何不同?

python - 如何处理pythonsoap模块zeep中的complexType参数?

python - 强制 Python 导入是绝对的(忽略本地包目录)

python - Tensorflow : TypeError: int() argument must be a string, 类似字节的对象或数字,而不是 'Tensor'

tensorflow - 为什么Tensorflow的MirroredStrategy和OneDevicestrategy在colab上不起作用?

performance - 在 WEKA 中评估模型

python - Scikit-learn KNN(K 最近邻)使用 Apache Spark 并行化