轻量微调和推理stanford_alpca

2023-10-26

当前的Alpaca模型是在Self-Instruct论文中使用的技术生成的52K条指令数据,从7B LLaMA模型微调而来,并进行了一些修改。

A10 gpu显存:22G,cu117,驱动470.103.01

absl-py                  1.4.0
accelerate               0.18.0
addict                   2.4.0
aenum                    3.1.12
aiofiles                 23.1.0
aiohttp                  3.8.4
aiosignal                1.3.1
albumentations           0.4.3
altair                   4.2.2
antlr4-python3-runtime   4.9.3
anyio                    3.6.2
appdirs                  1.4.4
asttokens                2.2.1
async-timeout            4.0.2
attrs                    22.2.0
backcall                 0.2.0
basicsr                  1.4.2
bcrypt                   4.0.1
beautifulsoup4           4.12.1
blendmodes               2022
blinker                  1.6
boltons                  23.0.0
braceexpand              0.1.7
cachetools               5.3.0
certifi                  2022.12.7
cffi                     1.15.1
chardet                  4.0.0
charset-normalizer       3.1.0
clean-fid                0.1.29
click                    8.1.3
clip-anytorch            2.5.2
cmake                    3.26.1
comm                     0.1.3
contourpy                1.0.7
cryptography             40.0.1
cssselect2               0.7.0
cycler                   0.11.0
datasets                 2.11.0
debugpy                  1.6.7
decorator                5.1.1
deprecation              2.1.0
diffusers                0.15.0.dev0
dill                     0.3.6
docker-pycreds           0.4.0
einops                   0.4.1
entrypoints              0.4
executing                1.2.0
facexlib                 0.2.5
fastapi                  0.94.0
ffmpy                    0.3.0
filelock                 3.10.7
filterpy                 1.4.5
fire                     0.5.0
font-roboto              0.0.1
fonts                    0.0.3
fonttools                4.39.3
frozenlist               1.3.3
fsspec                   2023.3.0
ftfy                     6.1.1
future                   0.18.3
gdown                    4.7.1
gfpgan                   1.3.8
gitdb                    4.0.10
GitPython                3.1.30
google-auth              2.17.2
google-auth-oauthlib     1.0.0
gradio                   3.16.2
grpcio                   1.53.0
h11                      0.12.0
httpcore                 0.15.0
httpx                    0.23.3
huggingface-hub          0.15.1
idna                     2.10
imageio                  2.9.0
imageio-ffmpeg           0.4.2
imgaug                   0.2.6
importlib-metadata       6.1.0
inflection               0.5.1
ipykernel                6.23.1
ipython                  8.13.2
jedi                     0.18.2
Jinja2                   3.1.2
joblib                   1.2.0
jsonmerge                1.8.0
jsonschema               4.17.3
jupyter_client           8.2.0
jupyter_core             5.3.0
kiwisolver               1.4.4
kornia                   0.6.7
lark                     1.1.2
lazy_loader              0.2
linkify-it-py            2.0.0
lit                      16.0.0
llvmlite                 0.39.1
lmdb                     1.4.0
lpips                    0.1.4
lxml                     4.9.2
Markdown                 3.4.3
markdown-it-py           2.2.0
MarkupSafe               2.1.2
matplotlib               3.7.1
matplotlib-inline        0.1.6
mdit-py-plugins          0.3.5
mdurl                    0.1.2
mpmath                   1.3.0
multidict                6.0.4
multiprocess             0.70.14
mypy-extensions          1.0.0
nest-asyncio             1.5.6
networkx                 3.1rc0
nltk                     3.8.1
numba                    0.56.4
numexpr                  2.8.4
numpy                    1.23.3
nvidia-cublas-cu11       11.10.3.66
nvidia-cuda-cupti-cu11   11.7.101
nvidia-cuda-nvrtc-cu11   11.7.99
nvidia-cuda-runtime-cu11 11.7.99
nvidia-cudnn-cu11        8.5.0.96
nvidia-cufft-cu11        10.9.0.58
nvidia-curand-cu11       10.2.10.91
nvidia-cusolver-cu11     11.4.0.1
nvidia-cusparse-cu11     11.7.4.91
nvidia-nccl-cu11         2.14.3
nvidia-nvtx-cu11         11.7.91
oauthlib                 3.2.2
omegaconf                2.2.3
open-clip-torch          2.7.0
openai                   0.27.7
opencv-python            4.7.0.72
opencv-python-headless   4.7.0.72
orjson                   3.8.9
packaging                23.0
pandas                   1.5.3
paramiko                 3.1.0
parso                    0.8.3
pathtools                0.1.2
pexpect                  4.8.0
pickleshare              0.7.5
piexif                   1.1.3
Pillow                   9.4.0
pip                      23.0.1
platformdirs             3.5.1
prompt-toolkit           3.0.38
protobuf                 3.20.3
psutil                   5.9.4
ptyprocess               0.7.0
pudb                     2019.2
pure-eval                0.2.2
pyarrow                  11.0.0
pyasn1                   0.4.8
pyasn1-modules           0.2.8
pycparser                2.21
pycryptodome             3.17
pydantic                 1.10.7
pydeck                   0.8.0
pyDeprecate              0.3.1
pydub                    0.25.1
Pygments                 2.14.0
Pympler                  1.0.1
PyNaCl                   1.5.0
pyparsing                3.0.9
pyre-extensions          0.0.23
pyrsistent               0.19.3
PySocks                  1.7.1
python-dateutil          2.8.2
python-multipart         0.0.6
pytorch-lightning        1.7.6
pytz                     2023.3
pytz-deprecation-shim    0.1.0.post0
PyWavelets               1.4.1
PyYAML                   6.0
pyzmq                    25.1.0
realesrgan               0.3.0
regex                    2023.3.23
reportlab                3.6.12
requests                 2.25.1
requests-oauthlib        1.3.1
resize-right             0.0.2
responses                0.18.0
rfc3986                  1.5.0
rich                     13.3.3
rouge-score              0.1.2
rsa                      4.9
safetensors              0.2.7
scikit-image             0.19.2
scipy                    1.10.1
semver                   3.0.0
sentencepiece            0.1.99
sentry-sdk               1.19.0
setproctitle             1.3.2
setuptools               59.6.0
six                      1.16.0
smmap                    5.0.0
sniffio                  1.3.0
soupsieve                2.4
stack-data               0.6.2
starlette                0.26.1
streamlit                1.20.0
svglib                   1.5.1
sympy                    1.12rc1
tb-nightly               2.13.0a20230405
tensorboard              2.12.1
tensorboard-data-server  0.7.0
tensorboard-plugin-wit   1.8.1
termcolor                2.3.0
test-tube                0.7.5
tifffile                 2023.3.21
timm                     0.6.7
tinycss2                 1.2.1
tokenizers               0.12.1
toml                     0.10.2
toolz                    0.12.0
torch                    2.0.1
torchdiffeq              0.2.3
torchmetrics             0.11.4
torchsde                 0.2.5
tornado                  6.2
tqdm                     4.65.0
traitlets                5.9.0
trampoline               0.1.2
transformers             4.28.0.dev0     /mnt/workspace/demos/alpaca/transformers
triton                   2.0.0
typing_extensions        4.5.0
typing-inspect           0.8.0
tzdata                   2023.3
tzlocal                  4.3
uc-micro-py              1.0.1
urllib3                  1.26.15
urwid                    2.1.2
uvicorn                  0.21.1
validators               0.20.0
wandb                    0.14.0
watchdog                 3.0.0
wcwidth                  0.2.6
webdataset               0.2.5
webencodings             0.5.1
websockets               11.0
Werkzeug                 2.2.3
wheel                    0.37.1
xformers                 0.0.16rc425
xxhash                   3.2.0
yapf                     0.32.0
yarl                     1.8.2
zipp                     3.15.0

aplaca的显存要求是比较大的,目前来看基本要保证32G的显存,当然我们可以通过调整模型的结构大小来减小显存。

1.下载stanford_alpaca

!wget https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/alpaca/stanford_alpaca.tgz
!tar -xvf stanford_alpaca.tgz

2.安装依赖

!cd stanford_alpaca &&  echo y | pip uninstall torch &&  echo y | pip uninstall torchvision && pip install -r requirements.txt && pip install gradio

!git clone https://github.com/huggingface/transformers.git && \
cd transformers && \
git checkout 165dd6dc916a43ed9b6ce8c1ed62c3fe8c28b6ef && \
pip install -e .

3.数据准备 

数据格式如下,如需使用自己的数据进行微调可以转化成如下形式:
"instruction":用于描述模型应该执行的任务
"input" : 任务的可选上下文或输入。例如,当指令是“总结以下文章”时,输入就是文章。
"output" :需要模型输出的答案

格式如下
[
    {
        "instruction": "Give three tips for staying healthy.",
        "input": "",
        "output": "1.Eat a balanced diet and make sure to include plenty of fruits and vegetables. \n2. Exercise regularly to keep your body active and strong. \n3. Get enough sleep and maintain a consistent sleep schedule."
    }
]

# 下载数据集,如有重名文件,先将文件夹中的重名文件重命名。
!wget https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/alpaca/alpaca_data.json

4.微调模型

4.1 准备权重

llama-7B的权重大概有12G

!wget https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/alpaca/llama-7b-hf.tar.gz && tar -xvf llama-7b-hf.tar.gz

4.2 参数调节

可以对参数进行微调以适应显存,可以修改部分参数来保证在较小显存和单卡上也可以测试,根据预训练路径找到对应的config.json文件,并按照下面的参数修改./llama-7b-hf路径下面的config.json文件修改max_sequence_length和num_hidden_layers等参数可以保证较小显存也可以训练。

{
    "architectures": ["LLaMAForCausalLM"], 
    "bos_token_id": 0, 
    "eos_token_id": 1, 
    "hidden_act": "silu", 
    "hidden_size": 4096, 
    "intermediate_size": 11008, 
    "initializer_range": 0.02, 
    "max_sequence_length": 4, 
    "model_type": "llama", 
    "num_attention_heads": 32, 
    "num_hidden_layers": 4, 
    "pad_token_id": -1, 
    "rms_norm_eps": 1e-06, 
    "torch_dtype": "float16", 
    "transformers_version": "4.27.0.dev0", 
    "use_cache": true, 
    "vocab_size": 32000
}

4.3 训练

在stanford_alpaca/train.py中加上

import os
os.environ["WANDB_DISABLED"] = "true"
# 执行训练指令
!torchrun --nproc_per_node=1 --master_port=29588 ./stanford_alpaca/train.py \
 --model_name_or_path "./llama-7b-hf" \
 --data_path ./alpaca_data.json \
 --bf16 False \
 --output_dir ./models/alpaca-2 \
 --num_train_epochs 1 \
 --per_device_train_batch_size 1 \
 --per_device_eval_batch_size 1 \
 --gradient_accumulation_steps 8 \
 --evaluation_strategy "no" \
 --save_strategy "steps" \
 --save_steps 20 \
 --save_total_limit 1 \
 --learning_rate 2e-5 \
 --model_max_length 4 \
 --weight_decay 0. \
 --warmup_ratio 0.03 \
 --lr_scheduler_type "cosine" \
 --logging_steps 1 \
 --fsdp "full_shard auto_wrap" \
 --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
 --tf32 False 

5.推理阶段

import transformers
tokenizers = transformers.LlamaTokenizer.from_pretrained("./models/alpaca-2")
model = transformers.LlamaForCausalLM.from_pretrained("./models/alpaca-2").cuda()
model.eval()
def gen(req):
    batch = tokenizers(req, return_tensors='pt', add_special_tokens=False)
    batch = {k: v.cuda() for k, v in batch.items()}
    full_completion = model.generate(inputs=batch["input_ids"],
                                    attention_mask=batch["attention_mask"],
                                    temperature=0.7,
                                    top_p=0.9,
                                    do_sample=True,
                                    num_beams=1,
                                    max_new_tokens=600,
                                    eos_token_id=tokenizers.eos_token_id,
                                    pad_token_id=tokenizers.pad_token_id)
    print(tokenizers.decode(full_completion[0]))

gen("List all Canadian provinces in alphabetical order.")

在这个路径中有完整的原始权重

!wget  https://atp-modelzoo-sh.oss-cn-shanghai.aliyuncs.com/release/tutorials/alpaca/gen.py

6.demo

import gradio as gr
import requests
import json
import transformers

tokenizers = transformers.LlamaTokenizer.from_pretrained("./models/alpaca-2")
model = transformers.LlamaForCausalLM.from_pretrained("./models/alpaca-2").cuda()
model.eval()


def inference(text):
    batch  = tokenizers(text, return_tensors="pt", add_special_tokens=False)                                                                                                                                                      
    batch = {k: v.cuda() for k, v in batch.items()}                                                                                                                                                                              
    full_completion = model.generate(inputs=batch["input_ids"],                                                                                                                                                                  
                                     attention_mask=batch["attention_mask"],                                                                                                                                                      
                                     temperature=0.7,                                                                                                                                                                             
                                     top_p=0.9,                                                                                                                                                                                   
                                     do_sample=True,                                                                                                                                                                              
                                     num_beams=1,                                                                                                                                                                                 
                                     max_new_tokens=600,                                                                                                                                                                          
                                     eos_token_id=tokenizers.eos_token_id,                                                                                                                                                        
                                     pad_token_id=tokenizers.pad_token_id)                                                                                                                                                                                                                                                                                                                                                              
    print(tokenizers.decode(full_completion[0]))
    return tokenizers.decode(full_completion[0])

demo = gr.Blocks()
with demo:
    input_prompt = gr.Textbox(label="请输入需求", 
                                value="帮我写一篇安全检查的新闻稿件。",
                                lines=6)
    generated_txt = gr.Textbox(lines=6)

    b1 = gr.Button("发送")
    b1.click(inference, inputs=[input_prompt], outputs=generated_txt) 

demo.launch(enable_queue=True, share=True)

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

轻量微调和推理stanford_alpca 的相关文章

随机推荐

  • curl 命令的学习笔记

    curl 命令的学习笔记 curl 官网 https curl haxx se curl 全称 CommmandLine URL 或 CommandLine Uniform Resource Locator 是用于从服务器传输传输数据或向服
  • 【ARM】Linux内核驱动之定时器

    作者主页 凉开水白菜 作者简介 共同学习 互相监督 热于分享 多加讨论 一起进步 专栏资料 https gitee com stylle linux code 点赞 收藏 再看 养成习惯 订阅的粉丝可通过PC端文末加我微信 可对文章的内容进
  • set -e -x 等等的作用

    set指令能设置所使用shell的执行方式 可依照不同的需求来做设置 a 标示已修改的变量 以供输出至环境变量 b 使被中止的后台程序立刻回报执行状态 C 转向所产生的文件无法覆盖已存在的文件 d Shell预设会用杂凑表记忆使用过的指令
  • VUE 输入框实现光标插入,设置光标位置并删除光标内容

    最近做项目遇到这样一个需求 可以往输入框指定光标出插入内容 并且当删除插入的内容时会先将插入的内容进行光标选中给用户进行提示 当再次删除时才删除内容 而这个需求的核心就在 setSelectionRange 设置光标位置 这个dom api
  • Docker之Nacos的持久化和集群部署

    注1 小插曲 由于虚拟机分配的内存为1G 开到第四个容器时 由于内存不够导致容器启动失败 重新设置4G内存后启动成功 ok 正式进入主题 一 Docker mysql 5 7的持久化存储及远程连接 1 拉取相关镜像 目前网络模式为 brid
  • Flutter中 解决自定义阿里妈妈图标一直显示不出来的问题

    前些天发现了一个蛮有意思的人工智能学习网站 8个字形容一下 通俗易懂 风趣幽默 感觉非常有意思 忍不住分享一下给大家 点击跳转到教程 前言 Flutter中 自定义图标一直显示出来的问题 这里引用的是阿里妈妈图标 问题解决 位置一定要对应好
  • 解决小程序报错getLocation:fail the api need to be declared in the requiredPrivateInfos field in app.json

    报错 1 uniapp项目 在manifest json中打开源码视图 小程序特有相关 mp weixin appid 你的开发者id setting urlCheck true es6 true postcss true minified
  • 【Scala入门】scala基础语法:类和对象,变量和常量

    上一篇请移步 Scala入门 Scala下载及安装 Windows 以及Idea创建第一个scala项目 水w的博客 CSDN博客 目录 一 Scala 二 Scala基础语法 2 1 注释与标识符规范 2 2 变量与常量 案例 变量声明和
  • 摩尔定律到摩尔第二定律

    摩尔定律相信大家都不陌生 由英特尔创始人之一戈登 摩尔提出来的 其内容为 当价格不变时 集成电路上可容纳的元器件的数目 约每隔两年便会增加一倍 而普遍的说法是约每隔18个月便会增加一倍 各种说法总结起来就是 1 集成电路芯片上所集成的电路的
  • 【docker】/var/lib/docker/overlay2/ 占用磁盘问题 最终解决方案

    找IT 挂载了新磁盘 比如 data2 100G 在docker配置文件中 加上这个 systemctl daemon reload 重启docker服务即可 会导致此服务器上的所有docker 容器丢失 需要重新部署 还会导致一个问题 d
  • Hibernate参数校验报错:No validator'javax.validation.constraints.Size' validating type 'java.lang.Integer'.

    javax validation UnexpectedTypeException HV000030 No validator could be found for constraint javax validation constraint
  • Python网络爬虫学习笔记(三)正则表达式

    正则表达式 正则表达式是处理字符串的强大工具 它有自己特定的语法结构 有了它 实现字符串的检索 替换 匹配验证 1 实例引入 正则表达式匹配 也就是用一定的规则将特定的文本提取出来 开源中国提供了正则表达式测试工具 https tool o
  • 虚拟机升级glibc(libc), 导致段错误等问题

    由于确实glibc高版本 需要升级glibc 导致出现段错误等信息 只剩下pwd cd等命令可以执行 这个时候需要靠补全命令查询到原系统使用的libc 2 xx文件 然后使用sln 原系统的重新索引libc so 6文件 sln lib64
  • SOA是什么?

    写这样的blog很容易被人砸砖头 而且我现在在专心做BPEL的研究 http hongsoft iteye com admin blogs 287353 也没有必要现在趟这个混水 不过想想 还是有话要说 定义 SOA是一种做架构的范式 这个
  • FreeSwitch数据库

    Freeswitch数据库 一 ODBC DSN 1 概念 ODBC 开放数据库连接 Open Database Connectivity ODBC https baike baidu com item ODBC 是为解决异构数据库间的数据
  • 线性回归(两种方式代码实现)

    方式一 最小二乘法 正规方程 公式推导 其中 代码实现 1 导入库 import numpy as np from sklearn datasets import load boston boston load boston x bosto
  • 前端面试总结及建议

    最近 由于项目组刚成立不久 团队处于天地初开的混沌状态 人员配置不齐 急需一大股新鲜血液融入 为此 开启了一段时间与求职面试者的博弈之路 如今的IT大环境 似乎每个公司一年四季都处于招人状态 而同时又有一大批无论是离职还是在职人员期许找一个
  • Linux操作命令笔记

    Linux Linux的字母大小写 下载和卸载 软件更新 查看空间使用情况 当前目录所在的位置 查看文件中的内容 查看目录下的文件 重启 关机 移动文件 磁盘管理软件 修改权限 删除文件或文件夹 新建文件夹 移动一个文件夹 文件重命名 编译
  • CMake中define_property的使用

    CMake中的define property命令用于定义和记录自定义属性 其格式如下 define property
  • 轻量微调和推理stanford_alpca

    当前的Alpaca模型是在Self Instruct论文中使用的技术生成的52K条指令数据 从7B LLaMA模型微调而来 并进行了一些修改 A10 gpu显存 22G cu117 驱动470 103 01 absl py 1 4 0 ac