DiG: Scalable and Efficient Diffusion Models with Gated Linear Attention

 Lianghui Zhu 1,2,⋄     Zilong Huang 22{}^{2~{}\textrm{{\char 0\relax}}}start_FLOATSUPERSCRIPT 2 ✉ end_FLOATSUPERSCRIPT     Bencheng Liao 1     Jun Hao Liew 2     Hanshu Yan 2
Jiashi Feng 2      Xinggang Wang 11{}^{1{~{}\textrm{{\char 0\relax}}}}start_FLOATSUPERSCRIPT 1 ✉ end_FLOATSUPERSCRIPT
1 School of EIC, Huazhong University of Science & Technology      2 ByteDance
Code & Models: hustvl/DiG
Abstract

Diffusion models with large-scale pre-training have achieved significant success in the field of visual content generation, particularly exemplified by Diffusion Transformers (DiT). However, DiT models have faced challenges with scalability and quadratic complexity efficiency. In this paper, we aim to leverage the long sequence modeling capability of Gated Linear Attention (GLA) Transformers, expanding its applicability to diffusion models. We introduce Diffusion Gated Linear Attention Transformers (DiG), a simple, adoptable solution with minimal parameter overhead, following the DiT design, but offering superior efficiency and effectiveness. In addition to better performance than DiT, DiG-S/2 exhibits 2.5×2.5\times2.5 × higher training speed than DiT-S/2 and saves 75.7%percent75.775.7\%75.7 % GPU memory at a resolution of 1792×1792179217921792\times 17921792 × 1792. Moreover, we analyze the scalability of DiG across a variety of computational complexity. DiG models, with increased depth/width or augmentation of input tokens, consistently exhibit decreasing FID. We further compare DiG with other subquadratic-time diffusion models. With the same model size, DiG-XL/2 is 4.2×4.2\times4.2 × faster than the recent Mamba-based diffusion model at a 1024102410241024 resolution, and is 1.8×1.8\times1.8 × faster than DiT with CUDA-optimized FlashAttention-2 under the 2048204820482048 resolution. All these results demonstrate its superior efficiency among the latest diffusion models.

footnotetext: This work was done when Lianghui Zhu was interning at ByteDance.
{{}^{~{}\textrm{{\char 0\relax}}}}start_FLOATSUPERSCRIPT ✉ end_FLOATSUPERSCRIPTCorresponding authors: Xinggang Wang ([email protected]) and Zilong Huang ([email protected])

1 Introduction

In recent years, diffusion models [20, 51, 4, 54] have emerged as potent deep generative models [46, 21, 13] renowned for their ability to generate high-quality images. Their rapid evolution has spurred extensive applications across various fields, including image-to-image generation [8, 63, 62], text-to-image generation [49, 45, 18, 6], speech synthesis [29, 7], video generation [19, 36, 34], and 3D generation [42, 64, 61]. Concurrent with the rapid development of sampling algorithms [52, 38, 33, 32, 22], the principal techniques have evolved into two main categories based on their architectural backbones: U-Net-based methods [20, 53] and ViT-based methods [14]. U-Net-based approaches continue to leverage the convolutional neural network (CNN) architecture [31, 48], whose hierarchical feature modeling ability benefits visual generation tasks. On the other hand, ViT-based methods [60, 1, 39] innovate by incorporating self-attention mechanisms [56] instead of traditional sampling blocks, resulting in streamlined yet effective performance.

Refer to caption
Figure 1: Efficiency comparison among DiT [39], DiS [16], and our DiG model. DiG achieves higher training speed while costs lower GPU memory in dealing with high-resolution images. For example, DiG is 2.5×2.5\times2.5 × faster than DiT and saves 75.7%percent75.775.7\%75.7 % GPU memory with a resolution of 1792×1792179217921792\times 17921792 × 1792, i.e., 12544 tokens per image. Patch size for all models is 2.
Refer to caption
Figure 2: FPS comparison among DiS [16], DiT [39], DiT with Flash Attention-2 (Flash-DiT) [11] and our DiG model varying from different model sizes. We take DiG as a baseline. With a resolution of 1024×1024102410241024\times 10241024 × 1024, DiG is 2.0×2.0\times2.0 × faster than DiS at small size while 4.2×4.2\times4.2 × faster at XL size. Furthermore, DiG-XL/2 is 1.8×1.8\times1.8 × faster than the most well-designed high-optimized Flash-DiT-XL/2 with a resolution of 2048×2048204820482048\times 20482048 × 2048.

Due to their excellent scalability in terms of performance, ViT-based methods [39] have been adopted as backbones in the most advanced diffusion works, including PixArt [6, 5], Sora [3], Stable Diffusion 3 [15], etc. However, the self-attention mechanism in ViT-based architectures scales quadratically with the input sequence length, making them resource-intensive when dealing with long sequence generation tasks, e.g., high-resolution image generation, video generation, etc. Recent advancements in subquadratic-time methods, i.e., Mamba [17], RWKV [40] and Gated Linear Attention Transformer (GLA) [59], try to improve the long-sequence processing efficiency by integrating Recurrent Neural Network (RNN) like architecture and hardware-aware algorithms. Among them, GLA incorporates data-dependent gating operation and hardware-efficient implementation to the Linear Attention Transformer, showing competitive performance but higher throughput.

Motivated by the success of GLA in the natural language processing domain, it is appealing that we can transfer this success from language generation to visual content generation, i.e., to design a scalable and efficient diffusion backbone with the advanced linear attention [26, 10, 25] method. However, visual generation with GLA faces two challenges, i.e., unidirectional scanning modeling and lack of local awareness. To address these challenges, we propose the Diffusion GLA (DiG) model, which incorporates a lightweight spatial reorient & enhancement module (SREM) for layer-wise scanning direction controlling and local awareness. At the end of each block, the SREM will change the sequence index with efficient matrix operation for different scanning of the next block. The scanning directions contain four basic patterns and enable each patch in sequences to be aware of other patches following crisscross directions. Furthermore, we also incorporate a depth-wise convolution (DWConv) [9] in the SREM to provide local awareness with extremely small amounts of parameters. Crucially, this paper presents a systematic ablation study that includes the integration of an SREM and the comprehensive evaluation of the model’s architecture. It is important to highlight that DiG adheres to the first practices of linear attention transformers in diffusion generation, renowned for their superior scalability and efficiency in image generation tasks.

Compared with the ViT-based method, i.e., DiT [39], DiG presents superior performance on ImageNet [12] generation with the same hyper-parameters. Furthermore, DiG is more efficient in terms of training speed and GPU memory for high-resolution image generation. The efficiency in terms of memory and speed empowers DiG to alleviate the resource constraint problem of long-sequence visual generation tasks. Notably, some Mamba-based subquadratic-time diffusion methods like DiS [16] often show lower efficiency as the model size scales due to the complicated block design and inability to efficiently utilize the GPU tensor core, as shown in Fig. 2. Thanks to the streamlined yet effective design of DiG block, the DiG can keep high efficiency with larger model sizes, and even outperforms the most well-designed high-optimized linear attention method, FlashAttention-2 [11], at a resolution of 1024×1024102410241024\times 10241024 × 1024.

Our main contributions can be summarized as follows:

  • We propose Diffusion GLA (DiG), which incorporates an efficient DiG block for both global visual context modeling through layer-wise scanning, and local visual awareness. To the best of our knowledge, DiG is the first exploration for diffusion backbone with linear attention transformer.

  • Without the burden of quadratic attention, the proposed DiG exhibits higher efficiency in both training speed and GPU memory cost while maintaining a similar modeling ability as DiT. Specifically, DiG is 2.5×2.5\times2.5 × faster than DiT and saves 75.7%percent75.775.7\%75.7 % GPU memory at the resolution of 1792×1792179217921792\times 17921792 × 1792 as shown in Fig. 1.

  • We conduct extensive experiments on the ImageNet dataset. The results demonstrate that DiG presents scalable ability and achieves superior performance when compared with DiT. DiG is promising to serve as the next-generation backbone for diffusion models in the context of large-scale long-sequence generation.

2 Related Work

2.1 Linear Attention Transformer

Different from standard autoregressive Transformer [57] which models the global attention matrix, the original linear attention [26] is essentially a linear RNN with matrix-valued-format hidden states. Linear attention introduces a similarity kernel k(x,y)𝑘𝑥𝑦k(x,y)italic_k ( italic_x , italic_y ) with an associated feature map ϕ()italic-ϕ\phi(\cdot)italic_ϕ ( ⋅ ), i.e., k(x,y)=ϕ(x),ϕ(y)𝑘𝑥𝑦italic-ϕ𝑥italic-ϕ𝑦k(x,y)=\langle\phi(x),\phi(y)\rangleitalic_k ( italic_x , italic_y ) = ⟨ italic_ϕ ( italic_x ) , italic_ϕ ( italic_y ) ⟩. The calculation of output 𝐎L×d𝐎superscript𝐿𝑑\mathbf{O}\in\mathbb{R}^{L\times d}bold_O ∈ blackboard_R start_POSTSUPERSCRIPT italic_L × italic_d end_POSTSUPERSCRIPT (here L𝐿Litalic_L is the sequence length and d𝑑ditalic_d is the dimension) can be represented as follows:

𝐎t=i=1tk(𝐐t,𝐊i)𝐕ii=1tk(𝐐t,𝐊i)=i=1tϕ(𝐐t)ϕ(𝐊i)𝐕ii=1tϕ(𝐐t)ϕ(𝐊i)=ϕ(𝐐t)i=1tϕ(𝐊i)𝐕iϕ(𝐐t)i=1tϕ(𝐊i),subscript𝐎𝑡superscriptsubscript𝑖1𝑡𝑘subscript𝐐𝑡subscript𝐊𝑖subscript𝐕𝑖superscriptsubscript𝑖1𝑡𝑘subscript𝐐𝑡subscript𝐊𝑖superscriptsubscript𝑖1𝑡italic-ϕsubscript𝐐𝑡italic-ϕsuperscriptsubscript𝐊𝑖topsubscript𝐕𝑖superscriptsubscript𝑖1𝑡italic-ϕsubscript𝐐𝑡italic-ϕsuperscriptsubscript𝐊𝑖topitalic-ϕsubscript𝐐𝑡superscriptsubscript𝑖1𝑡italic-ϕsuperscriptsubscript𝐊𝑖topsubscript𝐕𝑖italic-ϕsubscript𝐐𝑡superscriptsubscript𝑖1𝑡italic-ϕsuperscriptsubscript𝐊𝑖top\displaystyle\mathbf{O}_{t}=\frac{\sum_{i=1}^{t}k(\mathbf{Q}_{t},\mathbf{K}_{i% })\mathbf{V}_{i}}{\sum_{i=1}^{t}k(\mathbf{Q}_{t},\mathbf{K}_{i})}=\frac{\sum_{% i=1}^{t}\phi(\mathbf{Q}_{t})\phi(\mathbf{K}_{i})^{\top}\mathbf{V}_{i}}{\sum_{i% =1}^{t}\phi(\mathbf{Q}_{t})\phi(\mathbf{K}_{i})^{\top}}=\frac{\phi(\mathbf{Q}_% {t})\sum_{i=1}^{t}\phi(\mathbf{K}_{i})^{\top}\mathbf{V}_{i}}{\phi(\mathbf{Q}_{% t})\sum_{i=1}^{t}\phi(\mathbf{K}_{i})^{\top}},bold_O start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = divide start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_k ( bold_Q start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_K start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) bold_V start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_k ( bold_Q start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_K start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG = divide start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_ϕ ( bold_Q start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) italic_ϕ ( bold_K start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_V start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_ϕ ( bold_Q start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) italic_ϕ ( bold_K start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_ARG = divide start_ARG italic_ϕ ( bold_Q start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_ϕ ( bold_K start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_V start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG start_ARG italic_ϕ ( bold_Q start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_ϕ ( bold_K start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_ARG , (1)

where query 𝐐𝐐\mathbf{Q}bold_Q, key 𝐊𝐊\mathbf{K}bold_K, value 𝐕𝐕\mathbf{V}bold_V have shapes of L×d𝐿𝑑{L\times d}italic_L × italic_d and t𝑡titalic_t is the index of current token. By denoting hidden state 𝐒t=i=1tϕ(𝐊i)𝐕isubscript𝐒𝑡superscriptsubscript𝑖1𝑡italic-ϕsubscript𝐊𝑖subscript𝐕𝑖\mathbf{S}_{t}=\sum_{i=1}^{t}\phi(\mathbf{K}_{i})\mathbf{V}_{i}bold_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_ϕ ( bold_K start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) bold_V start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and normalizer zt=i=1tϕ(𝐊i)subscript𝑧𝑡superscriptsubscript𝑖1𝑡italic-ϕsuperscriptsubscript𝐊𝑖topz_{t}=\sum_{i=1}^{t}\phi(\mathbf{K}_{i})^{\top}italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_ϕ ( bold_K start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT where 𝐒td×d,ztd×1formulae-sequencesubscript𝐒𝑡superscript𝑑𝑑subscript𝑧𝑡superscript𝑑1\mathbf{S}_{t}\in\mathbb{R}^{d\times d},z_{t}\in\mathbb{R}^{d\times 1}bold_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_d end_POSTSUPERSCRIPT , italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × 1 end_POSTSUPERSCRIPT, the Eq. (1) can be rewritten as:

𝐒t=𝐒t1+ϕ(𝐊i)𝐕i,zt=zt1+ϕ(𝐊i),𝐎t=ϕ(𝐐t)𝐒tϕ(𝐐t)zt.formulae-sequencesubscript𝐒𝑡subscript𝐒𝑡1italic-ϕsubscript𝐊𝑖subscript𝐕𝑖formulae-sequencesubscript𝑧𝑡subscript𝑧𝑡1italic-ϕsuperscriptsubscript𝐊𝑖topsubscript𝐎𝑡italic-ϕsubscript𝐐𝑡subscript𝐒𝑡italic-ϕsubscript𝐐𝑡subscript𝑧𝑡\displaystyle\mathbf{S}_{t}=\mathbf{S}_{t-1}+\phi(\mathbf{K}_{i})\mathbf{V}_{i% },\quad z_{t}=z_{t-1}+\phi(\mathbf{K}_{i})^{\top},\quad\mathbf{O}_{t}=\frac{% \phi(\mathbf{Q}_{t})\mathbf{S}_{t}}{\phi(\mathbf{Q}_{t})z_{t}}.bold_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = bold_S start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT + italic_ϕ ( bold_K start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) bold_V start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_z start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT + italic_ϕ ( bold_K start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT , bold_O start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = divide start_ARG italic_ϕ ( bold_Q start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) bold_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_ϕ ( bold_Q start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG . (2)

Recent works set ϕ()italic-ϕ\phi(\cdot)italic_ϕ ( ⋅ ) to be the identity  [35, 55] and remove ztsubscript𝑧𝑡z_{t}italic_z start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT  [43], resulting linear attention Transformer with the following format:

𝐒t=𝐒t1+𝐊t𝐕t,𝐎t=𝐐t𝐒t.formulae-sequencesubscript𝐒𝑡subscript𝐒𝑡1superscriptsubscript𝐊𝑡topsubscript𝐕𝑡subscript𝐎𝑡subscript𝐐𝑡subscript𝐒𝑡\displaystyle\mathbf{S}_{t}=\mathbf{S}_{t-1}+\mathbf{K}_{t}^{\top}\mathbf{V}_{% t},\quad\mathbf{O}_{t}=\mathbf{Q}_{t}\mathbf{S}_{t}.bold_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = bold_S start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT + bold_K start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_V start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , bold_O start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = bold_Q start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT . (3)

Directly using a linear attention Transformer for visual generation leads to poor performance due to the unidirectional modeling, so we propose a lightweight spatial reorient & enhancement module to take care of both modeling global context in crisscross directions and local information.

2.2 Backbones in Diffusion Models

Existing diffusion models typically employ U-Net as backbones [20, 47] for image generation. Recently, Vision Transformer (ViT)-based backbones [39, 1, 6, 5, 3] receive significant attention due to the scalability of transformer and its natural fit for multi-modal learning. However, ViT-based architectures suffer from quadratic complexity, limiting their practicability in long sequence generation tasks, such as high-resolution image synthesis, video generation etc. To mitigate this, recent works explore subquadratic-time approaches to efficiently handle long sequences. For example, DiS [16], DiffuSSM [58] and ZigMa [23] employ state-space models as diffusion backbones for better computation efficiency. Diffusion-RWKV [58] adopt an RWKV architecture in diffusion models for image generation.

Our DiG also follows this line of research, aiming at improving the efficiency of long sequence processing by adopting Gated Linear Attention Transformer (GLA) as diffusion backbones. Our proposed adaptation maintains the fundamental structure and benefits of GLA while introducing a few crucial modifications necessary for generating high-fidelity visual data.

3 Method

3.1 Preliminaries

Gated Linear Attention Transformer.

The Gated Linear Attention Transformer (GLA) [59] combines a data-dependent gating mechanism and linear attention, achieving superior recurrent modeling performance. Given an input XL×dXsuperscript𝐿𝑑\mathrm{X}\in\mathbb{R}^{L\times d}roman_X ∈ blackboard_R start_POSTSUPERSCRIPT italic_L × italic_d end_POSTSUPERSCRIPT (here L𝐿Litalic_L is the sequence length and d𝑑ditalic_d is the dimension), GLA calculates the query, key, and value vectors as follows:

𝐐=𝐗𝐖QL×dk,𝐊=𝐗𝐖KL×dk,𝐕=𝐗𝐖VL×dv,formulae-sequence𝐐subscript𝐗𝐖𝑄superscript𝐿subscript𝑑𝑘𝐊subscript𝐗𝐖𝐾superscript𝐿subscript𝑑𝑘𝐕subscript𝐗𝐖𝑉superscript𝐿subscript𝑑𝑣\displaystyle\mathbf{Q}=\mathbf{X}\mathbf{W}_{Q}\in\mathbb{R}^{L\times d_{k}},% \quad\mathbf{K}=\mathbf{X}\mathbf{W}_{K}\in\mathbb{R}^{L\times d_{k}},\quad% \mathbf{V}=\mathbf{X}\mathbf{W}_{V}\in\mathbb{R}^{L\times d_{v}},bold_Q = bold_XW start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_L × italic_d start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT , bold_K = bold_XW start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_L × italic_d start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT , bold_V = bold_XW start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_L × italic_d start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT end_POSTSUPERSCRIPT , (4)

where 𝐖Qsubscript𝐖𝑄\mathbf{W}_{Q}bold_W start_POSTSUBSCRIPT italic_Q end_POSTSUBSCRIPT, 𝐖Ksubscript𝐖𝐾\mathbf{W}_{K}bold_W start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT, and 𝐖Vsubscript𝐖𝑉\mathbf{W}_{V}bold_W start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT are linear projection weights. dksubscript𝑑𝑘d_{k}italic_d start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT and dvsubscript𝑑𝑣d_{v}italic_d start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT are dimension numbers. Next, GLA compute the gating matrix 𝐆𝐆\mathbf{G}bold_G as follows:

𝐆t=αtβtdk×dv,α=σ(𝐗𝐖α+bα)τL×dk,β=σ(𝐗𝐖β+bβ)τL×dv,formulae-sequencesubscript𝐆𝑡superscriptsubscript𝛼𝑡topsubscript𝛽𝑡superscriptsubscript𝑑𝑘subscript𝑑𝑣𝛼𝜎subscript𝐗𝐖𝛼subscript𝑏𝛼𝜏superscript𝐿subscript𝑑𝑘𝛽𝜎subscript𝐗𝐖𝛽subscript𝑏𝛽𝜏superscript𝐿subscript𝑑𝑣\displaystyle\mathbf{G}_{t}=\alpha_{t}^{\top}\beta_{t}\in\mathbb{R}^{d_{k}% \times d_{v}},\quad\alpha=\frac{\sigma\left(\mathbf{XW}_{\alpha}+b_{\alpha}% \right)}{\tau}\in\mathbb{R}^{L\times d_{k}},\quad\beta=\frac{\sigma\left(% \mathbf{XW}_{\beta}+b_{\beta}\right)}{\tau}\in\mathbb{R}^{L\times d_{v}},bold_G start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_β start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT end_POSTSUPERSCRIPT , italic_α = divide start_ARG italic_σ ( bold_XW start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT + italic_b start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ) end_ARG start_ARG italic_τ end_ARG ∈ blackboard_R start_POSTSUPERSCRIPT italic_L × italic_d start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT , italic_β = divide start_ARG italic_σ ( bold_XW start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT + italic_b start_POSTSUBSCRIPT italic_β end_POSTSUBSCRIPT ) end_ARG start_ARG italic_τ end_ARG ∈ blackboard_R start_POSTSUPERSCRIPT italic_L × italic_d start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT end_POSTSUPERSCRIPT , (5)

where t𝑡titalic_t is the index of token, σ𝜎\sigmaitalic_σ is the 𝚜𝚒𝚐𝚖𝚘𝚒𝚍𝚜𝚒𝚐𝚖𝚘𝚒𝚍\mathtt{sigmoid}typewriter_sigmoid function, b𝑏bitalic_b is the bias term, and τ𝜏\tau\in\mathbb{R}italic_τ ∈ blackboard_R is a temperature term. As shown in Fig. 3, the final output 𝐘tsubscript𝐘𝑡\mathbf{Y}_{t}bold_Y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is obtained as follows:

Refer to caption
Figure 3: Pipeline of GLA.
𝐒t1superscriptsubscript𝐒𝑡1\displaystyle\mathbf{S}_{t-1}^{\prime}bold_S start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT =𝐆t𝐒t1dk×dv,absentdirect-productsubscript𝐆𝑡subscript𝐒𝑡1superscriptsubscript𝑑𝑘subscript𝑑𝑣\displaystyle=\mathbf{G}_{t}\odot\mathbf{S}_{t-1}\in\mathbb{R}^{d_{k}\times d_% {v}},= bold_G start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⊙ bold_S start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT end_POSTSUPERSCRIPT , (6)
𝐒tsubscript𝐒𝑡\displaystyle\mathbf{S}_{t}bold_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT =𝐒t1+𝐊tVtdk×dv,absentsuperscriptsubscript𝐒𝑡1superscriptsubscript𝐊𝑡topsubscript𝑉𝑡superscriptsubscript𝑑𝑘subscript𝑑𝑣\displaystyle=\mathbf{S}_{t-1}^{\prime}+\mathbf{K}_{t}^{\top}V_{t}\in\mathbb{R% }^{d_{k}\times d_{v}},= bold_S start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT + bold_K start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_V start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT × italic_d start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT end_POSTSUPERSCRIPT , (7)
𝐎tsubscript𝐎𝑡\displaystyle\mathbf{O}_{t}bold_O start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT =𝐐t𝐒t1×dv,absentsuperscriptsubscript𝐐𝑡topsubscript𝐒𝑡superscript1subscript𝑑𝑣\displaystyle=\mathbf{Q}_{t}^{\top}\mathbf{S}_{t}\in\mathbb{R}^{1\times d_{v}},= bold_Q start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 1 × italic_d start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT end_POSTSUPERSCRIPT , (8)
𝐑tsubscript𝐑𝑡\displaystyle\mathbf{R}_{t}bold_R start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT =𝚂𝚠𝚒𝚜𝚑(𝐗t𝐖r+br)1×dv,absent𝚂𝚠𝚒𝚜𝚑subscript𝐗𝑡subscript𝐖𝑟subscript𝑏𝑟superscript1subscript𝑑𝑣\displaystyle=\mathtt{Swish}(\mathbf{X}_{t}\mathbf{W}_{r}+b_{r})\in\mathbb{R}^% {1\times d_{v}},= typewriter_Swish ( bold_X start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_W start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT + italic_b start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT ) ∈ blackboard_R start_POSTSUPERSCRIPT 1 × italic_d start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT end_POSTSUPERSCRIPT , (9)
𝐘tsubscript𝐘𝑡\displaystyle\mathbf{Y}_{t}bold_Y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT =(𝐑tLN(𝐎t))𝐖O1×d,absentdirect-productsubscript𝐑𝑡LNsubscript𝐎𝑡subscript𝐖𝑂superscript1𝑑\displaystyle=(\mathbf{R}_{t}\odot\text{LN}(\mathbf{O}_{t}))\mathbf{W}_{O}\in% \mathbb{R}^{1\times d},= ( bold_R start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ⊙ LN ( bold_O start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ) bold_W start_POSTSUBSCRIPT italic_O end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 1 × italic_d end_POSTSUPERSCRIPT , (10)

where 𝚂𝚠𝚒𝚜𝚑𝚂𝚠𝚒𝚜𝚑\mathtt{Swish}typewriter_Swish is the Swish [44] activation function, and direct-product\odot is the element-wise multiplication operation. In subsequent sections, we use 𝐆𝐋𝐀()𝐆𝐋𝐀\mathbf{GLA}(\cdot)bold_GLA ( ⋅ ) to refer to the gated linear attention computation for the input sequence.

Diffusion Models.

Before introducing the proposed method, we provide a concise review of some basic concepts about diffusion models (DDPM) [20]. The DDPM takes noise as an input and samples images by iterative denoising the input. The forward process of DDPM begins with a stochastic process where the initial image x0subscript𝑥0x_{0}italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT is gradually corrupted by noise and is finally transformed into a simpler, noise-dominated state. The forward noising process can be represented as follows:

q(x1:Tx0)𝑞conditionalsubscript𝑥:1𝑇subscript𝑥0\displaystyle q(x_{1:T}\mid x_{0})italic_q ( italic_x start_POSTSUBSCRIPT 1 : italic_T end_POSTSUBSCRIPT ∣ italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) =t=1Tq(xtxt1),absentsuperscriptsubscriptproduct𝑡1𝑇𝑞conditionalsubscript𝑥𝑡subscript𝑥𝑡1\displaystyle=\prod_{t=1}^{T}q(x_{t}\mid x_{t-1}),= ∏ start_POSTSUBSCRIPT italic_t = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT italic_q ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∣ italic_x start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT ) , (11)
q(xtx0)𝑞conditionalsubscript𝑥𝑡subscript𝑥0\displaystyle q(x_{t}\mid x_{0})italic_q ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∣ italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) =𝒩(xt;αt¯x0,(1αt¯)I),absent𝒩subscript𝑥𝑡¯subscript𝛼𝑡subscript𝑥01¯subscript𝛼𝑡𝐼\displaystyle=\mathcal{N}(x_{t};\sqrt{\bar{\alpha_{t}}}x_{0},(1-\bar{\alpha_{t% }})I),= caligraphic_N ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; square-root start_ARG over¯ start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG end_ARG italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , ( 1 - over¯ start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG ) italic_I ) , (12)

where x1:Tsubscript𝑥:1𝑇x_{1:T}italic_x start_POSTSUBSCRIPT 1 : italic_T end_POSTSUBSCRIPT is the sequence of noised images from time t=1𝑡1t=1italic_t = 1 to t=T𝑡𝑇t=Titalic_t = italic_T. Then, DDPM learns the reverse process that recovers the original image with learned μθsubscript𝜇𝜃\mu_{\theta}italic_μ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT and ΣθsubscriptΣ𝜃\Sigma_{\theta}roman_Σ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT:

pθ(xt1xt)subscript𝑝𝜃conditionalsubscript𝑥𝑡1subscript𝑥𝑡\displaystyle p_{\theta}(x_{t-1}\mid x_{t})italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT ∣ italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) =𝒩(xt1;μθ(xt),Σθ(xt)),absent𝒩subscript𝑥𝑡1subscript𝜇𝜃subscript𝑥𝑡subscriptΣ𝜃subscript𝑥𝑡\displaystyle=\mathcal{N}(x_{t-1};\mu_{\theta}(x_{t}),\Sigma_{\theta}(x_{t})),= caligraphic_N ( italic_x start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT ; italic_μ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) , roman_Σ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ) , (13)

where θ𝜃\thetaitalic_θ are the parameters of the denoiser, and are trained with the variational lower bound [51] on the loglikelihood of the observed data x0subscript𝑥0x_{0}italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT.

(θ)𝜃\displaystyle\mathcal{L}(\theta)caligraphic_L ( italic_θ ) =p(x0x1)+tDKL(q(xt1xt,x0)pθ(xt1xt)),\displaystyle=-p(x_{0}\mid x_{1})+\sum_{t}D_{KL}(q^{*}(x_{t-1}\mid x_{t},x_{0}% )\parallel p_{\theta}(x_{t-1}\mid x_{t})),= - italic_p ( italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∣ italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) + ∑ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT italic_K italic_L end_POSTSUBSCRIPT ( italic_q start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT ∣ italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ∥ italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT ∣ italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ) , (14)

where \mathcal{L}caligraphic_L is the full loss. To further simplify the training process of DDPM, researchers reparameterize μθsubscript𝜇𝜃\mu_{\theta}italic_μ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT as a noise prediction network ϵθsubscriptitalic-ϵ𝜃\epsilon_{\theta}italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT and minimize the mean squared error loss simplesubscriptsimple\mathcal{L}_{\text{simple}}caligraphic_L start_POSTSUBSCRIPT simple end_POSTSUBSCRIPT between ϵθ(xt)subscriptitalic-ϵ𝜃subscript𝑥𝑡\epsilon_{\theta}(x_{t})italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) and the true Gaussian noise ϵtsubscriptitalic-ϵ𝑡\epsilon_{t}italic_ϵ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT:

simple(θ)=ϵθ(xt)ϵt22.subscriptsimple𝜃superscriptsubscriptnormsubscriptitalic-ϵ𝜃subscript𝑥𝑡subscriptitalic-ϵ𝑡22\displaystyle\mathcal{L}_{\text{simple}}(\theta)=\|\epsilon_{\theta}(x_{t})-% \epsilon_{t}\|_{2}^{2}.caligraphic_L start_POSTSUBSCRIPT simple end_POSTSUBSCRIPT ( italic_θ ) = ∥ italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) - italic_ϵ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT . (15)

However, to train a diffusion model that can learn a variable reverse process covariance ΣθsubscriptΣ𝜃\Sigma_{\theta}roman_Σ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT, we need to optimize the full DKLsubscript𝐷𝐾𝐿D_{KL}italic_D start_POSTSUBSCRIPT italic_K italic_L end_POSTSUBSCRIPT term. In this paper, we follow DiT [39] to train the network where we use the simple loss simplesubscriptsimple\mathcal{L}_{\text{simple}}caligraphic_L start_POSTSUBSCRIPT simple end_POSTSUBSCRIPT to train the noise prediction network ϵθsubscriptitalic-ϵ𝜃\epsilon_{\theta}italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT and use the full loss \mathcal{L}caligraphic_L to train the covariance prediction network ΣθsubscriptΣ𝜃\Sigma_{\theta}roman_Σ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT. After the training process, we follow the stochastic sampling process to generate images from the learned ϵθsubscriptitalic-ϵ𝜃\epsilon_{\theta}italic_ϵ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT and ΣθsubscriptΣ𝜃\Sigma_{\theta}roman_Σ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT.

3.2 Diffusion GLA

We present Diffusion GLA (DiG), a new architecture for diffusion generation. Our goal is to be as faithful to the standard GLA architecture as possible to retain its scaling ability and high-efficiency properties. An overview of the proposed GLA is shown in Fig. 3. The standard GLA is designed for the causal language modeling of 1-D sequences. To process the DDPM training of images, we follow some of the best practices of previous vision transformer architectures [14, 39]. DiG first takes a spatial representation z𝑧zitalic_z output by the VAE encoder [27, 47] as input. For an 256×256×32562563256\times 256\times 3256 × 256 × 3 image to VAE encoder, the shape of spatial representation z𝑧zitalic_z is 32×32×43232432\times 32\times 432 × 32 × 4. DiG subsequently converts the spatial input into a token sequence 𝐳pT×(P2C)subscript𝐳𝑝superscript𝑇superscript𝑃2𝐶\mathbf{z}_{p}\in\mathbb{R}^{{T}\times({P}^{2}\cdot{C})}bold_z start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_T × ( italic_P start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ⋅ italic_C ) end_POSTSUPERSCRIPT through the patchify layer, where T𝑇Titalic_T is length of token sequence, C𝐶{C}italic_C is the number of spatial representation channels, P𝑃{P}italic_P is the size of image patches, and halving P𝑃Pitalic_P will quadruple T𝑇Titalic_T. Next, we linearly project the 𝐳psubscript𝐳𝑝\mathbf{z}_{p}bold_z start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT to the vector with dimension D𝐷Ditalic_D and add frequency-based positional embeddings 𝐄posT×Dsubscript𝐄𝑝𝑜𝑠superscript𝑇𝐷\mathbf{E}_{pos}\in\mathbb{R}^{T\times{D}}bold_E start_POSTSUBSCRIPT italic_p italic_o italic_s end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_T × italic_D end_POSTSUPERSCRIPT to all projected tokens, as follows:

𝐳0subscript𝐳0\displaystyle\mathbf{z}_{0}bold_z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT =[𝐳p1𝐖;𝐳p2𝐖;;𝐳pT𝐖]+𝐄pos,absentsuperscriptsubscript𝐳𝑝1𝐖superscriptsubscript𝐳𝑝2𝐖superscriptsubscript𝐳𝑝𝑇𝐖subscript𝐄𝑝𝑜𝑠\displaystyle=[\mathbf{z}_{p}^{1}\mathbf{W};\mathbf{z}_{p}^{2}\mathbf{W};% \cdots;\mathbf{z}_{p}^{{T}}\mathbf{W}]+\mathbf{E}_{pos},= [ bold_z start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT bold_W ; bold_z start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_W ; ⋯ ; bold_z start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_W ] + bold_E start_POSTSUBSCRIPT italic_p italic_o italic_s end_POSTSUBSCRIPT , (16)

where 𝐳ptsuperscriptsubscript𝐳𝑝𝑡\mathbf{z}_{p}^{{t}}bold_z start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT is the t𝑡{t}italic_t-th patch of 𝐳psubscript𝐳𝑝\mathbf{z}_{p}bold_z start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT, 𝐖(P2C)×D𝐖superscriptsuperscript𝑃2𝐶𝐷\mathbf{W}\in\mathbb{R}^{({P}^{2}\cdot{C})\times{D}}bold_W ∈ blackboard_R start_POSTSUPERSCRIPT ( italic_P start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ⋅ italic_C ) × italic_D end_POSTSUPERSCRIPT is the learnable projection matrix. As for conditional information such as noise timesteps t𝑡t\in\mathbb{R}italic_t ∈ blackboard_R, and class labels y𝑦y\in\mathbb{R}italic_y ∈ blackboard_R, we adopt multi-layer perception (MLP) and embedding layer as timestep embedder and label embedder, respectively.

𝐭=𝐌𝐋𝐏(t),𝐲=𝐄𝐦𝐛𝐞𝐝(y),formulae-sequence𝐭𝐌𝐋𝐏𝑡𝐲𝐄𝐦𝐛𝐞𝐝𝑦\displaystyle\mathbf{t}=\mathbf{MLP}(t),\quad\mathbf{y}=\mathbf{Embed}(y),bold_t = bold_MLP ( italic_t ) , bold_y = bold_Embed ( italic_y ) , (17)

where 𝐭1×D𝐭superscript1𝐷\mathbf{t}\in\mathbb{R}^{1\times D}bold_t ∈ blackboard_R start_POSTSUPERSCRIPT 1 × italic_D end_POSTSUPERSCRIPT is time embedding and 𝐲1×D𝐲superscript1𝐷\mathbf{y}\in\mathbb{R}^{1\times D}bold_y ∈ blackboard_R start_POSTSUPERSCRIPT 1 × italic_D end_POSTSUPERSCRIPT is label embedding. We then send the token sequence (𝐳l1subscript𝐳𝑙1\mathbf{z}_{{l}-1}bold_z start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT) to the l𝑙{l}italic_l-th layer of the DiG encoder, and get the output 𝐳lsubscript𝐳𝑙\mathbf{z}_{{l}}bold_z start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT. Finally, we normalize the output token sequence 𝐳Lsubscript𝐳𝐿\mathbf{z}_{{L}}bold_z start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT, and feed it to the linear projection head to get the final predicted noise 𝐩^noisesubscript^𝐩𝑛𝑜𝑖𝑠𝑒\hat{\mathbf{p}}_{noise}over^ start_ARG bold_p end_ARG start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT and predicted covariance 𝐩^covariancesubscript^𝐩𝑐𝑜𝑣𝑎𝑟𝑖𝑎𝑛𝑐𝑒\hat{\mathbf{p}}_{covariance}over^ start_ARG bold_p end_ARG start_POSTSUBSCRIPT italic_c italic_o italic_v italic_a italic_r italic_i italic_a italic_n italic_c italic_e end_POSTSUBSCRIPT, as follows:

𝐳l=𝐃𝐢𝐆l(𝐳l1,𝐭,𝐲),𝐳n=𝐍𝐨𝐫𝐦(𝐳L),𝐩^noise,𝐩^covariance=𝐋𝐢𝐧𝐞𝐚𝐫(𝐳n),formulae-sequencesubscript𝐳𝑙subscript𝐃𝐢𝐆𝑙subscript𝐳𝑙1𝐭𝐲formulae-sequencesubscript𝐳𝑛𝐍𝐨𝐫𝐦subscript𝐳𝐿subscript^𝐩𝑛𝑜𝑖𝑠𝑒subscript^𝐩𝑐𝑜𝑣𝑎𝑟𝑖𝑎𝑛𝑐𝑒𝐋𝐢𝐧𝐞𝐚𝐫subscript𝐳𝑛\displaystyle\mathbf{z}_{l}=\mathbf{DiG{}}_{l}(\mathbf{z}_{{l}-1},\mathbf{t},% \mathbf{y}),\quad\mathbf{z}_{n}=\mathbf{Norm}(\mathbf{z}_{{L}}),\quad\hat{% \mathbf{p}}_{noise},\hat{\mathbf{p}}_{covariance}=\mathbf{Linear}(\mathbf{z}_{% n}),bold_z start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT = bold_DiG start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( bold_z start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT , bold_t , bold_y ) , bold_z start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT = bold_Norm ( bold_z start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ) , over^ start_ARG bold_p end_ARG start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT , over^ start_ARG bold_p end_ARG start_POSTSUBSCRIPT italic_c italic_o italic_v italic_a italic_r italic_i italic_a italic_n italic_c italic_e end_POSTSUBSCRIPT = bold_Linear ( bold_z start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) , (18)

where 𝐃𝐢𝐆lsubscript𝐃𝐢𝐆𝑙\mathbf{DiG{}}_{l}bold_DiG start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT is the l𝑙litalic_l-th diffusion GLA block, L𝐿{L}italic_L is the number of layers, and 𝐍𝐨𝐫𝐦𝐍𝐨𝐫𝐦\mathbf{Norm}bold_Norm is the normalization layer. The 𝐩^noisesubscript^𝐩𝑛𝑜𝑖𝑠𝑒\hat{\mathbf{p}}_{noise}over^ start_ARG bold_p end_ARG start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT and 𝐩^covariancesubscript^𝐩𝑐𝑜𝑣𝑎𝑟𝑖𝑎𝑛𝑐𝑒\hat{\mathbf{p}}_{covariance}over^ start_ARG bold_p end_ARG start_POSTSUBSCRIPT italic_c italic_o italic_v italic_a italic_r italic_i italic_a italic_n italic_c italic_e end_POSTSUBSCRIPT have the same shape as the input spatial representation, i.e., 32×32×43232432\times 32\times 432 × 32 × 4.

Refer to caption
Figure 4: The overview of the proposed DiG model. The figure presents the whole Latent DiG, DiG block, details of spatial reorient & enhancement module (SREM), and layer-wise DiG scanning directions controlled by the SREM. We mark the scanning order and indices on each patch.
Input: token sequence 𝐳l1subscript𝐳𝑙1\mathbf{z}_{l-1}bold_z start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT : (𝙱,𝚃,𝙳)𝙱𝚃𝙳(\mathtt{B},\mathtt{T},\mathtt{D})( typewriter_B , typewriter_T , typewriter_D ), timestep embed 𝐭𝐭\mathbf{t}bold_t : (𝙱,1,𝙳)𝙱1𝙳(\mathtt{B},1,\mathtt{D})( typewriter_B , 1 , typewriter_D ), label embed 𝐲𝐲\mathbf{y}bold_y : (𝙱,1,𝙳)𝙱1𝙳(\mathtt{B},1,\mathtt{D})( typewriter_B , 1 , typewriter_D )
Output: token sequence 𝐳lsubscript𝐳𝑙\mathbf{z}_{l}bold_z start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT : (𝙱,𝚃,𝙳)𝙱𝚃𝙳(\mathtt{B},\mathtt{T},\mathtt{D})( typewriter_B , typewriter_T , typewriter_D )
1 α1subscript𝛼1\mathbf{\alpha}_{1}italic_α start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT, β1subscript𝛽1\mathbf{\beta}_{1}italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT, γ1subscript𝛾1\mathbf{\gamma}_{1}italic_γ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT, α2subscript𝛼2\mathbf{\alpha}_{2}italic_α start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, β2subscript𝛽2\mathbf{\beta}_{2}italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, γ2subscript𝛾2\mathbf{\gamma}_{2}italic_γ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT : (𝙱,1,𝙳)𝙱1𝙳(\mathtt{B},1,\mathtt{D})( typewriter_B , 1 , typewriter_D ) \leftarrow 𝐌𝐋𝐏𝐌𝐋𝐏\mathbf{MLP}bold_MLP(𝐭+𝐲𝐭𝐲\mathbf{t}+\mathbf{y}bold_t + bold_y) // regress parameters of adaLN
2 𝐳l1:(𝙱,𝚃,𝙳)𝐳l1+α1𝐆𝐋𝐀(𝐍𝐨𝐫𝐦(𝐳l1)(1+γ1)+β1))\mathbf{z}_{l-1}^{\prime}:{\color[rgb]{0.0,0.6,0.0}(\mathtt{B},\mathtt{T},% \mathtt{D})}\leftarrow\mathbf{z}_{l-1}+\alpha_{1}\odot\mathbf{GLA}(\mathbf{% Norm}(\mathbf{z}_{l-1})\odot(1+\mathbf{\gamma}_{1})+\mathbf{\beta}_{1}))bold_z start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT : ( typewriter_B , typewriter_T , typewriter_D ) ← bold_z start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT + italic_α start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ⊙ bold_GLA ( bold_Norm ( bold_z start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ) ⊙ ( 1 + italic_γ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) + italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) )
3 𝐳l1′′:(𝙱,𝚃,𝙳)𝐳l1+α2𝐅𝐅𝐍(𝐍𝐨𝐫𝐦(𝐳l1)(1+γ2)+β2))\mathbf{z}_{l-1}^{\prime\prime}:{\color[rgb]{0.0,0.6,0.0}(\mathtt{B},\mathtt{T% },\mathtt{D})}\leftarrow\mathbf{z}_{l-1}^{\prime}+\alpha_{2}\odot\mathbf{FFN}(% \mathbf{Norm}(\mathbf{z}_{l-1}^{\prime})\odot(1+\mathbf{\gamma}_{2})+\mathbf{% \beta}_{2}))bold_z start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT : ( typewriter_B , typewriter_T , typewriter_D ) ← bold_z start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT + italic_α start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ⊙ bold_FFN ( bold_Norm ( bold_z start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ⊙ ( 1 + italic_γ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) + italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) )
4 𝐳l1′′:(𝙱,𝚃,𝚃,𝙳)𝐃𝐖𝐂𝐨𝐧𝐯𝟐𝐝(𝐫𝐞𝐬𝐡𝐚𝐩𝐞𝟐𝐝(𝐳l1′′)):superscriptsubscript𝐳𝑙1′′𝙱𝚃𝚃𝙳𝐃𝐖𝐂𝐨𝐧𝐯𝟐𝐝𝐫𝐞𝐬𝐡𝐚𝐩𝐞𝟐𝐝superscriptsubscript𝐳𝑙1′′\mathbf{z}_{l-1}^{\prime\prime}:{\color[rgb]{0.0,0.6,0.0}(\mathtt{B},\sqrt{% \mathtt{T}},\sqrt{\mathtt{T}},\mathtt{D})}\leftarrow\mathbf{DWConv2d}(\mathbf{% reshape2d}(\mathbf{z}_{l-1}^{\prime\prime}))bold_z start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT : ( typewriter_B , square-root start_ARG typewriter_T end_ARG , square-root start_ARG typewriter_T end_ARG , typewriter_D ) ← bold_DWConv2d ( bold_reshape2d ( bold_z start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT ) ) // lightweight spatial modeling
5 if l % 2222 == 00 then
6       𝐳l1′′:(𝙱,𝚃,𝚃,𝙳)𝐭𝐫𝐚𝐧𝐬𝐩𝐨𝐬𝐞(𝐳l1′′):superscriptsubscript𝐳𝑙1′′𝙱𝚃𝚃𝙳𝐭𝐫𝐚𝐧𝐬𝐩𝐨𝐬𝐞superscriptsubscript𝐳𝑙1′′\mathbf{z}_{l-1}^{\prime\prime}:{\color[rgb]{0.0,0.6,0.0}(\mathtt{B},\sqrt{% \mathtt{T}},\sqrt{\mathtt{T}},\mathtt{D})}\leftarrow\mathbf{transpose}(\mathbf% {z}_{l-1}^{\prime\prime})bold_z start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT : ( typewriter_B , square-root start_ARG typewriter_T end_ARG , square-root start_ARG typewriter_T end_ARG , typewriter_D ) ← bold_transpose ( bold_z start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT ) // transpose the token matrix every two block
7      
8 end if
9𝐳l:(𝙱,𝚃,𝙳)𝐟𝐥𝐢𝐩(𝐟𝐥𝐚𝐭𝐭𝐞𝐧(𝐳l1′′)):subscript𝐳𝑙𝙱𝚃𝙳𝐟𝐥𝐢𝐩𝐟𝐥𝐚𝐭𝐭𝐞𝐧superscriptsubscript𝐳𝑙1′′\mathbf{z}_{l}:{\color[rgb]{0.0,0.6,0.0}(\mathtt{B},\mathtt{T},\mathtt{D})}% \leftarrow\mathbf{flip}(\mathbf{flatten}(\mathbf{z}_{l-1}^{\prime\prime}))bold_z start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT : ( typewriter_B , typewriter_T , typewriter_D ) ← bold_flip ( bold_flatten ( bold_z start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ ′ end_POSTSUPERSCRIPT ) ) // flip token sequence at each end of block
Return: 𝐳lsubscript𝐳𝑙\mathbf{z}_{l}bold_z start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT
Algorithm 1 DiG Block Process.

3.3 DiG Block

The original GLA block process input sequence with a recurrent format, which only enables causal modeling for 1-D sequence. In this section, we introduce the DiG block, which incorporates a spatial reorient & enhancement module (SREM) that enables lightweight spatial recognition and controls layer-wise scanning directions. The DiG block is shown in Fig. 4.

Specifically, we present the forward process of DiG block in Algo. 1. Following the widespread usage of adaptive normalization layers [41] in GANs [2, 24] and diffusion models [13, 39], we add and normalize the input timestep embedding 𝐭𝐭\mathbf{t}bold_t and label embedding 𝐲𝐲\mathbf{y}bold_y to regress the scale parameter α𝛼\alphaitalic_α, γ𝛾\gammaitalic_γ, and shift parameter β𝛽\betaitalic_β. Next, we launch gated linear attention (GLA) and feedforward network (FFN) with the adjustment of regressed adaptive layer norm (adaLN) parameters. Then, we reshape the sequence to 2D and launch a lightweight 3×3333\times 33 × 3 depth-wise convolution (DWConv2d) layer to perceive local spatial information. Specifically, using traditional initialization for DWConv2d leads to slow convergence because convolutional weights are dispersed around. To address this problem, we propose identity initialization that only sets the convolutional kernel center as 1, and the surroundings to 0. Last, we transpose the 2D token matrix every two blocks and flip the flattened sequence to control the scanning directions of the next block. As shown in the right part of Fig. 4, each layer only processes scanning in one direction.

3.4 Architecture Details

Table 1: Details of DiG models. We follow DiT [39] model configurations for the Small (S), Base (B), Large (L), and XLarge (XL) variants. Given I=32,p=4formulae-sequence𝐼32𝑝4I=32,p=4italic_I = 32 , italic_p = 4.
Model Layers N𝑁Nitalic_N Hidden Size D𝐷Ditalic_D Heads Parameters (M) Gflops GflopsDiGGflopsDiTsubscriptGflopsDiGsubscriptGflopsDiT\frac{\text{Gflops}_{\text{DiG}}}{\text{Gflops}_{\text{DiT}}}divide start_ARG Gflops start_POSTSUBSCRIPT DiG end_POSTSUBSCRIPT end_ARG start_ARG Gflops start_POSTSUBSCRIPT DiT end_POSTSUBSCRIPT end_ARG
DiG-S 12 384 6 31.5 1.09 77.9%
DiG-B 12 768 12 124.6 4.31 77.0%
DiG-L 24 1024 16 443.4 15.54 78.9%
DiG-XL 28 1152 16 644.6 22.53 77.4%

We use a total of N𝑁Nitalic_N DiG blocks, each operating at the hidden dimension size D𝐷Ditalic_D. Following previous works [39, 14, 59], we use standard transformer configs that scales N𝑁Nitalic_N, D𝐷Ditalic_D, and attention heads number. Specifically, we provide four configs: DiG-S, DiG-B, DiG-L, and DiG-XL, as shown in the Tab. 1. They cover a wide range of parameters and flop allocations, from 31.531.531.531.5 M to 644.6644.6644.6644.6 M and 1.09 Gflops to 22.53 Gflops, presenting a way to gauge the scaling performance and efficiency. Notably, DiG only consume 77.0% to 78.9% Gflops when compared with the same size baseline models, i.e., DiTs.

3.5 Efficiency Analysis

GPU contains two important components, i.e., high bandwidth memory (HBM) and SRAM. HBM has a bigger memory size but SRAM has a larger bandwidth. To make full use of SRAM and modeling sequences in a parallel form, we follow GLA to split a whole sequence into many chunks that can complete calculations on SRAM. We denote the chunk size as M𝑀Mitalic_M, the training complexity is thus O(TM(M2D+MD2))=O(TMD+TD2)𝑂𝑇𝑀superscript𝑀2𝐷𝑀superscript𝐷2𝑂𝑇𝑀𝐷𝑇superscript𝐷2O(\frac{T}{M}(M^{2}D+MD^{2}))=O(TMD+TD^{2})italic_O ( divide start_ARG italic_T end_ARG start_ARG italic_M end_ARG ( italic_M start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_D + italic_M italic_D start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) ) = italic_O ( italic_T italic_M italic_D + italic_T italic_D start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ), which is less than the traditional attention’s complexity O(T2D)𝑂superscript𝑇2𝐷O(T^{2}D)italic_O ( italic_T start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_D ) when T>D𝑇𝐷T>Ditalic_T > italic_D. Furthermore, the lightweight DWConv2d and efficient matrix operations in DiG block also guarantee the efficiency as shown in Fig. 1 and Fig. 2.

4 Experiment

4.1 Experimental Settings

Datasets and metrics.

Following previous works [39], we use ImageNet [12] for class-conditional image generation learning at a resolution of 256×256256256256\times 256256 × 256. The ImageNet dataset contains 1,281,167 training images varying from 1,000 different classes. We use the horizontal flips as the data augmentation. We measure the generation performance with Frechet Inception Distance (FID) [37], Inception Score [50], sFID [37], and Precision/Recall [30].

Implementation details.

We use the AdamW optimizer with a constant learning rate of 1e41𝑒41e-41 italic_e - 4. Following the previous works [39], we utilize the exponential moving average (EMA) of DiG weights during training with a decay rate of 0.9999. We generate all images with the EMA model. For the training of ImageNet, we use an off-the-shelf pretrained variational autoencoder (VAE) [46, 28].

4.2 Model Analysis

Effect of spatial reorient & enhancement module.

Table 2: Ablation of the proposed Spatial Reorient & Enhancement Module (SREM). We validate the effectiveness of each SREM component and use the same hyperparameters for all models. The “half right and half wrong symbol” means use DWConv2d without the proposed identity initialization.
Model Spatial Reorient & Enhancement Module Flops (G) Params (M) FID-50K
𝙱𝚒𝚍𝚒𝚛𝚎𝚌𝚝𝚒𝚘𝚗𝚊𝚕𝙱𝚒𝚍𝚒𝚛𝚎𝚌𝚝𝚒𝚘𝚗𝚊𝚕\mathtt{Bidirectional}typewriter_Bidirectional 𝙳𝚆𝙲𝚘𝚗𝚟𝟸𝚍𝙳𝚆𝙲𝚘𝚗𝚟𝟸𝚍\mathtt{DWConv2d}typewriter_DWConv2d 𝙲𝚛𝚒𝚜𝚜𝚌𝚛𝚘𝚜𝚜𝙲𝚛𝚒𝚜𝚜𝚌𝚛𝚘𝚜𝚜\mathtt{Crisscross}typewriter_Crisscross
Baseline Method.
DiT-S/2 6.06 33.0 68.4
Ours.
DiG-S/2 4.29 33.0 175.84
DiG-S/2 4.29 33.0 69.28
DiG-S/2

4.30 33.1 96.83
DiG-S/2 4.30 33.1 63.84
DiG-S/2 4.30 33.1 62.06

As shown in Tab. 2, we analyze the effectiveness of the proposed spatial reorient & enhancement module (SREM). We take the DiT-S/2 as our baseline method. The naive DiG with only the causal modeling has significantly fewer flops and parameters, but also poor FID performance due to the lack of global context. We first add the bidirectional scanning to DiG and observe significant improvement, i.e., 69.28 FID, which demonstrates the importance of global context. Experiment without identity initialization for DWConv2d, i.e., the half right and half wrong symbol, leads to worse FID, while the DWConv2d with identity initialization can improve performance a lot. The experiments with DWConv2d prove the importance of identity initialization and local awareness. The experiment in the last row shows that the full SREM can bring the best performance, taking care of both local information and global context.

Refer to caption
Figure 5: The scaling analysis with DiG model sizes and patch sizes.

Scaling model size.

We investigate the scaling ability of DiG among four different model scales on the ImageNet dataset. As depicted in Fig. 5(a), the performance improves as the models scale from S/2 to XL/2. The results demonstrate the scaling ability of DiG, indicating its potential as a large foundational diffusion model.

Effect of patch size.

We train DiG-S with patch size varying from 2, 4, and 8 on the ImageNet dataset. As shown in Fig. 5(b), discernible FID enhancements can be observed throughout the training process by augmenting the patch sizes of DiG. Consequently, optimal performance necessitates a smaller patch size and longer sequence length. While the DiG is more efficient in dealing with the long-sequence generation tasks when compared to DiT [39] baseline.

4.3 Main Results

We mainly compare the proposed DiG with our baseline method, DiT [39], with the same hyperparameters. The proposed DiG outperforms DiT among four model scales with 400K training iterations. Furthermore, the DiG-XL/2-1200K with classifier-free guidance also presents competitive results when compared with previous state-of-the-art methods.

Table 3: Benchmarking class-conditional image generation on ImageNet 256×256256256256\times 256256 × 256. DiG models adopt the same hyperparameters as DiT [39] for fair comparison. We mark the best results in bold.
Model FID↓ sFID↓ IS↑ Precision↑ Recall↑
Previous state-of-the-art diffusion methods.
ADM [13] 10.94 6.02 100.98 0.69 0.63
ADM-U 7.49 5.13 127.49 0.72 0.63
ADM-G 4.59 5.25 186.70 0.82 0.52
ADM-G, ADM-U 3.94 6.14 215.84 0.83 0.53
CDM [21] 4.88 - 158.71 - -
LDM-8 [46] 15.51 - 79.03 0.65 0.63
LDM-8-G 7.76 - 209.52 0.84 0.35
LDM-4-G (cfg=1.25) 3.95 - 178.22 0.81 0.55
LDM-4-G (cfg=1.50) 3.60 - 247.67 0.87 0.48
Baselines and Ours.
DiT-S/2-400K [39] 68.40 - - - -
DiG-S/2-400K 62.06 11.77 22.81 0.39 0.56
DiT-B/2-400K 43.47 - - - -
DiG-B/2-400K 39.50 8.50 37.21 0.51 0.63
DiT-L/2-400K 23.33 - - - -
DiG-L/2-400K 22.90 6.91 59.87 0.60 0.64
DiT-XL/2-400K 19.47 - - - -
DiG-XL/2-400K 18.53 6.06 68.53 0.63 0.64
DiG-XL/2-1200K 11.96 7.39 106.65 0.65 0.67
DiG-XL/2-1200K (cfg=1.5) 2.84 5.47 250.36 0.82 0.56

4.4 Case Study

Refer to caption
Figure 6: Image results generated from the proposed DiG-XL/2 model.

Fig. 6 showcases a selection of samples from DiG-XL/2 that trained with ImageNet dataset at a resolution of 256×256256256256\times 256256 × 256. The results demonstrate correct semantic and accurate spatial relationships.

5 Conclusion

In this work, we present DiG, a cost-effective alternative to the vanilla Transformer for diffusion models in image generation tasks. In particular, DiG explores Gated Linear Attention Transformers (GLA), attaining superior efficiency and effectiveness in long-sequence image generation tasks. Experimentally, DiG shows comparable performance to prior diffusion models on class-conditional ImageNet benchmarks while significantly reducing the computational burden. We hope this work can open up the possibility for other long-sequence generation tasks, such as video and audio modeling.

Limitations.

Although DiG shows superior efficiency in diffusion image generation, building a large foundation model like Sora[3] upon DiG is still an area that needs to be explored further.

References

  • [1] Fan Bao, Shen Nie, Kaiwen Xue, Yue Cao, Chongxuan Li, Hang Su, and Jun Zhu. All are worth words: A vit backbone for diffusion models. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 22669–22679, 2023.
  • [2] Andrew Brock, Jeff Donahue, and Karen Simonyan. Large scale gan training for high fidelity natural image synthesis. arXiv preprint arXiv:1809.11096, 2018.
  • [3] Tim Brooks, Bill Peebles, Connor Holmes, Will DePue, Yufei Guo, Li Jing, David Schnurr, Joe Taylor, Troy Luhman, Eric Luhman, Clarence Ng, Ricky Wang, and Aditya Ramesh. Video generation models as world simulators. 2024.
  • [4] Hanqun Cao, Cheng Tan, Zhangyang Gao, Yilun Xu, Guangyong Chen, Pheng-Ann Heng, and Stan Z Li. A survey on generative diffusion models. IEEE Transactions on Knowledge and Data Engineering, 2024.
  • [5] Junsong Chen, Chongjian Ge, Enze Xie, Yue Wu, Lewei Yao, Xiaozhe Ren, Zhongdao Wang, Ping Luo, Huchuan Lu, and Zhenguo Li. Pixart-sigma: Weak-to-strong training of diffusion transformer for 4k text-to-image generation. arXiv preprint arXiv:2403.04692, 2024.
  • [6] Junsong Chen, Jincheng Yu, Chongjian Ge, Lewei Yao, Enze Xie, Yue Wu, Zhongdao Wang, James Kwok, Ping Luo, Huchuan Lu, et al. Pixart-alpha: Fast training of diffusion transformer for photorealistic text-to-image synthesis. arXiv preprint arXiv:2310.00426, 2023.
  • [7] Nanxin Chen, Yu Zhang, Heiga Zen, Ron J Weiss, Mohammad Norouzi, and William Chan. Wavegrad: Estimating gradients for waveform generation. arXiv preprint arXiv:2009.00713, 2020.
  • [8] Jooyoung Choi, Sungwon Kim, Yonghyun Jeong, Youngjune Gwon, and Sungroh Yoon. Ilvr: Conditioning method for denoising diffusion probabilistic models. arXiv preprint arXiv:2108.02938, 2021.
  • [9] François Chollet. Xception: Deep learning with depthwise separable convolutions. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 1251–1258, 2017.
  • [10] Krzysztof Choromanski, Valerii Likhosherstov, David Dohan, Xingyou Song, Andreea Gane, Tamas Sarlos, Peter Hawkins, Jared Davis, Afroz Mohiuddin, Lukasz Kaiser, et al. Rethinking attention with performers. arXiv preprint arXiv:2009.14794, 2020.
  • [11] Tri Dao. Flashattention-2: Faster attention with better parallelism and work partitioning. arXiv preprint arXiv:2307.08691, 2023.
  • [12] Jia Deng, Wei Dong, Richard Socher, Li-Jia Li, Kai Li, and Li Fei-Fei. Imagenet: A large-scale hierarchical image database. In 2009 IEEE conference on computer vision and pattern recognition, pages 248–255. Ieee, 2009.
  • [13] Prafulla Dhariwal and Alexander Nichol. Diffusion models beat gans on image synthesis. Advances in neural information processing systems, 34:8780–8794, 2021.
  • [14] Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, et al. An image is worth 16x16 words: Transformers for image recognition at scale. arXiv preprint arXiv:2010.11929, 2020.
  • [15] Patrick Esser, Sumith Kulal, Andreas Blattmann, Rahim Entezari, Jonas Müller, Harry Saini, Yam Levi, Dominik Lorenz, Axel Sauer, Frederic Boesel, et al. Scaling rectified flow transformers for high-resolution image synthesis. arXiv preprint arXiv:2403.03206, 2024.
  • [16] Zhengcong Fei, Mingyuan Fan, Changqian Yu, and Junshi Huang. Scalable diffusion models with state space backbone. arXiv preprint arXiv:2402.05608, 2024.
  • [17] Albert Gu and Tri Dao. Mamba: Linear-time sequence modeling with selective state spaces. arXiv preprint arXiv:2312.00752, 2023.
  • [18] Shuyang Gu, Dong Chen, Jianmin Bao, Fang Wen, Bo Zhang, Dongdong Chen, Lu Yuan, and Baining Guo. Vector quantized diffusion model for text-to-image synthesis. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 10696–10706, 2022.
  • [19] Jonathan Ho, William Chan, Chitwan Saharia, Jay Whang, Ruiqi Gao, Alexey Gritsenko, Diederik P Kingma, Ben Poole, Mohammad Norouzi, David J Fleet, et al. Imagen video: High definition video generation with diffusion models. arXiv preprint arXiv:2210.02303, 2022.
  • [20] Jonathan Ho, Ajay Jain, and Pieter Abbeel. Denoising diffusion probabilistic models. Advances in neural information processing systems, 33:6840–6851, 2020.
  • [21] Jonathan Ho, Chitwan Saharia, William Chan, David J Fleet, Mohammad Norouzi, and Tim Salimans. Cascaded diffusion models for high fidelity image generation. Journal of Machine Learning Research, 23(47):1–33, 2022.
  • [22] Jonathan Ho and Tim Salimans. Classifier-free diffusion guidance. arXiv preprint arXiv:2207.12598, 2022.
  • [23] Vincent Tao Hu, Stefan Andreas Baumann, Ming Gui, Olga Grebenkova, Pingchuan Ma, Johannes Fischer, and Bjorn Ommer. Zigma: Zigzag mamba diffusion model. arXiv preprint arXiv:2403.13802, 2024.
  • [24] Tero Karras, Samuli Laine, and Timo Aila. A style-based generator architecture for generative adversarial networks. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pages 4401–4410, 2019.
  • [25] Jungo Kasai, Hao Peng, Yizhe Zhang, Dani Yogatama, Gabriel Ilharco, Nikolaos Pappas, Yi Mao, Weizhu Chen, and Noah A Smith. Finetuning pretrained transformers into rnns. arXiv preprint arXiv:2103.13076, 2021.
  • [26] Angelos Katharopoulos, Apoorv Vyas, Nikolaos Pappas, and François Fleuret. Transformers are rnns: Fast autoregressive transformers with linear attention. In International conference on machine learning, pages 5156–5165. PMLR, 2020.
  • [27] Diederik P Kingma and Max Welling. Auto-encoding variational bayes. arXiv preprint arXiv:1312.6114, 2013.
  • [28] Diederik P Kingma and Max Welling. Auto-encoding variational bayes. arXiv preprint arXiv:1312.6114, 2013.
  • [29] Zhifeng Kong, Wei Ping, Jiaji Huang, Kexin Zhao, and Bryan Catanzaro. Diffwave: A versatile diffusion model for audio synthesis. arXiv preprint arXiv:2009.09761, 2020.
  • [30] Tuomas Kynkäänniemi, Tero Karras, Samuli Laine, Jaakko Lehtinen, and Timo Aila. Improved precision and recall metric for assessing generative models. Advances in neural information processing systems, 32, 2019.
  • [31] Yann LeCun, Léon Bottou, Yoshua Bengio, and Patrick Haffner. Gradient-based learning applied to document recognition. Proceedings of the IEEE, 86(11):2278–2324, 1998.
  • [32] Luping Liu, Yi Ren, Zhijie Lin, and Zhou Zhao. Pseudo numerical methods for diffusion models on manifolds. arXiv preprint arXiv:2202.09778, 2022.
  • [33] Cheng Lu, Yuhao Zhou, Fan Bao, Jianfei Chen, Chongxuan Li, and Jun Zhu. Dpm-solver: A fast ode solver for diffusion probabilistic model sampling in around 10 steps. Advances in Neural Information Processing Systems, 35:5775–5787, 2022.
  • [34] Xin Ma, Yaohui Wang, Gengyun Jia, Xinyuan Chen, Ziwei Liu, Yuan-Fang Li, Cunjian Chen, and Yu Qiao. Latte: Latent diffusion transformer for video generation. arXiv preprint arXiv:2401.03048, 2024.
  • [35] Huanru Henry Mao. Fine-tuning pre-trained transformers into decaying fast weights. arXiv preprint arXiv:2210.04243, 2022.
  • [36] Kangfu Mei and Vishal Patel. Vidm: Video implicit diffusion models. In Proceedings of the AAAI Conference on Artificial Intelligence, volume 37, pages 9117–9125, 2023.
  • [37] Charlie Nash, Jacob Menick, Sander Dieleman, and Peter W Battaglia. Generating images with sparse representations. arXiv preprint arXiv:2103.03841, 2021.
  • [38] Alexander Quinn Nichol and Prafulla Dhariwal. Improved denoising diffusion probabilistic models. In International conference on machine learning, pages 8162–8171. PMLR, 2021.
  • [39] William Peebles and Saining Xie. Scalable diffusion models with transformers. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pages 4195–4205, 2023.
  • [40] Bo Peng, Eric Alcaide, Quentin Anthony, Alon Albalak, Samuel Arcadinho, Huanqi Cao, Xin Cheng, Michael Chung, Matteo Grella, Kranthi Kiran GV, et al. Rwkv: Reinventing rnns for the transformer era. arXiv preprint arXiv:2305.13048, 2023.
  • [41] Ethan Perez, Florian Strub, Harm De Vries, Vincent Dumoulin, and Aaron Courville. Film: Visual reasoning with a general conditioning layer. In Proceedings of the AAAI conference on artificial intelligence, volume 32, 2018.
  • [42] Ben Poole, Ajay Jain, Jonathan T Barron, and Ben Mildenhall. Dreamfusion: Text-to-3d using 2d diffusion. arXiv preprint arXiv:2209.14988, 2022.
  • [43] Zhen Qin, Dong Li, Weigao Sun, Weixuan Sun, Xuyang Shen, Xiaodong Han, Yunshen Wei, Baohong Lv, Fei Yuan, Xiao Luo, et al. Scaling transnormer to 175 billion parameters. arXiv preprint arXiv:2307.14995, 2023.
  • [44] Prajit Ramachandran, Barret Zoph, and Quoc V. Le. Searching for activation functions, 2017.
  • [45] Aditya Ramesh, Prafulla Dhariwal, Alex Nichol, Casey Chu, and Mark Chen. Hierarchical text-conditional image generation with clip latents. arXiv preprint arXiv:2204.06125, 1(2):3, 2022.
  • [46] Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, and Björn Ommer. High-resolution image synthesis with latent diffusion models. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pages 10684–10695, 2022.
  • [47] Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, and Björn Ommer. High-resolution image synthesis with latent diffusion models. In Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pages 10684–10695, 2022.
  • [48] Olaf Ronneberger, Philipp Fischer, and Thomas Brox. U-net: Convolutional networks for biomedical image segmentation. In Medical image computing and computer-assisted intervention–MICCAI 2015: 18th international conference, Munich, Germany, October 5-9, 2015, proceedings, part III 18, pages 234–241. Springer, 2015.
  • [49] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily L Denton, Kamyar Ghasemipour, Raphael Gontijo Lopes, Burcu Karagol Ayan, Tim Salimans, et al. Photorealistic text-to-image diffusion models with deep language understanding. Advances in neural information processing systems, 35:36479–36494, 2022.
  • [50] Tim Salimans, Ian Goodfellow, Wojciech Zaremba, Vicki Cheung, Alec Radford, and Xi Chen. Improved techniques for training gans. Advances in neural information processing systems, 29, 2016.
  • [51] Jascha Sohl-Dickstein, Eric Weiss, Niru Maheswaranathan, and Surya Ganguli. Deep unsupervised learning using nonequilibrium thermodynamics. In International conference on machine learning, pages 2256–2265. PMLR, 2015.
  • [52] Jiaming Song, Chenlin Meng, and Stefano Ermon. Denoising diffusion implicit models. arXiv preprint arXiv:2010.02502, 2020.
  • [53] Yang Song and Stefano Ermon. Generative modeling by estimating gradients of the data distribution. Advances in neural information processing systems, 32, 2019.
  • [54] Yang Song, Jascha Sohl-Dickstein, Diederik P Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole. Score-based generative modeling through stochastic differential equations. arXiv preprint arXiv:2011.13456, 2020.
  • [55] Yutao Sun, Li Dong, Shaohan Huang, Shuming Ma, Yuqing Xia, Jilong Xue, Jianyong Wang, and Furu Wei. Retentive network: A successor to transformer for large language models. arXiv preprint arXiv:2307.08621, 2023.
  • [56] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. Advances in neural information processing systems, 30, 2017.
  • [57] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. Advances in neural information processing systems, 30, 2017.
  • [58] Jing Nathan Yan, Jiatao Gu, and Alexander M Rush. Diffusion models without attention. arXiv preprint arXiv:2311.18257, 2023.
  • [59] Songlin Yang, Bailin Wang, Yikang Shen, Rameswar Panda, and Yoon Kim. Gated linear attention transformers with hardware-efficient training. arXiv preprint arXiv:2312.06635, 2023.
  • [60] Xiulong Yang, Sheng-Min Shih, Yinlin Fu, Xiaoting Zhao, and Shihao Ji. Your vit is secretly a hybrid discriminative-generative diffusion model. arXiv preprint arXiv:2208.07791, 2022.
  • [61] Taoran Yi, Jiemin Fang, Guanjun Wu, Lingxi Xie, Xiaopeng Zhang, Wenyu Liu, Qi Tian, and Xinggang Wang. Gaussiandreamer: Fast generation from text to 3d gaussian splatting with point cloud priors. arXiv preprint arXiv:2310.08529, 2023.
  • [62] Lvmin Zhang, Anyi Rao, and Maneesh Agrawala. Adding conditional control to text-to-image diffusion models. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pages 3836–3847, 2023.
  • [63] Min Zhao, Fan Bao, Chongxuan Li, and Jun Zhu. Egsde: Unpaired image-to-image translation via energy-guided stochastic differential equations. Advances in Neural Information Processing Systems, 35:3609–3623, 2022.
  • [64] Zhizhuo Zhou and Shubham Tulsiani. Sparsefusion: Distilling view-conditioned diffusion for 3d reconstruction. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pages 12588–12597, 2023.