DIFFormer Scalable (Graph) Transformers Induced by Energy Constrained Diffusion

发布时间 2023-06-08 13:45:08作者: 馒头and花卷

Wu Q., Yang C., Zhao W., He Y., Wipf D. and Yan J. DIFFormer: Scalable (graph) transformers induced by energy constrained diffusion. ICLR, 2023.

图的 diffusion 用于模型设计. 完成度相当高的一篇文章, 但是说实话感觉讲得太复杂了, 本来应该是一个很直观的东西.

符号说明

  • \(\mathcal{G} = (\mathcal{V, E})\), 图;
  • \(\mathcal{V} = \{\mathbf{x}_i\}_{i=1}^N\), nodes;
  • \(\mathcal{E} = \{e_{ij}\}\).

Difformer

  • 对于如下问题:

    \[\min_{\mathbf{Z}} \quad \text{trace}(\mathbf{Z}^T \mathbf{L}\mathbf{Z}) = \frac{1}{2} \sum_{ij} w_{ij} \|\mathbf{z}_j - \mathbf{z}_i\|_2^2. \]

    其中 \(\mathbf{L = I - A}\), \(\mathbf{A}_{ij} = w_{ij}\).

  • 它的一步迭代 (在 point \(\mathbf{Z}^{(k)}\)) 可以写作:

    \[\tag{1} \mathbf{Z}^{(k+1)} = \mathbf{Z}^{(k)} - \tau \mathbf{L} \mathbf{Z}^{(k)} \Rightarrow \mathbf{z}_i^{(k + 1)} = \mathbf{z}^{(k)} + \tau \sum_{j=1}^N w_{ij} (\mathbf{z}_j^{(k)} - \mathbf{z}_i^{(k)}), \]

    其中 \(\tau\) 为步长. 作者将这一步写作:

    \[\mathbf{z}_i^{(k + 1)} = (1 - \tau \sum_{j=1}^N w_{ij}) \mathbf{z}_i^{(k)} + \tau \sum_{j=1}^N w_{ij}\mathbf{z}_j^{(k)}, \]

    并理解为热扩散方差的一步 Euler 近似.

  • 但是, 作者希望在扩散的过程中, 保持 \(\mathbf{z}_i, \mathbf{z}_j\) 间的一些性质, 故而引入了 energy constraints:

    \[\tag{2} E(\mathbf{Z}, k; \delta) = \|\mathbf{Z} - \mathbf{Z}^{(k)}\|_{\mathcal{F}}^2 + \lambda \sum_{i, j}\delta(\|\mathbf{z}_i - \mathbf{z}_j\|_2^2), \]

    其中 \(\delta: \mathbb{R}^+ \rightarrow \mathbb{R}\) 为非降的凹函数.

  • 作者的目标就是, 在进行 (1) 的扩散时, 同时要求:

    \[(3) E(\mathbf{Z}, k;\delta) \le E(\mathbf{Z}, k-1; \delta), \quad k \ge 1 \]

    成立.

  • 作者是通过设计合适 \(w_{ij}, \delta\) 来满足这一条件的, 具体地, 可以通过如下定理满足: 对于任意 \(\lambda\), 当

    \[\hat{\mathbf{A}}_{ij} = \frac{\omega_{ij}}{\sum_{l=1}^N \omega_{il}}, \quad \omega_{ij} = \frac{\partial \delta(z^2)}{\partial z^2}|_{z^2 = \|\mathbf{z}_i - \mathbf{z}_j\|_2^2}. \]

    存在 \(0 < \tau < 1\) 使得 (1) 更新的同时有 (3) 条件满足,

  • 注意, 上面的 \(\mathbf{S}_{ij}\), 就是 \(A_{ij} = w_{ij}\).

  • 故而, 我们只需要拟合:

    \[\omega_{ij} = f(\|\mathbf{z}_i - \mathbf{z}_j\|_2^2), \]

    然后令

    \[w_{ij} = A_{ij} = \frac{\omega_{ij}}{\sum_{l=1}^N \omega_{il}}, \]

    然后通过如下迭代公式更新即可:

    \[\mathbf{Z}^{(k+1)} = (1 - \tau) \mathbf{Z}^{(k)} + \tau \underbrace{\mathbf{A}^{(k)}}_{{\mathbf{D}_{\Omega}^{(k)}}^{-1}\Omega^{(k)}} \mathbf{Z}^{(k)}. \]

  • \(f(\cdot)\) 需要是非负且递减的函数, 比如:

    \[f(\|\tilde{\mathbf{z}}_i - \tilde{\mathbf{z}}_j\|_2^2) = 1 + \tilde{\mathbf{z}}_i^T \tilde{\mathbf{z}}_j. \]

    又或者

    \[f(\|\tilde{\mathbf{z}}_i - \tilde{\mathbf{z}}_j\|_2^2) = \frac{1}{1 + \exp(-\tilde{\mathbf{z}}_i^T\tilde{\mathbf{z}}_j)}. \]

    注: 这里我们令 \(\tilde{\mathbf{z}} = \mathbf{z}/\|\mathbf{z}\|\), 故此时 squared \(\ell_2\) norm 等价于内积.

  • 我们可以这样引进可学习的参数:

    \[E(\mathbf{Z}, k; \delta) = \|\mathbf{Z} - h^{(k)}(\mathbf{Z}^{(k)})\|_{\mathcal{F}}^2 + \lambda \sum_{i, j}\delta(\|\mathbf{z}_i - \mathbf{z}_j\|_2^2), \]

    \(h^{(k)}\), 实际上是就是 GCN 中的 update 的部分.

  • 注意, 倘若我们有一些先验的知识, 比如图的结构, 邻接矩阵 \(\tilde{\mathbf{A}}\), 我们可以用 \(\Omega + \tilde{A}\) 来替换 \(\Omega\).

代码

official