pytorch - torch 在 2 个 2D 张量中找到匹配行的索引

标签 pytorch tensor

我有两个二维张量,长度不同,都是同一个原始二维张量的不同子集,我想找到所有匹配的“行”
例如

A = [[1,2,3],[4,5,6],[7,8,9],[3,3,3]
B = [[1,2,3],[7,8,9],[4,4,4]]
torch.2dintersect(A,B) -> [0,2] (the indecies of A that B also have)

我只看到 numpy 解决方案,它使用 dtype 作为字典,不适用于 pytorch。


这是我在 numpy 中的做法

arr1 = edge_index_dense.numpy().view(np.int32)
arr2 = edge_index2_dense.numpy().view(np.int32)
arr1_view = arr1.view([('', arr1.dtype)] * arr1.shape[1])
arr2_view = arr2.view([('', arr2.dtype)] * arr2.shape[1])
intersected = np.intersect1d(arr1_view, arr2_view, return_indices=True)

最佳答案

这个答案是在 OP 使用其他限制更新问题之前发布的,这些限制使问题发生了很大变化。

TL;DR 你可以这样做:

torch.where((A == B).all(dim=1))[0]

首先,假设您有:

import torch
A = torch.Tensor([[1,2,3],[4,5,6],[7,8,9]])
B = torch.Tensor([[1,2,3],[4,4,4],[7,8,9]])

我们可以检查 A == B 返回:

>>> A == B
tensor([[ True,  True,  True],
        [ True, False, False],
        [ True,  True,  True]])

所以,我们想要的是:它们都为 True 的行。为此,我们可以使用 .all() 操作并指定感兴趣的维度,在我们的例子中是 1:

>>> (A == B).all(dim=1)
tensor([ True, False,  True])

您真正想知道的是 True 在哪里。为此,我们可以获得 torch.where() 函数的第一个输出:

>>> torch.where((A == B).all(dim=1))[0]
tensor([0, 2])

关于pytorch - torch 在 2 个 2D 张量中找到匹配行的索引,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59705001/

相关文章:

python - Pytorch DataLoader 不适用于远程解释器

python - 如何使用 TensorFlow 的 eager execution 检查张量的内容?

neural-network - 在pytorch中执行卷积(不互相关)

python - pytorch中自定义数据集的数据预处理(transform.Normalize)

python - 'ToPILImage' 对象没有属性 'show'

python - Pytorch Siamese 网络不收敛

python - Tensorflow Tf.tf.squared_difference 显示密集层的值错误

python - tensorflow 合并并压缩两个张量

cuda - 如何访问 CUDA 中的稀疏张量核心功能?

python - 为什么 PyTorch nn.Module.cuda() 不移动模块张量而只移动参数和缓冲区到 GPU?