
MaskDiT: Low Learning Cost Diffusion Model For Image Generation
3 main points
✔️ Diffusion models can generate high quality, highly diverse images, but training them takes a tremendous amount of computational resources and time.
✔️ MaskDiT aims to solve this challenge by using Masked transformers.
✔️ Experiments on the ImageNet dataset show that MaskDiT outperforms state-of-the-art Diffusion Transformer (DiT) in generation performance and reduces training time by about 30%.
Fast Training of Diffusion Models with Masked Transformers
written by Hongkai Zheng, Weili Nie, Arash Vahdat, Anima Anandkumar
(Submitted on 15 Jun 2023)
Comments: Published on arxiv.
Subjects: Computer Vision and Pattern Recognition (cs.CV); Artificial Intelligence (cs.AI); Machine Learning (cs.LG)
code:
The images used in this article are from the paper, the introductory slides, or were created based on them.
Introduction
Diffusion models have become the most popular class of deep generative models due to their excellent image generation performance, especially in synthesizing high quality and diverse images with text input. Their strong generative performance enables many applications such as style transformation and background generation.
Extensive training is essential for these models to have the ability to produce realistic images and creative art. However, training these models requires large amounts of computational resources and time, which is a major bottleneck to further scaling.
For example, the original Stable Diffusion was trained over 24 days using 256 A100 GPUs. While infrastructure and implementation improvements can reduce the training cost to 13 days with 256 A100 GPUs, it is still an enormous computational resource for many researchers and developers. Therefore, improving the training efficiency of diffusion models remains an open problem.
In this commentary paper, masking is introduced to the Diffusion Transformer (DiT) to solve this problem. It randomly masks a high percentage (about 50%) of the patches in the input image that are diffused during training. For mask training, we introduce an asymmetric encoder-decoder architecture, with the transformer encoder processing only unmasked patches and the lightweight transformer decoder processing all patches. It also adds an auxiliary task to reconstruct masked patches to facilitate long-range understanding of all patches.
Experiments on ImageNet-256×256 and ImageNet-512×512 show that the proposed method achieves competitive and even superior generation performance compared to state-of-the-art diffusion transformer (DiT) models, while requiring only about 30% of the original training time The results are shown in Figure 1.
Proposed Method
Figure 1 provides an overview of the proposed method and a comparison to DiT's architecture. This section briefly describes the architectural improvements to reduce learning costs.

Masking
This corresponds to the lower part of Figure 1-b. Given a clean image $x_0$ and a diffusion time step $t$, a diffusion image $x_t$ is obtained by adding Gaussian noise $N$. Next, divide $x_t$ into $N$ non-overlapping patches of size $p \times p$. For an image with resolution $H \times W$, $N$ is computed as $N = \frac{HW}{p^2}$. Using a fixed masking ratio $r$, randomly remove $\lfloor rN \rfloor$ patches and pass the remaining $N - \lfloor rN \rfloor$ unmasked patches to the diffusion model. Maintain the same masking ratio $r$ across all diffusion timesteps.
High masking ratios can greatly increase computational efficiency, but may decrease learning efficiency. However, due to the large redundancy in $x_t$, learning by masking may be compensated by the ability of the model to complement masked patches from neighboring patches. Thus, there may be an optimal balance that achieves both good performance and high training efficiency.
Asymmetric $Encoder$-Decoder$ Backbon
The diffusion backbone of the proposed method is based on DiT, the standard ViT-based architecture for diffusion models, with some modifications; it uses an asymmetric encoder-decoder architecture, similar to MAE ( He et al., 2022 ):
- Encoder: Retains the same architecture as the original DiT, but omits the final linear transformation layer and processes only unmasked patches.
- Decoder: This is another DiT architecture adapted from the lightweight MAE decoder, designed to process all tokens as input.
Similar to DiT, theencoders inthe proposed methodembed patches using a linear transformation and add standard ViT frequency-based position embedding to all input tokens. Masked tokens are removed before being passed to the remaining encoder layers.
The decoder receives both the encoded unmasked token and the new mask token as input. Each mask token is a shared learnable vector. Position embedding is added to all tokens before passing them to the decoder.
This asymmetric design (e.g., MAE decoder is less than 9% of DiT-XL/2 parameters) allows masking to significantly reduce computational cost per iteration.
Learning Loss
As in the usual diffusion model, denoising score matching loss is used for learning, but it is applied only to unmasked tokens.
Also, add a task to rebuildmasked patches to facilitate long-range understanding of all patches.
The final loss is balanced by the hyperparameters of λ as follows
Experiment
Effectiveness of The Proposed Method in Reducing Learning Costs
In this experiment, we compare the learning efficiency of MaskDiT, DiT-XL/2, and MDT-XL/2 on eight A100 GPUs in terms of GFLOPs, learning speed and memory consumption, and wall-time learning convergence. These three models are nearly identical in size.
- GFLOPs: As shown in Figure 2, MaskDiT's GFLOPs are significantly lower than those of DiT and MDT. Specifically, MaskDiT's GFLOPs are only 54.0% of DiT's and 31.7% of MDT's. For reference, LDM-8 has GFLOPs comparable to those of MaskDiT, but is inferior in terms of FID.
Figure 2: Comparison of state-of-the-art diffusion model generation performance on ImageNet-256×256 - LEARNING SPEED AND MEMORY CONSUMPTION: As shown in Figure 3, MaskDiT has higher learning speed and lower memory consumption than other models. This is especially true for large batch sizes. For example, with a resolution of 256 x 256 and a batch size of 1024, MaskDiT's learning speed is 3.5 times faster than DiT and 6.5 times faster than MDT, and its memory consumption is 45.0% of DiT's and 19.2% of MDT's.
Figure 3. training speed (steps/second) vs. memory per GPU (GB) for DiT, MDT, and MaskDiT - Wall-time learning convergence: MaskDiT converges faster than other models. For example, with ImageNet 256 x 256 and a batch size of 1024, MaskDiT achieves FID 10 in less than 40 hours, while other models take more than 160 hours. (See Figure 4)
Figure 4: Comparison of FID and training time on image net 256x256
Overall, MaskDiT shows superior learning efficiency compared to DiT and MDT in terms of GFLOPs, learning speed, memory consumption, and wall-time convergence.
Comparison with Previous Studies
In this experiment, the MaskDiT model is compared to the state-of-the-art class-conditional generative model, with results for 256 x 256 resolution obtained after 2 million training steps and results for 512 x 512 resolution obtained after 1 million training steps. The results are summarized in Tables 1 and 2.
ImageNet-256×256
- No CFG: MaskDiT improves FID from 6.23 to 5.69 after 37.5k steps, outperforming other non-cascade diffusion models; CDM has better FID but inferior IS (158.71 vs. 177.99).
- With CFG: MaskDiT-G achieved FID 2.28 after 75k steps of adjustment, close to DiT-XL/2-G's 2.27. Learning time is 273 hours on 8 A100 GPUs, 31% of DiT-XL/2's 868 hours; compared to MDT-XL/2-G, MaskDiT-G is inferior in FID and IS, but comparable in Precision/Recall.

ImageNet-512 x 512
- Without CFG: MaskDiT achieves FID 10.79, better than DiT-XL/2's 12.03.
- With CFG: MaskDiT-G achieved FID 2.50, outperforming ADM (3.85) and DiT (3.04). Total learning cost is about 209 A100 GPU days, 29% of DiT's 712 A100 GPU days.

Summary
This article introduced MaskDiT, an efficient method for learning diffusion models using masked transformers.
Random masking of the majority of image patches significantly reduces the learning overhead for each iteration. To accommodate masked learning, we introduced an asymmetric encoder-decoder diffusion backbone: the DiT encoder processes only visible tokens, while the lightweight DiT decoder processes all tokens after masked tokens are injected. We also added an auxiliary task to reconstruct masked patches to facilitate long-range understanding of all patches.
Experiments on ImageNet-256×256 and ImageNet-512×512 show that the proposed method achieves competitive and even superior generation performance compared to state-of-the-art diffusion models, while reducing the original training time by about 30%. It is hoped that MaskDiT will facilitate the participation of more researchers and developers in improving and studying image generation, especially diffusion models.
Categories related to this article