python - 为什么模式为 "max"的 Pytorch EmbeddingBag 不接受 `per_sample_weights` ?

标签 python pytorch embedding weighted-average

Pytorch 的 EmbeddingBag 允许对不同长度的嵌入索引集合进行高效的查找和归约操作。 reduce 操作有 3 种模式:“sum”、“average”和“max”。使用“sum”,您还可以提供 per_sample_weights 给您一个加权总和。

为什么 per_sample_weights 不允许进行“最大”操作?看着how it's implemented ,我只能假设在“Mul”操作之后执行“ReduceMean”或“ReduceMax”操作存在问题。这可能与计算梯度有关吗??


p.s:通过除以权重总和可以很容易地将加权和转换为加权平均值,但对于“最大值”,您无法得到这样的加权等价物。

最佳答案

参数 per_sample_weights 仅针对 mode='sum' 实现,不是由于技术限制,而是因为开发人员没有发现“最大加权”的用例:

I haven't been able to find use cases for "weighted mean" (which can be emulated via weighted sum) and "weighted max".

关于python - 为什么模式为 "max"的 Pytorch EmbeddingBag 不接受 `per_sample_weights` ?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/66666246/

相关文章:

python - 如何使用 WeightedRandomSampler 平衡 PyTorch 中的不平衡数据?

python - 多个 PyTorch 网络在不同 CPU 上并行运行

pytorch - 属性错误: 'builtin_function_or_method' object has no attribute 'requires_grad'

python - 在 Python 中使用 Peewee 的 SQL DDL 中的 SQLite 触发器和日期时间默认值

python - numpy 数组的 zip_longest

python - youtube-dl下载带有formatid的视频

javascript - 如何在 JavaScript 中包含 Ruby 字符串

python - 如何将 Python 嵌入到多平台 C++ 框架(JUCE)中?

python - python 中的总和,其中列表列表呈指数形式

python - 找到python文件中字符串的位置