Pytorch-Lightning基本方法介绍

2023-11-16

LIGHTNINGMODULE

LightningModule将PyTorch代码整理成5个部分:

  • Computations (init).
  • Train loop (training_step)
  • Validation loop (validation_step)
  • Test loop (test_step)
  • Optimizers (configure_optimizers)

Minimal Example

所需要的方法:

import pytorch_lightning as pl
class LitModel(pl.LightningModule):

     def __init__(self):
         super().__init__()
         self.l1 = torch.nn.Linear(28 * 28, 10)

     def forward(self, x):
         return torch.relu(self.l1(x.view(x.size(0), -1)))

     def training_step(self, batch, batch_idx):
         x, y = batch
         y_hat = self(x)
         loss = F.cross_entropy(y_hat, y)
         return loss

     def configure_optimizers(self):
         return torch.optim.Adam(self.parameters(), lr=0.02)

使用下面的代码进行训练:

train_loader = DataLoader(MNIST(os.getcwd(), download=True, transform=transforms.ToTensor()))
trainer = pl.Trainer()
model = LitModel()

trainer.fit(model, train_loader)

一些基本方法

Training
Training loop

使用training_step方法来增加training loop

class LitClassifier(pl.LightningModule):

     def __init__(self, model):
         super().__init__()
         self.model = model

     def training_step(self, batch, batch_idx):
         x, y = batch
         y_hat = self.model(x)
         loss = F.cross_entropy(y_hat, y)
         return loss

如果需要在epoch-level进行度量,并进行记录,可以使用*.log*方法

def training_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self.model(x)
    loss = F.cross_entropy(y_hat, y)

    # logs metrics for each training_step,
    # and the average across the epoch, to the progress bar and logger
    self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
    return loss

如果需要对每个training_step的输出做一些操作,可以通过改写training_epoch_end来实现

def training_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self.model(x)
    loss = F.cross_entropy(y_hat, y)
    preds = ...
    return {'loss': loss, 'other_stuff': preds}

def training_epoch_end(self, training_step_outputs):
   for pred in training_step_outputs:
       # do something

如果需要对每个batch分配到不同GPU上进行训练,可以采用training_step_end方法来实现

def training_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self.model(x)
    loss = F.cross_entropy(y_hat, y)
    pred = ...
    return {'loss': loss, 'pred': pred}

def training_step_end(self, batch_parts):
    gpu_0_prediction = batch_parts.pred[0]['pred']
    gpu_1_prediction = batch_parts.pred[1]['pred']

    # do something with both outputs
    return (batch_parts[0]['loss'] + batch_parts[1]['loss']) / 2

def training_epoch_end(self, training_step_outputs):
   for out in training_step_outputs:
       # do something with preds
Validation loop

增加一个validation loop,可以通过改写LightningModule中的validation_step来实现

class LitModel(pl.LightningModule):
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y)
        self.log('val_loss', loss)

对validation进行epoch-level度量,可以通过改写validation_epoch_end实现

def validation_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self.model(x)
    loss = F.cross_entropy(y_hat, y)
    pred =  ...
    return pred

def validation_epoch_end(self, validation_step_outputs):
   for pred in validation_step_outputs:
       # do something with a pred

如果需要validation进行数据并行计算(多GPU),可以通过validation_step_end方法实现

def validation_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self.model(x)
    loss = F.cross_entropy(y_hat, y)
    pred = ...
    return {'loss': loss, 'pred': pred}

def validation_step_end(self, batch_parts):
    gpu_0_prediction = batch_parts.pred[0]['pred']
    gpu_1_prediction = batch_parts.pred[1]['pred']

    # do something with both outputs
    return (batch_parts[0]['loss'] + batch_parts[1]['loss']) / 2

def validation_epoch_end(self, validation_step_outputs):
   for out in validation_step_outputs:
       # do something with preds
Test loop

增加一个test loop的过程和上面增加validation loop是相同的,唯一不同的是,只有在使用*.test()*的时候,test loop才会被调用

model = Model()
trainer = Trainer()
trainer.fit()

# automatically loads the best weights for you
trainer.test(model)

这里,有两种方式调用test():

# call after training
trainer = Trainer()
trainer.fit(model)

# automatically auto-loads the best weights
trainer.test(test_dataloaders=test_dataloader)

# or call with pretrained model
model = MyLightningModule.load_from_checkpoint(PATH)
trainer = Trainer()
trainer.test(model, test_dataloaders=test_dataloader)	
Inference

对于研究,LightningModules像系统一样结构化

import pytorch_lightning as pl
import torch
from torch import nn

class Autoencoder(pl.LightningModule):

     def __init__(self, latent_dim=2):
        super().__init__()
        self.encoder = nn.Sequential(nn.Linear(28 * 28, 256), nn.ReLU(), nn.Linear(256, latent_dim))
        self.decoder = nn.Sequential(nn.Linear(latent_dim, 256), nn.ReLU(), nn.Linear(256, 28 * 28))

     def training_step(self, batch, batch_idx):
        x, _ = batch

        # encode
        x = x.view(x.size(0), -1)
        z = self.encoder(x)

        # decode
        recons = self.decoder(z)

        # reconstruction
        reconstruction_loss = nn.functional.mse_loss(recons, x)
        return reconstruction_loss

     def validation_step(self, batch, batch_idx):
        x, _ = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        recons = self.decoder(z)
        reconstruction_loss = nn.functional.mse_loss(recons, x)
        self.log('val_reconstruction', reconstruction_loss)

     def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.0002)

可以用如下方式训练

autoencoder = Autoencoder()
trainer = pl.Trainer(gpus=1)
trainer.fit(autoencoder, train_dataloader, val_dataloader)

lightning inference部分的方法:

  • training_step
  • validation_step
  • test_step
  • configure_optimizers

注意到在这个例子中,train loop和val loop完全相同,我们可以重复使用这部分代码

class Autoencoder(pl.LightningModule):

     def __init__(self, latent_dim=2):
        super().__init__()
        self.encoder = nn.Sequential(nn.Linear(28 * 28, 256), nn.ReLU(), nn.Linear(256, latent_dim))
        self.decoder = nn.Sequential(nn.Linear(latent_dim, 256), nn.ReLU(), nn.Linear(256, 28 * 28))

     def training_step(self, batch, batch_idx):
        loss = self.shared_step(batch)

        return loss

     def validation_step(self, batch, batch_idx):
        loss = self.shared_step(batch)
        self.log('val_loss', loss)

     def shared_step(self, batch):
        x, _ = batch

        # encode
        x = x.view(x.size(0), -1)
        z = self.encoder(x)

        # decode
        recons = self.decoder(z)

        # loss
        return nn.functional.mse_loss(recons, x)

     def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.0002)

注:我们创建了所有loop都可以使用的一个新方法shared_step,这个方法的名字可以任意取

Inference in research

如果需要进行系统推断,可以将forward方法加入到LightningModule中

class Autoencoder(pl.LightningModule):
    def forward(self, x):
        return self.decoder(x)

在复杂系统中增加forward的优势,使得可以进行包含inference procedure等

class Seq2Seq(pl.LightningModule):

    def forward(self, x):
        embeddings = self(x)
        hidden_states = self.encoder(embeddings)
        for h in hidden_states:
            # decode
            ...
        return decoded
Inference in production

在LightningModule中迭代不同的模型

import pytorch_lightning as pl
from pytorch_lightning.metrics import functional as FM

class ClassificationTask(pl.LightningModule):

     def __init__(self, model):
         super().__init__()
         self.model = model

     def training_step(self, batch, batch_idx):
         x, y = batch
         y_hat = self.model(x)
         loss = F.cross_entropy(y_hat, y)
         return loss

     def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y)
        acc = FM.accuracy(y_hat, y)

        # loss is tensor. The Checkpoint Callback is monitoring 'checkpoint_on'
        metrics = {'val_acc': acc, 'val_loss': loss}
        self.log_dict(metrics)
        return metrics

     def test_step(self, batch, batch_idx):
        metrics = self.validation_step(batch, batch_idx)
        metrics = {'test_acc': metrics['val_acc'], 'test_loss': metrics['val_loss']}
        self.log_dict(metrics)

     def configure_optimizers(self):
         return torch.optim.Adam(self.model.parameters(), lr=0.02)

然后将任意适合该task的模型传进去

for model in [resnet50(), vgg16(), BidirectionalRNN()]:
    task = ClassificationTask(model)

    trainer = Trainer(gpus=2)
    trainer.fit(task, train_dataloader, val_dataloader)

tasks可以任意复杂,比如,可以实现GAN训练,self-supervised,甚至RL

class GANTask(pl.LightningModule):

     def __init__(self, generator, discriminator):
         super().__init__()
         self.generator = generator
         self.discriminator = discriminator
     ...

del)

trainer = Trainer(gpus=2)
trainer.fit(task, train_dataloader, val_dataloader)

tasks可以任意复杂,比如,可以实现GAN训练,self-supervised,甚至RL

```python
class GANTask(pl.LightningModule):

     def __init__(self, generator, discriminator):
         super().__init__()
         self.generator = generator
         self.discriminator = discriminator
     ...

LightningModule API(略)

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch-Lightning基本方法介绍 的相关文章

随机推荐

  • JSP include能包含html页面吗?

    转自 JSP include能包含html页面吗 jsp简介 JSP全称是Java Server Pages 是一种动态网页技术 JSP其实就是在html中插入了java代码和JSP标签之后形成的文件 文件名以 jsp结尾 其实JSP就是一
  • 输入网址后,会经历哪几个步骤

    1 面试官问输入网址后 会经历哪几个步骤 DNS HTTPS TCP 就知道这两个 DNS解析 TCP连接 发送http请求 HTTP请求报文的方法是 get 如果浏览器存储了该域名下的 Cookies 那么会把 Cookies放入 HTT
  • 协议数据处理流程

    数据处理流程 总体流程 数据放入缓冲 PushToComFIFO RecBuffer BufLen 从数据缓冲中解包协议格式 读缓冲 GetDataFromComFIFO ComStr 从数据缓冲中解包协议格式 协议格式解析 Get XXX
  • python实验报告实验总结_python还能干这事

    上文提到python可以干很多事 很多时候生活中的很多问题都可以用代码解决 尤其是那些反复重复的事 今天就拿读研的时候的一个例子给大家说说 如何用代码解决生活中的问题 问题 导师带了3个班的图形学 100多号人 期末了 平时成绩已经出来了
  • web常见的攻击方式有哪些,以及如何进行防御?

    一 是什么 Web攻击 WebAttack 是针对用户上网行为或网站服务器等设备进行攻击的行为 如植入恶意代码 修改网站权限 获取网站用户隐私信息等等 Web应用程序的安全性是任何基于Web业务的重要组成部分 确保Web应用程序安全十分重要
  • react组件中设置多个className

    错误写法
  • c++下的文件批量读写——查找文件的类 struct _finddata_t结构体用法

    查找文件的类 struct finddata t结构体用法 https blog csdn net yang332233 article details 53081785 但是运行原链接的代码时在while findnext handle
  • Android APP的安装路径

    小Tips app安装在哪个路径 2021 6 10更新 1 安装路径共五个 system app 系统自带的应用程序 无法删除 root后可以删除 system priv app 比system app 中的应用权限更加高 如Launch
  • DC/DC和LDO的区别是什么?以及如何选择?

    LDO是线性电源 DC DC是开关电源 SMPS 是两种不同种类电源 工作原理也不相同 开关电源和线性电源的区别 开关电源 SMPS 和低压差线性稳压电源 LDO 从模型理解原理 电源技术与新能源 面包板社区 LDO DC DC如何选型 L
  • DB2约束

    清单 1 查询数据库目录以判断哪些数据库列可为空 db2 select tabname colname nulls from syscat columns where tabschema MELNYK and nulls N 仅单独存在 惟
  • 告别BeanUtils,Mapstruct从入门到精通

    如果你现在还在使用BeanUtils 看了本文 也会像我一样 从此改用Mapstruct 对象之间的属性拷贝 之前用的是Spring的BeanUtils 有一次 在学习领域驱动设计的时候 看了一位大佬的文章 他在文章中提到使用Mapstru
  • LSB(Least Significant Bit)和MSB(Most Significant Bit)

    LSB Least Significant Bit 意为最低有效位 MSB Most Significant Bit 意为最高有效位 若MSB 1 则表示数据为负值 若MSB 0 则表示数据为正 MSB高位前导 LSB低位前导 谈到字节序的
  • MVC架构

    10 MVC 什么是MVC Model view Controller 模型视图控制器 10 1 以前的架构 用户可以直接访问控制层 控制层可以直接操作数据库 Servlet gt CURD gt 数据库 弊端 程序十分臃肿 不利于维护 S
  • hiveSql 重分组聚合问题

    hiveSql 重分组聚合问题 问题 分析 实现 最后 问题 将下图中A表转变为B和C 即A gt B A gt C 分析 1 首先看A gt B 可见是将name列分组 取最大组内最大id 介绍两种求解方式 1 很容易想到 开窗函数fir
  • html使用iframe包含pdf文件,HTML embedded PDF iframe

    It s downloaded probably because there is not Adobe Reader plug in installed In this case IE it doesn t matter which ver
  • 【数据架构系列-06】一文搞懂数据模型的3种类型——概念模型、逻辑模型、物理模型

    数据模型就是模拟现实世界的方法论 是通向智慧世界的基石 从现实世界发展到智慧世界 要数经历现实世界 信息世界 计算机世界 数据世界 智慧世界五个不同的世界 我们天生具有从混沌的世界抽象信息变为信息世界的能力 但是到另外几个世界需要我们懂得计
  • spring的自动装配即装配的各种模式

    Spring的自动装配 无须在Spring配置文件中描述javabean之间的依赖关系 IOC容器会自动建立JavaBean之间的关联关系 根据属性名称自动装配autowire byName 根据数据类型自动装配autowire byTyp
  • 完整安装datax-web教程

    1 安装mysql5 7 a 创建目录下载安装rpm包 mkdir p opt software cd opt software wget i c http dev mysql com get mysql57 community relea
  • 【c++复习笔记】——智能指针详细解析(智能指针的使用,原理分析)

    个人主页 努力学习的少年 版权 本文由 努力学习的少年 原创 在CSDN首发 需要转载请联系博主 如果文章对你有帮助 欢迎关注 点赞 收藏 一键三连 和订阅专栏哦 目录 一 智能指针的基本概念 二 智能指针的定义和使用 三 auto ptr
  • Pytorch-Lightning基本方法介绍

    文章目录 LIGHTNINGMODULE Minimal Example 一些基本方法 Training Training loop Validation loop Test loop Inference Inference in rese