python - tensorflow 集线器 : module spec export with checkpoint path doesn't save all variables

标签 python tensorflow tensorflow-hub

我想用tensorflow训练GAN,然后将生成器和鉴别器导出为tensorflow_hub模块。
为此:
- 我用tensorflow定义我的GAN架构
- 训练它并保存检查点
- 使用不同的标签创建 module_spec,例如:
(set(), {'batch_size': 8, 'model': 'gen'})
({'bs8', 'gen'}, {'batch_size': 8, 'model': 'gen'})
({'bs8', 'disc'}, {'batch_size': 8, 'model': 'disc'})
- 使用我在训练期间保存的 checkpoint_path 在 tf_hub_path 处使用 module_spec 导出

然后,我可以使用以下命令加载我的生成器:

hub.Module(tf_hub_path, tags={"gen", "bs8"})

但是,当我尝试使用类似的命令加载鉴别器时:

hub.Module(tf_hub_path, tags={"disc", "bs8"})

我收到错误:

ValueError: Tensor discriminator/linear/bias is not found in b'/tf_hub/variables/variables' checkpoint {'generator/fc_noise/kernel': [2, 48], 'generator/fc_noise/bias': [48]}

因此,我得出的结论是,鉴别器中存在的变量未保存在磁盘上的模块中。我检查了我想象的不同错误来源:

  • 模块规范已正确定义。为此,我决定训练我的模型,创建模块规范并直接从该 module_spec 加载模块。这对于生成器和鉴别器来说效果很好。然后,我假设我的 module_spec 是正确的
  • 然后,我想知道检查点是否正确保存了图表中的所有变量。

    checkpoint_path = tf.train.latest_checkpoint(self.model_dir)
    inspect_list = tf.train.list_variables(checkpoint_path)
    print(inspect_list)
    
    [('disc_step_1/beta1_power', []),
    ('disc_step_1/beta2_power', []),
    ('discriminator/linear/bias', [1]),
    ('discriminator/linear/bias/d_opt', [1]),
    ('discriminator/linear/bias/d_opt_1', [1]),
    ('discriminator/linear/kernel', [3, 1]),
    ('discriminator/linear/kernel/d_opt', [3, 1]),
    ('discriminator/linear/kernel/d_opt_1', [3, 1]),
    ('gen_step/beta1_power', []),
    ('gen_step/beta2_power', []),
    ('generator/fc_noise/bias', [48]),
    ('generator/fc_noise/bias/g_opt', [48]),
    ('generator/fc_noise/bias/g_opt_1', [48]),
    ('generator/fc_noise/kernel', [2, 48]),
    ('generator/fc_noise/kernel/g_opt', [2, 48]),
    ('generator/fc_noise/kernel/g_opt_1', [2, 48]),
    ('global_step', []),
    ('global_step_disc', [])]
    

    因此,我看到所有变量都正确保存在检查点内。只有与生成器相关的两个变量被正确导出到磁盘上的 tf hub 模块中。

最后,我认为我的错误来自:

module_spec.export(tf_hub_path, checkpoint_path=checkpoint_path)

从 checkpoint_path 导出变量时仅考虑标签“gen”。我还检查了 module.variable_map 和检查点路径中的列表变量之间的变量名称是否对应。这是带有标签“disc”的模块的变量映射:

print(module.variable_map)
{'discriminator/linear/bias': <tf.Variable 'module_8/discriminator/linear/bias:0' shape=(1,) dtype=float32>, 'discriminator/linear/kernel': <tf.Variable 'module_8/discriminator/linear/kernel:0' shape=(3, 1) dtype=float32>}

我有

  • tensorflow :1.13.1
  • tensorflow_hub:0.4.0
  • Python:3.5.2

感谢您的帮助

最佳答案

我找到了解决这个问题的方法,尽管我认为这不是最干净的方法:

当调用不带标签的 hub.Module 时,下一行代码默认定义模块:

(set(), {'batch_size': 8, 'model': 'gen'})

事实上,我意识到这组参数正在定义通过 module_spec.export 导出哪个图。它解释了为什么我在导入模块时能够访问生成器的变量,但不能访问鉴别器的变量。
因此,我决定默认使用这组参数:

(set(), {'batch_size': 8, 'model': 'both'})

并且,在 hub.create_module_spec 调用的 _module_fn 方法中,我将生成器和鉴别器的输入(分别是输出)定义为模型的输入(分别是输出)。因此,在导出 module_spec 时,我能够访问图表的所有变量。

关于python - tensorflow 集线器 : module spec export with checkpoint path doesn't save all variables,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56756705/

相关文章:

python - 更改 tensorflow 中张量的比例

tensorflow - tensorflow hub : Table not initialized 的问题

jquery - Django 自动完成功能与 django-ajax-selects

python - CNN 对所有输入数据预测相同的类别

python - 从 Selenium 导入 webdriver 不工作?

python - pytorch 中的 tf.cast 等价物?

python - 导入错误: cannot import name 'regex_replace'

python - 如何使用 math.atan2 函数检测一个对象是否在另一个对象的 "line of sight"中?

python - 稍后如何获取长时间运行的 Google Cloud Speech API 操作的结果?