LeNet-5识别数字

2023-11-14

LeNet识别数字

前言

实现经典卷积神经网络LeNet(LeNet-5)识别数字,这里将激活函数从sigmoid换成ReLU,参考资料《动手学深度学习》。

环境

python+pytorch

实现

import torch
import torch.nn as nn
import torchvision
import matplotlib.pyplot as plt
from torch.utils import data
from torchvision import transforms
import cv2 as cv

#下载数据集
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.MNIST(
    root="./drive/MyDrive/ex/data", transform=trans, train=True, download=True) 
mnist_test = torchvision.datasets.MNIST(
    root="./drive/MyDrive/ex/data", transform=trans, train=False, download=True)
# root后的路径为数据集下载后的保存路径

# 定义网络
class Reshape(nn.Module):
    def forward(self, x):
        return x.view(-1, 1, 28, 28)

net = nn.Sequential(
    Reshape(),
    nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.ReLU(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), nn.ReLU(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Flatten(),
    nn.Linear(16 * 5 * 5, 120), nn.ReLU(),
    nn.Linear(120, 84), nn.ReLU(),
    nn.Linear(84, 10))

#读取数据
batch_size = 256
train_iter = data.DataLoader(mnist_train, shuffle=True, batch_size=batch_size, num_workers=4)
test_iter = data.DataLoader(mnist_test, shuffle=True, batch_size=batch_size, num_workers=4)

class Accumulator:
    """在`n`个变量上累加。"""
    def __init__(self, n):
        self.data = [0.0] * n

    def add(self, *args):
        self.data = [a + float(b) for a, b in zip(self.data, args)]

    def reset(self):
        self.data = [0.0] * len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

def accuracy(y_hat, y):
    """计算预测正确的数量。"""
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = y_hat.argmax(axis=1)
    cmp = y_hat.type(y.dtype) == y
    return float(cmp.type(y.dtype).sum())

def evaluate_accuracy_gpu(net, data_iter, device=None):
    """使用GPU计算模型在数据集上的精度。"""
    if isinstance(net, nn.Module):
        net.eval()  # 设置为评估模式
        if not device:
            device = next(iter(net.parameters())).device
    # 正确预测的数量,总预测的数量
    metric = Accumulator(2)
    for X, y in data_iter:
        if isinstance(X, list):
            X = [x.to(device) for x in X]
        else:
            X = X.to(device)
        y = y.to(device)
        metric.add(accuracy(net(X), y), y.numel())
    return metric[0] / metric[1]

def train(net, train_iter, test_iter, num_epochs, lr, device):
    """用GPU训练"""
    def init_weights(m):
      if type(m) == nn.Linear or type(m) == nn.Conv2d:
            nn.init.xavier_uniform_(m.weight)
    net.apply(init_weights)
    net.to(device)
    optimizer = torch.optim.SGD(net.parameters(), lr=lr)
    loss = nn.CrossEntropyLoss()
    for epoch in range(num_epochs):
        metric = Accumulator(3)
        net.train()
        for i, (X, y) in enumerate(train_iter):
            optimizer.zero_grad()
            X, y = X.to(device), y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            l.backward()
            optimizer.step()
            with torch.no_grad():
                metric.add(l * X.shape[0], accuracy(y_hat, y), X.shape[0])
            train_l = metric[0] / metric[2]
            train_acc = metric[1] / metric[2]
        test_acc = evaluate_accuracy_gpu(net, test_iter)
    print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, '
          f'test acc {test_acc:.3f}')

lr, num_epochs = 0.1, 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train(net, train_iter, test_iter, num_epochs, lr, device)

def get_mnist_labels(labels):
    """返回数字标签。"""
    text_labels = ['0', '1', '2', '3', '4',
                   '5', '6', '7', '8', '9']
    return [text_labels[int(i)] for i in labels]
def show_image(img_tensor):
	"""用于数据集中的显示图片"""
      plt.imshow(img.numpy())
      plt.axis('off')

def predict(net, device=None):
    """预测图片数字"""
    if isinstance(net, nn.Module):
        net.eval()  # 设置为评估模式
        if not device:
            device = next(iter(net.parameters())).device
    img = cv.imread('./drive/MyDrive/ex/data/MNIST/test.png') # 放预测图片的位置,图片大小为28*28,黑底白字
    img = cv.cvtColor(img, cv.COLOR_RGB2GRAY)
    transf = transforms.ToTensor()
    img_tensor = transf(img)
    show_image(img_tensor)
    pred = get_mnist_labels(net(img_tensor).argmax(axis=1))
    print(pred)

predict(net)

结果

loss:0.046, train accuracy:0.986, test accuracy:0.978

导入下面的预测图片
导入的预测图片

导入的图片显示效果和识别的数字。
结果图片

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

LeNet-5识别数字 的相关文章

随机推荐

  • C++不定参数,模板函数,模板类详解附实例

    前言 在 C 中 有时我们在写一个函数时并不知道参数的数量和类型 这时需要用到不定参数 模板函数 正文 不定参数 不定参数怎么表示 对于不定参数的表示 就是三个点 注意是英文的点 那么我们在正常使用时函数参数写成这样 funtionType
  • Python 容器序列切片

    视频版教程 Python3零基础7天入门实战视频教程 序列是指内容连续且有序的一类数据容器 前面学的列表 元组 字符串都是序列 并且支持下标索引 切片是指从一个序列中 取出一个子序列 语法 序列 起始下标 结束下标 步长 返回一个新的序列
  • 短文阅读3:Variational Autoencoders (VAEs)

    深度生成网络 VAEs introduction 降维方法 PCA and Autoencoders 降维架构 PCA 问题1 什么是自动编码器autoencoder PCA和Autoencoders之间的关系 Variational Au
  • 【建议收藏】数据库 SQL 入门——数据查询操作(内附演示)

    引言 在上一节中 我们讨论了DML的使用方法 本节我们继续开始DQL的学习 首先回归一下DQL的基于定义 DQL Data Query Language 数据查询语言 用来查询数据库中表的记录 在本节中我们主要讨论DQL的用法以及基本语法
  • 计算机视觉之人脸识别(Yale数据集)--HOG和ResNet两种方法实现

    1 问题描述 在给定Yale数据集上完成以下工作 在给定的人脸库中 通过算法完成人脸识别 算法需要做到能判断出测试的人脸是否属于给定的数据集 如果属于 需要判断出测试的人脸属于数据集中的哪一位 否则 需要声明测试的人脸不属于数据集 这是一个
  • 思维导图 函数

  • PCL点云处理之最小二乘空间直线拟合(3D) (二百零二)

    PCL点云处理之最小二乘空间直线拟合 3D 二百零二 一 算法简介 二 实现代码 三 效果展示 一 算法简介 对于空间中的这样一组点 大致呈直线分布 散乱分布在直线左右 我们可采用最小二乘方法拟合直线 更进一步地 可以通过点到直线的投影 最
  • 5款程序员必备的免费在线画图工具,超级好用!

    点击上方 芋道源码 选择 设为星标 管她前浪 还是后浪 能浪的浪 才是好浪 每天 10 33 更新文章 每天掉亿点点头发 源码精品专栏 原创 Java 2021 超神之路 很肝 中文详细注释的开源项目 RPC 框架 Dubbo 源码解析 网
  • java中的集合基础

    集合介绍 集合类的特点 提供一种存储空间可变的存储模型 存储的数据容量可以发生改变 集合和数组的区别 共同点 都是存储数据的容器 不同点 数组的容量是固定的 集合的容量是可变的 数组可以存基本数据类型和引用数据类型 集合只能存引用数据类型
  • 【Android进阶篇】WebView显示网页详解

    概述 WebView是Android用于显示网页的控件 通过WebView 我们可以查看本地的网页 也可以查看网络资源 本文内容如下 一 加载本地网页 二 加载网络资源 三 在WebView中使用JavaScript和CSS 四 WebCh
  • 多线程案例(1) - 单例模式

    目录 单例模式 饿汉模式 懒汉模式 前言 多线程中有许多非常经典的设计模式 这就类似于围棋的棋谱 这是用来解决我们在开发中遇到很多 经典场景 简单来说 设计模式就是一份模板 可以套用 单例模式 顾名思义 就是一个程序只能含有一个实例 有的场
  • Permission denied

    Permission denied 出现的原因的是 没有权限进行读 写 创建文件 删除文件等操作 解决方法 输入命令 sudo chmod R 777 工作目录 例如 sudo chmode R 777 home HDD 此时就可以在该路径
  • poium测试库介绍

    poium测试库前身为selenium page objects测试库 我在以前的文章中也有介绍过 这可能是最简单的Page Object库 项目的核心是基于Page Objects实现元素定位的封装 该项目由我个人在维护 目前在公司项目中
  • 使用ChatGPT的方式与在其他地方使用它的方式基本相同。以下是一些步骤:

    在中国使用ChatGPT的方式与在其他地方使用它的方式基本相同 以下是一些步骤 访问OpenAI的官方网站 OpenAI 在网站上找到GPT 3或ChatGPT的相关信息 OpenAI提供了详细的API文档 可以帮助你理解如何使用它们 你需
  • mysql数据库之跨表复制

    背景说明 目标库 target db 目标数据表 target tb 将目标库的制定表复制到当前数据库中 包括一下几个方面 一 表结构复制 仅仅复制了表的结构 没有数据 create table current db new tb like
  • Logitech G系鼠标脚本编程,实现鼠标自动定位控制

    利用罗技官方提供的API来写一个鼠标自动定位移动脚本 点击脚本编辑器中的帮助选项 查看罗技官方提供的API说明 有很多实现好的鼠标功能 G series Lua API V8 45 Overview and Reference 下面是我写的
  • 深入解析SpringBoot启动原理

    1 启动类中的SpringApplication run方法会创建一个SpringApplication的实例 并做一些初始化工作 SpringBootApplication Slf4j public class HuotuUserServ
  • Linux C编程基础:获取时间

    1 前言 对于linux下的编程 无论是用户态还是内核态 时间获取都是经常需要使用到的 以下分别从用户态和内核态整理了几个常用的时间获取接口 供编写代码时快速查阅 linux时间子系统的发展历史及详细介绍 可以参考 深入理解Linux时间子
  • stm32 机械周期_STM32定时器周期计算

    STM32定时器周期计算 公式是 1 TIM Prescaler 时钟 1 TIM Period F103配置生成1ms的时钟 1 35 36M 1 999 1MS TIM TimeBaseInitTypeDef TIM TimeBaseS
  • LeNet-5识别数字

    LeNet识别数字 前言 环境 实现 结果 前言 实现经典卷积神经网络LeNet LeNet 5 识别数字 这里将激活函数从sigmoid换成ReLU 参考资料 动手学深度学习 环境 python pytorch 实现 import tor