python - Keras LSTM 保存后继续训练

标签 python tensorflow keras

我正在研究 LSTM 模型,我想保存它并在以后继续处理它积累的额外数据。 我的问题是,保存模型并在下次运行脚本时再次加载它后,预测完全错误,它只是模仿了我输入的数据。

这是模型初始化:

# create and fit the LSTM network
if retrain == 1:
    print "Creating a newly retrained network."
    model = Sequential()
    model.add(LSTM(inputDimension, input_shape=(1, inputDimension)))
    model.add(Dense(inputDimension, activation='relu'))
    model.compile(loss='mean_squared_error', optimizer='adam')
    model.fit(trainX, trainY, epochs=epochs, batch_size=batch_size, verbose=2)
    model.save("model.{}.h5".format(interval))
else:
    print "Using an existing network."
    model = load_model("model.{}.h5".format(interval))
    model.compile(loss='mean_squared_error', optimizer='adam')
    model.fit(trainX, trainY, epochs=epochs, batch_size=batch_size, verbose=2)
    model.save("model.{}.h5".format(interval))
    del model
    model = load_model("model.{}.h5".format(interval))
    model.compile(loss='mean_squared_error', optimizer='adam')

第一个数据集,当 retrain 设置为 1 时,大约有 10000 个条目,大约 3k 个时期和 5% 的批量大小。 第二个数据集是单项数据。就像在一行中一样,又是 3k 个时期和 batch_size=1

已解决

我错误地重新加载了定标器:

scaler = joblib.load('scaler.{}.data'.format(interval))
dataset = scaler.fit_transform(dataset)

正确:

scaler = joblib.load('scaler.{}.data'.format(interval))
dataset = scaler.transform(dataset)

fit_transform 重新计算缩放值的乘数,这意味着原始数据会有偏移量。

最佳答案

来自函数keras model api对于 model.fit():

initial_epoch: Integer. Epoch at which to start training (useful for resuming a previous training run).

设置此参数可能会解决您的问题。

我认为问题的根源是来自 adam 的自适应学习率。在训练期间,为了对模型进行更多微调,学习率会自然下降。当您仅使用一个样本重新训练您的模型时,权重更新太大(因为重新设置了学习率),这可能会完全破坏您之前的权重。

如果 initial_epoch 不好,则尝试以较低的学习率开始第二次训练。

关于python - Keras LSTM 保存后继续训练,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50253014/

相关文章:

python - 使用python转换从ntp服务器检索到的日期

python - 使用 Lab 时,transforms.Normalize() 介于 0 和 1 之间

python-3.x - 尝试使用 Keras 中的 model.predict() 时尺寸错误

tensorflow - 带有图像和标量的 keras 生成器

python - 如何在 keras 自定义回调中访问 tf.data.Dataset?

python UnicodeWarning : Unicode equal comparison. 如何解决这个错误?

python - 如何通过 Yocto/poky 在 Jetson Nano 上使用带摄像头的 OpenCV

python - 如何在 tensorflow 中保存训练好的模型?

tensorflow - 如何 : Import TensorFlow in Jupyter Notebook from Conda with GPU support?

python reversed(list) 和 list.sort(reverse=True) 的区别