将Hugging Face模型转换成LibTorch模型

2023-11-02

Hugging Face的模型

waifu-diffusion模型为例,给出的实现一般是基于diffuser库,示例代码如下:

import torch
from torch import autocast
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained(
    'hakurei/waifu-diffusion',
    torch_dtype=torch.float32
).to('cuda')

prompt = "1girl, aqua eyes, baseball cap, blonde hair, closed mouth, earrings, green background, hat, hoop earrings, jewelry, looking at viewer, shirt, short hair, simple background, solo, upper body, yellow shirt"
with autocast("cuda"):
    image = pipe(prompt, guidance_scale=6)["sample"][0]  
    
image.save("test.png")

通过网络下载预训练模型,预训练模型直接加载,但其实这个模型是下载到了本地的,只不过看起来不是很轻松:

因为模型太大,分成了一些小的文件进行了下载,而且后面可以看出来模型实际上是由一些子模型组成的,所以这里面有几个比较大的文件应该是对应了unet、vae这种,看大小也差不多。

下载好了可以直接print(pipe),发现:

StableDiffusionPipeline {
  "_class_name": "StableDiffusionPipeline",
  "_diffusers_version": "0.11.0",
  "feature_extractor": [
    "transformers",
    "CLIPImageProcessor"
  ],
  "requires_safety_checker": true,
  "safety_checker": [
    "stable_diffusion",
    "StableDiffusionSafetyChecker"
  ],
  "scheduler": [
    "diffusers",
    "PNDMScheduler"
  ],
  "text_encoder": [
    "transformers",
    "CLIPTextModel"
  ],
  "tokenizer": [
    "transformers",
    "CLIPTokenizer"
  ],
  "unet": [
    "diffusers",
    "UNet2DConditionModel"
  ],
  "vae": [
    "diffusers",
    "AutoencoderKL"
  ]
}

果然是一系列的小模型以及一些不重要的参数,这个模型可以直接保存为.pth文件,同样也可以使用torch.load(pipe.pth)读入,但是在实例化模型的时候,会出现

Traceback (most recent call last):
  File "/home/gaoyi/example-app/test.py", line 59, in <module>
    traced_script_module = torch.jit.trace(model, example)
  File "/home/gaoyi/anaconda3/lib/python3.9/site-packages/torch/jit/_trace.py", line 803, in trace
    name = _qualified_name(func)
  File "/home/gaoyi/anaconda3/lib/python3.9/site-packages/torch/_jit_internal.py", line 1125, in _qualified_name
    raise RuntimeError("Could not get name of python class object")
RuntimeError: Could not get name of python class object

这是因为这个大家伙不能作为一个模型类加载,故也不能直接通过torch.jit.trace进行转化,我们换个方式,将子模型进行转化

模型转化

通过打印print(pipe.unet),可以看出这个unet是一个普通的网络,拥有一堆熟悉的网络层:

UNet2DConditionModel(
  (conv_in): Conv2d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (time_proj): Timesteps()
  (time_embedding): TimestepEmbedding(
    (linear_1): Linear(in_features=320, out_features=1280, bias=True)
    (act): SiLU()
    (linear_2): Linear(in_features=1280, out_features=1280, bias=True)
  )
  (down_blocks): ModuleList(
    (0): CrossAttnDownBlock2D(
      (attentions): ModuleList(
        (0): Transformer2DModel(
          (norm): GroupNorm(32, 320, eps=1e-06, affine=True)
          (proj_in): Linear(in_features=320, out_features=320, bias=True)
          (transformer_blocks): ModuleList(
            (0): BasicTransformerBlock(
              (attn1): CrossAttention(
                (to_q): Linear(in_features=320, out_features=320, bias=False)
                (to_k): Linear(in_features=320, out_features=320, bias=False)
                (to_v): Linear(in_features=320, out_features=320, bias=False)
                (to_out): ModuleList(
                  (0): Linear(in_features=320, out_features=320, bias=True)
                  (1): Dropout(p=0.0, inplace=False)
                )
              )
              (ff): FeedForward(
                (net): ModuleList(
                  (0): GEGLU(
                    (proj): Linear(in_features=320, out_features=2560, bias=True)
                  )
                  (1): Dropout(p=0.0, inplace=False)
                  (2): Linear(in_features=1280, out_features=320, bias=True)
                )
              )
              (attn2): CrossAttention(
                (to_q): Linear(in_features=320, out_features=320, bias=False)
                (to_k): Linear(in_features=1024, out_features=320, bias=False)
                (to_v): Linear(in_features=1024, out_features=320, bias=False)
                (to_out): ModuleList(
                  (0): Linear(in_features=320, out_features=320, bias=True)
                  (1): Dropout(p=0.0, inplace=False)
                )
              )
              (norm1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
              (norm2): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
              (norm3): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
            )
          )
          (proj_out): Linear(in_features=320, out_features=320, bias=True)
        )
        (1): Transformer2DModel(
       
        ...
        ...略
        ...
        
  (conv_norm_out): GroupNorm(32, 320, eps=1e-05, affine=True)
  (conv_act): SiLU()
  (conv_out): Conv2d(320, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

好,那我们就可以将这个子模型进行转化,变成需要的LibTorch模型,但是我们不知道这个模型需要的输入,通过打印的信息我们知道了这个模型的名字是UNet2DConditionModel,所以我们可以从Hugging Face的官方文档进行查询:UNet2DConditionModel

查询发现模型的输入为:

但是具体的数值依旧不知道,这时候可以通过print(model.config)进行查看:

FrozenDict([('sample_size', 64), ('in_channels', 4), ('out_channels', 4), ('center_input_sample', False), 
('flip_sin_to_cos', True), ('freq_shift', 0), ('down_block_types', ['CrossAttnDownBlock2D', 
'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D']), ('mid_block_type', 
'UNetMidBlock2DCrossAttn'), ('up_block_types', ['UpBlock2D', 'CrossAttnUpBlock2D', 
'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D']), ('only_cross_attention', False), 
('block_out_channels', [320, 640, 1280, 1280]), ('layers_per_block', 2), ('downsample_padding', 1), 
('mid_block_scale_factor', 1), ('act_fn', 'silu'), ('norm_num_groups', 32), ('norm_eps', 1e-05), 
('cross_attention_dim', 1024), ('attention_head_dim', [5, 10, 20, 20]), ('dual_cross_attention', False), 
('use_linear_projection', True), ('class_embed_type', None), ('num_class_embeds', None), 
('upcast_attention', False), ('resnet_time_scale_shift', 'default'), ('_class_name', 'UNet2DConditionModel'), 
('_diffusers_version', '0.10.2'), ('_name_or_path', 
'/home/gaoyi/.cache/huggingface/diffusers/models--hakurei--waifu-diffusion/snapshots/55fd50bfae0dd8bcc4bd3a6f25cb167580b972a0/unet')])

一个大字典,找到我们所需要的('sample_size', 64), ('in_channels', 4), ('out_channels', 4),作为用于实例化的输入,此时我们的.py文件如下:

model = torch.load("pipe-unet.pth")

# print(model.config)
# print(model)

example = torch.rand(1, 4, 64, 64)
timestep = torch.rand(1)
encoder_hidden_states = torch.rand(1, 4, 64, 64)

traced_script_module = torch.jit.trace(model, (example, timestep, encoder_hidden_states))
traced_script_module.save("pipe-unet.pt")

但是报错mat1 can not be multiplied with mat2, shape 256x64 and 1024x320,大概是这么个问题,具体的信息就不粘贴了,既然是矩阵形状不对,那就改形状,之前理解的encoder_hidden_states形状与example应该是一样的,但看起来不对,可是改了1024x1024之后又遇到了新的问题,计算注意力的时候数据太多,接受的参数只有三个,所以干脆将encoder_hidden_states = torch.rand(1, 4, 1024),实测通过

之后的新问题,好像是实例化的时候输入元组的问题,具体如下:

RuntimeError: Encountering a dict at the output of the tracer might cause the trace to be incorrect, 
this is only valid if the container structure does not change based on the module's inputs. 
Consider using a constant container instead (e.g. for `list`, use a `tuple` instead. for `dict`, 
use a `NamedTuple` instead). If you absolutely need this and know the side effects, 
pass strict=False to trace() to allow this behavior.

应该是需要在转换的时候传一个参数strict=False,调整完之后的代码如下:

model = torch.load("pipe-unet.pth")

# print(model.config)
# print(model)

example = torch.rand(1, 4, 64, 64)
timestep = torch.rand(1)
encoder_hidden_states = torch.rand(1, 4, 1024)

traced_script_module = torch.jit.trace(model, (example, timestep, encoder_hidden_states), strict=False)
traced_script_module.save("pipe-unet.pt")

成功保存!

模型测试

根据PyTorch官网的测试教程,编写相应的C++文件,然后使用CMake进行编译,最终生成example-app的可执行文件,运行:

./example-app ../pipe-unet.pt

输出ok,成功转化!

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

将Hugging Face模型转换成LibTorch模型 的相关文章

  • Python setuptools:如何在 setup.py 中添加私有存储库 (gitlab)?

    我上传了 2 个包 它们位于我的 gitlab 存储库中 如果我想使用 pip 将它们安装在我的系统中 这很容易 因为 gitlab 可以帮助您 https docs gitlab com ee user packages pypi rep
  • 如何将base64字符串直接解码为二进制音频格式

    音频文件通过 API 发送给我们 该文件是 Base64 编码的 PCM 格式 我需要将其转换为 PCM 然后再转换为 WAV 进行处理 我能够使用以下代码解码 gt 保存到 pcm gt 从 pcm 读取 gt 保存为 wav decod
  • 从Python中的字符串中提取货币金额

    我正在制作一个程序 从字符串中获取货币并将其转换为其他货币 例如 如果字符串是 the car cost me 13 250 我需要得到 and 13250 我已经有了这个正则表达式 1 确实如此 但是该字符串很有可能有多个价格 并且全部使
  • 将 numpy 数组写入文本文件的速度

    我需要将一个非常 高 的两列数组写入文本文件 而且速度非常慢 我发现如果我将数组改造成更宽的数组 写入速度会快得多 例如 import time import numpy as np dataMat1 np random rand 1000
  • 为什么我的代码不能根据字典解码加密字符串?

    我有一本字典 其中包含代表字母的键和值 例如一个简单的 DICT CODE b g n a p o x d t y 我收到了一个加密代码 并将该字符串转换为一个列表 其中每个项目都是一个单词 我需要根据字典中的项目来解决它 代码示例是 wo
  • Tweepy StreamListener 到 CSV

    我是 python 新手 我正在尝试开发一个应用程序 使用 Tweepy 和 Streaming API 从 Twitter 检索数据并将数据转换为 CSV 文件 问题是此代码不会创建输出 CSV 文件 也许是因为我应该将代码设置为在实现例
  • 当单词以“|”分隔时如何读取文件(埃因霍温)?

    在Python中 我有一个文件 其中的单词由 例如 city state zipcode 我的文件阅读器无法区分单词 另外 我希望我的文件阅读器从第 2 行而不是第 1 行开始 如何让我的文件阅读器分隔单词 import os import
  • 根据开始列和结束列扩展数据框(速度)

    我有一个pandas DataFrame含有start and end列 加上几个附加列 我想将此数据框扩展为一个时间序列 从start值并结束于end值 但复制我的其他专栏 到目前为止 我想出了以下内容 import pandas as
  • Python Pandas 根据另一列的总计从另一个数据帧中选择值

    我下面有一个 DataFrame 但我需要根据取消和订单列从每个代码中选择行 假设代码 xxx 的阶数为 6 1 5 1 阶数为 11 我需要一种算法 可以选择满足总共 11 行的行 阶数为 6 5 如果没有行匹配 则选择最接近的 id 并
  • Apache Spark 中的高效字符串匹配

    我使用 OCR 工具从屏幕截图中提取文本 每个大约 1 5 句话 然而 当手动验证提取的文本时 我注意到时不时会出现一些错误 鉴于文本 你好 我真的很喜欢 Spark 我注意到 1 像 I 和 l 这样的字母被 替换 2 表情符号未被正确提
  • 为什么我无法在 Mac OS X Terminal.app 上的 Python 解释器中显示 unicode 字符?

    如果我尝试粘贴 unicode 字符 例如中间的点 在我的 python 解释器中它什么也不做 我在 Mac OS X 上使用 Terminal app 当我只是在 bash 中时 我没有遇到任何问题 但在解释器中 python Pytho
  • `list()` 被认为是一个函数吗?

    list显然是内置类型 https docs python org 3 library stdtypes html list在Python中 我看到底下有一条评论this https stackoverflow com a 53645813
  • 具有屏蔽无效值的 pcolormesh

    我试图将一维数组绘制为 pcolormesh 因此颜色沿 x 轴变化 但每个 x 的 y 轴保持不变 但我的数据有一些错误值 因此我使用屏蔽数组和自定义颜色图 其中屏蔽值设置为蓝色 import numpy as np import mat
  • 解析根元素内元素之间的 XML 文本

    我正在尝试用 Python 解析 XML 以下是 XML 结构的示例 a aaaa1 b bbbb b aaaa2 a
  • 如何使用 Keras ImageDataGenerator 预测单个图像?

    我已经训练 CNN 对图像进行 3 类分类 在训练模型时 我使用 keras 的 ImageDataGenerator 类对图像应用预处理功能并重新缩放它 现在我的网络在测试集上训练得非常准确 但我不知道如何在单图像预测上应用预处理功能 如
  • Python 通过从现有 csv 文件中过滤选定的行来写入新的 csv 文件

    只是一个问题 我试图将 csv 文件中的选定行写入新的 csv 文件 但出现错误 我试图读取的 test csv 文件是这样的 两列 2013 9 1 2013 10 2 2013 11 3 2013 12 4 2014 1 5 2014
  • 如何在单元测试中使用 JSON 发送请求

    我的 Flask 应用程序中有在请求中使用 JSON 的代码 我可以像这样获取 JSON 对象 Request request get json 这一直工作得很好 但是我正在尝试使用 Python 的 unittest 模块创建单元测试 但
  • 如何循环遍历字典列表并打印特定键的值?

    我是 Python 新手 有一个问题 我知道这是一个非常简单的问题 运行Python 3 4 我有一个需要迭代并提取特定信息的列表 以下是列表 称为部分 的示例 已截断 数千个项目 state DEAD id phwl type name
  • 在 Django shell 会话期间获取 SQL 查询计数

    有没有办法打印 Django ORM 在 Django shell 会话期间执行的原始 SQL 查询的数量 Django 调试工具栏已经提供了此类信息 例如 5 QUERIES in 5 83MS但如何从 shell 中获取它并不明显 您可
  • 如何为所有用户安装 Anaconda python?

    Anaconda python 发行版 https store continuum io cshop anaconda 非常方便地部署科学计算环境 SCE 并根据需要切换python版本 默认情况下 安装会将 python 定位到 anac

随机推荐

  • AI 协助办公 |记一次用 GPT-4 写一个消息同步 App

    GPT 4 最近风头正劲 作为 NebulaGraph 的研发人员的我自然是跟进新技术步伐 恰好 现在有一个将 Slack channel 消息同步到其他 IM 的需求 看看 GPT 4 能不能帮我完成这次的信息同步工具的代码编写工作 本文
  • 二叉树的翻转

    目录 一 题目 二 解题思路 1 二叉树翻转 2 具体步骤 迭代法 三 代码实现 一 题目 1 leetcode链接 力扣 2 题目内容 给你一棵二叉树的根节点 root 翻转这棵二叉树 并返回其根节点 示例 1 输入 root 4 2 7
  • LeetCode No3. 无重复字符的最长子串 题解

    文章目录 一 题目 二 算法思想 三 示例 四 代码 五 复杂度分析 六 算法评价 一 题目 给定一个字符串 s 请你找出其中不含有重复字符的 最长子串 的长度 示例 1 输入 s abcabcbb 输出 3 解释 因为无重复字符的最长子串
  • 从高中到大学 寻找真实的自己

    写在前面 这是这个寒假刚开始在CSDN上写博客的时候发的第一个blink 当时想说的话有点多 但blink的文字限制是1024字 所以那时控制了字数 现在放开重新写 写在正文 因为疫情原因在家上了差不多3个月的网课 大一回来过个寒假 再次回
  • 2020年研究生数学建模竞赛优秀论文汇总

    A题 ASIC 芯片上的载波恢复 DSP算法设计与实现论文1 论文2 论文3 论文4 论文5 B题 降低汽油精制过程中的辛烷值损失模型论文1 论文2 论文3 论文4 论文5 论文6 论文7 论文8 论文9 论文10 C题 面向康复工程的脑电
  • HTTP协议2)----对于传输层的详细讲解

    大家好 我是 兔7 一位努力学习C 的博主 如果文章知识点有错误的地方 请指正 和大家一起学习 一起进步 如有不懂 可以随时向我提问 我会全力讲解 如果感觉博主的文章还不错的话 希望大家关注 点赞 收藏三连支持一下博主哦 你们的支持是我创作
  • pythonfilter_Python如何用filter函数筛选数据

    一 filter函数简介 filter函数主要用来筛选数据 过滤掉不符合条件的元素 并返回一个迭代器对象 如果要转换为列表list或者元祖tuple 可以使用内置函数list 或者内置函数tuple 来转换 filter函数接收两个参数 第
  • Altium Designer可以实现选中整条同网络线路的快捷键

    选中一段线路 按Tab键 可以选中同网络的整条线路
  • Masked Autoencoders Are Scalable Vision Learners

    Masked Autoencoders Are Scalable Vision Learners Author Unit Facebook AI Research FAIR Authors Kaiming He
  • Finclip小程序目录结构与微信小程序目录结构

    Finclip小程序目录结构 小程序包含一个描述整体程序的 app 和多个描述各自页面的 page 一个小程序主体部分由三个文件组成 必须放在项目的根目录 如下 文件 必需 作用 app js 是 小程序逻辑 app json 是 小程序公
  • 两个无序的数组 如何进行合并 为一个有序的数组

    这里我们首先来看 自己也才毕业半年 这些题比较适合新手练练思想 技术之路且行且珍惜 算法绝对是核心竞争力 两个无序的数组 那么首先第一步合并 第二步 使用正则表达式去掉 第三步 split进行划分 第四步 最核心的排序 此处用了Arrays
  • MYSQL索引那些事

    一 关系型和非关系型的区别 以及使用场景 关系型数据库 采用关系模型来组织数据的数据库 关系模型就是二维表格模型 一张二维表的表名就是关系 二维表中的一行就是一条记录 二维表中的一列就是一个字段 优点 容易理解 使用方便 通用的 sql 语
  • Ceph OSD Down

    CEPH集群跑了一段时间后有几个OSD变成down的状态了 但是我用这个命令去activate也不行 ceph deploy osd activate osd1 dev sdb2 dev sdb1 只能把osd从集群中移除 然后再重建了 这
  • 【我的Android进阶之旅】如何快速寻找Android第三方开源库在Jcenter上的最新版本...

    问题描述 解决方法 先了解compile comsquareupokhttpokhttp240的意义 了解Jcenter和Maven jcenter Maven Central 理解jcenter和Maven Central 快速搜索方法1
  • 改造我们的学习

    我们知道 程序员必须得不断的学习 才能跟上日新月异的技术 但是很多朋友陷入了误区 比如学习C 总觉得我要把 C Primier 看完 再开始编程 学习图像处理也是 非要把数字图像处理与Opencv的书籍看完 才开始上机调试 最后云里雾里 感
  • 零基础Qt笔记<传智教育>Qt版本:2022 5.15

    目录 1 创建第一个Qt程序 2 命名规范以及快捷键 3 QPushBottom的创建 4 对象树 5 Qt中的坐标系 6 信号和槽 6 1 实现点击按钮关闭窗口 6 2 自定义的信号和槽 6 3 自定义的信号和槽发生重载的解决 6 4 信
  • 电话号码升位(拷贝构造函数)

    题目描述 定义一个电话号码类CTelNumber 包含1个字符指针数据成员 以及构造 析构 打印及拷贝构造函数 字符指针是用于动态创建一个字符数组 然后保存外来输入的电话号码 构造函数的功能是为对象设置键盘输入的7位电话号码 拷贝构造函数的
  • python编程实战(三):暴力破解WIFI密码!亲测运行有效!

    本文非原创 参考 Python破解WIFI密码详细介绍 对于代码有细微修改 增加注意事项介绍 声明 本文只是从技术的角度来阐述学习Pywifi库 并不建议大家做任何破坏性的操作和任何不当的行为 并不建议大家做任何破坏性的操作和任何不当的行为
  • js分治法入门级教程,二分搜索的解法

    一 分治法定义 在计算机科学中 分治法是一种很重要的算法 字面上的解释是 分而治之 分治法就是把一个复杂的问题分成两个或更多的相同或相似的子问题 再把子问题分成更小的子问题 直到最后子问题可以简单的直接求解 原问题的解即子问题的解的合并 分
  • 将Hugging Face模型转换成LibTorch模型

    Hugging Face的模型 以waifu diffusion模型为例 给出的实现一般是基于diffuser库 示例代码如下 import torch from torch import autocast from diffusers i