【Kaggle】Stable Diffusion - Image to Prompts竞赛代码初步理解

2023-11-18


在这里插入图片描述

一、前言

此次代码集成了 CLIP Interrogator、OFA 模型和 ViT 模型。

首先安装指定版本的 transformers 库:

transformers-4.18.0.dev0-py3-none-any.whl 是一个 transformers 库的文件,它的命名方式表示这是一个开发版本(dev)的预构建轮子(wheel)文件。

轮子文件是 Python 包的一种打包格式,可以通过 pip 安装。

如果您想要安装这个特定版本的 transformers 库,可以使用以下命令:

pip install transformers-4.18.0.dev0-py3-none-any.whl

请确保您位于包含该文件的目录,并且在运行该命令之前已经安装了适当的依赖项。

请注意,这是一个开发版本的预构建文件,可能包含尚未正式发布的功能或存在 bug。如果您只是想使用稳定版本的 transformers 库,建议使用正式发布的版本,例如:

pip install transformers

这将安装最新的稳定版本,而不是开发版本。

在比赛中,我们可以引入的方法为:

!pip install -q /kaggle/input/stable-diffusion-data/transformers-4.18.0.dev0-py3-none-any.whl

Kaggle 环境中使用 Jupyter NotebookJupyterLab 的一个指令。

  • ! 符号是 Jupyter Notebook 或 JupyterLab 中的一个魔术命令前缀,用于执行系统级命令。
  • pip install 是用于安装 Python 包的 pip 命令。
  • -q–quiet 参数是 pip 命令的选项之一,用于使安装过程静默,即不显示安装的详细信息。
  • /kaggle/input/stable-diffusion-data/transformers-4.18.0.dev0-py3-none-any.whl 是一个文件路径,表示要安装的 transformers 库的轮子文件(.whl 文件)。根据路径中的前缀 /kaggle/input,可以推断这是在 Kaggle 环境中安装位于 /kaggle/input 目录下的本地文件。

综上所述,该命令的含义是在 Kaggle 环境中安装指定路径下的 transformers-4.18.0.dev0-py3-none-any.whl 轮子文件,并且在安装过程中不显示详细信息。这将通过使用 pip 命令将该轮子文件作为本地文件进行安装。

二、导包

import os
import sys
import glob
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from transformers import OFATokenizer, OFAModel
from transformers.models.ofa.generate import sequence_generator
import gc

这是一个 Python 脚本中的导入语句,它用于将一些常用的 Python 和深度学习库导入到脚本中。以下是每个导入库的简要介绍:

  • os:提供了一个与操作系统交互的简单方法,例如读取和写入文件。
  • sys:提供了一个与 Python 解释器交互的方法,例如修改 sys.path 导入路径、获取命令行参数等。
  • glob:提供了一个通用文件和目录的匹配模式方法。
  • pathlib.Path:提供了一个简单的面向对象的路径类,用于操作文件和目录。
  • numpy:提供了一个用于进行数学和科学计算的 Python 库。
  • pandas:提供了一个用于数据分析的 Python 库,可以用于处理和操作大型数据集。
  • matplotlib.pyplot:提供了一个用于绘制数据可视化图形的 Python 库。
  • PIL.Image:提供了一个 Python 图像处理库,可以用于处理和操作图像。
  • torch:提供了一个 PyTorch 深度学习框架库,用于创建和训练神经网络模型。
  • torch.utils.data.Dataset:提供了一个 PyTorch 数据集抽象类,用于创建自定义数据集类。
  • torchvision.transforms:提供了一些 PyTorch 中用于数据增强和预处理的转换类。
  • transformers.OFATokenizer:提供了一个 OneFlow Aware Tokenizer(OFA)类,用于在 PyTorch 中进行模型搜索和架构优化。
  • transformers.OFAModel:提供了一个 OneFlow Aware Model(OFA)类,用于在 PyTorch 中进行模型搜索和架构优化。
  • transformers.models.ofa.generate.sequence_generator:提供了一个序列生成器,用于生成一系列优化的模型架构。
  • gc:是 Python 中的垃圾回收库,可以用于管理内存。
CKPT_DIR = "/kaggle/input/stable-diffusion-data/OFA-large-caption/"
IMAGE_DIR = "/kaggle/input/stable-diffusion-image-to-prompts/images"

BATCH_SIZE = 24

这些是一些变量的赋值语句:

  • CKPT_DIR:设置为 “/kaggle/input/stable-diffusion-data/OFA-large-caption/”,表示一个目录路径,指向存储 OneFlow Aware (OFA) 模型的检查点文件的位置。
  • IMAGE_DIR:设置为 “/kaggle/input/stable-diffusion-image-to-prompts/images”,表示一个目录路径,指向存储图像文件的位置。
  • BATCH_SIZE:设置为 24,表示批量处理数据时的批量大小,即一次传递给模型的样本数量。

这些变量的值可以根据需要进行调整,用于指定数据的路径、模型的保存位置以及批量处理数据时的批量大小。

三、加载预训练的 OFA 模型

mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
resolution = 480
patch_resize_transform = transforms.Compose([
        lambda image: image.convert("RGB"),
        transforms.Resize((resolution, resolution), interpolation=Image.BICUBIC),
        transforms.ToTensor(), 
        transforms.Normalize(mean=mean, std=std)
    ])

tokenizer = OFATokenizer.from_pretrained(CKPT_DIR)
model = OFAModel.from_pretrained(CKPT_DIR, use_cache=False).cuda()
txt = " what does the image describe?"
inputs = tokenizer([txt], return_tensors="pt").input_ids

让我们逐行解读这段代码:

  • mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]:这行代码定义了 meanstd 两个变量,它们是归一化图像时使用的均值和标准差。这里设置的均值和标准差都是 [0.5, 0.5, 0.5],表示将图像的每个通道的像素值缩放到范围 [-1, 1]
  • resolution = 480:这行代码定义了 resolution 变量,它表示将图像调整为的分辨率大小。在这里,图像将被调整为 480 x 480 像素。
  • patch_resize_transform = transforms.Compose([…]):这行代码定义了一个转换序列 patch_resize_transform,它将应用于输入图像。转换序列中的每个操作按顺序应用于图像。
  • image.convert(“RGB”) 将图像转换为 RGB 模式,确保图像有三个通道。
  • transforms.Resize((resolution, resolution), interpolation=Image.BICUBIC) 调整图像大小为给定的 resolution,使用双三次插值方法进行调整。
  • transforms.ToTensor() 将图像转换为张量形式,将像素值缩放到范围 [0, 1]
  • transforms.Normalize(mean=mean, std=std) 对图像进行标准化,将像素值归一化为均值为 mean,标准差为 std 的分布。
  • tokenizer = OFATokenizer.from_pretrained(CKPT_DIR):这行代码创建一个 OFATokenizer 对象,从预训练模型的检查点文件中加载 tokenizer
  • model = OFAModel.from_pretrained(CKPT_DIR, use_cache=False).cuda():这行代码创建一个 OFAModel 对象,并加载预训练模型的权重。use_cache = False 表示不使用缓存。
  • txt = " what does the image describe?":这行代码定义了一个字符串变量 txt,它包含了图像描述的文本。
  • inputs = tokenizer([txt], return_tensors=“pt”).input_ids:这行代码使用之前创建的 tokenizer 对象将文本 txt 编码为模型输入的张量形式。tokenizer([txt]) 将文本转换为 tokens,并返回一个字典对象。.input_ids 提取了输入 tokens 的张量表示。最终,inputs 变量将包含文本编码后的输入张量。

总的来说,这段代码的目的是为了准备图像和文本数据,以便将它们输入到 OFA 模型中进行处理和生成。

四、模型EDA

sample_images = glob.glob("/kaggle/input/stable-diffusion-image-to-prompts/images/*")[:7]
fig, ax = plt.subplots(7,1, figsize=(4,35))

for i,impath in enumerate(sample_images):
    image = Image.open(impath)
    image_t = patch_resize_transform(image).cuda().unsqueeze(0)
    out = model.generate(inputs.cuda(), patch_images=image_t.cuda(), num_beams=5, no_repeat_ngram_size=2)
    out_captions = tokenizer.batch_decode(out, skip_special_tokens=True)
    ax[i].imshow(image)
    ax[i].text(1.1, .5, out_captions[0], horizontalalignment='left', verticalalignment='center', transform=ax[i].transAxes)

让我们逐行解释这段代码:

  • sample_images = glob.glob(“/kaggle/input/stable-diffusion-image-to-prompts/images/*”)[:7]:这行代码使用 glob.glob 函数获取 /kaggle/input/stable-diffusion-image-to-prompts/images/ 目录下的图像文件路径,并选择前 7 个图像文件。这些文件路径被存储在 sample_images 列表中。
  • fig, ax = plt.subplots(7, 1, figsize=(4, 35)):这行代码创建了一个包含 7 行、1 列的子图布局,每个子图的大小为 (4, 35)。返回的 fig 对象表示整个图像,ax 对象是包含 7 个子图的数组。
  • for i, impath in enumerate(sample_images)::这是一个循环,遍历 sample_images 列表中的图像文件路径。enumerate 函数用于同时迭代列表中的元素和它们的索引。在每次迭代中,i 是索引,impath 是当前的图像文件路径。
  • image = Image.open(impath):这行代码使用 PIL 库的 Image.open 函数打开图像文件,并将图像对象存储在 image 变量中。
  • image_t = patch_resize_transform(image).cuda().unsqueeze(0):这行代码对打开的图像 image 应用之前定义的 patch_resize_transform 转换序列,将图像调整大小并进行标准化处理。然后,使用 .cuda() 将图像张量移动到 GPU 上,并使用 unsqueeze(0) 在批次维度上添加一个维度。
  • out = model.generate(inputs.cuda(), patch_images=image_t.cuda(), num_beams=5, no_repeat_ngram_size=2):这行代码使用预训练的 OFA 模型 model 生成文本。它接收一个输入文本的张量 inputs,以及调整大小和标准化后的图像张量 image_tnum_beams=5 表示使用束搜索方法生成多个可能的文本输出,no_repeat_ngram_size=2 表示生成的文本中不会有连续重复的 2-gram
  • out_captions = tokenizer.batch_decode(out, skip_special_tokens=True):这行代码使用 tokenizer 将模型生成的输出 out 解码为文本。skip_special_tokens = True 表示跳过特殊标记,如起始和结束标记。
  • ax[i].imshow(image):这行代码在第 i 个子图中显示图像。
  • ax[i].text(1.1, .5, out_captions[0], horizontalalignment=‘left’, verticalalignment=‘center’, transform=ax[i].transAxes):这行代码在第 i 个子图中添加文本标注。1.1, .5 是文本的位置坐标,horizontalalignment = ‘left’ 表示文本水平对齐方式为左对齐,verticalalignment = ‘center’ 表示文本垂直对齐方式为居中对齐。ax[i].transAxes 表示使用子图坐标系进行转换。

这段代码的目的是展示图像并在每张图像上显示由 OFA 模型生成的文本描述。它首先遍历了前 7 个图像文件的路径,然后对每张图像进行处理:调整大小、标准化,并传递给 OFA 模型生成文本描述。生成的文本描述被解码后,用作文本标注,并与相应的图像一起显示在子图中。

最终的结果是,在一个具有 7 行、1 列的图像布局中,显示了每张图像以及由 OFA 模型生成的相应文本描述。

五、Inference

sys.path.append('../input/sentence-transformers-222/sentence-transformers')
from sentence_transformers import SentenceTransformer, models

comp_path = Path('../input/stable-diffusion-image-to-prompts/')
st_model = SentenceTransformer('/kaggle/input/sentence-transformers-222/all-MiniLM-L6-v2')

让我们逐行解读上面的代码:

  • sys.path.append(‘…/input/sentence-transformers-222/sentence-transformers’):这行代码将 ‘…/input/sentence-transformers-222/sentence-transformers’ 目录添加到 Python 的模块搜索路径中,以便能够导入其中的模块。
  • from sentence_transformers import SentenceTransformer, models:这行代码从 sentence_transformers 模块中导入 SentenceTransformermodels 类。这些类提供了用于处理和生成句子向量的功能。
  • comp_path = Path(‘…/input/stable-diffusion-image-to-prompts/’):这行代码创建了一个 Path
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【Kaggle】Stable Diffusion - Image to Prompts竞赛代码初步理解 的相关文章

随机推荐

  • LRU LFU 概念、底层原理及其实现 超详细~

    0 前置提要 本篇约为8650字 阅读完需要约40 60分钟 主要介绍页面置换算法 LRU和LFU的原理及其实现 对应leetcode140和460 如果能给个赞就更好了 1 从内存置换算法说起 计算机的运行的程序和数据保存在内存中 内存的
  • C++ MYSQL 多线程并发查询处理数据

    1 ThreadPool hpp pragma once ifndef THREAD POOL H define THREAD POOL H include
  • ERROR: Could not find a version that satisfies the requirement keras-nightly (from versions: none)

    问题描述 这个错误发生于博主在python 3 8的环境里安装tensorflow 2 5 0时 执行pip install tensorflow 2 5 0 出现了以下错误 ERROR Could not find a version t
  • Java正则表达式详解

    1 1 正则表达式的概念以及演示 正则表达式可以用一些规定的字符来制定规则 并用来校验数据格式的合法性 正则表达式就是用来验证各种字符串的规则 它内部描述了一些规则 我们可以验证用户输入的字符串是否匹配这个规则 正则表达式是一种强大的校验机
  • 2023-8-23 堆排序

    题目链接 堆排序 include
  • python写程序计算无穷级数_[python][计算方法]利用无穷级数计算幂运算(开根号)...

    encoding gbk a的n次方函数 def exp a n ret 1 for i in range 0 n ret a return float ret n n 1 n 2 def getN minus n n x ret floa
  • 【教程】加速访问和下载github项目,原来替换一个域名就可以加速了

    目录 前言 gitee方法 更简便方法 使用教程 前言 大家平时下载github项目的时候 非常的慢 有时候浏览某个项目的md介绍时候 图片就是加载不出来 让人很苦恼 想锤电脑 gitee方法 于是有很多人都是用gitee的方法 先导入到g
  • 【存储管理】brk()系统调用

    尽管应用程序编程时很少直接调用brk 系统调用 但是它是最经常使用的系统调用 1 C语言中的malloc以及C 语言中的new都在间接的调用着brk 这个系统调用 内核中含有3GB的用户虚存空间 会部分映射到物理存储空间 用户程序经过编译
  • vue中怎么引入element以及使用的详细教程

    引入element 安装依赖 在使用 Element 之前 需要先安装 Element 的依赖库 可以使用 npm 或者 yarn 安装 npm npm i element ui S yarn yarn add element ui 引入C
  • Qt 如何关闭 Debug信息输出

    在pro文件中加上DEFINES QT NO DEBUG OUTPUT 然后重新构建一下程序 qDebug的信息就不再输出了 不过qWarning qCritical等信息仍然可以输出 类似的宏还有 QT NO INFO OUTPUT QT
  • 剑指Offer第五十八题:对称的二叉树

    题目描述 请实现一个函数 用来判断一颗二叉树是不是对称的 注意 如果一个二叉树同此二叉树的镜像是同样的 定义其为对称的 1 思路 我们通常有三种不同的二叉树遍历算法 即前序遍历 中序遍历和后序遍历 在这三种遍历算法中 都是先遍历左子结点再遍
  • 良许Linux

    Linux 服务器我们天天打交道 特别是 Linux 工程师更是如此 为了保证服务器的安全与性能 我们经常需要监控服务器的一些状态 以保证工作能顺利开展 本文介绍的几个命令 不仅仅适用于服务器监控 也适用于我们日常情况下的开发 1 watc
  • depcheck检测缺失哪些依赖包

    npm install g depcheck 如果不想全局安装 npm i depcheck后可以在package json的scripts中输入 check depcheck 之后使用 npm run check depcheck npm
  • umi-request 网络请求之路

    umi request 网络请求之路 背景 在做中台业务应用开发的过程中 我们发现在请求链路上存在以下问题 请求库各式各样 没有统一 每次新起应用都需要重复实现一套请求层逻辑 切换应用时需要重新学习请求库 API 各应用接口设计不一致 混乱
  • sql注入Less11-20

    Less 11 POST 1 先登录 在表格中输入admin admin 登录成功后为下图 2 在post data中输入以下 uname passwd 1 submit submit 返回的结果显示存在sql语法错误 证明存在注入漏洞 u
  • 修改别人代码的原则

    工作过程中难免会涉及到修改或维护别人写的代码 如 代码原作者请假 离职 或相关的bug落到了你的头上 或用别人写的通用方法不爽时 如果碰到修改别人的代码时 需要注意哪些事项呢 1 和原作者沟通 当用到了他人写的通用方法 又感觉不爽时 如果原
  • 各个版本chrome允许加载使用flash的方法

    根除办法 在html中嵌入标签 用户自动选择是否加载flash 69 0 之前的版本 1 打开 chrome settings content flash 2 禁止网站运行Flash gt 改为 Ask Default 3 允许 gt 添加
  • golang开发的准备 - gvm(go版本管理软件)的安装

    0 系统环境 ubuntu18 04 1 前置条件 sudo apt get install bison 2 安装步骤 1 从github下载安装包文件 git clone https github com moovweb gvm git
  • 【c++】14.编译proto和proto相关用法

    编译proto和proto相关用法 关于proto相关的知识可以参考系列博客 https blog csdn net daaikuaichuan category 9869251 html xx proto文件中如果要注释的话 注释符号也是
  • 【Kaggle】Stable Diffusion - Image to Prompts竞赛代码初步理解

    文章目录 一 前言 二 导包 三 加载预训练的 OFA 模型 四 模型EDA 五 Inference 六 安装并导入所有依赖项 七 设置配置 八 加载示例提交 九 Build index from images 十 CLIP interro