python - 如何广播或矢量化使用 scipy.ndimage map_coordinates 的 2D 数组的线性插值?

标签 python numpy image-processing scipy ndimage

我最近在性能方面遇到了障碍。我知道如何通过暴力/循环二维数组中的每一行和每一列来手动循环并从原始单元格到所有其他单元格进行插值。

但是,当我处理形状为 (3000, 3000) 的 2D 数组时,线性间距和插值会陷入停滞并严重损害性能。

我正在寻找一种可以优化此循环的方法,我知道矢量化和广播,只是不确定如何在这种情况下应用它。

我用代码和图来解释

import numpy as np
from scipy.ndimage import map_coordinates
m = np.array([
    [10,10,10,10,10,10],
    [9,9,9,10,9,9],
    [9,8,9,10,8,9],
    [9,7,8,0,8,9],
    [8,7,7,8,8,9],
    [5,6,7,7,6,7]])

origin_row = 3
origin_col = 3
m_max = np.zeros(m.shape)
m_dist = np.zeros(m.shape)

rows, cols = m.shape
for col in range(cols):
    for row in range(rows):
        # Get spacing linear interpolation
        x_plot = np.linspace(col, origin_col, 5)
        y_plot = np.linspace(row, origin_row, 5)

        # grab the interpolated line
        interpolated_line = map_coordinates(m,
                                      np.vstack((y_plot,
                                                 x_plot)),
                                      order=1, mode='nearest')
        m_max[row][col] = max(interpolated_line)
        m_dist[row][col] = np.argmax(interpolated_line)

print(m)
print(m_max)
print(m_dist)
正如你所看到的,这是非常暴力的,我已经设法广播了这部分周围的所有代码,但停留在这部分。 这是我想要实现的目标的说明,我将经历第一次迭代

1.) 输入数组

input array

2.) 从 0,0 到原点 (3,3) 的第一个循环

first cell to origin

3.) 这将返回 [10 9 9 8 0],最大值将为 10,索引将为 0

5.) 这是我使用的示例数组的输出

sample output of m

这是基于已接受答案的性能更新。

times

最佳答案

为了加快代码速度,您可以首先在循环外部创建 x_ploty_plot,而不是每次创建多次:

#this would be outside of the loops
num = 5
lin_col = np.array([np.linspace(i, origin_col, num) for i in range(cols)])
lin_row = np.array([np.linspace(i, origin_row, num) for i in range(rows)])

然后您可以通过x_plot = lin_col[col]y_plot = lin_row[row]在每个循环中访问它们

其次,您可以通过对每一对(rowcol )。为此,您可以使用 np.tile 创建 x_ploty_plot 的所有组合。和 np.ravel如:

arr_vs = np.vstack(( np.tile( lin_row, cols).ravel(),
                     np.tile( lin_col.ravel(), rows)))

请注意,ravel 并不是每次都在同一位置使用来获取所有组合。现在,您可以将 map_cooperatives 与此 arr_vs 结合使用,并使用 列数 reshape 结果num 获取 3D 数组最后一个轴中的每个 interpolated_line:

arr_map = map_coordinates(m, arr_vs, order=1, mode='nearest').reshape(rows,cols,num)

最后,您可以在arr_map的最后一个轴上使用np.maxnp.argmax来获取结果m_max m_dist。所以所有的代码都是:

import numpy as np
from scipy.ndimage import map_coordinates
m = np.array([
    [10,10,10,10,10,10],
    [9,9,9,10,9,9],
    [9,8,9,10,8,9],
    [9,7,8,0,8,9],
    [8,7,7,8,8,9],
    [5,6,7,7,6,7]])

origin_row = 3
origin_col = 3
rows, cols = m.shape

num = 5
lin_col = np.array([np.linspace(i, origin_col, num) for i in range(cols)])
lin_row = np.array([np.linspace(i, origin_row, num) for i in range(rows)])

arr_vs = np.vstack(( np.tile( lin_row, cols).ravel(),
                     np.tile( lin_col.ravel(), rows)))

arr_map = map_coordinates(m, arr_vs, order=1, mode='nearest').reshape(rows,cols,num)
m_max = np.max( arr_map, axis=-1)
m_dist = np.argmax( arr_map, axis=-1)

print (m_max)
print (m_dist)

你会得到预期的结果:

#m_max
array([[10, 10, 10, 10, 10, 10],
       [ 9,  9, 10, 10,  9,  9],
       [ 9,  9,  9, 10,  8,  9],
       [ 9,  8,  8,  0,  8,  9],
       [ 8,  8,  7,  8,  8,  9],
       [ 7,  7,  8,  8,  8,  8]])
#m_dist
array([[0, 0, 0, 0, 0, 0],
       [0, 0, 2, 0, 0, 0],
       [0, 2, 0, 0, 0, 0],
       [0, 1, 0, 0, 0, 0],
       [0, 2, 0, 0, 0, 0],
       [1, 1, 2, 1, 2, 1]])

编辑:lin_collin_row 相关,因此您可以做得更快:

if cols >= rows:
    arr = np.arange(cols)[:,None]
    lin_col = arr + (origin_col-arr)/(num-1.)*np.arange(num)
    lin_row = lin_col[:rows] + np.linspace(0, origin_row - origin_col, num)[None,:]
else:
    arr = np.arange(rows)[:,None]
    lin_row = arr + (origin_row-arr)/(num-1.)*np.arange(num)
    lin_col = lin_row[:cols] + np.linspace(0, origin_col - origin_row, num)[None,:]

关于python - 如何广播或矢量化使用 scipy.ndimage map_coordinates 的 2D 数组的线性插值?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53898073/

相关文章:

python - 是否有等同于 np.empty 的 tensorflow ?

python - 使用 Numpy 的最快方法 - 多维和与乘积

python - genfromtxt 返回 NaN 行

java - 如何从 Java 中的 URL 获取图像格式?

ruby - OS X 上的 Camellia Ruby 计算机视觉库

Python for-in-loop 停止迭代从 for-in-loop 创建的列表对象

python - 在 Django 中哪里可以翻译表单字段标签?

python - 排序二维列表python

matlab - 用 SIFT 检测数字图像上的印章(印章)印记

python - 使用两个掩码 a[mask1][mask2]=value 更新数组值