PyTorch基础入门六:PyTorch搭建卷积神经网络实现MNIST手写数字识别

2023-11-13

1)卷积神经网络(CNN)简介

关于什么是卷积神经网络(CNN),请自行查阅资料进行学习。如果是初学者,这里推荐一下台湾的李宏毅的深度学习课程。链接就不给了,这些资料网站上随处可见。

值得一提的是,CNN虽然在图像处理的领域具有不可阻挡的势头,但是它绝对不仅仅只能用来图像处理领域,大家熟知的alphaGo下围棋也可以通过CNN的结构进行处理,因为下围棋与图像有着相似之处,所以说,CNN提供给我们的是一种处理问题的思想,有学者归纳出了可以用CNN解决的问题所具备的三个性质:

  • 局部性

对于一张图片而言,需要检测图片中的特征来决定图片的类别,通常情况下这些特征都不是由整张图片决定的,而是由一些局部的区域决定的。例如在某张图片中的某个局部检测出了鸟喙,那么基本可以判定图片中有鸟这种动物。

  • 相同性

对于不同的图片,它们具有同样的特征,这些特征会出现在图片的不同位置,也就是说可以用同样的检测模式去检测不同图片的相同特征,只不过这些特征处于图片中不同的位置,但是特征检测所做的操作几乎一样。例如在不同的图片中,虽然鸟喙处于不同的位置,但是我们可以用相同的模式去检测。

  • 不变性

对于一张图片,如果我们进行下采样,那么图片的性质基本保持不变。

 

2)PyTorch中的卷积神经网络

简要介绍一下PyTorch中卷积神经网络中用到的一些方法。

  • 卷积层:nn.Conv2d()

其参数如下:

参数· 含义
in_channels 输入信号的通道数.
out_channels 卷积后输出结果的通道数.
kernel_size 卷积核的形状. 例如kernel_size=(3, 2)表示3X2的卷积核,如果宽和高相同,可以只用一个数字表示
stride 卷积每次移动的步长, 默认为1.
padding  处理边界时填充0的数量, 默认为0(不填充).
dilation 采样间隔数量, 默认为1, 无间隔采样.
groups 输入与输出通道的分组数量. 当不为1时, 默认为1(全连接).
bias 为 True 时, 添加偏置.

 

 

 

 

 

 

 

 

 

 

当然,这么多参数有一些是不常用的,读者只需要在实践中慢慢体会一些常用的即可,其他参数需要将理论打扎实之后去官网查阅。

  • 池化层:nn.MaxPool2d()

其参数如下:

参数 含义
kernel_size 最大池化操作时的窗口大小
stride 最大池化操作时窗口移动的步长, 默认值是 kernel_size
padding  输入的每条边隐式补0的数量
dilation   用于控制窗口中元素的步长的参数
return_indices 如果等于 True, 在返回 max pooling 结果的同时返回最大值的索引 这在之后的 Unpooling 时很有用
ceil_mode 如果等于 True, 在计算输出大小时,将采用向上取整来代替默认的向下取整的方式

 

3)实现MNIST手写数字识别

一共定义了五层,其中两层卷积层,两层池化层,最后一层为FC层进行分类输出。其网络结构如下:

中间一行表示当前数据块的维度,第一个维度为深度,后面两个为宽度和高度。输入数据为灰度图,所以深度为1,图片像素为28*28的图片,后面经过卷积,池化,会发现深度不断加深,而宽度和高度会逐渐减少,因此,最后CNN处理过的图片只是一个局部的图片,换句话说,计算机在进行CNN对图片进行识别的时候,它通过观察图片局部的信息来进行分类的,这一点和我们通过人眼来观察图片进行分类是不一样的。

下面是CNN网络的代码实现:

# !/usr/bin/python
# coding: utf8
# @Time    : 2018-08-05 19:22
# @Author  : Liam
# @Email   : luyu.real@qq.com
# @Software: PyCharm
#                        .::::.
#                      .::::::::.
#                     :::::::::::
#                  ..:::::::::::'
#               '::::::::::::'
#                 .::::::::::
#            '::::::::::::::..
#                 ..::::::::::::.
#               ``::::::::::::::::
#                ::::``:::::::::'        .:::.
#               ::::'   ':::::'       .::::::::.
#             .::::'      ::::     .:::::::'::::.
#            .:::'       :::::  .:::::::::' ':::::.
#           .::'        :::::.:::::::::'      ':::::.
#          .::'         ::::::::::::::'         ``::::.
#      ...:::           ::::::::::::'              ``::.
#     ```` ':.          ':::::::::'                  ::::..
#                        '.:::::'                    ':'````..
#                     美女保佑 永无BUG
from torch import nn

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 25, kernel_size=3),
            nn.BatchNorm2d(25),
            nn.ReLU(inplace=True)
        )

        self.layer2 = nn.Sequential(
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.layer3 = nn.Sequential(
            nn.Conv2d(25, 50, kernel_size=3),
            nn.BatchNorm2d(50),
            nn.ReLU(inplace=True)
        )

        self.layer4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.fc = nn.Sequential(
            nn.Linear(50 * 5 * 5, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

然后利用上述模型进行处理,其处理的方法和上一篇博文中的方法是一样的,这里不再赘述。

可以看到处理结果比上一次好多了:

 

完整代码请移步GitHub

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

PyTorch基础入门六:PyTorch搭建卷积神经网络实现MNIST手写数字识别 的相关文章

随机推荐

  • uniapp分页u-loadmore加载更多uview-ui小程序分页app分页处理

    uniapp分页u loadmore加载更多uview ui 小程序分页app分页处理 1 前端自定义分页条数传page pageSize参数给后端 根据第一页 总条数判断
  • 解决python爬虫urllib请求报错问题:urllib.error.URLError: <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certif

    项目场景 python爬虫urllib请求报错 问题描述 import urllib request url https movie douban com headers User Agent Mozilla 5 0 Macintosh I
  • python3中import time是什么意思_python3_time模块详解

    python提供的时间模块time是需要单独引入 1 time sleep secs 推迟调用线程的运行 secs指的是秒 time sleep secs 2 time time 返回当前时间的时间戳 时间戳都是从1970年1月1日午夜经过
  • java 获取home路径_关于JAVA_HOME等引用路径

    初学者往往在配置环境变量的时候会遇到一些小的细节问题 比如为了方便以后运行tomcat eclipse等 我们常常配置一个名为JAVA HOME的变量 如 D Program Files Java jdk1 6 0 10 这样在我们配置pa
  • 西门子SMART 存储区

    输入过程映像区 I区 Process image input register 范围 I0 0 to I31 7 输出 过程映像区 Q区 Process image output register 范围 Q0 0 to Q31 7 模拟量输
  • python小游戏毕设 滑雪小游戏设计与实现 (源码)

    文章目录 0 项目简介 1 游戏介绍 2 实现效果 3 开发工具 3 1 环境配置 3 2 Pygame介绍 4 具体实现 5 最后 0 项目简介 Hi 各位同学好呀 这里是L学长 今天向大家分享一个今年 2022 最新完成的毕业设计项目作
  • 关于多个 ELement UI Popover的处理

    处理 Element UI 中Popover组件会在页面中出现多个的情况 前言 今天有一个需求 一个列表中每一行都会有一个element ui的 popover 弹窗 使用click或者hover触发 但是 这个组件自身不会自动关闭 处理手
  • 多模型构建的多层级权限管控体系

    在阐述 CloudQuery 权限体系之前 想先跟大家分享下我们团队在客户侧收集到了的一些真实场景与诉求 对特定对象进行操作管控 SQL 命令 对某个字段实现精确动态脱敏 对某一条 SQL 语句进行精确提权 对高危命令进行拦截 实现用户登录
  • Flutter 实现文字向上/下滚动效果(八)

    实现原理 Flutter ListView 定时器 Timer 每隔一段时间通过控制器 scrollController 主动跳转 animateTo 下一条目 可以自定义动画 跳转时间 到达底部时从头开始 循环往复 import dart
  • 学习笔记之什么是持久化和对象关系映射ORM技术

    学习笔记之什么是持久化和对象关系映射ORM技术 by Naven at 2005 09 19 何谓 持久化 持久 Persistence 即把数据 如内存中的对象 保存到可永久保存的存储设备中 如磁盘 持久化的主要应用是将内存中的数据存储在
  • 你认为DAO是否可行?新年计划,卯足干劲,兔必No.1

    文章目录 课前小差 聚沙成塔 社会价值 DAO是什么 国产化 商业化回报 写在最后 课前小差 哈喽 大家好 我是几何心凉 这是一份全新的专栏 唯一得倒CSDN王总的授权 来对于我们每周四的绿萝时间 直达CSDN 直播内容进行总结概括 让大家
  • [mysql]游标和触发器

    目录 游标 或光标 定义 使用过程 示例 总结 触发器 应用场景 定义 使用 创建 查看 删除 示例 一个注意点 优缺点 拓展 MySQL 8 0的新特性 全局变量的持久化 游标 或光标 定义 游标是一种 能够对结果集中的每一条记录进行定位
  • Jetson nano之ROS入门 - - 机器人建模与仿真

    文章目录 前言 一 URDF建模 1 URDF语法详解 a robot b link c joint 2 URDF机器人建模实操 二 Xacro宏优化 1 Xacro宏语法详解 2 Xacro建模实操 三 Rviz与Gazebo仿真 1 G
  • 【人体姿态】Convolutional Pose Machines

    Wei Shih En et al Convolutional Pose Machines CVPR 2016 本论文将深度学习应用于人体姿态分析 同时用卷积图层表达纹理信息和空间信息 目前在2016年的MPII竞赛中名列前茅 作者在git
  • 51单片机之串口通讯应用实例(逻辑分析仪调试)

    硬件 STC89C52RC 开发工具 Keil uVision4 前言 8051是一款很经典的 历史悠久的单片机 作为一款入门级的单片机8051受到很多初学者的欢迎 89c52是8051系列的成员之一 拥有8K字节程序存储空间 512字节随
  • 基于Python Django Mysql数据库 的电商系统实现

    基于Python Django的电商系统实现 最近需要基于Django实现一个电商系统 目前已实现了基本功能 整个系统结构相对简单 没有进行前后端分离 使用的django的最简单的Template模板前后端交互模式 这个项目属于入门级项目
  • 环保行业如何开发废品回收微信小程序

    废品回收是近年来受到越来越多人关注的环保行动 为了推动废品回收的普及和方便 我们可以利用微信小程序进行制作 方便人们随时随地参与废品回收 首先 我们需要注册并登录乔拓云账号 并进入后台 乔拓云是一个提供微信小程序制作平台的服务商 非常适合我
  • php user.ini详解

    0x00 前言 本篇主要是讲解分析一下user ini相关的内容 因为这个知识点涉及到文件上传的绕过 0x01 正文 user ini 文件是PHP的配置文件 用于自定义PHP的配置选项 该文件通常位于PHP安装目录的根目录下 或者在特定的
  • 2. 依赖管理和自动配置

    文章目录 2 1 依赖管理 2 1 1 什么是依赖管理 2 1 2 修改自动仲裁 默认版本号 2 2 starter 场景启动器 2 2 1 starter 场景启动器基本介绍 2 2 2 官方提供的 starter 2 2 2 1 地址
  • PyTorch基础入门六:PyTorch搭建卷积神经网络实现MNIST手写数字识别

    1 卷积神经网络 CNN 简介 关于什么是卷积神经网络 CNN 请自行查阅资料进行学习 如果是初学者 这里推荐一下台湾的李宏毅的深度学习课程 链接就不给了 这些资料网站上随处可见 值得一提的是 CNN虽然在图像处理的领域具有不可阻挡的势头