我想首先按元素组合三个形状相等的数组,然后将较小数组的重复条目连接到此。我可以做到这一点,但我的方法有很多步骤,我想知道是否有更直接(更快?)的方法。现在看到目标中的规律性,似乎很明显一些列表理解可以解决它。然而,我希望保持可读性,以便其他人可以遵循它,但效率足以扩展到几百万行,因此我们首先想到的是 NumPy 操作。
取三个形状相同的 3D 数组 a x b x c,以及相同长度“a”的浅数组“m”。我们需要组合 a、b 和 c,删除 m 的第二列和第三列,然后在组合后将 m 广播到其他列,并将结果作为二维数组返回。
这是我的方法。下面的通用数组将是这些给定的输入——这些数字只是为了标记,以便目标结果会更清晰。
>>> import numpy as np
>>> a = 2
>>> b = 3
>>> c = 4
>>> n = a * b * c
>>> x = np.arange(n).reshape(a,b,c)
>>> y = x + n
>>> z = y + n
x = [[[ 0. 1. 2. 3.]
[ 4. 5. 6. 7.]
[ 8. 9. 10. 11.]]
[[12. 13. 14. 15.]
[16. 17. 18. 19.]
[20. 21. 22. 23.]]]
>>> m = 3 * n + np.linspace(0, 13 * a - 1, 13 * a).reshape(a, 13)
m = [[72. 73. 74. 75. 76. 77. 78. 79. 80. 81. 82. 83. 84.]
[85. 86. 87. 88. 89. 90. 91. 92. 93. 94. 95. 96. 97.]]
组合a、b和c,使每个对应的条目成为一行。
>>> triples = np.stack((x, y, z), axis=3)
不需要这些特定条目。
>>> cleared = np.delete(m, np.s_[1:3], axis = 1)
“手动广播”使尺寸达到正确的形状。 我认为可以一次性做到这一点,但我不确定如何做到。
>>> sizeup1 = np.linspace(cleared, cleared, c, axis=1)
>>> sizeup2 = np.linspace(sizeup1, sizeup1, b, axis=1)
连接每个相应的最内层行。
>>> pre_res = np.concatenate((sizeup2, triples), axis=3)
最终结果是一个二维数组,如下所示。
>>> result = pre_res.reshape(a * b * c, -1)
result = [[72. 75. 76. 77. 78. 79. 80. 81. 82. 83. 84. 0. 24. 48.]
[72. 75. 76. 77. 78. 79. 80. 81. 82. 83. 84. 1. 25. 49.]
[72. 75. 76. 77. 78. 79. 80. 81. 82. 83. 84. 2. 26. 50.]
[72. 75. 76. 77. 78. 79. 80. 81. 82. 83. 84. 3. 27. 51.]
[72. 75. 76. 77. 78. 79. 80. 81. 82. 83. 84. 4. 28. 52.]
[72. 75. 76. 77. 78. 79. 80. 81. 82. 83. 84. 5. 29. 53.]
[72. 75. 76. 77. 78. 79. 80. 81. 82. 83. 84. 6. 30. 54.]
[72. 75. 76. 77. 78. 79. 80. 81. 82. 83. 84. 7. 31. 55.]
[72. 75. 76. 77. 78. 79. 80. 81. 82. 83. 84. 8. 32. 56.]
[72. 75. 76. 77. 78. 79. 80. 81. 82. 83. 84. 9. 33. 57.]
[72. 75. 76. 77. 78. 79. 80. 81. 82. 83. 84. 10. 34. 58.]
[72. 75. 76. 77. 78. 79. 80. 81. 82. 83. 84. 11. 35. 59.]
[85. 88. 89. 90. 91. 92. 93. 94. 95. 96. 97. 12. 36. 60.]
[85. 88. 89. 90. 91. 92. 93. 94. 95. 96. 97. 13. 37. 61.]
[85. 88. 89. 90. 91. 92. 93. 94. 95. 96. 97. 14. 38. 62.]
[85. 88. 89. 90. 91. 92. 93. 94. 95. 96. 97. 15. 39. 63.]
[85. 88. 89. 90. 91. 92. 93. 94. 95. 96. 97. 16. 40. 64.]
[85. 88. 89. 90. 91. 92. 93. 94. 95. 96. 97. 17. 41. 65.]
[85. 88. 89. 90. 91. 92. 93. 94. 95. 96. 97. 18. 42. 66.]
[85. 88. 89. 90. 91. 92. 93. 94. 95. 96. 97. 19. 43. 67.]
[85. 88. 89. 90. 91. 92. 93. 94. 95. 96. 97. 20. 44. 68.]
[85. 88. 89. 90. 91. 92. 93. 94. 95. 96. 97. 21. 45. 69.]
[85. 88. 89. 90. 91. 92. 93. 94. 95. 96. 97. 22. 46. 70.]
[85. 88. 89. 90. 91. 92. 93. 94. 95. 96. 97. 23. 47. 71.]]
最终结果的规律性可能有助于找到更直接的结果,但代码仍然应该可以在没有首先看到目标的情况下遵循。我希望任何熟悉 NumPy 的人都能轻松识别出我的方法中的疏忽,甚至是一两行代码吗?
更新:我已经对 @hpaulj 的建议进行了基准测试,并将分享以防将来有人发现这有用。 对于小型阵列,它们几乎相同。对于任何数组大小,之间的差异可以忽略不计
sizeup_opt1 = np.tile(cleared[:,None,None,:],(1,b,c,1))
和
sizeup_opt2 = np.repeat(cleared[:,None,:],b*c,1).reshape(a,b,c,11)
那么另一种情况是对于大数组,当
triples = np.stack((x, y, z), axis=3)
np.concatenate((sizeup_opt2, triples), axis=3).reshape(a * b * c, -1)
只需要一半的时间
np.concatenate([sizeup_opt2, x[...,None], y[...,None], z[...,None]], axis=3).reshape(a * b * c, -1)
对于非常大的数组。
最佳答案
sizeup
linspace 的一些替代方案:
In [113]: sizeup1.shape
Out[113]: (2, 4, 11)
In [114]: sizeup2.shape
Out[114]: (2, 3, 4, 11)
In [115]: cleared.shape
Out[115]: (2, 11)
In [116]: np.tile(cleared[:,None,None,:],(1,3,4,1)).shape
Out[116]: (2, 3, 4, 11)
In [117]: np.allclose(np.tile(cleared[:,None,None,:],(1,3,4,1)),sizeup2)
Out[117]: True
In [118]: np.allclose(np.repeat(cleared[:,None,:],3*4,1).reshape(2,3,4,11),sizeup2)
我怀疑一个重复案例将是最快的,但我没有对这些进行计时。
In [120]: x.shape
Out[120]: (2, 3, 4)
In [121]: triples = np.stack((x, y, z), axis=3)
In [122]: triples.shape
Out[122]: (2, 3, 4, 3)
In [123]: pre_res = np.concatenate((sizeup2, triples), axis=3)
In [124]: pre_res.shape
Out[124]: (2, 3, 4, 14)
因此最后一个连接
将 (2,3,4,11) 与 (2,3,4,3) 连接起来
stack
有效地执行以下操作:
np.concatenate([x[...,None], y[...,None], z[...,None]], axis=3)
也就是说,它为每个维度添加一个尾随维度并将它们连接起来
所以 pre_res
可能是:
np.concatenate([sizeup2, x[...,None], ....], axis=3)
但我不确定这是否有很大的改进。
您可以之前将尾随尺寸添加到 x
:
x = np.linspace(0, n-1, n).reshape(a, b, c, 1)
关于arrays - 高效堆叠和连接 NumPy 数组,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/67052312/