Scikit-learn Predict_proba 给出错误答案

2024-01-09

这是来自的后续问题如何知道 Scikit-learn 中的 Predict_proba 返回数组中表示哪些类 https://stackoverflow.com/questions/16937243/how-to-know-what-classes-are-represented-in-return-array-from-predict-proba-in-s

在那个问题中,我引用了以下代码:

>>> import sklearn
>>> sklearn.__version__
'0.13.1'
>>> from sklearn import svm
>>> model = svm.SVC(probability=True)
>>> X = [[1,2,3], [2,3,4]] # feature vectors
>>> Y = ['apple', 'orange'] # classes
>>> model.fit(X, Y)
>>> model.predict_proba([1,2,3])
array([[ 0.39097541,  0.60902459]])

我在那个问题中发现这个结果代表了点属于每个类的概率,按照 model.classes_ 给出的顺序

>>> zip(model.classes_, model.predict_proba([1,2,3])[0])
[('apple', 0.39097541289393828), ('orange', 0.60902458710606167)]

所以......这个答案,如果解释正确的话,表示该点可能是一个“橙色”(由于数据量很小,置信度相当低)。但直观上,这个结果显然是不正确的,因为给出的点与“apple”的训练数据相同。为了确定起见,我也测试了相反的情况:

>>> zip(model.classes_, model.predict_proba([2,3,4])[0])
[('apple', 0.60705475211840931), ('orange', 0.39294524788159074)]

同样,显然是错误的,但方向相反。

最后,我尝试了距离更远的点。

>>> X = [[1,1,1], [20,20,20]] # feature vectors
>>> model.fit(X, Y)
>>> zip(model.classes_, model.predict_proba([1,1,1])[0])
[('apple', 0.33333332048410247), ('orange', 0.66666667951589786)]

同样,该模型预测了错误的概率。但是, model.predict 函数是正确的!

>>> model.predict([1,1,1])[0]
'apple'

现在,我记得在文档中读过一些关于 Predict_proba 对于小数据集不准确的内容,尽管我似乎无法再次找到它。这是预期的行为,还是我做错了什么?如果这是预期的行为,那么为什么 Predict 和 Predict_proba 函数的输出不一致?重要的是,数据集需要有多大,我才能信任 Predict_proba 的结果?

- - - - 更新 - - - -

好的,所以我对此做了一些更多的“实验”:predict_proba 的行为严重依赖于“n”,但不是以任何可预测的方式!

>>> def train_test(n):
...     X = [[1,2,3], [2,3,4]] * n
...     Y = ['apple', 'orange'] * n
...     model.fit(X, Y)
...     print "n =", n, zip(model.classes_, model.predict_proba([1,2,3])[0])
... 
>>> train_test(1)
n = 1 [('apple', 0.39097541289393828), ('orange', 0.60902458710606167)]
>>> for n in range(1,10):
...     train_test(n)
... 
n = 1 [('apple', 0.39097541289393828), ('orange', 0.60902458710606167)]
n = 2 [('apple', 0.98437355278112448), ('orange', 0.015626447218875527)]
n = 3 [('apple', 0.90235408180319321), ('orange', 0.097645918196806694)]
n = 4 [('apple', 0.83333299908143665), ('orange', 0.16666700091856332)]
n = 5 [('apple', 0.85714254878984497), ('orange', 0.14285745121015511)]
n = 6 [('apple', 0.87499969631893626), ('orange', 0.1250003036810636)]
n = 7 [('apple', 0.88888844127886335), ('orange', 0.11111155872113669)]
n = 8 [('apple', 0.89999988018127364), ('orange', 0.10000011981872642)]
n = 9 [('apple', 0.90909082368682159), ('orange', 0.090909176313178491)]

我应该如何在我的代码中安全地使用这个函数?至少,是否有任何 n 值可以保证与 model.predict 的结果一致?


predict_probas正在使用 libsvm 的 Platt 缩放功能来校准概率,请参阅:

  • sklearn.svm.svc的函数predict_proba()内部如何工作? https://stackoverflow.com/questions/15111408/how-does-sklearn-svm-svcs-function-predict-proba-work-internally

因此,超平面预测和概率校准确实可能不一致,特别是当数据集中只有 2 个样本时。奇怪的是,在这种情况下,libsvm 为缩放概率所做的内部交叉验证并没有(明确)失败。也许这是一个错误。人们必须深入研究 libsvm 的 Platt 扩展代码才能了解发生了什么。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Scikit-learn Predict_proba 给出错误答案 的相关文章

  • Python Popen 与 psexec 挂起 - 不良结果

    我对 subprocess Popen 和我认为是管道的问题有疑问 我有以下代码块 从 cli 运行时 100 都不会出现问题 p subprocess Popen psexec serverName get cmd c ver echo
  • 如何在序列化器创建方法中获取 URL Id?

    我有以下网址 url r member P
  • python 模拟第三方模块

    我正在尝试测试一些处理推文的类 我使用 Sixohsix twitter 来处理 Twitter API 我有一个类充当 Twitter 类的外观 我的想法是模拟实际的 Sixohsix 类 通过随机生成新推文或从数据库检索它们来模拟推文的
  • Python逻辑运算符优先级[重复]

    这个问题在这里已经有答案了 哪个运算符优先4 gt 5 or 3 lt 4 and 9 gt 8 这会被评估为真还是假 我知道该声明3 gt 4 or 2 lt 3 and 9 gt 10 显然应该评估为 false 但我不太确定 pyth
  • 从 ffmpeg 获取实时输出以在进度条中使用(PyQt4,stdout)

    我已经查看了很多问题 但仍然无法完全弄清楚 我正在使用 PyQt 并且希望能够运行ffmpeg i file mp4 file avi并获取流式输出 以便我可以创建进度条 我看过这些问题 ffmpeg可以显示进度条吗 https stack
  • Django 模型在模板中不可迭代

    我试图迭代模型以获取列表中的第一个图像 但它给了我错误 即模型不可迭代 以下是我的模型和模板的代码 我只需要获取与单个产品相关的列表中的第一个图像 模型 py class Product models Model title models
  • if 语句未命中中的 continue 断点

    在下面的代码中 两者a and b是生成器函数的输出 并且可以评估为None或者有一个值 def testBehaviour self a None b 5 while True if not a or not b continue pri
  • Argparse nargs="+" 正在吃位置参数

    这是我的解析器配置的一小部分 parser add argument infile help The file to be imported type argparse FileType r default sys stdin parser
  • 从零开始的 numpy 形状意味着什么

    好的 我发现数组的形状中可以包含 0 对于将 0 作为唯一维度的情况 这对我来说是有意义的 它是一个空数组 np zeros 0 但如果你有这样的情况 np zeros 0 100 让我很困惑 为什么这么定义呢 据我所知 这只是表达空数组的
  • 为什么在 Python 2.4 中使用 Unicode 数据会出现 ASCII 编码错误,而在 2.7 中却不会?

    我有一个程序 当在 Python 2 7 中运行时 会生成正确的 Unicode 输出到标准输出 当在 Python 2 4 中运行时 我得到UnicodeEncodeError ascii codec can t encode chara
  • 如何通过在 Python 3.x 上按键来启动和中断循环

    我有这段代码 当按下 P 键时会中断循环 但除非我按下非 P 键 否则循环不会工作 def main openGame while True purchase imageGrab if a sum gt 1200 fleaButton ti
  • 对图像块进行多重处理

    我有一个函数必须循环遍历图像的各个像素并计算一些几何形状 此函数需要很长时间才能运行 在 24 兆像素图像上大约需要 5 小时 但似乎应该很容易在多个内核上并行运行 然而 我一生都找不到一个有据可查 解释充分的例子来使用 Multiproc
  • 将 matplotlib 颜色图集中在特定值上

    我正在使用 matplotlib 颜色图 seismic 绘制绘图 并且希望白色以 0 为中心 当我在不进行任何更改的情况下运行脚本时 白色从 0 下降到 10 我尝试设置 vmin 50 vmax 50 但在这种情况下我完全失去了白色 关
  • 如何在 python 中没有 csv.reader 迭代器的情况下解析单行 csv 字符串?

    我有一个 CSV 文件 需要重新排列和重新编码 我想跑 line line decode windows 1250 encode utf 8 在由 CSV 读取器解析和分割之前的每一行 或者我想自己迭代行 运行重新编码 并仅使用单行解析表单
  • 使用 Firefox 绕过弹出窗口下载文件:Selenium Python

    我正在使用 selenium 和 python 来从中下载某些文件web page http www oceanenergyireland com testfacility corkharbour observations 我之前一直使用设
  • 限制 django 应用程序模型中的单个记录?

    我想使用模型来保存 django 应用程序的系统设置 因此 我想限制该模型 使其只能有一条记录 极限怎么办 尝试这个 class MyModel models Model onefield models CharField The fiel
  • 如何读取Python字节码?

    我很难理解 Python 的字节码及其dis module import dis def func x 1 dis dis func 上述代码在解释器中输入时会产生以下输出 0 LOAD CONST 1 1 3 STORE FAST 0 x
  • Elastic Beanstalk 中的 enum34 问题

    我正在尝试在 Elastic Beanstalk 中设置 django 环境 当我尝试通过requirements txt 文件安装时 我遇到了python3 6 问题 File opt python run venv bin pip li
  • 从 Twitter API 2.0 获取 user.fields 时出现问题

    我想从 Twitter API 2 0 端点加载推文 并尝试获取标准字段 作者 文本 和一些扩展字段 尤其是 用户 字段 端点和参数的定义工作没有错误 在生成的 json 中 我只找到标准字段 但没有找到所需的 user fields 用户
  • 列表值的意外更改

    这是我的课 class variable object def init self name name alias parents values table name of the variable self name 这是有问题的函数 f

随机推荐

  • PHP 在循环内使用查询的替代方案

    有人告诉我 在循环中使用查询 选择 是一种不好的做法 因为它会降低服务器性能 我有一个数组 例如 Array 1 gt Los Angeles Array 2 gt New York Array 3 gt Chicago 这些只是3个索引
  • 如何在 MySQL 中使用游标循环遍历表?

    我的数据库中有下表 我编写了以下存储过程来循环该表 当我调用这个存储过程时 我只得到一条记录 我可能犯了什么错误 如何解决 Field Type Null Key Default Extra date date NO NULL inQty
  • 通过Java像查询(JSON)一样执行Mongo

    我想知道是否有一种方法可以直接通过Java执行类似mongo的查询 即我们将类似mongoDB的查询作为字符串提供给mongoDB的Java驱动程序中的函数作为字符串对象 并返回一个DBCursor对象 就像是 import com mon
  • Angular2 xlink:href 问题

    我有一个 ngFor在里面我正在写一些SVG
  • 如何撤消 git 中的最后一次提交[重复]

    这个问题在这里已经有答案了 错误地 我做到了git add and git commit in the develop分支 但幸运的是 我没有这样做git push 所以我想把它恢复到原来的状态 I tried git reset soft
  • 为什么使用无符号字符写入二进制文件?为什么不应该使用流运算符写入二进制文件?

    我的第一个问题是 为什么习惯上使用无符号字符以二进制模式写入文件 在我见过的所有示例中 在写入二进制文件之前 任何其他数值都会被转换为 unsigned char 我的第二个问题是 使用流运算符写入二进制文件有什么不好 我听说 read 和
  • 尝试使用 npx create-react-app 时出现超时错误

    当我尝试运行此程序时 出现此错误 npm ERR Response timeout while trying to fetch https registry npmjs org typescript eslint 2fparser over
  • 如何使用开发数据填充生产数据库(heroku)? (导轨)

    heroku run rake db migrate可以很好地改变生产数据库的结构 Migrating to CreateUsers 20120318090252 Migrating to AddIndexToUsersEmail 2012
  • 内容长度和其他 HTTP 标头?

    如果我在生成普通 HTML 页面时设置此标头 会给我带来什么好处吗 我看到一些框架会设置这个标头属性 我想知道为什么 与其他标头一起 例如Content Type text html 浏览器加载网站是否更快或更流畅 PS 他们这样做是这样的
  • 如何解压蟒蛇蛋?

    我试图在使用 py2exe 时捆绑一些 Egg 依赖项 如 py2exe 网站上所述 它不适用于这些依赖项 我需要先解压缩它们 我尝试过先运行easy install m lxml进而easy install always unzip lx
  • pandas 将日期时间列转换为时间戳

    我是熊猫初学者 我的数据框第一列是日期时间 例如 2016 年 9 月 19 日 10 30 00 并且许多记录都喜欢它 我正在尝试将此列转换为时间戳并将其写入另一个数据帧 我正在尝试一步完成 我正在尝试用 python 3 编写 impo
  • 设置 Apache CouchDB 屏幕在容器重新启动时重新出现

    我使用官方 Docker 镜像运行 CouchDB v2 3 我已使用 Fauxton 将数据库配置为单节点 data 目录挂载到本地目录 当我重新启动容器时 数据库仍然存在 所以卷绑定按预期工作 现在 每次我重新启动容器并导航到 设置 选
  • Spark Streaming + Kafka:SparkException:无法找到 Set 的领导者偏移量

    我正在尝试设置 Spark Streaming 以从 Kafka 队列获取消息 我收到以下错误 py4j protocol Py4JJavaError An error occurred while calling o30 createDi
  • 库中存储库的 NoSuchBeanDefinitionException

    我创建了一个用于在多个 Spring Boot 应用程序上共享代码的库 该库包含一个 Repository 类RequestRepository 将库添加到 Spring Boot 项目后 编译并成功运行单元测试 Library Reque
  • YSlow 规则“不要在 HTML 中缩放图像”背后的基本原理是什么

    我在以下地方遇到过这个规则YSlow http developer yahoo com performance rules html no scale为了提高性能 表示不应在 HTML 中调整图像大小 他们没有提到这条规则的任何具体原因 有
  • 为什么OpenGL(IOS)中有.pvr文件

    我正在 IOS 中使用 OpenGL 制作应用程序 使用 PVR 纹理来制作 3D 效果 我无法理解 pvr 文件 所以请朋友们了解一下 pvr 文件以及它在 OpenGL 中的重要性以及我该如何制作它 PVR 文件是各种纹理格式的容器 例
  • 从 condaenvironment.yaml 安装时的依赖项的 pip 依赖项

    我正在尝试为项目的用户创建一个 condaenvironment yml 文件 其中一种依赖项不是由 conda 分发的 而是通过 pip github 提供 我假设基于这个例子 https github com conda conda b
  • 使用 avro-tools 连接 Avro 文件

    我正在尝试将 avro 文件合并为一个大文件 问题是concat命令不接受通配符 hadoop jar avro tools jar concat input part output bigfile avro I get 线程 main 中
  • MYSQL INSERT 中的德语变音

    我的 mysql 插入语句有问题 我有一个将 utf 8 字符正确提交到插入文件的表单 我已经检查了 POST 变量 现在 当我查看数据库中的 INSERT 时 没有变音符号 而是问号 该错误必须位于插入语句之前 如果我从数据库输出 手动输
  • Scikit-learn Predict_proba 给出错误答案

    这是来自的后续问题如何知道 Scikit learn 中的 Predict proba 返回数组中表示哪些类 https stackoverflow com questions 16937243 how to know what class