Graph Neural Networks with Adaptive Residual

发布时间 2023-10-31 19:45:34作者: 馒头and花卷

Liu X., Ding J., Jin W., Xu H., Ma Y., Liu Z. and Tang J. Graph neural networks with adaptive residual. NIPS, 2021.

基于 UGNN 框架的一个更加鲁棒的改进.

符号说明

  • \(\mathbf{A} \in \mathbb{R}^{n \times n}\), 邻接矩阵;
  • \(\mathbf{D} = \text{diag}([d_1, d_2, \ldots, d_n]), \quad d_i = \sum_{j} A_{ij}\).
  • \(\mathbf{\tilde{A}} = \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{1/2}\);

AirGNN

  • 下面是在不同的图任务上的一个训练结果:

  • 可以发现, 残差连接可以帮助 GNNs 利用更多的层去区别正常的结点, 但是却使得在异常结点上的分类恶化.

  • 我们可以这样认为, 简单的没有残差连接的图网络能够平滑结点表示, 所以此时随着层数的加深, 对于异常结点的分类会更好. 相反, 如果加了残差连接, 最后的结点表示始终会受到一开始的异常结点表示的影响, 所以结果并不太好.

  • 但是, 我们也不能直接移除残差连接, 因为这是加深 GNN 的几乎必须的技巧.

  • 一般的 GCN 都可以归结为如下的方式:

    \[\mathbf{X}_{out} = \text{argmin}_{\mathbf{X} \in \mathbb{R}^{n \times d}} \: \lambda \|\mathbf{X} - \mathbf{X}_{in}\|_F^2 + (1 - \lambda) \frac{1}{2} \text{tr}(\mathbf{X}^T (\mathbf{I} - \mathbf{\tilde{A}}) \mathbf{X}). \]

  • \(\|\mathbf{X} - \mathbf{X}_{in}\|_F^2 = \sum_{i=1}^n \|\mathbf{X}_i - (\mathbf{X}_{in})_i\|_2^2\), 我们知道, \(\|\cdot\|_2^2\) 对于异常值是敏感的, 所以作者转而改写成如下的更加鲁棒的方式:

    \[\text{argmin}_{\mathbf{X} \in \mathbb{R}^{n \times d}} \: \lambda \|\mathbf{X} - \mathbf{X}_{in}\|_{21} + (1 - \lambda) \text{tr}(\mathbf{X}^T (\mathbf{I} - \mathbf{\tilde{A}}) \mathbf{X}), \]

    其中

    \[\|\mathbf{X} - \mathbf{X}_{in}\|_{21} := \sum_{i=1}^n \|\mathbf{X}_i - (\mathbf{X}_{in})_i \|_2. \]

  • 通过 proximal gradient descent 来求解上面的问题, 得到如下的迭代方式:

  • 一个直观的理解是:

    • 当结点 \(i\) 的特征异常的时候, 通常 \(\|\mathbf{Y}_i - (\mathbf{X}_{in})_i\|_2\) 比较大;
    • 这就导致 \(\beta_i\) 比较大;
    • 此时 \(\mathbf{X}_i^{k+1}\) 更多由它的邻居决定 (即 \(\mathbf{Y}_i^k\)), 否则由它本身 \(\mathbf{X}_{in}\) 决定.

代码

[official]