一文弄懂元学习 (Meta Learing)(附代码实战)《繁凡的深度学习笔记》第 15 章 元学习详解 (上)万字中文综述

《繁凡的深度学习笔记》第 15 章 元学习详解 (上)万字中文综述(DL笔记整理系列)

3043331995@qq.com

https://fanfansann.blog.csdn.net/

https://github.com/fanfansann/fanfan-deep-learning-note

作者:凡凡

version 1.0  2022-1-20

陈述:

1)《繁凡的深度学习笔记》是我自学完成深度学习相关的教材、课程、论文、项目实战等内容之后,自我总结整理创作的学习笔记。写文章就图一乐,大家能看得开心,能学到些许知识,对我而言就已经足够了 ^q^ 。

2)因个人时间、能力和水平有限,本文并非由我个人完全原创,文章部分内容整理自互联网上的各种资源,引用内容标注在每章末的参考资料之中。

3)本文仅供学术交流,非商用。所以每一部分具体的参考资料并没有详细对应。如果某部分不小心侵犯了大家的利益,还望海涵,并联系博主删除,非常感谢各位为知识传播做出的贡献!

4)本人才疏学浅,整理总结的时候难免出错,还望各位前辈不吝指正,谢谢。

5)本文由我个人( CSDN 博主 「繁凡さん」(博客) , 知乎答主 「繁凡」(专栏), Github 「fanfansann」(全部源码) , 微信公众号 「繁凡的小岛来信」(文章 P D F 下载))整理创作而成,且仅发布于这四个平台,仅做交流学习使用,无任何商业用途。

6)「我希望能够创作出一本清晰易懂、可爱有趣、内容详实的深度学习笔记,而不仅仅只是知识的简单堆砌。」

7)本文《繁凡的深度学习笔记》全汇总链接:《繁凡的深度学习笔记》前言、目录大纲 https://fanfansann.blog.csdn.net/article/details/121702108

8)本文的Github 地址:https://github.com/fanfansann/fanfan-deep-learning-note/孩子的第一个『Github』 !给我个%5Cboxed%7B%E2%AD%90%20%5C%2C%5C%2C%5C%2C%5Ctext%7BStarred%7D%7D嘛!谢谢!!o(〃^▽^〃)o

9)此属 version 1.0 ,若有错误,还需继续修正与增删,还望大家多多指点。本文会随着我的深入学习不断地进行完善更新,Github 中的 P D F 版也会尽量每月进行一次更新,所以建议点赞收藏分享加关注,以便经常过来回看!

如果觉得还不错或者能对你有一点点帮助的话,请帮我给本篇文章点个赞,你的支持是我创作的最大动力!^0^

受篇幅所限(CSDN有字数限制),本文《元学习详解》分为上下两篇,这里是上篇。

一文弄懂元学习 (Meta Learing)(附代码实战)《繁凡的深度学习笔记》第 15 章 元学习详解 (下)万字中文综述(待更)

15.5 基于度量的方法
15.5.1 Convolutional Siamese Neural Network
15.5.2 Matching Networks
15.5.2.1 Simple Embedding
15.5.2.2 Full Context Embeddings
15.5.3 Relation Network
15.5.4 Prototypical Networks
15.6 基于模型的方法
15.6.1 Memory-Augmented Neural Networks
15.6.1.1 MANN for Meta-Learning
15.6.1.2 Addressing Mechanism for Meta-Learning
15.6.2 Meta Networks
15.6.2.1 Fast Weights
15.6.2.2 Model Components
15.6.2.3 训练过程
15.7 元学习应用
15.7.1 计算机视觉和图形
15.7.2 元强化学习和机器人技术
15.7.3 环境学习与模拟现实
15.7.4 神经架构搜索(NAS)
15.7.5 贝叶斯元学习
15.7.6 无监督元学习和元学习无监督学习
15.7.7 主动学习
15.7.8 持续、在线和适应性学习
15.7.9 领域适应和领域概括
15.7.10 超参数优化
15.7.11 新颖且生物学上可信的学习者
15.7.12 语言和言语
15.7.13 元学习促进社会福利
15.7.14 抽象和合成推理
15.7.15 系统
15.8 未来展望
15.9 参考资料

一文弄懂元学习 (Meta Learing)(附代码实战)《繁凡的深度学习笔记》第 15 章 元学习详解 (上)万字中文综述

元学习的诞生可以追溯到八十年代,当时的深度学习还没有如今这般火热, Jürgen Schmidhuber 在 1987 年的论文《 Evolutionary principles in self-referential learning》 [1] 一文中宣告了一种全新的机器学习方法的诞生:元学习。后来, Tom Schaul、 Jürgen Schmidhuber 两人在 2010 年的论文《Meta learning》 [2] 中更是确定了元学习的复兴。时间进入 2012 年,随着 Hinton 深度学习崭露头角,元学习与强化学习更是借着深度学习的大潮,在各个领域扩展到了极致(例如人脸识别领域等,均可用元学习来加以强化 cross domain 的性能)。

15.1 元学习 (Meta Learing):Learning To Learn

在现有的机器学习中,我们经常使用某个场景中的大量数据来训练模型,而训练出来的模型只适合这个场景。当场景发生变化时,我们需要重新设计模型并重新训练参数。但是对于人类来说,孩子在成长的过程中会看到很多物体的图片。有一天,孩子认识了这只鸟后,当孩子第一次看到几张狗钩的照片时,他可以很好地对待狗。蜱和鸟来区分。一个学会骑自行车的人可以很快甚至自学如何骑电动自行车……那么是否有可能让机器学习模型具有类似的属性?模型如何仅用少量数据学习新概念和技能?换句话说,你如何让模型学会如何自己学习?这是元学习旨在解决的问题。

元学习(Meta Learning),含义为学会学习,即 learn to learn ,带着学会人类的“学习能力”的期望而诞生的新型机器学习方法。Meta Learning 希望使得模型获取“学会学习”的能力,也可以理解为学会自己调参的能力。在接触到没见过的任务或者迁移到新环境中时,可以根据之前学习到的经验知识和少量的新样本的已有“知识”的基础上快速学习如何应对。元学习能解决的任务可以是任意一类定义好的机器学习任务,如监督学习,强化学习等。例如:

  • 让 Alphago 迅速学会下象棋;
  • 让一个猫图片分类器在只看到几张狗钩的图片后,快速具备分类猫狗图片的能力;

我们期望好的元学习模型能够具备强大的适应能力和泛化能力。在模型进行测试之前,模型会先经过一个自适应环节(Adaptation Process),即根据少量样本学习任务。经过自适应后,模型即可完成新的任务。自适应本质上来说就是一个短暂的学习过程,这就是为什么元学习也被称作“学会”学习。

需要注意的是,虽然元学习同样有“预训练”的思想,但是元学习的内核会有别于迁移学习(Transfer Learning),我们将在下文进行详细探讨。

元学习目前有三种常见的实现方式:

  1. 以快速学习为目标的训练模型(基于优化的方法);
  2. 学习有效的距离测量(基于度量的方法);
  3. 使用具有显式或隐式内存存储的(循环)神经网络(基于模型的方法)。

为了让大家更容易理解,我们尝试通过比较机器学习和元学习这两个概念中已经熟悉的元素来加深理解:

方法目的输入函数输出训练流程
Machine Learing通过训练数据,学习输入 x x x 与输出 y y y 之间的映射,找到函数 f f f x x x f f f y y y1. 初始化 f f f​ 参数
2. 输入训练数据 < x , y > <x,y> <x,y>
3. 计算损失函数,优化 f f f​ 参数
4. 最终得到 y = f ( x ) y=f(x) y=f(x)
Meta Learing通过大量训练任务 T T T​ 以及每个训练任务对应的训练数据 D D D​,找到函数 F F F​, F F F​ 可以输出一个可用于新任务的函数 f f f大量训练任务 T T T 及其对应的训练数据 D D D F F F f f f1. 初始化 F F F 参数
2. 输入大量训练任务 T T T 及其对应的训练数据 D D D
3. 得到 f = F ∗ f=F^* f=F
4. 在新任务中 y = f ( x ) y=f(x) y=f(x)

我们知道,在机器学习中,训练单元就是一段数据,通过数据优化模型的参数。数据分为训练集、测试集和验证集。在元学习中,训练单元是分层的,第一层训练单元是任务,即训练元学习需要准备很多不同的学习任务,第二层训练单元是每个任务对应的数据。两者的目的都是为了找到一个%5Ctext%7BFunction%7D,但是两个%5Ctext%7BFunction%7D的作用和目的是不同的。 %5Ctext%7BFunction%7D在机器学习中直接作用于数据的特征和标签,找出特征和标签之间的相关性。在元学习中,%5Ctext%7BFunction%7D用于寻找适合新任务的%5Ctext%7BFunction%7D%3Af,而f只作用于特定任务本身,即训练模型自行学习和解决问题。

一文弄懂元学习 (Meta Learing)(附代码实战)《繁凡的深度学习笔记》第 15 章 元学习详解 (上)万字中文综述
图 15.1 火热的元学习

15.2 元学习名词解释

典型的元学习技术包括以下概念:

  • Task:元学习通常将训练数据切分成一个个小的数据子集来训练 meta-learner。“task”的意思与多任务学习的“task”不同,是指元学习训练所使用的数据子集。
  • Support set & query set:每个 task 分成 support set 和 query set 两个子集。Support set 对应于算法中的内部更新,query set 对应于算法中的外部更新。
  • Way:是 class(类别)的别称。
  • Shot:指的是每个类别的样本数量。例如:1-shot 指的并不是一共只有一个数据样本,而是每个类有 1 个样本。

图15.2 中展示了一个典型的 K-shot 的元学习方法的一般套路,其训练阶段的数据和测试阶段的数据包含不同的类别,而训练的每个 task 又被切分成 support set 和 query set 。并且在测试的时候,元学习同样是在 task 上面测,每个 task 测出的准确率,汇总求和后求整体的均值。

一文弄懂元学习 (Meta Learing)(附代码实战)《繁凡的深度学习笔记》第 15 章 元学习详解 (上)万字中文综述
图 15.2 一个典型的元学习将数据切分成 task 训练,而每个 task 包含的 5 个分类不同,1-shot 是指每个分类只有一个样本。

15.3 元学习问题定义

15.3.1 元学习形式化

%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C元学习本身很难定义,它已经以各种不一致的方式被广泛使用,甚至在当前的神经网络文献中也是如此。因此这里引用最新元学习综述Meta-Learning in Neural Networks: A Survey[19] 的定义,介绍一种特定的定义和关键术语,旨在帮助理解大量文献。

%E2%9D%91 传统机器学习

%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C我们知道,在传统的监督机器学习中,有训练数据集%5Cmathcal%20D%20%3D%20%5C%7B%28x_1%2C%20y_1%29%EF%BC%8C%E2%80%A6%EF%BC%8C%28x_N%2C%20y_N%29%5C%7D如(输入图像,输出标签)对。我们可以通过解决以下问题来训练预测模型%5Chat%7By%7D%3Df_%7B%5Ctheta%20%7D%5Cleft%20%28%20x%20%5Cright%20%29参数化%CE%B8
%5Ctheta%5E%7B%2A%7D%3D%5Carg%20%5Cmin%20_%7B%5Ctheta%7D%20%5Cmathcal%7BL%7D%28%5Cmathcal%7BD%7D%20%3B%20%5Ctheta%2C%20%5Comega%29
其中%5Cmathcal%20L是用于衡量%5Cmathcal%20L的真实标签和预测标签之间误差的损失函数。泛化能力是通过使用已知标签评估多个测试点来衡量的。传统的机器学习假设是针对每个问题%5Cmathcal%20D从头开始执行这种优化;和%5Comega是预先指定的。但是,%5Comega 的规范会极大地影响性能指标,例如准确性或数据效率。元学习试图通过学习算法本身来改进这些指标,而不是假设它是预先指定和固定的。这通常是通过重新审视上面的第一个假设并从任务分布中学习而不是从头开始来实现的。

%E2%9D%91​元学习:任务分配视图我们之前说过,元学习旨在通过学习“如何学习”来提高性能。具体来说,元学习的总体思路是学习一种通用的跨任务泛化的学习算法,理想情况下让每一个新任务都比上一个学习得更好。这里我们简单地将任务%5Cmathcal%20T​定义为数据集%5Cmathcal%20D​和损失函数%5Cmathcal%20L​:%5Cmathcal%20T%20%3D%20%5C%7B%5Cmathcal%20D%2C%5Cmathcal%20L%5C%7D​,数据集包含特征向量x​和标签y​,任务分布表示为p%28%5Cmathcal%7BT%7D%29​。那么学习如何学习的学习目标可以表示为:
%5Cmin%20_%7B%5Comega%7D%20%5Cunderset%7B%5Cmathcal%7BT%7D%20%5Csim%20p%28%5Cmathcal%7BT%7D%29%7D%7B%5Cmathbb%7BE%7D%7D%20%5Cmathcal%7BL%7D%28%5Cmathcal%7BD%7D%20%3B%20%5Comega%29
其中%5Cmathcal%20L%28%5Cmathcal%20D%3B%5Comega%29为用于测量在数据集%5Cmathcal%20D上使用%5Comega训练的模型的性能的损失函数。“如何学习”的知识(参数)%CF%89通常被称作跨任务知识 (across-task knowledge) 或元知识 (meta-knowledge)。 简而言之,我们希望能够学到一个通用的参数 meta-knowledge:%5Comega,使得不同的 task 的%5Cmathcal%20L损失函数都越小越好。

%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C为了在实践中解决这个问题,我们通常假设访问一组从p%28%5Cmathcal%7BT%7D%29中提取的源任务,并通过这些源任务学习%CF%89

%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C类似于传统的机器学习,meta learning 同样分为 Meta-train 和 Meta-test 两个阶段。形式上,我们将 Meta-train 阶段使用的使用的M个源任务 (source tasks) 集表示为D%5E%7B%5Ctext%7Bsource%7D%7D%3D%5Cleft%5C%7B%5Cleft%28%5Cmathcal%7BD%7D_%7B%5Ctext%20%7Bsource%20%7D%7D%5E%7B%5Ctext%20%7Btrain%20%7D%7D%2C%20%5Cmathcal%7BD%7D_%7B%5Ctext%20%7Bsource%20%7D%7D%5E%7B%5Ctext%20%7Bval%20%7D%7D%5Cright%29%5E%7B%28i%29%7D%5Cright%5C%7D_%7Bi%3D1%7D%5E%7BM%7D,其中每个任务都有训练和验证数据。通常,源序列和验证数据集分别称为支持集 (support set) 和查询集 (query set) (或者预测集 (prediction set))。我们通过采样大量的源任务来学习 meta knowledge (也即最大化它的最大似然估计):
%5Comega%5E%7B%2A%7D%3D%5Carg%20%5Cmax%20_%7B%5Comega%7D%20%5Clog%20p%5Cleft%28%5Comega%20%5Cmid%20%5Cmathscr%7BD%7D_%7B%5Ctext%20%7Bsource%20%7D%7D%5Cright%29
%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C然后我们将元测试阶段使用的Q目标任务 (target tasks) 集合表示为%5Cmathcal%20D%5E%7B%5Ctext%7Btarget%7D%7D%3D%5Cleft%5C%7B%5Cleft%28%5Cmathcal%7BD%7D_%7B%5Ctext%20%7Btarget%20%7D%7D%5E%7B%5Ctext%20%7Btrain%20%7D%7D%2C%20%5Cmathcal%7BD%7D_%7B%5Ctext%20%7Btarget%20%7D%7D%5E%7B%5Ctext%20%7Btest%20%7D%7D%5Cright%29%5E%7B%28i%29%7D%5Cright%5C%7D_%7Bi%3D1%7D%5E%7BQ%7D,其中每个任务都有训练和测试数据。在 Meta-Testing 阶段,我们使用学习到的元知识对每个之前未见过的目标任务%5Ctext%7Btask%7D_i的基础模型进行训练:
%5Ctheta%5E%7B%2A%28i%29%7D%3D%5Carg%20%5Cmax%20_%7B%5Ctheta%7D%20%5Clog%20p%5Cleft%28%5Ctheta%20%5Cmid%20%5Comega%5E%7B%2A%7D%2C%20%5Cmathcal%7BD%7D_%7B%5Ctext%20%7Btarget%20%7D%7D%5E%7B%5Ctext%20%7Btrain%20%7D%5E%7B%28i%29%7D%7D%5Cright%29
%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C即 Meta-Testing 的目标就是基于已经学到的 Meta-Knowledge (%5Comega) 来寻找当前任务%5Ctext%7Btask%7D_i的最优参数%5Ctheta%5E%2A%28i%29。此时可以通过%5Ctheta%5E%2A%7B%28i%29%7D在每个目标任务%5Cmathcal%7BD%7D_%7B%5Ctext%20%7Btarget%20%7D%7D%5E%7B%5Ctext%20%7Btest%20%7D%28i%29%7D的测试拆分上的表现来评估我们的元学习器 (meta-learner) 的准确性。

%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C为了训练 Meta-Knowledge ,meta learning 提出了两种常用的做法:双层优化观(Bilevel Optimization View) 和前馈模型观(Feed-Forward Model View)。

%E2%9D%91元学习:双层优化观(Bilevel Optimization View)

双层优化观(Bilevel Optimization View)是指一个层次优化问题,其中一个优化包含另一个优化作为约束[25],[43]。我们知道正常训练模型的流程是先使用 train set 训练,然后再使用 test/val set 测试。因此,在 Meta-training 的过程中,我们可以构造两层优化过程,在 inner-loop 即内层中使用 train set 更新任务模型参数,然后在 outer-loop 基于更新后的模型对 meta-knowledge 进行优化。双层优化的元训练可以表示为:
%5Cbegin%7Baligned%7D%5Comega%5E%7B%2A%7D%20%26%3D%5Cunderset%7B%5Comega%7D%7B%5Carg%20%5Cmin%20%7D%20%5Csum_%7Bi%3D1%7D%5E%7BM%7D%20%5Cmathcal%7BL%7D%5E%7B%5Ctext%20%7Bmeta%20%7D%7D%5Cleft%28%5Ctheta%5E%7B%2A%28i%29%7D%28%5Comega%29%2C%20%5Comega%2C%20%5Cmathcal%7BD%7D_%7B%5Ctext%20%7Bsource%20%7D%7D%5E%7B%5Ctext%20%7Bval%20%7D%28i%29%7D%5Cright%29%20%5C%5C%5Ctext%20%7B%20s.t.%20%7D%20%5Ctheta%5E%7B%2A%28i%29%7D%28%5Comega%29%20%26%3D%5Cunderset%7B%5Ctheta%7D%7B%5Carg%20%5Cmin%20%7D%20%5Cmathcal%7BL%7D%5E%7B%5Ctext%20%7Btask%20%7D%7D%5Cleft%28%5Ctheta%2C%20%5Comega%2C%20%5Cmathcal%7BD%7D_%7B%5Ctext%20%7Bsource%20%7D%7D%5E%7B%5Ctext%20%7Btrain%20%7D%28i%29%7D%5Cright%29%5Cend%7Baligned%7D
其中%5Cmathcal%20L%5E%7B%5Ctext%7Bmeta%7D%7D​和%5Cmathcal%20L%5E%7B%5Ctext%7Btask%7D%7D​分别指外部目标损失函数和内部目标损失函数,例如few-shot分类情况下的交叉熵。两层范式的一个关键特征是内外层的主从不对称:内层优化Eq.5​以外层定义的学习策略%CF%89​为条件,但在训练期间无法更改%CF%89​ .

训练过程如下图所示:

一文弄懂元学习 (Meta Learing)(附代码实战)《繁凡的深度学习笔记》第 15 章 元学习详解 (上)万字中文综述
图 15.3 双层优化(图片来源 [26])

假设我们已经找到了最优的 meta knowledge,那么意味着我们使用 train set 优化模型得到的参数也是最优的。

请注意,在上述元训练的形式化中,我们使用了任务分布的概念,并使用了来自该分布的 M​ 任务样本。虽然此步骤在元学习文献中有效且广泛使用,但它不是元学习的要求。更正式地说,如果给定一个单独的训练和测试数据集,我们可以拆分训练集以获得验证数据,以便 %5Cmathscr%7BD%7D_%7B%5Ctext%20%7Bsource%20%7D%7D%3D%5Cleft%28%5Cmathcal%7BD%7D_%7B%5Ctext%20%7Bsource%20%7D%7D%5E%7B%5Ctext%20%7Btrain%20%7D%7D%2C%20%5Cmathcal%7BD%7D_%7B%5Ctext%20%7Bsource%20%7D%7D%5E%7B%5Ctext%20%7Bval%20%7D%7D%5Cright%29​ 用于元训练,而用于元测试,我们可以使用 %5Cmathscr%7BD%7D_%7B%5Ctext%20%7Btarget%20%7D%7D%3D%5Cleft%28%5Cmathcal%7BD%7D_%7B%5Ctext%20%7Bsource%20%7D%7D%5E%7B%5Ctext%20%7Btrain%20%7D%7D%20%5Ccup%20%5Cmathcal%7BD%7D_%7B%5Ctext%20%7Bsource%20%7D%7D%5E%7B%5Ctext%20%7Bval%20%7D%7D%2C%20%5Cmathcal%7BD%7D_%7B%5Ctext%20%7Btarget%20%7D%7D%5E%7B%5Ctext%20%7Best%20%7D%7D%5Cright%29​。虽然元训练中经常使用不同的训练值拆分,但我们仍然可以多次学习%CF%89​,这被认为是M%3DQ%3D1​​。

BiLevel Optimization 的思想非常重要,几乎所有的 meta learning 问题都可以套用。

%E2%9D%91元学习:前馈模型观 (Feed-Forward Model View)

在元学习几十年来的发展过程中,有许多元学习方法以前馈的方式来合成模型,而不是像上面的Eq.4%5C%20%5C%26%5C%20Eq.5​​ 那样通过显式的迭代优化。 因为 BiLevel Optimization 的 inner-loop 中并不一定要用优化的方法,可以是任意的方式。即优化后的模型是可以通过 meta-knowledge 隐式的来表示:
%5Cmin%20_%7B%5Comega%7D%20%5Cunderset%7B%5Cunderset%7B%5Cleft%28%5Cmathcal%7BD%7D%5E%7B%5Ctext%7Btr%7D%7D%2C%20%5Cmathcal%7BD%7D%5E%7B%5Ctext%7Bval%7D%7D%5Cright%29%20%5Cin%20%5Cmathcal%7BT%7D%7D%7B%5Cmathcal%7BT%7D%20%5Csim%20p%28%5Cmathcal%7BT%7D%29%7D%7D%7B%5Cmathbb%7BE%7D%7D%5C%20%5C%20%5C%20%7B%5Csum_%7B%28%5Cmathbf%7Bx%7D%2C%20y%29%20%5Cin%20%5Cmathcal%7BD%7D%5E%7B%5Ctext%7Bval%7D%7D%7D%5Cleft%5B%5Cleft%28%5Cmathbf%7Bx%7D%5E%7BT%7D%20%5Cmathbf%7Bg%7D_%7B%5Comega%7D%5Cleft%28%5Cmathcal%7BD%7D%5E%7B%5Ctext%7Btr%7D%7D%5Cright%29-y%5Cright%29%5E%7B2%7D%5Cright%5D%7D
其中%5Cmathbf%7Bg%7D_%7B%5Comega%7D%28%5Cmathcal%20D%5E%7B%5Ctext%7Btr%7D%7D%29​ 为基于 meta knowledge 和 train set 得到的隐式模型。

这里通过优化任务分布来进行元训练。对于每项任务,都会拆分为一个训练集和一个验证集。 训练集%5Cmathcal%20D%5E%7B%5Ctext%7Btr%7D%7D被嵌入到一个向量%5Cmathbf%20g_%7B%5Comega%7D中,该向量定义了线性回归权重以从验证集中预测示例%5Cmathbf%20x。通过训练函数%5Cmathbf%20g_%7B%5Comega%7D优化Eq.6的 ‘learns to learn’。 将训练集映射到权重向量。我们期望从p%28%5Cmathcal%20T%29中提取的新的元测试任务%5Cmathcal%20T%5E%7B%5Ctext%7Bte%7D%7D也可以在%5Cmathbf%20g_%7B%5Comega%7D得到一个较好的预测。该系列中的方法因所使用的预测模型%5Cmathbf%20g的复杂性以及支持集的嵌入方式而异(例如,通过池化层、CNN 层 或 RNN 层等)。这些模型也被称为 amortized,因为学习新任务的成本通过%5Cmathbf%20g_%7B%5Comega%7D%28%5Ccdot%29减少到前馈操作,在%5Comega的元训练期间已经进行了迭代优化。

15.3.2 简单示例

%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C 我们现在假设存在一个任务分布,我们从中抽取许多任务作为训练集。在这个训练集上训练的一个好的元学习模型应该在这个空间的所有任务上表现良好,即使是以前从未见过的任务。每个任务可以表示为一个数据集%5Cmathcal%7BD%7D,其中包括特征向量x和标签y,分布表示为p%28%5Cmathcal%7BD%7D%29。那么最优的元学习模型参数可以表示为:

%5Ctheta%5E%2A%20%3D%20%5Carg%5Cmin_%5Ctheta%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%7BD%7D%5Csim%20p%28%5Cmathcal%7BD%7D%29%7D%20%5B%5Cmathcal%7BL%7D_%5Ctheta%28%5Cmathcal%7BD%7D%29%5D

%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C少样本学习(Few-shot classification)是元学习的在监督学习中的一个实例。数据集%5Cmathcal%7BD%7D经常被划分为两部分,一个用于学习的支持集(support set)S,和一个用于训练和测试的预测集(prediction set)B,即%5Cmathcal%7BD%7D%3D%5Clangle%20S%2C%20B%5Crangle。K-shot N-class分类任务,即支持集中有N类数据,每类数据有K个带有标注的样本。

一文弄懂元学习 (Meta Learing)(附代码实战)《繁凡的深度学习笔记》第 15 章 元学习详解 (上)万字中文综述
图 15.4 4-shot 2-class 图像分类的例子。 (图像来源 https://www.pinterest.com/)

%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C一个数据集%5Cmathcal%7BD%7D包含许多对特征向量和标签,即%5Cmathcal%7BD%7D%20%3D%20%5C%7B%28%5Cmathbf%7Bx%7D_i%2C%20y_i%29%5C%7D。每个标签都属于一个标签类%5Cmathcal%7BL%7D。假设我们的分类器f_%5Ctheta的输入是特征向量%5Cmathbf%7Bx%7D,输出是属于第二类的概率P_%5Ctheta%28y%5Cvert%5Cmathbf%7Bx%7D%29%5Ctheta是分类器的参数。

%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C如果我们每次选一个B%20%5Csubset%20%5Cmathcal%7BD%7D作为训练的 batch ,则最佳的模型参数,应当能够最大化,多组 batch 的正确标签概率之和。

%5Cbegin%7Baligned%7D%20%5Ctheta%5E%2A%20%26%3D%20%7B%5Carg%5Cmax%7D_%7B%5Ctheta%7D%20%5Cmathbb%7BE%7D_%7B%28%5Cmathbf%7Bx%7D%2C%20y%29%5Cin%20%5Cmathcal%7BD%7D%7D%5BP_%5Ctheta%28y%20%5Cvert%20%5Cmathbf%7Bx%7D%29%5D%20%26%5C%5C%20%5Ctheta%5E%2A%20%26%3D%20%7B%5Carg%5Cmax%7D_%7B%5Ctheta%7D%20%5Cmathbb%7BE%7D_%7BB%5Csubset%20%5Cmathcal%7BD%7D%7D%5B%5Csum_%7B%28%5Cmathbf%7Bx%7D%2C%20y%29%5Cin%20B%7DP_%5Ctheta%28y%20%5Cvert%20%5Cmathbf%7Bx%7D%29%5D%20%26%20%5Cscriptstyle%7B%5Ctext%7B%3B%20trained%20with%20mini-batches.%7D%7D%20%5Cend%7Baligned%7D

%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C少样本学习的目标是,在小规模的 support set 上 “快速学习”(类似fine-tuning (微调技巧) )后,能够减少在 prediction set 上的预测误差。为了训练模型快速学习的能力,我们在训练的时候按照以下步骤:

  1. 对标签子集 L%5Csubset%5Cmathcal%7BL%7D 进行采样。
  2. 根据采样的标签子集,采样一个 support setS%5EL%20%5Csubset%20%5Cmathcal%7BD%7D和一个 training batchB%5EL%20%5Csubset%20%5Cmathcal%7BD%7DS%5ELB%5EL中的数据的标签都属于L,即y%20%5Cin%20L%2C%20%5Cforall%20%28x%2C%20y%29%20%5Cin%20S%5EL%2C%20B%5EL .
  3. 把 support set 作为模型的输入,进行“快速学习”。注意,不同的算法具有不同的学习策略,但总的来说,这一步不会永久性更新模型参数。
  4. 把 prediction set 作为模型的输入,计算模型在B%5EL上的 loss,根据这个 loss 进行反向传播更新模型参数。这一步与监督学习一致。

您可以将每一对%28S%5EL%2C%20B%5EL%29视为一个数据点。模型经过训练以在其他数据集上扩展。下面公式中的红色部分是元学习目标相对于监督学习目标的额外部分。

%5Ctheta%20%3D%20%5Carg%5Cmax_%5Ctheta%20%5Ccolor%7Bred%7D%7BE_%7BL%5Csubset%5Cmathcal%7BL%7D%7D%5B%7D%20E_%7B%5Ccolor%7Bred%7D%7BS%5EL%20%5Csubset%5Cmathcal%7BD%7D%2C%20%7DB%5EL%20%5Csubset%5Cmathcal%7BD%7D%7D%20%5B%5Csum_%7B%28x%2C%20y%29%5Cin%20B%5EL%7D%20P_%5Ctheta%28x%2C%20y%5Ccolor%7Bred%7D%7B%2C%20S%5EL%7D%29%5D%20%5Ccolor%7Bred%7D%7B%5D%7D

15.3.2 学习器和元学习器

还有一种常见的看待 meta learning 的视角,把模型的更新划分为了两个阶段:

  • 根据给定的任务,训练一个分类器f_%5Ctheta作为“学习者”模型完成任务
  • 同时,训练一个元学习器g_%5Cphi,根据support setS学习如何更新学习器模型的参数。%5Ctheta%27%20%3D%20g_%5Cphi%28%5Ctheta%2C%20S%29

那么在最终的优化目标中,我们需要更新%5Ctheta%5Cphi来最大化:

%5Cmathbb%7BE%7D_%7BL%5Csubset%5Cmathcal%7BL%7D%7D%5B%20%5Cmathbb%7BE%7D_%7BS%5EL%20%5Csubset%5Cmathcal%7BD%7D%2C%20B%5EL%20%5Csubset%5Cmathcal%7BD%7D%7D%20%5B%5Csum_%7B%28%5Cmathbf%7Bx%7D%2C%20y%29%5Cin%20B%5EL%7D%20P_%7Bg_%5Cphi%28%5Ctheta%2C%20S%5EL%29%7D%28y%20%5Cvert%20%5Cmathbf%7Bx%7D%29%5D%5D

15.3.4 Meta Learning的分类

%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C元学习主要有三类常见的方法:基于度量的方法(metric-based),基于模型的方法(model-based),基于优化的方法(optimization-based)。 Oriol Vinyals在NIPS 2018的meta-learning symposium上做了一个很好的总结(http://metalearning-symposium.ml/files/vinyals.pdf):

Model-basedMetric-basedOptimization-based
Key ideaRNN; memoryMetric learningGradient descent
How P θ ( y ∣ x ) P_\theta(y \vert \mathbf{x}) Pθ(yx) is modeled? f θ ( x , S ) f_\theta(\mathbf{x}, S) fθ(x,S) ∑ ( x i , y i ) ∈ S k θ ( x , x i ) y i \sum_{(\mathbf{x}_i, y_i) \in S} k_\theta(\mathbf{x}, \mathbf{x}_i)y_i (xi,yi)Skθ(x,xi)yi (*) P g ϕ ( θ , S L ) ( y ∣ x ) P_{g_\phi(\theta, S^L)}(y \vert \mathbf{x}) Pgϕ(θ,SL)(yx)

(*)k_%5Ctheta是一个衡量%5Cmathbf%7Bx%7D_i%5Cmathbf%7Bx%7D相似度的kernel function。

%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C元学习综述Meta-Learning in Neural Networks: A Survey[19]中创新性地提出了一种新的分类方法,即对元学习 Meta Learing 按照是什么(What,Meta-Representation),怎么做(How,Meta-Optimizer),为什么(Why,Objective)来分类。

我们沿着三个独立的轴引入了一个新的分类。对于每一个轴,我们提供了反映当前的元学习环境分类:

%E2%9D%91元的意思(“什么?”)

第一个轴是元知识%CF%89表示的选择。这可以将用于优化器初始化的模型参数 [19] 的估计扩展到程序归纳的可读代码 [89] 。注意,基本模型表示%CE%B8通常是特定于应用的,例如计算机视觉中的卷积神经网络(CNN)[1]。

%E2%9D%91 元优化器(“如何?”)

第二个轴是在元训练(见等式5)1期间用于外部水平的优化器的选择。ω的外层优化器可以有多种形式,从梯度下降[19],到强化学习[89]和进化搜索[23]。

%E2%9D%91 元目标(“为什么?”)

第三个轴是元学习的目标,它由元目标%5Cmathcal%20L_%5Ctext%7Bmeta%7D、任务分布p%28%5Cmathcal%20T%29​​​ 和两个层次之间的数据流的选择决定。它们可以一起为不同的目的定制元学习,例如样本有效的少样本学习[19]、[40]、快速多样本优化[89]、[91]或对域移位的鲁棒性[44]、[92]、标签噪声[93]和对抗攻击[94]。


此处的引用文章详见原综述论文 [19]。

一文弄懂元学习 (Meta Learing)(附代码实战)《繁凡的深度学习笔记》第 15 章 元学习详解 (上)万字中文综述
图 15.5 元学习环境概述,包括算法设计(元优化器,元表示,元目标)和应用

15.4 基于优化的方法

本节15.4 基于优化的方法部分内容翻译自Meta-Learning: Learning to Learn Fast[18],部分内容由我个人原创整理而成,一些更准确的描述请参见原文。

%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C 深度学习模型通过反向传播梯度来学习。基于梯度的优化方法不适用于只有少量训练样本的情况,仅几步就很难达到收敛。为了调整现有的优化算法,使模型可以只用少量样本训练出更好的模型,提出了一种基于优化的元学习算法。

15.4.1 MAML

%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2CModel-Agnostic Meta-Learning简称MAML[17],是一种非常通用的优化算法,可以被用于任何基于梯度下降学习的模型。MAML 的目的是获取一组更好的模型初始化参数(即让模型自己学会初始化)。

%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C假设我们的模型是f_%5Ctheta,参数是%5Ctheta。给定任务%5Cmathcal%20T_i及其对应的数据集%28%5Cmathcal%7BD%7D%5E%7B%28i%29%7D_%5Ctext%7Btrain%7D%2C%20%5Cmathcal%7BD%7D%5E%7B%28i%29%7D_%5Ctext%7Btest%7D%29,我们可以对模型参数执行一个或多个梯度下降。 (下式只进行了一次迭代):

%5Ctheta%27_i%20%3D%20%5Ctheta%20-%20%5Calpha%20%5Cnabla_%5Ctheta%5Cmathcal%7BL%7D%5E%7B%280%29%7D_%7B%5Cmathcal%20T_i%7D%28f_%5Ctheta%29
其中%5Cmathcal%7BL%7D%5E%7B%280%29%7D​是分批小数据计算的损失函数,编号为一文弄懂元学习 (Meta Learing)(附代码实战)《繁凡的深度学习笔记》第 15 章 元学习详解 (上)万字中文综述0​。

一文弄懂元学习 (Meta Learing)(附代码实战)《繁凡的深度学习笔记》第 15 章 元学习详解 (上)万字中文综述
图 15.6 MAML图示。 (图像来源:原论文 [17])

%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C当然,上面这个式子只针对一个特定的任务进行了优化。而 MAML 为了能够更好地扩展到一系列任务上,我们想要寻找一个在给定任意任务后微调过程最高效的%5Ctheta%5E%2A。现在,假设我们采样了一个编号为1的数据 batch 用于更新元目标。对应的损失函数记为%5Cmathcal%7BL%7D%5E%7B%281%29%7D%5Cmathcal%7BL%7D%5E%7B%280%29%7D%5Cmathcal%7BL%7D%5E%7B%281%29%7D的上标只代表着数据 batch 不同,他们都是同一个目标方程计算得到的。则有:

%5Cbegin%7Baligned%7D%20%5Ctheta%5E%2A%20%26%3D%20%5Carg%5Cmin_%5Ctheta%20%5Csum_%7B%5Cmathcal%20T_i%20%5Csim%20p%28%5Cmathcal%20T%29%7D%20%5Cmathcal%7BL%7D_%7B%5Cmathcal%20T_i%7D%5E%7B%281%29%7D%20%28f_%7B%5Ctheta%27_i%7D%29%20%3D%20%5Carg%5Cmin_%5Ctheta%20%5Csum_%7B%5Cmathcal%20T_i%20%5Csim%20p%28%5Cmathcal%20T%29%7D%20%5Cmathcal%7BL%7D_%7B%5Cmathcal%20T_i%7D%5E%7B%281%29%7D%20%28f_%7B%5Ctheta%20-%20%5Calpha%5Cnabla_%5Ctheta%20%5Cmathcal%7BL%7D_%7B%5Cmathcal%20T_i%7D%5E%7B%280%29%7D%28f_%5Ctheta%29%7D%29%20%26%20%5C%5C%20%5Ctheta%20%26%5Cleftarrow%20%5Ctheta%20-%20%5Cbeta%20%5Cnabla_%7B%5Ctheta%7D%20%5Csum_%7B%5Cmathcal%20T_i%20%5Csim%20p%28%5Cmathcal%20T%29%7D%20%5Cmathcal%7BL%7D_%7B%5Cmathcal%20T_i%7D%5E%7B%281%29%7D%20%28f_%7B%5Ctheta%20-%20%5Calpha%5Cnabla_%5Ctheta%20%5Cmathcal%7BL%7D_%7B%5Cmathcal%20T_i%7D%5E%7B%280%29%7D%28f_%5Ctheta%29%7D%29%20%26%20%5Cscriptstyle%7B%5Ctext%7B%3B%20updating%20rule%7D%7D%20%5Cend%7Baligned%7D
MAML 的损失函数即为每个任务的损失函数之和:
L%28%5Cphi%29%3D%5Csum_%7Bi%3D1%7D%5En%5Cmathcal%20L%5En%28%5Chat%5Ctheta%5En%29
MAML 算法过程的伪代码如下:

一文弄懂元学习 (Meta Learing)(附代码实战)《繁凡的深度学习笔记》第 15 章 元学习详解 (上)万字中文综述
一文弄懂元学习 (Meta Learing)(附代码实战)《繁凡的深度学习笔记》第 15 章 元学习详解 (上)万字中文综述
图 15.7 MAML的训练流程


图 15.4 为 MAML 的训练流程。内部更新将初始化参数的%5Ctheta更新为%5Ctheta%27而外部更新利用%5Ctheta%27作为网络参数去计算损失函数,这个损失函数最后用来计算元梯度来外部更新。

人们创造出了一个可以用来测试元学习效果的另一种语言:火星文字符图像数据集:Omniglot:https://github.com/brendenlake/omniglot 。Omniglot 包含1623​ 个不同的火星文字符,每个字符包含20​ 个手写的 case。这个任务是判断每个手写的 case 属于哪一个火星文字符。我们可以使用 Omniglot 将作为元学习的任务来源。

整个 MAML 算法流程如下:

  1. 准备N个训练任务 (Train Task) 、每个训练任务对应的 Support Set 和 Query Set。再准备几个测试任务,测试任务用于评估 meta learning 学习到的参数的效果。
  2. 定义网络结构,如 CNN,并初始化一个 meta 网络的参数为%5Cphi%5E0​ , meta 网络是最终要用来应用到新的测试任务中的网络,该网络中存储了 meta knowledge%5Comega​​ 。一文弄懂元学习 (Meta Learing)(附代码实战)《繁凡的深度学习笔记》第 15 章 元学习详解 (上)万字中文综述
    图 15.8 MAML 算法流程(图片来源 [28])
  3. 开始执行迭代“预训练”:
    a. 采样1个训练任务m(或者1个 batch 的几个训练任务,上图显示的是采样了1个训练任务)。将 meta 网络的参数%5Cphi%5E0赋值给任务m独有的网络,得到%5Chat%5Ctheta%5Em(初始的%5Chat%5Ctheta%5Em%3D%5Cphi%5E0) ;
    b. 使用任务m的 Support Set,基于任务m的学习率%5Calpha_m,对%5Chat%5Ctheta%5Em进行1次优化,更新%5Chat%5Ctheta%5Em
    c. 基于1次优化后的%5Chat%5Ctheta%5Em,使用 Query Set 计算任务m的损失函数 ——%5Cmathcal%20L%5Em%28%5Chat%5Ctheta%5Em%29,并计算%5Cmathcal%20L%5Em%28%5Chat%5Ctheta%5Em%29%5Chat%5Ctheta%5Em的梯度;
    d. 用该梯度,乘以 meta 网络的学习率%5Calpha_%7B%5Ctext%7Bmeta%7D%7D​,更新参数%5Cphi%5E0得到%5Cphi%5E1(注意,第一个蓝色箭头与第二个绿色箭头是平行的,这里的含义是%5Cphi%5E0的更新与%5Chat%5Ctheta%5Em​ 梯度的方向一致);
    e. 采样1个任务n,将参数中赋值给任务n,得到%5Chat%5Ctheta%5En(初始的%5Chat%5Ctheta%5En%3D%5Cphi%5E1);
    f. 然后使用任务n​ 的训练数据,基于任务n​ 的学习率,对%5Chat%5Ctheta%5En​ 进行一次优化更新%5Chat%5Ctheta%5En​ ;
    g. 基于1次优化后的%5Chat%5Ctheta%5En,使用 Query Set 计算任务n的损失函数——%5Cmathcal%20L%5En%28%5Chat%5Ctheta%5En%29,并计算%5Chat%5Ctheta%5En上的梯度;
    h. 用该梯度,乘以 meta 网络的学习率%5Calpha_%7B%5Ctext%7Bmeta%7D%7D,更新%5Cphi%5E1,得到%5Cphi%5E2
    i. 在训练任务上,重复执行 a-h 的过程。
  4. 通过 3 得到 meta 网络的参数,该参数可以在测试任务中,使用测试任务的 Support Set 对 meta 网络的参数进行 finetuing。
  5. 最终使用测试任务的 Query Set 评估 meta learning 的效果。

基于 Tensorflow2.0 实现的 MAML 部分代码如下:

完整代码详见:https://github.com/dragen1860/MAML-TensorFlow

## 网络构建部分: refer: https://github.com/dragen1860/MAML-TensorFlow

#################################################
# 任务描述:5-ways,1-shot图像分类任务,图像统一处理成 84 * 84 * 3 = 21168的尺寸。
# support set:5 * 1
# query set:5 * 15
# 训练取1个batch的任务:batch size:4
# 对训练任务进行训练时,更新5次:K = 5
#################################################

print(support_x) # (4, 5, 21168) 
print(query_x) # (4, 75, 21168)
print(support_y) # (4, 5, 5)
print(query_y) # (4, 75, 5)
print(meta_batchsz) # 4
print(K) # 5

model = MAML()
model.build(support_x, support_y, query_x, query_y, K, meta_batchsz, mode='train')

class MAML:
    def __init__(self):
        pass
    def build(self, support_xb, support_yb, query_xb, query_yb, K, meta_batchsz, mode='train'):
        """
        :param support_xb: [4, 5, 84*84*3] 
        :param support_yb: [4, 5, n-way]
        :param query_xb:  [4, 75, 84*84*3]
        :param query_yb: [4, 75, n-way]
        :param K:  训练任务的网络更新步数
        :param meta_batchsz: 任务数,4
        """

        self.weights = self.conv_weights() # 创建或者复用网络参数;训练任务对应的网络复用meta网络的参数
        training = True if mode is 'train' else False      
        def meta_task(input):
            """
            :param support_x:   [setsz, 84*84*3] (5, 21168)
            :param support_y:   [setsz, n-way] (5, 5)
            :param query_x:     [querysz, 84*84*3] (75, 21168)
            :param query_y:     [querysz, n-way] (75, 5)
            :param training:    training or not, for batch_norm
            :return:
            """

            support_x, support_y, query_x, query_y = input
            query_preds, query_losses, query_accs = [], [], [] # 子网络更新K次,记录每一次queryset的结果
 
            ## 第0次对网络进行更新
            support_pred = self.forward(support_x, self.weights, training) # 前向计算support set
            support_loss = tf.nn.softmax_cross_entropy_with_logits(logits=support_pred, labels=support_y) # support set loss
            support_acc = tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(support_pred, dim=1), axis=1),
                                                         tf.argmax(support_y, axis=1))
            grads = tf.gradients(support_loss, list(self.weights.values())) # 计算support set的梯度
            gvs = dict(zip(self.weights.keys(), grads))
            # 使用support set的梯度计算的梯度更新参数,theta_pi = theta - alpha * grads
            fast_weights = dict(zip(self.weights.keys(), \
                    [self.weights[key] - self.train_lr * gvs[key] for key in self.weights.keys()]))

            # 使用梯度更新后的参数对quert set进行前向计算
            query_pred = self.forward(query_x, fast_weights, training)
            query_loss = tf.nn.softmax_cross_entropy_with_logits(logits=query_pred, labels=query_y)
            query_preds.append(query_pred)
            query_losses.append(query_loss)
 
            # 第1到 K-1 次对网络进行更新
            for _ in range(1, K):           
                loss = tf.nn.softmax_cross_entropy_with_logits(logits=self.forward(support_x, fast_weights, training),
                                                               labels=support_y)
                grads = tf.gradients(loss, list(fast_weights.values()))
                gvs = dict(zip(fast_weights.keys(), grads))
                fast_weights = dict(zip(fast_weights.keys(), [fast_weights[key] - self.train_lr * gvs[key]
                                         for key in fast_weights.keys()]))
                query_pred = self.forward(query_x, fast_weights, training)
                query_loss = tf.nn.softmax_cross_entropy_with_logits(logits=query_pred, labels=query_y)
                # 子网络更新K次,记录每一次queryset的结果
                query_preds.append(query_pred)
                query_losses.append(query_loss)

            for i in range(K):
                query_accs.append(tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(query_preds[i], dim=1), axis=1),
                                                                tf.argmax(query_y, axis=1)))
            result = [support_pred, support_loss, support_acc, query_preds, query_losses, query_accs]
            return result

        # return: [support_pred, support_loss, support_acc, query_preds, query_losses, query_accs]
        out_dtype = [tf.float32, tf.float32, tf.float32, [tf.float32] * K, [tf.float32] * K, [tf.float32] * K]
        result = tf.map_fn(meta_task, elems=(support_xb, support_yb, query_xb, query_yb),
                           dtype=out_dtype, parallel_iterations=meta_batchsz, name='map_fn')
        support_pred_tasks, support_loss_tasks, support_acc_tasks, \
            query_preds_tasks, query_losses_tasks, query_accs_tasks = result

        if mode is 'train':
            self.support_loss = support_loss = tf.reduce_sum(support_loss_tasks) / meta_batchsz
            self.query_losses = query_losses = [tf.reduce_sum(query_losses_tasks[j]) / meta_batchsz
                                                    for j in range(K)]
            self.support_acc = support_acc = tf.reduce_sum(support_acc_tasks) / meta_batchsz
            self.query_accs = query_accs = [tf.reduce_sum(query_accs_tasks[j]) / meta_batchsz
                                                    for j in range(K)]

            # 更新meta网络,只使用了第 K步的query loss。这里应该是个超参,更新几步可以调调
            optimizer = tf.train.AdamOptimizer(self.meta_lr, name='meta_optim')
            gvs = optimizer.compute_gradients(self.query_losses[-1])
   # def ********

三个 MAML 算法的核心问题(源自 李宏毅2021机器学习课程 [28]):

  1. MAML 的执行过程与 model pre-training & transfer learning 的区别是什么?
  2. 为何在 meta 网络赋值给具体训练任务(如任务m)后,要先更训练任务的参数,再计算梯度,更新 meta 网络?
  3. 在更新训练任务的网络时,只走了一步,然后更新 meta 网络。为什么是一步,可以是多步吗?

问题1:MAML 的执行过程与 model pre-training & transfer learning 的区别是什么?

我们列举出 meta learning 与 model pre-training 的损失函数进行对比:

MAML 的 meta 模型的损失函数:
L%28%5Cphi%29%3D%5Csum_%7Bi%3D1%7D%5En%5Cmathcal%20L%5En%28%5Chat%5Ctheta%5En%29
model pre-training 的损失函数为:
L%28%5Cphi%29%3D%5Csum_%7Bi%3D1%7D%5En%5Cmathcal%20L%5En%28%5Cphi%29
meta learning 的损失函数%5Cmathcal%20L是在使用 support set 将任务网络的参数全部更新过一次后,然后使用 query set 进行计算。在更新之后任务网络的参数与 meta 网络的参数不同。

由于机器学习有且仅有一个模型,model pre-training 的损失函数%5Cmathcal%20L是对同一个模型的参数,使用训练数据计算的损失函数以及梯度,对模型参数进行更新;如果有多个训练任务,我们可以将模型参数在很多任务上进行预训练,训练的所有梯度都会直接更新模型的参数。

两者的更新过程如图:

一文弄懂元学习 (Meta Learing)(附代码实战)《繁凡的深度学习笔记》第 15 章 元学习详解 (上)万字中文综述
图 15.9 MAML 与 Model Pre-training 的更新区别(图片来源 [28])

%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2Cmeta learing 的 MAML 算法更新流程我们在上文中讲过了,这里简单复述一遍,我们首先将模型初始化的参数%5Cphi%5E0直接赋值给模型任务m的参数%5Chat%5Ctheta%5Em,即第一个绿色箭头。然后对m任务网络进行梯度更新,然后对 meta 网络进行同方向的梯度更新,即第一个蓝色箭头平行于第二个绿色箭头。然后对任务n进行一次上述操作,直到所有任务均对 meta 网络进行一次同梯度方向的更新以后,一次训练完成。我们发现尽管 meta 网络的参数更新方向使用的是梯度的方向,但是由于我们每次将 meta 网络的参数赋值给任务网络,然后按照任务网络的梯度方向进行更新,所以更新一次之后,所有人物网络的参数与 meta 网络的参数均不相同。meta Learing 希望最小化每一个子任务训练更新一次之后,第二次在 query set 上对于所有任务得到的损失函数,子任务从状态一文弄懂元学习 (Meta Learing)(附代码实战)《繁凡的深度学习笔记》第 15 章 元学习详解 (上)万字中文综述0到状态1,我们希望状态1的损失函数小,也即更加关心初始化参数未来的潜力。

%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C机器学习的 model pre-training使用子任务的梯度更新模型的参数,联合起来一步一步更新模型。希望最小化模型的在所有任务上的损失函数,由于只有一个模型,所以希望找到一个在大多数任务上表现最好的初始化参数,即满足当前表现最好。

%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C如图 15.10 所示,model pre-training 找到的参数%5Cphi,在两个任务上当前的表现比较好,也即当前选择最优的参数,但再之后训练并不能保证会更好。

%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2CMAML 的参数%5Cphi在子任务当前的表现并不是最好的,因为我们并不关心在任务网络上的结果,我们关心使用%5Cphi训练出的%5Chat%5Ctheta%5En表现如何,并且我们将模型继续训练下去,更大可能达到各自任务的局部最优情况。

一文弄懂元学习 (Meta Learing)(附代码实战)《繁凡的深度学习笔记》第 15 章 元学习详解 (上)万字中文综述
图 15.10 MAML (图片来源 [28])
一文弄懂元学习 (Meta Learing)(附代码实战)《繁凡的深度学习笔记》第 15 章 元学习详解 (上)万字中文综述
图 15.11 Model Pre-training (图片来源 [28])

这里给出一个 toy example 再对二者的区别进行展示。

%E2%9D%91训练任务:给定N函数,y%20%3D%20a%5Csin%20x%20%2B%20b,我们只需要修改a%2Cb的值,得到不同的%5Csin函数,并从每个函数中抽取K点来拟合原来的给定函数。

%E2%9D%91训练过程:用这N个训练任务采样的数据点分别通过 MAML 与 model pre-training 训练网络,得到预训练的参数。

训练结果如图 15.12 所示:

一文弄懂元学习 (Meta Learing)(附代码实战)《繁凡的深度学习笔记》第 15 章 元学习详解 (上)万字中文综述
图 15.12 Toy Eample(图源 [24])

我们发现 model pre-training 的结果,在测试任务上,在finetuning前后,绿色曲线也即预训练结果一直是一条水平线。因为我们在预训练的时候,给定了很多不同的%5Csin函数,model pre-training 为了找到一个在所有任务上都效果较好的初始化结果,由于若干%5Csin函数波峰和波谷重叠起来,基本就是一条水平线,因此使用这个初始化的结果取 finetuning ,得到的结果仍然是水平线。

而 MAML 的初始化结果是一条与原函数相差较大的曲线。随着 finetuning 的进行,得到的曲线越来越接近。

问题2:为何在 meta 网络赋值给具体训练任务(如任务m)后,要先更新训练任务的参数,再计算梯度,更新 meta 网络?

问题3:在更新训练任务的网络时,只走了一步,然后更新 meta 网络。为什么是一步,可以是多步吗?

  • 只更新一次,速度比较快;因为meta learning中,子任务有很多,都更新很多次,训练时间比较久。
  • MAML希望得到的初始化参数在新的任务中finetuning的时候效果好。如果只更新一次,就可以在新任务上获取很好的表现。把这件事情当成目标,可以使得meta网络参数训练是很好(目标与需求一致)。
  • 当初始化参数应用到具体的任务中时,也可以finetuning很多次。
  • Few-shot learning 往往数据较少。

15.4.2.1 实战 拟合y%3Da%20%5Csin%28x%2Bb%29

我们上面讲解了李宏毅老师给出的 toy example:拟合y%3Da%20%5Csin%28x%2Bb%29,这里简单实现一遍。

%5Ctext%20%7B%20In%20%7D%5B1%5D%3A

import numpy as np
import matplotlib.pyplot as plt

%5Ctext%20%7B%20In%20%7D%5B2%5D%3A

pi = np.pi
def sample_points(k):#k为对函数a*sin(x+b)在0到2π的采样点数
    a,b = np.random.uniform(0,2,2)
    x = np.arange(0,2*pi,2*pi/k)
    y = a*np.sin(x+b)
    return x,y,a,b
def draw_sin(a,b):
    x = np.arange(0,2*pi,0.1)
    y = a*np.sin(x+b)
    plt.plot(x,y)

%5Ctext%20%7B%20In%20%7D%5B3%5D%3A

task_num = 10   #每个batch中含有的task,即一次梯度下降所含的task数目
points_num =10
alpha = 0.01   #子模型的学习率
beta = 0.01    #元学习初始化参数更新的学习率
a_init = np.random.normal()
b_init = np.random.normal()   #要更新的初始化参数的最初数值
epoch = 10000   #元学习模型更新次数

%5Ctext%20%7B%20In%20%7D%5B4%5D%3A

for epoch_ in range(epoch):
    x_train = []
    y_train = []
    a_train = []
    b_train = []
    # 生成每个task的数据
    for i in range(task_num):
        x,y,a_,b_ = sample_points(points_num)
        x_train.append(x)
        y_train.append(y)
    a_gradient = 0
    b_gradient = 0   
    # 对每个task进行一次梯度下降,更新a,b,更新之后再计算一次梯度,并把这第二次的梯度累加,用来更新a_init和b_init
    loss = 0
    for i in range(task_num):
        a_0 = a_init
        b_0 = b_init   #梯度下降的初始值为元学习模型要学习的初始化的值
        x = x_train[i]
        y = y_train[i]  #第i个task的x和y
        y_ = a_0*np.sin(x+b_0)  #通过参数a,b对y的预测值
        a_0 = a_0 - sum(2*alpha*(y_-y)*np.sin(x+b_0))/points_num #MSE损失,复合函数求导
        b_0 = b_0 - sum(2*alpha*a_0*(y_-y)*np.cos(x+b_0))/points_num
        #更新完后,再计算一次梯度,累加
        y_ = a_0*np.sin(x+b_0)
        loss += sum(np.square(y_-y))/points_num
        a_gradient += sum(2*alpha*(y_-y)*np.sin(x+b_0))/points_num
        b_gradient += sum(2*alpha*a_0*(y_-y)*np.cos(x+b_0))/points_num
    a_init -= beta*a_gradient
    b_init -= beta*b_gradient
    if epoch_%1000==0:
        print("epoch:%d,loss:%f"%(epoch_,loss))

%5Ctext%20%7B%20Out%20%7D%5B4%5D%3A

epoch:0,loss:5.782693
epoch:1000,loss:4.591905
epoch:2000,loss:2.454080
epoch:3000,loss:3.021873
epoch:4000,loss:3.729429
epoch:5000,loss:3.780468
epoch:6000,loss:2.573931
epoch:7000,loss:2.781857
epoch:8000,loss:2.733948
epoch:9000,loss:3.111584

对MAML学习到的a_init和b_init,进行0次梯度下降更新,直接画出,和新来的采样样本比较

%5Ctext%20%7B%20In%20%7D%5B5%5D%3A

x,y,a_,b_ = sample_points(points_num)
draw_sin(a_,b_)
draw_sin(a_init,b_init)

%5Ctext%20%7B%20Out%20%7D%5B5%5D%3A

一文弄懂元学习 (Meta Learing)(附代码实战)《繁凡的深度学习笔记》第 15 章 元学习详解 (上)万字中文综述

MAML学习的初始化参数,迭代1次后

%5Ctext%20%7B%20In%20%7D%5B6%5D%3A

a_0 = a_init
b_0 = b_init   
for i in range(1):    
    y_ = a_0*np.sin(x+b_0)  #通过参数a,b对y的预测值
    a_0 = a_0 - sum(2*alpha*(y_-y)*np.sin(x+b_0))/points_num #MSE损失,复合函数求导
    b_0 = b_0 - sum(2*alpha*a_0*(y_-y)*np.cos(x+b_0))/points_num
draw_sin(a_,b_)
draw_sin(a_0,b_0)

%5Ctext%20%7B%20Out%20%7D%5B6%5D%3A

一文弄懂元学习 (Meta Learing)(附代码实战)《繁凡的深度学习笔记》第 15 章 元学习详解 (上)万字中文综述

MAML学习的初始化参数,迭代10次

%5Ctext%20%7B%20In%20%7D%5B7%5D%3A

a_0 = a_init
b_0 = b_init   
for i in range(10):
    y_ = a_0*np.sin(x+b_0)  #通过参数a,b对y的预测值
    a_0 = a_0 - sum(2*alpha*(y_-y)*np.sin(x+b_0))/points_num #MSE损失,复合函数求导
    b_0 = b_0 - sum(2*alpha*a_0*(y_-y)*np.cos(x+b_0))/points_num
draw_sin(a_,b_)
draw_sin(a_0,b_0)

%5Ctext%20%7B%20Out%20%7D%5B7%5D%3A

一文弄懂元学习 (Meta Learing)(附代码实战)《繁凡的深度学习笔记》第 15 章 元学习详解 (上)万字中文综述

MAML学习的初始化参数,迭代100次

%5Ctext%20%7B%20In%20%7D%5B8%5D%3A

a_0 = a_init
b_0 = b_init   
for i in range(100):
    y_ = a_0*np.sin(x+b_0)  #通过参数a,b对y的预测值
    a_0 = a_0 - sum(2*alpha*(y_-y)*np.sin(x+b_0))/points_num #MSE损失,复合函数求导
    b_0 = b_0 - sum(2*alpha*a_0*(y_-y)*np.cos(x+b_0))/points_num
draw_sin(a_,b_)
draw_sin(a_0,b_0)

%5Ctext%20%7B%20Out%20%7D%5B8%5D%3A

一文弄懂元学习 (Meta Learing)(附代码实战)《繁凡的深度学习笔记》第 15 章 元学习详解 (上)万字中文综述

MAML学习的初始化参数,迭代500次

%5Ctext%20%7B%20In%20%7D%5B9%5D%3A

a_0 = a_init
b_0 = b_init   
for i in range(500):
    y_ = a_0*np.sin(x+b_0)  #通过参数a,b对y的预测值
    a_0 = a_0 - sum(2*alpha*(y_-y)*np.sin(x+b_0))/points_num #MSE损失,复合函数求导
    b_0 = b_0 - sum(2*alpha*a_0*(y_-y)*np.cos(x+b_0))/points_num
draw_sin(a_,b_)
draw_sin(a_0,b_0)

%5Ctext%20%7B%20Out%20%7D%5B9%5D%3A

一文弄懂元学习 (Meta Learing)(附代码实战)《繁凡的深度学习笔记》第 15 章 元学习详解 (上)万字中文综述

不进行MAML,随机生成a,b,进行0次梯度下降更新,直接画出,和新来的采样样本比较

%5Ctext%20%7B%20In%20%7D%5B10%5D%3A

a = np.random.normal()
b = np.random.normal()
draw_sin(a_,b_)
draw_sin(a,b)

%5Ctext%20%7B%20Out%20%7D%5B10%5D%3A

一文弄懂元学习 (Meta Learing)(附代码实战)《繁凡的深度学习笔记》第 15 章 元学习详解 (上)万字中文综述

不进行MAML,随机生成a,b,1次迭代

%5Ctext%20%7B%20In%20%7D%5B11%5D%3A

a_0 = a
b_0 = b   
for i in range(1):    
    y_ = a_0*np.sin(x+b_0)  #通过参数a,b对y的预测值
    a_0 = a_0 - sum(2*alpha*(y_-y)*np.sin(x+b_0))/points_num #MSE损失,复合函数求导
    b_0 = b_0 - sum(2*alpha*a_0*(y_-y)*np.cos(x+b_0))/points_num
draw_sin(a_,b_)
draw_sin(a_0,b_0)

%5Ctext%20%7B%20Out%20%7D%5B11%5D%3A

一文弄懂元学习 (Meta Learing)(附代码实战)《繁凡的深度学习笔记》第 15 章 元学习详解 (上)万字中文综述

不进行MAML,随机生成a,b,10次迭代

%5Ctext%20%7B%20In%20%7D%5B12%5D%3A

a_0 = a
b_0 = b  
for i in range(10):    
    y_ = a_0*np.sin(x+b_0)  #通过参数a,b对y的预测值
    a_0 = a_0 - sum(2*alpha*(y_-y)*np.sin(x+b_0))/points_num #MSE损失,复合函数求导
    b_0 = b_0 - sum(2*alpha*a_0*(y_-y)*np.cos(x+b_0))/points_num
draw_sin(a_,b_)
draw_sin(a_0,b_0)

%5Ctext%20%7B%20Out%20%7D%5B12%5D%3A

一文弄懂元学习 (Meta Learing)(附代码实战)《繁凡的深度学习笔记》第 15 章 元学习详解 (上)万字中文综述

不进行MAML,随机生成a,b,100次迭代

%5Ctext%20%7B%20In%20%7D%5B13%5D%3A

a_0 = a
b_0 = b  
for i in range(100):    
    y_ = a_0*np.sin(x+b_0)  #通过参数a,b对y的预测值
    a_0 = a_0 - sum(2*alpha*(y_-y)*np.sin(x+b_0))/points_num #MSE损失,复合函数求导
    b_0 = b_0 - sum(2*alpha*a_0*(y_-y)*np.cos(x+b_0))/points_num
draw_sin(a_,b_)
draw_sin(a_0,b_0)

%5Ctext%20%7B%20Out%20%7D%5B13%5D%3A

一文弄懂元学习 (Meta Learing)(附代码实战)《繁凡的深度学习笔记》第 15 章 元学习详解 (上)万字中文综述

不进行MAML,随机生成a,b,500次迭代

%5Ctext%20%7B%20In%20%7D%5B14%5D%3A

a_0 = a
b_0 = b  
for i in range(500):    
    y_ = a_0*np.sin(x+b_0)  #通过参数a,b对y的预测值
    a_0 = a_0 - sum(2*alpha*(y_-y)*np.sin(x+b_0))/points_num #MSE损失,复合函数求导
    b_0 = b_0 - sum(2*alpha*a_0*(y_-y)*np.cos(x+b_0))/points_num
draw_sin(a_,b_)
draw_sin(a_0,b_0)

%5Ctext%20%7B%20Out%20%7D%5B14%5D%3A

一文弄懂元学习 (Meta Learing)(附代码实战)《繁凡的深度学习笔记》第 15 章 元学习详解 (上)万字中文综述

15.4.2.2 First-Order MAML

%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C上面的元优化过程依赖于二阶导数(多次迭代)。而为了加快计算,简化实现过程,一个忽略了二阶项的简化版MAML被提出了,称为First-Order MAML (FOMAML)。

%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C 让我们考虑执行k内循环(微调过程)的梯度下降过程(k%5Cgeq1)。假设初始模型参数为%5Ctheta_%5Ctext%7Bmeta%7D

%5Cbegin%7Baligned%7D%20%5Ctheta_0%20%26%3D%20%5Ctheta_%5Ctext%7Bmeta%7D%5C%5C%20%5Ctheta_1%20%26%3D%20%5Ctheta_0%20-%20%5Calpha%5Cnabla_%5Ctheta%5Cmathcal%7BL%7D%5E%7B%280%29%7D%28%5Ctheta_0%29%5C%5C%20%5Ctheta_2%20%26%3D%20%5Ctheta_1%20-%20%5Calpha%5Cnabla_%5Ctheta%5Cmathcal%7BL%7D%5E%7B%280%29%7D%28%5Ctheta_1%29%5C%5C%20%26%5Cdots%5C%5C%20%5Ctheta_k%20%26%3D%20%5Ctheta_%7Bk-1%7D%20-%20%5Calpha%5Cnabla_%5Ctheta%5Cmathcal%7BL%7D%5E%7B%280%29%7D%28%5Ctheta_%7Bk-1%7D%29%20%5Cend%7Baligned%7D

而在外循环中,我们采样一个新的数据batch用于更新元目标。

%5Cbegin%7Baligned%7D%20%5Ctheta_%5Ctext%7Bmeta%7D%20%26%5Cleftarrow%20%5Ctheta_%5Ctext%7Bmeta%7D%20-%20%5Cbeta%20g_%5Ctext%7BMAML%7D%20%26%20%5Cscriptstyle%7B%5Ctext%7B%3B%20update%20for%20meta-objective%7D%7D%20%5C%5C%5B2mm%5D%20%5Ctext%7Bwhere%20%7D%20g_%5Ctext%7BMAML%7D%20%26%3D%20%5Cnabla_%7B%5Ctheta%7D%20%5Cmathcal%7BL%7D%5E%7B%281%29%7D%28%5Ctheta_k%29%20%26%5C%5C%5B2mm%5D%20%26%3D%20%5Cnabla_%7B%5Ctheta_k%7D%20%5Cmathcal%7BL%7D%5E%7B%281%29%7D%28%5Ctheta_k%29%20%5Ccdot%20%28%5Cnabla_%7B%5Ctheta_%7Bk-1%7D%7D%20%5Ctheta_k%29%20%5Cdots%20%28%5Cnabla_%7B%5Ctheta_0%7D%20%5Ctheta_1%29%20%5Ccdot%20%28%5Cnabla_%7B%5Ctheta%7D%20%5Ctheta_0%29%20%26%20%5Cscriptstyle%7B%5Ctext%7B%3B%20following%20the%20chain%20rule%7D%7D%20%5C%5C%20%26%3D%20%5Cnabla_%7B%5Ctheta_k%7D%20%5Cmathcal%7BL%7D%5E%7B%281%29%7D%28%5Ctheta_k%29%20%5Ccdot%20%5CBig%28%20%5Cprod_%7Bi%3D1%7D%5Ek%20%5Cnabla_%7B%5Ctheta_%7Bi-1%7D%7D%20%5Ctheta_i%20%5CBig%29%20%5Ccdot%20I%20%26%20%5C%5C%20%26%3D%20%5Cnabla_%7B%5Ctheta_k%7D%20%5Cmathcal%7BL%7D%5E%7B%281%29%7D%28%5Ctheta_k%29%20%5Ccdot%20%5Cprod_%7Bi%3D1%7D%5Ek%20%5Cnabla_%7B%5Ctheta_%7Bi-1%7D%7D%20%28%5Ctheta_%7Bi-1%7D%20-%20%5Calpha%5Cnabla_%5Ctheta%5Cmathcal%7BL%7D%5E%7B%280%29%7D%28%5Ctheta_%7Bi-1%7D%29%29%20%26%20%5C%5C%20%26%3D%20%5Cnabla_%7B%5Ctheta_k%7D%20%5Cmathcal%7BL%7D%5E%7B%281%29%7D%28%5Ctheta_k%29%20%5Ccdot%20%5Cprod_%7Bi%3D1%7D%5Ek%20%28I%20-%20%5Calpha%5Cnabla_%7B%5Ctheta_%7Bi-1%7D%7D%28%5Cnabla_%5Ctheta%5Cmathcal%7BL%7D%5E%7B%280%29%7D%28%5Ctheta_%7Bi-1%7D%29%29%29%20%26%20%5Cend%7Baligned%7D

MAML的梯度是:

g_%5Ctext%7BMAML%7D%20%3D%20%5Cnabla_%7B%5Ctheta_k%7D%20%5Cmathcal%7BL%7D%5E%7B%281%29%7D%28%5Ctheta_k%29%20%5Ccdot%20%5Cprod_%7Bi%3D1%7D%5Ek%20%28I%20-%20%5Calpha%20%5Ccolor%7Bred%7D%7B%5Cnabla_%7B%5Ctheta_%7Bi-1%7D%7D%28%5Cnabla_%5Ctheta%5Cmathcal%7BL%7D%5E%7B%280%29%7D%28%5Ctheta_%7Bi-1%7D%29%29%7D%29

一阶 MAML 忽略了用红色标记的二阶导数部分。它被简化为了下式,等价于最后一次内循环梯度更新的结果。

g_%5Ctext%7BFOMAML%7D%20%3D%20%5Cnabla_%7B%5Ctheta_k%7D%20%5Cmathcal%7BL%7D%5E%7B%281%29%7D%28%5Ctheta_k%29

15.4.2 Reptile

%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2CReptileNichol, Achiam & Schulman, 2018(https://arxiv.org/abs/1803.02999) 是一个超级简单的元学习优化算法。它跟 MAML 类似,它们都靠梯度下降进行元优化,而且都是模型无关的算法。

Reptiled 的执行流程如下:

  1. 对任务进行抽样;
  2. 在这个任务上做多次梯度下降;
  3. 将模型参数移近新参数。
一文弄懂元学习 (Meta Learing)(附代码实战)《繁凡的深度学习笔记》第 15 章 元学习详解 (上)万字中文综述

如下图中算法所示: 给定初始参数%5Ctheta%5Ctext%7BSGD%7D%28%5Cmathcal%7BL%7D_%7B%5Cmathcal%20T_i%7D%2C%20%5Ctheta%2C%20k%29根据 loss%5Cmathcal%7BL%7D_%7B%5Cmathcal%20T_i%7D进行k次随机梯度下降,之后返回参数向量。带batch的版本则每次采样多个任务。Reptile的梯度定义为%28%5Ctheta%20-%20W%29/%5Calpha,其中%5Calpha是 SGD 所使用的步长。

一文弄懂元学习 (Meta Learing)(附代码实战)《繁凡的深度学习笔记》第 15 章 元学习详解 (上)万字中文综述
图 15.13 Batch 版本的 Reptile 算法。(图像来源:原论文[13])

一眼看上去,这个算法跟普通的 SGD 很像。但是,因为内循环里的梯度下降可以发生好多次,使得%5Ctext%7BSGD%7D%28%5Cmathbb%7BE%7D%20_%5Cmathcal%20T%5B%5Cmathcal%7BL%7D_%7B%5Cmathcal%20T%7D%5D%2C%20%5Ctheta%2C%20k%29%5Cmathbb%7BE%7D_%5Cmathcal%20T%20%5B%5Ctext%7BSGD%7D%28%5Cmathcal%7BL%7D_%7B%5Cmathcal%20T%7D%2C%20%5Ctheta%2C%20k%29%5D在 k > 1 时产生了区别。

一文弄懂元学习 (Meta Learing)(附代码实战)《繁凡的深度学习笔记》第 15 章 元学习详解 (上)万字中文综述
图 15.14 Reptile 的训练过程 [图源[28]]

在 Reptile 中:

  • 训练任务的网络可以多次更新;
  • reptile 不再像 MAML 一样计算梯度(因此带来了工程性能的提升),而是直接用一个参数%5Cepsilon乘以 meta 网络与训练任务的网络参数的差来更新 meta 网络参数;
  • 从效果上来看,Reptile 效果与 MAML 基本持平。

15.4.2.1 The Optimization Assumption

%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C%5C%2C假设任务%5Cmathcal%20T%20%5Csim%20p%28%5Cmathcal%20T%29有一个由最优模型参数空间形成的流形%5Cmathcal%7BW%7D_%7B%5Cmathcal%20T%7D%5E%2A。当参数%5Ctheta高于这个流形%5Cmathcal%7BW%7D_%7B%5Cmathcal%20T%7D%5E%2A时,模型f_%5Ctheta对任务%5Ctheta的效果最好。为了找到一个对所有任务都足够好的模型,我们希望找到一个接近所有任务的最优流形的参数,即:

%5Ctheta%5E%2A%20%3D%20%5Carg%5Cmin_%5Ctheta%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%20T%20%5Csim%20p%28%5Cmathcal%20T%29%7D%20%5B%5Cfrac%7B1%7D%7B2%7D%20%5Ctext%7Bdist%7D%28%5Ctheta%2C%20%5Cmathcal%7BW%7D_%5Cmathcal%20T%5E%2A%29%5E2%5D

一文弄懂元学习 (Meta Learing)(附代码实战)《繁凡的深度学习笔记》第 15 章 元学习详解 (上)万字中文综述
图 15.15 为了靠近不同任务的最优流形,Reptile 算法在交替更新参数。 (图像来源:[原论文https://arxiv.org/abs/1803.02999]

我们设%5Ctext%7Bdist%7D%28.%29为 L2 距离,并定义一个点%5Ctheta和一个集合%5Cmathcal%7BW%7D_%5Cmathcal%20T%5E%2A之间的距离等价于%5Ctheta和一个该流形上最接近%5Ctheta的点W_%7B%5Cmathcal%20T%7D%5E%2A%28%5Ctheta%29的距离。

%5Ctext%7Bdist%7D%28%5Ctheta%2C%20%5Cmathcal%7BW%7D_%7B%5Cmathcal%20T%7D%5E%2A%29%20%3D%20%5Ctext%7Bdist%7D%28%5Ctheta%2C%20W_%7B%5Cmathcal%20T%7D%5E%2A%28%5Ctheta%29%29%20%5Ctext%7B%2C%20where%20%7DW_%7B%5Cmathcal%20T%7D%5E%2A%28%5Ctheta%29%20%3D%20%5Carg%5Cmin_%7BW%5Cin%5Cmathcal%7BW%7D_%7B%5Cmathcal%20T%7D%5E%2A%7D%20%5Ctext%7Bdist%7D%28%5Ctheta%2C%20W%29

平方欧拉距离的梯度为:

%5Cbegin%7Baligned%7D%20%5Cnabla_%5Ctheta%5B%5Cfrac%7B1%7D%7B2%7D%5Ctext%7Bdist%7D%28%5Ctheta%2C%20%5Cmathcal%7BW%7D_%7B%5Cmathcal%20T_i%7D%5E%2A%29%5E2%5D%20%26%3D%20%5Cnabla_%5Ctheta%5B%5Cfrac%7B1%7D%7B2%7D%5Ctext%7Bdist%7D%28%5Ctheta%2C%20W_%7B%5Cmathcal%20T_i%7D%5E%2A%28%5Ctheta%29%29%5E2%5D%20%26%20%5C%5C%20%26%3D%20%5Cnabla_%5Ctheta%5B%5Cfrac%7B1%7D%7B2%7D%28%5Ctheta%20-%20W_%7B%5Cmathcal%20T_i%7D%5E%2A%28%5Ctheta%29%29%5E2%5D%20%26%20%5C%5C%20%26%3D%20%5Ctheta%20-%20W_%7B%5Cmathcal%20T_i%7D%5E%2A%28%5Ctheta%29%20%26%20%5Cscriptstyle%7B%5Ctext%7B%3B%20See%20notes.%7D%7D%20%5Cend%7Baligned%7D

注意:根据原论文,“一个点 Θ 和一个集合 S 的欧拉距离平方的梯度为 2(Θ − p),其中 p 是 S 中离 Θ 最近的点“。理论上来说,S 中最靠近 Θ 的点应该也是一个 Θ 的函数,但我不确定为什么计算梯度的时候不需要担心 p 的导数。(如果有什么想法欢迎讨论。)

因此,随机梯度下降的更新步骤是:

%5Ctheta%20%3D%20%5Ctheta%20-%20%5Calpha%20%5Cnabla_%5Ctheta%5B%5Cfrac%7B1%7D%7B2%7D%20%5Ctext%7Bdist%7D%28%5Ctheta%2C%20%5Cmathcal%7BW%7D_%7B%5Cmathcal%20T_i%7D%5E%2A%29%5E2%5D%20%3D%20%5Ctheta%20-%20%5Calpha%28%5Ctheta%20-%20W_%7B%5Cmathcal%20T_i%7D%5E%2A%28%5Ctheta%29%29%20%3D%20%281-%5Calpha%29%5Ctheta%20%2B%20%5Calpha%20W_%7B%5Cmathcal%20T_i%7D%5E%2A%28%5Ctheta%29

虽然很难精确计算出任务最优流形上的最近点W_%7B%5Cmathcal%20T_i%7D%5E%2A%28%5Ctheta%29,但Reptile 可以根据%5Ctext%7BSGD%7D%28%5Cmathcal%7BL%7D_%5Cmathcal%20T%2C%20%5Ctheta%2C%20k%29拟合出来。

15.4.2.2 Reptile vs FOMAML

为了说明 Reptile 和 MAML 的深层联系,我们用一个做了两步梯度下降的更新公式做例子,即%5Ctext%7BSGD%7D%28.%29k%3D2。 和上面定义的一样,%5Cmathcal%7BL%7D%5E%7B%280%29%7D%5Cmathcal%7BL%7D%5E%7B%281%29%7D只是不太batch对应的loss。为了方便阅读,我们使用两个记号:g%5E%7B%28i%29%7D_j%20%3D%20%5Cnabla_%7B%5Ctheta%7D%20%5Cmathcal%7BL%7D%5E%7B%28i%29%7D%28%5Ctheta_j%29H%5E%7B%28i%29%7D_j%20%3D%20%5Cnabla%5E2_%7B%5Ctheta%7D%20%5Cmathcal%7BL%7D%5E%7B%28i%29%7D%28%5Ctheta_j%29

%5Cbegin%7Baligned%7D%20%5Ctheta_0%20%26%3D%20%5Ctheta_%5Ctext%7Bmeta%7D%5C%5C%20%5Ctheta_1%20%26%3D%20%5Ctheta_0%20-%20%5Calpha%5Cnabla_%5Ctheta%5Cmathcal%7BL%7D%5E%7B%280%29%7D%28%5Ctheta_0%29%3D%20%5Ctheta_0%20-%20%5Calpha%20g%5E%7B%280%29%7D_0%20%5C%5C%20%5Ctheta_2%20%26%3D%20%5Ctheta_1%20-%20%5Calpha%5Cnabla_%5Ctheta%5Cmathcal%7BL%7D%5E%7B%281%29%7D%28%5Ctheta_1%29%20%3D%20%5Ctheta_0%20-%20%5Calpha%20g%5E%7B%280%29%7D_0%20-%20%5Calpha%20g%5E%7B%281%29%7D_1%20%5Cend%7Baligned%7D

如前面章节所述,FOMAML 的梯度是最后一次内循环梯度更新的结果。因此,当k%3D1时:

%5Cbegin%7Baligned%7D%20g_%5Ctext%7BFOMAML%7D%20%26%3D%20%5Cnabla_%7B%5Ctheta_1%7D%20%5Cmathcal%7BL%7D%5E%7B%281%29%7D%28%5Ctheta_1%29%20%3D%20g%5E%7B%281%29%7D_1%20%5C%5C%20g_%5Ctext%7BMAML%7D%20%26%3D%20%5Cnabla_%7B%5Ctheta_1%7D%20%5Cmathcal%7BL%7D%5E%7B%281%29%7D%28%5Ctheta_1%29%20%5Ccdot%20%28I%20-%20%5Calpha%5Cnabla%5E2_%7B%5Ctheta%7D%20%5Cmathcal%7BL%7D%5E%7B%280%29%7D%28%5Ctheta_0%29%29%20%3D%20g%5E%7B%281%29%7D_1%20-%20%5Calpha%20H%5E%7B%280%29%7D_0%20g%5E%7B%281%29%7D_1%20%5Cend%7Baligned%7D

Reptile 的梯度则定义为:

g_%5Ctext%7BReptile%7D%20%3D%20%28%5Ctheta_0%20-%20%5Ctheta_2%29%20/%20%5Calpha%20%3D%20g%5E%7B%280%29%7D_0%20%2B%20g%5E%7B%281%29%7D_1

现在,我们可以得到:

一文弄懂元学习 (Meta Learing)(附代码实战)《繁凡的深度学习笔记》第 15 章 元学习详解 (上)万字中文综述
图 15.16 Reptile 和 FOMAML 在一次外循环里的元优化方式对比。(图像来源:slides(https://www.slideshare.net/YoonhoLee4/on-firstorder-metalearning-algorithms)

%5Cbegin%7Baligned%7D%20g_%5Ctext%7BFOMAML%7D%20%26%3D%20g%5E%7B%281%29%7D_1%20%5C%5C%20g_%5Ctext%7BMAML%7D%20%26%3D%20g%5E%7B%281%29%7D_1%20-%20%5Calpha%20H%5E%7B%280%29%7D_0%20g%5E%7B%281%29%7D_1%20%5C%5C%20g_%5Ctext%7BReptile%7D%20%26%3D%20g%5E%7B%280%29%7D_0%20%2B%20g%5E%7B%281%29%7D_1%20%5Cend%7Baligned%7D

接下来,我们对g%5E%7B%281%29%7D_1使用泰勒展开式。可微函数f%28x%29a阶展开公式为:

f%28x%29%20%3D%20f%28a%29%20%2B%20%5Cfrac%7Bf%27%28a%29%7D%7B1%21%7D%28x-a%29%20%2B%20%5Cfrac%7Bf%27%27%28a%29%7D%7B2%21%7D%28x-a%29%5E2%20%2B%20%5Cdots%20%3D%20%5Csum_%7Bi%3D0%7D%5E%5Cinfty%20%5Cfrac%7Bf%5E%7B%28i%29%7D%28a%29%7D%7Bi%21%7D%28x-a%29%5Ei

我们将%5Cnabla_%7B%5Ctheta%7D%5Cmathcal%7BL%7D%5E%7B%281%29%7D%28.%29视为一个函数,将%5Ctheta_0视为一个自变量。 g_1%5E%7B%281%29%7D %5Ctheta_0 处的泰勒展开为:

%5Cbegin%7Baligned%7D%20g_1%5E%7B%281%29%7D%20%26%3D%20%5Cnabla_%7B%5Ctheta%7D%5Cmathcal%7BL%7D%5E%7B%281%29%7D%28%5Ctheta_1%29%20%5C%5C%20%26%3D%20%5Cnabla_%7B%5Ctheta%7D%5Cmathcal%7BL%7D%5E%7B%281%29%7D%28%5Ctheta_0%29%20%2B%20%5Cnabla%5E2_%5Ctheta%5Cmathcal%7BL%7D%5E%7B%281%29%7D%28%5Ctheta_0%29%28%5Ctheta_1%20-%20%5Ctheta_0%29%20%2B%20%5Cfrac%7B1%7D%7B2%7D%5Cnabla%5E3_%5Ctheta%5Cmathcal%7BL%7D%5E%7B%281%29%7D%28%5Ctheta_0%29%28%5Ctheta_1%20-%20%5Ctheta_0%29%5E2%20%2B%20%5Cdots%20%26%20%5C%5C%20%26%3D%20g_0%5E%7B%281%29%7D%20-%20%5Calpha%20H%5E%7B%281%29%7D_0%20g_0%5E%7B%280%29%7D%20%2B%20%5Cfrac%7B%5Calpha%5E2%7D%7B2%7D%5Cnabla%5E3_%5Ctheta%5Cmathcal%7BL%7D%5E%7B%281%29%7D%28%5Ctheta_0%29%20%28g_0%5E%7B%280%29%7D%29%5E2%20%2B%20%5Cdots%20%26%20%5Cscriptstyle%7B%5Ctext%7B%3B%20because%20%7D%5Ctheta_1-%5Ctheta_0%3D-%5Calpha%20g_0%5E%7B%280%29%7D%7D%20%5C%5C%20%26%3D%20g_0%5E%7B%281%29%7D%20-%20%5Calpha%20H%5E%7B%281%29%7D_0%20g_0%5E%7B%280%29%7D%20%2B%20O%28%5Calpha%5E2%29%20%5Cend%7Baligned%7D

把 MAML 的一步内循环梯度更新用g_1%5E%7B%281%29%7D的展开式重写:

%5Cbegin%7Baligned%7D%20g_%5Ctext%7BFOMAML%7D%20%26%3D%20g%5E%7B%281%29%7D_1%20%3D%20g_0%5E%7B%281%29%7D%20-%20%5Calpha%20H%5E%7B%281%29%7D_0%20g_0%5E%7B%280%29%7D%20%2B%20O%28%5Calpha%5E2%29%5C%5C%20g_%5Ctext%7BMAML%7D%20%26%3D%20g%5E%7B%281%29%7D_1%20-%20%5Calpha%20H%5E%7B%280%29%7D_0%20g%5E%7B%281%29%7D_1%20%5C%5C%20%26%3D%20g_0%5E%7B%281%29%7D%20-%20%5Calpha%20H%5E%7B%281%29%7D_0%20g_0%5E%7B%280%29%7D%20%2B%20O%28%5Calpha%5E2%29%20-%20%5Calpha%20H%5E%7B%280%29%7D_0%20%28g_0%5E%7B%281%29%7D%20-%20%5Calpha%20H%5E%7B%281%29%7D_0%20g_0%5E%7B%280%29%7D%20%2B%20O%28%5Calpha%5E2%29%29%5C%5C%20%26%3D%20g_0%5E%7B%281%29%7D%20-%20%5Calpha%20H%5E%7B%281%29%7D_0%20g_0%5E%7B%280%29%7D%20-%20%5Calpha%20H%5E%7B%280%29%7D_0%20g_0%5E%7B%281%29%7D%20%2B%20%5Calpha%5E2%20%5Calpha%20H%5E%7B%280%29%7D_0%20H%5E%7B%281%29%7D_0%20g_0%5E%7B%280%29%7D%20%2B%20O%28%5Calpha%5E2%29%5C%5C%20%26%3D%20g_0%5E%7B%281%29%7D%20-%20%5Calpha%20H%5E%7B%281%29%7D_0%20g_0%5E%7B%280%29%7D%20-%20%5Calpha%20H%5E%7B%280%29%7D_0%20g_0%5E%7B%281%29%7D%20%2B%20O%28%5Calpha%5E2%29%20%5Cend%7Baligned%7D

Reptile的梯度变为:

%5Cbegin%7Baligned%7D%20g_%5Ctext%7BReptile%7D%20%26%3D%20g%5E%7B%280%29%7D_0%20%2B%20g%5E%7B%281%29%7D_1%20%5C%5C%20%26%3D%20g%5E%7B%280%29%7D_0%20%2B%20g_0%5E%7B%281%29%7D%20-%20%5Calpha%20H%5E%7B%281%29%7D_0%20g_0%5E%7B%280%29%7D%20%2B%20O%28%5Calpha%5E2%29%20%5Cend%7Baligned%7D

现在,我们有了三种不同的梯度更新法则:

%5Cbegin%7Baligned%7D%20g_%5Ctext%7BFOMAML%7D%20%26%3D%20g_0%5E%7B%281%29%7D%20-%20%5Calpha%20H%5E%7B%281%29%7D_0%20g_0%5E%7B%280%29%7D%20%2B%20O%28%5Calpha%5E2%29%5C%5C%20g_%5Ctext%7BMAML%7D%20%26%3D%20g_0%5E%7B%281%29%7D%20-%20%5Calpha%20H%5E%7B%281%29%7D_0%20g_0%5E%7B%280%29%7D%20-%20%5Calpha%20H%5E%7B%280%29%7D_0%20g_0%5E%7B%281%29%7D%20%2B%20O%28%5Calpha%5E2%29%5C%5C%20g_%5Ctext%7BReptile%7D%20%26%3D%20g%5E%7B%280%29%7D_0%20%2B%20g_0%5E%7B%281%29%7D%20-%20%5Calpha%20H%5E%7B%281%29%7D_0%20g_0%5E%7B%280%29%7D%20%2B%20O%28%5Calpha%5E2%29%20%5Cend%7Baligned%7D

在训练过程中,我们经常会对多个数据batch做平均。在我们的例子中,小batch (0) 和 (1)是可交换的, 因为他们都是随机选取的。期望%5Cmathbb%7BE%7D_%7B%5Cmathcal%20T%2C0%2C1%7D表示在任务%5Cmathcal%20T上,两个标号为 (0) 和 (1) 的数据batch的平均值。

制作,

  • A%20%3D%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%20T%2C0%2C1%7D%20%5Bg_0%5E%7B%280%29%7D%5D%20%3D%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%20T%2C0%2C1%7D%20%5Bg_0%5E%7B%281%29%7D%5D; 代表着任务loss的平均梯度。沿着A代表的方向更新模型,能够使模型在当前任务上的表现更好。
  • B%20%3D%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%20T%2C0%2C1%7D%20%5BH%5E%7B%281%29%7D_0%20g_0%5E%7B%280%29%7D%5D%20%3D%20%5Cfrac%7B1%7D%7B2%7D%5Cmathbb%7BE%7D_%7B%5Cmathcal%20T%2C0%2C1%7D%20%5BH%5E%7B%281%29%7D_0%20g_0%5E%7B%280%29%7D%20%2B%20H%5E%7B%280%29%7D_0%20g_0%5E%7B%281%29%7D%5D%20%3D%20%5Cfrac%7B1%7D%7B2%7D%5Cmathbb%7BE%7D_%7B%5Cmathcal%20T%2C0%2C1%7D%20%5B%5Cnabla_%5Ctheta%28g%5E%7B%280%29%7D_0%20g_0%5E%7B%281%29%7D%29%5D; 代表着增加同个任务上两个小batch梯度的内积的方向。沿着B代表的方向更新模型,能够使得模型在当前任务上具有更好的泛化性。

综上所述, MAML 和 Reptile 的优化目标相同,都是更好的任务表现(由 A 主导)和更好的泛化能力(由 B 主导)。当梯度更新由泰勒展开的前三项近似时:

%5Cbegin%7Baligned%7D%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%20T%2C1%2C2%7D%5Bg_%5Ctext%7BFOMAML%7D%5D%20%26%3D%20A%20-%20%5Calpha%20B%20%2B%20O%28%5Calpha%5E2%29%5C%5C%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%20T%2C1%2C2%7D%5Bg_%5Ctext%7BMAML%7D%5D%20%26%3D%20A%20-%202%5Calpha%20B%20%2B%20O%28%5Calpha%5E2%29%5C%5C%20%5Cmathbb%7BE%7D_%7B%5Cmathcal%20T%2C1%2C2%7D%5Bg_%5Ctext%7BReptile%7D%5D%20%26%3D%202A%20-%20%5Calpha%20B%20%2B%20O%28%5Calpha%5E2%29%20%5Cend%7Baligned%7D

虽然我们没法确定忽略的O%28%5Calpha%5E2%29项是否对于参数更新有着重要作用。但根据 FOMAML 与 MAML 的表现相近来看,说高阶导数对于梯度更新不太重要还是比较靠谱的。

15.4.3 LSTM Meta-Learner

Ravi & Larochelle (2017)(https://openreview.net/pdf?id=rJY0-Kcll)把优化算法显式的建模出来,并命名为“元学习器”,原本处理任务的模型被称为“学习器”。元学习器的目标是使用少量支持集在仅仅几步之内快速更新学习器的参数,使得学习器能够快速适应新任务。

我们用M_%5Ctheta代表参数为%5Ctheta的学习器,用R_%5CTheta代表参数为%5CTheta的元学习器,loss函数为%5Cmathcal%7BL%7D.

%E2%9D%91为什么使用 LSTM?

之所以使用 LSTM 作为元学习器的模型,有这样几点原因:

  1. 反向传播中基于梯度的更新跟 LSTM 中 cell 状态的更新有相似之处。
  2. 知道之前的梯度对当前的梯度更新有好处。可以参考 momentum(http://ruder.io/optimizing-gradient-descent/index.html#momentum) 的原理。

第t步时,设定学习率为%5Calpha_t,更新学习器的参数:

%5Ctheta_t%20%3D%20%5Ctheta_%7Bt-1%7D%20-%20%5Calpha_t%20%5Cnabla_%7B%5Ctheta_%7Bt-1%7D%7D%5Cmathcal%7BL%7D_t
这个过程跟 LSTM 的 cell 状态更新具有相同的形式。如果我们令遗忘门f_t%3D1,输入门i_t%20%3D%20%5Calpha_t, cell 状态c_t%20%3D%20%5Ctheta_t, 新 cell 状态%5Ctilde%7Bc%7D_t%20%3D%20-%5Cnabla_%7B%5Ctheta_%7Bt-1%7D%7D%5Cmathcal%7BL%7D_t,则:

%5Cbegin%7Baligned%7D%20c_t%20%26%3D%20f_t%20%5Codot%20c_%7Bt-1%7D%20%2B%20i_t%20%5Codot%20%5Ctilde%7Bc%7D_t%5C%5C%20%26%3D%20%5Ctheta_%7Bt-1%7D%20-%20%5Calpha_t%5Cnabla_%7B%5Ctheta_%7Bt-1%7D%7D%5Cmathcal%7BL%7D_t%20%5Cend%7Baligned%7D

但是固定f_t%3D1i_t%3D%5Calpha_t可能不是最好的,我们可以让它们随着数据集的变化而变化,通过学习得到它们。

%5Cbegin%7Baligned%7D%20f_t%20%26%3D%20%5Csigma%28%5Cmathbf%7BW%7D_f%20%5Ccdot%20%5B%5Cnabla_%7B%5Ctheta_%7Bt-1%7D%7D%5Cmathcal%7BL%7D_t%2C%20%5Cmathcal%7BL%7D_t%2C%20%5Ctheta_%7Bt-1%7D%2C%20f_%7Bt-1%7D%5D%20%2B%20%5Cmathbf%7Bb%7D_f%29%20%26%20%5Cscriptstyle%7B%5Ctext%7B%3B%20how%20much%20to%20forget%20the%20old%20value%20of%20parameters.%7D%7D%5C%5C%20i_t%20%26%3D%20%5Csigma%28%5Cmathbf%7BW%7D_i%20%5Ccdot%20%5B%5Cnabla_%7B%5Ctheta_%7Bt-1%7D%7D%5Cmathcal%7BL%7D_t%2C%20%5Cmathcal%7BL%7D_t%2C%20%5Ctheta_%7Bt-1%7D%2C%20i_%7Bt-1%7D%5D%20%2B%20%5Cmathbf%7Bb%7D_i%29%20%26%20%5Cscriptstyle%7B%5Ctext%7B%3B%20corresponding%20to%20the%20learning%20rate%20at%20time%20step%20t.%7D%7D%5C%5C%20%5Ctilde%7B%5Ctheta%7D_t%20%26%3D%20-%5Cnabla_%7B%5Ctheta_%7Bt-1%7D%7D%5Cmathcal%7BL%7D_t%20%26%5C%5C%20%5Ctheta_t%20%26%3D%20f_t%20%5Codot%20%5Ctheta_%7Bt-1%7D%20%2B%20i_t%20%5Codot%20%5Ctilde%7B%5Ctheta%7D_t%20%26%5C%5C%20%5Cend%7Baligned%7D

%E2%9D%91Model Setup

一文弄懂元学习 (Meta Learing)(附代码实战)《繁凡的深度学习笔记》第 15 章 元学习详解 (上)万字中文综述
图 15.17 如何训练学习器 $M_\theta$ 和元学习器 $R_\Theta$。(图像来源:[原论文](https://openreview.net/pdf?id=rJY0-Kcll)

在 Matching Networks(https://wei-tianhao.github.io/blog/2019/09/17/meta-learning.html#matching-networks)中我们已经证明了用模仿测试过程的方式训练能够取得很好的效果,这里也用了类似的方法。在每个训练阶段,我们先采样一个数据集%5Cmathcal%7BD%7D%20%3D%20%28%5Cmathcal%7BD%7D_%5Ctext%7Btrain%7D%2C%20%5Cmathcal%7BD%7D_%5Ctext%7Btest%7D%29%20%5Cin%20%5Chat%7B%5Cmathcal%7BD%7D%7D_%5Ctext%7Bmeta-train%7D,再从%5Cmathcal%7BD%7D_%5Ctext%7Btrain%7D中采样T轮 mini-batches 用于更新%5Ctheta。学习器参数的最终状态%5Ctheta_T被用来在测试数据%5Cmathcal%7BD%7D_%5Ctext%7Btest%7D上训练元学习器。

有两个实现细节需要注意:

  1. 如何压缩 LSTM 元学习的参数空间?元学习器是在建模一个神经网络的参数,所以有上百万个变量要学。为了减小元学习器的参数空间,这篇文章借鉴了共享参数(https://arxiv.org/abs/1606.04474)的方法。元学习器本质上学习的是一种更新原则,即如何根据一个参数的值和其梯度生成这个参数的新值(比如一阶方法,牛顿法等),与参数在学习器中的位置无关。所以我们可以假设所有参数的更新原则都是一样的,即元学习只需要输出一维变量即可。
  2. 为了简化训练过程,元学习器假设损失函数%5Cmathcal%7BL%7D_t和梯度%5Cnabla_%7B%5Ctheta_%7Bt-1%7D%7D%20%5Cmathcal%7BL%7D_t是独立的。

train-meta-learner

剩余章节详见一文弄懂元学习 (Meta Learing)(附代码实战)《繁凡的深度学习笔记》第 15 章 元学习详解 (下)万字中文综述(待更):

15.5 基于度量的方法
15.5.1 Convolutional Siamese Neural Network
15.5.2 Matching Networks
15.5.2.1 Simple Embedding
15.5.2.2 Full Context Embeddings
15.5.3 Relation Network
15.5.4 Prototypical Networks
15.6 基于模型的方法
15.6.1 Memory-Augmented Neural Networks
15.6.1.1 MANN for Meta-Learning
15.6.1.2 Addressing Mechanism for Meta-Learning
15.6.2 Meta Networks
15.6.2.1 Fast Weights
15.6.2.2 Model Components
15.6.2.3 训练过程
15.7 元学习应用
15.7.1 计算机视觉和图形
15.7.2 元强化学习和机器人技术
15.7.3 环境学习与模拟现实
15.7.4 神经架构搜索(NAS)
15.7.5 贝叶斯元学习
15.7.6 无监督元学习和元学习无监督学习
15.7.7 主动学习
15.7.8 持续、在线和适应性学习
15.7.9 领域适应和领域概括
15.7.10 超参数优化
15.7.11 新颖且生物学上可信的学习者
15.7.12 语言和言语
15.7.13 元学习促进社会福利
15.7.14 抽象和合成推理
15.7.15 系统
15.8 未来展望
15.9 参考资料

15.9 参考资料

[1] Brenden M。Lake, Ruslan Salakhutdinov, and Joshua B。Tenenbaum。Human-level concept learning through probabilistic program induction.Science 350.6266 (2015): 1332-1338.
https://www.cs.cmu.edu/~rsalakhu/papers/LakeEtAl2015Science.pdf

[2] Oriol Vinyals’ talk onModel vs Optimization Meta Learning
http://metalearning-symposium.ml/files/vinyals.pdf

[3] Gregory Koch, Richard Zemel, and Ruslan Salakhutdinov. *Siamese neural networks for one-shot image recognition. ICML Deep Learning Workshop. 2015.
http://www.cs.toronto.edu/~rsalakhu/papers/oneshot1.pdf

[4] Oriol Vinyals, et al.Matching networks for one shot learning.NIPS。2016.
http://papers.nips.cc/paper/6385-matching-networks-for-one-shot-learning.pdf

[5] Flood Sung, et al.Learning to compare: Relation network for few-shot learning.CVPR。2018.
http://openaccess.thecvf.com/content_cvpr_2018/papers_backup/Sung_Learning_to_Compare_CVPR_2018_paper.pdf

[6] Jake Snell, Kevin Swersky, and Richard Zemel. *Prototypical Networks for Few-shot Learning. CVPR. 2018.
http://papers.nips.cc/paper/6996-prototypical-networks-for-few-shot-learning.pdf

[7] Adam Santoro, et al.Meta-learning with memory-augmented neural networks.ICML。2016.
http://proceedings.mlr.press/v48/santoro16.pdf

[8] Alex Graves, Greg Wayne, and Ivo Danihelka.Neural turing machines.arXiv preprint arXiv:1410.5401 (2014).
https://arxiv.org/abs/1410.5401

[9] Tsendsuren Munkhdalai and Hong Yu.Meta Networks.ICML。2017.
https://arxiv.org/abs/1703.00837

[10] Sachin Ravi and Hugo Larochelle.Optimization as a Model for Few-Shot Learning.ICLR。2017.
https://openreview.net/pdf?id=rJY0-Kcll

[11] Chelsea Finn’s BAIR blog onLearning to Learn.
https://bair.berkeley.edu/blog/2017/07/18/learning-to-learn/

[12] Chelsea Finn, Pieter Abbeel, and Sergey Levine.Model-agnostic meta-learning for fast adaptation of deep networks.ICML 2017.
https://arxiv.org/abs/1703.03400

[13] Alex Nichol, Joshua Achiam, John Schulman.On First-Order Meta-Learning Algorithms.arXiv preprint arXiv:1803.02999 (2018).
https://arxiv.org/abs/1803.02999

[14]Slides on Reptileby Yoonho Lee.
https://www.slideshare.net/YoonhoLee4/on-firstorder-metalearning-algorithms

[15] Schmidhuber J .Evolutionary principles in self-referential learning[J]. genetic programming, 1987.

[16] Tom Schaul and Jürgen Schmidhuber,Metalearning,Scholarpedia,6,4650,2010,https://doi.org/10.4249/scholarpedia.4650

[17] Chelsea Finn and Pieter Abbeel and Sergey Levine,Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks,arXiv preprint arXiv:1703.03400,(2017)

[18] Weng, Lilian,Meta-Learning: Learning to Learn Fast,lilianweng.github.io/lil-log, weng2018metalearning, 2018

http://lilianweng.github.io/lil-log/2018/11/29/meta-learning.html

[19] Timothy Hospedales and Antreas Antoniou and Paul Micaelli and Amos Storkey,Meta-Learning in Neural Networks: A Survey,hospedales2020metalearning,arXiv,cs.LG,2004.05439,2020.
https://arxiv.org/abs/2004.05439

[20] 元学习该何去何从 https://zhuanlan.zhihu.com/p/26636240

[21]MAML-TensorFlowhttps://github.com/dragen1860/MAML-TensorFlow

[22] 元学习(Meta Learning)最全论文、视频、书籍资源整理 https://zhuanlan.zhihu.com/p/70044607

[23] 最前沿:百家争鸣的 Meta Learning/Learning to learn https://zhuanlan.zhihu.com/p/28639662

[24] *Paper repro: Deep Metalearning using “MAML” and “Reptile”*https://towardsdatascience.com/paper-repro-deep-metalearning-using-maml-and-reptile-fd1df1cc81b0

[25] 一文入门元学习(Meta-Learning)(附代码)https://zhuanlan.zhihu.com/p/136975128

[26] 荐读Meta-Learning in Neural Networks: A surveyhttps://zhuanlan.zhihu.com/p/133159617

[27]Agent57: Outperforming the human Atari benchmarkhttps://deepmind.com/blog/article/Agent57-Outperforming-the-human-Atari-benchmark

[28] 《2021春机器学习课程》李宏毅 https://speech.ee.ntu.edu.tw/~hylee/ml/2021-spring.html

一文弄懂元学习 (Meta Learing)(附代码实战)《繁凡的深度学习笔记》第 15 章 元学习详解 (上)万字中文综述
谢谢!

转载请注明出处:https://fanfansann.blog.csdn.net/
版权声明:本文为 CSDN 博主 「繁凡さん」(博客) ,[知乎答主 「繁凡」(专栏) , Github 「fanfansann」(全部源码) , 微信公众号 「繁凡的小岛来信」(文章 P D F 版) )的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
xiaoxingxing的头像xiaoxingxing管理团队
上一篇 2022年3月22日 上午11:42
下一篇 2022年3月22日 下午12:36

相关推荐