模型选择、欠拟合和过拟合

2023-11-17

训练误差(training error):模型在训练数据集上表现出的误差。

泛化误差(generalization error):模型在任意一个测试数据样本上表现出的误差的期望,常常通过测试数据集上的误差来近似

机器学习模型应该关注泛化误差。


模型选择(model selection)

1. 验证数据集(validation set):预留一部分在训练数据集和测试数据集以外的数据进行模型选择 ,例如我们可以从给定的训练集中选取一小部分作为验证集,剩余部分作为真正的训练集。

2. k折交叉验证(k-fold cross-validation):由于验证数据集不参与模型训练,当训练数据不够用时,预留大量的验证数据显得太奢侈,并且人们发现用同一数据集,既进行训练,又进行模型误差估计,对误差估计的很不准确,这就是所说的模型误差估计的乐观性。为了克服这个问题,提出了交叉验证:我们把训练数据集分割成k个不重合的子数据集,然后我们做k次模型训练和验证。每一次我们使用一个子数据集验证模型,并使用其他k-1个子数据集来训练模型。最后,我们对这k次训练误差和验证误差分别求平均。


欠拟合和过拟合

欠拟合(underfitting):模型无法得到较低的训练误差

过拟合(overfitting):模型的训练误差远小于其在测试集上的误差

  • 造成过拟合和欠拟合的主要原因是模型复杂度训练数据集的大小。

模型复杂度:

1. 给定训练数据集,如果模型的复杂度过低,很容易出现欠拟合。

2. 如果模型的复杂度过高,容易出现过拟合。

训练数据集大小:

一般来说,训练数据集中样本过少(特别是比模型参数数量更少时),过拟合更容易发生。

测试如下:

1. 正常拟合,虽然这里测试集上的误差比训练误差还好

2. 欠拟合,训练集误差很大,并且训练误差在迭代早期下降后就很难继续降低。(这里选择了一个低阶模型去拟合高阶模型产生的数据)

3. 过拟合,训练误差很小,测试集上的误差很大。(这里使用少量的数据去训练模型)

 

应对过拟合的常用方法:

权重衰减(weight decay)

权重衰减也叫L2范数正则化(regularization)。通过为模型的损失函数添加惩罚项使学出的模型参数值较小。带有L2范数惩罚项的新损失函数为:

                                                                        l(\boldsymbol{w},\boldsymbol{b})+\frac{\lambda }{2n}\left \| \boldsymbol{w} \right \|^2

                                                                        \frac{1}{n}\sum_{i=1}^{n}\frac{1}{2}(\hat{y}^{(i)}-y^{(i)})^2+\frac{\lambda }{2n}\left \| \boldsymbol{w} \right \|^2

其中权重衰减超参数 \lambda > 0。当权重参数 w 均为0时,惩罚项最小。当 \lambda 较大时,惩罚项在损失函数中的比重较大,这通常会使学到的权重参数的元素较接近0。

PS:正则化(regularization)按照个人理解是给模型加上约束(惩罚),用于降低模型的复杂度。

MXNet实现

这里我们直接在构造Trainer实例时通过wd参数来指定权重衰减超参数。默认下,Gluon会对权重和偏差同时衰减。我们可以分别对权重和偏差构造Trainer实例,从而只对权重衰减

%matplotlib inline
import d2lzh as d2l
from mxnet import autograd, gluon, nd
from mxnet.gluon import data as gdata, loss as gloss, nn

def fit_and_plot_gluon(wd):
    """wd即为上式中的lambd值"""
    net = nn.Sequential()
    net.add(nn.Dense(1))
    net.initialize(init.Normal(sigma=1))
    # 对权重参数衰减。权重名称一般是以weight结尾
    trainer_w = gluon.Trainer(net.collect_params('.*weight'), 'sgd',
                              {'learning_rate': lr, 'wd': wd})
    # 不对偏差参数衰减。偏差名称一般是以bias结尾
    trainer_b = gluon.Trainer(net.collect_params('.*bias'), 'sgd',
                              {'learning_rate': lr})
    train_ls, test_ls = [], []
    for _ in range(num_epochs):
        for X, y in train_iter:
            with autograd.record():
                l = loss(net(X), y)
            l.backward()
            # 对两个Trainer实例分别调用step函数,从而分别更新权重和偏差
            trainer_w.step(batch_size)
            trainer_b.step(batch_size)
        train_ls.append(loss(net(train_features),
                             train_labels).mean().asscalar())
        test_ls.append(loss(net(test_features),
                            test_labels).mean().asscalar())
    d2l.semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'loss',
                 range(1, num_epochs + 1), test_ls, ['train', 'test'])
    print('L2 norm of w:', net[0].weight.data().norm().asscalar())

fit_and_plot_gluon(0)
fit_and_plot_gluon(3)

丢弃法(dropout)

训练过程中对隐藏层的神经元进行丢弃的方法,可以使输出层的计算无法过度依赖某一个神经元,从而在训练模型的时候起到正则化的作用。

注意,这里只是在训练过程中使用丢弃法进行计算

设丢弃概率为 p ,那么有 p 的概率 h_i (隐藏神经元)会被清零,有 1-p 的概率 h_i 会除以 1-p 做拉伸。丢弃概率是丢弃法的超参数。

使用丢弃法计算新的隐藏单元:

                                                                 h_i^{'} = \frac{\xi_i }{1-p}h_i

由于 E(\xi _i) = 1-p ,因此:

                                                                 E(h_i^{'}) = \frac{E(\xi_i) }{1-p}h_i = h_i

即丢弃法不会改变隐藏层的期望值。

使用MXNet实现:

import d2lzh as d2l
from mxnet import autograd, gluon, init, nd
from mxnet.gluon import loss as gloss, nn

"""定义dropout函数,drop_prob为丢弃概率,即使用概率drop_prob对X中的元素清零"""
def dropout(X, drop_prob):
    # assert作用,其条件为假,则终止程序
    assert 0 <= drop_prob <= 1
    keep_prob = 1 - drop_prob
    # 这种情况下把全部元素都丢弃
    if keep_prob == 0:
        return X.zeros_like()
    # 在[0,1]上均匀分布的<keep_prob的mask
    mask = nd.random.uniform(0, 1, X.shape) < keep_prob
    return mask * X / keep_prob

# 定义模型参数
num_inputs, num_outputs, num_hiddens1, num_hiddens2 = 784, 10, 256, 256
W1 = nd.random.normal(scale=0.01,shape=(num_inputs, num_hiddens1))
b1 = nd.zeros(num_hiddens1)
W2 = nd.random.normal(scale=0.01,shape=(num_hiddens1, num_hiddens2))
b2 = nd.zeros(num_hiddens2)
W3 = nd.random.normal(scale=0.01,shape=(num_hiddens2, num_outputs))
b3 = nd.zeros(num_outputs)
params = [W1,b1,W2,b2,W3,b3]
for param in params: # 添加需要求导的参数
    param.attach_grad()

# 定义模型
drop_prob1, drop_prob2 = 0.2, 0.5
def net(X):
    X = X.reshape((-1, num_inputs))
    H1 = (nd.dot(X,W1)+b1).relu()
    if autograd.is_training():  # 只在模型训练的时候丢弃
        H1 = dropout(H1, drop_prob1)
    H2 = (nd.dot(H1,W2)+b2).relu()
    if autograd.is_training():  
        H2 = dropout(H2, drop_prob2)
    return nd.dot(H2, W3)+b3

# 训练和测试模型
num_epochs, lr, batch_size = 5, 0.5, 256
loss = gloss.SoftmaxCrossEntropyLoss()
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, params, lr)

-----------------------------------------------------------------------------------------------
"""简洁实现"""
net = nn.Sequential()
net.add(nn.Dense(256, activation='relu'),
       nn.Dropout(drop_prob1),
       nn.Dense(256,activation='relu'),
       nn.Dropout(drop_prob2),
       nn.Dense(10))
net.initialize(init.Normal(sigma=0.01))

trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate':lr})
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, None, None, trainer)

# 直接使用MXNet的函数实现相对于手动实现的优点:
# 1.不需要手动定义以及初始化权重参数和偏差值
# 2.简洁的模型构建

 

Reference:

《动手学深度学习》-Aston Zhang

 

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

模型选择、欠拟合和过拟合 的相关文章

  • 正睿OI补题(搜索)

    搜索 目录 P1036 NOIP2002 普及组 选数 P2392 kkksc03考前临时抱佛脚 P1025 NOIP2001 提高组 数的划分 P6201 USACO07OPEN Fliptile S P1460 USACO2 1 健康的
  • Spring学习总结

    Spring学习总结 文章目录 Spring学习总结 toc 一 Spring介绍 1 概念 2 下载路径 二 IOC容器 1 IOC概念和原理 什么是IOC IOC底层原理 2 IOC接口 3 IOC操作 Bean管理 什么是Bean管理
  • Hadoop的伪分布式运行模式

    Hadoop运行模式包括 本地模式 伪分布式模式 以及完全分布式模式 1 本地模式 安装简单 在一台机器上运行服务 几乎不用做任何配置 但仅限于调试用途 没有分布式文件系统 直接读写本地操作系统的文件系统 2 伪分布式模式 在单节点上同时启
  • 浅析GPT2中的autoregressive和BERT的autoencoding源码实现

    经常使用BERT来做研究 因此对Encoder的架构较为熟悉 但是从来没有了解过GPT这样的Decoder架构 尤其对自回归的形式不知道源码是如何实现的 为了方便对比和讨论 接来下所探讨的源码都是基于HuggingFace这个框架的 Ber
  • MAC 命令行拷贝文件夹

    命令 cp R 源文件 目标文件 cp R libsvm 3 23 Applications MATLAB R2017b app toolbox 将当前目录下的libsvm 3 23拷贝到 Applications MATLAB R2017
  • 《Python 黑帽子》学习笔记 - 准备 - Day 1

    信息安全是一个有意思的方向 也是自己的爱好 从零开始 想在工作之余把这个爱好培养为自己的技术能力 而 web 安全相对来说容易入门些 于是选择 web 渗透测试作为学习的起点 并选择同样是容易入门的 Python 作为编程工具 潜心学习 持
  • 浏览器有哪些进程?浏览器进程,渲染进程,网络进程,渲染进程有哪些线程?

    浏览器进程 渲染进程有哪些线程 在浏览器中打开两个页面 会开启几个进程 1个浏览器进程 1个网络进程 一个GPU进程 通常一个Tab页对应一个渲染进程 但有其它情况 1 如果页面中有iframe的话 iframe也会运行在单独的进程中 2
  • Python学习笔记--exe文件打包与UI界面设计

    exe文件打包与UI界面设计 前言 一 基于tkinter实现的UI设计 1 1 库的选择及思路 1 2 定位方法的选用 1 3 Frame控件 1 4 变量设置 1 5 批量设置 1 6 Text文本框 1 7 总体界面设计 1 8 功能
  • 【论文笔记】TNASP:A Transformer-based NAS Predictor with a Self-evolution Framework

    文章目录 0 摘要 摘要解读 1 Introduction 2 相关工作 3 方法 3 1 Training based network performance predictors 3 2 基于Transformer的预测器 3 3 自演
  • 编译器报:lambda表达式中使用的变量应为final或有效final 解决方案

    目录 问题描述 原因分析 解决方案 1 声明为final 2 使用有效final 4 使用数组或集合 错误问题的最终解决示例 总结 问题描述 今天在写代码的过程中想要在stream map 方法内部对外部变量进行赋值 结果发现编译器报错 提
  • python的文件操作

    一 文件的基本操作 1 读文件read f open filename r encoding utf 8 data f read 读文件 f close 关闭文件 1 绝对路径的易错点 文件路径中 前要加转义字符 或者 使用r使转义字符失效
  • CentOS 7 挂载本地光盘作为镜像源

    1 上传iso文件到 usr local src 一定要确保这个ISO文件上传完毕后再进行下面的操作 2 创建挂载目录 mkdir media CentOS7 3 挂载iso文件 mount t iso9660 o loop usr loc
  • 神经网络编程技巧(一):两个矩阵相乘报错,np.random.randn(5,)不是矩阵,np.random.randn(5,1)才能得到1*5的矩阵,np.dot()函数

    np dot函数主要用于向量的点积和矩阵的乘法 格式如下np dot a b 其中a b均为n维向量 具体例子参考下面的代码及其结果 在神经网络中经常使用这个函数 能够节约大量的时间 原来复杂的公式在编程时只需要这一行代码即可实现 在编写p
  • 图像检索传统算法学习笔记

    图像检索领域传统算法学习笔记 与组内同学一起找到的一些图像检索传统算法 作一小结 以防忘记 性能统计 传统图像检索算法 CIFAR 10数据集mAP值 编码数不同 LSH局部敏感哈希 0 116 0 131 SH谱哈希 0 124 0 12
  • 学习笔记(三):Java中的List集合——ArrayList、LinkedList、Vector、Stack、CopyOnWriteArrayList

    目录 引言 一 List简介 二 常用List实现类 一 ArrayList 二 LinkedList 三 LinkedList和ArrayList的比较 三 其他List实现类 一 Vector 二 Stack 三 CopyOnWrite
  • CST2020 安装包和安装步骤

    安装包和破解码的百度云链接 链接 https pan baidu com s 1RNSWxVxb DIu8dg8gkCzAw 提取码 dve7 如果失效可评论留言 谢谢 1 关闭防火墙和杀毒软件 2 解压后 以管理员模式运行setup文件
  • 不在傻傻for循环!完美解决JPA批量插入问题

    前言 jpa在简单的增删改查方面确实帮助我们节省了大部分时间 但是面对复杂的情况就显得心有余而力不足了 最近遇到一个批量插入的情况 jpa虽然提供了saveAll方法 但是底层还是for循环save 如果遇到大量数据插入频繁与数据库交互必然
  • 10个 Python 脚本来自动化你的日常任务

    在这个自动化时代 我们有很多重复无聊的工作要做 想想这些你不再需要一次又一次地做的无聊的事情 让它自动化 让你的生活更轻松 那么在本文中 我将向您介绍 10 个 Python 自动化脚本 以使你的工作更加自动化 生活更加轻松 因此 没有更多
  • 一文图解 Transformer,小白也看得懂(完整版)

    原作者 Jay Alammar 原链接 https jalammar github io illustrated transformer 1 导语 谷歌推出的 BERT 模型在11项NLP任务中夺得SOTA结果 引爆了整个NLP界 而BER
  • C 库函数 - mktime()

    描述 C 库函数 time t mktime struct tm timeptr 把 timeptr 所指向的结构转换为自 1970 年 1 月 1 日以来持续时间的秒数 发生错误时返回 1 声明 下面是 mktime 函数的声明 time

随机推荐

  • 一个矩阵乘以它本身的转置等于什么

    如果一个矩阵 A 乘以它本身的转置 AT 那么结果就是一个对角矩阵 对角线上的元素就是 A 矩阵中每一列的平方和 其余的元素都是 0 例如 如果 A 矩阵是 a11 a12 a21 a22 那么 A 乘以 AT 就是 a11 2 a21 2
  • 网道 JS教程 (第一天)

    地址 https wangdoc com javascript js 特点 单线程 事件驱动 非阻塞式设计 数据类型 数值 number 整数和小数 比如1和3 14 字符串 string 文本 比如Hello World 布尔值 bool
  • 关闭或者半关闭?!

    2017 05 20 LIBnids这个库 对于关闭的两个状态 理解的不是很清楚 就是 CLOSE算一个状态 CLOSE之前并不调用EXITING的语句 这就很尴尬 目前就当这两个是同一个状态 但是看着有些数据包是关闭的 结果显示不关闭 这
  • 模块学习笔记—(1)编码器减速电机

    模块学习笔记 1 编码器减速电机 编码器电机作用 编码器电机转动可以产生脉冲信号 根据脉冲信号 可以得出轮胎的转动速度 轮胎的位移 电机正反转等 电机介绍 我的编码器电机是130TT减速电机 电机轴转一圈可以产生13个脉冲信号输出 电机减速
  • 使用UUID获得一个不重复的16位账号的算法

    public static String getAccountIdByUUId int machineId 1 最大支持1 9个集群机器部署 int hashCodeV UUID randomUUID toString hashCode i
  • java运行环境

    计算机基础知识 二进制 如图 十进制和二进制的转换 字节 计算机中一个0或者一个1就是一个位 bit 不过并不是最小的数据单位 最下的数据单位是字节 Byte 一个字节等于8个位 计算机中任何数据的存储都是以字节的形式存储 1 B 8 bi
  • 2、线程池篇 - 从理论基础到具体代码示例讲解(持续更新中......)

    前言 暂无 一 线程篇 有关线程部分的知识整理请看我下面这篇博客 1 线程篇 从理论到具体代码案例最全线程知识点梳理 持续更新中 二 线程池基础知识 线程池优点 他的主要特点为 线程复用 管理线程 不需要频繁的创建和销毁线程 控制线程数量
  • html 登录页面 简洁,简单登录html页面

    简单的登录页面 一个简单pc 移动端显示的html codeDocument margin 0px padding 0px bg width 100 height 45vh text align center color fff backg
  • Could not generate command line for the ‘VCCLCompilerTool’ tool

    转载自 http blog csdn net shirui1125 article details 6095774 gt ToolBox error PRJ0004 未能为 VCCLCompilerTool 工具生成命令行 从原有的平台复制
  • AD采集中的10种经典软件滤波程序优缺点分析(附程序)

    在AD采集中经常要用到数字滤波 而不同情况下又有不同的滤波需求 下面是10种经典的软件滤波方法的程序和优缺点分析 1 限幅滤波法 又称程序判断滤波法 2 中位值滤波法 3 算术平均滤波法 4 递推平均滤波法 又称滑动平均滤波法 5 中位值平
  • 自定义协议:如何实现keepalive

    高可用协议招式 keepalive 什么是keepalive tcp如何实现keepalive http如何实现keepalive 自定义协议时该怎样实现keepalive 什么是keepalive Keepalive是一种技术 它可以帮助
  • C语言最简单的服务器和客户端程序

    服务器 include
  • SQLServer之DEFAULT约束

    DEFAULT约束添加规则 1 若在表中定义了默认值约束 用户在插入新的数据行时 如果该行没有指定数据 那么系统将默认值赋给该列 如果我们不设置默认值 系统默认为NULL 2 如果 默认值 字段中的项替换绑定的默认值 以不带圆括号的形式显示
  • shell面试题

    第1章 选择 1 1 退出交互模式的 shell 应键入 A B q C exit D quit 1 2 下列变量名中有效的 shell 变量名是 C 2 time 2 3 trust no 1 2004file 1 3 在 shell 编
  • stm32低功耗解决方案-(外部时钟芯片RX8025T)

    首先在入手一个芯片时要先观看芯片手册rx8025t和rx8025as手册是不一样 两者的寄存器也会有很大的差距 RX8025t中文手册 本文介绍的是一个低功耗解决方案 因为我使用的是stm32的待机模式 所以只需要在唤醒时想办法就行了 因此
  • Android 中的线程池

    Android 中的线程池 线程池的优点 重用线程池中的线程 避免因为线程的创建和销毁所带来的性能开销 能有效控制线程池的最大并发数 避免大量的线程之间因互相抢占系统资源而导致的阻塞现象 能够对线程进行简单管理 并提供定时执行以及指定间隔循
  • C#编程中遇到的一些异常及部分异常的解决方法

    以下内容是在本人在C 编程中遇到的异常 针对部分异常给出了解决办法 但是此解决方法是否真的好用 有待进一步考证 仅供参考 1 System Invalid Operation Exception 类型的未经处理的异常 出现在System W
  • itext5创建pdf表格及遇到的一些问题

    0 核心依赖 1 设置页眉图片及下划线 2 document参数传递 3 生成的pdf文件转base64编码 4 平方 上标显示问题 5 压缩包的文件流InputStream输出文件 6 itext5进行pdf合并 0 核心依赖
  • 病例对照研究中—两组组间比较—的统计方法选择,基于R语言

    医学中最常设计的试验就是病例对照研究 以探究某一干预措施是否有改善性 需要根据基线的情况 选择相应的方法 试验数据如下 声明 该数据是随机自动生成的 虚拟的 该计算结果不代表任何真实的事情 该数据不适用于现实世界 数据由试验组长病程10名
  • 模型选择、欠拟合和过拟合

    训练误差 training error 模型在训练数据集上表现出的误差 泛化误差 generalization error 模型在任意一个测试数据样本上表现出的误差的期望 常常通过测试数据集上的误差来近似 机器学习模型应该关注泛化误差 模型