元学习—模型不可知元学习(MAML)

2023-05-16

元学习—模型不可知元学习(MAML)

在之前的文章中,我们介绍了神经图灵机和记忆增强网络(MANN),主要介绍了其对于内存中信息的读取与写入。有兴趣的读者可以参考我之前的博客元学习—神经图灵机。在今天的文章中,我们来介绍一种更加常见的元学习的学习方法,即模型不可知元学习。

1. MAML原理

1.1 MAML引入

MAML是一种最近被提出的,最为主流的一种元学习的方法。其是元学习上的一个重大突破。在元学习中,众所周知,其目标是学会学习。在元学习中,我们从大量的相关学习任务中获取一小部分的样本点,然后通过元学习器来生成一个快速的学习器,再通过少量的样本作用在新的相关的任务之上。

MAML背后的思想是寻找出更好的初始化参数。通过这种更好的初始化参数,模型可以通过少量的梯度下降的步骤来应用到新的任务之上。

下面我们举一个使用神经网络的分类任务作为例子。一般的来讲,我们初始训练过程往往是从一组随机参数开始的,通过最小化loss函数来实现梯度下降的过程,以此对于参数进行调优。即我们通过Loss函数来计算损失,通过梯度下降的方式来寻找新的参数值,新的参数值能够保证Loss变的更小,通过不断的迭代,我们将Loss值降到最小,同时最小的Loss值对应的参数值即为最优值(注意:这个Loss值最小,大多数是局部最小,并非全局最小。)

在MAML中,根据我们上面的描述,我们的目标是希望获取一组相对最优的参数来作为模型的初始化参数,那么应该如何获取这种最优参数呢? 在MAML中,我们使用的是从一些相似数据分布和相似任务上来进行获取。因此,当有一个新的任务开始时,我们不会使用一个随机的参数来进行初始化,我们可以通过将其他相关任务的最优参数进行迁移,作为新任务的初始化参数。这样做的好处有两个,第一个是可以减少梯度下降的步骤,而第二个是可以减少训练过程的数据需求。

这里,我们举一个例子来理解一下MAML计算参数与一般模型计算参数的过程对比,假设我们当前有三个任务,分别使用 T 1 , T 2 , T 3 T_1,T_2,T_3 T1,T2,T3来进行标记。对于一般的模型而言,首先,我们随机的初始化我们的模型参数θ,并利用模型来实现对任务 T 1 T_1 T1进行训练。然后,通过梯度下降的方式来最小化损失函数L。通过这一次的训练过程,我们可以为任务 T 1 T_1 T1寻找到一个相对最优的参数 θ 1 ′ θ_1' θ1。类似的方式,通过随机初始化参数,可以为任务 T 2 , T 3 T_2,T_3 T2,T3寻找相对最优的参数 θ 2 ′ , θ 3 ′ θ_2',θ_3' θ2,θ3。即,我们通过一组随机初始化的参数θ,可以生成三个相对最优的参数 θ 1 ′ , θ 2 ′ , θ 3 ′ θ_1',θ_2',θ_3' θ1,θ2,θ3。即如下图所示:

在这里插入图片描述

进一步,在MAML中。为了在初始化的时候替换掉随机生成的参数,以此来减少梯度下降的步数,缩短训练时间。这里选择其他相关任务训练出来的参数 θ ′ θ' θ来指导初始的参数θ,即如下图所示:
在这里插入图片描述
这里,值得考虑的一个问题是,我们选择的指导参数 θ ′ θ' θ是否能够同时适应三个任务 T 1 , T 2 , T 3 T_1,T_2,T_3 T1,T2,T3?,从这个角度出发,就需要我们考虑的指导参数 θ ′ θ' θ应该是一种共同的,泛化的参数。

进一步,当有新的任务 T 4 T_4 T4的时候,我们可以选择使用优化之后的参数 θ θ θ来进行作为新任务的初始化参数。

最后,我们简单的总结一下MAML的基本思路,即寻找一个优化的参数θ,这个参数对于相关任务是通用的,其能够帮助我们使用更少量的样本进行学习,缩短训练时间。这也意味着我们可以将MAML应用到任意的使用梯度下降的学习方法中。下面,我们来具体探索MAML中原理和细节。

1.2 MAML算法流程

通过之前的描述,我们对于MAML的背景已经有了一定的了解,下面我们来探索MAML中的一些细节问题。假设,我们的模型为 f f f,并且其可以通过参数 θ θ θ来进行描述,即 f θ f_θ fθ。这里,我们在定义一些相关的任务T,T中任务的分布概率为 P ( T ) P(T) P(T)

首先,我们先用随机值对于参数 θ θ θ进行随机的初始化。进一步,我们通过概率分布 P ( T ) P(T) P(T)对于任务集合中的任务进行采用,这里选择5个相关任务,作为一个batch,即表达为 T = { T 1 , T 2 , T 3 , T 4 , T 5 } T=\{T_1,T_2,T_3,T_4,T_5\} T={T1,T2,T3,T4,T5}。然后,对于每一个任务 T i T_i Ti,我们可以采用k个样本点来训练这个模型。至此,根据每一个任务,我们可以计算出来其损失函数 L T i ( f θ ) L_{T_i}(f_θ) LTi(fθ),我们通过梯度下降来最小化这个损失,寻找能够使得的损失函数最小的参数,即:
θ i ′ = θ − α ▽ θ L T i ( f θ ) θ_i'=θ-α▽_θL_{T_i}(f_θ) θi=θαθLTi(fθ)
其中, θ i ′ θ_i' θi表示的是对于任务 T i T_i Ti的最优化参数, θ θ θ表示的是初始化参数,α是一个超参数, L T i ( f θ ) L_{T_i}(f_θ) LTi(fθ)表示的是梯度计算结果。

对于T中5个任务都进行计算之后,我们可以获得各个任务的相对最优的参数集合,即 θ ′ = { θ 1 ′ , θ 2 ′ , θ 3 ′ , θ 4 ′ , θ 5 ′ } θ'=\{θ_1',θ_2',θ_3',θ_4',θ_5'\} θ={θ1,θ2,θ3,θ4,θ5}。在采样下一个batch的任务之前,我们使用一个元更新或者元优化的策略。在之前的一步中,我们通过梯度下降计算出了相对最优的参数 θ i ′ θ_i' θi,并且通过任务 T i T_i Ti中的参数对应的梯度,来更新了我们初始化的随机参数θ,这使得我们初始随机的参数θ,移动到了一个相对最优的位置。在一个批次的任务的训练中,减少了梯度下降的步数,这一步被称为“元步”,“元更新”,“元优化”或者“元训练”。通过公式,可以将其描述为:
θ = θ − β ▽ θ ∑ T i − p ( T ) L T i ( f θ i ′ ) θ=θ-β▽_θ∑_{T_i-p(T)}L_{T_i}(f_{θ_i'}) θ=θβθTip(T)LTi(fθi)
在上述的公式中,θ表示的是初始化的参数,β表示的是一个超参数。 L T i ( f θ i ′ ) L_{T_i}(f_{θ_i'}) LTi(fθi)表示的是通过参数 θ i ′ θ_i' θi所计算出来的关于任务 T i T_i Ti的梯度结果。这里,我们可以进一步的使用对于各个任务的相对最优参数 θ i ′ θ_i' θi对于的梯度和的平均值来进行计算。

最后,我们对于MAML算法的流程进行一下简单的总结。MAML算法一共可以分成两个循环,其中一个内部循环被用来确定当前任务集合中的各个任务对应的最优参数 θ i ′ θ_i' θi。外层的循环用于通过内层计算出来的最优参数对应的梯度来更新我们的初始的随机参数θ。我们使用一张图来描述一下这个过程:

在这里插入图片描述

2 MAML模型的应用

2.1 监督学习中的MAML模型

MAML模型善于去寻找最优的模型初始化参数。进一步,我们来描述一下其在监督学习过程中的使用过程。首先,我们先给出监督学习的损失函数的定义形式:

如果是监督学习中的回归学习,我们可以采用均方误差的形式来定义其损失函数:
L T i ( f θ ) = ∑ x j , y j − T i ∣ ∣ f θ ( x i ) − y i ∣ ∣ 2 2 L_{T_i}(f_θ)=∑_{x_j,y_j-T_i}||f_θ(x_i)-y_i||_2^2 LTi(fθ)=xj,yjTifθ(xi)yi22
如果是监督学习中的分类任务,我们使用交叉熵的损失函数:
L T i ( f θ ) = ∑ x j , y j − T i y j l o g f θ ( x j ) + ( 1 − y j ) l o g ( 1 − f θ ( x j ) ) L_{T_i}(f_θ)=∑_{x_j,y_j-T_i}y_jlogf_θ(x_j)+(1-y_j)log(1-f_θ(x_j)) LTi(fθ)=xj,yjTiyjlogfθ(xj)+(1yj)log(1fθ(xj))

下面,我们来逐步的介绍MAML的使用过程

  1. 假设我们当前拥有一个模型f,可以通过参数θ来进行描述。并且我们有一个分布为 p ( T ) p(T) p(T)的相关任务集合。首先,我们来随机初始化参数θ。
  2. 我们对任务集合中的任务进行采样,假设我们当前采样的任务集合为 T = { T 1 , T 2 , T 3 } T=\{T_1,T_2,T_3\} T={T1,T2,T3}
  3. 内层循环:对于当前任务集合T中的每一个任务 T i T_i Ti,我们采样K个样本点来生成当前任务的训练集和测试集
    D i t r a i n = { ( x 1 , y 1 ) , ( x 2 , y 2 ) , . . . , ( x k , y k ) } D_i^{train}=\{(x_1,y_1),(x_2,y_2),...,(x_k,y_k)\} Ditrain={(x1,y1),(x2,y2),...,(xk,yk)}
    D i t e s t = { ( x 1 , y 1 ) , ( x 2 , y 2 ) , . . . . , ( x k , y k ) } D_i^{test}=\{(x_1,y_1),(x_2,y_2),....,(x_k,y_k)\} Ditest={(x1,y1),(x2,y2),....,(xk,yk)}
    这里值得注意的是,我们的这里训练集的样本和测试集的样本是相同的,训练数据集的样本是在内层循环中为具体任务寻找最优参数θi的时候用的。而测试集是在外层循环中,寻找最优的参数θ时被用到。这里的测试集的目的不是来检查模型的表现。其基础的作用是作为外层循环的训练集。我们也可以将我们的测试集称为元训练集
    至此,我们使用监督学习算法作用在 D i t r a i n D_i^{train} Ditrain上面,计算出损失,并使用梯度下降算法来减小损失,获取相对最优参数 θ i ′ θ_i' θi,即: θ i ′ = θ − α ▽ θ L T i ( f θ ) θ_i'=θ-α▽_θL_{T_i}(f_θ) θi=θαθLTi(fθ)。对于任务集合中的每一个任务,我们都采样K个样本点来在其训练集上进行最小化损失,获取最优参数的操作。最后,我们可以获取一组最优参数: { θ 1 ′ , θ 2 ′ , θ 3 ′ } \{θ_1',θ_2',θ_3'\} {θ1,θ2,θ3}
  4. 外层循环: 这里我们使用之前定义的测试集来进行元优化。这里,我们使用测试集 D i t e s t D_i^{test} Ditest来最小化损失。通过我们之前计算出来的最优参数 { θ 1 ′ , θ 2 ′ , θ 3 ′ } \{θ_1',θ_2',θ_3'\} {θ1,θ2,θ3}对应的梯度结果,我们来最小化外层循环的损失,更新之前的随机参数,即 θ = θ − β ▽ θ ∑ T i − p ( T ) L T i ( f θ i ′ ) θ=θ-β▽_θ∑_{T_i-p(T)}L_{T_i}(f_{θ_i'}) θ=θβθTip(T)LTi(fθi)
  5. 我们重复第2步到第5步来进行迭代,以此来获取最优的参数θ’。

最后,我们使用一个图来总结一下上述的流程:

·

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

元学习—模型不可知元学习(MAML) 的相关文章

  • FreeRTOS任务调度最后篇

    FreeRTOS开启任务调度 一篇说到启动任务调度最后启动Systick定时器 xff0c 通过SVC中断引导第一个任务执行 然后系统就在Systick的定时中断下调度任务执行 xff0c 这次介绍最后的部分 xff0c Systick和P
  • 从STM32-FreeRTOS到linux

    之前买的STM32的开发板学习裸机开发 了解裸机之后学习FreeRTOS来作为小型操作系统学习 xff0c 理解操作系统调度实现 一直想学习一下linux的内核 xff0c 之前下载源码和初步看了下感觉无从下手 有了RTOS的基础后 xff
  • C#实现图片旋转

    C 绘图正常是不涉及到旋转的 有时候会有旋转画笔的情况 比如条码打印字竖着打印 旋转图片一定角度绘制 或者斜着画水印 这时候就涉及到旋转画笔了 源码地址 通过graphics TranslateTransform Pcenter X Pce
  • C#调C++库返回字符串

    用C 调C 43 43 库函数返回字符串 xff0c 由于C 43 43 本身方法之间调用返回字符串都是一般都是申明void或int返回的方法 xff0c 然后通过char变量带出返回值 在C 43 43 调用这种之前自己先初始化char空
  • Asp.NetCore在CentOS网站卡死

    最近碰到项目的网站在高峰期卡死的现象 刚开始以为是数据库问题导致的卡死 xff0c 就排查和改了数据的设置 然后观察几天发现网站还是会在高峰期卡死 xff0c 然后改了点网站设置 xff0c 准备第二天观察一下 xff0c 星期二竟然又没出
  • 使用IRIS碰到的坑

    最近换新电脑了 xff0c 然后直接不安装cache2016了 xff0c 直接上IRIS啊 然后碰到几个坑 xff0c 一是在win11不知道是兼容性不好还是怎么了 每次重启电脑后数据库就无法启动 xff0c 为此祭出多年保存的方子 xf
  • K8s 配置高可用提示Configuration file ‘/etc/keepalived/keepalived.conf‘ is not a regular non-executable file

    k8s配置keepalived高可用 xff0c systemctl start keepalived提示 检查keepalived配置文件 xff0c 查询配置也正常 从报错提示显示keepalived conf 配置文件是一个非执行的文
  • Std数据M的荣光

    对检验的上线 xff0c 实施和开发的大部分时间都用在做基础数据和联设备对通道这些 对相同的仪器每次都有做项目数据 xff0c 对通道那些我一直深有感触 xff0c 一直在构思怎么减少仪器对通道这些做数据的工作量 奈何以前只是浅显的使用M
  • matlab从图表中提取数据

    有如下的波形图 xff0c 如何从中精确提取出全部的数据 1 将波形图片 截图 保存为test png或test jpg xff0c 并将图片放于matlab工作目录中 xff0c 如下图示例所指定的目录中 xff1a 2 xff0c 新建
  • STM32 基础系列教程 1- CubeMX+GPIO

    前言 学习stm32 GPIO 的使用 xff0c 设置某一GPIO引脚为输出功能 xff0c 将对应引脚拉高或拉低输出 xff0c 同时学会初步认识STM32最新的HAL库的使用 xff0c 用代码实现控制GPIO引脚输出产生周期出1s
  • STM32 基础系列教程 29 - FreeRTOS

    前言 学习stm32 中 FreeRTOS嵌入式实时操作系统的使用 xff0c 学会在FreeRTOS时行任务创建与任务运动 xff0c 学习在嵌入式实时操作系统下编程 xff0c 用串口打印相应信息 xff0c 并控制LED闪烁 示例详解
  • 对本地的代码进行修改后,直接git pull会提示本地代码和github代码冲突,需要先commit本地代码,或者stash他们

    对本地的代码进行修改后 xff0c 直接git pull会提示本地代码和github代码冲突 xff0c 需要先commit本地代码 xff0c 或者stash他们 对本地的代码进行修改后 xff0c 直接git pull会提示本地代码和g
  • linux查询内存、CPU、硬盘等系统信息的命令

    一 linux CPU大小 root 64 idc cat proc cpuinfo grep 34 model name 34 amp amp cat proc cpuinfo grep 34 physical id 34 model n
  • ubuntu无法更新的问题,提示错误Err http://mirrors.163.com trusty Release.gpg Could not resolve 'mirrors.163.com

    最近在安装使用ubuntu xff0c 并且配置源文件下载相应gcc xff0c gdb时候 xff0c 出现错误 xff0c 提示报错内容为 Err http mirrors 163 com trusty Release gpg Coul
  • 在 GitHub 下载某个程序的特定版本代码

    情况 github中某个项目已经更新到2 1 0版本 但是想要它的1 0 1版本怎么办 方法一 xff1a 首先点击这个repository下的这个branch按钮 点开了以后你会看到这个 xff0c 然后点tags 选择你想要下载的版本
  • Pixhawk之姿态控制

    原文地址 xff1a http blog csdn net qq 21842557 1 写在前面 无人机控制部分主要分为两个部分 xff0c 姿态控制部分和位置控制部分 xff1b 位置控制可用远程遥控控制 xff0c 而姿态控制一般由无人
  • Android注册表文件

    data system packages plist com google android ears 10043 0 data data com google android ears default 3003 1028 1015 com
  • Java 爬虫系列丨(一)爬虫介绍

    1 简介 1 1 背景 随着互联网的迅速发展 xff0c 网络资源越来越丰富 xff0c 信息需求者如何从网络中抽取信息变得至关重要 目前 xff0c 有效的获取网络数据资源的重要方式 xff0c 便是网络爬虫技术 简单的理解 xff0c
  • 基于龙伯格观测器的永磁同步电机仿真与实现

    摘 要 xff1a 在永磁同步电动机控制系统中 xff0c 使用转子位置传感器不仅会增加设计和制造的成本 xff0c 还会使系统的可靠性降低 因此 xff0c 无位置传感器技术已成为永磁同步电机控制领域的研究热点之一 本文对龙伯格观测器技术
  • 拷贝cp大文件报错“文件太大”

    问题 xff1a 今天在centos7系统下 xff0c u盘位vfat格式16个G xff0c 拷贝7个G大小的问文件 xff0c 无论是用dd还是cp都在拷贝到4 3G大小的时候显示失败 故写下这篇博客 无论什么系统 xff0c 只要分

随机推荐

  • CMakeList.txt

    一 Cmake 简介 cmake 是一个跨平台 开源的构建系统 它是一个集软件构建 测试 打包于一身的软件 它使用与平台和编译器独立的配置文件来对软件编译过程进行控制 二 常用命令 1 指定 cmake 的最小版本 cmake minimu
  • 安装centos7 卡在 “正在安装引导装载程序”界面

    今天系统突然起不来 xff0c 不知道什么原因删掉了一些文件 修复太浪费时间 xff0c 还是重新装一个系统 xff08 原来的分区有很多个人资料 xff0c 所以一定不能格调 xff0c 在无用的分区上装新的系统 所以你装系统的时候尽量不
  • insmod: ERROR: could not insert module: Invalid module format

    root 64 zn pc home zn sedriver 5000 new sedriver 5000 span class token comment insmod wst se echip drv ko span insmod ER
  • LoongArch上正常使用`pip install`

    原创 xff1a 你在使用loongarch架构操作系统时 xff0c 是否遇到pip install 安装失败的情况 xff1f 刷到这篇文章 xff0c 大家可添加评论或者私信我 xff0c 及时满足大家的需求 那么 xff0c 下面讲
  • python SOABI兼容性问题

    首先说明一点 xff1a 龙芯发布的仓库都是基于configure ac 中包含loongarch64 linux gnu定义的python所构建 https blog csdn net zhangna20151015 article de
  • python中为什么加上中文注释就会报错

    由于Python源代码也是一个文本文件 xff0c 所以 xff0c 当你的源代码中包含中文的时候 xff0c 在保存源代码时 xff0c 就需要务必指定保存为UTF 8编码 当Python解释器读取源代码时 xff0c 为了让它按UTF
  • 关于在linux操作系统下打不出汉字或者在敲打汉字时无法显示拼音的问题

    在linux下出现问题不比在window下形象 在window下 你发现哪个软件有问题了 xff0c 点击几下鼠标就完事了 xff1b 要是在linux系统下 xff0c 不懂代码 xff0c 可修复不了 打不出汉字 xff0c 在这我就说
  • 解析/etc/hosts文件

    1 xff0c etc hosts xff0c 主机名和ip配置文件 hosts The static table lookup for host name 主机名查询静态表 linux 的 etc hosts是配置ip地址和其对应主机名的
  • c++语法大全

    c 43 43 语法大全 一 变量和简单数据类型 1 变量名只能包含字母 数字和下划线 可以以字母和下划线开头 xff0c 但是不能从数字开头 xff1b 变量名不能包含空格 2 数据类型 字符串 字符串可以用双引号或者单引号括起来 xff
  • libxml2的安装及使用

    本文着重介绍解析xml的libxml2库的安装及使用 xff0c 举例说明创建和解析xml的过程 是针对C语言开发人员使用 你若想详细学习前端的一套东西 xff0c 即xml html css javascript JS 等 xff0c 可
  • dd 与cp的区别

    dd命令和cp命令的区别 cp与dd的区别在于cp可能是以字节方式读取文件 xff0c 而dd是以扇区方式记取 显然dd方式效率要高些 dd最大的用处是他可以进行格式转换和格式化 dd是对块进行操作的 xff0c cp是对文件操作的 比如有
  • 畸变校正与极线校正(具体原理+Matlab代码)

    附 xff1a 相关需要的工具函数源代码 xff08 投影函数 校正矩阵计算等 xff09 见最下面 1 畸变校正 1 1 形成原因 图像畸变一般有两种 xff0c 第一种是透镜本身的形状有问题 xff0c 使得图像发生径向畸变 xff1b
  • 无人驾驶项目——交通标志识别

    在无人驾驶项目中 xff0c 实现交通标志识别是一项重要工作 本文以德国交通标志数据集为训练对象 xff0c 采用深度神经网络LeNet架构处理图像 xff0c 实现交通标志识别 具体处理过程包括包括 xff1a 数据导入 探索和可视化数据
  • 使用机器人操作系统ROS 2和仿真软件Gazebo 9主题进阶实战(七)- mobot速度发布与里程计订阅

    在ROS2课程中已经学过并掌握了一个基本的发布器和订阅器 xff08 C 43 43 xff09 xff0c 官网的教程全部掌握大致需要20分钟吧 这过程包括 xff1a 创建一个功能包编程实现一个发布节点编程实现一个订阅节点编译与运行 这
  • ROS + Caffe 机器人操作系统框架和深度学习框架笔记 (機器人控制與人工智能)

    ROS 43 Caffe xff0c 这里以环境中物体识别为示例 xff0c 机器人怎么知道环境里面有什么呢 xff1f 0 0567392 n03376595 folding chair 0 0566773 n04099969 rocki
  • Ubuntu 16.04 使用docker资料汇总与应用docker安装caffe并使用Classifier(ros kinetic+usb_cam+caffe)

    Docker是开源的应用容器引擎 若想简单了解一下 xff0c 可以参考百度百科词条Docker 好像只支持64位系统 Docker官网 xff1a https www docker com Docker 从入门到实践 xff1a http
  • Ubuntu与ROS的Docker桌面系统与ROS在线练习课程(在线Linux虚拟机)

    ROS在线练习课程正在逐步完善中 xff0c 目前以ROS官网中文资料制作 xff0c 可参考 xff1a https www shiyanlou com courses 854 邀请码 U23ERF8H 安装Ubuntu 43 ROS对于
  • 用于ARM和Debian的ROS Docker镜像

    这里推荐两个链接 xff1a 1 Using ROS with Docker in macOS xff1a https www xiaokeyang com blog using ros with docker in macos 2 Get
  • 2021电赛F题之openmv巡线(附代码)

    效果展示 xff1a 出错解决方法 openmv数字识别源代码 gitee 通过使用不同阈值的方法可以得到当前区域中什么区域有红线 xff0c 对于电控而言作用类似于红外对管 xff0c 之后电控通过逻辑判断如何运动 xff0c 这就是我们
  • 元学习—模型不可知元学习(MAML)

    元学习 模型不可知元学习 MAML 在之前的文章中 xff0c 我们介绍了神经图灵机和记忆增强网络 MANN xff0c 主要介绍了其对于内存中信息的读取与写入 有兴趣的读者可以参考我之前的博客元学习 神经图灵机 在今天的文章中 xff0c