pytorch使用早停策略

2023-11-01

早停的目的与流程

目的:防止模型过拟合,由于深度学习模型可以无限迭代下去,因此希望在即将过拟合时、或训练效果微乎其微时停止训练。
在这里插入图片描述
在这里插入图片描述

流程如下:

  1. 将数据集切分为三部分:训练数据(数据量最多),验证数据(数据量最少,一般10%-20%左右即可),测试数据(数据量第二多)
  2. 模型通过训练集,得到训练集的 L o s s t r a i n Loss_{train} Losstrain
  3. 然后模型通过验证集,此时不是训练,不需要反向传播,得到验证集的 L o s s v a l i d Loss_{valid} Lossvalid
  4. 早停策略通过 L o s s t r a i n Loss_{train} Losstrain L o s s v a l i d Loss_{valid} Lossvalid来判断,是否需要中断训练

早停策略

早停策略,我们都是拿着验证集训练集来说事:

  1. 常用的策略:

    ♣ 如果训练集loss与验证集loss连续几次下降不明显,就早停
    ♣ 验证集loss连续n次不降反升则早停。(通常是3次)

  2. 根据泛化损失卡阈值的策略

    ♣ 将目前已有的验证集的最小loss记录下来,看当前的验证集loss与最小的loss之间的差距
    ♣ 通过公式: G L ( t ) = 100 ⋅ ( E v a ( t ) E o p t ( t ) − 1 ) {GL(t)} = 100 \cdot \big( \frac{E_{va}(t)}{E_{opt}(t)} - 1) GL(t)=100(Eopt(t)Eva(t)1)计算一个值,并称之为泛化损失
    ♣ 当这个泛化损失超过阈值的时候停止训练

  3. 根据度量进展卡阈值的策略:我们通常假设过拟合会出现在训练集loss很难下降的时候,此时模型继续强行下降loss会导致过拟合的风险,因此,

    ♣ 定一个迭代周期,为训练k次,判断本次迭代的时候平均训练loss比最小训练loss大多少
    (公式: P k ( t ) = 1000 ⋅ ( ∑ t ′ = t − k + 1 t E t r ( t ′ ) k ⋅ m i n t ′ = t − k + 1 t E t r ( t ′ ) − 1 ) P_k(t) = 1000 \cdot \big( \frac{ \sum_{t' = t-k+1}^t E_{tr}(t') }{ k \cdot min_{t' = t-k+1}^t E_{tr}(t') } -1 \big) Pk(t)=1000(kmint=tk+1tEtr(t)t=tk+1tEtr(t)1)
    ♣ 然后结合上面的泛化损失,计算 G L ( t ) P k ( t ) \frac{GL(t)}{P_k(t)} Pk(t)GL(t)
    ♣ 当这个值大于一个阈值时,停止训练

pytorch使用示例

我们参考https://github.com/Bjarten/early-stopping-pytorch这个项目的早停策略

EarlyStopping类在:https://github.com/Bjarten/early-stopping-pytorch/blob/master/pytorchtools.py

结合深度学习的示例如下:

import torch
import torch.nn as nn
import os
from sklearn.datasets import make_regression
from torch.utils.data import Dataset, DataLoader
import numpy as np


class EarlyStopping: # 这个是别人写的工具类,大家可以把它放到别的地方
    """Early stops the training if validation loss doesn't improve after a given patience."""

    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement.
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(
                f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss


class MyDataSet(Dataset):  # 定义数据格式
    def __init__(self, train_x, train_y, sample):
        self.train_x = train_x
        self.train_y = train_y
        self._len = sample

    def __getitem__(self, item: int):
        return self.train_x[item], self.train_y[item]

    def __len__(self):
        return self._len


def get_data():
    """构造数据"""
    sample = 20000
    data_x, data_y = make_regression(n_samples=sample, n_features=100)  # 生成数据集
    train_data_x = data_x[:int(sample * 0.8)]
    train_data_y = data_y[:int(sample * 0.8)]
    valid_data_x = data_x[int(sample * 0.8):]
    valid_data_y = data_y[int(sample * 0.8):]
    train_loader = DataLoader(MyDataSet(train_data_x, train_data_y, len(train_data_x)), batch_size=10)
    valid_loader = DataLoader(MyDataSet(valid_data_x, valid_data_y, len(valid_data_x)), batch_size=10)
    return train_loader, valid_loader


class LinearRegressionModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LinearRegressionModel, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)  # 输入的个数,输出的个数

    def forward(self, x):
        out = self.linear(x)
        return out


def main():
    train_loader, valid_loader = get_data()
    model = LinearRegressionModel(input_dim=100, output_dim=1)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
    criterion = nn.MSELoss()
    early_stopping = EarlyStopping(patience=4, verbose=True)  # 早停

    # 开始训练模型
    for epoch in range(1000):
        # 正常的训练
        print("迭代第{}次".format(epoch))
        model.train()
        train_loss_list = []
        for train_x, train_y in train_loader:
            optimizer.zero_grad()
            outputs = model(train_x.float())
            loss = criterion(outputs.flatten(), train_y.float())
            loss.backward()
            train_loss_list.append(loss.item())
            optimizer.step()
        print("训练loss:{}".format(np.average(train_loss_list)))
        # 早停策略判断
        model.eval()
        with torch.no_grad():
            valid_loss_list = []
            for valid_x, valid_y in valid_loader:
                outputs = model(valid_x.float())
                loss = criterion(outputs.flatten(), valid_y.float())
                valid_loss_list.append(loss.item())
            avg_valid_loss = np.average(valid_loss_list)
            print("验证集loss:{}".format(avg_valid_loss))
            early_stopping(avg_valid_loss, model)
            if early_stopping.early_stop:
                print("此时早停!")
                break


if __name__ == '__main__':
    main()

参考网站

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

pytorch使用早停策略 的相关文章

随机推荐

  • VS2019 带参数启动调式程序

    有一些程序需要通过argv传入相应的参数 这类程序在vs中debug的方式为 参数是在 工程属性 调试 命令参数 设置的 默认是空 没有参数 想要加上 如下图所示
  • Anaconda常用命令小结

    在服务器上面使用Anaconda 经常会涉及一些命令 在各种任务之间徘徊 有时会忘记一些命令行 经常也会百度查一下 不如做一个小结 给自己蹭一下流量 1 检验是否安装以及当前conda的版本 conda V 2 创建python虚拟环境 c
  • Tomcat8 WEB-INF更改class后 用manager实现reload

    Tomcat8 更改WEB INF 下的class后 用自带manager实现reload 解决了困扰很长时间的问题 1 适用场景 2 解决思路之一 用tomcat自带manager工具 3 详细步骤 4 官网文档地址 5 不适用情况 解决
  • Vue键盘事件

    键盘事件
  • Fastjson

    Fastjson 是一个java类库 可以被用来把Java对象转换成Json方式 也可以把Json字符串转换成对应的Java对象 Fastjson可以作用于任何Java对象 包含没有源代码已存在的对象 目标 在服务器端或是adroid提供一
  • AtCoder Beginner Contest 286—C—Rotate and Palindrome

    题目链接 题意 给我一个长度为N的字符串S 你可以执行以下两种操作 0次或者更多次 以任意的顺序执行 1 支付A元 移动S的最左边的字符到最右边 换句话说 就是 S1 S2 Sn gt S2 Sn S1 2 支付B元 在 1 n之间选择一个
  • 前端难点,坑点总结

    问题总结 前言 登录验证码图片显示 post下载文件 js调用ie浏览器的打印功能打印网页上内容 移动端适配不同的屏幕 ie9浏览器异步上传文件 ie浏览器 input标签会出现ie自带叉号 使用flex布局 文字超出部分变省略号 移动端1
  • 初级测试开发工程师应该学些什么

    作为一个毕业半年的我来说 换了两份工作 现在在游戏公司做测试开发工程师 也就不到两个月吧 之前在学校学了C C 数据结构 算法设计等 但也只是考试过了 还是菜鸟一枚 然后来到公司 有做一些兼容性测试之类的 前一个星期给我一个星期做一个网页爬
  • sbt使用教程

    sbt使用教程 sbt 配置 sbt 单项目构建 sbt 多项目构建 sbt 配置定义 sbt 任务定义 sbt 作用域 sbt 插件 总结 项目地址 https gitee com jyq 18792721831 studyspark g
  • UBUNTU 18.04 安装CUDA 10.1 (解决循环登入的问题)

    我之前安装CUDA 会导致重启后卡在登入页面 查询了很多资料后 终于安装成功了 以下记录了我的安装过程 0 安装gcc和make sudo apt get install gcc sudo apt get install make 1 禁用
  • Oracle12报错:ERROR at line 1: ORA-01109: database not open

    描述 想要修改用户密码的时候发现报错 ERROR at line 1 ORA 01109 database not open 解决 发现当前容器的模式为MOUNTED 将其open即可 SQL gt select con id name o
  • IDEA中使用Debug教程

    Debug用来追踪代码的运行流程 通常在程序运行过程中出现异常 启用Debug模式可以分析定位异常发生的位置 以及在运行过程中参数的变化 通常我们也可以启用Debug模式来跟踪代码的运行流程去学习三方框架的源码 一 Debug开篇 首先看下
  • 【Linux专栏】Linux 常用文件管理命令(常用命令大全)

    个人博客 https blog csdn net Newin2020 spm 1011 2415 3001 5343 专栏定位 为 0 基础刚入门 Linux 的小伙伴整理的详细笔记 也欢迎大佬们一起交流 专栏地址 https blog c
  • 运算符重载、模板、标准模板库STL

    C day4 运算符重载 当我们要对自己定义的数据类型进行运算的时候 编译器识别不了 所以没法进行 这时就需要我们自己来写对应的运算符计算的规则 运算符对应的操作数有几个 重载完之后操作数的个数是不能发生变化 重载的方式 1 成员函数进行重
  • 失业在家做什么赚钱好?失业在家怎么赚钱?

    在目前经济形势下 由于诸多客观因素的影响 导致很多人失业在家 无事可做 这样就会导致一个家庭陷入生活困境 面临这种情况 一个人失业在家 可以做什么赚钱呢 1 直播 现在直播经济那么火 很多人也开始各种直播 最省事的就是直播睡觉 当然 也可以
  • js日期的格式化

    function formatDate cellValue 传入毫秒数 if cellValue null cellValue return var date new Date cellValue var year date getFull
  • MyEclipse安装JRebel插件实现热部署

    为什么要使用JRebel 之前用MyEclipse做Java Web开发的时候 有一个很头疼的问题 每次修改后台代码之后 都需要重新将项目部署到tomcat 然后启动tomcat重新运行项目才能查看修改后的结果 浪费不少时间 现在 给MyE
  • MySQL触发器怎么写?

    废话不多说 这篇文章主要讲 从0 到写两个简单的触发器 3分钟学会 工具 Navicat Premium 黄色的三叶草图标 触发器1 BEGIN IF new state in 2 3 then INSERT INTO userservic
  • Django学习 day1

    目录 Django简介 HTTP原理 Django简介 Python语言里最流行 强大的Web框架 同时亦是全球第5大WEB框架 可快速构建稳定强大的WEB项目 大大提高开发效率 很多知名项目都是基于Django开发 如Disqus Pin
  • pytorch使用早停策略

    文章目录 早停的目的与流程 早停策略 pytorch使用示例 参考网站 早停的目的与流程 目的 防止模型过拟合 由于深度学习模型可以无限迭代下去 因此希望在即将过拟合时 或训练效果微乎其微时停止训练 流程如下 将数据集切分为三部分 训练数据