CLIP与CoOp代码分析

2023-11-19

CLIP与CoOp代码分析

CoOp是稍微改了下CLIP的text encoder
CLIP代码:https://github.com/OpenAI/CLIP
CoOp代码:https://github.com/KaiyangZhou/CoOp

输入一张图片和三段文本,使用CLIP,计算相似度。以下是向量维度分析

import torch
import clip
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)# torch.Size([3,77])  [batch_size,n_ctx]
                                                                # 这里只是对每个词进行编号,空的地方补0
                                                                # 设置n_ctx为77,但只支持输入75个词,
                                                                # 因为有两个词分别是startoftext 和 endoftext,会自动加入

with torch.no_grad():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)# torch.Size([3,512]) 每一条prompt进行encode后会变成512维
    
    logits_per_image, logits_per_text = model(image, text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

print("Label probs:", probs)  # prints: [[0.9927937  0.00421068 0.00299572]]

CLIP的encode_text 函数

    def encode_text(self, text):
        x = self.token_embedding(text).type(self.dtype)  # [batch_size,n_ctx] -> [batch_size, n_ctx, d_model] #这里d_model==512

        x = x + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x).type(self.dtype)

        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection

        return x

CLIP的encode_text函数有①token_embedding和②positional_embedding。

  • ① token_embedding是nn.Embedding。是把clip.tokenize生成出来的维度为[batch_size,n_ctx]的text向量,转换成[batch_size, n_ctx, d_model]的向量。
  • ② positional_embedding是nn.Parameter。是可学习的。

①和②相加,输进transformer,通过训练学习,更新positional_embedding

CoOp的 TextEncoder

CoOp的encode_text把CLIP的①token_embedding换成了 p r o m p t s prompts prompts p r o m p t s prompts prompts是可学习的。

class TextEncoder(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection
        self.dtype = clip_model.dtype

    def forward(self, prompts, tokenized_prompts):
        x = prompts + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x).type(self.dtype)

        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        '''
        若输入进text encoder的是一句["this is hair"]
        print("text",text) #tensor([[49406,   589,   533,  2225, 49407,     0,     0,     0,     0, ...]])
        print("text.argmax(dim=-1)",text.argmax(dim=-1)) #tensor([4])
        print("torch.arange(x.shape[0])",torch.arange(x.shape[0])) #tensor([0])
        print("x",x.shape,x) #torch.Size([1, 77, 512])
        print("x[torch.arange(x.shape[0]), text.argmax(dim=-1)]",x[torch.arange(x.shape[0]), text.argmax(dim=-1)].shape) #torch.Size([1, 512])即从[1,77,512]取了index==4的矩阵
        '''
        # ! tokenized_prompts在这里只是为了用了获得eot_token对应的embedding?
        x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection

        return x

CoOp的 PromptLearner

与CLIP手动设置的Prompt不同(比如“a photo of a {label}”)。CoOp定义了可学习的PromptLearner,可生成 p r o m p t s prompts prompts t o k e n i z e d _ p r o m p t s tokenized\_prompts tokenized_prompts ,tokenized_prompts由如下代码得来:

        nn.init.normal_(ctx_vectors, std=0.02)  #初始化
        prompt_prefix = " ".join(["X"] * n_ctx) #初始化

        classnames = [name.replace("_", " ") for name in classnames]
        name_lens = [len(_tokenizer.encode(name)) for name in classnames] # 这里只是对每个词进行编号
        prompts = [prompt_prefix + " " + name + "." for name in classnames]

        tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts])
                with torch.no_grad():
            		embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)

        # These token vectors will be saved when in save_model(),
        # but they should be ignored in load_model() as we want to use
        # those computed using the current class names
        self.register_buffer("token_prefix", embedding[:, :1, :])  # SOS
        self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :])  # CLS, EOS

在前向阶段, p r o m p t s prompts prompts里设置了ctx,ctx是nn.Parameter,是可学习的。这里的 s u f f i x suffix suffix,是包含CLS, EOS的,即类名和end-of-sentence,即suffix包含了类的信息。

            # class_token_position == "end" class_token放在句末的情况
            self.ctx = nn.Parameter(ctx_vectors)  # to be optimized
            ctx = self.ctx                        # ctx是context的缩写
            prompts = torch.cat(
                [
                    prefix,  # (n_cls, 1, dim)    # 相当于startoftext的embedding
                    ctx,     # (n_cls, n_ctx, dim)
                    suffix,  # (n_cls, *, dim)    # 相当于endoftext 的embedding
                ],
                dim=1,
            )

训练CoOp

class CustomCLIP(nn.Module):
    def __init__(self, cfg, classnames, clip_model):
        super().__init__()
        self.prompt_learner = PromptLearner(cfg, classnames, clip_model)
        self.tokenized_prompts = self.prompt_learner.tokenized_prompts
        self.image_encoder = clip_model.visual
        self.text_encoder = TextEncoder(clip_model)
        self.logit_scale = clip_model.logit_scale
        self.dtype = clip_model.dtype
    # 训练CoOp时的前向阶段
    def forward(self, image):
        image_features = self.image_encoder(image.type(self.dtype))

        prompts = self.prompt_learner()
        tokenized_prompts = self.tokenized_prompts
        text_features = self.text_encoder(prompts, tokenized_prompts)

        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        logit_scale = self.logit_scale.exp()
        logits = logit_scale * image_features @ text_features.t()

        return logits

CoOp训练是用batch里的(image,label)。

前向阶段:image输入进image_encoder得到image_features;把prompts, tokenized_prompts放进TextEncoder获得text_features,两者算相似度获得logits。

logits与label进行交叉熵运算得到loss

反向阶段:loss反向传播,优化可学习的nn.Parameter。比如prompts里的ctx以及原CLIP里的 positional_embedding。

注:有理解错误的地方请指出

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

CLIP与CoOp代码分析 的相关文章

随机推荐

  • c语言答案计算鸡兔同笼,鸡兔同笼-题解(C语言代码,思路清晰,简单易懂)

    解题思路 设鸡和兔子的数量为x y 则有x y n 2x 4y m 即可得x 4n m 2 y m 2n 2 只有x y为分数 或者为负数时 即为无解情况 详细代码如下 include int main double n m chicken
  • solc安装指定版本

    1 系统linux ubuntu20 04 2 solc安装指定版本 在编译的时候报错 Error Data location must be storage or memory for constructor parameter but
  • 残差神经网络(ResNet)

    残差神经网络的主要贡献是发现了退化现象 并针对退化现象发明了快捷连接 shortcut connection 极大的消除了深度过大的神经网络训练困难问题 1 神经网络越深准确率越高 假设一个层数较少的神经网络已经达到了较高准确率 可以在这个
  • TB-RK3399pro(Fedora28)图形界面与字符界面的切换

    TB RK3399pro Fedora28 使用的是LXDE图形界面 使用时默认打开7个屏幕 分别是tty1到tty6 加上一个没名字的tty7 LXDE为tty1号屏幕 若要切换至字符界面 使用快捷键 Ctrl Alt F2 F2也可以为
  • Wps ppt中无法打开超链接外部文件的解决办法。

    今天突然发现 在原来的Wps ppt中的所有超链接视频或照片都无法打开了 以下是解决办法 供参考 主要原因是Windows10系统升级出现的冲突问题 请卸载这两个补丁 KB5015807和KB5016066 或者卸载其中之一即可打开
  • Oracle 设定允许访问的IP地址

    开启按ip地址访问 修改 oracle10 app db network admin sqlnet ora 在文件最后加下列2行 vim sqlnet ora tcp validnode checking yes tcp invited n
  • 滑雪(记忆化搜索)

    题目 题解 记忆化搜索模板题 记忆化搜索的核心 本质是带剪枝的深搜 当某点的dp已赋值时 返回该值 其他情况进行深度搜索 模板 dfs u点 if u点的 dp 已经有值了 return u点的 dp 值 else 说明第一次到达u 则为u
  • Flume之:二、企业开发案例

    Flume之 二 企业开发案例 文章目录 Flume之 二 企业开发案例 三 企业开发案例 1 监控端口数据官方案例 2 实时读取本地文件到HDFS案例 3 实时读取目录文件到HDFS案例 4 flume监控Kafka gt Spark知识
  • QCC300x笔记(3) -- QCC3007开发调试经验

    哈喽大家好 这是该系列博文的第三篇 篇 lt lt 系列博文索引 快速通道 gt gt 写在前面 这篇博客主要记录 在使用QCC300x平台中所遇到的问题以及解决方法 会不定时更新 1 使用的堆栈空间大小超出或者全局变量超出 会报以下错误
  • R语言回归分析

    R语言回归分析 回归分析可以说是统计学的核心 它其实是一个广义的概念 通指那些用一个或多个预测变量 也称自变量或解释变量 来预测响应变量 也称因变量 效标变量或结果变量 的方法 通常 回归分析可以用来挑选与响应变量相关的解释变量 可以描述两
  • ChatGPT国产平替出现了:APP商店就能下载,还可给AI加人设,背后公司刚成立3个月...

    明敏 发自 凹非寺量子位 公众号 QbitAI ChatGPT太火爆谁不想上手试试 但注册复杂 服务器拥挤 着实有点麻烦 不过很快就有极客网友指路 说国内其实已经有类似的APP上线了 也是上知天文下知地理的那种 比如聊聊 三体 还会说自己喜
  • 股票与债券的区别与联系

    1 股票与债券的联系 2 股票与债券的区别
  • C# Debug.WriteLine 参数显示不对{0}

    最近使用这个函数调试 原始代码 StackTrace st new StackTrace new StackFrame true Debug WriteLine Stack trace for current level 0 st ToSt
  • PgAdmin中的数据库查询功能

    参考博客 https blog csdn net qq 28289405 article details 80249509 utm medium distribute pc relevant none task blog BlogComme
  • 2022-TCGA数据库重大更新后RNASeq的STAR-Counts数据的下载与整理

    TCGA GEO 文献阅读 数据库 理论知识 R语言 Bioconductor 服务器与Linux 最近有粉丝留言 TCGA数据库发生更新 下载的数据和之前的不一样 比如转录组 之前是HTSeq流程的数据 现在是STAR Counts的数据
  • Jupyter Error “bad file descriptor“ in VSCode

    Jupyter Error bad file descriptor in VSCode 直接跑这一行 pip install upgrade force reinstall no cache dir jupyter
  • 已知斐波那契数列 1 1 2 3 5 8… ,求出第10项的值

    1 1 1 2 3 5 8 首先我们可以在这些数中找到规律 斐波那契数列的规定是固定的 从第三项开始等于前两项的和 第一项和第二项固定为 1 在求第N项时 首先把前面两项相加 再重新给前两项赋值 2 我们可以把第三项设为 np 那第二项的值
  • iOS 17更新,让苹果失去了魅力!

    1 iOS17的更新缺乏新意 随着WWDC2023的落幕 苹果发布了iOS17的开发者测试版 不过 由于需要开发者账号才能抢先体验 许多果粉们无法第一时间尝试iOS17的新功能 但实际上 这次的更新并没有带来令人期待的亮点 放眼望去 iOS
  • 优秀软件测试工程师必备的8个能力!-(附思维导图)

    结合自己以往的工作经验 自己梳理出来一些材料 绝对原创 绝对干货 优秀的软件测试工程师必备的 8个能力 作为一名软件工程师 需要的能力并不多 但是要成为一名优秀的软件测试工程师 需要的能力就比较多了 自己整理出来8个方面 每个方面都会分成很
  • CLIP与CoOp代码分析

    CLIP与CoOp代码分析 CoOp是稍微改了下CLIP的text encoder CLIP代码 https github com OpenAI CLIP CoOp代码 https github com KaiyangZhou CoOp 输