一文读懂LSTM及手写LSTM结构

2023-11-06

        `torch.nn.LSTM`是PyTorch中用于创建长短时记忆网络(Long Short-Term Memory)的类。LSTM是一种用于处理序列数据的循环神经网络(Recurrent Neural Network,RNN)变体。

官方给出的LSTM API 文档 

 以下是 `torch.nn.LSTM` 的主要参数(用于配置和定制 LSTM 层的行为):

1. `input_size`(必需参数):输入数据的特征维度大小。这是输入序列的特征向量的维度。

2. `hidden_size`(必需参数):LSTM 单元的隐藏状态的维度大小。这决定了 LSTM 层的输出和内部隐藏状态的维度。

3. `num_layers`(可选参数,默认为 1):LSTM 层的堆叠层数。你可以将多个 LSTM 层叠加在一起,以增加模型的容量和表示能力。

4. `bias`(可选参数,默认为 True):一个布尔值,确定是否在 LSTM 单元中包含偏置项。

5. `batch_first`(可选参数,默认为 False):一个布尔值,指定输入数据的形状。如果设置为 True,输入数据的形状应为 `(batch_size, sequence_length, input_size)`,否则为 `(sequence_length, batch_size, input_size)`。

6. `dropout`(可选参数,默认为 0.0):应用于除最后一层外的每个 LSTM 层的丢弃率。这有助于防止过拟合。

7. `bidirectional`(可选参数,默认为 False):一个布尔值,指定是否使用双向 LSTM。如果设置为 True,LSTM 将具有前向和后向的隐藏状态,以更好地捕捉序列的上下文信息。

8. `batch_first`(可选参数,默认为 False):一个布尔值,用于指定输入数据的形状。如果设置为 True,则输入数据应为 `(batch_size, sequence_length, input_size)`,否则为 `(sequence_length, batch_size, input_size)`。

9. `device`(可选参数):指定要在哪个设备上创建 LSTM 层,例如 CPU 或 GPU。

10. `dtype`(可选参数):指定数据类型,例如 `torch.float32` 或 `torch.float64`。

11. `return_sequences`(可选参数,默认为 False):一个布尔值,指定是否返回每个时间步的输出序列。如果设置为 True,则返回完整的输出序列;否则,只返回最后一个时间步的输出。

        这些参数允许你根据具体的任务和模型架构来配置 LSTM 层。根据你的需求,你可以灵活地选择不同的参数值来构建不同的 LSTM 模型。

LSTM的输入

 

`torch.nn.LSTM` 层的输入通常是一个包含两个元素的元组 `(input, (h_0, c_0))`,调用方法为:

output, (h_n, c_n) = torch.nn.LSTM(input, (h_0,c_0))

其中:

(1)        

        input 通常是一个三维张量,具体形状取决于是否设置了 `batch_first` 参数。输入张量包括以下维度:

1. 批量维度(Batch Dimension):这是数据中的样本数量。如果 `batch_first` 设置为 True,那么批量维度将是第一个维度;否则,批量维度将是第二个维度。

2. 序列长度维度(Sequence Length Dimension):这是时间步的数量,也是序列的长度。它是输入序列中数据点的数量。

3. 特征维度(Feature Dimension):这是输入数据点的特征数量。它表示每个时间步的输入特征向量 xt 的维度。

根据上述描述,以下是两种常见的输入形状:

- 如果 `batch_first` 为 True:
    - 输入张量的形状为 `(batch_size, sequence_length, input_size)`。
    - `batch_size` 是批量大小,表示同时处理的样本数量。
    - `sequence_length` 是序列的长度,即时间步的数量。
    - `input_size` 是输入特征向量的维度。

- 如果 `batch_first` 为 False:
    - 输入张量的形状为 `(sequence_length, batch_size, input_size)`。
    - `sequence_length` 是序列的长度,即时间步的数量。
    - `batch_size` 是批量大小,表示同时处理的样本数量。
    - `input_size` 是输入特征向量的维度。

        要注意的是,这只是输入的形状,LSTM 层的参数(例如 `input_size` 和 `hidden_size`)必须与输入形状相匹配。根据你的具体任务和数据,你需要将输入数据整理成适当形状的张量,然后将其传递给 `torch.nn.LSTM` 层以进行前向传播。

(2)

  `(h_0, c_0)`:是包含初始隐藏状态和初始细胞状态的元组。
   - `h_0`:是初始隐藏状态,其形状为 `(num_layers * num_directions, batch_size, hidden_size)`。`num_layers` 是 LSTM 层的堆叠层数,`num_directions` 是 1 或 2,取决于是否使用双向 LSTM。
   - `c_0`:是初始细胞状态,其形状也为 `(num_layers * num_directions, batch_size, hidden_size)`。

LSTM的输出

`torch.nn.LSTM` 层的输出通常是一个包含两个元素的元组 `(output, (h_n, c_n))`,其中:

1. `output`:是一个包含每个时间步的 LSTM 输出的张量。其形状为 `(batch_size, sequence_length, num_directions * hidden_size)`【batch_first = True的情况下】,其中:
   - `sequence_length` 是序列的长度,即时间步的数量。
   - `batch_size` 是批量大小,表示同时处理的样本数量。
   - `num_directions` 是 1 或 2,取决于是否使用双向(bidirectional)LSTM。
   - `hidden_size` 是 LSTM 单元的隐藏状态的维度大小。

2. `(h_n, c_n)`:是包含最后一个时间步的隐藏状态和细胞状态的元组。
   - `h_n`:是最后一个时间步的隐藏状态,其形状为 `(num_layers * num_directions, batch_size, hidden_size)`。`num_layers` 是 LSTM 层的堆叠层数,`num_directions` 是 1 或 2,取决于是否使用双向 LSTM。
   - `c_n`:是最后一个时间步的细胞状态,其形状也为 `(num_layers * num_directions, batch_size, hidden_size)`。

        你可以选择是否要使用输出中的全部时间步的输出,或者只使用最后一个时间步的输出,具体取决于你的任务需求。

        通常,如果你只关心最终的输出,你可以使用 `output[-1]` 或 `h_n`。如果你需要完整的时间步输出序列,可以使用 `output`。这些输出可以传递到其他层或用于任务的后续处理。

 LSTM的权重参数

`torch.nn.LSTM`具有以下主要的权重参数(用于捕捉序列中的长期依赖关系):

1. `weight_ih_l[k]`:这是输入到LSTM单元的权重参数,其中k表示LSTM层的索引。`weight_ih_l[k]`的维度是(4 * hidden_size,input_size),其中hidden_size是LSTM隐藏状态的大小,input_size是输入数据的特征维度。这个权重参数控制着输入数据如何影响LSTM单元的状态。

2. `weight_hh_l[k]`:这是隐藏状态到LSTM单元的权重参数,其中k表示LSTM层的索引。`weight_hh_l[k]`的维度是(4 * hidden_size,hidden_size)。这个权重参数控制着前一个时间步的隐藏状态如何影响当前时间步的隐藏状态。

3. `bias_ih_l[k]`和`bias_hh_l[k]`:这是输入到LSTM单元和隐藏状态到LSTM单元的偏置参数,其中k表示LSTM层的索引。`bias_ih_l[k]`的维度是(4 * hidden_size),`bias_hh_l[k]`的维度也是(4 * hidden_size)。这些偏置参数用于调整输入和隐藏状态的影响。

以上权重参数中的4表示LSTM单元的门控机制,通常被称为输入门(input gate)、遗忘门(forget gate)、输出门(output gate)和细胞状态(cell state)。LSTM使用这些门来控制信息的流动,以捕捉长期依赖关系。

        要访问和修改这些权重参数,您可以使用`state_dict`属性来获取或设置模型的权重。例如,如果您有一个名为`lstm_model`的`torch.nn.LSTM`模型,您可以使用以下代码来获取权重参数的字典:lstm_weights = lstm_model.state_dict()。然后,您可以从`lstm_weights`字典中提取和修改特定的权重参数。请注意,修改权重参数可能会影响模型的性能,因此需要谨慎操作。

你还可以使用:

for k, v in lstm_model.named_parameters():
    print(k, v) # 打印权重参数名称及数值

方法得到模型的权重参数。

代码部分

        下述代码包括了官方API以及手写的LSTM源码。 


# 视频链接:
# https://www.bilibili.com/video/BV1zq4y1m7aH/?spm_id_from=333.788&vd_source=fb7bfda367c76676e2483b9b60485e57

# 实现LSTM 源码
# 定义常量
import torch
import torch.nn as nn
batch_size, T, input_size, hidden_size = 2, 3, 4, 5


input = torch.randn(batch_size, T, input_size)
c_0 = torch.randn(batch_size, hidden_size) # 初始细胞单元,不参与网络训练
h_0 = torch.randn(batch_size, hidden_size) # 初始隐藏状态

# 调用官方API
lstm_layer = nn.LSTM(input_size=input_size, hidden_size=hidden_size, batch_first=True)
output, (h_n, c_n) = lstm_layer(input, (h_0.unsqueeze(0), c_0.unsqueeze(0)))
print("LSTM API")
print("output:\n", output)
print("h_n:\n", h_n)
print("c_n:\n", c_n)

# for k, v in lstm_layer.named_parameters():
#     print(k, v)
lstm_weight = lstm_layer.state_dict() # 使用`state_dict`属性来获取或设置模型的权重
print("lstm_weight:\n", lstm_weight)

# 自己写一个LSTM模型
def lstm_forward(input, initial_states, w_ih, w_hh, b_ih, b_hh):
    """

    :param input:
    :param initial_states:
    :param w_ih:
    :param w_hh:
    :param b_ih:
    :param b_hh:
    :return:
    """
    h_0, c_0 = initial_states # 初始状态
    batch_size, T, input_size = input.shape
    hidden_size = w_ih.shape[0] // 4
    prev_h = h_0
    prev_c = c_0

    batch_w_ih = w_ih.unsqueeze(0).tile(batch_size, 1, 1) # [batch_size, 4*hidden_size, input_size]
    batch_w_hh = w_hh.unsqueeze(0).tile(batch_size, 1, 1) # [batch_size, 4*hidden_size, hidden_size]
    output_size = hidden_size
    output = torch.zeros(batch_size, T, output_size) # 输出序列

    for t in range(T):
        x = input[:, t, :] # 当前时刻的输入向量,[batch_size*input_size]
        w_times_x = torch.bmm(batch_w_ih, x.unsqueeze(-1)) # [batch_size, 4*hidden_size, 1]
        w_times_x = w_times_x.squeeze(-1) # [batch_size, 4*hidden_size]

        w_times_h_prev = torch.bmm(batch_w_hh, prev_h.unsqueeze(-1)) # [batch_size, 4*hidden_size, 1]
        w_times_h_prev = w_times_h_prev.squeeze(-1)  # [batch_size, 4*hidden_size]

        # 分别计算输入门(i)、遗忘门(f)、cell(g)、输出门(o)
        i_t = torch.sigmoid(w_times_x[:, :hidden_size] + w_times_h_prev[:, :hidden_size]
                            +b_ih[ :hidden_size] + b_hh[ :hidden_size])
        f_t = torch.sigmoid(w_times_x[:, hidden_size:2*hidden_size] + w_times_h_prev[:, hidden_size:2*hidden_size]
                            + b_ih[hidden_size:2*hidden_size] + b_hh[hidden_size:2*hidden_size])
        g_t = torch.tanh(w_times_x[:, 2*hidden_size:3*hidden_size] + w_times_h_prev[:, 2*hidden_size:3*hidden_size]
                            + b_ih[2*hidden_size:3*hidden_size] + b_hh[2*hidden_size:3*hidden_size])
        o_t = torch.sigmoid(w_times_x[:, 3*hidden_size:4*hidden_size] + w_times_h_prev[:, 3*hidden_size:4*hidden_size]
                            + b_ih[3*hidden_size:4*hidden_size] + b_hh[3*hidden_size:4*hidden_size])
        prev_c = f_t * prev_c + i_t * g_t
        prev_h = o_t * torch.tanh(prev_c)

        output[:, t, :] = prev_h

    return output, (prev_h, prev_c)


output_custom, (h_final_custom, c_final_custom) = lstm_forward(input=input, initial_states = (h_0, c_0), w_ih=lstm_layer.weight_ih_l0,
             w_hh=lstm_layer.weight_hh_l0, b_ih=lstm_layer.bias_ih_l0, b_hh=lstm_layer.bias_hh_l0)

print("LSTM custom")
print("output_custom:\n", output_custom)
print("h_final_custom:\n", h_final_custom)
print("c_final_custom:\n", c_final_custom)

LSTM模型输入输出可视化理解

 

 

图文来自:pytorch中LSTM参数详解(一张图帮你更好的理解每一个参数)_pytorch lstm 参数一图_xjtuwfj的博客-CSDN博客

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

一文读懂LSTM及手写LSTM结构 的相关文章

随机推荐

  • 每日一题:选数

    选数 题目 Daimayuan Online Judge 原本我的思路是 大致题意就是从n个数中选取若干数 使得它们的和mod n等于0 任意选取 无关顺序 是可以跳着选的 也就是对于每一个数 有两种选择 选与不选 于是我想用01背包 但是
  • linux屏保默认图片,分享

    Ubuntu 4 10 Warty Warthog Ubuntu 5 04 Hoary Hedgehog Ubuntu 5 10 Breezy Badger Ubuntu 6 06 Dapper Drake Ubuntu 6 10 Edgy
  • 多个chatgpt模型

    GPT4All 简介 GPT4AllNomic AI Team 从 Alpaca 获得灵感 使用 GPT 3 5 Turbo OpenAI API 收集了大约 800 000 个提示 响应对 创建了 430 000 个助手式提示和生成训练对
  • SQL中in和not in遇到NULL值的查询情况

    首先 大家可以先试着做这道练习题 题目 给定一个表 T id 是树节点的编号 pid 是它父节点的 id 树中每个节点属于以下三种类型之一 叶子 如果这个节点没有任何孩子节点 根 如果这个节点是整棵树的根 即没有父节点 内部节点 如果这个节
  • 最简单的引入Vue看板娘教程

    最简单的引入Vue看板娘教程 一 项目引入 这里使用的是来自Evgo老哥的 vue live2d 使用简单 直接引入就好 具体的可以看老哥的文档 二 简单实现 在你的项目引入 npm install vue live2d 接着在想要引入看板
  • Unity3D关于iTween知识详解和接口总结

    目录 1 简介 2 物体移动方法 3 物体的颜色变化 4 摄像机淡入淡出 5 音频方法 6 Look类方法 7 旋转方法 8 物体大小缩放 9 晃动效果方法 10 值方法 11 外部工具方法 12 iTweenPath 1 简介 iTwee
  • Windows NT Session 概念的原理与应用浅析 [1] 遍历并获取信息

    转载自 http www cnblogs com flier archive 2004 07 19 25709 html 我在上一篇文章 DACL NULL or not NULL 中曾简要地介绍了 Windows 系统中 Session
  • SonarLint 默认扫描规则

    请看原文 原文链接 https blog csdn net jiaomubai article details 116596868 在平时写代码的时候 为了代码规范和减少 bug 的数量 使用 SonarLint 插件进行代码检查无疑是一个
  • tesseract-ocr

    转自 http www cnblogs com zcsor archive 2011 02 21 1959555 html 关于tesseract ocr3的训练和使用 众所周知 这是一个出色的字符识别软件 这个开源项目可以在http co
  • SpringBoot企业微信公众号开发案例一

    Java Spring Boot企业微信点餐系统 2 3 数据库设计 2 3 数据库设计 数据库表设计 3 1 开发环境搭建 3 1 开发环境搭建 virtualBox安装 ifconfig sequel pro连接数据库 虚拟机安装Mav
  • 【Linux系统安装RocketMQ并整合到SpringBoot项目】

    Linux系统安装RocketMQ并整合到SpringBoot项目 一 基本概念 1 1 NameServer 1 2 Broker 1 3 Message 1 3 Topic 1 4 Tag 1 5 Queue 1 6 MessageId
  • SpringBoot-AOP记录日志+脱敏

    目录 1 引入依赖 2 配置文件 3 注解类 4 切面类 5 需要脱敏的实体类 6 查看日志打印 1 引入依赖 脱敏工具包 implementation com github houbb sensitive core 0 0 9 imple
  • 虚拟机实现拖拽、复制、粘贴功能

    一 设置CD DVD驱动为自动检测 二 重新安装VMware Tools 在重启虚拟机的过程中 重新安装VMware Tools 会变亮 切记是重启过程中会变亮 虚拟机开启后又会变暗 下载完成后 双击进入CD虚拟驱动 将拷贝至桌面 依次执行
  • 青岛科技大学计算机专硕拟录取分数,又一批拟录取名单来了!最低分数...

    原标题 又一批拟录取名单来了 最低分数 最近 考研复试的动静闹得有点大 但是 这也没办法啊 按照往年的3月份就可以结束的事情 今年愣是拖到了五月底 各位20考研学生的心里苦啊 虽然所考研复试和咱们21考研的小伙伴 目前来看关系不大 但是 大
  • 关于el-table鼠标悬浮变色

    deep el table body el table row hover row td background color 颜色
  • HwServiceManager篇-Android10.0 HwBinder通信原理(五)

    Android取经之路 的源码都基于Android Q 10 0 进行分析 Android取经之路 系列文章 系统启动篇 Android系统架构Android是怎么启动的Android 10 0系统启动之init进程Android10 0系
  • 关于在使用tensorflow2.0版本时,出现RuntimeError:tf.placeholder() is not compatible with eager execution.的问题

    关于在使用tensorflow2 0版本时 出现RuntimeError tf placeholder is not compatible with eager execution 的问题 今天在运行程序 部分代码 import tenso
  • 可调时钟频率——IP核PLL/MMCM仿真及上板实验(Vivado)

    一 说明 实现平台 vivado2018 3 实验设备 领航者V2 ZYNQ7020 示波器一台 二 原理 在大多数实验项目中 所需要的时钟频率不尽相同 因此需要根据需求产生特定的时钟频率 PLL Phase Locked loop 锁相环
  • 关于css全局作用域(:global)和局部作用域(:local)

    local title color red 上面这种写法相当于 title color red global title color green
  • 一文读懂LSTM及手写LSTM结构

    torch nn LSTM 是PyTorch中用于创建长短时记忆网络 Long Short Term Memory 的类 LSTM是一种用于处理序列数据的循环神经网络 Recurrent Neural Network RNN 变体 官方给出