[NLP] transformers 使用指南

2023-11-10

严格意义上讲 transformers 并不是 PyTorch 的一部分,然而 transformers 与 PyTorch 或 TensorFlow 结合的太紧密了,而且可以把 transformers 看成是 PyTorch 或 TensorFlow 的延伸,所以也在这里一并讨论了。

transformers 内置了 17 种以 transformer 结构为基础的神经网络:

  • T5 model
  • DistilBERT model
  • ALBERT model
  • CamemBERT model
  • XLM-RoBERTa model
  • Longformer model
  • RoBERTa model
  • Reformer model
  • Bert model
  • OpenAI GPT model
  • OpenAI GPT-2 model
  • Transformer-XL model
  • XLNet model
  • XLM model
  • CTRL model
  • Flaubert model
  • ELECTRA model

这些模型的参数、用法大同小异。默认框架为 PyTorch,使用 TensorFlow 框架在类的前面加上 'TF" 即可。

每种模型都有至少一个预训练模型,限于篇幅,这里仅仅列举 Bert 的常用预训练模型:

模型 模型细节
bert-base-uncased 12-layer, 768-hidden, 12-heads, 110M parameters. Trained on lower-cased English text.
bert-large-uncased 24-layer, 1024-hidden, 16-heads, 340M parameters. Trained on lower-cased English text.
bert-base-cased 12-layer, 768-hidden, 12-heads, 110M parameters. Trained on cased English text.
bert-large-cased 24-layer, 1024-hidden, 16-heads, 340M parameters. Trained on cased English text.
bert-base-multilingual-cased 12-layer, 768-hidden, 12-heads, 110M parameters. Trained on cased text in the top 104 languages with the largest Wikipedias
bert-base-chinese 12-layer, 768-hidden, 12-heads, 110M parameters. Trained on cased Chinese Simplified and Traditional text.

完整的预训练模型列表可以在 transformers 官网上找到。

使用 transformers 库有三种方法:

  1. 使用 pipeline
  2. 指定预训练模型;
  3. 使用 AutoModels 加载预训练模型。

1. transformers.pipeline

这个管线函数包含三个部分:

  1. Tokenizer;
  2. 一个模型实例;
  3. 其它增强模型输出的功能。

它只有一个必需参数 task,接受如下变量之一:

  • ”feature-extraction”
  • ”sentiment-analysis”
  • ”ner”
  • ”question-answering”
  • ”fill-mask”
  • ”summarization”
  • ”translation_xx_to_yy”
  • ”text-generation”

这个函数还有其它可选参数,但是我的试用经验是,什么都不要动,使用默认参数即可。

例子:

>>> from transformers import pipeline

>>> nlp = pipeline("sentiment-analysis")

>>> print(nlp("I hate you"))
[{'label': 'NEGATIVE', 'score': 0.9991129040718079}]

>>> print(nlp("I love you"))
[{'label': 'POSITIVE', 'score': 0.9998656511306763}]

2. 指定预训练模型

这里我们以 Bert 为例。

2.1 配置 Bert 模型(可选,推荐不使用)transformers.BertConfig

transformers.BertConfig 可以自定义 Bert 模型的结构,以下参数都是可选的:

  • vocab_size:词汇数,默认 30522;
  • hidden_size:编码器内隐藏层神经元数量,默认 768;
  • num_hidden_layers:编码器内隐藏层层数,默认 12;
  • num_attention_heads:编码器内注意力头数,默认 12;
  • intermediate_size:编码器内全连接层的输入维度,默认 3072;
  • hidden_act:编码器内激活函数,默认 ‘gelu’,还可为 ‘relu’、‘swish’ 或 ‘gelu_new’
  • hidden_dropout_prob:词嵌入层或编码器的 dropout,默认为 0.1;
  • attention_probs_dropout_prob:注意力的 dropout,默认为 0.1;
  • max_position_embeddings:模型使用的最大序列长度,默认为 512;
  • type_vocab_size:词汇表类别,默认为 2;
  • initializer_range:神经元权重的标准差,默认为 0.02;
  • layer_norm_eps:layer normalization 的 epsilon 值,默认为 1e-12.

使用方法:

configuration = BertConfig() # 进行模型的配置,变量为空即使用默认参数

model = BertModel(configuration) # 使用自定义配置实例化 Bert 模型

configuration = model.config # 查看模型参数

2.2 分词 transformers.BertTokenizer

所有的 tokenizer 都继承自 transformers.PreTrainedTokenizer 基类,因此有共同的参数和方法实例化的参数有:

  • model_max_length:可选参数,最大输入长度,默认为 1e30;
  • padding_side:可选参数,填充的方向,应为 ‘left’ 或 ‘right’;
  • bos_token:可选参数,每句话的起始标记,默认为 ‘’;
  • eos_token:可选参数,每句话的结束标记,默认为 ‘’;
  • unk_token:可选参数,未知的标记,默认为 ‘’;
  • sep_token:可选参数,分隔标记,默认为 ‘’;
  • pad_token:可选参数,填充标记,默认为 ‘’;
  • cls_token:可选参数,分类标记,默认为 ‘’;
  • mask_token:可选参数,遮盖标记,默认为 ‘<MASK’。

为了演示,我们先实例化一个 BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

常用的方法有:

  • from_pretrained(model):载入预训练词汇表;
  • tokenizer.tokenize(str):分词;
>>> tokenizer.tokenize('Hello word!')
['Hello', 'word', '!']
  • encode(text, ...):将文本分词后编码为包含对应 id 的列表;
>>> tokenizer.encode('Hello word!')
[101, 8667, 1937, 106, 102]
  • encode_plus(text, ...):将文本分词后创建一个包含对应 id,token 类型及是否遮盖的词典;
tokenizer.encode_plus('Hello world!')
{'input_ids': [101, 8667, 1937, 106, 102], 'token_type_ids': [0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1]}
  • convert_ids_to_tokens(ids, skip_special_tokens):将 id 映射为 token;
>>> tokenizer.convert_ids_to_tokens(tokens)
['[CLS]', 'Hello', 'word', '!', '[SEP]']
  • decode(token_ids):将 id 解码;
>>> tokenizer.decode(tokens)
'[CLS] Hello word! [SEP]'
  • convert_tokens_to_ids(tokens):将 token 映射为 id。
>>> tokenizer.convert_tokens_to_ids(['[CLS]', 'Hello', 'word', '!', '[SEP]'])
[101, 8667, 1937, 106, 102]

2.3 使用预训练模型

根据任务的需要,既可以选择没有为指定任务 finetune 的模型如 transformers.BertModel,也可以选择为指定任务 finetune 之后的模型如 transformers.BertForSequenceClassification。一共有 6 个指定的任务类型:

  • transformers.BertForMaskedLM:语言模型;
  • transformers.BertForNextSentencePrediction:判断下一句话是否与上一句有关;
  • transformers.BertForSequenceClassification:序列分类如 GLUE;
  • transformers.BertForMultipleChoice:文本分类;
  • transformers.BertForTokenClassification:token 分类如 NER,
  • transformers.BertForQuestionAnswering;问答。

3. 使用 AutoModels

使用 AutoModels 与上面的指定模型进行预训练大同小异,只不过是另一种方式加载模型而已。

3.1 加载自动配置 transformers.AutoConfig

使用类方法 from_pretrained 加载模型配置,参数既可以为模型名称,也可以为具体文件。

config = AutoConfig.from_pretrained('bert-base-uncased')
# 或者直接加载模型文件
config = AutoConfig.from_pretrained('./test/bert_saved_model/')

3.2 加载分词器 transformers.AutoTokenizer

与上面的 BertTokenizer 非常相似,也是使用 from_pretrained 类方法加载预训练模型。

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
# 或者直接加载模型文件
tokenizer = AutoTokenizer.from_pretrained('./test/bert_saved_model/')

3.3 加载模型 transformers.AutoModel

可以使用 from_pretrained 加载预训练模型:

model = AutoModel.from_pretrained('bert-base-uncased')
# 或者直接加载模型文件
model = AutoModel.from_pretrained('./test/bert_model/') 

选好了预训练模型以后,只需要给模型接一个全连接层,这个神经网络就搭好了(当然可以根据需要添加更复杂的结构)。是不是香?

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

[NLP] transformers 使用指南 的相关文章

  • 如何有效计算文档流中文档之间的相似度

    我收集文本文档 在 Node js 中 其中一个文档i表示为单词列表 考虑到新文档以文档流的形式出现 计算这些文档之间相似性的有效方法是什么 我目前对每个文档中单词的归一化频率使用余弦相似度 我不使用 TF IDF 词频 逆文档频率 因为我
  • pytorch 中的 keras.layers.Masking 相当于什么?

    我有时间序列序列 我需要通过将零填充到矩阵中并在 keras 中使用 keras layers Masking 来将序列的长度固定为一个数字 我可以忽略这些填充的零以进行进一步的计算 我想知道它怎么可能在 Pytorch 中完成 要么我需要
  • 如何计算两个文本文档之间的相似度?

    我正在考虑使用任何编程语言 尽管我更喜欢 Python 来从事 NLP 项目 我想获取两个文档并确定它们的相似程度 常见的方法是将文档转换为 TF IDF 向量 然后计算它们之间的余弦相似度 任何有关信息检索 IR 的教科书都涵盖了这一点
  • target_vocab_size 在方法 tfds.features.text.SubwordTextEncoder.build_from_corpus 中到底意味着什么?

    根据这个链接 https www tensorflow org datasets api docs python tfds features text SubwordTextEncoder build from corpus target
  • Spacy 中的自定义句子分割

    I want spaCy使用我提供的句子分割边界而不是它自己的处理 例如 get sentences Bob meets Alice SentBoundary They play together gt Bob meets Alice Th
  • 如何计算 CNN 第一个线性层的维度

    目前 我正在使用 CNN 其中附加了一个完全连接的层 并且我正在使用尺寸为 32x32 的 3 通道图像 我想知道是否有一个一致的公式可以用来计算第一个线性层的输入尺寸和最后一个卷积 最大池层的输入 我希望能够计算第一个线性层的尺寸 仅给出
  • Pytorch CUDA 错误:没有内核映像可用于在带有 cuda 11.1 的 RTX 3090 设备上执行

    如果我运行以下命令 import torch import sys print A sys version print B torch version print C torch cuda is available print D torc
  • torch.stack() 和 torch.cat() 函数有什么区别?

    OpenAI 的强化学习 REINFORCE 和 actor critic 示例具有以下代码 加强 https github com pytorch examples blob master reinforcement learning r
  • Node2vec 的工作原理

    我一直在读关于node2vec https cs stanford edu jure pubs node2vec kdd16 pdf嵌入算法 我有点困惑它是如何工作的 作为参考 node2vec 由 p 和 q 参数化 并通过模拟来自节点的
  • 如何使用Python计算多类分割任务的dice系数?

    我想知道如何计算多类分割的骰子系数 这是计算二元分割任务的骰子系数的脚本 如何循环每个类并计算每个类的骰子 先感谢您 import numpy def dice coeff im1 im2 empty score 1 0 im1 numpy
  • ANEW 字典可以用于 Quanteda 中的情感分析吗?

    我正在尝试找到一种方法来实施英语单词情感规范 荷兰语 以便使用 Quanteda 进行纵向情感分析 我最终想要的是每年的 平均情绪 以显示任何纵向趋势 在数据集中 所有单词均由 64 名编码员按照 7 分李克特量表在四个类别上进行评分 这提
  • 将复数名词转换为单数名词

    如何使用 R 将复数名词转换为单数名词 我使用 tagPOS 函数来标记每个文本 然后提取所有标记为 NNS 的复数名词 但是如果我想将这些复数名词转换为单数该怎么办 library openNLP library tm acq o lt
  • BERT 输出不确定

    BERT 输出是不确定的 当我输入相同的输入时 我希望输出值是确定性的 但我的 bert 模型的值正在变化 听起来很尴尬 同一个值返回两次 一次 也就是说 一旦出现另一个值 就会出现相同的值并重复 如何使输出具有确定性 让我展示我的代码片段
  • Pytorch GPU 使用率低

    我正在尝试 pytorch 的例子https pytorch org tutorials beginner blitz cifar10 tutorial html https pytorch org tutorials beginner b
  • 是否可以使用 Google BERT 来计算两个文本文档之间的相似度?

    是否可以使用 Google BERT 来计算两个文本文档之间的相似度 据我了解 BERT 的输入应该是有限大小的句子 一些作品使用 BERT 来计算句子的相似度 例如 https github com AndriyMulyar semant
  • 如何从已安装的云端硬盘文件夹中永久删除?

    我编写了一个脚本 在每次迭代后将我的模型和训练示例上传到 Google Drive 以防发生崩溃或任何阻止笔记本运行的情况 如下所示 drive path drive My Drive Colab Notebooks models if p
  • PyTorch 中的连接张量

    我有一个张量叫做data形状的 128 4 150 150 其中 128 是批量大小 4 是通道数 最后 2 个维度是高度和宽度 我有另一个张量叫做fake形状的 128 1 150 150 我想放弃最后一个list array从第 2 维
  • 样本()和r样本()有什么区别?

    当我从 PyTorch 中的发行版中采样时 两者sample and rsample似乎给出了类似的结果 import torch seaborn as sns x torch distributions Normal torch tens
  • 旧版本的 spaCy 在尝试安装模型时抛出“KeyError: 'package'”错误

    我在 Ubuntu 14 04 4 LTS x64 上使用 spaCy 1 6 0 和 python3 5 为了安装 spaCy 的英文版本 我尝试运行 这给了我错误消息 ubun ner 3 NeuroNER master src pyt
  • 使用 NLP 进行地址分割

    我目前正在开发一个项目 该项目应识别地址的每个部分 例如来自 str Jack London 121 Corvallis ARAD ap 1603 973130 输出应如下所示 street name Jack London no 121

随机推荐

  • 排序算法——基数排序(C语言)

    基数排序的概念 什么是基数排序 基数排序是一种和快排 归并 希尔等等不一样的排序 它不需要比较和移动就可以完成整型的排序 它是时间复杂度是O K N 空间复杂度是O K M 基数排序的思想 基数排序是一种借助多关键字的思想对单逻辑关键字进行
  • python爬虫从零开始_python爬虫---从零开始(一)初识爬虫

    我们开始来谈谈python的爬虫 1 什么是爬虫 网络爬虫是一种按照一定的规则 自动地抓取万维网信息的程序或者脚本 另外一些不常使用的名字还有蚂蚁 自动索引 模拟程序或者蠕虫 互联网犹如一个大蜘蛛网 我们的爬虫就犹如一个蜘蛛 当在互联网遇到
  • 计算机网络mask是什么意思,mask是什么意思

    你知道mask是什么意思吗 可能你在网络上偶尔会看到这样的词 但网络上的新词多到数不清 根本没有时间去仔细去了解 下面就让我们带你一起 来详细了解一下mask是什么意思吧 mask是什么意思 假面具 伪装 遮蔽物 All guests wo
  • ppt拖动就复制_PPT快捷键丨这些快捷键可助你事半功倍

    工欲善其事 必先利其器 如果你常用的快捷键只有Ctrl C Ctrl V 那你要仔细看下这篇文章了 PS 这个键盘是PPT做的哦 后台回复 键盘 获取源文件 快捷键 顾名思义就是快和方便 所以能熟练使用PPT快捷键 会使我们变得更高效 桔子
  • Shiro和Spring Security对比

    一 Shiro简介 1 什么是Shiro Shiro是apache旗下一个开源框架 它将软件系统的安全认证相关的功能抽取出来 实现用户身份 认证 权限授权 加密 会话管理等功能 组成了一个通用的安全认证框架 2 Shiro 的特点 Shir
  • VMware虚拟机连不上网络,最详细排查解决方案

    虚拟机连不上网 ping某个网站时并显示此信息 ping www baidu com Name or service not known 步骤一 排查Windows自身问题 有可能这个问题不是你虚拟机有问题 而是装虚拟机的Windows本身
  • 【数据结构】数组和字符串

    本文是对leetbook 数组和字符串 学习完成后的总结 数组和字符串 数组简介 寻找数组的中心索引 搜索插入位置 合并区间 二维数组简介 旋转矩阵 零矩阵 对角线遍历 字符串简介 最长公共前缀 最长回文子串 翻转字符串里的单词 实现 st
  • 前端开发同步和异步的区别?

    在前端开发中 同步 一般指的是在代码运行的过程中 从上到下逐步运行代码 每一部分代码运行完成之后 下面的代码才能开始运行 异步 指的是当我们需要一些代码在执行的时候不会影响其他代码的执行 也就是在执行代码的同时 可以进行其他的代码的执行 不
  • 转:安装MySQL遇到MySQL Server Instance Configuration Wizard未响应的解决办法

    问题 安装了MySQL之后进入配置界面的时候 总会显示 MySQL Server Instance Configuration Wizard未响应 一直卡死 解决办法 Win7系统中 以管理员的权限登录系统 将C盘的ProgramData中
  • postman接口测试要点及错误总结

    本文主要针对接口测试工具postman出现的常见错误及解决办法进行了总结 请求分类及具体传参介绍 GET请求 GET请求是最常见的请求类型 最常用于向服务器查询信息 必要时 可以将查询字符串参数追加到URL的末尾 以便将信息发送给服务器 P
  • 机器学习的特征工程

    机器学习的特征工程 一 数据集 Kaggle网址 https www kaggle com datasets UCI数据集网址 http archive ics uci edu ml scikit learn网址 http scikit l
  • 蓝桥杯-基础训练-龟兔赛跑预测

    问题描述 话说这个世界上有各种各样的兔子和乌龟 但是研究发现 所有的兔子和乌龟都有一个共同的特点 喜欢赛跑 于是世界上各个角落都不断在发生着乌龟和兔子的比赛 小华对此很感兴趣 于是决定研究不同兔子和乌龟的赛跑 他发现 兔子虽然跑比乌龟快 但
  • Bert的MLM任务loss原理

    bert预训练有MLM和NSP两个任务 其中MLM是类似于 完形填空 的方式 对一个句子里的15 的词进行mask 通过双向transformer feedforward rediual add layer norm完成对每个词的embed
  • CMake支持C++11、14、17

    有个需求是使用C 14会没有C 17支持的std filesystem 使用C 17会有砍掉的std random shuffles的报错 这是因为我在cmake指定C 版本 set CMAKE CXX STANDARD 17 强制使用17
  • 用 Go 语言与 EOS.IO 交互的 API 库

    用 Go 语言与 EOS IO 交互的 API 库 该库提供对数据架构 二进制打包和JSON接口 的简单访问 以及对远程或本地运行的EOS IO RPC服务器的API调用 它提供钱包功能 KeyBag 或者可以通过 keosd 钱包签署交易
  • EasyExcel填充数据EasyExcel填充数据流下载 easyexcel填充excel下载 easyexcel填充

    EasyExcel填充数据EasyExcel填充数据流下载 easyexcel填充excel下载 easyexcel填充 1 填充数据然后将文件输出给浏览器 1 填充数据然后将文件输出给浏览器 官网地址 官网的demo填充生成的是file文
  • python爬虫、某云音乐直链爬取

    1 通过浏览器抓包分析 寻找音乐直链所在的api F12打开开发者工具 然后随便播放一首 在Network的XHR中寻找歌曲的直链 最终发现在v1 csrf token 中返回了歌曲的地址 将链接在浏览器中打开 发现果然是该音乐的下载地址
  • 苹果MDM原理和实现过程

    最近一段时间鼓捣了苹果MDM MDM 顾名思义就是移动管理 现在这里咱就不谈啥是移动设备管理了 直接进入正题 苹果的MDM主要是通过苹果MDM服务器实现整体流程如下 1 首先客户端需要从后台服务器 服务器自己部署 下载苹果配置文件或者说描述
  • SpringBoot整合Shiro实现登录和注册功能

    首先 让我们介绍一下Shiro Shiro是一个非常流行的Java安全框架 它提供了身份验证 授权 加密和会话管理等安全功能 Shiro的一个重要特点是它的易用性和灵活性 它可以与各种Java框架 如Spring Spring Boot S
  • [NLP] transformers 使用指南

    严格意义上讲 transformers 并不是 PyTorch 的一部分 然而 transformers 与 PyTorch 或 TensorFlow 结合的太紧密了 而且可以把 transformers 看成是 PyTorch 或 Ten