python - 带有 TimeSeriesGenerator 的 Keras LSTM 自定义数据生成器

标签 python keras lstm

所以我正在尝试使用 Keras 的 fit_generator使用自定义数据生成器输入 LSTM 网络。

什么有效

为了说明这个问题,我创建了一个玩具示例,试图以简单的升序预测下一个数字,我使用 Keras TimeseriesGenerator创建一个 Sequence 实例:

WINDOW_LENGTH = 4
data = np.arange(0,100).reshape(-1,1)
data_gen = TimeseriesGenerator(data, data, length=WINDOW_LENGTH,
                               sampling_rate=1, batch_size=1)

我使用一个简单的 LSTM 网络:
data_dim = 1
input1 = Input(shape=(WINDOW_LENGTH, data_dim))
lstm1 = LSTM(100)(input1)
hidden = Dense(20, activation='relu')(lstm1)
output = Dense(data_dim, activation='linear')(hidden)

model = Model(inputs=input1, outputs=output)
model.compile(loss='mse', optimizer='rmsprop', metrics=['accuracy'])

并使用 fit_generator 训练它功能:
model.fit_generator(generator=data_gen,
                    steps_per_epoch=32,
                    epochs=10)

这可以完美地训练,并且模型会按预期进行预测。

问题

现在的问题是,在我的非玩具情况下,我想在将数据输入 fit_generator 之前处理来自 TimeseriesGenerator 的数据。 .作为朝着这个方向迈出的一步,我创建了一个生成器函数,它只包装了之前使用的 TimeseriesGenerator。
def get_generator(data, targets, window_length = 5, batch_size = 32):
    while True:
        data_gen = TimeseriesGenerator(data, targets, length=window_length, 
                                       sampling_rate=1, batch_size=batch_size)
        for i in range(len(data_gen)):
            x, y = data_gen[i]
            yield x, y

data_gen_custom = get_generator(data, data,
                                window_length=WINDOW_LENGTH, batch_size=1)

但现在奇怪的是,当我像以前一样训练模型,但使用这个生成器作为输入时,
model.fit_generator(generator=data_gen_custom,
                    steps_per_epoch=32,
                    epochs=10)

没有错误,但训练错误无处不在(上下跳跃而不是像其他方法那样持续下降),并且模型没有学会做出好的预测。

任何想法我的自定义生成器方法做错了什么?

最佳答案

可能是因为对象类型从 Sequence 改变了这是什么TimeseriesGenerator是通用生成器。 fit_generator函数以不同的方式对待这些。更简洁的解决方案是继承该类并覆盖处理位:

class CustomGen(TimeseriesGenerator):
  def __getitem__(self, idx):
    x, y = super()[idx]
    # do processing here
    return x, y

并像以前一样使用这个类,因为内部逻辑的其余部分将保持不变。

关于python - 带有 TimeSeriesGenerator 的 Keras LSTM 自定义数据生成器,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50322660/

相关文章:

python - 来自 Yelp API 的错误请求

Python - 带请求的curl请求, header 设置不正确? (可能是 GitLab API 问题)

python - 在多维列表索引中使用 splat 运算符

neural-network - 使用 Keras 的 LSTM 网络中的验证损失和准确性

machine-learning - 我们应该何时何地使用这些 keras LSTM 模型

python - 戈朗 : Unexpected EOF error while reading Gzip Reader

machine-learning - 将 keras CNN 应用于新数据集

neural-network - Keras 上的域适配

python - Keras:如何从张量中仅提取某些层

deep-learning - LSTM神经网络中的损失函数