python - 向量化 numpy 对于子数组是唯一的

标签 python numpy

我有一个形状为 (N, 20, 20) 的 numpy 数组数据,其中 N 是一个非常大的数字。 我想获取每个 20x20 子数组中唯一值的数量。 有一个循环:

values = []
for i in data:
    values.append(len(np.unique(i)))

我如何向量化这个循环?速度是一个问题。

如果我尝试 np.unique(data),我会得到整个数据数组的唯一值,而不是单个 20x20 block ,所以这不是我需要的。

最佳答案

首先,您可以使用 data.reshape(N,-1) ,因为您有兴趣对最后 2 个维度进行排序。

获取每一行的唯一值数量的一种简单方法是将每一行转储到一个集合中并让它进行排序:

[len(set(i)) for i in data.reshape(data.shape[0],-1)]

但这是一个迭代,可能是一个快速的迭代。

“矢量化”的一个问题是每行中唯一值的集合或列表的长度不同。在“矢量化”方面,“不同长度的行”是一个危险信号。您不再拥有使大多数矢量化成为可能的“矩形”数据布局。

您可以对每一行进行排序:

np.sort(data.reshape(N,-1))

array([[1, 2, 2, 3, 3, 5, 5, 5, 6, 6],
       [1, 1, 1, 2, 2, 2, 3, 3, 5, 7],
       [0, 0, 2, 3, 4, 4, 4, 5, 5, 9],
       [2, 2, 3, 3, 4, 4, 5, 7, 8, 9],
       [0, 2, 2, 2, 2, 5, 5, 5, 7, 9]])

但是如何在不迭代的情况下识别每一行中的唯一值?计算非零差异的数量可能就可以解决问题:

In [530]: data=np.random.randint(10,size=(5,10))

In [531]: [len(set(i)) for i in data.reshape(data.shape[0],-1)]
Out[531]: [7, 6, 6, 8, 6]

In [532]: sdata=np.sort(data,axis=1)
In [533]: (np.diff(sdata)>0).sum(axis=1)+1            
Out[533]: array([7, 6, 6, 8, 6])

我打算添加一个关于 float 的警告,但是如果 np.unique正在为您的数据工作,我的方法应该同样有效。


[(np.bincount(i)>0).sum() for i in data]

这是一个迭代解决方案,显然比我的 len(set(i)) 快版本,并与 diff...sort 竞争.

在 [585] 中:data.shape 输出[585]: (10000, 400)

In [586]: timeit [(np.bincount(i)>0).sum() for i in data]
1 loops, best of 3: 248 ms per loop

In [587]: %%timeit                                       
sdata=np.sort(data,axis=1)
(np.diff(sdata)>0).sum(axis=1)+1
   .....: 
1 loops, best of 3: 280 ms per loop

我刚刚找到了一种使用 bincount 的更快方法, np.count_nonzero

In [715]: timeit np.array([np.count_nonzero(np.bincount(i)) for i in data])
10 loops, best of 3: 59.6 ms per loop

我对速度的提升感到惊讶。但后来我想起了count_nonzero用于其他函数(例如 np.nonzero )为其返回结果分配空间。因此,将此功能编码为最大速度是有道理的。 (它在 diff...sort 情况下没有帮助,因为它不采用轴参数)。

关于python - 向量化 numpy 对于子数组是唯一的,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/32379895/

相关文章:

python - 在 open cv 中检测橙色的推荐颜色空间是什么?

python - 确保 aiohttp/asyncio 中递归函数的 future

Python Numpy 掩码 NaN 不起作用

python - 拆分二维 numpy 数组,其中可能存在不均匀拆分

python:检查子字符串是否在字符串元组中

python - 我该如何解决这个 list.remove() 错误?

python - 将信号中继到包含的小部件

python - 确定 x 对于 float 和 timedelta 是否均为正

python-3.x - 如何加速以下for循环和函数的应用?

python/numpy 生成二进制文件以供 C 读取