BERT从零详细解读:如何微调BERT,提升BERT在下游任务中的效果

2023-11-17

在这里插入图片描述
a)是句子对的分类任务
b)是单个句子的分类任务
c) 是问答任务
d)是序列标注任务

首先我自己最常用的就是:文本分类、序列标注和文本匹配。
这四个都是比较简单的,我们来看d)序列标注,其实就是把所有的token输出,做了一个softmax,去看它属于实体中的哪一个。对于单个样本,它的一个文本分类就是使用CLS这边,第一个CLS的输出,去做一个微调,做一个二分类,或者是多分类。
a)这个其实本质是一个文本匹配的一个任务,文本匹配就是把两个句子拼接起来,去判断它是否相似。左上角也是用CLS输出判断,0不相似,1相似。基本上其实就是这样,其实在下游任务中它使用还是比较简单的。

如何提升BERT在下游任务中的效果或者是表现。因为我们在实际应用中,很少会让你自己去从头训练一个bert。一般都是用训练好的,就是大公司放出来的bert,然后我们自己在自己的任务中做一些微调。

很多朋友的做法都是,先获取谷歌中文或者是其它公司的bert,然后基于自己的任务数据去做微调。但是我们想要更好的性能的话,现在有很多tirck需要去做。

首先,我想提的第一点就是去做 Post training。

四步骤

比如做微博文本情感分析:

  1. 在大量通用预料上训练一个LM(pretrain);- 中文谷歌BERT
  2. 在相同领域上继续训练LM(Domain transfer); - 在大量微博文本上继续训练这个BERT
  3. 在任务相关的小数据上继续训练LM(Task transfer);- 在微博情感文本上(有的文本不属于情感分析的范畴)
  4. 在任务相关数据上做具体任务(Fine-tune)。

一般经验是,先做Domain transfer,再进行 Task transfer,最后Fine-tune 性能是最好的。

如何再相同领域数据中进行further pre-training

  1. 动态mask:就是每次epoch去训练的时候mask,而不是一直使用同一个。

    bert在训练的时候使用的是固定的mask,就是把文本mask之后存在本地,然后每次训练的时候都是使用同一个文件,也就是说每次训练的时候我们使用的都是同样的mask标志。比如之前的例子【我爱吃饭】,每次训练的时候都是mask掉了这个”吃“,这样其实不太好。然后动态mask呢,就是每个epoch训练之前,去对数据进行mask。
    刚才说bert一直使用同一套mask,也不太准确,它是有做一些改进,他有复制一些文本,大家具体去看一下论文。

  2. n-gram mask:其实比如ERINE 和 SpanBert都是类似于做了实体词的mask。

    我们可以退一步,就是如果你自己训练的时候,你没有特别准确的实体词,你可以不做实体词的mask,你可以做n-gram mask.

我们在做的时候参数一定要设置得特别的好,Batch size其实16,32,64,128影响不太大;Learning rate(Adam)5e-5,3e-5,2e-5,尽可能小一点避免灾难性遗忘;在微调的时候number of epochs,一般是3、4个,一般不会太大;weighted decay修改后的adam,使用warmup,搭配线性衰减,这个是比较重要的;

还有就是比如在预训练的时候做数据增强(一些简单的EDA)、自蒸馏、外部知识的融入(比如融入知识图谱的知识,或者加一些实体词的信息),这些都可以,不过比较吃机器。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

BERT从零详细解读:如何微调BERT,提升BERT在下游任务中的效果 的相关文章

随机推荐

  • CMake入门教程:使用target_include_directories指定头文件目录

    CMake入门教程 使用target include directories指定头文件目录 在进行软件开发时 我们经常需要引用一些外部库或模块的头文件以便使用其功能 CMake是一个强大的跨平台构建工具 能够帮助我们管理项目并生成相应的构建
  • WSL2 局域网访问以及hosts注意事项

    说明 WSL2用的是NAT方式 虚拟机有内部的ip 所以访问虚拟机可用代理访问方法 要点 根据微软文档 powershell 下做端口转发代理 netsh interface portproxy add v4tov4 listenport
  • 双引号后面要加句号吗_小学二年级老师容易疏忽的一个知识点:冒号和双引号...

    标点符号是特殊的文字 使用得当 会为文章增色不少 同时也是考试丢分的一个知识点 应引起师生重视 到了小学二年级 必须学会使用冒号和双引号 冒号 是常用的标点符号之一 通常表示提示语后的停顿或表示提示下文或总结上文 它用在提示语的后面 如果老
  • 用Python写一个比大小的小游戏(代码解释)

    代码解释 游戏 猜数字 玩法 程序会随机生成一个1 30的数字 玩家有无限次 机会去猜这个数字程序会告诉你是大了还是小了 在最后猜中的时候程序会告诉你猜中了 并且告诉你结束游戏以及猜中该数字所花费的次数 代码 Python import r
  • PyTorch的官方bug:torch.optim.lr_scheduler.CosineAnnealingWarmRestarts

    torch optim lr scheduler CosineAnnealingWarmRestarts 低版本 如torch1 7 1 指定last epoch参数时报错 已有人反馈指出 升级torch1 11 0可以解决该问题 升级之后
  • Python数据可视化——图型参数介绍

    前言 利用Python 绘制常见的统计图形 例如条形 图 饼图 直方图 折线图 散点图等 通过这些常用图形的展现 将 复杂的数据简单化 这些图形的绘制可以通过matplotlib 模块 pandas 模 块或者 seaborn 模块实现 饼
  • java 垃圾回收 sys_深入理解Java虚拟机学习笔记2.1-G1垃圾回收

    G1 GC是Jdk7的新特性之一 Jdk7 版本都可以自主配置G1作为JVM GC选项 作为JVM GC算法的一次重大升级 DK7u后G1已相对稳定 且未来计划替代CMS 所以有必要深入了解下 不同于其他的分代回收算法 G1将堆空间划分成了
  • springmvc中的resolveView(视图解析器)

    视图解析器接口只有一个方法 就是根据名称解析出视图信息 一个视图对象View 采用的是模板模式 抽象模板类 AbstractCachingViewResolver 主要处理缓存 如果视图对象在缓存中有 则从缓存中取 如果没有则创建 publ
  • 整理最全的图床集合——三千图床

    2021 09 25 更新 去除部分图床 添加新的图床 优化排版 引言 古有弱水三千 今有三千图床 勿埋我心 图床一般是指储存图片的服务器 有国内和国外之分 国外的图床由于有空间距离等因素决定访问速度很慢影响图片显示速度 国内也分为单线空间
  • remote: HTTP Basic: Access deniedfatal: Authentication failed for ‘xxxxx‘的问题解决

    在没有修改git密码的情况下 使用vs code推送代码 总是会报错 remote HTTP Basic Access denied fatal Authentication failed for xxxxxxxx git仓库地址 网上试了
  • YOLOV7开源代码讲解--训练参数解释

    目录 训练参数说明 weights cfg data hpy epoch batch size img size rect resume nosave notest noautoanchor evolve bucket cach image
  • 【Basis】狄利克雷分布

    初次看狄利克雷分布 比较懵 主要是它有很多先行知识 所以我先介绍狄利克雷分布用到的多项式分布 gamma 函数 beta分布 然后再介绍狄利克雷分布 参考文献见文章末 目录 一 多项式分布 multinomial distribution
  • 仅仅是一张照片就是不能刷脸支付的

    科技改变未来并不是一句口号 就拿买东西来讲 以前人们都是一手交钱一手交货 拿到大额的纸币 还要验真假 而现在移动支付成为主要付款方式 只要一部手机 扫一扫就能付款 一开始也有很多人不习惯手机支付 因为觉得没有现金实在 整天就是一堆数字转来转
  • 解决TypeError: 'function' object is not subscriptable

    一 解决问题 在tensorflow中使用零矩阵初始化变量的时候出现的该异常 TypeError function object is not subscriptable 二 解决方法 问题代码如下 bias tf Variable tf
  • 深度学习(9):Inception危险物品检测

    目标 基于Inception网络实现对危险物品检测 将危险物品图片或视频经过图像预处理后输入模型推理 最后将检测结果进行可视化输出 一 原理 Google的Inception网络介绍 Inception为Google开源的CNN模型 至今已
  • Java的变量

    1 Java 变量类型 答 在Java语言中 所有的变量在使用前必须声明 声明变量的基本格式如下 type identifier value identifier value 格式说明 type为Java数据类型 identifier是变量
  • Java实现生成csv文件并导入数据

    一 需求 下载列表 在没有过滤之前下载列表所有数据 点击过滤之后 下载过滤之后对数据 生成csv文件 二 思路 先根据条件 是否过滤了数据 筛选出数据 将数据导入csv文件 生成文件并返回 三 代码实现 1 controller层 文件下载
  • Gbase 8s存储结构简介及空间管理

    Gbase 8s实例可以创建多个dbspace 一个dbspace可以包含多个物理chunk 一个chunk分成多个连续扩展区extent 一个表或者索引占用的空间被称为一个tablespace 一个extent包含多个物理页page 其中
  • 利用多线程来实现一个简单的服务器,来实现处理多个用户的请求

    服务器来实现接受多个客户的请求 并且处理响应 服务器采用了多线程 代码如下服务器 package cn kgc basic tcpthread import java io IOException import java net Serve
  • BERT从零详细解读:如何微调BERT,提升BERT在下游任务中的效果

    a 是句子对的分类任务 b 是单个句子的分类任务 c 是问答任务 d 是序列标注任务 首先我自己最常用的就是 文本分类 序列标注和文本匹配 这四个都是比较简单的 我们来看d 序列标注 其实就是把所有的token输出 做了一个softmax