python - 从三个一维数组创建一个 3D 坐标的 numpy 数组

标签 python arrays performance numpy

假设我有三个任意一维数组,例如:

x_p = np.array((1.0, 2.0, 3.0, 4.0, 5.0))
y_p = np.array((2.0, 3.0, 4.0))
z_p = np.array((8.0, 9.0))

这三个数组表示 3D 网格中的采样间隔,我想为所有交叉点构造一个三维向量的 1D 数组,类似于

points = np.array([[1.0, 2.0, 8.0],
                   [1.0, 2.0, 9.0],
                   [1.0, 3.0, 8.0],
                   ...
                   [5.0, 4.0, 9.0]])

顺序实际上并不重要。生成它们的明显方法:

npoints = len(x_p) * len(y_p) * len(z_p)
points = np.zeros((npoints, 3))
i = 0
for x in x_p:
    for y in y_p:
        for z in z_p:
            points[i, :] = (x, y, z)
            i += 1

所以问题是……有没有更快的方法?我已经查找但没有找到(可能只是没有找到正确的 Google 关键字)。

我目前正在使用这个:

npoints = len(x_p) * len(y_p) * len(z_p)
points = np.zeros((npoints, 3))
i = 0
nz = len(z_p)
for x in x_p:
    for y in y_p:
        points[i:i+nz, 0] = x
        points[i:i+nz, 1] = y
        points[i:i+nz, 2] = z_p
        i += nz

但我觉得我缺少一些巧妙的 Numpy 方式?

最佳答案

要在上面的例子中使用 numpy 网格,下面的方法将起作用:

np.vstack(np.meshgrid(x_p,y_p,z_p)).reshape(3,-1).T

用于二维以上网格的 Numpy meshgrid 需要 numpy 1.7。为了避免这种情况并从 source code 中提取相关数据.

def ndmesh(*xi,**kwargs):
    if len(xi) < 2:
        msg = 'meshgrid() takes 2 or more arguments (%d given)' % int(len(xi) > 0)
        raise ValueError(msg)

    args = np.atleast_1d(*xi)
    ndim = len(args)
    copy_ = kwargs.get('copy', True)

    s0 = (1,) * ndim
    output = [x.reshape(s0[:i] + (-1,) + s0[i + 1::]) for i, x in enumerate(args)]

    shape = [x.size for x in output]

    # Return the full N-D matrix (not only the 1-D vector)
    if copy_:
        mult_fact = np.ones(shape, dtype=int)
        return [x * mult_fact for x in output]
    else:
        return np.broadcast_arrays(*output)

检查结果:

print np.vstack((ndmesh(x_p,y_p,z_p))).reshape(3,-1).T

[[ 1.  2.  8.]
 [ 1.  2.  9.]
 [ 1.  3.  8.]
 ....
 [ 5.  3.  9.]
 [ 5.  4.  8.]
 [ 5.  4.  9.]]

对于上面的例子:

%timeit sol2()
10000 loops, best of 3: 56.1 us per loop

%timeit np.vstack((ndmesh(x_p,y_p,z_p))).reshape(3,-1).T
10000 loops, best of 3: 55.1 us per loop

当每个维度为 100 时:

%timeit sol2()
1 loops, best of 3: 655 ms per loop
In [10]:

%timeit points = np.vstack((ndmesh(x_p,y_p,z_p))).reshape(3,-1).T
10 loops, best of 3: 21.8 ms per loop

根据你想对数据做什么,你可以返回一个 View :

%timeit np.vstack((ndmesh(x_p,y_p,z_p,copy=False))).reshape(3,-1).T
100 loops, best of 3: 8.16 ms per loop

关于python - 从三个一维数组创建一个 3D 坐标的 numpy 数组,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/18253210/

相关文章:

python - 为什么列表的长度必须包含 -1 才能在 for 循环中容纳 +1?

python - 我收到类型错误 : '<' not supported between instances of 'str' and 'int'

arrays - 重新排列整数数组

arrays - 将字符串复制到模型(Swift 3)

sql - 从 PostgreSQL 中的时间戳中提取日期的最有效方法是什么?

arrays - 使用一个函数将序列 3Un² + 2Un +1 的前 n 个值存储在数组中

php - 在python中解析php数组

python - "variable//= a value"语法在 Python 中意味着什么?

Java :Comparing two lists

python - Python 中的二维与一维字典效率