tensorflow - 如何将变量复制到 tensorflow 中的另一个图

标签 tensorflow

我希望将 tensorflow 变量从旧图复制到新图,然后删除旧图并将新图设置为默认图。下面是我的代码,但它引发了一个 AttributeError: 'Graph' object has no attribute 'variable1'。我是 tensorflow 新手。谁能给我一个具体的例子吗?

import tensorflow as tf
import numpy as np

graph1 = tf.Graph()
with graph1.as_default():
    variable1 = tf.Variable(np.array([2.,-1.]), name='x1')
    initialize = tf.initialize_all_variables()

sess1 = tf.Session(graph=graph1)
sess1.run(initialize)
print('sess1:',variable1.eval(session=sess1))

graph2 = tf.Graph()
with graph2.as_default():
    variable2 = tf.contrib.copy_graph.copy_variable_to_graph(graph1.variable1, graph2)
sess2 = tf.Session(graph=graph2)
#I want to remove graph1 and sess1, and make graph2 and sess2 the default here.
print('sess2:', variable2.eval(session=sess2))

最佳答案

  1. tf.initialize_all_variables() 已弃用。请改用 tf.global_variables_initializer()

  2. 您不需要graph1.variable1。只需传递 variable1

  3. 您忘记在第二个 session 中初始化变量:

    initialize2 = tf.global_variables_initializer()
    sess2=tf.Session(graph=graph2)
    sess2.run(initialize2)
    

所以你的代码应该是这样的:

import tensorflow as tf
import numpy as np

graph1 = tf.Graph()
with graph1.as_default():
    variable1 = tf.Variable(np.array([2.,-1.]), name='x1')
    initialize = tf.global_variables_initializer()
    sess1=tf.Session(graph=graph1)
    sess1.run(initialize)
    print('sess1:',variable1.eval(session=sess1))

graph2 = tf.Graph()
with graph2.as_default():
    variable2=tf.contrib.copy_graph.copy_variable_to_graph(variable1,graph2)
    initialize2 = tf.global_variables_initializer()
    sess2=tf.Session(graph=graph2)
    sess2.run(initialize2)
    #I want to remove graph1 and sess1, ande make graph2 and sess2 as default here.
    print('sess2:',variable2.eval(session=sess2))

关于tensorflow - 如何将变量复制到 tensorflow 中的另一个图,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45273461/

相关文章:

python - 如何将 perreplica 转换为张量?

TensorFlow:numpy.repeat() 替代方案

python - Windows 10 上的 TensorFlow : “not a supported wheel on this platform” error

python - Tensorflow:静态和动态形状

python - 使用 Tensorflow 在二元分类中改变准确度值并且不改变损失值

python - 如何在带有tensorflow v2.x后端的keras中加载带有tensorflow v1.x后端的keras模型?

python - 如何在 TensorFlow 中使用索引数组?

tensorflow - 如何在 Tensorflow 中获取 CNN 内核值

machine-learning - Tensorflow:启动新 session 时出现扭矩和 GPU 问题:CUDA_ERROR_INVALID_DEVICE

python - Tensorflow 上的 GPU 是否进行了非极大值抑制?