具有 n 倍交叉验证的精确召回曲线显示标准偏差

2024-02-06

我想生成一条具有 5 倍交叉验证的精确召回曲线,显示标准偏差,如ROC 曲线代码示例在这里 https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc_crossval.html.

下面的代码(改编自如何在 Scikit-Learn 中绘制超过 10 倍交叉验证的 PR 曲线 https://stackoverflow.com/questions/29656550/how-to-plot-pr-curve-over-10-folds-of-cross-validation-in-scikit-learn) 给出了每次交叉验证的 PR 曲线以及平均 PR 曲线。我还想以灰色显示平均 PR 曲线上方和下方一个标准差的区域。但它给出了以下错误(详细信息在代码下面的链接中):

ValueError: operands could not be broadcast together with shapes (91,) (78,)

import matplotlib.pyplot as plt
import numpy
from sklearn.datasets import make_blobs
from sklearn.metrics import precision_recall_curve, auc
from sklearn.model_selection import KFold
from sklearn.svm import SVC


X, y = make_blobs(n_samples=500, n_features=2, centers=2, cluster_std=10.0,
    random_state=10)

k_fold = KFold(n_splits=5, shuffle=True, random_state=10)
predictor = SVC(kernel='linear', C=1.0, probability=True, random_state=10)

y_real = []
y_proba = []

precisions, recalls = [], []

for i, (train_index, test_index) in enumerate(k_fold.split(X)):
    Xtrain, Xtest = X[train_index], X[test_index]
    ytrain, ytest = y[train_index], y[test_index]
    predictor.fit(Xtrain, ytrain)
    pred_proba = predictor.predict_proba(Xtest)
    precision, recall, _ = precision_recall_curve(ytest, pred_proba[:,1])
    lab = 'Fold %d AUC=%.4f' % (i+1, auc(recall, precision))
    plt.plot(recall, precision, alpha=0.3, label=lab)
    y_real.append(ytest)
    y_proba.append(pred_proba[:,1])
    precisions.append(precision)
    recalls.append(recall)

y_real = numpy.concatenate(y_real)
y_proba = numpy.concatenate(y_proba)
precision, recall, _ = precision_recall_curve(y_real, y_proba)
lab = 'Overall AUC=%.4f' % (auc(recall, precision))

plt.plot(recall, precision, lw=2,color='red', label=lab)

std_precision = np.std(precisions, axis=0)
tprs_upper = np.minimum(precisions[median] + std_precision, 1)
tprs_lower = np.maximum(precisions[median] - std_precision, 0)
plt.fill_between(recall_overall, upper_precision, lower_precision, alpha=0.5, linewidth=0, color='grey')

报告错误并生成绘图 https://i.stack.imgur.com/5ZPKh.png

您能否建议我如何添加以下代码以显示平均 PR 曲线周围的一个标准差?


我已经有了一个可行的解决方案,但如果有人能评论它是否在做正确的事情,那将会很有帮助。

import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import make_blobs
from sklearn.metrics import precision_recall_curve, auc
from sklearn.model_selection import KFold
from sklearn.svm import SVC
from numpy import interp

X, y = make_blobs(n_samples=500, n_features=2, centers=2, cluster_std=10.0,
    random_state=10)

k_fold = KFold(n_splits=5, shuffle=True, random_state=10)
predictor = SVC(kernel='linear', C=1.0, probability=True, random_state=10)

y_real = []
y_proba = []

precision_array = []
threshold_array=[]
recall_array = np.linspace(0, 1, 100)

for i, (train_index, test_index) in enumerate(k_fold.split(X)):
    Xtrain, Xtest = X[train_index], X[test_index]
    ytrain, ytest = y[train_index], y[test_index]
    predictor.fit(Xtrain, ytrain)
    pred_proba = predictor.predict_proba(Xtest)
    precision_fold, recall_fold, thresh = precision_recall_curve(ytest, pred_proba[:,1])
    precision_fold, recall_fold, thresh = precision_fold[::-1], recall_fold[::-1], thresh[::-1]  # reverse order of results
    thresh = np.insert(thresh, 0, 1.0)
    precision_array = interp(recall_array, recall_fold, precision_fold)
    threshold_array = interp(recall_array, recall_fold, thresh)
    pr_auc = auc(recall_array, precision_array)

    lab_fold = 'Fold %d AUC=%.4f' % (i+1, pr_auc)
    plt.plot(recall_fold, precision_fold, alpha=0.3, label=lab_fold)
    y_real.append(ytest)
    y_proba.append(pred_proba[:,1])

y_real = numpy.concatenate(y_real)
y_proba = numpy.concatenate(y_proba)
precision, recall, _ = precision_recall_curve(y_real, y_proba)
lab = 'Overall AUC=%.4f' % (auc(recall, precision))

plt.plot(recall, precision, lw=2,color='red', label=lab)

plt.legend(loc='lower left', fontsize='small')

mean_precision = np.mean(precision_array)
std_precision = np.std(precision_array)
plt.fill_between(recall, precision + std_precision, precision - std_precision, alpha=0.3, linewidth=0, color='grey')
plt.show()
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

具有 n 倍交叉验证的精确召回曲线显示标准偏差 的相关文章

  • 在 Python 中,部分函数应用(柯里化)与显式函数定义

    在 Python 中 以下方式是否被认为是更好的风格 根据更一般的 可能是内部使用的功能显式定义有用的功能 或者 使用偏函数应用来显式描述函数柯里化 我将通过一个人为的例子来解释我的问题 假设编写一个函数 sort by scoring 它
  • Python OverflowError:数学范围错误[重复]

    这个问题在这里已经有答案了 当我尝试这个计算时 出现溢出错误 output math exp 1391 12694245 100 我知道发生这种情况是因为使用的数字 超出了双精度数的范围 但有什么方法可以解决这个问题并获得输出值 有人可以帮
  • 如何使用playsound模块停止音频?

    如何在Python代码中通过playaudio模块停止音频播放 我播放过音乐 但我无法停止音乐 我怎样才能阻止它 playsound playsound name of file 您可以使用多处理模块将声音作为后台进程播放 然后随时终止它
  • 在 Django 中使用 prefetch_lated 连接 ManyToMany 字段

    我可能遗漏了一些明显的东西 但我在连接 ManyToMany 字段以在 Django 应用程序中工作时遇到问题 我有两个模型 class Area models Model name CharField class Role models
  • 在 Python 中同时插入行

    我正在尝试对我的代码进行矢量化 但遇到了障碍 我有 nxd x 值数组 x1 xn 其中每一行 x1 有很多点 x11 x1d nxd y 值数组 y1 y2 y3 其中每一行 y1 有很多点 y11 y1d x 值的 nx1 数组 x 1
  • python中嵌套字典值的总和

    我有一本这样的字典 data 11L a 2 b 1 a 2 b 3 22L a 3 b 2 a 2 b 5 a 4 b 2 a 1 b 5 a 1 b 0 33L a 1 b 2 a 3 b 5 a 5 b 2 a 1 b 3 a 1 b
  • 使用pip安装pylibmc时出错

    您好 当我尝试使用 pip 在 OSX Lion 上安装 pylibmc 时 出现以下错误 pylibmcmodule h 42 10 fatal error libmemcached memcached h file not found
  • 按字符串子字符串的列过滤 Pandas 数据框

    我正在尝试使用列中的字符串值是数据框外部字符串的子字符串的条件来过滤数据框 下面的例子 df a b c hello bye hello reference str hello there output a c 一种方法可能是使用正则表达式
  • 完全定制的Python帮助用法

    我正在尝试使用 Python 创建完全自定义的 帮助 用法 我计划将其导入到许多我想要具有风格一致性的程序中 但遇到了一些麻烦 我不知道为什么我的描述忽略换行符 尝试过 和 我无法让 出现在 ARGS 行的 换行符之后 显然它们坐在自己的行
  • Django Rest Framework 序列化器中的聚合(和其他带注释的)字段

    我正在尝试找出添加带注释字段的最佳方法 例如将任何聚合 计算 字段添加到 DRF 模型 序列化器 我的用例只是一种情况 端点返回的字段未存储在数据库中 而是从数据库计算得出 让我们看下面的例子 模型 py class IceCreamCom
  • t /= d 是什么意思? Python 和错误

    t current time b begInnIng value c change In value d duration def easeOutQuad swing function x t b c d alert jQuery easi
  • 模拟类:Mock() 还是 patch()?

    我在用mock http www voidspace org uk python mock index html使用Python 想知道这两种方法中哪一种更好 阅读 更Pythonic 方法一 只需创建一个模拟对象并使用它 代码如下 def
  • 如何使用 selenium 获取 javascript 结果?

    我有以下代码 from selenium import selenium selenium selenium localhost 4444 chrome http some site com selenium start sel selen
  • 枚举上的 random.choice

    我想用random choice on an Enum I tried class Foo Enum a 0 b 1 c 2 bar random choice Foo 但是这段代码失败了KeyError 我怎样才能随机选择一个成员Enum
  • 对 Python 列表元素进行分组

    我有一个 python 列表 如下所示 my list 25 1 0 65 25 3 0 63 25 2 0 62 50 3 0 65 50 2 0 63 50 1 0 62 我想根据以下规则对它们进行排序 1 gt 0 65 0 62 l
  • 使用神经网络包进行多项分类

    这个问题应该很简单 但文档没有帮助 我正在使用 R 我必须使用neuralnet多项式分类问题的包 所有示例均针对二项式或线性输出 我可以使用二项式输出进行一些一对一的实现 但我相信我应该能够通过使用 3 个单元作为输出层来做到这一点 其中
  • python 相当于 sed

    有没有一种方法 无需双循环即可完成以下 sed 命令的操作 Input Time Banana spinach turkey sed i Banana s Toothpaste file Output Time BananaToothpas
  • float() 参数必须是字符串或数字,而不是“时间戳”

    我无法使 scilearn 与日期时间系列一起工作 找到了这篇文章 但对我没有帮助 Pandas 类型错误 float 参数必须是字符串或数字 https stackoverflow com questions 41256626 panda
  • Python Web 编程的不同方法的优缺点

    我想使用 Python 编写一些服务器端脚本 但我对这样做的方法有点迷失了 它从 DIY CGI 方法开始 似乎以一些相当强大的框架结束 这些框架基本上可以自己完成所有工作 中间有很多东西 比如web py http webpy org P
  • 真实值与预测值的降维可视化

    我有一个数据框 如下所示 label predicted F1 F2 F3 F40 major minor 2 1 4 major major 1 0 10 minor patch 4 3 23 major patch 2 1 11 min

随机推荐

  • Java 8用于计算小数年龄的日期时间[重复]

    这个问题在这里已经有答案了 我是 Java 8 日期时间 API 的新手 想知道如何计算以小数表示的年龄 它返回双精度值 例如 30 5 这意味着 30 年零 6 个月 例如 下面的示例代码得到的输出为 30 0 但不是 30 5 这可能是
  • 散列和索引有什么区别?

    我研究了 DBMS 中的哈希 可扩展 线性 和 DBMS 中的索引 稀疏 密集 基于辅助键的索引等 但我无法理解哈希和索引之间的区别 这两种技术是一起使用还是单独使用 我很困惑 因为这两种技术的目的似乎都是为了让我们能够快速检索数据 所以我
  • EL 语法错误是 en

    JSP页面中的以下语句在第一个附近遇到错误equals出现这种情况 请问是什么原因 如何解决 请尽快纠正 fn length updateStock todayDimensionStones i count DimensionStones
  • 以编程方式获取当前页面

    在 JSF 支持 Bean 托管 Bean 焊接 Bean 无关紧要 中 我可以通过调用获取客户端所在的上下文路径 FacesContext ctx FacesContext getCurrentInstance String path c
  • C# 线程问题和最佳实践

    这是我第一次在 C 应用程序中使用线程 基本上它是一个应用程序 用于检查列表中的一堆网站是死是活 这是我第一次尝试使用多线程 public void StartThread string URL int no Thread newThrea
  • application(_:didFinishLaunchingWithOptions:)' 几乎符合可选要求

    安装 Xcode 8 beta 6 后 我收到一条警告 实例方法 application didFinishLaunchingWithOptions 几乎匹配协议 UIApplicationDelegate 的可选要求 applicatio
  • 如何在 Blackberry BrowserField 中缓存

    我正在创建一个 Blackberry 应用程序来显示某个站点的全屏 Web 视图 我有一个可以正常显示的工作浏览器字段 但从页面到页面的导航速度比本机浏览器慢 浏览器字段似乎没有内置缓存 导致加载时间很慢 当我添加以下代码来管理缓存时 该站
  • FlipSide 上带有导航控制器和表格视图的实用应用程序

    我对整个 MVC 看待事物的方式还比较陌生 我有一个基于 实用程序 应用程序模板的应用程序 MainView 和 FlipsideView 中的所有内容都运行良好 但现在我需要将 TableView 和导航控制器添加到 Flipside 主
  • igraph错误无法创建具有负数顶点的空图

    当我尝试创建下面的简单图表时 为什么会出现错误 如果我用数字替换 a 和 b 那么它可以工作吗 任何解决方案 g1 lt graph c a b directed TRUE error is Error in graph c a b dir
  • 在 hibernate 聚合函数中使用函数作为参数

    我想在 HQL 中执行以下查询 select count distinct year foo date from Foo foo 但是 这会导致以下异常 org hibernate hql ast QuerySyntaxException
  • 如何在 Dynamics 365 On-Premise 中使用 EasyRepro 自动登录

    我正在尝试使用 Dynamics 365 On Premise 中的 EasyRepro 进行自动化 UI 测试 我成功完成了测试 但遇到了一个问题 我无法自动登录到我的 Dynamics 365 Organization 下面是我使用的代
  • 如何在VBA中进行后期绑定?

    我有一个通过 VBA 创建电子邮件的函数 我通过 Excel 2016 做到了这一点 当我的一些同事尝试使用它时 出现了缺少引用的错误 Outlook Library 16 0 我在互联网上寻找解决方案 发现最好的是后期绑定 我已经阅读过它
  • JButton 边距。当雨云普拉夫时不受尊重

    该物业margin of a JButton安装 Nimbus 外观后不会受到尊重 我需要一些 小 按钮 但 nimbus 强制按钮文本周围的空间变大 所以我只得到 非常大 的按钮 我发现在Nimbus 默认页面 http docs ora
  • 将 float 转换为 UInt32 - 哪个表达式更精确

    我有一个号码float x它应该在 范围内 但它经过多次数值运算 结果可能稍微超出 范围 我需要将这个结果转换为uint y使用整个范围的UInt32 当然 我需要夹住x在 范围内并对其进行缩放 但哪种操作顺序更好呢 y uint roun
  • 如何在netbeans中对ejs文件进行语法高亮显示

    我很长时间以来一直在 netbeans IDE 中工作 最近开始在 NodeJs 中编码 但 ejs 文件没有高亮代码 如何摆脱这个问题 您需要使用某些关联文件类型配置 ejs 文件扩展名 脚步 转到工具 gt 选项 单击 其他 选项卡 然
  • 如何在重新安装应用程序时删除数据

    感谢之前的回复 重新安装应用程序后是否可以从 sqlite 中删除存储的内容 我将数据存储在数据库中 一旦我再次重新安装相同的应用程序 以前的数据仍然存储在 sqlite 中 我想在重新安装应用程序时删除存储的内容 我对此不太确定 这看起来
  • 使用spaCy 3.0将数据从旧的Spacy v2格式转换为全新的Spacy v3格式

    我有变量trainData其具有以下简化格式 Paragraph A entities 15 26 DiseaseClass 443 449 DiseaseClass 483 496 DiseaseClass Paragraph B ent
  • 添加对 MEF 插件项目的引用时,为什么会出现警告图标?

    我希望通过直接引用插件项目并实例化插件类来测试插件的核心类 当我创建测试控制台应用程序项目并将项目引用添加到插件项目时 我在引用列表中的引用旁边看到一个警告图标 带有感叹号的黄色三角形 当我添加对 dll 插件的程序集构建输出 的引用时 我
  • AngularJS:拒绝设置不安全标头“Access-Control-Request-Headers”

    我正在尝试使用 AngularJS 调用本地运行的 REST API 这是 AngularJS 代码 http defaults headers common Access Control Request Headers accept or
  • 具有 n 倍交叉验证的精确召回曲线显示标准偏差

    我想生成一条具有 5 倍交叉验证的精确召回曲线 显示标准偏差 如ROC 曲线代码示例在这里 https scikit learn org stable auto examples model selection plot roc cross