零基础LSTM入门示例

2023-05-16

最近用pytorch搭了个LSTM模型,由于博主两个都没基础,所以查来查去兜了不少圈子,干脆总结一个极简的LSTM代码示例,供参考

仅使用了torch.nn.Module自定义模型

随便挑了accuracy_score作为评估指标

以及,特征和标签是随便打的,不要指望它收敛ㄟ( ▔, ▔ )ㄏ

 

定义模型结构:

import torch
from sklearn.metrics import accuracy_score

#定义需要的模型结构,继承自torch.nn.Module
#必须包含__init__和forward两个功能
class mylstm(torch.nn.Module):
    def __init__(self, lstm_input_size, lstm_hidden_size, lstm_batch, lstm_layers):
        # 声明继承关系
        super(mylstm, self).__init__()

        self.lstm_input_size, self.lstm_hidden_size = lstm_input_size, lstm_hidden_size
        self.lstm_layers, self.lstm_batch = lstm_layers, lstm_batch

        # 定义lstm层
        self.lstm_layer = torch.nn.LSTM(self.lstm_input_size, self.lstm_hidden_size, num_layers=self.lstm_layers, batch_first=True)
        # 定义全连接层 二分类
        self.out = torch.nn.Linear(self.lstm_hidden_size, 2)

    def forward(self, x):
        # 激活
        x = torch.sigmoid(x)
        # LSTM
        x, _ = self.lstm_layer(x)
        # 保留最后一步的输出
        x = x[:, -1, :]
        # 全连接
        x = self.out(x)
        return x

    def init_hidden(self):
        #初始化隐藏层参数全0
        return torch.zeros(self.lstm_batch, self.lstm_hidden_size)

数据集特征和标签:

#训练集特征
train_feature = [
    [[0.1, 0.2, 0.3, 0.4], [0.1, 0.2, 0.3, 0.4], [0.1, 0.2, 0.3, 0.4]],
    [[0.4, 0.3, 0.2, 0.1], [0.4, 0.3, 0.2, 0.1], [0.4, 0.3, 0.2, 0.1]],
    [[0.2, 0.3, 0.5, 0.8], [0.2, 0.3, 0.5, 0.8], [0.2, 0.3, 0.5, 0.8]],
    [[0.7, 0.6, 0.5, 0.4], [0.7, 0.6, 0.5, 0.4], [0.7, 0.6, 0.5, 0.4]],
    [[0.1, 0.3, 0.5, 0.7], [0.1, 0.3, 0.5, 0.7], [0.1, 0.3, 0.5, 0.7]],
    [[0.5, 0.4, 0.2, 0.1], [0.5, 0.4, 0.2, 0.1], [0.5, 0.4, 0.2, 0.1]],
    [[0.2, 0.4, 0.6, 0.8], [0.2, 0.4, 0.6, 0.8], [0.2, 0.4, 0.6, 0.8]],
    [[0.7, 0.6, 0.3, 0.2], [0.7, 0.6, 0.3, 0.2], [0.7, 0.6, 0.3, 0.2]]]
#测试集特征
test_feature = [
    [[0.3, 0.4, 0.6, 0.8], [0.3, 0.4, 0.6, 0.8], [0.3, 0.4, 0.6, 0.8]],
    [[0.9, 0.6, 0.3, 0.2], [0.9, 0.6, 0.3, 0.2], [0.9, 0.6, 0.3, 0.2]]]
#训练集、测试集标签
train_label = [1, 0, 1, 0, 1, 0, 1, 0]
test_label = [1, 0]

模型的定义、训练、测试:

#预定义模型参数
dataset_batch_size = 2
learning_rate = 0.001
lstm_input_size = 4
lstm_hidden_size = 4
lstm_batch = 2
lstm_layers = 1

#定义模型
model = mylstm(lstm_input_size, lstm_hidden_size, lstm_batch, lstm_layers)
#初始化隐藏层
hidden = model.init_hidden()
#定义损失函数用交叉熵
criterion = torch.nn.CrossEntropyLoss()
#定义用Adam算法梯度下降
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(5):
    #避免保存的计算图过大,进行梯度清零
    optimizer.zero_grad()
    for i in range((len(train_feature) // dataset_batch_size) - 1):
        #训练
        model.train()
        #取一个batch并转化为张量
        x = train_feature[i * dataset_batch_size:(i + 1) * dataset_batch_size]
        x = torch.tensor(x)
        y = train_label[i * dataset_batch_size:(i + 1) * dataset_batch_size]
        y = torch.tensor(y)
        # 前馈
        y_pred = model(x)
        loss = criterion(y_pred, y)
        # 反馈
        optimizer.zero_grad()
        loss.backward()
        # 更新
        optimizer.step()

    #测试
    model.eval()
    test_feature = torch.tensor(test_feature)
    out = model(test_feature)
    predict_result = torch.argmax(out, dim=1)
    #计算准确率
    acc = accuracy_score(test_label, predict_result)
    print('epoch:', epoch+1, ' acc:', acc)

 

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

零基础LSTM入门示例 的相关文章

  • 性能监控工具nmon介绍

    性能监控工具nmon介绍 做性能测试 xff0c 服务器监控是至关重要的 xff0c 使用nmon命令可以轻松捕捉系统资源的使用情况 xff0c 便于做性能分析 nmon官方介绍 nmon是一种在 AIX 与各种 Linux 操作系统上广泛
  • 使用gitlab中issues做缺陷管理

    使用Gitlab中Issues做缺陷管理 创建issue bug模板 创建issue bug模板是为了在创建issue时可以选择模板 xff0c 控制issue的格式统一 上传bug templates md文件至git库上 gitlab
  • JMeter学习笔记(七):Linux运行JMeter

    JMeter在Linux下运行测试 安装JDK 首先安装JDK xff0c 并正确配置环境变量 下载jdk并上传至linux服务器 上传并解压jdk压缩包jdk 8u231 linux x64 tar gz xff0c 建议把软件都安装到
  • Jenkins配置Windows节点实现自动化测试(一)

    Jenkins配置Windows节点实现自动化测试 一 配置节点 目前公司已经有jenkins服务器 xff0c 且运维人员已经配置好CI CD持续集成持续部署 xff0c 测试人员期望将自动化测试集成到CI CD任务中 xff0c 由于U
  • 使用WinSW安装Windows服务

    使用WinSW安装Windows服务 背景 配置Jenkins Windows节点时 xff0c 需要手动执行命令启动服务 xff0c 每次手动启动很麻烦 xff0c 写成批处理文件放在C ProgramData Microsoft Win
  • VMware中ubuntu虚拟机重启后找不到ens33网卡问题

    VMware中ubuntu虚拟机关闭重启后ens33网卡找不到问题解决方案 工作中会使用ubuntu 桌面版本 xff0c 虚拟机中ubuntu使用完后直接点击关闭 xff0c 重新打开后无法上网 xff0c 查看无ens33网卡 xff0
  • HTTP之TCP三次握手四次挥手

    HTTP概述 HTTP是hypertext transfer protocol xff08 超文本传输协议 xff09 的简写 xff0c 它是TCP IP协议的一个应用层协议 xff0c 用于定义WEB浏览器与WEB服务器之间交换数据的过
  • 深入理解Wi-Fi P2P

    第7章 深入理解Wi Fi P2P 本章所涉及的源代码文件名及位置 W ifiP2pSettings java packages apps Settings src com android settings wifi p2p W ifiP2
  • JMeter遇到全局变量、BeanShell Sampler、JDBC(postgresql)

    全局变量 BeanShell Sampler JDBC postgresql xff09 设置全局变量 场景 xff1a 在压测指定接口时需要先进行登录才能去访问接口 xff0c 解决方案是先设置一个登录线程 xff0c 登录成功后通过JS
  • Python JSON dumps与loads傻傻分不清

    一 JSON基本概念 JSON代表JavaScript对象符号 它是一种轻量级的数据交换格式 xff0c 用于存储和交换数据 它是一种独立于语言的格式 xff0c 非常容易理解 xff0c 因为它本质上是自描述的 python中有一个内置包
  • tomcat 端口冲突问题的解决办法

    方法1 tomcat开机启动了 xff0c 你可以查看任务管理器 xff0c 把tomcat xff08 或者Apache tomcat xff09 的任务关了 方法 2 更改tomcat的8080端口 打开配置文件 xff08 如下 xf
  • 总结之知识图谱前沿技术课程

    前言正文参考文献 前言 写在前面 xff0c 本文的内容主要基于2017年12月2日在苏州大学举办的知识图谱前沿技术课程 xff08 感谢各位老师的talk xff0c 受益良多 xff09 以及本人在之前阅读的有关paper xff0c
  • Qt 5.15的源码编译(Windows)

    前言 xff1a 在技术革新如此之快的时代 xff0c Qt也在为适应这些变化发生着重大的改变 又一长期 3年 支持版Qt 5 15 LTS在2020年3月发布 xff0c 重大更新的大版本Qt 6 0也在2020年12月发布 但是 xff
  • 【vim编辑器的使用】

    目录 1 vim的编辑器的使用 1 1 vim 文件名 xff1a 表示将文件用vim编辑器打开 2 vim的三种编辑模式 2 1 命令模式 2 2 插入模式 2 3 底行模式 Vim 是从 vi 发展出来的一个文本编辑器 代码补完 编译及
  • gcc编译器

    GCC xff08 GNU Compiler Collection xff0c GNU编译器套件 xff09 是由GNU开发的编程语言译器 GNU编译器套件包括C C 43 43 Objective C Fortran Java Ada和G
  • send()函数recv()函数详解

    目录 1 send xff08 xff09 函数 2 recv xff08 xff09 函数 1 send xff08 xff09 函数 函数原型 xff1a ssize t send int sockfd const void buf s
  • 数据元素、数据项、数据对象的概念详解

    数据元素 xff1a 数据的基本单位 数据项 xff1a 独立包含的数据最小单位 若干数据项组成一个数据元素 数据对象 xff1a 相同数据元素的集合 若干数据元素组成数据对象
  • Linux解压压缩命令tar

    目录 一 tar tar命令打包 tar命令解压 选项解释 一 tar Linux系统中常用的压缩格式有 xff1a tar gz tar bz2 tar xz tar Z 可以用tar进行解压缩 tar命令打包 xff1a tar 选项
  • ubuntu服务器编译源码

    1 xff0c Vmware软件安装后 2 xff0c VMware workstation full 16 0 0 16894299 exe 3 xff0c 新建虚拟磁盘 xff0c 加载镜像文件 ubuntu 16 04 7 deskt
  • 3.1 Linux启动Shell

    系列文章目录 第1章 Linux Shell简介 第2章 Shell基础 第3章 Bash Shell基础命令 lt 本章所在位置 gt 第4章 Bash Shell命令进阶 第5章 Linux Shell深度理解 第6章 Linux环境变

随机推荐

  • IP地址打印格式

    在C语言中 xff0c 可以使用printf 函数打印IP地址 常见的方法是将IP地址转换为点分十进制格式 xff0c 并使用 s或 u u u u等格式说明符进行打印 以下是一些示例代码 xff1a 将IP地址转化为字符串并以 34 s
  • 子网掩码打印方式

    在C语言中 xff0c 可以使用printf 函数打印子网掩码 和打印IP地址类似 xff0c 常见的方法是将掩码转换为点分十进制格式 xff0c 并使用 s或 u u u u等格式说明符进行打印 以下是一些示例代码 xff1a 将子网掩码
  • realloc 用法 .

    最近在写source code时需要在数组的buffer小时重新申请一块buffer 故找了一些资料 xff0c 乖乖 xff0c 竟然原指针还可以 漂移 realloc 原型 xff1a extern void realloc void
  • 多个方面比较电路交换、报文交换和分组交换的主要优缺点

    xff08 1 xff09 电路交换 xff1a 由于电路交换在通信之前要在通信双方之间建立一条被双方独占的物理通路 xff08 由通信双方之间的交换设备和链路逐段连接而成 xff09 xff0c 因而有以下优缺点 优点 xff1a 由于通
  • LVDS,CML,LVPECL,VML接口详细介绍

    在平时的工作中 xff0c 经常会接触到各种差分电平的转换 xff0c 网上也有很多这样的资料 xff0c 但发现有些混乱 xff0c 所以找了TI的这份文档进行翻译 xff0c 一是系统的归类一下 xff0c 二是自己也能通过这个来加深理
  • Linux-网桥原理分析 .

    Linux 网桥原理分析 http biancheng dnbcw info linux 244269 html 目 录 1 前言 6 2 网桥的原理 7 2 1 桥接的概念 7 2 2 linux的桥接实现 8 2 3 网桥的功能 9 3
  • IP头、TCP头、UDP头详解以及定义

    一 MAC帧头定义 数据帧定义 xff0c 头14个字节 xff0c 尾4个字节 typedef struct MAC FRAME HEADER char m cDstMacAddress 6 目的mac地址 char m cSrcMacA
  • SGMII 和 Serdes 的详细说明

    Serdes xff1a SERDES是英文SERializer 串行器 DESerializer 解串器 的简称 它是一种时分多路复用 TDM 点对点的通信技术 xff0c 即在发送端多路低速并行信号被转换成高速串行信号 xff0c 经过
  • mips的内存管理-kseg0,kseg1虚拟和物理地址映射理解

    mips 24kf manual gliethttp pdf p89页 所以mips复位和中断发生 都会自动进入kernel模式 The core enters Kernel mode both at reset and when an e
  • Linux SSH Access denied(拒绝访问)解决方案

    新安装的 CentOS 7 使用 SSH 连接出现 Access denied xff0c 记录一下这个坑 详细问题如下 xff08 见图 xff09 xff1a 解决方案 查了下资料 xff0c Linux 系统默认就是禁止远程登录的 那
  • (数据结构与算法分析 一)------快速求幂算法,Java递归实现

    快速求幂算法 xff0c 递归实现 xff0c 其实算法的思想很简单 xff0c 但是感觉非常经典 xff0c 这个也是我开始看数据结构与算法分析这本书的开始把 xff0c 大学期间感觉就得深究一下算法 xff0c 课堂学习的太肤浅 xff
  • 字符串拷贝函数memcpy和strncpy以及snprintf 的性能比较

    问题 xff1a 函数memcpy dest src sizeof dest strncpy dest src sizeof dest 和snprintf dest sizeof dest s src 都可以将src字符串中的内容拷贝到de
  • snprintf函数使用总结

    一直有接触snprintf 经久不用知识点又会模糊 记录下来以便日后查看 依赖头文件 include lt stdio h gt 函数原型 int snprintf char str size t size const char forma
  • 如何在Ubuntu 18.04 LTS上使用UFW设置防火墙

    正确配置防火墙是整个系统安全中最重要的方面之一 默认情况下 xff0c Ubuntu 18 04 LTS 附带了一个名为 UFW xff08 Uncomplicated Firewall xff09 的防火墙配置工具 xff0c UFW 是
  • 全程技术干货:VR画面渲染性能是这样提升的

    本文您将了解到 xff1a 1 xff0c VR渲染面临什么问题 xff1f 2 xff0c 如何做好VR的渲染 xff1f 3 xff0c 怎样提升VR渲染的性能 xff1f 渲染对于VR内容的开发来说 xff0c 是非常重要的议题 但在
  • Python微信小程序,实现自动回复等功能(itchat模块)

    本文是使用Python的itchat模块进行微信私聊消息以及群消息的自动回复功能 xff0c 必须在自己的微信中添加微信号xiaoice ms xff08 微软的微信机器人 xff09 才能实现 xff0c 直接复制代码运行之后扫一扫二维码
  • 最大完全子图和极大连通子图

    最近学习图论的一串小结之一 完全图 amp 完全子图 amp 最大完全子图 完全图 xff1a 任意两点都恰有一条边相连的图 任意两点都相邻 完全子图 xff1a 满足任意两点都恰有一条边相连的子图 xff0c 也叫团 最大完全子图 xff
  • python3回溯找最大团

    最近学习图论的一串小结之三 数学概念见上上篇 xff1a 最大完全子图和极大连通子图 最大团问题分析可以移步这篇博文 xff1a 回溯 图论 最大团问题 xff08 求最大完全子图 xff09 代码一部分参考了这篇博文 xff1a pyth
  • python3关于经纬度、方向角、目标距离

    博主搞了半天haversin公式倒腾距离之后 xff0c 发现有现成的geopy可用 xff0c 且网上查到的一些函数用法似乎有改变 xff0c 遂整理如下 已知两点经纬度求距离 from geopy distance import geo
  • 零基础LSTM入门示例

    最近用pytorch搭了个LSTM模型 xff0c 由于博主两个都没基础 xff0c 所以查来查去兜了不少圈子 xff0c 干脆总结一个极简的LSTM代码示例 xff0c 供参考 仅使用了torch nn Module自定义模型 随便挑了a