Paper Reading: Interpretable Rule Discovery Through Bilevel Optimization of Split-Rules of Nonlinear Decision Trees

发布时间 2023-03-27 00:50:49作者: 乌漆WhiteMoon


Paper Reading 是从个人角度进行的一些总结分享,受到个人关注点的侧重和实力所限,可能有理解不到位的地方。具体的细节还需要以原文的内容为准,博客中的图表若未另外说明则均来自原文。

论文概况 详细
标题 《Interpretable Rule Discovery Through Bilevel Optimization of Split-Rules of Nonlinear Decision Trees》
作者 Yashesh Dhebar, Kalyanmoy Deb
发表期刊 《IEEE Transactions on Cybernetics》
发表年份 2021
期刊等级 中科院 SCI 期刊分区(2022年12月最新升级版)1区、CCF-B

作者单位:
Department of Mechanical Engineering, Michigan State University, East Lansing
Department of Electrical and Computer Engineering, Michigan State University, East Lansin

研究动机

分类任务的任务是设计一种算法,在给定的数据集上得到一个或多个分类规则,能够以最大的分类精度进行预测。然而除了分类的准确性,在许多应用中更注重于得到一个易于解释的分类器。可解释的分类器有助于识别完成分类任务的最重要的规则,还有助于提供更直接的特征关系,以增强知识并有利于未来的发展。容易解释的分类器的定义很大程度上取决于上下文,但现有的文献将它们表示为只涉及少数特征的线性、多项式或正项函数,从而无法得到良好的可解释性。
例如下图表示二维数据上不同的分割规则,分割规则 A 可能过于复杂且不可解释,分割规则 B 和 C 具有类似的简单且可解释的结构,但 C 比 B 更准确。

在理想情况下,最简单的拆分规则将只涉及一个特征。但在大多数复杂问题中,涉及单个特征的简单分割规则可能会导致太深的 DT,从而使分类器不可解释,例如使用 ID3 和 C45 诱导的 DT。在另一个极端,拓扑上最简单的树将对应于只涉及一个条件节点的 DT,但相关规则可能太复杂而无法解释,NN 和 SVM 属于这一类。

文章贡献

本文设计了一个高精度且易于解释的分类方法,首先提出了一个非线性 DT(NLDT)作为分类器,每个非终端条件节点将表示一个非线性特征函数来表达分裂规则。其次为了得到给定条件节点上的分裂规则,采用了双层优化算法,将规则结构和相关系数的学习视为两层的、相互关联的优化任务。通过自定义的双层优化方法来进化简单的规则结构,不仅计算效率高,而且获得的规则也是可解释的。

Nonlinear Decision Tree

分类器的表示

本文的 Nonlinear Decision Tree(NLDT) 由条件节点和终端叶节点组成,每个条件节点都有一个与之关联的规则(fi(x)≤0),其中 x 代表特征向量。为了使 DT 更具解释性,需要考虑两个方面:

  1. 分裂规则 fi(x) 在每个条件节点上的简单性;
  2. DT 的拓扑结构是简单的,这通过计算条件节点的总数来衡量。

因此新算法需要针对上述两个方面进行折衷,使 DT 不会太深,并且每个条件节点 fi(x) 上的相关分裂规则函数具有可控复杂性的非线性形式,从而易于解释。同时非线性 DT 需要使用有效的非线性优化方法来进化非线性规则,为此本文利用了两层优化的方法来实现。
在训练阶段主要使用递归算法诱导 NLDT,如下面的算法 1 所示:

在算法 2 中给出了获取分裂规则的双层算法伪代码:

在算法 3 中给出了上层遗传算法中种群评估的伪代码:

分裂规则的表示

本文中将划分规则的表达式限制在运算特征向量 x 上的 DT 的一个条件节点上,假设结构如下公式 1 所示:

其中 f(x,w,Θ,B) 可以用两种不同的形式表示,如下公式 2 所示,这取决于是否寻求模算子 m。其中 wi 是几个幂律(Bi)的系数或权重,θi 是偏差,m 表示模算子的存在或不存在,p 是超参数,为 f(x) 的表达式中可以存在的幂律(Bi)的最大数量。

Bi 表示一种幂律规则,如下公式 3 所示:

B 是指数 bij 的分块矩阵,如下公式 4 所示。bij 取值于指定的离散集 E 中的假设值,本文设置 p=3 和 E={-3,-2,-1,0,1,2,3} 以使规则可解释。

参数 wi 和 θi 为 [−1,1]中的实值变量,特征向量 x 是 d 维空间中的一个数据。另一个用户定义的参数 amax 控制每个幂律 Bi 中可以出现的变量的最大数量,默认值是 d(即特征空间的维度)。

树的诱导和剪枝

根据以下终止条件检查所得到的子节点,如果满足这些条件中的任何一个,则将该节点设置为叶结点。否则为附加条件节点,进行另一次拆分,
并重复该过程。最终产生需要的 NLDT,然后进行剪枝。

  1. 节点深度 > 最大允许深度;
  2. 节点内数据点数 < Nmin;
  3. 节点杂质 ≤ τmin。

剪枝将系统地删除根节点后的分割,直到修剪树的训练精度不低于预先指定的阈值 τprune=3% (预实验得到)。这使得生成的树在结构上不那么复杂,并提供了更好的泛化性。

双层优化方法

推导分割规则的层次目标

f(x)的几何形状由公式 2 和 3 的指数项 bij 定义,特征空间中的方向和位置由系数 wi 和偏差 θi 决定。因此 f(x) 优化任务涉及两类变量:

  1. 表示 b 项指数的 B 矩阵和表示f(x)中是否存在模算子的模标志 m;
  2. 每个规则中函数 f(x) 的权重 w 和偏差。

与权重和偏差相比,确定 B 项和 m 值是一项更困难的任务,最好是将搜索 (B,m) 与同级的权重偏向搜索分开,作为一个详细的任务。变量的这种层次结构促使本文采用一种双层优化方法,利用上层优化算法搜索 B 和 m 的空间。然后对于每组 (B,m),调用下层优化算法来确定 w 和 θ 的最优值。
双层优化问题可以表述为如下公式 5,上层目标 FU 是分割规则简单性的量化,下层目标 FL 量化由于分割规则 f(x)≤0 而产生的分割质量。只有在能够在参数 τI 规定的可接受限度内划分数据时,才认为上层解决办法是可行的。

对于 FL 的计算,本文使用 Gini 来生成结点,对于父节点 P 和两个子节点 L 和 R,子节点(FL)的净杂质用如下的公式 6 计算。其中 NP 为 P 的数据总数,NL 和 NR 分别为 L 和 R 中的数据总数。P 中满足分割规则 fP(x)≤0(1) 的数据点到左子树 L,其余数据到右子树 R,最小化子节点的净杂质的目标 FL 有利于创建更纯的子节点。

对于 FU 的计算,FU 的目标是产生视觉上简单的分割规则方程的主观概念,通常有更多变量和项的方程看起来是更复杂的。因此本文将 FU 设为公式 1 的总方程中出现的非零指数项的总数,数学上可以用如下的公式 7 表示。

上层优化(ULGA)

对于上层的优化问题,采用 GA 对上层 B 和 m 进行探索,基因组用 (B,m) 表示,其中 B 为如 4 所示的矩阵,m 为布尔值 0 或 1。上层 GA 侧重在子节点的净杂质(FL)的期望值内估计分裂规则的简单方程,因此上层为单目标约束优化问题,如 5 所示。FL 在框架的下层进行评估,阈值 τI 表示子节点的 6 的期望值,实验中设置 τI=0.05。

初始化上层 GA

最小化的公式 7 中需要更少数量的非零指数的公式 1,因此初始化种群限制只有一个非零指数在 split-rule 的表达。满足上述限制的唯一个体数量只有 2d 个(d 个 m=0 的个体和 d 个 m=1 的个体)。如果的种群大小超过 2d,则剩余的个体将被初始化为两个非零活动项。

上层解的排序

Select 使用如下定义的分层排序标准,对于上层的两个个体 i 和 j,当下列任意一项成立时 rank(i)>rank(j)。

  1. i and j are both infeasible AND FL(i) < FL(j).
  2. i is feasible (i.e., FL(i) ≤ τI) AND j is infeasible(i.e.,FL(j)>τI).
  3. i and j both are feasible AND FU(i) < FU(j).
  4. i and j both are feasible AND FU(i) = FU(j) AND FL(i) < FL(j).
  5. i and j both are feasible AND FU(i) = FU(j) AND FL(i) = FL(j) AND m(i) < m(j).

交叉算子

首先根据种群成员的 m 值进行聚类,m 为 0 和 1 的个体各属于一簇,交叉操作被限制在同簇的个体上。交叉从父群体中获取两个块矩阵(BP1 和 BP2),按照公式 2 中的权重 wi 的降序重新排列得到 B’P1 和 B’P2,然后在每一行上按元素进行交叉操作。

突变算子

突变操作将改变 bij 和 m 的值,上层解发生突变的概率由参数 pUmut 控制,通过实验后将它设置为 pUmut=1/d,bij 的修改方式如下图所示。图中的红色表示 bij 在 E 中的 id,橙色竖条表示突变值的概率分布。bij 可以突变为 k−2、k−1、k+1 或 k+2 的 id 值,概率分别为 α、βα、βα 和 α。超参数参数 β 最好大于1,参数 α=[1/(2(1+β))],概率之和等于1。在实验中设 β=3,m 的值以 50% 的概率随机突变为 0 或 1。

为了使搜索偏向于创建更简单的规则(即使用少量非零 bij 的 split-rules),引入参数 pzero 表示变量 bij 参与突变的概率设置为 0。本文使用 pzero=0.75,使得 bij→0 的净概率为 pUmut×pzero。然后检测和改变来自子代的副本,使整个子代由独立存在的个体组成,确保新规则的创建和多样性。然后将子代和父代组合,被选中的精英将进入下一次迭代。

下层优化(LLGA)

下层优化确定系数 wi(公式2) 和偏差θi,使 FL(w,θ)|(B,m)(公式 6)最小化,下层优化问题可以表述为如下的公式 8。

初始化方面,由于需要计算速度快的低级问题算法,此处使用混合偶极子对的概念来促进更促进更快地收敛到 w 和 θ 的最优值。令 xA 和 xB 分别是属于 Class1 和 Class2 的数据,则向量 xA xB 正交的分离的超平面对应的权值 wh 和偏置 θh 可由公式 9、10 给出。

△ 是 0 到 1 的随机数,则初始化的方法可由下图来描述:

初始种群中所有个体的偶极子对 (xA,xB) 从训练集中随机选择,变量 w 和 θ 然后对个体按照公式 11~13 进行初始化。

选择、交叉和突变算子,分别使用二元锦标赛选择、SBX 交叉和多项式突变来创建后代解,采用 (μ+λ) 生存选择策略保存精英。终止条件是达到最多 50 代或连续 10 代中目标值(FL)变化小于 0.01% 时终止。

实验分析

实验设置

本文在多个数据集上实验,训练集和测试集比例分别为 7:3 运行 50 次,指标为 Acc平均值和标准偏差。最后给出 NLDT 中条件节点的数量、DT 的 split-rules 中项的平均数量(即 FU/Rule)和规则长度(整个 NLDT 中活动项的总数)的数量,以量化分类器的简单性和可解释性。对比算法是 CART 和 SVM。

自定义数据集 DS1-DS4

下表列出在数据集 DS1-DS4 上运行 50 次的数据,结果表明提出的非线性、基于双层的 DT 方法优于经典的 CART 和基于 SVM 的分类方法。双层方法找到 1 个规则,在规则中有 2-3 个变量出现,而 CART 需要11-31个规则,每个规则涉及一个变量,SVM只需要一个规则,但在规则中有 44-90 个变量出现。

Breast Cancer Wisconsin 数据集

Breast 数据集有良性和恶性两个类别,良性有 458 个数据(占 65.5%),恶性有 241 个数据点(占 34.5%),每个数据点用 10 个特征表示。下表是的 Breast 数据集的结果,本文的算法和 SVM 具有相似的性能,但 SVM 在规则中需要大约 90 个变量,而本文的在单个规则中只需要大约 6 个变量。在准确性和可理解性/紧凑性方面,所提出的方法更优。

下图展示了通过双层优化方法运行获得的 NLDT 分类器,具有五个变量。

下图为本文方法得到的决策边界 b 空间可视化结果,该方法能够识别两个含变量的非线性 b 项,对数据进行线性分割,获得较高的精度。

Wisconsin Diagnostic Breast Cancer 数据集

WDBC 数据集是前一节数据集的扩展,共有 30 个特征,良性有 356 个数据,恶性有 212 个数据。下表的结果表明,基于两层优化的 NLDT 能够优于标准 CART 和 SVM 算法。

如下图所示,本文的算法仅使用了 7 个特征,它的结果几乎与支持向量机获得的结果一样准确,而且更具可解释性。

下图显示了本文算法能以高精度对提供的数据进行线性分类。

Real-World Autoindustry

在 Autoindustry 数据集中有 36 个特征、8 个约束条件和 1 个目标函数,创建的数据集是不均衡的,有 188 个数据属于好类,996 个属于坏类。RW-Problem 数据集的结果下表所示。

下图展示了本文算法仅需要两条分割规则就能获得接近 90% 的精度分数。

SVM 在 Acc 方面表现最好,但所得到的规则中约有 241 个特征,非常复杂。本文算法只需要大约两个规则,每个规则大约只有 10 个特征,以实现稍微不太准确的分类。CART 需要大约 30 条规则和很深的 DT,这使得分类器难以解释。

多目标优化问题

NLDT 在多个真实多目标数据集上进行了测试,分别是 welded-beam、2-D truss、m-ZDT 和 m-DTLZ,它们的实验结果如下表所示。

2-D truss 是一个三特征的数据集,与 CART 和 SVM 生成的分类器相比,NLDT 具有最好的精度和更少的特征。下图的 (a) 显示了 NLDT 在单个规则下达到 100% 正确率的分类器,(b) 显示了具有 19 个规则的 CART 分类器,显然后者更难以解释。

welded-beam 是双目标优化问题,有 4 个特征和 4 个约束,由实验数据所示 NLDTs 具有与 CART 和 SVM 相似的精度,且复杂度非常低。m-ZDT 和 m-DTLZ 问题有 500 个特征,并对两目标和三目标优化问题都做了实验,在所有这些问题中 CART 和 SVM 方法很难进行正确的分类任务。

多分类的 iris 数据集

目前研究的算法是针对二分类问题的,但它有可能可以扩展到多分类问题。iris 数据集有三个类别和四个特征,NLDT 在测试集上的分类准确率为 94.8%。下图是 NLDT 的一个样例,橙色为 versicolor 类,黄色为 virginica 类,绿色为 setosa 类。

优点和创新点

个人认为,本文有如下一些优点和创新点可供参考学习:

  1. 本文使用的 NLDT 结构用一个规则表达式作为分割的依据,是对传统 DT 的一种创新;
  2. 针对不同的参数,本文设计了上下两层用不同的策略分别优化,都取得了良好的效果;
  3. 本文的算法只需要用少量的特征和较浅的结构就能完成二分类任务,具有良好的可解释性;
  4. 由于涉及树模型和遗传算法,需要设计的组件很多,显得文章的工作较为复杂。