元学习算法MAML论文详解

2023-05-16

论文信息

题目:Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks

作者:Chelsea Finn(伯克利大学),Pieter Abbeel ,Sergey Levine

年份:19

论文地址:论文地址

代码:maml-pytorch

基础补充

小样本学习(Few Shot Learing)概念

  • N-way N-shot。N-way 的意思是N分类,N-shot是在学习的样本中,每个类只提供5个样本,比如说让你学习辨认一只猫,只有5张猫的照片供你学习。

内容

摘要

提出一种 meta-learning 算法,该算法是模型无关的,适用于任何利用梯度下降的方法来训练的模型,并且适用于任何任务,包括:classification,regression,and reinforcement learning. meta-learning ,目标是在不同的任务上训练一个模型,使得该模型可以仅仅利用少量的数据,就可以解决新的任务。在提出的方法中,该模型的参数,是显示的进行训练的,使得新任务中使用少量的梯度步和少量的训练数据就可以产生良好的泛化性能 效果上来说,我们的方法得到的模型更加容易进行微调。该论文表明这种方法可以在两个 few-shot image classification benchmarks 得到极好的效果,并且加速了策略梯度强化学习算法。

动机

动机

  • 快速的进行学习是最近机器学习领域的一个研究热点问题
  • 人类可以根据已有的先验知识,就可以根据少量新的信息,快速的掌握一项新的技能
  • 但是这种快速且灵活的学习任务对于机器来说,确是非常困难的
  • 能不能是机器从少量样本中学习,然后应用于不同的新任务时也能表现出很好的能力

创新:(key idea)

  • 该论文提出的方法是训练模型的初始参数,使模型在参数更新后对新任务具有最大的性能。用在来自新任务的少量数据进行,计算通过一个或多个梯度步骤,就能满足要求;
  • 与前人的工作不同,他们通常基于 learn an update function or learning rule, 本论文的方法不会增加所需要学习参数的数量,也不会限制模型的框架,可以很容易与各种全连接,全卷积,或者循环神经网络。也可以与各种损失函数结合,如:可微分的监督损失函数,或者不可导的强化学习目标函数。

核心算法

从一个动态系统的角度,本论文的学习过程可以看做是:使新任务相对于参数的损失函数的灵敏度最大化:,当敏感度高的时候,对参数进行较小的局部改变可以带来任务损失上大的改善。找到某参数使得在多个任务上loss都发生较大改变,即找到灵敏性强的参数

算法1
分为两步,

  • 第一步是:在 T i \mathcal{T}_{i} Ti上训练得到 θ i \theta_{i} θi
    θ i ′ = θ − α ∇ θ L T i ( f θ ) \theta_{i}^{\prime}=\theta-\alpha \nabla_{\theta} \mathcal{L}_{\mathcal{T}_{i}}\left(f_{\theta}\right) θi=θαθLTi(fθ)

  • 第二步是:
    θ ← θ − β ∇ θ ∑ T i ∼ p ( T ) L T i ( f θ i ′ ) \theta \leftarrow \theta-\beta \nabla_{\theta} \sum_{\mathcal{T}_{i} \sim p(\mathcal{T})} \mathcal{L}_{\mathcal{T}_{i}}\left(f_{\theta_{i}^{\prime}}\right) θθβθTip(T)LTi(fθi)
    the meta-optimization 是通过 model parameters θ \theta θ 来实现的,其中,模型被用新的模型参数 θ ′ \theta' θ 进行计算。实际上,我们提出的模型,目标是优化模型的参数,使得 one or a small number of gradient steps on a new task will produce maximally effective behavior on that task.

其中, α , β α,β α,β分别是task中的进行梯度下降的学习率、和meta-learning过程的学习率, θ θ θ 是模型(神经网络的参数) f f f的权重参数。

算法2

在一个task中,使用左边的训练集做5次SGD的过程,再使用右边的测试集计算test error,在meta-learning过程中,把一个batch的4个task的test error平均一下作为loss再去进行优化。这个过程结束后,神经网络的权重到达了下图中的P点

我们再使用这个模型或者测试这个模型的准确度怎么用呢?我们说把100类图片分成了3个子集被划分成了train(64)、test(20)、val(16)三个子集。train中有64个类,用于上述的meta-learning。现在要将这个模型用在新的任务集具有16个类的test数据集上。仔细一想,训练好的模型并没有看见过test数据集中任何类啊。现在就是要说title中的Fast Adaptation的关键字了,在5-way 5-shot设定中,在测试的时候从test数据集中随机抽取5个类,每个类抽取N(>5)张照片,其中每个类抽取5张照片,用来微调模型中的参数,比如说在一个新任务下,把模型的参数调整至 θ 3 ∗ \theta_{3}^{∗} θ3​的位置,就是task做的事,即在新任务下只用5张照片来学习一下,用剩下的照片来计算精度

实验

实验1
在使用MAML的方法,与直接使用使用数据进行预训练对比

  • 可以看出使用了maml方法在k为5都能得到比较好的效果

实验2:

结论

Model-Agnostic(与模型无关的)是说,可以把task换成其他可以进行SGD过程的模型;Deep Networks(深度网络)可以适用于所有的深度学习模型。提出的MAML能在不同的任务上训练一个模型,使得该模型可以仅仅利用少量的数据,就可以解决新的任务。

可借鉴地方

  • 算法很经典有效,可以用于PINN等网络预训练作为初始化参数,将会加速训练收敛

参考
小样本学习论文

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

元学习算法MAML论文详解 的相关文章

随机推荐

  • 【ESP32+freeRTOS学习笔记之“ESP32环境下使用freeRTOS的特性分析(4-多核下的临界区)”】

    目录 关于临界区API的更改临界区API的工作过程使用临界区的限制和注意事项 关于临界区API的更改 Vanilla FreeRTOS通过禁用中断来实现临界区域 xff0c 这可以防止抢占式上下文切换和在临界区域提供ISR xff08 中断
  • 【嵌入式环境下linux内核及驱动学习笔记-(8-内核 I/O)-信号驱动】

    目录 3 信号驱动的异步通知3 1 linux异步通知编程3 1 1 什么是信号3 1 2 信号的工作流程 3 2 应用层3 2 1 信号接收 signal函数3 2 2 应用层 fcntl 函数3 2 3 应用层信号驱动机制步骤 3 3
  • TensorFlow、Python、CUDA版本对应及下载链接

    关于版本对应 xff0c 官网很详细了 xff1a https tensorflow google cn install source 偷个懒 xff0c 我就把截图放这里吧 xff1a 1 Windows xff1a 2 Linux 和
  • 表格驱动编程在代码中的应用

    1 毕业设计中的使用 第一次使用表格驱动编程 xff0c 是在大学毕业设计的时候 做一个LL 1 的词法分析程序 xff0c 需要读取终结符 非终结符 以及推导公式 程序会根据以上信息生成FIRST集合和LAST集合 xff0c 然后根据递
  • 【嵌入式环境下linux内核及驱动学习笔记-(9-内核定时器)】

    目录 1 时钟tick中断等概念2 延时机制2 1 短延时 xff08 忙等待类 非阻塞害 xff09 2 1 1 ndelay 忙等待延迟多少纳秒2 1 2 udelay 忙等待延迟多少微秒2 1 3 mdelay 忙等待延迟多少毫秒 2
  • 【嵌入式环境下linux内核及驱动学习笔记-(10-内核内存管理)】

    目录 1 linux内核管理内存1 1 页1 2 区1 2 1 了解x86系统的内核地址映射区 xff1a 1 2 2 了解32位ARM系统的内核地址映射区 xff1a 2 内存存取2 1 kmalloc2 1 1 kfree2 1 2 k
  • 力扣刷题常用的c++库函数

    文章目录 1 xff0c max和min1 max函数2 xff0c min函数 2 xff0c sort函数sort 函数和lambda表达式 3 xff0c reverse 函数1 reverse函数可以反转一个字符串2 反转字符数组3
  • STM32学习(4)串口实验

    串口设置的一般步骤可以总结为如下几个步骤 xff1a 串口时钟使能 xff0c GPIO 时钟使能串口复位GPIO 端口模式设置串口参数初始化开启中断并且初始化 NVIC xff08 如果需要开启中断才需要这个步骤 xff09 使能串口编写
  • 【Docker】 入门与实战学习(Docker图形化工具和Docker Compose)

    文章目录 前言Docker图形化工具1 查看portainer镜像2 portainer镜像下载3 启动dockerui容器4 浏览器访问5 单机版Docker xff0c 直接选择Local xff0c 点击连接6 使用即可 Docker
  • 第三天_DOM

    第三天 Web APIs 学习目标 xff1a 能够使用removeChild 方法删除节点 能够完成动态生成表格案例 能够使用传统方式和监听方式给元素注册事件 能够说出事件流执行的三个阶段 能够在事件处理函数中获取事件对象 能够使用事件对
  • MySQL知识点整理汇总

    文章目录 前言一 数据库与SQL1 数据库与数据库管理系统2 关系数据库3 MySQL语句的种类4 MySQL语句的基本书写规则 二 MySQL语句的两大顺序1 MySQL 语句的书写顺序2 MySQL 语句的执行顺序 三 表的创建 删除与
  • 麦克科马克

    这里写自定义目录标题 欢迎使用Markdown编辑器新的改变功能快捷键合理的创建标题 xff0c 有助于目录的生成如何改变文本的样式插入链接与图片如何插入一段漂亮的代码片生成一个适合你的列表创建一个表格设定内容居中 居左 居右SmartyP
  • ROS-创建工作空间与功能包

    这里写目录标题 一 工作空间的组成与结构二 创建工作空间三 创建功能包四 设置环境变量五 功能包的package xml文件和CMakeLists txt文件 一 工作空间的组成与结构 工作空间的组成 xff1a src用于存放功能包源码
  • 「NeurIPS 2020」基于局部子图的图元学习

    点击蓝字 xff0c 设为星标 NeurIPS 2020 的接收论文 Graph Meta Learning via Local Subgraphs xff0c G META 是第一个使用局部子图来进行元学习的模型 Graph Meta L
  • Keras:Input()函数

    目录 1 Keras Input 函数 2 函数定义 xff1a 3 参数解释 4 例子 1 Keras Input 函数 作用 xff1a 初始化深度学习网络输入层的tensor 返回值 xff1a 一个tensor 2 函数定义 xff
  • JDBC入门笔记

    目录 1 xff0c JDBC概述 1 1 JDBC概念 2 xff0c JDBC快速入门 Java操作数据库的流程 2 1 编写代码步骤 3 JDBC API详解 3 1 DriverManager 3 2 Connection 3 2
  • 对抗样本入门详解

    文章目录 对抗样本基本原理对抗样本的发生对抗样本防御难在哪里对抗训练隐藏梯度defensive distillation 对抗样本的生成对抗样本生成方法介绍利用GAN生成对抗样本利用FGSM生成对抗样本代码复现 xff08 基于mnist
  • white/black-box attack(黑盒白盒攻击基础)

    基本概念 攻击方法分类标准 xff1a 假正性攻击 false positive 与伪负性攻击 false negative 假正性攻击 xff1a 原本是错误的但被被攻击模型识别为正例的攻击 eg 一张人类不可识别的图像 xff0c 被D
  • KL散度公式详解

    目录 文章目录 Jensen 39 s inequality讲解KL散度 xff08 又名relative entropy xff09 mutual information Jensen s inequality f x
  • 元学习算法MAML论文详解

    论文信息 题目 xff1a Model Agnostic Meta Learning for Fast Adaptation of Deep Networks 作者 xff1a Chelsea Finn 伯克利大学 xff0c Pieter