机器学习之朴素贝叶斯: sklearn.naive_bayes

2023-11-17

1. 贝叶斯原理

贝叶斯分类是以贝叶斯定理为基础的一种分类算法。

已知某条件概率,如何得到事件交换后的概率;即在已知P(A|B)的情况下求得P(B|A)。条件概率P(A|B)表示事件B已经发生的前提下,事件A发生的概率。其基本求解公式为:P(A|B)=P(AB)/P(B)。贝叶斯定理:
在这里插入图片描述
贝叶斯的主要思想可以概括为:先验概率+数据=后验概率。贝叶斯定理换个表达形式:

在这里插入图片描述
对于给定训练集,首先基于特征条件独立性的假设,学习输入/输出联合概率(计算出先验概率条件概率 ,然后求出联合概率 。然后基于此模型,给定输入x,利用贝叶斯概率定理求出最大的后验概率作为输出y。

例如:一个事物具有很多属性(features),把它的众多属性看作一个向量X,即X=(x1,x2,x3,…,xn),称为属性集。事物的类别(labels)也有很多种,用集合C={c1,c2,…cm}表示。一般X和C的关系是不确定的,可以将X和C看作是随机变量,P(C|X)称为C的后验概率,与之相对的,P( C)称为C的先验概率

根据贝叶斯公式,后验概率P(C|X)=P(X|C)P( C)/P(X),计算后验概率时可以巧妙利用以下几点:

  • 在比较不同C值的后验概率时,分母P(X)总是常数,忽略掉,即是后验概率P(C|X)=P(X|C)P©,
  • 先验概率P©可以通过计算训练集中属于每一个类的训练样本所占的比例,
  • 难度在于**条件概率P(X|C)**的估计
    【针对朴素贝叶斯,因为朴素贝叶斯假设事物属性之间相互条件独立,P(X|C)=∏P(xi|ci)。】

2. 朴素贝叶斯

朴素贝叶斯分类是贝叶斯分类中最简单最常见的一种。它是基于贝叶斯定理特征条件独立假设分类方法。朴素贝叶斯的含义是:朴素——特征条件独立,贝叶斯——基于贝叶斯定理。

在这里插入图片描述

3. 朴素贝叶斯模型:

朴素贝叶斯分类器是一种有监督学习,常见有两种模型,多项式模型即为词频型和伯努利模型即文档型,还有一种高斯模型。

3.1 多项式模型MultinomialNB

当特征是离散的时候,使用多项式模型。多项式模型在计算先验概率和条件概率时,会做一些平滑处理
如果不做平滑,当某一维特征的值xi没在训练样本中出现过时,会导致条件概率为0,从而导致后验概率为0,加上平滑后可以克服这个问题。

3.2 高斯模型GaussianNB

当特征是连续变量的时候,运用多项式模型会导致很多条件概率为0。 以处理连续的特征变量,应该采用高斯模型。

3.3 伯努利模型BernoulliNB

伯努利模型适用于离散特征的情况,不同的是,伯努利模型中每个特征的取值只能是0和1。
伯努利模型中,条件概率的计算方式是:
当特征值xi=1时, P ( x i ∣ c i ) = P ( x i = 1 ∣ c i ) P(x_i|c_i)=P(x_i=1|c_i) P(xici)=P(xi=1ci)
当特征值xi=0时, P ( x i ∣ c i ) = 1 − P ( x i = 1 ∣ c i ) P(x_i|c_i)=1-P(x_i=1|c_i) P(xici)=1P(xi=1ci)

在这里插入图片描述

当特征值xi=0时,

4. sklearn 实现 朴素贝叶斯分类

在这里插入图片描述
基于Cnews 数据集,跑了一下,准确率是0.86,偏低,想着用K折交叉验证改进一下:

from sklearn.metrics import accuracy_score,f1_score,roc_auc_score,recall_score,precision_score
from sklearn.svm import LinearSVC
from sklearn.naive_bayes import MultinomialNB, GaussianNB,BernoulliNB
def train_model(X, X_test, y, folds,params=None, model_type='LSVC', plot_feature_importance=False):
    n_fold=5
    iteration=3000
    nrepeats = 2
    prediction = np.zeros((X_test.shape[0], n_fold*nrepeats))
    scores = []
    feature_importance = pd.DataFrame()
    fold_n=0
    #split method 需要写得通用些,需支持两种策略的split!!!
    for  train_index, valid_index in folds.split(X, y):
        fold_n+=1
        print('Fold', fold_n, 'started at', time.ctime())
        X_trn, X_val = X[trn_index], X[val_index]
        y_trn, y_val = y[trn_index], y[val_index]
        if model_type=='LSVC':
            model= LinearSVC()
            model.fit(X_train,y_train)
            y_valid_pred=model.predict(X_valid)
            y_pred=model.predict(X_test)
        if mode_type=='mnb':
            model=MultinomialNB()
            model.fit(X_trn,y_trn)
            y_val_pred=model.predict(X_val)  # 使用逻辑回归函数对测试集进行预测          
            y_pred=model.predict(X_test)
        if mode_type=='gnb':
            model=GaussianNB()
            model.fit(X_trn,y_trn)
            y_val_pred=model.predict(X_val)  # 使用逻辑回归函数对测试集进行预测          
            y_pred=model.predict(X_test)
        if mode_type=='bnb':
            model=BernoulliNB()
            model.fit(X_trn,y_trn)
            y_val_pred=model.predict(X_val)  # 使用逻辑回归函数对测试集进行预测          
            y_pred=model.predict(X_test)
  
        f1_scores.append(f1_score(np.array(y_val), y_val_pred,average='micro'))
        accuracy_scores.append(accuracy_score(np.array(y_val), y_val_pred,average='micro'))
        roc_auc_scores.append(roc_auc_score(np.array(y_val), y_val_pred,average='micro'))
        recall_scores.append(recall_score(np.array(y_val), y_val_pred,average='micro'))
        precision_scores.append(precision_score(np.array(y_val), y_val_pred,average='micro'))
    print('CV mean f1_scores: {0:.4f}, std: {1:.4f}.'.format(np.mean(f1_scores), np.std(f1_scores)))
    print('CV mean accuracy_scores: {0:.4f}, std: {1:.4f}.'.format(np.mean(accuracy_scores), np.std(accuracy_scores)))
    print('CV mean roc_auc_scores: {0:.4f}, std: {1:.4f}.'.format(np.mean(roc_auc_scores), np.std(roc_auc_scores)))
    print('CV mean recall_scores: {0:.4f}, std: {1:.4f}.'.format(np.mean(recall_scores), np.std(recall_scores)))
    print('CV mean precision_scores: {0:.4f}, std: {1:.4f}.'.format(np.mean(precision_scores), np.std(precision_scores)))
    
    prediction[:,fold_n]=y_pred
    if model_type == 'lgb':
            # feature importance
            fold_importance = pd.DataFrame()
            fold_importance["feature"] = X.columns
            fold_importance["importance"] = model.feature_importance()
            fold_importance["fold"] = fold_n + 1
            feature_importance = pd.concat([feature_importance, fold_importance], axis=0)
    
    #   """对K个模型的结果进行融合,融合策略:投票机制"""
    y_test_pred = []
    for i in range(len(prediction)):
        result_vote = np.argmax(np.bincount(prediction[i,:]))
        y_test_pred.append(result_vote)
    print('CV mean f1_scores: {0:.4f}, std: {1:.4f}.'.format(np.mean(f1_scores), np.std(f1_scores)))
    print('CV mean accuracy_scores: {0:.4f}, std: {1:.4f}.'.format(np.mean(accuracy_scores), np.std(accuracy_scores)))
    print('CV mean roc_auc_scores: {0:.4f}, std: {1:.4f}.'.format(np.mean(roc_auc_scores), np.std(roc_auc_scores)))
    print('CV mean recall_scores: {0:.4f}, std: {1:.4f}.'.format(np.mean(recall_scores), np.std(recall_scores)))
    print('CV mean precision_scores: {0:.4f}, std: {1:.4f}.'.format(np.mean(precision_scores), np.std(precision_scores))) 
    return y_test_pred
    

实验跑崩溃了,,,

参考链接:

https://blog.csdn.net/u013710265/article/details/72780520
https://blog.csdn.net/Kaiyuan_sjtu/article/details/80030005
https://blog.csdn.net/ivy_reny/article/details/79132162

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

机器学习之朴素贝叶斯: sklearn.naive_bayes 的相关文章

随机推荐

  • 简单有效,如何彻底卸载删除AlibabaProtect.exe

    简单有效 如何彻底卸载删除AlibabaProtect exe Process Hacker https www isharepc com 33781 html
  • Java常量池理解和经典总结

    Java常量池理解和经典总结 一 相关知识 1 什么是常量 第一种 是一个值 这个值本身 我们就叫做常量 整型常量 1024 实型常量 1 024 字符常量 g c w 字符串常量 gcw 逻辑常量 true false 这只是我们平时我们
  • JPEG数据格式分析

    添加链接描述 参考如让 感谢原创分享 JPEG数据分析 分析对象是一幅8x8的jpg图片 如下 图片已被放大并被虚线切分 这里写图片描述 用windows照片查看器查看图片详细信息 信息 参数 大小 667字节 尺寸 8x8 宽度 8像素
  • 【干货】Spring远程命令执行漏洞(CVE-2022-22965)原理分析和思考

    前言 上周网上爆出Spring框架存在RCE漏洞 野外流传了一小段时间后 Spring官方在3月31日正式发布了漏洞信息 漏洞编号为CVE 2022 22965 本文章对该漏洞进行了复现和分析 希望能够帮助到有相关有需要的人员进一步研究 1
  • 《热题100》字符串、双指针、贪心算法篇

    思路 对于输入的的字符串 只有三种可能 ipv4 ipv6 和neither ipv4 四位 十进制 无前导0 小于256 ipv6 八位 十六进制 无多余0 00情况不允许 不为空 class Solution def solve sel
  • 区块链扩容系列之Plasma MVP

    以太坊低TPS一直被诟病 最近V神提出一种将以太坊TPS提升到500的方案 一经发表就被BM调侃 可见以太坊低TPS目前确实严重阻碍了以太坊的发展 连V神都不得不经常发声 我们知道以太坊低TPS的一个关键原因是以太坊采用POW 因而将部分交
  • selenium爬虫检测之如何避免对isTrusted属性检测

    如何避免对isTrusted属性检测 检测原理 什么是isTrusted属性 在web api官方网站mozilla org有如下解释 Event接口的 isTrusted 属性是一个只读属性 它是一个布尔值 Boolean 当事件是由用户
  • java中访问数组元素的方法

    1 使用普通 for 循环 这是最常见的遍历数组的方法 使用传统的 for 循环语法 通过索引来访问数组中的每个元素 int arr 1 2 3 4 5 for int i 0 i lt arr length i int element a
  • 【线上死锁分析】由index_merge引发的死锁事件

    1 事情背景 背景由于更换新的短信供应商 同事之前可能对这块业务不太熟 原本是回执ID recordId 一个手机号一个 但是同事接的时候将这个批量发送接口只设置了一个recordId 导致了多个手机号共用了一个recordId 2 线上d
  • Linux系统发生故障时,所有文件会以只读方式挂载

    解决办法 执行mount o remount rw 让文件可以修改 原因 挂载磁盘时 没有写fstab文件 或者fstab文件里写的是磁盘名称而不是uuid
  • 解决问题:EXT4 filefield 文件上传在IE8上返回状态无效,弹出下载页面

    解决描述 EXT4 filefield 以form 文件上传 基于IE8浏览器 不管上传成功与否 返回状态无效 即success function fp o 方法无效 并弹出下载页面 原代码情况如下 1 EXT4前台视图层view view
  • civetweb框架学习和使用(一)

    背景 CivetWeb基于Mongoose项目 是一个易于使用 功能强大的C C 嵌入式Web服务器 在2013年8月16日 在编写和分发此项目所依据的原始代码后 Mongoosed的许可证已经更改了 因此 CivetWeb已从上一个MIT
  • Windows下在后台运行jar包

    为什么80 的码农都做不了架构师 gt gt gt 新建一个bat文件 输入 echo off start javaw jar xxx jar exit 执行这个批处理程序就可以在后台运行jar包了 转载于 https my oschina
  • FIddler之Fiddler移动端抓包

    前言 笔者今天的这篇文章呢 想使用通俗易懂的话语 让大家明白以下内容 什么是抓包哪些场景需要用到抓包Fiddler抓包的原理怎样使用Fiddler进行移动端抓包 一 抓包 包 Packet 是TCP IP协议通信传输中的数据单位 一般也称
  • Apache/Tomcat/JBOSS/Jetty/Nginx区别 与选择

    总结 Apache Tomcat JBOSS Nginx区别 1 Apache是Web服务器 Tomcat是应用 Java 服务器 Tomcat在中小型系统和并发访问用户不是很多的场合下被普遍使用 Apache支持静态页 Tomcat支持动
  • 千行代码bug率统计

    1 计算公式 千行代码bug率 bug数 代码行数 1000 2 bug率标准 CMMI级别中做出了相关的指标规定 千行代码缺陷率 bug率 CMM1级 11 95 CMM2级 5 52 CMM3级 2 39 CMM4级 0 92 CMM5
  • JWT(Json Web Token)的原理、渗透与防御

    关于JWT kid安全部分后期整理完毕再进行更新 2023 05 16 JWT的原理 渗透与防御 目录 JWT的原理 渗透与防御 含义 原理 JWT的起源 传统session认证问题 token与session区别 JWT的结构与内容 JW
  • CVPR 2020-Object Detection

    目录 2D目标检测 视频目标检测 2D目标检测 Large Scale Object Detection in the Wild From Imbalanced Multi Labels Rethinking Classification
  • 芯片手册中的英文的表示含义

    芯片手册中的英文的表示含义 在读芯片的数据手册的时候 会有一些英文表示不知道是什么含义 现在整理了一些在下面 1 ppm 在一些电压芯片数据手册里 有一个描述基准性能的直流参数 称为温度漂移 也称温度系数 或简称TC Temperature
  • 机器学习之朴素贝叶斯: sklearn.naive_bayes

    朴素贝叶斯 sklearn naive bayes 1 贝叶斯原理 2 朴素贝叶斯 3 朴素贝叶斯模型 3 1 多项式模型MultinomialNB 3 2 高斯模型GaussianNB 3 3 伯努利模型BernoulliNB 4 skl