python - TensorFlow 变量范围 : reuse if variable exists

标签 python tensorflow

我想要一段代码,如果它不存在,则在范围内创建一个变量,如果它已经存在,则访问该变量。我需要它是 same 代码,因为它将被多次调用。

但是,Tensorflow 需要我指定是要创建还是重用变量,如下所示:

with tf.variable_scope("foo"): #create the first time
    v = tf.get_variable("v", [1])

with tf.variable_scope("foo", reuse=True): #reuse the second time
    v = tf.get_variable("v", [1])

如何让它确定是自动创建还是重用它?即,我希望上面的两个代码块是 same 并让程序运行。

最佳答案

get_variable() 在创建新变量且未声明形状时,或在创建变量期间违反重用时会引发 ValueError。因此,你可以试试这个:

def get_scope_variable(scope_name, var, shape=None):
    with tf.variable_scope(scope_name) as scope:
        try:
            v = tf.get_variable(var, shape)
        except ValueError:
            scope.reuse_variables()
            v = tf.get_variable(var)
    return v

v1 = get_scope_variable('foo', 'v', [1])
v2 = get_scope_variable('foo', 'v')
assert v1 == v2

请注意,以下内容也有效:

v1 = get_scope_variable('foo', 'v', [1])
v2 = get_scope_variable('foo', 'v', [1])
assert v1 == v2

更新。新的 API 现在支持自动重用:

def get_scope_variable(scope, var, shape=None):
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        v = tf.get_variable(var, shape)
    return v

关于python - TensorFlow 变量范围 : reuse if variable exists,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/38545362/

相关文章:

tensorflow - 获取 SparseTensor 的非零行

python - 字符串化列表到两个单独的列表

python - 从不同版本的 tf.keras 加载保存的模型

python - 类型错误 : Input 'b' of 'MatMul' Op has type float32 that does not match type int32 of argument 'a'

python - Pandas :检查列的子集中的任何值是否符合条件

tensorflow - 不了解类 UNET 架构中的数据流,并且对 Conv2DTranspose 层的输出有问题

tensorflow - 如何在Tensorflow中使用多层双向LSTM?

python - BeautifulSoup 找不到标签 li

python - 从列表中增量删除元素

python - 为什么部分在 View 之外的矩形被绘制为三角形?