【神经网络搜索】DARTS: Differentiable Architecture Search

2023-11-17

【GiantPandaCV】DARTS将离散的搜索空间松弛,从而可以用梯度的方式进行优化,从而求解神经网络搜索问题。本文首发于GiantPandaCV,未经允许,不得转载。https://arxiv.org/pdf/1806.09055v2.pdf

1. 简介

此论文之前的NAS大部分都是使用强化学习或者进化算法等在离散的搜索空间中找到最优的网络结构。而DARTS的出现,开辟了一个新的分支,将离散的搜索空间进行松弛,得到连续的搜索空间,进而可以使用梯度优化的方处理神经网络搜索问题。DARTS将NAS建模为一个两级优化问题(Bi-Level Optimization),通过使用Gradient Decent的方法进行交替优化,从而可以求解出最优的网络架构。DARTS也属于One-Shot NAS的方法,也就是先构建一个超网,然后从超网中得到最优子网络的方法。

2. 贡献

DARTS文章一共有三个贡献:

  • 基于二级最优化方法提出了一个全新的可微分的神经网络搜索方法。
  • 在CIFAR-10和PTB(NLP数据集)上都达到了非常好的结果。
  • 和之前的不可微分方式的网络搜索相比,效率大幅度提升,可以在单个GPU上训练出一个满意的模型。

笔者这里补一张对比图,来自之前笔者翻译的一篇综述:<NAS的挑战和解决方案-一份全面的综述>

ImageNet上各种方法对比,DARTS属于Gradient Optimization方法

简单一对比,DARTS开创的Gradient Optimization方法使用的GPU Days就可以看出结果非常惊人,与基于强化学习、进化算法等相比,DARTS不愧是年轻人的第一个NAS模型。

3. 方法

DARTS采用的是Cell-Based网络架构搜索方法,也分为Normal Cell和Reduction Cell两种,分别搜索完成以后会通过拼接的方式形成完整网络。在DARTS中假设每个Cell都有两个输入,一个输出。对于Convolution Cell来说,输入的节点是前两层的输出;对于Recurrent Cell来说,输入为当前步和上一步的隐藏状态。

DARTS核心方法可以用下面这四个图来讲解。

DARTS Overview

(a) 图是一个有向无环图,并且每个后边的节点都会与前边的节点相连,比如节点3一定会和节点0,1,2都相连。这里的节点可以理解为特征图;边代表采用的操作,比如卷积、池化等。

引入数学标记:

节点(特征图)为: x ( i ) x^{(i)} x(i) 代表第i个节点对应的潜在特征表示(特征图)。

边(操作)为: o ( i , j ) o^{(i,j)} o(i,j) 代表从第i个节点到第j个节点采用的操作。

每个节点的输入输出如下面公式表示,每个节点都会和之前的节点相连接,然后将结果通过求和的方式得到第j个节点的特征图。

x ( j ) = ∑ i < j o ( i , j ) ( x ( i ) ) x^{(j)}=\sum_{i\lt j} o^{(i, j)}(x^{(i)}) x(j)=i<jo(i,j)(x(i))

所有的候选操作为 O \mathcal{O} O, 在DARTS中包括了3x3深度可分离卷积、5x5深度可分离卷积、3x3空洞卷积、5x5空洞卷积、3x3最大化池化、3x3平均池化,恒等,直连,共8个操作。

(b) 图是一个超网,将每个边都扩展了8个操作,通过这种方式可以将离散的搜索空间松弛化。具体的操作根据如下公式:

o ˉ ( i , j ) ( x ) = ∑ o ∈ O exp ⁡ ( α o ( i , j ) ) ∑ o ′ ∈ O exp ⁡ ( α o ′ ( i , j ) ) o ( x ) \bar{o}^{(i, j)}(x)=\sum_{o \in \mathcal{O}} \frac{\exp \left(\alpha_{o}^{(i, j)}\right)}{\sum_{o^{\prime} \in \mathcal{O}} \exp \left(\alpha_{o^{\prime}}^{(i, j)}\right)} o(x) oˉ(i,j)(x)=oOoOexp(αo(i,j))exp(αo(i,j))o(x)

这个可以分为两个部分理解,一个是 o ( x ) o(x) o(x)代表操作,一个代表选择概率 exp ⁡ ( α o ( i , j ) ) ∑ o ′ ∈ O exp ⁡ ( α o ′ ( i , j ) ) \frac{\exp \left(\alpha_{o}^{(i, j)}\right)}{\sum_{o^{\prime} \in \mathcal{O}} \exp \left(\alpha_{o^{\prime}}^{(i, j)}\right)} oOexp(αo(i,j))exp(αo(i,j)),这是一个softmax构成的概率,其中 α o ( i , j ) \alpha_o^{(i,j)} αo(i,j)表示 第i个节点到第j个节点之间操作的权重,这也是之后需要搜索的网络结构参数,会影响该操作的概率。即以下公式:
s o f t m a x ( α ) × o p e r a t i o n w ( x ) softmax(\alpha)\times operation_{w}(x) softmax(α)×operationw(x)
左侧代表当前操作的概率,右侧代表当前操作的参数。

©和(d)图 是保留的边,训练完成以后,从所有的边中找到概率最大的边,即以下公式:
o ( i , j ) = argmax ⁡ o ∈ O α o ( i , j ) o^{(i, j)}=\operatorname{argmax}_{o \in \mathcal{O}} \alpha_{o}^{(i, j)} o(i,j)=argmaxoOαo(i,j)

4. 数学推导

DARTS将NAS问题看作二级最优化问题,具体定义如下:

min ⁡ α L v a l ( w ∗ ( α ) , α )  s.t.  w ∗ ( α ) = argmin ⁡ w L train  ( w , α ) \begin{aligned} \min _{\alpha} & \mathcal{L}_{v a l}\left(w^{*}(\alpha), \alpha\right) \\ \text { s.t. } & w^{*}(\alpha)=\operatorname{argmin}_{w} \mathcal{L}_{\text {train }}(w, \alpha) \end{aligned} αmin s.t. Lval(w(α),α)w(α)=argminwLtrain (w,α)

w ∗ ( α ) w*(\alpha) w(α) 代表当前网络结构参数 α \alpha α的情况下,训练获得的最优的网络结构参数。

第一行代表:在验证数据集中,在特定网络操作参数w下,通过训练获得最优的网络结构参数 α \alpha α

第二行表示:在训练数据集中,在特定网络结构参数 α \alpha α下,通过训练获得最优的网络操作参数 w w w

条件:在结构确定的情况下,获得最优的网络操作权重

​ ----- 结构确定,训练好卷积核

目标:在网络操作权重确定的情况下,获得最优的结构

​ ----- 卷积核不动,选择更好的结构

最简单的方法是通过交替优化参数 w w w和参数 α \alpha α, 来得到最优的结果,伪代码如下:

DARTS伪代码

交替优化的复杂度非常高,是 O ( ∣ α ∣ ∣ w ∣ ) O(|\alpha||w|) O(αw), 这种复杂度不可能投入使用,所以要将复杂度进行优化,用复杂度低的公式近似目标函数。
∇ α L val  ( w ∗ ( α ) , α ) ≈ ∇ α L v a l ( w − ξ ∇ w L t r a i n ( w , α ) , α ) \nabla_{\alpha} \mathcal{L}_{\text {val }}\left(w^{*}(\alpha), \alpha\right) \approx \nabla_{\alpha} \mathcal{L}_{v a l}\left(w-\xi \nabla_{w} \mathcal{L}_{t r a i n}(w, \alpha), \alpha\right) αLval (w(α),α)αLval(wξwLtrain(w,α),α)
这种近似方法在Meta Learning中经常用到,详见《Model-agnostic meta-learning for fast adaptation of deep networks》,也就是通过使用单个step的训练调整w,让这个结果来近似 w ∗ ( α ) w*(\alpha) w(α)

然后对右侧公式进行推导,得到梯度优化以后的表达式:

师兄提供


这里求梯度使用的是链式法则,回顾一下:
z = f ( g 1 ( x ) , g 2 ( x ) ) z=f(g1(x),g2(x)) z=f(g1(x),g2(x))

则梯度计算为:
∂ z ∂ x = ∂ g 1 ∂ x × ∂ z ∂ g 1 + ∂ g 2 ∂ x × ∂ z ∂ g 2 \frac{\partial z}{\partial x}=\frac{\partial g1}{\partial x} \times \frac{\partial z}{\partial g1} + \frac{\partial g2}{\partial x}\times\frac{\partial z}{\partial g2} xz=xg1×g1z+xg2×g2z

或者

师兄提供

上述公式中Di代表对 f ( g 1 ( α ) , g 2 ( α ) ) f(g1(\alpha),g2(\alpha)) f(g1(α),g2(α))的第i项的偏导。


手敲公式太痛苦了

整理以后结果就是:

计算结果

减号后边的是二次梯度,权重的梯度求解很麻烦,这里使用泰勒公式将二阶转为一阶(h是一个很小的值)。

泰勒公式复习

利用最右下角的公式:

A = ∇ ω ′ L v a l ( ω ′ , α ) A=\nabla_{\omega^{\prime}} \mathcal{L}_{v a l}\left(\omega^{\prime}, \alpha\right) A=ωLval(ω,α), h = ϵ h=\epsilon h=ϵ, x 0 = w x_0=w x0=w, f = ∇ α L train  ( ⋅ , ⋅ ) f=\nabla_{\alpha} \mathcal{L}_{\text {train }}(\cdot, \cdot) f=αLtrain (,), 代入可得(其中经验上设置 ϵ = 0.01 ∣ ∣ ∇ w ′ L v a l ( w ′ , α ) ∣ ∣ 2 \epsilon=\frac{0.01}{||\nabla_{w'}\mathcal{L}_{val}(w',\alpha)||_2} ϵ=wLval(w,α)20.01)
∇ α , ω 2 L train  ( ω , α ) ⋅ ∇ ω ′ L val  ( ω ′ , α ) ≈ ∇ α L train  ( ω + , α ) − ∇ α L train  ( ω − , α ) 2 ϵ \nabla_{\alpha, \omega}^{2} \mathcal{L}_{\text {train }}(\omega, \alpha) \cdot \nabla_{\omega^{\prime}} \mathcal{L}_{\text {val }}\left(\omega^{\prime}, \alpha\right) \approx \frac{\nabla_{\alpha} \mathcal{L}_{\text {train }}\left(\omega^{+}, \alpha\right)-\nabla_{\alpha} \mathcal{L}_{\text {train }}\left(\omega^{-}, \alpha\right)}{2 \epsilon} α,ω2Ltrain (ω,α)ωLval (ω,α)2ϵαLtrain (ω+,α)αLtrain (ω,α)

其中

ω ± = ω ± ϵ ∇ ω ′ L v a l ( ω ′ , α ) \omega^{\pm}=\omega \pm \epsilon \nabla_{\omega^{\prime}} \mathcal{L}_{v a l}\left(\omega^{\prime}, \alpha\right) ω±=ω±ϵωLval(ω,α)

这样就可以将二次梯度转化为多个一次梯度。到这里复杂度从 O ( ∣ α ∣ ∣ w ∣ ) O(|\alpha||w|) O(αw)优化到 O ( ∣ α ∣ + ∣ w ∣ ) O(|\alpha|+|w|) O(α+w)

一阶近似: ξ = 0 \xi=0 ξ=0, 下面式子的二阶倒数部分就消失了,这样模型的梯度计算可能不够准确,效果虽然不如二阶,但是计算速度快。只需要假设当前的 w w w就是 w ∗ ( α ) w*(\alpha) w(α), 然后启发式优化验证集上的loss值即可。

计算结果

代码实现上也有一定的区别,代码将在下一篇讲解。

5. 实验设置

这里我们暂且先关注CIFAR10上的实验效果。DARTS构成网络的方式之前已经提到了,首先为每个单元内布使用DARTS进行搜索,通过在验证集上的表现决定最好的单元然后使用这些单元构建更大的网络架构,然后从头开始训练,报告在测试集上的表现。

CIFAR10上搜索操作有:

  • 3x3 & 5x5 可分离卷积
  • 3x3 & 5x5 空洞可分离卷积
  • 3x3 max & avg pooling
  • identiy
  • zero

实验详细设置:

  • 所有操作的stride=1, 为了保证他们空间分辨率,使用了padding。

  • 卷积操作使用的是ReLU-Conv-BN的顺序,并且每个可分离卷积会被使用两次。

  • 卷积单元包括了7个节点,输出节点为所有中间节点concate以后的结果。

  • 网络整体深度的1/3和2/3处强制设置了reduction cell来降低空间分辨率。

  • 网络结构参数 α normal \alpha_{\text{normal}} αnormal是被所有normal cell共享的,同理 α reduce \alpha_{\text{reduce}} αreduce是被所有reduction cell共享的。

  • 并没有使用全局batch normalization, 使用的是batch-specific statistic batch normalization

  • CIFAR10一半的训练集作为验证集。

  • 8个单元的消亡了使用DARTS训练50个epoch, batch size设置为64, 初始通道个数为16。

  • 使用momentum SGD来优化权重,初始学习率设置为0.025,momentum 0.9 weight decay为0.0004.

  • 网络架构参数 α \alpha α 使用0作为初始化,使用Adam优化器来优化 α \alpha α参数,初始学习率设置为0.0004,momentum为(0.5,0.999)weight decay=0.001。

CIFAR10上搜索结果和其他算法对比

可以看到,搜索结果最终是优于AmoebaNet-A和NASNet-A。具体搜索得到的Normal Cell和Reduction Cell可视化如下:

Normal Cell & Reduction Cell for CIFAR10

网络评价

网络优化对初始化值是非常敏感的,为了确定最终的网络结构,DARTS将使用随机种子运行四次,每次得到的Cell都会在训练集上从头开始训练很短一段时间大概100 epochs , 然后根据验证集上得到的最优结果决定最终的架构。

为了验证被选择的架构:

  • 随机初始化权重
  • 从头开始训练
  • 报告测试集上的模型表现

CIFAR10搜索的模型迁移到ImageNet更多细节:

  • 20个单元的大型网络使用了96的batch size, 训练了600个epoch
  • 初始通道个数由16修改为36,为了让模型的参数和其他模型参数量相当。
  • 其他参数设置和搜索过程中参数一样
  • 使用了cutout的数据增强方法,以0.2的概率进行path dropout
  • 使用了auxiliary tower(辅助头,在这里施加loss, 提前进行反向传播,InceptionV3中提出)
  • 使用PyTorch在单个GPU上花费1.5天时间训练完ImageNet,独立训练10次作为最终的结果。

CIFAR10上搜索结果

使用二阶优化方法+cutout的数据增强方法,DARTS能达到约2.76的准确率,笔者使用nni进行了实验,最终结果是2.6%的Test Error。

nni上darts的实验结果

6. 致谢&参考

感谢师兄提供的资料,以及知乎上两位大佬,他们文章链接如下:

薰风读论文|DARTS—年轻人的第一个NAS模型 https://zhuanlan.zhihu.com/p/156832334

【论文笔记】DARTS公式推导 https://zhuanlan.zhihu.com/p/73037439

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【神经网络搜索】DARTS: Differentiable Architecture Search 的相关文章

随机推荐

  • 数组练习题(编程题)

    1 从终端 键盘 读 20个数据到数组中 统计其中正数的个数 并计算这些正数之和 int sum 0 int count 0 int input int arr 20 0 初始化处理 arr 0x0000002d1b13f8c0 85899
  • 7.25总结,正则表达式+号的含义

    一 正则表达式 由 括起来的是需要判断的字符 eg a z A Z 0 9 在 加 号表示多次并且连续满足 条件的式子 表示有没有 String s1 123qwe13qwe s1 s1 replaceAll 0 9 替换 System o
  • 用Python爬虫接私活,赚了32K!

    网络爬虫 很多人觉得这是技术控的专属 实际上爬虫是人人都能掌握的技能 爬虫到底能干什么 基本你所能看到的全部信息 它都能抓取 例如 收集并批量下载某音乐软件付费歌曲 某视频软件的付费视频 采集北京所有小区的信息及北京所有小区的所有历史成交记
  • centos7 lvm 创建脚本

    centos7 lvm 创建脚本 Centos7 lvm创建 适用场景只有一块新加的磁盘 且未分区 挂载目录为riva 可自定义 date 2023 bin bash 注意此处变量 Disk dev sd 不同的平台会有差异 比如腾讯云为
  • Attempted to load tokenizers/punkt/PY3/english.pickle

    分明已经把punkt放到服务器相应文件下 但是还是显示没成功 错误原因是解压得时候文件目录有两个punkt
  • VUE之自定义插件

    index js文件 import promptBox from prompt box vue 定义插件对象 const PromptBox vue的install方法 用于定义vue插件 PromptBox install functio
  • rocketMq介绍和安装

    rocketMq介绍和安装 Mq介绍 MQ MessageQueue 消息队列 队列 是一种FIFO 先进先出的数据结构 消息由生产者发送到MQ进行排队 然后按原来的顺序交由消息的消费者进行处理 QQ和微信就是典型的MQ MQ的作用 主要有
  • Vue项目中如何引入外部js文件,并使用其中定义的函数

    Vue项目中如何引入外部js文件 并使用其中定义的函数 一些常用的功能函数 我们希望将其封装起来放入一个外部JS文件中 好方便我们在需要的时候使用 vue可以使用import指令引入外部文件 但是作为新手 在使用过程中难免会导致很多错误 这
  • maven导入依赖失败问题——最系统、最彻底的解决方案

    目录 一 idea配置maven 1 配置maven版本及本地仓库 二 清理Local Repository的 lastUpdated文件 三 在idea中重新导入依赖 一 idea配置maven 1 配置maven版本及本地仓库 关于项目
  • 排序算法之堆排序(Heap Sort)——C语言实现

    堆排序 Heapsort 是指利用堆积树 堆 这种数据结构所设计的一种排序算法 它是选择排序的一种 算法分析 在学习堆排序之前我们要先了解堆这种数据结构 堆的定义如下 n个元素的序列 k1 k2 kn 当且满足以下关系时 称之为堆 若将此序
  • 短视频平台-小说推文(知乎)推广任务详情

    知乎会员 知乎日结内测中 可能暂只对部分优质会员开放 2023 03 29通知 知乎拉新项目 由于内部测试转化较低 暂时下线 原有关键词出单不受影响 1 关键词 1 1 选择会员文 在知乎 首页 或者 会员 里面选取 需要选取文章链接 点进
  • 二.centos7和主机实现xftp传文件,包含解决一些主机虚拟机互联的问题

    首先打开安装好的centos7 一 需要获取的信息 虚拟机 ifconfig 信息 主要是ensxx的信息 主机 ipconfig 信息 主要是vm8的信息 虚拟机点 编辑 虚拟网络编辑器 界面 查看虚拟机右下是否连接 点虚拟机右上角 看有
  • Nacos启动: com.mysql.jdbc.exceptions.jdbc4.MySQLSyntaxErrorExceptionTable nacos_config.config_inf

    把nacos 1 1 4 server下载到本地之后 然后直接在bin目录下启动startup cmd报错 本地Mysql版本 5 6 44 nacos server版本 1 1 4 我找了很多解决办法 更多的说的是nacos自带的mysq
  • vue 实时往数组里追加数据

    使用Vue set 以下来解读一下 Vue set this tableDatas this selected obj 1 this tableDatas是我们声明好的数组 以下是自定义数据 tableDatas id 1 caseName
  • Python之数据分析(三维立体图像、极坐标系、半对数坐标)

    文章目录 写在前面 一 三维立体图像 1 三维线框 2 三维曲面 3 三维散点 二 极坐标系 三 半对数坐标 写在前面 import numpy as np import matplotlib pylab as mp 因此文章中的np就代表
  • Linux学习大纲

  • 以太网设计FAQ:以太网MAC和PHY

    问 如何实现单片 以太网 微控制器 答 诀窍是将微控制器 以太网媒体接入控制器 MAC 和物理接口收发器 PHY 整合进同一芯片 这样能去掉许多外接元器件 这种方案可使MAC和PHY实现很好的匹配 同时还可减小引脚数 缩小芯片面积 单片以太
  • MySQL第一讲 一遍让你彻底掌握MVCC多版本并发控制机制原理

    Mysql在可重复读隔离级别下 同样的sql查询语句在一个事务里多次执行查询结果相同 就算其它事务对数据有修改也不会影响当前事务sql语句的查询结果 这个隔离性就是靠MVCC Multi Version Concurrency Contro
  • Python爬虫入门8:BeautifulSoup获取html标签相关属性

    前往老猿Python博客 https blog csdn net LaoYuanPython 一 引言 在上节 https blog csdn net LaoYuanPython article details 113091721 Pyth
  • 【神经网络搜索】DARTS: Differentiable Architecture Search

    GiantPandaCV DARTS将离散的搜索空间松弛 从而可以用梯度的方式进行优化 从而求解神经网络搜索问题 本文首发于GiantPandaCV 未经允许 不得转载 1 简介 此论文之前的NAS大部分都是使用强化学习或者进化算法等在离散