Structured Inverse-Free Natural Gradient Descent:
Memory-Efficient & Numerically-Stable KFAC

Wu Lin Felix Dangel Runa Eschenhagen Kirill Neklyudov Agustinus Kristiadi Richard E. Turner Alireza Makhzani
Abstract

Second-order methods such as KFAC can be useful for neural net training. However, they are often memory-inefficient since their preconditioning Kronecker factors are dense, and numerically unstable in low precision as they require matrix inversion or decomposition. These limitations render such methods unpopular for modern mixed-precision training. We address them by (i) formulating an inverse-free KFAC update and (ii) imposing structures in the Kronecker factors, resulting in structured inverse-free natural gradient descent (SINGD). On modern neural networks, we show that SINGD is memory-efficient and numerically robust, in contrast to KFAC, and often outperforms AdamW even in half precision. Our work closes a gap between first- and second-order methods in modern low-precision training.

Machine Learning, ICML

1 Introduction

The continuing success of deep learning (DL) is—to a large extent—powered by scaling up computational power (Thompson et al., 2020) to increase the number of trainable neural network (NN) parameters. Contemporary natural language processing (Radford et al., 2019; Brown et al., 2020; Touvron et al., 2023) and computer vision (Dehghani et al., 2023) models often consist of billions of parameters, and will likely grow further in the future. To compensate for increasing computational demands, many training pipelines use lower precision data types (Micikevicius et al., 2018) and memory-efficient first-order optimizers like SGD (Robbins & Monro, 1951) or Adam(W) (Kingma & Ba, 2015; Loshchilov & Hutter, 2019).

Second-order methods, like natural gradient descent (NGD, Amari, 1998), leverage curvature information which has many applications in DL: It is useful for improving training dynamics (Martens & Grosse, 2015; Osawa et al., 2023), pruning (Wang et al., 2019), understanding the influence of training examples (Bae et al., 2022), and uncertainty estimation (Zhang et al., 2018; Immer et al., 2021; Daxberger et al., 2021). One major obstacle why those methods are rarely used is their higher memory consumption and iteration cost.

The perhaps most common concept to scale second-order methods for DL is Kronecker-factored approximate curvature (KFAC, Heskes, 2000; Martens & Grosse, 2015) which approximates the Fisher’s block diagonals via Kronecker products. The KFAC optimizer built on top of this curvature approximation, and its variants such as George et al. (2018) show promising results for medium-sized NNs (e.g. Osawa et al., 2023), its usefulness is often limited by (i) memory consumption, and (ii) the use of low-precision floating-point (FP) training that renders matrix decompositions/inversions required to pre-condition the gradient numerically unstable.

Recently, Lin et al. (2023) proposed an inverse-free Kronecker-factored natural gradient descent (INGD) algorithm that replaces matrix inversion with subtraction in a matrix logarithm space. Their update is purely based on matrix multiplications and therefore numerically stable in single-precision (FP-32); however, it is unclear whether this extends to half-precision (BFP-16). Furthermore, INGD has not been derived from the popular natural gradient approaches for DL. It is unclear if and how the method is connected to the predominant KFAC optimizer. Also, INGD does not improve over KFAC’s memory complexity since its Kronecker factors are dense matrices of the same size. And lastly, INGD has only been tested on convolution-based models and it is unclear whether it is useful for training modern transformer-based architectures (Vaswani et al., 2017).

Refer to caption
Figure 1: CIFAR-100 experiments on VGG net. Left/Center: Our methods (IKFAC and SINGD) outperform AdamW and perform stably in FP-32 and BFP-16—unlike KFAC—as they do not require matrix inversions. IKFAC effectively performs KFAC updates and achieves similar performance in FP-32. For this task, replacing the dense Kronecker factors (INGD = SINGD-Dense) with diagonal ones (SINGD-Diag) does not harm performance while reducing cost. Right: Memory consumption. Removing Riemannian momentum (IKFAC) or using structured Kronecker factors (SINGD-Diag) reduces INGD’s memory in FP-32 and BFP-16. In BFP-16, SINGD-Diag achieves AdamW’s memory consumption (dashed line).

Here, we extend INGD to lower its computational cost and theoretically resolve its connection to other approximate NGD methods for DL (overview in Figure 2): First, we show that a special case of INGD recovers the KFAC method. This allows us to effectively perform KFAC updates in an inverse-free fashion. We call this modification of INGD inverse-free KFAC (IKFAC). Second, we exploit an algebraic structure in the matrix logarithm space and propose structure-preserving updates to maintain sparse structures on Kronecker factors. This significantly reduces memory and leads to a novel, scalable second-order optimization algorithm we call structured inverse-free natural gradient descent (SINGD) which contains INGD and IKFAC as special cases. We evaluate SINGD on convolution- and transformer-based models and show that it can (i) outperform SGD and AdamW while using as little memory as the latter thanks to structured Kronecker factors and (ii) yield better performance than KFAC while being stable in half-precision:

  1. (a)

    We bridge the gap between INGD (Lin et al., 2023) and the original KFAC (Martens & Grosse, 2015), whose matrix inversions are unstable in low precision. Thereby, we effectively make KFAC inverse-free and amenable to low-precision training (Figure 1, left/center).

  2. (b)

    We impose various structures (block-diagonal, low-rank, Toeplitz, hierarchical) on INGD’s Kronecker factors, allowing them to be sparse to lower the memory consumption and run time (Figure 1, right and Table 1). Unlike many existing second-order methods tailored to a form of structure, our proposed update rule (Figure 4) is unified, efficient, and inverse-free for a range of structures. We analyze the impact of structures on downstream performance and find that structures with considerably lower memory consumption (even lower than AdamW) can yield competitive performance.

  3. (c)

    Unlike other second-order methods, we show that SINGD can stably train a range of modern architectures (transformers, CNNs, GNNs) in BFP-16. In contrast to first-order methods which are often useful in narrower scopes (SGD is best for CNNs, AdamW is best for transformers), SINGD works well and outperforms SGD and AdamW in many cases (see Section 4).

Our work closes a gap between first- and second-order methods in modern low precision neural network training111PyTorch implementation: github.com/f-dangel/singd.

Table 1: Training times and memory consumption for the optimizers shown in Figure 1 (parenthesized values are normalized relative to SGD; our methods are marked with an asterisk). INGD has 80 % time and 30 % memory overhead compared to SGD. In contrast, our SINGD-Diag only has 30 % time and 2 % memory overhead. This means that by using structures we can reduce INGD’s time overhead by more than half, and basically eliminate its memory overhead compared to first-order competitors.
Method Peak memory Training time
[GiB] [min]
SGD (BFP-16) 2.63 (1.00 x) 18.5 (1.00 x)
AdamW (BFP-16) 2.69 (1.02 x) 19.7 (1.07 x)
SINGD-Diag* (BFP-16) 2.67 (1.02 x) 23.8 (1.29 x)
IKFAC* (BFP-16) 3.18 (1.21 x) 34.0 (1.84 x)
INGD (BFP-16) 3.39 (1.29 x) 34.1 (1.84 x)
KFAC (FP-32) 4.00 (1.52 x) 83.2 (4.49 x)

2 Preliminaries

We first introduce the necessary ingredients to establish a connection between INGD and KFAC, which are derived from different perspectives. We start by describing Newton’s method since both methods can be seen as approximate Newton methods using NGD. NN training often corresponds to an unconstrained minimization problem. Consider training a NN for image classification. Given a set of N𝑁Nitalic_N examples {yi,𝐱i}i=1Nsuperscriptsubscriptsubscript𝑦𝑖subscript𝐱𝑖𝑖1𝑁\{y_{i},\mbox{$\mbox{$\mathbf{x}$}$}_{i}\}_{i=1}^{N}{ italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT with labels yisubscript𝑦𝑖y_{i}italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and images 𝐱isubscript𝐱𝑖\mbox{$\mbox{$\mathbf{x}$}$}_{i}bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, the optimization problem is

minμ(𝝁;𝐲,𝐗):-minμi=1Nc(yi,f(𝝁;𝐱i)),:-subscript𝜇𝝁𝐲𝐗subscript𝜇superscriptsubscript𝑖1𝑁𝑐subscript𝑦𝑖𝑓𝝁subscript𝐱𝑖\min_{\mu}\ell(\mbox{$\mbox{$\boldsymbol{\mu}$}$};\mbox{$\mbox{$\mathbf{y}$}$}% ,\mbox{$\mbox{$\mathbf{X}$}$})\coloneq\min_{\mu}\textstyle\sum_{i=1}^{N}c(y_{i% },f(\mbox{$\mbox{$\boldsymbol{\mu}$}$};\mbox{$\mbox{$\mathbf{x}$}$}_{i}))\,,roman_min start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT roman_ℓ ( bold_italic_μ ; bold_y , bold_X ) :- roman_min start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_c ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_f ( bold_italic_μ ; bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) , (1)

where 𝐲:-(y1,,yN):-𝐲subscript𝑦1subscript𝑦𝑁\mbox{$\mbox{$\mathbf{y}$}$}\coloneq(y_{1},\dots,y_{N})bold_y :- ( italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_y start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ), 𝐗:-(𝐱1,,𝐱N):-𝐗subscript𝐱1subscript𝐱𝑁\mbox{$\mbox{$\mathbf{X}$}$}\coloneq(\mbox{$\mbox{$\mathbf{x}$}$}_{1},\dots,% \mbox{$\mbox{$\mathbf{x}$}$}_{N})bold_X :- ( bold_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_x start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ), and y^i:-f(𝝁;𝐱i):-subscript^𝑦𝑖𝑓𝝁subscript𝐱𝑖\hat{y}_{i}\coloneq f(\mbox{$\mbox{$\boldsymbol{\mu}$}$};\mbox{$\mbox{$\mathbf% {x}$}$}_{i})over^ start_ARG italic_y end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT :- italic_f ( bold_italic_μ ; bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) is a NN that outputs a predicted label y^isubscript^𝑦𝑖\hat{y}_{i}over^ start_ARG italic_y end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT for an image 𝐱isubscript𝐱𝑖\mbox{$\mbox{$\mathbf{x}$}$}_{i}bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT. Parameters 𝝁𝝁\boldsymbol{\mu}bold_italic_μ denote learnable weights of the NN and c(yi,y^i)𝑐subscript𝑦𝑖subscript^𝑦𝑖c(y_{i},\hat{y}_{i})italic_c ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , over^ start_ARG italic_y end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) is a differentiable loss function to measure the difference between a true label yisubscript𝑦𝑖y_{i}italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and a predicted label y^isubscript^𝑦𝑖\hat{y}_{i}over^ start_ARG italic_y end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT. To solve Equation 1, Newton’s method follows the update

𝝁𝝁𝐒1(μ(𝝁;𝐲,𝐗)),𝝁𝝁superscript𝐒1subscript𝜇𝝁𝐲𝐗\mbox{$\mbox{$\boldsymbol{\mu}$}$}\leftarrow\mbox{$\mbox{$\boldsymbol{\mu}$}$}% -\mbox{$\mbox{$\mathbf{S}$}$}^{-1}\left(\nabla_{\mu}\ell(\mbox{$\mbox{$% \boldsymbol{\mu}$}$};\mbox{$\mbox{$\mathbf{y}$}$},\mbox{$\mbox{$\mathbf{X}$}$}% )\right)\,,bold_italic_μ ← bold_italic_μ - bold_S start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( ∇ start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT roman_ℓ ( bold_italic_μ ; bold_y , bold_X ) ) , (2)

where 𝐒:=μ2(𝝁;𝐲,𝐗)assign𝐒superscriptsubscript𝜇2𝝁𝐲𝐗\mbox{$\mbox{$\mathbf{S}$}$}:=\nabla_{\mu}^{2}\ell(\mbox{$\mbox{$\boldsymbol{% \mu}$}$};\mbox{$\mbox{$\mathbf{y}$}$},\mbox{$\mbox{$\mathbf{X}$}$})bold_S := ∇ start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_ℓ ( bold_italic_μ ; bold_y , bold_X ) is the Hessian of the loss.

2.1 KFAC: Approximate NGD for MLE

Refer to caption
Figure 2: Existing methods and their relation to our proposed methods. IKFAC behaves like KFAC (1), but is numerically stable in low precision. In contrast to IKFAC, INGD has Riemannian momenta and adaptive damping and curvature, which can yield better performance in practice (Section 4). INGD is equivalent to SINGD with unstructured Kronecker factors (SINGD-Dense). Structured Kronecker factors reduce memory and computational cost.

Computing the Hessian, as required by Newton’s method, is usually intractable for NNs. NGD uses a Fisher information matrix (FIM) instead of the Hessian by reformulating problem (1) as maximum likelihood estimation (MLE) of p(𝐲𝝁,𝐗)=ip(yi𝝁,𝐱i)𝑝conditional𝐲𝝁𝐗subscriptproduct𝑖𝑝conditionalsubscript𝑦𝑖𝝁subscript𝐱𝑖p(\mbox{$\mbox{$\mathbf{y}$}$}\mid\mbox{$\mbox{$\boldsymbol{\mu}$}$},\mbox{$% \mbox{$\mathbf{X}$}$})=\prod_{i}p(y_{i}\mid\mbox{$\mbox{$\boldsymbol{\mu}$}$},% \mbox{$\mbox{$\mathbf{x}$}$}_{i})italic_p ( bold_y ∣ bold_italic_μ , bold_X ) = ∏ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_p ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∣ bold_italic_μ , bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ), where p(yi𝝁,𝐱i):-exp(c(yi,f(𝝁,𝐱i))):-𝑝conditionalsubscript𝑦𝑖𝝁subscript𝐱𝑖𝑐subscript𝑦𝑖𝑓𝝁subscript𝐱𝑖p(y_{i}\mid\mbox{$\mbox{$\boldsymbol{\mu}$}$},\mbox{$\mbox{$\mathbf{x}$}$}_{i}% )\coloneq\exp(-c(y_{i},f(\mbox{$\mbox{$\boldsymbol{\mu}$}$},\mbox{$\mbox{$% \mathbf{x}$}$}_{i})))italic_p ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∣ bold_italic_μ , bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) :- roman_exp ( - italic_c ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_f ( bold_italic_μ , bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) ). The maximization problem maxμp(𝐲𝝁,𝐗)subscript𝜇𝑝conditional𝐲𝝁𝐗\max_{\mu}p(\mbox{$\mbox{$\mathbf{y}$}$}\mid\mbox{$\mbox{$\boldsymbol{\mu}$}$}% ,\mbox{$\mbox{$\mathbf{X}$}$})roman_max start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT italic_p ( bold_y ∣ bold_italic_μ , bold_X ) is equivalent to the MLE problem

minμlogp(𝐲𝝁,𝐗)=minμ(𝝁;𝐲,𝐗).subscript𝜇𝑝conditional𝐲𝝁𝐗subscript𝜇𝝁𝐲𝐗\min_{\mu}-\log p(\mbox{$\mbox{$\mathbf{y}$}$}\mid\mbox{$\mbox{$\boldsymbol{% \mu}$}$},\mbox{$\mbox{$\mathbf{X}$}$})=\min_{\mu}\ell(\mbox{$\mbox{$% \boldsymbol{\mu}$}$};\mbox{$\mbox{$\mathbf{y}$}$},\mbox{$\mbox{$\mathbf{X}$}$}% )\,.roman_min start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT - roman_log italic_p ( bold_y ∣ bold_italic_μ , bold_X ) = roman_min start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT roman_ℓ ( bold_italic_μ ; bold_y , bold_X ) . (3)

This formulation allows to exploit additional statistical structures such as the FIM which is defined as shown below (Kunstner et al., 2019), where we assume a label y𝑦yitalic_y is sampled from the likelihood p(y𝝁,𝐱i)𝑝conditional𝑦𝝁subscript𝐱𝑖p(y\mid\mbox{$\mbox{$\boldsymbol{\mu}$}$},\mbox{$\mbox{$\mathbf{x}$}$}_{i})italic_p ( italic_y ∣ bold_italic_μ , bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) given an image 𝐱isubscript𝐱𝑖\mbox{$\mbox{$\mathbf{x}$}$}_{i}bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT. With 𝐬i(y):-logp(yμ,𝐱i):-subscript𝐬𝑖𝑦𝑝conditional𝑦𝜇subscript𝐱𝑖\mbox{$\mbox{$\mathbf{s}$}$}_{i}(y)\coloneq\log p(y\mid\mu,\mbox{$\mbox{$% \mathbf{x}$}$}_{i})bold_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_y ) :- roman_log italic_p ( italic_y ∣ italic_μ , bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ), we have

F(𝝁):-i=1N𝔼yp(yμ,xi)[μ𝐬i(y)(μ𝐬i(y))]=i=1N𝔼yp(yμ,xi)[μ2𝐬i(y)].:-𝐹𝝁superscriptsubscript𝑖1𝑁subscript𝔼similar-to𝑦𝑝conditional𝑦𝜇subscript𝑥𝑖delimited-[]subscript𝜇subscript𝐬𝑖𝑦superscriptsubscript𝜇subscript𝐬𝑖𝑦topsuperscriptsubscript𝑖1𝑁subscript𝔼similar-to𝑦𝑝conditional𝑦𝜇subscript𝑥𝑖delimited-[]superscriptsubscript𝜇2subscript𝐬𝑖𝑦\displaystyle\begin{split}F(\mbox{$\mbox{$\boldsymbol{\mu}$}$})&\coloneq\sum_{% i=1}^{N}\mathbb{E}_{y\sim p(y\mid\mu,x_{i})}\left[\nabla_{\mu}\mbox{$\mbox{$% \mathbf{s}$}$}_{i}(y)(\nabla_{\mu}\mbox{$\mbox{$\mathbf{s}$}$}_{i}(y))^{\top}% \right]\\ &=\sum_{i=1}^{N}\mathbb{E}_{y\sim p(y\mid\mu,x_{i})}\left[-\nabla_{\mu}^{2}% \mbox{$\mbox{$\mathbf{s}$}$}_{i}(y)\right]\,.\end{split}start_ROW start_CELL italic_F ( bold_italic_μ ) end_CELL start_CELL :- ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT italic_y ∼ italic_p ( italic_y ∣ italic_μ , italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ ∇ start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT bold_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_y ) ( ∇ start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT bold_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_y ) ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT italic_y ∼ italic_p ( italic_y ∣ italic_μ , italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ - ∇ start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_y ) ] . end_CELL end_ROW (4)

For ubiquitous loss functions like the mean-squared error and cross-entropy, and more generally, many members of the exponential family with natural parameterization, the FIM coincides with the generalized Gauss-Newton (GGN) matrix (Wang, 2010; Martens, 2014), a common approximation of the Hessian in deep learning (Schraudolph, 2002; Botev et al., 2017). This relationship connects NGD to Newton’s method. A common approximation of the FIM/GGN and Hessian is the so-called empirical Fisher F^(𝝁)^𝐹𝝁\smash{\hat{F}}(\mbox{$\mbox{$\boldsymbol{\mu}$}$})over^ start_ARG italic_F end_ARG ( bold_italic_μ ), which replaces the samples y𝑦yitalic_y from the model’s predictive distribution in Equation 4 with the empirical data labels yisubscript𝑦𝑖y_{i}italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT:

F^(𝝁):-i=1Nμ𝐬i(yi)(μ𝐬i(yi))i=1Nμ2𝐬i(yi)=𝐒.:-^𝐹𝝁superscriptsubscript𝑖1𝑁subscript𝜇subscript𝐬𝑖subscript𝑦𝑖superscriptsubscript𝜇subscript𝐬𝑖subscript𝑦𝑖topsuperscriptsubscript𝑖1𝑁superscriptsubscript𝜇2subscript𝐬𝑖subscript𝑦𝑖𝐒\displaystyle\begin{split}\hat{F}(\mbox{$\mbox{$\boldsymbol{\mu}$}$})&\coloneq% \sum_{i=1}^{N}\nabla_{\mu}\mbox{$\mbox{$\mathbf{s}$}$}_{i}(y_{i})(\nabla_{\mu}% \mbox{$\mbox{$\mathbf{s}$}$}_{i}(y_{i}))^{\top}\\ &\approx-\sum_{i=1}^{N}\nabla_{\mu}^{2}\mbox{$\mbox{$\mathbf{s}$}$}_{i}(y_{i})% =\mbox{$\mbox{$\mathbf{S}$}$}\,.\end{split}start_ROW start_CELL over^ start_ARG italic_F end_ARG ( bold_italic_μ ) end_CELL start_CELL :- ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT bold_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( ∇ start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT bold_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ≈ - ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) = bold_S . end_CELL end_ROW

While there is no clear theoretical justification for this Hessian approximation (Kunstner et al., 2019), it simplifies the implementation, reduces cost, and has been shown to work well in practice (Graves, 2011; Osawa et al., 2019). This approximation is also known as Fisher’s scoring with observed FIM for nonlinear models (Osborne, 1992; Smyth, 1996, 2015). With this, we can formulate an NGD update with the empirical FIM F^(𝝁)^𝐹𝝁\smash{\hat{F}}(\mbox{$\mbox{$\boldsymbol{\mu}$}$})over^ start_ARG italic_F end_ARG ( bold_italic_μ ) to approximate Newton’s method as

𝝁𝝁\boldsymbol{\mu}bold_italic_μ 𝝁β(F^(𝝁))1μ(𝝁;𝐲,𝐗)absent𝝁𝛽superscript^𝐹𝝁1subscript𝜇𝝁𝐲𝐗\displaystyle\leftarrow\mbox{$\mbox{$\boldsymbol{\mu}$}$}-\beta\left(\hat{F}(% \mbox{$\mbox{$\boldsymbol{\mu}$}$})\right)^{-1}\nabla_{\mu}\ell(\mbox{$\mbox{$% \boldsymbol{\mu}$}$};\mbox{$\mbox{$\mathbf{y}$}$},\mbox{$\mbox{$\mathbf{X}$}$})← bold_italic_μ - italic_β ( over^ start_ARG italic_F end_ARG ( bold_italic_μ ) ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT roman_ℓ ( bold_italic_μ ; bold_y , bold_X )
𝝁β𝐒1μ(𝝁;𝐲,𝐗).absent𝝁𝛽superscript𝐒1subscript𝜇𝝁𝐲𝐗\displaystyle\phantom{\leftarrow}\approx\mbox{$\mbox{$\boldsymbol{\mu}$}$}-% \beta\mbox{$\mbox{$\mathbf{S}$}$}^{-1}\nabla_{\mu}\ell(\mbox{$\mbox{$% \boldsymbol{\mu}$}$};\mbox{$\mbox{$\mathbf{y}$}$},\mbox{$\mbox{$\mathbf{X}$}$}).≈ bold_italic_μ - italic_β bold_S start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT roman_ℓ ( bold_italic_μ ; bold_y , bold_X ) .

We call this update NGD for MLE.

KFAC (Heskes, 2000; Martens & Grosse, 2015) is the probably most common second-order optimizer in DL. The KFAC algorithm is based on a Kronecker-factored approximation of the Fisher, which is also sometimes referred to as KFAC. Here, we refer to the algorithm as KFAC oder KFAC method and to the approximation as Kronecker approximation; we will consider the empirical Fisher’s Kronecker approximation. It approximates the per-layer FIM with a Kronecker-factored block F~lsubscript~𝐹𝑙\smash{\tilde{F}}_{l}over~ start_ARG italic_F end_ARG start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT for each layer l𝑙litalic_l of the net. This approximation has first been derived for linear layers, later for convolutional (Grosse & Martens, 2016) and recurrent layers (Martens et al., 2018), and recently been generalized to all linear layers that use weight sharing (Eschenhagen et al., 2023), e.g. graph neural networks and transformers. A block is given by F~l(𝝁):-𝐔l𝐆l,:-subscript~𝐹𝑙𝝁tensor-productsubscript𝐔𝑙subscript𝐆𝑙\smash{\tilde{F}_{l}}(\mbox{$\mbox{$\boldsymbol{\mu}$}$})\coloneq\mbox{$\mbox{% $\mathbf{U}$}$}_{l}\otimes\mbox{$\mbox{$\mathbf{G}$}$}_{l},over~ start_ARG italic_F end_ARG start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( bold_italic_μ ) :- bold_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ⊗ bold_G start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT , with 𝐔l:-𝐮l𝐮ldi×di:-subscript𝐔𝑙subscript𝐮𝑙superscriptsubscript𝐮𝑙topsuperscriptsubscript𝑑𝑖subscript𝑑𝑖absent\mbox{$\mbox{$\mathbf{U}$}$}_{l}\coloneq\mbox{$\mbox{$\mathbf{u}$}$}_{l}\smash% {\mbox{$\mbox{$\mathbf{u}$}$}_{l}^{\top}}\in^{d_{i}\times d_{i}}bold_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT :- bold_u start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT bold_u start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∈ start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT and 𝐆l:-𝐠l𝐠ldo×do\mbox{$\mbox{$\mathbf{G}$}$}_{l}\coloneq\mbox{$\mbox{$\mathbf{g}$}$}_{l}\smash% {\mbox{$\mbox{$\mathbf{g}$}$}_{l}^{\top}}\in\smash{{}^{d_{o}\times d_{o}}}bold_G start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT :- bold_g start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT bold_g start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∈ start_FLOATSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT end_FLOATSUPERSCRIPT, where 𝐮ldi\mbox{$\mbox{$\mathbf{u}$}$}_{l}\in\smash{{}^{d_{i}}}bold_u start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∈ start_FLOATSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_FLOATSUPERSCRIPT is the l𝑙litalic_lth layer’s input and 𝐠ldo\mbox{$\mbox{$\mathbf{g}$}$}_{l}\in\smash{{}^{d_{o}}}bold_g start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∈ start_FLOATSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT end_FLOATSUPERSCRIPT is the gradient of the loss w.r.t. the layer’s output. We suppress the dependence on the parameters 𝝁𝝁\boldsymbol{\mu}bold_italic_μ and the input 𝐱isubscript𝐱𝑖\mbox{$\mbox{$\mathbf{x}$}$}_{i}bold_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and, for simplicity, assume no weight sharing. KFAC also uses exponential moving averages (β1subscript𝛽1\beta_{1}italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT) over 𝐔𝐔\mathbf{U}bold_U and 𝐆𝐆\mathbf{G}bold_G (yielding 𝐒K,𝐒Csubscript𝐒𝐾subscript𝐒𝐶\mbox{$\mbox{$\mathbf{S}$}$}_{K},\mbox{$\mbox{$\mathbf{S}$}$}_{C}bold_S start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT , bold_S start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT) and damping λ𝜆\lambdaitalic_λ, see Figure 4.

While the Kronecker approximation enables more efficient gradient preconditioning, KFAC needs to store the dense Kronecker factors 𝐒Ksubscript𝐒𝐾\mbox{$\mbox{$\mathbf{S}$}$}_{K}bold_S start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT and 𝐒Csubscript𝐒𝐶\mbox{$\mbox{$\mathbf{S}$}$}_{C}bold_S start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT and invert them at every preconditioner update. The run time overhead is usually amortized by updating the preconditioner less frequently, but this can cause instabilities, especially in low-precision settings. Second, the Kronecker factors introduce significant memory overhead, which poses issues in large models. Since low-precision training is becoming the standard norm in fields like natural language processing, these issues will become more apparent in modern DL. There are multiple numerical concerns when using KFAC or variants thereof in low precision. In PyTorch (Paszke et al., 2019) and JAX (Bradbury et al., 2018) implementations, all tensors must be casted into FP-32 as (B)FP-16 matrix inverses/decompositions are not supported. Moreover, 𝐠lsubscript𝐠𝑙\mbox{$\mbox{$\mathbf{g}$}$}_{l}bold_g start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT has to be rescaled to avoid over- or under-flows when calculating 𝐆lsubscript𝐆𝑙\mbox{$\mbox{$\mathbf{G}$}$}_{l}bold_G start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT. Memory consumption has previously been addressed through diagonal or block-diagonal versions of 𝐔l,𝐆lsubscript𝐔𝑙subscript𝐆𝑙\mbox{$\mbox{$\mathbf{U}$}$}_{l},\mbox{$\mbox{$\mathbf{G}$}$}_{l}bold_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT , bold_G start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT (Zhang et al., 2018; Grosse et al., 2023). However, it is unclear if these simple structures maintain downstream performance.

2.2 INGD: Approximate NGD for Bayesian estimation

Derived from Bayesian principles, INGD (Lin et al., 2023) directly approximates the Hessian inverse. We first introduce two ingredients INGD builds on: the Bayesian learning rule (BLR, Khan & Lin, 2017; Zhang et al., 2018; Khan et al., 2018; Osawa et al., 2019; Lin et al., 2020; Khan & Rue, 2021; Tan, 2022) and an inverse-free second-order method from Lin et al. (2021). By the BLR, Newton’s method to solve the MLE (3) can be seen as another natural-gradient update to solve a variational inference (VI) problem with a delta approximation (Khan & Rue, 2021). This interpretation allows to view a precision matrix in the variational problem as Hessian estimation in the MLE problem. Thus, Lin et al. (2021) suggest reparameterizing the Hessian as the precision of the Gaussian posterior in a matrix logarithm space and exploiting the parameterization invariance of natural gradients to obtain an inverse-free update.

BLR

Consider a Bayesian problem formulation, where NN weights are random variables. We denote these weights by new parameters 𝐰𝐰\mathbf{w}bold_w since random variables are no longer learnable and use a variational Gaussian distribution to approximate the posterior over the random variables. Its mean and precision will be treated as the learnable weights 𝝁𝝁\boldsymbol{\mu}bold_italic_μ and the Hessian estimation 𝐒𝐒\mathbf{S}bold_S in Newton’s step (2).

The VI problem considered in the learning rule is defined as minτ(𝝉)subscript𝜏𝝉\min_{\tau}-\mathcal{L}(\boldsymbol{\tau})roman_min start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT - caligraphic_L ( bold_italic_τ ) with the evidence lower bound (ELBO)

(𝝉):-𝔼wq(wτ)[logp(𝐰)+logp(𝐲𝐰,𝐗)]+Hq(𝝉).:-𝝉subscript𝔼similar-to𝑤𝑞conditional𝑤𝜏delimited-[]𝑝𝐰𝑝conditional𝐲𝐰𝐗subscript𝐻𝑞𝝉\displaystyle\begin{split}\mathcal{L}(\boldsymbol{\tau})&\coloneq\mathbb{E}_{w% \sim q(w\mid\tau)}\left[\log p(\mbox{$\mbox{$\mathbf{w}$}$})+\log p(\mbox{$% \mbox{$\mathbf{y}$}$}\mid\mbox{$\mbox{$\mathbf{w}$}$},\mbox{$\mbox{$\mathbf{X}% $}$})\right]\\ &\phantom{\coloneq}+H_{q}(\boldsymbol{\tau})\,.\end{split}start_ROW start_CELL caligraphic_L ( bold_italic_τ ) end_CELL start_CELL :- blackboard_E start_POSTSUBSCRIPT italic_w ∼ italic_q ( italic_w ∣ italic_τ ) end_POSTSUBSCRIPT [ roman_log italic_p ( bold_w ) + roman_log italic_p ( bold_y ∣ bold_w , bold_X ) ] end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL + italic_H start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ( bold_italic_τ ) . end_CELL end_ROW (5)

𝝉={𝝁,𝐒}𝝉𝝁𝐒\boldsymbol{\tau}=\{\mbox{$\mbox{$\boldsymbol{\mu}$}$},\mbox{$\mbox{$\mathbf{S% }$}$}\}bold_italic_τ = { bold_italic_μ , bold_S } are the learnable parameters of the variational Gaussian distribution q(𝐰𝝉)=𝒩(𝐰𝝁,𝐒)𝑞conditional𝐰𝝉𝒩conditional𝐰𝝁𝐒q(\mbox{$\mbox{$\mathbf{w}$}$}\mid\boldsymbol{\tau})=\mbox{${\cal N}$}(\mbox{$% \mbox{$\mathbf{w}$}$}\mid\mbox{$\mbox{$\boldsymbol{\mu}$}$},\mbox{$\mbox{$% \mathbf{S}$}$})italic_q ( bold_w ∣ bold_italic_τ ) = caligraphic_N ( bold_w ∣ bold_italic_μ , bold_S ) with mean 𝝁𝝁\boldsymbol{\mu}bold_italic_μ and precision 𝐒𝐒\mathbf{S}bold_S. The likelihood p(𝐲𝐰,𝐗)=exp((𝐰;𝐲,𝐗))𝑝conditional𝐲𝐰𝐗𝐰𝐲𝐗p(\mbox{$\mbox{$\mathbf{y}$}$}\mid\mbox{$\mbox{$\mathbf{w}$}$},\mbox{$\mbox{$% \mathbf{X}$}$})=\exp(-\ell(\mbox{$\mbox{$\mathbf{w}$}$};\mbox{$\mbox{$\mathbf{% y}$}$},\mbox{$\mbox{$\mathbf{X}$}$}))italic_p ( bold_y ∣ bold_w , bold_X ) = roman_exp ( - roman_ℓ ( bold_w ; bold_y , bold_X ) ) takes the same form as in the MLE setting while the prior p(𝐰)exp(R(𝐰))proportional-to𝑝𝐰𝑅𝐰p(\mbox{$\mbox{$\mathbf{w}$}$})\propto\exp(-R(\mbox{$\mbox{$\mathbf{w}$}$}))italic_p ( bold_w ) ∝ roman_exp ( - italic_R ( bold_w ) ) is defined by a regularizer R(𝐰)0𝑅𝐰0R(\mbox{$\mbox{$\mathbf{w}$}$})\geq 0italic_R ( bold_w ) ≥ 0. To recover the MLE problem, we consider an uninformative prior p(𝐰)𝑝𝐰p(\mbox{$\mbox{$\mathbf{w}$}$})italic_p ( bold_w ) (i.e., R(𝐰)=0𝑅𝐰0R(\mbox{$\mbox{$\mathbf{w}$}$})=0italic_R ( bold_w ) = 0). Hq(𝝉):-𝔼wq[logq]:-subscript𝐻𝑞𝝉subscript𝔼similar-to𝑤𝑞delimited-[]𝑞H_{q}(\boldsymbol{\tau})\coloneq\mathbb{E}_{w\sim q}\left[-\log q\right]italic_H start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT ( bold_italic_τ ) :- blackboard_E start_POSTSUBSCRIPT italic_w ∼ italic_q end_POSTSUBSCRIPT [ - roman_log italic_q ] is the entropy of q(𝐰𝝉)𝑞conditional𝐰𝝉q(\mbox{$\mbox{$\mathbf{w}$}$}\mid\boldsymbol{\tau})italic_q ( bold_w ∣ bold_italic_τ ).

Similar to the MLE case, the Bayesian formulation allows to exploit additional statistical structures in form of another FIM, which is that of the variational Gaussian defined as

F(𝝉)𝐹𝝉\displaystyle F(\boldsymbol{\tau})italic_F ( bold_italic_τ ) :-𝔼wq(wτ)[τlogq(𝐰𝝉)τlogq(𝐰𝝉)]:-absentsubscript𝔼similar-to𝑤𝑞conditional𝑤𝜏delimited-[]subscript𝜏𝑞conditional𝐰𝝉superscriptsubscript𝜏top𝑞conditional𝐰𝝉\displaystyle\coloneq\mathbb{E}_{w\sim q(w\mid\tau)}\left[\nabla_{\tau}\log q(% \mbox{$\mbox{$\mathbf{w}$}$}\mid\boldsymbol{\tau})\nabla_{\tau}^{\top}\log q(% \mbox{$\mbox{$\mathbf{w}$}$}\mid\boldsymbol{\tau})\right]:- blackboard_E start_POSTSUBSCRIPT italic_w ∼ italic_q ( italic_w ∣ italic_τ ) end_POSTSUBSCRIPT [ ∇ start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT roman_log italic_q ( bold_w ∣ bold_italic_τ ) ∇ start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT roman_log italic_q ( bold_w ∣ bold_italic_τ ) ]
=𝔼wq[τ2logq(𝐰𝝉)],absentsubscript𝔼similar-to𝑤𝑞delimited-[]superscriptsubscript𝜏2𝑞conditional𝐰𝝉\displaystyle=-\mathbb{E}_{w\sim q}\left[\nabla_{\tau}^{2}\log q(\mbox{$\mbox{% $\mathbf{w}$}$}\mid\boldsymbol{\tau})\right]\,,= - blackboard_E start_POSTSUBSCRIPT italic_w ∼ italic_q end_POSTSUBSCRIPT [ ∇ start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_log italic_q ( bold_w ∣ bold_italic_τ ) ] ,

and has a closed-form expression. This FIM should not be confused with the FIM used for MLE (4).

Under the BLR, we perform NGD updates not only on 𝝁𝝁\boldsymbol{\mu}bold_italic_μ but also on 𝐒𝐒\mathbf{S}bold_S. Khan & Rue (2021) formulate a step with the exact FIM F(𝝉)𝐹𝝉F(\boldsymbol{\tau})italic_F ( bold_italic_τ ) and stepsize β>0𝛽0\beta>0italic_β > 0 to update 𝝉={𝝁,𝐒}𝝉𝝁𝐒\boldsymbol{\tau}=\{\mbox{$\mbox{$\boldsymbol{\mu}$}$},\mbox{$\mbox{$\mathbf{S% }$}$}\}bold_italic_τ = { bold_italic_μ , bold_S },

𝝉𝝉β(F(𝝉))1τ((𝝉)).𝝉𝝉𝛽superscript𝐹𝝉1subscript𝜏𝝉\boldsymbol{\tau}\leftarrow\boldsymbol{\tau}-\beta\Big{(}F(\boldsymbol{\tau})% \Big{)}^{-1}\nabla_{\tau}\left(-\mathcal{L}(\boldsymbol{\tau})\right)\,.bold_italic_τ ← bold_italic_τ - italic_β ( italic_F ( bold_italic_τ ) ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT ( - caligraphic_L ( bold_italic_τ ) ) .

This is the NGD update for BLR, vis-à-vis for MLE. Following Khan & Nielsen (2018), the update simplifies to

𝐒𝐒\mathbf{S}bold_S (1β)𝐒+β𝔼wq(wμ,S)[w2(𝐰;𝐲,𝐗)],absent1𝛽𝐒𝛽subscript𝔼similar-to𝑤𝑞conditional𝑤𝜇𝑆delimited-[]superscriptsubscript𝑤2𝐰𝐲𝐗\displaystyle\leftarrow(1-\beta)\mbox{$\mbox{$\mathbf{S}$}$}+\beta{\color[rgb]% {1,0,0}\definecolor[named]{pgfstrokecolor}{rgb}{1,0,0}\mathbb{E}_{w\sim q(w% \mid\mu,S)}\left[\nabla_{w}^{2}\ell(\mbox{$\mbox{$\mathbf{w}$}$};\mbox{$\mbox{% $\mathbf{y}$}$},\mbox{$\mbox{$\mathbf{X}$}$})\right]}\,,← ( 1 - italic_β ) bold_S + italic_β blackboard_E start_POSTSUBSCRIPT italic_w ∼ italic_q ( italic_w ∣ italic_μ , italic_S ) end_POSTSUBSCRIPT [ ∇ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_ℓ ( bold_w ; bold_y , bold_X ) ] ,
𝝁𝝁\boldsymbol{\mu}bold_italic_μ 𝝁β𝐒1𝔼wq(wμ,S)[w(𝐰;𝐲,𝐗)].absent𝝁𝛽superscript𝐒1subscript𝔼similar-to𝑤𝑞conditional𝑤𝜇𝑆delimited-[]subscript𝑤𝐰𝐲𝐗\displaystyle\leftarrow\mbox{$\mbox{$\boldsymbol{\mu}$}$}-\beta\mbox{$\mbox{$% \mathbf{S}$}$}^{-1}{\color[rgb]{1,0,0}\definecolor[named]{pgfstrokecolor}{rgb}% {1,0,0}\mathbb{E}_{w\sim q(w\mid\mu,S)}\left[\nabla_{w}\ell(\mbox{$\mbox{$% \mathbf{w}$}$};\mbox{$\mbox{$\mathbf{y}$}$},\mbox{$\mbox{$\mathbf{X}$}$})% \right]}\,.← bold_italic_μ - italic_β bold_S start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT italic_w ∼ italic_q ( italic_w ∣ italic_μ , italic_S ) end_POSTSUBSCRIPT [ ∇ start_POSTSUBSCRIPT italic_w end_POSTSUBSCRIPT roman_ℓ ( bold_w ; bold_y , bold_X ) ] .

Further simplifying expectations with a delta approximation (highlighted in red) at mean 𝝁𝝁\boldsymbol{\mu}bold_italic_μ, we obtain

𝐒𝐒\mathbf{S}bold_S (1β)𝐒+βμ2(𝝁;𝐲,𝐗),absent1𝛽𝐒𝛽superscriptsubscript𝜇2𝝁𝐲𝐗\displaystyle\leftarrow(1-\beta)\mbox{$\mbox{$\mathbf{S}$}$}+\beta{\color[rgb]% {1,0,0}\definecolor[named]{pgfstrokecolor}{rgb}{1,0,0}\nabla_{\mu}^{2}\ell(% \mbox{$\mbox{$\boldsymbol{\mu}$}$};\mbox{$\mbox{$\mathbf{y}$}$},\mbox{$\mbox{$% \mathbf{X}$}$})}\,,← ( 1 - italic_β ) bold_S + italic_β ∇ start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_ℓ ( bold_italic_μ ; bold_y , bold_X ) ,
𝝁𝝁\boldsymbol{\mu}bold_italic_μ 𝝁β𝐒1μ(𝝁;𝐲,𝐗).absent𝝁𝛽superscript𝐒1subscript𝜇𝝁𝐲𝐗\displaystyle\leftarrow\mbox{$\mbox{$\boldsymbol{\mu}$}$}-\beta\mbox{$\mbox{$% \mathbf{S}$}$}^{-1}{\color[rgb]{1,0,0}\definecolor[named]{pgfstrokecolor}{rgb}% {1,0,0}\nabla_{\mu}\ell(\mbox{$\mbox{$\boldsymbol{\mu}$}$};\mbox{$\mbox{$% \mathbf{y}$}$},\mbox{$\mbox{$\mathbf{X}$}$})}\,.← bold_italic_μ - italic_β bold_S start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT roman_ℓ ( bold_italic_μ ; bold_y , bold_X ) .

which recovers Newton’s method in (2) for β=1𝛽1\beta=1italic_β = 1.

KFAC (Martens & Grosse, 2015)

1:  Each T𝑇Titalic_T iters, update 𝐒Ksubscript𝐒𝐾\mbox{$\mbox{$\mathbf{S}$}$}_{K}bold_S start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT, 𝐒Csubscript𝐒𝐶\mbox{$\mbox{$\mathbf{S}$}$}_{C}bold_S start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT Obtain 𝐔𝐆tensor-product𝐔𝐆\mbox{$\mbox{$\mathbf{U}$}$}\otimes\mbox{$\mbox{$\mathbf{G}$}$}bold_U ⊗ bold_G to approximate μ2(𝝁)superscriptsubscript𝜇2𝝁\nabla_{\mu}^{2}\ell(\mbox{$\mbox{$\boldsymbol{\mu}$}$})∇ start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_ℓ ( bold_italic_μ ) 𝐒K(1β1)𝐒K+β1𝐔subscript𝐒𝐾1subscript𝛽1subscript𝐒𝐾subscript𝛽1𝐔\mbox{$\mbox{$\mathbf{S}$}$}_{K}\leftarrow(1-\beta_{1})\mbox{$\mbox{$\mathbf{S% }$}$}_{K}+\beta_{1}\mbox{$\mbox{$\mathbf{U}$}$}bold_S start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ← ( 1 - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) bold_S start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT + italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT bold_U 𝐒C(1β1)𝐒C+β1𝐆subscript𝐒𝐶1subscript𝛽1subscript𝐒𝐶subscript𝛽1𝐆\mbox{$\mbox{$\mathbf{S}$}$}_{C}\leftarrow(1-\beta_{1})\mbox{$\mbox{$\mathbf{S% }$}$}_{C}+\beta_{1}\mbox{$\mbox{$\mathbf{G}$}$}bold_S start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT ← ( 1 - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) bold_S start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT + italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT bold_G 𝐒K1(𝐒K+λ𝐈di)1superscriptsubscript𝐒𝐾1superscriptsubscript𝐒𝐾𝜆subscript𝐈subscript𝑑𝑖1\mbox{$\mbox{$\mathbf{S}$}$}_{K}^{-1}\leftarrow\left(\mbox{$\mbox{$\mathbf{S}$% }$}_{K}+\lambda\mbox{$\mbox{$\mathbf{I}$}$}_{d_{i}}\right)^{-1}bold_S start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ← ( bold_S start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT + italic_λ bold_I start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT 𝐒C1(𝐒C+λ𝐈do)1superscriptsubscript𝐒𝐶1superscriptsubscript𝐒𝐶𝜆subscript𝐈subscript𝑑𝑜1\mbox{$\mbox{$\mathbf{S}$}$}_{C}^{-1}\leftarrow\left(\mbox{$\mbox{$\mathbf{S}$% }$}_{C}+\lambda\mbox{$\mbox{$\mathbf{I}$}$}_{d_{o}}\right)^{-1}bold_S start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ← ( bold_S start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT + italic_λ bold_I start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT
2:  

𝐦μα2𝐦μ+𝐒C1vec1(𝐠)𝐒K1+γvec1(𝝁)subscript𝐦𝜇subscript𝛼2subscript𝐦𝜇superscriptsubscript𝐒𝐶1superscriptvec1𝐠superscriptsubscript𝐒𝐾1𝛾superscriptvec1𝝁\mbox{$\mbox{$\mathbf{m}$}$}_{\mu}\leftarrow\alpha_{2}\mbox{$\mbox{$\mathbf{m}% $}$}_{\mu}+\mbox{$\mbox{$\mathbf{S}$}$}_{C}^{-1}\mathrm{vec}^{-1}(\mbox{$\mbox% {$\mathbf{g}$}$})\mbox{$\mbox{$\mathbf{S}$}$}_{K}^{-1}+\gamma\mathrm{vec}^{-1}% (\mbox{$\mbox{$\boldsymbol{\mu}$}$})bold_m start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT ← italic_α start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT bold_m start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT + bold_S start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT roman_vec start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_g ) bold_S start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT + italic_γ roman_vec start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_italic_μ )

3:  

𝝁𝝁β2vec(𝐦μ)𝝁𝝁subscript𝛽2vecsubscript𝐦𝜇\mbox{$\mbox{$\boldsymbol{\mu}$}$}\leftarrow\mbox{$\mbox{$\boldsymbol{\mu}$}$}% -\beta_{2}\mathrm{vec}(\mbox{$\mbox{$\mathbf{m}$}$}_{\mu})bold_italic_μ ← bold_italic_μ - italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT roman_vec ( bold_m start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT )

IKFAC (ours)

1:  Each T𝑇Titalic_T iters, update 𝐦Ksubscript𝐦𝐾\mbox{$\mbox{$\mathbf{m}$}$}_{K}bold_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT, 𝐦Csubscript𝐦𝐶\mbox{$\mbox{$\mathbf{m}$}$}_{C}bold_m start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT, 𝐊𝐊\mathbf{K}bold_K, 𝐂𝐂\mathbf{C}bold_C Obtain 𝐔𝐆tensor-product𝐔𝐆\mbox{$\mbox{$\mathbf{U}$}$}\otimes\mbox{$\mbox{$\mathbf{G}$}$}bold_U ⊗ bold_G to approximate μ2(𝝁)superscriptsubscript𝜇2𝝁\nabla_{\mu}^{2}\ell(\mbox{$\mbox{$\boldsymbol{\mu}$}$})∇ start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_ℓ ( bold_italic_μ ) 𝐦K0𝐦K+12do(do𝐇K+λdo𝐊𝐊do𝐈di)subscript𝐦𝐾0subscript𝐦𝐾12subscript𝑑𝑜subscript𝑑𝑜subscript𝐇𝐾𝜆subscript𝑑𝑜superscript𝐊top𝐊subscript𝑑𝑜subscript𝐈subscript𝑑𝑖\mbox{$\mbox{$\mathbf{m}$}$}_{K}\leftarrow{\color[rgb]{1,0,0}\definecolor[% named]{pgfstrokecolor}{rgb}{1,0,0}0}\mbox{$\mbox{$\mathbf{m}$}$}_{K}+\frac{1}{% 2d_{o}}({\color[rgb]{1,0,0}\definecolor[named]{pgfstrokecolor}{rgb}{1,0,0}d_{o% }}\mbox{$\mbox{$\mathbf{H}$}$}_{K}+{\color[rgb]{1,0,0}\definecolor[named]{% pgfstrokecolor}{rgb}{1,0,0}\lambda d_{o}}\mbox{$\mbox{$\mathbf{K}$}$}^{\top}% \mbox{$\mbox{$\mathbf{K}$}$}-d_{o}\mbox{$\mbox{$\mathbf{I}$}$}_{d_{i}})bold_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ← 0 bold_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT + divide start_ARG 1 end_ARG start_ARG 2 italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT end_ARG ( italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT bold_H start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT + italic_λ italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT bold_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_K - italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT bold_I start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) 𝐦C0𝐦C+12di(di𝐇C+λdi𝐂𝐂di𝐈do)subscript𝐦𝐶0subscript𝐦𝐶12subscript𝑑𝑖subscript𝑑𝑖subscript𝐇𝐶𝜆subscript𝑑𝑖superscript𝐂top𝐂subscript𝑑𝑖subscript𝐈subscript𝑑𝑜\mbox{$\mbox{$\mathbf{m}$}$}_{C}\leftarrow{\color[rgb]{1,0,0}\definecolor[% named]{pgfstrokecolor}{rgb}{1,0,0}0}\mbox{$\mbox{$\mathbf{m}$}$}_{C}+\frac{1}{% 2d_{i}}({\color[rgb]{1,0,0}\definecolor[named]{pgfstrokecolor}{rgb}{1,0,0}d_{i% }}\mbox{$\mbox{$\mathbf{H}$}$}_{C}+{\color[rgb]{1,0,0}\definecolor[named]{% pgfstrokecolor}{rgb}{1,0,0}\lambda d_{i}}\mbox{$\mbox{$\mathbf{C}$}$}^{\top}% \mbox{$\mbox{$\mathbf{C}$}$}-d_{i}\mbox{$\mbox{$\mathbf{I}$}$}_{d_{o}})bold_m start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT ← 0 bold_m start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT + divide start_ARG 1 end_ARG start_ARG 2 italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG ( italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_H start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT + italic_λ italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_C start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_C - italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_I start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) 𝐊𝐊(𝐈diβ1𝐦K)𝐊𝐊subscript𝐈subscript𝑑𝑖subscript𝛽1subscript𝐦𝐾\mbox{$\mbox{$\mathbf{K}$}$}\leftarrow\mbox{$\mbox{$\mathbf{K}$}$}(\mbox{$% \mbox{$\mathbf{I}$}$}_{d_{i}}-\beta_{1}\mbox{$\mbox{$\mathbf{m}$}$}_{K})bold_K ← bold_K ( bold_I start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT bold_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ) 𝐂𝐂(𝐈doβ1𝐦C)𝐂𝐂subscript𝐈subscript𝑑𝑜subscript𝛽1subscript𝐦𝐶\mbox{$\mbox{$\mathbf{C}$}$}\leftarrow\mbox{$\mbox{$\mathbf{C}$}$}(\mbox{$% \mbox{$\mathbf{I}$}$}_{d_{o}}-\beta_{1}\mbox{$\mbox{$\mathbf{m}$}$}_{C})bold_C ← bold_C ( bold_I start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT end_POSTSUBSCRIPT - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT bold_m start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT )
2:  

𝐦μα2𝐦μ+𝐂𝐂vec1(𝐠)𝐊𝐊+γvec1(𝝁)subscript𝐦𝜇subscript𝛼2subscript𝐦𝜇superscript𝐂𝐂topsuperscriptvec1𝐠superscript𝐊𝐊top𝛾superscriptvec1𝝁\mbox{$\mbox{$\mathbf{m}$}$}_{\mu}\leftarrow\alpha_{2}\mbox{$\mbox{$\mathbf{m}% $}$}_{\mu}+\mbox{$\mbox{$\mathbf{C}$}$}\mbox{$\mbox{$\mathbf{C}$}$}^{\top}% \mathrm{vec}^{-1}(\mbox{$\mbox{$\mathbf{g}$}$})\mbox{$\mbox{$\mathbf{K}$}$}% \mbox{$\mbox{$\mathbf{K}$}$}^{\top}+\gamma\mathrm{vec}^{-1}(\mbox{$\mbox{$% \boldsymbol{\mu}$}$})bold_m start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT ← italic_α start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT bold_m start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT + roman_C roman_C start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT roman_vec start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_g ) roman_K roman_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT + italic_γ roman_vec start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_italic_μ )

3:  

𝝁𝝁β2vec(𝐦μ)𝝁𝝁subscript𝛽2vecsubscript𝐦𝜇\mbox{$\mbox{$\boldsymbol{\mu}$}$}\leftarrow\mbox{$\mbox{$\boldsymbol{\mu}$}$}% -\beta_{2}\mathrm{vec}(\mbox{$\mbox{$\mathbf{m}$}$}_{\mu})bold_italic_μ ← bold_italic_μ - italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT roman_vec ( bold_m start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT )

Figure 3: Comparison between KFAC and IKFAC update for one weight matrix vec1(𝝁)do×disuperscriptsubscript𝑑𝑜subscript𝑑𝑖superscriptvec1𝝁absent\mathrm{vec}^{-1}(\mbox{$\mbox{$\boldsymbol{\mu}$}$})\in^{d_{o}\times d_{i}}roman_vec start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_italic_μ ) ∈ start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT. The flattened gradient is 𝐠:-μ(𝝁)dodi:-𝐠subscript𝜇𝝁superscriptsubscript𝑑𝑜subscript𝑑𝑖absent\mbox{$\mbox{$\mathbf{g}$}$}\coloneq\nabla_{\mu}\ell(\mbox{$\mbox{$\boldsymbol% {\mu}$}$})\in^{d_{o}d_{i}}bold_g :- ∇ start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT roman_ℓ ( bold_italic_μ ) ∈ start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT and vec1(𝐠)do×disuperscriptsubscript𝑑𝑜subscript𝑑𝑖superscriptvec1𝐠absent\mathrm{vec}^{-1}(\mbox{$\mbox{$\mathbf{g}$}$})\in^{d_{o}\times d_{i}}roman_vec start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_g ) ∈ start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT is its matrix reshape. IKFAC uses 𝐇K:-𝐊𝐔𝐊:-subscript𝐇𝐾superscript𝐊top𝐔𝐊\mbox{$\mbox{$\mathbf{H}$}$}_{K}\coloneq\mbox{$\mbox{$\mathbf{K}$}$}^{\top}% \mbox{$\mbox{$\mathbf{U}$}$}\mbox{$\mbox{$\mathbf{K}$}$}bold_H start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT :- bold_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT roman_U roman_K and 𝐇C:-𝐂𝐆𝐂:-subscript𝐇𝐶superscript𝐂top𝐆𝐂\mbox{$\mbox{$\mathbf{H}$}$}_{C}\coloneq\mbox{$\mbox{$\mathbf{C}$}$}^{\top}% \mbox{$\mbox{$\mathbf{G}$}$}\mbox{$\mbox{$\mathbf{C}$}$}bold_H start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT :- bold_C start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT roman_G roman_C to incorporate the Kronecker curvature 𝐔𝐔\mathbf{U}bold_U and 𝐆𝐆\mathbf{G}bold_G. Both methods use momentum buffers 𝐦μsubscript𝐦𝜇\mbox{$\mbox{$\mathbf{m}$}$}_{\mu}bold_m start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT for the weight-decayed update direction with momentum α2subscript𝛼2\alpha_{2}italic_α start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT and weight decay γ𝛾\gammaitalic_γ, and a learning rate β2subscript𝛽2\beta_{2}italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT for the parameter update. (Left) KFAC uses an exponentially moving average with decay 1β11subscript𝛽11-\beta_{1}1 - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT to accumulate the Kronecker factors and applies a damping term λ𝐈𝜆𝐈\lambda\mbox{$\mbox{$\mathbf{I}$}$}italic_λ bold_I before inversion to handle potential singularities in 𝐒Ksubscript𝐒𝐾\mbox{$\mbox{$\mathbf{S}$}$}_{K}bold_S start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT, 𝐒Csubscript𝐒𝐶\mbox{$\mbox{$\mathbf{S}$}$}_{C}bold_S start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT. (Right) In contrast to KFAC, IKFAC directly approximates (𝐒K+λ𝐈)1superscriptsubscript𝐒𝐾𝜆𝐈1\smash{(\mbox{$\mbox{$\mathbf{S}$}$}_{K}+\lambda\mbox{$\mbox{$\mathbf{I}$}$})^% {-1}}( bold_S start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT + italic_λ bold_I ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT and (𝐒C+λ𝐈)1superscriptsubscript𝐒𝐶𝜆𝐈1\smash{(\mbox{$\mbox{$\mathbf{S}$}$}_{C}+\lambda\mbox{$\mbox{$\mathbf{I}$}$})^% {-1}}( bold_S start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT + italic_λ bold_I ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT by 𝐊𝐊superscript𝐊𝐊top\mbox{$\mbox{$\mathbf{K}$}$}\smash{\mbox{$\mbox{$\mathbf{K}$}$}^{\top}}roman_K roman_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT and 𝐂𝐂superscript𝐂𝐂top\mbox{$\mbox{$\mathbf{C}$}$}\smash{\mbox{$\mbox{$\mathbf{C}$}$}^{\top}}roman_C roman_C start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT. The pre-conditioner update is a modification of INGD (Lin et al., 2023) and the changes—zero Riemannian momentum, and non-adaptive damping and curvature—are highlighted in red.

INGD (Lin et al., 2023)

1:  Each T𝑇Titalic_T iterations, update 𝐦Ksubscript𝐦𝐾\mbox{$\mbox{$\mathbf{m}$}$}_{K}bold_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT, 𝐦Csubscript𝐦𝐶\mbox{$\mbox{$\mathbf{m}$}$}_{C}bold_m start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT, 𝐊𝐊\mathbf{K}bold_K, 𝐂𝐂\mathbf{C}bold_C Obtain 𝐔𝐆tensor-product𝐔𝐆\mbox{$\mbox{$\mathbf{U}$}$}\otimes\mbox{$\mbox{$\mathbf{G}$}$}bold_U ⊗ bold_G to approximate μ2(𝝁)superscriptsubscript𝜇2𝝁\nabla_{\mu}^{2}\ell(\mbox{$\mbox{$\boldsymbol{\mu}$}$})∇ start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_ℓ ( bold_italic_μ ) 𝐦Kα1𝐦K+12do(Tr(𝐇C)𝐇K+c2𝐊𝐊do𝐈di)subscript𝐦𝐾subscript𝛼1subscript𝐦𝐾12subscript𝑑𝑜Trsubscript𝐇𝐶subscript𝐇𝐾superscript𝑐2superscript𝐊top𝐊subscript𝑑𝑜subscript𝐈subscript𝑑𝑖\mbox{$\mbox{$\mathbf{m}$}$}_{K}\leftarrow{\color[rgb]{1,0,0}\definecolor[% named]{pgfstrokecolor}{rgb}{1,0,0}\alpha_{1}}\mbox{$\mbox{$\mathbf{m}$}$}_{K}+% \frac{1}{2d_{o}}({\color[rgb]{1,0,0}\definecolor[named]{pgfstrokecolor}{rgb}{% 1,0,0}\mathrm{Tr}(\mbox{$\mbox{$\mathbf{H}$}$}_{C})}\mbox{$\mbox{$\mathbf{H}$}% $}_{K}+{\color[rgb]{1,0,0}\definecolor[named]{pgfstrokecolor}{rgb}{1,0,0}c^{2}% }\mbox{$\mbox{$\mathbf{K}$}$}^{\top}\mbox{$\mbox{$\mathbf{K}$}$}-d_{o}\mbox{$% \mbox{$\mathbf{I}$}$}_{d_{i}})bold_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ← italic_α start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT bold_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT + divide start_ARG 1 end_ARG start_ARG 2 italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT end_ARG ( roman_Tr ( bold_H start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT ) bold_H start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT + italic_c start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_K - italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT bold_I start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) 𝐦Cα1𝐦C+12di(Tr(𝐇K)𝐇C+κ2𝐂𝐂di𝐈do)subscript𝐦𝐶subscript𝛼1subscript𝐦𝐶12subscript𝑑𝑖Trsubscript𝐇𝐾subscript𝐇𝐶superscript𝜅2superscript𝐂top𝐂subscript𝑑𝑖subscript𝐈subscript𝑑𝑜\mbox{$\mbox{$\mathbf{m}$}$}_{C}\leftarrow{\color[rgb]{1,0,0}\definecolor[% named]{pgfstrokecolor}{rgb}{1,0,0}\alpha_{1}}\mbox{$\mbox{$\mathbf{m}$}$}_{C}+% \frac{1}{2d_{i}}({\color[rgb]{1,0,0}\definecolor[named]{pgfstrokecolor}{rgb}{% 1,0,0}\mathrm{Tr}(\mbox{$\mbox{$\mathbf{H}$}$}_{K})}\mbox{$\mbox{$\mathbf{H}$}% $}_{C}+{\color[rgb]{1,0,0}\definecolor[named]{pgfstrokecolor}{rgb}{1,0,0}% \kappa^{2}}\mbox{$\mbox{$\mathbf{C}$}$}^{\top}\mbox{$\mbox{$\mathbf{C}$}$}-d_{% i}\mbox{$\mbox{$\mathbf{I}$}$}_{d_{o}})bold_m start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT ← italic_α start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT bold_m start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT + divide start_ARG 1 end_ARG start_ARG 2 italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG ( roman_Tr ( bold_H start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ) bold_H start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT + italic_κ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_C start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_C - italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_I start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) 𝐊𝐊(𝐈diβ1𝐦K)𝐊𝐊subscript𝐈subscript𝑑𝑖subscript𝛽1subscript𝐦𝐾\mbox{$\mbox{$\mathbf{K}$}$}\leftarrow\mbox{$\mbox{$\mathbf{K}$}$}(\mbox{$% \mbox{$\mathbf{I}$}$}_{d_{i}}-\beta_{1}\mbox{$\mbox{$\mathbf{m}$}$}_{K})bold_K ← bold_K ( bold_I start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT bold_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ) 𝐂𝐂(𝐈doβ1𝐦C)𝐂𝐂subscript𝐈subscript𝑑𝑜subscript𝛽1subscript𝐦𝐶\mbox{$\mbox{$\mathbf{C}$}$}\leftarrow\mbox{$\mbox{$\mathbf{C}$}$}(\mbox{$% \mbox{$\mathbf{I}$}$}_{d_{o}}-\beta_{1}\mbox{$\mbox{$\mathbf{m}$}$}_{C})bold_C ← bold_C ( bold_I start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT end_POSTSUBSCRIPT - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT bold_m start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT )
2:  

𝐦μα2𝐦μ+𝐂𝐂vec1(𝐠)𝐊𝐊+γvec1(𝝁)subscript𝐦𝜇subscript𝛼2subscript𝐦𝜇superscript𝐂𝐂topsuperscriptvec1𝐠superscript𝐊𝐊top𝛾superscriptvec1𝝁\mbox{$\mbox{$\mathbf{m}$}$}_{\mu}\leftarrow\alpha_{2}\mbox{$\mbox{$\mathbf{m}% $}$}_{\mu}+\mbox{$\mbox{$\mathbf{C}$}$}\mbox{$\mbox{$\mathbf{C}$}$}^{\top}% \mathrm{vec}^{-1}(\mbox{$\mbox{$\mathbf{g}$}$})\mbox{$\mbox{$\mathbf{K}$}$}% \mbox{$\mbox{$\mathbf{K}$}$}^{\top}+\gamma\mathrm{vec}^{-1}(\mbox{$\mbox{$% \boldsymbol{\mu}$}$})bold_m start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT ← italic_α start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT bold_m start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT + roman_C roman_C start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT roman_vec start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_g ) roman_K roman_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT + italic_γ roman_vec start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_italic_μ )

3:  

𝝁𝝁β2vec(𝐦μ)𝝁𝝁subscript𝛽2vecsubscript𝐦𝜇\mbox{$\mbox{$\boldsymbol{\mu}$}$}\leftarrow\mbox{$\mbox{$\boldsymbol{\mu}$}$}% -\beta_{2}\mathrm{vec}(\mbox{$\mbox{$\mathbf{m}$}$}_{\mu})bold_italic_μ ← bold_italic_μ - italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT roman_vec ( bold_m start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT )

SINGD (ours)

1:  Each T𝑇Titalic_T iterations, update ^mKsubscript^subscript𝑚𝐾{\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1}\hat{% \mathcal{L}}}_{m_{K}}over^ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT italic_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT end_POSTSUBSCRIPT, ^mCsubscript^subscript𝑚𝐶{\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1}\hat{% \mathcal{L}}}_{m_{C}}over^ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT italic_m start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT end_POSTSUBSCRIPT, ^Ksubscript^𝐾{\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1}\hat{% \mathcal{L}}}_{K}over^ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT, ^Csubscript^𝐶{\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1}\hat{% \mathcal{L}}}_{C}over^ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT Obtain 𝐔𝐆tensor-product𝐔𝐆\mbox{$\mbox{$\mathbf{U}$}$}\otimes\mbox{$\mbox{$\mathbf{G}$}$}bold_U ⊗ bold_G to approximate μ2(𝝁)superscriptsubscript𝜇2𝝁\nabla_{\mu}^{2}\ell(\mbox{$\mbox{$\boldsymbol{\mu}$}$})∇ start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_ℓ ( bold_italic_μ ) ^mKα1^mK+12doΠ^K(Tr(𝐇^C)𝐇^K+c2(^K)^Kdo𝐈di)subscript^subscript𝑚𝐾subscript𝛼1subscript^subscript𝑚𝐾12subscript𝑑𝑜subscript^Π𝐾Trsubscript𝐇subscript^𝐶subscript𝐇subscript^𝐾superscript𝑐2superscriptsubscript^𝐾topsubscript^𝐾subscript𝑑𝑜subscript𝐈subscript𝑑𝑖{\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1}\hat{% \mathcal{L}}}_{m_{K}}\leftarrow{\color[rgb]{1,0,0}\definecolor[named]{% pgfstrokecolor}{rgb}{1,0,0}\alpha_{1}}{\color[rgb]{0,0,1}\definecolor[named]{% pgfstrokecolor}{rgb}{0,0,1}\hat{\mathcal{L}}}_{m_{K}}+\frac{1}{2d_{o}}{\color[% rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1}\hat{\Pi}}_{K}{% \color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1}(}{\color[rgb% ]{1,0,0}\definecolor[named]{pgfstrokecolor}{rgb}{1,0,0}\mathrm{Tr}(\mbox{$% \mbox{$\mathbf{H}$}$}_{{\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{% rgb}{0,0,1}\hat{\mathcal{L}}}_{C}})}\mbox{$\mbox{$\mathbf{H}$}$}_{{\color[rgb]% {0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1}\hat{\mathcal{L}}}_{K}}+% {\color[rgb]{1,0,0}\definecolor[named]{pgfstrokecolor}{rgb}{1,0,0}c^{2}}({% \color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1}\hat{\mathcal% {L}}}_{K})^{\top}{\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{% 0,0,1}\hat{\mathcal{L}}}_{K}-d_{o}\mbox{$\mbox{$\mathbf{I}$}$}_{d_{i}}{\color[% rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1})}over^ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT italic_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT end_POSTSUBSCRIPT ← italic_α start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT over^ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT italic_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT end_POSTSUBSCRIPT + divide start_ARG 1 end_ARG start_ARG 2 italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT end_ARG over^ start_ARG roman_Π end_ARG start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ( roman_Tr ( bold_H start_POSTSUBSCRIPT over^ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) bold_H start_POSTSUBSCRIPT over^ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT end_POSTSUBSCRIPT + italic_c start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( over^ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT over^ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT - italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT bold_I start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) ^mCα1^𝐦C+12diΠ^C(Tr(𝐇^K)𝐇^C+κ2(^C)^Cdi𝐈do)subscript^subscript𝑚𝐶subscript𝛼1subscript^subscript𝐦𝐶12subscript𝑑𝑖subscript^Π𝐶Trsubscript𝐇subscript^𝐾subscript𝐇subscript^𝐶superscript𝜅2superscriptsubscript^𝐶topsubscript^𝐶subscript𝑑𝑖subscript𝐈subscript𝑑𝑜{\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1}\hat{% \mathcal{L}}}_{m_{C}}\leftarrow{\color[rgb]{1,0,0}\definecolor[named]{% pgfstrokecolor}{rgb}{1,0,0}\alpha_{1}}{\color[rgb]{0,0,1}\definecolor[named]{% pgfstrokecolor}{rgb}{0,0,1}\hat{\mathcal{L}}}_{\mbox{$\mbox{$\mathbf{m}$}$}_{C% }}+\frac{1}{2d_{i}}{\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}% {0,0,1}\hat{\Pi}}_{C}{\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{% rgb}{0,0,1}(}{\color[rgb]{1,0,0}\definecolor[named]{pgfstrokecolor}{rgb}{1,0,0% }\mathrm{Tr}(\mbox{$\mbox{$\mathbf{H}$}$}_{{\color[rgb]{0,0,1}\definecolor[% named]{pgfstrokecolor}{rgb}{0,0,1}\hat{\mathcal{L}}}_{K}})}\mbox{$\mbox{$% \mathbf{H}$}$}_{{\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{% 0,0,1}\hat{\mathcal{L}}}_{C}}+{\color[rgb]{1,0,0}\definecolor[named]{% pgfstrokecolor}{rgb}{1,0,0}\kappa^{2}}({\color[rgb]{0,0,1}\definecolor[named]{% pgfstrokecolor}{rgb}{0,0,1}\hat{\mathcal{L}}}_{C})^{\top}{\color[rgb]{0,0,1}% \definecolor[named]{pgfstrokecolor}{rgb}{0,0,1}\hat{\mathcal{L}}}_{C}-d_{i}% \mbox{$\mbox{$\mathbf{I}$}$}_{d_{o}}{\color[rgb]{0,0,1}\definecolor[named]{% pgfstrokecolor}{rgb}{0,0,1})}over^ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT italic_m start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT end_POSTSUBSCRIPT ← italic_α start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT over^ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT bold_m start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT end_POSTSUBSCRIPT + divide start_ARG 1 end_ARG start_ARG 2 italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG over^ start_ARG roman_Π end_ARG start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT ( roman_Tr ( bold_H start_POSTSUBSCRIPT over^ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) bold_H start_POSTSUBSCRIPT over^ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT end_POSTSUBSCRIPT + italic_κ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( over^ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT over^ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT - italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_I start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) ^K^K(𝐈diβ1^mK)subscript^𝐾subscript^𝐾subscript𝐈subscript𝑑𝑖subscript𝛽1subscript^subscript𝑚𝐾{\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1}\hat{% \mathcal{L}}}_{K}\leftarrow{\color[rgb]{0,0,1}\definecolor[named]{% pgfstrokecolor}{rgb}{0,0,1}\hat{\mathcal{L}}}_{K}(\mbox{$\mbox{$\mathbf{I}$}$}% _{d_{i}}-\beta_{1}{\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{% 0,0,1}\hat{\mathcal{L}}}_{m_{K}})over^ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ← over^ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ( bold_I start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT over^ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT italic_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) ^C^C(𝐈doβ1^mC)subscript^𝐶subscript^𝐶subscript𝐈subscript𝑑𝑜subscript𝛽1subscript^subscript𝑚𝐶{\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1}\hat{% \mathcal{L}}}_{C}\leftarrow{\color[rgb]{0,0,1}\definecolor[named]{% pgfstrokecolor}{rgb}{0,0,1}\hat{\mathcal{L}}}_{C}(\mbox{$\mbox{$\mathbf{I}$}$}% _{d_{o}}-\beta_{1}{\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{% 0,0,1}\hat{\mathcal{L}}}_{m_{C}})over^ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT ← over^ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT ( bold_I start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT end_POSTSUBSCRIPT - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT over^ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT italic_m start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT end_POSTSUBSCRIPT )
2:  

𝐦μα2𝐦μ+^C(^C)vec1(𝐠)^K(^K)+γvec1(𝝁)subscript𝐦𝜇subscript𝛼2subscript𝐦𝜇subscript^𝐶superscriptsubscript^𝐶topsuperscriptvec1𝐠subscript^𝐾superscriptsubscript^𝐾top𝛾superscriptvec1𝝁\mbox{$\mbox{$\mathbf{m}$}$}_{\mu}\leftarrow\alpha_{2}\mbox{$\mbox{$\mathbf{m}% $}$}_{\mu}+{\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1}% \hat{\mathcal{L}}}_{C}({\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{% rgb}{0,0,1}\hat{\mathcal{L}}}_{C})^{\top}\mathrm{vec}^{-1}(\mbox{$\mbox{$% \mathbf{g}$}$}){\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{% 0,0,1}\hat{\mathcal{L}}}_{K}({\color[rgb]{0,0,1}\definecolor[named]{% pgfstrokecolor}{rgb}{0,0,1}\hat{\mathcal{L}}}_{K})^{\top}+\gamma\mathrm{vec}^{% -1}(\mbox{$\mbox{$\boldsymbol{\mu}$}$})bold_m start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT ← italic_α start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT bold_m start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT + over^ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT ( over^ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT roman_vec start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_g ) over^ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ( over^ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT + italic_γ roman_vec start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_italic_μ )

3:  

𝝁𝝁β2vec(𝐦μ)𝝁𝝁subscript𝛽2vecsubscript𝐦𝜇\mbox{$\mbox{$\boldsymbol{\mu}$}$}\leftarrow\mbox{$\mbox{$\boldsymbol{\mu}$}$}% -\beta_{2}\mathrm{vec}(\mbox{$\mbox{$\mathbf{m}$}$}_{\mu})bold_italic_μ ← bold_italic_μ - italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT roman_vec ( bold_m start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT )

Figure 4: Comparison of a single weight matrix’s update between INGD and our extension—SINGD—via structured Kronecker factors. (Left) INGD features Riemannian momentum (α1subscript𝛼1\alpha_{1}italic_α start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT), adaptive curvature (Tr(𝐇C)Trsubscript𝐇𝐶\mathrm{Tr}(\mbox{$\mbox{$\mathbf{H}$}$}_{C})roman_Tr ( bold_H start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT ), Tr(𝐇K)Trsubscript𝐇𝐾\mathrm{Tr}(\mbox{$\mbox{$\mathbf{H}$}$}_{K})roman_Tr ( bold_H start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT )), adaptive damping (c2:-λTr(𝐂𝐂):-superscript𝑐2𝜆Trsuperscript𝐂top𝐂c^{2}\coloneq\lambda\mathrm{Tr}(\mbox{$\mbox{$\mathbf{C}$}$}^{\top}\mbox{$% \mbox{$\mathbf{C}$}$})italic_c start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT :- italic_λ roman_Tr ( bold_C start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_C ), κ2:-λTr(𝐊𝐊):-superscript𝜅2𝜆Trsuperscript𝐊top𝐊\kappa^{2}\coloneq\lambda\mathrm{Tr}(\mbox{$\mbox{$\mathbf{K}$}$}^{\top}\mbox{% $\mbox{$\mathbf{K}$}$})\,italic_κ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT :- italic_λ roman_Tr ( bold_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_K )), and correlated updates of 𝐊𝐊\mathbf{K}bold_K and 𝐂𝐂\mathbf{C}bold_C (𝐦Ksubscript𝐦𝐾\mbox{$\mbox{$\mathbf{m}$}$}_{K}bold_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT, 𝐦Csubscript𝐦𝐶\mbox{$\mbox{$\mathbf{m}$}$}_{C}bold_m start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT). The pre-conditioner matrices are updated with a learning rate β1subscript𝛽1\beta_{1}italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT, and the optimizer keeps a momentum buffer on the weight-decayed update with momentum α2subscript𝛼2\alpha_{2}italic_α start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT and weight decay γ𝛾\gammaitalic_γ. The learning rate for the parameters is β2subscript𝛽2\beta_{2}italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT. (Right) SINGD’s update is similar but each Kronecker factor and its momentum (\bullet) is replaced by its structured version (^subscript^\smash{{\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1}\hat{% \mathcal{L}}}_{\bullet}}over^ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT ∙ end_POSTSUBSCRIPT, e.g. (block-)diagonal); likewise in the computation of c2superscript𝑐2c^{2}italic_c start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT, κ2superscript𝜅2\kappa^{2}italic_κ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT, 𝐇Ksubscript𝐇𝐾\mbox{$\mbox{$\mathbf{H}$}$}_{K}bold_H start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT, and 𝐇Csubscript𝐇𝐶\mbox{$\mbox{$\mathbf{H}$}$}_{C}bold_H start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT. When updating the momenta, their structure is preserved through a subspace projection map Π^()subscript^Π\smash{{\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1}\hat{% \Pi}}_{\bullet}{\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{% 0,0,1}(}\cdot{\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1% })}}over^ start_ARG roman_Π end_ARG start_POSTSUBSCRIPT ∙ end_POSTSUBSCRIPT ( ⋅ ) that restores ^subscript^\smash{{\color[rgb]{0,0,1}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,1}\hat{% \mathcal{L}}}_{\bullet}}over^ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT ∙ end_POSTSUBSCRIPT’s structure from a dense symmetric matrix \cdot (e.g. taking the (block) diagonal). Importantly, we can efficiently compute the extraction map without expanding its argument in dense form, which reduces memory and run time. The extension of IKFAC to SIKFAC is analogous. One of the notable elements of INGD and SINGD is that they are scale invariant to the choice of the Kronecker approximation (see Appendix E) as the approximation is not unique.

Removing inversion

Lin et al. (2021) reparameterize the precision matrix 𝐒𝐒\mathbf{S}bold_S in a matrix logarithm space and perform natural gradient updates in this space, which transforms inversion into subtraction. One can go back directly to the original space, without explicitly inverting a matrix, via a truncated matrix exponential. The method is inverse-free and, since NGs are parameterization invariant, Newton-like.

The first step is to express the precision matrix 𝐒𝐒\mathbf{S}bold_S using a non-singular square matrix 𝐀𝐀\mathbf{A}bold_A as 𝐒=𝐀𝐀1𝐒superscript𝐀absenttopsuperscript𝐀1\mbox{$\mbox{$\mathbf{S}$}$}=\smash{\mbox{$\mbox{$\mathbf{A}$}$}^{-\top}\mbox{% $\mbox{$\mathbf{A}$}$}^{-1}}bold_S = bold_A start_POSTSUPERSCRIPT - ⊤ end_POSTSUPERSCRIPT bold_A start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT and perform a natural gradient step using the exact FIM in a tangent space (denoted by 𝐌𝐌\mathbf{M}bold_M) of 𝐀tsubscript𝐀𝑡\mbox{$\mbox{$\mathbf{A}$}$}_{t}bold_A start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT at iteration t𝑡titalic_t. We then construct a new map as 𝐀:-ϕ(𝐀t,𝐌):-𝐀tExpm(1/2𝐌):-𝐀bold-italic-ϕsubscript𝐀𝑡𝐌:-subscript𝐀𝑡Expm12𝐌\mbox{$\mbox{$\mathbf{A}$}$}\coloneq\mbox{$\mbox{$\boldsymbol{\phi}$}$}(\mbox{% $\mbox{$\mathbf{A}$}$}_{t},\mbox{$\mbox{$\mathbf{M}$}$})\coloneq\mbox{$\mbox{$% \mathbf{A}$}$}_{t}\mathrm{Expm}(\nicefrac{{1}}{{2}}\mbox{$\mbox{$\mathbf{M}$}$})bold_A :- bold_italic_ϕ ( bold_A start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_M ) :- bold_A start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT roman_Expm ( / start_ARG 1 end_ARG start_ARG 2 end_ARG bold_M ) using both the current point 𝐀tsubscript𝐀𝑡\mbox{$\mbox{$\mathbf{A}$}$}_{t}bold_A start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT and 𝐌𝐌\mathbf{M}bold_M as input, where Expm(𝐍)=𝐈+j=1𝐍j/j!Expm𝐍𝐈superscriptsubscript𝑗1superscript𝐍𝑗𝑗\mathrm{Expm}(\mbox{$\mbox{$\mathbf{N}$}$})=\mbox{$\mbox{$\mathbf{I}$}$}+% \smash{\sum_{j=1}^{\infty}\nicefrac{{\mbox{$\mbox{$\mathbf{N}$}$}^{j}}}{{j!}}}roman_Expm ( bold_N ) = bold_I + ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT / start_ARG bold_N start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT end_ARG start_ARG italic_j ! end_ARG is the matrix exponential. Observe that 𝐌𝐌\mathbf{M}bold_M stays in a matrix logarithm space. At each iteration t𝑡titalic_t, we use a new matrix logarithm space associated to 𝐀tsubscript𝐀𝑡\mbox{$\mbox{$\mathbf{A}$}$}_{t}bold_A start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT and generate a new origin 𝐌0=𝟎subscript𝐌00\mbox{$\mbox{$\mathbf{M}$}$}_{0}=\mathbf{0}bold_M start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = bold_0 in this space to represent 𝐀tsubscript𝐀𝑡\mbox{$\mbox{$\mathbf{A}$}$}_{t}bold_A start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT since 𝐀tϕ(𝐀t,𝟎)=𝐀tExpm(1/2𝐌0)subscript𝐀𝑡bold-italic-ϕsubscript𝐀𝑡0subscript𝐀𝑡Expm12subscript𝐌0\mbox{$\mbox{$\mathbf{A}$}$}_{t}\equiv\mbox{$\mbox{$\boldsymbol{\phi}$}$}(% \mbox{$\mbox{$\mathbf{A}$}$}_{t},\mathbf{0})=\mbox{$\mbox{$\mathbf{A}$}$}_{t}% \mathrm{Expm}(\nicefrac{{1}}{{2}}\mbox{$\mbox{$\mathbf{M}$}$}_{0})bold_A start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≡ bold_italic_ϕ ( bold_A start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_0 ) = bold_A start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT roman_Expm ( / start_ARG 1 end_ARG start_ARG 2 end_ARG bold_M start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ). The map ϕbold-italic-ϕ\boldsymbol{\phi}bold_italic_ϕ is a local reparameterization map that takes not only 𝐌𝐌\mathbf{M}bold_M but also 𝐀tsubscript𝐀𝑡\mbox{$\mbox{$\mathbf{A}$}$}_{t}bold_A start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT as input. Thanks to this map, the Fisher block is locally orthonormalized (Lin et al., 2023) at origin 𝐌0subscript𝐌0\mbox{$\mbox{$\mathbf{M}$}$}_{0}bold_M start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. Since we used the origin to represent 𝐀tsubscript𝐀𝑡\mbox{$\mbox{$\mathbf{A}$}$}_{t}bold_A start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT in the local coordinate 𝐌𝐌\mathbf{M}bold_M, a natural gradient step becomes a (Euclidean) gradient step in the space of 𝐌𝐌\mathbf{M}bold_M, which makes it easy to add Riemannian momentum (Lin et al., 2023) into the structured positive-definite matrix 𝐒𝐒\mathbf{S}bold_S. This allows to perform updates in the logarithmic space of 𝐌𝐌\mathbf{M}bold_M and avoid matrix inversions:

𝐌𝐌0β𝐍,𝝁𝝁β𝐀t+1𝐀t+1μ(𝝁;𝐲,𝐗),formulae-sequence𝐌subscript𝐌0𝛽𝐍𝝁𝝁𝛽subscript𝐀𝑡1superscriptsubscript𝐀𝑡1topsubscript𝜇𝝁𝐲𝐗\displaystyle\begin{split}\mbox{$\mbox{$\mathbf{M}$}$}&\leftarrow\mbox{$\mbox{% $\mathbf{M}$}$}_{0}-\beta\mbox{$\mbox{$\mathbf{N}$}$}\,,\\ \mbox{$\mbox{$\boldsymbol{\mu}$}$}&\leftarrow\mbox{$\mbox{$\boldsymbol{\mu}$}$% }-\beta\mbox{$\mbox{$\mathbf{A}$}$}_{t+1}\mbox{$\mbox{$\mathbf{A}$}$}_{t+1}^{% \top}\nabla_{\mu}\ell(\mbox{$\mbox{$\boldsymbol{\mu}$}$};\mbox{$\mbox{$\mathbf% {y}$}$},\mbox{$\mbox{$\mathbf{X}$}$})\,,\end{split}start_ROW start_CELL bold_M end_CELL start_CELL ← bold_M start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT - italic_β bold_N , end_CELL end_ROW start_ROW start_CELL bold_italic_μ end_CELL start_CELL ← bold_italic_μ - italic_β bold_A start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT bold_A start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT roman_ℓ ( bold_italic_μ ; bold_y , bold_X ) , end_CELL end_ROW (6)

where 𝐀t+1:-ϕ(𝐀t,𝐌)=𝐀tExpm(1/2𝐌):-subscript𝐀𝑡1bold-italic-ϕsubscript𝐀𝑡𝐌subscript𝐀𝑡Expm12𝐌\mbox{$\mbox{$\mathbf{A}$}$}_{t+1}\coloneq\mbox{$\mbox{$\boldsymbol{\phi}$}$}(% \mbox{$\mbox{$\mathbf{A}$}$}_{t},\mbox{$\mbox{$\mathbf{M}$}$})=\mbox{$\mbox{$% \mathbf{A}$}$}_{t}\mathrm{Expm}\left(\nicefrac{{1}}{{2}}\mbox{$\mbox{$\mathbf{% M}$}$}\right)bold_A start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT :- bold_italic_ϕ ( bold_A start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_M ) = bold_A start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT roman_Expm ( / start_ARG 1 end_ARG start_ARG 2 end_ARG bold_M ) and 𝐍:-𝐀tμ2(𝝁;𝐲,𝐗)𝐀t𝐈:-𝐍superscriptsubscript𝐀𝑡topsuperscriptsubscript𝜇2𝝁𝐲𝐗subscript𝐀𝑡𝐈\mbox{$\mbox{$\mathbf{N}$}$}\coloneq\mbox{$\mbox{$\mathbf{A}$}$}_{t}^{\top}% \nabla_{\mu}^{2}\ell(\mbox{$\mbox{$\boldsymbol{\mu}$}$};\mbox{$\mbox{$\mathbf{% y}$}$},\mbox{$\mbox{$\mathbf{X}$}$})\mbox{$\mbox{$\mathbf{A}$}$}_{t}-\mbox{$% \mbox{$\mathbf{I}$}$}bold_N :- bold_A start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_ℓ ( bold_italic_μ ; bold_y , bold_X ) bold_A start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - bold_I. Equation 6 is a Newton-like update without matrix inverse. To see that, we can reexpress the update of 𝐀𝐀\mathbf{A}bold_A in terms of 𝐒𝐒\mathbf{S}bold_S and use properties of the matrix exponential function,

𝐒t+1subscript𝐒𝑡1\displaystyle\mbox{$\mbox{$\mathbf{S}$}$}_{t+1}bold_S start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT =𝐀t+1T𝐀t+11=𝐀tTExpm(β𝐍)𝐀t1absentsuperscriptsubscript𝐀𝑡1𝑇superscriptsubscript𝐀𝑡11superscriptsubscript𝐀𝑡𝑇Expm𝛽𝐍superscriptsubscript𝐀𝑡1\displaystyle=\mbox{$\mbox{$\mathbf{A}$}$}_{t+1}^{-T}\mbox{$\mbox{$\mathbf{A}$% }$}_{t+1}^{-1}=\mbox{$\mbox{$\mathbf{A}$}$}_{t}^{-T}\mathrm{Expm}\left(\beta% \mbox{$\mbox{$\mathbf{N}$}$}\right)\mbox{$\mbox{$\mathbf{A}$}$}_{t}^{-1}= bold_A start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_T end_POSTSUPERSCRIPT bold_A start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT = bold_A start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_T end_POSTSUPERSCRIPT roman_Expm ( italic_β bold_N ) bold_A start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT
=(1β)𝐒t+βμ2(𝝁;𝐲,𝐗)+O(β2).absent1𝛽subscript𝐒𝑡𝛽superscriptsubscript𝜇2𝝁𝐲𝐗𝑂superscript𝛽2\displaystyle=(1-\beta)\mbox{$\mbox{$\mathbf{S}$}$}_{t}+\beta\nabla_{\mu}^{2}% \ell(\mbox{$\mbox{$\boldsymbol{\mu}$}$};\mbox{$\mbox{$\mathbf{y}$}$},\mbox{$% \mbox{$\mathbf{X}$}$})+O(\beta^{2}).= ( 1 - italic_β ) bold_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + italic_β ∇ start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_ℓ ( bold_italic_μ ; bold_y , bold_X ) + italic_O ( italic_β start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) .

Next, we can construct a structured precision matrix 𝐒𝐒\mathbf{S}bold_S as a structured Hessian estimation using a sparse non-singular matrix 𝐀𝐀\mathbf{A}bold_A. As we will discuss in Section 3.2, it is essential to update 𝐌𝐌\mathbf{M}bold_M to preserve sparsity in 𝐀𝐀\mathbf{A}bold_A. The space of 𝐌𝐌\mathbf{M}bold_M as a tangent/logarithm space of 𝐀𝐀\mathbf{A}bold_A allows us to efficiently impose sparse structures on 𝐀𝐀\mathbf{A}bold_A without requiring the Hessian μ2(𝝁;𝐲,𝐗)superscriptsubscript𝜇2𝝁𝐲𝐗\nabla_{\mu}^{2}\ell(\mbox{$\mbox{$\boldsymbol{\mu}$}$};\mbox{$\mbox{$\mathbf{% y}$}$},\mbox{$\mbox{$\mathbf{X}$}$})∇ start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_ℓ ( bold_italic_μ ; bold_y , bold_X ) or a Hessian approximation to be sparse or structured. This is different from another inverse-free method (Tan, 2022) that considers directly performing NGD updates of 𝐀𝐀\mathbf{A}bold_A instead of 𝐌𝐌\mathbf{M}bold_M, where 𝐀𝐀\mathbf{A}bold_A must be restricted to a (triangular) Cholesky factor. This does not preserve sparsity in 𝐀𝐀\mathbf{A}bold_A unless the Hessian or its approximation admit a special structure, which is usually not the case in DL problems.

INGD

Our work is built on INGD (Figure 4) where 𝐀=𝐊𝐂𝐀tensor-product𝐊𝐂\mbox{$\mbox{$\mathbf{A}$}$}=\mbox{$\mbox{$\mathbf{K}$}$}\otimes\mbox{$\mbox{$% \mathbf{C}$}$}bold_A = bold_K ⊗ bold_C is factorized into two Kronecker factors. The exact FIM under this parameterization is singular due to a correlation between 𝐊𝐊\mathbf{K}bold_K and 𝐂𝐂\mathbf{C}bold_C: the Kronecker factorization is not unique. Lin et al. (2023) propose a (non-singular) block-diagonal approximated FIM by ignoring the correlation in the original FIM and perform NGD with this block-diagonal FIM on tangent spaces of the factors. Riemannian momentum is further introduced in the update of 𝐊𝐊\mathbf{K}bold_K and 𝐂𝐂\mathbf{C}bold_C. They use the Kronecker approximation discussed in Section 2.1 to approximate the Hessian μ2(𝝁;𝐲,𝐗)superscriptsubscript𝜇2𝝁𝐲𝐗\nabla_{\mu}^{2}\ell(\mbox{$\mbox{$\boldsymbol{\mu}$}$};\mbox{$\mbox{$\mathbf{% y}$}$},\mbox{$\mbox{$\mathbf{X}$}$})∇ start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_ℓ ( bold_italic_μ ; bold_y , bold_X ) and truncate the matrix exponential to obtain a purely matrix-multiplication based update scheme. It is unclear how INGD is related to KFAC which uses another Kronecker factorization 𝐒=𝐒K𝐒C𝐒tensor-productsubscript𝐒𝐾subscript𝐒𝐶\mbox{$\mbox{$\mathbf{S}$}$}=\mbox{$\mbox{$\mathbf{S}$}$}_{K}\otimes\mbox{$% \mbox{$\mathbf{S}$}$}_{C}bold_S = bold_S start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ⊗ bold_S start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT. INGD also remains memory-inefficient due to the use of dense Kronecker factors. The authors only consider and evaluate it on convolution-based models in single precision. It remains unclear whether INGD is useful to train transformer-based models, and in half-precision.

3 Structured inverse-free NGD

Inspired by INGD, we propose an inverse-free KFAC update as a specific setting of INGD to address KFAC’s numerical instability in low precision. We show that this scheme effectively recovers KFAC. We then address the memory inefficiency of KFAC and INGD for training transformer-based models by extending INGD with structures.

3.1 Inverse-free KFAC Updates for Numerical Stability

Table 2: Subspaces of the logarithm space and their projection maps Π^(𝐌)^Π𝐌\hat{\Pi}(\mbox{$\mbox{$\mathbf{M}$}$})over^ start_ARG roman_Π end_ARG ( bold_M ), where 𝐌𝐌\mathbf{M}bold_M is a symmetric matrix. The hierarchical structure is constructed by replacing the diagonal matrix 𝐃22subscript𝐃22\mbox{$\mbox{$\mathbf{D}$}$}_{22}bold_D start_POSTSUBSCRIPT 22 end_POSTSUBSCRIPT in the rank-k upper-triangular structure with another rank-k𝑘kitalic_k triangular matrix [𝐀22𝟎𝐀23𝐀33]matrixsubscript𝐀220subscript𝐀23subscript𝐀33\begin{bmatrix}\mbox{$\mbox{$\mathbf{A}$}$}_{22}&\mathbf{0}\\ \mbox{$\mbox{$\mathbf{A}$}$}_{23}&\mbox{$\mbox{$\mathbf{A}$}$}_{33}\end{bmatrix}[ start_ARG start_ROW start_CELL bold_A start_POSTSUBSCRIPT 22 end_POSTSUBSCRIPT end_CELL start_CELL bold_0 end_CELL end_ROW start_ROW start_CELL bold_A start_POSTSUBSCRIPT 23 end_POSTSUBSCRIPT end_CELL start_CELL bold_A start_POSTSUBSCRIPT 33 end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] for a better approximation.
Subspace of the log (Lie-algebraic) space Matrix Lie sub-group structure in 𝐊𝐊\mathbf{K}bold_K Subspace projection map Π^(𝐌)^Π𝐌\hat{\Pi}(\mbox{$\mbox{$\mathbf{M}$}$})over^ start_ARG roman_Π end_ARG ( bold_M )

[a1,100a2,1a2,20adi,1adi,2adi,di]matrixsubscript𝑎1100subscript𝑎21subscript𝑎22missing-subexpression0subscript𝑎subscript𝑑𝑖1subscript𝑎subscript𝑑𝑖2subscript𝑎subscript𝑑𝑖subscript𝑑𝑖\begin{bmatrix}a_{1,1}&0&\ldots&0\\ a_{2,1}&a_{2,2}&&0\\ \vdots&\vdots&\ddots&\vdots\\ a_{d_{i},1}&a_{d_{i},2}&\ldots&a_{d_{i},d_{i}}\end{bmatrix}[ start_ARG start_ROW start_CELL italic_a start_POSTSUBSCRIPT 1 , 1 end_POSTSUBSCRIPT end_CELL start_CELL 0 end_CELL start_CELL … end_CELL start_CELL 0 end_CELL end_ROW start_ROW start_CELL italic_a start_POSTSUBSCRIPT 2 , 1 end_POSTSUBSCRIPT end_CELL start_CELL italic_a start_POSTSUBSCRIPT 2 , 2 end_POSTSUBSCRIPT end_CELL start_CELL end_CELL start_CELL 0 end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL start_CELL ⋮ end_CELL start_CELL ⋱ end_CELL start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL italic_a start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , 1 end_POSTSUBSCRIPT end_CELL start_CELL italic_a start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , 2 end_POSTSUBSCRIPT end_CELL start_CELL … end_CELL start_CELL italic_a start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ]

Lower-triangular (Tril.)

[m1,1002m2,1m2,202mdi,12mdi,2mdi,di]matrixsubscript𝑚11002subscript𝑚21subscript𝑚22missing-subexpression02subscript𝑚subscript𝑑𝑖12subscript𝑚subscript𝑑𝑖2subscript𝑚subscript𝑑𝑖subscript𝑑𝑖\begin{bmatrix}m_{1,1}&0&\ldots&0\\ 2m_{2,1}&m_{2,2}&&0\\ \vdots&\vdots&\ddots&\vdots\\ 2m_{d_{i},1}&2m_{d_{i},2}&\ldots&m_{d_{i},d_{i}}\end{bmatrix}[ start_ARG start_ROW start_CELL italic_m start_POSTSUBSCRIPT 1 , 1 end_POSTSUBSCRIPT end_CELL start_CELL 0 end_CELL start_CELL … end_CELL start_CELL 0 end_CELL end_ROW start_ROW start_CELL 2 italic_m start_POSTSUBSCRIPT 2 , 1 end_POSTSUBSCRIPT end_CELL start_CELL italic_m start_POSTSUBSCRIPT 2 , 2 end_POSTSUBSCRIPT end_CELL start_CELL end_CELL start_CELL 0 end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL start_CELL ⋮ end_CELL start_CELL ⋱ end_CELL start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL 2 italic_m start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , 1 end_POSTSUBSCRIPT end_CELL start_CELL 2 italic_m start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , 2 end_POSTSUBSCRIPT end_CELL start_CELL … end_CELL start_CELL italic_m start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ]

[𝐀11𝟎𝟎𝟎𝐀22𝟎𝟎𝟎𝐀qq]matrixsubscript𝐀11000subscript𝐀22000subscript𝐀𝑞𝑞\begin{bmatrix}\mathbf{A}_{11}&\mathbf{0}&\cdots&\mathbf{0}\\ \mathbf{0}&\mathbf{A}_{22}&\cdots&\mathbf{0}\\ \vdots&\vdots&\ddots&\vdots\\ \mathbf{0}&\mathbf{0}&\cdots&\mathbf{A}_{qq}\end{bmatrix}[ start_ARG start_ROW start_CELL bold_A start_POSTSUBSCRIPT 11 end_POSTSUBSCRIPT end_CELL start_CELL bold_0 end_CELL start_CELL ⋯ end_CELL start_CELL bold_0 end_CELL end_ROW start_ROW start_CELL bold_0 end_CELL start_CELL bold_A start_POSTSUBSCRIPT 22 end_POSTSUBSCRIPT end_CELL start_CELL ⋯ end_CELL start_CELL bold_0 end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL start_CELL ⋮ end_CELL start_CELL ⋱ end_CELL start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL bold_0 end_CELL start_CELL bold_0 end_CELL start_CELL ⋯ end_CELL start_CELL bold_A start_POSTSUBSCRIPT italic_q italic_q end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ]

(Block) Diagonal (block size k𝑘kitalic_k)

[𝐌11𝟎𝟎𝟎𝐌22𝟎𝟎𝟎𝐌qq]matrixsubscript𝐌11000subscript𝐌22000subscript𝐌𝑞𝑞\begin{bmatrix}\mathbf{M}_{11}&\mathbf{0}&\cdots&\mathbf{0}\\ \mathbf{0}&\mathbf{M}_{22}&\cdots&\mathbf{0}\\ \vdots&\vdots&\ddots&\vdots\\ \mathbf{0}&\mathbf{0}&\cdots&\mathbf{M}_{qq}\end{bmatrix}[ start_ARG start_ROW start_CELL bold_M start_POSTSUBSCRIPT 11 end_POSTSUBSCRIPT end_CELL start_CELL bold_0 end_CELL start_CELL ⋯ end_CELL start_CELL bold_0 end_CELL end_ROW start_ROW start_CELL bold_0 end_CELL start_CELL bold_M start_POSTSUBSCRIPT 22 end_POSTSUBSCRIPT end_CELL start_CELL ⋯ end_CELL start_CELL bold_0 end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL start_CELL ⋮ end_CELL start_CELL ⋱ end_CELL start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL bold_0 end_CELL start_CELL bold_0 end_CELL start_CELL ⋯ end_CELL start_CELL bold_M start_POSTSUBSCRIPT italic_q italic_q end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ]

[𝐀11𝐀12𝐀13𝟎𝐀22𝟎𝟎𝐀32𝐀33]matrixsubscript𝐀11subscript𝐀12subscript𝐀130subscript𝐀2200subscript𝐀32subscript𝐀33\begin{bmatrix}\mathbf{A}_{11}&\mathbf{A}_{12}&\mathbf{A}_{13}\\ \mathbf{0}&\mathbf{A}_{22}&\mathbf{0}\\ \mathbf{0}&\mathbf{A}_{32}&\mathbf{A}_{33}\end{bmatrix}[ start_ARG start_ROW start_CELL bold_A start_POSTSUBSCRIPT 11 end_POSTSUBSCRIPT end_CELL start_CELL bold_A start_POSTSUBSCRIPT 12 end_POSTSUBSCRIPT end_CELL start_CELL bold_A start_POSTSUBSCRIPT 13 end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL bold_0 end_CELL start_CELL bold_A start_POSTSUBSCRIPT 22 end_POSTSUBSCRIPT end_CELL start_CELL bold_0 end_CELL end_ROW start_ROW start_CELL bold_0 end_CELL start_CELL bold_A start_POSTSUBSCRIPT 32 end_POSTSUBSCRIPT end_CELL start_CELL bold_A start_POSTSUBSCRIPT 33 end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] , 𝐀22subscript𝐀22\mathbf{A}_{22}bold_A start_POSTSUBSCRIPT 22 end_POSTSUBSCRIPT is diag., 𝐀11d2×d2subscript𝐀11superscriptsubscript𝑑2subscript𝑑2\mathbf{A}_{11}\in\mathbb{R}^{d_{2}\times d_{2}}bold_A start_POSTSUBSCRIPT 11 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, 𝐀33d3×d3subscript𝐀33superscriptsubscript𝑑3subscript𝑑3\mathbf{A}_{33}\in\mathbb{R}^{d_{3}\times d_{3}}bold_A start_POSTSUBSCRIPT 33 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT Hierarchical (k:-d2+d3:-𝑘subscript𝑑2subscript𝑑3k\coloneq d_{2}+d_{3}italic_k :- italic_d start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT)

[𝐌112𝐌122𝐌13𝟎Diag(𝐌22)𝟎𝟎2𝐌32𝐌33]matrixsubscript𝐌112subscript𝐌122subscript𝐌130Diagsubscript𝐌22002subscript𝐌32subscript𝐌33\begin{bmatrix}\mathbf{M}_{11}&2\mathbf{M}_{12}&2\mathbf{M}_{13}\\ \mathbf{0}&\mathrm{Diag}(\mathbf{M}_{22})&\mathbf{0}\\ \mathbf{0}&2\mathbf{M}_{32}&\mathbf{M}_{33}\end{bmatrix}[ start_ARG start_ROW start_CELL bold_M start_POSTSUBSCRIPT 11 end_POSTSUBSCRIPT end_CELL start_CELL 2 bold_M start_POSTSUBSCRIPT 12 end_POSTSUBSCRIPT end_CELL start_CELL 2 bold_M start_POSTSUBSCRIPT 13 end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL bold_0 end_CELL start_CELL roman_Diag ( bold_M start_POSTSUBSCRIPT 22 end_POSTSUBSCRIPT ) end_CELL start_CELL bold_0 end_CELL end_ROW start_ROW start_CELL bold_0 end_CELL start_CELL 2 bold_M start_POSTSUBSCRIPT 32 end_POSTSUBSCRIPT end_CELL start_CELL bold_M start_POSTSUBSCRIPT 33 end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ]

[𝐀11𝐀12𝟎𝐃22]matrixsubscript𝐀11subscript𝐀120subscript𝐃22\begin{bmatrix}{\mbox{$\mbox{$\mathbf{A}$}$}}_{11}&\mathbf{A}_{12}\\ \mathbf{0}&\mathbf{D}_{22}\end{bmatrix}[ start_ARG start_ROW start_CELL bold_A start_POSTSUBSCRIPT 11 end_POSTSUBSCRIPT end_CELL start_CELL bold_A start_POSTSUBSCRIPT 12 end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL bold_0 end_CELL start_CELL bold_D start_POSTSUBSCRIPT 22 end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] , 𝐃22subscript𝐃22\mathbf{D}_{22}bold_D start_POSTSUBSCRIPT 22 end_POSTSUBSCRIPT is diag., 𝐀11k×ksubscript𝐀11superscript𝑘𝑘\mbox{$\mbox{$\mathbf{A}$}$}_{11}\in\mathbb{R}^{k\times k}bold_A start_POSTSUBSCRIPT 11 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_k × italic_k end_POSTSUPERSCRIPT Rank-k𝑘kitalic_k upper-triangular

[𝐌112𝐌12𝟎Diag(𝐌22)]matrixsubscript𝐌112subscript𝐌120Diagsubscript𝐌22\begin{bmatrix}{\mbox{$\mbox{$\mathbf{M}$}$}}_{11}&2\mathbf{M}_{12}\\ \mathbf{0}&\mathrm{Diag}(\mathbf{M}_{22})\end{bmatrix}[ start_ARG start_ROW start_CELL bold_M start_POSTSUBSCRIPT 11 end_POSTSUBSCRIPT end_CELL start_CELL 2 bold_M start_POSTSUBSCRIPT 12 end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL bold_0 end_CELL start_CELL roman_Diag ( bold_M start_POSTSUBSCRIPT 22 end_POSTSUBSCRIPT ) end_CELL end_ROW end_ARG ]

[a0a1a2a(di1)0a0a100a2a100a0]matrixsubscript𝑎0subscript𝑎1subscript𝑎2subscript𝑎subscript𝑑𝑖10subscript𝑎0subscript𝑎100subscript𝑎2subscript𝑎100subscript𝑎0\begin{bmatrix}a_{0}&a_{1}&a_{2}&\cdots&a_{(d_{i}-1)}\\ 0&a_{0}&a_{1}&\ddots&\vdots\\ 0&0&\ddots&\ddots&a_{2}\\ \vdots&\ddots&\ddots&\ddots&a_{1}\\ 0&\cdots&\ddots&0&a_{0}\end{bmatrix}[ start_ARG start_ROW start_CELL italic_a start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_CELL start_CELL italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_CELL start_CELL italic_a start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_CELL start_CELL ⋯ end_CELL start_CELL italic_a start_POSTSUBSCRIPT ( italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - 1 ) end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL 0 end_CELL start_CELL italic_a start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_CELL start_CELL italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_CELL start_CELL ⋱ end_CELL start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL 0 end_CELL start_CELL 0 end_CELL start_CELL ⋱ end_CELL start_CELL ⋱ end_CELL start_CELL italic_a start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL start_CELL ⋱ end_CELL start_CELL ⋱ end_CELL start_CELL ⋱ end_CELL start_CELL italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL 0 end_CELL start_CELL ⋯ end_CELL start_CELL ⋱ end_CELL start_CELL 0 end_CELL start_CELL italic_a start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ]

Upper-triangular Toeplitz (Triu-Toepl.) [b02b12b22b(di1)0b02b1002b22b100b0]matrixsubscript𝑏02subscript𝑏12subscript𝑏22subscript𝑏subscript𝑑𝑖10subscript𝑏02subscript𝑏1002subscript𝑏22subscript𝑏100subscript𝑏0\begin{bmatrix}b_{0}&2b_{1}&2b_{2}&\cdots&2b_{(d_{i}-1)}\\ 0&b_{0}&2b_{1}&\ddots&\vdots\\ 0&0&\ddots&\ddots&2b_{2}\\ \vdots&\ddots&\ddots&\ddots&2b_{1}\\ 0&\cdots&\cdots&0&b_{0}\end{bmatrix}[ start_ARG start_ROW start_CELL italic_b start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_CELL start_CELL 2 italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_CELL start_CELL 2 italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_CELL start_CELL ⋯ end_CELL start_CELL 2 italic_b start_POSTSUBSCRIPT ( italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - 1 ) end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL 0 end_CELL start_CELL italic_b start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_CELL start_CELL 2 italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_CELL start_CELL ⋱ end_CELL start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL 0 end_CELL start_CELL 0 end_CELL start_CELL ⋱ end_CELL start_CELL ⋱ end_CELL start_CELL 2 italic_b start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL start_CELL ⋱ end_CELL start_CELL ⋱ end_CELL start_CELL ⋱ end_CELL start_CELL 2 italic_b start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL 0 end_CELL start_CELL ⋯ end_CELL start_CELL ⋯ end_CELL start_CELL 0 end_CELL start_CELL italic_b start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] bj:-1dijk=1dijmk,k+j:-subscript𝑏𝑗1subscript𝑑𝑖𝑗superscriptsubscript𝑘1subscript𝑑𝑖𝑗subscript𝑚𝑘𝑘𝑗b_{j}\coloneq\frac{1}{d_{i}-j}\sum_{k=1}^{d_{i}-j}m_{k,k+j}italic_b start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT :- divide start_ARG 1 end_ARG start_ARG italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - italic_j end_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - italic_j end_POSTSUPERSCRIPT italic_m start_POSTSUBSCRIPT italic_k , italic_k + italic_j end_POSTSUBSCRIPT
𝐊𝐊\mathbf{K}bold_K 𝐊𝐊superscript𝐊𝐊top\mbox{$\mbox{$\mathbf{K}$}$}\mbox{$\mbox{$\mathbf{K}$}$}^{\top}roman_K roman_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT (𝐊𝐊)1superscriptsuperscript𝐊𝐊top1\big{(}\mbox{$\mbox{$\mathbf{K}$}$}\mbox{$\mbox{$\mathbf{K}$}$}^{\top})^{-1}( roman_K roman_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT
Dense Refer to caption Refer to caption Refer to caption
Diagonal Refer to caption Refer to caption Refer to caption
Block-diag. Refer to caption Refer to caption Refer to caption
Tril-Toepl. Refer to caption Refer to caption Refer to caption
Triu-Toepl. Refer to caption Refer to caption Refer to caption
Hierarchical Refer to caption Refer to caption Refer to caption
Sparse Triu. Refer to caption Refer to caption Refer to caption
Sparse Triu. Refer to caption Refer to caption Refer to caption
Sparse Tril. Refer to caption Refer to caption Refer to caption
Sparse Tril. Refer to caption Refer to caption Refer to caption
Figure 5: Illustration of structured matrices (Kronecker factors) supported by SINGD, their self-outer product (approximate inverse Hessian factor), and its inverse (approximate Hessian factor). With rank-one triangular matrices 𝐊𝐊\mathbf{K}bold_K, we can easily impose a low-rank structure on 𝐊𝐊superscript𝐊𝐊top\mbox{$\mbox{$\mathbf{K}$}$}\mbox{$\mbox{$\mathbf{K}$}$}^{\top}roman_K roman_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT oder (𝐊𝐊)1superscriptsuperscript𝐊𝐊top1(\mbox{$\mbox{$\mathbf{K}$}$}\mbox{$\mbox{$\mathbf{K}$}$}^{\top})^{-1}( roman_K roman_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT; the latter is difficult to achieve with other approaches.

We first propose a new inverse-free update to mimic the behavior of the KFAC update; we call this update IKFAC. We then show that IKFAC corresponds to a specific setting of INGD. This bridges the gap between INGD and KFAC and sheds light on the difference between both methods.

Inspired by INGD, we replace matrix inversion with matrix subtraction in a matrix logarithm space, then go back to the original space without explicitly inverting any matrix using a truncated matrix exponential map. The IKFAC update is related to the KFAC update as we will use 𝐊𝐊superscript𝐊𝐊top\mbox{$\mbox{$\mathbf{K}$}$}\smash{\mbox{$\mbox{$\mathbf{K}$}$}^{\top}}roman_K roman_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT and 𝐂𝐂superscript𝐂𝐂top\mbox{$\mbox{$\mathbf{C}$}$}\smash{\mbox{$\mbox{$\mathbf{C}$}$}^{\top}}roman_C roman_C start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT to approximate the inverse Kronecker factors (𝐒K+λ𝐈)1superscriptsubscript𝐒𝐾𝜆𝐈1\smash{\big{(}{\mbox{$\mbox{$\mathbf{S}$}$}}_{K}+\lambda\mbox{$\mbox{$\mathbf{% I}$}$}\big{)}^{-1}}( bold_S start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT + italic_λ bold_I ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT and (𝐒C+λ𝐈)1superscriptsubscript𝐒𝐶𝜆𝐈1\smash{\big{(}{\mbox{$\mbox{$\mathbf{S}$}$}}_{C}+\lambda\mbox{$\mbox{$\mathbf{% I}$}$}\big{)}^{-1}}( bold_S start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT + italic_λ bold_I ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT in KFAC, respectively. We propose the following IKFAC update with learning rate β1subscript𝛽1\beta_{1}italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT for 𝐊𝐊\mathbf{K}bold_K and 𝐂𝐂\mathbf{C}bold_C using a truncated matrix exponential

𝐊new𝐊(𝐈β1/2𝐦K),𝐂new𝐂(𝐈β1/2𝐦C),formulae-sequencesuperscript𝐊new𝐊𝐈subscript𝛽12subscript𝐦𝐾superscript𝐂new𝐂𝐈subscript𝛽12subscript𝐦𝐶\displaystyle\begin{split}\mbox{$\mbox{$\mathbf{K}$}$}^{\text{new}}&\leftarrow% \mbox{$\mbox{$\mathbf{K}$}$}\left(\mbox{$\mbox{$\mathbf{I}$}$}-\nicefrac{{% \beta_{1}}}{{2}}{\mbox{$\mbox{$\mathbf{m}$}$}_{K}}\right)\,,\\ \mbox{$\mbox{$\mathbf{C}$}$}^{\text{new}}&\leftarrow\mbox{$\mbox{$\mathbf{C}$}% $}\left(\mbox{$\mbox{$\mathbf{I}$}$}-\nicefrac{{\beta_{1}}}{{2}}{\mbox{$\mbox{% $\mathbf{m}$}$}_{C}}\right)\,,\end{split}start_ROW start_CELL bold_K start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT end_CELL start_CELL ← bold_K ( bold_I - / start_ARG italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG bold_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ) , end_CELL end_ROW start_ROW start_CELL bold_C start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT end_CELL start_CELL ← bold_C ( bold_I - / start_ARG italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG bold_m start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT ) , end_CELL end_ROW (7)

where 𝐇K:-𝐊𝐔𝐊:-subscript𝐇𝐾superscript𝐊top𝐔𝐊\mbox{$\mbox{$\mathbf{H}$}$}_{K}\coloneq\smash{\mbox{$\mbox{$\mathbf{K}$}$}^{% \top}}\mbox{$\mbox{$\mathbf{U}$}$}\mbox{$\mbox{$\mathbf{K}$}$}bold_H start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT :- bold_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT roman_U roman_K, 𝐇C:-𝐂𝐆𝐂:-subscript𝐇𝐶superscript𝐂top𝐆𝐂\mbox{$\mbox{$\mathbf{H}$}$}_{C}\coloneq\smash{\mbox{$\mbox{$\mathbf{C}$}$}^{% \top}}\mbox{$\mbox{$\mathbf{G}$}$}\mbox{$\mbox{$\mathbf{C}$}$}bold_H start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT :- bold_C start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT roman_G roman_C, 𝐦K:-𝐇K+λ𝐊𝐊𝐈:-subscript𝐦𝐾subscript𝐇𝐾𝜆superscript𝐊top𝐊𝐈{\mbox{$\mbox{$\mathbf{m}$}$}_{K}}\coloneq\mbox{$\mbox{$\mathbf{H}$}$}_{K}+% \lambda\smash{\mbox{$\mbox{$\mathbf{K}$}$}^{\top}}\mbox{$\mbox{$\mathbf{K}$}$}% -\mbox{$\mbox{$\mathbf{I}$}$}bold_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT :- bold_H start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT + italic_λ bold_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_K - bold_I, 𝐦C:-𝐇C+λ𝐂𝐂𝐈:-subscript𝐦𝐶subscript𝐇𝐶𝜆superscript𝐂top𝐂𝐈{\mbox{$\mbox{$\mathbf{m}$}$}_{C}}\coloneq\mbox{$\mbox{$\mathbf{H}$}$}_{C}+% \lambda\smash{\mbox{$\mbox{$\mathbf{C}$}$}^{\top}}\mbox{$\mbox{$\mathbf{C}$}$}% -\mbox{$\mbox{$\mathbf{I}$}$}bold_m start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT :- bold_H start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT + italic_λ bold_C start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_C - bold_I. This update is inverse- and matrix-decomposition-free. Since we truncate the matrix exponential Expm(β1/2𝐦K)(𝐈β1/2𝐦K)Expmsubscript𝛽12subscript𝐦𝐾𝐈subscript𝛽12subscript𝐦𝐾\mathrm{Expm}(-\nicefrac{{\beta_{1}}}{{2}}\mbox{$\mbox{$\mathbf{m}$}$}_{K})% \approx(\mbox{$\mbox{$\mathbf{I}$}$}-\nicefrac{{\beta_{1}}}{{2}}\,{\mbox{$% \mbox{$\mathbf{m}$}$}_{K}})roman_Expm ( - / start_ARG italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG bold_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ) ≈ ( bold_I - / start_ARG italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG bold_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ), 𝐦Ksubscript𝐦𝐾\mbox{$\mbox{$\mathbf{m}$}$}_{K}bold_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT indeed stays in a matrix logarithm space (see Appendix C). The logarithm space allows to impose structural constraints on 𝐊𝐊\mathbf{K}bold_K we discuss in Section 3.2.

The following theorem—proof in Appendix D—formally shows that 𝐊𝐊superscript𝐊𝐊top\mbox{$\mbox{$\mathbf{K}$}$}\mbox{$\mbox{$\mathbf{K}$}$}^{\top}roman_K roman_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT used in IKFAC is an approximation of (𝐒K+λ𝐈)1superscriptsubscript𝐒𝐾𝜆𝐈1\smash{(\mbox{$\mbox{$\mathbf{S}$}$}_{K}+\lambda\mbox{$\mbox{$\mathbf{I}$}$})^% {-1}}( bold_S start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT + italic_λ bold_I ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT in KFAC at every step even with a truncated matrix exponential. Similarly, 𝐂𝐂superscript𝐂𝐂top\mbox{$\mbox{$\mathbf{C}$}$}\smash{\mbox{$\mbox{$\mathbf{C}$}$}^{\top}}roman_C roman_C start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT is an approximation of (𝐒C+λ𝐈)1superscriptsubscript𝐒𝐶𝜆𝐈1\smash{(\mbox{$\mbox{$\mathbf{S}$}$}_{C}+\lambda\mbox{$\mbox{$\mathbf{I}$}$})^% {-1}}( bold_S start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT + italic_λ bold_I ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT. Thus, IKFAC effectively recovers KFAC up to a first-order accuracy.

Theorem 1.

If 𝐊𝐊\mathbf{K}bold_K is updated according to the IKFAC scheme (Figure 4) with the truncation of the matrix exponential and these two updates use the same initialization and the same sequence of curvature matrices 𝐔𝐔\mathbf{U}bold_U, then the product 𝐊𝐊superscript𝐊𝐊top\mbox{$\mbox{$\mathbf{K}$}$}\smash{\mbox{$\mbox{$\mathbf{K}$}$}^{\top}}roman_K roman_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT has a first-order accuracy of the KFAC update of (𝐒K+λ𝐈)1superscriptsubscript𝐒𝐾𝜆𝐈1\smash{\big{(}\mbox{$\mbox{$\mathbf{S}$}$}_{K}+\lambda\mbox{$\mbox{$\mathbf{I}% $}$}\big{)}^{-1}}( bold_S start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT + italic_λ bold_I ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT at each iteration, i.e., 𝐊𝐊=(𝐒K+λ𝐈)1+O(β12)superscript𝐊𝐊topsuperscriptsubscript𝐒𝐾𝜆𝐈1𝑂superscriptsubscript𝛽12{\mbox{$\mbox{$\mathbf{K}$}$}}\smash{{\mbox{$\mbox{$\mathbf{K}$}$}}^{\top}}=% \smash{\big{(}{\mbox{$\mbox{$\mathbf{S}$}$}}_{K}+\lambda\mbox{$\mbox{$\mathbf{% I}$}$}\big{)}^{-1}}+O(\beta_{1}^{2})roman_K roman_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT = ( bold_S start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT + italic_λ bold_I ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT + italic_O ( italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ).

1 trivially extends to diagonal and block-diagonal structures. I.e., KFAC with diagonal or block-diagonal Kronecker factors is equivalent to IKFAC with diagonal or block-diagonal structure up to first order in β1subscript𝛽1\beta_{1}italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT.

Now, we show that IKFAC is a specific case of INGD, whose update of 𝐊𝐊\mathbf{K}bold_K without Riemannian momentum (α1=0subscript𝛼10\alpha_{1}=0italic_α start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = 0) is

(8)

Since Tr(𝐈do)=doTrsubscript𝐈subscript𝑑𝑜subscript𝑑𝑜\mathrm{Tr}(\mbox{$\mbox{$\mathbf{I}$}$}_{d_{o}})=d_{o}roman_Tr ( bold_I start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) = italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT, 𝐇Cdo×dosuperscriptsubscript𝑑𝑜subscript𝑑𝑜subscript𝐇𝐶absent\mbox{$\mbox{$\mathbf{H}$}$}_{C}\in^{d_{o}\times d_{o}}bold_H start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT ∈ start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, 𝐂do×dosuperscriptsubscript𝑑𝑜subscript𝑑𝑜𝐂absent\mbox{$\mbox{$\mathbf{C}$}$}\in^{d_{o}\times d_{o}}bold_C ∈ start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, and 𝐊di×disuperscriptsubscript𝑑𝑖subscript𝑑𝑖𝐊absent\mbox{$\mbox{$\mathbf{K}$}$}\in^{d_{i}\times d_{i}}bold_K ∈ start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, we can obtain IKFAC from INGD by simply replacing Tr(𝐇C)Trsubscript𝐇𝐶\mathrm{Tr}(\mbox{$\mbox{$\mathbf{H}$}$}_{C})roman_Tr ( bold_H start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT ) and Tr(𝐂𝐂)Trsuperscript𝐂top𝐂\mathrm{Tr}(\smash{\mbox{$\mbox{$\mathbf{C}$}$}^{\top}}\mbox{$\mbox{$\mathbf{C% }$}$})roman_Tr ( bold_C start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_C ) with Tr(𝐈do)Trsubscript𝐈subscript𝑑𝑜\mathrm{Tr}(\mbox{$\mbox{$\mathbf{I}$}$}_{d_{o}})roman_Tr ( bold_I start_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT end_POSTSUBSCRIPT ):

(9)

This sheds light on the difference between both methods. In IKFAC (see Appendix C for details), 𝐇Ksubscript𝐇𝐾\mbox{$\mbox{$\mathbf{H}$}$}_{K}bold_H start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT and λ𝐊𝐊𝜆superscript𝐊top𝐊\lambda\smash{\mbox{$\mbox{$\mathbf{K}$}$}^{\top}}\mbox{$\mbox{$\mathbf{K}$}$}italic_λ bold_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_K are used for incorporating KFAC’s curvature 𝐔𝐔\mathbf{U}bold_U and damping λ𝐈𝜆𝐈\lambda\mbox{$\mbox{$\mathbf{I}$}$}italic_λ bold_I, respectively. In contrast, the curvature and damping are adaptively incorporated in INGD using (Tr(𝐇C)/do)𝐇KTrsubscript𝐇𝐶subscript𝑑𝑜subscript𝐇𝐾(\mathrm{Tr}(\mbox{$\mbox{$\mathbf{H}$}$}_{C})/d_{o})\mbox{$\mbox{$\mathbf{H}$% }$}_{K}( roman_Tr ( bold_H start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT ) / italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT ) bold_H start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT and (λTr(𝐂𝐂)/do)𝐊𝐊𝜆Trsuperscript𝐂top𝐂subscript𝑑𝑜superscript𝐊top𝐊(\lambda\mathrm{Tr}(\mbox{$\mbox{$\mathbf{C}$}$}^{\top}\mbox{$\mbox{$\mathbf{C% }$}$})/d_{o})\mbox{$\mbox{$\mathbf{K}$}$}^{\top}\mbox{$\mbox{$\mathbf{K}$}$}( italic_λ roman_Tr ( bold_C start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_C ) / italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT ) bold_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_K. The updates of 𝐊𝐊\mathbf{K}bold_K and 𝐂𝐂\mathbf{C}bold_C are correlated in INGD due to the trace terms, while 𝐊𝐊\mathbf{K}bold_K and 𝐂𝐂\mathbf{C}bold_C are updated independently in IKFAC—just like 𝐒Ksubscript𝐒𝐾\mbox{$\mbox{$\mathbf{S}$}$}_{K}bold_S start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT and 𝐒Csubscript𝐒𝐶\mbox{$\mbox{$\mathbf{S}$}$}_{C}bold_S start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT in KFAC. These trace terms are needed to satisfy the orthonormalization condition of the Fisher matrix (Lin et al., 2023). They make INGD and SINGD scale-invariant to the Kronecker approximation (see Appendix E) as the approximation is not unique. In contrast, KFAC and IKFAC are not scale-invariant. The trace terms together with Riemannian momentum (α1>0subscript𝛼10\alpha_{1}>0italic_α start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT > 0) are missing in KFAC and IKFAC. Our experiments show that they can contribute to stability.

3.2 Sparse Kronecker Factors for Reducing Memory

Refer to caption
Figure 6: Test error curves for mixed-precision training in the transformer-based models with BFP-16 on datasets ‘CIFAR-100’ and ‘ImageWoof-10’. SINGD performs as well as INGD while being memory efficient and, including IKFAC and INGD as special cases, outperforms AdamW in most of the cases. We omit KFAC since it performs unstably in BFP-16. The hierarchical structure often performs as well as the dense structure and outperforms the block-diagonal structure.

Now, we extend INGD to reduce its memory and iteration cost. Existing sparse KFAC methods use (block-)diagonal structures for 𝐒Ksubscript𝐒𝐾\mbox{$\mbox{$\mathbf{S}$}$}_{K}bold_S start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT and 𝐒Csubscript𝐒𝐶\mbox{$\mbox{$\mathbf{S}$}$}_{C}bold_S start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT (Zhang et al., 2019; Grosse et al., 2023). In contrast, we propose using sparse Kronecker factors 𝐊𝐊\mathbf{K}bold_K and 𝐂𝐂\mathbf{C}bold_C in INGD and exploiting Lie-algebraic properties in the logarithm space and algebraic sparsity of the Kronecker factors. This enables more flexible structures (Figure 5) that potentially achieve better downstream performance than (block-)diagonal structures in 𝐒Ksubscript𝐒𝐾\mbox{$\mbox{$\mathbf{S}$}$}_{K}bold_S start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT, 𝐒Csubscript𝐒𝐶\mbox{$\mbox{$\mathbf{S}$}$}_{C}bold_S start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT.

Other related works are Lie group preconditioners (Li, 2018, 2022) originally derived from directly approximating the Hessian inverse. Some Hessian-vector-product-based versions of these methods can be expensive and unavailable in pure low-precision settings due to sampling random weights and solving linear systems that are unstable in low precision. Our approach is sampling-free and available in pure half-precision settings.

We want to construct sparse factors 𝐊𝐊\mathbf{K}bold_K and 𝐂𝐂\mathbf{C}bold_C without requiring the Kronecker/Hessian approximation (𝐔𝐆tensor-product𝐔𝐆\mbox{$\mbox{$\mathbf{U}$}$}\otimes\mbox{$\mbox{$\mathbf{G}$}$}bold_U ⊗ bold_G) to be further sparse or structured. Imposing sparsity often leads to a complicated FIM which makes it difficult to perform NGD due to the FIM inversion. It is essential to update 𝐦Ksubscript𝐦𝐾\mbox{$\mbox{$\mathbf{m}$}$}_{K}bold_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT as the logarithm space of 𝐊𝐊\mathbf{K}bold_K to impose sparsity on 𝐊𝐊\mathbf{K}bold_K as the FIM in this (moving) coordinate 𝐦Ksubscript𝐦𝐾\mbox{$\mbox{$\mathbf{m}$}$}_{K}bold_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT is simplified and becomes an identity matrix due to the orthonormalization condition. This condition (Lin et al., 2023) makes it easy for us to impose a range of sparse structures on 𝐊𝐊\mathbf{K}bold_K through a unified and inverse-free update rule (Figure 4) since we can avoid inverting the Fisher block regarding the sparse structures. We also exploit the algebraic sparsity in these structures to make our rule more efficient than INGD (Table 3).

We exploit Lie-algebraic properties in the log space of 𝐦Ksubscript𝐦𝐾\mbox{$\mbox{$\mathbf{m}$}$}_{K}bold_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT to construct sparse structures of 𝐊𝐊\mathbf{K}bold_K. As a general design principle, we consider structures of 𝐊𝐊\mathbf{K}bold_K preserved under (i) elementwise matrix operations (subtraction and scalar multiplication) and (ii) matrix multiplication, which are needed for our updates. Concretely, we construct a new local reparameterization for 𝐊𝐊\mathbf{K}bold_K at iteration t𝑡titalic_t via

𝐊:-𝝍(𝐊t,𝐦K):-𝐊tExpm(12diΠ^K(𝐦K)),:-𝐊𝝍subscript𝐊𝑡subscript𝐦𝐾:-subscript𝐊𝑡Expm12subscript𝑑𝑖subscript^Π𝐾subscript𝐦𝐾\displaystyle\mbox{$\mbox{$\mathbf{K}$}$}\coloneq\mbox{$\boldsymbol{\psi}$}(% \mbox{$\mbox{$\mathbf{K}$}$}_{t},\mbox{$\mbox{$\mathbf{m}$}$}_{K})\coloneq% \mbox{$\mbox{$\mathbf{K}$}$}_{t}\mathrm{Expm}\left(\frac{1}{\sqrt{2d_{i}}}\,% \smash{\hat{\Pi}}_{K}(\mbox{$\mbox{$\mathbf{m}$}$}_{K})\right)\,,bold_K :- bold_italic_ψ ( bold_K start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ) :- bold_K start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT roman_Expm ( divide start_ARG 1 end_ARG start_ARG square-root start_ARG 2 italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG end_ARG over^ start_ARG roman_Π end_ARG start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ( bold_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ) ) ,

where Π^K(𝐦K)subscript^Π𝐾subscript𝐦𝐾\smash{\hat{\Pi}}_{K}(\mbox{$\mbox{$\mathbf{m}$}$}_{K})over^ start_ARG roman_Π end_ARG start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ( bold_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ) projects the dense 𝐦Ksubscript𝐦𝐾\mbox{$\mbox{$\mathbf{m}$}$}_{K}bold_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT to a subspace (identically for 𝐂𝐂\mathbf{C}bold_C, but potentially using a different structure Π^Csubscript^Π𝐶\hat{\Pi}_{C}over^ start_ARG roman_Π end_ARG start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT.

Many popular structures such as tri-diagonal matrices do not satisfy our requirements as they are not closed under matrix multiplication. Moreover, it can be difficult to construct the projection map to satisfy the orthonormalization condition. One subspace structure satisfying the requirements are upper/lower triangular matrices. The subspace projection Π^Ksubscript^Π𝐾\smash{\hat{\Pi}_{K}}over^ start_ARG roman_Π end_ARG start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT is a weighted extraction map since projecting the logarithm space onto a subspace is like projecting a dense square matrix onto a triangular matrix. Technically, we use

𝐀:-𝐊tExpm(Π^K(𝐦K)2di)𝐂t:-𝐀tensor-productsubscript𝐊𝑡Expmsubscript^Π𝐾subscript𝐦𝐾2subscript𝑑𝑖subscript𝐂𝑡\displaystyle\mbox{$\mbox{$\mathbf{A}$}$}\coloneq\mbox{$\mbox{$\mathbf{K}$}$}_% {t}\mathrm{Expm}\left(\frac{\hat{\Pi}_{K}(\mbox{$\mbox{$\mathbf{m}$}$}_{K})}{% \sqrt{2d_{i}}}\right)\otimes\mbox{$\mbox{$\mathbf{C}$}$}_{t}bold_A :- bold_K start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT roman_Expm ( divide start_ARG over^ start_ARG roman_Π end_ARG start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ( bold_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ) end_ARG start_ARG square-root start_ARG 2 italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG end_ARG ) ⊗ bold_C start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT

to update 𝐊𝐊\mathbf{K}bold_K at iteration t𝑡titalic_t, treating 𝐂tsubscript𝐂𝑡\mbox{$\mbox{$\mathbf{C}$}$}_{t}bold_C start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT and 𝐊tsubscript𝐊𝑡\mbox{$\mbox{$\mathbf{K}$}$}_{t}bold_K start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT as constants. Given a subspace ΩKdi×disuperscriptsubscript𝑑𝑖subscript𝑑𝑖subscriptΩ𝐾absent\Omega_{K}\subset^{d_{i}\times d_{i}}roman_Ω start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ⊂ start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT in the matrix logarithm space, the subspace projection map Π^K:Symdi×diΩK:subscript^Π𝐾maps-tosuperscriptSymsubscript𝑑𝑖subscript𝑑𝑖subscriptΩ𝐾\hat{\Pi}_{K}:\mathrm{Sym}^{d_{i}\times d_{i}}\mapsto\Omega_{K}over^ start_ARG roman_Π end_ARG start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT : roman_Sym start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ↦ roman_Ω start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT is specified by satisfying the local orthonormalization condition of the Fisher block regarding 𝐦Ksubscript𝐦𝐾\mbox{$\mbox{$\mathbf{m}$}$}_{K}bold_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT:

F|mK=𝟎:-𝔼wq[mK2logq(𝐰𝝁,𝐒)]|mK=𝟎=𝐈,:-evaluated-at𝐹subscript𝑚𝐾0evaluated-atsubscript𝔼similar-to𝑤𝑞delimited-[]superscriptsubscriptsubscript𝑚𝐾2𝑞conditional𝐰𝝁𝐒subscript𝑚𝐾0𝐈\displaystyle F|_{m_{K}=\mathbf{0}}\coloneq-\mathbb{E}_{w\sim q}\left[\nabla_{% m_{K}}^{2}\log q(\mbox{$\mbox{$\mathbf{w}$}$}\mid\mbox{$\mbox{$\boldsymbol{\mu% }$}$},\mbox{$\mbox{$\mathbf{S}$}$})\right]\big{|}_{m_{K}=\mathbf{0}}=\mbox{$% \mbox{$\mathbf{I}$}$}\,,italic_F | start_POSTSUBSCRIPT italic_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT = bold_0 end_POSTSUBSCRIPT :- - blackboard_E start_POSTSUBSCRIPT italic_w ∼ italic_q end_POSTSUBSCRIPT [ ∇ start_POSTSUBSCRIPT italic_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_log italic_q ( bold_w ∣ bold_italic_μ , bold_S ) ] | start_POSTSUBSCRIPT italic_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT = bold_0 end_POSTSUBSCRIPT = bold_I ,

with the variational Gaussian q(𝐰𝝁,𝐒)𝑞conditional𝐰𝝁𝐒q(\mbox{$\mbox{$\mathbf{w}$}$}\mid\mbox{$\mbox{$\boldsymbol{\mu}$}$},\mbox{$% \mbox{$\mathbf{S}$}$})italic_q ( bold_w ∣ bold_italic_μ , bold_S ) with mean 𝝁𝝁\boldsymbol{\mu}bold_italic_μ, precision 𝐒:-𝐀𝐀1:-𝐒superscript𝐀absenttopsuperscript𝐀1\mbox{$\mbox{$\mathbf{S}$}$}\coloneq\mbox{$\mbox{$\mathbf{A}$}$}^{-\top}\mbox{% $\mbox{$\mathbf{A}$}$}^{-1}bold_S :- bold_A start_POSTSUPERSCRIPT - ⊤ end_POSTSUPERSCRIPT bold_A start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT and Symdi×disuperscriptSymsubscript𝑑𝑖subscript𝑑𝑖\mathrm{Sym}^{d_{i}\times d_{i}}roman_Sym start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT the set of symmetric square real matrices. Similarly, we can obtain Π^Csubscript^Π𝐶\hat{\Pi}_{C}over^ start_ARG roman_Π end_ARG start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT for 𝐂𝐂\mathbf{C}bold_C.

We consider several sparsities and block extensions of triangular matrices illustrated in Figure 5. E.g., the subspace projection map for a diagonal structure simply extracts diagonal entries of its input. As a non-trivial example, the subspace projection map for a lower-triangular structure extracts lower-triangular entries of its input and multiplies the entries below the main diagonal by 2. Table 2 summarizes structures and their projection maps mathematically.

Using such a subspace and its projection map, we obtain a structured INGD update (Figure 4), and similar for IKFAC. Our approach allows to use more expressive structures than the block-diagonal structure shown in Figure 5, e.g. low-rank, flexible hierarchical, and Toeplitz structures. While existing methods mainly support low-rank structures. For an efficient implementation, we only compute and store non-zero entries of Π^K(𝐦K)subscript^Π𝐾subscript𝐦𝐾\smash{\hat{\Pi}_{K}}(\mbox{$\mbox{$\mathbf{m}$}$}_{K})over^ start_ARG roman_Π end_ARG start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ( bold_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ) and 𝐊𝐊\mathbf{K}bold_K without explicitly forming dense matrices. These structures lower not only memory consumption (Table 4), but also the iteration cost (Table 3).

4 Experiments

We evaluate SINGD on convolutional, transformer, and graph NNs, using mixed-precision training in BFP-16 with KFAC-reduce (Eschenhagen et al., 2023) and numerical tricks (Dangel, 2023) to further reduce memory consumption and iteration cost for convolutions. The performance metric is test error. To be memory-efficient, we consider SINGD with sparse structures such as ‘diagonal’, ‘block-diagonal’, and ‘hierarchical’. We also consider IKFAC, INGD (recall SINGD with dense structure becomes INGD), and AdamW as baselines. All methods except KFAC directly support training in BFP-16. For KFAC, we have to transform a matrix into FP-32 and then transform its inverse into BFP-16. We find that KFAC performs unstably in BFP-16. For ‘VGG’ and ‘ConvMixer’, we also consider SGD as a strong baseline, We fix momentum to 0.9 and tune other hyper-parameters of each optimizer using random search. For ‘VGG’ and ‘ConvMixer’, we decrease the learning rate β2subscript𝛽2\beta_{2}italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT every 40 epochs. For ‘GNN’, we use a constant learning rate; all other models use a cosine learning rate schedule. We consider KFAC as a strong baseline for the GNN as suggested by Izadi et al. (2020). We train the GNN in FP-32 so that KFAC performs stably. The search space for the random search can be found in Table 5 in Appendix B.

From Figure 6 and 7, we can observe that SINGD, including IKFAC and INGD as special cases, outperforms AdamW in many cases. SINGD works well for mixed-precision training. We do not show KFAC in the plots as it performs unstably due to numerical issues. We also observe that the hierarchical structure often performs as well as the dense structure (INGD) on all the models. In several cases, the hierarchical structure outperforms the block-diagonal and diagonal structures. However, on the models shown in Figure 7, even the diagonal structure can perform as well as the dense one. Thus, we can reduce INGD’s memory consumption and make SINGD as competitive as AdamW. We also train a ViT model on “ImageNet-100" to demonstrate the superior performance of SINGD over AdamW in large-scale settings (see  Figure 9 in Appendix B).

Refer to caption
Figure 7: Test error curves for mixed-precision training in CNN and GNN models on datasets ‘ImageWoof-10’, ‘CIFAR-100’ and ‘Cora’. ‘Rep-ViT’ is a CNN model inspired by transformers. SINGD performs as well as INGD while being memory efficient. SINGD including IKFAC and INGD as special cases, outperforms AdamW on all the models. The diagonal structure can perform as well as the dense structure on these models. KFAC only appears in the rightmost plot since it performs unstably in the other plots due to numerical issues in half-precision settings.

5 Conclusion

We propose an inverse-free, memory-efficient natural gradient descent method—SINGD—which addresses the numerical instability and memory inefficiency of second-order methods like KFAC (Martens & Grosse, 2015). The algorithm is an extension of the inverse-free natural gradient (INGD) method from Lin et al. (2023), whose update relies only on matrix multiplications. We theoretically establish the algorithm’s relation to KFAC by showing that a modification of INGD effectively performs KFAC-like updates and further improve its memory efficiency through sparse Kronecker factors. We showed that SINGD supports low-precision training and often outperforms AdamW on transformer-based models. Our work expands the scope of second-order methods to training transformer-based NNs and in low precision, making them more widely applicable.

Acknowledgements

Resources used in preparing this research were provided, in part, by the Province of Ontario, the Government of Canada through CIFAR, and companies sponsoring Vector Institute. Runa Eschenhagen is supported by ARM and the Cambridge Trust. Richard E. Turner is supported by Google, Amazon, ARM, Improbable and EPSRC grant EP/T005386/1.

Impact Statement

This paper presents work whose goal is to advance the field of Machine Learning. There are many potential societal consequences of our work, none which we feel must be specifically highlighted here.

References

  • Amari (1998) Amari, S.-I. Natural gradient works efficiently in learning. Neural computation, 10(2):251–276, 1998.
  • Bae et al. (2022) Bae, J., Ng, N., Lo, A., Ghassemi, M., and Grosse, R. B. If influence functions are the answer, then what is the question? In NeurIPS, 2022.
  • Botev et al. (2017) Botev, A., Ritter, H., and Barber, D. Practical Gauss-Newton optimisation for deep learning. In ICML, 2017.
  • Bradbury et al. (2018) Bradbury, J., Frostig, R., Hawkins, P., Johnson, M. J., Leary, C., Maclaurin, D., Necula, G., Paszke, A., VanderPlas, J., Wanderman-Milne, S., and Zhang, Q. JAX: composable transformations of Python+NumPy programs, 2018. URL http://github.com/google/jax.
  • Brown et al. (2020) Brown, T., Mann, B., Ryder, N., Subbiah, M., Kaplan, J. D., Dhariwal, P., Neelakantan, A., Shyam, P., Sastry, G., Askell, A., et al. Language models are few-shot learners. In NeurIPS, 2020.
  • Dangel (2023) Dangel, F. Convolutions through the lens of tensor networks. arXiv 2307.02275, 2023.
  • Daxberger et al. (2021) Daxberger, E., Kristiadi, A., Immer, A., Eschenhagen, R., Bauer, M., and Hennig, P. Laplace redux—effortless Bayesian deep learning. In NeurIPS, 2021.
  • Dehghani et al. (2023) Dehghani, M., Djolonga, J., Mustafa, B., Padlewski, P., Heek, J., Gilmer, J., Steiner, A. P., Caron, M., Geirhos, R., Alabdulmohsin, I., et al. Scaling vision transformers to 22 billion parameters. In ICML, 2023.
  • Eschenhagen et al. (2023) Eschenhagen, R., Immer, A., Turner, R. E., Schneider, F., and Hennig, P. Kronecker-Factored Approximate Curvature for modern neural network architectures. In NeurIPS, 2023.
  • George et al. (2018) George, T., Laurent, C., Bouthillier, X., Ballas, N., and Vincent, P. Fast approximate natural gradient descent in a kronecker factored eigenbasis. In NeurIPS, 2018.
  • Graves (2011) Graves, A. Practical variational inference for neural networks. In NeurIPS, 2011.
  • Grosse & Martens (2016) Grosse, R. and Martens, J. A kronecker-factored approximate fisher matrix for convolution layers. In ICML, 2016.
  • Grosse et al. (2023) Grosse, R., Bae, J., Anil, C., Elhage, N., Tamkin, A., Tajdini, A., Steiner, B., Li, D., Durmus, E., Perez, E., et al. Studying large language model generalization with influence functions. arXiv preprint arXiv:2308.03296, 2023.
  • Hassani et al. (2021) Hassani, A., Walton, S., Shah, N., Abuduweili, A., Li, J., and Shi, H. Escaping the big data paradigm with compact transformers. arXiv preprint arXiv:2104.05704, 2021.
  • Hatamizadeh et al. (2023) Hatamizadeh, A., Yin, H., Heinrich, G., Kautz, J., and Molchanov, P. Global context vision transformers. In International Conference on Machine Learning, pp.  12633–12646. PMLR, 2023.
  • Heskes (2000) Heskes, T. On “natural” learning and pruning in multilayered perceptrons. Neural Computation, 12(4), 2000.
  • Immer et al. (2021) Immer, A., Bauer, M., Fortuin, V., Rätsch, G., and Emtiyaz, K. M. Scalable marginal likelihood estimation for model selection in deep learning. In ICML, 2021.
  • Izadi et al. (2020) Izadi, M. R., Fang, Y., Stevenson, R., and Lin, L. Optimization of graph neural networks with natural gradient descent. In 2020 IEEE international conference on big data (big data), pp.  171–179. IEEE, 2020.
  • Khan & Lin (2017) Khan, M. and Lin, W. Conjugate-computation variational inference: Converting variational inference in non-conjugate models to inferences in conjugate models. In Artificial Intelligence and Statistics, pp.  878–887, 2017.
  • Khan & Nielsen (2018) Khan, M. E. and Nielsen, D. Fast yet Simple Natural-Gradient Descent for Variational Inference in Complex Models. arXiv preprint arXiv:1807.04489, 2018.
  • Khan & Rue (2021) Khan, M. E. and Rue, H. The bayesian learning rule. arXiv preprint arXiv:2107.04562, 2021.
  • Khan et al. (2018) Khan, M. E., Nielsen, D., Tangkaratt, V., Lin, W., Gal, Y., and Srivastava, A. Fast and scalable Bayesian deep learning by weight-perturbation in Adam. In ICML, 2018.
  • Kingma & Ba (2015) Kingma, D. P. and Ba, J. Adam: A method for stochastic optimization. In International Conference on Learning Representations, 2015.
  • Kipf & Welling (2016) Kipf, T. N. and Welling, M. Semi-supervised classification with graph convolutional networks. arXiv preprint arXiv:1609.02907, 2016.
  • Kunstner et al. (2019) Kunstner, F., Balles, L., and Hennig, P. Limitations of the empirical Fisher approximation for natural gradient descent. In NeurIPS, 2019.
  • Li (2022) Li, X. Black box lie group preconditioners for sgd. arXiv preprint arXiv:2211.04422, 2022.
  • Li (2018) Li, X.-L. Preconditioner on matrix lie group for sgd. In International Conference on Learning Representations, 2018.
  • Lin et al. (2020) Lin, W., Schmidt, M., and Khan, M. E. Handling the positive-definite constraint in the bayesian learning rule. In ICML, 2020.
  • Lin et al. (2021) Lin, W., Nielsen, F., Emtiyaz, K. M., and Schmidt, M. Tractable structured natural-gradient descent using local parameterizations. In ICML, 2021.
  • Lin et al. (2023) Lin, W., Duruisseaux, V., Leok, M., Nielsen, F., Khan, M. E., and Schmidt, M. Simplifying momentum-based positive-definite submanifold optimization with applications to deep learning. In ICML, 2023.
  • Liu et al. (2021) Liu, Z., Lin, Y., Cao, Y., Hu, H., Wei, Y., Zhang, Z., Lin, S., and Guo, B. Swin transformer: Hierarchical vision transformer using shifted windows. In Proceedings of the IEEE/CVF international conference on computer vision, pp.  10012–10022, 2021.
  • Loshchilov & Hutter (2019) Loshchilov, I. and Hutter, F. Decoupled weight decay regularization. In ICLR, 2019.
  • Lu et al. (2022) Lu, Z., Xie, H., Liu, C., and Zhang, Y. Bridging the gap between vision transformers and convolutional neural networks on small datasets. Advances in Neural Information Processing Systems, 35:14663–14677, 2022.
  • Martens (2014) Martens, J. New insights and perspectives on the natural gradient method. JMLR, 21(146), 2014.
  • Martens & Grosse (2015) Martens, J. and Grosse, R. Optimizing neural networks with Kronecker-factored approximate curvature. In ICML, 2015.
  • Martens et al. (2018) Martens, J., Ba, J., and Johnson, M. Kronecker-factored curvature approximations for recurrent neural networks. In ICLR, 2018.
  • Micikevicius et al. (2018) Micikevicius, P., Narang, S., Alben, J., Diamos, G., Elsen, E., Garcia, D., Ginsburg, B., Houston, M., Kuchaiev, O., Venkatesh, G., and Wu, H. Mixed precision training. In International Conference on Learning Representations (ICLR), 2018.
  • Osawa et al. (2019) Osawa, K., Swaroop, S., Khan, M. E. E., Jain, A., Eschenhagen, R., Turner, R. E., and Yokota, R. Practical deep learning with Bayesian principles. In NeurIPS, 2019.
  • Osawa et al. (2023) Osawa, K., Li, S., and Hoefler, T. PipeFisher: Efficient training of large language models using pipelining and Fisher information matrices. In MLSys, 2023.
  • Osborne (1992) Osborne, M. R. Fisher’s method of scoring. International Statistical Review/Revue Internationale de Statistique, pp.  99–117, 1992.
  • Paszke et al. (2019) Paszke, A., Gross, S., Massa, F., Lerer, A., Bradbury, J., Chanan, G., Killeen, T., Lin, Z., Gimelshein, N., Antiga, L., et al. PyTorch: An imperative style, high-performance deep learning library. In NeurIPS, 2019.
  • Radford et al. (2019) Radford, A., Wu, J., Child, R., Luan, D., Amodei, D., Sutskever, I., et al. Language models are unsupervised multitask learners. OpenAI blog, 1(8):9, 2019.
  • Robbins & Monro (1951) Robbins, H. and Monro, S. A Stochastic Approximation Method. The Annals of Mathematical Statistics, 1951.
  • Schraudolph (2002) Schraudolph, N. N. Fast curvature matrix-vector products for second-order gradient descent. Neural computation, 14(7), 2002.
  • Simonyan & Zisserman (2014) Simonyan, K. and Zisserman, A. Very deep convolutional networks for large-scale image recognition. arXiv preprint arXiv:1409.1556, 2014.
  • Smyth (1996) Smyth, G. K. Partitioned algorithms for maximum likelihood and other non-linear estimation. Statistics and Computing, 6:201–216, 1996.
  • Smyth (2015) Smyth, G. K. Optimization and nonlinear equations. Statistics reference online, 1:1–9, 2015.
  • Tan (2022) Tan, L. S. Analytic natural gradient updates for cholesky factor in gaussian variational approximation. arXiv preprint arXiv:2109.00375, 2022.
  • Thompson et al. (2020) Thompson, N. C., Greenewald, K., Lee, K., and Manso, G. F. The computational limits of deep learning. 2020.
  • Touvron et al. (2023) Touvron, H., Lavril, T., Izacard, G., Martinet, X., Lachaux, M.-A., Lacroix, T., Rozière, B., Goyal, N., Hambro, E., Azhar, F., et al. LLaMA: Open and efficient foundation language models. arXiv preprint arXiv:2302.13971, 2023.
  • Trockman & Kolter (2023) Trockman, A. and Kolter, J. Z. Patches are all you need? Transactions on Machine Learning Research, 2023.
  • Vaswani et al. (2017) Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, Ł., and Polosukhin, I. Attention is all you need. In NIPS, 2017.
  • Wang et al. (2023) Wang, A., Chen, H., Lin, Z., Pu, H., and Ding, G. Repvit: Revisiting mobile cnn from vit perspective. arXiv preprint arXiv:2307.09283, 2023.
  • Wang et al. (2019) Wang, C., Grosse, R., Fidler, S., and Zhang, G. Eigendamage: Structured pruning in the kronecker-factored eigenbasis. In ICML, 2019.
  • Wang (2010) Wang, Y. Fisher scoring: An interpolation family and its Monte Carlo implementations. Comput. Stat. Data Anal., 54(7), 2010.
  • Zhang et al. (2018) Zhang, G., Sun, S., Duvenaud, D., and Grosse, R. Noisy natural gradient as variational inference. In ICML, 2018.
  • Zhang et al. (2019) Zhang, G., Li, L., Nado, Z., Martens, J., Sachdeva, S., Dahl, G. E., Shallue, C. J., and Grosse, R. B. Which algorithmic choices matter at which batch sizes? Insights from a noisy quadratic model. In NeurIPS, 2019.

Appendix A space and time complexity

Method
𝝁𝝁\triangle\mbox{$\mbox{$\boldsymbol{\mu}$}$}△ bold_italic_μ (descent direction)
Update 𝐒Ksubscript𝐒𝐾\mbox{$\mbox{$\mathbf{S}$}$}_{K}bold_S start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT oder 𝐊𝐊\mathbf{K}bold_K
Update 𝐒Csubscript𝐒𝐶\mbox{$\mbox{$\mathbf{S}$}$}_{C}bold_S start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT oder 𝐂𝐂\mathbf{C}bold_C
μsubscript𝜇\nabla_{\mu}\ell∇ start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT roman_ℓ (BackProp)
Iteration Cost KFAC O(di2do+do2di)𝑂superscriptsubscript𝑑𝑖2subscript𝑑𝑜superscriptsubscript𝑑𝑜2subscript𝑑𝑖O(d_{i}^{2}d_{o}+d_{o}^{2}d_{i})italic_O ( italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) O(1T(mdi2+di3))𝑂1𝑇𝑚superscriptsubscript𝑑𝑖2superscriptsubscript𝑑𝑖3O(\frac{1}{T}(md_{i}^{2}+d_{i}^{3}))italic_O ( divide start_ARG 1 end_ARG start_ARG italic_T end_ARG ( italic_m italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT ) ) O(1T(mdo2+do3))𝑂1𝑇𝑚superscriptsubscript𝑑𝑜2superscriptsubscript𝑑𝑜3O(\frac{1}{T}(md_{o}^{2}+d_{o}^{3}))italic_O ( divide start_ARG 1 end_ARG start_ARG italic_T end_ARG ( italic_m italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT ) ) O(mdido)𝑂𝑚subscript𝑑𝑖subscript𝑑𝑜O(md_{i}d_{o})italic_O ( italic_m italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT )
INGD/SINGD (Dense) O(di2do+do2di)𝑂superscriptsubscript𝑑𝑖2subscript𝑑𝑜superscriptsubscript𝑑𝑜2subscript𝑑𝑖O(d_{i}^{2}d_{o}+d_{o}^{2}d_{i})italic_O ( italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT + italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) O(1T(mdi2+di3))𝑂1𝑇𝑚superscriptsubscript𝑑𝑖2superscriptsubscript𝑑𝑖3O(\frac{1}{T}(md_{i}^{2}+d_{i}^{3}))italic_O ( divide start_ARG 1 end_ARG start_ARG italic_T end_ARG ( italic_m italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT ) ) O(1T(mdo2+do3))𝑂1𝑇𝑚superscriptsubscript𝑑𝑜2superscriptsubscript𝑑𝑜3O(\frac{1}{T}(md_{o}^{2}+d_{o}^{3}))italic_O ( divide start_ARG 1 end_ARG start_ARG italic_T end_ARG ( italic_m italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT ) ) O(mdido)𝑂𝑚subscript𝑑𝑖subscript𝑑𝑜O(md_{i}d_{o})italic_O ( italic_m italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT )
SINGD (Block-Diag. with block size k𝑘kitalic_k) O(kdido)𝑂𝑘subscript𝑑𝑖subscript𝑑𝑜O(kd_{i}d_{o})italic_O ( italic_k italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT ) O(1T(kmdi))𝑂1𝑇𝑘𝑚subscript𝑑𝑖O(\frac{1}{T}(kmd_{i}))italic_O ( divide start_ARG 1 end_ARG start_ARG italic_T end_ARG ( italic_k italic_m italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) O(1T(kmdo))𝑂1𝑇𝑘𝑚subscript𝑑𝑜O(\frac{1}{T}(kmd_{o}))italic_O ( divide start_ARG 1 end_ARG start_ARG italic_T end_ARG ( italic_k italic_m italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT ) ) O(mdido)𝑂𝑚subscript𝑑𝑖subscript𝑑𝑜O(md_{i}d_{o})italic_O ( italic_m italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT )
SINGD (Toeplitz) O(didolog(dodi))𝑂subscript𝑑𝑖subscript𝑑𝑜subscript𝑑𝑜subscript𝑑𝑖O(d_{i}d_{o}\log(d_{o}d_{i}))italic_O ( italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT roman_log ( italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) O(1T(mdilogdi))𝑂1𝑇𝑚subscript𝑑𝑖subscript𝑑𝑖O(\frac{1}{T}(md_{i}\log d_{i}))italic_O ( divide start_ARG 1 end_ARG start_ARG italic_T end_ARG ( italic_m italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT roman_log italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) O(1T(mdologdo))𝑂1𝑇𝑚subscript𝑑𝑜subscript𝑑𝑜O(\frac{1}{T}(md_{o}\log d_{o}))italic_O ( divide start_ARG 1 end_ARG start_ARG italic_T end_ARG ( italic_m italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT roman_log italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT ) ) O(mdido)𝑂𝑚subscript𝑑𝑖subscript𝑑𝑜O(md_{i}d_{o})italic_O ( italic_m italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT )
SINGD (Rank-1 Triangular) O(dido)𝑂subscript𝑑𝑖subscript𝑑𝑜O(d_{i}d_{o})italic_O ( italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT ) O(1T(mdi))𝑂1𝑇𝑚subscript𝑑𝑖O(\frac{1}{T}(md_{i}))italic_O ( divide start_ARG 1 end_ARG start_ARG italic_T end_ARG ( italic_m italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) O(1T(mdo))𝑂1𝑇𝑚subscript𝑑𝑜O(\frac{1}{T}(md_{o}))italic_O ( divide start_ARG 1 end_ARG start_ARG italic_T end_ARG ( italic_m italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT ) ) O(mdido)𝑂𝑚subscript𝑑𝑖subscript𝑑𝑜O(md_{i}d_{o})italic_O ( italic_m italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT )
SINGD (Hierarchical with parameter k𝑘kitalic_k) O(kdido)𝑂𝑘subscript𝑑𝑖subscript𝑑𝑜O(kd_{i}d_{o})italic_O ( italic_k italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT ) O(1T(kmdi))𝑂1𝑇𝑘𝑚subscript𝑑𝑖O(\frac{1}{T}(kmd_{i}))italic_O ( divide start_ARG 1 end_ARG start_ARG italic_T end_ARG ( italic_k italic_m italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) O(1T(kmdo))𝑂1𝑇𝑘𝑚subscript𝑑𝑜O(\frac{1}{T}(kmd_{o}))italic_O ( divide start_ARG 1 end_ARG start_ARG italic_T end_ARG ( italic_k italic_m italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT ) ) O(mdido)𝑂𝑚subscript𝑑𝑖subscript𝑑𝑜O(md_{i}d_{o})italic_O ( italic_m italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT )
AdamW O(dido)𝑂subscript𝑑𝑖subscript𝑑𝑜O(d_{i}d_{o})italic_O ( italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT ) NA NA O(mdido)𝑂𝑚subscript𝑑𝑖subscript𝑑𝑜O(md_{i}d_{o})italic_O ( italic_m italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT )
Table 3: Iteration cost for a non-weight-sharing layer, where m𝑚mitalic_m is the size of a mini-batch and 𝝁di×dosuperscriptsubscript𝑑𝑖subscript𝑑𝑜𝝁absent\mbox{$\mbox{$\boldsymbol{\mu}$}$}\in^{d_{i}\times d_{o}}bold_italic_μ ∈ start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT end_POSTSUPERSCRIPT is a learnable weight matrix. We assume factors 𝐊𝐊\mathbf{K}bold_K and 𝐂𝐂\mathbf{C}bold_C use the same structure.
Method
μμdirect-productsubscript𝜇subscript𝜇\nabla_{\mu}\ell\odot\nabla_{\mu}\ell∇ start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT roman_ℓ ⊙ ∇ start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT roman_ℓ
𝐒Ksubscript𝐒𝐾\mbox{$\mbox{$\mathbf{S}$}$}_{K}bold_S start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT oder 𝐊𝐊\mathbf{K}bold_K
𝐒Csubscript𝐒𝐶\mbox{$\mbox{$\mathbf{S}$}$}_{C}bold_S start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT oder 𝐂𝐂\mathbf{C}bold_C
Memory Usage KFAC NA O(di2)𝑂superscriptsubscript𝑑𝑖2O(d_{i}^{2})italic_O ( italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) O(do2)𝑂superscriptsubscript𝑑𝑜2O(d_{o}^{2})italic_O ( italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT )
INGD/SINGD (Dense) NA O(di2)𝑂superscriptsubscript𝑑𝑖2O(d_{i}^{2})italic_O ( italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) O(do2)𝑂superscriptsubscript𝑑𝑜2O(d_{o}^{2})italic_O ( italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT )
SINGD (Block-Diag. with block size k𝑘kitalic_k) NA O(kdi)𝑂𝑘subscript𝑑𝑖O(kd_{i})italic_O ( italic_k italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) O(kdo)𝑂𝑘subscript𝑑𝑜O(kd_{o})italic_O ( italic_k italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT )
SINGD (Toeplitz) NA O(di)𝑂subscript𝑑𝑖O(d_{i})italic_O ( italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) O(do)𝑂subscript𝑑𝑜O(d_{o})italic_O ( italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT )
SINGD (Rank-1 Triangular) NA O(di)𝑂subscript𝑑𝑖O(d_{i})italic_O ( italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) O(do)𝑂subscript𝑑𝑜O(d_{o})italic_O ( italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT )
SINGD (Hierarchical with parameter k𝑘kitalic_k) NA O(kdi)𝑂𝑘subscript𝑑𝑖O(kd_{i})italic_O ( italic_k italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) O(kdo)𝑂𝑘subscript𝑑𝑜O(kd_{o})italic_O ( italic_k italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT )
AdamW O(dido)𝑂subscript𝑑𝑖subscript𝑑𝑜O(d_{i}d_{o})italic_O ( italic_d start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_d start_POSTSUBSCRIPT italic_o end_POSTSUBSCRIPT ) NA NA
Table 4: Additional Storage

Appendix B Details of the Experiments

To demonstrate the robustness and memory efficiency of our method, we consider image classification tasks with transformer-based models such as “Compact-ViT" (Hassani et al., 2021), “Swin-ViT" (Liu et al., 2021), “GC-ViT" (Hatamizadeh et al., 2023), and “HDVT” (Lu et al., 2022). We also consider convolution-based models such as “VGG” (Simonyan & Zisserman, 2014), “ConvMixer” (Trockman & Kolter, 2023), and “Rep-ViT" (Wang et al., 2023). We train these models on datasets “CIFAR-100" and “ImageWoof-10". Note that “Rep-ViT" is a CNN model inspired by transformers while “Compact-ViT" is a data-efficient transformer using convolutional tokenization. We also consider a graph convolution model (Kipf & Welling, 2016) denoted by “GNN” for node classification on dataset “Cora". We also train a ViT model on “ImageNet-100" (https://www.kaggle.com/datasets/ambityga/imagenet100) to demonstrate the performance of SINGD in large-scale settings (see Fig. 9).

B.1 Hyper-parameter Tuning

Hyperparameter
Meaning
KFAC/IKFAC/SINGD in Figure 4 and 8
AdamW in Figure 8
β2subscript𝛽2\beta_{2}italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT Standard stepsize Tuned Tuned
α2subscript𝛼2\alpha_{2}italic_α start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT Standard momentum weight 0.9 0.9
γ𝛾\gammaitalic_γ (L2) weight decay Tuned Tuned
λ𝜆\lambdaitalic_λ Damping Tuned Tuned
β1subscript𝛽1\beta_{1}italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT Stepsize for preconditioner Tuned Tuned
α1subscript𝛼1\alpha_{1}italic_α start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT Riemannian Momentum (SINGD only) Tuned NA
Table 5: Hyperparameters used for a random search.
Table 6: Peak memory and run time of different optimizers for GCViT on ImageWoof10 (Figure 6, right). Parenthesized values are normalized relative to SGD. For this vision transformer task, we observe that the backpropagation dominates both run time and memory. In this setting, all our methods as well as INGD have basically no run time and memory overhead compared to the first-order methods. INGD and our proposed methods are even able to beat AdamW and SGD in terms of test error. INGD, KFAC and SINGD update their preconditioner every T=5𝑇5T=5italic_T = 5 iterations.
Method Peak memory Training time
[GiB] [min]
SGD (BFP-16) 15.6 (1.00 x) 190 (1.00 x)
AdamW (BFP-16) 15.7 (1.00 x) 191 (1.01 x)
SINGD-Diag* (BFP-16) 15.8 (1.02 x) 200 (1.06 x)
IKFAC* (BFP-16) 16.0 (1.02 x) 197 (1.04 x)
INGD (BFP-16) 16.0 (1.02 x) 203 (1.07 x)
KFAC (FP-32) 16.0 (1.02 x) 359 (1.89 x)

INGD 1:  Each T𝑇Titalic_T iter., update 𝐦Ksubscript𝐦𝐾\mbox{$\mbox{$\mathbf{m}$}$}_{K}bold_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT, 𝐦Csubscript𝐦𝐶\mbox{$\mbox{$\mathbf{m}$}$}_{C}bold_m start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT, 𝐊𝐊\mathbf{K}bold_K, 𝐂𝐂\mathbf{C}bold_C Obtain 𝝁AA𝝁GGtensor-productsubscript𝝁𝐴𝐴subscript𝝁𝐺𝐺\mbox{$\mbox{$\boldsymbol{\mu}$}$}_{AA}\otimes\mbox{$\mbox{$\boldsymbol{\mu}$}% $}_{GG}bold_italic_μ start_POSTSUBSCRIPT italic_A italic_A end_POSTSUBSCRIPT ⊗ bold_italic_μ start_POSTSUBSCRIPT italic_G italic_G end_POSTSUBSCRIPT to approximate μ2(𝝁)superscriptsubscript𝜇2𝝁\nabla_{\mu}^{2}\ell(\mbox{$\mbox{$\boldsymbol{\mu}$}$})∇ start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_ℓ ( bold_italic_μ ) 𝐦Kα1𝐦K+12d(Tr(𝐇C)𝐇K+c2𝐊T𝐊d𝐈p)subscript𝐦𝐾subscript𝛼1subscript𝐦𝐾12𝑑Trsubscript𝐇𝐶subscript𝐇𝐾superscript𝑐2superscript𝐊𝑇𝐊𝑑subscript𝐈𝑝\mbox{$\mbox{$\mathbf{m}$}$}_{K}\leftarrow\alpha_{1}\mbox{$\mbox{$\mathbf{m}$}% $}_{K}+\frac{1}{2d}(\mathrm{Tr}(\mbox{$\mbox{$\mathbf{H}$}$}_{C})\mbox{$\mbox{% $\mathbf{H}$}$}_{K}+c^{2}\mbox{$\mbox{$\mathbf{K}$}$}^{T}\mbox{$\mbox{$\mathbf% {K}$}$}-d\mbox{$\mbox{$\mathbf{I}$}$}_{p})bold_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ← italic_α start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT bold_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT + divide start_ARG 1 end_ARG start_ARG 2 italic_d end_ARG ( roman_Tr ( bold_H start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT ) bold_H start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT + italic_c start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_K start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_K - italic_d bold_I start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ) 𝐦Cα1𝐦C+12p(Tr(𝐇K)𝐇C+κ2𝐂T𝐂p𝐈d)subscript𝐦𝐶subscript𝛼1subscript𝐦𝐶12𝑝Trsubscript𝐇𝐾subscript𝐇𝐶superscript𝜅2superscript𝐂𝑇𝐂𝑝subscript𝐈𝑑\mbox{$\mbox{$\mathbf{m}$}$}_{C}\leftarrow\alpha_{1}\mbox{$\mbox{$\mathbf{m}$}% $}_{C}+\frac{1}{2p}(\mathrm{Tr}(\mbox{$\mbox{$\mathbf{H}$}$}_{K})\mbox{$\mbox{% $\mathbf{H}$}$}_{C}+\kappa^{2}\mbox{$\mbox{$\mathbf{C}$}$}^{T}\mbox{$\mbox{$% \mathbf{C}$}$}-p\mbox{$\mbox{$\mathbf{I}$}$}_{d})bold_m start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT ← italic_α start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT bold_m start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT + divide start_ARG 1 end_ARG start_ARG 2 italic_p end_ARG ( roman_Tr ( bold_H start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ) bold_H start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT + italic_κ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_C start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_C - italic_p bold_I start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ) 𝐊𝐊Expm(β1𝐦K)𝐊(𝐈pβ1𝐦K)𝐊𝐊Expmsubscript𝛽1subscript𝐦𝐾𝐊subscript𝐈𝑝subscript𝛽1subscript𝐦𝐾\mbox{$\mbox{$\mathbf{K}$}$}\leftarrow\mbox{$\mbox{$\mathbf{K}$}$}\mathrm{Expm% }(-\beta_{1}\mbox{$\mbox{$\mathbf{m}$}$}_{K})\approx\mbox{$\mbox{$\mathbf{K}$}% $}(\mbox{$\mbox{$\mathbf{I}$}$}_{p}-\beta_{1}\mbox{$\mbox{$\mathbf{m}$}$}_{K})bold_K ← bold_K roman_Expm ( - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT bold_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ) ≈ bold_K ( bold_I start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT bold_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ) 𝐂𝐂Expm(β1𝐦C)𝐂(𝐈dβ1𝐦C)𝐂𝐂Expmsubscript𝛽1subscript𝐦𝐶𝐂subscript𝐈𝑑subscript𝛽1subscript𝐦𝐶\mbox{$\mbox{$\mathbf{C}$}$}\leftarrow\mbox{$\mbox{$\mathbf{C}$}$}\mathrm{Expm% }(-\beta_{1}\mbox{$\mbox{$\mathbf{m}$}$}_{C})\approx\mbox{$\mbox{$\mathbf{C}$}% $}(\mbox{$\mbox{$\mathbf{I}$}$}_{d}-\beta_{1}\mbox{$\mbox{$\mathbf{m}$}$}_{C})bold_C ← bold_C roman_Expm ( - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT bold_m start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT ) ≈ bold_C ( bold_I start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT bold_m start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT ) 2:   𝐌μα2𝐌μ+𝐂𝐂Tvec1(μ(𝝁))𝐊𝐊T+γvec1(𝝁)subscript𝐌𝜇subscript𝛼2subscript𝐌𝜇superscript𝐂𝐂𝑇superscriptvec1subscript𝜇𝝁superscript𝐊𝐊𝑇𝛾superscriptvec1𝝁\mbox{$\mbox{$\mathbf{M}$}$}_{\mu}\leftarrow\alpha_{2}\mbox{$\mbox{$\mathbf{M}% $}$}_{\mu}+\mbox{$\mbox{$\mathbf{C}$}$}\mbox{$\mbox{$\mathbf{C}$}$}^{T}\mathrm% {vec}^{-1}(\nabla_{\mu}\ell(\mbox{$\mbox{$\boldsymbol{\mu}$}$}))\mbox{$\mbox{$% \mathbf{K}$}$}\mbox{$\mbox{$\mathbf{K}$}$}^{T}+\gamma\mathrm{vec}^{-1}(\mbox{$% \mbox{$\boldsymbol{\mu}$}$})bold_M start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT ← italic_α start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT bold_M start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT + roman_C roman_C start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT roman_vec start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( ∇ start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT roman_ℓ ( bold_italic_μ ) ) roman_K roman_K start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT + italic_γ roman_vec start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_italic_μ ) 3:   𝝁𝝁β2vec(𝐌μ)𝝁𝝁subscript𝛽2vecsubscript𝐌𝜇\mbox{$\mbox{$\boldsymbol{\mu}$}$}\leftarrow\mbox{$\mbox{$\boldsymbol{\mu}$}$}% -\beta_{2}\mathrm{vec}(\mbox{$\mbox{$\mathbf{M}$}$}_{\mu})bold_italic_μ ← bold_italic_μ - italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT roman_vec ( bold_M start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT ) AdamW Optimizer 1:  At iter. t𝑡titalic_t, update 𝐦ssubscript𝐦𝑠\mbox{$\mbox{$\mathbf{m}$}$}_{s}bold_m start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT, 𝐬𝐬\mathbf{s}bold_s Use (μ(𝝁))2superscriptsubscript𝜇𝝁2\left(\nabla_{\mu}\ell(\mbox{$\mbox{$\boldsymbol{\mu}$}$})\right)^{2}( ∇ start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT roman_ℓ ( bold_italic_μ ) ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT to approximate diag(μ2(𝝁))diagsuperscriptsubscript𝜇2𝝁\mathrm{diag}\left(\nabla_{\mu}^{2}\ell(\mbox{$\mbox{$\boldsymbol{\mu}$}$})\right)roman_diag ( ∇ start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_ℓ ( bold_italic_μ ) ) 𝐦s(1β1)𝐦s+β1(μ(𝝁))2subscript𝐦𝑠1subscript𝛽1subscript𝐦𝑠subscript𝛽1superscriptsubscript𝜇𝝁2\mbox{$\mbox{$\mathbf{m}$}$}_{s}\leftarrow(1-\beta_{1})\mbox{$\mbox{$\mathbf{m% }$}$}_{s}+\beta_{1}\left(\nabla_{\mu}\ell(\mbox{$\mbox{$\boldsymbol{\mu}$}$})% \right)^{2}bold_m start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT ← ( 1 - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) bold_m start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT + italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( ∇ start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT roman_ℓ ( bold_italic_μ ) ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT 𝐬2𝐦s/(1(1β1)t)superscript𝐬2subscript𝐦𝑠1superscript1subscript𝛽1𝑡\mbox{$\mbox{$\mathbf{s}$}$}^{2}\leftarrow\nicefrac{{\mbox{$\mbox{$\mathbf{m}$% }$}_{s}}}{{(1-(1-\beta_{1})^{t})}}bold_s start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ← / start_ARG bold_m start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_ARG start_ARG ( 1 - ( 1 - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) end_ARG 𝐬𝐬2+λ𝐬superscript𝐬2𝜆\mbox{$\mbox{$\mathbf{s}$}$}\leftarrow\sqrt{\mbox{$\mbox{$\mathbf{s}$}$}^{2}}+\lambdabold_s ← square-root start_ARG bold_s start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG + italic_λ 2:   𝐦μα2𝐦μ+(1α2)μ(𝝁)subscript𝐦𝜇subscript𝛼2subscript𝐦𝜇1subscript𝛼2subscript𝜇𝝁\mbox{$\mbox{$\mathbf{m}$}$}_{\mu}\leftarrow\alpha_{2}\mbox{$\mbox{$\mathbf{m}% $}$}_{\mu}+(1-\alpha_{2})\nabla_{\mu}\ell(\mbox{$\mbox{$\boldsymbol{\mu}$}$})bold_m start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT ← italic_α start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT bold_m start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT + ( 1 - italic_α start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) ∇ start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT roman_ℓ ( bold_italic_μ ) 𝐌μ𝐬1𝐦μ/(1α2t)subscript𝐌𝜇superscript𝐬1subscript𝐦𝜇1superscriptsubscript𝛼2𝑡\mbox{$\mbox{$\mathbf{M}$}$}_{\mu}\leftarrow\mbox{$\mbox{$\mathbf{s}$}$}^{-1}% \mbox{$\mbox{$\mathbf{m}$}$}_{\mu}/\big{(}1-\alpha_{2}^{t}\big{)}bold_M start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT ← bold_s start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_m start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT / ( 1 - italic_α start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) 3:   𝝁𝝁β2𝐌μ+γ𝝁𝝁𝝁subscript𝛽2subscript𝐌𝜇𝛾𝝁\mbox{$\mbox{$\boldsymbol{\mu}$}$}\leftarrow\mbox{$\mbox{$\boldsymbol{\mu}$}$}% -\beta_{2}\mbox{$\mbox{$\mathbf{M}$}$}_{\mu}+\gamma\mbox{$\mbox{$\boldsymbol{% \mu}$}$}bold_italic_μ ← bold_italic_μ - italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT bold_M start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT + italic_γ bold_italic_μ

Figure 8: Baseline methods in the same notation for a hyperparameter search.
Refer to caption
Figure 9: Test error curves for mixed-precision training on a GCViT model on dataset ‘ImageNet-100’. SINGD has a similar iteration cost as AdamW while achieving better performance.

Appendix C Connection between IKFAC and KFAC

To relate to the KFAC method, we now show that 𝐊new(𝐊new)superscript𝐊newsuperscriptsuperscript𝐊newtop\mbox{$\mbox{$\mathbf{K}$}$}^{\text{new}}\big{(}\mbox{$\mbox{$\mathbf{K}$}$}^{% \text{new}}\big{)}^{\top}bold_K start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT ( bold_K start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT is an approximation of (𝐒Knew+λ𝐈)1superscriptsuperscriptsubscript𝐒𝐾new𝜆𝐈1\big{(}\mbox{$\mbox{$\mathbf{S}$}$}_{K}^{\text{new}}+\lambda\mbox{$\mbox{$% \mathbf{I}$}$}\big{)}^{-1}( bold_S start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT + italic_λ bold_I ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT at a new step of our scheme. For simplicity, we first assume 𝐊𝐊superscript𝐊𝐊top\mbox{$\mbox{$\mathbf{K}$}$}\mbox{$\mbox{$\mathbf{K}$}$}^{\top}roman_K roman_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT exactly equals to (𝐒Kcur+λ𝐈)1superscriptsuperscriptsubscript𝐒𝐾cur𝜆𝐈1\left(\mbox{$\mbox{$\mathbf{S}$}$}_{K}^{\text{cur}}+\lambda\mbox{$\mbox{$% \mathbf{I}$}$}\right)^{-1}( bold_S start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT cur end_POSTSUPERSCRIPT + italic_λ bold_I ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT at the current step. Later, we will relax this assumption and prove that 𝐊𝐊superscript𝐊𝐊top\mbox{$\mbox{$\mathbf{K}$}$}\mbox{$\mbox{$\mathbf{K}$}$}^{\top}roman_K roman_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT is an approximation of (𝐒K+λ𝐈)1superscriptsubscript𝐒𝐾𝜆𝐈1\left(\mbox{$\mbox{$\mathbf{S}$}$}_{K}+\lambda\mbox{$\mbox{$\mathbf{I}$}$}% \right)^{-1}( bold_S start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT + italic_λ bold_I ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT at every step as stated in Theorem 1. For notation simplicity, we denote 𝐒¯K:-𝐒K+λ𝐈:-subscript¯𝐒𝐾subscript𝐒𝐾𝜆𝐈\bar{\mbox{$\mbox{$\mathbf{S}$}$}}_{K}\coloneq\mbox{$\mbox{$\mathbf{S}$}$}_{K}% +\lambda\mbox{$\mbox{$\mathbf{I}$}$}over¯ start_ARG bold_S end_ARG start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT :- bold_S start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT + italic_λ bold_I. The update of 𝐒Ksubscript𝐒𝐾\mbox{$\mbox{$\mathbf{S}$}$}_{K}bold_S start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT with damping λ𝐈𝜆𝐈\lambda\mbox{$\mbox{$\mathbf{I}$}$}italic_λ bold_I can be reexpressed as an update of 𝐒¯Ksubscript¯𝐒𝐾\bar{\mbox{$\mbox{$\mathbf{S}$}$}}_{K}over¯ start_ARG bold_S end_ARG start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT:

(𝐒Knew+λ𝐈)=𝐒¯Knew(1β1)𝐒¯Kcur+β1(𝐔+λ𝐈).superscriptsubscript𝐒𝐾new𝜆𝐈superscriptsubscript¯𝐒𝐾new1subscript𝛽1superscriptsubscript¯𝐒𝐾cursubscript𝛽1𝐔𝜆𝐈\displaystyle\left(\mbox{$\mbox{$\mathbf{S}$}$}_{K}^{\text{new}}+\lambda\mbox{% $\mbox{$\mathbf{I}$}$}\right)=\bar{\mbox{$\mbox{$\mathbf{S}$}$}}_{K}^{\text{% new}}\leftarrow(1-\beta_{1})\bar{\mbox{$\mbox{$\mathbf{S}$}$}}_{K}^{\text{cur}% }+\beta_{1}\left(\mbox{$\mbox{$\mathbf{U}$}$}+\lambda\mbox{$\mbox{$\mathbf{I}$% }$}\right).( bold_S start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT + italic_λ bold_I ) = over¯ start_ARG bold_S end_ARG start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT ← ( 1 - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) over¯ start_ARG bold_S end_ARG start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT cur end_POSTSUPERSCRIPT + italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( bold_U + italic_λ bold_I ) .

Since 𝐒^Kcur=𝐊T𝐊1superscriptsubscript^𝐒𝐾cursuperscript𝐊𝑇superscript𝐊1\hat{\mbox{$\mbox{$\mathbf{S}$}$}}_{K}^{\text{cur}}=\mbox{$\mbox{$\mathbf{K}$}% $}^{-T}\mbox{$\mbox{$\mathbf{K}$}$}^{-1}over^ start_ARG bold_S end_ARG start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT cur end_POSTSUPERSCRIPT = bold_K start_POSTSUPERSCRIPT - italic_T end_POSTSUPERSCRIPT bold_K start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT by our assumption, we can express update of 𝐒Ksubscript𝐒𝐾\mbox{$\mbox{$\mathbf{S}$}$}_{K}bold_S start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT in terms of 𝐊𝐊\mathbf{K}bold_K as follows. 𝐒¯Knew(1β1)𝐒¯Kcur+β1(𝐔+λ𝐈)superscriptsubscript¯𝐒𝐾new1subscript𝛽1superscriptsubscript¯𝐒𝐾cursubscript𝛽1𝐔𝜆𝐈\displaystyle\bar{\mbox{$\mbox{$\mathbf{S}$}$}}_{K}^{\text{new}}\leftarrow(1-% \beta_{1})\bar{\mbox{$\mbox{$\mathbf{S}$}$}}_{K}^{\text{cur}}+\beta_{1}\left(% \mbox{$\mbox{$\mathbf{U}$}$}+\lambda\mbox{$\mbox{$\mathbf{I}$}$}\right)over¯ start_ARG bold_S end_ARG start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT ← ( 1 - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) over¯ start_ARG bold_S end_ARG start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT cur end_POSTSUPERSCRIPT + italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( bold_U + italic_λ bold_I ) =𝐊T(𝐈+β1(𝐊𝐔𝐊+λ𝐊𝐊𝐈))𝐊1=𝐊T(𝐈+β1𝐦K)𝐊1absentsuperscript𝐊𝑇𝐈subscript𝛽1superscript𝐊top𝐔𝐊𝜆superscript𝐊top𝐊𝐈superscript𝐊1superscript𝐊𝑇𝐈subscript𝛽1subscript𝐦𝐾superscript𝐊1\displaystyle=\mbox{$\mbox{$\mathbf{K}$}$}^{-T}\left(\mbox{$\mbox{$\mathbf{I}$% }$}+\beta_{1}\left(\mbox{$\mbox{$\mathbf{K}$}$}^{\top}\mbox{$\mbox{$\mathbf{U}% $}$}\mbox{$\mbox{$\mathbf{K}$}$}+\lambda\mbox{$\mbox{$\mathbf{K}$}$}^{\top}% \mbox{$\mbox{$\mathbf{K}$}$}-\mbox{$\mbox{$\mathbf{I}$}$}\right)\right)\mbox{$% \mbox{$\mathbf{K}$}$}^{-1}=\mbox{$\mbox{$\mathbf{K}$}$}^{-T}\left(\mbox{$\mbox% {$\mathbf{I}$}$}+\beta_{1}\mbox{$\mbox{$\mathbf{m}$}$}_{K}\right)\mbox{$\mbox{% $\mathbf{K}$}$}^{-1}= bold_K start_POSTSUPERSCRIPT - italic_T end_POSTSUPERSCRIPT ( bold_I + italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( bold_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT roman_U roman_K + italic_λ bold_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_K - bold_I ) ) bold_K start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT = bold_K start_POSTSUPERSCRIPT - italic_T end_POSTSUPERSCRIPT ( bold_I + italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT bold_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ) bold_K start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT

𝐒¯Knewsuperscriptsubscript¯𝐒𝐾new\bar{\mbox{$\mbox{$\mathbf{S}$}$}}_{K}^{\text{new}}over¯ start_ARG bold_S end_ARG start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT in the KFAC update can be approximated as below, where we consider 𝐈+β1𝐦K𝐈subscript𝛽1subscript𝐦𝐾\mbox{$\mbox{$\mathbf{I}$}$}+\beta_{1}{\mbox{$\mbox{$\mathbf{m}$}$}_{K}}bold_I + italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT bold_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT as an approximate of the matrix exponential Expm(β1𝐦K)𝐈+β1𝐦KExpmsubscript𝛽1subscript𝐦𝐾𝐈subscript𝛽1subscript𝐦𝐾\mathrm{Expm}(\beta_{1}{\mbox{$\mbox{$\mathbf{m}$}$}_{K}})\approx\mbox{$\mbox{% $\mathbf{I}$}$}+\beta_{1}{\mbox{$\mbox{$\mathbf{m}$}$}_{K}}roman_Expm ( italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT bold_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ) ≈ bold_I + italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT bold_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT and notice that 𝐦Ksubscript𝐦𝐾{\mbox{$\mbox{$\mathbf{m}$}$}_{K}}bold_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT is symmetric. 𝐒¯Knew=𝐊T(𝐈+β1𝐦K)𝐊1𝐊TExpm(β1𝐦K)𝐊1=𝐊TExpm(β12𝐦K)Expm(β12𝐦K)𝐊1.superscriptsubscript¯𝐒𝐾newsuperscript𝐊𝑇𝐈subscript𝛽1subscript𝐦𝐾superscript𝐊1superscript𝐊𝑇Expmsubscript𝛽1subscript𝐦𝐾superscript𝐊1superscript𝐊𝑇Expmsuperscriptsubscript𝛽12subscript𝐦𝐾topExpmsubscript𝛽12subscript𝐦𝐾superscript𝐊1\displaystyle\bar{\mbox{$\mbox{$\mathbf{S}$}$}}_{K}^{\text{new}}=\mbox{$\mbox{% $\mathbf{K}$}$}^{-T}\left(\mbox{$\mbox{$\mathbf{I}$}$}+\beta_{1}{\mbox{$\mbox{% $\mathbf{m}$}$}_{K}}\right)\mbox{$\mbox{$\mathbf{K}$}$}^{-1}\approx\mbox{$% \mbox{$\mathbf{K}$}$}^{-T}\mathrm{Expm}\left(\beta_{1}{\mbox{$\mbox{$\mathbf{m% }$}$}_{K}}\right)\mbox{$\mbox{$\mathbf{K}$}$}^{-1}=\mbox{$\mbox{$\mathbf{K}$}$% }^{-T}\mathrm{Expm}\Big{(}\frac{\beta_{1}}{2}{\mbox{$\mbox{$\mathbf{m}$}$}_{K}% }\Big{)}^{\top}\mathrm{Expm}\Big{(}\frac{\beta_{1}}{2}{\mbox{$\mbox{$\mathbf{m% }$}$}_{K}}\Big{)}\mbox{$\mbox{$\mathbf{K}$}$}^{-1}.over¯ start_ARG bold_S end_ARG start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT = bold_K start_POSTSUPERSCRIPT - italic_T end_POSTSUPERSCRIPT ( bold_I + italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT bold_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ) bold_K start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ≈ bold_K start_POSTSUPERSCRIPT - italic_T end_POSTSUPERSCRIPT roman_Expm ( italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT bold_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ) bold_K start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT = bold_K start_POSTSUPERSCRIPT - italic_T end_POSTSUPERSCRIPT roman_Expm ( divide start_ARG italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG bold_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT roman_Expm ( divide start_ARG italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG bold_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ) bold_K start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT .

Informally, we can see that 𝐊new(𝐊new)superscript𝐊newsuperscriptsuperscript𝐊newtop\mbox{$\mbox{$\mathbf{K}$}$}^{\text{new}}\big{(}\mbox{$\mbox{$\mathbf{K}$}$}^{% \text{new}}\big{)}^{\top}bold_K start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT ( bold_K start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT approximates (𝐒¯Knew)1superscriptsuperscriptsubscript¯𝐒𝐾new1\big{(}\bar{\mbox{$\mbox{$\mathbf{S}$}$}}_{K}^{\text{new}}\big{)}^{-1}( over¯ start_ARG bold_S end_ARG start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT by using the matrix exponential. We can see that 𝐦Ksubscript𝐦𝐾{\mbox{$\mbox{$\mathbf{m}$}$}_{K}}bold_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT stays in a matrix logarithm space.

(𝐒¯Knew)1𝐊Expm(β12𝐦K)Expm(β12𝐦K)𝐊𝐊(𝐈β12𝐦K)(𝐈β12𝐦K)T𝐊=𝐊new(𝐊new)superscriptsuperscriptsubscript¯𝐒𝐾new1𝐊Expmsubscript𝛽12subscript𝐦𝐾Expmsuperscriptsubscript𝛽12subscript𝐦𝐾topsuperscript𝐊top𝐊𝐈subscript𝛽12subscript𝐦𝐾superscript𝐈subscript𝛽12subscript𝐦𝐾𝑇superscript𝐊topsuperscript𝐊newsuperscriptsuperscript𝐊newtop\displaystyle\left(\bar{\mbox{$\mbox{$\mathbf{S}$}$}}_{K}^{\text{new}}\right)^% {-1}\approx\mbox{$\mbox{$\mathbf{K}$}$}\mathrm{Expm}\Big{(}{\color[rgb]{1,0,0}% \definecolor[named]{pgfstrokecolor}{rgb}{1,0,0}-}\frac{\beta_{1}}{2}{\mbox{$% \mbox{$\mathbf{m}$}$}_{K}}\Big{)}\mathrm{Expm}\Big{(}{\color[rgb]{1,0,0}% \definecolor[named]{pgfstrokecolor}{rgb}{1,0,0}-}\frac{\beta_{1}}{2}{\mbox{$% \mbox{$\mathbf{m}$}$}_{K}}\Big{)}^{\top}\mbox{$\mbox{$\mathbf{K}$}$}^{\top}% \approx\mbox{$\mbox{$\mathbf{K}$}$}\Big{(}\mbox{$\mbox{$\mathbf{I}$}$}-\frac{% \beta_{1}}{2}{\mbox{$\mbox{$\mathbf{m}$}$}_{K}}\Big{)}\Big{(}\mbox{$\mbox{$% \mathbf{I}$}$}-\frac{\beta_{1}}{2}{\mbox{$\mbox{$\mathbf{m}$}$}_{K}}\Big{)}^{T% }\mbox{$\mbox{$\mathbf{K}$}$}^{\top}=\mbox{$\mbox{$\mathbf{K}$}$}^{\text{new}}% \big{(}\mbox{$\mbox{$\mathbf{K}$}$}^{\text{new}}\big{)}^{\top}( over¯ start_ARG bold_S end_ARG start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ≈ bold_K roman_Expm ( - divide start_ARG italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG bold_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ) roman_Expm ( - divide start_ARG italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG bold_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ≈ bold_K ( bold_I - divide start_ARG italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG bold_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ) ( bold_I - divide start_ARG italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG bold_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT = bold_K start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT ( bold_K start_POSTSUPERSCRIPT new end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT

Theorem 1 formally shows that 𝐊𝐊superscript𝐊𝐊top\mbox{$\mbox{$\mathbf{K}$}$}\mbox{$\mbox{$\mathbf{K}$}$}^{\top}roman_K roman_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT used in our update is an approximation of (𝐒K+λ𝐈)1superscriptsubscript𝐒𝐾𝜆𝐈1\Big{(}\mbox{$\mbox{$\mathbf{S}$}$}_{K}+\lambda\mbox{$\mbox{$\mathbf{I}$}$}% \Big{)}^{-1}( bold_S start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT + italic_λ bold_I ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT in the KFAC update for every step even when the truncation of the matrix exponential is employed.

Appendix D Proof of Theorem 1

We first consider the following lemmas in order to prove Theorem 1.

Recall that we denote 𝐒¯K:-𝐒K+λ𝐈:-subscript¯𝐒𝐾subscript𝐒𝐾𝜆𝐈\bar{\mbox{$\mbox{$\mathbf{S}$}$}}_{K}\coloneq\mbox{$\mbox{$\mathbf{S}$}$}_{K}% +\lambda\mbox{$\mbox{$\mathbf{I}$}$}over¯ start_ARG bold_S end_ARG start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT :- bold_S start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT + italic_λ bold_I. For notation simplicity, we will drop the subscript K𝐾Kitalic_K in this section and use 𝐒¯tsubscript¯𝐒𝑡\bar{\mbox{$\mbox{$\mathbf{S}$}$}}_{t}over¯ start_ARG bold_S end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT to denote 𝐒¯Ksubscript¯𝐒𝐾\bar{\mbox{$\mbox{$\mathbf{S}$}$}}_{K}over¯ start_ARG bold_S end_ARG start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT at iteration t𝑡titalic_t. Notice that 𝐒¯tsubscript¯𝐒𝑡\bar{\mbox{$\mbox{$\mathbf{S}$}$}}_{t}over¯ start_ARG bold_S end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is non-singular at each iteration t𝑡titalic_t so that we can inverse it in the original KFAC update (see Figure 4).

Lemma D.1.

Consider the following update in the original KFAC update at iteration t𝑡titalic_t.

𝐒¯t:-(1β1)𝐒¯t1+β1(𝐔^t1+λ𝐈):-subscript¯𝐒𝑡1subscript𝛽1subscript¯𝐒𝑡1subscript𝛽1subscript^𝐔𝑡1𝜆𝐈\displaystyle\bar{\mbox{$\mbox{$\mathbf{S}$}$}}_{t}\coloneq(1-\beta_{1})\bar{% \mbox{$\mbox{$\mathbf{S}$}$}}_{t-1}+\beta_{1}\big{(}\hat{\mbox{$\mbox{$\mathbf% {U}$}$}}_{t-1}+\lambda\mbox{$\mbox{$\mathbf{I}$}$}\big{)}over¯ start_ARG bold_S end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT :- ( 1 - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) over¯ start_ARG bold_S end_ARG start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT + italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( over^ start_ARG bold_U end_ARG start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT + italic_λ bold_I )

where 𝐒tsubscript𝐒𝑡\mbox{$\mbox{$\mathbf{S}$}$}_{t}bold_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is the factor 𝐒Ksubscript𝐒𝐾\mbox{$\mbox{$\mathbf{S}$}$}_{K}bold_S start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT used in the original KFAC update, β1subscript𝛽1\beta_{1}italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT is known as the weight of the moving average, and 𝐔^t1subscript^𝐔𝑡1\hat{\mbox{$\mbox{$\mathbf{U}$}$}}_{t-1}over^ start_ARG bold_U end_ARG start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT is a curvature matrix.

The initial factor 𝐒¯0subscript¯𝐒0\bar{\mbox{$\mbox{$\mathbf{S}$}$}}_{0}over¯ start_ARG bold_S end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT can be decomposed as 𝐒¯0=𝐊^0T𝐊^01subscript¯𝐒0superscriptsubscript^𝐊0𝑇superscriptsubscript^𝐊01\bar{\mbox{$\mbox{$\mathbf{S}$}$}}_{0}=\hat{\mbox{$\mbox{$\mathbf{K}$}$}}_{0}^% {-T}\hat{\mbox{$\mbox{$\mathbf{K}$}$}}_{0}^{-1}over¯ start_ARG bold_S end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_T end_POSTSUPERSCRIPT over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT since 𝐒¯0subscript¯𝐒0\bar{\mbox{$\mbox{$\mathbf{S}$}$}}_{0}over¯ start_ARG bold_S end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT as a preconditioning factor is symmetric positive definite.

Define 𝐍^i:-𝐊^0T𝐔^i𝐊^0+λ𝐊^0T𝐊^0𝐈:-subscript^𝐍𝑖superscriptsubscript^𝐊0𝑇subscript^𝐔𝑖subscript^𝐊0𝜆superscriptsubscript^𝐊0𝑇subscript^𝐊0𝐈\hat{\mbox{$\mbox{$\mathbf{N}$}$}}_{i}\coloneq\hat{\mbox{$\mbox{$\mathbf{K}$}$% }}_{0}^{T}\hat{\mbox{$\mbox{$\mathbf{U}$}$}}_{i}\hat{\mbox{$\mbox{$\mathbf{K}$% }$}}_{0}+\lambda\hat{\mbox{$\mbox{$\mathbf{K}$}$}}_{0}^{T}\hat{\mbox{$\mbox{$% \mathbf{K}$}$}}_{0}-\mbox{$\mbox{$\mathbf{I}$}$}over^ start_ARG bold_N end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT :- over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT over^ start_ARG bold_U end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_λ over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT - bold_I.

The Kronecker factor can be reexpressed as

𝐒¯t=𝐊^0T(𝐈+β1i=0t1𝐍^i)𝐊^01+O(β12)subscript¯𝐒𝑡superscriptsubscript^𝐊0𝑇𝐈subscript𝛽1superscriptsubscript𝑖0𝑡1subscript^𝐍𝑖superscriptsubscript^𝐊01𝑂superscriptsubscript𝛽12\displaystyle\bar{\mbox{$\mbox{$\mathbf{S}$}$}}_{t}=\hat{\mbox{$\mbox{$\mathbf% {K}$}$}}_{0}^{-T}\left(\mbox{$\mbox{$\mathbf{I}$}$}+\beta_{1}\sum_{i=0}^{t-1}% \hat{\mbox{$\mbox{$\mathbf{N}$}$}}_{i}\right)\hat{\mbox{$\mbox{$\mathbf{K}$}$}% }_{0}^{-1}+O(\beta_{1}^{2})over¯ start_ARG bold_S end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_T end_POSTSUPERSCRIPT ( bold_I + italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT over^ start_ARG bold_N end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT + italic_O ( italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT )
Lemma D.2.

Consider the following update in our inverse-free KFAC at iteration t𝑡titalic_t.

𝐊t:-𝐊t1(𝐈β12(𝐊t1𝐔t1𝐊t1+λ𝐊t1𝐊t1𝐈)):-subscript𝐊𝑡subscript𝐊𝑡1𝐈subscript𝛽12superscriptsubscript𝐊𝑡1topsubscript𝐔𝑡1subscript𝐊𝑡1𝜆superscriptsubscript𝐊𝑡1topsubscript𝐊𝑡1𝐈\displaystyle{\mbox{$\mbox{$\mathbf{K}$}$}}_{t}\coloneq{\mbox{$\mbox{$\mathbf{% K}$}$}}_{t-1}\left(\mbox{$\mbox{$\mathbf{I}$}$}-\frac{\beta_{1}}{2}\left({% \mbox{$\mbox{$\mathbf{K}$}$}}_{t-1}^{\top}{\mbox{$\mbox{$\mathbf{U}$}$}}_{t-1}% {\mbox{$\mbox{$\mathbf{K}$}$}}_{t-1}+\lambda{\mbox{$\mbox{$\mathbf{K}$}$}}_{t-% 1}^{\top}{\mbox{$\mbox{$\mathbf{K}$}$}}_{t-1}-\mbox{$\mbox{$\mathbf{I}$}$}% \right)\right)bold_K start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT :- bold_K start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT ( bold_I - divide start_ARG italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG ( bold_K start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_U start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT bold_K start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT + italic_λ bold_K start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_K start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT - bold_I ) )

where 𝐊t1𝐔t1𝐊t1superscriptsubscript𝐊𝑡1topsubscript𝐔𝑡1subscript𝐊𝑡1{\mbox{$\mbox{$\mathbf{K}$}$}}_{t-1}^{\top}{\mbox{$\mbox{$\mathbf{U}$}$}}_{t-1% }{\mbox{$\mbox{$\mathbf{K}$}$}}_{t-1}bold_K start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_U start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT bold_K start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT is used in our update and 𝐔t1subscript𝐔𝑡1{\mbox{$\mbox{$\mathbf{U}$}$}}_{t-1}bold_U start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT is a curvature matrix.

Define 𝐍i:-𝐊i𝐔i𝐊i+λ𝐊i𝐊i𝐈:-subscript𝐍𝑖superscriptsubscript𝐊𝑖topsubscript𝐔𝑖subscript𝐊𝑖𝜆superscriptsubscript𝐊𝑖topsubscript𝐊𝑖𝐈{\mbox{$\mbox{$\mathbf{N}$}$}}_{i}\coloneq{\mbox{$\mbox{$\mathbf{K}$}$}}_{i}^{% \top}{\mbox{$\mbox{$\mathbf{U}$}$}}_{i}{\mbox{$\mbox{$\mathbf{K}$}$}}_{i}+% \lambda{\mbox{$\mbox{$\mathbf{K}$}$}}_{i}^{\top}{\mbox{$\mbox{$\mathbf{K}$}$}}% _{i}-\mbox{$\mbox{$\mathbf{I}$}$}bold_N start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT :- bold_K start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_U start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_K start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + italic_λ bold_K start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_K start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - bold_I.

Our update of 𝐊𝐊\mathbf{K}bold_K can be reexpressed as

𝐊t=𝐊0(𝐈β12i=0t1𝐍i)+O(β12)subscript𝐊𝑡subscript𝐊0𝐈subscript𝛽12superscriptsubscript𝑖0𝑡1subscript𝐍𝑖𝑂superscriptsubscript𝛽12\displaystyle{\mbox{$\mbox{$\mathbf{K}$}$}}_{t}={\mbox{$\mbox{$\mathbf{K}$}$}}% _{0}\left(\mbox{$\mbox{$\mathbf{I}$}$}-\frac{\beta_{1}}{2}\sum_{i=0}^{t-1}{% \mbox{$\mbox{$\mathbf{N}$}$}}_{i}\right)+O(\beta_{1}^{2})bold_K start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = bold_K start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_I - divide start_ARG italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT bold_N start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) + italic_O ( italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT )

Moreover, the product 𝐊𝐊superscript𝐊𝐊top{\mbox{$\mbox{$\mathbf{K}$}$}}{\mbox{$\mbox{$\mathbf{K}$}$}}^{\top}italic_K italic_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT can be reexpressed as

𝐊t𝐊t=𝐊0(𝐈β1i=0t1𝐍i)𝐊0+O(β12)subscript𝐊𝑡superscriptsubscript𝐊𝑡topsubscript𝐊0𝐈subscript𝛽1superscriptsubscript𝑖0𝑡1subscript𝐍𝑖superscriptsubscript𝐊0top𝑂superscriptsubscript𝛽12\displaystyle{\mbox{$\mbox{$\mathbf{K}$}$}}_{t}{\mbox{$\mbox{$\mathbf{K}$}$}}_% {t}^{\top}={\mbox{$\mbox{$\mathbf{K}$}$}}_{0}\left(\mbox{$\mbox{$\mathbf{I}$}$% }-\beta_{1}\sum_{i=0}^{t-1}{\mbox{$\mbox{$\mathbf{N}$}$}}_{i}\right){\mbox{$% \mbox{$\mathbf{K}$}$}}_{0}^{\top}+O(\beta_{1}^{2})bold_K start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_K start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT = bold_K start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_I - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT bold_N start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) bold_K start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT + italic_O ( italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT )

Lemma D.3 is useful to establish a relationship between the KFAC update and our inverse-free update.

Lemma D.3.

If we use the same sequence of curvature matrices in both the original KFAC update and our update such as 𝐔^i=𝐔isubscript^𝐔𝑖subscript𝐔𝑖\hat{\mbox{$\mbox{$\mathbf{U}$}$}}_{i}=\mbox{$\mbox{$\mathbf{U}$}$}_{i}over^ start_ARG bold_U end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = bold_U start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT for each iteration i𝑖iitalic_i and 𝐊^0=𝐊0subscript^𝐊0subscript𝐊0\hat{\mbox{$\mbox{$\mathbf{K}$}$}}_{0}={\mbox{$\mbox{$\mathbf{K}$}$}}_{0}over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = bold_K start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT are used on the initialization, we have the following expression.

𝐍i=𝐍^i+O(β1)subscript𝐍𝑖subscript^𝐍𝑖𝑂subscript𝛽1\displaystyle{\mbox{$\mbox{$\mathbf{N}$}$}}_{i}=\hat{\mbox{$\mbox{$\mathbf{N}$% }$}}_{i}+O(\beta_{1})bold_N start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = over^ start_ARG bold_N end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + italic_O ( italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT )

Similarly, we have the following result for 𝐂𝐂\mathbf{C}bold_C.

Theorem 2.

The product 𝐂𝐂superscript𝐂𝐂top\mbox{$\mbox{$\mathbf{C}$}$}\mbox{$\mbox{$\mathbf{C}$}$}^{\top}roman_C roman_C start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT has a first-order accuracy of the KFAC update of (𝐒C+λ𝐈)1superscriptsubscript𝐒𝐶𝜆𝐈1\big{(}\mbox{$\mbox{$\mathbf{S}$}$}_{C}+\lambda\mbox{$\mbox{$\mathbf{I}$}$}% \big{)}^{-1}( bold_S start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT + italic_λ bold_I ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT at each iteration if the update of 𝐂𝐂\mathbf{C}bold_C is updated according to Figure 4 with the truncation of the matrix exponential and these two updates use the same initialization and the same sequence of curvature matrices 𝐆𝐆\mathbf{G}bold_G.

𝐂𝐂=(𝐒C+λ𝐈)1+O(β12)superscript𝐂𝐂topsuperscriptsubscript𝐒𝐶𝜆𝐈1𝑂superscriptsubscript𝛽12\displaystyle{\mbox{$\mbox{$\mathbf{C}$}$}}{\mbox{$\mbox{$\mathbf{C}$}$}}^{% \top}=\big{(}\mbox{$\mbox{$\mathbf{S}$}$}_{C}+\lambda\mbox{$\mbox{$\mathbf{I}$% }$}\big{)}^{-1}+O(\beta_{1}^{2})roman_C roman_C start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT = ( bold_S start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT + italic_λ bold_I ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT + italic_O ( italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT )

D.1 Proof of Lemma D.1

We prove the lemma by induction We first show the base case when t=1𝑡1t=1italic_t = 1. By definition, we have

𝐒¯1subscript¯𝐒1\displaystyle\bar{\mbox{$\mbox{$\mathbf{S}$}$}}_{1}over¯ start_ARG bold_S end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT =(1β1)𝐒¯0+β1(𝐔^0+λ𝐈)absent1subscript𝛽1subscript¯𝐒0subscript𝛽1subscript^𝐔0𝜆𝐈\displaystyle=(1-\beta_{1})\bar{\mbox{$\mbox{$\mathbf{S}$}$}}_{0}+\beta_{1}% \big{(}\hat{\mbox{$\mbox{$\mathbf{U}$}$}}_{0}+\lambda\mbox{$\mbox{$\mathbf{I}$% }$}\big{)}= ( 1 - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) over¯ start_ARG bold_S end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( over^ start_ARG bold_U end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_λ bold_I ) (10)
=(1β1)𝐊^0T𝐊^01+β1(𝐔^0+λ𝐈)absent1subscript𝛽1superscriptsubscript^𝐊0𝑇superscriptsubscript^𝐊01subscript𝛽1subscript^𝐔0𝜆𝐈\displaystyle=(1-\beta_{1})\hat{\mbox{$\mbox{$\mathbf{K}$}$}}_{0}^{-T}\hat{% \mbox{$\mbox{$\mathbf{K}$}$}}_{0}^{-1}+\beta_{1}\big{(}\hat{\mbox{$\mbox{$% \mathbf{U}$}$}}_{0}+\lambda\mbox{$\mbox{$\mathbf{I}$}$}\big{)}= ( 1 - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_T end_POSTSUPERSCRIPT over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT + italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( over^ start_ARG bold_U end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_λ bold_I ) (11)
=𝐊^0T[𝐈+β1(𝐊^0T𝐔^0𝐊^0+λ𝐊^0T𝐊^0𝐈)=𝐍^0]𝐊^01absentsuperscriptsubscript^𝐊0𝑇delimited-[]𝐈subscript𝛽1subscriptsuperscriptsubscript^𝐊0𝑇subscript^𝐔0subscript^𝐊0𝜆superscriptsubscript^𝐊0𝑇subscript^𝐊0𝐈absentsubscript^𝐍0superscriptsubscript^𝐊01\displaystyle=\hat{\mbox{$\mbox{$\mathbf{K}$}$}}_{0}^{-T}\Big{[}\mbox{$\mbox{$% \mathbf{I}$}$}+\beta_{1}\underbrace{\Big{(}\hat{\mbox{$\mbox{$\mathbf{K}$}$}}_% {0}^{T}\hat{\mbox{$\mbox{$\mathbf{U}$}$}}_{0}\hat{\mbox{$\mbox{$\mathbf{K}$}$}% }_{0}+\lambda\hat{\mbox{$\mbox{$\mathbf{K}$}$}}_{0}^{T}\hat{\mbox{$\mbox{$% \mathbf{K}$}$}}_{0}-\mbox{$\mbox{$\mathbf{I}$}$}\Big{)}}_{=\hat{\mbox{$\mbox{$% \mathbf{N}$}$}}_{0}}\Big{]}\hat{\mbox{$\mbox{$\mathbf{K}$}$}}_{0}^{-1}= over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_T end_POSTSUPERSCRIPT [ bold_I + italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT under⏟ start_ARG ( over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT over^ start_ARG bold_U end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_λ over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT - bold_I ) end_ARG start_POSTSUBSCRIPT = over^ start_ARG bold_N end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ] over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT (12)
=𝐊^0T[𝐈+β1𝐍^0]𝐊^01absentsuperscriptsubscript^𝐊0𝑇delimited-[]𝐈subscript𝛽1subscript^𝐍0superscriptsubscript^𝐊01\displaystyle=\hat{\mbox{$\mbox{$\mathbf{K}$}$}}_{0}^{-T}\left[\mbox{$\mbox{$% \mathbf{I}$}$}+\beta_{1}\hat{\mbox{$\mbox{$\mathbf{N}$}$}}_{0}\right]\hat{% \mbox{$\mbox{$\mathbf{K}$}$}}_{0}^{-1}= over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_T end_POSTSUPERSCRIPT [ bold_I + italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT over^ start_ARG bold_N end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ] over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT (13)

Thus, the claim holds when t=1𝑡1t=1italic_t = 1.

Suppose, the claim holds when t=n𝑡𝑛t=nitalic_t = italic_n. By the claim, we have

𝐒¯n=𝐊^0T(𝐈+β1i=0n1𝐍^i)𝐊^01+O(β12)subscript¯𝐒𝑛superscriptsubscript^𝐊0𝑇𝐈subscript𝛽1superscriptsubscript𝑖0𝑛1subscript^𝐍𝑖superscriptsubscript^𝐊01𝑂superscriptsubscript𝛽12\displaystyle\bar{\mbox{$\mbox{$\mathbf{S}$}$}}_{n}=\hat{\mbox{$\mbox{$\mathbf% {K}$}$}}_{0}^{-T}\left(\mbox{$\mbox{$\mathbf{I}$}$}+\beta_{1}\sum_{i=0}^{n-1}% \hat{\mbox{$\mbox{$\mathbf{N}$}$}}_{i}\right)\hat{\mbox{$\mbox{$\mathbf{K}$}$}% }_{0}^{-1}+O(\beta_{1}^{2})over¯ start_ARG bold_S end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT = over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_T end_POSTSUPERSCRIPT ( bold_I + italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n - 1 end_POSTSUPERSCRIPT over^ start_ARG bold_N end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT + italic_O ( italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) (14)

Now, we consider the case when t=n+1𝑡𝑛1t=n+1italic_t = italic_n + 1. Notice that

(1β1)𝐒¯n1subscript𝛽1subscript¯𝐒𝑛\displaystyle(1-\beta_{1})\bar{\mbox{$\mbox{$\mathbf{S}$}$}}_{n}( 1 - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) over¯ start_ARG bold_S end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT =𝐊^0T(𝐈+β1i=0n1𝐍^iβ1𝐈+O(β12))𝐊^01+O(β12)absentsuperscriptsubscript^𝐊0𝑇𝐈subscript𝛽1superscriptsubscript𝑖0𝑛1subscript^𝐍𝑖subscript𝛽1𝐈𝑂superscriptsubscript𝛽12superscriptsubscript^𝐊01𝑂superscriptsubscript𝛽12\displaystyle=\hat{\mbox{$\mbox{$\mathbf{K}$}$}}_{0}^{-T}\left(\mbox{$\mbox{$% \mathbf{I}$}$}+\beta_{1}\sum_{i=0}^{n-1}\hat{\mbox{$\mbox{$\mathbf{N}$}$}}_{i}% -\beta_{1}\mbox{$\mbox{$\mathbf{I}$}$}+O(\beta_{1}^{2})\right)\hat{\mbox{$% \mbox{$\mathbf{K}$}$}}_{0}^{-1}+O(\beta_{1}^{2})= over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_T end_POSTSUPERSCRIPT ( bold_I + italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n - 1 end_POSTSUPERSCRIPT over^ start_ARG bold_N end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT bold_I + italic_O ( italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) ) over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT + italic_O ( italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT )
=𝐊^0T(𝐈+β1i=0n1𝐍^iβ1𝐈)𝐊^01+O(β12)absentsuperscriptsubscript^𝐊0𝑇𝐈subscript𝛽1superscriptsubscript𝑖0𝑛1subscript^𝐍𝑖subscript𝛽1𝐈superscriptsubscript^𝐊01𝑂superscriptsubscript𝛽12\displaystyle=\hat{\mbox{$\mbox{$\mathbf{K}$}$}}_{0}^{-T}\left(\mbox{$\mbox{$% \mathbf{I}$}$}+\beta_{1}\sum_{i=0}^{n-1}\hat{\mbox{$\mbox{$\mathbf{N}$}$}}_{i}% -\beta_{1}\mbox{$\mbox{$\mathbf{I}$}$}\right)\hat{\mbox{$\mbox{$\mathbf{K}$}$}% }_{0}^{-1}+O(\beta_{1}^{2})= over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_T end_POSTSUPERSCRIPT ( bold_I + italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n - 1 end_POSTSUPERSCRIPT over^ start_ARG bold_N end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT bold_I ) over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT + italic_O ( italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT )

By the definition of 𝐒^n+1subscript^𝐒𝑛1\hat{\mbox{$\mbox{$\mathbf{S}$}$}}_{n+1}over^ start_ARG bold_S end_ARG start_POSTSUBSCRIPT italic_n + 1 end_POSTSUBSCRIPT, we have

𝐒¯n+1subscript¯𝐒𝑛1\displaystyle\bar{\mbox{$\mbox{$\mathbf{S}$}$}}_{n+1}over¯ start_ARG bold_S end_ARG start_POSTSUBSCRIPT italic_n + 1 end_POSTSUBSCRIPT =(1β1)𝐒¯n+β1(𝐔^n+λ𝐈)absent1subscript𝛽1subscript¯𝐒𝑛subscript𝛽1subscript^𝐔𝑛𝜆𝐈\displaystyle=(1-\beta_{1})\bar{\mbox{$\mbox{$\mathbf{S}$}$}}_{n}+\beta_{1}% \big{(}\hat{\mbox{$\mbox{$\mathbf{U}$}$}}_{n}+\lambda\mbox{$\mbox{$\mathbf{I}$% }$}\big{)}= ( 1 - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) over¯ start_ARG bold_S end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT + italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( over^ start_ARG bold_U end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT + italic_λ bold_I ) (15)
=𝐊^0T(𝐈+β1i=0n1𝐍^iβ1𝐈+β1𝐊^0T𝐔^n𝐊^0+β1λ𝐊^0T𝐊^0=β1𝐍^n)𝐊^01+O(β12)absentsuperscriptsubscript^𝐊0𝑇𝐈subscript𝛽1superscriptsubscript𝑖0𝑛1subscript^𝐍𝑖subscriptsubscript𝛽1𝐈subscript𝛽1superscriptsubscript^𝐊0𝑇subscript^𝐔𝑛subscript^𝐊0subscript𝛽1𝜆superscriptsubscript^𝐊0𝑇subscript^𝐊0absentsubscript𝛽1subscript^𝐍𝑛superscriptsubscript^𝐊01𝑂superscriptsubscript𝛽12\displaystyle=\hat{\mbox{$\mbox{$\mathbf{K}$}$}}_{0}^{-T}\left(\mbox{$\mbox{$% \mathbf{I}$}$}+\beta_{1}\sum_{i=0}^{n-1}\hat{\mbox{$\mbox{$\mathbf{N}$}$}}_{i}% \underbrace{-\beta_{1}\mbox{$\mbox{$\mathbf{I}$}$}+\beta_{1}\hat{\mbox{$\mbox{% $\mathbf{K}$}$}}_{0}^{T}\hat{\mbox{$\mbox{$\mathbf{U}$}$}}_{n}\hat{\mbox{$% \mbox{$\mathbf{K}$}$}}_{0}+\beta_{1}\lambda\hat{\mbox{$\mbox{$\mathbf{K}$}$}}_% {0}^{T}\hat{\mbox{$\mbox{$\mathbf{K}$}$}}_{0}}_{=\beta_{1}\hat{\mbox{$\mbox{$% \mathbf{N}$}$}}_{n}}\right)\hat{\mbox{$\mbox{$\mathbf{K}$}$}}_{0}^{-1}+O(\beta% _{1}^{2})= over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_T end_POSTSUPERSCRIPT ( bold_I + italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n - 1 end_POSTSUPERSCRIPT over^ start_ARG bold_N end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT under⏟ start_ARG - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT bold_I + italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT over^ start_ARG bold_U end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_λ over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG start_POSTSUBSCRIPT = italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT over^ start_ARG bold_N end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT + italic_O ( italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) (16)
=𝐊^0T(𝐈+β1i=0n𝐍^i)𝐊^01+O(β12)absentsuperscriptsubscript^𝐊0𝑇𝐈subscript𝛽1superscriptsubscript𝑖0𝑛subscript^𝐍𝑖superscriptsubscript^𝐊01𝑂superscriptsubscript𝛽12\displaystyle=\hat{\mbox{$\mbox{$\mathbf{K}$}$}}_{0}^{-T}\left(\mbox{$\mbox{$% \mathbf{I}$}$}+\beta_{1}\sum_{i=0}^{n}\hat{\mbox{$\mbox{$\mathbf{N}$}$}}_{i}% \right)\hat{\mbox{$\mbox{$\mathbf{K}$}$}}_{0}^{-1}+O(\beta_{1}^{2})= over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_T end_POSTSUPERSCRIPT ( bold_I + italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT over^ start_ARG bold_N end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT + italic_O ( italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) (17)

which is exactly the claim when t=n+1𝑡𝑛1t=n+1italic_t = italic_n + 1.

Thus, by induction, the claim holds.

D.2 Proof of Lemma D.2

We prove the lemma by induction We first show the base case when t=1𝑡1t=1italic_t = 1. By definition, we have

𝐊1=𝐊0(𝐈β12(𝐊0𝐔0𝐊0+λ𝐊0𝐊0𝐈)=𝐍0)subscript𝐊1subscript𝐊0𝐈subscript𝛽12subscriptsuperscriptsubscript𝐊0topsubscript𝐔0subscript𝐊0𝜆superscriptsubscript𝐊0topsubscript𝐊0𝐈absentsubscript𝐍0\displaystyle\mbox{$\mbox{$\mathbf{K}$}$}_{1}=\mbox{$\mbox{$\mathbf{K}$}$}_{0}% \Big{(}\mbox{$\mbox{$\mathbf{I}$}$}-\frac{\beta_{1}}{2}\underbrace{\left(\mbox% {$\mbox{$\mathbf{K}$}$}_{0}^{\top}\mbox{$\mbox{$\mathbf{U}$}$}_{0}\mbox{$\mbox% {$\mathbf{K}$}$}_{0}+\lambda\mbox{$\mbox{$\mathbf{K}$}$}_{0}^{\top}\mbox{$% \mbox{$\mathbf{K}$}$}_{0}-\mbox{$\mbox{$\mathbf{I}$}$}\right)}_{=\mbox{$\mbox{% $\mathbf{N}$}$}_{0}}\Big{)}bold_K start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = bold_K start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_I - divide start_ARG italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG under⏟ start_ARG ( bold_K start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_U start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT bold_K start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_λ bold_K start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_K start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT - bold_I ) end_ARG start_POSTSUBSCRIPT = bold_N start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) (18)

Thus, the claim holds when t=1𝑡1t=1italic_t = 1.

Suppose, the claim holds when t=n𝑡𝑛t=nitalic_t = italic_n. By the claim, we have

𝐊n=𝐊0(𝐈β12i=0n1𝐍i)+O(β12)subscript𝐊𝑛subscript𝐊0𝐈subscript𝛽12superscriptsubscript𝑖0𝑛1subscript𝐍𝑖𝑂superscriptsubscript𝛽12\displaystyle\mbox{$\mbox{$\mathbf{K}$}$}_{n}=\mbox{$\mbox{$\mathbf{K}$}$}_{0}% \left(\mbox{$\mbox{$\mathbf{I}$}$}-\frac{\beta_{1}}{2}\sum_{i=0}^{n-1}\mbox{$% \mbox{$\mathbf{N}$}$}_{i}\right)+O(\beta_{1}^{2})bold_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT = bold_K start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_I - divide start_ARG italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n - 1 end_POSTSUPERSCRIPT bold_N start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) + italic_O ( italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) (19)

Now, we consider the case when t=n+1𝑡𝑛1t=n+1italic_t = italic_n + 1. Notice that

𝐊n+1subscript𝐊𝑛1\displaystyle\mbox{$\mbox{$\mathbf{K}$}$}_{n+1}bold_K start_POSTSUBSCRIPT italic_n + 1 end_POSTSUBSCRIPT =𝐊n(𝐈β12(𝐊n𝐔n𝐊n+λ𝐊n𝐊n𝐈)=𝐍n)absentsubscript𝐊𝑛𝐈subscript𝛽12subscriptsuperscriptsubscript𝐊𝑛topsubscript𝐔𝑛subscript𝐊𝑛𝜆superscriptsubscript𝐊𝑛topsubscript𝐊𝑛𝐈absentsubscript𝐍𝑛\displaystyle=\mbox{$\mbox{$\mathbf{K}$}$}_{n}\Big{(}\mbox{$\mbox{$\mathbf{I}$% }$}-\frac{\beta_{1}}{2}\underbrace{\Big{(}\mbox{$\mbox{$\mathbf{K}$}$}_{n}^{% \top}\mbox{$\mbox{$\mathbf{U}$}$}_{n}\mbox{$\mbox{$\mathbf{K}$}$}_{n}+\lambda% \mbox{$\mbox{$\mathbf{K}$}$}_{n}^{\top}\mbox{$\mbox{$\mathbf{K}$}$}_{n}-\mbox{% $\mbox{$\mathbf{I}$}$}\Big{)}}_{=\mbox{$\mbox{$\mathbf{N}$}$}_{n}}\Big{)}= bold_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( bold_I - divide start_ARG italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG under⏟ start_ARG ( bold_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_U start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT bold_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT + italic_λ bold_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT - bold_I ) end_ARG start_POSTSUBSCRIPT = bold_N start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) (20)
=𝐊0(𝐈β12i=0n1𝐍i)=𝐊nO(β12)(𝐈β12𝐍n)+O(β12)absentsubscriptsubscript𝐊0𝐈subscript𝛽12superscriptsubscript𝑖0𝑛1subscript𝐍𝑖absentsubscript𝐊𝑛𝑂superscriptsubscript𝛽12𝐈subscript𝛽12subscript𝐍𝑛𝑂superscriptsubscript𝛽12\displaystyle=\underbrace{\mbox{$\mbox{$\mathbf{K}$}$}_{0}\left(\mbox{$\mbox{$% \mathbf{I}$}$}-\frac{\beta_{1}}{2}\sum_{i=0}^{n-1}\mbox{$\mbox{$\mathbf{N}$}$}% _{i}\right)}_{=\mbox{$\mbox{$\mathbf{K}$}$}_{n}-O(\beta_{1}^{2})}\Big{(}\mbox{% $\mbox{$\mathbf{I}$}$}-\frac{\beta_{1}}{2}\mbox{$\mbox{$\mathbf{N}$}$}_{n}\Big% {)}+O(\beta_{1}^{2})= under⏟ start_ARG bold_K start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_I - divide start_ARG italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n - 1 end_POSTSUPERSCRIPT bold_N start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG start_POSTSUBSCRIPT = bold_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT - italic_O ( italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) end_POSTSUBSCRIPT ( bold_I - divide start_ARG italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG bold_N start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) + italic_O ( italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) (21)
=𝐊0(𝐈β12i=0n1𝐍iβ12𝐍n+O(β12))+O(β12)absentsubscript𝐊0𝐈subscript𝛽12superscriptsubscript𝑖0𝑛1subscript𝐍𝑖subscript𝛽12subscript𝐍𝑛𝑂superscriptsubscript𝛽12𝑂superscriptsubscript𝛽12\displaystyle=\mbox{$\mbox{$\mathbf{K}$}$}_{0}\left(\mbox{$\mbox{$\mathbf{I}$}% $}-\frac{\beta_{1}}{2}\sum_{i=0}^{n-1}\mbox{$\mbox{$\mathbf{N}$}$}_{i}-\frac{% \beta_{1}}{2}\mbox{$\mbox{$\mathbf{N}$}$}_{n}+O(\beta_{1}^{2})\right)+O(\beta_% {1}^{2})= bold_K start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_I - divide start_ARG italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n - 1 end_POSTSUPERSCRIPT bold_N start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - divide start_ARG italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG bold_N start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT + italic_O ( italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) ) + italic_O ( italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) (22)
=𝐊0(𝐈β12i=0n𝐍i)+O(β12)absentsubscript𝐊0𝐈subscript𝛽12superscriptsubscript𝑖0𝑛subscript𝐍𝑖𝑂superscriptsubscript𝛽12\displaystyle=\mbox{$\mbox{$\mathbf{K}$}$}_{0}\left(\mbox{$\mbox{$\mathbf{I}$}% $}-\frac{\beta_{1}}{2}\sum_{i=0}^{n}\mbox{$\mbox{$\mathbf{N}$}$}_{i}\right)+O(% \beta_{1}^{2})= bold_K start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_I - divide start_ARG italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT bold_N start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) + italic_O ( italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) (23)

which is exactly the claim when t=n+1𝑡𝑛1t=n+1italic_t = italic_n + 1.

Thus, by induction, the claim holds.

Notice that 𝐍isubscript𝐍𝑖\mbox{$\mbox{$\mathbf{N}$}$}_{i}bold_N start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT by definition is symmetric. It is easy to see that

𝐊t𝐊tsubscript𝐊𝑡superscriptsubscript𝐊𝑡top\displaystyle\mbox{$\mbox{$\mathbf{K}$}$}_{t}\mbox{$\mbox{$\mathbf{K}$}$}_{t}^% {\top}bold_K start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_K start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT =𝐊0(𝐈β12i=0t1𝐍i)(𝐈β12i=0t1𝐍i)𝐊0+O(β12)absentsubscript𝐊0𝐈subscript𝛽12superscriptsubscript𝑖0𝑡1subscript𝐍𝑖superscript𝐈subscript𝛽12superscriptsubscript𝑖0𝑡1subscript𝐍𝑖topsuperscriptsubscript𝐊0top𝑂superscriptsubscript𝛽12\displaystyle=\mbox{$\mbox{$\mathbf{K}$}$}_{0}\left(\mbox{$\mbox{$\mathbf{I}$}% $}-\frac{\beta_{1}}{2}\sum_{i=0}^{t-1}\mbox{$\mbox{$\mathbf{N}$}$}_{i}\right)% \left(\mbox{$\mbox{$\mathbf{I}$}$}-\frac{\beta_{1}}{2}\sum_{i=0}^{t-1}\mbox{$% \mbox{$\mathbf{N}$}$}_{i}\right)^{\top}\mbox{$\mbox{$\mathbf{K}$}$}_{0}^{\top}% +O(\beta_{1}^{2})= bold_K start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_I - divide start_ARG italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT bold_N start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_I - divide start_ARG italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT bold_N start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_K start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT + italic_O ( italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) (24)
=𝐊0(𝐈β12i=0t1𝐍i)(𝐈β12i=0t1𝐍i)𝐊0+O(β12)absentsubscript𝐊0𝐈subscript𝛽12superscriptsubscript𝑖0𝑡1subscript𝐍𝑖𝐈subscript𝛽12superscriptsubscript𝑖0𝑡1subscript𝐍𝑖superscriptsubscript𝐊0top𝑂superscriptsubscript𝛽12\displaystyle=\mbox{$\mbox{$\mathbf{K}$}$}_{0}\left(\mbox{$\mbox{$\mathbf{I}$}% $}-\frac{\beta_{1}}{2}\sum_{i=0}^{t-1}\mbox{$\mbox{$\mathbf{N}$}$}_{i}\right)% \left(\mbox{$\mbox{$\mathbf{I}$}$}-\frac{\beta_{1}}{2}\sum_{i=0}^{t-1}\mbox{$% \mbox{$\mathbf{N}$}$}_{i}\right)\mbox{$\mbox{$\mathbf{K}$}$}_{0}^{\top}+O(% \beta_{1}^{2})= bold_K start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_I - divide start_ARG italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT bold_N start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_I - divide start_ARG italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT bold_N start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) bold_K start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT + italic_O ( italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) (25)
=𝐊0(𝐈β1i=0t1𝐍i)𝐊0+O(β12)absentsubscript𝐊0𝐈subscript𝛽1superscriptsubscript𝑖0𝑡1subscript𝐍𝑖superscriptsubscript𝐊0top𝑂superscriptsubscript𝛽12\displaystyle=\mbox{$\mbox{$\mathbf{K}$}$}_{0}\left(\mbox{$\mbox{$\mathbf{I}$}% $}-\beta_{1}\sum_{i=0}^{t-1}\mbox{$\mbox{$\mathbf{N}$}$}_{i}\right)\mbox{$% \mbox{$\mathbf{K}$}$}_{0}^{\top}+O(\beta_{1}^{2})= bold_K start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_I - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT bold_N start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) bold_K start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT + italic_O ( italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) (26)

Thus, the claim also holds.

D.3 Proof of Lemma D.3

We first show the base case when t=0𝑡0t=0italic_t = 0. By the assumption, we have 𝐊0=𝐊^0subscript𝐊0subscript^𝐊0\mbox{$\mbox{$\mathbf{K}$}$}_{0}=\hat{\mbox{$\mbox{$\mathbf{K}$}$}}_{0}bold_K start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. Similarly, we have 𝐔0=𝐔^0subscript𝐔0subscript^𝐔0\mbox{$\mbox{$\mathbf{U}$}$}_{0}=\hat{\mbox{$\mbox{$\mathbf{U}$}$}}_{0}bold_U start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = over^ start_ARG bold_U end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT by the assumption.

By definition, we have

𝐍0subscript𝐍0\displaystyle\mbox{$\mbox{$\mathbf{N}$}$}_{0}bold_N start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT =𝐊0𝐔0𝐊0+λ𝐊0𝐊0𝐈absentsuperscriptsubscript𝐊0topsubscript𝐔0subscript𝐊0𝜆superscriptsubscript𝐊0topsubscript𝐊0𝐈\displaystyle={\mbox{$\mbox{$\mathbf{K}$}$}}_{0}^{\top}{\mbox{$\mbox{$\mathbf{% U}$}$}}_{0}{\mbox{$\mbox{$\mathbf{K}$}$}}_{0}+\lambda{\mbox{$\mbox{$\mathbf{K}% $}$}}_{0}^{\top}{\mbox{$\mbox{$\mathbf{K}$}$}}_{0}-\mbox{$\mbox{$\mathbf{I}$}$}= bold_K start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_U start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT bold_K start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_λ bold_K start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_K start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT - bold_I (27)
=𝐊^0𝐔^0𝐊^0+λ𝐊^0𝐊^0𝐈absentsuperscriptsubscript^𝐊0topsubscript^𝐔0subscript^𝐊0𝜆superscriptsubscript^𝐊0topsubscript^𝐊0𝐈\displaystyle=\hat{\mbox{$\mbox{$\mathbf{K}$}$}}_{0}^{\top}\hat{\mbox{$\mbox{$% \mathbf{U}$}$}}_{0}\hat{\mbox{$\mbox{$\mathbf{K}$}$}}_{0}+\lambda\hat{\mbox{$% \mbox{$\mathbf{K}$}$}}_{0}^{\top}\hat{\mbox{$\mbox{$\mathbf{K}$}$}}_{0}-\mbox{% $\mbox{$\mathbf{I}$}$}= over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT over^ start_ARG bold_U end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_λ over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT - bold_I (28)
=𝐍^0absentsubscript^𝐍0\displaystyle=\hat{\mbox{$\mbox{$\mathbf{N}$}$}}_{0}= over^ start_ARG bold_N end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT (29)

Thus, the claim holds when t=0𝑡0t=0italic_t = 0.

When t>0𝑡0t>0italic_t > 0, we can use Lemma D.2 to obtain the claim. Notice that

𝐍n+1subscript𝐍𝑛1\displaystyle{\mbox{$\mbox{$\mathbf{N}$}$}}_{n+1}bold_N start_POSTSUBSCRIPT italic_n + 1 end_POSTSUBSCRIPT =𝐊n+1𝐔n+1𝐊n+1+λ𝐊n+1𝐊n+1𝐈absentsuperscriptsubscript𝐊𝑛1topsubscript𝐔𝑛1subscript𝐊𝑛1𝜆superscriptsubscript𝐊𝑛1topsubscript𝐊𝑛1𝐈\displaystyle=\mbox{$\mbox{$\mathbf{K}$}$}_{n+1}^{\top}\mbox{$\mbox{$\mathbf{U% }$}$}_{n+1}\mbox{$\mbox{$\mathbf{K}$}$}_{n+1}+\lambda\mbox{$\mbox{$\mathbf{K}$% }$}_{n+1}^{\top}\mbox{$\mbox{$\mathbf{K}$}$}_{n+1}-\mbox{$\mbox{$\mathbf{I}$}$}= bold_K start_POSTSUBSCRIPT italic_n + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_U start_POSTSUBSCRIPT italic_n + 1 end_POSTSUBSCRIPT bold_K start_POSTSUBSCRIPT italic_n + 1 end_POSTSUBSCRIPT + italic_λ bold_K start_POSTSUBSCRIPT italic_n + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_K start_POSTSUBSCRIPT italic_n + 1 end_POSTSUBSCRIPT - bold_I (30)
=(𝐈β12i=0n𝐍i)𝐊0(𝐔n+1+λ𝐈)𝐊0(𝐈β12i=0n𝐍i)𝐈+O(β12)(Lemma 2)absentsuperscript𝐈subscript𝛽12superscriptsubscript𝑖0𝑛subscript𝐍𝑖topsuperscriptsubscript𝐊0topsubscript𝐔𝑛1𝜆𝐈subscript𝐊0𝐈subscript𝛽12superscriptsubscript𝑖0𝑛subscript𝐍𝑖𝐈𝑂superscriptsubscript𝛽12(Lemma 2)\displaystyle=\left(\mbox{$\mbox{$\mathbf{I}$}$}-\frac{\beta_{1}}{2}\sum_{i=0}% ^{n}\mbox{$\mbox{$\mathbf{N}$}$}_{i}\right)^{\top}\mbox{$\mbox{$\mathbf{K}$}$}% _{0}^{\top}\big{(}\mbox{$\mbox{$\mathbf{U}$}$}_{n+1}+\lambda\mbox{$\mbox{$% \mathbf{I}$}$}\big{)}\mbox{$\mbox{$\mathbf{K}$}$}_{0}\left(\mbox{$\mbox{$% \mathbf{I}$}$}-\frac{\beta_{1}}{2}\sum_{i=0}^{n}\mbox{$\mbox{$\mathbf{N}$}$}_{% i}\right)-\mbox{$\mbox{$\mathbf{I}$}$}+O(\beta_{1}^{2})\,\,\text{(Lemma 2)}= ( bold_I - divide start_ARG italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT bold_N start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_K start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( bold_U start_POSTSUBSCRIPT italic_n + 1 end_POSTSUBSCRIPT + italic_λ bold_I ) bold_K start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_I - divide start_ARG italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT bold_N start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - bold_I + italic_O ( italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) (Lemma 2) (31)
=𝐊0(𝐔n+1+λ𝐈)𝐊0+O(β1)+O(β12)absentsuperscriptsubscript𝐊0topsubscript𝐔𝑛1𝜆𝐈subscript𝐊0𝑂subscript𝛽1𝑂superscriptsubscript𝛽12\displaystyle=\mbox{$\mbox{$\mathbf{K}$}$}_{0}^{\top}\big{(}\mbox{$\mbox{$% \mathbf{U}$}$}_{n+1}+\lambda\mbox{$\mbox{$\mathbf{I}$}$})\mbox{$\mbox{$\mathbf% {K}$}$}_{0}+O(\beta_{1})+O(\beta_{1}^{2})= bold_K start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( bold_U start_POSTSUBSCRIPT italic_n + 1 end_POSTSUBSCRIPT + italic_λ bold_I ) bold_K start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_O ( italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) + italic_O ( italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) (32)
=𝐊^0(𝐔^n+1+λ𝐈)𝐊^0+O(β1)(Assumption)absentsuperscriptsubscript^𝐊0topsubscript^𝐔𝑛1𝜆𝐈subscript^𝐊0𝑂subscript𝛽1(Assumption)\displaystyle=\hat{\mbox{$\mbox{$\mathbf{K}$}$}}_{0}^{\top}\big{(}\hat{\mbox{$% \mbox{$\mathbf{U}$}$}}_{n+1}+\lambda\mbox{$\mbox{$\mathbf{I}$}$}\big{)}\hat{% \mbox{$\mbox{$\mathbf{K}$}$}}_{0}+O(\beta_{1})\,\,\text{(Assumption)}= over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( over^ start_ARG bold_U end_ARG start_POSTSUBSCRIPT italic_n + 1 end_POSTSUBSCRIPT + italic_λ bold_I ) over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_O ( italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) (Assumption) (33)
=𝐍^n+1+O(β1)absentsubscript^𝐍𝑛1𝑂subscript𝛽1\displaystyle=\hat{\mbox{$\mbox{$\mathbf{N}$}$}}_{n+1}+O(\beta_{1})= over^ start_ARG bold_N end_ARG start_POSTSUBSCRIPT italic_n + 1 end_POSTSUBSCRIPT + italic_O ( italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) (34)

D.4 Proof of Theorem 1

It is sufficient to show that the following claim holds at iteration t𝑡titalic_t since 𝐒¯tsubscript¯𝐒𝑡\bar{\mbox{$\mbox{$\mathbf{S}$}$}}_{t}over¯ start_ARG bold_S end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is non-singular.

𝐊t𝐊t𝐒¯t=𝐈+O(β12)subscript𝐊𝑡subscriptsuperscript𝐊top𝑡subscript¯𝐒𝑡𝐈𝑂superscriptsubscript𝛽12\displaystyle{\mbox{$\mbox{$\mathbf{K}$}$}}_{t}{\mbox{$\mbox{$\mathbf{K}$}$}}^% {\top}_{t}\bar{\mbox{$\mbox{$\mathbf{S}$}$}}_{t}=\mbox{$\mbox{$\mathbf{I}$}$}+% O(\beta_{1}^{2})bold_K start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT over¯ start_ARG bold_S end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = bold_I + italic_O ( italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT )

where we use 𝐒¯tsubscript¯𝐒𝑡\bar{\mbox{$\mbox{$\mathbf{S}$}$}}_{t}over¯ start_ARG bold_S end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT to denote 𝐒¯Ksubscript¯𝐒𝐾\bar{\mbox{$\mbox{$\mathbf{S}$}$}}_{K}over¯ start_ARG bold_S end_ARG start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT at iteration t𝑡titalic_t.

By assumptions, we know that Lemmas D.1, D.2, D.3 hold. Moreover, we have 𝐊0=𝐊^0subscript𝐊0subscript^𝐊0\mbox{$\mbox{$\mathbf{K}$}$}_{0}=\hat{\mbox{$\mbox{$\mathbf{K}$}$}}_{0}bold_K start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. Thus, we have

𝐊t𝐊t𝐒¯tsubscript𝐊𝑡subscriptsuperscript𝐊top𝑡subscript¯𝐒𝑡\displaystyle{\mbox{$\mbox{$\mathbf{K}$}$}}_{t}{\mbox{$\mbox{$\mathbf{K}$}$}}^% {\top}_{t}\bar{\mbox{$\mbox{$\mathbf{S}$}$}}_{t}bold_K start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT over¯ start_ARG bold_S end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT =𝐊0(𝐈β1i=0t1𝐍i)𝐊0𝐒¯t+O(β12) (by Lemma D.2)absentsubscript𝐊0𝐈subscript𝛽1superscriptsubscript𝑖0𝑡1subscript𝐍𝑖superscriptsubscript𝐊0topsubscript¯𝐒𝑡𝑂superscriptsubscript𝛽12 (by Lemma D.2)\displaystyle=\mbox{$\mbox{$\mathbf{K}$}$}_{0}\left(\mbox{$\mbox{$\mathbf{I}$}% $}-\beta_{1}\sum_{i=0}^{t-1}\mbox{$\mbox{$\mathbf{N}$}$}_{i}\right)\mbox{$% \mbox{$\mathbf{K}$}$}_{0}^{\top}\bar{\mbox{$\mbox{$\mathbf{S}$}$}}_{t}+O(\beta% _{1}^{2})\text{ (by Lemma \ref{lemma:ours_identity}) }= bold_K start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_I - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT bold_N start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) bold_K start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT over¯ start_ARG bold_S end_ARG start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + italic_O ( italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) (by Lemma ) (35)
=𝐊0(𝐈β1i=0t1𝐍i)𝐊0𝐊^0T(𝐈+β1i=0t1𝐍^i)𝐊^01+O(β12) (by Lemma D.1)absentsubscript𝐊0𝐈subscript𝛽1superscriptsubscript𝑖0𝑡1subscript𝐍𝑖superscriptsubscript𝐊0topsuperscriptsubscript^𝐊0𝑇𝐈subscript𝛽1superscriptsubscript𝑖0𝑡1subscript^𝐍𝑖superscriptsubscript^𝐊01𝑂superscriptsubscript𝛽12 (by Lemma D.1)\displaystyle=\mbox{$\mbox{$\mathbf{K}$}$}_{0}\left(\mbox{$\mbox{$\mathbf{I}$}% $}-\beta_{1}\sum_{i=0}^{t-1}\mbox{$\mbox{$\mathbf{N}$}$}_{i}\right)\mbox{$% \mbox{$\mathbf{K}$}$}_{0}^{\top}\hat{\mbox{$\mbox{$\mathbf{K}$}$}}_{0}^{-T}% \left(\mbox{$\mbox{$\mathbf{I}$}$}+\beta_{1}\sum_{i=0}^{t-1}\hat{\mbox{$\mbox{% $\mathbf{N}$}$}}_{i}\right)\hat{\mbox{$\mbox{$\mathbf{K}$}$}}_{0}^{-1}+O(\beta% _{1}^{2})\text{ (by Lemma \ref{lemma:kfac_identity}) }= bold_K start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_I - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT bold_N start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) bold_K start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_T end_POSTSUPERSCRIPT ( bold_I + italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT over^ start_ARG bold_N end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT + italic_O ( italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) (by Lemma ) (36)
=𝐊^0(𝐈β1i=0t1𝐍^i+O(β12))(𝐈+β1i=0t1𝐍^i)𝐊^01+O(β12) (by Lemma D.3)absentsubscript^𝐊0𝐈subscript𝛽1superscriptsubscript𝑖0𝑡1subscript^𝐍𝑖𝑂superscriptsubscript𝛽12𝐈subscript𝛽1superscriptsubscript𝑖0𝑡1subscript^𝐍𝑖superscriptsubscript^𝐊01𝑂superscriptsubscript𝛽12 (by Lemma D.3)\displaystyle=\hat{\mbox{$\mbox{$\mathbf{K}$}$}}_{0}\left(\mbox{$\mbox{$% \mathbf{I}$}$}-\beta_{1}\sum_{i=0}^{t-1}\hat{\mbox{$\mbox{$\mathbf{N}$}$}}_{i}% +O(\beta_{1}^{2})\right)\left(\mbox{$\mbox{$\mathbf{I}$}$}+\beta_{1}\sum_{i=0}% ^{t-1}\hat{\mbox{$\mbox{$\mathbf{N}$}$}}_{i}\right)\hat{\mbox{$\mbox{$\mathbf{% K}$}$}}_{0}^{-1}+O(\beta_{1}^{2})\text{ (by Lemma \ref{lemma:kfac_and_ours}) }= over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_I - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT over^ start_ARG bold_N end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + italic_O ( italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) ) ( bold_I + italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT over^ start_ARG bold_N end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT + italic_O ( italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) (by Lemma ) (37)
=𝐊^0𝐈𝐊^01+O(β12)absentsubscript^𝐊0𝐈superscriptsubscript^𝐊01𝑂superscriptsubscript𝛽12\displaystyle=\hat{\mbox{$\mbox{$\mathbf{K}$}$}}_{0}\mbox{$\mbox{$\mathbf{I}$}% $}\hat{\mbox{$\mbox{$\mathbf{K}$}$}}_{0}^{-1}+O(\beta_{1}^{2})= over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT bold_I over^ start_ARG bold_K end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT + italic_O ( italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) (38)
=𝐈+O(β12)absent𝐈𝑂superscriptsubscript𝛽12\displaystyle=\mbox{$\mbox{$\mathbf{I}$}$}+O(\beta_{1}^{2})= bold_I + italic_O ( italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) (39)

Appendix E Invariance of INGD and SINGD

INGD and SINGD are scale invariant to the choice of the Kronecker approximation while KFAC and IKFAC are not. Recall that we use the following Kronecker approximation to approximate the Hessian.

𝐔𝐆μ2(𝝁)tensor-product𝐔𝐆superscriptsubscript𝜇2𝝁\displaystyle\mbox{$\mbox{$\mathbf{U}$}$}\otimes\mbox{$\mbox{$\mathbf{G}$}$}% \approx\nabla_{\mu}^{2}\ell(\mbox{$\mbox{$\boldsymbol{\mu}$}$})bold_U ⊗ bold_G ≈ ∇ start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_ℓ ( bold_italic_μ )

However, such an approximation is not unique. We can consider an equivalent approximation such as

(α𝐔)(α1𝐆)μ2(𝝁)tensor-product𝛼𝐔superscript𝛼1𝐆superscriptsubscript𝜇2𝝁\displaystyle(\alpha\mbox{$\mbox{$\mathbf{U}$}$})\otimes(\alpha^{-1}\mbox{$% \mbox{$\mathbf{G}$}$})\approx\nabla_{\mu}^{2}\ell(\mbox{$\mbox{$\boldsymbol{% \mu}$}$})( italic_α bold_U ) ⊗ ( italic_α start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_G ) ≈ ∇ start_POSTSUBSCRIPT italic_μ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_ℓ ( bold_italic_μ )

where α0𝛼0\alpha\neq 0italic_α ≠ 0 can be any arbitrary non-zero scalar.

INGD is invariant since the update scheme involving the approximation is scale invariant: Tr(𝐇C)𝐇K=Tr(𝐂T𝐆𝐂)𝐊T𝐔𝐊=Tr(𝐂T(α1𝐆)𝐂)𝐊T(α𝐔)𝐊Trsubscript𝐇𝐶subscript𝐇𝐾Trsuperscript𝐂𝑇𝐆𝐂superscript𝐊𝑇𝐔𝐊Trsuperscript𝐂𝑇superscript𝛼1𝐆𝐂superscript𝐊𝑇𝛼𝐔𝐊\mathrm{Tr}(\mbox{$\mbox{$\mathbf{H}$}$}_{C})\mbox{$\mbox{$\mathbf{H}$}$}_{K}=% \mathrm{Tr}(\mbox{$\mbox{$\mathbf{C}$}$}^{T}\mbox{$\mbox{$\mathbf{G}$}$}\mbox{% $\mbox{$\mathbf{C}$}$})\mbox{$\mbox{$\mathbf{K}$}$}^{T}\mbox{$\mbox{$\mathbf{U% }$}$}\mbox{$\mbox{$\mathbf{K}$}$}=\mathrm{Tr}(\mbox{$\mbox{$\mathbf{C}$}$}^{T}% (\alpha^{-1}\mbox{$\mbox{$\mathbf{G}$}$})\mbox{$\mbox{$\mathbf{C}$}$})\mbox{$% \mbox{$\mathbf{K}$}$}^{T}(\alpha\mbox{$\mbox{$\mathbf{U}$}$})\mbox{$\mbox{$% \mathbf{K}$}$}roman_Tr ( bold_H start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT ) bold_H start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT = roman_Tr ( bold_C start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT roman_G roman_C ) bold_K start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT roman_U roman_K = roman_Tr ( bold_C start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( italic_α start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_G ) bold_C ) bold_K start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( italic_α bold_U ) bold_K. The invariance is also preserved in SINGD since structures and their subspace projection maps are closed under scalar multiplications.

In contrast, the updates of KFAC and IKFAC are not scale invariant. As an example, we consider using curvature approximations 𝐔𝐔\mathbf{U}bold_U and (α𝐔)𝛼𝐔(\alpha\mbox{$\mbox{$\mathbf{U}$}$})( italic_α bold_U ) to update 𝐒K1superscriptsubscript𝐒𝐾1\mbox{$\mbox{$\mathbf{S}$}$}_{K}^{-1}bold_S start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT in KFAC, and denote the updated 𝐒K1superscriptsubscript𝐒𝐾1\mbox{$\mbox{$\mathbf{S}$}$}_{K}^{-1}bold_S start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT by 𝐒^K1superscriptsubscript^𝐒𝐾1\hat{\mbox{$\mbox{$\mathbf{S}$}$}}_{K}^{-1}over^ start_ARG bold_S end_ARG start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT and 𝐒¯K1superscriptsubscript¯𝐒𝐾1\bar{\mbox{$\mbox{$\mathbf{S}$}$}}_{K}^{-1}over¯ start_ARG bold_S end_ARG start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT, respectively. As shown below, we cannot recover 𝐒^K1superscriptsubscript^𝐒𝐾1\hat{\mbox{$\mbox{$\mathbf{S}$}$}}_{K}^{-1}over^ start_ARG bold_S end_ARG start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT from 𝐒¯K1superscriptsubscript¯𝐒𝐾1\bar{\mbox{$\mbox{$\mathbf{S}$}$}}_{K}^{-1}over¯ start_ARG bold_S end_ARG start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT by scale transformations and thus, the KFAC update is not scale invariant.

𝐒^K1=[(1β1)𝐒^K+β1𝐔+λ𝐈]1[(1β1)𝐒¯K+β1(α𝐔)+λ𝐈]1=𝐒¯K1superscriptsubscript^𝐒𝐾1superscriptdelimited-[]1subscript𝛽1subscript^𝐒𝐾subscript𝛽1𝐔𝜆𝐈1superscriptdelimited-[]1subscript𝛽1subscript¯𝐒𝐾subscript𝛽1𝛼𝐔𝜆𝐈1superscriptsubscript¯𝐒𝐾1\displaystyle\hat{\mbox{$\mbox{$\mathbf{S}$}$}}_{K}^{-1}=\big{[}(1-\beta_{1})% \hat{\mbox{$\mbox{$\mathbf{S}$}$}}_{K}+\beta_{1}\mbox{$\mbox{$\mathbf{U}$}$}+% \lambda\mbox{$\mbox{$\mathbf{I}$}$}\big{]}^{-1}\neq\big{[}(1-\beta_{1})\bar{% \mbox{$\mbox{$\mathbf{S}$}$}}_{K}+\beta_{1}(\alpha\mbox{$\mbox{$\mathbf{U}$}$}% )+\lambda\mbox{$\mbox{$\mathbf{I}$}$}\big{]}^{-1}=\bar{\mbox{$\mbox{$\mathbf{S% }$}$}}_{K}^{-1}over^ start_ARG bold_S end_ARG start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT = [ ( 1 - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) over^ start_ARG bold_S end_ARG start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT + italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT bold_U + italic_λ bold_I ] start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ≠ [ ( 1 - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) over¯ start_ARG bold_S end_ARG start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT + italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_α bold_U ) + italic_λ bold_I ] start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT = over¯ start_ARG bold_S end_ARG start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT

An attempt to make the update of 𝐒Ksubscript𝐒𝐾\mbox{$\mbox{$\mathbf{S}$}$}_{K}bold_S start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT invariant is to set the damping weight to be αλ𝛼𝜆\alpha\lambdaitalic_α italic_λ. However, the update of 𝐒Csubscript𝐒𝐶\mbox{$\mbox{$\mathbf{S}$}$}_{C}bold_S start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT requires us to set the damping weight to be α1λsuperscript𝛼1𝜆\alpha^{-1}\lambdaitalic_α start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_λ as shown below. Thus, it is impossible to make KFAC invariant without introducing individual damping weights.

𝐒^C1=[(1β1)𝐒^C+β1𝐆+λ𝐈]1[(1β1)𝐒¯C+β1(α1𝐆)+λ𝐈]1=𝐒¯C1superscriptsubscript^𝐒𝐶1superscriptdelimited-[]1subscript𝛽1subscript^𝐒𝐶subscript𝛽1𝐆𝜆𝐈1superscriptdelimited-[]1subscript𝛽1subscript¯𝐒𝐶subscript𝛽1superscript𝛼1𝐆𝜆𝐈1superscriptsubscript¯𝐒𝐶1\displaystyle\hat{\mbox{$\mbox{$\mathbf{S}$}$}}_{C}^{-1}=\big{[}(1-\beta_{1})% \hat{\mbox{$\mbox{$\mathbf{S}$}$}}_{C}+\beta_{1}\mbox{$\mbox{$\mathbf{G}$}$}+% \lambda\mbox{$\mbox{$\mathbf{I}$}$}\big{]}^{-1}\neq\big{[}(1-\beta_{1})\bar{% \mbox{$\mbox{$\mathbf{S}$}$}}_{C}+\beta_{1}(\alpha^{-1}\mbox{$\mbox{$\mathbf{G% }$}$})+\lambda\mbox{$\mbox{$\mathbf{I}$}$}\big{]}^{-1}=\bar{\mbox{$\mbox{$% \mathbf{S}$}$}}_{C}^{-1}over^ start_ARG bold_S end_ARG start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT = [ ( 1 - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) over^ start_ARG bold_S end_ARG start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT + italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT bold_G + italic_λ bold_I ] start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ≠ [ ( 1 - italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) over¯ start_ARG bold_S end_ARG start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT + italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_α start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_G ) + italic_λ bold_I ] start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT = over¯ start_ARG bold_S end_ARG start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT