python - 如何使用 TensorFlow 在 Returnn 中加载经过训练的网络的权重

标签 python tensorflow returnn

当尝试加载经过训练的多个时期保存的权重时 使用以下代码返回网络:

import tensorflow as tf
from returnn.Config import Config
from returnn.TFNetwork import TFNetwork

for i in range(1,11):
    modelFilePath = path/to/model/ + 'network.' + '%03d' % (i,)

    returnnConfig = Config()
    returnnConfig.load_file(path/to/configFile)
    returnnTfNetwork = TFNetwork(config=path/to/configFile, train_flag=False, eval_flag=True)

    returnnTfNetwork.construct_from_dict(returnnConfig.typed_value('network'))

    with tf.Session() as sess:
        returnnTfNetwork.load_params_from_file(modelFilePath, sess)

我收到以下错误:

Variables to restore which are not in checkpoint:
global_step_1

Variables in checkpoint which are not needed for restore:
global_step

Probably we can restore these:
(None)

Error, some entry is missing in the checkpoint

最佳答案

问题在于,您每次在循环中都重新创建 TFNetwork,并且每次都会为全局步骤创建一个新变量,该变量必须以不同的方式调用,因为每个变量都必须有一个唯一的名称。

你可以在循环内做这样的事情:

tf.reset_default_graph()

关于python - 如何使用 TensorFlow 在 Returnn 中加载经过训练的网络的权重,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51043758/

相关文章:

python - 即使 access_type=offline,OAuth2.0 也无法获取刷新 token

python-2.7 - keras自定义损失函数

javascript - 已移植到 TensorFlow.js 的预训练 TensorFlow 模型的权重是否可以在运行时通过模型对象访问?

python - 从 Tensorflow 转换 -> CoreML 3.0 用于插槽/意图检测

python - 从子网内部的层到该子网外部的层重用参数

python - 在 Python 中删除 USB 设备上的树

python - 如何将文本放在python图之外?

python - 我可以加入类(class)列表吗?