python - 如何使用 Tensorflow 2/Keras 保存和继续训练具有多个模型部分的 GAN

标签 python tensorflow keras deep-learning generative-adversarial-network

我目前正在尝试添加一个功能来中断和恢复基于此示例代码创建的 GAN 的训练:https://machinelearningmastery.com/how-to-develop-an-auxiliary-classifier-gan-ac-gan-from-scratch-with-keras/
我设法让它以一种方式工作,我将整个复合 GAN 的权重保存在 summarise_performance 函数中,该函数每 10 个时期触发一次,如下所示:

# save all weights
filename3 = 'weights_%08d.h5' % (step+1)
gan_model.save_weights(filename3)
print('>Saved: %s and %s and %s' % (filename1, filename2, filename3))
它被加载到我添加到程序开头的一个名为 load_model 的函数中,该函数采用正常构建的 gan 架构,但将其权重更新为最新值,如下所示:
#load model from file and return startBatch number
def load_model(gan_model):
   start_batch = 0
   files = glob.glob("./weights_0*.h5")
   if(len(files) > 0 ):
       most_recent_file = files[len(files)-1]
       gan_model.load_weights(most_recent_file)
       #TODO: breaks if using more than 8 digits for batches
       startBatch = int(most_recent_file[10:18])
       if (start_batch != 0):
           print("> found existing weights; starting at batch %d" % start_batch)
   return start_batch
start_batch 被传递给 train 函数以跳过已经完成的时期。
虽然这种减轻重量的方法确实“有效”,但我仍然认为我在这里的方法是错误的,因为我发现权重数据显然不包括 GAN 的优化器状态,​​因此训练不会像它那样继续没有被打断。
我发现保存进度同时保存优化器状态的方法显然是通过保存整个模型而不仅仅是权重来完成的
在这里我遇到了一个问题,因为在 GAN 中,我不仅训练了一个模型,而且有 3 个模型:
  • 生成器模型 g_model
  • 判别器模型 d_model
  • 和复合 GAN 模型 gan_model

  • 这些都是相互联系和相互依赖的。如果我采用天真的方法并分别保存和恢复这些零件模型中的每一个,我最终会得到 3 个独立的脱节模型而不是 GAN
    有没有一种方法可以让我恢复训练,就好像没有发生中断一样,可以保存和恢复整个 GAN?

    最佳答案

    或许可以考虑使用 tf.train.Checkpoint , 如果您想恢复整个 GAN:

    ### In your training loop
    
    checkpoint_dir = '/checkpoints'
    checkpoint = tf.train.Checkpoint(gan_optimizer=gan_optimizer,
                                discriminator_optimizer=discriminator_optimizer,
                                      generator=generator,
                                      discriminator=discriminator
                                      gan_model = gan_model)
      
    ckpt_manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3)
    if ckpt_manager.latest_checkpoint:
        checkpoint.restore(ckpt_manager.latest_checkpoint)  
        print ('Latest checkpoint restored!!')
    
    ....
    ....
    
    
    if (epoch + 1) % 40 == 0:
        ckpt_save_path = ckpt_manager.save()
        print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,ckpt_save_path))
    
    ### After x number of epochs, just save your generator model for inference.
    
    generator.save('your_model.h5')
    
    您也可以考虑完全摆脱复合模型。 Here是我的意思的一个例子。

    关于python - 如何使用 Tensorflow 2/Keras 保存和继续训练具有多个模型部分的 GAN,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/69708977/

    相关文章:

    python - CountVectorizer 给出错误的单词计数?

    python - 将 Pandas 数据框 reshape 为与重复行一样多的列

    python - vscode 中的 pylint 导入错误,即使 python 执行器成功导入它

    python - Tensorflow:无法将 tf.case 与输入参数一起使用

    python - Tensorflow-gpu获取卷积算法失败

    machine-learning - Keras:连接不同模型的两层以创建新模型

    python - 用 python 生成/合成声音?

    python - 无法保存模型架构(bilstm+attention)

    python - 如何将特征单独输入到LSTM模型中

    python - 如何将变量添加到 Keras 中的进度条?