Main Stages of Auto-regressive Decoding for LLM Inference
The Graphical Explanation of Prefill Stage and Decoding Stage
If you’ve used ChatGPT or other Large Language Models (LLMs), you may have noticed a brief delay before the model outputs its response word by word. This behavior differs significantly from traditional language models. LLMs have only one encoding but multiple decodings. The initial delay after submitting a question to GPT is the encoding time (prefill stage), while the generation of the answer, word by word, is the decoding time (decoding stage).
Most popular decoder-only LLMs, such as Llama, are pretrained with the objective of causal modeling. They essentially act as predictors of the next word. These LLMs take a sequence of tokens as input and generate subsequent tokens in an auto-regressive manner. The generation process continues until a stopping condition is met, such as reaching a limit on the number of generated tokens or encountering a stop word from a predefined list. Alternatively, a special <end>
token is generated to indicate the end of generation.
The main content of this article is about the two main stages of auto-regressive decoding for LLM inference:
Prefill stage: This stage takes a prompt sequence and generates the key-value cache(KV cache) for each Transformer layer of LLM.
Decoding stage: This stage utilizes and updates the KV cache to progressively generate tokens. The generation of each token depends on the previously generated tokens.
In general, the process of using LLM generative inference with KV cache is illustrated in Figure 1:
Prefill Stage: Processing the Prompt
In prefill stage, as shown in Figure 1, LLM processes the input token to calculate intermediate states (keys and values) that are used to generate the “first” new token. In other words, this stage computes and caches the keys and values for each layer, while the others do not require caching, this cache is referred to as the KV cache, a critical component throughout the following decoding process.
Below is a formal description of this stage (for the sake of simplicity, we only consider the case of a single head), including the parameters, weight matrices, and intermediate variables as follows:
Then, the cached key and value for the i-th layer can be computed by:
The remaining computation in the i-th layer is as follows:
Since the entire scope of the input is known, at a high level, this is a highly parallelized matrix-matrix operation. It effectively utilizes the efficiency of the GPU.
Decoding Stage: Generating the Output
In the decoding stage, as shown in Figure 1, LLM generates output tokens one by one, and this stage is sequential. It utilizes the previously generated token to generate the next token until the stopping condition is met. Each sequentially output token needs to know all the previous iterations’ output states (keys and values).
To avoid recalculating the K
and V
for all tokens in each iteration, most implementations store these values in a KV cache.
During the decoding stage, the embedding of the current generated token in the i-th layer is:
At this stage, two things need to be done:
update the KV cache:
2. Compute the output of the current layer
From the above process, it is evident that the decoding stage performs similar operations to the prefill stage. However, there is a difference in the input tensor shape [b, 1, h1]
for the decoding stage and [b, s, h1]
for the prefill stage. This can be likened to a matrix-vector operation.
Unlike the prefill stage, the decoding stage does not fully utilize the computing power of the GPU.
Conclusion
The benefit of dividing it into two stages is evident. The prefill stage only requires filling the key-value cache once, while the decoding stage only involves updating and looking up the key-value cache. This approach helps prevent unnecessary computation.
Finally, if there are any errors or omissions in this article, or if you have any thoughts, please point them out in the comments section.
References
[1]: Ying Sheng, Lianmin Zheng, Binhang Yuan, Zhuohan Li, Max Ryabinin, et al. FlexGen: High-Throughput Generative Inference of Large Language Models with a Single GPU. arXiv preprint arXiv:2303.06865, 2023.