假设我们有 2 个矩阵:
mat = torch.randn([20, 7]) * 100
mat2 = torch.randn([7, 20]) * 100
n, m = mat.shape
最简单的常用矩阵乘法如下所示:
def mat_vec_dot_product(mat, vect):
n, m = mat.shape
res = torch.zeros([n])
for i in range(n):
for j in range(m):
res[i] += mat[i][j] * vect[j]
return res
res = torch.zeros([n, n])
for k in range(n):
res[:, k] = mat_vec_dot_product(mat, mat2[:, k])
但是如果我需要应用 L2 范数而不是点积怎么办?接下来是代码:
def mat_vec_l2_mult(mat, vect):
n, m = mat.shape
res = torch.zeros([n])
for i in range(n):
for j in range(m):
res[i] += (mat[i][j] - vect[j]) ** 2
res = res.sqrt()
return res
for k in range(n):
res[:, k] = mat_vec_l2_mult(mat, mat2[:, k])
我们能否使用 Torch 或任何其他库以某种方式以最佳方式做到这一点?因为天真的 O(n^3) Python 代码运行起来真的很慢。
最佳答案
使用torch.cdist
对于 L2 范数 - 欧氏距离
res = torch.cdist(mat, mat2.permute(1,0), p=2)
在这里,我使用 permute
将 mat2
的 dim 从 7,20
交换为 20,7
关于python - 在 PyTorch 中计算欧氏距离而不是矩阵乘法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/63727907/