Graph Normalizing Flows

发布时间 2023-05-25 10:55:05作者: 馒头and花卷

Liu J., Kumar A., Ba J., Kiros J. and Swersky K. Graph normalizing flows. NIPS, 2019.

基于 flows 的图的生成模型.

符号说明

  • \(\mathcal{G} = (H, \Omega)\), 图;
  • \(H = (\mathbf{h}^{(1)}, \cdots, \mathbf{h}^{(N)}) \in \mathbb{R}^{N \times d_n}\), node feature matrix;
  • \(\Omega \in \mathbb{R}^{N \times N \times (d_e + 1)}\), 其中 \(\Omega_{:,:,0} \in \mathbb{R}^{N \times N}\) 表示邻接矩阵, \(\Omega_{:, :, :1:(d_e+1)}\) 表示 edge features.
  • 寻常的 MPNN 可以表述为:

    \[\mathbf{m}_{t+1}^{(v)} = \text{Agg} \Bigg( \{ M_t (\mathbf{h}_t^{(v)}, \mathbf{h}_t^{(u)}, \Omega_{u, v}) \}_{u \in \mathcal{N}}(v) \Bigg) \\ \mathbf{h}_{t+1}^{(v)} = U_t(\mathbf{h}_t^{(v)}, \mathbf{m}_{t+1}^{v}), \]

    其中 \(M_t(\cdot), U_t(\cdot)\) 分别是 message generation function 和 vertex update function.

Graph Normalizing Flows

  • 需要注意的是, 本文的 flows 和一般的 flows 有点区别, 它并不具有一个 encoder 先将 \(\mathbf{x}\) 转换为隐变量 \(\mathbf{z}\)\(\mathbf{z}' = f(\mathbf{z})\) 的过程, 而是直接构造 flow \(\mathbf{z} = f(\mathbf{x})\).

  • 简单来说, flow 需要保证 \(f(\cdot)\) 是可逆的, 此时:

    \[P(\mathbf{z}) = P(\mathbf{x})|\text{det}(\frac{\partial f(\mathbf{x})}{\partial \mathbf{x}})|^{-1}. \]

  • 作者是基于 RealNVP 进行的, 该方法将 \(\mathbf{x}\) 切分为 \(\mathbf{x}^{(0)}, \mathbf{x}^{(1)}\), 然后:

    \[\mathbf{z}^{(0)} = \mathbf{x}^{(0)} \\ \mathbf{z}^{(1)} = \mathbf{x}^{(1)} \odot \exp(s(\mathbf{x}^{(0)})) + t(\mathbf{x}^{(0)}), \]

    其中 \(s, t\) 为两个 non-linear 函数. 和显然 \(\nabla_{\mathbf{x}} \mathbf{z}\) 为一个下三角矩阵. 此时行列式就是对角线元素相乘.

GRevNets

  • 让我们来看看作者是怎么构造可以的 flow 的.

  • 首先, 对每个结点 \(v\), 将它的结点特征切分为 \(\mathbf{h}_t^{0}, \mathbf{h}_t^{1}\) (这里我们省略标识 \((v)\)).

  • 前向的过程可以表述为:

    \[H_{t+\frac{1}{2}}^0 = H_t^0 \odot \exp(F_1(H_t^1)) + F_2 (H_t^1), \quad H_{t+1}^0 = H_{t+\frac{1}{2}}^0, \\ H_{t+\frac{1}{2}}^1 = H_t^{1}, \quad H_{t+1}^{1} = H_{t+\frac{1}{2}}^1 \odot \exp(G_1(H_{t+\frac{1}{2}}^{0})) + G_2(H_{t+\frac{1}{2}}^0). \]

  • 给定 \(H_{t+1}^0, H_{t+1}^1\) 我们可以得到:

  • 于是我们有:

    \[P(H_0) = P(H_T) \prod_{t=1}^T |\text{det}(\frac{\partial H_t}{\partial H_{t-1}})|. \]

  • 我们可以通过极大化对数似然 \(\log P(H_0)\) 来优化参数.

  • 但是, 我们最终希望的其实是生成离散的图 (通过邻接矩阵 \(A\) 来刻画).

  • 所以在生成的时候, 比如我们采样 \(H_T \sim \mathcal{N}(0, 1)\), 然后通过 GNF 得到 \(H_0\), 那么我们实际上还需要一个 decoder 将 \(H_0\) 映射为 \(\hat{A}\).

  • 为此, 作者还额外设计了一个 encoder, 将 \(A, H\) 映射为隐变量 \(X\), 不过我不是特别清楚为什么 \(H\) 也是采样子正态分布而不是直接用 node features.

  • 训练编码器是通过如下损失:

    \[\mathcal{L}(\theta) = -\sum_{i=1}^N \sum_{j=1}^{N/2} A_{ij} \log (\hat{A}_{ij}) + (1 - A_{ij}) \log (1 - \hat{A}_{ij}), \]

    这里 \(N/2\) 的原因是作者假设我们生成的是无向图, 所以 \(A\) 是对称的.

  • 对于 decoder, 作者采用的是一种非常简单的方式:

    \[\hat{A}_{ij} = \frac{1}{1 + \exp(C(\|\mathbf{x}_i - \mathbf{x}_j\|_2^2 - 1))}. \]