python - 在 numpy 中,如何有效地列出所有固定大小的子矩阵?

标签 python numpy matrix

我有一个任意的 NxM 矩阵,例如:

1 2 3 4 5 6
7 8 9 0 1 2
3 4 5 6 7 8
9 0 1 2 3 4

我想得到这个矩阵中所有 3x3 子矩阵的列表:

1 2 3       2 3 4               0 1 2
7 8 9   ;   8 9 0   ;  ...  ;   6 7 8
3 4 5       4 5 6               2 3 4

我可以用两个嵌套循环来做到这一点:

rows, cols = input_matrix.shape
patches = []
for row in np.arange(0, rows - 3):
    for col in np.arange(0, cols - 3):
        patches.append(input_matrix[row:row+3, col:col+3])

但是对于一个大的输入矩阵,这是很慢的。 有没有办法用 numpy 更快地做到这一点?

我查看了 np.split,但这给了我非重叠子矩阵,而我想要所有可能的子矩阵,无论重叠如何。

最佳答案

你想要一个窗口 View :

from numpy.lib.stride_tricks import as_strided

arr = np.arange(1, 25).reshape(4, 6) % 10
sub_shape = (3, 3)
view_shape = tuple(np.subtract(arr.shape, sub_shape) + 1) + sub_shape
arr_view = as_strided(arr, view_shape, arr.strides * 2
arr_view = arr_view.reshape((-1,) + sub_shape)

>>> arr_view
array([[[[1, 2, 3],
         [7, 8, 9],
         [3, 4, 5]],

        [[2, 3, 4],
         [8, 9, 0],
         [4, 5, 6]],

        ...

        [[9, 0, 1],
         [5, 6, 7],
         [1, 2, 3]],

        [[0, 1, 2],
         [6, 7, 8],
         [2, 3, 4]]]])

这样做的好处是您没有复制任何数据,您只是以不同的方式访问原始数组的数据。对于大型阵列,这可以节省大量内存。

关于python - 在 numpy 中,如何有效地列出所有固定大小的子矩阵?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/19414673/

相关文章:

javascript - 来自正方形的矩阵与 JavaScript

python - tkinter 和 asyncio,窗口拖动/调整大小阻止事件循环,单线程

python - 将二维数组乘以一维数组

python - Pandas : How to get groups of each n rows after row matching query?

python-3.x - sklearn : Pandas Dataframe vs Numpy ndarray - Which is more efficient to hold a [600k * 1k] data of different data types

复制矩阵时出现 C++ 段错误

python - 如何以最佳方式调整我的面部识别常数(即阈值)?

python - 图形解析错误

python - 如何使用 python mechanize 发送有效负载

python - 如果数字为 1,如何按值检查矩阵值