python - 绘制阈值(precision_recall 曲线) matplotlib/sklearn.metrics

标签 python matplotlib scikit-learn precision-recall

我正在尝试为我的精度/召回曲线绘制阈值。我只是使用 MNSIT 数据,并以《使用 scikit-learn、keras 和 TensorFlow 进行机器学习动手》一书中的示例为例。尝试训练模型检测 5 的图像。不知道你需要看多少代码。我已经为训练集制作了混淆矩阵,并计算了精度和召回值以及阈值。我已经绘制了 pre/rec 曲线,书中的示例说要添加轴标签、壁架、网格并突出显示阈值,但代码在我在下面放置星号的书中被切断了。除了如何让阈值显示在图上之外,我能够弄清楚所有这些。我已经附上了一张图,说明了书中的图表与我所拥有的相比。
这就是书中所展示的:
enter image description here
对比我的图表:
enter image description here
我无法显示带有两个阈值点的红色点线。有谁知道我将如何做到这一点?这是我的代码如下:

from sklearn.metrics import precision_recall_curve

precisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores)

def plot_precision_recall_vs_thresholds(precisions, recalls, thresholds):
    plt.plot(thresholds, precisions[:-1], "b--", label="Precision")
    plt.plot(thresholds, recalls[:-1], "g--", label="Recall")
    plt.xlabel("Threshold")
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
    plt.grid(b=True, which="both", axis="both", color='gray', linestyle='-', linewidth=1)

plot_precision_recall_vs_thresholds(precisions, recalls, thresholds)
plt.show()
我知道这里有很多关于 sklearn 的问题,但似乎没有一个问题包括让那条红线出现。我将不胜感激!

最佳答案

您可以使用以下代码绘制水平线和垂直线:

plt.axhline(y_value, c='r', ls=':')
plt.axvline(x_value, c='r', ls=':')

关于python - 绘制阈值(precision_recall 曲线) matplotlib/sklearn.metrics,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/65975815/

相关文章:

python - 逻辑回归系数在Python中编写函数

python - 列表串联,将以 ':' 结尾的元素添加到列表中,将其他元素添加到子列表中

python - Matplotlib 用户界面示例因超过三个补丁而中断

Python:Rabin-Karp算法散列

显示均值和置信区间的 Python 箱线图

python - imshow() 子图生成不需要的空白

python - sklearn 和 n_jobs 中的超参数优化 > 1 : Pickling

python - 我怎样才能用sklearn做一维高斯混合的直方图?

python - 如何从 [1,4,256] cuda.FloatTensor 形成 [1,1,256]?

python - 分组箱线图