Time Interval Aware Self-Attention for Sequential Recommendation

发布时间 2023-06-20 13:39:15作者: 馒头and花卷

Li J., Wang Y., McAuley J. Time interval aware self-attention for sequential recommendation. WSDM, 2020.

本文介绍了一种更好利用时间戳信息的方式, 引入相对位置编码.

符号说明

Motivation

  • 无论是 GRU4Rec, 还是 SASRec 等, 他们处理序列 \(s=[s_1, s_2, \ldots, s_n]\), 仅仅是利用 \(s_i, s_j\) 的前后顺序关系, 但是, 模型是无法知道两个交互 \(s_i, s_j\) 的时间间隔.

  • 如上图所示, 即便两个用户的以往的交互序列是一致的, 但是倘若发生的时间间隔不同, 我们应该推荐的下一个 item 也应该有所差异 (在某一时间段内频繁观看一种类型的电影, 推荐的下一个 item 可能也应该更加关注短期偏好, 否则更应该关注一点长期兴趣).

TiSASRec

  • 首先, 和 SASRec 一样, 给定序列 \(s = [s_1, s_2, \ldots, s_n]\) 我们可以获得它的 embedding:

    \[\mathbf{E}^I = [m_{s_1}, m_{s_2}, \ldots, m_{s_n}] \in \mathbb{R}^{n \times d}. \]

  • 接着, 不同于 SASRec, 作者引入两个可学习的绝对位置编码:

    \[\mathbf{M}_k^P \in \mathbb{R}^{n \times d}, \mathbf{M}_v^{P} \in \mathbb{R}^{n \times d}, \]

    分别作用于 attention 中的 keys 和 values.

  • 此外, 为了让模型感知到时间间隔, 作者建立:

    \[\mathbf{M}^u = \left[ \begin{array}{cccc} r_{11}^u & r_{12}^u & \ldots & r_{1n}^u \\ r_{21}^u & r_{22}^u & \ldots & r_{2n}^u \\ \vdots & \vdots & \ddots & \vdots \\ r_{n1}^u & r_{n2}^u & \ldots & r_{nn}^u \\ \end{array} \right], \]

    其中

    \[r_{ij}^u = \lfloor \frac{|t_i - t_j|}{r_{min}^u} \rfloor, r_{min}^u = \min_{ij} |t_i - t_j|. \]

  • 此外, 在实际中, 作者将 \(\mathbf{M}^u\) 中元素超过 \(k\) 的元素截断为 \(k\), 即

    \[\mathbf{M}_{clipped}^u = clip(M^u, k): r_{ij}^u \leftarrow \min(k, r_{ij}^u). \]

  • 对于 \(r \in \{1, 2, \ldots, k\}\) 建立 embedding, 我们可以得到:

    \[\mathbf{M}_K^R \in \mathbb{R}^{k \times d}, \mathbf{M}_V^R \in \mathbb{R}^{k \times d}. \]

  • 对于长度为 \(n\) 的序列, 由此得到:

    \[\mathbf{E}_K^R \in \mathbb{R}^{n \times n \times d}, \mathbf{E}_V^R \in \mathbb{R}^{n \times n \times d} \]

    两组 embedding.

  • 作者是在注意力阶段引入位置编码的:

    \[z_i = \sum_{j=1}^n \alpha_{ij} (W^V m_{s_j} + r_{ij}^v + p_j^v) \in \mathbb{R}^d, \\ \alpha_{ij} = \frac{\exp(e_{ij})}{\sum_{k=1}^n \exp(e_{ik})}, \\ e_{ij} = \frac{(W^Q m_{s_i})^T (W^K m_{s_j} + r_{ij}^k + p_j^k)}{\sqrt{d}}. \]

    注: 这里我们省略了 mask.

  • 其它的部分和普通的 transformer 的架构是一致的.

代码

[official]

[PyTorch]