Exploring Medusa and Multi-Token Prediction | by Matthew Gunton | Jul, 2024 – Towards Data Science

Speculative Decoding was introduced as a way to speed up inferencing for an LLM. You see, LLMs are autoregressive, meaning we take the output token that we just predicted and use it to help predict the next token we want. Typically we are predicting one-token at a time (or one-token per forward pass of the neural network). However, because the attention pattern for the next token is very similar to the attention pattern from the previous one, we are repeating most of the same calculations and not gaining much new information.

Speculative decoding means that rather than doing one forward pass for one token, instead after one forward pass we try to find as many tokens as we can. In general there are three steps for this:

(1) Generate the candidates

(2) Process the candidates

(3) Accept certain candidates

Medusa is a type of speculative decoding, and so its steps map directly onto these. Medusa appends decoding heads to the final layer of the model as its implementation of (1). Tree attention is how it processes the candidates for (2). Finally, Medusa uses either rejection sampling or a typical acceptance scheme to accomplish (3). Lets go through each of these in detail.

A decoding head takes the internal representation of the hidden state produced by a forward pass of the model and then creates the probabilities that correspond to different tokens in the vocabulary. In essence, it is converting the things the model has learned into probabilities that will determine what the next token is.

Medusa adjusts the architecture of a typical Transformer by appending multiple decoding heads to the last hidden layer of the model. By doing so, it can predict more than just one token given a forward pass. Each additional head that we add predicts one token further. So if you have 3 Medusa heads, you are predicting the first token from the forward pass, and then 3 more tokens after that with the Medusa heads. In the paper, the authors recommend using 5, as they saw this gave the best balance between speed-up and quality.

To accomplish this, the authors of the paper proposed the below decoder head for Medusa:

This equation gives us the probability of token t from the k-th head. We start off by using the weights weve found through training the Medusa head, W1, and multiplying them by our internal state for token t. We use the SiLU activation function to pass through only selective information(SiLU = x * sigmoid(x)). We add to this the internal state a second time as part of a skip connection, which allows the model to be more performant by not losing information during the linear activation of the SiLU. We then multiply the sum by the second set of weights weve trained for the head, W2, and run that product through a softmax to get our probability.

The first Medusa heads give the model probabilities they should consider based off the forward pass, but the subsequent Medusa heads need to figure out what token they should pick based off what the prior Medusa heads chose.

Naturally, the more options the earlier Medusa heads put forward (hyperparameter sk), the more options future heads need to consider. For example, when we consider just the top two candidates from head 1 (s1=2) and the top three from head 2 (s2=3), we wind up with 6 different situations we need to compute.

Due to this expansion, we would like to generate and verify these candidates as concurrently as possible.

The above matrix shows how we can run all of these calculations within the same batch via tree attention. Unlike typical causal self-attention, only the tokens from the same continuation are considered relevant for the attention pattern. As the matrix illustrates with this limited space, we can fit our candidates all into one batch and run attention on them concurrently.

The challenge here is that each prediction needs to consider only the candidate tokens that would be directly behind it. In other words, if we choose It from head 1, and we are evaluating which token should come next, we do not want to have the attention pattern for I being used for the tokens.

The authors avoid this kind of interference by using a mask to avoid passing data about irrelevant tokens into the attention calculation. By using this mask, they can be memory efficient while they calculate the attention pattern & then use that information in the decoding head to generate the subsequent token candidates.

While the above matrix shows us considering every prediction the same, if we have a probability for each prediction, we can treat these differently based on how likely they are to be the best choice. The below tree visualizes just that.

Read more:

Exploring Medusa and Multi-Token Prediction | by Matthew Gunton | Jul, 2024 - Towards Data Science

Related Posts

Comments are closed.