Paper Deepdives, Retrieval Augmented Generation (RAG), I - DrQA

Posted by Ibrahim Cikrikcioglu on October 12, 2023 · 10 mins read

1. Introduction

In this series of paper deep-dives, I aim to cover the Retrieval Augmented Generation (RAG) papers. This is because RAG is increasingly becoming the primary solution to prevent Large Language Models (LLMs) from hallucinating and ensuring they generate factual knowledge. The paper I’m discussing today isn’t strictly a generation model, but it’s crucial for understanding the literature and observing the early developments. I’ll strive to be as concise as possible.

Big Picture

In the RAG approach, instead of solely relying on the knowledge ingrained within the trained model, information or context is retrieved from external sources like Wikipedia or corporate documents. When the model responds to a query, it leverages this retrieved information to produce a more factually accurate output. The goal is to derive answers from the external information rather than generating potentially inaccurate responses based on its own ‘parameters’ or risking hallucination.

This paper, though relatively “old” in the rapidly evolving NLP landscape (especially after LLMs revolutionized the field), does not utilize a GPT-like model. Instead, the model performs span prediction. Given a paragraph containing the answer, the model identifies tokens likely to signify the start and end of the answer. The tokens in between these two points then form the model’s response.

The overall pipeline can be summarized as follows: First, the retriever component of the model receives a query (q) and searches Wikipedia to identify the top-5 most relevant articles. Each article is then split into paragraphs (p), which are passed to the DocumentReader segment of the model. After encoding both the query and each paragraph into vectors, the predictor component determines the start and end spans for every paragraph. Given that multiple paragraphs might correspond to a single question, the paragraph with the highest multiplicative span probability — $p(start)$ x $p(end)$ — is selected. As a result, the model presents the corresponding excerpt from that paragraph. Naturally, there are numerous nuances, from feature engineering to optimal span identification, but these will be delved into in subsequent sections. This section merely provides a high-level overview of the prediction process.

Retriever

This component isn’t groundbreaking. The paper employs a traditional Information Retrieval (IR) method to pinpoint the documents most pertinent to a given query.

Inverted Index

The paper harnesses an inverted index to swiftly pull up relevant articles. For those unfamiliar with the concept, an inverted index is essentially a data structure optimized for document search. Every term is linked to a list of documents in which it appears. Thus, when you query the database, it examines each token in your query and swiftly retrieves the documents containing those tokens.

TFIDF

There’s a possibility that the initial retrieval phase brings up hundreds of articles, leaving us uncertain about their relevance. To further filter and refine our search, we enter the ranking phase of IR. With TFIDF, we allocate varying scores to the tokens in a query and aggregate them. For instance, the term “retrieval” might crop up 10 times in document A and just once in document B. Our ranking mechanism should favor the document with more occurrences. This is represented by the term frequency (TF) element. Simultaneously, we ought to discount terms that are ubiquitously present across all documents. Consider the word “every”; it’s likely that many documents feature it, irrespective of their pertinence. Hence, this token shouldn’t significantly influence our relevance score. Given a term $i$ and a document $j$, the relevance score can be formulated as:

\[w_{i, j}=t f_{i, j} \times \log \left(\frac{N}{d f_i}\right)\]

To enhance the retrieval process, the paper also incorporates bigrams.

Bigrams and Hashing

Instead of solely relying on unigrams from the query, the paper integrates bigrams into the vocabulary. This ensures that specific phrases and word sequences are considered. However, with an extensive corpus like Wikipedia, the number of bigrams can be immense. To avoid ballooning the vocabulary size while preserving efficiency, the paper hashes these bigrams into one of the bins, up to a total of $2^{24}$. Although collisions may occur, they are generally minimal and don’t significantly impair the overall search system.

With this, we’ve grasped the retrieval stage. Upon receiving a query, this stage fetches the top 5 most pertinent Wikipedia articles that could contain the answer. Next, let’s delve into how the query, paragraphs, and their tokens are encoded for subsequent predictions.

ENCODING

DocumentReader

Given a query, represented as $q = {q_1, q_2, … , q_n}$, the system retrieves 5 articles. Each article is further subdivided into paragraphs, denoted by $p = {p_1, p_2, …, p_n}$. The objective is to ascertain the probability of span-start and span-end for every token, $p_j$, within a paragraph. This necessitates encoding both the query and the paragraph, which are sets of tokens.

Document Encoding

Remember, we require an encoding at the token level. This means every token in the document should be encoded. The encoding for a token, $\hat{p}_{i}$, is the concatenation of the following:

  • GloVe embedding of the token in $R^{300}, $emb(p_i)$
  • Binary encodings: 1 if there’s an exact match with any of the query tokens in three different forms: original, lemma, or lowercase. So, for example, it’s [1, 0, 0]: $match(p_i)$
  • Token features such as part-of-speech, Named Entity Recognition (NER) tags, and term frequency: $extra(p_i)$
  • Aligned question embedding: weighted average of all token embeddings in the query using a soft alignment mechanism:
\[align(p_i) = \sum_j a_{i, j} \mathbf{E}\left(q_j\right)\]

where

\[a_{i, j}=\frac{\exp \left(\alpha\left(\mathbf{E}\left(p_i\right)\right) \cdot \alpha\left(\mathbf{E}\left(q_j\right)\right)\right)}{\sum_{j^{\prime}} \exp \left(\alpha\left(\mathbf{E}\left(p_i\right)\right) \cdot \alpha\left(\mathbf{E}\left(q_{j^{\prime}}\right)\right)\right)}\]

Here, $\alpha$ is a single-layer NN with learnable parameters and a ReLU activation function. This step ensures that words which are semantically similar, but lexically different, are considered during prediction.

Therefore, $p_i$ is encoded as $[emb(p_i), match(p_i), extra(p_i), align(p_i)]$ and represented as $\hat{p}_{i}$.

Yet, that’s not the culmination. The encoding, $\hat{p}_{i}$, further undergoes processing by an additional RNN to generate the final embedding for each token, helping produce contextual embeddings per token.

Question Encoding

Encoding the queries is straightforward. We don’t require final encoding at the token level. Hence, query encoding is a weighted sum of per-token GloVe embeddings. For a query, q = {$q_1, q_2, …, q_n$}, we determine:

\[\mathbf{q}=\sum_j b_j \mathbf{q}_j\]

where

\[b_j=\frac{\exp \left(\mathbf{w} \cdot \mathbf{q}_j\right)}{\sum_{j^{\prime}} \exp \left(\mathbf{w} \cdot \mathbf{q}_{j^{\prime}}\right)}\]

In this context, $\mathbf{w}$ is a learnable weight vector.

Prediction

With the forward pass, document, and query encodings clarified, what remains is understanding prediction. This is a span prediction task; we must ascertain the starting token in the paragraph and the concluding token in response to the query. While the exponential of the cosine similarity between the query encoding and the paragraph token can be used, and the highest score selected independently for both the start and end token predictions, there are certain nuances to note:

  1. It’s imperative that the index of the start token is less than that of the end token; else it’s a logical contradiction!
  2. Given a query, there are multiple paragraphs and even articles to consider. Only one paragraph can be selected, along with its respective end and start indices.

Instead of utilizing plain cosine similarities, the probabilities are derived using a bilinear component $W_{s}$ and $W_{e}$ for the start and end index predictions, respectively.

\[P_(i) \propto \exp (\mathbf{p}_i \mathbf{W}_s \mathbf{q})\] \[P_(i) \propto \exp \left(\mathbf{p}_i \mathbf{W}_e \mathbf{q}\right)\]

The paragraph with the peak value of $P_{\text {start }}(i) \times P_{\text {end }}\left(i^{\prime}\right)$ is selected, ensuring that $i \leq i^{\prime} \leq i+15$.

It’s vital to understand that this is the inference phase. During training, the accurate passages are fed directly, training two independent classifiers—one for the start index and another for the end index.

While I won’t delve into their evaluations across multiple datasets, it’s worth noting that the Document Retriever isn’t trainable. For a more comprehensive analysis, I recommend perusing the paper directly.

Conclusion

In the ever-evolving landscape of Natural Language Processing, efficient retrieval and accurate prediction are paramount. The methodologies outlined in the discussed paper, from employing traditional IR techniques to leveraging advanced token-level encoding, underscore the significance of grounding machine predictions in substantive data. By integrating both unigrams and bigrams, and applying nuanced weighting and alignment mechanisms, the paper crafts a robust system that addresses the challenges of span prediction. Stay tuned for upcoming RAG papers that are more advanced.

Resources

  1. Chen et al. 2017