我正在将我的代码转换为 Tensorflow v2,但我不断收到以下错误:
AssertionError: Called a function referencing variables which have been deleted. This likely means that function-local variables were created and not referenced elsewhere in the program. This is generally a mistake; consider storing variables in an object attribute on first call.
这是一个重现错误的最小示例
import tensorflow as tf
class TEST:
def __init__(self, a=1):
self.a = tf.Variable(a)
@tf.function
def increment(self):
self.a = self.a + 1
return self.a
tst = TEST()
tst.increment()
我应该如何解决这个问题?
最佳答案
当你这样做时:
self.a = self.a + 1
您正在用该操作的结果覆盖 self.a
中的引用,该引用最初与上面创建的变量相关联。您没有更新 TensorFlow 变量的值,只是替换了 Python 引用。您正在创建的新张量(self.a + 1
的结果)反过来会在其计算中使用该变量。问题是,当 self.a
被覆盖时,这个变量就被遗忘了,不能再使用了。这有点像先有鸡还是先有蛋的问题,但是 tf.function
认为这是无效的。如果您想拥有变量并为其分配新值,请执行以下操作:
@tf.function
def increment(self):
self.a.assign(self.a + 1)
return self.a
或者只是他的:
@tf.function
def increment(self):
self.a.assign_add(1)
return self.a
关于python - tensorflow v2 中的变量赋值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57168947/