使用LeNet-5识别手写数字MNIST

2023-10-27

LeNet5

LeNet-5卷积神经网络模型
LeNet-5:是Yann LeCun在1998年设计的用于手写数字识别的卷积神经网络,当年美国大多数银行就是用它来识别支票上面的手写数字的,它是早期卷积神经网络中最有代表性的实验系统之一。

LenNet-5共有7层(不包括输入层),每层都包含不同数量的训练参数,如下图所示。
在这里插入图片描述
LeNet-5中主要有2个卷积层、2个下抽样层(池化层)、3个全连接层3种连接方式

使用LeNet5识别MNIST

初试版本:

import torch
import torchvision

import torch.nn as nn
from matplotlib import pyplot as plt

from torch.utils.data import DataLoader

# 先定义一个绘图工具
def plot_curve(data):
    fig = plt.figure()
    plt.plot(range(len(data)),data,color = 'blue')
    plt.legend(['value'],loc = 'upper right')
    plt.xlabel('step')
    plt.ylabel('value')
    plt.show()

device=torch.device('cuda' if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1=nn.Sequential(
            nn.Conv2d(1,6,5,1,2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2,stride=2)
        )
        self.conv2=nn.Sequential(
            nn.Conv2d(6,16,5),
            nn.ReLU(),
            nn.MaxPool2d(2,2)
        )
        self.fc1=nn.Sequential(
            nn.Linear(16*5*5,120),
            nn.ReLU()
        )
        self.fc2=nn.Sequential(
            nn.Linear(120,84),
            nn.ReLU()
        )
        self.fc3=nn.Linear(84,10)

        # self.model=nn.Sequential(
        #     nn.Conv2d(1,6,5,1,2),
        #     nn.ReLU(),
        #     nn.MaxPool2d(2,2),
        #     nn.Conv2d(6,16,5),
        #     nn.ReLU(),
        #     nn.MaxPool2d(2,2),
        #     nn.Flatten(),
        #     nn.Linear(16*5*5,120),
        #     nn.ReLU(),
        #     nn.Linear(120,84),
        #     nn.ReLU(),
        #     nn.Linear(84,10)
        # )

    def forward(self, x):
        x=self.conv1(x)
        x=self.conv2(x)
        # nn.Linear()的输入输出都是维度为1的值,所以要把多维度的tensor展平或一维
        x=x.view(x.size()[0], -1)
        x=self.fc1(x)
        x=self.fc2(x)
        x=self.fc3(x)
        # x=self.model(x)
        return x

epoch=8
batch_size=64
lr=0.001

traindata=torchvision.datasets.MNIST(root='./dataset', train=True, transform=torchvision.transforms.ToTensor(),download=True)
testdata=torchvision.datasets.MNIST(root='./dataset', train=False, transform=torchvision.transforms.ToTensor(),download=True)

trainloader=DataLoader(traindata,batch_size=batch_size,shuffle=True)
testloader=DataLoader(testdata,batch_size=batch_size,shuffle=False)

net=LeNet().to(device)

loss_fn=nn.CrossEntropyLoss().to(device)

optimizer=torch.optim.SGD(net.parameters(),lr=lr,momentum=0.9)

train_loss=[]
accuracy=[]
train_step=0
for epoch in range(epoch):
    sum_loss=0
    for data in trainloader:
        inputs,labels=data
        inputs,labels=inputs.to(device),labels.to(device)

        optimizer.zero_grad()
        outputs=net(inputs)
        loss=loss_fn(outputs,labels)
        loss.backward()
        optimizer.step()
        train_step+=1
        sum_loss+=loss.item()
        if train_step % 100==99:
            print("[epoch:{},轮次:{},sum_loss:{}".format(epoch+1,train_step,sum_loss/100))
            train_loss.append(sum_loss/100)
            sum_loss=0

    with torch.no_grad():
        correct=0
        total=0
        for data in testloader:
            images, labels=data
            images,labels=images.to(device),labels.to(device)
            outputs=net(images)
            _,predicted=torch.max(outputs.data,1)
            total+=labels.size(0)
            correct+=(predicted==labels).sum()
        accuracy.append(correct)
        print("第{}个epoch的识别准确率为:{}".format(epoch+1,correct/total))

plot_curve(train_loss)
plot_curve(accuracy)

运行结果:识别准确率还是不错的
在这里插入图片描述

每一步的训练损失值变化:
在这里插入图片描述
每轮测试集的识别准确率:

在这里插入图片描述

代码优化一下:

import torch
import torchvision

import torch.nn as nn
from matplotlib import pyplot as plt

from torch.utils.data import DataLoader

# 先定义一个绘图工具
def plot_curve(data):
    fig = plt.figure()
    plt.plot(range(len(data)),data,color = 'blue')
    plt.legend(['value'],loc = 'upper right')
    plt.xlabel('step')
    plt.ylabel('value')
    plt.show()

device=torch.device('cuda' if torch.cuda.is_available() else "cpu")

# 定义LeNet网络
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.model=nn.Sequential(
            # MNIST数据集大小为28x28,要先做padding=2的填充才满足32x32的输入大小
            nn.Conv2d(1,6,5,1,2),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            nn.Conv2d(6,16,5),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            nn.Flatten(),
            nn.Linear(16*5*5,120),
            nn.ReLU(),
            nn.Linear(120,84),
            nn.ReLU(),
            nn.Linear(84,10)
        )

    def forward(self, x):
        x=self.model(x)
        return x

epoch=8
batch_size=64
lr=0.001

# 导入数据集
traindata=torchvision.datasets.MNIST(root='./dataset', train=True, transform=torchvision.transforms.ToTensor(),download=True)
testdata=torchvision.datasets.MNIST(root='./dataset', train=False, transform=torchvision.transforms.ToTensor(),download=True)

test_size=len(testdata)

# 加载数据集
trainloader=DataLoader(traindata,batch_size=batch_size,shuffle=True)
testloader=DataLoader(testdata,batch_size=batch_size,shuffle=False)

net=LeNet().to(device)

loss_fn=nn.CrossEntropyLoss().to(device)

optimizer=torch.optim.SGD(net.parameters(),lr=lr,momentum=0.9)

train_loss=[]
precision=[]
train_step=0
for epoch in range(epoch):
    net.train()
    sum_loss=0
    for data in trainloader:
        inputs,labels=data
        inputs,labels=inputs.to(device),labels.to(device)

        # 更新梯度
        optimizer.zero_grad()
        outputs=net(inputs)
        loss=loss_fn(outputs,labels)
        loss.backward()
        optimizer.step()

        train_step+=1
        sum_loss+=loss.item()
        if train_step % 100==99:
            print("[epoch:{},轮次:{},sum_loss:{}]".format(epoch+1,train_step,sum_loss/100))
            train_loss.append(sum_loss/100)
            sum_loss=0

    net.eval()
    with torch.no_grad():
        correct=0
        # total=0
        accuracy=0
        for data in testloader:
            images, labels=data
            images,labels=images.to(device),labels.to(device)
            outputs=net(images)
            # _,predicted=torch.max(outputs.data,1)
            # total+=labels.size(0)
            # correct+=(predicted==labels).sum()
            correct+=(outputs.argmax(1)==labels).sum()
        accuracy=correct/test_size
        print("第{}个epoch的识别准确率为:{}".format(epoch+1,accuracy))
        precision.append(accuracy.cpu())

plot_curve(train_loss)
plot_curve(precision)

运行结果:
在这里插入图片描述

每一步的训练loss变化
在这里插入图片描述
测试集每一轮的准确率
在这里插入图片描述

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用LeNet-5识别手写数字MNIST 的相关文章

  • 1.3 手写数字识别任务

    文章目录 横纵式 教学法 一 加载类库 二 数据处理 飞桨API的使用方法 三 模型设计 四 训练配置 五 训练过程 六 模型测试 横纵式 教学法 在本教程中 我们采用了专门为读者设计的创新性的 横纵式 教学法进行深度学习建模介绍 如 图4
  • collection和collections区别

    区别 Collection它是一个集合的接口 它提供了对集合对象进行基本操作的通用接口方法 Collection接口在java类库当中有很多具体的实现 Collection接口的意义就是为各种具体的集合提供最大化的统一操作方式 主要实现的C
  • 软能力那点事,你知多少

    目录 一 软能力是什么 二 软能力 程序猿生存指南 A 沟通能力 B 学习能力 C 时间管理 D 分解能力 E 总结改进 三 程序猿成长路线 1 架构师 2 项目经理 3 产品经理 四 小结 在我们日常工作中 常常会听到软能力这一个词汇 尤
  • “wget: 无法解析主机地址”的解决方法

    问题 root iZ2zefny2a19ms6azli2pwZ wget https download redis io releases redis 5 0 10 tar gz 2020 11 01 14 30 12 https down
  • Jmeter 集合点

    概念 对于性能测试可以理解为多用户并发 但是真正的并发是不存在的 为了更真实的实现并发的概念 我们可以在需要的地方设置集合点 所有虚拟用户都互相之间等一等 然后一起访问 Jmeter集合点是通过添加定时器 Synchronizing tim

随机推荐

  • 针对Failed to execute goal org.apache.maven.pluginsmaven-compiler-plugin3.1的解决方案

    背景 本项目使用JDK1 8 编译maven工程的时候出现如下错误 Failed to execute goal org apache maven plugins maven compiler plugin 3 1 pom中如下配置mave
  • 不同项目中,S7-300 DP 和 S7-1200 PROFINET 的profibus通信(300做主站,1200做从站)

    使用 S7 1200 与 S7 300 的集成 DP 接口进行主从通信 这里是将 S7 300 做为主站 将 S7 1200 做为从站 即 S7 300 集成的 DP 接口做主站 S7 1200 通过 CM1242 5 做从站 两个PLC在
  • Mysql-Galera Cluster

    使用Galera Cluster需要下载包含wsrep补丁的mysql版本 官网下载地址 http galeracluster com downloads 安装前要卸载之前安装的mariadb或者mysql 或者迁移也可以 不过就是另一套操
  • elementui的el-table的插槽功能,添加判断值,在单元格添加输入框,点击事件等等

    elementui的el table的插槽功能 添加判断值 在单元格添加输入框 点击事件等等
  • 三、 HBuilderX运行到手机上看效果

    以下均已录制 点击查看B站视频 1 运行 gt 运行到手机或模拟器 出现如下问题 未检测到手机或模拟器 请稍后重试 2 这时需要手机上做一些设置 设置 gt 关于手机 gt 连续多次点击版本号 就能打开开发者模式 设置 gt 系统和更新 g
  • 优化pxe网启动时tftp的传输速度 --- 针对pxelinux和bootmgr

    作为一名IT人士 一般的计算机维护当然不好意思找别人 于是自己用pxelinux搭了个网络启动环境 可以启动各种WinPE 以供折腾电脑系统 刷新固件的需要 只是一般的网络启动都是基于tftp协议的 传输文件那叫一个慢 启动时光是加载映像文
  • 交易中间件消息中间件_什么是中间件

    交易中间件消息中间件 什么是中间件 What Is Middleware In network architecture a middleware is a layer of software that creates a network
  • 答题小程序常用脚本整理

    答题小程序常用脚本整理 本文主要描述答题活动小程序运营过程中 高频使用的 几个脚本操作 1 如何清理当前题库 在开发控制台的高级操作右侧有个加号按钮 点击下 选择空白模板即可 将下面的脚本复制进去 db collection questio
  • 聚簇索引和二级索引

    原文链接 https blog csdn net jijianshuai article details 79084874
  • SpringCloud项目如何成功打包以及其中的一些坑

    我的项目结构 其中edu online和edu admin是前端项目 其他是后端模块 首先需要在父工程中添加需要打包的模块和打包依赖 如果在父工程中配置过打包依赖则子模块中不需要配置打包依赖 但是如果有子模块需要被其他模块依赖 则需要在被依
  • ROS:解决Error:cannot launch node of type [map_server/map_server]: can't locate node [map_server] in......

    写在前面 本文为原创 如需转载请注明出处 https www jianshu com p e9981bc35cff 欢迎大家留言共同探讨 有误的地方也希望指出 另如果有好的SLAM ROS等相关交流群也希望可以留言给我 在此先谢过了 1 E
  • Gof23设计模式之建造者模式

    1 概述 建造者模式 Builder Pattern 又叫生成器模式 是一种对象构建模式 它可以将复杂对象的建造过程抽象出来 抽象类别 使这个抽象过程的不同实现方法可以构造出不同表现 属性 的对象 建造者模式是一步一步创建一个复杂的对象 它
  • 要求用成员函数实现以下功能由键盘输入,计算长方体的体积,输出3个长方体的体积。

    题目 需要求三个长方体的体积 请编写一个基于对象的程序 数据成员包括length 长 width 宽 height 高 要求用成员函数实现以下功能 1 由键盘输入3个长方体的长 宽 高 2 计算长方体的体积 3 输出3个长方体的体积 请编程
  • linux 关于修改命令提示符

    1 首先 进入root 用户获得权限 输入 su root 2 进入修改提示符的文件 输入 vim etc profile 3 进入文件 不要修改任何地方 在最后加入命令 1 输入 export PS1 e 1 32 40m 孔子曰 e 1
  • Flink将本地数据写入Redis

    第一步 配置文件redis conf cd usr apps redis vim redis conf 先输入 set nu 打开行号标识 69行 bind 127 0 0 1加上注释 取消IP绑定 否则其他主机不能连接 88行 prote
  • sqli-labs第十八十九关

    这两关为头注入 Less 18 POST Header Injection Uagent field Error based 手工注入 这关和下一关必须要抓包才能完成 因为在这里怎么是都没有反应 全是报错的状态 那么我估计就要抓包了 根本判
  • pythonqt对比_用 Python 和 C++ 创建 Qt 程序的简单对比

    假设要做一个简单的小窗口 如下图所示 PyQt 和 C 要用多少代码可以完成呢 效果图 注 本文内容较多 主要是 C 的部分 若有必要请直接跳到最后看结论 一 C 版本 除了最基础的 pro 文件之外 我一共创建了 5 个文件 custom
  • 电脑固定ip地址之后重启却失效了的解决办法

    开始 运行 cmd 回车 英文状态下输入 netsh winsock reset 回车后会提示重启 先不重启 继续输入 netsh int ip reset reset log 回车后会提示重启 此时先重启电脑 重启之后再次设置好固定ip地
  • SQL 映射文件

    SQL 映射文件 SQL 映射文件只有很少的几个顶级元素 按照应被定义的顺序列出 cache 对给定命名空间的缓存配置 cache ref 对其他命名空间缓存配置的引用 resultMap 是最复杂也是最强大的元素 用来描述如何从数据库结果
  • 使用LeNet-5识别手写数字MNIST

    LeNet5 LeNet 5卷积神经网络模型 LeNet 5 是Yann LeCun在1998年设计的用于手写数字识别的卷积神经网络 当年美国大多数银行就是用它来识别支票上面的手写数字的 它是早期卷积神经网络中最有代表性的实验系统之一 Le