pytorch - 如何在pytorch中有效地重复时间的张量元素变量?

标签 pytorch tensor

例如,如果我有一个张量 A = [[1,1,1], [2,2,2], [3,3,3]],并且 B = [1,2,3]。如何获得 C = [[1,1,1], [2,2,2], [2,2,2], [3,3,3], [3,3,3], [3, 3,3]],并批量执行此操作?

顺便说一句,我当前的逐元素解决方案(需要永远......):

        def get_char_context(valid_embeds, words_lens):
            chars_contexts = []
            for ve, wl in zip(valid_embeds, words_lens):
                for idx, (e, l) in enumerate(zip(ve, wl)):
                    if idx ==0:
                        chars_context = e.view(1,-1).repeat(l, 1)
                    else:
                        chars_context = torch.cat((chars_context, e.view(1,-1).repeat(l, 1)),0)
                chars_contexts.append(chars_context)
            return chars_contexts

我这样做是为了将 bert 词嵌入添加到字符级 seq2seq 任务中...

最佳答案

使用这个:

import torch
# A is your tensor
B = torch.tensor([1, 2, 3])
C = A.repeat_interleave(B, dim = 0)

编辑:

如果 A 是单个二维张量,则上述方法可以正常工作。要以相同的方式重复批量重复所有(2D)张量,这是一个简单的解决方法:

A = torch.tensor([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], 
    [[1, 2, 3], [4, 5, 6], [2,2,2]]]) # A has 2 tensors each of shape (3, 3)
B = torch.tensor([1, 2, 3]) # Rep. of each row of every tensor in the batch

A1 = A.reshape(1, -1, A.shape[2]).squeeze()
B1 = B.repeat(A.shape[0])
C = A1.repeat_interleave(B1, dim = 0).reshape(A.shape[0], -1, A.shape[2])

C 是:

tensor([[[1, 1, 1],
         [2, 2, 2],
         [2, 2, 2],
         [3, 3, 3],
         [3, 3, 3],
         [3, 3, 3]],

        [[1, 2, 3],
         [4, 5, 6],
         [4, 5, 6],
         [2, 2, 2],
         [2, 2, 2],
         [2, 2, 2]]])

正如您所看到的,批处理中的每个内部张量都以相同的方式重复。

关于pytorch - 如何在pytorch中有效地重复时间的张量元素变量?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/67186064/

相关文章:

python - 如何在具有 2 个优化器的循环中调用 "backward"?

python - 如何使用pytorch获得jacobian多元正态分布的对数概率

python - 这两种神经网络结构有什么区别呢?

python - 通过索引列表选择pytorch张量元素

pytorch - 将张量扩展几个维度

Tensorflow 中张量的 For 循环

python - Pytorch 张量 - 如何通过特定张量获取索引

tensorflow - Tensorflow中哪个函数与Pytorch中的expand_as类似

python - 在 Tensorflow 中删除张量的维度

machine-learning - Pytorch 中的 int8 数据类型