Stochastic Training of Graph Convolutional Networks with Variance Reduction

发布时间 2023-04-15 15:59:29作者: 馒头and花卷

Chen J., Zhu J. and Song L. Stochastic training of graph convolutional networks with variance reduction. ICML, 2018.

我们都知道, GCN 虽然形式简单, 但是对于结点个数非常多的情形是不易操作的: 多层的卷积之后基本上每个结点的感受野都会变得非常大 (指数级上升), 这对导致 mini-batch 训练的思想在图的任务中是不那么普遍的. GraphSage 通过采样结点的方式缓解了这个问题, 但是不想以往的 mini-batch 那样有很好的收敛性保证. 本文主要解决的就是这两个问题.

符号说明

  • \(\mathcal{G} = (\mathcal{V}, \mathcal{E})\), 图;
  • \(V = |\mathcal{V}|, E = |\mathcal{E}|\);
  • \(A\), 邻接矩阵;
  • \(\tilde{A} = A + I\);
  • \(P = \tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2}\);
  • graph convolution layer:

    \[Z^{(l+1)} = P H^{(l)} W^{(l)}, \\ H^{(l+1)} = \sigma(Z^{l+1}). \]

Motivation

  • 我们首先将任务场景限定在 node-level 的问题上, 假设训练集 \(\mathcal{V}_l \subset \mathcal{V}\), 我们希望根据这些结点以及他们的标签为其余的结点 \(\mathcal{V} \setminus \mathcal{V}_l\) 打标签;

  • 假设我们通过如下损失进行训练:

    \[\mathcal{L} = \frac{1}{\mathcal{V}_l} \sum_{v \in \mathcal{V}_l} f(y_v, z_v^{(L)}), \]

    则每次训练都要计算如下的梯度:

    \[\nabla \mathcal{L} = \frac{1}{\mathcal{V}_l} \sum_{v \in \mathcal{V}_l} \nabla f (y_v, z_v^{(L)}), \]

    作者认为这一步的计算是非常大的.

  • 一些方法, 比如 GraphSage 采用采样邻居的方式来模拟 mini-batch 的训练方式, 它相当于:

    1. 采样部分结点;
    2. 利用这些结点构建新的邻接矩阵 \(\hat{P}\);
    3. 然后通过如下方式进行更新:

      \[Z^{(l+1)} = \hat{P}^{(l)} H^{(l)} W^{(l)}, \\ H^{(l+1)} = \sigma(Z^{(l+1)}), \]

      这就大大减轻了计算.
  • 但是这种方式, 由于非线性 \(\sigma\) 的存在, 无法保证 \(\hat{P}\) 会收敛到 \(\hat{P}\), 也因此整体的训练方式也缺乏严格的收敛保证.

本文方法

  • 作者维护历史变量 \(\bar{H}^{(l)}\), 然后通过如下方式更新:

    \[Z^{(l+1)} = (\hat{P}^{(l)}(H^{(l)} - \bar{H}^{l}) + P\bar{H}^{(l)}) W, \]

    由于 \(\bar{H}^{(l)}\) 本身是不带梯度的, 所以在反向计算梯度的时候:

    \[\begin{array}{ll} \mathrm{d} Z^{(l+1)} &= \mathrm{d} (\hat{P}^{(l)}(H^{(l)} - \bar{H}^{l}) + P\bar{H}^{(l)}) W \\ &= (\hat{P}^{(l)} \mathrm{d} H^{(l)}) W \\ & \quad + (\hat{P}^{(l)}(H^{(l)} - \bar{H}^{l}) + P\bar{H}^{(l)}) \mathrm{d} W, \end{array} \]

  • 说实话, 从训练的方式来看, 并没有减少计算量的感觉. 不过, 因为这种方式更加稳定, 所以方差比较小, 因此只需要为每个结点采样 2 个邻居就可以了, 这应该是计算量降低的主要原因.

代码

official