\floatsetup

[table]capposition=top \optauthor \NameXiuying Wei \Email[email protected]
\addrCLAIRE, EPFL and \NameSkander Moalla \Email[email protected]
\addrCLAIRE, EPFL and \NameRazvan Pascanu \Email[email protected]
\addrGoogle DeepMind and \NameCaglar Gulcehre \Email[email protected]
\addrCLAIRE, EPFL

Investigating Low-Rank Training in Transformer Language Models: Efficiency and Scaling Analysis

Abstract

State-of-the-art LLMs often rely on scale with high computational costs, which has sparked a research agenda to reduce parameter counts and costs without significantly impacting performance. Our study focuses on Transformer-based LLMs, specifically applying low-rank parametrization to the computationally intensive feedforward networks (FFNs), which are less studied than attention blocks. In contrast to previous works, (i) we explore low-rank parametrization at scale, up to 1.3B parameters; (ii) within Transformer language models rather than convolutional architectures; and (iii) starting from training from scratch. Experiments on the large RefinedWeb dataset show that low-rank parametrization is both efficient (e.g., 2.6×\times× FFN speed-up with 32% parameters) and effective during training. Interestingly, these structured FFNs exhibit steeper scaling curves than the original models. Motivated by this finding, we develop the wide and structured networks surpassing the current medium-sized and large-sized Transformer in perplexity and throughput performance.

1 Introduction

Transformer language models (Vaswani et al., 2017b) have gained significant attention for their performance and scalability. These models have grown from hundreds of millions of parameters (Radford et al., 2019) to hundreds of billions (Brown et al., 2020; Touvron et al., 2023; Smith et al., 2022), increasing the need for efficient training and inference. While much research focuses on attention, feed forward networks (FFNs) account for over 60% of the model’s parameters and FLOPs, significantly impacting latency. Low-rank parametrization, as one of the very popular structured matrices, is an important technique to make linear layer efficient. However, they have not yet been thoroughly explored at sufficient scales as a modification in modern LLM architectures.

In this work, we investigate low-rank matrices for FFN blocks from initialization on recent Transformer language models ranging from 110M to 1.3B parameters. Specifically, by using low-rank parametrization with 32% of the parameters of FFN, the training speed of the 1.3B model can be boosted by 1.35×\times× with only a 1 PPL increase. Interestingly, the low-rank parametrization has steeper loss scaling curves than the traditional Transformer at its optimal trade-off \autoreffig:performance(a), suggesting a high potential for even better performance at larger scales. Finally, combined with Ainslie et al. (2023) for attention, we design wide and structured networks with slightly better PPL and maximum throughput performance under the same training FLOPs (e.g., 8% and 17% throughput boost on medium- and large-sized models). We hope our findings and results shed new light on the study of efficient NLP architectures.

2 Related work

Low-rank matrices have been widely used to decompose pre-trained weights for downstream compression Sharma et al. (2023) and to construct adapters for efficient fine-tuning like LoRA Hu et al. (2021). LoRA uses a low-rank approximation to reduce trainable parameters, while Sharma et al. (2023) selectively applies low-rank decomposition to well-trained weights.

Several works investigate low-rank training. Arora et al. (2019) argues that dense layers naturally converge to low-rank solutions during training, making this parametrization ideal. Early works like Denil et al. (2013); Tai et al. (2015) showed high efficiency of low-rank training. Some studies Yang et al. (2020); Xu et al. (2020); Vodrahalli et al. (2022) adapt rank during training and suggest regularizers for better accuracy. Khodak et al. (2021) propose spectral initialization and aligned weight decay for matrix products. Vodrahalli et al. (2022) suggest learning the initialization of low-rank matrices with data. However, these studies mainly focus on ResNets He et al. (2016) rather than recent LLMs.

In this paper, we train low-rank matrices with a fixed rank as a replacement for the FFN linear layers of recent Transformers from scratch and investigate the performance of the new architecture. Formally, the low-rank parametrization of a linear layer can be given as 𝑾𝒙𝑼(𝑽𝒙)𝑾𝒙𝑼𝑽𝒙{\bm{W}}{\bm{x}}\approx{\bm{U}}({\bm{V}}{\bm{x}})bold_italic_W bold_italic_x ≈ bold_italic_U ( bold_italic_V bold_italic_x ), where 𝑾𝑾{\bm{W}}bold_italic_W is the original weight, 𝒙𝒙{\bm{x}}bold_italic_x is the input, 𝑼M×R𝑼superscript𝑀𝑅{\bm{U}}\in\mathbb{R}^{M\times R}bold_italic_U ∈ blackboard_R start_POSTSUPERSCRIPT italic_M × italic_R end_POSTSUPERSCRIPT, 𝑽R×N𝑽superscript𝑅𝑁{\bm{V}}\in\mathbb{R}^{R\times N}bold_italic_V ∈ blackboard_R start_POSTSUPERSCRIPT italic_R × italic_N end_POSTSUPERSCRIPT, and R<min(M,N)𝑅𝑀𝑁R<\min(M,N)italic_R < roman_min ( italic_M , italic_N ). This reduces parameter count and FLOPs from MN𝑀𝑁M\cdot Nitalic_M ⋅ italic_N to (M+N)R𝑀𝑁𝑅(M+N)\cdot R( italic_M + italic_N ) ⋅ italic_R.

3 Experiments

3.1 Settings

Implementation

We replace only the FFN modules with low-rank parametrization, as the attention module is well-studied Ainslie et al. (2023); Shazeer (2019). We use ranks that are half or a quarter of the original hidden state dimension, reducing FFN parameters to 63% or 32% of the original size. The first FFN module remains unchanged to avoid significant performance degradation. For initialization, we follow the spectral initialization suggested by prior works Khodak et al. (2021).

Training

We use a basic Transformer architecture Vaswani et al. (2017a); Radford et al. (2019) with Rotary Embedding Su et al. (2024) and a basic FFN module composed of two linear layers and a GeLU activation function. Our model ranges from 110M to 1.3B parameters and is trained on the RefinedWeb dataset Penedo et al. (2023). We randomly select 0.5B tokens as validation set while the number of training tokens is allocated based on the scaling law Hoffmann et al. (2022). We measure training FLOPs as in Megatron Narayanan et al. (2021), including all matrix multiplications. Hyperparameters, such as learning rates and global batch size, are set according to recent studies Gu and Dao (2023); Zhang et al. (2022). Details are summarized in \autoreftab:baseline_config.

Table 1: Model and Training configuration. We report the number of layers (#Layer), hidden states dimension (Width), training tokens (Tokens)), global batch size in number of tokens (Batch), peaking learning rate (LR), and total training steps (Steps).

Name Size Width Layers Tokens Batch LR Steps Training FLOPs Transformer-s 110M 768 12 2.2B 0.5M 6.0e-4 4.2K 1.69e+18 Transformer-m 335M 1024 24 6.7B 0.5M 3.0e-4 13K 1.55e+19 Transformer-l 729M 1536 24 14.6B 0.5M 2.5e-4 28K 7.03e+19 Transformer-xl 1274M 2048 24 25.5B 0.5M 2.0e-4 49K 2.10e+20

3.2 Efficiency and accuracy performance

We evaluate both the efficiency and accuracy performance of low-rank parametrization in FFN. First, as shown in \autoreffig:performance(b), with increasing FFN width, GPU resources can be utilized more thoroughly, and this parametrization can bring a 1.4×\times× and 2.6×\times× speed-up with 63% and 32% of the parameters, respectively, compared to the width of 1536.

Second, in \autoreftab:complete_efficient_linear_layer, we observe that this parametrization results in about a 0.4 PPL increase on Transformer-xl with a 15% reduction in training time, and about a 1.0 higher PPL with a 1.35×\times× speed-up for the whole model.

Table 2: Performance of low-rank parametrization with 63% and 32% of the original FFN module’s parameters, where R𝑅Ritalic_R indicates the rank. Note that the total structured FFN is not exactly 63% of the original because we don’t replace the first FFN module.

Architecture Model FFN Training PPL Size (M) Size (M) Tokens (B) FLOPs Time (h) Transformer-s 110 57 2.2 1.69e+18 4.0 25.97 Low-Rank (R=384) 90 37 2.2 1.44e+18 3.8 27.16 Low-Rank (R=192) 74 21 2.2 1.22e+18 3.6 29.22 Transformer-m 335 201 6.7 1.55e+19 32.5 18.29 Low-Rank (R=512) 263 129 6.7 1.26e+19 29.6 19.12 Low-Rank (R=256) 202 69 6.7 1.01e+19 26.9 20.60 Transformer-l 729 453 14.6 7.03e+19 130.5 14.29 Low-Rank (R=768) 566 290 14.6 5.61e+19 113.6 14.82 Low-Rank (R=384) 431 155 14.6 4.42e+19 100.0 15.69 Transformer-xl 1274 805 25.5 2.10e+20 352.2 12.46 Low-Rank (R=1024) 985 516 25.5 1.66e+20 302.2 12.86 Low-Rank (R=512) 744 275 25.5 1.29e+20 260.2 13.55

3.3 Scaling analysis

Refer to caption
Figure 1: (a): The training scaling curves between the standard Transformer and the modified version with low-rank parametrization, which retains 63% and 32% of the original parameters, respectively. (b): FFN latency performance across different widths, measured on 30,000 tokens.

From \autoreffig:performance(a), it can be seen that the low-rank parametrization gets closer to the baseline when the model size increases. Technically, we observe that: (i) The low-rank parametrization exhibits steeper scaling curves compared to the dense networks, indicating significant potential for these efficient designs in LLMs. (ii) The scaling curve of 32% parameters of FFN is steeper than the 63% parameters of FFN highlights the scaling potential of highly structured large models. (iii) Given fixed training FLOPs budget, a wider and structured network with more tokens may achieve comparable or superior performance to dense networks at the optimal trade-off.

The scaling curves can be further optimized: (1) they are not drawn at their optimal training-compute trade-off unlike the baseline. (2) Only the FFN is made structured, while attention remains dense, contributing more to the model’s performance. The second point also explains why the current 32% parameter curve shows a larger validation loss than the 63% parameter curve under the same training FLOPs. This motivates us to further reduce attention using existing techniques in \autorefsec: wide_sparse.

3.4 Wide and Structured network

Motivated by the scaling curves, we reduce both the attention and FFN and create a wide and structured network, as shown in \autoreftab:new_arch. This approach aims to enhance efficiency with a much smaller network, achieving an 8% and 17% maximum throughput boost compared to medium- and large-sized GQA Ainslie et al. (2023) models while maintaining or slightly improving perplexity.

Table 3: We compare the performance of GQA and our wide, structured networks. Left: TP indicates the maximum throughput measured for a generation length of 256. Right: Dimensions of various components, including hidden states, FFN intermediate states, attention, and KVCache. GQA’s intermediate size is increased to match parameters, as in Meta (2024).

Method #Param Training FLOPs Tokens PPL TP (256) Hidden Intermediate Attention KV Transformer-m 335M 1.55e+19 6.7B 18.29 30229 1024 4096 1024 1024 Transformer-m (GQA) 335M 1.55e+19 6.7B 18.23 84202 1024 4864 1024 256 Low-Rank (R=512) 219M 1.55e+19 10.6B 17.89 91147 1024 4864 512 256 Transformer-l 729M 7.03e+19 14.6B 14.29 23351 1536 6144 1536 1536 Transformer-l (GQA) 729M 7.03e+19 14.6B 14.40 64737 1536 7424 1536 256 Low-Rank (R=768) 464M 7.03e+19 22.3B 14.27 75930 1536 7424 768 256

4 Conclusion and Limitation

In this paper, we investigate low-rank parametrization in the FFN of Transformer language models. Training such structured models from scratch shows promising scaling curves and efficiency. However, we have not explored its optimal scaling laws and have only limited our exploration to the language aspect. Studying the upper limits and other applications of low-rank training would also be very valuable.

References

  • Ainslie et al. (2023) Joshua Ainslie, James Lee-Thorp, Michiel de Jong, Yury Zemlyanskiy, Federico Lebrón, and Sumit Sanghai. Gqa: Training generalized multi-query transformer models from multi-head checkpoints. arXiv preprint arXiv:2305.13245, 2023.
  • Arora et al. (2019) Sanjeev Arora, Nadav Cohen, Wei Hu, and Yuping Luo. Implicit regularization in deep matrix factorization. Advances in Neural Information Processing Systems, 32, 2019.
  • Brown et al. (2020) Tom Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared D Kaplan, Prafulla Dhariwal, Arvind Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, et al. Language models are few-shot learners. Advances in neural information processing systems, 33:1877–1901, 2020.
  • Denil et al. (2013) Misha Denil, Babak Shakibi, Laurent Dinh, Marc’Aurelio Ranzato, and Nando De Freitas. Predicting parameters in deep learning. Advances in neural information processing systems, 26, 2013.
  • Gu and Dao (2023) Albert Gu and Tri Dao. Mamba: Linear-time sequence modeling with selective state spaces. arXiv preprint arXiv:2312.00752, 2023.
  • He et al. (2016) Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), June 2016.
  • Hoffmann et al. (2022) Jordan Hoffmann, Sebastian Borgeaud, Arthur Mensch, Elena Buchatskaya, Trevor Cai, Eliza Rutherford, Diego de Las Casas, Lisa Anne Hendricks, Johannes Welbl, Aidan Clark, et al. Training compute-optimal large language models. arXiv preprint arXiv:2203.15556, 2022.
  • Hu et al. (2021) Edward J Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, and Weizhu Chen. Lora: Low-rank adaptation of large language models. arXiv preprint arXiv:2106.09685, 2021.
  • Khodak et al. (2021) Mikhail Khodak, Neil Tenenholtz, Lester Mackey, and Nicolo Fusi. Initialization and regularization of factorized neural layers. arXiv preprint arXiv:2105.01029, 2021.
  • Meta (2024) Meta. Llama 3. https://llama.meta.com/llama3/, 2024.
  • Narayanan et al. (2021) Deepak Narayanan, Mohammad Shoeybi, Jared Casper, Patrick LeGresley, Mostofa Patwary, Vijay Korthikanti, Dmitri Vainbrand, Prethvi Kashinkunti, Julie Bernauer, Bryan Catanzaro, et al. Efficient large-scale language model training on gpu clusters using megatron-lm. In Proceedings of the International Conference for High Performance Computing, Networking, Storage and Analysis, pages 1–15, 2021.
  • Penedo et al. (2023) Guilherme Penedo, Quentin Malartic, Daniel Hesslow, Ruxandra Cojocaru, Alessandro Cappelli, Hamza Alobeidli, Baptiste Pannier, Ebtesam Almazrouei, and Julien Launay. The refinedweb dataset for falcon llm: outperforming curated corpora with web data, and web data only. arXiv preprint arXiv:2306.01116, 2023.
  • Radford et al. (2019) Alec Radford, Jeffrey Wu, Rewon Child, David Luan, Dario Amodei, Ilya Sutskever, et al. Language models are unsupervised multitask learners. OpenAI blog, 1(8):9, 2019.
  • Sharma et al. (2023) Pratyusha Sharma, Jordan T Ash, and Dipendra Misra. The truth is in there: Improving reasoning in language models with layer-selective rank reduction. arXiv preprint arXiv:2312.13558, 2023.
  • Shazeer (2019) Noam Shazeer. Fast transformer decoding: One write-head is all you need. arXiv preprint arXiv:1911.02150, 2019.
  • Smith et al. (2022) Shaden Smith, Mostofa Patwary, Brandon Norick, Patrick LeGresley, Samyam Rajbhandari, Jared Casper, Zhun Liu, Shrimai Prabhumoye, George Zerveas, Vijay Korthikanti, et al. Using deepspeed and megatron to train megatron-turing nlg 530b, a large-scale generative language model. arXiv preprint arXiv:2201.11990, 2022.
  • Su et al. (2024) Jianlin Su, Murtadha Ahmed, Yu Lu, Shengfeng Pan, Wen Bo, and Yunfeng Liu. Roformer: Enhanced transformer with rotary position embedding. Neurocomputing, 568:127063, 2024.
  • Tai et al. (2015) Cheng Tai, Tong Xiao, Yi Zhang, Xiaogang Wang, et al. Convolutional neural networks with low-rank regularization. arXiv preprint arXiv:1511.06067, 2015.
  • Touvron et al. (2023) Hugo Touvron, Louis Martin, Kevin Stone, Peter Albert, Amjad Almahairi, Yasmine Babaei, Nikolay Bashlykov, Soumya Batra, Prajjwal Bhargava, Shruti Bhosale, et al. Llama 2: Open foundation and fine-tuned chat models. arXiv preprint arXiv:2307.09288, 2023.
  • Vaswani et al. (2017a) Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. Advances in neural information processing systems, 30, 2017a.
  • Vaswani et al. (2017b) Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. Advances in neural information processing systems, 30, 2017b.
  • Vodrahalli et al. (2022) Kiran Vodrahalli, Rakesh Shivanna, Maheswaran Sathiamoorthy, Sagar Jain, and Ed H Chi. Nonlinear initialization methods for low-rank neural networks. arXiv preprint arXiv:2202.00834, 2022.
  • Xu et al. (2020) Yuhui Xu, Yuxi Li, Shuai Zhang, Wei Wen, Botao Wang, Yingyong Qi, Yiran Chen, Weiyao Lin, and Hongkai Xiong. Trp: Trained rank pruning for efficient deep neural networks. arXiv preprint arXiv:2004.14566, 2020.
  • Yang et al. (2020) Huanrui Yang, Minxue Tang, Wei Wen, Feng Yan, Daniel Hu, Ang Li, Hai Li, and Yiran Chen. Learning low-rank deep neural networks via singular vector orthogonality regularization and singular value sparsification. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition workshops, pages 678–679, 2020.
  • Zhang et al. (2022) Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen, Christopher Dewan, Mona Diab, Xian Li, Xi Victoria Lin, et al. Opt: Open pre-trained transformer language models. arXiv preprint arXiv:2205.01068, 2022.