python - 使用 `scipy.interpolate.griddata` 插值非常慢

标签 python numpy matplotlib scipy

当我尝试将“几乎”规则网格化的数据插值到 map 坐标时,我遇到了 scipy.interpolate.griddata 极其缓慢的性能,以便可以使用 matplotlib 绘制 map 和数据。 pyplot.imshow 因为 matplotlib.pyplot.pcolormesh 花费的时间太长并且在 alpha 等方面表现不佳。

最好举个例子(输入文件可以下载here):

import matplotlib.pyplot as plt
import numpy as np
from scipy.interpolate import griddata

map_extent = (34.4, 36.2, 30.6, 33.4)
# data corners:
lon = np.array([[34.5,        34.83806236],
                [35.74547079, 36.1173923]])
lat = np.array([[30.8,        33.29936152],
                [30.67890411, 33.17826563]])

# load saved files
topo = np.load('topo.npy')
lons = np.load('lons.npy')
lats = np.load('lats.npy')
data = np.load('data.npy')

# get max res of data
dlon = abs(np.array(np.gradient(lons))).max()
dlat = abs(np.array(np.gradient(lats))).max()

# interpolate the data to the extent of the map
loni,lati = np.meshgrid(np.arange(map_extent[0], map_extent[1]+dlon, dlon),
                        np.arange(map_extent[2], map_extent[3]+dlat, dlat))
zi = griddata((lons.flatten(),lats.flatten()),
              data.flatten(), (loni,lati), method='linear')

绘图:

fig, (ax1,ax2) = plt.subplots(1,2)
ax1.axis(map_extent)
ax1.imshow(topo,extent=extent,cmap='Greys')

ax2.axis(map_extent)
ax2.imshow(topo,extent=extent,cmap='Greys')

ax1.imshow(zi, vmax=0.1, extent=extent, alpha=0.5, origin='lower')
ax1.plot(lon[0],lat[0], '--k', lw=3, zorder=10)
ax1.plot(lon[-1],lat[-1], '--k', lw=3, zorder=10)
ax1.plot(lon.T[0],lat.T[0], '--k', lw=3, zorder=10)
ax1.plot(lon.T[-1],lat.T[-1], '--k', lw=3, zorder=10)


ax2.pcolormesh(lons,lats,data, alpha=0.5)
ax2.plot(lon[0],lat[0], '--k', lw=3, zorder=10)
ax2.plot(lon[-1],lat[-1], '--k', lw=3, zorder=10)
ax2.plot(lon.T[0],lat.T[0], '--k', lw=3, zorder=10)
ax2.plot(lon.T[-1],lat.T[-1], '--k', lw=3, zorder=10)

结果:

enter image description here

注意,这不能通过简单地使用仿射变换旋转数据来完成。

griddata 每次调用我的真实数据需要超过 80 秒,而 pcolormesh 需要更长的时间(超过 2 分钟!)。我看过 Jaimi 的回答 here和 Joe Kington 的回答 here但我想不出一种方法让它为我工作。

我所有的数据集都具有完全相同的 lonslats,所以基本上我需要将这些映射一次到 map 的坐标,并对数据本身应用相同的转换。问题是我该怎么做?

最佳答案

在长期忍受 scipy.interpolate.griddata 极其缓慢的性能之后,我决定放弃 griddata,转而使用 OpenCV 进行图像转换| .具体来说,Perspective Transformation .

所以对于上面的例子,上面问题中的那个,你可以获得输入文件here ,这是一段代码,它需要 1.1 毫秒,而上面示例中的重新网格化部分需要 692 毫秒。

import cv2
new_data = data.T[::-1]

# calculate the pixel coordinates of the
# computational domain corners in the data array
w,e,s,n = map_extent
dx = float(e-w)/new_data.shape[1]
dy = float(n-s)/new_data.shape[0]
x = (lon.ravel()-w)/dx
y = (n-lat.ravel())/dy

computational_domain_corners = np.float32(zip(x,y))

data_array_corners = np.float32([[0,new_data.shape[0]],
                                 [0,0],
                                 [new_data.shape[1],new_data.shape[0]],
                                 [new_data.shape[1],0]])

# Compute the transformation matrix which places
# the corners of the data array at the corners of
# the computational domain in data array pixel coordinates
tranformation_matrix = cv2.getPerspectiveTransform(data_array_corners,
                                                   computational_domain_corners)

# Make the transformation making the final array the same shape
# as the data array, cubic interpolate the data placing NaN's
# outside the new array geometry
mapped_data = cv2.warpPerspective(new_data,tranformation_matrix,
                                  (new_data.shape[1],new_data.shape[0]),
                                  flags=2,
                                  borderMode=0,
                                  borderValue=np.nan)

我看到此解决方案的唯一缺点是数据中存在轻微偏移,如所附图像中的非重叠轮廓所示。黑色重新网格化的数据轮廓(可能更准确)和 'jet' 色标中的 warpPerspective 数据轮廓。

目前,我对性能优势方面的差异感到满意,我希望这个解决方案也能帮助其他人。

有人(不是我...)应该找到提高 griddata 性能的方法 :) 享受吧!

enter image description here

关于python - 使用 `scipy.interpolate.griddata` 插值非常慢,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/28613747/

相关文章:

python - 为什么 np.dot 比使用 for 循环求点积快得多

python - 用最后一个非零值填充 1d numpy 数组的零值

python - 在 Enthought Canopy IDE 中执行脚本时 matplotlib.figure() 不起作用

python - 如何在 Matplotlib 中生成嵌套图例

python - Python 中的非顺序循环优化

python - 减去二维数组中每行的第一个元素

python-3.x - 比较来自两个数据帧的列并删除 df2 中与 df1 中的值相差 +/-0.03 范围内的行

python - 如何向 matplotlib.pylab.imshow() 对话框添加按钮?

Python 多处理如何优雅地退出?

python - 如何使用 Crontab 在 Django 环境中运行 Python 文件?