我在实现 Keras TimeseriesGenerator 时遇到问题。我想要的是用不同的 look_back
值进行试验,这是一个变量,它根据每个 y 确定 X 的滞后长度。现在,我将其设置为 3,但希望能够测试多个值。本质上,我想看看使用最后 n 行来预测一个值是否会提高准确性。这是我的代码:
### trying with timeseries generator
from keras.preprocessing.sequence import TimeseriesGenerator
look_back = 3
train_data_gen = TimeseriesGenerator(X_train, X_train,
length=look_back, sampling_rate=1,stride=1,
batch_size=3)
test_data_gen = TimeseriesGenerator(X_test, X_test,
length=look_back, sampling_rate=1,stride=1,
batch_size=1)
### Bi_LSTM
Bi_LSTM = Sequential()
Bi_LSTM.add(layers.Bidirectional(layers.LSTM(512, input_shape=(look_back, 11))))
Bi_LSTM.add(layers.Dropout(.5))
# Bi_LSTM.add(layers.Flatten())
Bi_LSTM.add(Dense(11, activation='softmax'))
Bi_LSTM.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])
### fitting a small normal model seems to be necessary for compile
Bi_LSTM.fit(X_train[:1],
y_train[:1],
epochs=1,
batch_size=32,
validation_data=(X_test[:1], y_test[:1]),
class_weight=class_weights)
print('ignore above, necessary to run custom generator...')
Bi_LSTM_history = Bi_LSTM.fit_generator(Bi_LSTM.fit_generator(generator,
steps_per_epoch=1,
epochs=20,
verbose=0,
class_weight=class_weights))
这会产生以下错误:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-35-11561ec7fb92> in <module>()
26 batch_size=32,
27 validation_data=(X_test[:1], y_test[:1]),
---> 28 class_weight=class_weights)
29 print('ignore above, necessary to run custom generator...')
30 Bi_LSTM_history = Bi_LSTM.fit_generator(Bi_LSTM.fit_generator(generator,
2 frames
/usr/local/lib/python3.6/dist-packages/keras/engine/training_utils.py in standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
143 ': expected ' + names[i] + ' to have shape ' +
144 str(shape) + ' but got array with shape ' +
--> 145 str(data_shape))
146 return data
147
ValueError: Error when checking input: expected lstm_16_input to have shape (3, 11) but got array with shape (1, 11)
如果我将 BiLSTM 输入形状更改为上面列出的 (1,11),则会出现此错误:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-36-7360e3790518> in <module>()
31 epochs=20,
32 verbose=0,
---> 33 class_weight=class_weights))
34
5 frames
/usr/local/lib/python3.6/dist-packages/keras/engine/training_utils.py in standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
143 ': expected ' + names[i] + ' to have shape ' +
144 str(shape) + ' but got array with shape ' +
--> 145 str(data_shape))
146 return data
147
ValueError: Error when checking input: expected lstm_17_input to have shape (1, 11) but got array with shape (3, 11)
这是怎么回事?
如果需要,我的数据是从 df 读取的,其中每一行(观察)是一个 (1,11)
浮点向量,每个标签是一个 int,我将其转换为 1 热向量形状 (1,11)
。
最佳答案
我在代码中发现了很多错误...因此,我想提供一个虚拟示例,您可以按照它来执行任务。请注意您的数据的原始维度和TimeSeriesGenerator生成的数据的维度。这对于了解如何构建网络很重要
# utility variable
look_back = 3
batch_size = 3
n_feat = 11
n_class = 11
n_train = 200
n_test = 60
# data simulation
X_train = np.random.uniform(0,1, (n_train,n_feat)) # 2D!
X_test = np.random.uniform(0,1, (n_test,n_feat)) # 2D!
y_train = np.random.randint(0,2, (n_train,n_class)) # 2D!
y_test = np.random.randint(0,2, (n_test,n_class)) # 2D!
train_data_gen = TimeseriesGenerator(X_train, y_train, length=look_back, batch_size=batch_size)
test_data_gen = TimeseriesGenerator(X_test, y_test, length=look_back, batch_size=batch_size)
# check generator dimensions
for i in range(len(train_data_gen)):
x, y = train_data_gen[i]
print(x.shape, y.shape)
Bi_LSTM = Sequential()
Bi_LSTM.add(Bidirectional(LSTM(512), input_shape=(look_back, n_feat)))
Bi_LSTM.add(Dropout(.5))
Bi_LSTM.add(Dense(n_class, activation='softmax'))
print(Bi_LSTM.summary())
Bi_LSTM.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])
Bi_LSTM_history = Bi_LSTM.fit_generator(train_data_gen,
steps_per_epoch=50,
epochs=3,
verbose=1,
validation_data=test_data_gen) # class_weight=class_weights)
关于python - 如何使用 Keras TimeseriesGenerator,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61641048/