在numpy中有没有快速获取唯一元素的方法?我有类似的代码(最后一行)
tab = numpy.arange(100000000)
indices1 = numpy.random.permutation(10000)
indices2 = indices1.copy()
indices3 = indices1.copy()
indices4 = indices1.copy()
result = numpy.unique(numpy.array([tab[indices1], tab[indices2], tab[indices3], tab[indices4]]))
这只是一个示例,在我的情况下 indices1, indices2,...,indices4
包含不同的索引集并具有不同的大小。最后一行执行了很多次,我注意到它实际上是我代码中的瓶颈({numpy.core.multiarray.arange}
是先行的)。此外,顺序并不重要,索引数组中的元素是 int32
类型。我在考虑使用以元素值作为键的哈希表并尝试过:
seq = itertools.chain(tab[indices1].flatten(), tab[indices2].flatten(), tab[indices3].flatten(), tab[indices4].flatten())
myset = {}
map(myset.__setitem__, seq, [])
result = numpy.array(myset.keys())
但情况更糟。
有什么办法可以加快速度吗?我想性能损失来自复制数组的“花式索引”,但我只需要结果元素来读取(我不修改任何东西)。
最佳答案
[以下内容实际上部分不正确(请参阅 PS):]
下面的获取所有子数组中唯一元素的方式非常快:
seq = itertools.chain(tab[indices1].flat, tab[indices2].flat, tab[indices3].flat, tab[indices4].flat)
result = set(seq)
请注意 flat
(它返回一个迭代器)被用来代替 flatten()
(它返回一个完整的数组),并且 set()
可以直接调用(而不是像第二种方法那样使用 map()
和字典)。
以下是计时结果(在 IPython shell 中获得):
>>> %timeit result = numpy.unique(numpy.array([tab[indices1], tab[indices2], tab[indices3], tab[indices4]]))
100 loops, best of 3: 8.04 ms per loop
>>> seq = itertools.chain(tab[indices1].flat, tab[indices2].flat, tab[indices3].flat, tab[indices4].flat)
>>> %timeit set(seq)
1000000 loops, best of 3: 223 ns per loop
在这个例子中,set/flat 方法因此快了 40 倍。
PS:set(seq)
的时间其实不具有代表性。事实上,计时的第一个循环清空了seq
。迭代器和后续的 set()
评估返回一个空集!正确的时序测试如下
>>> %timeit set(itertools.chain(tab[indices1].flat, tab[indices2].flat, tab[indices3].flat, tab[indices4].flat))
100 loops, best of 3: 9.12 ms per loop
这表明 set/flat 方法实际上并不快。
PPS:这里是对 mtrw 的建议的(不成功的)探索;事先找到唯一索引可能是个好主意,但我找不到比上述方法更快的实现方法:
>>> %timeit set(indices1).union(indices2).union(indices3).union(indices4)
100 loops, best of 3: 11.9 ms per loop
>>> %timeit set(itertools.chain(indices1.flat, indices2.flat, indices3.flat, indices4.flat))
100 loops, best of 3: 10.8 ms per loop
因此,找到所有不同索引的集合本身就很慢。
PPPS:numpy.unique(<concatenated array of indices>)
实际上比 set(<concatenated array of indices>)
快 2-3 倍.这是在 Bago 的回答(unique(concatenate((…)))
)中获得加速的关键。原因可能是让 NumPy 自己处理它的数组通常比将纯 Python ( set
) 与 NumPy 数组连接起来更快。
结论:因此,此答案仅记录不应完全遵循的失败尝试,以及关于迭代器计时代码的可能有用的评论......
关于python - 在 numpy 和 python 中快速删除重复项,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/8620521/