如何使用HuggingFace训练Transformer

2023-11-08

HuggingFace Transformers

使用BERT和其他各类Transformer模型,绕不开HuggingFace提供的Transformers生态。HuggingFace提供了各类BERT的API(transformers库)、训练好的模型(HuggingFace Hub)还有数据集(datasets)。最初,HuggingFace用PyTorch实现了BERT,并提供了预训练的模型,后来。越来越多的人直接使用HuggingFace提供好的模型进行微调,将自己的模型共享到HuggingFace社区。HuggingFace的社区越来越庞大,不仅覆盖了PyTorch版,还提供TensorFlow版,主流的预训练模型都会提交到HuggingFace社区,供其他人使用。

使用transformers库进行微调,主要包括:

  • Tokenizer:使用提供好的Tokenizer对原始文本处理,得到Token序列;
  • 构建模型:在提供好的模型结构上,增加下游任务所需预测接口,构建所需模型;
  • 微调:将Token序列送入构建的模型,进行训练。

Tokenizer

下面两行代码会创建 BertTokenizer,并将所需的词表加载进来。首次使用这个模型时,transformers 会帮我们将模型从HuggingFace Hub下载到本地。

from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

用得到的tokenizer进行分词:

encoded_input = tokenizer("我是一句话")
print(encoded_input)

输出

{'input_ids': [101, 2769, 3221, 671, 1368, 6413, 102], 
'token_type_ids': [0, 0, 0, 0, 0, 0, 0], 
'attention_mask': [1, 1, 1, 1, 1, 1, 1]}

得到的一个Python dict。其中,input_ids最容易理解,它表示的是句子中的每个Token在词表中的索引数字。词表(Vocabulary)是一个Token到索引数字的映射。可以使用decode()方法,将索引数字转换为Token。

 tokenizer.decode(encoded_input["input_ids"])

输出

'[CLS] 我 是 一 句 话 [SEP]'

可以看到,BertTokenizer在给原始文本处理时,自动给文本加上了[CLS]和[SEP]这两个符号,分别对应在词表中的索引数字为101和102。decode()之后,也将这两个符号反向解析出来了。

token_type_ids主要用于句子对,比如下面的例子,两个句子通过[SEP]分割,0表示Token对应的input_ids属于第一个句子,1表示Token对应的input_ids属于第二个句子。不是所有的模型和场景都用得上token_type_ids。

encoded_input = tokenizer("您贵姓?", "免贵姓李")
print(encoded_input)
{'input_ids': [101, 2644, 6586, 1998, 136, 102, 1048, 6586, 1998, 3330, 102], 
'token_type_ids': [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1], 
'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
text = '[CLS]'
print('Original Text : ', text)
print('Tokenized Text: ', tokenizer.tokenize(text))
print('Token IDs     : ', tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text)))
print('\n')

text = '[SEP]'
print('Original Text : ', text)
print('Tokenized Text: ', tokenizer.tokenize(text))
print('Token IDs     : ', tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text)))
print('\n')

text = '[PAD]'
print('Original Text : ', text)
print('Tokenized Text: ', tokenizer.tokenize(text))
print('Token IDs     : ', tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text)))
Original Text :  [CLS]
Tokenized Text:  ['[CLS]']
Token IDs     :  [101]

Original Text :  [SEP]
Tokenized Text:  ['[SEP]']
Token IDs     :  [102]

Original Text :  [PAD]
Tokenized Text:  ['[PAD]']
Token IDs     :  [0]

句子通常是变长的,多个句子组成一个Batch时,attention_mask就起了至关重要的作用。

batch_sentences = ["我是一句话", "我是另一句话", "我是最后一句话"]
batch = tokenizer(batch_sentences, padding=True, return_tensors="pt")
print(batch) 
{'input_ids': 
 tensor([[ 101, 2769, 3221,  671, 1368, 6413,  102,    0,    0],
        [ 101, 2769, 3221, 1369,  671, 1368, 6413,  102,    0],
        [ 101, 2769, 3221, 3297, 1400,  671, 1368, 6413,  102]]), 
 'token_type_ids': 
 tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0]]), 
 'attention_mask': 
 tensor([[1, 1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1]])}

对于这种batch_size = 3的场景,不同句子的长度是不同的,padding=True表示短句子的结尾会被填充[PAD]符号,return_tensors="pt"表示返回PyTorch格式的Tensor。attention_mask告诉模型,哪些Token需要被模型关注而加入到模型训练中,哪些Token是被填充进去的无意义的符号,模型无需关注。

Model

下面两行代码会创建BertModel,并将所需的模型参数加载进来。

from transformers import BertModel
model = BertModel.from_pretrained("bert-base-chinese")

BertModel是一个PyTorch中用来包裹网络结构的torch.nn.Module,BertModel里有forward()方法,forward()方法中实现了将Token转化为词向量,再将词向量进行多层的Transformer Encoder的复杂变换。
forward()方法的入参有input_ids、attention_mask、token_type_ids等等,这些参数基本上是刚才Tokenizer部分的输出。

bert_output = model(input_ids=batch['input_ids'])

forward()方法返回模型预测的结果,返回结果是一个tuple(torch.FloatTensor),即多个Tensor组成的tuple。tuple默认返回两个重要的Tensor:

len(bert_output)
2
  • last_hidden_state:输出序列每个位置的语义向量,形状为:(batch_size, sequence_length, hidden_size)。
  • pooler_output:[CLS]符号对应的语义向量,经过了全连接层和tanh激活;该向量可用于下游分类任务。

下游任务

BERT可以进行很多下游任务,transformers库中实现了一些下游任务,我们也可以参考transformers中的实现,来做自己想做的任务。比如单文本分类,transformers库提供了BertForSequenceClassification类。

class BertForSequenceClassification(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.config = config

        self.bert = BertModel(config)
        classifier_dropout = ...
        self.dropout = nn.Dropout(classifier_dropout)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        ...
        
    def forward(
        ...
    ):
        ...

        outputs = self.bert(...)
        pooled_output = outputs[1]
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        ...

在这段代码中,BertForSequenceClassification在BertModel基础上,增加了nn.Dropout和nn.Linear层,在预测时,将BertModel的输出放入nn.Linear,完成一个分类任务。除了BertForSequenceClassification,还有BertForQuestionAnswering用于问答,BertForTokenClassification用于序列标注,比如命名实体识别。

transformers 中的各个API还有很多其他参数设置,比如得到每一层Transformer Encoder的输出等等,可以访问他们的文档查看使用方法。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何使用HuggingFace训练Transformer 的相关文章

随机推荐

  • 关于Element-ui中Table表格无法显示的问题及解决

    Element ui中Table表格无法显示 1 准备工作 2 引用Element ui官方文档中的Table表格代码 3 启动端口 并在浏览器访问 Element ui表格不生效问题 原因是 Element ui中Table表格无法显示
  • linux设备模型之bus,device,driver分析一

    本文系本站原创 欢迎转载 转载请注明出处 http www cnblogs com gdt a20 内核的开发者将总线 设备 驱动这三者用软件思想抽象了出来 巧妙的建立了其间的关系 使之更形象化 结合前面所学的知识 总的来说其三者间的关系为
  • vscode删除缩进多行tab

    shift tab 转载于 https www cnblogs com v5captain p 9160398 html
  • chrome扩展开发(2)- manifest.json文件简述

    一 本文目标 结合具体应用场景 让读者对manifest json文件的写法和主要属性拥有初步认识 二 目标读者 chrome扩展开发的初学者 想要先从宏观上了解一下chrome扩展能干哪些事情 而不是急于写出一个能运行的demo的人 三
  • 密码学 ~ 数字签名

    概念 数字签名 又称公钥数字签名 是只有信息的发送者才能产生的别人无法伪造的一段数字串 这段数字串同时也是对信息的发送者发送信息真实性的一个有效证明 它是一种类似写在纸上的普通的物理签名 但是使用了公钥加密领域的技术来实现的 用于鉴别数字信
  • PyCharm调试代码的时候出现pydev debugger: process xxxx is connecting

    最近在初学python的时候 在使用PyCharm调试代码的时候出现 pydev debugger process xxxx is connecting 和 Process finished with exit code 0 从从其返回的状
  • matlab神经网络训练图解释,matlab实现神经网络算法

    matlab 神经网络 net newff pr 3 2 logsig logsig 创建一个bp神经网络 10 显示训练迭代过程 0 05 学习速率0 05 1e 10 训练精度net trainParam epochs 50000 最大
  • 自定义单选框和多选框

    说明 作为一个Java后端程序员 有时候也需要自己去写些前端代码 所以将工作中用到的一些小知识做记录分享 1 自定义单选框 有图片 先看效果图 再献上完整代码
  • STM32F103ZET6【HAL函开发】STM32CUBEMX------3.USART串口进行数据的接收的发送

    目的 1 开机后 向串口1发送 hello world 2 串口1收到字节指令 0xA1 打开LED1 发送 LED1 Open 3 串口1收到字节指令 0xA2 关闭LED1 发送 LED1 Closed 4 在串口发送过程中 打开LED
  • python3 pip ipython 安装

    1 安装Python3 6 安装准备 mkdir usr local python3 wget no check certificate https www python org ftp python 3 6 0 Python 3 6 0
  • readis windows servrer 搭建与Java客户端的连接

    1 首先下载redis redis 2 0 2 zip 32 bit 解压 从下面地址下 http code google com p servicestack wiki RedisWindowsDownload 看到下面有redis 2
  • 《SegFormer:Simple and Efficient Design for Semantic Segmentation with Transformers》论文笔记

    参考代码 SegFormer 1 概述 介绍 这篇文章提出的分割方法是基于transformer结构构建的 不过这里使用到的transformer是针对分割任务在patch merge self attention和FFN进行了改进 使其更
  • 计算机内存取证之BitLocker恢复密钥提取还原

    BitLocker是微软Windows自带的用于加密磁盘分卷的技术 通常 解开后的加密卷通过Windows自带的命令工具 manage bde 可以查看其恢复密钥串 如下图所示 如图 这里的数字密码下面的一长串字符串即是下面要提取恢复密钥
  • 华为服务器如何设置网站dns,设置为正确的DNS 服务器地址

    设置为正确的DNS 服务器地址 内容精选 换一换 域名的DNS服务器定义了域名用于解析的权威DNS服务器 通过华为云注册成功的域名默认使用华为云DNS进行解析 详细内容 请参见华为云DNS对用户提供域名服务的DNS是什么 若您选择非华为云D
  • UE4缓存路径修改

    最简单的办法就是通过Evenyting搜索 Engine Config BaseEngine ini 找到你要修改的引擎对应文件 将 ENGINEVERSIONAGNOSTICUSERDIR DerivedDataCache修改为 GAME
  • ORA-00257:archiver error解决办法*

    ORA 00257 archiver error解决办法 出现ORA 00257错误 空间不足错误 通过查找资料 绝大部分说这是由于归档日志太多 占用了全部的硬盘剩余空间导致的 通过简单删除日志或加大存储空间就能够解决 一 更改归档模式 目
  • 为什么C++有多种整型?

    C 中有多种整型是为了满足不同的需求 提供更灵活和高效的整数表示方式 不同的整型具有不同的字节大小 范围和精度 可以根据应用的需求选择合适的整型类型 以下是一些原因解释为什么C 有多种整型 内存和性能优化 不同的整型在内存中占用的空间不同
  • 【读点论文】YOLOv7: Trainable bag-of-freebies sets new state-of-the-art for real-time object detectors新集合体

    YOLOv7 Trainable bag of freebies sets new state of the art for real time object detectors Abstract YOLOv7在5 FPS到160 FPS的
  • 实时监控MySQL慢查询

    背景 为了优化SQL 我们首先需要发现有问题的SQL语句 网上诸多教程都在教你使用诸如mysqldumpslow pt query digest这类工具分析MySQL慢查询日志 然而这一系列的工具都存在一个致命的缺陷 无法实时监控 而说起实
  • 如何使用HuggingFace训练Transformer

    文章目录 HuggingFace Transformers Tokenizer Model 下游任务 HuggingFace Transformers 使用BERT和其他各类Transformer模型 绕不开HuggingFace提供的Tr