深度学习入门之SGD随机梯度下降法

2023-10-27

SGD

SGD为随机梯度下降法。用数学式可以将 SGD 写成如下的式(6.1)。
在这里插入图片描述
这里把需要更新的权重参数记为W,把损失函数关于W的梯度记为 ∂L/∂W 。 η η η 表示学习率,实际上会取 0.01 或 0.001 这些事先决定好的值。式子中的←表示用右边的值更新左边的值。

如式(6.1)所示,SGD 是朝着梯度方向只前进一定距离的简单方法。现在,将 SGD 实现为一个 Python 类(为方便后面使用,将其实现为一个名为 SGD 的类)。

class SGD:
    def __init__(self, lr=0.01):
        self.lr = lr#学习率

    def update(self, params, grads):
        for key in params.keys():
            params[key] -= self.lr * grads[key]

进行初始化时的参数 lr 表示 learning rate(学习率)。这个学习率会保存为实例变量。此外,代码段中还定义了 update(params, grads) 方法,这个方法在 SGD 中会被反复调用。

参数 params 和 grads 是字典型变量,按 params[‘W1’] 、grads[‘W1’] 的形式,分别保存了权重参数和它们对应的梯度。

使用这个 SGD 类,可以按如下方式进行神经网络的参数的更新(下面的代码是不能实际运行的伪代码)。

network = TwoLayerNet(...)
optimizer = SGD()
for i in range(10000):#更新次数
    ...
    x_batch, t_batch = get_mini_batch(...) # mini-batch
    grads = network.gradient(x_batch, t_batch)
    params = network.params
    optimizer.update(params, grads)
    ...

这里首次出现的变量名 optimizer 表示“进行最优化的人”的意思,这里由 SGD 承担这个角色。参数的更新由 optimizer 负责完成。我们在这里需要做的只是将参数和梯度的信息传给 optimizer 。

SGD的缺点

虽然 SGD 简单,并且容易实现,但是在解决某些问题时可能没有效率。这里,在指出 SGD 的缺点之际,我们来思考一下求下面这个函数的最小值的问题。
在这里插入图片描述

如图 6-1 所示,式(6.2)表示的函数是向 x 轴方向延伸的“碗”状函数。实际上,式(6.2)的等高线呈向 x 轴方向延伸的椭圆状。
在这里插入图片描述

图 6-1  的图形(左图)和它的等高线(右图)

现在看一下式(6.2)表示的函数的梯度。如果用图表示梯度的话,则如图 6-2 所示。这个梯度的特征是, y y y 轴方向上大, x x x 轴方向上小。换句话说,就是 y y y 轴方向的坡度大,而 x x x 轴方向的坡度小。

这里需要注意的是,虽然式 (6.2) 的最小值在 ( x , y ) = ( 0 , 0 ) (x , y ) = (0, 0) (x,y)=(0,0) 处,但是图 6-2 中的梯度在很多地方并没有指向 ( 0 , 0 ) (0, 0) (0,0)
在这里插入图片描述

图 6-2 f(x,y)=1/20 x^2 + y^2的梯度

我们来尝试对图 6-1 这种形状的函数应用 SGD。从 ( x , y ) = ( − 7.0 , 2.0 ) (x , y ) = (-7.0, 2.0) (x,y)=(7.0,2.0) 处(初始值)开始搜索,结果如图 6-3 所示。
在这里插入图片描述

图 6-3 基于 SGD 的最优化的更新路径:呈“之”字形朝最小值 (0, 0) 移动,**效率低**

在图 6-3 中,SGD 呈“之”字形移动。这是一个相当低效的路径。也就是说,SGD 的缺点是,如果函数的形状非均向(anisotropic),比如呈延伸状,搜索的路径就会非常低效。因此,我们需要比单纯朝梯度方向前进的 SGD 更聪明的方法。

SGD 低效的根本原因是,梯度的方向并没有指向最小值的方向。为了改正SGD的缺点,引入了MomentumAdaGrad、Adam这 3 种方法来取代SGD。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

深度学习入门之SGD随机梯度下降法 的相关文章

  • 查找 with: 块中定义的函数

    这是一些代码理查德 琼斯的博客 http www mechanicalcat net richard log Python Something I m working on 3 with gui vertical text gui labe
  • 将 pandas 数据框中的列减去其第一个值

    我需要将 pandas 数据帧的一列中的所有元素减去其第一个值 在这段代码中 pandas 抱怨 self inferred type 我猜这是循环引用 df Time df Time df Time 0 在这段代码中 pandas 抱怨为
  • Matplotlib 标准化颜色条 (Python)

    我正在尝试使用 matplotlib 当然还有 numpy 绘制轮廓图 它有效 它绘制了它应该绘制的内容 但不幸的是我无法设置颜色条范围 问题是我有很多图 并且需要所有图都具有相同的颜色条 相同的最小值和最大值 相同的颜色 我复制并粘贴了在
  • 如何在 Ubuntu 上安装 Python 模块

    我刚刚用Python写了一个函数 然后 我想将其做成模块并安装在我的 Ubuntu 11 04 上 这就是我所做的 创建 setup py 和 function py 文件 使用 Python2 7 setup py sdist 构建分发文
  • 用 Python 编写一个无操作或虚拟类

    假设我有这样的代码 foo fooFactory create 由于种种原因 fooFactory create 可能无法创建实例Foo 如果可以的话我想要fooFactory create 返回一个虚拟 无操作对象 这个对象应该是完全惰性
  • 如何自动替换多个文件的文本内容中的字符?

    我有一个文件夹 myfolder包含许多乳胶表 我需要替换其中每个字符 即替换任何minus sign by an en dash 只是为了确定 我们正在替换连字符INSIDE该文件夹中的所有 tex 文件 我不关心 tex 文件名 手动执
  • 无法包含外部 pandas 文档 Pycharm v--2018.1.2

    我无法包含外部 pandas 文档Pycharm v 2018 1 2 例如 numpy gt http docs scipy org doc numpy reference generated module name element na
  • 如何使用 openpyxl 对工作簿中的 Excel 工作表/选项卡进行排序

    我需要按字母数字对工作簿中的选项卡 工作表进行排序 我在用openpyxl https openpyxl readthedocs io en default 操作工作表 您可以尝试排序workbook sheets list workboo
  • 唯一的图像哈希值即使 EXIF 信息更新也不会改变

    我正在寻找一种方法来为 python 和 php 中的图像创建唯一的哈希值 我考虑过对原始文件使用 md5 和 因为它们可以快速生成 但是当我更新 EXIF 信息 有时时区关闭 时 它会更改总和 并且哈希也会更改 有没有其他方法可以为这些文
  • 反加入熊猫

    我有两个表 我想附加它们 以便仅保留表 A 中的所有数据 并且仅在其键唯一时添加表 B 中的数据 键值在表 A 和 B 中是唯一的 但在某些情况下键将出现在表 A 和 B 中 我认为执行此操作的方法将涉及某种过滤联接 反联接 以获取表 B
  • 我可以使用 dask 创建 multivariate_normal 矩阵吗?

    有点相关这个帖子 https stackoverflow com questions 52337612 random multivariate normal on a dask array 我正在尝试复制multivariate norma
  • 如何逐像素绘制正方形(Python,PIL)

    在空白画布上 我想使用 Pillow 逐像素绘制一个正方形 我尝试使用 img putpixel 30 60 155 155 55 绘制一个像素 但它没有执行任何操作 from PIL import Image def newImg img
  • Python Flask 是否定义了路由顺序?

    在我看来 我的设置类似于以下内容 app route test def test app route
  • python中的sys.stdin.fileno()是什么

    如果这是非常基本的或之前已经问过的 我很抱歉 我用谷歌搜索但找不到简单且令人满意的解释 我想知道什么sys stdin fileno is 我在代码中看到了它 但不明白它的作用 这是实际的代码块 fileno sys stdin filen
  • WindowsError:[错误 5] 访问被拒绝

    我一直在尝试终止一个进程 但我的所有选项都给出了 Windows 访问被拒绝错误 我通过以下方式打开进程 一个python脚本 test subprocess Popen sys executable testsc py 我想杀死那个进程
  • 使用 lambda 函数更改属性值

    我可以使用 lambda 函数循环遍历类对象列表并更改属性值 对于所有对象或满足特定条件的对象 吗 class Student object def init self name age self name name self age ag
  • 使用 Doc2vec 后如何解释 Clusters 结果?

    我正在使用 doc2vec 将关注者的前 100 条推文转换为矢量表示形式 例如 v1 v100 之后 我使用向量表示来进行 K 均值聚类 model Doc2Vec documents t size 100 alpha 035 windo
  • Python模块单元测试的最佳文件结构组织?

    遗憾的是 我发现有太多方法可以在 Python 中保存单元测试 而且它们通常没有很好的文档记录 我正在寻找一种 终极 结构 它可以满足以下大部分要求 be discoverable by test frameworks including
  • 使用 Keras 和 fit_generator 绘制 TensorBoard 分布和直方图

    我正在使用 Keras 使用 fit generator 函数训练 CNN 这似乎是一个已知问题 https github com fchollet keras issues 3358TensorBoard 在此设置中不显示直方图和分布 有
  • 从时间序列生成日期特征

    我有一个数据框 其中包含如下列 Date temp data holiday day 01 01 2000 10000 0 1 02 01 2000 0 1 2 03 01 2000 2000 0 3 30 01 2000 200 0 30

随机推荐

  • 最新Python爬虫有道翻译JS逆向解析详细介绍版,附源码

    我的第一篇文章 写的很详细 这里方便刚接触爬虫帅哥们理解 大家一起加油 前两步为js的逆向分析过程 了解过程的请跳到第三步 源码 最后打包成exe文件 有道翻译网址 在线翻译 有道 第一步 找到有道翻译发送请求的Url地址 老规矩进去界面F
  • 从方法到目标了解什么是机器学习?

    一 什么是机器学习 1 简述 机器学习是 人工智能 AI 和计算机科学的一个分支 专注于利用数据和算法来模仿人类的学习方式 逐步提高其准确性 过去几十年来 存储和处理能力方面的技术进步催生了一些基于机器学习的创新产品 例如 Netflix
  • NodeJs快速入门

    NodeJs入门介绍 Node js是一个Javascript运行环境 runtime 实际上它是对Google V8引擎进行了封装 所以 语法还是JavaScript的语法 只不过它封装了一些类库 可以更多的事 nodejs官网 在命令行
  • Python接口基础: WSDL 文件(soap )照样可以用requests进行post

    昨天 遇到一个难题 我接到一个webservice API 接口进行批量出单任务 造数据 方便测试report XML 内容如下
  • java实现通过共享文件夹实现文件上传下载(附源码工具类)

    1 简介 要实现文件上传下载基于smb协议 SMB Server Message Block 通信协议是微软 Microsoft 和英特尔 Intel 在1987年制定的协议 主要是作为Microsoft网络的通讯协议 SMB 是在会话层
  • 报错注入(主键重复)攻击原理

    基本原理 利用数据表中主键不能重复的特点 通过构造重复的主键 使得数据库报错 并将报错结果返回到前端 SQL说明函数 以pet数据表为例进行说明 rond 返回 0 1 区间内的任意浮点数 count 返回每个组的列行数 如 返回test表
  • 饿了么 (Element)的 日历(Calendar)组件 自定义

    笔记 由于本人用vue elementui 写了一个关于日历的项目 需求是每个日期对应不同的价格并且点击两次之后取其区间的值并计算出总价 后来翻了很多资料才找到一些思路 由于在饿了么ui的日历组件库里面没有这些方法 所以很奇怪 具体实现的方
  • CSS——三种导入方式

    h1 h1
  • Linux服务器安装anaconda和pytorch

    Linux服务器安装anaconda 参考链接 如何在Linux服务器上安装Anaconda 超详细 在官网上下载需要的版本 https repo anaconda com archive 注意尽量安装最新版本 Linux服务器安装pyto
  • 【计算机视觉

    文章目录 一 检测相关 16篇 1 1 Contextual Object Detection with Multimodal Large Language Models 1 2 Towards minimizing efforts for
  • 2023年直播行业的困境是什么?未来有哪些发展趋势?

    仅仅两年 现场直播货物完全着火 疫情再次将现场直播货物推向新的热潮 现场直播货物真的是未来的趋势吗 从比亚 李佳琦的货物神话到网红 明星 主持人 创业者 选手 企业干部 社长 法官 县长 市长等都陆续进入 与其他电子商务模式相比 直播电子商
  • 李宏毅2021年机器学习作业5(Seq2seq)实验记录

    李宏毅2021年机器学习作业5学习笔记 前言 一 问题描述 二 实验过程 2 1 基于RNN 2 2 基于Transformer 三 总结 前言 声明 本文参考了李宏毅机器学习2021年作业例程 开发平台是colab 一 问题描述 机器翻译
  • ClickHouse数据库与PHP的无缝集成

    ClickHouse数据库是一种基于列的数据库 支持高效数据的存储和查询 而PHP是一种流行的Web编程语言 被广泛应用于Web开发 在实际应用中 我们经常需要将PHP与ClickHouse进行集成 以实现高效的数据处理和查询 本文将探讨如
  • QT中事件及事件处理

    QT中事件及事件处理 什么是事件 事件与Qt中信号的区别 个人所见 事件是应用程序对内部或者外部的动作的统称 信号是事件的后续响应通知 例如你点击了一个按钮 物理上的鼠标点击动作就是事件 而程序收到事件时 就会发出按钮被按下的信号 通知按钮
  • mybatis+MySQL 新增数据返回主键id问题

    今天遇到个问题 怕自己又忘记 记录一下 有个需求 需要存入数据到MySQL后要返回主键id 我按照以前设置的方式得到的结果始终是1 就非常奇怪 找了原因 记录一下 int count userMapper insert user 拿到的是插
  • BigDecimal处理 四舍五入

    最近项目中遇到了关于BigDecimal取舍精度的问题 还遇到了一些坑 在此记录一下 public static void main String args BigDecimal bd new BigDecimal 10 5 int cou
  • python 微信授权 昵称乱码解决

    微信采用的是 ISO 8859 1 编码 所以只需要进行下面的转码 就可以了 先iso8859 1 解码 然后转换成 utf8 即可 print info nickname encode iso8859 1 decode utf8
  • oracle SGA

    三 实例内存结构和进程结构 由于内存结构和进程结构关系较紧密 进程会作用到对应的内存区域 比如数据库写入器作用到数据库缓冲区缓存中 日志写入器会作用到日志缓冲区 所以内存结构和进程结构会相互配合地进行描述 oracle实例内存结构由两部分组
  • Tensorboard打不开的解决方法

    最近在学习tensorflow 遇到了tensorboard打不开的现象 在网上在了一些方法 把他们全部总结在这里 1 如果在调用tensorboard之后 cmd中的链接打不开的话 可以试试127 0 0 1 6006或者localhos
  • 深度学习入门之SGD随机梯度下降法

    SGD SGD为随机梯度下降法 用数学式可以将 SGD 写成如下的式 6 1 这里把需要更新的权重参数记为W 把损失函数关于W的梯度记为 L W 表示学习率 实际上会取 0 01 或 0 001 这些事先决定好的值 式子中的 表示用右边的值