python - 随机获取 PyTorch 张量中最大值之一的索引

标签 python python-3.x pytorch

我需要在一维张量上执行类似于内置 torch.argmax() 函数的操作,但我希望能够随机选择一个最大值而不是选择第一个最大值的索引最大值之一的索引。例如:

my_tensor = torch.tensor([0.1, 0.2, 0.2, 0.1, 0.1, 0.2, 0.1])
index_1 = random_max_val_index_fn(my_tensor)
index_2 = random_max_val_index_fn(my_tensor)

print(f"{index_1}, {index_2}")
> 5, 1

最佳答案

可以先获取所有最大值的索引,然后从中随机选择:

def rand_argmax(tens):
    max_inds, = torch.where(tens == tens.max())
    return np.random.choice(max_inds)

样本运行:

>>> my_tensor = torch.tensor([0.1, 0.2, 0.2, 0.1, 0.1, 0.2, 0.1])
>>> rand_argmax(my_tensor)
2
>>> rand_argmax(my_tensor)
5
>>> rand_argmax(my_tensor)
2
>>> rand_argmax(my_tensor)
1

关于python - 随机获取 PyTorch 张量中最大值之一的索引,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/68499637/

相关文章:

python - 使用Python OpenCV捏缩/凸出变形

python - 从数据框中绘制并排堆叠条形图

python - 有没有办法在训练步骤后直接更新层/变量的权重?

python-3.x - 使用 apply lambda 检查条件时,系列的真值不明确

python - pytorch/torchtext 中的单热编码

python - PyTorch张量高级索引

python - 为什么 sympy 不简化导数的傅里叶变换?

python - 通过添加小数时间增量增量来删除重复的日期时间索引值

python - 具有随机幂律分布权重的网络

python-3.x - 无法使用从 git 安装的 pip 包