Paper Reading: Adaptive Neural Trees

发布时间 2023-05-27 21:30:58作者: 乌漆WhiteMoon


Paper Reading 是从个人角度进行的一些总结分享,受到个人关注点的侧重和实力所限,可能有理解不到位的地方。具体的细节还需要以原文的内容为准,博客中的图表若未另外说明则均来自原文。

论文概况 详细
标题 《Adaptive Neural Trees》
作者 Ryutaro Tanno, Kai Arulkumaran, Daniel C. Alexander, Antonio Criminisi, Aditya Nori
发表会议 International Conference on Machine Learning (ICML)
发表年份 2019
会议等级 CCF-A
论文代码 https://github.com/rtanno21609/AdaptiveNeuralTrees

作者单位:

  1. University College London, UK
  2. Imperial College London, UK
  3. Microsoft Research, Cambridge, UK

研究动机

神经网络和决策树都是强大且实用的机器学习模型,但是两种方法通常具有相互排斥的优点和局限性。神经网络是通过非线性变换的组合来学习数据的分层表示,该方法对特征工程的需求很小。同时神经网络是用随机优化器训练的,允许训练扩展到大型数据集。但是神经网络的架构通常需要手工设计,并根据任务或数据集进行固定,且有些任务下需要巨大的计算开销。决策树的特点是学习如何分割输入空间,令每个子集中都能用线性模型来解决问题。决策树的架构是基于训练数据进行优化的,在数据稀缺的情况下特别有优势。但是应用决策树时通常需要手工设计的数据特征,且损失函数是不可微的,所以限制了基于梯度下降的优化和复杂分割函数的使用。

文章贡献

本文设计了自适应神经树(ANT)将 NN 和 DT 的优点结合起来,ANT 将树结构中的路由决策和根到叶的计算路径表示为 NN,从而实现了分层表示学习。ANT 以树形拓扑作为一个强结构先验,通过该结构令特征以分层方式共享和分离。同时提出了一种基于反向传播的训练算法,基于一系列决策来生长 ANT 的结构。总而言之,ANT同时具备了表示学习、架构学习、轻量级推理的能力。通过SARCOS、MNIST 和 CIFAR-10 数据集的实验,证明了本文方法具有较好的性能,具有多种良好的特性。

自适应神经树

模型拓扑与操作

ANT(Adaptive Neural Trees)是一种树状结构模型,是基于三个可微操作的基本模块构成的,分别是 Routers、Transformers 和 Solvers,在图中分别用白色圆圈、黑色圆圈和灰色圆圈标出。

  1. Routers:路由器,以样本作为输入,并确定将样本发送到左分支或右分支。例如可以将定义为一个小的 CNN,对该 CNN 的输出求平均后从伯努利分布中采样来决策进入哪个分支,左分支为 1,右分支为 0。
  2. Transformers:变压器,ANT 的每条边都由一个或多个 Transformers 组成。每个 Transformers 都是一个非线性函数,用于将样本进行非线性变换后继续向下传递。例如它可以是一个卷积层 + ReLU,并且可以在一条边上堆叠多个 Transformers 来实现对特征的深度转换。
  3. Solvers:求解器,它对转换后的输入数据进行决策,对于分类任务可以将其定义为线性分类器。


当样本输入时,就会经过 Transformers 进行转换,通过 Routers 进行分支决策,最后利用到达的 Solvers 进行求解。针对不同的问题,Routers、Transformers 和 Solvers 的设置如下表所示:

表中的符号及其含义如下所示:

符号 含义
conv5-40 空间大小为 5×5 的 40 核二维卷积
GAP global-average-pooling
FC fully connected layer
LC linear classifier
LR linear regressor
Downsample Freq 2×2 max-pooling

概率模型与推理

ANT 由树的根结点到叶节点构成的一个层级混合专家系统(hierarchical mixture of experts, HMEs)产生输出,每个 HMEs 都被定义为一个神经网络。ANT 的每个输入 x 根据 Routers 的决策来遍历树,并经历一系列 Transformers 的变换,直到到达一个叶节点,用对应的 Solvers 预测标签 y。假设树中有 L 个叶节点,参数为 Θ = (Θ, ψ, φ),则模型对样本输出的条件概率如下:

公式中的符号及其含义如下所示:

符号 含义
Θ Routers 的参数
ψ Transformers 的参数
φ Solvers 的参数
z z∈{0,1}^L,是 L 维的二值变量,为左右分支的选择

混合系数 π 量化了 x 被分配到叶节点 l 的概率,由从根节点到叶节点 l 的唯一路径 Pl上 所有 Routers 的决策概率的乘积给出,公式如下所示。式中的 l→j 为一个二值关系,且仅当叶子 l 位于节点 j 的左子树时为 true,xψj 为 x 在节点 j 处的特征表示。

令 Tj 表示从根节点到节点 j 路径上 n 个变压器模块的有序集,特征向量 xψj 可由如下公式求得:

最后的 Solvers 的输出由 sφl(xψparent(l)) 给出,算法支持多路径推理和单路径推理。多路径推理将给出的完整预测分布,但是计算时需要对所有叶子的分布进行平均,如果 ANT 模型较大则会带来巨大的计算开销。单路径推理仅使用沿路由器置信度最高的方向,从该方向贪婪遍历树所选择的叶节点上的预测分布,从而降低了计算开销。

优化

ANT 的训练分为两个阶段,在生长阶段 ANT 基于局部优化学习模型架构,在细化阶段 ANT 基于全局优化进一步调整模型参数。

损失函数方面,使用负对数似然(NLL)作为两个阶段共同目标函数来最小化 Loss,公式如下所示。由于所有模块的参数都是可微的,因此使用反向传播进行梯度计算,并使用梯度下降最小化 NLL 以学习参数。

在生长阶段,ANT 从根节点开始以宽度优先的顺序选择一个叶节点,向其添加计算模块来增量修改体系结构。ANT 在每个叶节点上评估以下 3 种策略:

生长策略 说明
split data 增加一个新的 Routers 拆分节点
deepen transform 增加一个新的 Transformers 来增加来边深度
keep 保留当前的结构

接着通过梯度下降最小化 NLL 来局部优化新增模块的参数,同时固定前一部分的参数。最后选择具有最低验证 NLL 的模型,如果它改善了之前的最低 NLL 则保留,这个过程逐级重复直到收敛。

评估生长策略的基本原理是让模型在深入或分割数据空间之间选择最有效的选项,拆分节点相当于对传入数据的特征空间进行软划分,并产生两个新的叶节点。加深边缘则是通过额外的非线性变换来学习更丰富的表示,并用新的求解器替换旧的求解器。在生长阶段确定了模型拓扑之后,通过执行全局优化来细化模型的参数,此时将针对 ANT 的所有模块的参数在 NLL 上执行梯度下降。

实验结果

本文使用 SARCOS 多元回归数据集、MNIST 和 CIFAR-10 数据集进行评估,同时进行消融实验,所有的模型都是在 PyTorch 中实现。

模型性能

将 ANT 的性能与一系列 DT 和 NN 模型进行比较,ANT 在 SARCOS 上实现了最低的误差,并且在 MNIST 和 CIFAR-10 上表现良好。在 SARCOS 数据集中,全路径的 ANT-SARCOS 的 MSE 优于所有其他方法,单路径时 GBT 的性能略好于单个 ANT 同时需要更少的参数。

在 MNIST 数据集上 ANT-MNISTA 在精度上优于最先进的 GBT 和 RF 方法,与 LeNet-5 相比 ANT-MNIST-A 和 ANT-MNIST-B 在参数数量较少的情况下获得了更好的精度。CapsNets 比 ANT-MNIST-A 具有更多的参数,ANT 则可以在参数更少的情况下达到类似的性能。ANT-MNIST-C 使用了最简单的模块,单路径推理能在使用大致相同的参数量的情况下明显优于线性分类器。

CIFAR-10 数据集上 ANT 优于 gcForest,在单路径推理时 ANT-CIFAR-A 比没有捷径连接的 CNN 模型具有更高的准确率。

消融实验

本文比较了在禁用 Transformers 或 Routers 的情况下,ANT 的不同结构的预测误差如下表所示。禁用 Transformers 时模型相当于 HMEs,禁用 Routers 相当于使用标准 CNN。在所有三个数据集上,任何一种消融都会导致不同模块配置的更高误差,证明 ANT 中特征学习和分层划分的结合是合理的。

可解释性

ANT 算法的生长过程能够发现有用的层次结构,在没有对 Routers 施加任何正则化的情况下,学习到的层次结构通常会在 MNIST 和 CIFAR-10 数据集中显示某些类别的强专门化路径。

细化阶段的影响

全局细化阶段改善了泛化误差,下图显示了 CIFAR-10 上各种 ANT 的泛化误差,垂直虚线表示模型进入细化阶段的代数。几种设置都获得更高的测试精度,并且全局优化使 Routers 的决策概率两极分化,导致一定的剪枝效果。

自适应模型复杂度

在 CIFAR-10 上设置大小为 50、250、500、2.5k、5k、25k 和完整训练集的数据集子集上训练 ANT、All-CNN 和线性分类器的三个变体,选择 All-CNN 作为基线,利用 5k 个样本的验证集的性能来选择最优模型。下图展示了实验的分类性能,随着数据集越来越小,ANT 和 All-CNN/线性分类器的测试精度之间的差距增加。

下图显示了 ANT 的模型大小随数据集大小的变化情况,参数的数量通常因数据集大小的增加而增加。All-CNN 的参数是固定的,始终比构造的 ANT 要大,并且存在过拟合的问题,线性分类器对则体现为欠拟合。ANT 能得到足够复杂的模型,具有更好的泛化能力。

优点和创新点

个人认为,本文有如下一些优点和创新点可供参考学习:

  1. 本文用树形结构来自适应地构建模型,决策路径构成了一个 NN,思路非常具有创新性;
  2. ANT 的结构包括子空间的划分、特征构造和决策模型的部分,模型结构设计有参考价值;
  3. 本文的方法在不同的问题上可通过不同的方式实现,模型的迁移性强。