PyTorch: 训练分类CIFAR10

2023-11-11

# !/usr/bin/env python
# -- coding: utf-8 --
# @Author zengxiaohui
# Datatime:8/13/2021 11:20 AM
# @File:train_cifar10
import os
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


def shufflenet_v2_x0_5(nc, pretrained):
    model_ft = torchvision.models.shufflenet_v2_x0_5(pretrained=pretrained)
    num_ftrs = model_ft.fc.in_features
    model_ft.fc = nn.Linear(num_ftrs, nc)
    return model_ft


if __name__ == '__main__':
    os.environ['CUDA_VISIBLE_DEVICES'] = '1'
    epochs = 5
    batch_size = 256
    num_workers = 8
    classes = 10

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers,
                                              pin_memory=True)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    model = shufflenet_v2_x0_5(classes, True)
    model.cuda()
    model.train()

    criterion = nn.CrossEntropyLoss()
    # SGD with momentum
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    for epoch in range(epochs):
        running_loss = 0.0
        for i, (inputs, labels) in tqdm(enumerate(trainloader)):
            inputs, labels = inputs.cuda(), labels.cuda()

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward
            outputs = model(inputs)
            # loss
            loss = criterion(outputs, labels)
            # backward
            loss.backward()
            # update weights
            optimizer.step()

            # print statistics
            running_loss += loss
        print('%d/%d loss: %.3f' % (epochs, epoch + 1, running_loss / len(trainset)))

    correct = 0
    model.eval()
    for j, (images, labels) in tqdm(enumerate(testloader)):
        outputs = model(images.cuda())
        _, predicted = torch.max(outputs.data, 1)
        correct += (predicted.cpu() == labels).sum()
    print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / len(testset)))

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

PyTorch: 训练分类CIFAR10 的相关文章

  • 在vue项目中设置网站图标

    怎么在我们的Vue项目中设置自定义的网站图标 首先我们需要先制作个ico图标 大小为32 32的 放到static文件夹下 附制作网站 我们把制作好的ico图片改名为 favicon ico 注 必须改名 放到我们项目中的static文件夹
  • softmax分类器 python实现

    转自 http blog csdn net wds2006sdo article details 53699778 utm source itdadao utm medium referral 算法 算法参考的是Andrew 的课件与这篇文

随机推荐

  • Android中获取View宽高方法

    今天遇到一个问题 就是view获取宽度 高度都为0的问题 其实这个大家都遇到过 这里转载别人的 大家好共同学习 本文转载于 http www jianshu com p f56c92e29dea Android开发中经常需要获取控件的宽高
  • FileZilla的下载与安装

    FileZilla的下载与安装 为什么要使用FileZilla进行文件互传呢 Windows下 FileZilla客户端下载与安装 1 FileZilla的下载 1 FileZilla的安装 1 双击运行安装包 点击 i agree 2 n
  • Shader中的一些专业用语的解释

    Shader中的一些专业用语的解释 此文章收录于我主页顶置的 Unity Shader入门精要文章目录 点击即可跳转 一 什么是OPenGL DirectX 简单的来说 就是图像应用编程的接口 这些接口用语渲染二维和三维的图形 架起了上层应
  • 【毕业设计】基于单片机的桌面炫酷律动灯条 -物联网 嵌入式 单片机

    文章目录 0 前言 1 简介 2 主要器件 3 实现效果 4 设计原理 5 部分核心代码 6 最后 0 前言 这两年开始毕业设计和毕业答辩的要求和难度不断提升 传统的毕设题目缺少创新和亮点 往往达不到毕业答辩的要求 这两年不断有学弟学妹告诉
  • 公办幼儿园教师要涨工资了???

    终于盼到这一天了 已在市区公办园上班3年多却一直没有编制的季馨 听说从明年开始要涨工资了 高兴坏了 记者从日前召开的全市学前教育工作会议上获悉 从2012年起 确保市区公办幼儿园中具有国家教师资格的聘用教师最低工资水平不得低于当地最低工资标
  • 蓝桥杯 问题 1083: Hello, world!(C/C++ vector实现)

    问题 1083 Hello world 时间限制 1Sec 内存限制 64MB 提交 944 解决 476 题目描述 This is the first problem for test Since all we know the ASCI
  • 《一周搞定模电》—功率放大器

    系列文章目录 文章目录 系列文章目录 前言 一 功率放大电路三极管的工作模式 二 功率放大器内部结构 前言 功率放大器指一种以输出较大功率为目的的放大电路 特点 输出电压大 输出电流大 放大电路的输出电阻与负载匹配 电压放大器和功率放大器的
  • 三子棋创作(c语言)

    我们写三子棋之前首先要思考一下三子棋的实现逻辑 一 1 游戏菜单 是选择开始游戏还是结束游戏 2 打印一个棋盘出来 并且进行棋盘的初始化 即没有旗子的棋盘 3 玩家下棋 用 表示 4 电脑下棋 用 表示 5 判断胜负 电脑和玩家下完棋之后
  • java使用lambda表达式对List集合进行操作(Java8)

    import java util ArrayList import java util List import java util function Predicate import java util stream Collectors
  • token会被截取吗_OAuth2 为什么要用 code 换 token

    先简单介绍下 OAuth2 再用一个例子说明下为什么要用 code 换 token OAuth2 简单介绍 4 个角色 resource owner 可以授权访问被保护资源的实体 如果是人的话 即是最终用户 resource server
  • h2数据库优缺点

    h2数据库是嵌入式的内存型数据库 也可以存储在磁盘上 效率比通过socket调用的redis执行的要快 纯java编写就一个jar h2数据库的缺点是不适合大数据量高并发的操作
  • centos 安装防火墙,并开启对应端口号

    1 查看防火墙状态 命令 systemctl status firewalld service 开启防火墙时 提示没有安装防火墙 root localhost systemctl start firewalld service Failed
  • 关于锁的面试题

    1 synchronized和ReentranctLock有什么区别 底层实现 synchronized是jvm层面的锁 通过monitor对象完成 对象只能在同步代码块和同步方法中调用wait notify方法 ReentranctLoc
  • Java多线程——线程的sleep方法、中断线程的睡眠

    一 关于Sleep方法的应用 public static void sleep long millis throws InterruptedException 让当前正在执行的线程进入休眠 暂时停止执行 指定的毫秒数 静态方法 Thread
  • 数字媒体技术专业方向

    现在是大三下 这篇文章是大一时 整理知乎青岛大学 某学姐的高赞回答 咱这个专业 你可以根据你的学校进行选择 学校好 按部就班的学 以下几个方向都走得通 学校不好 很普通 那么大概率也不学了什么 普通本科院校的学风啊 教学质量啊 与其都学个皮
  • C++11/14/17中提供的mutex系列区别

    C 11 14 17中提供的mutex系列类型如下 互斥量 C 版本 作用 mutex C 11 基本的互斥量 timed mutex C 11 timed mutex带超时功能 在规定的等待时间内 没有获取锁 线程不会一直阻塞 代码会继续
  • 监听小程序切换到后台

    注意要写在app js里面 onHide wx onAppHide
  • 图像处理学习笔记(三):基于匹配的目标识别

    Matlab图像处理学习笔记 三 基于匹配的目标识别 如果要在一幅图像中寻找已知物体 最常用且最简单的方法之一就是匹配 在目标识别的方法中 匹配属于基于决策理论方法的识别 匹配方法可以是最小距离分类器 相关匹配 本文code是基于最小距离分
  • 三进制计算机_数学糖果S10:N进制

    不同进制各有各的特点 二进制更为基础 十进制匹配人体手指数量 十二进制之基数12所含因数多 十六进制之基数16易被多次二分 六十进制结合了五进制与十二进制 世界可能是由概率控制的 现实世界中十进制被选中 计算机世界中二进制被选中 N 进 制
  • PyTorch: 训练分类CIFAR10

    usr bin env python coding utf 8 Author zengxiaohui Datatime 8 13 2021 11 20 AM File train cifar10 import os import torch