Graph Neural Networks Inspired by Classical Iterative Algorithms

发布时间 2023-06-11 09:32:41作者: 馒头and花卷

Yang Y., Liu T., Wang Y., Zhou J., Gan Q., Wei Z., Zhang Z., Huang Z. and Wipf D. Graph neural networks inspired by classical iterative algorithms. ICML, 2021.

基于广义 energy function (diffusion) 的图神经网络.

符号说明

  • \(\mathcal{G} = \{\mathcal{V, E}\}\), 图;
  • \(n = |\mathcal{V}|\), \(m = |\mathcal{E}|\);
  • energy function:

    \[\tag{1} \ell_Y (Y) := \|Y - f(X; W)\|_{\mathcal{F}}^2 + \lambda \text{tr}(Y^T L Y). \]

    其中 \(X \in \mathbb{R}^{n \times d_0}\) 为 node features. \(f(\cdot)\) 是特征变换, 比如:

    \[f(X; W) = XW, \quad f(X; W) = \text{MLP}[X; W]. \]

    \(L = D - A = B^T B\), 为 \(\mathcal{G}\) 的 Laplacian 矩阵. \(D, A\) 分别为 degree 和 adjacency matrix. \(B \in \mathbb{R}^{m \times n}\) 为入射 (incidence) 矩阵 (注: 说实话, 我不知道这个分解是怎么得到的).

Motivation

  • 对于问题 (1), 我们有显式解:

    \[\tag{2} Y^*(W) = (I + \lambda L)^{-1} f(X; W), \]

    \(n\) 比较大的时候, 求逆是费时费力的, 所以可以通过多步梯度下降来近似, 我们知道:

    \[\nabla_{Y} \ell_Y = 2 \lambda L Y + 2 Y - 2 f(X; W), \]

    取 step size 为 \(\alpha / 2\) 可得迭代公式:

    \[Y^{(k+1)} = Y^{(k)} - \alpha [(I + \lambda L)Y^{(k)} - f(X; W)]. \]

    这种方式存在一个问题, \(I + \lambda L\) 具有很大的条件数, 这会导致整体的收敛非常慢, 所以作者认为可以利用 Jacobi preconditioning 来 rescale 这一步:

    \[\begin{array}{ll} \tag{6} Y^{(k+1)} &= Y^{(k)} - \alpha \tilde{D}^{-1} [(I + \lambda L)Y^{(k)} - f(X; W)] \\ &= Y^{(k)} - \alpha \tilde{D}^{-1} [(I + \lambda D - \lambda A)Y^{(k)} - f(X; W)] \\ &= (1 - \alpha) Y^{(k)} - \alpha \tilde{D}^{-1} [- \lambda AY^{(k)} - f(X; W)] \\ &= (1 - \alpha) Y^{(k)} + \alpha \tilde{D}^{-1} [\lambda AY^{(k)} + f(X; W)]. \end{array} \]

    其中 \(\tilde{D} = \lambda D + I\).

  • 我们知道, 如果 \(A_{ij} \in \{0, 1\}\), 此时有:

    \[\text{tr}[Y^T L Y] = \sum_{(i, j) \in \mathcal{E}} \|\bm{y}_i - \bm{y}_j \|_2^2. \]

  • 这种形式虽然很一般, 但是在实际中, 可能会遇到异常值的问题 (\(\|\cdot\|_2^2\) 对异常值非常敏感).

  • 对于 (1) 我们有一种概率上的解释, 令:

    \[p(X|Y) \propto \exp(-\frac{1}{2\lambda} \|Y - f(X; W)\|_{\mathcal{F}}^2), \\ p(Y) \propto \exp(-\frac{1}{2}\text{tr}(Y^TLY)), \]

    此时

    \[\ell_Y \Leftrightarrow -\log p(X|Y)p(Y) \Leftrightarrow -\log p(Y|X). \]

    故, 我们可以认为, 最小化 \(\ell_Y\) 某种程度上就是在找最大后验概率的点. 这里, 我们假设先验为 \(p(Y)\), 即每条边 \((i, j)\) 的方差均为 \(1\), 这是一个很强的假设, 因为可能某些边是噪声.

Robust Regularization

  • 于是, 作者希望如此建模先验:

    \[p(Y) =\prod_{(i, j) \in \mathcal{E}} p(\bm{y}_i - \bm{y}_j) =\prod_{(i, j) \in \mathcal{E}} p(\bm{u}_{ij}), \: \bm{\mu}_{ij} := \bm{y}_i - \bm{y}_j. \]

  • 再建模:

    \[p(\bm{\mu}_{ij}) = \int p(\bm{\mu}_{ij}, \gamma_{ij}) \mathrm{d} \mu (\gamma_{ij}), \]

    这里 \(\gamma_{ij}\) 是 edge \((i, j)\) 的不确定度的变量, 此时

    \[\tag{13} p(Y) = Z^{-1} \prod_{(i, j) \in \mathcal{E}} \int \mathcal{N}(\bm{\mu}_{ij}|0, \gamma_{ij}^{-1}I) \mathrm{d} \mu(\gamma_{ij}). \]

    这里我们假设 \(p(\bm{\mu}_{ij}|\gamma_{ij}) = \mathcal{N}(\bm{\mu}_{ij}|0, \gamma_{ij}^{-1}I)\).

  • 一个比较重要的结论是: 对于任意的满足 (13) 的先验, 都存在凹的且非降的函数 \(\rho: \mathbb{R}^+ \rightarrow \mathbb{R}\) 使得下列成立:

    \[-\log p(Y) = \pi(Y; \rho) \Leftrightarrow \sum_{(i, j) \in \mathcal{E}} \rho (\|\bm{y}_i - \bm{y}_j\|_2^2) \]

  • 换言之, 我们考虑不同的先验的建模, 实际上等价于寻找不同的 \(\rho\), 故而, 我们可以将 (1) 转换为如下的更加一般的形式:

    \[\tag{14} \ell_Y(Y; \rho) :=\|Y - f(X; W)\|_{\mathcal{F}}^2 + \lambda \sum_{(i, j) \in \mathcal{E}} \rho(\|\bm{y}_i - \bm{y}_j\|_2^2). \]

  • 实际上, (14) 可以进一步改写为:

    \[\sum_{(i, j) \in \mathcal{E}} \Big[ \gamma_{ij} \|\bm{y}_i - \bm{y}_j\|_2^2 - \tilde{\rho}(\gamma_{ij}), \Big] \]

    这里 \(\gamma_{ij}\) 知识一组变分分解系数, \(\tilde{\rho}(\gamma) := \inf_x(\gamma x - \rho(x))\)\(\rho\) 的凹共轭 (concave conjugate). 故

    \[\begin{array}{ll} & \tilde{\rho}(\gamma_{ij}) \le \gamma_{ij} \|\bm{y}_i - \bm{y}_j\|_2^2 - \rho(\|\bm{y}_i - \bm{y}_j\|_2^2) \\ \Rightarrow & \rho(\|\bm{y}_i - \bm{y}_j\|_2^2)\le \gamma_{(i,j} \|\bm{y}_i - \bm{y}_j\|_2^2 - \tilde{\rho}(\gamma_{ij}) \\ \Rightarrow & \sum_{(i, j) \in \mathcal{E}}\rho(\|\bm{y}_i - \bm{y}_j\|_2^2)\le \sum_{(i, j) \in \mathcal{E}} \Big[ \gamma_{(i,j} \|\bm{y}_i - \bm{y}_j\|_2^2 - \tilde{\rho}(\gamma_{ij}) \Big]. \\ \end{array} \]

  • 这有一个什么好处呢, 我们可以通过确定 \(\gamma_{ij}\) 然后优化 \(\ell_Y(Y; \rho)\) 的一个上界:

    \[\hat{\ell}_Y (Y; \Gamma; \tilde{\rho}) = \|Y - f(X; W)\|_{\mathcal{F}}^2 + \lambda \text{tr}(Y^T \hat{L} Y) + f(\Gamma), \]

    其中 \(\Gamma \in \mathbb{R}^{m \times m}\) 为一个对角矩阵, 对角线元素为 \(\gamma_{(i, j)}\), 而 \(\hat{L} := B^T \Gamma B\).

  • 对于一般的 \(\gamma\), 最小化 \(\hat{\ell}_{Y}\) 实际上最小化 \(\ell_Y(Y; \rho)\) 的一个上界, 且倘若

    \[\tag{18} \gamma_{ij} = \frac{\partial \rho(z^2)}{\partial z^2}|_{z = \|\bm{y}_i - \bm{y}_j\|_2} \]

    的时候, 最小化 \(\hat{\ell}_Y\) 等价于最小化 \(\ell_Y(Y; \rho)\).

  • 故, 一种可行的算法是:

    1. 利用 (18) 更新 \(\gamma\);
    2. 利用类似 (6) 的公式更新 \(Y^{(k+1)}\).
  • 作者给出了一些 \(\rho\) 的选择: