Multi-Granularity Semantic Revision for Large Language Model Distillation

Xiaoyu Liu University of Science and Technology of China Yun Zhang DSA Thrust, INFO Hub, Hong Kong University of Science and Technology (GZ) Wei Li Huawei Noah’s Ark Lab Simiao Li Huawei Noah’s Ark Lab Xudong Huang Huawei Noah’s Ark Lab Hanting Chen Huawei Noah’s Ark Lab Yehui Tang Huawei Noah’s Ark Lab Jie Hu Huawei Noah’s Ark Lab Zhiwei Xiong University of Science and Technology of China Yunhe Wang Huawei Noah’s Ark Lab
Abstract

Knowledge distillation plays a key role in compressing the Large Language Models (LLMs), which boosts a small-size student model under large teacher models’ guidance. However, existing LLM distillation methods overly rely on student-generated outputs, which may introduce generation errors and misguide the distillation process. Moreover, the distillation loss functions introduced in previous art struggle to align the most informative part due to the complex distribution of LLMs’ outputs. To address these problems, we propose a multi-granularity semantic revision method for LLM distillation. At the sequence level, we propose a sequence correction and re-generation (SCRG) strategy. SCRG first calculates the semantic cognitive difference between the teacher and student to detect the error token, then corrects it with the teacher-generated one, and re-generates the sequence to reduce generation errors and enhance generation diversity. At the token level, we design a distribution adaptive clipping Kullback-Leibler (DAC-KL) loss as the distillation objective function. DAC-KL loss exploits a learnable sub-network to adaptively extract semantically dense areas from the teacher’s output, avoiding the interference of redundant information in the distillation process. Finally, at the span level, we leverage the span priors of a sequence to compute the probability correlations within spans, and constrain the teacher and student’s probability correlations to be consistent, further enhancing the transfer of semantic information. Extensive experiments across different model families with parameters ranging from 0.1B to 13B demonstrate the superiority of our method compared to existing methods.

1 Introduction

The remarkable advancements in auto-regressive Large Language Models (LLMs) kaplan2020scaling ; Wei_2022 ; Radford_Wu_Child_Luan_Amodei_Sutskever ; Zhangetal ; Brown_Mann have led to unprecedented breakthroughs in a diverse array of text generative tasks, with numerous open-source models Touvron ; zhang2022opt now available. A crucial factor contributing to this success is the ability to scale up the models, which involves increasing both the amount of training data and the number of model parameters. However, the massive size and computational intensity of these state-of-the-art models pose significant challenges, particularly when it comes to deployment and real-time applications. In contrast, smaller models with limited parameters often sacrifice performance on real-world generation tasks wang2022self . To mitigate these challenges, Knowledge Distillation (KD) hinton2015distilling has emerged as a pivotal technique, enabling the development of smaller, more efficient student models that inherit the strengths of their larger teacher counterparts.

Refer to caption
Figure 1: Knowledge Distillation using Different Sampled Datasets. (a) Traditional KD using a fixed dataset hinton2015distilling . (b) KD using the student-generated dataset, which can be categorized into on-policy based methods agarwal2024policy ; gu2023minillm and the off-policy based method ko2024distillm . (c) Our proposed KD approach, which leverages a sequence correction and re-generation strategy and can be seamlessly integrated with both on-policy and off-policy generation schedules.

Traditional knowledge distillation methods hinton2015distilling ; kim2016sequence directly employ Kullback-Leibler divergence (KLD) as the distillation loss for aligning the output distributions of teacher and student models on a static dataset (see Figure 1 (a)). Unlike these methods, recent LLM distillation methods are exploring diverse divergence loss functions tailored to LLMs and leveraging student-generated datasets to avoid distribution mismatch between the outputs student-generated in the training and inference stages. GKD agarwal2024policy and MiniLLM gu2023minillm propose to exploit reverse KLD as the distillation objective, replacing the commonly used forward KLD. These approaches aim to prevent students from overestimating the low-probability regions of the teacher’s distribution. Also, these methods train the student on self-generated sequences that are on-policy instead of a fixed set of output sequences. Recently, Distillm ko2024distillm proposes an adaptive off-policy student-generation strategy to improve the sample efficiency and high generation time faced in on-policy generation (see Figure 1 (b)). Meanwhile, it designs a new distillation object function i.e., skew KLD loss for better generalizability and convergence. However, relying on student-generated sequences may introduce generation errors and lead to suboptimal learning, as the distillation process becomes vulnerable to the inaccuracies inherent in the student’s predictions. The student model’s limited capacity and biases can further perpetuate these errors, resulting in a distorted representation of the teacher’s knowledge. Moreover, the rich semantic knowledge and the significant variance across different tokens make it challenging for existing distillation objective functions to capture and transfer the essential knowledge within the teacher model’s output distribution.

To address the above-mentioned issues, we introduce a novel multi-level semantic revision approach, across sequence token and span levels, to significantly improve the KD process for LLMs. At the sequence level, we propose a sequence correction and re-generation (SCRG) strategy. We detect the error token in the student-generated sequence and re-generate the sequence from the position of the error token to reduce generation errors and enhance generation diversity. As shown in Figure 1 (c), by assessing the semantic cognitive differences between teacher and student outputs on a token-by-token basis, we identify and correct errors, leading to re-generated sequences that steer the student model towards generating more reliable and diverse samples and can be seamlessly integrated with both on-policy and off-policy generation schedules. At the token level, we employ a distribution adaptive clipping Kullback-Leibler (DAC-KL) loss function, which leverages a learnable sub-network to target semantically salient regions of the output distribution. This loss function effectively filters out redundant information, preserving only the most relevant signals for distillation. Finally, at the span level, we incorporate pre-defined span priors of sequences to align the relations of probability vectors of the student and teacher models, ensuring a consistent transfer of semantic information across related tokens within the same span. Through extensive experiments with different models, including the LLAMA2, OpenLLAMA2, OPT, and GPT2 series, ranging from 0.1B to 13B parameters, we showcase the superiority of our approach over existing knowledge distillation methods.

The contributions of this paper are summarized as follows:

  • We introduce a novel multi-level semantic revision approach to enhance the knowledge distillation (KD) process for large language models (LLMs).

  • At the sequence level, we propose a sequence correction and re-generation strategy to steer the student model towards generating more reliable and diverse sequences.

  • At the token level, we propose a distribution adaptive clipping Kullback-Leibler loss to capture semantically salient regions of the output space.

  • At the span level, we incorporate input span priors to ensure a consistent transfer of semantic knowledge across related tokens.

  • Through extensive experimentation with models ranging from 0.1B to 13B parameters, we demonstrate the superiority of our method over existing KD methods for LLMs.

2 Related work

KD for encoder-only language models.

Pretrained encoder-only language models, such as BERT jiao2019tinybert , can be compressed using the traditional logit distillation hinton2015distilling and feature distillation adriana2015fitnets . These knowledge distillation methods minimize the Kullback-Leibler divergence loss between the outputs of the student and teacher models on a fixed dataset kim2016sequence . Liang et al.liang2020mixkd applied this objective to train students on masked language modelling and text classification tasks. Jiao et al.jiao2019tinybert utilized intermediate representations from each transformer layer of the teacher as transferable knowledge. Despite the potential of KD in encoder-only language models sanh2019distilbert ; liang2023less ; sun2019patient ; liu2022multi , the complex predictions generated by large language models (LLMs) through auto-regressive inference present a new challenge. This paper primarily discusses KD for auto-regressive LLMs.

KD for auto-regression large language models.

Existing knowledge distillation (KD) methods for auto-regressive large language models (LLMs) can be divided into black-box methods for closed-source models such as GPT-3.5 ouyang2022training and GPT-4 achiam2023gpt , and white-box methods for open-source models such as LLaMA Touvron . Black-box methods chen2024knowledge ; jiang2023lion ; hsieh2023distilling cannot access the internal parameters of the teacher model and utilize only the inference results provided by the teacher API taori2023stanford ; chiang2023vicuna ; peng2023instruction . The inference results of the teacher model are used to construct prompt-response pairs, which serve as a new training dataset to fine-tune the student model. In contrast, white-box KD methods ko2024distillm ; agarwal2024policy ; gu2023minillm leverage the internal parameters of the teacher model, providing richer training signals such as the probability distribution of predictions, potentially leading to better student model performance. Our methods primarily address the challenges of existing methods in the realm of white-box KD.

3 Preliminary

Before introducing our method, we provide some preliminary information on KD for LLMs. We consider the inference of LLMs as a vocabulary classification task, where a model p𝑝pitalic_p predicts the conditional probability distribution of a target response y𝑦yitalic_y given a prompt and target sequence pair (x,y)𝑥𝑦(x,y)( italic_x , italic_y ). Let y<i=(y1,y2,.,yi1)y_{<i}=(y_{1},y_{2},....,y_{i-1})italic_y start_POSTSUBSCRIPT < italic_i end_POSTSUBSCRIPT = ( italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … . , italic_y start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT ) denote the generated output sequence up to the (i1)thsuperscript𝑖1𝑡{(i-1)}^{th}( italic_i - 1 ) start_POSTSUPERSCRIPT italic_t italic_h end_POSTSUPERSCRIPT token yi1subscript𝑦𝑖1y_{i-1}italic_y start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT. A token-level auto-regression model outputs a next-token Mlimit-from𝑀M-italic_M -vocabulary probability distribution. Specifically, for the model p𝑝pitalic_p, y^i=p(.|y<i,X)(y^iM)\hat{y}_{i}=p(.|y_{<i},X)(\hat{y}_{i}\in\mathbb{R}^{M})over^ start_ARG italic_y end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_p ( . | italic_y start_POSTSUBSCRIPT < italic_i end_POSTSUBSCRIPT , italic_X ) ( over^ start_ARG italic_y end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT ) represents the probability distribution of the generated ithsuperscript𝑖𝑡i^{th}italic_i start_POSTSUPERSCRIPT italic_t italic_h end_POSTSUPERSCRIPT token, where y^i(0,1)Msubscript^𝑦𝑖superscript01𝑀\hat{y}_{i}\in(0,1)^{M}over^ start_ARG italic_y end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ ( 0 , 1 ) start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT. yip(.|y<i,X)y_{i}\sim p(.|y_{<i},X)italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∼ italic_p ( . | italic_y start_POSTSUBSCRIPT < italic_i end_POSTSUBSCRIPT , italic_X ) is the corresponding output token.

We formulate KD as an optimization problem that aims to minimize the difference between the prediction distribution of a fixed teacher model p(.|y<i,x)p(.|y_{<i},x)italic_p ( . | italic_y start_POSTSUBSCRIPT < italic_i end_POSTSUBSCRIPT , italic_x ) and that of a parameterized student model qθ(.|y<i,x)q_{\theta}(.|y_{<i},x)italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( . | italic_y start_POSTSUBSCRIPT < italic_i end_POSTSUBSCRIPT , italic_x ), using sampled input-output sequence pairs (x𝑥xitalic_x,y𝑦yitalic_y) from the fixed dataset (X𝑋Xitalic_X,Y𝑌Yitalic_Y). θ𝜃\thetaitalic_θ is the student’s parameters to be optimized. The sequence-level distillation with Lysubscript𝐿𝑦L_{y}italic_L start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT tokens employs KL Divergence DKLDsubscript𝐷𝐾𝐿𝐷D_{KLD}italic_D start_POSTSUBSCRIPT italic_K italic_L italic_D end_POSTSUBSCRIPT as the distillation object. The total distillation loss KDsubscript𝐾𝐷\mathcal{L}_{KD}caligraphic_L start_POSTSUBSCRIPT italic_K italic_D end_POSTSUBSCRIPT is broken down into a sum of token-wise distillation:

KD=1Lyi=1LyDKLD(p(.|y<i,x)||qθ(.|y<i,x))=1Lyi=1Lyp(.|y<i,x)logp(.|y<i,x)qθ(.|y<i,x)),\displaystyle\mathcal{L}_{KD}=\tfrac{1}{L_{y}}\sum_{i=1}^{L_{y}}D_{KLD}(p(.|y_% {<i},x)||q_{\theta}(.|y_{<i},x))=\tfrac{1}{L_{y}}\sum_{i=1}^{L_{y}}p(.|y_{<i},% x)log\frac{p(.|y_{<i},x)}{q_{\theta}(.|y_{<i},x))},caligraphic_L start_POSTSUBSCRIPT italic_K italic_D end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG italic_L start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_D start_POSTSUBSCRIPT italic_K italic_L italic_D end_POSTSUBSCRIPT ( italic_p ( . | italic_y start_POSTSUBSCRIPT < italic_i end_POSTSUBSCRIPT , italic_x ) | | italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( . | italic_y start_POSTSUBSCRIPT < italic_i end_POSTSUBSCRIPT , italic_x ) ) = divide start_ARG 1 end_ARG start_ARG italic_L start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_p ( . | italic_y start_POSTSUBSCRIPT < italic_i end_POSTSUBSCRIPT , italic_x ) italic_l italic_o italic_g divide start_ARG italic_p ( . | italic_y start_POSTSUBSCRIPT < italic_i end_POSTSUBSCRIPT , italic_x ) end_ARG start_ARG italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( . | italic_y start_POSTSUBSCRIPT < italic_i end_POSTSUBSCRIPT , italic_x ) ) end_ARG , (1)

where the conditional sequence y𝑦yitalic_y can be easily generated by sampling from the teacher or student model policy, i.e.,{xX,yp(.|x)}\{x\in X,y\sim p(.|x)\}{ italic_x ∈ italic_X , italic_y ∼ italic_p ( . | italic_x ) } oder {xX,yqθ(.|x)}\{x\in X,y\sim q_{\theta}(.|x)\}{ italic_x ∈ italic_X , italic_y ∼ italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( . | italic_x ) } instead of directly {(x,y)(X,Y)}𝑥𝑦𝑋𝑌\{(x,y)\in(X,Y)\}{ ( italic_x , italic_y ) ∈ ( italic_X , italic_Y ) }.

During the distillation process, the student model is also guided by the ground-truth output sequence without querying the policies of the teacher or student models. The supervised fine-tuning (SFT) loss is formulated as

SFT=𝔼(x,y)(X,Y)[logqθ(y|x)].subscriptSFTsubscript𝔼similar-to𝑥𝑦𝑋𝑌delimited-[]𝑙𝑜𝑔subscript𝑞𝜃conditional𝑦𝑥\displaystyle\mathcal{L}_{\text{SFT}}=\mathbb{E}_{(x,y)\sim(X,Y)}[-log\ q_{% \theta}(y|x)].caligraphic_L start_POSTSUBSCRIPT SFT end_POSTSUBSCRIPT = blackboard_E start_POSTSUBSCRIPT ( italic_x , italic_y ) ∼ ( italic_X , italic_Y ) end_POSTSUBSCRIPT [ - italic_l italic_o italic_g italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_y | italic_x ) ] . (2)

4 Multi-Granularity Semantic Revision

In this section, we introduce the proposed multi-granularity semantic revision for LLM distillation, which revises the semantic representation during the knowledge transfer stage at three levels: sequence-level, token-level, and span-level.

4.1 Sequence-level correction and re-generation

Refer to caption
Figure 2: The workflow of sequence correction and re-generation strategy.

As illustrated by Eq. (1), prevalent KD methods agarwal2024policy ; gu2023minillm ; ko2024distillm , utilizes conditional sequences generated from the student model (denoted as yqθ(|x)y\sim q_{\theta}(\cdot|x)italic_y ∼ italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( ⋅ | italic_x ) ) for the distillation process. While these methods are designed to mitigate the training-inference mismatch between the fixed training data and the student’s auto-regressive inferences, they simultaneously risk introducing generation errors. Due to the limited capabilities of the student model, the generated sequences may contain additional errors which reduces the effectiveness of KD. To address this issue, we propose a sequence correction and re-generation (SCRG) strategy (shown in Fig. 2) to detect generation errors and re-generate sequences that steer the student model towards generating reliable and diverse sequences.

We denote the generated n𝑛nitalic_n-token sequence from the student model qθsubscript𝑞𝜃q_{\theta}italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT as y<n+1s=(y1s,y2s,.,yns)y^{s}_{<n+1}=(y^{s}_{1},y^{s}_{2},....,y^{s}_{n})italic_y start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT start_POSTSUBSCRIPT < italic_n + 1 end_POSTSUBSCRIPT = ( italic_y start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_y start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … . , italic_y start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) which correspondences the probability outputs (y^1s,y^2s,.,y^ns)(\hat{y}^{s}_{1},\hat{y}^{s}_{2},....,\hat{y}^{s}_{n})( over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … . , over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ), where yisqθ(.|y<is,x)(1in)y^{s}_{i}\sim q_{\theta}(.|y^{s}_{<i},x)(1\leq i\leq n)italic_y start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∼ italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( . | italic_y start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT start_POSTSUBSCRIPT < italic_i end_POSTSUBSCRIPT , italic_x ) ( 1 ≤ italic_i ≤ italic_n ). Similarly, we denote the teacher model’s output sequence as y<n+1t=(y1t,y2t,.,ynt)y^{t}_{<n+1}=(y^{t}_{1},y^{t}_{2},....,y^{t}_{n})italic_y start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT < italic_n + 1 end_POSTSUBSCRIPT = ( italic_y start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_y start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … . , italic_y start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) and probability outputs (y^1t,y^2t,.,y^nt)(\hat{y}^{t}_{1},\hat{y}^{t}_{2},....,\hat{y}^{t}_{n})( over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … . , over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ). We denote each token of the teacher model’s output sequence as yitp(.|y<is,x)y^{t}_{i}\sim p(.|y^{s}_{<i},x)italic_y start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∼ italic_p ( . | italic_y start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT start_POSTSUBSCRIPT < italic_i end_POSTSUBSCRIPT , italic_x ). We follow previous methods agarwal2024policy ; gu2023minillm ; ko2024distillm using the student-generated outputs as the distillation dataset, and calculate token-wise KLD loss to evaluate the semantic cognitive differences between the teacher and student for each token to detect the position of the error token within the sequence y<n+1tsubscriptsuperscript𝑦𝑡absent𝑛1y^{t}_{<n+1}italic_y start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT < italic_n + 1 end_POSTSUBSCRIPT. We formulate the detection process of the error token yjssuperscriptsubscript𝑦𝑗𝑠y_{j}^{s}italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT as

j=argmax1in(KLD(y^isy^it)ifyisyit).𝑗subscript1𝑖𝑛KLDconditionalsubscriptsuperscript^𝑦𝑠𝑖subscriptsuperscript^𝑦𝑡𝑖ifsubscriptsuperscript𝑦𝑠𝑖subscriptsuperscript𝑦𝑡𝑖\displaystyle j=\mathop{\arg\max}\limits_{1\leq i\leq n}\left(\text{KLD}(\hat{% y}^{s}_{i}\|\hat{y}^{t}_{i})\ \text{if}\ y^{s}_{i}\neq y^{t}_{i}\right).italic_j = start_BIGOP roman_arg roman_max end_BIGOP start_POSTSUBSCRIPT 1 ≤ italic_i ≤ italic_n end_POSTSUBSCRIPT ( KLD ( over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∥ over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) if italic_y start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ≠ italic_y start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) . (3)

We then replace the yjssuperscriptsubscript𝑦𝑗𝑠y_{j}^{s}italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT by yjtsuperscriptsubscript𝑦𝑗𝑡y_{j}^{t}italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT to construct new samples and re-generate the student output sequence and each token in y<n+1ssubscriptsuperscript𝑦𝑠absent𝑛1y^{s}_{<n+1}italic_y start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT start_POSTSUBSCRIPT < italic_n + 1 end_POSTSUBSCRIPT is formulated as

yis{qθ(.|y<is,x) if i<jp(.|y<is,x) if i=jqθ(.|y<i,js,yjt,x) if i>j.\displaystyle y^{s}_{i}\sim\begin{cases}q_{\theta}(.|y^{s}_{<i},x)&\text{ if }% i<j\\ p(.|y^{s}_{<i},x)&\text{ if }i=j\\ q_{\theta}(.|y^{s}_{<i,\neq j},y^{t}_{j},x)&\text{ if }i>j.\end{cases}italic_y start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∼ { start_ROW start_CELL italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( . | italic_y start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT start_POSTSUBSCRIPT < italic_i end_POSTSUBSCRIPT , italic_x ) end_CELL start_CELL if italic_i < italic_j end_CELL end_ROW start_ROW start_CELL italic_p ( . | italic_y start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT start_POSTSUBSCRIPT < italic_i end_POSTSUBSCRIPT , italic_x ) end_CELL start_CELL if italic_i = italic_j end_CELL end_ROW start_ROW start_CELL italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( . | italic_y start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT start_POSTSUBSCRIPT < italic_i , ≠ italic_j end_POSTSUBSCRIPT , italic_y start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , italic_x ) end_CELL start_CELL if italic_i > italic_j . end_CELL end_ROW (4)

Our SCRG strategy can be seamlessly integrated with existing on-policy sampling agarwal2024policy and off-policy sampling ko2024distillm . By incorporating an adaptive scheduler ko2024distillm for student-model generation, we enhance the efficiency of our sampling process.

4.2 Token-level DAC-KL loss function

The probability output of LLMs is a high-dimensional vector for each token. However, existing modified Kullback-Leibler divergence (KLD) loss functions, used as knowledge distillation objectives, struggle to effectively capture the valuable distribution with high semantic knowledge from the teacher network. They either underfitt the the teacher’s distribution, as seen in forward KLD, or tend to overfit to a part of the high-probability region, as seen in reverse KLD. To address this issue, we design a Distribution-Adaptive Clipping Kullback-Leibler (DAC-KL) loss function (in Fig. 3) to capture high-density semantic regions of the teacher’s output probability distribution, which can be more easily imitated by the student models with limited capacity.

The probability outputs at the ithsuperscript𝑖𝑡i^{th}italic_i start_POSTSUPERSCRIPT italic_t italic_h end_POSTSUPERSCRIPT token position of both the teacher and student models are high-dimensional probability vectors with M𝑀Mitalic_M tokens, which are denoted as

y^it=p(.|y<is,x)=[v1t,v2t,,vMt]M,\displaystyle\hat{y}^{t}_{i}=p(.|y^{s}_{<i},x)=[v^{t}_{1},v^{t}_{2},...,v^{t}_% {M}]\in\mathbb{R}^{M},over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_p ( . | italic_y start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT start_POSTSUBSCRIPT < italic_i end_POSTSUBSCRIPT , italic_x ) = [ italic_v start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_v start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , italic_v start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT ] ∈ blackboard_R start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT , (5)
y^is=qθ(.|y<is,x)=[v1s,v2s,,vMs]M.\displaystyle\hat{y}^{s}_{i}=q_{\theta}(.|y^{s}_{<i},x)=[v^{s}_{1},v^{s}_{2},.% ..,v^{s}_{M}]\in\mathbb{R}^{M}.over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_q start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( . | italic_y start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT start_POSTSUBSCRIPT < italic_i end_POSTSUBSCRIPT , italic_x ) = [ italic_v start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_v start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , italic_v start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT ] ∈ blackboard_R start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT .

We input these two probability vectors to a learnable MLP sub-network fsubsubscript𝑓𝑠𝑢𝑏f_{sub}italic_f start_POSTSUBSCRIPT italic_s italic_u italic_b end_POSTSUBSCRIPT to predict the upper limit quantile u[0,1]𝑢01u\in[0,1]italic_u ∈ [ 0 , 1 ] and the lower limit quantile l[0,u]𝑙0𝑢l\in[0,u]italic_l ∈ [ 0 , italic_u ] of the probability distribution y^itsubscriptsuperscript^𝑦𝑡𝑖\hat{y}^{t}_{i}over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT. We formulate this process as

u,l=σ(fsub(y^itsort(y^it)y^is)),𝑢𝑙𝜎subscript𝑓subsuperscriptsubscript^𝑦𝑖𝑡delimited-∣∣𝑠𝑜𝑟𝑡superscriptsubscript^𝑦𝑖𝑡superscriptsubscript^𝑦𝑖𝑠\displaystyle u,l=\sigma(f_{\text{sub}}(\hat{y}_{i}^{t}\mid sort(\hat{y}_{i}^{% t})\mid\hat{y}_{i}^{s})),italic_u , italic_l = italic_σ ( italic_f start_POSTSUBSCRIPT sub end_POSTSUBSCRIPT ( over^ start_ARG italic_y end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ∣ italic_s italic_o italic_r italic_t ( over^ start_ARG italic_y end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) ∣ over^ start_ARG italic_y end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT ) ) , (6)

where σ()𝜎\sigma(\cdot)italic_σ ( ⋅ ) is the SIGMOID activation, sort()𝑠𝑜𝑟𝑡sort(\cdot)italic_s italic_o italic_r italic_t ( ⋅ ) is the decending sort operation, and \mid represents the concatenation operation, l𝑙litalic_l is clipped into the range [0,u]0𝑢[0,u][ 0 , italic_u ].

The predicted quantiles u𝑢uitalic_u and l𝑙litalic_l are used to adaptively clip out the high-density semantic classes from the teacher’s probability vector y^itsubscriptsuperscript^𝑦𝑡𝑖\hat{y}^{t}_{i}over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT. We utilize the clipped high-density classes and the target class with the most probability value to construct a new probability vector y^itsubscriptsuperscript^𝑦𝑡𝑖\hat{y}^{t*}_{i}over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT italic_t ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, which is formulated as

y^it=[{vitlvitu}1iMmax(v1t,v2t,,vMt)].subscriptsuperscript^𝑦𝑡𝑖delimited-[]conditionalsubscriptconditional-setsubscriptsuperscript𝑣𝑡𝑖𝑙subscriptsuperscript𝑣𝑡𝑖𝑢1𝑖𝑀subscriptsuperscript𝑣𝑡1subscriptsuperscript𝑣𝑡2subscriptsuperscript𝑣𝑡𝑀\displaystyle\hat{y}^{t*}_{i}=\left[\left\{v^{t}_{i}\mid l\leq v^{t}_{i}\leq u% \right\}_{1\leq i\leq M}\mid\max(v^{t}_{1},v^{t}_{2},\ldots,v^{t}_{M})\right].over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT italic_t ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = [ { italic_v start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∣ italic_l ≤ italic_v start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ≤ italic_u } start_POSTSUBSCRIPT 1 ≤ italic_i ≤ italic_M end_POSTSUBSCRIPT ∣ roman_max ( italic_v start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_v start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , italic_v start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT ) ] . (7)

The high-density classes and the target class contain the most knowledge in the teacher’s probability distribution. Based on the corresponding positions of the clipped classes and target class of y^itsubscriptsuperscript^𝑦𝑡𝑖\hat{y}^{t*}_{i}over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT italic_t ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, we construct the student’s new probability vector y^issubscriptsuperscript^𝑦𝑠𝑖\hat{y}^{s*}_{i}over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT italic_s ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT. Then, we adopt a vanilla KLD to calculate the sum of token-wise distillation loss and the final loss is calculated on the dataset (X𝑋Xitalic_X,Y𝑌Yitalic_Y):

DAC-KLD=ExX[1Lysi=1Lysy^itlogy^ity^is],subscriptDAC-KLDsubscript𝐸similar-to𝑥𝑋delimited-[]1subscript𝐿superscript𝑦𝑠superscriptsubscript𝑖1subscript𝐿superscript𝑦𝑠subscriptsuperscript^𝑦𝑡𝑖𝑙𝑜𝑔subscriptsuperscript^𝑦𝑡𝑖subscriptsuperscript^𝑦𝑠𝑖\displaystyle\mathcal{L}_{\text{DAC-KLD}}=E_{x\sim X}[\tfrac{1}{L_{y^{s*}}}% \sum_{i=1}^{L_{y^{s*}}}\hat{y}^{t*}_{i}log\frac{\hat{y}^{t*}_{i}}{\hat{y}^{s*}% _{i}}],caligraphic_L start_POSTSUBSCRIPT DAC-KLD end_POSTSUBSCRIPT = italic_E start_POSTSUBSCRIPT italic_x ∼ italic_X end_POSTSUBSCRIPT [ divide start_ARG 1 end_ARG start_ARG italic_L start_POSTSUBSCRIPT italic_y start_POSTSUPERSCRIPT italic_s ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L start_POSTSUBSCRIPT italic_y start_POSTSUPERSCRIPT italic_s ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_POSTSUPERSCRIPT over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT italic_t ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_l italic_o italic_g divide start_ARG over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT italic_t ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG start_ARG over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT italic_s ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG ] , (8)

where Lyssubscript𝐿superscript𝑦𝑠L_{y^{s*}}italic_L start_POSTSUBSCRIPT italic_y start_POSTSUPERSCRIPT italic_s ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT is the length of the sequence generated from the proposed SCRG strategy.

Refer to caption
Figure 3: The workflow of the DAC-KL loss function.

4.3 Span-level correlation consistency

Motivated by the work liu2022multi , we utilize the pre-defined chunker kiss2006unsupervised to extract spans (including noun phrases, verb phrases, and prepositional phrases) that have complete meanings from the input sequences, which split a sequence into several token sets. For each token in the input sequence, LLMs predict a high-dimensional probability vector. The relations between tokens within the same span should maintain consistent relations in the transformed probability space. Constraining the relation consistency between the outputs of the student and the teacher models is crucial to transfer semantic knowledge, as shown in Fig. 4.

We divide a probability sequence [y^1,y^2,,y^n]subscript^𝑦1subscript^𝑦2subscript^𝑦𝑛\left[\hat{y}_{1},\hat{y}_{2},...,\hat{y}_{n}\right][ over^ start_ARG italic_y end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , over^ start_ARG italic_y end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , over^ start_ARG italic_y end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ] into nssubscript𝑛𝑠n_{s}italic_n start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT spans s=[s1,s2,,sns]𝑠subscript𝑠1subscript𝑠2subscript𝑠subscript𝑛𝑠s=\left[s_{1},s_{2},...,s_{n_{s}}\right]italic_s = [ italic_s start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_s start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , italic_s start_POSTSUBSCRIPT italic_n start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT ] accoreding to the pre-defined span priors from [y1,y2,,yn]subscript𝑦1subscript𝑦2subscript𝑦𝑛\left[y_{1},y_{2},...,y_{n}\right][ italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , italic_y start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ]. Here, si=[y^j,y^j+1,,y^j+nsi1]subscript𝑠𝑖subscript^𝑦𝑗subscript^𝑦𝑗1subscript^𝑦𝑗subscript𝑛subscript𝑠𝑖1s_{i}=\left[\hat{y}_{j},\hat{y}_{j+1},...,\hat{y}_{j+n_{s_{i}}-1}\right]italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = [ over^ start_ARG italic_y end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , over^ start_ARG italic_y end_ARG start_POSTSUBSCRIPT italic_j + 1 end_POSTSUBSCRIPT , … , over^ start_ARG italic_y end_ARG start_POSTSUBSCRIPT italic_j + italic_n start_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT - 1 end_POSTSUBSCRIPT ] represents ithspansuperscript𝑖𝑡𝑠𝑝𝑎𝑛i^{th}spanitalic_i start_POSTSUPERSCRIPT italic_t italic_h end_POSTSUPERSCRIPT italic_s italic_p italic_a italic_n, whcih starts at the jthsuperscript𝑗𝑡j^{th}italic_j start_POSTSUPERSCRIPT italic_t italic_h end_POSTSUPERSCRIPT token of the sequence and contains nsisubscript𝑛subscript𝑠𝑖n_{s_{i}}italic_n start_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT tokens. Both the student and teacher model outputs adhere to the same span priors for token divisions. Consequently, we divide the probability outputs of the student and teacher models into spans, denoting the ithsuperscript𝑖𝑡i^{th}italic_i start_POSTSUPERSCRIPT italic_t italic_h end_POSTSUPERSCRIPT span as

sis=[y^js,y^j+1s,,y^j+nsi1s],sit=[y^jt,y^j+1t,,y^j+nsi1t].formulae-sequencesubscriptsuperscript𝑠𝑠𝑖subscriptsuperscript^𝑦𝑠𝑗subscriptsuperscript^𝑦𝑠𝑗1subscriptsuperscript^𝑦𝑠𝑗subscript𝑛subscript𝑠𝑖1subscriptsuperscript𝑠𝑡𝑖subscriptsuperscript^𝑦𝑡𝑗subscriptsuperscript^𝑦𝑡𝑗1subscriptsuperscript^𝑦𝑡𝑗subscript𝑛subscript𝑠𝑖1\displaystyle s^{s}_{i}=\left[\hat{y}^{s}_{j},\hat{y}^{s}_{j+1},...,\hat{y}^{s% }_{j+n_{s_{i}}-1}\right],s^{t}_{i}=\left[\hat{y}^{t}_{j},\hat{y}^{t}_{j+1},...% ,\hat{y}^{t}_{j+n_{s_{i}}-1}\right].italic_s start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = [ over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j + 1 end_POSTSUBSCRIPT , … , over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j + italic_n start_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT - 1 end_POSTSUBSCRIPT ] , italic_s start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = [ over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j + 1 end_POSTSUBSCRIPT , … , over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j + italic_n start_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT - 1 end_POSTSUBSCRIPT ] . (9)

Next, we calculate the correlation between two adjacent tokens within the same spans and ensure consistency of this correlation between the probability outputs of the student model and the teacher model. To achieve this, we utilize the L2 distance to align the consistency. The span consistency loss is defined as follows:

span=ExX[1nsi=1ns1nsi(y^js,y^j+1s)sis,(y^jt,y^j+1t)sity^jsy^j+1sy^jty^j+1t2],subscriptspansubscript𝐸similar-to𝑥𝑋delimited-[]1subscript𝑛𝑠superscriptsubscript𝑖1subscript𝑛𝑠1subscript𝑛subscript𝑠𝑖subscriptformulae-sequencesubscriptsuperscript^𝑦𝑠𝑗subscriptsuperscript^𝑦𝑠𝑗1subscriptsuperscript𝑠𝑠𝑖subscriptsuperscript^𝑦𝑡𝑗subscriptsuperscript^𝑦𝑡𝑗1subscriptsuperscript𝑠𝑡𝑖subscriptnormsubscriptsuperscript^𝑦𝑠𝑗subscriptsuperscript^𝑦𝑠𝑗1subscriptsuperscript^𝑦𝑡𝑗subscriptsuperscript^𝑦𝑡𝑗12\displaystyle\mathcal{L}_{\text{span}}=E_{x\sim X}[\frac{1}{n_{s}}\sum_{i=1}^{% n_{s}}\frac{1}{n_{s_{i}}}\sum_{(\hat{y}^{s}_{j},\hat{y}^{s}_{j+1})\in s^{s}_{i% },(\hat{y}^{t}_{j},\hat{y}^{t}_{j+1})\in s^{t}_{i}}\left\|\hat{y}^{s}_{j}\circ% \hat{y}^{s}_{j+1}-\hat{y}^{t}_{j}\circ\hat{y}^{t}_{j+1}\right\|_{2}],caligraphic_L start_POSTSUBSCRIPT span end_POSTSUBSCRIPT = italic_E start_POSTSUBSCRIPT italic_x ∼ italic_X end_POSTSUBSCRIPT [ divide start_ARG 1 end_ARG start_ARG italic_n start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_POSTSUPERSCRIPT divide start_ARG 1 end_ARG start_ARG italic_n start_POSTSUBSCRIPT italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG ∑ start_POSTSUBSCRIPT ( over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j + 1 end_POSTSUBSCRIPT ) ∈ italic_s start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , ( over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j + 1 end_POSTSUBSCRIPT ) ∈ italic_s start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∥ over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∘ over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j + 1 end_POSTSUBSCRIPT - over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∘ over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j + 1 end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ] , (10)

where ||2\left|\cdot\right|_{2}| ⋅ | start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT represents the L2 distance function, and \circ denotes the Hadamard multiplication operation calculating correlation in the high-dimensional probability space. It is important to note that the output sequence is also generated by the student using the SCRG strategy. For simplicity, we adopt a standard symbol representation for y^jtsubscriptsuperscript^𝑦𝑡𝑗\hat{y}^{t}_{j}over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT and y^jssubscriptsuperscript^𝑦𝑠𝑗\hat{y}^{s}_{j}over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT instead of y^jtsubscriptsuperscript^𝑦𝑡𝑗\hat{y}^{t*}_{j}over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT italic_t ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT and y^jssubscriptsuperscript^𝑦𝑠𝑗\hat{y}^{s*}_{j}over^ start_ARG italic_y end_ARG start_POSTSUPERSCRIPT italic_s ∗ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT.

Refer to caption
Figure 4: The workflow of the span-level correlation distillation. \circ denotes Hadamard multiplication.

4.4 Overall Optimization

We use the proposed KD method in the SFT stage of based models. The student model is supervised by the distillation loss, guided by the finetuned teacher model, and also supervised by the SFT loss. The overall optimization objective for the student model is formulated as

overall=SFT+DAC-KLD+span.subscript𝑜𝑣𝑒𝑟𝑎𝑙𝑙subscriptSFTsubscriptDAC-KLDsubscriptspan\displaystyle\mathcal{L}_{overall}=\mathcal{L}_{\text{SFT}}+\mathcal{L}_{\text% {DAC-KLD}}+\mathcal{L}_{\text{span}}.caligraphic_L start_POSTSUBSCRIPT italic_o italic_v italic_e italic_r italic_a italic_l italic_l end_POSTSUBSCRIPT = caligraphic_L start_POSTSUBSCRIPT SFT end_POSTSUBSCRIPT + caligraphic_L start_POSTSUBSCRIPT DAC-KLD end_POSTSUBSCRIPT + caligraphic_L start_POSTSUBSCRIPT span end_POSTSUBSCRIPT . (11)

where SFTsubscriptSFT\mathcal{L}_{\text{SFT}}caligraphic_L start_POSTSUBSCRIPT SFT end_POSTSUBSCRIPT represents the SFT loss, DAC-KLDsubscriptDAC-KLD\mathcal{L}_{\text{DAC-KLD}}caligraphic_L start_POSTSUBSCRIPT DAC-KLD end_POSTSUBSCRIPT represents the distillation loss using the DAC-KLD object, and spansubscriptspan\mathcal{L}_{\text{span}}caligraphic_L start_POSTSUBSCRIPT span end_POSTSUBSCRIPT represents the span consistency loss which assists the distillation process.

5 Experiments

In this section, we experiment by initially fine-tuning a large model on the dataset comprising instructions and corresponding responses (X,Y)𝑋𝑌(X,Y)( italic_X , italic_Y ), establishing it as the teacher model p𝑝pitalic_p. Subsequently, we examine various knowledge distillation methods for distilling a smaller student model under the guidance of the teacher, evaluating the instruction-following performance of the distilled model.

5.1 Experimental description

Dataset and evaluation metrics. We conduct the KD experiments on five instruction-following datasets: (1) Dolly Evaluation dolly2023introducing is a a sampled subset of atabricks-dolly-15k 111https://github.com/databrickslabs/dolly/tree/master (Dolly) dataset consists of 500 samples. It covers various behavioural categories such as brainstorming, classification, closed QA, generation, information extraction, open QA, and summarization; (2) Self-Instruct wang2022self is a dataset for language models’ ability to understand and follow instructions. It incorporates 252 expert-written tasks; (3) Vicuna wang2022super is a dataset consisting of 80 challenging questions used for evaluating the Vicuna model. It follows the evaluation methodology introduced by MiniLLM gu2023minillm ; (4) Super-Natural Instruction wang2022super is introduced as a benchmark, and this dataset contains 1,616 diverse NLP tasks along with their expert-written instructions. It covers 76 different task types, and its test set consists of 9K samples from 119 tasks; (5) Unnatural Instruction honovich2022unnatural dataset comprises 240K instructions generated by AI with minimal human involvement. It shows that AI-generated data can be as effective as human-created data for training language models. The core component of this dataset has 60K samples.

We use the ROUGE-L lin2004rouge metric to evaluate the model-generated results and report the average scores of 5 generations for each prompt with different random seeds (10101010, 20202020, 30303030, 40404040, 50505050) for all test datasets. ROUGE-L evaluates the precision of the model’s output by measuring the longest common subsequence between the generated text and the reference text. It is well-suited for large-scale instruction-following evaluation due to its ability to capture both sentence-level structure and content.

Base models and baselines. We distil four kinds of teacher-student model pairs with different model sizes: LLAMA2 touvron2023llama (13B teacher, 7B student), OpenLLAMA2 geng2023openllama (7B teacher, 3B student), OPT zhang2022opt (6.7B teacher, 1.3B student), GPT2 radford2019language (1.5B teacher, 0.1B student).

We benchmark our method against several advanced knowledge distillation methods: (1) SFT Fine-tunes the student model on a fixed dataset in a vanilla manner; (2) KD  hinton2015distilling utilizes KLD on a fixed dataset; (3) SeqKD kim2016sequence fine-tunes on a teacher-generated dataset; (4) ImitKD lin2020autoregressive utilizes KLD on a dataset generated by the student model; (5) GKD agarwal2024policy utilizes Jensen-Shannon Divergence (JSD) agarwal2024policy on a mixture of a student-generated dataset and a fixed dataset; (6) MiniLLM gu2023minillm utilizes a policy gradient approach on a dataset generated by the student model; (7) DistiLLM ko2024distillm utilizes Skew KLD on a student-generated dataset sampling with an off-policy scheduler.

All of our baseline experiments are re-implemented using the open-source code 222https://github.com/jongwooko/distillm on the same GPU servers utilized by our method. Additionally, we execute these experiments using the exact hyper-parameters as specified in the original codebase.

Training details. We follow MiniLLM gu2023minillm to finetune base models using the training set of the databricks-dolly-15k. Dolly is divided into 14K samples as the training set and equally left 500 samples for validation and testing, respectively. After the fine-tuning process, we select the best-performing model based on its validation set of the Dolly dataset. We then proceeded to test this selected model on the test sets of the five above-mentioned datasets.

For training the teacher and student models, we utilize four A100 (40GB) GPUs for the OpenLLAMA2, OPT, and GPT2 models and four A800 (80GB) GPUs for the LLAMA2 models. A fixed learning rate of 5e-4 is applied consistently across all experiments. Specifically, for the LLAMA2, OpenLLAMA2, and OPT models, we follow DistiLLM ko2024distillm , employing low-rank adaptation (LoRA) for the query and value weights with a rank of 16 for 10 epochs. In contrast, for the GPT2 models, we fine-tune all parameters for 20 epochs.

Table 1: Comparison of state-of-the-art knowledge distillation methods evaluated by the ROUGE-L metric lin2004rouge . ‘Average’ is the average score on the five test datasets The bold and underlined markings signify the best and second-best results, respectively.
Methods Parameters Datasets
Dolly Evaluation Self-Instruct Vicuna Super-Natural Unnatural Average
LLAMA2 Teacher (SFT) 13B 29.8241 21.0617 19.4909 35.8318 35.7802 28.3978
SFT 7B 27.3504 28.4430 18.7567 28.4430 30.2788 26.6544
KD hinton2015distilling 27.0737 20.7076 17.9850 30.3350 31.4926 25.5188
SeqKD kim2016sequence 26.2689 19.0278 18.4602 25.9461 28.1010 23.5608
ImitKD lin2020autoregressive 27.4359 20.6792 18.8058 29.1726 30.5764 25.3340
GKD agarwal2024policy 28.4662 22.1717 20.7564 33.3325 33.2682 27.5990
MiniLLM gu2023minillm 30.6447 23.9493 22.3010 34.3454 36.0828 29.4646
DistiLLM ko2024distillm 30.7277 25.2181 20.8356 36.1154 37.5072 30.0808
Ours 31.9195 25.4937 21.7810 37.9154 38.1257 31.0471
OpenLLAMA2 Teacher (SFT) 7B 27.5100 17.9400 17.6900 32.7500 31.4000 25.4580
SFT 3B 24.4000 16.1300 16.5600 27.4862 28.0500 22.5252
KD hinton2015distilling 25.4814 19.1805 16.6562 31.3307 31.8136 24.8924
SeqKD kim2016sequence 24.8184 16.0980 17.2718 29.4081 28.7395 23.2672
ImitKD lin2020autoregressive 25.3600 18.1600 17.5700 31.0900 28.9600 24.2280
GKD agarwal2024policy 26.8525 20.1060 18.4337 34.4383 32.4797 26.4621
MiniLLM gu2023minillm 28.4950 21.7770 20.6260 35.4001 34.7011 28.1999
DistiLLM ko2024distillm 27.8546 19.3456 19.1723 34.4973 34.9434 27.1627
Ours 29.3062 20.5835 19.0086 37.6171 37.2410 28.8724
OPT Teacher (SFT) 6.7B 25.8758 14.8408 16.4199 24.9551 25.8377 21.5859
SFT 1.3B 22.7595 11.9784 15.2267 22.8556 24.5763 19.4793
KD hinton2015distilling 22.4476 13.4676 13.9975 23.7679 25.4132 19.8188
SeqKD kim2016sequence 22.4556 12.1588 14.8157 21.4574 24.5907 19.0956
ImitKD lin2020autoregressive 21.6624 12.9286 15.8039 22.0426 24.9619 19.4799
GKD agarwal2024policy 22.5062 12.8309 15.3303 23.8537 26.6441 20.2330
MiniLLM gu2023minillm 24.3168 13.5880 17.4633 26.6789 28.7968 22.1688
DistiLLM ko2024distillm 24.7311 14.9932 16.3293 27.1037 29.3285 22.4972
Ours 27.1486 17.3016 14.8491 32.0618 34.9709 25.2664
GPT2 Teacher (SFT) 1.5B 27.0357 14.5594 16.7390 24.9659 29.4874 22.5575
SFT 0.1B 23.8269 9.6682 14.9022 16.4117 18.3221 16.6262
KD hinton2015distilling 23.2172 10.0899 14.9954 15.4826 18.9597 16.5490
SeqKD kim2016sequence 23.7248 10.3935 14.6558 19.8119 22.7425 18.2657
ImitKD lin2020autoregressive 21.7724 10.1876 15.4640 17.1918 20.8907 17.1013
GKD agarwal2024policy 23.3150 10.3364 15.9384 16.0802 17.7699 16.6880
MiniLLM gu2023minillm 23.8142 12.2771 17.0158 23.8555 24.9101 20.3745
DistiLLM ko2024distillm 25.6114 12.5988 16.7521 24.6374 27.5827 21.4365
Ours 26.5614 13.1174 17.6781 24.6973 27.4025 21.8913

5.2 Comparison with state-of-the-art KD methods

We present the quantitative comparison of state-of-the-art knowledge distillation methods evaluated using the ROUGE-L metric in Table 1. It is observed that:

(1) Our method outperforms existing methods in most distillation tasks, with only a few achieving second-best results, across five test datasets, including the LLAMA2, OPT, OpenLLAMA2, and GPT2 series of large language models. Particularly for the OPT datasets, our method shows an average score improvement of over 12% compared to the second-best performing methods.

(2) The KD methods, such as GKD, MiniLLM, and DistiLLM, utilizing student-generated datasets show a greater improvement in enhancing student performance compared to those using the fixed dataset. Furthermore, the distilled student models generally outperform the teacher models, which can be attributed to the mismatch between teacher-forcing training and free-run generation, i.e., exposure bias bengio2015scheduled . Our method can improve the performance of all student models on average scores of the five test datasets by at least 15%.

(3) We also provide some representative instruction-following cases in Section A.3, further highlighting the effectiveness and superiority of our method in achieving high-quality answers.

5.3 Ablation analysis

We conduct an ablation analysis of the proposed methods on the Dolly Validation set, Dolly Evaluation set and Self-Instruct dataset.

Overall Ablation. We conduct an overall ablation study to validate the effectiveness of the proposed multi-granularity semantic revision, in Table 2. Initially, employing sequence correction alone yields moderate performance improvement across all evaluation datasets compared to the vanilla result. Upon the addition of DAC-KL, an improvement is observed. A further enhancement is achieved with the inclusion of span-level relation distillation, resulting in more notable performance gains. The most significant improvement is witnessed when all components of the proposed method are combined, leading to the highest performance metrics across all evaluation datasets. This demonstrates that each component contributes to the overall enhancement of model performance, with the combined approach yielding the most substantial improvements.

Different student-generation methods. To validate the effectiveness of the proposed SCRG strategy, we compare it with different student-generation methods for sampling the distillation dataset. As illustrated in Table 3(b), we observe substantial performance enhancements with SCRG compared to existing student-generation methods. For on-policy sampling, We follow GKD agarwal2024policy to utilize a mixture of student-generated and fixed datasets. For off-policy sampling, we follow Distillm ko2024distillm to adopt an adaptive student-generation schedule for improved sample efficiency. Remarkably, when employing both off-policy and on-policy sampling methods, SCRG achieves notably higher scores across all evaluation metrics. This underscores the effectiveness of SCRG in augmenting performance by improving the quality and diversity of generated sequences. Additionally, we provide an example of SCRG in Section A.4.

Different distillation loss functions. To validate the effectiveness of the proposed DAC-KL loss, we compare it with different loss functions in Table 3(c). The results demonstrate that DAC-KL significantly outperforms other loss functions across all evaluation metrics. This indicates that DAC-KL effectively captures high-density semantic regions of the teacher’s output probability distribution, facilitating easier imitation by the student models. Additionally, we provide visualized examples of the DAC-KL impact on the probability distribution of the teacher’s output depicted using kernel density estimation in Section A.2.

Different components involved in DAC-KL. The DAC-KL loss guides the distillation process to effectively transfer knowledge from the high-density semantic classes and the target class of the teacher’s probability outputs. As evidenced by the results in Table 3(b), when both high-density and target classes are considered, the DAC-KL loss achieves the highest validation, evaluation, and self-instruct scores compared to other configurations. This indicates that focusing on these specific classes leads to better performance in knowledge distillation, highlighting the importance of targeting relevant semantic regions for the effective transfer of knowledge.

Table 2: Ablation study of the proposed multi-granularity semantic revision.
Sequence-correcting DAC-KL Span Relation Dolly Validation Dolly Evaluation Self-Instruct
29.1874 24.1603 14.8578
29.6982 24.5307 14.9485
30.3486 26.9012 17.2392
31.2575 27.1486 17.3016
Table 3: Ablation studies on the proposed SCRG strategy and the DAC-KL loss.
(a) Different student-generation methods
Generation Validation Evaluation Self-Instruct
On-policy lin2020autoregressive 30.3786 26.0948 16.1853
Mixed agarwal2024policy 30.8335 26.4667 16.7789
Off-policy ko2024distillm 30.4539 27.0961 16.7745
SCRG w. off policy 31.0444 27.1453 17.2574
SCRG w. on policy 31.2575 27.1486 17.3016
(b) Components involved in DAC-KL losses
High-density Target Validation Evaluation Self-Instruct
29.3490 24.3130 14.3810
21.3936 19.5050 11.5035
31.2575 27.1486 17.3016
(c) Different distillation loss functions
Loss Function Validation Evaluation Self-Instruct
Forward KL 28.9631 24.1922 14.5108
Reverse KL 30.0209 25.6688 14.7184
Symmetric KL 30.2584 27.0961 16.7745
Generalized JSD agarwal2024policy 27.8759 23.3144 14.3154
TVD wen2023f 30.1973 25.0033 14.6138
SRKL ko2024distillm 29.9858 25.4849 14.9514
SFKL ko2024distillm 29.1226 25.1400 14.4412
DAC-KL 31.2575 27.14864 17.3016

6 Conclusion and Limitation

In this paper, we address the challenges in knowledge distillation for LLMs by proposing a novel multi-level semantic revision approach at the sequence, token, and span levels. At the sequence level, our sequence correction and re-generation strategy improves reliability and diversity in student-generated sequences. At the token level, the DAC-KL loss function targets semantically salient regions in the teacher’s probability distribution, filtering out redundant information. At the span level, input span priors ensure consistent transfer of semantic information across related tokens. Our experiments with four various model series, demonstrate the effectiveness of our approach, significantly improving student model performance over existing KD methods. In addition, our experiments and evaluations were conducted primarily on language models in specific domains. The effectiveness of our approach in other domains or tasks may vary, and further research is needed to explore its generalizability.

References

  • [1] Jared Kaplan, Sam McCandlish, Tom Henighan, Tom B Brown, Benjamin Chess, Rewon Child, Scott Gray, Alec Radford, Jeffrey Wu, and Dario Amodei. Scaling laws for neural language models. arXiv preprint arXiv:2001.08361, 2020.
  • [2] Jason Wei, Yi Tay, Rishi Bommasani, Colin Raffel, Barret Zoph, Sebastian Borgeaud, Dani Yogatama, Maarten Bosma, Denny Zhou, Donald Metzler, EdH. Chi, Tatsunori Hashimoto, Oriol Vinyals, Percy Liang, Jeff Dean, and William Fedus. Emergent abilities of large language models. Jun 2022.
  • [3] Alec Radford, Jeffrey Wu, Rewon Child, David Luan, Dario Amodei, and Ilya Sutskever. Language models are unsupervised multitask learners.
  • [4] Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen, Christopher Dewan, Mona Diab, Xian Li, Victoria Lin, Todor Mihaylov, Myle Ott, Sam Shleifer, Kurt Shuster, Daniel Simig, Singh Koura, Anjali Sridhar, Tianlu Wang, and Luke Zettlemoyer. Opt: Open pre-trained transformer language models.
  • [5] TomB. Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared Kaplan, Prafulla Dhariwal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, Sandhini Agarwal, Ariel Herbert-Voss, Gretchen Krueger, Thomas Henighan, Rewon Child, Aditya Ramesh, DanielM. Ziegler, Jeffrey Wu, Clemens Winter, Christopher Hesse, Mark Chen, Eric Sigler, Mateusz Litwin, Scott Gray, Benjamin Chess, Jack Clark, Christopher Berner, Samuel McCandlish, Alec Radford, Ilya Sutskever, and Dario Amodei. Language models are few-shot learners. arXiv: Computation and Language,arXiv: Computation and Language, May 2020.
  • [6] Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timoth’ee Lacroix, Baptiste Rozi‘ere, Naman Goyal, Eric Hambro, Faisal Azhar, Aurelien Rodriguez, Armand Joulin, Edouard Grave, and Guillaume Lample. Llama: Open and efficient foundation language models.
  • [7] Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen, Christopher Dewan, Mona Diab, Xian Li, Xi Victoria Lin, et al. Opt: Open pre-trained transformer language models. arXiv preprint arXiv:2205.01068, 2022.
  • [8] Yizhong Wang, Yeganeh Kordi, Swaroop Mishra, Alisa Liu, Noah A Smith, Daniel Khashabi, and Hannaneh Hajishirzi. Self-instruct: Aligning language models with self-generated instructions. arXiv preprint arXiv:2212.10560, 2022.
  • [9] Geoffrey Hinton, Oriol Vinyals, and Jeff Dean. Distilling the knowledge in a neural network. arXiv preprint arXiv:1503.02531, 2015.
  • [10] Rishabh Agarwal, Nino Vieillard, Yongchao Zhou, Piotr Stanczyk, Sabela Ramos Garea, Matthieu Geist, and Olivier Bachem. On-policy distillation of language models: Learning from self-generated mistakes. In The Twelfth International Conference on Learning Representations, 2024.
  • [11] Yuxian Gu, Li Dong, Furu Wei, and Minlie Huang. Minillm: Knowledge distillation of large language models. In The Twelfth International Conference on Learning Representations, 2023.
  • [12] Jongwoo Ko, Sungnyun Kim, Tianyi Chen, and Se-Young Yun. Distillm: Towards streamlined distillation for large language models. arXiv preprint arXiv:2402.03898, 2024.
  • [13] Yoon Kim and Alexander M Rush. Sequence-level knowledge distillation. arXiv preprint arXiv:1606.07947, 2016.
  • [14] Xiaoqi Jiao, Yichun Yin, Lifeng Shang, Xin Jiang, Xiao Chen, Linlin Li, Fang Wang, and Qun Liu. Tinybert: Distilling bert for natural language understanding. arXiv preprint arXiv:1909.10351, 2019.
  • [15] Romero Adriana, Ballas Nicolas, K Samira Ebrahimi, Chassang Antoine, Gatta Carlo, and Bengio Yoshua. Fitnets: Hints for thin deep nets. Proc. ICLR, 2(3):1, 2015.
  • [16] Kevin J Liang, Weituo Hao, Dinghan Shen, Yufan Zhou, Weizhu Chen, Changyou Chen, and Lawrence Carin. Mixkd: Towards efficient distillation of large-scale language models. arXiv preprint arXiv:2011.00593, 2020.
  • [17] Victor Sanh, Lysandre Debut, Julien Chaumond, and Thomas Wolf. Distilbert, a distilled version of bert: smaller, faster, cheaper and lighter. arXiv preprint arXiv:1910.01108, 2019.
  • [18] Chen Liang, Simiao Zuo, Qingru Zhang, Pengcheng He, Weizhu Chen, and Tuo Zhao. Less is more: Task-aware layer-wise distillation for language model compression. In International Conference on Machine Learning, pages 20852--20867. PMLR, 2023.
  • [19] Siqi Sun, Yu Cheng, Zhe Gan, and Jingjing Liu. Patient knowledge distillation for bert model compression. arXiv preprint arXiv:1908.09355, 2019.
  • [20] Chang Liu, Chongyang Tao, Jiazhan Feng, and Dongyan Zhao. Multi-granularity structural knowledge distillation for language model compression. In Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), pages 1001--1011, 2022.
  • [21] Long Ouyang, Jeffrey Wu, Xu Jiang, Diogo Almeida, Carroll Wainwright, Pamela Mishkin, Chong Zhang, Sandhini Agarwal, Katarina Slama, Alex Ray, et al. Training language models to follow instructions with human feedback. Advances in neural information processing systems, 35:27730--27744, 2022.
  • [22] Josh Achiam, Steven Adler, Sandhini Agarwal, Lama Ahmad, Ilge Akkaya, Florencia Leoni Aleman, Diogo Almeida, Janko Altenschmidt, Sam Altman, Shyamal Anadkat, et al. Gpt-4 technical report. arXiv preprint arXiv:2303.08774, 2023.
  • [23] Hongzhan Chen, Xiaojun Quan, Hehong Chen, Ming Yan, and Ji Zhang. Knowledge distillation for closed-source language models. arXiv preprint arXiv:2401.07013, 2024.
  • [24] Yuxin Jiang, Chunkit Chan, Mingyang Chen, and Wei Wang. Lion: Adversarial distillation of closed-source large language model. arXiv preprint arXiv:2305.12870, 2023.
  • [25] Cheng-Yu Hsieh, Chun-Liang Li, Chih-Kuan Yeh, Hootan Nakhost, Yasuhisa Fujii, Alexander Ratner, Ranjay Krishna, Chen-Yu Lee, and Tomas Pfister. Distilling step-by-step! outperforming larger language models with less training data and smaller model sizes. arXiv preprint arXiv:2305.02301, 2023.
  • [26] Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li, Carlos Guestrin, Percy Liang, and Tatsunori B Hashimoto. Stanford alpaca: An instruction-following llama model, 2023.
  • [27] Wei-Lin Chiang, Zhuohan Li, Zi Lin, Ying Sheng, Zhanghao Wu, Hao Zhang, Lianmin Zheng, Siyuan Zhuang, Yonghao Zhuang, Joseph E Gonzalez, et al. Vicuna: An open-source chatbot impressing gpt-4 with 90%* chatgpt quality. See https://vicuna. lmsys. org (accessed 14 April 2023), 2(3):6, 2023.
  • [28] Baolin Peng, Chunyuan Li, Pengcheng He, Michel Galley, and Jianfeng Gao. Instruction tuning with gpt-4. arXiv preprint arXiv:2304.03277, 2023.
  • [29] Tibor Kiss and Jan Strunk. Unsupervised multilingual sentence boundary detection. Computational linguistics, 32(4):485--525, 2006.
  • [30] Free Dolly. Introducing the world’s first truly open instruction-tuned llm. databricks. com, 2023.
  • [31] Yizhong Wang, Swaroop Mishra, Pegah Alipoormolabashi, Yeganeh Kordi, Amirreza Mirzaei, Anjana Arunkumar, Arjun Ashok, Arut Selvan Dhanasekaran, Atharva Naik, David Stap, et al. Super-naturalinstructions: Generalization via declarative instructions on 1600+ nlp tasks. arXiv preprint arXiv:2204.07705, 2022.
  • [32] Or Honovich, Thomas Scialom, Omer Levy, and Timo Schick. Unnatural instructions: Tuning language models with (almost) no human labor. arXiv preprint arXiv:2212.09689, 2022.
  • [33] Chin-Yew Lin. Rouge: A package for automatic evaluation of summaries. In Text summarization branches out, pages 74--81, 2004.
  • [34] Hugo Touvron, Louis Martin, Kevin Stone, Peter Albert, Amjad Almahairi, Yasmine Babaei, Nikolay Bashlykov, Soumya Batra, Prajjwal Bhargava, Shruti Bhosale, et al. Llama 2: Open foundation and fine-tuned chat models. arXiv preprint arXiv:2307.09288, 2023.
  • [35] Xinyang Geng and Hao Liu. Openllama: An open reproduction of llama. URL: https://github. com/openlm-research/open_llama, 2023.
  • [36] Alec Radford, Jeffrey Wu, Rewon Child, David Luan, Dario Amodei, Ilya Sutskever, et al. Language models are unsupervised multitask learners. OpenAI blog, 1(8):9, 2019.
  • [37] Alexander Lin, Jeremy Wohlwend, Howard Chen, and Tao Lei. Autoregressive knowledge distillation through imitation learning. arXiv preprint arXiv:2009.07253, 2020.
  • [38] Samy Bengio, Oriol Vinyals, Navdeep Jaitly, and Noam Shazeer. Scheduled sampling for sequence prediction with recurrent neural networks. Advances in neural information processing systems, 28, 2015.
  • [39] Yuqiao Wen, Zichao Li, Wenyu Du, and Lili Mou. f-divergence minimization for sequence-level knowledge distillation. arXiv preprint arXiv:2307.15190, 2023.

Appendix A Appendix / supplemental material

A.1 Social Impact

The primary objective of this study is to contribute to the advancement of the field of Machine Learning, without explicitly emphasizing any specific societal consequences. Although smaller models can lead to positive outcomes, such as reduced emissions, it is crucial to conduct a comprehensive study on potential biases associated with model compression. However, there are potential negative impacts to consider. Model compression may inadvertently exacerbate existing biases within data, leading to unfair outcomes, particularly for underrepresented groups. Additionally, the simplification involved in compression could result in the loss of critical nuances and reduce the model’s ability to handle complex tasks accurately.

A.2 Visualized probability distribution of the teacher model

Refer to caption
Figure 5: Examples of the probability distribution of the teacher’s output are depicted using kernel density estimation. The original distribution is represented by the blue line, while the distribution of the adaptively clipped probability classes is shown by the red line. From this picture, we can observe that the DAC-KL loss constrains the regions of probability distribution with dense semantic knowledge. Enforcing student model to imitate the distribution of these regions can effectively mitigate the training interference caused by low-semantic regions for student models with limited learning capacity.

A.3 Examples of model responses on the instruction-following task.

Table 4: Instruction-following cases from the databricks-dolly-15k dataset.
Case #1, Category: classification
Instruction If we were playing a game where we had to identify things that can be found inside a house, which of these would we call out: car, chair, table, park, cloud, microwave.
Input -
Groundtruth Inside a house you can find a chair, a table and a microwave. Cars, clouds and parks are found outside of houses.
SFT car, chair, table, microwave.
KD Car, chair, table and microwave are things that can be found in a house. Park and cloud are unlikely to be found in a house but are things that can be seen from a house.
SeqKD We would call out: chair, table, car and microwave.
ImitKD Car, chair, table and microwave are items that can be found inside of a house. Park and cloud are outdoors items.
MiniLLM Car, chair, table, and microwave can all be found inside a house. Park and cloud cannot be found inside a house because they are neither objects nor locations.
GKD You can find a car, chair, table and microwave inside a house. A park and cloud are not found inside a house.
DISTILLM In this list, we can find a car, chair, table and microwave. A park and a cloud are not things you would typically find inside a house.
Ours In this game, a car, park, and cloud would not be things that could be found inside a house. A chair, table, and microwave would be things that could be find inside a house.
Table 5: Instruction-following cases from the databricks-dolly-15k dataset.
Case #2, Category: closed QA
Instruction How many managers have Arsenal had?
Input There have been twenty permanent and eight caretaker managers of Arsenal since 1897; Stewart Houston has managed the club in two separate spells as caretaker. The most successful person to manage Arsenal is Arsé8ne Wenger, who won three Premier League titles, seven FA Cups and seven Community Shields between 1996 and 2018. Wenger is the club’s longest-serving manager; he surpassed George Allison’s record of 13 years in October 2009. Two Arsenal managers have died in the job 2̆013 Herbert Chapman and Tom Whittaker.
Groundtruth Arsenal have had twenty permanent and eight caretake managers.
SFT There have been thirty-seven managers since 1896.
KD There have been 28 managers of Arsenalf FC.
SeqKD Arsenal has had 20 permanent managers, 8 caretaker managers and 2 manager with dead in the job.
ImitKD There have been thirty-seven managers since 1896.
MiniLLM Since 1897, Arsenal have had twenty-eight managers. Eight of these were caretaker managers. The most successful manager is Arséne Wener, who won three Premier Leagues, seven FA Cups, and seven Community Shield trophies. He is the longest-serving manager, as he surpassed George Alison’s record of twelve years in October 2,009. Two managers have died in this job; Herbert Chapman and Tom Whitaker.
GKD There have been thirty managers of Arsenals football club since 1896. This includes twenty permanent managers and eight caretaker or interim managers.
DISTILLM There have been 20 permanent managers and 8 caretaker managers for a total of 28 managers.
Ours There have been 20 permanent and 8 caretaker managers at Arsenal.
Table 6: Instruction-following cases from the databricks-dolly-15k dataset.
Case #3, Category: brainstorming
Instruction What is a good way to get around without a car?
Input -
Groundtruth If you live in a city you can access public transport such as buses and trains. You can also walk, use a bicycle, or electric scooter to get around.
SFT When you don’t have a car, you need to find alternative ways to get around. If you are able-bodied and live in a city, your best bet would be to use public transportation. Every city has trains, buses and other modes of transportation that are accessible without a car.If you are unable-bodied or have a car that doesn’t work, you may want to consider ride-sharing services such as Uber and Lyft.
KD A few examples of ways to get around without a personal car are using public transportation such as the bus or subway, carpooling with friends or coworkers, or using ride-sharing services such as Uber or Lyft. If none of these options are viable for you, it is still possible to get around without a vehicle; for example, it is often possible to walk or ride a bike for short distances, and in some cases, it is possible to use a bicycle, scooter, or skateboard for longer distances.
SeqKD There are multiple options for getting around without a car. Biking, walking, riding a scooter, riding a bike with a scooter, using public transportation, rideshare, or a moped are all good options. If you are in a smaller town where public transportation is available, walking or riding a bike may also be a good option.
ImitKD Public transportation is a good way to avoid a need for a car. In many cities in the US, public transportation includes buses, subways, and trains that can take you where you need to go. For shorter distances, many cities and towns have bike-shares which can provide an easy way to get around without having to rely on a car.
MiniLLM There are many ways to get around without a driver’s license or car. The most common way is to walk, but you can also use a bicycle, public transportation, or ride-sharing services. If you live in a city, you can also take advantage of public transportation, which is often more affordable than owning a car.
GKD There are multiple ways to get around without a personal vehicle. Public transportation is available in many cities and towns. Many cities also have bike share programs, where you can rent a bike or low-cost electric scooter for short trips. You can also walk, but this may depend on the distance and weather.
DISTILLM There are many ways to get around without cars. You can walk, ride a bike, use public transportation, or ride share services such as Uber and Lyft.
Ours If you don’t have a car, there are still many ways to get around. You can take public transportation, which includes buses, trains, and subways. If you are near a large city, you may also have access to light rail, which is like a bus that runs on rails. If you are in a suburb, you can use public transportation such as a bus or train. If you live in a small town, you may need to take a taxi or ride share service.

A.4 Example of the sequence correction and re-generation (SCRG) strategies

Table 7: Example of the student-generated output sequence using the sequence correction and re-generation strategy. The red token represents the detected position of the error token.
Instruction: What is the difference between men’s and women’s lacrosse

Samples from student: Men’s lacrosse has a limited amount of time to play play play as as as as as as as as as as as as as as as as as as as as

Student’s Prediction: Men’s lacrosse is a of of of to play and play play as as as as as as as as as as as as as as as as as as as as

Teacher’s Prediction: Men’s lacrosse is a limited number of movesouts play each each. opposed opposed they a they opposed opposed opposed opposed opposed opposed opposed opposed opposed opposed opposed opposed they a they opposed opposed opposed opposed opposed opposed opposed opposed opposed opposed

Re-sample: Men’s lacrosse has a limited number of players and women’s lacrosse has a maximum number of players.

A.5 Prompt template for the instruction-following task

Table 8: The prompt template for training and evaluation of instruction-following task experiments.
Below is an instruction that describes a task. Write a response that appropriately completes the request. ### Instruction:
{instruction}
### Input:
{input}
### Response: