python - 保存和加载模型优化器状态

标签 python tensorflow machine-learning keras

我正在训练一组相当复杂的模型,并且正在寻找一种保存和加载模型优化器状态的方法。 “训练者模型”由其他几个“权重模型”的不同组合组成,其中一些模型具有共享权重,一些模型根据训练者的不同而卡住权重等等。这个例子有点太复杂了,无法分享,但总之,我无法使用model.save('model_file.h5')keras.models.load_model('model_file.h5')当我停止和开始训练时。

使用model.load_weights('weight_file.h5')如果训练完成,则可以很好地测试我的模型,但如果我尝试使用此方法继续训练模型,则损失甚至不会接近返回到其最后位置。我读到这是因为优化器状态没有使用这种方法保存,这是有意义的。但是,我需要一种方法来保存和加载训练模型优化器的状态。好像 keras 曾经有一个 model.optimizer.get_sate()model.optimizer.set_sate()这将完成我所追求的目标,但情况似乎不再如此(至少对于 Adam 优化器而言)。当前的 Keras 还有其他解决方案吗?

最佳答案

您可以从 load_modelsave_model 函数中提取重要的行。

用于保存优化器状态,​​在save_model中:

# Save optimizer weights.
symbolic_weights = getattr(model.optimizer, 'weights')
if symbolic_weights:
    optimizer_weights_group = f.create_group('optimizer_weights')
    weight_values = K.batch_get_value(symbolic_weights)

用于加载优化器状态,​​在 load_model 中:

# Set optimizer weights.
if 'optimizer_weights' in f:
    # Build train function (to get weight updates).
    if isinstance(model, Sequential):
        model.model._make_train_function()
    else:
        model._make_train_function()

    # ...

    try:
        model.optimizer.set_weights(optimizer_weight_values)

结合上面的行,这是一个示例:

  1. 首先将模型拟合 5 个时期。
X, y = np.random.rand(100, 50), np.random.randint(2, size=100)
x = Input((50,))
out = Dense(1, activation='sigmoid')(x)
model = Model(x, out)
model.compile(optimizer='adam', loss='binary_crossentropy')
model.fit(X, y, epochs=5)

Epoch 1/5
100/100 [==============================] - 0s 4ms/step - loss: 0.7716
Epoch 2/5
100/100 [==============================] - 0s 64us/step - loss: 0.7678
Epoch 3/5
100/100 [==============================] - 0s 82us/step - loss: 0.7665
Epoch 4/5
100/100 [==============================] - 0s 56us/step - loss: 0.7647
Epoch 5/5
100/100 [==============================] - 0s 76us/step - loss: 0.7638
  • 现在保存权重和优化器状态。
  • model.save_weights('weights.h5')
    symbolic_weights = getattr(model.optimizer, 'weights')
    weight_values = K.batch_get_value(symbolic_weights)
    with open('optimizer.pkl', 'wb') as f:
        pickle.dump(weight_values, f)
    
  • 在另一个 Python session 中重建模型,并加载权重。
  • x = Input((50,))
    out = Dense(1, activation='sigmoid')(x)
    model = Model(x, out)
    model.compile(optimizer='adam', loss='binary_crossentropy')
    
    model.load_weights('weights.h5')
    model._make_train_function()
    with open('optimizer.pkl', 'rb') as f:
        weight_values = pickle.load(f)
    model.optimizer.set_weights(weight_values)
    
  • 继续模型训练。
  • model.fit(X, y, epochs=5)
    
    Epoch 1/5
    100/100 [==============================] - 0s 674us/step - loss: 0.7629
    Epoch 2/5
    100/100 [==============================] - 0s 49us/step - loss: 0.7617
    Epoch 3/5
    100/100 [==============================] - 0s 49us/step - loss: 0.7611
    Epoch 4/5
    100/100 [==============================] - 0s 55us/step - loss: 0.7601
    Epoch 5/5
    100/100 [==============================] - 0s 49us/step - loss: 0.7594
    

    关于python - 保存和加载模型优化器状态,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49503748/

    相关文章:

    python - 模型有多个输出时的ModelCheckpoint监控值

    java - 为什么在 WEKA 中读取此 ARFF 文件时会出现过早的 EOF?

    python - pytorch:来自两个网络时损失的表现如何?

    python - 我想了解为什么会出现这种特殊的 tensorflow 警告......?

    python - Tensorflow 新手 - 尝试将 MNIST 多层网络重新用作计算器

    javascript - 为什么我需要删除Child表单输入来调用自定义GAE API

    python - 用向量/列表元素展平数据框python

    python - 如何从文本文件中读取非英语文本并在 python 中打印?

    python - 通过node-red将Python连接到网页

    python - 使用 argmax 在 tensorflow 中切片张量