在Kaggle手写数字数据集上使用Spark MLlib的RandomForest进行手写数字识别

2023-05-16

昨天我使用Spark MLlib的朴素贝叶斯进行手写数字识别,准确率在0.83左右,今天使用了RandomForest来训练模型,并进行了参数调优。

首先来说说RandomForest 训练分类器时使用到的一些参数:

  • numTrees:随机森林中树的数目。增大这个数值可以减小预测的方差,提高预测试验的准确性,训练时间会线性地随之增长。
  • maxDepth:随机森林中每棵树的深度。增加这个值可以是模型更具表征性和更强大,然而训练也更耗时,更容易过拟合。

    在这次的训练过程中,我就是反复调整上面两个参数来提升预测的准确性。首先来设定一下一些参数的初始值。

    val numClasses = 10
    val categoricalFeaturesInfo = Map[Int, Int]()
    val numTrees = 3 
    val featureSubsetStrategy = "auto" 
    val impurity = "gini"
    val maxDepth = 4
    val maxBins = 32

第一次我将树的数目设定为3,每棵树深度为4。下面开始训练模型:

val randomForestModel = RandomForest.trainClassifier(data, numClasses, categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)

与使用朴素贝叶斯时评估准确率方式一样,我使用训练数据来计算准确率:

    val nbTotalCorrect = data.map { point =>
      if (randomForestModel.predict(point.features) == point.label) 1 else 0
    }.sum
    val numData = data.count()
    println(numData)
    //42000
    val nbAccuracy = nbTotalCorrect / numData

下面是每次对上面所说到的两个参数进行调整后得到的准确率:

    //numTree=3,maxDepth=4,准确率:0.5507619047619048
    //numTree=4,maxDepth=5,准确率:0.7023095238095238
    //numTree=5,maxDepth=6,准确率:0.693595238095238
    //numTree=6,maxDepth=7,准确率:0.8426428571428571
    //numTree=7,maxDepth=8,准确率:0.879452380952381
    //numTree=8,maxDepth=9,准确率:0.9105714285714286
    //numTree=9,maxDepth=10,准确率:0.9446428571428571
    //numTree=10,maxDepth=11,准确率:0.9611428571428572
    //numTree=11,maxDepth=12,准确率:0.9765952380952381
    //numTree=12,maxDepth=13,准确率:0.9859523809523809
    //numTree=13,maxDepth=14,准确率:0.9928333333333333
    //numTree=14,maxDepth=15,准确率:0.9955
    //numTree=15,maxDepth=16,准确率:0.9972857142857143
    //numTree=16,maxDepth=17,准确率:0.9979285714285714
    //numTree=17,maxDepth=18,准确率:0.9983809523809524
    //numTree=18,maxDepth=19,准确率:0.9989285714285714
    //numTree=19,maxDepth=20,准确率:0.9989523809523809
    //numTree=20,maxDepth=21,准确率:0.999
    //numTree=21,maxDepth=22,准确率:0.9994761904761905
    //numTree=22,maxDepth=23,准确率:0.9994761904761905
    //numTree=23,maxDepth=24,准确率:0.9997619047619047
    //numTree=24,maxDepth=25,准确率:0.9997857142857143
    //numTree=25,maxDepth=26,准确率:0.9998333333333334
    //numTree=29,maxDepth=30,准确率:0.9999523809523809

可以发现,准确率在numTree=11,maxDepth=12 附近开始收敛到0.999。这次得到的准确率要比上次使用朴素贝叶斯训练得出的准确率(0.826)要高出许多。现在开始对测试数据进行预测,使用的参数是numTree=29,maxDepth=30

val predictions = randomForestModel.predict(features).map { p => p.toInt }

把训练出来的结果上传到Kaggle上,得到的准确率为0.95929 ,经过我的四次参数调整,得到的最高的准确率是0.96586 ,设置的参数是:numTree=55,maxDepth=30 ,当我将参数改为numTree=70,maxDepth=30 时,准确率有所下降,为0.96271 ,看来这个时候出现过拟合了。不过准确率能从昨天的0.83提高到0.96还是挺兴奋的,我还会继续尝试使用其他方式进行手写数字识别,不知何时能达到1.

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

在Kaggle手写数字数据集上使用Spark MLlib的RandomForest进行手写数字识别 的相关文章

  • sparkstreamming 消费kafka(1)

    pom
  • spark SQL基础教程

    1 sparkSQL入门 sparksql专门用于处理结构化的数据 而RDD还可以处理非结构化的数据 sparksql的优点之一是sparkfsql使用统一的api读取不同的数据 第二个优点是可以在语言中使用其他语言 例如python 另外
  • JAVA 安装与简单使用

    JAVA简易安装 下载安装 环境变量 进入变量界面 设置变量 验证JAVA环境 运行Java程序 个人站 ghzzz cn 还在备案 很快就能访问了 下载安装 第一步当然是从官网下载安装java了 网上有很多的教程 这里简单的写一下 在这里
  • Spark 配置

    文章目录 1 Spark 配置 1 1 Spark 属性 1 1 1 动态加载Spark属性 1 1 2 查看Spark属性 1 2 环境变量 2 重新指定配置文件目录 3 继承Hadoop集群配置 4 定制的Hadoop Hive配置 1
  • Kaggle 数据集导入 Jupyter Notebook

    我正在尝试将一些数据从 kaggle 导入到笔记本中 我收到的错误是 401 未经授权 但我已接受比赛规则并且能够下载数据 这是我正在运行的代码 from kaggle api kaggle api extended import Kagg
  • python+django基于Spark的国漫画推荐系统 可视化大屏分析

    国漫推荐信息是现如今社会信息交流中一个重要的组成部分 本文将从国漫推荐管理的需求和现状进行分析 使得本系统的设计实现具有可使用的价 做出一个实用性好的国漫推荐系统 使其能满足用户的需求 并可以让用户更方便快捷地国漫推荐 国漫推荐系统的设计开
  • Caret 模型随机森林转化为 PMML 错误

    我想使用 pmml 库导出 Caret 随机森林模型 以便我可以使用它在 Java 中进行预测 这是我收到的错误的重现 data iris require caret require pmml rfGrid2 lt expand grid
  • “R”包“ranger”中的“最大深度”相当于什么?

    其他随机森林工具具有限制特定分支上的最大分割深度的 刻度盘 例如 h2o randomForest 具有 max 深度 游侠 的版本是什么 我不熟悉h2o randomForest包 但我对随机森林的一般理解是 每棵树都会生长 直到树的每片
  • 如何消除使用 randomForest 运行预测的“外部函数调用中的 NA/NaN/Inf (arg 7)”

    我对此进行了广泛的研究 但没有找到解决方案 我已经清理了我的数据集 如下所示 library raster impute mean lt function x replace x is na x is nan x is infinite x
  • 运行 randomForest 时出错:找不到对象

    所以我试图为我的数据集拟合一个随机森林分类器 我对 R 很陌生 我想这是一个简单的格式问题 我读入一个文本文件并转换我的数据集 使其具有以下格式 取出机密信息 gt head df train 2 GOLGA8A ITPR3 GPR174
  • R 插入符:结合 rfe() 和 train()

    我想将递归特征消除与rfe 并与模型选择一起进行调整trainControl 使用该方法rf 随机森林 我想要的是 MAPE 平均绝对百分比误差 而不是标准的汇总统计数据 因此我尝试使用以下代码ChickWeight数据集 library
  • Spark 中 BroadCast 导致的内存溢出(SparkFatalException)

    背景 本文基于 Spark 3 1 1 open jdk 1 8 0 352 目前在排查 Spark 任务的时候 遇到了一个很奇怪的问题 在此记录一下 现象描述 一个 Spark Application Driver端的内存为 5GB 一直
  • 随机森林在 opencv python (cv2) 中不起作用

    我似乎无法正确传递参数来从 python 中训练 opencv 中的随机森林分类器 我用 C 编写了一个可以正常工作的实现 但在 python 中没有得到相同的结果 我在这里找到了一些示例代码 http fossies org linux
  • 为什么 R 和 Python 之间得到不同的 RandomForest 结果?

    我正在尝试比较使用 R 和使用 Python 的随机森林模型的结果 我要比较的模型性能的关键衡量指标是 AUC ROC 曲线下面积 原因是 AUC 值代表预测值 即概率 的分布 我确实发现 R 和 Python 之间的 AUC 值存在一些显
  • PySpark 和 MLLib:随机森林预测的类概率

    我正在尝试提取使用 PySpark 训练过的随机森林对象的类概率 但是 我在文档中没有看到它的示例 也不是一种方法RandomForestModel 我怎样才能从a中提取类别概率RandomForestModelPySpark 中的分类器
  • 随机森林中什么是袋外错误? [关闭]

    Closed 这个问题需要多问focused help closed questions 目前不接受答案 随机森林中什么是袋外错误 它是在随机森林中找到正确数量的树的最佳参数吗 我将尝试解释一下 假设我们的训练数据集由 T 表示 并且假设数
  • Pandas scatter_matrix - 绘制分类变量

    我正在查看 Kaggle 竞赛中著名的泰坦尼克号数据集 http www kaggle com c titanic gettingStarted data http www kaggle com c titanic gettingStart
  • scikit-learn:如何计算百分比均方根误差(RMSE)?

    我有一个数据集 在此链接中找到 https drive google com open id 0B2Iv8dfU4fTUY2ltNGVkMG05V00 https drive google com open id 0B2Iv8dfU4fTU
  • R ranger 包中的预测概率

    我正在尝试在 R 中建立一个具有随机森林分类的 模型 通过 Ned Horning 编辑代码 我首先使用randomForest包但后来发现ranger 这保证了更快的计算 首先 我使用下面的代码在拟合模型后获得每个类别的预测概率rando
  • 用于分类的 Python 向量化[重复]

    这个问题在这里已经有答案了 我目前正在尝试构建一个包含大约 80 个类别的文本分类模型 文档分类 当我使用随机森林构建和训练模型时 将文本矢量化为 TF IDF 矩阵后 该模型运行良好 然而 当我引入新数据时 我用来构建 RF 的相同单词不

随机推荐

  • secoclient全版本下载分享

    前言 工作需要使用 secoclient xff0c 同事们大多都用 Windows环境 客户提供的客户端也是Windows版本的 这就让使用Mac几个同事难受啦 用Windows虚拟机 xff1f 根据我的经验 xff0c 一般的VPN客
  • Centos升级ruby

    CentOS7 安装的ruby默认版本是 xff1a ruby v span class token punctuation span 11 43 53 span class token punctuation span ruby 2 0
  • Windows 11下载

    Windows 11是微软于2021年推出的Windows NT系列操作系统 xff0c 为Windows 10的后继者 正式版本于2021年10月5日发行 xff0c 并开放给符合条件的Windows 10设备通过Windows Upda
  • docker容器安装图形桌面

    文章目录 视频教程版本信息创建一个CONTAINERubuntu官方国内源docker镜像unminimize中文环境设置中文环境 安装安装TigerVNC Server安装 xfce4精简版本 配置设置vnc密码 vnc xstartup
  • ubuntu官方国内源

    背景 之前我一直在使用中科大的源 xff0c 还是挺快的 一直也没有感觉有什么问题 直到最近在折腾vnc xff0c 发现中科大的源有一些包会404 xff0c 安装不了 而我在vmware中的正好是默认的cn archive ubuntu
  • mame新版ROM下载网站推荐

    网站地址 https www retroroms info index php 中文插件安装 浏览器插件 https www tampermonkey net UP主自己写的脚本 已经失效 https gitee com lxyoucan
  • RuoYi若依打包发布与部署

    上一节我们已经讲过了如果搭建开发环境 xff0c 那么如果代码写完了 xff0c 如何打包发布 部署到生产环境呢 xff1f RuoYi开发实战 搭建开发环境 https blog csdn net lxyoucan article det
  • vscode设置Prettier为默认格式化插件

    1 目的 xff1a ctrl 43 s保存 xff0c 自动格式化文档 2 所需插件Prettier 3 操作步骤 先打开vscode软件 xff0c 左下角点击设置 gt 打开设置 gt 在右上方有一个搜索框 先设定自动保存文件 xff
  • ASUS X415安装系统找不到硬盘解决办法

    同事让我帮忙安装系统 xff0c 笔记本电脑型号是ASUS X415 原本以为是手到擒来的事情 xff0c 结果我在上面还是消耗了不少时间 现象 老毛桃PE 无法识别到硬盘 微PE可以识别到硬盘 xff0c 但是系统安装以后 xff0c 无
  • archlinux中navicat无法使用fcitx5输入法

    现象 archlinux中navicat无法使用fcitx5输入法 而我在ubuntu中使用navicat调用fcitx输入法是可以正常使用的 在网上搜索了很久 xff0c 这方面的文章比较少 而我的其他程序输入法又是正常的 解决办法 参考
  • JetBrains Gateway IDEA远程开发

    为什么进行远程开发 xff1f 无论身处何处数秒内连接至远程环境 充分利用远程计算机的强大功能 在任何笔记本电脑上都可以轻松工作 xff0c 无论其性能如何 借助远程计算机的计算资源 xff0c 充分利用最大规模的数据集和代码库 在远程服务
  • ubuntu 22.04安装nvm

    执行安装脚本 span class token function sudo span span class token function apt span span class token function install span spa
  • 手推DNN,CNN池化层,卷积层反向传播

    反向传播算法是神经网络中用来学习的算法 xff0c 从网络的输出一直往输出方向计算梯度来更新网络参数 xff0c 达到学习的目的 xff0c 而因为其传播方向与网络的推理方向相反 xff0c 因此成为反向传播 神经网络有很多种 xff0c
  • 软件架构概念和面向服务的架构

    摘要 软件架构作为软件开发过程的一个重要组成部分 xff0c 有着各种各样的方法和路线图 xff0c 它们都有一些共同的原则 基于架构的方法作为控制系统构建和演化复杂性的一种手段得到了推广 引言 在计算机历史中 xff0c 软件变得越来越复
  • 初识强化学习,什么是强化学习?

    相信很多人都听过 机器学习 和 深度学习 但是听过 强化学习 的人可能没有那么多 那么 什么是强化学习呢 强化学习是机器学习的一个子领域 它可以随着时间的推移自动学习到最优的策略 在我们不断变化的纷繁复杂的世界里 从更广的角度来看 即使是单
  • 强化学习形式与关系

    在强化学习中有这么几个术语 智能体 Agent 环境 Environment 动作 Action 奖励 Reward 状态 State 有些地方称作观察 Observation 奖励 Reward 在强化学习中 奖励是一个标量 它是从环境中
  • 多层网络和反向传播笔记

    在我之前的博客中讲到了感知器 xff08 感知器 xff09 xff0c 它是用于线性可分模式分类的最简单的神经网络模型 xff0c 单个感知器只能表示线性的决策面 xff0c 而反向传播算法所学习的多层网络能够表示种类繁多的非线性曲面 对
  • 在Kaggle手写数字数据集上使用Spark MLlib的朴素贝叶斯模型进行手写数字识别

    昨天我在Kaggle上下载了一份用于手写数字识别的数据集 xff0c 想通过最近学习到的一些方法来训练一个模型进行手写数字识别 这些数据集是从28 28像素大小的手写数字灰度图像中得来 xff0c 其中训练数据第一个元素是具体的手写数字 x
  • Ros使用自定义数据通讯无法收到消息的分析和解决

    nbsp 在实际的开发中 和别的模块定义了自定义的 数据类型 比如 userMsg msg文件 Header header int32 nState string strImageName string strYamlName 报错和原因
  • 在Kaggle手写数字数据集上使用Spark MLlib的RandomForest进行手写数字识别

    昨天我使用Spark MLlib的朴素贝叶斯进行手写数字识别 xff0c 准确率在0 83左右 xff0c 今天使用了RandomForest来训练模型 xff0c 并进行了参数调优 首先来说说RandomForest 训练分类器时使用到的