HTML conversions sometimes display errors due to content that did not convert correctly from the source. This paper uses the following packages that are not yet supported by the HTML conversion tool. Feedback on these issues are not necessary; they are known and are being worked on.

  • failed: ctable

Authors: achieve the best HTML results from your LaTeX submissions by following these best practices.

License: CC BY 4.0
arXiv:2401.06799v1 [cs.CL] 09 Jan 2024

Make Prompts Adaptable: Bayesian Modeling
for Vision-Language Prompt Learning with Data-Dependent Prior

Youngjae Cho1, HeeSun Bae1, Seungjae Shin 1,Yeo Dong Youn 2, Weonyoung Joo 3
Il-Chul Moon 1
Abstract

Recent Vision-Language Pretrained (VLP) models have become the backbone for many downstream tasks, but they are utilized as frozen model without learning. Prompt learning is a method to improve the pre-trained VLP model by adding a learnable context vector to the inputs of the text encoder. In a few-shot learning scenario of the downstream task, MLE training can lead the context vector to over-fit dominant image features in the training data. This overfitting can potentially harm the generalization ability, especially in the presence of a distribution shift between the training and test dataset. This paper presents a Bayesian-based framework of prompt learning, which could alleviate the over-fitting issues on few-shot learning application and increase the adaptability of prompts on unseen instances. Specifically, modeling data-dependent prior enhances the adaptability of text features for both seen and unseen image features without the trade-off of performance between them. Based on the Bayesian framework, we utilize the Wasserstein Gradient Flow in the estimation of our target posterior distribution, which enables our prompt to be flexible in capturing the complex modes of image features. We demonstrate the effectiveness of our method on benchmark datasets for several experiments by showing statistically significant improvements on performance compared to existing methods. The code is available at https://github.com/youngjae-cho/APP.

Einführung

Recently, Vision-Language Pretrained models (VLP) (Radford et al. 2021; Jia et al. 2021) have been used as backbones for various downstream tasks (Shen et al. 2022; Ruan, Dubois, and Maddison 2022), and the pre-trained models have shown successful adaptation. Since these pre-trained models are used as-is in downstream tasks, prompt learning adds a context vector to the input of pre-trained model, so the context vector becomes the learnable parameter to improve the representation from the pre-trained model (Zhou et al. 2022b) for the downstream task. For instance, a text input is concatenated to a context vector, and the new text input could be fed to the text encoder. The learning of context vector comes from the back-propagation after the feed-forward of the concatenated text input. Since there is only a single context vector without being conditioned by inputs, the inferred value of context vector becomes a static single context defined for the given downstream task.

Whereas improving parameter-frozen VLP models by additional input context vector is a feasible solution, it can potentially overfit to a dense area of image features in few-shot learning. Since text features are hard to capture the multi-modes of image features in MLE training, it could fail to infer the minor area of image features, eventually degrading the performance. In addition, MLE training undermines the generalization capability of VLP models especially when there is a distribution shift between the training and testing (Zhou et al. 2022a). Although several input-conditioned prompt learning (Zhou et al. 2022a; Derakhshani et al. 2023) tried to generalize unseen datasets, it inevitably undermines the performance of seen datasets.

Refer to caption
(a) Structure view of APP
Refer to caption
(b) Learning dynamics of APP
Figure 1: Structure (left) and learning dynamics (right) of APP. Multiple context vectors are particles of approximated distribution and image conditioned prior can guide the context vector to capture the multi modes.

To alleviate the impact of uncertainty arising from a few-shot learning scenario, our paper proposes Adaptive Particle-based Prompt Learning (APP), which utilizes a Bayesian inference for prompt learning with a data-dependent prior as shown in Figure 1. Through regularization using this data-dependent prior, the context vector is directed toward capturing the diverse modes in image features among the seen data instances. Additionally, we approximate the posterior distribution via Wasserstein Gradient Flow to enhance the flexibility of our text features to infer the complex image features. Furthermore, we extend the modeling of the data-dependent prior to unseen test instances to adapt the distribution shift. This adaptation of the context vector to the unseen instances enhances the model’s resilience in the face of distribution shifts, providing robustness to these variations.

We summarize our contributions in two aspects.

  1. 1.

    Enhancing Flexibility of Prompt: By approximating prompt posterior with Wasserstein Gradient Flow, our context vectors can be more flexibly utilized to infer the complex image feature spaces.

  2. 2.

    Enhancing Adaptability of Prompt: By modeling data-dependent prior based on the image feature information, text features capture the multi-modes of seen image features and adapt to unseen image features, which leads to the improved performances of seen and unseen datasets without trade-off.

Preliminary

Formulation of Prompt Learning

Deterministic Prompt Learning

The goal of prompt learning, i.e. CoOp (Zhou et al. 2022b), is to facilitate adaptation of a given Vision-Language Pretrained model for a target task through learning arbitrary context vectors. In this setting, the pre-trained VLP model is frozen under the prompt learning task, and the additional learnable input is added to text inputs. Finally, the learning gradient is obtained from adapting the target task, a.k.a. downstream task. When we define (X,Y)𝑋𝑌(X,Y)( italic_X , italic_Y ) as a pair of image and its label (i.e. text phrase), Eq. Deterministic Prompt Learning shows a log-likelihood of prompt learning in the image classification.

CE(θ,X,Y)=logp(Y|X,θ)subscript𝐶𝐸𝜃𝑋𝑌𝑝conditional𝑌𝑋𝜃\displaystyle\mathcal{L}_{CE}(\theta,X,Y)=-\log p(Y|X,\theta)caligraphic_L start_POSTSUBSCRIPT italic_C italic_E end_POSTSUBSCRIPT ( italic_θ , italic_X , italic_Y ) = - roman_log italic_p ( italic_Y | italic_X , italic_θ )
=i=1N{log(exp(sim(g(θ,yi),f(xi))/τ)k=1Cexp(sim(g(θ,yk),f(xi))/τ))}absentsuperscriptsubscript𝑖1𝑁sim𝑔𝜃subscript𝑦𝑖𝑓subscript𝑥𝑖𝜏superscriptsubscript𝑘1𝐶sim𝑔𝜃subscript𝑦𝑘𝑓subscript𝑥𝑖𝜏\displaystyle=-\sum_{i=1}^{N}\{\log({\exp(\text{sim}(g(\theta,y_{i}),f(x_{i}))% /\tau)\over\sum_{k=1}^{C}\exp(\text{sim}(g(\theta,y_{k}),f(x_{i}))/\tau)})\}= - ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT { roman_log ( divide start_ARG roman_exp ( sim ( italic_g ( italic_θ , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) , italic_f ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) / italic_τ ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C end_POSTSUPERSCRIPT roman_exp ( sim ( italic_g ( italic_θ , italic_y start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) , italic_f ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) / italic_τ ) end_ARG ) } (1)

Here, f𝑓fitalic_f and g𝑔gitalic_g are image and text encoders from VLP models, respectively; and they are frozen in prompt learning. Therefore, the only learnable part is θ𝜃\thetaitalic_θ, which is a context vector. θd𝜃superscript𝑑\theta\in\mathbb{R}^{d}italic_θ ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT is learned as a unique vector for each downstream task without discriminating the data instances. Since g𝑔gitalic_g is often implemented as a transformer to take sequential inputs of any length, g𝑔gitalic_g does not need to be modified to accept y𝑦yitalic_y and θ𝜃\thetaitalic_θ. Additionally, sim()sim\text{sim}(\cdot)sim ( ⋅ ) represents cosine similarity, and τ𝜏\tauitalic_τ is the annealing temperature.

From Eq. Deterministic Prompt Learning, we define the prompt as {θ,yi}𝜃subscript𝑦𝑖\{\theta,y_{i}\}{ italic_θ , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT }, which is the concatenation of the context vector and the label. By minimizing Eq. Deterministic Prompt Learning, θ𝜃\thetaitalic_θ is learned to maximize the alignment of the space between the image feature f(x)𝑓𝑥f(x)italic_f ( italic_x ) and the text feature g(θ,yi)𝑔𝜃subscript𝑦𝑖g(\theta,y_{i})italic_g ( italic_θ , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ). Therefore, the learnable part of prompt learning contributes from the input space side, rather than the frozen VLP model parameters. There were some follow-up researches on CoOp. For example, CoCoOp (Zhou et al. 2022a) extended CoOp by learning the image-conditional context vector as θ+ϕ(f(xi))𝜃italic-ϕ𝑓subscript𝑥𝑖\theta+\phi(f(x_{i}))italic_θ + italic_ϕ ( italic_f ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ), where ϕitalic-ϕ\phiitalic_ϕ is a neural network to map image feature to the prompt space.

Probabilistic Prompt Learning

Given a few instance of training dataset, the point estimate of text feature given prompt is hard to capture unseen image feature. ProDA (Lu et al. 2022) is the first probabilistic model, where the text feature-given prompt is approximated as a Gaussian distribution with a regularizer to enhance the diversity of text feature. PLOT (Chen et al. 2023) formulates the prompt learning as optimal transport, where image and text features are defined as a discrete distribution by Dirac measure. The text features, given as multiple prompts, are assigned to the locality of image features to learn diverse semantics.

Bayesian Probabilistic Prompt Learning

Since MLE training can induce overfitting with a few training dataset, Bayesian inference is needed to mitigate the high data variance from such a limited dataset. BPL (Derakhshani et al. 2023) is the first prompt learning model from the view of Bayesian inference, which uses variational inference to approximate the posterior distribution with a parameterized Gaussian distribution. The objective function of BPL is defined as follows:

𝔼q(r|X)[log(p(Y|X,θ,r))]DKL(q(r|X)||p(r))\operatorname{\mathbb{E}}_{q(r|X)}[\log(p(Y|X,\theta,r))]-D_{KL}(q(r|X)||p(r))blackboard_E start_POSTSUBSCRIPT italic_q ( italic_r | italic_X ) end_POSTSUBSCRIPT [ roman_log ( italic_p ( italic_Y | italic_X , italic_θ , italic_r ) ) ] - italic_D start_POSTSUBSCRIPT italic_K italic_L end_POSTSUBSCRIPT ( italic_q ( italic_r | italic_X ) | | italic_p ( italic_r ) ) (2)

where r𝑟ritalic_r is a random variable conditioned on X𝑋Xitalic_X, and r𝑟ritalic_r is added to a deterministic θ𝜃\thetaitalic_θ to turn it into a random variable, which is a reparameterization trick. In detail, the distribution of r𝑟ritalic_r is given by q(r|X)𝑞conditional𝑟𝑋q(r|X)italic_q ( italic_r | italic_X ), which is a Gaussian distribution parameterized by m(f(X))𝑚𝑓𝑋m(f(X))italic_m ( italic_f ( italic_X ) ) and Σ(f(X))Σ𝑓𝑋\Sigma(f(X))roman_Σ ( italic_f ( italic_X ) ), where m(f(X))𝑚𝑓𝑋m(f(X))italic_m ( italic_f ( italic_X ) ) and Σ(f(X))Σ𝑓𝑋\Sigma(f(X))roman_Σ ( italic_f ( italic_X ) ) are functions of the image feature f(X)𝑓𝑋f(X)italic_f ( italic_X ). The prior distribution over r𝑟ritalic_r is p(r)𝑝𝑟p(r)italic_p ( italic_r ), which is a standard Gaussian distribution N(0,I)𝑁0𝐼N(0,I)italic_N ( 0 , italic_I ).

Wasserstein Gradient Flow

Bayesian inference is a solution to mitigate the uncertainty from modeling the posterior distribution of parameters. Often, the hurdle of the Bayesian inference is the inference of the posterior distribution, which could be difficult, i.e. modeling either prior or likelihood to be flexible without being conjugate to each other. For improving the posterior distribution to be flexible and multi-modal, we need an inference tool for this complex posterior distribution.

For instance, the JKO scheme (Jordan, Kinderlehrer, and Otto 1998) interprets variational inference as gradient flow, which minimizes the KL divergence between the variational distribution q𝑞qitalic_q and the true posterior distribution πexp(V(θ))proportional-to𝜋𝑉𝜃\pi\propto\exp(-V(\theta))italic_π ∝ roman_exp ( - italic_V ( italic_θ ) ) in Wasserstein Space, where V(θ)𝑉𝜃V(\theta)italic_V ( italic_θ ) is an energy function of posterior distribution. The learning objective of this variational posterior inference F(q)𝐹𝑞F(q)italic_F ( italic_q ) becomes the KL Divergence as follows.

F(q)DKL(q||π)𝔼q[V(θ)+logq]F(q)\coloneqq D_{KL}(q||\pi)\approx\operatorname{\mathbb{E}}_{q}[V(\theta)+% \log q]italic_F ( italic_q ) ≔ italic_D start_POSTSUBSCRIPT italic_K italic_L end_POSTSUBSCRIPT ( italic_q | | italic_π ) ≈ blackboard_E start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT [ italic_V ( italic_θ ) + roman_log italic_q ] (3)

To compute the steepest gradient of F(q)𝐹𝑞F(q)italic_F ( italic_q ), we define the Wasserstein Gradient Flow (WGF) as follows:

Definition 1.

Suppose we have a Wasserstein space 𝒲2=(𝒫2(d),W2)subscript𝒲2subscript𝒫2superscript𝑑subscript𝑊2\mathcal{W}_{2}=(\mathcal{P}_{2}(\mathbb{R}^{d}),W_{2})caligraphic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = ( caligraphic_P start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ) , italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ), 𝒫2(d)={μ𝒫2,θ2𝑑μ(θ)<}subscript𝒫2superscript𝑑formulae-sequence𝜇subscript𝒫2superscriptnorm𝜃2differential-d𝜇𝜃\mathcal{P}_{2}(\mathbb{R}^{d})=\{\mu\in\mathcal{P}_{2},\int||\theta||^{2}d\mu% (\theta)<\infty\}caligraphic_P start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ) = { italic_μ ∈ caligraphic_P start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , ∫ | | italic_θ | | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_d italic_μ ( italic_θ ) < ∞ }, W2(μ1,μ2)=minωΠ(μ1,μ2)θθ2𝑑ω(θ,θ)subscript𝑊2subscript𝜇1subscript𝜇2𝑚𝑖subscript𝑛𝜔normal-Πsubscript𝜇1subscript𝜇2superscriptnorm𝜃superscript𝜃normal-′2differential-d𝜔𝜃superscript𝜃normal-′W_{2}(\mu_{1},\mu_{2})=min_{\omega\in\Pi(\mu_{1},\mu_{2})}\int{||\theta-\theta% ^{\prime}||}^{2}d\omega(\theta,\theta^{\prime})italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) = italic_m italic_i italic_n start_POSTSUBSCRIPT italic_ω ∈ roman_Π ( italic_μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT ∫ | | italic_θ - italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT | | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_d italic_ω ( italic_θ , italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ).

A curve of μtsubscript𝜇𝑡\mu_{t}italic_μ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is a Wasserstein Gradient Flow for functional F𝐹Fitalic_F, if it satisfies Eq. 4.

tμt=(μtθδF(μt)δμ)=(μtW2F(μt))subscript𝑡subscript𝜇𝑡subscript𝜇𝑡subscript𝜃𝛿𝐹subscript𝜇𝑡𝛿𝜇subscript𝜇𝑡subscriptsubscript𝑊2𝐹subscript𝜇𝑡\partial_{t}\mu_{t}=\nabla\cdot(\mu_{t}\nabla_{\theta}{\delta F(\mu_{t})\over% \delta\mu})=\nabla\cdot(\mu_{t}\nabla_{W_{2}}F(\mu_{t}))∂ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_μ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = ∇ ⋅ ( italic_μ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT divide start_ARG italic_δ italic_F ( italic_μ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_ARG start_ARG italic_δ italic_μ end_ARG ) = ∇ ⋅ ( italic_μ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∇ start_POSTSUBSCRIPT italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_F ( italic_μ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ) (4)

WGF can be discretized as Stochastic Gradient Langevin Dynamics (SGLD) (Welling and Teh 2011; Chen et al. 2018) as Eq. 5, where each particle θ𝜃\thetaitalic_θ follows the true posterior distribution π𝜋\piitalic_π with Gaussian perturbation.

θt+1i=θtih(θtiV(θti))+2hϵ,ϵN(0,I)formulae-sequencesuperscriptsubscript𝜃𝑡1𝑖superscriptsubscript𝜃𝑡𝑖subscriptsuperscriptsubscript𝜃𝑡𝑖𝑉subscriptsuperscript𝜃𝑖𝑡2italic-ϵsimilar-toitalic-ϵ𝑁0𝐼\theta_{t+1}^{i}=\theta_{t}^{i}-h(\nabla_{\theta_{t}^{i}}V(\theta^{i}_{t}))+% \sqrt{2h}\epsilon,\epsilon\sim N(0,I)italic_θ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT = italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT - italic_h ( ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_V ( italic_θ start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ) + square-root start_ARG 2 italic_h end_ARG italic_ϵ , italic_ϵ ∼ italic_N ( 0 , italic_I ) (5)

Whereas the Gaussian noise can assure the diversity of parameters, the learning can be unstable, when the learning rate is high.

Hence, this paper relies on Stein Variational Gradient Descent (SVGD) (Liu and Wang 2016), which is a version of Wasserstein Gradient Flow with Reproducing Kernel Hilbert Space (RKHS) (Chen et al. 2018). In SVGD, interaction between particles θisubscript𝜃𝑖\theta_{i}italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT guarantees their convergence to true posterior distribution with forcing diversity.

θt+1i=θtihMj=1MK(θti,θtj)θtjV(θtj)θtjK(θti,θtj)superscriptsubscript𝜃𝑡1𝑖superscriptsubscript𝜃𝑡𝑖𝑀superscriptsubscript𝑗1𝑀𝐾superscriptsubscript𝜃𝑡𝑖superscriptsubscript𝜃𝑡𝑗subscriptsuperscriptsubscript𝜃𝑡𝑗𝑉superscriptsubscript𝜃𝑡𝑗subscriptsuperscriptsubscript𝜃𝑡𝑗𝐾superscriptsubscript𝜃𝑡𝑖superscriptsubscript𝜃𝑡𝑗\theta_{t+1}^{i}=\theta_{t}^{i}-{h\over M}\sum_{j=1}^{M}K(\theta_{t}^{i},% \theta_{t}^{j})\nabla_{\theta_{t}^{j}}V(\theta_{t}^{j})-\nabla_{\theta_{t}^{j}% }K(\theta_{t}^{i},\theta_{t}^{j})italic_θ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT = italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT - divide start_ARG italic_h end_ARG start_ARG italic_M end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_K ( italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT , italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ) ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_V ( italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ) - ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_K ( italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT , italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ) (6)

By using SVGD to approximate true posterior distribution, context vectors θjsuperscript𝜃𝑗\theta^{j}italic_θ start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT can be optimized to follow the true posterior distribution, effectively capturing a representation space of the image features.

Data-Dependent Prior

In a Bayesian neural network, the prior is commonly chosen as the Standard normal distribution, i.e. BPL, which is data-independently initialized with zero means. Since such distribution does not include any information on data, it only regularizes the context vector in the neighbor of zero-mean. In many domains (Li et al. 2020; gil Lee et al. 2022), the data-dependent prior is utilized to improve the prior knowledge more informative.

This paper utilizes the data-dependent prior for prompt learning, which is not restricted to the standard Gaussian distribution. Specifically, this paper derives the prior distribution to be dependent on image features, which can have multiple modes in their distributions.

Method

This section introduces adaptive particle-based prompt learning (APP) by enumerating the model formulation and by explaining its inference method.

Formulation of Prompt Posterior Distribution

Following the CoOp formulation (Zhou et al. 2022b), we additionally reformulate the posterior distribution of context vector θ𝜃\thetaitalic_θ as Eq. Formulation of Prompt Posterior Distribution.

π(θ)=p(θ|X,Y)𝜋𝜃𝑝conditional𝜃𝑋𝑌\displaystyle\pi(\theta)=p(\theta|X,Y)italic_π ( italic_θ ) = italic_p ( italic_θ | italic_X , italic_Y ) =p(Y|X,θ)p(θ|X)p(X)p(X,Y)absent𝑝conditional𝑌𝑋𝜃𝑝conditional𝜃𝑋𝑝𝑋𝑝𝑋𝑌\displaystyle={p(Y|X,\theta)p(\theta|X)p(X)\over p(X,Y)}= divide start_ARG italic_p ( italic_Y | italic_X , italic_θ ) italic_p ( italic_θ | italic_X ) italic_p ( italic_X ) end_ARG start_ARG italic_p ( italic_X , italic_Y ) end_ARG
p(Y|X,θ)p(θ|X)proportional-toabsent𝑝conditional𝑌𝑋𝜃𝑝conditional𝜃𝑋\displaystyle\propto p(Y|X,\theta)p(\theta|X)∝ italic_p ( italic_Y | italic_X , italic_θ ) italic_p ( italic_θ | italic_X ) (7)

π(θ)𝜋𝜃\pi(\theta)italic_π ( italic_θ ) is true posterior distribution, which is factorized with likelihood p(Y|X,θ)𝑝conditional𝑌𝑋𝜃p(Y|X,\theta)italic_p ( italic_Y | italic_X , italic_θ ) and conditioned prior p(θ|X)𝑝conditional𝜃𝑋p(\theta|X)italic_p ( italic_θ | italic_X ). For likelihood Eq. Deterministic Prompt Learning, we follow the formulation of CoOp, Eq. Deterministic Prompt Learning, and we propose the image feature conditioned prior Eq. 8, where mean is parametrized as ϕitalic-ϕ\phiitalic_ϕ and ϕ(f)¯1Ni=1Nϕ(f(xi))¯italic-ϕ𝑓1𝑁superscriptsubscript𝑖1𝑁italic-ϕ𝑓subscript𝑥𝑖\overline{\phi(f)}\coloneqq{1\over N}\sum_{i=1}^{N}\phi(f(x_{i}))over¯ start_ARG italic_ϕ ( italic_f ) end_ARG ≔ divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_ϕ ( italic_f ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ). We set the standard deviation of prior σ𝜎\sigmaitalic_σ as a hyper-parameter.

logp(θ|X)(θϕ(f)¯)T(θϕ(f)¯)σ2proportional-to𝑝conditional𝜃𝑋superscript𝜃¯italic-ϕ𝑓𝑇𝜃¯italic-ϕ𝑓superscript𝜎2\log p(\theta|X)\propto-{(\theta-\overline{\phi(f)})^{T}(\theta-\overline{\phi% (f)})\over\sigma^{2}}roman_log italic_p ( italic_θ | italic_X ) ∝ - divide start_ARG ( italic_θ - over¯ start_ARG italic_ϕ ( italic_f ) end_ARG ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( italic_θ - over¯ start_ARG italic_ϕ ( italic_f ) end_ARG ) end_ARG start_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG (8)

Bayesian Adaptation of Prompt to Test data

In few-shot learning, there is uncertainty on whether the test data will follow the distribution of training data or not. If there is a difference between the two data distributions, such a difference will harm the generalization ability of VLP model. To model the uncertainty from the mismatch between training and test datasets in the few-shot learning framework, we reformulate the posterior distribution to consider the uncertainty of a test image xsuperscript𝑥x^{\prime}italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT. Since training and test datasets are i.i.d sample; and because the prior is assumed to follow the Gaussian distribution; we derive the posterior distribution as Eq.Bayesian Adaptation of Prompt to Test data.

π(θ)=p(θ|X,Y,x)𝜋𝜃𝑝conditional𝜃𝑋𝑌superscript𝑥\displaystyle\pi(\theta)=p(\theta|X,Y,x^{\prime})italic_π ( italic_θ ) = italic_p ( italic_θ | italic_X , italic_Y , italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) =p(Y|X,θ)p(θ|X)p(θ|x)p(X)p(X,Y)absent𝑝conditional𝑌𝑋𝜃𝑝conditional𝜃𝑋𝑝conditional𝜃superscript𝑥𝑝𝑋𝑝𝑋𝑌\displaystyle={p(Y|X,\theta)p(\theta|X)p(\theta|x^{\prime})p(X)\over p(X,Y)}= divide start_ARG italic_p ( italic_Y | italic_X , italic_θ ) italic_p ( italic_θ | italic_X ) italic_p ( italic_θ | italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) italic_p ( italic_X ) end_ARG start_ARG italic_p ( italic_X , italic_Y ) end_ARG
p(Y|X,θ)p(θ|X)Trainingp(θ|x)Testingproportional-toabsentsubscript𝑝conditional𝑌𝑋𝜃𝑝conditional𝜃𝑋𝑇𝑟𝑎𝑖𝑛𝑖𝑛𝑔subscript𝑝conditional𝜃superscript𝑥𝑇𝑒𝑠𝑡𝑖𝑛𝑔\displaystyle\propto\underbrace{p(Y|X,\theta)p(\theta|X)}_{Training}% \underbrace{p(\theta|x^{\prime})}_{Testing}∝ under⏟ start_ARG italic_p ( italic_Y | italic_X , italic_θ ) italic_p ( italic_θ | italic_X ) end_ARG start_POSTSUBSCRIPT italic_T italic_r italic_a italic_i italic_n italic_i italic_n italic_g end_POSTSUBSCRIPT under⏟ start_ARG italic_p ( italic_θ | italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) end_ARG start_POSTSUBSCRIPT italic_T italic_e italic_s italic_t italic_i italic_n italic_g end_POSTSUBSCRIPT (9)

After approximating the Eq.Formulation of Prompt Posterior Distribution, we adapt the context vector θ𝜃\thetaitalic_θ with test data-dependent prior.

Variational Inference for Prompt Posterior

Since Eq. Formulation of Prompt Posterior Distribution is not tractable, we approximate the posterior distribution of π𝜋\piitalic_π, using particle-based variational inference. Suppose that q𝑞qitalic_q is a probabilistic measure of variational distribution, which generates the context vector θ𝜃\thetaitalic_θ, in Wasserstein space. Eq. 11 defines the optimization problem to approximate the model posterior distribution by the variational distribution.

V(θ)logp(Y|X,θ)logp(θ|X)𝑉𝜃𝑝conditional𝑌𝑋𝜃𝑝conditional𝜃𝑋V(\theta)\coloneqq-\log p(Y|X,\theta)-\log p(\theta|X)italic_V ( italic_θ ) ≔ - roman_log italic_p ( italic_Y | italic_X , italic_θ ) - roman_log italic_p ( italic_θ | italic_X ) (10)
F(q)DKL(q||π)𝔼q[V(θ)+logq]F(q)\coloneqq D_{KL}(q||\pi)\approx\operatorname{\mathbb{E}}_{q}[V(\theta)+% \log q]italic_F ( italic_q ) ≔ italic_D start_POSTSUBSCRIPT italic_K italic_L end_POSTSUBSCRIPT ( italic_q | | italic_π ) ≈ blackboard_E start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT [ italic_V ( italic_θ ) + roman_log italic_q ] (11)

To define the steepest direction of Eq. 11, we follow Wasserstein Gradient Flow Eq.4. By solving the Wasserstein Gradient Flow in Reproducing Kernel Hilbert Space (RKHS), we can define the following Wasserstein Gradient for variational distribution q𝑞qitalic_q, where the linear operator 𝒦qT(θ)𝔼θq[K(θ,θ)T(θ)]subscript𝒦𝑞𝑇𝜃subscript𝔼similar-tosuperscript𝜃𝑞𝐾𝜃superscript𝜃𝑇superscript𝜃\mathcal{K}_{q}T(\theta)\coloneqq\operatorname{\mathbb{E}}_{\theta^{\prime}% \sim q}[K({\theta},{\theta^{\prime}})T({\theta^{\prime}})]caligraphic_K start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT italic_T ( italic_θ ) ≔ blackboard_E start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∼ italic_q end_POSTSUBSCRIPT [ italic_K ( italic_θ , italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) italic_T ( italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ].

tqt=(qt𝒦qqδF(q)δq)subscript𝑡subscript𝑞𝑡subscript𝑞𝑡subscript𝒦𝑞subscript𝑞𝛿𝐹𝑞𝛿𝑞\partial_{t}q_{t}=\nabla\cdot(q_{t}\mathcal{K}_{q}\nabla_{q}{\delta F(q)\over% \delta q})∂ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = ∇ ⋅ ( italic_q start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT caligraphic_K start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ∇ start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT divide start_ARG italic_δ italic_F ( italic_q ) end_ARG start_ARG italic_δ italic_q end_ARG ) (12)

Following the Continuity Equation, we can define the evolving path of θtqtsimilar-tosubscript𝜃𝑡subscript𝑞𝑡\theta_{t}\sim q_{t}italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∼ italic_q start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT as follows.

tθtsubscript𝑡subscript𝜃𝑡\displaystyle\partial_{t}\theta_{t}∂ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT =𝒦qθ(δF(q)δq)absentsubscript𝒦𝑞subscript𝜃𝛿𝐹𝑞𝛿𝑞\displaystyle=-\mathcal{K}_{q}\nabla_{\theta}({\delta F(q)\over\delta q})= - caligraphic_K start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( divide start_ARG italic_δ italic_F ( italic_q ) end_ARG start_ARG italic_δ italic_q end_ARG )
=[K(θ,θ)θV(θ)𝑑qθK(θ,θ)𝑑q]absentdelimited-[]𝐾𝜃superscript𝜃subscriptsuperscript𝜃𝑉superscript𝜃differential-d𝑞subscriptsuperscript𝜃𝐾𝜃superscript𝜃differential-d𝑞\displaystyle=-[\int K({\theta},{\theta^{\prime}})\nabla_{\theta^{\prime}}V(% \theta^{\prime})dq-\int\nabla_{\theta^{\prime}}K(\theta,\theta^{\prime})dq]= - [ ∫ italic_K ( italic_θ , italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_V ( italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) italic_d italic_q - ∫ ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_K ( italic_θ , italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) italic_d italic_q ] (13)

By discretizing Eq. Variational Inference for Prompt Posterior, we derive the Stein Variational Gradient Descent (Liu and Wang 2016), where each context vector θjsuperscript𝜃𝑗\theta^{j}italic_θ start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT can be optimized as follows.

θt+1i=θtihMj=1M[K(θti,θtj)θtjV(θtj)θtjK(θti,θtj)]superscriptsubscript𝜃𝑡1𝑖superscriptsubscript𝜃𝑡𝑖𝑀superscriptsubscript𝑗1𝑀delimited-[]𝐾superscriptsubscript𝜃𝑡𝑖superscriptsubscript𝜃𝑡𝑗subscriptsuperscriptsubscript𝜃𝑡𝑗𝑉superscriptsubscript𝜃𝑡𝑗subscriptsuperscriptsubscript𝜃𝑡𝑗𝐾superscriptsubscript𝜃𝑡𝑖superscriptsubscript𝜃𝑡𝑗\theta_{t+1}^{i}=\theta_{t}^{i}-{h\over M}\sum_{j=1}^{M}[K(\theta_{t}^{i},% \theta_{t}^{j})\nabla_{\theta_{t}^{j}}V(\theta_{t}^{j})-\nabla_{\theta_{t}^{j}% }K(\theta_{t}^{i},\theta_{t}^{j})]italic_θ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT = italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT - divide start_ARG italic_h end_ARG start_ARG italic_M end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT [ italic_K ( italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT , italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ) ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_V ( italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ) - ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_K ( italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT , italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ) ] (14)

For Eq. 14, the first term can be interpreted as a smoothing gradient between context vectors θjsuperscript𝜃𝑗\theta^{j}italic_θ start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT and assure the convergence toward the true posterior distribution. The second term can be interpreted as the repulsive force between context vectors θjsuperscript𝜃𝑗\theta^{j}italic_θ start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT and guide the text features can cover the multi modes sparsely.

Parameter Training of Data-Dependent Prior

Since the prior has the parametrized mean ϕitalic-ϕ\phiitalic_ϕ, we pre-train the ϕitalic-ϕ\phiitalic_ϕ, which can map the image feature on the prompt space. To preserve the image feature information within our prior distribution, we propose to maximize the mutual information I(ϕ(f(x));f(x))𝐼italic-ϕ𝑓𝑥𝑓𝑥I\left(\phi\left(f\left(x\right)\right);f(x)\right)italic_I ( italic_ϕ ( italic_f ( italic_x ) ) ; italic_f ( italic_x ) ). In other words, this mutual information encourages the prior to capture the dependencies between the image features and the parameter of the prompt distribution. Due to the data processing inequality (Beaudry and Renner 2012), we derive inequality as follows:

I(ϕ(f(X));f(X))I(g(ϕ(f(X)),);f(X))𝐼italic-ϕ𝑓𝑋𝑓𝑋𝐼𝑔italic-ϕ𝑓𝑋𝑓𝑋I(\phi(f(X));f(X))\geq I(g(\phi(f(X)),\cdot);f(X))italic_I ( italic_ϕ ( italic_f ( italic_X ) ) ; italic_f ( italic_X ) ) ≥ italic_I ( italic_g ( italic_ϕ ( italic_f ( italic_X ) ) , ⋅ ) ; italic_f ( italic_X ) ) (15)

where ϕitalic-ϕ\phiitalic_ϕ can be learned to maximize the mutual information.

Proposition 1.

Suppose that the Markov chain assumption holds as f(X)ϕ(f(X))g(ϕ(f(X)),)normal-→𝑓𝑋italic-ϕ𝑓𝑋normal-→𝑔italic-ϕ𝑓𝑋normal-⋅f(X)\rightarrow\phi(f(X))\rightarrow g(\phi(f(X)),\cdot)italic_f ( italic_X ) → italic_ϕ ( italic_f ( italic_X ) ) → italic_g ( italic_ϕ ( italic_f ( italic_X ) ) , ⋅ ), then the lower bound of the mutual information, I(f(X);ϕ(f(X)))𝐼𝑓𝑋italic-ϕ𝑓𝑋I(f(X);\phi(f(X)))italic_I ( italic_f ( italic_X ) ; italic_ϕ ( italic_f ( italic_X ) ) ), is derived as follows:
I(f(X);ϕ(f(X)))I(f(X);g(ϕ(f(X)),))logCCE(ϕ(f(X)),X,Y)𝐼𝑓𝑋italic-ϕ𝑓𝑋𝐼𝑓𝑋𝑔italic-ϕ𝑓𝑋normal-⋅𝐶subscript𝐶𝐸italic-ϕ𝑓𝑋𝑋𝑌I(f(X);\phi(f(X)))\geq I(f(X);g(\phi(f(X)),\cdot))\geq\log C-\mathcal{L}_{CE}(% \phi(f(X)),X,Y)italic_I ( italic_f ( italic_X ) ; italic_ϕ ( italic_f ( italic_X ) ) ) ≥ italic_I ( italic_f ( italic_X ) ; italic_g ( italic_ϕ ( italic_f ( italic_X ) ) , ⋅ ) ) ≥ roman_log italic_C - caligraphic_L start_POSTSUBSCRIPT italic_C italic_E end_POSTSUBSCRIPT ( italic_ϕ ( italic_f ( italic_X ) ) , italic_X , italic_Y )

Based on Proposition 1, we can maximize the mutual information by minimizing the cross entropy. The full training scenario is reported in Algorithm 1.

Adaptation θ𝜃\thetaitalic_θ with Test Data-Dependent Prior

Following the training of the posterior distribution as described in Eq. Formulation of Prompt Posterior Distribution, we extend our approach to accommodate an unseen data instance, xsuperscript𝑥x^{\prime}italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT, within the posterior distribution. This involves updating the context vector θqsimilar-to𝜃𝑞\theta\sim qitalic_θ ∼ italic_q through a linear combination with the prior mean ϕ(f(x))italic-ϕ𝑓superscript𝑥\phi(f(x^{\prime}))italic_ϕ ( italic_f ( italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ). For the sake of simplicity, we perform a weighted average of the text features, which can be outlined as follows. The adaptation scenario is reported in Algorithm 2.

g(θ*,y)=αg(θ,y)+(1α)g(ϕ(f(x)),y)𝑔superscript𝜃𝑦𝛼𝑔𝜃𝑦1𝛼𝑔italic-ϕ𝑓superscript𝑥𝑦g(\theta^{*},y)=\alpha g(\theta,y)+(1-\alpha)g(\phi(f(x^{\prime})),y)italic_g ( italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT , italic_y ) = italic_α italic_g ( italic_θ , italic_y ) + ( 1 - italic_α ) italic_g ( italic_ϕ ( italic_f ( italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ) , italic_y ) (16)
Algorithm 1 Training Scenario of APP
1:Input: Dataset 𝒟={X,Y}𝒟𝑋𝑌\mathcal{D}=\{X,Y\}caligraphic_D = { italic_X , italic_Y }, Context vector θisuperscript𝜃𝑖\theta^{i}italic_θ start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT, Prior Network ϕitalic-ϕ\phiitalic_ϕ
2:while  not converged do
3:     Compute CE(ϕ(f(X)),X,Y)subscript𝐶𝐸italic-ϕ𝑓𝑋𝑋𝑌\mathcal{L}_{CE}(\phi(f(X)),X,Y)caligraphic_L start_POSTSUBSCRIPT italic_C italic_E end_POSTSUBSCRIPT ( italic_ϕ ( italic_f ( italic_X ) ) , italic_X , italic_Y )
4:     Update ϕt+1=ϕthϕCEsubscriptitalic-ϕ𝑡1subscriptitalic-ϕ𝑡subscriptitalic-ϕsubscript𝐶𝐸\phi_{t+1}=\phi_{t}-h\nabla_{\phi}\mathcal{L}_{CE}italic_ϕ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT = italic_ϕ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_h ∇ start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_C italic_E end_POSTSUBSCRIPT
5:end while
6:while  not converged do
7:     Compute V(θ)=logp(Y|X,θ)logp(θ|X)𝑉𝜃𝑝conditional𝑌𝑋𝜃𝑝conditional𝜃𝑋V(\theta)=-\log p(Y|X,\theta)-\log p(\theta|X)italic_V ( italic_θ ) = - roman_log italic_p ( italic_Y | italic_X , italic_θ ) - roman_log italic_p ( italic_θ | italic_X )
8:     Update θt+1i=θtihMj=1M[K(θti,θtj)θtjV(θtj)θtjK(θti,θtj)]superscriptsubscript𝜃𝑡1𝑖superscriptsubscript𝜃𝑡𝑖𝑀superscriptsubscript𝑗1𝑀delimited-[]𝐾superscriptsubscript𝜃𝑡𝑖superscriptsubscript𝜃𝑡𝑗subscriptsuperscriptsubscript𝜃𝑡𝑗𝑉superscriptsubscript𝜃𝑡𝑗subscriptsuperscriptsubscript𝜃𝑡𝑗𝐾superscriptsubscript𝜃𝑡𝑖superscriptsubscript𝜃𝑡𝑗\theta_{t+1}^{i}=\theta_{t}^{i}-{h\over M}\sum_{j=1}^{M}[K(\theta_{t}^{i},% \theta_{t}^{j})\nabla_{\theta_{t}^{j}}V(\theta_{t}^{j})-\nabla_{\theta_{t}^{j}% }K(\theta_{t}^{i},\theta_{t}^{j})]italic_θ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT = italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT - divide start_ARG italic_h end_ARG start_ARG italic_M end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT [ italic_K ( italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT , italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ) ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_V ( italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ) - ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_K ( italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT , italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ) ] ,i[1,,M],\forall{i\in[1,...,M]}, ∀ italic_i ∈ [ 1 , … , italic_M ]
9:end while
Algorithm 2 Test Scenario of APP
1:Input: Test image instance xsuperscript𝑥x^{\prime}italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT, Context vector θisuperscript𝜃𝑖\theta^{i}italic_θ start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT, Prior Network ϕitalic-ϕ\phiitalic_ϕ
2:Training as Algorithm 1
3:Compute g(θ*,yj)𝑔superscript𝜃subscript𝑦𝑗g(\theta^{*},y_{j})italic_g ( italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT , italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) as Eq.16, j[1,..,K]\forall{j\in[1,..,K]}∀ italic_j ∈ [ 1 , . . , italic_K ]
4:y=argmaxyjsim(g(θ*,yj),f(x))superscript𝑦subscriptargmaxsubscript𝑦𝑗sim𝑔superscript𝜃subscript𝑦𝑗𝑓superscript𝑥y^{\prime}=\text{argmax}_{y_{j}}\text{sim}(g(\theta^{*},y_{j}),f(x^{\prime}))italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = argmax start_POSTSUBSCRIPT italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT sim ( italic_g ( italic_θ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT , italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) , italic_f ( italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) )
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Refer to caption
Figure 2: Result of Few-shot Classification. We conduct three-replicated experiments.

Results

Experiment Settings

We conduct three distinct experiments; Few-shot classification, domain generalization of ImageNet, and base-to-new generalization following PLOT (Chen et al. 2023). For Few-shot classification, we conduct 11 image datasets, including Caltech101 (Fei-Fei, Fergus, and Perona 2004), DTD (Cimpoi et al. 2014), EuroSAT (Helber et al. 2019), FGVCAircraft (Maji et al. 2013), Oxford 102 Flower (Nilsback and Zisserman 2008), OxfordPets (Parkhi et al. 2012), Food101 (Bossard, Guillaumin, and Van Gool 2014), StanfordCars (Krause et al. 2013), Sun397 (Xiao et al. 2010), UCF101 (Soomro, Zamir, and Shah 2012), and ImageNet (Deng et al. 2009). We followed the training setting of PLOT (Chen et al. 2023), where the training shots are chosen in 1, 2, 4, 8, 16 shots, and we train 50, 100, 100, 200, and 200 epochs for each shot. For ImageNet, we train the prompts in 50 epochs for all shots. Before the training context vector θ𝜃\thetaitalic_θ, we train the prior mean ϕitalic-ϕ\phiitalic_ϕ in 10, 20, 20, 40, and 40 epochs for each shot. For domain generalization of ImageNet, we train prompts about ImageNet as a source dataset and report the accuracy of the source dataset and target datasets, including ImageNetV2 (Recht et al. 2019), ImageNet-A (Hendrycks et al. 2021b), ImageNet-R (Hendrycks et al. 2021a), and ImageNet-Sketch (Wang et al. 2019). For base-to-new generalization, we train prompts using 16 shots for each of 11 datasets for the base class and report the performance of base and new classes.

As a common setting, we conduct three replicated experiments to report the performances, and we use CLIP (Jia et al. 2021) as the backbone network, where ResNet50 (He et al. 2016) is chosen as the image encoder. All context vectors are vectors of 16 dimensions, which are sampled from N(0,0.02I)𝑁00.02𝐼N(0,0.02I)italic_N ( 0 , 0.02 italic_I ), and class information is inserted at the end position of context vectors. We fixed the precision 1/σ21superscript𝜎21/\sigma^{2}1 / italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT as 1.0 for all settings, and RBF kernel is used for K𝐾Kitalic_K. The adaptation weight α𝛼\alphaitalic_α is chosen as 0.9 for few-shot classification, and 0.7 for generalization experiments. We report more details of the setting in the Appendix.

Baselines

We compare the performance of our method, APP with CoOp (Zhou et al. 2022b), CoCoOP (Zhou et al. 2022a), PLOT (Chen et al. 2023), and ProDA (Lu et al. 2022). We do not include BPL (Derakhshani et al. 2023) as our baselines due to reasons in the Appendix. We initialize the four context vectors for PLOT, ProDA, and APP randomly.

Few-Shot Classification

Quantitative Analysis

Figure 2 indicates that our method outperforms baselines on all benchmark datasets on average, where our performance is superior to 45 out of 55 experiment cases (11 datasets ×\times× 5 shots). The advantage of APP stands out in Caltech101, DTD, EuroSAT, and ImageNet, which consist of more diverse images. Since the data-dependent prior and the repulsive force of Eq. 14 enable text features to infer the multi-modes of image features, our context vectors are learned to capture the diverse semantics of image features.

Qualitative Analysis

Refer to caption
Refer to caption
Refer to caption
Refer to caption
(a) PLOT
Refer to caption
(b) ProDA
Refer to caption
(c) APP
Figure 3: Umap visualization about image features and text features for EuroSAT. (Upper) Histograms correspond to image features and \star means text features of all classes. (Lower) Image and text features of arbitrary two classes. The color coding corresponds to each class.

To demonstrate the efficacy of our method in capturing intricate image features, we provide visualizations of both image features (f(xi)𝑓subscript𝑥𝑖f(x_{i})italic_f ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT )) and text features (g(θj,yi)𝑔subscript𝜃𝑗subscript𝑦𝑖g(\theta_{j},y_{i})italic_g ( italic_θ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT )). In Figure 3, we present Umap (McInnes, Healy, and Melville 2018) representations for EuroSAT test dataset. The upper figures depict image and text representations of all classes. A richer yellow hue indicates a denser allocation of image features. Our method, APP, well captures and comprehensively spans various modes within image features. The lower figures spotlight two arbitrarily selected classes, revealing how APP’s text features harmoniously align with the image features of each class. This alignment is particularly pronounced in comparison to other methods.

Table 1 provides the numerical analysis of the alignment between text feature representations from our method and image feature modes. We fitted k-means clustering on the test image features with k as a hyperparameter. Then, we counted the number of prompts assigned on each cluster, which will be similar for all clusters if prompts are adequately assigned to several modes of image features. In this context, we present the variance of prompt counts and show that text prompt representations of APP are well distributed across image feature representations.

Table 1: Variance of the number of text feature representations assigned to each cluster after fitting k-means clustering to image features. Bold means the smallest of each column.
Methods k (Number of clusters)
5 6 7 8 9 10
PLOT 26.0 27.6 11.9 7.5 8.5 7.8
ProDA 53.2 26.6 25.3 12.0 9.8 9.0
APP 25.2 21.9 9.9 6.3 4.0 2.6

Ablation Study

We show two ablation studies, considering our method and the number of prompts.

To identify the key enabler, we conduct additional ablation studies for APP by experimenting on 1) data-dependent prior and 2) Stein Variational Gradient Descent (SVGD). For all cases, four context vectors are initialized, and we select MLE training as the baseline optimized by SGD. Table 2 shows the posterior approximation with data-dependent prior improved performances generally than MLE Training. SVGD shows a more robust performance in a few data instances than SGD. The posterior approximation by SVGD also shows consistently better performance, outperforming in Caltech101 and EuroSAT dataset.

Table 2: An ablation study about our method, APP. Experiments are replicated over three times.
Dataset Methods Number of shots
1 2 4
Caltech101 SGD 89.87±0.02plus-or-minus89.870.0289.87\pm 0.0289.87 ± 0.02 90.29±0.20plus-or-minus90.290.2090.29\pm 0.2090.29 ± 0.20 91.00±0.29plus-or-minus91.000.2991.00\pm 0.2991.00 ± 0.29
SVGD 89.90±0.07plus-or-minus89.900.0789.90\pm 0.0789.90 ± 0.07 90.56±0.10plus-or-minus90.560.1090.56\pm 0.1090.56 ± 0.10 91.02±0.29plus-or-minus91.020.2991.02\pm 0.2991.02 ± 0.29
SGD+Prior 90.26±0.21plus-or-minus90.260.2190.26\pm 0.2190.26 ± 0.21 90.70±0.25plus-or-minus90.700.2590.70\pm 0.2590.70 ± 0.25 91.01±0.14plus-or-minus91.010.1491.01\pm 0.1491.01 ± 0.14
SVGD+Prior (APP) 90.30±0.33plus-or-minus90.300.33\mathbf{90.30\pm 0.33}bold_90.30 ± bold_0.33 90.87±0.25plus-or-minus90.870.25\mathbf{90.87\pm 0.25}bold_90.87 ± bold_0.25 91.05±0.45plus-or-minus91.050.45\mathbf{91.05\pm 0.45}bold_91.05 ± bold_0.45
EuroSAT SGD 56.43±2.31plus-or-minus56.432.3156.43\pm 2.3156.43 ± 2.31 63.93±2.85plus-or-minus63.932.8563.93\pm 2.8563.93 ± 2.85 72.63±2.74plus-or-minus72.632.7472.63\pm 2.7472.63 ± 2.74
SVGD 58.93±4.28plus-or-minus58.934.2858.93\pm 4.2858.93 ± 4.28 64.36±3.16plus-or-minus64.363.1664.36\pm 3.1664.36 ± 3.16 73.64±2.00plus-or-minus73.642.0073.64\pm 2.0073.64 ± 2.00
SGD+Prior 59.96±4.77plus-or-minus59.964.7759.96\pm 4.7759.96 ± 4.77 63.99±2.88plus-or-minus63.992.8863.99\pm 2.8863.99 ± 2.88 72.63±2.76plus-or-minus72.632.7672.63\pm 2.7672.63 ± 2.76
SVGD+Prior (APP) 60.04±2.08plus-or-minus60.042.08\mathbf{60.04\pm 2.08}bold_60.04 ± bold_2.08 66.02±2.70plus-or-minus66.022.70\mathbf{66.02\pm 2.70}bold_66.02 ± bold_2.70 74.08±2.23plus-or-minus74.082.23\mathbf{74.08\pm 2.23}bold_74.08 ± bold_2.23
Food101 SGD 78.53±0.14plus-or-minus78.530.1478.53\pm 0.1478.53 ± 0.14 78.88±0.09plus-or-minus78.880.0978.88\pm 0.0978.88 ± 0.09 78.88±0.25plus-or-minus78.880.2578.88\pm 0.2578.88 ± 0.25
SVGD 78.51±0.12plus-or-minus78.510.1278.51\pm 0.1278.51 ± 0.12 78.86±0.07plus-or-minus78.860.0778.86\pm 0.0778.86 ± 0.07 78.87±0.26plus-or-minus78.870.2678.87\pm 0.2678.87 ± 0.26
SGD+Prior 78.81±0.16plus-or-minus78.810.1678.81\pm 0.1678.81 ± 0.16 79.15±0.07plus-or-minus79.150.0779.15\pm 0.0779.15 ± 0.07 79.32±0.08plus-or-minus79.320.0879.32\pm 0.0879.32 ± 0.08
SVGD+Prior (APP) 78.87±0.17plus-or-minus78.870.17\mathbf{78.87\pm 0.17}bold_78.87 ± bold_0.17 79.25±0.06plus-or-minus79.250.06\mathbf{79.25\pm 0.06}bold_79.25 ± bold_0.06 79.43±0.10plus-or-minus79.430.10\mathbf{79.43\pm 0.10}bold_79.43 ± bold_0.10

We also carry out additional experiments varying the number of prompts to explore the influence of samples within the posterior distribution as table 3. While there is a tendency for higher performance with an increased number of prompts, please note that the performance can sufficiently be achieved with approximately four prompts.

Table 3: An ablation study with regard to the number of prompts (M). Experiments are replicated over three times.
Dataset M Number of shots
1 2 4
Caltech101 2 89.19±0.20plus-or-minus89.190.2089.19\pm 0.2089.19 ± 0.20 90.08±0.22plus-or-minus90.080.2290.08\pm 0.2290.08 ± 0.22 90.67±0.69plus-or-minus90.670.6990.67\pm 0.6990.67 ± 0.69
4 90.30±0.33plus-or-minus90.300.33\mathbf{90.30\pm 0.33}bold_90.30 ± bold_0.33 90.87±0.25plus-or-minus90.870.25\mathbf{90.87\pm 0.25}bold_90.87 ± bold_0.25 91.05±0.45plus-or-minus91.050.4591.05\pm 0.4591.05 ± 0.45
8 90.18±0.28plus-or-minus90.180.2890.18\pm 0.2890.18 ± 0.28 90.33±0.29plus-or-minus90.330.2990.33\pm 0.2990.33 ± 0.29 91.38±0.16plus-or-minus91.380.16\mathbf{91.38\pm 0.16}bold_91.38 ± bold_0.16
EuroSAT 2 52.06±4.98plus-or-minus52.064.9852.06\pm 4.9852.06 ± 4.98 60.65±3.00plus-or-minus60.653.0060.65\pm 3.0060.65 ± 3.00 71.19±1.78plus-or-minus71.191.7871.19\pm 1.7871.19 ± 1.78
4 60.04±2.08plus-or-minus60.042.08\mathbf{{60.04\pm 2.08}}bold_60.04 ± bold_2.08 66.02±2.70plus-or-minus66.022.70\mathbf{66.02\pm 2.70}bold_66.02 ± bold_2.70 74.08±2.23plus-or-minus74.082.2374.08\pm 2.2374.08 ± 2.23
8 58.60±2.35plus-or-minus58.602.3558.60\pm 2.3558.60 ± 2.35 64.22±5.32plus-or-minus64.225.3264.22\pm 5.3264.22 ± 5.32 74.32±1.97plus-or-minus74.321.97\mathbf{74.32\pm 1.97}bold_74.32 ± bold_1.97
Food101 2 78.27±0.16plus-or-minus78.270.1678.27\pm 0.1678.27 ± 0.16 78.74±0.09plus-or-minus78.740.0978.74\pm 0.0978.74 ± 0.09 78.93±0.12plus-or-minus78.930.1278.93\pm 0.1278.93 ± 0.12
4 78.87±0.17plus-or-minus78.870.1778.87\pm 0.1778.87 ± 0.17 79.25±0.06plus-or-minus79.250.0679.25\pm 0.0679.25 ± 0.06 79.43±0.10plus-or-minus79.430.1079.43\pm 0.1079.43 ± 0.10
8 78.90±0.21plus-or-minus78.900.21\mathbf{78.90\pm 0.21}bold_78.90 ± bold_0.21 79.29±0.09plus-or-minus79.290.09\mathbf{79.29\pm 0.09}bold_79.29 ± bold_0.09 79.45±0.12plus-or-minus79.450.12\mathbf{79.45\pm 0.12}bold_79.45 ± bold_0.12

Time and Memory Complexity Comparison

In this paragraph, we compare the time and memory complexity between APP and CoCoOp. This comparison aims for showing the appropriateness of prompt learning in the context of efficiency. While both data-dependent prior and conditioned prompts share similarities in Figure 0(a), Table 4 underscores the fact that data-dependent priors exhibit efficiency and effectiveness. This is because the regularization process of the data-dependent prior does not require gradient updates.

Table 4: Time complexity for Caltech101. M means the number of prompts. -- is a training failure due to the memory issue. BS means Batch Size and Acc represents accuracy.
Methods BS Memory (GB) Time (s) Acc (%percent\%%)
CoCoOp 10 20 289 84.4
(M=1) 128 - - -
APP 10 7.2 260 89.3
(M=4) 128 9.6 112 90.2
Table 5: Test accuracies (%percent\%%) of the unseen classes generalization settings. H means the harmonic mean between the base accuracy and the new accuracy. Kühn means the best accuracy of each column. We report the mean accuracy with three times replications due to the space issue.
(a) Average
(b) Caltech101
(c) DTD
(d) EuroSAT
(e) FGVC-Aircraft
(f) Food101
(g) ImageNet
(h) Flower102
(i) Oxford Pets
(j) Stanford Cars
(k) Sun397
Method Base New H
CoCoOp 71.771.771.771.7 53.653.653.653.6 61.461.461.461.4
PLOT 82.282.282.282.2 60.560.560.560.5 69.769.769.769.7
ProDA 82.482.482.482.4 63.663.663.663.6 71.871.871.871.8
APP 83.083.0\mathbf{83.0}bold_83.0 65.865.8\mathbf{65.8}bold_65.8 73.473.4\mathbf{73.4}bold_73.4
Method Base New H
CoCoOp 95.295.2\mathbf{95.2}bold_95.2 87.487.487.487.4 91.291.291.291.2
PLOT 94.794.794.794.7 88.188.188.188.1 91.391.391.391.3
ProDA 95.295.2\mathbf{95.2}bold_95.2 86.886.886.886.8 90.890.890.890.8
APP 95.295.2\mathbf{95.2}bold_95.2 91.091.0\mathbf{91.0}bold_91.0 93.093.0\mathbf{93.0}bold_93.0
Method Base New H
CoCoOp 74.674.674.674.6 38.938.938.938.9 51.151.151.151.1
PLOT 78.178.178.178.1 42.742.742.742.7 55.255.255.255.2
ProDA 78.078.078.078.0 47.047.047.047.0 58.658.658.658.6
APP 78.478.4\mathbf{78.4}bold_78.4 48.948.9\mathbf{48.9}bold_48.9 60.260.2\mathbf{60.2}bold_60.2
Method Base New H
CoCoOp 91.491.491.491.4 35.635.635.635.6 51.351.351.351.3
PLOT 92.992.992.992.9 39.339.339.339.3 55.255.255.255.2
ProDA 89.689.689.689.6 39.039.039.039.0 54.454.454.454.4
APP 93.693.6\mathbf{93.6}bold_93.6 47.647.6\mathbf{47.6}bold_47.6 63.163.1\mathbf{63.1}bold_63.1
Method Base New H
CoCoOp 29.129.129.129.1 14.114.114.114.1 19.019.019.019.0
PLOT 43.343.343.343.3 20.420.420.420.4 27.827.827.827.8
ProDA 44.344.344.344.3 24.124.124.124.1 31.231.231.231.2
APP 44.944.9\mathbf{44.9}bold_44.9 26.026.0\mathbf{26.0}bold_26.0 33.033.0\mathbf{33.0}bold_33.0
Method Base New H
CoCoOp 80.780.780.780.7 78.878.878.878.8 79.779.779.779.7
PLOT 83.483.483.483.4 84.284.284.284.2 83.883.883.883.8
ProDA 84.584.584.584.5 86.286.2\mathbf{86.2}bold_86.2 85.485.4\mathbf{85.4}bold_85.4
APP 84.684.6\mathbf{84.6}bold_84.6 86.186.186.186.1 85.485.4\mathbf{85.4}bold_85.4
Method Base New H
CoCoOp 68.368.368.368.3 60.560.560.560.5 64.164.164.164.1
PLOT 68.368.368.368.3 58.458.458.458.4 62.962.962.962.9
ProDA 68.868.868.868.8 63.063.063.063.0 65.765.765.765.7
APP 69.969.9\mathbf{69.9}bold_69.9 63.263.2\mathbf{63.2}bold_63.2 66.466.4\mathbf{66.4}bold_66.4
Method Base New H
CoCoOp 94.794.794.794.7 58.658.658.658.6 72.472.472.472.4
PLOT 97.497.4\mathbf{97.4}bold_97.4 54.254.254.254.2 69.669.669.669.6
ProDA 97.097.097.097.0 58.558.558.558.5 73.073.073.073.0
APP 96.896.896.896.8 61.061.0\mathbf{61.0}bold_61.0 74.874.8\mathbf{74.8}bold_74.8
Method Base New H CoCoOp 89.489.489.489.4 91.091.0\mathbf{91.0}bold_91.0 90.290.290.290.2 PLOT 95.995.995.995.9 87.687.687.687.6 91.591.591.591.5 ProDA 96.496.496.496.4 88.688.688.688.6 92.492.4\mathbf{92.4}bold_92.4 APP 96.896.8\mathbf{96.8}bold_96.8 88.388.388.388.3 92.492.4\mathbf{92.4}bold_92.4
Method Base New H
CoCoOp 68.768.768.768.7 51.651.651.651.6 58.958.958.958.9
PLOT 84.284.284.284.2 62.662.662.662.6 71.871.871.871.8
ProDA 84.584.584.584.5 68.168.168.168.1 75.575.575.575.5
APP 85.985.9\mathbf{85.9}bold_85.9 69.569.5\mathbf{69.5}bold_69.5 76.876.8\mathbf{76.8}bold_76.8
Method Base New H
CoCoOp 73.373.373.373.3 64.064.064.064.0 68.468.468.468.4
PLOT 79.879.879.879.8 65.365.365.365.3 71.871.871.871.8
ProDA 80.980.9\mathbf{80.9}bold_80.9 70.870.870.870.8 75.575.575.575.5
APP 80.680.680.680.6 73.373.3\mathbf{73.3}bold_73.3 76.876.8\mathbf{76.8}bold_76.8
Method Base New H
CoCoOp 79.279.279.279.2 47.047.047.047.0 59.059.059.059.0
PLOT 86.586.586.586.5 62.762.762.762.7 72.772.772.772.7
ProDA 86.9 67.967.967.967.9 76.276.276.276.2
APP 86.286.286.286.2 69.269.2\mathbf{69.2}bold_69.2 76.876.8\mathbf{76.8}bold_76.8
(a) Average
(b) Caltech101
(c) DTD
(d) EuroSAT
(e) FGVC-Aircraft
(f) Food101
(g) ImageNet
(h) Flower102
(i) Oxford Pets
(j) Stanford Cars
(k) Sun397
(l) UCF101

Generalization Experiment

It is well known that VLP model is rather robust for domain shift (Radford et al. 2021), yet this good property can be corrupted when the model parameter is fine-tuned on the downstream task (Wortsman et al. 2022). Therefore, if this robustness regarding domain shift from VLP models could be sustained with prompt learning, it implies that this technique can be utilized more generally. For comparing the robustness, we conduct two experiments: 1) Unseen classes generalization setting in 11 datasets. and 2) Domain generalization setting in ImageNet.

Unseen classes Generalization in 11 Datasets.

Following Zhou et al. (2022a), we report the robustness over unseen classes in 11 datasets. Table 5 shows the test accuracies with regard to both seen (base) classes and unseen (new) classes. Note that APP is robust to unseen (new) class data, maintaining the performance of seen class data, while other baselines have a performance trade-off between seen and unseen classes.

Sensitive Analysis of α𝛼\alphaitalic_α

We additionally investigate the impact of test data-dependent prior, p(θ|x)𝑝conditional𝜃superscript𝑥p(\theta|x^{\prime})italic_p ( italic_θ | italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ), which adapts our posterior distribution to unseen instance. Figure 4 shows that adaptation of test data is beneficial for both seen and unseen performances. Additionally, balancing between posterior of seen data and prior of unseen data holds significance in achieving effective generalization for both scenarios.

Refer to caption
(a) Oxford Pets
Refer to caption
(b) UCF101
Refer to caption
(c) Flower102
Figure 4: Sensitivity analysis on α𝛼\alphaitalic_α, effect of test data-dependent prior. Experiments are replicated over 3 times.

Domain Generalization on ImageNet

Despite the tendency for a potential performance trade-off between the source and target datasets, table 6 demonstrates APP attains the improved performance achieved on both the source and target datasets, highlighting the robustness of APP in dealing with distribution shifts.

Table 6: Result of domain generalization in ImageNet. Acc represents the accuracy. Bold means the best accuracy.
Dataset Methods Acc (%percent\%%)
Source ImageNet CoCoOp 63.13±0.12plus-or-minus63.130.1263.13\pm 0.1263.13 ± 0.12
PLOT 63.14±0.16plus-or-minus63.140.1663.14\pm 0.1663.14 ± 0.16
ProDA 62.73±0.11plus-or-minus62.730.1162.73\pm 0.1162.73 ± 0.11
APP 64.50±0.06plus-or-minus64.500.06\mathbf{64.50\pm 0.06}bold_64.50 ± bold_0.06
Target ImageNetV2 CoCoOp 55.23±0.25plus-or-minus55.230.2555.23\pm 0.2555.23 ± 0.25
PLOT 54.23±0.45plus-or-minus54.230.4554.23\pm 0.4554.23 ± 0.45
ProDA 54.97±0.05plus-or-minus54.970.0554.97\pm 0.0554.97 ± 0.05
APP 57.10±0.29plus-or-minus57.100.29\mathbf{57.10\pm 0.29}bold_57.10 ± bold_0.29
ImageNet-Sketch CoCoOp 34.07±0.46plus-or-minus34.070.4634.07\pm 0.4634.07 ± 0.46
PLOT 33.93±0.12plus-or-minus33.930.1233.93\pm 0.1233.93 ± 0.12
ProDA 34.60±0.22plus-or-minus34.600.2234.60\pm 0.2234.60 ± 0.22
APP 35.70±0.14plus-or-minus35.700.14\mathbf{35.70\pm 0.14}bold_35.70 ± bold_0.14
ImageNet-R CoCoOp 56.03±0.38plus-or-minus56.030.3856.03\pm 0.3856.03 ± 0.38
PLOT 56.86±0.42plus-or-minus56.860.4256.86\pm 0.4256.86 ± 0.42
ProDA 58.57±0.49plus-or-minus58.570.4958.57\pm 0.4958.57 ± 0.49
APP 58.70±0.08plus-or-minus58.700.08\mathbf{58.70\pm 0.08}bold_58.70 ± bold_0.08
ImageNet-A CoCoOp 22.37±0.09plus-or-minus22.370.0922.37\pm 0.0922.37 ± 0.09
PLOT 22.63±0.12plus-or-minus22.630.1222.63\pm 0.1222.63 ± 0.12
ProDA 23.47±0.12plus-or-minus23.470.1223.47\pm 0.1223.47 ± 0.12
APP 23.80±0.22plus-or-minus23.800.22\mathbf{23.80\pm 0.22}bold_23.80 ± bold_0.22

Fazit

We propose the Bayesian framework for prompt learning to consider the uncertainty from few-shot learning scenario, where the image features are possible to be multi-modal and a distribution shift exists between the train and test dataset. We enhance flexibility via Wasserstein Gradient Flow. Furthermore, we propose a novel data-dependent prior distribution that is conditioned on averaged image features. This approach is designed to capture minor modes of image features and facilitate adaptation to previously unseen distributions. We demonstrate substantial performance improvements in various scenarios, including few-shot classifications, domain generalizations, and unseen class generalizations. Additionally, the qualitative analyses indicate that our prompt learning facilitates capturing the multi-modes of image features sparsely.

Acknowledgments

This research was supported by AI Technology Development for Commonsense Extraction, Reasoning, and Inference from Heterogeneous Data (IITP) funded by the Ministry of Science and ICT(2022-0-00077).

References

  • Barber and Agakov (2003) Barber, D.; and Agakov, F. 2003. Information Maximization in Noisy Channels : A Variational Approach. In Thrun, S.; Saul, L.; and Schölkopf, B., eds., Advances in Neural Information Processing Systems, volume 16. MIT Press.
  • Beaudry and Renner (2012) Beaudry, N. J.; and Renner, R. 2012. An intuitive proof of the data processing inequality. arXiv:1107.0740.
  • Bossard, Guillaumin, and Van Gool (2014) Bossard, L.; Guillaumin, M.; and Van Gool, L. 2014. Food-101–mining discriminative components with random forests. In Computer Vision–ECCV 2014: 13th European Conference, Zurich, Switzerland, September 6-12, 2014, Proceedings, Part VI 13, 446–461. Springer.
  • Chen et al. (2018) Chen, C.; Zhang, R.; Wang, W.; Li, B.; and Chen, L. 2018. A Unified Particle-Optimization Framework for Scalable Bayesian Sampling. arXiv:1805.11659.
  • Chen et al. (2023) Chen, G.; Yao, W.; Song, X.; Li, X.; Rao, Y.; and Zhang, K. 2023. PLOT: Prompt Learning with Optimal Transport for Vision-Language Models. In The Eleventh International Conference on Learning Representations.
  • Cimpoi et al. (2014) Cimpoi, M.; Maji, S.; Kokkinos, I.; Mohamed, S.; and Vedaldi, A. 2014. Describing textures in the wild. In Proceedings of the IEEE conference on computer vision and pattern recognition, 3606–3613.
  • Deng et al. (2009) Deng, J.; Dong, W.; Socher, R.; Li, L.-J.; Li, K.; and Fei-Fei, L. 2009. Imagenet: A large-scale hierarchical image database. In 2009 IEEE conference on computer vision and pattern recognition, 248–255. Ieee.
  • Derakhshani et al. (2023) Derakhshani, M. M.; Sanchez, E.; Bulat, A.; da Costa, V. G. T.; Snoek, C. G. M.; Tzimiropoulos, G.; and Martinez, B. 2023. Bayesian Prompt Learning for Image-Language Model Generalization. arXiv:2210.02390.
  • Fei-Fei, Fergus, and Perona (2004) Fei-Fei, L.; Fergus, R.; and Perona, P. 2004. Learning generative visual models from few training examples: An incremental bayesian approach tested on 101 object categories. In 2004 conference on computer vision and pattern recognition workshop, 178–178. IEEE.
  • gil Lee et al. (2022) gil Lee, S.; Kim, H.; Shin, C.; Tan, X.; Liu, C.; Meng, Q.; Qin, T.; Chen, W.; Yoon, S.; and Liu, T.-Y. 2022. PriorGrad: Improving Conditional Denoising Diffusion Models with Data-Dependent Adaptive Prior. In International Conference on Learning Representations.
  • He et al. (2016) He, K.; Zhang, X.; Ren, S.; and Sun, J. 2016. Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, 770–778.
  • Helber et al. (2019) Helber, P.; Bischke, B.; Dengel, A.; and Borth, D. 2019. Eurosat: A novel dataset and deep learning benchmark for land use and land cover classification. IEEE Journal of Selected Topics in Applied Earth Observations and Remote Sensing, 12(7): 2217–2226.
  • Hendrycks et al. (2021a) Hendrycks, D.; Basart, S.; Mu, N.; Kadavath, S.; Wang, F.; Dorundo, E.; Desai, R.; Zhu, T.; Parajuli, S.; Guo, M.; et al. 2021a. The many faces of robustness: A critical analysis of out-of-distribution generalization. In Proceedings of the IEEE/CVF International Conference on Computer Vision, 8340–8349.
  • Hendrycks et al. (2021b) Hendrycks, D.; Zhao, K.; Basart, S.; Steinhardt, J.; and Song, D. 2021b. Natural adversarial examples. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, 15262–15271.
  • Jia et al. (2021) Jia, C.; Yang, Y.; Xia, Y.; Chen, Y.-T.; Parekh, Z.; Pham, H.; Le, Q.; Sung, Y.-H.; Li, Z.; and Duerig, T. 2021. Scaling up visual and vision-language representation learning with noisy text supervision. In International Conference on Machine Learning, 4904–4916. PMLR.
  • Jordan, Kinderlehrer, and Otto (1998) Jordan, R.; Kinderlehrer, D.; and Otto, F. 1998. The Variational Formulation of the Fokker–Planck Equation. SIAM Journal on Mathematical Analysis, 29(1): 1–17.
  • Krause et al. (2013) Krause, J.; Stark, M.; Deng, J.; and Fei-Fei, L. 2013. 3d object representations for fine-grained categorization. In Proceedings of the IEEE international conference on computer vision workshops, 554–561.
  • Li et al. (2020) Li, Z.; Wang, R.; Chen, K.; Utiyama, M.; Sumita, E.; Zhang, Z.; and Zhao, H. 2020. Data-dependent Gaussian Prior Objective for Language Generation. In International Conference on Learning Representations.
  • Liu and Wang (2016) Liu, Q.; and Wang, D. 2016. Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm. In Lee, D.; Sugiyama, M.; Luxburg, U.; Guyon, I.; and Garnett, R., eds., Advances in Neural Information Processing Systems, volume 29. Curran Associates, Inc.
  • Lu et al. (2022) Lu, Y.; Liu, J.; Zhang, Y.; Liu, Y.; and Tian, X. 2022. Prompt distribution learning. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, 5206–5215.
  • Maji et al. (2013) Maji, S.; Rahtu, E.; Kannala, J.; Blaschko, M.; and Vedaldi, A. 2013. Fine-grained visual classification of aircraft. arXiv preprint arXiv:1306.5151.
  • McInnes, Healy, and Melville (2018) McInnes, L.; Healy, J.; and Melville, J. 2018. Umap: Uniform manifold approximation and projection for dimension reduction. arXiv preprint arXiv:1802.03426.
  • Nilsback and Zisserman (2008) Nilsback, M.-E.; and Zisserman, A. 2008. Automated flower classification over a large number of classes. In 2008 Sixth Indian Conference on Computer Vision, Graphics & Image Processing, 722–729. IEEE.
  • Parkhi et al. (2012) Parkhi, O. M.; Vedaldi, A.; Zisserman, A.; and Jawahar, C. V. 2012. Cats and dogs. In 2012 IEEE Conference on Computer Vision and Pattern Recognition, 3498–3505.
  • Radford et al. (2021) Radford, A.; Kim, J. W.; Hallacy, C.; Ramesh, A.; Goh, G.; Agarwal, S.; Sastry, G.; Askell, A.; Mishkin, P.; Clark, J.; et al. 2021. Learning transferable visual models from natural language supervision. In International conference on machine learning, 8748–8763. PMLR.
  • Recht et al. (2019) Recht, B.; Roelofs, R.; Schmidt, L.; and Shankar, V. 2019. Do imagenet classifiers generalize to imagenet? In International conference on machine learning, 5389–5400. PMLR.
  • Ruan, Dubois, and Maddison (2022) Ruan, Y.; Dubois, Y.; and Maddison, C. J. 2022. Optimal Representations for Covariate Shift. In International Conference on Learning Representations.
  • Shen et al. (2022) Shen, S.; Li, L. H.; Tan, H.; Bansal, M.; Rohrbach, A.; Chang, K.-W.; Yao, Z.; and Keutzer, K. 2022. How Much Can CLIP Benefit Vision-and-Language Tasks? In International Conference on Learning Representations.
  • Soomro, Zamir, and Shah (2012) Soomro, K.; Zamir, A. R.; and Shah, M. 2012. UCF101: A dataset of 101 human actions classes from videos in the wild. arXiv preprint arXiv:1212.0402.
  • Sordoni et al. (2021) Sordoni, A.; Dziri, N.; Schulz, H.; Gordon, G.; Bachman, P.; and Des Combes, R. T. 2021. Decomposed mutual information estimation for contrastive representation learning. In International Conference on Machine Learning, 9859–9869. PMLR.
  • Wang et al. (2019) Wang, H.; Ge, S.; Lipton, Z.; and Xing, E. P. 2019. Learning robust global representations by penalizing local predictive power. Advances in Neural Information Processing Systems, 32.
  • Welling and Teh (2011) Welling, M.; and Teh, Y. W. 2011. Bayesian learning via stochastic gradient Langevin dynamics. In Proceedings of the 28th international conference on machine learning (ICML-11), 681–688.
  • Wortsman et al. (2022) Wortsman, M.; Ilharco, G.; Kim, J. W.; Li, M.; Kornblith, S.; Roelofs, R.; Lopes, R. G.; Hajishirzi, H.; Farhadi, A.; Namkoong, H.; et al. 2022. Robust fine-tuning of zero-shot models. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, 7959–7971.
  • Xiao et al. (2010) Xiao, J.; Hays, J.; Ehinger, K. A.; Oliva, A.; and Torralba, A. 2010. Sun database: Large-scale scene recognition from abbey to zoo. In 2010 IEEE computer society conference on computer vision and pattern recognition, 3485–3492. IEEE.
  • Zhou et al. (2022a) Zhou, K.; Yang, J.; Loy, C. C.; and Liu, Z. 2022a. Conditional prompt learning for vision-language models. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, 16816–16825.
  • Zhou et al. (2022b) Zhou, K.; Yang, J.; Loy, C. C.; and Liu, Z. 2022b. Learning to prompt for vision-language models. International Journal of Computer Vision, 130(9): 2337–2348.

Supplementary Material

Proof of Proposition 1

For simplicity, define the Af(X)𝐴𝑓𝑋A\coloneqq f(X)italic_A ≔ italic_f ( italic_X ) and Bg(ϕ(f(X)),)𝐵𝑔italic-ϕ𝑓𝑋B\coloneqq g(\phi(f(X)),\cdot)italic_B ≔ italic_g ( italic_ϕ ( italic_f ( italic_X ) ) , ⋅ ). Following (Barber and Agakov 2003), we formulate the variational bound for Mutual information as follows:

I(f(X);g(ϕ(f(X)),))I(A;B)=𝔼p(A,B)logp(B|A)p(B)Ep(A,B)logq(B|A)p(B)𝐼𝑓𝑋𝑔italic-ϕ𝑓𝑋𝐼𝐴𝐵subscript𝔼𝑝𝐴𝐵𝑝conditional𝐵𝐴𝑝𝐵subscript𝐸𝑝𝐴𝐵𝑞conditional𝐵𝐴𝑝𝐵I\left(f(X);g(\phi(f(X)),\cdot)\right)\coloneqq I(A;B)=\operatorname{\mathbb{E% }}_{p(A,B)}\log{p(B|A)\over p(B)}\geq E_{p(A,B)}\log{q(B|A)\over p(B)}italic_I ( italic_f ( italic_X ) ; italic_g ( italic_ϕ ( italic_f ( italic_X ) ) , ⋅ ) ) ≔ italic_I ( italic_A ; italic_B ) = blackboard_E start_POSTSUBSCRIPT italic_p ( italic_A , italic_B ) end_POSTSUBSCRIPT roman_log divide start_ARG italic_p ( italic_B | italic_A ) end_ARG start_ARG italic_p ( italic_B ) end_ARG ≥ italic_E start_POSTSUBSCRIPT italic_p ( italic_A , italic_B ) end_POSTSUBSCRIPT roman_log divide start_ARG italic_q ( italic_B | italic_A ) end_ARG start_ARG italic_p ( italic_B ) end_ARG (17)

where q𝑞qitalic_q is a variational distribution. Following (Sordoni et al. 2021), we define variational distribution q𝑞qitalic_q by sampling B1,,BCsubscript𝐵1subscript𝐵𝐶B_{1},...,B_{C}italic_B start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_B start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT from the distribution p(B)𝑝𝐵p(B)italic_p ( italic_B ), where Big(ϕ(f(X)),yi)subscript𝐵𝑖𝑔italic-ϕ𝑓𝑋subscript𝑦𝑖B_{i}\coloneqq g(\phi(f(X)),y_{i})italic_B start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ≔ italic_g ( italic_ϕ ( italic_f ( italic_X ) ) , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ). For formulating the variational distribution, we choose the Bisubscript𝐵𝑖B_{i}italic_B start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT by the importance weight, which is defined as exp(sim(A,Bj)/τ)i=1Cexp((sim(A,Bi))/τ)𝑠𝑖𝑚𝐴subscript𝐵𝑗𝜏superscriptsubscript𝑖1𝐶𝑠𝑖𝑚𝐴subscript𝐵𝑖𝜏{\exp(sim(A,B_{j})/\tau)\over\sum_{i=1}^{C}\exp((sim(A,B_{i}))/\tau)}divide start_ARG roman_exp ( italic_s italic_i italic_m ( italic_A , italic_B start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) / italic_τ ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C end_POSTSUPERSCRIPT roman_exp ( ( italic_s italic_i italic_m ( italic_A , italic_B start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) / italic_τ ) end_ARG.

Therefore, unnormalized variational distribution for B1subscript𝐵1B_{1}italic_B start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT is defined as follows:

q(B1|A,B2:C)=p(B1)Cexp(sim(A,B1)/τ)i=1Cexp((sim(A,Bi))/τ)𝑞conditionalsubscript𝐵1𝐴subscript𝐵:2𝐶𝑝subscript𝐵1𝐶𝑠𝑖𝑚𝐴subscript𝐵1𝜏superscriptsubscript𝑖1𝐶𝑠𝑖𝑚𝐴subscript𝐵𝑖𝜏q(B_{1}|A,B_{2:C})=p(B_{1}){C\cdot\exp(sim(A,B_{1})/\tau)\over\sum_{i=1}^{C}% \exp((sim(A,B_{i}))/\tau)}italic_q ( italic_B start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_A , italic_B start_POSTSUBSCRIPT 2 : italic_C end_POSTSUBSCRIPT ) = italic_p ( italic_B start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) divide start_ARG italic_C ⋅ roman_exp ( italic_s italic_i italic_m ( italic_A , italic_B start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) / italic_τ ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C end_POSTSUPERSCRIPT roman_exp ( ( italic_s italic_i italic_m ( italic_A , italic_B start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) / italic_τ ) end_ARG (18)

By Jensen’s inequality, the lower bound of Eq. 17 is derived as follows:

𝔼p(A,B1)log𝔼p(B2:C)(q(B1|A,B2:C)p(B1))subscript𝔼𝑝𝐴subscript𝐵1subscript𝔼𝑝subscript𝐵:2𝐶𝑞conditionalsubscript𝐵1𝐴subscript𝐵:2𝐶𝑝subscript𝐵1\displaystyle\operatorname{\mathbb{E}}_{p(A,B_{1})}\log\operatorname{\mathbb{E% }}_{p(B_{2:C})}\left({q(B_{1}|A,B_{2:C})\over p(B_{1})}\right)blackboard_E start_POSTSUBSCRIPT italic_p ( italic_A , italic_B start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT roman_log blackboard_E start_POSTSUBSCRIPT italic_p ( italic_B start_POSTSUBSCRIPT 2 : italic_C end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT ( divide start_ARG italic_q ( italic_B start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT | italic_A , italic_B start_POSTSUBSCRIPT 2 : italic_C end_POSTSUBSCRIPT ) end_ARG start_ARG italic_p ( italic_B start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_ARG ) 𝔼p(A,B1)p(B2:C)log(p(B1)Cexp(sim(A,B1)/τ)i=1Cexp((sim(A,Bi))/τ)p(B1))absentsubscript𝔼𝑝𝐴subscript𝐵1𝑝subscript𝐵:2𝐶𝑝subscript𝐵1𝐶𝑠𝑖𝑚𝐴subscript𝐵1𝜏superscriptsubscript𝑖1𝐶𝑠𝑖𝑚𝐴subscript𝐵𝑖𝜏𝑝subscript𝐵1\displaystyle\geq\operatorname{\mathbb{E}}_{p(A,B_{1})p(B_{2:C})}\log\left({p(% B_{1}){C\cdot\exp(sim(A,B_{1})/\tau)\over\sum_{i=1}^{C}\exp((sim(A,B_{i}))/% \tau)}\over p(B_{1})}\right)≥ blackboard_E start_POSTSUBSCRIPT italic_p ( italic_A , italic_B start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) italic_p ( italic_B start_POSTSUBSCRIPT 2 : italic_C end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT roman_log ( divide start_ARG italic_p ( italic_B start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) divide start_ARG italic_C ⋅ roman_exp ( italic_s italic_i italic_m ( italic_A , italic_B start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) / italic_τ ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C end_POSTSUPERSCRIPT roman_exp ( ( italic_s italic_i italic_m ( italic_A , italic_B start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) / italic_τ ) end_ARG end_ARG start_ARG italic_p ( italic_B start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_ARG )
=𝔼p(A,B1)p(B2:C)log(Cexp(sim(A,B1)/τ)i=1Cexp((sim(A,Bi))/τ))absentsubscript𝔼𝑝𝐴subscript𝐵1𝑝subscript𝐵:2𝐶𝐶𝑠𝑖𝑚𝐴subscript𝐵1𝜏superscriptsubscript𝑖1𝐶𝑠𝑖𝑚𝐴subscript𝐵𝑖𝜏\displaystyle=\operatorname{\mathbb{E}}_{p(A,B_{1})p(B_{2:C})}\log\left({{C% \cdot\exp(sim(A,B_{1})/\tau)\over\sum_{i=1}^{C}\exp((sim(A,B_{i}))/\tau)}}\right)= blackboard_E start_POSTSUBSCRIPT italic_p ( italic_A , italic_B start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) italic_p ( italic_B start_POSTSUBSCRIPT 2 : italic_C end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT roman_log ( divide start_ARG italic_C ⋅ roman_exp ( italic_s italic_i italic_m ( italic_A , italic_B start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) / italic_τ ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C end_POSTSUPERSCRIPT roman_exp ( ( italic_s italic_i italic_m ( italic_A , italic_B start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) / italic_τ ) end_ARG )
=logCCE(ϕ(f(X)),X,Y)absent𝐶subscript𝐶𝐸italic-ϕ𝑓𝑋𝑋𝑌\displaystyle=\log C-\mathcal{L}_{CE}(\phi(f(X)),X,Y)= roman_log italic_C - caligraphic_L start_POSTSUBSCRIPT italic_C italic_E end_POSTSUBSCRIPT ( italic_ϕ ( italic_f ( italic_X ) ) , italic_X , italic_Y ) (19)

For training the prior network, we minimize the CEsubscript𝐶𝐸\mathcal{L}_{CE}caligraphic_L start_POSTSUBSCRIPT italic_C italic_E end_POSTSUBSCRIPT, which is the upper bound of Mutual information.

Derivation from Wasserstein Gradient Flow to Stein Variational Gradient Descent

Define the functional F𝐹Fitalic_F as follows:

F(μ)DKL(μ||π)F(\mu)\coloneqq D_{KL}(\mu||\pi)italic_F ( italic_μ ) ≔ italic_D start_POSTSUBSCRIPT italic_K italic_L end_POSTSUBSCRIPT ( italic_μ | | italic_π ) (20)

, where πexp(V(θ))proportional-to𝜋𝑉𝜃\pi\propto\exp(-V(\theta))italic_π ∝ roman_exp ( - italic_V ( italic_θ ) ). Then, Wasserstein Gradient is defined as follows:

δF(μ)δμ=logμπ𝛿𝐹𝜇𝛿𝜇𝜇𝜋\nabla{\delta F(\mu)\over\delta\mu}=\nabla\log{\mu\over\pi}∇ divide start_ARG italic_δ italic_F ( italic_μ ) end_ARG start_ARG italic_δ italic_μ end_ARG = ∇ roman_log divide start_ARG italic_μ end_ARG start_ARG italic_π end_ARG (21)

Considering Wasserstein Gradient in RKHS with transformation 𝒦μT(θ)𝔼θμ[K(θ,θ)T(θ)]subscript𝒦𝜇𝑇𝜃subscript𝔼similar-tosuperscript𝜃𝜇𝐾𝜃superscript𝜃𝑇superscript𝜃\mathcal{K}_{\mu}T(\theta)\coloneqq\operatorname{\mathbb{E}}_{\theta^{\prime}% \sim\mu}[K({\theta},{\theta^{\prime}})T({\theta^{\prime}})]caligraphic_K start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT italic_T ( italic_θ ) ≔ blackboard_E start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∼ italic_μ end_POSTSUBSCRIPT [ italic_K ( italic_θ , italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) italic_T ( italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ], we derive kernelized Wasserstein Gradient as follows:

𝒦μθδF(μ)δμsubscript𝒦𝜇subscriptsuperscript𝜃𝛿𝐹𝜇𝛿𝜇\displaystyle\mathcal{K}_{\mu}\nabla_{\theta^{\prime}}{\delta F(\mu)\over% \delta\mu}caligraphic_K start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT divide start_ARG italic_δ italic_F ( italic_μ ) end_ARG start_ARG italic_δ italic_μ end_ARG =𝒦μθlogμπabsentsubscript𝒦𝜇subscriptsuperscript𝜃𝜇𝜋\displaystyle=\mathcal{K}_{\mu}\nabla_{\theta^{\prime}}\log{\mu\over\pi}= caligraphic_K start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT roman_log divide start_ARG italic_μ end_ARG start_ARG italic_π end_ARG (22)
=K(θ,θ)(θlogπ+θlogμ)μ(θ)𝑑θabsent𝐾𝜃superscript𝜃subscriptsuperscript𝜃𝜋subscriptsuperscript𝜃𝜇𝜇superscript𝜃differential-dsuperscript𝜃\displaystyle=\int K({\theta},{\theta^{\prime}})(-\nabla_{\theta^{\prime}}\log% \pi+\nabla_{\theta^{\prime}}\log\mu)\mu(\theta^{\prime})d\theta^{\prime}= ∫ italic_K ( italic_θ , italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ( - ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT roman_log italic_π + ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT roman_log italic_μ ) italic_μ ( italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) italic_d italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT (23)
=K(θ,θ)(θV(θ)+θlogμ)μ(θ)𝑑θabsent𝐾𝜃superscript𝜃subscriptsuperscript𝜃𝑉superscript𝜃subscriptsuperscript𝜃𝜇𝜇superscript𝜃differential-dsuperscript𝜃\displaystyle=\int K({\theta},{\theta^{\prime}})(\nabla_{\theta^{\prime}}V(% \theta^{\prime})+\nabla_{\theta^{\prime}}\log\mu)\mu(\theta^{\prime})d\theta^{\prime}= ∫ italic_K ( italic_θ , italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ( ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_V ( italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) + ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT roman_log italic_μ ) italic_μ ( italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) italic_d italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT (24)
=K(θ,θ)θV(θ)μ𝑑θθK(θ,θ)μ𝑑θ+[K(θ,θ)μ]|θ=0absent𝐾𝜃superscript𝜃subscriptsuperscript𝜃𝑉superscript𝜃𝜇differential-dsuperscript𝜃subscriptsuperscript𝜃𝐾𝜃superscript𝜃𝜇differential-dsuperscript𝜃subscriptevaluated-atdelimited-[]𝐾𝜃superscript𝜃𝜇normsuperscript𝜃absent0\displaystyle=\int K({\theta},{\theta^{\prime}})\nabla_{\theta^{\prime}}V(% \theta^{\prime})\mu d\theta^{\prime}-\int\nabla_{\theta^{\prime}}K({\theta},{% \theta^{\prime}})\mu d\theta^{\prime}+\underbrace{[K({\theta},{\theta^{\prime}% })\mu]\big{|}_{\|\theta^{\prime}\|\rightarrow\infty}}_{=0}= ∫ italic_K ( italic_θ , italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_V ( italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) italic_μ italic_d italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT - ∫ ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_K ( italic_θ , italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) italic_μ italic_d italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT + under⏟ start_ARG [ italic_K ( italic_θ , italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) italic_μ ] | start_POSTSUBSCRIPT ∥ italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∥ → ∞ end_POSTSUBSCRIPT end_ARG start_POSTSUBSCRIPT = 0 end_POSTSUBSCRIPT (25)

By the Continuity equation, evolving path of θ𝜃\thetaitalic_θ is derived as follows:

tθtsubscript𝑡subscript𝜃𝑡\displaystyle\partial_{t}\theta_{t}∂ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT =𝒦μθ(δF(μ)δμ)absentsubscript𝒦𝜇subscriptsuperscript𝜃𝛿𝐹𝜇𝛿𝜇\displaystyle=-\mathcal{K}_{\mu}\nabla_{\theta^{\prime}}({\delta F(\mu)\over% \delta\mu})= - caligraphic_K start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( divide start_ARG italic_δ italic_F ( italic_μ ) end_ARG start_ARG italic_δ italic_μ end_ARG )
=[K(θ,θ)θV(θ)μ𝑑θθK(θ,θ)μ𝑑θ]absentdelimited-[]𝐾𝜃superscript𝜃subscriptsuperscript𝜃𝑉superscript𝜃𝜇differential-dsuperscript𝜃subscriptsuperscript𝜃𝐾𝜃superscript𝜃𝜇differential-dsuperscript𝜃\displaystyle=-[\int K({\theta},{\theta^{\prime}})\nabla_{\theta^{\prime}}V(% \theta^{\prime})\mu d\theta^{\prime}-\int\nabla_{\theta^{\prime}}K({\theta},{% \theta^{\prime}})\mu d\theta^{\prime}]= - [ ∫ italic_K ( italic_θ , italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_V ( italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) italic_μ italic_d italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT - ∫ ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_K ( italic_θ , italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) italic_μ italic_d italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ] (26)

By discretization, we formulate the following update rule:

θt+1i=θtihMj=1M[K(θti,θtj)θtjV(θtj)θtjK(θti,θtj)]superscriptsubscript𝜃𝑡1𝑖superscriptsubscript𝜃𝑡𝑖𝑀superscriptsubscript𝑗1𝑀delimited-[]𝐾superscriptsubscript𝜃𝑡𝑖superscriptsubscript𝜃𝑡𝑗subscriptsuperscriptsubscript𝜃𝑡𝑗𝑉superscriptsubscript𝜃𝑡𝑗subscriptsuperscriptsubscript𝜃𝑡𝑗𝐾superscriptsubscript𝜃𝑡𝑖superscriptsubscript𝜃𝑡𝑗\theta_{t+1}^{i}=\theta_{t}^{i}-{h\over M}\sum_{j=1}^{M}[K(\theta_{t}^{i},% \theta_{t}^{j})\nabla_{\theta_{t}^{j}}V(\theta_{t}^{j})-\nabla_{\theta_{t}^{j}% }K(\theta_{t}^{i},\theta_{t}^{j})]italic_θ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT = italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT - divide start_ARG italic_h end_ARG start_ARG italic_M end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT [ italic_K ( italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT , italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ) ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_V ( italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ) - ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_K ( italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT , italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ) ] (27)

Implementation Details

We adopted the training settings from the work of (Chen et al. 2023). In our experiments, we used a batch size of 32 for the Oxford-Flowers, FGVC-Aircraft, and Stanford-Cars datasets, while the batch size for other datasets was fixed at 128. The prior network was trained for 10, 20, 20, 40, and 40 shots for each respective dataset. We employed the SGD optimizer with an initial learning rate of 0.002, which was annealed using the CosineAnnealing schedule. All experiments were conducted on a single NVIDIA A100 GPU core.

For evaluation, we considered the model output within a Bayesian framework, taking into account the posterior distribution p(θ|X,Y)𝑝conditional𝜃𝑋𝑌p(\theta|X,Y)italic_p ( italic_θ | italic_X , italic_Y ). Thus, predictive distribution for test data point xsuperscript𝑥x^{\prime}italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT is defined as follows:

p(y|x,𝒟)=p(y|x,θ)p(θ|𝒟,x)𝑑θ𝑝conditionalsuperscript𝑦superscript𝑥𝒟𝑝conditionalsuperscript𝑦superscript𝑥𝜃𝑝conditional𝜃𝒟superscript𝑥differential-d𝜃p(y^{\prime}|x^{\prime},\mathcal{D})=\int p(y^{\prime}|x^{\prime},\theta)p(% \theta|\mathcal{D},x^{\prime})d\thetaitalic_p ( italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT | italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , caligraphic_D ) = ∫ italic_p ( italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT | italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_θ ) italic_p ( italic_θ | caligraphic_D , italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) italic_d italic_θ (28)

, where 𝒟{X,Y}𝒟𝑋𝑌\mathcal{D}\coloneqq\{X,Y\}caligraphic_D ≔ { italic_X , italic_Y }

Table 7 indicates a performance comparison between ours and other multi-prompts methods.

Table 7: Result of Few-shot Classification. We report the mean accuracy by conducting three-replicated experiments.
Dataset Methods 1 shot 2 shots 4 shots 8 shots 16 shots
Caltech101 PLOT (Chen et al. 2023) 89.72±0.44plus-or-minus89.720.4489.72\pm 0.4489.72 ± 0.44 90.48±0.22plus-or-minus90.480.2290.48\pm 0.2290.48 ± 0.22 90.99±0.30plus-or-minus90.990.3090.99\pm 0.3090.99 ± 0.30 91.23±0.46plus-or-minus91.230.4691.23\pm 0.4691.23 ± 0.46 92.29±0.22plus-or-minus92.290.2292.29\pm 0.2292.29 ± 0.22
ProDA (Lu et al. 2022) 89.14±0.19plus-or-minus89.140.1989.14\pm 0.1989.14 ± 0.19 90.08±0.50plus-or-minus90.080.5090.08\pm 0.5090.08 ± 0.50 90.94±0.19plus-or-minus90.940.1990.94\pm 0.1990.94 ± 0.19 91.83±0.27plus-or-minus91.830.2791.83\pm 0.2791.83 ± 0.27 92.47±0.08plus-or-minus92.470.0892.47\pm 0.0892.47 ± 0.08
Ours 90.30±0.33plus-or-minus90.300.33\mathbf{90.30\pm 0.33}bold_90.30 ± bold_0.33 90.87±0.25plus-or-minus90.870.25\mathbf{90.87\pm 0.25}bold_90.87 ± bold_0.25 91.05±0.45plus-or-minus91.050.45\mathbf{91.05\pm 0.45}bold_91.05 ± bold_0.45 91.92±0.08plus-or-minus91.920.08\mathbf{91.92\pm 0.08}bold_91.92 ± bold_0.08 92.89±0.22plus-or-minus92.890.22\mathbf{92.89\pm 0.22}bold_92.89 ± bold_0.22
DTD PLOT (Chen et al. 2023) 46.94±1.89plus-or-minus46.941.8946.94\pm 1.8946.94 ± 1.89 51.46±2.27plus-or-minus51.462.2751.46\pm 2.2751.46 ± 2.27 55.95±0.74plus-or-minus55.950.7455.95\pm 0.7455.95 ± 0.74 61.68±0.34plus-or-minus61.680.3461.68\pm 0.3461.68 ± 0.34 65.27±0.23plus-or-minus65.270.2365.27\pm 0.2365.27 ± 0.23
ProDA (Lu et al. 2022) 47.52±1.69plus-or-minus47.521.6947.52\pm 1.6947.52 ± 1.69 52.17±1.89plus-or-minus52.171.89\mathbf{52.17\pm 1.89}bold_52.17 ± bold_1.89 55.85±0.92plus-or-minus55.850.9255.85\pm 0.9255.85 ± 0.92 62.74±0.38plus-or-minus62.740.3862.74\pm 0.3862.74 ± 0.38 66.27±1.11plus-or-minus66.271.1166.27\pm 1.1166.27 ± 1.11
Ours 49.09±2.29plus-or-minus49.092.29\mathbf{49.09\pm 2.29}bold_49.09 ± bold_2.29 51.73±3.47plus-or-minus51.733.4751.73\pm 3.4751.73 ± 3.47 58.06±0.69plus-or-minus58.060.69\mathbf{58.06\pm 0.69}bold_58.06 ± bold_0.69 63.42±0.12plus-or-minus63.420.12\mathbf{63.42\pm 0.12}bold_63.42 ± bold_0.12 66.57±0.07plus-or-minus66.570.07\mathbf{66.57\pm 0.07}bold_66.57 ± bold_0.07
EuroSAT PLOT (Chen et al. 2023) 54.15±2.04plus-or-minus54.152.0454.15\pm 2.0454.15 ± 2.04 63.64±3.08plus-or-minus63.643.0863.64\pm 3.0863.64 ± 3.08 74.91±2.13plus-or-minus74.912.1374.91\pm 2.1374.91 ± 2.13 78.55±1.33plus-or-minus78.551.3378.55\pm 1.3378.55 ± 1.33 84.06±0.91plus-or-minus84.060.9184.06\pm 0.9184.06 ± 0.91
ProDA (Lu et al. 2022) 49.89±2.37plus-or-minus49.892.3749.89\pm 2.3749.89 ± 2.37 62.37±1.41plus-or-minus62.371.4162.37\pm 1.4162.37 ± 1.41 70.38±1.73plus-or-minus70.381.7370.38\pm 1.7370.38 ± 1.73 77.20±1.97plus-or-minus77.201.9777.20\pm 1.9777.20 ± 1.97 80.62±1.51plus-or-minus80.621.5180.62\pm 1.5180.62 ± 1.51
Ours 60.04±2.08plus-or-minus60.042.08\mathbf{60.04\pm 2.08}bold_60.04 ± bold_2.08 66.02±2.70plus-or-minus66.022.70\mathbf{66.02\pm 2.70}bold_66.02 ± bold_2.70 74.08±2.23plus-or-minus74.082.23\mathbf{74.08\pm 2.23}bold_74.08 ± bold_2.23 77.87±2.51plus-or-minus77.872.51\mathbf{77.87\pm 2.51}bold_77.87 ± bold_2.51 84.08±1.1plus-or-minus84.081.1\mathbf{84.08\pm 1.1}bold_84.08 ± bold_1.1
FGVC-Aircraft PLOT (Chen et al. 2023) 18.22±0.59plus-or-minus18.220.5918.22\pm 0.5918.22 ± 0.59 19.26±0.96plus-or-minus19.260.9619.26\pm 0.9619.26 ± 0.96 22.39±1.41plus-or-minus22.391.4122.39\pm 1.4122.39 ± 1.41 27.58±0.83plus-or-minus27.580.8327.58\pm 0.8327.58 ± 0.83 32.28±0.43plus-or-minus32.280.4332.28\pm 0.4332.28 ± 0.43
ProDA (Lu et al. 2022) 20.25±0.88plus-or-minus20.250.88\mathbf{20.25\pm 0.88}bold_20.25 ± bold_0.88 22.45±0.37plus-or-minus22.450.37\mathbf{22.45\pm 0.37}bold_22.45 ± bold_0.37 24.76±1.00plus-or-minus24.761.0024.76\pm 1.0024.76 ± 1.00 28.46±0.43plus-or-minus28.460.4328.46\pm 0.4328.46 ± 0.43 32.84±0.34plus-or-minus32.840.3432.84\pm 0.3432.84 ± 0.34
Ours 18.10±0.92plus-or-minus18.100.9218.10\pm 0.9218.10 ± 0.92 21.87±1.21plus-or-minus21.871.2121.87\pm 1.2121.87 ± 1.21 25.65±0.59plus-or-minus25.650.59\mathbf{25.65\pm 0.59}bold_25.65 ± bold_0.59 29.40±0.42plus-or-minus29.400.42\mathbf{29.40\pm 0.42}bold_29.40 ± bold_0.42 33.80±0.53plus-or-minus33.800.53\mathbf{33.80\pm 0.53}bold_33.80 ± bold_0.53
Food101 PLOT (Chen et al. 2023) 77.87±0.13plus-or-minus77.870.1377.87\pm 0.1377.87 ± 0.13 77.78±0.37plus-or-minus77.780.3777.78\pm 0.3777.78 ± 0.37 77.20±0.37plus-or-minus77.200.3777.20\pm 0.3777.20 ± 0.37 75.42±0.09plus-or-minus75.420.0975.42\pm 0.0975.42 ± 0.09 77.18±0.16plus-or-minus77.180.1677.18\pm 0.1677.18 ± 0.16
ProDA (Lu et al. 2022) 78.65±0.17plus-or-minus78.650.1778.65\pm 0.1778.65 ± 0.17 79.04±0.17plus-or-minus79.040.1779.04\pm 0.1779.04 ± 0.17 79.47±0.17plus-or-minus79.470.17\mathbf{79.47\pm 0.17}bold_79.47 ± bold_0.17 78.88±0.22plus-or-minus78.880.2278.88\pm 0.2278.88 ± 0.22 79.82±0.10plus-or-minus79.820.1079.82\pm 0.1079.82 ± 0.10
Ours 78.87±0.17plus-or-minus78.870.17\mathbf{78.87\pm 0.17}bold_78.87 ± bold_0.17 79.25±0.06plus-or-minus79.250.06\mathbf{79.25\pm 0.06}bold_79.25 ± bold_0.06 79.43±0.10plus-or-minus79.430.1079.43\pm 0.1079.43 ± 0.10 79.49±0.17plus-or-minus79.490.17\mathbf{79.49\pm 0.17}bold_79.49 ± bold_0.17 79.83±0.06plus-or-minus79.830.06\mathbf{79.83\pm 0.06}bold_79.83 ± bold_0.06
ImageNet PLOT (Chen et al. 2023) 59.32±0.58plus-or-minus59.320.5859.32\pm 0.5859.32 ± 0.58 59.81±0.25plus-or-minus59.810.2559.81\pm 0.2559.81 ± 0.25 61.12±0.14plus-or-minus61.120.1461.12\pm 0.1461.12 ± 0.14 61.91±0.09plus-or-minus61.910.0961.91\pm 0.0961.91 ± 0.09 63.14±0.16plus-or-minus63.140.1663.14\pm 0.1663.14 ± 0.16
ProDA (Lu et al. 2022) 61.65±0.46plus-or-minus61.650.4661.65\pm 0.4661.65 ± 0.46 62.20±0.11plus-or-minus62.200.1162.20\pm 0.1162.20 ± 0.11 61.96±0.30plus-or-minus61.960.3061.96\pm 0.3061.96 ± 0.30 61.78±0.29plus-or-minus61.780.2961.78\pm 0.2961.78 ± 0.29 62.73±0.11plus-or-minus62.730.1162.73\pm 0.1162.73 ± 0.11
Ours 62.11±0.26plus-or-minus62.110.26\mathbf{62.11\pm 0.26}bold_62.11 ± bold_0.26 62.50±0.10plus-or-minus62.500.10\mathbf{62.50\pm 0.10}bold_62.50 ± bold_0.10 63.05±0.15plus-or-minus63.050.15\mathbf{63.05\pm 0.15}bold_63.05 ± bold_0.15 63.70±0.17plus-or-minus63.700.17\mathbf{63.70\pm 0.17}bold_63.70 ± bold_0.17 64.50±0.06plus-or-minus64.500.06\mathbf{64.50\pm 0.06}bold_64.50 ± bold_0.06
Oxford Flowers PLOT (Chen et al. 2023) 72.32±0.94plus-or-minus72.320.9472.32\pm 0.9472.32 ± 0.94 82.65±0.38plus-or-minus82.650.3882.65\pm 0.3882.65 ± 0.38 88.13±0.27plus-or-minus88.130.2788.13\pm 0.2788.13 ± 0.27 92.56±0.02plus-or-minus92.560.0292.56\pm 0.0292.56 ± 0.02 95.64±0.28plus-or-minus95.640.2895.64\pm 0.2895.64 ± 0.28
ProDA (Lu et al. 2022) 70.56±0.48plus-or-minus70.560.4870.56\pm 0.4870.56 ± 0.48 83.73±0.50plus-or-minus83.730.5083.73\pm 0.5083.73 ± 0.50 88.94±0.30plus-or-minus88.940.3088.94\pm 0.3088.94 ± 0.30 93.30±0.35plus-or-minus93.300.3593.30\pm 0.3593.30 ± 0.35 95.47±0.18plus-or-minus95.470.1895.47\pm 0.1895.47 ± 0.18
Ours 74.42±0.58plus-or-minus74.420.58\mathbf{74.42\pm 0.58}bold_74.42 ± bold_0.58 84.23±0.14plus-or-minus84.230.14\mathbf{84.23\pm 0.14}bold_84.23 ± bold_0.14 89.13±0.31plus-or-minus89.130.31\mathbf{89.13\pm 0.31}bold_89.13 ± bold_0.31 93.38±0.09plus-or-minus93.380.09\mathbf{93.38\pm 0.09}bold_93.38 ± bold_0.09 95.67±0.09plus-or-minus95.670.09\mathbf{95.67\pm 0.09}bold_95.67 ± bold_0.09
Oxford Pets PLOT (Chen et al. 2023) 87.35±0.67plus-or-minus87.350.6787.35\pm 0.6787.35 ± 0.67 87.16±0.37plus-or-minus87.160.3787.16\pm 0.3787.16 ± 0.37 88.25±0.58plus-or-minus88.250.5888.25\pm 0.5888.25 ± 0.58 87.39±0.39plus-or-minus87.390.3987.39\pm 0.3987.39 ± 0.39 87.20±0.20plus-or-minus87.200.2087.20\pm 0.2087.20 ± 0.20
ProDA (Lu et al. 2022) 88.76±0.08plus-or-minus88.760.0888.76\pm 0.0888.76 ± 0.08 88.15±0.39plus-or-minus88.150.3988.15\pm 0.3988.15 ± 0.39 89.17±0.41plus-or-minus89.170.4189.17\pm 0.4189.17 ± 0.41 89.97±0.23plus-or-minus89.970.23\mathbf{89.97\pm 0.23}bold_89.97 ± bold_0.23 89.61±0.36plus-or-minus89.610.3689.61\pm 0.3689.61 ± 0.36
Ours 88.97±0.52plus-or-minus88.970.52\mathbf{88.97\pm 0.52}bold_88.97 ± bold_0.52 88.28±0.39plus-or-minus88.280.39\mathbf{88.28\pm 0.39}bold_88.28 ± bold_0.39 89.48±0.28plus-or-minus89.480.28\mathbf{89.48\pm 0.28}bold_89.48 ± bold_0.28 89.79±0.16plus-or-minus89.790.1689.79\pm 0.1689.79 ± 0.16 89.73±0.63plus-or-minus89.730.63\mathbf{89.73\pm 0.63}bold_89.73 ± bold_0.63
Stanford Cars PLOT (Chen et al. 2023) 56.21±0.78plus-or-minus56.210.7856.21\pm 0.7856.21 ± 0.78 57.35±0.27plus-or-minus57.350.2757.35\pm 0.2757.35 ± 0.27 63.35±0.54plus-or-minus63.350.5463.35\pm 0.5463.35 ± 0.54 67.60±0.05plus-or-minus67.600.0567.60\pm 0.0567.60 ± 0.05 73.78±0.11plus-or-minus73.780.1173.78\pm 0.1173.78 ± 0.11
ProDA (Lu et al. 2022) 59.55±0.13plus-or-minus59.550.1359.55\pm 0.1359.55 ± 0.13 62.08±0.65plus-or-minus62.080.6562.08\pm 0.6562.08 ± 0.65 66.32±0.18plus-or-minus66.320.1866.32\pm 0.1866.32 ± 0.18 71.11±0.21plus-or-minus71.110.2171.11\pm 0.2171.11 ± 0.21 74.86±0.21plus-or-minus74.860.2174.86\pm 0.2174.86 ± 0.21
Ours 59.65±0.41plus-or-minus59.650.41\mathbf{59.65\pm 0.41}bold_59.65 ± bold_0.41 63.39±0.35plus-or-minus63.390.35\mathbf{63.39\pm 0.35}bold_63.39 ± bold_0.35 67.18±0.30plus-or-minus67.180.30\mathbf{67.18\pm 0.30}bold_67.18 ± bold_0.30 71.86±0.11plus-or-minus71.860.11\mathbf{71.86\pm 0.11}bold_71.86 ± bold_0.11 76.14±0.23plus-or-minus76.140.23\mathbf{76.14\pm 0.23}bold_76.14 ± bold_0.23
Sun397 PLOT (Chen et al. 2023) 63.03±0.32plus-or-minus63.030.3263.03\pm 0.3263.03 ± 0.32 62.10±0.50plus-or-minus62.100.5062.10\pm 0.5062.10 ± 0.50 65.60±0.44plus-or-minus65.600.4465.60\pm 0.4465.60 ± 0.44 66.90±0.25plus-or-minus66.900.2566.90\pm 0.2566.90 ± 0.25 69.76±0.29plus-or-minus69.760.2969.76\pm 0.2969.76 ± 0.29
ProDA (Lu et al. 2022) 63.99±0.15plus-or-minus63.990.15\mathbf{63.99\pm 0.15}bold_63.99 ± bold_0.15 65.56±0.40plus-or-minus65.560.4065.56\pm 0.4065.56 ± 0.40 67.79±0.49plus-or-minus67.790.49\mathbf{67.79\pm 0.49}bold_67.79 ± bold_0.49 69.50±0.13plus-or-minus69.500.13\mathbf{69.50\pm 0.13}bold_69.50 ± bold_0.13 71.61±0.18plus-or-minus71.610.18\mathbf{71.61\pm 0.18}bold_71.61 ± bold_0.18
Ours 63.99±0.20plus-or-minus63.990.20\mathbf{63.99\pm 0.20}bold_63.99 ± bold_0.20 65.59±0.19plus-or-minus65.590.19\mathbf{65.59\pm 0.19}bold_65.59 ± bold_0.19 67.01±0.36plus-or-minus67.010.3667.01\pm 0.3667.01 ± 0.36 69.46±0.08plus-or-minus69.460.0869.46\pm 0.0869.46 ± 0.08 71.59±0.19plus-or-minus71.590.1971.59\pm 0.1971.59 ± 0.19
UCF101 PLOT (Chen et al. 2023) 64.31±0.36plus-or-minus64.310.3664.31\pm 0.3664.31 ± 0.36 67.64±0.42plus-or-minus67.640.4267.64\pm 0.4267.64 ± 0.42 70.87±0.32plus-or-minus70.870.3270.87\pm 0.3270.87 ± 0.32 75.77±0.25plus-or-minus75.770.2575.77\pm 0.2575.77 ± 0.25 77.58±0.29plus-or-minus77.580.2977.58\pm 0.2977.58 ± 0.29
ProDA (Lu et al. 2022) 64.44±0.68plus-or-minus64.440.6864.44\pm 0.6864.44 ± 0.68 67.80±0.11plus-or-minus67.800.1167.80\pm 0.1167.80 ± 0.11 70.78±0.23plus-or-minus70.780.2370.78\pm 0.2370.78 ± 0.23 77.01±0.35plus-or-minus77.010.35\mathbf{77.01\pm 0.35}bold_77.01 ± bold_0.35 79.34±0.68plus-or-minus79.340.68\mathbf{79.34\pm 0.68}bold_79.34 ± bold_0.68
Ours 65.83±0.84plus-or-minus65.830.84\mathbf{65.83\pm 0.84}bold_65.83 ± bold_0.84 69.05±0.19plus-or-minus69.050.19\mathbf{69.05\pm 0.19}bold_69.05 ± bold_0.19 71.79±0.13plus-or-minus71.790.13\mathbf{71.79\pm 0.13}bold_71.79 ± bold_0.13 76.71±0.60plus-or-minus76.710.6076.71\pm 0.6076.71 ± 0.60 79.07±0.12plus-or-minus79.070.1279.07\pm 0.1279.07 ± 0.12
Average PLOT (Chen et al. 2023) 62.68±0.09plus-or-minus62.680.0962.68\pm 0.0962.68 ± 0.09 65.39±0.48plus-or-minus65.390.4865.39\pm 0.4865.39 ± 0.48 68.98±0.24plus-or-minus68.980.2468.98\pm 0.2468.98 ± 0.24 71.51±0.10plus-or-minus71.510.1071.51\pm 0.1071.51 ± 0.10 74.38±0.06plus-or-minus74.380.0674.38\pm 0.0674.38 ± 0.06
ProDA (Lu et al. 2022) 63.13±0.26plus-or-minus63.130.2663.13\pm 0.2663.13 ± 0.26 66.88±0.22plus-or-minus66.880.2266.88\pm 0.2266.88 ± 0.22 69.67±0.16plus-or-minus69.670.1669.67\pm 0.1669.67 ± 0.16 72.89±0.25plus-or-minus72.890.2572.89\pm 0.2572.89 ± 0.25 75.06±0.05plus-or-minus75.060.0575.06\pm 0.0575.06 ± 0.05
Ours 64.67±0.26plus-or-minus64.670.26\mathbf{64.67\pm 0.26}bold_64.67 ± bold_0.26 67.53±0.50plus-or-minus67.530.50\mathbf{67.53\pm 0.50}bold_67.53 ± bold_0.50 70.54±0.15plus-or-minus70.540.15\mathbf{70.54\pm 0.15}bold_70.54 ± bold_0.15 73.36±0.24plus-or-minus73.360.24\mathbf{73.36\pm 0.24}bold_73.36 ± bold_0.24 75.81±0.06plus-or-minus75.810.06\mathbf{75.81\pm 0.06}bold_75.81 ± bold_0.06

Semantics of Prompts

To analyze the semantic meaning of the learned context vectors in our approach, we performed a study where we extracted the nearest words in the embedding space of each context vector. While not all vectors have a straightforward semantic interpretation in continuous space, we observed that certain components of each context vector exhibited clear semantic relevance and effectively described the image data.

Table 8 demonstrates that each context vector is trained in a direction that corresponds to specific aspects of the image data. This finding indicates that our model has successfully learned to capture meaningful representations of the image content within the context vectors. For instance, context 1 contains words such as ”americanair” and ”usnavy,” which are associated with civil aircraft and fighter aircraft, respectively, in the dataset.

Table 8: Interpretable and closest words to our context vectors for FGVC-Aircraft. We report one of the top 4 words, which are interpretable. - means that there are no interpretable words and bold means highly relevant words to datasets.
Number Context 1 Context 2 Context 3 Context 4
1 sculpt calling salazar -
2 espn byte grassroots bucketlist
3 postponed turquoise installation below
4 grocery wig followed accepted
5 taking shells translates turkish
6 usnavy likes administrator fair
7 fort - centred qantas
8 staten fo blaze modern
9 rusher tue attempting serve
10 wielding ata - crossing
11 trick brown blowing -
12 pate vfl donkey occasion
13 americanair - legendary times
14 eight ima reunite inevit
15 ..? thanku flew letting
16 brox facing ! grand

Empirical comparison with BPL

Table 9 (unseen classes generalization experiments) indicates that APP is more efficient in time and memory complexity with more robust performance, harmonic mean (accuracies between base and new classes). This gain demonstrates that the Gaussian variational distribution of BPL is not sufficient to capture the whole multimodes of image features. In addition, the conditional posterior assumption is not scalable, when the batch size is set to high.

Dataset Method BS Memory (GB) Time (hr) H (%)
Caltech101 BPL 1 11.0 5.75 92.392.392.392.3
128 - - -
APP (Ours) 1 10.610.6\mathbf{10.6}bold_10.6 2.502.50\mathbf{2.50}bold_2.50 93.093.0\mathbf{93.0}bold_93.0
128 10.810.8\mathbf{10.8}bold_10.8 0.250.25\mathbf{0.25}bold_0.25 93.093.0\mathbf{93.0}bold_93.0
Oxford pets BPL 1 7.4 1.5 91.591.591.591.5
128 - - -
APP (Ours) 1 7.07.0\mathbf{7.0}bold_7.0 0.830.83\mathbf{0.83}bold_0.83 91.791.7\mathbf{91.7}bold_91.7
128 8.28.2\mathbf{8.2}bold_8.2 0.060.06\mathbf{0.06}bold_0.06 92.492.4\mathbf{92.4}bold_92.4
UCF101 BPL 1 11.5 9.26 72.872.872.872.8
128 - - -
APP (Ours) 1 7.07.0\mathbf{7.0}bold_7.0 3.253.25\mathbf{3.25}bold_3.25 73.373.3\mathbf{73.3}bold_73.3
128 8.58.5\mathbf{8.5}bold_8.5 0.10.1\mathbf{0.1}bold_0.1 76.876.8\mathbf{76.8}bold_76.8
Table 9: Results of BPL in unseen classes generalization experiments. We report the performance of three replicated experiments. The symbol ”-” indicates the training failure due to the memory issue.