我为我的神经网络编写了自定义损失函数,但它无法计算任何梯度。我认为这是因为我需要最高值的索引,因此使用 argmax 来获取该索引。
由于 argmax 不可微分,我可以绕过这个问题,但我不知道这怎么可能。
有人能帮忙吗?
最佳答案
正如 aidan 所建议的,它只是一个 softargmax被 beta 拉到了极限。我们可以使用 tf.nn.softmax
来解决数值问题:
def softargmax(x, beta=1e10):
x = tf.convert_to_tensor(x)
x_range = tf.range(x.shape.as_list()[-1], dtype=x.dtype)
return tf.reduce_sum(tf.nn.softmax(x*beta) * x_range, axis=-1)
关于python - 绕过不可微分的 tf.argmax,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46926809/