python - 使用 tf.estimator 的自定义指标

标签 python tensorflow tensorflow-estimator

我希望 tensorflow 在评估我的估算器期间计算决定系数(R 平方)。我尝试基于官方指标的实现以以下方式松散地实现它:

def r_squared(labels, predictions, weights=None,
              metrics_collections=None,
              updates_collections=None,
              name=None):

    total_error = tf.reduce_sum(tf.square(labels - tf.reduce_mean(labels)))
    unexplained_error = tf.reduce_sum(tf.square(labels - predictions))
    r_sq = 1 - tf.div(unexplained_error, total_error)

    # update_rsq_op = ?

    if metrics_collections:
        ops.add_to_collections(metrics_collections, r_sq)

    # if updates_collections:
    #     ops.add_to_collections(updates_collections, update_rsq_op)

    return r_sq #, update_rsq_op

然后,我将此函数用作 EstimatorSpec 中的指标:

estim_specs = tf.estimator.EstimatorSpec(
    ...
    eval_metric_ops={
        'r_squared': r_squared(labels, predictions),
        ...
    })

但是,这失败了,因为我的 R 平方实现没有返回 update_op。

TypeError: Values of eval_metric_ops must be (metric_value, update_op) tuples, given: Tensor("sub_4:0", dtype=float64) for key: r_squared

现在我想知道,update_op 究竟应该做什么?我真的需要实现 update_op 还是可以创建某种虚拟 update_op?如果有必要,我将如何实现?

最佳答案

好的,所以我能够弄明白。我可以将我的指标包装在一个平均指标中并使用它的 update_op。这似乎对我有用。

def r_squared(labels, predictions, weights=None,
              metrics_collections=None,
              updates_collections=None,
              name=None):

    total_error = tf.reduce_sum(tf.square(labels - tf.reduce_mean(labels)))
    unexplained_error = tf.reduce_sum(tf.square(labels - predictions))
    r_sq = 1 - tf.div(unexplained_error, total_error)

    m_r_sq, update_rsq_op = tf.metrics.mean(r_sq)

    if metrics_collections:
        ops.add_to_collections(metrics_collections, m_r_sq)

    if updates_collections:
        ops.add_to_collections(updates_collections, update_rsq_op)

    return m_r_sq, update_rsq_op

关于python - 使用 tf.estimator 的自定义指标,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47753736/

相关文章:

python - python的递归itemgetter

python - 如何将变换应用于单个神经元?

python - 对keras中的部分张量应用不同的损失函数

tensorflow - 结合使用 Estimators API 和 tf.data.Dataset 时如何加快批处理准备

python-dev安装错误: ImportError: No module named apt_pkg

python - Python 中的两行平均值,同时忽略 NaN

用于正则表达式的 Python 逻辑 NOT 运算符

python - 在 Tensorflow 中检测损坏的图像

python - 如何正确组合 tf.data.Dataset 和 tf.estimator.DNNRegressor

tensorflow - 如何在不保存检查点的情况下运行 estimator.train