pytorch用LeNet5识别Mnist手写体数据集(训练+预测单张输入图片代码)

2023-11-06

首先,在论文上的LeNet5的结构如下,由于论文的数据集是32x32的,mnist数据集是28x28的,所有只有INPUT变了,其余地方会严格按照LeNet5的结构编写程序:

训练代码:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable


lr = 0.01 #学习率
momentum = 0.5
log_interval = 10 #跑多少次batch进行一次日志记录
epochs = 10
batch_size = 64
test_batch_size = 1000


class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Sequential(  # input_size=(1*28*28)
            nn.Conv2d(1, 6, 5, 1, 2),  # padding=2保证输入输出尺寸相同
            nn.ReLU(),  # input_size=(6*28*28)
            nn.MaxPool2d(kernel_size=2, stride=2),  # output_size=(6*14*14)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(6, 16, 5),
            nn.ReLU(),  # input_size=(16*10*10)
            nn.MaxPool2d(2, 2)  # output_size=(16*5*5)
        )
        self.fc1 = nn.Sequential(
            nn.Linear(16 * 5 * 5, 120),
            nn.ReLU()
        )
        self.fc2 = nn.Sequential(
            nn.Linear(120, 84),
            nn.ReLU()
        )
        self.fc3 = nn.Linear(84, 10)

    # 定义前向传播过程,输入为x
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        # nn.Linear()的输入输出都是维度为一的值,所以要把多维度的tensor展平成一维
        x = x.view(x.size()[0], -1)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x #F.softmax(x, dim=1)



def train(epoch):  # 定义每个epoch的训练细节
    model.train()  # 设置为trainning模式
    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.to(device)
        target = target.to(device)
        data, target = Variable(data), Variable(target)  # 把数据转换成Variable
        optimizer.zero_grad()  # 优化器梯度初始化为零
        output = model(data)  # 把数据输入网络并得到输出,即进行前向传播
        loss = F.cross_entropy(output,target)  #交叉熵损失函数
        loss.backward()  # 反向传播梯度
        optimizer.step()  # 结束一次前传+反传之后,更新参数
        if batch_idx % log_interval == 0:  # 准备打印相关信息
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))


def test():
    model.eval()  # 设置为test模式
    test_loss = 0  # 初始化测试损失值为0
    correct = 0  # 初始化预测正确的数据个数为0
    for data, target in test_loader:

        data = data.to(device)
        target = target.to(device)
        data, target = Variable(data), Variable(target)  #计算前要把变量变成Variable形式,因为这样子才有梯度

        output = model(data)
        test_loss += F.cross_entropy(output, target, size_average=False).item()  # sum up batch loss 把所有loss值进行累加
        pred = output.data.max(1, keepdim=True)[1]  # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()  # 对预测正确的数据个数进行累加

    test_loss /= len(test_loader.dataset)  # 因为把所有loss值进行过累加,所以最后要除以总得数据长度才得平均loss
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))



if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') #启用GPU

    train_loader = torch.utils.data.DataLoader(  # 加载训练数据
        datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))  #数据集给出的均值和标准差系数,每个数据集都不同的,都数据集提供方给出的
                       ])),
        batch_size=batch_size, shuffle=True)

    test_loader = torch.utils.data.DataLoader(  # 加载训练数据,详细用法参考我的Pytorch打怪路(一)系列-(1)
        datasets.MNIST('../data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)) #数据集给出的均值和标准差系数,每个数据集都不同的,都数据集提供方给出的
        ])),
        batch_size=test_batch_size, shuffle=True)

    model = LeNet()  # 实例化一个网络对象
    model = model.to(device)
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)  # 初始化优化器

    for epoch in range(1, epochs + 1):  # 以epoch为单位进行循环
        train(epoch)
        test()

    torch.save(model, 'model.pth') #保存模型

 

预测代码:

import torch
import cv2
import torch.nn.functional as F
from modela import LeNet  ##重要,虽然显示灰色(即在次代码中没用到),但若没有引入这个模型代码,加载模型时会找不到模型
from torch.autograd import Variable
from torchvision import datasets, transforms
import numpy as np

if __name__ =='__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = torch.load('model.pth') #加载模型
    model = model.to(device)
    model.eval()    #把模型转为test模式

    img = cv2.imread("3.jpg")  #读取要预测的图片
    trans = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])

    img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)#图片转为灰度图,因为mnist数据集都是灰度图
    img = trans(img)
    img = img.to(device)
    img = img.unsqueeze(0)  #图片扩展多一维,因为输入到保存的模型中是4维的[batch_size,通道,长,宽],而普通图片只有三维,[通道,长,宽]
    #扩展后,为[1,1,28,28]
    output = model(img)
    prob = F.softmax(output, dim=1)
    prob = Variable(prob)
    prob = prob.cpu().numpy()  #用GPU的数据训练的模型保存的参数都是gpu形式的,要显示则先要转回cpu,再转回numpy模式
    print(prob)  #prob是10个分类的概率
    pred = np.argmax(prob) #选出概率最大的一个
    print(pred.item())

 

用画图软件画一张28x28的灰度图:

输入到预测代码中,效果:

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

pytorch用LeNet5识别Mnist手写体数据集(训练+预测单张输入图片代码) 的相关文章

  • Win10安装Linux虚拟机-安装与使用

    Win10安装Linux虚拟机 安装与使用 1 VMware 的下载 VMWare虚拟机软件是一个 虚拟PC 软件 它使你可以在一台机器上同时运行二个或更多Windows DOS LINUX系统 下载地址 https customercon
  • ie浏览器打不开闪退_教你修复win7IE浏览器闪退的问题

    使用win7系统的朋友不少会使用IE浏览器来访问网页的时候 经常会出现IE浏览器自动退出了 另外在闪退前会有个提示 出现一个问题导致程序停止正常工作 那么这样的问题该怎样解决呢 下面就跟小编来了解一下怎样修复IE浏览器问题吧 Win7 IE
  • Flex3.2 Lists & Grids 内存泄漏

    所有继承于ListBase的类List DataGrid AdvancedDataGrid and TileList 在选中列表中的一项后 增加了鼠标相关Listener 导致泄漏 SDK3 3中已经修改 Sdk3 2中修复补丁http w
  • 使用plsql访问远程数据库

    1 plsql输入ip端口数据库实例名直接登录 Username 用户名 如 scott Password 用户对应密码 如 tiger Database 数据库位置 语法为 ip 端口号 数据库实例名 如 192 168 1 156 15
  • Nand Flash的同步、异步、ONFI、Toggle

    1 SDR和DDR SDR Single Data Rate 写读数据使用上升沿或下降沿来触发 因为只用上升沿或下降沿 对信号准确性要求较低 DDR Double Data Rate 写数据时通过MCU来控制DQS信号跳变沿来触发 即上升沿
  • android fragment 重复创建的问题

    解决fragment重复创建目前用到有两个方法 1 fragment同viewpager一起使用 vp setOffscreenPageLimit 3 设置缓存页面的个数 2 fragment单独使用 在onCreateView 方法中加入
  • 用C语言写UTF-8编码的文件

    原文地址 http blog csdn net zaffix article details 7217701 为实现用C语言写UTF 8编码的文件 测试了以下两种情况 第一种情况 为 fopen 指定一个编码 然后写入 wchar t 字符
  • Flink笔记14:Flink之window起始点的确定与watermark使用详解

    1 window起始时间的确定 在TimeWindow java中有如下方法来确定window的起始时间 public static long getWindowStartWithOffset long timestamp long off
  • win32 API函数大全

    1 API之网络函数 WNetAddConnection 创建同一个网络资源的永久性连接 WNetAddConnection2 创建同一个网络资源的连接 WNetAddConnection3 创建同一个网络资源的连接 WNetCancelC
  • Python网络爬虫之数美滑块的加密及轨迹之动态js参数分析

    前言 数美滑块的加密及轨迹等应该是入门级别的吧 用他们的教程和话来说 就一个des 然后识别缺口位置可以用cv2或者ddddoc 轨迹 也可以随便模拟一个 这些简单的教程 在csdn已经有一大把可以搜到的 但是却很少人告诉你 它的js好像是
  • CMake 命令

    1 Usage cmake options
  • MAC安装渗透测试靶机

    1 mac 安装docker 直接到docker官网下载docker dmg 下载前先要注册docker 下载后直接安装就可以了 docker version 就能看见安装的版本 我的版本17 03 1 ce 2 下载docker镜像ima
  • 了解l电源纹波PSRR----转摘

    PSRR 就是 Power Supply Rejection Ratio 的缩写 中文含意为 电源纹波抑制比 也就是说 PSRR 表示把输入与电源视为两个独立的信号源时 所得到的两个电压增益的比值 基本计算公式为 PSRR 20log Ri
  • 【C语言】-- 整型数据的存储

    目录 1 数据类型的分类 2 基本类型 2 1 基本类型大小 2 2 整型家族 2 3 数据的存储形式 2 4 整形数据的存储方式 1 数据类型的分类 在C语言中有如下类型 2 基本类型 2 1 基本类型大小 一个变量的创建是要在内存中开辟
  • node の SQLite

    node操作SQLite 之前在做electron桌面制作番茄钟应用时曾经想过用数据库存储数据 一开始打算mongodb 但是发现不能实现无服务器 那么只能使用SQLite了 介绍 SQLite 是一个软件库 实现了自给自足的 无服务器的
  • android 前端常用布局文件升级总结(一)

    问题一 android support design widget CoordinatorLayout 报红 不显示页面 解决方法 把xml布局文件里面的 android support design widget CoordinatorL
  • SpringBoot + Prometheus + Grafana 打造可视化监控

    SpringBoot Prometheus Grafana 打造可视化监控 文章目录 SpringBoot Prometheus Grafana 打造可视化监控 常见的监控组件搭配 安装Prometheus 安装Grafana 搭建Spri

随机推荐

  • [正能量系列]失业的程序员(一)

    注 本文原型为作者的好友 全文不完全代表作者本人的意图 不小心 我失业了 原因是前几天和我的部门经理拍了桌子 我的组员去内蒙古出差 项目没有中标 年后 长得很像猪刚烈的部门经理发飙了 要辞退我的组员 我纳闷了 我的组员是技术支持 要退也应该
  • Proxmox VE虚拟化从入门到应用实战-服务器管理篇(网络配置2

    Proxmox VE虚拟化从入门到应用实战 服务器管理篇 网络配置2 一 Linux多网口绑定 多网口绑定 也称为网卡组或链路聚合 是一种将多个网卡绑定单个网络设备的技术 利用该技术可以实现某个或多个目标 例如提高网络链容错能力 增加网络通
  • 哈希算法总结

    目录 1 Hash是什么 它的作用 2 Hash算法有什么特点 2 1 Hash在管理数据结构中的应用 2 1 Hash在在密码学中的应用 3 Hash算法是如何实现的 4 Hash有哪些流行的算法 5 那么 何谓Hash算法的 碰撞 5
  • Markdown文件关机没保存,怎么恢复

    1 2 点开找到你想恢复的时间段的文件
  • JS date格式化

    Date prototype Format function fmt author meizz use strict jshint var o M this getMonth 1 月份 d this getDate 日 h this get
  • Qt Creator中,include路径包含过程(或如何找到对应的头文件)

    Qt Creator中 include路径包含过程 或如何找到对应的头文件 利用Qt Creator开发程序时 需要包含利用 include来添加头文件 大家都知道 include lt gt 用于包含标准库头文件 路径在安装软件的incl
  • centos7环境下mysql8的tar包的安装及配置

    内网环境下安装及配置 并将数据保存指向某个文件夹 因为博主这里的数据文件夹是有硬盘挂靠的 centos 7 aliyun CentOS 7 x86 64 DVD 1810 mysql mysql 8 0 17 linux glibc2 12
  • 【题解】闯关游戏

    题目描述 艾伦正在闯关 游戏有N个关卡 按照必须完成的顺序编号为1到N 每个关卡可以用两个参数来描述 prob i 和value i 这些参数的含义如下 每当艾伦尝试闯第i关时 他要么顺利通过 要么 挂掉 他完成该关卡的概率总是prob i
  • Red5应用开发(二)直播串流与录制

    环境 操作系统 win10 1803 Eclipse版本 4 7 3a Oxygen J2EE版本 Red5 Server版本 1 0 8 Release 环境搭建参考前一篇文章 Red5应用开发 一 开发环境搭建 后续不再涉及red5 f
  • 职场加班

    总是听到形形色色的职场加班过劳死的故事 甚至有人写了一篇文章 别让老板杀了你 职场果真那么恐怖吗 其实公司怎么想那是公司的事情 公司有权想着把你干掉 把你榨干 因为这样对于公司最有利 但是 问题在于我们自己怎么想呢 在我看来 在这个社会上混
  • python 泛型函数--singledispatch的使用

    functools singledispatch 将一个函数转变为单一分派的泛型函数 用 singledispatch装饰一个函数 将定义一个泛型函数 注意 我们创建的函数获得分派的依据是第一个参数的类型 from functools im
  • SSM框架整合方案(Spring+SpringMVC+Mybatis)

    一 将application进行纵向切分 每一个配置文件只配置与之相关的Bean 除此之外 项目中通常还有log4j properties SqlMapConfig xml db properties文件 二 各文件配置方案详解 1 日志组
  • IDDPM原理和代码剖析

    前言 Improved Denoising Diffusion Probabilistic Models IDDPM 是上一篇 Denoising Diffusion Probabilistic Models DDPM 的改进工作 之前一些
  • 2022年比若依更香的开源项目

    项目名 cpms 是Concise practical management system 的首字母缩写 意思是 简洁实用的后台管理系统 cpms开源项目目前分为 cpms cloud微服务架构和cpms单体应用架构 cpms cloud是
  • Ant Design:Form表单组件的正确使用

    Form 设置表单初始默认值 initialValues 只在初始化和重置表单时生效 Object 表单字段状态发生改变触发的回调函数 onValuesChange function changedValues allValues Form
  • vue3.2 父子组件传参

    父组件father vue 子组件child vue 1 父传子 把子组件引入到父组件里 定义数据 然后在子组件里使用props接收数据 father vue
  • python已知两边求第三边_探究“已知一个三角形两边及其夹角,求第三边”的问题...

    探究 已知一个三角形两边及其夹角求第三边 的问题 知识点 余弦定理 对应版本章节 本节课是人民教育出版社出版的高中数学 A 版数学必修 5 第一章 解三角 形 第一节第二课时余弦定理的内容 教学目标 1 理解利用向量猜想证明余弦定理 2 掌
  • 下载及安装Python详细步骤

    安装python分三个步骤 下载python 安装python 检查是否安装成功 1 下载Python 1 python下载地址https www python org downloads 2 选择下载的版本 3 点开Download后 找
  • NXP i.MX6ULL 移植python3.9.5

    项目场景 在眺望电子TW AC6G EVM开发板上移植python3 9 5 编译环境及开发包 主机 ubuntu18 04 交叉编译器 arm linux gnueabihf gcc QT5 12 8 qt everywhere open
  • pytorch用LeNet5识别Mnist手写体数据集(训练+预测单张输入图片代码)

    首先 在论文上的LeNet5的结构如下 由于论文的数据集是32x32的 mnist数据集是28x28的 所有只有INPUT变了 其余地方会严格按照LeNet5的结构编写程序 训练代码 import torch import torch nn