Transformer计算公式

发布时间 2023-08-23 19:05:13作者: 鸽鸽的书房

LLM inference workflow

Generative Inference. A typical LLM generative inference task consists of two stages: i) the prefill stage which takes a prompt sequence to generate the key-value cache (KV cache) for each transformer layer of the LLM; and ii) the decoding stage which utilizes and updates the KV cache to generate tokens step-by-step, where the current token generation depends on previously generated tokens.

prefill phase

Then, the cached key,value can be computed by:

\(\mathrm{x}_K^i=\mathrm{x}^i \cdot \mathrm{w}_K^i ; \quad \mathrm{x}_V^i=\mathrm{x}^i \cdot \mathrm{w}_V^i\)

The rest of the computation in the i-th layer is:

\(\begin{gathered}\mathrm{x}_Q^i=\mathrm{x}^i \cdot \mathrm{w}_Q^i \\ \mathrm{x}_{\text {Out }}^i=f_{\text {Softmax }}\left(\frac{\mathrm{x}_Q^i \mathrm{x}_K^i}{\sqrt{h}}\right) \cdot \mathrm{x}_V^i \cdot \mathrm{w}_O^i+\mathrm{x}^i \\ \mathrm{x}^{i+1}=f_{\text {relu }}\left(\mathrm{x}_{\text {Out }}^i \cdot \mathrm{w}_1\right) \cdot \mathrm{w}_2+\mathrm{x}_{\text {Out }}^i\end{gathered}\)

decode phase

During the decode phase, given \(\mathbf{t}^i \in \mathcal{R}^{b \times 1 \times h_1}\) as the embedding of the current generated token in the \(i\)-th layer, the inference computation needs to i) update the KV cache:

\[\begin{aligned} & \mathbf{x}_K^i \leftarrow \text { Concat }\left(\mathbf{x}_K^i, \mathbf{t}^i \cdot \mathbf{w}_K^i\right) \\ & \mathbf{x}_V^i \leftarrow \text { Concat }\left(\mathbf{x}_V^i, \mathbf{t}^i \cdot \mathbf{w}_V^i\right) \end{aligned} \]

and ii) compute the output of the current layer:

\[\begin{gathered} \mathbf{t}_Q^i=\mathbf{t}^i \cdot \mathbf{w}_Q^i \\ \mathbf{t}_{\text {Out }}^i=f_{\text {Softmax }}\left(\frac{\mathbf{t}_Q^i \mathbf{x}_K^i}{\sqrt{h}}\right) \cdot \mathbf{x}_V^i \cdot \mathbf{w}_O^i+\mathbf{t}^i \\ \mathbf{t}^{i+1}=f_{\text {relu }}\left(\mathbf{t}_{\text {Out }}^i \cdot \mathbf{w}_1\right) \cdot \mathbf{w}_2+\mathbf{t}_{\text {Out }}^i \end{gathered} \]