循环神经网络-LSTM

2023-11-10

参考
长期以来,隐变量模型存在着长期信息保存和短期输入缺失的问题。 解决这一问题的最早方法之一是长短期存储器(long short-term memory,LSTM) (Hochreiter and Schmidhuber, 1997)。 它有许多与门控循环单元( 9.1节)一样的属性。 有趣的是,长短期记忆网络的设计比门控循环单元稍微复杂一些, 却比门控循环单元早诞生了近20年。

从零开始实现

import torch
from torch import nn
from d2l import torch as d2l

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

def get_lstm_params(vocab_size, num_hiddens, device):
    num_inputs = num_outputs = vocab_size

    def normal(shape):
        return torch.randn(size=shape, device=device)*0.01

    def three():
        return (normal((num_inputs, num_hiddens)),
                normal((num_hiddens, num_hiddens)),
                torch.zeros(num_hiddens, device=device))

    W_xi, W_hi, b_i = three()  # 输入门参数
    W_xf, W_hf, b_f = three()  # 遗忘门参数
    W_xo, W_ho, b_o = three()  # 输出门参数
    W_xc, W_hc, b_c = three()  # 候选记忆元参数
    # 输出层参数
    W_hq = normal((num_hiddens, num_outputs))
    b_q = torch.zeros(num_outputs, device=device)
    # 附加梯度
    params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,
              b_c, W_hq, b_q]
    for param in params:
        param.requires_grad_(True)
    return params
def init_lstm_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device),
            torch.zeros((batch_size, num_hiddens), device=device))
def lstm(inputs, state, params):
    [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,
     W_hq, b_q] = params
    (H, C) = state
    outputs = []
    for X in inputs:
        I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)
        F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)
        O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)
        C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)
        C = F * C + I * C_tilda
        H = O * torch.tanh(C)
        Y = (H @ W_hq) + b_q
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H, C)

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,
                            init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

简洁实现

num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = d2l.RNNModel(lstm_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

循环神经网络-LSTM 的相关文章

随机推荐

  • WEB攻击与防御

    这里列举一些常见的攻击类型与基本防御手段 XSS攻击 跨站脚本 Cross site scripting 简称XSS 把JS代码注入到表单中运行例如在表单中提交含有可执行的JS的内容文本 如果服务器端没有过滤或转义这些脚本 而这些脚本由通过
  • 判断带头结点的循环双链表是否对称

    题目 设计一个算法 用于判断带头结点的循环双链表是否对称 分析 循环双链表的特点是 当前结点方便找到前后节点 且尾指针指向第一个结点 对称性 判断第一个结点和最后一个结点的值是否相等 如果相等 再判断第二个结点和倒数第二个结点 以此类推 从
  • [论文阅读] (18)英文论文Model Design和Overview如何撰写及精句摘抄——以系统AI安全顶会为例

    娜璋带你读论文 系列主要是督促自己阅读优秀论文及听取学术讲座 并分享给大家 希望您喜欢 由于作者的英文水平和学术能力不高 需要不断提升 所以还请大家批评指正 非常欢迎大家给我留言评论 学术路上期待与您前行 加油 前一篇介绍CCS2019的P
  • 电子设计大赛作品_电子设计大赛

    为了进一步提高学生对电子和科技的兴趣 培养学生的动手能力和想象能力 增强学生的团队合作意识 提高学生分析和解决问题的能力 现决定开展电子设计大赛 电子设计大赛详情 一 参赛对象 全体全日制在校大学生 1 3人自由组队 并指定队长一名 可自由
  • 华为OD机试 - 简易内存池(Java)

    题目描述 请实现一个简易内存池 根据请求命令完成内存分配和释放 内存池支持两种操作命令 REQUEST和RELEASE 其格式为 REQUEST 请求的内存大小 表示请求分配指定大小内存 如果分配成功 返回分配到的内存首地址 如果内存不足
  • java-map-put方法源码分析

    HashMap是由数组 链表和红黑树组成的数据结构 而其中put方法可以算的上HashMap中的核心方法 这个方法给我们展示了HashMap的大部分精髓 我们首先来看一下map的核心变量 transient Node
  • 2022年一起努力应对互联网寒冬吧,5G音视频时代还不学NDK开发吗

    前言 找工作还是需要大家不要紧张 有我们干这一行的接触人本来就不多 难免看到面试官会紧张 主要是因为怕面试官问的问题到不上来 那时候不要着急 答不上了的千万不然胡扯一些 直接就给面试官说这块我还没接触到 以后如果工作当中遇到的话我可以很快的
  • i2c 编程接口

    1 通信接口 i2c发送或者接收一次数据都以数据包 struct i2c msg 封装 struct i2c msg u16 addr 从机地址 u16 flags 标志 define I2C M TEN 0x0010 十位地址标志 def
  • Vert.X通过Hoverfly满足服务虚拟化

    服务虚拟化是一种用于模拟基于组件的应用程序的依赖关系行为的技术 Hoverfly是用Go语言编写的服务虚拟化工具 可让您模拟HTTP S 服务 它是一个代理 它使用存储的响应来响应HTTP S 请求 并假装它是真正的对应对象 食蚜蝇Java
  • 使用 IO 流 读取 本 地 文 件 (两种方式)

    使用IO 流读取本地文件 public class FileReadWrite public static void main String args FileReader fr null try 1 创建读取文件 fr new FileR
  • [Manjaro] OpenGL 配合着色器实现光线跟踪之引入光线

    概述 本文介绍 GLFW GLAD 在 RayTracing in one weekend 的实现 实验环境 Manjaro Linux 22 0 0 整体思路 使用基于屏幕空间的光线跟踪算法 每个像素点代表一个光线 使用 GLSL 着色器
  • 在IntelliJ IDEA中查看代码覆盖率

    在IDEA中使用Junit测试时 时常需要考虑代码覆盖率 以下是查看代码覆盖率的方法 在test class右键选择 more run debug gt Run Test with Coverage 使用时发现会出现没有Branch Cov
  • ubuntu上安装最新的docker社区版

    如果安装有老的docker先删除老的版本 sudo apt get remove docker docker engine docker io 老的镜像 存储 网络信息保留在 var lib docker 下 可以自行删除 新的社区版本叫d
  • 矩阵的转置怎么编程用C语言,将一个3x3的矩阵转置,怎样用c语言写?

    include
  • 学建模时常遇到的问题(看专业解答)

    常做到一半就卡住 那是不是操作时有问题 没有处理好 选择面数选不中 选择线选不中 还只能用最初始的命令 这些有一种可能就是你操作模型时 按到了空格键 然后界面就会锁死 按空格键就阔以取消啦 还有一个情况就是 进入可编辑模式之后 对模型进行点
  • 权限提升-MYSQL数据库提权

    基础知识 1 需要了解掌握的权限 后台权限 网站权限 数据库权限 接口权限 系统权限 域控权限等 2 权限获取方法简要归类说明 后台权限 SQL注入 数据库备份泄露 默认或弱口令等获取帐号密码进入 网站权限 后台提升至网站权限 RCE或文件
  • zookeeper到nacos的迁移实践

    本文已收录 https github com lkxiaolou lkxiaolou 欢迎star 技术选型 公司的RPC框架是dubbo 配合使用的服务发现组件一直是zookeeper 长久以来也没什么大问题 至于为什么要考虑换掉zook
  • azkaban上传zip报错:Error Chunking during uploading files to db

    上传时页面报 Instalation Failed Error Chunking during uploading files to db 查看web server日志 2021 11 26 11 20 38 253 0800 INFO P
  • vue图片上传组件

    vue图片上传组件 最近在做项目的时候顺便补充了一下公司项目的公共组件库 刚刚手头事情告一段落 就来做个笔记 首先来看看最终效果 1 不允许删除 2 允许用户删除 显示删除按钮 实现的效果就是上图显示内容 接下来说说组件布局那部分直接上代码
  • 循环神经网络-LSTM

    参考 长期以来 隐变量模型存在着长期信息保存和短期输入缺失的问题 解决这一问题的最早方法之一是长短期存储器 long short term memory LSTM Hochreiter and Schmidhuber 1997 它有许多与门