DIFFUSSM] Attention-independent Diffusion Model
3 main points
✔️ Diffusion models with their high image generation capability present computational challenges for use at high resolutions, and current methods such as patching speed up processing, but at the expense of image quality
✔️ since the primary computational resources are spent on Attention, Eliminating reliance onAttentioncan save a lot of computational cost
✔️ DIFFUSSM uses the SSM mechanism instead of Attention to improve computational efficiency while preserving the quality of the generated image
Diffusion Models Without Attention
written by Jing Nathan Yan, Jiatao Gu, Alexander M. Rush
(Submitted on 30 Nov 2023)
Comments: Published on arxiv.
Subjects: Computer Vision and Pattern Recognition (cs.CV)
code:
The images used in this article are from the paper, the introductory slides, or were created based on them.
Introduction
Rapid advances in image generation are being driven by denoising diffusion probability models (DDPMs ), which iteratively denoise latent variables to produce high-fidelity samples, but face computational challenges in scaling to higher resolutions. In particular, the Self-Attention mechanism is a bottleneck, and representation compression is used to reduce computational cost.
High-resolution architectures employ patchiness and multiscale resolution, both of which degrade spatial information and introduce artifacts; DIFFUSSM eliminates the Self-Attention mechanism and uses a state-space model (SSM) to reduce computational complexity. DIFFUSSM uses an hourglass architecture to handle the finer representations in images and improve efficiency; experiments on ImageNet have shown improved FID, sFID, and Inception Score with fewer Gflops than existing methods.
Proposed Method
The goal of this paper is to design a diffusion architecture that learns long-range interactions at high resolution without requiring "length reduction" as in patching; similar to the Transformer-based diffusion model (DiT), this approach flattens the image and treats the sequence modeling problem as a treat it as a sequence modeling problem. However, unlike Transformers, this approach uses sub-quadratic calculations for the length of the sequence.
State Space Models (SSMs)
State Space Models (SSMs) are a type of architecture for processing discrete-time sequences. These models are used in the following equations to represent the input sequences $u_1$, . . . $u_L$ is processed and the output $y_1$, . . . $y_L$ and behaves like a linear recurrent neural network (RNN) that produces $y_1$, .
WHEREAS,
This is the case. The main advantage of this approach over alternative architectures such as Transformers or standard RNNs is that it allows implementing long convolutions rather than recursion using linear structures. Specifically, using FFT to compute $y_k$ from $u_k$ yields a computational complexity of $O(Llog L)$ and can be applied to longer sequences. When processing vector inputs, one can stack D different SSMs and apply D batch FFTs.
A linear RNN by itself is not an effective sequence model, but using discrete-time values from an appropriate continuous-time state-space model is a stable and effective approach. We use a diagonalized SSM neural network, S4D, as the backbone model, which can learn continuous-time SSM parameterizations and discrete-time parameters and obtain equivalent results by approximation.
DIFFUSSM Block
The central component of DIFFUSSM is a gated bidirectional SSM for optimized processing of long sequences. To improve efficiency, an hourglass architectureis incorporated withinthe MLP layer. This design alternates between extending and reducing sequence lengths around the Bidirectional SSMs and specifically within the MLP to shorten sequence lengths. The complete model architecture is shown in Figure 1.
Specifically, each Hourglass layer receives a shortened, flattened input sequence $I ∈ R^{J×D}$. where M = L/J is the ratio of downscale to upscale. At the same time, the entire block containing the bidirectional SSMs is computed in its original length, fully exploiting the global context. Here, $σ$ denotes the activation function. $l ∈ {1 . . . L}$, j = $⌊l/M⌋$, $m = l mod M$, and $D_m = 2D/M$, we compute the following
We integrate this gated SSM block at each layer with skip connections. Furthermore, at each position we integrate the combination of class label $y ∈ R^{L×1}$ and time step $t ∈ R^{L×1}$, as shown in Figure 1.
Experiment
Class Conditional Image Generation
In this experiment, we validate the effectiveness of the proposed method through a class-conditional image generation task on the ImageNet 256x256 and512x512datasets. The real results are summarized in Table 1. We can confirm that the computational complexity (Gflops) is significantly reduced compared to previous studies, especially to the conventional diffusion model.
The ImageNet 256x256 dataset outperformed DiT on several metrics, andImageNet 512x512 achieved competitive results with less training.
In other words, the goal of reducing computational complexity and preserving the quality of the generated images has been achieved. Figure 2 shows an example of the generation for each dataset.
Non-Class Conditional Generation Performance
The results show that DIFFUSSM achieves comparable FID scores to LDM, with comparable training costs (difference of -0.08 and 0.07). This result highlights the applicability of DIFFUSSM to different benchmarks and different tasks; as with LDM, this approach did not perform better against ADM in LSUN-Bedrooms. However, this is because it used only 25% of ADM's total training costs.
Model scalability and hourglass architecture effects
We trained the model with different sampling settings and evaluated the impact of compression on the latent space. Results are shown in Figure 3 (right). Comparing the regular model (M = 2) to the model with patch size 2 applied (P = 2), the regular model showed a better FID score, and the difference widened as the training steps increased. This suggests that information compression may have a negative impact on high-quality image generation.
We also trained three different sizes of DIFFUSSM to evaluate performance with scaling. Calculations of FID-50k for the first 400k steps confirmed that the larger model utilizes FLOPs more efficiently and that scaling improves FID at each stage of training. Results are shown in Figure 3 (left).
Summary
In this article, we introduced DIFFUSSM, an Attention-independent diffusion model. This approach can handle long-range hidden states without requiring compression of the representation. As a result, it achieves better performance with fewer Gflops than the DiT model at 256x256 resolution and shows competitive results with less training at higher resolutions.
However, several limitations remain. First, it focuses on unconditional image generation and does not support full text-to-image approaches. Also, recent approaches such as masked image training could improve the model.
Nevertheless, DIFFUSSM offers an alternative approach to learning large-scale diffusion models, and we believe that removing the Attention bottleneck opens the door for applications in other areas that require long-range diffusion, such as high-fidelity audio, video, and 3D modeling! The company is also working with the industry to develop a new technology for the use of Attention.
Categories related to this article