Paper Reading: NBDT: Neural-Backed Decision Trees

发布时间 2023-08-10 22:58:45作者: 乌漆WhiteMoon


Paper Reading 是从个人角度进行的一些总结分享,受到个人关注点的侧重和实力所限,可能有理解不到位的地方。具体的细节还需要以原文的内容为准,博客中的图表若未另外说明则均来自原文。

论文概况 详细
标题 《NBDT: Neural-Backed Decision Trees》
作者 Alvin Wan, Lisa Dunlap, Daniel Ho, Jihan Yin, Scott Lee, Henry Jin, Suzanne Petryk, Sarah Adel Bargal, Joseph E Gonzalez
发表会议 International Conference on Learning Representations(ICLR)
论文代码 https://github.com/alvinwan/neural-backed-decision-trees

研究动机

许多计算机视觉应用需要了解模型的决策过程,但是深度学习模型是黑盒的,要实现这一点将会复杂。最近的一些可解释计算机视觉方面的研究试图解决这一需求,这些方法可以分为显著性图和顺序决策过程两类。显著性图通过识别哪些像素对预测影响最大来追溯解释模型预测,但是该方法只关注输入,无法获取模型的决策过程。也可以通过将预测分解为一系列更小的语义上有意义的决策来深入了解模型的决策过程,就像在决策树等基于规则的模型中一样。然而现有的融合深度学习和决策树的工作存在预测精度显著低于一些常用的模型,同时对精度进行优化的过程也导致了可解释性降低的问题。

文章贡献

为了提高计算机视觉模型的可解释性,本文融合深度学习和决策树提出了神经支持决策树(NBDTs)。NBDT 使用一个可微的倾斜决策树取代了神经网络的最后一个线性层,和经典的决策树方法不同,NBDT 使用从模型参数派生的层次结构,不使用分层 softmax。NBDT 可以从任何现有的分类神经网络中创建,无需对模型架构进行修改。这样的模型结构不会过度拟合特征空间,减少了决策树对高度不确定决策的依赖,并鼓励对高级概念的准确识别。通过实验证明 NBDT 在 ImageNet、TinyImageNet200 和 CIFAR100 上的性能等同于或优于一些现有的模型,模型提供的解释可以让用户更方便地识别模型的错误,并且可用于识别模糊的 ImageNet 标签。

本文方法

神经支持决策树(NBDTs)用决策树代替网络最后的线性层,NBDT 使用路径概率进行推理以容忍高度不确定的中间决策,从预训练的模型权重构建层次结构以减少过拟合,并使用分层损失进行训练以更好地学习高级决策。

推理

NBDT 使用神经网络对每个样本进行表示,将最终的全连接层修改为使用倾斜决策树。但是这样的设置存在两个问题:

  1. 经典决策树不能修正前面神经网络的早期错误;
  2. 直接将神经网络提取的特征输入决策树会显著降低准确率;

针对这个问题本文使用软决策规则进行替代,经典的倾斜决策树使用的是硬决策,每个节点选择具有最大内积的子节点进行访问,直到叶子节点输出决策结果。软决策中每个节点只是返回每个子节点的归一化内积的概率,通过计算其到根节点路径的概率选择概率最大的叶子结点输出。
软决策和硬决策的效果如下图所示,假设 w4 是正确的输出。如果使用硬决策将访问红色节点进而产生错误的输出,这个错误是不可恢复的。然而如果使用软决策,w3 处的高度确定决策将修正在根和 w2 处的高度不确定决策。说明在前置节点出现错误时,模型可以通过软决策正确地选择 w4。

首先用神经网络的权重生成倾斜决策树的决策规则,将权重向量 ni 与每个节点关联起来。对于叶节点,每个 ni=wk(i=k∈[1,k])是一个来自全连通层的权值 W∈RD×K 的行向量。对于所有的非叶节点,在节点 i 的子树中找到所有的叶子 K∈L(i)(i∈[K+1,N]),并取其权值的平均值 ni=∑K∈L(i)wk/|L(i)|。接着使用 softmax 内积计算节点概率,对于每个样本 x 和节点 i,使用 p(j|i)=SOFTMAX(<ni->,x>)[j] 计算每个子节点 j∈C(i) 的概率,其中 ni->=(<nj,x>)j∈C(i)。最后使用路径概率选择叶子节点,设某个类别 k 和该类别从根节点到叶节点的路径 Pk,i∈Pk 为路径中的一个节点,遍历路径 Ck(i)∈Pk ∩ C(i) 时下一个节点的概率为 p(Ck(i)|i)。则该叶子是的类别 k 的概率如下公式所示:

软决策就是在这些类概率上定义了最终类预测 k^,计算公式如下:

建立层次结构

现有的基于决策树的方法通常根据数据本身启发地构建的层次结构,这样对数据会过拟合。也可以类似 WordNet 使用现有的层次结构,这样关注的是概念上的相似性而不是视觉上的相似性。例如 WordNet 认为鸟是动物所以更接近猫,而不是和飞机更为接近,然而在视觉上鸟更相似于飞机。
为了防止过拟合并反映视觉相似性,本文使用模型权重构建层次结构。首先取全连通层权值 W 中的行向量 wk(k∈[1,k]) 代表一个类,然后对规范化的类向量 wk/||wk||2 进行分层聚集聚类,决定哪些节点和节点组是迭代配对的。3

例如下图 A 中的红色 w1 被分配给 B 中的红色叶子,取 B 中的 w1 和 w2(红色和紫色)的平均值,得到 C 中的 w5 (蓝色),根结点的权值就是所有叶结点 w1、w2、w3、w4 权重的平均值。

用 WordNet 标记决策节点

WordNet 是一个大型英文词汇数据库,以层次结构的形式组织。使用 WordNet 给节点分配含义,方法是计算子树中所有叶子的最早的共同祖先。例如假设 Dog 和 Cat 是共享父节点的两个叶子,在 WordNet 中查找它们共享的所有祖先概念,如哺乳动物、动物和生物。其中最接近的共同祖先是哺乳动物,所以把哺乳动物归为 Dog 和 Cat 的祖先。但 WordNet 语料库缺乏非对象概念的信息(如对象属性)和抽象的视觉概念(如上下文),这是一个待解决的问题。

微调和树监督损失

标准的交叉熵损失定义了类别的损失,但是无法定义非叶节点的损失。本文添加了树监督损失,即路径概率类分布上的交叉熵损失 Dnbdt={p(k)}Kk=1随时间变化的权重 ωt 和 βt,使用如下的公式表示:

当叶权值无意义时,树监督损失会损害训练早期的学习速度。此时需要将树监督权 ωt 线性增长,βt∈[0,1] 随时间线性衰减。只有在原始模型精度不可复制时才使用 Lsoft 进行微调,Lsoft 不成比例地提高了分层早期决策的权重,鼓励准确的高层决策。

实验结果

对比实验结果

对比模型选择了 ResNet、WideResNet、EfficientNet,下表展示了小规模数据集 CIFAR10、CIFAR100、TinyImageNet 的实验结果。实验结果表明本文的 NBDT 的性能优于或等同于对比的算法。在 CIFAR10 和 TinyImageNet 上 NBDT 精度低于基线神经网络 0.15%,在 CIFAR100 上比基线高出约 1%。

在大规模数据集 ImageNet 上的实验结果如下表所示,结果可见 NBDT 准确率最高,比最好的对比算法 NofE 高出 15%。

NBDT 的改进主要是通过显著提高区分高级概念的能力,下表展示了诱导层次结构和、WordNet 层次结构、基于神经特征构建的经典决策树、基于神经特征构建的倾斜决策树的数据,可见诱导层次结构优于其他层次结构。

先前的研究认为使用分层 softmax 对于分层分类器是必要的,但是从下表汇总的实验结果可见,使用分层 softmax 训练的 NBDT 比在 TinyImageNet 上使用树监督损失训练的准确率低约 3%。

经过树监督损失的训练后,模型还可以正常运行原始神经网络的全连接层,实验结果如下表所示,可见原始神经网络在 CIFAR100、TinyImageNet 上的准确率提高了 2%。

零次学习指的是对于要分类的对象一次也不学习,训练集和测试集之间没有交集,而是通过类别之间的描述建立彼此之间的联系。此处将“超类”定义为几个类的上位词,例如 Animal 是 Cat 和 Dog 的超类)。通过实验确定每个 NBDT 内部节点在哪些超类之间做出决定,实验结果如下表所示,课件 NBDT 始终比原始神经网络高出 8% 以上,在区分食肉动物和有蹄类动物时 NBDT 比原始神经网络高出 16%。

还测试了 NBDT 不使用预训练的权重,而是在训练时从部分训练的网络的权重构建层次结构。实验结果如下表所示,可见具有中间训练层次的树监督损失可靠地提高了原始神经网络的精度,最高可达 0.6%。但是这样的性能不如使用完整的结构,表明完全训练的权值仍然是构建层次的首选。

可解释性

根据一些现有的研究对可解释性定义:如果人可以验证模型的预测,并能确定模型何时犯了错误,那么模型就是可解释的。对于 NBDT 模型的可解释性的验证使用 cifar10 数据集训练的模型,使用 ResNet18 作为骨干神经网络。

识别错误的模型预测

这方面的验证以问卷调查的形式展开,每个用户得到 3 张图片,其中 2 张是正确分类的,1 张是错误分类的,用户需要在给定模型解释和没有最终预测的情况下预测哪个图像被错误分类。对于显著图方法来说,这项任务是不可完成的,因为无论分类错误或正确,显著图都只会突出显示图像中的主要对象,而本文的 NBDT 通过分层方法提供了一系列可检查的中间决策。
作者收集了 600 份调查回复,当给出显著性图和类别概率时只有87个预测被正确识别为错误,而当给出 NBDT 预测类别和概率时 237 张图像被正确识别为错误,受访者在 NBDT 解释中识别错误的能力提高了近 3 倍。

引导图像分类

第一次调查中每个用户被要求对一张严重模糊的图像进行分类,该用于判断问题的难度,调查结果是 600 个回答中有 163 个是正确的(准确率为 27.2%)。接下来的调查提供了模糊的图像和两组预测,分别是原始神经网络的预测类及其显著性图,以及 NBDT 预测类和决策序列。其中在 30% 的例子中 NBDT 是正确的而原始模型是错误的,另外 30% 的的情况正好相反,剩下的 40% 中两种预测都是错误的。
这些图片如下图所示,图像非常模糊,因此用户必须依靠模型来进行预测。在本次调查中 600 个回答中有 255 个是正确的(准确率为 42.5%),比没有模型指导提高了 15.3点。在调查中 NBDT 的预测更可信,在 600 份回复中有 312 份回复同意 NBDT 的预测,167 份回复同意基本模型的预测,119 份回复不同意两个模型的预测。大多数用户决策(约 80%)同意模型预测中的任何一个,表明图像已经足够模糊,以至于不得不依赖模型。此外尽管只有 30% 的 NBDT 预测是正确的,但 52% 的回答同意NBDT。

人更倾向的解释

NBDT 预测的解释是遍历决策路径的可视化,然后将这些 NBDT 解释与人类研究其他可解释性方法进行比较。这次调查要求参与者在显著性图和 NBDT 选择一种更可信的解释,此处只使用 ResNet18 和 NBDT 预测一致的样本。在 374 名受访者中,正确分类的样本中有 65.9% 的人更喜欢 NBDT 的解释,错误分类的样本有 73.5% 的人更喜欢 NBDT。

识别有缺陷的数据标签

下图展示了几种类型的模糊标签,任何一种都可能损害图像分类数据集的模型性能,可以通过在 NBDT 预测路径上找到具有高“路径熵”或高度不同熵的样本来识别模糊标签。

下图所示 ImageNet 中最高的“路径熵”样本包含多个对象,每个对象都可以用于图像分类,而在基线神经网络中诱导最高熵的样本并不暗示模糊标签。这表明与标准神经网络相比,NBDT 熵的信息量更大。

优点和创新点

个人认为,本文有如下一些优点和创新点可供参考学习:

  1. 本文的模型使用神经网络作为骨干模型,使用决策树对网络生成低维、稠密的表示进行决策,其他的一些问题也可以参考这样的设计思路;
  2. 和常用的硬决策不同,本文的决策树利用决策路径实现具有修正作用的软决策,同时也提高了可解释性;
  3. 将语料数据库和决策层次结构结合,为这些层次赋予了实际意义,具有很强的可解释性;
  4. 实验部分和可解释调查部分内容非常丰富,利用了大量图表从多方面展示了 NBDT 的优势,非常清晰、具体。