python - pytorch masked_fill : why can't I mask all zeros?

标签 python tensorflow pytorch

我想用 -np.inf 屏蔽分数矩阵中的所有零,但我只能屏蔽部分零,看起来像

enter image description here

你会在右上角看到仍然有没有被 -np.inf 掩盖的零

这是我的代码:

q = torch.Tensor([np.random.random(10),np.random.random(10),np.random.random(10), np.random.random(10), np.zeros((10,1)), np.zeros((10,1))])
k = torch.Tensor([np.random.random(10),np.random.random(10),np.random.random(10), np.random.random(10), np.zeros((10,1)), np.zeros((10,1))])
scores = torch.matmul(q, k.transpose(0,1)) / math.sqrt(10)
mask = torch.Tensor([1,1,1,1,0,0])
mask = mask.unsqueeze(1)
scores = scores.masked_fill(mask==0, -np.inf)

也许面具是错的?

最佳答案

你的面具错了。尝试

scores = scores.masked_fill(scores == 0, -np.inf)
scores 现在看起来像
tensor([[1.4796, 1.2361, 1.2137, 0.9487,   -inf,   -inf],
        [0.6889, 0.4428, 0.6302, 0.4388,   -inf,   -inf],
        [0.8842, 0.7614, 0.8311, 0.6431,   -inf,   -inf],
        [0.9884, 0.8430, 0.7982, 0.7323,   -inf,   -inf],
        [  -inf,   -inf,   -inf,   -inf,   -inf,   -inf],
        [  -inf,   -inf,   -inf,   -inf,   -inf,   -inf]])

关于python - pytorch masked_fill : why can't I mask all zeros?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56328630/

相关文章:

python - 如何在字典列表中查找值并返回特定键

Python - 在 Mac 上安装 opencv 时遇到问题(几个月前 opencv 运行良好之后)

python - 想要通过占位符选择运行时训练的变量

python - 从 Iris CSV 数据训练的完整 Tensorflow 使用

python - Pytorch 训练和评估不同的样本量

python __getattribute__ 返回类变量属性时出现RecursionError

python - query_set 中的 values_list() 只显示数字而不显示国家名称

python - 为什么 ptb_word_ln.py 中 embedding_lookup 仅用作编码器而不用作解码器

deep-learning - 应用softmax后提取标签

image-processing - 带填充的平均池的期望行为是什么?