python - 在 Numpy/PyTorch 中快速查找值大于阈值的索引

标签 python numpy pytorch

任务

给定一个 numpy 或 pytorch 矩阵,找到值大于给定阈值的单元格的索引。

我的实现

#abs_cosine is the matrix
#sim_vec is the wanted

sim_vec = []
for m in range(abs_cosine.shape[0]):
    for n in range(abs_cosine.shape[1]):
        # exclude diagonal cells
        if m != n and abs_cosine[m][n] >= threshold:
            sim_vec.append((m, n))

疑虑

速度。所有其他计算都建立在 Pytorch 上,使用 numpy 已经是一种妥协,因为它将计算从 GPU 转移到了 CPU。纯 python for 循环会使整个过程变得更糟(对于小数据集已经慢了 5 倍)。 我想知道我们是否可以在不调用任何 for 循环的情况下将整个计算转移到 Numpy(或 pytorch)?

我能想到的改进(但卡住了...)

bool_cosine = abs_cosine > threshold

返回 TrueFalse 的 bool 矩阵。但是我找不到快速检索 True 单元格索引的方法。

最佳答案

以下是 PyTorch(完全在 GPU 上)

# abs_cosine should be a Tensor of shape (m, m)
mask = torch.ones(abs_cosine.size()[0])
mask = 1 - mask.diag()
sim_vec = torch.nonzero((abs_cosine >= threshold)*mask)

# sim_vec is a tensor of shape (?, 2) where the first column is the row index and second is the column index

以下在 numpy 中工作

mask = 1 - np.diag(np.ones(abs_cosine.shape[0]))
sim_vec = np.nonzero((abs_cosine >= 0.2)*mask)
# sim_vec is a 2-array tuple where the first array is the row index and the second array is column index

关于python - 在 Numpy/PyTorch 中快速查找值大于阈值的索引,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50045202/

相关文章:

python - 如何将稀疏 pandas 数据帧转换为 2d numpy 数组

python - python 中的 Turtle 模块不导入

python - Log(2 ** 62,2).is_integer() 返回 false 。任何想法为什么?

python - 从 numpy 数组创建字典

python - 内置计算协方差的函数

python - layout = torch.strided 是什么意思?

python - Keras、TorchVision 中的预训练模型

python - RuntimeError : output with shape [1, 224, 224] 与广播形状 [3, 224, 224] 不匹配

python - Numpy 索引 : Set values of an array given by conditions in different array

python - 在 Python 中使用 igraph 创建网络的性能瓶颈