我有一个很大的NumPy.array
field_array
和一个较小的数组 match_array
, 均由 int
组成值。使用以下示例,我如何检查 field_array
的任何 match_array-shaped 段?包含与 match_array
中的值完全对应的值?
import numpy
raw_field = ( 24, 25, 26, 27, 28, 29, 30, 31, 23, \
33, 34, 35, 36, 37, 38, 39, 40, 32, \
-39, -38, -37, -36, -35, -34, -33, -32, -40, \
-30, -29, -28, -27, -26, -25, -24, -23, -31, \
-21, -20, -19, -18, -17, -16, -15, -14, -22, \
-12, -11, -10, -9, -8, -7, -6, -5, -13, \
-3, -2, -1, 0, 1, 2, 3, 4, -4, \
6, 7, 8, 4, 5, 6, 7, 13, 5, \
15, 16, 17, 8, 9, 10, 11, 22, 14)
field_array = numpy.array(raw_field, int).reshape(9,9)
match_array = numpy.arange(12).reshape(3,4)
这些例子应该返回 True
由于 match_array
描述的模式对齐 [6:9,3:7]
.
最佳答案
方法 #1
此方法源自 a solution
至 Implement Matlab's im2col 'sliding' in python
旨在将 slider 从 2D 数组重新排列到列中
。因此,为了解决我们这里的情况,可以将来自 field_array
的 slider 堆叠为列,并与 match_array
的列向量版本进行比较。
这是重新排列/堆叠函数的正式定义 -
def im2col(A,BLKSZ):
# Parameters
M,N = A.shape
col_extent = N - BLKSZ[1] + 1
row_extent = M - BLKSZ[0] + 1
# Get Starting block indices
start_idx = np.arange(BLKSZ[0])[:,None]*N + np.arange(BLKSZ[1])
# Get offsetted indices across the height and width of input array
offset_idx = np.arange(row_extent)[:,None]*N + np.arange(col_extent)
# Get all actual indices & index into input array for final output
return np.take (A,start_idx.ravel()[:,None] + offset_idx.ravel())
为了解决我们的问题,下面是基于 im2col
的实现 -
# Get sliding blocks of shape same as match_array from field_array into columns
# Then, compare them with a column vector version of match array.
col_match = im2col(field_array,match_array.shape) == match_array.ravel()[:,None]
# Shape of output array that has field_array compared against a sliding match_array
out_shape = np.asarray(field_array.shape) - np.asarray(match_array.shape) + 1
# Now, see if all elements in a column are ONES and reshape to out_shape.
# Finally, find the position of TRUE indices
R,C = np.where(col_match.all(0).reshape(out_shape))
问题中给定样本的输出将是 -
In [151]: R,C
Out[151]: (array([6]), array([3]))
方法 #2
鉴于 opencv 已经具有计算差异平方的模板匹配功能,您可以使用它并寻找零差异,这将是您的匹配位置。因此,如果您有权访问 cv2(opencv 模块),则实现将如下所示 -
import cv2
from cv2 import matchTemplate as cv2m
M = cv2m(field_array.astype('uint8'),match_array.astype('uint8'),cv2.TM_SQDIFF)
R,C = np.where(M==0)
给我们-
In [204]: R,C
Out[204]: (array([6]), array([3]))
基准测试
本部分比较了解决该问题的所有建议方法的运行时间。本节中列出的各种方法归功于它们的贡献者。
方法定义-
def seek_array(search_in, search_for, return_coords = False):
si_x, si_y = search_in.shape
sf_x, sf_y = search_for.shape
for y in xrange(si_y-sf_y+1):
for x in xrange(si_x-sf_x+1):
if numpy.array_equal(search_for, search_in[x:x+sf_x, y:y+sf_y]):
return (x,y) if return_coords else True
return None if return_coords else False
def skimage_based(field_array,match_array):
windows = view_as_windows(field_array, match_array.shape)
return (windows == match_array).all(axis=(2,3)).nonzero()
def im2col_based(field_array,match_array):
col_match = im2col(field_array,match_array.shape)==match_array.ravel()[:,None]
out_shape = np.asarray(field_array.shape) - np.asarray(match_array.shape) + 1
return np.where(col_match.all(0).reshape(out_shape))
def cv2_based(field_array,match_array):
M = cv2m(field_array.astype('uint8'),match_array.astype('uint8'),cv2.TM_SQDIFF)
return np.where(M==0)
运行时测试 -
案例 #1(来自问题的示例数据):
In [11]: field_array
Out[11]:
array([[ 24, 25, 26, 27, 28, 29, 30, 31, 23],
[ 33, 34, 35, 36, 37, 38, 39, 40, 32],
[-39, -38, -37, -36, -35, -34, -33, -32, -40],
[-30, -29, -28, -27, -26, -25, -24, -23, -31],
[-21, -20, -19, -18, -17, -16, -15, -14, -22],
[-12, -11, -10, -9, -8, -7, -6, -5, -13],
[ -3, -2, -1, 0, 1, 2, 3, 4, -4],
[ 6, 7, 8, 4, 5, 6, 7, 13, 5],
[ 15, 16, 17, 8, 9, 10, 11, 22, 14]])
In [12]: match_array
Out[12]:
array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
In [13]: %timeit seek_array(field_array, match_array, return_coords = False)
1000 loops, best of 3: 465 µs per loop
In [14]: %timeit skimage_based(field_array,match_array)
10000 loops, best of 3: 97.9 µs per loop
In [15]: %timeit im2col_based(field_array,match_array)
10000 loops, best of 3: 74.3 µs per loop
In [16]: %timeit cv2_based(field_array,match_array)
10000 loops, best of 3: 30 µs per loop
案例 #2(更大的随机数据):
In [17]: field_array = np.random.randint(0,4,(256,256))
In [18]: match_array = field_array[100:116,100:116].copy()
In [19]: %timeit seek_array(field_array, match_array, return_coords = False)
1 loops, best of 3: 400 ms per loop
In [20]: %timeit skimage_based(field_array,match_array)
10 loops, best of 3: 54.3 ms per loop
In [21]: %timeit im2col_based(field_array,match_array)
10 loops, best of 3: 125 ms per loop
In [22]: %timeit cv2_based(field_array,match_array)
100 loops, best of 3: 4.08 ms per loop
关于python - 如何检查一个二维 NumPy 数组中是否包含特定的值模式?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/32531377/