【无监督】3、SimCLRv1

2023-11-11

在这里插入图片描述

论文:A Simple Framework for Contrastive Learning of Visual Representations

代码: https://github.com/google-research/simclr

出处:ICML 2020 | Hinton 大佬 | Google

贡献:

  • 证明不同数据增强的结合很重要
  • 在特征表达和 contrastive loss 之间引入了可学习的非线性 transformer 结构,取得了很大的效果提升
  • 在大的 batch size 和大的 epoch 的加持下对比学习能获得比有监督学习更好的效果

效果:

  • 使用自监督对比学习的方式训练 ImageNet 提取特征后,训练了一个线性分类器,就获得了 76.5% top-1 acc,比当时的 SOTA 高 7%,和有监督基线网络 ResNet50 获得了同样的效果

一、背景

目前来说,大致有两个不同的路线来做无标签的视觉特征提取,分别是 generative 和 discriminative,也就是生成式和判别式

  • 生成式的方法是学习如何生成和输入空间相同的像素 ,但是 pixel-level 的生成计算量很大而且没有很强的特征表达意义
  • 判别式的方法是使用目标函数来判断两个输入是否来源于同一个数据,一般都是需要使用代理任务来对同一输入生成不同的样本,所以代理任务如果用的不好,有可能会限制模型的泛化性。

基于判别式的方法在目前取得了 SOTA 的效果(如 MOCO),所以本文作者为了探究其原因,就做了一些探索和实验,并且证明了下面这几个结论:

  • 在代理任务中,结合使用多种不同的数据增强方式能得到更好的特征表达,而且数据增强为无监督对比学习带来的效果提升大于有监督学习
  • 作者在特征表达的计算 contrastive loss 之间引入了一个可学习的非线性 transformer,能很大程度的提高模型效果
  • 对特征进行归一化更有利于使用 contrastive cross entropy 学习的方法
  • 自监督学习需要更大的 batch size 和更长的训练时间(相比有监督学习而言)

作者正式结合了上面的几种发现,所以才构建了一个简单的网络框架 SimCLR

二、方法

2.1 对比学习框架

在这里插入图片描述

SimCLR 是通过最大化同一样本的不同视角在特征空间中的一致性来学习的,网络结构如图 2 所示

  • 首先,给定一个输入样本 x,作者使用数据增强来生成两个图片,这两个图片就是一对 positive pairs。

    本文中会顺序的使用 3 种数据增强:random cropping → resize 回原来的尺寸 → random color distortion →随机高斯噪声。因为作者通过实验发现 random crop 和 color distortion 的结合能取得最好的效果。

  • 然后,使用基础 encoder f ( . ) f(.) f(.) 来抽取数据的特征,这里的 encoder 选择的是 ResNet

  • 接着,对得到的特征使用 projection head g ( . ) g(.) g(.) 来将特征映射到 contrastive loss space。这里的 g ( . ) g(.) g(.) 是有一层隐藏层的 MLP。这里的 g ( . ) g(.) g(.) 是非线性的,因为使用了 ReLU 激活函数。

  • 最后,在最终的特征上进行对比预测任务,使用的是对比学习 loss,也就是在给定一堆经过变换后的样本,模型要能通过给定的 x i x_i xi 识别出其对应的正样本 x j x_j xj

对比学习具体是怎么学习的呢:

  • 首先,假设一个 batch 输入了 N 个 samples,经过代理任务后,就能得到 2N 个 augumented samples

  • 然后,使用 f ( . ) f(.) f(.) g ( . ) g(.) g(.) 进行对应的特征提取,得到 z i z_i zi z j z_j zj

  • 接着,计算对比学习 loss,对于一个样本 z i z_i zi,只有一个正样本 z j z_j zj,其余所有的 2(N-1) 个 augumented samples 都是负样本,所以样本 i 对应的 loss 函数如下,分母是剔除了 i 自己,sim 表示点乘, τ \tau τ 表示温度参数

    在这里插入图片描述

SimCLR 的整体过程:

这里为什么是 2k-1 次呢,因为一个 sample 得到的两个 aug samples 都是当前 batch 内的样本,所以每个样本都会和其他所有的样本计算 loss,i 和 j 计算一次,j 和 i 也会计算一次,所以每个样本都会计算 2k-1 次 loss。然后最后的 L 也除以 2 了,因为每个样本都计算了 2 次。

在这里插入图片描述

2.2 训练所使用的 batch size

我们已知对比学习比较依赖于负样本的数量,只有在负样本数量较大的时候才能学习到更有区分力的特征

所以作者使用了从 256~8192 大小的 batch size,当 batch size 为 8192 时,每个样本对应的负样本的数量就是 16382 (16382=2x(8192-1))

如果使用这么大的 batch size 同时使用 SGD/Momentum 优化器结合线性学习率变化的话,就会不稳定,所以作者使用了 LARS 优化器。

Global BN:所有机器上的数据共同计算 BN 的 mean 和 variance

在分布式训练中,BN 的 mean 和 variance 是计算单个卡上的所有样本得到的。在对比学习中,positive pairs 是同一机器上得到的,会导致信息泄露,泄露给模型说所有的正样本对都在对角线上,模型能使用泄露的局部信息来提高准确率,而不用提高学习效果。

为了避免这个问题,作者使用一次迭代的所有机器上的全部数据来计算 mean 和 variance,MOCO 中使用 shuffling data 的方式解决。

2.3 数据增强方式

在这里插入图片描述

图 4 中展示了不同的数据增强方式,作者也对不同的数据增强方式进行了凉凉组合,最后发现结合 random crop 和 color 的方式能够得到最好的效果

在这里插入图片描述

且作者证明了 strong augmentation 在无监督学习中更重要

在这里插入图片描述

2.4 更大的模型更有利于无监督对比学习

在这里插入图片描述

2.5 非线性映射头能带来更好的效果

在这里插入图片描述

2.6 更大的 batch size 和更长的训练时间更有利于对比学习

在这里插入图片描述

2.7 测评方式

之前和很多无监督与村里方法的测评都是在 ImageNet 上,还有一些在 cifar-10 上。

作者在本文中也会使用迁移学习的方式来测评预训练模型的效果

作者测评使用的方式是 linear protocal(就是冻结预训练 backbone,只训练最后添加的分类头)

设置:

  • base encoder:R50
  • projection head(从输出映射到 128-d 特征):2 层 MLP
  • loss: NT-Xent,使用 LARS ,学习率为 4.8,weight decay 为 10^-6
  • batch size:4096
  • epoch:100

三、效果

在这里插入图片描述

在这里插入图片描述

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【无监督】3、SimCLRv1 的相关文章

随机推荐

  • centos8安装mysql8

    本文主要介绍如何在Centos8下安装Mysql 一 下载Mysql 使用wget命令下载mysql安装包 确保系统已经安装了wget 如果没有安装 执行 yum install wget 安装 wget https repo mysql
  • React的基础概念JSX

    1 创建一个用react写的页面 div div
  • 面试那些题(1)

    更新ing 一 Canvas和SVG的区别是什么 1 Canvas主要是用笔刷来绘制2D图形的 2 SVG 主要是用标签来绘制不规则矢量图的 3 相同点 都是主要用来画2D图形的 4 不同点 Canvas画的是位图 SVG画的是矢量图 5
  • 检查服务器的系统类型,查看服务器的操作系统类型

    查看服务器的操作系统类型 内容精选 换一换 查看用户的镜像类型 如果是公共镜像则排除私有镜像的源镜像问题 镜像类型单击 申请服务器 查看能否创建出此镜像的弹性云服务器 申请完成后未出现此镜像对应的弹性云服务器 则此类镜像可能已经下线 属于老
  • git出现fatal: Authentication failed for ‘http:xxxx.git/‘‘

    在git上clone的时候 输入用户名和密码第一遍输错以后 之后就无法再自动弹出输入用户名和密码的窗口了 出现错误如下 fatal Authentication failed for http xxxx git 解决办法 git confi
  • 【Linux】UDP、TCP协议

    目录 前言 1 UDP协议 1 1 UDP协议段格式 1 2 UDP的特点 1 3 UDP的缓冲区 2 TCP协议 2 1 TCP报文格式 2 2 TCP的确认应答机制 2 3 流量控制 2 4 标志位 2 4 1 ACK SYN 2 4
  • 时间序列预测方法总结

    时间序列预测方法总结 数据准备 方法1 朴素法 方法2 简单平均法 方法3 移动平均法 方法4 简单指数平滑法 平面预测 优化 方法5 霍尔特 Holt 线性趋势法 方法6 Holt Winters季节性预测模型 加法分量形式 方法7 自回
  • #pragma预处理指令

    pragma是C和C 编译器提供的一种预处理指令 preprocessor directive 用于控制编译器的行为或指示特定的编译器选项 它以 pragma开头 后面跟着不同的命令或参数 pragma指令在源代码被编译之前由预处理器进行处
  • 热烈祝贺开源社顾问委员会委员姜宁当选 2022 Apache 软件基金会新任董事~

    设计 张千禧 内容 SegmentFault思否 Apache软件基金会官网 责编 李明康 在刚刚结束的 ASF Annual Meeting 上 2022 年新任 ASF Member 及董事会成员诞生了 Apache 软件基金会通过官方
  • iOS vs Flutter(语法篇)

    iOS开发者入门Flutter 首先说一下 为什么要关心iOS和Flutter的区别问题 因为移动端开发的业务逻辑设计模式等是一致的 区别可能只在于使用的语言不同 实现逻辑的风格不同而已 所以这里我们先分析一下iOS和Flutter的区别到
  • 华为OD机试真题-单向链表的中间节点/哈希表【2023Q1】

    题目描述 求单向链表中间的节点值 如果奇数个节点取中间 偶数个取偏右边的那个值 输入描述 第一行 链表头节点地址path 后续输入的节点数n 后续输入每行表示一个节点 格式 节点地址 节点值 下一个节点地址 1表示空指针 输入保证链表不会出
  • 多线程与高并发v2.0版

    多线程是程序员面试时常常会面对的问题 对多线程概念的掌握和理解水平 也会被一些面试官用来衡量一个人的编程实力的重要参考指标 另附一张思维导图供大家参考学习 不论是实际工作需要还是为了应付面试 掌握多线程都是程序员职业生涯中一个必须经过的环节
  • SSM整合web项目访问同时html和jsp页面

    ssm 配置请看 https blog csdn net qq 19688207 article details 114578526 spm 1001 2014 3001 5501 html页面访问路径 http localhost 808
  • 第八章 多态(下)

    第八章 多态 下 本章多态目前就到此介绍完毕 可能还会有些疑问 不过后面还会有很多设计到多态的地方 通过不断学习 最后一定会将这些知识掌握的 同时这一章所讲解的知识也还是比较详细的 多态基本上都会在我们的代码中体现出来 只是之前不知道 不明
  • java.sql.SQLRecoverableException: IO 错误: Undefined Error错误解决办法

    应用报错 java sql SQLRecoverableException IO 错误 Got minus one from a read call 据开发人员描述 起多个服务 最后服务的时候报这个错 无论最后的服务是啥 提供的报错日志 是
  • 带权 有向 图 邻接矩阵 建立 插入 删除

    include
  • curl移植

    一 下载curl https curl se download html 二 交叉编译 configure prefix PWD install host arm buildroot linux CC arm buildroot linux
  • 蓝桥杯冲击-02约数篇(必考)

    文章目录 前言 一 约数是什么 二 三大模板 1 试除法求约数个数 2 求约数个数 3 求约数之和 三 真题演练 前言 约数和质数一样在蓝桥杯考试中是在数论中考察频率较高的一种 在省赛考察的时候往往就是模板题 难度大一点会结合其他知识点考察
  • 密码技术-数字签名

    一 数字签名 用私钥生成数字签名 用公钥验证签名 数字签名的方法 直接对消息签名 很少用这个 1 Alice 用自己的私钥对消息进行加密 2 Alice 将消息和签名发送给 Bob 3 Bob 用 Alice 的公钥对收到的签名进行解密 4
  • 【无监督】3、SimCLRv1

    文章目录 一 背景 二 方法 2 1 对比学习框架 2 2 训练所使用的 batch size 2 3 数据增强方式 2 4 更大的模型更有利于无监督对比学习 2 5 非线性映射头能带来更好的效果 2 6 更大的 batch size 和更