1.pytorch lightning之验证与测试

2023-11-11

训练

训练部分已在《入门篇》介绍。

验证集和测试集中评估模型

通常将数据集分为三部分,train/val/test,val集在训练时评估模型的泛化性,选择其中表现最好的checkpoint。test集只在模型训练完成后使用,用于评估模型的真实性能。

添加test流程

划分数据集

以下代码使用torchvision包内实现的MNIST。如果使用自定义的数据集,先用pytorch实现Dataset子类,再继承pl.LightningDataModule类,实现相应接口。

import torch.utils.data as data
from torchvision import datasets
import torchvision.transforms as transforms

# Load data sets
transform = transforms.ToTensor()
train_set = datasets.MNIST(root="MNIST", download=True, train=True, transform=transform)
test_set = datasets.MNIST(root="MNIST", download=True, train=False, transform=transform)

实现test_step()接口

在trainer.test()阶段会自动调用test_step方法,根据需要内部可以增加保存图片、评估模型等功能。

class LitAutoEncoder(pl.LightningModule):
    def training_step(self, batch, batch_idx):
        ...

    def test_step(self, batch, batch_idx):
        # this is the test loop
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        test_loss = F.mse_loss(x_hat, x)
        self.log("test_loss", test_loss)

测试

模型训练完成后,即可调用test()方法进入测试流程

from torch.utils.data import DataLoader

# initialize the Trainer
trainer = Trainer()

# 训练模型
trainer.fit(model, data)

# 训练完成后测试
trainer.test(model, dataloaders=DataLoader(test_set))

验证阶段validation的流程

与test 流程类似,实现validation_step()接口,可以配合on_validation_epoch_end()方法在计算所有样例后评估模型。

class LitAutoEncoder(pl.LightningModule):
    def training_step(self, batch, batch_idx):
        ...

    def validation_step(self, batch, batch_idx):
        # this is the validation loop
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        val_loss = F.mse_loss(x_hat, x)
        self.log("val_loss", val_loss)
        self.metric.update(x_hat, y)# metric是任务相关的评价方法,比如更新混淆矩阵
    
    def on_validation_epoch_step(self, batch, batch_idx):
        
        # 从混淆矩阵中计算tp,fp, tn, fn, acc, F1等指标 
        score = self.metric.get_scores()
        # 记录,横坐标为epoch
        self.log('val/F1', score['F1'], logger=True, on_epoch=True)

预测predict流程

实现predict_step方法,然后调用trainer.predict()

其它HOOK见LightningModule,了解LightningModule的接口基本就会用pl了。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

1.pytorch lightning之验证与测试 的相关文章

随机推荐

  • BI数据系统的设计流程

    BI大数据产品 数据管理平台可以通过报表或者BI模块来搭建 在专栏 帆软数据应用研究院 里有关于企业数据管理和BI报表平台建设的案例 站在项目实施的角度 可以从技术和业务两个层面来考虑 前期进行需求调研 罗列了一张建设思路图 技术上需要考虑
  • 量化投资学习-30:股性与人性,从傅里叶变换谈谈股市大V的操作风格的观察

    1 名家的操作风格的差异 2 方波的傅里叶变换 徐小明 1次基频率 冯矿伟 3次谐波 东风红 5 7次谐波
  • 把多层次的 XML 文档解析为 TreeView 显示

    XML 文档是一个有多层树形节点的文档 因为节点数不确定 所以要跟踪每个节点 需要用到递归 肉眼阅读 XML 比较累 需要去对付一堆的尖括号 用 Delphi 程序把它显示为一个 TreeView 的树结构 比较容易用眼睛去看 以下是我的代
  • 对比学习MocoV1

    对比学习 希望模型能分辨哪些图片类似 哪些图片不类似 即类似的图片特征空间拉近 不类似的拉远 可以设计不同的代理任务提供监督信号 代理任务例子 个体判别 Xi经过两种Ti变成两张不一样的照片 为正样本 其他都是负样本 损失 NCE loss
  • 项目8—八位数码管动态显示(包含程序化简)

    利用74HC573芯片 74HC573的八个锁存器都是透明的D型锁存器 当使能 G 为高时 Q输出将随数据 D 输入而变 当使能为低时 输出将锁存在已建立的数据电平上 输出控制不影响锁存器的内部工作 即老数据可以保持 甚至当输出被关闭时 新
  • centos sftp配置

    SFTP 即 SSH 文件传输协议 或者说是安全文件传输协议 通过SSH端口加密传输 但是 由于这种传输方式使用了加密 解密技术 所以传输效率比普通的FTP要低得多 SFTP的优势在于SSH软件包中包含SFTP 无需额外安装 iptable
  • ES6中const详解

    我们使用const声明常量时 总认为值一旦声明就不可改变 然后我发现在定义对象时 对象的值是可以改变的 对于数值 var message Hello let age 25 以下两行都会报错 const message Goodbye con
  • 在计算机网络中 主机及主机上运行的程序,防灾科技学院网络协议分析复习题...

    一 单项选择题 1 为了提高传输效率 TCP通常采用 A 三次握手法 B 窗口控制机制 C 自动重发机制 D 端口机制 2 某校园网的地址是202 100 192 0 18 要把该网络分成30个子网 则子网掩码应该是 A 255 255 2
  • 李超树(无脑秒斜率)

    文章目录 一 What 二 How 1 插入 update 2 查询 query 三 板题 JSOI2008 Blue Mary开公司 https www luogu com cn problem P4254 五 高端操作 1 动态开点 2
  • delphi中常见错误提示说明

    Delphi的中文错误提示 not allowed before ELSE ElSE前不允许有 clause not allowed in OLE automation section 在OLE自动区段不允许 子句 is not a typ
  • URL中的转义字符

    URL中的转义字符 当URL的参数中出现诸如 空格 等特殊字符串符号时 因为上述字符有特殊含义 导致服务器端无法正确解析参数 如何处理 解决办法 将这些字符转化成服务器可以识别的字符 如果要在URL中传递特殊符号的原本意义 要对他们进行编码
  • chromium 之 webui 调用逻辑

    chromium之webui详细文档参考 WebUI Explainer googlesource com 本文主要讲述webui的调用逻辑 webui webui 用于管理chrome浏览器 通过 chrome url 的方式 可以进行不
  • K8s知识点梳理

    1 k8s是一个编排容器的工具 其实也是管理应用的全生命周期的一个工具 从创建应用 应用的部署 应用提供服务 扩容缩容应用 应用更新 都非常的方便 而且可以做到故障自愈 例如一个服务器 挂了 可以自动将这个服务器上的服务调度到另一个主机上进
  • Java之对象比较

    目录 1 同一性比较 2 相等性比较 3 需要比较对象之间的大小关系 3 1 Comparable接口比较 两个对象 3 1 1 已经实现了Comparable接口的类 例如下面的String类 3 1 2 自己定义类并实现Comparab
  • Apollo原理

    Apollo原理 https github com ctripcorp apollo wiki Apollo E9 85 8D E7 BD AE E4 B8 AD E5 BF 83 E8 AE BE E8 AE A1 提交就是 提交给客户端
  • QT中UDPSocket丢包问题

    1 配置和编程 下位机向上位机发送UDP数据包 由于UDP小包不能写太大 每个小包也就1kB左右 下位机周期性地发送数据 每个周期发送数百个udp包 并且是使用while死循环来发送的 上位机使用QUdpSocket类接收UDP数据 采用信
  • 利用eclipse比较两个文件的代码差异或者一个文件不同版本之间的异同

    1 比较两个文件之间的代码差异 选中两个文件 右键选择Compare With 再选择Each Other即可 2 比较一个文件不同版本之间的差异 选中文件 右键选择team 选择显示资源历史记录 然后从历史记录中选择需要比较的版本 两个文
  • moment.js的使用方法和日期格式化介绍

    文章目录 1 node js 2 使用方法 日期格式化介绍 fromNow 相对时间 startOf 时间开头 endOf 时间末尾 subtract 时间减 add 时间加 获取星期几 moment 被设计为在浏览器和 Node js 中
  • Java多线程程序设计初步

    Java多线程程序设计初步 在Java语言产生前 传统的程序设计语言的程序同一时刻只能单任务操作 效率非常低 例如程序往往在接收数据输入时发生阻塞 只有等到程序获得数据后才能继续运行 随着Internet的迅猛发展 这种状况越来越不能让人们
  • 1.pytorch lightning之验证与测试

    训练 训练部分已在 入门篇 介绍 验证集和测试集中评估模型 通常将数据集分为三部分 train val test val集在训练时评估模型的泛化性 选择其中表现最好的checkpoint test集只在模型训练完成后使用 用于评估模型的真实