Catch up on the latest AI articles

ADD: Diffusion Model With Adversarial Learning And Knowledge Distillation

ADD: Diffusion Model With Adversarial Learning And Knowledge Distillation

Image Generation

3 main points
✔️ Diffusion models are very slow to infer, making them difficult to use in real-time
✔️ Introducing ADD
significantly increases estimation speedwhile maintaining quality
✔️ ADD is thefirst single-step diffusion modelusing adversarial learning and knowledge distillation

Adversarial Diffusion Distillation
written by Axel SauerDominik LorenzAndreas BlattmannRobin Rombach
(Submitted on 28 Nov 2023)
Comments: Published on arxiv.

Subjects: Computer Vision and Pattern Recognition (cs.CV)

code:
 

The images used in this article are from the paper, the introductory slides, or were created based on them.

Introduction

Diffusion models have received a great deal of attention as generative models and have recently made remarkable advances in the generation of high-quality images and videos. Two of the strengths of diffusion models are high image quality and high diversity. However, when generating images, hundreds to thousands of sampling steps are required and the estimation speed is very slow.

Generative adversarial networks (GANs), on the other hand, are characterized by single-step formulations and fast sampling. However, despite attempts to extend them to large data sets, GANs often fall short of diffusion models in sample quality. Another weakness is the low diversity of the generated images.

The goal of this commentary paper is to combine the excellent sample quality of the diffusionmodel with the fast sampling of the GAN. To achieve this, we introduce a combination of two training goals.

  • hostile loss
  • Distillation loss corresponding to score distillation sampling (SDS)

Adversarial loss avoids blurring and other artifacts that often occur in other distillation methods by comparing real and generated images through a discriminator. Distillation loss uses another pre-trained (and fixed) diffusion model as a teacher,effectively leveraging the extensive knowledge ofpre-traineddiffusionmodels.

The proposed method outperforms SDXL, a diffusion model SOTA, by generating real-time images with high fidelity in only one to four sampling steps.

Proposed Method

Figure 1: Overview of the proposed method

Training Procedures

The training procedure is shown in Figure 1, where the main model, ADD-student, consistsofthree pre-trained diffusion models with weights θ (UNet-DM), a discriminator with trainable weights $ϕ$, and a DM-Teacher with frozen weights $ψ$ (diffusion model ) models.

With respect to hostile loss, the generated sample $(\hat{x}_\theta) $ and the actual image $( x_0) $ are passed to a discriminator to distinguish them. Details on the design of the discriminator and adversarial loss are discussed in the next section. To distill knowledge from the DM-Teacher, the ADD-student sample $( \hat{x}_\theta) $ is diffused into the teacher's (DM-Teacher)forward process $(\hat{x}_{\theta,t})$ and the distillation loss $( L_{\text{distill }})$ and use the teacher's denoising prediction $(\hat{x}_\psi(\hat{x}_{\theta,t},t)) $ as the reconstruction target. Detailsare given in the next section.

The overall loss function becomes the following equation

Hostile loss and discriminator

For the discriminators, we use the structure and setup of Stylegan-t (Sauer et al, 2023). We use a fixed pre-trained feature network F and a set of trainable lightweight discriminators head \( D_{(ϕ, k)} \). For feature network F, we examine different choices of ViT and model size in the next section, as Sauer et al. found that the vision transformer (ViT) works well. Trainable discriminator heads are applied to features Fk in different layers of the feature network.

The loss of the discriminator \( L_{adv}^D \) and the main model \( L_{adv}^G \) are as follows

where \( R1 \) denotes the \( R1 \) gradient penalty. Instead of computing the gradient penalty on the pixel values, we compute it at the output of each discriminator head \( D_{(ϕ, k)} \). The R1 penalty is particularly useful when the output resolution is larger than \(128 \times 128 \) pixels.

Score distillation loss

The score distillation loss becomes the following equation

\( sg \) indicates a stop-gradient operation.Score distillation lossuses the distance metric \(d \) to calculate the difference between the sample \(x_\theta \)generated by the ADD-student and the output of the DM-teacher . To find theappropriate \( d \), we tested many functions in our experiments, but the mean squared error (MSE) was the most effective.

Experiment

Quantitative comparison of generative models with SOTA

Figure 2. user preference survey (single step). results comparing ADD-XL (single step) with baseline
Figure 3. user preference survey (multiple steps). results comparing ADD-XL (4 steps) with baseline

In this experiment, we test the effectiveness of the proposed method more objectively with a user preference survey, rather than with the commonly used evaluation index of automatic calculations. Users choose the better of two evaluation metrics: prompt compliance (whether the input prompt is correctly reflected in the output image) and image quality. The results are summarized in Figures 2 and 3. With only a few sampling steps (1-4), the proposed method outperforms the representative models of the generative model and achieves SOTA results, especially with 4 steps.

Qualitative Results and Comparisons

Figure 4: Examples of results generated by SDXL and the proposed method

A qualitative comparison between SDXL and the proposed method is shown in Figure 4. It can be seen that the proposed method is able to produce the same or better image quality as SDXL in only 4 steps. We can also confirm that the input proton is correctly reflected in the generated results. In particular, as shown in the lower left image in Figure 4, we can confirm that the proposed method produces lessnoise and artifacts than SDXL. Including the quantitative experimental results,we can see that the proposed method outperforms SDXL, the SOTA of the diffusion model, in both quality and prompt consistency, with fewer sampling steps.

Summary

In this article, we introduced Adversarial Diffusion Distillation (ADD) for distilling pre-trained diffusion models into fast, low-step image generation models. The proposed method combines adversarial distillation and score distillation loss to distill trained models such as Stable Diffusion and SDXL, leveraging both real data from discriminators and structural understanding from diffusion teachers. The proposed method performs particularly well with ultra-fast sampling of one or two steps, and experimental results show that it outperforms prior work in many cases. On the other hand, further increasing the number of steps yielded much better results,outperforming commonly used multi-step diffusion models such as SDXL, IF, and OpenMUSE. However, there is still room for improvement in the generation with one sampling step with respect to image quality and consistency with the proton.With further improvement, the proposed methodmay be the firstdiffusionmodel available in real-time.

  • メルマガ登録(ver
  • ライター
  • エンジニア_大募集!!

If you have any suggestions for improvement of the content of the article,
please contact the AI-SCHOLAR editorial team through the contact form.

Contact Us