debugging - 如何在Keras中调试自定义损失函数?

标签 debugging printing callback keras tensor

我使用参数创建了自定义损失函数。

def w_categorical_crossentropy(weights):
  def loss(y_true, y_pred):
  print(weights)
  print("----")
  print(weights.shape)
  final_mask = K.zeros_like(y_pred[:, 0])
  y_pred_max = K.max(y_pred, axis=1)
  y_pred_max = K.reshape(y_pred_max, (K.shape(y_pred)[0], 1))
  y_pred_max_mat = K.cast(K.equal(y_pred, y_pred_max), K.floatx())
  return K.categorical_crossentropy(y_pred, y_true)
return loss

现在,我需要控制权重参数值,但是打印功能无法正常工作。有什么方法可以打印权重值?

最佳答案

我有时会做的事情(肯定不是最好的解决方案,也永远不可能),只是用np替换K后端,并用一些简单的数据对其进行测试。这是一个例子

原始Keras函数:

def loss(y_true, y_pred):
    means = K.reshape(y_pred[:, 0], (-1, 1))
    stds = K.reshape(y_pred[:, 1], (-1, 1))
    var = K.square(stds)
    denom = K.sqrt(2 * np.pi * var)
    prob_num = - K.square(y_true - means) / (2 * var)
    prob = prob_num - denom
    r = K.exp(prob - old_prediction)
    return -K.mean(K.minimum(r * advantage, K.clip(r, min_value=1 - self.LOSS_CLIPPING, max_value=1 + self.LOSS_CLIPPING) * advantage))

测试功能:
def loss(y_true, y_pred):
    means = np.reshape(y_pred[:, 0], (-1, 1))
    stds = np.reshape(y_pred[:, 1], (-1, 1))
    var = np.square(stds)
    print(var.shape)
    denom = np.sqrt(2 * np.pi * var)
    print(denom.shape)
    prob_num = - np.square(y_true - means) / (2 * var)
    prob = prob_num - denom
    r = np.exp(prob - old_prediction)
    print(r.shape)
    cliped = np.minimum(r * advantage, np.clip(r, a_min=1 - LOSS_CLIPPING, a_max=1 + LOSS_CLIPPING) * advantage)
    print(cliped.shape)
    return -np.mean(cliped)

测试它:
LOSS_CLIPPING = 0.2
y_pred = np.array([[2,1], [1, 1], [5, 1]])
y_true = np.array([[1], [3], [2]])
old_prediction = np.array([[-2], [-5], [-6]])
advantage = np.array([[ 0.51467506],[-0.64960159],[-0.53304715]])
loss(y_true, y_pred)

在上面运行之后,结果:
(3, 1)
(3, 1)
(3, 1)
(3, 1)
0.43409555193679816

关于debugging - 如何在Keras中调试自定义损失函数?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49201632/

相关文章:

python - 在pycharm中调试python代码

c - 如何像在 turbo C 中那样在 geany 中逐行执行我的 C 程序

c# - .NET PrintDocument - 文本被截断

java - Twitter Android SDK 不执行回调

linux - 警告 : GDB: Failed to set controlling terminal: Invalid argument

xcode - XCode:调试=>“模拟位置”菜单项不可用(灰色)

PHP - 尽管打印文本,但为什么不打印变量?

java - 打印所有 JVM 标志

objective-c - Objective-C 中的 Typedef 返回类型在 Swift 中不起作用

JavaScript 闭包和回调函数