模型优化-RMSprop

2023-11-18

RMSprop 全称 root mean square prop 算法,和动量方法一样都可以加快梯度下降速度。关于动量方法的内容可以参考这篇博文模型优化-动量方法

动量方法借助前一时刻的动量,从而能够有效地缓解山谷震荡以及鞍部停滞问题。而 RMSprop 对比动量方法的思想有所不同,以 y = wx + b 为例,因为只有两个参数,因此可以通过可视化的方式进行说明。

RMSprop示例图

假设纵轴代表参数 b,横轴代表参数 w,由于 w 的取值大于 b,因此整个梯度的等高线呈椭圆形。可以看到越接近最低点(谷底),椭圆的横轴与纵轴的差值也越大,正好对应我们先前所说的山谷地形。

上图中可以看到每个点的位置,以及这些点的梯度方向,也就是说,每个位置的梯度方向垂直于等高线。那么在山谷附近,虽然横轴正在推进,但纵轴方向的摆动幅度也越来越大,这就是山谷震荡现象。如果使用的随机梯度下降,则很有可能不断地上下震荡而无法收敛到最优值附近。所以,我们向减缓参数 b 方向(纵轴)的速度,同时加快参数 w 方向(横轴)的速度。

【计算过程】:

  • 单独计算每个参数在当前位置的梯度。
    d w i = ∂ L ( w ) ∂ w i dw_{i} = \frac{\partial L(w)}{\partial w_i} dwi=wiL(w)

  • 计算更新量。
    S d w i = β S d w i + ( 1 − β ) d w i 2 Sdw_{i} = \beta Sdw{i} + (1 - \beta)dw_{i}^2 Sdwi=βSdwi+(1β)dwi2
    需要注意的是 d w 2 dw^2 dw2 是指对 dw 做平方处理。

  • 更新参数。
    w i = w i − η d w i S d w i w_i = w_i - \eta \frac{dw_i}{\sqrt{Sdw_i}} wi=wiηSdwi dwi

需要注意 S d w i Sdw_i Sdwi 有可能为 0,因此可以添加一个极小的常数来防止分母为零的情况出现。
w i = w i − η d w i σ + S d w i w_i = w_i - \eta \frac{dw_i}{\sigma + \sqrt{Sdw_i}} wi=wiησ+Sdwi dwi
也可以把这个极小的值放到根号里面。
w i = w i − η d w i σ + S d w i w_i = w_i - \eta \frac{dw_i}{\sqrt{\sigma + Sdw_i}} wi=wiησ+Sdwi dwi

根据参数更新公式, S d w i Sdw_i Sdwi 越大,则 w 更新得越慢。在先前所讲的山谷地形中,纵轴方向的梯度要大于横轴方向的梯度,也就是说 db 远大于 dw, d b / S d b db/\sqrt{Sdb} db/Sdb 值要小于 d w / S d w dw/\sqrt{Sdw} dw/Sdw ,最终在纵轴方向上更新得较慢,而在横轴上更新得更快。

RMSprop 实际上是将椭圆形的等高线转换为圆形的等高线。怎么理解呢?当采用特征归一化将 w 和 b 都转化为 [0, 1] 区间后,此时的图等同于右图。

特征归一化.png

因为是圆形,无论是纵轴还是横轴的梯度大小都相等,那么计算得到的更新量 Sdw = Sdb。若等高线呈椭圆形,则椭圆形长轴方向更新量要大于椭圆形短轴方向,就好比长轴长度为 10,短轴长度为 5,长轴方向每次更新 1,短轴方向每次更新 0.5。虽然速度上不想等,但两者最终从一端抵达另一端所需的时间是一致的。这也是为什么我将 RMSprop 理解成将椭圆形等高线转换为圆形。

【代码实现】:

def RMSprop(x, y, step=0.01, iter_count=500, batch_size=4, beta=0.9):
    length, features = x.shape
    data = np.column_stack((x, np.ones((length, 1))))
    w = np.zeros((features + 1, 1))
    Sdw, eta = 0, 10e-7
    start, end = 0, batch_size
    for i in range(iter_count):
        # 计算梯度
        dw = np.sum((np.dot(data[start:end], w) - y[start:end]) * data[start:end], axis=0) / length        
        # 计算更新量
        Sdw = beta * Sdw + (1 - beta) * np.dot(dw, dw)                     
        # 更新参数
        w = w - (step / np.sqrt(eta + Sdw)) * dw.reshape((features + 1, 1))
        start = (start + batch_size) % length
        if start > length:
            start -= length
        end = (end + batch_size) % length
        if end > length:
            end -= length
    return w

对比 AdaGrad 的实现代码,我们可以发现 RMSprop 实际上在 AdaGrad 的梯度累积平方计算公式上新增了一个衰减系数 β 来控制历史信息的获取。

  • AdaGrad:
    r = r + d w 2 r = r + dw^2 r=r+dw2
  • RMSprop:
    S d w = β S d w + ( 1 − β ) d w 2 Sdw = \beta Sdw + (1 - \beta)dw^2 Sdw=βSdw+(1β)dw2

从这个角度来说,RMSprop 改变了学习率。

RMSprop 算法可以结合牛顿动量,RMSprop 改变了学习率,而牛顿动量改变了梯度,从两方面改变更新方式。

【代码实现】:

def RMSprop(x, y, step=0.01, iter_count=500, batch_size=4, alpha=0.9, beta=0.9):
    length, features = x.shape
    data = np.column_stack((x, np.ones((length, 1))))
    w = np.zeros((features + 1, 1))
    Sdw, v, eta = 0, 0, 10e-7
    start, end = 0, batch_size
    
    # 开始迭代
    for i in range(iter_count):
        # 计算临时更新参数
        w_temp = w - step * v
        
        # 计算梯度
        dw = np.sum((np.dot(data[start:end], w_temp) - y[start:end]) * data[start:end], axis=0).reshape((features + 1, 1)) / length        
        
        # 计算累积梯度平方
        Sdw = beta * Sdw + (1 - beta) * np.dot(dw.T, dw)
        
        # 计算速度更新量、
        v = alpha * v + (1 - alpha) * dw
        
        # 更新参数
        w = w - (step / np.sqrt(eta + Sdw)) * v
        start = (start + batch_size) % length
        if start > length:
            start -= length
        end = (end + batch_size) % length
        if end > length:
            end -= length
    return w

关于 RMSProp 相关的代码都可从 传送门 中获得。

参考

  • 吴恩达老师的深度学习课程
  • Deep Learning 最优化方法之 RMSProp:https://blog.csdn.net/bvl10101111/article/details/72616378
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

模型优化-RMSprop 的相关文章

  • Java爬虫与Python爬虫有什么区别

    Java爬虫和Python爬虫是两种常见的网络爬虫实现方式 它们在语言特性 开发环境和生态系统等方面存在一些区别 1 语言特性 Java是一种面向对象的编程语言 而Python是一种脚本语言 Java较为严谨 需要明确定义类 方法和变量 而

随机推荐

  • MDT 2013 从入门到精通之软件自动化部署设置

    因为工作时间原因已经很长一段时间没有更新博客 还请大伙见谅哈 有关MDT系列文章也是很久没有更新了 今天就来谈谈一些常规技巧内容 我们在日常使用MDT部署过程中 很多新手总是纠结于软件的安装问题 总是通过SkipApplications N
  • h5页面loading丝滑小妙招,vue+vant

    1 v if 使用v if tag 1 在data声明一个变量tag 0 请求到参数后tag 1 我会在created重新初始化tag 0 为了保险我还会加一个setTimeout定时器 div class main div data re
  • java项目的远程调试

    我们在工作中可能会遇到这样的场景 有时候有个问题在本地环境不重现开发或者测试环境的问题 而这个问题需要急需解决的情况 更有部分项目在本地无法启动 需要依赖在服务器启动 有时候可以尝试远程调试 我这里用springboot项目 做一下演示 在
  • phpexcel导出

    fileName 亚马逊品类数据 date Y m d fileType xlsx sql select a sku b product typename c category status a gender a sales status
  • 数据库表与表的三种方式

    表和表之间 一般就是三种关系 一对一 一对多 多对多 1 一对一 数据库表中的数据结构 我们用人与车一 一对应的方式来描述一对一的数据表结构 type是区分这条数据是人还是车 master对应是的主人 车的主人是哪个id car对应的是那辆
  • 示波器探头碰人的波形,人碰示波器探头的波形

    如上图所示 如图中 点说明电流恒定 导体切割磁场线 向导线方向切割磁场变强 远离导线切割磁场变弱 则图中 点说明导体不动 但是导线电流增大则磁场强度增加 等效成导体往恒定电流磁场切割 导线电流减小则磁场减小 等效成导体往恒定电流磁场反方向切
  • 「 标准 」NTSC、PAL、SECAM 三大制式简介

    NTSC National Televison System Committee 制式 NTSC 电视标准 每秒 29 97 帧 简化为 30 帧 电视扫描线为 525 线 偶场在前 奇场在后 标准的数字化 NTSC 电视标准分辨率为720
  • KubeVela 正式开源:一个高可扩展的云原生应用平台与核心引擎

    来源 阿里巴巴云原生公众号 美国西部时间 2020 年 11 月 18 日 在云原生技术 最高盛宴 的 KubeCon 北美峰会 2020 上 CNCF 应用交付领域小组 CNCF SIG App Delivery 与 Open Appli
  • 信号处理——梅尔滤波器(MFCC)

    信号处理 梅尔滤波器 MFCC 一 概述 在语音识别 Speech Recognition 和话者识别 Speaker Recognition 方面 最常用到的语音特征就是梅尔倒谱系数 Mel scale FrequencyCepstral
  • NoSQL系统的分类

    什么是NoSQL系统 采用最终一致性的数据库系统 统称为NoSQL Not only SQL 系统 根据数据模型的不同 NoSQL系统又分为以下几类 基于键值对的 Memcached Redis 基于列存储的 Bigtable Apache
  • 小米路由器3/3G/4通过串口(ttl)刷机

    准备工作 淘宝购买 USB转TTL CH340模块 杜邦线 排针 https detail tmall com item htm id 525204252260 spm a1z09 2 0 0 19dc2e8doubZVx u blagqs
  • 查看linux下安装了哪些软件

    1 查看是否安装了gcc 命令 rpm ql gcc rpm qa grep gcc 参数 q 询问 a 查询全部 l 显示列表 2 权限 安装和删除只有root和有安装权限的用户才可以进行 查询是每个用户都可以进行操作的 RPM 的介绍和
  • 《Docker 镜像操作》

    Docker 镜像原理 1 Docker 镜像本质是什么 是一个分层文件系统 2 Docker 中一个 centos 镜像为什么只有 200MB 而一个 centos 操作系统的 iso 文件要几个个 G Centos 的 iso 镜像文件
  • IDEA插件-PlantUML

    一 idea安装plantUml插件 在idea中Preferences gt plugins gt Browse repositories gt 搜索 plantUML gt 安装即可 二 通过 brew 安装 Graphviz 安装pl
  • [极客大挑战 2019]RCE ME(取反、异或绕过正则表达式、bypass disable_function)

    题目进去后 很简单的代码 显然命令执行 看到了eval 应该是用system等函数来实现命令执行 但是得要先绕过preg match 中正则表达式的限制 一开始傻乎乎的直接传了个数组 妄图绕过preg match 这很显然是不行的 附上大佬
  • c语言 push,深入了解C语言(局部变量的定义)

    深入了解C语言 这一节我们主要来研究一下C语言如何使用函数中的局部变量的 C语言中对于全局变量和局部变量所分配的空间地址是不一样的 全局变量是放在 DATA段 也就是除开 TEXT代码段的另一块集中的内存空间 而局部变量主要是使用堆栈的内存
  • Java 9:装B之前你必须要会的——泛型,注解,反射

    1 泛型 1 1 基本概念 泛型提供了编译期的类型检查 但问题远非这么简单 原生态类型 List list1 new ArrayList 规避的类型检查 List list1 new ArrayList
  • 【mcuclub】PH酸碱度检测传感器-PH4502C

    一 实物图 型号 PH4502C 二 原理图 编号 名称 功能 1 VCC 供电电压正极 5V 2 GND 供电电压负极 3 GND 模拟信号输出负极 4 PO 模拟信号输出正极 5 2V5 基准电压2 5V输出口 6 T1 温度传感器DS
  • 在 vscode 上刷力扣 Leetcode 可以这样来

    背景 神奇的算法网站 LeetCode 值得驻留 网页版似乎不太方便 作为习惯于在编译器上敲代码的你 如何 vscode 上优雅的刷力扣 Leetcode 在本地配置 记录下来方便备查 环境前置 电脑具备 NodeJs环境 第一步 安装插件
  • 模型优化-RMSprop

    RMSprop 全称 root mean square prop 算法 和动量方法一样都可以加快梯度下降速度 关于动量方法的内容可以参考这篇博文模型优化 动量方法 动量方法借助前一时刻的动量 从而能够有效地缓解山谷震荡以及鞍部停滞问题 而