python - 了解 TensorFlow LSTM 输入形状

标签 python tensorflow regression lstm

我有一个数据集 X,它包含 N = 4000 个样本,每个样本包含 d = 2 个特征(连续值),跨越 t = 10 次步骤。在时间步 11,我也有每个样本的相应“标签”,它们也是连续值。

目前我的数据集的形状是 X:[4000,20],Y:[4000]。

我想使用 TensorFlow 训练一个 LSTM 来预测 Y 的值(回归),给定 d 特征的 10 个先前输入,但我很难在 TensorFlow 中实现这一点。

我目前遇到的主要问题是了解 TensorFlow 如何期望输入被格式化。我见过各种例子,例如 this ,但这些示例处理一大串连续的时间序列数据。我的数据是不同的样本,每个样本都是独立的时间序列。

最佳答案

documentation of tf.nn.dynamic_rnn状态:

inputs: The RNN inputs. If time_major == False (default), this must be a Tensor of shape: [batch_size, max_time, ...], or a nested tuple of such elements.

在您的情况下,这意味着输入的形状应为 [batch_size, 10, 2]。您不必一次对所有 4000 个序列进行训练,而是在每次训练迭代中只使用 batch_size 其中许多序列。类似以下的东西应该可以工作(为清楚起见添加了 reshape ):

batch_size = 32
# batch_size sequences of length 10 with 2 values for each timestep
input = get_batch(X, batch_size).reshape([batch_size, 10, 2])
# Create LSTM cell with state size 256. Could also use GRUCell, ...
# Note: state_is_tuple=False is deprecated;
# the option might be completely removed in the future
cell = tf.nn.rnn_cell.LSTMCell(256, state_is_tuple=True)
outputs, state = tf.nn.dynamic_rnn(cell,
                                   input,
                                   sequence_length=[10]*batch_size,
                                   dtype=tf.float32)

来自 documentation , outputs 的形状为 [batch_size, 10, 256],即每个时间步有一个 256 输出。 state 将是 tuple形状 [batch_size, 256]。你可以从中预测你的最终值,每个序列一个:

predictions = tf.contrib.layers.fully_connected(state.h,
                                                num_outputs=1,
                                                activation_fn=None)
loss = get_loss(get_batch(Y).reshape([batch_size, 1]), predictions)

outputsstate 形状中的数字 256 分别由 cell.output_size 决定。 cell.state_size。像上面那样创建 LSTMCell 时,它们是相同的。另见 LSTMCell documentation .

关于python - 了解 TensorFlow LSTM 输入形状,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39324520/

相关文章:

javascript - bootstrap 导航栏下拉菜单未在 django 中显示

python - TensorFlow LSTM 预测相同的值

python - Python 中的 SVM 回归速度更快

python - 好的高级 python ftp/http 库?

python - Django:使用 user.get_profile()

python - 将 CSV 中的列编码为 Base64

python - 如何为此机器学习模型设置 request.py?

python - 如何使用 TensorFlow 的 sketch RNN 教程对 QuickDraw 涂鸦进行分类?

具有多处理功能的 Tensorflow2.x 自定义数据生成器

variables - Outreg2 和具有交互作用的回归 c.var1##c.var2