使用 pad_sequence

2023-11-18

pad_sequence 是用来干嘛的?

首先 pad_sequence 是用来对对tensor做padding 的,先看官方示例:
文档地址https://pytorch.org/docs/stable/generated/torch.nn.utils.rnn.pad_sequence.html?highlight=pad_sequence#torch.nn.utils.rnn.pad_sequence

from torch.nn.utils.rnn import pad_sequence
a = torch.ones(25, 300)
b = torch.ones(22, 300)
c = torch.ones(15, 300)
pad_sequence([a, b, c]).size()

首先说应用场景

也就是说我们有了几个矩阵,除了第一个维度不一样,其他维度是一样的
常见的就是很多个sequence( length * embedding_dimension )
此时我们希望用做神经网络的input,但是input都是定长的,所以我们需要统一处理成定长的,以保证input layer
此时,我们需要将它们处理成各个维度都一样

遇到问题

10月22日

问题描述:

a = [[1,2,3,4,5],[1,1,1,1],[2,2,2]]
pad_sequence(a, batch_first = True, padding_value = 0 )
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_79801/829534548.py in <module>
      1 a = [[1,2,3,4,5],[1,1,1,1],[2,2,2]]
----> 2 pad_sequence(a, batch_first = True, padding_value = 0 )

~/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/utils/rnn.py in pad_sequence(sequences, batch_first, padding_value)
    361     # assuming trailing dimensions and type of all the Tensors
    362     # in sequences are same and fetching those from sequences[0]
--> 363     return torch._C._nn.pad_sequence(sequences, batch_first, padding_value)
    364 
    365 

TypeError: expected Tensor as element 0 in argument 0, but got list

原因:我的矩阵是由数字组成的矩阵,也就是基本单元都是数字,但是,再高一层次(既是+2 也是-2 个维度)的数据类型是list,而不是数组。
解决:使用 torch.tensor() 对第二层的数据类型做一个类型转换,不适用list

a = [torch.tensor([1,2,3,4,5]),torch.tensor([1,1,1,1]),torch.tensor([2,2,2])]
pad_sequence(a, batch_first = True, padding_value = 0 )
tensor([[1, 2, 3, 4, 5],
        [1, 1, 1, 1, 0],
        [2, 2, 2, 0, 0]])
11月5日
问题:

使用pad_sequence 时,如果是对一个完整的数据集进行pad,那么操作很简单,但是计算的代价会很大。
比如,100 个sequence 组成的list,当其中只有一个是长度为1000 的,其余都是100,那么多pad的部分就高达99 X 900,
而总的有效部分才 99 X 100 + 1000 ,pad 部分比有效部分还多。

思路:

所以,应当基于一个batch 去做pad_sequence

预备知识

data,dataset,dataloader 的关系

  1. dataloader
    发现,训练使用的dataloader构建时,传入的是class Dataset 的实例,目标是从实例中sample出来一个个小的batch,这些batchs 被dataloader组成了一个list。
    注意,list组成单元类型要一样,但是组成单元的基本单元并不管,因为dataloader只负责给你返回batchs
    但是sample出来一个个小的batch 过程也是一条一条数据的从dataset 中获取的。

  2. class Dataset 作用是什么呢?
    pytorch 的class Dataset 有torch.utils.data.Dataset, torch.utils.data.IterableDataset 等,区别只是取出来一条数据的方式不同
    由于我使用的是torch.utils.data.Dataset,所以以此为例。
    Dataset 数据集嘛,所以其自身的变量里面是有一个大的数据单位。(强调一下,你自己怎么组织无所谓,字典也好,列表也好,数组也好,只要给你一下index 索引,你能够返回一条数据就行)
    看一下官方解释

CLASStorch.utils.data.Dataset(*args, **kwds)[SOURCE]
An abstract class representing a Dataset.
All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite getitem(), supporting fetching a data sample for a given key. Subclasses could also optionally overwrite len(), which is expected to return the size of the dataset by many Sampler implementations and the default options of DataLoader.

就是你只要是通过继承torch.utils.data.Dataset构建自己的dataset,那么你一定要实现__getitem__()方法,而这个方法就是输入一个下标,返回一条数据。

  1. 重点部分dataloader 的一个参数 collate_fn

首先看一下说明:

collate_fn (callable, optional) – merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.

也就是说在你从dataset 获取一个batch 数据后,你准备对这批数据怎么处理
我们基于batch 做pad_sequence其实就是通过这个参数指向的函数实现,所以我们实现一个这样的函数

def collate_func(batch):
    # print(batch)
    x_list = [dic['x'] for dic in batch]
    y_list = [dic['y'] for dic in batch]
    x_padded = pad_sequence(x_list, batch_first = True, padding_value = 0 )
    y_padded = pad_sequence(y_list, batch_first = True, padding_value = structrue_index_dict['X'] )
    # return tuple(x_padded, y_padded)
    return {'x':x_padded, 'y':y_padded}

就是刚才说的,输入的是一个batch ,batch是通过list 组织的,list 的基本单位dataloader 不在乎,只要你自己能够知道怎么取出来用就行了。
所以我返回的时候是一个字典,字典的key ‘x’ 是input 部分,key ‘y’ 对应的是 label。
也就是说,我这一个batch 组成单元是dict,但是dict的基本单元是torch的数组
我处理成一个字典,字典里面两个key,每个key是一个torch数组。
但是x_list 不等长,此处对这一个batch做pad_sequence,目的达到。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用 pad_sequence 的相关文章

随机推荐

  • 系统架构设计师-数据库系统(1)

    目录 一 数据库模式 1 集中式数据库 2 分布式数据库 二 数据库设计过程 1 E R模型 2 概念结构设计 3 逻辑结构设计 三 关系代数 1 并交差 2 投影和选择 3 笛卡尔积 4 自然连接 一 数据库模式 1 集中式数据库 三级模
  • less命令打开两个或多个文件时切换文件的快捷键

    在使用 less 命令查看多个文件时 可以使用快捷键 n 和 p 来切换文件 输入 n 后 将切换到下一个文件 输入 p 后 将切换到上一个文件 如下图 less可以打开两个文件 使用快捷键就可以快速查看 而不用退出后再重新打开另一个文件了
  • 计算机专业PhD申请文书范文,美国统计学博士申请文书范文

    美国统计学博士申请文书范文推荐 美国博士申请文书个人陈述作用十分重要 本文为大家提供了一篇成功获取美国统计学博士申请的PS范文 希望大家可以从这一篇文章中得到一些有用的参考信息 I am applying for acceptance in
  • SQL 连接运算join

    连接运算是 8种关系运算 中的一种 五种JOIN方式 1 INNER JOIN or JOIN 2 OUTER JOIN 2 1LEFT OUTER JOIN or LEFT JOIN 2 2RIGHT OUTER JOIN or RIGH
  • 一图看懂 openpyxl 资料整理+笔记(大全)

    本文由 大侠 AhcaoZhu 原创 转载请声明 链接 https blog csdn net Ahcao2008 一图看懂 openpyxl 资料整理 笔记 大全 摘要 类结构图 一级模块目录 按字序 多级模块 按层级 模块级 doc 及
  • 【已更新代码图表】2023数学建模国赛E题python代码--黄河水沙监测数据分析

    E 题 黄河水沙监测数据分析 黄河是中华民族的母亲河 研究黄河水沙通量的变化规律对沿黄流域的环境治理 气候变 化和人民生活的影响 以及对优化黄河流域水资源分配 协调人地关系 调水调沙 防洪减灾 等方面都具有重要的理论指导意义 附件 1 给出
  • STM32---SPI

    SPI 1 SPI介绍 SPI主要应用在EEPROM FLASH 实时时钟 AD转换器 数字信号处理器 数字信号解码器 4条信号线 MISO 主设备输入 从设备输出引脚 主机从这条信号线读入数据 从机的数据由这条信号线输出到主机 即在这条线
  • 面试必问

    若有收获 请记得分享和转发哦 对于工作3年左右的Java程序员来说 在面试大厂的过程中 面试官可能不会太关注你做了多少个项目 你的CRUD水平如何 更多的是关注你对某项技术点的理解深度 所以说 工作3年左右的小伙伴一定要把自己的重心放到技术
  • HTML基础标签及其语义

    一 HTML语法规范 1 1 基本语法表述 标签通常都是成对出现 双标签 开始结束标签 br 单标签 1 2 标签关系 包含关系与并列关系 包含关系 父标签 子标签 并列关系 1 3 HTML基本结构标签 骨架标签 页面内容也是在这些基本标
  • 陷波器设计

    中心频率 f c H z f c rm Hz fc Hz 3dB陷波器带宽 f
  • emplace_back与push_back异同

    vector的emplace back与push back 文章目录 vector的emplace back与push back 前言 1 区别总览 2 push back 支持右值引用 不支持传入多个构造参数 总是会进行拷贝构造 3 em
  • C++学习笔记——随机数

    利用rand 函数生成随机数如何随机是根据随机数种子来生成 一个程序的随机数种子一般是固定的 所以是伪随机数 若想生成真随机数 则用电脑的时间来初始化这个随机数种子 include
  • LLM 赋能的研发效能:如何探索软件开发新工序?

    上周末 我们 我和我的同事谢保龙 在 QCon 广州 2023 上分享了一个 AI 结合研发效能的话题 探索软件开发新工序 LLM 赋能研发效能提升 我们分享了 Thoughtworks 在过去的两个月里对于 LLM 大语言模型 结合软件开
  • 高防CDN和加速CDN有什么区别?

    高防CDN和加速CDN有什么区别 随着互联网技术的不断发展 CDN Content Delivery Network 已经成为了网络加速和安全保障的重要手段 在CDN的领域中 高防CDN和加速CDN是两种不同的CDN服务 它们有不同的特点和
  • 重积分的计算与理解

    主要分为二重积分和三重积分 二重积分 二重积分的基本思想是变成两次积分 物理意义已知面密度f 算质量 即首先把y方向的每一根线段计算出质量 相当于把y的线捏起来了 然后算x 主要方法如下 计算 D f x
  • 数据结构之链表:单向链表、单向循环链表、双向链表及基本操作

    目录 一 链表 1 1 单向链表 1 1 1 单链表的操作 1 2 单向循环链表 1 3 双向链表 了解 二 链表与顺序表的对比 一 链表 链表 将元素存放在通过链接构造起来的一系列存储块中 在每一个节点 数据存储单元 里存放下一个节点的位
  • 2020美赛建模F题思路和理解

    2020 MCM ICM 美国大学生数学建模竞赛 MCM ICM F题 2020 ICM Weekend 2 Problem F The Place I Called Home 思路和理解 问题中心 设计模型研究海平面上升对相关国家的人口
  • 报错:Dependency annotations: {@org.springframework.beans.factory.annotation.Autowired(required=true)}

    这两天自己搭spingmvc 总是报错 找不到自动注册的bean Could not autowire field private lf service UserService lf controllers UserController u
  • java实现队列_java实现队列

    队列的定义 队列的特点是节点的排队次序和出队次序按入队时间先后确定 即先入队者先出队 后入队者后出队 即我们常说的FIFO first in first out 先进先出 顺序队列定义及相关操作 顺序存储结构存储的队列称为顺序队列 内部使用
  • 使用 pad_sequence

    pad sequence 是用来干嘛的 首先 pad sequence 是用来对对tensor做padding 的 先看官方示例 文档地址https pytorch org docs stable generated torch nn ut