python - Numpy 根据条件拆分数组,无需 for 循环

标签 python arrays performance numpy vectorization

假设我有一个 numpy 数组,它在二维空间中保存点,如下所示

np.array([[3, 2], [4, 4], [5, 4], [4, 2], [4, 6], [9, 5]]) 

我还有一个 numpy 数组,将每个点标记为一个数字,这个数组是一个一维数组,长度为点数组中的点数。

np.array([0, 1, 1, 0, 2, 1])

现在我想取标签数组中具有索引的每个点的平均值。因此,对于所有标签为 0 的点,取这些点的平均值。 我目前解决这个问题的方法是以下方法

return np.array([points[labels==k].mean(axis=0) for k in range(k)])

其中 k 是标签数组中的最大数字,或者称为标记点的方法数。

我想要一种不使用 for 循环来执行此操作的方法,也许我还没有发现一些 numpy 功能?

最佳答案

方法#1:braodcasting 的一些帮助下,我们可以利用矩阵乘法 -

mask = labels == np.arange(labels.max()+1)[:,None]
out = mask.dot(points)/np.bincount(labels).astype(float)[:,None]

sample 运行-

In [36]: points = np.array([[3, 2], [4, 4], [5, 4], [4, 2], [4, 6], [9, 5]]) 
    ...: labels = np.array([0, 1, 1, 0, 2, 1])

# Original soln
In [37]: L = labels.max()+1

In [38]: np.array([points[labels==k].mean(axis=0) for k in range(L)])
Out[38]: 
array([[3.5       , 2.        ],
       [6.        , 4.33333333],
       [4.        , 6.        ]])

# Proposed soln
In [39]: mask = labels == np.arange(labels.max()+1)[:,None]
    ...: out = mask.dot(points)/np.bincount(labels).astype(float)[:,None]

In [40]: out
Out[40]: 
array([[3.5       , 2.        ],
       [6.        , 4.33333333],
       [4.        , 6.        ]])

方法#2:使用np.add.at -

sums = np.zeros((labels.max()+1,points.shape[1]),dtype=float)
np.add.at(sums,labels,points)
out = sums/np.bincount(labels).astype(float)[:,None]

方法#3:如果从 0 到 max-label 中的所有数字都出现在 labels 中,我们也可以使用 np.add.reduceat -

sidx = labels.argsort()
sorted_points = points[sidx]
sums = np.add.reduceat(sorted_points,np.r_[0,np.bincount(labels)[:-1].cumsum()])
out = sums/np.bincount(labels).astype(float)[:,None]

关于python - Numpy 根据条件拆分数组,无需 for 循环,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54889608/

相关文章:

python - 为什么 Python itertools 没有归类为生成器 (GeneratorType)?

python - 安装用 C++ : g++ unrecognized command line option --output-lib 编写的 Python 包 (leven) 时出错

java - 为什么我要为String数组获取java.lang.NullPointerException?

c# - 如果有的话,使用 System.Diagnostics.Stopwatch 的资源损失是多少?

python - 为什么列表理解可以比 Python 中的 map() 更快?

python - re.match 在两个不同的字符串上返回 true

c++ - 使用 cout 后对象数组被破坏

c++ - C/C++ 中的指针编译但给出段错误

performance - 如何在 JMeter 中显示实际循环计数

python - 查找列表开头的相等元素的数量