Bert模型做多标签文本分类

2023-11-12

Bert模型做多标签文本分类

参考链接

BERT模型的详细介绍

图解BERT模型:从零开始构建BERT

(强推)李宏毅2021春机器学习课程

我们现在来说,怎么把Bert应用到多标签文本分类的问题上。注意,本文的重点是Bert的应用,对多标签文本分类的介绍并不全面

单标签文本分类

对应单标签文本分类来说,例如二元的文本分类,我们首先用一层或多层LSTM提取文本序列特征,然后接一个dropout层防止过拟合,最后激活函数采用sigmoid,或者计算损失的时候使用sigmoid交叉熵损失函数。对于多元分类则激活函数采用softmax,其它没有差别

多标签文本分类

怎么从单标签分类问题拓展到多标签分类呢?

我们可以把二元分类的情况归并到多元分类

至少有以下两种方案(我懂的):

1,最后的全连接层以sigmoid作为激活函数,把每个神经元都当成是二元分类。另外,也可以直接把最后的全连接层改成n个全连接层,每个全连接层再接一个神经元做二元分类(激活函数是sigmoid),我认为二者本质上没有区别。

2,将多标签分类任务视作seq2seq的问题,对于给定的文本序列,生成不定长的标签序列。

这篇文章将介绍第一种方案。

首先我们先看看怎么使用Bert模型

下载transformers包,pip install transformers

如果是处理英文问题,并且不用统一大小写的话,可以按照下方链接下载

其次手动下载模型,下载bert-base-uncasedconfig.josn,vocab.txt,pytorch_model.bin三个文件

配置文件下载地址:https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json

模型文件下载地址:https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin

词汇表下载地址:https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt

下载完成后,按照config.json,vocab.txt,pytorch_model.bin重命名,放在bert-base-uncased文件夹下,此例中bert-base-uncased文件夹放置在项目根目录下

如果是处理中文任务,把链接中的bert-base-uncased替换成bert-base-chinese即可,存放文件夹名可根据习惯修改为相应模型的名称

下面的demo是基于中文bert演示的,真正的多标签分类的项目代码里用的是bert-base-uncased

导入包,加载预训练模型

import numpy as np
import torch 
from transformers import BertTokenizer, BertConfig, BertForMaskedLM, BertForNextSentencePrediction
from transformers import BertModel

model_name = "bert-base-chinese"
# a. 通过词典导入分词器
tokenizer = BertTokenizer.from_pretrained(model_name)
# b. 导入配置文件
model_config = BertConfig.from_pretrained(model_name)
# 修改配置
model_config.output_hidden_states = True
model_config.output_attentions = True
# 通过配置和路径导入模型
bert_model = BertModel.from_pretrained(model_name, config = model_config)

完成模型加载后,我们来看看Bert的输入输出

输入

假设我们输入了一句话是“我爱你,你爱我”,我们需要利用tokernizer做初步的embedding处理

sen_code = tokenizer.encode_plus("我爱你,你爱我")

得到的sen_code是这样的

{‘input_ids’: [101, 2769, 4263, 872, 102, 872, 4263, 2769, 102],

‘token_type_ids’: [0, 0, 0, 0, 0, 1, 1, 1, 1],

‘attention_mask’: [1, 1, 1, 1, 1, 1, 1, 1, 1]}

input_ids就是每个字符在字符表中的编号,101表示[CLS]开始符号,[102]表示[SEP]句子结尾分割符号。

token_type_ids是区分上下句的编码,上句全0,下句全1,用在Bert的句子预测任务上

attention_mask表示指定哪些词作为query进行attention操作,全为1表示self-attention,即每个词都作为query计算跟其它词的相关度

将input_ids转化回token

tokenizer.convert_ids_to_tokens(sen_code['input_ids'])
#output:['[CLS]', '我', '爱', '你', '[SEP]', '你', '爱', '我', '[SEP]']

Bert的输入是三个embedding的求和,token embedding,segment embedding和position embedding

# token embedding
tokens_tensor = torch.tensor([sen_code['input_ids']]) # 添加batch维度
# segment embedding
segments_tensors = torch.tensor([sen_code['token_type_ids']]) # 添加batch维度

输出

Bert是按照两个任务进行预训练的,分别是遮蔽语言任务(MLM)和句子预测任务。

我先简单解释一下这两个任务

遮蔽语言任务(Masked Language Model

对输入的语句中的字词 随机用 [MASK] 标签覆盖,然后模型对mask位置的单词进行预测。这个过程类似CBOW训练的过程,我们利用这个训练任务从而得到每个字符对应的embedding。特别的,[CLS]字符的embedding我们可以视为整个句子的embedding。我们可以理解为[CLS]字符跟句子中的其它字符都没有关系,能较为公平的考虑整个句子。

句子预测任务(NextSentence Prediction

该任务就是给定一篇文章中的两句话,判断第二句话在文本中是否紧跟在第一句话之后。如果我们训练的时候将问题和答案作为上下句作为模型输入,该任务也可以理解为判断问题和答案是否匹配

现在我们根据代码看看bert的输出

bert_model.eval()
with torch.no_grad():
    outputs = bert_model(tokens_tensor, token_type_ids = segments_tensors)
    encoded_layers = outputs   # outputs类型为tuple

最后一个隐藏层的输出,即遮蔽语言任务的输出,亦即每个字符的embedding

print("sequence output",encoded_layers[0].shape)
# sequence output torch.Size([1, 9, 768])

最后一个隐藏层的第一个输出[CLS]的embedding,然后进行pool操作的结果,所谓的pool操作就是接一个全连接层+tanh激活函数层。它可以作为整个句子的语义表示,但也有将所有向量的平均作为句子的表示的做法

print("pooled output",encoded_layers[1].shape)
# pooled output torch.Size([1, 768])

所有隐藏层的输出,hidden_states有13个元素,第一个是[CLS]的embedding,后面12个元素表示12个隐藏层的输出,对于seq2seq的任务,它们将作为decoder的输入

print("hidden_states",len(encoded_layers[2]),encoded_layers[2][0].shape)
# hidden_states 13 torch.Size([1, 9, 768])

attention分布,有12个元素,每个隐藏层的hidden_states经过self-attention层得到的attention分布,没有乘以V矩阵。因为是multi-head,一共有12个头,所以每个attention分布的维度是1x12x9x9(1是batch_size,9是序列长度)

print("attentions",len(encoded_layers[3]),encoded_layers[3][0].shape)
# attentions 12 torch.Size([1, 12, 9, 9])

要明白上面的输出为什么是那个意思,还是得看源码Bert代码详解(一)

模型构建

搞明白bert的输入输出之后我们就可以试着做fine-tune了,我们是要做多标签文本分类,根据第一个方案,我们首先提取出文本的特征,然后接全连接层,最后接一个sigmoid激活函数。

前面已经说过,pooled output就是表示bert得到的整个句子的语义特征,这正是我们需要的。将这个特征作为全连接层的输入即可。代码里面还定义了dropout层,这都是训练的常用技巧,防止过拟合

class BertForMultiLabel(BertPreTrainedModel):
    def __init__(self, config):
        super(BertForMultiLabel, self).__init__(config)
        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_ids, token_type_ids=None, attention_mask=None, head_mask=None):
        outputs = self.bert(input_ids, token_type_ids,attention_mask,head_mask)
        pooled_output = outputs[1]
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        return self.sigmoid(logits)
    def unfreeze(self, start_layer, end_layer):
        def children(m):
            return m if isinstance(m, (list, tuple)) else list(m.children())

        def set_trainable_attr(m, b):
            m.trainable = b
            for p in m.parameters():
                p.requires_grad = b

        def apply_leaf(m, f):
            c = children(m)
            if isinstance(m, nn.Module):
                f(m)
            if len(c) > 0:
                for l in c:
                    apply_leaf(l, f)

        def set_trainable(l, b):
            apply_leaf(l, lambda m: set_trainable_attr(m, b))

        set_trainable(self.bert, False)
        for i in range(start_layer, end_layer + 1):
            set_trainable(self.bert.encoder.layer[i], True)

定义损失函数,优化器和超参数

Bert原项目对训练使用了很多性能、显存消耗的优化技术,包括warmup,gradient accumulation,还有fp16,这些技术我暂时也没有全部搞懂,所以暂时抛弃部分优化技术,写一个最简单的优化器。AdamW是Bert预训练采用的优化算法,大家如果不懂可以去百度一下,我也不是很了解,所以就直接用了

# 定义超参数
batch_size = 8
lr = 2e-5
adam_epsilon = 1e-8
grad_clip = 1.0
start_layer = 11  #[0,11]
end_layer = 11	  #[start_layer,11]

# 定义损失函数
loss = nn.BCELoss()
# 定义优化器
optimizer = optim.AdamW(model.parameter(), lr=lr, eps=adam_epsilon)

# 加载模型
model = BertForMultiLabel(config)
# 现在使用的Bert模型是12层,我们可以自由调节冻结bert模型的层数,当前是只训练最后一层
model.unfreeze(start_layer, end_layer)
model = model.cuda()

加载处理数据集

一个模型想要跑起来必然需要数据输入,Bert对参与训练的数据格式要求为input_ids, input_mask, segment_ids, label_ids。而原始的数据格式为string,label_ids

所以我们需要对数据做一些处理,为此我们定义一个BertProcessor类,这个类的主要方法为read_dataset和train_val_split。

注意我现在的做法和那些好的做法有很多差别,那些好的做法是基于优化的考虑,但我们现在暂时不用考虑这么多,把重心放在bert的使用和模型的成功训练上,优化做法读者可进一步研究。

先看类中部分代码,完整项目在最后

class BertProcessor:
    def __init__(self, vocab_path, do_lower_case, max_seq_length) -> None:
        self.tokenizer = BertTokenizer(vocab_path, do_lower_case)
        self.max_seq_length = max_seq_length

    def get_input_ids(self, x):
        # 使用tokenizer对字符编码
        # 并将字符串填充或裁剪到max_seq_length的长度
        ...

    def get_label_ids(self, x):
        # 合并标签为一个list
        ...

    def read_dataset(self, file_path, train=True):
        data = pd.read_csv(file_path)
        if train:
            data['label_ids'] = data.iloc[:, 2:].apply(self.get_label_ids, axis=1)
            label_ids = torch.tensor(list(data['label_ids'].values))
        # 英文预处理,包括去除停用词,大小写转换,删除无关字符,拆解单词等等
        preprocessor = EnglishPreProcessor()
        tqdm.pandas(desc="english preprocess")
        data['comment_text'] = data['comment_text'].progress_apply(preprocessor)
        # 对每一个comment_text做encode操作
        tqdm.pandas(desc="convert tokens to ids")
        data['input_ids'] = data['comment_text'].progress_apply(self.get_input_ids)
        input_ids = torch.tensor(list(data['input_ids'].values), dtype=torch.int)
        input_mask = torch.ones(size=(len(data), self.max_seq_length), dtype=torch.int)
        segment_ids = torch.zeros(size=(len(data), self.max_seq_length), dtype=torch.int)
        if train:
            dataset = Data.TensorDataset(input_ids, input_mask, segment_ids, label_ids)
        else:
            dataset = Data.TensorDataset(input_ids, input_mask, segment_ids)
        return dataset

我想如果前面输入输出部分大家看懂的话,read_dataset函数很容易看懂

模型训练

有几点需要注意一下,为了使用gpu,需要调用cuda方法将数据转移到gpu上,然后在反向传播计算梯度后,需要做一个梯度裁剪,即当梯度超过grad_clip的时候就把梯度设为grad_clip

def train(model, train_iter, valid_iter, n_epoch, loss, optimizer):
    history = {
        'train_loss': [],
        'val_loss': [],
        'val_acc': []
    }
    for epoch in range(n_epoch):
        train_loss, n = 0.0, 0
        for input_ids, input_mask, segment_ids, label_ids in tqdm(train_iter):
            input_ids = input_ids.cuda()
            input_mask = input_mask.cuda()
            segment_ids = segment_ids.cuda()
            logits = model(input_ids, segment_ids, input_mask)
            l = loss(logits, label_ids.float().cuda())
            l.backward()
            clip_grad_norm_(model.parameters(), grad_clip)
            optimizer.step()
            optimizer.zero_grad()
            train_loss += l.item()
        train_loss = train_loss/n
        val_loss, val_acc, n = 0.0, 0.0, 0
        with torch.no_grad():
            for input_ids, input_mask, segment_ids, label_ids in valid_iter:
                input_ids = input_ids.cuda()
                input_mask = input_mask.cuda()
                segment_ids = segment_ids.cuda()
                logits = model(input_ids, segment_ids, input_mask)
                label_ids = label_ids.float().cuda()
                l = loss(logits, label_ids)
                val_loss += l.item()
                val_acc += (torch.where(logits > 0.5, 1, 0) == label_ids).min(axis=1)[0].sum()
                n += len(label_ids)
            val_acc = val_acc / n
            val_loss = val_loss / n
        print("epoch %s train loss:%s val loss:%s" % (epoch + 1, train_loss, val_loss))
        history['train_loss'].append(train_loss)
        history['val_acc'].append(val_acc)
        history['val_loss'].append(val_loss)

        # save model checkpoint
        model.save_pretrained("models%s" % (epoch + 1))
    return history

把损失曲线和准确率曲线绘制出来就是这样

plt.plot(range(len(history['train_loss'])), history['train_loss'], label="train loss")
plt.show()
plt.plot(range(len(history['val loss'])), history['val loss'], label="val loss")
plt.show()
plt.plot(range(len(history['val acc'])), history['val_acc'])
plt.show()

暂时没图,待添加。。。

补充陈述

事实上,当我们评估多标签分类的模型的时候,前面只考虑了总体的Accuracy这个指标,但是还有很多更详细的metric需要考虑。这个也交给大家去查阅资料吧。

为了便于理解,前面的实现非常的粗糙,摒弃了很多好的优化策略。大家可以看看这个repo里面的实现,我就是参考的这个代码。代码里面在读取数据时设置了缓存机制,方便再次运行的时候快速读取数据,然后模型保存,日志输出,训练性能显存优化,模型评估等方面都有更好的处理。

至于我这份代码,后续应该会逐渐改进,如果大家有需要可以评论或私信留下邮箱地址。

补充:没想到要代码的人还有点多,发到评论区置顶了

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Bert模型做多标签文本分类 的相关文章

随机推荐

  • Vue - 去掉路由中的#号

    vue router默认是hash模式 在hash模式下 是会有 号在URL上 可以在路由的第一行添加 mode history来去掉 号 const router new Router mode history routes 一开始用的t
  • nginx实践(一)、安装和部署

    很长一段时间没有更新blog 因为这一段时间 都在研究一个实时事件处理应用 计划把该实时事件处理服务 修改为分布式服务 相关内容以后再与大家汇报 好了 因为工作缘故 要分析一个使用nginx的应用 所以从本节开始 介绍一下nginx的相关实
  • H5页面在ios的浏览里返回不刷新页面,解决浏览器自带缓存的问题。

    1 利用pageshow来解决 pageshow的web api pageshow Web API 接口参考 MDN 2 解决 在app vue里面 isPageHide false 定义变量 created void window add
  • Connections between cities 【HDU - 2874】【在线LCA算法】

    题目链接 昨天刚学了在线LCA 今天就来硬刚这道题还是花了一整天的时间 不过对于LCA却有了更多的理解 这道题在讲述不同根的做法上尤其是很好的 题目告诉我们有N个节点和M条边 以及C次询问 每次查询的是 L R 这两个节点间的距离 还是算得
  • dbeaver 配置mysql数据库驱动

    右键点击要数据库连接选择 编辑连接 然后点击 编辑驱动设置 从mysql8版本后 mysql的驱动类名发生改变 变成了com mysql cj jdbc Driver 所以如果要连接的数据库版本在8之前 需将 设置 界面的 类名 处改为 c
  • 华为od机试 Python【快递装载】

    前言 本题使用python解答 如果需要Java版本 请参考 点我 题目 快递需要按照一定的规则装载 所有的快递放在长方体的盒子当中 我们的需要是尽可能装载更多的快递 并且不能让货车超载 需要计算最多能装多少个快递 快递数最多1000个 货
  • 双列集合系列之Map集合的初了解

    Welcome Huihui s Code World 接下来看看由辉辉所写的关于双列集合的相关操作吧 目录 Welcome Huihui s Code World 顶级接口Map 一 Map集合的特点 二 Map集合的常见子类 HashM
  • xss渗透(跨站脚本攻击)

    一 什么是XSS XSS全称是Cross Site Scripting即跨站脚本 当目标网站目标用户浏览器渲染HTML文档的过程中 出现了不被预期的脚本指令并执行时 XSS就发生了 这里我们主要注意四点 1 目标网站目标用户 2 浏览器 3
  • 项目管理中什么最重要?

    被问过多次这个问题 尤其是在面试的时候 有说需求最重要 有说控制最重要 有的冠冕堂皇 来个成本 质量 时间三要素 美其名曰都重要 免得以偏概全 经多方求证 思索 结合十余年的项目管理经历 敝以为 项目管理中干系人管理最重要 尤其是关键干系人
  • Java面向对象编程

    一个关系数据库文件中的各条记录 A 前后顺序不能任意颠倒 一定要按照输入的顺序排列 B 前后顺序可以任意颠倒 不影响库中的数据关系 C 前后顺序可以任意颠倒 但排列顺序不同 统计处理的结果就可能不同 D 前后顺序不能任意颠倒 一定要按照关键
  • textarea placeholder不显示

    textarea placeholder不显示 textarea 的 placeholder 属性值不显示的原因可能是
  • DirectD3D-纹理映射

    DirectD3D 纹理映射 标签 Direct3Ddirectx游戏游戏开发 2014 11 12 14 03 321人阅读 评论 0 收藏 举报 分类 DirectX 8 版权声明 本文为博主原创文章 未经博主允许不得转载 纹理映射的概
  • python哪些类型可以作为迭代器_Python教程|全面理解Python迭代器和生成器

    在Python中 很多对象都是可以通过for语句来直接遍历的 例如list string dict等等 这些对象都可以被称为可迭代对象 至于说哪些对象是可以被迭代访问的 就要了解一下迭代器相关的知识了 迭代器 迭代器对象要求支持迭代器协议的
  • Golang架构直通车——理解Go GC

    文章目录 设计原理 三色抽象 三色不变性 插入写屏障 删除写屏障 垃圾收集器的增量和并发 增量式垃圾收集 并发式垃圾收集器 Go GC演进过程 并发垃圾收集 回收堆目标 混合写屏障 设计原理 三色抽象 标记清除 Mark Sweep 算法是
  • 数学建模--退火算法求解最值的Python实现

    目录 1 算法流程简介 2 算法核心代码 3 算法效果展示 1 算法流程简介 1 设定退火算法的基础参数 2 设定需要优化的函数 求解该函数的最小值 最大值 3 进行退火过程 随机产生退火解并且纠正 直到冷却 4 绘制可视化图片进行了解退火
  • 异步javaScript

    在本文中 我们将解释什么是异步编程 为什么我们需要它 并简要讨论 JavaScript 历史上异步函数是怎样被实现的 预备知识 基本的计算机素养 以及对 JavaScript 基础知识的一定了解 包括函数和事件处理程序 目标 熟悉异步 Ja
  • 日增30-40亿数据量的数据库

    author skate time 2010 08 13 前几天和个朋友聊天 他说他有每天30 40亿条数据量的数据库如何规划与优化 简单了解需求是这30 40亿数据是每天采集的 然后同时还对这些采集的数据进行分析挖掘 对于这么大量的数据量
  • MySQL数据库使用小皮系统(phpstudy)的安装及配置流程

    小皮系统phpstudy的安装及配置流程 一 小皮系统 phpstudy 的下载 二 数据库管理工具 一 小皮系统 phpstudy 的下载 搜索 phpStudy V8 1 下载大约 78m 左右 官网下载地址 phpStudy 可以随时
  • Android红外遥控器移植

    1 编译hal层代码 红外的hal代码路径 hardware libhardware modules consumerir 最终生成consumerir default so 但system文件系统中并没有该库 选择安装该库即可 在devi
  • Bert模型做多标签文本分类

    Bert模型做多标签文本分类 参考链接 BERT模型的详细介绍 图解BERT模型 从零开始构建BERT 强推 李宏毅2021春机器学习课程 我们现在来说 怎么把Bert应用到多标签文本分类的问题上 注意 本文的重点是Bert的应用 对多标签