A Statistical Framework for Data-dependent Retrieval-Augmented Models

Soumya Basu  Ankit Singh Rawat  Manzil Zaheer Equal contribution in alphabetical order. Google, New York
{basusoumya,ankitsrawat,manzilzaheer}@google.com
Abstract

Modern ML systems increasingly augment input instances with additional relevant information to enhance final prediction. Despite growing interest in such retrieval-augmented models, their fundamental properties and training are not well understood. We propose a statistical framework to study such models with two components: 1) a retriever to identify the relevant information out of a large corpus via a data-dependent metric; and 2) a predictor that consumes the input instances along with the retrieved information to make the final predictions. We present a principled method for end-to-end training of both components and draw connections with various training approaches in the literature. Furthermore, we establish excess risk bounds for retrieval-augmented models while delineating the contributions of both retriever and predictor towards the model performance.We validate the utility of our proposed training methods along with the key takeaways from our statistical analysis on open domain question answering task where retrieval augmentation is important.

1 Introduction

Recent advancements in machine learning (ML) have not only led to breakthroughs on long-standing challenging tasks across various fields, but they have also inspired a great deal of interest to develop ML models that can solve even harder tasks (Meinhardt et al., 2022; Lewkowycz et al., 2022; Cramer, 2021) or focus on completely new fields (Austin et al., 2021; OpenAI, 2023; Singhal et al., 2023). While scaling the size of parametric ML models, such as neural networks, is becoming the predominant approach to meet such demands (Brown et al., 2020; Chowdhery et al., 2022; Touvron et al., 2023; Dosovitskiy et al., 2021; Dehghani et al., 2023), the excellent performance realized by this approach is marred by drawbacks such as high computational cost, inefficient storage of world knowledge in parameters, lack of transparency in model behavior, and reduced grounding/factuality of model predictions.

Recognizing these shortcomings, retrieval-augmented models (RAMs) have emerged as a promising alternative. Such models typically employ two components, namely retriever and predictor, during inference on a given input instance: The retriever first identifies instance-specific relevant information from a data-store, and then the predictor jointly processes the retrieved information and the input instance to make a final prediction. In practice, RAMs have enjoyed favorable performance vs. compute trade-off (Borgeaud et al., 2021; Das et al., 2021; Thai et al., 2023) as employing moderate-size parametric models as retriever and predictor in a RAM often matches or exceeds the performance of a much larger standalone ML model that directly maps input instances to predictions. Similarly, conditioning prediction on the retrieved information has shown to exhibit improved grounding (Shuster et al., 2021; Lin et al., 2023; Asai et al., 2023). Furthermore, having access to an external corpus can obviate the need to store task-specific world knowledge in model parameters and enable incorporating dynamically evolving knowledge (Izacard et al., 2022; Liska et al., 2022).

Despite these desirable characteristics, training RAMs presents multiple challenges. The natural approach of independently training retriever and predictor can be sub-optimal (Izacard et al., 2022). Moreover, it requires collecting intermediate supervision on the instance-dependent relevant information to retrieve, which is missing in common datasets and expensive to obtain in general. A common strategy to circumvent the lack of intermediate supervision is to perform end-to-end training which presents its own unique challenges in the context of RAMs. Fundamentally, the retrieval corresponds to the non-differentiable discrete operation of selecting relevant information from a data-store, e.g., via top-k selection based on retriever scores, which prevents direct gradient propagation to the entire retriever. Several clever solutions to above-mentioned issues have been proposed in the literature that focus on different training objectives to propagate learning signal from the predictor into the retriever. However, a formal study that unifies these solutions is missing from the literature.

Another key challenge that prevents the resource-efficient development and deployment of RAMs is the limited understanding of their basic properties such as their generalization behavior and expressive power. For instance, how do the retriever and predictor components interact to ensure good task-specific performance? Are there any principles guiding the selection of the retriever and predictor components? How does (size of) the data-store feature in the final performance of a RAM?

In this paper, we address both aforementioned shortcoming in the literature pertaining RAMs. To unify the training of RAMs, we begin with writing down the natural objective function, which somehow has eluded the literature. This natural objective simply minimizes the expected prediction loss, where the expectation is taken over the distribution induced by the retriever. Empirically, we find this objective to be effective on standard benchmarks: NaturalQuestions (NQ; Kwiatkowski et al., 2019) and TriviaQA (Joshi et al., 2017).

As for the generalization and expressive power, we present an excess risk bound for RAMs that captures the effect of retrieval and prediction function classes. The proposed bound allows us to highlight how retriever and predictor components play complementary roles to reduce approximation error as we increase their respective function class complexity. We also capture the role of data store in improving the model performance by reducing the approximation error. On the generalization front, we carefully decouple the generalization term in the excess risk over the predictor and retriever function classes. This allows us to tightly control the generalization term with only logarithmic dependence on the data store size. As a concrete instantiation for our excess risk bounds, we consider feed-forward neural networks of varying depth for both the retriever and the predictor.

To summarize, our main contributions include:

  • We present a principled objective for end-to-end training of RAMs focusing on a classification setting (Sec. 2 3) and draw connections between existing approaches for training RAMs (Sec. 3.6).

  • We derive excess risk bound highlighting the role played by retriever and predictor functions classes as well as the data-store towards ensuing improved performance by RAMs (Sec. 3.4); capturing the trade off between model capacities at retriever and predictor (Sec. 3.5).

  • We validated the utility of the proposed objective on two standard QA benchmarks: NaturalQuestions (NQ) and TriviaQA (Sec. 4).

2 Problem setup

In this paper, we focus on developing a systematic understanding of RAMs with learned retrievers in a classification setting where the model has access to a data-store. Towards this, we begin by formally defining the problem setup and providing the necessary background along with the notations used.

Let’s first consider the standard classification setting which requires predicting a class in 𝒴𝒴\mathscr{Y}script_Y for a given instance x𝒳𝑥𝒳x\in\mathscr{X}italic_x ∈ script_X. Assume that 𝖣XYsubscript𝖣𝑋𝑌\mathsf{D}_{XY}sansserif_D start_POSTSUBSCRIPT italic_X italic_Y end_POSTSUBSCRIPT captures the underlying data distribution and one has access to n𝑛nitalic_n training examples 𝒮n{(xi,yi)}i[n]subscript𝒮𝑛subscriptsubscript𝑥𝑖subscript𝑦𝑖𝑖delimited-[]𝑛\mathscr{S}_{n}\triangleq\{(x_{i},y_{i})\}_{i\in[n]}script_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ≜ { ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) } start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT that are independent and identically distributed (i.i.d.) according to 𝖣XYsubscript𝖣𝑋𝑌\mathsf{D}_{XY}sansserif_D start_POSTSUBSCRIPT italic_X italic_Y end_POSTSUBSCRIPT. Given 𝒮nsubscript𝒮𝑛\mathscr{S}_{n}script_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT, one hopes to learn a classifier f:𝒳|𝒴|:𝑓𝒳superscript𝒴f:\mathscr{X}\to\mathbb{R}^{|\mathscr{Y}|}italic_f : script_X → blackboard_R start_POSTSUPERSCRIPT | script_Y | end_POSTSUPERSCRIPT that minimizes the miss-classification error:

R(f)=(X,Y)𝖣XY[argmaxy𝒴fy(X)Y],𝑅𝑓subscriptsimilar-to𝑋𝑌subscript𝖣𝑋𝑌delimited-[]subscriptargmax𝑦𝒴superscript𝑓𝑦𝑋𝑌\displaystyle R(f)=\mathbb{P}_{(X,Y)\sim\mathsf{D}_{XY}}\big{[}\operatorname*{% arg\,max}_{y\in\mathscr{Y}}f^{y}(X)\neq Y\big{]},italic_R ( italic_f ) = blackboard_P start_POSTSUBSCRIPT ( italic_X , italic_Y ) ∼ sansserif_D start_POSTSUBSCRIPT italic_X italic_Y end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ start_OPERATOR roman_arg roman_max end_OPERATOR start_POSTSUBSCRIPT italic_y ∈ script_Y end_POSTSUBSCRIPT italic_f start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT ( italic_X ) ≠ italic_Y ] , (1)

where fy(x)superscript𝑓𝑦𝑥f^{y}(x)italic_f start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT ( italic_x ) denotes the score that f𝑓fitalic_f assigns to the y𝑦yitalic_y-th class, given the input instance x𝑥xitalic_x. Since directly optimizing the miss-classification error or 0/1010/10 / 1-loss poses computational challenges, one typically selects the classifier that minimizes the empirical risk associated with a well behaved surrogate loss function :|𝒴|×𝒴:superscript𝒴𝒴\ell:\mathbb{R}^{|\mathscr{Y}|}\times\mathscr{Y}\to\mathbb{R}roman_ℓ : blackboard_R start_POSTSUPERSCRIPT | script_Y | end_POSTSUPERSCRIPT × script_Y → blackboard_R on the training sample 𝒮nsubscript𝒮𝑛\mathscr{S}_{n}script_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT:

R,n(f)=1ni[n](f(xi),yi).subscript𝑅𝑛𝑓1𝑛subscript𝑖delimited-[]𝑛𝑓subscript𝑥𝑖subscript𝑦𝑖\displaystyle R_{\ell,n}(f)=\frac{1}{n}\sum_{i\in[n]}\ell\big{(}f(x_{i}),y_{i}% \big{)}.italic_R start_POSTSUBSCRIPT roman_ℓ , italic_n end_POSTSUBSCRIPT ( italic_f ) = divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT roman_ℓ ( italic_f ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) . (2)

The (population) risk associated with the surrogate loss function takes the following form:

R(f)=𝔼(X,Y)𝖣XY[(f(X),Y)].subscript𝑅𝑓subscript𝔼similar-to𝑋𝑌subscript𝖣𝑋𝑌delimited-[]𝑓𝑋𝑌\displaystyle R_{\ell}(f)=\mathbb{E}_{(X,Y)\sim\mathsf{D}_{XY}}\big{[}\ell\big% {(}f(X),Y\big{)}\big{]}.italic_R start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT ( italic_f ) = blackboard_E start_POSTSUBSCRIPT ( italic_X , italic_Y ) ∼ sansserif_D start_POSTSUBSCRIPT italic_X italic_Y end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ roman_ℓ ( italic_f ( italic_X ) , italic_Y ) ] . (3)

Different from the standard classification setup described above, we now consider the classification task with access to a data-store: Given an instance x𝑥xitalic_x, the classifier can potentially leverage a data-store 𝒵𝒵\mathscr{I}\subseteq\mathscr{Z}script_I ⊆ script_Z – a collection of potentially relevant information or evidences, where 𝒵𝒵\mathscr{Z}script_Z denotes the space of all possible evidences. Accordingly, one can define the empirical and population risks of a classifier f(,):𝒳|𝒴|:𝑓𝒳superscript𝒴f(\cdot,\mathscr{I}):\mathscr{X}\to\mathbb{R}^{|\mathscr{Y}|}italic_f ( ⋅ , script_I ) : script_X → blackboard_R start_POSTSUPERSCRIPT | script_Y | end_POSTSUPERSCRIPT as follows:

R,,n(f)subscript𝑅𝑛𝑓\displaystyle R_{\ell,\mathscr{I},n}(f)italic_R start_POSTSUBSCRIPT roman_ℓ , script_I , italic_n end_POSTSUBSCRIPT ( italic_f ) =1ni[n](f(xi,),yi),absent1𝑛subscript𝑖delimited-[]𝑛𝑓subscript𝑥𝑖subscript𝑦𝑖\displaystyle=\frac{1}{n}\sum_{i\in[n]}\ell\big{(}f(x_{i},\mathscr{I}),y_{i}% \big{)},= divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT roman_ℓ ( italic_f ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , script_I ) , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) , (4)
R,(f)subscript𝑅𝑓\displaystyle R_{\ell,\mathscr{I}}(f)italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_f ) =𝔼[(f(X,),Y)],absent𝔼delimited-[]𝑓𝑋𝑌\displaystyle=\mathbb{E}\big{[}\ell\big{(}f(X,\mathscr{I}),Y\big{)}\big{]},= blackboard_E [ roman_ℓ ( italic_f ( italic_X , script_I ) , italic_Y ) ] , (5)

where expectation is take over in (X,Y)𝖣XYsimilar-to𝑋𝑌subscript𝖣𝑋𝑌(X,Y)\sim\mathsf{D}_{XY}( italic_X , italic_Y ) ∼ sansserif_D start_POSTSUBSCRIPT italic_X italic_Y end_POSTSUBSCRIPT as well as the possible randomness in f(,)𝑓f(\cdot,\mathscr{I})italic_f ( ⋅ , script_I ). However, due its prohibitive computational cost, such a general classifier that directly processes the entire data-store for each prediction is far from how an additional data-store is utilized by ML models in practice.

This motivates us to study the following explicit retrieval-augmented classification setup to utilize the data-store: Given an input instance x𝒳𝑥𝒳x\in\mathscr{X}italic_x ∈ script_X, one first retrieves input-dependent supporting evidences xsuperscript𝑥\mathcal{E}^{x}\subset\mathscr{I}caligraphic_E start_POSTSUPERSCRIPT italic_x end_POSTSUPERSCRIPT ⊂ script_I with the help of a retriever model which has access to the entire data-store \mathscr{I}script_I. Now, given x𝑥xitalic_x and xsuperscript𝑥\mathscr{E}^{x}script_E start_POSTSUPERSCRIPT italic_x end_POSTSUPERSCRIPT, one invokes a predictor model to predict the class associated with x𝑥xitalic_x. Thus, a retriever-augmented classification setup consists of two key components models: 1) retriever model and 2) predictor model, which we formally introduce next.

Retriever model. For the retrieval stage, we rely on a retriever model to capture the relevance of an evidence z𝑧z\in\mathscr{I}italic_z ∈ script_I towards the input instance x𝒳𝑥𝒳x\in\mathscr{X}italic_x ∈ script_X. Let rθ:𝒳×𝒵:subscript𝑟𝜃𝒳𝒵r_{\theta}:\mathscr{X}\times\mathscr{Z}\to\mathbb{R}italic_r start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT : script_X × script_Z → blackboard_R be the retriever model parameterized by θΘ𝜃Θ\theta\in\Thetaitalic_θ ∈ roman_Θ that assigns a relevance score rθ(x,z)subscript𝑟𝜃𝑥𝑧r_{\theta}(x,z)italic_r start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x , italic_z ) to the instance-evidence pair (x,z)𝑥𝑧(x,z)( italic_x , italic_z ). Furthermore, for each instance x𝑥xitalic_x, the retriever model rθsubscript𝑟𝜃r_{\theta}italic_r start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT induces the following distribution over the set of potential evidences:

pθ,(z|x)=exp(rθ(x,z))zexp(rθ(x,z)),z.formulae-sequencesubscript𝑝𝜃conditional𝑧𝑥subscript𝑟𝜃𝑥𝑧subscriptsuperscript𝑧subscript𝑟𝜃𝑥superscript𝑧for-all𝑧\displaystyle p_{\theta,\mathscr{I}}\big{(}z|x\big{)}=\frac{\exp\big{(}r_{% \theta}(x,z)\big{)}}{\sum_{z^{\prime}\in\mathscr{I}}\exp\big{(}r_{\theta}(x,z^% {\prime})\big{)}},\quad\forall~{}z\in\mathscr{I}.italic_p start_POSTSUBSCRIPT italic_θ , script_I end_POSTSUBSCRIPT ( italic_z | italic_x ) = divide start_ARG roman_exp ( italic_r start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x , italic_z ) ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ script_I end_POSTSUBSCRIPT roman_exp ( italic_r start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x , italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ) end_ARG , ∀ italic_z ∈ script_I . (6)

There are multiple strategies to construct the set of input-dependent supporting evidences xsuperscript𝑥\mathscr{E}^{x}script_E start_POSTSUPERSCRIPT italic_x end_POSTSUPERSCRIPT based on rθsubscript𝑟𝜃r_{\theta}italic_r start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT. For example, for a fixed integer k1𝑘1k\geq 1italic_k ≥ 1, one could select k𝑘kitalic_k evidences corresponding to the k𝑘kitalic_k highest scores in {rθ(x,z)}zsubscriptsubscript𝑟𝜃𝑥𝑧𝑧\{r_{\theta}(x,z)\}_{z\in\mathscr{I}}{ italic_r start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x , italic_z ) } start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT. Another strategy is to sample k𝑘kitalic_k evidences according to the distribution pθ,(|x)p_{\theta,\mathscr{I}}(\cdot|x)italic_p start_POSTSUBSCRIPT italic_θ , script_I end_POSTSUBSCRIPT ( ⋅ | italic_x ) in (6). Here, one could perform the sampling with or without replacement. In what follows, we denote the retrieved supporting evidence for the instance x𝑥xitalic_x as θxsubscriptsuperscript𝑥𝜃\mathscr{E}^{x}_{\theta}script_E start_POSTSUPERSCRIPT italic_x end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT to highlight the dependence on the underlying retriever model.

Predictor model. Let hξ:𝒳×|𝒴|:subscript𝜉𝒳superscriptsuperscript𝒴h_{\xi}:\mathscr{X}\times\mathscr{I}^{\ast}\to\mathbb{R}^{|\mathscr{Y}|}italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT : script_X × script_I start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT | script_Y | end_POSTSUPERSCRIPT be the predictor model parameterized by ξΞ𝜉Ξ\xi\in\Xiitalic_ξ ∈ roman_Ξ, where superscript\mathscr{I}^{\ast}script_I start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT denotes the Kleene star on \mathscr{I}script_I. Given x𝒳𝑥𝒳x\in\mathscr{X}italic_x ∈ script_X and superscript\mathcal{E}\in\mathscr{I}^{\ast}caligraphic_E ∈ script_I start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT, the predictor model hξsubscript𝜉h_{\xi}italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT assigns a score to each class in 𝒴𝒴\mathscr{Y}script_Y, defining a distribution over 𝒴𝒴\mathscr{Y}script_Y as follows:

pξ(y|x,)=exp(hξy(x,))y𝒴exp(hξy(x,)),y𝒴,formulae-sequencesubscript𝑝𝜉conditional𝑦𝑥subscriptsuperscript𝑦𝜉𝑥subscriptsuperscript𝑦𝒴subscriptsuperscriptsuperscript𝑦𝜉𝑥for-all𝑦𝒴\displaystyle p_{\xi}\big{(}y|x,\mathscr{E})=\frac{\exp\big{(}h^{y}_{\xi}(x,% \mathcal{E})\big{)}}{\sum_{y^{\prime}\in\mathscr{Y}}\exp\big{(}h^{y^{\prime}}_% {\xi}(x,\mathcal{E})\big{)}},\quad\forall~{}y\in\mathscr{Y},italic_p start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_y | italic_x , script_E ) = divide start_ARG roman_exp ( italic_h start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x , caligraphic_E ) ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ script_Y end_POSTSUBSCRIPT roman_exp ( italic_h start_POSTSUPERSCRIPT italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x , caligraphic_E ) ) end_ARG , ∀ italic_y ∈ script_Y , (7)

where hξy(,)subscriptsuperscript𝑦𝜉h^{y}_{\xi}(\cdot,\cdot)italic_h start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( ⋅ , ⋅ ) denotes the score assigned to the y𝑦yitalic_y-th class by the predictor model hξsubscript𝜉h_{\xi}italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT.

For ease of exposition, we focus on the setting with k=|θx|=1,x𝒳,formulae-sequence𝑘subscriptsuperscript𝑥𝜃1for-all𝑥𝒳k=|\mathscr{E}^{x}_{\theta}|=1,\forall x\in\mathscr{X},italic_k = | script_E start_POSTSUPERSCRIPT italic_x end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT | = 1 , ∀ italic_x ∈ script_X , in our analysis throughout this paper. This corresponds to retrieving a single supporting evidence for each input instance. Our analysis can be generalized to k>1𝑘1k>1italic_k > 1 by working with a ~k~superscript𝑘\tilde{\mathscr{I}}\subseteq\mathscr{I}^{k}over~ start_ARG script_I end_ARG ⊆ script_I start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT as the new data-store and p~θ,(|x)\tilde{p}_{\theta,\mathscr{I}}(\cdot|x)over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_θ , script_I end_POSTSUBSCRIPT ( ⋅ | italic_x ) as a distribution over ~~\tilde{\mathscr{I}}over~ start_ARG script_I end_ARG obtained by suitably modifying pθ,subscript𝑝𝜃p_{\theta,\mathscr{I}}italic_p start_POSTSUBSCRIPT italic_θ , script_I end_POSTSUBSCRIPT in (6). For example, when k𝑘kitalic_k supporting evidences are sampled with replacement, then the following holds (z1,,zk)kfor-allsubscript𝑧1subscript𝑧𝑘superscript𝑘\forall(z_{1},\ldots,z_{k})\in\mathscr{I}^{k}∀ ( italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) ∈ script_I start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT.

p~θ,((z1,,zk)|x)=j[k]pθ,(zj|x).subscript~𝑝𝜃conditionalsubscript𝑧1subscript𝑧𝑘𝑥subscriptproduct𝑗delimited-[]𝑘subscript𝑝𝜃conditionalsubscript𝑧𝑗𝑥\displaystyle\tilde{p}_{\theta,\mathscr{I}}\big{(}(z_{1},\ldots,z_{k})\big{|}x% \big{)}=\prod_{j\in[k]}p_{\theta,\mathscr{I}}(z_{j}|x).over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_θ , script_I end_POSTSUBSCRIPT ( ( italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) | italic_x ) = ∏ start_POSTSUBSCRIPT italic_j ∈ [ italic_k ] end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_θ , script_I end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT | italic_x ) .

Empirical risk minimization and excess risk for RAMs.  For a pair of retriever and predictor models parameterized by θ𝜃\thetaitalic_θ and ξ𝜉\xiitalic_ξ, respectively, we can define the empirical and population risks associated with a (surrogate) loss function \ellroman_ℓ as follows:

R,,n(ξ,θ)subscript𝑅𝑛𝜉𝜃\displaystyle R_{\ell,\mathscr{I},n}(\xi,\theta)italic_R start_POSTSUBSCRIPT roman_ℓ , script_I , italic_n end_POSTSUBSCRIPT ( italic_ξ , italic_θ ) =1ni[n]zpθ(z|x)(hξ(xi,z),yi),absent1𝑛subscript𝑖delimited-[]𝑛subscript𝑧subscript𝑝𝜃conditional𝑧𝑥subscript𝜉subscript𝑥𝑖𝑧subscript𝑦𝑖\displaystyle=\frac{1}{n}\sum_{i\in[n]}\sum_{z\in\mathscr{I}}p_{\theta}(z|x)% \ell\big{(}h_{\xi}(x_{i},z),y_{i}\big{)},= divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z | italic_x ) roman_ℓ ( italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) , (8)
R,(ξ,θ)subscript𝑅𝜉𝜃\displaystyle R_{\ell,\mathscr{I}}(\xi,\theta)italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_ξ , italic_θ ) =𝔼[(hξ(X,θX),Y)].absent𝔼delimited-[]subscript𝜉𝑋subscriptsuperscript𝑋𝜃𝑌\displaystyle=\mathbb{E}\big{[}\ell\big{(}h_{\xi}(X,\mathcal{E}^{X}_{\theta}),% Y\big{)}\big{]}.= blackboard_E [ roman_ℓ ( italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_X , caligraphic_E start_POSTSUPERSCRIPT italic_X end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ) , italic_Y ) ] . (9)

Note that the expectation in (9) is taken over (X,Y)𝖣XYsimilar-to𝑋𝑌subscript𝖣𝑋𝑌(X,Y)\sim\mathsf{D}_{XY}( italic_X , italic_Y ) ∼ sansserif_D start_POSTSUBSCRIPT italic_X italic_Y end_POSTSUBSCRIPT as well as the randomness involved in the retrieval stage, e.g., sampling the evidences according to pθ,(|x)p_{\theta,\mathscr{I}}(\cdot|x)italic_p start_POSTSUBSCRIPT italic_θ , script_I end_POSTSUBSCRIPT ( ⋅ | italic_x ) in (6). Given a pair of predictor class ΞΞ\Xiroman_Ξ and retriever class ΘΘ\Thetaroman_Θ, let (ξ^,θ^)^𝜉^𝜃(\hat{\xi},\hat{\theta})( over^ start_ARG italic_ξ end_ARG , over^ start_ARG italic_θ end_ARG ) denote the predictor-retriever pair obtained via empirical risk minimization (ERM) as follows:

(ξ^,θ^)argmin(ξ,θ)Ξ×ΘR,,n(ξ,θ).^𝜉^𝜃subscriptargmin𝜉𝜃ΞΘsubscript𝑅𝑛𝜉𝜃\displaystyle(\hat{\xi},\hat{\theta})\in\operatorname*{arg\,min}_{(\xi,\theta)% \in\Xi\times\Theta}R_{\ell,\mathscr{I},n}(\xi,\theta).( over^ start_ARG italic_ξ end_ARG , over^ start_ARG italic_θ end_ARG ) ∈ start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT ( italic_ξ , italic_θ ) ∈ roman_Ξ × roman_Θ end_POSTSUBSCRIPT italic_R start_POSTSUBSCRIPT roman_ℓ , script_I , italic_n end_POSTSUBSCRIPT ( italic_ξ , italic_θ ) . (10)

Let allsubscriptall\mathscr{F}_{\rm all}script_F start_POSTSUBSCRIPT roman_all end_POSTSUBSCRIPT denote the set of all measurable functions from 𝒳×𝒵𝒳𝒵\mathscr{X}\times\mathscr{Z}script_X × script_Z to |𝒴|superscript𝒴\mathbb{R}^{|\mathscr{Y}|}blackboard_R start_POSTSUPERSCRIPT | script_Y | end_POSTSUPERSCRIPT. The optimal risk for the classification with access to the data-store is achieved by the best possible predictor fopt,allsubscriptsuperscript𝑓optsubscriptallf^{\ell}_{\rm opt,\mathscr{I}}\in\mathscr{F}_{\rm all}italic_f start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_opt , script_I end_POSTSUBSCRIPT ∈ script_F start_POSTSUBSCRIPT roman_all end_POSTSUBSCRIPT when it has access to the best retrieved evidence in \mathscr{I}script_I. In particular, we have

fopt,=argminfall𝔼[minz(f(X,z),Y)].subscriptsuperscript𝑓optsubscriptargmin𝑓subscriptall𝔼delimited-[]subscript𝑧𝑓𝑋𝑧𝑌\displaystyle f^{\ell}_{\mathrm{opt},\mathscr{I}}=\operatorname*{arg\,min}_{f% \in\mathscr{F}_{\mathrm{all}}}\mathbb{E}\big{[}\min_{z\in\mathscr{I}}\ell(f(X,% z),Y)\big{]}.italic_f start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_opt , script_I end_POSTSUBSCRIPT = start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_f ∈ script_F start_POSTSUBSCRIPT roman_all end_POSTSUBSCRIPT end_POSTSUBSCRIPT blackboard_E [ roman_min start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT roman_ℓ ( italic_f ( italic_X , italic_z ) , italic_Y ) ] . (11)

Given fopt,subscriptsuperscript𝑓optf^{\ell}_{\mathrm{opt},\mathscr{I}}italic_f start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_opt , script_I end_POSTSUBSCRIPT, we defined the excess risk of a predictor-retriever pair (ξ,θ)𝜉𝜃(\xi,\theta)( italic_ξ , italic_θ ) as follows:

Δ,(ξ,θ)=R,(ξ,θ)R,(fopt,)R,(ξ,θ)𝔼[minz(fopt,(X,z),Y)].subscriptΔ𝜉𝜃subscript𝑅𝜉𝜃subscript𝑅subscriptsuperscript𝑓optsubscript𝑅𝜉𝜃𝔼delimited-[]subscript𝑧subscriptsuperscript𝑓opt𝑋𝑧𝑌\displaystyle\Delta_{\ell,\mathscr{I}}(\xi,\theta)=R_{\ell,\mathscr{I}}(\xi,% \theta)-R_{\ell,\mathscr{I}}(f^{\ell}_{\mathrm{opt},\mathscr{I}})\triangleq R_% {\ell,\mathscr{I}}(\xi,\theta)-\mathbb{E}\big{[}\min_{z\in\mathscr{I}}\ell(f^{% \ell}_{\mathrm{opt},\mathscr{I}}(X,z),Y)\big{]}.roman_Δ start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_ξ , italic_θ ) = italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_ξ , italic_θ ) - italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_f start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_opt , script_I end_POSTSUBSCRIPT ) ≜ italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_ξ , italic_θ ) - blackboard_E [ roman_min start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT roman_ℓ ( italic_f start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_opt , script_I end_POSTSUBSCRIPT ( italic_X , italic_z ) , italic_Y ) ] . (12)

With the formal definition of the classification setting with access to a data-store and the necessary background in place, we proceed to address the two key objectives of this work: 1) Proposing a natural and efficient joint end-to-end training procedure for the predictor-retriever pair in a RAM; and 2) Developing a rigorous statistical understanding of RAMs focusing on the interaction between predictor and retriever components towards reducing overall excess risk.

3 Joint training and excess risk

Recall that training a RAM involves training both the retriever rθ:𝒳×𝒵:subscript𝑟𝜃𝒳𝒵r_{\theta}:\mathscr{X}\times\mathscr{Z}\to\mathbb{R}italic_r start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT : script_X × script_Z → blackboard_R and the predictor hξ:𝒳×|𝒴|:subscript𝜉𝒳superscript𝒴h_{\xi}:\mathscr{X}\times\mathscr{I}\to\mathbb{R}^{|\mathscr{Y}|}italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT : script_X × script_I → blackboard_R start_POSTSUPERSCRIPT | script_Y | end_POSTSUPERSCRIPT components of the model without access to intermediate supervision on retrieval, which is infeasible to obtain in most practical settings. Thus, it becomes critical to devise methods to jointly train rθsubscript𝑟𝜃r_{\theta}italic_r start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT and hξsubscript𝜉h_{\xi}italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT with access to only labeled instances 𝒮n={(xi,yi)}i[n]𝒳×𝒴subscript𝒮𝑛subscriptsubscript𝑥𝑖subscript𝑦𝑖𝑖delimited-[]𝑛𝒳𝒴\mathscr{S}_{n}=\{(x_{i},y_{i})\}_{i\in[n]}\subseteq\mathscr{X}\times\mathscr{Y}script_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT = { ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) } start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT ⊆ script_X × script_Y with the predictor guiding the retriever training based on how valuable the retriever-provided evidences are towards the correct final prediction.

Towards this, we leverage the empirical risk from (8) along with the log-loss (hξ(x,z),y)=logpξ(y|x,z)subscript𝜉𝑥𝑧𝑦subscript𝑝𝜉conditional𝑦𝑥𝑧\ell(h_{\xi}(x,z),y)=-\log p_{\xi}(y|x,z)roman_ℓ ( italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x , italic_z ) , italic_y ) = - roman_log italic_p start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_y | italic_x , italic_z ), where pξ(y|x,z)subscript𝑝𝜉conditional𝑦𝑥𝑧p_{\xi}(y|x,z)italic_p start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_y | italic_x , italic_z ) is defined in (7). In particular, this leads to the following joint end-to-end training objective:

n(ξ,θ;)Rlog,,n(ξ,θ)=1ni[n]zpθ,(z|xi)logpξ(yi|xi,z).subscript𝑛𝜉𝜃subscript𝑅log𝑛𝜉𝜃1𝑛subscript𝑖delimited-[]𝑛subscript𝑧subscript𝑝𝜃conditional𝑧subscript𝑥𝑖subscript𝑝𝜉conditionalsubscript𝑦𝑖subscript𝑥𝑖𝑧\displaystyle\mathscr{L}_{n}(\xi,\theta;\mathscr{I})\triangleq R_{{\rm log},% \mathscr{I},n}(\xi,\theta)=-\frac{1}{n}\sum_{i\in[n]}\sum_{z\in\mathscr{I}}p_{% \theta,\mathscr{I}}(z|x_{i})\cdot\log p_{\xi}(y_{i}|x_{i},z).script_L start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_ξ , italic_θ ; script_I ) ≜ italic_R start_POSTSUBSCRIPT roman_log , script_I , italic_n end_POSTSUBSCRIPT ( italic_ξ , italic_θ ) = - divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_θ , script_I end_POSTSUBSCRIPT ( italic_z | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ⋅ roman_log italic_p start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) . (13)

Note that the objective in (13) aims to improve the end-to-end performance of a RAM in deployment in the sense that the objective aims to minimize the expected loss given the selected evidences as per the retriever-induced distribution. One can use gradient-based methods to jointly minimize the objective in (13) with respect to (ξ,θ)𝜉𝜃(\xi,\theta)( italic_ξ , italic_θ ); however, its efficient implementation is non-trivial due to the sum over entire data-store \mathscr{I}script_I. In App. C.1, we discuss some approximate design choices. Lastly, please refer to Sec. 3.6 for connections between our proposed objective in (13) and some of the existing end-to-end training approaches for RAMs.

Next, to study the generalization and expressive power of RAMs, we want to bound the excess risk Δ,(ξ^,θ^)subscriptΔ^𝜉^𝜃\Delta_{\ell,\mathscr{I}}(\hat{\xi},\hat{\theta})roman_Δ start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( over^ start_ARG italic_ξ end_ARG , over^ start_ARG italic_θ end_ARG ) as defined in (12). We consider 𝒳𝒳\mathscr{X}script_X to be a compact subspace of dxsuperscriptsubscript𝑑𝑥\mathbb{R}^{d_{x}}blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT end_POSTSUPERSCRIPT and, for simplicity, take 𝒳[1,1]dx𝒳superscript11subscript𝑑𝑥\mathscr{X}\subseteq[-1,1]^{d_{x}}script_X ⊆ [ - 1 , 1 ] start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT end_POSTSUPERSCRIPT. Similarly, we consider that each retrieval example z𝑧z\in\mathscr{I}italic_z ∈ script_I is embedded in the space [1,1]dzsuperscript11subscript𝑑𝑧[-1,1]^{d_{z}}[ - 1 , 1 ] start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_POSTSUPERSCRIPT. We consider a data-store that polynomially scales with training data size, i.e., ||=poly(n)poly𝑛|\mathscr{I}|={\rm poly}(n)| script_I | = roman_poly ( italic_n ). For the purpose of analysis, we specialize our log-loss to be bounded by max>0subscript0\ell_{\max}>0roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT > 0, which is given as

(hξ(x,z),y)=min(max,logpξ(y|x,z))=min(max,log(y𝒴exp(hξy(x,z)))hξy(x,z)),subscript𝜉𝑥𝑧𝑦subscriptsubscript𝑝𝜉conditional𝑦𝑥𝑧subscriptsubscriptsuperscript𝑦𝒴subscriptsuperscriptsuperscript𝑦𝜉𝑥𝑧subscriptsuperscript𝑦𝜉𝑥𝑧\displaystyle\ell(h_{\xi}(x,z),y)=\min(\ell_{\max},-\log p_{\xi}(y|x,z))=\min% \bigg{(}\ell_{\max},\log\Big{(}\sum_{y^{\prime}\in\mathscr{Y}}\exp(h^{y^{% \prime}}_{\xi}(x,z))\Big{)}-h^{y}_{\xi}(x,z)\bigg{)},roman_ℓ ( italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x , italic_z ) , italic_y ) = roman_min ( roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT , - roman_log italic_p start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_y | italic_x , italic_z ) ) = roman_min ( roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT , roman_log ( ∑ start_POSTSUBSCRIPT italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ script_Y end_POSTSUBSCRIPT roman_exp ( italic_h start_POSTSUPERSCRIPT italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x , italic_z ) ) ) - italic_h start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x , italic_z ) ) ,

where pξ(y|x,z)subscript𝑝𝜉conditional𝑦𝑥𝑧p_{\xi}(y|x,z)italic_p start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_y | italic_x , italic_z ) and hξy(x,z)subscriptsuperscript𝑦𝜉𝑥𝑧h^{y}_{\xi}(x,z)italic_h start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x , italic_z ) are defined in (7).

3.1 Excess risk decomposition

Our excess risk relies on separating out the contribution coming from the retriever and the predictor during the joint training. Moreover, the retriever and predictor errors can be each split into generalization and approximation error.

The population risk optimizer of our joint training over the space Ξ×ΘΞΘ\Xi\times\Thetaroman_Ξ × roman_Θ is defined as

ξjoint,θjoint=argmin(ξ,θ)Ξ×Θ𝔼X[𝔼Zpθ(|X)𝔼Y|X(hξ(X,Z),Y)].\displaystyle\xi^{\ast}_{\rm joint},\theta^{\ast}_{\rm joint}=\operatorname*{% arg\,min}_{(\xi,\theta)\in\Xi\times\Theta}\mathbb{E}_{X}\big{[}\mathbb{E}_{Z% \sim p_{\theta}(\cdot|X)}\mathbb{E}_{Y|X}\ell\big{(}h_{\xi}(X,Z),Y)\big{]}.italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_joint end_POSTSUBSCRIPT , italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_joint end_POSTSUBSCRIPT = start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT ( italic_ξ , italic_θ ) ∈ roman_Ξ × roman_Θ end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ blackboard_E start_POSTSUBSCRIPT italic_Z ∼ italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ | italic_X ) end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_Y | italic_X end_POSTSUBSCRIPT roman_ℓ ( italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_X , italic_Z ) , italic_Y ) ] .

For a predictor ξ𝜉\xiitalic_ξ, sample x𝒳𝑥𝒳x\in\mathscr{X}italic_x ∈ script_X and retrieved example z𝑧z\in\mathscr{I}italic_z ∈ script_I, let us denote the risk averaged over the labels 𝒴𝒴\mathscr{Y}script_Y as

gξ(x,z)=𝔼Y|X=x[(hξ(x,z),Y)].subscript𝑔𝜉𝑥𝑧subscript𝔼conditional𝑌𝑋𝑥delimited-[]subscript𝜉𝑥𝑧𝑌g_{\xi}(x,z)=\mathbb{E}_{Y|X=x}[\ell\big{(}h_{\xi}(x,z),Y)].italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x , italic_z ) = blackboard_E start_POSTSUBSCRIPT italic_Y | italic_X = italic_x end_POSTSUBSCRIPT [ roman_ℓ ( italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x , italic_z ) , italic_Y ) ] . (14)

For any fixed predictor ξ𝜉\xiitalic_ξ (not necessarily in ΞΞ\Xiroman_Ξ) and fixed data-store \mathscr{I}script_I, the retriever that optimizes the joint population risk is given as p,ξ(z|x)=𝟙argminzgξ(x,z)(z)superscript𝑝𝜉conditional𝑧𝑥subscript1subscriptargminsuperscript𝑧subscript𝑔𝜉𝑥superscript𝑧𝑧p^{\ast,\xi}(z|x)=\mathbbm{1}_{\operatorname*{arg\,min}_{z^{\prime}\in\mathscr% {I}}g_{\xi}(x,z^{\prime})}(z)italic_p start_POSTSUPERSCRIPT ∗ , italic_ξ end_POSTSUPERSCRIPT ( italic_z | italic_x ) = blackboard_1 start_POSTSUBSCRIPT start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ script_I end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x , italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) end_POSTSUBSCRIPT ( italic_z ), where a tie is broken arbitrarily. Note that, for each sample x𝑥xitalic_x, the best retrieved evidence z𝑧zitalic_z may change. We define the optimal predictor within the class ΞΞ\Xiroman_Ξ with best possible retriever as

ξ=argminξΞ𝔼X[minzgξ(X,z)].superscript𝜉subscriptargmin𝜉Ξsubscript𝔼𝑋delimited-[]subscript𝑧subscript𝑔𝜉𝑋𝑧\xi^{\ast}=\operatorname*{arg\,min}_{\xi\in\Xi}\mathbb{E}_{X}\big{[}\min_{z\in% \mathscr{I}}g_{\xi}(X,z)\big{]}.italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT = start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_ξ ∈ roman_Ξ end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ roman_min start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_X , italic_z ) ] .

The optimal retriever within the class ΘΘ\Thetaroman_Θ for a given predictor ξ𝜉\xiitalic_ξ is defined as

θ(ξ)=argminθΘ𝔼X[𝔼Zpθ(|X)gξ(X,Z)].\theta(\xi)=\operatorname*{arg\,min}_{\theta\in\Theta}\mathbb{E}_{X}\big{[}% \mathbb{E}_{Z\sim p_{\theta}(\cdot|X)}g_{\xi}(X,Z)\big{]}.italic_θ ( italic_ξ ) = start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_θ ∈ roman_Θ end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ blackboard_E start_POSTSUBSCRIPT italic_Z ∼ italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ | italic_X ) end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_X , italic_Z ) ] .

The excess risk for the classes ΘΘ\Thetaroman_Θ and ΞΞ\Xiroman_Ξ can be bounded as

Δ,(ξ^,θ^)(θ,ξ){(θ^,ξ^),(θjoint,ξjoint)}|R,(ξ,θ)R,,n(ξ,θ)|Generalization ErrorsubscriptΔ^𝜉^𝜃subscriptsubscript𝜃𝜉^𝜃^𝜉subscriptsuperscript𝜃jointsubscriptsuperscript𝜉jointsubscript𝑅𝜉𝜃subscript𝑅𝑛𝜉𝜃Generalization Error\displaystyle\Delta_{\ell,\mathscr{I}}(\hat{\xi},\hat{\theta})\leq\underbrace{% \sum_{(\theta,\xi)\in\{(\hat{\theta},\hat{\xi}),(\theta^{\ast}_{\rm joint},\xi% ^{\ast}_{\rm joint})\}}|R_{\ell,\mathscr{I}}(\xi,\theta)-R_{\ell,\mathscr{I},n% }(\xi,\theta)|}_{\text{Generalization Error}}roman_Δ start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( over^ start_ARG italic_ξ end_ARG , over^ start_ARG italic_θ end_ARG ) ≤ under⏟ start_ARG ∑ start_POSTSUBSCRIPT ( italic_θ , italic_ξ ) ∈ { ( over^ start_ARG italic_θ end_ARG , over^ start_ARG italic_ξ end_ARG ) , ( italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_joint end_POSTSUBSCRIPT , italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_joint end_POSTSUBSCRIPT ) } end_POSTSUBSCRIPT | italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_ξ , italic_θ ) - italic_R start_POSTSUBSCRIPT roman_ℓ , script_I , italic_n end_POSTSUBSCRIPT ( italic_ξ , italic_θ ) | end_ARG start_POSTSUBSCRIPT Generalization Error end_POSTSUBSCRIPT
+R,(ξ,θ(ξ))𝔼X[minzgξ(X,z)]retriever error+𝔼X[minzgξ(X,z)]R,(fopt,)predictor errorsubscriptsubscript𝑅superscript𝜉𝜃superscript𝜉subscript𝔼𝑋delimited-[]subscript𝑧subscript𝑔superscript𝜉𝑋𝑧retriever errorsubscriptsubscript𝔼𝑋delimited-[]subscript𝑧subscript𝑔superscript𝜉𝑋𝑧subscript𝑅superscriptsubscript𝑓optpredictor error\displaystyle\qquad+\underbrace{R_{\ell,\mathscr{I}}(\xi^{*},\theta(\xi^{\ast}% ))-\mathbb{E}_{X}\big{[}\min_{z\in\mathscr{I}}g_{\xi^{*}}(X,z)\big{]}}_{\text{% retriever error}}+\underbrace{\mathbb{E}_{X}\big{[}\min_{z\in\mathscr{I}}g_{% \xi^{*}}(X,z)\big{]}-R_{\ell,\mathscr{I}}(f_{{\rm opt},\mathscr{I}}^{\ell})}_{% \text{predictor error}}+ under⏟ start_ARG italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , italic_θ ( italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) ) - blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ roman_min start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_X , italic_z ) ] end_ARG start_POSTSUBSCRIPT retriever error end_POSTSUBSCRIPT + under⏟ start_ARG blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ roman_min start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_X , italic_z ) ] - italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT roman_opt , script_I end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) end_ARG start_POSTSUBSCRIPT predictor error end_POSTSUBSCRIPT (15)

3.2 Generalization error

We first bound the generalization error and relate it to the covering number of the retriever and predictor class.

As our loss is bounded by maxsubscript\ell_{\max}roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT, through standard concentration bounds (Shalev-Shwartz and Ben-David, 2014), we obtain that, for any δ>0𝛿0\delta>0italic_δ > 0, with probability at least (1δ)1𝛿(1-\delta)( 1 - italic_δ ):

|R,(ξjoint,θjoint)R,,n(ξjoint,θjoint)|3maxlog(1/δ)n.subscript𝑅subscriptsuperscript𝜉jointsubscriptsuperscript𝜃jointsubscript𝑅𝑛subscriptsuperscript𝜉jointsubscriptsuperscript𝜃joint3subscript1𝛿𝑛|R_{\ell,\mathscr{I}}(\xi^{\ast}_{\rm joint},\theta^{\ast}_{\rm joint})-R_{% \ell,\mathscr{I},n}(\xi^{\ast}_{\rm joint},\theta^{\ast}_{\rm joint})|\leq 3% \ell_{\max}\sqrt{\tfrac{\log(1/\delta)}{n}}.| italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_joint end_POSTSUBSCRIPT , italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_joint end_POSTSUBSCRIPT ) - italic_R start_POSTSUBSCRIPT roman_ℓ , script_I , italic_n end_POSTSUBSCRIPT ( italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_joint end_POSTSUBSCRIPT , italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_joint end_POSTSUBSCRIPT ) | ≤ 3 roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT square-root start_ARG divide start_ARG roman_log ( 1 / italic_δ ) end_ARG start_ARG italic_n end_ARG end_ARG .

However, (ξ^,θ^)^𝜉^𝜃(\hat{\xi},\hat{\theta})( over^ start_ARG italic_ξ end_ARG , over^ start_ARG italic_θ end_ARG ) is learned from the data. A high probability generalization error requires taking union over the space of Ξ×ΘΞΘ\Xi\times\Thetaroman_Ξ × roman_Θ. We employ Rademacher complexity based generalization error bounds. Next, the covering number of the space ΞΞ\Xiroman_Ξ is used to bound the associated Rademacher complexity. See Shalev-Shwartz and Ben-David (2014) for details.

We define two norms which are used in defining the covering numbers for ΘΘ\Thetaroman_Θ and ΞΞ\Xiroman_Ξ. In particular, 𝐮n×||for-all𝐮superscript𝑛\forall\mathbf{u}\in\mathbb{R}^{n\times|\mathscr{I}|}∀ bold_u ∈ blackboard_R start_POSTSUPERSCRIPT italic_n × | script_I | end_POSTSUPERSCRIPT and fixed ξΞ,θΘformulae-sequence𝜉Ξ𝜃Θ\xi\in\Xi,\theta\in\Thetaitalic_ξ ∈ roman_Ξ , italic_θ ∈ roman_Θ,

𝐮2,[n],ξ=(1ni[n](zui,z(hξ(xi,z),yi))2)1/2,subscriptnorm𝐮2delimited-[]𝑛𝜉superscript1𝑛subscript𝑖delimited-[]𝑛superscriptsubscript𝑧subscript𝑢𝑖𝑧subscript𝜉subscript𝑥𝑖𝑧subscript𝑦𝑖212\displaystyle\|\mathbf{u}\|_{2,[n],\xi}=\Big{(}\tfrac{1}{n}\sum_{i\in[n]}\big{% (}\sum_{z\in\mathscr{I}}u_{i,z}\ell\big{(}h_{\xi}(x_{i},z),y_{i}\big{)}\big{)}% ^{2}\Big{)}^{1/2},∥ bold_u ∥ start_POSTSUBSCRIPT 2 , [ italic_n ] , italic_ξ end_POSTSUBSCRIPT = ( divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT ( ∑ start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_u start_POSTSUBSCRIPT italic_i , italic_z end_POSTSUBSCRIPT roman_ℓ ( italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT ,
𝐮2,[n],θ=(1ni[n](zpθ(z|xi)ui,z)2)1/2.subscriptnorm𝐮2delimited-[]𝑛𝜃superscript1𝑛subscript𝑖delimited-[]𝑛superscriptsubscript𝑧subscript𝑝𝜃conditional𝑧subscript𝑥𝑖subscript𝑢𝑖𝑧212\displaystyle\|\mathbf{u}\|_{2,[n],\theta}=\Big{(}\tfrac{1}{n}\sum_{i\in[n]}% \big{(}\sum_{z\in\mathscr{I}}p_{\theta}(z|x_{i})u_{i,z}\big{)}^{2}\Big{)}^{1/2}.∥ bold_u ∥ start_POSTSUBSCRIPT 2 , [ italic_n ] , italic_θ end_POSTSUBSCRIPT = ( divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT ( ∑ start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) italic_u start_POSTSUBSCRIPT italic_i , italic_z end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT . (16)

We also define 𝒩(Ξ,ν,2,[n],θ)\mathcal{N}(\Xi,\nu,{\|\cdot\|_{2,[n],\theta}})caligraphic_N ( roman_Ξ , italic_ν , ∥ ⋅ ∥ start_POSTSUBSCRIPT 2 , [ italic_n ] , italic_θ end_POSTSUBSCRIPT ) to be the ν𝜈\nuitalic_ν-covering number for the class ΞΞ\Xiroman_Ξ with respect to the norm 2,[n],θ\|\cdot\|_{2,[n],\theta}∥ ⋅ ∥ start_POSTSUBSCRIPT 2 , [ italic_n ] , italic_θ end_POSTSUBSCRIPT, and 𝒩(Θ,ν,2,[n],ξ)\mathcal{N}(\Theta,\nu,{\|\cdot\|_{2,[n],\xi}})caligraphic_N ( roman_Θ , italic_ν , ∥ ⋅ ∥ start_POSTSUBSCRIPT 2 , [ italic_n ] , italic_ξ end_POSTSUBSCRIPT ) to be the ν𝜈\nuitalic_ν-covering number for the class ΘΘ\Thetaroman_Θ with respect to the norm 2,[n],ξ\|\cdot\|_{2,[n],\xi}∥ ⋅ ∥ start_POSTSUBSCRIPT 2 , [ italic_n ] , italic_ξ end_POSTSUBSCRIPT. Then we have the generalization bound given as

|R,(ξ^,θ^)R,,n(ξ^,θ^)|infε[0,max/2](8ε+24nεmax2f𝒩(ν/2;Θ,Ξ)+f𝒩(ν/2;Ξ,Θ)dν),subscript𝑅^𝜉^𝜃subscript𝑅𝑛^𝜉^𝜃subscriptinfimum𝜀0subscript28𝜀24𝑛superscriptsubscript𝜀subscript2subscript𝑓𝒩𝜈2ΘΞsubscript𝑓𝒩𝜈2ΞΘ𝑑𝜈\displaystyle|R_{\ell,\mathscr{I}}(\hat{\xi},\hat{\theta})-R_{\ell,\mathscr{I}% ,n}(\hat{\xi},\hat{\theta})|\leq\inf_{\varepsilon\in[0,\ell_{\max}/2]}\Big{(}8% \varepsilon+\tfrac{24}{\sqrt{n}}\int_{\varepsilon}^{\tfrac{\ell_{\max}}{2}}f_{% \mathcal{N}}(\nu/2;\Theta,\Xi)+f_{\mathcal{N}}(\nu/2;\Xi,\Theta)d\nu\Big{)},| italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( over^ start_ARG italic_ξ end_ARG , over^ start_ARG italic_θ end_ARG ) - italic_R start_POSTSUBSCRIPT roman_ℓ , script_I , italic_n end_POSTSUBSCRIPT ( over^ start_ARG italic_ξ end_ARG , over^ start_ARG italic_θ end_ARG ) | ≤ roman_inf start_POSTSUBSCRIPT italic_ε ∈ [ 0 , roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT / 2 ] end_POSTSUBSCRIPT ( 8 italic_ε + divide start_ARG 24 end_ARG start_ARG square-root start_ARG italic_n end_ARG end_ARG ∫ start_POSTSUBSCRIPT italic_ε end_POSTSUBSCRIPT start_POSTSUPERSCRIPT divide start_ARG roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG end_POSTSUPERSCRIPT italic_f start_POSTSUBSCRIPT caligraphic_N end_POSTSUBSCRIPT ( italic_ν / 2 ; roman_Θ , roman_Ξ ) + italic_f start_POSTSUBSCRIPT caligraphic_N end_POSTSUBSCRIPT ( italic_ν / 2 ; roman_Ξ , roman_Θ ) italic_d italic_ν ) , (17)

for f𝒩(ν;𝒜,)=supblog(𝒩(𝒜,ν,2,[n],b)).f_{\mathcal{N}}(\nu;\mathcal{A},\mathcal{B})=\sup_{b\in\mathcal{B}}\sqrt{\log(% \mathcal{N}(\mathcal{A},\nu,\|\cdot\|_{2,[n],b}))}.italic_f start_POSTSUBSCRIPT caligraphic_N end_POSTSUBSCRIPT ( italic_ν ; caligraphic_A , caligraphic_B ) = roman_sup start_POSTSUBSCRIPT italic_b ∈ caligraphic_B end_POSTSUBSCRIPT square-root start_ARG roman_log ( caligraphic_N ( caligraphic_A , italic_ν , ∥ ⋅ ∥ start_POSTSUBSCRIPT 2 , [ italic_n ] , italic_b end_POSTSUBSCRIPT ) ) end_ARG .

We use ideas in Zhang (2023) to upper bound the covering number with pseudo-dimension (defined in the Appendix A) of the function class. This allows us to have a log||\log|\mathscr{I}|roman_log | script_I | dependence in the generalization error, while working with norm unbounded function classes.

3.3 Approximation error

We next proceed to bound the retriever and predictor approximation errors. Towards this, we extensively use the Sobolev functions spaces. A Sobolev space for a domain ΩΩ\Omegaroman_Ω is characterized by two quantities, κ𝜅\kappaitalic_κ – the number of weak-derivatives a (real-valued) function within it possesses, and Lp(Ω)subscript𝐿𝑝ΩL_{p}(\Omega)italic_L start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( roman_Ω ) – the norm with respect to which these derivatives are integrable. Please see Appendix A for a complete definition.

3.3.1 Retriever error

The retriever error is given by how well the score function rθ(x,z)subscript𝑟𝜃𝑥𝑧r_{\theta}(x,z)italic_r start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x , italic_z ) approximates the optimal retriever given ξsuperscript𝜉\xi^{*}italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT. In order to do so we first need to impose some smoothness constraints on the function gξ:𝒳×𝒵:subscript𝑔superscript𝜉𝒳𝒵g_{\xi^{*}}:\mathscr{X}\times\mathscr{Z}\to\mathbb{R}italic_g start_POSTSUBSCRIPT italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT : script_X × script_Z → blackboard_R. In particular, we assume the following.

Assumption 3.1 (Complexity of gξsubscript𝑔superscript𝜉g_{\xi^{*}}italic_g start_POSTSUBSCRIPT italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT).

There exists a baseline function bξ:[1,1]dx:subscript𝑏superscript𝜉superscript11subscript𝑑𝑥b_{\xi^{*}}:[-1,1]^{d_{x}}\to\mathbb{R}italic_b start_POSTSUBSCRIPT italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT : [ - 1 , 1 ] start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT end_POSTSUPERSCRIPT → blackboard_R such that the function gapξ:[1,1]dx+dz:subscriptgapsuperscript𝜉superscript11subscript𝑑𝑥subscript𝑑𝑧\mathrm{gap}_{\xi^{*}}:[-1,1]^{d_{x}+d_{z}}\to\mathbb{R}roman_gap start_POSTSUBSCRIPT italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT : [ - 1 , 1 ] start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_POSTSUPERSCRIPT → blackboard_R defined by gapξ(x,z)(gξ(x,z)bξ(x))subscriptgapsuperscript𝜉𝑥𝑧subscript𝑔superscript𝜉𝑥𝑧subscript𝑏superscript𝜉𝑥\mathrm{gap}_{\xi^{*}}(x,z)\triangleq(g_{\xi^{*}}(x,z)-b_{\xi^{*}}(x))roman_gap start_POSTSUBSCRIPT italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_x , italic_z ) ≜ ( italic_g start_POSTSUBSCRIPT italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_x , italic_z ) - italic_b start_POSTSUBSCRIPT italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_x ) ) lies in the Sobolev space with κ𝜅\kappaitalic_κ derivatives and L([1,1]dx+dz)subscript𝐿superscript11subscript𝑑𝑥subscript𝑑𝑧L_{\infty}([-1,1]^{d_{x}+d_{z}})italic_L start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT ( [ - 1 , 1 ] start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ) norm.

The above assumption says that for the predictor ξsuperscript𝜉\xi^{*}italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT the loss profile (averaged over labels in 𝒴𝒴\mathscr{Y}script_Y) gξ(x,z)subscript𝑔superscript𝜉𝑥𝑧g_{\xi^{*}}(x,z)italic_g start_POSTSUBSCRIPT italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_x , italic_z ), has two components – a (possibly) complex bξ(x)subscript𝑏superscript𝜉𝑥b_{\xi^{*}}(x)italic_b start_POSTSUBSCRIPT italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_x ) component that is uniform over z𝑧zitalic_z, and a ‘smooth’ gapξ(x,z)subscriptgapsuperscript𝜉𝑥𝑧\mathrm{gap}_{\xi^{*}}(x,z)roman_gap start_POSTSUBSCRIPT italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_x , italic_z ) component. In other words, given two similar retrieved evidences, the predictor incurs similar losses when each of the evidences is utilized with an input instance.

Then, for any τ>0𝜏0\tau>0italic_τ > 0, we can bound the retriever loss as follows:

R,(ξ,θ(ξ))𝔼X[minzgξ(X,z)]infθΘmaxrθ+τgapξ+log||τ2subscript𝑅superscript𝜉𝜃superscript𝜉subscript𝔼𝑋delimited-[]subscript𝑧subscript𝑔superscript𝜉𝑋𝑧subscriptinfimum𝜃Θsubscriptsubscriptnormsubscript𝑟𝜃𝜏subscriptgapsuperscript𝜉superscript𝜏2\displaystyle R_{\ell,\mathscr{I}}(\xi^{*},\theta(\xi^{\ast}))-\mathbb{E}_{X}% \big{[}\min_{z\in\mathscr{I}}g_{\xi^{*}}(X,z)\big{]}\leq\inf_{\theta\in\Theta}% \ell_{\max}\|r_{\theta}+\tau\cdot\mathrm{gap}_{\xi^{*}}\|_{\infty}+\frac{\log|% \mathscr{I}|}{\tau^{2}}italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , italic_θ ( italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) ) - blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ roman_min start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_X , italic_z ) ] ≤ roman_inf start_POSTSUBSCRIPT italic_θ ∈ roman_Θ end_POSTSUBSCRIPT roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ∥ italic_r start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT + italic_τ ⋅ roman_gap start_POSTSUBSCRIPT italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT + divide start_ARG roman_log | script_I | end_ARG start_ARG italic_τ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG (18)

3.3.2 Predictor error

The predictor error is measured with the optimal retrieval (as the retriever error is considered separately above). For this, we need to first quantify how the retrieval augmentation using the data-store \mathscr{I}script_I helps.

Usefulness of retrieval set:

We start with characterization of the prediction task in the presence of the data-store 𝒵𝒵\mathscr{I}\subset\mathscr{Z}script_I ⊂ script_Z. We assume that there exists a score function h:𝒳×𝒵|𝒴|:subscript𝒳𝒵superscript𝒴h_{*}:\mathscr{X}\times\mathscr{Z}\to\mathbb{R}^{|\mathscr{Y}|}italic_h start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT : script_X × script_Z → blackboard_R start_POSTSUPERSCRIPT | script_Y | end_POSTSUPERSCRIPT, and the corresponding probability distribution

py(x,z)=exp(hy(x,z))yexp(hy(x,z)),superscriptsubscript𝑝𝑦𝑥𝑧superscriptsubscript𝑦𝑥𝑧subscriptsuperscript𝑦superscriptsubscriptsuperscript𝑦𝑥𝑧p_{*}^{y}(x,z)=\frac{\exp(h_{*}^{y}(x,z))}{\sum_{y^{\prime}}\exp(h_{*}^{y^{% \prime}}(x,z))},italic_p start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT ( italic_x , italic_z ) = divide start_ARG roman_exp ( italic_h start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT ( italic_x , italic_z ) ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT roman_exp ( italic_h start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT ( italic_x , italic_z ) ) end_ARG , (19)

that approximates p𝖣XYy(x):=Y𝖣Y|X(y|X=x)assignsuperscriptsubscript𝑝subscript𝖣𝑋𝑌𝑦𝑥subscriptsimilar-to𝑌subscript𝖣conditional𝑌𝑋conditional𝑦𝑋𝑥p_{\mathsf{D}_{XY}}^{y}(x):=\mathbb{P}_{Y\sim\mathsf{D}_{Y|X}}(y|X=x)italic_p start_POSTSUBSCRIPT sansserif_D start_POSTSUBSCRIPT italic_X italic_Y end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT ( italic_x ) := blackboard_P start_POSTSUBSCRIPT italic_Y ∼ sansserif_D start_POSTSUBSCRIPT italic_Y | italic_X end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_y | italic_X = italic_x ) well for all x𝒳𝑥𝒳x\in\mathscr{X}italic_x ∈ script_X and y𝒴𝑦𝒴y\in\mathscr{Y}italic_y ∈ script_Y. Furthermore, we want this score function hsubscripth_{*}italic_h start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT to lie coordinate wise in a Sobolev space. The following assumption formalizes this.

Assumption 3.2 (Retrieval quality).

There exists a score function h:𝒳×𝒵|𝒴|:subscript𝒳𝒵superscript𝒴h_{*}:\mathscr{X}\times\mathscr{Z}\to\mathbb{R}^{|\mathscr{Y}|}italic_h start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT : script_X × script_Z → blackboard_R start_POSTSUPERSCRIPT | script_Y | end_POSTSUPERSCRIPT such that

  1. 1.

    for each y𝒴𝑦𝒴y\in\mathscr{Y}italic_y ∈ script_Y, the function hysuperscriptsubscript𝑦h_{*}^{y}italic_h start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT lies in the Sobolev space with κsubscript𝜅\kappa_{\mathscr{I}}italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT derivatives and finite L([1,1]dx+dz)subscript𝐿superscript11subscript𝑑𝑥subscript𝑑𝑧L_{\infty}([-1,1]^{d_{x}+d_{z}})italic_L start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT ( [ - 1 , 1 ] start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ) norm,

  2. 2.

    for any x𝒳𝑥𝒳x\in\mathscr{X}italic_x ∈ script_X, there exists a retrieved evidence z(x)superscript𝑧𝑥z^{*}(x)\in\mathscr{I}italic_z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_x ) ∈ script_I such that py(x,z)superscriptsubscript𝑝𝑦𝑥𝑧p_{*}^{y}(x,z)italic_p start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT ( italic_x , italic_z ), as defined in (19), satisfies

    maxy𝒴supx𝒳|py(x,z(x))p𝖣XYy(x)|c||γ.subscript𝑦𝒴subscriptsupremum𝑥𝒳superscriptsubscript𝑝𝑦𝑥superscript𝑧𝑥superscriptsubscript𝑝subscript𝖣𝑋𝑌𝑦𝑥subscript𝑐superscriptsubscript𝛾\max_{y\in\mathscr{Y}}\sup_{x\in\mathscr{X}}|p_{*}^{y}(x,z^{*}(x))-p_{\mathsf{% D}_{XY}}^{y}(x)|\leq c_{\mathscr{I}}|\mathscr{I}|^{-\gamma_{\mathscr{I}}}.roman_max start_POSTSUBSCRIPT italic_y ∈ script_Y end_POSTSUBSCRIPT roman_sup start_POSTSUBSCRIPT italic_x ∈ script_X end_POSTSUBSCRIPT | italic_p start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT ( italic_x , italic_z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_x ) ) - italic_p start_POSTSUBSCRIPT sansserif_D start_POSTSUBSCRIPT italic_X italic_Y end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT ( italic_x ) | ≤ italic_c start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT | script_I | start_POSTSUPERSCRIPT - italic_γ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_POSTSUPERSCRIPT .

Note that this is independent of the retriever class ΘΘ\Thetaroman_Θ and ΞΞ\Xiroman_Ξ, and captures intrinsic property of the data-store \mathscr{I}script_I. The tuple (γ,dz,κ)subscript𝛾subscript𝑑𝑧subscript𝜅(\gamma_{\mathscr{I}},d_{z},\kappa_{\mathscr{I}})( italic_γ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT , italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT , italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT ) defines the usefulness of \mathscr{I}script_I. In particular, the higher γsubscript𝛾\gamma_{\mathscr{I}}italic_γ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT the closer the approximation; and the higher the κsubscript𝜅\kappa_{\mathscr{I}}italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT and smaller the embedding dimension dzsubscript𝑑𝑧d_{z}italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT the ‘simpler’ the score function used for this approximation.

Under the Assumption 3.2, we bound the predictor error as

𝔼X[minzgξ(X,z)]R,(fopt,)subscript𝔼𝑋delimited-[]subscript𝑧subscript𝑔superscript𝜉𝑋𝑧subscript𝑅superscriptsubscript𝑓opt\displaystyle\mathbb{E}_{X}\big{[}\min_{z\in\mathscr{I}}g_{\xi^{*}}(X,z)\big{]% }-R_{\ell,\mathscr{I}}(f_{{\rm opt},\mathscr{I}}^{\ell})blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ roman_min start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_X , italic_z ) ] - italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT roman_opt , script_I end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) infξΞ2𝔼X[maxy𝒴|hξy(X,z(X))hy(X,z(X))|]+absentlimit-fromsubscriptinfimum𝜉Ξ2subscript𝔼𝑋delimited-[]subscript𝑦𝒴superscriptsubscript𝜉𝑦𝑋superscript𝑧𝑋superscriptsubscript𝑦𝑋superscript𝑧𝑋\displaystyle\leq\inf_{\xi\in\Xi}2\mathbb{E}_{X}\big{[}\max_{y\in\mathscr{Y}}|% h_{\xi}^{y}(X,z^{*}(X))-h_{*}^{y}(X,z^{*}(X))|\big{]}+≤ roman_inf start_POSTSUBSCRIPT italic_ξ ∈ roman_Ξ end_POSTSUBSCRIPT 2 blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ roman_max start_POSTSUBSCRIPT italic_y ∈ script_Y end_POSTSUBSCRIPT | italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT ( italic_X , italic_z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_X ) ) - italic_h start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT ( italic_X , italic_z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_X ) ) | ] +
(|𝒴|1)exp(max)+c||γexp(max).𝒴1subscriptsubscript𝑐superscriptsubscript𝛾subscript\displaystyle\qquad\quad(|\mathscr{Y}|-1)\exp(-\ell_{\max})+c_{\mathscr{I}}|% \mathscr{I}|^{-\gamma_{\mathscr{I}}}\exp(\ell_{\max}).( | script_Y | - 1 ) roman_exp ( - roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ) + italic_c start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT | script_I | start_POSTSUPERSCRIPT - italic_γ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_POSTSUPERSCRIPT roman_exp ( roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ) . (20)

One key step in arriving to the above inequality is expressing the loss of fopt,superscriptsubscript𝑓optf_{{\rm opt},\mathscr{I}}^{\ell}italic_f start_POSTSUBSCRIPT roman_opt , script_I end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT using the probability function hsubscripth_{*}italic_h start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT defined in Assumption 3.2. In particular, under Assumption 3.2, we show that

𝔼X[minzgfopt,(X,z)]𝔼X[gh(X,z(X))](|𝒴|1)exp(max)c||γexp(max).subscript𝔼𝑋delimited-[]subscript𝑧subscript𝑔superscriptsubscript𝑓opt𝑋𝑧subscript𝔼𝑋delimited-[]subscript𝑔subscript𝑋superscript𝑧𝑋𝒴1subscriptsubscript𝑐superscriptsubscript𝛾subscript\displaystyle\mathbb{E}_{X}\big{[}\min_{z\in\mathscr{I}}g_{f_{{\rm opt},% \mathscr{I}}^{\ell}}(X,z)\big{]}\geq\mathbb{E}_{X}\big{[}g_{h_{*}}(X,z^{*}(X))% \big{]}-(|\mathscr{Y}|-1)\exp(-\ell_{\max})-c_{\mathscr{I}}|\mathscr{I}|^{-% \gamma_{\mathscr{I}}}\exp(\ell_{\max}).blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ roman_min start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT roman_opt , script_I end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_X , italic_z ) ] ≥ blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ italic_g start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_X , italic_z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_X ) ) ] - ( | script_Y | - 1 ) roman_exp ( - roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ) - italic_c start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT | script_I | start_POSTSUPERSCRIPT - italic_γ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_POSTSUPERSCRIPT roman_exp ( roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ) .

3.4 Final excess risk bound

We now combine the three components of the excess risk bounds under Assumptions 3.1 and 3.2 and discuss the design tradeoffs. The following theorem captures our main theoretical result.

Theorem 3.3 (Excess risk of joint training).

Under Assumption 3.1 and 3.2, the excess risk for the retriever class ΘΘ\Thetaroman_Θ and predictor class ΞΞ\Xiroman_Ξ is bounded as

Δ,(ξ^,θ^)3max(1n+log(n)n)+infε[0,max2]8ε+24nεmax2f𝒩(ν2;Θ,Ξ)+f𝒩(ν2;Ξ,Θ)dνsubscriptΔ^𝜉^𝜃3subscript1𝑛𝑛𝑛subscriptinfimum𝜀0subscript28𝜀24𝑛superscriptsubscript𝜀subscript2subscript𝑓𝒩𝜈2ΘΞsubscript𝑓𝒩𝜈2ΞΘ𝑑𝜈\displaystyle\Delta_{\ell,\mathscr{I}}(\hat{\xi},\hat{\theta})\leq 3\ell_{\max% }(\tfrac{1}{n}+\sqrt{\tfrac{\log(n)}{n}})+\inf_{\varepsilon\in[0,\tfrac{\ell_{% \max}}{2}]}8\varepsilon+\tfrac{24}{\sqrt{n}}\int_{\varepsilon}^{\tfrac{\ell_{% \max}}{2}}f_{\mathcal{N}}(\tfrac{\nu}{2};\Theta,\Xi)+f_{\mathcal{N}}(\tfrac{% \nu}{2};\Xi,\Theta)d\nuroman_Δ start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( over^ start_ARG italic_ξ end_ARG , over^ start_ARG italic_θ end_ARG ) ≤ 3 roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ( divide start_ARG 1 end_ARG start_ARG italic_n end_ARG + square-root start_ARG divide start_ARG roman_log ( italic_n ) end_ARG start_ARG italic_n end_ARG end_ARG ) + roman_inf start_POSTSUBSCRIPT italic_ε ∈ [ 0 , divide start_ARG roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG ] end_POSTSUBSCRIPT 8 italic_ε + divide start_ARG 24 end_ARG start_ARG square-root start_ARG italic_n end_ARG end_ARG ∫ start_POSTSUBSCRIPT italic_ε end_POSTSUBSCRIPT start_POSTSUPERSCRIPT divide start_ARG roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG end_POSTSUPERSCRIPT italic_f start_POSTSUBSCRIPT caligraphic_N end_POSTSUBSCRIPT ( divide start_ARG italic_ν end_ARG start_ARG 2 end_ARG ; roman_Θ , roman_Ξ ) + italic_f start_POSTSUBSCRIPT caligraphic_N end_POSTSUBSCRIPT ( divide start_ARG italic_ν end_ARG start_ARG 2 end_ARG ; roman_Ξ , roman_Θ ) italic_d italic_ν
+infθΘinfτ>0maxrθ+τgapξ+log||τ2subscriptinfimum𝜃Θsubscriptinfimum𝜏0subscriptsubscriptnormsubscript𝑟𝜃𝜏subscriptgapsuperscript𝜉superscript𝜏2\displaystyle\qquad+\inf_{\theta\in\Theta}\inf_{\tau>0}\ell_{\max}\|r_{\theta}% +\tau\cdot\mathrm{gap}_{\xi^{*}}\|_{\infty}+\frac{\log|\mathscr{I}|}{\tau^{2}}+ roman_inf start_POSTSUBSCRIPT italic_θ ∈ roman_Θ end_POSTSUBSCRIPT roman_inf start_POSTSUBSCRIPT italic_τ > 0 end_POSTSUBSCRIPT roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ∥ italic_r start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT + italic_τ ⋅ roman_gap start_POSTSUBSCRIPT italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT + divide start_ARG roman_log | script_I | end_ARG start_ARG italic_τ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG
+infξΞ2𝔼X[maxy𝒴|hξy(X,z(X))hy(X,z(X))|]+(|𝒴|1)exp(max)+c||γexp(max),subscriptinfimum𝜉Ξ2subscript𝔼𝑋delimited-[]subscript𝑦𝒴superscriptsubscript𝜉𝑦𝑋superscript𝑧𝑋superscriptsubscript𝑦𝑋superscript𝑧𝑋𝒴1subscriptsubscript𝑐superscriptsubscript𝛾subscript\displaystyle\qquad+\inf_{\xi\in\Xi}2\mathbb{E}_{X}\big{[}\max_{y\in\mathscr{Y% }}|h_{\xi}^{y}(X,z^{*}(X))-h_{*}^{y}(X,z^{*}(X))|\big{]}+(|\mathscr{Y}|-1)\exp% (-\ell_{\max})+c_{\mathscr{I}}|\mathscr{I}|^{-\gamma_{\mathscr{I}}}\exp(\ell_{% \max}),+ roman_inf start_POSTSUBSCRIPT italic_ξ ∈ roman_Ξ end_POSTSUBSCRIPT 2 blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ roman_max start_POSTSUBSCRIPT italic_y ∈ script_Y end_POSTSUBSCRIPT | italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT ( italic_X , italic_z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_X ) ) - italic_h start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT ( italic_X , italic_z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_X ) ) | ] + ( | script_Y | - 1 ) roman_exp ( - roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ) + italic_c start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT | script_I | start_POSTSUPERSCRIPT - italic_γ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_POSTSUPERSCRIPT roman_exp ( roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ) ,

where f𝒩(ν;𝒜,)supblog(𝒩(𝒜,ν,2,[n],b))f_{\mathcal{N}}(\nu;\mathcal{A},\mathcal{B})\triangleq\sup_{b\in\mathcal{B}}% \sqrt{\log(\mathcal{N}(\mathcal{A},\nu,\|\cdot\|_{2,[n],b}))}italic_f start_POSTSUBSCRIPT caligraphic_N end_POSTSUBSCRIPT ( italic_ν ; caligraphic_A , caligraphic_B ) ≜ roman_sup start_POSTSUBSCRIPT italic_b ∈ caligraphic_B end_POSTSUBSCRIPT square-root start_ARG roman_log ( caligraphic_N ( caligraphic_A , italic_ν , ∥ ⋅ ∥ start_POSTSUBSCRIPT 2 , [ italic_n ] , italic_b end_POSTSUBSCRIPT ) ) end_ARG and 2,[n],θ\|\cdot\|_{2,[n],\theta}∥ ⋅ ∥ start_POSTSUBSCRIPT 2 , [ italic_n ] , italic_θ end_POSTSUBSCRIPT and 2,[n],ξ\|\cdot\|_{2,[n],\xi}∥ ⋅ ∥ start_POSTSUBSCRIPT 2 , [ italic_n ] , italic_ξ end_POSTSUBSCRIPT are defined in (3.2).

3.5 Illustrative example: MLPs

We instantiate both our retriever and predictor classes to be multi-layer perceptron (MLP) with depth Lretsubscript𝐿retL_{\rm ret}italic_L start_POSTSUBSCRIPT roman_ret end_POSTSUBSCRIPT & width Wret=O(dx+dz)subscript𝑊ret𝑂subscript𝑑𝑥subscript𝑑𝑧W_{\rm ret}=O(d_{x}+d_{z})italic_W start_POSTSUBSCRIPT roman_ret end_POSTSUBSCRIPT = italic_O ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) and depth Lpredsubscript𝐿predL_{\rm pred}italic_L start_POSTSUBSCRIPT roman_pred end_POSTSUBSCRIPT & width Wpred=O(|𝒴|(dx+dz))subscript𝑊pred𝑂𝒴subscript𝑑𝑥subscript𝑑𝑧W_{\rm pred}=O(|\mathscr{Y}|(d_{x}+d_{z}))italic_W start_POSTSUBSCRIPT roman_pred end_POSTSUBSCRIPT = italic_O ( | script_Y | ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) ), respectively. The class MLP(d,k;L,W)MLPsuperscript𝑑superscript𝑘𝐿𝑊{\rm MLP}\left(\mathbb{R}^{d},\mathbb{R}^{k};L,W\right)roman_MLP ( blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT , blackboard_R start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_L , italic_W ) is defined in Appendix A. The specialized excess risk bound for this setting is given as

Theorem 3.4 (Excess risk for MLP).

Under Assumption 3.1 and 3.2, the excess risk for the retriever class Θ=MLP(dx+dz,;Lret,O(dx+dz))Θ𝑀𝐿𝑃superscriptsubscript𝑑𝑥subscript𝑑𝑧subscript𝐿ret𝑂subscript𝑑𝑥subscript𝑑𝑧\Theta=MLP\left(\mathbb{R}^{d_{x}+d_{z}},\mathbb{R};L_{\rm ret},O(d_{x}+d_{z})\right)roman_Θ = italic_M italic_L italic_P ( blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_POSTSUPERSCRIPT , blackboard_R ; italic_L start_POSTSUBSCRIPT roman_ret end_POSTSUBSCRIPT , italic_O ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) ) and predictor class Ξ=MLP(dx+dz,|𝒴|;Lpred,O(|𝒴|(dx+dz)))Ξ𝑀𝐿𝑃superscriptsubscript𝑑𝑥subscript𝑑𝑧superscript𝒴subscript𝐿pred𝑂𝒴subscript𝑑𝑥subscript𝑑𝑧\Xi=MLP\left(\mathbb{R}^{d_{x}+d_{z}},\mathbb{R}^{|\mathscr{Y}|};L_{\rm pred},% O(|\mathscr{Y}|(d_{x}+d_{z}))\right)roman_Ξ = italic_M italic_L italic_P ( blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_POSTSUPERSCRIPT , blackboard_R start_POSTSUPERSCRIPT | script_Y | end_POSTSUPERSCRIPT ; italic_L start_POSTSUBSCRIPT roman_pred end_POSTSUBSCRIPT , italic_O ( | script_Y | ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) ) ) is bounded as

Δ,(ξ^,θ^)O~(maxn(Lret+Lpred|𝒴|))+O(maxLret4κ3(dx+dz)log1/3(||))+subscriptΔ^𝜉^𝜃~𝑂subscript𝑛subscript𝐿retsubscript𝐿pred𝒴limit-from𝑂subscriptsuperscriptsubscript𝐿ret4𝜅3subscript𝑑𝑥subscript𝑑𝑧superscript13\displaystyle\Delta_{\ell,\mathscr{I}}(\hat{\xi},\hat{\theta})\leq\tilde{O}% \left(\frac{\ell_{\max}}{\sqrt{n}}\left(L_{\rm ret}+L_{\rm pred}|\mathscr{Y}|% \right)\right)+O\Big{(}\ell_{\max}L_{\rm ret}^{-\tfrac{4\kappa}{3(d_{x}+d_{z})% }}\log^{1/3}(|\mathscr{I}|)\Big{)}+roman_Δ start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( over^ start_ARG italic_ξ end_ARG , over^ start_ARG italic_θ end_ARG ) ≤ over~ start_ARG italic_O end_ARG ( divide start_ARG roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT end_ARG start_ARG square-root start_ARG italic_n end_ARG end_ARG ( italic_L start_POSTSUBSCRIPT roman_ret end_POSTSUBSCRIPT + italic_L start_POSTSUBSCRIPT roman_pred end_POSTSUBSCRIPT | script_Y | ) ) + italic_O ( roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT italic_L start_POSTSUBSCRIPT roman_ret end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - divide start_ARG 4 italic_κ end_ARG start_ARG 3 ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) end_ARG end_POSTSUPERSCRIPT roman_log start_POSTSUPERSCRIPT 1 / 3 end_POSTSUPERSCRIPT ( | script_I | ) ) +
O(Lpred2κ(dx+dz)+(|𝒴|1)exp(max)+c||γexp(max)).𝑂superscriptsubscript𝐿pred2subscript𝜅subscript𝑑𝑥subscript𝑑𝑧𝒴1subscriptsubscript𝑐superscriptsubscript𝛾subscript\displaystyle O\left(L_{\rm pred}^{-\tfrac{2\kappa_{\mathscr{I}}}{(d_{x}+d_{z}% )}}+(|\mathscr{Y}|-1)\exp(-\ell_{\max})+c_{\mathscr{I}}|\mathscr{I}|^{-\gamma_% {\mathscr{I}}}\exp(\ell_{\max})\right).italic_O ( italic_L start_POSTSUBSCRIPT roman_pred end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - divide start_ARG 2 italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_ARG start_ARG ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) end_ARG end_POSTSUPERSCRIPT + ( | script_Y | - 1 ) roman_exp ( - roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ) + italic_c start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT | script_I | start_POSTSUPERSCRIPT - italic_γ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_POSTSUPERSCRIPT roman_exp ( roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ) ) .

Finally, to capture the optimal trade-off under finite data size n𝑛nitalic_n, we consider classes of retriever and predictors that change with the data size, denoted by ΘnsubscriptΘ𝑛\Theta_{n}roman_Θ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT and ΞnsubscriptΞ𝑛\Xi_{n}roman_Ξ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT, with growing depths Lret,nsubscript𝐿ret𝑛L_{{\rm ret},n}italic_L start_POSTSUBSCRIPT roman_ret , italic_n end_POSTSUBSCRIPT and Lpred,nsubscript𝐿pred𝑛L_{{\rm pred},n}italic_L start_POSTSUBSCRIPT roman_pred , italic_n end_POSTSUBSCRIPT respectively. Similarly, we also consider growing upper bound on the loss function by max,nsubscript𝑛\ell_{\max,n}roman_ℓ start_POSTSUBSCRIPT roman_max , italic_n end_POSTSUBSCRIPT. Let dtot=dx+dzsubscript𝑑totsubscript𝑑𝑥subscript𝑑𝑧d_{\rm tot}=d_{x}+d_{z}italic_d start_POSTSUBSCRIPT roman_tot end_POSTSUBSCRIPT = italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT. For Lret,n=n3dtot6dtot+8κsubscript𝐿ret𝑛superscript𝑛3subscript𝑑tot6subscript𝑑tot8𝜅L_{{\rm ret},n}=n^{\tfrac{3d_{\rm tot}}{6d_{\rm tot}+8\kappa}}italic_L start_POSTSUBSCRIPT roman_ret , italic_n end_POSTSUBSCRIPT = italic_n start_POSTSUPERSCRIPT divide start_ARG 3 italic_d start_POSTSUBSCRIPT roman_tot end_POSTSUBSCRIPT end_ARG start_ARG 6 italic_d start_POSTSUBSCRIPT roman_tot end_POSTSUBSCRIPT + 8 italic_κ end_ARG end_POSTSUPERSCRIPT, Lpred,n=(n/|𝒴|)dtot2dtot+4κsubscript𝐿pred𝑛superscript𝑛𝒴subscript𝑑tot2subscript𝑑tot4subscript𝜅L_{{\rm pred},n}=(\sqrt{n}/|\mathscr{Y}|)^{\tfrac{d_{\rm tot}}{2d_{\rm tot}+4% \kappa_{\mathscr{I}}}}italic_L start_POSTSUBSCRIPT roman_pred , italic_n end_POSTSUBSCRIPT = ( square-root start_ARG italic_n end_ARG / | script_Y | ) start_POSTSUPERSCRIPT divide start_ARG italic_d start_POSTSUBSCRIPT roman_tot end_POSTSUBSCRIPT end_ARG start_ARG 2 italic_d start_POSTSUBSCRIPT roman_tot end_POSTSUBSCRIPT + 4 italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_ARG end_POSTSUPERSCRIPT, and max,n=log|𝒴|+κ(dtot+2κ)lognsubscript𝑛𝒴subscript𝜅subscript𝑑tot2subscript𝜅𝑛\ell_{\max,n}=\log|\mathscr{Y}|+\frac{\kappa_{\mathscr{I}}}{(d_{\rm tot}+2% \kappa_{\mathscr{I}})}\log nroman_ℓ start_POSTSUBSCRIPT roman_max , italic_n end_POSTSUBSCRIPT = roman_log | script_Y | + divide start_ARG italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_ARG start_ARG ( italic_d start_POSTSUBSCRIPT roman_tot end_POSTSUBSCRIPT + 2 italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT ) end_ARG roman_log italic_n, the excess risk is bounded by

O(n2κ3dtot+4κ+max(||γ|𝒴|nκdtot+2κ,(n|𝒴|2)κdtot+2κ)).𝑂superscript𝑛2𝜅3subscript𝑑tot4𝜅superscriptsubscript𝛾𝒴superscript𝑛subscript𝜅subscript𝑑tot2subscript𝜅superscript𝑛superscript𝒴2subscript𝜅subscript𝑑tot2subscript𝜅O\bigg{(}n^{-\tfrac{2\kappa}{3d_{\rm tot}+4\kappa}}+\max\Big{(}|\mathscr{I}|^{% -\gamma_{\mathscr{I}}}|\mathscr{Y}|n^{\tfrac{\kappa_{\mathscr{I}}}{d_{\rm tot}% +2\kappa_{\mathscr{I}}}},\big{(}\frac{n}{|\mathscr{Y}|^{2}}\big{)}^{-\tfrac{% \kappa_{\mathscr{I}}}{d_{\rm tot}+2\kappa_{\mathscr{I}}}}\Big{)}\bigg{)}.italic_O ( italic_n start_POSTSUPERSCRIPT - divide start_ARG 2 italic_κ end_ARG start_ARG 3 italic_d start_POSTSUBSCRIPT roman_tot end_POSTSUBSCRIPT + 4 italic_κ end_ARG end_POSTSUPERSCRIPT + roman_max ( | script_I | start_POSTSUPERSCRIPT - italic_γ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_POSTSUPERSCRIPT | script_Y | italic_n start_POSTSUPERSCRIPT divide start_ARG italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_ARG start_ARG italic_d start_POSTSUBSCRIPT roman_tot end_POSTSUBSCRIPT + 2 italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_ARG end_POSTSUPERSCRIPT , ( divide start_ARG italic_n end_ARG start_ARG | script_Y | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ) start_POSTSUPERSCRIPT - divide start_ARG italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_ARG start_ARG italic_d start_POSTSUBSCRIPT roman_tot end_POSTSUBSCRIPT + 2 italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_ARG end_POSTSUPERSCRIPT ) ) .

We should contrast the above result with the prediction when there is no retrieval. Let us assume that the functions p𝖣XYy(x)superscriptsubscript𝑝subscript𝖣𝑋𝑌𝑦𝑥p_{\mathsf{D}_{XY}}^{y}(x)italic_p start_POSTSUBSCRIPT sansserif_D start_POSTSUBSCRIPT italic_X italic_Y end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT ( italic_x ) for all y𝒴𝑦𝒴y\in\mathscr{Y}italic_y ∈ script_Y lies in the Sobolev space with derivative κtruesubscript𝜅true\kappa_{\rm true}italic_κ start_POSTSUBSCRIPT roman_true end_POSTSUBSCRIPT and Lsubscript𝐿L_{\infty}italic_L start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT norm. The predictor excess risk rate with Lpred,n=(n/|𝒴|)dxdx+2κtruesubscript𝐿pred𝑛superscript𝑛𝒴subscript𝑑𝑥subscript𝑑𝑥2subscript𝜅trueL_{{\rm pred},n}=(\sqrt{n}/|\mathscr{Y}|)^{\frac{d_{x}}{d_{x}+2\kappa_{\rm true% }}}italic_L start_POSTSUBSCRIPT roman_pred , italic_n end_POSTSUBSCRIPT = ( square-root start_ARG italic_n end_ARG / | script_Y | ) start_POSTSUPERSCRIPT divide start_ARG italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT end_ARG start_ARG italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + 2 italic_κ start_POSTSUBSCRIPT roman_true end_POSTSUBSCRIPT end_ARG end_POSTSUPERSCRIPT is O((n/|𝒴|2)κtruedx+2κtrue)𝑂superscript𝑛superscript𝒴2subscript𝜅truesubscript𝑑𝑥2subscript𝜅trueO((n/|\mathscr{Y}|^{2})^{-\tfrac{\kappa_{\rm true}}{d_{x}+2\kappa_{\rm true}}})italic_O ( ( italic_n / | script_Y | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - divide start_ARG italic_κ start_POSTSUBSCRIPT roman_true end_POSTSUBSCRIPT end_ARG start_ARG italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + 2 italic_κ start_POSTSUBSCRIPT roman_true end_POSTSUBSCRIPT end_ARG end_POSTSUPERSCRIPT ).

Note that our analysis indicates that we may gain through retrieval: For large enough data store |||𝒴|dtotγ1dtot+2κn2κγ1dtot+2κsuperscript𝒴subscript𝑑totsuperscriptsubscript𝛾1subscript𝑑tot2subscript𝜅superscript𝑛2subscript𝜅superscriptsubscript𝛾1subscript𝑑tot2subscript𝜅|\mathscr{I}|\geq|\mathscr{Y}|^{\tfrac{d_{\rm tot}\gamma_{\mathscr{I}}^{-1}}{d% _{\rm tot}+2\kappa_{\mathscr{I}}}}n^{\tfrac{2\kappa_{\mathscr{I}}\gamma_{% \mathscr{I}}^{-1}}{d_{\rm tot}+2\kappa_{\mathscr{I}}}}| script_I | ≥ | script_Y | start_POSTSUPERSCRIPT divide start_ARG italic_d start_POSTSUBSCRIPT roman_tot end_POSTSUBSCRIPT italic_γ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT end_ARG start_ARG italic_d start_POSTSUBSCRIPT roman_tot end_POSTSUBSCRIPT + 2 italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_ARG end_POSTSUPERSCRIPT italic_n start_POSTSUPERSCRIPT divide start_ARG 2 italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT italic_γ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT end_ARG start_ARG italic_d start_POSTSUBSCRIPT roman_tot end_POSTSUBSCRIPT + 2 italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_ARG end_POSTSUPERSCRIPT, as the data size n𝑛nitalic_n increases and κ>3dtot2dxκtrue𝜅3subscript𝑑tot2subscript𝑑𝑥subscript𝜅true\kappa>\tfrac{3d_{\rm tot}}{2d_{x}}\kappa_{\rm true}italic_κ > divide start_ARG 3 italic_d start_POSTSUBSCRIPT roman_tot end_POSTSUBSCRIPT end_ARG start_ARG 2 italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT end_ARG italic_κ start_POSTSUBSCRIPT roman_true end_POSTSUBSCRIPT and κ>dtotdxκtruesubscript𝜅subscript𝑑totsubscript𝑑𝑥subscript𝜅true\kappa_{\mathscr{I}}>\tfrac{d_{\rm tot}}{d_{x}}\kappa_{\rm true}italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT > divide start_ARG italic_d start_POSTSUBSCRIPT roman_tot end_POSTSUBSCRIPT end_ARG start_ARG italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT end_ARG italic_κ start_POSTSUBSCRIPT roman_true end_POSTSUBSCRIPT (see Fig. 1).

Refer to caption
Refer to caption
Figure 1: Left: Excess risk bound as we vary retriever and predictor size for a fixed n𝑛nitalic_n and \mathscr{I}script_I based on Theorem 3.4. Note that different size combination of predictor and retriever achieves same risk bound. Right: Excess risk bound of RAM as we increase data-store size in contrast to direct MLP predictor with no retrieval. We plot for various values of n𝑛nitalic_n, with each color corresponding to a fixed n𝑛nitalic_n.
Method small base large
small base large small base large small base large
No retriever, train predictor ξ𝜉\xiitalic_ξ
Reverse Cross-Entropy 19.6 25.5 29.1
Fixed retriever θ0subscript𝜃0\theta_{0}italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, train predictor ξ𝜉\xiitalic_ξ
Reverse Cross-Entropy 23.2 26.6 28.3 27.5 32.4 34.7 32.2 36.4 37.8
Fixed predictor ξ(θ0)superscript𝜉subscript𝜃0\xi^{\star}(\theta_{0})italic_ξ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ), train retriever θ𝜃\thetaitalic_θ
EMDR2 23.9 28.5 31.0 29.2 34.2 36.6 33.4 38.0 40.8
PDist 30.1 34.5 38.4 34.0 39.7 42.8 37.6 42.8 44.7
Reverse Cross-Entropy + PG 25.9 30.6 31.7 31.5 36.4 37.9 36.0 40.2 41.4
Reverse Cross-Entropy + TopK 29.4 35.5 37.9 33.8 39.7 43.0 37.2 42.3 45.0
Jointly train predictor ξ𝜉\xiitalic_ξ and retriever θ𝜃\thetaitalic_θ
EMDR2 24.1 30.4 32.7 30.4 35.6 39.3 34.5 39.7 42.1
PDist 28.7 33.2 36.6 33.3 37.1 38.8 36.2 40.2 41.6
Reverse Cross-Entropy + PG 27.1 31.0 32.7 33.3 37.2 38.2 36.5 39.8 41.4
Reverse Cross-Entropy + TopK 32.8 37.8 40.1 36.6 41.8 44.8 38.8 43.8 46.4
Table 1: Exact match accuracy on NQ. We measure the performance of RAMs across various training paradigms and model sizes. Top row specifies the predictor size and the second row specifies the retriever size.

3.6 Connections with prior end-to-end training

We conclude our treatment of end-to-end training of RAMs by drawing parallels between our proposed method with some representative approaches from the literature.

EMDR2 Sachan et al. (2021) minimize the following objective based on the negative log-likelihood:

nEmdr2(ξ,θ;)=1ni[n]logpξ,θ,(y|x)=1ni[n]log(zpθ,(z|xi)pξ(yi|xi,z)).subscriptsuperscriptsuperscriptEmdr2𝑛𝜉𝜃1𝑛subscript𝑖delimited-[]𝑛subscript𝑝𝜉𝜃conditional𝑦𝑥1𝑛subscript𝑖delimited-[]𝑛subscript𝑧subscript𝑝𝜃conditional𝑧subscript𝑥𝑖subscript𝑝𝜉conditionalsubscript𝑦𝑖subscript𝑥𝑖𝑧\displaystyle\mathscr{L}^{\textsc{Emdr}^{2}}_{n}(\xi,\theta;\mathscr{I})=-% \frac{1}{n}\sum_{i\in[n]}\log p_{\xi,\theta,\mathscr{I}}(y|x)=-\frac{1}{n}\sum% _{i\in[n]}\log\Big{(}\sum_{z\in\mathscr{I}}p_{\theta,\mathscr{I}}(z|x_{i})% \cdot p_{\xi}(y_{i}|x_{i},z)\Big{)}.script_L start_POSTSUPERSCRIPT Emdr start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_ξ , italic_θ ; script_I ) = - divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_ξ , italic_θ , script_I end_POSTSUBSCRIPT ( italic_y | italic_x ) = - divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT roman_log ( ∑ start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_θ , script_I end_POSTSUBSCRIPT ( italic_z | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ⋅ italic_p start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) ) . (21)

It follows from the convexity of log()-\log(\cdot)- roman_log ( ⋅ ) and Jensen’s inequality that our objective in (13) upper bounds the EMDR2 objective in (21); as a result, minimizing the former also minimizes the latter but not vice versa.

Perplexity distillation (PDist) Another approach for joint training of RAMs in the literature involves optimizing two distinct objectives for training the predictor and retriever components. For example, Izacard et al. (2022) propose multiple objectives for retriever training, including PDist (Sachan et al., 2023) which is defined as follows:

,nPDist(θ;ξ,)=1ni[n]CE(pξ,PDist(Z|xi,yi),pθ,(Z|xi)),subscriptsuperscriptPDist𝑛𝜃𝜉1𝑛subscript𝑖delimited-[]𝑛CEsubscriptsuperscript𝑝PDist𝜉conditional𝑍subscript𝑥𝑖subscript𝑦𝑖subscript𝑝𝜃conditional𝑍subscript𝑥𝑖\displaystyle\mathscr{L}^{\textsc{PDist}}_{\mathscr{I},n}(\theta;\xi,\mathscr{% I})=\frac{1}{n}\sum_{i\in[n]}\mathrm{CE}\big{(}p^{\textsc{PDist}}_{\xi,% \mathscr{I}}(Z|x_{i},y_{i}),p_{\theta,\mathscr{I}}(Z|x_{i})\big{)},script_L start_POSTSUPERSCRIPT PDist end_POSTSUPERSCRIPT start_POSTSUBSCRIPT script_I , italic_n end_POSTSUBSCRIPT ( italic_θ ; italic_ξ , script_I ) = divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT roman_CE ( italic_p start_POSTSUPERSCRIPT PDist end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ξ , script_I end_POSTSUBSCRIPT ( italic_Z | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) , italic_p start_POSTSUBSCRIPT italic_θ , script_I end_POSTSUBSCRIPT ( italic_Z | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) , (22)

where CE(,)CE\mathrm{CE}(\cdot,\cdot)roman_CE ( ⋅ , ⋅ ) denotes the cross entropy between two distributions and

pξ,PDist(z|x,y)=pξ(y|x,z)/zpξ(y|x,z)z,formulae-sequencesubscriptsuperscript𝑝PDist𝜉conditional𝑧𝑥𝑦subscript𝑝𝜉conditional𝑦𝑥𝑧subscriptsuperscript𝑧subscript𝑝𝜉conditional𝑦𝑥superscript𝑧for-all𝑧\displaystyle p^{\textsc{PDist}}_{\xi,\mathscr{I}}(z|x,y)={p_{\xi}(y|x,z)}/{% \sum_{z^{\prime}\in\mathscr{I}}p_{\xi}(y|x,z^{\prime})}\quad\forall~{}z\in% \mathscr{I},italic_p start_POSTSUPERSCRIPT PDist end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ξ , script_I end_POSTSUBSCRIPT ( italic_z | italic_x , italic_y ) = italic_p start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_y | italic_x , italic_z ) / ∑ start_POSTSUBSCRIPT italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ script_I end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_y | italic_x , italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ∀ italic_z ∈ script_I ,

represents a predictor-assigned distribution over evidences based on their utility towards making correct prediction. As for the predictor training, they optimize an objective akin to (13) with respect to ξ𝜉\xiitalic_ξ. Besides this similarity in the predictor training, our approach for retrieval training has a subtle connection with PDist. Note that PDist optimizes forward cross-entropy between the predictor and the retriever induced distributions to train the retriever. On the other hand, our objective in (13) is closer to 1niCE(pθ,(Z|xi),pξ,PDist)1𝑛subscript𝑖CEsubscript𝑝𝜃conditional𝑍subscript𝑥𝑖subscriptsuperscript𝑝PDist𝜉\frac{1}{n}\sum_{i}\mathrm{CE}(p_{\theta,\mathscr{I}}(Z|x_{i}),p^{\textsc{% PDist}}_{\xi,\mathscr{I}})divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT roman_CE ( italic_p start_POSTSUBSCRIPT italic_θ , script_I end_POSTSUBSCRIPT ( italic_Z | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) , italic_p start_POSTSUPERSCRIPT PDist end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ξ , script_I end_POSTSUBSCRIPT ), the reversed cross-entropy between the two distributions. The former has the “mean-seeking” behavior whereas the latter has the “mode-seeking” behavior (Huszár, 2015; Gu et al., 2023; Agarwal et al., 2023).

Similarity with RLHF/RLAIF  Note that the per-example objective of our retrieval training approach takes the form:

𝔼Zpθ,(|xi)[(hξ(xi,Z),yi)],\displaystyle\mathbb{E}_{Z\sim p_{\theta,\mathscr{I}}(\cdot|x_{i})}\big{[}\ell% \big{(}h_{\xi}(x_{i},Z),y_{i}\big{)}\big{]},blackboard_E start_POSTSUBSCRIPT italic_Z ∼ italic_p start_POSTSUBSCRIPT italic_θ , script_I end_POSTSUBSCRIPT ( ⋅ | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ roman_ℓ ( italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_Z ) , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ] , (23)

i.e., the predictor model provides feedback on the (value) of the evidences sampled by the retriever model. Alternatively, one can view (hξ(xi,Z),yi)subscript𝜉subscript𝑥𝑖𝑍subscript𝑦𝑖-\ell\big{(}h_{\xi}(x_{i},Z),y_{i}\big{)}- roman_ℓ ( italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_Z ) , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) as the reward assigned to the evidence z𝑧zitalic_z by the predictor model hξsubscript𝜉h_{\xi}italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT and retriever model aims to select those evidences that maximize this reward value. This is similar to RLHF (Ziegler et al., 2019) or RLAIF (Bai et al., 2022) paradigm, where the underlying LLM aims to sample those generations which maximize the reward assigned by a reward model. However, note that in RLHF/RLAIF paradigm the policy network and reward model are not jointly trained together unlike in RAM.

4 Experiments

There have been numerous successful practical applications of RAMs in the literature (e.g., Sachan et al. (2021); Izacard et al. (2022)). Here, we present a brief empirical study for such models in order to corroborate the benefits predicted by our theoretical results. In particular, we consider the task of open-domain question answering and show that proposed objective is competitive to the objectives proposed in the literature and observe the trade-offs in model capacity between retriever and predictor model.

Data  Our evaluation is based on two benchmark datasets: NQOpen Kwiatkowski et al. (2019) and TriviaQA Joshi et al. (2017), which serve as sources for supervised examples (x,y)𝑥𝑦(x,y)( italic_x , italic_y ), while chunked Wikipedia 2018 is used as the data-store \mathscr{I}script_I following literature (Karpukhin et al., 2020a). Consistent with established practices, we employ the exact match metric to assess the correspondence between the predicted answers and the ground truth. Additionally, we introduce a recall metric to measure the frequency at which the answer string appears within the retrieved documents.

Models  We implement the retriever component using GTR (Ni et al., 2022) and the predictor component using T5 (Raffel et al., 2020). We sweep across small, base, and large configurations for both retriever and predictor. The details regarding the model sizes, expressed in terms of the number of parameters, are provided in Table 6 (App. C).

Method small base large
small base large small base large small base large
No retriever, train predictor ξ𝜉\xiitalic_ξ
Reverse Cross-Entropy 17.9 23.1 28.0
Fixed retriever θ0subscript𝜃0\theta_{0}italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, train predictor ξ𝜉\xiitalic_ξ
Reverse Cross-Entropy 31.5 34.9 38.8 37.0 40.6 44.4 43.4 45.9 49.7
Fixed predictor ξ(θ0)superscript𝜉subscript𝜃0\xi^{\star}(\theta_{0})italic_ξ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ), train retriever θ𝜃\thetaitalic_θ
EMDR2 34.6 41.3 48.3 40.1 48.2 53.4 46.0 50.7 54.9
PDist 45.7 53.3 57.2 50.8 53.2 61.6 53.5 55.4 62.3
Reverse Cross-Entropy + PG 43.2 46.7 54.3 48.6 56.1 55.1 51.7 56.4 56.7
Reverse Cross-Entropy + TopK 43.6 50.4 54.4 48.6 54.9 58.5 52.1 56.6 60.3
Jointly train predictor ξ𝜉\xiitalic_ξ and retriever θ𝜃\thetaitalic_θ
EMDR2 37.0 43.1 49.7 42.4 50.5 55.6 47.1 53.4 59.2
PDist 46.7 54.3 57.3 48.8 56.7 60.7 51.0 58.5 63.3
Reverse Cross-Entropy + PG 47.0 52.9 55.7 49.9 57.6 61.1 52.1 59.8 59.2
Reverse Cross-Entropy + TopK 46.8 52.9 56.0 49.2 56.6 60.1 52.3 58.8 62.4
Table 2: Exact match accuracy on TriviaQA. We measure the performance of RAMs across various training paradigms and model sizes. Top row specifies the predictor size and the second row specifies the retriever size.

Methods  We compare following approaches: 1) utilizing no retriever, directly training predictor, 2) employing a fixed retriever, but training the predictor, 3) using a fixed predictor, but training the retriever, and 4) conducting joint training of both components. For the joint training and the retriever training phases, we experiment with multiple objectives: EMDR2 (cf. (21)), PDist (cf. (22)), Reverse Cross-Entropy + PG (cf. (45) in App. C.1), and Reverse Cross-Entropy + TopK (cf. (44) in App. C.1). Efficiently implementing any of these objectives is challenging due to the need to compute the gradient with respect to expectation over the entire data-store. We consider two approaches for computing the gradients approximately by: 1) restricting the expectation to top-K elements similar to EMDR2 and PDist; and 2) using REINFORCE (Williams, 1992) to obtain an unbaised estimate. More details can be found in App. C.1.

Observation 1  The addition of a retrieval component markedly enhances performance, as demonstrated in Tables 1 and 2, which present the exact match accuracy. Further improvements are observed when the retriever is specifically trained while keeping the predictor fixed. Joint training emerges as the most effective strategy.

Observation 2  Tables 4 and 5 (App. C) list the recall for the presence of the answer string within the retrieved content. PDist consistently achieves the highest recall, aligning with expectations given its design for distilling the retriever based on the predictor’s scores. However, despite its superior recall, other objectives may lead to better overall performance than PDist, suggesting that different objectives optimize the retriever and predictor with varying efficiencies.

Observation 3  Finally, in Table 3, we report the query per second (QPS), as a proxy for computational cost, achieved by different configuration of retriever and predictor model sizes. For achieving a specific accuracy threshold (e.g., \geq38.8 on NQ), multiple configurations are viable, such as pairing a large predictor with a small retriever, a base model for both, or a small predictor with a large retriever. The associated query per second (QPS) rates for these configurations are 135, 333, and 800, respectively, illustrating that equivalent accuracy levels can be attained with significantly differing QPS rate. This corroborates with our trade-offs in excess risk bounds for MLPs with different capacity in retriever and predictor components as illustrated in Figure 1. Thus, adding capacity to different parts of the model has different repercussion on quality and computational cost.

5 Discussion and related work

small base large
small base large small base large small base large
822.60 819.83 800.89 334.30 333.22 331.06 135.06 135.34 134.87
Table 3: Query per second. We measure the query per second processed by RAMs as a proxy for computational cost across various model sizes. Top row specifies the predictor size and the second row specifies the retriever size.

Several works have proposed some form of retrieval augmented models. Here, we provide a brief account of the evolution of RAMs and discuss how our proposed joint-learning objective and the framework for excess risk analysis compare with existing end-to-end training methods.

Augment with local neighborhood  The first approaches dating back to 1970s employed just augmenting training instance in the local neighborhood of the input space (Stone, 1977, 1980). Such approaches gained a lot of attention as parametric regression was not adequate in various practical applications of the time. This line of work aims to fit a low-degree polynomial at each point in the data set based on a subset of data points, which resulted in a rich literature on local polynomial regression in low dimensions  (Katkovnik and Kheisin, 1979; Cleveland, 1979; Pinsker, 1980; Donoho and Liu, 1988; Ruppert and Wand, 1994; Ibragimov and Has Minskii, 2013). These classical ideas have found their application in many ML algorithms such as face recognition (Jain and Learned-Miller, 2011), dimensionality reduction via local linear embeddings (Roweis and Saul, 2000), domain adaptation (Yang et al., 2021), test time training on neighboring points (Sun et al., 2020; Gandelsman et al., 2022), etc. Recently, Basu et al. (2023) generalized this setup of augmenting with a local neighborhood of the input instance in the context of modern ML models like neural networks and proposed a statistical framework to study such retrieval-based models. However, they do not consider a learned or a specialized distance metric to find the augmenting set, which is critical for realizing good performance in practice (Schonberger et al., 2017; Karpukhin et al., 2020b) and studied in the present work.

Fixed retriever augmentation  Next generation retrieval augmented models started to deploy either a hand crafted or a learned retriever. Zhang et al. (2006) employed SIFT (Lowe, 1999) based retrieval followed by a SVM (Cortes and Vapnik, 1995) classifier to improve performance on multiple vision tasks. Chen et al. (2009) studied generalization bounds for SVM-kNN methods – one of the limited works in this domain with formal analysis. For natural language understanding, methods like TF-IDF (Sparck Jones, 1972) were employed in the tasks like case based reasoning (Leake et al., 1996) and open-domain question answering (ODQA; Voorhees et al. 1999). Unlike many previous methods, one retrieves relevant text passages in ODQA settings as opposed to retrieving labelled training pairs. With introduction of transformers (Vaswani et al., 2017), both retriever and predictor models based on encoder and decoder, respectively, have become popular across various domains, including image classification (Long et al., 2022; Iscen et al., 2023), text classification (Wang et al., 2022; Zemlyanskiy et al., 2022), ODQA (Lee et al., 2019; Izacard and Grave, 2021), language modelling (Borgeaud et al., 2021), and even protein folding prediction (Cramer, 2021). Even using the same transformer model as both retriever and predictor boosts performance in language modeling (Khandelwal et al., 2020). Unlike SVM-kNN (Chen et al., 2009), to best of our knowledge, a formal analysis of retrieval-augmented approaches with modern neural networks is missing from the literature. Interestingly, retrieving examples also helps in-context learning (Rubin et al., 2022; Li et al., 2023). Our framework covers this scenario with z𝑧zitalic_z representing the in-context examples retrieved from a data-store of examples. Our risk bounds can provide insights into why in-context learning with retrieved few-shot examples performs better than a zero-shot model.

End-to-end trained retriever augmentation  For ODQA, Guu et al. (2020) proposed maximizing the marginalized likelihood by considering the retrieved set as a latent variable. EMDR2 (Sachan et al., 2021) optimized the same objective by approximating it based on the retriever induced distribution on the elements that receive top-K scores by the retriever. Hindsight (Paranjape et al., 2022) instead optimizes the ELBO by introducing a variational distribution with access to the outputs. VOD (Liévin et al., 2023) further generalized the standard ELBO based on KL divergence by employing Rényi divergence thereby tightening the lower bound. On the other hand, Atlas (Izacard et al., 2022) proposed an auxiliary loss for training the retriever directly rather than following the latent variable approach. Interestingly, RAG (Lewis et al., 2020) proposed to only train the query encoder for retriever, leaving the retrieval index fixed, thereby alleviating much of the end-to-end training difficulties of RAMs, but at cost of limiting model adaptation flexibility. None of these prior works studied statistical properties vis-à-vis expressivity and generalization of RAMs.

6 Conclusion

In this work, we initiate the development of a theoretical framework to study the statistical properties of RAMs with data-dependent retrieval. Our excess-risks analysis allows us to highlight how retriever and predictor components play complementary roles in reducing approximation error as we increase their respective function class complexity. We surface both theoretically and empirically a Pareto surface achieving the same performance with different size predictors and retrievers. As future work, it would be interesting to study the effect of dynamically updatable data-store and multi-step retrievals for making predictions.

References

  • Agarwal et al. [2023] Rishabh Agarwal, Nino Vieillard, Piotr Stanczyk, Sabela Ramos, Matthieu Geist, and Olivier Bachem. Gkd: Generalized knowledge distillation for auto-regressive sequence models. arXiv preprint arXiv:2306.13649, 2023.
  • Asai et al. [2023] Akari Asai, Zeqiu Wu, Yizhong Wang, Avirup Sil, and Hannaneh Hajishirzi. Self-rag: Learning to retrieve, generate, and critique through self-reflection. arXiv preprint arXiv:2310.11511, 2023.
  • Austin et al. [2021] Jacob Austin, Augustus Odena, Maxwell Nye, Maarten Bosma, Henryk Michalewski, David Dohan, Ellen Jiang, Carrie Cai, Michael Terry, Quoc Le, et al. Program synthesis with large language models. arXiv preprint arXiv:2108.07732, 2021.
  • Bai et al. [2022] Yuntao Bai, Saurav Kadavath, Sandipan Kundu, Amanda Askell, Jackson Kernion, Andy Jones, Anna Chen, Anna Goldie, Azalia Mirhoseini, Cameron McKinnon, et al. Constitutional ai: Harmlessness from ai feedback. arXiv preprint arXiv:2212.08073, 2022.
  • Bartlett et al. [2019] Peter L Bartlett, Nick Harvey, Christopher Liaw, and Abbas Mehrabian. Nearly-tight vc-dimension and pseudodimension bounds for piecewise linear neural networks. The Journal of Machine Learning Research, 20(1):2285–2301, 2019.
  • Basu et al. [2023] Soumya Basu, Ankit Singh Rawat, and Manzil Zaheer. A statistical perspective on retrieval-based models. In Andreas Krause, Emma Brunskill, Kyunghyun Cho, Barbara Engelhardt, Sivan Sabato, and Jonathan Scarlett, editors, Proceedings of the 40th International Conference on Machine Learning, volume 202 of Proceedings of Machine Learning Research, pages 1852–1886. PMLR, 23–29 Jul 2023. URL https://proceedings.mlr.press/v202/basu23a.html.
  • Borgeaud et al. [2021] Sebastian Borgeaud, Arthur Mensch, Jordan Hoffmann, Trevor Cai, Eliza Rutherford, Katie Millican, George van den Driessche, Jean-Baptiste Lespiau, Bogdan Damoc, Aidan Clark, Diego de Las Casas, Aurelia Guy, Jacob Menick, Roman Ring, Tom Hennigan, Saffron Huang, Loren Maggiore, Chris Jones, Albin Cassirer, Andy Brock, Michela Paganini, Geoffrey Irving, Oriol Vinyals, Simon Osindero, Karen Simonyan, Jack W. Rae, Erich Elsen, and Laurent Sifre. Improving language models by retrieving from trillions of tokens. CoRR, abs/2112.04426, 2021.
  • Brown et al. [2020] Tom B. Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared Kaplan, Prafulla Dhariwal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, Sandhini Agarwal, Ariel Herbert-Voss, Gretchen Krueger, Tom Henighan, Rewon Child, Aditya Ramesh, Daniel M. Ziegler, Jeffrey Wu, Clemens Winter, Christopher Hesse, Mark Chen, Eric Sigler, Mateusz Litwin, Scott Gray, Benjamin Chess, Jack Clark, Christopher Berner, Sam McCandlish, Alec Radford, Ilya Sutskever, and Dario Amodei. Language models are few-shot learners, 2020.
  • Burda et al. [2015] Yuri Burda, Roger Grosse, and Ruslan Salakhutdinov. Importance weighted autoencoders. arXiv preprint arXiv:1509.00519, 2015.
  • Chen et al. [2009] Yihua Chen, Eric K Garcia, Maya R Gupta, Ali Rahimi, and Luca Cazzanti. Similarity-based classification: Concepts and algorithms. Journal of Machine Learning Research, 10(3), 2009.
  • Chowdhery et al. [2022] Aakanksha Chowdhery, Sharan Narang, Jacob Devlin, Maarten Bosma, Gaurav Mishra, Adam Roberts, Paul Barham, Hyung Won Chung, Charles Sutton, Sebastian Gehrmann, et al. Palm: Scaling language modeling with pathways. arXiv preprint arXiv:2204.02311, 2022.
  • Cleveland [1979] William S Cleveland. Robust locally weighted regression and smoothing scatterplots. Journal of the American statistical association, 74(368):829–836, 1979.
  • Cortes and Vapnik [1995] Corinna Cortes and Vladimir Vapnik. Support-vector networks. Machine learning, 20(3):273–297, 1995.
  • Cramer [2021] Patrick Cramer. Alphafold2 and the future of structural biology. Nature Structural & Molecular Biology, 28(9):704–705, 2021.
  • Das et al. [2021] Rajarshi Das, Manzil Zaheer, Dung Thai, Ameya Godbole, Ethan Perez, Jay Yoon Lee, Lizhen Tan, Lazaros Polymenakos, and Andrew McCallum. Case-based reasoning for natural language queries over knowledge bases. In Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing, pages 9594–9611, Online and Punta Cana, Dominican Republic, November 2021. Association for Computational Linguistics. doi: 10.18653/v1/2021.emnlp-main.755.
  • Dehghani et al. [2023] Mostafa Dehghani, Josip Djolonga, Basil Mustafa, Piotr Padlewski, Jonathan Heek, Justin Gilmer, Andreas Peter Steiner, Mathilde Caron, Robert Geirhos, Ibrahim Alabdulmohsin, et al. Scaling vision transformers to 22 billion parameters. In International Conference on Machine Learning, pages 7480–7512. PMLR, 2023.
  • Donoho and Liu [1988] David L Donoho and Richard C Liu. The" automatic" robustness of minimum distance functionals. The Annals of Statistics, 16(2):552–586, 1988.
  • Dosovitskiy et al. [2021] Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, and Neil Houlsby. An image is worth 16x16 words: Transformers for image recognition at scale. In International Conference on Learning Representations, 2021.
  • Epasto et al. [2020] Alessandro Epasto, Mohammad Mahdian, Vahab Mirrokni, and Emmanouil Zampetakis. Optimal approximation-smoothness tradeoffs for soft-max functions. Advances in Neural Information Processing Systems, 33:2651–2660, 2020.
  • Gandelsman et al. [2022] Yossi Gandelsman, Yu Sun, Xinlei Chen, and Alexei Efros. Test-time training with masked autoencoders. Advances in Neural Information Processing Systems, 35:29374–29385, 2022.
  • Grathwohl et al. [2021] Will Grathwohl, Kevin Swersky, Milad Hashemi, David Duvenaud, and Chris Maddison. Oops i took a gradient: Scalable sampling for discrete distributions. In International Conference on Machine Learning, pages 3831–3841. PMLR, 2021.
  • Gu et al. [2023] Yuxian Gu, Li Dong, Furu Wei, and Minlie Huang. Knowledge distillation of large language models. arXiv preprint arXiv:2306.08543, 2023.
  • Guu et al. [2020] Kelvin Guu, Kenton Lee, Zora Tung, Panupong Pasupat, and Ming-Wei Chang. Realm: Retrieval-augmented language model pre-training. In Proceedings of the 37th International Conference on Machine Learning, ICML’20. JMLR.org, 2020.
  • henrikl [https://math.stackexchange.com/users/351007/henrikl] henrikl (https://math.stackexchange.com/users/351007/henrikl). 1-smoothness of the symmetric softmax function. Mathematics Stack Exchange, 2021. URL https://math.stackexchange.com/q/4170855. URL:https://math.stackexchange.com/q/4170855 (version: 2021-06-12).
  • Huszár [2015] Ferenc Huszár. How (not) to train your generative model: Scheduled sampling, likelihood, adversary? arXiv preprint arXiv:1511.05101, 2015.
  • Ibragimov and Has Minskii [2013] Ildar Abdulovich Ibragimov and Rafail Zalmanovich Has Minskii. Statistical estimation: asymptotic theory, volume 16. Springer Science & Business Media, 2013.
  • Iscen et al. [2023] Ahmet Iscen, Alireza Fathi, and Cordelia Schmid. Improving image recognition by retrieving from web-scale image-text data. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 19295–19304, 2023.
  • Izacard and Grave [2021] Gautier Izacard and Edouard Grave. Leveraging passage retrieval with generative models for open domain question answering. In Proceedings of the 16th Conference of the European Chapter of the Association for Computational Linguistics: Main Volume, pages 874–880, Online, April 2021. Association for Computational Linguistics. doi: 10.18653/v1/2021.eacl-main.74. URL https://aclanthology.org/2021.eacl-main.74.
  • Izacard et al. [2022] Gautier Izacard, Patrick Lewis, Maria Lomeli, Lucas Hosseini, Fabio Petroni, Timo Schick, Jane Dwivedi-Yu, Armand Joulin, Sebastian Riedel, and Edouard Grave. Few-shot learning with retrieval augmented language models. arXiv preprint arXiv:2208.03299, 2022.
  • Jain and Learned-Miller [2011] Vidit Jain and Erik Learned-Miller. Online domain adaptation of a pre-trained cascade of classifiers. In CVPR 2011, pages 577–584. IEEE, 2011.
  • Joshi et al. [2017] Mandar Joshi, Eunsol Choi, Daniel S Weld, and Luke Zettlemoyer. Triviaqa: A large scale distantly supervised challenge dataset for reading comprehension. In Proceedings of the 55th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), pages 1601–1611, 2017.
  • Karpukhin et al. [2020a] Vladimir Karpukhin, Barlas Oguz, Sewon Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih. Dense passage retrieval for open-domain question answering. In Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP), pages 6769–6781, Online, November 2020a. Association for Computational Linguistics.
  • Karpukhin et al. [2020b] Vladimir Karpukhin, Barlas Oğuz, Sewon Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih. Dense passage retrieval for open-domain question answering. arXiv preprint arXiv:2004.04906, 2020b.
  • Katkovnik and Kheisin [1979] Vladimir Yakovlevich Katkovnik and VE Kheisin. Dynamic stochastic approximation of polynomials drifts. Avtomatika i Telemekhanika, pages 89–98, 1979.
  • Khandelwal et al. [2020] Urvashi Khandelwal, Omer Levy, Dan Jurafsky, Luke Zettlemoyer, and Mike Lewis. Generalization through memorization: Nearest neighbor language models. In International Conference on Learning Representations, 2020.
  • Kwiatkowski et al. [2019] Tom Kwiatkowski, Jennimaria Palomaki, Olivia Redfield, Michael Collins, Ankur Parikh, Chris Alberti, Danielle Epstein, Illia Polosukhin, Jacob Devlin, Kenton Lee, et al. Natural questions: a benchmark for question answering research. Transactions of the Association for Computational Linguistics, 7:453–466, 2019.
  • Leake et al. [1996] David B. Leake, Andrew Kinley, and David C. Wilson. Acquiring case adaptation knowledge: A hybrid approach. In AAAI/IAAI, Vol. 1, 1996. URL https://api.semanticscholar.org/CorpusID:11169287.
  • Lee et al. [2019] Kenton Lee, Ming-Wei Chang, and Kristina Toutanova. Latent retrieval for weakly supervised open domain question answering. In Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics, pages 6086–6096, Florence, Italy, July 2019. Association for Computational Linguistics. doi: 10.18653/v1/P19-1612. URL https://aclanthology.org/P19-1612.
  • Lewis et al. [2020] Patrick Lewis, Ethan Perez, Aleksandra Piktus, Fabio Petroni, Vladimir Karpukhin, Naman Goyal, Heinrich Küttler, Mike Lewis, Wen-tau Yih, Tim Rocktäschel, et al. Retrieval-augmented generation for knowledge-intensive nlp tasks. Advances in Neural Information Processing Systems, 33:9459–9474, 2020.
  • Lewkowycz et al. [2022] Aitor Lewkowycz, Anders Andreassen, David Dohan, Ethan Dyer, Henryk Michalewski, Vinay Ramasesh, Ambrose Slone, Cem Anil, Imanol Schlag, Theo Gutman-Solo, et al. Solving quantitative reasoning problems with language models. Advances in Neural Information Processing Systems, 35:3843–3857, 2022.
  • Li et al. [2023] Yingcong Li, Muhammed Emrullah Ildiz, Dimitris Papailiopoulos, and Samet Oymak. Transformers as algorithms: Generalization and stability in in-context learning. In Andreas Krause, Emma Brunskill, Kyunghyun Cho, Barbara Engelhardt, Sivan Sabato, and Jonathan Scarlett, editors, Proceedings of the 40th International Conference on Machine Learning, volume 202 of Proceedings of Machine Learning Research, pages 19565–19594. PMLR, 23–29 Jul 2023. URL https://proceedings.mlr.press/v202/li23l.html.
  • Liévin et al. [2023] Valentin Liévin, Andreas Geert Motzfeldt, Ida Riis Jensen, and Ole Winther. Variational open-domain question answering. In Andreas Krause, Emma Brunskill, Kyunghyun Cho, Barbara Engelhardt, Sivan Sabato, and Jonathan Scarlett, editors, Proceedings of the 40th International Conference on Machine Learning, volume 202 of Proceedings of Machine Learning Research, pages 20950–20977. PMLR, 23–29 Jul 2023.
  • Lin et al. [2023] Xi Victoria Lin, Xilun Chen, Mingda Chen, Weijia Shi, Maria Lomeli, Rich James, Pedro Rodriguez, Jacob Kahn, Gergely Szilvasy, Mike Lewis, et al. Ra-dit: Retrieval-augmented dual instruction tuning. arXiv preprint arXiv:2310.01352, 2023.
  • Liska et al. [2022] Adam Liska, Tomas Kocisky, Elena Gribovskaya, Tayfun Terzi, Eren Sezener, Devang Agrawal, Cyprien De Masson D’Autume, Tim Scholtes, Manzil Zaheer, Susannah Young, Ellen Gilsenan-Mcmahon, Sophia Austin, Phil Blunsom, and Angeliki Lazaridou. StreamingQA: A benchmark for adaptation to new knowledge over time in question answering models. In Kamalika Chaudhuri, Stefanie Jegelka, Le Song, Csaba Szepesvari, Gang Niu, and Sivan Sabato, editors, Proceedings of the 39th International Conference on Machine Learning, volume 162 of Proceedings of Machine Learning Research, pages 13604–13622. PMLR, 17–23 Jul 2022. URL https://proceedings.mlr.press/v162/liska22a.html.
  • Long et al. [2022] Alexander Long, Wei Yin, Thalaiyasingam Ajanthan, Vu Nguyen, Pulak Purkait, Ravi Garg, Alan Blair, Chunhua Shen, and Anton van den Hengel. Retrieval augmented classification for long-tail visual recognition. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pages 6959–6969, 2022.
  • Lowe [1999] David G Lowe. Object recognition from local scale-invariant features. In Proceedings of the seventh IEEE international conference on computer vision, volume 2, pages 1150–1157. Ieee, 1999.
  • McSherry and Talwar [2007] Frank McSherry and Kunal Talwar. Mechanism design via differential privacy. In 48th Annual IEEE Symposium on Foundations of Computer Science (FOCS’07), pages 94–103. IEEE, 2007.
  • Meinhardt et al. [2022] Tim Meinhardt, Alexander Kirillov, Laura Leal-Taixe, and Christoph Feichtenhofer. Trackformer: Multi-object tracking with transformers. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pages 8844–8854, 2022.
  • Ni et al. [2022] Jianmo Ni, Chen Qu, Jing Lu, Zhuyun Dai, Gustavo Hernandez Abrego, Ji Ma, Vincent Zhao, Yi Luan, Keith Hall, Ming-Wei Chang, et al. Large dual encoders are generalizable retrievers. In Proceedings of the 2022 Conference on Empirical Methods in Natural Language Processing, pages 9844–9855, 2022.
  • OpenAI [2023] OpenAI. Gpt-4 technical report. ArXiv, abs/2303.08774, 2023. URL https://api.semanticscholar.org/CorpusID:257532815.
  • Paranjape et al. [2022] Ashwin Paranjape, Omar Khattab, Christopher Potts, Matei Zaharia, and Christopher D Manning. Hindsight: Posterior-guided training of retrievers for improved open-ended generation. In International Conference on Learning Representations, 2022. URL https://openreview.net/forum?id=Vr_BTpw3wz.
  • Pinsker [1980] Mark Semenovich Pinsker. Optimal filtering of square-integrable signals in gaussian noise. Problemy Peredachi Informatsii, 16(2):52–68, 1980.
  • Raffel et al. [2020] Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena, Yanqi Zhou, Wei Li, and Peter J Liu. Exploring the limits of transfer learning with a unified text-to-text transformer. The Journal of Machine Learning Research, 21(1):5485–5551, 2020.
  • Roweis and Saul [2000] Sam T Roweis and Lawrence K Saul. Nonlinear dimensionality reduction by locally linear embedding. science, 290(5500):2323–2326, 2000.
  • Rubin et al. [2022] Ohad Rubin, Jonathan Herzig, and Jonathan Berant. Learning to retrieve prompts for in-context learning. In Proceedings of the 2022 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, pages 2655–2671, Seattle, United States, July 2022. Association for Computational Linguistics. doi: 10.18653/v1/2022.naacl-main.191. URL https://aclanthology.org/2022.naacl-main.191.
  • Ruppert and Wand [1994] David Ruppert and Matthew P Wand. Multivariate locally weighted least squares regression. The annals of statistics, pages 1346–1370, 1994.
  • Sachan et al. [2021] Devendra Singh Sachan, Siva Reddy, William L. Hamilton, Chris Dyer, and Dani Yogatama. End-to-end training of multi-document reader and retriever for open-domain question answering. In A. Beygelzimer, Y. Dauphin, P. Liang, and J. Wortman Vaughan, editors, Advances in Neural Information Processing Systems, 2021. URL https://openreview.net/forum?id=5KWmB6JePx.
  • Sachan et al. [2023] Devendra Singh Sachan, Mike Lewis, Dani Yogatama, Luke Zettlemoyer, Joelle Pineau, and Manzil Zaheer. Questions are all you need to train a dense passage retriever. Transactions of the Association for Computational Linguistics, 11:600–616, 2023.
  • Schonberger et al. [2017] Johannes L Schonberger, Hans Hardmeier, Torsten Sattler, and Marc Pollefeys. Comparative evaluation of hand-crafted and learned local features. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 1482–1491, 2017.
  • Shalev-Shwartz and Ben-David [2014] Shai Shalev-Shwartz and Shai Ben-David. Understanding machine learning: From theory to algorithms. Cambridge university press, 2014.
  • Shuster et al. [2021] Kurt Shuster, Spencer Poff, Moya Chen, Douwe Kiela, and Jason Weston. Retrieval augmentation reduces hallucination in conversation. arXiv preprint arXiv:2104.07567, 2021.
  • Siegel [2023] Jonathan W Siegel. Optimal approximation rates for deep relu neural networks on sobolev and besov spaces. Journal of Machine Learning Research, 24(357):1–52, 2023.
  • Singhal et al. [2023] Karan Singhal, Shekoofeh Azizi, Tao Tu, S Sara Mahdavi, Jason Wei, Hyung Won Chung, Nathan Scales, Ajay Tanwani, Heather Cole-Lewis, Stephen Pfohl, et al. Large language models encode clinical knowledge. Nature, pages 1–9, 2023.
  • Sparck Jones [1972] Karen Sparck Jones. A statistical interpretation of term specificity and its application in retrieval. Journal of documentation, 28(1):11–21, 1972.
  • Stone [1977] Charles J Stone. Consistent nonparametric regression. The annals of statistics, pages 595–620, 1977.
  • Stone [1980] Charles J Stone. Optimal rates of convergence for nonparametric estimators. The annals of Statistics, pages 1348–1360, 1980.
  • Sun et al. [2020] Yu Sun, Xiaolong Wang, Zhuang Liu, John Miller, Alexei Efros, and Moritz Hardt. Test-time training with self-supervision for generalization under distribution shifts. In International conference on machine learning, pages 9229–9248. PMLR, 2020.
  • Sutton and Barto [2018] Richard S Sutton and Andrew G Barto. Reinforcement learning: An introduction. MIT press, 2018.
  • Thai et al. [2023] Dung Thai, Dhruv Agarwal, Mudit Chaudhary, Rajarshi Das, Manzil Zaheer, Jay-Yoon Lee, Hannaneh Hajishirzi, and Andrew McCallum. Machine reading comprehension using case-based reasoning. arXiv preprint arXiv:2305.14815, 2023.
  • Touvron et al. [2023] Hugo Touvron, Louis Martin, Kevin Stone, Peter Albert, Amjad Almahairi, Yasmine Babaei, Nikolay Bashlykov, Soumya Batra, Prajjwal Bhargava, Shruti Bhosale, et al. Llama 2: Open foundation and fine-tuned chat models. arXiv preprint arXiv:2307.09288, 2023.
  • Vaswani et al. [2017] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. Advances in neural information processing systems, 30, 2017.
  • Voorhees et al. [1999] Ellen M Voorhees et al. The trec-8 question answering track report. In Trec, volume 99, pages 77–82, 1999.
  • Wang et al. [2022] Shuohang Wang, Yichong Xu, Yuwei Fang, Yang Liu, Siqi Sun, Ruochen Xu, Chenguang Zhu, and Michael Zeng. Training data is more valuable than you think: A simple and effective method by retrieving from training data. In Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), pages 3170–3179, Dublin, Ireland, May 2022. Association for Computational Linguistics. doi: 10.18653/v1/2022.acl-long.226. URL https://aclanthology.org/2022.acl-long.226.
  • Williams [1992] Ronald J Williams. Simple statistical gradient-following algorithms for connectionist reinforcement learning. Machine learning, 8:229–256, 1992.
  • Yang et al. [2021] Shiqi Yang, Joost van de Weijer, Luis Herranz, Shangling Jui, et al. Exploiting the intrinsic neighborhood structure for source-free domain adaptation. Advances in neural information processing systems, 34:29393–29405, 2021.
  • Yarotsky [2017] Dmitry Yarotsky. Error bounds for approximations with deep relu networks. Neural Networks, 94:103–114, 2017.
  • Zemlyanskiy et al. [2022] Yury Zemlyanskiy, Michiel de Jong, Joshua Ainslie, Panupong Pasupat, Peter Shaw, Linlu Qiu, Sumit Sanghai, and Fei Sha. Generate-and-retrieve: Use your predictions to improve retrieval for semantic parsing. In Proceedings of the 29th International Conference on Computational Linguistics, pages 4946–4951, Gyeongju, Republic of Korea, October 2022. International Committee on Computational Linguistics. URL https://aclanthology.org/2022.coling-1.438.
  • Zhang et al. [2006] Hao Zhang, Alexander C Berg, Michael Maire, and Jitendra Malik. Svm-knn: Discriminative nearest neighbor classification for visual category recognition. In 2006 IEEE Computer Society Conference on Computer Vision and Pattern Recognition (CVPR’06), volume 2, pages 2126–2136. IEEE, 2006.
  • Zhang [2023] Tong Zhang. Mathematical analysis of machine learning algorithms. Cambridge University Press, 2023.
  • Ziegler et al. [2019] Daniel M Ziegler, Nisan Stiennon, Jeffrey Wu, Tom B Brown, Alec Radford, Dario Amodei, Paul Christiano, and Geoffrey Irving. Fine-tuning language models from human preferences. arXiv preprint arXiv:1909.08593, 2019.

Appendix A Preliminaries

Definition A.1 (Rademacher complexity).

Given a sample 𝒮n={(xi,yi)}i[n]𝒳×𝒴subscript𝒮𝑛subscriptsubscript𝑥𝑖subscript𝑦𝑖𝑖delimited-[]𝑛𝒳𝒴\mathscr{S}_{n}=\{(x_{i},y_{i})\}_{i\in[n]}\subset\mathscr{X}\times\mathscr{Y}script_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT = { ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) } start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT ⊂ script_X × script_Y and a real-valued function class :𝒳×𝒴:𝒳𝒴\mathscr{F}:\mathscr{X}\times\mathscr{Y}\to\mathbb{R}script_F : script_X × script_Y → blackboard_R, the empirical Rademacher complexity of \mathscr{F}script_F with respect to 𝒮nsubscript𝒮𝑛\mathscr{S}_{n}script_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT is defined as

𝒮n()=1n𝔼𝝈[supfi=1nσif(xi,yi)],subscriptsubscript𝒮𝑛1𝑛subscript𝔼𝝈delimited-[]subscriptsupremum𝑓superscriptsubscript𝑖1𝑛subscript𝜎𝑖𝑓subscript𝑥𝑖subscript𝑦𝑖\displaystyle\mathfrak{R}_{\mathscr{S}_{n}}(\mathscr{F})=\frac{1}{n}\mathbb{E}% _{\bm{\sigma}}\left[\sup_{f\in\mathscr{F}}\sum_{i=1}^{n}\sigma_{i}f(x_{i},y_{i% })\right],fraktur_R start_POSTSUBSCRIPT script_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( script_F ) = divide start_ARG 1 end_ARG start_ARG italic_n end_ARG blackboard_E start_POSTSUBSCRIPT bold_italic_σ end_POSTSUBSCRIPT [ roman_sup start_POSTSUBSCRIPT italic_f ∈ script_F end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_σ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_f ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ] , (24)

where 𝝈={σi}i[n]𝝈subscriptsubscript𝜎𝑖𝑖delimited-[]𝑛\bm{\sigma}=\{\sigma_{i}\}_{i\in[n]}bold_italic_σ = { italic_σ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT is a collection of n𝑛nitalic_n i.i.d. Bernoulli random variables. For n𝑛n\in{\mathbb{N}}italic_n ∈ blackboard_N, the Rademacher complexity ¯n()subscript¯𝑛\bar{\mathfrak{R}}_{n}(\mathscr{F})over¯ start_ARG fraktur_R end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( script_F ) and worst case Rademacher complexity n()subscript𝑛\mathfrak{R}_{n}(\mathscr{F})fraktur_R start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( script_F ) are defined as follows.

¯n()=𝔼𝒮n𝖣n[𝒮()],andn()=sup𝒮n(𝒳×𝒴)n𝒮().formulae-sequencesubscript¯𝑛subscript𝔼similar-tosubscript𝒮𝑛superscript𝖣𝑛delimited-[]subscript𝒮andsubscript𝑛subscriptsupremumsimilar-tosubscript𝒮𝑛superscript𝒳𝒴𝑛subscript𝒮\displaystyle\bar{\mathfrak{R}}_{n}(\mathscr{F})=\mathbb{E}_{\mathscr{S}_{n}% \sim\mathsf{D}^{n}}\left[\mathfrak{R}_{\mathscr{S}}(\mathscr{F})\right],\quad% \text{and}\quad\mathfrak{R}_{n}(\mathscr{F})=\sup_{\mathscr{S}_{n}\sim(% \mathscr{X}\times\mathscr{Y})^{n}}\mathfrak{R}_{\mathscr{S}}(\mathscr{F}).over¯ start_ARG fraktur_R end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( script_F ) = blackboard_E start_POSTSUBSCRIPT script_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ∼ sansserif_D start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT end_POSTSUBSCRIPT [ fraktur_R start_POSTSUBSCRIPT script_S end_POSTSUBSCRIPT ( script_F ) ] , and fraktur_R start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( script_F ) = roman_sup start_POSTSUBSCRIPT script_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ∼ ( script_X × script_Y ) start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT end_POSTSUBSCRIPT fraktur_R start_POSTSUBSCRIPT script_S end_POSTSUBSCRIPT ( script_F ) . (25)
Definition A.2 (Covering nsumber).

Let ϵ>0italic-ϵ0\epsilon>0italic_ϵ > 0 and \|\cdot\|∥ ⋅ ∥ be a norm defined over nsuperscript𝑛\mathbb{R}^{n}blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT. Given a function class :𝒳×𝒴:𝒳𝒴\mathscr{F}:\mathscr{X}\times\mathscr{Y}\to\mathbb{R}script_F : script_X × script_Y → blackboard_R and a collection of points 𝒮n={(xi,yi)}i[n]𝒳×𝒴subscript𝒮𝑛subscriptsubscript𝑥𝑖subscript𝑦𝑖𝑖delimited-[]𝑛𝒳𝒴\mathscr{S}_{n}=\{(x_{i},y_{i})\}_{i\in[n]}\subset\mathscr{X}\times\mathscr{Y}script_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT = { ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) } start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT ⊂ script_X × script_Y, we call a set of points {uj}j[m]nsubscriptsubscript𝑢𝑗𝑗delimited-[]𝑚superscript𝑛\{u_{j}\}_{j\in[m]}\subset\mathbb{R}^{n}{ italic_u start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_j ∈ [ italic_m ] end_POSTSUBSCRIPT ⊂ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT an (ϵ,)(\epsilon,\|\cdot\|)( italic_ϵ , ∥ ⋅ ∥ )-cover of \mathscr{F}script_F with respect to 𝒮𝒮\mathscr{S}script_S, if we have

supfminj[m]f(𝒮n)ujϵ,subscriptsupremum𝑓subscript𝑗delimited-[]𝑚norm𝑓subscript𝒮𝑛subscript𝑢𝑗italic-ϵ\displaystyle\sup_{f\in\mathscr{F}}\min_{j\in[m]}\|f(\mathscr{S}_{n})-u_{j}\|% \leq\epsilon,roman_sup start_POSTSUBSCRIPT italic_f ∈ script_F end_POSTSUBSCRIPT roman_min start_POSTSUBSCRIPT italic_j ∈ [ italic_m ] end_POSTSUBSCRIPT ∥ italic_f ( script_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) - italic_u start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∥ ≤ italic_ϵ , (26)

where f(𝒮n)=(f(x1,y1),,f(xn,yn))n𝑓subscript𝒮𝑛𝑓subscript𝑥1subscript𝑦1𝑓subscript𝑥𝑛subscript𝑦𝑛superscript𝑛f(\mathscr{S}_{n})=\big{(}f(x_{1},y_{1}),\ldots,f(x_{n},y_{n})\big{)}\in% \mathbb{R}^{n}italic_f ( script_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) = ( italic_f ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , … , italic_f ( italic_x start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT. The \|\cdot\|∥ ⋅ ∥-covering number 𝒩(,ϵ,;𝒮n)\mathcal{N}(\mathscr{F},\epsilon,\|\cdot\|;\mathscr{S}_{n})caligraphic_N ( script_F , italic_ϵ , ∥ ⋅ ∥ ; script_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) denotes the cardinality of the minimal (ϵ,)(\epsilon,\|\cdot\|)( italic_ϵ , ∥ ⋅ ∥ )-cover of \mathscr{F}script_F with respect to 𝒮nsubscript𝒮𝑛\mathscr{S}_{n}script_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT. In particular, if \|\cdot\|∥ ⋅ ∥ is a psubscript𝑝\ell_{p}roman_ℓ start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT norm (e.g. v=(j=1d|vj|p)1/pnorm𝑣superscriptsuperscriptsubscript𝑗1𝑑superscriptsubscript𝑣𝑗𝑝1𝑝\|v\|=(\sum_{j=1}^{d}|v_{j}|^{p})^{1/p}∥ italic_v ∥ = ( ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT | italic_v start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT | start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 1 / italic_p end_POSTSUPERSCRIPT for vd𝑣superscript𝑑v\in\mathbb{R}^{d}italic_v ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT), then we simply use 𝒩(,ϵ,Lp;𝒮n)\mathcal{N}(\mathscr{F},\epsilon,\|\cdot\|_{L_{p}};\mathscr{S}_{n})caligraphic_N ( script_F , italic_ϵ , ∥ ⋅ ∥ start_POSTSUBSCRIPT italic_L start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT end_POSTSUBSCRIPT ; script_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) to denote the corresponding psubscript𝑝\ell_{p}roman_ℓ start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT-covering number.

When 𝒮nsubscript𝒮𝑛\mathscr{S}_{n}script_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT is unambiguous we may drop it, i.e., we use 𝒩(,ϵ,Lp)\mathcal{N}(\mathscr{F},\epsilon,\|\cdot\|_{L_{p}})caligraphic_N ( script_F , italic_ϵ , ∥ ⋅ ∥ start_POSTSUBSCRIPT italic_L start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) to represent the covering number.

Definition A.3 (Multi-layer perceptron (MLP)).

We consider for both retrieval and predictor, the class of multi-layer-perceptron, aka fully connected Deep Neural Network, with Relu nonlinearity σ(x)=max(x,0)𝜎𝑥𝑥0\sigma(x)=\max(x,0)italic_σ ( italic_x ) = roman_max ( italic_x , 0 ). An MLP is specified by the number of layers L𝐿Litalic_L, and the width W𝑊Witalic_W. We define with weight 𝐖d2×d1𝐖superscriptsubscript𝑑2superscriptsubscript𝑑1\mathbf{W}\in\mathbb{R}^{d_{2}}\times\mathbb{R}^{d_{1}}bold_W ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT × blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT and bias bd2𝑏superscriptsubscript𝑑2b\in\mathbb{R}^{d_{2}}italic_b ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, an affine transform A𝐖,b(d1,d2):x𝐖x+b:subscript𝐴𝐖𝑏superscriptsubscript𝑑1superscriptsubscript𝑑2𝑥𝐖𝑥𝑏A_{\mathbf{W},b}(\mathbb{R}^{d_{1}},\mathbb{R}^{d_{2}}):x\to\mathbf{W}x+bitalic_A start_POSTSUBSCRIPT bold_W , italic_b end_POSTSUBSCRIPT ( blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT , blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ) : italic_x → bold_W italic_x + italic_b. Let σA𝐖,b(d1,d2)𝜎subscript𝐴𝐖𝑏superscriptsubscript𝑑1superscriptsubscript𝑑2\sigma\circ A_{\mathbf{W},b}(\mathbb{R}^{d_{1}},\mathbb{R}^{d_{2}})italic_σ ∘ italic_A start_POSTSUBSCRIPT bold_W , italic_b end_POSTSUBSCRIPT ( blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT , blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ) define the elementwise application of the Relu non-linearity on the affine transform. The class of L𝐿Litalic_L layers and W𝑊Witalic_W width MLP is defined as

MLP(d,k;W,L)={A𝐖L,bLσA𝐖L1,bL1σA𝐖0,b0},MLPsuperscript𝑑superscript𝑘𝑊𝐿subscript𝐴subscript𝐖𝐿subscript𝑏𝐿𝜎subscript𝐴subscript𝐖𝐿1subscript𝑏𝐿1𝜎subscript𝐴subscript𝐖0subscript𝑏0{\rm MLP}(\mathbb{R}^{d},\mathbb{R}^{k};W,L)=\{A_{\mathbf{W}_{L},b_{L}}\circ% \sigma\circ A_{\mathbf{W}_{L-1},b_{L-1}}\circ\dots\sigma\circ A_{\mathbf{W}_{0% },b_{0}}\},roman_MLP ( blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT , blackboard_R start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_W , italic_L ) = { italic_A start_POSTSUBSCRIPT bold_W start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∘ italic_σ ∘ italic_A start_POSTSUBSCRIPT bold_W start_POSTSUBSCRIPT italic_L - 1 end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT italic_L - 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∘ … italic_σ ∘ italic_A start_POSTSUBSCRIPT bold_W start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_b start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT } , (27)

where 𝐖Lk×Wsubscript𝐖𝐿superscript𝑘𝑊\mathbf{W}_{L}\in\mathbb{R}^{k\times W}bold_W start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_k × italic_W end_POSTSUPERSCRIPT and bLksubscript𝑏𝐿superscript𝑘b_{L}\in\mathbb{R}^{k}italic_b start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT; 𝐖iW×Wsubscript𝐖𝑖superscript𝑊𝑊\mathbf{W}_{i}\in\mathbb{R}^{W\times W}bold_W start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_W × italic_W end_POSTSUPERSCRIPT and biWsubscript𝑏𝑖superscript𝑊b_{i}\in\mathbb{R}^{W}italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_W end_POSTSUPERSCRIPT, for 1i(L1)1𝑖𝐿11\leq i\leq(L-1)1 ≤ italic_i ≤ ( italic_L - 1 ); and 𝐖0W×dsubscript𝐖0superscript𝑊𝑑\mathbf{W}_{0}\in\mathbb{R}^{W\times d}bold_W start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_W × italic_d end_POSTSUPERSCRIPT and b0Wsubscript𝑏0superscript𝑊b_{0}\in\mathbb{R}^{W}italic_b start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_W end_POSTSUPERSCRIPT.

Definition A.4 (Sobolev space).

For p1𝑝1p\geq 1italic_p ≥ 1, we denote the set of functions with finite Lpsubscript𝐿𝑝L_{p}italic_L start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT norm over ΩΩ\Omegaroman_Ω as Lp(Ω)subscript𝐿𝑝ΩL_{p}(\Omega)italic_L start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( roman_Ω ), i.e., for any fLp(Ω)𝑓subscript𝐿𝑝Ωf\in L_{p}(\Omega)italic_f ∈ italic_L start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( roman_Ω ), fLp(Ω)(sΩf(s)p𝑑s)1/p<subscriptnorm𝑓subscript𝐿𝑝Ωsuperscriptsubscript𝑠Ω𝑓superscript𝑠𝑝differential-d𝑠1𝑝\|f\|_{L_{p}(\Omega)}\triangleq\big{(}\int_{s\in\Omega}f(s)^{p}ds\big{)}^{1/p}<\infty∥ italic_f ∥ start_POSTSUBSCRIPT italic_L start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( roman_Ω ) end_POSTSUBSCRIPT ≜ ( ∫ start_POSTSUBSCRIPT italic_s ∈ roman_Ω end_POSTSUBSCRIPT italic_f ( italic_s ) start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT italic_d italic_s ) start_POSTSUPERSCRIPT 1 / italic_p end_POSTSUPERSCRIPT < ∞. Note that for p=𝑝p=\inftyitalic_p = ∞, we have fL(Ω)=esssupsΩ|f(s)|.subscriptnorm𝑓subscript𝐿Ωesssubscriptsupremum𝑠Ω𝑓𝑠\|f\|_{L_{\infty}(\Omega)}=\mathrm{ess}\sup_{s\in\Omega}|f(s)|.∥ italic_f ∥ start_POSTSUBSCRIPT italic_L start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT ( roman_Ω ) end_POSTSUBSCRIPT = roman_ess roman_sup start_POSTSUBSCRIPT italic_s ∈ roman_Ω end_POSTSUBSCRIPT | italic_f ( italic_s ) | . Let αd𝛼superscript𝑑\alpha\in\mathbb{N}^{d}italic_α ∈ blackboard_N start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT denote a multi-index, and |α|=idαi𝛼subscript𝑖𝑑subscript𝛼𝑖|\alpha|=\sum_{i\in d}\alpha_{i}| italic_α | = ∑ start_POSTSUBSCRIPT italic_i ∈ italic_d end_POSTSUBSCRIPT italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT be it’s degree. We denote by Dαsuperscript𝐷𝛼D^{\alpha}italic_D start_POSTSUPERSCRIPT italic_α end_POSTSUPERSCRIPT the weak-derivative with respect to multi-index α𝛼\alphaitalic_α for any function.

For an integer κ>0𝜅0\kappa>0italic_κ > 0, the Sobolev semi-norm Wκ(Lp(Ω))superscript𝑊𝜅subscript𝐿𝑝ΩW^{\kappa}(L_{p}(\Omega))italic_W start_POSTSUPERSCRIPT italic_κ end_POSTSUPERSCRIPT ( italic_L start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( roman_Ω ) ) for a function f𝑓fitalic_f that has weak-derivatives of order κ𝜅\kappaitalic_κ is defined as

1p<,|f|Wκ(Lp(Ω))(α:|α|=κDαfLp(Ω)p)1/p and |f|Wκ(L(Ω))maxα:|α|=κDαfL(Ω).formulae-sequencefor-all1𝑝subscript𝑓superscript𝑊𝜅subscript𝐿𝑝Ωsuperscriptsubscript:𝛼𝛼𝜅superscriptsubscriptnormsuperscript𝐷𝛼𝑓subscript𝐿𝑝Ω𝑝1𝑝 and subscript𝑓superscript𝑊𝜅subscript𝐿Ωsubscript:𝛼𝛼𝜅subscriptnormsuperscript𝐷𝛼𝑓subscript𝐿Ω\forall 1\leq p<\infty,|f|_{W^{\kappa}(L_{p}(\Omega))}\triangleq\big{(}\sum_{% \alpha:|\alpha|=\kappa}\|D^{\alpha}f\|_{L_{p}(\Omega)}^{p}\big{)}^{1/p}\text{ % and }|f|_{W^{\kappa}(L_{\infty}(\Omega))}\triangleq\max_{\alpha:|\alpha|=% \kappa}\|D^{\alpha}f\|_{L_{\infty}(\Omega)}.∀ 1 ≤ italic_p < ∞ , | italic_f | start_POSTSUBSCRIPT italic_W start_POSTSUPERSCRIPT italic_κ end_POSTSUPERSCRIPT ( italic_L start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( roman_Ω ) ) end_POSTSUBSCRIPT ≜ ( ∑ start_POSTSUBSCRIPT italic_α : | italic_α | = italic_κ end_POSTSUBSCRIPT ∥ italic_D start_POSTSUPERSCRIPT italic_α end_POSTSUPERSCRIPT italic_f ∥ start_POSTSUBSCRIPT italic_L start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( roman_Ω ) end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 1 / italic_p end_POSTSUPERSCRIPT and | italic_f | start_POSTSUBSCRIPT italic_W start_POSTSUPERSCRIPT italic_κ end_POSTSUPERSCRIPT ( italic_L start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT ( roman_Ω ) ) end_POSTSUBSCRIPT ≜ roman_max start_POSTSUBSCRIPT italic_α : | italic_α | = italic_κ end_POSTSUBSCRIPT ∥ italic_D start_POSTSUPERSCRIPT italic_α end_POSTSUPERSCRIPT italic_f ∥ start_POSTSUBSCRIPT italic_L start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT ( roman_Ω ) end_POSTSUBSCRIPT .

The Sobolev norm Wκ(Lp(Ω))superscript𝑊𝜅subscript𝐿𝑝ΩW^{\kappa}(L_{p}(\Omega))italic_W start_POSTSUPERSCRIPT italic_κ end_POSTSUPERSCRIPT ( italic_L start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( roman_Ω ) ) for the same function f𝑓fitalic_f is defined as fWκ(Lp(Ω))=fLp(Ω)+|f|Wκ(Lp(Ω)).subscriptnorm𝑓superscript𝑊𝜅subscript𝐿𝑝Ωsubscriptnorm𝑓subscript𝐿𝑝Ωsubscript𝑓superscript𝑊𝜅subscript𝐿𝑝Ω\|f\|_{W^{\kappa}(L_{p}(\Omega))}=\|f\|_{L_{p}(\Omega)}+|f|_{W^{\kappa}(L_{p}(% \Omega))}.∥ italic_f ∥ start_POSTSUBSCRIPT italic_W start_POSTSUPERSCRIPT italic_κ end_POSTSUPERSCRIPT ( italic_L start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( roman_Ω ) ) end_POSTSUBSCRIPT = ∥ italic_f ∥ start_POSTSUBSCRIPT italic_L start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( roman_Ω ) end_POSTSUBSCRIPT + | italic_f | start_POSTSUBSCRIPT italic_W start_POSTSUPERSCRIPT italic_κ end_POSTSUPERSCRIPT ( italic_L start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( roman_Ω ) ) end_POSTSUBSCRIPT . A function f𝑓fitalic_f with all weak-derivatives of order κ𝜅\kappaitalic_κ, and a finite Wκ(Lp(Ω))superscript𝑊𝜅subscript𝐿𝑝ΩW^{\kappa}(L_{p}(\Omega))italic_W start_POSTSUPERSCRIPT italic_κ end_POSTSUPERSCRIPT ( italic_L start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( roman_Ω ) ) norm lies in the Sobolev space with κ𝜅\kappaitalic_κ derivatives and Lp(Ω)subscript𝐿𝑝ΩL_{p}(\Omega)italic_L start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( roman_Ω ) norm.

In our approximation guarantees for MLP retreiver and predictor classes later, we use [Siegel, 2023, Theorem 1]. We restate the result here for completeness.

Theorem A.5 (Restated Siegel [2023] Theorem 1).

Let f0:Ω:subscript𝑓0Ωf_{0}:\Omega\to\mathbb{R}italic_f start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT : roman_Ω → blackboard_R be a function in the Sobolev space with κ𝜅\kappaitalic_κ derivatives and norm Lq(Ω)subscript𝐿𝑞ΩL_{q}(\Omega)italic_L start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ( roman_Ω ), for q[1,)𝑞1q\in[1,\infty)italic_q ∈ [ 1 , ∞ ) and κ(0,)𝜅0\kappa\in(0,\infty)italic_κ ∈ ( 0 , ∞ ). For Ω=[1,1]dΩsuperscript11𝑑{\Omega=[-1,1]^{d}}roman_Ω = [ - 1 , 1 ] start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT and any p[1,)𝑝1p\in[1,\infty)italic_p ∈ [ 1 , ∞ ) satisfying (1/q1/p)s/d1𝑞1𝑝𝑠𝑑(1/q-1/p)\leq s/d( 1 / italic_q - 1 / italic_p ) ≤ italic_s / italic_d, we have for C=c(κ,d)<𝐶𝑐𝜅𝑑C=c(\kappa,d)<\inftyitalic_C = italic_c ( italic_κ , italic_d ) < ∞, and W=25d+31𝑊25𝑑31W=25d+31italic_W = 25 italic_d + 31

inffMLP(d,;W,L)ff0Lp(Ω)Cf0Wκ(Lq(Ω))L2κd.subscriptinfimum𝑓MLPsuperscript𝑑𝑊𝐿subscriptnorm𝑓subscript𝑓0subscript𝐿𝑝Ω𝐶subscriptnormsubscript𝑓0superscript𝑊𝜅subscript𝐿𝑞Ωsuperscript𝐿2𝜅𝑑\inf_{f\in{\rm MLP}(\mathbb{R}^{d},\mathbb{R};W,L)}\|f-f_{0}\|_{L_{p}(\Omega)}% \leq C\|f_{0}\|_{W^{\kappa}(L_{q}(\Omega))}L^{-\tfrac{2\kappa}{d}}.roman_inf start_POSTSUBSCRIPT italic_f ∈ roman_MLP ( blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT , blackboard_R ; italic_W , italic_L ) end_POSTSUBSCRIPT ∥ italic_f - italic_f start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT italic_L start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( roman_Ω ) end_POSTSUBSCRIPT ≤ italic_C ∥ italic_f start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT italic_W start_POSTSUPERSCRIPT italic_κ end_POSTSUPERSCRIPT ( italic_L start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ( roman_Ω ) ) end_POSTSUBSCRIPT italic_L start_POSTSUPERSCRIPT - divide start_ARG 2 italic_κ end_ARG start_ARG italic_d end_ARG end_POSTSUPERSCRIPT .

Our generalization bounds leverages VC Dimension bounds of MLP Bartlett et al. [2019]. Here, we state some results from Bartlett et al. [2019] for completeness.

Definition A.6 (VC dimension and growth of a binary function class).

For \mathcal{H}caligraphic_H, a class of functions from 𝒜𝒜\mathcal{A}caligraphic_A to {0,1}01\{0,1\}{ 0 , 1 } the growth function of \mathcal{H}caligraphic_H evaluated on an input set of size m𝑚mitalic_m, is defined as

Π(m)=maxa1,,am𝒜|{h(a1),,h(am):h}|.subscriptΠ𝑚subscriptsubscript𝑎1subscript𝑎𝑚𝒜conditional-setsubscript𝑎1subscript𝑎𝑚\Pi_{\mathcal{H}}(m)=\max_{a_{1},\dots,a_{m}\in\mathcal{A}}|\{h(a_{1}),\dots,h% (a_{m}):h\in\mathcal{H}\}|.roman_Π start_POSTSUBSCRIPT caligraphic_H end_POSTSUBSCRIPT ( italic_m ) = roman_max start_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_a start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ∈ caligraphic_A end_POSTSUBSCRIPT | { italic_h ( italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , … , italic_h ( italic_a start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) : italic_h ∈ caligraphic_H } | .

The VCdim()VCdim{\rm VCdim}(\mathcal{H})roman_VCdim ( caligraphic_H ) is defined as the largest m𝑚mitalic_m such that Π(m)=2msubscriptΠ𝑚superscript2𝑚\Pi_{\mathcal{H}}(m)=2^{m}roman_Π start_POSTSUBSCRIPT caligraphic_H end_POSTSUBSCRIPT ( italic_m ) = 2 start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT, where if no such m𝑚mitalic_m is there we have VCdim()=VCdim{\rm VCdim}(\mathcal{H})=\inftyroman_VCdim ( caligraphic_H ) = ∞.

Definition A.7 (Pseudo dimension of real valued function class).

Let \mathcal{F}caligraphic_F be a class of functions from some space 𝒜𝒜\mathcal{A}caligraphic_A to the real \mathbb{R}blackboard_R. The pseudo-dimension of class \mathcal{F}caligraphic_F, denoted by Pdim()𝑃𝑑𝑖𝑚Pdim(\mathcal{F})italic_P italic_d italic_i italic_m ( caligraphic_F ), is the largest m𝑚mitalic_m such that there exists {a1,,am,r1,,rm}𝒜m×msubscript𝑎1subscript𝑎𝑚subscript𝑟1subscript𝑟𝑚superscript𝒜𝑚superscript𝑚\{a_{1},\dots,a_{m},r_{1},\dots,r_{m}\}\in\mathcal{A}^{m}\times\mathbb{R}^{m}{ italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_a start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT , italic_r start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_r start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT } ∈ caligraphic_A start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT × blackboard_R start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT such that for any binary sequence {b1,,bm}{0,1}msubscript𝑏1subscript𝑏𝑚superscript01𝑚\{b_{1},\dots,b_{m}\}\in\{0,1\}^{m}{ italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_b start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT } ∈ { 0 , 1 } start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT there exists a function f𝑓f\in\mathcal{F}italic_f ∈ caligraphic_F satisfying i:f(ai)>ribi=1:for-all𝑖𝑓subscript𝑎𝑖subscript𝑟𝑖iffsubscript𝑏𝑖1\forall i:f(a_{i})>r_{i}\iff b_{i}=1∀ italic_i : italic_f ( italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) > italic_r start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ⇔ italic_b start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 1.

Note that the pseudo-dimension is same as the VC dimension of the subgraph of class \mathcal{F}caligraphic_F which is used in Zhang [2023]. Let sgn(x)=𝟙(x0)𝑠𝑔𝑛𝑥1𝑥0sgn(x)=\mathbbm{1}(x\geq 0)italic_s italic_g italic_n ( italic_x ) = blackboard_1 ( italic_x ≥ 0 ). We denote by sgn(f)𝑠𝑔𝑛𝑓sgn(f)italic_s italic_g italic_n ( italic_f ) the sign of the function f:𝒜:𝑓𝒜f:\mathcal{A}\to\mathbb{R}italic_f : caligraphic_A → blackboard_R. We define sgn(){sgn(f):f}𝑠𝑔𝑛conditional-set𝑠𝑔𝑛𝑓𝑓sgn(\mathcal{F})\triangleq\{sgn(f):f\in\mathcal{F}\}italic_s italic_g italic_n ( caligraphic_F ) ≜ { italic_s italic_g italic_n ( italic_f ) : italic_f ∈ caligraphic_F }, and the VC dimension of the real valued function class \mathcal{F}caligraphic_F as VCdim()VCdim(sgn())VCdimVCdim𝑠𝑔𝑛{\rm VCdim}(\mathcal{F})\triangleq{\rm VCdim}(sgn(\mathcal{F}))roman_VCdim ( caligraphic_F ) ≜ roman_VCdim ( italic_s italic_g italic_n ( caligraphic_F ) ). It is mentioned in Bartlett et al. [2019] that for neural network with a fixed architecture and fixed activation functions, namely class MLPMLP{\rm MLP}roman_MLP, we have that VCdim(sgn(MLP))=Pdim(MLP)VCdim𝑠𝑔𝑛MLP𝑃𝑑𝑖𝑚MLP{\rm VCdim}(sgn({\rm MLP}))=Pdim({\rm MLP})roman_VCdim ( italic_s italic_g italic_n ( roman_MLP ) ) = italic_P italic_d italic_i italic_m ( roman_MLP ).

We now adapt [Bartlett et al., 2019, Theorem 6] to use it for the class MLP(d,;L,W)MLPsuperscript𝑑𝐿𝑊{\rm MLP}(\mathbb{R}^{d},\mathbb{R};L,W)roman_MLP ( blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT , blackboard_R ; italic_L , italic_W ) the employs the Relu non-linearity. In terminology of Bartlett et al. [2019], it amounts to focusing on the number of breakpoints pnt=1𝑝𝑛𝑡1pnt=1italic_p italic_n italic_t = 1, and degree of polynomial deg=1𝑑𝑒𝑔1deg=1italic_d italic_e italic_g = 1.111Originally in Bartlett et al. [2019] degree is denoted by d𝑑ditalic_d and break point by p𝑝pitalic_p, but we use deg𝑑𝑒𝑔degitalic_d italic_e italic_g and pnt𝑝𝑛𝑡pntitalic_p italic_n italic_t, respectively, to avoid confusion. These notations are used for the rest of the paper.

Theorem A.8 (Adaptation of Bartlett et al. [2019] Theorem 6).

Consider the neural network class MLP(d,;L,W)MLPsuperscript𝑑𝐿𝑊{\rm MLP}(\mathbb{R}^{d},\mathbb{R};L,W)roman_MLP ( blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT , blackboard_R ; italic_L , italic_W ) that has the Relu non-linearity. Let Wtotal,lsubscript𝑊𝑡𝑜𝑡𝑎𝑙𝑙W_{total,l}italic_W start_POSTSUBSCRIPT italic_t italic_o italic_t italic_a italic_l , italic_l end_POSTSUBSCRIPT denote the total number of parameters (weights and biases) up to layer l(L1)𝑙𝐿1l\leq(L-1)italic_l ≤ ( italic_L - 1 ), and klsubscript𝑘𝑙k_{l}italic_k start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT denote the number of non-linear units (output width) in layer l𝑙litalic_l. Also define the parameters L¯=1Wtotal,Ll=1LWtotal,lL¯𝐿1subscript𝑊𝑡𝑜𝑡𝑎𝑙𝐿superscriptsubscript𝑙1𝐿subscript𝑊𝑡𝑜𝑡𝑎𝑙𝑙𝐿\bar{L}=\tfrac{1}{W_{total,L}}\sum_{l=1}^{L}W_{total,l}\leq Lover¯ start_ARG italic_L end_ARG = divide start_ARG 1 end_ARG start_ARG italic_W start_POSTSUBSCRIPT italic_t italic_o italic_t italic_a italic_l , italic_L end_POSTSUBSCRIPT end_ARG ∑ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_t italic_o italic_t italic_a italic_l , italic_l end_POSTSUBSCRIPT ≤ italic_L, and R=l=1LlklL2W𝑅superscriptsubscript𝑙1𝐿𝑙subscript𝑘𝑙superscript𝐿2𝑊R=\sum_{l=1}^{L}lk_{l}\leq L^{2}Witalic_R = ∑ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT italic_l italic_k start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ≤ italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_W. Then for the function class \mathcal{F}caligraphic_F of all real-valued functions computed by the MLP class and m𝑚mitalic_m

Πsgn()(m)l=1L2(2emkllWtotal,l)Wtotal,l(4emL)Wtotal,L.subscriptΠ𝑠𝑔𝑛𝑚superscriptsubscriptproduct𝑙1𝐿2superscript2𝑒𝑚subscript𝑘𝑙𝑙subscript𝑊𝑡𝑜𝑡𝑎𝑙𝑙subscript𝑊𝑡𝑜𝑡𝑎𝑙𝑙superscript4𝑒𝑚𝐿subscript𝑊𝑡𝑜𝑡𝑎𝑙𝐿\Pi_{sgn(\mathcal{F})}(m)\leq\prod_{l=1}^{L}2\left(\frac{2emk_{l}l}{W_{total,l% }}\right)^{W_{total,l}}\leq(4emL)^{W_{total,L}}.roman_Π start_POSTSUBSCRIPT italic_s italic_g italic_n ( caligraphic_F ) end_POSTSUBSCRIPT ( italic_m ) ≤ ∏ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT 2 ( divide start_ARG 2 italic_e italic_m italic_k start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT italic_l end_ARG start_ARG italic_W start_POSTSUBSCRIPT italic_t italic_o italic_t italic_a italic_l , italic_l end_POSTSUBSCRIPT end_ARG ) start_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_t italic_o italic_t italic_a italic_l , italic_l end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ≤ ( 4 italic_e italic_m italic_L ) start_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_t italic_o italic_t italic_a italic_l , italic_L end_POSTSUBSCRIPT end_POSTSUPERSCRIPT .

Moreover, we have

VCdim()=L+L¯Wtotal,Llog2(4ellkllog2(l2elkl))=O(L¯Wtotal,Llog(L2W)).VCdim𝐿¯𝐿subscript𝑊𝑡𝑜𝑡𝑎𝑙𝐿subscript24𝑒subscript𝑙𝑙subscript𝑘𝑙subscript2subscript𝑙2𝑒𝑙subscript𝑘𝑙𝑂¯𝐿subscript𝑊𝑡𝑜𝑡𝑎𝑙𝐿superscript𝐿2𝑊{\rm VCdim}(\mathscr{F})=L+\bar{L}W_{total,L}\log_{2}(4e\sum_{l}lk_{l}\log_{2}% (\sum_{l}2elk_{l}))=O(\bar{L}W_{total,L}\log(L^{2}W)).roman_VCdim ( script_F ) = italic_L + over¯ start_ARG italic_L end_ARG italic_W start_POSTSUBSCRIPT italic_t italic_o italic_t italic_a italic_l , italic_L end_POSTSUBSCRIPT roman_log start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( 4 italic_e ∑ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT italic_l italic_k start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT roman_log start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( ∑ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT 2 italic_e italic_l italic_k start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) ) = italic_O ( over¯ start_ARG italic_L end_ARG italic_W start_POSTSUBSCRIPT italic_t italic_o italic_t italic_a italic_l , italic_L end_POSTSUBSCRIPT roman_log ( italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_W ) ) .

We generalize the above result to capture the MLP with multi dimensional output as used by our predictor.

Theorem A.9 (Multi-ouput version of Bartlett et al. [2019] Theorem 6).

Consider the neural network class MLP(d,k;L,W)MLPsuperscript𝑑superscript𝑘𝐿𝑊{\rm MLP}(\mathbb{R}^{d},\mathbb{R}^{k};L,W)roman_MLP ( blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT , blackboard_R start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_L , italic_W ) that has Relu non-linearity with Wtotal,lsubscript𝑊𝑡𝑜𝑡𝑎𝑙𝑙W_{total,l}italic_W start_POSTSUBSCRIPT italic_t italic_o italic_t italic_a italic_l , italic_l end_POSTSUBSCRIPT, klsubscript𝑘𝑙k_{l}italic_k start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT, L¯¯𝐿\bar{L}over¯ start_ARG italic_L end_ARG, and R𝑅Ritalic_R as defined in Theorem A.8. We denote by \mathcal{F}caligraphic_F the class of functions f:d×[k]:𝑓superscript𝑑delimited-[]𝑘f:\mathbb{R}^{d}\times[k]\to\mathbb{R}italic_f : blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT × [ italic_k ] → blackboard_R where f(,k)𝑓𝑘f(\cdot,k)italic_f ( ⋅ , italic_k ) is the k𝑘kitalic_k-th output coordinate of a neural network in class MLP(d,k;L,W)MLPsuperscript𝑑superscript𝑘𝐿𝑊{\rm MLP}(\mathbb{R}^{d},\mathbb{R}^{k};L,W)roman_MLP ( blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT , blackboard_R start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ; italic_L , italic_W ). Then, we have

VCdim()=L+L¯Wtotal,Llog2(4ellkllog2(l2elkl))=O(L¯Wtotal,Llog(L2W)).VCdim𝐿¯𝐿subscript𝑊𝑡𝑜𝑡𝑎𝑙𝐿subscript24𝑒subscript𝑙𝑙subscript𝑘𝑙subscript2subscript𝑙2𝑒𝑙subscript𝑘𝑙𝑂¯𝐿subscript𝑊𝑡𝑜𝑡𝑎𝑙𝐿superscript𝐿2𝑊{\rm VCdim}(\mathscr{F})=L+\bar{L}W_{total,L}\log_{2}(4e\sum_{l}lk_{l}\log_{2}% (\sum_{l}2elk_{l}))=O(\bar{L}W_{total,L}\log(L^{2}W)).roman_VCdim ( script_F ) = italic_L + over¯ start_ARG italic_L end_ARG italic_W start_POSTSUBSCRIPT italic_t italic_o italic_t italic_a italic_l , italic_L end_POSTSUBSCRIPT roman_log start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( 4 italic_e ∑ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT italic_l italic_k start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT roman_log start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( ∑ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT 2 italic_e italic_l italic_k start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) ) = italic_O ( over¯ start_ARG italic_L end_ARG italic_W start_POSTSUBSCRIPT italic_t italic_o italic_t italic_a italic_l , italic_L end_POSTSUBSCRIPT roman_log ( italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_W ) ) .
Proof.

Let aWtotal,L𝑎superscriptsubscript𝑊𝑡𝑜𝑡𝑎𝑙𝐿a\in\mathbb{R}^{W_{total,L}}italic_a ∈ blackboard_R start_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_t italic_o italic_t italic_a italic_l , italic_L end_POSTSUBSCRIPT end_POSTSUPERSCRIPT parameterize one function f𝑓f\in\mathcal{F}italic_f ∈ caligraphic_F. Based on the discussions, we need to find the VCdimVCdim{\rm VCdim}roman_VCdim of the set {sgn(f(xi,j,a)):aWtotal,L,j[k],i[m]}conditional-set𝑠𝑔𝑛𝑓subscript𝑥𝑖𝑗𝑎formulae-sequence𝑎superscriptsubscript𝑊𝑡𝑜𝑡𝑎𝑙𝐿formulae-sequence𝑗delimited-[]𝑘𝑖delimited-[]𝑚\{sgn(f(x_{i},j,a)):a\in\mathbb{R}^{W_{total,L}},j\in[k],i\in[m]\}{ italic_s italic_g italic_n ( italic_f ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_j , italic_a ) ) : italic_a ∈ blackboard_R start_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_t italic_o italic_t italic_a italic_l , italic_L end_POSTSUBSCRIPT end_POSTSUPERSCRIPT , italic_j ∈ [ italic_k ] , italic_i ∈ [ italic_m ] }. Note that here we have f:Wtotal,L×[m]×[k]:𝑓superscriptsubscript𝑊𝑡𝑜𝑡𝑎𝑙𝐿delimited-[]𝑚delimited-[]𝑘f:\mathbb{R}^{W_{total,L}}\times[m]\times[k]\to\mathbb{R}italic_f : blackboard_R start_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_t italic_o italic_t italic_a italic_l , italic_L end_POSTSUBSCRIPT end_POSTSUPERSCRIPT × [ italic_m ] × [ italic_k ] → blackboard_R is a function mapping the tuple (xi,j,a)subscript𝑥𝑖𝑗𝑎(x_{i},j,a)( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_j , italic_a ) to a real number.

We obtain the following inequality.

|{sgn(f(xi,j,a)):aWtotal,L,i[m],j[k]}|conditional-set𝑠𝑔𝑛𝑓subscript𝑥𝑖𝑗𝑎formulae-sequence𝑎superscriptsubscript𝑊𝑡𝑜𝑡𝑎𝑙𝐿formulae-sequence𝑖delimited-[]𝑚𝑗delimited-[]𝑘\displaystyle|\{sgn(f(x_{i},j,a)):a\in\mathbb{R}^{W_{total,L}},i\in[m],j\in[k]\}|| { italic_s italic_g italic_n ( italic_f ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_j , italic_a ) ) : italic_a ∈ blackboard_R start_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_t italic_o italic_t italic_a italic_l , italic_L end_POSTSUBSCRIPT end_POSTSUPERSCRIPT , italic_i ∈ [ italic_m ] , italic_j ∈ [ italic_k ] } |
j[k]|{sgn(f(xi,j,a)):aWtotal,L,i[m]}|absentsubscript𝑗delimited-[]𝑘conditional-set𝑠𝑔𝑛𝑓subscript𝑥𝑖𝑗𝑎formulae-sequence𝑎superscriptsubscript𝑊𝑡𝑜𝑡𝑎𝑙𝐿𝑖delimited-[]𝑚\displaystyle\qquad\leq\sum_{j\in[k]}|\{sgn(f(x_{i},j,a)):a\in\mathbb{R}^{W_{% total,L}},i\in[m]\}|≤ ∑ start_POSTSUBSCRIPT italic_j ∈ [ italic_k ] end_POSTSUBSCRIPT | { italic_s italic_g italic_n ( italic_f ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_j , italic_a ) ) : italic_a ∈ blackboard_R start_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_t italic_o italic_t italic_a italic_l , italic_L end_POSTSUBSCRIPT end_POSTSUPERSCRIPT , italic_i ∈ [ italic_m ] } |
j[k]Πsgn(MLP(d,;L,W))(m)absentsubscript𝑗delimited-[]𝑘subscriptΠ𝑠𝑔𝑛MLPsuperscript𝑑𝐿𝑊𝑚\displaystyle\qquad\leq\sum_{j\in[k]}\Pi_{sgn({\rm MLP}(\mathbb{R}^{d},\mathbb% {R};L,W))}(m)≤ ∑ start_POSTSUBSCRIPT italic_j ∈ [ italic_k ] end_POSTSUBSCRIPT roman_Π start_POSTSUBSCRIPT italic_s italic_g italic_n ( roman_MLP ( blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT , blackboard_R ; italic_L , italic_W ) ) end_POSTSUBSCRIPT ( italic_m )
k2L(2eRm/Wtotal,L)Wtotal,L.absent𝑘superscript2𝐿superscript2𝑒𝑅𝑚subscript𝑊𝑡𝑜𝑡𝑎𝑙𝐿subscript𝑊𝑡𝑜𝑡𝑎𝑙𝐿\displaystyle\qquad\leq k2^{L}(2eRm/{W_{total,L}})^{W_{total,L}}.≤ italic_k 2 start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT ( 2 italic_e italic_R italic_m / italic_W start_POSTSUBSCRIPT italic_t italic_o italic_t italic_a italic_l , italic_L end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_t italic_o italic_t italic_a italic_l , italic_L end_POSTSUBSCRIPT end_POSTSUPERSCRIPT .

In the first inequality, we partition the set with respect to j[k]𝑗delimited-[]𝑘j\in[k]italic_j ∈ [ italic_k ]. For the second inequality we notice that for a fixed j𝑗jitalic_j the function f(xi,j,a)𝑓subscript𝑥𝑖𝑗𝑎f(x_{i},j,a)italic_f ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_j , italic_a ) is computed by MLP(d,;L,W)MLPsuperscript𝑑𝐿𝑊{\rm MLP}(\mathbb{R}^{d},\mathbb{R};L,W)roman_MLP ( blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT , blackboard_R ; italic_L , italic_W ) and bound it with the growth function Πsgn(MLP(d,;L,W))subscriptΠ𝑠𝑔𝑛MLPsuperscript𝑑𝐿𝑊\Pi_{sgn({\rm MLP}(\mathbb{R}^{d},\mathbb{R};L,W))}roman_Π start_POSTSUBSCRIPT italic_s italic_g italic_n ( roman_MLP ( blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT , blackboard_R ; italic_L , italic_W ) ) end_POSTSUBSCRIPT over m𝑚mitalic_m points. Therefore, for the third inequality we can apply the specified bound for Πsgn(MLP(d,;L,W))(m)subscriptΠ𝑠𝑔𝑛MLPsuperscript𝑑𝐿𝑊𝑚\Pi_{sgn({\rm MLP}(\mathbb{R}^{d},\mathbb{R};L,W))}(m)roman_Π start_POSTSUBSCRIPT italic_s italic_g italic_n ( roman_MLP ( blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT , blackboard_R ; italic_L , italic_W ) ) end_POSTSUBSCRIPT ( italic_m ) inside the proof of Theorem 6 in Bartlett et al. [2019]. Note that, here we have specialized for Relu nonlinearlity, i.e. breaking point pnt=1𝑝𝑛𝑡1pnt=1italic_p italic_n italic_t = 1, and degree deg=1𝑑𝑒𝑔1deg=1italic_d italic_e italic_g = 1. Applying Lemma 6 in Bartlett et al. [2019] we obtain

VCdim()Llog(k)+Wtotal,Llog2(4eRlog2(4eR))=O(Llog(k)+L2W2log(LW)).VCdim𝐿𝑘subscript𝑊𝑡𝑜𝑡𝑎𝑙𝐿subscript24𝑒𝑅subscript24𝑒𝑅𝑂𝐿𝑘superscript𝐿2superscript𝑊2𝐿𝑊{\rm VCdim}(\mathscr{F})\leq L\log(k)+W_{total,L}\log_{2}(4eR\log_{2}(4eR))=O(% L\log(k)+L^{2}W^{2}\log(LW)).roman_VCdim ( script_F ) ≤ italic_L roman_log ( italic_k ) + italic_W start_POSTSUBSCRIPT italic_t italic_o italic_t italic_a italic_l , italic_L end_POSTSUBSCRIPT roman_log start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( 4 italic_e italic_R roman_log start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( 4 italic_e italic_R ) ) = italic_O ( italic_L roman_log ( italic_k ) + italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_W start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_log ( italic_L italic_W ) ) .

Finally, we state a bounded version of the Gibb’s inequality, that lower bounds the cross entropy of two discrete probability distributions.

Proposition A.10 (Truncated Gibb’s inequality).

Let us consider two discrete distributions α,β𝛼𝛽\alpha,\betaitalic_α , italic_β over alphabet size K𝐾Kitalic_K. Then for any constant C>0𝐶0C>0italic_C > 0, we have

i=1Kαimin(C,log(βi))i=1Kαimin(C,log(αi))(K1)exp(C).superscriptsubscript𝑖1𝐾subscript𝛼𝑖𝐶subscript𝛽𝑖superscriptsubscript𝑖1𝐾subscript𝛼𝑖𝐶subscript𝛼𝑖𝐾1𝐶\sum_{i=1}^{K}\alpha_{i}\min(C,-\log(\beta_{i}))\geq\sum_{i=1}^{K}\alpha_{i}% \min(C,-\log(\alpha_{i}))-(K-1)\exp(-C).∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT roman_min ( italic_C , - roman_log ( italic_β start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) ≥ ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT roman_min ( italic_C , - roman_log ( italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) - ( italic_K - 1 ) roman_exp ( - italic_C ) .
Proof.

For two discrete distributions α,β𝛼𝛽\alpha,\betaitalic_α , italic_β over alphabet size K𝐾Kitalic_K.

i=1Kαimin(C,log(βi))superscriptsubscript𝑖1𝐾subscript𝛼𝑖𝐶subscript𝛽𝑖\displaystyle\sum_{i=1}^{K}\alpha_{i}\min(C,-\log(\beta_{i}))∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT roman_min ( italic_C , - roman_log ( italic_β start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) )
=i=1Kαilog(max(exp(C),βi))absentsuperscriptsubscript𝑖1𝐾subscript𝛼𝑖𝐶subscript𝛽𝑖\displaystyle=-\sum_{i=1}^{K}\alpha_{i}\log(\max(\exp(-C),\beta_{i}))= - ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT roman_log ( roman_max ( roman_exp ( - italic_C ) , italic_β start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) )
=i=1Kαilog(αi)+i=1Kαilog(αi/max(exp(C),βi))absentsuperscriptsubscript𝑖1𝐾subscript𝛼𝑖subscript𝛼𝑖superscriptsubscript𝑖1𝐾subscript𝛼𝑖subscript𝛼𝑖𝐶subscript𝛽𝑖\displaystyle=-\sum_{i=1}^{K}\alpha_{i}\log(\alpha_{i})+\sum_{i=1}^{K}\alpha_{% i}\log\big{(}\alpha_{i}/\max(\exp(-C),\beta_{i})\big{)}= - ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT roman_log ( italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) + ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT roman_log ( italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT / roman_max ( roman_exp ( - italic_C ) , italic_β start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) )
i=1Kαilog(αi)+(i=1Kαi)log(i=1Kαi/i=1Kmax(exp(C),βi)))\displaystyle\geq-\sum_{i=1}^{K}\alpha_{i}\log(\alpha_{i})+(\sum_{i=1}^{K}% \alpha_{i})\log\big{(}\sum_{i=1}^{K}\alpha_{i}/\sum_{i=1}^{K}\max(\exp(-C),% \beta_{i}))\big{)}≥ - ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT roman_log ( italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) + ( ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) roman_log ( ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT / ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT roman_max ( roman_exp ( - italic_C ) , italic_β start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) )
i=1Kαilog(αi)log(1+(K1)exp(C))absentsuperscriptsubscript𝑖1𝐾subscript𝛼𝑖subscript𝛼𝑖1𝐾1𝐶\displaystyle\geq-\sum_{i=1}^{K}\alpha_{i}\log(\alpha_{i})-\log(1+(K-1)\exp(-C))≥ - ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT roman_log ( italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - roman_log ( 1 + ( italic_K - 1 ) roman_exp ( - italic_C ) )
i=1Kαilog(αi)(K1)exp(C)absentsuperscriptsubscript𝑖1𝐾subscript𝛼𝑖subscript𝛼𝑖𝐾1𝐶\displaystyle\geq-\sum_{i=1}^{K}\alpha_{i}\log(\alpha_{i})-(K-1)\exp(-C)≥ - ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT roman_log ( italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - ( italic_K - 1 ) roman_exp ( - italic_C )
i=1Kαimin(C,log(αi))(K1)exp(C)absentsuperscriptsubscript𝑖1𝐾subscript𝛼𝑖𝐶subscript𝛼𝑖𝐾1𝐶\displaystyle\geq\sum_{i=1}^{K}\alpha_{i}\min(C,-\log(\alpha_{i}))-(K-1)\exp(-C)≥ ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT roman_min ( italic_C , - roman_log ( italic_α start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) - ( italic_K - 1 ) roman_exp ( - italic_C )

The first inequality follows from the log-sum-inequality. The second inequality follows as i=1Kmax(exp(C),βi)superscriptsubscript𝑖1𝐾𝐶subscript𝛽𝑖\sum_{i=1}^{K}\max(\exp(-C),\beta_{i})∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT roman_max ( roman_exp ( - italic_C ) , italic_β start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) is maximized by setting one βi=1subscript𝛽𝑖1\beta_{i}=1italic_β start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 1 for some 1iK1𝑖𝐾1\leq i\leq K1 ≤ italic_i ≤ italic_K, while the rest are set to 00. The second last inequality follows by log(1+x)x1𝑥𝑥\log(1+x)\leq xroman_log ( 1 + italic_x ) ≤ italic_x. The final inequality follows by taking a minimum with C𝐶Citalic_C can only decrease the value. ∎

Appendix B Derivations of main result

As discussed in Section 2, the objective here is to study the excess risk in Eq. (12) which has three main components, generalization error, retriever approximation error, and predictor approximation error (cf. (3.1)). In this section, we structure our results somewhat differently than the main body to capture the general setting of learning retriever with a fixed predictor, and vice versa. We first prove excess risk bounds for learning the retriever, then excess risk bounds for learning the predictor. Finally, we combine the results to obtain the guarantees for jointly learning the retriever and the predictor presented in the paper. For the rest of the analysis we need to specify the space of retrieved examples to define the complexity of the gap function (cf. 3.1). We recall that our retrieved samples are embedded in a compact subspace of dzsuperscriptsubscript𝑑𝑧\mathbb{R}^{d_{z}}blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, and 𝒳𝒳\mathscr{X}script_X is a compact subspace of dsuperscript𝑑\mathbb{R}^{d}blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT. In particular, for simplicity, we assume that 𝒳[1,1]dx𝒳superscript11subscript𝑑𝑥\mathscr{X}\subseteq[-1,1]^{d_{x}}script_X ⊆ [ - 1 , 1 ] start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT end_POSTSUPERSCRIPT and 𝒵[1,1]dz𝒵superscript11subscript𝑑𝑧\mathscr{Z}\subseteq[-1,1]^{d_{z}}script_Z ⊆ [ - 1 , 1 ] start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_POSTSUPERSCRIPT.

B.1 Learning the retriever

We first study learning the retriever over class ΘΘ\Thetaroman_Θ when the predictor ξ𝜉\xiitalic_ξ is fixed. The task of learning the retriever corresponds to minimizing the following over θΘ𝜃Θ\theta\in\Thetaitalic_θ ∈ roman_Θ,

𝔼(X,Y)𝒟[𝔼Zpθ(|X)(hξ(X,Z),Y)]\displaystyle\mathbb{E}_{(X,Y)\sim\mathscr{D}}[\mathbb{E}_{Z\sim p_{\theta}(% \cdot|X)}\ell(h_{\xi}(X,Z),Y)]blackboard_E start_POSTSUBSCRIPT ( italic_X , italic_Y ) ∼ script_D end_POSTSUBSCRIPT [ blackboard_E start_POSTSUBSCRIPT italic_Z ∼ italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ | italic_X ) end_POSTSUBSCRIPT roman_ℓ ( italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_X , italic_Z ) , italic_Y ) ] =𝔼X[𝔼Zpθ(|X)𝔼Y|X(hξ(X,Z),Y)|X]]=𝔼X[𝔼Zpθ(|X)gξ(X,Z)],\displaystyle=\mathbb{E}_{X}\big{[}\mathbb{E}_{Z\sim p_{\theta}(\cdot|X)}% \mathbb{E}_{Y|X}\ell(h_{\xi}(X,Z),Y)|X]\big{]}=\mathbb{E}_{X}\big{[}\mathbb{E}% _{Z\sim p_{\theta}(\cdot|X)}g_{\xi}(X,Z)\big{]},= blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ blackboard_E start_POSTSUBSCRIPT italic_Z ∼ italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ | italic_X ) end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_Y | italic_X end_POSTSUBSCRIPT roman_ℓ ( italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_X , italic_Z ) , italic_Y ) | italic_X ] ] = blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ blackboard_E start_POSTSUBSCRIPT italic_Z ∼ italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ | italic_X ) end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_X , italic_Z ) ] ,

where gξ(X,Z)=𝔼Y|X(hξ(X,Z),Y)subscript𝑔𝜉𝑋𝑍subscript𝔼conditional𝑌𝑋subscript𝜉𝑋𝑍𝑌g_{\xi}(X,Z)=\mathbb{E}_{Y|X}\ell(h_{\xi}(X,Z),Y)italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_X , italic_Z ) = blackboard_E start_POSTSUBSCRIPT italic_Y | italic_X end_POSTSUBSCRIPT roman_ℓ ( italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_X , italic_Z ) , italic_Y ). We have a closed form for the optimal retriever when not restricted within a function class. The optimal retriever is p,ξ(z|x)=𝟙argminzgξ(x,z)(z)superscript𝑝𝜉conditional𝑧𝑥subscript1subscriptargminsuperscript𝑧subscript𝑔𝜉𝑥superscript𝑧𝑧p^{\ast,\xi}(z|x)=\mathbbm{1}_{\operatorname*{arg\,min}_{z^{\prime}\in\mathscr% {I}}g_{\xi}(x,z^{\prime})}(z)italic_p start_POSTSUPERSCRIPT ∗ , italic_ξ end_POSTSUPERSCRIPT ( italic_z | italic_x ) = blackboard_1 start_POSTSUBSCRIPT start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ script_I end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x , italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) end_POSTSUBSCRIPT ( italic_z ), where a tie is broken arbitrarily.

For the fixed predictor ξ𝜉\xiitalic_ξ, let θ^(ξ)^𝜃𝜉\hat{\theta}(\xi)over^ start_ARG italic_θ end_ARG ( italic_ξ ) minimize the empirical risk given, and θ(ξ)𝜃𝜉\theta(\xi)italic_θ ( italic_ξ ) minimize the population risk over the class ΘΘ\Thetaroman_Θ, i.e.

θ^(ξ)=argminθΘ1ni[n]zpθ(z|xi)(hξ(xi,z),yi),^𝜃𝜉subscriptargmin𝜃Θ1𝑛subscript𝑖delimited-[]𝑛subscript𝑧subscript𝑝𝜃conditional𝑧subscript𝑥𝑖subscript𝜉subscript𝑥𝑖𝑧subscript𝑦𝑖\displaystyle\hat{\theta}(\xi)=\operatorname*{arg\,min}_{\theta\in\Theta}\frac% {1}{n}\sum_{i\in[n]}\sum_{z\in\mathscr{I}}p_{\theta}(z|x_{i})\ell\big{(}h_{\xi% }(x_{i},z),y_{i}\big{)},over^ start_ARG italic_θ end_ARG ( italic_ξ ) = start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_θ ∈ roman_Θ end_POSTSUBSCRIPT divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) roman_ℓ ( italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ,
θ(ξ)=argminθΘ𝔼X[𝔼Zpθ(|X)gξ(X,Z)].\displaystyle\theta(\xi)=\operatorname*{arg\,min}_{\theta\in\Theta}\mathbb{E}_% {X}\big{[}\mathbb{E}_{Z\sim p_{\theta}(\cdot|X)}g_{\xi}(X,Z)\big{]}.italic_θ ( italic_ξ ) = start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_θ ∈ roman_Θ end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ blackboard_E start_POSTSUBSCRIPT italic_Z ∼ italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ | italic_X ) end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_X , italic_Z ) ] .

Here, the probability is defined using the softmax operator for a given θΘ𝜃Θ\theta\in\Thetaitalic_θ ∈ roman_Θ as follows:

pθ,(z|x)=exp(rθ(x,z))zexp(rθ(x,z)),z,x𝒳.formulae-sequencesubscript𝑝𝜃conditional𝑧𝑥subscript𝑟𝜃𝑥𝑧subscriptsuperscript𝑧subscript𝑟𝜃𝑥superscript𝑧formulae-sequencefor-all𝑧𝑥𝒳p_{\theta,\mathscr{I}}\big{(}z|x\big{)}=\frac{\exp\big{(}r_{\theta}(x,z)\big{)% }}{\sum_{z^{\prime}\in\mathscr{I}}\exp\big{(}r_{\theta}(x,z^{\prime})\big{)}},% \quad\forall~{}z\in\mathscr{I},x\in\mathscr{X}.italic_p start_POSTSUBSCRIPT italic_θ , script_I end_POSTSUBSCRIPT ( italic_z | italic_x ) = divide start_ARG roman_exp ( italic_r start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x , italic_z ) ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ script_I end_POSTSUBSCRIPT roman_exp ( italic_r start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x , italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ) end_ARG , ∀ italic_z ∈ script_I , italic_x ∈ script_X .
Hardness of retrieval:

We recall the Sobolev space with κ𝜅\kappaitalic_κ derivatives as defined in Section A. The following is the restatement of Assumption 3.1 but for any ξΞ𝜉Ξ\xi\in\Xiitalic_ξ ∈ roman_Ξ and not just the optimal one ξsuperscript𝜉\xi^{*}italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT.

Assumption B.1 (Complexity of gξsubscriptg𝜉\mathrm{g}_{\xi}roman_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT).

For any ξΞ𝜉Ξ\xi\in\Xiitalic_ξ ∈ roman_Ξ, there exists a baseline bξ:[1,1]dx:subscript𝑏𝜉superscript11subscript𝑑𝑥b_{\xi}:[-1,1]^{d_{x}}\to\mathbb{R}italic_b start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT : [ - 1 , 1 ] start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT end_POSTSUPERSCRIPT → blackboard_R such that the function gapξ:[1,1]dx+dz:subscriptgap𝜉superscript11subscript𝑑𝑥subscript𝑑𝑧\mathrm{gap}_{\xi}:[-1,1]^{d_{x}+d_{z}}\to\mathbb{R}roman_gap start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT : [ - 1 , 1 ] start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_POSTSUPERSCRIPT → blackboard_R with baseline bξsubscript𝑏𝜉b_{\xi}italic_b start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT, as defined by gapξ(x,z)=(gξ(x,z)bξ(x))subscriptgap𝜉𝑥𝑧subscript𝑔𝜉𝑥𝑧subscript𝑏𝜉𝑥\mathrm{gap}_{\xi}(x,z)=(g_{\xi}(x,z)-b_{\xi}(x))roman_gap start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x , italic_z ) = ( italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x , italic_z ) - italic_b start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x ) ) lies in the Sobolev space with κ𝜅\kappaitalic_κ derivatives and L([1,1]dx+dz)subscript𝐿superscript11subscript𝑑𝑥subscript𝑑𝑧L_{\infty}([-1,1]^{d_{x}+d_{z}})italic_L start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT ( [ - 1 , 1 ] start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ) norm.

As noted in the main text this means that the predictor loss has a possibly ‘complex’ component bξ(x)subscript𝑏𝜉𝑥b_{\xi}(x)italic_b start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x ), and a relatively ‘smooth’ component gapξ(x,z)𝑔𝑎subscript𝑝𝜉𝑥𝑧gap_{\xi}(x,z)italic_g italic_a italic_p start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x , italic_z ) that ensures two retrieved examples that are close leads to similar loss for the predictor ξ𝜉\xiitalic_ξ for any x𝒳𝑥𝒳x\in\mathscr{X}italic_x ∈ script_X. As gapξ(x,z)𝑔𝑎subscript𝑝𝜉𝑥𝑧gap_{\xi}(x,z)italic_g italic_a italic_p start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x , italic_z ) solely determines the optimal retrieved set, it’s smoothness defines the hardness of underlying retrieval task.

Excess risk decomposition:

With the fixed predictor ξ𝜉\xiitalic_ξ, excess risk in (12) takes the following form

R,(ξ,θ^(ξ))R,(fopt,)subscript𝑅𝜉^𝜃𝜉subscript𝑅superscriptsubscript𝑓opt\displaystyle R_{\ell,\mathscr{I}}(\xi,\hat{\theta}(\xi))-R_{\ell,\mathscr{I}}% (f_{{\rm opt},\mathscr{I}}^{\ell})italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_ξ , over^ start_ARG italic_θ end_ARG ( italic_ξ ) ) - italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT roman_opt , script_I end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT )
=θ{θ(ξ),θ^(ξ)}|1ni[n]zpθ(z|xi)(hξ(xi,z),yi)𝔼X[𝔼Zpθ(|X)gξ(X,Z)]|retriever generalization error\displaystyle\qquad=\underbrace{\sum_{\theta\in\{\theta(\xi),\hat{\theta}(\xi)% \}}\big{|}\frac{1}{n}\sum_{i\in[n]}\sum_{z\in\mathscr{I}}p_{\theta}(z|x_{i})% \ell\big{(}h_{\xi}(x_{i},z),y_{i}\big{)}-\mathbb{E}_{X}\big{[}\mathbb{E}_{Z% \sim p_{\theta}(\cdot|X)}g_{\xi}(X,Z)\big{]}\big{|}}_{\text{retriever % generalization error}}= under⏟ start_ARG ∑ start_POSTSUBSCRIPT italic_θ ∈ { italic_θ ( italic_ξ ) , over^ start_ARG italic_θ end_ARG ( italic_ξ ) } end_POSTSUBSCRIPT | divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) roman_ℓ ( italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ blackboard_E start_POSTSUBSCRIPT italic_Z ∼ italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ | italic_X ) end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_X , italic_Z ) ] | end_ARG start_POSTSUBSCRIPT retriever generalization error end_POSTSUBSCRIPT
+R,(ξ,θ(ξ))𝔼X[minzgξ(X,z)]retriever approximation error+𝔼X[minzgξ(X,z)]R,(fopt,)error from predictor ξ.subscriptsubscript𝑅𝜉𝜃𝜉subscript𝔼𝑋delimited-[]subscript𝑧subscript𝑔𝜉𝑋𝑧retriever approximation errorsubscriptsubscript𝔼𝑋delimited-[]subscript𝑧subscript𝑔𝜉𝑋𝑧subscript𝑅superscriptsubscript𝑓opterror from predictor ξ\displaystyle\qquad+\underbrace{R_{\ell,\mathscr{I}}(\xi,\theta(\xi))-\mathbb{% E}_{X}\big{[}\min_{z\in\mathscr{I}}g_{\xi}(X,z)\big{]}}_{\text{retriever % approximation error}}+\underbrace{\mathbb{E}_{X}\big{[}\min_{z\in\mathscr{I}}g% _{\xi}(X,z)\big{]}-R_{\ell,\mathscr{I}}(f_{{\rm opt},\mathscr{I}}^{\ell})}_{% \text{error from predictor $\xi$}}.+ under⏟ start_ARG italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_ξ , italic_θ ( italic_ξ ) ) - blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ roman_min start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_X , italic_z ) ] end_ARG start_POSTSUBSCRIPT retriever approximation error end_POSTSUBSCRIPT + under⏟ start_ARG blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ roman_min start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_X , italic_z ) ] - italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT roman_opt , script_I end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) end_ARG start_POSTSUBSCRIPT error from predictor italic_ξ end_POSTSUBSCRIPT .

B.1.1 Generalization error

We now proceed to bound the generalization error using the Radamacher complexity. With probability at least (1δ)1𝛿(1-\delta)( 1 - italic_δ ) for any δ>0𝛿0\delta>0italic_δ > 0,

|𝔼X[𝔼Zpθ^(ξ)(|X)gξ(X,Z)]1ni[n]zpθ^(ξ)(z|xi)(hξ(xi,z),yi)|\displaystyle\Big{|}\mathbb{E}_{X}\big{[}\mathbb{E}_{Z\sim p_{\hat{\theta}(\xi% )}(\cdot|X)}g_{\xi}(X,Z)\big{]}-\frac{1}{n}\sum_{i\in[n]}\sum_{z\in\mathscr{I}% }p_{\hat{\theta}(\xi)}(z|x_{i})\ell\big{(}h_{\xi}(x_{i},z),y_{i}\big{)}\Big{|}| blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ blackboard_E start_POSTSUBSCRIPT italic_Z ∼ italic_p start_POSTSUBSCRIPT over^ start_ARG italic_θ end_ARG ( italic_ξ ) end_POSTSUBSCRIPT ( ⋅ | italic_X ) end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_X , italic_Z ) ] - divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT over^ start_ARG italic_θ end_ARG ( italic_ξ ) end_POSTSUBSCRIPT ( italic_z | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) roman_ℓ ( italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) |
2𝔼𝝈[maxθΘ1ni[n]σizpθ(z|xi)(hξ(xi,z),yi)]+3maxlog(2/δ)nabsent2subscript𝔼𝝈delimited-[]subscript𝜃Θ1𝑛subscript𝑖delimited-[]𝑛subscript𝜎𝑖subscript𝑧subscript𝑝𝜃conditional𝑧subscript𝑥𝑖subscript𝜉subscript𝑥𝑖𝑧subscript𝑦𝑖3subscript2𝛿𝑛\displaystyle\qquad\leq 2\mathbb{E}_{\bm{\sigma}}\Big{[}\max_{\theta\in\Theta}% \frac{1}{n}\sum_{i\in[n]}\sigma_{i}\sum_{z\in\mathscr{I}}p_{\theta}(z|x_{i})% \ell\big{(}h_{\xi}(x_{i},z),y_{i}\big{)}\Big{]}+3\ell_{\max}\sqrt{\tfrac{\log(% 2/\delta)}{n}}≤ 2 blackboard_E start_POSTSUBSCRIPT bold_italic_σ end_POSTSUBSCRIPT [ roman_max start_POSTSUBSCRIPT italic_θ ∈ roman_Θ end_POSTSUBSCRIPT divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT italic_σ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) roman_ℓ ( italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ] + 3 roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT square-root start_ARG divide start_ARG roman_log ( 2 / italic_δ ) end_ARG start_ARG italic_n end_ARG end_ARG
2×infε[0,cξ/2](4ε+12nεcξ/2log(𝒩(Θ,ν,2,[n],ξ))𝑑ν)+3maxlog(2/δ)n\displaystyle\qquad\leq 2\times\inf_{\varepsilon\in[0,c_{\xi}/2]}\big{(}4% \varepsilon+\tfrac{12}{\sqrt{n}}\int_{\varepsilon}^{c_{\xi}/2}\sqrt{\log(% \mathcal{N}(\Theta,\nu,\|\cdot\|_{2,[n],\xi}))}d\nu\big{)}+3\ell_{\max}\sqrt{% \tfrac{\log(2/\delta)}{n}}≤ 2 × roman_inf start_POSTSUBSCRIPT italic_ε ∈ [ 0 , italic_c start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT / 2 ] end_POSTSUBSCRIPT ( 4 italic_ε + divide start_ARG 12 end_ARG start_ARG square-root start_ARG italic_n end_ARG end_ARG ∫ start_POSTSUBSCRIPT italic_ε end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_c start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT / 2 end_POSTSUPERSCRIPT square-root start_ARG roman_log ( caligraphic_N ( roman_Θ , italic_ν , ∥ ⋅ ∥ start_POSTSUBSCRIPT 2 , [ italic_n ] , italic_ξ end_POSTSUBSCRIPT ) ) end_ARG italic_d italic_ν ) + 3 roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT square-root start_ARG divide start_ARG roman_log ( 2 / italic_δ ) end_ARG start_ARG italic_n end_ARG end_ARG (28)

Using covering number bound with chaining we obtain the final inequality, where

cξ=supθΘ(1ni[n](zpθ(z|xi)(hξ(xi,z),yi))2)1/2,subscript𝑐𝜉subscriptsupremum𝜃Θsuperscript1𝑛subscript𝑖delimited-[]𝑛superscriptsubscript𝑧subscript𝑝𝜃conditional𝑧subscript𝑥𝑖subscript𝜉subscript𝑥𝑖𝑧subscript𝑦𝑖212c_{\xi}=\sup_{\theta\in\Theta}\Big{(}\tfrac{1}{n}\sum_{i\in[n]}\big{(}\sum_{z% \in\mathscr{I}}p_{\theta}(z|x_{i})\ell\big{(}h_{\xi}(x_{i},z),y_{i}\big{)}\big% {)}^{2}\Big{)}^{1/2},italic_c start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT = roman_sup start_POSTSUBSCRIPT italic_θ ∈ roman_Θ end_POSTSUBSCRIPT ( divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT ( ∑ start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) roman_ℓ ( italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT ,

and 𝒩(Θ,ν,2,[n],ξ)\mathcal{N}(\Theta,\nu,\|\cdot\|_{2,[n],\xi})caligraphic_N ( roman_Θ , italic_ν , ∥ ⋅ ∥ start_POSTSUBSCRIPT 2 , [ italic_n ] , italic_ξ end_POSTSUBSCRIPT ) denote the covering number of the retriever function ΘΘ\Thetaroman_Θ with error ν𝜈\nuitalic_ν in L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT norm w.r.t. the set {(xi,yi):i[n]}conditional-setsubscript𝑥𝑖subscript𝑦𝑖𝑖delimited-[]𝑛\{(x_{i},y_{i}):i\in[n]\}{ ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) : italic_i ∈ [ italic_n ] } and ξ𝜉\xiitalic_ξ fixed,

𝐮2,[n],ξ=(1ni[n](zui,z(hξ(xi,z),yi))2)1/2,𝐮n×||.formulae-sequencesubscriptnorm𝐮2delimited-[]𝑛𝜉superscript1𝑛subscript𝑖delimited-[]𝑛superscriptsubscript𝑧subscript𝑢𝑖𝑧subscript𝜉subscript𝑥𝑖𝑧subscript𝑦𝑖212for-all𝐮superscript𝑛\|\mathbf{u}\|_{2,[n],\xi}=\Big{(}\tfrac{1}{n}\sum_{i\in[n]}\big{(}\sum_{z\in% \mathscr{I}}u_{i,z}\ell\big{(}h_{\xi}(x_{i},z),y_{i}\big{)}\big{)}^{2}\Big{)}^% {1/2},\forall\mathbf{u}\in\mathbb{R}^{n\times|\mathscr{I}|}.∥ bold_u ∥ start_POSTSUBSCRIPT 2 , [ italic_n ] , italic_ξ end_POSTSUBSCRIPT = ( divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT ( ∑ start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_u start_POSTSUBSCRIPT italic_i , italic_z end_POSTSUBSCRIPT roman_ℓ ( italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT , ∀ bold_u ∈ blackboard_R start_POSTSUPERSCRIPT italic_n × | script_I | end_POSTSUPERSCRIPT .

The generalization error in retriever learning depends on the covering number of ΘΘ\Thetaroman_Θ (which we shall see is dependent on the embedding space of the retrieved examples).

As θ(ξ)𝜃𝜉\theta(\xi)italic_θ ( italic_ξ ) is a fixed retriever, we do not need to take any union bound over the retriever space. Therefore, we have

|𝔼X[𝔼Zpθ(ξ)(|X)gξ(X,Z)]1ni[n]zpθ(ξ)(z|xi)(hξ(xi,z),yi)|3maxlog(2/δ)n.\Big{|}\mathbb{E}_{X}\big{[}\mathbb{E}_{Z\sim p_{\theta(\xi)}(\cdot|X)}g_{\xi}% (X,Z)\big{]}-\frac{1}{n}\sum_{i\in[n]}\sum_{z\in\mathscr{I}}p_{\theta(\xi)}(z|% x_{i})\ell\big{(}h_{\xi}(x_{i},z),y_{i}\big{)}\Big{|}\leq 3\ell_{\max}\sqrt{% \tfrac{\log(2/\delta)}{n}}.| blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ blackboard_E start_POSTSUBSCRIPT italic_Z ∼ italic_p start_POSTSUBSCRIPT italic_θ ( italic_ξ ) end_POSTSUBSCRIPT ( ⋅ | italic_X ) end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_X , italic_Z ) ] - divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_θ ( italic_ξ ) end_POSTSUBSCRIPT ( italic_z | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) roman_ℓ ( italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) | ≤ 3 roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT square-root start_ARG divide start_ARG roman_log ( 2 / italic_δ ) end_ARG start_ARG italic_n end_ARG end_ARG .

B.1.2 Approximation error

The approximation error for learning the retriever depends on the hardness of the function minzgξ(X,z)subscript𝑧subscript𝑔𝜉𝑋𝑧\min_{z\in\mathscr{I}}g_{\xi}(X,z)roman_min start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_X , italic_z ). We recall that this term is approximated using softmax over rθ(X,Z)subscript𝑟𝜃𝑋𝑍r_{\theta}(X,Z)italic_r start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_X , italic_Z ) (cf. (6)).

We want to approximate the term minzgξ(x,z)subscript𝑧subscript𝑔𝜉𝑥𝑧\min_{z\in\mathscr{I}}g_{\xi}(x,z)roman_min start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x , italic_z ) for all x𝒳𝑥𝒳x\in\mathscr{X}italic_x ∈ script_X, by zpθ,(z|x)gξ(x,z)subscript𝑧subscript𝑝𝜃conditional𝑧𝑥subscript𝑔𝜉𝑥𝑧\sum_{z\in\mathscr{I}}p_{\theta,\mathscr{I}}(z|x)g_{\xi}(x,z)∑ start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_θ , script_I end_POSTSUBSCRIPT ( italic_z | italic_x ) italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x , italic_z ). We can break down the approximation into two parts. First we show that the function softmax(τ×gξ(x,z))softmax𝜏subscript𝑔𝜉𝑥𝑧\mathrm{softmax}(-\tau\times g_{\xi}(x,z))roman_softmax ( - italic_τ × italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x , italic_z ) ) approximates minzgξ(x,z)subscript𝑧subscript𝑔𝜉𝑥𝑧\min_{z}g_{\xi}(x,z)roman_min start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x , italic_z ) for large τ𝜏\tauitalic_τ. In particular, if τ=O(log(||)/δ)𝜏𝑂𝛿\tau=O(\log(|\mathscr{I}|)/\delta)italic_τ = italic_O ( roman_log ( | script_I | ) / italic_δ ) then softmax approximates minimum with error δ𝛿\deltaitalic_δ (see, McSherry and Talwar [2007], Epasto et al. [2020]). Second, we show that pθ,(z|x)subscript𝑝𝜃conditional𝑧𝑥p_{\theta,\mathscr{I}}\big{(}z|x\big{)}italic_p start_POSTSUBSCRIPT italic_θ , script_I end_POSTSUBSCRIPT ( italic_z | italic_x ) can approximate softmax(τ×gξ(x,z))softmax𝜏subscript𝑔𝜉𝑥𝑧\mathrm{softmax}(-\tau\times g_{\xi}(x,z))roman_softmax ( - italic_τ × italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x , italic_z ) ) well in L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT norm.

We define

p~ξ(z|x)=exp(τgξ(x,z))zexp(τgξ(x,z))=exp(τ(gξ(x,z)bξ(x)))zexp(τ(gξ(x,z)bξ(x))).subscript~𝑝𝜉conditional𝑧𝑥𝜏subscript𝑔𝜉𝑥𝑧subscriptsuperscript𝑧𝜏subscript𝑔𝜉𝑥superscript𝑧𝜏subscript𝑔𝜉𝑥𝑧subscript𝑏𝜉𝑥subscriptsuperscript𝑧𝜏subscript𝑔𝜉𝑥superscript𝑧subscript𝑏𝜉𝑥\tilde{p}_{\xi}(z|x)=\frac{\exp(-\tau g_{\xi}(x,z))}{\sum_{z^{\prime}}\exp(-% \tau g_{\xi}(x,z^{\prime}))}=\frac{\exp(-\tau(g_{\xi}(x,z)-b_{\xi}(x)))}{\sum_% {z^{\prime}}\exp(-\tau(g_{\xi}(x,z^{\prime})-b_{\xi}(x)))}.over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_z | italic_x ) = divide start_ARG roman_exp ( - italic_τ italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x , italic_z ) ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT roman_exp ( - italic_τ italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x , italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ) end_ARG = divide start_ARG roman_exp ( - italic_τ ( italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x , italic_z ) - italic_b start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x ) ) ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT roman_exp ( - italic_τ ( italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x , italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) - italic_b start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x ) ) ) end_ARG .

Here recall that bξ(x)subscript𝑏𝜉𝑥b_{\xi}(x)italic_b start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x ) is the baseline function in Assumption 3.1. An example of such baseline is the loss under the optimal retrieved sample for each x𝒳𝑥𝒳x\in\mathscr{X}italic_x ∈ script_X, i.e. bξ(x)=minz~gξ(x,z~)subscript𝑏𝜉𝑥subscript~𝑧subscript𝑔𝜉𝑥~𝑧b_{\xi}(x)=\min_{\tilde{z}}g_{\xi}(x,\tilde{z})italic_b start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x ) = roman_min start_POSTSUBSCRIPT over~ start_ARG italic_z end_ARG end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x , over~ start_ARG italic_z end_ARG ).

For any θΘ𝜃Θ\theta\in\Thetaitalic_θ ∈ roman_Θ, we have

R,(ξ,θ(ξ))𝔼X[minzgξ(X,z)]subscript𝑅𝜉𝜃𝜉subscript𝔼𝑋delimited-[]subscript𝑧subscript𝑔𝜉𝑋𝑧\displaystyle R_{\ell,\mathscr{I}}(\xi,\theta(\xi))-\mathbb{E}_{X}\big{[}\min_% {z\in\mathscr{I}}g_{\xi}(X,z)\big{]}italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_ξ , italic_θ ( italic_ξ ) ) - blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ roman_min start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_X , italic_z ) ]
(i)R,(ξ,θ)𝔼X[minzgξ(X,z)]𝑖subscript𝑅𝜉𝜃subscript𝔼𝑋delimited-[]subscript𝑧subscript𝑔𝜉𝑋𝑧\displaystyle\qquad\overset{(i)}{\leq}R_{\ell,\mathscr{I}}(\xi,\theta)-\mathbb% {E}_{X}\big{[}\min_{z\in\mathscr{I}}g_{\xi}(X,z)\big{]}start_OVERACCENT ( italic_i ) end_OVERACCENT start_ARG ≤ end_ARG italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_ξ , italic_θ ) - blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ roman_min start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_X , italic_z ) ]
=(ii)𝔼X[z(pθ,(z|x)p~ξ(z|x))gξ(x,z)]+𝔼X[zp~ξ(z|x)minzgξ(x,z)]𝑖𝑖subscript𝔼𝑋delimited-[]subscript𝑧subscript𝑝𝜃conditional𝑧𝑥subscript~𝑝𝜉conditional𝑧𝑥subscript𝑔𝜉𝑥𝑧subscript𝔼𝑋delimited-[]subscript𝑧subscript~𝑝𝜉conditional𝑧𝑥subscript𝑧subscript𝑔𝜉𝑥𝑧\displaystyle\qquad\overset{(ii)}{=}\mathbb{E}_{X}\big{[}\sum_{z\in\mathscr{I}% }(p_{\theta,\mathscr{I}}(z|x)-\tilde{p}_{\xi}(z|x))g_{\xi}(x,z)\big{]}+\mathbb% {E}_{X}\big{[}\sum_{z\in\mathscr{I}}\tilde{p}_{\xi}(z|x)-\min_{z\in\mathscr{I}% }g_{\xi}(x,z)\big{]}start_OVERACCENT ( italic_i italic_i ) end_OVERACCENT start_ARG = end_ARG blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ ∑ start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT ( italic_p start_POSTSUBSCRIPT italic_θ , script_I end_POSTSUBSCRIPT ( italic_z | italic_x ) - over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_z | italic_x ) ) italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x , italic_z ) ] + blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ ∑ start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_z | italic_x ) - roman_min start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x , italic_z ) ]
(iii)𝔼X[gξ(x,)pθ,(|x)p~ξ(|x)1]+log(||)τ2\displaystyle\qquad\overset{(iii)}{\leq}\mathbb{E}_{X}\big{[}\|g_{\xi}(x,\cdot% )\|_{\infty}\|p_{\theta,\mathscr{I}}(\cdot|x)-\tilde{p}_{\xi}(\cdot|x)\|_{1}% \big{]}+\frac{\log(|\mathscr{I}|)}{\tau^{2}}start_OVERACCENT ( italic_i italic_i italic_i ) end_OVERACCENT start_ARG ≤ end_ARG blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ ∥ italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x , ⋅ ) ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT ∥ italic_p start_POSTSUBSCRIPT italic_θ , script_I end_POSTSUBSCRIPT ( ⋅ | italic_x ) - over~ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( ⋅ | italic_x ) ∥ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ] + divide start_ARG roman_log ( | script_I | ) end_ARG start_ARG italic_τ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG
(iv)𝔼X[gξ(x,)rθ(x,)+τgapξ(x,)]+log(||)τ2𝑖𝑣subscript𝔼𝑋delimited-[]subscriptnormsubscript𝑔𝜉𝑥subscriptnormsubscript𝑟𝜃𝑥𝜏subscriptgap𝜉𝑥superscript𝜏2\displaystyle\qquad\overset{(iv)}{\leq}\mathbb{E}_{X}\big{[}\|g_{\xi}(x,\cdot)% \|_{\infty}\|r_{\theta}(x,\cdot)+\tau\mathrm{gap}_{\xi}(x,\cdot)\|_{\infty}% \big{]}+\frac{\log(|\mathscr{I}|)}{\tau^{2}}start_OVERACCENT ( italic_i italic_v ) end_OVERACCENT start_ARG ≤ end_ARG blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ ∥ italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x , ⋅ ) ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT ∥ italic_r start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x , ⋅ ) + italic_τ roman_gap start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x , ⋅ ) ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT ] + divide start_ARG roman_log ( | script_I | ) end_ARG start_ARG italic_τ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG
(v)maxrθ+τgapξ+log(||)τ2𝑣subscriptsubscriptnormsubscript𝑟𝜃𝜏subscriptgap𝜉superscript𝜏2\displaystyle\qquad\overset{(v)}{\leq}\ell_{\max}\|r_{\theta}+\tau\mathrm{gap}% _{\xi}\|_{\infty}+\frac{\log(|\mathscr{I}|)}{\tau^{2}}start_OVERACCENT ( italic_v ) end_OVERACCENT start_ARG ≤ end_ARG roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ∥ italic_r start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT + italic_τ roman_gap start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT + divide start_ARG roman_log ( | script_I | ) end_ARG start_ARG italic_τ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG

In the first inequality (i)𝑖(i)( italic_i ), we replace θ(ξ)𝜃𝜉\theta(\xi)italic_θ ( italic_ξ ) which is the optimal retriever for predictor ξ𝜉\xiitalic_ξ with an arbitrary retriever θ𝜃\thetaitalic_θ. The first term in the inequality (iii)𝑖𝑖𝑖(iii)( italic_i italic_i italic_i ) uses the norm bounds for inner product, while the second term follows from Theorem 3.1 in [Epasto et al., 2020] (which originates from [McSherry and Talwar, 2007]). The inequality (iv)𝑖𝑣(iv)( italic_i italic_v ) uses the fact that softmax functions over K𝐾Kitalic_K classes follow softmax(x)softmax(y)1xysubscriptnorm𝑠𝑜𝑓𝑡𝑚𝑎𝑥𝑥𝑠𝑜𝑓𝑡𝑚𝑎𝑥𝑦1subscriptnorm𝑥𝑦\|softmax(x)-softmax(y)\|_{1}\leq\|x-y\|_{\infty}∥ italic_s italic_o italic_f italic_t italic_m italic_a italic_x ( italic_x ) - italic_s italic_o italic_f italic_t italic_m italic_a italic_x ( italic_y ) ∥ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ≤ ∥ italic_x - italic_y ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT (see henrikl [https://math.stackexchange.com/users/351007/henrikl]). In the final inequality (v)𝑣(v)( italic_v ), we use maxsubscript\ell_{\max}roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT to bound the norm of gξsubscript𝑔𝜉g_{\xi}italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT.

As the above bound hold for any τ>0𝜏0\tau>0italic_τ > 0, by optimizing of τ𝜏\tauitalic_τ and θ𝜃\thetaitalic_θ we obtain,

R,(ξ,θ(ξ))𝔼X[minzgξ(X,z)]infθΘinfτ>0maxrθ+τgapξ+log(||)τ2.subscript𝑅𝜉𝜃𝜉subscript𝔼𝑋delimited-[]subscript𝑧subscript𝑔𝜉𝑋𝑧subscriptinfimum𝜃Θsubscriptinfimum𝜏0subscriptsubscriptnormsubscript𝑟𝜃𝜏subscriptgap𝜉superscript𝜏2R_{\ell,\mathscr{I}}(\xi,\theta(\xi))-\mathbb{E}_{X}\big{[}\min_{z\in\mathscr{% I}}g_{\xi}(X,z)\big{]}\leq\inf_{\theta\in\Theta}\inf_{\tau>0}\ell_{\max}\|r_{% \theta}+\tau\mathrm{gap}_{\xi}\|_{\infty}+\frac{\log(|\mathscr{I}|)}{\tau^{2}}.italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_ξ , italic_θ ( italic_ξ ) ) - blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ roman_min start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_X , italic_z ) ] ≤ roman_inf start_POSTSUBSCRIPT italic_θ ∈ roman_Θ end_POSTSUBSCRIPT roman_inf start_POSTSUBSCRIPT italic_τ > 0 end_POSTSUBSCRIPT roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ∥ italic_r start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT + italic_τ roman_gap start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT + divide start_ARG roman_log ( | script_I | ) end_ARG start_ARG italic_τ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG . (29)

Since the right had side in the inequality (v)𝑣(v)( italic_v ) holds for any θΘ𝜃Θ\theta\in\Thetaitalic_θ ∈ roman_Θ, if there exists a θΘ𝜃Θ\theta\in\Thetaitalic_θ ∈ roman_Θ such that the function rθ(x,z)subscript𝑟𝜃𝑥𝑧r_{\theta}(x,z)italic_r start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x , italic_z ) approximates the function τgapξ(x,z)𝜏subscriptgap𝜉𝑥𝑧-\tau\mathrm{gap}_{\xi}(x,z)- italic_τ roman_gap start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x , italic_z ) well we end up with small approximation error.

B.1.3 Instantiation of MLP retriever

We consider ΘΘ\Thetaroman_Θ to be the class of MLP defined in Equation (27). As we know MLP with appropriate depth and width has universal approximation properties, this choice of ΘΘ\Thetaroman_Θ ensures the function rθ(x,z)subscript𝑟𝜃𝑥𝑧r_{\theta}(x,z)italic_r start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x , italic_z ) approximates the function τgapξ(x,z)𝜏subscriptgap𝜉𝑥𝑧-\tau\mathrm{gap}_{\xi}(x,z)- italic_τ roman_gap start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x , italic_z ) well. To bound the excess risk of learning the retriever, we need to prove generalization error, and approximation error bounds for the MLP class.

Generalization error for MLP retriever:

To bound the generalization error, we need to first bound the covering number 𝒩(Θ,ν,2,[n],ξ)\mathcal{N}(\Theta,\nu,\|\cdot\|_{2,[n],\xi})caligraphic_N ( roman_Θ , italic_ν , ∥ ⋅ ∥ start_POSTSUBSCRIPT 2 , [ italic_n ] , italic_ξ end_POSTSUBSCRIPT ), for Θ=MLP(dx+dz,;W,L)ΘMLPsuperscriptsubscript𝑑𝑥subscript𝑑𝑧𝑊𝐿\Theta={\rm MLP}(\mathbb{R}^{d_{x}+d_{z}},\mathbb{R};W,L)roman_Θ = roman_MLP ( blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_POSTSUPERSCRIPT , blackboard_R ; italic_W , italic_L ). Here, 𝒳dx𝒳superscriptsubscript𝑑𝑥\mathscr{X}\subseteq\mathbb{R}^{d_{x}}script_X ⊆ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT end_POSTSUPERSCRIPT and dzsuperscriptsubscript𝑑𝑧\mathscr{I}\subseteq\mathbb{R}^{d_{z}}script_I ⊆ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_POSTSUPERSCRIPT i.e., the retrieved space is embedded in dzsuperscriptsubscript𝑑𝑧\mathbb{R}^{d_{z}}blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_POSTSUPERSCRIPT. We first want to bound the covering number 𝒩(Θ,ν,2,[n],ξ)\mathcal{N}(\Theta,\nu,\|\cdot\|_{2,[n],\xi})caligraphic_N ( roman_Θ , italic_ν , ∥ ⋅ ∥ start_POSTSUBSCRIPT 2 , [ italic_n ] , italic_ξ end_POSTSUBSCRIPT ) with a covering number of MLP(dx+dz,;W,L)MLPsuperscriptsubscript𝑑𝑥subscript𝑑𝑧𝑊𝐿{\rm MLP}(\mathbb{R}^{d_{x}+d_{z}},\mathbb{R};W,L)roman_MLP ( blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_POSTSUPERSCRIPT , blackboard_R ; italic_W , italic_L ).

For a fixed data set 𝒮n:={(x1,y1),,(xn,yn)}assignsubscript𝒮𝑛subscript𝑥1subscript𝑦1subscript𝑥𝑛subscript𝑦𝑛\mathcal{S}_{n}:=\{(x_{1},y_{1}),\dots,(x_{n},y_{n})\}caligraphic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT := { ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , … , ( italic_x start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) }; predictor ξ𝜉\xiitalic_ξ; and two retrievers θ,θΘ𝜃superscript𝜃Θ\theta,\theta^{\prime}\in\Thetaitalic_θ , italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ roman_Θ

(1ni[n](z(pθ(z|xi)pθ(z|xi))(hξ(xi,z),yi))2)1/2superscript1𝑛subscript𝑖delimited-[]𝑛superscriptsubscript𝑧subscript𝑝𝜃conditional𝑧subscript𝑥𝑖subscript𝑝superscript𝜃conditional𝑧subscript𝑥𝑖subscript𝜉subscript𝑥𝑖𝑧subscript𝑦𝑖212\displaystyle\Big{(}\tfrac{1}{n}\sum_{i\in[n]}\big{(}\sum_{z\in\mathscr{I}}(p_% {\theta}(z|x_{i})-p_{\theta^{\prime}}(z|x_{i}))\ell\big{(}h_{\xi}(x_{i},z),y_{% i}\big{)}\big{)}^{2}\Big{)}^{1/2}( divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT ( ∑ start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT ( italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - italic_p start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_z | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) roman_ℓ ( italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT
(i)max(1ni[n](z|pθ(z|xi)pθ(z|xi)|)2)1/2\displaystyle\qquad\overset{(i)}{\leq}\ell_{\max}\Big{(}\tfrac{1}{n}\sum_{i\in% [n]}\big{(}\sum_{z\in\mathscr{I}}|p_{\theta}(z|x_{i})-p_{\theta^{\prime}}(z|x_% {i})|\big{)}^{2}\Big{)}^{1/2}start_OVERACCENT ( italic_i ) end_OVERACCENT start_ARG ≤ end_ARG roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ( divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT ( ∑ start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT | italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - italic_p start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_z | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) | ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT
(ii)max(1ni[n](maxz|rθ(xi,z)rθ(xi,z)|)2)1/2𝑖𝑖subscriptsuperscript1𝑛subscript𝑖delimited-[]𝑛superscriptsubscript𝑧subscript𝑟𝜃subscript𝑥𝑖𝑧subscript𝑟superscript𝜃subscript𝑥𝑖𝑧212\displaystyle\qquad\overset{(ii)}{\leq}\ell_{\max}\Big{(}\tfrac{1}{n}\sum_{i% \in[n]}\big{(}\max_{z\in\mathscr{I}}|r_{\theta}(x_{i},z)-r_{\theta^{\prime}}(x% _{i},z)|\big{)}^{2}\Big{)}^{1/2}start_OVERACCENT ( italic_i italic_i ) end_OVERACCENT start_ARG ≤ end_ARG roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ( divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT ( roman_max start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT | italic_r start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) - italic_r start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) | ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT
(iii)maxsupx𝒮n,z|rθ(x,z)rθ(x,z)|𝑖𝑖𝑖subscriptsubscriptsupremumformulae-sequence𝑥subscript𝒮𝑛𝑧subscript𝑟𝜃𝑥𝑧subscript𝑟superscript𝜃𝑥𝑧\displaystyle\qquad\overset{(iii)}{\leq}\ell_{\max}\sup_{x\in\mathcal{S}_{n},z% \in\mathscr{I}}|r_{\theta}(x,z)-r_{\theta^{\prime}}(x,z)|start_OVERACCENT ( italic_i italic_i italic_i ) end_OVERACCENT start_ARG ≤ end_ARG roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT roman_sup start_POSTSUBSCRIPT italic_x ∈ caligraphic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_z ∈ script_I end_POSTSUBSCRIPT | italic_r start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x , italic_z ) - italic_r start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_x , italic_z ) |

Above, the inequality (i)𝑖(i)( italic_i ) follow by upper bounding (hξ(xi,z),yi)subscript𝜉subscript𝑥𝑖𝑧subscript𝑦𝑖\ell\big{(}h_{\xi}(x_{i},z),y_{i}\big{)}roman_ℓ ( italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) with maxsubscript\ell_{\max}roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT. The inequality (ii)𝑖𝑖(ii)( italic_i italic_i ) uses the fact that softmax functions over K𝐾Kitalic_K classes follow softmax(x)softmax(y)1xysubscriptnorm𝑠𝑜𝑓𝑡𝑚𝑎𝑥𝑥𝑠𝑜𝑓𝑡𝑚𝑎𝑥𝑦1subscriptnorm𝑥𝑦\|softmax(x)-softmax(y)\|_{1}\leq\|x-y\|_{\infty}∥ italic_s italic_o italic_f italic_t italic_m italic_a italic_x ( italic_x ) - italic_s italic_o italic_f italic_t italic_m italic_a italic_x ( italic_y ) ∥ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ≤ ∥ italic_x - italic_y ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT.

Let us define the norm ,n||\|\cdot\|_{\infty,n|\mathscr{I}|}∥ ⋅ ∥ start_POSTSUBSCRIPT ∞ , italic_n | script_I | end_POSTSUBSCRIPT as u,n||supxi𝒮nsupz|ui,z|,𝐮n×||.formulae-sequencesubscriptnorm𝑢𝑛subscriptsupremumsubscript𝑥𝑖subscript𝒮𝑛subscriptsupremum𝑧subscript𝑢𝑖𝑧for-all𝐮superscript𝑛\|u\|_{\infty,n|\mathscr{I}|}\triangleq\sup_{x_{i}\in\mathcal{S}_{n}}\sup_{z% \in\mathscr{I}}|u_{i,z}|,~{}\forall\mathbf{u}\in\mathbb{R}^{n\times|\mathscr{I% }|}.∥ italic_u ∥ start_POSTSUBSCRIPT ∞ , italic_n | script_I | end_POSTSUBSCRIPT ≜ roman_sup start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ caligraphic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_sup start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT | italic_u start_POSTSUBSCRIPT italic_i , italic_z end_POSTSUBSCRIPT | , ∀ bold_u ∈ blackboard_R start_POSTSUPERSCRIPT italic_n × | script_I | end_POSTSUPERSCRIPT . Now consider a ,n||\|\cdot\|_{\infty,n|\mathscr{I}|}∥ ⋅ ∥ start_POSTSUBSCRIPT ∞ , italic_n | script_I | end_POSTSUBSCRIPT norm cover of ΘΘ\Thetaroman_Θ, ΘcovsubscriptΘcov\Theta_{{\rm cov}}roman_Θ start_POSTSUBSCRIPT roman_cov end_POSTSUBSCRIPT with cardinality 𝒩(Θ,ν/max,,n||)\mathcal{N}(\Theta,\nu/\ell_{\max},\|\cdot\|_{\infty,n|\mathscr{I}|})caligraphic_N ( roman_Θ , italic_ν / roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT , ∥ ⋅ ∥ start_POSTSUBSCRIPT ∞ , italic_n | script_I | end_POSTSUBSCRIPT ).

Note that, by definition, for any θΘ𝜃Θ\theta\in\Thetaitalic_θ ∈ roman_Θ, there exists a θcov(θ)Θcovsubscript𝜃cov𝜃subscriptΘcov\theta_{{\rm cov}}(\theta)\in\Theta_{{\rm cov}}italic_θ start_POSTSUBSCRIPT roman_cov end_POSTSUBSCRIPT ( italic_θ ) ∈ roman_Θ start_POSTSUBSCRIPT roman_cov end_POSTSUBSCRIPT such that supx𝒮n,z|rθ(x,z)rθcov(θ)(x,z)|ν/maxsubscriptsupremumformulae-sequence𝑥subscript𝒮𝑛𝑧subscript𝑟𝜃𝑥𝑧subscript𝑟subscript𝜃cov𝜃𝑥𝑧𝜈subscript\sup_{x\in\mathcal{S}_{n},z\in\mathscr{I}}|r_{\theta}(x,z)-r_{\theta_{{\rm cov% }}(\theta)}(x,z)|\leq\nu/\ell_{\max}roman_sup start_POSTSUBSCRIPT italic_x ∈ caligraphic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_z ∈ script_I end_POSTSUBSCRIPT | italic_r start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x , italic_z ) - italic_r start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT roman_cov end_POSTSUBSCRIPT ( italic_θ ) end_POSTSUBSCRIPT ( italic_x , italic_z ) | ≤ italic_ν / roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT. This means, that ΘcovsubscriptΘcov\Theta_{{\rm cov}}roman_Θ start_POSTSUBSCRIPT roman_cov end_POSTSUBSCRIPT forms a ν𝜈\nuitalic_ν-cover in the 2,[n],ξ\|\cdot\|_{2,[n],\xi}∥ ⋅ ∥ start_POSTSUBSCRIPT 2 , [ italic_n ] , italic_ξ end_POSTSUBSCRIPT norm. In other words, we have 𝒩(Θ,ν,2,[n],ξ)𝒩(Θ,ν/max,,n||).\mathcal{N}(\Theta,\nu,\|\cdot\|_{2,[n],\xi})\leq\mathcal{N}(\Theta,\nu/\ell_{% \max},\|\cdot\|_{\infty,n|\mathscr{I}|}).caligraphic_N ( roman_Θ , italic_ν , ∥ ⋅ ∥ start_POSTSUBSCRIPT 2 , [ italic_n ] , italic_ξ end_POSTSUBSCRIPT ) ≤ caligraphic_N ( roman_Θ , italic_ν / roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT , ∥ ⋅ ∥ start_POSTSUBSCRIPT ∞ , italic_n | script_I | end_POSTSUBSCRIPT ) .

Most existing results on covering number bounds for MLP assumes norm bounds for the MLP weights and biases. However, we do not impose such norm bounds for the MLP weights and biases. Therefore, we will use pseudo-dimension of the class ΘΘ\Thetaroman_Θ from Bartlett et al. [2019] to bound the covering number 𝒩(Θ,ν,,n||)\mathcal{N}(\Theta,\nu,\|\cdot\|_{\infty,n|\mathscr{I}|})caligraphic_N ( roman_Θ , italic_ν , ∥ ⋅ ∥ start_POSTSUBSCRIPT ∞ , italic_n | script_I | end_POSTSUBSCRIPT ) using  Zhang [2023]. In particular, if the pseudo-dimension of ΘΘ\Thetaroman_Θ is dVCsubscript𝑑𝑉𝐶d_{VC}italic_d start_POSTSUBSCRIPT italic_V italic_C end_POSTSUBSCRIPT, then we have log𝒩(Θ,ν,,n||)1+log(1+dVC)+dVClog(max{2,en||/dVCν})\log\mathcal{N}(\Theta,\nu,\|\cdot\|_{\infty,n|\mathscr{I}|})\leq 1+\log(1+d_{% VC})+d_{VC}\log(\max\{2,en|\mathscr{I}|/d_{VC}\nu\})roman_log caligraphic_N ( roman_Θ , italic_ν , ∥ ⋅ ∥ start_POSTSUBSCRIPT ∞ , italic_n | script_I | end_POSTSUBSCRIPT ) ≤ 1 + roman_log ( 1 + italic_d start_POSTSUBSCRIPT italic_V italic_C end_POSTSUBSCRIPT ) + italic_d start_POSTSUBSCRIPT italic_V italic_C end_POSTSUBSCRIPT roman_log ( roman_max { 2 , italic_e italic_n | script_I | / italic_d start_POSTSUBSCRIPT italic_V italic_C end_POSTSUBSCRIPT italic_ν } ) as per in [Zhang, 2023, Theorem 5.11]. From [Bartlett et al., 2019, Theorem 6] we know that for the class MLP(d,;W,L)MLPsuperscript𝑑𝑊𝐿{\rm MLP}(\mathbb{R}^{d},\mathbb{R};W,L)roman_MLP ( blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT , blackboard_R ; italic_W , italic_L ) the pseudo-dimension is O(LNlog(M))𝑂𝐿𝑁𝑀O(LN\log(M))italic_O ( italic_L italic_N roman_log ( italic_M ) ), where N=O(LW2)𝑁𝑂𝐿superscript𝑊2N=O(LW^{2})italic_N = italic_O ( italic_L italic_W start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) is the number of parameters, and M=O(LW)𝑀𝑂𝐿𝑊M=O(LW)italic_M = italic_O ( italic_L italic_W ) is the number of computation units. By setting ε=c/n𝜀𝑐𝑛\varepsilon=c/\sqrt{n}italic_ε = italic_c / square-root start_ARG italic_n end_ARG for a constant c𝑐citalic_c, and δ=1/n𝛿1𝑛\delta=1/nitalic_δ = 1 / italic_n in Equation (28), for large enough L𝐿Litalic_L (we will set L𝐿Litalic_L as a function of the data size n𝑛nitalic_n) we obtain the final generalization error as

|𝔼X[𝔼Zpθ^(ξ)(|X)gξ(X,Z)]1ni[n]zpθ^(ξ)(z|xi)(hξ(xi,z),yi)|=O(maxLWlog(LW)log(n||)n).\Big{|}\mathbb{E}_{X}\big{[}\mathbb{E}_{Z\sim p_{\hat{\theta}(\xi)}(\cdot|X)}g% _{\xi}(X,Z)\big{]}-\frac{1}{n}\sum_{i\in[n]}\sum_{z\in\mathscr{I}}p_{\hat{% \theta}(\xi)}(z|x_{i})\ell\big{(}h_{\xi}(x_{i},z),y_{i}\big{)}\Big{|}=O\big{(}% \frac{\ell_{\max}LW\sqrt{\log(LW)\log(n|\mathscr{I}|)}}{\sqrt{n}}\big{)}.| blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ blackboard_E start_POSTSUBSCRIPT italic_Z ∼ italic_p start_POSTSUBSCRIPT over^ start_ARG italic_θ end_ARG ( italic_ξ ) end_POSTSUBSCRIPT ( ⋅ | italic_X ) end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_X , italic_Z ) ] - divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT over^ start_ARG italic_θ end_ARG ( italic_ξ ) end_POSTSUBSCRIPT ( italic_z | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) roman_ℓ ( italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) | = italic_O ( divide start_ARG roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT italic_L italic_W square-root start_ARG roman_log ( italic_L italic_W ) roman_log ( italic_n | script_I | ) end_ARG end_ARG start_ARG square-root start_ARG italic_n end_ARG end_ARG ) . (30)
Approximation error for MLP retriever:

Our excess risk bounds closely follow the work of Siegel [2023] which generalizes Yarotsky [2017].222We note Siegel [2023] works with Ω=[0,1]dΩsuperscript01𝑑\Omega=[0,1]^{d}roman_Ω = [ 0 , 1 ] start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT, and as mentioned therein, the analysis can be extended to bounded domain, e.g. [a,b]dsuperscript𝑎𝑏𝑑[a,b]^{d}[ italic_a , italic_b ] start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT which includes our setting. Furthermore, one can extend the analysis to non-integer Sobolev and Besov spaces following Siegel [2023]. Under Assumption B.1, by specializing [Siegel, 2023, Theorem 1] with p=q=𝑝𝑞p=q=\inftyitalic_p = italic_q = ∞ we get that

inffMLP(dx+dz,;W,L)fgapξL(Ω)CgapξWκ(L(Ω))L2κ/(dx+dz)subscriptinfimum𝑓MLPsuperscriptsubscript𝑑𝑥subscript𝑑𝑧𝑊𝐿subscriptnorm𝑓subscriptgap𝜉subscript𝐿Ω𝐶subscriptnormsubscriptgap𝜉superscript𝑊𝜅subscript𝐿Ωsuperscript𝐿2𝜅subscript𝑑𝑥subscript𝑑𝑧\inf_{f\in{\rm MLP}(\mathbb{R}^{d_{x}+d_{z}},\mathbb{R};W,L)}\|f-\mathrm{gap}_% {\xi}\|_{L_{\infty}(\Omega)}\leq C\|\mathrm{gap}_{\xi}\|_{W^{\kappa}(L_{\infty% }(\Omega))}L^{-2\kappa/(d_{x}+d_{z})}roman_inf start_POSTSUBSCRIPT italic_f ∈ roman_MLP ( blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_POSTSUPERSCRIPT , blackboard_R ; italic_W , italic_L ) end_POSTSUBSCRIPT ∥ italic_f - roman_gap start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT italic_L start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT ( roman_Ω ) end_POSTSUBSCRIPT ≤ italic_C ∥ roman_gap start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT italic_W start_POSTSUPERSCRIPT italic_κ end_POSTSUPERSCRIPT ( italic_L start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT ( roman_Ω ) ) end_POSTSUBSCRIPT italic_L start_POSTSUPERSCRIPT - 2 italic_κ / ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) end_POSTSUPERSCRIPT

for Ω[1,1]dx+dzΩsuperscript11subscript𝑑𝑥subscript𝑑𝑧\Omega\in[-1,1]^{d_{x}+d_{z}}roman_Ω ∈ [ - 1 , 1 ] start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, W=25(dx+dz)+31𝑊25subscript𝑑𝑥subscript𝑑𝑧31W=25(d_{x}+d_{z})+31italic_W = 25 ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) + 31 and C=c(κ,dx+dz)<𝐶𝑐𝜅subscript𝑑𝑥subscript𝑑𝑧C=c(\kappa,d_{x}+d_{z})<\inftyitalic_C = italic_c ( italic_κ , italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) < ∞ (independent of L). Note that κ𝜅\kappaitalic_κ is the number of derivatives of the Sobolev space under consideration in Assumption B.1.

Therefore, under Assumption B.1 for Θ=MLP(dx+dz,;25(dx+dz)+31,L)ΘMLPsuperscriptsubscript𝑑𝑥subscript𝑑𝑧25subscript𝑑𝑥subscript𝑑𝑧31𝐿\Theta={\rm MLP}(\mathbb{R}^{d_{x}+d_{z}},\mathbb{R};25(d_{x}+d_{z})+31,L)roman_Θ = roman_MLP ( blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_POSTSUPERSCRIPT , blackboard_R ; 25 ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) + 31 , italic_L ) we show that

R,(ξ,θ(ξ))𝔼X[minzgξ(X,z)]CmaxL2κ/(dx+dz)+log(||)τ2.subscript𝑅𝜉𝜃𝜉subscript𝔼𝑋delimited-[]subscript𝑧subscript𝑔𝜉𝑋𝑧superscript𝐶subscriptsuperscript𝐿2𝜅subscript𝑑𝑥subscript𝑑𝑧superscript𝜏2R_{\ell,\mathscr{I}}(\xi,\theta(\xi))-\mathbb{E}_{X}\big{[}\min_{z\in\mathscr{% I}}g_{\xi}(X,z)\big{]}\leq C^{\prime}\ell_{\max}L^{-2\kappa/(d_{x}+d_{z})}+% \frac{\log(|\mathscr{I}|)}{\tau^{2}}.italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_ξ , italic_θ ( italic_ξ ) ) - blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ roman_min start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_X , italic_z ) ] ≤ italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT italic_L start_POSTSUPERSCRIPT - 2 italic_κ / ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) end_POSTSUPERSCRIPT + divide start_ARG roman_log ( | script_I | ) end_ARG start_ARG italic_τ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG . (31)

This follows from the following series of inequalities:

R,(ξ,θ(ξ))𝔼X[minzgξ(X,z)]subscript𝑅𝜉𝜃𝜉subscript𝔼𝑋delimited-[]subscript𝑧subscript𝑔𝜉𝑋𝑧\displaystyle R_{\ell,\mathscr{I}}(\xi,\theta(\xi))-\mathbb{E}_{X}\big{[}\min_% {z\in\mathscr{I}}g_{\xi}(X,z)\big{]}italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_ξ , italic_θ ( italic_ξ ) ) - blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ roman_min start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_X , italic_z ) ]
(i)𝔼X[gξ(x,)]𝔼X[rθ(x,)+τgapξ(x,)]+log(||)τ2𝑖subscript𝔼𝑋delimited-[]subscriptnormsubscript𝑔𝜉𝑥subscript𝔼𝑋delimited-[]subscriptnormsubscript𝑟𝜃𝑥𝜏subscriptgap𝜉𝑥superscript𝜏2\displaystyle\qquad\overset{(i)}{\leq}\mathbb{E}_{X}\big{[}\|g_{\xi}(x,\cdot)% \|_{\infty}\big{]}\mathbb{E}_{X}\big{[}\|r_{\theta}(x,\cdot)+\tau\mathrm{gap}_% {\xi}(x,\cdot)\|_{\infty}\big{]}+\frac{\log(|\mathscr{I}|)}{\tau^{2}}start_OVERACCENT ( italic_i ) end_OVERACCENT start_ARG ≤ end_ARG blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ ∥ italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x , ⋅ ) ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT ] blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ ∥ italic_r start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x , ⋅ ) + italic_τ roman_gap start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x , ⋅ ) ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT ] + divide start_ARG roman_log ( | script_I | ) end_ARG start_ARG italic_τ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG
=(ii)τ𝔼X[gξ(x,)]r~θgapξL(Ω)+log(||)τ2𝑖𝑖𝜏subscript𝔼𝑋delimited-[]subscriptnormsubscript𝑔𝜉𝑥subscriptnormsubscript~𝑟𝜃subscriptgap𝜉subscript𝐿Ωsuperscript𝜏2\displaystyle\qquad\overset{(ii)}{=}\tau\mathbb{E}_{X}\big{[}\|g_{\xi}(x,\cdot% )\|_{\infty}\big{]}\|\tilde{r}_{\theta}-\mathrm{gap}_{\xi}\|_{L_{\infty}(% \Omega)}+\frac{\log(|\mathscr{I}|)}{\tau^{2}}start_OVERACCENT ( italic_i italic_i ) end_OVERACCENT start_ARG = end_ARG italic_τ blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ ∥ italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x , ⋅ ) ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT ] ∥ over~ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT - roman_gap start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT italic_L start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT ( roman_Ω ) end_POSTSUBSCRIPT + divide start_ARG roman_log ( | script_I | ) end_ARG start_ARG italic_τ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG
(iii)Cτ𝔼X[gξ(x,)]gapξWκ(L(Ω))L2κ/(dx+dz)+log(||)τ2𝑖𝑖𝑖𝐶𝜏subscript𝔼𝑋delimited-[]subscriptnormsubscript𝑔𝜉𝑥subscriptnormsubscriptgap𝜉superscript𝑊𝜅subscript𝐿Ωsuperscript𝐿2𝜅subscript𝑑𝑥subscript𝑑𝑧superscript𝜏2\displaystyle\qquad\overset{(iii)}{\leq}C\tau\mathbb{E}_{X}\big{[}\|g_{\xi}(x,% \cdot)\|_{\infty}\big{]}\|\mathrm{gap}_{\xi}\|_{W^{\kappa}(L_{\infty}(\Omega))% }L^{-2\kappa/(d_{x}+d_{z})}+\frac{\log(|\mathscr{I}|)}{\tau^{2}}start_OVERACCENT ( italic_i italic_i italic_i ) end_OVERACCENT start_ARG ≤ end_ARG italic_C italic_τ blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ ∥ italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x , ⋅ ) ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT ] ∥ roman_gap start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT italic_W start_POSTSUPERSCRIPT italic_κ end_POSTSUPERSCRIPT ( italic_L start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT ( roman_Ω ) ) end_POSTSUBSCRIPT italic_L start_POSTSUPERSCRIPT - 2 italic_κ / ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) end_POSTSUPERSCRIPT + divide start_ARG roman_log ( | script_I | ) end_ARG start_ARG italic_τ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG
(iv)CmaxτL2κ/(dx+dz)+log(||)τ2𝑖𝑣superscript𝐶subscript𝜏superscript𝐿2𝜅subscript𝑑𝑥subscript𝑑𝑧superscript𝜏2\displaystyle\qquad\overset{(iv)}{\leq}C^{\prime}\ell_{\max}\tau L^{-2\kappa/(% d_{x}+d_{z})}+\frac{\log(|\mathscr{I}|)}{\tau^{2}}start_OVERACCENT ( italic_i italic_v ) end_OVERACCENT start_ARG ≤ end_ARG italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT italic_τ italic_L start_POSTSUPERSCRIPT - 2 italic_κ / ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) end_POSTSUPERSCRIPT + divide start_ARG roman_log ( | script_I | ) end_ARG start_ARG italic_τ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG

The first inequality (i)𝑖(i)( italic_i ) follows from Equation (29). The second equality (ii)𝑖𝑖(ii)( italic_i italic_i ), replaces r~θ=τrθsubscript~𝑟𝜃𝜏subscript𝑟𝜃\tilde{r}_{\theta}=-\tau r_{\theta}over~ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT = - italic_τ italic_r start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT. The inequality (iii)𝑖𝑖𝑖(iii)( italic_i italic_i italic_i ) follows by optimizing r~θsubscript~𝑟𝜃\tilde{r}_{\theta}over~ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT over the class ΘΘ\Thetaroman_Θ, as we see then τrθ𝜏subscript𝑟𝜃-\tau r_{\theta}- italic_τ italic_r start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT also lies in ΘΘ\Thetaroman_Θ, and applying Theorem 1 in Siegel [2023]. The final inequality (iv)𝑖𝑣(iv)( italic_i italic_v ) combines C=CgapξWκ(L(Ω))superscript𝐶𝐶subscriptnormsubscriptgap𝜉superscript𝑊𝜅subscript𝐿ΩC^{\prime}=C\|\mathrm{gap}_{\xi}\|_{W^{\kappa}(L_{\infty}(\Omega))}italic_C start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = italic_C ∥ roman_gap start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT italic_W start_POSTSUPERSCRIPT italic_κ end_POSTSUPERSCRIPT ( italic_L start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT ( roman_Ω ) ) end_POSTSUBSCRIPT and bounds 𝔼X[gξ(x,)]maxsubscript𝔼𝑋delimited-[]subscriptnormsubscript𝑔𝜉𝑥subscript\mathbb{E}_{X}\big{[}\|g_{\xi}(x,\cdot)\|_{\infty}\big{]}\leq\ell_{\max}blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ ∥ italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x , ⋅ ) ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT ] ≤ roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT.

Note that the choice of τ𝜏\tauitalic_τ is not algorithmic, we can optimize for τ𝜏\tauitalic_τ. In particular, we choose τ=cL2κ/3(dx+dz)log1/3(||)𝜏𝑐superscript𝐿2𝜅3subscript𝑑𝑥subscript𝑑𝑧superscript13\tau=cL^{-2\kappa/3(d_{x}+d_{z})}\log^{1/3}(|\mathscr{I}|)italic_τ = italic_c italic_L start_POSTSUPERSCRIPT - 2 italic_κ / 3 ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) end_POSTSUPERSCRIPT roman_log start_POSTSUPERSCRIPT 1 / 3 end_POSTSUPERSCRIPT ( | script_I | ) to obtain the approximation error bound as O(maxL4κ/3(dx+dz)log1/3(||))𝑂subscriptsuperscript𝐿4𝜅3subscript𝑑𝑥subscript𝑑𝑧superscript13O(\ell_{\max}L^{-4\kappa/3(d_{x}+d_{z})}\log^{1/3}(|\mathscr{I}|))italic_O ( roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT italic_L start_POSTSUPERSCRIPT - 4 italic_κ / 3 ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) end_POSTSUPERSCRIPT roman_log start_POSTSUPERSCRIPT 1 / 3 end_POSTSUPERSCRIPT ( | script_I | ) ), where we treat the remaining terms that are independent of τ𝜏\tauitalic_τ and L𝐿Litalic_L as constants.

Excess risk for MLP retriever learning:

Adding the approximation error (31), and the generalization error (30) we bound the excess risk as

Excess Risk 𝔼X[minzgξ(X,z)]R,(fopt,)error from predictor ξ+O(maxL4κ3(dx+dz)log1/3(||))retriever approximation errorabsentsubscriptsubscript𝔼𝑋delimited-[]subscript𝑧subscript𝑔𝜉𝑋𝑧subscript𝑅superscriptsubscript𝑓opterror from predictor ξsubscript𝑂subscriptsuperscript𝐿4𝜅3subscript𝑑𝑥subscript𝑑𝑧superscript13retriever approximation error\displaystyle\leq\underbrace{\mathbb{E}_{X}\big{[}\min_{z\in\mathscr{I}}g_{\xi% }(X,z)\big{]}-R_{\ell,\mathscr{I}}(f_{{\rm opt},\mathscr{I}}^{\ell})}_{\text{% error from predictor $\xi$}}+\underbrace{O(\ell_{\max}L^{-\tfrac{4\kappa}{3(d_% {x}+d_{z})}}\log^{1/3}(|\mathscr{I}|))}_{\text{retriever approximation error}}≤ under⏟ start_ARG blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ roman_min start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_X , italic_z ) ] - italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT roman_opt , script_I end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) end_ARG start_POSTSUBSCRIPT error from predictor italic_ξ end_POSTSUBSCRIPT + under⏟ start_ARG italic_O ( roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT italic_L start_POSTSUPERSCRIPT - divide start_ARG 4 italic_κ end_ARG start_ARG 3 ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) end_ARG end_POSTSUPERSCRIPT roman_log start_POSTSUPERSCRIPT 1 / 3 end_POSTSUPERSCRIPT ( | script_I | ) ) end_ARG start_POSTSUBSCRIPT retriever approximation error end_POSTSUBSCRIPT
+O(maxLWlog(LW)log(n||)n)retriever generalization errorsubscript𝑂subscript𝐿𝑊𝐿𝑊𝑛𝑛retriever generalization error\displaystyle\qquad\qquad\quad+\underbrace{O\big{(}\frac{\ell_{\max}LW\sqrt{% \log(LW)\log(n|\mathscr{I}|)}}{\sqrt{n}}\big{)}}_{\text{retriever % generalization error}}+ under⏟ start_ARG italic_O ( divide start_ARG roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT italic_L italic_W square-root start_ARG roman_log ( italic_L italic_W ) roman_log ( italic_n | script_I | ) end_ARG end_ARG start_ARG square-root start_ARG italic_n end_ARG end_ARG ) end_ARG start_POSTSUBSCRIPT retriever generalization error end_POSTSUBSCRIPT (32)

By choosing L=n3(dx+dz)6(dx+dz)+8κ𝐿superscript𝑛3subscript𝑑𝑥subscript𝑑𝑧6subscript𝑑𝑥subscript𝑑𝑧8𝜅L=n^{\tfrac{3(d_{x}+d_{z})}{6(d_{x}+d_{z})+8\kappa}}italic_L = italic_n start_POSTSUPERSCRIPT divide start_ARG 3 ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) end_ARG start_ARG 6 ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) + 8 italic_κ end_ARG end_POSTSUPERSCRIPT, and using the data-store size ||=poly(n)𝑝𝑜𝑙𝑦𝑛|\mathscr{I}|=poly(n)| script_I | = italic_p italic_o italic_l italic_y ( italic_n ) and width W=O(dx+dz)𝑊𝑂subscript𝑑𝑥subscript𝑑𝑧W=O(d_{x}+d_{z})italic_W = italic_O ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) we obtain

Excess Risk𝔼X[minzgξ(X,z)]R,(fopt,)error from predictor ξ+O~(maxn2κ3(dx+dz)+4κ)retriever combined error.Excess Risksubscriptsubscript𝔼𝑋delimited-[]subscript𝑧subscript𝑔𝜉𝑋𝑧subscript𝑅superscriptsubscript𝑓opterror from predictor ξsubscript~𝑂subscriptsuperscript𝑛2𝜅3subscript𝑑𝑥subscript𝑑𝑧4𝜅retriever combined error\displaystyle\text{ Excess Risk}\leq\underbrace{\mathbb{E}_{X}\big{[}\min_{z% \in\mathscr{I}}g_{\xi}(X,z)\big{]}-R_{\ell,\mathscr{I}}(f_{{\rm opt},\mathscr{% I}}^{\ell})}_{\text{error from predictor $\xi$}}+\underbrace{\tilde{O}(\ell_{% \max}n^{-\tfrac{2\kappa}{3(d_{x}+d_{z})+4\kappa}})}_{\text{retriever combined % error}}.Excess Risk ≤ under⏟ start_ARG blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ roman_min start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_X , italic_z ) ] - italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT roman_opt , script_I end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) end_ARG start_POSTSUBSCRIPT error from predictor italic_ξ end_POSTSUBSCRIPT + under⏟ start_ARG over~ start_ARG italic_O end_ARG ( roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT italic_n start_POSTSUPERSCRIPT - divide start_ARG 2 italic_κ end_ARG start_ARG 3 ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) + 4 italic_κ end_ARG end_POSTSUPERSCRIPT ) end_ARG start_POSTSUBSCRIPT retriever combined error end_POSTSUBSCRIPT . (33)

B.2 Learning the predictor

We now quantify the excess risk of a predictor ξ𝜉\xiitalic_ξ for a fixed retriever θ𝜃\thetaitalic_θ. For a fixed retriever θ𝜃\thetaitalic_θ, the learning task of the predictor is to minimize

𝔼(X,Y)𝒟XY[𝔼Zpθ(|X)(hξ(X,Z),Y)]=𝔼((X,Z),Y)𝒟XY×pθ(|X)[(hξ(X,Z),Y)|X]\displaystyle\mathbb{E}_{(X,Y)\sim\mathscr{D}_{XY}}[\mathbb{E}_{Z\sim p_{% \theta}(\cdot|X)}\ell(h_{\xi}(X,Z),Y)]=\mathbb{E}_{((X,Z),Y)\sim\mathscr{D}_{% XY}\times p_{\theta}(\cdot|X)}\big{[}\ell(h_{\xi}(X,Z),Y)|X\big{]}blackboard_E start_POSTSUBSCRIPT ( italic_X , italic_Y ) ∼ script_D start_POSTSUBSCRIPT italic_X italic_Y end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ blackboard_E start_POSTSUBSCRIPT italic_Z ∼ italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ | italic_X ) end_POSTSUBSCRIPT roman_ℓ ( italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_X , italic_Z ) , italic_Y ) ] = blackboard_E start_POSTSUBSCRIPT ( ( italic_X , italic_Z ) , italic_Y ) ∼ script_D start_POSTSUBSCRIPT italic_X italic_Y end_POSTSUBSCRIPT × italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ | italic_X ) end_POSTSUBSCRIPT [ roman_ℓ ( italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_X , italic_Z ) , italic_Y ) | italic_X ]

The predictor now learns from the joint distribution 𝒟XY×pθ(|X)\mathscr{D}_{XY}\times p_{\theta}(\cdot|X)script_D start_POSTSUBSCRIPT italic_X italic_Y end_POSTSUBSCRIPT × italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ | italic_X ). We assume that the hardness of the classification task performed by the predictor varies with the selected retriever θ𝜃\thetaitalic_θ.

Similar to retriever learning in Section B.1, for a fixed retriever θ𝜃\thetaitalic_θ, the predictor that minimizes the empirical risk ξ^(θ)^𝜉𝜃\hat{\xi}(\theta)over^ start_ARG italic_ξ end_ARG ( italic_θ ), and the predictor that minimizes the population risk ξ(θ)superscript𝜉𝜃\xi^{\ast}(\theta)italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_θ ) over the class ΞΞ\Xiroman_Ξ are defined as

ξ^(θ)=argminξΞ1ni[n]zpθ(z|xi)(hξ(xi,z),yi),^𝜉𝜃subscriptargmin𝜉Ξ1𝑛subscript𝑖delimited-[]𝑛subscript𝑧subscript𝑝𝜃conditional𝑧subscript𝑥𝑖subscript𝜉subscript𝑥𝑖𝑧subscript𝑦𝑖\displaystyle\hat{\xi}(\theta)=\operatorname*{arg\,min}_{\xi\in\Xi}\frac{1}{n}% \sum_{i\in[n]}\sum_{z\in\mathscr{I}}p_{\theta}(z|x_{i})\ell\big{(}h_{\xi}(x_{i% },z),y_{i}\big{)},over^ start_ARG italic_ξ end_ARG ( italic_θ ) = start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_ξ ∈ roman_Ξ end_POSTSUBSCRIPT divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) roman_ℓ ( italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ,
ξ(θ)=argminξΞ𝔼X[𝔼Zpθ(|X)gξ(X,Z)],\displaystyle\xi^{\ast}(\theta)=\operatorname*{arg\,min}_{\xi\in\Xi}\mathbb{E}% _{X}\big{[}\mathbb{E}_{Z\sim p_{\theta}(\cdot|X)}g_{\xi}(X,Z)\big{]},italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_θ ) = start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_ξ ∈ roman_Ξ end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ blackboard_E start_POSTSUBSCRIPT italic_Z ∼ italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ | italic_X ) end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_X , italic_Z ) ] ,

where gξ(X,Z)=𝔼Y|X(hξ(X,Z),Y)subscript𝑔𝜉𝑋𝑍subscript𝔼conditional𝑌𝑋subscript𝜉𝑋𝑍𝑌g_{\xi}(X,Z)=\mathbb{E}_{Y|X}\ell(h_{\xi}(X,Z),Y)italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_X , italic_Z ) = blackboard_E start_POSTSUBSCRIPT italic_Y | italic_X end_POSTSUBSCRIPT roman_ℓ ( italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_X , italic_Z ) , italic_Y ). We also define the predictor over the class ΞΞ\Xiroman_Ξ with ‘optimal’ retrieval (possibly outside of ΘΘ\Thetaroman_Θ) that minimizes the population risk as ξsuperscript𝜉\xi^{\ast}italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT as ξ=argminξΞ𝔼X[minzgξ(X,z)].superscript𝜉subscriptargmin𝜉Ξsubscript𝔼𝑋delimited-[]subscript𝑧subscript𝑔𝜉𝑋𝑧\xi^{\ast}=\operatorname*{arg\,min}_{\xi\in\Xi}\mathbb{E}_{X}\big{[}\min_{z\in% \mathscr{I}}g_{\xi}(X,z)\big{]}.italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT = start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_ξ ∈ roman_Ξ end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ roman_min start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_X , italic_z ) ] .

Usefulness of data-store:

We start with characterization of the prediction task in the presence of the data-store \mathscr{I}script_I. We consider that there exists a score function h:𝒳×𝒵|𝒴|:subscript𝒳𝒵superscript𝒴h_{*}:\mathscr{X}\times\mathscr{Z}\to\mathbb{R}^{|\mathscr{Y}|}italic_h start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT : script_X × script_Z → blackboard_R start_POSTSUPERSCRIPT | script_Y | end_POSTSUPERSCRIPT and corresponding probability distribution

py(x,z)=exp(hy(x,z))yexp(hy(x,z))superscriptsubscript𝑝𝑦𝑥𝑧superscriptsubscript𝑦𝑥𝑧subscriptsuperscript𝑦superscriptsubscriptsuperscript𝑦𝑥𝑧p_{*}^{y}(x,z)=\frac{\exp(h_{*}^{y}(x,z))}{\sum_{y^{\prime}}\exp(h_{*}^{y^{% \prime}}(x,z))}italic_p start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT ( italic_x , italic_z ) = divide start_ARG roman_exp ( italic_h start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT ( italic_x , italic_z ) ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT roman_exp ( italic_h start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT ( italic_x , italic_z ) ) end_ARG (34)

that approximates well p𝖣XYy(x)Y𝖣XY(y|X=x)superscriptsubscript𝑝subscript𝖣𝑋𝑌𝑦𝑥subscriptsimilar-to𝑌subscript𝖣𝑋𝑌conditional𝑦𝑋𝑥p_{\mathsf{D}_{XY}}^{y}(x)\triangleq\mathbb{P}_{Y\sim\mathsf{D}_{XY}}(y|X=x)italic_p start_POSTSUBSCRIPT sansserif_D start_POSTSUBSCRIPT italic_X italic_Y end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT ( italic_x ) ≜ blackboard_P start_POSTSUBSCRIPT italic_Y ∼ sansserif_D start_POSTSUBSCRIPT italic_X italic_Y end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_y | italic_X = italic_x ) for all x𝒳𝑥𝒳x\in\mathscr{X}italic_x ∈ script_X and y𝒴𝑦𝒴y\in\mathscr{Y}italic_y ∈ script_Y. Furthermore, this score function hsubscripth_{*}italic_h start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT lies coordinate wise in the Sobolev space (see Definition A.4). The Assumption 3.2 captures the above. We restate the assumption here for convenience.

Assumption B.2 (Retrieval quality).

There exists a score function h:𝒳×𝒵|𝒴|:subscript𝒳𝒵superscript𝒴h_{*}:\mathscr{X}\times\mathscr{Z}\to\mathbb{R}^{|\mathscr{Y}|}italic_h start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT : script_X × script_Z → blackboard_R start_POSTSUPERSCRIPT | script_Y | end_POSTSUPERSCRIPT such that

  1. 1.

    for each y𝒴𝑦𝒴y\in\mathscr{Y}italic_y ∈ script_Y, the function hysuperscriptsubscript𝑦h_{*}^{y}italic_h start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT (the y𝑦yitalic_y-th coordinate of hsuperscripth^{*}italic_h start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT) lies in the Sobolev space with κsubscript𝜅\kappa_{\mathscr{I}}italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT derivatives and finite L([1,1]dx+dz)subscript𝐿superscript11subscript𝑑𝑥subscript𝑑𝑧L_{\infty}([-1,1]^{d_{x}+d_{z}})italic_L start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT ( [ - 1 , 1 ] start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ) norm,

  2. 2.

    for any x𝒳𝑥𝒳x\in\mathscr{X}italic_x ∈ script_X there exists a retrieved example z(x)superscript𝑧𝑥z^{*}(x)\in\mathscr{I}italic_z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_x ) ∈ script_I such that for py(x,z)superscriptsubscript𝑝𝑦𝑥𝑧p_{*}^{y}(x,z)italic_p start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT ( italic_x , italic_z ) as defined in Equation (34)

    maxy𝒴supx𝒳|py(x,z(x))p𝖣XYy(x)|c||γ.subscript𝑦𝒴subscriptsupremum𝑥𝒳superscriptsubscript𝑝𝑦𝑥𝑧𝑥superscriptsubscript𝑝subscript𝖣𝑋𝑌𝑦𝑥subscript𝑐superscriptsubscript𝛾\max_{y\in\mathscr{Y}}\sup_{x\in\mathscr{X}}|p_{*}^{y}(x,z(x))-p_{\mathsf{D}_{% XY}}^{y}(x)|\leq c_{\mathscr{I}}|\mathscr{I}|^{-\gamma_{\mathscr{I}}}.roman_max start_POSTSUBSCRIPT italic_y ∈ script_Y end_POSTSUBSCRIPT roman_sup start_POSTSUBSCRIPT italic_x ∈ script_X end_POSTSUBSCRIPT | italic_p start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT ( italic_x , italic_z ( italic_x ) ) - italic_p start_POSTSUBSCRIPT sansserif_D start_POSTSUBSCRIPT italic_X italic_Y end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT ( italic_x ) | ≤ italic_c start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT | script_I | start_POSTSUPERSCRIPT - italic_γ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_POSTSUPERSCRIPT .

Note that the tuple (γ,dz,κ)subscript𝛾subscript𝑑𝑧subscript𝜅(\gamma_{\mathscr{I}},d_{z},\kappa_{\mathscr{I}})( italic_γ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT , italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT , italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT ) defines the usefulness of the data-store \mathscr{I}script_I. In particular, the higher the γsubscript𝛾\gamma_{\mathscr{I}}italic_γ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT the closer the approximation, and the higher the κsubscript𝜅\kappa_{\mathscr{I}}italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT and the smaller the embedding dimension dzsubscript𝑑𝑧d_{z}italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT the ‘easier’ the score function used for this approximation.

Excess risk decomposition

The excess risk decomposition for the learned predictor ξ^(θ)^𝜉𝜃\hat{\xi}(\theta)over^ start_ARG italic_ξ end_ARG ( italic_θ ) takes the following form.

R,(ξ^(θ),θ)R,(fopt,)subscript𝑅^𝜉𝜃𝜃subscript𝑅superscriptsubscript𝑓opt\displaystyle R_{\ell,\mathscr{I}}(\hat{\xi}(\theta),\theta)-R_{\ell,\mathscr{% I}}(f_{{\rm opt},\mathscr{I}}^{\ell})italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( over^ start_ARG italic_ξ end_ARG ( italic_θ ) , italic_θ ) - italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT roman_opt , script_I end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT )
(i)ξ=ξ(θ),ξ^(θ)|1ni[n]zpθ(z|xi)(hξ(xi,z),yi)𝔼X[𝔼Zpθ(|X)gξ(X,Z)]|\displaystyle\qquad\overset{(i)}{\leq}\sum_{\xi=\xi^{\ast}(\theta),\hat{\xi}(% \theta)}\big{|}\frac{1}{n}\sum_{i\in[n]}\sum_{z\in\mathscr{I}}p_{\theta}(z|x_{% i})\ell\big{(}h_{\xi}(x_{i},z),y_{i}\big{)}-\mathbb{E}_{X}\big{[}\mathbb{E}_{Z% \sim p_{\theta}(\cdot|X)}g_{\xi}(X,Z)\big{]}\big{|}start_OVERACCENT ( italic_i ) end_OVERACCENT start_ARG ≤ end_ARG ∑ start_POSTSUBSCRIPT italic_ξ = italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_θ ) , over^ start_ARG italic_ξ end_ARG ( italic_θ ) end_POSTSUBSCRIPT | divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) roman_ℓ ( italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ blackboard_E start_POSTSUBSCRIPT italic_Z ∼ italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ | italic_X ) end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_X , italic_Z ) ] |
+R,(ξ(θ),θ)R,(fopt,)subscript𝑅superscript𝜉𝜃𝜃subscript𝑅superscriptsubscript𝑓opt\displaystyle\qquad\quad+R_{\ell,\mathscr{I}}(\xi^{\ast}(\theta),\theta)-R_{% \ell,\mathscr{I}}(f_{{\rm opt},\mathscr{I}}^{\ell})+ italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_θ ) , italic_θ ) - italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT roman_opt , script_I end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT )
(ii)ξ=ξ(θ),ξ^(θ)|1ni[n]zpθ(z|xi)(hξ(xi,z),yi)𝔼X[𝔼Zpθ(|X)gξ(X,Z)]|\displaystyle\qquad\overset{(ii)}{\leq}\sum_{\xi=\xi^{\ast}(\theta),\hat{\xi}(% \theta)}\big{|}\frac{1}{n}\sum_{i\in[n]}\sum_{z\in\mathscr{I}}p_{\theta}(z|x_{% i})\ell\big{(}h_{\xi}(x_{i},z),y_{i}\big{)}-\mathbb{E}_{X}\big{[}\mathbb{E}_{Z% \sim p_{\theta}(\cdot|X)}g_{\xi}(X,Z)\big{]}\big{|}start_OVERACCENT ( italic_i italic_i ) end_OVERACCENT start_ARG ≤ end_ARG ∑ start_POSTSUBSCRIPT italic_ξ = italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_θ ) , over^ start_ARG italic_ξ end_ARG ( italic_θ ) end_POSTSUBSCRIPT | divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) roman_ℓ ( italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ blackboard_E start_POSTSUBSCRIPT italic_Z ∼ italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ | italic_X ) end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_X , italic_Z ) ] |
+R,(ξ(θ),θ)R,(ξ,θ)0+R,(ξ,θ)R,(fopt,)subscriptsubscript𝑅superscript𝜉𝜃𝜃subscript𝑅superscript𝜉𝜃absent0subscript𝑅superscript𝜉𝜃subscript𝑅superscriptsubscript𝑓opt\displaystyle\qquad\quad+\underbrace{R_{\ell,\mathscr{I}}(\xi^{\ast}(\theta),% \theta)-R_{\ell,\mathscr{I}}(\xi^{\ast},\theta)}_{\leq 0}+R_{\ell,\mathscr{I}}% (\xi^{\ast},\theta)-R_{\ell,\mathscr{I}}(f_{{\rm opt},\mathscr{I}}^{\ell})+ under⏟ start_ARG italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_θ ) , italic_θ ) - italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , italic_θ ) end_ARG start_POSTSUBSCRIPT ≤ 0 end_POSTSUBSCRIPT + italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , italic_θ ) - italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT roman_opt , script_I end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT )
(iii)ξ=ξ(θ),ξ^(θ)|1ni[n]zpθ(z|xi)(hξ(xi,z),yi)𝔼X[𝔼Zpθ(|X)gξ(X,Z)]|generalization error\displaystyle\qquad\overset{(iii)}{\leq}\underbrace{\sum_{\xi=\xi^{\ast}(% \theta),\hat{\xi}(\theta)}\big{|}\frac{1}{n}\sum_{i\in[n]}\sum_{z\in\mathscr{I% }}p_{\theta}(z|x_{i})\ell\big{(}h_{\xi}(x_{i},z),y_{i}\big{)}-\mathbb{E}_{X}% \big{[}\mathbb{E}_{Z\sim p_{\theta}(\cdot|X)}g_{\xi}(X,Z)\big{]}\big{|}}_{% \text{generalization error}}start_OVERACCENT ( italic_i italic_i italic_i ) end_OVERACCENT start_ARG ≤ end_ARG under⏟ start_ARG ∑ start_POSTSUBSCRIPT italic_ξ = italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_θ ) , over^ start_ARG italic_ξ end_ARG ( italic_θ ) end_POSTSUBSCRIPT | divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) roman_ℓ ( italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ blackboard_E start_POSTSUBSCRIPT italic_Z ∼ italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ | italic_X ) end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_X , italic_Z ) ] | end_ARG start_POSTSUBSCRIPT generalization error end_POSTSUBSCRIPT
+R,(ξ,θ)𝔼X[minzgξ(X,z)]retriever error+𝔼X[minzgξ(X,z)]𝔼X[minzgfopt,(X,z)]predictor errorsubscriptsubscript𝑅superscript𝜉𝜃subscript𝔼𝑋delimited-[]subscript𝑧subscript𝑔superscript𝜉𝑋𝑧retriever errorsubscriptsubscript𝔼𝑋delimited-[]subscript𝑧subscript𝑔superscript𝜉𝑋𝑧subscript𝔼𝑋delimited-[]subscript𝑧subscript𝑔superscriptsubscript𝑓opt𝑋𝑧predictor error\displaystyle\qquad\quad+\underbrace{R_{\ell,\mathscr{I}}(\xi^{\ast},\theta)-% \mathbb{E}_{X}\big{[}\min_{z\in\mathscr{I}}g_{\xi^{\ast}}(X,z)\big{]}}_{\text{% retriever error}}+\underbrace{\mathbb{E}_{X}\big{[}\min_{z\in\mathscr{I}}g_{% \xi^{\ast}}(X,z)\big{]}-\mathbb{E}_{X}\big{[}\min_{z\in\mathscr{I}}g_{f_{{\rm opt% },\mathscr{I}}^{\ell}}(X,z)\big{]}}_{\text{predictor error}}+ under⏟ start_ARG italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , italic_θ ) - blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ roman_min start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_X , italic_z ) ] end_ARG start_POSTSUBSCRIPT retriever error end_POSTSUBSCRIPT + under⏟ start_ARG blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ roman_min start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_X , italic_z ) ] - blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ roman_min start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT roman_opt , script_I end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_X , italic_z ) ] end_ARG start_POSTSUBSCRIPT predictor error end_POSTSUBSCRIPT (35)

Note that in the inequality (ii)𝑖𝑖(ii)( italic_i italic_i ), the predictor ξ(θ)superscript𝜉𝜃\xi^{\ast}(\theta)italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_θ ) which is optimised for the fixed retriever θ𝜃\thetaitalic_θ has lower risk compared to the predictor ξsuperscript𝜉\xi^{\ast}italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT, i.e. R,(ξ(θ),θ)R,(ξ,θ)subscript𝑅superscript𝜉𝜃𝜃subscript𝑅superscript𝜉𝜃R_{\ell,\mathscr{I}}(\xi^{\ast}(\theta),\theta)\leq R_{\ell,\mathscr{I}}(\xi^{% \ast},\theta)italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_θ ) , italic_θ ) ≤ italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , italic_θ ).

B.2.1 Approximation error

We specialize our analysis for the log-loss bounded by max>0subscript0\ell_{\max}>0roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT > 0 given as

(hξ(x,z),y)=min(max,log(pξ(y|x,z)))=min(max,log(y𝒴exp(hξy(x,z)))hξy(x,z)).subscript𝜉𝑥𝑧𝑦subscriptsubscript𝑝𝜉conditional𝑦𝑥𝑧subscriptsubscriptsuperscript𝑦𝒴subscriptsuperscriptsuperscript𝑦𝜉𝑥𝑧subscriptsuperscript𝑦𝜉𝑥𝑧\ell(h_{\xi}(x,z),y)=\min(\ell_{\max},-\log(p_{\xi}(y|x,z)))=\min(\ell_{\max},% \log(\sum_{y^{\prime}\in\mathscr{Y}}\exp(h^{y^{\prime}}_{\xi}(x,z)))-h^{y}_{% \xi}(x,z)).roman_ℓ ( italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x , italic_z ) , italic_y ) = roman_min ( roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT , - roman_log ( italic_p start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_y | italic_x , italic_z ) ) ) = roman_min ( roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT , roman_log ( ∑ start_POSTSUBSCRIPT italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ script_Y end_POSTSUBSCRIPT roman_exp ( italic_h start_POSTSUPERSCRIPT italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x , italic_z ) ) ) - italic_h start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x , italic_z ) ) . (36)

Note that we need to bound the predictor error (𝔼X[minzgξ(X,z)]𝔼X[minzgfopt,(X,z)])subscript𝔼𝑋delimited-[]subscript𝑧subscript𝑔superscript𝜉𝑋𝑧subscript𝔼𝑋delimited-[]subscript𝑧subscript𝑔superscriptsubscript𝑓opt𝑋𝑧(\mathbb{E}_{X}\big{[}\min_{z\in\mathscr{I}}g_{\xi^{\ast}}(X,z)\big{]}-\mathbb% {E}_{X}\big{[}\min_{z\in\mathscr{I}}g_{f_{{\rm opt},\mathscr{I}}^{\ell}}(X,z)% \big{]})( blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ roman_min start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_X , italic_z ) ] - blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ roman_min start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT roman_opt , script_I end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_X , italic_z ) ] ) for the bounded log-loss. We want to relate this term to the py(x,z)superscriptsubscript𝑝𝑦𝑥𝑧p_{*}^{y}(x,z)italic_p start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT ( italic_x , italic_z ) (cf. Equation.(34)) for which we have good control over its complexity. We first need a lower bound for 𝔼X[minzgfopt,(X,z)]subscript𝔼𝑋delimited-[]subscript𝑧subscript𝑔superscriptsubscript𝑓opt𝑋𝑧\mathbb{E}_{X}\big{[}\min_{z\in\mathscr{I}}g_{f_{{\rm opt},\mathscr{I}}^{\ell}% }(X,z)\big{]}blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ roman_min start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT roman_opt , script_I end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_X , italic_z ) ] as a function of py(x,z)superscriptsubscript𝑝𝑦𝑥𝑧p_{*}^{y}(x,z)italic_p start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT ( italic_x , italic_z ). We proceed as follows:

𝔼X[minzgfopt,(X,z)]subscript𝔼𝑋delimited-[]subscript𝑧subscript𝑔superscriptsubscript𝑓opt𝑋𝑧\displaystyle\mathbb{E}_{X}\big{[}\min_{z\in\mathscr{I}}g_{f_{{\rm opt},% \mathscr{I}}^{\ell}}(X,z)\big{]}blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ roman_min start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT roman_opt , script_I end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_X , italic_z ) ]
(i)𝔼X[y𝒴p𝖣XYy(X)min(max,ln(p𝖣XYy(X)))](|𝒴|1)exp(max)𝑖subscript𝔼𝑋delimited-[]subscript𝑦𝒴superscriptsubscript𝑝subscript𝖣𝑋𝑌𝑦𝑋subscriptsuperscriptsubscript𝑝subscript𝖣𝑋𝑌𝑦𝑋𝒴1subscript\displaystyle\qquad\overset{(i)}{\geq}\mathbb{E}_{X}\big{[}\sum_{y\in\mathscr{% Y}}p_{\mathsf{D}_{XY}}^{y}(X)\min(\ell_{\max},-\ln(p_{\mathsf{D}_{XY}}^{y}(X))% )\big{]}-(|\mathscr{Y}|-1)\exp(-\ell_{\max})start_OVERACCENT ( italic_i ) end_OVERACCENT start_ARG ≥ end_ARG blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ ∑ start_POSTSUBSCRIPT italic_y ∈ script_Y end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT sansserif_D start_POSTSUBSCRIPT italic_X italic_Y end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT ( italic_X ) roman_min ( roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT , - roman_ln ( italic_p start_POSTSUBSCRIPT sansserif_D start_POSTSUBSCRIPT italic_X italic_Y end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT ( italic_X ) ) ) ] - ( | script_Y | - 1 ) roman_exp ( - roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT )
(ii)𝔼X[y𝒴p𝖣XYy(X)min(max,ln(py(X,z(X)))](|𝒴|1)exp(max)\displaystyle\qquad\overset{(ii)}{\geq}\mathbb{E}_{X}\big{[}\sum_{y\in\mathscr% {Y}}p_{\mathsf{D}_{XY}}^{y}(X)\min(\ell_{\max},-\ln(p_{*}^{y}(X,z^{*}(X)))\big% {]}-(|\mathscr{Y}|-1)\exp(-\ell_{\max})start_OVERACCENT ( italic_i italic_i ) end_OVERACCENT start_ARG ≥ end_ARG blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ ∑ start_POSTSUBSCRIPT italic_y ∈ script_Y end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT sansserif_D start_POSTSUBSCRIPT italic_X italic_Y end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT ( italic_X ) roman_min ( roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT , - roman_ln ( italic_p start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT ( italic_X , italic_z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_X ) ) ) ] - ( | script_Y | - 1 ) roman_exp ( - roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT )
exp(max)𝔼X[maxy𝒴|py(X,z(X))p𝖣XYy(X)|]subscriptsubscript𝔼𝑋delimited-[]subscript𝑦𝒴superscriptsubscript𝑝𝑦𝑋superscript𝑧𝑋superscriptsubscript𝑝subscript𝖣𝑋𝑌𝑦𝑋\displaystyle\quad\quad-\exp(\ell_{\max})\,\mathbb{E}_{X}\big{[}\max_{y\in% \mathscr{Y}}|p_{*}^{y}(X,z^{*}(X))-p_{\mathsf{D}_{XY}}^{y}(X)|\big{]}- roman_exp ( roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ) blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ roman_max start_POSTSUBSCRIPT italic_y ∈ script_Y end_POSTSUBSCRIPT | italic_p start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT ( italic_X , italic_z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_X ) ) - italic_p start_POSTSUBSCRIPT sansserif_D start_POSTSUBSCRIPT italic_X italic_Y end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT ( italic_X ) | ]
(iii)𝔼X[y𝒴p𝖣XYy(X)min(max,ln(py(X,z(X)))](|𝒴|1)exp(max)c||γexp(max)\displaystyle\qquad\overset{(iii)}{\geq}\mathbb{E}_{X}\big{[}\sum_{y\in% \mathscr{Y}}p_{\mathsf{D}_{XY}}^{y}(X)\min(\ell_{\max},-\ln(p_{*}^{y}(X,z^{*}(% X)))\big{]}-(|\mathscr{Y}|-1)\exp(-\ell_{\max})-c_{\mathscr{I}}|\mathscr{I}|^{% -\gamma_{\mathscr{I}}}\exp(\ell_{\max})start_OVERACCENT ( italic_i italic_i italic_i ) end_OVERACCENT start_ARG ≥ end_ARG blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ ∑ start_POSTSUBSCRIPT italic_y ∈ script_Y end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT sansserif_D start_POSTSUBSCRIPT italic_X italic_Y end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT ( italic_X ) roman_min ( roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT , - roman_ln ( italic_p start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT ( italic_X , italic_z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_X ) ) ) ] - ( | script_Y | - 1 ) roman_exp ( - roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ) - italic_c start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT | script_I | start_POSTSUPERSCRIPT - italic_γ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_POSTSUPERSCRIPT roman_exp ( roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT )
=(iv)𝔼X[gh(X,z(X))](|𝒴|1)exp(max)c||γexp(max)𝑖𝑣subscript𝔼𝑋delimited-[]subscript𝑔subscript𝑋superscript𝑧𝑋𝒴1subscriptsubscript𝑐superscriptsubscript𝛾subscript\displaystyle\qquad\overset{(iv)}{=}\mathbb{E}_{X}\big{[}g_{h_{*}}(X,z^{*}(X))% \big{]}-(|\mathscr{Y}|-1)\exp(-\ell_{\max})-c_{\mathscr{I}}|\mathscr{I}|^{-% \gamma_{\mathscr{I}}}\exp(\ell_{\max})start_OVERACCENT ( italic_i italic_v ) end_OVERACCENT start_ARG = end_ARG blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ italic_g start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_X , italic_z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_X ) ) ] - ( | script_Y | - 1 ) roman_exp ( - roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ) - italic_c start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT | script_I | start_POSTSUPERSCRIPT - italic_γ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_POSTSUPERSCRIPT roman_exp ( roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ) (37)

In the first inequality, applying Proposition A.10 to our setting with C=max𝐶subscriptC=\ell_{\max}italic_C = roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT and K=|𝒴|𝐾𝒴K=|\mathscr{Y}|italic_K = | script_Y | we obtain the lower bound. The second inequality follows from mean-value theorem as below,

|min(C,log(x))min(C,log(y))|supx|xmin(C,log(x))|×|xy|exp(C)|xy|𝐶𝑥𝐶𝑦subscriptsupremum𝑥𝑥𝐶𝑥𝑥𝑦𝐶𝑥𝑦\displaystyle|\min(C,-\log(x))-\min(C,-\log(y))|\leq\sup_{x}\big{\lvert}\frac{% \partial}{\partial x}\min(C,-\log(x))\big{\rvert}\times|x-y|\leq\exp(C)|x-y|| roman_min ( italic_C , - roman_log ( italic_x ) ) - roman_min ( italic_C , - roman_log ( italic_y ) ) | ≤ roman_sup start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT | divide start_ARG ∂ end_ARG start_ARG ∂ italic_x end_ARG roman_min ( italic_C , - roman_log ( italic_x ) ) | × | italic_x - italic_y | ≤ roman_exp ( italic_C ) | italic_x - italic_y |

Next inequality (iii)𝑖𝑖𝑖(iii)( italic_i italic_i italic_i ) is obtained by Assumption B.2 with z(x)superscript𝑧𝑥z^{*}(x)italic_z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_x ) is ad defined therein. The final inequality substitutes gh(x,z(x))=𝔼Y|X=x[(h(x,z(x)),y)]subscript𝑔superscript𝑥superscript𝑧𝑥subscript𝔼conditional𝑌𝑋𝑥delimited-[]subscript𝑥superscript𝑧𝑥𝑦g_{h^{*}}(x,z^{*}(x))=\mathbb{E}_{Y|X=x}[\ell(h_{*}(x,z^{*}(x)),y)]italic_g start_POSTSUBSCRIPT italic_h start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_x , italic_z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_x ) ) = blackboard_E start_POSTSUBSCRIPT italic_Y | italic_X = italic_x end_POSTSUBSCRIPT [ roman_ℓ ( italic_h start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT ( italic_x , italic_z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_x ) ) , italic_y ) ] where h(x,z)subscript𝑥𝑧h_{*}(x,z)italic_h start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT ( italic_x , italic_z ) is the score function used in Equation (34).

We now derive an upper bound for the predictor error part of our excess risk bound in Equation (35). Let ξΞ𝜉Ξ\xi\in\Xiitalic_ξ ∈ roman_Ξ be an arbitrary predictor

Predictor Error 𝔼X[minzgξ(X,z)]𝔼X[minzgfopt,(X,z)]absentsubscript𝔼𝑋delimited-[]subscript𝑧subscript𝑔superscript𝜉𝑋𝑧subscript𝔼𝑋delimited-[]subscript𝑧subscript𝑔superscriptsubscript𝑓opt𝑋𝑧\displaystyle\triangleq\mathbb{E}_{X}\big{[}\min_{z\in\mathscr{I}}g_{\xi^{\ast% }}(X,z)\big{]}-\mathbb{E}_{X}\big{[}\min_{z\in\mathscr{I}}g_{f_{{\rm opt},% \mathscr{I}}^{\ell}}(X,z)\big{]}≜ blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ roman_min start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_X , italic_z ) ] - blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ roman_min start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT roman_opt , script_I end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_X , italic_z ) ]
(i)𝔼X[minzgξ(X,z)]𝔼X[gh(X,z(X))]+(|𝒴|1)exp(max)+c||γexp(max)𝑖subscript𝔼𝑋delimited-[]subscript𝑧subscript𝑔superscript𝜉𝑋𝑧subscript𝔼𝑋delimited-[]subscript𝑔subscript𝑋superscript𝑧𝑋𝒴1subscriptsubscript𝑐superscriptsubscript𝛾subscript\displaystyle\overset{(i)}{\leq}\mathbb{E}_{X}\big{[}\min_{z\in\mathscr{I}}g_{% \xi^{\ast}}(X,z)\big{]}-\mathbb{E}_{X}\big{[}g_{h_{*}}(X,z^{*}(X))\big{]}+(|% \mathscr{Y}|-1)\exp(-\ell_{\max})+c_{\mathscr{I}}|\mathscr{I}|^{-\gamma_{% \mathscr{I}}}\exp(\ell_{\max})start_OVERACCENT ( italic_i ) end_OVERACCENT start_ARG ≤ end_ARG blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ roman_min start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_X , italic_z ) ] - blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ italic_g start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_X , italic_z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_X ) ) ] + ( | script_Y | - 1 ) roman_exp ( - roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ) + italic_c start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT | script_I | start_POSTSUPERSCRIPT - italic_γ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_POSTSUPERSCRIPT roman_exp ( roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT )
(ii)𝔼X[minzgξ(X,z)]𝔼X[gh(X,z(X))]+(|𝒴|1)exp(max)+c||γexp(max)𝑖𝑖subscript𝔼𝑋delimited-[]subscript𝑧subscript𝑔𝜉𝑋𝑧subscript𝔼𝑋delimited-[]subscript𝑔subscript𝑋superscript𝑧𝑋𝒴1subscriptsubscript𝑐superscriptsubscript𝛾subscript\displaystyle\overset{(ii)}{\leq}\mathbb{E}_{X}\big{[}\min_{z\in\mathscr{I}}g_% {\xi}(X,z)\big{]}-\mathbb{E}_{X}\big{[}g_{h_{*}}(X,z^{*}(X))\big{]}+(|\mathscr% {Y}|-1)\exp(-\ell_{\max})+c_{\mathscr{I}}|\mathscr{I}|^{-\gamma_{\mathscr{I}}}% \exp(\ell_{\max})start_OVERACCENT ( italic_i italic_i ) end_OVERACCENT start_ARG ≤ end_ARG blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ roman_min start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_X , italic_z ) ] - blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ italic_g start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_X , italic_z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_X ) ) ] + ( | script_Y | - 1 ) roman_exp ( - roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ) + italic_c start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT | script_I | start_POSTSUPERSCRIPT - italic_γ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_POSTSUPERSCRIPT roman_exp ( roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT )
(iii)𝔼X[gξ(X,z(X))]𝔼X[gh(X,z(X))]+(|𝒴|1)exp(max)+c||γexp(max)𝑖𝑖𝑖subscript𝔼𝑋delimited-[]subscript𝑔𝜉𝑋superscript𝑧𝑋subscript𝔼𝑋delimited-[]subscript𝑔subscript𝑋superscript𝑧𝑋𝒴1subscriptsubscript𝑐superscriptsubscript𝛾subscript\displaystyle\overset{(iii)}{\leq}\mathbb{E}_{X}\big{[}g_{\xi}(X,z^{*}(X))\big% {]}-\mathbb{E}_{X}\big{[}g_{h_{*}}(X,z^{*}(X))\big{]}+(|\mathscr{Y}|-1)\exp(-% \ell_{\max})+c_{\mathscr{I}}|\mathscr{I}|^{-\gamma_{\mathscr{I}}}\exp(\ell_{% \max})start_OVERACCENT ( italic_i italic_i italic_i ) end_OVERACCENT start_ARG ≤ end_ARG blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_X , italic_z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_X ) ) ] - blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ italic_g start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_X , italic_z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_X ) ) ] + ( | script_Y | - 1 ) roman_exp ( - roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ) + italic_c start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT | script_I | start_POSTSUPERSCRIPT - italic_γ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_POSTSUPERSCRIPT roman_exp ( roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT )

The second inequality follows by substituting the lower bound of 𝔼X[minzgfopt,(X,z)]subscript𝔼𝑋delimited-[]subscript𝑧subscript𝑔superscriptsubscript𝑓opt𝑋𝑧\mathbb{E}_{X}\big{[}\min_{z\in\mathscr{I}}g_{f_{{\rm opt},\mathscr{I}}^{\ell}% }(X,z)\big{]}blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ roman_min start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT roman_opt , script_I end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_X , italic_z ) ] from Equation (37). As ξsuperscript𝜉\xi^{\ast}italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT optimizes \ellroman_ℓ-risk over ΞΞ\Xiroman_Ξ, we can substitute with the arbitrary predictor ξ𝜉\xiitalic_ξ to obtain an upper bound. The final inequality is obtained by substituting z(X)superscript𝑧𝑋z^{*}(X)italic_z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_X ) instead of minimizing with respect to z𝑧z\in\mathscr{I}italic_z ∈ script_I. Note that the final inequality holds for all ξΞ𝜉Ξ\xi\in\Xiitalic_ξ ∈ roman_Ξ as the initial choice of ξ𝜉\xiitalic_ξ was arbitrary.

Bounding the term 𝔼X[gξ(X,z(X))]𝔼X[gh(X,z(X))]subscript𝔼𝑋delimited-[]subscript𝑔𝜉𝑋superscript𝑧𝑋subscript𝔼𝑋delimited-[]subscript𝑔subscript𝑋superscript𝑧𝑋\mathbb{E}_{X}\big{[}g_{\xi}(X,z^{*}(X))\big{]}-\mathbb{E}_{X}\big{[}g_{h_{*}}% (X,z^{*}(X))\big{]}blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_X , italic_z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_X ) ) ] - blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ italic_g start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_X , italic_z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_X ) ) ], is similar to bounding the \ellroman_ℓ-risk for classification with the data distribution (X=x,Z=z,Y=y)=𝖣XY(X=x,Y=y)𝟙(z=z(X))formulae-sequence𝑋𝑥formulae-sequence𝑍𝑧𝑌𝑦subscriptsubscript𝖣𝑋𝑌formulae-sequence𝑋𝑥𝑌𝑦1𝑧superscript𝑧𝑋\mathbb{P}(X=x,Z=z,Y=y)=\mathbb{P}_{\mathsf{D}_{XY}}(X=x,Y=y)\mathbbm{1}(z=z^{% *}(X))blackboard_P ( italic_X = italic_x , italic_Z = italic_z , italic_Y = italic_y ) = blackboard_P start_POSTSUBSCRIPT sansserif_D start_POSTSUBSCRIPT italic_X italic_Y end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_X = italic_x , italic_Y = italic_y ) blackboard_1 ( italic_z = italic_z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_X ) ). Our strategy is to bound \ellroman_ℓ-risk with Lsubscript𝐿L_{\infty}italic_L start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT distance between the score functions hξy(x,z)superscriptsubscript𝜉𝑦𝑥𝑧h_{\xi}^{y}(x,z)italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT ( italic_x , italic_z ) and the score function hy(x,z)superscriptsubscript𝑦𝑥𝑧h_{*}^{y}(x,z)italic_h start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT ( italic_x , italic_z ) which lies in the Sobolev space as given in the Assumption B.2. In particular, we have the following Lsubscript𝐿L_{\infty}italic_L start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT norm bound.

𝔼X[gξ(X,z(X))]𝔼X[gh(X,z(X))]subscript𝔼𝑋delimited-[]subscript𝑔𝜉𝑋superscript𝑧𝑋subscript𝔼𝑋delimited-[]subscript𝑔subscript𝑋superscript𝑧𝑋\displaystyle\mathbb{E}_{X}\big{[}g_{\xi}(X,z^{*}(X))\big{]}-\mathbb{E}_{X}% \big{[}g_{h_{*}}(X,z^{*}(X))\big{]}blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_X , italic_z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_X ) ) ] - blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ italic_g start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_X , italic_z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_X ) ) ]
=(i)𝔼XY[(hξY(X,z(X)))(hY(X,z(X)))]𝑖subscript𝔼𝑋𝑌delimited-[]superscriptsubscript𝜉𝑌𝑋superscript𝑧𝑋superscriptsubscript𝑌𝑋superscript𝑧𝑋\displaystyle\qquad\overset{(i)}{=}\mathbb{E}_{XY}\big{[}\ell(h_{\xi}^{Y}(X,z^% {*}(X)))-\ell(h_{*}^{Y}(X,z^{*}(X)))\big{]}start_OVERACCENT ( italic_i ) end_OVERACCENT start_ARG = end_ARG blackboard_E start_POSTSUBSCRIPT italic_X italic_Y end_POSTSUBSCRIPT [ roman_ℓ ( italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_Y end_POSTSUPERSCRIPT ( italic_X , italic_z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_X ) ) ) - roman_ℓ ( italic_h start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_Y end_POSTSUPERSCRIPT ( italic_X , italic_z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_X ) ) ) ]
(ii)𝔼XY[|hξY(X,z(X))hY(X,z(X))|+maxy𝒴|hξy(X,z(X))hy(X,z(X))|]𝑖𝑖subscript𝔼𝑋𝑌delimited-[]superscriptsubscript𝜉𝑌𝑋superscript𝑧𝑋superscriptsubscript𝑌𝑋superscript𝑧𝑋subscript𝑦𝒴superscriptsubscript𝜉𝑦𝑋superscript𝑧𝑋superscriptsubscript𝑦𝑋superscript𝑧𝑋\displaystyle\qquad\overset{(ii)}{\leq}\mathbb{E}_{XY}\big{[}|h_{\xi}^{Y}(X,z^% {*}(X))-h_{*}^{Y}(X,z^{*}(X))|+\max_{y\in\mathscr{Y}}|h_{\xi}^{y}(X,z^{*}(X))-% h_{*}^{y}(X,z^{*}(X))|\big{]}start_OVERACCENT ( italic_i italic_i ) end_OVERACCENT start_ARG ≤ end_ARG blackboard_E start_POSTSUBSCRIPT italic_X italic_Y end_POSTSUBSCRIPT [ | italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_Y end_POSTSUPERSCRIPT ( italic_X , italic_z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_X ) ) - italic_h start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_Y end_POSTSUPERSCRIPT ( italic_X , italic_z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_X ) ) | + roman_max start_POSTSUBSCRIPT italic_y ∈ script_Y end_POSTSUBSCRIPT | italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT ( italic_X , italic_z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_X ) ) - italic_h start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT ( italic_X , italic_z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_X ) ) | ]
(iii)2×𝔼X[maxy𝒴|hξy(X,z(X))hy(X,z(X))|]𝑖𝑖𝑖2subscript𝔼𝑋delimited-[]subscript𝑦𝒴superscriptsubscript𝜉𝑦𝑋superscript𝑧𝑋superscriptsubscript𝑦𝑋superscript𝑧𝑋\displaystyle\qquad\overset{(iii)}{\leq}2\times\mathbb{E}_{X}\big{[}\max_{y\in% \mathscr{Y}}|h_{\xi}^{y}(X,z^{*}(X))-h_{*}^{y}(X,z^{*}(X))|\big{]}start_OVERACCENT ( italic_i italic_i italic_i ) end_OVERACCENT start_ARG ≤ end_ARG 2 × blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ roman_max start_POSTSUBSCRIPT italic_y ∈ script_Y end_POSTSUBSCRIPT | italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT ( italic_X , italic_z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_X ) ) - italic_h start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT ( italic_X , italic_z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_X ) ) | ]

The inequality (ii)𝑖𝑖(ii)( italic_i italic_i ) follows by substituting the bounded log-loss, and using the fact that for any two s,sK𝑠superscript𝑠superscript𝐾s,s^{\prime}\in\mathbb{R}^{K}italic_s , italic_s start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT, |log(kexp(sk))log(kexp(sk))|maxk|sksk|subscript𝑘subscript𝑠𝑘subscript𝑘subscriptsuperscript𝑠𝑘subscript𝑘subscript𝑠𝑘subscriptsuperscript𝑠𝑘|\log(\sum_{k}\exp(s_{k}))-\log(\sum_{k}\exp(s^{\prime}_{k}))|\leq\max_{k}|s_{% k}-s^{\prime}_{k}|| roman_log ( ∑ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT roman_exp ( italic_s start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) ) - roman_log ( ∑ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT roman_exp ( italic_s start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) ) | ≤ roman_max start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT | italic_s start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - italic_s start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT |. The final inequality (iii)𝑖𝑖𝑖(iii)( italic_i italic_i italic_i ) follows by bounding the first term by second.

We note that the above holds for all ξ𝜉\xiitalic_ξ. This gives the general approximation error bound as

Predictor ErrorinfξΞ2𝔼X[maxy𝒴|hξy(X,z(X))hy(X,z(X))|]+(|𝒴|1)exp(max)+c||γexp(max).Predictor Errorsubscriptinfimum𝜉Ξ2subscript𝔼𝑋delimited-[]subscript𝑦𝒴superscriptsubscript𝜉𝑦𝑋superscript𝑧𝑋superscriptsubscript𝑦𝑋superscript𝑧𝑋𝒴1subscriptsubscript𝑐superscriptsubscript𝛾subscript\text{Predictor Error}\leq\inf_{\xi\in\Xi}2\mathbb{E}_{X}\big{[}\max_{y\in% \mathscr{Y}}|h_{\xi}^{y}(X,z^{*}(X))-h_{*}^{y}(X,z^{*}(X))|\big{]}+(|\mathscr{% Y}|-1)\exp(-\ell_{\max})+c_{\mathscr{I}}|\mathscr{I}|^{-\gamma_{\mathscr{I}}}% \exp(\ell_{\max}).Predictor Error ≤ roman_inf start_POSTSUBSCRIPT italic_ξ ∈ roman_Ξ end_POSTSUBSCRIPT 2 blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ roman_max start_POSTSUBSCRIPT italic_y ∈ script_Y end_POSTSUBSCRIPT | italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT ( italic_X , italic_z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_X ) ) - italic_h start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT ( italic_X , italic_z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_X ) ) | ] + ( | script_Y | - 1 ) roman_exp ( - roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ) + italic_c start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT | script_I | start_POSTSUPERSCRIPT - italic_γ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_POSTSUPERSCRIPT roman_exp ( roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ) . (38)

Note the predictor approximation error is independent of retriever learning as it is compared with respect to the Bayes optimal retriever (i.e. minzgξ(x,z)subscript𝑧subscript𝑔𝜉𝑥𝑧\min_{z\in\mathscr{I}}g_{\xi}(x,z)roman_min start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x , italic_z )) as seen in the Equation (35).

B.2.2 Generalization error

The generalization error in Equation (35) can be bounded in a similar manner as the retriever learning in Section B.1. Note that the predictor is learnt over the space ΞΞ\Xiroman_Ξ while the retriever is fixed in this setup.

|𝔼X[𝔼Zpθ(|X)gξ^(θ)(X,Z)]1ni[n]zpθ(z|xi)(hξ^(θ)(xi,z),yi)|\displaystyle|\mathbb{E}_{X}\big{[}\mathbb{E}_{Z\sim p_{\theta}(\cdot|X)}g_{% \hat{\xi}(\theta)}(X,Z)\big{]}-\frac{1}{n}\sum_{i\in[n]}\sum_{z\in\mathscr{I}}% p_{\theta}(z|x_{i})\ell\big{(}h_{\hat{\xi}(\theta)}(x_{i},z),y_{i}\big{)}|| blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ blackboard_E start_POSTSUBSCRIPT italic_Z ∼ italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ | italic_X ) end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT over^ start_ARG italic_ξ end_ARG ( italic_θ ) end_POSTSUBSCRIPT ( italic_X , italic_Z ) ] - divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) roman_ℓ ( italic_h start_POSTSUBSCRIPT over^ start_ARG italic_ξ end_ARG ( italic_θ ) end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) |
(i)2𝔼𝝈[maxξΞ1ni[n]σizpθ(z|xi)(hξ(xi,z),yi)]+3maxlog(2/δ)n𝑖2subscript𝔼𝝈delimited-[]subscript𝜉Ξ1𝑛subscript𝑖delimited-[]𝑛subscript𝜎𝑖subscript𝑧subscript𝑝𝜃conditional𝑧subscript𝑥𝑖subscript𝜉subscript𝑥𝑖𝑧subscript𝑦𝑖3subscript2𝛿𝑛\displaystyle\qquad\overset{(i)}{\leq}2\mathbb{E}_{\bm{\sigma}}\Big{[}\max_{% \xi\in\Xi}\frac{1}{n}\sum_{i\in[n]}\sigma_{i}\sum_{z\in\mathscr{I}}p_{\theta}(% z|x_{i})\ell\big{(}h_{\xi}(x_{i},z),y_{i}\big{)}\Big{]}+3\ell_{\max}\sqrt{% \tfrac{\log(2/\delta)}{n}}start_OVERACCENT ( italic_i ) end_OVERACCENT start_ARG ≤ end_ARG 2 blackboard_E start_POSTSUBSCRIPT bold_italic_σ end_POSTSUBSCRIPT [ roman_max start_POSTSUBSCRIPT italic_ξ ∈ roman_Ξ end_POSTSUBSCRIPT divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT italic_σ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) roman_ℓ ( italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ] + 3 roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT square-root start_ARG divide start_ARG roman_log ( 2 / italic_δ ) end_ARG start_ARG italic_n end_ARG end_ARG
(ii)2×infε[0,cθ/2](4ε+12nεcθ/2log(𝒩(Ξ,ν,2,[n],θ))𝑑ν)+3maxlog(2/δ)n\displaystyle\qquad\overset{(ii)}{\leq}2\times\inf_{\varepsilon\in[0,c_{\theta% }/2]}\big{(}4\varepsilon+\tfrac{12}{\sqrt{n}}\int_{\varepsilon}^{c_{\theta}/2}% \sqrt{\log(\mathcal{N}(\Xi,\nu,\|\cdot\|_{2,[n],\theta}))}d\nu\big{)}+3\ell_{% \max}\sqrt{\tfrac{\log(2/\delta)}{n}}start_OVERACCENT ( italic_i italic_i ) end_OVERACCENT start_ARG ≤ end_ARG 2 × roman_inf start_POSTSUBSCRIPT italic_ε ∈ [ 0 , italic_c start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT / 2 ] end_POSTSUBSCRIPT ( 4 italic_ε + divide start_ARG 12 end_ARG start_ARG square-root start_ARG italic_n end_ARG end_ARG ∫ start_POSTSUBSCRIPT italic_ε end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_c start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT / 2 end_POSTSUPERSCRIPT square-root start_ARG roman_log ( caligraphic_N ( roman_Ξ , italic_ν , ∥ ⋅ ∥ start_POSTSUBSCRIPT 2 , [ italic_n ] , italic_θ end_POSTSUBSCRIPT ) ) end_ARG italic_d italic_ν ) + 3 roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT square-root start_ARG divide start_ARG roman_log ( 2 / italic_δ ) end_ARG start_ARG italic_n end_ARG end_ARG

The final inequality again follows using covering number based bounds with chaining (cf.  Shalev-Shwartz and Ben-David [2014]). We have used for a fixed retriever θ𝜃\thetaitalic_θ

cθ=supξΞ(1ni[n](zpθ(z|xi)(hξ(xi,z),yi))2)1/2,subscript𝑐𝜃subscriptsupremum𝜉Ξsuperscript1𝑛subscript𝑖delimited-[]𝑛superscriptsubscript𝑧subscript𝑝𝜃conditional𝑧subscript𝑥𝑖subscript𝜉subscript𝑥𝑖𝑧subscript𝑦𝑖212c_{\theta}=\sup_{\xi\in\Xi}\Big{(}\tfrac{1}{n}\sum_{i\in[n]}\big{(}\sum_{z\in% \mathscr{I}}p_{\theta}(z|x_{i})\ell\big{(}h_{\xi}(x_{i},z),y_{i}\big{)}\big{)}% ^{2}\Big{)}^{1/2},italic_c start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT = roman_sup start_POSTSUBSCRIPT italic_ξ ∈ roman_Ξ end_POSTSUBSCRIPT ( divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT ( ∑ start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) roman_ℓ ( italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT ,

and 𝒩(Ξ,ν,2,[n],θ)\mathcal{N}(\Xi,\nu,\|\cdot\|_{2,[n],\theta})caligraphic_N ( roman_Ξ , italic_ν , ∥ ⋅ ∥ start_POSTSUBSCRIPT 2 , [ italic_n ] , italic_θ end_POSTSUBSCRIPT ) denote the covering number of the predictor function class ΞΞ\Xiroman_Ξ with error ν𝜈\nuitalic_ν in L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT norm w.r.t. the set {(xi,yi):i[n]}conditional-setsubscript𝑥𝑖subscript𝑦𝑖𝑖delimited-[]𝑛\{(x_{i},y_{i}):i\in[n]\}{ ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) : italic_i ∈ [ italic_n ] } and fixed θ𝜃\thetaitalic_θ,

𝐮2,[n],θ:=(1ni[n](zpθ(z|xi)ui,z)2)1/2,𝐮n×||.formulae-sequenceassignsubscriptnorm𝐮2delimited-[]𝑛𝜃superscript1𝑛subscript𝑖delimited-[]𝑛superscriptsubscript𝑧subscript𝑝𝜃conditional𝑧subscript𝑥𝑖subscript𝑢𝑖𝑧212for-all𝐮superscript𝑛\|\mathbf{u}\|_{2,[n],\theta}:=\Big{(}\tfrac{1}{n}\sum_{i\in[n]}\big{(}\sum_{z% \in\mathscr{I}}p_{\theta}(z|x_{i})u_{i,z}\big{)}^{2}\Big{)}^{1/2},\,\forall% \mathbf{u}\in\mathbb{R}^{n\times|\mathscr{I}|}.∥ bold_u ∥ start_POSTSUBSCRIPT 2 , [ italic_n ] , italic_θ end_POSTSUBSCRIPT := ( divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT ( ∑ start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) italic_u start_POSTSUBSCRIPT italic_i , italic_z end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT , ∀ bold_u ∈ blackboard_R start_POSTSUPERSCRIPT italic_n × | script_I | end_POSTSUPERSCRIPT .

As ξ(θ)superscript𝜉𝜃\xi^{\ast}(\theta)italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_θ ) is fixed for a fixed θ𝜃\thetaitalic_θ, we can directly bound without any union over the learner/predictor space,

|𝔼X[𝔼Zpθ(|X)gξ(θ)(X,Z)]1ni[n]zpθ(z|xi)(hξ(θ)(xi,z),yi)|3maxlog(2/δ)n.|\mathbb{E}_{X}\big{[}\mathbb{E}_{Z\sim p_{\theta}(\cdot|X)}g_{\xi^{\ast}(% \theta)}(X,Z)\big{]}-\frac{1}{n}\sum_{i\in[n]}\sum_{z\in\mathscr{I}}p_{\theta}% (z|x_{i})\ell\big{(}h_{\xi^{\ast}(\theta)}(x_{i},z),y_{i}\big{)}|\leq 3\ell_{% \max}\sqrt{\tfrac{\log(2/\delta)}{n}}.| blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ blackboard_E start_POSTSUBSCRIPT italic_Z ∼ italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ | italic_X ) end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_θ ) end_POSTSUBSCRIPT ( italic_X , italic_Z ) ] - divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) roman_ℓ ( italic_h start_POSTSUBSCRIPT italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_θ ) end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) | ≤ 3 roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT square-root start_ARG divide start_ARG roman_log ( 2 / italic_δ ) end_ARG start_ARG italic_n end_ARG end_ARG .

B.2.3 Instantiation of MLP predictor

As a concrete example, we now consider the space Ξ=MLP(dx+dz,𝒴;W,L)ΞMLPsuperscriptsubscript𝑑𝑥subscript𝑑𝑧superscript𝒴𝑊𝐿\Xi={\rm MLP}(\mathbb{R}^{d_{x}+d_{z}},\mathbb{R}^{\mathscr{Y}};W,L)roman_Ξ = roman_MLP ( blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_POSTSUPERSCRIPT , blackboard_R start_POSTSUPERSCRIPT script_Y end_POSTSUPERSCRIPT ; italic_W , italic_L ) as defined in Equation (27).

Approximation error of MLP predictor:

Our approximation results rely mainly on the results in Siegel [2023]. The key difference here is the output is now |𝒴|𝒴|\mathscr{Y}|| script_Y | dimensional. We find an MLP of depth L𝐿Litalic_L and width at most W=O(dx+dz)superscript𝑊𝑂subscript𝑑𝑥subscript𝑑𝑧W^{\prime}=O(d_{x}+d_{z})italic_W start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = italic_O ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) to individually approximate the functions hy(x,z)superscriptsubscript𝑦𝑥𝑧h_{*}^{y}(x,z)italic_h start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT ( italic_x , italic_z ) for each y𝒴𝑦𝒴y\in\mathscr{Y}italic_y ∈ script_Y. Later we can join these networks in parallel to obtain a final network with depth L𝐿Litalic_L and width at most O((dx+dz)|𝒴|)𝑂subscript𝑑𝑥subscript𝑑𝑧𝒴O((d_{x}+d_{z})|\mathscr{Y}|)italic_O ( ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) | script_Y | ). In principle, these networks may share sub-networks (e.g. the bit extraction networks, the sub-domain indexation network for p=q𝑝𝑞p=qitalic_p = italic_q in Siegel [2023]) used for constructing the approximation. However, this is out of scope for this work, and we leave this open.

From Theorem 1 in Siegel [2023], by taking p=q=𝑝𝑞p=q=\inftyitalic_p = italic_q = ∞ in the theorem statement, under Assumption B.2 we get that for each y𝒴𝑦𝒴y\in\mathscr{Y}italic_y ∈ script_Y there exists a MLP fyMLP(dx+dz,;W,L)subscript𝑓𝑦MLPsuperscriptsubscript𝑑𝑥subscript𝑑𝑧𝑊𝐿f_{y}\in{\rm MLP}(\mathbb{R}^{d_{x}+d_{z}},\mathbb{R};W,L)italic_f start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ∈ roman_MLP ( blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_POSTSUPERSCRIPT , blackboard_R ; italic_W , italic_L ) such that

fyhyL(Ω)CyhyWκ(L(Ω))L2κ/(dx+dz)subscriptnormsubscript𝑓𝑦superscriptsubscript𝑦subscript𝐿Ωsubscript𝐶𝑦subscriptnormsuperscriptsubscript𝑦superscript𝑊𝜅subscript𝐿Ωsuperscript𝐿2subscript𝜅subscript𝑑𝑥subscript𝑑𝑧\|f_{y}-h_{*}^{y}\|_{L_{\infty}(\Omega)}\leq C_{y}\|h_{*}^{y}\|_{W^{\kappa}(L_% {\infty}(\Omega))}L^{-2\kappa_{\mathscr{I}}/(d_{x}+d_{z})}∥ italic_f start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT - italic_h start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT italic_L start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT ( roman_Ω ) end_POSTSUBSCRIPT ≤ italic_C start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ∥ italic_h start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT italic_W start_POSTSUPERSCRIPT italic_κ end_POSTSUPERSCRIPT ( italic_L start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT ( roman_Ω ) ) end_POSTSUBSCRIPT italic_L start_POSTSUPERSCRIPT - 2 italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT / ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) end_POSTSUPERSCRIPT

for Ω[1,1]dx+dzΩsuperscript11subscript𝑑𝑥subscript𝑑𝑧\Omega\in[-1,1]^{d_{x}+d_{z}}roman_Ω ∈ [ - 1 , 1 ] start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, W=25(dx+dz)+31𝑊25subscript𝑑𝑥subscript𝑑𝑧31W=25(d_{x}+d_{z})+31italic_W = 25 ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) + 31 and Cy=c(κ,dx+dz)<subscript𝐶𝑦𝑐subscript𝜅subscript𝑑𝑥subscript𝑑𝑧C_{y}=c(\kappa_{\mathscr{I}},d_{x}+d_{z})<\inftyitalic_C start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT = italic_c ( italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT , italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) < ∞ (independent of L). By concatenating the networks fysubscript𝑓𝑦f_{y}italic_f start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT for y𝒴𝑦𝒴y\in\mathscr{Y}italic_y ∈ script_Y in parallel (c.f. Lemma 5 in Siegel [2023]), and using the first layer to share the (dx+dz)subscript𝑑𝑥subscript𝑑𝑧(d_{x}+d_{z})( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) input to these parallel networks we obtain a MLP foptMLP(dx+dz,K;W𝒴,L+1)subscript𝑓optMLPsuperscriptsubscript𝑑𝑥subscript𝑑𝑧superscript𝐾subscript𝑊𝒴𝐿1f_{{\rm opt}}\in{\rm MLP}(\mathbb{R}^{d_{x}+d_{z}},\mathbb{R}^{K};W_{\mathscr{% Y}},L+1)italic_f start_POSTSUBSCRIPT roman_opt end_POSTSUBSCRIPT ∈ roman_MLP ( blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_POSTSUPERSCRIPT , blackboard_R start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT ; italic_W start_POSTSUBSCRIPT script_Y end_POSTSUBSCRIPT , italic_L + 1 ), W𝒴=O(|𝒴|(dx+dz))subscript𝑊𝒴𝑂𝒴subscript𝑑𝑥subscript𝑑𝑧W_{\mathscr{Y}}=O(|\mathscr{Y}|(d_{x}+d_{z}))italic_W start_POSTSUBSCRIPT script_Y end_POSTSUBSCRIPT = italic_O ( | script_Y | ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) ), such that we have

foptyhyL(Ω)(maxy𝒴CyhyWκ(L(Ω)))L2κ/(dx+dz).subscriptnormsubscriptsuperscript𝑓𝑦optsuperscriptsubscript𝑦subscript𝐿Ωsubscript𝑦𝒴subscript𝐶𝑦subscriptnormsuperscriptsubscript𝑦superscript𝑊𝜅subscript𝐿Ωsuperscript𝐿2subscript𝜅subscript𝑑𝑥subscript𝑑𝑧\|f^{y}_{{\rm opt}}-h_{*}^{y}\|_{L_{\infty}(\Omega)}\leq\big{(}\max_{y\in% \mathscr{Y}}C_{y}\|h_{*}^{y}\|_{W^{\kappa}(L_{\infty}(\Omega))}\big{)}L^{-2% \kappa_{\mathscr{I}}/(d_{x}+d_{z})}.∥ italic_f start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_opt end_POSTSUBSCRIPT - italic_h start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT italic_L start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT ( roman_Ω ) end_POSTSUBSCRIPT ≤ ( roman_max start_POSTSUBSCRIPT italic_y ∈ script_Y end_POSTSUBSCRIPT italic_C start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ∥ italic_h start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT italic_W start_POSTSUPERSCRIPT italic_κ end_POSTSUPERSCRIPT ( italic_L start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT ( roman_Ω ) ) end_POSTSUBSCRIPT ) italic_L start_POSTSUPERSCRIPT - 2 italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT / ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) end_POSTSUPERSCRIPT .

By using ξ=fopty𝜉subscriptsuperscript𝑓𝑦opt\xi=f^{y}_{{\rm opt}}italic_ξ = italic_f start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_opt end_POSTSUBSCRIPT in our bounds we obtain the predictor error as

Predictor Error2(maxy𝒴CyhyWκ(L(Ω)))L2κ/(dx+dz)+(|𝒴|1)exp(max)+c||γPredictor Error2subscript𝑦𝒴subscript𝐶𝑦subscriptnormsuperscriptsubscript𝑦superscript𝑊𝜅subscript𝐿Ωsuperscript𝐿2subscript𝜅subscript𝑑𝑥subscript𝑑𝑧𝒴1subscriptsubscript𝑐superscriptsubscript𝛾\displaystyle\text{Predictor Error}\leq 2\big{(}\max_{y\in\mathscr{Y}}C_{y}\|h% _{*}^{y}\|_{W^{\kappa}(L_{\infty}(\Omega))}\big{)}L^{-2\kappa_{\mathscr{I}}/(d% _{x}+d_{z})}+(|\mathscr{Y}|-1)\exp(-\ell_{\max})+c_{\mathscr{I}}|\mathscr{I}|^% {-\gamma_{\mathscr{I}}}Predictor Error ≤ 2 ( roman_max start_POSTSUBSCRIPT italic_y ∈ script_Y end_POSTSUBSCRIPT italic_C start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ∥ italic_h start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT italic_W start_POSTSUPERSCRIPT italic_κ end_POSTSUPERSCRIPT ( italic_L start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT ( roman_Ω ) ) end_POSTSUBSCRIPT ) italic_L start_POSTSUPERSCRIPT - 2 italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT / ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) end_POSTSUPERSCRIPT + ( | script_Y | - 1 ) roman_exp ( - roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ) + italic_c start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT | script_I | start_POSTSUPERSCRIPT - italic_γ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_POSTSUPERSCRIPT (39)
Generalization error for MLP predictor:

We now bound the generalization error in Equation 35 when ΞΞ\Xiroman_Ξ denotes a class of multi-layer perceptron (MLP) with Relu nonlinearity MLP((dx+dz),|𝒴|;W,L)MLPsuperscriptsubscript𝑑𝑥subscript𝑑𝑧superscript𝒴𝑊𝐿{\rm MLP}(\mathbb{R}^{(d_{x}+d_{z})},\mathbb{R}^{|\mathscr{Y}|};W,L)roman_MLP ( blackboard_R start_POSTSUPERSCRIPT ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) end_POSTSUPERSCRIPT , blackboard_R start_POSTSUPERSCRIPT | script_Y | end_POSTSUPERSCRIPT ; italic_W , italic_L ).

The first step is to bound the covering number 𝒩(Ξ,ν,2,[n],θ)\mathcal{N}(\Xi,\nu,\|\cdot\|_{2,[n],\theta})caligraphic_N ( roman_Ξ , italic_ν , ∥ ⋅ ∥ start_POSTSUBSCRIPT 2 , [ italic_n ] , italic_θ end_POSTSUBSCRIPT ) norm with the covering number 𝒩(Ξ,ν,,n|||𝒴|)\mathcal{N}(\Xi,\nu,\|\cdot\|_{\infty,n|\mathscr{I}||\mathscr{Y}|})caligraphic_N ( roman_Ξ , italic_ν , ∥ ⋅ ∥ start_POSTSUBSCRIPT ∞ , italic_n | script_I | | script_Y | end_POSTSUBSCRIPT ). Where ,n|||𝒴|\|\cdot\|_{\infty,n|\mathscr{I}||\mathscr{Y}|}∥ ⋅ ∥ start_POSTSUBSCRIPT ∞ , italic_n | script_I | | script_Y | end_POSTSUBSCRIPT is defined as u,n|||𝒴|=supxi𝒮nsupzsupy𝒴|ui,z,y|,𝐮n×||×|𝒴|.formulae-sequencesubscriptnorm𝑢𝑛𝒴subscriptsupremumsubscript𝑥𝑖subscript𝒮𝑛subscriptsupremum𝑧subscriptsupremum𝑦𝒴subscript𝑢𝑖𝑧𝑦for-all𝐮superscript𝑛𝒴\|u\|_{\infty,n|\mathscr{I}||\mathscr{Y}|}=\sup_{x_{i}\in\mathcal{S}_{n}}\sup_% {z\in\mathscr{I}}\sup_{y\in\mathscr{Y}}|u_{i,z,y}|,~{}\forall\mathbf{u}\in% \mathbb{R}^{n\times|\mathscr{I}|\times|\mathscr{Y}|}.∥ italic_u ∥ start_POSTSUBSCRIPT ∞ , italic_n | script_I | | script_Y | end_POSTSUBSCRIPT = roman_sup start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ caligraphic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_sup start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT roman_sup start_POSTSUBSCRIPT italic_y ∈ script_Y end_POSTSUBSCRIPT | italic_u start_POSTSUBSCRIPT italic_i , italic_z , italic_y end_POSTSUBSCRIPT | , ∀ bold_u ∈ blackboard_R start_POSTSUPERSCRIPT italic_n × | script_I | × | script_Y | end_POSTSUPERSCRIPT .

For a fixed data set 𝒮n:={(x1,y1),,(xn,yn)}assignsubscript𝒮𝑛subscript𝑥1subscript𝑦1subscript𝑥𝑛subscript𝑦𝑛\mathcal{S}_{n}:=\{(x_{1},y_{1}),\dots,(x_{n},y_{n})\}caligraphic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT := { ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , … , ( italic_x start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) } and retriever ξ𝜉\xiitalic_ξ, and two predictors ξ,ξΞ𝜉superscript𝜉Ξ\xi,\xi^{\prime}\in\Xiitalic_ξ , italic_ξ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ roman_Ξ, we have

(1ni[n](zpθ(z|xi)((hξ(xi,z),yi)(hξ(xi,z),yi)))2)1/2superscript1𝑛subscript𝑖delimited-[]𝑛superscriptsubscript𝑧subscript𝑝𝜃conditional𝑧subscript𝑥𝑖subscript𝜉subscript𝑥𝑖𝑧subscript𝑦𝑖subscriptsuperscript𝜉subscript𝑥𝑖𝑧subscript𝑦𝑖212\displaystyle\Big{(}\tfrac{1}{n}\sum_{i\in[n]}\big{(}\sum_{z\in\mathscr{I}}p_{% \theta}(z|x_{i})(\ell\big{(}h_{\xi}(x_{i},z),y_{i}\big{)}-\ell\big{(}h_{\xi^{% \prime}}(x_{i},z),y_{i}\big{)})\big{)}^{2}\Big{)}^{1/2}( divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT ( ∑ start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( roman_ℓ ( italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - roman_ℓ ( italic_h start_POSTSUBSCRIPT italic_ξ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT
(i)(1ni[n]zpθ(z|xi)((hξ(xi,z),yi)(hξ(xi,z),yi))2)1/2𝑖superscript1𝑛subscript𝑖delimited-[]𝑛subscript𝑧subscript𝑝𝜃conditional𝑧subscript𝑥𝑖superscriptsubscript𝜉subscript𝑥𝑖𝑧subscript𝑦𝑖subscriptsuperscript𝜉subscript𝑥𝑖𝑧subscript𝑦𝑖212\displaystyle\qquad\overset{(i)}{\leq}\Big{(}\tfrac{1}{n}\sum_{i\in[n]}\sum_{z% \in\mathscr{I}}p_{\theta}(z|x_{i})\big{(}\ell\big{(}h_{\xi}(x_{i},z),y_{i}\big% {)}-\ell\big{(}h_{\xi^{\prime}}(x_{i},z),y_{i}\big{)}\big{)}^{2}\Big{)}^{1/2}start_OVERACCENT ( italic_i ) end_OVERACCENT start_ARG ≤ end_ARG ( divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( roman_ℓ ( italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - roman_ℓ ( italic_h start_POSTSUBSCRIPT italic_ξ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT
(ii)(1ni[n]zpθ(z|xi)(|hξyi(xi,z)hξyi(xi,z)|+maxy𝒴|hξy(xi,z)hξy(xi,z)|)2)1/2𝑖𝑖superscript1𝑛subscript𝑖delimited-[]𝑛subscript𝑧subscript𝑝𝜃conditional𝑧subscript𝑥𝑖superscriptsubscriptsuperscriptsubscript𝑦𝑖𝜉subscript𝑥𝑖𝑧subscriptsuperscriptsubscript𝑦𝑖superscript𝜉subscript𝑥𝑖𝑧subscript𝑦𝒴subscriptsuperscript𝑦𝜉subscript𝑥𝑖𝑧subscriptsuperscript𝑦superscript𝜉subscript𝑥𝑖𝑧212\displaystyle\qquad\overset{(ii)}{\leq}\Big{(}\tfrac{1}{n}\sum_{i\in[n]}\sum_{% z\in\mathscr{I}}p_{\theta}(z|x_{i})\big{(}|h^{y_{i}}_{\xi}(x_{i},z)-h^{y_{i}}_% {\xi^{\prime}}(x_{i},z)|+\max_{y\in\mathscr{Y}}|h^{y}_{\xi}(x_{i},z)-h^{y}_{% \xi^{\prime}}(x_{i},z)|\big{)}^{2}\Big{)}^{1/2}start_OVERACCENT ( italic_i italic_i ) end_OVERACCENT start_ARG ≤ end_ARG ( divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( | italic_h start_POSTSUPERSCRIPT italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) - italic_h start_POSTSUPERSCRIPT italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ξ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) | + roman_max start_POSTSUBSCRIPT italic_y ∈ script_Y end_POSTSUBSCRIPT | italic_h start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) - italic_h start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ξ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) | ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT
(iii)(1ni[n]zpθ(z|xi)(|hξyi(xi,z)hξyi(xi,z)|+maxy𝒴|hξy(xi,z)hξy(xi,z)|)2)1/2𝑖𝑖𝑖superscript1𝑛subscript𝑖delimited-[]𝑛subscript𝑧subscript𝑝𝜃conditional𝑧subscript𝑥𝑖superscriptsubscriptsuperscriptsubscript𝑦𝑖𝜉subscript𝑥𝑖𝑧subscriptsuperscriptsubscript𝑦𝑖superscript𝜉subscript𝑥𝑖𝑧subscript𝑦𝒴subscriptsuperscript𝑦𝜉subscript𝑥𝑖𝑧subscriptsuperscript𝑦superscript𝜉subscript𝑥𝑖𝑧212\displaystyle\qquad\overset{(iii)}{\leq}\Big{(}\tfrac{1}{n}\sum_{i\in[n]}\sum_% {z\in\mathscr{I}}p_{\theta}(z|x_{i})\big{(}|h^{y_{i}}_{\xi}(x_{i},z)-h^{y_{i}}% _{\xi^{\prime}}(x_{i},z)|+\max_{y\in\mathscr{Y}}|h^{y}_{\xi}(x_{i},z)-h^{y}_{% \xi^{\prime}}(x_{i},z)|\big{)}^{2}\Big{)}^{1/2}start_OVERACCENT ( italic_i italic_i italic_i ) end_OVERACCENT start_ARG ≤ end_ARG ( divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( | italic_h start_POSTSUPERSCRIPT italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) - italic_h start_POSTSUPERSCRIPT italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ξ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) | + roman_max start_POSTSUBSCRIPT italic_y ∈ script_Y end_POSTSUBSCRIPT | italic_h start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) - italic_h start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ξ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) | ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT
(iv)2supx𝒳supy𝒴supz|hξy(x,z)hξy(x,z)|𝑖𝑣2subscriptsupremum𝑥𝒳subscriptsupremum𝑦𝒴subscriptsupremum𝑧subscriptsuperscript𝑦𝜉𝑥𝑧subscriptsuperscript𝑦superscript𝜉𝑥𝑧\displaystyle\qquad\overset{(iv)}{\leq}\sqrt{2}\sup_{x\in\mathscr{X}}\sup_{y% \in\mathscr{Y}}\sup_{z\in\mathscr{I}}|h^{y}_{\xi}(x,z)-h^{y}_{\xi^{\prime}}(x,% z)|start_OVERACCENT ( italic_i italic_v ) end_OVERACCENT start_ARG ≤ end_ARG square-root start_ARG 2 end_ARG roman_sup start_POSTSUBSCRIPT italic_x ∈ script_X end_POSTSUBSCRIPT roman_sup start_POSTSUBSCRIPT italic_y ∈ script_Y end_POSTSUBSCRIPT roman_sup start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT | italic_h start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x , italic_z ) - italic_h start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ξ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_x , italic_z ) |

The first inequality follows from Jensen. For the case of bounded log-loss, we obtain the second inequality using the fact that for any two s,sK𝑠superscript𝑠superscript𝐾s,s^{\prime}\in\mathbb{R}^{K}italic_s , italic_s start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT, |log(kexp(sk))log(kexp(sk))|maxk|sksk|subscript𝑘subscript𝑠𝑘subscript𝑘subscriptsuperscript𝑠𝑘subscript𝑘subscript𝑠𝑘subscriptsuperscript𝑠𝑘|\log(\sum_{k}\exp(s_{k}))-\log(\sum_{k}\exp(s^{\prime}_{k}))|\leq\max_{k}|s_{% k}-s^{\prime}_{k}|| roman_log ( ∑ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT roman_exp ( italic_s start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) ) - roman_log ( ∑ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT roman_exp ( italic_s start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) ) | ≤ roman_max start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT | italic_s start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - italic_s start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT |.

Let ΞcovsubscriptΞcov\Xi_{{\rm cov}}roman_Ξ start_POSTSUBSCRIPT roman_cov end_POSTSUBSCRIPT be a ,n|||𝒴|\|\cdot\|_{\infty,n|\mathscr{I}||\mathscr{Y}|}∥ ⋅ ∥ start_POSTSUBSCRIPT ∞ , italic_n | script_I | | script_Y | end_POSTSUBSCRIPT norm cover for the space ΞΞ\Xiroman_Ξ of cardinality 𝒩(Ξ,ν,,n|||𝒴|)\mathcal{N}(\Xi,\nu,\|\cdot\|_{\infty,n|\mathscr{I}||\mathscr{Y}|})caligraphic_N ( roman_Ξ , italic_ν , ∥ ⋅ ∥ start_POSTSUBSCRIPT ∞ , italic_n | script_I | | script_Y | end_POSTSUBSCRIPT ). That implies, for any ξΞ𝜉Ξ\xi\in\Xiitalic_ξ ∈ roman_Ξ there exists a ξ(ξ)Ξcovsuperscript𝜉𝜉subscriptΞcov\xi^{\prime}(\xi)\in\Xi_{{\rm cov}}italic_ξ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_ξ ) ∈ roman_Ξ start_POSTSUBSCRIPT roman_cov end_POSTSUBSCRIPT such that supx𝒳supy𝒴supz|hξy(x,z)hξy(x,z)|νsubscriptsupremum𝑥𝒳subscriptsupremum𝑦𝒴subscriptsupremum𝑧subscriptsuperscript𝑦𝜉𝑥𝑧subscriptsuperscript𝑦superscript𝜉𝑥𝑧𝜈\sup_{x\in\mathscr{X}}\sup_{y\in\mathscr{Y}}\sup_{z\in\mathscr{I}}|h^{y}_{\xi}% (x,z)-h^{y}_{\xi^{\prime}}(x,z)|\leq\nuroman_sup start_POSTSUBSCRIPT italic_x ∈ script_X end_POSTSUBSCRIPT roman_sup start_POSTSUBSCRIPT italic_y ∈ script_Y end_POSTSUBSCRIPT roman_sup start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT | italic_h start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x , italic_z ) - italic_h start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ξ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_x , italic_z ) | ≤ italic_ν. Therefore, due to the above inequality, we have (1ni[n](zpθ(z|xi)((hξ(xi,z),yi)(hξ(xi,z),yi)))2)1/2νsuperscript1𝑛subscript𝑖delimited-[]𝑛superscriptsubscript𝑧subscript𝑝𝜃conditional𝑧subscript𝑥𝑖subscript𝜉subscript𝑥𝑖𝑧subscript𝑦𝑖subscriptsuperscript𝜉subscript𝑥𝑖𝑧subscript𝑦𝑖212𝜈\Big{(}\tfrac{1}{n}\sum_{i\in[n]}\big{(}\sum_{z\in\mathscr{I}}p_{\theta}(z|x_{% i})(\ell\big{(}h_{\xi}(x_{i},z),y_{i}\big{)}-\ell\big{(}h_{\xi^{\prime}}(x_{i}% ,z),y_{i}\big{)})\big{)}^{2}\Big{)}^{1/2}\leq\nu( divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT ( ∑ start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( roman_ℓ ( italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - roman_ℓ ( italic_h start_POSTSUBSCRIPT italic_ξ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT ≤ italic_ν. So ΞcovsubscriptΞcov\Xi_{{\rm cov}}roman_Ξ start_POSTSUBSCRIPT roman_cov end_POSTSUBSCRIPT forms a cover of ΞΞ\Xiroman_Ξ with respect to the 2,[n],θ\|\cdot\|_{2,[n],\theta}∥ ⋅ ∥ start_POSTSUBSCRIPT 2 , [ italic_n ] , italic_θ end_POSTSUBSCRIPT norm. Hence, 𝒩(Ξ,ν,2,[n],θ)𝒩(Ξ,ν,,n|||𝒴|).\mathcal{N}(\Xi,\nu,\|\cdot\|_{2,[n],\theta})\leq\mathcal{N}(\Xi,\nu,\|\cdot\|% _{\infty,n|\mathscr{I}||\mathscr{Y}|}).caligraphic_N ( roman_Ξ , italic_ν , ∥ ⋅ ∥ start_POSTSUBSCRIPT 2 , [ italic_n ] , italic_θ end_POSTSUBSCRIPT ) ≤ caligraphic_N ( roman_Ξ , italic_ν , ∥ ⋅ ∥ start_POSTSUBSCRIPT ∞ , italic_n | script_I | | script_Y | end_POSTSUBSCRIPT ) .

We need to bound 𝒩(Ξ,ν,,n|||𝒴|)\mathcal{N}(\Xi,\nu,\|\cdot\|_{\infty,n|\mathscr{I}||\mathscr{Y}|})caligraphic_N ( roman_Ξ , italic_ν , ∥ ⋅ ∥ start_POSTSUBSCRIPT ∞ , italic_n | script_I | | script_Y | end_POSTSUBSCRIPT ) next. Similar to the retrieval analysis in Section B.1, we first apply  Zhang [2023] to bound the covering number 𝒩(Ξ,ν,,n|||𝒴|)\mathcal{N}(\Xi,\nu,\|\cdot\|_{\infty,n|\mathscr{I}||\mathscr{Y}|})caligraphic_N ( roman_Ξ , italic_ν , ∥ ⋅ ∥ start_POSTSUBSCRIPT ∞ , italic_n | script_I | | script_Y | end_POSTSUBSCRIPT ) with pseudo-dimension. However, we need slight reformulation of the function hξ:𝒳×𝒵|𝒴|:subscript𝜉𝒳𝒵superscript𝒴h_{\xi}:\mathscr{X}\times\mathscr{Z}\to\mathbb{R}^{|\mathscr{Y}|}italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT : script_X × script_Z → blackboard_R start_POSTSUPERSCRIPT | script_Y | end_POSTSUPERSCRIPT to apply the results therein. Let us define function h~ξ:𝒳×𝒵×𝒴:subscript~𝜉𝒳𝒵𝒴\tilde{h}_{\xi}:\mathscr{X}\times\mathscr{Z}\times\mathscr{Y}\to\mathbb{R}over~ start_ARG italic_h end_ARG start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT : script_X × script_Z × script_Y → blackboard_R, where for each y𝒴𝑦𝒴y\in\mathscr{Y}italic_y ∈ script_Y we have h~ξ(x,y,z)=hξy(x,z)subscript~𝜉𝑥𝑦𝑧subscriptsuperscript𝑦𝜉𝑥𝑧\tilde{h}_{\xi}(x,y,z)=h^{y}_{\xi}(x,z)over~ start_ARG italic_h end_ARG start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x , italic_y , italic_z ) = italic_h start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x , italic_z ). It is easy to see that 𝒩(Ξ,ν,,n|||𝒴|)\mathcal{N}(\Xi,\nu,\|\cdot\|_{\infty,n|\mathscr{I}||\mathscr{Y}|})caligraphic_N ( roman_Ξ , italic_ν , ∥ ⋅ ∥ start_POSTSUBSCRIPT ∞ , italic_n | script_I | | script_Y | end_POSTSUBSCRIPT ) covering of set ΞΞ\Xiroman_Ξ remains unchanged due to this reformulation. In particular, if the pseudo-dimension of {h~ξ:ξΞ}conditional-setsubscript~𝜉𝜉Ξ\{\tilde{h}_{\xi}:\xi\in\Xi\}{ over~ start_ARG italic_h end_ARG start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT : italic_ξ ∈ roman_Ξ } is d~VCsubscript~𝑑𝑉𝐶\tilde{d}_{VC}over~ start_ARG italic_d end_ARG start_POSTSUBSCRIPT italic_V italic_C end_POSTSUBSCRIPT, then we have log𝒩(Ξ,ν,,n|||𝒴|)1+log(1+d~VC)+d~VClog(max{2,en|||𝒴|/d~VCν})\log\mathcal{N}(\Xi,\nu,\|\cdot\|_{\infty,n|\mathscr{I}||\mathscr{Y}|})\leq 1+% \log(1+\tilde{d}_{VC})+\tilde{d}_{VC}\log(\max\{2,en|\mathscr{I}||\mathscr{Y}|% /\tilde{d}_{VC}\nu\})roman_log caligraphic_N ( roman_Ξ , italic_ν , ∥ ⋅ ∥ start_POSTSUBSCRIPT ∞ , italic_n | script_I | | script_Y | end_POSTSUBSCRIPT ) ≤ 1 + roman_log ( 1 + over~ start_ARG italic_d end_ARG start_POSTSUBSCRIPT italic_V italic_C end_POSTSUBSCRIPT ) + over~ start_ARG italic_d end_ARG start_POSTSUBSCRIPT italic_V italic_C end_POSTSUBSCRIPT roman_log ( roman_max { 2 , italic_e italic_n | script_I | | script_Y | / over~ start_ARG italic_d end_ARG start_POSTSUBSCRIPT italic_V italic_C end_POSTSUBSCRIPT italic_ν } ) as per Theorem 5.11 in Zhang [2023].

Next we derive the pseudo-dimension of the class {h~ξ:ξΞ}conditional-setsubscript~𝜉𝜉Ξ\{\tilde{h}_{\xi}:\xi\in\Xi\}{ over~ start_ARG italic_h end_ARG start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT : italic_ξ ∈ roman_Ξ } using  Bartlett et al. [2019]. One challenge here is that for the MLP we are considering the label y𝑦yitalic_y does not lie in the input space, rather this correspond to one coordinate of the |𝒴|𝒴|\mathscr{Y}|| script_Y |-dimensional output. This can be captured with the slight modification of Theorem 6 in  Bartlett et al. [2019], namely Theorem A.9 in Appendix A. By Theorem A.9 we have for Ξ=MLP(dx+dz,|𝒴|;L,W)ΞMLPsuperscriptsubscript𝑑𝑥subscript𝑑𝑧superscript𝒴𝐿𝑊\Xi={\rm MLP}(\mathbb{R}^{d_{x}+d_{z}},\mathbb{R}^{|\mathscr{Y}|};L,W)roman_Ξ = roman_MLP ( blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_POSTSUPERSCRIPT , blackboard_R start_POSTSUPERSCRIPT | script_Y | end_POSTSUPERSCRIPT ; italic_L , italic_W ) the VC dimension of ΞΞ\Xiroman_Ξ as VCdim(Ξ)=O(Llog(|𝒴|)+L2W2log(LW))VCdimΞ𝑂𝐿𝒴superscript𝐿2superscript𝑊2𝐿𝑊{\rm VCdim}(\Xi)=O(L\log(|\mathscr{Y}|)+L^{2}W^{2}\log(LW))roman_VCdim ( roman_Ξ ) = italic_O ( italic_L roman_log ( | script_Y | ) + italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_W start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_log ( italic_L italic_W ) ). The final generalization bound obtained is as

Generalization ErrorO(max(Llog(|𝒴|)+L2W2log(LW))log(n|||𝒴|)n).Generalization Error𝑂subscript𝐿𝒴superscript𝐿2superscript𝑊2𝐿𝑊𝑛𝒴𝑛\text{Generalization Error}\leq O\bigg{(}\frac{\ell_{\max}\sqrt{(L\log(|% \mathscr{Y}|)+L^{2}W^{2}\log(LW))\log(n|\mathscr{I}||\mathscr{Y}|)}}{\sqrt{n}}% \bigg{)}.Generalization Error ≤ italic_O ( divide start_ARG roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT square-root start_ARG ( italic_L roman_log ( | script_Y | ) + italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_W start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_log ( italic_L italic_W ) ) roman_log ( italic_n | script_I | | script_Y | ) end_ARG end_ARG start_ARG square-root start_ARG italic_n end_ARG end_ARG ) . (40)
Excess risk of predictor learning:

We can now combine the generalization error (40) and approximation error (38) to obtain the final excess risk. The final excess risk is upper bounded as

Excess Risk R,(ξ,θ)𝔼X[minzgξ(X,z)]error from retriever θ+O(L2κ/(dx+dz)+(|𝒴|1)exp(max)+c||γexp(max))predictor approximation errorabsentsubscriptsubscript𝑅superscript𝜉𝜃subscript𝔼𝑋delimited-[]subscript𝑧subscript𝑔superscript𝜉𝑋𝑧error from retriever θsubscript𝑂superscript𝐿2subscript𝜅subscript𝑑𝑥subscript𝑑𝑧𝒴1subscriptsubscript𝑐superscriptsubscript𝛾subscriptpredictor approximation error\displaystyle\leq\underbrace{R_{\ell,\mathscr{I}}(\xi^{\ast},\theta)-\mathbb{E% }_{X}\big{[}\min_{z\in\mathscr{I}}g_{\xi^{\ast}}(X,z)\big{]}}_{\text{error % from retriever $\theta$}}+\underbrace{O\big{(}L^{-2\kappa_{\mathscr{I}}/(d_{x}% +d_{z})}+(|\mathscr{Y}|-1)\exp(-\ell_{\max})+c_{\mathscr{I}}|\mathscr{I}|^{-% \gamma_{\mathscr{I}}}\exp(\ell_{\max})\big{)}}_{\text{predictor approximation % error}}≤ under⏟ start_ARG italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , italic_θ ) - blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ roman_min start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_X , italic_z ) ] end_ARG start_POSTSUBSCRIPT error from retriever italic_θ end_POSTSUBSCRIPT + under⏟ start_ARG italic_O ( italic_L start_POSTSUPERSCRIPT - 2 italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT / ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) end_POSTSUPERSCRIPT + ( | script_Y | - 1 ) roman_exp ( - roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ) + italic_c start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT | script_I | start_POSTSUPERSCRIPT - italic_γ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_POSTSUPERSCRIPT roman_exp ( roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ) ) end_ARG start_POSTSUBSCRIPT predictor approximation error end_POSTSUBSCRIPT
+O(max(Llog(|𝒴|)+L2W2log(LW))log(n|||𝒴|)n)predictor generalization errorsubscript𝑂subscript𝐿𝒴superscript𝐿2superscript𝑊2𝐿𝑊𝑛𝒴𝑛predictor generalization error\displaystyle+\underbrace{O\bigg{(}\frac{\ell_{\max}\sqrt{(L\log(|\mathscr{Y}|% )+L^{2}W^{2}\log(LW))\log(n|\mathscr{I}||\mathscr{Y}|)}}{\sqrt{n}}\bigg{)}}_{% \text{predictor generalization error}}+ under⏟ start_ARG italic_O ( divide start_ARG roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT square-root start_ARG ( italic_L roman_log ( | script_Y | ) + italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_W start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_log ( italic_L italic_W ) ) roman_log ( italic_n | script_I | | script_Y | ) end_ARG end_ARG start_ARG square-root start_ARG italic_n end_ARG end_ARG ) end_ARG start_POSTSUBSCRIPT predictor generalization error end_POSTSUBSCRIPT
=R,(ξ,θ)𝔼X[minzgξ(X,z)]error from retriever θ+O~(|𝒴|2κ(dx+dz)+2κnκ(dx+dz)+2κ)predictor combined errorabsentsubscriptsubscript𝑅superscript𝜉𝜃subscript𝔼𝑋delimited-[]subscript𝑧subscript𝑔superscript𝜉𝑋𝑧error from retriever θsubscript~𝑂superscript𝒴2subscript𝜅subscript𝑑𝑥subscript𝑑𝑧2subscript𝜅superscript𝑛subscript𝜅subscript𝑑𝑥subscript𝑑𝑧2subscript𝜅predictor combined error\displaystyle=\underbrace{R_{\ell,\mathscr{I}}(\xi^{\ast},\theta)-\mathbb{E}_{% X}\big{[}\min_{z\in\mathscr{I}}g_{\xi^{\ast}}(X,z)\big{]}}_{\text{error from % retriever $\theta$}}+\underbrace{\tilde{O}\bigg{(}|\mathscr{Y}|^{\tfrac{2% \kappa_{\mathscr{I}}}{(d_{x}+d_{z})+2\kappa_{\mathscr{I}}}}n^{-\tfrac{\kappa_{% \mathscr{I}}}{(d_{x}+d_{z})+2\kappa_{\mathscr{I}}}}\bigg{)}}_{\text{predictor % combined error}}= under⏟ start_ARG italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , italic_θ ) - blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ roman_min start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_X , italic_z ) ] end_ARG start_POSTSUBSCRIPT error from retriever italic_θ end_POSTSUBSCRIPT + under⏟ start_ARG over~ start_ARG italic_O end_ARG ( | script_Y | start_POSTSUPERSCRIPT divide start_ARG 2 italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_ARG start_ARG ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) + 2 italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_ARG end_POSTSUPERSCRIPT italic_n start_POSTSUPERSCRIPT - divide start_ARG italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_ARG start_ARG ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) + 2 italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_ARG end_POSTSUPERSCRIPT ) end_ARG start_POSTSUBSCRIPT predictor combined error end_POSTSUBSCRIPT (41)

We have data store grow polynomially with data, ||=Ω(ns|𝒴|1/γ)Ωsuperscript𝑛𝑠superscript𝒴1subscript𝛾|\mathscr{I}|=\Omega(n^{s}|\mathscr{Y}|^{1/\gamma_{\mathscr{I}}})| script_I | = roman_Ω ( italic_n start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT | script_Y | start_POSTSUPERSCRIPT 1 / italic_γ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ), and we let max=log(|𝒴|)+slog(n)subscript𝒴superscript𝑠𝑛\ell_{\max}=\log(|\mathscr{Y}|)+s^{\prime}\log(n)roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT = roman_log ( | script_Y | ) + italic_s start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT roman_log ( italic_n ). For s2κ((dx+dz)+2κ)γ𝑠2subscript𝜅subscript𝑑𝑥subscript𝑑𝑧2subscript𝜅subscript𝛾s\geq\frac{2\kappa_{\mathscr{I}}}{((d_{x}+d_{z})+2\kappa_{\mathscr{I}})\gamma_% {\mathscr{I}}}italic_s ≥ divide start_ARG 2 italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_ARG start_ARG ( ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) + 2 italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT ) italic_γ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_ARG and sκ((dx+dz)+2κ)superscript𝑠subscript𝜅subscript𝑑𝑥subscript𝑑𝑧2subscript𝜅s^{\prime}\geq\frac{\kappa_{\mathscr{I}}}{((d_{x}+d_{z})+2\kappa_{\mathscr{I}})}italic_s start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ≥ divide start_ARG italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_ARG start_ARG ( ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) + 2 italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT ) end_ARG, the final error bound for predictor follows by setting L=n(dx+dz)2(dx+dz)+4κ|𝒴|dx+dz(dx+dz)+2κ𝐿superscript𝑛subscript𝑑𝑥subscript𝑑𝑧2subscript𝑑𝑥subscript𝑑𝑧4subscript𝜅superscript𝒴subscript𝑑𝑥subscript𝑑𝑧subscript𝑑𝑥subscript𝑑𝑧2subscript𝜅L=n^{\tfrac{(d_{x}+d_{z})}{2(d_{x}+d_{z})+4\kappa_{\mathscr{I}}}}|\mathscr{Y}|% ^{-\frac{d_{x}+d_{z}}{(d_{x}+d_{z})+2\kappa_{\mathscr{I}}}}italic_L = italic_n start_POSTSUPERSCRIPT divide start_ARG ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) end_ARG start_ARG 2 ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) + 4 italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_ARG end_POSTSUPERSCRIPT | script_Y | start_POSTSUPERSCRIPT - divide start_ARG italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_ARG start_ARG ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) + 2 italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_ARG end_POSTSUPERSCRIPT. Note that the choice of L𝐿Litalic_L and W𝑊Witalic_W here are related to predictor size, and are independent of the choices in retriever size. Moreover, here we see Assumption B.2 forces the quality of retriever set to become the bottleneck in predictor excess risk, if we have ||=o(ns|𝒴|1/γ)𝑜superscript𝑛𝑠superscript𝒴1subscript𝛾|\mathscr{I}|=o(n^{s}|\mathscr{Y}|^{1/\gamma_{\mathscr{I}}})| script_I | = italic_o ( italic_n start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT | script_Y | start_POSTSUPERSCRIPT 1 / italic_γ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ) for s=2κ((dx+dz)+2κ)γ𝑠2subscript𝜅subscript𝑑𝑥subscript𝑑𝑧2subscript𝜅subscript𝛾s=\frac{2\kappa_{\mathscr{I}}}{((d_{x}+d_{z})+2\kappa_{\mathscr{I}})\gamma_{% \mathscr{I}}}italic_s = divide start_ARG 2 italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_ARG start_ARG ( ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) + 2 italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT ) italic_γ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_ARG.

B.3 Joint learning of retriever and predictor

In this section, we consider the task of joint learning the predictor and retriever from the space ΞΞ\Xiroman_Ξ and ΘΘ\Thetaroman_Θ, respectively. The empirical optimizer pair (ξ^joint,θ^joint)subscript^𝜉jointsubscript^𝜃joint(\hat{\xi}_{\rm joint},\hat{\theta}_{\rm joint})( over^ start_ARG italic_ξ end_ARG start_POSTSUBSCRIPT roman_joint end_POSTSUBSCRIPT , over^ start_ARG italic_θ end_ARG start_POSTSUBSCRIPT roman_joint end_POSTSUBSCRIPT ) and the population optimizer (ξjoint,θjoint)subscriptsuperscript𝜉jointsubscriptsuperscript𝜃joint(\xi^{\ast}_{\rm joint},\theta^{\ast}_{\rm joint})( italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_joint end_POSTSUBSCRIPT , italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_joint end_POSTSUBSCRIPT ) for the joint task are given as follows.

ξ^joint,θ^joint=argminξΞ,θ^Θ1ni[n]zpθ(z|xi)(hξ(xi,z),yi),subscript^𝜉jointsubscript^𝜃jointsubscriptargminformulae-sequence𝜉Ξ^𝜃Θ1𝑛subscript𝑖delimited-[]𝑛subscript𝑧subscript𝑝𝜃conditional𝑧subscript𝑥𝑖subscript𝜉subscript𝑥𝑖𝑧subscript𝑦𝑖\displaystyle\hat{\xi}_{\rm joint},\hat{\theta}_{\rm joint}=\operatorname*{arg% \,min}_{\xi\in\Xi,\hat{\theta}\in\Theta}\frac{1}{n}\sum_{i\in[n]}\sum_{z\in% \mathscr{I}}p_{\theta}(z|x_{i})\ell\big{(}h_{\xi}(x_{i},z),y_{i}\big{)},over^ start_ARG italic_ξ end_ARG start_POSTSUBSCRIPT roman_joint end_POSTSUBSCRIPT , over^ start_ARG italic_θ end_ARG start_POSTSUBSCRIPT roman_joint end_POSTSUBSCRIPT = start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_ξ ∈ roman_Ξ , over^ start_ARG italic_θ end_ARG ∈ roman_Θ end_POSTSUBSCRIPT divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) roman_ℓ ( italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ,
ξjoint,θjoint=argminξΞ𝔼X[𝔼Zpθ(|X)gξ(X,Z)].\displaystyle\xi^{\ast}_{\rm joint},\theta^{\ast}_{\rm joint}=\operatorname*{% arg\,min}_{\xi\in\Xi}\mathbb{E}_{X}\big{[}\mathbb{E}_{Z\sim p_{\theta}(\cdot|X% )}g_{\xi}(X,Z)\big{]}.italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_joint end_POSTSUBSCRIPT , italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_joint end_POSTSUBSCRIPT = start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_ξ ∈ roman_Ξ end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ blackboard_E start_POSTSUBSCRIPT italic_Z ∼ italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ | italic_X ) end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_X , italic_Z ) ] .

Recall, the optimal predictor with best possible retrieval is ξ=argminξΞ𝔼X[minzgξ(X,z)].superscript𝜉subscriptargmin𝜉Ξsubscript𝔼𝑋delimited-[]subscript𝑧subscript𝑔𝜉𝑋𝑧\xi^{\ast}=\operatorname*{arg\,min}_{\xi\in\Xi}\mathbb{E}_{X}\big{[}\min_{z\in% \mathscr{I}}g_{\xi}(X,z)\big{]}.italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT = start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_ξ ∈ roman_Ξ end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ roman_min start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_X , italic_z ) ] . We denote the optimal retriever for ξsuperscript𝜉\xi^{*}italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT as θ(ξ)=argminθΘ𝔼X[𝔼Zpθ(|X)gξ(X,Z)]\theta(\xi^{\ast})=\operatorname*{arg\,min}_{\theta\in\Theta}\mathbb{E}_{X}% \big{[}\mathbb{E}_{Z\sim p_{\theta}(\cdot|X)}g_{\xi^{\ast}}(X,Z)\big{]}italic_θ ( italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) = start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_θ ∈ roman_Θ end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ blackboard_E start_POSTSUBSCRIPT italic_Z ∼ italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ | italic_X ) end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_X , italic_Z ) ].

The excess risk for the classes ΘΘ\Thetaroman_Θ and ΞΞ\Xiroman_Ξ can be bounded as

R,(ξ^joint,θ^joint)R,(fopt,)subscript𝑅subscript^𝜉jointsubscript^𝜃jointsubscript𝑅superscriptsubscript𝑓opt\displaystyle R_{\ell,\mathscr{I}}(\hat{\xi}_{\rm joint},\hat{\theta}_{\rm joint% })-R_{\ell,\mathscr{I}}(f_{{\rm opt},\mathscr{I}}^{\ell})italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( over^ start_ARG italic_ξ end_ARG start_POSTSUBSCRIPT roman_joint end_POSTSUBSCRIPT , over^ start_ARG italic_θ end_ARG start_POSTSUBSCRIPT roman_joint end_POSTSUBSCRIPT ) - italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT roman_opt , script_I end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT )
(i)R,(ξ^joint,θ^joint)(R,,n(ξ^joint,θ^joint)R,,n(ξjoint,θjoint))0 as ERM minimizes empirical risk𝑖subscript𝑅subscript^𝜉jointsubscript^𝜃jointsubscriptsubscript𝑅𝑛subscript^𝜉jointsubscript^𝜃jointsubscript𝑅𝑛subscriptsuperscript𝜉jointsubscriptsuperscript𝜃jointabsent0 as ERM minimizes empirical risk\displaystyle\qquad\overset{(i)}{\leq}R_{\ell,\mathscr{I}}(\hat{\xi}_{\rm joint% },\hat{\theta}_{\rm joint})-\underbrace{\bigg{(}R_{\ell,\mathscr{I},n}(\hat{% \xi}_{\rm joint},\hat{\theta}_{\rm joint})-R_{\ell,\mathscr{I},n}(\xi^{\ast}_{% \rm joint},\theta^{\ast}_{\rm joint})\bigg{)}}_{\leq 0\text{ as ERM minimizes % empirical risk}}start_OVERACCENT ( italic_i ) end_OVERACCENT start_ARG ≤ end_ARG italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( over^ start_ARG italic_ξ end_ARG start_POSTSUBSCRIPT roman_joint end_POSTSUBSCRIPT , over^ start_ARG italic_θ end_ARG start_POSTSUBSCRIPT roman_joint end_POSTSUBSCRIPT ) - under⏟ start_ARG ( italic_R start_POSTSUBSCRIPT roman_ℓ , script_I , italic_n end_POSTSUBSCRIPT ( over^ start_ARG italic_ξ end_ARG start_POSTSUBSCRIPT roman_joint end_POSTSUBSCRIPT , over^ start_ARG italic_θ end_ARG start_POSTSUBSCRIPT roman_joint end_POSTSUBSCRIPT ) - italic_R start_POSTSUBSCRIPT roman_ℓ , script_I , italic_n end_POSTSUBSCRIPT ( italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_joint end_POSTSUBSCRIPT , italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_joint end_POSTSUBSCRIPT ) ) end_ARG start_POSTSUBSCRIPT ≤ 0 as ERM minimizes empirical risk end_POSTSUBSCRIPT
R,(ξjoint,θjoint)+R,(ξjoint,θjoint)R,(fopt,)subscript𝑅subscriptsuperscript𝜉jointsubscriptsuperscript𝜃jointsubscript𝑅subscriptsuperscript𝜉jointsubscriptsuperscript𝜃jointsubscript𝑅superscriptsubscript𝑓opt\displaystyle\qquad\quad-R_{\ell,\mathscr{I}}(\xi^{\ast}_{\rm joint},\theta^{% \ast}_{\rm joint})+R_{\ell,\mathscr{I}}(\xi^{\ast}_{\rm joint},\theta^{\ast}_{% \rm joint})-R_{\ell,\mathscr{I}}(f_{{\rm opt},\mathscr{I}}^{\ell})- italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_joint end_POSTSUBSCRIPT , italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_joint end_POSTSUBSCRIPT ) + italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_joint end_POSTSUBSCRIPT , italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_joint end_POSTSUBSCRIPT ) - italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT roman_opt , script_I end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT )
(ii)(θ,ξ){(θ^joint,ξ^joint),(θjoint,ξjoint)}|R,(ξ,θ)R,,n(ξ,θ)|+R,(ξjoint,θjoint)R,(fopt,)𝑖𝑖subscript𝜃𝜉subscript^𝜃jointsubscript^𝜉jointsubscriptsuperscript𝜃jointsubscriptsuperscript𝜉jointsubscript𝑅𝜉𝜃subscript𝑅𝑛𝜉𝜃subscript𝑅subscriptsuperscript𝜉jointsubscriptsuperscript𝜃jointsubscript𝑅superscriptsubscript𝑓opt\displaystyle\qquad\overset{(ii)}{\leq}\sum_{(\theta,\xi)\in\{(\hat{\theta}_{% \rm joint},\hat{\xi}_{\rm joint}),(\theta^{\ast}_{\rm joint},\xi^{\ast}_{\rm joint% })\}}|R_{\ell,\mathscr{I}}(\xi,\theta)-R_{\ell,\mathscr{I},n}(\xi,\theta)|+R_{% \ell,\mathscr{I}}(\xi^{\ast}_{\rm joint},\theta^{\ast}_{\rm joint})-R_{\ell,% \mathscr{I}}(f_{{\rm opt},\mathscr{I}}^{\ell})start_OVERACCENT ( italic_i italic_i ) end_OVERACCENT start_ARG ≤ end_ARG ∑ start_POSTSUBSCRIPT ( italic_θ , italic_ξ ) ∈ { ( over^ start_ARG italic_θ end_ARG start_POSTSUBSCRIPT roman_joint end_POSTSUBSCRIPT , over^ start_ARG italic_ξ end_ARG start_POSTSUBSCRIPT roman_joint end_POSTSUBSCRIPT ) , ( italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_joint end_POSTSUBSCRIPT , italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_joint end_POSTSUBSCRIPT ) } end_POSTSUBSCRIPT | italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_ξ , italic_θ ) - italic_R start_POSTSUBSCRIPT roman_ℓ , script_I , italic_n end_POSTSUBSCRIPT ( italic_ξ , italic_θ ) | + italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_joint end_POSTSUBSCRIPT , italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_joint end_POSTSUBSCRIPT ) - italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT roman_opt , script_I end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT )
(iii)(θ,ξ){(θ^joint,ξ^joint),(θjoint,ξjoint)}|R,(ξ,θ)R,,n(ξ,θ)|+R,(ξ,θ(ξ))R,(fopt,)𝑖𝑖𝑖subscript𝜃𝜉subscript^𝜃jointsubscript^𝜉jointsubscriptsuperscript𝜃jointsubscriptsuperscript𝜉jointsubscript𝑅𝜉𝜃subscript𝑅𝑛𝜉𝜃subscript𝑅superscript𝜉𝜃superscript𝜉subscript𝑅superscriptsubscript𝑓opt\displaystyle\qquad\overset{(iii)}{\leq}\sum_{(\theta,\xi)\in\{(\hat{\theta}_{% \rm joint},\hat{\xi}_{\rm joint}),(\theta^{\ast}_{\rm joint},\xi^{\ast}_{\rm joint% })\}}|R_{\ell,\mathscr{I}}(\xi,\theta)-R_{\ell,\mathscr{I},n}(\xi,\theta)|+R_{% \ell,\mathscr{I}}(\xi^{*},\theta(\xi^{\ast}))-R_{\ell,\mathscr{I}}(f_{{\rm opt% },\mathscr{I}}^{\ell})start_OVERACCENT ( italic_i italic_i italic_i ) end_OVERACCENT start_ARG ≤ end_ARG ∑ start_POSTSUBSCRIPT ( italic_θ , italic_ξ ) ∈ { ( over^ start_ARG italic_θ end_ARG start_POSTSUBSCRIPT roman_joint end_POSTSUBSCRIPT , over^ start_ARG italic_ξ end_ARG start_POSTSUBSCRIPT roman_joint end_POSTSUBSCRIPT ) , ( italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_joint end_POSTSUBSCRIPT , italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_joint end_POSTSUBSCRIPT ) } end_POSTSUBSCRIPT | italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_ξ , italic_θ ) - italic_R start_POSTSUBSCRIPT roman_ℓ , script_I , italic_n end_POSTSUBSCRIPT ( italic_ξ , italic_θ ) | + italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , italic_θ ( italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) ) - italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT roman_opt , script_I end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT )
(iv)(θ,ξ){(θ^joint,ξ^joint),(θjoint,ξjoint)}|R,(ξ,θ)R,,n(ξ,θ)|Generalization Error𝑖𝑣subscriptsubscript𝜃𝜉subscript^𝜃jointsubscript^𝜉jointsubscriptsuperscript𝜃jointsubscriptsuperscript𝜉jointsubscript𝑅𝜉𝜃subscript𝑅𝑛𝜉𝜃Generalization Error\displaystyle\qquad\overset{(iv)}{\leq}\underbrace{\sum_{(\theta,\xi)\in\{(% \hat{\theta}_{\rm joint},\hat{\xi}_{\rm joint}),(\theta^{\ast}_{\rm joint},\xi% ^{\ast}_{\rm joint})\}}|R_{\ell,\mathscr{I}}(\xi,\theta)-R_{\ell,\mathscr{I},n% }(\xi,\theta)|}_{\text{Generalization Error}}start_OVERACCENT ( italic_i italic_v ) end_OVERACCENT start_ARG ≤ end_ARG under⏟ start_ARG ∑ start_POSTSUBSCRIPT ( italic_θ , italic_ξ ) ∈ { ( over^ start_ARG italic_θ end_ARG start_POSTSUBSCRIPT roman_joint end_POSTSUBSCRIPT , over^ start_ARG italic_ξ end_ARG start_POSTSUBSCRIPT roman_joint end_POSTSUBSCRIPT ) , ( italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_joint end_POSTSUBSCRIPT , italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_joint end_POSTSUBSCRIPT ) } end_POSTSUBSCRIPT | italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_ξ , italic_θ ) - italic_R start_POSTSUBSCRIPT roman_ℓ , script_I , italic_n end_POSTSUBSCRIPT ( italic_ξ , italic_θ ) | end_ARG start_POSTSUBSCRIPT Generalization Error end_POSTSUBSCRIPT
+R,(ξ,θ(ξ))𝔼X[minzgξ(X,z)]retriever error+𝔼X[minzgξ(X,z)]R,(fopt,)predictor errorsubscriptsubscript𝑅superscript𝜉𝜃superscript𝜉subscript𝔼𝑋delimited-[]subscript𝑧subscript𝑔superscript𝜉𝑋𝑧retriever errorsubscriptsubscript𝔼𝑋delimited-[]subscript𝑧subscript𝑔superscript𝜉𝑋𝑧subscript𝑅superscriptsubscript𝑓optpredictor error\displaystyle\qquad\quad+\underbrace{R_{\ell,\mathscr{I}}(\xi^{*},\theta(\xi^{% \ast}))-\mathbb{E}_{X}\big{[}\min_{z\in\mathscr{I}}g_{\xi^{*}}(X,z)\big{]}}_{% \text{retriever error}}+\underbrace{\mathbb{E}_{X}\big{[}\min_{z\in\mathscr{I}% }g_{\xi^{*}}(X,z)\big{]}-R_{\ell,\mathscr{I}}(f_{{\rm opt},\mathscr{I}}^{\ell}% )}_{\text{predictor error}}+ under⏟ start_ARG italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , italic_θ ( italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) ) - blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ roman_min start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_X , italic_z ) ] end_ARG start_POSTSUBSCRIPT retriever error end_POSTSUBSCRIPT + under⏟ start_ARG blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ roman_min start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_X , italic_z ) ] - italic_R start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT roman_opt , script_I end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) end_ARG start_POSTSUBSCRIPT predictor error end_POSTSUBSCRIPT

In the inequality (iii)𝑖𝑖𝑖(iii)( italic_i italic_i italic_i ), we substitute the pair (ξ,θ(ξ))superscript𝜉𝜃superscript𝜉(\xi^{*},\theta(\xi^{\ast}))( italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , italic_θ ( italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) ) for (ξjoint,θjoint)subscriptsuperscript𝜉jointsubscriptsuperscript𝜃joint(\xi^{\ast}_{\rm joint},\theta^{\ast}_{\rm joint})( italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_joint end_POSTSUBSCRIPT , italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_joint end_POSTSUBSCRIPT ) as the former may have higher loss than latter. For the pair (ξ,θ(ξ))superscript𝜉𝜃superscript𝜉(\xi^{*},\theta(\xi^{\ast}))( italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , italic_θ ( italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) ) the predictor error is easily controlled. Also, note that the retriever θ(ξ)𝜃superscript𝜉\theta(\xi^{\ast})italic_θ ( italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) is optimized for the optimal predictor ξsuperscript𝜉\xi^{\ast}italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT. Therefore, unlike the fixed predictor case in Section B.1 we do not have additional predictor error. We next bound the generalization and approximation errors separately by combining the retriever and predictor errors derived earlier.

B.3.1 Generalization Error

First, for the fixed (θ,ξ)superscript𝜃superscript𝜉(\theta^{\ast},\xi^{\ast})( italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) pair we bound the generalization error as

|𝔼X[𝔼Zpθ(|X)gξ(X,Z)]1ni[n]zpθ(z|xi)(hξ(xi,z),yi)|3maxlog(2/δ)n.|\mathbb{E}_{X}\big{[}\mathbb{E}_{Z\sim p_{\theta^{\ast}}(\cdot|X)}g_{\xi^{% \ast}}(X,Z)\big{]}-\frac{1}{n}\sum_{i\in[n]}\sum_{z\in\mathscr{I}}p_{\theta^{% \ast}}(z|x_{i})\ell\big{(}h_{\xi^{\ast}}(x_{i},z),y_{i}\big{)}|\leq 3\ell_{% \max}\sqrt{\tfrac{\log(2/\delta)}{n}}.| blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ blackboard_E start_POSTSUBSCRIPT italic_Z ∼ italic_p start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( ⋅ | italic_X ) end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_X , italic_Z ) ] - divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_z | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) roman_ℓ ( italic_h start_POSTSUBSCRIPT italic_ξ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) | ≤ 3 roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT square-root start_ARG divide start_ARG roman_log ( 2 / italic_δ ) end_ARG start_ARG italic_n end_ARG end_ARG .

Next, the generalization for the (ξ^,θ^)^𝜉^𝜃(\hat{\xi},\hat{\theta})( over^ start_ARG italic_ξ end_ARG , over^ start_ARG italic_θ end_ARG ) error can be bounded as.

|𝔼X[𝔼Zpθ^(|X)gξ^(X,Z)]1ni[n]zpθ^(z|xi)(hξ^(xi,z),yi)|\displaystyle|\mathbb{E}_{X}\big{[}\mathbb{E}_{Z\sim p_{\hat{\theta}}(\cdot|X)% }g_{\hat{\xi}}(X,Z)\big{]}-\frac{1}{n}\sum_{i\in[n]}\sum_{z\in\mathscr{I}}p_{% \hat{\theta}}(z|x_{i})\ell\big{(}h_{\hat{\xi}}(x_{i},z),y_{i}\big{)}|| blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ blackboard_E start_POSTSUBSCRIPT italic_Z ∼ italic_p start_POSTSUBSCRIPT over^ start_ARG italic_θ end_ARG end_POSTSUBSCRIPT ( ⋅ | italic_X ) end_POSTSUBSCRIPT italic_g start_POSTSUBSCRIPT over^ start_ARG italic_ξ end_ARG end_POSTSUBSCRIPT ( italic_X , italic_Z ) ] - divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT over^ start_ARG italic_θ end_ARG end_POSTSUBSCRIPT ( italic_z | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) roman_ℓ ( italic_h start_POSTSUBSCRIPT over^ start_ARG italic_ξ end_ARG end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) |
(i)2𝔼𝝈[max(θ,ξ)Θ×Ξ1ni[n]σizpθ(z|xi)(hξ(xi,z),yi)]+3maxlog(2/δ)n𝑖2subscript𝔼𝝈delimited-[]subscript𝜃𝜉ΘΞ1𝑛subscript𝑖delimited-[]𝑛subscript𝜎𝑖subscript𝑧subscript𝑝𝜃conditional𝑧subscript𝑥𝑖subscript𝜉subscript𝑥𝑖𝑧subscript𝑦𝑖3subscript2𝛿𝑛\displaystyle\qquad\overset{(i)}{\leq}2\mathbb{E}_{\bm{\sigma}}\Big{[}\max_{(% \theta,\xi)\in\Theta\times\Xi}\frac{1}{n}\sum_{i\in[n]}\sigma_{i}\sum_{z\in% \mathscr{I}}p_{\theta}(z|x_{i})\ell\big{(}h_{\xi}(x_{i},z),y_{i}\big{)}\Big{]}% +3\ell_{\max}\sqrt{\tfrac{\log(2/\delta)}{n}}start_OVERACCENT ( italic_i ) end_OVERACCENT start_ARG ≤ end_ARG 2 blackboard_E start_POSTSUBSCRIPT bold_italic_σ end_POSTSUBSCRIPT [ roman_max start_POSTSUBSCRIPT ( italic_θ , italic_ξ ) ∈ roman_Θ × roman_Ξ end_POSTSUBSCRIPT divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT italic_σ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) roman_ℓ ( italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ] + 3 roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT square-root start_ARG divide start_ARG roman_log ( 2 / italic_δ ) end_ARG start_ARG italic_n end_ARG end_ARG
(ii)2×infε[0,cmax/2](4ε+12nεcmax/2log(𝒩(Θ×Ξ,ν,2,[n]))𝑑ν)+3maxlog(2/δ)n.\displaystyle\qquad\overset{(ii)}{\leq}2\times\inf_{\varepsilon\in[0,c_{\max}/% 2]}\big{(}4\varepsilon+\tfrac{12}{\sqrt{n}}\int_{\varepsilon}^{c_{\max}/2}% \sqrt{\log(\mathcal{N}(\Theta\times\Xi,\nu,\|\cdot\|_{2,[n]}))}d\nu\big{)}+3% \ell_{\max}\sqrt{\tfrac{\log(2/\delta)}{n}}.start_OVERACCENT ( italic_i italic_i ) end_OVERACCENT start_ARG ≤ end_ARG 2 × roman_inf start_POSTSUBSCRIPT italic_ε ∈ [ 0 , italic_c start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT / 2 ] end_POSTSUBSCRIPT ( 4 italic_ε + divide start_ARG 12 end_ARG start_ARG square-root start_ARG italic_n end_ARG end_ARG ∫ start_POSTSUBSCRIPT italic_ε end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_c start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT / 2 end_POSTSUPERSCRIPT square-root start_ARG roman_log ( caligraphic_N ( roman_Θ × roman_Ξ , italic_ν , ∥ ⋅ ∥ start_POSTSUBSCRIPT 2 , [ italic_n ] end_POSTSUBSCRIPT ) ) end_ARG italic_d italic_ν ) + 3 roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT square-root start_ARG divide start_ARG roman_log ( 2 / italic_δ ) end_ARG start_ARG italic_n end_ARG end_ARG . (42)

The second inequality again follows using covering number based bounds with chaining Shalev-Shwartz and Ben-David [2014]. We have used for a fixed retriever θ𝜃\thetaitalic_θ

cmax=supθ,ξΘ×Ξ(i[n](zpθ(z|xi)(hξ(xi,z),yi))2)1/2,subscript𝑐subscriptsupremum𝜃𝜉ΘΞsuperscriptsubscript𝑖delimited-[]𝑛superscriptsubscript𝑧subscript𝑝𝜃conditional𝑧subscript𝑥𝑖subscript𝜉subscript𝑥𝑖𝑧subscript𝑦𝑖212c_{\max}=\sup_{\theta,\xi\in\Theta\times\Xi}\Big{(}\sum_{i\in[n]}\big{(}\sum_{% z\in\mathscr{I}}p_{\theta}(z|x_{i})\ell\big{(}h_{\xi}(x_{i},z),y_{i}\big{)}% \big{)}^{2}\Big{)}^{1/2},italic_c start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT = roman_sup start_POSTSUBSCRIPT italic_θ , italic_ξ ∈ roman_Θ × roman_Ξ end_POSTSUBSCRIPT ( ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT ( ∑ start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) roman_ℓ ( italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT ,

and 𝒩(Ξ,ν,2,[n])\mathcal{N}(\Xi,\nu,\|\cdot\|_{2,[n]})caligraphic_N ( roman_Ξ , italic_ν , ∥ ⋅ ∥ start_POSTSUBSCRIPT 2 , [ italic_n ] end_POSTSUBSCRIPT ) denotes the covering number of the retriever function class ΞΞ\Xiroman_Ξ with error ν𝜈\nuitalic_ν in L2subscript𝐿2L_{2}italic_L start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT norm w.r.t. the set {(xi,yi):i[n]}conditional-setsubscript𝑥𝑖subscript𝑦𝑖𝑖delimited-[]𝑛\{(x_{i},y_{i}):i\in[n]\}{ ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) : italic_i ∈ [ italic_n ] }, i.e.,

𝐮2,[n]:=(i[n](zui,z)2)1/2,𝐮n×||.formulae-sequenceassignsubscriptnorm𝐮2delimited-[]𝑛superscriptsubscript𝑖delimited-[]𝑛superscriptsubscript𝑧subscript𝑢𝑖𝑧212for-all𝐮superscript𝑛\|\mathbf{u}\|_{2,[n]}:=\Big{(}\sum_{i\in[n]}\big{(}\sum_{z\in\mathscr{I}}u_{i% ,z}\big{)}^{2}\Big{)}^{1/2},\,\forall\mathbf{u}\in\mathbb{R}^{n\times|\mathscr% {I}|}.∥ bold_u ∥ start_POSTSUBSCRIPT 2 , [ italic_n ] end_POSTSUBSCRIPT := ( ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT ( ∑ start_POSTSUBSCRIPT italic_z ∈ script_I end_POSTSUBSCRIPT italic_u start_POSTSUBSCRIPT italic_i , italic_z end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT , ∀ bold_u ∈ blackboard_R start_POSTSUPERSCRIPT italic_n × | script_I | end_POSTSUPERSCRIPT .

The covering number in Equation (42) can be bounded using the retriever and predictor learning complexities as

log(𝒩(Θ×Ξ,ν,2,[n]))maxξΞlog(𝒩(Θ,ν/2,2,[n],ξ))+maxθΘlog(𝒩(Ξ,ν/2,2,[n],θ)).\sqrt{\log(\mathcal{N}(\Theta\times\Xi,\nu,\|\cdot\|_{2,[n]}))}\leq\max_{\xi% \in\Xi}\sqrt{\log(\mathcal{N}(\Theta,\nu/2,\|\cdot\|_{2,[n],\xi}))}+\max_{% \theta\in\Theta}\sqrt{\log(\mathcal{N}(\Xi,\nu/2,\|\cdot\|_{2,[n],\theta}))}.square-root start_ARG roman_log ( caligraphic_N ( roman_Θ × roman_Ξ , italic_ν , ∥ ⋅ ∥ start_POSTSUBSCRIPT 2 , [ italic_n ] end_POSTSUBSCRIPT ) ) end_ARG ≤ roman_max start_POSTSUBSCRIPT italic_ξ ∈ roman_Ξ end_POSTSUBSCRIPT square-root start_ARG roman_log ( caligraphic_N ( roman_Θ , italic_ν / 2 , ∥ ⋅ ∥ start_POSTSUBSCRIPT 2 , [ italic_n ] , italic_ξ end_POSTSUBSCRIPT ) ) end_ARG + roman_max start_POSTSUBSCRIPT italic_θ ∈ roman_Θ end_POSTSUBSCRIPT square-root start_ARG roman_log ( caligraphic_N ( roman_Ξ , italic_ν / 2 , ∥ ⋅ ∥ start_POSTSUBSCRIPT 2 , [ italic_n ] , italic_θ end_POSTSUBSCRIPT ) ) end_ARG .

This implies that the generalization error of joint learning is (orderwise) bounded by the sum of the generalization error of retriever learning (cf. (30)) and predictor learning (cf. (40)).

B.3.2 Approximation error

The approximation error of predictor and retriever decouples under our decomposition, and under Assumption B.1 and B.2. So the approximation error is also bounded by the sum of the approximation error of retriever learning with optimal predictor, and the approximation error of predictor learning. Our derived bounds approximation error of the retriever holds uniformly for all predictor, so it also holds for optimal predictor. This implies that the joint retriever and predictor learning error is bounded (orderwise) by the sum of the predictor and retriever errors derived earlier in (29), and (38) earlier.

Proof of Theorem 3.3.

We define f𝒩(ν;𝒜,)=supblog(𝒩(𝒜,ν,2,n,b))f_{\mathcal{N}}(\nu;\mathcal{A},\mathcal{B})=\sup_{b\in\mathcal{B}}\sqrt{\log(% \mathcal{N}(\mathcal{A},\nu,\|\cdot\|_{2,n,b}))}italic_f start_POSTSUBSCRIPT caligraphic_N end_POSTSUBSCRIPT ( italic_ν ; caligraphic_A , caligraphic_B ) = roman_sup start_POSTSUBSCRIPT italic_b ∈ caligraphic_B end_POSTSUBSCRIPT square-root start_ARG roman_log ( caligraphic_N ( caligraphic_A , italic_ν , ∥ ⋅ ∥ start_POSTSUBSCRIPT 2 , italic_n , italic_b end_POSTSUBSCRIPT ) ) end_ARG. Putting the approximation and generalization errors together we obtain the final excess risk bound as

Δ,ξ(ξ^,θ^)subscriptΔ𝜉^𝜉^𝜃\displaystyle\Delta_{\ell,\xi}(\hat{\xi},\hat{\theta})roman_Δ start_POSTSUBSCRIPT roman_ℓ , italic_ξ end_POSTSUBSCRIPT ( over^ start_ARG italic_ξ end_ARG , over^ start_ARG italic_θ end_ARG )
3max(1n+log(n)n)+infε[0,max2]8ε+24nεmax2f𝒩(ν2;Θ,Ξ)+f𝒩(ν2;Ξ,Θ)dνabsent3subscript1𝑛𝑛𝑛subscriptinfimum𝜀0subscript28𝜀24𝑛superscriptsubscript𝜀subscript2subscript𝑓𝒩𝜈2ΘΞsubscript𝑓𝒩𝜈2ΞΘ𝑑𝜈\displaystyle\quad\leq 3\ell_{\max}(\tfrac{1}{n}+\sqrt{\tfrac{\log(n)}{n}})+% \inf_{\varepsilon\in[0,\tfrac{\ell_{\max}}{2}]}8\varepsilon+\tfrac{24}{\sqrt{n% }}\int_{\varepsilon}^{\tfrac{\ell_{\max}}{2}}f_{\mathcal{N}}(\tfrac{\nu}{2};% \Theta,\Xi)+f_{\mathcal{N}}(\tfrac{\nu}{2};\Xi,\Theta)d\nu≤ 3 roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ( divide start_ARG 1 end_ARG start_ARG italic_n end_ARG + square-root start_ARG divide start_ARG roman_log ( italic_n ) end_ARG start_ARG italic_n end_ARG end_ARG ) + roman_inf start_POSTSUBSCRIPT italic_ε ∈ [ 0 , divide start_ARG roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG ] end_POSTSUBSCRIPT 8 italic_ε + divide start_ARG 24 end_ARG start_ARG square-root start_ARG italic_n end_ARG end_ARG ∫ start_POSTSUBSCRIPT italic_ε end_POSTSUBSCRIPT start_POSTSUPERSCRIPT divide start_ARG roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG end_POSTSUPERSCRIPT italic_f start_POSTSUBSCRIPT caligraphic_N end_POSTSUBSCRIPT ( divide start_ARG italic_ν end_ARG start_ARG 2 end_ARG ; roman_Θ , roman_Ξ ) + italic_f start_POSTSUBSCRIPT caligraphic_N end_POSTSUBSCRIPT ( divide start_ARG italic_ν end_ARG start_ARG 2 end_ARG ; roman_Ξ , roman_Θ ) italic_d italic_ν
+infθΘinfτ>0maxrθ+τgapξ+log(||)τ2subscriptinfimum𝜃Θsubscriptinfimum𝜏0subscriptsubscriptnormsubscript𝑟𝜃𝜏subscriptgap𝜉superscript𝜏2\displaystyle\quad+\inf_{\theta\in\Theta}\inf_{\tau>0}\ell_{\max}\|r_{\theta}+% \tau\mathrm{gap}_{\xi}\|_{\infty}+\frac{\log(|\mathscr{I}|)}{\tau^{2}}+ roman_inf start_POSTSUBSCRIPT italic_θ ∈ roman_Θ end_POSTSUBSCRIPT roman_inf start_POSTSUBSCRIPT italic_τ > 0 end_POSTSUBSCRIPT roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ∥ italic_r start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT + italic_τ roman_gap start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT + divide start_ARG roman_log ( | script_I | ) end_ARG start_ARG italic_τ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG
+infξΞ2𝔼X[maxy𝒴|hξy(X,z(X))hy(X,z(X))|]+(|𝒴|1)exp(max)+c||γexp(max).subscriptinfimum𝜉Ξ2subscript𝔼𝑋delimited-[]subscript𝑦𝒴superscriptsubscript𝜉𝑦𝑋superscript𝑧𝑋superscriptsubscript𝑦𝑋superscript𝑧𝑋𝒴1subscriptsubscript𝑐superscriptsubscript𝛾subscript\displaystyle\quad+\inf_{\xi\in\Xi}2\mathbb{E}_{X}\big{[}\max_{y\in\mathscr{Y}% }|h_{\xi}^{y}(X,z^{*}(X))-h_{*}^{y}(X,z^{*}(X))|\big{]}+(|\mathscr{Y}|-1)\exp(% -\ell_{\max})+c_{\mathscr{I}}|\mathscr{I}|^{-\gamma_{\mathscr{I}}}\exp(\ell_{% \max}).+ roman_inf start_POSTSUBSCRIPT italic_ξ ∈ roman_Ξ end_POSTSUBSCRIPT 2 blackboard_E start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT [ roman_max start_POSTSUBSCRIPT italic_y ∈ script_Y end_POSTSUBSCRIPT | italic_h start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT ( italic_X , italic_z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_X ) ) - italic_h start_POSTSUBSCRIPT ∗ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_y end_POSTSUPERSCRIPT ( italic_X , italic_z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_X ) ) | ] + ( | script_Y | - 1 ) roman_exp ( - roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ) + italic_c start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT | script_I | start_POSTSUPERSCRIPT - italic_γ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_POSTSUPERSCRIPT roman_exp ( roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ) .

This completes the proof. ∎

B.3.3 Instantiation of MLP retriever and predictor

For the scenario where the retriever and predictor are MLP, we can reuse the earlier analysis to provide the excess risk bound here.

Proof of Theorem 3.4.

Let us recall from Appendix B.1.3, in Equation 32 that a retriever MLP with depth Lretsubscript𝐿retL_{{\rm ret}}italic_L start_POSTSUBSCRIPT roman_ret end_POSTSUBSCRIPT, and width O(dx+dz)𝑂subscript𝑑𝑥subscript𝑑𝑧O(d_{x}+d_{z})italic_O ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) gives an approximation error O(maxLret4κ3(dx+dz)log1/3(||))𝑂subscriptsuperscriptsubscript𝐿ret4𝜅3subscript𝑑𝑥subscript𝑑𝑧superscript13O\left(\ell_{\max}L_{{\rm ret}}^{-\tfrac{4\kappa}{3(d_{x}+d_{z})}}\log^{1/3}(|% \mathscr{I}|)\right)italic_O ( roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT italic_L start_POSTSUBSCRIPT roman_ret end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - divide start_ARG 4 italic_κ end_ARG start_ARG 3 ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) end_ARG end_POSTSUPERSCRIPT roman_log start_POSTSUPERSCRIPT 1 / 3 end_POSTSUPERSCRIPT ( | script_I | ) ) and the generalization error O(maxLWlog(LW)log(n||)n)𝑂subscript𝐿𝑊𝐿𝑊𝑛𝑛O\left(\frac{\ell_{\max}LW\sqrt{\log(LW)\log(n|\mathscr{I}|)}}{\sqrt{n}}\right)italic_O ( divide start_ARG roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT italic_L italic_W square-root start_ARG roman_log ( italic_L italic_W ) roman_log ( italic_n | script_I | ) end_ARG end_ARG start_ARG square-root start_ARG italic_n end_ARG end_ARG ).

Similarly, from Appendix B.2.3, in Equation (39), a MLP predictor with depth Lpredsubscript𝐿predL_{{\rm pred}}italic_L start_POSTSUBSCRIPT roman_pred end_POSTSUBSCRIPT and width O(|𝒴|(dx+dz))𝑂𝒴subscript𝑑𝑥subscript𝑑𝑧O(|\mathscr{Y}|(d_{x}+d_{z}))italic_O ( | script_Y | ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) ) has an approximation error O(Lpred2κ/(dx+dz)+(|𝒴|1)exp(max)+c||γexp(max))𝑂superscriptsubscript𝐿pred2subscript𝜅subscript𝑑𝑥subscript𝑑𝑧𝒴1subscriptsubscript𝑐superscriptsubscript𝛾subscriptO\left(L_{{\rm pred}}^{-2\kappa_{\mathscr{I}}/(d_{x}+d_{z})}+(|\mathscr{Y}|-1)% \exp(-\ell_{\max})+c_{\mathscr{I}}|\mathscr{I}|^{-\gamma_{\mathscr{I}}}\exp(% \ell_{\max})\right)italic_O ( italic_L start_POSTSUBSCRIPT roman_pred end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 2 italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT / ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) end_POSTSUPERSCRIPT + ( | script_Y | - 1 ) roman_exp ( - roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ) + italic_c start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT | script_I | start_POSTSUPERSCRIPT - italic_γ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_POSTSUPERSCRIPT roman_exp ( roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ) ), and a generalization error O(max(Lpredlog(|𝒴|)+Lpred|𝒴|log(Lpred|𝒴|))log(n|||𝒴|)n)𝑂subscriptsubscript𝐿pred𝒴subscript𝐿pred𝒴subscript𝐿pred𝒴𝑛𝒴𝑛O\left(\frac{\ell_{\max}\sqrt{(L_{{\rm pred}}\log(|\mathscr{Y}|)+L_{{\rm pred}% }|\mathscr{Y}|\log(L_{{\rm pred}}|\mathscr{Y}|))\log(n|\mathscr{I}||\mathscr{Y% }|)}}{\sqrt{n}}\right)italic_O ( divide start_ARG roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT square-root start_ARG ( italic_L start_POSTSUBSCRIPT roman_pred end_POSTSUBSCRIPT roman_log ( | script_Y | ) + italic_L start_POSTSUBSCRIPT roman_pred end_POSTSUBSCRIPT | script_Y | roman_log ( italic_L start_POSTSUBSCRIPT roman_pred end_POSTSUBSCRIPT | script_Y | ) ) roman_log ( italic_n | script_I | | script_Y | ) end_ARG end_ARG start_ARG square-root start_ARG italic_n end_ARG end_ARG ).

Thus, the combined error in this case is given as

Δ,(ξ^,θ^)subscriptΔ^𝜉^𝜃\displaystyle\Delta_{\ell,\mathscr{I}}(\hat{\xi},\hat{\theta})roman_Δ start_POSTSUBSCRIPT roman_ℓ , script_I end_POSTSUBSCRIPT ( over^ start_ARG italic_ξ end_ARG , over^ start_ARG italic_θ end_ARG ) O~(maxn(Lret+Lpred|𝒴|))+O(maxLret4κ3(dx+dz)log1/3(||))absent~𝑂subscript𝑛subscript𝐿retsubscript𝐿pred𝒴𝑂subscriptsuperscriptsubscript𝐿ret4𝜅3subscript𝑑𝑥subscript𝑑𝑧superscript13\displaystyle\leq\tilde{O}\left(\frac{\ell_{\max}}{\sqrt{n}}\left(L_{{\rm ret}% }+L_{{\rm pred}}|\mathscr{Y}|\right)\right)+O\Big{(}\ell_{\max}L_{{\rm ret}}^{% -\tfrac{4\kappa}{3(d_{x}+d_{z})}}\log^{1/3}(|\mathscr{I}|)\Big{)}≤ over~ start_ARG italic_O end_ARG ( divide start_ARG roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT end_ARG start_ARG square-root start_ARG italic_n end_ARG end_ARG ( italic_L start_POSTSUBSCRIPT roman_ret end_POSTSUBSCRIPT + italic_L start_POSTSUBSCRIPT roman_pred end_POSTSUBSCRIPT | script_Y | ) ) + italic_O ( roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT italic_L start_POSTSUBSCRIPT roman_ret end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - divide start_ARG 4 italic_κ end_ARG start_ARG 3 ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) end_ARG end_POSTSUPERSCRIPT roman_log start_POSTSUPERSCRIPT 1 / 3 end_POSTSUPERSCRIPT ( | script_I | ) )
+O(Lpred2κ(dx+dz)+(|𝒴|1)exp(max)+c||γexp(max)).𝑂superscriptsubscript𝐿pred2subscript𝜅subscript𝑑𝑥subscript𝑑𝑧𝒴1subscriptsubscript𝑐superscriptsubscript𝛾subscript\displaystyle+O\left(L_{{\rm pred}}^{-\tfrac{2\kappa_{\mathscr{I}}}{(d_{x}+d_{% z})}}+(|\mathscr{Y}|-1)\exp(-\ell_{\max})+c_{\mathscr{I}}|\mathscr{I}|^{-% \gamma_{\mathscr{I}}}\exp(\ell_{\max})\right).+ italic_O ( italic_L start_POSTSUBSCRIPT roman_pred end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - divide start_ARG 2 italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_ARG start_ARG ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) end_ARG end_POSTSUPERSCRIPT + ( | script_Y | - 1 ) roman_exp ( - roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ) + italic_c start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT | script_I | start_POSTSUPERSCRIPT - italic_γ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_POSTSUPERSCRIPT roman_exp ( roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ) ) .

This completes the proof. ∎

Finally, letting max=log(|𝒴|)+κ((dx+dz)+2κ)log(n)subscript𝒴subscript𝜅subscript𝑑𝑥subscript𝑑𝑧2subscript𝜅𝑛\ell_{\max}=\log(|\mathscr{Y}|)+\frac{\kappa_{\mathscr{I}}}{((d_{x}+d_{z})+2% \kappa_{\mathscr{I}})}\log(n)roman_ℓ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT = roman_log ( | script_Y | ) + divide start_ARG italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_ARG start_ARG ( ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) + 2 italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT ) end_ARG roman_log ( italic_n ) and combining the excess risk of retriever learning (2nd term in (33)) and of predictor learning (2nd term in (41)), the joint learning excess error rate is given as

Joint Excess Risk MLP
{O~(n2κ3(dx+dz)+4κ+|𝒴|2κ(dx+dz)+2κnκ(dx+dz)+2κ),if||=Ω(|𝒴|γ1n2κγ1((dx+dz)+2κ)),O~(n2κ3(dx+dz)+4κ+||γ|𝒴|nκ((dx+dz)+2κ)),otherwise.absentcases~𝑂superscript𝑛2𝜅3subscript𝑑𝑥subscript𝑑𝑧4𝜅superscript𝒴2subscript𝜅subscript𝑑𝑥subscript𝑑𝑧2subscript𝜅superscript𝑛subscript𝜅subscript𝑑𝑥subscript𝑑𝑧2subscript𝜅ifΩsuperscript𝒴superscriptsubscript𝛾1superscript𝑛2subscript𝜅superscriptsubscript𝛾1subscript𝑑𝑥subscript𝑑𝑧2subscript𝜅~𝑂superscript𝑛2𝜅3subscript𝑑𝑥subscript𝑑𝑧4𝜅superscriptsubscript𝛾𝒴superscript𝑛subscript𝜅subscript𝑑𝑥subscript𝑑𝑧2subscript𝜅otherwise.\displaystyle\qquad\leq\begin{cases}\tilde{O}\left(n^{-\tfrac{2\kappa}{3(d_{x}% +d_{z})+4\kappa}}+|\mathscr{Y}|^{\tfrac{2\kappa_{\mathscr{I}}}{(d_{x}+d_{z})+2% \kappa_{\mathscr{I}}}}n^{-\tfrac{\kappa_{\mathscr{I}}}{(d_{x}+d_{z})+2\kappa_{% \mathscr{I}}}}\right),&\text{if}~{}|\mathscr{I}|=\Omega\Big{(}|\mathscr{Y}|^{% \gamma_{\mathscr{I}}^{-1}}n^{\frac{2\kappa_{\mathscr{I}}\gamma_{\mathscr{I}}^{% -1}}{((d_{x}+d_{z})+2\kappa_{\mathscr{I}})}}\Big{)},\\ \tilde{O}\left(n^{-\tfrac{2\kappa}{3(d_{x}+d_{z})+4\kappa}}+|\mathscr{I}|^{-% \gamma_{\mathscr{I}}}|\mathscr{Y}|n^{\frac{\kappa_{\mathscr{I}}}{((d_{x}+d_{z}% )+2\kappa_{\mathscr{I}})}}\right),&\text{otherwise.}\end{cases}≤ { start_ROW start_CELL over~ start_ARG italic_O end_ARG ( italic_n start_POSTSUPERSCRIPT - divide start_ARG 2 italic_κ end_ARG start_ARG 3 ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) + 4 italic_κ end_ARG end_POSTSUPERSCRIPT + | script_Y | start_POSTSUPERSCRIPT divide start_ARG 2 italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_ARG start_ARG ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) + 2 italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_ARG end_POSTSUPERSCRIPT italic_n start_POSTSUPERSCRIPT - divide start_ARG italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_ARG start_ARG ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) + 2 italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_ARG end_POSTSUPERSCRIPT ) , end_CELL start_CELL if | script_I | = roman_Ω ( | script_Y | start_POSTSUPERSCRIPT italic_γ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT italic_n start_POSTSUPERSCRIPT divide start_ARG 2 italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT italic_γ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT end_ARG start_ARG ( ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) + 2 italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT ) end_ARG end_POSTSUPERSCRIPT ) , end_CELL end_ROW start_ROW start_CELL over~ start_ARG italic_O end_ARG ( italic_n start_POSTSUPERSCRIPT - divide start_ARG 2 italic_κ end_ARG start_ARG 3 ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) + 4 italic_κ end_ARG end_POSTSUPERSCRIPT + | script_I | start_POSTSUPERSCRIPT - italic_γ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_POSTSUPERSCRIPT | script_Y | italic_n start_POSTSUPERSCRIPT divide start_ARG italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT end_ARG start_ARG ( ( italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) + 2 italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT ) end_ARG end_POSTSUPERSCRIPT ) , end_CELL start_CELL otherwise. end_CELL end_ROW (43)

Here κ𝜅\kappaitalic_κ is defined in Assumption B.1, and (κ,γ)subscript𝜅subscript𝛾(\kappa_{\mathscr{I}},\gamma_{\mathscr{I}})( italic_κ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT , italic_γ start_POSTSUBSCRIPT script_I end_POSTSUBSCRIPT ) are defined in Assumption B.2. Also, dxsubscript𝑑𝑥d_{x}italic_d start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT is the embedding dimension of input x𝒳𝑥𝒳x\in\mathscr{X}italic_x ∈ script_X and dzsubscript𝑑𝑧d_{z}italic_d start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT is the embedding dimension of retrieved example z𝑧z\in\mathscr{I}italic_z ∈ script_I.

Appendix C More experiments

Method small base large
small base large small base large small base large
EMDR2 40.0 47.7 52.0 41.5 48.0 51.4 41.6 48.8 52.6
PDist 49.7 57.4 61.3 48.6 57.0 61.0 47.7 55.7 58.9
Reverse Cross-Entropy + PG 44.9 52.6 54.7 45.3 53.3 55.2 44.9 51.7 54.9
Reverse Cross-Entropy + TopK 48.9 56.8 60.9 47.9 55.5 59.6 46.7 54.3 58.2
Table 4: Recall on NQ. We measure the recall of answer string being present in the retrieved passage performance of RAMs across various training objectives and model sizes. Top row specifies the predictor size and the second row specifies the retriever size.
Method small base large
small base large small base large small base large
EMDR2 46.6 54.7 62.4 46.1 55.7 61.6 46.0 53.9 59.5
PDist 59.6 68.6 72.8 59.1 61.9 72.2 56.4 59.3 69.3
Reverse Cross-Entropy + PG 58.1 60.7 70.7 56.9 66.1 64.2 54.2 61.4 61.3
Reverse Cross-Entropy + TopK 57.1 64.5 69.1 55.9 63.5 68.1 54.2 61.2 65.8
Table 5: Recall on TriviaQA. We measure the recall of answer string being present in the retrieved passage performance of RAMs across various training objectives and model sizes. Top row specifies the predictor size and the second row specifies the retriever size.
small base large
small base large small base large small base large
96.4M 170.9M 396.4M 258.8M 333.3M 558.9M 773.6M 848.1M 1073.7M
Table 6: Parameters. We report the model parameters in various configuration by RAMs across various model sizes. Top row specifies the predictor size and the second row specifies the retriever size.

C.1 Implementation details

Computing the objective (13), let alone its gradient, requires evaluating the reader and predictor over the entire data-store \mathscr{I}script_I making it prohibitively expensive. We explore two ways to approximately compute the objective:

Top-K approximation

This approach involves constraining the summation to a specific subset. Periodically we compute pθ(z|x)subscript𝑝𝜃conditional𝑧𝑥p_{\theta}(z|x)italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_z | italic_x ) for all items z𝑧z\in\mathscr{I}italic_z ∈ script_I based on the current value of θ𝜃\thetaitalic_θ. We use this to obtain a set of K𝐾Kitalic_K documents 𝒵(xi)𝒵subscript𝑥𝑖\mathscr{Z}(x_{i})script_Z ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) with the highest (stale) scores, i.e. 𝒯K(pθ(|xi))\mathcal{T}_{K}(p_{\theta}(\cdot|x_{i}))caligraphic_T start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ( italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) and evaluate the sum on this.

,nRCE+TopK(θ;ξ,)=1ni[n]z𝒵(xi)pθ,(z|xi)logpξ(yi|xi,z)subscriptsuperscriptRCE+TopK𝑛𝜃𝜉1𝑛subscript𝑖delimited-[]𝑛subscript𝑧𝒵subscript𝑥𝑖subscript𝑝𝜃conditional𝑧subscript𝑥𝑖subscript𝑝𝜉conditionalsubscript𝑦𝑖subscript𝑥𝑖𝑧\mathscr{L}^{\textsc{RCE+TopK}}_{\mathscr{I},n}(\theta;\xi,\mathscr{I})=-\frac% {1}{n}\sum_{i\in[n]}\sum_{z\in\mathscr{Z}(x_{i})}p_{\theta,\mathscr{I}}(z|x_{i% })\cdot\log p_{\xi}(y_{i}|x_{i},z)script_L start_POSTSUPERSCRIPT RCE+TopK end_POSTSUPERSCRIPT start_POSTSUBSCRIPT script_I , italic_n end_POSTSUBSCRIPT ( italic_θ ; italic_ξ , script_I ) = - divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_z ∈ script_Z ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_θ , script_I end_POSTSUBSCRIPT ( italic_z | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ⋅ roman_log italic_p start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) (44)

This methodology is akin to those adopted by EMDR2 and PDist, with the set being refreshed every 500 training steps and the selection of K=64𝐾64K=64italic_K = 64.

Policy gradient

Based on connection to RLHF/RLAIF, we propose to use policy gradient method [Sutton and Barto, 2018] to obtain an unbaised estimate of gradient with respect to θ𝜃\thetaitalic_θ efficiently. However, as policy gradients suffer from high variance [Burda et al., 2015, Grathwohl et al., 2021] we use a constant baseline [Williams, 1992] for variance reduction, i.e. our objective becomes

,nRCE+PG(θ;ξ,)subscriptsuperscriptRCE+PG𝑛𝜃𝜉\displaystyle\mathscr{L}^{\textsc{RCE+PG}}_{\mathscr{I},n}(\theta;\xi,\mathscr% {I})script_L start_POSTSUPERSCRIPT RCE+PG end_POSTSUPERSCRIPT start_POSTSUBSCRIPT script_I , italic_n end_POSTSUBSCRIPT ( italic_θ ; italic_ξ , script_I ) =1ni[n]j[K]pθ,(zj(xi)|xi)[logpξ(yi|xi,zj(xi))b]absent1𝑛subscript𝑖delimited-[]𝑛subscript𝑗delimited-[]𝐾subscript𝑝𝜃conditionalsubscript𝑧𝑗subscript𝑥𝑖subscript𝑥𝑖delimited-[]subscript𝑝𝜉conditionalsubscript𝑦𝑖subscript𝑥𝑖subscript𝑧𝑗subscript𝑥𝑖𝑏\displaystyle=-\frac{1}{n}\sum_{i\in[n]}\sum_{j\in[K]}p_{\theta,\mathscr{I}}(z% _{j}(x_{i})|x_{i})\cdot\big{[}\log p_{\xi}(y_{i}|x_{i},z_{j}(x_{i}))-b\big{]}= - divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_j ∈ [ italic_K ] end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_θ , script_I end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ⋅ [ roman_log italic_p start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) - italic_b ] (45)
θ,nRCE+PG(θ;ξ,)subscript𝜃subscriptsuperscriptRCE+PG𝑛𝜃𝜉\displaystyle\nabla_{\theta}\mathscr{L}^{\textsc{RCE+PG}}_{\mathscr{I},n}(% \theta;\xi,\mathscr{I})∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT script_L start_POSTSUPERSCRIPT RCE+PG end_POSTSUPERSCRIPT start_POSTSUBSCRIPT script_I , italic_n end_POSTSUBSCRIPT ( italic_θ ; italic_ξ , script_I ) =1ni[n]j[K]θlogpθ,(zj(xi)|xi)[logpξ(yi|xi,zj(xi))b],absent1𝑛subscript𝑖delimited-[]𝑛subscript𝑗delimited-[]𝐾subscript𝜃subscript𝑝𝜃conditionalsubscript𝑧𝑗subscript𝑥𝑖subscript𝑥𝑖delimited-[]subscript𝑝𝜉conditionalsubscript𝑦𝑖subscript𝑥𝑖subscript𝑧𝑗subscript𝑥𝑖𝑏\displaystyle=-\frac{1}{n}\sum_{i\in[n]}\sum_{j\in[K]}\nabla_{\theta}\log p_{% \theta,\mathscr{I}}(z_{j}(x_{i})|x_{i})\cdot\big{[}\log p_{\xi}(y_{i}|x_{i},z_% {j}(x_{i}))-b\big{]},= - divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_j ∈ [ italic_K ] end_POSTSUBSCRIPT ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_θ , script_I end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ⋅ [ roman_log italic_p start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) - italic_b ] ,

where zj(xi)pθ(|xi)z_{j}(x_{i})\sim p_{\theta}(\cdot|x_{i})italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ∼ italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) are K𝐾Kitalic_K i.i.d. samples from the retriever distribution. We use K=64𝐾64K=64italic_K = 64 and b=5𝑏5b=5italic_b = 5.

C.2 Training details

Optimization. For all of our experiments, we use ADAM weight decay optimizer with a short warm up period (2000 steps) and a linear decay schedule. We use the peak learning rate of 1×1041superscript1041\times 10^{-4}1 × 10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT. The weight decay factor is 0.1. We chose batch sizes to be 64646464. The number of total training steps is as follows:

  • No retriever, train predictor ξ𝜉\xiitalic_ξ: 40,000

  • Fixed retriever θ0subscript𝜃0\theta_{0}italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, train predictor ξ𝜉\xiitalic_ξ: 20,000

  • Fixed predictor ξ(θ0)superscript𝜉subscript𝜃0\xi^{\star}(\theta_{0})italic_ξ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ), train retriever θ𝜃\thetaitalic_θ: 20,000

  • Jointly train predictor ξ𝜉\xiitalic_ξ and retriever θ𝜃\thetaitalic_θ: 40,000

Initializations  We initialize models for different configurations as follows:

  • No retriever, train predictor ξ𝜉\xiitalic_ξ: We initialize the predictor from public pretrained T5 checkpoint.

  • Fixed retriever θ0subscript𝜃0\theta_{0}italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, train predictor ξ𝜉\xiitalic_ξ: We initialize the fixed retriever from public pretrained GTR checkpoint and predictor from public pretrained T5 checkpoint.

  • Fixed predictor ξ(θ0)superscript𝜉subscript𝜃0\xi^{\star}(\theta_{0})italic_ξ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ), train retriever θ𝜃\thetaitalic_θ: We initialize the fixed predictor from the final checkpoint of previous run, i.e. “Fixed retriever θ0subscript𝜃0\theta_{0}italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, train predictor ξ𝜉\xiitalic_ξ”. The retriever is initialized from public pretrained GTR checkpoint.

  • Jointly train predictor ξ𝜉\xiitalic_ξ and retriever θ𝜃\thetaitalic_θ: We initialize the fixed retriever from public pretrained GTR checkpoint and predictor from public pretrained T5 checkpoint.