A Statistical Framework for Data-dependent Retrieval-Augmented Models
Abstract
Modern ML systems increasingly augment input instances with additional relevant information to enhance final prediction. Despite growing interest in such retrieval-augmented models, their fundamental properties and training are not well understood. We propose a statistical framework to study such models with two components: 1) a retriever to identify the relevant information out of a large corpus via a data-dependent metric; and 2) a predictor that consumes the input instances along with the retrieved information to make the final predictions. We present a principled method for end-to-end training of both components and draw connections with various training approaches in the literature. Furthermore, we establish excess risk bounds for retrieval-augmented models while delineating the contributions of both retriever and predictor towards the model performance.We validate the utility of our proposed training methods along with the key takeaways from our statistical analysis on open domain question answering task where retrieval augmentation is important.
1 Introduction
Recent advancements in machine learning (ML) have not only led to breakthroughs on long-standing challenging tasks across various fields, but they have also inspired a great deal of interest to develop ML models that can solve even harder tasks (Meinhardt et al., 2022; Lewkowycz et al., 2022; Cramer, 2021) or focus on completely new fields (Austin et al., 2021; OpenAI, 2023; Singhal et al., 2023). While scaling the size of parametric ML models, such as neural networks, is becoming the predominant approach to meet such demands (Brown et al., 2020; Chowdhery et al., 2022; Touvron et al., 2023; Dosovitskiy et al., 2021; Dehghani et al., 2023), the excellent performance realized by this approach is marred by drawbacks such as high computational cost, inefficient storage of world knowledge in parameters, lack of transparency in model behavior, and reduced grounding/factuality of model predictions.
Recognizing these shortcomings, retrieval-augmented models (RAMs) have emerged as a promising alternative. Such models typically employ two components, namely retriever and predictor, during inference on a given input instance: The retriever first identifies instance-specific relevant information from a data-store, and then the predictor jointly processes the retrieved information and the input instance to make a final prediction. In practice, RAMs have enjoyed favorable performance vs. compute trade-off (Borgeaud et al., 2021; Das et al., 2021; Thai et al., 2023) as employing moderate-size parametric models as retriever and predictor in a RAM often matches or exceeds the performance of a much larger standalone ML model that directly maps input instances to predictions. Similarly, conditioning prediction on the retrieved information has shown to exhibit improved grounding (Shuster et al., 2021; Lin et al., 2023; Asai et al., 2023). Furthermore, having access to an external corpus can obviate the need to store task-specific world knowledge in model parameters and enable incorporating dynamically evolving knowledge (Izacard et al., 2022; Liska et al., 2022).
Despite these desirable characteristics, training RAMs presents multiple challenges. The natural approach of independently training retriever and predictor can be sub-optimal (Izacard et al., 2022). Moreover, it requires collecting intermediate supervision on the instance-dependent relevant information to retrieve, which is missing in common datasets and expensive to obtain in general. A common strategy to circumvent the lack of intermediate supervision is to perform end-to-end training which presents its own unique challenges in the context of RAMs. Fundamentally, the retrieval corresponds to the non-differentiable discrete operation of selecting relevant information from a data-store, e.g., via top-k selection based on retriever scores, which prevents direct gradient propagation to the entire retriever. Several clever solutions to above-mentioned issues have been proposed in the literature that focus on different training objectives to propagate learning signal from the predictor into the retriever. However, a formal study that unifies these solutions is missing from the literature.
Another key challenge that prevents the resource-efficient development and deployment of RAMs is the limited understanding of their basic properties such as their generalization behavior and expressive power. For instance, how do the retriever and predictor components interact to ensure good task-specific performance? Are there any principles guiding the selection of the retriever and predictor components? How does (size of) the data-store feature in the final performance of a RAM?
In this paper, we address both aforementioned shortcoming in the literature pertaining RAMs. To unify the training of RAMs, we begin with writing down the natural objective function, which somehow has eluded the literature. This natural objective simply minimizes the expected prediction loss, where the expectation is taken over the distribution induced by the retriever. Empirically, we find this objective to be effective on standard benchmarks: NaturalQuestions (NQ; Kwiatkowski et al., 2019) and TriviaQA (Joshi et al., 2017).
As for the generalization and expressive power, we present an excess risk bound for RAMs that captures the effect of retrieval and prediction function classes. The proposed bound allows us to highlight how retriever and predictor components play complementary roles to reduce approximation error as we increase their respective function class complexity. We also capture the role of data store in improving the model performance by reducing the approximation error. On the generalization front, we carefully decouple the generalization term in the excess risk over the predictor and retriever function classes. This allows us to tightly control the generalization term with only logarithmic dependence on the data store size. As a concrete instantiation for our excess risk bounds, we consider feed-forward neural networks of varying depth for both the retriever and the predictor.
To summarize, our main contributions include:
- •
- •
-
•
We validated the utility of the proposed objective on two standard QA benchmarks: NaturalQuestions (NQ) and TriviaQA (Sec. 4).
2 Problem setup
In this paper, we focus on developing a systematic understanding of RAMs with learned retrievers in a classification setting where the model has access to a data-store. Towards this, we begin by formally defining the problem setup and providing the necessary background along with the notations used.
Let’s first consider the standard classification setting which requires predicting a class in for a given instance . Assume that captures the underlying data distribution and one has access to training examples that are independent and identically distributed (i.i.d.) according to . Given , one hopes to learn a classifier that minimizes the miss-classification error:
(1) |
where denotes the score that assigns to the -th class, given the input instance . Since directly optimizing the miss-classification error or -loss poses computational challenges, one typically selects the classifier that minimizes the empirical risk associated with a well behaved surrogate loss function on the training sample :
(2) |
The (population) risk associated with the surrogate loss function takes the following form:
(3) |
Different from the standard classification setup described above, we now consider the classification task with access to a data-store: Given an instance , the classifier can potentially leverage a data-store – a collection of potentially relevant information or evidences, where denotes the space of all possible evidences. Accordingly, one can define the empirical and population risks of a classifier as follows:
(4) | ||||
(5) |
where expectation is take over in as well as the possible randomness in . However, due its prohibitive computational cost, such a general classifier that directly processes the entire data-store for each prediction is far from how an additional data-store is utilized by ML models in practice.
This motivates us to study the following explicit retrieval-augmented classification setup to utilize the data-store: Given an input instance , one first retrieves input-dependent supporting evidences with the help of a retriever model which has access to the entire data-store . Now, given and , one invokes a predictor model to predict the class associated with . Thus, a retriever-augmented classification setup consists of two key components models: 1) retriever model and 2) predictor model, which we formally introduce next.
Retriever model. For the retrieval stage, we rely on a retriever model to capture the relevance of an evidence towards the input instance . Let be the retriever model parameterized by that assigns a relevance score to the instance-evidence pair . Furthermore, for each instance , the retriever model induces the following distribution over the set of potential evidences:
(6) |
There are multiple strategies to construct the set of input-dependent supporting evidences based on . For example, for a fixed integer , one could select evidences corresponding to the highest scores in . Another strategy is to sample evidences according to the distribution in (6). Here, one could perform the sampling with or without replacement. In what follows, we denote the retrieved supporting evidence for the instance as to highlight the dependence on the underlying retriever model.
Predictor model. Let be the predictor model parameterized by , where denotes the Kleene star on . Given and , the predictor model assigns a score to each class in , defining a distribution over as follows:
(7) |
where denotes the score assigned to the -th class by the predictor model .
For ease of exposition, we focus on the setting with in our analysis throughout this paper. This corresponds to retrieving a single supporting evidence for each input instance. Our analysis can be generalized to by working with a as the new data-store and as a distribution over obtained by suitably modifying in (6). For example, when supporting evidences are sampled with replacement, then the following holds .
Empirical risk minimization and excess risk for RAMs. For a pair of retriever and predictor models parameterized by and , respectively, we can define the empirical and population risks associated with a (surrogate) loss function as follows:
(8) | ||||
(9) |
Note that the expectation in (9) is taken over as well as the randomness involved in the retrieval stage, e.g., sampling the evidences according to in (6). Given a pair of predictor class and retriever class , let denote the predictor-retriever pair obtained via empirical risk minimization (ERM) as follows:
(10) |
Let denote the set of all measurable functions from to . The optimal risk for the classification with access to the data-store is achieved by the best possible predictor when it has access to the best retrieved evidence in . In particular, we have
(11) |
Given , we defined the excess risk of a predictor-retriever pair as follows:
(12) |
With the formal definition of the classification setting with access to a data-store and the necessary background in place, we proceed to address the two key objectives of this work: 1) Proposing a natural and efficient joint end-to-end training procedure for the predictor-retriever pair in a RAM; and 2) Developing a rigorous statistical understanding of RAMs focusing on the interaction between predictor and retriever components towards reducing overall excess risk.
3 Joint training and excess risk
Recall that training a RAM involves training both the retriever and the predictor components of the model without access to intermediate supervision on retrieval, which is infeasible to obtain in most practical settings. Thus, it becomes critical to devise methods to jointly train and with access to only labeled instances with the predictor guiding the retriever training based on how valuable the retriever-provided evidences are towards the correct final prediction.
Towards this, we leverage the empirical risk from (8) along with the log-loss , where is defined in (7). In particular, this leads to the following joint end-to-end training objective:
(13) |
Note that the objective in (13) aims to improve the end-to-end performance of a RAM in deployment in the sense that the objective aims to minimize the expected loss given the selected evidences as per the retriever-induced distribution. One can use gradient-based methods to jointly minimize the objective in (13) with respect to ; however, its efficient implementation is non-trivial due to the sum over entire data-store . In App. C.1, we discuss some approximate design choices. Lastly, please refer to Sec. 3.6 for connections between our proposed objective in (13) and some of the existing end-to-end training approaches for RAMs.
Next, to study the generalization and expressive power of RAMs, we want to bound the excess risk as defined in (12). We consider to be a compact subspace of and, for simplicity, take . Similarly, we consider that each retrieval example is embedded in the space . We consider a data-store that polynomially scales with training data size, i.e., . For the purpose of analysis, we specialize our log-loss to be bounded by , which is given as
where and are defined in (7).
3.1 Excess risk decomposition
Our excess risk relies on separating out the contribution coming from the retriever and the predictor during the joint training. Moreover, the retriever and predictor errors can be each split into generalization and approximation error.
The population risk optimizer of our joint training over the space is defined as
For a predictor , sample and retrieved example , let us denote the risk averaged over the labels as
(14) |
For any fixed predictor (not necessarily in ) and fixed data-store , the retriever that optimizes the joint population risk is given as , where a tie is broken arbitrarily. Note that, for each sample , the best retrieved evidence may change. We define the optimal predictor within the class with best possible retriever as
The optimal retriever within the class for a given predictor is defined as
The excess risk for the classes and can be bounded as
(15) |
3.2 Generalization error
We first bound the generalization error and relate it to the covering number of the retriever and predictor class.
As our loss is bounded by , through standard concentration bounds (Shalev-Shwartz and Ben-David, 2014), we obtain that, for any , with probability at least :
However, is learned from the data. A high probability generalization error requires taking union over the space of . We employ Rademacher complexity based generalization error bounds. Next, the covering number of the space is used to bound the associated Rademacher complexity. See Shalev-Shwartz and Ben-David (2014) for details.
We define two norms which are used in defining the covering numbers for and . In particular, and fixed ,
(16) |
We also define to be the -covering number for the class with respect to the norm , and to be the -covering number for the class with respect to the norm . Then we have the generalization bound given as
(17) |
for
3.3 Approximation error
We next proceed to bound the retriever and predictor approximation errors. Towards this, we extensively use the Sobolev functions spaces. A Sobolev space for a domain is characterized by two quantities, – the number of weak-derivatives a (real-valued) function within it possesses, and – the norm with respect to which these derivatives are integrable. Please see Appendix A for a complete definition.
3.3.1 Retriever error
The retriever error is given by how well the score function approximates the optimal retriever given . In order to do so we first need to impose some smoothness constraints on the function . In particular, we assume the following.
Assumption 3.1 (Complexity of ).
There exists a baseline function such that the function defined by lies in the Sobolev space with derivatives and norm.
The above assumption says that for the predictor the loss profile (averaged over labels in ) , has two components – a (possibly) complex component that is uniform over , and a ‘smooth’ component. In other words, given two similar retrieved evidences, the predictor incurs similar losses when each of the evidences is utilized with an input instance.
Then, for any , we can bound the retriever loss as follows:
(18) |
3.3.2 Predictor error
The predictor error is measured with the optimal retrieval (as the retriever error is considered separately above). For this, we need to first quantify how the retrieval augmentation using the data-store helps.
Usefulness of retrieval set:
We start with characterization of the prediction task in the presence of the data-store . We assume that there exists a score function , and the corresponding probability distribution
(19) |
that approximates well for all and . Furthermore, we want this score function to lie coordinate wise in a Sobolev space. The following assumption formalizes this.
Assumption 3.2 (Retrieval quality).
There exists a score function such that
-
1.
for each , the function lies in the Sobolev space with derivatives and finite norm,
-
2.
for any , there exists a retrieved evidence such that , as defined in (19), satisfies
Note that this is independent of the retriever class and , and captures intrinsic property of the data-store . The tuple defines the usefulness of . In particular, the higher the closer the approximation; and the higher the and smaller the embedding dimension the ‘simpler’ the score function used for this approximation.
3.4 Final excess risk bound
3.5 Illustrative example: MLPs
We instantiate both our retriever and predictor classes to be multi-layer perceptron (MLP) with depth & width and depth & width , respectively. The class is defined in Appendix A. The specialized excess risk bound for this setting is given as
Theorem 3.4 (Excess risk for MLP).
Finally, to capture the optimal trade-off under finite data size , we consider classes of retriever and predictors that change with the data size, denoted by and , with growing depths and respectively. Similarly, we also consider growing upper bound on the loss function by . Let . For , , and , the excess risk is bounded by
We should contrast the above result with the prediction when there is no retrieval. Let us assume that the functions for all lies in the Sobolev space with derivative and norm. The predictor excess risk rate with is .
Note that our analysis indicates that we may gain through retrieval: For large enough data store , as the data size increases and and (see Fig. 1).
Method | small | base | large | |||||||||
small | base | large | small | base | large | small | base | large | ||||
No retriever, train predictor | ||||||||||||
Reverse Cross-Entropy | 19.6 | 25.5 | 29.1 | |||||||||
Fixed retriever , train predictor | ||||||||||||
Reverse Cross-Entropy | 23.2 | 26.6 | 28.3 | 27.5 | 32.4 | 34.7 | 32.2 | 36.4 | 37.8 | |||
Fixed predictor , train retriever | ||||||||||||
EMDR2 | 23.9 | 28.5 | 31.0 | 29.2 | 34.2 | 36.6 | 33.4 | 38.0 | 40.8 | |||
PDist | 30.1 | 34.5 | 38.4 | 34.0 | 39.7 | 42.8 | 37.6 | 42.8 | 44.7 | |||
Reverse Cross-Entropy + PG | 25.9 | 30.6 | 31.7 | 31.5 | 36.4 | 37.9 | 36.0 | 40.2 | 41.4 | |||
Reverse Cross-Entropy + TopK | 29.4 | 35.5 | 37.9 | 33.8 | 39.7 | 43.0 | 37.2 | 42.3 | 45.0 | |||
Jointly train predictor and retriever | ||||||||||||
EMDR2 | 24.1 | 30.4 | 32.7 | 30.4 | 35.6 | 39.3 | 34.5 | 39.7 | 42.1 | |||
PDist | 28.7 | 33.2 | 36.6 | 33.3 | 37.1 | 38.8 | 36.2 | 40.2 | 41.6 | |||
Reverse Cross-Entropy + PG | 27.1 | 31.0 | 32.7 | 33.3 | 37.2 | 38.2 | 36.5 | 39.8 | 41.4 | |||
Reverse Cross-Entropy + TopK | 32.8 | 37.8 | 40.1 | 36.6 | 41.8 | 44.8 | 38.8 | 43.8 | 46.4 |
3.6 Connections with prior end-to-end training
We conclude our treatment of end-to-end training of RAMs by drawing parallels between our proposed method with some representative approaches from the literature.
EMDR2 Sachan et al. (2021) minimize the following objective based on the negative log-likelihood:
(21) |
It follows from the convexity of and Jensen’s inequality that our objective in (13) upper bounds the EMDR2 objective in (21); as a result, minimizing the former also minimizes the latter but not vice versa.
Perplexity distillation (PDist) Another approach for joint training of RAMs in the literature involves optimizing two distinct objectives for training the predictor and retriever components. For example, Izacard et al. (2022) propose multiple objectives for retriever training, including PDist (Sachan et al., 2023) which is defined as follows:
(22) |
where denotes the cross entropy between two distributions and
represents a predictor-assigned distribution over evidences based on their utility towards making correct prediction. As for the predictor training, they optimize an objective akin to (13) with respect to . Besides this similarity in the predictor training, our approach for retrieval training has a subtle connection with PDist. Note that PDist optimizes forward cross-entropy between the predictor and the retriever induced distributions to train the retriever. On the other hand, our objective in (13) is closer to , the reversed cross-entropy between the two distributions. The former has the “mean-seeking” behavior whereas the latter has the “mode-seeking” behavior (Huszár, 2015; Gu et al., 2023; Agarwal et al., 2023).
Similarity with RLHF/RLAIF Note that the per-example objective of our retrieval training approach takes the form:
(23) |
i.e., the predictor model provides feedback on the (value) of the evidences sampled by the retriever model. Alternatively, one can view as the reward assigned to the evidence by the predictor model and retriever model aims to select those evidences that maximize this reward value. This is similar to RLHF (Ziegler et al., 2019) or RLAIF (Bai et al., 2022) paradigm, where the underlying LLM aims to sample those generations which maximize the reward assigned by a reward model. However, note that in RLHF/RLAIF paradigm the policy network and reward model are not jointly trained together unlike in RAM.
4 Experiments
There have been numerous successful practical applications of RAMs in the literature (e.g., Sachan et al. (2021); Izacard et al. (2022)). Here, we present a brief empirical study for such models in order to corroborate the benefits predicted by our theoretical results. In particular, we consider the task of open-domain question answering and show that proposed objective is competitive to the objectives proposed in the literature and observe the trade-offs in model capacity between retriever and predictor model.
Data Our evaluation is based on two benchmark datasets: NQOpen Kwiatkowski et al. (2019) and TriviaQA Joshi et al. (2017), which serve as sources for supervised examples , while chunked Wikipedia 2018 is used as the data-store following literature (Karpukhin et al., 2020a). Consistent with established practices, we employ the exact match metric to assess the correspondence between the predicted answers and the ground truth. Additionally, we introduce a recall metric to measure the frequency at which the answer string appears within the retrieved documents.
Models We implement the retriever component using GTR (Ni et al., 2022) and the predictor component using T5 (Raffel et al., 2020). We sweep across small, base, and large configurations for both retriever and predictor. The details regarding the model sizes, expressed in terms of the number of parameters, are provided in Table 6 (App. C).
Method | small | base | large | |||||||||
small | base | large | small | base | large | small | base | large | ||||
No retriever, train predictor | ||||||||||||
Reverse Cross-Entropy | 17.9 | 23.1 | 28.0 | |||||||||
Fixed retriever , train predictor | ||||||||||||
Reverse Cross-Entropy | 31.5 | 34.9 | 38.8 | 37.0 | 40.6 | 44.4 | 43.4 | 45.9 | 49.7 | |||
Fixed predictor , train retriever | ||||||||||||
EMDR2 | 34.6 | 41.3 | 48.3 | 40.1 | 48.2 | 53.4 | 46.0 | 50.7 | 54.9 | |||
PDist | 45.7 | 53.3 | 57.2 | 50.8 | 53.2 | 61.6 | 53.5 | 55.4 | 62.3 | |||
Reverse Cross-Entropy + PG | 43.2 | 46.7 | 54.3 | 48.6 | 56.1 | 55.1 | 51.7 | 56.4 | 56.7 | |||
Reverse Cross-Entropy + TopK | 43.6 | 50.4 | 54.4 | 48.6 | 54.9 | 58.5 | 52.1 | 56.6 | 60.3 | |||
Jointly train predictor and retriever | ||||||||||||
EMDR2 | 37.0 | 43.1 | 49.7 | 42.4 | 50.5 | 55.6 | 47.1 | 53.4 | 59.2 | |||
PDist | 46.7 | 54.3 | 57.3 | 48.8 | 56.7 | 60.7 | 51.0 | 58.5 | 63.3 | |||
Reverse Cross-Entropy + PG | 47.0 | 52.9 | 55.7 | 49.9 | 57.6 | 61.1 | 52.1 | 59.8 | 59.2 | |||
Reverse Cross-Entropy + TopK | 46.8 | 52.9 | 56.0 | 49.2 | 56.6 | 60.1 | 52.3 | 58.8 | 62.4 |
Methods We compare following approaches: 1) utilizing no retriever, directly training predictor, 2) employing a fixed retriever, but training the predictor, 3) using a fixed predictor, but training the retriever, and 4) conducting joint training of both components. For the joint training and the retriever training phases, we experiment with multiple objectives: EMDR2 (cf. (21)), PDist (cf. (22)), Reverse Cross-Entropy + PG (cf. (45) in App. C.1), and Reverse Cross-Entropy + TopK (cf. (44) in App. C.1). Efficiently implementing any of these objectives is challenging due to the need to compute the gradient with respect to expectation over the entire data-store. We consider two approaches for computing the gradients approximately by: 1) restricting the expectation to top-K elements similar to EMDR2 and PDist; and 2) using REINFORCE (Williams, 1992) to obtain an unbaised estimate. More details can be found in App. C.1.
Observation 1 The addition of a retrieval component markedly enhances performance, as demonstrated in Tables 1 and 2, which present the exact match accuracy. Further improvements are observed when the retriever is specifically trained while keeping the predictor fixed. Joint training emerges as the most effective strategy.
Observation 2 Tables 4 and 5 (App. C) list the recall for the presence of the answer string within the retrieved content. PDist consistently achieves the highest recall, aligning with expectations given its design for distilling the retriever based on the predictor’s scores. However, despite its superior recall, other objectives may lead to better overall performance than PDist, suggesting that different objectives optimize the retriever and predictor with varying efficiencies.
Observation 3 Finally, in Table 3, we report the query per second (QPS), as a proxy for computational cost, achieved by different configuration of retriever and predictor model sizes. For achieving a specific accuracy threshold (e.g., 38.8 on NQ), multiple configurations are viable, such as pairing a large predictor with a small retriever, a base model for both, or a small predictor with a large retriever. The associated query per second (QPS) rates for these configurations are 135, 333, and 800, respectively, illustrating that equivalent accuracy levels can be attained with significantly differing QPS rate. This corroborates with our trade-offs in excess risk bounds for MLPs with different capacity in retriever and predictor components as illustrated in Figure 1. Thus, adding capacity to different parts of the model has different repercussion on quality and computational cost.
5 Discussion and related work
small | base | large | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
small | base | large | small | base | large | small | base | large | ||
822.60 | 819.83 | 800.89 | 334.30 | 333.22 | 331.06 | 135.06 | 135.34 | 134.87 |
Several works have proposed some form of retrieval augmented models. Here, we provide a brief account of the evolution of RAMs and discuss how our proposed joint-learning objective and the framework for excess risk analysis compare with existing end-to-end training methods.
Augment with local neighborhood The first approaches dating back to 1970s employed just augmenting training instance in the local neighborhood of the input space (Stone, 1977, 1980). Such approaches gained a lot of attention as parametric regression was not adequate in various practical applications of the time. This line of work aims to fit a low-degree polynomial at each point in the data set based on a subset of data points, which resulted in a rich literature on local polynomial regression in low dimensions (Katkovnik and Kheisin, 1979; Cleveland, 1979; Pinsker, 1980; Donoho and Liu, 1988; Ruppert and Wand, 1994; Ibragimov and Has Minskii, 2013). These classical ideas have found their application in many ML algorithms such as face recognition (Jain and Learned-Miller, 2011), dimensionality reduction via local linear embeddings (Roweis and Saul, 2000), domain adaptation (Yang et al., 2021), test time training on neighboring points (Sun et al., 2020; Gandelsman et al., 2022), etc. Recently, Basu et al. (2023) generalized this setup of augmenting with a local neighborhood of the input instance in the context of modern ML models like neural networks and proposed a statistical framework to study such retrieval-based models. However, they do not consider a learned or a specialized distance metric to find the augmenting set, which is critical for realizing good performance in practice (Schonberger et al., 2017; Karpukhin et al., 2020b) and studied in the present work.
Fixed retriever augmentation Next generation retrieval augmented models started to deploy either a hand crafted or a learned retriever. Zhang et al. (2006) employed SIFT (Lowe, 1999) based retrieval followed by a SVM (Cortes and Vapnik, 1995) classifier to improve performance on multiple vision tasks. Chen et al. (2009) studied generalization bounds for SVM-kNN methods – one of the limited works in this domain with formal analysis. For natural language understanding, methods like TF-IDF (Sparck Jones, 1972) were employed in the tasks like case based reasoning (Leake et al., 1996) and open-domain question answering (ODQA; Voorhees et al. 1999). Unlike many previous methods, one retrieves relevant text passages in ODQA settings as opposed to retrieving labelled training pairs. With introduction of transformers (Vaswani et al., 2017), both retriever and predictor models based on encoder and decoder, respectively, have become popular across various domains, including image classification (Long et al., 2022; Iscen et al., 2023), text classification (Wang et al., 2022; Zemlyanskiy et al., 2022), ODQA (Lee et al., 2019; Izacard and Grave, 2021), language modelling (Borgeaud et al., 2021), and even protein folding prediction (Cramer, 2021). Even using the same transformer model as both retriever and predictor boosts performance in language modeling (Khandelwal et al., 2020). Unlike SVM-kNN (Chen et al., 2009), to best of our knowledge, a formal analysis of retrieval-augmented approaches with modern neural networks is missing from the literature. Interestingly, retrieving examples also helps in-context learning (Rubin et al., 2022; Li et al., 2023). Our framework covers this scenario with representing the in-context examples retrieved from a data-store of examples. Our risk bounds can provide insights into why in-context learning with retrieved few-shot examples performs better than a zero-shot model.
End-to-end trained retriever augmentation For ODQA, Guu et al. (2020) proposed maximizing the marginalized likelihood by considering the retrieved set as a latent variable. EMDR2 (Sachan et al., 2021) optimized the same objective by approximating it based on the retriever induced distribution on the elements that receive top-K scores by the retriever. Hindsight (Paranjape et al., 2022) instead optimizes the ELBO by introducing a variational distribution with access to the outputs. VOD (Liévin et al., 2023) further generalized the standard ELBO based on KL divergence by employing Rényi divergence thereby tightening the lower bound. On the other hand, Atlas (Izacard et al., 2022) proposed an auxiliary loss for training the retriever directly rather than following the latent variable approach. Interestingly, RAG (Lewis et al., 2020) proposed to only train the query encoder for retriever, leaving the retrieval index fixed, thereby alleviating much of the end-to-end training difficulties of RAMs, but at cost of limiting model adaptation flexibility. None of these prior works studied statistical properties vis-à-vis expressivity and generalization of RAMs.
6 Conclusion
In this work, we initiate the development of a theoretical framework to study the statistical properties of RAMs with data-dependent retrieval. Our excess-risks analysis allows us to highlight how retriever and predictor components play complementary roles in reducing approximation error as we increase their respective function class complexity. We surface both theoretically and empirically a Pareto surface achieving the same performance with different size predictors and retrievers. As future work, it would be interesting to study the effect of dynamically updatable data-store and multi-step retrievals for making predictions.
References
- Agarwal et al. [2023] Rishabh Agarwal, Nino Vieillard, Piotr Stanczyk, Sabela Ramos, Matthieu Geist, and Olivier Bachem. Gkd: Generalized knowledge distillation for auto-regressive sequence models. arXiv preprint arXiv:2306.13649, 2023.
- Asai et al. [2023] Akari Asai, Zeqiu Wu, Yizhong Wang, Avirup Sil, and Hannaneh Hajishirzi. Self-rag: Learning to retrieve, generate, and critique through self-reflection. arXiv preprint arXiv:2310.11511, 2023.
- Austin et al. [2021] Jacob Austin, Augustus Odena, Maxwell Nye, Maarten Bosma, Henryk Michalewski, David Dohan, Ellen Jiang, Carrie Cai, Michael Terry, Quoc Le, et al. Program synthesis with large language models. arXiv preprint arXiv:2108.07732, 2021.
- Bai et al. [2022] Yuntao Bai, Saurav Kadavath, Sandipan Kundu, Amanda Askell, Jackson Kernion, Andy Jones, Anna Chen, Anna Goldie, Azalia Mirhoseini, Cameron McKinnon, et al. Constitutional ai: Harmlessness from ai feedback. arXiv preprint arXiv:2212.08073, 2022.
- Bartlett et al. [2019] Peter L Bartlett, Nick Harvey, Christopher Liaw, and Abbas Mehrabian. Nearly-tight vc-dimension and pseudodimension bounds for piecewise linear neural networks. The Journal of Machine Learning Research, 20(1):2285–2301, 2019.
- Basu et al. [2023] Soumya Basu, Ankit Singh Rawat, and Manzil Zaheer. A statistical perspective on retrieval-based models. In Andreas Krause, Emma Brunskill, Kyunghyun Cho, Barbara Engelhardt, Sivan Sabato, and Jonathan Scarlett, editors, Proceedings of the 40th International Conference on Machine Learning, volume 202 of Proceedings of Machine Learning Research, pages 1852–1886. PMLR, 23–29 Jul 2023. URL https://proceedings.mlr.press/v202/basu23a.html.
- Borgeaud et al. [2021] Sebastian Borgeaud, Arthur Mensch, Jordan Hoffmann, Trevor Cai, Eliza Rutherford, Katie Millican, George van den Driessche, Jean-Baptiste Lespiau, Bogdan Damoc, Aidan Clark, Diego de Las Casas, Aurelia Guy, Jacob Menick, Roman Ring, Tom Hennigan, Saffron Huang, Loren Maggiore, Chris Jones, Albin Cassirer, Andy Brock, Michela Paganini, Geoffrey Irving, Oriol Vinyals, Simon Osindero, Karen Simonyan, Jack W. Rae, Erich Elsen, and Laurent Sifre. Improving language models by retrieving from trillions of tokens. CoRR, abs/2112.04426, 2021.
- Brown et al. [2020] Tom B. Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared Kaplan, Prafulla Dhariwal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, Sandhini Agarwal, Ariel Herbert-Voss, Gretchen Krueger, Tom Henighan, Rewon Child, Aditya Ramesh, Daniel M. Ziegler, Jeffrey Wu, Clemens Winter, Christopher Hesse, Mark Chen, Eric Sigler, Mateusz Litwin, Scott Gray, Benjamin Chess, Jack Clark, Christopher Berner, Sam McCandlish, Alec Radford, Ilya Sutskever, and Dario Amodei. Language models are few-shot learners, 2020.
- Burda et al. [2015] Yuri Burda, Roger Grosse, and Ruslan Salakhutdinov. Importance weighted autoencoders. arXiv preprint arXiv:1509.00519, 2015.
- Chen et al. [2009] Yihua Chen, Eric K Garcia, Maya R Gupta, Ali Rahimi, and Luca Cazzanti. Similarity-based classification: Concepts and algorithms. Journal of Machine Learning Research, 10(3), 2009.
- Chowdhery et al. [2022] Aakanksha Chowdhery, Sharan Narang, Jacob Devlin, Maarten Bosma, Gaurav Mishra, Adam Roberts, Paul Barham, Hyung Won Chung, Charles Sutton, Sebastian Gehrmann, et al. Palm: Scaling language modeling with pathways. arXiv preprint arXiv:2204.02311, 2022.
- Cleveland [1979] William S Cleveland. Robust locally weighted regression and smoothing scatterplots. Journal of the American statistical association, 74(368):829–836, 1979.
- Cortes and Vapnik [1995] Corinna Cortes and Vladimir Vapnik. Support-vector networks. Machine learning, 20(3):273–297, 1995.
- Cramer [2021] Patrick Cramer. Alphafold2 and the future of structural biology. Nature Structural & Molecular Biology, 28(9):704–705, 2021.
- Das et al. [2021] Rajarshi Das, Manzil Zaheer, Dung Thai, Ameya Godbole, Ethan Perez, Jay Yoon Lee, Lizhen Tan, Lazaros Polymenakos, and Andrew McCallum. Case-based reasoning for natural language queries over knowledge bases. In Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing, pages 9594–9611, Online and Punta Cana, Dominican Republic, November 2021. Association for Computational Linguistics. doi: 10.18653/v1/2021.emnlp-main.755.
- Dehghani et al. [2023] Mostafa Dehghani, Josip Djolonga, Basil Mustafa, Piotr Padlewski, Jonathan Heek, Justin Gilmer, Andreas Peter Steiner, Mathilde Caron, Robert Geirhos, Ibrahim Alabdulmohsin, et al. Scaling vision transformers to 22 billion parameters. In International Conference on Machine Learning, pages 7480–7512. PMLR, 2023.
- Donoho and Liu [1988] David L Donoho and Richard C Liu. The" automatic" robustness of minimum distance functionals. The Annals of Statistics, 16(2):552–586, 1988.
- Dosovitskiy et al. [2021] Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, and Neil Houlsby. An image is worth 16x16 words: Transformers for image recognition at scale. In International Conference on Learning Representations, 2021.
- Epasto et al. [2020] Alessandro Epasto, Mohammad Mahdian, Vahab Mirrokni, and Emmanouil Zampetakis. Optimal approximation-smoothness tradeoffs for soft-max functions. Advances in Neural Information Processing Systems, 33:2651–2660, 2020.
- Gandelsman et al. [2022] Yossi Gandelsman, Yu Sun, Xinlei Chen, and Alexei Efros. Test-time training with masked autoencoders. Advances in Neural Information Processing Systems, 35:29374–29385, 2022.
- Grathwohl et al. [2021] Will Grathwohl, Kevin Swersky, Milad Hashemi, David Duvenaud, and Chris Maddison. Oops i took a gradient: Scalable sampling for discrete distributions. In International Conference on Machine Learning, pages 3831–3841. PMLR, 2021.
- Gu et al. [2023] Yuxian Gu, Li Dong, Furu Wei, and Minlie Huang. Knowledge distillation of large language models. arXiv preprint arXiv:2306.08543, 2023.
- Guu et al. [2020] Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat, and Ming-Wei Chang. Realm: Retrieval-augmented language model pre-training. In Proceedings of the 37th International Conference on Machine Learning, ICML’20. JMLR.org, 2020.
- henrikl [https://math.stackexchange.com/users/351007/henrikl] henrikl (https://math.stackexchange.com/users/351007/henrikl). 1-smoothness of the symmetric softmax function. Mathematics Stack Exchange, 2021. URL https://math.stackexchange.com/q/4170855. URL:https://math.stackexchange.com/q/4170855 (version: 2021-06-12).
- Huszár [2015] Ferenc Huszár. How (not) to train your generative model: Scheduled sampling, likelihood, adversary? arXiv preprint arXiv:1511.05101, 2015.
- Ibragimov and Has Minskii [2013] Ildar Abdulovich Ibragimov and Rafail Zalmanovich Has Minskii. Statistical estimation: asymptotic theory, volume 16. Springer Science & Business Media, 2013.
- Iscen et al. [2023] Ahmet Iscen, Alireza Fathi, and Cordelia Schmid. Improving image recognition by retrieving from web-scale image-text data. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 19295–19304, 2023.
- Izacard and Grave [2021] Gautier Izacard and Edouard Grave. Leveraging passage retrieval with generative models for open domain question answering. In Proceedings of the 16th Conference of the European Chapter of the Association for Computational Linguistics: Main Volume, pages 874–880, Online, April 2021. Association for Computational Linguistics. doi: 10.18653/v1/2021.eacl-main.74. URL https://aclanthology.org/2021.eacl-main.74.
- Izacard et al. [2022] Gautier Izacard, Patrick Lewis, Maria Lomeli, Lucas Hosseini, Fabio Petroni, Timo Schick, Jane Dwivedi-Yu, Armand Joulin, Sebastian Riedel, and Edouard Grave. Few-shot learning with retrieval augmented language models. arXiv preprint arXiv:2208.03299, 2022.
- Jain and Learned-Miller [2011] Vidit Jain and Erik Learned-Miller. Online domain adaptation of a pre-trained cascade of classifiers. In CVPR 2011, pages 577–584. IEEE, 2011.
- Joshi et al. [2017] Mandar Joshi, Eunsol Choi, Daniel S Weld, and Luke Zettlemoyer. Triviaqa: A large scale distantly supervised challenge dataset for reading comprehension. In Proceedings of the 55th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), pages 1601–1611, 2017.
- Karpukhin et al. [2020a] Vladimir Karpukhin, Barlas Oguz, Sewon Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih. Dense passage retrieval for open-domain question answering. In Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP), pages 6769–6781, Online, November 2020a. Association for Computational Linguistics.
- Karpukhin et al. [2020b] Vladimir Karpukhin, Barlas Oğuz, Sewon Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih. Dense passage retrieval for open-domain question answering. arXiv preprint arXiv:2004.04906, 2020b.
- Katkovnik and Kheisin [1979] Vladimir Yakovlevich Katkovnik and VE Kheisin. Dynamic stochastic approximation of polynomials drifts. Avtomatika i Telemekhanika, pages 89–98, 1979.
- Khandelwal et al. [2020] Urvashi Khandelwal, Omer Levy, Dan Jurafsky, Luke Zettlemoyer, and Mike Lewis. Generalization through memorization: Nearest neighbor language models. In International Conference on Learning Representations, 2020.
- Kwiatkowski et al. [2019] Tom Kwiatkowski, Jennimaria Palomaki, Olivia Redfield, Michael Collins, Ankur Parikh, Chris Alberti, Danielle Epstein, Illia Polosukhin, Jacob Devlin, Kenton Lee, et al. Natural questions: a benchmark for question answering research. Transactions of the Association for Computational Linguistics, 7:453–466, 2019.
- Leake et al. [1996] David B. Leake, Andrew Kinley, and David C. Wilson. Acquiring case adaptation knowledge: A hybrid approach. In AAAI/IAAI, Vol. 1, 1996. URL https://api.semanticscholar.org/CorpusID:11169287.
- Lee et al. [2019] Kenton Lee, Ming-Wei Chang, and Kristina Toutanova. Latent retrieval for weakly supervised open domain question answering. In Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics, pages 6086–6096, Florence, Italy, July 2019. Association for Computational Linguistics. doi: 10.18653/v1/P19-1612. URL https://aclanthology.org/P19-1612.
- Lewis et al. [2020] Patrick Lewis, Ethan Perez, Aleksandra Piktus, Fabio Petroni, Vladimir Karpukhin, Naman Goyal, Heinrich Küttler, Mike Lewis, Wen-tau Yih, Tim Rocktäschel, et al. Retrieval-augmented generation for knowledge-intensive nlp tasks. Advances in Neural Information Processing Systems, 33:9459–9474, 2020.
- Lewkowycz et al. [2022] Aitor Lewkowycz, Anders Andreassen, David Dohan, Ethan Dyer, Henryk Michalewski, Vinay Ramasesh, Ambrose Slone, Cem Anil, Imanol Schlag, Theo Gutman-Solo, et al. Solving quantitative reasoning problems with language models. Advances in Neural Information Processing Systems, 35:3843–3857, 2022.
- Li et al. [2023] Yingcong Li, Muhammed Emrullah Ildiz, Dimitris Papailiopoulos, and Samet Oymak. Transformers as algorithms: Generalization and stability in in-context learning. In Andreas Krause, Emma Brunskill, Kyunghyun Cho, Barbara Engelhardt, Sivan Sabato, and Jonathan Scarlett, editors, Proceedings of the 40th International Conference on Machine Learning, volume 202 of Proceedings of Machine Learning Research, pages 19565–19594. PMLR, 23–29 Jul 2023. URL https://proceedings.mlr.press/v202/li23l.html.
- Liévin et al. [2023] Valentin Liévin, Andreas Geert Motzfeldt, Ida Riis Jensen, and Ole Winther. Variational open-domain question answering. In Andreas Krause, Emma Brunskill, Kyunghyun Cho, Barbara Engelhardt, Sivan Sabato, and Jonathan Scarlett, editors, Proceedings of the 40th International Conference on Machine Learning, volume 202 of Proceedings of Machine Learning Research, pages 20950–20977. PMLR, 23–29 Jul 2023.
- Lin et al. [2023] Xi Victoria Lin, Xilun Chen, Mingda Chen, Weijia Shi, Maria Lomeli, Rich James, Pedro Rodriguez, Jacob Kahn, Gergely Szilvasy, Mike Lewis, et al. Ra-dit: Retrieval-augmented dual instruction tuning. arXiv preprint arXiv:2310.01352, 2023.
- Liska et al. [2022] Adam Liska, Tomas Kocisky, Elena Gribovskaya, Tayfun Terzi, Eren Sezener, Devang Agrawal, Cyprien De Masson D’Autume, Tim Scholtes, Manzil Zaheer, Susannah Young, Ellen Gilsenan-Mcmahon, Sophia Austin, Phil Blunsom, and Angeliki Lazaridou. StreamingQA: A benchmark for adaptation to new knowledge over time in question answering models. In Kamalika Chaudhuri, Stefanie Jegelka, Le Song, Csaba Szepesvari, Gang Niu, and Sivan Sabato, editors, Proceedings of the 39th International Conference on Machine Learning, volume 162 of Proceedings of Machine Learning Research, pages 13604–13622. PMLR, 17–23 Jul 2022. URL https://proceedings.mlr.press/v162/liska22a.html.
- Long et al. [2022] Alexander Long, Wei Yin, Thalaiyasingam Ajanthan, Vu Nguyen, Pulak Purkait, Ravi Garg, Alan Blair, Chunhua Shen, and Anton van den Hengel. Retrieval augmented classification for long-tail visual recognition. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pages 6959–6969, 2022.
- Lowe [1999] David G Lowe. Object recognition from local scale-invariant features. In Proceedings of the seventh IEEE international conference on computer vision, volume 2, pages 1150–1157. Ieee, 1999.
- McSherry and Talwar [2007] Frank McSherry and Kunal Talwar. Mechanism design via differential privacy. In 48th Annual IEEE Symposium on Foundations of Computer Science (FOCS’07), pages 94–103. IEEE, 2007.
- Meinhardt et al. [2022] Tim Meinhardt, Alexander Kirillov, Laura Leal-Taixe, and Christoph Feichtenhofer. Trackformer: Multi-object tracking with transformers. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pages 8844–8854, 2022.
- Ni et al. [2022] Jianmo Ni, Chen Qu, Jing Lu, Zhuyun Dai, Gustavo Hernandez Abrego, Ji Ma, Vincent Zhao, Yi Luan, Keith Hall, Ming-Wei Chang, et al. Large dual encoders are generalizable retrievers. In Proceedings of the 2022 Conference on Empirical Methods in Natural Language Processing, pages 9844–9855, 2022.
- OpenAI [2023] OpenAI. Gpt-4 technical report. ArXiv, abs/2303.08774, 2023. URL https://api.semanticscholar.org/CorpusID:257532815.
- Paranjape et al. [2022] Ashwin Paranjape, Omar Khattab, Christopher Potts, Matei Zaharia, and Christopher D Manning. Hindsight: Posterior-guided training of retrievers for improved open-ended generation. In International Conference on Learning Representations, 2022. URL https://openreview.net/forum?id=Vr_BTpw3wz.
- Pinsker [1980] Mark Semenovich Pinsker. Optimal filtering of square-integrable signals in gaussian noise. Problemy Peredachi Informatsii, 16(2):52–68, 1980.
- Raffel et al. [2020] Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena, Yanqi Zhou, Wei Li, and Peter J Liu. Exploring the limits of transfer learning with a unified text-to-text transformer. The Journal of Machine Learning Research, 21(1):5485–5551, 2020.
- Roweis and Saul [2000] Sam T Roweis and Lawrence K Saul. Nonlinear dimensionality reduction by locally linear embedding. science, 290(5500):2323–2326, 2000.
- Rubin et al. [2022] Ohad Rubin, Jonathan Herzig, and Jonathan Berant. Learning to retrieve prompts for in-context learning. In Proceedings of the 2022 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, pages 2655–2671, Seattle, United States, July 2022. Association for Computational Linguistics. doi: 10.18653/v1/2022.naacl-main.191. URL https://aclanthology.org/2022.naacl-main.191.
- Ruppert and Wand [1994] David Ruppert and Matthew P Wand. Multivariate locally weighted least squares regression. The annals of statistics, pages 1346–1370, 1994.
- Sachan et al. [2021] Devendra Singh Sachan, Siva Reddy, William L. Hamilton, Chris Dyer, and Dani Yogatama. End-to-end training of multi-document reader and retriever for open-domain question answering. In A. Beygelzimer, Y. Dauphin, P. Liang, and J. Wortman Vaughan, editors, Advances in Neural Information Processing Systems, 2021. URL https://openreview.net/forum?id=5KWmB6JePx.
- Sachan et al. [2023] Devendra Singh Sachan, Mike Lewis, Dani Yogatama, Luke Zettlemoyer, Joelle Pineau, and Manzil Zaheer. Questions are all you need to train a dense passage retriever. Transactions of the Association for Computational Linguistics, 11:600–616, 2023.
- Schonberger et al. [2017] Johannes L Schonberger, Hans Hardmeier, Torsten Sattler, and Marc Pollefeys. Comparative evaluation of hand-crafted and learned local features. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 1482–1491, 2017.
- Shalev-Shwartz and Ben-David [2014] Shai Shalev-Shwartz and Shai Ben-David. Understanding machine learning: From theory to algorithms. Cambridge university press, 2014.
- Shuster et al. [2021] Kurt Shuster, Spencer Poff, Moya Chen, Douwe Kiela, and Jason Weston. Retrieval augmentation reduces hallucination in conversation. arXiv preprint arXiv:2104.07567, 2021.
- Siegel [2023] Jonathan W Siegel. Optimal approximation rates for deep relu neural networks on sobolev and besov spaces. Journal of Machine Learning Research, 24(357):1–52, 2023.
- Singhal et al. [2023] Karan Singhal, Shekoofeh Azizi, Tao Tu, S Sara Mahdavi, Jason Wei, Hyung Won Chung, Nathan Scales, Ajay Tanwani, Heather Cole-Lewis, Stephen Pfohl, et al. Large language models encode clinical knowledge. Nature, pages 1–9, 2023.
- Sparck Jones [1972] Karen Sparck Jones. A statistical interpretation of term specificity and its application in retrieval. Journal of documentation, 28(1):11–21, 1972.
- Stone [1977] Charles J Stone. Consistent nonparametric regression. The annals of statistics, pages 595–620, 1977.
- Stone [1980] Charles J Stone. Optimal rates of convergence for nonparametric estimators. The annals of Statistics, pages 1348–1360, 1980.
- Sun et al. [2020] Yu Sun, Xiaolong Wang, Zhuang Liu, John Miller, Alexei Efros, and Moritz Hardt. Test-time training with self-supervision for generalization under distribution shifts. In International conference on machine learning, pages 9229–9248. PMLR, 2020.
- Sutton and Barto [2018] Richard S Sutton and Andrew G Barto. Reinforcement learning: An introduction. MIT press, 2018.
- Thai et al. [2023] Dung Thai, Dhruv Agarwal, Mudit Chaudhary, Rajarshi Das, Manzil Zaheer, Jay-Yoon Lee, Hannaneh Hajishirzi, and Andrew McCallum. Machine reading comprehension using case-based reasoning. arXiv preprint arXiv:2305.14815, 2023.
- Touvron et al. [2023] Hugo Touvron, Louis Martin, Kevin Stone, Peter Albert, Amjad Almahairi, Yasmine Babaei, Nikolay Bashlykov, Soumya Batra, Prajjwal Bhargava, Shruti Bhosale, et al. Llama 2: Open foundation and fine-tuned chat models. arXiv preprint arXiv:2307.09288, 2023.
- Vaswani et al. [2017] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. Advances in neural information processing systems, 30, 2017.
- Voorhees et al. [1999] Ellen M Voorhees et al. The trec-8 question answering track report. In Trec, volume 99, pages 77–82, 1999.
- Wang et al. [2022] Shuohang Wang, Yichong Xu, Yuwei Fang, Yang Liu, Siqi Sun, Ruochen Xu, Chenguang Zhu, and Michael Zeng. Training data is more valuable than you think: A simple and effective method by retrieving from training data. In Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), pages 3170–3179, Dublin, Ireland, May 2022. Association for Computational Linguistics. doi: 10.18653/v1/2022.acl-long.226. URL https://aclanthology.org/2022.acl-long.226.
- Williams [1992] Ronald J Williams. Simple statistical gradient-following algorithms for connectionist reinforcement learning. Machine learning, 8:229–256, 1992.
- Yang et al. [2021] Shiqi Yang, Joost van de Weijer, Luis Herranz, Shangling Jui, et al. Exploiting the intrinsic neighborhood structure for source-free domain adaptation. Advances in neural information processing systems, 34:29393–29405, 2021.
- Yarotsky [2017] Dmitry Yarotsky. Error bounds for approximations with deep relu networks. Neural Networks, 94:103–114, 2017.
- Zemlyanskiy et al. [2022] Yury Zemlyanskiy, Michiel de Jong, Joshua Ainslie, Panupong Pasupat, Peter Shaw, Linlu Qiu, Sumit Sanghai, and Fei Sha. Generate-and-retrieve: Use your predictions to improve retrieval for semantic parsing. In Proceedings of the 29th International Conference on Computational Linguistics, pages 4946–4951, Gyeongju, Republic of Korea, October 2022. International Committee on Computational Linguistics. URL https://aclanthology.org/2022.coling-1.438.
- Zhang et al. [2006] Hao Zhang, Alexander C Berg, Michael Maire, and Jitendra Malik. Svm-knn: Discriminative nearest neighbor classification for visual category recognition. In 2006 IEEE Computer Society Conference on Computer Vision and Pattern Recognition (CVPR’06), volume 2, pages 2126–2136. IEEE, 2006.
- Zhang [2023] Tong Zhang. Mathematical analysis of machine learning algorithms. Cambridge University Press, 2023.
- Ziegler et al. [2019] Daniel M Ziegler, Nisan Stiennon, Jeffrey Wu, Tom B Brown, Alec Radford, Dario Amodei, Paul Christiano, and Geoffrey Irving. Fine-tuning language models from human preferences. arXiv preprint arXiv:1909.08593, 2019.
Appendix A Preliminaries
Definition A.1 (Rademacher complexity).
Given a sample and a real-valued function class , the empirical Rademacher complexity of with respect to is defined as
(24) |
where is a collection of i.i.d. Bernoulli random variables. For , the Rademacher complexity and worst case Rademacher complexity are defined as follows.
(25) |
Definition A.2 (Covering nsumber).
Let and be a norm defined over . Given a function class and a collection of points , we call a set of points an -cover of with respect to , if we have
(26) |
where . The -covering number denotes the cardinality of the minimal -cover of with respect to . In particular, if is a norm (e.g. for ), then we simply use to denote the corresponding -covering number.
When is unambiguous we may drop it, i.e., we use to represent the covering number.
Definition A.3 (Multi-layer perceptron (MLP)).
We consider for both retrieval and predictor, the class of multi-layer-perceptron, aka fully connected Deep Neural Network, with Relu nonlinearity . An MLP is specified by the number of layers , and the width . We define with weight and bias , an affine transform . Let define the elementwise application of the Relu non-linearity on the affine transform. The class of layers and width MLP is defined as
(27) |
where and ; and , for ; and and .
Definition A.4 (Sobolev space).
For , we denote the set of functions with finite norm over as , i.e., for any , . Note that for , we have Let denote a multi-index, and be it’s degree. We denote by the weak-derivative with respect to multi-index for any function.
For an integer , the Sobolev semi-norm for a function that has weak-derivatives of order is defined as
The Sobolev norm for the same function is defined as A function with all weak-derivatives of order , and a finite norm lies in the Sobolev space with derivatives and norm.
In our approximation guarantees for MLP retreiver and predictor classes later, we use [Siegel, 2023, Theorem 1]. We restate the result here for completeness.
Theorem A.5 (Restated Siegel [2023] Theorem 1).
Let be a function in the Sobolev space with derivatives and norm , for and . For and any satisfying , we have for , and
Our generalization bounds leverages VC Dimension bounds of MLP Bartlett et al. [2019]. Here, we state some results from Bartlett et al. [2019] for completeness.
Definition A.6 (VC dimension and growth of a binary function class).
For , a class of functions from to the growth function of evaluated on an input set of size , is defined as
The is defined as the largest such that , where if no such is there we have .
Definition A.7 (Pseudo dimension of real valued function class).
Let be a class of functions from some space to the real . The pseudo-dimension of class , denoted by , is the largest such that there exists such that for any binary sequence there exists a function satisfying .
Note that the pseudo-dimension is same as the VC dimension of the subgraph of class which is used in Zhang [2023]. Let . We denote by the sign of the function . We define , and the VC dimension of the real valued function class as . It is mentioned in Bartlett et al. [2019] that for neural network with a fixed architecture and fixed activation functions, namely class , we have that .
We now adapt [Bartlett et al., 2019, Theorem 6] to use it for the class the employs the Relu non-linearity. In terminology of Bartlett et al. [2019], it amounts to focusing on the number of breakpoints , and degree of polynomial .111Originally in Bartlett et al. [2019] degree is denoted by and break point by , but we use and , respectively, to avoid confusion. These notations are used for the rest of the paper.
Theorem A.8 (Adaptation of Bartlett et al. [2019] Theorem 6).
Consider the neural network class that has the Relu non-linearity. Let denote the total number of parameters (weights and biases) up to layer , and denote the number of non-linear units (output width) in layer . Also define the parameters , and . Then for the function class of all real-valued functions computed by the MLP class and
Moreover, we have
We generalize the above result to capture the MLP with multi dimensional output as used by our predictor.
Theorem A.9 (Multi-ouput version of Bartlett et al. [2019] Theorem 6).
Consider the neural network class that has Relu non-linearity with , , , and as defined in Theorem A.8. We denote by the class of functions where is the -th output coordinate of a neural network in class . Then, we have
Proof.
Let parameterize one function . Based on the discussions, we need to find the of the set . Note that here we have is a function mapping the tuple to a real number.
We obtain the following inequality.
In the first inequality, we partition the set with respect to . For the second inequality we notice that for a fixed the function is computed by and bound it with the growth function over points. Therefore, for the third inequality we can apply the specified bound for inside the proof of Theorem 6 in Bartlett et al. [2019]. Note that, here we have specialized for Relu nonlinearlity, i.e. breaking point , and degree . Applying Lemma 6 in Bartlett et al. [2019] we obtain
∎
Finally, we state a bounded version of the Gibb’s inequality, that lower bounds the cross entropy of two discrete probability distributions.
Proposition A.10 (Truncated Gibb’s inequality).
Let us consider two discrete distributions over alphabet size . Then for any constant , we have
Proof.
For two discrete distributions over alphabet size .
The first inequality follows from the log-sum-inequality. The second inequality follows as is maximized by setting one for some , while the rest are set to . The second last inequality follows by . The final inequality follows by taking a minimum with can only decrease the value. ∎
Appendix B Derivations of main result
As discussed in Section 2, the objective here is to study the excess risk in Eq. (12) which has three main components, generalization error, retriever approximation error, and predictor approximation error (cf. (3.1)). In this section, we structure our results somewhat differently than the main body to capture the general setting of learning retriever with a fixed predictor, and vice versa. We first prove excess risk bounds for learning the retriever, then excess risk bounds for learning the predictor. Finally, we combine the results to obtain the guarantees for jointly learning the retriever and the predictor presented in the paper. For the rest of the analysis we need to specify the space of retrieved examples to define the complexity of the gap function (cf. 3.1). We recall that our retrieved samples are embedded in a compact subspace of , and is a compact subspace of . In particular, for simplicity, we assume that and .
B.1 Learning the retriever
We first study learning the retriever over class when the predictor is fixed. The task of learning the retriever corresponds to minimizing the following over ,
where . We have a closed form for the optimal retriever when not restricted within a function class. The optimal retriever is , where a tie is broken arbitrarily.
For the fixed predictor , let minimize the empirical risk given, and minimize the population risk over the class , i.e.
Here, the probability is defined using the softmax operator for a given as follows:
Hardness of retrieval:
We recall the Sobolev space with derivatives as defined in Section A. The following is the restatement of Assumption 3.1 but for any and not just the optimal one .
Assumption B.1 (Complexity of ).
For any , there exists a baseline such that the function with baseline , as defined by lies in the Sobolev space with derivatives and norm.
As noted in the main text this means that the predictor loss has a possibly ‘complex’ component , and a relatively ‘smooth’ component that ensures two retrieved examples that are close leads to similar loss for the predictor for any . As solely determines the optimal retrieved set, it’s smoothness defines the hardness of underlying retrieval task.
Excess risk decomposition:
With the fixed predictor , excess risk in (12) takes the following form
B.1.1 Generalization error
We now proceed to bound the generalization error using the Radamacher complexity. With probability at least for any ,
(28) |
Using covering number bound with chaining we obtain the final inequality, where
and denote the covering number of the retriever function with error in norm w.r.t. the set and fixed,
The generalization error in retriever learning depends on the covering number of (which we shall see is dependent on the embedding space of the retrieved examples).
As is a fixed retriever, we do not need to take any union bound over the retriever space. Therefore, we have
B.1.2 Approximation error
The approximation error for learning the retriever depends on the hardness of the function . We recall that this term is approximated using softmax over (cf. (6)).
We want to approximate the term for all , by . We can break down the approximation into two parts. First we show that the function approximates for large . In particular, if then softmax approximates minimum with error (see, McSherry and Talwar [2007], Epasto et al. [2020]). Second, we show that can approximate well in norm.
We define
Here recall that is the baseline function in Assumption 3.1. An example of such baseline is the loss under the optimal retrieved sample for each , i.e. .
For any , we have
In the first inequality , we replace which is the optimal retriever for predictor with an arbitrary retriever . The first term in the inequality uses the norm bounds for inner product, while the second term follows from Theorem 3.1 in [Epasto et al., 2020] (which originates from [McSherry and Talwar, 2007]). The inequality uses the fact that softmax functions over classes follow (see henrikl [https://math.stackexchange.com/users/351007/henrikl]). In the final inequality , we use to bound the norm of .
As the above bound hold for any , by optimizing of and we obtain,
(29) |
Since the right had side in the inequality holds for any , if there exists a such that the function approximates the function well we end up with small approximation error.
B.1.3 Instantiation of MLP retriever
We consider to be the class of MLP defined in Equation (27). As we know MLP with appropriate depth and width has universal approximation properties, this choice of ensures the function approximates the function well. To bound the excess risk of learning the retriever, we need to prove generalization error, and approximation error bounds for the MLP class.
Generalization error for MLP retriever:
To bound the generalization error, we need to first bound the covering number , for . Here, and i.e., the retrieved space is embedded in . We first want to bound the covering number with a covering number of .
For a fixed data set ; predictor ; and two retrievers
Above, the inequality follow by upper bounding with . The inequality uses the fact that softmax functions over classes follow .
Let us define the norm as Now consider a norm cover of , with cardinality .
Note that, by definition, for any , there exists a such that . This means, that forms a -cover in the norm. In other words, we have
Most existing results on covering number bounds for MLP assumes norm bounds for the MLP weights and biases. However, we do not impose such norm bounds for the MLP weights and biases. Therefore, we will use pseudo-dimension of the class from Bartlett et al. [2019] to bound the covering number using Zhang [2023]. In particular, if the pseudo-dimension of is , then we have as per in [Zhang, 2023, Theorem 5.11]. From [Bartlett et al., 2019, Theorem 6] we know that for the class the pseudo-dimension is , where is the number of parameters, and is the number of computation units. By setting for a constant , and in Equation (28), for large enough (we will set as a function of the data size ) we obtain the final generalization error as
(30) |
Approximation error for MLP retriever:
Our excess risk bounds closely follow the work of Siegel [2023] which generalizes Yarotsky [2017].222We note Siegel [2023] works with , and as mentioned therein, the analysis can be extended to bounded domain, e.g. which includes our setting. Furthermore, one can extend the analysis to non-integer Sobolev and Besov spaces following Siegel [2023]. Under Assumption B.1, by specializing [Siegel, 2023, Theorem 1] with we get that
for , and (independent of L). Note that is the number of derivatives of the Sobolev space under consideration in Assumption B.1.
Therefore, under Assumption B.1 for we show that
(31) |
This follows from the following series of inequalities:
The first inequality follows from Equation (29). The second equality , replaces . The inequality follows by optimizing over the class , as we see then also lies in , and applying Theorem 1 in Siegel [2023]. The final inequality combines and bounds .
Note that the choice of is not algorithmic, we can optimize for . In particular, we choose to obtain the approximation error bound as , where we treat the remaining terms that are independent of and as constants.
Excess risk for MLP retriever learning:
B.2 Learning the predictor
We now quantify the excess risk of a predictor for a fixed retriever . For a fixed retriever , the learning task of the predictor is to minimize
The predictor now learns from the joint distribution . We assume that the hardness of the classification task performed by the predictor varies with the selected retriever .
Similar to retriever learning in Section B.1, for a fixed retriever , the predictor that minimizes the empirical risk , and the predictor that minimizes the population risk over the class are defined as
where . We also define the predictor over the class with ‘optimal’ retrieval (possibly outside of ) that minimizes the population risk as as
Usefulness of data-store:
We start with characterization of the prediction task in the presence of the data-store . We consider that there exists a score function and corresponding probability distribution
(34) |
that approximates well for all and . Furthermore, this score function lies coordinate wise in the Sobolev space (see Definition A.4). The Assumption 3.2 captures the above. We restate the assumption here for convenience.
Assumption B.2 (Retrieval quality).
There exists a score function such that
-
1.
for each , the function (the -th coordinate of ) lies in the Sobolev space with derivatives and finite norm,
-
2.
for any there exists a retrieved example such that for as defined in Equation (34)
Note that the tuple defines the usefulness of the data-store . In particular, the higher the the closer the approximation, and the higher the and the smaller the embedding dimension the ‘easier’ the score function used for this approximation.
Excess risk decomposition
The excess risk decomposition for the learned predictor takes the following form.
(35) |
Note that in the inequality , the predictor which is optimised for the fixed retriever has lower risk compared to the predictor , i.e. .
B.2.1 Approximation error
We specialize our analysis for the log-loss bounded by given as
(36) |
Note that we need to bound the predictor error for the bounded log-loss. We want to relate this term to the (cf. Equation.(34)) for which we have good control over its complexity. We first need a lower bound for as a function of . We proceed as follows:
(37) |
In the first inequality, applying Proposition A.10 to our setting with and we obtain the lower bound. The second inequality follows from mean-value theorem as below,
Next inequality is obtained by Assumption B.2 with is ad defined therein. The final inequality substitutes where is the score function used in Equation (34).
We now derive an upper bound for the predictor error part of our excess risk bound in Equation (35). Let be an arbitrary predictor
Predictor Error | |||
The second inequality follows by substituting the lower bound of from Equation (37). As optimizes -risk over , we can substitute with the arbitrary predictor to obtain an upper bound. The final inequality is obtained by substituting instead of minimizing with respect to . Note that the final inequality holds for all as the initial choice of was arbitrary.
Bounding the term , is similar to bounding the -risk for classification with the data distribution . Our strategy is to bound -risk with distance between the score functions and the score function which lies in the Sobolev space as given in the Assumption B.2. In particular, we have the following norm bound.
The inequality follows by substituting the bounded log-loss, and using the fact that for any two , . The final inequality follows by bounding the first term by second.
We note that the above holds for all . This gives the general approximation error bound as
(38) |
Note the predictor approximation error is independent of retriever learning as it is compared with respect to the Bayes optimal retriever (i.e. ) as seen in the Equation (35).
B.2.2 Generalization error
The generalization error in Equation (35) can be bounded in a similar manner as the retriever learning in Section B.1. Note that the predictor is learnt over the space while the retriever is fixed in this setup.
The final inequality again follows using covering number based bounds with chaining (cf. Shalev-Shwartz and Ben-David [2014]). We have used for a fixed retriever
and denote the covering number of the predictor function class with error in norm w.r.t. the set and fixed ,
As is fixed for a fixed , we can directly bound without any union over the learner/predictor space,
B.2.3 Instantiation of MLP predictor
As a concrete example, we now consider the space as defined in Equation (27).
Approximation error of MLP predictor:
Our approximation results rely mainly on the results in Siegel [2023]. The key difference here is the output is now dimensional. We find an MLP of depth and width at most to individually approximate the functions for each . Later we can join these networks in parallel to obtain a final network with depth and width at most . In principle, these networks may share sub-networks (e.g. the bit extraction networks, the sub-domain indexation network for in Siegel [2023]) used for constructing the approximation. However, this is out of scope for this work, and we leave this open.
From Theorem 1 in Siegel [2023], by taking in the theorem statement, under Assumption B.2 we get that for each there exists a MLP such that
for , and (independent of L). By concatenating the networks for in parallel (c.f. Lemma 5 in Siegel [2023]), and using the first layer to share the input to these parallel networks we obtain a MLP , , such that we have
By using in our bounds we obtain the predictor error as
(39) |
Generalization error for MLP predictor:
We now bound the generalization error in Equation 35 when denotes a class of multi-layer perceptron (MLP) with Relu nonlinearity .
The first step is to bound the covering number norm with the covering number . Where is defined as
For a fixed data set and retriever , and two predictors , we have
The first inequality follows from Jensen. For the case of bounded log-loss, we obtain the second inequality using the fact that for any two , .
Let be a norm cover for the space of cardinality . That implies, for any there exists a such that . Therefore, due to the above inequality, we have . So forms a cover of with respect to the norm. Hence,
We need to bound next. Similar to the retrieval analysis in Section B.1, we first apply Zhang [2023] to bound the covering number with pseudo-dimension. However, we need slight reformulation of the function to apply the results therein. Let us define function , where for each we have . It is easy to see that covering of set remains unchanged due to this reformulation. In particular, if the pseudo-dimension of is , then we have as per Theorem 5.11 in Zhang [2023].
Next we derive the pseudo-dimension of the class using Bartlett et al. [2019]. One challenge here is that for the MLP we are considering the label does not lie in the input space, rather this correspond to one coordinate of the -dimensional output. This can be captured with the slight modification of Theorem 6 in Bartlett et al. [2019], namely Theorem A.9 in Appendix A. By Theorem A.9 we have for the VC dimension of as . The final generalization bound obtained is as
(40) |
Excess risk of predictor learning:
We can now combine the generalization error (40) and approximation error (38) to obtain the final excess risk. The final excess risk is upper bounded as
Excess Risk | ||||
(41) |
We have data store grow polynomially with data, , and we let . For and , the final error bound for predictor follows by setting . Note that the choice of and here are related to predictor size, and are independent of the choices in retriever size. Moreover, here we see Assumption B.2 forces the quality of retriever set to become the bottleneck in predictor excess risk, if we have for .
B.3 Joint learning of retriever and predictor
In this section, we consider the task of joint learning the predictor and retriever from the space and , respectively. The empirical optimizer pair and the population optimizer for the joint task are given as follows.
Recall, the optimal predictor with best possible retrieval is We denote the optimal retriever for as .
The excess risk for the classes and can be bounded as
In the inequality , we substitute the pair for as the former may have higher loss than latter. For the pair the predictor error is easily controlled. Also, note that the retriever is optimized for the optimal predictor . Therefore, unlike the fixed predictor case in Section B.1 we do not have additional predictor error. We next bound the generalization and approximation errors separately by combining the retriever and predictor errors derived earlier.
B.3.1 Generalization Error
First, for the fixed pair we bound the generalization error as
Next, the generalization for the error can be bounded as.
(42) |
The second inequality again follows using covering number based bounds with chaining Shalev-Shwartz and Ben-David [2014]. We have used for a fixed retriever
and denotes the covering number of the retriever function class with error in norm w.r.t. the set , i.e.,
The covering number in Equation (42) can be bounded using the retriever and predictor learning complexities as
B.3.2 Approximation error
The approximation error of predictor and retriever decouples under our decomposition, and under Assumption B.1 and B.2. So the approximation error is also bounded by the sum of the approximation error of retriever learning with optimal predictor, and the approximation error of predictor learning. Our derived bounds approximation error of the retriever holds uniformly for all predictor, so it also holds for optimal predictor. This implies that the joint retriever and predictor learning error is bounded (orderwise) by the sum of the predictor and retriever errors derived earlier in (29), and (38) earlier.
Proof of Theorem 3.3.
We define . Putting the approximation and generalization errors together we obtain the final excess risk bound as
This completes the proof. ∎
B.3.3 Instantiation of MLP retriever and predictor
For the scenario where the retriever and predictor are MLP, we can reuse the earlier analysis to provide the excess risk bound here.
Proof of Theorem 3.4.
Let us recall from Appendix B.1.3, in Equation 32 that a retriever MLP with depth , and width gives an approximation error and the generalization error .
Similarly, from Appendix B.2.3, in Equation (39), a MLP predictor with depth and width has an approximation error , and a generalization error .
Thus, the combined error in this case is given as
This completes the proof. ∎
Finally, letting and combining the excess risk of retriever learning (2nd term in (33)) and of predictor learning (2nd term in (41)), the joint learning excess error rate is given as
Joint Excess Risk MLP | |||
(43) |
Here is defined in Assumption B.1, and are defined in Assumption B.2. Also, is the embedding dimension of input and is the embedding dimension of retrieved example .
Appendix C More experiments
Method | small | base | large | |||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
small | base | large | small | base | large | small | base | large | ||||
EMDR2 | 40.0 | 47.7 | 52.0 | 41.5 | 48.0 | 51.4 | 41.6 | 48.8 | 52.6 | |||
PDist | 49.7 | 57.4 | 61.3 | 48.6 | 57.0 | 61.0 | 47.7 | 55.7 | 58.9 | |||
Reverse Cross-Entropy + PG | 44.9 | 52.6 | 54.7 | 45.3 | 53.3 | 55.2 | 44.9 | 51.7 | 54.9 | |||
Reverse Cross-Entropy + TopK | 48.9 | 56.8 | 60.9 | 47.9 | 55.5 | 59.6 | 46.7 | 54.3 | 58.2 |
Method | small | base | large | |||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
small | base | large | small | base | large | small | base | large | ||||
EMDR2 | 46.6 | 54.7 | 62.4 | 46.1 | 55.7 | 61.6 | 46.0 | 53.9 | 59.5 | |||
PDist | 59.6 | 68.6 | 72.8 | 59.1 | 61.9 | 72.2 | 56.4 | 59.3 | 69.3 | |||
Reverse Cross-Entropy + PG | 58.1 | 60.7 | 70.7 | 56.9 | 66.1 | 64.2 | 54.2 | 61.4 | 61.3 | |||
Reverse Cross-Entropy + TopK | 57.1 | 64.5 | 69.1 | 55.9 | 63.5 | 68.1 | 54.2 | 61.2 | 65.8 |
small | base | large | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
small | base | large | small | base | large | small | base | large | ||
96.4M | 170.9M | 396.4M | 258.8M | 333.3M | 558.9M | 773.6M | 848.1M | 1073.7M |
C.1 Implementation details
Computing the objective (13), let alone its gradient, requires evaluating the reader and predictor over the entire data-store making it prohibitively expensive. We explore two ways to approximately compute the objective:
Top-K approximation
This approach involves constraining the summation to a specific subset. Periodically we compute for all items based on the current value of . We use this to obtain a set of documents with the highest (stale) scores, i.e. and evaluate the sum on this.
(44) |
This methodology is akin to those adopted by EMDR2 and PDist, with the set being refreshed every 500 training steps and the selection of .
Policy gradient
Based on connection to RLHF/RLAIF, we propose to use policy gradient method [Sutton and Barto, 2018] to obtain an unbaised estimate of gradient with respect to efficiently. However, as policy gradients suffer from high variance [Burda et al., 2015, Grathwohl et al., 2021] we use a constant baseline [Williams, 1992] for variance reduction, i.e. our objective becomes
(45) | ||||
where are i.i.d. samples from the retriever distribution. We use and .
C.2 Training details
Dataset The versions of the open-domain QA datasets, we use are:
- •
- •
Optimization. For all of our experiments, we use ADAM weight decay optimizer with a short warm up period (2000 steps) and a linear decay schedule. We use the peak learning rate of . The weight decay factor is 0.1. We chose batch sizes to be . The number of total training steps is as follows:
-
•
No retriever, train predictor : 40,000
-
•
Fixed retriever , train predictor : 20,000
-
•
Fixed predictor , train retriever : 20,000
-
•
Jointly train predictor and retriever : 40,000
Initializations We initialize models for different configurations as follows:
-
•
No retriever, train predictor : We initialize the predictor from public pretrained T5 checkpoint.
-
•
Fixed retriever , train predictor : We initialize the fixed retriever from public pretrained GTR checkpoint and predictor from public pretrained T5 checkpoint.
-
•
Fixed predictor , train retriever : We initialize the fixed predictor from the final checkpoint of previous run, i.e. “Fixed retriever , train predictor ”. The retriever is initialized from public pretrained GTR checkpoint.
-
•
Jointly train predictor and retriever : We initialize the fixed retriever from public pretrained GTR checkpoint and predictor from public pretrained T5 checkpoint.