PyTorch基础之模型保存与重载模块、可视化模块讲解(附源码)

2023-11-13

训练模型时,在众多训练好的模型中会有几个较好的模型,我们希望储存这些模型对应的参数值,避免后续难以训练出更好的结果,同时也方便我们复现这些模型,用于之后的研究。PyTorch提供了模型的保存与重载模块,包括torch.save()和torch.load(),以及pytorchtools中的EarlyStopping,这个模块就是用来解决上述的模型保存与重载问题

一、保存与重载模块

若希望保存/加载模型model的参数,而不保存/加载模型的结构,可以通过如下代码

其中state_dict是torch中的一个字典对象,将每一层与该层的对应参数张量建立映射关系

若希望同时保存/加载模型model的参数以及模型结构,而不保存/加载模型的结构,可以通过如下代码

为了获取性能良好的神经网络,训练网络的过程中需要进行许多对于模型各部分的设置,也就是超参数的调整。超参数之一就是训练周期(epoch),训练周期如果取值过小可能会导致欠拟合,取值过大可能会导致过拟合。为了避免训练周期设置不合适影响模型效果,EarlyStopping应运而生。EarlyStopping解决epoch需要手动设定的问题,也可以认为是一种避免网络发生过拟合的正则化方法 

EarlyStopping的原理可以大致分为三个部分:

将原数据分为训练集和验证集;

只在训练集上进行训练,并每隔一个周期计算模型在验证集上的误差,如果随着周期的增加,在验证集上的测试误差也在增加,则停止训练;

将停止之后的权重作为网络的最终参数

初始化 early_stopping 对象:

EarlyStopping 对象的初始化包括三个参数,其含义如下:

patience(int) : 上次验证集损失值改善后等待几个epoch,默认值:7。

verbose(bool):如果值为True,为每个验证集损失值打印一条信息;若为False,则不打印,默认值:False。

delta(float):损失函数值改善的最小变化,当损失函数值的改善大于该值时,将会保存模型,默认值:0,即损失函数只要有改善即保存模型 

 定义一个函数,表示训练函数,希望通过 EarlyStopping 当测试集上的损失值有所下降时,将此时的信息打印出来,并且保存参数。 先创建将要用到的变量,以及初始化 earlystopping 对象

之后训练模型并保存损失值,计算每次迭代在训练集和测试集上的损失值得均值,并保存

 

调用 EarlyStopping 中的_call_()模块,判断损失值是否下降,若下降则进行保存,并打印信息

最后调用torch.load()加载最后一次的保存点,即最优模型,并返回模型,以及每轮迭代在训练集、测试集上的损失值的均值

 

二、可视化模块

在模型训练过程中,有时不仅需要保持和加载已经训练好的模型,也需要将训练过程中的训练集损失函数、验证集损失函数、模型计算图(即模型框架图、模型数据流图)等保持下来,供后续分析作图使用

例如,通过损失函数变化情况,可以观察模型是否收敛,通过模型计算图,可以观察数据流动情况等

Tensorboard可以将数据、模型计算图等进行可视化,会自动获取最新的数据信息,将其存入日志文件中,并且会在日志文件中更新信息,运行数据或模型最新的状态。Tensorboard中常用的模块包括如下七类

add_graph():添加网络结构图,将计算图可视化。

add_image()/add_images():添加单个图像数据/批量添加图像数据。

add_figure():添加matplotlib图片。

add_scalar()/add_scalars():添加一个标量/批量添加标量,在机器学习中可用于绘制损失函数。

add_histogram():添加统计分布直方图。

add_pr_curve():添加P-R(精准率-召回率)曲线。  

add_txt():添加文字

Tensorboard的整体用法,参见下图 

 

 TensorBoard中可以使用add_graph()函数保存模型计算图,该函数用于在tensorboard中创建存放网络结构的Graphs,函数及其参数如下:

model(torch.nn.Module) 表示需要可视化的网络模型;

input_to_model(torch.Tensor or list of torch.Tensor)表示模型的输入变量,如果模型输入为多个变量,则用list或元组按顺序传入多个变量即可;

verbose(bool)为开关语句,控制是否在控制台中打印输出网络的图形结构 

例如,有一个数据类型为torch.nn.Module的变量model,输入的张量为input1和input2,期望返回模型计算图,则可以输入如下代码,即可在SummaryWriter的日志文件夹中保存数据流图

 PyTorch中SummaryWriter的输出文件夹一般为runs文件,保存的日志文件不可以直接双击打开,需要在cmd命令窗口中将目录导航到runs文件夹的上一级目录,并输入tensorboard –logdir runs即可打开日志文件,打开后复制链接到浏览器中,即可打开保存的模型计算图或数据变量等 

TensorBoard中可以使用add_scalar()/add_scalars()函数保存一个或在一张图中保存多个常量,如训练损失函数值、测试损失函数值、或将训练损失函数值和测试损失函数值保存在一张图中。

add_scalar()函数及参数如下:

  

tag(string)为数据标识符;

scalar_value(float or string)为标量值,即希望保存的数值;

global_step(int)为全局步长值,可理解为x轴坐标 

 add_scalars()函数及参数如下:

main_tag(string)为主标识符,即tag的父级名称;

tag_scalar_dict(dict)为保存tag及tag对应的值的字典类型数据;

global_step(int)为全局步长值,可理解为x轴坐标。 

add_scalars()可以批量添加标量,例如,绘制y=xsinx、y=xcosx、y=tanx的图像,可以输入如下代码,保存的日志文件打开方式与上文所述相同

 

 

创作不易 觉得有帮助请点赞关注收藏~~~ 

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

PyTorch基础之模型保存与重载模块、可视化模块讲解(附源码) 的相关文章

随机推荐

  • background-position设置百分比“失效了”!!

    1 background position设置百分比的计算原理 当指定百分比值的时候 实际上执行了以下的计算公式 该公式可以用数学方式定义图片和容器相对位置重合 这是当background size auto 时 百分比有效 contain
  • 组装一台1u服务器

    1 服务器的内存条都是带校验功能的 2 服务器cpu可以多个 一般电脑只有一个
  • NAS个人云存储服务器搭建

    NAS Network Attached Storage 网络附属存储 通过网络提供数据访问服务 可以理解为长时间连网的存储设备 其功能基本和市面上的各种云盘相似 它以数据为中心 将存储设备与服务器彻底分离 集中管理数据 从而释放带家 提高
  • layui 时间选择器laydate开始时间结束时间限制,及设置默认时分秒

    这个时间控件实现了 开始时间和结束时间都不得超过当前时间 结束时间大于开始时间并且小于当前时间 开始时间默认时分秒为00 00 00 结束时间默认时分秒为23 59 59 其他官方自带功能 项目中的需求是 结束时间要大于开始时间 包括时分秒
  • 全网页截图教程,如何截图截全屏

    系统自身的截屏快捷键 台式键盘的电脑 全屏 Ctrl Print Screen 当前窗口 Alt Print Screen 笔记本截图快捷键 FN Prt sc 浏览器自带的 非常好用 在浏览器打开要截取全网页为图片的那个网页 打开那个网页
  • Unity—常用API(续Time类)

    今天整理了Time类 一张很有意思的理解API的图 此图灵感来源于 如何理解API API 是如何工作的 仁杰兄的博客 CSDN博客 api 目录 Time 练习 使用Text制作倒计时预制体 在Update每帧执行的方法中 个别语句实现指
  • ubuntu学习(四)----文件写入操作编程

    1 write函数的详解 ssize t write int fd const void buf size t count 参数说明 fd 是文件描述符 write所对应的是写 即就是1 buf 通常是一个字符串 需要写入的字符串 coun
  • 浅析芝麻信用分征信体系

    在互联网大佬中 能将未雨绸缪之功力炼至极致的非马云莫属了 其所扔下的每一个棋子都会前思三招 后推九步 每一步商业布局都像是一枚运气极佳的的种子总能找到温润肥沃的土壤而后破土而发 十年遮风 百年纳凉 马老板曾苦心孤诣费尽心机地想把淘宝由一家广
  • 教你自己搭建一个ip池(绝对超好用!!!!)

    随着我们爬虫的速度越来越快 很多时候 有人发现 数据爬不了啦 打印出来一看 不返回数据 而且还甩一句话 是不是很熟悉啊 要想想看 人是怎么访问网站的 发请求 对 那么就会带有 request headers 那么当你疯狂请求别人的网站时候
  • .NET/C#下html转PDF,PDF加水印,PDF转图片

    一 添加OpenHtmlToPdf iTextSharp O2S Components PDFRender4NET引用 1 OpenHtmlToPdf 是一个 NET库 开源的 用于将HTML文档呈现为PDF格式 github地址 http
  • hibernate derby 配置文件

    hibernate cfg xml
  • Python数据分析——上海市二手房价格分析

    自学数据分析与机器学习已有两月 近期房价问题引人深思 即兴做个上海市房价的数据分析小项目 上网一查上海市新楼盘价格 高的不忍直视 索性退而求其次 分析上海二手房的价格 一 数据收集 常规做法是编写网络爬虫程序 爬取相关网站的数据信息 捷径是
  • 如何把glb格式模型gltf格式模型导入3dmax和C4D,U3D,UE4这些主流软件中

    咱有时候去glbxz com添加链接描述 官网下载免费glb格式模型 gltf模型下载时候是没有通用格式 例如fbx obj 这个时候3dmax和C4D直接打开导入是不行的 也可以制作glb模型 扣扣 424081801 这个时候 咱们用
  • AI工具汇总

    大家好 我是可夫小子 关注AIGC 读书和自媒体 ChatGPT已经火了这么久 我也写不了少玩ChatGPT的方法 昨天OpenAI又推出了苹果手机的APP 我也介绍下载和安装的攻略 但根据读者反馈 仍然还是有许多同学没能用上 今天我就把我
  • VS2015远程编译Linux项目

    用VS2015开发Linux程序详细教程 配置篇 crazytea的博客 CSDN博客 linux 程序开发VS2015推出了跨平台开发 其中包括了对Linux程序开发的支持 最近刚好需要开发Linux程序 对其进行了一些研究 首先介绍下涉
  • redis的安装与配置

    第一章 redis 1 1redis的概述 1 2关系型数据库与非关系数据库 1 3关系型数据库和非关系型数据库区别 1 4redis优点与缺点 第二章redis的安装 2 1 YUM安装 2 2下载编译安装 2 2 1关闭防火墙 2 2
  • element左侧导航栏el-menu,菜单栏文字超出不折行问题

    在CSS样式中加上这些样式就可以了 el submenu title display flex align items center el submenu title span white space normal word break b
  • C# 面向对象05 StringBuilder的用法

    好处 相比普通的 string处理 提高了字符串的处理速度 注意点 使用时需要使用对象的方式 StringBuilder world new StringBuilder using System using System Diagnosti
  • SimpleDateFormat模式字符串格式

    SimpleDateFormat模式字符串 new SimpleDateFormat String parm parm为一个字符串 表示格式 时间模式 字母 时间元素 表示 示例 y 年 Year 1996 96 M 年中的月份 Month
  • PyTorch基础之模型保存与重载模块、可视化模块讲解(附源码)

    训练模型时 在众多训练好的模型中会有几个较好的模型 我们希望储存这些模型对应的参数值 避免后续难以训练出更好的结果 同时也方便我们复现这些模型 用于之后的研究 PyTorch提供了模型的保存与重载模块 包括torch save 和torch