python - Keras - 数组形状与 model.predict() 不匹配

标签 python numpy keras

我有一个简单的密集神经网络,有 2 个用 Keras 编写的输入值,在 Tensorflow 和 Python 上运行。我已经成功安装了该网络,并且可以毫无错误地运行评估。但是,当我想要预测单个样本数据的结果时,由于输入数据的维度形状不正确,我收到错误。但是,当我打印 numpy 数组的形状时,它会返回正确的形状:

inputArr = np.array((x[sample][0], x[sample][1]))
print(inputArr)
print(inputArr.shape)
prediction = model.predict(inputArr)

这会产生以下输出:

Input data: [-1. -1.]
Array shape: (2,)

随后出现错误:

Traceback (most recent call last):
  File ".\train3d.py", line 60, in <module>
    prediction = model.predict(inputArr)
  File "C:\Users\svoja\AppData\Local\Programs\Python\Python35\lib\site-packages\keras\engine\training.py", line 1147, in predict
    x, _, _ = self._standardize_user_data(x)
  File "C:\Users\svoja\AppData\Local\Programs\Python\Python35\lib\site-packages\keras\engine\training.py", line 749, in _standardize_user_data
    exception_prefix='input')
  File "C:\Users\svoja\AppData\Local\Programs\Python\Python35\lib\site-packages\keras\engine\training_utils.py", line 137, in standardize_input_data
    str(data_shape))
ValueError: Error when checking input: expected dense_1_input to have shape (2,) but got array with shape (1,)

正如您从错误消息中看到的,网络需要一个维度为 (2, ) 的数组,这与我的输入数组的输出完全相同。

我的问题是,数组到底出了什么问题?

最佳答案

您缺少批量大小,Keras 希望数据隐式(N, D),其中 N 是批量大小,D 是特征数量。在您的情况下 D=2 但您没有矩阵。

要传递单个数据点,您需要形状 (1, 2),它读取具有 2 个特征的 1 个数据点。您可以通过以下方式实现此目的:

inputArr = np.array((x[sample][0], x[sample][1]))
print(inputArr) # [-1, -1]
print(inputArr.shape) # (2,)
inputArr = np.expand_dims(inputArr, 0)
print(inputArr.shape) # (1, 2)

或更短的语法糖版本:

inputArr = inputArr[None, :] # (1, 2)

其中 None 添加新维度。

关于python - Keras - 数组形状与 model.predict() 不匹配,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55436040/

相关文章:

machine-learning - 从 Keras 中的 3 维张量收集 2 维张量列表

machine-learning - 为什么每个 epoch 之后损失都会突然下降?

具有固定数量元素的python列表

python - 加权平均日期时间,关闭但仅限某些月份

Python 将 Counter 附加到 Counter,就像 Python 字典更新一样

java - Scala:相当于 np.digitize 对数据进行分桶

python - 如何在 Python 中计算自协方差

python - TensorFlow "Please provide as model inputs a single array or a list of arrays"

python - "not all arguments converted during string formatting"当to_sql

Python 随机播放列表不起作用