python - 为什么决定系数 R² 的实现会产生不同的结果?

标签 python numpy statistics coefficient-of-determination

当尝试实现一个 python 函数来计算决定系数 R² 时,我注意到根据我使用的计算顺序,我得到了截然不同的结果。

wikipedia page on R²对于如何计算 R² 给出了看似非常清楚的解释。我对 wiki 页面上所说内容的 numpy 解释如下:

def calcR2_wikipedia(y, yhat):
    # Mean value of the observed data y.
    y_mean = np.mean(y)
    # Total sum of squares.
    SS_tot = np.sum((y - y_mean)**2)
    # Residual sum of squares.
    SS_res = np.sum((y - yhat)**2)
    # Coefficient of determination.
    R2 = 1.0 - (SS_res / SS_tot)
    return R2

当我使用目标向量 y 和建模估计向量 yhat 尝试此方法时,此函数生成的 R² 值为 -0.00301。

但是,this stackoverflow post discussing how to calculate R² 接受的答案,给出以下定义:

def calcR2_stackOverflow(y, yhat):
    SST = np.sum((y - np.mean(y))**2)
    SSReg = np.sum((yhat - np.mean(y))**2)
    R2 = SSReg/SST
    return R2

使用与之前相同的 yyhat 向量的方法,我现在得到的 R² 为 0.319。

此外,在同一篇 stackoverflow 帖子中,很多人似乎都赞成使用 scipy 模块计算 R²,如下所示:

import scipy
slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(yhat, y)
R2 = r_value**2

在我的例子中产生 0.261。

所以我的问题是:为什么看似广为接受的来源产生的 R² 值彼此截然不同?计算两个向量之间的 R² 的正确方法是什么?

最佳答案

定义

这是一种符号滥用,常常会导致误解。您正在比较两个不同的系数:

如果你仔细阅读维基百科页面上关于决定系数的介绍,你会发现那里有讨论,它的开头如下:

There are several definitions of R2 that are only sometimes equivalent.

MCVE

您可以确认这些分数的经典实现会返回预期结果:

import numpy as np
import scipy
from sklearn import metrics

np.random.seed(12345)
x = np.linspace(-3, 3, 1001)
yh = np.polynomial.polynomial.polyval(x, [1, 2])
e = np.random.randn(x.size)
yn = yh + e

然后你的函数calcR2_wikipedia (0.9265536406736125)返回决定系数,可以确认它返回与sklearn.metrics.r2_score相同的结果。 :

metrics.r2_score(yn, yh) # 0.9265536406736125

另一方面,scipy.stats.linregress返回相关系数(仅对线性回归有效):

slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(yh, yn)
r_value # 0.9625821384210018

你可以通过它的定义来交叉确认:

C = np.cov(yh, yn)
C[1,0]/np.sqrt(C[0,0]*C[1,1]) # 0.9625821384210017

关于python - 为什么决定系数 R² 的实现会产生不同的结果?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64192772/

相关文章:

python - 在 Python 中将任何可迭代对象转换为数组

python - numpy python 3.4.1 安装 : Python 3. 4 在注册表中找不到

python - 将峰彼此等距分开

python - Cython编译错误-赋值前引用的局部变量

mysql - MySQL中是否有统计函数可以找到列中最流行的值?

python - 如何使用生成器一次性计算百分位数和排名?

python - 如何删除列表中每个列表中的最后一个元素

python - 如何在可执行文件上运行脚本?

c# - 加倍游戏模拟

python - 保存抓取的项目和文件时,Scrapy 在输出 csv 文件中插入空行