为了标准化区域提案算法(即,对图像的每个 X-y-Y 区域应用回归),我需要在对每个提案的激活求和时创建区域提案标准化。目前,对于图像的 128x128 补丁,我在 Python 中运行这段代码
region_normalization = np.zeros(image.shape)
for x in range(0,image.shape[0]-128):
for y in range(0,image.shape[0]-128):
region_normalization[x:x+128,y:y+128] =
np.add(region_normalization[x:x+128,y:y+128],1)`
但这效率特别低。该算法的更快和/或更Python化的实现是什么?
谢谢!
最佳答案
对其进行逆向工程!
好吧,让我们看一下小图像和较小的 N
情况的输出,因为我们将尝试对这个循环代码进行逆向工程。因此,使用 N = 4
(其中 N
在原始情况下为 128
)和 image.shape = (10,10)
,我们会有:
In [106]: region_normalization
Out[106]:
array([[ 1, 2, 3, 4, 4, 4, 3, 2, 1, 0],
[ 2, 4, 6, 8, 8, 8, 6, 4, 2, 0],
[ 3, 6, 9, 12, 12, 12, 9, 6, 3, 0],
[ 4, 8, 12, 16, 16, 16, 12, 8, 4, 0],
[ 4, 8, 12, 16, 16, 16, 12, 8, 4, 0],
[ 4, 8, 12, 16, 16, 16, 12, 8, 4, 0],
[ 3, 6, 9, 12, 12, 12, 9, 6, 3, 0],
[ 2, 4, 6, 8, 8, 8, 6, 4, 2, 0],
[ 1, 2, 3, 4, 4, 4, 3, 2, 1, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
我们确实在那里看到了对称性,并且这种对称性恰好横跨 X
和 Y
轴。我们注意到的另一件事是每个元素都是其起始行和列元素的乘积。因此,我们的想法是获取第一行和第一列,并在它们的元素之间执行逐元素乘法。由于第一行和第一列是相同的,我们只需要获取一次并将其与附加轴一起使用,然后让 NumPy broadcasting
处理这些乘法。因此,实现将是 -
N = 128
a1D = np.hstack((np.arange(N)+1,np.full(image.shape[0]-2*N-1,N,dtype=int),\
np.arange(N,-1,-1)))
out = a1D[:,None]*a1D
运行时测试
In [137]: def original_app(image):
...: region_normalization = np.zeros(image.shape,dtype=int)
...: for x in range(0,image.shape[0]-128):
...: for y in range(0,image.shape[0]-128):
...: region_normalization[x:x+128,y:y+128] = \
...: np.add(region_normalization[x:x+128,y:y+128],1)
...: return region_normalization
...:
...: def vectorized_app(image):
...: N = 128
...: a1D = np.hstack((np.arange(N)+1,np.full(image.shape[0]-2*N-1,N,\
...: dtype=int),np.arange(N,-1,-1)))
...:
...: return a1D[:,None]*a1D
...:
In [138]: # Input
...: image = np.random.randint(0,255,(512,512))
In [139]: np.allclose(original_app(image),vectorized_app(image)) #Verify
Out[139]: True
In [140]: %timeit original_app(image)
1 loops, best of 3: 13 s per loop
In [141]: %timeit vectorized_app(image)
1000 loops, best of 3: 1.4 ms per loop
super 加速!
关于python - 区域提案标准化的最快算法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/38643385/