python - 在numpy中组合多维数据

标签 python arrays numpy

我有一维 NumPy 数组,表示 n 维网格上的点。每个 NumPy 数组代表一个维度中的点。我想生成一个合并的 NumPy 数组,该数组将具有形状为 (n , m) 的 n 维网格,其中 n 将为 len (dim-1) * len(dim-2) * ...m 是维度数

例如(二维情况):

In [1]: x = np.array([1, 2])

In [2]: x
Out[2]: array([1, 2])

In [3]: y = np.array([3, 4, 5])

In [4]: y
Out[4]: array([3, 4, 5])

In [5]: result = np.array([[1, 3], [1, 4],[1, 5],[2, 3],[2, 4],[2, 5]])

In [6]: result
Out[6]: 
array([[1, 3],
       [1, 4],
       [1, 5],
       [2, 3],
       [2, 4],
       [2, 5]])

另一个例子(3-D):

In [7]: x = np.array([1])

In [8]: y = np.array([2, 3])   

In [9]: z = np.array([4, 5, 6])

In [10]: x
Out[10]: array([1])

In [11]: y
Out[11]: array([2, 3])

In [12]: z
Out[12]: array([4, 5, 6])

In [13]: result = np.array([[1, 2, 4], [1, 3, 4], [1, 2, 5], [1, 3, 5], [1, 2, 6], [1, 3, 6]])

In [14]: result
Out[14]: 
array([[1, 2, 4],
       [1, 3, 4],
       [1, 2, 5],
       [1, 3, 5],
       [1, 2, 6],
       [1, 3, 6]])

有没有一种方法可以轻松地在 n 维上做到这一点,而无需循环遍历每个数组?

最佳答案

您可以使用np.meshgrid创建扩展版本,然后使用 np.column_stack在以列为主的扁平化版本上,就像这样 -

X,Y,Z = np.meshgrid(x,y,z)
out = np.column_stack((X.ravel('F'),Y.ravel('F'),Z.ravel('F')))
<小时/>

为了使其通用,以便它适用于任意数量的输入情况,我们需要一些额外的工作,就像这样 -

def combine_arrays(A):
    return np.dstack(np.meshgrid(*A)).ravel('F').reshape(len(A),-1).T

运行示例来测试 2D3D 情况 -

In [67]: # 2D case
    ...: x = np.array([1, 2])
    ...: y = np.array([3, 4, 5])
    ...: 

In [68]: combine_arrays((x,y))
Out[68]: 
array([[1, 3],
       [1, 4],
       [1, 5],
       [2, 3],
       [2, 4],
       [2, 5]])

In [69]: # 3D case
    ...: x = np.array([1])
    ...: y = np.array([2, 3])   
    ...: z = np.array([4, 5, 6])
    ...: 

In [70]: combine_arrays((x,y,z))
Out[70]: 
array([[1, 2, 4],
       [1, 3, 4],
       [1, 2, 5],
       [1, 3, 5],
       [1, 2, 6],
       [1, 3, 6]])

关于python - 在numpy中组合多维数据,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/37617996/

相关文章:

python - python 3.4 的正则表达式

python - 比较 2 个字典中的键和值

python - 简单的Python-3程序中的无效语法错误

java - Java int[] 压缩工具

python - 导入错误: undefined symbol: _PyUnicodeUCS4_IsWhitespace

python - 将 itertools 数组转换为 numpy 数组

python - 将数据点拟合到累积分布

python - python项目结构是什么

java - 检查句子是否包含某些单词

php - 我应该将数组存储到 mysql 还是将每个值单独存储在表中?