Beyond KV Caching: Shared Attention for Efficient LLMs

Liao Bingli1, Danilo Vasconcellos Vargas1
Abstract

The efficiency of large language models (LLMs) remains a critical challenge, particularly in contexts where computational resources are limited. Traditional attention mechanisms in these models, while powerful, require significant computational and memory resources due to the necessity of recalculating and storing attention weights across different layers. This paper introduces a novel Shared Attention (SA) mechanism, designed to enhance the efficiency of LLMs by directly sharing computed attention weights across multiple layers. Unlike previous methods that focus on sharing intermediate Key-Value (KV) caches, our approach utilizes the isotropic tendencies of attention distributions observed in advanced LLMs post-pretraining to reduce both the computational flops and the size of the KV cache required during inference. We empirically demonstrate that implementing SA across various LLMs results in minimal accuracy loss on standard benchmarks. Our findings suggest that SA not only conserves computational resources but also maintains robust model performance, thereby facilitating the deployment of more efficient LLMs in resource-constrained environments. Code: https://github.com/metacarbon/shareAtt

Einführung

The rapid growth of large language models (LLM) has brought forth significant challenges in terms of computational and memory efficiency during inference. Traditional approaches, such as Multi-Query Attention (MQA) (Shazeer 2019) and Grouped-Query Attention (GQA) (Ainslie et al. 2023), have made strides in reducing the key-value (KV) cache size by sharing keys and values across multiple heads within a layer. More recently, Cross-Layer Attention (CLA) has extended this concept by sharing keys and values across adjacent layers, further reducing memory requirements without substantially impacting model performance (Brandon et al. 2024). Despite these advancements, the need for more efficient methods continues to grow, particularly as models scale and are deployed in resource-constrained environments.

In this paper, we introduce a novel method termed Shared Attention (SA), which significantly reduces the KV cache requirements and computational load during inference for LLMs. Unlike previous methods that focused on sharing KV caches either within the same layer or between adjacent layers, our approach inspired by the inherent similarity of attention weights distribution across layers, and sharing these weights directly could further reduce the need for repeated key and value computations. This innovative approach not only reduces the KV cache size but also circumvents the need for the computationally expensive softmax operation, leading to a more efficient inference process.

The key contributions of our work are summarized as follows:

  1. 1.

    We propose a novel Shared Attention mechanism that reduces computational and memory overhead by directly sharing pre-computed attention weights across multiple layers in LLMs.

  2. 2.

    We empirically validate the effectiveness of Shared Attention by implementing it across various benchmarks and demonstrate that it achieves comparable accuracy.

  3. 3.

    Our analysis of attention isotropy across pretrained LLMs provides insights into how attention mechanisms stabilize and become more uniform across layers as training progresses. This understanding informs the optimal layer ranges for applying Shared Attention.

Shared Attention

In this section we demonstrate motivation, Shared Attention (SA) method, and the comparison to existed KV-sharing mechanisms.

Refer to caption
Figure 1: Illustration of various sharing algorithms. The MQA and GQA methods share the Key and Value caches with the Query within the same layer to reduce memory usage. The CLA method extends this by sharing the Key and Value caches across different layers. Our method, Shared Attention, advances this concept further by sharing the attention weights across multiple layers.

Motivation

The self-attention mechanism in transformer models is typically defined as softmax(QKTd)Vsoftmax𝑄superscript𝐾𝑇𝑑𝑉\text{softmax}(\frac{QK^{T}}{\sqrt{d}})Vsoftmax ( divide start_ARG italic_Q italic_K start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG ) italic_V, where Q𝑄Qitalic_Q, K𝐾Kitalic_K, and V𝑉Vitalic_V represent the query, key, and value matrices respectively, and d𝑑ditalic_d is the dimension of the key vectors. This formulation necessitates the recomputation of attention weights at each layer, a computationally intensive task, particularly when the model is deployed in inference mode. To mitigate this, the concept of a KV-cache is employed, reducing the need to recompute K𝐾Kitalic_K and V𝑉Vitalic_V matrices for previously encountered tokens.

While prior methodologies have focused on sharing KV caches at different levels to minimize memory overhead, they predominantly operate under the assumption that attention weights differ significantly across layers, thereby necessitating individual computations to capture diverse contextual dependencies effectively. This assumption prompts a critical inquiry: Are the attention weights indeed markedly different across layers, or is this variation minimal enough to allow for a unified approach across multiple layers?

To explore this, we conducted an empirical analysis on the distribution of attention weights across different layers of the model. Based on the Llama2-7B-chat model, we processed the Massive Multitask Language Understanding (MMLU) dataset (Hendrycks et al. 2020) to extract the attention matrices, softmax(QKTd)softmax𝑄superscript𝐾𝑇𝑑\text{softmax}(\frac{QK^{T}}{\sqrt{d}})softmax ( divide start_ARG italic_Q italic_K start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG ), for each layer. Given the variability in sequence lengths, we standardized these matrices to a uniform size by applying zero-padding to align them to a consistent shape of maxlen×maxlenmaxlenmaxlen\text{maxlen}\times\text{maxlen}maxlen × maxlen.

Our analysis employed the cosine similarity metric to compare the attention matrices of all layers, revealing a notable high degree of similarity across most of layers, particularly from indices 3 to 30. Contrastingly, the initial layers (0 and 1) and the final output layer (31) exhibited substantially lower similarity scores to middle layers. This observation is intuitive as the early layers are closer to the input token embeddings, requiring frequent adjustments to their attention distribution to accurately abstract semantic meanings from diverse inputs. Similarly, the final layer’s unique role in predicting the next token justifies its distinct attention pattern.

Inspired by these findings, we hypothesize that the high similarity in attention weights across the majority of layers could allow for a shared representation of these weights, thus eliminating the need for separate softmax computations in each layer and reducing the key cache size. Such a strategy could not only streamline the inference process but also enhance computational efficiency significantly.

Based on the observed uniformity in attention weights, we propose a novel algorithm as shown in Algorithm 1, Shared Attention, which utilizes a single shared attention matrix across multiple layers. This approach fundamentally redefines the operational paradigm by maintaining a consistent attention mechanism across various contextual layers, thereby reducing redundancy and enhancing inference speed.

Algorithm 1 Shared Attention Algorithm

Input: Set of layers L𝐿Litalic_L, input tokens X𝑋Xitalic_X
Parameters: Attention span S𝑆Sitalic_S (e.g., layers 23 to 30)
Output: Updated attention weights across specified layers

1:  Initialize attention weights A𝐴A\leftarrow\emptysetitalic_A ← ∅
2:  for each layer lS𝑙𝑆l\in Sitalic_l ∈ italic_S do
3:     if first layer in S𝑆Sitalic_S then
4:        Compute initial attention weights Alsoftmax(QlKlTdk)subscript𝐴𝑙softmaxsubscript𝑄𝑙superscriptsubscript𝐾𝑙𝑇subscript𝑑𝑘A_{l}\leftarrow\text{softmax}(\frac{Q_{l}K_{l}^{T}}{\sqrt{d_{k}}})italic_A start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ← softmax ( divide start_ARG italic_Q start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT italic_K start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT end_ARG start_ARG square-root start_ARG italic_d start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG end_ARG )
5:        Set AAl𝐴subscript𝐴𝑙A\leftarrow A_{l}italic_A ← italic_A start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT
6:     else
7:        Share attention weights AlAsubscript𝐴𝑙𝐴A_{l}\leftarrow Aitalic_A start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ← italic_A
8:     end if
9:     Apply shared attention to compute outputs OlAlVlsubscript𝑂𝑙subscript𝐴𝑙subscript𝑉𝑙O_{l}\leftarrow A_{l}\cdot V_{l}italic_O start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ← italic_A start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ⋅ italic_V start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT
10:  end for
11:  Adjust subsequent layers’ inputs using outputs from S𝑆Sitalic_S
12:  Continue processing remaining layers with standard attention
13:  return Final output after processing all layers
Refer to caption
Figure 2: Layer-wise similarity of attention weights across various LLMs. The x-axis and y-axis represent the layer indices, while the z-axis depicts the cosine similarity values. The distinct similarity patterns are indicative of the specific functional roles each group of layers plays within the overall architecture.

Comparison with Existing Approaches

The original self-attention mechanism in Transformers, characterized by the Multi-Head Attention (MHA) model, necessitates caching the keys (K𝐾Kitalic_K) and values (V𝑉Vitalic_V) in each head and layer to accelerate inference (Vaswani et al. 2017). This requirement has historically imposed a significant memory overhead, prompting a series of innovations aimed at reducing this burden.

Among these, Multi-Query Attention (MQA) and its more generalized counterpart, Grouped-Query Attention (GQA), consolidate the KV cache by allowing multiple query heads within the same layer to share a singular set of K and V matrices. This approach effectively reduces the number of unique key and value pairs that must be stored and retrieved during the computation process. Subsequently, Cross-Layer Attention (CLA) extends this concept by facilitating the sharing of K and V matrices across different layers, thereby offering further reductions in the memory footprint required for KV storage.

Our method, however, introduces a fundamentally different paradigm in addressing the challenges of self-attention. While previous methods have focused on reducing the redundancy in storing K and V matrices, our approach centers on the optimization of the computation of attention weights themselves. In standard practice, the cached keys (K𝐾Kitalic_K) are primarily utilized to compute attention weights in conjunction with the queries (Q𝑄Qitalic_Q). Instead of indirectly facilitating this interaction through shared KV matrices, our method proposes the direct sharing of the resultant attention weights—specifically, the softmax-normalized scores.

This not only diminishes the memory requirements by obviating the need to store separate sets of keys for each layer but also significantly reduces the computational complexity. By sharing the pre-computed softmax results across layers, our approach circumvents the repeated calculation of softmax, which is often one of the most computationally intensive operations in the attention mechanism. This efficiency gain is reflected in a substantial reduction in the number of floating-point operations (FLOPs) required during model inference, enhancing both the speed and scalability of Transformer deployments.

Unlike traditional methods that optimize memory use by sharing physical keys and values across layers or heads, our Shared Attention model innovates on the computational process itself, exploiting the consistent patterns in attention weights to streamline operations across multiple layers of the Transformer architecture.

Isotropic Attention Distribution

In an extensive analysis of layer-specific attention weights across a spectrum of LLMs, we explored the attention dynamics within models such as Llama2-7B-chat, Llama3-8B-instruct, Llama3-70B-instruct, Baichuan2-7B-chat, Qwen2-7B-instruct, and Qwen2-72B-instruct (Touvron et al. 2023; Yang et al. 2023; Bai et al. 2023). These models were evaluated using the MMLU.

Our investigations reveal a self-organization pattern in the attention weights across these diverse models. As depicted in Figure 2, there exists a consistent global similarity pattern in the layers’ attention weights across all tested models. This pattern suggests an inherent structural characteristic in the way LLMs process information, which can be broadly segmented into four distinct groups:

  • Group 1: Comprising the initial layers (indices 0 and 1), this group is situated closest to the input tokens and primarily focuses on abstracting token-level semantic information. These layers exhibit data-dependent attention patterns that are crucial for the initial semantic processing of the inputs.

  • Group 2: This group includes layers immediately following the first group and extends up to layer index 5. Layers in this segment demonstrate high internal similarity in attention weights but are markedly different from those in other groups. These layers likely serve as transitional zones where intermediate semantic features are refined.

  • Group 3: Encompassing layers post-Group 2 and extending to the penultimate layer, this is the largest group both in terms of the number of layers and their role within the architecture. The layers within this group display a high degree of similarity, suggesting an isotropy in the attention mechanism where the refined features are consistently utilized to inform the model’s deeper contextual understanding.

  • Group 4: The final group, consisting solely of the output layer, distinctively processes the aggregated contextual information to generate outputs. This layer’s attention weights diverge from those observed in other layers, underscoring its specialized role in the final decision-making process.

The distinct attention weight patterns identified across these groups reinforce the concept of functional specialization within LLMs. This segmentation not only highlights the diverse roles of different layers in processing inputs but also supports the potential for optimizing computational strategies, such as our proposed Shared Attention method, by manipulating these inherent patterns to reduce computational redundancy.

Refer to caption
Figure 3: Evolution of layer attention weights similarity throughout the pretraining phase of the Baichuan2 7B model, as it processes trained tokens from 220 billion to 2.6 trillion. The color gradient in the visualization represents cosine similarity, effectively illustrating the transition in attention patterns from the initial to the advanced stages of pretraining.

Dynamics During Pretraining

To elucidate the formation and evolution of attention weight patterns during the pretraining phase of LLMs, we utilized intermediate checkpoints of the Baichuan 7B model, provided by the model developers. These checkpoints, spanning from 0.2T to 2.6T tokens processed, offer a unique point of view to observe the dynamic shifts in attention mechanisms as the model gains exposure to an increasing volume of data.

We applied a consistent metric for measuring the similarity of attention weights across layers at each pretraining checkpoint. Additionally, the final chat model, fine-tuned to align with human reference responses, was included to benchmark the evolution against practical application outcomes. The dynamics of these attention weights are visualized in Figure 3, which illustrates the progressive differentiation and stabilization of attention patterns across the model’s layers.

As observed in the early pretraining stage at 0.2T tokens, Groups 1 and 2 appear merged, indicating a less differentiated processing strategy across these initial layers. This combination suggests that early in training, the model does not distinctly separate token-level semantic processing from intermediate semantic refinement. However, as the model progresses to 1.0T tokens, a clear division emerges between Groups 1 and 2. This separation aligns with the model beginning to form more specialized and efficient strategies for handling different types of information across its architecture.

The similarity within Group 3, which encompasses the bulk of the model’s layers, shows a marked improvement from a similarity score of 0.8 to 0.9. This increase is indicative of the model’s attention mechanism stabilizing and becoming more consistent in its approach to processing the bulk of contextual information.

The training advancements observed across the pretraining checkpoints not only demonstrate significant shifts in the internal structure of the model’s attention mechanisms but also correlate positively with performance improvements on multiple benchmarks. This includes results on the MMLU, CMMLU (Li et al. 2023), and C-Eval (Huang et al. 2024) 5-shot accuracy tests, which have reportedly improved from a baseline accuracy of 0.25 to 0.50 (Yang et al. 2023). This notable enhancement underscores the intrinsic link between the refinement of attention mechanisms within LLMs and their enhanced capabilities in natural language understanding tasks.

Moreover, further examination of the model’s development, as observed in supplementary material, reveals that the similarity within Group 3—comprising the core contextual processing layers of the model—continues to enhance after the alignment stage. This observation suggests that the alignment process, typically aimed at fine-tuning the model to more closely mirror human-like understanding and response generation, also contributes to the stabilization of the model’s attention mechanisms.

Experiments and Discussion

To validate the efficacy of our proposed Shared Attention (SA) method, we conducted series of experiments. These experiments were designed to test the robustness of SA under various configurations and to evaluate its performance on widely recognized benchmarks.

Initially, we applied the SA mechanism directly to advanced LLMs without any prior training to assess its impact on pre-trained models. This experiment aimed to understand the immediate effects of SA when integrated into existing model architectures. We evaluated the performance of these models on standard LLM benchmarks, including GLUE (General), GSM8k (Arithmetic), HellaSwag (Reasoning), and MMLU (Knowledge) (Wang et al. 2018; Cobbe et al. 2021; Zellers et al. 2019). As anticipated, the direct application of SA resulted in a loss of accuracy on some benchmarks. This outcome is consistent with our expectations given the lack of retraining to adapt the models fully to the nuances of the Shared Attention mechanism. Due to computational constraints, it was impractical for our team to pretrain an LLM from scratch incorporating SA.

To further probe the capabilities of SA under a training regimen, we fine-tuned base LLMs equipped with Shared Attention on the publicly available Instruct dataset (Taori et al. 2023). Post fine-tuning, these models were tested against the same benchmarks to find out any performance changes. This approach allowed us to measure the adaptability of SA when models are trained to accommodate its dynamics.

These experiments collectively demonstrate the potential of Shared Attention to modify the traditional attention mechanism in LLMs, showing a promising avenue for reducing computational demands while maintaining, and in some cases enhancing, model performance. The detailed results and further discussion on each benchmark and dataset are provided in the subsequent sections.

Experimental Setup

For the fine-tuning experiments, we utilized the Llama2-7B and Llama3-8B base models. These experiments were conducted on a robust hardware configuration consisting of two NVIDIA A100 80GB GPUs. Optimization of the models was carried out using the AdamW optimizer, with an initial learning rate set at 2×1052superscript1052\times 10^{-5}2 × 10 start_POSTSUPERSCRIPT - 5 end_POSTSUPERSCRIPT. We employed the bf16 datatype for model parameters, which enhances the numeric range and stability during backpropagation, crucial for maintaining precision in large model training.

Each GPU handled a micro-batch size of 16, leveraging gradient accumulation techniques to effectively manage the computational load. Additionally, we utilized DeepSpeed Zero Stage 3 to optimize the distribution of model and optimizer parameters and enhance memory management across the GPUs, ensuring efficient use of available resources. The fine-tuning process spanned two epochs and employed the standard Alpaca instruction format, which is designed to improve the responsiveness and accuracy of the models in handling instruction-based tasks.

Direct Application of Shared Attention

The application of SA was tested across discrete segments of layers within the Llama2-7B and Llama3-8B models, each comprising 32 layers in total. To evaluate the robustness and adaptability of SA as shown in Figure 4, it was implemented in varying layer segments, ranging from narrower spans such as four layers (e.g., SA:15similar-to\sim18) to broader spans such as eight layers (e.g., SA:23similar-to\sim30).

Refer to caption
Figure 4: The figure illustrates the implementation of Shared Attention within specific layer segments of the model. Shared Attention spans from layer 27 to 30 for a four-layer segment and from layer 23 to 30 for an eight-layer segment.

Preliminary assessments of SA in the earlier layers of Llama2-7B (e.g., layers 3 to 6) resulted in an explosion of perplexity, indicating significant disruptions in the model’s ability to predict subsequent tokens accurately. This phenomenon underscores the crucial role that attention score variances play in the model’s early stages of processing, which are essential for initial context setting and feature extraction. To quantitatively assess the impact of attention variance throughout the model, we conducted a detailed variance analysis. We applied the same computational method used to obtain attention mean scores to calculate the variance of attention weights in Llama2-7B and Llama3-8B while processing the MMLU dataset. We further explored the potential influence of attention variance in downstream layers by computing a weighted cumulative variance. This metric aggregates the variances of all downstream layers starting from each specific layer, weighted by the average of these summed variances. As illustrated in Figure 5, the analysis revealed that early layers exhibited significantly higher weighted variances compared to latter layers. This variance tends to decrease as one progresses through the model’s architecture, suggesting a stabilization of attention mechanisms in the latter layers. Given these results, our experiments predominantly focused on the application of SA in the latter layers, where such variances appear to stabilize.

Model GLUE
GSM8K
5-shot
HellaSwag MMLU
Llama2-7B 0.4050 ±plus-or-minus\pm± 0.0019 0.1395 ±plus-or-minus\pm± 0.0095 0.5713 ±plus-or-minus\pm± 0.0049 0.4119 ±plus-or-minus\pm± 0.0041
Llama2-7BSA:2330:SAsimilar-to2330{}_{\text{SA}:23\sim 30}start_FLOATSUBSCRIPT SA : 23 ∼ 30 end_FLOATSUBSCRIPT 0.3819 ±plus-or-minus\pm± 0.0019 0.0728 ±plus-or-minus\pm± 0.0072 0.5575 ±plus-or-minus\pm± 0.0050 0.3794 ±plus-or-minus\pm± 0.0040
Llama2-7BSA:2730:SAsimilar-to2730{}_{\text{SA}:27\sim 30}start_FLOATSUBSCRIPT SA : 27 ∼ 30 end_FLOATSUBSCRIPT 0.3882 ±plus-or-minus\pm± 0.0019 0.1243 ±plus-or-minus\pm± 0.0091 0.5616 ±plus-or-minus\pm± 0.0050 0.4056 ±plus-or-minus\pm± 0.0041
Llama2-7BSA:2326:SAsimilar-to2326{}_{\text{SA}:23\sim 26}start_FLOATSUBSCRIPT SA : 23 ∼ 26 end_FLOATSUBSCRIPT 0.4351 ±plus-or-minus\pm± 0.0019 0.1122 ±plus-or-minus\pm± 0.0087 0.5681 ±plus-or-minus\pm± 0.0049 0.3994 ±plus-or-minus\pm± 0.0040
Llama2-7BSA:1922:SAsimilar-to1922{}_{\text{SA}:19\sim 22}start_FLOATSUBSCRIPT SA : 19 ∼ 22 end_FLOATSUBSCRIPT 0.3996 ±plus-or-minus\pm± 0.0019 0.0834 ±plus-or-minus\pm± 0.0076 0.5553 ±plus-or-minus\pm± 0.0050 0.3926 ±plus-or-minus\pm± 0.0040
Llama2-7BSA:1518:SAsimilar-to1518{}_{\text{SA}:15\sim 18}start_FLOATSUBSCRIPT SA : 15 ∼ 18 end_FLOATSUBSCRIPT 0.3731 ±plus-or-minus\pm± 0.0019 0.0220 ±plus-or-minus\pm± 0.0040 0.4790 ±plus-or-minus\pm± 0.0050 0.3378 ±plus-or-minus\pm± 0.0047
Llama2-7B-Instruct-SFT 0.5372 ±plus-or-minus\pm± 0.0019 0.1440 ±plus-or-minus\pm± 0.0097 0.5772 ±plus-or-minus\pm± 0.0049 0.3722 ±plus-or-minus\pm± 0.0040
Llama2-7B-Instruct-SFTSA:2330:SAsimilar-to2330{}_{\text{SA}:23\sim 30}start_FLOATSUBSCRIPT SA : 23 ∼ 30 end_FLOATSUBSCRIPT 0.5401 ±plus-or-minus\pm± 0.0019 0.0758 ±plus-or-minus\pm± 0.0073 0.5671 ±plus-or-minus\pm± 0.0049 0.3717 ±plus-or-minus\pm± 0.0040
Llama3-8B 0.4804 ±plus-or-minus\pm± 0.0019 0.5155 ±plus-or-minus\pm± 0.0138 0.6009 ±plus-or-minus\pm± 0.0049 0.6198 ±plus-or-minus\pm± 0.0038
Llama3-8BSA:2330:SAsimilar-to2330{}_{\text{SA}:23\sim 30}start_FLOATSUBSCRIPT SA : 23 ∼ 30 end_FLOATSUBSCRIPT 0.5595 ±plus-or-minus\pm± 0.0019 0.3275 ±plus-or-minus\pm± 0.0129 0.6011 ±plus-or-minus\pm± 0.0049 0.6122 ±plus-or-minus\pm± 0.0038
Llama3-8BSA:2730:SAsimilar-to2730{}_{\text{SA}:27\sim 30}start_FLOATSUBSCRIPT SA : 27 ∼ 30 end_FLOATSUBSCRIPT 0.5532 ±plus-or-minus\pm± 0.0019 0.4526 ±plus-or-minus\pm± 0.0137 0.6060 ±plus-or-minus\pm± 0.0049 0.6163 ±plus-or-minus\pm± 0.0038
Llama3-8BSA:2326:SAsimilar-to2326{}_{\text{SA}:23\sim 26}start_FLOATSUBSCRIPT SA : 23 ∼ 26 end_FLOATSUBSCRIPT 0.5024 ±plus-or-minus\pm± 0.0019 0.4556 ±plus-or-minus\pm± 0.0137 0.5993 ±plus-or-minus\pm± 0.0049 0.6189 ±plus-or-minus\pm± 0.0038
Llama3-8BSA:1922:SAsimilar-to1922{}_{\text{SA}:19\sim 22}start_FLOATSUBSCRIPT SA : 19 ∼ 22 end_FLOATSUBSCRIPT 0.5115 ±plus-or-minus\pm± 0.0019 0.3745 ±plus-or-minus\pm± 0.0133 0.5829 ±plus-or-minus\pm± 0.0049 0.6181 ±plus-or-minus\pm± 0.0038
Llama3-8BSA:1518:SAsimilar-to1518{}_{\text{SA}:15\sim 18}start_FLOATSUBSCRIPT SA : 15 ∼ 18 end_FLOATSUBSCRIPT 0.4685 ±plus-or-minus\pm± 0.0019 0.0136 ±plus-or-minus\pm± 0.0032 0.5307 ±plus-or-minus\pm± 0.0050 0.3019 ±plus-or-minus\pm± 0.0038
Table 1: Performance metrics for different models across tasks
Figure 5: The figure displays the weighted cumulative variance for the Llama2-7B-chat and Llama3-8B-instruct models. The two lower axes represent the model’s structure: the left axis details the 32 layers, and the right axis shows the 32 heads within each layer. The z-axis represents the variance values.

The outcomes of these experiments, as summarized in Table 1, reveal interesting patterns. For the Llama2-7B model, implementing SA in the latter layers (e.g., SA:23similar-to\sim26 and SA:27similar-to\sim30) maintained relatively stable performance across a variety of benchmarks, including GLUE and MMLU. Conversely, extending the scope of SA to encompass more layers, particularly mid-level layers such as SA:15similar-to\sim18, led to a noticeable degradation in tasks requiring mathematical reasoning (GSM8K).

In comparison, the Llama3-8B model, which inherently showed higher layer-wise attention similarity as discussed in the previous sections, exhibited less performance deterioration when SA was applied. After implementing SA in the layers closer to the model’s output (e.g., SA:27similar-to\sim30), the Llama3-8B even outperformed its original configuration on the GLUE benchmark, suggesting that strategic placement of SA can potentially enhance the model’s performance in complex natural language understanding tasks.

Fine-Tuning on Instruct Dataset

Given the computational constraints that preclude the pretraining of LLMs with SA from scratch, we adopted to fine-tune existing LLMs to evaluate whether fine-tuning could ameliorate the performance deficits observed with the direct application of SA. This approach was particularly aimed at understanding the adaptability of SA under a more controlled learning regimen.

Fine-tuning was conducted on the publicly available Instruct dataset, which is designed to evaluate models on tasks that require following complex instructions. This dataset was chosen because it challenges the models to utilize their learned representations effectively, making it an ideal benchmark for testing the efficacy of modifications like SA.

The results, as summarized in Table 1, demonstrate a narrowed performance gap between the original models and those modified with SA. For instance, while the original Llama2-7B model outperformed the SA version in direct application tests, the fine-tuned Llama2-7BSA:2330:SAsimilar-to2330{}_{\text{SA}:23\sim 30}start_FLOATSUBSCRIPT SA : 23 ∼ 30 end_FLOATSUBSCRIPT showed significant improvements across multiple metrics. This suggests that fine-tuning enables the model to better integrate and leverage the Shared Attention mechanism, effectively regaining some of the lost performance noted in the initial application of SA.

These findings indicate the potential of fine-tuning as a viable method for integrating new architectural changes like SA into existing models. The recovery in performance indicates that with adequate training, the initial disadvantages of directly applying SA can be mitigated, leading to enhanced model capabilities that more closely align with or even exceed their original configurations.

Future Directions

Our experimental investigations have demonstrated that implementing Shared Attention (SA) across multiple latter layers in LLMs arouses minimal accuracy loss, making it a promising approach for enhancing model efficiency. Furthermore, our analysis reveals a trend towards isotropic attention patterns during the pretraining process, indicating that the models’ attention mechanisms tend to stabilize as they process more data.

Given these insights, integrating SA from the pretraining appears to be a particularly beneficial strategy. This early integration could allow models to better adapt to the streamlined attention mechanism, potentially improving performance and efficiency across various tasks. The foundational embedding of SA might simplify later adaptations and inherently supports efficient attention dynamics.

Another promising research direction involves exploring combinations between SA and other attention-sharing strategies like Cross-Layer Attention (CLA). Combining SA with methods such as CLA could exploit the strengths of both approaches, leading to a more robust and flexible attention mechanism. This holistic approach to attention management could provide a comprehensive solution that maximizes both computational efficiency and model scalability.

By pursuing these avenues, future research can not only refine the application of Shared Attention within LLMs but also explore its full potential in enhancing the architectural and operational efficiency of next-generation language models. These efforts could lead to models that are better equipped to handle the increasing complexity and diversity of tasks in natural language processing.

Related Work

Efficient memory management in transformers is a critical area of research with diverse objectives ranging from reducing memory bandwidth and storage requirements to optimizing computational costs during both training and inference phases. Notably, our work focuses on minimizing the size of the inference Key-Value (KV) cache that persists between model passes, thereby enhancing model efficiency without a significant compromise in performance.

Memory Efficiency in Attention Mechanisms

Significant efforts have been made to address the efficiency of the KV cache post-training. Techniques such as KV cache compression have been explored extensively. For instance, methods like KVQuant (Hooper et al. 2024) and KIVI (Liu et al. 2024b) employ quantization strategies to reduce the memory footprint of KV pairs to just a few bits. Moreover, works such as AttentionSink (Xiao et al. 2023) and Scissorhands (Liu et al. 2024a) introduce sparsity into the KV cache by selectively storing elements based on their proximity or importance to the generation token, thus reducing the overall storage requirements.

Architectural Innovations for Reducing KV Cache

Architectural modifications aimed at reducing the KV cache size are pivotal in enhancing the efficiency of large language models. Such strategies include limiting the effective sequence length, as seen in Sparse Attention (Child et al. 2019), which constrain attention to local windows to reduce both computational load and memory overhead. Another approach involves replacing traditional softmax attention with scalable alternatives like linear attention (Katharopoulos et al. 2020), which maintains constant space complexity and offers more graceful scaling with respect to the token count. Additionally, methods such as Grouped-Query Attention (GQA) (Ainslie et al. 2023) and Multi-Query Attention (MQA) (Shazeer 2019) aggregate attention across multiple queries, significantly decreasing the memory footprint by sharing KV pairs across attention heads. These innovations collectively contribute to reducing the redundancy in attention calculations and are directly relevant to our work, informing our development of the Shared Attention mechanism that further optimizes memory usage by sharing attention weights across layers.

Fazit

In this paper, we explored the attention dynamics within advanced LLMs and observed that the attention distribution across layers tends to isotropize following extensive pretraining. This isotropic pattern of attention, where layers exhibit similar attention mechanisms, inspired a novel approach to attention sharing that departs from conventional methods.

Traditionally, methods like MQA and CLA have focused on sharing KV caches to reduce memory overheads but still required the computation of attention weights independently across each layer. Our proposed Shared Attention (SA) method bypasses this redundancy by directly sharing the computed attention weights across multiple layers. This approach not only significantly reduces the size of the KV cache but also decreases the computational FLOPs required during model inference.

The introduction of Shared Attention represents a paradigm shift in the design of attention mechanisms in neural networks, emphasizing efficiency without compromising the model’s performance. By reducing both the computational burden and memory requirements, SA enables more scalable and efficient deployment of LLMs, particularly in environments where resources are constrained.

This research paves the way for further explorations into efficient model architectures and opens up new possibilities for the application of LLMs across a broader spectrum of tasks and datasets. Future work will focus on expanding the applicability of Shared Attention, exploring its integration during the initial phases of model training, and combining it with other optimization techniques to maximize the operational efficiency of LLMs.

References

  • Ainslie et al. (2023) Ainslie, J.; Lee-Thorp, J.; de Jong, M.; Zemlyanskiy, Y.; Lebrón, F.; and Sanghai, S. 2023. Gqa: Training generalized multi-query transformer models from multi-head checkpoints. arXiv preprint arXiv:2305.13245.
  • Bai et al. (2023) Bai, J.; Bai, S.; Chu, Y.; Cui, Z.; Dang, K.; Deng, X.; Fan, Y.; Ge, W.; Han, Y.; Huang, F.; et al. 2023. Qwen technical report. arXiv preprint arXiv:2309.16609.
  • Brandon et al. (2024) Brandon, W.; Mishra, M.; Nrusimha, A.; Panda, R.; and Kelly, J. R. 2024. Reducing Transformer Key-Value Cache Size with Cross-Layer Attention. arXiv preprint arXiv:2405.12981.
  • Child et al. (2019) Child, R.; Gray, S.; Radford, A.; and Sutskever, I. 2019. Generating long sequences with sparse transformers. arXiv preprint arXiv:1904.10509.
  • Cobbe et al. (2021) Cobbe, K.; Kosaraju, V.; Bavarian, M.; Chen, M.; Jun, H.; Kaiser, L.; Plappert, M.; Tworek, J.; Hilton, J.; Nakano, R.; et al. 2021. Training verifiers to solve math word problems. arXiv preprint arXiv:2110.14168.
  • Hendrycks et al. (2020) Hendrycks, D.; Burns, C.; Basart, S.; Zou, A.; Mazeika, M.; Song, D.; and Steinhardt, J. 2020. Measuring massive multitask language understanding. arXiv preprint arXiv:2009.03300.
  • Hooper et al. (2024) Hooper, C.; Kim, S.; Mohammadzadeh, H.; Mahoney, M. W.; Shao, Y. S.; Keutzer, K.; and Gholami, A. 2024. Kvquant: Towards 10 million context length llm inference with kv cache quantization. arXiv preprint arXiv:2401.18079.
  • Huang et al. (2024) Huang, Y.; Bai, Y.; Zhu, Z.; Zhang, J.; Zhang, J.; Su, T.; Liu, J.; Lv, C.; Zhang, Y.; Fu, Y.; et al. 2024. C-eval: A multi-level multi-discipline chinese evaluation suite for foundation models. Advances in Neural Information Processing Systems, 36.
  • Katharopoulos et al. (2020) Katharopoulos, A.; Vyas, A.; Pappas, N.; and Fleuret, F. 2020. Transformers are rnns: Fast autoregressive transformers with linear attention. In International conference on machine learning, 5156–5165. PMLR.
  • Li et al. (2023) Li, H.; Zhang, Y.; Koto, F.; Yang, Y.; Zhao, H.; Gong, Y.; Duan, N.; and Baldwin, T. 2023. Cmmlu: Measuring massive multitask language understanding in chinese. arXiv preprint arXiv:2306.09212.
  • Liu et al. (2024a) Liu, Z.; Desai, A.; Liao, F.; Wang, W.; Xie, V.; Xu, Z.; Kyrillidis, A.; and Shrivastava, A. 2024a. Scissorhands: Exploiting the persistence of importance hypothesis for llm kv cache compression at test time. Advances in Neural Information Processing Systems, 36.
  • Liu et al. (2024b) Liu, Z.; Yuan, J.; Jin, H.; Zhong, S.; Xu, Z.; Braverman, V.; Chen, B.; and Kivi, X. H. 2024b. A tuning-free asymmetric 2bit quantization for kv cache. arXiv preprint arXiv:2402.02750.
  • Shazeer (2019) Shazeer, N. 2019. Fast transformer decoding: One write-head is all you need. arXiv preprint arXiv:1911.02150.
  • Taori et al. (2023) Taori, R.; Gulrajani, I.; Zhang, T.; Dubois, Y.; Li, X.; Guestrin, C.; Liang, P.; and Hashimoto, T. B. 2023. Alpaca: A strong, replicable instruction-following model. Stanford Center for Research on Foundation Models. https://crfm. stanford. edu/2023/03/13/alpaca. html, 3(6): 7.
  • Touvron et al. (2023) Touvron, H.; Martin, L.; Stone, K.; Albert, P.; Almahairi, A.; Babaei, Y.; Bashlykov, N.; Batra, S.; Bhargava, P.; Bhosale, S.; et al. 2023. Llama 2: Open foundation and fine-tuned chat models. arXiv preprint arXiv:2307.09288.
  • Vaswani et al. (2017) Vaswani, A.; Shazeer, N.; Parmar, N.; Uszkoreit, J.; Jones, L.; Gomez, A. N.; Kaiser, Ł.; and Polosukhin, I. 2017. Attention is all you need. Advances in neural information processing systems, 30.
  • Wang et al. (2018) Wang, A.; Singh, A.; Michael, J.; Hill, F.; Levy, O.; and Bowman, S. R. 2018. GLUE: A multi-task benchmark and analysis platform for natural language understanding. arXiv preprint arXiv:1804.07461.
  • Xiao et al. (2023) Xiao, G.; Tian, Y.; Chen, B.; Han, S.; and Lewis, M. 2023. Efficient streaming language models with attention sinks. arXiv preprint arXiv:2309.17453.
  • Yang et al. (2023) Yang, A.; Xiao, B.; Wang, B.; Zhang, B.; Bian, C.; Yin, C.; Lv, C.; Pan, D.; Wang, D.; Yan, D.; et al. 2023. Baichuan 2: Open large-scale language models. arXiv preprint arXiv:2309.10305.
  • Zellers et al. (2019) Zellers, R.; Holtzman, A.; Bisk, Y.; Farhadi, A.; and Choi, Y. 2019. Hellaswag: Can a machine really finish your sentence? arXiv preprint arXiv:1905.07830.