python - 自定义 Keras Metric 抛出错误

标签 python keras metric

在尝试实现 Intersection over Union (IoU) 时,我遇到了一个我似乎无法定位的 python/keras 错误。 在一个单独的文件中,我定义了以下指标:

def computeIoU(y_pred_batch, y_true_batch):
    print y_true_batch.shape[0]
    return np.mean(np.asarray([imageIoU(y_pred_batch[i], y_true_batch[i]) for i in range(y_true_batch.shape[0])]))

def imageIoU(y_pred, y_true):
    y_pred = np.argmax(y_pred, axis=2)
    y_true = np.argmax(y_true, axis=2)
    inter = 0
    union = 0
    for x in range(imCols):
        for y in range(imRows):
            for i in range(num_classes):
                inter += (y_pred[y][x] == y_true[y][x] == i)
                union += (y_pred[y][x] == i or y_true[y][x] == i)
    print inter
    print union
    return float(inter)/union

在主文件中,我导入了函数并使用指标,如下所示:

fcn32_model.compile(optimizer=sgd, loss='categorical_crossentropy', metrics=['accuracy', computeIoU])

抛出的错误是

TypeError: __int__ should return int object

使用此处答案和另一个问题中建议的 Keras/tf 语法实现上述算法后,代码更改为:

def iou(y_pred_batch, y_true_batch):
    intersection = tf.zeros(())
    union = tf.zeros(())
    y_pred_batch = K.argmax(y_pred_batch, axis=-1)
    y_true_batch = K.argmax(y_true_batch, axis=-1)
    for i in range(num_classes):
        iTensor = tf.to_int64(tf.fill(y_pred_batch.shape, i))
        intersection = tf.add(intersection, tf.to_float(tf.count_nonzero(tf.logical_and(K.equal(y_true_batch, y_pred_batch), K.equal(y_true_batch, iTensor)))))
        union = tf.add(union, tf.to_float(tf.count_nonzero(tf.logical_or(K.equal(y_true_batch, iTensor), K.equal(y_pred_batch, iTensor)))))
    return intersection/union

最佳答案

问题似乎是您正在尝试以普通整数进行计算,而不是 keras 变量。

intersection = K.sum(K.abs(y_true * y_pred), axis=-1)
union_sum = K.sum(K.abs(y_true) + K.abs(y_pred), axis=-1)
IOU = (intersection) / (union_sum- intersection)

关于python - 自定义 Keras Metric 抛出错误,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49858233/

相关文章:

python-3.x - Keras 在每个 epoch 中占用的内存量无限增加

python - ValueError : `class_weight` must contain all classes in the data. 类{1,2,3}存在于数据中但不存在于 `class_weight`

java - 关于带焊接/CDI 的 LCOM4 的问题?

python - Alembic --autogenerate 对我来说失败了

Python - 'Error: AudioFileOpen failed (' wht ?')',在乒乓球游戏中播放音频文件

python - matplotlib 动画绘图散点图

python - keras中的精度计算不匹配

facebook - 获取 "People talking about this"指标 (PTAT)

java - 使用 BigInteger

python - 在 Tensorflow 中使用 InceptionV3 进行预测