机器学习/深度学习--手写数字识别(MNIST数据集)

2023-10-26

import torch
# 导入torchvision的transform模块,用来处理数据集
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# 创建一个transforms实例,用于对数据集进行处理
transform = transforms.Compose([
    # 将图片转换为tensor
    transforms.ToTensor(),
    # 归一化
    transforms.Normalize((0.1307,), (0.3081,))
])

# 创建训练集
"""
datasets.MNIST是Pytorch的内置函数torchvision.datasets.MNIST,通过这个可以导入数据集。
    root设置数据集路径,如果该路径中有需要的数据,则使用,如果没有且download为True,则下载数据集到目录中。
    train=True 表明这是训练集
    transform=transform 表明使用上面定义的transform来处理数据集。
"""
# 创建训练集
train_dataset = datasets.MNIST(root='DataSet\\',
                               train=True,
                               download=True,
                               transform=transform)
# 创建测试集
test_dataset = datasets.MNIST(root='DataSet\\',
                              train=False,
                              download=True,
                              transform=transform)
# 定义训练数据装载器实例
train_loader = DataLoader(train_dataset,
                          batch_size=64,
                          shuffle=True)
# 定义测试数据装载器实例
test_loader = DataLoader(test_dataset,
                         batch_size=64,
                         shuffle=True)


# 定义模型
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # MNIST数据集中的图片是28*28=784的
        self.l1 = torch.nn.Linear(784, 512)
        self.l2 = torch.nn.Linear(512, 256)
        self.l3 = torch.nn.Linear(256, 128)
        self.l4 = torch.nn.Linear(128, 64)
        self.l5 = torch.nn.Linear(64, 10)

    def forward(self, x):
        # view中一个参数定为 - 1,代表自动调整这个维度上的元素个数,以保证元素的总数不变.每行784个元素
        # 如果设置为(-1,392)则x为2行392列
        x = x.view(-1, 784)
        x = torch.relu(self.l1(x))
        x = torch.relu(self.l2(x))
        x = torch.relu(self.l3(x))
        x = torch.relu(self.l4(x))
        return self.l5(x)

# 创建模型实例
net = Net()

# 设置损失函数和优化器
# 损失函数使用交叉熵
loss = torch.nn.CrossEntropyLoss(reduction='mean')
# 优化器使用随机梯度下降
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.5)

# 定义训练
def train(epoch):
    running_loss = 0.0
    for data in train_loader:
        inputs, label = data
        y_pre = net(inputs)
        loss_num = loss(y_pre, label)
        # 累加每一次训练的损失
        running_loss += loss_num

        # 梯度置0
        optimizer.zero_grad()
        # 反向传播
        loss_num.backward()
        # 参数更新
        optimizer.step()

    # 输出一下将所有样本训练一遍的损失
    print("epoch:", epoch+1, "one-epoch-loss:", running_loss)


# 定义测试集
def test():
    correct = 0
    sum = 0
    # 在测试集中不需要进行梯度下架
    with torch.no_grad():
        for data in test_loader:
            inputs, label = data
            outputs = net(inputs)
            """
            torch.max(matrix,dim)
            dim有两个取值0和1,取0时表示按列索引matrix每列最大值,取1时表示按行索引matrix每行最大值
            返回值是两个torch,第一个values是找到的最大值,第二个indices是最大值的索引
            
            例如:
            a = torch.tensor([[1,5,62,54], [2,6,2,6], [2,65,2,6]])
            torch.max(a, 1)
            输出:
            torch.return_types.max(
                values=tensor([62, 6, 65]),
                indices=tensor([2, 3, 1]))
            
            """
            # "_",是占位符,接收torch.max()第一个返回值
            _, predicted = torch.max(outputs.data, dim=1)
            sum += label.size(0)
            correct += (predicted == label).sum().item()
        print("Accuracy on test:", correct / sum)


if __name__ == '__main__':
    for epoch in range(10):
        train(epoch)
        test()

经过十轮训练后,模型 识别准确率达到了97.8%

在测试集中将图片换成自己要预测的图片

 

def test():
    img = cv2.imread('F:\\pytorch_study\\test01\\DataSet\\num_pic\\7.png')
    img = cv2.resize(img, (28, 28))
    # 训练集是单通道的图片,这里将图片拆分,只取图片一个通道
    img1 = cv2.split(img)[0]
    img_tensor = transform(img1)
    data = img_tensor
    # 在测试集中不需要进行梯度下架
    with torch.no_grad():
        inputs = data
        outputs = net(inputs)
        # "_",是占位符,接收torch.max()第一个返回值
        _, predicted = torch.max(outputs.data, dim=1)
    print('预测结果:', predicted.data.item())

运行结果:

 

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

机器学习/深度学习--手写数字识别(MNIST数据集) 的相关文章

随机推荐

  • 最小优先级队列 — 使用最小堆实现

    最小优先级支持的操作 1 INSERT S x 将元素x插入队列S 2 MINIMUM S 返回S中最小的元素 3 EXTRACT MIN S 去掉并返回S中最小的元素 4 DECREASE KEY S x key 将下标为x的元素值降低为
  • 获得代理ippython_python爬虫之抓取代理服务器IP

    前言 使用爬虫爬取网站的信息常常会遇到的问题是 你的爬虫行为被对方识别了 对方把你的IP屏蔽了 返回不了正常的数据给你 那么这时候就需要使用代理服务器IP来伪装你的请求了 免费代理服务器网站有 下面我们以西刺免费代理IP为例子看看是如何获取
  • ArcGISMapsSDK for UnrealEngine_AQ

    ArcGISMapsSDK for UnrealEngine AQ Prepare 1 ArcGIS Maps SDK for game engines 2 ArcGIS Maps SDK for Unreal Engine Beta 2
  • jQuery 的 DOM 操作- 中

    文章目录 jQuery 的 DOM 操作 中 复制节点 复制节点介绍 复制节点 应用实例 替换节点 替换节点介绍 替换节点 应用实例 属性操作 样式操作 样式操作介绍 应用实例 jQuery 的 DOM 操作 中 注意本篇和jQuery 的
  • 【java】常用到的一些获取文件内容的方法

    一 前奏准备 获取文件名 根据文件名获取路径 文件路径名 String path public String getPath return path 根据路径获取文件名 return 文件名字符串 public String fileNam
  • Cocos 2dx iOS 平台初始化,OpenGL 初始化,分辨率设置

    Cocos 2dx iOS 平台初始化 OpenGL 初始化 分辨率设置 1 Main m int retVal UIApplicationMain argc argv nil AppController AppController mm
  • 判断操作系统和浏览器类型(苹果还是安卓,微信还是QQ)

    一 获取操作系统类型 function getOS var userAgent navigator in window userAgent in navigator navigator userAgent toLowerCase var v
  • FPGA时序约束学习笔记——IO约束(转)

    一 参考模型 图源来自 抢先版 小梅哥FPGA时序约束从遥望到领悟 二 参数分析 T 0 gt 3 Tclk1 T 3 gt 4 Tco T 4 gt 5 T 5 gt 6 Tdata T 4 gt 5 Tdata Init T 5 gt
  • 渗透测试流程

    文章目录 前言 一 渗透测试流程 二 流程步骤 1 明确目标 2 信息收集 3 漏洞探测 4 漏洞验证 5 提权所需 6 信息分析 7 信息整理 8 形成报告 总结 前言 渗透测试 出于保护信息系统的目的 更全面地找出测试对象的安全隐患 入
  • python数据库连接

    python数据库连接 import os import time import pymysql import sys class Myclass object def init self try self db pymysql conne
  • Springboot整合Activiti详解

    文章目录 版本依赖 配置文件 需要注意的问题 画流程图 activiti服务类进行编写 流程部署 流程定义 启动流程 流程实例 测试流程 启动流程 完成任务 受理任务 版本依赖 开发工具 IDEA SpringBoot 2 4 5 这里我试
  • MySQL

    1 MySQL概述 1 什么是数据库 数据库是一个存储数据的仓库 2 都有哪些公司在用数据库 金融机构 游戏网站 购物网站 论坛网站 3 提供数据库服务的软件 1 软件分类 MySQL SQL Server Oracle Mariadb D
  • 初中计算机试题戏曲进校园,【校园通讯】“戏曲进校园”走进东街学校,春风化新雨,戏曲百媚生!...

    原标题 校园通讯 戏曲进校园 走进东街学校 春风化新雨 戏曲百媚生 戏曲进校园 戏曲进校园 走进东街学校 春风化新雨 戏曲百媚生 文 东街学校 张永慰 弘扬民族文化 展现戏曲精华 10月10日 戏曲进校园 活动走进济水东街学校 为全体师生带
  • 3.18飞书面试(58min)

    3 18飞书面试 58min 1 问项目 首先是问redis是怎么用的 mq的消费是怎么写的呢 mq如何保证消息消费的可靠性 你在项目中用到了本地缓存 放在了业务代码内存中 那如果签到一半你的项目突然崩了 本地缓存都消失了 那不是会出问题啊
  • [人工智能-深度学习-66]:架构 - 人工智能的学习误区与正确思路、方法

    作者主页 文火冰糖的硅基工坊 文火冰糖 王文兵 的博客 文火冰糖的硅基工坊 CSDN博客 本文网址 https blog csdn net HiWangWenBing article details 122116482 目录 前言 第1章
  • spring+struts+ibatis

    原来的系统里面只采用了struts的框架 并且没有使用struts的校验功能 为方便开发 修改框架为spring struts ibatis组合1 添加需要的jar文件2 添加spring配置文件applicationContext xml
  • view-source是一种协议,查看源码

    view source是一种协议 早期基本上每个浏览器都支持这个协议 后来Microsoft考虑安全性 对于WindowsXP pack2以及更高版本以后IE就不再支持此协议 但是这个方法在FireFox和Chrome浏览器都还可以使用 如
  • Linux驱动_spi驱动(ICM20608)

    参考 Linux SPI 驱动分析 1 结构框架 StephenZhou CSDN博客 linux spi驱动 Linux SPI 驱动分析 2 框架层源码分析 StephenZhou CSDN博客 spi message init SPI
  • sql server 加密_列级SQL Server加密

    列加密 创建一个新的数据库并创建CustomerInfo表 CREATE DATABASE CustomerData Go USE CustomerData GO CREATE TABLE CustomerData dbo Customer
  • 机器学习/深度学习--手写数字识别(MNIST数据集)

    import torch 导入torchvision的transform模块 用来处理数据集 from torchvision import transforms from torchvision import datasets from