MaskDiT: 画像生成向けた低学習コストの拡散モデル
3つの要点
✔️ 拡散モデルは高品質で多様性の高い画像生成ができますが、学習には膨大な計算リソースと時間がかかります。
✔️ MaskDiTはMasked transformersを用いることで、この課題の解決を目指しています。
✔️ ImageNetデータセットでの実験では、MaskDiTが最先端のDiffusion Transformer(DiT)を上回る生成性能を達成し、トレーニング時間を約30%削減することが示されました。
Fast Training of Diffusion Models with Masked Transformers
written by Hongkai Zheng, Weili Nie, Arash Vahdat, Anima Anandkumar
(Submitted on 15 Jun 2023)
Comments: Published on arxiv.
Subjects: Computer Vision and Pattern Recognition (cs.CV); Artificial Intelligence (cs.AI); Machine Learning (cs.LG)
code:
本記事で使用している画像は論文中のもの、紹介スライドのもの、またはそれを参考に作成したものを使用しております。
はじめに
拡散モデルは、特にテキスト入力を用いて高品質で多様な画像を合成する際の優れた画像生成性能により、最も人気のある深層生成モデルのクラスとなっています。その強力な生成性能は、スタイル変換や背景生成など、多くのアプリケーションを可能にしています。
これらのモデルが現実的な画像や創造的なアートを生成する能力を持つためには、大規模なトレーニングが不可欠です。しかし、これらのモデルのトレーニングには大量の計算資源と時間が必要であり、さらなるスケーリングの大きなボトルネックとなっています。
例えば、オリジナルのStable Diffusionは、256台のA100 GPUを使用して24日以上かけてトレーニングされました。インフラと実装の改善により、256台のA100 GPUでトレーニングコストを13日間に削減することは可能ですが、それでも多くの研究者や開発者にとっては膨大な計算リソースです。したがって、拡散モデルのトレーニング効率を改善することは依然として未解決の課題です。
今回の解説論文では、マスキングをDiffusion Transformer(DiT)に導入して、この課題の解決を目指しています。トレーニング中に拡散された入力画像のパッチの高い割合(約50%)をランダムにマスクします。マスクトレーニングのために、非対称のエンコーダーデコーダーアーキテクチャを導入し、トランスフォーマーエンコーダーはマスクされていないパッチのみを処理し、軽量なトランスフォーマーデコーダーは全てのパッチを処理します。また、全パッチの長距離理解を促進するために、マスクされたパッチを再構築する補助タスクを追加します。
ImageNet-256×256およびImageNet-512×512での実験により、提案手法は最先端の拡散トランスフォーマー(DiT)モデルと比較して競争力のある、さらにはそれを上回る生成性能を達成しながら、元のトレーニング時間の約30%で済むことが示されました。
提案手法
図1は、提案手法の概要とDiTのアーキテクチャとの比較を示しています。本節では、学習コストを削減するためのアーキテクチャの改善について簡単に説明します。
マスキング
図1-bの下部に該当します。クリーンな画像 $x_0$ と拡散タイムステップ $t$ が与えられた場合、ガウスノイズ $n$ を加えることで拡散画像 $x_t$ を得ます。次に、$x_t$ を $p \times p$ サイズの重ならない $N$ 個のパッチに分割します。解像度が $H \times W$ の画像の場合、$N$ は $N = \frac{HW}{p^2}$ と計算されます。固定されたマスキング比率 $r$ を用いて、ランダムに $\lfloor rN \rfloor$ 個のパッチを削除し、残りの$N - \lfloor rN \rfloor$ 個のマスクされていないパッチを拡散モデルに渡します。すべての拡散タイムステップにわたって同じマスキング比率 $r$ を維持します。
高いマスキング比率は計算効率を大幅に向上させますが、学習効率を減少させる可能性があります。しかし、$x_t$ に大きな冗長性があるため、マスキングによる学習は、モデルが隣接するパッチからマスクされたパッチを補完する能力によって補われる可能性があります。したがって、良好な性能と高いトレーニング効率の両方を達成できる最適なバランスが存在するかもしれません。
非対称の$Encoder$ー$Decoder$バックボン
提案手法の拡散バックボーンは、拡散モデルの標準的なViTベースのアーキテクチャであるDiTに基づいており、いくつかの修正を加えています。MAE(He et al., 2022)と同様に、非対称のエンコーダーデコーダーアーキテクチャを使用します:
- エンコーダー:元のDiTと同じアーキテクチャを保持していますが、最終的な線形変換層を省略し、マスクされていないパッチのみを処理します。
- デコーダー:これは軽量なMAEデコーダーから適応された別のDiTアーキテクチャで、全てのトークンを入力として処理するように設計されています。
DiTと同様に、提案手法のエンコーダーは線形変換を使用してパッチを埋め込み、標準的なViTの周波数ベースの位置埋め込みをすべての入力トークンに追加します。マスクされたトークンは、残りのエンコーダーレイヤーに渡される前に削除されます。
デコーダーは、エンコードされたマスクされていないトークンと新しいマスクトークンの両方を入力として受け取ります。各マスクトークンは共有される学習可能なベクトルです。位置埋め込みは、デコーダーに渡す前にすべてのトークンに追加されます。
この非対称設計(例:MAEデコーダーはDiT-XL/2のパラメータの9%未満)により、マスキングは1回のイテレーションあたりの計算コストを大幅に削減できます。
学習損失
学習には、通常の拡散モデルと同様、denoising score matching lossを利用しますが、マスクされないトークンだけに適用します。
また、全パッチの長距離理解を促進するために、マスクされたパッチの再構築タスクを追加します。
最終の損失は次のように、λのハイパーパラメータでバランスを調整します。
実験
提案手法の学習コスト削減効果
本実験では、8台のA100 GPU上でMaskDiT、DiT-XL/2、およびMDT-XL/2の学習効率を、GFLOPs、学習速度とメモリ消費、そしてウォールタイム学習収束の3つの観点から比較します。これら3つのモデルは、サイズがほぼ同じです。
- GFLOPs: 図2に示されているように、MaskDiTのGFLOPsは、DiTおよびMDTのそれよりも著しく低いです。具体的には、MaskDiTのGFLOPsはDiTの54.0%、MDTの31.7%に過ぎません。参考までに、LDM-8はMaskDiTと同程度のGFLOPsを持っていますが、FIDの点では劣ります。
- 学習速度とメモリ消費: 図3に示されるように、MaskDiTは他のモデルよりも学習速度が高く、メモリ消費が低いです。特に大きなバッチサイズで顕著です。例えば、解像度256×256でバッチサイズ1024の場合、MaskDiTの学習速度はDiTの3.5倍、MDTの6.5倍で、メモリ消費はDiTの45.0%、MDTの19.2%です。
- ウォールタイム学習収束: MaskDiTは他のモデルよりも早く収束します。例えば、ImageNet 256×256でバッチサイズ1024の場合、MaskDiTは40時間以内にFID 10を達成しますが、他のモデルは160時間以上かかります。(図4を参照)
総合的に、MaskDiTはGFLOPs、学習速度、メモリ消費、およびウォールタイム収束の観点で、DiTおよびMDTに比べて優れた学習効率を示しています。
先行研究との比較
本実験では、MaskDiTモデルを最先端のクラス条件付き生成モデルと比較しまます。256×256解像度の結果は200万ステップの学習後、512×512解像度の結果は100万ステップの学習後に得られました。結果は表1と2にまとめられます。
・ImageNet-256×256
- CFGなし: MaskDiTは37.5kステップ後、FIDを6.23から5.69に改善し、他の非カスケード拡散モデルを上回りました。CDMはより良いFIDを持ちますが、ISは劣ります(158.71対177.99)。
- CFGあり: MaskDiT-Gは75kステップの調整後、FID 2.28を達成し、DiT-XL/2-Gの2.27に近い値を示しました。学習時間は8台のA100 GPUで273時間で、DiT-XL/2の868時間の31%です。MDT-XL/2-Gと比較すると、MaskDiT-GはFIDとISで劣りますが、Precision/Recallは同等です。
・ImageNet-512×512
- CFGなし: MaskDiTはFID 10.79を達成し、DiT-XL/2の12.03よりも優れています。
- CFGあり: MaskDiT-GはFID 2.50を達成し、ADM(3.85)やDiT(3.04)を上回りました。総学習コストは約209 A100 GPU日で、DiTの712 A100 GPU日の29%です。
まとめ
本記事では、マスクされたトランスフォーマーを用いた効率的な拡散モデルの学習手法であるMaskDiTを紹介しました。
画像パッチの大部分をランダムにマスクすることで、各イテレーションの学習オーバーヘッドを大幅に削減します。マスクされた学習に対応するため、非対称のエンコーダーデコーダー拡散バックボーンを導入しました。DiTエンコーダーは可視トークンのみを処理し、軽量なDiTデコーダーはマスクされたトークンが注入された後に全トークンを処理します。また、全パッチの長距離理解を促進するために、マスクされたパッチを再構築する補助タスクも追加しました。
ImageNet-256×256およびImageNet-512×512での実験により、提案手法は最先端の拡散モデルと比較して競争力のある、さらにはそれを上回る生成性能を達成しながら、元のトレーニング時間を約30%削減できました。MaskDiTは、より多くの研究者や開発者が、画像生成、特に拡散モデルの改善と研究に参加しやすくなることが期待されています。
この記事に関するカテゴリー