python - 绕过不可微分的 tf.argmax

标签 python tensorflow

我为我的神经网络编写了自定义损失函数,但它无法计算任何梯度。我认为这是因为我需要最高值的索引,因此使用 argmax 来获取该索引。

由于 argmax 不可微分,我可以绕过这个问题,但我不知道这怎么可能。

有人能帮忙吗?

最佳答案

正如 aidan 所建议的,它只是一个 softargmax被 beta 拉到了极限。我们可以使用 tf.nn.softmax 来解决数值问题:

def softargmax(x, beta=1e10):
  x = tf.convert_to_tensor(x)
  x_range = tf.range(x.shape.as_list()[-1], dtype=x.dtype)
  return tf.reduce_sum(tf.nn.softmax(x*beta) * x_range, axis=-1)

关于python - 绕过不可微分的 tf.argmax,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46926809/

相关文章:

python - 使用 POST 数据的 HTTPS Flask

android - 在 Android 上保持 TensorFlow 模型加密

python - 分布式 tensorflow 复制训练示例: grpc_tensorflow_server - No such file or directory

python - 无法读取的笔记本 : FileNotFoundError(2, 'No such file or directory' )

python - 在 keras 中禁用 tensorflow 回溯警告

tensorflow - TPU术语混淆

python - 在 PyGObject 中设置样式属性

python - 通过pip安装tensorflow但是失败如下

python - 在 python 中使用 numpy 时出现 "shape mismatch"错误

Python super() 不适用于 Enum 参数