transformers库的使用【二】tokenizer的使用,模型的保存自定义

2023-11-02

使用标记器(tokenizer)

在之前提到过,标记器(tokenizer)是用来对文本进行预处理的一个工具。

首先,标记器会把输入的文档进行分割,将一个句子分成单个的word(或者词语的一部分,或者是标点符号)

这些进行分割以后的到的单个的word被称为tokens。

第二步,标记器会把这些得到的单个的词tokens转换成为数字,经过转换成数字之后,我们就可以把它们送入到模型当中。

为了实现这种能把tokens转换成数字的功能,标记器拥有一个词表,这个词汇表是在我们进行实例化并指明模型的时候下载的,这个标记器使用的词汇表与模型在预训练时使用的词汇表相同。

举个例子说:

from transformers import AutoTokenizer,AutoModelForSequenceClassification

Model_name = 'distillery-base-uncashed-finetuned-still-2-english'

model=AutoModelForSequenceClassification.from_pretrained(model_name)

tokenizer=AutoTokenizer.from_pretrained(model_name)

sentence="We are very happy to show you the Transformers library"

inputs = tokenizer(sentence)

然后打印一下得到的结果:

print(inputs)

{'input_ids': [101, 2057, 2024, 2200, 3407, 2000, 2265, 2017, 1996, 100, 19081, 3075, 1012, 102],

'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

可以看到,返回值是一个字典,这个字典里面有两个键值对,第一个键值对'input_ids'是对输入的句子转换成数字以后的结果,并且长度为这个句子的单词的个数。

第二个'attention_mask'这里面全部都是1,表示让模型关注里面所有的词,具体相关的应用后面会再提到。

上面的例子是拿一个句子放入标记器中得到的结果,如果希望一次放入一批(batch)语句,希望将这一批句子都转换成为数字送到模型里面去,那么你可以这么做

sentences=["We are very happy to show you the Transformers library",

"We hope you don't hate it"]


 

Pt_batch = tokenizer(

Sentences,

padding=True,

truncation=True,

max_length=512,

return_tensors="Pt"

)

首先padding属性是用来指明是否启用填补。他会自动补全结果中的input_ids以及attention_mask右边缺失的值。

打印一下结果来看一下:

for key,value in pt_batch.items():

print(f"{key}:{value.numpy().tolist()}")

input_ids: [[101, 2057, 2024, 2200, 3407, 2000, 2265, 2017, 1996, 100, 19081, 3075, 1012, 102], [101, 2057, 3246, 2017, 2123, 1005, 1056, 5223, 2009, 1012, 102, 0, 0, 0]]
attention_mask: [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]]

可以看到结果中第二个句子的最右边补充了一些0,这是因为使用了padding属性,第二个句子没有第一个句子长,而我们希望得到的结果都是一样长的,所以会自动的在结尾补充0,并且在attention_mask字段里面也补充了0。

使用模型

当我们对输入的数据使用标记器进行处理之后,可以直接把它送到模型当中,这些数据会包含所有模型需要的相关信息。

在使用pytorch的时候,你需要可以用下面的方法对字典类型进行解包:

Pt_outputs = pt_model(**pt_batch)

在Transformers中,所有的输出都是一个元组(tuple)

Print(pt_ourputs)

(tensor([[-4.0833,  4.3364],
        [ 0.0818, -0.0418]], grad_fn=<AddmmBackward>),)

可以看到得到的结果

接下来使用SoftMax激活函数进行预测,并打印一下最终的结果

Import torch.nn.functional as F

pt_predictions = F.softmax(py_output[0],dim=-1)

print(pt_predictions)
tensor([[2.2043e-04, 9.9978e-01],
        [5.3086e-01, 4.6914e-01]], grad_fn=<SoftmaxBackward>)

这里输出的只是经过了softmax函数后得到的结果,那么如果有标签的时候,需要在使用模型的时候,在label字段指明标签

import torch

pt_output = pt_model(**pt_batch,labels = torch.tensor([1,0]))

在Transformers提供了一个Trainer类来帮助训练

模型的保存

在模型进行微调之后,可以对模型以及标记器进行保存操作

save_directory='E:/my model/'

tokenizer.save_pretrained(save_directory)

model.save_pretrained(save_directory)

这样就可以将模型进行保存

模型的加载

如果想要重新加载之前训练好并保存的模型,可以使用一个from_pretrained()方法,通过传入保存了模型的文件夹路径。

tokenizer = AutoTokenizer.from_pretrained(save_directory)

model = AutoModel.from_pretrained(save_directory)

如果希望读取TensorFlow模型,那么需要一点点改变

model=AutoModel.from_pretrained(save_directory,from_tf=True)

最终,如果在使用模型的时候,你希望得到的不仅仅是最终的输出,还希望能得到所有的隐藏层状态以及注意力权重,你可以这样做:

pt_outputs = pt_model(**pt_batch,output_hidden_states= True,output_attentions=True)

All_hidden_states ,all_attentions = pt_outputs[-2:]

访问代码

之前用到的AutoModel与AutoTokenizer两个类实际上可以和任何的预训练模型一起工作。

在之前的实例中,模型使用的是"distilbert-base-uncashed-finetuned-still-2-enghish",这意味着我们使用的是DistilBERT的结构。

在创建模型的时候用到的AutoModelForSequenceClassification会自动创建一个DistilBertForSequenceCLassification。

如果不使用自动的方式构建,我们可以使用下面的代码:

from transformers import DistilBertTokenizer,DistilBertForSequenceClassification

model_name = "distilbert-base-uncashed-fintuned-still-2-english"

model = DistilBertForSequenceClassification.from_pretrain(model_name)

tokenizer = DIstilBertTokenizer.from_pretrained(model_name)

自定义模型

如果希望改变的一些参数,来定义自己的特殊的类,那么可以使用模型特定的或者说相关的配置文件(configuration)比如说,在之前用熬的DistilBERT中,可以使用DistilBertConfig来设置隐藏层纬度,dropout rate等等。

具体来说:

from transformers import DIstilBertConfig,DIstilBertTokenizer,DistilBertForSequence

config = DistilBertTokenizer(n_heads=8,dim=512,hidden_dim=4*512)

tokenizer=DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

model = DistilBertForSequenceClassification(config)

如果你希望改变的只是模型的头,比如说标签的数量,那么你只需要直接改变模型创建时候的参数即可

from transformers import DIstilBertConfig,DistilBertTokenizer,DistilBertForSequenceClassification

model_name='distilbert-base-uncased'

model = DistilBertForSequenceClassification.from_pretrained(model_name,num_labels=10)

tokenizer = DistilBertTokenizer.from_pretrained(model_name)

 

 

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

transformers库的使用【二】tokenizer的使用,模型的保存自定义 的相关文章

  • VMware Workstation安装

    VMware Workstation安装 1 安装步骤 双击运行安装包程序 接受许可证协议 关键不接受不让安装啊 选择安装位置 建议非中文无空格 增强型键盘驱动程序可选 按照自身使用习惯勾选产品更新和客户体验提升计划 快捷方式 开始安装 稍
  • MD5加密

    1 md5是什么 md信息摘要算法 一种被广泛使用的密码散列函数 2 md5的特征 一 长度固定 任意长度的数据都会输出长度相等的md5值 二 不可逆 三 对原密码进行改动改变成一个字节输出数据 四 很少碰到两个不同的数据产生相同的md5值
  • 算法该不该刷?如何高效刷算法?

    一 算法该不该刷 最近有小伙伴向我咨询一个问题 就是算法该不该刷 该如何刷算法呢 这个问题可谓太大众化了 只要你去某乎 某度搜索一下相关的解答 会有无数种回答 可见这个问题困扰了多少学习计算机的同学们 但不管回答有多少种 总结一句话就是 算
  • 科大奥锐密立根油滴实验数据_密立根油滴实验数据表格

    静态法 平衡法 第1粒油滴数据 序数 U V t g s v g m s 1 q i C n i 个 e C 10 19 u e e 0 1 235 9 98 1 50E 04 1 12E 18 7 1 61 0 62 2 235 9 88
  • chatglm-6b模型在windows的详细安装教程

    1 先是看了github的文章 如果打不开这篇文章 可能需要科学上网 即访问外网的VPN https github com THUDM ChatGLM 6B 2 准备 台式机 GPU是8G 关于是否可以在笔记本运行 我后面测试下 等我下一篇

随机推荐

  • 什么是频谱仪的RBW带宽和VBW带宽

    1 RBW Resolution Bandwidth 代表两个不同频率的信号能够被清楚的分辨出来的最低频宽差异 两个不同频率的信号频宽如低于频谱分析仪的RBW 此时该两信号将重叠 难以分辨 RBW 分辨率带宽 有人也叫参考带宽 表示测试的是
  • 在laravel中合并路由_一些实用的 Laravel 小技巧

    Laravel 中一些常用的小技巧 说不定你就用上了 1 侧栏 1 网站一般都有侧栏 用来显示分类 标签 热门文章 热门评论啥的 但是这些侧栏都是相对独立的模块 如果在每一个引入侧栏的视图中都单独导入与视图有关的数据的话 未免太冗余了 所以
  • 算法——回溯法(子集、全排列、皇后问题)

    参考 http www cnblogs com wuyuegb2312 p 3273337 html intro 参考 算法竞赛入门经典 P120 1 定义 回溯算法也叫试探法 它是一种系统地搜索问题的解的方法 回溯算法的基本思想是 从一条
  • IDA宏定义

    This file contains definitions used by the Hex Rays decompiler output It has type definitions and convenience macros to
  • 机器学习中的 K-均值聚类算法及其优缺点。

    K 均值聚类算法是一种常见的无监督学习算法 它可以将数据集分成 K 个簇 每个簇内部的数据点尽可能相似 而不同簇之间的数据点应尽可能不同 下面详细讲解 K 均值聚类算法的优缺点 优点 简单易用 K 均值聚类算法是一种简单易懂的算法 容易理解
  • String index out of range错误与解决方法

    在做算法题时遇到了报错 原因是字符串的索引越界 查看自己的代码 原来int的类型范围越界 int的范围 2147483648 2147483647 long的范围 9223372036854775808 922337203685477580
  • golang-nil切片和空切片

    package main import fmt func main var a int b make int 0 if a nil fmt Println a is nil else fmt Println a is not nil if
  • Spring Boot中使用token:jwt

    token由3部分组成 Header Payload Signature 其中Header记录了签名算法和token 的类型 Payload是以明文存储的一些信息 包括用户自定义信息 Signature是使用签名算法 对Payload结合服
  • Android RxJava:组合 / 合并操作符 详细教程

    前言 Rxjava 由于其基于事件流的链式调用 逻辑简洁 使用简单的特点 深受各大 Android开发者的欢迎 Github截图 如果还不了解 RxJava 请看文章 Android 这是一篇 清晰 易懂的Rxjava 入门教程 RxJav
  • 04-----关于Qt下编译大文件的源码时报too many section

    1 关于Qt下编译大文件的源码时报too many section 这种问题是因为编译源码文件太大造成的 解决的方法如下 因为不同Qt版本可能添加的宏不一样 所以大家可能需要试一试下面的编译参数 我是用 Wa mbig obj 这个参数解决
  • Java面试题大全(整理版)附答案详解最全面看完稳了

    文末有彩蛋 进大厂是大部分程序员的梦想 而进大厂的门槛也是比较高的 所以这里整理了一份阿里 美团 滴滴 头条等大厂面试大全 其中概括的知识点有 Java MyBatis ZooKeeper Dubbo Elasticsearch Memca
  • 前端面试之道

    小册介绍 如果需要用一句话来介绍这本小册的话 一年磨一剑 应该是最好的答案了 为什么这样说呢 在出小册之前 我收集了大量的一线大厂面试题 通过大数据统计出了近百个常考知识点 然后根据这些知识点写成了这本小册 这本小册可以说是一线互联网大厂的
  • Win7封装全过程

    安装操作系统是个漫长而无聊的过程 我们个人安装原版系统都要花费半小时以上的时间 想象一下 一个500 1000 上万人的公司要是按这种方式装的话要花费多少时间 人力 物力 还好 系统制造商早就考虑到了这一点 有自己的应对之策 这就是操作系统
  • P10.编程生成Excel内图表

    P10 编程生成Excel内图表 md 插入图片 openpyxl插入图片 openpyxl drawing image sheet add image 例 from openpyxl drawing image import Image
  • Chapter Two : Python 语言基础、运算符与表达式、程序的控制结构合集

    目录 一 Python 语言基础 1 Python 语法规则 2 关键字与标识符 3 变量 4 基本数据类型 5 输入与输出 二 运算符与表达式 1 算术运算符 2 赋值运算符 3 比较 关系 运算符 4 逻辑运算符 5 位运算符 6 赋值
  • 2023华为OD统一考试(B卷)题库清单(按算法分类),如果你时间紧迫,就按这个刷

    目录 专栏导读 华为OD机试算法题太多了 知识点繁杂 如何刷题更有效率呢 一 逻辑分析 二 数据结构 1 线性表 数组 双指针 2 map与list 3 优先队列 4 滑动窗口 5 二叉树 6 并查集 7 栈 三 算法 1 基础算法 贪心算
  • xml文件报错Unable to resolve column ‘xxx‘

    项目场景 问题描述 我在使用mybatis的逆向工程时生成的xml文件报错Unable to resolve column xxx 原因分析 需要连接到数据库 解决方案 点击右侧 填写数据库信息 点击测试 报错的话点击下放Set time
  • shell 格式化输出密码

    格式化输出 etc passwd 效果如下 root zabbix server day6 awk F BEGIN print 用户名 UID 家目录 print 1 3 6 etc passwd 用户名 UID 家目录 root 0 ro
  • Unity 移动方法总结

    Unity移动方法总结 在Unity3D中 有多重方式可以改变物体的坐标 实现移动的目的 其本质是每帧改变物体的position 通过Transform组件移动物体 Transform组件用于描述物体在空间中的状态 它包括位置 positi
  • transformers库的使用【二】tokenizer的使用,模型的保存自定义

    使用标记器 tokenizer 在之前提到过 标记器 tokenizer 是用来对文本进行预处理的一个工具 首先 标记器会把输入的文档进行分割 将一个句子分成单个的word 或者词语的一部分 或者是标点符号 这些进行分割以后的到的单个的wo