计算梯度的三种方法: 数值法,解析法,反向传播法

2023-11-07

计算梯度的三种方法: 数值法,解析法,反向传播法

一个简单的函数:

Python:

f(x,y,z)=(x+y)z

# coding=gbk

"""
function : f(x,y,z) = (x+y)z
"""
# first method   解析法
def grad1(x,y,z):
    dx = z
    dy = z
    dz = (x+y)
    return (dx,dy,dz)
# second method  数值法
def grad2(x,y,z,epi): 
    # dx
    fx1 = (x+epi+y)*z
    fx2 = (x-epi+y)*z
    dx = (fx1-fx2)/(2*epi)
    # dy
    fy1 = (x+y+epi)*z
    fy2 = (x+y-epi)*z
    dy = (fy1-fy2)/(2*epi)
    # dz
    fz1 = (x+y)*(z+epi)
    fz2 = (x+y)*(z-epi)
    dz = (fz1-fz2)/(2*epi)
    return (dx,dy,dz)
# third method 反向传播法
def grad3(x,y,z): 
    # forward
    p = x+y;
    f = p*z;    
    # backward
    dp = z
    dz = p
    dx = 1 * dp
    dy = 1 * dp
    return (dx,dy,dz)

print ("<df/dx,df/dy,df/dz>: %.2f %.2f %.2f"%(grad1(1,2,3)))       
print ("<df/dx,df/dy,df/dz>: %.2f %.2f %.2f"%(grad2(1,2,3,1e-5)))
print ("<df/dx,df/dy,df/dz>: %.2f %.2f %.2f"%(grad3(1,2,3)))

结果:

<df/dx,df/dy,df/dz>: 3.00 3.00 3.00
<df/dx,df/dy,df/dz>: 3.00 3.00 3.00
<df/dx,df/dy,df/dz>: 3.00 3.00 3.00

复杂一点的函数

以Sigmoid 为例:

f(w,x)=11+e(w0x0+w1x1+w2)

上面的Sigmoid 函数是输入二维的情况。 x=[x0x1]T
, w=[w0,w1]T , w2=b

显然函数是一个复合函数,是简单函数: f(x)=1x,f(x)=ex,f(x)=ax,f(x)=c+x 复合而成。

因此,我们可以写成: 波兰表达式树的形式。

这里我们只关心关于 w 的梯度,我们将函数写为:

f(w)=11+e(w0x0+w1x1+w2)

Matlab:


clc;
%% 下面向量书写的格式不采用统一规范形式。例如全部采用列向量的形式等。
w = [2,-3,-3];
x = [-1,-2];
% 一般形式的反向传播
[dw0,dw1,dw2] = grad1(w(1),w(2),w(3),x(1),x(2));
fprintf('%.8f,%.8f,%.8f \n',dw0,dw1,dw2);
% 数值法
[dw0,dw1,dw2] = grad2(w(1),w(2),w(3),x(1),x(2),1e-5);
fprintf('%.8f,%.8f,%.8f \n',dw0,dw1,dw2);
% 技巧形式的反向传播
dw = grad3(w,x);
fprintf('%.8f,%.8f,%.8f \n',dw(1),dw(2),dw(3));
% 解析法
dw = grad4(w,x);
fprintf('%.8f,%.8f,%.8f \n',dw(1),dw(2),dw(3));

% 一般形式的反向传播
function  [dw0,dw1,dw2] = grad1(w0,w1,w2,x0,x1)

% forward
p0 = -1*(w0*x0+w1*x1+w2);
p1 = exp(p0);
p2 = 1+p1;
p3 = 1/p2;
% backward
dp2 = (-1)*(p2^(-2));
dp1 = 1*dp2;
dp0 = dp1*exp(p0);

dw0 = dp0*(-x0);
dw1 = dp0*(-x1);
dw2 = dp0 *(-1);
end
% 数值法
function  [dw0,dw1,dw2] = grad2(w0,w1,w2,x0,x1,epi)
% dw0
f1w0 = 1.0/(1+exp(-1*((w0+epi)*x0+w1*x1+w2)));
f2w0 = 1.0/(1+exp(-1*((w0-epi)*x0+w1*x1+w2)));
dw0 = (f1w0 - f2w0)/(2*epi);
% dw1
f1w1 = 1.0/(1+exp(-1*(w0*x0+(w1+epi)*x1+w2)));
f2w1 = 1.0/(1+exp(-1*(w0*x0+(w1-epi)*x1+w2)));
dw1 = (f1w1 - f2w1)/(2*epi);
% dw2
f1w2 = 1.0/(1+exp(-1*(w0*x0+w1*x1+(w2+epi))));
f2w2 = 1.0/(1+exp(-1*(w0*x0+w1*x1+(w2-epi))));
dw2 = (f1w2 - f2w2)/(2*epi);
end
% 技巧形式的反向传播
% 利用sigmoid 函数的技巧:  sigma(x)' = (1-sigma(x))*sigma(x)
function  dw = grad3(w,x)
% forward
dot = w(1)*x(1) + w(2)*x(2) + w(3);
f = 1.0/(1+exp(-dot));
% backward
ddot = (1-f)*f;
dx = [w(1)*ddot,w(2)*ddot]; % 不输出
dw = [x(1)*ddot,x(2)*ddot,1.0*ddot];
end
% 解析法
%  f(w)' = 1/(1+e^())  * e^() * (-x0)
function  dw = grad4(w,x)
x = [x 1];
dw = (-1)*(1+exp(- w*x'))^(-2)*exp(- w*x').*(-x);
end

结果:

-0.19661193,-0.39322387,0.19661193 
-0.19661193,-0.39322387,0.19661193 
-0.19661193,-0.39322387,0.19661193 
-0.19661193,-0.39322387,0.19661193 

更复杂一些的函数

如下函数:

f(x,y)=x+σ(y)σ(x)+(x+y)2

其中
σ(x)=11+ex

上述公式写出解析形式的表达式,似乎吃力。

略… 请参考[参考文献].

参考文献:

  1. https://zhuanlan.zhihu.com/p/21407711?refer=intelligentunit [CS231n课程笔记翻译:反向传播笔记]
  2. http://cs231n.github.io/optimization-2/ [CS231n backpropagation]
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

计算梯度的三种方法: 数值法,解析法,反向传播法 的相关文章

随机推荐

  • Jmeter系列-测试计划详细介绍(3)

    测试计划的作用 测试计划描述了 Jmeter 在执行时 一系列的步骤 一个完整的测试计划包含了一个或多个 线程组 逻辑控制器 采样器 监听器 定时器 断言和配置元素 Jmeter原件和组件的介绍 基本元件的介绍 多个类似功能组件的 容器 类
  • 浅谈Unity资源异步加载和Coroutine的使用

    为了节省内存 游戏的一些资源往往需要在运行时 runtime 动态加载 如果资源本身加载比较耗时 采用同步方法会产生卡顿现象 对此的解决方法通常采用多线程或者使用引擎本身自带的异步加载方法 在Unity开发中 由于一些方法 如Resourc
  • 微信小程序 audio 音频 组件

    完整微信小程序 Java后端 技术贴目录清单页面 必看 音频 1 6 0版本开始 该组件不再维护 建议使用能力更强的 wx createInnerAudioContext 接口 属性 类型 默认值 必填 说明 最低版本 id string
  • 知识图谱——Python操作Neo4j导入CSV文件建立图谱

    首先Neo4j是图数据库 最重要的就是结点和边的关系 每两个结点和边都可以看成三元组 主谓宾的关系 当然结点也是可以添加属性的 但是首先要有结点 在添加属性 本片文章就是用简单的方式一次性给大家讲解清楚 简单起见 我们用西游记师徒四人为例子
  • HC-SR505红外感应模块驱动(STM32)

    一 前期准备 单片机 STM32F103ZET6 开发环境 MDK5 14 库函数 标准库V3 5 HC SR505红外感应模块 淘宝有售 二 实验效果 三 驱动原理 这个模块比较简单 当有人靠近时候其IO输出3 3V STM32可以直接采
  • Scrapy知识系列:使用CrawlerProcess从外部运行多个spider时,运行脚本需要与scrapy.cfg在同级目录

    说明 如题 否则settings pipelines middlewares都没有办法直接使用 修改起来非常麻烦
  • JAVAWEB学习笔记-前端基础

    文章目录 HTML篇 HTML简介 HTML元素 开始编写 CSS篇 认识css CSS 规则集 解释 css的初步使用 在HTML里使用CSS 外部样式表 内部样式表 内联样式 规则 速记属性 CSS工作原理 HTML篇 HTML简介 参
  • postgres wal2json插件jsonb字段数据丢失问题解决

    使用pg wal2json debezium进行数据同步时 发现偶尔会有jsonb字段数据丢失的问题 进行测试时发现 1 发生数据丢失的jsonb字段长度都比较大 超过toast阈值 使用toast表存储 2 针对发生jsonb字段丢失的数
  • llvm libLLVMCore源码分析 04 - Use Class

    源码路径 llvm include llvm IR Use h llvm include llvm IR Value h llvm include llvm IR User h llvm Use class 在之前的系列文章中 我们讲到Us
  • npm,cnpm,yarn,tyarn 区别

    做前端的应该都用过标题提到的包管理器 简单说一下这4个包管理器的区别 npm 这应该是最常用的 在某些情况会出现丢包 而且由于某种原因会下载很慢 通常会配置国内镜像 我已经很少用npm了 主要用它下载 cnpm 或 yarn cnpm 这个
  • 为什么您的WordPress网站会容易被黑客攻击

    首先 不仅是WordPress 互联网上所有具有内容管理系统 CMS 的网站都容易受到黑客攻击 WordPress网站成为通用目标的原因是因为WordPress是世界上最受欢迎的网站CMS 它为全球超过33 的网站提供支持 这种巨大的流行度
  • 【Spring Boot 源码学习】@SpringBootApplication 注解

    Spring Boot 源码学习系列 SpringBootApplication 注解 引言 主要内容 1 创建 Spring Boot 项目 2 Spring Boot 入口类 3 SpringBootApplication 介绍 总结
  • 结构体中定义函数指针

    结构体指针变量的定义 定义结构体变量的一般形式如下 形式 先定义结构体类型 再定义变量 struct 结构体标识符 成员变量列表 struct 结构体标识符 指针变量名 变量初始化 struct 结构体标识符 变量名 初始化值1 初始化值2
  • 【MYSQL】排序时 如何将0排到最后,并让其他值按正序展示?

    背景 展示排名时需要1 2 3 4 5 这样展示但是有些没有排名得数据字段默认值时0 这时直接用ASC就会出现问题 实现效果 实现方式 使用MySQL的ORDER BY语句来实现 以下是一个示例的SQL查询语句 SELECT FROM ta
  • 划拳 C语言

    划拳是古老中国酒文化的一个有趣的组成部分 酒桌上两人划拳的方法为 每人口中喊出一个数字 同时用手比划出一个数字 如果谁比划出的数字正好等于两人喊出的数字之和 谁就赢了 输家罚一杯酒 两人同赢或两人同输则继续下一轮 直到唯一的赢家出现 下面给
  • IDEA导入maven依赖失败解决方法

    由于网络问题 maven依赖经常会导入失败 一般的jar包是从中央仓库或阿里云仓库进行拉取 网络加载慢超时等原因导致相关依赖jar包导入不全 下面就我在实际的项目导入操作中遇到的问题及解决方法进行总结梳理 希望可以帮助到大家 方法一 更换仓
  • 数据结构-单链表交换相邻两个元素-java

    1 递归法 时间复杂度O n 递归的时间复杂度一般看层数 这个层数是n层 每层执行一次操作 所以是O n 原理 把后半部分看成已经反转好的数据 public ListNode reverseAdjoinList ListNode head
  • 运放的PID电路

    PID就是 比例 proportion 积分 integral 导数 derivative 在工程实际中 应用最为广泛的调节器控制规律为比例 积分 微分控制 简称PID控制 又称PID调节 运放的积分电路 运放的微分电路 微分电路的输出端和
  • 【Python机器学习】KNN进行水果分类和分类器实战(附源码和数据集)

    需要源码和数据集请点赞关注收藏后评论区留言私信 KNN算法简介 KNN K Nearest Neighbor 算法是机器学习算法中最基础 最简单的算法之一 它既能用于分类 也能用于回归 KNN通过测量不同特征值之间的距离来进行分类 KNN算
  • 计算梯度的三种方法: 数值法,解析法,反向传播法

    计算梯度的三种方法 数值法 解析法 反向传播法 一个简单的函数 Python f x y z x y z begin equation begin aligned f x y z x y z end aligned end equation