python - tensorflow v2 中的变量赋值

标签 python tensorflow

我正在将我的代码转换为 Tensorflow v2,但我不断收到以下错误:

AssertionError: Called a function referencing variables which have been deleted. This likely means that function-local variables were created and not referenced elsewhere in the program. This is generally a mistake; consider storing variables in an object attribute on first call.

这是一个重现错误的最小示例

import tensorflow as tf

class TEST:
    def __init__(self, a=1):
        self.a = tf.Variable(a)

    @tf.function
    def increment(self):
        self.a = self.a + 1
        return self.a

tst = TEST()
tst.increment()

我应该如何解决这个问题?

最佳答案

当你这样做时:

self.a = self.a + 1

您正在用该操作的结果覆盖 self.a 中的引用,该引用最初与上面创建的变量相关联。您没有更新 TensorFlow 变量的值,只是替换了 Python 引用。您正在创建的新张量(self.a + 1 的结果)反过来会在其计算中使用该变量。问题是,当 self.a 被覆盖时,这个变量就被遗忘了,不能再使用了。这有点像先有鸡还是先有蛋的问题,但是 tf.function 认为这是无效的。如果您想拥有变量并为其分配新值,请执行以下操作:

@tf.function
def increment(self):
    self.a.assign(self.a + 1)
    return self.a

或者只是他的:

@tf.function
def increment(self):
    self.a.assign_add(1)
    return self.a

关于python - tensorflow v2 中的变量赋值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57168947/

相关文章:

python - pybrain - ClassificationDataSet - 如何理解使用 SoftmaxLayer 时的输出

python - pandas.read_json() 未按预期工作

python - Keras InvalidArgumentError 未知输入节点

python - 验证前馈网络的有效性

python - TensorFlow tf.data.Dataset 和分桶

python - Google App Engine 应用程序是否可以交流或控制机器学习模型或任务?

python - Networkx:如何绘制彩色边缘?

python - 用于使用 Mac 在 python IDLE 中访问先前语句的键盘快捷方式

python - pycharm中有没有办法关闭运行窗口并使用键盘快捷键查看编辑器窗口

python - 为什么在 RTX 3070/cudnn8/CUDA11.1 上运行时添加卷积/池层会使 Keras/Tensorflow 模型崩溃?