ReMamba: Equip Mamba with Effective Long-Sequence Modeling

Danlong Yuan1 2 5, Jiahao Liu5, Bei Li5, Huishuai Zhang1 3, Jingang Wang5, Xunliang Cai5, Dongyan Zhao1 2 3 411footnotemark: 1 Corresponding author.
Abstract

While the Mamba architecture demonstrates superior inference efficiency and competitive performance on short-context natural language processing (NLP) tasks, empirical evidence suggests its capacity to comprehend long contexts is limited compared to transformer-based models. In this study, we investigate the long-context efficiency issues of the Mamba models and propose ReMamba, which enhances Mamba’s ability to comprehend long contexts. ReMamba incorporates selective compression and adaptation techniques within a two-stage re-forward process, incurring minimal additional inference costs overhead. Experimental results on the LongBench and L-Eval benchmarks demonstrate ReMamba’s efficacy, improving over the baselines by 3.2 and 1.6 points, respectively, and attaining performance almost on par with same-size transformer models.

Einführung

Transformers (Vaswani et al. 2017), which form the backbone of most LLMs, encounter substantial challenges when dealing with long texts. The quadratic computational demands and the linear memory costs of the attention mechanism become prohibitive as the text length grows. This complexity poses a significant barrier to effectively modeling long texts, which is crucial for the development of LLMs. To address this, Mamba is proposed as a solution (Gu and Dao 2024). Mamba models utilize a recurrent inference mode that ensures linear time complexity and compress information into the fixed state size, resulting in constant memory demands during inference. Furthermore, Mamba models eliminate the need for positional encoding, theoretically allowing them to handle inputs of any length. Mamba performs competitively against transformers on downsteam tasks. Shortly after, Mamba2 was introduced, simplifying the structured A𝐴Aitalic_A matrix of Mamba to enable faster training and enlarged state size (Dao and Gu 2024).

Refer to caption
Figure 1: A comparison of pretrained Mamba models and Transformers of equivalent size across speed, short-context, and long-context performance metrics. Speed is measured under conditions of 6k input tokens and 1k output tokens. “short scores” represents the average accuracy across six tasks (HellaSwag, PIQA, Arc-E, Arc-C, WinoGrande, OpenbookQA) evaluated within the LM evaluation harness (Gao et al. 2023). “long scores” corresponds to the average scores on the LongBench-E benchmark (Bai et al. 2024). Notably, all LongBench evaluations employ a maximum token length of 2k to align with the model’s training configuration.

Despite these advantages, some studies reveal that Mamba models do not perform as well as expected when dealing with long texts reaching 2k tokens or more (Waleffe et al. 2024). As depicted in Figure 1, our experimental findings reveal that the pretrained Mamba model surpasses pretrained Transformers of comparable size, such as llama2-3b (Geng and Liu 2023), on short-context tasks. Conversely, a substantial performance degradation is observed for Mamba on long-context tasks relative to Transformers. This performance disparity underscores a significant limitation of Mamba models in practical long-context applications.

This long-context deficency issue of Mamba is usually attributed to its RNN-like nature. This kind of architecture exhibits limitations in preserving crucial information from earlier input sequences as the context length increases due to the fixed-size memory (Wen, Dang, and Lyu 2024; Yang et al. 2024b). Hybrid architectures (Lieber et al. 2024; Ren et al. 2024; Park et al. 2024) have sought to mitigate this issue by integrating attention mechanisms from transformers. However, these approaches often lead to decreased computational efficiency and increased memory consumption. A parallel study, DeciMamba (Ben-Kish et al. 2024), also attributes Mamba’s limitations to a restricted effective receptive field and proposes a method for discarding less important tokens in specific layers to extend the length capabilities of Mamba. However, their focus is on improving the length extrapolation ability without training, and they still achieve limited performance.

To improve the long-context performance of Mamba, we introduce ReMamba. The core intention in ReMamba is straightforward. The distant information within Mamba undergoes excessive degradation. An effective compression strategy to condense information and reduce the distance can be of help. Our approach achieves this compression by selecting the top-k hidden states during the first forward pass and leverages Mamba’s selective mechanism to incorporate them into the state space during the second forward pass. ReMamba incurs minimal additional computational overhead (a single extra forward pass). Experimental results demonstrate that our approach significantly improves Mamba’s long-context performance, bringing it close to the performance of transformers. Our ReMamba model achieves a 3.2 improvement over the baseline on LongBench (Bai et al. 2024) and 1.6 improvement on L-Eval (An et al. 2023). Furthermore, our methodology exhibits transferability to Mamba2, yielding a 1.6 improvement on LongBench, offering a broader impact on the Mamba model family.

Preliminaries

This section provides an overview of the Mamba architecture’s development.

State Space Models

State space sequence models (SSMs) are a class of models that take inspiration from classical continuous dynamical systems. To efficiently compute the state transformations, structured SSMs impose specific constraints on the state transition matrix A^^𝐴\hat{A}over^ start_ARG italic_A end_ARG. Mamba models are a prime example, employing a diagonal structure for this matrix. In the one-dimensional situation, discrete structured SSMs transform sequences as follows:

ht+1subscript𝑡1\displaystyle h_{t+1}italic_h start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT =A^ht+B^xtabsent^𝐴subscript𝑡^𝐵subscript𝑥𝑡\displaystyle=\hat{A}h_{t}+\hat{B}x_{t}= over^ start_ARG italic_A end_ARG italic_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + over^ start_ARG italic_B end_ARG italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT (1)
yt+1subscript𝑦𝑡1\displaystyle y_{t+1}italic_y start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT =Cht+1absent𝐶subscript𝑡1\displaystyle=Ch_{t+1}= italic_C italic_h start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT

Here htN×1subscript𝑡superscript𝑁1h_{t}\in\mathbb{R}^{N\times 1}italic_h start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_N × 1 end_POSTSUPERSCRIPT are state vectors, A^N×N^𝐴superscript𝑁𝑁\hat{A}\in\mathbb{R}^{N\times N}over^ start_ARG italic_A end_ARG ∈ blackboard_R start_POSTSUPERSCRIPT italic_N × italic_N end_POSTSUPERSCRIPT are state transition matrices, B^N^𝐵superscript𝑁\hat{B}\in\mathbb{R}^{N}over^ start_ARG italic_B end_ARG ∈ blackboard_R start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT are input coefficient matrices, xtsubscript𝑥𝑡x_{t}\in\mathbb{R}italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ blackboard_R are inputs, C1×N𝐶superscript1𝑁C\in\mathbb{R}^{1\times N}italic_C ∈ blackboard_R start_POSTSUPERSCRIPT 1 × italic_N end_POSTSUPERSCRIPT is an output matrix, and N𝑁Nitalic_N is the state dimension. For simplicity, matrix multiplications are omitted.

Mamba

The state space model above chooses the time-invariant A^^𝐴\hat{A}over^ start_ARG italic_A end_ARG (state transition matrix) and B^^𝐵\hat{B}over^ start_ARG italic_B end_ARG (input coefficient matrix) thus lacking expressiveness and flexibility. Mamba (Gu and Dao 2024) proposes to make A^^𝐴\hat{A}over^ start_ARG italic_A end_ARG and B^^𝐵\hat{B}over^ start_ARG italic_B end_ARG dynamically depend on inputs.

Recall that in one Mamba layer l𝑙litalic_l , SSM states S𝑆Sitalic_S are transformed as follows:

Δt1lsuperscriptsubscriptΔ𝑡1𝑙\displaystyle\Delta_{t-1}^{l}roman_Δ start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT =Softplus(Proj1(ht1l1)),absentSoftplussubscriptProj1superscriptsubscript𝑡1𝑙1\displaystyle=\mathrm{Softplus}\left(\text{Proj}_{1}(h_{t-1}^{l-1})\right),= roman_Softplus ( Proj start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_l - 1 end_POSTSUPERSCRIPT ) ) , (2a)
Bt1lsuperscriptsubscript𝐵𝑡1𝑙\displaystyle B_{t-1}^{l}italic_B start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT =Proj2(ht1l1),absentsubscriptProj2superscriptsubscript𝑡1𝑙1\displaystyle=\text{Proj}_{2}\left(h_{t-1}^{l-1}\right),= Proj start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_l - 1 end_POSTSUPERSCRIPT ) , (2b)
A^l,B^t1lsuperscript^𝐴𝑙superscriptsubscript^𝐵𝑡1𝑙\displaystyle\hat{A}^{l},\hat{B}_{t-1}^{l}over^ start_ARG italic_A end_ARG start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT , over^ start_ARG italic_B end_ARG start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT =discretize(Al,Bt1l,Δt1l),absentdiscretizesuperscript𝐴𝑙superscriptsubscript𝐵𝑡1𝑙superscriptsubscriptΔ𝑡1𝑙\displaystyle=\text{discretize}\left(A^{l},B_{t-1}^{l},\Delta_{t-1}^{l}\right),= discretize ( italic_A start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT , italic_B start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT , roman_Δ start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT ) , (2c)
ht1lsubscriptsuperscript𝑙𝑡1\displaystyle h^{\prime l}_{t-1}italic_h start_POSTSUPERSCRIPT ′ italic_l end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT =Proj3(ht1l1),absentsubscriptProj3superscriptsubscript𝑡1𝑙1\displaystyle=\text{Proj}_{3}\left(h_{t-1}^{l-1}\right),= Proj start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_l - 1 end_POSTSUPERSCRIPT ) , (2d)
Stlsuperscriptsubscript𝑆𝑡𝑙\displaystyle S_{t}^{l}italic_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT =A^lSt1l+B^t1l(ht1l)T.absenttensor-productsuperscript^𝐴𝑙superscriptsubscript𝑆𝑡1𝑙superscriptsubscript^𝐵𝑡1𝑙superscriptsubscriptsuperscript𝑙𝑡1𝑇\displaystyle=\hat{A}^{l}\otimes S_{t-1}^{l}+\hat{B}_{t-1}^{l}\left({h^{\prime l% }_{t-1}}\right)^{T}.= over^ start_ARG italic_A end_ARG start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT ⊗ italic_S start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT + over^ start_ARG italic_B end_ARG start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT ( italic_h start_POSTSUPERSCRIPT ′ italic_l end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT . (2e)

Here, ht1l1Hsuperscriptsubscript𝑡1𝑙1superscript𝐻h_{t-1}^{l-1}\in\mathbb{R}^{H}italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_l - 1 end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_H end_POSTSUPERSCRIPT represents the output hidden state of Mamba at layer l1𝑙1l-1italic_l - 1 and time step t1𝑡1t-1italic_t - 1. The Softplus function is denoted by SoftplusSoftplus\mathrm{Softplus}roman_Softplus, and Proj1subscriptProj1\mathrm{Proj}_{1}roman_Proj start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT, Proj2subscriptProj2\mathrm{Proj}_{2}roman_Proj start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, and Proj3subscriptProj3\mathrm{Proj}_{3}roman_Proj start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT are abbreviations for multiple space projection operations.

Furthermore, Δt1lHsuperscriptsubscriptΔ𝑡1𝑙superscriptsuperscript𝐻\Delta_{t-1}^{l}\in\mathbb{R}^{H^{\prime}}roman_Δ start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_H start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT is the discrete time step corresponding to the selective mechanism in Mamba, where Hsuperscript𝐻H^{\prime}italic_H start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT is the intermediate hidden size. The continuous and discrete state transformation matrices at layer l𝑙litalic_l are given by Al,A^lH×Nsuperscript𝐴𝑙superscript^𝐴𝑙superscriptsuperscript𝐻𝑁A^{l},\hat{A}^{l}\in\mathbb{R}^{H^{\prime}\times N}italic_A start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT , over^ start_ARG italic_A end_ARG start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_H start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT × italic_N end_POSTSUPERSCRIPT, respectively. The continuous and discrete input coefficient matrices are denoted by Bt1l,B^t1lN×1superscriptsubscript𝐵𝑡1𝑙superscriptsubscript^𝐵𝑡1𝑙superscript𝑁1B_{t-1}^{l},\hat{B}_{t-1}^{l}\in\mathbb{R}^{N\times 1}italic_B start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT , over^ start_ARG italic_B end_ARG start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_N × 1 end_POSTSUPERSCRIPT. The state size is represented by N𝑁Nitalic_N. The discretization method for computing A^^𝐴\hat{A}over^ start_ARG italic_A end_ARG and B^^𝐵\hat{B}over^ start_ARG italic_B end_ARG is indicated by “discretize”. The vector ht1lH×1superscriptsubscript𝑡1superscript𝑙superscriptsuperscript𝐻1h_{t-1}^{{}^{\prime}l}\in\mathbb{R}^{H^{\prime}\times 1}italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT italic_l end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_H start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT × 1 end_POSTSUPERSCRIPT and the SSM state is represented by StlH×Nsuperscriptsubscript𝑆𝑡𝑙superscriptsuperscript𝐻𝑁S_{t}^{l}\in\mathbb{R}^{H^{\prime}\times N}italic_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_H start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT × italic_N end_POSTSUPERSCRIPT. The symbol tensor-product\otimes denotes element-wise multiplication, and B^t1l(ht1l)Tsuperscriptsubscript^𝐵𝑡1𝑙superscriptsubscriptsuperscript𝑙𝑡1𝑇\hat{B}_{t-1}^{l}\left({h^{\prime l}_{t-1}}\right)^{T}over^ start_ARG italic_B end_ARG start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT ( italic_h start_POSTSUPERSCRIPT ′ italic_l end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT represents matrix multiplication. It is important to note that the definitions of Alsuperscript𝐴𝑙A^{l}italic_A start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT and A^lsuperscript^𝐴𝑙\hat{A}^{l}over^ start_ARG italic_A end_ARG start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT presented here differ from their original definitions due to Mamba’s simplification to diagonal matrices.

Mamba improves sequence modelling ability of the state space model by enabling the model to focus on or filter out inputs into a sequential state.

Mamba2

Dao and Gu (2024) theoretically proves the connections between structured state space models and attention mechanisms. They also simplify structured matrix A^^𝐴\hat{A}over^ start_ARG italic_A end_ARG further into scalar-times-identity structure and thus develop a new state space duality (SSD) framework with multi-head patterns similar to transformers. This modification trades in the expressiveness of the A^^𝐴\hat{A}over^ start_ARG italic_A end_ARG matrix for faster training and enlarged state size.

Refer to caption
Figure 2: ReMamba architecture. We just show one layer and leave out the A𝐴Aitalic_A, B𝐵Bitalic_B and discrete method here. For Stage 2, only those value vectors selected need to go through selective adaption. Normal token embeddings just flow as usual. We select top-K𝐾Kitalic_K (here is top-2) hidden states in the last layer according to their importance scores calculated with the last hidden state hLsubscript𝐿h_{L}italic_h start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT. And we incorporate the scores into the gradient utilizing the selective mechanism in Mamba.

Methodology

ReMamba consists of two forward stages. In the first stage, three feed-forward networks are employed to help determine the significance of hidden states from Mamba’s final layer. These hidden states are selected based on their importance scores. The second stage integrates these compression hidden states with the input context, adapting Mamba’s selective mechanism to incorporate them into the state space.

Our proposed method draws some spirits from techniques employed in KV cache compression (Mu, Li, and Goodman 2023; Ge et al. 2024; Yang et al. 2024a; Chevalier et al. 2023; Hwang et al. 2024; Gao, Cao, and Li 2024) by leveraging the language model itself to aggregate information via hidden states and employing a scoring mechanism to select the most salient representations. Nevertheless, different from transformers, ReMamba’s compression strategy focuses on two key objectives: 1) compressing and selectively retaining crucial information to minimize information degradation, and 2) reducing the frequency of state space updates to further alleviate the information loss.

Stage1 : Selective Compression

Selective compression involves selectively compressing the input prompt by leveraging the final layer hidden states of the Mamba model to decrease state updates and consolidate information.

Suppose the sequence length is L𝐿Litalic_L and the context token embeddings are {ti}i=1Lsuperscriptsubscriptsubscript𝑡𝑖𝑖1𝐿\{t_{i}\}_{i=1}^{L}{ italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT. We define the relative range to be compressed as range:=(s,e)assign𝑟𝑎𝑛𝑔𝑒𝑠𝑒range:=(s,e)italic_r italic_a italic_n italic_g italic_e := ( italic_s , italic_e ), where e=s+p𝑒𝑠𝑝e=s+pitalic_e = italic_s + italic_p, with s𝑠sitalic_s and e𝑒eitalic_e denoting the relative start and end positions, respectively, and p𝑝pitalic_p representing the relative length to compress. These values satisfy 0s,p,e1formulae-sequence0𝑠𝑝𝑒10\leq s,p,e\leq 10 ≤ italic_s , italic_p , italic_e ≤ 1. The index set of the context to compress is :=[S,E]assign𝑆𝐸\mathcal{R}:=[S,E]caligraphic_R := [ italic_S , italic_E ], where S=Ls+1𝑆𝐿𝑠1S=L\cdot s+1italic_S = italic_L ⋅ italic_s + 1 and E=L(s+p)𝐸𝐿𝑠𝑝E=L\cdot(s+p)italic_E = italic_L ⋅ ( italic_s + italic_p ). Consequently, the length of the prompt to compress is L=ES+1superscript𝐿𝐸𝑆1L^{\prime}=E-S+1italic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = italic_E - italic_S + 1. For convenience, we use \mathcal{R}caligraphic_R to represent both the set of indices and the set of actual tokens within the context to be compressed. Furthermore, we define the compression ratio ρ𝜌\rhoitalic_ρ and compress the selected context \mathcal{R}caligraphic_R into K:=||ρassign𝐾𝜌K:=|\mathcal{R}|\cdot\rhoitalic_K := | caligraphic_R | ⋅ italic_ρ hidden representations.

In Figure 2, the compression hyperparameter settings are: s=0.2𝑠0.2s=0.2italic_s = 0.2, p=0.4𝑝0.4p=0.4italic_p = 0.4, range=(0.2,0.6)𝑟𝑎𝑛𝑔𝑒0.20.6range=(0.2,0.6)italic_r italic_a italic_n italic_g italic_e = ( 0.2 , 0.6 ), =[3,6]36\mathcal{R}=[3,6]caligraphic_R = [ 3 , 6 ], ρ=0.5𝜌0.5\rho=0.5italic_ρ = 0.5, K=2𝐾2K=2italic_K = 2. In our experiments, we find that s=0𝑠0s=0italic_s = 0 yields the best results, which can be attributed to the casual language modeling nature of Mamba (this will be discussed in more details later).

As shown in the Stage 1 of Figure 2, we denote the last layer’s output hidden states as {hi}i=1Lsuperscriptsubscriptsubscript𝑖𝑖1𝐿\{h_{i}\}_{i=1}^{L}{ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT, where each hiHsubscript𝑖superscript𝐻h_{i}\in\mathbb{R}^{H}italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_H end_POSTSUPERSCRIPT with H𝐻Hitalic_H representing the hidden size. We then transform the last hidden state hLsubscript𝐿h_{L}italic_h start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT into a query hidden state, namely q𝑞qitalic_q, through a feed-forward layer named Query𝑄𝑢𝑒𝑟𝑦Queryitalic_Q italic_u italic_e italic_r italic_y. Additionally, the hidden states to be compressed, denoted as {hi}i=SEsuperscriptsubscriptsubscript𝑖𝑖𝑆𝐸\{h_{i}\}_{i=S}^{E}{ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = italic_S end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_E end_POSTSUPERSCRIPT, are transformed into {ki}i=SEsuperscriptsubscriptsubscript𝑘𝑖𝑖𝑆𝐸\{k_{i}\}_{i=S}^{E}{ italic_k start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = italic_S end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_E end_POSTSUPERSCRIPT via a Key𝐾𝑒𝑦Keyitalic_K italic_e italic_y layer (this transformation is not shown in Figure 2). Finally, the cosine similarity scores, Cos={cosi}i=SE𝐶𝑜𝑠superscriptsubscript𝑐𝑜subscript𝑠𝑖𝑖𝑆𝐸Cos=\{cos_{i}\}_{i=S}^{E}italic_C italic_o italic_s = { italic_c italic_o italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = italic_S end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_E end_POSTSUPERSCRIPT, are computed to serve as importance scores for the hidden states {hi}i=SEsuperscriptsubscriptsubscript𝑖𝑖𝑆𝐸\{h_{i}\}_{i=S}^{E}{ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = italic_S end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_E end_POSTSUPERSCRIPT. The calculation of q𝑞qitalic_q, kisubscript𝑘𝑖k_{i}italic_k start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, and cosi𝑐𝑜subscript𝑠𝑖cos_{i}italic_c italic_o italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is formulated as follows:

q𝑞\displaystyle qitalic_q =Query(hL)absentQuerysubscript𝐿\displaystyle=\mathrm{Query}(h_{L})= roman_Query ( italic_h start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ) (3)
{ki}i=SEsuperscriptsubscriptsubscript𝑘𝑖𝑖𝑆𝐸\displaystyle\{k_{i}\}_{i=S}^{E}{ italic_k start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = italic_S end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_E end_POSTSUPERSCRIPT =Key({hi}i=SE)absentKeysuperscriptsubscriptsubscript𝑖𝑖𝑆𝐸\displaystyle=\mathrm{Key}(\{h_{i}\}_{i=S}^{E})= roman_Key ( { italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = italic_S end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_E end_POSTSUPERSCRIPT )
cosi𝑐𝑜subscript𝑠𝑖\displaystyle cos_{i}italic_c italic_o italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT =kiqmax(ki2q2,ϵ)absentsubscript𝑘𝑖𝑞subscriptnormsubscript𝑘𝑖2subscriptnorm𝑞2italic-ϵ\displaystyle=\frac{k_{i}\cdot q}{\max(\|k_{i}\|_{2}\cdot\|q\|_{2},\epsilon)}= divide start_ARG italic_k start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ⋅ italic_q end_ARG start_ARG roman_max ( ∥ italic_k start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ⋅ ∥ italic_q ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_ϵ ) end_ARG

where kisubscript𝑘𝑖k_{i}italic_k start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT represents the transformed hidden state at position i𝑖iitalic_i, and cosi𝑐𝑜subscript𝑠𝑖cos_{i}italic_c italic_o italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT computes the cosine similarity between q𝑞qitalic_q and kisubscript𝑘𝑖k_{i}italic_k start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT. The constant ϵitalic-ϵ\epsilonitalic_ϵ prevents division by zero.

We select the top-K𝐾Kitalic_K hidden states hjsubscript𝑗{h_{j}}italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT, where jG𝑗𝐺j\in Gitalic_j ∈ italic_G, from the hidden states {hi}i=SEsuperscriptsubscriptsubscript𝑖𝑖𝑆𝐸\{h_{i}\}_{i=S}^{E}{ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = italic_S end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_E end_POSTSUPERSCRIPT based on their importance scores, denoted by Cos𝐶𝑜𝑠Cositalic_C italic_o italic_s. The index set G𝐺Gitalic_G is defined as:

G=argmaxA{S,S+1,,E},|A|=KiAcosi𝐺subscriptformulae-sequence𝐴𝑆𝑆1𝐸𝐴𝐾subscript𝑖𝐴𝑐𝑜subscript𝑠𝑖G=\mathop{\arg\max}_{A\subset\{S,S+1,...,E\},|A|=K}\sum_{i\in A}cos_{i}italic_G = start_BIGOP roman_arg roman_max end_BIGOP start_POSTSUBSCRIPT italic_A ⊂ { italic_S , italic_S + 1 , … , italic_E } , | italic_A | = italic_K end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_i ∈ italic_A end_POSTSUBSCRIPT italic_c italic_o italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT (4)

Note that the original order of these indices is preserved.

In our model, after selecting the top-K𝐾Kitalic_K hidden states hjsubscript𝑗h_{j}italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT, we apply a feed-forward layer, Value𝑉𝑎𝑙𝑢𝑒Valueitalic_V italic_a italic_l italic_u italic_e, to project them into the token embedding hidden space:

{vi}i=1K=V({hj},jG)superscriptsubscriptsubscript𝑣𝑖𝑖1𝐾𝑉subscript𝑗𝑗𝐺\{v_{i}\}_{i=1}^{K}=V(\{h_{j}\},j\in G){ italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT = italic_V ( { italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } , italic_j ∈ italic_G ) (5)

Their corresponding cosine similarity scores are {cosi}i=1Ksuperscriptsubscript𝑐𝑜superscriptsubscript𝑠𝑖𝑖1𝐾\{cos_{i}^{\prime}\}_{i=1}^{K}{ italic_c italic_o italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT. We then replace the token embeddings {ti}i=SEsuperscriptsubscriptsubscript𝑡𝑖𝑖𝑆𝐸\{t_{i}\}_{i=S}^{E}{ italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = italic_S end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_E end_POSTSUPERSCRIPT (\mathcal{R}caligraphic_R) with {vi}i=1Ksuperscriptsubscriptsubscript𝑣𝑖𝑖1𝐾\{v_{i}\}_{i=1}^{K}{ italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT. Consequently, the new input embeddings for Mamba are replaced by:

Tnewsubscript𝑇new\displaystyle T_{\text{new}}italic_T start_POSTSUBSCRIPT new end_POSTSUBSCRIPT =Cat({ti}i=1S1,{vi}i=1K,{ti}i=E+1L)absentCatsuperscriptsubscriptsubscript𝑡𝑖𝑖1𝑆1superscriptsubscriptsubscript𝑣𝑖𝑖1𝐾superscriptsubscriptsubscript𝑡𝑖𝑖𝐸1𝐿\displaystyle=\text{Cat}(\{t_{i}\}_{i=1}^{S-1},\{v_{i}\}_{i=1}^{K},\{t_{i}\}_{% i=E+1}^{L})= Cat ( { italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S - 1 end_POSTSUPERSCRIPT , { italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT , { italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = italic_E + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT ) (6)
={ti}i=1LL+Kabsentsuperscriptsubscriptsuperscriptsubscript𝑡𝑖𝑖1𝐿superscript𝐿𝐾\displaystyle=\{t_{i}^{\prime}\}_{i=1}^{L-L^{\prime}+K}= { italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L - italic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT + italic_K end_POSTSUPERSCRIPT (7)

where Cat denotes the concatenation operation. The length of Tnewsubscript𝑇newT_{\text{new}}italic_T start_POSTSUBSCRIPT new end_POSTSUBSCRIPT is LL+K𝐿superscript𝐿𝐾L-L^{\prime}+Kitalic_L - italic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT + italic_K, resulting in a significantly shorter input sequence for the second forward pass compared to the first.

Stage 2: Selective Adaption

One significant challenge in using top-K𝐾Kitalic_K selection based on importance scores is its non-differentiability, which impedes the ability to train such models effectively. Here we propose a framework that integrates importance scores into the selective mechanisms of the Mamba model.

For hidden states (embeddings) that do not require compression in stage 1, namely {ti}i=1S1superscriptsubscriptsubscript𝑡𝑖𝑖1𝑆1\{t_{i}\}_{i=1}^{S-1}{ italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S - 1 end_POSTSUPERSCRIPT and {ti}i=E+1Lsuperscriptsubscriptsubscript𝑡𝑖𝑖𝐸1𝐿\{t_{i}\}_{i=E+1}^{L}{ italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = italic_E + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT, the standard Mamba algorithm is applied during the second forward pass. For embeddings at selected positions, specifically {ti}i=SS+K1superscriptsubscriptsuperscriptsubscript𝑡𝑖𝑖𝑆𝑆𝐾1\{t_{i}^{{}^{\prime}}\}_{i=S}^{S+K-1}{ italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_i = italic_S end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S + italic_K - 1 end_POSTSUPERSCRIPT or equivalently {vi}i=1Ksuperscriptsubscriptsubscript𝑣𝑖𝑖1𝐾\{v_{i}\}_{i=1}^{K}{ italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT, Equation 2a is reformulated as follows:

α𝛼\displaystyle\alphaitalic_α =ReLU(cost1)absentReLU𝑐𝑜superscriptsubscript𝑠𝑡1\displaystyle=\mathrm{ReLU}(cos_{t-1}^{{}^{\prime}})= roman_ReLU ( italic_c italic_o italic_s start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT end_POSTSUPERSCRIPT ) (8)
Δt1lsuperscriptsuperscriptsubscriptΔ𝑡1𝑙\displaystyle{\Delta_{t-1}^{l}}^{{}^{\prime}}roman_Δ start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT end_POSTSUPERSCRIPT =Proj1(ht1l1)absentsubscriptProj1superscriptsubscript𝑡1𝑙1\displaystyle=\text{Proj}_{1}(h_{t-1}^{l-1})= Proj start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_l - 1 end_POSTSUPERSCRIPT )
δ𝛿\displaystyle\deltaitalic_δ =Δt1lα+ΘlabsentsuperscriptsuperscriptsubscriptΔ𝑡1𝑙𝛼superscriptΘ𝑙\displaystyle={\Delta_{t-1}^{l}}^{{}^{\prime}}\cdot\alpha+\Theta^{l}= roman_Δ start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT start_POSTSUPERSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT end_POSTSUPERSCRIPT ⋅ italic_α + roman_Θ start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT
Δt1lsuperscriptsubscriptΔ𝑡1𝑙\displaystyle\Delta_{t-1}^{l}roman_Δ start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT =Softplus(δ)absentSoftplus𝛿\displaystyle=\mathrm{Softplus}(\delta)= roman_Softplus ( italic_δ )

where ΘlHsuperscriptΘ𝑙superscriptsuperscript𝐻\Theta^{l}\in\mathbb{R}^{H^{\prime}}roman_Θ start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_H start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT is a layer-wise trainable offset parameter controlling scale intensity. ReLUReLU\mathrm{ReLU}roman_ReLU is the activation function. Intuitively, hidden states with low importance scores should minimally impact model computations. Therefore, we approximate this behavior by setting their corresponding ΔΔ\Deltaroman_Δ values close to zero. Ideally, directly multiplying ΔΔ\Deltaroman_Δ by α𝛼\alphaitalic_α would be more precise, but this necessitates modifications to the selective scan algorithm, leading us to adopt the simpler approach.

Training

Following the forward encoding processes, standard causal language generation is applied using the Mamba architecture. During training, newly introduced parameters within the selective compression mechanism are optimized. These parameters, except for ΘΘ\Thetaroman_Θ which is initialized to all zeros, are initialized with a subset of the weights from the first layer’s in_proj matrix. Additionally, for parameters in Mamba, the dt_proj matrix is fully trained, while in_proj, out_proj, embeddings, and lm_head are updated using Low-Rank Adaptation (LoRA) (Hu et al. 2022). In our best implementation, to emphasize the significance of specific information, gradients flowing into the importance scores are scaled proportionally to these scores. This approach intuitively prioritizes the training of more critical representations.

Model

2WikiMQA

GovReport

HotpotQA

LCC

MultiNews

MultiQA

PassCount

PassRetrie.

Qasper

RepoBench

SAMSum

TREC

TriviaQA

Average

llama2-3b (Pre) 13.24 25.98 12.64 60.20 16.36 27.75 1.03 6.33 8.41 46.92 32.68 52.33 62.43 28.18
llama2-3b (SFT) 17.10 23.79 22.51 57.79 20.27 33.69 0.00 6.67 20.56 43.47 34.62 52.33 49.64 29.42
Mamba (Pre) 3.73 8.72 4.03 24.03 11.31 4.95 0.80 1.75 3.67 12.83 6.86 9.00 17.40 8.39
Mamba (SFT) 22.10 19.08 15.90 40.20 19.36 30.28 0.00 4.67 19.04 36.02 28.30 39.33 45.97 24.63
ReMamba (SFT) 21.18 19.67 20.56 48.21 18.86 26.39 3.21 6.83 16.76 40.40 33.65 48.67 57.73 27.86
Table 1: Performance on LongBench-E (English branch). “MultiQA” denotes MultiFieldQA , “PassCount” denotes PassageCount, “PassRetrie.” denotes PassageRetrieval. Mamba and ReMamba models are evaluated using a maximum length of 6K tokens, matching their training configurations. Llama2-3B is capped at 2K tokens due to positional encoding limitations, which proves to be the best setting. Here “(Pre)” means pretrained model. “(SFT)” means finetuned model.
Model Finetuned Tokens CodeU Coursera GSM QuALITY SFictio TOEFL Average
llama2-3b (Pre) 2k 0.00 19.33 3.00 23.27 56.25 18.22 20.01
llama2-3b (SFT) 2k 2.22 25.00 4.00 27.23 58.59 20.45 22.92
Mamba (Pre) 6k 2.22 23.26 0.00 25.74 23.44 17.10 15.29
Mamba (SFT) 6k 4.44 26.16 1.00 27.72 50.78 23.05 22.19
ReMamba (SFT) 6k 2.22 22.97 3.00 25.74 58.59 30.48 23.83
Table 2: Model performance on closed-ended tasks of L-Eval. “Tokens” denotes the max length. “Finetuned” denotes whether the model is finetuned or not. Mamba and ReMamba models are evaluated using a maximum length of 6K tokens, matching their training configuration. In contrast, llama2-3B is capped at 2K tokens due to positional encoding limitations.

Experiments

Experimental Setups

Our model is designed for long-context question-answering tasks, necessitating a substantial corpus of long-context instruction tuning data. To this end, we leverage the OpenOrca dataset (Mukherjee et al. 2023) and LongAlpaca-12k (Chen et al. 2024). The former comprises a rich collection of ChatGPT-augmented FLAN data alignments, while the latter is a long-context alignment dataset. We initially filter long instruction tuning instances from OpenOrca and concatenate them with LongAlpaca. To accommodate device memory constraints, prompts are truncated to a maximum length of 6,000 tokens. This process yields approximately 200,000 long-context training examples. To augment training data diversity, the initial 300,000 standard instances from OpenOrca are incorporated. During training, the hyperparameter s𝑠sitalic_s is fixed at 0. The hyperparameter p𝑝pitalic_p is randomly sampled from the interval [0.1, 0.3], while ρ𝜌\rhoitalic_ρ is randomly sampled from the interval [0.05, 0.2].

We finetune the baseline Mamaba 2.8b model and our ReMamba model on the same dataset using LoRA with the same hyperparameter setting (r=32, others default). We also finetune a llama2-3b (Geng and Liu 2023) models for reference. Given the 2k maximum positional encoding limit of llama2-3b, we conduct fine-tuning experiments using two configurations: one matching the Mamba data and another with data truncated to a maximum length of 2K. Comparative results reveal superior performance for the truncated model, which is subsequently adopted as the default setting.

Evaluations

We conduct comparative analyses of our model against baseline Mamba2.8b (both of finetuned and pretrained) on the widely adopted LongBench benchmark (Bai et al. 2024) and LEval benchmark (An et al. 2023), which encompass a diverse set of challenging real-world long-context tasks. For consistency, the same prompt templates and greedy decoding configurations are employed across all models. To provide a reference point, the performance of a similarly sized transformer architecture (llama2-3b) is also included. It is important to note that due to llama2-3b’s 2k maximum positional encoding limitation, its performance in the 6k setting is subpar, necessitating the utilization of its optimal 2k configuration for evaluation.

Results

Model

2WikiMQA

GovReport

HotpotQA

LCC

MultiNews

MultiQA

PassCount

PassRetrie.

Qasper

RepoBench

SAMSum

TREC

TriviaQA

Average

Mamba2 (Pre) 2.18 4.76 1.54 23.46 7.71 2.88 0.60 1.47 1.17 14.97 2.07 8.33 10.60 6.29
Mamba2 (SFT) 13.73 19.97 15.05 41.78 19.52 25.20 0.67 5.67 13.44 38.79 33.95 49.67 43.54 24.69
ReMamba2 (SFT) 18.90 19.03 18.09 51.15 17.68 25.82 3.33 5.00 14.84 43.99 23.55 44.00 56.90 26.33
Table 3: The performance comparisons of LongBench-E (English Branch) on Mamba2. Mamba2 (Pre) means pretrained Mamba2. Mamba2 (SFT) means finetuned Mamba2. ReMamba2 (SFT) means our model. All use the setting of 6k max length.

Results on LongBench

We choose the English branch of LongBench because our training set only contains English. Higher values across all indicators are indicative of better performance. We compare the performance of the models in detailed tasks in Table 1 under the max length 6k corresponding to the training setting. Here the hyperparameters for ReMamba are: s=0𝑠0s=0italic_s = 0, p=0.18𝑝0.18p=0.18italic_p = 0.18 and ρ=0.009𝜌0.009\rho=0.009italic_ρ = 0.009. We will also show later that our model’s robustness to various of hyperparameter combinations. Table 1 shows that our ReMamba model improves the average scores on LongBench 3.23 compared to the SFT Mamba baseline. Our model approaches the pretrained and finetuned transformer baseline.

Results on LEval

We compare the performance on the closed-ended tasks of L-Eval. The higher all indicators are, the better. A snap of detailed task scores for the maximum length of 6k is presented in Table 2. We can witness a 1.64 improvement on average scores compared to the SFT Mamba baseline. Here the hyperparameter setting for ReMamba is: s=0𝑠0s=0italic_s = 0, p=0.20𝑝0.20p=0.20italic_p = 0.20 and ρ=0.05𝜌0.05\rho=0.05italic_ρ = 0.05.

Analyses and Discussions

Varying Length

To complement our main results, which employ a maximum sequence length of 6k tokens to align with training settings, we further evaluate the model performance at varying input lengths ranging from 2k to 9k tokens. This evaluation is conducted using the LongBench and L-Eval benchmarks. As depicted in Figure 3, our ReMamba consistently outperforms the baseline Mamba model across all tested context lengths on LongBench. Notably, the performance gap between our model and the baseline widens as the context length increases. Furthermore, our model extends the efficient context length (the length at which greatest performance is observed) to 6k tokens, compared to 4k tokens for the finetuned Mamba baseline. In Figure 4, we observe performance improvements across all context lengths for our model on L-Eval. Our ReMamba even surpasses the transformers baseline.

Speed Performance and Memory Expense

Our model introduces a single additional forward pass during inference, resulting in no additional memory consumption. To evaluate the speed performance, we varys the input sequence length from 1k to 8k tokens while fixing the output length at 1k tokens. For each configuration, we use a batch size of 1 and measure the speed on an NVIDIA A100 80GB GPU. We compare the performance of ReMamba, Mamba, and the vanilla transformer model (llama2-3b), as illustrated in Figure 6. The speed metric is given in tokens per second. Our experiments indicate that ReMamba operates at speeds comparable to the original baseline, maintaining a significant speed advantage over traditional transformers.

Generalizing to Mamba2

While our method is specifically tailored for Mamba, we also conduct experiments to verify its applicability to Mamba2. As is shown in Table 3, the same method applied to Mamba2 (we call ReMamba2 here) achieves 1.6 improved performance on averaged scores of LongBench. Here we use s=0𝑠0s=0italic_s = 0, p=0.25𝑝0.25p=0.25italic_p = 0.25 and ρ=0.05𝜌0.05\rho=0.05italic_ρ = 0.05. The max length is still 6k. It is noteworthy that Mamba2 exhibits nearly no performance improvement over Mamba on LongBench, suggesting potential limitations within the Mamba model series.

Ablation Study

To verify the effectiveness of the modules we introduced, we conduct an ablation study by comparing ReMamba against three alternative methods: 1. Random Selection: which randomly select hidden states as the compressed information according to ρ𝜌\rhoitalic_ρ. 2. Fix Selection: given the ρ𝜌\rhoitalic_ρ we select enough hidden states every k𝑘kitalic_k positions. The interval k𝑘kitalic_k is calculated based on the compression ratio. 3. Multiplicative Selection: This variant just modifies the selective adaptation process by directly multiplying importance scores with the selected hidden states, aligning with the approach proposed by Raposo et al. (2024). All of those models are trained on the same data as ReMamba.

We report the averaged scores on LongBench across various maximum input lengths. As illustrated in Figure 5, both the fixed and random selection methods achieve performance comparable to the finetuned Mamba baseline. Interestingly, these methods even outperform Mamba at lengths of 5k and 6k. This observation confirms our hypothesis that Mamba models suffer from severe forgetting issues. Even simple methods like dropping some information appear beneficial. The performance of the multiplicative selection method shows some improvements across varying input lengths. However, the substantial performance gap observed with our selective adaptation module demonstrates its critical role in the ReMamba model. The selective adaptation module not only mitigates the forgetting problem, but also significantly enhances the model’s ability to handle longer input sequences effectively.

Refer to caption
Figure 3: Average scores on LongBench varying max length from 2k to 9k. The “Pre” means pretrained model while “SFT” means finetuned model. The performance of llama2-3b (SFT) and llama2-3b (Pre) is for reference, using the max length of 2k due to positional encoding problems.
Refer to caption
Figure 4: Average scores on L-Eval varying max length from 2k to 9k. The performance of llama2-3b (SFT) and llama2-3b (Pre) is for reference.
Refer to caption
Figure 5: Ablation study about average scores on LongBench varying max length from 2k to 9k. “Mamba(SFT)” is the finetuned Mamba. “fix_select” is the Fix Selection. “random_select” is the Random Selection. “multiplicative_select” is the Multiplicative Selection.

Robustness varying choices of hyperparamters

The aforementioned results were obtained using the hyperparameter settings s=0𝑠0s=0italic_s = 0, p=0.18𝑝0.18p=0.18italic_p = 0.18, and ρ=0.009𝜌0.009\rho=0.009italic_ρ = 0.009, which demonstrates relatively superior performance. In Figure 7, we also show the stability of our model by varying the hyperparameters p𝑝pitalic_p and ρ𝜌\rhoitalic_ρ. For these experiments, the parameter s𝑠sitalic_s is fixed at 0.

Refer to caption
Figure 6: Speed (tokens/second) performance comparisons. Here 1024_1024 means input 1024 tokens and output 1024 tokens.
Refer to caption
Figure 7: Robustness of the ReMamba model with varying hyperparameters. The row label denotes the relative ratio of the prompt to be compressed, corresponding to parameter p𝑝pitalic_p. The column label indicates the compression ratio, corresponding to parameter ρ𝜌\rhoitalic_ρ.

.

Model ρ𝜌\rhoitalic_ρ p𝑝pitalic_p s𝑠sitalic_s model_type average
ReMamba 0.009 0.18 0.00 ReMamba 27.86
middle0.0 0.009 0.18 0.00 middle 25.96
middle0.1 0.009 0.18 0.10 middle 26.45
middle0.2 0.009 0.18 0.20 middle 26.86
middle0.3 0.009 0.18 0.30 middle 26.56
middle0.4 0.009 0.18 0.40 middle 26.43
special 0.009 1.00 1.00 special 15.76
Table 4: Performance of different model variants on LongBench. In this context, the “ReMamba” model type constitutes our optimal model. The “middle” type corresponds to the model variant where s𝑠sitalic_s is non-zero. The “special” model variant compresses the entire prompt using ρ=0.009𝜌0.009\rho=0.009italic_ρ = 0.009 and subsequently appends the compressed hidden states to the end of the original prompt in the second stage.

Why compress from the start

Experimental results indicate that setting s=0𝑠0s=0italic_s = 0 is the best. However, one might wonder about the effectiveness of compressing in the middle of the sequence. We conduct additional analytical studies to explore the impact of compressing the input sequence from different starting positions.

We train a model utilizing s𝑠sitalic_s sampled uniformly from the interval [0.1, 0.3] during the training process. Subsequently, we evaluate its performance on LongBench under conditions identical to those of the ReMamba model, employing a maximum length of 6k tokens, p=0.18𝑝0.18p=0.18italic_p = 0.18, and ρ=0.009𝜌0.009\rho=0.009italic_ρ = 0.009. We evaluate the average scores ranging s𝑠sitalic_s from 0 to 0.4. Additionally, we train a special model variant that compresses the entire prompt based on ρ=0.009𝜌0.009\rho=0.009italic_ρ = 0.009 and appends the compressed hidden states to the end of the original prompt in the second stage.

Table 4 presents the results of these experiments. We observe a performance degradation when the compression is applied in the middle of the sequence. The special model variant performs even worse than the finetuned Mamba baseline.

This degradation can be explained by the disruption caused to the causal language modeling nature of the Mamba model. When compressed information is integrated into the initial position, the subsequent language modeling process can proceed without modification, effectively treating the compressed data as a specialized non-zero initial state. Conversely, inserting those compressed hidden states as tokens within the sequence disrupts the causal language modeling paradigm, which assumes complete sentences as input. This incongruity hinders the model’s ability to maintain a coherent state space and can lead to performance degradation. Among the tested models, the special model variant that appends compressed hidden states to the end of the original prompt exhibits the most pronounced negative impact due to the significant disruption of the model’s expected input structure.

Despite these challenges, the model that compresses in the middle still outperforms the finetuned Mamba baseline. This demonstrates that our method exhibits apparent effectiveness.

Conclusions

This study investigates the long-context efficiency challenges posed by Mamba models, hypothesizing that distant information within these models is subject to substantial degradation. In response, we introduce ReMamba, a novel approach that compresses and selectively preserves critical information during an initial forward pass. This compressed information is subsequently integrated into the state space during a second forward pass, capitalizing on Mamba’s inherent selective mechanism. Notably, ReMamba incurs minimal computational overhead while substantially enhancing Mamba’s long-context performance, thereby offering a promising avenue for advancing the Mamba model family.

References

  • An et al. (2023) An, C.; Gong, S.; Zhong, M.; Zhao, X.; Li, M.; Zhang, J.; Kong, L.; and Qiu, X. 2023. L-Eval: Instituting Standardized Evaluation for Long Context Language Models. arXiv:2307.11088.
  • Bai et al. (2024) Bai, Y.; Lv, X.; Zhang, J.; Lyu, H.; Tang, J.; Huang, Z.; Du, Z.; Liu, X.; Zeng, A.; Hou, L.; Dong, Y.; Tang, J.; and Li, J. 2024. LongBench: A Bilingual, Multitask Benchmark for Long Context Understanding. arXiv:2308.14508.
  • Ben-Kish et al. (2024) Ben-Kish, A.; Zimerman, I.; Abu-Hussein, S.; Cohen, N.; Globerson, A.; Wolf, L.; and Giryes, R. 2024. DeciMamba: Exploring the Length Extrapolation Potential of Mamba. arXiv:2406.14528.
  • Chen et al. (2024) Chen, Y.; Qian, S.; Tang, H.; Lai, X.; Liu, Z.; Han, S.; and Jia, J. 2024. LongLoRA: Efficient Fine-tuning of Long-Context Large Language Models. In The Twelfth International Conference on Learning Representations, ICLR 2024, Vienna, Austria, May 7-11, 2024. OpenReview.net.
  • Chevalier et al. (2023) Chevalier, A.; Wettig, A.; Ajith, A.; and Chen, D. 2023. Adapting Language Models to Compress Contexts. In Bouamor, H.; Pino, J.; and Bali, K., eds., Proceedings of the 2023 Conference on Empirical Methods in Natural Language Processing, EMNLP 2023, Singapore, December 6-10, 2023, 3829–3846. Association for Computational Linguistics.
  • Dao and Gu (2024) Dao, T.; and Gu, A. 2024. Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality. arXiv:2405.21060.
  • Gao, Cao, and Li (2024) Gao, J.; Cao, Z.; and Li, W. 2024. SelfCP: Compressing Over-Limit Prompt via the Frozen Large Language Model Itself. arXiv:2405.17052.
  • Gao et al. (2023) Gao, L.; Tow, J.; Abbasi, B.; Biderman, S.; Black, S.; DiPofi, A.; Foster, C.; Golding, L.; Hsu, J.; Le Noac’h, A.; Li, H.; McDonell, K.; Muennighoff, N.; Ociepa, C.; Phang, J.; Reynolds, L.; Schoelkopf, H.; Skowron, A.; Sutawika, L.; Tang, E.; Thite, A.; Wang, B.; Wang, K.; and Zou, A. 2023. A framework for few-shot language model evaluation. https://zenodo.org/records/10256836. Accessed: 2024-May-29.
  • Ge et al. (2024) Ge, T.; Hu, J.; Wang, L.; Wang, X.; Chen, S.-Q.; and Wei, F. 2024. In-context Autoencoder for Context Compression in a Large Language Model. arXiv:2307.06945.
  • Geng and Liu (2023) Geng, X.; and Liu, H. 2023. OpenLLaMA: An Open Reproduction of LLaMA. https://github.com/openlm-research/open_llama. Accessed: 2024-May-29.
  • Gu and Dao (2024) Gu, A.; and Dao, T. 2024. Mamba: Linear-Time Sequence Modeling with Selective State Spaces. arXiv:2312.00752.
  • Hu et al. (2022) Hu, E. J.; Shen, Y.; Wallis, P.; Allen-Zhu, Z.; Li, Y.; Wang, S.; Wang, L.; and Chen, W. 2022. LoRA: Low-Rank Adaptation of Large Language Models. In The Tenth International Conference on Learning Representations, ICLR 2022, Virtual Event, April 25-29, 2022. OpenReview.net.
  • Hwang et al. (2024) Hwang, D.; Wang, W.; Huo, Z.; Sim, K. C.; and Mengibar, P. M. 2024. TransformerFAM: Feedback attention is working memory. arXiv:2404.09173.
  • Lieber et al. (2024) Lieber, O.; Lenz, B.; Bata, H.; Cohen, G.; Osin, J.; Dalmedigos, I.; Safahi, E.; Meirom, S.; Belinkov, Y.; Shalev-Shwartz, S.; Abend, O.; Alon, R.; Asida, T.; Bergman, A.; Glozman, R.; Gokhman, M.; Manevich, A.; Ratner, N.; Rozen, N.; Shwartz, E.; Zusman, M.; and Shoham, Y. 2024. Jamba: A Hybrid Transformer-Mamba Language Model. arXiv:2403.19887.
  • Mu, Li, and Goodman (2023) Mu, J.; Li, X.; and Goodman, N. 2023. Learning to Compress Prompts with Gist Tokens. In Oh, A.; Naumann, T.; Globerson, A.; Saenko, K.; Hardt, M.; and Levine, S., eds., Advances in Neural Information Processing Systems, volume 36, 19327–19352. Curran Associates, Inc.
  • Mukherjee et al. (2023) Mukherjee, S.; Mitra, A.; Jawahar, G.; Agarwal, S.; Palangi, H.; and Awadallah, A. 2023. Orca: Progressive Learning from Complex Explanation Traces of GPT-4. arXiv:2306.02707.
  • Park et al. (2024) Park, J.; Park, J.; Xiong, Z.; Lee, N.; Cho, J.; Oymak, S.; Lee, K.; and Papailiopoulos, D. 2024. Can Mamba Learn How to Learn? A Comparative Study on In-Context Learning Tasks. arXiv:2402.04248.
  • Raposo et al. (2024) Raposo, D.; Ritter, S.; Richards, B.; Lillicrap, T.; Humphreys, P. C.; and Santoro, A. 2024. Mixture-of-Depths: Dynamically allocating compute in transformer-based language models. arXiv:2404.02258.
  • Ren et al. (2024) Ren, L.; Liu, Y.; Lu, Y.; Shen, Y.; Liang, C.; and Chen, W. 2024. Samba: Simple Hybrid State Space Models for Efficient Unlimited Context Language Modeling. arXiv:2406.07522.
  • Vaswani et al. (2017) Vaswani, A.; Shazeer, N.; Parmar, N.; Uszkoreit, J.; Jones, L.; Gomez, A. N.; Kaiser, L. u.; and Polosukhin, I. 2017. Attention is All you Need. In Guyon, I.; Luxburg, U. V.; Bengio, S.; Wallach, H.; Fergus, R.; Vishwanathan, S.; and Garnett, R., eds., Advances in Neural Information Processing Systems, volume 30. Curran Associates, Inc.
  • Waleffe et al. (2024) Waleffe, R.; Byeon, W.; Riach, D.; Norick, B.; Korthikanti, V.; Dao, T.; Gu, A.; Hatamizadeh, A.; Singh, S.; Narayanan, D.; Kulshreshtha, G.; Singh, V.; Casper, J.; Kautz, J.; Shoeybi, M.; and Catanzaro, B. 2024. An Empirical Study of Mamba-based Language Models. arXiv:2406.07887.
  • Wen, Dang, and Lyu (2024) Wen, K.; Dang, X.; and Lyu, K. 2024. Rnns are not transformers (yet): The key bottleneck on in-context retrieval. arXiv preprint arXiv:2402.18510.
  • Yang et al. (2024a) Yang, D.; Han, X.; Gao, Y.; Hu, Y.; Zhang, S.; and Zhao, H. 2024a. PyramidInfer: Pyramid KV Cache Compression for High-throughput LLM Inference. arXiv:2405.12532.
  • Yang et al. (2024b) Yang, K.; Ackermann, J.; He, Z.; Feng, G.; Zhang, B.; Feng, Y.; Ye, Q.; He, D.; and Wang, L. 2024b. Do Efficient Transformers Really Save Computation? arXiv preprint arXiv:2402.13934.