python - 显式清除/重置嵌套 TensorFlow Graph 范围

标签 python tensorflow

所以,我正在使用 OpenAI baselines 中的一堆函数用于强化学习。在这些函数中,策略网络使用如下语句进行初始化:

with tf.variable_scope('deepq', reuse=True):
    ...
    return output

问题在于,指向这些网络输出的指针仍在范围内返回,这意味着当从另一个 .py 文件访问这些函数时,我仍然在这些范围内。

基本上,我想运行第一个函数train_policy(output_dir),它训练网络并使用tf.Saver()将检查点转储到磁盘。 接下来,我运行一个函数 run_policy(output_dir),该函数重新初始化相同的 tf 图并使用检查点目录加载其预训练值。

现在,当我尝试这个时,我收到一个 ValueError: “变量 deepq/... 已存在,不允许。您的意思是在 VarScope 中设置reuse=True 或reuse=tf.AUTO_REUSE 吗?” 因为在运行时第二个函数,我仍然在第一个函数定义的范围内。我检查了 code来自 OpenAI 基线(非常嵌套的代码,很难看到正在发生的一切),并且重用已设置为 True

所以我尝试做类似的事情:

tf.get_default_session().close() 后跟:

tf.reset_default_graph()

第一次函数调用之后。 (我不需要 session 保持事件状态,因为我将所有内容转储到磁盘)

但这给了我错误,因为我仍然在嵌套图形范围内,所以我无法重置默认图形...(参见例如 here )

或者我尝试了以下方法:

tf.get_default_graph().as_graph_def().__exit__() 

tf.name_scope('deepq').__exit__()

但是 exit() 函数需要一大堆我不知道如何获取的参数...(而且我找不到关于如何使用此函数的好 documentation)。

我当前的解决方案是在 Python 中的单独子进程中运行这些函数(并让垃圾收集器完成所有工作),但这感觉不是一个令人满意的解决方案。

关于如何处理这个问题有什么想法吗?理想情况下,我需要类似的东西: tf.clear_all_graphs_and_sessions()

最佳答案

一种解决方案确实是重置默认图表: 我只是将每个函数调用包装在一个新的默认图形对象中,如下所示:

with tf.Graph().as_default():
  train_policy(output_dir)

with tf.Graph().as_default():
   run_policy(output_dir)

...

这样,默认图表就会重新初始化为空,您可以加载检查点文件中的任何内容。 (在每个函数中,我还会在返回之前关闭默认 session )。

关于python - 显式清除/重置嵌套 TensorFlow Graph 范围,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48115096/

相关文章:

javascript - 带有 HTMLVideoElement 的 Tensorflow.js Posenet 返回位置 0,0

python - Django 将静态值注释到查询集

python - 为什么数组的形状不正确?

tensorflow - keras 使用 tensorflow 作为后端 :Cannot interpret feed_dict key as Tensor: Can not convert a int into a Tensor

python - 将 keras 模型权重直接保存到字节/内存?

python - 使用批归一化层创建顺序模型会卡住程序

python - Appengine - 从标准数据库升级到导航台 - ReferenceProperties

python - PyQt:QDateComboBox setDate 使用字符串?

python - 无需多次迭代即可从文件中获取数据

python - TensorFlow tf.reshape Fortran 命令(像 numpy)