python-3.x - PIL 逊线性系数keras

标签 python-3.x keras metrics pearson-correlation

我尝试在 Keras 中实现 PIL 逊线性系数作为度量,但是由于占位符,我无法使用该度量编译我的模型。

def CC(y_true, y_pred):

y_true = K.clip(y_true, K.epsilon(), 1)
y_pred = K.clip(y_pred, K.epsilon(), 1)
n_y_true=y_true/(K.sum(y_true)+K.epsilon())
n_y_pred=y_pred/(K.sum(y_pred)+K.epsilon())
y_true_average=K.mean(y_true)
y_pred_average=K.mean(y_pred)
print((K.map_fn(lambda x: x-y_pred_average,n_y_pred)).shape[0])
if not(K.map_fn(lambda x: x-y_pred_average,n_y_pred)).shape[0]==None:
    return (K.sum(K.dot((K.map_fn(lambda x: x-y_pred_average,n_y_pred)),(K.map_fn(lambda x: x-y_true_average,n_y_true))))/(K.count_params(n_y_true)-1))/(K.dot(K.std(n_y_pred),K.std(n_y_true)))
else:
    return 0

我尝试使用 K.dot 而不是 *,但仍然存在相同的错误。在编译过程中,我收到错误unsupported operand type(s) for *: 'NoneType' and 'NoneType。我不知道如何解决它。发生这种情况是因为我想将两个张量按元素相乘,但形状中的批量大小在编译期间未定义并表示为 ?形状为(?,224,224,3)。有没有办法设置或解决它?

最佳答案

问题在于两个事实:

  1. 张量的第一个维度是批量维度(这就是为什么在模型编译期间设置为None)。
  2. 您使用 summean 的方式使得您也在计算中包含了这个附加维度。

您的 Pearson 相关损失应如下所示:

def pearson_loss(y_true, y_pred):
    y_true = K.clip(y_true, K.epsilon(), 1)
    y_pred = K.clip(y_pred, K.epsilon(), 1)
    # reshape stage
    y_true = K.reshape(y_true, shape=(-1, 224 * 224 * 3))
    y_pred = K.reshape(y_pred, shape=(-1, 224 * 224 * 3))
    # normalizing stage - setting a 0 mean.
    y_true -= y_true.mean(axis=-1)
    y_pred -= y_pred.mean(axis=-1)
    # normalizing stage - setting a 1 variance
    y_true = K.l2_normalize(y_true, axis=-1)
    y_pred = K.l2_normalize(y_pred, axis=-1)
    # final result
    pearson_correlation = K.sum(y_true * y_pred, axis=-1)
    return pearson_correlation

关于python-3.x - PIL 逊线性系数keras,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48171188/

相关文章:

maven - 收集 Maven 构建指标,例如花费的时间,构建是否成功

java - 有没有办法在 SonarQube 中测量 Java 代码的各种指标

python - 使用Apply、Python 3.6从函数返回值更新数据框的多个字段

Python 对短整型而不是整型执行字节操作

keras - 如何在Keras中将ModelCheckpoint与自定义指标一起使用?

python - 在 Keras Lambda 层中调整输入图像的大小

python - Keras - 检查目标时出错

python - 评估 k-means 算法找到的邻居

python-3.x - 使用 python pandas 读取 LabVIEW TDMS 文件

python - 从 numpy 数组中选择最小索引的第一次出现