python - tensorflow 中不同范围的子网权重共享

标签 python tensorflow

使用 tensorflow,我试图在不同变量范围内共享来自相同网络的相同权重以节省内存。但是,似乎没有简单的方法可以做到这一点。我准备了一个小代码示例,以在较小的规模上说明我想对较大的子网执行的操作:

import tensorflow as tf

graph = tf.Graph()
with graph.as_default():
    with tf.variable_scope("super_scope_one"):
        scope1 = tf.variable_scope("sub_scope_one")
        with scope1:
            number_one = tf.get_variable("number_one", shape=[1],
                                         initializer=tf.ones_initializer)
    with tf.variable_scope("super_scope_two"):
        with tf.variable_scope("sub_scope_one", reuse=True) as scope2:
            # Here is the problem.
            # scope1.reuse_variables() # this crashes too if reuse=None.
            number_one = tf.get_variable("number_one", shape=[1])
        with tf.variable_scope("sub_scope_two"):
            number_two = tf.get_variable("number_two", shape=[1],
                                         initializer=tf.ones_initializer)
        number_three = number_one + number_two

    init_op = tf.global_variables_initializer()

with tf.Session(graph=graph):
    init_op.run()
    print(number_three.eval())

有没有办法在不删除的情况下共享两个子作用域中的变量 上面的范围?如果不是,是否有充分的理由说明这将是一个坏主意?

最佳答案

您可以在 "super_scope_one" 中简单地定义一次 number_one 并在 "super_scope_two" 中使用它。

不同作用域的两个变量可以一起使用。见下文:

import tensorflow as tf

graph = tf.Graph()
with graph.as_default():
    with tf.variable_scope("super_scope_one"):
        scope1 = tf.variable_scope("sub_scope_one")
        with scope1:
            number_one = tf.get_variable("number_one", shape=[1],
                                         initializer=tf.ones_initializer)
    with tf.variable_scope("super_scope_two"):
        with tf.variable_scope("sub_scope_two"):
            number_two = tf.get_variable("number_two", shape=[1],
                                         initializer=tf.ones_initializer)
        number_three = number_one + number_two

    init_op = tf.global_variables_initializer()

    with tf.Session(graph=graph):
        init_op.run()
        print(number_three.eval())

返回 [2]

关于python - tensorflow 中不同范围的子网权重共享,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46585245/

相关文章:

python - 如何找出 python pandas dataframe 列(日期格式)中的空白?

python - 将 Flask 表单值转换为 int

python - 在Raspberry Pi上使用相机进行运动检测

tensorflow - 不了解类 UNET 架构中的数据流,并且对 Conv2DTranspose 层的输出有问题

python - Tensorflow GradientBoostedDecisionTreeClassifier错误: "Dense float feature must be a matrix"

Python检查列表项是否为整数?

python - 在python中将本地时间从UTC更改为UTC + 2

python - Pandas Dataframe - 向下移动行并维护数据

Tensorflow Lite GPU 对 python 的支持

python - 如何在 tensorflow 中复制 numpy.choose() ?