python-2.7 - 如何在 TensorFlow 中提供自定义梯度

标签 python-2.7 tensorflow

我试图了解如何使用 @tf.custom_gradient TensorFlow 1.7 中可用的函数,用于提供向量相对于向量的自定义梯度。下面的代码是解决以下问题以获得dz/dx的最小工作示例.

y=轴
z=||y||2

Also, this attached image describes the solution as expected by manually calulation

如果我不使用 @tf.custom_gradient然后 TensorFlow 按预期提供所需的解决方案。我的问题是如何为 y=Ax 提供自定义渐变?我们知道dy/dx = A^T如上面的附件所示,其中显示了与 TensorFlow 输出匹配的计算步骤。

import tensorflow as tf

#I want to write custom gradient for this function f1
def f1(A,x):
    y=tf.matmul(A,x,name='y')
    return y

#for y= Ax, the derivative is: dy/dx= transpose(A)
@tf.custom_gradient
def f2(A,x):
    y=f1(A,x)
    def grad(dzByDy): # dz/dy = 2y reaches here correctly.
        dzByDx=tf.matmul(A,dzByDy,transpose_a=True) 
        return dzByDx
    return y,grad


x= tf.constant([[1.],[0.]],name='x')
A= tf.constant([ [1., 2.], [3., 4.]],name='A')

y=f1(A,x) # This works as desired
#y=f2(A,x) #This line gives Error


z=tf.reduce_sum(y*y,name='z')

g=tf.gradients(ys=z,xs=x)

with tf.Session() as sess:
    print sess.run(g)

最佳答案

由于您的功能 f2()有两个输入,你必须提供一个梯度流回它们中的每一个。你看到的错误:

Num gradients 2 generated for op name: "IdentityN" [...] do not match num inputs 3



不过,这无疑是相当神秘的。假设你永远不想计算 d y /d , 你可以只返回 None, dzByDx。下面的代码(已测试):
import tensorflow as tf

#I want to write custom gradient for this function f1
def f1(A,x):
    y=tf.matmul(A,x,name='y')
    return y

#for y= Ax, the derivative is: dy/dx= transpose(A)
@tf.custom_gradient
def f2(A,x):
    y=f1(A,x)
    def grad(dzByDy): # dz/dy = 2y reaches here correctly.
        dzByDx=tf.matmul(A,dzByDy,transpose_a=True) 
        return None, dzByDx
    return y,grad

x= tf.constant([[1.],[0.]],name='x')
A= tf.constant([ [1., 2.], [3., 4.]],name='A')

#y=f1(A,x) # This works as desired
y=f2(A,x) #This line gives Error

z=tf.reduce_sum(y*y,name='z')

g=tf.gradients(ys=z,xs=x)

with tf.Session() as sess:
    print sess.run( g )

输出:

[array([[20.], [28.]], dtype=float32)]



如预期的。

关于python-2.7 - 如何在 TensorFlow 中提供自定义梯度,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50030026/

相关文章:

python - 如何保持从 Pandas DataFrame 到字典的行顺序

python 请求<响应[520]>

python - 使用 Python 转储 PostgreSQL 数据库模式

python-3.x - 属性错误 : module 'resource' has no attribute 'getpagesize'

python - ValueError : `class_weight` must contain all classes in the data. 类{1,2,3}存在于数据中但不存在于 `class_weight`

python - 如果没有值则最大值 - python

python - 是否可以异常返回?

python - 在纯 Tensorflow 中混合分类数据和连续数据

import - Tensorflow 0.8 import/export 输出张量问题

python - Tensorflow 中逻辑运算符的梯度