我想定义自定义损失,但似乎无法将 keras 张量 K.sum(y_true)
与单个值 0
进行比较。
def custom_loss_keras(y_true, y_pred):
if(K.sum(y_true) > 0):
loss = K.mean(K.binary_crossentropy(y_true, y_pred), axis=-1)
return loss
else:
loss = 0.0
return loss
我还在损失函数中尝试了 K.eval()
来获取 numpy 数组,但失败了。
def custom_loss_keras(y_true, y_pred):
y_true_np = K.eval(y_true)
#if(K.sum(y_true) > 0):
if(np.sum(y_true_np) > 0):
loss = K.mean(K.binary_crossentropy(y_true, y_pred), axis=-1)
return loss
else:
loss = 0.0
return loss
更新:
def custom_loss_keras(y_true, y_pred):
if(K.greater(K.sum(y_true), 0)):
loss = K.mean(K.binary_crossentropy(y_true, y_pred), axis=-1)
return loss
else:
loss = 0.0
return loss
它产生错误:
TypeError: Using a `tf.Tensor` as a Python `bool` is not allowed. Use `if t is not None:` instead of `if t:` to test if a tensor is defined, and use TensorFlow ops such as tf.cond to execute subgraphs conditioned on the value of a tensor.
我还尝试将建议的 tf.cond
与 keras 函数结合起来:
def custom_loss_keras(y_true, y_pred):
loss = tf.cond(K.greater(K.sum(y_true),0), K.mean(K.binary_crossentropy(y_true, y_pred), axis=-1), 0.0)
return loss
它产生错误:
22 def custom_loss_keras(y_true, y_pred):
---> 23 loss = tf.cond(K.greater(K.sum(y_true),0), K.mean(K.binary_crossentropy(y_true, y_pred), axis=-1), 0.0)
24
25 return loss
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py in cond(pred, fn1, fn2, name)
1718 with ops.name_scope(name, "cond", [pred]) as name:
1719 if not callable(fn1):
-> 1720 raise TypeError("fn1 must be callable.")
1721 if not callable(fn2):
1722 raise TypeError("fn2 must be callable.")
TypeError: fn1 must be callable.
看来我需要用纯 tensorflow 来编写它?
最佳答案
在损失函数中使用 if
和 else
(或 K.eval
)将不起作用,因为 custom_loss_keras 中的行
在模型编译期间执行,而不是模型拟合期间执行。
您可以使用K.switch
,而不是调用tf.cond
:
def custom_loss_keras(y_true, y_pred):
loss = K.mean(K.binary_crossentropy(y_true, y_pred), axis=-1)
condition = K.greater(K.sum(y_true), 0)
return K.switch(condition, loss, K.zeros_like(loss))
关于python - Keras:如何将自定义损失中的 K.sum(y_true) 与 0 进行比较?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49533965/