机器学习——Dropout原理介绍

2023-10-30

一:引言

  因为在机器学习的一些模型中,如果模型的参数太多,而训练样本又太少的话,这样训练出来的模型很容易产生过拟合现象。在训练bp网络时经常遇到的一个问题,过拟合指的是模型在训练数据上损失函数比较小,预测准确率较高(如果通过画图来表示的话,就是拟合曲线比较尖,不平滑,泛化能力不好),但是在测试数据上损失函数比较大,预测准确率较低。

  常用的防治过拟合的方法是在模型的损失函数中,需要对模型的参数进行“惩罚”,这样的话这些参数就不会太大,而越小的参数说明模型越简单,越简单的模型则越不容易产生过拟合现象。因此在添加权值惩罚项后,应用梯度下降算法迭代优化计算时,如果参数theta比较大,则此时的正则项数值也比较大,那么在下一次更新参数时,参数削减的也比较大。可以使拟合结果看起来更平滑,不至于过拟合。

  Dropout是hintion最近2年提出的;为了防止模型过拟合,Dropout可以作为一种trikc供选择。在hinton的论文摘要中指出,在每个训练批次中,通过忽略一半的特征检测器(让一半的隐层节点值为0),可以明显地减少过拟合现象。这种方式可以减少特征检测器间的相互作用,检测器相互作用是指某些检测器依赖其他检测器才能发挥作用。

二 Dropout方法

训练阶段:

  1.Dropout是在标准的bp网络的的结构上,使bp网的隐层激活值,以一定的比例v变为0,即按照一定比例v,随机地让一部分隐层节点失效;在后面benchmark实验测试时,部分实验让隐层节点失效的基础上,使输入数据也以一定比例(试验用20%)是部分输入数据失效(这个有点像denoising autoencoder),这样得到了更好的结果。

  2.去掉权值惩罚项,取而代之的事,限制权值的范围,给每个权值设置一个上限范围;如果在训练跟新的过程中,权值超过了这个上限,则把权值设置为这个上限的值(这个上限值得设定作者并没有说设置多少最好,后面的试验中作者说这个上限设置为15时,最好;为啥?估计是交叉验证得出的实验结论)。

  这样处理,不论权值更新量有多大,权值都不会过大。此外,还可以使算法使用一个比较大的学习率,来加快学习速度,从而使算法在一个更广阔的权值空间中搜索更好的权值,而不用担心权值过大。

测试阶段:

  在网络前向传播到输出层前时隐含层节点的输出值都要缩减到(1-v)倍;例如正常的隐层输出为a,此时需要缩减为a(1-v)。

  这里我的解释是:假设比例v=0.5,即在训练阶段,以0.5的比例忽略隐层节点;那么假设隐层有80个节点,每个节点输出值为1,那么此时只有40个节点正常工作;也就是说总的输出为40个1和40个0;输出总和为40;而在测试阶段,由于我们的权值已经训练完成,此时就不在按照0.5的比例忽略隐层输出,假设此时每个隐层的输出还是1,那么此时总的输出为80个1,明显比dropout训练时输出大一倍(由于dropout比例为0.5);所以为了得到和训练时一样的输出结果,就缩减隐层输出为a(1-v);即此时输出80个0.5,总和也为40.这样就使得测试阶段和训练阶段的输出“一致”了。(个人见解)

三 Dropout原理分析

  Dropout可以看做是一种模型平均,所谓模型平均,顾名思义,就是把来自不同模型的估计或者预测通过一定的权重平均起来,在一些文献中也称为模型组合,它一般包括组合估计和组合预测。

  Dropout中哪里体现了“不同模型”;这个奥秘就是我们随机选择忽略隐层节点,在每个批次的训练过程中,由于每次随机忽略的隐层节点都不同,这样就使每次训练的网络都是不一样的,每次训练都可以单做一个“新”的模型;此外,隐含节点都是以一定概率随机出现,因此不能保证每2个隐含节点每次都同时出现,这样权值的更新不再依赖于有固定关系隐含节点的共同作用,阻止了某些特征仅仅在其它特定特征下才有效果的情况。

  这样dropout过程就是一个非常有效的神经网络模型平均方法,通过训练大量的不同的网络,来平均预测概率。不同的模型在不同的训练集上训练(每个批次的训练数据都是随机选择),最后在每个模型用相同的权重来“融合”,介个有点类似boosting算法。

四 代码详解

  首先先介绍一个基于matlab deeplearning toolbox版本的dropout代码,主要参考(tornadomeet大牛博客),如果了解DenoisingAutoencoder的训练过程,则这个dropout的训练过程如出一辙;不需要怎么修改,就可以直接运行,因为在toolbox中已经修改完成了。

  这个过程比较简单,而且也没有使用L2规则项,来限制权值的范围;主要是用于理解dropout网络,在训练样本比较少的情况下,dropout可以很好的防止网络过拟合。

训练步骤:

1.提取数据(只提取2000个训练样本)

2 初始化网络结构:这里主要利用nnsetup函数构建一个[784 100 10]的网络。由于是练习用途,所以不进行pre_training。

3 采用minibatch方法,设置dropout比例nn.dropoutFraction=0.5;利用nntrain函数训练网络。

  按比例随机忽略隐层节点:

if(nn.dropoutFraction > 0)

           if(nn.testing)%测试阶段实现mean network,详见上篇博文

                nn.a{i} = nn.a{i}.*(1 - nn.dropoutFraction);

           else%训练阶段使用
                nn.dropOutMask{i} =(rand(size(nn.a{i}))>nn.dropoutFraction);

                nn.a{i} =nn.a{i}.*nn.dropOutMask{i};
           end
end
>> a=rand(1,6)

>> temp=(rand(size(a))>0.5)

>> dropout_a=a.*temp

误差delta反向传播实现:

% delta(i)=delta(i+1)W(i)*a(i)(1-a(i)) ;之后再进行dropout

if(nn.dropoutFraction>0)

   d{i} = d{i} .* [ones(size(d{i},1),1) nn.dropOutMask{i}];

end

权值更新值delta_w实现:

%  delta_w(i)=delta(i+1)*a(i) 
for i = 1 : (n - 1)
    if i+1==n
       nn.dW{i} = (d{i + 1}' * nn.a{i}) / size(d{i + 1}, 1);
    else
   nn.dW{i} = (d{i + 1}(:,2:end)' * nn.a{i}) / size(d{i + 1}, 1);
    end
end

测试样本错误率:15.500% without dropout

测试样本错误率:12.100% with dropout

参考文献:

http://www.cnblogs.com/tornadomeet/p/3258122.html

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

机器学习——Dropout原理介绍 的相关文章

  • I don't know what to say 事件的 NPM 包中奖名单,有你在用的吗?

    事件详情请看 GitHub Issue 及 justjavac 发布的文章 有人统计出目前引用了 event stream 的 3900 多个包 如下 名次越靠前使用的人越多 ps tree nodemon flatmap stream p
  • 电脑开机就进入bios的解决方法

    最近很多人反映自己的电脑一开机就直接进入bios里 无法正常进入系统 这是怎么回事呢 开机进入bios无法进入系统怎么办呢 别着急 今天就为大家带来电脑开机就进入bios的解决方法 电脑开机就进入bios的解决方法 1 如果是电脑的硬盘出了

随机推荐

  • 区块链应用对金融科技行业的未来造成的巨大冲击力

    区块链技术被认为是继蒸汽机 电力 互联网之后 第四次技术 最具颠覆性的技术 有可能彻底改变整个人类社会价值传递的方式 也将深刻地变革金融的未来 今年早春区块链的呼喊仍在回荡 9月的朗迪峰会上区块链再度成为本次金融科技峰会的关注热点 像天使数
  • 【VUE】-使用VUE进行移动端H5页面开发前的推荐准备工作

    在正式使用Vue进行移动端页面开发前 需要做一些前置工作 以此保证用户在访问页面时看到的东西不会因设备的差异而出现各种不同的效果 比如一个页面在iphone7 plus上显示的很正常 然后切换到了Iphone5上因为屏幕太小部分页面内容被遮
  • 七大编程语言

    编程入门之hello world 1 java 编程语言之首 Java是种开发者用来创造计算机应用的程序语言 Java也有一些Web插件允许你在浏览器中运行 Java可以用来安卓和IOS应用开发 视频游戏开发 桌面GUI 软件开发 Java
  • 宋浩概率论与数理统计笔记(一)

    基本信息 本篇是根据宋浩老师在B站的概率论与数理统计完成 标明了每一个知识点所在的时间点 在学数学的时候笔记必不可少 但频繁暂停记笔记又浪费时间 那你就借他山之石 快速掌握基本数学知识 宋浩老师视频的时长分布 Lesson1 随机试验与随机
  • [Unity][Unity光照][Unity摄像机]代码来改变场景变黑

    要使得场景中完全变黑 同时需要几个操作 1 控制场景的光源 比如新场景中的 直射光Directional Light 把所有光源的active设置为false 2 对摄像机背景进行设置 设置Camera的ClearFlags不为Skybox
  • 国内镜像安装Centos7的jenkins.rpm

    Jenkins官方推荐的安装方式 不过该方式有两个弊端 第一 该方式默认安装的是Jenkins的最新版本 所以无法自定义安装版本 第二 使用国外的镜像源 所以其下载速度极慢 Jenkins华为镜像源 Jenkins清华大学镜像源 Jenki
  • python邮件合并的基本操作步骤_Python SMTP:将电子邮件合并为一个

    The objective is to send the email to two people at a time I prepare the email message I iterate over the pairs and send
  • 数组顺序颠倒php,php怎么将数组顺序反转

    PHP中可以使用array reverse 函数来将数组顺序反转 语法格式为 array reverse array preserve 参数preserve可省略 用于规定是否保留原始数组的键名 只针对数字键名 非数字的键则不受影响 本教程
  • 解决jupyter notebook中出现"Figure size 640x480 with 1 Axes"不显示图片的方案

    问题代码 可忽略代码 import numpy as np from sklearn feature selection import SelectKBest f classif import matplotlib pyplot as pl
  • R语言实现读取excel

    可以使用R语言中的 readxl 包来读取excel文件 可以使用read excel 函数读取整个工作簿或指定工作表 示例代码如下 安装包 install packages readxl 载入包 library readxl 读取整个工作
  • SSD咯````

    文章目录 SSD咯 为什用卷积代替全连接 为什么conv4 3有一个Normalize操作 为什么采用anchor 如何匹配anchor 损失函数 SSD咯 SSD Single Shot MultiBox Detector 的主干网络基于
  • c++求余的用处

    求余符号常常用于数组的数值重新定位的问题 求余符号会把数组穿成一个环状的结构 例如0 10 0 1 10 1 如果将一个值向右平移两个位置则 9 2 10 1 则会在1的位置上出现
  • ARMV8体系结构简介:AArch64系统级体系结构之VMSA

    1 前言 2 VMSA概述 2 1 ARMv8 VMSA naming VMSAv8 整个转换机中 地址转换有一个或两个stage VMSAv8 32 由运行AArch32的异常级别来管理 VMSAv8 64 由运行AArch64的异常级别
  • Bulma Tracy 小笔记

    https bulma zcopy site column cd C Program Files nginx start nginx exe 启动服务 55555 cd front account cd service gateway cd
  • 十八年开发经历小结

    原文地址 http blog csdn net binarytreeex article details 7999853 comments 本来题目想写为 十八年开发经历总结 但是一想我的开发生涯还没结束 怎么就总结了呢 再说个人的一些积累
  • Nuget 配置文件的位置

    最近在 Visual Studio 中使用 Nuget 时 发现总是连接代理服务器 忘了什么时候配置的了 找了半天没找到配置位置 最后发现在这个地方 appdata NuGet 找到 NuGet Config 文件 其中的
  • L2-001 紧急救援 (25 分)

    题目 题目链接 题解 最短路 扩展 算是朴素Dijkstra模板吧 Dijkstra算法 额外加上记录路径 记录到达此处的最短距离 记录以最短距离到达此处的最多人数 更新方式 假设未确定距离的点集中的点t距离已确定距离的点集最近 以t对其他
  • Redis 网络模型

    redis网络模型背景 1 进程分为用户空间和内核空间 用户空间和内核空间共同目标是对系统资源的访问 为了提高IO效率 给用户空间和内核空间都加入了缓存 访问的流程为读写两部分 读 用户空间访问内核空间的缓存产看是否存在资源 若没有内核空间
  • log4j自定类的日志信息打印到指定文件

    需求 现在有需要把每月的定时任务的日志信息 INFO级别的 打印到自定义的emailAccount log文件中 这个跑批类日志信息需要跟其他文件中INFO区分开来 也就是说emailAccount log文件不能有别的文件的INFO级别的
  • 机器学习——Dropout原理介绍

    一 引言 因为在机器学习的一些模型中 如果模型的参数太多 而训练样本又太少的话 这样训练出来的模型很容易产生过拟合现象 在训练bp网络时经常遇到的一个问题 过拟合指的是模型在训练数据上损失函数比较小 预测准确率较高 如果通过画图来表示的话