Probabilistic Inference in Language Models
via Twisted Sequential Monte Carlo
Abstract
Numerous capability and safety techniques of Large Language Models (LLMs), including RLHF, automated red-teaming, prompt engineering, and infilling, can be cast as sampling from an unnormalized target distribution defined by a given reward or potential function over the full sequence. In this work, we leverage the rich toolkit of Sequential Monte Carlo (SMC) for these probabilistic inference problems. In particular, we use learned twist functions to estimate the expected future value of the potential at each timestep, which enables us to focus inference-time computation on promising partial sequences. We propose a novel contrastive method for learning the twist functions, and establish connections with the rich literature of soft reinforcement learning. As a complementary application of our twisted SMC framework, we present methods for evaluating the accuracy of language model inference techniques using novel bidirectional SMC bounds on the log partition function. These bounds can be used to estimate the KL divergence between the inference and target distributions in both directions. We apply our inference evaluation techniques to show that twisted SMC is effective for sampling undesirable outputs from a pretrained model (a useful component of harmlessness training and automated red-teaming), generating reviews with varied sentiment, and performing infilling tasks.
1 Introduction
A wide range of language model learning and inference tasks can be viewed as steering a model’s generations to satisfy a specified property. In particular, traditional reinforcement learning from human feedback (RLHF) pipelines (Ziegler et al., 2019; Stiennon et al., 2020; Ouyang et al., 2022; Bai et al., 2022; Rafailov et al., 2023) may be viewed as targeting an unnormalized target modulated by a terminal reward function which reflects human feedback (Korbak et al., 2022b). Red-teaming techniques such as prompt-engineering and infilling may seek target outputs with low reward or (high probability of) undesirable responses (Zou et al., 2023; Perez et al., 2022). In reasoning tasks, we may seek to target outputs which are likely to be deemed valid by a ‘verifier’ (Cobbe et al., 2021; Anil et al., 2021; Dohan et al., 2022; Hu et al., 2023). Specific properties of the generated responses might also be enforced (Khalifa et al., 2020; Yang & Klein, 2021; Lew et al., 2023).
We view the above tasks as instances of probabilistic inference: sampling from a target unnormalized density and estimating its intractable (log) normalization constant. Consider a pretrained base model which generates responses of maximum length based on a variable-length prompt . We consider defining the target distribution of interest using the base model modulated by a potential function which evaluates full sequences,
(1) | ||||
where denotes the unnormalized density. We refer to as the normalization constant or partition function, which is intractable due to the summation over . We drop dependence on to avoid clutter, but note that each prompt induces a different partition function. In the context of the aforementioned applications, may be derived from a human preference model (for RLHF), an indication of bad behavior (for automated red-teaming), or a verifier’s prediction of correctness (for reasoning tasks). We refer to Table 5 oder Korbak et al. (2022b); Dohan et al. (2022); Phan et al. (2023); Hu et al. (2023) for further examples and discussion of probabilistic inference in language models.
Twisted Sequential Monte Carlo in Language Models
In this work, we leverage tools from (twisted) Sequential Monte Carlo (SMC) (Doucet et al., 2001; Del Moral et al., 2006; Briers et al., 2010; Chopin et al., 2020) to perform and evaluate inference in the language modeling setting (Sec. 3). A particular challenge in sampling from Eq. 1 is that the target distribution is non-causal. In order to sample tokens sequentially, one needs to infer the marginal distribution , which involves an intractable marginalization. To address this problem, we propose to learn twist functions which modulate the base model such that matches the target marginals , up to normalization. The twist functions can be used to focus each step of language model generation on promising partial sequences.
Evaluating Inference in Language Modeling
Sampling from the target distribution is closely intertwined with bounding the log partition function. Similarly to variational inference or traditional RLHF objectives (Korbak et al., 2022b), SMC algorithms yield lower bounds on , where tighter bounds typically coincide with more accurate target sampling. However, upper bounds may often be obtained when an exact target sample is available (Grosse et al., 2015, 2016; Brekelmans et al., 2022). The difference between upper and lower bounds on in fact yields an upper bound on the symmetrized KL divergence between inference samples and the target distribution (Grosse et al., 2016). For these reasons, we argue in Sec. 5 that log partition function estimates are a powerful tool for evaluating language model inference techniques.
Contributions
Our probabilistic inference perspective leads to the following contributions:
-
•
Twisted Sequential Monte Carlo for Language Modeling: We view twisted SMC as a general framework for sampling and evaluation of language models. While twisted SMC is well-known and Lew et al. (2023) consider SMC with fixed, few-step-ahead target information in the language modeling setting, we propose to learn intermediate twist functions for target distributions defined by terminal potential only.
-
•
Contrastive Twist Learning: We develop probabilistic methods for learning intermediate twist functions, presenting a novel contrastive twist learning (CTL) method inspired by energy-based modeling and density ratio estimation in Sec. 4.1. Further, we adapt existing twisted SMC methods (Lawson et al., 2018, 2022; Lioutas et al., 2022) to the language modeling setting, and highlight connections with inference techniques inspired by (soft) reinforcement learning (RL).
-
•
Evaluating Inference in Language Models: Finally, we demonstrate that twisted SMC provides a rich set of tools for evaluating language model fine-tuning or controlled generation techniques. We propose a novel SMC upper bound on which is applicable when an exact target sample is available and may be of independent interest. We leverage these bounds to evaluate the quality of inference by measuring the KL divergence to the target in both directions, which can be used to diagnose mode-dropping behavior of methods such as proximal policy optimization (PPO) (Schulman et al., 2017) which optimize a mode-seeking divergence.
We proceed to describe background on importance sampling and SMC in Sec. 2, before presenting our framework for twisted SMC in the language modeling setting in Sec. 3. We propose methods to learn the twist functions in Sec. 4 and methods to evaluate inference in Sec. 5. Our experimental results in Sec. 7 showcase the ability of twisted SMC to improve controlled generation and lend insights into inference quality in existing methods.
2 Background
Suppose we are given access to an unnormalized density which can be efficiently evaluated. We focus on estimation of the partition function or normalization constant , since unbiased estimators with low variance yield approximate sampling techniques which closely approximate the target distribution (Finke, 2015; Maddison et al., 2017). We review simple importance sampling (SIS) and SMC techniques in this section.
2.1 Simple Importance Sampling
Simple importance sampling (SIS) provides an unbiased estimator of by calculating importance weights for any normalized proposal distribution ,
(2) |
which is unbiased since . The importance weights also yield an an unbiased -sample estimator of the partition function,
(3) |
By normalizing the weights in Eq. 2 over samples from , we can obtain (biased) estimators of expectations under ,
(4) |
or select an approximate target sample from a categorical distribution with the self-normalized importance weights
(5) |
The quality of the approximations in Eq. 3-(5) depends crucially on how well the proposal (which may be learned, Sec. 3.2) matches the target . While we discuss evaluation methods in Sec. 5, note that if inference is exact (i.e., ), then the variance of the importance weights is zero, as for all .
2.2 Sequential Monte Carlo
SMC improves inference by decomposing it into easier subproblems involving a set of unnormalized intermediate target distributions . A key observation is that as long as , we obtain an unbiased estimate of the partition function , regardless of the intermediate and proposal .
We begin by defining the incremental importance weights
(6) |
where is the unnormalized density of .
SMC maintains a set of partial sequences, by first sampling from the proposal in each index . Optional resampling steps may be performed to clone sequences with high incremental importance weights using
(7) |
similarly to Eq. 5. Since resampling is performed with replacement, sequences with high weights may be cloned multiple times. The resulting are used as prefixes for the next step of proposal sampling in index (see 20).
We can show that SMC yields an unbiased estimator of the normalization constant , by considering the extended state space of token and index random variables from the sampling procedure in 20. Assuming resampling at every step,111 The decision to resample may be based on an adaptive condition such as Effective Sample Size (ESS) (Chopin et al., 2020). For , let index times where resampling occurs and fix and . The estimator becomes , and the final-step weights for expectations in Eq. 4 or sampling in Eq. 5 are given by .
(8) |
To see that is unbiased, we can view Eq. 8 as performing simple importance sampling in the extended state space, for appropriate definitions of and detailed in App. F oder (Andrieu et al., 2010; Maddison et al., 2017). Intuitively, we may view the average incremental importance weights at each step as estimating the partition function ratio . Eq. 8 composes intermediate partition function ratio estimators to obtain an estimate of the final , with .
With no resampling, SMC reduces to SIS with target and proposal . Using the final-step SMC weights, we may estimate expectations or draw approximate samples as in Eq. 4-(5).
Fig. 1 illustrates the key advantage of SMC resampling over SIS. While a suboptimal may produce sequences with low probability under the target , SMC resampling with well-chosen intermediate targets clones the most promising partial sequences at step . Since later sampling proceeds from these prefixes, we expect to obtain final sequences which better cover the high-probability regions of the target distribution. We discuss techniques to evaluate the quality of SMC oder SIS sampling in Sec. 5.
3 Twisted Sequential Monte Carlo for Language Modeling
A key design choice in the SMC procedure above is the intermediate targets , where we assume is always the target distribution. In state-space models with observation likelihoods or environments with intermediate rewards, filtering SMC considers target information collected from times to define . (Chopin et al., 2020). Previous work on SMC for language models (Lew et al., 2023) has considered per-token or few-step-ahead statistics to define tractable intermediate . However, we are often interested in target distributions which are determined by a terminal potential only, as in Eq. 1.
In such settings, twisted SMC methods (Briers et al., 2010; Whiteley & Lee, 2014; Lawson et al., 2022) consider the full target information (until time ) to define . In other words, our desired intermediate targets are the true marginals of the target distribution. Intuitively, note that in order to exactly sample , we need to ensure partial sequences are distributed according to the intermediate marginals . In Sec. 3.1, we will represent the intermediate targets using twist functions which modulate the base model to (approximately) match the target marginals, thereby summarizing future information relevant to sampling at time .
3.1 Twist Functions
We represent the intermediate target distributions for SMC sampling using the following general form.
Definition 3.1 ( Twisted (Intermediate) Targets ).
Using approximate twist functions and the final target , we define the twisted intermediate target distributions
(9) |
For an arbitrary proposal and the unnormalized targets in Eq. 9, the incremental importance weights are given by
(10) |
While uninformed twist functions may result in which are no closer to the target marginal than the base model (for example, in early stages of learning), the crucial fact is that our final target distribution in Eq. 9 reflects the target potential . As in Sec. 2.2, this ensures that, regardless of the intermediate twists, our resulting importance sampling estimators will be unbiased.
Finally, the optimal twists recover the intermediate marginals of the target distribution. We state the sense in which and are optimal in Sec. A.1, and prove the following proposition in App. B Prop. B.1.
Proposition 3.2 (Optimal Twists).
For a given target distribution in Eq. 1, the optimal twist functions (in regions where ) correspond to
(11) |
Up to a constant independent of , the optimal twists are
(12) |
and satisfy the recursion
(13) |
Since the optimal twist functions are unavailable due to the need to marginalize over future timesteps, we consider learning approximate twist functions using methods in Sec. 4.
3.2 Proposal Distribution
For a given set of targets , the importance weights in Eq. 10 depend crucially on the choice of proposal.
Base Model as Proposal
The most straightforward choice of proposal is the base pre-trained model, . While we demonstrate in Sec. 7 that SMC resampling with learned twists and the base model proposal can closely approximate the target distribution, this may require large . We can achieve greater efficiency using better choices of proposal.
Twist-Induced Proposal
For given targets , the optimal proposal minimizes the variance of the importance weights (Sec. A.1). In the language model setting with a terminal potential only, we will in fact be able to sample from the optimal proposal for the one-step importance weights.
Proposition 3.3.
(Twist-Induced Proposal). For a given set of intermediate twisted targets in Eq. 9, the proposal which minimizes the variance of the one-step incremental importance weights is given by
(14) | ||||
See Sec. A.2 for proof. For , we can construct a parameterization of such that the proposal is tractable to sample in transformer architectures, where the normalization sums over the discrete vocabulary of next tokens . However, for the final timestep, note that may require calls to a different neural network such as a reward model or classifier. We thus consider an approximate for the proposal in the final step. With slight abuse of notation, we let denote this tractable proposal over full sequences,
(15) |
Using this proposal, the incremental weights become
|
(16) |
which are independent of for .
Variational Proposal
As noted in Sec. 2.1, SMC with no resampling steps reduces to SIS with the full target distribution . Policy gradient methods (Schulman et al., 2017; Parshakova et al., 2019; Korbak et al., 2022a; Go et al., 2023) which directly learn a tractable approximation to the target distribution may thus be viewed as a particularly simple instance of SMC, or inference more generally (see Korbak et al. (2022b)). We may also evaluate these inference methods using our proposed tools in Sec. 5. See Table 1 and App. E for detailed losses and discussion.
Finally, note that a separate proposal might also be learned alongside the twisting targets . This may be useful to approximate the variance-minimizing proposal for multi-step or adaptive resampling (Prop. A.5) beyond the tractable optimal one-step proposal in Prop. 3.3. We discuss training losses based on multi-step importance weights in Sec. C.1.
3.3 Conditional Target Distributions
More generally, we may consider conditional target distributions, obtained by conditioning on an observation random variable . This mirrors the standard setting of SMC in state-space models (Doucet et al., 2001; Briers et al., 2010; Gu et al., 2015; Maddison et al., 2017; Lawson et al., 2022), with further discussion in Sec. B.2.
Defining as a probabilistic model of , our target distribution is the posterior ,
(17) |
where the partition function is the marginal of the given .
In this setting, Prop. 3.2 suggests that the optimal twists, which match the marginals , correspond to the conditional likelihood of given ,
(18) |
since . We can proceed to construct intermediate target distributions and proposals as in the previous sections, where and even may be conditioned on a particular value of .
To recover the unconditional setting, we can fix a binary observational variable (Levine, 2018) and omit explicit conditioning, showing that conditional twist learning generalizes our previous exposition.222To obtain a probabilistic interpretation for , note we need to ensure . As a result, sampling from the target or joint is no easier in this interpretation than in Eq. 1, which is intractable in general. For example, finding and dividing to rescale is equivalent to being able to perform rejection sampling with the base model proposal (see Sec. 4.1.2).
Exact Target Sampling on Simulated Data
Assuming is tractable to sample, we may obtain an exact sample from the target posterior for simulated using ancestral sampling. In particular, by sampling , we obtain a sample from the joint distribution, which also factorizes as . Using the latter factorization, we may interpret as an exact sample from the target posterior for the given .
We refer to this as the Bidirectional Monte Carlo (BDMC) trick (Grosse et al., 2015, 2016), and will use it to draw exact samples for training in Sec. 4.1.2 or evaluation in Sec. 5.
3.4 Connections with Reinforcement Learning
Twisted SMC shares close connections with (soft) reinforcement learning (Levine, 2018; Piché et al., 2018; Lawson et al., 2018; Heng et al., 2020), which we develop with detailed discussion in Sec. B.3 and App. D. In particular, we consider language modeling as a Markov Decision Process (MDP) with states , actions , and deterministic transitions . We describe two different definitions of the reward function in relation to the potential function below. In Sec. B.1, we further extend our SMC framework to capture settings with intermediate potentials or rewards over partial sequences.
Base Model Policy Evaluation
Viewing the final potential as the reward function, the optimality condition in Eq. 12 corresponds to exact policy evaluation of the future reward under the fixed base model policy . Mudgal et al. (2023) adopt this perspective for controlled decoding, and refer to the twist functions as ‘prefix scorers’.
Soft RL with KL Regularization
Alternatively, we may consider the soft or KL-regularized RL target distributions commonly used in language modeling (Levine, 2018; Korbak et al., 2022b) as a special case of our twisted SMC framework. For a regularization strength , define the terminal potential as
(19) |
In this case, the intermediate twist functions in Def. 3.1 correspond to state-action -values, (Sec. B.3). In particular, consider the recursion for the optimal twists in Eq. 13. Taking the logarithm of both sides and recalling the definition of the soft value function (Levine, 2018), we obtain
|
(20) |
which is a soft Bellman recursion with no intermediate reward. From the soft RL perspective, the twist functions are analogous to a critic, while the proposal plays the role of an actor (Levine, 2018; Haarnoja et al., 2018). We provide detailed discussion of the soft RL case in Sec. B.3, and review RL-inspired losses for twist learning in Sec. C.1.
Benefits of the Probabilistic Perspective
While soft RL is a natural special case of our framework which gives intuition for the role of the twist functions, our approach allows for general target distributions without reference to RL objectives and suggests principled probabilistic resampling using SMC. Further, we develop twist learning techniques inspired by density ratio estimation, including our novel CTL method or the SIXO objective from (Lawson et al., 2022), which are more naturally motivated from a probabilistic perspective. Finally, we leverage our probabilistic perspective to propose novel language model evaluation techniques inspired by Bidirectional Monte Carlo (Grosse et al. (2015, 2016)) in Sec. 5.
4 Learning the Twist Functions
We next consider methods to learn twist functions parameterized by neural networks, presenting a novel contrastive twist learning (CTL) approach in Sec. 4.1. We summarize twist learning methods from related work in Sec. 4.2.
4.1 Contrastive Twist Learning
To match our approximate to the target marginals, we propose to minimize separate KL divergences,
(21) |
While other divergences could be used to learn , we argue that the mass-covering behavior of Eq. 21 is a desirable property for twist learning. Since we separately match each , our hope is that suboptimal learning in early timesteps does not lead to aggressive pruning of partial sequences that would achieve high final target likelihood.
Using Eq. 9, the gradient of Eq. 21 at each becomes
(22) |
which allows us to learn from exact target samples of in the first term when they are available.
We note the similarity of the objective in Eq. 21 and gradient in Eq. 22 to maximum likelihood training of energy-based models (EBM) s. Due to the form of the gradient update, we refer to this method as contrastive twist learning (CTL). We proceed to describe approximate techniques for positive sampling (first term) and negative sampling (second term) in the next subsections.
4.1.1 Approximate Negative Sampling
A common challenge in energy-based modeling is that the second term in Eq. 22 involves sampling from the target with intractable normalization constant . We proceed to estimate the expectation using SIS as in Eq. 4, using a proposal such as the base model or the twist-induced proposal from Sec. 3.2. Note that SMC resampling with learned intermediate twist functions could also be used.
Name | Loss | Learning Principle | |
---|---|---|---|
CTL | (Gradient:) | Marginal Matching with MLE | |
RL | Twist Consistency / Soft Q-Learning | ||
SIXO | Noise Contrastive Estimation | ||
FUDGE | Binary Classification | ||
DPG | Maximum Likelihood (MLE) | ||
PPO | Variational Inference |
4.1.2 (Approximate) Positive Sampling
In contrast to traditional EBM settings, we do not necessarily have exact samples available from a ‘data’ distribution. We describe several settings related to availability of positive samples, which are explored in our experiments in Sec. 7.
Exact Target Samples
Rejection Sampling
Rejection sampling can yield exact target samples when an upper bound on the likelihood ratio is known. In cases where the target is defined by thresholding or an indicator function or joint distribution , we can clearly take for the base model proposal . If the base model yields posterior samples in reasonable time, we can obtain exact samples for training using rejection sampling, and use our twist learning procedures to greatly improve sampling efficiency at generation time.
While an improved proposal should more efficiently draw samples meeting the target conditions, exact rejection sampling would require estimation of . Approximate or quasi rejection sampling might be used in this case, as analysed in Eikema et al. (2022).
Approximate Positive Sampling using SIS or SMC
In cases where exact samples are unavailable and rejection sampling is inefficient or inexact, we leverage SMC sampling with twist targets and any proposal to first draw a set of full sequences . As in Eq. 4, we can use the normalized SMC weights since the last resampling step to estimate the expected gradient in the first term of Eq. 22. Without resampling, we recover SIS estimation.
While both our approximate positive and negative sampling for estimating the expectations in Eq. 22 rely on SMC oder SIS weights (often with the same proposal), the crucial distinction is that weights for positive sampling are based on the true target potential over full sequences.
Truncation to Partial Sequences
For an exact positive sample, we use its truncation to a partial sequence of length (which corresponds to a sample from the desired marginal ) to perform the gradient update in Eq. 22. For approximate positive sampling, we use the same set of final weights to estimate the expected gradient at each timestep.
4.2 Twist Learning Methods from Related Work
We briefly describe alternative approaches for twist learning, with detailed discussion in App. C and a summary of the loss functions for methods used in our experiments in Table 1.
Soft Q-Learning (RL)
Enforcing the recursion in Eq. 13 using a squared error loss is analogous to soft -learning in the RL literature (see Eq. 20), and has been used for twisted SMC in Lioutas et al. (2022). Mudgal et al. (2023) derive a similar squared-error loss, viewing as the reward. Finally, we interpret path consistency losses (Nachum et al., 2017), which were derived in the soft RL setting and have been used for language modeling in Guo et al. (2021); Hu et al. (2023), from an importance sampling perspective in Sec. C.1 and E.1.
SIXO
The SIXO loss proposed by Lawson et al. (2022) learns twist functions using a binary classification task to distinguish samples from the target marginal and base model at each step, which corresponds to noise contrastive estimation (Gutmann & Hyvärinen, 2010) for learning energy-based models. See Sec. C.3.
FUDGE
Yang & Klein (2021) learn twists by constructing a binary classification task to instead learn the conditional likelihood (Eq. 18). This may be viewed as enforcing the step optimality equation in Eq. 12 oder Eq. 18, where rollouts should be obtained using the base model (see Table 1 oder Sec. C.4). Mudgal et al. (2023); Deng & Raffel (2023) similarly propose to enforce the step optimality condition using a squared-error loss, .
5 Evaluating Inference in Language Models
Our SMC framework yields a rich set of tools for evaluating inference techniques in language models, using well-studied quantities such as the log partition function and KL divergence to the target distribution. Remarkably, with access to a single exact sample from the target distribution, we show in Prop. 5.1 that we can obtain upper bounds on in addition to lower bounds. These bounds can tightly sandwich with increasing , thereby ensuring reliable conclusions regarding inference quality.
5.1 Applications of Estimation
Evaluating Fine-Tuned Models
To motivate this section and present an important application of our SMC methods, consider evaluating how well a given matches a target distribution for controlled generation or fine-tuning. Assume that is tractable to sample and evaluate. To calculate the KL divergence to in either direction, we also require an estimate of the partition function ,
(23) |
For , note that we also require samples from the target , as may be readily available using the BDMC trick when is defined as a Bayesian posterior (Sec. 3.3). In such cases, we argue that SMC can be used to accurately bound the value of and estimate each KL divergence above. Estimation of may be particularly important to diagnose mode-dropping in inference techniques such as PPO which optimize the mode-seeking during fine-tuning (Korbak et al., 2022b).
Evaluating Twisted SMC Sampling
After running SIS oder SMC with samples, we can sample a single index as in Eq. 5 to return a single approximate target sample . However, the marginal distribution of this sample, which we denote as , is not tractable due to the need to sum over all possible sets of samples. Nevertheless, we will show below that the tightness of our lower or upper bounds in Prop. 5.1 provides upper bounds on the KL divergences oder , respectively.
5.2 Bidirectional SMC Bounds on
Given the importance of estimation as motivated above, we propose a bidirectional SMC stochastic upper bound which is novel (to the best of our knowledge), and may be of interest outside of the language modeling setting.
Recall from Sec. 2.2 that SMC admits an interpretation as SIS in an extended state space which includes all tokens and resampling indices. We derive lower and upper bounds on in Prop. 5.1 below, with proof and detailed description of the extended state space target and proposal distributions in App. F.
Proposition 5.1.
(Bidirectional SMC Bounds) The log partition function of a target distribution can be lower and upper bounded by
(24) |
The gap in the lower bound is , and the gap in the upper bound is .
See App. F for a detailed discussion and derivations. The proof proceeds by adapting a general approach for extended state space log partition function bounds from Brekelmans et al. (2022) using the probabilistic interpretation of SMC from Andrieu et al. (2010); Maddison et al. (2017). With no resampling, the SIS case recovers the Importance Weighted Autoencoder (IWAE) lower (Burda et al., 2015) and upper (Sobolev & Vetrov, 2019; Brekelmans et al., 2022) bounds.
Sampling from for SMC Upper Bounds
We now discuss sampling from for the expectation in the upper bound, which requires a single, exact sample from the target distribution . This sample may be obtained, for example, using the BDMC trick in Sec. 3.3. Note that Sec. 2.2 and 20 describe sampling from , which is used for the expectation in the lower bound.
Sampling from differs from sampling from by its treatment of the exact target sample. In particular, the partial sequence corresponding to the exact target sample is guaranteed to be cloned once at each resampling step. In other indices, resampling proceeds as in Sec. 2.2, where the exact sample may be cloned additional times based on its incremental importance weights. Finally, we sample next tokens from the proposal, while the value of the remaining chain is fixed by the exact target sample. See App. F and 30 for detailed discussion.
Tightness of the Bidirectional Bounds
Since the bounds in Prop. 5.1 become exact as for any proposal (Burda et al., 2015; Maddison et al., 2017), we can use SMC oder IWAE with large to sandwich the partition function when samples are available.
For a given , the gap in the extended state space bounds in Prop. 5.1 provides further insight into the quality of twisted SMC sampling via the distribution of the marginal sample (Sec. 5.1). In particular, the data processing inequality suggests that and (Grosse et al., 2015, 2016). Thus, if the difference between upper and lower bounds on is small, then we can conclude that the -sample SMC oder SIS procedures in Sec. 2.2 yield a single approximate sample whose distribution is close to the target in symmetrized KL divergence.333Note that the difference between upper and lower bound yields .
6 Related Work
In the previous sections, we have discussed related work as it fit within our SMC framework for language modeling. Note that Lew et al. (2023) consider SMC sampling for language models, but do not learn twist functions or proposals.
Decoding from language models to obtain diverse (Holtzman et al., 2019; Vilnis et al., 2023) or controlled generation (Zhang et al., 2023; Dathathri et al., 2019; Krause et al., 2020; Yang & Klein, 2021; Guo et al., 2021; Qin et al., 2022; Snell et al., 2022; Hu et al., 2023) is an active area of research. Our SMC resampling approach may be viewed as a principled probabilistic extension of best-of- decoding methods. Mudgal et al. (2023) propose a -way decoding scheme based on ‘prefix scorers’ learned using Eq. 13, but also consider using these twists as logits for softmax sampling in the proposal. However, neither of these decoding schemes are aligned with our proposed SMC framework, as we discuss in detail in App. D. For example, greedy decoding with respect to the optimal twists in Prop. 3.2 does not yield samples from the target distribution .
Finally, RL-based methods such as PPO maintain both a policy or proposal network and value network or advantage estimator during training. From the soft RL perspective in Sec. 3.4 and Sec. B.3, the soft values play a similar role as our twist functions for SMC resampling. Liu et al. (2023) consider using Monte Carlo Tree Search (MCTS) based on PPO value estimates to improve decoding, while Chaffin et al. (2022) consider discriminator-driven MCTS.
7 Experiments
We now illustrate empirically how our framework can be used to evaluate inference through bounds and KL divergences between the sampling and target distributions, providing meaningful quantitative comparison between various learning methods. We consider a range of tasks throughout this section, including toxic story generation (as an example of uncovering rare undesirable behavior), generating reviews with varied sentiment, and infilling. For the toxicity and infilling tasks, we consider the TinyStories model (Eldan & Li, 2023)444https://huggingface.co/roneneldan/TinyStories-33M as a small-scale model where the generation is coherent, and use the prompt of ‘Once upon a time, there was a’. For the toxicity task, we elicit responses judged to be toxic by the classifier from Corrêa (2023)555https://huggingface.co/nicholasKluge/ToxicityModel. For the sentiment task, we consider the GPT2-Medium666https://huggingface.co/gpt2-medium model and a classifier trained on Amazon reviews.777https://huggingface.co/LiYuan/amazon-review-sentiment-analysis Our code is available at https://github.com/Silent-Zebra/twisted-smc-lm .
7.1 Comparing SIS and SMC for Estimation
We first use our bounds to test how twisted SMC can improve upon SIS and efficiently sample rare events. We consider the task of toxic story generation. The target is defined as where , is the non-toxic logit, and the threshold corresponds to a greater than 99% chance of being toxic. Rejection sampling under yields exact samples for UB estimation, but can require hundreds of thousands of samples. Thus, this setting also allows us to test the effectiveness of approximate positive sampling for twist training when target samples are rare.
Fig. 2 demonstrates that training twists with CTL and approximate positive sampling can significantly improve log partition function estimation and sampling efficiency. We first note that both upper and lower bounds tighten as increases, as expected, for both SIS and SMC. Using as proposal, the SIS LB (orange) generally fails to draw any samples meeting the threshold. By contrast, SMC resampling (red) with proposal eventually achieves tight upper and lower bounds, yielding near-exact target samples (small KL divergence between the distribution over samples and the target distribution) by the reasoning in Sec. 5.
Proposal | Twist Learning | ||
---|---|---|---|
Twisted | Contrastive | ||
Twisted | RL | ||
Twisted | SIXO | ||
Twisted | FUDGE | ||
DPG | - | ||
PPO | - |
Proposal | Twist Learning | ||
---|---|---|---|
Twisted | Contrastive | ||
Twisted | RL | ||
Twisted | SIXO | ||
Twisted | FUDGE | ||
DPG | - | ||
PPO | - |
Proposal | Twist Learning | ||
---|---|---|---|
Twisted | Contrastive | ||
Twisted | RL | ||
Twisted | SIXO | ||
Twisted | FUDGE | ||
DPG | - | ||
PPO | - |
However, both SMC and SIS with the twist-induced proposal achieve tight estimation and near-exact sampling of the target toxic outputs with orders of magnitude lower . Resampling does not appear to help or hurt these bounds, as the effect of the twists has been incorporated in the proposal in Eq. 15. Thus, we conclude that using the twist-induced proposal can provide significant efficiency gains over base model sampling.
7.2 Evaluating Twist-Induced or Variational Proposals
We next leverage our bounds to evaluate single-sample inference using and , as in Sec. 5.1. Across settings, we consider two SIS proposal-learning methods: PPO (Schulman et al., 2017) which minimizes during optimization, and distributional policy gradient (DPG), which minimizes (Parshakova et al., 2019) (see App. E).
We consider four twist learning methods, including CTL and RL from Sec. 4, SIXO (Lawson et al., 2022), and FUDGE (Yang & Klein, 2021) (see App. C). For each, we measure KL divergences involving the twist-induced proposal . Thus, these experiments showcase two complementary applications of SMC: as a novel inference method yielding a tractable , and as an evaluation method for any other inference method (such as PPO) using -sample bounds on to estimate the KL divergence.
7.2.1 Generating Toxic Stories
We consider toxic story generation as in Sec. 7.1, but using a target , where denotes the probability of the text being judged as toxic by a classifier. Compared to the thresholding target, this task provides a smoother gradient signal for learning (see Sec. G.3) but still allows for exact sampling via rejection sampling. We train using approximate positive sampling, but provide an ablation with exact positive sampling results in Sec. H.3.
We report KL divergences in Table 2. We observe that PPO learns the best proposal with respect to while our CTL method performs best with respect to , which is consistent with the divergences minimized during training. Finally, in Sec. H.1 we provide a qualitative example of a toxic story generated with CTL for with , a case where no exact samples are available.
7.2.2 Generation with Varied Sentiment
For the sentiment setting described earlier, we consider a prompt ‘I bought this’ and target , where indicates a 1-star review and exact samples are available by rejection sampling. We train using approximate positive sampling (see Sec. H.3 for comparison with exact). While all methods achieve low KL divergences in Table 3, CTL performs best for both directions.
7.2.3 Infilling
In this section, we demonstrate a conditional twist function parameterization, where takes input which identifies the target distribution as in Sec. 3.3. We consider an infilling task (Lew et al., 2023; Hu et al., 2023), where the observation variables correspond to continuation tokens, and their likelihood is evaluated under the base model, given generated . The target distribution corresponds to the posterior . Instead of training separate for each , we amortize learning of a conditional twist network .
A second distinctive feature of this setting is that we train from exact posterior or target samples, which are readily available using the BDMC trick in Sec. 3.3. In particular, we may sample sequences of length from the base model , and interpret the prefix as a target sample. Note that we do not explicitly control the continuations tokens defining the tasks. We evaluate average KL divergences over 2000 different , with and , and report results in Table 4.
We find that DPG performs best for both directions of the KL divergence in this setting, likely due to its ability to leverage exact positive samples by minimizing . While CTL also learns from exact positive samples, it requires approximate negative sampling and only performs comparably to SIXO, which uses exact positive samples and performs exact negative sampling under . Finally, PPO trains from samples only, and performs relatively poorly with respect to . We show qualitative results in Sec. H.1 to correlate KL divergence results with sample quality.
Using our KL divergence evaluation methods, we conclude DPG may be preferable when exact target samples are available (Sec. 7.2.3, Sec. H.3), while CTL may be preferable with approximate positive sampling (Sec. 7.2.1, Sec. 7.2.2).
8 Conclusion
In this work, we have presented twisted SMC as a principled probabilistic inference framework for solving numerous capability and safety tasks in LLMs. After discussing different design choices for twisted SMC and their relation to related work, we proposed a novel contrastive method for twist learning. Furthermore, we have proposed novel bidirectional SMC bounds for evaluating LLM inference methods. We demonstrated the effectiveness of our methods quantitatively and qualitatively in both sampling and evaluation across a variety of experimental settings.
Acknowledgments
AM and RG acknowledge support from the Canada CIFAR AI Chairs program and from Open Philanthropy. SZ thanks Juhan Bae for helping debug memory issues in the code. Resources used in this research were provided, in part, by the Province of Ontario, the Government of Canada, and companies sponsoring the Vector Institute. We thank the anonymous reviewers for helpful comments on earlier versions of this paper.
Impact Statement
This paper is motivated by the social consequences of recent advances in the field of machine learning. Controlled generation from language models has the potential to improve safety through better steering of generation to human preferences, more efficient automated red-teaming, and the ability to estimate or bound probabilities of rare behaviors. Any such work is inherently a double-edged sword; the same techniques used to generate samples from a harmless distribution of text could, with a single sign change, be repurposed for generating samples from a harmful distribution of text. Thus, better controlled generation (in our framework, better sampling from target distributions) can provide benefits in the hands of responsible users but can also magnify harms in the hands of malevolent users (who have access to model weights).
Overall, we believe the potential positive social benefits of our work in evaluation and steering language model output towards desired target distributions outweigh the potential negatives stemming primarily from misuse.
References
- Andrieu et al. (2010) Andrieu, C., Doucet, A., and Holenstein, R. Particle markov chain monte carlo methods. Journal of the Royal Statistical Society Series B: Statistical Methodology, 72(3):269–342, 2010.
- Anil et al. (2021) Anil, C., Zhang, G., Wu, Y., and Grosse, R. Learning to give checkable answers with prover-verifier games. arXiv preprint arXiv:2108.12099, 2021.
- Bae et al. (2022) Bae, J., Zhang, M. R., Ruan, M., Wang, E., Hasegawa, S., Ba, J., and Grosse, R. B. Multi-rate vae: Train once, get the full rate-distortion curve. In The Eleventh International Conference on Learning Representations, 2022.
- Bai et al. (2022) Bai, Y., Jones, A., Ndousse, K., Askell, A., Chen, A., DasSarma, N., Drain, D., Fort, S., Ganguli, D., Henighan, T., et al. Training a helpful and harmless assistant with reinforcement learning from human feedback. arXiv preprint arXiv:2204.05862, 2022.
- Banerjee et al. (2005) Banerjee, A., Guo, X., and Wang, H. On the optimality of conditional expectation as a bregman predictor. IEEE Transactions on Information Theory, 51(7), 2005.
- Brekelmans et al. (2022) Brekelmans, R., Huang, S., Ghassemi, M., Ver Steeg, G., Grosse, R. B., and Makhzani, A. Improving mutual information estimation with annealed and energy-based bounds. In International Conference on Learning Representations, 2022.
- Briers et al. (2010) Briers, M., Doucet, A., and Maskell, S. Smoothing algorithms for state–space models. Annals of the Institute of Statistical Mathematics, 62:61–89, 2010.
- Burda et al. (2015) Burda, Y., Grosse, R., and Salakhutdinov, R. Importance weighted autoencoders. arXiv preprint arXiv:1509.00519, 2015.
- Chaffin et al. (2022) Chaffin, A., Claveau, V., and Kijak, E. Ppl-mcts: Constrained textual generation through discriminator-guided mcts decoding. In NAACL 2022-Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, 2022.
- Chopin et al. (2020) Chopin, N., Papaspiliopoulos, O., et al. An introduction to sequential Monte Carlo, volume 4. Springer, 2020.
- Cobbe et al. (2021) Cobbe, K., Kosaraju, V., Bavarian, M., Chen, M., Jun, H., Kaiser, L., Plappert, M., Tworek, J., Hilton, J., Nakano, R., et al. Training verifiers to solve math word problems. arXiv preprint arXiv:2110.14168, 2021.
- Corrêa (2023) Corrêa, N. K. Aira, 2023. URL https://huggingface.co/nicholasKluge/ToxicityModel.
- Dathathri et al. (2019) Dathathri, S., Madotto, A., Lan, J., Hung, J., Frank, E., Molino, P., Yosinski, J., and Liu, R. Plug and play language models: A simple approach to controlled text generation. In International Conference on Learning Representations, 2019.
- Del Moral et al. (2006) Del Moral, P., Doucet, A., and Jasra, A. Sequential monte carlo samplers. Journal of the Royal Statistical Society Series B: Statistical Methodology, 68(3):411–436, 2006.
- Deng & Raffel (2023) Deng, H. and Raffel, C. Reward-augmented decoding: Efficient controlled text generation with a unidirectional reward model. In The 2023 Conference on Empirical Methods in Natural Language Processing, 2023.
- Dohan et al. (2022) Dohan, D., Xu, W., Lewkowycz, A., Austin, J., Bieber, D., Lopes, R. G., Wu, Y., Michalewski, H., Saurous, R. A., Sohl-Dickstein, J., et al. Language model cascades. arXiv preprint arXiv:2207.10342, 2022.
- Domke & Sheldon (2018) Domke, J. and Sheldon, D. R. Importance weighting and variational inference. Advances in neural information processing systems, 31, 2018.
- Doucet et al. (2001) Doucet, A., De Freitas, N., Gordon, N. J., et al. Sequential Monte Carlo methods in practice, volume 1. Springer, 2001.
- Eikema et al. (2022) Eikema, B., Kruszewski, G., Dance, C. R., Elsahar, H., and Dymetman, M. An approximate sampler for energy-based models with divergence diagnostics. Transactions on Machine Learning Research, 2022.
- Eldan & Li (2023) Eldan, R. and Li, Y. Tinystories: How small can language models be and still speak coherent english? arXiv preprint arXiv:2305.07759, 2023.
- Finke (2015) Finke, A. On extended state-space constructions for Monte Carlo methods. PhD thesis, University of Warwick, 2015.
- Go et al. (2023) Go, D., Korbak, T., Kruszewski, G., Rozen, J., Ryu, N., and Dymetman, M. Aligning foundation models for language with preferences through -divergence minimization. In International Conference on Machine Learning, 2023.
- Grosse et al. (2015) Grosse, R. B., Ghahramani, Z., and Adams, R. P. Sandwiching the marginal likelihood using bidirectional monte carlo. arXiv preprint arXiv:1511.02543, 2015.
- Grosse et al. (2016) Grosse, R. B., Ancha, S., and Roy, D. Measuring the reliability of mcmc inference with bidirectional monte carlo. Advances in Neural Information Processing Systems, 2016.
- Gu et al. (2015) Gu, S. S., Ghahramani, Z., and Turner, R. E. Neural adaptive sequential monte carlo. Advances in neural information processing systems, 28, 2015.
- Guo et al. (2021) Guo, H., Tan, B., Liu, Z., Xing, E. P., and Hu, Z. Efficient (soft) q-learning for text generation with limited good data. arXiv preprint arXiv:2106.07704, 2021.
- Gutmann & Hyvärinen (2010) Gutmann, M. and Hyvärinen, A. Noise-contrastive estimation: A new estimation principle for unnormalized statistical models. In International conference on artificial intelligence and statistics, pp. 297–304. JMLR Workshop and Conference Proceedings, 2010.
- Haarnoja et al. (2018) Haarnoja, T., Zhou, A., Abbeel, P., and Levine, S. Soft actor-critic: Off-policy maximum entropy deep reinforcement learning with a stochastic actor. In International conference on machine learning. PMLR, 2018.
- Heng et al. (2020) Heng, J., Bishop, A., Deligiannidis, G., and Doucet, A. Controlled sequential monte carlo. Annals of Statistics, 48(5), 2020.
- Holtzman et al. (2019) Holtzman, A., Buys, J., Du, L., Forbes, M., and Choi, Y. The curious case of neural text degeneration. In International Conference on Learning Representations, 2019.
- Hu et al. (2023) Hu, E. J., Jain, M., Elmoznino, E., Kaddar, Y., Lajoie, G., Bengio, Y., and Malkin, N. Amortizing intractable inference in large language models. arXiv preprint arXiv:2310.04363, 2023.
- Khalifa et al. (2020) Khalifa, M., Elsahar, H., and Dymetman, M. A distributional approach to controlled text generation. arXiv preprint arXiv:2012.11635, 2020.
- Khanov et al. (2024) Khanov, M., Burapacheep, J., and Li, Y. ARGS: Alignment as reward-guided search. In The Twelfth International Conference on Learning Representations, 2024. URL https://openreview.net/forum?id=shgx0eqdw6.
- Korbak et al. (2022a) Korbak, T., Elsahar, H., Kruszewski, G., and Dymetman, M. Controlling conditional language models without catastrophic forgetting. In International Conference on Machine Learning, pp. 11499–11528. PMLR, 2022a.
- Korbak et al. (2022b) Korbak, T., Perez, E., and Buckley, C. L. Rl with kl penalties is better viewed as bayesian inference. arXiv preprint arXiv:2205.11275, 2022b.
- Krause et al. (2020) Krause, B., Gotmare, A. D., McCann, B., Keskar, N. S., Joty, S., Socher, R., and Rajani, N. F. Gedi: Generative discriminator guided sequence generation. arXiv preprint arXiv:2009.06367, 2020.
- Lawson et al. (2018) Lawson, D., Tucker, G., Naesseth, C. A., Maddison, C., Adams, R. P., and Teh, Y. W. Twisted variational sequential monte carlo. In Third workshop on Bayesian Deep Learning (NeurIPS), 2018.
- Lawson et al. (2022) Lawson, D., Raventós, A., Warrington, A., and Linderman, S. Sixo: Smoothing inference with twisted objectives, 2022.
- Levine (2018) Levine, S. Reinforcement learning and control as probabilistic inference: Tutorial and review. arXiv preprint arXiv:1805.00909, 2018.
- Lew et al. (2023) Lew, A. K., Zhi-Xuan, T., Grand, G., and Mansinghka, V. K. Sequential monte carlo steering of large language models using probabilistic programs. arXiv preprint arXiv:2306.03081, 2023.
- Lioutas et al. (2022) Lioutas, V., Lavington, J. W., Sefas, J., Niedoba, M., Liu, Y., Zwartsenberg, B., Dabiri, S., Wood, F., and Scibior, A. Critic sequential monte carlo. In The Eleventh International Conference on Learning Representations, 2022.
- Liu et al. (2021) Liu, A., Sap, M., Lu, X., Swayamdipta, S., Bhagavatula, C., Smith, N. A., and Choi, Y. Dexperts: Decoding-time controlled text generation with experts and anti-experts. In 59th Annual Meeting of the Association for Computational Linguistics and the 11th International Joint Conference on Natural Language Processing, 2021.
- Liu et al. (2023) Liu, J., Cohen, A., Pasunuru, R., Choi, Y., Hajishirzi, H., and Celikyilmaz, A. Don’t throw away your value model! making ppo even better via value-guided monte-carlo tree search decoding. arXiv e-prints, pp. arXiv–2309, 2023.
- Maddison et al. (2017) Maddison, C. J., Lawson, J., Tucker, G., Heess, N., Norouzi, M., Mnih, A., Doucet, A., and Teh, Y. Filtering variational objectives. Advances in Neural Information Processing Systems, 30, 2017.
- Mudgal et al. (2023) Mudgal, S., Lee, J., Ganapathy, H., Li, Y., Wang, T., Huang, Y., Chen, Z., Cheng, H.-T., Collins, M., Strohman, T., et al. Controlled decoding from language models. arXiv preprint arXiv:2310.17022, 2023.
- Nachum et al. (2017) Nachum, O., Norouzi, M., Xu, K., and Schuurmans, D. Bridging the gap between value and policy based reinforcement learning. Advances in neural information processing systems, 30, 2017.
- Ouyang et al. (2022) Ouyang, L., Wu, J., Jiang, X., Almeida, D., Wainwright, C., Mishkin, P., Zhang, C., Agarwal, S., Slama, K., Ray, A., et al. Training language models to follow instructions with human feedback. Advances in Neural Information Processing Systems, 35:27730–27744, 2022.
- Parshakova et al. (2019) Parshakova, T., Andreoli, J.-M., and Dymetman, M. Distributional reinforcement learning for energy-based sequential models. arXiv preprint arXiv:1912.08517, 2019.
- Perez et al. (2022) Perez, E., Huang, S., Song, F., Cai, T., Ring, R., Aslanides, J., Glaese, A., McAleese, N., and Irving, G. Red teaming language models with language models. In Proceedings of the 2022 Conference on Empirical Methods in Natural Language Processing, pp. 3419–3448, 2022.
- Phan et al. (2023) Phan, D., Hoffman, M. D., Douglas, S., Le, T. A., Parisi, A. T., Sountsov, P., Sutton, C., Vikram, S., Saurous, R. A., et al. Training chain-of-thought via latent-variable inference. In Thirty-seventh Conference on Neural Information Processing Systems, 2023.
- Piché et al. (2018) Piché, A., Thomas, V., Ibrahim, C., Bengio, Y., and Pal, C. Probabilistic planning with sequential monte carlo methods. In International Conference on Learning Representations, 2018.
- Qin et al. (2022) Qin, L., Welleck, S., Khashabi, D., and Choi, Y. Cold decoding: Energy-based constrained text generation with langevin dynamics. Advances in Neural Information Processing Systems, 35:9538–9551, 2022.
- Rafailov et al. (2023) Rafailov, R., Sharma, A., Mitchell, E., Ermon, S., Manning, C. D., and Finn, C. Direct preference optimization: Your language model is secretly a reward model. arXiv preprint arXiv:2305.18290, 2023.
- Schulman et al. (2017) Schulman, J., Wolski, F., Dhariwal, P., Radford, A., and Klimov, O. Proximal policy optimization algorithms. arXiv preprint arXiv:1707.06347, 2017.
- Shih et al. (2023) Shih, A., Sadigh, D., and Ermon, S. Long horizon temperature scaling. arXiv preprint arXiv:2302.03686, 2023.
- Snell et al. (2022) Snell, C. V., Kostrikov, I., Su, Y., Yang, S., and Levine, S. Offline rl for natural language generation with implicit language q learning. In The Eleventh International Conference on Learning Representations, 2022.
- Sobolev & Vetrov (2019) Sobolev, A. and Vetrov, D. P. Importance weighted hierarchical variational inference. Advances in Neural Information Processing Systems, 32, 2019.
- Stiennon et al. (2020) Stiennon, N., Ouyang, L., Wu, J., Ziegler, D., Lowe, R., Voss, C., Radford, A., Amodei, D., and Christiano, P. F. Learning to summarize with human feedback. Advances in Neural Information Processing Systems, 33:3008–3021, 2020.
- Vilnis et al. (2023) Vilnis, L., Zemlyanskiy, Y., Murray, P., Passos, A. T., and Sanghai, S. Arithmetic sampling: parallel diverse decoding for large language models. In International Conference on Machine Learning. PMLR, 2023.
- Whiteley & Lee (2014) Whiteley, N. and Lee, A. Twisted particle filters. 2014.
- Yang & Klein (2021) Yang, K. and Klein, D. Fudge: Controlled text generation with future discriminators. In Proceedings of the 2021 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, pp. 3511–3535, 2021.
- Zhang et al. (2023) Zhang, H., Song, H., Li, S., Zhou, M., and Song, D. A survey of controllable text generation using transformer-based pre-trained language models. ACM Computing Surveys, 56(3):1–37, 2023.
- Ziegler et al. (2019) Ziegler, D. M., Stiennon, N., Wu, J., Brown, T. B., Radford, A., Amodei, D., Christiano, P., and Irving, G. Fine-tuning language models from human preferences. arXiv preprint arXiv:1909.08593, 2019.
- Zou et al. (2023) Zou, A., Wang, Z., Kolter, J. Z., and Fredrikson, M. Universal and transferable adversarial attacks on aligned language models. arXiv preprint arXiv:2307.15043, 2023.
Appendix
Typ | Target | References / Examples |
Reward | RLHF (Ziegler et al., 2019; Ouyang et al., 2022; Korbak et al., 2022b) | |
Continuation | Generates tokens based on likelihood of future tokens | |
For , this is in-filling (Lew et al., 2023). | ||
As , disregard , focus on of continuation prob. | ||
- similar to adversarial prompt generation (Zou et al., 2023) | ||
Indicator | Generations from this target must satisfy the properties of set . | |
where is indicator of set : | - Meeting reward threshold | |
if | - Containing topical or specific words in | |
if | - Having certain structure or rhyme (Yang & Klein, 2021), | |
- Valid output according to verifier (Cobbe et al., 2021; Dohan et al., 2022)) | ||
Classifier | Class can be a binary (e.g. toxicity) or multinomial (e.g. 1-5 star reviews) | |
Bayesian posterior for : | ||
(Dathathri et al., 2019; Krause et al., 2020; Liu et al., 2021) | ||
Global | Tempering on entire sequences (long-horizon) vs. per-token (myopic) | |
Temperature | - yields higher quality generation in Shih et al. (2023) | |
Distributional | KL minimization subj. expectation constraints on | |
s.t. | ||
( = optimal Lagrange multipliers for constraints ) | ||
e.g. gender roles/references (Khalifa et al., 2020) | ||
Intermediate | References / Examples | |
Indicator | oder | words of specific length, or specific sets of tokens |
(Khalifa et al., 2020; Lew et al., 2023) | ||
Product of | ||
Experts | prompt intersection (Lew et al., 2023) |
Appendix A Proofs
In this section, we present the sense in which the target marginals correspond to the optimal intermediate distributions in twisted SMC. We defer proof of Prop. 3.2 from the main text to slightly more general version in Sec. B.1 Prop. B.1, although Prop. A.4 provides the analogous statement in terms of the intermediate target distributions instead of the optimal twists .
We also prove Prop. 3.3 from the main text in Sec. A.2 and derive the gradient of the CTL loss (Eq. 22) in Sec. A.3.
A.1 Proof for Optimal Intermediate Target Distributions
In order to achieve sampling from the full joint distribution , each intermediate target must match the intermediate marginal . To formalize this notion, we provide the following definition of optimality, justified by the fact that it yields an exact partition function estimator.
To do so, we will consider the multi-step importance weights
(-Step SMC Weights) |
using a telescoping cancellation in the final equality. The one-step weights correspond to , denoted simply as .
Definition A.1 (Optimal Twisted SMC Sampling).
For a given target distribution , we refer to a twisted SMC procedure, oder (with oder ), as optimal if -step importance weights for all and .
Note, that the role of and is specified in Def. 3.1. We assume for the goal of estimating , and show below that an optimal twisted SMC procedure yields an exact partition function estimator.
Proposition A.2 (Optimal SMC yields Exact Partition Function Estimation).
For any optimal twisted SMC procedure, the resulting estimator of the partition function has zero bias and zero variance.
Proof.
As in Footnote 1 oder App. F 30, consider index timesteps where resampling occurs and fix and . The SMC estimator of becomes for . Using the optimality definition in Def. A.1, we have for all partial sequences . Noting the telescoping multiplicative cancellation and the fact that is constant with respect to indices , we have the following estimator for a single run of an optimal SMC procedure,
(25) |
as desired, assuming . Since is independent of , we conclude has zero bias and zero variance.
With this notion of optimality in mind, we demonstrate the following necessary and sufficient conditions.
Proposition A.3 (Optimality Conditions).
The following conditions are necessary and sufficient for twisted SMC optimality,
(26) |
Proof.
(Necessary) Optimal Twisted SMC : We begin by writing the marginalization of the unnormalized density over prefixes of length as
The normalization constant of can easily be seen to be after summing over above, which yields . We now factorize the -step incremental importance weights (at step , see Eq. -Step SMC Weights) using the above identities, which imply that and
(27) |
In order to have in general, we thus require and for all and .
(Sufficient) Optimal Twisted SMC: Consider the incremental importance weights using and ,
(28) |
which matches the optimality definition in Def. A.1. ∎
Proposition A.4 (Optimal Intermediate Target Distributions).
Proof.
: It is clear that as a special case for . To show , we have
Recursively applying until time suggests that
The target marginals clearly satisfy the recursion
∎
A.2 Proof of Twist-Induced Proposal
See 3.3
Proof.
We seek to minimize the variance of the resulting importance weights, subject to a constraint on the proposal probabilities summing to 1. Introducing a Lagrange multiplier , we have
|
Taking implies
where the derivative in the second term is zero since the cancel. Finally, we have
where (or ) is chosen to enforce normalization. ∎
We focused on the one-step twist-induced proposal in Prop. 3.3. However, this proposal is not optimal for resampling every steps (as would also occur, for example, with adaptive resampling).
Proposition A.5 (Multi-Step Twist Induced Proposal (Generalization of Prop. 3.3)).
For resampling -steps ahead, the optimal proposal (over ) which minimizes the variance of the importance weights is given by
The proof follows the same reasoning as in the proof of Prop. 3.3 above, using the multistep weights from Eq. -Step SMC Weights.
Note that the denominator is not usually tractable for in language modeling applications.
A.3 Derivation of CTL Gradient
Lemma A.6 (Derivation of CTL Gradient).
For the CTL loss , the (negative) gradient with respect to the parameters is given by
(30) |
Proof.
Consider expanding the form of using Eq. 9, noting that the normalization is independent of . Taking the gradient with respect to using the log derivative identity , we have
∎
Appendix B SMC with Intermediate Potentials and Connection with Soft Reinforcement Learning
In the main text, we focused on settings where the target distribution is defined by a potential depending on full sequences only, as in Eq. 1. This setting highlights the need for (learned) twist functions to summarize the future expected value of the potential in the absence of intermediate target information.
In this appendix, we generalize our exposition to show how our twisted SMC framework can accommodate settings with intermediate potentials, which is evocative of connections with soft reinforcement learning (Levine, 2018). We leverage intuition from soft RL while introducing our general probabilistic interpretation, by using to instantiate the soft RL special case. In particular, soft RL will correspond to the terminal potential
(soft RL Definition) |
which corresponds to if the potential is given at the final step only (as in RLHF, Korbak et al. (2022b)). However, we defer detailed discussion of soft RL to Sec. B.3. See Table 5 for several examples of intermediate potentials.
Finally, we formalize a notion of conditional target distributions and twist functions in Sec. B.2, which generalizes the exposition in the main text and captures our conditional twist learning experiments in Sec. 7.2.3.
B.1 Twisted SMC with Intermediate Potentials
To generalize the exposition in the main text, we might consider defining the target as
(31) |
where Eq. 1 and the main text exposition corresponds to for .
Optimal Twists with Intermediate Potentials
Using Eq. 31, the marginal distribution over tokens becomes
(32) | ||||
(soft RL special case) |
As in Prop. 3.2, the goal of the optimal twist functions is to facilitate sampling from the intermediate marginals of the target distribution .
We consider two different quantities involved in defining the optimal twists, which differ in their treatment of the intermediate reward. For the soft RL setting, this corresponds to the natural distinction between -values and (soft) value functions .
(33) |
where means ‘defined to be proportional to’ and in RL notation. See Sec. B.3 for detailed derivations in the soft RL special case. In general, captures the expectation of future potentials from , analogous to the (soft) value function. The twists play a role analogous to a -value, estimating both the immediate and future value . In particular,
(34) |
We continue to refer to as the twist functions and focus on probabilistic interpretations based on instead of (see Sec. B.4 for additional discussion).
To show that this notation is consistent with the main text, consider the optimal twists with no intermediate potentials, for . For , reflect the future expected potential and for , the terminal potential is , with no future potentials after step , i.e. .
Building on Eq. 32-(33) above, the following generalization of Prop. 3.2 defines the ‘optimal’ twists so as to obtain the intermediate target marginals (see Prop. A.4).
Proposition B.1 (Optimal Twists).
For a given target distribution in Eq. 31, the optimal twist functions yield intermediate which match the target marginals. In regions where , the optimal twists are given by
(35) |
Up to a constant independent of , the optimal twists are given by
(36) |
where is absorbed into the normalization constant . The optimal twists satisfy the recursion
(37) |
Remark B.2 (Equivalence Class of and ).
Note that any rescaling of by a constant with respect to will yield the same intermediate marginals , due to the normalization constant which scales with . This defines an equivalent class in the space of functions. The same statement holds for . We express results such as Eq. 36 using proportionality . We define and as the members of their equivalent classes whose normalization and are equal. Thus, we have .
Proof.
Substituting Eq. 36 into Eq. 35, we obtain the desired marginal Eq. 32,
where the final equality follows from absorbing the constant into , with and which normalizes . We will now use to show the recursion in Eq. 37. Note that Eq. 36 implies
where the second line follows from . This demonstrates Eq. 37. ∎
This leads to the following definition of the intermediate twisting targets (we defer the soft RL special case to Sec. B.3).
Definition B.3 (Twisted Intermediate Targets ).
Using approximate twist functions , we define the twisted intermediate target distributions
(Twist Targets () ) | |||
One-Step Twist-Induced Proposal
Using Prop. 3.3 and Def. B.3 and noting that is independent of , we have the optimal one-step proposal
(Twist-Induced Proposal () ) |
where in the second line, we absorb terms which depend only on (and not ) into the normalization. In the soft RL special case, we have (see Eq. Twist-Induced Proposal (soft RL) below).
B.2 Conditional Twisted SMC
To formalize our notion of conditional twists in the infilling experiments (Sec. 7.2.3), we generalize our above framework to explicitly depend on ‘observation’ random variables . This matches the common setting of SMC in state-space models (Briers et al., 2010; Gu et al., 2015; Lawson et al., 2022; Chopin et al., 2020). Our derivations in this section also emphasize that the optimal twist functions in Prop. B.1 learn functions proportional to conditional likelihoods of the future observation variables given the current sequence (see Eq. 40 below)). We recover the unconditional targets in the main text for fixed .
Consider a target distribution conditioned on particular observation random variables . We define a probabilistic model over observations as the intermediate potential,888Note, rescaling by a constant with respect to does not affect the target posterior in Eq. 38. For example, with terminal potential only: as long as the scaling factor is independent of and . which yields the target posterior
(38) |
where we interpret and to make the Bayesian posterior explicit in the last equality. Note, we now seek to estimate a different partition function for each set of observation variables.
Using our infilling experiments in Sec. 7.2.3 as an example, consider (a sequence of) subsequent tokens as observation variables, where the observation model is simply the base language model .
The optimal twists take a similar form as Prop. B.1, but now as a function of the future observation or conditioning information. Further, the optimal twists is proportional to the conditional likelihoods (e.g., ) of future observations given , which marginalize over future tokens (e.g., ),
(40) |
where denotes proportionality up to a constant which depends on only: . These equations can be confirmed by comparing Prop. B.1 with the last two lines in Eq. 39.
The intermediate marginals over partial sequences can finally be rewritten as either
(41) |
We discuss the choice of parameterization using versus in Sec. B.4.
The conditional twist learning formulation matches the setting of Lawson et al. (2022), to which we refer the reader for additional discussion. We use this conditional perspective to derive classification losses for twist learning in Sec. C.3-C.4.
Unconditional Targets as a Special Case
In cases where we are only learning twists for a single set of conditioning information such as a single classifier label or a reward model, note that we can omit explicit conditioning information in and consider setting . With terminal potential only as in the main text, we write and the overall target distribution as . Thus, the formulation in Eq. 38-Eq. 40 strictly generalizes our exposition in the main text and Sec. B.1. With intermediate potentials, we set .
B.3 Connection with Soft Reinforcement Learning
In this section, we more explicitly describe the soft reinforcement learning setting (Levine, 2018) commonly used in RLHF (Korbak et al., 2022b) as a special case of our probabilistic framework. Again, we use notation to indicate that the expressions in this section correspond to a particular instance of our SMC framework where .
Summary of Soft RL Notation
To summarize the below derivations, we state the relevant assignments for the soft RL case. We focus on the optimal case for simplicity, but note that approximate versions play identical roles,
(Twist to Soft RL) |
where oder . In the other direction, we have
(Soft RL to Twist) |
MDP Interpretation
To draw connections with soft RL, we view language model controlled decoding as a MDP, where the prompt is drawn from an initial state distribution , an action policy selects the next token given a partial sequence as the state, and deterministic environment transitions append the selected token to update the state. Discounting may also be included without difficulty. The reward is given by .
Final Target Distribution
We define the target distribution as the solution to the following variational optimization which solves the regularized MDP described above,
(42) | ||||
which corresponds to the choice as in Eq. Twist to Soft RL. The soft value is defined as the maximum value of the above optimization for optimal , and corresponds to the scaled log partition function | ||||
(43) |
which can be confirmed by substituting from Eq. 42 into the maximization on the right side of Eq. 43. Although we omit the dependence of on the prompt for notational simplicity (see Eq. 1), note that naturally corresponds to the soft value of the prompt as the initial state in the MDP.
Optimal Intermediate Marginals and Soft Value
Decomposing the maximization in Eq. 43 into optimizations over each , we define the intermediate soft value as the maximum of the expected future regularized reward
(Optimal Intermediate Soft Value) | ||||
where, in the third line, we isolate the optimization over by (i) assuming optimality at and (ii) substituting the optimal value of the maximization over (i.e. recursively applying the second line).
The optimal intermediate marginal can be written using either oder form (as in Eq. 33 above, or by substituting the optimal oder into the twist targets below).
Twisted Intermediate Targets
We state the approximate twisting targets for both oder parameterizations in order to make connections with soft RL losses in App. C. For approximate oder , we have
(Twist Targets (Soft RL V) ) | ||||
(Twist Targets (Soft RL Q) ) |
where the final twisting target is given by Eq. 42 and the optimal -values are defined as
(44) |
One-Step Proposal
Finally, the optimal one-step proposal (e.g. in form) can be derived either (i) as the twist-induced proposal from Eq. Twist Targets (Soft RL V) and Prop. B.1 or (ii) as the solution to the one-step optimization in the third line of Eq. Optimal Intermediate Soft Value. As in Eq. Twist-Induced Proposal (),
(Twist-Induced Proposal (soft RL)) |
We define the one-step log normalization constant induced by an approximate oder as oder , respectively,
(45) |
such that, for example, .
RLHF Minimizes
Note that, for a given suboptimal , the value of the variational optimization in Eq. 42 is a lower bound on the (scaled) log partition function . Similarly to the standard Evidence Lower Bound, the gap in this lower bound is given by the KL divergence
(46) |
B.4 Remarks on Parameterization
While the twisting targets (Eq. Twist Targets ()) and twist-induced proposal (Eq. Twist-Induced Proposal ()) may equivalently be parameterized using approximate , we focus on the parameterization to match the main text. In particular, recall that the optimal twists satisfy for all . With no intermediate potential ( for ), our approximate twists estimate for . In this section, we describe how the presence of intermediate potentials may affect the choice of twist parameterization.
The twist-induced proposal may not be tractable to evaluate at the final timestep, since it may be costly to evaluate the terminal potential for all given a context (as described in Sec. 3.2). Thus, we learn an approximate for proposal sampling, which can be easily evaluated over next tokens. The final is defined using in order to preserve unbiased estimation. However, after sampling the proposal according to , we only need to evaluate over full sequences to calculate the importance weights at the final step (Eq. 16). See Intermediate Potential Tractable over Sequences Only paragraph below.
Intermediate Potentials Tractable over Sequences
However, in settings where the intermediate potentials are tractable to calculate for all given (e.g. using an indicator function or forward pass in a transformer architecture, as in Table 5), it may be useful to use a parameterization of the twist targets and twist-induced proposal. This allows us to use the exact immediate potentials alongside an estimated , instead of an approximate which estimates both the immediate and future expected value of potentials . Using notation established in Eq. 33 and Prop. B.1, the twisting targets in Eq. Twist Targets () can be rewritten using a parameterization
(Twist Targets () ) |
with as before. The twist-induced proposal and its normalization constant are tractable in this case, by evaluating both the given and parameterized in a single forward pass and normalizing over the discrete vocabulary of next tokens.
Intermediate Potentials Tractable over Sequences Only
In cases where the intermediate potentials are difficult to evaluate, we would like to limit evaluation of to only partial sequences. In this case, parameterizing the twisted targets using oder (Eq. Twist Targets (), Eq. Twist Targets (Soft RL Q)) instead of oder may be preferable to ensure a tractable twist-induced proposal. Separate parameterizations of the proposal (using ) and targets () might also be considered.
In the case of the final timestep described above or in Sec. 3.2, note that we use a learned to parameterize a tractable variational proposal . In this case, we have no future value and only need to evaluate the terminal potential for calculating importance weights over sequences.
Appendix C Twist Learning Losses
In this section, we describe various methods for twist learning beyond our proposed contrastive twist learning (CTL) procedure from Sec. 4. In Sec. C.1, we first describe several losses from the soft RL literature from a probabilistic perspective, building closely on our developments in Sec. B.1. We then proceed to describe SIXO (Lawson et al., 2022) and FUDGE (Yang & Klein, 2021) in Sec. C.3-C.4.
We emphasize losses found in related work or used as experimental baselines using equation tags (e.g. Eq. SIXO), where equations Eq. RL Baseline, Eq. SIXO, Eq. FUDGE are used in our experiments. We consider settings with intermediate potentials in Sec. C.1, but focus on the ( for ) setting in the remainder of the section, as in the main text.
C.1 Soft Q-Learning (RL) and Path Consistency Losses from Log Importance Weights
From the probabilistic perspective of the SMC log importance weights, we can derive several losses for twist learning, including soft Q-learning and path consistency learning (PCL) (Nachum et al., 2017) losses from the soft RL literature.
A general principle for deriving loss functions would be to minimize the variance of the (log) importance weights under some sampling distribution , which leads to constant importance weights at optimality. To draw connections with previous work, we also consider minimizing the square of the log weights, which at optimality, ensures that and are equal to a particular constant. We will proceed to parameterize the twist functions using parameters and consider loss terms which minimize the variance or square of -step log weights at time ,
(47) |
indicates ‘consistency’ in -weight space for -step-ahead weights at time (see Eq. -Step SMC Weights).
We will consider various choices of parameterization and proposal in the following subsections. For example, let denote the log-consistency loss corresponding to twisting targets parameterized by and the twist induced proposal (note, our notation for the one-step weights does not make these choices explicit).
For reference, we derive the log importance weights with intermediate potentials and arbitrary as
(48) |
Various special cases arise from choices of twist parameterizations and proposals in the following subsections.
C.1.1 Soft Q-Learning and RL Baseline
For single-step log-weights, the -parameterization of the targets (Eq. Twist Targets (), Eq. Twist Targets (Soft RL Q)), and the twist-induced proposal (Eq. Twist-Induced Proposal (), Eq. Twist-Induced Proposal (soft RL)), we have
(49) |
where the second term normalizes the twist-induced proposal (Eq. 14).
Minimizing the sum of one-step log consistency losses (i.e. squared log weights in Eq. 48) will yield the familiar soft -learning loss (e.g. Lioutas et al. (2022) Eq. (4)-(5)). Adjusting indexing from Eq. 48 and introducing a stop-gradient within , we have
(Soft Q Learning) | ||||
In the final line, we rewrite the loss for the soft RL special case, using the substitutions in Eq. Twist to Soft RL. Note that the -normalization term is analogous to an induced soft value , so that each squared error loss has the form . Hence, we refer to this loss as Soft Q-learning loss.
The -normalization term, which arises from normalizing the twist-induced proposal, is analogous to the ‘target’ value in deep -learning. Lioutas et al. (2022) consider the soft-Q learning loss to SMC sampling in self-driving applications where interaction with the environment is expensive. Lawson et al. (2018) adopt a similar loss function (using a parameterization of the value ) in the setting of state-space models with tractable intermediate rewards.
RL Baseline with no Intermediate Reward
The soft Q-learning loss in Eq. Soft Q Learning simplifies nicely in the case of no intermediate rewards, as in the main text ( for and ).
Written in terms of twist functions, we separate the terms at and for purposes of exposition
(RL Baseline) | |||
For intermediate timesteps, note that Eq. RL Baseline enforces the recursion in Eq. 13 of the main text, albeit in log space. In Sec. C.2 below, we consider the one-step squared error loss enforcing this recursion directly (without logarithms), i.e. ,
C.1.2 Path Consistency Learning (for Twist Learning)
Using the value parameterization of the targets ( oder , see Eq. Twist Targets (), Eq. Twist Targets (Soft RL V)), the one-step log consistency loss with arbitrary proposal recovers the path-consistency loss (PCL) from Nachum et al. (2017).
Switching to a parameterization of the twisting targets, we substitute into the log importance weights in Eq. 48. The log-consistency loss becomes,
(PCL) | ||||
In particular, substituting the soft RL potential terms from Eq. Twist to Soft RL, Eq. PCL recovers the path consistency loss from Nachum et al. (2017). Note that we derived PCL from an importance sampling perspective, whereas PCL was originally derived by enforcing KKT conditions of the soft RL problem.
We might also consider multi-step losses for various . Minimizing the square of the multi-step log weights with arbitrary recovers the multi-step PCL loss (Nachum et al., 2017),
(multi-step PCL) | ||||
(50) | ||||
where we write the parameterization in Eq. 50 explicitly for use in Sec. D.1. While PCL considers learned a proposal or policy with the goal of approximating the solution of a regularized MDP, we leave joint learning of proposals and SMC target twists oder to future work.
In App. E, we describe using PCL to learn the proposal only (Guo et al., 2021), with the values induced from learned proposal twists which define (in similar fashion to Eq. Twist-Induced Proposal (soft RL), but without reference to twisting targets).
C.2 Controlled Decoding Losses via Optimal Twist Identities (Mudgal et al., 2023)
In Prop. B.1 (or Prop. 3.2 and Eq. 13 in the main text), we noted that the optimal twists satisfy the following relationships
(51) |
We proceed to describe two ‘controlled decoding’ (CD) losses from Mudgal et al. (2023) as using a squared error loss to enforce the optimality conditions in Eq. 51, for settings with no intermediate potentials ( for ). Mudgal et al. (2023) also propose two ways to use the learned ‘twists’ at inference time, which we discuss in relation to our proposed SMC framework in Sec. D.1.
CD-Q
The CD-Q loss from Mudgal et al. (2023) corresponds to minimizing the one-step recursion in Eq. 51 using the expected squared error under a (possibly off-policy) sampling distribution . Assuming no intermediate reward and an additional squared error loss to approximate the terminal potential , we have
(CD-Q) |
Eq. CD-Q enforces the same optimality condition as the Eq. RL Baseline loss (i.e. ), without log scaling of each term inside the squared error. At optimality, we have zero-variance one-step importance weights ( in Eq. 10) for the twist-induced proposal.
CD-FUDGE
While we might naively like to consider the loss to enforce Prop. 3.2 oder Eq. 51, note that marginalization over multiple steps is not tractable in general.
Instead, the CD-FUDGE loss999Note, we reserve the naming convention FUDGE (Yang & Klein, 2021) for a binary cross entropy loss described in Sec. C.4, as opposed to the CD-FUDGE squared error loss from Mudgal et al. (2023). defined as
(CD-FUDGE) |
can be shown to have the same gradient as the desired (but intractable) squared error loss above (Mudgal et al., 2023).
Since the minimizer of the expected squared error (under ) to a single function (which is independent of ) is given by the conditional expectation (Banerjee et al., 2005), we can also see that Eq. CD-FUDGE has the desired minimum . Note, it is crucial that the inner expectation samples rollouts under the base model to obtain the desired conditional expectation as the minimizer. While it appears that any prefix sampling distribution can be used, allows for losses to be calculated at all in a single sampling run.
CD-FUDGE for
We can also compare Eq. CD-FUDGE with the multi-step PCL loss in Eq. 50, choosing for and the proposal equal to the base model so that the proposal terms cancel. Noting that is fixed to the exact terminal potential and choosing the -step PCL loss for each , note that Eq. 50 would reduce to . Deng & Raffel (2023) optimize this loss with reweighting of terms based on timestep (higher weight for ). Eq. CD-FUDGE optimizes the squared error of the difference without log scaling of each term, under appropriate sampling of rollouts. 101010Note the difference in choice of proposal between Eq. CD-Q (twist-induced ) and Eq. CD-FUDGE (base ).
C.3 SIXO: Smoothing Inference with Twisted Objectives (Lawson et al., 2022)
Lawson et al. (2022) adopt a noise-contrastive estimation loss (Gutmann & Hyvärinen, 2010) to learn the target twist functions using binary classification. For state space models, Lawson et al. (2022) adopt our setting in Sec. B.2 with observation variables emitted based on the sampling state (or simply ) and a known likelihood . As discussed in Sec. B.4, in these settings with easily evaluable intermediate potentials, it may be preferable to parameterize as in Eq. Twist Targets (). Lawson et al. (2022) indeed use this parameterization (see their Eq. 5).
Recall from Eq. 39 that the optimal twists or future values amount to conditional likelihoods,
(52) |
where denotes proportionality up to a constant which depends on only. Using Bayes rule, we have
(53) |
noting that and are marginals of by definition. The above reasoning suggests that we may learn the twists, or likelihood ratio , using a classifier which seeks to distinguish samples from and (Gutmann & Hyvärinen, 2010; Lawson et al., 2022). In particular, at each , we classify the event , indicating that , or , indicating that .
Consider a given , which can be either in the unconditional case or drawn from a behavioral policy as discussed below. The SIXO loss becomes
(SIXO) |
Note that we can perform approximate positive sampling as in Sec. 4 to estimate expectations in the first term.
Exact Conditional Sampling
However, we can also use the BDMC trick in Sec. 3.3 to obtain exact target samples for general observation variables. In order to facilitate tractable sampling, we optimize the Eq. SIXO loss over a sampling distribution for all , such that the objective becomes
With this choice, note that we may sample once from using ancestral sampling and use the appropriate truncations for positive sampling terms involving . By shuffling observation variables across a batch of samples, we may obtain samples from the product of marginals oder in the negative sampling term.
In the main text, note that we condition on oder (for infilling).
Gradient and Comparison with CTL
Proceeding with the parameterization for the target with fixed and unconditional twists , the gradient of Eq. SIXO with respect to is
(SIXO Gradient) |
The SIXO gradient is superficially similar to our CTL gradient in Sec. 4.1, in that it involves under positive and negatives samples. However, viewing as the unnormalized density of our intermediate twisting target, we can see that the second term in the sixo update includes . Rewriting to highlight differences with our CTL gradient, we have
(SIXO vs. CTL) |
To compare the two, first note that the positive sampling gradient in SIXO is scaled by a factor of factor (which reflects the misclassification probability under ). For the negative sampling terms, note that is divided by a factor of in the SIXO gradient, instead of the true normalization constant for the gradient of our CTL loss Eq. 22.
C.4 FUDGE: Future Discriminators (Yang & Klein, 2021)
In contrast to SIXO, the FUDGE method from Yang & Klein (2021) seeks to directly learn a discriminative classifier to match the conditional likelihood oder (see Sec. B.2).
As before, we define the joint distribution with . From Eq. 52 above or Sec. B.2 Eq. 40, we have
(54) |
Yang & Klein (2021) consider training a ‘future discriminator’ which, as in Eq. 54 marginalizes over future tokens to predict the expected probability that a full sequence with prefix emits (e.g., let be the probability of a classifier for class , or the probability that satisfies a desired attribute indicated by a boolean ).
In similar fashion to SIXO in the previous section, we define a binary random variable such that
(55) |
where we directly parameterize to be a probability distribution (e.g. using a sigmoid or softmax activation). For a given observation random variable and partial sequence , we can define the FUDGE loss
(FUDGE) | ||||
where, in moving from the second to the third line, we have used the fact that from Eq. 54 and Eq. 55. At the optimum, implies , as desired.
While sampling may be done using an arbitrary distribution over prefixes and observation , Eq. FUDGE requires that rollouts be sampled under the base model in order to ensure sampling from the appropriate distribution . This restriction is similar to what we required in Eq. CD-FUDGE, although the loss in Eq. FUDGE is based on cross entropy classification rather than a squared error. We discuss the choices made in our experiments below.
Yang & Klein (2021) Setting
In the original FUDGE paper, Yang & Klein (2021) consider learning from a dataset of labelled examples oder for a binary observation variable which defines the target distribution.
Unconditional Twist Setting
For the unconditional twist experiments in Sec. 7.2.1-7.2.2, we sample under the base model proposal where the target distribution conditions on and . In particular, we optimize
Conditional Twist Setting
For conditional twist learning, we can consider amortizing learning the twists over some distribution of observation variables . In particular, in our infilling experiments in Sec. 7.2.3, we consider sampling under the following joint distribution,
which we can sample from by first sampling from and then dropping subsequence. Therefore, the overall objective becomes
(56) | ||||
where the expectation includes the expectation under from Eq. FUDGE. Note that rollout of used to sample from should be independent of the rollout used to sample from .
Appendix D Decoding Strategies using Learned Twists from Mudgal et al. (2023)
D.1 Proposal Sampling in Mudgal et al. (2023)
As noted in Sec. C.2 (and in in Mudgal et al. (2023)), the CD losses can be seen as enforcing the optimality conditions
(57) |
In RL terms, we interpret the twists as performing policy evaluation of the expected unregularized ‘reward’ under a fixed policy . The notation of Mudgal et al. (2023) (their Eq. (1), (5), our Eq. 57) indeed corresponds to
(CD reward) |
However, Mudgal et al. (2023) propose to use the learned twist functions to perform one-step sampling as
(CD proposal) |
We proceed to explain that this scheme does not correspond to sampling from the twist-induced proposal under two different definitions of the target (or potential ) in our SMC framework.
Comparison with Our Case:
As we have argued above, the CD-Q and CD-FUDGE may be viewed as learning twist values for a terminal potential . However, our twist-induced proposal which minimizes the variance of the one-step importance weights with these SMC targets would yield
(Twist-Ind. proposal ()) |
which, compared to Eq. CD proposal does not exponentiate or scale and is directly proportional to the expected .
Comparison with Our Case (Soft RL):
The stochastic sampling in Eq. CD proposal is also reminiscent of the twist-induced proposal in the soft RL case of our framework where, in contrast to Eq. CD reward, the target is defined via . As in Sec. B.3,
(Twist-Ind. proposal ()) |
We proceed to write both and as the solution to a variational optimization, highlighting similarities in blue, but noting the different definitions of in terms of . We assume no intermediate potential or reward, and consider the optimal twists to emphasize the role of . Using Mudgal et al. (2023) Eq. 2 and Thm 2.1 (for CD) and Eq. Optimal Intermediate Soft Value (for soft RL), we have
(CD proposal optimization) | ||||
(Soft RL proposal optimization) |
The second terms of Eq. CD proposal optimization and Eq. Soft RL proposal optimization match and correspond to one-step KL divergence regularization of the policy . However, the expectation terms differ as we now discuss.
Soft Values Account for Future Regularization
Using Eq. Optimal Intermediate Soft Value to expand the definition of the soft value function, we see that Eq. Soft RL proposal optimization also implicitly contains an expected terminal reward,
(58) |
As in Eq. 58, this optimization strictly enforces , and the soft value function recovers the expected reward under the base model , which appears in first term Eq. CD proposal optimization. On the other hand, the second term in Eq. CD proposal optimization uses for optimization of the proposal at the current step. This inconsistency in Eq. CD proposal optimization (using in the first term and in the second term) arises from the fact that Eq. CD proposal optimization does not consider the effect of future regularization, while the MDP formulation in Eq. Soft RL proposal optimization does so via the optimization in Eq. 58 and the log-mean-exp form of the soft value function .
On Mudgal et al. (2023)’s One-Step Proposal and SMC Interpretation
As noted in Eq. 57, the twists learned by Mudgal et al. (2023) correspond to policy evaluation for the reward under the base model . However, we have argued that the one-step proposal in Eq. CD proposal (which considers one-step KL regularization of to ) does not immediately fit within our SMC framework. In particular, it is not apparent that the composition of one-step proposals samples from the marginals of a natural target distribution at optimality.
Flexible Inference-Time Scaling
The experiments in Mudgal et al. (2023) evaluate tradeoff curves between expected reward and for various values of regularization strength . Since the twists learned by Mudgal et al. (2023) in Eq. 57 do not depend on , sampling according to Eq. CD proposal oder Eq. CD proposal optimization has the benefit of allowing flexible tempering or -scaling at inference time without additional learning.
Such tradeoff curves are also natural from the perspective of soft-RL (c.f. Eq. 42 and Eq. 46). While Eq. 58 appears to require separate twist-learning procedures for each , flexible inference-time scaling could be achieved with a single training run in our framework by learning a conditional twist network which considers in its input and training loss, or adapting the methods of (Bae et al., 2022) proposed in the context of rate-distortion optimization.
Comparison with Khanov et al. (2024)
Khanov et al. (2024) consider softmax decoding similar to Eq. Twist-Ind. proposal (). However, instead of as the logit, they use a reward model which is trained from full sequences (), but applied to partial sequences without modification, . This clearly does not correspond to a twist or soft value function .
D.2 Blockwise Greedy Decoding in Mudgal et al. (2023)
As an alternative use of the twist functions at inference time and a generalization of best-of- decoding to partial sequences, Mudgal et al. (2023) also consider a ‘blockwise’ decoding scheme using the learned twist functions . In particular, for partial completions of length (from a prefix ), sampled from the base model, , Mudgal et al. (2023) propose to choose
(59) |
and proceed with sampling further continuations with prefix until the next resampling step or an end-of-string token is reached. The selection strategy may seem natural from the unregularized RL (as ) or expected future reward perspective in Sec. D.1, but does not yield samples from with the corresponding optimal twists.
Our SMC framework instead would advocate probabilistic resampling based on the approximate twist functions using the (- or -step) importance weights in Sec. 3 in order to match the desired target distribution.
Finally, Khanov et al. (2024) also consider decoding of next tokens using the unmodified described above.
Appendix E Proposal Learning Methods
We next describe methods for learning variational policies or proposals parameterized by , which can be used for SMC sampling with intermediate targets and learned twists oder parameterized by . Alternatively, such proposals may be used directly in the IWAE bounds on , which rely on simple importance sampling over full sequences as in Sec. 2.1 and do not require the definition of intermediate targets .
In Sec. E.3, we provide a detailed description of the DPG policy gradient method, which can be interpreted as a maximum likelihood objective for a sequential energy-based model. To distinguish this EBM approach from our CTL method for twist learning, we emphasize issues which can arise from naive use of a proposal-learning objective to define intermediate twisting targets for SMC in Sec. E.3.1.
E.1 Path Consistency Learning for Controlled Generation
Guo et al. (2021) consider learning -values to obtain a fine-tuned variational policy which can be directly used as a sampling distribution for controlled generation. Building on the path consistency learning (PCL) loss in Nachum et al. (2017) and Sec. C.1.2, Guo et al. (2021) consider parameterizing the proposal using ,
(60) |
where enforces normalization.
Guo et al. (2021) define the targets using , a slowly-updated target network based on . Using the implied form of the soft value , the single-step PCL loss becomes
(61) |
where indicates stop gradient. Building on the interpretation in Sec. C.1, we view and as the twisting targets, with a learned proposal parameterized by as in Eq. 60 (or Sec. B.4). While the loss in Eq. 61 is similar in practice to the soft Q-learning loss in Sec. C.1.1, we emphasize that the latter is motivated from the SMC perspective with the twisting targets as the primary object of interest and flexibility in the choice of proposal. By contrast, Guo et al. (2021) are interested in learning a proposal policy and do not consider, for example, resampling according to .
Guo et al. (2021); Nachum et al. (2017) also consider ‘multi-step’ PCL losses (Eq. multi-step PCL) which use observed reward during rollouts of length to limit reliance on estimated intermediate values . The objective in Hu et al. (2023) also corresponds to a PCL objective.
E.2 Policy Gradient Methods
Traditional RLHF pipelines use a policy gradient method such as PPO to optimize the objective in Eq. 42,
(62) |
where corresponds to our final twist. As in Eq. 46, the gap in this optimization is the mode-seeking KL divergence .
Notably, this objective does not make use of exact target samples from when they are available. Further, the mode-seeking behavior has been shown to reduce diversity of fine-tuned models (Stiennon et al., 2020; Go et al., 2023). To combat this, Go et al. (2023) derive policy gradient methods to optimize arbitrary -divergences between the learned variational policy and target .
E.3 Policy Gradient with Mass-Covering / Maximum Likelihood KL Divergence
We focus on the case of minimizing the mass-covering kl divergence to train , which constitutes the distributional policy gradients (dpg) method for language model finetuning (Parshakova et al., 2019; Khalifa et al., 2020; Korbak et al., 2022a; Go et al., 2023) and has been used to learn SMC proposals in state-space models in (Gu et al., 2015).
In particular, the gradient of is
(63) |
We recognize the importance weights from Eq. 3. Go et al. (2023) consider estimating Eq. 63 using a moving average estimate of the partition function
(DPG (general )) |
Alternatively, the expectation may thus be estimated using SIS with the variational policy . Using self-normalized importance sampling (SNIS) to estimate Eq. 63 as in Eq. 5 corresponds to , with
(64) |
We use this gradient for DPG proposal learning in the main text experiments, although we use the parameterization described in Eq. DPG below.
DPG as Sequential Maximum Likelihood Objective
We now show Eq. 64 is equivalent to a sequential maximum likelihood EBM objective. Consider minimizing the KL divergence,
(EBM proposal learning) | ||||
(65) |
While this is reminscent of the twist-induced proposal in Prop. 3.3, we emphasize distinctions between energy-based learning of the proposal (DPG) versus energy-based learning of twist functions (CTL) in Sec. E.3.1.
The gradient of Eq. EBM proposal learning becomes
(66) |
Starting from Eq. 64, we now seek to recover Eq. 66. Using Eq. 65, we can write
Substituting into Eq. 64, we recover
(DPG) |
which is an SNIS estimate of the maximum likelihood EBM gradient in Eq. 66, as desired. Note that the expectation over can be calculated exactly.
Comparison with CTL Objective
The gradient in Eq. DPG above appears similar to our CTL objective and gradient in Sec. 4.1. However, the DPG loss in Eq. EBM proposal learning is a single (joint) KL divergence over the entire sequence, whereas CTL optimizes separate KL divergences for each intermediate marginal.
For the DPG gradient in Eq. 66, negative sampling is performed using a ‘positive’ prefix and an exact ‘negative’ sample from the one-step-ahead (Eq. 65, which we have assumed to be tractable). In practice, we obtain the prefixes using the truncation of exact samples or approximate positive sampling with the final target weights as in Eq. DPG. By contrast, the CTL gradient in Eq. 22 involves approximate negative sampling under each .
E.3.1 Naive Use of Proposal Learning to define Twisted SMC Targets
While we have shown in Prop. 3.3 how one-step proposals can be induced from a given set of twist functions or target distributions , we now emphasize that moving the other direction (inducing intermediate twisting targets from a proposal learning scheme parameterized by ) does not yield the correct intermediate targets for resampling (Sec. A.1), even at optimality in the proposal learning objective.
We focus our arguments on learning with the EBM maximum likelihood objective in Eq. EBM proposal learning as an example. The proposal energies appear to play a role analogous to the twist function in the one-step proposal induced from twist targets in Sec. 3.
However, we proceed to show in Prop. E.2 that naive use of to define twisting targets using 111111We assume no intermediate potentials in this section, as in the main text.
(67) |
need not lead to an SMC procedure for which , even if for all . We thus argue that learned using Eq. EBM proposal learning should not be used as target twists in Eq. 67, since they do not yield the optimal interemdiate target distributions at optimality (Sec. A.1).
We begin by showing a simple lemma for the one-step conditionals in Eq. EBM proposal learning.
Lemma E.1.
Any twist induced proposal (induced by ) is invariant to rescaling by an arbitrary constant with respect to ,
(68) |
Proof.
∎
Proposition E.2.
There exist such that (i) and (ii) the SMC targets induced by via Eq. 67 are different from .
Proof.
To satisfy condition (i) of the current proposition, we define
(69) |
which for all , yields optimal proposals: via Lemma E.1. However, it is clear that can break the necessary condition for optimality of SMC sampling that (Prop. A.4). In particular, consider
(70) |
for , which introduces an additional factor which depends on . Thus, the twist target induced from in Eq. 69 is not equal to the desired marginal , despite the fact that all proposals are optimal. ∎
We indeed observed experimentally that resampling based on Eq. 67 after training using Eq. EBM proposal learning could lead to worse SMC bounds than simply calculating the SIS or IWAE bound with .
Optimality in CTL Objective implies Optimal Twisted SMC
E.3.2 SMC with Normalized Targets Induced by Learned Proposal Leads to Uniform Weights
The issue in Prop. E.2 arises from the degree of freedom in the normalization constant of the one-step proposal. To avoid this, we can instead define normalized twisted intermediate targets using
(71) |
where arises from the proposal learned according to Eq. EBM proposal learning.
Crucially, in Eq. 71 are automatically normalized for , as the product of normalized proposals. In this case, SMC resampling with or the twist-induced proposal yields uniform resampling weights,
|
(72) |
Although we were able to construct well-behaved intermediate twisting targets from a proposal-learning scheme , Eq. 72 shows that this does not lead to meaningful intermediate SMC resampling. In other words, for , the marginal distributions of SMC samples with this scheme are simply , the same as we would obtain with no resampling (SIS/IWAE).
Appendix F Bidirectional SMC
In this section, we recall the extended state-space probabilistic interpretation of SMC from (Maddison et al., 2017; Andrieu et al., 2010). The idea is to define an unnormalized target distribution and normalized proposal over an extended state space containing all random variables relevant to SMC sampling and importance weighting with sequences of length . Defining such that its normalization constant matches , we can use simple importance sampling (SIS) in this extended state space to show that -sequence SMC sampling yields an unbiased estimator of , for example (as in Eq. 8). Our end goal is to use this probabilistic interpretation to derive the lower and upper bounds on in Prop. 5.1, following Brekelmans et al. (2022) App. A.
We define the extended state space proposal and target distributions below, noting that our bounds will require sampling from normalized oder , and evaluating and . We summarize the algorithm for sampling in 30, using concatenation notation for simplicity instead of the probabilistic interpretation using index histories in the text.
Single-Sequence Target and Proposal
We construct our importance sampling bounds with the goal of estimating the (log) partition function and sampling from a target distribution . We leverage a sequence of intermediate target distributions, over partial sequences, with the final target and . We assume for all prompts with . Finally, our bounds and sampling procedures also depend on a given set of proposal distribution , as in Sec. 2.2.
Extended State Space Random Variables
Consider an extended state space containing tokens with and indexing random variables with , to represent the results of resampling (Eq. 7),
(73) |
For ease of notation (and similarly to Maddison et al. (2017); Andrieu et al. (2010)), we call attention to our use of recursive backtracking index operations to collect sequences based on the results of resampling . We use lists of index histories to construct sequences of tokens, with two recursive definitions of histories. Letting indicate appending of lists,
(Index Notation) | ||||
For example, the history will be used to construct prefix sequences (i.e. lists of tokens) for sampling a next token . We denote sequences of tokens with the index history in the superscript and also expand the definition for clarity, | ||||
(Sequence Notations) |
In the second line, we define as a sequence of length which concatenates the prefix with next token . The notation represents partial sequences before resampling. By contrast, we will use the notation in the first line of Eq. Sequence Notations to refer to sequences after resampling.
Consider the sequence in a particular index before resampling. Resampling at time may result in choosing for some . Using the first line, we see that for those indices such that . Indeed, this matches the definition of in the second line (before resampling). Thus, the indexing notation in Eq. Sequence Notations reflects resampling or cloning of sequences into the indices such that , which yields prefixes for the next step of sampling () in each index .
(blue indicates changes from SMC proposal algorithm; is an exact posterior sample)
Extended State Space Proposal Distribution
Sampling from the extended state space proposal corresponds to the procedure described in Sec. 2.2 and Alg. 1, which we write as121212Note that , , and are deterministically constructed from during sampling, and simply track the quantities to be calculated when evaluating densities.
(SMC Extended Proposal) | ||||
where , | (74) |
To recount the description above, note that the next token in index is sampled from the proposal, conditioned on the prefix . We concatenate these tokens ( Eq. Sequence Notations) and calculate importance weights. We perform resampling in each index according to , or SNIS with the calculated weights (as in Eq. 7). Finally, after resampling, we clone the sequence in the chosen index into index and proceed to sample with an prefix defined by the indices .
Worked Example: To make this more concrete, we provide a worked example of the procedure in Fig. 3 (a). At step , we resample the token twice (for indices ), with (and in index , set to sample ). We record the prefix history as, for example, , which corresponds to .
At step 2 in (a), we proceed to sample (and similarly ), whereas . We next evaluate the importance weights over three concatenated sequences: , , and , emphasizing that is the final token in each index. Shown in the red circles, we proceed to resample and at step .
Finally, we need to backtrack to obtain the history of the indices for the sequence to be cloned in resampling. Namely, for index where , we concatenate (i.e. the history for time 2, index 1). This list of indices specifies the prefix at step , index . Similar reasoning applies for other indices.
Extended State Space Target
We are finally ready to specify the extended state space target distribution. The crucial difference is to identify a single sequence of length (the choice of index 1 is arbitrary). This sequence will be evaluated under the unnormalized target distribution or exactly sampled from the target in the extended state space target distribution.
In particular, we begin by sampling a full sequence of indices uniformly at random . Setting , we let for all . This implies the following,
(75) | ||||
(76) |
To show these identities, note that and Eq. Index Notation imply , which matches Eq. 76. Applying this recursion again yields . Taken together, these notations allow us to interleave a true target sample in particular indices , guaranteeing that at least one target samples appears at each step.
The extended state space target distribution differs from in its handling of this sequence, which identified as with prefixes using Eq. 75. Noting that sampling amounts to specifying a particular set of as in Eq. 75-(76),
(SMC Extended Target) |
Note, the normalization constant of is equal to since only is unnormalized.
To describe ancestral sampling from Eq. SMC Extended Target, we first sample uniformly as above, and place an exact target sequence in indices (or, equivalently, sequentially sample . At each step, the remaining indices are sampled from the proposal. For resampling, we fix index to hold the exact sample and resample the remaining indices. Note that the resampling weights in Eq. 74 include the exact sample, which may be cloned additional times into indices other than if its importance weights are high. The procedure above simply ensures that at least one exact sequence is sampled. See 30 for the pseudocode of the algorithm.
Note that Maddison et al. (2017, Alg. 2) presents a different SMC extended state space target distribution than ours. In their work, and they sample , while in ours and we sample . However, both targets result in the same log partition function bounds.
Worked Example: In Fig. 1 (c), we use blue circles and arrows to highlight the exact-sample indices and the target sequence . Using the recursion with fixed, we may also express . At step 2, note the target sequence is sampled/evaluated an additional time in index 3.
Importance Weights in the Extended State Space
Assume we are given a fixed set of , which may be sampled from either oder . We proceed to show that the unnormalized importance weights in the extended state space simplify as follows.
Lemma F.1.
Proof.
To evaluate the importance weights (with the goal of estimating ), we consider
(78) | ||||
(79) | ||||
where in , note that terms in the denominator cancel except for the indices . Recalling that from Eq. 76, we expand the resampling weights for the sequence indexed by , , and , | ||||
(80) |
Finally, we obtain a telescoping cancellation of terms using the indexing identities in Eq. 75-(76). In particular, since and with , we can simplify the terms in Eq. 80 as
using the assumption that . Simplifying from Eq. 80, the final unnormalized importance weights become
(81) |
as desired, where we abbreviate the importance weights as for simplicity of notation. Note that we also obtain an unbiased estimate of the partition function via
∎
See 5.1
Proof.
The proof follows directly from Brekelmans et al. (2022) App. A, where it is shown that for such that , we can construct lower and upper bounds on
(82) | ||||
(83) |
where the gap in the lower and upper bounds are and , respectively.
Substituting our SMC probabilistic interpretation in Eq. SMC Extended Proposal and Eq. SMC Extended Target, along with the importance weights in Lemma F.1, into Eq. 83 yields the desired bounds in Eq. 24. ∎
IWAE as a Special Case of our SMC Probabilistic Interpretation
Note that we recover IWAE (or SIS over samples) from SMC with no intermediate resampling. In particular, this corresponds to for all , with importance weighting from resampling occurring at the final step . This yields the average inside the log in the IWAE bounds (i.e., SMC with only one resampling step at ). While the importance weights are crucial to construct the bound, note that ‘resampling’ is not necessary at the final step and we may return all samples along with their weights.
Appendix G Additional Experiment Details
G.1 Common Details Across Experiments
For all experiments, we use the Adam optimizer with . We use custom implementations of SMC. For PPO, we use the HuggingFace TRL PPO Trainer (https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py), modified slightly to accomodate our custom twist parameterizations, as described below. For other methods, we use Optax (Flax) and custom loss functions. We use HuggingFace models (https://huggingface.co/models) for the base models and build custom layers on top of those.
For the twist , we always parameterize for numerical stability. We choose random normal initializations centered at mean 0, with low variance,131313We specifically use a form of Xavier initialization, taking the variance as . such that at the beginning of training, which means the initial sequences generated by the twist-induced proposal approximately come from the base model . All methods are initialized using the same random seeds, and thus start from the same parameter values. See Sec. G.2 for additional discussion of choices for the twist parameterization.
For methods that directly learn a proposal (DPG and PPO), we could directly finetune a language model that outputs . However, in order to ensure consistency in terms of model capacity and ease of learning compared to our twisted proposals, we instead have these proposal learning methods output a modifier which is added to the base model log probability . Note that using random normal initializations centered at mean 0 with low variance, this scheme results in initial samples coming approximately from .
For methods that can make use of exact posterior samples, when we have access to them (Sec. 7.2.3, Sec. H.3), we use them. This is straightforward for methods like DPG, SIXO, and our CTL (unless we have only a single sample, as we discuss for infilling in Sec. G.4 ). For our RL twist learning, we found the best empirical performance training on a combination of and exact samples when they were available (as opposed to just otherwise), and use those results. Similarly, for FUDGE, when exact samples are available, we use them together with samples.
It is not straightforward to compare PPO versus other methods, because of the inner loop in PPO that repeats several clipped gradient steps on a given set of samples. This means that, for a constant number of samples, PPO makes more gradient updates than other methods, while for a constant number of gradient updates, PPO sees fewer samples. Ultimately we decided to normalize based on the number of samples seen; we consider each outer step (including a full PPO inner loop, in our experiments, 4 gradient steps) as a single “gradient update.” We make this choice since sampling is the main bottleneck in terms of computational cost, and the number of inner PPO steps is a hyperparameter which we did not tune.
All of our experiments were run on a single GPU, usually on an NVIDIA A40 with 48G memory. All experiments took no longer than 9 wall-clock hours to run for a single learning method, with infilling (Sec. 7.2.3) experiments taking longest; most other experiments took no longer than 4 hours.
G.2 Choices of Twist Parameterization
The choice of parameterization for the twist is a design decision, independent of our overall framework. While one could keep an entirely separate model for each , this is likely to be memory-inefficient and learn slowly. Instead, we use a shared parameterization across , in the same way that the base language model uses a single architecture to output probability distributions over tokens at each time step . We lay out parameterization choices we considered below.
G.2.1 Linear Head
The simplest choice is to replace the linear head of the base language model with a new linear head, keep the base model fixed, and only train the linear head. This parameterization incurs very little additional computation cost compared to just using the base language model. However, we found this to be capacity constrained in our experiments, achieving worse KL divergences than other parameterizations.
G.2.2 MLP Head
Instead of a linear head, we consider a 3-layer fully connected neural network (MLP) with ReLU non-linearities as a head on top of the base language model. The base model is still kept fixed; only the MLP head is trained. This incurs more computational cost than a linear head (Sec. G.2.1), but the additional cost is still small relative to the cost of a forward pass through the base transformer model. We found this to generally perform well in our experiments, so we use it for the toxicity threshold experiment in Sec. 7.1 and sentiment in Sec. 7.2.2.
G.2.3 Separate Transformer for the Twist
We can also consider an entirely separate transformer that outputs only the twist value. That is, we copy the base model, and repurpose it to output a twist value instead of logits for next-token probabilities. We then train the entire network end-to-end. This is significantly more computationally costly than the former approaches, and does not always do better than just an MLP head (Sec. G.2.2), so we generally do not recommend using this. Still, we found it to perform well in toxicity classification in Sec. 7.2.1, so we use it there.
G.2.4 Separate Transformer for the Twist, with MLP Head
This is similar to Sec. G.2.3, except we also replace the final linear head with a MLP head as in Sec. G.2.2. The model outputs and is trained end-to-end. This is the most computationally costly approach outlined here, and is unnecessary for most of our settings. However, in infilling with 15 generated tokens (Sec. 7.2.3) we found this parameterization to perform materially better than all others, particularly with DPG (Sec. E.3), so we use it for all infilling experiments.
With both this parameterization and Sec. G.2.3, we increase computation time by a factor of around 2 on the forward pass, and significantly increase memory and time usage on the backwards pass during training (though sampling is still the main time bottleneck). Whether this parameterization is worth the potential gain in performance depends on the desired use case. We emphasize that our overall framework is independent of the choice of parameterization.
G.3 Comments on Our Choices of Experiment Settings
Our settings and evaluation metrics in Sec. 7 are chosen to highlight our scientific findings. In particular, the toxicity threshold experiment in Sec. 7.1 demonstrates the improvement of SMC over SIS with the base model with CTL learned twists. In order to highlight this distinction, we have chosen a setting where it is extremely difficult to draw samples satisfying the threshold using the base model (see SIS/IWAE LB line in Fig. 2).
However, twist-learning in the toxicity threshold setting presents challenges. For approximate positive sampling and a thresholded target, all importance weights will be 0 if none of our samples meet the threshold. As noted above, sampling from , or the SMC/twisted proposal for at initialization, is extremely unlikely to draw samples meeting the threshold (i.e., within the support of the target) in the setting of Sec. 7.1. As a result, initial iterations of twist learning receive no learning signal until a thresholded positive sample is drawn from the base model.
To avoid this difficulty for baselines comparisons in Sec. 7.2, we instead focused on settings with given by probabilities. Nevertheless, we note that there are no fundamental differences between the settings considered in Sec. 7.1 and Sec. 7.2. Thus, we may also evaluate single-sample and in the setting of Sec. 7.1, or plot bounds as a function of in for the settings in Sec. 7.2.
G.4 Experiment-Specific Details
Details for SIS and SMC Comparison (Sec. 7.1)
We generate 10 output tokens, and train twists using Sec. 4.1 with approximate positive sampling as discussed in Sec. 4.1.2.
Note that using where directly runs into numerical issues for calculating when and . We instead use everywhere instead of , where . In Fig. 2, this yields a SIS/IWAE LB when no samples are drawn that fall in the set .
We use an MLP head to parameterize the twist, as in Sec. G.2.2, with 768 hidden units per layer, matching the TinyStories model’s embedding dimension. We use a batch size (number of SMC particles/samples) of 1000, with a learning rate of 0.0001, and train using CTL for a total of 5000 gradient updates. We did not tune hyperparameters because we found this setting to work well, and we are not comparing across different learning methods.
For each point on each line on Fig. 2, we run SIS or SMC 20 times, each with a different randomly selected true posterior sample for the upper bounds. The line shows the average value across these 20 runs, while the shaded area shows 95% confidence intervals. See also Sec. G.1 for details common across experiments.
Details for Toxicity (Sec. 7.2.1)
We generate 20 output tokens. We parameterize the twist with a separate network as in Sec. G.2.3. We use a batch size (number of SMC particles/samples) of 100, and train for a total of 2048 gradient updates. For each learning method, we used a coarse grid search over learning rates between 0.000001 and 0.001, using the best one found, which was usually 0.00003 or 0.0001. We run each learning method over 5 different random seeds, reporting the average KL divergence and 95% confidence intervals over these 5 seeds.
For each KL divergence evaluation, we first get sandwich bounds on as laid out in Sec. 5, using the learned twists for the twisted proposal with samples. We find SIS/IWAE and SMC bounds to be similarly tight, so use SIS/IWAE for simplicity. We do this 4 times, providing 4 upper bound estimates and 4 lower bound estimates, and take the average midpoint as the estimate for each experiment. We then take the median (across all learning methods and seeds) of these estimates, and use that as our estimate of . This is then used as a common value for the KL divergence across all methods and seeds, which controls for possible noise in bounds and ensures a fair comparison across methods. We generally have tight bounds (upper bound lower bound), which suggest our estimates are generally accurate, but note that any inaccuracies in estimating would only affect the absolute values of the KL divergences, not the relative differences among different learning methods.
We estimate expectations in Eq. 23 with 2000 samples from and 2000 exact posterior samples for . With 2000 samples, our estimates have 95% confidence intervals generally between 0.05 and 0.10, suggesting that our estimates of expectations are unlikely to be off by more than 0.10. The exact posterior samples were collected offline; such a large number of samples takes several hours to collect, and in practical settings, we would likely only be able to collect a much smaller number of samples. All our methods still apply with fewer exact posterior samples, but the variance in estimates will be higher. See also Sec. G.1 for details common across experiments.
Details for Sentiment (Sec. 7.2.2)
We generate 10 output tokens. We parameterize the twist using an MLP head (Sec. G.2.2), with 1024 hidden units per layer, matching the GPT2Medium model’s embedding dimension. Other details are the same as for toxicity above. Collecting exact posterior samples is less time consuming in this case (less than an hour). See Sec. G.1 for common experimental details.
Details for Infilling (Sec. 7.2.3)
We parameterize the twist using a separate transformer with an MLP head (Sec. G.2.4), with 768 hidden units per layer (matching the TinyStories model’s embedding dimension). We make the following adjustments to the forward pass of the language model for the conditional twist setting. Instead of taking in only , the model takes in both and and passes each separately through the body (everything except the head). Thus, can be seen as a second prompt. For , we take the embeddings produced after the last conditioning token has been processed, broadcast it across time steps , and pass that as additional input to the MLP head (concatenated with embeddings for at each ). This allows the MLP head to produce different output depending on the conditioning tokens.
Since we are in the conditional target distribution setting (Sec. 3.3), with , to compare across learning methods using a single quantity, we estimate and where for infilling. Note that,
where for a fixed , and may be evaluated as before, similar to the unconditional setting. In particular, for our experiments, we use 1-sample estimates of these expectations, as we have a single exact sample from by the BDMC trick (Sec. 3.3), and we choose to draw a single sample from the conditional proposal . We average this over 2000 , approximating the outer expectation, giving us a 2000-sample estimate of 1-sample estimates for the first term in the right hand side of both equations above. With 2000 samples, our estimates have 95% confidence intervals generally between 0.20 and 0.30.
Note that is independent of the learning method or proposal , unlike the first term we discussed above. Thus, in order to save computation and provide us with a more accurate estimate of , we estimate this term only once. Specifically, we consider only the learning method with the lowest KL divergence (DPG), and use SIS/IWAE bounds. For each , we estimate with samples, which gives us relatively tight sandwich bounds, again taking the midpoint as our estimate. We average this over 1000 , giving us a 1000-sample estimate of , where each is itself estimated via 500 samples.
For negative sampling with contrastive twist learning (CTL) in this setting, we need at least 2 negative samples per set of conditioning tokens to perform SIS reweighting; this is in contrast with other twist learning methods which can generate a single negative sample per . For the positive sample, we can use our single exact sample directly, or we can run the SMC upper bound sampling procedure (“Sampling from for SMC Upper Bounds” section in Sec. 5.2) generate more approximate samples using the given exact sample. We find the latter to generally perform slightly better than the former, so adopt that for our infilling experiments.
We use a fixed batch size of 100 across all methods for training twists. To clarify the meaning of this batch size, for methods other than CTL, we have 100 draws of exact samples, each for a different set of conditioning tokens , so we train over 100 different at a time using 1 negative sample per . For CTL, since we need at least 2 negative samples per , we split the batch size of 100 across the number of different and the number of negative samples per , as an additional hyperparameter. We use 25 with 4 negative samples per for the experiments in Sec. 7.2.3 and 10 with 10 negative samples per for the experiments in Sec. H.2. Controlling for batch size in this way is arguably disadvantageous for CTL compared to other learning methods, as it learns on a smaller number of , but this controls for memory requirements, and we feel is more fair than controlling for the number of seen but allowing more negative samples for CTL relative to other methods. We train for a total of 5500 gradient updates. For each method, we used a coarse grid search over learning rates between 0.000001 and 0.001, using the best one found, which was usually 0.0001 or 0.00003. We run each learning method over 5 different random seeds, reporting the average KL divergence and 95% confidence intervals over these 5 seeds. See also Sec. G.1 for details common across experiments.
Appendix H Additional Experimental Results
H.1 Qualitative Results
Toxicity Controlled Generation when No Exact Posterior Samples are Available
In Sec. 7.2.1 we targeted with . We can also target ; higher produces a more peaked distribution of text that is more likely to be of class . However, for we can no longer generate exact posterior samples and thus cannot upper bound . Our twist learning (Sec. 4.1) with approximate positive sampling (Sec. 4.1.2) can learn meaningful twists in this setting, which we illustrate with a qualitative example of a story (200 tokens upper limit) and :
“Once upon a time, there was a little girl named Lily. She had a big thumb that she liked to suck on. One day, Lily went to the park to play with her friends. She was having so much fun until her thumb got stuck in her shoe. She tried to pull it out, but it hurt too much.
Lily started to cry and her friends tried to help her, but they couldn’t get her thumb out either. She was scared and didn’t know what to do. Her friends tried to help her, but they couldn’t get it out either. Sadly, Lily had to go to the hospital and get a big bandage on her thumb. She couldn’t play with her friends anymore. From that day on, Lily never went to the park again.”
The story is coherent and follows the general style of the TinyStories base model, while having a high probability ( 88%) of being toxic according to the toxicity classifier, likely due to the presence of negative words such as ‘suck’, ‘hurt’, ‘cry’, and ‘scared’. This supports the ability of our methods to control outputs based on the chosen posterior distribution.
Sentiment Controlled Generation when No Exact Posterior Samples are Available
As above, we also consider , where , except now is based on the sentiment classifier in Sec. 7.2.2. In Table 6 we provide qualitative examples showing 20 tokens produced with twisted SMC with 500 particles, for , using twists trained with Sec. 4.1. These illustrate our framework’s ability to learn reviews that embody each rating class.141414The results are slightly incoherent; this is a result of the base GPT2-Medium model often being incoherent. Qualitatively, we find that these generations are more coherent than the uncontrolled ones from .
Class (Rating) | Most Text Generated Using Twisted SMC |
---|---|
1-Stern | “I bought this sucker for my wife to use on her python that she sent me last year. It was terrible!” |
2-Sterne | “I bought this throat raiser for combating dental caries. I didn’t really like it. I didn’t like” |
3-Sterne | “I bought this a few months back, and I enjoyed it every time I held it. I’m giving 3 stars” |
4-Sterne | “I bought this product a few months ago and have really enjoyed it. Only reason I gave it 4 stars is because” |
5-Sterne | “I bought this phone recently, and I’ve been loving it! Gorgeous design, outstanding battery life, fantastic camera” |
Proposal | Prompt () | Generated Tokens () | Conditioning Tokens () |
---|---|---|---|
DPG | Once upon a time, there was a | little girl named Mia. She had a big heart. Mia loved to help | others and make them feel safe. Mia liked to |
SIXO | Once upon a time, there was a | girl named Mia. Mia was very kind and compassionate. She always helped her | others and make them feel safe. Mia liked to |
CTL | Once upon a time, there was a | girl named Mia. She had a thin, pink dress. Mia liked to | others and make them feel safe. Mia liked to |
Infilling
In Table 7 we compare qualitative results on an example set of conditioning tokens for DPG, SIXO, and CTL (in that order, to reflect increasing KL divergence). The qualitative results correlate with the quantitative measures of KL divergence; the lowest KL divergence (DPG) corresponds to infilled tokens that respect grammar and the topic. SIXO, which has higher KL divergence, fails to respect grammar. CTL generates incorrect grammar and is less on-topic, corresponding to the highest KL divergence among these methods.
Proposal | Twist Learning | ||
---|---|---|---|
Twisted | Contrastive | ||
Twisted | RL | ||
Twisted | SIXO | ||
Twisted | FUDGE | ||
DPG | - | ||
PPO | - |
H.2 Infilling with Fewer Tokens
We consider the same setting as Sec. 7.2.3 but only generating 2 tokens, conditioned on 1 token. We show KL divergence evaluations in Table 8. Our evaluation reveals interesting differences among learning methods, even in this easier setting where most methods achieve low KL divergence in both directions. DPG and RL learns best, while FUDGE learns notably slower. PPO suffers on , though this may be unsurprising since PPO does not make use of exact samples.
H.3 Approximate vs. Exact Posterior Sampling
In our toxicity and sentiment experiments, we train using approximate samples to reflect the more common real-world setting where the amount of exact samples needed for training are not available. However, here we run an additional ablation experiment for insight into the effect of positive versus approximate sampling. We use rejection sampling (Sec. 4.1.2) to generate exact posterior samples for training. This is much slower than generating approximate samples, so is not a practical strategy for training; we investigate this solely for understanding.
We provide a comparison of KL divergences (evaluated the same way as in the main paper) when training using exact versus approximate samples for a selection of methods that performed well in our previous experiments and are able to make use of samples. Toxicity (Sec. 7.2.1) results are in Table 9 and sentiment (Sec. 7.2.2) results are in Table 10. The first two columns of KL divergences are for exact samples. The next two are for training on the same number of samples, but using approximate positive sampling (Sec. 4.1.2). Overall, for a constant number of samples, having exact samples improves performance for most methods. Note however that there is an additional time cost required for rejection sampling to generate exact samples, so the exact training requires significantly more wall-clock time for any given number of samples.
We also plot the single-sample KL divergence in both directions as a function of training time for exact vs. approximate sampling, on toxicity and sentiment experiments, in Fig. 4. The approximate sampling results match those in the main paper (with different colors). The exact sample results cut off earlier because the time cost required for rejection sampling reduces the number of gradient updates that can be made for a given amount of wall-clock time.
Exact Samples | Same # of Approx. Samples | ||||
---|---|---|---|---|---|
Proposal | Type of Twist Learning | ||||
Twisted | Contrastive | ||||
Twisted | RL | ||||
Twisted | SIXO | ||||
DPG | - |
Exact Samples | Same # of Approx. Samples | ||||
---|---|---|---|---|---|
Proposal | Type of Twist Learning | ||||
Twisted | Contrastive | ||||
Twisted | RL | ||||
Twisted | SIXO | ||||
DPG | - |