深度学习中的优化算法之AdaMax

2023-10-29

      之前在https://blog.csdn.net/fengbingchun/article/details/125018001 介绍过深度学习中的优化算法Adam,这里介绍下深度学习的另一种优化算法AdaMax。AdaMax与Adam来自于同一篇论文。论文名字为《ADAM: A METHOD FOR STOCHASTIC OPTIMIZATION》,论文地址:https://arxiv.org/pdf/1412.6980.pdf

      AdaMax:是梯度优化算法的扩展,基于无穷范数的Adam的变体(a variant of Adam based on the infinity norm)。此算法对学习率的上限提供了一个更简单的范围,并可能对某些问题进行更有效的优化。如下图所示,截图来自:https://arxiv.org/pdf/1609.04747.pdf

      AdaMax与Adam区别:本质上前者是将L2范数推广到L-infinity范数。AdaMax与Adam最终公式中仅分母的计算方式不同,AdaMax使用公式24,Adam使用公式20。

以下是与Adam不同的代码片段:

      1.在原有枚举类Optimizaiton的基础上新增AdaMax

enum class Optimization {
	BGD, // Batch Gradient Descent
	SGD, // Stochastic Gradient Descent
	MBGD, // Mini-batch Gradient Descent
	SGD_Momentum, // SGD with Momentum
	AdaGrad, // Adaptive Gradient
	RMSProp, // Root Mean Square Propagation
	Adadelta, // an adaptive learning rate method
	Adam, // Adaptive Moment Estimation
	AdaMax // a variant of Adam based on the infinity norm
};

      2.calculate_gradient_descent函数:

void LogisticRegression2::calculate_gradient_descent(int start, int end)
{
	switch (optim_) {
		case Optimization::AdaMax: {
			int len = end - start;
			std::vector<float> m(feature_length_, 0.), u(feature_length_, 1e-8), mhat(feature_length_, 0.);
			std::vector<float> z(len, 0.), dz(len, 0.);
			float beta1t = 1.;
			for (int i = start, x = 0; i < end; ++i, ++x) {
				z[x] = calculate_z(data_->samples[random_shuffle_[i]]);
				dz[x] = calculate_loss_function_derivative(calculate_activation_function(z[x]), data_->labels[random_shuffle_[i]]);

				beta1t *= beta1_;

				for (int j = 0; j < feature_length_; ++j) {
					float dw = data_->samples[random_shuffle_[i]][j] * dz[x];
					m[j] = beta1_ * m[j] + (1. - beta1_) * dw; // formula 19
					u[j] = std::max(beta2_ * u[j], std::fabs(dw)); // formula 24

					mhat[j] = m[j] / (1. - beta1t); // formula 20

					// Note: need to ensure than u[j] cannot be 0.
					// (1). u[j] is initialized to 1e-8, or
					// (2). if u[j] is initialized to 0., then u[j] adjusts to (u[j] + 1e-8)
					w_[j] = w_[j] - alpha_ * mhat[j] / u[j]; // formula 25
				}

				b_ -= (alpha_ * dz[x]);
			}
		}
			break;
		case Optimization::Adam: {
			int len = end - start;
			std::vector<float> m(feature_length_, 0.), v(feature_length_, 0.), mhat(feature_length_, 0.), vhat(feature_length_, 0.);
			std::vector<float> z(len, 0.), dz(len, 0.);
			float beta1t = 1., beta2t = 1.;
			for (int i = start, x = 0; i < end; ++i, ++x) {
				z[x] = calculate_z(data_->samples[random_shuffle_[i]]);
				dz[x] = calculate_loss_function_derivative(calculate_activation_function(z[x]), data_->labels[random_shuffle_[i]]);

				beta1t *= beta1_;
				beta2t *= beta2_;

				for (int j = 0; j < feature_length_; ++j) {
					float dw = data_->samples[random_shuffle_[i]][j] * dz[x];
					m[j] = beta1_ * m[j] + (1. - beta1_) * dw; // formula 19
					v[j] = beta2_ * v[j] + (1. - beta2_) * (dw * dw); // formula 19

					mhat[j] = m[j] / (1. - beta1t); // formula 20
					vhat[j] = v[j] / (1. - beta2t); // formula 20

					w_[j] = w_[j] - alpha_ * mhat[j] / (std::sqrt(vhat[j]) + eps_); // formula 21
				}

				b_ -= (alpha_ * dz[x]);
			}
		}
			break;
		case Optimization::Adadelta: {
			int len = end - start;
			std::vector<float> g(feature_length_, 0.), p(feature_length_, 0.);
			std::vector<float> z(len, 0.), dz(len, 0.);
			for (int i = start, x = 0; i < end; ++i, ++x) {
				z[x] = calculate_z(data_->samples[random_shuffle_[i]]);
				dz[x] = calculate_loss_function_derivative(calculate_activation_function(z[x]), data_->labels[random_shuffle_[i]]);

				for (int j = 0; j < feature_length_; ++j) {
					float dw = data_->samples[random_shuffle_[i]][j] * dz[x];
					g[j] = mu_ * g[j] + (1. - mu_) * (dw * dw); // formula 10

					float alpha = (eps_ + std::sqrt(p[j])) / (eps_ + std::sqrt(g[j]));
					float change = alpha * dw;
					p[j] = mu_ * p[j] +  (1. - mu_) * (change * change); // formula 15

					w_[j] = w_[j] - change;
				}

				b_ -= (eps_ * dz[x]);
			}
		}
			break;
		case Optimization::RMSProp: {
			int len = end - start;
			std::vector<float> g(feature_length_, 0.);
			std::vector<float> z(len, 0), dz(len, 0);
			for (int i = start, x = 0; i < end; ++i, ++x) {
				z[x] = calculate_z(data_->samples[random_shuffle_[i]]);
				dz[x] = calculate_loss_function_derivative(calculate_activation_function(z[x]), data_->labels[random_shuffle_[i]]);

				for (int j = 0; j < feature_length_; ++j) {
					float dw = data_->samples[random_shuffle_[i]][j] * dz[x];
					g[j] = mu_ * g[j] + (1. - mu_) * (dw * dw); // formula 18
					w_[j] = w_[j] - alpha_ * dw / (std::sqrt(g[j]) + eps_);
				}

				b_ -= (alpha_ * dz[x]);
			}
		}
			break;
		case Optimization::AdaGrad: {
			int len = end - start;
			std::vector<float> g(feature_length_, 0.);
			std::vector<float> z(len, 0), dz(len, 0);
			for (int i = start, x = 0; i < end; ++i, ++x) {
				z[x] = calculate_z(data_->samples[random_shuffle_[i]]);
				dz[x] = calculate_loss_function_derivative(calculate_activation_function(z[x]), data_->labels[random_shuffle_[i]]);

				for (int j = 0; j < feature_length_; ++j) {
					float dw = data_->samples[random_shuffle_[i]][j] * dz[x];
					g[j] += dw * dw;
					w_[j] = w_[j] - alpha_ * dw / (std::sqrt(g[j]) + eps_);
				}

				b_ -= (alpha_ * dz[x]);
			}
		}
			break;
		case Optimization::SGD_Momentum: {
			int len = end - start;
			std::vector<float> change(feature_length_, 0.);
			std::vector<float> z(len, 0), dz(len, 0);
			for (int i = start, x = 0; i < end; ++i, ++x) {
				z[x] = calculate_z(data_->samples[random_shuffle_[i]]);
				dz[x] = calculate_loss_function_derivative(calculate_activation_function(z[x]), data_->labels[random_shuffle_[i]]);

				for (int j = 0; j < feature_length_; ++j) {
					float new_change = mu_ * change[j] - alpha_ * (data_->samples[random_shuffle_[i]][j] * dz[x]);
					w_[j] += new_change;
					change[j] = new_change;
				}

				b_ -= (alpha_ * dz[x]);
			}
		}
			break;
		case Optimization::SGD:
		case Optimization::MBGD: {
			int len = end - start;
			std::vector<float> z(len, 0), dz(len, 0);
			for (int i = start, x = 0; i < end; ++i, ++x) {
				z[x] = calculate_z(data_->samples[random_shuffle_[i]]);
				dz[x] = calculate_loss_function_derivative(calculate_activation_function(z[x]), data_->labels[random_shuffle_[i]]);

				for (int j = 0; j < feature_length_; ++j) {
					w_[j] = w_[j] - alpha_ * (data_->samples[random_shuffle_[i]][j] * dz[x]);
				}

				b_ -= (alpha_ * dz[x]);
			}
		}
			break;
		case Optimization::BGD:
		default: // BGD
			std::vector<float> z(m_, 0), dz(m_, 0);
			float db = 0.;
			std::vector<float> dw(feature_length_, 0.);
			for (int i = 0; i < m_; ++i) {
				z[i] = calculate_z(data_->samples[i]);
				o_[i] = calculate_activation_function(z[i]);
				dz[i] = calculate_loss_function_derivative(o_[i], data_->labels[i]);

				for (int j = 0; j < feature_length_; ++j) {
					dw[j] += data_->samples[i][j] * dz[i]; // dw(i)+=x(i)(j)*dz(i)
				}
				db += dz[i]; // db+=dz(i)
			}

			for (int j = 0; j < feature_length_; ++j) {
				dw[j] /= m_;
				w_[j] -= alpha_ * dw[j];
			}

			b_ -= alpha_*(db/m_);
	}
}

      执行结果如下图所示:测试函数为test_logistic_regression2_gradient_descent,多次执行每种配置,最终结果都相同。图像集使用MNIST,其中训练图像总共10000张,0和1各5000张,均来自于训练集;预测图像总共1800张,0和1各900张,均来自于测试集。Adam和AdaMax配置参数相同的情况下,即eps为1e-8,学习率为0.002,beta1为0.9,beta2为0.999的情况下,Adam耗时30秒,AdaMax耗时为25秒;它们的识别率均为100%

      GitHub: https://github.com/fengbingchun/NN_Test

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

深度学习中的优化算法之AdaMax 的相关文章

随机推荐

  • 高数基础——步长

    目录 1 什么是步长 2 步长 怎么取 1 Armijo conditions 充分下降条件 2 curvature condition 不要取得太小 3 Wolfe conditions 1 什么是步长 在确定了搜索方向的情况下 讨论搜索
  • 【Linux 内核网络协议栈源码剖析】sendto 函数剖析

    前面介绍的函数基本上都是TCP协议的 如listen connect accept 等函数 这都是为可靠传输协议TCP定制的 对于另一个不可靠udp协议 通信系统其可靠性交由上层应用层负责 则主要由两个函数完成 sendto 和 recvf
  • 第15届全国大学生知识竞赛场景实操 2022ciscn初赛 部分writeup

    文章目录 Crypto 签到电台 基于挑战码的双向认证1 基于挑战码的双向认证2 基于挑战码的双向认证3 Misc ez usb everlasting night 问卷 babydisk Web Ezpop Crypto 签到电台 签到
  • 【Spring】浅谈spring为什么推荐使用构造器注入

    一 前言 Spring框架对Java开发的重要性不言而喻 其核心特性就是IOC Inversion of Control 控制反转 和AOP 平时使用最多的就是其中的IOC 我们通过将组件交由Spring的IOC容器管理 将对象的依赖关系由
  • WPF封装VLC播放器控件(方式二:VlcVideoSourceProvider绑定Image控件)

    之前写过一篇文章关于WPF利用VLCPlayer控制Winform窗体句柄封装的视频播放器 链接 https blog csdn net dnazhd article details 102476134 这里换一种方式重写一下视频播放器控件
  • 【算法】回溯算法

    1 概念 回溯算法实际上一个类似枚举的搜索尝试过程 主要是在搜索尝试过程中寻找问题的解 当发现已不满足求解条件时 就 回溯 返回到上一步还能执行的状态 尝试别的路径 类似于走迷宫一样 假设我们到了每一个岔路口都规定 除了走过的地方 按照先往
  • linux more命令用法,linux more命令详解

    大家好 我是时间财富网智能客服时间君 上述问题将由我为大家进行解答 linux more命令详解分析如下 1 使用cat命令显示install log文件 系统会将install log文件完整的显示出来 但是用户只能看到文件的末尾部分 该
  • 网易易盾滑块逆向分析 js 滑动轨迹生成_2

    网易易盾无感逆向 提示 仅学习参考 如有涉及侵权联系本人删除 目标网站已做脱敏处理 aHR0cHM6Ly9kdW4uMTYzLmNvbS90cmlhbC9zZW5zZQ 文章目录 网易易盾无感逆向 加密参数 一 data参数 易盾滑块总结
  • 转码日记——Javascript笔记(3)

    代码块 只具有分组的作用 代码块内部的内容在外部也是完全可见的 console log hello world 一个单独的语句 document write goodbye alert FBI warning 大括号中的是一组语句 也叫代码
  • mmpycocotools包安装的问题:源码安装出现:“gcc: error : ../common/maksApi.c: 没有那个文件或目录“

    mmdetection框架中的mmpycocotools包的安装问题解决 问题背景 解决方案 方案1 不安装mmpycocotools包 方案2 安装mmpycocotools包 问题总结 问题背景 在配一篇detection论文时 安装R
  • Zookeeper可视化工具PrettyZoo

    文章目录 安装 创建连接 虽然市面上 Zookeeper 的 WEB 管理工具很丰富 但是却很难找到一款满意的图形化客户端 鉴于这样的情况 经过时间的查找 找到了这款管理 Zookeeper 的图形化工具 取名 PrettyZoo 意为 P
  • HTML5手机端网页开发

  • 关于postcss-px-to-viewport的使用经验

    最近在工作项目使用中新接触到postcss px to viewport 在使用上遇到一个坑 也有段时间没更新啦 记录分享一下 希望对你有所帮助 直接上重点 节省在网站苦苦寻找有效答案的你 gt gt gt 我所遇到的坑 由于项目需要 需要
  • 服务器实现端口转发的N种方式

    简介 在一些实际的场景里 我们需要通过利用一些端口转发工具 比如系统自带的命令行工具或第三方小软件 来绕过网络访问限制触及目标系统 下文为大家总结了linux系统和windows系统端口转发常用的一些方法 注 Linux实现端口转发需要内核
  • 在react中使用Markdown编辑器

    提示 写完文章后 目录可以自动生成 如何生成可参考右边的帮助文档 文章目录 一 在react中使用Markdown编辑器 二 使用步骤 实现效果 安装 使用 安装 使用 一 在react中使用Markdown编辑器 首先我们需要清楚Mark
  • File类(Java)

    目录 1 File类定义 2 File类构造方法 常用摘要 3 练习 4 主要成员方法 1 File类定义 1 File类主要是JAVA为文件这块的操作 如删除 新增等 而设计的相关类 2 File类的包名是java io 其实现了Seri
  • Vue3常规登录页面模板

    本文基于vue3 JavaScript 使用vue3的setup语法糖书写方式 setup语法糖也是当前各大适用Vue的框架官网都在推崇的书写方式 此外各大主流框架的源码首选是TypeScript 而不是JavaScript 登录页面模板
  • 正则表达式匹配字符串中的任何空格

    a zA Z0 9 匹配空格包括 r t n f 的含义是后面出现一个或多个 s
  • SSM+Layui整合

    文章目录 概述 依赖 各种配置文件 web xml spring配置 springMVC配置 MyBatis配置 Mapper映射文件 关于Layui 概述 刚学完了ssm 打算自己上手做一个项目玩玩 先第一步 整合ssm 自己不会写前端
  • 深度学习中的优化算法之AdaMax

    之前在https blog csdn net fengbingchun article details 125018001 介绍过深度学习中的优化算法Adam 这里介绍下深度学习的另一种优化算法AdaMax AdaMax与Adam来自于同一