Permutation Invariant Graph Generation via Score-Based Generative Modeling

发布时间 2023-05-26 11:22:05作者: 馒头and花卷

Niu C., Song Y., Song J., Zhao S., Grover A. and Ermon S. Permutation invariant graph generation via score-based generative modeling. AISTATS, 2020.

本文利用 diffusion 进行图的生成, 很朴素.

符号说明

  • \(\mathbf{A}^{\pi}\), 邻接矩阵, \(\pi\) 可以理解为结点的序;
  • \(\mathcal{A} = \{\mathbf{A} \in \mathbb{R}^{N \times N}| \mathbf{A} = \mathbf{A}^T, N \in \mathbb{N}^{+}\}\), 邻接矩阵的集合 (这里只考虑无向图);

本文方法

  • 首先, 我们需要了解 score-based 生成模型, 总而言之, 它要求我们估计概率模型的 score:

    \[\mathbf{s}_{\theta}(\mathbf{A}; \sigma): \mathcal{A} \rightarrow \mathcal{A}. \]

  • 我们的目的是生成图 \(G\), 但是图 \(G\) 实际上是由它的邻接矩阵 \(\mathbf{A}\) 决定的:

    \[p(G) = \sum_{\pi} p(G, \pi) = \sum_{\mathbf{A}^{\pi}} p(\mathbf{A}^{\pi}). \]

    所以, 我们要做的就是估计 \(p(\mathbf{A})\).

  • 按照 diffusion 的思想, 我们需要先构造一个前向的模糊过程:

    \[q_{\sigma}(\mathbf{\tilde{A}}| \mathbf{A})= \left \{ \begin{array}{ll} \prod_{i < j} \frac{1}{\sqrt{2\pi} \sigma} \exp\{ - \frac{(\mathbf{\tilde{A}}_{ij} - \mathbf{A}_{ij})^2}{2 \sigma^2} \}, & \text{if } \mathbf{\tilde{A}} = \mathbf{\tilde{A}}^T \\ 0, & \text{otherwise}. \end{array} \right . \]

    可以发现, 作者在扰动的时候, 对于矩阵的各元素独立地添加高斯噪声, 这和在图片上的 diffusion 是完全一致的. 特别地是, 因为我们考虑的是无向图, 所以假设非对称的扰动是不被允许的.

  • 因为 \(\nabla_{\mathbf{\tilde{A}}} \log q_{\sigma} (\mathbf{\tilde{A}}|\mathbf{A}) = - (\mathbf{\tilde{A}} - \mathbf{A}) / \sigma^2\), 所以最后的损失是:

    \[\mathcal{L}(\bm{\theta}; \{\sigma_i\}_{i=1}^L) = \frac{1}{2L} \sum_{i=1}^L \sigma_i^2 \mathbb{E} \Bigg[ \|\mathbf{s_{\theta}}(\mathbf{\tilde{A}}, \sigma_i) + \frac{\mathbf{\tilde{A} - A}}{\sigma_i^2}\|_2^2 \Bigg], \]

    其中 \(\sigma_i, i=1,2,\ldots, L\) 表示不同的噪声程度, 它依旧是来自 diffusion 的概念.

  • 实际采样的时候和 diffusion 的反向过程也是一致的, 就是从正态分布中采样, 然后逐步地利用如下算法来恢复:

  • 不同的是, \(\mathbf{z}_t\) 只需要采样 下 (或者上)三角即可, 另一半对称复制一下就行以保证邻接矩阵的对称性, 即:

    \[[\mathbf{z}_t]_{i,j} = \left \{ \begin{array}{ll} [\mathbf{z}_t]_{i,j} & i < j, \\ [\mathbf{z}_{t}]_{j, i} & i \ge j. \end{array} \right . \]

  • 最后, 通过比如 \(\mathbb{I}_{\mathbf{\tilde{A}}_{i, j} > 0.5}\) 来得到离散的邻接矩阵.

  • 关于 \(\mathbf{s_{\theta}}\) 的模型是如何设计的, 请参考原文.

代码

official