pytorch源码解读——RNN/LSTM篇

2023-11-16

文章的字母中:

b: batch_size
t: time_step
n: num_feature
h: hidden_size

假设输入数据维度input = (b, t, n)
所设计的LSTM模型如下:

class MYLSTM(nn.Module):

    def __init__(self, input_size, hidden_size, out_size):
        super(MYLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.input_size = input_size
        
        self.lstm = nn.LSTM(
            input_size=self.input_size + self.hidden_size,
            hidden_size=self.hidden_size,
            num_layers=1,
            batch_first=True,
        )

        self.out = nn.Linear(self.hidden_size, out_size)

    def forward(self, x):
        hidden, cell = Variable(torch.zeros(1, x.size(0), self.hidden_size)),\
                       Variable(torch.zeros(1, x.size(0), self.hidden_size))
        for i in range(x.size(1)):
            curx = x[:, i, :].unsqueeze(1)
            curx = torch.cat((curx, hidden.permute(1, 0, 2)), dim=2)
            _, lstm_state = self.lstm(curx, (hidden, cell))
            hidden, cell = lstm_state[0], lstm_state[1]
            outs = self.out(hidden)
        return outs

由于num_layer=1,因此hidden,cell的维度均为(1, b, h)
对于每一个时间步,将其与hidden拼接,得到(b, 1, h + n)维度的curx,此对应下图中红框torch.cat
这个整体作为torch中LSTM单元的输入
在modules\rnn.py中,存在这样一段代码:

        if mode == 'LSTM':
            gate_size = 4 * hidden_size
        elif mode == 'GRU':
            gate_size = 3 * hidden_size
        else:
            gate_size = hidden_size
        
		self._all_weights = []
        for layer in range(num_layers):
            for direction in range(num_directions):
                layer_input_size = input_size if layer == 0 else hidden_size * num_directions

                w_ih = Parameter(torch.Tensor(gate_size, layer_input_size))
                w_hh = Parameter(torch.Tensor(gate_size, hidden_size))
                b_ih = Parameter(torch.Tensor(gate_size))
                b_hh = Parameter(torch.Tensor(gate_size))
                layer_params = (w_ih, w_hh, b_ih, b_hh)

                suffix = '_reverse' if direction == 1 else ''
                param_names = ['weight_ih_l{}{}', 'weight_hh_l{}{}']
                if bias:
                    param_names += ['bias_ih_l{}{}', 'bias_hh_l{}{}']
                param_names = [x.format(layer, suffix) for x in param_names]

                for name, param in zip(param_names, layer_params):
                    setattr(self, name, param)
                self._all_weights.append(param_names)

这里的符号跟我上面的图略有不符,因为我习惯于纵向拼接,放下面这个原始的LSTM状态公式可能更好对应一些:
gate
首先根据LSTM网络的特点,或直接看状态计算公式,共有四个地方用到了拼接的输入即计算,因此gate_size = 4 * hidden_size,即相当于把上面的四个Wh和Wx各自合并在一起,各自偏置也合并,方便定义域运算,这个后面还会拆分,分别用于各部分的计算
而由于我们每次的输入均为(b, 1, n + h),因此layer_input_size = n + h
这样所有需要用到的权重和偏置均已求得,用_all_weights进行包装

此后,同样是在modules\rnn.py文件中

			func = self._backend.RNN(
            self.mode,
            self.input_size,
            self.hidden_size,
            num_layers=self.num_layers,
            batch_first=self.batch_first,
            dropout=self.dropout,
            train=self.training,
            bidirectional=self.bidirectional,
            dropout_state=self.dropout_state,
            variable_length=is_packed,
            flat_weight=flat_weight
        )
        output, hidden = func(input, self.all_weights, hx, batch_sizes)

func将所有参数重新包装并计算,计算过程在_functions\rnn.py中:

    def forward(input, weight, hidden, batch_sizes):
        if batch_first and not variable_length:
            input = input.transpose(0, 1)

        nexth, output = func(input, hidden, weight, batch_sizes)

        if batch_first and not variable_length:
            output = output.transpose(0, 1)

        return output, nexth

上面提到input = (b, 1, n + h),第一维为batch_size, 即batch_first = True, 于是先将其前两维转置,即此时input = (1, b, n + h)
第一维的1实际代表了LSTM的层数与是否双向,因此此后的运算仅针对单层LSTM进行运算,即此后的input = (b, n + h)
_functions\rnn.py

	hx, cx = hidden
    gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh)

    ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)

    ingate = torch.sigmoid(ingate)
    forgetgate = torch.sigmoid(forgetgate)
    cellgate = torch.tanh(cellgate)
    outgate = torch.sigmoid(outgate)

    cy = (forgetgate * cx) + (ingate * cellgate)
    hy = outgate * torch.tanh(cy)

    return hy, cy

input = (b, n + h), w_hh = (4 * h, h), w_ih = (4* h, n + h)
F.linear是线性操作,无论是CNN、RNN都很常用,其定义如下:

def linear(input, weight, bias=None):
    r"""
    Applies a linear transformation to the incoming data: :math:`y = xA^T + b`.

    Shape:

        - Input: :math:`(N, *, in\_features)` where `*` means any number of
          additional dimensions
        - Weight: :math:`(out\_features, in\_features)`
        - Bias: :math:`(out\_features)`
        - Output: :math:`(N, *, out\_features)`
    """
    if input.dim() == 2 and bias is not None:
        # fused op is marginally faster
        return torch.addmm(bias, input, weight.t())

    output = input.matmul(weight.t())
    if bias is not None:
        output += bias
    return output

比较容易看懂,返回input * weight.t() + bias这样的矩阵
于是经过线性变换后,返回的gates = (b, 4 * h)
然后通过chunk()函数,将gates的第一维切分为四份
于是ingate, forgetgate, cellgate, outgate = (b, h)
分别对ingate forgetgate outgate作sigmoid,对cellgate作tanh,注意 * 运算是点积,而不是矩阵乘法,前述代码配合下图饮用更佳,感觉均能一一对应:
lstm
如此即结束了第一个时间步的hidden、cell计算,有多少个时间步,循环迭代即可,最后一步的hidden即可作为最终输出

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

pytorch源码解读——RNN/LSTM篇 的相关文章

随机推荐

  • 解决微信小程序报错:[渲染层网络层错误] Failed to load local image resource

    一 场景 写了一个图片点击 全屏展示的组件 页面图片 gt 点击 gt 打开全屏遮罩层显示大图片 1控制元素展示的变量 data photoShow false 2图片点击函数 onClick const url null e curren
  • Shell的read 读取控制台输入、read的使用

    文章目录 1 read 读取控制台输入 1 1基本语法 1 2read的使用 如果想看更详细的Shell总结请到我之前写的博客https blog csdn net Redamancy06 article details 126048299
  • com.sun.org.apache.xerces.internal.impl.io.MalformedByteSequenceException: Invalid byte 2 of 2-byte

    com sun org apache xerces internal impl io MalformedByteSequenceException Invalid byte 2 of 2 byte UTF 8 sequence 分析 这个问
  • YOLO-----关于正负样本、Loss、IOU、怎样去平衡正负样本的问题?

    关于正负样本 Loss IOU 怎样去平衡正负样本的问题 1 关于正负样本 2 Loss计算 3 IOU GIOU DIOU CIOU 4 怎样去平衡正负样本的问题 先整理一下anchor的概念 常用的anchor定义 Faster R C
  • MySQL 8 安装教程

    MySQL 8发布了 据说相比MySQL 5速度提升了2倍 今天来搞一搞MySQL 8 一 下载MySQL 8 1 首先当然是下载安装包了 下载地址 点击下载MySQL 8 这个页面相信大家都熟悉 我就不多说了 2 将下载的压缩包解压 解压
  • 全网最简洁的mpy-cross教程

    大家知道我一向精干 不喜欢搞花儿的 如果去mpy官网看mpy cross的相关资料 估计又得绕蒙 跟我来 保证你三分钟学会 但是本文不涉及原理 第一 mpy cross是干嘛滴 答 把py文件转成mpy系统读的mpy文件 术语咱不懂 叫交叉
  • H3C交换机如何配置SNMP协议?

    1 使用telnet 登陆设备 system view snmp agent snmp agent community read public snmp agent sys infoversion all dis cur save 保存 Y
  • 操作系统原理大题

    一 地址变换和求FAT表大小 某一页表内容自0 7依次为03 07 0B 11 1A 1D 20 22 请计算页面大小为1K和4K时的逻辑地址134D对应的物理地址 首先 将134D转换为二进制数为 0001001101001101 1k为
  • 【2024届校招内推:NTAA84y】腾讯云智研发中心

    云智校招新官网查看最新岗位情况 云智研发中心2024届校园招聘官网 内推码 NTAA84y 云智研发公司2024届校园招聘启动啦 腾讯旗下子公司 八大类岗位 五大城市全面开放 在喜欢的城市 做喜欢的工作 期待正能量 共担当 实干家的你加入云
  • dumpsys meminfo 的原理和应用

    什么是dumpsys meminfo Android中通过命令dumpsys meminfo package name pid 查看指定进程的内存使用情况 通过输出的信息 可以看出来应用在内存哪里分配出现了问题 比如native heap
  • 华为服务器sn号查询网站,linux 查询服务器sn

    linux 查询服务器sn 内容精选 换一换 Linux云服务器变更规格时 可能会发生磁盘挂载失败的情况 因此 变更规格后 需检查磁盘挂载状态是否正常 本节操作介绍变更规格后检查磁盘挂载状态的操作步骤 以root用户登录云服务器 执行以下命
  • top 命令

    NAME top display Linux tasks SYNOPSIS top hv abcHimMsS d delay n iterations p pid pid a 按内存使用排序 b 批处理 c 显示完整的命令 d 指定间隔时间
  • 文章目录 定义 抽象类型定义 存储结构 顺序存储 定长顺序存储结构 堆式顺序存储结构 链式存储 串的链式存储结构 定义 串是一种内容受限的线性表 串 字符串 由零个或多个字符组成的有限序列 子串 串的任意个连续的字符组成的子序列 主串 包含
  • 深度学习部署--tensorflow 用c++调用前向

    使用TensorFlow C API构建线上预测服务 第一篇 Oct 9 2017 tensorflow 文章目录 1 使用Python接口训练模型 2 源码编译TensorFlow 3 使用TensorFlow C API编写预测代码 3
  • 线下零售场景的消费者商品场景终端数字升级

    按照识别的精度排序 确实是虹膜 指纹 人脸的识别精度依次降低 但人脸识别可以根据摄像头的提升而提升 双目摄像头 结构光摄像头 TOF等等 这个上升空间很大 从应用性来看 你现在让所有的用户都去提取虹膜信息 指纹信息 这个很难 不现实 而我们
  • 微信小游戏 can't find variable: window

    最近测试微信小游戏的时候 需要加入一些SDK代码 在加入这些文件到项目并require相应的库的时候 小游戏开发者工具一直报错 can t find variable window 查找了相关资料 https developers weix
  • 逐点插入法实现 Delaunary三角网 ( 附 C++ 代码)

    逐点插入法作为一种经典的凸闭包收缩算法 其思想是 首先找到包含数据区域的最小凸包边形 并从该多边形开始从外到内形成Delaunary三角网 因此其每次插入一个新的点就会删除相应的三角形来构建性的三角网 这个过程中常常伴随着大量的查询计算过程
  • flutter 从A到B,然后在从B返回A页面,A页面刷新数据

    flutter 从A到B 然后在从B返回A页面 A页面刷新数据实现代码如下 Navigator push context MaterialPageRoute builder context gt NoticePage then value
  • 解决java.lang.IllegalArgumentException: Could not resolve placeholder xx.xx.addr 的问题,思路:一定是配置文件问题

    今天启动SpringBoot遇到一个问题 提示 java lang IllegalArgumentException Could not resolve placeholder xx xx addr in value xx xx addr
  • pytorch源码解读——RNN/LSTM篇

    文章的字母中 b batch size t time step n num feature h hidden size 假设输入数据维度input b t n 所设计的LSTM模型如下 class MYLSTM nn Module def