python - 用掩码为真的另一个张量填充张量

标签 python pytorch tensor

我需要以一定的概率将张量 new 的元素插入到张量 old 中,为简单起见,我们假设它是 0.8。 基本上这就是 masked_fill 会做的,但它只适用于一维张量。 其实我在做

    prob = torch.rand(trgs.shape, dtype=torch.float32).to(trgs.device)
    mask = prob < 0.8

    dim1, dim2, dim3, dim4 = new.shape
    for a in range(dim1):
        for b in range(dim2):
            for c in range(dim3):
                for d in range(dim4):
                    old[a][b][c][d] = old[a][b][c][d] if mask[a][b][c][d] else new[a][b][c][d]

这太糟糕了。我想要类似的东西

    prob = torch.rand(trgs.shape, dtype=torch.float32).to(trgs.device)
    mask = prob < 0.8

    old = trgs.multidimensional_masked_fill(mask, new)

最佳答案

我不确定你的一些对象是什么,但这应该能让你在短时间内完成你需要做的事情:

old 是您现有的数据。

mask 是您以概率 p 生成的掩码

new 是包含您要插入的元素的新张量。

# torch.where
result = old.where(mask, new)

关于python - 用掩码为真的另一个张量填充张量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/66688647/

相关文章:

python - Google App Engine 中的全局异常处理

python - 生成 1D 张量作为 2D 张量的行的唯一索引

Pytorch Tensor::data_ptr<long long>() 不适用于 Linux

docker - 在TF中服务多个TF模型服务并创建适当的客户端请求以根据请求与特定模型进行交互

python - 调用 storage() 方法时,Pytorch Tensor 存储具有相同的 id

python - 将文本从第一人称转换为第二人称时出现问题,同时忽略引号内的文本 ""

python - 如何将 Django 项目添加/导入到 virtualenv 中?

python - 在 Python 3 中调用命令行参数

tensorflow - tensorflow 中的索引比收集慢

python - RuntimeError : Expected a Tensor of type torch. FloatTensor 但发现序列元素的类型 torch.IntTensor