【Pytorch】循环神经网络实现手写体识别

2023-11-13

【Pytorch】循环神经网络实现手写体识别

1 数据集加载

import seaborn as sns
sns.set(font_scale=1.5,style="white") 
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import time
import copy

import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torch.utils.data as Data
from torchvision import transforms

## 准备训练数据集Minist
train_data  = torchvision.datasets.MNIST(root='./data',  
                            train=True,   
                            transform=transforms.ToTensor(),  
                            download=True) 
## 定义一个数据加载器
train_loader = Data.DataLoader(
    dataset = train_data, ## 使用的数据集
    batch_size=64, # 批处理样本大小
    shuffle = True, # 每次迭代前打乱数据
    num_workers = 2, # 使用两个进程 
)


##  可视化训练数据集的一个batch的样本来查看图像内容
for step, (b_x, b_y) in enumerate(train_loader):  
    if step > 0:
        break
## 输出训练图像的尺寸和标签的尺寸,都是torch格式的数据
print(b_x.shape)
print(b_y.shape)
train_data

输出

torch.Size([64, 1, 28, 28])
torch.Size([64])
Dataset MNIST
    Number of datapoints: 60000
    Root location: ./data
    Split: Train
    StandardTransform
Transform: ToTensor()

## 准备需要使用的测试数据集
test_data  = torchvision.datasets.MNIST(root='./data', 
                           train=False, 
                           transform=transforms.ToTensor())
## 定义一个数据加载器
test_loader = Data.DataLoader(
    dataset = test_data, ## 使用的数据集
    batch_size=64, # 批处理样本大小
    shuffle = True, # 每次迭代前打乱数据
    num_workers = 2, # 使用两个进程 
)


##  可视化训练数据集的一个batch的样本来查看图像内容
for step, (b_x, b_y) in enumerate(train_loader):  
    if step > 0:
        break
## 输出训练图像的尺寸和标签的尺寸,都是torch格式的数据
print(b_x.shape)
print(b_y.shape)
test_data

输出

torch.Size([64, 1, 28, 28])
torch.Size([64])
Dataset MNIST
    Number of datapoints: 10000
    Root location: ./data
    Split: Test
    StandardTransform
Transform: ToTensor()

2 搭建RNN模型

class RNNimc(nn.Module):
    def __init__(self, input_dim, hidden_dim, layer_dim, output_dim):
        """
        input_dim:输入数据的维度(图片每行的数据像素点)
        hidden_dim: RNN神经元个数
        layer_dim: RNN的层数
        output_dim:隐藏层输出的维度(分类的数量)
        """
        super(RNNimc, self).__init__()
        self.hidden_dim = hidden_dim ## RNN神经元个数
        self.layer_dim = layer_dim ## RNN的层数
        # RNN
        self.rnn = nn.RNN(input_dim, hidden_dim, layer_dim,
                          batch_first=True, nonlinearity='relu')
        
        # 连接全连阶层
        self.fc1 = nn.Linear(hidden_dim, output_dim)
    def forward(self, x):
        # x:[batch, time_step, input_dim]
        # 本例中time_step=图像所有像素数量/input_dim
        # out:[batch, time_step, output_size]
        # h_n:[layer_dim, batch, hidden_dim]
        out, h_n = self.rnn(x, None) # None表示h0会使用全0进行初始化
        # 选取最后一个时间点的out输出
        out = self.fc1(out[:, -1, :]) 
        return out

## 模型的调用
input_dim=28   # 图片每行的像素数量
hidden_dim=128  # RNN神经元个数
layer_dim = 1  # RNN的层数
output_dim=10  # 隐藏层输出的维度(10类图像)
MyRNNimc = RNNimc(input_dim, hidden_dim, layer_dim, output_dim)
print(MyRNNimc)

输出

RNNimc(
  (rnn): RNN(28, 128, batch_first=True)
  (fc1): Linear(in_features=128, out_features=10, bias=True)
)

3 训练模型

## 对模型进行训练
optimizer = torch.optim.RMSprop(MyRNNimc.parameters(), lr=0.0003)  
criterion = nn.CrossEntropyLoss()   # 损失函数
train_loss_all = []
train_acc_all = []
test_loss_all = []
test_acc_all = []
num_epochs = 30
for epoch in range(num_epochs):
    print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    MyRNNimc.train() ## 设置模型为训练模式
    corrects = 0
    train_num  = 0
    for step,(b_x, b_y) in enumerate(train_loader):
        # input :[batch, time_step, input_dim]
        xdata = b_x.view(-1, 28, 28)
        output = MyRNNimc(xdata)     
        pre_lab = torch.argmax(output,1)
        loss = criterion(output, b_y) 
        optimizer.zero_grad()        
        loss.backward()       
        optimizer.step()  
        loss += loss.item() * b_x.size(0)
        corrects += torch.sum(pre_lab == b_y.data)
        train_num += b_x.size(0)
    ## 计算经过一个epoch的训练后在训练集上的损失和精度
    train_loss_all.append(loss / train_num)
    train_acc_all.append(corrects.double().item()/train_num)
    print('{} Train Loss: {:.4f}  Train Acc: {:.4f}'.format(
        epoch, train_loss_all[-1], train_acc_all[-1]))
    ## 设置模型为验证模式
    MyRNNimc.eval()
    corrects = 0
    test_num  = 0
    for step,(b_x, b_y) in enumerate(test_loader):
        # input :[batch, time_step, input_dim]
        xdata = b_x.view(-1, 28, 28)
        output = MyRNNimc(xdata)     
        pre_lab = torch.argmax(output,1)
        loss = criterion(output, b_y) 
        loss += loss.item() * b_x.size(0)
        corrects += torch.sum(pre_lab == b_y.data)
        test_num += b_x.size(0)
    ## 计算经过一个epoch的训练后在测试集上的损失和精度
    test_loss_all.append(loss / test_num)
    test_acc_all.append(corrects.double().item()/test_num)
    print('{} Test Loss: {:.4f}  Test Acc: {:.4f}'.format(
        epoch, test_loss_all[-1], test_acc_all[-1]))

4 模型保存和加载

# 保存
torch.save(MyRNNimc, 'rnn.pkl')

model = torch.load('rnn.pkl')
print(model)

输出

RNNimc(
  (rnn): RNN(28, 128, batch_first=True)
  (fc1): Linear(in_features=128, out_features=10, bias=True)
)

模型测试

import cv2
import matplotlib.pyplot as plt

# 第一步:读取图片
img = cv2.imread('./data/test/8.png') 
print(img.shape)

# 第二步:将图片转为灰度图
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
print(img.shape)
plt.imshow(img,cmap='Greys')

# 第三步:将图片的底色和字的颜色取反
img = cv2.bitwise_not(img)
plt.imshow(img,cmap='Greys')

# 第四步:将底变成纯白色,将字变成纯黑色
img[img<=144]=0
img[img>140]=255  # 130

# 显示图片
plt.imshow(img,cmap='Greys')

# 第五步:将图片尺寸缩放为输入规定尺寸
img = cv2.resize(img,(28,28))

# 第六步:将数据类型转为float32
img = img.astype('float32')

# 第七步:数据正则化
img /= 255

img = img.reshape(1,784)
# 第八步:增加维度为输入的规定格式

_img = torch.from_numpy(img).float()
# _img = torch.from_numpy(img).unsqueeze(0)

输出

(234, 182, 3)
(234, 182)

在这里插入图片描述

model.eval()
_img = _img.view(-1, 28, 28)
# 第九步:预测
outputs = model(_img)

# 第十步:输出结果
print(outputs.argmax())
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【Pytorch】循环神经网络实现手写体识别 的相关文章

随机推荐

  • 龙书笔记(13)

    chap 13 地形绘制基础 主要是创建一个 地形类 Terrain 1 高度图 其实是一个数组 每个元素都指定了地形方格中某一顶点的高度值 每个元素只分配了1个字节的存储空间 当加载到程序时 重新分配 浮点型 或 整型 数据来存储这些高度
  • CentOS7开机时的菜单选项及时间的修改

    转载记录 以防丢失 一 在CentOS更新后 并不会自动删除旧内核 所以在启动选项中会有多个内核选项 可以手动使用以下命令删除多余的内核 正常下 第一个选项正常启动 第二个选项急救模式启动 系统出项问题不能正常启动时使用并修复系统 1 查看
  • 记录一下树莓派打内核补丁cjktty的天坑

    首先cjktty的下载地址在此 大家根据自己的linux内核去选择 https github com zhmars cjktty patches 下载好了补丁文件之后 需要下载完整的linux内核 是的完整的 https github co
  • ahut 月赛1

    心得 一点一点理解 对于一段要学习的代码 跟着写下来 理解一点写一点 对于一道题目 用记事本 看题目 看一句题目 用自己的话概括一句 写在记事本上 并将自己的 想法一并写下来 这样做下来 心会很平静 你会发现 理解一段代码并不费力 解决一道
  • Cookie、cookie与session区别

    Cookie Cookie 有时也用其复数形式 Cookies 类型为 小型文本文件 是某些网站为了辨别用户身份 进行Session跟踪而储存在用户本地终端上的数据 通常经过加密 由用户客户端计算机暂时或永久保存的信息 Cookie有什么用
  • 一个字节造成的巨大性能差异——SQL Server存储结构

    今天同事问了我一个SQL的问题 关于SQL Server内部存储结构的 我觉得挺有意思 所以写下这篇博客讨论并归纳了一下 问题是这样的 首先我们创建两张表 一张表的列长度是4039字节 另一张表的长度是4040字节 他们就只有一个字节的差距
  • 阿里巴巴 cola设计架构

    https github com alibaba COLA
  • leetcode 21 合并两个有序链表 (c++和python)

    目录 题目描述 解题思路 C 代码 python代码 题目描述 将两个有序链表合并为一个新的有序链表并返回 新链表是通过拼接给定的两个链表的所有节点组成的 示例 输入 1 gt 2 gt 4 1 gt 3 gt 4 输出 1 gt 1 gt
  • golang的chan(管道)

    golang的chan翻译成中文就是管道 顾名思义 就是管道的一端用来读 另一端用来写 这与write和read函数的性质是非常相似的 比如说管道中没数据 就会发生读阻塞 管道中数据是满的 就会发生写阻塞 又类似生产者和消费者 也就是必须有
  • 大学生python实验心得体会_大学生实训心得体会3篇

    转眼间为期两个星期的实训就结束了 但是安利公司的物流配送 黄埔港 益邦物流公司 南沙港以及学校里面的航海模拟实验中心 轮机实训实验楼这些实训过程仍历历在目 以下是小编整理的大学生实训心得体会 欢迎阅读 大学生实训心得体会1 通过实训中心老师
  • 微信小程序-flex布局:垂直、水平方向-自动填充满剩余空间

    在微信小程序项目中经常需要将水平或垂直方向分成两大部分 一部分内容宽度或高度固定 剩余的一部分需填充满剩余空间 那么 该怎么快速解决这类布局 效果图如下 垂直方向 水平方向 我个人比较喜欢使用flex布局 面对此类布局 最先想到的也是fle
  • 【HDLBits 刷题 12】Circuits(8)Finite State Manchines 27-34

    目录 写在前面 Finite State Manchines 2014 q3c m2014 q6b m2014 q6c m2014 q6 2012 q2fsm 2012 q2b 2013 q2afsm 2013 q2bfsm 写在前面 HD
  • 类和对象笔记(1.类和对象的关系,类基本架构)

    梳理C 基础 纯干货或许会很干燥 gt gt gt gt gt 分界线 类 指对象的类型 类代表了一批对象的共性和特征 抽象的 不占用内存 对象 类的具体实例 具体的 占用储存空间 类是对象的抽象 对象是类的具体实例 可以同结构体进行比较学
  • eclipse常用插件之FindBugs

    1 简介 FindBugs 是由马里兰大学提供的一款开源 Java静态代码分析工具 FindBugs通过检查类文件或 JAR文件 将字节码与一组缺陷模式进行对比从而发现代码缺陷 完成静态代码分析 FindBugs既提供可视化 UI 界面 同
  • 打印金字塔代码

    Description 输入n值 打印下列形状的金字塔 其中n代表金字塔的层数 Input 输入只有一个正整数n Output 打印金字塔图形 其中每个数字之间有一个空格 include
  • JsonMap对象的获取与前台浏览器报错Uncaught TypeError: Cannot read property '0' of undefined

    JsonMap对象的获取与前台浏览器报错Uncaught TypeError Cannot read property 0 of undefined 后台问题在浏览器报错 很多时候在我们遇到浏览器报错的时候都会去前端js里去找错 但有时候确
  • 一起学nRF51xx 15 - spis

    前言 SPIS是一个从SPI 它与EasyDMA一起支持与外部的主SPI超低功耗串行通信 EasyDMA使得SPIS交互无需CPU的介入 在提高数据传输效率的同时还减轻了CPU的负担 SPIS即是SPI从模式 但它比stm32中直接将spi
  • Android开发:最全面、最易懂的Android屏幕适配解决方案

    前言 Android的屏幕适配一直以来都在折磨着我们Android开发者 本文将结合 Google的官方权威适配文档 郭霖 Android官方提供的支持不同屏幕大小的全部方法 Stormzhang Android 屏幕适配 鸿洋 Andro
  • Wireshark抓包分析交换机工作原理

    实验名称 交换机工作原理 实验目的 1 熟悉Linux虚拟网络环境 2 熟悉Linux中network namespace的基本操作 3 熟悉Linux中虚拟以太网设备Tap和veth pair的基本操作 4 熟悉Linux中Bridge设
  • 【Pytorch】循环神经网络实现手写体识别

    Pytorch 循环神经网络实现手写体识别 1 数据集加载 2 搭建RNN模型 3 训练模型 4 模型保存和加载 模型测试 1 数据集加载 import seaborn as sns sns set font scale 1 5 style