机器学习案例6:基于SVM的数字识别

2023-10-27

案例6:基于SVM的数字识别

为什么写本博客

​ 前人种树,后人乘凉。希望自己的学习笔记可以帮助到需要的人。

需要的基础

​ 懂不懂原理不重要,本系列的目标是使用python实现机器学习。

​ 必须会的东西:python基础、numpy、pandas、matplotlib和库的使用技巧。

说明

​ 完整的代码在最后,另外之前案例中出现过的方法不会再讲解。

目录结构

1. 涉及的新方法:

模型创建

from sklearn import svm
# 创建模型
model = svm.SVC()
'''
核心参数:
	C 正则化因子,越大,意味着划分越严格,即越接近线性分割
	kernel 核函数,常用的 poly\rbf\sigmoid
	degree 当为poly多项式核函数的时候启用,指定多项式次数
'''

2. 数据集介绍与处理:

数据集介绍与下载

​ MNIST是一个经典的手写数字数据,也是一个公开的小型数据。

​ 数据集可以通过官网进行下载http://yann.lecun.com/exdb/mnist/,不过,这个网址似乎要用点魔法才可以打开,所以,我也用百度云分享下:

链接:https://pan.baidu.com/s/1P5c4GJQqfuWDP2g6wM5Slw 
提取码:6666 

​ 其中,主要分为两个文件夹,一个是原始数据文件夹,即从网站下载的;第二个是做出处理后的文件夹,即划分为测试集和训练集,并且将图像数据转为矩阵数据了,这也是我们需要使用的数据(这个数据集来自于网上黑马的案例,特此声明)。

MNIST中的图像每个都是28*28=784的大小,并且为灰度图,值为0-255。训练集共有785列,第一列为标签列,后面每列为一个像素值。

数据集加载和提取

​ 首先,使用pandas加载数据:

# 加载数据
data = pd.read_csv('./data/MNIST/train.csv')
print(data.head())
print(data.shape)

​ 打印结果为:

   label  pixel0  pixel1  pixel2  ...  pixel780  pixel781  pixel782  pixel783
0      1       0       0       0  ...         0         0         0         0
1      0       0       0       0  ...         0         0         0         0
2      1       0       0       0  ...         0         0         0         0
3      4       0       0       0  ...         0         0         0         0
4      0       0       0       0  ...         0         0         0         0

[5 rows x 785 columns]
(42000, 785)

​ 接着,提取出x和y数据,但是由于整体数据量太大,足足有42000条,而这里我们又不采用降维或者特征提取等手段,会导致模型计算时间非常长,因此我们取出3000条数据进行测试:

# 提取x和y
x_train = data.iloc[:3000,1:]
y_train = data.iloc[:3000,0]
print(y_train.head())

​ 打印结果为:

0    1
1    0
2    1
3    4
4    0
Name: label, dtype: int64

数据集显示

​ 我们定义一个显示函数,可以把任意一条数据显示为数字,处理的思路:首先把行向量转为矩阵,再用matplotlib显示即可;

# 定义显示函数
def show_image(index):
    # index : 传入的索引
    target = x_train.iloc[index,:].values.reshape(28,28)
    plt.imshow(target)
    plt.show()

# 尝试
show_image(0)

​ 显示的结果为:
在这里插入图片描述

一个明显的问题,这里显然颜色不对劲,这是因为matplotlib的颜色空间和我们认为的颜色空间不一致所致。不过,这也无伤大雅,问题不大(当然,你想改还是很轻松的,可以使用opencv库中的颜色空间转变)。

归一化

​ 下面,将数据归一化,本来归一化公式为:

(xi - x_min) / (x_max - x_min)

​ 而这里最小值为0,最大值为255,所以直接改写为:

x / 255

​ 因此,代码可写为:

# 归一化处理
x_train = x_train.values / 255
y_train = y_train.values

划分数据集

​ 由于给出的test.csv中没有标签,因此暂时没有办法用,只好将训练集划分了:

# 数据集划分
x_train,x_test,y_train,y_test = train_test_split(x_train,y_train,test_size=0.2,random_state=2)

3. 创建模型、训练和评估:

​ 创建模型、训练和评估:

# 创建模型
model = svm.SVC()
model.fit(x_train,y_train)
# 评估
score = model.score(x_test,y_test)
print('准确率:',score)

​ 打印结果为:

准确率: 0.9416666666666667

4. 探究不同参数准确率结果:

探究数据量影响

​ 这里,我分别使用1000\2000\3000\4000的数据来测试,所得结果为:

# 3000 : 准确率: 0.9416666666666667
# 2000 : 准确率: 0.945
# 1000 : 准确率: 0.92
# 4000 : 准确率: 0.9325

不同正则化因子的影响:

​ 数据集为3000条,进行测试,结果如下:

# 3000
# C=0.5 准确率: 0.9266666666666666
# C=0.7 准确率: 0.93
# C=1 准确率: 0.9416666666666667
# C=1.2 准确率: 0.9433333333333334
# C=5 准确率: 0.94
# C=20 准确率: 0.94

不同核函数

​ 采取了高斯核函数、sigmoid核函数和多项式核函数:

# RBF : 准确率: 0.9416666666666667
# sigmoid : 准确率: 0.865
# poly : 次数为3 准确率: 0.915 ; 次数为4 准确率: 0.875 ; 次数为5 准确率: 0.82

​ 通过上面的结果,知道不同的参数具有不同的影响,当然单纯通过数据集大小来判断模型好坏不可取,主要还是对数据集的利用程度。

5. 总结和完整代码:

​ 这里不得不提一句,明明有这么多的数据,但是由于没有进行特征提取,导致我们只能对原始数据加载,这会导致我们的计算时间拉长。这告诉我们,特征提取在机器学习中的重要性。

​ 完整代码:

# author: baiCai
# 导包
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn import svm
from sklearn.model_selection import train_test_split

# 加载数据
data = pd.read_csv('./data/MNIST/train.csv')
# print(data.head())
# print(data.shape) # (42000, 785)
# 提取x和y
x_train = data.iloc[:3000,1:]
y_train = data.iloc[:3000,0]
# print(y_train.head())

# 定义显示函数
def show_image(index):
    # index : 传入的索引
    target = x_train.iloc[index,:].values.reshape(28,28)
    plt.imshow(target)
    plt.show()
# 尝试
# show_image(0)

# 归一化处理
x_train = x_train.values / 255
y_train = y_train.values
# 数据集划分
x_train,x_test,y_train,y_test = train_test_split(x_train,y_train,test_size=0.2,random_state=2)


# 创建模型
model = svm.SVC()
model.fit(x_train,y_train)
# 评估
score = model.score(x_test,y_test)
print('准确率:',score)

# 3000 : 准确率: 0.9416666666666667
# 2000 : 准确率: 0.945
# 1000 : 准确率: 0.92
# 4000 : 准确率: 0.9325


# 3000
# C=0.5 准确率: 0.9266666666666666
# C=0.7 准确率: 0.93
# C=1 准确率: 0.9416666666666667
# C=1.2 准确率: 0.9433333333333334
# C=5 准确率: 0.94
# C=20 准确率: 0.94

# RBF : 准确率: 0.9416666666666667
# sigmoid : 准确率: 0.865
# poly : 次数为3 准确率: 0.915 ; 次数为4 准确率: 0.875 ; 次数为5 准确率: 0.82
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

机器学习案例6:基于SVM的数字识别 的相关文章

随机推荐

  • A Survey of Large Language Models

    本文是LLM系列的第一篇文章 针对 A Survey of Large Language Models 的翻译 大语言模型综述 摘要 1 引言 2 概述 2 1 LLM的背景 2 2 GPT系列模型 的技术演化 3 LLMs的资源 3 1
  • 新优选商城上线发布会,京庐空间执行总裁赵娅勤女士接受采访!

    新优选商城上线发布会 京庐空间执行总裁赵娅勤女士接受采访现场 互联网的快速发展对于推动各行各业重构商业业态具有深远的历史影响 互联网发展进入21世纪 在全球信息化进程加快的背景下 电子商务 新零售 社交电商等行业业态也为传统行业提供了新的增
  • 正交试验设计例题及答案_正交矩阵求解

    我对题的难度进行划分 满难度是 难度中包含 计算难度和 思路难度 满难度各 难度评测是以我初见题目视角写出的标准 比较主观 另外 考试时的解答时间的长短也需要考虑 毕竟有些东西复习多了见过了 或者看过答案了就不难了 假如每颗 代表20分 我
  • ubuntu:关闭某个进程

    参考 http blog csdn net chen861201 article details 6980677 ps aux grep xxx 程序名称 kill xxx 某个PID
  • ganglia监控hadoop 容器节点

    hadoop容器运行参考上篇博客 http blog csdn net wenwenxiong article details 78973755 参看网址 https gist github com ameizi 0c77e3dbb13de
  • (二十五)admin-boot项目之集成消息队列Rabbitmq

    目录 项目地址 https gitee com springzb admin boot 如果觉得不错 给个 star 简介 这是一个基础的企业级基础后端脚手架项目 主要由springboot为基础搭建 后期整合一些基础插件例如 redis
  • java 动态线程池_线程池的参数动态调整

    经典面试题 这次的文章还是绕回了我写的第三篇原创文章 有的线程它死了 于是它变成一道面试题 中留下的几个问题 哎 兜兜转转 走走停停 天道好轮回 苍天饶过谁 在这篇文章中我主要回答上面抛出的这个问题 你这几个参数的值怎么来的呀 要回答这个问
  • OpenGL2 spec releases at the SIGGRAPH2004

    发信人 chsoft 珍惜光华 善待光华 信区 Graphics标 题 OpenGL2 spec releases at the SIGGRAPH2004发信站 日月光华 2004年08月11日13 22 06 星期三 站内信件 SIGGR
  • 编译busybox有这个提示,是怎么回事

    我编译busybox有这个提示 是怎么回事 有人知道吗 分类 海思论坛 https www ebaina com questions 100000031827
  • DC-DC直流斩波---BUCK降压斩波电路

    降压斩波电路 Buck Chopper 的原理图及工作波形 该电路使用一个全控型器件V 图中为IGBT 也可使用其他器件 若采用晶闸管 需设置使晶闸管关断的辅助电路 图5 1中 为在V关断时给负载中电感电流提供通道 设置了续流二极管VD 斩
  • Linux 文件权限

    一 文件权限 Linux系统中的每个文件和目录都有访问许可权限 用他来确定谁能通过何种方式对文件和目录进行访问和操作 文件或目录的访问权限分为只读 只写和可执行三种 Linux文件权限一共10位长度 分成四段 第一段1位 表示文件类型 d表
  • glGetString(GL_VERSION)、 glGetIntegerv(GL_MAX_TEXTURE_SIZE, &max)为何老是得不到正确的值

    今天碰到了的问题如下 在程序里调用 printf s r n glGetString GL VERSION 总是输出 null glGetIntegerv GL MAX TEXTURE SIZE max 得不到max的值 答 在MFC情况下
  • @RefreshScope注解处理

    spring启动时会调用ClassPathBeanDefinitionScanner java类中的doScan 对包路径下的所有class进行扫描 获取bean的定义 同时对bean的 RefreshScope Scope的父类 进行处理
  • JSP基础_0700_HelloWorld 全局变量和局部变量

    本文讨论jsp中生成的servlet代码中全局变量和局部变量的问题 请看下面一段代码
  • tar包安装

    在Linux操作系统中 常用的软件包一共有两种 rpm包 相当于Windows中的exe软件包 tar gz包 未编译的源码包 软件的编译需要使用gcc编译器 Linux安装 开发工具 gt gcc gcc c tar包解压 基本语法 ta
  • WIN7&WIN10共享打印机0x000000709错误解决方法

    这两天连续碰到709错误 打印机安装正常 共享正常 但通过 ip 打印机名添加共享时就709错误 开始以为是printspooler服务的问题 但试过故障依旧 也检查了共享设置 密码共享是关闭的 guest用户是开启的 为什么呢 最终发现了
  • numpy ndarray 打印格式化

    1 ndarray打印省略问题 np set printoptions threshold np inf 2 ndarray打印换行限制 加上下面这句代码 输出时打印不换行 np set printoptions linewidth 400
  • 用Servlet结合c3p0连接池等写一个简单的注册登录

    首先 给一张截图 上面图是我的整体内容 1 先进入工具类 代码如下 package com qf util import javax sql DataSource import com mchange v2 c3p0 ComboPooled
  • oracle体验实验,Oracle实验三

    1 实验目的 1 掌握表的创建与管理 2 掌握索引的创建与管理 3 掌握视图的创建与管理 4 掌握序列的创建与应用 2 实验环境 Win10 以及Oracle 11g 3 实验要求 1 为图书销售系统创建表 2 在图书销售系统适当表的适当列
  • 机器学习案例6:基于SVM的数字识别

    案例6 基于SVM的数字识别 为什么写本博客 前人种树 后人乘凉 希望自己的学习笔记可以帮助到需要的人 需要的基础 懂不懂原理不重要 本系列的目标是使用python实现机器学习 必须会的东西 python基础 numpy pandas ma