Question Answering with BERT: Extract Answers from Text

Learn how BERT excels at Question Answering (QA) by extracting precise answers from text passages. Understand the question-paragraph input and span extraction process with examples.

Question Answering with BERT

In a question-answering (QA) task, the primary objective is to extract the precise answer to a given question from a provided text passage (paragraph). The input to a BERT model for QA typically consists of a question-paragraph pair. The model's goal is to identify the exact contiguous span of text within the paragraph that serves as the answer to the question.

Example

Consider the following input:

Question: "What is the immune system?"

Paragraph: "The immune system is a system of many biological structures and processes within an organism that protects against disease. To function properly, an immune system must detect a wide variety of agents, known as pathogens, from viruses to parasitic worms, and distinguish them from the organism’s own healthy tissue."

The expected answer extracted from the paragraph is: "a system of many biological structures and processes within an organism that protects against disease."


How to Fine-Tune BERT for Question-Answering

Fine-tuning BERT for question-answering involves training the model to predict the start and end positions of the answer span within the paragraph. For the example question above, if the model correctly identifies that the answer begins at token index 4 (corresponding to the word "a") and ends at token index 21 (corresponding to the word "disease"), it can then extract the correct answer span from the paragraph.

Predicting Start and End Positions

To determine the start and end indices of the answer, BERT assigns probabilities to each token in the paragraph. These probabilities indicate how likely each token is to be the beginning or the end of the answer.

  • Start Vector: A learned vector that is used to compute the probability of each token being the start of the answer.
  • End Vector: A learned vector that is used to compute the probability of each token being the end of the answer.

For each token representation in the paragraph:

  1. Start Probability Calculation: Compute the dot product of the token's representation with the learned start_vector. Then, apply the softmax function to the resulting scores to obtain the probabilities for each token being the start of the answer.
  2. End Probability Calculation: Similarly, compute the dot product of the token's representation with the learned end_vector. Apply the softmax function to get the probabilities for each token being the end of the answer.

The token indices with the highest computed start and end probabilities are then identified as the predicted start and end points of the answer span.


Workflow Overview

The typical workflow for question-answering with a fine-tuned BERT model is as follows:

  1. Tokenization: Tokenize the question and the paragraph together. This usually involves combining them with special tokens (e.g., [CLS] at the beginning and [SEP] between the question and paragraph).
  2. Model Input: Input the combined tokens, along with token_type_ids and attention_mask, into the pre-trained BERT model.
    • token_type_ids help distinguish between tokens belonging to the question and tokens belonging to the paragraph.
    • attention_mask informs the model which tokens are actual words and which are padding.
  3. Embedding Extraction: Obtain the contextualized token embeddings from the output of the BERT model.
  4. Probability Computation: Utilize the learned start_vector and end_vector to compute the start and end probabilities for each token in the paragraph.
  5. Span Identification: Identify the token indices with the highest start and end probabilities.
  6. Answer Extraction: Reconstruct the answer by extracting the text span from the original paragraph that corresponds to the identified start and end token indices.

This process is often visualized to show the tokenization of the question-paragraph pair, the extraction of token embeddings, and the subsequent probability computation for the start and end tokens.


Technical Details

When BERT is fine-tuned for extractive question answering, the output layer is typically modified. For each token in the input sequence (specifically within the paragraph part), two scores are predicted:

  • A score indicating the likelihood of that token being the start of the answer.
  • A score indicating the likelihood of that token being the end of the answer.

These scores are then passed through a softmax function independently to convert them into probabilities. The model aims to find a start token index i and an end token index j (where i <= j) such that the sum of the start probability at i and the end probability at j is maximized. The actual answer is then the sequence of tokens from index i to j in the original paragraph.


Key Concepts

  • Extractive QA: The task of extracting a contiguous span of text from a document that answers a question.
  • Start/End Token Prediction: The core mechanism by which BERT identifies the answer span.
  • Fine-tuning: Adapting a pre-trained language model like BERT to a specific downstream task (QA in this case) by training it on a task-specific dataset.
  • token_type_ids: Used to differentiate segments in the input, crucial for distinguishing the question from the context.
  • Softmax: Applied to convert raw scores into probabilities, allowing for the selection of the most likely start and end tokens.

Interview Questions

Here are some common interview questions related to BERT for Question Answering:

  • What is the main objective in a BERT-based question answering task?
  • How does BERT process a question-paragraph pair for QA?
  • What are start and end vectors in the context of BERT QA?
  • How does BERT predict the start and end position of an answer?
  • Why is softmax used in predicting answer spans in BERT?
  • What is the input format for BERT in a QA task?
  • How does BERT differentiate between question and context tokens?
  • What are the roles of token_type_ids and attention_mask in QA with BERT?
  • What is the significance of the highest start and end probabilities?
  • How is the extracted answer span reconstructed from token indices?