【优化器】(一) SGD原理 & pytorch代码解析

2023-11-19

1.简介

很多情况下,我们调用优化器的时候都不清楚里面的原理和构造,主要基于自己数据集和模型的特点,然后再根据别人的经验来选择或者尝试优化器。下面分别对SGD的原理、pytorch代码进行介绍和解析。


2.梯度下降

梯度下降方法可以分为3种,分别是:

  • BGD (Batch gradient descent)

这种方法是最朴素的梯度下降方法,将全部的数据样本输入网络计算梯度后进行一次更新:

w^{^{k+1}} =w^{^{k}}-\alpha *\bigtriangledown f(w^{k})

其中 w为模型参数, \bigtriangledown f(w^{k})为模型参数更新梯度,\alpha为学习率。

这个方法的最大问题就是容易落入局部最优点或者鞍点。

局部最优点很好理解,就是梯度在下降过程中遇到下图的情况,导致在local minimum区间不断震荡最终收敛。

鞍点(saddle point)是指一个非局部极值点的驻点,如下图所示,长得像一个马鞍因此得名。以红点的位置来说,在x轴方向是一个向上弯曲的曲线,在y轴方向是一个向下弯曲的曲线。当点从x轴方向向下滑动时,最终也会落入鞍点,导致梯度为0。

  • SGD (Stochastic gradient descent)

为了解决BGD落入鞍点或局部最优点的问题,SGD引入了随机性,即将每个数据样本输入网络计算梯度后就进行一次更新:

w^{^{k+1}} =w^{^{k}}-\alpha *\bigtriangledown f(w^{k};x^{_{i}};y^{_{i}})

其中 w为模型参数, \bigtriangledown f(w^{k};x^{_{i}};y^{_{i}})为一个样本输入后的模型参数更新梯度,\alpha为学习率。

由于要对每个样本都单独计算梯度,那么相当于引入了许多噪声,梯度下降时就会跳出鞍点和局部最优点。但要对每个样本都计算一次梯度就导致了时间复杂度较高,模型收敛较慢,而且loss和梯度会有大幅度的震荡。

  • MBGD (Mini-batch gradient descent)

MBGD相当于缝合了SGD和BGD,即将多个数据样本输入网络计算梯度后就进行一次更新:

w^{^{k+1}} =w^{^{k}}-\alpha *\bigtriangledown f(w^{k};x^{_{i:i+b}};y^{_{i:i+b}})

其中 w为模型参数, \bigtriangledown f(w^{k};x^{_{i:i+b}};y^{_{i:i+b}})为batch size个样本输入后的模型参数更新梯度,\alpha为学习率。

MBGD同时解决了两者的缺点,使得参数更新更稳定更快速,这也是我们最常用的方法,pytorch代码里SGD类也是指的MBGD(当然可以自己设置特殊的batch size,就会退化为SGD或BGD)。


3.SGD优化

实际在pytorch的代码中,还加了两个优化,分别是:

  • Momentum

从原理上可以很好理解,Momentum就是在当前step的参数更新中加入了部分上一个step的梯度,公式表示为:

v^{k} =\gamma *v^{k-1}-\alpha *\bigtriangledown f(w^{k};x^{_{i:i+b}};y^{_{i:i+b}})

w^{^{k+1}} =w^{^{k}}-v^{^{k}}

其中 v^{^{k}}v^{^{k-1}}为当前step和上一个step的动量,即当前step的动量会有当前step的梯度和上一个step的动量叠加计算而来,其中\gamma一般设置为0.9或者0.99。

我们可以从以下两幅示意图中看到区别,第一张图没有加Momentum,第二张图加了Momentum。可以看到在第一张图中,点一开始往梯度变化的方向移动,但是到后来梯度逐渐变小,到最后变为了0,所以最终没有到达最优点。而第二张图由于加了Momentum,所以点会有一个横向移动的惯性,即使到了梯度为0的地方也能依靠惯性跳出。

  • Nesterov accelerated gradient(NAG)

加了Momentum之后,实际上模型参数更新的方向就不是当前点的梯度方向,所以这会一定程度上导致模型更新的不准确。NAG方法就是让参数先根据惯性预测出下一步点应该在的位置,然后根据预测点的梯度再更新一次:

w^{^{k{}'}} =w^{^{k}}-\gamma *v^{^{k-1}}

v^{k} =\gamma *v^{k-1}-\alpha *\bigtriangledown f(w^{k{}'};x^{_{i:i+b}};y^{_{i:i+b}})

w^{^{k+1}} =w^{^{k}}-v^{^{k}}


4.pytorch代码

以下代码为pytorch官方SGD代码,其中关键部分在step()中。

import torch
from torch.optim import Optimizer
from torch.optim.optimizer import required


class SGD(Optimizer):
    r"""Implements stochastic gradient descent (optionally with momentum).

    Nesterov momentum is based on the formula from
    `On the importance of initialization and momentum in deep learning`__.

    Args:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float): learning rate
        momentum (float, optional): momentum factor (default: 0)
        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
        dampening (float, optional): dampening for momentum (default: 0)
        nesterov (bool, optional): enables Nesterov momentum (default: False)

    Example:
        >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
        >>> optimizer.zero_grad()
        >>> loss_fn(model(input), target).backward()
        >>> optimizer.step()

    __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf

    .. note::
        The implementation of SGD with Momentum/Nesterov subtly differs from
        Sutskever et. al. and implementations in some other frameworks.

        Considering the specific case of Momentum, the update can be written as

        .. math::
                  v = \rho * v + g \\
                  p = p - lr * v

        where p, g, v and :math:`\rho` denote the parameters, gradient,
        velocity, and momentum respectively.

        This is in contrast to Sutskever et. al. and
        other frameworks which employ an update of the form

        .. math::
             v = \rho * v + lr * g \\
             p = p - v

        The Nesterov version is analogously modified.
    """

    def __init__(self, params, lr=required, momentum=0, dampening=0,
                 weight_decay=0, nesterov=False):
        if lr is not required and lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if momentum < 0.0:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if weight_decay < 0.0:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))

        defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
                        weight_decay=weight_decay, nesterov=nesterov)
        if nesterov and (momentum <= 0 or dampening != 0):
            raise ValueError("Nesterov momentum requires a momentum and zero dampening")
        super(SGD, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(SGD, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault('nesterov', False)

    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']

            for p in group['params']:
                if p.grad is None:
                    continue
                d_p = p.grad.data
                if weight_decay != 0:
                    d_p.add_(weight_decay, p.data)
                if momentum != 0:
                    param_state = self.state[p]
                    if 'momentum_buffer' not in param_state:
                        buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
                    else:
                        buf = param_state['momentum_buffer']
                        buf.mul_(momentum).add_(1 - dampening, d_p)
                    if nesterov:
                        d_p = d_p.add(momentum, buf)
                    else:
                        d_p = buf

                p.data.add_(-group['lr'], d_p)

        return loss

业务合作/学习交流+v:lizhiTechnology

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【优化器】(一) SGD原理 & pytorch代码解析 的相关文章

随机推荐

  • web服务器响应的端口号,web服务器端口号

    web服务器端口号 内容精选 换一换 Nginx Web Server场景是以Nginx作为Web Server的场景 Nginx作为Web Server 可以被配置部署为静态资源Web Server 在该配置下可以高效的进行静态资源的请求
  • python学习笔记——条件判断

    上篇 https blog csdn net qq 42489308 article details 89388218 条件判断 条件判断是通过一条或多条判断语句的执行结果 True或者False 来决定执行的代码块 在Python语法中
  • uboot分析之Makefile

    Uboot分析之Makefile 1 uboot根目录下执行 make smdk2410 config smdk2410 config unconfig MKCONFIG config arm arm920t smdk2410 samsun
  • 数据集下载OTB,VOT,UAV,鸢尾花

    OTB数据集下载百度网盘链接 链接 https pan baidu com s 1snsJF 7Sw EbKtzdvLO1nw 提取码 ls23 VOT数据集下载百度网盘链接 链接 https pan baidu com s 1UiTG1z
  • AI顶级会议列表 & ACL相关

    The First Class tier 1的conferences 其实基本上就是AI里面大家比较公认的top conference 下面同分的按字母序排列 IJCAI 1 AI最好的综合性会议 1969年开始 每两年开一次 奇数年开 因
  • 基于互补搜索技术和新颖架构设计,结合MobileNetV3主干网络,打造不同的目标检测器

    基于互补搜索技术和新颖架构设计 结合MobileNetV3主干网络 打造不同的目标检测器 目标检测是计算机视觉中的一个重要任务 随着深度学习技术的发展和神经网络的不断优化 YOLOv5已成为目前最流行的目标检测框架之一 然而 为了进一步提高
  • opengl shader 使用札记

    一 shader的使用步骤 创建shader 1 创建一个shader对象 GLuint glCreateShader GLenum shaderType 2 将shader源代码传入前面创建的shader对象 void glShaderS
  • 老嫂子的保姆级科普 选择视频剪辑软件就从阅读本文开始

    选错一款视频剪辑软件 是种什么样的体验 就好像新婚当晚 发现老婆是人妖一样 浪费了感情 又错付了青春 新手在学习视频剪辑的初期 需要花费大量精力去熟悉剪辑软件的基础功能 而软件挑选本身没有对错可言 适合自己的才是最好的 因此 本文仅从事实与
  • 初识Java(一)

    Java开发语言 前言 一 Java是什么 二 应用领域 特点及核心机制 1 应用领域 2 特性及特点 特性 特点 3 两种核心机制 三 JDK JRE JVM的关系 四 Java环境变量配置 五 编写我的第一个程序 总结 前言 计算机语言
  • C# 实现rabbitmq 延迟队列功能(不堵塞)

    最近在研究rabbitmq 项目中有这样一个场景 在用户要支付订单的时候 如果超过30分钟未支付 会把订单关掉 当然我们可以做一个定时任务 每个一段时间来扫描未支付的订单 如果该订单超过支付时间就关闭 但是在数据量小的时候并没有什么大的问题
  • 计算机基础msoffice等宽两缆,一级计算机基础及《MSOffice应用》模拟题

    一级计算机基础及 MSOffice应用 模拟题 三 字处理题 共25分 26 在考生文件夹下打开文档WORD DOCX 按照要求完成下列操作并以该文件名 WORD oocx 保存文档 文档开始 IBM电子商务专利的特点 通过对IBM e c
  • Typora主题下载

    1 0前言 Typora有很多主题可以使用 默认的主题很少 想要自己的主题更加个性化 可以去添加更多的主题来优化自己的使用体验 2 0下载主题 2 1 找到Typora主题的网站 1 打开一个typora文件此点击 2 进入偏好设置 3依次
  • 【目标检测】32、让你一文看懂且看全 NMS 及其变体

    文章目录 一 NMS 1 1 背景 1 2 方法 1 3 代码 1 4 不足 二 Soft NMS 2 1 背景 2 2 方法 2 3 效果 2 4 代码 2 5 不足 三 Softer NMS 3 1 背景 3 2 方法 四 IoU Ne
  • MySQL开启bin_log后导致创建函数、存储过程失败。Error:Result_ 1418 - This function has none of DETERMINISTIC

    搭建分布式服务 使用了主从数据库 需要使用MySQL的binlog去同步数据 但是开启binlog后导致新增函数 存储过程等报错 具体报错信息如下 Result 1418 This function has none of DETERMIN
  • kitti depth complement

    代码 运行环境 windows10 open3d版本 0 12 0 import cv2 import numpy as np import os import math import open3d as o3d basic path D
  • 【好工具】网页剪藏+免费云端笔记+一键变博客

    欢迎大家来到 好工具 专栏 这个专栏面向所有希望获得高效生产力工具的朋友 在这个专栏里 我们会和大家聊聊那些狂拽酷霸炫的生产力工具 相信大家一定我一样 茫然于庞大的工具海洋 却仍找不到称心的它来使用 这也是 好工具 专栏存在的意义 发掘 折
  • 贝叶斯优化及其python实现

    贝叶斯优化是机器学习中一种常用的优化技术 其目的是在有限步数内寻找函数的最大值或最小值 它可以被视为在探索不同参数配置与观察这些配置结果之间寻求平衡点的过程 基本思想是将我们在过去的观察和体验 传递到下一个尝试中 从而在等待数据的反馈时 逐
  • 微信小程序开发实战第五讲之授权登录

    上一节 我们实现了简单的通过用户名和密码调用接口进行登录的实战 但是在小程序中 有个特殊的情况 就是很少有厂商去开发一个注册功能或者是通过用户名 密码来登录的逻辑 为什么 因为APP 小程序为了用户体验 是尽量多的避免用户多次输入交互 所以
  • 物联网LoRa系列-17:LoRa终端Sx1262芯片内部的射频信号放大器

    至此 我们已经拆解了天线是如何发送和接收空中的无线电磁波信号 拆解了无线终端如何对射频前端的高频电信号进行进一步处理的 还拆解了无线终端的发送和接收如何分时复用天线的半双工模式 本篇将进一步拆解无线终端是如何对射频电信号进行进一步的处理 包
  • 【优化器】(一) SGD原理 & pytorch代码解析

    1 简介 很多情况下 我们调用优化器的时候都不清楚里面的原理和构造 主要基于自己数据集和模型的特点 然后再根据别人的经验来选择或者尝试优化器 下面分别对SGD的原理 pytorch代码进行介绍和解析 2 梯度下降 梯度下降方法可以分为3种