FastGCN Fast Learning with Graph Convolutional Networks via Importance Sampling

发布时间 2023-04-16 17:16:10作者: 馒头and花卷

Chen J., Ma T. and Xiao C. FastGCN: fast learning with graph convolutional networks via importance sampling. ICLR, 2018.

一般的 GCN 每层通常需要经过所有的结点的 propagation, 但是这是费时的. 像普通的深度学习方法一样利用 mini-batch 的训练方式由于缺乏独立性的保证而不能非常有效地施展. 本文提出地 FastGCN 希望解决这个问题.

符号说明

  • \(H^{(l)}\), 第 \(l\) 层的 embeddings;
  • \(W^{(l)}\), 第 \(l\) 层的 权重;
  • \(\hat{A}\), 邻接矩阵;
  • GCN 的 propagation 过程:

    \[H^{(l+1)} = \sigma (\hat{A}H^{(l)}W^{(l)}). \]

Motivation

  • 在深度学习中, 我们通常需要优化这样的一个损失:

    \[L = \mathbb{E}_{x \sim D} [g(W; x)], \]

    通常我们通过独立地采样点来近似期望:

    \[L_{emp} = \frac{1}{n} \sum_{i=1}^n g(W; x_i), \quad x_i \sim D, \forall i. \]

  • 但是在图中, 独立性是难以保障的, 因为, 在 GCN 中, 数据成了结构的一部分, 即 \(\hat{A}\) 的存在导致即使我们独立地采样样本点, 但是邻接矩阵是无法保证和完整的邻接矩阵一个效果的.

FastGCN

  • 假设我们拥有图 \(G' = (V', E')\), 它包含了所有在现在和未来可能遇到的结点, 且我们假设结点集合 \(V'\) 上存在这样的一个概率空间: \((V', F, P)\), 其中 \(F\) 是定义的域 (如 \(2^{V'}\)), \(P\) 是某个概率测度.

  • 我们把 GCN 的每一层看成是如下的积分变换:

    \[\tilde{h}^{(l+1)}(v) = \int \hat{A}(v, u) h^{(l)}(u) W^{(l)} \mathrm{d} P(u), \\ h^{(l+1)}(v) = \sigma(\tilde{h}^{(l+1)}(v)), \\ l=0,\ldots, M-1. \]

    这里, 我们将 \(h\) 看成是一个 embedding function, 给定结点返回对应的特征.

  • 类似地, 我们可以将图上的训练损失表示为:

    \[L = \mathbb{E}_{v \sim P}[g(h^{(M)}(v))] = \int g(h^{(M)}(v)) \mathrm{d} P(v). \]

  • 于是, 我们可以通过采样来近似:

    \[\hat{h}^{(l+1)} (v) := \frac{1}{t_l} \sum_{j=1}^{t_l} \hat{A}(v, u_j^{(l)}) \hat{h}^{(l)} (u_j^{(l)}) W^{(l)}, \\ \hat{L} := \frac{1}{t_M} \sum_{i=1}^{t_M} g(\hat{h}^{(M)}(u_i^{(M)})). \]

  • 可以证明 (依概率 1):

    \[\lim_{t_0, t_1, \ldots, t_M \rightarrow + \infty} \hat{L} = L. \]

  • 当我们选择 \(P\) 为均匀采样的时候, 算法如下:

  • 其中 \(n\) 是因为, 原来的 GCN 的 aggregation 过程:

    \[\begin{array}{ll} h(v) &= \sum_{u} \hat{A}(v, u) h(u) W \\ &= n \cdot \frac{1}{n} \sum_{u} \hat{A}(v, u) h(u) W \\ &= n \cdot \mathbb{E}_{u} [\hat{A}(v, u) h(u) W] &\approx n \cdot \frac{1}{t} \sum_{j=1}^{t} [\hat{A}(v, u_j) h(u_j) W]. \end{array} \]

方差分析

  • 令:

    \[y(v_i) := \frac{1}{t} \sum_{j=1}^t \hat{A}(v, u_j) x(u_j) \]

    表示一个结点的估计, 作者希望估计

    \[G := \frac{1}{s} \sum_{i=1}^s y(v_i) \]

    的方差.

  • 其方差如下 (\(u, v\) 是独立的):

    \[\text{Var}(G) = R + \frac{1}{st} \int \int \hat{A}(v, u)^2 x(u)^2 dP(u) dP(v), \]

    其中

    \[R = \frac{1}{s}(1 - \frac{1}{t}) \int e(v)^2 dP(v) - \frac{1}{s} (\int e(v) d P(v))^2, \\ e(v) = \int \hat{A}(v, u) x(u) dP(u). \]

    注: 作者证明的时候, 用到了一个很有意思的性质:

    \[\begin{array}{ll} \text{Var}_{u,v}\Big\{ f(u, v) \Big\} &=\mathbb{E}_{u,v}\Big\{ (f(u, v) - \mathbb{E}_{u,v}[f(u, v)])^2 \Big\} \\ &=\mathbb{E}_{u,v}\Big\{ (f(u, v) - \mathbb{E}_{u}[f(u, v)])^2 \Big\} + \mathbb{E}_{v}\Big\{ (\mathbb{E}_{u}[f(u, v)] - \mathbb{E}_{u,v}[f(u, v))^2 \Big\} \\ &=\mathbb{E}_{v}\Big\{ \text{Var}_u(f(u, v)) \Big\} + \text{Var}_{v}\Big\{ \mathbb{E}_{u}[f(u, v)] \Big\}. \end{array} \]

  • 通过改变采样策略, 我们可以改进第二项的值从而改进方差, 从而作者引入了 importance sampling, 即

    \[y_Q(v) := \frac{1}{t} \sum_{j=1}^t \hat{A}(v, u_j) x(u_j) (\frac{dP(u)}{dQ(u)}|_{u_j}), \quad, u_1, \ldots, u_t \sim Q. \]

    从而:

    \[G_{Q} := \frac{1}{s} \sum_{i=1}^s y_Q(v_i). \]

  • 这样最优的 \(Q\) 为:

    \[dQ(u) = \frac{b(u)|x(u)| dP(u)}{\int b(u)|x(u)| dP(u)}, \: b(u) = [\int \hat{A}(v, u) dP(v)]^{1/2}, \]

    使得

    \[\text{Var}\{G_Q\} = R + \frac{1}{st} [\int b(u) |x(u)| dP(u)]^2. \]

  • 但是这个有一个问题, 就是在训练过程中 \(x(u)\) 是时刻在变化的, 所以这个分布是不稳定的, 故实际中, 作者选择

    \[dQ(u) = \frac{b(u)^2 dP(u)}{\int b(u)^2 dP(u)}. \]

  • 在实际中, 我们可以定义:

    \[q(u) = \|\hat{A}(:, u)\|^2 / \sum_{u' \in V} \|\hat{A}(:, u')\|^2, \quad \forall u \in V. \]

  • 基于重要性采样的算法如下:

代码

official