python - 从 3D numpy 数组中选择多个补丁

标签 python arrays numpy 3d

我有一个 3D numpy 数组,大小为 50x50x4。我还有 50x50 平面上几个点的位置。对于每个点,我需要提取一个以该点为中心的 11x11x4 区域。如果该区域与边界重叠,则必须环绕。请问最有效的方法是什么?

我目前正在使用 for 循环来迭代每个点,对 3D 矩阵进行子集化,并将其存储在预初始化数组中。是否有内置的 numpy 函数可以执行此操作?谢谢。

<小时/>

抱歉回复慢,非常感谢大家的意见。

最佳答案

一种方法是使用 np.pad沿最后一个轴具有包裹功能。然后,我们将使用 np.lib.stride_tricks.as_strided 在此填充版本上创建滑动窗口,它是填充数组的 View ,不会再占用内存。最后,我们将索引到滑动窗口以获得最终输出。

# Based on http://stackoverflow.com/a/41850409/3293881
def patchify(img, patch_shape): 
    X, Y, a = img.shape
    x, y = patch_shape
    shape = (X - x + 1, Y - y + 1, x, y, a)
    X_str, Y_str, a_str = img.strides
    strides = (X_str, Y_str, X_str, Y_str, a_str)
    return np.lib.stride_tricks.as_strided(img, shape=shape, strides=strides)

def sliding_patches(a, BSZ):
    hBSZ = (BSZ-1)//2
    a_ext = np.dstack(np.pad(a[...,i], hBSZ, 'wrap') for i in range(a.shape[2]))
    return patchify(a_ext, (BSZ,BSZ))

示例运行 -

In [51]: a = np.random.randint(0,9,(4,5,2)) # Input array

In [52]: a[...,0]
Out[52]: 
array([[2, 7, 5, 1, 0],
       [4, 1, 2, 0, 7],
       [1, 3, 0, 8, 4],
       [8, 0, 5, 2, 7]])

In [53]: a[...,1]
Out[53]: 
array([[0, 3, 3, 8, 7],
       [3, 8, 2, 8, 2],
       [8, 4, 3, 8, 7],
       [6, 6, 8, 5, 5]])

现在,让我们在 a 中选择一个中心点,假设 (1,0) 并尝试获取 blocksize (BSZ) = 3 围绕它 -

In [54]: out = sliding_patches(a, BSZ=3) # Create sliding windows

In [55]: out[1,0,...,0]  # patch centered at (1,0) for slice-0
Out[55]: 
array([[0, 2, 7],
       [7, 4, 1],
       [4, 1, 3]])

In [56]: out[1,0,...,1]  # patch centered at (1,0) for slice-1
Out[56]: 
array([[7, 0, 3],
       [2, 3, 8],
       [7, 8, 4]])

因此,围绕 (1,0) 获取补丁的最终输出将很简单: out[1,0,...,:]输出[1,0]

无论如何,让我们对原始形状数组进行形状检查 -

In [65]: a = np.random.randint(0,9,(50,50,4))

In [66]: out = sliding_patches(a, BSZ=11)

In [67]: out[1,0].shape
Out[67]: (11, 11, 4)

关于python - 从 3D numpy 数组中选择多个补丁,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41864715/

相关文章:

python - 如何在Tensorflow中删除张量中的重复值?

python - 判断一个句子是英文概率的比较简单的方法是什么?

c++ - 如何有效地将大数存储在整数数组中? C++

C:让用户一个字符一个字符地填充一个char数组。然后打印出来

java - 使用另一个类中的方法将数字添加到数组中?

python - 快速、优雅的方法来计算经验/样本协方差图

python - 如何以干净有效的方式在pytorch中获得小批量?

python - 在具有负斜率的曲线上插值数据点

python - 为什么 numpy ma.average 比 arr.mean 慢 24 倍?

python - 如何将平流扩散 react 偏微分方程与 FiPy 耦合