Pytorch中实现CPU和GPU之间的切换

2023-11-17

如何在pytorch中指定CPU和GPU进行训练,以及cpu和gpu之间切换

由CPU切换到GPU,要修改的几个地方:

网络模型、损失函数、数据(输入,标注)

# 创建网络模型
tudui = Tudui()
if torch.cuda.is_available():
   tudui = tudui.cuda()

# 损失函数
loss_fn = nn.CrossEntropyLoss()
if torch.cuda.is_available():
    loss_fn = loss_fn.cuda()

# 数据输入   包括训练和测试的代码,二者都需要添加此代码
if torch.cuda.is_available():
   imgs = imgs.cuda()
   targets = targets.cuda()

方法一:.to(device)

1.不知道电脑GPU可不可用时:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu' )
a.to(device)

第一行代码的意思是判断电脑GPU可不可用,如果可用的话device就采用cuda()即调用GPU,不可用的话就采用cpu()即调用CPU。

第二行代码的意思就是把变量放到对应的device上(当然如果你用的是CPU的话就不用这一步了,因为变量默认是存在CPU上的,调用GPU的话要先把变量放到GPU上跑,跑完之后再调回CPU上)

2.指定GPU

# 定义训练的设备
device = torch.device("cuda:0")

# 网络模型创建
tudui = Tudui()
tudui = tudui.to(device)

# 损失函数
loss_fn = nn.CrossEntropyLoss()
loss_fn = loss_fn.to(device)

# 训练步骤开始
    tudui.train()
    for data in train_dataloader:
        imgs, targets=data
        imgs = imgs.to(device)
        targets = targets.to(device)
        outputs = tudui(imgs)
        loss = loss_fn(outputs, targets)

# 测试步骤开始
    tudui.eval()
    total_test_loss = 0
    total_accuracy = 0
    with torch.no_grad():
        for data in test_dataloader:
            imgs, targets=data
            imgs = imgs.to(device)
            targets = targets.to(device)
            outputs = tudui(imgs)
            loss = loss_fn(outputs, targets)
            total_test_loss = total_test_loss + loss.item()
            accuracy = (outputs.argmax(1)==targets).sum()
            total_accuracy = total_accuracy + accuracy

3.指定cpu时:

device = torch.device('cpu')

方法二:

1、需要修改的

# 三种常见的写法
device = torch.device('cuda')
device = torch.device('cuda: 0')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

2、代码

# 创建模型
tudui = Tudui()
if torch.cuda.is_available():
   tudui = tudui.cuda()

# 损失函数
loss_fn = nn.CrossEntropyLoss()
if torch.cuda.is_available():
    loss_fn = loss_fn.cuda()

# 训练步骤开始
    tudui.train()
    for data in train_dataloader:
        imgs, targets=data
        if torch.cuda.is_available():
            imgs = imgs.cuda()
            targets = targets.cuda()
        outputs = tudui(imgs)
        loss = loss_fn(outputs, targets)

 # 测试步骤开始
    tudui.eval()
    total_test_loss = 0
    total_accuracy = 0
    with torch.no_grad():
        for data in test_dataloader:
            imgs, targets=data
            if torch.cuda.is_available():
                imgs = imgs.cuda()
                targets = targets.cuda()
            outputs = tudui(imgs)
            loss = loss_fn(outputs, targets)
            total_test_loss = total_test_loss + loss.item()
            accuracy = (outputs.argmax(1)==targets).sum()
            total_accuracy = total_accuracy + accuracy

总结:

推荐方法一,如果自己电脑是只有CPU,可以推荐使用云端服务器,比如PaddlePaddle,Google colab,这些服务器由每周免费八个小时的使用时间,可供我们基本的需求。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch中实现CPU和GPU之间的切换 的相关文章

随机推荐

  • 解除Discuz!X2的15分钟锁定

    第一种方法 清两个failedlogin空表 解除用户锁定 mysql gt delete from pre common failedlogin Query OK 1 row affected 0 02 sec 解除UC用户锁定 mysq
  • C/C++学习记录--double和float的区别

    单精度浮点数 float 与双精度浮点数 double 的区别如下 1 在内存中占有的字节数不同 单精度浮点数在机内占4个字节 双精度浮点数在机内占8个字节 2 有效数字位数不同 单精度浮点数有效数字8位 双精度浮点数有效数字16位 3 所
  • HertzBeat监控部署及使用

    易用友好的高性能监控告警系统 网站监测 PING连通性 端口可用性 数据库监控 API监控 自定义监控 阈值告警 告警通知 邮件微信钉钉飞书 安装部署 HertzBeat最少依赖于 关系型数据库MYSQL8 实际亲测用mysql5 7 也行
  • 用java.util.Timer定时执行任务

    用java util Timer定时执行任务 如果要在程序中定时执行任务 可以使用java util Timer这个类实现 使用Timer类需要一个继承了java util TimerTask的类 TimerTask是一个虚类 需要实现它的
  • 基于spring+struts2+hibernate实现的Java web论坛

    源码及论文下载 源码及论文下载 http www byamd xyz tag java 1 绪论 这次的实训项目是开发一个java论坛系统 而开发java论坛系统的目的是提供一个供java学习交流的平台 为Java程序员提供交流经验 探讨问
  • angular的form表单验证

    angular的form表单验证 注 基于日常工作的总结 只是基础用法 一般的情况下应该是够用了 本次总结是angular的响应式表单验证 不足之处欢迎指正 首先要在你的页面组件的ts文件中引入angular的表单模块 import For
  • ESP8266 无线wifi AT 指令操作详解

    分享一下 ESP8266 无线wifi AT 的常见指令操作详解 按照官方说明整理 如有问题请私信 再次修改 指令集分为 基础 AT 命令 Wifi 功能 AT 命令 TCP IP 工具箱 AT 命令等 指令分类 测试命令 该命令用于查询设
  • Qt关闭子线程时程序崩溃及解决

    在Qt关闭子线程时 一般使用quit 函数和wait 函数关闭子线程 但可能关闭子线程时 子线程正在接受信号工作 因此 需要在子线程工作之前使用while 工作 进行判断 同时在关闭线程的按钮中需要设置flag的布尔值 问题 点击关闭按钮的
  • Unity3D C#数学系列之点积

    文章目录 1 定义 2 几何意义 3 向量a 向量b xaxb yayb zazb 4 应用案例 4 1 求两向量的夹角 4 2 判断两向量是否垂直 4 3 判断NPC是否在攻击范围内 4 4 已知入射光线和表面法线求反射光线 5 项目 1
  • 期货反向跟单小资金适合做吗?

    反向交易得到了越来越多人的青睐 但我们对其依然停留在一个很朦胧的阶段 仿佛雾里看花 一看三不知 或许是听别人一说 或许哪里留意过 但是真正的去实践 去落地 反而不知从哪里下手了 需要做什么品种 招多少盘手 用多少资金 模拟多久 培训多久 等
  • C、C++、Qt类型转换总结

    一 C类型转换 转换格式如下 Type b Type a 二 C 类型转换 1 const cast 去掉类型的const或volatile属性 const int a 10 a 20 compile error int b const c
  • Ubunt文件压缩和解压、打包和解包

    Ubunt文件压缩和解压 打包和解包 一 压缩和解压 zip tar gz tar bz2 1 zip 优点 支持不同的操作系统平台 如Linux Windows Mac OS 缺点 支持的压缩率不是很高 压缩 zip r file nam
  • 最全Mac&Win软件分享

    由于诸多因素影响 无法再分享相关的资料 如果无法访问GitHub的话大家可以去搜一下 GitHub加速 直接搜索找到相关的解决方案即可 包含常用的所有软件以及在线工具等等 GitHub地址 other doc Tools at main c
  • 微信小程序审核需要多久?微信小程序审核时间加快至2小时!

    8月15日起 微信将上线小程序全新审核机制 为第三方服务商的代码提审铺设 快车道 以往 小程序审核更像是 单车道 同一个第三方 同一时间审核大批量的小程序 也只能一一排队等候通过 8月15日起 平台将上线第三方预检加速机制 同一时间大批量提
  • PYTHON 编写 识别图片中两个峰值的代码

    Python 编写用于识别图片中的两个峰值的代码的方法有很多种 主要可以使用 OpenCV 和 NumPy 等库来实现 具体的代码可以参考网上的一些文章 例如 https www geeksforgeeks org python detec
  • Linux高性能服务器编程 学习笔记 第二章 IP协议详解

    本章从两方面探讨IP协议 1 IP头部信息 IP头部出现在每个IP数据报中 用于指定IP通信的源端IP地址 目的端IP地址 指导IP分片和重组 指定部分通信行为 2 IP数据报的路由和转发 IP数据报的路由和转发发生在除目标机器外的所有主机
  • msvcp140.dll重新安装的解决方法

    在打开游戏或者软件的时候 电脑提示msvcp140 dll丢失无法运行需要怎么办 相信这个问题困扰着不少小伙伴 msvcp140 dll是Windows系统中非常重要的动态连接组件 是连接程序与系统的必不可少的文件 小编今天就把重新安装的解
  • Java中long的表达式问题

    今天在代码里发现了有个抛错 是由下面这段分片上传时定位的代码捕获的 第一想法是是不是由于包太大 6 4G 导致long的offset超限 虽然long好像没有这么短 然后查了下long的最大值Long MAX VALUE 2的63次方 1
  • python监听端口获取数据_python从网络端口读取文本数据

    python从网络端口读取文本数据 To test it with netcat start the script and execute echo Hello cat ncat exe 127 0 0 1 12345 import soc
  • Pytorch中实现CPU和GPU之间的切换

    如何在pytorch中指定CPU和GPU进行训练 以及cpu和gpu之间切换 由CPU切换到GPU 要修改的几个地方 网络模型 损失函数 数据 输入 标注 创建网络模型 tudui Tudui if torch cuda is availa