基于卷积神经网络(cnn)的手写数字识别(PyTorch)

2023-11-11

目录

1.1 卷积神经网络简介

1.2 神经网络

1.2.1 神经元模型

 1.2.2 神经网络模型

1.3 卷积神经网络

1.3.1卷积的概念

1.3.2 卷积的计算过程

1.3.3 感受野

1.3.4 步长

1.3.5 输出特征尺寸计算

 1.3.6 全零填充

1.3.7 标准化

1.3.7 池化层

 1.4 卷积神经网络的全过程

 1.5 PyTorch的卷积神经网络(cnn)手写数字识别

1.5.1 代码


1.1 卷积神经网络简介

卷积神经网络(Convolutional Neural Networks,简称:CNN)是深度学习当中一个非常重要的神经网络结构。它主要用于用在图像图片处理,视频处理,音频处理以及自然语言处理等等。
早在上世纪80年代左右,卷积神经网络的概念就已经被提出来了。但其真正的崛起却是在21世纪之后,21世纪之后,随着深度学习理论的不断完善,同时,由计算机硬件性能的提升,计算机算力的不断发展,给卷积神经网络这种算法提供了应用的空间。著名的AlphaGo,手机上的人脸识别,大多数都是采用卷积神经网络。因此可以说,卷积神经网络在如今的深度学习领域,有着举足轻重的作用。

在了解卷积神经网络之前,我们务必要知道:什么是神经网络(Neural Networks),关于这个,我们已经在深度学习简介的 第二部分有所介绍。这里就不赘述了。在了解了神经网络的基础上,我们再来探究:卷积神经网络又是什么呢?当中的“卷积”这个词,又意味着什么呢?
 

1.2 神经网络

1.2.1 神经元模型


人工神神经网络(neural networks)方面的研究很早就已出现,今天“神经网络”    已是一个相当大的、多学科交叉的学科领域.各相关学科对神经网络的定义多种多样。简单单元组成的广泛并行互连的网络,它的组织能够模拟生物神经系统对真实世界物体所作出的交互反应” 。
神经网络中最基本的成分是神经元(neuron)模型,即上述定义中的“简单单元”,在生物神经网络中,每个神经元与其他神经元相连,当它“兴奋”时,就会向相连的神经元发送化学物质,从而改变这些神经元内的电位;如果某神经元的电位超过了一个“阈值”(threshold),那么它就会被激活,即“兴奋”起来,向其他神经元发送化学物质。在这个模型中,神经元接收到来自n个其他神经元传递过来的输入信号,这些输入信号通过带权重的连接(connection)进行传递,神经元接收到的总输入值将与神经元的间值进行比较,然后通过激活函数处理,产生神经元输出。

 

 1.2.2 神经网络模型

 神经网络是一种运算模型,由大量的节点(或称神经元)之间相互联接构成。每个节点代表一种特定的输出函数,称为激励函数(activation function)。每两个节点间的连接都代表一个对于通过该连接信号的加权值,称之为权重,这相当于人工神经网络的记忆。网络的输出则依网络的连接方式,权重值和激励函数的不同而不同。而网络自身通常都是对自然界某种算法或者函数的逼近,也可能是对一种逻辑策略的表达。
 

1.3 卷积神经网络

1.3.1卷积的概念

卷积神经网络与普通神经网络的区别在于,卷积神经网络包含了一个由卷积层子采样层(池化层)构成的特征抽取器。在卷积神经网络的卷积层中,一个神经元只与部分邻层神经元连接。在CNN的一个卷积层中,通常包含若干个特征图(featureMap),每个特征图由一些矩形排列的的神经元组成,同一特征图的神经元共享权值,这里共享的权值就是卷积核。卷积核一般以随机小数矩阵的形式初始化,在网络的训练过程中卷积核将学习得到合理的权值。共享权值(卷积核)带来的直接好处是减少网络各层之间的连接,同时又降低了过拟合的风险。子采样也叫做池化(pooling),通常有均值子采样(mean pooling)最大值子采样(max pooling)两种形式。子采样可以看作一种特殊的卷积过程。卷积和子采样大大简化了模型复杂度,减少了模型的参数。

1.3.2 卷积的计算过程

假设我们输入的是5*5*1的图像,中间的那个3*3*1是我们定义的一个卷积核(简单来说可以看做一个矩阵形式运算器),通过原始输入图像和卷积核做运算可以得到绿色部分的结果,怎么样的运算呢?实际很简单就是我们看左图中深色部分,处于中间的数字是图像的像素,处于右下角的数字是我们卷积核的数字,只要对应相乘再相加就可以得到结果。例如图中‘3*0+1*1+2*2+2*2+0*2+0*0+2*0+0*1+0*2=9’

计算过程如下动图:

图中最左边的三个输入矩阵就是我们的相当于输入d=3时有三个通道图,每个通道图都有一个属于自己通道的卷积核,我们可以看到输出(output)的只有两个特征图意味着我们设置的输出d=2,有几个输出通道就有几层卷积核(比如图中就有FilterW0和FilterW1),这意味着我们的卷积核数量就是输入d的个数乘以输出d的个数(图中就是2*3=6个),其中每一层通道图的计算与上文中提到的一层计算相同,再把每一个通道输出的输出再加起来就是绿色的输出数字。

1.3.3 感受野

感受野(Receptive Field):卷积神经网络各输出层每个像素点在原始图像上的映射区域大小。
下图为感受野示意图:

 当我们采用尺寸不同的卷积核时,最大的区别就是感受野的大小不同,所以经常会采用多层小卷积核来替换一层大卷积核,在保持感受野相同的情况下减少参数量和计算量。
例如十分常见的用2层3 * 3卷积核来替换1层5 * 5卷积核的方法,如下图所示。

1.3.4 步长

每次卷积核移动的大小。

1.3.5 输出特征尺寸计算

输出特征尺寸计算:在了解神经网络中卷积计算的整个过程后,就可以对输出特征图的尺寸进行计算。如下图所示,5×5的图像经过3×3大小的卷积核做卷积计算后输出特征尺寸为3×3

 1.3.6 全零填充

当卷积核尺寸大于 1 时,输出特征图的尺寸会小于输入图片尺寸。如果经过多次卷积,输出图片尺寸会不断减小。为了避免卷积之后图片尺寸变小,通常会在图片的外围进行填充(padding),如下图所示

全零填充(padding):为了保持输出图像尺寸与输入图像一致,经常会在输入图像周围进行全零填充,如下所示,在5×5的输入图像周围填0,则输出特征尺寸同为5×5。

当padding=1和paadding=2时,如下图所示:

1.3.7 标准化

使数据符合0均值,1为标准差的分布。
批标准化(Batch Normalization):对一小批数据(batch),做标准化处理。

 Batch Normalization将神经网络每层的输入都调整到均值为0,方差为1的标准正态分布,其目的是解决神经网络中梯度消失的问题.

BN操作的另一个重要步骤是缩放和偏移,值得注意的是,缩放因子γ以及偏移因子β都是可训练参数。 

1.3.7 池化层

池化(Pooling)用于减少特征数据量。
最大值池化可提取图片纹理,均值池化可保留背景特征

 1.4 卷积神经网络的全过程

 1.5 PyTorch的卷积神经网络(cnn)手写数字识别

使用的框架为pytorch。

数据集:MNIST数据集,60000张训练图像,每张图像size为28*28。

可在http://yann.lecun.com/exdb/mnist/中获取

1.5.1 代码

import torch
import torch.nn as nn
import torchvision.datasets as dataset
import torchvision.transforms as transforms
import torch.utils.data as data_utils

#获取数据集
train_data=dataset.MNIST(root="D",
                         train=True,
                         transform=transforms.ToTensor(),
                         download=True
                         )
test_data=dataset.MNIST(root="D",
                         train=False,
                         transform=transforms.ToTensor(),
                         download=False
                         )
train_loader=data_utils.DataLoader(dataset=train_data, batch_size=100, shuffle=True)
test_loader=data_utils.DataLoader(dataset=test_data, batch_size=100, shuffle=True)

#创建网络
class Net(torch.nn.Module):
   def __init__(self):
        super().__init__()
        self.conv=nn.Conv2d(1, 32, kernel_size=5, padding=2)
        self.bat2d=nn.BatchNorm2d(32)
        self.relu=nn.ReLU()
        self.pool=nn.MaxPool2d(2)
        self.linear=nn.Linear(14 * 14 * 32, 70)
        self.tanh=nn.Tanh()
        self.linear1=nn.Linear(70,30)
        self.linear2=nn.Linear(30, 10)
   def forward(self,x):
        y=self.conv(x)
        y=self.bat2d(y)
        y=self.relu(y)
        y=self.pool(y)
        y=y.view(y.size()[0],-1)
        y=self.linear(y)
        y=self.tanh(y)
        y=self.linear1(y)
        y=self.tanh(y)
        y=self.linear2(y)
        return y
cnn=Net()
cnn=cnn.cuda()

#损失函数
los=torch.nn.CrossEntropyLoss()

#优化函数
optime=torch.optim.Adam(cnn.parameters(), lr=0.01)

#训练模型
for epo in range(10):
   for i, (images,lab) in enumerate(train_loader):
        images=images.cuda()
        lab=lab.cuda()
        out = cnn(images)
        loss=los(out,lab)
        optime.zero_grad()
        loss.backward()
        optime.step()
        print("epo:{},i:{},loss:{}".format(epo+1,i,loss))

#测试模型
loss_test=0
accuracy=0
with torch.no_grad():
   for j, (images_test,lab_test) in enumerate(test_loader):
        images_test = images_test.cuda()
        lab_test=lab_test.cuda()
        out1 = cnn(images_test)
        loss_test+=los(out1,lab_test)
        loss_test=loss_test/(len(test_data)//100)
        _,p=out1.max(1)
        accuracy += (p==lab_test).sum().item()
        accuracy=accuracy/len(test_data)
        print("loss_test:{},accuracy:{}".format(loss_test,accuracy))

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

基于卷积神经网络(cnn)的手写数字识别(PyTorch) 的相关文章

随机推荐

  • numpy中设置始终使用定点表示法显示小数

    默认numpy会在某些情况触发科学计数法显示 scientific notation is used when absolute value of the smallest number is lt 1e 4 or the ratio of
  • 前置++和后置++的区别

    今天在阅读 google c 编程风格 的文档的时候 5 10 前置自增和自减 有一句话引起了我的注意 对于迭代器和其他模板对象使用前缀形式 i 的自增 自减运算符 理由是 前置自增 i 通常要比后置自增 i 效率更高 于是我查了查前置 和
  • 【C++学习第五讲】第一章总结 + 复习题(十一道)

    目录 第一章总结 一 总结 二 复习题 1 C 程序的模块叫什么 2 下面的预处理器编译指令的功能是什么 3 下面的语句的功能是什么 4 什么语句可以用来输出 hello world 然后开始新的一行 5 什么语句可以用来创建名为chees
  • TBDR下msaa 在metal vulkan和ogles的解决方案

    https developer arm com solutions graphics developer guides understanding render passes multi sample anti aliasing msaa在
  • 第十四届蓝桥杯大赛软件赛省赛(Java 大学B组)

    目录 试题 A 阶乘求和 1 题目描述 2 解题思路 3 模板代码 试题 B 幸运数字 1 题目描述 2 解题思路 3 模板代码 试题 C 数组分割 1 题目描述 2 解题思路 3 模板代码 试题 D 矩形总面积 1 问题描述 2 解题思路
  • (译)cocos2d-x跨android&ios平台开发入门教程

    免责申明 必读 本博客提供的所有教程的翻译原稿均来自于互联网 仅供学习交流之用 切勿进行商业传播 同时 转载时不要移除本申明 如产生任何纠纷 均与本博客所有人 发表该翻译稿之人无任何关系 谢谢合作 原文链接地址 http www raywe
  • Windows下小白安装Qt详细教程

    一 软件下载 官网下载地址 http download qt io 1 点击进入 2 archive 和 official releases 两个目录都有最新的 Qt 开发环境安装包 我们以 archive 目录里的内容为例来说明 点击进入
  • 太强大!发现一个数据分析老司机专用神器!

    去年秋招 字节跳动整体报录比降到了2 创造了150000人争3000岗位的盛况 今年909万毕业生再创新高 激烈程度可想而知 除了技术岗 大部分毕业生也瞄准了高薪高前景的数据分析师岗位 教育部关于高校毕业生就业工作通知 人才缺口 巨大 未来
  • 用python进行数据分析(一:数据理解)

    python作为当前主流的语言之一 他的功能是非常强大的 不论是在游戏行业还是数据分析行业还是软件开发啥的好像都可以用python 但作为一个数据分析师 并不需要用到他的全部功能 只是想要达到 能够用python完成数据分析工作 的效果来帮
  • 同步FIFO的verilog实现(2)——高位扩展法

    一 前言 在之前的文章中 我们介绍了同步FIFO的verilog的一种实现方法 计数法 其核心在于 在同步FIFO中 我们可以很容易的使用计数来判断FIFO中还剩下多少可读的数据 从而可以判断空 满 关于计数法实现同步FIFO的详细内容 请
  • logback 配置文件 XML 案例

    logback配置文件案例 1 实现功能 1 控制台输出日志 2 info warn error 三个级别的日志分文件输出 3 日志文件 按天 按文件大小 滚动保存 4 日志文件保存于 项目根目录下的 logs 目录下 2 具体配置
  • STM 32如何实现程序自加密

    在嵌入式应用开发中 应用开发完成后往往需要对芯片中的程序进行加密处理 用以保护程序安全 不至被竞争对手从芯片把程序固件考走 本节将给大学介绍一个如何实现程序自动给芯片加密功能 下面给大家介绍一个STM32 用程序给MCU加密码的方法 标准库
  • 解决stata安装外部命令报错cannot write in directory C:\Users\�ƿ���\ado\plus\_

    参考网址 https bbs pinggu org thread 10685955 1 1 html ado文件下没有plus文件夹 在do文件或命令行中输入以下三个命令 sysdir set PLUS D stata17 MP ado p
  • Search and Replace -- 搜索与替换的高级利器

    对于从事电脑无纸化办公拟文写作的朋友 随着文档的增多 要查找一个遗忘的文件犹如大海捞针 虽然Windows系统已有很强的搜索功能 但依然不能满足我们的要求 如Windows不能搜索WPS格式的文件 不能搜索数据库 而在第三方软件的帮助下便可
  • 【YOLO系列】YOLOv6论文超详细解读(翻译 +学习笔记)

    前言 YOLOv6 是美团视觉智能部研发的一款目标检测框架 致力于工业应用 论文题目是 YOLOv6 A Single Stage Object Detection Framework for Industrial Applications
  • Xilinx FPGA 学习笔记——时钟资源

    在Xilinx的FPGA中 时钟网络资源分为两大类 全局时钟资源和区域时钟资源 全局时钟资源是一种专用互连网络 它可以降低时钟歪斜 占空比失真和功耗 提高抖动容限 Xilinx的全局时钟资源设计了专用时钟缓冲与驱动结构 从而使全局时钟到达C
  • vue cli4.5.13项目兼容IE问题记录

    用脚手架安装项目后 IE遇到问题如下 一 报错1002 点击错误发现在socketjs client 办法 降低socketjs client 版本 npm install sockjs client 1 5 1 二 安装后 仍然白屏 按照
  • 数据搜索之二分查询

    数据搜索中 如果给定数据集是乱序的情况下一般我们使用顺序搜索按位查询是最常用的方法 但是一旦数据是顺序的 二分法则能大大减少数据搜索的工作量 尤其在几十万甚至上亿的数据量情况下 它的效率就能大大的体现 二分法的思想是通过每次把数据集所在小区
  • html怎样去除超链接的样式_前端从入门到精通

    HTML 标题 标题 Heading 是通过 h1 h6 等标签进行定义的 h1 定义最大的标题 h6 定义最小的标题 注释 浏览器会自动地在标题的前后添加空行 注释 默认情况下 HTML 会自动地在块级元素前后添加一个额外的空行 比如段落
  • 基于卷积神经网络(cnn)的手写数字识别(PyTorch)

    目录 1 1 卷积神经网络简介 1 2 神经网络 1 2 1 神经元模型 1 2 2 神经网络模型 1 3 卷积神经网络 1 3 1卷积的概念 1 3 2 卷积的计算过程 1 3 3 感受野 1 3 4 步长 1 3 5 输出特征尺寸计算