Huggingface训练Transformer

2023-11-20

在之前的博客中,我采用SFT(监督优化训练)的方法训练一个GPT2的模型,使得这个模型可以根据提示语进行回答。具体可见博客召唤神龙打造自己的ChatGPT_gzroy的博客-CSDN博客

Huggingface提供了一个TRL的扩展库,可以对transformer模型进行强化学习,SFT是其中的一个训练步骤,为此我也测试一下如何用Huggingface来进行SFT训练,和Pytorch的训练方式做一个比较。

训练数据

首先是获取训练数据,这里同样是采用Huggingface的chatbot_instruction_prompts的数据集,这个数据集涵盖了不同类型的问答,可以用作我们的模型优化之用。

from datasets import load_dataset

ds = load_dataset("alespalla/chatbot_instruction_prompts")
train_ds = ds['train']
eval_dataset = ds['test']
eval_dataset = eval_dataset.select(range(1024))

训练集总共包括了258042条问答数据,对于验证集我只选取了头1024条记录,因为总的数据集太长,如果在训练过程中全部验证的话耗时太长。

加载GPT2模型

Huggingface提供了很多大模型的训练好的参数,这里我们可以直接加载一个已经训练好的GPT2模型来做优化

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

然后我们可以定义TRL提供的SFTTrainer来进行训练,首先需要对训练数据处理一下,因为训练数据包括了两列,分别是prompt和response,我们需要把两列的文本合为一起,通过格式化字符来区分,如以下格式化函数:

def formatting_func(example):
    text = f"### Prompt: {example['prompt']}\n ### Response: {example['response']}"
    return text

定义SFTTrainer的训练参数,具体每个参数的含义可见官网的文档:

args = TrainingArguments(
    output_dir='checkpoints_hf_sft',
    overwrite_output_dir=True, 
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    fp16=True,
    torch_compile=True,
    evaluation_strategy='steps',
    prediction_loss_only=True,
    eval_accumulation_steps=1,
    learning_rate=0.00006,
    weight_decay=0.01,
    adam_beta1=0.9,
    adam_beta2=0.95,
    warmup_steps=1000,
    eval_steps=4000,
    save_steps=4000,
    save_total_limit=4,
    dataloader_num_workers=4,
    max_steps=12000,
    optim='adamw_torch_fused')

最后就可以定义一个trainer来训练了

trainer = SFTTrainer(
    model,
    args = args,
    train_dataset=dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    packing=True,
    formatting_func=formatting_func,
    max_seq_length=1024
)

trainer.train(resume_from_checkpoint=False)

因为是第一次训练,我设置了resume_from_checkpoint=False,如果是继续训练,把这个参数设为True即可从上次checkpoint目录自动加载最新的checkpoint来训练。

训练结果如下:

 [12000/12000 45:05, Epoch 0/1]

Step Training Loss Validation Loss
4000 2.137900 2.262321
8000 2.187800 2.232235
12000 2.218500 2.210413

总共耗时45分钟,比我在pytorch上的训练要快一些(快了10分钟多一些),但是这个训练集的Loss随着Step的增加反而增加了,Validation Loss就减少了,有些奇怪。在Pytorch上我同样训练12000个迭代,最后training loss是去到1.8556的。可能还要再调整一下trainer的参数看看。

测试

最后我们把SFT训练完成的模型,通过huggingface的pipeline就可加载进行测试了。

from transformers import pipeline

model = AutoModelForCausalLM.from_pretrained('checkpoints_hf_sft/checkpoint-12000/')
pipe = pipeline(task='text-generation', model=model, tokenizer=tokenizer, device=0)

pipe('### Prompt: Who is the current president of USA?')
[{'generated_text': '### Prompt: Who is the current president of USA?\n ### Response: Harry K. Busby is currently President of the United States.'}]

回答的语法没问题,不过内容是错的。

再测试另一个问题

### Prompt: How to make a cup of coffee?

[{'generated_text': '### Prompt: How to make a cup of coffee?\n ### Response: 1. Preheat the oven to 350°F (175°C).\n\n2. Boil the coffee beans according to package instructions.\n\n3. In a'}]

这个回答就正确了。

总结

通过用Huggingface可以很方便的对大模型进行强化学习的训练,不过也正因为太方便了,很多训练的细节被包装了,所以训练的结果不太容易优化,不像在Pytorch里面控制的自由度更高一些。当然可能我对huggingface的trainer参数的细节还不太了解,这个有待后续继续了解。

另外我还发现huggingface的一个小的bug,就是模型从头训练的时候没有问题,但是当我从之前的checkpoint继续训练时,会报CUDA OOM的错误,从nvidia-smi命令看到的显存占用率来看,好像trainer定义模型和装载Checkpoint会重复占用了显存,因此同样的batch_size,在继续训练时就报内存不够了,这个也有待后溪继续了解。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Huggingface训练Transformer 的相关文章

随机推荐

  • android10官方支持机型,Andorid10.0支持哪些手机?附安卓10支持机型介绍

    Andorid10 0支持哪些手机 附安卓10支持机型介绍 2019 03 12 09 58 24 来源 qqtn com 扫码可以 1 在手机上浏览 2 分享给微信好友或朋友圈 摘要 Andorid10支持手机型号目前正在逐渐公布 许多小
  • c#窗体开发俄罗斯方块小游戏

    在个人电脑日益普及的今天 一些有趣的桌面游戏已经成为人们在使用计算机进行工作或学习之余休闲娱乐的首选 而俄罗斯方块游戏是人们最熟悉的小游戏知益 它趣味性极强 变化无穷 易上手等诸多特点得到了大众的认可 此外对运动的方块进行组合 可以训练玩家
  • Session&Cookie&token

    一 Session 什么是Session 服务器为了保存用户状态而创建的一个特殊的对象 当浏览器第一次访问服务器时 服务器创建一个session对象 该对象有一个唯一的id 一般称之为sessionId 服务器会将sessionId以coo
  • 原创的 20 个 Python 自动化案例,一口一个,高效办公!

    导读 大家好 自从 发布第一篇 Python 办公自动化办公系列文章以来 目前已经马不停蹄的更新了 20 个案例 累计阅读超 10W 为了方便大家阅读学习 我将这二十个案例再次进行分类汇总 内容涵盖 Python 操作Word Excel
  • electron添加SQLite数据库

    序 在之前 我曾经使用electron开发过一个番茄钟应用 但是当时的应用数据存储是在JSON文件当中 通过node的fs文件系统进行读写的 但是感觉不用数据库总有点不太专业 所以还是打算使用数据库来作为存储的地方 不过数据库的选择就多了
  • html效果总结记录

    自动换行 style word break break all 改变布局或者显示文字 在html中插入如下类似代码
  • html5 图片下拉加载动画效果,HTML5和CSS3炫酷彩色loading加载动画特效

    这是一款HTML5和CSS3炫酷彩色loading加载动画特效 该loading加载动画特效共15种动画效果 它们分别通过div盒子或svg元素 配合CSS3来制作loading动画效果 使用方法 HTML结构 第一种loading加载动画
  • MySQL基本操作四:数据的查询

    之前我介绍了MySQL中 数据记录的增 删 改操作 本文我们看查询操作 为方便后面举例 还是先建立一个表 并插入一些数据 我在这里依旧建立一个学生信息表 建表的代码如下 CREATE TABLE tab student StuID CHAR
  • linux docker 基础命令

    docker常用命令 docker ps 显示所有的操作命令集合列表 docker login container registry oracle com 登录官方远程docker私有仓库 docker login username tes
  • spss练习数据_SPSS篇——数据的筛选

    前言 我们知道 在做数据分析时 经常要对数据库中的个案进行选择性的分析 由此排除一些数据 那么该如何利用SPSS进行数据的筛选呢 Step 1 打开SPSS 打开 数据 选择个案 Step 2 选择个案 对话框右侧 选择 栏有5个选项 所有
  • 云计算入门基础命令行

    严重声明 本人支持一切正规软件开发行为 接受知识付费理念 并坚决抵制盗版行为 用于学习交流的非盈利目的的 且法律允许且支持的条件下 可以进行相关文件交流 他人利用交流文件进行非法售卖等一切违法犯罪行为 本人概不负责 分享的网页链接能保证截止
  • Dynamics CRM online 添加附件

    1 创建一个字段 类型为文件 2 通过PoweApps将该字段添加到Form表单上
  • ESP8266上电时串口打印乱码原因和发送AT指令串口返回信息含义

    因为esp8266模块上电时 默认打印波特率为74880 其固件中通信与串口默认为115200 所以如果把串口设置为74880 之后再上电模块打印的信息就不会有乱码了 但AT指令默认通信波特率是115200 如果在使用AT通信时 在调到11
  • Docker 容器迁移

    1 把容器打包成镜像 docker commit m 描述 a 作者 CONTAINER ID 新的镜像名 docker commit m my rabbitmq a eric a922049125c4 rabbitmq my 1 0 2
  • QT 改编官方example QQuickView实现QQuickWidgt示波器

    先找到官方实例 将官方实例的QML文件全部搞到自己的资源中 pro中加入这个 DISTFILES 刚刚的QML文件路径 qml 还有这个 QT core gui quickwidgets QT qml 将官方的datasource cpp
  • k8s部署Prometheus抓取pods的metrics

    1 暴露pods给Prometheus抓取 spec replicas app replicas template metadata annotations prometheus io scrape true prometheus io p
  • centos7没有ens33网卡的解决方案

    1 首先设置在系统启动时激活网卡 vim etc sysconfig network scripts ifcfg ens33 将ONBOOT no改为ONBOOT yes 2 执行下面的命令 此时ifconfig 显示有了ens33网卡 但
  • 【Linux】Libevent库

    Libevent 高性能I O框架库 底层封装了select poll epoll 便于使用 I O框架库以库函数的形式 封装了较为底层的系统调用 给应用程序提供了一组更便于使用的接口 特点 1 跨平台 2 统一事件源 I O事件 信号 定
  • Jeesite框架实用 前端页面展示表如何插入其他表数据

    目录 前言 问题 解决方法 第一步 第二步 第三步 前言 最近做类似于OA产品时选用了Jeesite框架 也学习有一段时间 这个框架的初级操作嘛就是 设计好表然后用代码生成器生成 一张表生成一个前端页面显示 问题 代码生成器根据一张数据库表
  • Huggingface训练Transformer

    在之前的博客中 我采用SFT 监督优化训练 的方法训练一个GPT2的模型 使得这个模型可以根据提示语进行回答 具体可见博客召唤神龙打造自己的ChatGPT gzroy的博客 CSDN博客 Huggingface提供了一个TRL的扩展库 可以