Rotary Position Embedding分析

发布时间 2023-11-07 21:52:10作者: liangyming

1 旋转角度计算

计算公式如下,其中d为词嵌入维度,

\[\theta_j=10000^{-2(j-1)/d},j\in [1,2,\ldots,d/2] \]

# 计算词向量元素两两分组之后,每组元素对应的旋转角度
# 维度:[dim / 2]
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))

2 计算整个seq的cos_sin矩阵

def precompute_freqs_cis(dim: int, seq_len: int, theta: float = 10000.0):
    # 计算词向量元素两两分组之后,每组元素对应的旋转角度
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    # 生成 token 序列索引 t = [0, 1,..., seq_len-1]
    t = torch.arange(seq_len, device=freqs.device)
    # freqs.shape = [seq_len, dim // 2] 
    freqs = torch.outer(t, freqs).float()
    # torch.polar计算得到每个值的复数向量
    # 假设 freqs = [[x, ..., y]]
    # 则 freqs_cis = [[cos(x) + sin(x)i, ..., cos(y) + sin(y)i]]
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis

3 计算旋转式位置编码

\[\begin{aligned}f_q(x_m,m)&=(W_qx_m)e^{im\theta} \\f_k(x_n,n)&=(W_kx_n)e^{in\theta}\end{aligned} \]

公式根据欧拉公式转化后为