深度学习中的优化算法之AdaGrad

2023-11-18

      之前在https://blog.csdn.net/fengbingchun/article/details/123955067 介绍过SGD(Mini-Batch Gradient Descent(MBGD),有时提到SGD的时候,其实指的是MBGD)。这里介绍下自适应梯度优化算法。

      AdaGrad:全称Adaptive Gradient,自适应梯度,是梯度下降优化算法的扩展。AdaGrad是一种具有自适应学习率的梯度下降优化方法。它使参数的学习率自适应,对不频繁的参数执行较大的更新,对频繁的参数执行较小的更新(It adapts the learning rate to the parameters, performing larger updates for infrequent and smaller updates for frequent parameters)。因此,它非常适合处理稀疏数据。AdaGrad可大大提高SGD的鲁棒性。如下图所示,截图来自:https://arxiv.org/pdf/1609.04747.pdf

      在SGD中,我们每次迭代对所有参数进行更新,因为每个参数使用相同的学习率。而AdaGrad在每个时间步长(every time step)对每个参数使用不同的学习率。在SGD和MBSD中,对于每个权值(each weight)或者说对于每个参数(each parameter),学习率的值都是相同的。但是在AdaGrad中,每个权值都有不同的学习率。在现实世界的数据集中,一些特征是稀疏的(大部分特征为零,所以它是稀疏的),而另一些则是密集的(dense,大部分特征是非零的),因此为所有权值保持相同的学习率不利于优化。

      AdaGrad的主要优点之一是它消除了手动调整学习率的需要AdaGrad在迭代过程中不断调整学习率,并让目标函数中的每个参数都分别拥有自己的学习率。大多数实现使用学习率默认值为0.01,开始设置一个较大的学习率。

      AdaGrad的主要弱点是它在分母中累积平方梯度:由于每个添加项都是正数,因此在训练过程中累积和不断增长。这反过来又导致学习率不断变小并最终变得无限小,此时算法不再能够获得额外的知识即导致模型不会再次学习。Adadelta算法旨在解决此缺陷。

      以下是与SGD不同的代码片段:

      1.在原有枚举类Optimization的基础上新增AdaGrad:

enum class Optimization {
	BGD, // Batch Gradient Descent
	SGD, // Stochastic Gradient Descent
	MBGD, // Mini-batch Gradient Descent
	SGD_Momentum, // SGD with Momentum
	AdaGrad // Adaptive Gradient
};

      2.为了每次运行与SGD产生的随机初始化权值相同,这里使用std::default_random_engine

template<typename T>
void generator_real_random_number(T* data, int length, T a, T b, bool default_random)
{
	// 每次产生固定的不同的值
    std::default_random_engine generator;
 
    std::uniform_real_distribution<T> distribution(a, b);
    for (int i = 0; i < length; ++i)
        data[i] = distribution(generator);
}

      3.为了对数据集每次执行shuffle时结果一致,这里有std::random_shuffle调整为std::shuffle:

//std::srand(unsigned(std::time(0)));
//std::random_shuffle(random_shuffle_.begin(), random_shuffle_.end(), generate_random); // 每次执行后random_shuffle_结果不同
std::default_random_engine generator;
std::shuffle(random_shuffle_.begin(), random_shuffle_.end(), generator); // 每次执行后random_shuffle_结果相同

      4.calculate_gradient_descent函数:

void LogisticRegression2::calculate_gradient_descent(int start, int end)
{
	switch (optim_) {
		case Optimization::AdaGrad: {
			int len = end - start;
			std::vector<float> g(feature_length_, 0.);
			std::vector<float> z(len, 0), dz(len, 0);
			for (int i = start, x = 0; i < end; ++i, ++x) {
				z[x] = calculate_z(data_->samples[random_shuffle_[i]]);
				dz[x] = calculate_loss_function_derivative(calculate_activation_function(z[x]), data_->labels[random_shuffle_[i]]);

				for (int j = 0; j < feature_length_; ++j) {
					float dw = data_->samples[random_shuffle_[i]][j] * dz[x];
					g[j] += dw * dw;
					w_[j] = w_[j] - alpha_ * dw / (std::sqrt(g[j]) + eps_);
				}

				b_ -= (alpha_ * dz[x]);
			}
		}
			break;
		case Optimization::SGD_Momentum: {
			int len = end - start;
			std::vector<float> change(feature_length_, 0.);
			std::vector<float> z(len, 0), dz(len, 0);
			for (int i = start, x = 0; i < end; ++i, ++x) {
				z[x] = calculate_z(data_->samples[random_shuffle_[i]]);
				dz[x] = calculate_loss_function_derivative(calculate_activation_function(z[x]), data_->labels[random_shuffle_[i]]);

				for (int j = 0; j < feature_length_; ++j) {
					float new_change = mu_ * change[j] - alpha_ * (data_->samples[random_shuffle_[i]][j] * dz[x]);
					w_[j] += new_change;
					change[j] = new_change;
				}

				b_ -= (alpha_ * dz[x]);
			}
		}
			break;
		case Optimization::SGD:
		case Optimization::MBGD: {
			int len = end - start;
			std::vector<float> z(len, 0), dz(len, 0);
			for (int i = start, x = 0; i < end; ++i, ++x) {
				z[x] = calculate_z(data_->samples[random_shuffle_[i]]);
				dz[x] = calculate_loss_function_derivative(calculate_activation_function(z[x]), data_->labels[random_shuffle_[i]]);

				for (int j = 0; j < feature_length_; ++j) {
					w_[j] = w_[j] - alpha_ * (data_->samples[random_shuffle_[i]][j] * dz[x]);
				}

				b_ -= (alpha_ * dz[x]);
			}
		}
			break;
		case Optimization::BGD:
		default: // BGD
			std::vector<float> z(m_, 0), dz(m_, 0);
			float db = 0.;
			std::vector<float> dw(feature_length_, 0.);
			for (int i = 0; i < m_; ++i) {
				z[i] = calculate_z(data_->samples[i]);
				o_[i] = calculate_activation_function(z[i]);
				dz[i] = calculate_loss_function_derivative(o_[i], data_->labels[i]);

				for (int j = 0; j < feature_length_; ++j) {
					dw[j] += data_->samples[i][j] * dz[i]; // dw(i)+=x(i)(j)*dz(i)
				}
				db += dz[i]; // db+=dz(i)
			}

			for (int j = 0; j < feature_length_; ++j) {
				dw[j] /= m_;
				w_[j] -= alpha_ * dw[j];
			}

			b_ -= alpha_*(db/m_);
	}
}

      执行结果如下图所示:测试函数为test_logistic_regression2_gradient_descent,多次执行每种配置,最终结果都相同。在它们学习率及其它配置参数相同的情况下,AdaGrad耗时为16秒,而SGD仅为7秒;但AdaGrad的识别率为100%,SGD的识别率为99.89%。

      GitHub: https://github.com/fengbingchun/NN_Test

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

深度学习中的优化算法之AdaGrad 的相关文章

  • RNN/LSTM循环神经网络讲解

    转自 https zhuanlan zhihu com p 123211148 一 什么是循环神经网络 循环神经网络 Rerrent Neural Network RNN 历史啊 谁发明的都不重要 说了你也记不住 你只要记住RNN是神经网络
  • 【阅读论文方法总结】

    1 快速浏览摘要 看是否有自己需要的东西 2 如果需要 github上查找相关论文代码 对照着论文进行阅读 这样效率高 能够快速理解
  • 图解NCHW与NHWC数据格式

    图解NCHW与NHWC数据格式 田海立 CSDN CSDN博客 nchw 流行的深度学习框架中有不同的数据格式 典型的有NCHW和NHWC格式 本文从逻辑表达和物理存储角度用图的方式来理解这两种数据格式 最后以RGB图像为例来加深NHWC和
  • 深度学习网络篇——VGGNet(Part1 网络结构&训练环节)

    我们上篇文章了解了一下NIN 接下来我们来了解一下VGGNet 可以说是另一波的跪舔和膜拜 VGGNet主要是分为两篇文章 第一篇文章来分享一下VGGNet的网络结构还有训练环节 第二篇文章是分享VGGNet做的分类实验和总结 此为第一篇
  • 序列模型——自然语言处理与词嵌入(理论部分)

    1 词汇表征 深度学习已经给自然语言处理 Natural Language Process NLP 带来革命性的变革 其中一个很关键的概念是词嵌入 word embedding 这是语言表示的一种方式 可以让算法自动的了解一些类似的词 例如
  • 3D人体重建方法漫谈

    转自 https blog csdn net Asimov Liu article details 96442990 1 概述 2 模型匹配的方法 2 1SMPL Skinned Multi Person Linear model 模型 2
  • 深度学习论文:Deep Residual Learning for Image Recognition

    论文 He Kaiming et al Deep residual learning for image recognition Proceedings of the IEEE conference on computer vision a
  • [NLP] transformers 使用指南

    严格意义上讲 transformers 并不是 PyTorch 的一部分 然而 transformers 与 PyTorch 或 TensorFlow 结合的太紧密了 而且可以把 transformers 看成是 PyTorch 或 Ten
  • Tensorflow错误InvalidArgumentError see above for traceback): No OpKernel was registered to support Op

    调用tensorflow gpu运行错误 错误信息如下 2023 06 21 15 36 14 007389 I tensorflow core platform cpu feature guard cc 141 Your CPU supp
  • pytorch 入门 DenseNet

    知识点0 dense block的结构 知识点1 定义dense block 知识点2 定义DenseNet的主体 知识点3 add module 知识点 densenet是由 多个这种结构串联而成的 import torch import
  • window 7 平台上 MXNET 源码编译

    目的 本文主要描述怎么在windows上编译mxnet源码 得到可用的libmxnet dll和libmxnet lib文件 版本 mxnet x64 release CPU版 运行环境 windows 7 64bit visual stu
  • (#########优化器函数########)TensorFlow实现与优化深度神经网络

    反正是要学一些API的 不如直接从例子里面学习怎么使用API 这样同时可以复习一下一些基本的机器学习知识 但是一开始开始和以前一样 先直接讲类和常用函数用法 然后举例子 这里主要是各种优化器 以及使用 因为大多数机器学习任务就是最小化损失
  • Transformer——《Attention is all you need》

    本文是Google 机器翻译团队在2017 年发表 提出了一个新的简单的网络模型 Transformer 该模型基于纯注意力机制 Attention mechanisms 完全抛弃了RNN和CNN网络结构 在机器翻译任务上取得了很好的效果
  • LoFTR配置运行: Detector-Free Local Feature Matching with Transformers ubuntu18.04 预训练模型分享

    刚装好系统的空白系统ubuntu18 04安装 首先进入 软件与更新 换到国内源 论文下载 代码下载 1 anaconda 3 5 3 安装 Index of anaconda archive 清华大学开源软件镜像站 Tsinghua Op
  • 基于Lasagne实现限制玻尔兹曼机(RBM)

    RBM理论部分大家看懂这个图片就差不多了 Lasagne写代码首先要确定层与层 RBM 正向反向过程可以分别当作一个层 权值矩阵互为转置即可 代码 coding utf 8 data format is bc01 written by Ph
  • Pointpillars for object detection

    博客参考 pointpillars代码阅读 prep pointcloud篇 Little sky jty的博客 CSDN博客Brief这一篇内容主要是对函数prep pointcloud进行debug和记录 这里也是dataloader的
  • 16个车辆信息检测数据集收集汇总(简介及链接)

    16个车辆信息检测数据集收集汇总 简介及链接 目录 1 UA DETRAC 2 BDD100K 自动驾驶数据集 3 综合汽车 CompCars 数据集 4 Stanford Cars Dataset 5 OpenData V11 0 车辆重
  • 谈一谈关于NLP的落地场景和商业价值

    欢迎大家关注微信公众号 baihuaML 白话机器学习 在这里 我们一起分享AI的故事 您可以在后台留言 关于机器学习 深度学习的问题 我们会选择其中的优质问题进行回答 本期的问题 你好 请问下nlp在现在的市场主要应用在哪些方面 什么是N
  • 当我们谈人工智能 我们在谈论什么

    我们对一个事物的认识模糊往往是因为宣传过剩冲淡了理论的真实 我们陷在狂欢里 暂时忘记为什么要狂欢 如何踏上这趟飞速发展的列车成为越来越多人心心念念的事情 人工智能的浪潮更像是新闻舆论炒起来的话题 城外的人想进去 城内的人也不想出来 当我们谈
  • 基于矩阵求解多元线性回归

    多元线性回归法也是深度学习的内容之一 用java实现一下多元线性回归 一元线性回归的公式为 y a x b 多元线性回归的公式与一元线性回归的公式类似 不过是矩阵的形式 可以表示为Y AX b 其中 Y是样本输出的合集 X是样本输入的合集

随机推荐

  • kuangbin的模板

    直接链接 间接链接
  • 使用DbHelperSQL调用存储过程的方法

    下面代码是个调用存储过程的例子 对于学习怎么使用DbHelperSQL调用存储过程很有帮助
  • ceph 维护系列(二)--卸载osd

    一 摘要 本文主要介绍从ceph 某台节点上卸载一块或者多块osd 硬盘 二 环境信息 2 1 操作系统版本 root proceph05 cat etc centos release CentOS Linux release 7 6 18
  • SSM框架搭建,及遇到的问题

    SSM框架搭建 及遇到的问题 1 基本概念 1 1 Spring Spring是一个开源框架 Spring是于2003 年兴起的一个轻量级的Java 开发框架 由Rod Johnson 在其著作Expert One On One J2EE
  • 使用NNI对BERT模型进行粗剪枝、蒸馏与微调

    前言 模型剪枝 Model Pruning 是一种用于减少神经网络模型尺寸和计算复杂度的技术 通过剪枝 可以去除模型中冗余的参数和连接 从而减小模型的存储需求和推理时间 同时保持模型的性能 模型剪枝的一般步骤 训练初始模型 训练一个初始的神
  • win10 WSL2 Ubuntu图像化界面安装和配置

    1 win11 设置 打开虚拟机安装许可 2 开启开发者模式 2 Microsoft Store下载安装ubuntu 我这里使用的是20 04 5LTS版本 3 打开ubuntu 命令窗口 1 打开win11的命令行 在下拉三角下标 打开
  • 【云原生之Docker实战】使用Docker部署宝塔面板

    云原生之Docker实战 使用Docker部署宝塔面板 一 宝塔面板介绍 二 检查本地docker环境 1 检查系统版本 2 检查内核版本 3 检查docker版本 三 下载宝塔镜像 四 部署宝塔面板 1 创建挂载目录 2 创建宝塔容器 3
  • 四、C++语言进阶:Boost入门

    4 Boost入门 4 1 简介 Boost库是一个可移植 提供源代码的C 库 作为标准库的后备 是C 标准化进程的开发引擎之一 是为C 语言标准库提供扩展的一些C 程序库的总称 4 2 使用 4 2 1 lamdba表达式 lambda库
  • 字符设备

    from here 字符设备http blog 163 com sunshine linting blog static 44893323201181102957282 字符设备是一种按字节来访问的设备 字符驱动则负责驱动字符设备 这样的驱
  • C++Static成员

    Static成员 概念 声明为static的类成员称为类的静态成员 用static修饰的成员变量 称之为静态成员变量 用static修饰的成员函数 称之为静态成员函数 静态成员变量一定要在类外进行初始化 例题 实现一个类 计算程序中创建了多
  • Mysql索引原理

    Mysql索引类型及其特性 1 普通索引 最基本的索引 它没有任何限制 也是我们大多数情况下用到的索引 直接创建索引 CREATE INDEX index name ON table column length 修改表结构的方式添加索引 A
  • Linux深度系统分区顺序,深度Deepin 20操作系统默认全盘分区不合理?附建设性意见探讨...

    有的网友认为深度 Deepin 20 操作系统默认全盘分区不合理 以下是某位深度网友的个人意见 首先 必须认为默认全盘分区的确存在一些不合理 以下是建设性意见 供与网友们一起探讨 建设性意见内容如下 1 EFI 引导分区 315M 实际使用
  • Javascript模块化规范之CommonJs,AMD,CMD

    Javascript模块化编程规范 一 模块化编程背景 1 什么是模块化编程 2 Javascript模块化编程有哪些规范 二 Javascript模块化编程 1 CommonJs 2 AMD异步模块定义 3 CMD 通用模块定义 4 ES
  • printf()函数

    printf函数对输出表中各量求值的顺序是自右至左进行的 也即程序执行的过程中参数的压栈顺序是从右至左的 并且压栈时压入的是值 因为参数的压栈是在程序的执行过程中 所以即使参数列表中有函数调用则在压栈时也即计算出来 即调用此函数去执行 把得
  • MathType改变字体大小

    目录 一 MathType中的公式字体 二 临时自定义字体大小 三 更改默认字体大小 四 总结 一 MathType中的公式字体 MathType中默认的字体大小为12pt 在word中即小四 word字体对应MathType的字体大小如下
  • Android Studio开发环境的搭建

    Android Studio开发环境的搭建 一 实验目的及任务 Windows下掌握Android Studio的安装和配置 模拟器的创建 Activity的创建和注册 二 实验环境 Jdk Android Studio 三 实验步骤 An
  • 7 种提升SpringBoot 吞吐量神技

    架构师专栏 2022 04 11 08 44 大家好 我是磊哥 一 异步执行 实现方式二种 1 使用异步注解 aysnc 启动类 添加 EnableAsync注解 2 JDK 8本身有一个非常好用的Future类 CompletableFu
  • 计算两个数之和,不能用+ = 运算符

    在lintcode的一个简单的算法题 计算两数的和 不能用 运算符 对于这个题 我是一点思路都没有 不用 那能用什么计算呢 于是在网上找了找答案 答案其实很简单 主要是涉及到运算 我是觉得应该记一下 所以才将这个题写下来 具体代码 异或 运
  • centos 6.5 连接MySQL 提示:ERROR 1045 (28000): Access denied for user 'root'@'localhost' (using password:

    centos 6 5 连接MySQL 提示 ERROR 1045 28000 Access denied for user root localhost using password NO CentOs 第一次登入MySQL 默认超级用户
  • 深度学习中的优化算法之AdaGrad

    之前在https blog csdn net fengbingchun article details 123955067 介绍过SGD Mini Batch Gradient Descent MBGD 有时提到SGD的时候 其实指的是MB