赶上最新的AI论文

MaskDiT:用于图像生成的低学习成本扩散模型

MaskDiT:用于图像生成的低学习成本扩散模型

生成图像

三个要点
✔️ 扩散模型可以生成高质量和高度多样化的图像,但训练这些模型需要大量的计算资源和时间。
✔️ MaskDiT 旨在通过使用屏蔽变换器来解决这一难题。

✔️ 在 ImageNet 数据集上的实验表明,与最先进的扩散变换器(DiT)相比,MaskDiT 实现了更好的生成性能,并减少了约 30% 的训练时间。

Fast Training of Diffusion Models with Masked Transformers
written by Hongkai Zheng, Weili Nie, Arash Vahdat, Anima Anandkumar
(Submitted on 
15 Jun 2023)
Comments: Published on arxiv.

Subjects: Computer Vision and Pattern Recognition (cs.CV); Artificial Intelligence (cs.AI); Machine Learning (cs.LG)

code:

本文所使用的图片要么来自论文、介绍性幻灯片,要么是参考这些图片制作的。

介绍

扩散模型因其出色的图像生成性能而成为最受欢迎的深度生成模型,尤其是在合成高质量、多样化的文本输入图像方面。其强大的生成性能使许多应用成为可能,如风格转换和背景生成。

要使这些模型能够生成逼真的图像和创造性的艺术作品,大量的训练是必不可少的。然而,训练这些模型需要大量的计算资源和时间,这是进一步扩展的主要瓶颈。

例如,最初的 "稳定扩散 "使用 256 个 A100 GPU 训练了 24 天。虽然基础设施和实施方面的改进可以将 256 个 A100 GPU 的训练成本降低到 13 天,但对于许多研究人员和开发人员来说,这仍然是一个巨大的计算资源。因此,提高扩散模型的训练效率仍然是一个有待解决的难题。

在这篇评论文章中,为解决这一问题,在扩散变换器(DiT)中引入了遮罩技术。在训练过程中,很大部分(约 50%)的扩散输入图像会被随机屏蔽。为了进行遮罩训练,我们引入了非对称编码器-解码器架构,其中变换器编码器只处理未遮罩的斑块,而轻量级变换器解码器则处理所有斑块。此外,它还增加了一项辅助任务,以重建已遮蔽的补丁,从而促进对所有补丁的远距离理解。

在 ImageNet-256×256 和 ImageNet-512×512 上的实验表明,与最先进的扩散变换器(DiT)模型相比,所提出的方法实现了具有竞争力甚至更优越的生成性能,而所需的训练时间仅为原来的 30% 左右。结果如下

建议方法

图 1 提供了拟议方法的概览以及与 DiT 架构的比较。本节将简要介绍为降低学习成本而进行的架构改进。

图 1:拟议方法概述及与 DiT 的比较。

遮蔽

这与图 1-b 的下部相对应。给定干净的图像 $x_0$ 和扩散时间步长 $t$,通过添加高斯噪声 $N$,得到扩散图像 $x_t$。然后,将 $x_t$ 分成大小为 $p \times p$ 的 $N$ 非重叠斑块。对于分辨率为 $H ×times W$ 的图像,$N$ 的计算公式为 $N = \frac{HW}{p^2}$。使用固定的遮挡率 $r$,随机移除 $lfloor rN \rfloor$ 补丁,并将剩余的 $N - \lfloor rN \rfloor$ 未遮挡补丁传递给扩散模型。在所有扩散时间步中保持相同的屏蔽率 $r$。

高掩蔽率可以显著提高计算效率,但可能会降低学习效率。然而,由于 $x_t$ 中存在大量冗余,模型从邻近斑块中补充被遮挡斑块的能力可能会弥补遮挡学习的不足。因此,可能存在一个最佳平衡点,既能实现良好的性能,又能达到较高的训练效率。

非对称编码器-解码器骨干网

拟议方法的扩散主干基于 DiT,这是一种基于 ViT 的标准扩散建模架构,但做了一些修改:与 MAE(He 等人,2022 年)一样,它使用非对称编码器-解码器架构:

  • 编码器:保留了与原始 DiT 相同的结构,但省略了最终线性变换层,只处理未屏蔽的补丁。
  • 解码器:这是另一种 DiT 架构,改编自轻量级 MAE 解码器,旨在处理作为输入的所有标记。

与 DiT 类似,拟议方法中的编码器使用线性变换嵌入补丁,并对所有输入标记添加标准的基于 ViT 频率的位置嵌入。屏蔽的标记在传递到其余编码器层之前会被移除。

解码器接收已编码的未掩码标记和新的掩码标记作为输入。每个掩码标记都是一个共享的可学习向量。在将所有标记传递给解码器之前,会对其进行位置嵌入。

由于采用了这种非对称设计(例如,MAE 解码器的参数小于 DiT-XL/2 的 9%),屏蔽可以显著降低每次迭代的计算成本。

学习损失

学习使用去噪分数匹配损失,就像通常的扩散模型一样,但只针对未屏蔽的标记。

它还增加了重建屏蔽补丁的任务,以便于对所有补丁进行长期了解

最终损失由 λ超参数平衡,具体如下。

试验

拟议方法在降低学习成本方面的有效性

本实验从 GFLOPs、学习速度和内存消耗以及墙时学习收敛等方面比较了 MaskDiT、DiT-XL/2 和 MDT-XL/2 在 8 个 A100 GPU 上的学习效率。这三个模型的大小几乎相同。

  • GFLOPs:如图 2 所示,MaskDiT 的 GFLOPs 明显低于 DiT 和 MDT。具体来说,MaskDiT 的 GFLOPs 仅为 DiT 的 54.0%,MDT 的 31.7%。作为参考,LDM-8 的 GFLOPs 与 MaskDiT 相似,但在 FID 方面不如 MaskDiT。
    图 2:ImageNet-256×256 上最先进的扩散模型生成性能比较
  • 学习速度和内存消耗:如图 3 所示,与其他模型相比,MaskDiT 具有更高的学习速度和更低的内存消耗。这一点在批量较大时尤为明显。例如,分辨率为 256 x 256、批量大小为 1024 时,MaskDiT 的学习速度是 DiT 的 3.5 倍,是 MDT 的 6.5 倍,而内存消耗则是 DiT 的 45.0%,是 MDT 的 19.2%。
    图 3 DiT、MDT 和 MaskDiT 的训练速度(步/秒)与每 GPU 内存(GB)的对比。
  • 墙时学习收敛:MaskDiT 比其他模型收敛更快。例如,对于批量大小为 1024 的 ImageNet 256×256,MaskDiT 可在 40 小时内达到 FID 10,而其他模型则需要 160 小时以上。(见图 4)
    图 4:256×256 图像网的 FID 和训练时间比较。

总之,与 DiT 和 MDT 相比,MaskDiT 在 GFLOPs、学习速度、内存消耗和墙时收敛方面都显示出更高的学习效率。

与以往研究的比较

在本实验中,MaskDiT 模型与最先进的类条件生成模型进行了比较:256×256 分辨率的结果是经过 200 万步训练后得出的,512×512 分辨率的结果是经过 100 万步训练后得出的。结果汇总于表 1 和表 2。

图像网络 - 256×256

  • 无 CFG:MaskDiT 在 37.5 千步后将 FID 从 6.23 提高到 5.69,优于其他非级联扩散模型;CDM 的 FID 更好,但 IS 较差(158.71 对 177.99)。
  • 使用 CFG:经过 75k 步调整后,MaskDiT-G 的 FID 达到 2.28,接近 DiT-XL/2-G 的 2.27。在 8 个 A100 GPU 上的学习时间为 273 小时,是 DiT-XL/2 的 868 小时的 31%;与 MDT-XL/2-G 相比,MaskDiT-G 在 FID 和 IS 方面逊色,但在精度/调用方面不相上下。
表 1.ImageNet 256×256 的比较结果

ImageNet-512×512

  • 无 CFG:MaskDiT 的 FID 为 10.79,优于 DiT-XL/2 的 12.03。
  • 使用 CFG:MaskDiT-G 的 FID 为 2.50,优于 ADM(3.85)和 DiT(3.04)。总学习成本约为 209 A100 GPU 天,是 DiT 712 A100 GPU 天的 29%。
表 2.ImageNet 256×256 的比较结果

摘要

本文介绍了 MaskDiT,这是一种利用遮蔽变换器学习扩散模型的高效方法。

通过随机屏蔽大部分图像片段,每次迭代的学习开销都会大大减少。为了适应屏蔽学习,我们引入了非对称编码器-解码器扩散骨干:DiT 编码器只处理可见标记,而轻量级 DiT 解码器则在注入屏蔽标记后处理所有标记。此外,还增加了一项辅助任务,以重建被遮蔽的片段,从而促进对所有片段的长程理解。

在 ImageNet-256×256 和 ImageNet-512×512 上的实验表明,与最先进的扩散模型相比,所提出的方法达到了有竞争力甚至更优越的生成性能,同时减少了约 30% 的原始训练时间。开发人员,并促进更多研究人员和开发人员参与图像生成(尤其是扩散模型)的改进和研究。

  • メルマガ登録(ver
  • ライター
  • エンジニア_大募集!!

如果您对文章内容有任何改进建议等,请通过 "联系我们 "表格与爱学网编辑部联系。
如果您能通过咨询表与我们联系,我们将非常感激。

联系我们