基于Pytorch的模型推理

2023-11-08

训练部分说明

假设我们现在有两个文件
{
first_file: train.py #用于训练模型
second_file: inference.py#用于推理检测
}
在train.py文件中我们使用了定义了一个类,里面声明了我的网络模型,例如。

class Net(nn.Moudle):
........

假设在train.py文件中,我们处理图片是用以下的代码

transforms = transforms.Compose([transforms.Resize((224, 224)),
                                            transforms.ToTensor(),
                                            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]
                                          )

然后我们最终保存权重文件是通过以下方式。

torch.save(model.state_dict,save_path)#save_path是自定义的路径,注意要加上自定义的文件名。
如:D:/my_file/first_moedl.pth  #其中first_model.pth是保存之后的名字

推理部分实现

我们需要优先导入网络框架,也就是上文提到的Net类

from train import Net

然后还有一些常用的功能包

from PIL import Image
import torch
import torch.nn as nn
from torchvision import transforms

接着便开始代码的编写了

def prediect(device):
	net = Net()#实例化net
    net.load_state_dict(torch.load('D:/my_file/first_model.pth'))#加载模型
    net = net.to(device)#同样用GPU
    torch.no_grad()
    for i in range(1000):
        img = Image.open("存放图片路径")
        image = transforms(img).unsqueeze(0)#由于训练的时候还有一个参数,是batch_size,而推理的时候没有,所以我们为了保持维度统一,就得使用.unsqueeze(0)来拓展维度
        image = image.to(device)#同样将图片数据放入cuda(GPU)中
        output = net(image)
        _, pre = torch.max(output, 1)#拿出最高置信度的结果

        print(pre) #打印出结果

if __name__ == '__main__':
	#图片格式的转化,和训练时一样
	transforms = transforms.Compose([transforms.Resize((15, 15)),
	                                            transforms.ToTensor(),
	                                            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]
	                                          )
	#我训练的时候用的是GPU,所以这里也一样
    device = torch.device("cuda")
    #开始预测输出
    prediect(device)

当然,在很多时候我们并不会把训练和推理放在两个文件夹,也不会直接打印出它的结果,也是在一个文件夹中实现了训练、测试、推理三个部分,其中我们通常会把推理之后的结果比上正确的结果,来查看准确率。但实现过程都是换汤不换药的。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

基于Pytorch的模型推理 的相关文章

随机推荐

  • python中使用apscheduler二步简单完成定时任务设置,用于自动化任务的创建,无人值守后台任务创建

    一 apscheduler的安装 首先需要安装pip 打开CMD输入pip install apscheduler 安装apscheduler模块 安装过程如下图 二 导入apscheduler包 设置参数与需要执行的脚本 coding u
  • Pytorch从0实现Transformer

    文章目录 摘要 一 构造数据 1 1 句子长度 1 2 生成句子 1 3 生成字典 1 4 得到向量化的句子 该阶段总程序 二 位置编码 2 1 计算括号内的值 2 2 得到位置编码 三 多头注意力 3 1 self mask 摘要 Wit
  • Elasticsearch笔记(七):聚合查询

    聚合框架有助于根据搜索查询提供聚合数据 聚合查询是数据库中重要的功能特性 ES作为搜索引擎兼数据库 同样提供了强大的聚合分析能力 它基于查询条件来对数据进行分桶 计算的方法 有点类似于 SQL 中的 group by 再加一些函数方法的操作
  • 高薪全栈工程师必备 Linux 基础

    https mp weixin qq com s biz MzI0MTQwMTMyOQ tempkey OTkzX0xtOTVOZkJQbjVQSnhQaWdFcU5pTXZiZ3BvRW5DaDNiaGg5MXJDdGVCSTdkSlFU
  • 【转载】TCP的seq和ack号计算方法

    seq和ack号存在于TCP报文段的首部中 seq是序号 ack是确认号 大小均为4字节 注意与大写的ACK不同 ACK是6个控制位之一 大小只有一位 仅当 ACK 1 时ack字段才有效 建立 TCP 连接后 所有报文段都必须把 ACK
  • Kotlin 集合框架

    集合概述 Kotlin 标准库提供了一整套用于管理集合的工具 集合是可变数量 可能为零 的一组条目 各种集合对于解决问题都具有重要意义 并且经常用到 集合通常包含相同类型的一些 数目也可以为零 对象 集合中的对象称为元素或条目 例如 一个系
  • Linux 网桥实现分析

    第一部份 源码框架 一 网桥原理 传统的中继器 如HUB 是一个单纯的物理层设备 它将每一个收到的数据包 在其所有的端口上广播 由接收主机来判断这个数据包是否是给自己的 这样 网络资源被极大的浪费掉了 网桥之所以不同于中继器 主要在于其除了
  • kubernetes详解

    kubernetes详解 1 kuberenetes简介 1 1什么是kubernetes 1 2 Kubernetes发展史 1 3 为什么要使用kubernetes 1 4 Kubernetes 特点 1 5 kubernetes特性
  • Class.forName()用法简介说明

    转自 Class forName 用法简介说明 下文笔者讲述Class forName 方法的功能简介说明 如下所示 class对象简介说明 class对象用于表示类 每一个类在JVM中都对应一个class对象 jvm中将使用class对象
  • Keil5 STM32 软件仿真错误

    error 65 access violation at 0x40021000 no read permission 这个错误是Keil引起的 应该是没有识别出芯片的型号 我是在RT Thread OS 仿真运行的时候发现的 这里的仿真芯片
  • ipynb转markdown

    不知道为什么ipynb里面没法显示图片 为了阅读体验 没办法只能转为markdown了 jupyter nbconvert to markdown 特征处理 ipynb
  • 8800个机器学习开源项目为你精选TOP30!

    授权自AI科技大本营 ID rgznai100 本文共图文结合 建议阅读5分钟 本文为大家带来了30个广受好评的机器学习开源项目 最近 Mybridge发布了一篇文章 对比了过去一年中机器学习领域大约8800个开源项目后 选出30个2017
  • chatGPT身份指令

    充当 Linux 终端 我想让你充当 Linux 终端 我将输入命令 您将回复终端应显示的内容 我希望您只在一个唯一的代码块内回复终端输出 而不是其他任何内容 不要写解释 除非我指示您这样做 否则不要键入命令 当我需要用英语告诉你一些事情时
  • 静态变量与静态函数

    堆与栈 1 栈区 stack 由编译器自动分配释放 存放函数的参数值 局部变量的值等 操作是类似于数据结构中的栈 2 堆区 heap 一般有程序员分配和释放 动态存储分配 分配方式类似于链表 3 全局区 static 全局变量和静态变量的存
  • debian ubuntu 设置DNS 永久设置 重启系统不会丢失

    debian ubuntu 设置DNS 永久设置 重启系统不会丢失 1 debian ubuntu 设置DNS 快捷步骤 2 下面是命令解释 2 1 决定系统dns的文件是 etc resolv conf 2 2 谁能最终影响 etc re
  • 封神台——Cookie伪造目标权限(存储型XSS)

    点击传送门看到的是一个留言板 我们首先要判断是否存在XSS 于是输入一串JS代码 看是否会弹出一个内容为 zkaq 的弹窗 出现了 说明存在XSS漏洞 关于XSS漏洞的科普如下 跨站脚本攻击是指恶意攻击者往Web页面里插入恶意Script代
  • loadtxt()读取数据类型转换/string转换float/ValueError: could not convert string to float:

    实验数据样式 Test csv 只显示几行 0 589469 5 000059 0 480721 0 000204 0 000204 12 945284 4 999956 9 671936 0 000145 0 000145 9 70103
  • 每天一个小技巧之Bash Shell Debug

    sh x xxxx sh
  • 多态的定义及原理

    一 多态的概念和定义 1 多态的概念 多态 Polymorphism 同字面意思意为多种形态 本质就是不同对象完成同一行为产生的不同结果 2 多态的构成条件 多态是在不同继承关系的类对象 去调用同一函数 产生了不同的行为 在继承中要构成多态
  • 基于Pytorch的模型推理

    训练部分说明 假设我们现在有两个文件 first file train py 用于训练模型 second file inference py 用于推理检测 在train py文件中我们使用了定义了一个类 里面声明了我的网络模型 例如 cla