深度学习实战11(进阶版)-BERT模型的微调应用-文本分类案例

2023-10-27

文章目录

一、前期工作

  1. 导入库包
  2. 导入数据

二、模型加载
三、模型训练

四、模型测试

大家好,我是微学AI,今天给大家带来一个基于BERT模型做文本分类的实战案例,在BERT模型基础上做微调,训练自己的数据集,相信之前大家很多都是套用别人的模型直接训练,或者直接用于预训练模型进行预测,没有训练和微调过大模型,因为像BERT这种大模型一般人是训练不了的,我们只能在大模型的基础上进行微调,或者做下游任务改造。

下面来介绍一下BERT模型,BERT是基于transfomer的预训练语言模型,它利用了transfomer中的编码器,进行数据编码,将文本数据转化为词向量。BERT核心内容是利用transfomer中的多头自注意力机制进行编码,关于transfomer的多头自注意力机制详细可以观看网络上的资料。

BERT模型是以两个NLP任务进行训练的,第一个任务是文本中词的预测,将已知训练文本隐掉词的信息,用MASK进行隐码,让模型去预测。第二个任务是在训练数据中随机抽取上下文关系句子或非上下文关系句子,让机器判断是否为上下文关系。BERT模型训练优势是无需进行标注数据。
我们可以利用BERT预训练模型进行下游任务改造,做自己相关任务,比如中文分词、文本分类,命名实体识别,阅读理解,情感分析,文本相似度、信息抽取等任务。

一、前期工作

1. 导入库包

import torch
from datasets import load_dataset
import torch.nn.functional as F
from transformers import BertTokenizer

#加载字典和分词工具
token = BertTokenizer.from_pretrained('bert-base-chinese')

2. 定义数据、数据载入

#定义数据集
class Dataset(torch.utils.data.Dataset):
    def __init__(self, split):
        self.dataset = load_dataset(path='data', split=split)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):
        text = self.dataset[i]['text']
        label = self.dataset[i]['label']

        return text, label

dataset = Dataset('train')
print(len(dataset), dataset[0])

def collate_fn(data):
    sents = [i[0] for i in data]
    labels = [i[1] for i in data]

    #编码
    data = token.batch_encode_plus(batch_text_or_text_pairs=sents,
                                   truncation=True,
                                   padding='max_length',
                                   max_length=500,
                                   return_tensors='pt',
                                   return_length=True)

    #input_ids:编码之后的数字
    #attention_mask:是补零的位置是0,其他位置是1
    input_ids = data['input_ids']
    attention_mask = data['attention_mask']
    token_type_ids = data['token_type_ids']
    labels = torch.LongTensor(labels)

    #print(data['length'], data['length'].max())
    return input_ids, attention_mask, token_type_ids, labels

#数据加载器
loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=10,
                                     collate_fn=collate_fn,
                                     shuffle=True,
                                     drop_last=True)

for i, (input_ids, attention_mask, token_type_ids,
        labels) in enumerate(loader):
    break

print(len(loader))
print(input_ids.shape, attention_mask.shape, token_type_ids.shape, labels)

这里代码需要在同级文件夹下创建data 文件夹, 放入train.csv、test.csv数据集。

数据集格式如下:

二、BERT模型加载

我们可以在BERT输出端接入一个全连接层,输出2分类问题,也可加入CNN卷积层,这些可以自行操作。

from transformers import BertModel

#加载预训练模型
pretrained = BertModel.from_pretrained('bert-base-chinese')

#不训练,不需要计算梯度
for param in pretrained.parameters():
    param.requires_grad_(False)

#模型试算
out = pretrained(input_ids=input_ids,
           attention_mask=attention_mask,
           token_type_ids=token_type_ids)

print(out.last_hidden_state.shape)


#定义下游任务模型
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = torch.nn.Linear(768, 2)
        # 可加入CNN卷积层,可以自行操作
        # self.conv1D = torch.nn.Conv1d(in_channels=500, out_channels=500, kernel_size=1)
        # self.MaxPool1D = torch.nn.MaxPool1d(4, stride=2)
        # self.Dropout = torch.nn.Dropout(p=0.5, inplace=False)

    def forward(self, input_ids, attention_mask, token_type_ids):
        with torch.no_grad():
            out = pretrained(input_ids=input_ids,
                       attention_mask=attention_mask,
                       token_type_ids=token_type_ids)
        out = self.fc(out.last_hidden_state[:, 0])
        out = out.softmax(dim=1)
        print(out.shape)
        return out

三、模型训练

model = Model()
print(model)
#model.summary()
model(input_ids=input_ids,
      attention_mask=attention_mask,
      token_type_ids=token_type_ids).shape

from transformers import AdamW
#训练
optimizer = AdamW(model.parameters(), lr=5e-4)
criterion = torch.nn.CrossEntropyLoss()

model.train()
epochs = 30

for i, (input_ids, attention_mask, token_type_ids,
        labels) in enumerate(loader):
    out = model(input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids)

    loss = criterion(out, labels)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if i % 1 == 0:
        out = out.argmax(dim=1)
        accuracy = (out == labels).sum().item() / len(labels)

        print('epochs:',i, 'loss:',loss.item(),'accuracy:', accuracy)

    if i == epochs:
        torch.save(model, 'text_classfiy.model')
        #model_load = torch.load('model/命名实体识别_中文.model')
        break

四、模型测试

#测试函数
def test():
    model.eval()
    correct = 0
    total = 0

    loader_test = torch.utils.data.DataLoader(dataset=Dataset('validation'),
                                              batch_size=10,
                                              collate_fn=collate_fn,
                                              shuffle=True,
                                              drop_last=True)

    for i, (input_ids, attention_mask, token_type_ids,
            labels) in enumerate(loader_test):

        if i == 5:
            break

        with torch.no_grad():
            out = model(input_ids=input_ids,
                        attention_mask=attention_mask,
                        token_type_ids=token_type_ids)

        out = out.argmax(dim=1)
        correct += (out == labels).sum().item()
        total += len(labels)

    print(correct / total)

可以调用测试函数进行测试,看看模型训练效果。

欢迎继续关注 深度学习实战案例,持续更新。获取数据可私聊。

 往期作品:

深度学习实战项目

1.深度学习实战1-(keras框架)企业数据分析与预测

2.深度学习实战2-(keras框架)企业信用评级与预测

3.深度学习实战3-文本卷积神经网络(TextCNN)新闻文本分类

4.深度学习实战4-卷积神经网络(DenseNet)数学图形识别+题目模式识别

5.深度学习实战5-卷积神经网络(CNN)中文OCR识别项目

6.深度学习实战6-卷积神经网络(Pytorch)+聚类分析实现空气质量与天气预测

7.深度学习实战7-电商产品评论的情感分析

8.深度学习实战8-生活照片转化漫画照片应用

9.深度学习实战9-文本生成图像-本地电脑实现text2img

10.深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)

11.深度学习实战11(进阶版)-BERT模型的微调应用-文本分类案例

12.深度学习实战12(进阶版)-利用Dewarp实现文本扭曲矫正

13.深度学习实战13(进阶版)-文本纠错功能,经常写错别字的小伙伴的福星

14.深度学习实战14(进阶版)-手写文字OCR识别,手写笔记也可以识别了

15.深度学习实战15(进阶版)-让机器进行阅读理解+你可以变成出题者提问

16.深度学习实战16(进阶版)-虚拟截图识别文字-可以做纸质合同和表格识别

17.深度学习实战17(进阶版)-智能辅助编辑平台系统的搭建与开发案例

18.深度学习实战18(进阶版)-NLP的15项任务大融合系统,可实现市面上你能想到的NLP任务

19.深度学习实战19(进阶版)-ChatGPT的本地实现部署测试,自己的平台就可以实现ChatGPT

...(待更新)

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

深度学习实战11(进阶版)-BERT模型的微调应用-文本分类案例 的相关文章

  • AM335X外部看门狗及LINUX系统驱动移植(二)

    看门狗定时器 WDT Watch Dog Timer 是嵌入式系统的的一个组成部分 它实际上是一个计数器 一般给看门狗设置复位时间间隔 程序开始运行后看门狗开始计数 如果程序运行正常 过一段时间CPU应发出指令让看门狗置零 重新开始计数 如
  • python打砖块游戏算法设计分析_Python打砖块

    在家闲来无事用Python写了一个打砖块游戏 目前完成度一般 先来段视频 声音有点大 演示https www zhihu com video 1235510400411369472 游戏主要分那么几个板块 小球Ball 挡板Paddle 砖
  • ECS突发性能T6服务器可以用来做哪些事情?

    能做的事情还是挺多的 一般比如个人建站 WordPress建站 小微企业建站 小程序搭建 web开发部署等各种项目基本都是可以的 不过这类入门的就不太时候大型项目了 比如大型电商网站 比如人工智能 机器学习等就不要用突发型实例了 一般的小网

随机推荐

  • 利用maven-war-plugin实现不同环境下的配置文件

    我这是一个标准的maven的目录结构 配置文件都在src main resources根目录下 因为要改成多环境的配置 所以只有properties的文件改变了 公共配置可以原地不动 1 将配置文件放到不同的文件夹下 2 创建2个不同的pr
  • 默认值约束 [MySQL][数据库]

    默认值约束 DEFAULT 默认值约束的作用 给某个字段 某列指定默认值 一旦我们设置了默认值约束之后 在插入数据时 如果此字段没有显式赋值 则赋值为默认值 如果我们没有给一个字段添加默认值约束 这个时候我们如果没给一个字段显式赋值 那么这
  • Android编译之lunch命令

    google官方给的编译步骤 官方的详细编译步骤见 http source android com source building html 按照google给出的编译步骤如下 source build envsetup sh 加载命令 这
  • 巨头刷脸补贴大战自伊始就没有停止过

    随 着5G时代的到来 互联网 AI智能 云计算 物联网等技术的成熟 中国财政科学研究院应用学博士后盘和林认为 刷脸支付比密码支付更安全更便捷 我国在移动支付领域相较于其他国家来说一直处于领先地位 支付宝和微信支付两家在这一领域的竞争就从来没
  • QT笔记之QSpinBox和QSlider的封装使用

    文章目录 1 创建QT测试工程 2 右键添加 新建项 3 添加新的Qt Widget Class 叫做MySpinBox Slider QSpinBox和QSlider的组合使用 4 添加好QSpinBox和QSlider两个控件 并且调整
  • 牛顿第二定律沿流线流动粒子 Python 分析(流体力学)

    当流体粒子从一个位置移动到另一个位置时 它通常会经历加速或减速 根据牛顿第二运动定律 作用在所考虑的流体粒子上的合力 必须等于其质量乘以其加速度 F m a mathbf F m mathbf a F ma 实际上 不存在无粘
  • 微信第三方平台的授权过程整理

    最近碰到微信第三方平台这个东西 就研究了下 由于微信官方文档顺序不是很明确 我特别也整理了一下 官方的概述是 公众平台第三方平台是为了让公众号或小程序运营者 在面向垂直行业需求时 可以一键授权给第三方平台 并且可以同时授权给多家第三方 通过
  • Android基础进阶 - 消息机制 之Native层分析,Android面试回忆录

    synchronized this msg markInUse msg when when Message p mMessages boolean needWake 如果消息链表为空 或者插入的Message比消息链表第一个消息要执行的更早
  • Mockito框架@Mock, @InjectMocks注解使用

    最近写项目Junit 使用Junit4框架 测试的数据都要依赖数据库 而好多接口需要调其他的系统 junit4框架完全无法实现测试功能 大佬推荐用Mockito框架 这篇博客用来记录学习Mockito的使用方法 不足欢迎指点 Mock In
  • Dump文件的生成和使用

    1 简介 第一次遇到程序崩溃的问题 之前为单位开发了一个插件程序 在本机运行没有出现问题 但把生成的可执行文件拷贝到服务器上一运行程序 刚进入插件代码 插件服务就崩溃了 当时被这个问题整的很惨 在同事的帮助下了解到 对于程序崩溃 最快的解决
  • qemu图形界面linux,QEMU 简单几步搭建一个虚拟的ARM开发板

    1 安装QEMU 先在Ubuntu中安装QEMU sudo apt get install qemu 1 安装几个QEMU需要的软件包 sudo apt get install zlib1g dev sudo apt get install
  • 在Windows中使用WSL和VS Code搭建出友好的终端开发环境

    使用WSL Windows Subsystem for Linux 这一适用于 Linux 的 Windows 子系统可让开发人员按原样运行 GNU Linux 环境 包括大多数命令行工具 实用工具和应用程序 且不会产生传统虚拟机或双启动设
  • 平面几何-python

    三角形面积 题目描述 平面直角坐标系中有一个三角形 请你求出它的面积 输入描述 第一行输入一个 TT 代表测试数据量 每组测试数据输入有三行 每行一个实数坐标 x y x y 代表三角形三个顶点 1 T 10 3 10 5 x y 10 5
  • 2018年LeetCode高频算法面试题刷题笔记——验证回文串(字符串)

    1 解答之前的碎碎念 这个题还蛮简单的 大概就是考研机试第一题的水平 所以就不写解法了 2 问题描述 给定一个字符串 验证它是否是回文串 只考虑字母和数字字符 可以忽略字母的大小写 说明 本题中 我们将空字符串定义为有效的回文串 示例 1
  • 自媒体账号ID应该怎么取?

    我们都知道 在做自媒体之前 我们需要注册自媒体账号 这时候我们需要给账号取一个名称 一个好的名称能让你吸取更多的粉丝 让读者记忆深刻 取名规范各平台的规则都差不多 所以在入驻平台之前一定要先认真看清各平台名称规范 防止因为名称不规范而不能通
  • 鹏仔暴力刷导航网页排行榜HTML模板,网站如何测压

    现在很多站长都做了导航站 大多数导航站都有网站排行榜 也就是今日浏览榜 月最高浏览榜 总浏览榜等 你页面阅读量越高 那么排名越靠前 曝光的几率就更高 很多站长在某些导航站提交完收录后 手动刷新增加阅读量 真的特别慢 那么本次鹏仔就给大家简单
  • 头文件路径包含问题

    头文件包含两种 系统头文件和自定义头文件 系统头文件不说了 格式统一 自定义头文件在包含的时候要注意路径 其实是头文件与主文件的相对位置关系的问题 ps 另外 LInux和Windows下也有所区别 举4个例子 应该就能看明白了 一 这种情
  • InnoSetup 脚本打包及管理员权限设置

    InnoSetup使用教程 InnoSetup打包安装 脚本详细 1 定义变量 1 define MyAppName TranslationTool 2 define MyAppChineseName 翻译工具 3 define MyApp
  • ios系统下input边框有默认阴影

    修复代码 1 input outline none webkit appearance none 去除系统默认的样式 webkit tap highlight color rgba 0 0 0 0 点击高亮的颜色 2 input appea
  • 深度学习实战11(进阶版)-BERT模型的微调应用-文本分类案例

    文章目录 一 前期工作 导入库包 导入数据 二 模型加载三 模型训练 四 模型测试 大家好 我是微学AI 今天给大家带来一个基于BERT模型做文本分类的实战案例 在BERT模型基础上做微调 训练自己的数据集 相信之前大家很多都是套用别人的模