Generative AI

NVIDIA NeMo Accelerates LLM Innovation with Hybrid State Space Model Support

Illustration showing models and NeMo.

Today’s large language models (LLMs) are based on the transformer model architecture introduced in 2017. Since then, rapid advances in AI compute performance have enabled the creation of even larger transformer-based LLMs, dramatically improving their capabilities. Advanced transformer-based LLMs are enabling many exciting applications such as intelligent chatbots, computer code generation, and even chip design.

Training cutting-edge LLMs requires an efficient and versatile software stack. NVIDIA NeMo provides an end-to-end platform to build, customize, and deploy LLMs. Integrated deeply into the NeMo framework is Megatron-Core, a PyTorch-based library that provides the essential components and optimizations needed to train LLMs at scale.  As model developers explore new model architectures, the NVIDIA platform continues to expand, enabling their innovations.

Today, NVIDIA is announcing that both NeMo and Megatron-Core now support pre-training and fine-tuning of state space models (SSMs), respectively. Additionally, NeMo now supports training models based on the Griffin architecture, described by Google DeepMind. 

Why explore alternative model architectures?

Transformer models excel at capturing long-range dependencies through the now-famous attention mechanism, making them the ideal choice for tasks that require global context understanding. 

However, the computational complexity of attention scales quadratically with sequence length, leading to large increases in training time and training cost with increasing sequence length. Additionally, during inference, attention requires storing a cache of key-value pairs (known as a KV cache), that grows linearly with sequence length. This leads to a growing memory footprint as sequence length increases. 

Recently, SSMs have emerged as a compelling model architecture for sequence modeling tasks as they overcome several of the limitations of attention. 

SSMs enable more efficient long-sequence length training

SSMs are a class of models that have gained popularity in the deep learning community as efficient alternatives to attention-based transformer models for sequence modeling tasks. 

SSMs feature the following compelling properties: 

  • Linear complexity: SSMs are linear in both computational and memory complexity, while attention is quadratic in both. This means they can model long-range dependencies in sequences much more efficiently than attention. 
  • High quality and accuracy: Like attention, SSMs look across the tokens of the input sequence, enabling models to focus on the most relevant parts. This results in comparable quality and accuracy to transformer-based models. 
  • Efficient inference: SSMs need only store constant-size vectors, rather than a KV cache, making inference more memory efficient, particularly with longer sequence lengths.

To illustrate the benefits SSMs provide for longer sequence lengths, the following chart shows the relative speedup of training a layer of Mamba-2 (a state space model variant described later in this post) compared to training a transformer layer as the sequence length increases. As the sequence length increases to 256K, the Mamba-2 layer is 18x faster than the transformer layer. 

Alt-text: A plot of the performance speedup of a Mamba layer compared to a transformer layer as sequence length increases.
Figure 1. Mamba layer performance relative to transformer layer, with the Mamba advantage growing rapidly as sequence length increases

Transformer: model dimension 4,096, 32 heads. Mamba-2: model dimension 4,096, state dimension 128, 8 groups.

Several SSM variants have become popular in the AI community, including Hyena, Mamba-1, and more recently, Mamba-2

Structured state space duality and Mamba-2

Mamba-2 stands out as a recent release that achieves very strong accuracy across multiple benchmarks. At the core of Mamba-2 is a new structured state space duality (SSD) layer, which is, in practice, a reformulation of the SSM math used in the Mamba-1 model. This reformulation recasts SSM computations as matrix multiplications, allowing them to make use of the significant matrix multiplication performance of NVIDIA Tensor Cores. 

Therefore, compared to Mamba-1, Mamba-2 can be trained far more quickly.  Mamba-2 also offers quality and accuracy competitive with transformers on language modeling tasks and can yield even better results when a few attention layers are combined with SSD layers in a hybrid model. 

However, pure SSMs are not without limitations. For example, they have been shown to struggle in “needle-in-a-haystack” type scenarios that require precise recall of information in very long sequences. 

Hybrid models can improve results and increase performance

Hybrid models that combine SSMs, SSDs, RNNs, and transformers can leverage the strengths of each model architecture while mitigating their individual weaknesses. 

In a recent paper, researchers, including members of the NVIDIA Applied Deep Learning Research (ADLR) team, described hybrid Mamba-Transformer models. In these hybrid models, standard transformer layers and novel SSM layers can be interleaved in arbitrary configurations. For example, the 8B hybrid model described in this paper has 56 layers. Four layers are self-attention layers, 24 are Mamba-2 layers, and 28 are multilayer perceptron (MLP) layers. The layers are allocated such that a Mamba-2 layer comes first, followed by the attention layers, with MLP layers distributed evenly throughout the model.  

According to the paper, the hybrid 8B Mamba-2-Hybrid model “exceeds the 8B Transformer on all 12 standard tasks” evaluated by the team. And, the 8B Mamba-2-Hybrid is also “predicted to be up to 8x faster when generating tokens at inference time.” 

A figure showing the Mamba-2-Hybrid architecture.
Figure 2. The Mamba-2-Hybrid architecture, is described in “An Empirical Study of Mamba-based Language Models.” Credit: 2406.07887 (arxiv.org) (Table 6)

Beyond the improved ability to perform tasks and significant performance benefits during inference, the Mamba-2-Hybrid model also shows greater compute efficiency. The chart below shows the compute needed to train the 8B Mamba-2-Hybrid model compared to the compute required to train an 8B Transformer model as the sequence length increases. 

A chart showing the amount of training compute, in TFLOPS, required for an 8B Transformer model and an 8B Mamba-2-Hybrid model, at sequence lengths starting at 2,048 and going to 32,768, from bottom to top. 
Figure 3. The amount of compute required for one iteration of an 8B Transformer model compared to one iteration of an 8B Mamba-2-Hybrid. The hybrid model dramatically slows the compute increase with sequence length compared to the pure transformer model

At a sequence length of 2,048 tokens, the compute required for both is roughly similar, with the hybrid model showing a slight advantage. However, as sequence length scales to as many as 32,768 tokens, the compute required for the 8B Transformer model doubles, while only growing by 13% for the hybrid model. As modern language models support sequence lengths of 1M tokens and above, this advantage for the SSM-Transformer-hybrid models will only grow.

A first step in supporting a new class of model architectures

Model architecture innovation is critical to delivering new levels of intelligence. In addition to world-class support for building transformer-based models, NeMo and Megatron-Core now provide the community with the ability to train SSMs and SSDs, as well as hybrid models that combine their benefits with the strengths of transformer models. 

With this release of NeMo, the following initial features are provided to enable the community to quickly begin experimenting: 

  • Support for SSD models, including Mamba-2.
  • Support for RG-LRU (Griffin architecture.) 
  • Support for Transformer/SSM hybrid model combinations.
  • Fine-tuning support for Recurrent Gemma (Griffin), pure Mamba-2 models, and the 8B Mamba-2-Hybrid models.
  • Sharding and model parallelism support.

In upcoming releases, support for additional sub-quadratic model architectures, additional performance optimizations, and support for FP8 training are planned.

Discuss (0)

Tags