【deep_thoughts】30_PyTorch LSTM和LSTMP的原理及其手写复现

2023-11-13


视频链接: 30、PyTorch LSTM和LSTMP的原理及其手写复现_哔哩哔哩_bilibili

PyTorch LSTM API:https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html?highlight=lstm#torch.nn.LSTM

LSTM API

首先实例化一些参数:

import torch
import torch.nn as nn

# 定义一些常量
batch_size, seq_len, input_size, h_size = 2, 3, 4, 5
input = torch.randn(batch_size, seq_len, input_size)  # 随机初始化一个输入序列
c_0 = torch.randn(batch_size, h_size)  # 初始值,不会参与训练
h_0 = torch.randn(batch_size, h_size)

调用PyTorch中的 LSTM API:

# 调用官方 LSTM API
lstm_layer = nn.LSTM(input_size, h_size, batch_first=True)  # num_layers默认为1
output, (h_n, c_n) = lstm_layer(input, (h_0.unsqueeze(0), c_0.unsqueeze(0)))  # (D*num_layers=1, b, hidden_size)

看一下返回的结果的形状:

print(output.shape)  # [2,3,5] [b, seq_len, hidden_size]
print(h_n.shape)  # [1,2,5] [num_layers, b, hidden_size]
print(c_n.shape)  # [1,2,5] [num_layers, b, hidden_size]

这里输出一下lstm_layer中的参数名称及其形状:

for name, para in lstm_layer.named_parameters():
    print(name, para.shape)

输出结果如下:

weight_ih_l0 torch.Size([20, 4])  # [4*hidden_size, input_size]
weight_hh_l0 torch.Size([20, 5])  # [4*hidden_size, hidden_size]
bias_ih_l0 torch.Size([20])  # [4*hidden_size]
bias_hh_l0 torch.Size([20])  # [4*hidden_size]

手写 lstm_forward 函数

手写一个lstm_forward函数,实现LSTM的计算原理。官网上的计算公式,如下:
i t = σ ( W i i x t + b i i + W h i h t − 1 + b h i ) f t = σ ( W i f x t + b i f + W h f h t − 1 + b h f ) g t = tanh ( W i g x t + b i g + W h g h t − 1 + b h g ) o t = σ ( W i o x t + b i o + W h o h t − 1 + b h o ) c t = f t ⊙ c t + i t ⊙ g t h t = o t ⊙ tanh ( c t ) \begin{align} &i_t = \sigma(W_{ii}x_t + b_{ii} + W_{hi}h_{t-1} + b_{hi}) \\ &f_t = \sigma(W_{if}x_t + b_{if} + W_{hf}h_{t-1} + b_{hf}) \\ &g_t = \textup{tanh}(W_{ig}x_t + b_{ig} + W_{hg}h_{t-1} + b_{hg}) \\ &o_t = \sigma(W_{io}x_t + b_{io} + W_{ho}h_{t-1} + b_{ho}) \\ &c_t = f_t \odot c_t + i_t \odot g_t \\ &h_t = o_t \odot \textup{tanh}(c_t) \end{align} it=σ(Wiixt+bii+Whiht1+bhi)ft=σ(Wifxt+bif+Whfht1+bhf)gt=tanh(Wigxt+big+Whght1+bhg)ot=σ(Wioxt+bio+Whoht1+bho)ct=ftct+itgtht=ottanh(ct)
这里先将lstm_forward函数中的每个参数的维度写出来:

def lstm_forward(input, initial_states, w_ih, w_hh, b_ih, b_hh):
    h_0, c_0 = initial_states  # 初始状态  [b_size, hidden_size]
    b_size, seq_len, input_size = input.shape
    h_size = h_0.shape[-1]

    h_prev, c_prev = h_0, c_0
    # 需要将权重w在batch_size维进行扩维并复制,才能和x与h进行相乘
    w_ih_batch = w_ih.unsqueeze(0).tile(b_size, 1, 1)  # [4*hidden_size, in_size]->[b_size, ,]
    w_hh_batch = w_hh.unsqueeze(0).tile(b_size, 1, 1)  # [4*hidden_size, hidden_size]->[b_size, ,]

    output_size = h_size
    output = torch.zeros(b_size, seq_len, output_size)  # 初始化一个输出序列
    for t in range(seq_len):
        x = input[:, t, :]  # 当前时刻的输入向量 [b,in_size]->[b,in_size,1]
        w_times_x = torch.bmm(w_ih_batch, x.unsqueeze(-1)).squeeze(-1)   # bmm:含有批量大小的矩阵相乘
        # [b, 4*hidden_size, 1]->[b, 4*hidden_size]
        # 这一步就是计算了 Wii*xt|Wif*xt|Wig*xt|Wio*xt
        w_times_h_prev = torch.bmm(w_hh_batch, h_prev.unsqueeze(-1)).squeeze(-1)
        # [b, 4*hidden_size, hidden_size]*[b, hidden_size, 1]->[b,4*hidden_size, 1]->[b, 4*hidden_size]
        # 这一步就是计算了 Whi*ht-1|Whf*ht-1|Whg*ht-1|Who*ht-1

        # 分别计算输入门(i)、遗忘门(f)、cell门(g)、输出门(o)  维度均为 [b, h_size]
        i_t = torch.sigmoid(w_times_x[:, :h_size] + w_times_h_prev[:, :h_size] + b_ih[:h_size] + b_hh[:h_size])  # 取前四分之一
        f_t = torch.sigmoid(w_times_x[:, h_size:2*h_size] + w_times_h_prev[:, h_size:2*h_size]
                            + b_ih[h_size:2*h_size] + b_hh[h_size:2*h_size])
        g_t = torch.tanh(w_times_x[:, 2*h_size:3*h_size] + w_times_h_prev[:, 2*h_size:3*h_size]
                         + b_ih[2*h_size:3*h_size] + b_hh[2*h_size:3*h_size])
        o_t = torch.sigmoid(w_times_x[:, 3*h_size:] + w_times_h_prev[:, 3*h_size:]
                            + b_ih[3*h_size:] + b_hh[3*h_size:])
        c_prev = f_t * c_prev + i_t * g_t
        h_prev = o_t * torch.tanh(c_prev)

        output[:, t, :] = h_prev

    return output, (h_prev.unsqueeze(0), c_prev.unsqueeze(0))  # 官方是三维,在第0维扩一维

验证一下 lstm_forward 的准确性:

# 这里使用 lstm_layer 中的参数
# 加了me表示自己手写的
output_me, (h_n_me, c_n_me) = lstm_forward(input, (h_0, c_0), lstm_layer.weight_ih_l0,
                                           lstm_layer.weight_hh_l0, lstm_layer.bias_ih_l0, lstm_layer.bias_hh_l0)

打印一下,看两个的计算结果是否相同:

print("PyTorch API output:")
print(output)  # [2,3,5] [b, seq_len, hidden_size]
print(h_n)  # [1,2,5] [num_layers, b, hidden_size]
print(c_n)  # [1,2,5] [num_layers, b, hidden_size]
print("\nlstm_forward function output:")
print(output_me)  # [2,3,5] [b, seq_len, hidden_size]
print(h_n_me)  # [1,2,5] [num_layers, b, hidden_size]
print(c_n_me)

结果如下,完全一致,说明手写的是对的:

PyTorch API output:
tensor([[[ 0.1671,  0.2493,  0.2603, -0.1448, -0.1951],
         [-0.0680,  0.0478,  0.0218,  0.0735, -0.0604],
         [ 0.0144,  0.0507, -0.0556, -0.2600,  0.1234]],

        [[ 0.4561, -0.0015, -0.0776, -0.0644, -0.5319],
         [ 0.1667,  0.0111,  0.0114, -0.1227, -0.2369],
         [-0.0220,  0.0637, -0.2353,  0.0404, -0.1309]]],
       grad_fn=<TransposeBackward0>)
tensor([[[ 0.0144,  0.0507, -0.0556, -0.2600,  0.1234],
         [-0.0220,  0.0637, -0.2353,  0.0404, -0.1309]]],
       grad_fn=<StackBackward0>)
tensor([[[ 0.0223,  0.1574, -0.1572, -0.4663,  0.2110],
         [-0.0382,  0.6440, -0.4334,  0.0779, -0.3198]]],
       grad_fn=<StackBackward0>)

lstm_forward function output:
tensor([[[ 0.1671,  0.2493,  0.2603, -0.1448, -0.1951],
         [-0.0680,  0.0478,  0.0218,  0.0735, -0.0604],
         [ 0.0144,  0.0507, -0.0556, -0.2600,  0.1234]],

        [[ 0.4561, -0.0015, -0.0776, -0.0644, -0.5319],
         [ 0.1667,  0.0111,  0.0114, -0.1227, -0.2369],
         [-0.0220,  0.0637, -0.2353,  0.0404, -0.1309]]], grad_fn=<CopySlices>)
tensor([[[ 0.0144,  0.0507, -0.0556, -0.2600,  0.1234],
         [-0.0220,  0.0637, -0.2353,  0.0404, -0.1309]]],
       grad_fn=<UnsqueezeBackward0>)
tensor([[[ 0.0223,  0.1574, -0.1572, -0.4663,  0.2110],
         [-0.0382,  0.6440, -0.4334,  0.0779, -0.3198]]],
       grad_fn=<UnsqueezeBackward0>)

LSTMP

# 定义一些常量
batch_size, seq_len, input_size, h_size = 2, 3, 4, 5
proj_size = 3  # 要比hidden_size小

input = torch.randn(batch_size, seq_len, input_size)
c_0 = torch.randn(batch_size, h_size)
h_0 = torch.randn(batch_size, proj_size)  # 注意这里从原来的 h_size 换成了 proj_size

# 调用官方 LSTM API
lstm_layer = nn.LSTM(input_size, h_size, batch_first=True, proj_size=proj_size)  
output, (h_n, c_n) = lstm_layer(input, (h_0.unsqueeze(0), c_0.unsqueeze(0)))

打印一下返回的结果的形状:

print(output.shape)  # [2,3,3] [b, seq_len, proj_size]
print(h_n.shape)  # [1,2,3] [num_layers, b, proj_size]
print(c_n.shape)  # [1,2,5] [num_layers, b, hidden_size]

这里输出一下lstm_layer中的参数名称及其形状:

for name, para in lstm_layer.named_parameters():
    print(name, para.shape)

输出结果如下输出结果如下:

weight_ih_l0 torch.Size([20, 4])  # [4*hidden_size, input_size]
weight_hh_l0 torch.Size([20, 3])  # [4*hidden_size, proj_size]
bias_ih_l0 torch.Size([20])
bias_hh_l0 torch.Size([20])
weight_hr_l0 torch.Size([3, 5])  # 这个参数就是对 hidden_state 进行压缩的 [hidden_size, proj_size]

修改 lstm_forward 函数

修改lstm_forward函数,从而能够实现LSTMP:

def lstm_forward(input, initial_states, w_ih, w_hh, b_ih, b_hh, w_hr=None):
    h_0, c_0 = initial_states  # 初始状态  [b, proj_size][b, hidden_size]
    b_size, seq_len, input_size = input.shape
    h_size = c_0.shape[-1]

    h_prev, c_prev = h_0, c_0
    # 需要将权重w在batch_size维进行扩维并复制,才能和x与h进行相乘
    w_ih_batch = w_ih.unsqueeze(0).tile(b_size, 1, 1)  # [4*hidden_size, in_size]->[b_size, ,]
    w_hh_batch = w_hh.unsqueeze(0).tile(b_size, 1, 1)  # [4*hidden_size, hidden_size]->[b_size, ,]


    if w_hr is not None:
        proj_size = w_hr.shape[0]
        output_size = proj_size
        w_hr_batch = w_hr.unsqueeze(0).tile(b_size, 1, 1)  # [proj_size, hidden_size]->[b_size, ,]
    else:
        output_size = h_size

    output = torch.zeros(b_size, seq_len, output_size)  # 初始化一个输出序列
    for t in range(seq_len):
        x = input[:, t, :]  # 当前时刻的输入向量 [b,in_size]->[b,in_size,1]
        w_times_x = torch.bmm(w_ih_batch, x.unsqueeze(-1)).squeeze(-1)   # bmm:含有批量大小的矩阵相乘
        # [b, 4*hidden_size, 1]->[b, 4*hidden_size]
        # 这一步就是计算了 Wii*xt|Wif*xt|Wig*xt|Wio*xt
        w_times_h_prev = torch.bmm(w_hh_batch, h_prev.unsqueeze(-1)).squeeze(-1)
        # [b, 4*hidden_size, hidden_size]*[b, hidden_size, 1]->[b,4*hidden_size, 1]->[b, 4*hidden_size]
        # 这一步就是计算了 Whi*ht-1|Whf*ht-1|Whg*ht-1|Who*ht-1

        # 分别计算输入门(i)、遗忘门(f)、cell门(g)、输出门(o)  维度均为 [b, h_size]
        i_t = torch.sigmoid(w_times_x[:, :h_size] + w_times_h_prev[:, :h_size] + b_ih[:h_size] + b_hh[:h_size])  # 取前四分之一
        f_t = torch.sigmoid(w_times_x[:, h_size:2*h_size] + w_times_h_prev[:, h_size:2*h_size]
                            + b_ih[h_size:2*h_size] + b_hh[h_size:2*h_size])
        g_t = torch.tanh(w_times_x[:, 2*h_size:3*h_size] + w_times_h_prev[:, 2*h_size:3*h_size]
                         + b_ih[2*h_size:3*h_size] + b_hh[2*h_size:3*h_size])
        o_t = torch.sigmoid(w_times_x[:, 3*h_size:] + w_times_h_prev[:, 3*h_size:]
                            + b_ih[3*h_size:] + b_hh[3*h_size:])
        c_prev = f_t * c_prev + i_t * g_t
        h_prev = o_t * torch.tanh(c_prev)  # [b_size, h_size]

        if w_hr is not None:  # 对 h_prev 进行压缩,做projection
            h_prev = torch.bmm(w_hr_batch, h_prev.unsqueeze(-1))  # [b,proj_size,hidden_size]*[b,h_size,1]=[b,proj_size,1]
            h_prev = h_prev.squeeze(-1)  # [b, proj_size]

        output[:, t, :] = h_prev

    return output, (h_prev.unsqueeze(0), c_prev.unsqueeze(0))  # 官方是三维,在第0维扩一维

验证一下 lstm_forward 的准确性:

output_me, (h_n_me, c_n_me) = lstm_forward(input, (h_0, c_0), lstm_layer.weight_ih_l0, lstm_layer.weight_hh_l0,
                                           lstm_layer.bias_ih_l0, lstm_layer.bias_hh_l0, lstm_layer.weight_hr_l0)

print("PyTorch API output:")
print(output)  # [2,3,3] [b, seq_len, proj_size]
print(h_n)  # [1,2,3] [num_layers, b, proj_size]
print(c_n)  # [1,2,5] [num_layers, b, hidden_size]
print("\nlstm_forward function output:")
print(output_me)  # [2,3,3] [b, seq_len, proj_size]
print(h_n_me)  # [1,2,3] [num_layers, b, proj_size]
print(c_n_me)  # [1,2,5] [num_layers, b, hidden_size]

输出的结果如下,完全一致,说明手写的是对的:

PyTorch API output:
tensor([[[ 0.0392, -0.3149, -0.1264],
         [ 0.0141, -0.2619, -0.0760],
         [ 0.0306, -0.2166,  0.0915]],

        [[-0.0777, -0.1205, -0.0555],
         [-0.0646, -0.0926,  0.0391],
         [-0.0456, -0.0576,  0.1849]]], grad_fn=<TransposeBackward0>)
tensor([[[ 0.0306, -0.2166,  0.0915],
         [-0.0456, -0.0576,  0.1849]]], grad_fn=<StackBackward0>)
tensor([[[ 1.9913, -0.2683, -0.1221,  0.1751, -0.6072],
         [-0.2383, -0.2253, -0.0385, -0.8820, -0.1794]]],
       grad_fn=<StackBackward0>)

lstm_forward function output:
tensor([[[ 0.0392, -0.3149, -0.1264],
         [ 0.0141, -0.2619, -0.0760],
         [ 0.0306, -0.2166,  0.0915]],

        [[-0.0777, -0.1205, -0.0555],
         [-0.0646, -0.0926,  0.0391],
         [-0.0456, -0.0576,  0.1849]]], grad_fn=<CopySlices>)
tensor([[[ 0.0306, -0.2166,  0.0915],
         [-0.0456, -0.0576,  0.1849]]], grad_fn=<UnsqueezeBackward0>)
tensor([[[ 1.9913, -0.2683, -0.1221,  0.1751, -0.6072],
         [-0.2383, -0.2253, -0.0385, -0.8820, -0.1794]]],
       grad_fn=<UnsqueezeBackward0>)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【deep_thoughts】30_PyTorch LSTM和LSTMP的原理及其手写复现 的相关文章

随机推荐

  • 自动化运维---ansible常用模块之文件操作(file&blockinfile&lineinfile模块)

    自动化运维 ansible常用模块之文件操作 file blockinfile lineinfile模块 文章目录 自动化运维 ansible常用模块之文件操作 file blockinfile lineinfile模块 1 file模块
  • 7. QML类中对象树的创建和销毁顺序是这样的

    简述 有下面一段代码 通常会有需求在Component onCompleted信号之后做一些初始化操作 那这些组件初始化完成的顺序是怎样的 同时有创建完成的信号 也有对应销毁完成的信号 类似C 中的构造和析构函数 但我们这里叫信号处理程序
  • java三种分页查询的方式

    第一种 分页 需要查询出总数 第二种分页如果是以id为主键并且是递增的情况 第三种直接用do while进行分页查询 不需要查询总个数和最大最小值 mybatis plus分页 第四种分页 for循环分页
  • Vue指令学习

    目录 v text 设置标签的内容 v html 设置元素的innerHTML v on 为元素绑定事件 v show 根据布尔值控制元素的样式为显示或隐藏 v if 根据布尔值控制dom为显示或隐藏 v bind 在vue中为元素绑定属性
  • SQLite 如何在Windows下编译?

    SQLite 如何在Windows下编译 发表时间 2007 6 13 12 44 00 评论 打印 字体 大 中 小 本文链接 http blog pfan cn lounger 26745 html 复制链接 分享到 0 标签 C C
  • 计算机中¥符号按哪个键,电脑键盘符号快捷键大全 电脑键盘上每个键的作用?...

    电脑键盘符号快捷键大全 电脑键盘符号怎么打 很多朋友还不太清楚电脑的各个符号要怎么打 快捷键是什么呢 那么下面就一起来看看电脑键盘符号大全吧 电脑键盘符号怎么打 电脑键盘符号大全 常见的标点符号 分号 书名号 双引号 单引号 破折号 竖线
  • sublime简用

    1 使用goto anything 快速查询各种文件 可以快速定位CSS中选择器 或JavaScript中的function 2 其中的输入时选取简化的输入则可 bgc就代表background color 3 多行游标 光标放在单词中 然
  • hashmap为什么8转成红黑树_深入分析HashMap的红黑树实现方式

    在分析jdk1 8的HashMap实现原理之前 咱们先可以了解一下红黑树的设计 相比jdk1 7的HashMap而言 jdk1 8最重要的就是引入了红黑树的设计 当冲突的链表长度超过8个的时候 链表结构就会转为红黑树结构 01 故事的起因
  • Mysql——压缩包方式安装教程

    一 Mysql压缩包下载方式 zip版 5 7及8 0 的下载需到官方网站下载 不同版本对应能安装在不同的操作系统下 本次介绍的是mysql 8 0 30 winx64在win10下的安装方式 下载网址 MySQL Download MyS
  • android模拟器与宿主机通讯

    android模拟器与PC的端口映射 一 概述 Android系统为实现通信将PC电脑IP设置为10 0 2 2 自身设置为127 0 0 1 而PC并没有为Android模拟器系统指定IP 所以PC电脑不能通过IP来直接访问Android
  • Mysql增强半同步模式_MySQL增强半同步的搭建实验,和一些参数的个人理解

    环境信息 role ip port hostname master 192 168 188 101 4306 mysqlvm1 slave 192 168 188 201 4306 mysqlvm1 1 5306 6306 7306 MyS
  • eclipse搜索类快捷键

    习惯的编辑器可以提高编程效率 熟悉的快捷键可以提高工作效率 本文更新eclipse中常用的搜索快捷键 打开资源快捷键 Ctrl Shift R 通过在搜索框中输入名字可以很方便的在项目或工作空间中找某个文件 支持模糊查询功能 例如输入文件的
  • Linux防火墙

    关于linux系统防火墙 centos5 centos6 redhat6系统自带的是iptables防火墙 centos7 redhat7自带firewall防火墙 ubuntu系统使用的是ufw防火墙 必要操作 linux系统防火墙开放相
  • AOP之5种增强方法应用范例

    林炳文Evankaka原创作品 转载请注明出处http blog csdn net evankaka Spring AOP 提供了 5 种类型的通知 它们分别是 Before Advice 前置通知 After Returning Advi
  • PyTorch 手把手搭建神经网络 (MNIST)

    推荐下我自己建的Python学习群 856833272 群里都是学Python的 如果你想学或者正在学习Python 欢迎你加入 大家都是软件开发党 不定期分享干货 还有免费直播课程领取 包括我自己整理的一份2021最新的Python进阶资
  • python写入文件后换行_python写入文件自动换行问题的方法

    现在需要一个写文件方法 将selenium的脚本运行结果写入test result log文件中 首先创建写入方法 def write result str writeresult file r D eclipse4 4 1 script
  • 一些文件头

    由这些文件头即使文件后缀被乱改也可以通过查看二进制文件查出文件的匹配格式 当然这就是一些播放器识别文件的方法 1 从Ultra edit 32中提取出来的 JPEG jpg 文件头 FFD8FF PNG png 文件头 89504E47 G
  • 浅析进程与线程之间的区别

    文章目录 浅析进程与线程之间的区别 从最普遍的答案出发 什么是计算机资源 计算资源 存储资源 I O设备资源 什么是进程 线程 操作系统怎样给进程分配资源的 操作系统怎样调度进 线程的 进程的上下文切换 为什么需要线程 参考链接 浅析进程与
  • mybatis generator

    文章目录 generatorConfig xml GeneratorSqlmap java log4j properties lib maven pom generatorConfig xml
  • 【deep_thoughts】30_PyTorch LSTM和LSTMP的原理及其手写复现

    文章目录 LSTM API 手写 lstm forward 函数 LSTMP 修改 lstm forward 函数 视频链接 30 PyTorch LSTM和LSTMP的原理及其手写复现 哔哩哔哩 bilibili PyTorch LSTM