我想要一段代码,如果它不存在,则在范围内创建一个变量,如果它已经存在,则访问该变量。我需要它是 same 代码,因为它将被多次调用。
但是,Tensorflow 需要我指定是要创建还是重用变量,如下所示:
with tf.variable_scope("foo"): #create the first time
v = tf.get_variable("v", [1])
with tf.variable_scope("foo", reuse=True): #reuse the second time
v = tf.get_variable("v", [1])
如何让它确定是自动创建还是重用它?即,我希望上面的两个代码块是 same 并让程序运行。
最佳答案
get_variable()
在创建新变量且未声明形状时,或在创建变量期间违反重用时会引发 ValueError
。因此,你可以试试这个:
def get_scope_variable(scope_name, var, shape=None):
with tf.variable_scope(scope_name) as scope:
try:
v = tf.get_variable(var, shape)
except ValueError:
scope.reuse_variables()
v = tf.get_variable(var)
return v
v1 = get_scope_variable('foo', 'v', [1])
v2 = get_scope_variable('foo', 'v')
assert v1 == v2
请注意,以下内容也有效:
v1 = get_scope_variable('foo', 'v', [1])
v2 = get_scope_variable('foo', 'v', [1])
assert v1 == v2
更新。新的 API 现在支持自动重用:
def get_scope_variable(scope, var, shape=None):
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
v = tf.get_variable(var, shape)
return v
关于python - TensorFlow 变量范围 : reuse if variable exists,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/38545362/