Multi-Token Joint Speculative Decoding for Accelerating Large Language Model Inference

Zongyue Qin
&Ziniu Hu111
&Zifan He111
&Neha Prakriya111
&Jason Cong111
&Yizhou Sun111
Department of Computer Science, University of California, Los Angeles, USA. Correspondence to: Zongyue Qin [email protected]
Abstract

Transformer-based Large language models (LLMs) have demonstrated their power in various tasks, but their inference incurs significant time and energy costs. To accelerate LLM inference, speculative decoding uses a smaller model to propose one sequence of tokens, which are subsequently validated in batch by the target large model. Compared with autoregressive decoding, speculative decoding generates the same number of tokens with fewer runs of the large model, hence accelerating the overall inference by 1111-2×2\times2 ×. However, greedy decoding is not the optimal decoding algorithm in terms of output perplexity, which is a direct measurement of the effectiveness of a decoding algorithm. An algorithm that has better output perplexity and even better efficiency than speculative decoding can be more useful in practice. To achieve this seemingly contradictory goal, we first introduce multi-token joint greedy decoding (MJGD), which greedily generates multiple tokens at each step based on their joint perplexity. We show that it leads to better perplexity for the whole output. But the computation cost of MJGD is infeasible in practice. So we further propose multi-token joint speculative decoding (MJSD), which approximates and accelerates the MJGD from two aspects: it approximates the joint distribution of the large model with that of a small model, and uses a verification step to guarantee the accuracy of approximation; then it uses beam decoding to accelerate the sequence generation from the joint distribution. Compared with vanilla speculative decoding, MJSD has two advantages: (1) it is an approximation of MJGD, thus achieving better output perplexity; (2) verification with joint likelihood allows it to accept the longest prefix sub-sequence of the draft tokens with valid perplexity, leading to better efficiency. In addition, we analyze energy consumption during LLM inference and provide theoretical and empirical evidence that, surprisingly, MJSD reduces energy consumption even when the number of operations increases. Experiment results show that our approach enhances perplexity by 21.2%percent21.221.2\%21.2 % than greedy decoding. Moreover, MJSD achieves 2.21×2.21\times2.21 × speed-up and 2.84×2.84\times2.84 × less energy consumption than greedy decoding, and 1.49×1.49\times1.49 × speed-up and 1.62×1.62\times1.62 × less energy consumption than vanilla speculative decoding.

1 Introduction

Large Language Models (LLMs) have exhibited remarkable performance across real-world tasks spanning text and image domains [4, 7, 33, 34]. However, the substantial parameter size of these models leads to high computational cost during inference [27]. For example, ChatGPT is estimated to have an energy demand of 564 MWh per day [8]. Compounding this issue, the autoregressive generation of K𝐾Kitalic_K tokens necessitates calling the model K𝐾Kitalic_K times, involving the repeated loading of weight matrices and intermediate results from GPU global memory to computing units. As a result, LLM inference is often hampered by constraints related to memory bandwidth and communication [20], leading to diminished hardware utilization, heightened inference latency, and elevated energy cost.

To tackle this challenge, researchers have delved into non-autoregressive decoding approaches. Earlier methods [11, 12, 13] aimed at reducing inference latency by concurrently generating multiple tokens. But these methods usually require task-dependent techniques and information to achieve comparable performance to that of autoregressive decoding [17, 36]. More recently, speculative decoding has emerged [20, 5, 17, 32], exploiting the observation that, even when a small model is ten times smaller than a large model, most of the small model’s prediction aligns well with that of a large model. So it leverages a smaller auxiliary model to draft a few future tokens autoregressively, which are subsequently validated in parallel by the larger model. As the smaller model operates significantly faster and parallel token verification incurs a similar time cost as generating a single token, speculative decoding achieves an overall speed-up of 1111-2×2\times2 ×. Meanwhile, it always has an identical sampling distribution as greedy decoding.

However, there are some limitations of the existing studies on speculative decoding. First, they aim to have an identical sampling distribution as greedy sampling. But the goal of decoding algorithms is to generate sequences with optimal perplexity, which is defined as the exponentiated average negative log-likelihood of all tokens. An algorithm that does not follow the same sampling distribution as greedy decoding, but yields output with better perplexity could be more useful in the real world. Second, in these algorithms, if a draft token is rejected, all the draft tokens after it are discarded. But the overall draft sequences may have good quality despite there being a few low-quality tokens. Rejecting these sequences will shorten the average acceptance length for each iteration, lowering the decoding efficiency. Third, existing research does not investigate the impact of speculative decoding on inference energy consumption. While speculative decoding increases the number of FLOPs during inference, it simultaneously reduces the overall inference time, communication, and memory operations, which are crucial factors in determining energy consumption [1, 6]. Thus, it remains unclear if speculative decoding increases or decreases the energy cost of inference.

So we aim to design a new algorithm that is both more efficient and more effective than greedy decoding and vanilla speculative decoding. We first introduce multi-token joint greedy decoding (MJGD), which generates multiple tokens at each step greedily based on their joint likelihood. We empirically show that it has better overall perplexity than greedy decoding does. However, the computational cost of MJGD is infeasible in practice. So we further propose multi-token joint speculative decoding (MJSD), which approximates and accelerates MJGD from two aspects: (1) it approximates the joint distribution of the large model with that of a small model by generating draft tokens with the small model and validating them with the large model; (2) it uses beam decoding to accelerate the sequence generation from the joint distribution of the small model. Compared with vanilla speculative decoding, MJSD has two advantages: (1) it is more effective, as it approximates MJGD, which has better output perplexity than greedy decoding does; (2) it is more efficient because it uses joint likelihood as the verification criterion, which allows it to accept the longest prefix sub-sequence of the draft tokens with valid perplexity. In addition, we analyze the energy consumption of LLM inference. We give theoretical and empirical evidence that, despite that MJSD and other speculative decoding algorithms increase the number of FLOPs needed during LLM inference, they reduce the overall energy consumption by reducing the overhead induced by accessing GPU global memory.

We evaluate our method on text generation tasks with Llama-2 [34] and OPT models [38]. Experiment results show that our approach enhances perplexity by 21.2%percent21.221.2\%21.2 % than the baselines with identical sampling distributions as greedy decoding. In addition, MJSD achieves 2.21×2.21\times2.21 × speed-up and 2.84×2.84\times2.84 × smaller energy consumption than greedy decoding, and 1.49×1.49\times1.49 × speed-up and 1.62×1.62\times1.62 × smaller energy consumption than vanilla speculative decoding. Our code is open-sourced111https://anonymous.4open.science/r/LLMSpeculativeSampling-EE52.

2 Preliminaries

2.1 Decodings of LLMs

Let p𝑝pitalic_p denote the distribution defined by LLM model Mpsubscript𝑀𝑝M_{p}italic_M start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT. Given an input prefix prefix𝑝𝑟𝑒𝑓𝑖𝑥prefixitalic_p italic_r italic_e italic_f italic_i italic_x, the optimal decoding algorithm is to generate a sequence of N𝑁Nitalic_N tokens with maximum likelihood p(x1:N|prefix)𝑝conditionalsubscript𝑥:1𝑁𝑝𝑟𝑒𝑓𝑖𝑥p(x_{1:N}|prefix)italic_p ( italic_x start_POSTSUBSCRIPT 1 : italic_N end_POSTSUBSCRIPT | italic_p italic_r italic_e italic_f italic_i italic_x ). The likelihood of the sequence is directly linked to perplexity of the sequence, which is the exponentiated average negative log-likelihood of all tokens. Based on autoregressive decomposition p(x1:N|prefix)=t=1Np(xt|x1:t1,prefix)𝑝conditionalsubscript𝑥:1𝑁𝑝𝑟𝑒𝑓𝑖𝑥superscriptsubscriptproduct𝑡1𝑁𝑝conditionalsubscript𝑥𝑡subscript𝑥:1𝑡1𝑝𝑟𝑒𝑓𝑖𝑥p(x_{1:N}|prefix)=\prod_{t=1}^{N}p(x_{t}|x_{1:t-1},prefix)italic_p ( italic_x start_POSTSUBSCRIPT 1 : italic_N end_POSTSUBSCRIPT | italic_p italic_r italic_e italic_f italic_i italic_x ) = ∏ start_POSTSUBSCRIPT italic_t = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_p ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 : italic_t - 1 end_POSTSUBSCRIPT , italic_p italic_r italic_e italic_f italic_i italic_x ), the perplexity is defined as:

PPL(x1:N)=exp{1Nt=1Nlogp(xt|x1:t1)}𝑃𝑃𝐿subscript𝑥:1𝑁1𝑁superscriptsubscript𝑡1𝑁𝑝conditionalsubscript𝑥𝑡subscript𝑥:1𝑡1PPL(x_{1:N})=\exp\left\{-\frac{1}{N}\sum_{t=1}^{N}\log p(x_{t}|x_{1:t-1})\right\}italic_P italic_P italic_L ( italic_x start_POSTSUBSCRIPT 1 : italic_N end_POSTSUBSCRIPT ) = roman_exp { - divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_t = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT roman_log italic_p ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 : italic_t - 1 end_POSTSUBSCRIPT ) } (1)

Now we introduce commonly used decoding approaches.

Greedy Decoding.

Greedy decoding samples the next single token xtsubscript𝑥𝑡x_{t}italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT based on p(|x1:t1,prefix)p(\cdot|x_{1:t-1},prefix)italic_p ( ⋅ | italic_x start_POSTSUBSCRIPT 1 : italic_t - 1 end_POSTSUBSCRIPT , italic_p italic_r italic_e italic_f italic_i italic_x ). According to the autoregressive decomposition, greedy sampling leads to the desired joint distribution for the sequence. In practice, however, the objective of decoding algorithms is to generate sequences with high likelihood. Instead of sampling from the joint distribution, a warping operation 𝒯𝒯\mathcal{T}caligraphic_T usually is applied to boost the high probability region, which transforms the original distribution p𝑝pitalic_p to p~=𝒯p~𝑝𝒯𝑝\tilde{p}=\mathcal{T}\circ pover~ start_ARG italic_p end_ARG = caligraphic_T ∘ italic_p. A very popular warping operation is argmax𝑎𝑟𝑔𝑚𝑎𝑥argmaxitalic_a italic_r italic_g italic_m italic_a italic_x, which only selects the sample with the highest probability. Other possible warping operations include top-k. For most of the warping operations, autoregressive decomposition no longer holds after the warping. For example, argmaxx1:Np(x1:N|prefix)t=1Nargmaxxtp(xt|x1:t1,prefix)subscriptsubscript𝑥:1𝑁𝑝conditionalsubscript𝑥:1𝑁𝑝𝑟𝑒𝑓𝑖𝑥superscriptsubscriptproduct𝑡1𝑁subscriptsubscript𝑥𝑡𝑝conditionalsubscript𝑥𝑡subscript𝑥:1𝑡1𝑝𝑟𝑒𝑓𝑖𝑥\arg\max_{x_{1:N}}{p}(x_{1:N}|prefix)\neq\prod_{t=1}^{N}\arg\max_{x_{t}}{p}(x_% {t}|x_{1:t-1},prefix)roman_arg roman_max start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT 1 : italic_N end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_p ( italic_x start_POSTSUBSCRIPT 1 : italic_N end_POSTSUBSCRIPT | italic_p italic_r italic_e italic_f italic_i italic_x ) ≠ ∏ start_POSTSUBSCRIPT italic_t = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT roman_arg roman_max start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_p ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 : italic_t - 1 end_POSTSUBSCRIPT , italic_p italic_r italic_e italic_f italic_i italic_x ).

Beam Decoding.

Beam decoding aims to do a better job than greedy decoding. For each position t𝑡titalic_t (1tN1𝑡𝑁1\leq t\leq N1 ≤ italic_t ≤ italic_N), it maintains M>1𝑀1M>1italic_M > 1 candidate sequences, which are also called beams. Still take argmax𝑎𝑟𝑔𝑚𝑎𝑥argmaxitalic_a italic_r italic_g italic_m italic_a italic_x warping operator as an example and assume we have already kept the M𝑀Mitalic_M sequences t1={x1:t1(1),,x1:t1(M)}subscript𝑡1superscriptsubscript𝑥:1𝑡11superscriptsubscript𝑥:1𝑡1𝑀\mathcal{I}_{t-1}=\{x_{1:t-1}^{(1)},\ldots,x_{1:t-1}^{(M)}\}caligraphic_I start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT = { italic_x start_POSTSUBSCRIPT 1 : italic_t - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT , … , italic_x start_POSTSUBSCRIPT 1 : italic_t - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_M ) end_POSTSUPERSCRIPT } at position t1𝑡1t-1italic_t - 1 with highest likelihood, M𝑀Mitalic_M tokens are then sampled from p(|x1:t1(m),prefix)p(\cdot|x_{1:t-1}^{(m)},prefix)italic_p ( ⋅ | italic_x start_POSTSUBSCRIPT 1 : italic_t - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_m ) end_POSTSUPERSCRIPT , italic_p italic_r italic_e italic_f italic_i italic_x ) for each sequence m𝑚mitalic_m and the top M𝑀Mitalic_M sequences with the highest likelihood p(x1:t|prefix)𝑝conditionalsubscript𝑥:1𝑡𝑝𝑟𝑒𝑓𝑖𝑥p(x_{1:t}|prefix)italic_p ( italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT | italic_p italic_r italic_e italic_f italic_i italic_x ) will be kept. It is widely recognized that beam decoding is a closer approximation to p~(x1:N|prefix)~𝑝conditionalsubscript𝑥:1𝑁𝑝𝑟𝑒𝑓𝑖𝑥\tilde{p}(x_{1:N}|prefix)over~ start_ARG italic_p end_ARG ( italic_x start_POSTSUBSCRIPT 1 : italic_N end_POSTSUBSCRIPT | italic_p italic_r italic_e italic_f italic_i italic_x ) than greedy decoding [18], with few exceptions under special construction. The computation for each beam can be parallelized, thus its run time is about the same as greedy decoding. The fact that beam decoding is more effective than greedy decoding also suggests that the output perplexity of greedy decoding still has room for improvement.

2.2 Vanilla Speculative Decoding 

Besides effectiveness, speculative decoding is proposed by [20, 5] to accelerate the inference of LLMs. It utilizes a small model to generate the next γ𝛾\gammaitalic_γ tokens and then uses the large model to verify the drafted tokens in parallel, which is summarized below:

  1. 1.

    Let prefix𝑝𝑟𝑒𝑓𝑖𝑥prefixitalic_p italic_r italic_e italic_f italic_i italic_x be the input context, the small model samples γ𝛾\gammaitalic_γ draft tokens x1,,xγsubscript𝑥1subscript𝑥𝛾x_{1},\ldots,x_{\gamma}italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_x start_POSTSUBSCRIPT italic_γ end_POSTSUBSCRIPT using greedy decoding based on the warped predicted conditional probability q~(xt|x1:t1,prefix))\tilde{q}(x_{t}|x_{1:t-1},prefix))over~ start_ARG italic_q end_ARG ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 : italic_t - 1 end_POSTSUBSCRIPT , italic_p italic_r italic_e italic_f italic_i italic_x ) ) for t=1,,γ𝑡1𝛾t=1,\ldots,\gammaitalic_t = 1 , … , italic_γ, where q~=𝒯q~𝑞𝒯𝑞\tilde{q}=\mathcal{T}\circ qover~ start_ARG italic_q end_ARG = caligraphic_T ∘ italic_q and q𝑞qitalic_q is the small model’s output distribution.

  2. 2.

    The large model verifies the draft tokens in parallel by computing the conditional probability p~(xt|x1:t1,prefix)~𝑝conditionalsubscript𝑥𝑡subscript𝑥:1𝑡1𝑝𝑟𝑒𝑓𝑖𝑥\tilde{p}(x_{t}|x_{1:t-1},prefix)over~ start_ARG italic_p end_ARG ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 : italic_t - 1 end_POSTSUBSCRIPT , italic_p italic_r italic_e italic_f italic_i italic_x ) for t=1,,γ𝑡1𝛾t=1,\ldots,\gammaitalic_t = 1 , … , italic_γ.

  3. 3.

    Each draft token xtsubscript𝑥𝑡x_{t}italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is accepted with probability min(1,p~(xt)/q~(xt))1~𝑝subscript𝑥𝑡~𝑞subscript𝑥𝑡\min(1,\tilde{p}(x_{t})/\tilde{q}(x_{t}))roman_min ( 1 , over~ start_ARG italic_p end_ARG ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) / over~ start_ARG italic_q end_ARG ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ). The draft tokens before the first rejected token are kept as the decoding output. An additional token is sampled from a residual distribution as a correction to the first rejected token. Then the accepted tokens and the resampled token are appended to the context prefix𝑝𝑟𝑒𝑓𝑖𝑥prefixitalic_p italic_r italic_e italic_f italic_i italic_x as the input to the next iteration.

  4. 4.

    Repeat step 1-3 until reaching the stopping criteria, e.g., reaching the length limit..

Because the large model verifies γ𝛾\gammaitalic_γ tokens in parallel with one run, the time cost is smaller than calling it γ𝛾\gammaitalic_γ times. Meanwhile, although the small model still runs in an autoregressive way, its inference speed is much faster than the large model. So speculative decoding can accelerate the inference process of LLMs. Additionally, it is proven that each token xisubscript𝑥𝑖x_{i}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT generated by speculative sampling follows the identical sampling distribution p~(xt|x1:t1,prefix)~𝑝conditionalsubscript𝑥𝑡subscript𝑥:1𝑡1𝑝𝑟𝑒𝑓𝑖𝑥\tilde{p}(x_{t}|x_{1:t-1},prefix)over~ start_ARG italic_p end_ARG ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 : italic_t - 1 end_POSTSUBSCRIPT , italic_p italic_r italic_e italic_f italic_i italic_x ) as greedy decoding [20]. However, since p~(x1:N|prefix))t=1Np~(xt|x1:t1,prefix)\tilde{p}(x_{1:N}|prefix))\neq\prod_{t=1}^{N}\tilde{p}(x_{t}|x_{1:t-1},prefix)over~ start_ARG italic_p end_ARG ( italic_x start_POSTSUBSCRIPT 1 : italic_N end_POSTSUBSCRIPT | italic_p italic_r italic_e italic_f italic_i italic_x ) ) ≠ ∏ start_POSTSUBSCRIPT italic_t = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT over~ start_ARG italic_p end_ARG ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 : italic_t - 1 end_POSTSUBSCRIPT , italic_p italic_r italic_e italic_f italic_i italic_x ), the effectiveness of speculative decoding also has room for improvement. Therefore, instead of having an identical sampling distributon as greedy decoding, our algorithm aims to achieve a better approximation to p~(x1:N|prefix)~𝑝conditionalsubscript𝑥:1𝑁𝑝𝑟𝑒𝑓𝑖𝑥\tilde{p}(x_{1:N}|prefix)over~ start_ARG italic_p end_ARG ( italic_x start_POSTSUBSCRIPT 1 : italic_N end_POSTSUBSCRIPT | italic_p italic_r italic_e italic_f italic_i italic_x ).

3 Methodology

The perplexity is a direct measurement of the effectiveness of a decoding algorithm. In practice, when the model is well-trained, a better perplexity usually leads to better downstream effectiveness. Although it is possible that improving perplexity does not improve the performance on downstream tasks, the problem is due to the model itself instead of the decoding algorithm. Therefore, the goal of this work is to design an algorithm that yields better output perplexity and better efficiency than greedy decoding and vanilla speculative decoding.

In this section, we first introduce multi-token joint greedy decoding (MJGD) to give a motivating example for why generating multiple tokens based on their joint likelihood can lead to better perplexity. Then we introduce multi-token joint speculative decoding (MJSD), which approximates and accelerates MJGD by exploiting a small model.

3.1 Multi-Token Joint Greedy Decoding

Let Mpsubscript𝑀𝑝M_{p}italic_M start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT be the large target model, the goal is to generate a sequence of N𝑁Nitalic_N tokens with maximum likelihood p(x1:N|prefix)𝑝conditionalsubscript𝑥:1𝑁𝑝𝑟𝑒𝑓𝑖𝑥p(x_{1:N}|prefix)italic_p ( italic_x start_POSTSUBSCRIPT 1 : italic_N end_POSTSUBSCRIPT | italic_p italic_r italic_e italic_f italic_i italic_x ). We now introduce a new decoding algorithm to improve greedy decoding in terms of perplexity.

Definition 3.1.

Multi-Token Joint Greedy Decoding. Let Mpsubscript𝑀𝑝M_{p}italic_M start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT be the large target model with distribution p𝑝pitalic_p. Different from single-token greedy decoding, multi-token joint greedy decoding (MJGD) generates the next γisubscript𝛾𝑖\gamma_{i}italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT tokens at step i𝑖iitalic_i based on their joint conditional probability p(xt+1:t+γi|x1:t)𝑝conditionalsubscript𝑥:𝑡1𝑡subscript𝛾𝑖subscript𝑥:1𝑡p(x_{t+1:t+\gamma_{i}}|x_{1:t})italic_p ( italic_x start_POSTSUBSCRIPT italic_t + 1 : italic_t + italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT )222In the paper, we omit prefix𝑝𝑟𝑒𝑓𝑖𝑥prefixitalic_p italic_r italic_e italic_f italic_i italic_x when there is no ambiguity., where γisubscript𝛾𝑖\gamma_{i}italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is an integer no less than 1 and t=i=1i1γi𝑡superscriptsubscriptsuperscript𝑖1𝑖1subscript𝛾superscript𝑖t=\sum_{i^{\prime}=1}^{i-1}{\gamma_{i^{\prime}}}italic_t = ∑ start_POSTSUBSCRIPT italic_i start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i - 1 end_POSTSUPERSCRIPT italic_γ start_POSTSUBSCRIPT italic_i start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT, i.e., the total tokens generated in the previous i1𝑖1i-1italic_i - 1 steps.

Refer to caption
Refer to caption
Figure 1: Perplexity and rougeL score of the output when γi=Ksubscript𝛾𝑖𝐾\gamma_{i}=Kitalic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_K for MJGD with OPT-125M and Llama-2-68M fined-tuned on ChatGPT-Prompts [26] dataset.

Greedy decoding is a special case of MJGD where γi=1subscript𝛾𝑖1\gamma_{i}=1italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 1, ifor-all𝑖\forall i∀ italic_i. When γ1=Nsubscript𝛾1𝑁\gamma_{1}=Nitalic_γ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = italic_N, MJGD generates the optimal sequence by returning the sequence with the highest likelihood. So intuitively, output perplexity should improve as γisubscript𝛾𝑖\gamma_{i}italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT increases. Additionally, generating γisubscript𝛾𝑖\gamma_{i}italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT tokens simultaneously allows MJGD to return the optimal γisubscript𝛾𝑖\gamma_{i}italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT tokens for step i𝑖iitalic_i. In contrast, vanilla greedy decoding selects each token without considering any future tokens. So MJGD is less prone to choosing local optima than greedy decoding.

Empirical evidence supports our claim. Figure 1 shows the output perplexity and downstream ROUGE-L scores of MJGD with γisubscript𝛾𝑖\gamma_{i}italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT set to a constant K𝐾Kitalic_K, where K=1,,5𝐾15K=1,\ldots,5italic_K = 1 , … , 5. Notice that setting K=1𝐾1K=1italic_K = 1 is equivalent to vanilla greedy decoding. We use beam decoding to approximate the argmax\arg\maxroman_arg roman_max sampling from the joint distribution p(xt+1:t+K|x+1:t,prefix)p(x_{t+1:t+K}|x+{1:t},prefix)italic_p ( italic_x start_POSTSUBSCRIPT italic_t + 1 : italic_t + italic_K end_POSTSUBSCRIPT | italic_x + 1 : italic_t , italic_p italic_r italic_e italic_f italic_i italic_x ). We can see that when K𝐾Kitalic_K increases from 1 to 2, there is a significant improvement in perplexity. The perplexity continues to decrease when K𝐾Kitalic_K increases further. It confirms our claim that increasing γisubscript𝛾𝑖\gamma_{i}italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT will increase the output perplexity. So i=1np~(xΓi1+1:Γi|prefix,x1:Γi1)superscriptsubscriptproduct𝑖1𝑛~𝑝conditionalsubscript𝑥:subscriptΓ𝑖11subscriptΓ𝑖𝑝𝑟𝑒𝑓𝑖𝑥subscript𝑥:1subscriptΓ𝑖1\prod_{i=1}^{n}\tilde{p}(x_{\Gamma_{i-1}+1:\Gamma_{i}}|prefix,x_{1:\Gamma_{i-1% }})∏ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT over~ start_ARG italic_p end_ARG ( italic_x start_POSTSUBSCRIPT roman_Γ start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT + 1 : roman_Γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT | italic_p italic_r italic_e italic_f italic_i italic_x , italic_x start_POSTSUBSCRIPT 1 : roman_Γ start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) is a better approximation to p~(x1:N|prefix)~𝑝conditionalsubscript𝑥:1𝑁𝑝𝑟𝑒𝑓𝑖𝑥\tilde{p}(x_{1:N}|prefix)over~ start_ARG italic_p end_ARG ( italic_x start_POSTSUBSCRIPT 1 : italic_N end_POSTSUBSCRIPT | italic_p italic_r italic_e italic_f italic_i italic_x ) than t=1Np~(xt|prefix,x1:t1)superscriptsubscriptproduct𝑡1𝑁~𝑝conditionalsubscript𝑥𝑡𝑝𝑟𝑒𝑓𝑖𝑥subscript𝑥:1𝑡1\prod_{t=1}^{N}\tilde{p}(x_{t}|prefix,x_{1:t-1})∏ start_POSTSUBSCRIPT italic_t = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT over~ start_ARG italic_p end_ARG ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_p italic_r italic_e italic_f italic_i italic_x , italic_x start_POSTSUBSCRIPT 1 : italic_t - 1 end_POSTSUBSCRIPT ). Here Γ0=0subscriptΓ00\Gamma_{0}=0roman_Γ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = 0, and Γi=j=1iγisubscriptΓ𝑖superscriptsubscript𝑗1𝑖subscript𝛾𝑖\Gamma_{i}=\sum_{j=1}^{i}\gamma_{i}roman_Γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT for i1𝑖1i\geq 1italic_i ≥ 1, and n𝑛nitalic_n is the total number of steps of MJGD.

Moreover, the ROUGE-L score also improves with K𝐾Kitalic_K, supporting our claim that better perplexity reflects enhanced performance in downstream tasks. This demonstrates that MJGD not only achieves lower perplexity but also yields higher quality text generation, making it a superior decoding strategy compared to vanilla greedy decoding333Additional discussions comparing MJGD and greedy decoding are provided in the Appendix F..

3.2 Multi-Token Joint Speculative Decoding

Unfortunately, the computation cost of MJGD is infeasible in practice, since the time and space complexity to compute the joint distribution of γisubscript𝛾𝑖\gamma_{i}italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT tokens is |V|γisuperscript𝑉subscript𝛾𝑖|V|^{\gamma_{i}}| italic_V | start_POSTSUPERSCRIPT italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT. Inspired by speculative decoding and the facts that “even when a small model is an order of magnitude smaller than a large model, only a small fraction of the small model’s prediction deviate from those of the large model” [20, 17], we propose multi-token joint speculative decoding (MJSD), which exploits a small model Mqsubscript𝑀𝑞M_{q}italic_M start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT to accelerate MJGD approximately. The core idea is to (1) use the joint distribution q(xt+1:t+γi|x1:t)𝑞conditionalsubscript𝑥:𝑡1𝑡subscript𝛾𝑖subscript𝑥:1𝑡q(x_{t+1:t+\gamma_{i}}|x_{1:t})italic_q ( italic_x start_POSTSUBSCRIPT italic_t + 1 : italic_t + italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ) output by Mqsubscript𝑀𝑞M_{q}italic_M start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT to approximate p(xt+1:t+γi|x1:t)𝑝conditionalsubscript𝑥:𝑡1𝑡subscript𝛾𝑖subscript𝑥:1𝑡p(x_{t+1:t+\gamma_{i}}|x_{1:t})italic_p ( italic_x start_POSTSUBSCRIPT italic_t + 1 : italic_t + italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT )444It is also valid to approximate p~~𝑝\tilde{p}over~ start_ARG italic_p end_ARG with q~~𝑞\tilde{q}over~ start_ARG italic_q end_ARG. Without loss of generality, we consider non-warped distribution in the illustration of MJSD. and generate γ𝛾\gammaitalic_γ draft tokens from q(xt+1:t+γi|x1:t)𝑞conditionalsubscript𝑥:𝑡1𝑡subscript𝛾𝑖subscript𝑥:1𝑡q(x_{t+1:t+\gamma_{i}}|x_{1:t})italic_q ( italic_x start_POSTSUBSCRIPT italic_t + 1 : italic_t + italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ), then (2) use the large model to validate draft tokens in parallel and accept the longest draft sub-sequence that passes verification, and (3) sample an additional token from the distribution of the large model without extra overhead to ensure at least one token is generated at each iteration. However, it is still infeasible to directly generate draft tokens from the joint distribution q(xt+1:t+γi|x1:t)𝑞conditionalsubscript𝑥:𝑡1𝑡subscript𝛾𝑖subscript𝑥:1𝑡q(x_{t+1:t+\gamma_{i}}|x_{1:t})italic_q ( italic_x start_POSTSUBSCRIPT italic_t + 1 : italic_t + italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ). So we propose to further approximate this process with beam decoding, which is an effective and efficient algorithm to generate sequences with high likelihood. In this way, MJSD reduces the number of runs of the large model to generate N𝑁Nitalic_N tokens, thus accelerating the inference in the same way as vanilla speculative decoding does. Algorithm 1 in the Appendix illustrates the pseudocode of MJSD algorithm.

Refer to caption
Figure 2: An example of MJSD’s verification process. MJSD accepts the longest draft sub-sequence that passes verification based on joint likelihood.
Draft Tokens Verification

Figure 2 illustrates the verification process of MJSD. Let xt+1,,xt+γsubscript𝑥𝑡1subscript𝑥𝑡𝛾x_{t+1},\ldots,x_{t+\gamma}italic_x start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT , … , italic_x start_POSTSUBSCRIPT italic_t + italic_γ end_POSTSUBSCRIPT be the draft tokens generated by beam decoding with the small model. Since beam decoding is a widely recognized algorithm to generate sequences with high overall likelihood [18], it is reasonable to assume q(xt+1:t+γ|x1:t)𝑞conditionalsubscript𝑥:𝑡1𝑡𝛾subscript𝑥:1𝑡q(x_{t+1:t+\gamma}|x_{1:t})italic_q ( italic_x start_POSTSUBSCRIPT italic_t + 1 : italic_t + italic_γ end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ) is large. Also, since beam decoding works in an autoregressive way, we can also assume that j{1,,γ}for-all𝑗1𝛾\forall j\in\{1,\ldots,\gamma\}∀ italic_j ∈ { 1 , … , italic_γ }, q(xt+1:t+j|x1:t)𝑞conditionalsubscript𝑥:𝑡1𝑡𝑗subscript𝑥:1𝑡q(x_{t+1:t+j}|x_{1:t})italic_q ( italic_x start_POSTSUBSCRIPT italic_t + 1 : italic_t + italic_j end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ) is large. To approximate MJGD, for each step i𝑖iitalic_i, MJSD needs to ensure the accepted tokens xt+1:t+γisubscript𝑥:𝑡1𝑡subscript𝛾𝑖x_{t+1:t+\gamma_{i}}italic_x start_POSTSUBSCRIPT italic_t + 1 : italic_t + italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT (0γiγ0subscript𝛾𝑖𝛾0\leq\gamma_{i}\leq\gamma0 ≤ italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ≤ italic_γ) also have high joint likelihood with the large model Mpsubscript𝑀𝑝M_{p}italic_M start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT. So MJSD first computes the joint likelihood p(xt+1:t+j|x1:t)𝑝conditionalsubscript𝑥:𝑡1𝑡𝑗subscript𝑥:1𝑡p(x_{t+1:t+j}|x_{1:t})italic_p ( italic_x start_POSTSUBSCRIPT italic_t + 1 : italic_t + italic_j end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ) for j=1,,γ𝑗1𝛾j=1,\ldots,\gammaitalic_j = 1 , … , italic_γ. Then for each prefix sub-sequence xt+1:t+jsubscript𝑥:𝑡1𝑡𝑗x_{t+1:t+j}italic_x start_POSTSUBSCRIPT italic_t + 1 : italic_t + italic_j end_POSTSUBSCRIPT, it passes verification if and only if min(1,p(xt+1:t+j|x1:t)q(xt+1:t+j|x1:t))>τ1𝑝conditionalsubscript𝑥:𝑡1𝑡𝑗subscript𝑥:1𝑡𝑞conditionalsubscript𝑥:𝑡1𝑡𝑗subscript𝑥:1𝑡𝜏\min(1,\frac{p(x_{t+1:t+j}|x_{1:t})}{q(x_{t+1:t+j}|x_{1:t})})>\tauroman_min ( 1 , divide start_ARG italic_p ( italic_x start_POSTSUBSCRIPT italic_t + 1 : italic_t + italic_j end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ) end_ARG start_ARG italic_q ( italic_x start_POSTSUBSCRIPT italic_t + 1 : italic_t + italic_j end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ) end_ARG ) > italic_τ, where τ[0,1)𝜏01\tau\in[0,1)italic_τ ∈ [ 0 , 1 ) is a pre-defined threshold. Notice that if min(1,p(xt+1:t+j|x1:t)q(xt+1:t+j|x1:t))>τ1𝑝conditionalsubscript𝑥:𝑡1𝑡𝑗subscript𝑥:1𝑡𝑞conditionalsubscript𝑥:𝑡1𝑡𝑗subscript𝑥:1𝑡𝜏\min(1,\frac{p(x_{t+1:t+j}|x_{1:t})}{q(x_{t+1:t+j}|x_{1:t})})>\tauroman_min ( 1 , divide start_ARG italic_p ( italic_x start_POSTSUBSCRIPT italic_t + 1 : italic_t + italic_j end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ) end_ARG start_ARG italic_q ( italic_x start_POSTSUBSCRIPT italic_t + 1 : italic_t + italic_j end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ) end_ARG ) > italic_τ, we have p(xt+1:t+j|x1:t)q(xt+1:t+j|x1:t)>τ𝑝conditionalsubscript𝑥:𝑡1𝑡𝑗subscript𝑥:1𝑡𝑞conditionalsubscript𝑥:𝑡1𝑡𝑗subscript𝑥:1𝑡𝜏\frac{p(x_{t+1:t+j}|x_{1:t})}{q(x_{t+1:t+j}|x_{1:t})}>\taudivide start_ARG italic_p ( italic_x start_POSTSUBSCRIPT italic_t + 1 : italic_t + italic_j end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ) end_ARG start_ARG italic_q ( italic_x start_POSTSUBSCRIPT italic_t + 1 : italic_t + italic_j end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ) end_ARG > italic_τ, which means q(xt+1:t+j|x1:t)p(xt+1:t+j|x1:t)p(xt+1:t+j|x1:t)<1τ1𝑞conditionalsubscript𝑥:𝑡1𝑡𝑗subscript𝑥:1𝑡𝑝conditionalsubscript𝑥:𝑡1𝑡𝑗subscript𝑥:1𝑡𝑝conditionalsubscript𝑥:𝑡1𝑡𝑗subscript𝑥:1𝑡1𝜏1\frac{q(x_{t+1:t+j}|x_{1:t})-p(x_{t+1:t+j}|x_{1:t})}{p(x_{t+1:t+j}|x_{1:t})}<% \frac{1}{\tau}-1divide start_ARG italic_q ( italic_x start_POSTSUBSCRIPT italic_t + 1 : italic_t + italic_j end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ) - italic_p ( italic_x start_POSTSUBSCRIPT italic_t + 1 : italic_t + italic_j end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ) end_ARG start_ARG italic_p ( italic_x start_POSTSUBSCRIPT italic_t + 1 : italic_t + italic_j end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ) end_ARG < divide start_ARG 1 end_ARG start_ARG italic_τ end_ARG - 1. Therefore, our acceptance policy guarantees that when q(xt+1:t+j|x1:t)>p(xt+1:t+j|x1:t)𝑞conditionalsubscript𝑥:𝑡1𝑡𝑗subscript𝑥:1𝑡𝑝conditionalsubscript𝑥:𝑡1𝑡𝑗subscript𝑥:1𝑡q(x_{t+1:t+j}|x_{1:t})>p(x_{t+1:t+j}|x_{1:t})italic_q ( italic_x start_POSTSUBSCRIPT italic_t + 1 : italic_t + italic_j end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ) > italic_p ( italic_x start_POSTSUBSCRIPT italic_t + 1 : italic_t + italic_j end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ), the relative error is bounded. And if q(xt+1:t+j|x1:t)p(xt+1:t+j|x1:t)𝑞conditionalsubscript𝑥:𝑡1𝑡𝑗subscript𝑥:1𝑡𝑝conditionalsubscript𝑥:𝑡1𝑡𝑗subscript𝑥:1𝑡q(x_{t+1:t+j}|x_{1:t})\leq p(x_{t+1:t+j}|x_{1:t})italic_q ( italic_x start_POSTSUBSCRIPT italic_t + 1 : italic_t + italic_j end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ) ≤ italic_p ( italic_x start_POSTSUBSCRIPT italic_t + 1 : italic_t + italic_j end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ), it means the sub-sequence has higher likelihood in the large model, then it is reasonable to accept it. After verifying all the sub-sequences, MJSD accepts the longest draft sub-sequence that passes verification.

The verification step of MJSD ensures that the accepted tokens have a high joint likelihood with the large model. We have shown that selecting multiple tokens based on their joint likelihood lead to better output perplexity. Thus, MJSD is more effective than greedy decoding and vanilla speculative decoding. Furthermore, since MJSD accepts the longest draft sub-sequence with high likelihood, it can tolerate low-quality tokens as long as the joint likelihood is high. So at each iteration, MJSD can accept more draft tokens than vanilla speculative decoding, which results in better efficiency.

4 Energy Efficiency Analysis

Previous studies [20, 5, 17, 32] only focus on the speed of speculative decoding. However, an equally important consideration is energy consumption. To our knowledge, there is no existing work evaluating the impact of speculative decoding on inference energy consumption. Although MJSD and speculative decoding increase the number of FLOPs due to the involvement of an additional small model and the rollback operation, they concurrently reduce the inference time and memory operations, which are key factors of GPU energy consumption [1, 6]. Consequently, it poses an open question regarding whether speculative decoding increases or decreases overall energy consumption.

To understand the net effect of speculative decoding, we decompose the total energy consumption into two parts following [1]:

Etotal=PWflopTflop+PWmemTmemsubscript𝐸𝑡𝑜𝑡𝑎𝑙𝑃subscript𝑊𝑓𝑙𝑜𝑝subscript𝑇𝑓𝑙𝑜𝑝𝑃subscript𝑊𝑚𝑒𝑚subscript𝑇𝑚𝑒𝑚E_{total}=PW_{flop}T_{flop}+PW_{mem}T_{mem}italic_E start_POSTSUBSCRIPT italic_t italic_o italic_t italic_a italic_l end_POSTSUBSCRIPT = italic_P italic_W start_POSTSUBSCRIPT italic_f italic_l italic_o italic_p end_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT italic_f italic_l italic_o italic_p end_POSTSUBSCRIPT + italic_P italic_W start_POSTSUBSCRIPT italic_m italic_e italic_m end_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT italic_m italic_e italic_m end_POSTSUBSCRIPT (2)

where PWflop,PWmem𝑃subscript𝑊𝑓𝑙𝑜𝑝𝑃subscript𝑊𝑚𝑒𝑚PW_{flop},PW_{mem}italic_P italic_W start_POSTSUBSCRIPT italic_f italic_l italic_o italic_p end_POSTSUBSCRIPT , italic_P italic_W start_POSTSUBSCRIPT italic_m italic_e italic_m end_POSTSUBSCRIPT denote the power (energy/second) of FLOPs and memory operations, and Tflop,Tmemsubscript𝑇𝑓𝑙𝑜𝑝subscript𝑇𝑚𝑒𝑚T_{flop},T_{mem}italic_T start_POSTSUBSCRIPT italic_f italic_l italic_o italic_p end_POSTSUBSCRIPT , italic_T start_POSTSUBSCRIPT italic_m italic_e italic_m end_POSTSUBSCRIPT denote the time spent on these operations. When input batch size increases, PWflop𝑃subscript𝑊𝑓𝑙𝑜𝑝PW_{flop}italic_P italic_W start_POSTSUBSCRIPT italic_f italic_l italic_o italic_p end_POSTSUBSCRIPT increases until it reaches the power of maximum FLOPs, denoted as PWflop𝑃subscriptsuperscript𝑊𝑓𝑙𝑜𝑝PW^{*}_{flop}italic_P italic_W start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_f italic_l italic_o italic_p end_POSTSUBSCRIPT. Meanwhile, PWmem𝑃subscript𝑊𝑚𝑒𝑚PW_{mem}italic_P italic_W start_POSTSUBSCRIPT italic_m italic_e italic_m end_POSTSUBSCRIPT is irrelevant to the input batch size, as it only depends on the memory hardware.

To determine the relative magnitude relationship between PWflop𝑃subscript𝑊𝑓𝑙𝑜𝑝PW_{flop}italic_P italic_W start_POSTSUBSCRIPT italic_f italic_l italic_o italic_p end_POSTSUBSCRIPT and PWmem𝑃subscript𝑊𝑚𝑒𝑚PW_{mem}italic_P italic_W start_POSTSUBSCRIPT italic_m italic_e italic_m end_POSTSUBSCRIPT, we first point out the fact that GPU memory operations in LLM inference are dominated by accessing off-chip global memory, which consumes about 100×100\times100 × of energy compared to accessing on-chip shared memory [16]. It is because each multiprocessor on GPU usually has 64KB of on-chip memory shared by multiple threads, while storing a single layer of LLM, say T5-11b [25], requires about 1GB memory. Moreover, Allen and Ge showed that doing sequential read from off-chip memory consumes 20202020-30%percent3030\%30 % more power than running maximum FLOPs [1]. So we have PWmem>PWflopPWflop𝑃subscript𝑊𝑚𝑒𝑚𝑃subscriptsuperscript𝑊𝑓𝑙𝑜𝑝𝑃subscript𝑊𝑓𝑙𝑜𝑝PW_{mem}>PW^{*}_{flop}\geq PW_{flop}italic_P italic_W start_POSTSUBSCRIPT italic_m italic_e italic_m end_POSTSUBSCRIPT > italic_P italic_W start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_f italic_l italic_o italic_p end_POSTSUBSCRIPT ≥ italic_P italic_W start_POSTSUBSCRIPT italic_f italic_l italic_o italic_p end_POSTSUBSCRIPT. Notice that PWflop=PWflop𝑃subscriptsuperscript𝑊𝑓𝑙𝑜𝑝𝑃subscript𝑊𝑓𝑙𝑜𝑝PW^{*}_{flop}=PW_{flop}italic_P italic_W start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_f italic_l italic_o italic_p end_POSTSUBSCRIPT = italic_P italic_W start_POSTSUBSCRIPT italic_f italic_l italic_o italic_p end_POSTSUBSCRIPT only if the batch size reaches the maximum parallelization capacity of GPUs. During greedy decoding and speculative decoding, the batch size is usually small [20]. So most of the computing power is not utilized [20], which means PWmemPWflopmuch-greater-than𝑃subscript𝑊𝑚𝑒𝑚𝑃subscript𝑊𝑓𝑙𝑜𝑝PW_{mem}\gg PW_{flop}italic_P italic_W start_POSTSUBSCRIPT italic_m italic_e italic_m end_POSTSUBSCRIPT ≫ italic_P italic_W start_POSTSUBSCRIPT italic_f italic_l italic_o italic_p end_POSTSUBSCRIPT.

In addition, previous studies have shown that during LLM inference TmemTflopmuch-greater-thansubscript𝑇𝑚𝑒𝑚subscript𝑇𝑓𝑙𝑜𝑝T_{mem}\gg T_{flop}italic_T start_POSTSUBSCRIPT italic_m italic_e italic_m end_POSTSUBSCRIPT ≫ italic_T start_POSTSUBSCRIPT italic_f italic_l italic_o italic_p end_POSTSUBSCRIPT [20]. Therefore, the energy induced by memory operations, i.e., PWmemTmem𝑃subscript𝑊𝑚𝑒𝑚subscript𝑇𝑚𝑒𝑚PW_{mem}T_{mem}italic_P italic_W start_POSTSUBSCRIPT italic_m italic_e italic_m end_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT italic_m italic_e italic_m end_POSTSUBSCRIPT dominates Etotalsubscript𝐸𝑡𝑜𝑡𝑎𝑙E_{total}italic_E start_POSTSUBSCRIPT italic_t italic_o italic_t italic_a italic_l end_POSTSUBSCRIPT. Since speculative decoding reduces Tmemsubscript𝑇𝑚𝑒𝑚T_{mem}italic_T start_POSTSUBSCRIPT italic_m italic_e italic_m end_POSTSUBSCRIPT by reducing the number of runs of the large model, it should reduce the inference energy consumption to a similar extent as it reduces time consumption.

To validate our hypothesis, we conducted an experiment to evaluate how batch size influences energy consumption during inference. We ran OPT-13b models on 2 Nvidia A10 GPUs with 24GB memory. Fixing the total number of runs of the large model while varying the input batch size b{1,2,4,8,16}𝑏124816b\in\{1,2,4,8,16\}italic_b ∈ { 1 , 2 , 4 , 8 , 16 } for each run, we measured time and energy cost. The details of energy measurement are illustrated in the Appendix B. Table 1 shows the results. As batch size doubles, although the number of FLOPs doubles, the energy consumption per run increases slightly. This observation demonstrates that PWmemTmem𝑃subscript𝑊𝑚𝑒𝑚subscript𝑇𝑚𝑒𝑚PW_{mem}T_{mem}italic_P italic_W start_POSTSUBSCRIPT italic_m italic_e italic_m end_POSTSUBSCRIPT italic_T start_POSTSUBSCRIPT italic_m italic_e italic_m end_POSTSUBSCRIPT dominates Etotalsubscript𝐸𝑡𝑜𝑡𝑎𝑙E_{total}italic_E start_POSTSUBSCRIPT italic_t italic_o italic_t italic_a italic_l end_POSTSUBSCRIPT. In the next section, we will show more direct evidence for the energy efficiency of MJSD and other speculative decoding algorithms.

Table 1: The effect of batch size to inference speed and energy consumption. The number of inputs is the product of the number of LLM runs and input batch size.
Batch Size Energy (J) Energy/run (J) Energy/Input (J) Time (s) Time/run (s) Time/input(s)
1 42,450 14.1 14.1 1,129 0.376 0.376
2 49,621 16.5 8.26 1,191 0.397 0.198
4 53,325 17.7 4.43 1,178 0.392 0.098
8 59,210 19.7 2.46 1,211 0.403 0.050
16 74,058 24.7 1.54 1,255 0.418 0.026

5 Experiments

Datasets and Models.We use three public datasets for evaluation: (1) ChatGPT-Prompt [26], (2) ChatAlpaca [3], and (3) CNN Dailymail [30]. The three datasets represent three common tasks for LLMs: instruction, multi-turn chat, and summarization. Table 4 in the Appendix shows more details of the datasets. We use two pairs of small and large models in our experiments: OPT-125m and OPT-13B [38], and Llama-68M [24] and Llama-2-13B [34].

Baselines. For each pair of small and large models, we compare our method with greedy decoding (greedy), and three speculative decoding methods: vanilla speculative decoding (speculative[19, 5], Spectr [32], and BiLD [17]. Our implementation of MJSD and all the baselines are based on a public implementation of speculative decoding [2]. We also tried to compare with SpecInfer [24] using their released implementation. But it runs out of memory on our machine. For each method, we let it generate at most 128 tokens for each input and run it for 1,00010001,0001 , 000 seconds. We open-sourced our code for reproduction.

More details of the hyper-parameters, warping operations, and machine configurations of the experiments can be found in the Appendix B, C, and D .

Table 2: Inference efficiency and output perplexity of different methods on ChatGPT-Prompt (CP), ChatAlpaca (CA), and CNNDailyMail (CD) datasets. Bold numbers mark the best result, underlined numbers mark the second best.
\hlineB3 greedy speculative BiLD Spectr MJSD
\hlineB3 CP Llama-2 speed (token/s) \uparrow 22.6±plus-or-minus\pm±0.03 36.8±plus-or-minus\pm±0.53 34.4±plus-or-minus\pm±0.87 45.1±plus-or-minus\pm±1.32 63.0±plus-or-minus\pm±0.20
speed up \uparrow 1.00±plus-or-minus\pm±0.00 1.63±plus-or-minus\pm±0.02 1.52±plus-or-minus\pm±0.04 1.99±plus-or-minus\pm±0.06 2.78±plus-or-minus\pm±0.01
energy (J/token) \downarrow 11.7±plus-or-minus\pm±1.56 6.62±plus-or-minus\pm±0.91 7.45±plus-or-minus\pm±0.90 5.17±plus-or-minus\pm±0.88 3.38±plus-or-minus\pm±0.02
perplexity \downarrow 3.74±plus-or-minus\pm±0.14 3.64±plus-or-minus\pm±0.11 3.15±plus-or-minus\pm±0.06 3.64±plus-or-minus\pm±0.08 2.06±plus-or-minus\pm±0.06
OPT speed (token/s) \uparrow 22.4±plus-or-minus\pm±0.48 33.8±plus-or-minus\pm±2.47 31.5±plus-or-minus\pm±1.87 38.0±plus-or-minus\pm±2.20 55.8±plus-or-minus\pm±0.30
speed up\uparrow 1.00±plus-or-minus\pm±0.00 1.51±plus-or-minus\pm±0.08 1.41±plus-or-minus\pm±0.06 1.70±plus-or-minus\pm±0.06 2.50±plus-or-minus\pm±0.05
energy (J/token) \downarrow 13.2±plus-or-minus\pm±0.28 7.48±plus-or-minus\pm±0.07 8.75±plus-or-minus\pm±0.13 6.08±plus-or-minus\pm±0.11 3.61±plus-or-minus\pm±0.03
perplexity \downarrow 5.49±plus-or-minus\pm±0.15 5.47±plus-or-minus\pm±0.11 4.51±plus-or-minus\pm±0.09 5.27±plus-or-minus\pm±0.09 3.00±plus-or-minus\pm±0.09
CA Llama-2 speed (token/s) \uparrow 22.0±plus-or-minus\pm±0.15 31.6±plus-or-minus\pm±0.35 28.8±plus-or-minus\pm±0.20 27.7±plus-or-minus\pm±0.29 44.1±plus-or-minus\pm±0.25
speed up \uparrow 1.00±plus-or-minus\pm±0.00 1.43±plus-or-minus\pm±0.02 1.31±plus-or-minus\pm±0.01 1.26±plus-or-minus\pm±0.00 2.00±plus-or-minus\pm±0.02
energy (J/token) \downarrow 11.2±plus-or-minus\pm±0.16 6.98±plus-or-minus\pm±0.15 7.99±plus-or-minus\pm±0.15 7.20±plus-or-minus\pm±0.08 4.72±plus-or-minus\pm±0.03
perplexity \downarrow 2.11±plus-or-minus\pm±0.01 2.13±plus-or-minus\pm±0.03 1.95±plus-or-minus\pm±0.03 2.15±plus-or-minus\pm±0.01 1.88±plus-or-minus\pm±0.05
OPT speed (token/s) \uparrow 23.8±plus-or-minus\pm±0.10 35.6±plus-or-minus\pm±0.45 38.5±plus-or-minus\pm±0.93 28.4±plus-or-minus\pm±0.34 49.6±plus-or-minus\pm±0.42
speed up\uparrow 1.00±plus-or-minus\pm±0.00 1.49±plus-or-minus\pm±0.01 1.62±plus-or-minus\pm±0.03 1.19±plus-or-minus\pm±0.01 2.08±plus-or-minus\pm±0.03
energy (J/token) \downarrow 11.3±plus-or-minus\pm±0.22 5.74±plus-or-minus\pm±0.11 5.12±plus-or-minus\pm±0.06 6.24±plus-or-minus\pm±0.11 4.03±plus-or-minus\pm±0.02
perplexity \downarrow 3.28±plus-or-minus\pm±0.06 3.32±plus-or-minus\pm±0.10 2.60±plus-or-minus\pm±0.06 3.16±plus-or-minus\pm±0.06 2.07±plus-or-minus\pm±0.03
CD Llama-2 speed (token/s) \uparrow 21.7±plus-or-minus\pm±0.08 30.7±plus-or-minus\pm±0.18 30.5±plus-or-minus\pm±0.21 25.0±plus-or-minus\pm±0.31 44.2±plus-or-minus\pm±0.99
speed up \uparrow 1.00±plus-or-minus\pm±0.00 1.41±plus-or-minus\pm±0.00 1.41±plus-or-minus\pm±0.01 1.15±plus-or-minus\pm±0.01 2.04±plus-or-minus\pm±0.05
energy (J/token) \downarrow 11.3±plus-or-minus\pm±0.21 7.07±plus-or-minus\pm±0.19 7.41±plus-or-minus\pm±0.16 8.22±plus-or-minus\pm±0.19 4.80±plus-or-minus\pm±0.12
perplexity \downarrow 2.88±plus-or-minus\pm±0.04 2.87±plus-or-minus\pm±0.08 2.93±plus-or-minus\pm±0.03 3.06±plus-or-minus\pm±0.11 2.63±plus-or-minus\pm±0.10
OPT speed (token/s) \uparrow 23.3±plus-or-minus\pm±0.81 31.7±plus-or-minus\pm±0.91 30.9±plus-or-minus\pm±0.80 23.7±plus-or-minus\pm±0.40 43.6±plus-or-minus\pm±0.33
speed up\uparrow 1.00±plus-or-minus\pm±0.00 1.36±plus-or-minus\pm±0.04 1.32±plus-or-minus\pm±0.04 1.01±plus-or-minus\pm±0.02 1.87±plus-or-minus\pm±0.02
energy (J/token) \downarrow 11.5±plus-or-minus\pm±0.14 6.37±plus-or-minus\pm±0.11 6.71±plus-or-minus\pm±0.17 7.31±plus-or-minus\pm±0.17 4.86±plus-or-minus\pm±0.03
perplexity \downarrow 3.93±plus-or-minus\pm±0.14 3.97±plus-or-minus\pm±0.06 3.74±plus-or-minus\pm±0.09 4.04±plus-or-minus\pm±0.07 3.17±plus-or-minus\pm±0.06
\hlineB3

5.1 Comparison with Baselines

Table 2 shows the main results of our experiments. The standard deviations in the table are computed by repeating the experiment four times. First, we compare the output perplexity of different algorithms. We can see that the perplexity of greedy decoding, vanilla speculative decoding, and Spectr decoding are close because they have identical sampling distributions for each token. Meanwhile, BiLD approximates the sampling distribution of greedy decoding but yields better perplexity. It is because we set a strict acceptance threshold for BiLD decoding, which lowers the acceptance rate but ensures every token has a high probability in the large model. It can be viewed as having an additional warping operation towards the sampling distribution. Nevertheless, there is a significant gap between the perplexity of MJSD and other baselines. On average, the output perplexity of MJSD is 21.2%percent21.221.2\%21.2 % lower than that of BiLD. It provides strong evidence that MJSD has better effectiveness than existing speculative decoding methods that aim to have identical sampling distribution as greedy decoding.

Next, we compare the speeds of different algorithms. While the speeds of all speculative decoding baselines are close, MJSD is significantly faster. On average, MJSD is 2.21×2.21\times2.21 × faster than greedy decoding, 1.49×1.49\times1.49 × faster than vanilla speculative decoding, 1.55×1.55\times1.55 × faster than BiLD, and 1.64×1.64\times1.64 × faster than Spectr. The speed improvement is because MJSD has a larger average acceptance length at each iteration, as MJSD accepts the longest sub-sequence that passes verification. Figure 3 shows the average acceptance length of vanilla speculative decoding and MJSD. We can see that the acceptance length of MJSD is on average 2.15×2.15\times2.15 × larger than that of vanilla speculative decoding.

Finally, we focus on energy efficiency. We can see that all speculative decoding algorithms have significantly better energy efficiency over greedy decoding. It supports our analysis that speculative decoding can improve the LLM inference energy efficiency by reducing the number of memory operations. In addition, we can see that the energy efficiency of MJSD is significantly better than that of all baseline algorithms. We believe it is also because of the longer acceptance sequence at each iteration. The number of memory operations is proportional to the number of LLM runs, which is inversely proportional to the average acceptance length. Since by always accepting the longest prefix that passes verification, the energy consumption of MJSD is on average 1.62×1.62\times1.62 × smaller than that of vanilla speculative decoding, and 2.84×2.84\times2.84 × smaller than that of greedy decoding.

Refer to caption
Figure 3: Average number of accepted tokens per iteration of speculative decoding and MJSD. Each iteration generates 4 draft tokens.

5.2 Ablation Study

Table 3: Effect of number of beams to the inference performance on ChatGPT-Prompts dataset.
# beams 2 4 6 8
speed (token/s) \uparrow 55.9 59.9 60.2 61.3
Llama-2 energy (J/token) \downarrow 2.43 2.25 2.22 2.20
perplexity \downarrow 2.44 2.12 2.14 2.10
speed (token/s) \uparrow 51.0 54.1 54.3 55.9
OPT energy (J/token) \downarrow 2.50 2.32 2.36 2.30
perplexity \downarrow 3.63 3.16 3.42 3.19

5.2.1 Effects of Number of Beams

First, we investigate how the number of beams used in the beam decoding of the small model affects the inference performance. Table 3 shows the results. Increasing the number of beams improves the quality of the draft tokens, which not only improves the output perplexity but also increases the average acceptance length and hence leads to better efficiency. But we can see that the increment slows down when the number of beams is large enough. In addition, when the number of beams is too large, the inference cost of the small model will become too high.

5.2.2 Effects of Acceptance Threshold

Refer to caption
Refer to caption
Figure 4: Effect of acceptance threshold on output perplexity and decoding speed on ChatGPT-Prompts dataset.

Next, we evaluate the effect of acceptance threshold τ𝜏\tauitalic_τ. Intuitively, when we increase τ𝜏\tauitalic_τ from 0 to 1, the acceptance criterion becomes more strict, the efficiency drops while the output perplexity increases. Surprisingly, this expectation is only partially correct. As shown in Figure 4, the efficiency indeed drops when τ𝜏\tauitalic_τ increases. However, the perplexity increases when τ𝜏\tauitalic_τ is close to 1. When τ=1𝜏1\tau=1italic_τ = 1, all the draft tokens are rejected, which makes MJSD equivalent to greedy decoding. Similarly, when τ𝜏\tauitalic_τ is close to 1, the advantage of multi-token joint greedy decoding on effectiveness disappears, hence the perplexity becomes close to the perplexity of greedy decoding. Another surprising observation is that the perplexity of MJSD is good when τ=0𝜏0\tau=0italic_τ = 0. When τ=0𝜏0\tau=0italic_τ = 0, MJSD is equivalent to generating γ𝛾\gammaitalic_γ tokens using beam decoding with the small model, then generating an additional token with the large model. The fact that MJSD achieves good perplexity when τ=0𝜏0\tau=0italic_τ = 0 can be explained by the fact that “even when a small model is an order of magnitude smaller than a large model, only a small fraction of the small model’s predictions deviate from those of the large model” [17, 20]. Moreover, when τ𝜏\tauitalic_τ ranges from 0.10.10.10.1 to 0.90.90.90.9, the performance of MJSD is relatively stable, suggesting that MJSD is not sensitive to the acceptance threshold.

6 Related Work

Efficient Decoding Inference. There are extensive studies on improving large model inference efficiency. Well-known methods include model quantization [9, 22], model pruning [10, 28], and model distillation [15]. Despite achieving significant speed-ups, a common drawback of these methods is that they have to sacrifice the model’s effectiveness.

A direction closer to our work is non-autoregressive decoding. It is first proposed by [12] to generate multiple tokens in parallel. That is, the model simultaneously predicts p(xt+k|x1:t)𝑝conditionalsubscript𝑥𝑡𝑘subscript𝑥:1𝑡p(x_{t+k}|x_{1:t})italic_p ( italic_x start_POSTSUBSCRIPT italic_t + italic_k end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ) (k=1,2,𝑘12k=1,2,\ldotsitalic_k = 1 , 2 , …). Subsequent studies further improved the performance of parallel decoding by incorporating additional information [35, 31, 21] or using additional iterations to refine predictions [11, 19, 13]. However, these works require continuous training of the model and usually either compromise the model effectiveness or require task-dependent techniques to achieve comparable performance [17].

Speculative Decoding. Speculative decoding was recently proposed in [20, 5] as a way to accelerate LLM inference. Spectr [32] enhances speculative decoding by letting the small model generate multiple i.i.d. draft sequences. While speculative decoding and Spectr use the large model to verify all the tokens drafted by the small model, BiLD [17] only calls the large model when the probability output by the small model is below a pre-defined threshold τ1subscript𝜏1\tau_{1}italic_τ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT. The large model rejects a token if its negative log-likelihood is larger than threshold τ2subscript𝜏2\tau_{2}italic_τ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT. SpecInfer [24] uses one or multiple small models to generate a draft token tree to increase the average acceptance length for each iteration. All these methods can be perceived as exact or approximate versions of sampling tokens from the conditional distribution p(xt|x<t)𝑝conditionalsubscript𝑥𝑡subscript𝑥absent𝑡p(x_{t}|x_{<t})italic_p ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT < italic_t end_POSTSUBSCRIPT ). Therefore, their output perplexity is bounded by greedy decoding.

An orthogonal direction to improve speculative decoding is to improve the effectiveness of the small draft model. It is obvious that if more draft tokens are accepted, the overall inference speed will increase. BiLD [17] uses a model prediction alignment technique to better train the small model. Liu et al. [23] propose online speculative decoding to continually update the draft model based on observed input data. Instead, Rest [14] uses a retrieval model to produce draft tokens. The improvement of small models can benefit all speculative decoding algorithms, including our method.

7 Conclusion

We introduce multi-token joint speculative decoding that significantly enhances output quality with better time and energy efficiency. A distinctive aspect of our work is the exploration of speculative decoding’s impact on inference energy consumption, an often neglected area in existing studies. Experiment results demonstrate our method achieves significant energy reduction. This research contributes not only a novel decoding approach but also valuable insights for optimizing LLM deployment in real-world applications where considerations of both quality and efficiency are crucial.

8 Limitations and Impact Statements

Limitations. This paper mainly considers improving the output perplexity of decoding algorithms. Under the assumption that the model is well-trained for downstream tasks, improving output perplexity usually leads to better downstream effectiveness. But if the model is not well-trained, improving the perplexity may not necessarily improve the downstream effectiveness. However, it is mainly the problem of the model itself. In real-world applications, it is reasonable to assume the model is properly trained for downstream tasks.

Impact Statements. The goal of this work is to advance the field of Large Language Model (LLM) Inference, which has received tremendous attention from both academia and industry. However, LLMs also have received many critiques, including their extremely large carbon footage emission during both training and inference. Our work pays special attention to their energy consumption during inference to provide high-quality and fast-inference LLMs with reduced energy consumption, which has become a serious concern with the rapid increase of LLM-related workloads in the past few years.

References

  • [1] Tyler Allen and Rong Ge. Characterizing power and performance of gpu memory access. In 2016 4th International Workshop on Energy Efficient Supercomputing (E2SC), pages 46–53. IEEE, 2016.
  • [2] Feifei Bear. Llmspeculativesampling. https://github.com/feifeibear/LLMSpeculativeSampling, 2024. Accessed: 2024-05-19.
  • [3] Ning Bian, Hongyu Lin, Yaojie Lu, Xianpei Han, Le Sun, and Ben He. Chatalpaca: A multi-turn dialogue corpus based on alpaca instructions. https://github.com/cascip/ChatAlpaca, 2023.
  • [4] Tom Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared D Kaplan, Prafulla Dhariwal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, et al. Language models are few-shot learners. Advances in neural information processing systems, 33:1877–1901, 2020.
  • [5] Charlie Chen, Sebastian Borgeaud, Geoffrey Irving, Jean-Baptiste Lespiau, Laurent Sifre, and John Jumper. Accelerating large language model decoding with speculative sampling. arXiv preprint arXiv:2302.01318, 2023.
  • [6] Jianmin Chen, Bin Li, Ying Zhang, Lu Peng, and Jih-kwon Peir. Tree structured analysis on gpu power study. In 2011 IEEE 29th International Conference on Computer Design (ICCD), pages 57–64. IEEE, 2011.
  • [7] Aakanksha Chowdhery, Sharan Narang, Jacob Devlin, Maarten Bosma, Gaurav Mishra, Adam Roberts, Paul Barham, Hyung Won Chung, Charles Sutton, Sebastian Gehrmann, et al. Palm: Scaling language modeling with pathways. Journal of Machine Learning Research, 24(240):1–113, 2023.
  • [8] Alex de Vries. The growing energy footprint of artificial intelligence. Joule, 7(10):2191–2194, 2023.
  • [9] Elias Frantar, Saleh Ashkboos, Torsten Hoefler, and Dan Alistarh. Gptq: Accurate post-training quantization for generative pre-trained transformers. arXiv preprint arXiv:2210.17323, 2022.
  • [10] Trevor Gale, Erich Elsen, and Sara Hooker. The state of sparsity in deep neural networks. arXiv preprint arXiv:1902.09574, 2019.
  • [11] Marjan Ghazvininejad, Omer Levy, Yinhan Liu, and Luke Zettlemoyer. Mask-predict: Parallel decoding of conditional masked language models. arXiv preprint arXiv:1904.09324, 2019.
  • [12] Jiatao Gu, James Bradbury, Caiming Xiong, Victor OK Li, and Richard Socher. Non-autoregressive neural machine translation. arXiv preprint arXiv:1711.02281, 2017.
  • [13] Junliang Guo, Linli Xu, and Enhong Chen. Jointly masked sequence-to-sequence model for non-autoregressive neural machine translation. In Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics, pages 376–385, 2020.
  • [14] Zhenyu He, Zexuan Zhong, Tianle Cai, Jason D Lee, and Di He. Rest: Retrieval-based speculative decoding. arXiv preprint arXiv:2311.08252, 2023.
  • [15] Geoffrey Hinton, Oriol Vinyals, and Jeff Dean. Distilling the knowledge in a neural network. arXiv preprint arXiv:1503.02531, 2015.
  • [16] Norman P Jouppi, Doe Hyun Yoon, Matthew Ashcraft, Mark Gottscho, Thomas B Jablin, George Kurian, James Laudon, Sheng Li, Peter Ma, Xiaoyu Ma, et al. Ten lessons from three generations shaped google’s tpuv4i: Industrial product. In 2021 ACM/IEEE 48th Annual International Symposium on Computer Architecture (ISCA), pages 1–14. IEEE, 2021.
  • [17] Sehoon Kim, Karttikeya Mangalam, Suhong Moon, Jitendra Malik, Michael W Mahoney, Amir Gholami, and Kurt Keutzer. Speculative decoding with big little decoder. In Thirty-seventh Conference on Neural Information Processing Systems, 2023.
  • [18] Rémi Leblond, Jean-Baptiste Alayrac, Laurent Sifre, Miruna Pislar, Jean-Baptiste Lespiau, Ioannis Antonoglou, Karen Simonyan, and Oriol Vinyals. Machine translation decoding beyond beam search. arXiv preprint arXiv:2104.05336, 2021.
  • [19] Jason Lee, Elman Mansimov, and Kyunghyun Cho. Deterministic non-autoregressive neural sequence modeling by iterative refinement. arXiv preprint arXiv:1802.06901, 2018.
  • [20] Yaniv Leviathan, Matan Kalman, and Yossi Matias. Fast inference from transformers via speculative decoding. In International Conference on Machine Learning, pages 19274–19286. PMLR, 2023.
  • [21] Zhuohan Li, Zi Lin, Di He, Fei Tian, Tao Qin, Liwei Wang, and Tie-Yan Liu. Hint-based training for non-autoregressive machine translation. arXiv preprint arXiv:1909.06708, 2019.
  • [22] Ji Lin, Jiaming Tang, Haotian Tang, Shang Yang, Xingyu Dang, and Song Han. Awq: Activation-aware weight quantization for llm compression and acceleration. arXiv preprint arXiv:2306.00978, 2023.
  • [23] Xiaoxuan Liu, Lanxiang Hu, Peter Bailis, Ion Stoica, Zhijie Deng, Alvin Cheung, and Hao Zhang. Online speculative decoding. arXiv preprint arXiv:2310.07177, 2023.
  • [24] Xupeng Miao, Gabriele Oliaro, Zhihao Zhang, Xinhao Cheng, Zeyu Wang, Rae Ying Yee Wong, Zhuoming Chen, Daiyaan Arfeen, Reyna Abhyankar, and Zhihao Jia. Specinfer: Accelerating generative llm serving with speculative inference and token tree verification. arXiv preprint arXiv:2305.09781, 1(2):4, 2023.
  • [25] Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena, Yanqi Zhou, Wei Li, and Peter J Liu. Exploring the limits of transfer learning with a unified text-to-text transformer. The Journal of Machine Learning Research, 21(1):5485–5551, 2020.
  • [26] Mohamed Rashad. Chatgpt-prompts, 2023.
  • [27] Siddharth Samsi, Dan Zhao, Joseph McDonald, Baolin Li, Adam Michaleas, Michael Jones, William Bergeron, Jeremy Kepner, Devesh Tiwari, and Vijay Gadepally. From words to watts: Benchmarking the energy costs of large language model inference. In 2023 IEEE High Performance Extreme Computing Conference (HPEC), pages 1–9. IEEE, 2023.
  • [28] Victor Sanh, Thomas Wolf, and Alexander Rush. Movement pruning: Adaptive sparsity by fine-tuning. Advances in Neural Information Processing Systems, 33:20378–20389, 2020.
  • [29] Tal Schuster, Adam Fisch, Jai Gupta, Mostafa Dehghani, Dara Bahri, Vinh Tran, Yi Tay, and Donald Metzler. Confident adaptive language modeling. Advances in Neural Information Processing Systems, 35:17456–17472, 2022.
  • [30] Abigail See, Peter J. Liu, and Christopher D. Manning. Get to the point: Summarization with pointer-generator networks. In Proceedings of the 55th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), pages 1073–1083, Vancouver, Canada, July 2017. Association for Computational Linguistics.
  • [31] Zhiqing Sun, Zhuohan Li, Haoqing Wang, Di He, Zi Lin, and Zhihong Deng. Fast structured decoding for sequence models. Advances in Neural Information Processing Systems, 32, 2019.
  • [32] Ziteng Sun, Ananda Theertha Suresh, Jae Hun Ro, Ahmad Beirami, Himanshu Jain, and Felix Yu. Spectr: Fast speculative decoding via optimal transport. arXiv preprint arXiv:2310.15141, 2023.
  • [33] Romal Thoppilan, Daniel De Freitas, Jamie Hall, Noam Shazeer, Apoorv Kulshreshtha, Heng-Tze Cheng, Alicia Jin, Taylor Bos, Leslie Baker, Yu Du, et al. Lamda: Language models for dialog applications. arXiv preprint arXiv:2201.08239, 2022.
  • [34] Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timothée Lacroix, Baptiste Rozière, Naman Goyal, Eric Hambro, Faisal Azhar, et al. Llama: Open and efficient foundation language models. arXiv preprint arXiv:2302.13971, 2023.
  • [35] Yiren Wang, Fei Tian, Di He, Tao Qin, ChengXiang Zhai, and Tie-Yan Liu. Non-autoregressive machine translation with auxiliary regularization. In Proceedings of the AAAI conference on artificial intelligence, volume 33, pages 5377–5384, 2019.
  • [36] Yisheng Xiao, Lijun Wu, Junliang Guo, Juntao Li, Min Zhang, Tao Qin, and Tie-yan Liu. A survey on non-autoregressive generation for neural machine translation and beyond. IEEE Transactions on Pattern Analysis and Machine Intelligence, 2023.
  • [37] Zeyu Yang, Karel Adamek, and Wesley Armour. Part-time power measurements: nvidia-smi’s lack of attention. arXiv preprint arXiv:2312.02741, 2023.
  • [38] Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen, Christopher Dewan, Mona Diab, Xian Li, Xi Victoria Lin, et al. Opt: Open pre-trained transformer language models. arXiv preprint arXiv:2205.01068, 2022.

Appendix A Pseudocode of MJSD

See Algorithm 1.

Algorithm 1 One Iteration of MJSD Algorithm
1:Input: draft model Mqsubscript𝑀𝑞M_{q}italic_M start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT, target model Mpsubscript𝑀𝑝M_{p}italic_M start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT, input𝑖𝑛𝑝𝑢𝑡inputitalic_i italic_n italic_p italic_u italic_t, threshold τ𝜏\tauitalic_τ
2:\eqparboxCOMMENT# Sample draft sequences from Mqsubscript𝑀𝑞M_{q}italic_M start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT with beam sample.
3:𝒙,𝒒𝒙𝒒absent\bm{x},\bm{q}\leftarrowbold_italic_x , bold_italic_q ← beamSample(Mqsubscript𝑀𝑞M_{q}italic_M start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT, input𝑖𝑛𝑝𝑢𝑡inputitalic_i italic_n italic_p italic_u italic_t) \eqparboxCOMMENT# 𝒙isubscript𝒙𝑖\bm{x}_{i}bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is the i𝑖iitalic_i-th draft token. 𝒒i=q(𝒙1:i|input)subscript𝒒𝑖𝑞conditionalsubscript𝒙:1𝑖𝑖𝑛𝑝𝑢𝑡\bm{q}_{i}=q(\bm{x}_{1:i}|input)bold_italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_q ( bold_italic_x start_POSTSUBSCRIPT 1 : italic_i end_POSTSUBSCRIPT | italic_i italic_n italic_p italic_u italic_t )
4:𝑷Mp(input,𝑿)𝑷subscript𝑀𝑝𝑖𝑛𝑝𝑢𝑡𝑿\bm{P}\leftarrow M_{p}(input,\bm{X})bold_italic_P ← italic_M start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ( italic_i italic_n italic_p italic_u italic_t , bold_italic_X ) \eqparboxCOMMENT# 𝑷𝐑(γ+1)×|V|𝑷superscript𝐑𝛾1𝑉\bm{P}\in\mathbf{R}^{(\gamma+1)\times|V|}bold_italic_P ∈ bold_R start_POSTSUPERSCRIPT ( italic_γ + 1 ) × | italic_V | end_POSTSUPERSCRIPT, 𝑷i,j=p(x=j|𝒙1:i1,input)subscript𝑷𝑖𝑗𝑝𝑥conditional𝑗subscript𝒙:1𝑖1𝑖𝑛𝑝𝑢𝑡\bm{P}_{i,j}=p(x=j|\bm{x}_{1:i-1},input)bold_italic_P start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT = italic_p ( italic_x = italic_j | bold_italic_x start_POSTSUBSCRIPT 1 : italic_i - 1 end_POSTSUBSCRIPT , italic_i italic_n italic_p italic_u italic_t )
5:\eqparboxCOMMENT# Select the longest accepted draft sequence
6:p1𝑝1p\leftarrow 1italic_p ← 1, η1𝜂1\eta\leftarrow-1italic_η ← - 1
7:for i=1𝑖1i=1italic_i = 1 to γ𝛾\gammaitalic_γ do
8:     j𝒙i𝑗subscript𝒙𝑖j\leftarrow\bm{x}_{i}italic_j ← bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT
9:     pp𝑷i,j𝑝𝑝subscript𝑷𝑖𝑗p\leftarrow p*\bm{P}_{i,j}italic_p ← italic_p ∗ bold_italic_P start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT, q𝒒i𝑞subscript𝒒𝑖q\leftarrow\bm{q}_{i}italic_q ← bold_italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT
10:     if τ<min(1,pq)𝜏1𝑝𝑞\tau<\min(1,\frac{p}{q})italic_τ < roman_min ( 1 , divide start_ARG italic_p end_ARG start_ARG italic_q end_ARG ) then
11:         ηj𝜂𝑗\eta\leftarrow jitalic_η ← italic_j \eqparboxCOMMENT# longest accepted prefix so far
12:     end if
13:end for
14:\eqparboxCOMMENT# Sample the next token using results of Mpsubscript𝑀𝑝M_{p}italic_M start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT
15:𝒑Pη+1superscript𝒑subscript𝑃𝜂1\bm{p}^{\prime}\leftarrow P_{\eta+1}bold_italic_p start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ← italic_P start_POSTSUBSCRIPT italic_η + 1 end_POSTSUBSCRIPT
16:t𝒑similar-to𝑡superscript𝒑t\sim\bm{p}^{\prime}italic_t ∼ bold_italic_p start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT
17:return [𝒙1,,𝒙η,t]subscript𝒙1subscript𝒙𝜂𝑡[\bm{x}_{1},\ldots,\bm{x}_{\eta},t][ bold_italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_x start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT , italic_t ]
Table 4: Dataset Statistics
Dataset Task # Test Input Avg. Input Len
ChatGPT-Prompt Instruction 360 25.2
ChatAlpaca Chat 43.081 277.7
CNNDM Summarization 11,490 3,967.1

Appendix B Energy Consumption Measurement

We use the command "nvidia-smi --query-gpu=power.draw --format=csv" to get GPU power every second, and sum them up as the total energy consumption. We use average energy consumption per token to measure energy efficiency. There is a recent study pointing out the measurement error using nvidia-smi [37]. We follow the three principles proposed in [37] to minimize the error.

Appendix C Configuration

The experiments are conducted on a machine with 1 Nvidia L40 GPU (48 GB), 4 CPUs, and 50 GB main memory, using a batch size of 1, which is common for online serving [29]. We set the maximum running time to be an hour for each baseline. We use average tokens/second to measure the inference speed and use perplexity (exponentiated average negative log-likelihood) based on the probability of the large model to measure the output quality. Because different methods might finish different numbers of inputs, we only calculate the perplexity of the first M𝑀Mitalic_M inputs, where M𝑀Mitalic_M is the number of inputs finished by greedy decoding. We use average energy consumption per token to measure energy efficiency. The details of energy measurement are illustrated in the Appendix.

Appendix D Hyper-parameter Details

In the experiments, we follow the default settings in [2] to warp the sampling distribution p𝑝pitalic_p and q𝑞qitalic_q with the following steps, which are the default warpping operations in a public speculative decoding implementation.

  1. 1.

    Keep the probabilities of top 20 tokens unchanged, and set the probabilities of other tokens to 0, then normalize the distribution.

  2. 2.

    Sort the tokens based on their distributions descendingly. Keep the first K𝐾Kitalic_K tokens such that their cumulative probabilities is larger than 0.9, while set the probabilities of other tokens to 0.

For different methods, we choose their hyper-parameters by using a small validation set. We select the set of hyper-parameters that make the corresponding method have best output perplexity. Table 5 shows the hyper-parameters used in the experiments.

Table 5: Hyper-parameters of different methods for different models and datasets. CP: ChatGPT-Prompts, CA: ChatAlpaca, CD: CNNDaily.
Llama,CP OPT,CP Llama,CA OPT,CA Llama,D OPT,CD
speculative step len γ𝛾\gammaitalic_γ 4 4 4 4 4 4
Spectr step len γ𝛾\gammaitalic_γ 4 4 4 4 4 4
num seq m𝑚mitalic_m 4 4 2 2 4 2
BiLD step len γ𝛾\gammaitalic_γ 10 10 10 10 10 10
fallback thres τ1subscript𝜏1\tau_{1}italic_τ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT 0.9 0.9 0.9 0.3 0.9 0.3
rollback thres τ2subscript𝜏2\tau_{2}italic_τ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT 2 2 1 2 3 2
MJSD step len γ𝛾\gammaitalic_γ 4 4 4 4 4 4
num beams 8 8 8 8 8 8
acc/rej thres τ𝜏\tauitalic_τ 0.1 0.1 0.1 0.1 0.1 0.1

Appendix E License of Datasets and Models

Datasets:

  • ChatGPT-Prompts: Non (https://huggingface.co/datasets/MohamedRashad/ChatGPT-prompts)

  • ChatAlpaca: Apache-2.0 License

  • CNN Dailymail: Apache-2.0 License

Models

Codes

Appendix F More Discussion on Multi-Token Joint Greedy Decoding

It is not trivial to prove that MJGD is better than vanilla greedy decoding. We might tend to show that given the same x1:tsubscript𝑥:1𝑡x_{1:t}italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT and prefix𝑝𝑟𝑒𝑓𝑖𝑥prefixitalic_p italic_r italic_e italic_f italic_i italic_x, the γisubscript𝛾𝑖\gamma_{i}italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT tokens generated by MJGD has higher likelihood than the γisubscript𝛾𝑖\gamma_{i}italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT tokens generated by calling greedy decoding γisubscript𝛾𝑖\gamma_{i}italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT times. However, at step i𝑖iitalic_i (i2for-all𝑖2\forall i\geq 2∀ italic_i ≥ 2), the previously generated tokens x1:tsubscript𝑥:1𝑡x_{1:t}italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT are different for the two decoding algorithms, making it difficult to compare the overall perplexity.

To explain why MJGD is more effective, let us consider an optimal algorithm that generates next γisubscript𝛾𝑖\gamma_{i}italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT tokens at one step based on a score function:

s(xt+1:t+γi)=maxxt+γi+1:Nlogp(xt+1:N|x1:t)superscript𝑠subscript𝑥:𝑡1𝑡subscript𝛾𝑖subscriptsubscript𝑥:𝑡subscript𝛾𝑖1𝑁𝑝conditionalsubscript𝑥:𝑡1𝑁subscript𝑥:1𝑡s^{*}(x_{t+1:t+\gamma_{i}})=\max_{x_{t+\gamma_{i}+1:N}}\log p(x_{t+1:N}|x_{1:t})italic_s start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t + 1 : italic_t + italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) = roman_max start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_t + italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + 1 : italic_N end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log italic_p ( italic_x start_POSTSUBSCRIPT italic_t + 1 : italic_N end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ) (3)

which is the log-likelihood of the optimal future sequence with prefix x1:tsubscript𝑥:1𝑡x_{1:t}italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT and the generated γisubscript𝛾𝑖\gamma_{i}italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT tokens. The optimal greedy decoding chooses xt+1:t+γisubscript𝑥:𝑡1𝑡subscript𝛾𝑖x_{t+1:t+\gamma_{i}}italic_x start_POSTSUBSCRIPT italic_t + 1 : italic_t + italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT that maximizes the score function and is guaranteed to return the sequence with the optimal perplexity.

Both MJGD and vanilla greedy decoding can be viewed as approximations to the optimal greedy decoding by replacing s()superscript𝑠s^{*}(\cdot)italic_s start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( ⋅ ) with sMJGDsubscript𝑠𝑀𝐽𝐺𝐷s_{MJGD{}}italic_s start_POSTSUBSCRIPT italic_M italic_J italic_G italic_D end_POSTSUBSCRIPT oder sGDsubscript𝑠𝐺𝐷s_{GD}italic_s start_POSTSUBSCRIPT italic_G italic_D end_POSTSUBSCRIPT that estimates the values of s()superscript𝑠s^{*}(\cdot)italic_s start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( ⋅ ). Specifically, since MJGD selects the γisubscript𝛾𝑖\gamma_{i}italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT tokens with the largest joint likelihood, we have sMJGD()=logp(x1:t+γi|x1:t)+c1subscript𝑠𝑀𝐽𝐺𝐷𝑝conditionalsubscript𝑥:1𝑡subscript𝛾𝑖subscript𝑥:1𝑡subscript𝑐1s_{MJGD}(\cdot)=\log p(x_{1:t+\gamma_{i}}|x_{1:t})+c_{1}italic_s start_POSTSUBSCRIPT italic_M italic_J italic_G italic_D end_POSTSUBSCRIPT ( ⋅ ) = roman_log italic_p ( italic_x start_POSTSUBSCRIPT 1 : italic_t + italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ) + italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT where c1subscript𝑐1c_{1}italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT is an arbitrary constant. Notice that adding a constant to s()𝑠s(\cdot)italic_s ( ⋅ ) changes the error of the estimation but does not change the behavior of the algorithm.

On the other hand, greedy decoding generates the γisubscript𝛾𝑖\gamma_{i}italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT tokens one by one. At the last token, greedy decoding selects the token with the largest p(xt+γi|x1:t+γi1)𝑝conditionalsubscript𝑥𝑡subscript𝛾𝑖subscript𝑥:1𝑡subscript𝛾𝑖1p(x_{t+\gamma_{i}}|x_{1:t+\gamma_{i}-1})italic_p ( italic_x start_POSTSUBSCRIPT italic_t + italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 : italic_t + italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - 1 end_POSTSUBSCRIPT ), which is equivalent to selecting the tokens with the largest p(xt+1:t+γi|x1:t)𝑝conditionalsubscript𝑥:𝑡1𝑡subscript𝛾𝑖subscript𝑥:1𝑡p(x_{t+1:t+\gamma_{i}}|x_{1:t})italic_p ( italic_x start_POSTSUBSCRIPT italic_t + 1 : italic_t + italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ) provided that xt+1:t+γi1subscript𝑥:𝑡1𝑡subscript𝛾𝑖1x_{t+1:t+\gamma_{i}-1}italic_x start_POSTSUBSCRIPT italic_t + 1 : italic_t + italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - 1 end_POSTSUBSCRIPT are the same tokens generated by previous steps. So for greedy decoding, sGD(xt+1:t+γi)=logp(x1:t+γi|x1:t)+c2subscript𝑠𝐺𝐷subscript𝑥:𝑡1𝑡subscript𝛾𝑖𝑝conditionalsubscript𝑥:1𝑡subscript𝛾𝑖subscript𝑥:1𝑡subscript𝑐2s_{GD}(x_{t+1:t+\gamma_{i}})=\log p(x_{1:t+\gamma_{i}}|x_{1:t})+c_{2}italic_s start_POSTSUBSCRIPT italic_G italic_D end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t + 1 : italic_t + italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) = roman_log italic_p ( italic_x start_POSTSUBSCRIPT 1 : italic_t + italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ) + italic_c start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT (c2subscript𝑐2c_{2}italic_c start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT is also an arbitrary constant) if and only if xt+i=argmaxxp(x|x1:t+i1)subscript𝑥𝑡𝑖subscript𝑥𝑝conditional𝑥subscript𝑥:1𝑡𝑖1x_{t+i}=\arg\max_{x}p(x|x_{1:t+i-1})italic_x start_POSTSUBSCRIPT italic_t + italic_i end_POSTSUBSCRIPT = roman_arg roman_max start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_p ( italic_x | italic_x start_POSTSUBSCRIPT 1 : italic_t + italic_i - 1 end_POSTSUBSCRIPT ) for i=1,,γi1𝑖1subscript𝛾𝑖1i=1,\ldots,\gamma_{i}-1italic_i = 1 , … , italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - 1. Otherwise, the score of xt+1:t+γisubscript𝑥:𝑡1𝑡subscript𝛾𝑖x_{t+1:t+\gamma_{i}}italic_x start_POSTSUBSCRIPT italic_t + 1 : italic_t + italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT has to be smaller than any possible value of logp(x1:t+γi|x1:t)+c2𝑝conditionalsubscript𝑥:1𝑡subscript𝛾𝑖subscript𝑥:1𝑡subscript𝑐2\log p(x_{1:t+\gamma_{i}}|x_{1:t})+c_{2}roman_log italic_p ( italic_x start_POSTSUBSCRIPT 1 : italic_t + italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ) + italic_c start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT. Notice that the lower bound of logp(x1:t+γi|x1:t)𝑝conditionalsubscript𝑥:1𝑡subscript𝛾𝑖subscript𝑥:1𝑡\log p(x_{1:t+\gamma_{i}}|x_{1:t})roman_log italic_p ( italic_x start_POSTSUBSCRIPT 1 : italic_t + italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT ) is -\infty- ∞, so sGD(xt+1:t+γi)=subscript𝑠𝐺𝐷subscript𝑥:𝑡1𝑡subscript𝛾𝑖s_{GD}(x_{t+1:t+\gamma_{i}})=-\inftyitalic_s start_POSTSUBSCRIPT italic_G italic_D end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t + 1 : italic_t + italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) = - ∞ if and only if j=1,,γi1𝑗1subscript𝛾𝑖1\exists j=1,\ldots,\gamma_{i}-1∃ italic_j = 1 , … , italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - 1 such that xt+jargmaxxp(x|x1:t+i1)subscript𝑥𝑡𝑗subscript𝑥𝑝conditional𝑥subscript𝑥:1𝑡𝑖1x_{t+j}\neq\arg\max_{x}p(x|x_{1:t+i-1})italic_x start_POSTSUBSCRIPT italic_t + italic_j end_POSTSUBSCRIPT ≠ roman_arg roman_max start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_p ( italic_x | italic_x start_POSTSUBSCRIPT 1 : italic_t + italic_i - 1 end_POSTSUBSCRIPT ).

Let c=argminc2xt+1:t+γi|sGD(xt+1:t+γi)s(xt+1:t+γi)|superscript𝑐subscriptsubscript𝑐2subscriptsubscript𝑥:𝑡1𝑡subscript𝛾𝑖subscript𝑠𝐺𝐷subscript𝑥:𝑡1𝑡subscript𝛾𝑖superscript𝑠subscript𝑥:𝑡1𝑡subscript𝛾𝑖c^{*}=\arg\min_{c_{2}}\sum_{x_{t+1:t+\gamma_{i}}}|s_{GD}(x_{t+1:t+\gamma_{i}})% -s^{*}(x_{t+1:t+\gamma_{i}})|italic_c start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT = roman_arg roman_min start_POSTSUBSCRIPT italic_c start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_t + 1 : italic_t + italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_POSTSUBSCRIPT | italic_s start_POSTSUBSCRIPT italic_G italic_D end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t + 1 : italic_t + italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) - italic_s start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t + 1 : italic_t + italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) |. By setting c1=csubscript𝑐1superscript𝑐c_{1}=c^{*}italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = italic_c start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT, for xt+1,,xt+γi1subscript𝑥𝑡1subscript𝑥𝑡subscript𝛾𝑖1x_{t+1},\ldots,x_{t+\gamma_{i}-1}italic_x start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT , … , italic_x start_POSTSUBSCRIPT italic_t + italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - 1 end_POSTSUBSCRIPT satisfying that xt+i=argmaxxp(x|x1:t+i1)subscript𝑥𝑡𝑖subscript𝑥𝑝conditional𝑥subscript𝑥:1𝑡𝑖1x_{t+i}=\arg\max_{x}p(x|x_{1:t+i-1})italic_x start_POSTSUBSCRIPT italic_t + italic_i end_POSTSUBSCRIPT = roman_arg roman_max start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT italic_p ( italic_x | italic_x start_POSTSUBSCRIPT 1 : italic_t + italic_i - 1 end_POSTSUBSCRIPT ), i=1,,γi1for-all𝑖1subscript𝛾𝑖1\forall i=1,\ldots,\gamma_{i}-1∀ italic_i = 1 , … , italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - 1, we have

sMJGD(xt+1:t+γi)=sGD(xt+1:t+γi)=logp(x1:t+γi|x1:t)+csubscript𝑠𝑀𝐽𝐺𝐷subscript𝑥:𝑡1𝑡subscript𝛾𝑖subscript𝑠𝐺𝐷subscript𝑥:𝑡1𝑡subscript𝛾𝑖𝑝subscript𝑥:1𝑡conditionalsubscript𝛾𝑖subscript𝑥:1𝑡superscript𝑐\displaystyle s_{MJGD}(x_{t+1:t+\gamma_{i}})=s_{GD}(x_{t+1:t+\gamma_{i}})=\log p% (x_{1:t+\gamma_{i}|x_{1:t}})+c^{*}italic_s start_POSTSUBSCRIPT italic_M italic_J italic_G italic_D end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t + 1 : italic_t + italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) = italic_s start_POSTSUBSCRIPT italic_G italic_D end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t + 1 : italic_t + italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) = roman_log italic_p ( italic_x start_POSTSUBSCRIPT 1 : italic_t + italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 : italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) + italic_c start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT

Otherwise, we have

|sMJGD(xt+1:t+γi)s(xt+1:t+γi)|=|ch(xt+1:t+γi)||infs(xt+1:t+γi)|subscript𝑠𝑀𝐽𝐺𝐷subscript𝑥:𝑡1𝑡subscript𝛾𝑖superscript𝑠subscript𝑥:𝑡1𝑡subscript𝛾𝑖superscript𝑐superscriptsubscript𝑥:𝑡1𝑡subscript𝛾𝑖infimumsuperscript𝑠subscript𝑥:𝑡1𝑡subscript𝛾𝑖|s_{MJGD}(x_{t+1:t+\gamma_{i}})-s^{*}(x_{t+1:t+\gamma_{i}})|=|c^{*}-h^{*}(x_{t% +1:t+\gamma_{i}})|\leq|-\inf-s^{*}(x_{t+1:t+\gamma_{i}})|| italic_s start_POSTSUBSCRIPT italic_M italic_J italic_G italic_D end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t + 1 : italic_t + italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) - italic_s start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t + 1 : italic_t + italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) | = | italic_c start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT - italic_h start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t + 1 : italic_t + italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) | ≤ | - roman_inf - italic_s start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t + 1 : italic_t + italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) |

where h(xt+1:t+γi)=maxxt+γi+1:Nlogp(xt+γi+1:N|x1:t+γi,prefix)superscriptsubscript𝑥:𝑡1𝑡subscript𝛾𝑖subscriptsubscript𝑥:𝑡subscript𝛾𝑖1𝑁𝑝conditionalsubscript𝑥:𝑡subscript𝛾𝑖1𝑁subscript𝑥:1𝑡subscript𝛾𝑖𝑝𝑟𝑒𝑓𝑖𝑥h^{*}(x_{t+1:t+\gamma_{i}})=\max_{x_{t+\gamma_{i}+1:N}}\log p(x_{t+\gamma_{i}+% 1:N}|x_{1:t+\gamma_{i}},prefix)italic_h start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t + 1 : italic_t + italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) = roman_max start_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_t + italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + 1 : italic_N end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log italic_p ( italic_x start_POSTSUBSCRIPT italic_t + italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + 1 : italic_N end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT 1 : italic_t + italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT , italic_p italic_r italic_e italic_f italic_i italic_x ). Therefore,

|sMJGD(xt+1:t+γi)s(xt+1:t+γi)||sGD(x1:K)s(xt+1:t+γi)|subscript𝑠𝑀𝐽𝐺𝐷subscript𝑥:𝑡1𝑡subscript𝛾𝑖superscript𝑠subscript𝑥:𝑡1𝑡subscript𝛾𝑖subscript𝑠𝐺𝐷subscript𝑥:1𝐾superscript𝑠subscript𝑥:𝑡1𝑡subscript𝛾𝑖|s_{MJGD}(x_{t+1:t+\gamma_{i}})-s^{*}(x_{t+1:t+\gamma_{i}})|\leq|s_{GD}(x_{1:K% })-s^{*}(x_{t+1:t+\gamma_{i}})|| italic_s start_POSTSUBSCRIPT italic_M italic_J italic_G italic_D end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t + 1 : italic_t + italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) - italic_s start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t + 1 : italic_t + italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) | ≤ | italic_s start_POSTSUBSCRIPT italic_G italic_D end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT 1 : italic_K end_POSTSUBSCRIPT ) - italic_s start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t + 1 : italic_t + italic_γ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) | (4)

The derivation above provides a heuristic explanation to show that MJGD is a closer approximation to the optimal greedy decoding than vanilla greedy decoding. Hence, its output perplexity is generally better.