用自己的数据增量训练预训练语言模型

2023-11-07

预训练模型给各类NLP任务的性能带来了巨大的提升,预训练模型通常是在通用领域的大规模文本上进行训练的。而很多场景下,使用预训练语言模型的下游任务是某些特定场景,如金融,法律等。这是如果可以用这些垂直领域的语料继续训练原始的预训练模型,对于下游任务往往会有更大的提升。

以BERT为例,利用huggingface的tranformers介绍一下再训练的方式:

1. 定义tokenizer

bert的预训练模式一般分为,Masked language model (MLM)与 next sentence prediction(NSP),主要利用MLM在自己的语料上进行预训练

from transformers import RobertaConfig,BertTokenizer
from transformers import BertForMaskedLM as Model
from transformers import MaskedLMDataset,Split
from transformers import DataCollatorForLanguageModeling
from transformers.trainer_utils import get_last_checkpoint
# 定义tokenizer
tokenizer = BertTokenizer.from_pretrained(retrained_bert_path, max_len=max_seq_length)

2. 定义预训练模型的参数

# 定义预训练模型的参数
config = RobertaConfig(
    vocab_size=tokenizer.vocab_size,
    max_position_embeddings=max_seq_length,
    num_attention_heads=12,
    num_hidden_layers=12,
    type_vocab_size=2,
)

预训练的模式为MLM,直接调用 DataCollatorForLanguageModeling API即可方便得以自己的语料定义生成器。

retrained_model = Model(config=config)
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.15)

3. 加载MLM训练数据

train_data = MaskedLMDataset(data_file=train_file,
                                 tokenizer=tokenizer,
                                 tag=train_tags,
                                 max_seq_length=max_seq_length,
                                 mode=Split.train,
                                 overwrite_cache=overwrite_cache)
train_data = [feature.convert_feature_to_dict() for feature in train_data]

 

4. 开始预训练

这里可以设置的参数有,输入端的batch_size、语料文件、tokenizer,训练过程方面则有 训练轮数epochs、batch_size 以及保存频率。经过这些简单的即可成功训练好一个基于MLM的bert模型了(损失loss降到0.5左右就可以了),也可以通过MLM模型所带的接口来做MLM预测,当然我们这里需要的只是bert的权重。

设置训练参数

from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
        output_dir=out_model_path,
        overwrite_output_dir=True,
        num_train_epochs=train_epoches,
        per_device_train_batch_size=batch_size,
        save_steps=2000,
        save_total_limit=2,
        prediction_loss_only=True,
    )

训练

trainer = Trainer(
        model=retrained_model,
        args=training_args,
        train_dataset=train_data,
        data_collator=data_collator,
    )

last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is not None:
    train_result = trainer.train(resume_from_checkpoint=last_checkpoint)
else:
    train_result = trainer.train()

保存模型

trainer.save_model()  # Saves the tokenizer too for easy upload

本篇介绍了 增量训练预训练语言模型的方法,下一篇将介绍fine-tunning再训练好的语言模型的使用方法

 

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

用自己的数据增量训练预训练语言模型 的相关文章

随机推荐

  • Java RMI 解析

    1 什么是RMI Java RMI 即 远程方法调用 Remote Method Invocation 一种用于实现远程过程调用 RPC Remote procedure call 的Java API 能直接传输序列化后的Java对象和分布
  • GD32替换STM32后 写片上闪存(flash)失败的解决方法

    目录 型号 问题 解决办法 下载gd的fmc操作库 修改fmc文件 使用 擦除一页 写一页 型号 使用的GD32C103CB等引脚替换STM32F103CB 问题 使用hal库的flash操作接口 片上flash可以正常擦除 但是无法写入
  • Flash Player 10 中的RTMFP协议(实现P2P技术)

    RTMFP是Adobe公司开发的一套新的通信协议 该协议可以让使用Adobe Flash Player的终端用户之间进行直接通信 用Adobe AIR框架开发的程序也可以用此协议来发布直播 实时信息 通过使用RTMFP 那些以来直播 实时通
  • 03多线程之间通讯

    线程之间的通信 一 为什么要线程通信 1 多个线程并发执行时 在默认情况下CPU是随机切换线程的 当我们需要多个线程来共同完成一件任务 并且我们希望他们有规律的执行 那么多线程之间需要一些协调通信 以此来帮我们达到多线程共同操作一份数据 2
  • linux内存文件系统

    写文件时 太耗内存的话 可以使用dma拷贝 或者使用内存文件系统的方式 但首先要搞清楚一点 正常的文件操作 多久会真正保存到磁盘中去呢 参考 浅谈Linux系统写磁盘机制 http blog sina com cn s blog 96757
  • mybatis通用mapper的Example查询

    mybatis的通用mapper 多用于单表查询 接口内部为我们提供了单表查询的基础查询语法 可以极大地帮助我们简化编程 接下来让我们动手试一试 我建的是springboot项目 先导依赖
  • 词云下载jieba成功后仍然报错

    下载jieba终端 pip install i https pypi tuna tsinghua edu cn simple jieba 成功下载后仍然报错 TransposedFont object has no attribute ge
  • 牛顿-拉夫逊法潮流计算matlab程序,牛顿—拉夫逊法潮流计算MATLAB程序.doc

    牛顿 拉夫逊法潮流计算程序By Yuluo 牛顿 拉夫逊法进行潮流计算 n input 请输入节点数 n n1 input 请输入支路数 n1 isb input 请输入平衡母线节点号 isb pr input 请输入误差精度 pr B1
  • python之struct详解

    python之struct详解 醉小义的博客 CSDN博客 python struct 尊重原创
  • Unity中,在按钮的处理事件中,显示UI(Panel)的一些问题

    问题来源 自己遇到的 32条消息 Unity SetActive True 滞后严重 游戏 CSDN问答 简单概括就是 点击按钮 开始处理某个事件 这个事件需要花费较长时间 我的想法是加入一个加载中界面 方便告知用户当前程序没有卡住 在完成
  • kodi刮削器 中文_手把手教你用Kodi,搭建最强私人娱乐/学习中心!(小白篇)...

    喜欢本篇内容请给我们点个在看 什么是KODI 简单的说 Kodi 就是一个功能强大且免费的媒体播放器 支持全平台 如Windows Linux iOS Android Xbox 以及树莓派等 可播放电影 电视剧 音乐 电视直播 电台等等 特
  • JS逆向解析---某知名小说网站内容加密

    该小说网站的全部内容都是经过一个JS的加密 要想爬取这个网站那么将其内容解析是不可避免的 本文将讲解如何对其进行JS的逆向解析 网站 shuqi 随便点开一本书 打开浏览器自带的抓包工具 点击第一个包 但是在这里找不到我们想要的数据 说明不
  • 实现ListView中每行显示进度条,并且各自显示自己的进度

    package com sagaware process list import java util ArrayList import java util HashMap import java util List import java
  • Web2.0网站一些通用业务采用NoSql的解决方案

    首先理解NoSql的划分 Often NoSQL databases are categorized according to the way they store the data and fall under categories su
  • MySQL生产环境高可用架构实战

    分布式技术MongoDB 1 MySQL高可用集群介绍 1 1 数据库主从架构与分库分表 1 2 MySQL主从同步原理 2 动手搭建MySQL主从集群 2 1 基础环境搭建 2 2 安装MySQL服务 2 2 1 初始化MySQL 2 2
  • 仿射密码 affine

    参考链接 https www cnblogs com 0yst3r 2046 p 12172757 html 仿射加密法 在仿射加密法中 字母表的字母被赋予一个数字 例如 a 0 b 1 c 2 z 25 仿射加密法的密钥为0 25直接的数
  • Incorrect integer value: '' for column 'id' at row 1 错误解决办法

    最近一个项目 在本地php环境里一切正常 ftp上传到虚拟空间后 当执行更新操作 我的目的是为了设置id为空 set id 时提示 Incorrect integer value for column id at row 1 解决办法 方法
  • 广工人福利,openwrt+gduth3c通过inode认证,妈妈再也不用担心我要用电脑开wifi了

    刚开校园网的时候 天天都只能用电脑开wifi 用类似于360wifi 猎豹wifi之类的软件要经常开着电脑 而且电脑网卡发射功率又小 上个厕所wifi就断了 睡觉前在床上还没wifi用 超级不爽 于是从家里面拿来了放在自己房间挂迅雷百度云的
  • x86下的C函数调用惯例

    1 从汇编到C 1 1 汇编语言的局限性 汇编语言是一种符号化了的机器语言 machine code 即用指令助记符 符号地址 标号等符号书写程序的语言 汇编语句与机器语句一一对应 它只是把每条指令及数据用便于记忆的符号书写而已 汇编语言
  • 用自己的数据增量训练预训练语言模型

    预训练模型给各类NLP任务的性能带来了巨大的提升 预训练模型通常是在通用领域的大规模文本上进行训练的 而很多场景下 使用预训练语言模型的下游任务是某些特定场景 如金融 法律等 这是如果可以用这些垂直领域的语料继续训练原始的预训练模型 对于下