python - 查找 3 维 numpy 数组的唯一值的索引

标签 python arrays for-loop optimization numpy

我有一个包含 N 个点坐标的数组。另一个数组包含这 N 个点的质量。

 >>> import numpy as np
 >>> N=10
 >>> xyz=np.random.randint(0,2,(N,3))
 >>> mass=np.random.rand(len(xyz))
 >>> xyz
 array([[1, 0, 1],
   [1, 1, 0],
   [0, 1, 1],
   [0, 0, 0],
   [0, 1, 0],
   [1, 1, 0],
   [1, 0, 1],
   [0, 0, 1],
   [1, 0, 1],
   [0, 0, 1]])
 >>> mass
 array([ 0.38668401,  0.44385111,  0.47756182,  0.74896529,  0.20424403,
    0.21828435,  0.98937523,  0.08736635,  0.24790248,  0.67759276])

现在我想获得一个具有唯一 xyz 值的数组以及相应的总质量数组。这意味着以下数组:

 >>> xyz_unique
 array([[0, 1, 1],
   [1, 1, 0],
   [0, 0, 1],
   [1, 0, 1],
   [0, 0, 0],
   [0, 1, 0]])
 >>> mass_unique
 array([ 0.47756182,  0.66213546,  0.76495911,  1.62396172,  0.74896529,
    0.20424403])

我的尝试是使用双 for 循环的以下代码:

 >>> xyz_unique=np.array(list(set(tuple(p) for p in xyz)))
 >>> mass_unique=np.zeros(len(xyz_unique))
 >>> for j in np.arange(len(xyz_unique)):
 ...     indices=np.array([],dtype=np.int64)
 ...     for i in np.arange(len(xyz)):
 ...         if np.all(xyz[i]==xyz_unique[j]):
 ...             indices=np.append(indices,i)
 ...     mass_unique[j]=np.sum(mass[indices])

问题是这花了太长时间,我实际上有 N=100000。 有没有更快的方法或者如何改进我的代码?

编辑 我的坐标实际上是 float 。为了简单起见,我制作了随机整数以在低 N 处具有重复项。

最佳答案

情况 1:xyz 中的二进制数

如果输入数组xyz中的元素是01,则可以将每一行转换为十进制数,然后根据每一行的唯一性使用其他十进制数字来标记它们。然后,根据这些标签,您可以使用 np.bincount累积总和,就像在 MATLAB 中一样,可以使用 accumarray 。这是实现所有这些的实现 -

import numpy as np

# Input arrays xyz and mass
xyz = np.array([
   [1, 0, 1],
   [1, 1, 0],
   [0, 1, 1],
   [0, 0, 0],
   [0, 1, 0],
   [1, 1, 0],
   [1, 0, 1],
   [0, 0, 1],
   [1, 0, 1],
   [0, 0, 1]])

mass = np.array([ 0.38668401,  0.44385111,  0.47756182,  0.74896529,  0.20424403,
    0.21828435,  0.98937523,  0.08736635,  0.24790248,  0.67759276])

# Convert each row entry in xyz into equivalent decimal numbers
dec_num = np.dot(xyz,2**np.arange(xyz.shape[1])[:,None])

# Get indices of the first occurrences of the unique values and also label each row
_, unq_idx,row_labels = np.unique(dec_num, return_index=True, return_inverse=True)

# Find unique rows from xyz array
xyz_unique = xyz[unq_idx,:]

# Accumulate the summations from mass based on the row labels
mass_unique = np.bincount(row_labels, weights=mass)

输出 -

In [148]: xyz_unique
Out[148]: 
array([[0, 0, 0],
       [0, 1, 0],
       [1, 1, 0],
       [0, 0, 1],
       [1, 0, 1],
       [0, 1, 1]])

In [149]: mass_unique
Out[149]: 
array([ 0.74896529,  0.20424403,  0.66213546,  0.76495911,  1.62396172,
        0.47756182])

案例 2:通用

对于一般情况,您可以使用此 -

import numpy as np

# Perform lex sort and get the sorted indices
sorted_idx = np.lexsort(xyz.T)
sorted_xyz =  xyz[sorted_idx,:]

# Differentiation along rows for sorted array
df1 = np.diff(sorted_xyz,axis=0)
df2 = np.append([True],np.any(df1!=0,1),0)

# Get unique sorted labels
sorted_labels = df2.cumsum(0)-1

# Get labels
labels = np.zeros_like(sorted_idx)
labels[sorted_idx] = sorted_labels

# Get unique indices
unq_idx  = sorted_idx[df2]

# Get unique xyz's and the mass counts using accumulation with bincount
xyz_unique = xyz[unq_idx,:]
mass_unique = np.bincount(labels, weights=mass)

示例运行 -

In [238]: xyz
Out[238]: 
array([[1, 2, 1],
       [1, 2, 1],
       [0, 1, 0],
       [1, 0, 1],
       [2, 1, 2],
       [2, 1, 1],
       [0, 1, 0],
       [1, 0, 0],
       [2, 1, 0],
       [2, 0, 1]])

In [239]: mass
Out[239]: 
array([ 0.5126308 ,  0.69075674,  0.02749734,  0.384824  ,  0.65151772,
        0.77718427,  0.18839268,  0.78364902,  0.15962722,  0.09906355])

In [240]: xyz_unique
Out[240]: 
array([[1, 0, 0],
       [0, 1, 0],
       [2, 1, 0],
       [1, 0, 1],
       [2, 0, 1],
       [2, 1, 1],
       [1, 2, 1],
       [2, 1, 2]])

In [241]: mass_unique
Out[241]: 
array([ 0.78364902,  0.21589002,  0.15962722,  0.384824  ,  0.09906355,
        0.77718427,  1.20338754,  0.65151772])

关于python - 查找 3 维 numpy 数组的唯一值的索引,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/29566885/

相关文章:

python - 在生产环境中使用 Flask 编译 Coffeescript

ruby - 如何删除数组哈希中的部分?

node.js - Mongoose for var in loop 在子文档中

Java - 同时执行两个相应的for循环

python - 如何在终端模式执行(批处理模式)中找到节点的宽度?

python - 更新 python 时避免将 SAME 字符串写入文本文件

python - 在 python 中捕获 ONLY stdout 的输出

php - 此数组插入的更短语法

javascript - 尝试将这个大的 if else 语句转换为循环

swift - 在 for ... where 子句中测试枚举是否相等