Pytorch 的 LSTM 模型的简单示例

2023-10-27

1. 代码

完整的源代码:

import torch
from torch import nn

# 定义一个LSTM模型
class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super(LSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        # 初始化隐藏状态h0, c0为全0向量
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)

        # 将输入x和隐藏状态(h0, c0)传入LSTM网络
        out, _ = self.lstm(x, (h0, c0))
        # 取最后一个时间步的输出作为LSTM网络的输出
        out = self.fc(out[:, -1, :])
        return out

# 定义LSTM超参数
input_size = 10   # 输入特征维度
hidden_size = 32  # 隐藏单元数量
num_layers = 2    # LSTM层数
output_size = 2   # 输出类别数量

# 构建一个随机输入x和对应标签y
x = torch.randn(64, 5, 10)  # [batch_size, sequence_length, input_size]
y = torch.randint(0, 2, (64,))  # 二分类任务,标签为0或1

# 创建LSTM模型,并将输入x传入模型计算预测输出
lstm = LSTM(input_size, hidden_size, num_layers, output_size)
pred = lstm(x)  # [batch_size, output_size]

# 定义损失函数和优化器,并进行模型训练
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(lstm.parameters(), lr=1e-3)
num_epochs = 100

for epoch in range(num_epochs):
    # 前向传播计算损失函数值
    pred = lstm(x)  # 在每个epoch中重新计算预测输出
    loss = criterion(pred.squeeze(), y)

    # 反向传播更新模型参数
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # 输出每个epoch的训练损失
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

2. 模型结构分析

# 定义一个LSTM模型
class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super(LSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        # 初始化隐藏状态h0, c0为全0向量
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)

        # 将输入x和隐藏状态(h0, c0)传入LSTM网络
        out, _ = self.lstm(x, (h0, c0))
        # 取最后一个时间步的输出作为LSTM网络的输出
        out = self.fc(out[:, -1, :])
        return out

上述代码定义了一个LSTM类,这个类可以用于完成一个基于LSTM的序列模型的搭建。

在初始化函数中,输入的参数分别是输入数据的特征维度(input_size),隐藏层的大小(hidden_size),LSTM层数(num_layers)以及输出数据的维度(output_size)。这里使用batch_first=True表示输入数据的第一个维度是batch size,第二个维度是时间步长和特征维度。

在forward函数中,首先初始化了LSTM网络的隐藏状态为全0向量,并且将其移动到与输入数据相同的设备上。然后调用了nn.LSTM函数进行前向传播操作,并且通过fc层将最后一个时间步的输出映射为输出的数据,最后进行了返回。

3. 代码详解

        # 将输入x和隐藏状态(h0, c0)传入LSTM网络
        out, _ = self.lstm(x, (h0, c0))

这行代码是利用 PyTorch 自带的 LSTM 模块处理输入张量 x(形状为 [batch_size, sequence_length, input_size])并得到 LSTM 层的输出 out 和最终状态。其中,h0 是 LSTM 层的初始隐藏状态,c0 是 LSTM 层的初始细胞状态。

在代码中,调用了 self.lstm(x, (h0, c0)) 函数,该函数的返回值有两个:第一个返回值是 LSTM 层的输出 out,其包含了所有时间步上的隐状态;第二个返回值是一个元组,包含了最后一个时间步的隐藏状态和细胞状态,但我们用“_”丢弃了它。

因为对于许多深度学习任务来说,只需要输出序列的最后一个时间步的隐藏状态,而不需要每个时间步上的隐藏状态。因此,这里我们只保留 LSTM 层的输出 out,而忽略了 LSTM 层最后时间步的状态。

最后,out 的形状为 [batch_size, sequence_length, hidden_size],其中 hidden_size 是 LSTM 层输出的隐藏状态的维度大小。

x = torch.randn(64, 5, 10)

这行代码创建了一个形状为 (64, 5, 10) 的张量 x,它包含 64 个样本,每个样本具有 5 个特征维度和 10 个时间步。该张量的值是由均值为 0,标准差为 1 的正态分布随机生成的。

torch.randn() 是 PyTorch 中生成服从标准正态分布的随机数的函数。它的输入是张量的形状,输出是符合正态分布的张量。在本例中,形状为 (64, 5, 10) 表示该张量包含 64 个样本,每个样本包含 5 个特征维度和 10 个时间步,每个元素都是服从标准正态分布的随机数。这种方式生成的随机数可以用于初始化模型参数、生成噪音数据等许多深度学习应用场景。

y = torch.randint(0, 2, (64,))  # 二分类任务,标签为0或1

y = torch.randint(0, 2, (64,)) 是使用 PyTorch 库中的 randint() 函数来生成一个64个元素的张量 y,张量的每个元素都是从区间 [0, 2) 中随机生成的整数。

具体而言,torch.randint() 函数包含三个参数,分别是 low、high 和 size。其中,low 和 high 分别表示随机生成整数的区间为 [low, high),而 size 参数指定了生成的张量的形状。

在上述代码中,size=(64,) 表示生成的张量 y 的形状为 64x1,即一个包含 64 个元素的一维张量,并且每个元素的值都在 [0, 2) 中随机生成。这种形式的张量通常用于分类问题中的标签向量。在该任务中,一个标签通常由一个整数表示,因此可以采用使用 randint() 函数生成一个长度为标签类别数的一维张量,其每个元素的取值为 0 或 1,表示对应类别是否被选中。

# 创建LSTM模型,并将输入x传入模型计算预测输出
lstm = LSTM(input_size, hidden_size, num_layers, output_size)
pred = lstm(x)  # [batch_size, output_size]

通过定义的LSTM类创建了一个LSTM模型,并将输入x传入模型进行前向计算,得到了一个预测输出pred,其形状为[64, output_size],其中output_size是在LSTM初始化函数中指定的输出数据的维度。

这段代码演示了如何使用已经构建好的代码搭建并训练一个基于LSTM的序列模型,并且展示了其中的一些关键步骤,包括数据输入、模型创建以及前向计算。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch 的 LSTM 模型的简单示例 的相关文章

随机推荐

  • Pandas 之 过滤DateFrame中所有小于0的值并替换

    Outline 前几天 数据清洗时有用到pandas去过滤大量数据中的 负值 把过滤出来的 负值 替换为 NaN 或者指定的值 故做个小记录 读取CSV文件 代码 import pandas as pd import numpy as np
  • 万能指针:void * 指针

    背景 最近看到void 类型的指针不知道该怎么处理 特别学习一下 适用语言 C C 当中都可以使用 但就目前认知水平 C当中用的较为普遍一些 void 指针的机制 指针从某种程度上来说 无非就是一个地址 它的类型只是用于说明数据结构的 指针
  • RISC-V指令集是一种精简的、可编程的指令集,它主要用于实现各种复杂的数据处理与控制任务。它提供了一系列简单的、可编程的指令,可以用来实现复杂的操作,比如addi指令,它可以将一个常数(如0x1)加...

    RISC V指令集是一种精简可编程的指令集 可以用来实现复杂的数据处理和控制操作 它提供了一系列简单可编程的指令 例如addi指令 它可以将一个常数加到寄存器中 并将结果存储到另一个寄存器中 从而实现特定的操作
  • 小白学Linux之#pragma的用法

    预编译指令 pragma的用法 最近在看开源项目中的代码时 发现许多地方都用到了 pragma的程序 因此 就问了下谷歌老师 总结了下 pragma预编译指令的常用用法 现在和大家分享下 一 pragma最常用的方法 1 progma pa
  • 【Node】package.json文件

    package json 文件详解前言一 package json 文件作用二 package json 文件创建三 package json 文件示例四 package json 文件配置说明 五 项目依赖 六 开发依赖 七 Node j
  • 【Linux】工具(5)——gdb

    今天我们来到Linux工具的最后一篇博客 gdb的使用 目录 一 Linux下的release和debug 二 gdb常用指令选项 一 Linux下的release和debug 我们先来写一个Makfile 来方便我们编译代码 再来写一个t
  • C# 中的多线程和异步编程

    目录 前言 1 并发 并行 异步 同步 的概念 区别以及使用场景 1 并发和并行 2 同步和异步 3 何时使用多线程编程 何时使用异步编程 2 基础知识 1 简介及概念 1 1Join 和 Sleep 1 2线程是如何工作的 1 3线程 v
  • MySql事务和存储引擎

    目录 一 MySQL 事物 1 事务的概念 2 事务的ACID特点 2 1 1 原子性 2 1 2 一致性 2 1 3 隔离性 2 1 4 Mysql 及事物隔离级别 查询全局事务隔离级别 查询会话事务隔离级别 设置全局事务隔离级别 设置会
  • DRF---序列化组件

    目录 序列化器Serializer 序列化组件基本使用 使用序列化类 序列化多条数据 使用序列化类 序列化单条数据 反序列化 新增 修改 新增 视图类 序列化类 视图类 序列化类 序列化类的常见字段类和常见参数 常用字段类型 选项参数 通用
  • 【Linux线程同步】生产者消费者模型

    文章目录 1 peach 线程互斥中可能还会存在的问题 peach 2 peach 线程同步 peach 2 1 apple 同步概念与竞态条件 apple 2 2 apple 条件变量函数 apple lemon 初始化 lemon le
  • Qt5.15源码编译详解

    1 请先参考 https blog csdn net weixin 60395515 article details 127284046 spm 1001 2014 3001 5501 2 有以下几个不同的地方需要修改 Qt5的mkspec
  • 超详细解决困扰人的python典例:“有n个人围成一圈”式n里挑一

    自学python No 2 引语 题目 案例实现 range 函数 append 函数 pop 函数 完整代码 引语 记录学习路程 抛砖引玉 如有更好的算法或者出现错误 欢迎指点 题目 有n个人围成一圈 顺序排号 从第一个人开始报数 从1到
  • 汽车之家各种车型参数爬虫

    汽车之家各种车型参数爬虫 结果如下 本案例使用jupyter notebook 用到requests BeautifulSoup lxml urlencode pandas五个库 爬取下来的数据如下图所示 详细过程 整个过程分成三个部分 1
  • ubuntu系统信息查询(主板,内存,硬盘,网卡)

    1 主板型号 主板支持最大内存 单条内存的参数 sudo dmidecode t 2 查看主板信息 sudo dmidecode t 16 grep Maximum 查看主板支持最大内存 sudo dmidecode t memory 查看
  • JDBC、连接步骤(4步)、需要导入的第三方jar包、开发步骤

    1 JDBC Java Database Connectivity java连接数据库的工具 1 1 什么是JDBC 他是java提供的一组API 用来提供连接数据库中需要用到的类和接口 他是一组规范 为不同数据库封装相同接口的一组规范 让
  • 基于 Web 的 LDAP 认证,访问资源就是这么安全

    轻量级目录访问协议 即 LDAP 协议 是微软 Active Directory AD 和 OpenLDAP 等传统身份管理解决方案中的核心身份认证协议 然而 IT 环境的不断发展暴露了传统方案的问题 基于本地部署的设计逻辑无法适应新兴的云
  • Unity2D游戏无限刷新地图

    关于Unity2D游戏如何无限刷新地图的问题 首先在Unity中创建多个大小相同的物体当做刷新的地图对象 然后在创建一个名称为Endless cs的脚本 然后添加如下代码 public float distance void OnBecam
  • cmake(三十五)Cmake之include指令

    一 CMakeLists txt和cmake脚本的联系和区别 cmake脚本 1 cmake文件里面通常是 什么信息 information cmake文件 里包含了一些 公共 复用 的 cmake命令 和一些 宏 函数 当CMakeLis
  • java开发团队认知_一个优秀的研发团队应该具备什么特征

    1 计划执行 计划安排得当 不要老加班 不要老是现实和计划不匹配 不要做到哪儿计划就推后到哪儿 2 研发成果 成功产出几个重影响力级别的 完整成块的 有成就感自豪感的产品或项目 3 团队氛围 这个团队每个人都相处的很融洽 4 团队协作 每个
  • Pytorch 的 LSTM 模型的简单示例

    1 代码 完整的源代码 import torch from torch import nn 定义一个LSTM模型 class LSTM nn Module def init self input size hidden size num l