meta—learning调研及MAML概述

2023-05-16

背景

Meta Learning,又称为 learning to learn,Meta Learning希望使得模型获取一种“学会学习”的能力,使其可以在获取已有“知识”的基础上快速学习新的任务,对于新的类别,只需要少量的样本就能快速学习(Few-shot Learning)。

Few-shot Learning 是 Meta Learning 在监督学习领域的应用。

数据集

早期研究都基于以下两个图像数据集:

Omniglot:https://github.com/brendenlake/omniglot

包含1623个不同的火星文字符,每个字符包含20个手写的case

miniImageNet:https://github.com/yaoyao-liu/mini-imagenet-tools

包含100类共60000张彩色图片,其中每类有600个样本

主流算法

MAML(入门+重要)

2017年发表,到2022年7月12日已经收获493的引用 https://arxiv.org/pdf/1703.03400.pdf

MAML与其说是一个深度学习模型,倒不如说是一个框架,提供一个meta-learner(MAML的精髓所在,learing to learn)用于训练base-learner(根据新数据实际用于预测任务的模型)。

绝大多数深度学习模型都可以作为base-learner无缝嵌入MAML中。

(一)目的

MAML的目的是获取一组更好的模型初始化参数(即让模型自己学会初始化)

可以这么理解:假设我们目前有3个tasks,分别为T 1 , T 2 , T 3 。按照以前模型的训练方式,首先,我们随机初始化模型参数θ,然后开始训练任务T 1 ,接着最小化损失函数L 来更新网络的参数,这样我们就会得到新的参数θ 1 。同理,我们可以接着更新其他两个任务。但以前模型的训练方式,是每个任务都是随机初始化θ开始,每个任务都是独立的。如果我们把三个任务初始化的θ到公用的位置,则不需要更多的梯度更新步骤。MAML就是做这件事的。

(二)专有术语介绍:

构建的任务分为训练任务(Train Task),测试任务(Test Task)。

每个任务都有自己的训练集(Support Set)、测试集( Query Set

N-ways,K-shot(数据中包含N个类别,每个类别有K个样本)

(三)训练流程

以训练 miniImage 数据集为例,按4:1划分数据集

Train Task:从训练集(80 个类,每类 600 个样本)中随机采样 5 个类,每个类 1 个样本(5-way 1-shot),构成Support Set,去学习 learner;然后从训练集的样本(采出的5 个类,每类剩下的样本)中抽 15 个样本采样构成Query Set,用来获得 learner 的 loss,去学习 meta leaner。

Test Task:(20 个类,每类 600 个样本)中随机采样5个类,每个类1 个样本(与training阶段一致,5-way 1-shot),构成支撑集 Support Set,去学习 learner;然后从测试集剩余的样本(采出的5 个类,每类剩下的样本)中抽 15 个样本采样构成 Query Set,用来获得 learner 的参数,进而得到预测的类别概率。

(四)实现代码

## 网络构建部分: refer: https://github.com/dragen1860/MAML-TensorFlow
​
#################################################
# 任务描述:5-ways,1-shot图像分类任务,图像统一处理成 84 * 84 * 3 = 21168的尺寸。
# support set:5 * 1
# query set:5 * 15
# 训练取1个batch的任务:batch size:4
# 对训练任务进行训练时,更新5次:K = 5
#################################################
​
print(support_x) # (4, 5, 21168) 
print(query_x) # (4, 75, 21168)
print(support_y) # (4, 5, 5)
print(query_y) # (4, 75, 5)
print(meta_batchsz) # 4
print(K) # 5
​
model = MAML()
model.build(support_x, support_y, query_x, query_y, K, meta_batchsz, mode='train')
​
class MAML:
    def __init__(self):
        pass
    def build(self, support_xb, support_yb, query_xb, query_yb, K, meta_batchsz, mode='train'):
        """
        :param support_xb: [4, 5, 84*84*3] 
        :param support_yb: [4, 5, n-way]
        :param query_xb:  [4, 75, 84*84*3]
        :param query_yb: [4, 75, n-way]
        :param K:  训练任务的网络更新步数
        :param meta_batchsz: 任务数,4
        """
​
        self.weights = self.conv_weights() # 创建或者复用网络参数;训练任务对应的网络复用meta网络的参数
        training = True if mode is 'train' else False      
        def meta_task(input):
            """
            :param support_x:   [setsz, 84*84*3] (5, 21168)
            :param support_y:   [setsz, n-way] (5, 5)
            :param query_x:     [querysz, 84*84*3] (75, 21168)
            :param query_y:     [querysz, n-way] (75, 5)
            :param training:    training or not, for batch_norm
            :return:
            """
​
            support_x, support_y, query_x, query_y = input
            query_preds, query_losses, query_accs = [], [], [] # 子网络更新K次,记录每一次queryset的结果
 
            ## 第0次对网络进行更新
            support_pred = self.forward(support_x, self.weights, training) # 前向计算support set
            support_loss = tf.nn.softmax_cross_entropy_with_logits(logits=support_pred, labels=support_y) # support set loss
            support_acc = tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(support_pred, dim=1), axis=1),
                                                         tf.argmax(support_y, axis=1))
            grads = tf.gradients(support_loss, list(self.weights.values())) # 计算support set的梯度
            gvs = dict(zip(self.weights.keys(), grads))
            # 使用support set的梯度计算的梯度更新参数,theta_pi = theta - alpha * grads
            fast_weights = dict(zip(self.weights.keys(), \
                    [self.weights[key] - self.train_lr * gvs[key] for key in self.weights.keys()]))
​
            # 使用梯度更新后的参数对quert set进行前向计算
            query_pred = self.forward(query_x, fast_weights, training)
            query_loss = tf.nn.softmax_cross_entropy_with_logits(logits=query_pred, labels=query_y)
            query_preds.append(query_pred)
            query_losses.append(query_loss)
 
            # 第1到 K-1次对网络进行更新
            for _ in range(1, K):           
                loss = tf.nn.softmax_cross_entropy_with_logits(logits=self.forward(support_x, fast_weights, training),
                                                               labels=support_y)
                grads = tf.gradients(loss, list(fast_weights.values()))
                gvs = dict(zip(fast_weights.keys(), grads))
                fast_weights = dict(zip(fast_weights.keys(), [fast_weights[key] - self.train_lr * gvs[key]
                                         for key in fast_weights.keys()]))
                query_pred = self.forward(query_x, fast_weights, training)
                query_loss = tf.nn.softmax_cross_entropy_with_logits(logits=query_pred, labels=query_y)
                # 子网络更新K次,记录每一次queryset的结果
                query_preds.append(query_pred)
                query_losses.append(query_loss)
​
            for i in range(K):
                query_accs.append(tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(query_preds[i], dim=1), axis=1),
                                                                tf.argmax(query_y, axis=1)))
            result = [support_pred, support_loss, support_acc, query_preds, query_losses, query_accs]
            return result
​
        # return: [support_pred, support_loss, support_acc, query_preds, query_losses, query_accs]
        out_dtype = [tf.float32, tf.float32, tf.float32, [tf.float32] * K, [tf.float32] * K, [tf.float32] * K]
        result = tf.map_fn(meta_task, elems=(support_xb, support_yb, query_xb, query_yb),
                           dtype=out_dtype, parallel_iterations=meta_batchsz, name='map_fn')
        support_pred_tasks, support_loss_tasks, support_acc_tasks, \
            query_preds_tasks, query_losses_tasks, query_accs_tasks = result
​
        if mode is 'train':
            self.support_loss = support_loss = tf.reduce_sum(support_loss_tasks) / meta_batchsz
            self.query_losses = query_losses = [tf.reduce_sum(query_losses_tasks[j]) / meta_batchsz
                                                    for j in range(K)]
            self.support_acc = support_acc = tf.reduce_sum(support_acc_tasks) / meta_batchsz
            self.query_accs = query_accs = [tf.reduce_sum(query_accs_tasks[j]) / meta_batchsz
                                                    for j in range(K)]
​
            # 更新meta网络,只使用了第 K步的query loss。这里应该是个超参,更新几步可以调调
            optimizer = tf.train.AdamOptimizer(self.meta_lr, name='meta_optim')
            gvs = optimizer.compute_gradients(self.query_losses[-1])
   # def ********

参考:

1.原论文:Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networkshttps://arxiv.org/pdf/1703.03400.pdf

2.小样本学习(Few-shot Learning)综述小样本学习(Few-shot Learning)综述

3.一文入门元学习(Meta-Learning)(附代码)一文入门元学习(Meta-Learning)(附代码) - 知乎

4.Model-Agnostic Meta-Learning (MAML)模型介绍及算法详解Model-Agnostic Meta-Learning (MAML)模型介绍及算法详解 - 知乎

5.从代码上解析Meta-learning从代码上解析Meta-learning_洛克-李的博客-CSDN博客

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

meta—learning调研及MAML概述 的相关文章

随机推荐

  • 尚硅谷react课程-day04

    目录 1 回调形式的ref 2 回调ref中调用次数问题 3 受控组件 4 非受控组件 1 回调形式的ref 1 利用react提供的ref属性名通过回调函数的属性值去调用节点自身 currentnode 61 gt this input1
  • 尚硅谷react课程-day05

    目录 1 高阶函数 2 组件的生命周期 onChange 61 this saveFormData 39 username 39 这个代码的意思是把saveFormData函数调用后的返回值交给onChange回调 xff0c 不是把sav
  • 快捷式~node.js环境搭建

    1 安装包官网下载 xff1a Node js nodejs org 2 安装完成后修改环境变量 在上面已经完成了 node js 的安装 xff0c 即使不进行此步骤的环境变量配置也不影响node js的使用 但是 xff0c 若不进行环
  • 51单片机LCD1602液晶屏显示方法

    以显示hello world 2022 10 17 为例 首先把LCD1602的模块化程序添加到项目目录中 xff0c 模块化方法在51单片机之程序模块化 学习笔记吧的博客 CSDN博客这里可以学习 实验程序 xff1a include l
  • 深度剖析C语言符号篇

    致前行的人 xff1a 人生像攀登一座山 xff0c 而找寻出路 xff0c 却是一种学习的过程 xff0c 我们应当在这过程中 xff0c 学习稳定冷静 xff0c 学习如何从慌乱中找到生机 目录 1 注释符号 xff1a 2 续接符和转
  • HTML5(入门)

    目录 一 HTML5概念和基本的结构 二 基本标签学习 三 图像标签 四 连接标签 五 列表标签 六 表格标签 table 七 媒体标签 八 网页结构 九 内联框架 iframe 十 表单标签 form 十一 初级验证 一 HTML5概念和
  • Arduino驱动oled

    1 模块介绍 I2C显示屏 xff08 驱动为ssd1306 xff0c 分辨率为128 64 xff09 Arduino nano xff08 Atmega168p xff09 2 模块连接 参考开发板管教定义图可知SCL SDA应该连接
  • 4.3.2、分类编址的 IPv4 地址

    分类编址的 IPv4 地址分为 A B C D E 五类 A 类地址的网络号部分占 8 8 8 比特 xff0c 主机号部分占 24 24 24
  • 解决idea2020版本无法使用actiBPM插件问题

    下载 由于在idea自带的插件商店中搜索不到此插件 xff0c 所以我们需要去官网下载 xff1a 地址 xff1a JetBrains Marketplace 点击下载 xff1a 安装 下载完成之后 xff0c 打开idea的设置 xf
  • 【Ubuntu小工具安装】

    span class token number 1 span 安装谷歌中文拼音输入法 span class token number 2 span 双显示器屏幕设置 和独立显卡显示设置 span class token number 3 s
  • 图像的底层特征、高层特征是什么,语义信息是什么意思

    底层特征指的是 xff1a 轮廓 边缘 颜色 纹理和形状特征 颜色特征 是一种全局特征 描述了图像或图像区域所对应的景物的表面性质 纹理特征 也是一种全局特征 它也描述了图像或图像区域所对应景物的表面性质 形状特征 有两类表示方法 一类是轮
  • 配置与管理samba服务器(Linux)

    实验目的 1 了解samba服务器的功能 2 掌握samba服务器的配置管理 3 掌握samba 客户端程序的使用 4 掌握Windows主机和Linux主机共享文件互访的方法 准备工作 1 物理机 xff08 windows客户端 xff
  • IPV4地址详解

    文章目录 IPV4地址分类编址划分子网无分类编制CIDR路由聚合 应用规划 xff08 子网划分的细节 xff09 定长的子网掩码FLSM变长的子网掩码VLSM IPV4地址 IPV4地址就是给因特网 xff08 Internet xff0
  • 字符串拆分函数strtok实现对字符串的拆分

    前言 xff1a 在本章 xff0c 将介绍如何通过strtok函数来分隔字符串 问 xff1a 现有一段字符串 34 chatgpt 64 wenxin baidu 34 如何才能将 64 去掉打印出剩下的部分呢 xff1f 下面将先介绍
  • python语法糖总结

    python语法糖总结 语法糖 是指在编程语言中一些命令的特殊用法 xff0c 以提升编程速度 xff0c 但不一定降低复杂度 xff0c 还可能增加程序的不可读性 xff0c 但在大部分情况下 xff0c 利大于弊 if 语句 span
  • 互联网职场技术分享的必备技能:VNC 远程桌面演示

    VNC 远程桌面控制 职场必备技能点 初衷引子远程桌面软件被需要言归正传VNCVNC服务端SSH 远程访问协议安装图形管理界面继续安装VNC Server VNC 客户端一些小碎语 初衷 不断涌入高科技开发产业圈的新生代 xff0c 助长了
  • SQL 错误 [1055] [42000]: Expression #2 of SELECT list is not in GROUP BY clause and contains nonaggreg

    在使用group by时 xff0c 报错信息如下 xff1a ERROR 1055 42000 Expression 1 of SELECT list is not in GROUP BY clause and contains nona
  • android手机执行shell脚本

    注意 xff1a 1 手机必须root 2 shell脚本需要有执行权限 流程 xff1a 1 编写shell脚本 system bin sh i 61 1 while i le 100 do let i 43 43 sleep 2 inp
  • 毕业设计使用第三方api

    最近要着手毕业设计了 xff0c 本人的毕设是基于android的 xff0c 和公交有关 xff0c 所以想引用第三方的API xff0c 你们觉得可以吗 xff1f
  • meta—learning调研及MAML概述

    背景 Meta Learning xff0c 又称为 learning to learn xff0c Meta Learning希望使得模型获取一种 学会学习 的能力 xff0c 使其可以在获取已有 知识 的基础上快速学习新的任务 xff0