Adam和AdamW的区别

2023-10-26

Adam 与 Adamw的区别

一句话版本

Adamw 即 Adam + weight decate ,效果与 Adam + L2正则化相同,但是计算效率更高,因为L2正则化需要在loss中加入正则项,之后再算梯度,最后在反向传播,而Adamw直接将正则项的梯度加入反向传播的公式中,省去了手动在loss中加正则项这一步

实验

Adamw算法:

在这里插入图片描述

图片参考:https://www.cnblogs.com/tfknight/p/13425532.html

ps:以现在的源码为标准,图中的代码12行中:
λ = w d ∗ l r \lambda = wd * lr λ=wdlr
wd=weight decay, lr=learning rate

实验代码:

my_opt是手动实现的代码,主要验证梯度下降是否正确

import torch
import torch.nn as nn
import torch.optim.lr_scheduler
from torch.optim import AdamW, Adam


class M(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(3, 1, bias=False)

    def forward(self, x):
        return self.fc(x)


class my_opt():
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
                 weight_decay=1e-2):
        self.params = params
        self.lr = lr
        self.b1 = betas[0]
        self.b2 = betas[1]
        self.eps = eps
        self.wd = weight_decay
        self.m = 0
        self.v = 0
        self.b1t = 1.0
        self.b2t = 1.0

    # 模拟AdamW的step
    def stepW(self):
        for name, param in self.params.named_parameters():
            if param.grad is None:
                continue
            g = param.grad
            self.m = self.b1 * self.m + (1 - self.b1) * g
            self.v = self.b2 * self.v + (1 - self.b2) * g * g
            self.b1t *= self.b1
            self.b2t *= self.b2
            m = self.m / (1 - self.b1t)
            v = self.v / (1 - self.b2t)
            n = 1.0
            param.data -= n * (self.lr * (m / (v.sqrt() + self.eps) + self.wd * param.data))

    # 模拟Adam的step
    def step(self):
        for name, param in self.params.named_parameters():
            if param.grad is None:
                continue
            g = param.grad
            self.m = self.b1 * self.m + (1 - self.b1) * g
            self.v = self.b2 * self.v + (1 - self.b2) * g * g
            self.b1t *= self.b1
            self.b2t *= self.b2
            m = self.m / (1 - self.b1t)
            v = self.v / (1 - self.b2t)
            n = 1.0
            param.data -= n * (self.lr * m / (v.sqrt() + self.eps))


adam_model = M()
adamw_model = M()
my_adam_model = M()
my_adamw_model = M()
# 使4个模型参数相同
adamw_model.load_state_dict(adam_model.state_dict())
my_adam_model.load_state_dict(adam_model.state_dict())
my_adamw_model.load_state_dict(adam_model.state_dict())

model_ls = {'adam_model': adam_model,
            'adamw_model': adamw_model,
            'my_adam_model': my_adam_model,
            'my_adamw_model': my_adamw_model}

# 检查4个模型初始参数
for m in model_ls:
    print(f"Model : {m}")
    model = model_ls[m]
    for name, parma in model.named_parameters():
        print(name)
        print(parma)

adam_opt = Adam(adam_model.parameters(), lr=0.1)
adamw_opt = AdamW(adamw_model.parameters(), lr=0.1)
my_adam_opt = my_opt(my_adam_model, lr=0.1)
my_adamw_opt = my_opt(my_adamw_model, lr=0.1)

opt_ls = {'adam_model': adam_opt,
          'adamw_model': adamw_opt,
          'my_adam_model': my_adam_opt,
          'my_adamw_model': my_adamw_opt}

for i in range(5):
    print(">>>>>>>>>>>>>>>>>>>>>>> epoch", i)
    ip = torch.rand(2, 3)
    for m in model_ls:
        print(f">>> Model : {m}")
        model = model_ls[m]
        opt = opt_ls[m]
        loss = (model(ip).sum()) ** 2
        loss.backward()

        if m != 'my_adamw_model':
            opt.step()
        else:
            opt.stepW()

        for name, parma in model.named_parameters():
            print(name)
            print(parma)

        model.zero_grad()

代码是照着图片公式写的,和源码还是有点区别的,但是可以验证每一步的梯度下降都对了

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Adam和AdamW的区别 的相关文章

  • 如何查找分布式dask中任务失败的原因?

    我正在开发一个分布式计算系统dask distributed 我通过以下方式提交给它的任务Executor map功能有时会失败 而其他看起来相同的功能却可以成功运行 该框架是否提供了诊断问题的方法 update我所说的失败是指增加 Bok
  • 为什么方法无法访问类变量?

    我试图理解Python中的变量作用域 除了我不明白为什么类变量不能从其方法访问的部分之外 大多数事情对我来说都很清楚 在下面的例子中mydef1 无法访问a 但如果a可以在全局范围 类定义之外 声明 class MyClass1 a 25
  • 蟒蛇 |如何将元素随机添加到列表中

    有没有一种方法可以将元素随机添加到列表中 内置函数 ex def random append lst a lst append b lst append c lst append d lst append e return print ls
  • 返回不包括指定键的字典副本

    我想创建一个函数 返回字典的副本 不包括列表中指定的键 考虑这本词典 my dict keyA 1 keyB 2 keyC 3 致电without keys my dict keyB keyC 应该返回 keyA 1 我想用一行简洁的字典理
  • multiprocessing.freeze_support()

    为什么多处理模块需要调用特定的function http docs python org dev library multiprocessing html multiprocessing freeze support在被 冻结 以生成 Wi
  • 在python中调用subprocess.Popen时“系统找不到指定的文件”

    我正在尝试使用svnmerge py合并一些文件 它在底层使用 python 当我使用它时 我收到一个错误 系统找不到指定的文件 工作中的同事正在运行相同版本的svnmerge py 以及 python 2 5 2 特别是 r252 609
  • 熊猫记忆

    我有冗长的计算 我重复了很多次 因此 我想使用记忆 诸如jug http packages python org Jug and joblib http packages python org joblib memory html 与Pan
  • 这可能是因为 cuDNN 初始化失败,因此请尝试查看上面是否打印了警告日志消息。 [操作:Conv2D]

    我在 anaconda 中安装了 TensorFlow GPU 2 0 当我安装它并导入包 然后运行我的 CNN 模型时 它工作正常 但当我尝试运行训练模型时 出现错误 这是我的错误报告 Epoch 1 50 UnknownError Tr
  • 别碰我的女人

    我讨厌的一件事迪斯图尔斯 http docs python org distutils 我猜他是邪恶的人 他这样做了 https github com python cpython blob 300dd552b15825abfe0e367a
  • 如何获取 Matplotlib 生成的散点图的像素坐标?

    我使用 Matplotlib 生成散点图的 PNG 文件 现在 对于每个散点图 除了 PNG 文件之外 我还会also就像生成散点图中各个点的像素坐标列表一样 我用来生成散点图 PNG 文件的代码基本上是这样的 from matplotli
  • 从 Apache 运行 python 脚本的最简单方法

    我花了很长时间试图弄清楚这一点 我基本上正在尝试开发一个网站 当用户单击特定按钮时 我必须在其中执行 python 脚本 在研究了 Stack Overflow 和 Google 之后 我需要配置 Apache 以便能够运行 CGI 脚本
  • django如何将字符串转换为模块?

    我试图了解 django 的另一个神奇之处 它可以将字符串转换为模块 In settings py INSTALLED APPS声明如下 INSTALLED APPS django contrib auth django contrib c
  • 向伪 shell (pty) 发出命令

    我尝试使用 subprocess popen os spawn 来运行进程 但似乎需要伪终端 import pty master slave pty openpty os write master ls l 应该发送 ls l 到从属终端
  • Pandas DataFrame:如何计算组中第一行和最后一行的差异?

    这是我的熊猫数据框 import pandas as pd import numpy as np data column1 338 519 871 1731 2693 2963 3379 3789 3910 4109 4307 4800 4
  • scrapy python 请求未定义

    我在这里找到了答案 code for site in sites Link site xpath a href extract CompleteLink urlparse urljoin response url Link yield Re
  • Python“self”关键字[重复]

    这个问题在这里已经有答案了 我是 Python 新手 通常使用 C 最近几天开始使用它 在类中 是否需要在对该类的数据成员和方法的任何调用前添加前缀 因此 如果我在该类中调用方法或从该类获取值 我需要使用self method or sel
  • Django - 缺少 1 个必需的位置参数:'request'

    我收到错误 get indiceComercioVarejista 缺少 1 个必需的位置参数 要求 当尝试访问 get indiceComercioVarejista 方法时 我不知道这是怎么回事 views from django ht
  • 将数组从 .npy 文件读入 Fortran 90

    我使用 Python 以二维数组 例如 X 的形式生成一些初始数据 然后使用 Fortran 对它们进行一些计算 最初 当数组大小约为 10 000 x 10 000 时 np savetxt 在速度方面表现良好 但是一旦我开始增加数组的维
  • 如何抑制 Pandas Future 警告?

    当我运行该程序时 Pandas 每次都会给出如下所示的 未来警告 D Python lib site packages pandas core frame py 3581 FutureWarning rename with inplace
  • 从 Flask 中的 S3 返回 PDF

    我正在尝试在 Flask 应用程序的浏览器中返回 PDF 我使用 AWS S3 来存储文件 并使用 boto3 作为与 S3 交互的 SDK 到目前为止我的代码是 s3 boto3 resource s3 aws access key id

随机推荐

  • R手册(Tidy+Transform)--缺失处理(naniar and simputation)

    文章目录 naniar 缺失数据摘要 阴影矩阵 可视化缺失值变量分布关系 simputation make imputation simpler for missing data 缺失值是指粗糙数据中由于缺少信息而造成的数据的聚类 分组 删
  • cloudstack GuestNetwork Ingress-Egress rule

    Egress 1 创建 egress 规则 1 向management发出api命令 createEgressFirewallRulecmd 的create 方法 最终在cloud数据库firewall rules表中插一条state Ad
  • 阿里druid-spring-boot-starter 配置,个人整理以及遇到的问题(防止之后找不到)

    简介 什么是Druid Druid是阿里巴巴开源平台上的一个项目 整个项目由数据库连接池 插件框架和SQL解析器组成 该项目主要是为了扩展JDBC的一些限制 可以让程序员实现一些特殊的需求 比如向密钥服务请求凭证 统计SQL信息 SQL性能
  • 服务器上使用screen和linux的基本操作

    临时换源 pip install torch 1 7 1 i https pypi tuna tsinghua edu cn simple some package pip install torch 1 7 1 i http pypi d
  • MATLAB入门到精通(三):常用函数及数学应用

    合集如下 MATLAB入门到精通 一 简介及数据类型 MATLAB入门到精通 二 基本语句及绘图 MATLAB入门到精通 三 常用函数及数学应用 十一 常用函数 11 1 随机数函数 11 1 1 rand 函数 rand 函数用来产生均匀
  • 创建一个React项目实现一个计算器

    使用环境react脚手架 node js create react app 文件名 配置完这些就让我们开始把 count js import React Component from react import store from redu
  • mybatis中的typeAlias

    mybatis 的 xml 文件中需要写类的全限定名 较繁琐 可以配置自动扫描包路径给类配置别名 有两种配置方式 方式一 mybatis config xml 中配置
  • AB32VG1项目之智能晾衣架

    智能晾机架项目 开发过程 前期准备 分离工程 导入工程 安装包 安装最近的rt thread 包 AB32VG1的 SDK包 RISC V GCC工具链 下载 硬件搭建 开发板上的3 3V能否可用的问题 大体的硬件规划 软件设计 控制逻辑设
  • 关于Unrecognized Windows Sockets error: 5: socket write error 错误

    最近有个需求是从A数据库读取数据导入到B数据库 demo的数据量也就几万条 但是遇到了一个非常罕见的问题 后端框架是mybatis plus spring boot 在insertBatch到数据库B时 没有立即报错 而是执行插入了几百条数
  • 小程序通过webView打开H5页面并传参(包含webView业务域名配置)、H5页面实现返回小程序并实现传参

    小程序内嵌webview实现跳转 传参 1 小程序通过webView打开H5页面并传参 2 H5接收小程序传参 H5返回小程序并实现传参 小程序接收H5传参 目录 一 小程序通过webView打开H5页面并传参 1 业务域名 2 在小程序中
  • (转)认识SAP SD销售模式之跨公司销售

    跨公司销售 销售订单的发货工厂对应的公司和销售组织对应的公司不同 比如 9801公司为销售性公司 9901为生产性的公司 当公司9801接到订单后 直接从9901公司发货 如果不通过跨公司销售 需要9801像9901公司下虚拟的采购订单 然
  • win10 下 Linux使用方法笔记

    最近想学习一下比特币源码 官方推荐是在Linux系统下学习 且推荐在win10 下的Linux系统进行编译运行 所以下面将学习过程记录一下 1 参考了这篇文章中的方法 进行安装WSL https www cnblogs com JettTa
  • agoda获取酒店数据

    最近改了改代码 正好解决了一些报错问题 更新出来 个别处会加蜜 数据库以及线程控制 from DBUtils PooledDB import PooledDB import requests import demjson import ti
  • 堆和栈的区别

    1 1内存分配方面 堆 一般由程序员分配释放 若程序员不释放 程序结束时可能由OS回收 注意它与数据结构中的堆是两回事 分配方式是类似于链表 可能用到的关键字如下 new malloc delete free等等 栈 由编译器 Compil
  • leetcode622-设计循环队列

    本题重点 1 选择合适的数据结构 2 针对选择的数据结构判断 空 和 满 这两点是不分先后次序的 在思考时应该被综合起来 事实上 无论我们选择链表还是数组 最终都能实现题中描述的 循环队列 的功能 只不过选择不同结构时 我们面临和需要解决的
  • 不是一个PDF文件或该文件已损坏

    之前用公司电脑打开PDF文档的时候 出现了这样的一种现象 就是提示格式错误 不是一个PDF文件或该文件已被损坏 有三种解决方法 1 有可能是电脑上自带的PDF阅读软件版本太低 出现了不兼容的现象 换个最新的PDF阅读器吧 我用了福昕阅读器很
  • 【死磕NIO】— 探索 SocketChannel 的核心原理

    大家好 我是大明哥 一个专注于 死磕 Java 系列创作的程序员 死磕 Java 系列为作者 chenssy 倾情打造的 Java 系列文章 深入分析 Java 相关技术核心原理及源码 死磕 Java https www cmsblogs
  • oracle不小心将表update修改了如何回滚

    oracle提供了一种闪回的方法 可以将某个时间的数据给还原回来 SELECT FROM T DIS EVENT RELATION TYPE AS OF TIMESTAMP TO TIMESTAMP 2023 08 08 15 31 00
  • python opencv 在线读取网络图片图像资源

    opencv在线读取网络图片图像资源 照例打开opencv3 3 0 python3 6官方文档 https docs opencv org master d8 dfe classcv 1 1VideoCapture html 详解 官方文
  • Adam和AdamW的区别

    Adam 与 Adamw的区别 一句话版本 Adamw 即 Adam weight decate 效果与 Adam L2正则化相同 但是计算效率更高 因为L2正则化需要在loss中加入正则项 之后再算梯度 最后在反向传播 而Adamw直接将