python - 对keras中的部分张量应用不同的损失函数

标签 python tensorflow keras

我正在尝试构建一个自定义损失函数,它将根据groundtruth将不同的函数应用于张量的不同部分。 .

例如说 groundtruth是:

[0 1 1 0]

我要申请log(n)到输出张量的索引 1, 2(即真实值中值为 1 的索引),并应用 log(n-1)其余的。

我怎样才能实现它?

最佳答案

您可以创建两个蒙版。

  • 第一个函数屏蔽了零,因此您可以将其应用于第一个损失函数,在该函数中仅将 log(n) 应用于那些 1 的值。

  • 第二个掩码会屏蔽掉那些,因此您可以将其应用于第二个损失函数,在该函数中将 log(n-1) 应用于那些 0 值。

类似于:

input = tf.constant([0, 1, 1, 0], tf.float32)
mask1 = tf.cast(tf.equal(input, 1.0), tf.float32)
loss1 = tf.log(input) * mask1

mask2 = tf.cast(tf.equal(input, 0.0), tf.float32)
loss2 = tf.log(input - 1) * mask2

overall_loss = tf.add(loss1, loss2)

关于python - 对keras中的部分张量应用不同的损失函数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56942351/

相关文章:

python - MySQLClient 安装错误 : "raise Exception("Wrong MySQL configuration: maybe https://bugs. mysql.com/bug.php?id"

python - 分组后用均值和标准差绘制误差线

python - 如何以 Pythonic 方式将此元组列表转换为此字典?

python - 如何对目录中的所有文件运行单元测试?

machine-learning - 在多个类上进行训练时如何在 Keras 中获取标签 id?

python - 关于 .shuffle、.batch 和 .repeat 的 Tensorflow 数据集问题

python - cuda 的 tensorflow 错误

python - 使用 Keras 微调 ResNet50 - val_loss 不断增加

python - Keras中如何实现RBF激活函数?

python - Keras自定义损失函数访问python全局变量时的内部机制是什么?