Chatglm2使用及微调教程

2023-11-08

1、下载chatglm2代码

GitHub - THUDM/ChatGLM2-6B: ChatGLM2-6B: An Open Bilingual Chat LLM | 开源双语对话语言模型

github代码见上面所示

2、下载chatglm2-6B模型

git lfs clone THUDM/chatglm2-6b · Hugging Face

如果存在如下报错:OpenSSL SSL_connect: SSL_ERROR_SYSCALL in connection to github.com:443。

使用命令:git config --global --unset http.proxy

然后再多执行几次git lfs clone xxx的命令。

3、运行chatglm2

修改web_demo2.py中model位置的代码

然后执行启动命令:streamlit run web_demo2.py,

运行时模型以 FP16 精度加载,占用GPU显存为:13161MiB

注意:确保transformers的版本为4.30.2,否则会报错:ImportError: cannot import name 'GenerationConfig' from 'transformers.generation.utils'。

4、微调p-tuning

(1)官方INT4量化版本

官方教程地址:https://www.heywhale.com/mw/project/64984a7b72ebe240516ae79c

下载AdvertiseGen数据集(见教程中的链接)到ptuning目录下,如下图所示:

/data/work/xiehao/ChatGLM2-6B/ptuning/AdvertiseGen

安装除ChatGLM2-6B的依赖之外的其他python依赖包

pip install rouge_chinese nltk jieba datasets transformers[torch] -i https://pypi.douban.com/simple/

执行命令:

torchrun --standalone --nnodes=1 --nproc-per-node=1 main.py \
    --do_train \
    --train_file AdvertiseGen/train.json \
    --validation_file AdvertiseGen/dev.json \
    --preprocessing_num_workers 1 \
    --prompt_column content \
    --response_column summary \
    --overwrite_cache \
    --model_name_or_path /data/work/xiehao/chatglm2-6b-model \
    --output_dir output/adgen-chatglm2-6b-pt \
    --overwrite_output_dir \
    --max_source_length 64 \
    --max_target_length 128 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 16 \
    --predict_with_generate \
    --max_steps 3000 \
    --logging_steps 10 \
    --save_steps 1000 \
    --learning_rate 2e-2 \
    --pre_seq_len 128 \
    --quantization_bit 4

运行占用GPU显存为:7945MiB,3000个steps整体需要4个小时。

运行日志如下:

运行完成后,生成的模型位于:

ChatGLM2-6B/ptuning/output/adgen-chatglm2-6b-pt/checkpoint-3000下

模型比较:

Chatglm2的大模型约为12G,而微调模型约为7M。

测试微调前后的效果对比:

测试代码:

from transformers import AutoTokenizer, AutoModel, AutoConfig

import os

import torch



chat_str = "类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞"



model_path = "/data/work/xiehao/chatglm2-6b-model"

lora_model_path = "/data/work/xiehao/ChatGLM2-6B/ptuning/output/adgen-chatglm2-6b-pt/checkpoint-3000/pytorch_model.bin"



tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)



# 微调前

#model = AutoModel.from_pretrained(model_path, trust_remote_code=True, device='cuda')

#model = model.eval()

#response, history = model.chat(tokenizer, chat_str, history=[])

#print(response)



# 微调后

config = AutoConfig.from_pretrained(model_path, trust_remote_code=True, pre_seq_len=128)

model = AutoModel.from_pretrained(model_path, config=config, trust_remote_code=True)

prefix_state_dict = torch.load(lora_model_path)

new_prefix_state_dict = {}

for k, v in prefix_state_dict.items():

    if k.startswith("transformer.prefix_encoder."):

        new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v

model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)



model = model.cuda()

model.transformer.prefix_encoder.float()

model = model.eval()

response, history = model.chat(tokenizer, chat_str, history=[])

print(response)

微调前的输出:

这是一个描述一件上衣的文本,它采用牛仔布材质,颜色为白色,风格是简约的,图案是刺绣的,衣样式是外套,衣款式是破洞的。

微调后的输出:

这一款牛仔外套采用白底黑字的图案设计,简约大方,彰显出帅气又酷酷的气息。衣身上的刺绣图案,在微光下显得特别的帅气有型。衣身上的破洞处理,彰显出酷酷的时尚感,让整件外套充满了个性

(2)非量化版本

官方的微调脚本中包含“--quantization_bit 4”,输出INT4量化后的lora模型。

当然我们也不可以不用量化模型,直接去掉就好了。

此时会要求安装accelerate依赖包,执行命令:pip install accelerate -U,其他参数同之前的一样。

此时占用15393MiB的GPU显存,执行时间只要2个小时左右,loss下降的也更快

5、AutoModel加载使用模型解读

入口调用:model = AutoModel.from_pretrained(model_path, trust_remote_code=True)

首先,通过AutoConfig读取model_path目录下的config.json的参数信息

然后,动态读取模型参数和模型网络结构信息。

        在config.json的auto_map的AutoModel信息如下:modeling_chatglm.ChatGLMForConditionalGeneration

其中modeling_chatglm为模型名称信息,ChatGLMForConditionalGeneration为类型信息

通过modeling_chatglm.py的ChatGLMForConditionalGeneration就可以获取到大模型对应的网络结构信息,接着再加载模型文件进而生成模型的实例

最后,通过model.chat(tokenizer, chat_str, history=[])生成结果,就是调用ChatGLMForConditionalGeneration实例的chat方法。

6、微调部分解读

从微调后使用可以看出,ChatGLM只重新训练了网络结构的PrefixEncoder部分的代码。

这层网络主要是根据prompt的tokens生成embedding,可参考网络源码:

    def get_prompt(self, batch_size, device, dtype=torch.half):
        prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
        past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)

微调完成后将这部分的模型信息更新到原来的大模型中。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Chatglm2使用及微调教程 的相关文章

随机推荐

  • linux 下工具

    一 文本比较工具 命令行有 colordiff 图形化的有 meld bcompare 二 记事本 tomboy 工具 快捷键 Bold C B Italic C I Strikeout C S Highlight C H Underlin
  • Hibernate学习笔记 单表映射

    建立实体类 配置好SessionFactory之后 我们就可以开始建立一对一的单表映射了 首先需要建立一个实体类 这里Getter Setter toString equals等方法省略了 我们可以方便的使用IDEA或者Eclipse的生成
  • 【分享】ROM厂商刷机工具合集

    1 MTK刷机 SP Flash Tool下载地址 SP Flash Tool v5 1924 Download SmartPhone Flash Tool MTKClient 下载地址 Releases notmyst33d mtkcli
  • MySQL—存储引擎(下)

    作者 小刘在C站 个人主页 小刘主页 每天分享云计算网络运维课堂笔记 努力不一定有回报 但一定会有收获加油 一起努力 共赴美好人生 树高千尺 落叶归根人生不易 人间真情 前言 上一章讲了存储引擎 本章继续 从特点开始 目录 MySQL 1
  • 中国1-0胜新加坡

    TOM体育讯 北京时间8月16日 亚洲杯预选赛战火重新点燃 中国队在本轮比赛中坐镇天津泰达体育场迎战小组赛的又一个对手新加坡队 结果在赵旭日下半场被早早罚下 中国队以少打多的不利局面下 中国队的顽强感动了上苍 也拯救了自己 补时最后一分钟
  • 【Unity】获取相机画面将其保存成图片

    void CameraCapture Camera m Camera string filename RenderTexture rt new RenderTexture Screen width Screen height 16 m Ca
  • 抓包工具Wireshark使用体会

    这两天在工作上遇到了一些问题 必须要用抓包工具来捕获手机端发送过来的数据包 分析其帧结构 以前虽然学习过网络知识 但是也从未接触过抓包工具Wireshark 迫于工作的压力 自己在摸索中学到了一些基本的使用方法 文件格式 pcap 帧排序
  • 笔记(一)斯坦福CS224W图机器学习、图神经网络、知识图谱

    节点和连接构成的图 如何对图数据进行挖掘 传统机器学习 数据是独立同分布的 解决表格 矩阵 序列等问题 图机器学习处理连接的数据 需要满足以下几个方面 1 图是任意尺寸输入 2 图是动态变化的 有时也是多模态数据 图 可以实现端到端的表示学
  • 矩阵分析学习(补充)

    在系统分析中 会涉及到多项式矩阵互质性的判别问题 此类问题通常归结为两种 1 具有相同行数的多项式左互质 2 具有相同列数的多项式右互质 一 多项式矩阵的右公因子 左公因子 的定义 二 多项式矩阵的最大右公因子 最大左公因子 的定义 首先这
  • Asp.net 移动开发

    Asp net能进行移动开发 移动开发是手机运用 而asp net是网页开发 能合在一起吗 答案是能的 随着科技的发展 现在asp net也能进行移动开发 移动开发也称为手机开发 或叫做移动互联网开发 是指以手机 PDA UMPC等便携终端
  • 基于MATLAB的白鲸算法在太阳能光伏模型参数估计中的应用

    基于MATLAB的白鲸算法在太阳能光伏模型参数估计中的应用 本文将介绍如何使用MATLAB编写基于白鲸算法的太阳能光伏模型参数估计 并提供相应的源代码 太阳能光伏模型的参数估计是对光伏系统性能分析的重要步骤 它可以帮助我们了解和优化光伏系统
  • 如何解决redis的缓存击穿、缓存穿透、缓存雪崩等问题?

    关注我 升职加薪就是你 1 缓存击穿 指一个非常热点的key在缓存过期的一刻 同时有大量的并发请求访问该key 导致所有请求都落到了数据库上 引起数据库压力过大甚至宕机 解决方案 1 设置热点数据永不过期 2 加互斥锁 只允许一个请求去查询
  • Java获取前N个季度的开始时间和结束时间

    获取前N个季度的开始日期和结束日期 param count return private List
  • 【十大经典排序算法】C语言实现

    十大经典排序算法 插入类排序 直接插入排序 折半 二分 插入排序 希尔排序 交换类排序 冒泡排序 快速排序 选择类排序 选择排序 树形选择排序 堆排序 归并排序 计数排序 分配类排序 捅排序 基数排序 插入类排序 直接插入排序 void i
  • 关于pip安装第三方库,但PyCharm中却无法识别的问题;以及PyCharm安装第三方库的方法解析

    Table of Contents 一 问题具体描述 二 解决方法 1 方法一 在PyCharm下载第三方库 即把之前下的库作废 这里重新再下一次 2 方法二 坚持用pip的方法安装第三方库 三 扩展延伸 pip install 安装路径问
  • BP神经网络的非线性系统建模以及matlab神经网络工具箱的使用

    在所有的关系中 数学公式的线性表达是对那些规律性数据的预测统计 而非线性关系的数据 数学方程式只能通过多个参数尽可能模拟数据曲线 神经网络的非线性拟合能力不仅在于参数多还在于激活函数的非线性表达 以拟合拟合的非线性函数为 为例 BP神经网络
  • 【论文阅读 08】Defect Detection in Electronic Surfaces Using Template-Based Fourier Image Reconstruction

    比较老的一篇论文 基于模板的傅里叶图像重建电子表面的缺陷检测 关键词 缺陷检测 傅里叶变换 F T 机器视觉 印刷电路板 PCB 模板匹配 总结 1 Abstract 一种用于检测和定位非周期性模式图像中小缺陷的新方法 在电子工业中 例如在
  • 完美解决E: Unable to lock directory /var/lib/apt/lists/方案

    使用命令 sudo fuser vki var lib apt lists lock 重新执行 sudo apt update
  • QT之动态进度条

    简介 前两天需要接到一个需求需要做一个好看的进度条 在网上搜了一圈发现要不然就是不符合我的需求要不然就是没有源码 最后找到一个大佬写的有部分源码的 自己也折腾了一个 原文链接 效果图 思路 主要就是重写了QProcessBar的paintE
  • Chatglm2使用及微调教程

    1 下载chatglm2代码 GitHub THUDM ChatGLM2 6B ChatGLM2 6B An Open Bilingual Chat LLM 开源双语对话语言模型 github代码见上面所示 2 下载chatglm2 6B模