RoPE

发布时间 2023-07-24 17:42:15作者: 馒头and花卷

目录

Su J., Lu Y., Pan S., Murtadha A., Wen B. and Liu Y. RoFormer: Enhanced transformer with rotary position embedding.

原作者的博客已经讲得非常到位了: [here] and [there].

RoPE

  • RoPE 是一种相对位置编码, 特点是它可以像绝对位置编码一样, 在 embedding 上操作后再进行 attention 的运算, 而不限定于在 score 矩阵上操作.

  • 具体的, 假设 \(\bm{x}_m, \bm{x}_n\) 为位置 \(m, n\) 上的两个 embedding, 令:

    \[\bm{z}_m := \mathbf{R}_m \mathbf{W}_q \bm{x}_m, \\ \bm{z}_n := \mathbf{R}_n \mathbf{W}_q \bm{x}_n, \\ \]

    \[\bm{z}_{m}^T \bm{z}_n \]

    就是吸收了相对位置信息 \((m-n)\) 的 score.

  • 其中 \(\mathbf{R} \in \mathbb{R}^{d \times d}\) 是旋转矩阵, 它作用在向量是相当于对两个两个维度地进行旋转. \(\theta_i = 10000^{-2i/d}\) 和最普通的 Sinusoidal 编码保持一致.

  • \(\mathbf{R}\bm{x}\) 有一种更加高效的方式:

  • 下面是 LLaMA 中的实现方式:


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)  # type: ignore
    freqs = torch.outer(t, freqs).float()  # type: ignore
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis


def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)


def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    # x: B, S, H, D
    # freqs_cis: S, D // 2
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # (B, S, H, D // 2)
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # (B, S, H, D // 2)
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_) # (1, S, 1, 1, D // 2)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)