python - 禁止在 PyTorch 神经网络的 CrossEntropyLoss 中使用 Softmax

标签 python neural-network pytorch softmax cross-entropy

我知道当使用 nn.CrossEntropyLoss 作为损失函数时,不需要在神经网络的输出层使用 nn.Softmax() 函数。

但是我需要这样做,有没有办法抑制 nn.CrossEntropyLoss 中 softmax 的实现使用,而是在我的输出中使用 nn.Softmax()神经网络本身的层?

动机:我正在使用 shap 包来分析之后的特征影响,我只能将训练好的模型作为输入。然而,输出没有任何意义,因为我正在查看未绑定(bind)的值而不是概率。

示例:我想要的不是 -69.36 作为我模型的一个类的输出值,而是介于 0 和 1 之间的值,所有类的总和为 1。由于我之后无法更改它,因此在训练期间输出需要已经像这样。

enter image description here

最佳答案

nn.CrossEntropyLoss 的文档说,

This criterion combines nn.LogSoftmax() and nn.NLLLoss() in one single class.

我建议您坚持使用 CrossEntropyLoss 作为损失标准。但是,您可以使用 softmax 将模型的输出转换为概率值。功能。

请注意,您始终可以使用模型的输出值,不需要为此更改损失标准。

但是如果你仍然想在你的网络中使用Softmax(),那么你可以使用NLLLoss()作为损失准则,只应用log()在将模型的输出提供给准则函数之前。同样,如果您在网络中使用 LogSoftmax,则可以应用 exp()获取概率值。

更新:

要在 Softmax 输出上使用 log(),请执行以下操作:

torch.log(prob_scores + 1e-20)

通过向 prob_scores 添加一个非常小的数字 (1e-20),我们可以避免 log(0) 问题。

关于python - 禁止在 PyTorch 神经网络的 CrossEntropyLoss 中使用 Softmax,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58122505/

相关文章:

python - 如何从列表列表中识别所有相同的列表?

python - JMeter - 在调用每个 HTTP 请求采样器之前运行 python 脚本

如果服务器关闭连接,Python 客户端会断开连接

python - PyTables 列到普通 python 列表

machine-learning - 数据集中所有图像中特定对象的存在是否会影响 CNN 的性能

machine-learning - 为什么我的训练模型不能预测正确的结果?

machine-learning - 在模式识别问题中使用什么更好?机器学习还是神经网络?

python-3.x - 如何使用 SWA :stochastic weights average? 添加 BatchNormalization

python - 由另一个多维张量索引多维 torch 张量

python - PyTorch 模型未进行训练