numpy - Torch 广播如何为 (8, 8) @ (4, 8, 2) 工作?

标签 numpy pytorch

我的张量大小为:

a.shape 
=> (8, 8)
b.shape 
=> (4, 8, 2)

c = a @ b 
c.shape
=> (4, 8, 2)

我很惊讶地发现 a @ b 是可以广播的。我检查了 pytorch 的广播规则,似乎这些规则不兼容。

有人可以说明一下这是如何计算的吗?

最佳答案

查看 torch.matmul 的文档

在本例中,b 被解释为批处理矩阵,第一个轴是批处理维度。

关于numpy - Torch 广播如何为 (8, 8) @ (4, 8, 2) 工作?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/77299241/

相关文章:

python - 将张量列表转换为张量

python-3.x - torch.stack([t1,t1,t1],dim=1) 和 torch.hstack([t1,t1,t1]) 之间有什么区别?

pytorch - 在训练模式下 DeepLabV3 的输入应该是什么?

python - 如何优化推理脚本以获得更快的分类器预测?

python - 使用 Python JAX/Autograd 计算向量值函数的雅可比行列式

Python:元素明智除法运算符错误

python - Pandas Series 中的名称参数是什么?

Python Numpy 掩码 NaN 不起作用

python - 使用 pytorch/python,覆盖变量还是定义新变量更好?

python - numpy.linalg.inv() 是否给出了正确的矩阵逆?编辑: Why does inv() gives numerical errors?