Cluster-GCN An Efficient Algorithm for Training Deep Convolution Networks

发布时间 2023-04-27 16:50:43作者: 馒头and花卷

Chiang W., Liu X., Si S., Li Y., Bengio S. and Hsieh C. Cluster-GCN: An efficient algorithm for training deep and large graph convolutional networks. KDD, 2019.

以往的 GraphSage, FastGCN 等方法, 虽然能够实现 mini-batch 的训练, 但是他们所采样的方式效率是很低: 所采样的点之间往往可能具有很少的边, 导致整体的结果非常稀疏. 本文提出了一种高效的采样方式, 首先将所有的点聚类, 再采样.

符号说明

  • \(G = (\mathcal{V, E}, A)\), 图;
  • \(|\mathcal{V}| = N\);
  • \(X \in \mathbb{R}^{N \times F}\), 特征矩阵;
  • GCN 的每一层可以表述为:

    \[Z^{(l+1)} = A' X^{(l)} W^{(l)}, \: X^{(l+1)} = \sigma(Z^{(l+1)}), \]

    其中 \(A'\) 是 normalized 邻接矩阵.
  • 最后的损失可以表述为

    \[\tag{1} \mathcal{L} = \frac{1}{|\mathcal{Y}_L|} \sum_{i \in \mathcal{Y}_L} \text{loss}(y_i, z_i^L), \]

    其中 \(\mathcal{Y}_L\) 表示所有打了标签的结点的集合.

Motivation

  • (1) 是一个整体的在所有的打过标签的结点上的损失, 这在应对特别大规模的数据的时候就很麻烦了, 所以我们更希望的是采用 mini-batch 的方式:

    \[\mathcal{L} = \frac{1}{|\mathcal{B}|} \sum_{i \in \mathcal{B}} \text{loss}(y_i, z_i^L). \]

  • 但是有一个问题, 就是图的聚合操作, 需要用到所选的结点邻居, 所以, 即使 \(|\mathcal{B}|\) 本身很小, 为了精准地计算 \(z^L\), 所需的结点也是很多的 (随着层数指数增长). 所以我们只能采样一批点, 然后在较小的邻接矩阵 \(\hat{A}\) 上做聚合操作.

  • 倘若我们采用随机采样的方式, 就容易导致采样的点之间的 edges 很少 (因为我们很难保证恰好采样到那些关系比较紧密的结点). 假设采样的点为 \(\mathcal{B}\), 实际上就是 \(\|A_{\mathcal{B, B}}\|_0\) 很小, 这会使得训练效率异常低下.

Cluster-GCN

  • 本文的思想很简单, 希望通过聚类, 先将结点切分为多个紧密联系的群体 (通过聚类算法 METIS):

    \[[\mathcal{V}_1, \cdots, \mathcal{V}_c], \]

    则我们同样得到 \(c\) 个子图:

    \[[\{\mathcal{V}_1, \mathcal{E}_1, A_{11}\}, \cdots, \{\mathcal{V}_c, \mathcal{E}_c, A_{cc}\}]. \]

  • 于是乎, 在实际上训练的时候, 我们可以直接选择某个子图 \(G_i\) 作为一个 batch 用于训练.

  • 这种做法, 实际上相当于用

    \[\bar{A} = \left[ \begin{array}{ccc} A_{11} & \cdots & 0 \\ \vdots & \ddots & \vdots \\ 0 & \cdots & A_{cc} \\ \end{array} \right] \]

    去逼近 \(A\), 由于我们舍去了很多 Links, 必然会导致性能的下降.

  • 故本文在实际中, 会选择一个 (比预想) 较大的 \(c\), 然后每次采样的时候, 从中选择 \(q\) 个 clusters 作为一个 batch.

代码

Official

PyTorch