python - 为什么这段Theano代码运行成功没有任何错误?

标签 python numpy theano deep-learning

我从在线教程中借用了以下代码。我看到下面这行写在代码的主要方法中

c = broadcasted_add(a, b)

是添加维度 (2,1,2,2) 的张量“a”和维度 (2,2,2,2) 的张量“b”。即使我们在 make_tensor 方法中将 broadcastable 声明为“false”,它如何能够正确添加?我们不应该将 broadcastable 声明为 True 以便它可以添加不同的维度吗?它不应该抛出一个错误说尺寸不匹配吗?我对可广播的理解有误吗?

import numpy as np
from theano import function
import theano.tensor as T

def make_tensor(dim):
    """
    Returns a new Theano tensor with no broadcastable dimensions.
    dim: the total number of dimensions of the tensor.
    """

    return T.TensorType(broadcastable=tuple([False] * dim), dtype='float32')()

def broadcasted_add(a, b):
    """
    a: a 3D theano tensor
    b: a 4D theano tensor
    Returns c, a 4D theano tensor, where
    c[i, j, k, l] = a[l, k, i] + b[i, j, k, l]
    for all i, j, k, l
    """

return a.dimshuffle(2, 'x', 1, 0) + b

def partial_max(a):
    """
    a: a 4D theano tensor
    Returns b, a theano matrix, where
    b[i, j] = max_{k,l} a[i, k, l, j]
    for all i, j
    """

return a.max(axis=(1, 2))

if __name__ == "__main__":
    a = make_tensor(3)
    b = make_tensor(4)
    c = broadcasted_add(a, b)
    d = partial_max(c)

    f = function([a, b,], d)

    rng = np.random.RandomState([1, 2, 3])
    a_value = rng.randn(2, 2, 2).astype(a.dtype)
    b_value = rng.rand(2, 2, 2, 2).astype(b.dtype)
    c_value = np.transpose(a_value, (2, 1, 0))[:, None, :, :] + b_value
    expected = c_value.max(axis=1).max(axis=1)

    actual = f(a_value, b_value)

    assert np.allclose(actual, expected), (actual, expected)
    print "SUCCESS!"

最佳答案

这样做的原因是 dimshuffle 通过 'x' 参数值添加的新维度总是可广播的。

请注意,在 broadcasted_add 中,唯一需要广播的维度是通过 dimshuffle 添加到 a 的维度。其他维度都不需要广播。

关于python - 为什么这段Theano代码运行成功没有任何错误?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/33101410/

相关文章:

python - 如果我们使用 Django 通用 View ,如何发送成功消息

python - 获取 keras 预测的真实标签

python - 如何有效地使用索引数组作为掩码将 numpy 数组转换为 bool 数组?

python - 动态追加N维数组

python - 在同一台计算机上安装 Python 2.x 和 python 3.x

Python 行继续字符后出现意外字符

python - 多线程 Python 应用程序和套接字连接的问题

python - 使用 numpy reshape 数组 - ValueError : cannot reshape array

python - 从 GPU 核心/线程的角度理解 Theano 示例

python - 在theano中扫描不同维度的张量