11institutetext: Department of Mathematics and Computer Science, University of Southern Denmark, Odense, Denmark
11email: {jacn,petersk}@imada.sdu.dk

BitNet b1.58 Reloaded: State-of-the-art Performance Also on Smaller Networks

Jacob Nielsen 11 0009-0009-8141-630X   
Peter Schneider-Kamp
11 0000-0003-4000-5570
Abstract

Recently proposed methods for 1-bit and 1.58-bit quantization aware training investigate the performance and behavior of these methods in the context of large language models, finding state-of-the-art performance for models with more than 3B parameters. In this work, we investigate 1.58-bit quantization for small language and vision models ranging from 100K to 48M parameters. We introduce a variant of BitNet b1.58, which allows to rely on the median rather than the mean in the quantization process. Through extensive experiments we investigate the performance of 1.58-bit models obtained through quantization aware training. We further investigate the robustness of 1.58-bit quantization-aware training to changes in the learning rate and regularization through weight decay, finding different patterns for small language and vision models than previously reported for large language models. Our results showcase that 1.58-bit quantization-aware training provides state-of-the-art performance for small language models when doubling hidden layer sizes and reaches or even surpasses state-of-the-art performance for small vision models of identical size. Ultimately, we demonstrate that 1.58-bit quantization-aware training is a viable and promising approach also for training smaller deep learning networks, facilitating deployment of such models in low-resource use-cases and encouraging future research.

Keywords:
deep learning quantization-aware training green machine learning small language models image classification.

1 Introduction

The recent years of development of natural language processing (NLP) have been dominated by the capabilities offered by Large Language Models (LLMs). However, due to the size of these models, they pose a challenge in deployment and raise concerns regarding the environmental impact. Post-training quantisation methods transform the 16-bit weights to a lower bit-representation, which both reduces the memory and computational needs. The idea is to take the trained weights and find a good way of mapping them to fewer bits, enabling more efficient inference.

Several post-training quantisation methods have been proposed, including but not limited to, Generative Pre-trained Transformer Quantization [5] and Activation-aware Weight Quantization [9]. However, post-training quantization inherently comes at the cost of precision. Post-training quantization has also been employed in other domain such as vision models [8].

An alternative to post-training quantization is quantization-aware training such as LLM-QAT [10] and QA-LoRA. Here, as the training optimizes the quantized weights, there is no loss of precision when using the quantized model for inference. Recent works on 1-bit [13] and 1.58-bit [11] quantization-aware training architectures have demonstrated the potential of training in very low-bit representation while still maintaining most or all of the performance for LLMs.

The 1.58-bit quantization aware training architecture BitNet b1.58[11] proposes a solution based on replacing linear 16-bit layers with layers where the weights only assume the values 11-1- 1, 00, and 1111. Notably, for large-enough LLMs, BitNet b1.58 can match the 16-bit precision baselines both in capacity and performance. From above 3B parameters, the 1.58-bit models trained from scratch perform just as well as 16-bit models.

In this work we investigate 1.58-bit quantization aware training for small language models (SLMs) and vision models ranging from 100K to 48M parameters. We introduce a variant of BitNet b1.58 that relies on the median rather than the mean of the absolute values of the weights. Through extensive experiments we investigate and compare the scaling, the learning-rate robustness, and the regularization properties of both 1.58-bit variants. Our work demonstrates that 1.58-bit quantization aware training can get close to state-of-the-art performance on SLMs and even exceed the state-of-the-art performance on vision models, opening a new avenue for research in this direction. This facilitates the deployment of SLMs and small vision models in low-ressource settings. Our implementation is available from GitHub111https://github.com/schneiderkamplab/bitlinear and the Python Packacking Index222https://pypi.org/project/bitlinear/.

Refer to caption
Figure 1: The BitLinear layer is the backbone of the BitNet 1.58 Bits Reloaded architecture. It provides a drop-in replacement for linear layers (often referred to as feed-forward networks or multi-level perceptrons) in any architecture. AbsMeasure denotes the mean oder median of the absolute values of the weight. The two factors xscalesubscript𝑥𝑠𝑐𝑎𝑙𝑒x_{scale}italic_x start_POSTSUBSCRIPT italic_s italic_c italic_a italic_l italic_e end_POSTSUBSCRIPT and wscalesubscript𝑤𝑠𝑐𝑎𝑙𝑒w_{scale}italic_w start_POSTSUBSCRIPT italic_s italic_c italic_a italic_l italic_e end_POSTSUBSCRIPT denote two scaling factors for the input and 16-bit weights respectively, used in the dequantization. We employ a straight-through estimator for the backward computations of the gradients.

2 Method

In this section we present our quantization aware training architecture as a generalization of the BitNet b1.58 architecture [11]. First, we present our quantization method. Then, we document our experimental setup.

2.1 b1.58 Quantization

Our BitLinear layer functions as a drop-in replacement for PyTorch’s torch.nn.Linear layer. Figure 1 illustrates BitLinear’s 5-step computation flow:

  1. 1.

    The activations are normalized.

  2. 2.

    The normalized activations are quantized to k-bit precision.

  3. 3.

    The 16-bit shadow weights are quantized to 1.58-bit weights.

  4. 4.

    The quantized activations are multiplied with the 1.58-bit weights.

  5. 5.

    The result of the multiplication is dequantized by rescaling.

In the following, we details the mathematics behind this computation flow. We denote the Layer normalization [4] of input I𝐼Iitalic_I, as I^^𝐼\hat{I}over^ start_ARG italic_I end_ARG. We then define the quantified activation-bits as xscalesubscript𝑥𝑠𝑐𝑎𝑙𝑒x_{scale}italic_x start_POSTSUBSCRIPT italic_s italic_c italic_a italic_l italic_e end_POSTSUBSCRIPT, constituting the AbsMax:

xscale=Qbmax(|I^|)+ϵsubscript𝑥𝑠𝑐𝑎𝑙𝑒subscript𝑄𝑏𝑚𝑎𝑥^𝐼italic-ϵx_{scale}=\frac{Q_{b}}{max(|\hat{I}|)+\epsilon}italic_x start_POSTSUBSCRIPT italic_s italic_c italic_a italic_l italic_e end_POSTSUBSCRIPT = divide start_ARG italic_Q start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT end_ARG start_ARG italic_m italic_a italic_x ( | over^ start_ARG italic_I end_ARG | ) + italic_ϵ end_ARG (1)

where Qb=2k1subscript𝑄𝑏superscript2𝑘1Q_{b}=2^{k-1}italic_Q start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT = 2 start_POSTSUPERSCRIPT italic_k - 1 end_POSTSUPERSCRIPT is the range of the k bits used for the quantized activation. ϵitalic-ϵ\epsilonitalic_ϵ is s small value preventing zero-division. This means all activations can be scaled to integer values {Qb1,,Qb}subscript𝑄𝑏1subscript𝑄𝑏\{-Q_{b}-1,\ldots,Q_{b}\}{ - italic_Q start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT - 1 , … , italic_Q start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT }. We define the AbsMax Quantization for the activations as follows:

xquant=max(QB,min(QB1,round(I^xscale))x_{quant}=max(-Q_{B},min(Q_{B}-1,round(\hat{I}\cdot x_{scale}))italic_x start_POSTSUBSCRIPT italic_q italic_u italic_a italic_n italic_t end_POSTSUBSCRIPT = italic_m italic_a italic_x ( - italic_Q start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT , italic_m italic_i italic_n ( italic_Q start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT - 1 , italic_r italic_o italic_u italic_n italic_d ( over^ start_ARG italic_I end_ARG ⋅ italic_x start_POSTSUBSCRIPT italic_s italic_c italic_a italic_l italic_e end_POSTSUBSCRIPT ) ) (2)

Furthermore, we quantize the 16-bit weights Wn×m𝑊superscript𝑛𝑚W\in\mathcal{R}^{n\times m}italic_W ∈ caligraphic_R start_POSTSUPERSCRIPT italic_n × italic_m end_POSTSUPERSCRIPT to a ternary system of integer values {1,0,1}101\{-1,0,1\}{ - 1 , 0 , 1 } as follows. We define the scaling of W𝑊Witalic_W as:

wscale=1Measure(|W|)+ϵsubscript𝑤𝑠𝑐𝑎𝑙𝑒1𝑀𝑒𝑎𝑠𝑢𝑟𝑒𝑊italic-ϵw_{scale}=\frac{1}{Measure(|W|)+\epsilon}italic_w start_POSTSUBSCRIPT italic_s italic_c italic_a italic_l italic_e end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG italic_M italic_e italic_a italic_s italic_u italic_r italic_e ( | italic_W | ) + italic_ϵ end_ARG (3)

Where Measure𝑀𝑒𝑎𝑠𝑢𝑟𝑒Measureitalic_M italic_e italic_a italic_s italic_u italic_r italic_e denotes either the mean or median function, constituting the AbsMeasure Quantization.

We define the quantized weights Wquantsubscript𝑊𝑞𝑢𝑎𝑛𝑡W_{quant}italic_W start_POSTSUBSCRIPT italic_q italic_u italic_a italic_n italic_t end_POSTSUBSCRIPT (denoted as 1.58-Bit Weights in Figure 1) as:

Wquant=max(1,min(1,round(Wwscale))W_{quant}=max(-1,min(1,round(W\cdot w_{scale}))italic_W start_POSTSUBSCRIPT italic_q italic_u italic_a italic_n italic_t end_POSTSUBSCRIPT = italic_m italic_a italic_x ( - 1 , italic_m italic_i italic_n ( 1 , italic_r italic_o italic_u italic_n italic_d ( italic_W ⋅ italic_w start_POSTSUBSCRIPT italic_s italic_c italic_a italic_l italic_e end_POSTSUBSCRIPT ) ) (4)

Having quantized both the activations and the weights, we can apply a kernel with qquantsubscript𝑞𝑞𝑢𝑎𝑛𝑡q_{quant}italic_q start_POSTSUBSCRIPT italic_q italic_u italic_a italic_n italic_t end_POSTSUBSCRIPT and wquantsubscript𝑤𝑞𝑢𝑎𝑛𝑡w_{quant}italic_w start_POSTSUBSCRIPT italic_q italic_u italic_a italic_n italic_t end_POSTSUBSCRIPT as inputs:

yquant=xquantWquant+bsubscript𝑦𝑞𝑢𝑎𝑛𝑡subscript𝑥𝑞𝑢𝑎𝑛𝑡subscript𝑊𝑞𝑢𝑎𝑛𝑡𝑏y_{quant}=x_{quant}\cdot W_{quant}+bitalic_y start_POSTSUBSCRIPT italic_q italic_u italic_a italic_n italic_t end_POSTSUBSCRIPT = italic_x start_POSTSUBSCRIPT italic_q italic_u italic_a italic_n italic_t end_POSTSUBSCRIPT ⋅ italic_W start_POSTSUBSCRIPT italic_q italic_u italic_a italic_n italic_t end_POSTSUBSCRIPT + italic_b (5)

where b𝑏bitalic_b is optional bias. We detach both xquantsubscript𝑥𝑞𝑢𝑎𝑛𝑡x_{quant}italic_x start_POSTSUBSCRIPT italic_q italic_u italic_a italic_n italic_t end_POSTSUBSCRIPT and wquantsubscript𝑤𝑞𝑢𝑎𝑛𝑡w_{quant}italic_w start_POSTSUBSCRIPT italic_q italic_u italic_a italic_n italic_t end_POSTSUBSCRIPT from the computation graph to achieve a straight-through estimation of the gradients. The gradients update the “shadow weights”, i.e., the 16-Bit Weights that are quantized by AbsMeasure Quantization.

Finally, we rescale the output y𝑦yitalic_y during the Dequantization process:

y=yquantwscalexscale𝑦subscript𝑦𝑞𝑢𝑎𝑛𝑡subscript𝑤𝑠𝑐𝑎𝑙𝑒subscript𝑥𝑠𝑐𝑎𝑙𝑒y=\frac{y_{quant}}{w_{scale}\cdot x_{scale}}italic_y = divide start_ARG italic_y start_POSTSUBSCRIPT italic_q italic_u italic_a italic_n italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_w start_POSTSUBSCRIPT italic_s italic_c italic_a italic_l italic_e end_POSTSUBSCRIPT ⋅ italic_x start_POSTSUBSCRIPT italic_s italic_c italic_a italic_l italic_e end_POSTSUBSCRIPT end_ARG (6)

Comparing to the original BitNet b1.58, there are a number of differences:

  • We chose to use a standard layer normalization (LayerNorm) rather than RMS normalization, as the computational overhead is minimal and we observed slightly better performance with the standard layer norm in preliminary experiments.

  • We allow the use of both the median and the mean for quantizing weights. Prior works [13, 11] solely employ the mean. We investigate the impact of this choice in Section 3.

  • We actually quantize weights and activations to integer values. This means the matrix multiplications are performed between the 1.58-bit weights with integer values {1,0,1}101\{-1,0,1\}{ - 1 , 0 , 1 } and the 8-bit quantized activations with integer values 128,,127128127{-128,\ldots,127}- 128 , … , 127. This allows to develop multiplication-free kernels, as multiplication with 11-1- 1 corresponds to the subtraction of an 8-bit integer value, multiplication with 00 to the disregard of a value, and multiplication with 1111 to the addition of an 8-bit integer value.

    This is in contract to previous work [11], where the quantized weights have floating point values {1wscale,0,1wscale}1subscript𝑤𝑠𝑐𝑎𝑙𝑒01subscript𝑤𝑠𝑐𝑎𝑙𝑒\{\frac{-1}{w_{scale}},0,\frac{1}{w_{scale}}\}{ divide start_ARG - 1 end_ARG start_ARG italic_w start_POSTSUBSCRIPT italic_s italic_c italic_a italic_l italic_e end_POSTSUBSCRIPT end_ARG , 0 , divide start_ARG 1 end_ARG start_ARG italic_w start_POSTSUBSCRIPT italic_s italic_c italic_a italic_l italic_e end_POSTSUBSCRIPT end_ARG } while quantized activations have floating point values {128xscale,,127xscale}128subscript𝑥𝑠𝑐𝑎𝑙𝑒127subscript𝑥𝑠𝑐𝑎𝑙𝑒\{\frac{-128}{x_{scale}},\ldots,\frac{127}{x_{scale}}\}{ divide start_ARG - 128 end_ARG start_ARG italic_x start_POSTSUBSCRIPT italic_s italic_c italic_a italic_l italic_e end_POSTSUBSCRIPT end_ARG , … , divide start_ARG 127 end_ARG start_ARG italic_x start_POSTSUBSCRIPT italic_s italic_c italic_a italic_l italic_e end_POSTSUBSCRIPT end_ARG } according to the published information about the implementation[1]. Consequently, our BitNet 1.58 Bits Reloaded architecture is more directly amenable to custom software kernels and hardware implementations.

2.2 Experimental setup

We conduct all experiments with standard networks in small configurations with the torch.nn.Linear layers replaced by our BitLinear layers. The Adam[6] optimizer and a batch-size of 128 are employed. The number of model parameters is slightly higher in the BitLinear setting, as we both have 1.58-bit weights as well as the 16-bit shadow weights. However, this fact does not change the number of trainable/optimized parameters in practice.

For SLMs, we train small Mistral-like models with 4 layers and hidden sizes of 32323232, 64646464, 128128128128, and 256256256256. The number of attention head and key-value cache heads is set to the ceiling of the hidden size divided by 64, i.e., 1111 head for 32323232 and 64646464 hidden sizes and 2222 and 4444 heads for 128128128128 and 256256256256, respectively. The resulting models sizes are 6M, 12M, 24M, and 48M parameters. We use a text corpus of 135M tokens and train from scratch for 10 epochs unless otherwise noted, corresponding to a total of 1.35B tokens for each training. We trained a Byte Pair Encoding tokenizer with a vocabulary size of 8,00080008{,}0008 , 000. The experiments are conducted with the standard trainer from the Hugging Face transformers library333https://github.com/huggingface/transformers.

For vision models, we consider a standard serial implementation of classifier for MNIST and standard CNN-based implementations for CIFAR-10, and CIFAR-100. The model for MNIST is the smallest in this paper with only 100K parameters. The CIFAR-10 and CIFAR-100 models represent the but smallest models with 2.1M and 2.2M, respectively. The difference in model size is explained by CIFAR-10 having 10 classes and CIFAR-100 having 100 classes. The experiments are based on Pytorch Lightning444https://github.com/Lightning-AI/pytorch-lightning and use torchvision’s555https://pytorch.org/vision/stable/index.html versions of the datasets.

The MNIST [2] dataset consists of 60.000 train and 10.000 test samples. The CIFAR10 [7] and CIFAR100 [7] datasets both contains 50.000 train and 10.000 test samples. All models are trained from scratch. We calculate the accuracy as the mean of the percentage of correct batches across the test set.

Figure 2: Scaling behaviour of 16-bit and 1.58-bit (mean and medium) training for SLMs over 10 epochs (= 1,020 evaluations on test set.)
Refer to caption
(a) Scaling for 16 bit.
Refer to caption
(b) Scaling for 1.58 bit (mean).
Refer to caption
(c) Scaling 16 vs 1.58 bit (mean).
Refer to caption
(d) mean vs median for 64 hidden size.

3 Results

In this section, we present a comparison of our BitLinear implementation with 16-bit floating point torch.nn.Layer, showing close-to-state-of-the-art performance on SLMs and better-than-state-of-the-art performance for vision models. We also perform ablation studies on the learning rate and weight decay hyperparameters, as well as the choice of mean vs median for the quantization of the weights.

3.1 Small Language Models

The first experiment for SLMs is a scaling experiment, where we perform 16-bit and 1.58-bit training on all four model sizes. The second experiment is a hyperparameter tuning for the learning rate and weight decay in a 12M SLM, with a fixed hidden size of 64. We show the results of the first and second experiment in Tables 3.1 and 3.1, respectively.

Figure 3: Hyperparameter tuning regarding weight decay and learning rate for SLMs over 10 epochs (= 1,020 evaluations on test set.)
Refer to caption
(a) Weight decay for 1.58 bit (mean)
Refer to caption
(b) Weight decay for 1.58 bit (median)
Refer to caption
(c) Weight decay for 16 bit
Refer to caption
(d) Learning rate for 16 bit
Refer to caption
(e) Learning rate for 1.58 bit (mean)
Refer to caption
(f) Learning rate for 1.58 bit (median)

Both tables show the different configurations and perplexities after 10 epochs. For most configurations, the training has converged or is close to convergence at the end of the experiment. It is important to keep in mind that the reported perplexity is the exponentiation of the entropy, i.e., here the exponentiation of the loss defined via cross-entropy. Thus, minor changes in the loss result in quite discernible changes to the perplexity.

The first two columns give the hidden layer size and the number of parameters. The third column provides the bit-depth and implementation: “16” stands for 16-bit training, “1.58-mean” for our BitLinear implementation with 1.58 bits and AbsMean quantization of weights, and “1.58-median” for our BitLinear implementation with 1.58 bits and AbsMedian quantization of weights.

We show the results of this first experiment in Figure 2(d). Figure 2(a)) show that the 16-bit training scales exactly as expected when the number of hidden layers, and thus the models capacity, increases. We see in Figure 2(b), that 1.58-bit training follows the same trend, albeit with slightly lower performance. In Figure 2(c), we can visually compare the scaling between 16-bit training for models with 32 and 64 hidden sizes and 1.58-bit training for models with 64 and 128 hidden sizes. The observed perplexities suggest that the effective capacities of the models with 1.58-bit weights are around half that of the models with 16-bit weights, i.e., that hidden layers of approximately double size are needed for 1.58-bit models to reach performance comparable with the 16-bit counterparts. Figure 2(d) shows that the median generally converges slower than the mean over the employed weight decays. We discuss this fact in Section 4.

The fourth column shows the learning rate. For 16-bit training, we took a high but stable learning rate of 0.001 (1e-3). For 1.58-bit, we used the same or a larger learning rate of 0.01 (1e-2), as 1.58-bit training has been found to be more robust to higher learning rate in the context of LLMs[11]. The fifth column shows the weight decay. We tried both a small but noticeable decay of 5%, which is pretty prevalent in the pre-training and fine-tuning of LLMs, and no weight decay. Both Tables 3.1 and 3.1 hint that a weight decay of 5% yields the best or similar performance compared to other values of weight decay. This is also visualized in Figures 3(a), 3(b), and 3(c), where trainings with a weight decay of 5% are represented by a red line.

The sixth column provides the perplexity after 10 epochs. For nearly all configurations, after 10 epochs the training had converged. The best perplexities for 16-bit and 1.58-bit training are marked in bold, respectively. The seventh and last column shows the number of epochs, with total training length corresponding to 135M tokens per epoch, i.e., 1.35B tokens per 10-epoch experiment.

Table 1: Language modelling benchmarks at different scales.
#Hidden #Params Bits Learning Rate Weight Decay Perplexity Epochs
32 6M 16 0.001 0.00 77.8 10
0.05 81.0 10
1.58-mean 0.001 0.00 166.2 10
0.05 164.9 10
0.01 0.00 130.1 10
0.05 134.4 10
\cdashline3-7 1.58-median 0.001 0.00 183.6 10
0.05 183.8 10
0.01 0.00 116.6 10
0.05 118.0 10
64 12M 16 0.001 0.00 36.7 10
0.05 37.5 10
1.58-mean 0.001 0.00 67.4 10
0.05 68.2 10
0.01 0.00 76.3 10
0.05 68.2 10
\cdashline3-7 1.58-median 0.001 0.00 76.5 10
0.05 68.1 10
0.01 0.00 61.1 10
0.05 60.0 10
128 24M 16 0.001 0.00 22.3 10
0.05 21.4 10
1.58-mean 0.001 0.00 36.8 10
0.05 36.3 10
0.01 0.00 61.6 10
0.05 71.0 10
\cdashline3-7 1.58-median 0.001 0.00 39.8 10
0.05 37.5 10
0.01 0.00 42.3 10
0.05 38.4 10
256 48M 16 0.001 0.00 16.6 10
0.05 16.7 10
1.58-mean 0.001 0.00 28.7 10
0.05 27.1 10
0.01 0.00 77.7 10
0.05 65.6 10
\cdashline3-7 1.58-median 0.001 0.00 26.8 10
0.05 27.5 10
0.01 0.00 65.1 10
0.05 63.8 10
Table 2: Hyperparameter tuning for a 12M SLM.
#Hidden #Params Bits Learning Rate Weight Decay Perplexity Epochs
64 12M 16 0.001 0.00 36.7 10
0.01 37.7 10
0.05 37.5 10
0.10 36.8 10
0.01 0.00 44.8 10
0.01 51.0 10
0.05 45.9 10
0.10 44.2 10
0.1 0.00 871.5 10
0.01 938.9 10
0.05 191.3 10
0.10 65.3 10
mean 0.001 0.00 67.4 10
0.01 68.1 10
0.05 68.2 10
0.10 69.6 10
0.01 0.00 76.3 10
0.01 76.0 10
0.05 68.2 10
0.10 76.0 10
0.1 0.00 203.3 10
0.01 240.4 10
0.05 173.5 10
0.10 204.5 10
\cdashline3-7 median 0.001 0.00 76.5 10
0.01 74.1 10
0.05 68.1 10
0.10 72.2 10
0.01 0.00 61.1 10
0.01 63.5 10
0.05 60.0 10
0.10 61.3 10
0.1 0.00 197.6 10
0.01 223.2 10
0.05 154.0 10
0.10 157.9 10

3.2 Small Vision Models

The implementation of Bitnet b1.58 [1] adopts from [11] the strategy of employing significantly higher learning rates, arguing that this is crucial for optimising the 1.58-bit weights. They also state, that this does not carry to the full 16-bit precision, suggesting this might be because of prior fine-tuning. We show in the graphs presented in Figure 5(b) that larger learning rates are sub-optimal for both 1.58- and 16-bit weights in small classification models, despite being trained from scratch. We use the mean-based benchmark for comparability, but observe similar results for the median-based counterpart. For 1.58 bits, we see in Figure 5(a) that performance gradually declines as the learning grows from 0.0001 to 0.1, with the smallest learning rate providing the best performance. We observe a similar trend in Figure 5(b), where learning with a rate of 0.05 or above distorts the training, preventing the network from learning at all, as evident in the evaluation.

In Figure 4(b) we document the impact of weight decays of 0%, 1%, 5%, and 10% across the two learning rates 0.001 and 0.0001. Figure 4(a) showcases the effect of weight decay when using higher learning rates, whereas in Figure 4(b), we see more continuity over the first epochs and that a training with a weight decay of 1% appears superior.

As described in Section 2, we conducted experiments with both AbsMean and with AbsMedian quantization, all of which are shown in Table 3.2. The mean-based quantization is superior on MNIST and CIFAR10, with 0.150.150.150.15 and 1.221.221.221.22 difference in percentage points test accuracy, respectively. On CIFAR100 the median-based quantization is superior with a percentage point difference of 0.70.70.70.7. From our experiments in Table 3.2, no clear conclusion can be drawn as to which is preferable in general. Similarly, in Figure 4(b), we do not see a clear distinction between the two in neither the evolving performance nor the resulting one. Therefore, we propose the choice of AbsMean vs AbsMedian quantization for the weights as a hyperparameter for 1.58-bit training.

Figure 4: The effect of weight decay (WD) on the training robustness for CIFAR100 over 10 epochs.
Refer to caption
(a) Weight decay for learning rate 0.001.
Refer to caption
(b) Weight decay for learning rate 0.0001.
Figure 5: The effect of the learning rate (LR) on the training robustness for CIFAR100 over 10 epochs.
Refer to caption
(a) 1.58-Bit (mean).
Refer to caption
(b) 16-bit.

We investigate the effect of weight decay on the 1.58-bit networks and compare to 16-bit benchmarks in Table 3.2. We employed weight decays of 0%, 1%, 5%, and 10%. For both CIFAR10 and CIFAR100, the 1.58-bit models are significantly more robust compared to their 16-bit counterparts, which become so unstable the training is distorted and evaluates to a test accuracy of 1.001.001.001.00. This might be because of the coarse training-scheme inherently associated with 1.58-bit quantization, increasing the robustness of the training against further regularization through, for example, weight decay. For MNIST we see that, while weight decay yields a decrease in accuracy, is does not prevent the models from learning. Similarly, we see that the 1.58-bit training present more stability in performance when weight decay is employed.

Table 3: Supervised benchmarks on vision classification datasets.
Dataset #Params Bits Learning Rate Weight Decay Test Accuracy Epochs
MNIST 100K 16 0.0001 0.00 92.29 10
0.001 0.00 96.93 10
0.01 93.35 10
0.05 77.06 10
1.58-mean 0.0001 0.00 95.63 10
0.0001 0.01 96.08 10
0.001 0.00 96.01 10
0.01 93.11 10
0.05 86.57 10
0.01 0.00 94.59 10
0.05 0.00 93.80 10
0.10 0.00 93.15 10
\cdashline3-7 1.58-median 0.0001 0.00 94.14 10
0.0001 0.01 95.93 10
0.001 0.00 95.80 10
0.01 91.27 10
0.05 89.15 10
0.01 0.00 93.03 10
0.05 0.00 86.61 10
0.10 0.00 52.35 10
CIFAR10 2.1M 16 0.0001 0.00 60.86 10
0.001 0.00 70.06 10
0.01 58.32 10
0.05 10.0 10
0.10 10.0 10
1.58-mean 0.0001 0.00 68.94 10
0.01 69.1 10
0.05 71.47 10
0.001 0.00 70.35 10
0.01 69.08 10
0.05 58.04 10
0.01 0.00 63.92 10
0.05 0.00 25.01 10
0.10 0.00 23.05 10
\cdashline3-7 1.58-median 0.0001 0.00 69.08 10
0.01 69.55 10
0.05 70.25 10
0.001 0.00 71.21 10
0.01 69.80 10
0.05 60.61 10
0.01 0.00 65.80 10
0.05 0.00 54.77 10
0.10 0.00 49.48 10

Dataset #Params Bits Learning Rate Weight Decay Test Accuracy Epochs
CIFAR100 2.2M 16 0.0001 0.00 28.30 10
0.001 0.00 36.62 10
0.01 17.97 10
0.05 1.0 10
1.58-mean 0.0001 0.00 39.52 10
0.01 41.57 10
0.05 39.60 10
0.001 0.00 36.09 10
0.01 40.05 10
0.05 21.78 10
0.01 0.00 26.48 10
0.05 0.00 3.73 10
0.10 0.00 1.03 10
\cdashline3-7 1.58-median 0.0001 0.00 40.12 10
0.01 42.27 10
0.05 39.63 10
0.001 0.00 36.55 10
0.01 34.35 10
0.05 16.12 10
0.01 0.00 30.06 10
0.05 0.00 5.53 10
0.10 0.00 1.93 10

4 Discussion

Our results demonstrate that 1.58-bit training provides competitive performance on both small language and vision models. We hope our work encourages the community to work on 1.58-bit based architectures to facilitate efficient and fast inference independent of model scale. Overall, this enables both more environmentally friendly inference for many applications, and the deployment of deep neural networks in various low-resource uses-cases, with the potential of increased energy efficiency through multiplication-free kernels and even specialised hardware.

As reported in Section 3, using our approach to 1.58-bit training generally yields a small performance penalty in SLMs. In the small vision models, we see 1.58-bit outperforming the CNN-based networks for the CIFAR10 and CIFAR100 datasets, while being within a percentage point difference of 0.850.850.850.85 for a sequential-based network trained on MNIST. This difference between training on text and image data is not entirely unexpected due to the nature of complexity difference in the SLM-data and the simpler vision-classification datasets. While 1.58-bit training still relies on the full precision 16-bit weights as shadow weights for the quantization in computing the 1.58-bit weights, we are reducing the capacity of each weight in the linear layers and, hence, of the overall network.

This implies that, in some case, there might be a need for creating networks relying on an increased number of parameters to re-introduce some capacity. Evidently in our SLM-results shown in Figure 2(c), we see the need to utilize hidden layers of size 64 in the 1.58-mean case to gain the same performance as the 16-bit on using hidden layers of size 32. This is also holds for hidden layer sizes of 128 for 1.58 bits and corresponding 64 for 16-bit. Prior works have shown that LLMs do not utilize all parameters effectively [3] and even consist of redundant layers [12]. Therefore, we would expect the need for for this to decrease as model-size grows, i.e., we would expect 1.58-bit, to work well in networks from a certain size without increasing the number of parameters. This is in line with prior work [11].

The need for increasing parameters seems to depend on the complexity of the downstream task, evident from our results on SLMs in Section 3.2 compared to our results on small vision models in Section 3.2, where the same architecture (with an adjusted size of prediction head) outperforms full precision 16-bit models on both CIFAR10 and CIFAR100. For SLMs, we do not consider the increased size of hidden layers to be an obstruction for profiting from 1.58-bit architectures, as the models can still be expected to run significantly more efficient in inference settings when implemented using custom kernels.

In Sections 3.1 and 3.2 we conducted extensive experiments employing either one of the two quantization schemes: “1.58-mean” or “1.58-median”. Changing between the two changes the factor with which 16-bit weights are scaled before rounding to integers in the quantization process. The median will, in some cases, be resilient to weight-updates, allowing higher variance without notable effect on the scaling factor. The mean will be more directly affected, particularly by large weight changes of few weights. Therefore, one provides flexibility of weights-updates whereas the other provides more constrained feedback affecting the gradient-magnitude on the shadow weights.

In Figures 3(d), 3(e), and 3(f), we report the robustness in SLMs over different learning rates across 16-bit and both mean and median 1.58-bit schemes. Contrary, from what is reported on the 1.58-bit LLMs in [1], the large learning rate 0.010.010.010.01 (1e-1) produces instability to such a degree that it distorts the training, effectively rendering it unable to optimize the performance of the network. This also happens for 16 bits even though we are training from scratch. The fact that we are training small networks might explain this behaviour. The learning rate of 0.010.010.010.01 (1e-2) in Figure 3(f) shows the effectiveness of the median quantization, as it converges faster than the mean quantization proposed in [13, 11] and actually yields a convergent process similar to the one for 16 bits. This supports our claims of the behavior explained above, i.e., that the flexibility allowed in median quantization can aid faster convergence in some situations. Interestingly, we see that employing median quantization yields a significant difference in convergence when using a learning rate of 0.010.010.010.01 and 0.0010.0010.0010.001 (1e-3), contrary to the same learning rates with mean quantization as displayed in Figure 3(e), making the network more sensitive to the learning rate. SMLs exhibt some but not the same level of learning-rate robustness as LLMs [11] when being trained using larger learning rates.

5 Conclusion

In this paper, we introduced a variant of the BitNet b1.58-bit precision quantization-aware training demonstrating state-of-the-art performance on core downstream tasks for SLMs and vision classification models. To the best of our knowledge, this is the first work studying the characteristics and behaviour of the particular 1.58-bit quantization approach from [13, 11] on small networks. The investigations provided in this work underline the potential of employing 1.58 bits more generally in small networks, mitigating prior arguments that these weight-resolutions only exhibits potential on large networks with billions of parameters. This opens up for the efficient deployment of SLMs and small vision models, particularly in low-resource use-cases. We encourage future work to investigate 1.58-bit quantization-aware training on other networks such as object-detection networks in the vision domain and language models with encoders, investigating the degree to which our conclusions hold for such types of networks.

Our results suggest a scaling law for small SLMs, with a 1.58-bit network needing approximately the double size of hidden layers to achieve performance comparable to 16-bit versions. The learning rate for SLMs and small vision models employing the 1.58-bit does not follow the findings in prior work [11] to employ significantly larger learning-rates, even when trained from scratch. Weight decay distorts the training when employed in training with a high learning rate, but to the contrary helps when applied with smaller learning rates. Our results on employing AbsMean vs AbsMedian quantization of the 16-bit shadow weights do not yield distinctive and conclusive results, leaving it as a hyperparameter for now and opening avenues for future work on the most advantageous quantization schemes from 16 to 1.58 bits in the context of quantization-aware training.

References

  • [1] The era of 1-bit llms: Training tips, code and faq. https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf
  • [2] The mnist database of handwritten digits. http://yann. lecun. com/exdb/mnist/
  • [3] Ashkboos, S., Croci, M.L., do Nascimento, M.G., Hoefler, T., Hensman, J.: Slicegpt: Compress large language models by deleting rows and columns (2024)
  • [4] Ba, J.L., Kiros, J.R., Hinton, G.E.: Layer normalization. arXiv preprint arXiv:1607.06450 (2016)
  • [5] Frantar, E., Ashkboos, S., Hoefler, T., Alistarh, D.: Gptq: Accurate post-training quantization for generative pre-trained transformers (2023)
  • [6] Kingma, D.P., Ba, J.: Adam: A method for stochastic optimization. In: Proceedings of the International Conference on Learning Representations (2015)
  • [7] Krizhevsky, A., Hinton, G., et al.: Learning multiple layers of features from tiny images (2009)
  • [8] Li, Z., Gu, Q.: I-vit: integer-only quantization for efficient vision transformer inference. In: Proceedings of the IEEE/CVF International Conference on Computer Vision. pp. 17065–17075 (2023)
  • [9] Lin, J., Tang, J., Tang, H., Yang, S., Chen, W.M., Wang, W.C., Xiao, G., Dang, X., Gan, C., Han, S.: Awq: Activation-aware weight quantization for llm compression and acceleration (2024)
  • [10] Liu, Z., Oguz, B., Zhao, C., Chang, E., Stock, P., Mehdad, Y., Shi, Y., Krishnamoorthi, R., Chandra, V.: Llm-qat: Data-free quantization aware training for large language models (2023)
  • [11] Ma, S., Wang, H., Ma, L., Wang, L., Wang, W., Huang, S., Dong, L., Wang, R., Xue, J., Wei, F.: The era of 1-bit llms: All large language models are in 1.58 bits (2024)
  • [12] Men, X., Xu, M., Zhang, Q., Wang, B., Lin, H., Lu, Y., Han, X., Chen, W.: Shortgpt: Layers in large language models are more redundant than you expect. arXiv preprint arXiv:2403.03853 (2024)
  • [13] Wang, H., Ma, S., Dong, L., Huang, S., Wang, H., Ma, L., Yang, F., Wang, R., Wu, Y., Wei, F.: Bitnet: Scaling 1-bit transformers for large language models (2023)