GrootVL: Tree Topology is All You Need
in State Space Model

Yicheng Xiao1 ,   Lin Song🖂2,3superscript🖂23{}^{2,3}\textsuperscript{\Letter}start_FLOATSUPERSCRIPT 2 , 3 end_FLOATSUPERSCRIPT,
Shaoli Huang3, Jiangshan Wang1, Siyu Song4, Yixiao Ge2,3, Xiu Li🖂1superscript🖂1{}^{1}\textsuperscript{\Letter}start_FLOATSUPERSCRIPT 1 end_FLOATSUPERSCRIPT, Ying Shan2,3

1Tsinghua Shenzhen International Graduate School, Tsinghua University
2ARC Lab, Tencent PCG  3Tencent AI Lab  4South China Normal University
[email protected][email protected]
Equal contribution. Work done during an internship at Tencent. 🖂 Corresponding author.
Abstract

The state space models, employing recursively propagated features, demonstrate strong representation capabilities comparable to Transformer models and superior efficiency. However, constrained by the inherent geometric constraints of sequences, it still falls short in modeling long-range dependencies. To address this issue, we propose the GrootVL network, which first dynamically generates a tree topology based on spatial relationships and input features. Then, feature propagation is performed based on this graph, thereby breaking the original sequence constraints to achieve stronger representation capabilities. Additionally, we introduce a linear complexity dynamic programming algorithm to enhance long-range interactions without increasing computational cost. GrootVL is a versatile multimodal framework that can be applied to both visual and textual tasks. Extensive experiments demonstrate that our method significantly outperforms existing structured state space models on image classification, object detection and segmentation. Besides, by fine-tuning large language models, our approach achieves consistent improvements in multiple textual tasks at minor training cost. Code is available at https://github.com/EasonXiao-888/GrootVL.

1 Introduction

Mainstream fundamental models are primarily based on CNN [27, 57, 41, 29, 13] and Transformer architectures [15, 40, 39, 54, 14], which dominate in visual and language tasks. However, the small receptive field of CNNs and the high complexity of Transformers make it challenging to strike a good balance between effectiveness and efficiency. The state space models (SSMs) [21, 23, 48] attempt to disrupt this impasse, which model sequences in a recurrent form. Different from the previous recurrent neural networks [28, 7], these approaches draw inspiration from control systems, leveraging structural parameter initialization to attain stable optimization and superior computing performance. Nevertheless, it remains susceptible to the intrinsic flaw shared by recurrent neural networks, i.e.formulae-sequence𝑖𝑒i.e.italic_i . italic_e ., a deficiency in capturing long-range dependencies.

Recently, an improved selection mechanism known as Mamba [18] is proposed to mitigate the challenges of SSMs. This approach introduces weight modulation during the propagation process, which substantially enlarges the effective receptive field and achieves impressive performance in NLP tasks. Besides, numerous studies aim to extend Mamba into computer vision, by employing various pre-defined strategies to map 2D image features into 1D sequences. ViM [70] and VMamba [38] utilize a multi-directional raster-scanning strategy, while LocalMamba [31] further confines its propagation range within a local window. They have successfully adapted Mamba to image inputs. Nevertheless, as shown in Fig. 1(a), both raster-scanning and local-scanning strategies introduce spatial discontinuities between adjacent pixels, and feature transformations in Mamba rely on the feature relationships, thereby impeding the effective information flow in a sequence. Additionally, PlainMamba [62] introduces a continuous scanning strategy, aiming to alleviate this issue by simply adjusting the propagation direction at discontinuous positions. However, all these methods rely on fixed propagation trajectories, which ignore the inherent spatial structure and cannot dynamically adjust the topology based on input. Therefore, this paper endeavors to explore a new perspective: introducing an input-aware topological network for feature propagation in state space models.

To achieve it, we develop a tree state space model and propose a new framework, termed GrootVL, which adaptively generates a tree topology based on the input feature and then performs feature propagation on it. Specifically, two sub-networks, GrootV and GrootL, are designed for visual and language tasks respectively, which are illustrated in  Fig. 1(b) and  Fig. 1(d). For visual tasks, motivated by [64, 50], we first utilize the dissimilarity between adjacent features to construct a minimum spanning tree on a four-connected planner graph. This process can adaptively encode the spatial and semantic information into a tree graph [64, 50]. Then, we iteratively traverse each pixel, considering it as the root vertex, and aggregate the features of other pixels using the state transition function of Mamba. Intuitively, this operation requires two levels of traversal across the entire pixel set, resulting in an unacceptable quadratic complexity relative to the number of pixels. However, given that the tree graph is acyclic, we propose a dynamic programming algorithm to achieve linear complexity propagation. With such an input-aware tree topology, our approach enables more effective long-range interactions while maintaining consistent linear complexity with Mamba. Furthermore, our method can also be applied to language tasks by constructing a tree typology based on the dissimilarity between token features, which overcomes the geometrical constraints of the text sequence. Using a similar aggregation process as GrootV, GrootL can significantly enhance the language representation of a pre-trained Large Language Model [18].

We conduct extensive experiments to validate the effectiveness of GrootV on multiple visual benchmarks, i.e.formulae-sequence𝑖𝑒i.e.italic_i . italic_e . image classification on ImageNet [12], object detection and instance segmentation on MSCOCO [36] as well as semantic segmentation on ADE20K [68]. Results show that our method notably outperforms existing SSM-based methods for all benchmarks and achieves competitive performance with CNN and Transformer-based approaches. Moreover, with LoRA finetuning [30], GrootL demonstrates consistent improvements for a pre-trained large language model at minor training cost.

2 Related Work

2.1 Conventional Vision Foundation Models

The evolution of deep neural networks has been a significant catalyst in machine vision perception. CNN-based models [27, 47, 32, 24, 56, 65, 35, 51, 66] firstly emerge as pivotal landmarks, with ResNet [27] notably standing out for its inventive residual connection module, garnering widespread adoption across diverse domains of visual recognition. Furthermore, more efficient convolution operations are formulated, such as depth-wise convolutions introduced by MobileNet [29], paving the way for lightweight models. Additionally, deformable convolution [10] has been proposed to enhance the receptive field. Subsequently, ViT [15] has significantly improved the vision recognition paradigm. It reformulates the architecture design and training mechanism by combining transformer architecture in natural language processing, aiming to improve computational efficiency and broaden the scope of applications. After research discourse is centred on hierarchical ViTs [40, 39, 11, 58, 14, 52, 5] which design networks by decreasing feature resolution across the backbone gradually. Furthermore, recent research built on CNN serves to re-emphasize the capabilities of convolutional networks. For example, InternImage [57] presents a large model based on deformable CNN, while UniRepLKNet [13] exhibits significant performance through large kernel convolution.

2.2 Explorations about State Space Models

State space models (SSMs) have emerged as a novel class of models within the deep learning paradigm, showing significant potential for sequence transforming [22, 21, 48]. These methods have attracted significant attention due to their linear scalability with sequence length. The early method, LSSL [22], draws inspiration from continuous state space models in control systems and attempts to address the long-range dependency problem through a combination with HIPPO [19] initialization. S4 [21] proposes to normalize the parameters into a diagonal matrix, prompting a subsequent series of research on structured SSMs [23, 20, 25, 18]. Recently, the Selective State Space Model [18], known as Mamba, strikes a balance between effectiveness and efficiency through the design of an input-dependent parameter initialization strategy, which has emerged as a formidable competitor to both transformer and CNN structures. In addition to showcasing superior outcomes in sequence modeling, Mamba has been seamlessly incorporated into the visual domain [70, 38, 31, 62]. These studies often rely on handcrafted fixed scanning mechanisms to mitigate the execution bias of the selective state space model on 2D non-causal images. However, such simplistic approaches cannot effectively capture spatial relationships in an input-dependent paradigm. To address this limitation, we propose an effective framework GrootVL in this work to enhance long-range modeling for both vision and language tasks by introducing an input-aware tree-based topological structure.

Refer to caption
Figure 1: Comparison of different propagation strategies for multi-modal tasks. For visual tasks, the previous strategies (a) are based on fixed patterns, while our method can adaptively generate the propagation topology according to input features. For textual tasks, compared to previous methods (c), our approach (d) can break the inherent constraints of text sequences, facilitating the effective transmission of long-range information.

3 Method

In this section, we first revisit the selective state space model [18] and then elaborate on our input-aware topology scanning algorithm for state space modeling. With this superior algorithm, we develop a tree SSM and propose a novel framework called GrootVL, which consists of two sub-networks: GrootV for visual tasks and GrootL for fine-tuning a pre-trained language model [18].

3.1 Revisiting Selective State Space Model

State Space Models (SSMs) are commonly regarded as continuous linear time-invariant systems [59] that map input stimulation x(t)1×D𝑥𝑡superscript1𝐷x(t)\in\mathbb{R}^{1\times D}italic_x ( italic_t ) ∈ blackboard_R start_POSTSUPERSCRIPT 1 × italic_D end_POSTSUPERSCRIPT to output signal y(t)1×D𝑦𝑡superscript1𝐷y(t)\in\mathbb{R}^{1\times D}italic_y ( italic_t ) ∈ blackboard_R start_POSTSUPERSCRIPT 1 × italic_D end_POSTSUPERSCRIPT through a state vector h(t)1×N𝑡superscript1𝑁h(t)\in\mathbb{R}^{1\times N}italic_h ( italic_t ) ∈ blackboard_R start_POSTSUPERSCRIPT 1 × italic_N end_POSTSUPERSCRIPT, where t𝑡titalic_t, D𝐷Ditalic_D and N𝑁Nitalic_N indicate the time step, channel number of the signal and state size, respectively. These models can be formulated as the following linear ordinary differential equations:

h(t)=𝐀h(t)+𝐁x(t),y(t)=𝐂h(t)+𝐃x(t),formulae-sequencesuperscript𝑡𝐀𝑡𝐁𝑥𝑡𝑦𝑡𝐂𝑡𝐃𝑥𝑡\displaystyle h^{\prime}(t)=\mathbf{A}h(t)+\mathbf{B}x(t),\quad y(t)=\mathbf{C% }h(t)+\mathbf{\mathbf{D}}x(t),italic_h start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_t ) = bold_A italic_h ( italic_t ) + bold_B italic_x ( italic_t ) , italic_y ( italic_t ) = bold_C italic_h ( italic_t ) + bold_D italic_x ( italic_t ) , (1)

where 𝐀N×N𝐀superscript𝑁𝑁\mathbf{A}\in\mathbb{R}^{N\times N}bold_A ∈ blackboard_R start_POSTSUPERSCRIPT italic_N × italic_N end_POSTSUPERSCRIPT, 𝐁N×D𝐁superscript𝑁𝐷\mathbf{B}\in\mathbb{R}^{N\times D}bold_B ∈ blackboard_R start_POSTSUPERSCRIPT italic_N × italic_D end_POSTSUPERSCRIPT, 𝐂N×D𝐂superscript𝑁𝐷\mathbf{C}\in\mathbb{R}^{N\times D}bold_C ∈ blackboard_R start_POSTSUPERSCRIPT italic_N × italic_D end_POSTSUPERSCRIPT and feedthrough coefficient 𝐃D𝐃superscript𝐷\mathbf{D}\in\mathbb{R}^{D}bold_D ∈ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT.

Discretization.

Although SSM serves as a powerful tool in systems and control engineering, its time-continuous nature poses challenges for integration into deep learning architectures. To alleviate this issue, most methods utilize the zero-order hold rule [18] to discretize the continuous system described by Eq. 1 and convert continuous variables (𝐀𝐀\mathbf{A}bold_A, 𝐁𝐁\mathbf{B}bold_B, 𝐂𝐂\mathbf{C}bold_C, 𝐃𝐃\mathbf{D}bold_D) into corresponding discrete parameters (𝐀¯¯𝐀\bar{\mathbf{A}}over¯ start_ARG bold_A end_ARG, 𝐁¯¯𝐁\bar{\mathbf{B}}over¯ start_ARG bold_B end_ARG, 𝐂¯¯𝐂\bar{\mathbf{C}}over¯ start_ARG bold_C end_ARG, 𝐃¯¯𝐃\bar{\mathbf{D}}over¯ start_ARG bold_D end_ARG) over the specified sampling time-scale ΔDΔsuperscript𝐷\Delta\in\mathbb{R}^{D}roman_Δ ∈ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT:

𝐀¯=eΔ𝐀,𝐁¯=(eΔ𝐀I)𝐀1𝐁,𝐂¯=𝐂,𝐃¯=𝐃formulae-sequence¯𝐀superscript𝑒Δ𝐀formulae-sequence¯𝐁superscript𝑒Δ𝐀𝐼superscript𝐀1𝐁formulae-sequence¯𝐂𝐂¯𝐃𝐃\displaystyle\bar{\mathbf{A}}=e^{\Delta\mathbf{A}},\quad\bar{\mathbf{B}}=\left% (e^{\Delta\mathbf{A}}-I\right)\mathbf{A}^{-1}\mathbf{B},\quad\bar{\mathbf{C}}=% \mathbf{C},\quad\bar{\mathbf{D}}=\mathbf{D}over¯ start_ARG bold_A end_ARG = italic_e start_POSTSUPERSCRIPT roman_Δ bold_A end_POSTSUPERSCRIPT , over¯ start_ARG bold_B end_ARG = ( italic_e start_POSTSUPERSCRIPT roman_Δ bold_A end_POSTSUPERSCRIPT - italic_I ) bold_A start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_B , over¯ start_ARG bold_C end_ARG = bold_C , over¯ start_ARG bold_D end_ARG = bold_D (2)

In addition, many improved methods [38, 18] use an approximation of 𝐁¯¯𝐁\bar{\mathbf{B}}over¯ start_ARG bold_B end_ARG based on the first-order Taylor Series:

𝐁¯=(eΔ𝐀I)𝐀1𝐁(Δ𝐀)(Δ𝐀)1Δ𝐁=Δ𝐁¯𝐁superscript𝑒Δ𝐀𝐼superscript𝐀1𝐁Δ𝐀superscriptΔ𝐀1Δ𝐁Δ𝐁\bar{\mathbf{B}}=\left(e^{\Delta\mathbf{A}}-I\right)\mathbf{A}^{-1}\mathbf{B}% \approx(\Delta\mathbf{A})(\Delta\mathbf{A})^{-1}\Delta\mathbf{B}=\Delta\mathbf% {B}over¯ start_ARG bold_B end_ARG = ( italic_e start_POSTSUPERSCRIPT roman_Δ bold_A end_POSTSUPERSCRIPT - italic_I ) bold_A start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_B ≈ ( roman_Δ bold_A ) ( roman_Δ bold_A ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT roman_Δ bold_B = roman_Δ bold_B (3)

Selective Mechanism .

Previous SSMs store information through finite states and inherent time-invariance, which limits their effectiveness. Therefore, Mamba [18] introduces a dynamic mechanism to selectively filter out input into a sequential state. Specifically, it utilizes Linear Projection to calculate the parameters {𝐁i}i=1Lsuperscriptsubscriptsubscript𝐁𝑖𝑖1𝐿\{\mathbf{B}_{i}\}_{i=1}^{L}{ bold_B start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT, {𝐂i}i=1Lsuperscriptsubscriptsubscript𝐂𝑖𝑖1𝐿\{\mathbf{C}_{i}\}_{i=1}^{L}{ bold_C start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT and {𝚫i}i=1Lsuperscriptsubscriptsubscript𝚫𝑖𝑖1𝐿\{\mathbf{\Delta}_{i}\}_{i=1}^{L}{ bold_Δ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT from the input sequence {xi}i=1Lsuperscriptsubscriptsubscript𝑥𝑖𝑖1𝐿\{x_{i}\}_{i=1}^{L}{ italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT with xi1×Dsubscript𝑥𝑖superscript1𝐷x_{i}\in\mathbb{R}^{1\times D}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 1 × italic_D end_POSTSUPERSCRIPT directly to improve the context-aware ability. Then the output sequence {yi}i=1Lsuperscriptsubscriptsubscript𝑦𝑖𝑖1𝐿\{y_{i}\}_{i=1}^{L}{ italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT can be computed with those input-adaptive discretized parameters as follows:

hi=𝐀¯ihi1+𝐁¯ixi,yi=𝐂ihi+𝐃xiformulae-sequencesubscript𝑖subscript¯𝐀𝑖subscript𝑖1subscript¯𝐁𝑖subscript𝑥𝑖subscript𝑦𝑖subscript𝐂𝑖subscript𝑖𝐃subscript𝑥𝑖h_{i}=\bar{\mathbf{A}}_{i}h_{i-1}+\bar{\mathbf{B}}_{i}x_{i},\quad y_{i}=% \mathbf{C}_{i}h_{i}+\mathbf{D}x_{i}italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = over¯ start_ARG bold_A end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT + over¯ start_ARG bold_B end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = bold_C start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + bold_D italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT (4)
Refer to caption
Figure 2: Illustration of Tree State Space Model. With an image feature map x𝑥xitalic_x, we perform Tree Scanning Algorithm (TSA) to construct a 4444-connected graph with edge weights measured by dissimilarity between pixels. Then, we obtain an MST with vertices set ΩΩ\Omegaroman_Ω through a pruning algorithm and perform the state transition for each vertex in this topology (detailed in Sec. 3.2). Red arrows describe the propagation source of vertex i𝑖iitalic_i.

3.2 Tree State Space Model

Mamba [18] has showcased remarkable performance in modeling the dependencies of consecutive words in a sequence. However, its applicability in long-context tasks, especially visual modeling, still poses certain challenges. For visual tasks, many methods attempt to address this problem by employing fixed scanning strategies, such as multi-directional raster scan [38, 70], local scan [31], and continuous scan [62]. However, these handcrafted scanning methods fail to effectively preserve the 2D structural information of images.

Following the design in Mamba [18], we construct a transform block as a tree state space model, which is presented in Fig. 2. The only difference between our block and Mamba lies in the replacement of the structured state space block with the proposed tree scanning algorithm. In the tree scanning algorithm, we generate a tree typology and then propagate the state of each vertex along the topological path to obtain strong feature representations. In addition, our algorithm can effectively enhance language representations by incorporating such a tree topology during text processing, which overcomes the geometrical constraints of text sequences. In the following, we elaborate on the proposed tree scanning algorithm and its applications for multi-modal tasks.

Refer to caption
Figure 3: Overview of GrootV. LN means LayerNorm and FFN is a feed-forward network in the basic block. S2 and P1 denote stride of 2222 and padding size of 1111 in convolution, respectively.

Tree Scanning Algorithm.

Given an input feature X={xi}i=1L𝑋superscriptsubscriptsubscript𝑥𝑖𝑖1𝐿X=\{x_{i}\}_{i=1}^{L}italic_X = { italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT where L𝐿Litalic_L is the sequence length (or the number of input pixels), we construct an undirected m𝑚mitalic_m-connected graph G=(V,E)𝐺𝑉𝐸G=(V,E)italic_G = ( italic_V , italic_E ) for the feature. m𝑚mitalic_m is a hyper-parameter that indicates the number of adjacent tokens. Following [64, 50], we set m=4𝑚4m=4italic_m = 4 for visual tasks, meaning each pixel is connected to its four neighboring pixels. For language tasks, we set m=3𝑚3m=3italic_m = 3 by default, meaning each token is connected to the previous three tokens. In addition, the vertices V𝑉Vitalic_V represent the pixel (or token) embeddings, and the E𝐸Eitalic_E indicates the edges of the graph. The edge weight is calculated by the feature dissimilarity between adjacent vertices. Besides, the metric of dissimilarity uses cosine distance by default, and the comparison with other metrics refers to Table 6.

We use the Contractive Boruvka algorithm [2] to prune the edges with significant dissimilarity, which generates a minimum spanning tree (MST) 𝒢Tsubscript𝒢𝑇\mathcal{G}_{T}caligraphic_G start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT whose sum of dissimilarity weights is minimum out of all spanning trees. In the propagation process, we iteratively traverse each vertex, treating it as the root, and aggregate the features of the remaining vertices. Intuitively, applying state propagation within such a geometric configuration makes its preferential interactions among vertices with small spatial and feature distances. Following the Mamba, we employ the data-dependent transition matrix for state propagation. For a vertex k𝑘kitalic_k, we denote the transition matrix with its parent as 𝐀¯ksubscript¯𝐀𝑘\bar{\mathbf{A}}_{k}over¯ start_ARG bold_A end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT. Furthermore, following the Eq. 4, the state aggregation process for the i𝑖iitalic_i-th vertex can be formulated as:

hi=jΩS(Eij)𝐁¯jxj,S(Eij)=kNij𝐀¯k,formulae-sequencesubscript𝑖subscriptfor-all𝑗Ω𝑆subscript𝐸𝑖𝑗subscript¯𝐁𝑗subscript𝑥𝑗𝑆subscript𝐸𝑖𝑗subscriptproduct𝑘subscript𝑁𝑖𝑗subscript¯𝐀𝑘h_{i}=\sum_{\forall j\in\Omega}S(E_{ij})\bar{\mathbf{B}}_{j}x_{j},\quad S(E_{% ij})=\prod_{k\in N_{ij}}\bar{\mathbf{A}}_{k},italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT ∀ italic_j ∈ roman_Ω end_POSTSUBSCRIPT italic_S ( italic_E start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT ) over¯ start_ARG bold_B end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , italic_S ( italic_E start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT ) = ∏ start_POSTSUBSCRIPT italic_k ∈ italic_N start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT over¯ start_ARG bold_A end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , (5)

where ΩΩ\Omegaroman_Ω denotes the index set of all vertices in the tree. S(Eij)𝑆subscript𝐸𝑖𝑗S(E_{ij})italic_S ( italic_E start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT ) represents the path weight of hyperedge Eijsubscript𝐸𝑖𝑗E_{ij}italic_E start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT traced from j𝑗jitalic_j-th vertex to i𝑖iitalic_i-th vertex in the tree 𝒢Tsubscript𝒢𝑇\mathcal{G}_{T}caligraphic_G start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT, and Nijsubscript𝑁𝑖𝑗N_{ij}italic_N start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT indicates the index set of all vertices on this hyperedge. For visual tasks, we iterate over each vertex, treating it as the root of the spanning tree 𝒢Tsubscript𝒢𝑇\mathcal{G}_{T}caligraphic_G start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT, and aggregate the states from the other vertices, thereby obtaining the transformed states {hi}i=1Lsuperscriptsubscriptsubscript𝑖𝑖1𝐿\{h_{i}\}_{i=1}^{L}{ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT. For textual tasks, because of the causal prediction manner in large language models, we only take the last token as root and aggregate from other tokens. To achieve end-to-end training, we derive the derivative of the output hidden state hisubscript𝑖h_{i}italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT to the input variables 𝐀¯ksubscript¯𝐀𝑘\bar{\mathbf{A}}_{k}over¯ start_ARG bold_A end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT, 𝐁¯jsubscript¯𝐁𝑗\bar{\mathbf{B}}_{j}over¯ start_ARG bold_B end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT and xjsubscript𝑥𝑗x_{j}italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT as follows:

hixj=S(Eij)𝐁¯j,hi𝐁¯j=S(Eij)xjformulae-sequencesubscript𝑖subscript𝑥𝑗𝑆subscript𝐸𝑖𝑗subscript¯𝐁𝑗subscript𝑖subscript¯𝐁𝑗𝑆subscript𝐸𝑖𝑗subscript𝑥𝑗\frac{\partial h_{i}}{\partial x_{j}}=S\left({E}_{ij}\right)\bar{\mathbf{B}}_{% j},\quad\frac{\partial h_{i}}{\partial\bar{\mathbf{B}}_{j}}=S\left({E}_{ij}% \right)x_{j}divide start_ARG ∂ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG start_ARG ∂ italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG = italic_S ( italic_E start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT ) over¯ start_ARG bold_B end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , divide start_ARG ∂ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG start_ARG ∂ over¯ start_ARG bold_B end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG = italic_S ( italic_E start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT ) italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT (6)
hi𝐀¯k=jCki𝐁¯jxjS(Ekj)S(Ein),subscript𝑖subscript¯𝐀𝑘subscriptfor-all𝑗superscriptsubscript𝐶𝑘𝑖subscript¯𝐁𝑗subscript𝑥𝑗𝑆subscript𝐸𝑘𝑗𝑆subscript𝐸𝑖𝑛\frac{\partial h_{i}}{\partial\bar{\mathbf{A}}_{k}}=\sum_{\forall j\in C_{k}^{% i}}\bar{\mathbf{B}}_{j}x_{j}S(E_{kj})S(E_{in}),divide start_ARG ∂ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG start_ARG ∂ over¯ start_ARG bold_A end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG = ∑ start_POSTSUBSCRIPT ∀ italic_j ∈ italic_C start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT end_POSTSUBSCRIPT over¯ start_ARG bold_B end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_S ( italic_E start_POSTSUBSCRIPT italic_k italic_j end_POSTSUBSCRIPT ) italic_S ( italic_E start_POSTSUBSCRIPT italic_i italic_n end_POSTSUBSCRIPT ) , (7)

where Ckisuperscriptsubscript𝐶𝑘𝑖C_{k}^{i}italic_C start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT indicates the children of vertex k𝑘kitalic_k in tree 𝒢Tsubscript𝒢𝑇\mathcal{G}_{T}caligraphic_G start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT whose root is the vertex i𝑖iitalic_i, and n𝑛nitalic_n denotes the parent of vertex k𝑘kitalic_k in Eq. 7. Finally, the output feature Y𝑌Yitalic_Y can be formulated as:

Y=𝐂Norm(H)+𝐃X,𝑌direct-product𝐂𝑁𝑜𝑟𝑚𝐻direct-product𝐃𝑋Y=\mathbf{C}\odot Norm(H)+\mathbf{D}\odot X,italic_Y = bold_C ⊙ italic_N italic_o italic_r italic_m ( italic_H ) + bold_D ⊙ italic_X , (8)

where Y𝑌Yitalic_Y, H𝐻Hitalic_H and X𝑋Xitalic_X indicate the stack of {yi}i=1Lsuperscriptsubscriptsubscript𝑦𝑖𝑖1𝐿\{y_{i}\}_{i=1}^{L}{ italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT, {hi}i=1Lsuperscriptsubscriptsubscript𝑖𝑖1𝐿\{h_{i}\}_{i=1}^{L}{ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT and {x}i=1Lsuperscriptsubscript𝑥𝑖1𝐿\{x\}_{i=1}^{L}{ italic_x } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT respectively. direct-product\odot denotes the element-wise multiplication.

Algorithm 1 Vision Tree Scanning
0:  Input feature {xi}i=1Lsuperscriptsubscriptsubscript𝑥𝑖𝑖1𝐿\{x_{i}\}_{i=1}^{L}{ italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT; Input matrix {𝐁¯i}i=1Lsuperscriptsubscriptsubscript¯𝐁𝑖𝑖1𝐿\{\bar{\mathbf{B}}_{i}\}_{i=1}^{L}{ over¯ start_ARG bold_B end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT; State matrix {𝐀¯i}i=1Lsuperscriptsubscriptsubscript¯𝐀𝑖𝑖1𝐿\{\bar{\mathbf{A}}_{i}\}_{i=1}^{L}{ over¯ start_ARG bold_A end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT; Gradient of loss to hidden states {Losshi}i=1Lsuperscriptsubscript𝐿𝑜𝑠𝑠subscript𝑖𝑖1𝐿\{\frac{\partial Loss}{\partial h_{i}}\}_{i=1}^{L}{ divide start_ARG ∂ italic_L italic_o italic_s italic_s end_ARG start_ARG ∂ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT; Minimum Spanning Tree 𝒢Tsubscript𝒢𝑇\mathcal{G}_{T}caligraphic_G start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT.
0:  Root,,LeafBFS(𝒢T)𝑅𝑜𝑜𝑡𝐿𝑒𝑎𝑓𝐵𝐹𝑆subscript𝒢𝑇Root,\dots,Leaf\leftarrow BFS(\mathcal{G}_{T})italic_R italic_o italic_o italic_t , … , italic_L italic_e italic_a italic_f ← italic_B italic_F italic_S ( caligraphic_G start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT )       contains-as-subgroup\rhd Breadth-first topological order of 𝒢Tsubscript𝒢𝑇\mathcal{G}_{T}caligraphic_G start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT
0:  
  Initialization: {ξi}i=1L{xi}i=1Lsuperscriptsubscriptsubscript𝜉𝑖𝑖1𝐿superscriptsubscriptsubscript𝑥𝑖𝑖1𝐿\{\xi_{i}\}_{i=1}^{L}\leftarrow\{x_{i}\}_{i=1}^{L}{ italic_ξ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT ← { italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT
2:  for iLeaf𝑖𝐿𝑒𝑎𝑓i\leftarrow{Leaf}italic_i ← italic_L italic_e italic_a italic_f to Root𝑅𝑜𝑜𝑡{Root}italic_R italic_o italic_o italic_t do
     ξi=𝐁¯ixi+j{tPar(t)=i}ξj𝐀¯jsubscript𝜉𝑖subscript¯𝐁𝑖subscript𝑥𝑖subscriptfor-all𝑗conditional-set𝑡Par𝑡𝑖subscript𝜉𝑗subscript¯𝐀𝑗\xi_{i}=\bar{\mathbf{B}}_{i}x_{i}+\sum_{\forall j\in\{t\mid\text{Par}(t)=i\}}% \xi_{j}\bar{\mathbf{A}}_{j}italic_ξ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = over¯ start_ARG bold_B end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + ∑ start_POSTSUBSCRIPT ∀ italic_j ∈ { italic_t ∣ Par ( italic_t ) = italic_i } end_POSTSUBSCRIPT italic_ξ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT over¯ start_ARG bold_A end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT
4:  end for
  for iRoot𝑖𝑅𝑜𝑜𝑡i\leftarrow{Root}italic_i ← italic_R italic_o italic_o italic_t to Leaf𝐿𝑒𝑎𝑓{Leaf}italic_L italic_e italic_a italic_f do
6:     if i𝑖iitalic_i is Root𝑅𝑜𝑜𝑡Rootitalic_R italic_o italic_o italic_t then
        hi=ξisubscript𝑖subscript𝜉𝑖h_{i}=\xi_{i}italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_ξ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT
8:     else
        hi=𝐀¯i(hPar(i)𝐀¯iξi)+ξi=(1𝐀¯i2)ξi+𝐀¯ihPar(i)subscript𝑖subscript¯𝐀𝑖subscriptPar𝑖subscript¯𝐀𝑖subscript𝜉𝑖subscript𝜉𝑖1superscriptsubscript¯𝐀𝑖2subscript𝜉𝑖subscript¯𝐀𝑖subscriptPar𝑖h_{i}=\bar{\mathbf{A}}_{i}(h_{\text{Par}(i)}-\bar{\mathbf{A}}_{i}\xi_{i})+\xi_% {i}=(1-\bar{\mathbf{A}}_{i}^{2})\xi_{i}+\bar{\mathbf{A}}_{i}h_{\text{Par}(i)}italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = over¯ start_ARG bold_A end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT Par ( italic_i ) end_POSTSUBSCRIPT - over¯ start_ARG bold_A end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_ξ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) + italic_ξ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = ( 1 - over¯ start_ARG bold_A end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) italic_ξ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + over¯ start_ARG bold_A end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT Par ( italic_i ) end_POSTSUBSCRIPT
10:     end if
  end for
  
12:  Initialization: {ηi}i=1L{Losshi}i=1Lsuperscriptsubscriptsubscript𝜂𝑖𝑖1𝐿superscriptsubscript𝐿𝑜𝑠𝑠subscript𝑖𝑖1𝐿\{\eta_{i}\}_{i=1}^{L}\leftarrow\{\frac{\partial{Loss}}{\partial{h_{i}}}\}_{i=% 1}^{L}{ italic_η start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT ← { divide start_ARG ∂ italic_L italic_o italic_s italic_s end_ARG start_ARG ∂ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT
  for iLeaf𝑖𝐿𝑒𝑎𝑓i\leftarrow{Leaf}italic_i ← italic_L italic_e italic_a italic_f to Root𝑅𝑜𝑜𝑡{Root}italic_R italic_o italic_o italic_t do
14:     ηi=𝐁¯iLosshi+j{tPar(t)=i}ηj𝐀¯jsubscript𝜂𝑖subscript¯𝐁𝑖𝐿𝑜𝑠𝑠subscript𝑖subscriptfor-all𝑗conditional-set𝑡Par𝑡𝑖subscript𝜂𝑗subscript¯𝐀𝑗\eta_{i}=\bar{\mathbf{B}}_{i}\frac{\partial{Loss}}{\partial{h_{i}}}+\sum_{% \forall j\in\{t\mid\text{Par}(t)=i\}}\eta_{j}\bar{\mathbf{A}}_{j}italic_η start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = over¯ start_ARG bold_B end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT divide start_ARG ∂ italic_L italic_o italic_s italic_s end_ARG start_ARG ∂ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG + ∑ start_POSTSUBSCRIPT ∀ italic_j ∈ { italic_t ∣ Par ( italic_t ) = italic_i } end_POSTSUBSCRIPT italic_η start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT over¯ start_ARG bold_A end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT
  end for
16:  for iRoot𝑖𝑅𝑜𝑜𝑡i\leftarrow{Root}italic_i ← italic_R italic_o italic_o italic_t to Leaf𝐿𝑒𝑎𝑓{Leaf}italic_L italic_e italic_a italic_f do
     if i𝑖iitalic_i is Root𝑅𝑜𝑜𝑡Rootitalic_R italic_o italic_o italic_t then
18:        Lossxi=ηi𝐁¯i𝐿𝑜𝑠𝑠subscript𝑥𝑖subscript𝜂𝑖subscript¯𝐁𝑖\frac{\partial Loss}{\partial x_{i}}=\eta_{i}\bar{\mathbf{B}}_{i}divide start_ARG ∂ italic_L italic_o italic_s italic_s end_ARG start_ARG ∂ italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG = italic_η start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT over¯ start_ARG bold_B end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ,  Loss𝐁¯i=ηixi𝐿𝑜𝑠𝑠subscript¯𝐁𝑖subscript𝜂𝑖subscript𝑥𝑖\frac{\partial Loss}{\partial\bar{\mathbf{B}}_{i}}=\eta_{i}x_{i}divide start_ARG ∂ italic_L italic_o italic_s italic_s end_ARG start_ARG ∂ over¯ start_ARG bold_B end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG = italic_η start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ,  Loss𝐀¯i=0𝐿𝑜𝑠𝑠subscript¯𝐀𝑖0\frac{\partial Loss}{\partial\bar{\mathbf{A}}_{i}}=0divide start_ARG ∂ italic_L italic_o italic_s italic_s end_ARG start_ARG ∂ over¯ start_ARG bold_A end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG = 0
     else
20:        Lossxi=(1𝐀¯i2)ηi𝐁¯i+𝐀¯iLossxPar(i)𝐁¯i𝐿𝑜𝑠𝑠subscript𝑥𝑖1superscriptsubscript¯𝐀𝑖2subscript𝜂𝑖subscript¯𝐁𝑖subscript¯𝐀𝑖𝐿𝑜𝑠𝑠subscript𝑥Par𝑖subscript¯𝐁𝑖\frac{\partial Loss}{\partial x_{i}}=(1-\bar{\mathbf{A}}_{i}^{2})\eta_{i}\bar{% \mathbf{B}}_{i}+\bar{\mathbf{A}}_{i}\frac{\partial Loss}{\partial x_{\text{Par% }(i)}}\bar{\mathbf{B}}_{i}divide start_ARG ∂ italic_L italic_o italic_s italic_s end_ARG start_ARG ∂ italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG = ( 1 - over¯ start_ARG bold_A end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) italic_η start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT over¯ start_ARG bold_B end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + over¯ start_ARG bold_A end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT divide start_ARG ∂ italic_L italic_o italic_s italic_s end_ARG start_ARG ∂ italic_x start_POSTSUBSCRIPT Par ( italic_i ) end_POSTSUBSCRIPT end_ARG over¯ start_ARG bold_B end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ,  Loss𝐁¯i=(1𝐀¯i2)ηixi+𝐀¯iLoss𝐁¯Par(i)xi𝐿𝑜𝑠𝑠subscript¯𝐁𝑖1superscriptsubscript¯𝐀𝑖2subscript𝜂𝑖subscript𝑥𝑖subscript¯𝐀𝑖𝐿𝑜𝑠𝑠subscript¯𝐁Par𝑖subscript𝑥𝑖\frac{\partial Loss}{\partial\bar{\mathbf{B}}_{i}}=(1-\bar{\mathbf{A}}_{i}^{2}% )\eta_{i}x_{i}+\bar{\mathbf{A}}_{i}\frac{\partial Loss}{\partial\bar{\mathbf{B% }}_{\text{Par}(i)}}x_{i}divide start_ARG ∂ italic_L italic_o italic_s italic_s end_ARG start_ARG ∂ over¯ start_ARG bold_B end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG = ( 1 - over¯ start_ARG bold_A end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) italic_η start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + over¯ start_ARG bold_A end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT divide start_ARG ∂ italic_L italic_o italic_s italic_s end_ARG start_ARG ∂ over¯ start_ARG bold_B end_ARG start_POSTSUBSCRIPT Par ( italic_i ) end_POSTSUBSCRIPT end_ARG italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT
        Loss𝐀¯i=ηi(hi𝐀¯iξi)+ξi(Lossxi𝐀¯iηi)=ηihi+ξiLossxi2ηiξi𝐀¯i𝐿𝑜𝑠𝑠subscript¯𝐀𝑖subscript𝜂𝑖subscript𝑖subscript¯𝐀𝑖subscript𝜉𝑖subscript𝜉𝑖𝐿𝑜𝑠𝑠subscript𝑥𝑖subscript¯𝐀𝑖subscript𝜂𝑖subscript𝜂𝑖subscript𝑖subscript𝜉𝑖𝐿𝑜𝑠𝑠subscript𝑥𝑖2subscript𝜂𝑖subscript𝜉𝑖subscript¯𝐀𝑖\frac{\partial Loss}{\partial\bar{\mathbf{A}}_{i}}=\eta_{i}*(h_{i}-\bar{% \mathbf{A}}_{i}\xi_{i})+\xi_{i}*(\frac{\partial Loss}{\partial x_{i}}-\bar{% \mathbf{A}}_{i}\eta_{i})=\eta_{i}h_{i}+\xi_{i}\frac{\partial Loss}{\partial x_% {i}}-2\eta_{i}\xi_{i}\bar{\mathbf{A}}_{i}divide start_ARG ∂ italic_L italic_o italic_s italic_s end_ARG start_ARG ∂ over¯ start_ARG bold_A end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG = italic_η start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∗ ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - over¯ start_ARG bold_A end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_ξ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) + italic_ξ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∗ ( divide start_ARG ∂ italic_L italic_o italic_s italic_s end_ARG start_ARG ∂ italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG - over¯ start_ARG bold_A end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_η start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) = italic_η start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + italic_ξ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT divide start_ARG ∂ italic_L italic_o italic_s italic_s end_ARG start_ARG ∂ italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG - 2 italic_η start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_ξ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT over¯ start_ARG bold_A end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT
22:     end if
  end for
  Hidden states {hi}i=1Lsuperscriptsubscriptsubscript𝑖𝑖1𝐿\{h_{i}\}_{i=1}^{L}{ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT; Grad. of loss to input feature {Lossxi}i=1Lsuperscriptsubscript𝐿𝑜𝑠𝑠subscript𝑥𝑖𝑖1𝐿\{\frac{\partial Loss}{\partial x_{i}}\}_{i=1}^{L}{ divide start_ARG ∂ italic_L italic_o italic_s italic_s end_ARG start_ARG ∂ italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT; Grad. of loss to input matrix {Loss𝐁¯i}i=1Lsuperscriptsubscript𝐿𝑜𝑠𝑠subscript¯𝐁𝑖𝑖1𝐿\{\frac{\partial Loss}{\partial\bar{\mathbf{B}}_{i}}\}_{i=1}^{L}{ divide start_ARG ∂ italic_L italic_o italic_s italic_s end_ARG start_ARG ∂ over¯ start_ARG bold_B end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT; Grad. of loss to state matrix {Loss𝐀¯i}i=1Lsuperscriptsubscript𝐿𝑜𝑠𝑠subscript¯𝐀𝑖𝑖1𝐿\{\frac{\partial Loss}{\partial\bar{\mathbf{A}}_{i}}\}_{i=1}^{L}{ divide start_ARG ∂ italic_L italic_o italic_s italic_s end_ARG start_ARG ∂ over¯ start_ARG bold_A end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT.

Efficient Implementation for Multi-Modality.

For visual tasks, the tree scanning algorithm requires two levels of traversal across the entire pixel set, resulting in an unacceptable quadratic complexity relative to the number of pixels 𝒪(L2)𝒪superscript𝐿2\mathcal{O}(L^{2})caligraphic_O ( italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ). To alleviate this issue, we utilize a dynamic programming procedure to accelerate the inference and training processes as elaborated in Algorithm 1, which results in linear complexity 𝒪(L)𝒪𝐿\mathcal{O}(L)caligraphic_O ( italic_L ). For textual tasks, we perform a unidirectional aggregation approach (shown in Algorithm 2 of Appendix B) in adherence to the causal nature of language. Moreover, we provide the back-propagation process for both Vision Tree Scanning and Language Tree Scanning processes, whose detailed proofs refer to Appendix C.

3.3 Application for Vision and Language

GrootV

Given an image with a shape of H×W×3𝐻𝑊3H\times W\times 3italic_H × italic_W × 3, our goal is to obtain high-quality visual features for downstream tasks. To this end, we propose an effective vision architecture GrootV which consists of a stem module, several basic blocks and downsampling layers to generate hierarchical representations illustrated in Fig. 3. Overall, our GrootV comprises four stages similar to previous general vision backbones [41, 40, 57, 38]. We integrate the stem module before the first stage to decrease the resolution of the input image signal by a factor of 4444, resulting in a feature map with a shape of H4×W4×C𝐻4𝑊4𝐶\frac{H}{4}\times\frac{W}{4}\times Cdivide start_ARG italic_H end_ARG start_ARG 4 end_ARG × divide start_ARG italic_W end_ARG start_ARG 4 end_ARG × italic_C. It includes two convolutions, two Layer Normalization (LN) layers and one GELU activation function. The kernel size for both convolutions is 3333 with a stride of 2222 and padding of 1111. Similarly, a downsampling layer consists of a 3×3333\times 33 × 3 convolution with a stride of 2222 and padding of 1111 and an LN layer. Positioned between two stages, it serves to downsample the input feature map by a factor of 2222. Motivated by [57, 38], we devise a residual block with skip connections to integrate our fundamental Tree State Space Model in Sec. 3.2. In detail, we first normalize the input features with LN layer. Spatial priors and long-range dependencies are then obtained through our tree scanning algorithm with residual connections established alongside the input features. Finally, a feedforward neural network is utilized to project the normalized features to output signals as shown in Fig. 3. Based on the above origin components, we develop our GrootV in three scales, i.e.formulae-sequence𝑖𝑒i.e.italic_i . italic_e ., GrootV-Tiny, GrootV-Small and GrootV-Base.

GrootL

Recurrent neural networks rely on fixed memory to preserve past information, which poses limitations when handling long contexts where relevant words are distant from the current moment. While Mamba [18] employs a selection mechanism to enhance context awareness, its fixed memory size cannot expand over time, resulting in restricted state space. Therefore, the ability to extrapolate decreases during scrolling as the prompt extends. To mitigate this issue, we propose an effective fine-tuning paradigm. Specifically, the tree-based topology branch is built upon one-way scrolling with a scaling factor, enabling state transitions within such a structure. This arrangement facilitates the preferential interaction of semantically related tokens. It is noteworthy that this paradigm does not introduce any additional training parameters. Instead, it utilizes pretrained state transformation parameters to conduct semantic aggregation by incorporating topological structures. Experimental results demonstrate the effectiveness of our approach.

Method Typ #Param. #FLOPs Top-1
Acc.
Deit-S [54] T 22M 4.6G 79.9
Swin-T [40] T 28M 4.6G 81.3
CoAtNet-0 [11] T 25M 4.0G 81.6
SG-Former-S [46] T 23M 4.8G 83.2
ConvNeXt-T [41] C 29M 4.5G 82.1
SLaK-T [37] C 30M 5.0G 82.5
UniRepLKNet-T [13] C 31M 4.9G 83.2
InternImage-T [57] C 30M 5.0G 83.5
ViM-S [70] S 26M 5.1G 80.5
LocalViM-S [31] S 28M 4.8G 81.2
PlainMamba-L2 [62] S 25M 8.1G 81.6
Mamba-2D-S [34] S 24M - 81.7
S4ND-ConvNeXt-T [44] S 30M - 82.2
VMamba-T [38] S 31M 4.9G 82.5
LocalVMamba-T [31] S 26M 5.7G 82.7
GrootV-T (Ours) S 30M 4.8G 83.4
Swin-S [40] T 50M 8.7G 83.0
CoAtNet-1 [11] T 42M 8.0G 83.3
Method Typ #Param. #FLOPs Top-1
Acc.
ConvNeXt-S [41] C 50M 8.7G 83.1
SLaK-S [37] C 55M 9.8G 83.8
UniRepLKNet-S [13] C 56M 9.1G 83.9
InternImage-S [57] C 50M 8.0G 84.2
HyenaViT-B [16] S 88M - 78.5
S4ND-ViT-B [44] S 89M - 80.4
PlainMamba-L3 [62] S 50M 14.4G 82.3
VMamba-S [38] S 50M 8.7G 83.6
LocalVMamba-S [31] S 50M 11.4G 83.7
GrootV-S (Ours) S 51M 8.5G 84.2
Deit-B [54] T 86M 55.4G 83.1
Swin-B [40] T 88M 15.4G 83.5
CoAtNet-2 [11] T 75M 16.0G 84.1
ConvNeXt-B [41] C 89M 15.4G 83.8
SLaK-B [37] C 95M 17.0G 84.0
Mamba-2D-B [34] S 92M - 83.0
VMamba-B [38] S 89M 15.4G 83.9
GrootV-B (Ours) S 91M 15.1G 84.8
Table 1: Image classification performance on the ImageNet-1K validation set. T, C and S indicate the model type of Transformer, CNN and SSM, respectively. All models take a scale of 2242 as input.

4 Experiments

We conduct extensive experiments to evaluate the effectiveness of GrootV and compare it with advanced CNN-based, Transformer-based, and SSM-based models covering various downstream tasks, including image classification, object detection and semantic segmentation. Furthermore, we validate the capability of GrootL in the field of natural language understanding.

4.1 Image Classification

Settings.

We assess the classification performance of GrootV on the ImageNet-1k dataset [12]. Following previous practices [40, 41, 57, 38], all GrootV models are trained for 300300300300 epochs from scratch using AdamW optimizer with a warm-up strategy of 20202020 epochs. During training, we utilize a Cosine Scheduler with an initial learning rate of 1×1031superscript1031\times 10^{-3}1 × 10 start_POSTSUPERSCRIPT - 3 end_POSTSUPERSCRIPT and weight decay of 0.050.050.050.05. In addition, the exponential moving average (EMA) is also applied.

Results.

The comparison results summarized in Table 1 show GrootV leading all SSM-based models and competitive with advanced CNNs and Transformers across tiny, small, and base scales. Specifically, GrootV-T achieves 83.4%percent83.483.4\%83.4 % Top-1 Acc. boosting ViM-S by 2.9%percent2.92.9\%2.9 %, LocalVim-S by 2.2%percent2.22.2\%2.2 %, PlainMamba-L2 by 1.8%percent1.81.8\%1.8 % and VMamba-T by 0.9%percent0.90.9\%0.9 % with similar FLOPs. Additionally, it surpasses ConvNeXt-T by 1.3%percent1.31.3\%1.3 % and Swin-T by 2.2%percent2.22.2\%2.2 %, demonstrating the effectiveness of our method.

4.2 Object Detection

Settings.

We verify the detection performance of GrootV on the MSCOCO 2017 dataset [36] with MMDetection library [3]. We follow previous works [38, 57, 40, 31, 49, 51, 67, 63, 6] to validate object detection and instance segmentation tasks with Mask-RCNN [26]. Specifically, We adopt the AdamW optimizer with a learning rate of 1×1041superscript1041\times 10^{-4}1 × 10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT and batch size of 16161616 to optimize the model built upon our pre-trained classification backbones on ImageNet-1K. The training schedules include 1×1\times1 × (12121212 epochs) and 3×3\times3 × (36363636 epochs) with multi-scale data augmentation.

Results.

As depicted in Table 7 (in Appendix A.), our method outperforms existing methods on most evaluation metrics, especially for instance segmentation. Under 1×1\times1 × schedule, GrootV-T achieves 47.047.047.047.0 in box mAP (APb), which is 1.11.11.11.1 points higher than ViM-S and 0.50.50.50.5 points higher than VMamba-T. It is worth noting that GrootV-T outperforms ViM-S by 1.71.71.71.7 points with 1×1\times1 × schedule and LocalVMamba-T by 0.40.40.40.4 points with 3×3\times3 × schedule in mask mAP (APm). Moreover, the best APb 50.150.150.150.1 and APm 44.644.644.644.6 are obtained by GrootV-S in 3×3\times3 × schedule with multi-scale training.

Refer to caption
Figure 4: Visualization of affinity maps in the specific position. The Location is marked by the red cross in each input (a). TP is our tree topology scanning algorithm (b), which captures more detailed structural information and has a larger receptive field compared to raster scanning (c).
Method Typ #FLOPs mIoU mIoU
SS MS
Swin-T [40] T 945G 44.5 45.8
ConvNeXt-T [41] C 939G 46.0 46.7
SLaK-T [37] C 936G 47.6 -
InternImage-T [57] C 944G 47.9 48.1
UniRepLKNet-T [13] C 946G 48.6 49.1
ViM-S [70] S - 44.9 -
LocalViM-S [31] S 297G 46.4 47.5
PlainMamba-L2 [62] S 285G 46.8 -
VMamba-T [38] S 964G 47.3 48.3
LocalVMamba-T [38] S 970G 47.9 49.1
GrootV-T (Ours) S 941G 48.5 49.4
Swin-S [40] T 1038G 47.6 49.5
ConvNeXt-S [41] C 1027G 48.7 49.6
SLaK-S [37] C 1028G 49.4 -
InternImage-S [57] C 1017G 50.1 50.9
UniRepLKNet-S [13] C 1036G 50.5 51.0
PlainMamba-L3 [62] S 419G 49.1 -
VMamba-S [38] S 1081G 49.5 50.5
LocalVMamba-S [31] S 1095G 50.0 51.0
GrootV-S (Ours) S 1019G 50.7 51.7
Table 2: Semantic segmentation performance on ADE20K val set. The crop size is all set to 5122superscript5122512^{2}512 start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT. SS and MS denote single-scale and multi-scale testing, respectively.

4.3 Semantic Segmentation

Settings.

To evaluate the semantic segmentation performance of our GrootV series, we train our models with UperNet [60] initialized by pre-trained classification weights on ADE20K[68] for 160k iterations, following common practices without additional augmentations for fair comparison.

Results.

Our method performs exceptionally well on segmentation tasks shown in Fig. 4. GrootV-T yields a clear improvement of +3.63.6+3.6+ 3.6 in single-scale mIoU compared to ViM-S and +1.91.9+1.9+ 1.9 in multi-scale mIoU compared to LocalViM-S. Furthermore, GrootV-S boosts InterImage-S by 0.60.60.60.6 and 0.80.80.80.8 in single-scale and multi-scale respectively. We consider the preservation of intricate structural details through tree topology scanning to be particularly advantageous for segmentation tasks that require pixel-level perception.

Method PIQA \uparrow Arc-E \uparrow SST \uparrow WG \uparrow L-ppl \downarrow Race \uparrow BQA \uparrow Average\uparrow
Acc.
Mamba [18] 64.5 48.0 65.6 51.8 16.1 27.4 16.8 45.7
+ LoRA [30] 64.7 48.3 65.1 52.2 17.7 28.6 17.8 46.1
+ GrootL (Ours) 65.0 49.8 69.5 51.1 15.9 28.9 19.2 47.2
Table 3: Evaluation on language model benchmarks. Arc-E, WG, L-ppl and BQA indicate Arc-easy [8], WinoGrande, LAMBADA [45] and Openbookqa [43] benchmark, respectively.

4.4 Language Understanding

We regard Mamba [18] with 130130130130M parameters as the base model. To verify the effectiveness of our GrootL in nature language understanding, we first fine-tune pre-trained Mamba via LoRA [30] and GrootL under the same setting with the Alpaca data [53], which contains 52000520005200052000 instruction tuning data for supervised fine-tuning. Then we utilize popular language benchmarks provided in the open-sourced lm-evaluation-harness project [17] for evaluation, including PIQA [1], AI2-ARC [8], SST [55], WinoGrande, LAMBADA [45], Race [33] and Openbookqa [43]. The results in Table 3 demonstrate that our GrootL provides a benefit of +1.1%percent1.1+1.1\%+ 1.1 % in average Acc. compared to LoRA. Since the short prompt length of WinoGrande dataset, the performance degrades with a marginal gap.

Scanning Strategy Acc
Raster Scan 82.6
Cross Scan 83.1
Tree Topology Scan 83.4
Table 4: Effectiveness of our algorithm.
Distance Metric Acc.
Manhattan𝑀𝑎𝑛𝑎𝑡𝑡𝑎𝑛Manhattanitalic_M italic_a italic_n italic_h italic_a italic_t italic_t italic_a italic_n 82.9
Euclidean𝐸𝑢𝑐𝑙𝑖𝑑𝑒𝑎𝑛Euclideanitalic_E italic_u italic_c italic_l italic_i italic_d italic_e italic_a italic_n 83.2
Cosine𝐶𝑜𝑠𝑖𝑛𝑒Cosineitalic_C italic_o italic_s italic_i italic_n italic_e 83.4
Table 5: Impact of different distance Metrics.
Root Setting Acc.
First vertex 82.9
Last vertex 83.0
All vertices 83.4
Table 6: Superiority of traversing all vertices.

4.5 Ablation Study & Qualitative Results

In this section, we conduct analysis experiments on ImageNet-1K dataset and present some visual results to illustrate the effectiveness of our algorithm.

Scanning Strategy.

We conduct a head-to-head comparison of different scanning strategies, as shown in Table 6. The tree topology scanning outperforms previous strategies by 0.8%percent0.80.8\%0.8 % and 0.3%percent0.30.3\%0.3 %, highlighting the superiority of our algorithm in vision recognition.

Distance Metric.

Before generating a minimum spanning tree from a connected graph, it is important to measure the edge weights between vertices. Therefore, we validate several distance metrics as illustrated in Table 6. The results indicate that Cosine𝐶𝑜𝑠𝑖𝑛𝑒Cosineitalic_C italic_o italic_s italic_i italic_n italic_e distance most effectively represents the relationship between vertices, performing 0.5%percent0.50.5\%0.5 % better than Manhattan𝑀𝑎𝑛𝑎𝑡𝑡𝑎𝑛Manhattanitalic_M italic_a italic_n italic_h italic_a italic_t italic_t italic_a italic_n and 0.2%percent0.20.2\%0.2 % better than Euclidean𝐸𝑢𝑐𝑙𝑖𝑑𝑒𝑎𝑛Euclideanitalic_E italic_u italic_c italic_l italic_i italic_d italic_e italic_a italic_n.

Root Setting.

We traverse all vertices, treating each as a root, and perform state transitions along the topological path from the other vertices toward the root. This traversal ensures that each vertex captures long-range dependencies. To verify the effectiveness of this operation, we consider only the first and last vertices as the root in Table 6. The results show reductions of 0.5%percent0.50.5\%0.5 % and 0.4%percent0.40.4\%0.4 %, respectively.

Qualitative Results.

To better illustrate the superiority of our scanning strategy, we visualize the affinity maps of different positions marked by the red cross in each input image. For example, we set the anchor point in the upper left corner of the sky as shown in the second row of in Fig. 4(a). Our method can easily identify white houses, flagpoles, and the sky, which raster scanning fails to achieve. This demonstrates the capability of our algorithm to preserve detailed structural information. More comparisons can be seen in Fig. 6 (in Appendix D.)

5 Conclusion & Limitations

In this paper, we propose a tree state space model to perform feature propagation on an input-aware topology. Besides, we introduce a linear complexity dynamic programming algorithm to enhance long-range interactions without increasing computational cost. With the proposed techniques, we establish the general multi-modal networks to break the original sequence constraints and achieve stronger representation capabilities. Extensive experiments demonstrate the effectiveness of our method in both visual and language tasks. The limitation of our method is that the tree structure is not a common paradigm, and it needs to be specifically optimized according to the hardware device.

References

  • [1] Bisk, Y., Zellers, R., Gao, J., Choi, Y., et al.: Piqa: Reasoning about physical commonsense in natural language. In: AAAI. pp. 7432–7439 (2020)
  • [2] Borůvka, O.: O jistém problému minimálním (1926)
  • [3] Chen, K., Wang, J., Pang, J., Cao, Y., Xiong, Y., Li, X., Sun, S., Feng, W., Liu, Z., Xu, J., et al.: Mmdetection: Open mmlab detection toolbox and benchmark. arXiv preprint arXiv:1906.07155 (2019)
  • [4] Chen, Z., Duan, Y., Wang, W., He, J., Lu, T., Dai, J., Qiao, Y.: Vision transformer adapter for dense predictions. arXiv preprint arXiv:2205.08534 (2022)
  • [5] Cheng, C., Song, L., Xue, R., Wang, H., Sun, H., Ge, Y., Shan, Y.: Meta-adapter: An online few-shot learner for vision-language model. arXiv preprint arXiv:2311.03774 (2023)
  • [6] Cheng, T., Song, L., Ge, Y., Liu, W., Wang, X., Shan, Y.: Yolo-world: Real-time open-vocabulary object detection. arXiv preprint arXiv:2401.17270 (2024)
  • [7] Chung, J., Gulcehre, C., Cho, K., Bengio, Y.: Empirical evaluation of gated recurrent neural networks on sequence modeling. arXiv preprint arXiv:1412.3555 (2014)
  • [8] Clark, P., Cowhey, I., Etzioni, O., Khot, T., Sabharwal, A., Schoenick, C., Tafjord, O.: Think you have solved question answering? try arc, the ai2 reasoning challenge. arXiv preprint arXiv:1803.05457 (2018)
  • [9] Cubuk, E.D., Zoph, B., Mane, D., Vasudevan, V., Le, Q.V.: Autoaugment: Learning augmentation strategies from data. In: CVPR. pp. 113–123 (2019)
  • [10] Dai, J., Qi, H., Xiong, Y., Li, Y., Zhang, G., Hu, H., Wei, Y.: Deformable convolutional networks. In: ICCV. pp. 764–773 (2017)
  • [11] Dai, Z., Liu, H., Le, Q.V., Tan, M.: Coatnet: Marrying convolution and attention for all data sizes. NeurIPS 34, 3965–3977 (2021)
  • [12] Deng, J., Dong, W., Socher, R., Li, L.J., Li, K., Fei-Fei, L.: Imagenet: A large-scale hierarchical image database. In: CVPR. pp. 248–255. Ieee (2009)
  • [13] Ding, X., Zhang, Y., Ge, Y., Zhao, S., Song, L., Yue, X., Shan, Y.: Unireplknet: A universal perception large-kernel convnet for audio, video, point cloud, time-series and image recognition. CVPR (2023)
  • [14] Dong, X., Bao, J., Chen, D., Zhang, W., Yu, N., Yuan, L., Chen, D., Guo, B.: Cswin transformer: A general vision transformer backbone with cross-shaped windows. In: CVPR. pp. 12124–12134 (2022)
  • [15] Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., Dehghani, M., Minderer, M., Heigold, G., Gelly, S., Uszkoreit, J., Houlsby, N.: An image is worth 16x16 words: Transformers for image recognition at scale. In: ICLR (2021)
  • [16] Fu, D., Arora, S., Grogan, J., Johnson, I., Eyuboglu, E.S., Thomas, A., Spector, B., Poli, M., Rudra, A., Ré, C.: Monarch mixer: A simple sub-quadratic gemm-based architecture. NeurIPS 36 (2023)
  • [17] Gao, L., Tow, J., Abbasi, B., Biderman, S., Black, S., DiPofi, A., Foster, C., Golding, L., Hsu, J., Le Noac’h, A., Li, H., McDonell, K., Muennighoff, N., Ociepa, C., Phang, J., Reynolds, L., Schoelkopf, H., Skowron, A., Sutawika, L., Tang, E., Thite, A., Wang, B., Wang, K., Zou, A.: A framework for few-shot language model evaluation (12 2023)
  • [18] Gu, A., Dao, T.: Mamba: Linear-time sequence modeling with selective state spaces. arXiv preprint arXiv:2312.00752 (2023)
  • [19] Gu, A., Dao, T., Ermon, S., Rudra, A., Ré, C.: Hippo: Recurrent memory with optimal polynomial projections. NeurIPS 33, 1474–1487 (2020)
  • [20] Gu, A., Goel, K., Gupta, A., Ré, C.: On the parameterization and initialization of diagonal state space models. NeurIPS 35, 35971–35983 (2022)
  • [21] Gu, A., Goel, K., Ré, C.: Efficiently modeling long sequences with structured state spaces. In: ICLR (2022)
  • [22] Gu, A., Johnson, I., Goel, K., Saab, K., Dao, T., Rudra, A., Ré, C.: Combining recurrent, convolutional, and continuous-time models with linear state space layers. NeurIPS 34, 572–585 (2021)
  • [23] Gupta, A., Gu, A., Berant, J.: Diagonal state spaces are as effective as structured state spaces. NeurIPS 35, 22982–22994 (2022)
  • [24] Han, K., Wang, Y., Xu, C., Guo, J., Xu, C., Wu, E., Tian, Q.: Ghostnets on heterogeneous devices via cheap operations. IJCV 130(4), 1050–1069 (2022)
  • [25] Hasani, R., Lechner, M., Wang, T.H., Chahine, M., Amini, A., Rus, D.: Liquid structural state-space models. arXiv preprint arXiv:2209.12951 (2022)
  • [26] He, K., Gkioxari, G., Dollár, P., Girshick, R.: Mask r-cnn. In: ICCV. pp. 2961–2969 (2017)
  • [27] He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In: CVPR. pp. 770–778 (2016)
  • [28] Hochreiter, S., Schmidhuber, J.: Long short-term memory. Neural computation 9(8), 1735–1780 (1997)
  • [29] Howard, A.G., Zhu, M., Chen, B., Kalenichenko, D., Wang, W., Weyand, T., Andreetto, M., Adam, H.: Mobilenets: Efficient convolutional neural networks for mobile vision applications. arXiv preprint arXiv:1704.04861 (2017)
  • [30] Hu, E.J., Shen, Y., Wallis, P., Allen-Zhu, Z., Li, Y., Wang, S., Wang, L., Chen, W.: Lora: Low-rank adaptation of large language models. In: ICLR (2022)
  • [31] Huang, T., Pei, X., You, S., Wang, F., Qian, C., Xu, C.: Localmamba: Visual state space model with windowed selective scan. arXiv preprint arXiv:2403.09338 (2024)
  • [32] Krizhevsky, A., Sutskever, I., Hinton, G.E.: Imagenet classification with deep convolutional neural networks. NeurIPS 25 (2012)
  • [33] Lai, G., Xie, Q., Liu, H., Yang, Y., Hovy, E.H.: RACE: large-scale reading comprehension dataset from examinations. In: EMNLP. pp. 785–794. Association for Computational Linguistics (2017)
  • [34] Li, S., Singh, H., Grover, A.: Mamba-nd: Selective state space modeling for multi-dimensional data. arXiv preprint arXiv:2402.05892 (2024)
  • [35] Li, Y., Song, L., Chen, Y., Li, Z., Zhang, X., Wang, X., Sun, J.: Learning dynamic routing for semantic segmentation. In: CVPR (2020)
  • [36] Lin, T.Y., Maire, M., Belongie, S., Hays, J., Perona, P., Ramanan, D., Dollár, P., Zitnick, C.L.: Microsoft coco: Common objects in context. In: ECCV. pp. 740–755. Springer (2014)
  • [37] Liu, S., Chen, T., Chen, X., Chen, X., Xiao, Q., Wu, B., Kärkkäinen, T., Pechenizkiy, M., Mocanu, D., Wang, Z.: More convnets in the 2020s: Scaling up kernels beyond 51x51 using sparsity. arXiv preprint arXiv:2207.03620 (2022)
  • [38] Liu, Y., Tian, Y., Zhao, Y., Yu, H., Xie, L., Wang, Y., Ye, Q., Liu, Y.: Vmamba: Visual state space model. arXiv preprint arXiv:2401.10166 (2024)
  • [39] Liu, Z., Hu, H., Lin, Y., Yao, Z., Xie, Z., Wei, Y., Ning, J., Cao, Y., Zhang, Z., Dong, L., et al.: Swin transformer v2: Scaling up capacity and resolution. In: CVPR. pp. 12009–12019 (2022)
  • [40] Liu, Z., Lin, Y., Cao, Y., Hu, H., Wei, Y., Zhang, Z., Lin, S., Guo, B.: Swin transformer: Hierarchical vision transformer using shifted windows. In: ICCV. pp. 10012–10022 (2021)
  • [41] Liu, Z., Mao, H., Wu, C.Y., Feichtenhofer, C., Darrell, T., Xie, S.: A convnet for the 2020s. In: CVPR. pp. 11976–11986 (2022)
  • [42] Loshchilov, I., Hutter, F.: Decoupled weight decay regularization. arXiv preprint arXiv:1711.05101 (2017)
  • [43] Mihaylov, T., Clark, P., Khot, T., Sabharwal, A.: Can a suit of armor conduct electricity? A new dataset for open book question answering. In: EMNLP. pp. 2381–2391. Association for Computational Linguistics (2018)
  • [44] Nguyen, E., Goel, K., Gu, A., Downs, G.W., Shah, P., Dao, T., Baccus, S.A., Ré, C.: S4nd: Modeling images and videos as multidimensional signals using state spaces. arXiv preprint arXiv:2210.06583 (2022)
  • [45] Radford, A., Wu, J., Child, R., Luan, D., Amodei, D., Sutskever, I., et al.: Language models are unsupervised multitask learners. OpenAI blog 1(8),  9 (2019)
  • [46] Ren, S., Yang, X., Liu, S., Wang, X.: Sg-former: Self-guided transformer with evolving token reallocation. In: ICCV. pp. 6003–6014 (2023)
  • [47] Simonyan, K., Zisserman, A.: Very deep convolutional networks for large-scale image recognition. In: Bengio, Y., LeCun, Y. (eds.) ICLR (2015)
  • [48] Smith, J.T., Warrington, A., Linderman, S.W.: Simplified state space layers for sequence modeling. arXiv preprint arXiv:2208.04933 (2022)
  • [49] Song, L., Li, Y., Jiang, Z., Li, Z., Sun, H., Sun, J., Zheng, N.: Fine-grained dynamic head for object detection. NIPS (2020)
  • [50] Song, L., Li, Y., Li, Z., Yu, G., Sun, H., Sun, J., Zheng, N.: Learnable tree filter for structure-preserving feature transform. NeurIPS 32 (2019)
  • [51] Song, L., Zhang, S., Yu, G., Sun, H.: Tacnet: Transition-aware context network for spatio-temporal action detection. In: CVPR (2019)
  • [52] Song, L., Zhang, S., Liu, S., Li, Z., He, X., Sun, H., Sun, J., Zheng, N.: Dynamic grained encoder for vision transformers. NIPS (2021)
  • [53] Taori, R., Gulrajani, I., Zhang, T., Dubois, Y., Li, X., Guestrin, C., Liang, P., Hashimoto, T.B.: Stanford alpaca: An instruction-following llama model (2023)
  • [54] Touvron, H., Cord, M., Douze, M., Massa, F., Sablayrolles, A., Jégou, H.: Training data-efficient image transformers & distillation through attention. In: ICML. pp. 10347–10357. PMLR (2021)
  • [55] Wang, A., Singh, A., Michael, J., Hill, F., Levy, O., Bowman, S.R.: GLUE: A multi-task benchmark and analysis platform for natural language understanding. In: ICLR (2019)
  • [56] Wang, J., Song, L., Li, Z., Sun, H., Sun, J., Zheng, N.: End-to-end object detection with fully convolutional network. In: CVPR (2021)
  • [57] Wang, W., Dai, J., Chen, Z., Huang, Z., Li, Z., Zhu, X., Hu, X., Lu, T., Lu, L., Li, H., et al.: Internimage: Exploring large-scale vision foundation models with deformable convolutions. In: CVPR. pp. 14408–14419 (2023)
  • [58] Wang, W., Xie, E., Li, X., Fan, D.P., Song, K., Liang, D., Lu, T., Luo, P., Shao, L.: Pyramid vision transformer: A versatile backbone for dense prediction without convolutions. In: ICCV. pp. 568–578 (2021)
  • [59] Williams, R.L., Lawrence, D.A., et al.: Linear state-space control systems. John Wiley & Sons (2007)
  • [60] Xiao, T., Liu, Y., Zhou, B., Jiang, Y., Sun, J.: Unified perceptual parsing for scene understanding. In: ECCV. pp. 418–434 (2018)
  • [61] Xiao, Y., Luo, Z., Liu, Y., Ma, Y., Bian, H., Ji, Y., Yang, Y., Li, X.: Bridging the gap: A unified video comprehension framework for moment retrieval and highlight detection. CVPR (2024)
  • [62] Yang, C., Chen, Z., Espinosa, M., Ericsson, L., Wang, Z., Liu, J., Crowley, E.J.: Plainmamba: Improving non-hierarchical mamba in visual recognition. arXiv preprint arXiv:2403.17695 (2024)
  • [63] Yang, J., Song, L., Liu, S., Li, Z., Li, X., Sun, H., Sun, J., Zheng, N.: Dbq-ssd: Dynamic ball query for efficient 3d object detection. arXiv preprint arXiv:2207.10909 (2022)
  • [64] Yang, Q.: Stereo matching using tree filtering. IEEE TPAMI 37(4), 834–846 (2014)
  • [65] Yang, R., Song, L., Ge, Y., Li, X.: Boxsnake: Polygonal instance segmentation with box supervision. In: Proceedings of the IEEE/CVF International Conference on Computer Vision (2023)
  • [66] Zhang, S., Song, L., Gao, C., Sang, N.: Glnet: Global local network for weakly supervised action localization. IEEE Transactions on Multimedia 22(10), 2610–2622 (2019)
  • [67] Zhang, S., Song, L., Liu, S., Ge, Z., Li, Z., He, X., Sun, J.: Workshop on autonomous driving at cvpr 2021: Technical report for streaming perception challenge. arXiv preprint arXiv:2108.04230 (2021)
  • [68] Zhou, B., Zhao, H., Puig, X., Fidler, S., Barriuso, A., Torralba, A.: Scene parsing through ade20k dataset. In: CVPR. pp. 633–641 (2017)
  • [69] Zhou, H., Yang, R., Zhang, Y., Duan, H., Huang, Y., Hu, R., Li, X., Zheng, Y.: Unihead: unifying multi-perception for detection heads. TNNLS (2023)
  • [70] Zhu, L., Liao, B., Zhang, Q., Wang, X., Liu, W., Wang, X.: Vision mamba: Efficient visual representation learning with bidirectional state space model. arXiv preprint arXiv:2401.09417 (2024)

Appendix

Appendix A Detailed Training Settings and Results

A.1 Image Classification.

We follow the previous works [57, 38, 40] to conduct the experiments. The models are trained with thirty-two 32GB V100 GPUs by default. We set betas and momentum of the AdamW [42, 69, 61] optimizer with (0.9,0.999)0.90.999(0.9,0.999)( 0.9 , 0.999 ) and 0.90.90.90.9, respectively. During training, we utilize a Cosine Scheduler with an initial learning rate of 1×1031superscript1031\times 10^{-3}1 × 10 start_POSTSUPERSCRIPT - 3 end_POSTSUPERSCRIPT and weight decay of 0.050.050.050.05. We adopt the common training data augmentation strategies following [31, 57], including AutoAugment [9] with rand-m9-mstd0.5-inc1𝑟𝑎𝑛𝑑-𝑚9-𝑚𝑠𝑡𝑑0.5-𝑖𝑛𝑐1rand\text{-}m9\text{-}mstd0.5\text{-}inc1italic_r italic_a italic_n italic_d - italic_m 9 - italic_m italic_s italic_t italic_d 0.5 - italic_i italic_n italic_c 1. A MixUp strategy with a ratio of 0.80.80.80.8 is also adopted in each batch. Horizontal flip and Random resized crop strategy are both used in the process of training.

Refer to caption
Figure 5: Classification performance comparison among SSM-based vision foundation models.

Performance Comparison.

We compare various SSM-based visual foundation models as shown in Fig. 5, with different colors representing different models and different shapes indicating different model scales. The size of each shape indicates the number of model parameters. The horizontal axis denotes FLOPs and the vertical axis represents the Top-1 accuracy of the corresponding method on ImageNet-1K val dataset.  Fig. 5 demonstrates that GrootV is the best choice in terms of efficiency and effectiveness.

A.2 Object Detection.

For a fair comparison, we conduct the evaluation following common practice [57, 38, 40]. The models are trained with eight 32GB V100 GPUs by default. The input image is resized so that the shorter side is 800800800800 pixels, while the longer side does not exceed 1333133313331333 pixels during the 1×1\times1 × schedule. The number of warmup steps is set to 500500500500 in the 1×1\times1 × schedule. For 3×3\times3 × schedule, the shorter side is resized to 480-800480-800480\text{-}800480 - 800 pixels and the longer side does not exceed 1333133313331333 pixels. The number of warmup steps is set to 1000100010001000 in 3×3\times3 × schedule. Results shown in Table 7 demonstrate the effectiveness of GrootV in object detection and instance segmentation on COCO val2017.

Method #FLOPs. Mask R-CNN 1×\times× Zeitplan Mask R-CNN 3×\times× MS Schedule
APb AP50bsubscriptsuperscriptabsent𝑏50{}^{b}_{50}start_FLOATSUPERSCRIPT italic_b end_FLOATSUPERSCRIPT start_POSTSUBSCRIPT 50 end_POSTSUBSCRIPT AP75bsubscriptsuperscriptabsent𝑏75{}^{b}_{75}start_FLOATSUPERSCRIPT italic_b end_FLOATSUPERSCRIPT start_POSTSUBSCRIPT 75 end_POSTSUBSCRIPT APm AP50msubscriptsuperscriptabsent𝑚50{}^{m}_{50}start_FLOATSUPERSCRIPT italic_m end_FLOATSUPERSCRIPT start_POSTSUBSCRIPT 50 end_POSTSUBSCRIPT AP75msubscriptsuperscriptabsent𝑚75{}^{m}_{75}start_FLOATSUPERSCRIPT italic_m end_FLOATSUPERSCRIPT start_POSTSUBSCRIPT 75 end_POSTSUBSCRIPT APb AP50bsubscriptsuperscriptabsent𝑏50{}^{b}_{50}start_FLOATSUPERSCRIPT italic_b end_FLOATSUPERSCRIPT start_POSTSUBSCRIPT 50 end_POSTSUBSCRIPT AP75bsubscriptsuperscriptabsent𝑏75{}^{b}_{75}start_FLOATSUPERSCRIPT italic_b end_FLOATSUPERSCRIPT start_POSTSUBSCRIPT 75 end_POSTSUBSCRIPT APm AP50msubscriptsuperscriptabsent𝑚50{}^{m}_{50}start_FLOATSUPERSCRIPT italic_m end_FLOATSUPERSCRIPT start_POSTSUBSCRIPT 50 end_POSTSUBSCRIPT AP75msubscriptsuperscriptabsent𝑚75{}^{m}_{75}start_FLOATSUPERSCRIPT italic_m end_FLOATSUPERSCRIPT start_POSTSUBSCRIPT 75 end_POSTSUBSCRIPT
Swin-T [40] 267G 42.7 65.2 46.8 39.3 62.2 42.2 46.0 68.1 50.3 41.6 65.1 44.9
ConvNeXt-T [41] 262G 44.2 66.6 48.3 40.1 63.3 42.8 46.2 67.9 50.8 41.7 65.0 44.9
CSWin-T [14] 279G 46.7 68.6 51.3 42.2 65.6 45.4 49.0 70.7 53.7 43.6 67.9 46.6
ViM-S [70] 218G 44.9 67.1 49.3 41.0 64.2 44.1 - - - - - -
VMamba-T [38] 286G 46.5 68.5 50.7 42.1 65.5 45.3 48.5 69.9 52.9 43.2 66.8 46.3
L-Vmamba-T [31] 291G 46.7 68.7 50.8 42.2 65.7 45.5 48.7 70.1 53.0 43.4 67.0 46.4
GrootV-T (Ours) 265G 47.0 69.4 51.5 42.7 66.4 46.0 49.0 70.8 54.0 43.8 67.6 47.1
Vit-Adapter-S [4] 403G 44.7 65.8 48.3 39.9 62.5 42.8 48.2 69.7 52.5 42.8 66.4 45.9
Swin-S [40] 354G 44.8 66.6 48.9 40.9 63.4 44.2 48.2 69.8 52.8 43.2 67.0 46.1
ConvNeXt-T [41] 348G 45.4 67.9 50.0 41.8 65.2 45.1 47.9 70.0 52.7 42.9 66.9 46.2
InternImage-S [57] 340G 47.8 69.8 52.8 43.3 67.1 46.7 49.7 71.1 54.5 44.5 68.5 47.8
VMamba-S [38] 400G 48.2 69.7 52.5 43.0 66.6 46.4 49.7 70.4 54.2 44.0 67.6 47.3
L-Vmamba-S [31] 414G 48.4 69.9 52.7 43.2 66.7 46.5 49.9 70.5 54.4 44.1 67.8 47.4
GrootV-S (Ours) 341G 48.6 70.3 53.5 43.6 67.5 47.1 50.1 71.2 54.9 44.6 68.7 47.8
Table 7: Object detection and instance segmentation performance on COCO val2017. APb and APm indicate the mAP of detection and segmentation, respectively. MS indicates the multi-scale training strategy.

A.3 Semantic Segmentation.

We optimize our GrootV-T/S using AdamW optimizer with an initial learning rate of 6×1056superscript1056\times 10^{-5}6 × 10 start_POSTSUPERSCRIPT - 5 end_POSTSUPERSCRIPT which is decayed by a rate of 1.01.01.01.0 with the polynomial decay schedule following [57]. The number of warmup iters is set to 1600160016001600 with an initial learning rate of 1×1061superscript1061\times 10^{-6}1 × 10 start_POSTSUPERSCRIPT - 6 end_POSTSUPERSCRIPT. The default input resolution is 512×512512512512\times 512512 × 512 as well as FLOPs are calculated with an input size of 512×20485122048512\times 2048512 × 2048. The models are trained with eight 32GB V100 GPUs by default.

Appendix B Language Tree Topology Scanning Operator

Algorithm 2 Language Tree Scanning
0:  Input feature {xi}i=1Lsuperscriptsubscriptsubscript𝑥𝑖𝑖1𝐿\{x_{i}\}_{i=1}^{L}{ italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT; Input matrix {𝐁¯i}i=1Lsuperscriptsubscriptsubscript¯𝐁𝑖𝑖1𝐿\{\bar{\mathbf{B}}_{i}\}_{i=1}^{L}{ over¯ start_ARG bold_B end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT; State matrix {𝐀¯i}i=1Lsuperscriptsubscriptsubscript¯𝐀𝑖𝑖1𝐿\{\bar{\mathbf{A}}_{i}\}_{i=1}^{L}{ over¯ start_ARG bold_A end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT; Gradient of loss to hidden states {Losshi}i=1Lsuperscriptsubscript𝐿𝑜𝑠𝑠subscript𝑖𝑖1𝐿\{\frac{\partial Loss}{\partial h_{i}}\}_{i=1}^{L}{ divide start_ARG ∂ italic_L italic_o italic_s italic_s end_ARG start_ARG ∂ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT; Minimum Spanning Tree 𝒢Tsubscript𝒢𝑇\mathcal{G}_{T}caligraphic_G start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT.
0:  Root,,LeafBFS(𝒢T)𝑅𝑜𝑜𝑡𝐿𝑒𝑎𝑓𝐵𝐹𝑆subscript𝒢𝑇Root,\dots,Leaf\leftarrow BFS(\mathcal{G}_{T})italic_R italic_o italic_o italic_t , … , italic_L italic_e italic_a italic_f ← italic_B italic_F italic_S ( caligraphic_G start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT )       contains-as-subgroup\rhd Breadth-first topological order of 𝒢Tsubscript𝒢𝑇\mathcal{G}_{T}caligraphic_G start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT
0:  
  Initialization: {ξi}i=1L{xi}i=1Lsuperscriptsubscriptsubscript𝜉𝑖𝑖1𝐿superscriptsubscriptsubscript𝑥𝑖𝑖1𝐿\{\xi_{i}\}_{i=1}^{L}\leftarrow\{x_{i}\}_{i=1}^{L}{ italic_ξ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT ← { italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT
2:  for iLeaf𝑖𝐿𝑒𝑎𝑓i\leftarrow{Leaf}italic_i ← italic_L italic_e italic_a italic_f to Root𝑅𝑜𝑜𝑡{Root}italic_R italic_o italic_o italic_t do
     ξi=𝐁¯ixi+j{tPar(t)=i}ξj𝐀¯jsubscript𝜉𝑖subscript¯𝐁𝑖subscript𝑥𝑖subscriptfor-all𝑗conditional-set𝑡Par𝑡𝑖subscript𝜉𝑗subscript¯𝐀𝑗\xi_{i}=\bar{\mathbf{B}}_{i}x_{i}+\sum_{\forall j\in\{t\mid\text{Par}(t)=i\}}% \xi_{j}\bar{\mathbf{A}}_{j}italic_ξ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = over¯ start_ARG bold_B end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + ∑ start_POSTSUBSCRIPT ∀ italic_j ∈ { italic_t ∣ Par ( italic_t ) = italic_i } end_POSTSUBSCRIPT italic_ξ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT over¯ start_ARG bold_A end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT
4:  end for
4:  
  for iRoot𝑖𝑅𝑜𝑜𝑡i\leftarrow{Root}italic_i ← italic_R italic_o italic_o italic_t to Leaf𝐿𝑒𝑎𝑓{Leaf}italic_L italic_e italic_a italic_f do
6:     if i𝑖iitalic_i is Root𝑅𝑜𝑜𝑡Rootitalic_R italic_o italic_o italic_t then
        Lossxi=ηi𝐁¯i𝐿𝑜𝑠𝑠subscript𝑥𝑖subscript𝜂𝑖subscript¯𝐁𝑖\frac{\partial Loss}{\partial x_{i}}=\eta_{i}\bar{\mathbf{B}}_{i}divide start_ARG ∂ italic_L italic_o italic_s italic_s end_ARG start_ARG ∂ italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG = italic_η start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT over¯ start_ARG bold_B end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ,  Loss𝐁¯i=ηixi𝐿𝑜𝑠𝑠subscript¯𝐁𝑖subscript𝜂𝑖subscript𝑥𝑖\frac{\partial Loss}{\partial\bar{\mathbf{B}}_{i}}=\eta_{i}x_{i}divide start_ARG ∂ italic_L italic_o italic_s italic_s end_ARG start_ARG ∂ over¯ start_ARG bold_B end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG = italic_η start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT,  Loss𝐀¯i=0𝐿𝑜𝑠𝑠subscript¯𝐀𝑖0\frac{\partial Loss}{\partial\bar{\mathbf{A}}_{i}}=0divide start_ARG ∂ italic_L italic_o italic_s italic_s end_ARG start_ARG ∂ over¯ start_ARG bold_A end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG = 0
8:     else
        Lossxi=Losshi𝐁¯i+𝐀¯iLossxPar(i)𝐁¯i𝐿𝑜𝑠𝑠subscript𝑥𝑖𝐿𝑜𝑠𝑠subscript𝑖subscript¯𝐁𝑖subscript¯𝐀𝑖𝐿𝑜𝑠𝑠subscript𝑥Par𝑖subscript¯𝐁𝑖\frac{\partial Loss}{\partial x_{i}}=\frac{\partial Loss}{\partial h_{i}}{\bar% {\mathbf{B}}}_{i}+\bar{\mathbf{A}}_{i}\frac{\partial Loss}{\partial x_{\text{% Par}(i)}}\bar{\mathbf{B}}_{i}divide start_ARG ∂ italic_L italic_o italic_s italic_s end_ARG start_ARG ∂ italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG = divide start_ARG ∂ italic_L italic_o italic_s italic_s end_ARG start_ARG ∂ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG over¯ start_ARG bold_B end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + over¯ start_ARG bold_A end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT divide start_ARG ∂ italic_L italic_o italic_s italic_s end_ARG start_ARG ∂ italic_x start_POSTSUBSCRIPT Par ( italic_i ) end_POSTSUBSCRIPT end_ARG over¯ start_ARG bold_B end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ,  Loss𝐁¯i=Losshixi+𝐀¯iLoss𝐁¯Par(i)xi𝐿𝑜𝑠𝑠subscript¯𝐁𝑖𝐿𝑜𝑠𝑠subscript𝑖subscript𝑥𝑖subscript¯𝐀𝑖𝐿𝑜𝑠𝑠subscript¯𝐁Par𝑖subscript𝑥𝑖\frac{\partial Loss}{\partial\bar{\mathbf{B}}_{i}}=\frac{\partial Loss}{% \partial h_{i}}x_{i}+\bar{\mathbf{A}}_{i}\frac{\partial Loss}{\partial\bar{% \mathbf{B}}_{\text{Par}(i)}}x_{i}divide start_ARG ∂ italic_L italic_o italic_s italic_s end_ARG start_ARG ∂ over¯ start_ARG bold_B end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG = divide start_ARG ∂ italic_L italic_o italic_s italic_s end_ARG start_ARG ∂ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + over¯ start_ARG bold_A end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT divide start_ARG ∂ italic_L italic_o italic_s italic_s end_ARG start_ARG ∂ over¯ start_ARG bold_B end_ARG start_POSTSUBSCRIPT Par ( italic_i ) end_POSTSUBSCRIPT end_ARG italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT
10:        Loss𝐀¯i=LossxPar(i)hi𝐿𝑜𝑠𝑠subscript¯𝐀𝑖𝐿𝑜𝑠𝑠superscriptsubscript𝑥𝑃𝑎𝑟𝑖subscript𝑖\frac{\partial Loss}{\partial\bar{\mathbf{A}}_{i}}=\frac{\partial Loss}{% \partial{x_{Par(i)}^{\prime}}}h_{i}divide start_ARG ∂ italic_L italic_o italic_s italic_s end_ARG start_ARG ∂ over¯ start_ARG bold_A end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG = divide start_ARG ∂ italic_L italic_o italic_s italic_s end_ARG start_ARG ∂ italic_x start_POSTSUBSCRIPT italic_P italic_a italic_r ( italic_i ) end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_ARG italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT
     end if
12:  end for
12:  Hidden states {hi}i=1Lsuperscriptsubscriptsubscript𝑖𝑖1𝐿\{h_{i}\}_{i=1}^{L}{ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT; Grad. of loss to input feature {Lossxi}i=1Lsuperscriptsubscript𝐿𝑜𝑠𝑠subscript𝑥𝑖𝑖1𝐿\{\frac{\partial Loss}{\partial x_{i}}\}_{i=1}^{L}{ divide start_ARG ∂ italic_L italic_o italic_s italic_s end_ARG start_ARG ∂ italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT; Grad. of loss to input matrix {Loss𝐁¯i}i=1Lsuperscriptsubscript𝐿𝑜𝑠𝑠subscript¯𝐁𝑖𝑖1𝐿\{\frac{\partial Loss}{\partial\bar{\mathbf{B}}_{i}}\}_{i=1}^{L}{ divide start_ARG ∂ italic_L italic_o italic_s italic_s end_ARG start_ARG ∂ over¯ start_ARG bold_B end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT; Grad. of loss to state matrix {Loss𝐀¯i}i=1Lsuperscriptsubscript𝐿𝑜𝑠𝑠subscript¯𝐀𝑖𝑖1𝐿\{\frac{\partial Loss}{\partial\bar{\mathbf{A}}_{i}}\}_{i=1}^{L}{ divide start_ARG ∂ italic_L italic_o italic_s italic_s end_ARG start_ARG ∂ over¯ start_ARG bold_A end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT.

Appendix C Algorithm Proof

In this section, we present detailed proofs for our tree scanning algorithm. The definitions of symbols are consistent with those in the main paper.

C.1 Proof for Algorithm 1.

We randomly take a vertex in the MST 𝒢Tsubscript𝒢𝑇\mathcal{G}_{T}caligraphic_G start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT as the root𝑟𝑜𝑜𝑡rootitalic_r italic_o italic_o italic_t. According to the definition of the tree scanning algorithm introduced in Sec. 3.2, we can derive hrootsubscript𝑟𝑜𝑜𝑡h_{root}italic_h start_POSTSUBSCRIPT italic_r italic_o italic_o italic_t end_POSTSUBSCRIPT as follows:

hroot=jCrootS(Eroot,j)𝐁¯jxj,S(Eroot,j)=kNroot,j𝐀¯k,formulae-sequencesubscript𝑟𝑜𝑜𝑡subscriptfor-all𝑗subscript𝐶𝑟𝑜𝑜𝑡𝑆subscript𝐸𝑟𝑜𝑜𝑡𝑗subscript¯𝐁𝑗subscript𝑥𝑗𝑆subscript𝐸𝑟𝑜𝑜𝑡𝑗subscriptproduct𝑘subscript𝑁𝑟𝑜𝑜𝑡𝑗subscript¯𝐀𝑘h_{root}=\sum_{\forall j\in C_{root}}S(E_{root,j})\bar{\mathbf{B}}_{j}x_{j},% \quad S(E_{root,j})=\prod_{k\in N_{root,j}}\bar{\mathbf{A}}_{k},italic_h start_POSTSUBSCRIPT italic_r italic_o italic_o italic_t end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT ∀ italic_j ∈ italic_C start_POSTSUBSCRIPT italic_r italic_o italic_o italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_S ( italic_E start_POSTSUBSCRIPT italic_r italic_o italic_o italic_t , italic_j end_POSTSUBSCRIPT ) over¯ start_ARG bold_B end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , italic_S ( italic_E start_POSTSUBSCRIPT italic_r italic_o italic_o italic_t , italic_j end_POSTSUBSCRIPT ) = ∏ start_POSTSUBSCRIPT italic_k ∈ italic_N start_POSTSUBSCRIPT italic_r italic_o italic_o italic_t , italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT over¯ start_ARG bold_A end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , (9)

which shows a process of aggregation from all leaf vertices to the root𝑟𝑜𝑜𝑡rootitalic_r italic_o italic_o italic_t. Therefore, each vertex is only related to its child in this period. Taking vertex m𝑚mitalic_m as an example, the AggrmsubscriptAggr𝑚\text{Aggr}_{m}Aggr start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT can be derived as:

Aggrm(x)=𝐁¯mxm+k{tPar(t)=i}Aggrk(x)𝐀¯k.subscriptAggr𝑚𝑥subscript¯𝐁𝑚subscript𝑥𝑚subscriptfor-all𝑘conditional-set𝑡Par𝑡𝑖subscriptAggr𝑘𝑥subscript¯𝐀𝑘\text{Aggr}_{m}(x)=\bar{\mathbf{B}}_{m}x_{m}+\sum_{\forall k\in\{t\mid\text{% Par}(t)=i\}}\text{Aggr}_{k}(x)\bar{\mathbf{A}}_{k}.Aggr start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ( italic_x ) = over¯ start_ARG bold_B end_ARG start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT + ∑ start_POSTSUBSCRIPT ∀ italic_k ∈ { italic_t ∣ Par ( italic_t ) = italic_i } end_POSTSUBSCRIPT Aggr start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_x ) over¯ start_ARG bold_A end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT . (10)

We assume that one of the child of m𝑚mitalic_m is n𝑛nitalic_n and hnsubscript𝑛h_{n}italic_h start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT can be derived as following:

hn=Aggrn(x)+𝐀¯nAggr~m(x),subscript𝑛subscriptAggr𝑛𝑥subscript¯𝐀𝑛subscript~Aggr𝑚𝑥h_{n}=\text{Aggr}_{n}(x)+\bar{\mathbf{A}}_{n}\widetilde{\text{Aggr}}_{m}(x),italic_h start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT = Aggr start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_x ) + over¯ start_ARG bold_A end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT over~ start_ARG Aggr end_ARG start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ( italic_x ) , (11)

where Aggr~m(x)subscript~Aggr𝑚𝑥\widetilde{\text{Aggr}}_{m}(x)over~ start_ARG Aggr end_ARG start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ( italic_x ) indicates the aggregation value from the vertices ΩCmrootabsentΩsuperscriptsubscript𝐶𝑚𝑟𝑜𝑜𝑡\in\Omega\setminus C_{m}^{root}∈ roman_Ω ∖ italic_C start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_r italic_o italic_o italic_t end_POSTSUPERSCRIPT to vertex m𝑚mitalic_m. Therefore, we can obtain the propagation relationship between the hidden state of parent m𝑚mitalic_m and child n𝑛nitalic_n:

hnsubscript𝑛\displaystyle h_{n}italic_h start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT =Aggrn(x)+𝐀¯nAggr~m(x)absentsubscriptAggr𝑛𝑥subscript¯𝐀𝑛subscript~Aggr𝑚𝑥\displaystyle=\text{Aggr}_{n}(x)+\bar{\mathbf{A}}_{n}\widetilde{\text{Aggr}}_{% m}(x)= Aggr start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_x ) + over¯ start_ARG bold_A end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT over~ start_ARG Aggr end_ARG start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ( italic_x ) (12)
=Aggrn(x)+𝐀¯n(hm𝐀¯nAggrn(x))absentsubscriptAggr𝑛𝑥subscript¯𝐀𝑛subscript𝑚subscript¯𝐀𝑛subscriptAggr𝑛𝑥\displaystyle=\text{Aggr}_{n}(x)+\bar{\mathbf{A}}_{n}(h_{m}-\bar{\mathbf{A}}_{% n}\text{Aggr}_{n}(x))= Aggr start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_x ) + over¯ start_ARG bold_A end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT - over¯ start_ARG bold_A end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT Aggr start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_x ) )
=𝐀¯nhm+(1𝐀¯n2)Aggrn(x)absentsubscript¯𝐀𝑛subscript𝑚1superscriptsubscript¯𝐀𝑛2subscriptAggr𝑛𝑥\displaystyle=\bar{\mathbf{A}}_{n}h_{m}+(1-\bar{\mathbf{A}}_{n}^{2})\text{Aggr% }_{n}(x)= over¯ start_ARG bold_A end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT + ( 1 - over¯ start_ARG bold_A end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) Aggr start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_x )

Through the above derivation, we can calculate {hi}i=1Lsuperscriptsubscriptsubscript𝑖𝑖1𝐿\{h_{i}\}_{i=1}^{L}{ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT with only two traversals (i.e.formulae-sequence𝑖𝑒i.e.italic_i . italic_e ., the aggregation from leaf𝑙𝑒𝑎𝑓leafitalic_l italic_e italic_a italic_f to root𝑟𝑜𝑜𝑡rootitalic_r italic_o italic_o italic_t and the propagation from root𝑟𝑜𝑜𝑡rootitalic_r italic_o italic_o italic_t to leaf𝑙𝑒𝑎𝑓leafitalic_l italic_e italic_a italic_f) in the forward process as shown in Algorithm 1, thereby reducing the computational complexity from 𝒪(L2)𝒪superscript𝐿2\mathcal{O}(L^{2})caligraphic_O ( italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) to 𝒪(L)𝒪𝐿\mathcal{O}(L)caligraphic_O ( italic_L ).

Next, we analyze the backpropagation process in Algorithm 1. According to the chain rule, we can easily calculate the derivative of loss𝑙𝑜𝑠𝑠lossitalic_l italic_o italic_s italic_s with respect to xisubscript𝑥𝑖x_{i}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT:

lossxi𝑙𝑜𝑠𝑠subscript𝑥𝑖\displaystyle\frac{\partial loss}{\partial{x}_{i}}divide start_ARG ∂ italic_l italic_o italic_s italic_s end_ARG start_ARG ∂ italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG =jΩ loss hjhjxiabsentsubscript𝑗Ω loss subscript𝑗subscript𝑗subscript𝑥𝑖\displaystyle=\sum_{j\in\Omega}\frac{\partial\text{ loss }}{\partial{h}_{j}}% \frac{\partial{h}_{j}}{\partial x_{i}}= ∑ start_POSTSUBSCRIPT italic_j ∈ roman_Ω end_POSTSUBSCRIPT divide start_ARG ∂ loss end_ARG start_ARG ∂ italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG divide start_ARG ∂ italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG start_ARG ∂ italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG (13)
=𝐁¯ijΩS(Eji)losshjabsentsubscript¯𝐁𝑖subscript𝑗Ω𝑆subscript𝐸𝑗𝑖𝑙𝑜𝑠𝑠subscript𝑗\displaystyle=\bar{\mathbf{B}}_{i}\sum_{j\in\Omega}S\left({E}_{ji}\right)\frac% {\partial loss}{\partial{h}_{j}}= over¯ start_ARG bold_B end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_j ∈ roman_Ω end_POSTSUBSCRIPT italic_S ( italic_E start_POSTSUBSCRIPT italic_j italic_i end_POSTSUBSCRIPT ) divide start_ARG ∂ italic_l italic_o italic_s italic_s end_ARG start_ARG ∂ italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG

Similarly, the derivative of loss𝑙𝑜𝑠𝑠lossitalic_l italic_o italic_s italic_s with respect to 𝐁¯isubscript¯𝐁𝑖\bar{\mathbf{B}}_{i}over¯ start_ARG bold_B end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is:

loss𝐁¯i𝑙𝑜𝑠𝑠subscript¯𝐁𝑖\displaystyle\frac{\partial loss}{\partial\bar{\mathbf{B}}_{i}}divide start_ARG ∂ italic_l italic_o italic_s italic_s end_ARG start_ARG ∂ over¯ start_ARG bold_B end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG =jΩ loss hjhj𝐁¯iabsentsubscript𝑗Ω loss subscript𝑗subscript𝑗subscript¯𝐁𝑖\displaystyle=\sum_{j\in\Omega}\frac{\partial\text{ loss }}{\partial{h}_{j}}% \frac{\partial{h}_{j}}{\partial\bar{\mathbf{B}}_{i}}= ∑ start_POSTSUBSCRIPT italic_j ∈ roman_Ω end_POSTSUBSCRIPT divide start_ARG ∂ loss end_ARG start_ARG ∂ italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG divide start_ARG ∂ italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG start_ARG ∂ over¯ start_ARG bold_B end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG (14)
=xijΩS(Eji)losshjabsentsubscript𝑥𝑖subscript𝑗Ω𝑆subscript𝐸𝑗𝑖𝑙𝑜𝑠𝑠subscript𝑗\displaystyle=x_{i}\sum_{j\in\Omega}S\left({E}_{ji}\right)\frac{\partial loss}% {\partial{h}_{j}}= italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_j ∈ roman_Ω end_POSTSUBSCRIPT italic_S ( italic_E start_POSTSUBSCRIPT italic_j italic_i end_POSTSUBSCRIPT ) divide start_ARG ∂ italic_l italic_o italic_s italic_s end_ARG start_ARG ∂ italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG

The above formulas are equivalent to replacing the input x𝑥xitalic_x with lossh𝑙𝑜𝑠𝑠\frac{\partial loss}{\partial{h}}divide start_ARG ∂ italic_l italic_o italic_s italic_s end_ARG start_ARG ∂ italic_h end_ARG during the forward process.

Subsequently, we assume that the vertex k𝑘kitalic_k is the child of vertex l𝑙litalic_l and define Clksuperscriptsubscript𝐶𝑙𝑘C_{l}^{k}italic_C start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT indicates the children of vertex l𝑙litalic_l with the root of vertex k𝑘kitalic_k. loss𝐀¯k𝑙𝑜𝑠𝑠subscript¯𝐀𝑘\frac{\partial loss}{\partial\bar{\mathbf{A}}_{k}}divide start_ARG ∂ italic_l italic_o italic_s italic_s end_ARG start_ARG ∂ over¯ start_ARG bold_A end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG is formulated as follows:

loss𝐀¯k𝑙𝑜𝑠𝑠subscript¯𝐀𝑘\displaystyle\frac{\partial loss}{\partial\bar{\mathbf{A}}_{k}}divide start_ARG ∂ italic_l italic_o italic_s italic_s end_ARG start_ARG ∂ over¯ start_ARG bold_A end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG =jΩlosshjhj𝐀¯kabsentsubscript𝑗Ω𝑙𝑜𝑠𝑠subscript𝑗subscript𝑗subscript¯𝐀𝑘\displaystyle=\sum_{j\in\Omega}\frac{\partial loss}{\partial{h}_{j}}\frac{% \partial h_{j}}{\partial\bar{\mathbf{A}}_{k}}= ∑ start_POSTSUBSCRIPT italic_j ∈ roman_Ω end_POSTSUBSCRIPT divide start_ARG ∂ italic_l italic_o italic_s italic_s end_ARG start_ARG ∂ italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG divide start_ARG ∂ italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG start_ARG ∂ over¯ start_ARG bold_A end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG (15)
=jΩlosshjpΩS(Ejp)𝐁¯pxp𝐀¯kabsentsubscript𝑗Ω𝑙𝑜𝑠𝑠subscript𝑗subscript𝑝Ω𝑆subscript𝐸𝑗𝑝subscript¯𝐁𝑝superscriptsubscript𝑥𝑝subscript¯𝐀𝑘\displaystyle=\sum_{j\in\Omega}\frac{\partial loss}{\partial{h}_{j}}\sum_{p\in% \Omega}\frac{\partial S(E_{jp})\bar{\mathbf{B}}_{p}x_{p}^{\prime}}{\partial% \bar{\mathbf{A}}_{k}}= ∑ start_POSTSUBSCRIPT italic_j ∈ roman_Ω end_POSTSUBSCRIPT divide start_ARG ∂ italic_l italic_o italic_s italic_s end_ARG start_ARG ∂ italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG ∑ start_POSTSUBSCRIPT italic_p ∈ roman_Ω end_POSTSUBSCRIPT divide start_ARG ∂ italic_S ( italic_E start_POSTSUBSCRIPT italic_j italic_p end_POSTSUBSCRIPT ) over¯ start_ARG bold_B end_ARG start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_ARG start_ARG ∂ over¯ start_ARG bold_A end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG
=jClklosshjpCklS(Ekp)S(Ejl)𝐁¯pxp+jCkllosshjpClkS(Ekj)S(Epl)𝐁¯pxpabsentsubscript𝑗superscriptsubscript𝐶𝑙𝑘𝑙𝑜𝑠𝑠subscript𝑗subscript𝑝superscriptsubscript𝐶𝑘𝑙𝑆subscript𝐸𝑘𝑝𝑆subscript𝐸𝑗𝑙subscript¯𝐁𝑝superscriptsubscript𝑥𝑝subscript𝑗superscriptsubscript𝐶𝑘𝑙𝑙𝑜𝑠𝑠subscript𝑗subscript𝑝superscriptsubscript𝐶𝑙𝑘𝑆subscript𝐸𝑘𝑗𝑆subscript𝐸𝑝𝑙subscript¯𝐁𝑝superscriptsubscript𝑥𝑝\displaystyle=\sum_{j\in C_{l}^{k}}\frac{\partial loss}{\partial{h}_{j}}\sum_{% p\in C_{k}^{l}}S(E_{kp})S(E_{jl})\bar{\mathbf{B}}_{p}x_{p}^{\prime}+\sum_{j\in C% _{k}^{l}}\frac{\partial loss}{\partial{h}_{j}}\sum_{p\in C_{l}^{k}}S(E_{kj})S(% E_{pl})\bar{\mathbf{B}}_{p}x_{p}^{\prime}= ∑ start_POSTSUBSCRIPT italic_j ∈ italic_C start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT end_POSTSUBSCRIPT divide start_ARG ∂ italic_l italic_o italic_s italic_s end_ARG start_ARG ∂ italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG ∑ start_POSTSUBSCRIPT italic_p ∈ italic_C start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_S ( italic_E start_POSTSUBSCRIPT italic_k italic_p end_POSTSUBSCRIPT ) italic_S ( italic_E start_POSTSUBSCRIPT italic_j italic_l end_POSTSUBSCRIPT ) over¯ start_ARG bold_B end_ARG start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT + ∑ start_POSTSUBSCRIPT italic_j ∈ italic_C start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT end_POSTSUBSCRIPT divide start_ARG ∂ italic_l italic_o italic_s italic_s end_ARG start_ARG ∂ italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG ∑ start_POSTSUBSCRIPT italic_p ∈ italic_C start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_S ( italic_E start_POSTSUBSCRIPT italic_k italic_j end_POSTSUBSCRIPT ) italic_S ( italic_E start_POSTSUBSCRIPT italic_p italic_l end_POSTSUBSCRIPT ) over¯ start_ARG bold_B end_ARG start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT
=jClkS(Ejl)losshjpCklS(Ekp)𝐁¯pxp+jCklS(Ekj)losshjpClkS(Epl)𝐁¯pxpabsentsubscript𝑗superscriptsubscript𝐶𝑙𝑘𝑆subscript𝐸𝑗𝑙𝑙𝑜𝑠𝑠subscript𝑗subscript𝑝superscriptsubscript𝐶𝑘𝑙𝑆subscript𝐸𝑘𝑝subscript¯𝐁𝑝superscriptsubscript𝑥𝑝subscript𝑗superscriptsubscript𝐶𝑘𝑙𝑆subscript𝐸𝑘𝑗𝑙𝑜𝑠𝑠subscript𝑗subscript𝑝superscriptsubscript𝐶𝑙𝑘𝑆subscript𝐸𝑝𝑙subscript¯𝐁𝑝superscriptsubscript𝑥𝑝\displaystyle=\sum_{j\in C_{l}^{k}}S(E_{jl})\frac{\partial loss}{\partial{h}_{% j}}\sum_{p\in C_{k}^{l}}S(E_{kp})\bar{\mathbf{B}}_{p}x_{p}^{\prime}+\sum_{j\in C% _{k}^{l}}S(E_{kj})\frac{\partial loss}{\partial{h}_{j}}\sum_{p\in C_{l}^{k}}S(% E_{pl})\bar{\mathbf{B}}_{p}x_{p}^{\prime}= ∑ start_POSTSUBSCRIPT italic_j ∈ italic_C start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_S ( italic_E start_POSTSUBSCRIPT italic_j italic_l end_POSTSUBSCRIPT ) divide start_ARG ∂ italic_l italic_o italic_s italic_s end_ARG start_ARG ∂ italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG ∑ start_POSTSUBSCRIPT italic_p ∈ italic_C start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_S ( italic_E start_POSTSUBSCRIPT italic_k italic_p end_POSTSUBSCRIPT ) over¯ start_ARG bold_B end_ARG start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT + ∑ start_POSTSUBSCRIPT italic_j ∈ italic_C start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_S ( italic_E start_POSTSUBSCRIPT italic_k italic_j end_POSTSUBSCRIPT ) divide start_ARG ∂ italic_l italic_o italic_s italic_s end_ARG start_ARG ∂ italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG ∑ start_POSTSUBSCRIPT italic_p ∈ italic_C start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_S ( italic_E start_POSTSUBSCRIPT italic_p italic_l end_POSTSUBSCRIPT ) over¯ start_ARG bold_B end_ARG start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT
=(Lossxk𝐀¯kAggrk(lossh))Aggrk(x)+Aggrk(lossh)(hk𝐀¯kAggrk(x))absent𝐿𝑜𝑠𝑠subscript𝑥𝑘subscript¯𝐀𝑘subscriptAggr𝑘𝑙𝑜𝑠𝑠subscriptAggr𝑘𝑥subscriptAggr𝑘𝑙𝑜𝑠𝑠subscript𝑘subscript¯𝐀𝑘subscriptAggr𝑘𝑥\displaystyle=(\frac{\partial Loss}{\partial x_{k}}-\bar{\mathbf{A}}_{k}\text{% Aggr}_{k}(\frac{\partial loss}{\partial{h}}))*\text{Aggr}_{k}(x)+\text{Aggr}_{% k}(\frac{\partial loss}{\partial{h}})*(h_{k}-\bar{\mathbf{A}}_{k}\text{Aggr}_{% k}(x))= ( divide start_ARG ∂ italic_L italic_o italic_s italic_s end_ARG start_ARG ∂ italic_x start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG - over¯ start_ARG bold_A end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT Aggr start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( divide start_ARG ∂ italic_l italic_o italic_s italic_s end_ARG start_ARG ∂ italic_h end_ARG ) ) ∗ Aggr start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_x ) + Aggr start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( divide start_ARG ∂ italic_l italic_o italic_s italic_s end_ARG start_ARG ∂ italic_h end_ARG ) ∗ ( italic_h start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - over¯ start_ARG bold_A end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT Aggr start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_x ) )
=LossxkAggrk(x)+Aggrk(lossh)hk2𝐀¯kAggrk(lossh)Aggrk(x)absent𝐿𝑜𝑠𝑠subscript𝑥𝑘subscriptAggr𝑘𝑥subscriptAggr𝑘𝑙𝑜𝑠𝑠subscript𝑘2subscript¯𝐀𝑘subscriptAggr𝑘𝑙𝑜𝑠𝑠subscriptAggr𝑘𝑥\displaystyle=\frac{\partial Loss}{\partial x_{k}}\text{Aggr}_{k}(x)+\text{% Aggr}_{k}(\frac{\partial loss}{\partial{h}})h_{k}-2\bar{\mathbf{A}}_{k}\text{% Aggr}_{k}(\frac{\partial loss}{\partial{h}})\text{Aggr}_{k}(x)= divide start_ARG ∂ italic_L italic_o italic_s italic_s end_ARG start_ARG ∂ italic_x start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG Aggr start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_x ) + Aggr start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( divide start_ARG ∂ italic_l italic_o italic_s italic_s end_ARG start_ARG ∂ italic_h end_ARG ) italic_h start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - 2 over¯ start_ARG bold_A end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT Aggr start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( divide start_ARG ∂ italic_l italic_o italic_s italic_s end_ARG start_ARG ∂ italic_h end_ARG ) Aggr start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_x )
Lossxkξk+ηkhk2𝐀¯kηkξk(definition in Algorithm 1)absent𝐿𝑜𝑠𝑠subscript𝑥𝑘subscript𝜉𝑘subscript𝜂𝑘subscript𝑘2subscript¯𝐀𝑘subscript𝜂𝑘subscript𝜉𝑘definition in Algorithm 1\displaystyle\triangleq\frac{\partial Loss}{\partial x_{k}}\xi_{k}+\eta_{k}h_{% k}-2\bar{\mathbf{A}}_{k}\eta_{k}\xi_{k}\quad(\textit{definition in~{}\lx@cref{% creftype~refnum}{algvis}})≜ divide start_ARG ∂ italic_L italic_o italic_s italic_s end_ARG start_ARG ∂ italic_x start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG italic_ξ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT + italic_η start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - 2 over¯ start_ARG bold_A end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_η start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_ξ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( definition in )

So far we have completed the proof of forward and back-propagation of Algorithm 1.

C.2 Proof for Algorithm 2.

We only take the last token as root and replace the transition source from ΩΩ\Omegaroman_Ω to Cisubscript𝐶𝑖C_{i}italic_C start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT in sequence modeling tasks like nature language understanding to ensure causality. Therefore, only one traversal (from leaf𝑙𝑒𝑎𝑓leafitalic_l italic_e italic_a italic_f to root𝑟𝑜𝑜𝑡rootitalic_r italic_o italic_o italic_t) is required for the forward process, and another traversal (from root𝑟𝑜𝑜𝑡rootitalic_r italic_o italic_o italic_t to leaf𝑙𝑒𝑎𝑓leafitalic_l italic_e italic_a italic_f) is needed for the backpropagation process. The proof is similar to the Algorithm 1.

Refer to caption
Figure 6: Visualization of affinity maps in the specific position. The Location is marked by the red cross in each affinity map. TP represents our Tree Scanning Algorithm.

Appendix D More Qualitative Results

Fig. 6 displays additional qualitative comparisons between our algorithm and previous scanning strategies (e.g.formulae-sequence𝑒𝑔e.g.italic_e . italic_g ., cross-scanning and raster-scanning), which shows our advanced capability to perceive detailed structural information and capture long-range dependencies.

Appendix E Statistical Significance

Method PIQA Arc-Easy SST WinoGrande LAM-ppl Race Openbookqa
GrootL (Ours) 0.011 0.010 0.016 0.014 0.553 0.014 0.018
Table 8: Standard error on language model benchmarks. LAM-ppl indicates LAMBADA [45].

We calculate the standard deviation of our GrootL on language model benchmarks in the open-sourced lm-evaluation-harness project as shown in Table 8.