python - 如何设置keras LSTM的输入形状

标签 python tensorflow keras

我经常遇到这样的问题

ValueError:层顺序的输入 0 与层不兼容:预期 ndim=3,发现 ndim=2。收到完整形状:[10 ,3]

我用谷歌搜索了一下,发现了

LSTM 层期望输入的形状为 (batch_size, timesteps, input_dim)

好吧,但老实说我还是有点困惑。

例如,我有这样的训练数据

x_train (100,3) #it consists of like `[[1,2,3],[3,4,5],[5,6,7]]`
y_train (100,3) #answers

我想使用 10 组 3 对数字并预测下一个 3 对,例如 [7,8,9]

就像从 x_train[1~10] 猜测到 y_train[11]`

下面的代码可以工作,但是我仍然不清楚

input_shape=(3,1)1是什么意思??它应该是 3 (我最终想要得到的维度)

并且batch_size是LSTM请求的第一个参数。

所以,,,当我想从过去 10 项中预测一项时,这里设置 10 是否正确???

x_train = np.array(x).reshape(100, 3,1)
y_train = np.array(x).reshape(100, 3,1)

model.add(LSTM(512, activation=None, input_shape=(3, 1), return_sequences=True))

model.add(Dense(1, activation="linear"))

opt = Adam(lr=0.001)

model.compile(loss='mse', optimizer=opt)
model.summary()
history = model.fit(x_train, y_train, epochs=epoch, batch_size=10) // how to set batch size???

最佳答案

试试这个代码:

import tensorflow as tf
import numpy as np
x = np.random.uniform(0, 10, [101, 3])

x_train = np.array(x[:-1]).reshape(-1, 5, 3) # your data comprise of 20 sequences
y_train = np.array(x[1:]).reshape(-1, 5, 3)

model = tf.keras.Sequential()
model.add(tf.keras.layers.LSTM(512, activation=None, input_shape=(None, 3), return_sequences=True))

model.add(tf.keras.layers.Dense(1, activation="linear"))

opt = tf.keras.optimizers.Adam(lr=0.001)

model.compile(loss='mse', optimizer=opt)
model.summary()
history = model.fit(x_train, y_train, epochs=10, batch_size=10) # here you can set a batch size (your 20 sequences will be splitted into two batches)

关于python - 如何设置keras LSTM的输入形状,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64805127/

相关文章:

python - 无效参数错误 : cannot compute MatMul as input #0(zero-based) was expected to be a float tensor but is a double tensor [Op:MatMul]

python - 将装袋分类器与支持向量机模型结合使用

python pillow(更好的PIL)编码检查bug

python - 将 4 channel RGB-D 图像输入 LSTM

tensorflow - 循环中的 tf.train.string_input_ Producer 行为

python - Tensorflow LSTM 有状态选项不维护批处理之间的状态

python - 如何选择 LSTM 中 Dense 层的维度?

python - 在 NetworkX 中打印图形

python 轮版本

tensorflow - 如何使用 tf.summary.text?