python - 如何计算 Tensorflow 中的所有二阶导数(仅 Hessian 矩阵的对角线)?

标签 python tensorflow

我有一个损失值/函数,我想计算关于张量f(大小为n)的所有二阶导数。我设法使用了 tf.gradients 两次,但在第二次应用它时,它对第一个输入的导数求和(请参阅我的代码中的 second_derivatives)。

我还设法检索了 Hessian 矩阵,但我只想计算它的对角线以避免额外计算。

import tensorflow as tf
import numpy as np

f = tf.Variable(np.array([[1., 2., 0]]).T)
loss = tf.reduce_prod(f ** 2 - 3 * f + 1)

first_derivatives = tf.gradients(loss, f)[0]

second_derivatives = tf.gradients(first_derivatives, f)[0]

hessian = [tf.gradients(first_derivatives[i,0], f)[0][:,0] for i in range(3)]

model = tf.initialize_all_variables()
with tf.Session() as sess:
    sess.run(model)
    print "\nloss\n", sess.run(loss)
    print "\nloss'\n", sess.run(first_derivatives)
    print "\nloss''\n", sess.run(second_derivatives)
    hessian_value = np.array(map(list, sess.run(hessian)))
    print "\nHessian\n", hessian_value

我的想法是 tf.gradients(first_derivatives, f[0, 0])[0] 可以检索例如关于 f_0 的二阶导数,但似乎 tensorflow 没有不允许从张量的切片中导出。

最佳答案

tf.gradients([f1,f2,f3],...) 计算 f=f1+f2+f3 的梯度 此外,区分 x[0] 是有问题的,因为 x[0] 指的是一个新的 Slice 节点,它不是你的损失,因此关于它的导数将是 None。你可以通过使用 packx[0], x[1], ... 粘合到 xx 中来绕过它,并让你的损失取决于 xx 而不是 x 。另一种方法是为各个组件使用单独的变量,在这种情况下计算 Hessian 看起来像这样。

def replace_none_with_zero(l):
  return [0 if i==None else i for i in l] 

tf.reset_default_graph()

x = tf.Variable(1.)
y = tf.Variable(1.)
loss = tf.square(x) + tf.square(y)
grads = tf.gradients([loss], [x, y])
hess0 = replace_none_with_zero(tf.gradients([grads[0]], [x, y]))
hess1 = replace_none_with_zero(tf.gradients([grads[1]], [x, y]))
hessian = tf.pack([tf.pack(hess0), tf.pack(hess1)])
sess = tf.InteractiveSession()
sess.run(tf.initialize_all_variables())
print hessian.eval()

你会看到

[[ 2.  0.]
 [ 0.  2.]]

关于python - 如何计算 Tensorflow 中的所有二阶导数(仅 Hessian 矩阵的对角线)?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/38200982/

相关文章:

python - 哪种模型/技术用于特定句子提取?

python - 网页抓取 - 不显示内容

python - 基于键和值的嵌套字典的条件递归搜索

python - Gmail API 插入带有附件的电子邮件不显示邮件列表显示中存在附件

python - 如何将值是列表的字典传递给函数

python - Tensorflow2.0 - 如何将张量转换为 numpy() 数组

tensorflow - strip_unused_nodes 的正确参数

python - Django : Testing admin user

python - 如何在 Resnet 50 分类中输出置信度?

python - 要保存的变量应该在字典或列表中传递