Structured Inverse-Free Natural Gradient Descent:
Memory-Efficient & Numerically-Stable KFAC
Abstract
Second-order methods such as KFAC can be useful for neural net training. However, they are often memory-inefficient since their preconditioning Kronecker factors are dense, and numerically unstable in low precision as they require matrix inversion or decomposition. These limitations render such methods unpopular for modern mixed-precision training. We address them by (i) formulating an inverse-free KFAC update and (ii) imposing structures in the Kronecker factors, resulting in structured inverse-free natural gradient descent (SINGD). On modern neural networks, we show that SINGD is memory-efficient and numerically robust, in contrast to KFAC, and often outperforms AdamW even in half precision. Our work closes a gap between first- and second-order methods in modern low-precision training.
1 Introduction
The continuing success of deep learning (DL) is—to a large extent—powered by scaling up computational power (Thompson et al., 2020) to increase the number of trainable neural network (NN) parameters. Contemporary natural language processing (Radford et al., 2019; Brown et al., 2020; Touvron et al., 2023) and computer vision (Dehghani et al., 2023) models often consist of billions of parameters, and will likely grow further in the future. To compensate for increasing computational demands, many training pipelines use lower precision data types (Micikevicius et al., 2018) and memory-efficient first-order optimizers like SGD (Robbins & Monro, 1951) or Adam(W) (Kingma & Ba, 2015; Loshchilov & Hutter, 2019).
Second-order methods, like natural gradient descent (NGD, Amari, 1998), leverage curvature information which has many applications in DL: It is useful for improving training dynamics (Martens & Grosse, 2015; Osawa et al., 2023), pruning (Wang et al., 2019), understanding the influence of training examples (Bae et al., 2022), and uncertainty estimation (Zhang et al., 2018; Immer et al., 2021; Daxberger et al., 2021). One major obstacle why those methods are rarely used is their higher memory consumption and iteration cost.
The perhaps most common concept to scale second-order methods for DL is Kronecker-factored approximate curvature (KFAC, Heskes, 2000; Martens & Grosse, 2015) which approximates the Fisher’s block diagonals via Kronecker products. The KFAC optimizer built on top of this curvature approximation, and its variants such as George et al. (2018) show promising results for medium-sized NNs (e.g. Osawa et al., 2023), its usefulness is often limited by (i) memory consumption, and (ii) the use of low-precision floating-point (FP) training that renders matrix decompositions/inversions required to pre-condition the gradient numerically unstable.
Recently, Lin et al. (2023) proposed an inverse-free Kronecker-factored natural gradient descent (INGD) algorithm that replaces matrix inversion with subtraction in a matrix logarithm space. Their update is purely based on matrix multiplications and therefore numerically stable in single-precision (FP-32); however, it is unclear whether this extends to half-precision (BFP-16). Furthermore, INGD has not been derived from the popular natural gradient approaches for DL. It is unclear if and how the method is connected to the predominant KFAC optimizer. Also, INGD does not improve over KFAC’s memory complexity since its Kronecker factors are dense matrices of the same size. And lastly, INGD has only been tested on convolution-based models and it is unclear whether it is useful for training modern transformer-based architectures (Vaswani et al., 2017).
Here, we extend INGD to lower its computational cost and theoretically resolve its connection to other approximate NGD methods for DL (overview in Figure 2): First, we show that a special case of INGD recovers the KFAC method. This allows us to effectively perform KFAC updates in an inverse-free fashion. We call this modification of INGD inverse-free KFAC (IKFAC). Second, we exploit an algebraic structure in the matrix logarithm space and propose structure-preserving updates to maintain sparse structures on Kronecker factors. This significantly reduces memory and leads to a novel, scalable second-order optimization algorithm we call structured inverse-free natural gradient descent (SINGD) which contains INGD and IKFAC as special cases. We evaluate SINGD on convolution- and transformer-based models and show that it can (i) outperform SGD and AdamW while using as little memory as the latter thanks to structured Kronecker factors and (ii) yield better performance than KFAC while being stable in half-precision:
- (a)
-
(b)
We impose various structures (block-diagonal, low-rank, Toeplitz, hierarchical) on INGD’s Kronecker factors, allowing them to be sparse to lower the memory consumption and run time (Figure 1, right and Table 1). Unlike many existing second-order methods tailored to a form of structure, our proposed update rule (Figure 4) is unified, efficient, and inverse-free for a range of structures. We analyze the impact of structures on downstream performance and find that structures with considerably lower memory consumption (even lower than AdamW) can yield competitive performance.
-
(c)
Unlike other second-order methods, we show that SINGD can stably train a range of modern architectures (transformers, CNNs, GNNs) in BFP-16. In contrast to first-order methods which are often useful in narrower scopes (SGD is best for CNNs, AdamW is best for transformers), SINGD works well and outperforms SGD and AdamW in many cases (see Section 4).
Our work closes a gap between first- and second-order methods in modern low precision neural network training111PyTorch implementation: github.com/f-dangel/singd.
Method | Peak memory | Training time |
---|---|---|
[GiB] | [min] | |
SGD (BFP-16) | 2.63 (1.00 x) | 18.5 (1.00 x) |
AdamW (BFP-16) | 2.69 (1.02 x) | 19.7 (1.07 x) |
SINGD-Diag* (BFP-16) | 2.67 (1.02 x) | 23.8 (1.29 x) |
IKFAC* (BFP-16) | 3.18 (1.21 x) | 34.0 (1.84 x) |
INGD (BFP-16) | 3.39 (1.29 x) | 34.1 (1.84 x) |
KFAC (FP-32) | 4.00 (1.52 x) | 83.2 (4.49 x) |
2 Preliminaries
We first introduce the necessary ingredients to establish a connection between INGD and KFAC, which are derived from different perspectives. We start by describing Newton’s method since both methods can be seen as approximate Newton methods using NGD. NN training often corresponds to an unconstrained minimization problem. Consider training a NN for image classification. Given a set of examples with labels and images , the optimization problem is
(1) |
where , , and is a NN that outputs a predicted label for an image . Parameters denote learnable weights of the NN and is a differentiable loss function to measure the difference between a true label and a predicted label . To solve Equation 1, Newton’s method follows the update
(2) |
where is the Hessian of the loss.
2.1 KFAC: Approximate NGD for MLE
Computing the Hessian, as required by Newton’s method, is usually intractable for NNs. NGD uses a Fisher information matrix (FIM) instead of the Hessian by reformulating problem (1) as maximum likelihood estimation (MLE) of , where . The maximization problem is equivalent to the MLE problem
(3) |
This formulation allows to exploit additional statistical structures such as the FIM which is defined as shown below (Kunstner et al., 2019), where we assume a label is sampled from the likelihood given an image . With , we have
(4) |
For ubiquitous loss functions like the mean-squared error and cross-entropy, and more generally, many members of the exponential family with natural parameterization, the FIM coincides with the generalized Gauss-Newton (GGN) matrix (Wang, 2010; Martens, 2014), a common approximation of the Hessian in deep learning (Schraudolph, 2002; Botev et al., 2017). This relationship connects NGD to Newton’s method. A common approximation of the FIM/GGN and Hessian is the so-called empirical Fisher , which replaces the samples from the model’s predictive distribution in Equation 4 with the empirical data labels :
While there is no clear theoretical justification for this Hessian approximation (Kunstner et al., 2019), it simplifies the implementation, reduces cost, and has been shown to work well in practice (Graves, 2011; Osawa et al., 2019). This approximation is also known as Fisher’s scoring with observed FIM for nonlinear models (Osborne, 1992; Smyth, 1996, 2015). With this, we can formulate an NGD update with the empirical FIM to approximate Newton’s method as
We call this update NGD for MLE.
KFAC (Heskes, 2000; Martens & Grosse, 2015) is the probably most common second-order optimizer in DL. The KFAC algorithm is based on a Kronecker-factored approximation of the Fisher, which is also sometimes referred to as KFAC. Here, we refer to the algorithm as KFAC oder KFAC method and to the approximation as Kronecker approximation; we will consider the empirical Fisher’s Kronecker approximation. It approximates the per-layer FIM with a Kronecker-factored block for each layer of the net. This approximation has first been derived for linear layers, later for convolutional (Grosse & Martens, 2016) and recurrent layers (Martens et al., 2018), and recently been generalized to all linear layers that use weight sharing (Eschenhagen et al., 2023), e.g. graph neural networks and transformers. A block is given by with and , where is the th layer’s input and is the gradient of the loss w.r.t. the layer’s output. We suppress the dependence on the parameters and the input and, for simplicity, assume no weight sharing. KFAC also uses exponential moving averages () over and (yielding ) and damping , see Figure 4.
While the Kronecker approximation enables more efficient gradient preconditioning, KFAC needs to store the dense Kronecker factors and and invert them at every preconditioner update. The run time overhead is usually amortized by updating the preconditioner less frequently, but this can cause instabilities, especially in low-precision settings. Second, the Kronecker factors introduce significant memory overhead, which poses issues in large models. Since low-precision training is becoming the standard norm in fields like natural language processing, these issues will become more apparent in modern DL. There are multiple numerical concerns when using KFAC or variants thereof in low precision. In PyTorch (Paszke et al., 2019) and JAX (Bradbury et al., 2018) implementations, all tensors must be casted into FP-32 as (B)FP-16 matrix inverses/decompositions are not supported. Moreover, has to be rescaled to avoid over- or under-flows when calculating . Memory consumption has previously been addressed through diagonal or block-diagonal versions of (Zhang et al., 2018; Grosse et al., 2023). However, it is unclear if these simple structures maintain downstream performance.
2.2 INGD: Approximate NGD for Bayesian estimation
Derived from Bayesian principles, INGD (Lin et al., 2023) directly approximates the Hessian inverse. We first introduce two ingredients INGD builds on: the Bayesian learning rule (BLR, Khan & Lin, 2017; Zhang et al., 2018; Khan et al., 2018; Osawa et al., 2019; Lin et al., 2020; Khan & Rue, 2021; Tan, 2022) and an inverse-free second-order method from Lin et al. (2021). By the BLR, Newton’s method to solve the MLE (3) can be seen as another natural-gradient update to solve a variational inference (VI) problem with a delta approximation (Khan & Rue, 2021). This interpretation allows to view a precision matrix in the variational problem as Hessian estimation in the MLE problem. Thus, Lin et al. (2021) suggest reparameterizing the Hessian as the precision of the Gaussian posterior in a matrix logarithm space and exploiting the parameterization invariance of natural gradients to obtain an inverse-free update.
BLR
Consider a Bayesian problem formulation, where NN weights are random variables. We denote these weights by new parameters since random variables are no longer learnable and use a variational Gaussian distribution to approximate the posterior over the random variables. Its mean and precision will be treated as the learnable weights and the Hessian estimation in Newton’s step (2).
The VI problem considered in the learning rule is defined as with the evidence lower bound (ELBO)
(5) |
are the learnable parameters of the variational Gaussian distribution with mean and precision . The likelihood takes the same form as in the MLE setting while the prior is defined by a regularizer . To recover the MLE problem, we consider an uninformative prior (i.e., ). is the entropy of .
Similar to the MLE case, the Bayesian formulation allows to exploit additional statistical structures in form of another FIM, which is that of the variational Gaussian defined as
and has a closed-form expression. This FIM should not be confused with the FIM used for MLE (4).
Under the BLR, we perform NGD updates not only on but also on . Khan & Rue (2021) formulate a step with the exact FIM and stepsize to update ,
This is the NGD update for BLR, vis-à-vis for MLE. Following Khan & Nielsen (2018), the update simplifies to
Further simplifying expectations with a delta approximation (highlighted in red) at mean , we obtain
which recovers Newton’s method in (2) for .
Removing inversion
Lin et al. (2021) reparameterize the precision matrix in a matrix logarithm space and perform natural gradient updates in this space, which transforms inversion into subtraction. One can go back directly to the original space, without explicitly inverting a matrix, via a truncated matrix exponential. The method is inverse-free and, since NGs are parameterization invariant, Newton-like.
The first step is to express the precision matrix using a non-singular square matrix as and perform a natural gradient step using the exact FIM in a tangent space (denoted by ) of at iteration . We then construct a new map as using both the current point and as input, where is the matrix exponential. Observe that stays in a matrix logarithm space. At each iteration , we use a new matrix logarithm space associated to and generate a new origin in this space to represent since . The map is a local reparameterization map that takes not only but also as input. Thanks to this map, the Fisher block is locally orthonormalized (Lin et al., 2023) at origin . Since we used the origin to represent in the local coordinate , a natural gradient step becomes a (Euclidean) gradient step in the space of , which makes it easy to add Riemannian momentum (Lin et al., 2023) into the structured positive-definite matrix . This allows to perform updates in the logarithmic space of and avoid matrix inversions:
(6) |
where and . Equation 6 is a Newton-like update without matrix inverse. To see that, we can reexpress the update of in terms of and use properties of the matrix exponential function,
Next, we can construct a structured precision matrix as a structured Hessian estimation using a sparse non-singular matrix . As we will discuss in Section 3.2, it is essential to update to preserve sparsity in . The space of as a tangent/logarithm space of allows us to efficiently impose sparse structures on without requiring the Hessian or a Hessian approximation to be sparse or structured. This is different from another inverse-free method (Tan, 2022) that considers directly performing NGD updates of instead of , where must be restricted to a (triangular) Cholesky factor. This does not preserve sparsity in unless the Hessian or its approximation admit a special structure, which is usually not the case in DL problems.
INGD
Our work is built on INGD (Figure 4) where is factorized into two Kronecker factors. The exact FIM under this parameterization is singular due to a correlation between and : the Kronecker factorization is not unique. Lin et al. (2023) propose a (non-singular) block-diagonal approximated FIM by ignoring the correlation in the original FIM and perform NGD with this block-diagonal FIM on tangent spaces of the factors. Riemannian momentum is further introduced in the update of and . They use the Kronecker approximation discussed in Section 2.1 to approximate the Hessian and truncate the matrix exponential to obtain a purely matrix-multiplication based update scheme. It is unclear how INGD is related to KFAC which uses another Kronecker factorization . INGD also remains memory-inefficient due to the use of dense Kronecker factors. The authors only consider and evaluate it on convolution-based models in single precision. It remains unclear whether INGD is useful to train transformer-based models, and in half-precision.
3 Structured inverse-free NGD
Inspired by INGD, we propose an inverse-free KFAC update as a specific setting of INGD to address KFAC’s numerical instability in low precision. We show that this scheme effectively recovers KFAC. We then address the memory inefficiency of KFAC and INGD for training transformer-based models by extending INGD with structures.
3.1 Inverse-free KFAC Updates for Numerical Stability
Subspace of the log (Lie-algebraic) space | Matrix Lie sub-group structure in | Subspace projection map |
---|---|---|
|
Lower-triangular (Tril.) |
|
|
(Block) Diagonal (block size ) |
|
, is diag., , | Hierarchical () |
|
, is diag., | Rank- upper-triangular |
|
|
Upper-triangular Toeplitz (Triu-Toepl.) |
Dense | |||
---|---|---|---|
Diagonal | |||
Block-diag. | |||
Tril-Toepl. | |||
Triu-Toepl. | |||
Hierarchical | |||
Sparse Triu. | |||
Sparse Triu. | |||
Sparse Tril. | |||
Sparse Tril. |
We first propose a new inverse-free update to mimic the behavior of the KFAC update; we call this update IKFAC. We then show that IKFAC corresponds to a specific setting of INGD. This bridges the gap between INGD and KFAC and sheds light on the difference between both methods.
Inspired by INGD, we replace matrix inversion with matrix subtraction in a matrix logarithm space, then go back to the original space without explicitly inverting any matrix using a truncated matrix exponential map. The IKFAC update is related to the KFAC update as we will use and to approximate the inverse Kronecker factors and in KFAC, respectively. We propose the following IKFAC update with learning rate for and using a truncated matrix exponential
(7) |
where , , , . This update is inverse- and matrix-decomposition-free. Since we truncate the matrix exponential , indeed stays in a matrix logarithm space (see Appendix C). The logarithm space allows to impose structural constraints on we discuss in Section 3.2.
The following theorem—proof in Appendix D—formally shows that used in IKFAC is an approximation of in KFAC at every step even with a truncated matrix exponential. Similarly, is an approximation of . Thus, IKFAC effectively recovers KFAC up to a first-order accuracy.
Theorem 1.
If is updated according to the IKFAC scheme (Figure 4) with the truncation of the matrix exponential and these two updates use the same initialization and the same sequence of curvature matrices , then the product has a first-order accuracy of the KFAC update of at each iteration, i.e., .
1 trivially extends to diagonal and block-diagonal structures. I.e., KFAC with diagonal or block-diagonal Kronecker factors is equivalent to IKFAC with diagonal or block-diagonal structure up to first order in .
Now, we show that IKFAC is a specific case of INGD, whose update of without Riemannian momentum () is
(8) |
Since , , , and , we can obtain IKFAC from INGD by simply replacing and with :
(9) |
This sheds light on the difference between both methods. In IKFAC (see Appendix C for details), and are used for incorporating KFAC’s curvature and damping , respectively. In contrast, the curvature and damping are adaptively incorporated in INGD using and . The updates of and are correlated in INGD due to the trace terms, while and are updated independently in IKFAC—just like and in KFAC. These trace terms are needed to satisfy the orthonormalization condition of the Fisher matrix (Lin et al., 2023). They make INGD and SINGD scale-invariant to the Kronecker approximation (see Appendix E) as the approximation is not unique. In contrast, KFAC and IKFAC are not scale-invariant. The trace terms together with Riemannian momentum () are missing in KFAC and IKFAC. Our experiments show that they can contribute to stability.
3.2 Sparse Kronecker Factors for Reducing Memory
Now, we extend INGD to reduce its memory and iteration cost. Existing sparse KFAC methods use (block-)diagonal structures for and (Zhang et al., 2019; Grosse et al., 2023). In contrast, we propose using sparse Kronecker factors and in INGD and exploiting Lie-algebraic properties in the logarithm space and algebraic sparsity of the Kronecker factors. This enables more flexible structures (Figure 5) that potentially achieve better downstream performance than (block-)diagonal structures in , .
Other related works are Lie group preconditioners (Li, 2018, 2022) originally derived from directly approximating the Hessian inverse. Some Hessian-vector-product-based versions of these methods can be expensive and unavailable in pure low-precision settings due to sampling random weights and solving linear systems that are unstable in low precision. Our approach is sampling-free and available in pure half-precision settings.
We want to construct sparse factors and without requiring the Kronecker/Hessian approximation () to be further sparse or structured. Imposing sparsity often leads to a complicated FIM which makes it difficult to perform NGD due to the FIM inversion. It is essential to update as the logarithm space of to impose sparsity on as the FIM in this (moving) coordinate is simplified and becomes an identity matrix due to the orthonormalization condition. This condition (Lin et al., 2023) makes it easy for us to impose a range of sparse structures on through a unified and inverse-free update rule (Figure 4) since we can avoid inverting the Fisher block regarding the sparse structures. We also exploit the algebraic sparsity in these structures to make our rule more efficient than INGD (Table 3).
We exploit Lie-algebraic properties in the log space of to construct sparse structures of . As a general design principle, we consider structures of preserved under (i) elementwise matrix operations (subtraction and scalar multiplication) and (ii) matrix multiplication, which are needed for our updates. Concretely, we construct a new local reparameterization for at iteration via
where projects the dense to a subspace (identically for , but potentially using a different structure .
Many popular structures such as tri-diagonal matrices do not satisfy our requirements as they are not closed under matrix multiplication. Moreover, it can be difficult to construct the projection map to satisfy the orthonormalization condition. One subspace structure satisfying the requirements are upper/lower triangular matrices. The subspace projection is a weighted extraction map since projecting the logarithm space onto a subspace is like projecting a dense square matrix onto a triangular matrix. Technically, we use
to update at iteration , treating and as constants. Given a subspace in the matrix logarithm space, the subspace projection map is specified by satisfying the local orthonormalization condition of the Fisher block regarding :
with the variational Gaussian with mean , precision and the set of symmetric square real matrices. Similarly, we can obtain for .
We consider several sparsities and block extensions of triangular matrices illustrated in Figure 5. E.g., the subspace projection map for a diagonal structure simply extracts diagonal entries of its input. As a non-trivial example, the subspace projection map for a lower-triangular structure extracts lower-triangular entries of its input and multiplies the entries below the main diagonal by 2. Table 2 summarizes structures and their projection maps mathematically.
Using such a subspace and its projection map, we obtain a structured INGD update (Figure 4), and similar for IKFAC. Our approach allows to use more expressive structures than the block-diagonal structure shown in Figure 5, e.g. low-rank, flexible hierarchical, and Toeplitz structures. While existing methods mainly support low-rank structures. For an efficient implementation, we only compute and store non-zero entries of and without explicitly forming dense matrices. These structures lower not only memory consumption (Table 4), but also the iteration cost (Table 3).
4 Experiments
We evaluate SINGD on convolutional, transformer, and graph NNs, using mixed-precision training in BFP-16 with KFAC-reduce (Eschenhagen et al., 2023) and numerical tricks (Dangel, 2023) to further reduce memory consumption and iteration cost for convolutions. The performance metric is test error. To be memory-efficient, we consider SINGD with sparse structures such as ‘diagonal’, ‘block-diagonal’, and ‘hierarchical’. We also consider IKFAC, INGD (recall SINGD with dense structure becomes INGD), and AdamW as baselines. All methods except KFAC directly support training in BFP-16. For KFAC, we have to transform a matrix into FP-32 and then transform its inverse into BFP-16. We find that KFAC performs unstably in BFP-16. For ‘VGG’ and ‘ConvMixer’, we also consider SGD as a strong baseline, We fix momentum to 0.9 and tune other hyper-parameters of each optimizer using random search. For ‘VGG’ and ‘ConvMixer’, we decrease the learning rate every 40 epochs. For ‘GNN’, we use a constant learning rate; all other models use a cosine learning rate schedule. We consider KFAC as a strong baseline for the GNN as suggested by Izadi et al. (2020). We train the GNN in FP-32 so that KFAC performs stably. The search space for the random search can be found in Table 5 in Appendix B.
From Figure 6 and 7, we can observe that SINGD, including IKFAC and INGD as special cases, outperforms AdamW in many cases. SINGD works well for mixed-precision training. We do not show KFAC in the plots as it performs unstably due to numerical issues. We also observe that the hierarchical structure often performs as well as the dense structure (INGD) on all the models. In several cases, the hierarchical structure outperforms the block-diagonal and diagonal structures. However, on the models shown in Figure 7, even the diagonal structure can perform as well as the dense one. Thus, we can reduce INGD’s memory consumption and make SINGD as competitive as AdamW. We also train a ViT model on “ImageNet-100" to demonstrate the superior performance of SINGD over AdamW in large-scale settings (see Figure 9 in Appendix B).
5 Conclusion
We propose an inverse-free, memory-efficient natural gradient descent method—SINGD—which addresses the numerical instability and memory inefficiency of second-order methods like KFAC (Martens & Grosse, 2015). The algorithm is an extension of the inverse-free natural gradient (INGD) method from Lin et al. (2023), whose update relies only on matrix multiplications. We theoretically establish the algorithm’s relation to KFAC by showing that a modification of INGD effectively performs KFAC-like updates and further improve its memory efficiency through sparse Kronecker factors. We showed that SINGD supports low-precision training and often outperforms AdamW on transformer-based models. Our work expands the scope of second-order methods to training transformer-based NNs and in low precision, making them more widely applicable.
Acknowledgements
Resources used in preparing this research were provided, in part, by the Province of Ontario, the Government of Canada through CIFAR, and companies sponsoring Vector Institute. Runa Eschenhagen is supported by ARM and the Cambridge Trust. Richard E. Turner is supported by Google, Amazon, ARM, Improbable and EPSRC grant EP/T005386/1.
Impact Statement
This paper presents work whose goal is to advance the field of Machine Learning. There are many potential societal consequences of our work, none which we feel must be specifically highlighted here.
References
- Amari (1998) Amari, S.-I. Natural gradient works efficiently in learning. Neural computation, 10(2):251–276, 1998.
- Bae et al. (2022) Bae, J., Ng, N., Lo, A., Ghassemi, M., and Grosse, R. B. If influence functions are the answer, then what is the question? In NeurIPS, 2022.
- Botev et al. (2017) Botev, A., Ritter, H., and Barber, D. Practical Gauss-Newton optimisation for deep learning. In ICML, 2017.
- Bradbury et al. (2018) Bradbury, J., Frostig, R., Hawkins, P., Johnson, M. J., Leary, C., Maclaurin, D., Necula, G., Paszke, A., VanderPlas, J., Wanderman-Milne, S., and Zhang, Q. JAX: composable transformations of Python+NumPy programs, 2018. URL http://github.com/google/jax.
- Brown et al. (2020) Brown, T., Mann, B., Ryder, N., Subbiah, M., Kaplan, J. D., Dhariwal, P., Neelakantan, A., Shyam, P., Sastry, G., Askell, A., et al. Language models are few-shot learners. In NeurIPS, 2020.
- Dangel (2023) Dangel, F. Convolutions through the lens of tensor networks. arXiv 2307.02275, 2023.
- Daxberger et al. (2021) Daxberger, E., Kristiadi, A., Immer, A., Eschenhagen, R., Bauer, M., and Hennig, P. Laplace redux—effortless Bayesian deep learning. In NeurIPS, 2021.
- Dehghani et al. (2023) Dehghani, M., Djolonga, J., Mustafa, B., Padlewski, P., Heek, J., Gilmer, J., Steiner, A. P., Caron, M., Geirhos, R., Alabdulmohsin, I., et al. Scaling vision transformers to 22 billion parameters. In ICML, 2023.
- Eschenhagen et al. (2023) Eschenhagen, R., Immer, A., Turner, R. E., Schneider, F., and Hennig, P. Kronecker-Factored Approximate Curvature for modern neural network architectures. In NeurIPS, 2023.
- George et al. (2018) George, T., Laurent, C., Bouthillier, X., Ballas, N., and Vincent, P. Fast approximate natural gradient descent in a kronecker factored eigenbasis. In NeurIPS, 2018.
- Graves (2011) Graves, A. Practical variational inference for neural networks. In NeurIPS, 2011.
- Grosse & Martens (2016) Grosse, R. and Martens, J. A kronecker-factored approximate fisher matrix for convolution layers. In ICML, 2016.
- Grosse et al. (2023) Grosse, R., Bae, J., Anil, C., Elhage, N., Tamkin, A., Tajdini, A., Steiner, B., Li, D., Durmus, E., Perez, E., et al. Studying large language model generalization with influence functions. arXiv preprint arXiv:2308.03296, 2023.
- Hassani et al. (2021) Hassani, A., Walton, S., Shah, N., Abuduweili, A., Li, J., and Shi, H. Escaping the big data paradigm with compact transformers. arXiv preprint arXiv:2104.05704, 2021.
- Hatamizadeh et al. (2023) Hatamizadeh, A., Yin, H., Heinrich, G., Kautz, J., and Molchanov, P. Global context vision transformers. In International Conference on Machine Learning, pp. 12633–12646. PMLR, 2023.
- Heskes (2000) Heskes, T. On “natural” learning and pruning in multilayered perceptrons. Neural Computation, 12(4), 2000.
- Immer et al. (2021) Immer, A., Bauer, M., Fortuin, V., Rätsch, G., and Emtiyaz, K. M. Scalable marginal likelihood estimation for model selection in deep learning. In ICML, 2021.
- Izadi et al. (2020) Izadi, M. R., Fang, Y., Stevenson, R., and Lin, L. Optimization of graph neural networks with natural gradient descent. In 2020 IEEE international conference on big data (big data), pp. 171–179. IEEE, 2020.
- Khan & Lin (2017) Khan, M. and Lin, W. Conjugate-computation variational inference: Converting variational inference in non-conjugate models to inferences in conjugate models. In Artificial Intelligence and Statistics, pp. 878–887, 2017.
- Khan & Nielsen (2018) Khan, M. E. and Nielsen, D. Fast yet Simple Natural-Gradient Descent for Variational Inference in Complex Models. arXiv preprint arXiv:1807.04489, 2018.
- Khan & Rue (2021) Khan, M. E. and Rue, H. The bayesian learning rule. arXiv preprint arXiv:2107.04562, 2021.
- Khan et al. (2018) Khan, M. E., Nielsen, D., Tangkaratt, V., Lin, W., Gal, Y., and Srivastava, A. Fast and scalable Bayesian deep learning by weight-perturbation in Adam. In ICML, 2018.
- Kingma & Ba (2015) Kingma, D. P. and Ba, J. Adam: A method for stochastic optimization. In International Conference on Learning Representations, 2015.
- Kipf & Welling (2016) Kipf, T. N. and Welling, M. Semi-supervised classification with graph convolutional networks. arXiv preprint arXiv:1609.02907, 2016.
- Kunstner et al. (2019) Kunstner, F., Balles, L., and Hennig, P. Limitations of the empirical Fisher approximation for natural gradient descent. In NeurIPS, 2019.
- Li (2022) Li, X. Black box lie group preconditioners for sgd. arXiv preprint arXiv:2211.04422, 2022.
- Li (2018) Li, X.-L. Preconditioner on matrix lie group for sgd. In International Conference on Learning Representations, 2018.
- Lin et al. (2020) Lin, W., Schmidt, M., and Khan, M. E. Handling the positive-definite constraint in the bayesian learning rule. In ICML, 2020.
- Lin et al. (2021) Lin, W., Nielsen, F., Emtiyaz, K. M., and Schmidt, M. Tractable structured natural-gradient descent using local parameterizations. In ICML, 2021.
- Lin et al. (2023) Lin, W., Duruisseaux, V., Leok, M., Nielsen, F., Khan, M. E., and Schmidt, M. Simplifying momentum-based positive-definite submanifold optimization with applications to deep learning. In ICML, 2023.
- Liu et al. (2021) Liu, Z., Lin, Y., Cao, Y., Hu, H., Wei, Y., Zhang, Z., Lin, S., and Guo, B. Swin transformer: Hierarchical vision transformer using shifted windows. In Proceedings of the IEEE/CVF international conference on computer vision, pp. 10012–10022, 2021.
- Loshchilov & Hutter (2019) Loshchilov, I. and Hutter, F. Decoupled weight decay regularization. In ICLR, 2019.
- Lu et al. (2022) Lu, Z., Xie, H., Liu, C., and Zhang, Y. Bridging the gap between vision transformers and convolutional neural networks on small datasets. Advances in Neural Information Processing Systems, 35:14663–14677, 2022.
- Martens (2014) Martens, J. New insights and perspectives on the natural gradient method. JMLR, 21(146), 2014.
- Martens & Grosse (2015) Martens, J. and Grosse, R. Optimizing neural networks with Kronecker-factored approximate curvature. In ICML, 2015.
- Martens et al. (2018) Martens, J., Ba, J., and Johnson, M. Kronecker-factored curvature approximations for recurrent neural networks. In ICLR, 2018.
- Micikevicius et al. (2018) Micikevicius, P., Narang, S., Alben, J., Diamos, G., Elsen, E., Garcia, D., Ginsburg, B., Houston, M., Kuchaiev, O., Venkatesh, G., and Wu, H. Mixed precision training. In International Conference on Learning Representations (ICLR), 2018.
- Osawa et al. (2019) Osawa, K., Swaroop, S., Khan, M. E. E., Jain, A., Eschenhagen, R., Turner, R. E., and Yokota, R. Practical deep learning with Bayesian principles. In NeurIPS, 2019.
- Osawa et al. (2023) Osawa, K., Li, S., and Hoefler, T. PipeFisher: Efficient training of large language models using pipelining and Fisher information matrices. In MLSys, 2023.
- Osborne (1992) Osborne, M. R. Fisher’s method of scoring. International Statistical Review/Revue Internationale de Statistique, pp. 99–117, 1992.
- Paszke et al. (2019) Paszke, A., Gross, S., Massa, F., Lerer, A., Bradbury, J., Chanan, G., Killeen, T., Lin, Z., Gimelshein, N., Antiga, L., et al. PyTorch: An imperative style, high-performance deep learning library. In NeurIPS, 2019.
- Radford et al. (2019) Radford, A., Wu, J., Child, R., Luan, D., Amodei, D., Sutskever, I., et al. Language models are unsupervised multitask learners. OpenAI blog, 1(8):9, 2019.
- Robbins & Monro (1951) Robbins, H. and Monro, S. A Stochastic Approximation Method. The Annals of Mathematical Statistics, 1951.
- Schraudolph (2002) Schraudolph, N. N. Fast curvature matrix-vector products for second-order gradient descent. Neural computation, 14(7), 2002.
- Simonyan & Zisserman (2014) Simonyan, K. and Zisserman, A. Very deep convolutional networks for large-scale image recognition. arXiv preprint arXiv:1409.1556, 2014.
- Smyth (1996) Smyth, G. K. Partitioned algorithms for maximum likelihood and other non-linear estimation. Statistics and Computing, 6:201–216, 1996.
- Smyth (2015) Smyth, G. K. Optimization and nonlinear equations. Statistics reference online, 1:1–9, 2015.
- Tan (2022) Tan, L. S. Analytic natural gradient updates for cholesky factor in gaussian variational approximation. arXiv preprint arXiv:2109.00375, 2022.
- Thompson et al. (2020) Thompson, N. C., Greenewald, K., Lee, K., and Manso, G. F. The computational limits of deep learning. 2020.
- Touvron et al. (2023) Touvron, H., Lavril, T., Izacard, G., Martinet, X., Lachaux, M.-A., Lacroix, T., Rozière, B., Goyal, N., Hambro, E., Azhar, F., et al. LLaMA: Open and efficient foundation language models. arXiv preprint arXiv:2302.13971, 2023.
- Trockman & Kolter (2023) Trockman, A. and Kolter, J. Z. Patches are all you need? Transactions on Machine Learning Research, 2023.
- Vaswani et al. (2017) Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, Ł., and Polosukhin, I. Attention is all you need. In NIPS, 2017.
- Wang et al. (2023) Wang, A., Chen, H., Lin, Z., Pu, H., and Ding, G. Repvit: Revisiting mobile cnn from vit perspective. arXiv preprint arXiv:2307.09283, 2023.
- Wang et al. (2019) Wang, C., Grosse, R., Fidler, S., and Zhang, G. Eigendamage: Structured pruning in the kronecker-factored eigenbasis. In ICML, 2019.
- Wang (2010) Wang, Y. Fisher scoring: An interpolation family and its Monte Carlo implementations. Comput. Stat. Data Anal., 54(7), 2010.
- Zhang et al. (2018) Zhang, G., Sun, S., Duvenaud, D., and Grosse, R. Noisy natural gradient as variational inference. In ICML, 2018.
- Zhang et al. (2019) Zhang, G., Li, L., Nado, Z., Martens, J., Sachdeva, S., Dahl, G. E., Shallue, C. J., and Grosse, R. B. Which algorithmic choices matter at which batch sizes? Insights from a noisy quadratic model. In NeurIPS, 2019.
Appendix A space and time complexity
Method |
|
|
|
|
|||||
---|---|---|---|---|---|---|---|---|---|
Iteration Cost | KFAC | ||||||||
INGD/SINGD (Dense) | |||||||||
SINGD (Block-Diag. with block size ) | |||||||||
SINGD (Toeplitz) | |||||||||
SINGD (Rank-1 Triangular) | |||||||||
SINGD (Hierarchical with parameter ) | |||||||||
AdamW | NA | NA |
Method |
|
|
|
||||
Memory Usage | KFAC | NA | |||||
INGD/SINGD (Dense) | NA | ||||||
SINGD (Block-Diag. with block size ) | NA | ||||||
SINGD (Toeplitz) | NA | ||||||
SINGD (Rank-1 Triangular) | NA | ||||||
SINGD (Hierarchical with parameter ) | NA | ||||||
AdamW | NA | NA |
Appendix B Details of the Experiments
To demonstrate the robustness and memory efficiency of our method, we consider image classification tasks with transformer-based models such as “Compact-ViT" (Hassani et al., 2021), “Swin-ViT" (Liu et al., 2021), “GC-ViT" (Hatamizadeh et al., 2023), and “HDVT” (Lu et al., 2022). We also consider convolution-based models such as “VGG” (Simonyan & Zisserman, 2014), “ConvMixer” (Trockman & Kolter, 2023), and “Rep-ViT" (Wang et al., 2023). We train these models on datasets “CIFAR-100" and “ImageWoof-10". Note that “Rep-ViT" is a CNN model inspired by transformers while “Compact-ViT" is a data-efficient transformer using convolutional tokenization. We also consider a graph convolution model (Kipf & Welling, 2016) denoted by “GNN” for node classification on dataset “Cora". We also train a ViT model on “ImageNet-100" (https://www.kaggle.com/datasets/ambityga/imagenet100) to demonstrate the performance of SINGD in large-scale settings (see Fig. 9).
B.1 Hyper-parameter Tuning
|
|
|
|
||||
---|---|---|---|---|---|---|---|
Standard stepsize | Tuned | Tuned | |||||
Standard momentum weight | 0.9 | 0.9 | |||||
(L2) weight decay | Tuned | Tuned | |||||
Damping | Tuned | Tuned | |||||
Stepsize for preconditioner | Tuned | Tuned | |||||
Riemannian Momentum | (SINGD only) Tuned | NA |
Method | Peak memory | Training time |
---|---|---|
[GiB] | [min] | |
SGD (BFP-16) | 15.6 (1.00 x) | 190 (1.00 x) |
AdamW (BFP-16) | 15.7 (1.00 x) | 191 (1.01 x) |
SINGD-Diag* (BFP-16) | 15.8 (1.02 x) | 200 (1.06 x) |
IKFAC* (BFP-16) | 16.0 (1.02 x) | 197 (1.04 x) |
INGD (BFP-16) | 16.0 (1.02 x) | 203 (1.07 x) |
KFAC (FP-32) | 16.0 (1.02 x) | 359 (1.89 x) |
Appendix C Connection between IKFAC and KFAC
To relate to the KFAC method, we now show that is an approximation of at a new step of our scheme. For simplicity, we first assume exactly equals to at the current step. Later, we will relax this assumption and prove that is an approximation of at every step as stated in Theorem 1. For notation simplicity, we denote . The update of with damping can be reexpressed as an update of :
Since by our assumption, we can express update of in terms of as follows.
in the KFAC update can be approximated as below, where we consider as an approximate of the matrix exponential and notice that is symmetric.
Informally, we can see that approximates by using the matrix exponential. We can see that stays in a matrix logarithm space.
Theorem 1 formally shows that used in our update is an approximation of in the KFAC update for every step even when the truncation of the matrix exponential is employed.
Appendix D Proof of Theorem 1
We first consider the following lemmas in order to prove Theorem 1.
Recall that we denote . For notation simplicity, we will drop the subscript in this section and use to denote at iteration . Notice that is non-singular at each iteration so that we can inverse it in the original KFAC update (see Figure 4).
Lemma D.1.
Consider the following update in the original KFAC update at iteration .
where is the factor used in the original KFAC update, is known as the weight of the moving average, and is a curvature matrix.
The initial factor can be decomposed as since as a preconditioning factor is symmetric positive definite.
Define .
The Kronecker factor can be reexpressed as
Lemma D.2.
Consider the following update in our inverse-free KFAC at iteration .
where is used in our update and is a curvature matrix.
Define .
Our update of can be reexpressed as
Moreover, the product can be reexpressed as
Lemma D.3 is useful to establish a relationship between the KFAC update and our inverse-free update.
Lemma D.3.
If we use the same sequence of curvature matrices in both the original KFAC update and our update such as for each iteration and are used on the initialization, we have the following expression.
Similarly, we have the following result for .
Theorem 2.
The product has a first-order accuracy of the KFAC update of at each iteration if the update of is updated according to Figure 4 with the truncation of the matrix exponential and these two updates use the same initialization and the same sequence of curvature matrices .
D.1 Proof of Lemma D.1
We prove the lemma by induction We first show the base case when . By definition, we have
(10) | ||||
(11) | ||||
(12) | ||||
(13) |
Thus, the claim holds when .
Suppose, the claim holds when . By the claim, we have
(14) |
Now, we consider the case when . Notice that
By the definition of , we have
(15) | ||||
(16) | ||||
(17) |
which is exactly the claim when .
Thus, by induction, the claim holds.
D.2 Proof of Lemma D.2
We prove the lemma by induction We first show the base case when . By definition, we have
(18) |
Thus, the claim holds when .
Suppose, the claim holds when . By the claim, we have
(19) |
Now, we consider the case when . Notice that
(20) | ||||
(21) | ||||
(22) | ||||
(23) |
which is exactly the claim when .
Thus, by induction, the claim holds.
Notice that by definition is symmetric. It is easy to see that
(24) | ||||
(25) | ||||
(26) |
Thus, the claim also holds.
D.3 Proof of Lemma D.3
We first show the base case when . By the assumption, we have . Similarly, we have by the assumption.
By definition, we have
(27) | ||||
(28) | ||||
(29) |
Thus, the claim holds when .
When , we can use Lemma D.2 to obtain the claim. Notice that
(30) | ||||
(31) | ||||
(32) | ||||
(33) | ||||
(34) |
D.4 Proof of Theorem 1
It is sufficient to show that the following claim holds at iteration since is non-singular.
where we use to denote at iteration .
Appendix E Invariance of INGD and SINGD
INGD and SINGD are scale invariant to the choice of the Kronecker approximation while KFAC and IKFAC are not. Recall that we use the following Kronecker approximation to approximate the Hessian.
However, such an approximation is not unique. We can consider an equivalent approximation such as
where can be any arbitrary non-zero scalar.
INGD is invariant since the update scheme involving the approximation is scale invariant: . The invariance is also preserved in SINGD since structures and their subspace projection maps are closed under scalar multiplications.
In contrast, the updates of KFAC and IKFAC are not scale invariant. As an example, we consider using curvature approximations and to update in KFAC, and denote the updated by and , respectively. As shown below, we cannot recover from by scale transformations and thus, the KFAC update is not scale invariant.
An attempt to make the update of invariant is to set the damping weight to be . However, the update of requires us to set the damping weight to be as shown below. Thus, it is impossible to make KFAC invariant without introducing individual damping weights.