python - 如何使用三角矩阵使 np.where 更高效?

标签 python numpy scipy python-itertools

我得到了这段代码,其中距离是定义如下的下三角矩阵:

distance = np.tril(scipy.spatial.distance.cdist(points, points))  
def make_them_touch(distance):
    """
    Return the every distance where two points touched each other. See example below.
    """
    thresholds = np.unique(distance)[1:] # to avoid 0 at the beginning, not taking a lot of time at all
    result = dict()
    for t in thresholds:
            x, y = np.where(distance == t)
            result[t] = [i for i in zip(x,y)]
    return result

我的问题是 np.where 对于大矩阵(例如 2000*100)来说非常慢。
如何通过改进 np.where 或更改算法来加速此代码?

编辑:MaxU指出,这里最好的优化不是生成方阵并使用迭代器。

示例:

points = np.array([                                                                        
...: [0,0,0,0],                                                            
...: [1,1,1,1],         
...: [3,3,3,3],              
...: [6,6,6,6]                             
...: ])  

In [106]: distance = np.tril(scipy.spatial.distance.cdist(points, points))

In [107]: distance
Out[107]: 
array([[ 0.,  0.,  0.,  0.],
   [ 2.,  0.,  0.,  0.],
   [ 6.,  4.,  0.,  0.],
   [12., 10.,  6.,  0.]])

In [108]: make_them_touch(distance)
Out[108]: 
{2.0: [(1, 0)],
 4.0: [(2, 1)],
 6.0: [(2, 0), (3, 2)],
 10.0: [(3, 1)],
 12.0: [(3, 0)]}

最佳答案

更新1:这是上三角距离矩阵的片段(这并不重要,因为距离矩阵始终是对称的):

from itertools import combinations

res = {tup[0]:tup[1] for tup in zip(pdist(points), list(combinations(range(len(points)), 2)))}

结果:

In [111]: res
Out[111]:
{1.4142135623730951: (0, 1),
 4.69041575982343: (0, 2),
 4.898979485566356: (1, 2)}
<小时/>

更新2:此版本将支持距离重复:

In [164]: import pandas as pd

首先我们构建一个 Pandas.Series:

In [165]: s = pd.Series(list(combinations(range(len(points)), 2)), index=pdist(points))

In [166]: s
Out[166]:
2.0     (0, 1)
6.0     (0, 2)
12.0    (0, 3)
4.0     (1, 2)
10.0    (1, 3)
6.0     (2, 3)
dtype: object

现在我们可以按索引分组并生成坐标列表:

In [167]: s.groupby(s.index).apply(list)
Out[167]:
2.0             [(0, 1)]
4.0             [(1, 2)]
6.0     [(0, 2), (2, 3)]
10.0            [(1, 3)]
12.0            [(0, 3)]
dtype: object

PS 这里的主要思想是,如果您打算随后将其展平并消除重复项,则不应构建平方距离矩阵。

关于python - 如何使用三角矩阵使 np.where 更高效?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50907049/

相关文章:

Python终端打开后报错

python - tkinter 中不断更新图表

python - 比较两个 numpy 数组是否符合两个条件

python - 在 Odoo 8 中,如何将模块 .py 文件导入自定义模块?

python - 代码适用于 Python 3.6,但不适用于 3.7

python - 在 Synology 上安装 python 模块 - pip 错误

python - 如何从 Python 中的两个 n 维数组中获取匹配的行?

python - 用python求解联立多元多项式方程

python - 如何用python获得复杂网络的拉普拉斯矩阵的第二小特征值?

python - 如何从python中的wav文件绘制波形?