Transformer Decoder: Linear & Softmax Layers Explained

Uncover the crucial role of linear and softmax layers in the Transformer decoder for generating target language sentences in LLM and AI models.

Linear and Softmax Layers in the Transformer Decoder

The Transformer decoder, after processing input through its various blocks, produces a contextualized representation of the target sentence. The final crucial step in generating the actual words of this sentence involves passing the output of the topmost decoder layer through a linear layer followed by a softmax function.

Role of the Linear Layer

The linear layer functions as a classifier. Its primary purpose is to transform the high-dimensional contextualized representation from the decoder's top layer into a logits vector.

  • Logits Vector: This vector has a size precisely equal to the size of the target vocabulary. Each element in this vector represents an unnormalized score for a specific word in the vocabulary.
  • Transformation: It essentially projects the decoder's output into a space where each dimension corresponds to a word in the vocabulary.

Example: If our target vocabulary consists of three words: ["hello", "world", "goodbye"], and the decoder outputs a dense vector of size 512 (common in Transformer models), the linear layer will transform this into a logits vector of size 3. For instance, the output might look like [2.5, 1.8, -0.9]. These are not probabilities yet, but scores indicating relative favorability for each word.

Role of the Softmax Function

The softmax function takes the logits vector generated by the linear layer and converts these unnormalized scores into a probability distribution.

  • Probability Distribution: For each logit score, the softmax function calculates a probability. The sum of all probabilities in the output vector will always be 1.
  • Likelihood: This probability distribution represents the likelihood of each word in the vocabulary being the next word in the generated sequence.

Example (Continuing from above): If the linear layer produced [2.5, 1.8, -0.9], applying the softmax function would result in a probability distribution like [0.78, 0.19, 0.03]. This indicates that "hello" has a 78% chance of being the next word, "world" has a 19% chance, and "goodbye" has a 3% chance.

The decoder then typically selects the word with the highest probability as the predicted next token.

Predicting the Next Word: An Illustrative Process

Let's trace how the decoder predicts the next word, assuming it has already generated the start-of-sentence token <sos> and the word "Je" (e.g., in French translation).

  1. Decoder Output: The decoder processes the current input sequence (<sos>, Je) and generates a rich contextualized output vector from its top decoder layer.
  2. Linear Layer Transformation: This output vector is fed into the linear layer, which projects it into a logits vector of size equal to the vocabulary size.
  3. Softmax Application: The softmax function is applied to the logits vector, yielding a probability distribution over all possible next words in the vocabulary.
  4. Token Selection: If, for example, the word "vais" is at index 2 in the vocabulary and the probability at index 2 in the softmax output is the highest, the decoder predicts "vais" as the next word.

This iterative process of generating a vector, transforming it to logits, converting to probabilities via softmax, and selecting the highest probability token is repeated. The newly predicted token is then appended to the input sequence for the next prediction step. This continues until the decoder generates the special <eos> (end-of-sentence) token, signaling the completion of the target sentence.

Summary

  • The linear layer projects the decoder's internal representation to a vector whose size matches the target vocabulary.
  • The softmax function converts these raw scores (logits) into a normalized probability distribution over the entire vocabulary.
  • The word with the highest probability is selected as the decoder's predicted next token.

This mechanism is fundamental to how Transformer decoders achieve sequential word generation, powering tasks like machine translation, text summarization, and more.


In the subsequent sections, we will explore how all the decoder components are integrated to understand the complete operation of the Transformer decoder.