python - 在 PyTorch 和 Numpy 中快速生成形状为 (1, 1, 256) 和 (10, 1, 256) 的多个 3D 张量

标签 python numpy pytorch

我正在尝试根据自己的任务调整 seq2seq 模型,https://github.com/spro/practical-pytorch/blob/master/seq2seq-translation/seq2seq-translation.ipynb

我在解码阶段有两个张量

rnn_output: (1, 1, 256)       # time_step x batch_size x hidden_dimension
encoder_inputs: (10, 1, 256)  # seq_len   x batch_size x hidden_dimension

它们应该相乘以获得形状的注意力分数(在 softmax 之前)

attn_score: (10, 1, 1) 

最好的方法是什么? notebook好像用了for循环,有没有更好的矩阵乘法之类的操作?

最佳答案

没有使用 pytorch 的经验,但是这样的东西可以工作吗?

torch.einsum('ijk,abk->abc', (rnn_output, encoder_inputs))

它对最后一个轴进行点积,并向后添加几个空轴。

类似的东西可以用纯 numpy 实现(pytorch 0.4 还没有 ... 符号)

np.einsum('...ik,...jk', rnn_output.numpy(), encoder_inputs.numpy())

或者用np.tensordot

np.tensordot(rnn_output.numpy(), encoder_inputs.numpy(), axes=[2,2])

但是在这里你会得到输出形状:(1, 1, 10, 1)

您可以通过挤压和重新展开来解决这个问题(几乎可以肯定必须有一些更清洁的方法来做到这一点)

 np.tensordot(rnn_output.numpy(), encoder_inputs.numpy(), axes=[2,2]).squeeze()[..., None, None]

关于python - 在 PyTorch 和 Numpy 中快速生成形状为 (1, 1, 256) 和 (10, 1, 256) 的多个 3D 张量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50570697/

相关文章:

python - Python 中的波浪号运算符

python - 为什么 GPU 上的乘法比 CPU 慢?

python - 为什么在用 python 编写 excel 文件时会得到像 [0, 1, 2, ...] 这样的行作为标题?

python - 为什么我得到 "IndexError: string index out of range"?

Python崩溃没有错误

python - pandas 比应用 lambda 更快的方式在每一行应用逻辑?

python - 找到第一个 np.nan 值位置的最有效方法是什么?

python - 通过无法执行的并行Python执行Fortran子例程

python - 可变大小输入的小批量训练

python - PyTorch 中是否有提取图像补丁的功能?