我需要在 numpy 数组“标签”中找到最频繁出现的元素,前提是这些元素位于掩码数组内。这是蛮力方法:
def getlabel(mask, label):
# get majority label
assert label.shape == mask.shape
tmp = []
for i in range(mask.shape[0]):
for j in range(mask.shape[1]):
if mask[i][j] == True:
tmp.append(label[i][j])
return Counter(tmp).most_common(1)[0][0]
不过,我认为这还不是最优雅、最快速的方法。我应该使用哪些其他数据结构? (hasing、字典等...)?
最佳答案
假设您的 mask
是一个 bool 数组:
import numpy as np
cnt = np.bincount(label[mask].flat)
这为您提供了值 0、1、2 ... max(label) 出现次数的向量
你可以通过以下方式找到最频繁的
most_frequent = np.argmax(cnt)
当然,输入数据中这些元素的数量是
cnt[most_frequent]
通常,np.bincount
很快。让我们尝试使用最大数量为 999 的标签(即 1000 个分箱)和一个由 8 000 000 个值屏蔽的 10 000 000 个元素数组:
data = np.random.randint(0, 1000, (1000, 10000))
mask = np.random.random((1000, 10000)) < 0.8
# time this section
cnt = np.bincount(data[mask].flat)
在我的机器上,这需要 80 毫秒。 argmax
可能需要 2 ns/bin,所以即使您的标签整数有点分散,也没关系。
如果满足以下条件,这种方法可能是最快的方法:
- 标签是0..N范围内的整数,其中N不大于输入数组的大小
- 输入数据在 NumPy 数组中
这个解决方案可能适用于其他一些情况,但更多的是如何以及是否有更好的解决方案可用的问题。 (参见 metaperture
的回答。)例如,将 Python 列表简单转换为 ndarray
。相当昂贵,并且通过bincount
获得了速度优势如果输入的是Python列表,数据量不大,就会丢失。
整数空间中标签的稀疏性本身不是问题。创建和归零输出向量相对较快,使用 np.nonzero
可以轻松快速地压缩回去.但是,如果最大标签值与输入数组的大小相比较大,则可能会失去速度优势。
关于python - 在掩码数组中查找出现频率最高的元素,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/24814397/