深度学习之经典案例 CIFAR10 图形识别(jupyter)

2023-11-04

图像识别:CIFAR10图形识别

1.CIFAR10数据集共有60000张彩色图像,这些图像式32*32*3,分为10个类,每个类6000张

2.这里面有50000张用于训练,构成5个训练批,每一批10000张图;另外10000张用于测试,单独构成一批。测试批的数据里,取自10类中的每一类,每一类随机取1000张。

3.一个训练批中的各类图像并不一定数量相同,总的来看训练集,每一类都有5000张图片

 代码如下:与官网代码不一致

# 导入包
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms,datasets

 # 设置transforms
transform = transforms.Compose([
    transforms.ToTensor(), # numpy -> Tensor
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))  # 归一化 ,范围[-1,1]
])

 # 下载训练数据集
# 训练集
trainset = datasets.CIFAR10(root='./CIFAR10',train=True,download=True,transform=transform)
# 测试集
testset = datasets.CIFAR10(root='./CIFAR10',train=False,download=True,transform=transform)

 出现如下图结果数据集下载成功

# 批量获取数据
from torch.utils.data.dataloader import DataLoader

BATCH_SIZE = 32

train_loader = DataLoader(trainset,batch_size=BATCH_SIZE,shuffle=True,num_workers=8,pin_memory=True)

test_loader = DataLoader(testset,batch_size=BATCH_SIZE,shuffle=True,num_workers=8,pin_memory=True)


 注意:其中BATCH_SIZE = 32 中的32 可以根据自己电脑配置来定,配置高可以定128 低可以定16

# 可视化显示
import matplotlib.pyplot as plt
import numpy as np

# 十个类别
classes = ('plane','car','bird','cat','deer',
          'dog','frog','horse','ship','truck')

def imshow(img):
    img = img / 2 + 0.5 # 逆正则化
    np_img = img.numpy()  # tensor --> numpy
    plt.imshow(np.transpose(np_img,(1,2,0)))  # 改变通道顺序
    plt.show()
    
# 随机获取一批数据
imgs,labs = next(iter(train_loader))


print(imgs.shape)
print(labs.shape)
    
#调用方法
imshow(torchvision.utils.make_grid(imgs))

# 输出这批图片对应的标签
print(' '.join('%5s' % classes[labs[i]] for i in range(BATCH_SIZE)))    

 结果如下:

 其中

torch.Size([32, 3, 32, 32])
torch.Size([32])

中32代表32张图片,3代表3个通道,32代表像素

# 定义网络模型
import torch.nn as nn
import torch.nn.functional as F

'''
知识点:
1.特征图尺寸的计算公式为:[(原图片尺寸 = 卷积核尺寸) / 步长] + 1
'''
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        # 卷积层1.输入是32*32*3,计算(32-5)/ 1 + 1 = 28,那么通过conv1输出的结果是28*28*6
        self.conv1 = nn.Conv2d(3,6,5)  # imput:3 output:6, kernel:5
        # 池化层, 输入时28*28*6, 窗口2*2,计算28 / 2 = 14,那么通过max_poll层输出的结果是14*14*6
        self.pool = nn.MaxPool2d(2,2) # kernel:2 stride:2
        # 卷积层2, 输入是14*14*6,计算(14-5)/ 1 + 1 = 10,那么通过conv2输出的结果是10*10*16
        self.conv2 = nn.Conv2d(6,16,5) # imput:6 output:16, kernel:5
        # 全连接层1
        self.fc1 = nn.Linear(16*5*5, 120)  # input:16*5*5,output:120
        # 全连接层2
        self.fc2 = nn.Linear(120, 84)  # input:120,output:84
        # 全连接层3
        self.fc3 = nn.Linear(84, 10)  # input:84,output:10
        
    def forward(self,x):
        # 卷积1
        '''
        32x32x3 --> 28x28x6 -->14x14x6
        '''
        x = self.pool(F.relu(self.conv1(x)))
        # 卷积2
        '''
        14x14x6 --> 10x10x16 --> 5x5x16
        '''
        x = self.pool(F.relu(self.conv2(x)))
        # 改变shape
        x = x.view(-1,16*5*5)
        # 全连接层1
        x = F.relu(self.fc1(x))
        # 全连接层2
        x = F.relu(self.fc2(x))
        # 全连接层3
        
        x = self.fc3(x)
        return x 

 注意:__init__这一块下划线要注意,按理说只要将模型定义到__init__()里就ok了,但是大家容易少打一个下划线会报错,将下划线_改为__即可解决问题。

# 创建模型
net = Net().to('cuda')

电脑有GPU的话这一步是部署到CUDA上运行调用GPU, 这一步容易出现下图问题,

 这时候多运行几次,代码是没有问题的,应为JUPYTER是在网页上运行,需要时间反应,多运行几次

如果出现以下问题:

 注意Linear中的L要大写

# 定义优化器和损失函数
import torch.optim as optim

criterion = nn.CrossEntropyLoss()  # 交叉式损失函数

optimizer = optim.SGD(net.parameters(),lr=0.001,momentum=0.9)  # 优化器
# 定义函数
EPOCHS = 200

for epoch in range(EPOCHS):
    
    train_loss = 0.0
    for i, (datas,labels) in enumerate(train_loader):
        datas,labels = datas.to('cuda'),labels.to('cuda')
        # 梯度置零
        optimizer.zero_grad()
        # 训练
        outputs = net(datas)
        # 计算损失
        loss = criterion(outputs,labels)
        # 反向传播
        loss.backward()
        # 参数更新
        optimizer.step()
        # 累计损失
        train_loss += loss.item()
    # 打印信息
    print(epoch+1, i+1,train_loss/len(train_loader.dataset))

 循环次数可以自己设置,这里设置为200轮,for循环读取训练集

输出结果如下:

 (可以参考网上其他输出格式) 

# 测试
correct = 0
total = 0
# flag=True
with torch.no_grad():
    for i , (datas,labels) in enumerate(test_loader):
        # 输出
        outputs = model(datas) # outputs.data,shape --> torch.Size([128,10])
        _, predicted = torch.max(outputs.data, dim=1)  # 第一个是值得张量,第二个是序号张量
        # 累计数据值
        total += labels.size(0)  # labels.size() --> torch.Size([128]) , labels.size(0) --> 128
        # 比较有多少个预测正确
        correct += (predicted == labels).sum()  # 相同为1,不同为0,利用sum()总求和
    print("在1000张测试集图片上的准确率:{:.3f}%".format(torch.true_divide(correct,total))
        
# 显示每一类预测的概率
class_correct = list(0. for i in range(10))
total = list(0. for i in range(10))

with torch.no_grad():
    for (images,labels) in test_loader:
        outputs = model(images)  # 输出
        _,predicted = torch.max(outputs,dim=1)  # 获取到每一行的最大索引
        c = (predicted ==labels).squeeze()  # squeeze() 去掉0维【默认】,unsqueeze() 增加一维
        if labels.shape[0]  == 128:
            for i in range(BATCH_SIZE):
                label = labels[i] # 获取每一个label
                class_correct[label] += c[i].item()  # 累计维True都个数,注意:1 + True = 2,1 + False = 1
                total[label]  += 1 # 该类总的个数
                
# 输出正确率
for i in range(10):
    print("正确率 : %5s : $2d %%" % (classes[i],100 * class_correct[i] / total[i])


参考视频:​​​​​07-02 经典案例 CIFAR10 图像识别【个人实现】_哔哩哔哩_bilibili

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

深度学习之经典案例 CIFAR10 图形识别(jupyter) 的相关文章

随机推荐

  • NETCore入门系列(Log4NET组件的使用)

    文章目录 分析 整合Log4net 源码 分析 1 官方自带的Log中间件可在命令行中输出日志 通过在构当前控制器的构造函数中注入 如下图 2 此时如果想要将日志输出到项目的某个文件中 则可以通过整合Log4net组件 3 一般建议日志记录
  • 算法:z字形排列

    将一个给定字符串根据给定的行数 以从上往下 从左到右进行 Z 字形排列 class Solution public string convert string s int numRows string result 如果排序长度为1 或者字
  • Python,OpenCV中的非局部均值去噪(Non-Local Means Denoising)

    Python OpenCV中的非局部均值去噪 Non Local Means Denoising 1 效果图 2 原理 3 源码 2 1 单彩色图去噪 2 2 多连续彩色帧去噪 参考 这篇博客将介绍不同的计算摄影技术 非局部均值去噪 Non
  • MYSQL 命令大全

    一 连接MySQL 格式 mysql h 主机地址 u 用户名 p 用户密码 1 例1 连接到本机上的MYSQL 首先在打开DOS 窗口 然后进入目录 mysqlbin 再键入命令mysql uroot p 回车后提示你输密码 如果刚安装好
  • PHY调试经验

    1 PHY调试过程 1 设备树中配置正确的PHY ADDR PHY ID clause 45或者22协议 PHY ADDR配置不正确会导致MDC MDIO通信不正常或失败 PHY ID用于匹配PHY驱动程序 2 通过MDC MDIO读写PH
  • Google亲儿子 Nexus/Pixel 手机刷机Root之旅

    简介 本文介绍的方法是针对Google亲儿子的教程 其他国内厂商请绕道 1 解锁 1 1 OEM解锁 想要做下面这些事 需要先在开发者选项里打开oem解锁 如果你的手机是V版 运营商定制版 请看这里 oem解锁选项灰色 1 2 进入boot
  • 【JDBC】关于JDBC入门和一些见解

    关于JDBC的一些理解和总结 JDBC连接数据库 刚开始学的时候经常忘记步骤 其实多了几次之后发现完全就是自己没有理解到原理 现在回头看还是挺有意思的 分为以下几个步骤 1 注册加载JDBC驱动 把Driver装进JVM Class for
  • Centos 7 重启网卡报错解决方案

    一 Network 当重启网卡时报错 解决方案 步骤1 修改对应文件 增加命令 步骤2 关闭NetworkManager服务 并重启网卡 systemctl stop NetworkManager systemctl restart net
  • 解谜元宇宙元年的十个疑问

    解谜元宇宙元年的十个疑问 2021年 元宇宙突然出现在大家的视野之中 相关概念受到资本的热捧 成为金融市场的热点 这难免会让我们对元宇宙产生很多好奇和疑问 本文总结了十个对元宇宙的疑问 并一一作出解答 2021年为什么是元宇宙元年 元宇宙
  • 关于运算放大器电流流向的问题

    前言 一 问题的引入 二 提出问题 三 问题解答 写在结尾的话 前言 问题缘起于一次硬件同事之间的讨论 虽然目前我不是做硬件的 但签于我的专业以及之前从事的工作 觉得有必要把记录下来 后期也打算写一些站在学习者的角度 关于硬件知识的学习心得
  • 数字电路和模拟电路-8触发器

    前言 掌握锁存器原理及应用 基本SR锁存器 钟控SR锁存器 钟控D锁存器 钟控D锁存器的动态参数 掌握触发器原理及应用 主从触发器 维持阻塞触发器 其它功能的触发器 目录 一 基本SR锁存器 1 双稳态电路 Bistate Elements
  • Android系统Unity使用HttpWebRequest访问Https请求出现连接超时

    多渠道版本配置网络地址时 http地址替换为了https 由于粗心大意 之前同事遗留的请求框架代码没有对https协议进行 处理 导致在android手机下unity访问https地址进行配置文件下载更新时出现连接超时问题 解决方案 if
  • word vba设置表格样式

    Sub 表格处理 功能 光标在表格中处理当前表格 否则处理所有表格 Application ScreenUpdating False 关闭屏幕刷新 Application DisplayAlerts False 关闭提示 On Error
  • java连接db2数据库示例代码_java实现连接db2数据库的代码实例

    java实现连接db2数据库的代码实例 第一种 目前ibm一直都没有提供type 1的jdbc驱动程序 第二种 类型2驱动 com ibm db2 jdbc app db2driver 该驱动也位于包db2java zip中 jdk必须能访
  • uniapp 微信小程序长按识别二维码,跳转小程序、个人微信

    前言 业务要求是小程序放一个二维码图片 长按可以识别二维码 进而识别出个人微信 添加个人微信 我们可以通过uni previewImage OBJECT 或者 wx previewImage Object object 预览当前图片去实现
  • 24-系统自带的 Win+R 功能

    Win 运行窗口 Win R 开始菜单 gt 运行 是 Windows 的一个原生的功能 从 XP 到 Windows 10 都自带了 当用户按下快捷键 Win R Win 为键盘上Windows图标键 后 系统会弹出一个小窗口让你输入命令
  • 用户态虚拟化IO通道实现概览及实践(上)

    自虚拟化技术诞生起 提升虚拟化场景中IO设备性能和驱动的兼容性 可扩展性一直是备受关注和追求的目标 随着半虚拟化技术的出现 virtio设备及驱动也很快流行并逐步变成了虚拟化应用中的主要IO通道形态 例如 virtio现已支持实现的设备涵盖
  • Dell IDRAC服务器重装系统详解(远程连接)

    主要的操作步骤文末附上的那篇博客写的比较详细了 不足的地方是有一些小问题没有说明白 导致新手可能不太清楚操作 而无法 复现 安装过程 TIPS 1 远程连接登录的时候 用户名root 密码calvin不一定可行 如果不行的话 看一下服务器机
  • 研一Python基础课程第四周课后习题分享(含源代码)

    代码写的较多 有问题可以私聊我 第四周作业分享 一 题目前言 二 题目分享 1 问题1 2 问题2 3 问题3 4 问题4 5 问题5 6 问题6 7 问题7 8 问题8 9 问题9 10 问题10 11 问题11 12 问题12 13 问
  • 深度学习之经典案例 CIFAR10 图形识别(jupyter)

    图像识别 CIFAR10图形识别 1 CIFAR10数据集共有60000张彩色图像 这些图像式32 32 3 分为10个类 每个类6000张 2 这里面有50000张用于训练 构成5个训练批 每一批10000张图 另外10000张用于测试