Collaborative Distillation for Top-N Recommendation

发布时间 2023-09-21 21:03:00作者: 馒头and花卷

Lee J., Choi M., Lee J. and Shim H. Collaborative distillation for top-N recommendation. ICDM, 2019.

Ranking-aware 的蒸馏.

符号说明

  • \(\mathcal{U} = \{u_1, \ldots, u_m\}\), users;
  • \(\mathcal{I} = \{i_1, \ldots, i_n\}\), items;
  • \(\mathbf{R} \in \{0, 1\}^{m \times n}\), 1 表示观测到的 (且大概率是 positive), 0 表示未观测到的 (可能是 positive 的, 也可能是 negative 的)
  • \(\mathcal{I}_u^+ = \{i \in \mathcal{I}| r_{ui} = 1\}\);
  • \(\mathcal{I}_u^- = \{i \in \mathcal{I}| r_{ui} = 0\}\);
  • \(M(u, i; \theta)\), ranking model.

Collaborative distillation (CD)

  • 本文的目标是解决针对推荐系统的蒸馏的问题: 主要是稀疏性问题. 具体来说, 由于在推荐系统中, 真正的 positive 的交互数量相较于那些未被观测到的实在是过于少了, 所以倘如我们采用一般的 (二元) 交叉熵损失来训练就会存在一定的问题:

    \[\mathcal{L}(\theta) = -\sum_{i \in \mathcal{I}_u^+} \log (P(r=1|u, i)) - \sum_{i \in \mathcal{I}_u^-} \log (1 - P(r=1|u, i)). \]

    其中 \(\hat{r}_{ui} = P(r=1|u, i)\) 表示 user \(u\) 中意 item \(i\) 的预测概率. 对于 \(i \in \mathcal{I}_u^+\), 这个是靠谱的, 但是对于 \(i \in \mathcal{I}_u^-\), 由于 \(0\) 即可能是 unlabled positive 交互, 也可能是 negative 的交互, 所以显得并不那么可靠.

  • 所以, 作者的思路是先训练一个教师网络 \(M_T\) (通过上述的损失), 然后利用如下损失再去训练一个学生网络:

    \[\mathcal{L}(\theta_S; \theta_T) = \mathcal{L}_{CF}(\theta_S) + \lambda \mathcal{L}_{KD} (\theta_S; \theta_T). \]

  • 这里

    \[\mathcal{L}_{CF}(\theta_S) = - \sum_{i \in \mathcal{I}_u^+} \log (\hat{r}_{ui}), \]

    仅仅针对 positive 的交互. 当然了, 仅仅这部分是不够的, 负样本的信息交由了 \(\mathcal{L}_{KD}\):

    \[\mathcal{L}_{KD}(\theta_S; \theta_T) = - \sum_{i \in S(\mathcal{I}_u^-)} \Big \{ q_{ui} \log (P(r = 1|u, i)) + (1 - q_{ui}) \log (1 - \hat{r}_{ui}) \Big \}. \]

    其中 \(S(\mathcal{I}_u^-)\) 表示在 \(\mathcal{I}_u^-\) 进行采样. \(q_{ui}\) 教师模型 \(M_T\) 输出的 logit 经过一些微调后的概率.

  • 首先一个问题是为什么需要采样, 作者认为`负样本'中不同位置的重要性是不同的, 一般来说越靠前的越重要 (说实话, 我不认为需要采样). 假设 item \(i\) 的重要性为 (仅针对 user \(u\)):

    \[\pi (i) = \frac{rank(i)}{|\mathcal{I}_u^-|}, \]

    我们有三种采样方式:

    1. uniform sampling;
    2. teacher guided, 比如 \(rank(i)\) 是按照教师模型的打分进行排名的, 然后采样概率满足

      \[p_i \propto 1 - \pi(i). \]

    3. student guided, 此时 \(rank(i)\) 是按照学生模型的打分进行排名的, 然后在采样. 这种方式被认为能够更加好地考虑到学生的感受 (当然了, 需要更多的一点的计算开销).