python - 为什么损失函数中使用.numpy()时Tensorflow的自动微分失败?

标签 python numpy tensorflow

我注意到,当损失函数将输入转换为 numpy 数组以计算输出值时,Tensorflow 的自动微分不会给出与有限差分相同的值。这是该问题的最小工作示例:

import tensorflow as tf
import numpy as np

def lossFn(inputTensor):
    # Input is a rank-2 square tensor
    return tf.linalg.trace(inputTensor @ inputTensor)

def lossFnWithNumpy(inputTensor):
    # Same function, but converts input to a numpy array before performing the norm
    inputArray = inputTensor.numpy()

    return tf.linalg.trace(inputArray @ inputArray)

N = 2
tf.random.set_seed(0)
randomTensor = tf.random.uniform([N, N])

# Prove that the two functions give the same output; evaluates to exactly zero
print(lossFn(randomTensor) - lossFnWithNumpy(randomTensor)) 

theoretical, numerical = tf.test.compute_gradient(lossFn, [randomTensor])
# These two values match
print(theoretical[0])
print(numerical[0])

theoretical, numerical = tf.test.compute_gradient(lossFnWithNumpy, [randomTensor])
# The theoretical value is [0 0 0 0]
print(theoretical[0])
print(numerical[0])

函数tf.test.compute_gradients使用自动微分计算“理论”梯度,并使用有限差分计算数值梯度。如代码所示,如​​果我在损失函数中使用 .numpy() ,自动微分不会计算梯度。

谁能解释一下原因吗?

最佳答案

来自指南:Introduction to Gradients and Automatic Differentiation

The tape can't record the gradient path if the calculation exits TensorFlow. For example:

x = tf.Variable([[1.0, 2.0],
                 [3.0, 4.0]], dtype=tf.float32)

with tf.GradientTape() as tape:   
  x2 = x**2
  # This step is calculated with NumPy   
  y = np.mean(x2, axis=0)
  # Like most ops, reduce_mean will cast the NumPy array to a constant tensor 
  # using `tf.convert_to_tensor`. 
  y = tf.reduce_mean(y,axis=0)

print(tape.gradient(y, x)) 

outputs None

numpy 值将在调用 tf.linalg.trace 时被转换回常量张量,Tensorflow 无法计算梯度。

关于python - 为什么损失函数中使用.numpy()时Tensorflow的自动微分失败?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64771324/

相关文章:

Python httplib2 处理异常

使用 cron-job 将 Python 输出到文本文件

python - 使用 toPandas 时强制将 null 一致转换为 nan

python - 将 3 维坐标映射到 1 维索引

tensorflow - 如何使用 TensorFlow 学习多类多输出 CNN

python - 使用 Pandas 读取 JSON 时出现“预期的字符串或 Unicode”

python - 检查点是否位于球的边界上/检查 Delaunay 三角剖分的唯一性

python - pandas,将 DataFrame 转换为 MultiIndex'ed DataFrame

python-3.x - 神经网络预测第 n 个方格

python - conv2d 层的输入 0 与层 : expected axis -1 of input shape to have value 1 but received input with shape [None, 64, 64, 3 不兼容]