The masked multi-head attention mechanism is a crucial component within the Transformer decoder architecture. It serves as the first sublayer in each decoder block and represents a modified version of multi-head self-attention. Its primary purpose is to prevent the model from attending to future tokens during the training phase, thereby ensuring that the decoder's behavior during training closely mirrors its behavior during inference. During inference, the decoder generates the target sequence token by token, and at any given time step, it only has access to the tokens generated up to that point.
Consider the task of translating an English sentence to French:
English: "I am good"
French: "Je vais bien"
During inference, the decoder generates the French sentence sequentially:
Time Step 1: Decoder receives <sos> (start-of-sequence) token and predicts the first word.
Time Step 2: Decoder receives <sos> and the predicted word "Je", and predicts the next word.
Time Step 3: Decoder receives <sos>, "Je", and the predicted word "vais", and predicts the next word.
During training, we have access to the entire target sentence ("Je vais bien"). However, to accurately simulate the inference process, we must prevent the decoder from "seeing" future tokens. Masking achieves this by artificially obscuring tokens that appear later in the sequence. This ensures that when the decoder predicts a specific token, it relies solely on the context of previously generated (or available) tokens, mirroring the real-world generation scenario.
Instead of feeding the complete target sentence directly into the decoder, the input sequence is prepared as follows:
Prepend <sos>: The <sos> token is added to the beginning of the target sentence.
Modified Sequence: For the example "Je vais bien", the input becomes: <sos> Je vais bien.
Embedding and Positional Encoding: This modified sequence is then embedded and combined with positional encoding. The resulting matrix serves as the input to the decoder's first sublayer: the masked multi-head attention.
The masked multi-head attention operates in a manner similar to standard multi-head self-attention, but with an added masking step to enforce the autoregressive property. The process involves the following steps:
For each of the h attention heads, the input matrix (which is the embedded sequence with positional encodings) is linearly transformed into three distinct matrices: Query (Q), Key (K), and Value (V).
Q = X ⋅ W_QK = X ⋅ W_KV = X ⋅ W_V
Where:
X is the input matrix.
W_Q, W_K, and W_V are learned weight matrices specific to each attention head.
This step involves calculating attention scores and applying the crucial masking.
Calculate Attention Scores: The dot product between the Query (Q) and the transpose of the Key (Kᵀ) matrices is computed.
Scores = Q ⋅ Kᵀ
Scale Scores: The scores are scaled by dividing by the square root of the dimension of the key vectors (√d_k). This scaling helps to stabilize gradients during training.
Scaled_Scores = Scores / √d_k
Apply Masking: Before applying the softmax function, a mask is applied to the Scaled_Scores. This mask sets the scores corresponding to future tokens to negative infinity (-∞). This ensures that when the softmax is applied, the attention weights for these masked positions become zero.
Example: When predicting the word at position i, the decoder should only attend to tokens from position 1 up to i. Tokens at positions i+1, i+2, etc., are masked.
The masking effectively creates an upper triangular matrix of -∞ values in the attention score matrix.
Softmax: The softmax function is applied to the masked attention scores. This converts the scores into probability distributions, representing the attention weights.
Attention_Weights = softmax(Scaled_Scores)
Due to the masking, positions that were assigned -∞ will now have an attention weight of 0.
Weighted Sum: The attention weights are then multiplied with the Value (V) matrix. This operation produces the final output for each attention head, where each output is a weighted sum of the Value vectors, based on the computed attention weights.
The process described above is repeated for all h attention heads.
Concatenate Outputs: The output matrices from each attention head are concatenated together.
Linear Projection: This concatenated output is then passed through a final linear transformation using a learned projection matrix W_O.
Final_Output = Concatenated_Outputs ⋅ W_O
This Final_Output is the result of the masked multi-head attention sublayer and is then passed to the subsequent sublayer in the decoder block, which typically involves cross-attention to the encoder's output.
Prevents Data Leakage: Ensures that the model does not cheat by looking at future tokens during training, which would lead to an overestimation of its performance.
Mimics Inference Conditions: By forcing the model to predict based only on past tokens, it accurately replicates how the decoder will function during actual generation.
Maintains Sequence Integrity: Crucial for tasks involving sequential data generation, like language translation or text summarization, to ensure the output is coherent and follows the correct order.
Enables Autoregressive Behavior: Teaches the decoder to learn the inherent sequential dependencies in the target data, allowing it to generate sequences one token at a time in a meaningful way.
Masked multi-head attention is a foundational mechanism in Transformer decoders, essential for learning autoregressive sequence generation. By strategically masking future tokens during training, it bridges the gap between training and inference, ensuring that the model learns to predict tokens based solely on preceding context. This adherence to causal masking is what enables the decoder to produce coherent and contextually relevant sequences, token by token.
The output of this sublayer is then fed into the next stage of the decoder, which typically involves multi-head attention that attends to the output of the encoder, allowing the decoder to leverage information from the source sequence.