Relational Knowledge Distillation

发布时间 2023-09-20 16:02:27作者: 馒头and花卷

Park W., Kim D., Lu Y. and Cho M. Relational knowledge distillation. CVPR, 2019.

符号说明

  • \(f_T, f_S\), teacher and student model;
  • \(\mathcal{X}^2 := \mathcal{X} \times \mathcal{X}\);
  • \(\mathcal{X}^3 := \mathcal{X} \times \mathcal{X} \times \mathcal{X}\);

RKD

  • IKD (Individual Knowledge Distillation): 的蒸馏损失为:

    \[\mathcal{L}_{IKD} = \sum_{x_i \in \mathcal{X}} \ell (f_T(x_i), f_S(x_i)), \]

    即不存在样本之间是独立处理的.

  • RKD (Relational Knowledge Distillation): 的蒸馏损失为:

    \[\mathcal{L}_{RKD} = \sum_{(x_1, x_2, \ldots, x_N) \in \mathcal{X}^N} \ell (\psi(t_1, \ldots, t_N), \psi(s_1, \ldots, s_N)), \]

    即存在 \(N\) 阶的交互. 其中 \(t_i = f_T(x_i), s_i = f_S(x_i)\).

  • 作者给了两种 \(\psi\) 的具体形式:

    1. distance-wise:

      \[\mathcal{L}_{RKD-D} = \sum_{(x_i, x_j) \in \mathcal{X}^2} \ell_{\delta} (\psi_D(t_i, t_j), \psi_D(s_i, s_j)) \\ \psi_D(t_i, t_j) = \frac{1}{\mu} \|t_i - t_j\|_2, \\ \mu = \frac{1}{|\mathcal{X}^2|} \sum_{(x_i, x_j) \in \mathcal{X}^2} \|t_i - t_j\|_2. \]

      即对于教师或者学生模型, 都计算两两的距离, 然后希望对应的 pair 的距离接近.

    2. angle-wise:

      \[\mathcal{L}_{RKD-A} = \sum_{(x_i, x_j, x_k) \in \mathcal{X}^3} \ell_{\delta} (\psi_A(t_i, t_j, t_k), \psi_A(s_i, s_j, s_k)), \\ \psi_A(t_i, t_j, t_k) = \cos \angle t_i t_j t_k = \langle \mathbf{e}^{ij}, \mathbf{e}^{kj} \rangle, \\ \mathbf{e}^{ij} = \frac{t_i - t_j}{\|t_i - t_j\|_2}, \mathbf{e}^{kj} = \frac{t_k - t_j}{\|t_k - t_j\|_2}. \]

      它时间是要求各自的三元组的角度相吻合.

  • 关于 \(\ell_{\delta}\) 的选择:

    \[\ell_{\delta}(x, y) = \left \{ \begin{array}{ll} \frac{1}{2}(x - y)^2 & \text{ for } |x - y| \le 1, \\ |x - y| - \frac{1}{2} & \text{otherwise}, \end{array} \right . \]

    此为 Huber loss.

代码

[official]