python - 在展平参数张量上使用 tf.gradients 或 tf.hessians

标签 python tensorflow deep-learning hessian-matrix

假设我想根据某些参数 W(例如,前馈神经网络的权重和偏差)计算标量值函数的 Hessian 矩阵。 如果考虑以下代码,实现经过训练以最小化 MSE 损失的二维线性模型:

import numpy as np
import tensorflow as tf

x = tf.placeholder(dtype=tf.float32, shape=[None, 2])  #inputs
t = tf.placeholder(dtype=tf.float32, shape=[None,])  #labels
W = tf.placeholder(np.eye(2), dtype=tf.float32)  #weights

preds = tf.matmul(x, W)  #linear model
loss = tf.reduce_mean(tf.square(preds-t), axis=0) #mse loss

params = tf.trainable_variables() 
hessian = tf.hessians(loss, params)

您希望 session.run(tf.hessian,feed_dict={}) 返回一个 2x2 矩阵(等于 W)。事实证明,因为 params 是一个 2x2 张量,所以输出是一个形状为 [2, 2, 2, 2] 的张量。虽然我可以很容易地 reshape 张量以获得我想要的矩阵,但当 params 成为不同大小的张量列表时(即当模型是深度神经网络时),似乎这个操作可能会非常麻烦例如)。

似乎有两种解决方法:

  • params 展平为名为 flat_params 的一维张量:

    flat_params = tf.concat([tf.reshape(p, [-1]) for p in params])
    

    因此 tf.hessians(loss, flat_params) 自然返回一个 2x2 矩阵。然而,如 Why does Tensorflow Reshape tf.reshape() break the flow of gradients? 中所述对于 tf.gradients(但也适用于 tf.hessians),tensorflow 无法在图表中看到 paramsflat_paramstf 之间的符号链接(symbolic link)。 hessians(loss, flat_params) 将引发错误,因为梯度将被视为 None

  • https://afqueiruga.github.io/tensorflow/2017/12/28/hessian-mnist.html ,代码作者反其道而行之,首先创建平面参数并将其部分 reshape 为self.params。这个技巧确实有效,并为您提供了具有预期 形状(2x2 矩阵)的粗麻布。然而,在我看来,当你有一个复杂的模型时,这会很麻烦,如果你通过内置函数(比如 tf.layers.dense,. .).

self.params 是一个任意形状的张量列表?如果不是,您如何自动 reshape tf.hessians 的输出张量?

最佳答案

事实证明(根据 TensorFlow r1.13)如果 len(xs) > 1,则 tf.hessians(ys, xs) 返回仅对应于完整 Hessian 矩阵的 block 对角子矩阵的张量。本文中的完整故事和解决方案 https://arxiv.org/pdf/1905.05559 , 代码在 https://github.com/gknilsen/pyhessian

关于python - 在展平参数张量上使用 tf.gradients 或 tf.hessians,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51948171/

相关文章:

c++ - 使用 SWIG 在 Python 中包装 C++ 类

tensorflow - 如何重新排列张量中的元素,就像在 MATLAB 中一样?

python - 在 python 中,测试字符串的 1 或 3 个连续实例,但不是 2 个(没有正则表达式)

python - 如何将多个 pandas 数据框列汇总为父列名称?

python - Apache Zeppelin 问题 - Python 错误

python - tensorflow 中的参数值

python - 在 tensorflow 中嵌套控制依赖上下文

machine-learning - 在 MXNet 上运行 Ptr-Net

python - 咖啡乐网 : Difference between `solver.step(1)` and `solver.net.forward()`

machine-learning - 为什么深度残差网络中的每个 block 都有两个卷积层而不是一个?