numpy - 澄清 Theano 中的 flatten 函数

标签 numpy theano flatten conv-neural-network

在 [ http://deeplearning.net/tutorial/lenet.html#lenet]它说:

This will generate a matrix of shape (batch_size, nkerns[1] * 4 * 4),
# or (500, 50 * 4 * 4) = (500, 800) with the default values.
layer2_input = layer1.output.flatten(2)

当我在 numpy 3d 数组上使用 flatten 函数时,我得到一个 1D 数组。但在这里它说我得到了一个矩阵。 flatten(2) 在 theano 中如何工作?

numpy 上的类似示例生成一维数组:

     a= array([[[ 1,  2,  3],
    [ 4,  5,  6],
    [ 7,  8,  9]],

   [[10, 11, 12],
    [13, 14, 15],
    [16, 17, 18]],

   [[19, 20, 21],
    [22, 23, 24],
    [25, 26, 27]]])

   a.flatten(2)=array([ 1, 10, 19,  4, 13, 22,  7, 16, 25,  2, 11, 20,  5, 14, 23,  8, 17,
   26,  3, 12, 21,  6, 15, 24,  9, 18, 27])

最佳答案

numpy 不支持仅展平某些维度,但 Theano 支持。

因此,如果 a 是一个 numpy 数组,则 a.flatten(2) 没有任何意义。它运行没有错误,但这只是因为 2 作为 order 参数传递,这似乎导致 numpy 坚持 C 的默认顺序。

Theano 的 flatten 确实 支持轴规范。 The documentation解释它是如何工作的。

Parameters:
    x (any TensorVariable (or compatible)) – variable to be flattened
    outdim (int) – the number of dimensions in the returned variable

Return type:
    variable with same dtype as x and outdim dimensions

Returns:
    variable with the same shape as x in the leading outdim-1 dimensions,
    but with all remaining dimensions of x collapsed into the last dimension.

For example, if we flatten a tensor of shape (2, 3, 4, 5) with flatten(x, outdim=2), then we’ll have the same (2-1=1) leading dimensions (2,), and the remaining dimensions are collapsed. So the output in this example would have shape (2, 60).

一个简单的 Theano 演示:

import numpy
import theano
import theano.tensor as tt


def compile():
    x = tt.tensor3()
    return theano.function([x], x.flatten(2))


def main():
    a = numpy.arange(2 * 3 * 4).reshape((2, 3, 4))
    f = compile()
    print a.shape, f(a).shape


main()

打印

(2L, 3L, 4L) (2L, 12L)

关于numpy - 澄清 Theano 中的 flatten 函数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/34709950/

相关文章:

python - 有没有办法将 libgpuarray 与 Intel GPU 一起使用?

python - 为什么此方法会在参数数量方面引发错误?

search - 如何为一对多关系配置 Solr

arrays - Julia:扁平化数组/元组数组

python - 创建自定义图像数据集时出现 numpy 数组形状问题

numpy - 安装错误 : ftheader. h:没有那个文件或目录

python - 机器 epsilon 的倍数是什么意思?

python - 使用 Python 沿列插入二维矩阵

python - Keras 函数式 API 有什么特别之处?

python - 使用选择/忽略特定键来展平字典