深度学习中的优化算法之RMSProp

2023-11-17

      之前在https://blog.csdn.net/fengbingchun/article/details/124766283 中介绍过深度学习中的优化算法AdaGrad,这里介绍下深度学习的另一种优化算法RMSProp。

      RMSProp全称为Root Mean Square Propagation,是一种未发表的自适应学习率方法,由Geoff Hinton提出,是梯度下降优化算法的扩展。如下图所示,截图来自:https://arxiv.org/pdf/1609.04747.pdf

     

       AdaGrad的一个限制是,它可能会在搜索结束时导致每个参数的步长(学习率)非常小,这可能会大大减慢搜索进度,并且可能意味着无法找到最优值。RMSProp和Adadelta都是在同一时间独立开发的,可认为是AdaGrad的扩展,都是为了解决AdaGrad急剧下降的学习率问题。

      RMSProp采用了指数加权移动平均(exponentially weighted moving average)。

      RMSProp比AdaGrad只多了一个超参数,其作用类似于动量(momentum),其值通常置为0.9

      RMSProp旨在加速优化过程,例如减少达到最优值所需的迭代次数,或提高优化算法的能力,例如获得更好的最终结果。

      以下是与AdaGrad不同的代码片段:

      1.在原有枚举类Optimizaiton的基础上新增RMSProp:

enum class Optimization {
	BGD, // Batch Gradient Descent
	SGD, // Stochastic Gradient Descent
	MBGD, // Mini-batch Gradient Descent
	SGD_Momentum, // SGD with Momentum
	AdaGrad, // Adaptive Gradient
	RMSProp // Root Mean Square Propagation
};

      2.calculate_gradient_descent函数:RMSProp与AdaGrad只有g[j]的计算不同

void LogisticRegression2::calculate_gradient_descent(int start, int end)
{
	switch (optim_) {
		case Optimization::RMSProp: {
			int len = end - start;
			std::vector<float> g(feature_length_, 0.);
			std::vector<float> z(len, 0), dz(len, 0);
			for (int i = start, x = 0; i < end; ++i, ++x) {
				z[x] = calculate_z(data_->samples[random_shuffle_[i]]);
				dz[x] = calculate_loss_function_derivative(calculate_activation_function(z[x]), data_->labels[random_shuffle_[i]]);

				for (int j = 0; j < feature_length_; ++j) {
					float dw = data_->samples[random_shuffle_[i]][j] * dz[x];
					g[j] = mu_ * g[j] + (1. - mu_) * (dw * dw);
					w_[j] = w_[j] - alpha_ * dw / (std::sqrt(g[j]) + eps_);
				}

				b_ -= (alpha_ * dz[x]);
			}
		}
			break;
		case Optimization::AdaGrad: {
			int len = end - start;
			std::vector<float> g(feature_length_, 0.);
			std::vector<float> z(len, 0), dz(len, 0);
			for (int i = start, x = 0; i < end; ++i, ++x) {
				z[x] = calculate_z(data_->samples[random_shuffle_[i]]);
				dz[x] = calculate_loss_function_derivative(calculate_activation_function(z[x]), data_->labels[random_shuffle_[i]]);

				for (int j = 0; j < feature_length_; ++j) {
					float dw = data_->samples[random_shuffle_[i]][j] * dz[x];
					g[j] += dw * dw;
					w_[j] = w_[j] - alpha_ * dw / (std::sqrt(g[j]) + eps_);
				}

				b_ -= (alpha_ * dz[x]);
			}
		}
			break;
		case Optimization::SGD_Momentum: {
			int len = end - start;
			std::vector<float> change(feature_length_, 0.);
			std::vector<float> z(len, 0), dz(len, 0);
			for (int i = start, x = 0; i < end; ++i, ++x) {
				z[x] = calculate_z(data_->samples[random_shuffle_[i]]);
				dz[x] = calculate_loss_function_derivative(calculate_activation_function(z[x]), data_->labels[random_shuffle_[i]]);

				for (int j = 0; j < feature_length_; ++j) {
					float new_change = mu_ * change[j] - alpha_ * (data_->samples[random_shuffle_[i]][j] * dz[x]);
					w_[j] += new_change;
					change[j] = new_change;
				}

				b_ -= (alpha_ * dz[x]);
			}
		}
			break;
		case Optimization::SGD:
		case Optimization::MBGD: {
			int len = end - start;
			std::vector<float> z(len, 0), dz(len, 0);
			for (int i = start, x = 0; i < end; ++i, ++x) {
				z[x] = calculate_z(data_->samples[random_shuffle_[i]]);
				dz[x] = calculate_loss_function_derivative(calculate_activation_function(z[x]), data_->labels[random_shuffle_[i]]);

				for (int j = 0; j < feature_length_; ++j) {
					w_[j] = w_[j] - alpha_ * (data_->samples[random_shuffle_[i]][j] * dz[x]);
				}

				b_ -= (alpha_ * dz[x]);
			}
		}
			break;
		case Optimization::BGD:
		default: // BGD
			std::vector<float> z(m_, 0), dz(m_, 0);
			float db = 0.;
			std::vector<float> dw(feature_length_, 0.);
			for (int i = 0; i < m_; ++i) {
				z[i] = calculate_z(data_->samples[i]);
				o_[i] = calculate_activation_function(z[i]);
				dz[i] = calculate_loss_function_derivative(o_[i], data_->labels[i]);

				for (int j = 0; j < feature_length_; ++j) {
					dw[j] += data_->samples[i][j] * dz[i]; // dw(i)+=x(i)(j)*dz(i)
				}
				db += dz[i]; // db+=dz(i)
			}

			for (int j = 0; j < feature_length_; ++j) {
				dw[j] /= m_;
				w_[j] -= alpha_ * dw[j];
			}

			b_ -= alpha_*(db/m_);
	}
}

      执行结果如下图所示:测试函数为test_logistic_regression2_gradient_descent,多次执行每种配置,最终结果都相同。图像集使用MNIST,其中训练图像总共10000张,0和1各5000张,均来自于训练集;预测图像总共1800张,0和1各900张,均来自于测试集。在它们学习率为0.01及其它配置参数相同的情况下,AdaGrad耗时为17秒,RMSProp耗时为33秒;它们的识别率均为100%。当学习率调整为0.001时,AdaGrad耗时为26秒,RMSProp耗时为19秒;它们的识别率均为100%。

      GitHub: https://github.com/fengbingchun/NN_Test

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

深度学习中的优化算法之RMSProp 的相关文章

  • Bug解决:ModuleNotFoundError: No module named ‘taming‘

    from taming modules vqvae quantize import VectorQuantizer2 as VectorQuantizer ModuleNotFoundError No module named taming
  • 【阅读论文方法总结】

    1 快速浏览摘要 看是否有自己需要的东西 2 如果需要 github上查找相关论文代码 对照着论文进行阅读 这样效率高 能够快速理解
  • 图像识别中,目标分割、目标识别、目标检测和目标跟踪这几个方面区别是什么?+资料列表

    目标识别 深度学习进行目标识别的资源列表 转 https zhuanlan zhihu com p 26076489 以下转自 https www zhihu com question 36500536 作者 知乎用户 链接 https w
  • 图解NCHW与NHWC数据格式

    图解NCHW与NHWC数据格式 田海立 CSDN CSDN博客 nchw 流行的深度学习框架中有不同的数据格式 典型的有NCHW和NHWC格式 本文从逻辑表达和物理存储角度用图的方式来理解这两种数据格式 最后以RGB图像为例来加深NHWC和
  • 【生成式网络】入门篇(二):GAN的 代码和结果记录

    GAN非常经典 我就不介绍具体原理了 直接上代码 感兴趣的可以阅读 里面有更多变体 https github com rasbt deeplearning models tree master pytorch ipynb gan GAN 在
  • windows下运行pointnet(全)

    放假闲着在家没事 本人突然想跑一下3d深度学习的开山之作 pointnet玩一玩 可是目前网上大部分pointnet的运行教程都是在Ubuntu系统下的 其实本人也曾装过双系统 但是因为我太菜了 在Ubuntu下装完显卡驱动和cuda后切换
  • 序列模型——自然语言处理与词嵌入(理论部分)

    1 词汇表征 深度学习已经给自然语言处理 Natural Language Process NLP 带来革命性的变革 其中一个很关键的概念是词嵌入 word embedding 这是语言表示的一种方式 可以让算法自动的了解一些类似的词 例如
  • Android平台深度学习--NNAPI

    转自 http blog sina com cn s blog 602f87700102y62v html 1 Android 8 1 API 27 NNAPI 人工智能神经网络API 如 TensorFlow 神经网络 API 能够向设备
  • 火爆科研圈的三维重建技术:Neural radiance fields (NeRF)

    如果说最近两年最火的三维重建技术是什么 相信NeRF 1 是一个绝对绕不过去的名字 这项强到逆天的技术 一经提出 就被众多研究者所重视 对该技术进行深入研究并提出改进已经成为一个热点 仅仅过了不到两年的时间 NeRF及其变种已经成为重建领域
  • Deep Learning Tutorials(一):开头语

    万事开头难 当你开始看这些时候 有可能你已经开始了研究生生活 不在像本科时候过着那种得过且过 考试不挂科的日子 你整天盲目 漫无目的的过日子实际上是在浪费自己的生命 所以坚持每天进步吧 回到正事 你可能开始从事深度学习研究或者有关机器学习方
  • Mac电脑配置李沐深度学习环境[pytorch版本]使用vscode

    文章目录 第一步 M1芯片安装Pytorch环境 安装Miniforge 创建虚拟环境 安装Pytorch 第二步 下载李沐Jupyter文件 第三步 配置vscode 参考 第一步 M1芯片安装Pytorch环境 安装Miniforge
  • Softmax分类和两层神经网络以及反向传播的代码推导

    发现草稿箱里还有一篇很早之前的学习笔记 希望可以帮助到有需要的童鞋 目录 序 Softmax分类器 反向传播 数据构建以及网络训练 交叉验证参数优化 序 原来都是用的c 学习的传统图像分割算法 主要学习聚类分割 水平集 图割 欢迎一起讨论学
  • Transformer——《Attention is all you need》

    本文是Google 机器翻译团队在2017 年发表 提出了一个新的简单的网络模型 Transformer 该模型基于纯注意力机制 Attention mechanisms 完全抛弃了RNN和CNN网络结构 在机器翻译任务上取得了很好的效果
  • 词向量的运算与Emoji生成器

    本文参考参考 没有对框架内容进行学习 旨在学习思路和方法 1 词向量运算 之前学习RNN和LSTM的时候 输入的语句都是一个向量 比如恐龙的名字那个例子就是将一个单词中的字母按顺序依次输入 这对于一个单词的预测是可行的 但是对于想让机器学习
  • 深度学习中的优化算法之Adam

    之前在https blog csdn net fengbingchun article details 124909910 介绍过深度学习中的优化算法Adadelta 这里介绍下深度学习的另一种优化算法Adam 论文名字为 ADAM A M
  • 深度学习系统为什么容易受到对抗样本的欺骗?

    转自 https zhuanlan zhihu com p 89665397 本文作者 kurffzhou 腾讯 TEG 安全工程师 最近 Nature发表了一篇关于深度学习系统被欺骗的新闻文章 该文指出了对抗样本存在的广泛性和深度学习的脆
  • PyTorch训练简单的全连接神经网络:手写数字识别

    文章目录 pytorch 神经网络训练demo 输出结果 来源 pytorch 神经网络训练demo 数据集 MNIST 该数据集的内容是手写数字识别 其分为两部分 分别含有60000张训练图片和10000张测试图片 神经网络 全连接网络
  • 决策树(Decision Tree)简介

    决策树 Decision Tree 及其变种是另一类将输入空间分成不同的区域 每个区域有独立参数的算法 决策树分类算法是一种基于实例的归纳学习方法 它能从给定的无序的训练样本中 提炼出树型的分类模型 树中的每个非叶子节点记录了使用哪个特征来
  • cifar数据集介绍及到图像转换的实现

    CIFAR是一个用于普通物体识别的数据集 CIFAR数据集分为两种 CIFAR 10和CIFAR 100 The CIFAR 10 and CIFAR 100 are labeled subsets of the 80 million ti
  • pthread_create返回值错误码11 (EAGAIN)或libgomp: Thread creation failed: Resource temporarily unavailable错误

    在主机上开发torch xla时 使用非root用户在conda环境 遇到tensorflow中报pthread create 11错误 大意为系统资源不足 解决方案 分析 此主机多用户使用 资源占用非常大 且大多数情况下在docker容器

随机推荐

  • 三维旋转:旋转矩阵,欧拉角,四元数

    在介绍下面的文章前 大家如果接触到欧拉角的话 就一定要关注一个词 要顺规 在欧拉角体系里面 有12种顺规 这一点是好多文章没有让读书意识到 导致后面学习图形学里面的 heading pitch bank 时对不上号 一般百度百科里面说到的
  • 课程笔记2

    一 实现 1 区块链是去中心化的账本 比特币采用的是基于交易的账本模式 区块链的全节点需要维护一种名叫UTXO的数据结构 所有未花掉的交易的输出的集合 可以有效检测双花攻击 交易的总输入略微大于总输出 这是因为比特币的第二个激励机制 获得记
  • load data inpath出错原因及解决方法

    hive gt load data inpath hdfs Master hdp 9000 person txt into table Person1 FAILED SemanticException Error 10028 Line 1
  • java setcellvalue NA_java minioClient.setBucketPolicy 调用失败 折腾好几天了 求大佬解惑...

    方法调用后 提示 Request processing failed nested exception is java lang IllegalArgumentException unknown error code string Malf
  • 简要损益科目口诀,营业外收支和其他业务收支的区别

    一 损益科目口诀 三收三费所得税 两成三益外加减 三收 主营业务收入 其他业务收入 营业外收入 三费 管理费用 财务费用 销售费用 这是常用费用 某些企业可能还分有研究开发费用 两成 主营业务成本 其他业务成本 三益 投资收益 公允价值变动
  • java查看包的源代码

    把鼠标放在方法上 按Ctrl进去 打开的 class文件就是Java jdk1 7 0 src zip中的源码 但是在Java jdk1 7 0 src zip 中是以 java为扩展名
  • ios开发教程入门到精通

    第1集 初识macOS 点击观看 第2集 开发工具Xcode 点击观看 第3集 初识Objective C 点击观看 待续
  • 华为机试 牛客网 HJ1 字符串最后一个单词的长度

    华为机试 牛客网 HJ1 字符串最后一个单词的长度 描述 输入描述 输出描述 示例一 解法一 解法二 反思 描述 计算字符串最后一个单词的长度 单词以空格隔开 字符串长度小于5000 输入描述 输入一行 代表要计算的字符串 非空 长度小于5
  • shell简单脚本编写

    1 第一步 安装邮件服务 root server yum install s nail y 第二步 编辑配置文件 root server vim etc s nail rc set from 自己的qq邮箱地址 set smtp smtp
  • OpenCV - 基本知识

    1 读取并显示图片 namedWindow新建一个显示窗口 imshow输出图片 namedwindow可有可无 Mat image cv imread E 其他文档 图片 2 jpg 2 cv namedWindow 照片 CV WIND
  • window中gcc编译程序、编辑环境配置以及gcc编译程序的过程(含system函数以及CMD快捷键)

    1 system函数的使用 include
  • 关于rocketmq 中日志文件路径的配置

    前些天发现了一个巨牛的人工智能学习网站 通俗易懂 风趣幽默 忍不住分享一下给大家 点击跳转到网站 rocketmq 中的数据和日志文件默认都是存储在user home路径下面的 往往我们都需要修改这些路径到指定文件夹以便管理 服务端日志 网
  • ML-朴素贝叶斯

    参考 西瓜书 P151 以前对贝叶斯参数的计算过程不是很清楚 在西瓜书里讲的很详细 原来可以把X属性分为离散型与连续型 离散型的话可以直接按照频率计算 连续型的话 要用极大似然估计 首先假设概率密度函数满足一个分布 比如正态分布 然后利用已
  • 动态控制ToolStrip上ToolStripButton的大小(包括图标的大小)

    一 设置固定大小的ToolStripButton 设置固定大小的ToolStripButton很简单 ToolStripButton gt AutoSize属性设置为false size调整为自己想要的大小即可 同时配合的是ToolStri
  • Flutter与android原生通信

    Flutter 与 Android iOS 之间信息交互通过 Platform Channel 进行桥接 Flutter 定义了三种不同的 Channel 但无论是传递方法还是传递事件 其本质上都是数据的传递 MethodChannel 用
  • 因Redis分布式锁造成的P0级重大事故,整个项目组被扣了绩效......,请慎用

    作者 浪漫先生 出处 juejin im post 5f159cd8f265da22e425f71d 背景 我们项目中的抢购订单采用的是分布式锁来解决的 有一次 运营做了一个飞天茅台的抢购活动 库存 100 瓶 但是却超卖了 要知道 这个地
  • 在C++11通过SFINAE机制实现静态检查类成员是否存在并分情况处理,以及一种通用宏的实现

    目录 引入 目的 代码 测试 TIPS 引入 c 模板中 我们无法知道参数类是否具有某个成员 例如下面代码 我们希望下面的代码中能够打印t的成员变量a的值 然而当类型T不包含成员a时 调用下面的代码就会报错 template
  • iOS Push详述,了解一下?

    欢迎大家前往腾讯云 社区 获取更多腾讯海量技术实践干货哦 本文由WeTest质量开放平台团队发表于云 社区专栏 作者 陈裕发 腾讯系统测试工程师 商业转载请联系腾讯WeTest获得授权 非商业转载请注明出处 原文链接 http wetest
  • 安装eli5库的踩雷

    报错方法 在Anaconda Prompt中输入pip install eli5 conda install eli5指令 分别显示安装失败和未找到包 解决方法 在Anaconda Powershell Prompt中输入conda ins
  • 深度学习中的优化算法之RMSProp

    之前在https blog csdn net fengbingchun article details 124766283 中介绍过深度学习中的优化算法AdaGrad 这里介绍下深度学习的另一种优化算法RMSProp RMSProp全称为R