python - 使用 scipy.interpolate.interpn 对 N 维数组进行插值

标签 python python-2.7 numpy scipy

假设我的数据取决于 4 个变量:a、b、c 和 d。我希望插值返回一个二维数组,它对应于 a 和 b 的单个值,以及 c 和 d 的值数组。但是,数组大小不必相同。具体来说,我的数据来自晶体管模拟。当前取决于此处的 4 个变量。我想绘制参数变化。参数上的点数比水平轴上的点数少很多。

import numpy as np
from scipy.interpolate import interpn
arr = np.random.random((4,4,4,4))
x1 = np.array([0, 1, 2, 3])
x2 = np.array([0, 10, 20, 30])
x3 = np.array([0, 10, 20, 30])
x4 = np.array([0, .1, .2, .30])
points = (x1, x2, x3, x4)

以下作品:

xi = (0.1, 9, np.transpose(np.linspace(0, 30, 4)), np.linspace(0, 0.3, 4))
result = interpn(points, arr, xi)

还有这个:

xi = (0.1, 9, 24, np.linspace(0, 0.3, 4))
result = interpn(points, arr, xi)

但不是这个:

xi = (0.1, 9, np.transpose(np.linspace(0, 30, 3)), np.linspace(0, 0.3, 4))
result = interpn(points, arr, xi)

如您所见,在最后一种情况下,xi 中最后两个数组的大小不同。 scipy 不支持这种功能,还是我使用的 interpn 不正确?我需要创建一个绘图,其中 xi 之一是参数,而另一个是水平轴。

最佳答案

我将尝试以二维方式向您解释这一点,以便您更好地了解正在发生的事情。首先,让我们创建一个线性阵列来进行测试。

import numpy as np

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm

# Set up grid and array of values
x1 = np.arange(10)
x2 = np.arange(10)
arr = x1 + x2[:, np.newaxis]

# Set up grid for plotting
X, Y = np.meshgrid(x1, x2)

# Plot the values as a surface plot to depict
fig = plt.figure()
ax = fig.gca(projection='3d')
surf = ax.plot_surface(X, Y, arr, rstride=1, cstride=1, cmap=cm.jet,
                       linewidth=0, alpha=0.8)
fig.colorbar(surf, shrink=0.5, aspect=5)

这给了我们: surface plot of values

然后,假设您想要沿一条线进行插值,即沿第一个维度的一个点,但沿第二个维度的所有点。这些点显然不在原始数组 (x1, x2) 中。假设我们要插值到一个点 x1 = 3.5,它位于 x1 轴上的两点之间。

from scipy.interpolate import interpn

interp_x = 3.5           # Only one value on the x1-axis
interp_y = np.arange(10) # A range of values on the x2-axis

# Note the following two lines that are used to set up the
# interpolation points as a 10x2 array!
interp_mesh = np.array(np.meshgrid(interp_x, interp_y))
interp_points = np.rollaxis(interp_mesh, 0, 3).reshape((10, 2))

# Perform the interpolation
interp_arr = interpn((x1, x2), arr, interp_points)

# Plot the result
ax.scatter(interp_x * np.ones(interp_y.shape), interp_y, interp_arr, s=20,
           c='k', depthshade=False)
plt.xlabel('x1')
plt.ylabel('x2')

plt.show()

这会给出您想要的结果:注意黑点正确地位于平面上,x1 值为 3.5surface plot of interpolated points

请注意,大部分“魔法”以及您问题的答案都在于这两行:

interp_mesh = np.array(np.meshgrid(interp_x, interp_y))
interp_points = np.rollaxis(interp_mesh, 0, 3).reshape((10, 2))

我已经解释了这个 elsewhere 的工作原理.简而言之,它所做的是创建一个大小为 10x2 的数组,其中包含您要插入 arr 的 10 个点的坐标。 (那篇文章和这篇文章的唯一区别是我已经为 np.mgrid 写了解释,这是为一堆编写 np.meshgrid 的捷径排列。)

对于您的 4x4x4x4 案例,您可能需要这样的东西:

interp_mesh = np.meshgrid([0.1], [9], np.linspace(0, 30, 3),
                          np.linspace(0, 0.3, 4))
interp_points = np.rollaxis(interp_mesh, 0, 5)
interp_points = interp_points.reshape((interp_mesh.size // 4, 4))
result = interpn(points, arr, interp_points)

希望对您有所帮助!

关于python - 使用 scipy.interpolate.interpn 对 N 维数组进行插值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39332053/

相关文章:

python - 子进程调用失败

python - 将 self 变量作为参数传递给 mixin 父方法的正确方法

python - 如何使用 Drive API 在 Google 电子表格中获取工作表列表(名称和 "gid")?

python - 在 Python 中对混合列表进行排序

python - 如何使用 SWIG 将 numpy 数组转换为 vector<int>& (引用)

python - 我需要将包含数据的文本文件读入 python 并将数据分配给变量

Python:第 1 行的语法无效(文件 <string>)

Python与Sqlite 3,如何仅提取文本,不带'u?

python - 如何将 numpy ndarrays 列表作为向量传递给 cython

python - 如何在Python中迭代地将数组保存到文件中?