pytorch - 排列后如何进行张量运算

标签 pytorch permute tensordot

我有 2 个张量,A 和 B:

A = torch.randn([32,128,64,12],dtype=torch.float64)
B = torch.randn([64,12,64,12],dtype=torch.float64)
C = torch.tensordot(A,B,([2,3],[0,1]))
D = C.permute(0,2,1,3) # shape:[32,64,128,12]

张量 D 来自操作“tensordot -> permute”。如何实现新的操作 f() 以在 f() 之后进行张量操作,如下所示:

A_2 = f(A)
B_2 = f(B)
D = torch.tensordot(A_2,B_2)

最佳答案

您是否考虑过使用torch.einsum哪个非常灵活?

D = torch.einsum('ijab,abkl->ikjl', A, B)

tensordot 的问题在于,它在 B 的维度之前输出 A 的所有维度以及您要查找的内容(排列时)是“交错”AB 的维度。

关于pytorch - 排列后如何进行张量运算,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/65716540/

相关文章:

python - torch 服务加载模型失败

python - 零丢包率的丢包层

javascript - 很长的排列 - 句子字谜

python - 使用排列在 Python 中广播

python - 张量运算python中的内存和时间

python - 输出 :\ntorch-1. 1.0-cp27-cp27mu-linux_x86_64.whl 不是此平台上支持的轮子 - Pytorch/云函数

python - 如何将 model.state_dict() 存储在临时变量中以供以后使用?

c - 将排列后的字符串存储到数组中

python - 执行大点/张量点积同时仅保留对角线条目的最有效方法

python - 将 np.einsum 转换为性能更高的内容