复现BART finetune历程

2023-10-29

复现BART finetune历程

准备

  • 安装fairseq,使用fairseq官方提供的finetune代码

    git clone https://github.com/pytorch/fairseq
    cd fairseq
    pip install --editable ./
    
  • 下载Xsum与DailyCNN数据集,已处理为train.source等形式。解压保存在/home/DataSets/Xsum和/home/DataSets/DailyCNN

    https://github.com/huggingface/transformers/blob/master/examples/seq2seq/README.md
    
  • 下载官方release的bart_large模型,解压保存至/home/LM/bart_large

    https://github.com/pytorch/fairseq/tree/master/examples/bart
    
  • 安装files2rouge使用paper使用的ROUGE计算方法

    git clone https://github.com/pltrdy/files2rouge.git     
    cd files2rouge
    python setup_rouge.py
    python setup.py install
    

    在Linux系统安装前,修改setup.py文件第29行。

        install_requires=[
    	"pyrouge==0.1.3"
        ],
    

    若安装后运行时出现BUG

    TypeError: __init__() got an unexpected keyword argument 'log_level'
    

    则使用命令手动安装pyrouge

    pip install -U git+https://github.com/pltrdy/pyrouge
    

预处理

数据预处理

  • BPE分词,使用bart_large 模型的词典进行分词

    wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json'
    wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe'
    wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt'
    
    TASK=Xsum
    for SPLIT in train val
    do
      for LANG in source target
      do
        python -m examples.roberta.multiprocessing_bpe_encoder \
        --encoder-json encoder.json \
        --vocab-bpe vocab.bpe \
        --inputs "$TASK/$SPLIT.$LANG" \
        --outputs "$TASK/$SPLIT.bpe.$LANG" \
        --workers 60 \
        --keep-empty;
      done
    done
    

    bart_large

  • Binarize dataset.MODE可以切换使用bart_base或者bart_large模型的词典

    MODE=large
    TASK=Xsum
    fairseq-preprocess \
      --source-lang "source" \
      --target-lang "target" \
      --trainpref "${TASK}/train.bpe" \
      --validpref "${TASK}/val.bpe" \
      --destdir "${TASK}-${MODE}-bin/" \
      --workers 60 \
      --srcdict dict.txt \
      --tgtdict dict.txt;
    

训练

  • finetune和官方提供版本做了相应调整,但超参数设置保持一致。如果使用1块GPU卡进行训练,TOTAL_NUM_UPDATES, WARMUP_UPDATES, 和UPDATE_FREQ分别再*8

    DailyCNN使用16G Tesla V-100进行单卡finetune时MAX_TOKENS只能设为1024否则爆显存。

    #!/bin/bash
    export PYTHONUNBUFFERED=1 
    
    MODE=large
    TASK=/home/DataSets/Xsum
    TOTAL_NUM_UPDATES=15000
    WARMUP_UPDATES=500   
    LR=3e-05
    MAX_TOKENS=2048
    UPDATE_FREQ=2
    BART_PATH=/home/LM/bart_${MODE}/model.pt
    
    
    CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 fairseq-train $TASK-${MODE}-bin \
        --restore-file $BART_PATH \
        --reset-optimizer --reset-dataloader --reset-meters \
        --save-dir Xsum_checkpoints_${MODE} \
        --max-tokens $MAX_TOKENS \
        --task translation \
        --source-lang source --target-lang target \
        --truncate-source \
        --layernorm-embedding \
        --share-all-embeddings \
        --share-decoder-input-output-embed \
        --required-batch-size-multiple 1 \
        --arch bart_${MODE} \
        --criterion label_smoothed_cross_entropy \
        --label-smoothing 0.1 \
        --dropout 0.1 --attention-dropout 0.1 \
        --weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.999)" --adam-eps 1e-08 \
        --clip-norm 0.1 \
        --lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \
        --fp16 --update-freq $UPDATE_FREQ \
        --skip-invalid-size-inputs-valid-test \
        --no-epoch-checkpoints \
        --find-unused-parameters;
    
    

评测

以Xsum为例,使用单卡训练在4个epoch后大约就得到最佳checkpoint。

使用

import torch
from fairseq.models.bart import BARTModel
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default="Xsum", type=str, required=False, help='path of pre-trained model files')
parser.add_argument('--model_size', default="large", type=str, required=False, help='预测结果提交文件')
parser.add_argument('--file_prefix', default="large_model", type=str, required=False, help='预测结果提交文件')
args = parser.parse_args()


result_path = 'results/'+args.dataset+'/'+args.model_size
checkpoint_path = './'+args.dataset+'_checkpoints_'+args.model_size
#checkpoint_path = '/home/LM/bart-large-xsum' #下载官方finetune好的ckpt进行再次验证
print(args.dataset)
print(args.model_size)
print(result_path)
'''
bart = BARTModel.from_pretrained(
    '/home/LM/bart_'+args.model_size+'_pt',
    checkpoint_file='model.pt',
    data_name_or_path=result_path,
    task='translation',
    source_lang = "source",
    target_lang = "target",
)
'''
bart = BARTModel.from_pretrained(
    checkpoint_path,
    checkpoint_file='checkpoint_best.pt',
    #checkpoint_file = 'model.pt', # 载官方finetune好的ckpt进行再次验证
    data_name_or_path=result_path,
    #task='translation',
    #source_lang = "source",
    #target_lang = "target",
)

bart.cuda()
bart.eval()
bart.half()
count = 1
bsz = 32
with open('/home/DataSets/'+args.dataset+'/test.source',encoding="utf8") as source, open(result_path+'/'+args.file_prefix+'_test.hypo', 'w',encoding = "utf8") as fout:
    sline = source.readline().strip()
    slines = [sline]
    for sline in source:
        if count % bsz == 0:
            with torch.no_grad():
                #print(slines,"\n\n\n\n\n\n")
                hypotheses_batch = bart.sample(slines, beam=6, lenpen=1.0, max_len_b=60, min_len=10, no_repeat_ngram_size=3)
                #print(hypotheses_batch)
                #exit()
            for hypothesis in hypotheses_batch:
                fout.write(hypothesis + '\n')
                fout.flush()
            slines = []

        slines.append(sline.strip())
        count += 1
    if slines != []:
        hypotheses_batch = bart.sample(slines, beam=6, lenpen=1.0, max_len_b=60, min_len=10, no_repeat_ngram_size=3)
        for hypothesis in hypotheses_batch:
            fout.write(hypothesis + '\n')
            fout.flush()

进行测试,并将结果保存在 ./result/Xsum/large下。对于DailyCNN数据,使用

beam=4, lenpen=2.0, max_len_b=140, min_len=55

之后使用files2rouge计算ROUGE得分比较Average_F

$ sudo files2rouge xsum_test.target xsum_my_finetuned_model.hypo 
---------------------------------------------
1 ROUGE-1 Average_R: 0.49389 (95%-conf.int. 0.49098 - 0.49668)
1 ROUGE-1 Average_P: 0.39964 (95%-conf.int. 0.39682 - 0.40230)
1 ROUGE-1 Average_F: 0.43521 (95%-conf.int. 0.43248 - 0.43775)
---------------------------------------------
1 ROUGE-2 Average_R: 0.23021 (95%-conf.int. 0.22718 - 0.23331)
1 ROUGE-2 Average_P: 0.18480 (95%-conf.int. 0.18230 - 0.18739)
1 ROUGE-2 Average_F: 0.20179 (95%-conf.int. 0.19915 - 0.20462)
---------------------------------------------
1 ROUGE-L Average_R: 0.38980 (95%-conf.int. 0.38671 - 0.39278)
1 ROUGE-L Average_P: 0.31509 (95%-conf.int. 0.31253 - 0.31770)
1 ROUGE-L Average_F: 0.34325 (95%-conf.int. 0.34069 - 0.34584)


$ sudo files2rouge xsum_test.target using_offical_model_test.hypo
---------------------------------------------
1 ROUGE-1 Average_R: 0.49030 (95%-conf.int. 0.48742 - 0.49314)
1 ROUGE-1 Average_P: 0.41509 (95%-conf.int. 0.41240 - 0.41798)
1 ROUGE-1 Average_F: 0.44299 (95%-conf.int. 0.44036 - 0.44571)
---------------------------------------------
1 ROUGE-2 Average_R: 0.23270 (95%-conf.int. 0.22979 - 0.23559)
1 ROUGE-2 Average_P: 0.19613 (95%-conf.int. 0.19347 - 0.19868)
1 ROUGE-2 Average_F: 0.20964 (95%-conf.int. 0.20691 - 0.21222)
---------------------------------------------
1 ROUGE-L Average_R: 0.38979 (95%-conf.int. 0.38700 - 0.39267)
1 ROUGE-L Average_P: 0.32982 (95%-conf.int. 0.32691 - 0.33261)
1 ROUGE-L Average_F: 0.35206 (95%-conf.int. 0.34932 - 0.35488)

在这里插入图片描述

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

复现BART finetune历程 的相关文章

随机推荐

  • vue自定义指令---页面水印

    一些页面为了防止用户截图 可以添加水印 下面介绍以下思路 主要是创建一个新的节点作为水印 设置好水印的样式以后 再添加到目标节点上面去 水印 export default inserted el text let dom document
  • bootstrap导航栏鼠标移入展开

    bootstrap鼠标移入导航展开下拉菜单则加以下jq代码 function dropdown mouseover function this children a addClass show next ul addClass show d
  • 记录第一次部署streamlit应用

    网上相关教程很多 经过多方尝试 记录自己成功的方法 一 通过git将项目文件上传至github 参考教程 23条消息 部署项目到github Gao 的博客 CSDN博客 github部署项目 二 添加requirements 部署在Str
  • 计算机主机mac地址怎么查,怎么查看电脑的Mac地址

    第一种方法 利用dos命令查看Mac地址 1 点击 开始 菜单 在 搜索程序和文件 输入框 输入 cmd 会找到进入dos命令的cmd程序 然后回车 快捷方式 WIN R 在输入cmd 2 回车后 弹出命令符窗口 输入 ipconfig a
  • 技术积累 — Ellisys软件及抓包器用户使用指南

    一 前言 Ellisys号称是业界最先进的蓝牙 Wi Fi USB协议分析仪 支持低功耗蓝牙协议分析测试 支持蓝牙5低功耗以及Wi Fi的物联网应用 支持与原始频谱 UART SPI HCI 逻辑信号等同步的宽带蓝牙5低能耗BLE Wi F
  • [Linux]-进程间通信之消息队列

    目录 消息队列的概述 消息队列的API 1 获取系统唯一Key值 IPC键值 2 创建消息队列 2 1查看消息队列的一些Linux命令 2 2消息队列的创建 3 消息的发送以及定义 3 1 通过消息队列发送信息 4 信息的接收 5 通过消息
  • Codeforces Round #660 (Div. 2)1388C - Uncle Bogdan and Country Happiness (好题,条件判断,DFS)

    题目大意 国家有N个城市 1号城市为首都 有M个国民 每个国民都在首都工作 晚上返回家中 给定每个城市有多少国民居住 每个城市都有一个心情检测器 当国民经过城市时 心情检测器根据国民的心情加减1 但是心情检测器并不精确 所以要求你去判断在所
  • 汽车电子相关术语

    SOA SOA Service Oriented Architecture 面向服务的架构 是一种在计算机环境中设计 开发 部署和管理离散模型的方法 是由Garnter1996年提出的概念 将应用程序的不同功能单元 称为服务 进行拆分 并通
  • NeRF论文翻译笔记

    分享 NeRF神经辐射场理解 深兰深延AI的博客 CSDN博客 神经辐射场 githubNeRF总结 https github com yenchenlin awesome NeRF 目录 摘要 1 介绍 2 相关工作 2 1 神经三维形状
  • ModuleNotFoundError: No module named 'exceptions'

    ModuleNotFoundError No module named exceptions 意味着你在你的代码中尝试使用了一个名为 exceptions 的模块 但是你的程序运行环境中找不到这个模块 这可能是因为这个模块没有安装 或者是你
  • MPC学习记录

    参考 无人驾驶车辆模型预测控制 第二版 第四章详细学习 算法部分 总系学不废的博客 CSDN博客 控制 模型预测控制MPC08 01总结修正 105664978 哔哩哔哩 bilibiliMPC 3 常用车辆模型 MATLAB 无人驾驶车辆
  • 用python计算工程量_使用python计算vintage

    coding utf 8 Created on Mon Jan 14 18 57 19 2019 author hinnc importnumpy as npimportpandas as pd from pandas tseries of
  • Python爬虫实战

    在本篇博客中 我们将使用Scrapy框架完成一个入门爬虫程序 在命令行创建scrapy项目 首先在命令行进入PyCharm的项目目录 然后执行 scrapy startproject 项目名 如ScrapyExample 生产爬虫项目 会自
  • eclipse如何安装server

    在eclipse中想添加配置server的是否 发现Preference目录里并没有Server这个选项 也就是说 我们并没有办法新建服务器 所以要安装一个server 1 eclipse help Install New Software
  • Java获取当前电脑的ip地址

    import java net Inet4Address import java net InetAddress import java net UnknownHostException author guochao version 1 0
  • 一文玩转pytorch转onnx-tensorRT ——(A)onnx转tensorRT

    说明 onnx和tensorRT是分开的 onnx像是prototxt和weight的打包在一起的东西 所以由onnx转到tensorRT下 还需要让onnx能搜索到 或parsing 所对应的层 caffeparsing有注册自定义层的函
  • C# 中的依赖注入模式

    依赖注入模式 DI 首先 依赖注入模式 是一种软件设计模式 它被称为 模式 因为它建议针对特定问题的低级特定实现 该模式旨在解决的主要问题是如何创建 松散耦合 的组件 它通过将组件的创建与其依赖项分开来实现这一点 此模式中有四个主要角色 类
  • Bitbucket入门手册

    老大要我去调研一下有什么好用的免费软件版本管理工具 有利于小团队开发的 我第一个想到的就是git 经常在git下东西 听说它的代码仓库好用 于是就注册了一个github的账号 创建仓库的时候才发现只能创建开源项目 私有仓库要收费 于是就在网
  • pyglet 绝对路径 相对路径

    加载绝对路径 加载 3D 模型文件 model path path to model obj model pyglet resource file model path import pyglet window pyglet window
  • 复现BART finetune历程

    复现BART finetune历程 准备 安装fairseq 使用fairseq官方提供的finetune代码 git clone https github com pytorch fairseq cd fairseq pip instal