Unsupervised Data Augmentation For Consistency Training 论文阅读

2023-05-16

Unsupervised Data Augmentation For Consistency Training 论文阅读

UDA这篇文章针对的是半监督学习中无标签数据的增强,论文提出,使用有标签数据的data agumentation方法,也能有效的应用于无标签数据的增强中。论文在文本分类和图像分类问题上进行了实验对比。

这是半监督训练的流程图,左侧是有标签数据,右侧是无标签数据。可以看出有标签数据的做法和普通做法没什么区别,而无标签数据采用的一种叫一致性训练的思想(不是这篇论文的成果):首先对无标签数据做一个增强,然后将增强前的数据和增强后的数据都送进网络,出一个预测结果,将这两个结果算一个KL散度作为无监督的loss,和有监督的loss加在一起做BP。目标函数就是这个样子:

min ⁡ θ J ( θ ) = E x , y ∗ ∈ L [ − log ⁡ p θ ( y ∗ ∣ x ) ] + λ E x ∈ U E x ^ ∼ q ( x ^ ∣ x ) [ D K L ( p θ ~ ( y ∣ x ) ∥ p θ ( y ∣ x ^ ) ) ) ] \left.\min _{\theta} \mathcal{J}(\theta)=\mathbb{E}_{x, y^{*} \in L}\left[-\log p_{\theta}\left(y^{*} | x\right)\right]+\lambda \mathbb{E}_{x \in U} \mathbb{E}_{\hat{x} \sim q(\hat{x} | x)}\left[\mathcal{D}_{\mathrm{KL}}\left(p_{\tilde{\theta}}(y | x) \| p_{\theta}(y | \hat{x})\right)\right)\right] minθJ(θ)=Ex,yL[logpθ(yx)]+λExUEx^q(x^x)[DKL(pθ~(yx)pθ(yx^)))]

前一项就是有标签的loss,后一项就是无标签的loss, λ \lambda λ表示两者之间的比例。这里面有一个训练时的trick,在做BP时,无标签那边未增强数据那一条支路(也就是x那一路)是BP阻断的,就是反向传播不计算也不使用那里的梯度(但不是不更新M),事实上确实挺有用的(因为一开始我没阻断…)

本文关注的问题是,用什么样的方法来做无标签数据的增强,训练的效果好呢?本文得出的结论是,用在有标签数据上增强效果好的方法,用在无标签数据上也好,因此就挑了三种在有标签数据增强上表现好的方法来做实验。文本相关的我不懂,这个RandAugment(RA)是做图像增强的,是从一个叫AutoAugment的方法简化过来的。RA增强时,每次使用两种数据增强方法,是从PIL中选出来的15种中随机抽两种(可以重复)。这两种方法,每种方法都有50%的概率会被执行(或者以另一种方式执行,比如旋转可以是正度数也可以是负度数),还有一个magnitude参数来描述具体的执行程度(比如旋转可以转0°~60°,magnitude用来确定具体旋转多少度)。论文还使用了一个16x16的mask,来随机遮盖(涂黑)cifar10(32*32)上的一个区块。论文在附录还给了一个研究成果,大概是每次做十种增强,训练出来的效果最好。

论文中使用了自己提出的TSA(Training Signal Annealing),来防止有标签数据过拟合。具体来说,如果一个有标签数据的预测概率超过了当前阈值,那么在这次计算loss时就不算它的了(因为它已经表现很好了,再训练也不过是过拟合而已)。如下图,针对容易/一般/不易过拟合三种情况,阈值曲线分别是指数型/线性型/对数型,来影响拟合的速度。


实际上,半监督的研究我是做不起的,看后面的训练参数就知道做不到。就拿Cifar10来说,论文训练了400k个step,有监督数据batchsize64,无监督数据batchsize320,如果有监督数据取4k(论文的某一个实验),这玩意要跑6.4k个epoch,我一个epoch就要8分钟(当然,实现上有点差别),没设备是真的不行…以后不做这种自己根本不可能做的研究和复现了…

我的复现实现细节上有些差别,我用了5k个有监督数据,有监督数据batchsize32,无监督数据batchsize288,刚好一个epoch每个图片用一次。我一共只训练了15k个step,100个epoch,学习率论文是从0.03到0.004余弦下降, λ \lambda λ取1。我是从0.004开始每30个epoch降到1/2, λ \lambda λ取10。这样子训练出来也就那样…我也不想折腾它了,还是干点能干的吧。代码挂在GITHUB上:https://github.com/cyfwry/UDA-repo-pytorch
以上是我一周前的想法,后来我去自己写了个ESPCN来训练,结果也做不出来…折腾了一周了…我想我可能确实不适合DL这个领域,也许应该转行了吧…

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Unsupervised Data Augmentation For Consistency Training 论文阅读 的相关文章

随机推荐