在 Pytorch 中,假设我有一个 top-k 索引矩阵 P(B,N,k)
、一个权重矩阵 W(B,N,N)
和目标矩阵A(B,N,N)
,我想获得一个按以下循环操作的相邻矩阵:
for i in range(B):
for ii in range(N):
for j in range(k):
if weighted:
A[i][ii][P[i][ii][j]] = W[i][ii][P[i][ii][j]]
else:
A[i][ii][P[i][ii][j]] = 1
如何在Pytorch中更高效、简洁地实现?
最佳答案
我认为您正在寻找 torch.scatter_
:
A.scatter_(dim=2, index=P, src=W) # for the weighted version
A.scatter_(dim=2, index=P, src=torch.ones_like(W)) # for the un-weighted version
关于matrix - 在Pytorch中创建knn邻接矩阵,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/65517572/