手写数字识别代码详解

2023-11-16

文件目录如下,其中数据集data目录运行时在与手写数字识别同级目录自动生成,具体文件内代码见下文

一、conf.py文件

"""
项目配置
"""
import torch

train_batch_size = 128   # 训练批次大小,表示每次训练神经网络时每次使用的图像张量和标签张量的数量为128
test_batch_size = 1000   # 测试批次大小,即测试时为1000
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")   # 判断当前系统是否支持cuda来加速运算,后续用于设置模型和数据在何处运行

二、dataset.py文件

"""
准备数据集
"""
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torchvision
import conf

def mnist_dataset(train): #准备minist的dataset数据集,输入参数train,表示是否需要加载训练数据集,True则加载,否则加载测试数据集
    func = torchvision.transforms.Compose([   # 将多个数据处理函数组合在一起,以便对数据进行连续的转换操作,以下用来两个转换操作
        torchvision.transforms.ToTensor(),   # 将PIL image或numpy.ndarray的数据类型转为Tensor,并进行归一化处理(0~1之间)
        torchvision.transforms.Normalize(   # 对每个通道进行标准化,即先减均值再除以标准差
            mean=(0.1307,),
            std=(0.3081,)
        )]
    )

    # 1. 准备Mnist数据集
    return  MNIST(root="../data/mnist", train=train, download=True, transform=func)   # root为存储MNIST数据集的目录,train表示是否加载训练数据集,download表示是否从远程下载MNIST数据,transform表示对数据集进行的数据处理操作,最后返回处理后的MNIST数据集

def get_dataloader(train=True):   # 获取数据加载器,train=True表示是否是训练集
    mnist = mnist_dataset(train)   # 调用函数返回MNIST数据集,并将其赋值给mnist变量
    batch_size = conf.train_batch_size if train else conf.test_batch_size  # 根据是否为训练集确定批次大小
    return DataLoader(mnist,batch_size=batch_size,shuffle=True)  # 返回数据加载器对象,DataLoader是pytorch中用于数据加载的工具类,用于从给定的数据集中读取数据并返回迭代器
    # 此处将 MNIST 数据集传递给 DataLoader 对象,并指定批次大小和是否打乱数据顺序。这样,就可以在训练或测试神经网络模型时使用该迭代器来获取数据

if __name__ == '__main__':
    for (images,labels) in get_dataloader():   # 遍历数据加载器中的所有图像和标签数据,将每次循环得到的图像数据赋值给images变量,标签数据赋值给labels变量
        print(images.size())  # 打印图像数据的大小,即图形张量的形状,具体为形如(batch_size,channel,height,width)的4维张量,依次为批次大小、图像通道数、图像高度和宽度
        print(labels.size())   # 打印标签数据的大小,即标签张量的形状,具体为形如(batch_size,)的1维张量,表示批次大小
        break  # 查看第一个批次数据后停止循环,即输出第一个批次的图像和标签数据的大小

三、models.py文件

"""定义模型"""

import torch.nn as nn
import torch.nn.functional as F

class MnistModel(nn.Module):   # 定义神经网络模型,继承pytorch的nn.Module类
    def __init__(self):  # 初始化函数,用于定义网络结构和模型参数
        super(MnistModel,self).__init__()   # 继承nn.Module类的属性和方法
        self.fc1 = nn.Linear(1*28*28,100)   # 定义全连接层,包含100个神经元,即输出特征数量为100
        self.fc2 = nn.Linear(100,10)  # 定义全连接层,包含10个神经元,即输出特征数量为10,输入特征的数量为100

    def forward(self, image):   # 前向传播函数,用于进行模型推断操作,image是一个手写数字图像,
        image_viwed = image.view(-1,1*28*28) #[batch_size,1*28*28]  #形态转换,以便能够输入到全连接层
        fc1_out = self.fc1(image_viwed) #[batch_size,100]    # 将数据输入到第一个全连接层self.fc1中
        fc1_out_relu = F.relu(fc1_out) #[batch_siz3,100]   # ReLU激活函数
        out = self.fc2(fc1_out_relu) #[batch_size,10]   # 输入到第二个全连接层
        return F.log_softmax(out,dim=-1)  # 输出层使用该函数对得到的结果进行对数概率归一化处理,以便能够进行后续的损失计算和反向传播操作

四、train.py文件

"""
进行模型的训练
"""
from dataset import get_dataloader
from models import MnistModel
from torch import optim
import torch.nn.functional as F
import conf
from tqdm import tqdm
from test2 import eval
import numpy as np
import torch
import os

#1. 实例化模型,优化器,损失函数
model = MnistModel().to(conf.device)  # 创建一个MnistModel的实例,并将其移动到指定设备
optimizer = optim.Adam(model.parameters(),lr=1e-3)   # 使用Adam优化算法创建优化器,拥有优化神经网络模型中的参数,前者返回模型中所有需要训练的权重和偏置张量,后者指定学习率为0.001
                                                    # Adam优化算法是一种自适应梯度下降优化算法,可在高维空间中快速、稳定的优化神经网络模型

# if os.path.exists("./models/model.pkl"):
#     model.load_state_dict(torch.load("./models/model.pkl"))
#     optimizer.load_state_dict(torch.load("./models/optimizer.pkl"))


#2. 进行循环,进行训练
def train(epoch):   # 训练神经网络模型,epoch为训练的轮数
    train_dataloader = get_dataloader(train=True)   # 调用函数获取训练数据集的数据加载器
    bar = tqdm(enumerate(train_dataloader),total=len(train_dataloader))   # 使用tqdm()函数创建进度条,enumerate(train_dataloader)表示将数据加载器train_dataloader转换成可迭代的枚举对象,total=len(train_dataloader)表示进度条的总长度
    total_loss = []
    for idx,(input,target) in bar:   # 遍历数据加载器中的所有数据,并将其赋值给input和target
        input = input.to(conf.device)   # 将输入数据转移到指定的计算设备上,即GPU或CPU
        target = target.to(conf.device)  # 将目标数据转移到指定计算设备上
        optimizer.zero_grad()  # 将梯度清零,以避免梯度累加导致错误的梯度更新
        output = model(input)  # 将输入数据input输入到定义好的神经网络模型model中,得到输出预测结果output
        loss = F.nll_loss(output,target)   # 计算损失函数,使用负对数似然损失函数(Negative Log Likelihood Loss)进行分类问题的学习,该函数会将网络的输出结果output和目标数据target代入公式中计算出损失值并返回
        loss.backward()  # 计算反向传播,计算损失对网络参数的梯度
        total_loss.append(loss.item())   # 将计算得到的损失值添加到列表total_loss中
        optimizer.step()   # 根据梯度更新参数
        #打印数据
        if idx%10 ==0 :   # 每迭代10个批次就保存一下模型和优化器参数的状态
            bar.set_description("epcoh:{} idx:{},loss:{:.6f}".format(epoch,idx,np.mean(total_loss)))
            torch.save(model.state_dict(),"./models/model.pkl")   # # 将模型的参数保存到磁盘中
            torch.save(optimizer.state_dict(),"./models/optimizer.pkl")   # 将优化器的参数保存到磁盘中

if __name__ == '__main__':
    for i in range(10):  # 开启训练轮数为10轮的循环
        train(i)  # 调用train()函数开始训练模型
        eval()   # 训练完成后,调用eval()函数对模型进行评估

五、test2.py文件

"""
进行模型的评估
"""
from dataset import get_dataloader
from models import  MnistModel
from torch import optim
import torch.nn.functional as F
import conf
from tqdm import tqdm
import numpy as np
import torch
import os

def eval():   # 测试模型在测试集上的性能表现
    # 1. 实例化模型,优化器,损失函数
    model = MnistModel().to(conf.device)   # 实例化MnistModel的对象并将其移动到指定的设备上(CPU或GPU)

    if os.path.exists("./models/model.pkl"):  # 判断模型是否已保存到本地来恢复模型权重,若不存在,则根据MnistModel类中定义的初始化函数随机生成权重
        model.load_state_dict(torch.load("./models/model.pkl"))
    test_dataloader = get_dataloader(train=False)   # 获取测试数据集的数据加载器
    total_loss = []
    total_acc = []
    with torch.no_grad():
        for input,target in test_dataloader: #2. 进行循环,进行训练,每次使用一个批次的图像输入和对于的标签,并将数据一道指定设备上
            input = input.to(conf.device)
            target = target.to(conf.device)
            #计算得到预测值
            output = model(input)  # 使用模型前向传播获取模型输出
            #得到损失
            loss = F.nll_loss(output,target)   # 计算模型预测与真实标签间的损失loss
            #反向传播,计算损失
            total_loss.append(loss.item())   # 将该批次数据的loss值添加到数组中

            #计算准确率
            ###计算预测值
            pred = output.max(dim=-1)[-1]
            total_acc.append(pred.eq(target).float().mean().item())   # 计算该批次数据的准确率并添加到数组中
    print("test loss:{},test acc:{}".format(np.mean(total_loss),np.mean(total_acc)))   # 计算平均损失和平均准确率并打印

if __name__ == '__main__':
    # for i in range(10):
    #     train(i)
    eval()

六、运行train.py文件

输出如下

epcoh:0 idx:460,loss:0.321847: 100%|██████████| 469/469 [00:22<00:00, 21.20it/s]
test loss:0.1768832817673683,test acc:0.9481000065803528
epcoh:1 idx:460,loss:0.147730: 100%|██████████| 469/469 [00:26<00:00, 17.82it/s]
test loss:0.1185844399034977,test acc:0.9652999997138977
epcoh:2 idx:460,loss:0.101420: 100%|██████████| 469/469 [00:25<00:00, 18.67it/s]
test loss:0.09583198800683021,test acc:0.9711999952793121
epcoh:3 idx:460,loss:0.079119: 100%|██████████| 469/469 [00:25<00:00, 18.42it/s]
test loss:0.08689267784357071,test acc:0.9724000036716461
epcoh:4 idx:460,loss:0.061671: 100%|██████████| 469/469 [00:29<00:00, 15.77it/s]
test loss:0.0784779790788889,test acc:0.9767000019550324
epcoh:5 idx:460,loss:0.051475: 100%|██████████| 469/469 [00:30<00:00, 15.20it/s]
test loss:0.07767344787716865,test acc:0.975299996137619
epcoh:6 idx:460,loss:0.042745: 100%|██████████| 469/469 [00:29<00:00, 16.05it/s]
test loss:0.07635272592306137,test acc:0.9771999955177307
epcoh:7 idx:460,loss:0.034893: 100%|██████████| 469/469 [00:31<00:00, 14.98it/s]
test loss:0.0841908399015665,test acc:0.9741999983787537
epcoh:8 idx:460,loss:0.029657: 100%|██████████| 469/469 [00:31<00:00, 15.10it/s]
test loss:0.09327867105603219,test acc:0.9734000027179718
epcoh:9 idx:460,loss:0.023495: 100%|██████████| 469/469 [00:34<00:00, 13.76it/s]
test loss:0.08141878321766853,test acc:0.976800000667572

学习导航:https://www.xqnav.top/

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

手写数字识别代码详解 的相关文章

随机推荐

  • Sqli-labs之Less-25和Less-25a

    Less 25 GET 基于错误 您所有的OR AND都属于我们 字符串单引号 Notice Undefined variable hint in C phpStudy WWW sqli Less 25 index php on line
  • Android自动化测试框架实现

    背景介绍 最近打算梳理一下不同产品领域的自动化测试实现方案 如 Android终端 Web 服务端 智能硬件等 就先从Android终端产品开始梳理吧 本文主要介绍UI自动化测试的实现 因为这类测试解决方案比较通用 Android系统层 内
  • CentOS 7下安装nginx+php+mysql

    目录 一 安装Nginx 1 安装make 2 安装g 3 安装PCRE库 4 安装zlib库 5 安装ssl 6 下载和解压nginx 7 添加nginx用户和用户组 8 配置nginx安装参数 9 编译并且安装nginx 10 启动ng
  • Windows下的mingw-Qt开发环境安装及helloworld实现

    Windows下的mingw Qt开发环境安装及helloworld实现 我用的是Qt5 7 因此本次总结是基于Qt5 7 0的 我在自学的时候使用的IDE是Qt自带的Qt creator 上手简单 配置属于自己顺手的设置很方便 此外 如果
  • element主题色切换

    在网上搜了很多主题切换方案 发现没有适合自己项目的 不得已结合根据实际情况做一个子主题切换的功能 其中参考了element 官方的theme chalk preview 感兴趣的可以自己研究一下 主要功能是基于less切换主题色 可以自定义
  • 网络安全工程师需要学什么?零基础怎么从入门到精通,看这一篇就够了

    网络安全工程师需要学什么 零基础怎么从入门到精通 看这一篇就够了 我发现关于网络安全的学习路线网上有非常多看似高大上却无任何参考意义的回答 大多数的路线都是给了一个大概的框架 告诉你那些东西要考 以及建议了一个学习顺序 但是这对于小白来说是
  • 计算机退出程序的四种方法,退出windows10系统账户的四种方法

    网友反馈说Win10系统打开某些程序时 经常会弹出提示 你要允许以下程序对此计算机进行更改吗 每回都要手动关闭 而且频繁的弹出影响办公效率 有什么办法能将此窗口给永久关闭 退出微软账户即可 接下去看下具体操作方法 退出Win10账户的方法
  • 同步与异步的区别(一看则懂)

    前端面试经常被问 同步与异步的区别是什么 答案呢 大家都知道 只是在于你怎么表达 这种问题也不是很复杂 建议在回答的时候最好结合自己的实际项目开发以及自己的理解来回答 这样的效果会比较好 面试上提的问题本来目的就是想考察你是否熟悉 是否有用
  • TSN协议之冗余协议——IEEE 802.1 CB

    在车载通信领域 我们时常面临一个困惑 要是通信线路异常断开了怎么办 这里的异常断开不仅指物理上的断开 也可能是受电磁干扰等导致线路通信功能的异常等 解决此类问题 一个显而易见的解决方案就是增加冗余路径 即把数据传输2 N份以进行备份 这样就
  • 【转载】阿里数据技术大图详解

    架构图从下往上看 从数据采集和接入为始 抽取到计算平台 通过OneData体系 以 业务板块 分析维度 为架构去构建 公共数据中心 基于公共数据中心在上层根据业务需求去建设消费者数据体系 企业数据体系 内容数据体系等核心数据资产 深度加工后
  • JS判断数组是否包含其他数组中的一个值

    Test var a 2 3 4 5 6 7 8 9 10 var b 2 3 var c 1 var x S1 var y S2 c findIndex val gt x y a includes val Demo POC primary
  • 读取nacos配置_Nacos入门指南01 Nacos是什么?

    你好 欢迎阅读 本文是系列文章中的第1篇 Part1 Nacos 是什么 Part2 Nacos 环境搭建 Part3 Nacos 服务发现实践 Part4 Nacos 分布式配置实践 本文的目标是理解 Nacos 的概念作用 并理解服务发
  • 【发布】ChatGLM又开源了一个6B多模态版本

    点击蓝字 关注我们 AI TIME欢迎每一位AI爱好者的加入 OpenAI 的GPT 4样例中展现出令人印象深刻的多模态理解能力 但是能理解图像的中文开源对话模型仍是空白 近期 智谱AI 和清华大学 KEG 实验室开源了基于 ChatGLM
  • Quartus Ⅱ 15.1 将Verilog模块程序封装

    将模块程序封装 我们可以更加直观查看每个模块间的联系 先放一张成果图 博主做完数电实验就忘干净了 所以自己又摸索了一遍 最后成品可能不是太好看 怪自己手残 下面是详细步骤 首先要在files一栏 右击想要封装的模块 然后选择 Create
  • 如何在PC上查看一个web页面在移动端的展示效果

    最近在chrome上发现一个东东 emulation 这个果断可以用来模拟web页面在移动端的显示结果 F12的界面 点击 Show drawer 就可以看到这个界面了 这里可以选择各种设备 选中之后 点击emulate就可以模拟了 这个就
  • python 基础语法使用Demo

    基本模型 usr bin python coding UTF 8 print 你好 世界 一行显示多条语句 方法是用分号 分开 print hello print world 编写格式注意点 没有严格缩进 在执行时会报错 if True p
  • tw8836flash制作

    TW8836 Flash的bin制作 2 选bitmap 在选menu 3 4 压缩需勾选 5 添加图片 制作bin文件 6 改生成的MRLE为Bin后缀 BIN文件为烧写 INF文件为图片存储信息 代码要用到 7 这里用到BIN文件作为烧
  • chapter6可视化(不想看版)

    pip install visdom python m visdom server 直接使用 http localhost 8097 def linspace start stop num 50 endpoint True retstep
  • [个人笔记] origin学习 入门教程

    良心官方 已经入驻bilibili 官号 Origin Pro软件官方 投稿了许多基础教程 还有技术交流群等 打算学习的同学可以去找一下看看 2020 7 5官号只有三级 快去欺负 晚了就欺负不到了 图片中包含引用于官方视频教程的图片左下角
  • 手写数字识别代码详解

    文件目录如下 其中数据集data目录运行时在与手写数字识别同级目录自动生成 具体文件内代码见下文 一 conf py文件 项目配置 import torch train batch size 128 训练批次大小 表示每次训练神经网络时每次