PyTorch-06 过拟合&欠拟合、Train-Val-Test划分、Regularization减轻防止overfitting、动量与学习率衰减、其他技巧Early Stop,Dropout

2023-11-16

PyTorch-06 (过拟合&欠拟合、Train-Val-Test划分、Regularization减轻防止overfitting、动量与学习率衰减、其他技巧Tricks(Early Stop,Dropout))

一、过拟合&欠拟合

讨论过拟合和欠拟合之前,先了解一下数据真实的模态(即数据的真实分布):Pr(x)

讨论一下过拟合和欠拟合:

场景1 Scenario1: house price 房价
是一个线性模型:
在这里插入图片描述
场景2 Scenario2: GPA
非线性模型:
在这里插入图片描述
首先,对于上面两个图,我们知道真实的分布吗?The ground-truth distribution?
▪ That’s perfect if known 如果知道那就完美了
▪ However 我们是不知道的。

其次,另一个因素Another factor,噪声noise
在这里插入图片描述

对模型本身进行度量Let’s assume

对于高次方越高,则模型抖动越大,波形越复杂。
在这里插入图片描述

衡量不同模型的学习能力 Mismatch: ground-truth VS estimated

model capacity模型容量:
在这里插入图片描述
对于常数模型来说,学习能力非常的弱。
次方越高模型所表达的能力越强。
在这里插入图片描述

案例Case1: Estimated < Ground-truth(Underfitting)

Estimated 模型本身的capacity,即模型的表达能力。
当模型的表达能力<真实模型的复杂度,出现under-fitting的情况。
这种时候会造成我们模型的表达能力不够。
在这里插入图片描述
for example :WGAN
WGAN早期版本也是增加了一个约束,将模型的复杂度降低下来。
在这里插入图片描述
Underfitting:
发现不论是训练数据还是测试数据的loss、acc表现都不是很好的时候,尝试将模型复杂度增加一些(比如说堆叠更多的层数、每一层的单元数量会增加)。通过这种方式增加模型的复杂度后,查看Underfitting这种情况是不是有所改变。
在这里插入图片描述

案例Case2: Ground-truth < Estimated(Overfitting)

使用模型的复杂度>真实模型的复杂度
这样的情况在训练的时候,模型会尝试将模型每一个点的loss都降低。这样模型会逼近每一个点。这种情况会使得train的结果非常的好,但是test的结果会比较差。
在这里插入图片描述
Overfitting:
▪ train loss and acc. is much better 训练模型的loss和acc都非常好。
▪ test acc. is worse 但是训练模型的loss和acc会很差。
▪ => Generalization Performance 泛化能力效果。
当Overfitting很严重的情况,泛化能力会很差。

本节总结

在这里插入图片描述
在现实生活中,大多数情况都是overfitting,因为现在计算机的计算能力很强,优化的网络复杂度会变得非常非常深,这样很容易网络的表达能力超过了现实模型的能力(数据集足够多,就不会overfitting,如果数据集有限的话,因为包含了噪声,这样就很容易overfitting),因此我们需要如何检测overfitting,并且降低overfitting的情况,下节会讲。

二、Train-Val-Test划分(交叉验证)

如何检测Overfitting

首先我们将数据划分为Train dataset 和 Test dataset
在这里插入图片描述
划分了之后,在train dataset进行training训练,在training的过程中就会去学习pattern,train dataset和test dataset都是来自同一个数据集,所以他们的真实分布肯定是一样的,当我们在train dataset上学习到了一个分布情况以后,我们要检测是不是overfitting,就需要用train dataset训练好的模型对test dataset进行loss和acc的检测,发现在train dataset表现很好,test dataset表现很差就是overfitting了。

源数据集的划分 splitting:
在这里插入图片描述

For example

在这里插入图片描述

另一种常用的对源数据的划分情况 splitting:数据集划分为三部

常用的划分splitting情况是划分三部分:新增了validation dataset,这时test dataset就不在是对模型进行挑选了。将原来test dataset换成了validation dataset。将挑选模型的功能给了validation dataset,而test dataset是交给客户,客户在验收的时候查看模型性能的表现。所以在Kaggle这样的比赛,主办方是不会将test set提供给我们的。
在这里插入图片描述
在这里插入图片描述

数据集划分为三部分如何划分:train-val-test

一般来说,现有的数据集只提供两个划分,一个是train一个是test。
通过参数train=True划分为train dataset,train=False划分为test dataset。
在这里插入图片描述
之后我们在train dataset 的基础上进行人为的划分:
假设train_db有60k个sample样本,将60k的样本数据划分为50k和10k。其中50k依然是train_db,10k就是val_db,再加上之前数据默认划分的test_db,这样就划分好了三部分。
在这里插入图片描述
for example:使用了三个的划分

import  torch
import  torch.nn as nn
import  torch.nn.functional as F
import  torch.optim as optim
from    torchvision import datasets, transforms

batch_size=200
learning_rate=0.01
epochs=10

#将参数train=True
#获得train_db
train_db = datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))
train_loader = torch.utils.data.DataLoader(
    train_db,
    batch_size=batch_size, shuffle=True)

#将参数train=False
#获得test_db
test_db = datasets.MNIST('../data', train=False, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
]))
test_loader = torch.utils.data.DataLoader(test_db,
    batch_size=batch_size, shuffle=True)


print('train:', len(train_db), 'test:', len(test_db))
#再将train_db划分为train_db和val_db
train_db, val_db = torch.utils.data.random_split(train_db, [50000, 10000])
print('db1:', len(train_db), 'db2:', len(val_db))
train_loader = torch.utils.data.DataLoader(
    train_db,
    batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(
    val_db,
    batch_size=batch_size, shuffle=True)


class MLP(nn.Module):

    def __init__(self):
        super(MLP, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(784, 200),
            nn.LeakyReLU(inplace=True),
            nn.Linear(200, 200),
            nn.LeakyReLU(inplace=True),
            nn.Linear(200, 10),
            nn.LeakyReLU(inplace=True),
        )

    def forward(self, x):
        x = self.model(x)

        return x

device = torch.device('cuda:0')
net = MLP().to(device)
optimizer = optim.SGD(net.parameters(), lr=learning_rate)
criteon = nn.CrossEntropyLoss().to(device)

for epoch in range(epochs):

    #使用train_loader来做训练
    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.view(-1, 28*28)
        data, target = data.to(device), target.cuda()

        logits = net(data)
        loss = criteon(logits, target)

        optimizer.zero_grad()
        loss.backward()
        # print(w1.grad.norm(), w2.grad.norm())
        optimizer.step()

        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))


    #在总的epoch中添加监视是否overfitting,就使用val_loader数据集
    test_loss = 0
    correct = 0
    for data, target in val_loader:
        data = data.view(-1, 28 * 28)
        data, target = data.to(device), target.cuda()
        logits = net(data)
        test_loss += criteon(logits, target).item()

        pred = logits.data.max(1)[1]
        correct += pred.eq(target.data).sum()

    test_loss /= len(val_loader.dataset)
    print('\nVAL set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(val_loader.dataset),
        100. * correct / len(val_loader.dataset)))


#test_loader这部分代码和val_loader的代码是一样的,只不过使用模型最好时所对应的参数
test_loss = 0
correct = 0
for data, target in test_loader:
    data = data.view(-1, 28 * 28)
    data, target = data.to(device), target.cuda()
    logits = net(data)
    test_loss += criteon(logits, target).item()

    pred = logits.data.max(1)[1]
    correct += pred.eq(target.data).sum()

test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

在这里插入图片描述
在这里插入图片描述

另一种数据划分的方式:交叉验证 K-fold cross-validation

首先test dataset是不能动的,之后第一次划分train dataset:
在train dataset基础上划分出了val set蓝色部分,和剩余train set部分。
在这里插入图片描述
第二次再切割划分:
黄色部分为切割划分后的train set,白色部分是val set。
在这里插入图片描述
在这里插入图片描述
这样做的好处是每个数据集都有可能成为validation dataset。

K-fold cross-validation就是将train dataset的数据集划分成k份,每次取k-1/k 做train,另外剩下的1份做val。这样做是提升了train dataset数据的利用度。

三、Regularization 减轻防止overfitting

Reduce Overfitting

▪ More data 获得更多的数据,但是代价最大。

▪ Constraint model complexity 约束模型的复杂性
▪ shallow 降低模型的表达能力
▪ regularization 本节课所讲解的内容

▪ Dropout 对于神经网络单元,按照一定的概率将其暂时从网络中丢弃。

▪ Data argumentation 做数据增强

▪ Early Stopping 提前停止

Regularization

累加不同cross entropy的成为最终的loss,如果给原来的loss添加一项。
在这里插入图片描述
让|θi|接近于0,即模型参数的范数接近于0,这里的范数可以是1范数,也可以是2范数或者无穷范数。 这样可以使得参数的范数接近于0,使得β0、β1、β2、β3…这些参数接近于0,因为其参数范数接近于零,可以使得高维项的参数接近于0,达到降维效果。为了保证模型的解释能力,因此前几个参数保持较大的数值,而后面的参数接近于0,即β0,β1,β2保持不变这样可以保证模型的解释能力,β3,β4…很小,使得高维特征网络几乎没有了。可以将下图的网络,退化成更小次方的网络。即退化成 y = β0+β1x+β2x^2 失去了高维,也保证了模型的性能,也降低了overfitting。
在这里插入图片描述
该过程叫Regularization,在pytorch或则其他文本中也被称作Weight Decay,Weight权值参数w,使得w接近于0,迫使w越来越接近于0,有衰减的意思Decay。
在这里插入图片描述
通过直观Intuition的角度来解释:右图加了一个L2-regularization
在这里插入图片描述

How 如何Regularization:

常用的是L2-regularization
此外lambda也是一个超参数,需要人为的进行调整。
在这里插入图片描述
在pytorch中如何做L2-regularization:

注意:没有overfitting,设置了L2-regularization,就会使得原有网络的复杂度降低,使得模型的性能急剧下降。
如果有overfitting,设置了weight decay ,该参数设置的好的话,该网络模型表现不会有大的影响,但是这个网络的test performance会有一定的提升。

在这里插入图片描述
在这里插入图片描述

四、动量与学习率衰减

动量 momentum

梯度更新的公式:
在这里插入图片描述
在这里插入图片描述
情况1,no momentum
在这里插入图片描述
情况2,with appr.momentum
添加了动量后,会更加偏向原来运动的方向,动量越大,偏向越多。
此外如果有动量,会有可能使原本停止的最小点被增加的原本的动量所冲出去,会有一定的惯性,使其达到新的最小值点,这个最小值点更加接近全局最小值点。
在这里插入图片描述
在这里插入图片描述
有一些优化器,如Adam是没有momentum这个参数的,这个Adam本身就是用momentum做了一些事情,所以不需要额外管理这个变量,内部已经设置好如何处理这个momentum了。只有最原始的SGD才没有负责处理momentum这个属性,需要我们人为来设置。
在这里插入图片描述

学习率衰减 learning rate decay

学习速率调整Learning rate tunning:

learning rate学习率越大,更新的幅度也就越大。
在这里插入图片描述

学习率衰减 Learning rate decay

学习率衰减目的是一开始学习率很大,后期学习率逐步减小。
在这里插入图片描述
在这里插入图片描述
学习率的衰减可以让模型找到更小的最小值点,可以将模型寻找最小值点的性能提升上来。
在这里插入图片描述

监听learning rate的方案:

1、方案1
在这里插入图片描述
在这里插入图片描述2、方案2:简单粗暴的
这里step_size=30设置的有点小,一般我们会设置为1k或10k。
在这里插入图片描述

五、其他技巧Tricks(Early Stop,Dropout)

Early Stopping

在这里插入图片描述
How-To
▪ Validation set to select parameters
▪ Monitor validation performance
▪ Stop at the highest val performance.

Dropout

不使用全部w,使用最少且表现最好的w数量,其他的w被抛弃。
▪ Learning less to learn better
▪ Each connection has

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

PyTorch-06 过拟合&欠拟合、Train-Val-Test划分、Regularization减轻防止overfitting、动量与学习率衰减、其他技巧Early Stop,Dropout 的相关文章

  • Python 3 os.urandom

    在哪里可以找到完整的教程或文档os urandom 我需要获得一个随机 int 来从 80 个字符的字符串中选择一个字符 如果你只需要一个随机整数 你可以使用random randint a b 来自随机模块 http docs pytho
  • 如何在python 3.7中生成条形码

    我正在使用 python 3 7 为了生成条形码 我尝试使用安装 pyBarcode 库pip install pyBarcode 但它显示以下错误 找不到满足 pyBarcode 要求的版本 来自版本 找不到 pyBarcode 的匹配分
  • 使用 python 中的公式函数使从 Excel 中提取的值的百分比相等

    import xlrd numpy excel Users Bob Desktop wb1 xlrd open workbook excel assignment3 xlsx sh1 wb1 sheet by index 0 colA co
  • 如何检查python xlrd库中的excel文件是否有效

    有什么办法与xlrd库来检查您使用的文件是否是有效的 Excel 文件 我知道还有其他库可以检查文件头 我可以使用文件扩展名检查 但为了多平台性我想知道是否有任何我可以使用的功能xlrd库本身在尝试打开文件时可能会返回类似 false 的内
  • Kivy - 有所有颜色名称的列表吗?

    在 Kivy 中 小部件 color属性允许输入其值作为字符串颜色名称 也 例如在 kv file Label color red 是否有所有可能的颜色名称的列表 就在这里 来自Kivy 的文档 https kivy org doc sta
  • Python - 来自 .进口

    我第一次尝试图书馆 我注意到解决图书馆内导入问题的最简单方法是使用如下结构 from import x from some module import y 我觉得这件事有些 糟糕 也许只是因为我不记得经常看到它 尽管公平地说我还没有深入研究
  • Python3.0 - 标记化和取消标记化

    我正在使用类似于以下简化脚本的内容来解析较大文件中的 python 片段 import io import tokenize src foo bar src bytes src encode src io BytesIO src src l
  • 如何在 Python 中加密并在 Java 中解密?

    我正在尝试在 Python 程序中加密一些数据并将其保存 然后在 Java 程序中解密该数据 在Python中 我像这样加密它 from Crypto Cipher import AES KEY 1234567890123456789012
  • 使用 python 绘制正值小提琴图

    我发现小提琴图信息丰富且有用 我使用 python 库 seaborn 然而 当应用于正值时 它们几乎总是在低端显示负值 我发现这确实具有误导性 尤其是在处理现实数据集时 在seaborn的官方文档中https seaborn pydata
  • 使用Python计算目录的大小?

    在我重新发明这个特殊的轮子之前 有没有人有一个很好的例程来使用 Python 计算目录的大小 如果例程能够很好地以 Mb Gb 等格式格式化大小 那就太好了 这会遍历所有子目录 总结文件大小 import os def get size s
  • Matplotlib 中 x 轴标签的频率和旋转

    我在下面编写了一个简单的脚本来使用 matplotlib 生成图形 我想将 x tick 频率从每月增加到每周并轮换标签 我不知道从哪里开始 x 轴频率 我的旋转线产生错误 TypeError set xticks got an unexp
  • 使用“默认”环境变量启动新的子进程

    我正在编写一个构建脚本来解析依赖的共享库 及其共享库等 这些共享库在正常情况下是不存在的PATH环境变量 为了使构建过程正常工作 让编译器找到这些库 PATH已更改为包含这些库的目录 构建过程是这样的 加载器脚本 更改 PATH gt 基于
  • 返回表示每组内最大值的索引的一系列数字位置

    考虑一下这个系列 np random seed 3 1415 s pd Series np random rand 100 pd MultiIndex from product list ABDCE list abcde One Two T
  • Anaconda 无法导入 ssl 但 Python 可以

    Anaconda 3 Jupyter笔记本无法导入ssl 但使用Atom终端导入ssl没有问题 我尝试在 Jupyter 笔记本中导入 ssl 但出现以下错误 C ProgramData Anaconda3 lib ssl py in
  • 在系统托盘中隐藏 tkinter 窗口 [重复]

    这个问题在这里已经有答案了 我正在制作一个程序来提醒我朋友的生日 这样我就不会忘记祝福他们 为此 我制作了两个 tkinter 窗口 1 First one is for entering name and birth date 2 Sec
  • Elasticsearch 通过搜索返回拼音标记

    我用语音分析插件 https www elastic co guide en elasticsearch plugins current analysis phonetic html由于语音转换 从弹性搜索中进行一些字符串匹配 我的问题是
  • Django Admin 中的反向内联

    我有以下 2 个型号 现在我需要将模型 A 内联到模型 B 的页面上 模型 py class A models Model name models CharField max length 50 class B models Model n
  • Django 与谷歌图表

    我试图让谷歌图表显示在我的页面上 但我不知道如何将值从 django 视图传递到 javascript 以便我可以绘制图表 姜戈代码 array Year Sales Expenses 2004 1000 400 2005 1170 460
  • TKinter 中的禁用/启用按钮

    我正在尝试制作一个像开关一样的按钮 所以如果我单击禁用按钮 它将禁用 按钮 有效 如果我再次按下它 它将再次启用它 我尝试了 if else 之类的东西 但没有成功 这是一个例子 from tkinter import fenster Tk
  • 如何为不同操作系统/Python 版本编译 Python C/C++ 扩展?

    我注意到一些成熟的Python库已经为大多数架构 Win32 Win amd64 MacOS 和Python版本提供了预编译版本 针对不同环境交叉编译扩展的标准方法是什么 葡萄酒 虚拟机 众包 我们使用虚拟机和Hudson http hud

随机推荐

  • 传感器尺寸与像素密度对相片分辨率的影响

    在人们日常生活摄影中 相机的传感器尺寸以及像素素往往决定了一幅图像的清晰度 当然 不同的镜头 不同的CMOS质量等等都会对相片的质量产生影响 今天就简单讨论讨论传感器尺寸和像素密度对图像分辨率的影响 当传感器尺寸一定时 像素越多 也就是像素
  • Python-集合

    探索Python集合的奇妙世界 在Python编程中 集合 Set 是一种强大且有用的数据结构 它用于存储多个不重复的元素 集合的独特之处在于它的元素是无序的 并且每个元素都是唯一的 这使得集合在处理去重和进行快速成员检查时非常有效 创建集
  • 手把手带你打造自己的UI样式库(第五章)之常用页面切图的设计与开发

    常用页面切图的设计与开发 在一些大的前端团队中 前端工程师这个职位会出现一个分支 叫做重构工程师 重构工程师主要负责 HTML 和 CSS 的制作 也就是把设计稿转换成 HTML 和 CSS 代码 重构工作完成以后 把制作好的 HTML 和
  • 【第十四届蓝桥杯单片机组底层驱动测试】

    第十四届蓝桥杯单片机组底层驱动测试 下面分享的是第十四届蓝桥杯单片机组底层驱动代码的测试和相关说明 今年官方提供的资料包中底层驱动代码和以往有了变化 主要代码还是提供给了我们 只是此次没有了相关头文件iic h ds1302 onewire
  • win10剪贴板快捷键win+v

    win v可以出现最近10多次粘贴的数据
  • Ioc容器refresh总结(3)--- Spring源码从入门到精通(三十三)

    上篇文章介绍了 调用bean工厂的后置处理器 主要分为两步 他是在beanFactory预准备标准初始化之后执行invokBeanFactoryPostProcessor 先调用beanDefinitionRegistryPostProce
  • [paper] MTCNN

    MTCNN 论文全称 Joint Face Detection and Alignment using Multi task Cascaded Convolutional Networks 论文下载链接 https arxiv org ab
  • vue.js基础学习(模板语法)

    基础入门 vue js模板语法 1 模板语法 methods 给vue定义方法 this 指向当前vue实例 v html 让内容以HTML形式编译 v bind 绑定动态数据 v noce 当数据发生改变时 插值处内容不发生改变 动态属性
  • maven相关

    1 webxml attribute is required or pre existing WEB INF web xml if executing in update 原因 web项目下缺少 WEB INF web xml 在servl
  • 【AWS】API Gateway创建Rest API--从S3下载文件

    一 背景 在不给AK SK的前提下 用户查看s3上文件 从s3下载文件 二 创建API 1 打开API Gateway 点击创建API 选择REST API REST API和HTTP API区别 来自AWS官网 REST API 和 HT
  • 算法——查找

    文章目录 一 基本概念和评价 1 相关概念 2 查找表 2 1 常见操作 2 2 分类 3 查找算法的评价指标 二 线性结构查找 1 顺序查找算法 1 1 定义 1 2 算法思想 1 3 特点 1 4 分类 1 无哨兵的无序线性表的顺序查找
  • Unity 安卓报错 failed to extract resources needed by IL2cpp

    Unity打出来的包在自己的PC放置好文件后 运行能够正常运行 但是git提交之后 别的机器拉代码下来报错 failed to extract resources needed by IL2cpp 这里推测原因是 安卓包打出来的Asset
  • TikTok逆向,全球的小姐姐们,我来啦!

    作者 AYJk 链接 https juejin im post 5c19a38ae51d453e0a209256 开源地址 首先抛出GitHub地址吧 多多支持指点 谢谢 AYTikTokPod https github com AYJk
  • 知识索引目录

    author skate time 2012 11 22 存储 io系统的压力测试工具 fio http blog csdn net wyzxg article details 7454072 iozone使用 http blog csdn
  • 从零开发一套完整的react项目开发环境

    不管是工作需要还是面试加分 除了Vue相关技术以外 React技术栈也已经成为了前端开发工程师必备的技术点 接下来 我将从零开发一套完整的React全家桶项目开发环境 提供给需要的同行小伙伴观看也方便自己以后复习 篇幅很长 请需要的小伙伴耐
  • ZLMediaKit+wvp-GB28181-pro 安装文档

    文章目录 前言 1 安装zlm 1 1 镜像说明 1 2 docker安装 1 2 1 docker安装命令 1 2 2 docker compose安装 1 3 zlm配置和日志重点说明 2 安装wvp 2 1 目录结构说明 2 1 1
  • 汇编宏伪指令介绍

    1 汇编宏伪指令介绍 macro macname macargs endm 1 macro 和 endm 表示宏定义的开始和结束 2 macro 后面接着宏定义的名字 然后是参数 参数后面的宏定义的实现 3 在宏定义中使用参数 需要添加前缀
  • Eclipse中Web项目开发与Tomcat发布的的路径问题详解

    本人以前对Web项目的开发路径和发布路径等一直都很懵逼 今天找到了一片文章 里面写得很详细 在这里转载分享给大家 https www cnblogs com teach p 5669873 html
  • STM32F407输入捕获应用--PWM 输入模式测量脉冲频率与宽度

    STM32F407输入捕获应用 PWM 输入模式测量脉冲频率与宽度 一 测量脉宽或者频率 二 PWM 输入模式 三 软件实现 3 1 硬件准备 3 2代码 3 4 验证 输入捕获一般应用在两个方面 一个方面是脉冲跳变沿时间测量 另一方面 是
  • PyTorch-06 过拟合&欠拟合、Train-Val-Test划分、Regularization减轻防止overfitting、动量与学习率衰减、其他技巧Early Stop,Dropout

    PyTorch 06 过拟合 欠拟合 Train Val Test划分 Regularization减轻防止overfitting 动量与学习率衰减 其他技巧Tricks Early Stop Dropout 一 过拟合 欠拟合 讨论过拟合