python - 如何实现maxpool : taking a maximum on sliding window on image or tensor

标签 python numpy neural-network conv-neural-network array-broadcasting

简而言之:我正在寻找 Maxpool 的简单 numpy(也许是 oneliner)实现 - 在 上的窗口上最大化numpy.narray 用于跨维度窗口的所有位置。

更多细节:我正在实现一个卷积神经网络(“CNN”),这种网络中的典型层之一是 MaxPool 层(例如here)。写作 y = MaxPool(x, S)x是一个输入narrayS是一个参数,使用伪代码,MaxPool 的输出为:

     y[b,h,w,c] = max(x[b, s*h + i, s*w + j, c]) over i = 0,..., S-1; j = 0,...,S-1.

也就是说,ynarray,其中索引 b,h,w,c 的值等于窗口中的最大值沿输入 x 的第二和第三维大小 S x S,窗口“角”位于索引 b,h,w,c

一些额外的细节:网络是使用numpy实现的。 CNN 有许多“层”,其中一层的输出是下一层的输入。层的输入是 numpy.narray,称为“张量”。在我的例子中,张量是 4 维 numpy.narrayx。即x.shape是一个元组(B,H,W,C)。张量经过一层处理后,维度的每个大小都会发生变化,例如层 i= 4 的输入可以有大小 B = 10, H = 24, W = 24, C = 3,而输出,也就是 i+1 层的输入有 B = 10,H = 12,W = 12,C = 5。如注释中所示,应用 MaxPool 后的大小为 (B, H - S + 1, W - S + 1, C)

具体来说:如果我使用

import numpy as np

y = np.amax(x, axis = (1,2)) 

其中 x.shape 是说 (2,3,3,4) 这将给我我想要的,但对于我正在最大化窗口的退化情况over 的大小是 3 x 3x 的第二和第三维度的大小,这不是我想要的。

最佳答案

这是一个使用 np.lib.stride_tricks.as_strided 来创建滑动窗口的解决方案,从而产生一个形状为 6D 的数组:(B,H-S+ 1,W-S+1,S,S,C) 然后简单地沿第四和第五轴执行 max,得到形状的输出数组:(B,H-S+1, W-S+1,C)。中间 6D 数组将是输入数组的 View ,因此不会再占用内存。 max 作为缩减的后续操作将有效地利用滑动 views

因此,一个实现将是 -

# Based on http://stackoverflow.com/a/41850409/3293881
def patchify(img, patch_shape):
    a, X, Y, b = img.shape
    x, y = patch_shape
    shape = (a, X - x + 1, Y - y + 1, x, y, b)
    a_str, X_str, Y_str, b_str = img.strides
    strides = (a_str, X_str, Y_str, X_str, Y_str, b_str)
    return np.lib.stride_tricks.as_strided(img, shape=shape, strides=strides)

out = patchify(x, (S,S)).max(axis=(3,4))

sample 运行-

In [224]: x = np.random.randint(0,9,(10,24,24,3))

In [225]: S = 5

In [226]: np.may_share_memory(patchify(x, (S,S)), x)
Out[226]: True

In [227]: patchify(x, (S,S)).shape
Out[227]: (10, 20, 20, 5, 5, 3)

In [228]: patchify(x, (S,S)).max(axis=(3,4)).shape
Out[228]: (10, 20, 20, 3)

关于python - 如何实现maxpool : taking a maximum on sliding window on image or tensor,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41881638/

相关文章:

python - 汇总数据框的前 10 行

R - 神经网络 - 传统的反向传播看起来很奇怪

python - 从列中删除带冒号的单词 - 为什么它不起作用?

python - 如何使用 python groupby()

python - 混淆矩阵缺失实例

python - 挤压网络问题

javascript - 大脑.js : XOR example does not work

python - 使用 pelican 的不安全脚本

python - numpy 数组中非唯一行的快速组合,映射到列(即快速数据透视表问题,没有 Pandas)

Python:距离直线最近的点