python - 在 PyTorch 中计算欧氏距离而不是矩阵乘法

标签 python pytorch matrix-multiplication

假设我们有 2 个矩阵:

mat = torch.randn([20, 7]) * 100
mat2 = torch.randn([7, 20]) * 100

n, m = mat.shape

最简单的常用矩阵乘法如下所示:

def mat_vec_dot_product(mat, vect):
    n, m = mat.shape
    
    res = torch.zeros([n])
    for i in range(n):
        for j in range(m):
            res[i] += mat[i][j] * vect[j]
        
    return res

res = torch.zeros([n, n])
for k in range(n):
    res[:, k] = mat_vec_dot_product(mat, mat2[:, k])
    

但是如果我需要应用 L2 范数而不是点积怎么办?接下来是代码:

def mat_vec_l2_mult(mat, vect):
    n, m = mat.shape
    
    res = torch.zeros([n])
    for i in range(n):
        for j in range(m):
            res[i] += (mat[i][j] - vect[j]) ** 2
            
    res = res.sqrt()
        
    return res

for k in range(n):
    res[:, k] = mat_vec_l2_mult(mat, mat2[:, k])

我们能否使用 Torch 或任何其他库以某种方式以最佳方式做到这一点?因为天真的 O(n^3) Python 代码运行起来真的很慢。

最佳答案

使用torch.cdist对于 L2 范数 - 欧氏距离

res = torch.cdist(mat, mat2.permute(1,0), p=2)

在这里,我使用 permutemat2 的 dim 从 7,20 交换为 20,7

关于python - 在 PyTorch 中计算欧氏距离而不是矩阵乘法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/63727907/

相关文章:

c - 这个使用 Strassen 算法进行矩阵乘法的 C 代码有什么问题?

python - Python Click 提示是否有任何预填充选项?

python - PyTorch 中的数据增强

python - 定心矩阵

pytorch - AllenNLP DatasetReader.read 返回生成器而不是 AllennlpDataset

python - 如何降级 Google Colab 的 Torch 版本

c - 2x2 矩阵乘法

python - 在 Python 中,如何解析表示一组关键字参数的字符串,使得顺序无关紧要

python - 类型错误 : translate() takes exactly one argument (2 given)

python - 在 python 中将字节转换为 BufferedReader