PyTorch框架中使用早停止Early Stopping(含详细代码)

2023-10-29

1.什么是早停止?为什么使用早停止?

早停止(Early Stopping)是 当达到某种或某些条件时,认为模型已经收敛,结束模型训练,保存现有模型的一种手段

机器学习或深度学习中,有很大一批算法是依靠梯度下降,求来优化模型的。是通过更新参数,让Loss往小的方向走,来优化模型的。可参考BP神经网络推导过程详解

关于模型何时收敛(模型训练好了,性能达到要求了或不能再优化了),此时我们可以采取一些判断标准:

1.验证集上的Loss在模型多次迭代后,没有下降
2.验证集上的Loss开始上升
3.验证集上的准确率在模型多次迭代后,没有上升
3.验证集上的准确率开始下降
……
这时,一般可以认为,模型没必要再训练了,可以及时结束训练了。这就被称为早停止,也是避免模型过拟合的一种方法(不等模型拟合,就结束训练了)。

2.如何使用早停止?

early_stopping.py

import numpy as np
import torch
import os

class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, save_path, patience=7, verbose=False, delta=0):
        """
        Args:
            save_path : 模型保存文件夹
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
        """
        self.save_path = save_path
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        path = os.path.join(self.save_path, 'best_network.pth')
        torch.save(model.state_dict(), path)	# 这里会存储迄今最优模型的参数
        self.val_loss_min = val_loss

把该文件拷贝到自己项目中,
在需要使用早停止的文件中,导入:

from early_stopping import EarlyStopping

使用示例(大致代码):

train_losses = []
train_acces = []
# 用数组保存每一轮迭代中,在测试数据上测试的损失值和精确度,也是为了通过画图展示出来。
eval_losses = []
eval_acces = []

save_path = ".\\" #当前目录下
early_stopping = EarlyStopping(save_path)

for e in range(20000):


    # 4.1==========================训练模式==========================
    train_loss = 0
    train_acc = 0
    model.train()   # 将模型改为训练模式

    # 每次迭代都是处理一个小批量的数据,batch_size是64
    for im, label in train_data:
        im = Variable(im)
        targets = Variable(label)

        # 计算前向传播,并且得到损失函数的值
        outputs = model(im)
        loss = criterion(outputs, targets)

        #add by tyb

        #model.save_metrics(metrics)
        # 反向传播,记得要把上一次的梯度清0,反向传播,并且step更新相应的参数。
        optimizer.zero_grad()

        loss.backward()
        optimizer.step()
        #scheduler.step()

        # 记录误差
        train_loss += loss.item()

        # 计算分类的准确率
        out_t = outputs.argmax(dim=1) #取出预测的最大值
        num_correct = (out_t == targets).sum().item()
        acc = num_correct / im.shape[0]
        train_acc += acc

    train_losses.append(train_loss / len(train_data))
    train_acces.append(train_acc / len(train_data))



    # 4.2==========================每次进行完一个训练迭代,就去测试一把看看此时的效果==========================
    # 在测试集上检验效果
    eval_loss = 0
    eval_acc = 0

    model.eval()  # 将模型改为预测模式



    # 每次迭代都是处理一个小批量的数据,batch_size是128
    for im, label in test_data:

        #print("test_data len:",len(test_data))
        im = Variable(im)  # torch中训练需要将其封装即Variable,此处封装像素即784
        label = Variable(label)  # 此处为标签

        out = model(im)  # 经网络输出的结果
        loss = criterion(out, label)  # 得到误差

        # 记录误差
        eval_loss += loss.item()

        # 记录准确率
        out_t = out.argmax(dim=1)  # 取出预测的最大值的索引
        num_correct = (out_t == label).sum().item()  # 判断是否预测正确
        acc = num_correct / im.shape[0]  # 计算准确率
        eval_acc += acc

    eval_losses.append(eval_loss / len(test_data))
    eval_acces.append(eval_acc / len(test_data))
    #scheduler.step()

    print('epoch: {}, Train Loss: {:.6f}, Train Acc: {:.6f}, Eval Loss: {:.6f}, Eval Acc: {:.6f}'
          .format(e, train_loss / len(train_data), train_acc / len(train_data),
                  eval_loss / len(test_data), eval_acc / len(test_data)))

     
    # 早停止
    early_stopping(eval_loss, model)
    #达到早停止条件时,early_stop会被置为True
    if early_stopping.early_stop:
        print("Early stopping")
        break #跳出迭代,结束训练

未用早停止:训练集和验证集上的accuracy和loss曲线
在这里插入图片描述
使用早停止:训练集和验证集上的accuracy和loss曲线
在这里插入图片描述

3. Refferences

  1. 在 Pytorch 中实现 early stopping
  2. 线性代数及其应用——“早停止”与“L2正则”的关系
  3. BP神经网络推导过程详解
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

PyTorch框架中使用早停止Early Stopping(含详细代码) 的相关文章

随机推荐

  • Qt自绘控件之扇形统计图

    首先绘制区域扇形需要先注意一下几点 QPainter中绘制完整的圆等于5760 16 360 此处数值用于计算每一块扇形区域所显示的 需要了解一下扇形二等分线的计算方法 要注意做坐标原点转换 此处为屏幕分辨率自适应 const qreal
  • Keil 中出现“encountered an improper argument” 解决办法

    Keil 中出现 encountered an improper argument 解决办法 出现这种情况就是因为目录文件下带有中文路径 不要弄成中文路径就可以解决了
  • HyperledgerFabric资产案例链码实例

    案例分析 功能 用户开户和销户 资产登记 资产上链 与具体的用户绑定 资产转让 资产所有权变更 查询功能 用户查询 资产查询 资产变更的历史查询 业务实体 用户 名字 身份证 标识 资产列表 资产 名字 标识 特殊属性列表 车 排量 品牌
  • linux基础操作命令符(下)

    linux基础操作命令符 上 linux笔记查询 关于用户 用户的管理 用户组的管理 权限的管理 SSH 解决等待缓存 无法获得锁问题 关于ping命令 ssh 远程连接 ssh远程拷贝的命令 查看linux本地配置 查看磁盘文件目录 df
  • 亚马逊+纽约大学开源图神经网络框架DGL:新手友好,与主流框架无缝衔接

    量子位 授权转载 公众号 QbitAI 最近 纽约大学 纽约大学上海分校 AWS上海研究院以及AWS MXNet Science Team共同开源了一个面向图神经网络及图机器学习的全新框架 命名为Deep Graph Library DGL
  • 1800*D. Nested Segments(数组数组&&离散化)

    解析 按照右端点进行排序 这样某个区间包含的区间只能是在其前面的区间中 所以维护左端点 x 的出现次数 这样我们在查询某个区间 x y 的时候 只需要求 x y 之间包含多少个前面区间的 x 即可 前缀和 因为 前面区间的 y 显然小于当前
  • 微信小程序——常用组件的属性介绍

    常用的组件内容标签 text 文本组件 类似于HTML中的span标签 是一个行内元素 rich text 富文本标签 支持把HTML字符串渲染为WXML结构 text标签的基本使用 通过text组件的selectable属性 实现长按选中
  • C++中的左值与右值(二

    C 中的左值与右值 二 以前以为自己把左值和右值已经弄清楚了 果然发现自己还是太年轻了 下面的这些东西是自己通过在网上拾人牙慧 加上自己的理解写的 1 2 怎么区分左值和右值 知乎大神 顾露的回答 3 我们不能直接定义一个引用的引用 但是
  • ts重点学习85-map类型

  • Idea更新新版本报错,Some conflicts were found in the installation area.

    笔者使用的Idea是2021 2版本 今天直接升级发现报错 找了下解决方法 供大家参考 升级过程 请添加图片描述 https img blog csdnimg cn eaa75e5af7d243d8a2a3f8a731feb6c1 png
  • 【计算机视觉

    文章目录 一 检测相关 8篇 1 1 Explainable Cost Sensitive Deep Neural Networks for Brain Tumor Detection from Brain MRI Images consi
  • 基于java的出租车预约网站

    出租车预约网站能够有效的解决大家上班下班打不到车 加快吃饭逛街的效率 天阴下雨无障碍出行 自己有车不舍得开等问题 使得用户查询车辆信息更加方面快捷 同时便于管理员对车辆和用户的管理 从而给出租车管理公司的预约管理工作带来更高的效率 因此 我
  • CentOS 6 时间,时区,设置修改及时间同步

    一 时区 显示时区 date help 获取帮助 date R date z 上面两个命令都可 root localhost date R date z
  • 数据清洗基本概念

    1 基本概念 数据清洗从名字上也看的出就是把 脏 的 洗掉 指发现并纠正数据文件中可识别的错误的最后一道程序 包括检查数据一致性 处理无效值和缺失值等 因为数据仓库中的数据是面向某一主题的数据的集合 这些数据从多个业务系统中抽取而来而且包含
  • 从零实现一个在线考试系统

    晚上好 我是老北 公众号 GitHub 指北 会推荐 GitHub 上有用有趣的项目 挖掘开源的价值 欢迎关注 基于 SpringBoot Mybatis Plus Shiro mysql redis 构建的智慧云智能教育平台 架构上使用完
  • 前端交互设计利器--MVVM框架avalon.js

    前端交互设计 少不了使用js框架 特别是近来非常火爆的MVVM框架 MVVM框架的确是前端交互设计的利器 最近接触到国内大牛编写的前端框架 avalon js 功能强大 兼容性好 非常好用 MVVM框架核心思想是模型数据与视图绑定 改变了模
  • AI聊天机器人,你更爱哪个?

    嗨 各位同学 最近这几个人工智能助手可是火得很啊 叮咚 AI哥们儿ChatGPT已经很强了 轻松应对各种问题 文笔挺不错的 咻 Anthropic公司的Claude也很给力 聊天能力十分强大 嗖 Google新出的Bard看着也很厉害 刚一
  • 中国自主可控计算机大会、,2019CCF自主可控计算机大会召开

    nbsp nbsp nbsp nbsp光明网讯 齐柳明 7月23日至24日 2019CCF自主可控计算机大会 在北京召开 会议以 应用驱动 协同创新 自主可控发展的源泉和动力 为主题 大会在目前自主可控计算机发展态势良好基础上 针对相关信息
  • css 选择除了某个类下的所有某种元素

    要求 选择除了某个类下的所有input输入框 非页脚下的输入框高度 input not bs table foot input height 40px important line height 40px important
  • PyTorch框架中使用早停止Early Stopping(含详细代码)

    文章目录 1 什么是早停止 为什么使用早停止 2 如何使用早停止 3 Refferences 1 什么是早停止 为什么使用早停止 早停止 Early Stopping 是 当达到某种或某些条件时 认为模型已经收敛 结束模型训练 保存现有模型