python - 定义自定义 Op theano 的 grad

标签 python theano pymc3

我正在尝试定义一个带有渐变的自定义 theano Op 以便将其与 pymc3 一起使用,但我不明白如何定义 grad 方法。

下面的代码是我陷入困境的地方。函数 phi() 是一个模拟函数(实际上,它是一个外部程序);对于标量输入x,它返回一个向量(phi_0(x), phi_1(x), ...)。函数 phi_diff()(也是一个模拟函数)返回向量 (dphi_0/dx, dphi_1/dx, ...)

我将 phi()phi_diff() 包装在 theano.Op 对象中,但我实现了 grad 功能不起作用。 theano 的文档包含更简单的示例,我不明白如何在这种情况下调整它们。任何帮助将不胜感激。

import numpy as np
import theano.tensor as T
import theano

theano.config.optimizer = "None"
theano.config.exception_verbosity = "high"


def phi(x):
    return np.arange(n) * x


def phi_diff(x):
    return np.arange(n)


class PhiOp(theano.Op):
    itypes = [theano.tensor.dscalar]
    otypes = [theano.tensor.dvector]

    def perform(self, node, inputs, output_storage):
        x = inputs[0]
        output_storage[0][0] = phi(x)

    def grad(self, inputs, output_grads):
        x = inputs[0]
        # ???
        return [PhiDiffOp()(x) * output_grads[0]]


class PhiDiffOp(theano.Op):
    itypes = [theano.tensor.dscalar]
    otypes = [theano.tensor.dvector]

    def perform(self, node, inputs, output_storage):
        x = inputs[0]
        output_storage[0][0] = phi_diff(x)


n = 5
x = 777.

phi_op = PhiOp()
x_tensor = T.dscalar("x_tensor")
phi_func = theano.function([x_tensor], phi_op(x_tensor))
np.testing.assert_allclose(phi_func(x), phi(x))

T.jacobian(phi_op(x_tensor), x_tensor)

最佳答案

找到解决方案,更改如下:

def phi_diff(x):
    return np.arange(n, dtype=np.float_)

class PhiOp(theano.Op):
    def grad(self, inputs, output_grads):
        x = inputs[0]
        gg = (PhiDiffOp()(x) * output_grads[0]).sum()
        return [gg]

关于python - 定义自定义 Op theano 的 grad,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52759259/

相关文章:

python - 你应该总是喜欢 xrange() 而不是 range() 吗?

python - 在 TestCase 执行期间 django.db.connection.cursor() SQL 查询从主数据库返回数据,而不是从测试数据库返回数据

python - 在Tensorflow中实现Theano运算

python - 安装和导入后,colaboratory 将不承认 arviz

python - 搜索没有 html 标签的漂亮的 soup 输出

python - 从多个 txt 文件中读取 - 剥离数据并保存到 xls

python - 在 Keras 模型中获取中间层输出的正确方法?

python - 错误 : NVCC compiler not in $path

python - 如何评估 pymc3 中的对数后验

python - 将 Theano 共享变量与 PyMC3 结合使用