keras - 当我应用有状态 LSTM 并重置状态时,有什么方法可以使用 EarlyStopping Keras 函数吗?

标签 keras lstm stateful

我正在使用有状态 LSTM 回归模型,并且我想应用 EarlyStopping 函数。在我阅读的有状态 LSTM 中,状态应该在每个时期重置。但是,我注意到当我重置状态时,EarlyStopping 方法根本不起作用。我还附上了代码。


model = Sequential()
model.add(LSTM(256, batch_input_shape=(batch_size, timesteps, features), return_sequences=False, stateful=True))
model.add(Dropout(rate=0.2))
model.add(Dense(1, activation='sigmoid'))
model.compile(loss='mean_squared_error', optimizer='adam')
mc = ModelCheckpoint('best_model.h5', monitor='val_loss', mode='min', verbose=0, save_best_only=True)
es = EarlyStopping(monitor='val_loss', mode='min', patience=1, restore_best_weights=True, verbose=1)

for epoch in range(epochs):
    print("Epoch: ", epoch + 1)

    hist = model.fit(train_x, train_y, epochs=1, batch_size, shuffle=False,
                     validation_data=(validation_x, validation_y), verbose=2, callbacks=[mc, es])
    model.reset_states()


如果我在没有 for 循环且没有重置状态的情况下运行上述代码,则 EarlyStopping 可以正常工作。有没有办法在 for 循环中应用 EarlyStopping?

先感谢您

最佳答案

epochs=1 的数量时,似乎无法应用 EarlyStopping在 model.fit()功能。据我所知,这是因为 EarlyStopping每次申请一个epoch,只有在model.fit()中的epoch数才能工作高于 1。
我使用以下代码保存最佳模型并在一些时期后停止训练过程。

# Number of epochs to wait before halting the training process
patience = 50

# Store the metrics of each epoch to a pandas dataframe
history = pd.DataFrame()

# Define a high loss value (this may change based on the classification problem that you have)
min_loss = 2.00

# Define a minimum accuracy value
min_acc = 0.25

# Initialize the wait variable        
wait = 0

for epoch in range(epochs):
    print("Epoch: ", epoch + 1)

    hist = model.fit(train_x, train_y, epochs=1, batch_size, shuffle=False,
                     validation_data=(validation_x, validation_y), verbose=2)
    model.reset_states()

    if epoch >= 0:
         if np.isnan(hist.history['val_loss'][0]):
                    break
                else:
                    if round(hist.history['val_loss'][0], 4) < min_loss:
                        min_loss = round(hist.history['val_loss'][0], 4)
                        min_acc = hist.history['val_accuracy'][0]
                        model.save('best_model')
                        history.loc[epoch, 'epoch'] = epoch + 1
                        history.loc[epoch, 'loss'] = hist.history['loss'][0]
                        history.loc[epoch, 'val_loss'] = hist.history['val_loss'][0]
                        history.loc[epoch, 'accuracy'] = hist.history['accuracy'][0]
                        history.loc[epoch, 'val_accuracy'] = hist.history['val_accuracy'][0]
                        wait = 0
                    else:
                        wait += 1
                        print('*' * 50)
                        print(f"Patience: {wait}/ {patience}", "-", "Current best val_accuracy:",
                              '{0:.5}'.format(min_acc),
                              "with loss:", '{0:.5}'.format(min_loss), f"at epoch {epoch - wait}")
                        print('*' * 50)

                        if wait < patience:
                            history.loc[epoch, 'epoch'] = epoch + 1
                            history.loc[epoch, 'loss'] = hist.history['loss'][0]
                            history.loc[epoch, 'val_loss'] = hist.history['val_loss'][0]
                            history.loc[epoch, 'accuracy'] = hist.history['accuracy'][0]
                            history.loc[epoch, 'val_accuracy'] = hist.history['val_accuracy'][0]

                        else:
                            break

history.to_csv('history.csv', header=True, index=False)

关于keras - 当我应用有状态 LSTM 并重置状态时,有什么方法可以使用 EarlyStopping Keras 函数吗?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57732296/

相关文章:

python - Keras 合并 VS 连接,无法更新我的代码

python - 如何在 LSTM 中实现 Tensorflow 批量归一化

jsf - 为什么我不应该将 JSF SessionScoped bean 用于逻辑?

java - 差异 : @SessionScoped vs @Stateful and @ApplicationScoped vs @Singleton

java - ERROR 清理广播异常时出错

python - 如何比较两个一个热编​​码列表?

python-3.x - Google Colab - 输出保存在哪里?

python - 每个张量组的 Keras 自定义损失函数

python - 用于时间序列预测的 Keras LSTM 神经网络在模型拟合期间显示 nan

python-3.x - LSTM陷入循环