神经网络(十四)Pytorch完整模型训练和调用GPU加速

2023-11-13

一、模型的训练

        Step1.准备数据集

import torchvision

train_data = torchvision.dataset.CIFAR10("../data",train=True,
        transform=torchvision.ToTensor,download=True)    --载入训练集
test_data = torchvision.dataset.CIFAR10("../data",train=False,
        transform=torchvision.ToTensor,download=True)    --载入测试集

                        Tips.获取数据集长度

train_data_size = len(train_data)

        Step2.加载数据集

train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)

        Step3.搭建网络

class MyNerNet(nn.Moduel):
    def __init__(self):
        super(MyNerNet,self).__init__()    --基类初始化
        self.model = nn.Sequential(                  --网络序列器
                            nn.Cov2d(3,32,5,1,2),    --卷积
                            nn.MaxPool2d(2),         --池化
                            nn.Conv2d(32,32,5,1,2),
                            nn.MaxPool2d(2),
                            nn.Conv2d(32,64,5,1,2),
                            nn.MaxPool2d(2),
                            nn.Flatten(),            --展平
                            nn.Linear(64*4*4,64),    --线性层
                            nn.Linear(64,10)         --线性层代分类器)

    def forward(self,x):    --传递函数
        return self,model(x)

                可将网络相关代码放置在一个单独的文件中,但是在主文件中需要使用引用

from model import *    --将model文件中所有的内容引用

        Step4.创建损失函数和优化器

loss_fn = nn.CrossEntropyLoss()    --交叉损失函数

learing_rate=0.01    --学习速率,外置方便修改
optimizer = torch.optim.SGD(MyNerNet.parameters(),lr=learing_rate)    --随机梯度下降

        Step5.训练+测试

--设置一些计数器
    total_train_step = 0    --记录训练次数
    total_test_step = 0    --记录训练次数
    epoch = 10    --训练轮数

--开始训练
    for i in range(epoch):
        print("---第{}轮训练---".format(i+1))
        for data in train_dataloader
            imgs,targets = data    --拆包
            outputs = mynet(imgs)    --使用网络
            loss = loss_fn(outputs,targets)    --计算损失函数

            optimizer.zero_grad()    --梯度清零
            loss.backforward()    --前向传递
            optimizer.step()    --逐步优化

            total_train_step += 1    --计数
            print("训练次数:{},Loss:{}".format(total_train_step,loss.item()))    
                --展示//也可以使用TensorBorad进行展示

    --开始测试
        total_test_loss = 0    --总损失函数计数
        with torch.no_grad():    --不设置梯度(保证不进行调优)
            for data in test_dataloader:
                imgs,targets = data    --拆包
                outputs = mynet(imgs)    --使用网络
                loss = loss_fn(outputs,targets)    --计算损失函数
                total_test_loss = total_test_loss + loss  --添加此次部分损失函数
        print("整个测试集上的Loss:{}".format(total_test_loss))
        total_test_step = total_test_step + 1

    --保存每轮的模型
        torch.save(mynet,"MyNerNet_Ver{}.pth".format(total_train_step))

                Tips.正确率展示(用于分类问题)

outputs = torch.Tensor([[0.1,0.2],
                        [0.3,0.4]])

preds = outputs.argmax(1)    --最大延展
targets = Torch.Tensor([0,1])    --真实输入

print(preds == targets)    --检验(对应位置是否相等),输出正确的个数

二、模型训练需要注意的事项

        1.网络训练/测试模式

                当网络中含有Dropout、BatchNorm时,必须调用

                但是如果没有对应的内容不是必须的,使用无效

mynet.train()    --训练模式
mynet.test()    --测试模型

        2.测试时关闭梯度

                测试之前需要调用这行代码,关闭网络的梯度

with torch.no_grad():    

三、使用GPU进行训练加速

        1.方式一:在原有的网络模型数据(输入、标注)、损失函数中调用.cuda()函数即可

mynet = mynet.cuda()    --对网络调用

loss_fn = loss_fn.cuda()    --对损失函数调用

imgs = imgs.cuda()    --仅对部分数据生效(数据集的输入数据)

                但如果电脑没有N卡就会报错,最好在代码前部加上验证函数

if torch.cuda.is_available()
    mynet = mynet.cuda()    --有GPU再将网络进行转移

        2.方式二:在原有的网络模型数据(输入、标注)、损失函数中调用.to(device)函数--流转到其他设备

Device = torch.device("cpu")    --调用CPU
Device = torch.device("cuda")    --调用GPU
Device = torch.device("cuda:0")    --调用第一块GPU(当存在多块GPU时)
mynet.to(device)    --流转到设备

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

神经网络(十四)Pytorch完整模型训练和调用GPU加速 的相关文章

随机推荐

  • c++实现引用计数

    概述 当有指针指向同一块内存空间时 计数器加1 没增加一个指向该内存空间的指针 计数器加1 同理 当原本指向该内存空间的指针指向另一块内存 计数器减1 被指向的另一个内存的计数器加1 下面是一个引用计数的一种实现 示例 直接上代码 总共分为
  • uni-app项目中如何使用scss less

    前言 由于公司业务调整 特意学习下uni项目框架 其实根据官方api就是实现很多功能 其实都是一些小坑要走 下面来说一下uni项目中如何使用scss vue编写中我们可以直接使用下面这样方法 多方便
  • Eclispse中Run on Server窗口让选择Server,但已经存在的选择不了

    对于这种问题 通常是因为版本不匹配造成的 jdk版本 Dynamic Web Modules版本 只要改到相应版本就好了 jdk7 时Dynamic Web Modules应设为2 5 如果无法修改 可以新建一个工程 在新建工程时选择Dyn
  • 记忆深处有尘埃——Memory Compiler

    Memory是大家Floorplan中经常使用到一个器件 而且需要花费不少时间去摆放它 Memory的种类很多 各种类型还分别具有不同的参数 那大家有没有想过 对一个设计来说 我们是如何去选择合适的memory类型 不同的类型有什么区别 在
  • 作为一名程序员,如何开展自己的副业?月赚三万的真实故事

    作为一名程序员 除了敲代码之外还应该有一些副业 我们都是程序员 大多数都是普通人 都在替别人打工 虽然收入在别人眼中挺高 但是连个首付都付不起 这时 首先得要发展副业 与其拿着死工资 还不如做些啥 今天 我所说的不是教大家如何去挣很多钱 而
  • mavon-editor 页面回显使用turndown将HTML转为markdown

    1 安装npm install turndown npm install turndown 2页面使用 v model markdowntext
  • 后端接口返回近万条数据,前端渲染缓慢,content Download 时间长的优化方案

    前言 性能优化 是前端绕过不去的一道门槛 甚是重要 最近一年 也很少有机会在项目中进行前端性能优化 一直在忙于业务开发 最近终于是来了机会 遇到了这样的场景 心里也甚是激动 写个随笔记录下性能优化的过程及逻辑 有需要的可以参考下 场景 后端
  • 机器学习实战笔记8(kmeans)

    前面的7次笔记介绍的都是分类问题 本次开始介绍聚类问题 分类和聚类的区别在于前者属于监督学习算法 已知样本的标签 后者属于无监督的学习 不知道样本的标签 下面我们来讲解最常用的kmeans算法 1 kmeans算法 算法过程 Kmeans中
  • Spring核心思想 IOC 、 AOP

    Spring核心思想 IOC AOP IOC 1 什么是IOC 2 IOC解决了什么问题 IoC解决对象之间的耦合问题 3 IOC和DI的区别 AOP 1 什么是AOP 2 AOP在解决什么问题 3 为什么叫切面编程 内容就不展示了 里面已
  • Python自动检查哪位学生未提交作业

    最近期未需要对学生提交的作业进行统计 给平时成绩 总共交了8次作业 每个作业都有2个班 数量太多 于是就利用Python写了一个程序来自动实现 思想 获取指定路径下的所有文件名 如果文件名中包含了学生的名字 因为提交作业的时候以学号 名字进
  • 基于TMF SID的高可扩展性数据模型

    基于TMF SID的高可扩展性数据模型 前言 此文根据TMF SID规范撰写 欢迎大家提出建议和意见 TMF文档版权信息 Copyright TeleManagement Forum 2013 All Rights Reserved Thi
  • Flutter Windows应用开发环境配置

    为什么要入Flutter开发的坑 首先在当今Windows开发已经逐渐成为一个偏小众的领域 不仅要涉及的知识面广 还对开发人员的要求不低 界面的精美也成为一个重要因素 目前已知的Windows 客户端主要分成以下几种 开发语言 Qt C C
  • Android登录 之 Twitter登录

    作为Android登录 之 GooglePlay登录的姊妹篇 这俩篇主要是对接国外平台登录的文章 作者文笔并不好 但是 管他呢 实现功能不就得了嘛 Twitter官网 兄弟们自带梯子啊 然后按照流程 创建申请什么的 也就不多说了 接下来就是
  • Google C++风格指南 阅读笔记

    这个Google C 风格指南出得太好了 有很多C 的问题 其实通过阅读这份文档就可以了 相信读完后 可以在简历上加上一句 具有良好的编码风格 哈哈 下面记录一下我的读书笔记吧 整份文档的中文版本我已经上传到了资源里面 1 头文件 1 1头
  • 在vue使用jsx来解决template中复杂的逻辑处理

    1 首先安装依赖 npm install postcss loader autoprefixer babel loader babel core 2 在 babelrc文件中修改 把 presets env stage 2 plugins
  • 【Python】Windows如何在cmd中切换python版本

    相信很多小伙伴都会有像我一样经历 在windows中装了很多python版本 那么如果我们正式使用的时候应该如何切换呢 方法一 从环境变量中切换python 第一步 打开环境变量 第二步 打开系统变量中Path变量 第三步 将你想使用的Py
  • spring 多个切面的执行顺序及原理

    最近和同事聊起来了springAOP的话题 说了到多个切面的时候程序是怎么执行的 我们常用的spring事务本身也是一个切面 使用的AOP原理 本人从网上找了一些资料 然后根据这些资料进行一下总结 资料地址 1 https blog csd
  • CodeLlama本地部署的实战方案

    大家好 我是herosunly 985院校硕士毕业 现担任算法研究员一职 热衷于机器学习算法研究与应用 曾获得阿里云天池比赛第一名 CCF比赛第二名 科大讯飞比赛第三名 拥有多项发明专利 对机器学习和深度学习拥有自己独到的见解 曾经辅导过若
  • C++:没有与参数列表匹配的构造函数

    报错 E0289 没有与参数列表匹配的构造函数 sales data sales data 实例 初始化一个实例对象 类内定义的构造函数 报错原因 构造函数中第二个参数的类型为 unsigned 而引用只能是引用一个对象 实例化对象时 括号
  • 神经网络(十四)Pytorch完整模型训练和调用GPU加速

    一 模型的训练 Step1 准备数据集 import torchvision train data torchvision dataset CIFAR10 data train True transform torchvision ToTe