使用自动模型

2023-11-08

本文通过文本分类任务演示了HuggingFace自动模型使用方法,既不需要手动计算loss,也不需要手动定义下游任务模型,通过阅读自动模型实现源码,提高NLP建模能力。

一.任务和数据集介绍
1.任务介绍
前面章节通过手动方式定义下游任务模型,HuggingFace也提供了一些常见的预定义下游任务模型,如下所示:

说明:包括预测下一个词,文本填空,问答任务,文本摘要,文本分类,命名实体识别,翻译等。

2.数据集介绍
本文使用ChnSentiCorp数据集,不清楚的可以参考中文情感分类介绍。一些样例如下所示:

二.准备数据集
1.使用编码工具

def load_encode_tool(pretrained_model_name_or_path):
    """
    加载编码工具
    """
    tokenizer = BertTokenizer.from_pretrained(Path(f'{pretrained_model_name_or_path}'))
    return tokenizer
if __name__ == '__main__':
    # 测试编码工具
    pretrained_model_name_or_path = r'L:/20230713_HuggingFaceModel/bert-base-chinese'
    tokenizer = load_encode_tool(pretrained_model_name_or_path)
    print(tokenizer)

输出结果如下所示:

BertTokenizer(name_or_path='L:\20230713_HuggingFaceModel\bert-base-chinese', vocab_size=21128, model_max_length=1000000000000000019884624838656, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True)

2.定义数据集
直接使用HuggingFace数据集对象,如下所示:

def load_dataset_from_disk():
    pretrained_model_name_or_path = r'L:\20230713_HuggingFaceModel\ChnSentiCorp'
    dataset = load_from_disk(pretrained_model_name_or_path)
    return dataset
if __name__ == '__main__':
    # 加载数据集
    dataset = load_dataset_from_disk()
    print(dataset)

输出结果如下所示:

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 9600
    })
    validation: Dataset({
        features: ['text', 'label'],
        num_rows: 1200
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 1200
    })
})

3.定义计算设备

# 定义计算设备
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'
# print(device)

4.定义数据整理函数

def collate_fn(data):
    sents = [i['text'] for i in data]
    labels = [i['label'] for i in data]
    #编码
    data = tokenizer.batch_encode_plus(batch_text_or_text_pairs=sents, # 输入文本
            truncation=True, # 是否截断
            padding=True, # 是否填充
            max_length=512, # 最大长度
            return_tensors='pt') # 返回的类型
    #转移到计算设备
    for k, v in data.items():
        data[k] = v.to(device)
    data['labels'] = torch.LongTensor(labels).to(device)
    return data

5.定义数据集加载器

# 数据集加载器
loader = torch.utils.data.DataLoader(dataset=dataset['train'], batch_size=16, collate_fn=collate_fn, shuffle=True, drop_last=True)
print(len(loader))

# 查看数据样例
for i, data in enumerate(loader):
    break
for k, v in data.items():
    print(k, v.shape)

输出结果如下所示:

600
input_ids torch.Size([16, 200])
token_type_ids torch.Size([16, 200])
attention_mask torch.Size([16, 200])
labels torch.Size([16])

三.加载自动模型
使用HuggingFace的AutoModelForSequenceClassification工具类加载自动模型,来实现文本分类任务,代码如下:

# 加载预训练模型
model = AutoModelForSequenceClassification.from_pretrained(Path(f'{pretrained_model_name_or_path}'), num_labels=2)
model.to(device)
print(sum(i.numel() for i in model.parameters()) / 10000)

四.训练和测试
1.训练
需要说明自动模型本身包括loss计算,因此在train()中就不再需要手工计算loss,如下所示:

def train():
    # 定义优化器
    optimizer = AdamW(model.parameters(), lr=5e-4)
    # 定义学习率调节器
    scheduler = get_scheduler(name='linear', # 调节器名称
                              num_warmup_steps=0, # 预热步数
                              num_training_steps=len(loader), # 训练步数
                              optimizer=optimizer) # 优化器
    # 将模型切换到训练模式
    model.train()
    # 按批次遍历训练集中的数据
    for i, data in enumerate(loader):
        # print(i, data)
        # 模型计算
        out = model(**data)
        # 计算1oss并使用梯度下降法优化模型参数
        out['loss'].backward() # 反向传播
        optimizer.step() # 优化器更新
        scheduler.step() # 学习率调节器更新
        optimizer.zero_grad() # 梯度清零
        model.zero_grad() # 梯度清零
        # 输出各项数据的情况,便于观察
        if i % 10 == 0:
            out_result = out['logits'].argmax(dim=1)
            accuracy = (out_result == data.labels).sum().item() / len(data.labels)
            lr = optimizer.state_dict()['param_groups'][0]['lr']
            print(i, out['loss'].item(), lr, accuracy)

其中,out数据结构如下所示:

2.测试

def test():
    # 定义测试数据集加载器
    loader_test = torch.utils.data.DataLoader(dataset=dataset['test'],
                                              batch_size=32,
                                              collate_fn=collate_fn,
                                              shuffle=True,
                                              drop_last=True)
    # 将下游任务模型切换到运行模式
    model.eval()
    correct = 0
    total = 0
    # 按批次遍历测试集中的数据
    for i, data in enumerate(loader_test):
        # 计算5个批次即可,不需要全部遍历
        if i == 5:
            break
        print(i)
        # 计算
        with torch.no_grad():
            out = model(**data)
        # 统计正确率
        out = out['logits'].argmax(dim=1)
        correct += (out == data.labels).sum().item()
        total += len(data.labels)
    print(correct / total)

五.深入自动模型源代码
1.加载配置文件过程
在执行AutoModelForSequenceClassification.from_pretrained(Path(f'{pretrained_model_name_or_path}'), num_labels=2)时,实际上调用了AutoConfig.from_pretrained(),该函数返回的config对象内容如下所示:

config对象如下所示:

BertConfig {
  "_name_or_path": "L:\\20230713_HuggingFaceModel\\bert-base-chinese",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "directionality": "bidi",
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "pooler_fc_size": 768,
  "pooler_num_attention_heads": 12,
  "pooler_num_fc_layers": 3,
  "pooler_size_per_head": 128,
  "pooler_type": "first_token_transform",
  "position_embedding_type": "absolute",
  "transformers_version": "4.32.1",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 21128
}

(1)_name_or_path=bert-base-chinese:模型名字。
(2)attention_probs_DropOut_prob=0.1:注意力层DropOut的比例。
(3)hidden_act=gelu:隐藏层的激活函数。
(4)hidden_DropOut_prob=0.1:隐藏层DropOut的比例。
(5)hidden_size=768:隐藏层神经元的数量。
(6)layer_norm_eps=1e-12:标准化层的eps参数。
(7)max_position_embeddings=512:句子的最大长度。
(8)model_type=bert:模型类型。
(9)num_attention_heads=12:注意力层的头数量。
(10)num_hidden_layers=12:隐藏层层数。
(11)pad_token_id=0:PAD的编号。
(12)pooler_fc_size=768:池化层的神经元数量。
(13)pooler_num_attention_heads=12:池化层的注意力头数。
(14)pooler_num_fc_layers=3:池化层的全连接神经网络层数。
(15)vocab_size=21128:字典的大小。

2.初始化模型过程
BertForSequenceClassification类构造函数包括一个BERT模型和全连接神经网络,基本思路为通过BERT提取特征,通过全连接神经网络进行分类,如下所示:

def __init__(self, config):
    super().__init__(config)
    self.num_labels = config.num_labels
    self.config = config

    self.bert = BertModel(config)
    classifier_dropout = (
        config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
    )
    self.dropout = nn.Dropout(classifier_dropout)
    self.classifier = nn.Linear(config.hidden_size, config.num_labels)

    # Initialize weights and apply final processing
    self.post_init()

通过forward()函数可证明以上推测,根据问题类型为regression(MSELoss()损失函数)、single_label_classification(CrossEntropyLoss()损失函数)和multi_label_classification(BCEWithLogitsLoss()损失函数)选择损失函数。

参考文献:
[1]HuggingFace自然语言处理详解:基于BERT中文模型的任务实战
[2]https://github.com/ai408/nlp-engineering/blob/main/20230625_HuggingFace自然语言处理详解/第12章:使用自动模型.py

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用自动模型 的相关文章

随机推荐

  • ruby网站部署到服务器,Ruby China 已迁移到新的服务器,基于 Docker 部署

    终于决定要迁移新服务器了 之前那台老机器陪同 Ruby China 运作了 6 年 如果我没记错的话 系统还是 Ubuntu 12 04 昨天下班前还是准备 同步数据库到 UCloud 的 PostgreSQL 服务上 避免以后自己维护和备
  • Elasticsearch(一):入门篇

    文章目录 一 Docker安装ES和Kibana 二 基本概念 文档 index 索引 type 类型 id ID 三 保存或修改文档数据 POST PUT 四 检索文档 GET 1 检索一个文档 2 检索多个文档 mget 五 删除文档
  • 计算机组成原理——万字详解

    引言 作为还在学习的学生和不断进步的同事 学习计算机组成原理具有以下几个重要的好处 它可以帮助你深入理解计算机系统的工作原理 包括处理器 存储器 输入输出设备等组成部分之间的交互关系 这种深入理解可以提高你对计算机系统的整体把握能力 让你能
  • Selenium 自动化测试实战笔记1

    1 安装 selenium pip install selenium 3 11 0 安装指定版本 pip install selenium U 安装最新版本 pip show selenium 查看当前版本 pip uninstall se
  • linux关机等待90秒

    ubuntu关机时 提示 A stop job is running for Session c2 of user 1min 30s 解决方法 sudo gedit etc systemd system conf 去除默认的注释 修改为 D
  • Lua : 回调函数不用怕,用法简单仿C/C++

    Lua也可以做回调函数 那当然 不明觉厉 嘿嘿嘿 那是不是可以在Lua编程时候搞点飞机啦 加 function add x y return x y end 减法 function minux x y return x y end func
  • 使用Iframe+Post请求的方式嵌入第三方页面

    背景描述 本身我们有自己的一个系统 之后采购了一个新系统 新系统的页面要嵌入到我们自己系统页面来 两个系统之间的权限交互通过token来进行传递和认证 本身嵌入采用如下方式就非常简单了 就是常规的iframe嵌入页面的方式 常规的ifram
  • Windows10上使用VS2017编译OpenCV3.4.2+OpenCV_Contrib3.4.2+Python3.6.2操作步骤

    1 从https github com opencv opencv releases 下载opencv 3 4 2 zip并解压缩到D soft OpenCV3 4 2 opencv 3 4 2目录下 2 从https github com
  • ps2021神经网络AI滤镜下载,ps神经网络滤镜安装包

    如何解决ps2021 新版 AI神经滤镜不能用 网上买正版 更新下就好了 盗版的都会有各种这样的问题 ps2021神经AI滤镜是需简要上传云端 由Adobe官方服务器人工智能运算的 Ps2021版本新增了Ai神经元滤镜 它不是与软件一起安装
  • 如何实现Android app开机自启动

    这里写目录标题 前言 代码实现 AndroidManifest xml BootReceiver java MainActivity java MyService java 问题解决 前言 上一篇文章如何实现无界面Android app介绍
  • 深度学习中Epoch、Batch以及Batch size的设定

    Epoch 时期 当一个完整的数据集通过了神经网络一次并且返回了一次 这个过程称为一次 gt epoch 也就是说 所有训练样本在神经网络中都 进行了一次正向传播 和一次反向传播 再通俗一点 一个Epoch就是将所有训练样本训练一次的过程
  • ajax url传递中文乱码,jquery.ajax的url中传递中文乱码问题的解决方法

    JQuery JQuery默认的contentType application x www form urlencoded 这才是JQuery正在乱码的原因 在未指定字符集的时候 是使用ISO 8859 1 ISO8859 1 通常叫做La
  • 数据结构与算法实验3(栈) 括号匹配

    数据结构与算法实验3 栈 括号匹配 用栈ADT应用 对称符号匹配判断 输入一行符号 以 结束 判断其中的对称符号是否匹配 对称符号包括 lt gt 输出分为以下几种情况 1 对称符号都匹配 输出 right 2 如果处理到最后出现了失配 则
  • 54331 DCDC 纹波 干扰 收音机 原因

    用了一个TPS54331 把12 V 转5V后 再经过一个LDO转换为3 3V给收音IC 结果干扰非常大 这之前用的是LM2596 对收音机干扰很小 分析输出纹波大 但是一直找不到原因 最后 经过排查 对比54331的datasheet 发
  • docker第三讲 docker启动redis容器以及解决redis-server启动redis直接挂的问题

    本地启动配置redis 安装包安装 下载安装包 下载地址 Download Redis 安装gcc yum install gcc 把下载好的redis 6 2 1r 1 tar gz放在 usr local文件夹下 并解压 wget ht
  • 小波教程-part2-傅立叶变换和短时傅立叶变换

    1 基本原理 让我们简要回顾一下第一部分 我们基本上需要小波变换 WT 来分析非平稳信号 即其频率响应随时间变化的信号 我已经写过傅立叶变换 FT 不适合非平稳信号 并且已经展示了一些例子以使其更加清晰 快速回顾一下 让我举一个例子 假设我
  • 使用 Live555 搭建流媒体服务器

    搭建环境为Centos 7 2 64bit 一 安装gcc编译器 yum install gcc c 二 安装live555 wget http www live555 com liveMedia public live555 latest
  • Swagger使用详解(基于knife4j方案)

    1 简介 Swagger 是一个规范和完整的框架 用于生成 描述 调用和可视化 RESTful 风格的 Web 服务 总体目标是使客户端和文件系统作为服务器以同样的速度来更新 文件的方法 参数和模型紧密集成到服务器端的代码 允许 API 来
  • AttributeError: ‘function‘ object has no attribute ‘xxx‘报错问题

    问题描述 AttributeError function object has no attribute send bp route mail def mail message Message subject 邮箱测试 recipients
  • 使用自动模型

    本文通过文本分类任务演示了HuggingFace自动模型使用方法 既不需要手动计算loss 也不需要手动定义下游任务模型 通过阅读自动模型实现源码 提高NLP建模能力 一 任务和数据集介绍 1 任务介绍 前面章节通过手动方式定义下游任务模型