使用神经网络对手写体数字图片数据分类(MLP/PCA)

2023-11-19

使用神经网络对手写体数字图片数据分类(MLP/PCA)

使用sklearn.neural_network.MLPClassifier类实现手写数字图片识别
MLP的常用的几个参数一般为activation(选择激活函数,如relu,sigmod等,计算效率不一样),solver(权重优化算法,adam,sgd,lbfgs等,小数据集<如本项目>一般选用lbfgs效果更佳,收敛速度也更快。),alpha(正则化项参数),hidden_layer_sizes(隐藏层参数,(200,100,50)即为三层,每层神经元个数为200、100、50),此外还有batch_size:随机优化的minibatches的大小;learning_rate:学习率,constant、invscaling、adaptive;learning_rate_init:初始学习率。只有当solver为sgd或adam时才使用;power_t:逆扩展学习率的指数,只有当solver为sgd时才使用;max_iter:最大迭代次数,具体用法可以直接查文档。

来看一下不降维直接使用MLP进行识别的情况。

from sklearn.neural_network import MLPClassifier
import joblib
import numpy as np
from sklearn.metrics import accuracy_score
#加载数据集
data=np.loadtxt(r"C:/Users/Downloads/digits_training.csv",skiprows=1, delimiter=',')
#查看前五行数据和数据集大小
print("first five lines::", data[:4,:])
print("data shape:",data.shape[0])
xTrain = data[:, 1:]
yTrain = data[:, 0]
def normalizeData(X):
    return (X - X.mean())/X.max()
#数据初始化
xTrain=normalizeData(xTrain)
#建立模型,拟合mlp模型
model = MLPClassifier(activation='relu',solver= "lbfgs" ,alpha=1e-6, hidden_layer_sizes=(200, 100, 50))
model.fit(xTrain,yTrain)
#使用joblib保存模型
print("save train model")
joblib.dump(model, "mlp_classifier_model1.m")
#测试集
data2=np.loadtxt(r"C:\Users\csh\Downloads\digits_testing.csv",skiprows=1, delimiter=',')
print("first five lines:", data2[:4,:])
print("data shape:",data2.shape[0])
xTest= data2[:, 1:]
yTest = data2[:, 0]
xTest=normalizeData(xTest)
#载入模型
model2=joblib.load("mlp_classifier_model1.m")
#预测模型
pred=model2.predict(xTest)
#打印错误数据
print("error data:",(pred != yTest).sum())
#评价模型
print("accuracy_predict;",accuracy_score(yTest,pred))

mlp

我们可以看到经过训练过后,测试集的准确率能达到93.8%,模型训练耗时39.9s,重复几次发现准确率基本稳定在93%左右。
使用sklearn.decomposition的PCA类对手写体数字图片数据进行降维后的情况。
PCA降维即主成分分析,主成分分析的原理非常简单,概括来说就是选择包含信息量大的维度,去除信息量少的“干扰”维度。原理:
1.数据从原来的坐标系转换到新的坐标系,新坐标系的选择是由数据本身决定的。第一个新坐标轴选择的是原始数据中方差最大的方向(即数据差异性最大的方向),第二个新坐标轴选择与第一个新坐标轴正交且具有最大方差的方向,以此类推,共建立与原始数据特征数目相等的新坐标轴。
2.大部分方差都包含在最前面的几个新坐标轴中,因此我们可以忽略余下的坐标轴,从而实现降维。
我们先来画一个图确认方差解释程度,利用matplotlib库画图

from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import numpy as np
data=np.loadtxt(r"C:/Users/Downloads/digits_training.csv",skiprows=1, delimiter=',')
xTrain = data[:, 1:]
yTrain = data[:, 0]
def normalizeData(X):
    return X - np.mean(X, axis=0)
xTrain=normalizeData(xTrain)
pca = PCA(n_components=xTrain.shape[1])
pca.fit(xTrain)
print(pca.explained_variance_ratio_)
'''表示取前n个主成分能解释多少百分比的方差'''
plt.plot([i for i in range(xTrain.shape[1])],\
         [np.sum(pca.explained_variance_ratio_[:i+1]) for i in range(xTrain.shape[1])])
plt.show()

图片
对于方差解释度我们既要顾忌准确程度,也要顾忌效率,我们可以看见大约在150左右的维度能达到95%以上(其实准确计算后发现149时达到95%)
接下来我们使用pca降维后用mlp进行识别

from sklearn.decomposition import PCA
from sklearn.neural_network import MLPClassifier
from timeit import default_timer as timer
import joblib
import numpy as np
from sklearn.metrics import accuracy_score
tic=timer()
data=np.loadtxt(r"C:/Users/Downloads/digits_training.csv",skiprows=1, delimiter=',')
print("data shape:",data.shape[0])
xTrain = data[:, 1:]
yTrain = data[:, 0]
def normalizeData(X):
    return X - np.mean(X, axis=0)
xTrain=normalizeData(xTrain)
pca=PCA(n_components=0.95)
pca.fit(xTrain)
xTrain_re=pca.transform(xTrain)
model = MLPClassifier(activation='relu',solver= "lbfgs" ,alpha=1e-6, hidden_layer_sizes=(200, 100, 50))
model.fit(xTrain_re,yTrain)
print("save train model")
joblib.dump(model, "mlpNN_pca.m")
#测试集
data2=np.loadtxt(r"C:\Users\Downloads\digits_testing.csv",skiprows=1, delimiter=',')
print("first five lines:", data2[:4,:])
print("data shape:",data2.shape[0])
xTest= data2[:, 1:]
yTest = data2[:, 0]
xTest=normalizeData(xTest)
xTest_re=pca.transform(xTest)
#载入模型
model2=joblib.load("mlpNN_pca.m")
#预测模型
pred=model2.predict(xTest_re)
print("error data:",(pred != yTest).sum())
#评价模型
print("accuracy_predict;",accuracy_score(yTest,pred))
toc=timer()
print(toc-tic)

pca

pca降维后分类准确率有所降低,达到90.4%,是正常情况,去掉部分维度后肯定是有所降低,但是我们可以看见训练模型的时间显著缩短,只用了16.2s,差不多是未降维时训练时间的2/5,可见pca降维对训练模型效率的提升有显著帮助。带来的准确率损失相比之下就并不突出了。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用神经网络对手写体数字图片数据分类(MLP/PCA) 的相关文章

  • 在 Python 中处理单值元组的最佳实践是什么?

    我正在使用第三方库函数 它从文件中读取一组关键字 并且应该返回一个值的元组 只要有至少两个关键字 它就能正确执行此操作 但是 在只有一个关键字的情况下 它返回一个原始字符串 而不是大小为 1 的元组 这是特别有害的 因为当我尝试做类似的事情
  • python的_random是什么?

    如果你打开random py看看它是如何工作的 它的类Random子类 random Random import random class Random random Random Random number generator base
  • 地图与星图的性能?

    我试图对两个序列进行纯Python 没有外部依赖 逐元素比较 我的第一个解决方案是 list map operator eq seq1 seq2 然后我发现starmap函数来自itertools 这看起来和我很相似 但事实证明 在最坏的情
  • SMTPAuthenticationError: (535, b'5.7.8 用户名和密码在 Django 生产中不被接受?

    我在 Heroku 上部署了一个 Django 应用程序 在其中一节中 我使用 SMTP Gmail 设置向用户发送电子邮件 当我在本地运行项目时 电子邮件发送成功 但在 Heroku 上部署的项目上却发送失败 我在 Stackoverfl
  • 确定Python模块中的函数是否可用

    我正在研究一些使用Python套接字的代码socket fromfd http docs python org library socket html socket fromfd功能 但是 此方法并非在所有平台上都可用 因此我正在编写一些后
  • Python MySQL 模块

    我正在开发一个需要与 MySQL 数据库交互的 Web 应用程序 但我似乎找不到任何真正适合 Python 的模块 我特别寻找快速模块 能够处理数十万个连接 和查询 所有这些都在短时间内完成 而不会对速度产生重大影响 我想我的答案将是游戏领
  • 如何最好地将包含列表或元组的 Pandas 列提取到多个列中[重复]

    这个问题在这里已经有答案了 我不小心用错误重复的链接关闭了这个问题 这是正确的 Pandas 将列表的列拆分为多列 https stackoverflow com questions 35491274 pandas split column
  • 图像堆栈的最大强度投影

    我正在尝试重新创建该功能 max array 3 来自 MatLab 它可以获取 N 个图像的 300x300px 图像堆栈 我在这里说 图像 因为我正在处理图像 实际上这只是一个大的双数组 300x300xN 并创建一个 300x300
  • 如何抑制 pyinstaller 生成的可执行文件窗口中的所有警告

    我已经使用 pyinstaller 从 python 文件生成了可执行文件 该程序按其应有的方式工作 但在我想隐藏的窗口中出现了一条警告消息 当 python 文件在 IDE 中运行时 以下行会抑制所有警告消息 warnings filte
  • 不重复的Python组合

    我有一个数字列表 我想从中进行组合 如果我有清单 t 2 2 2 2 4 c list itertools combinations t 4 结果是 2 2 2 2 2 2 2 4 2 2 2 4 2 2 2 4 2 2 2 4 但我想得到
  • 对于 pygtk 应用程序来说,什么是好的嵌入式浏览器?

    我计划在我的 pygtk 应用程序中使用嵌入式浏览器 并且我正在 gtkmozembed 和 pywebkitgtk 之间进行辩论 两者之间有什么引人注目的区别吗 还有我不知道的第三种选择吗 应该注意的是 我不会使用它来访问网络上的内容 我
  • 如何替换被测模块的文件访问引用

    pyfakefs https code google com p pyfakefs 听起来非常有用 它 最初是作为核心 Python 模块的一个适度的假实现来开发的 以支持中等复杂的文件系统交互 并于 2006 年 9 月在 Google
  • 具有条件的重复行 pandas dataframe python

    我的数据框有问题 我的 df 是 product power brand product 1 3 x 1500W brand A product 2 2x1000W 1x100W product 3 1x1500W 1x500W brand
  • 在 Django/python 中,如何将内存缓存设置为无限时间?

    cache set key value 9999999 但这并不是无限的时间 def get memcache timeout self timeout Memcached deals with long gt 30 days timeou
  • Beautiful Soup 获取动态表数据

    我有以下代码 url https www basketball reference com leagues NBA 2017 standings html all expanded standings html urlopen url so
  • 从 sublime_plugin.WindowCommand 获取当前文件名

    我开发插件sublime text 3 并想要获取当前打开的文件路径 absolute1 self window view file name 在哪里self is sublime plugin WindowCommand 但失败了 Att
  • 安排 Asyncio 任务每 X 秒执行一次?

    我正在尝试创建一个 python 不和谐机器人 它将每隔 X 秒检查一次活跃会员 并根据会员的在线时间奖励积分 我正在使用 asyncio 来处理聊天命令 这一切都正常 我的问题是找到一种方法来安排每隔 X 秒异步检查一次活动成员 我已经阅
  • 连接运算符 + 或 ,

    var1 abc var2 xyz print literal var1 var2 literalabcxyz print literal var1 var2 literal abc xyz 除了带有 的自动空格之外 两者有什么区别 哪个通
  • 将 pandas 数据框中的多列更改为日期时间

    我有一个 13 列和 55 000 行的数据框 我正在尝试将其中 5 行转换为日期时间 现在它们返回类型 对象 我需要转换这些数据以进行机器学习 我知道如果我这样做 data birth date pd to datetime data b
  • Tensorflow ctc_loss_calculator:找不到有效路径

    当运行我的神经网络 双向 LSTM 进行音频识别时 我使用连接主义时间分类 CTC 但在某些时候 训练网络时我几乎每批都会收到来自 Tensorflow 的警告 W tensorflow core util ctc ctc loss cal

随机推荐

  • 招聘数据采集+Hive数据分析+数据可视化

    1 需求 背景描述 近年来随着IT产业的快速发展 全国各地对IT 类的人才需求数量也在不断 增多 XHS集团 为了明确今后IT产业人才培养方向 在多个省份进行 IT公 司岗位情况调研分析 你所在的小组将承担模拟调研分析的任务 通过在招聘 网
  • 【Apifox Helper】自动生成接口文档,IDEA+Apifox懒人必备

    文章目录 前言 缘由 接口文档对接爽 整理起来真费脑 本文阅读时长 约10分钟 前置条件 1 IDEA开发工具 2 Apifox 不必要 主要目标 一秒生成接口文档 水图 IDEA中项目接结构图 生成到Apifox接口文档图 试用人群 1
  • PowerDesigner15.1如何安装?汉化破解版安装教程(含安装文件、汉化包、破解文件)

    powerdesigner是能进行数据库设计的强大的软件 是一款开发人员常用的数据库建模工具 使用它可以分别从概念数据模型 Conceptual Data Model 和物理数据模型 Physical Data Model 两个层次对数据库
  • glog详解

    glog详解 1 glog的安装与使用 1 下载glog http code google com p google glog downloads detail name glog 0 3 3 tar gz 2 安装glog configu
  • 请用详细列出宝塔搭建Nacos的流程

    好的 为了在宝塔中搭建 Nacos 服务 您需要执行以下步骤 下载 Nacos 的最新版本并解压缩 使用文本编辑器打开 conf application properties 文件 并修改其中的数据库配置 使 Nacos 可以连接到您的数据
  • C++中队列的顺序存储

    队列 引言 示例 开发环境 运行结果 引言 队列先进先出 队头出队 队尾入队 其存储可以分为顺序存储和链式储存 本文记录队列的顺序存储 也就是队列中的元素存储的内存空间是连续的 这里使用数组来模拟线性队列 示例 直接上代码 SeqQueue
  • java arraylist 拷贝_Copy ArrayList的四种方式

    Copy ArrayList的四种方式 简介 ArrayList是我们经常会用到的集合类 有时候我们需要拷贝一个ArrayList 今天向大家介绍拷贝ArrayList常用的四种方式 使用构造函数 ArrayList有个构造函数 可以传入一
  • 第三周课程总结&实验报告一

    实验报告 1 打印输出所有的 水仙花数 所谓 水仙花数 是指一个3位数 其中各位数字立方和等于该数本身 例如 153是一个 水仙花数 实验代码 public class ShuiXianHua public static void main
  • L3-014 周游世界 (30 分)

    题目 题目链接 题解 DFS 采用的数据结构 vector 索引为起点 值为 终点 起点公司编号 当然你也可以保存终点公司编号 但是代码中的语句就需要改一下了 dfs 传入四个信息 当前节点 遇到的节点数 换乘数 当前节点所在公司的编号 由
  • python 的回调函数

    回调函数就是一个通过函数指针调用的函数 如果你把函数的指针 地址 作为参数传递给另一个函数 当这个指针被用来调用其所指向的函数时 我们就说这是回调函数 有些库函数 library function 却要求应用先传给它一个函数 好在合适的时候
  • 指针(一)

    这里写目录标题 一 什么是指针 二 指针和指针类型 三 野指针 四 指针运算 五 指针和数组 六 二级指针 七 指针数组 一 什么是指针 1 指针是内存中一个最小单元的编号 也就是地址 2 平时我们说的指针 通常指的是指针变量 是用来存放地
  • 大数据手册(Spark)--Spark基本概念

    文章目录 Spark 基本概念 Hadoop 生态 Spark 生态 Spark 基本架构 Spark运行基本流程 弹性分布式数据集 RDD Spark安装配置 Spark基本概念 Spark基础知识 PySpark版 Spark机器学习
  • 用户互动优化:微信营销系统实践

    在当今移动互联网时代 微信已经成为企业进行营销的重要平台之一 而用户互动作为营销的关键环节 对于提升品牌影响力 增强用户忠诚度至关重要 本文将深入探讨如何通过微信营销系统实践 优化用户互动 提升营销效果 建立更紧密的用户关系 1 理解用户互
  • 修改本地host文件加入可用ip使谷歌浏览器翻译插件重新生效

    修改本地host文件加入可用ip使谷歌浏览器翻译插件重新生效 第一步 找到host文件 可以使用这个工具进行对Hosts文件进行一个查找 鼠标放到对应路径上面 点击鼠标右键 选择打开路径就到对应 路径了 也可以复制到这个路径下面去找host
  • .NET6-Asp.Net Core webapi -从零开始的webapi项目

    本项目为本人22年毕设项目 后续会不断更新本篇文章 全部内容都会写在这一篇文章里 喜欢的请持续关注 一 如何创建 Asp Net Core webapi 项目 二 如何使用 EntityFrameWorkCore DbFirst 需要用到的
  • 人工智能 ai基础知识_如何使用人工智能改善基础医疗的成果和效率

    人工智能 ai基础知识 Only 7 percent of a message is based on the words it contains The rest 93 percent comes from the speaker s t
  • 第一个只出现一次的字符(Java)

    题目 在字符串中找出第一个只出现一次的字符 如输入 abaccdeff 则输出 b 第一思路 借助于数组来做 开辟一个长度为26的数组 用来存放字符串中每个字符出现的次数 这样第一次扫描去统计这个字符串中字符出现的次数 第二次去统计第一个出
  • [Leetcode] 3.无重复字符的最长子串

    题目描述 给定一个字符串 找出不含有重复字符的最长子串的长度 示例 给定 abcabcbb 没有重复字符的最长子串是 abc 那么长度就是3 给定 bbbbb 最长的子串就是 b 长度是1 给定 pwwkew 最长子串是 wke 长度是3
  • CUDA复制测试

    这里主要是测试了内存数据读写操作的几种方式 记录了一些测试结果 对于二维数组 10244 1024 4 1 二维线程格 每个线程对应一个元素 2 转换为int2类型 线程宽度减半 3 线程宽度和高度减半 单个线程操作邻近的4个元素 4 线程
  • 使用神经网络对手写体数字图片数据分类(MLP/PCA)

    使用神经网络对手写体数字图片数据分类 MLP PCA 使用sklearn neural network MLPClassifier类实现手写数字图片识别 MLP的常用的几个参数一般为activation 选择激活函数 如relu sigmo