python - sklearn 上的逻辑回归函数

标签 python numpy machine-learning scipy scikit-learn

我正在从 sklearn 学习逻辑回归并遇到这个:http://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html#sklearn.linear_model.LogisticRegression

我创建了一个实现,可以显示训练和测试的准确度分数。然而,目前还不清楚这是如何实现的。我的问题是:什么是最大似然估计?这是如何计算的?什么是误差度量?使用的优化算法是什么?

我在理论上知道以上所有内容,但是我不确定 scikit.learn 在何时何地以及如何计算它,或者我是否需要在某个时候实现它。我的准确率为 83%,这是我的目标,但我对 scikit learn 如何实现这一目标感到非常困惑。

谁能给我指出正确的方向?

最佳答案

我最近开始自己研究 LR,我仍然没有得到很多推导步骤,但我想我知道正在使用哪些公式。

首先,我们假设您使用的是最新版本的 scikit-learn,并且正在使用的求解器是 solver='lbfgs'(我相信这是默认设置)。

代码在这里:https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/linear_model/logistic.py

What is the Maximum likelihood estimate? How is this being calculated?

计算似然估计的函数是这个 https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/linear_model/logistic.py#L57

有趣的是:

# Logistic loss is the negative of the log of the logistic function.
out = -np.sum(sample_weight * log_logistic(yz)) + .5 * alpha * np.dot(w, w)

这是formula 7 of this tutorial .该函数还计算可能性的梯度,然后将其传递给最小化函数(见下文)。一件重要的事情是拦截是教程中公式的 w0。但只有 fit_intercept 是 True 才有效。

What is the error measure?

抱歉,我不确定。

What is the optimisation algorithm used?

查看代码中的以下几行:https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/linear_model/logistic.py#L389

就是这个函数http://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.fmin_l_bfgs_b.html

一件非常重要的事情是类是+1或-1! (对于二进制情况,文献中0和1很常见,但行不通)

另请注意 numpy broadcasting所有公式都使用规则。 (这就是你看不到迭代的原因)

This was my attempt at understanding the code.我慢慢地发疯,直到撕掉 appart scikit-learn 代码(仅适用于二进制情况)。这也用作 inspiration too

希望对您有所帮助。

关于python - sklearn 上的逻辑回归函数,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/24935415/

相关文章:

optimization - 提高以下计算 softmax 导数的代码性能的技巧

python - django-pipeline js 压缩器不工作

python - python中两个不同数据帧的散点图数据

python - 如何将 numpy.int32 转换为 decimal.Decimal

python - Tensorflow Estimator : loss not decreasing when using tf. feature_column.embedding_column 用于分类变量列表

machine-learning - Sense2vec : os error Could not open binary file b

python - 将字符串转换为 Quantlib Date()

python - 获取访问 token 后如何读取用户的谷歌日历事件?

python - 如何将 (3,) numpy 向量转换为 (2,2,3) 矩阵?

python - numpy.histogram2d 在传递子集 pandas.DataFrame 时引发异常