ADD: 敵対的学習と知識蒸留を用いた拡散モデル
3つの要点
✔️ 拡散モデルは推論速度が非常に遅いため、リアルタイムでの使用が難しい
✔️ ADDを導入することで、品質を維持しつつ推定速度を大幅に向上
✔️ ADDは敵対的学習と知識蒸留を用いた、初めてのシングルステップ拡散モデル
Adversarial Diffusion Distillation
written by Axel Sauer, Dominik Lorenz, Andreas Blattmann, Robin Rombach
(Submitted on 28 Nov 2023)
Comments: Published on arxiv.
Subjects: Computer Vision and Pattern Recognition (cs.CV)
code:
本記事で使用している画像は論文中のもの、紹介スライドのもの、またはそれを参考に作成したものを使用しております。
はじめに
拡散モデルは、生成モデルとして非常に注目されており、最近では高品質な画像や動画の生成に顕著な進歩をもたらしています。拡散モデルの強みは高い画質、高い多様性という2点を挙げられます。しかし、画像を生成する時、数百~数千のサンプリングステップが必要であり、推定速度が非常に遅いです。
一方、生成敵対的ネットワーク(GAN)は、シングルステップの定式化と高速サンプリングで特徴付けられています。しかし、大規模なデータセットへの拡張の試みにもかかわらず、GANはしばしばサンプル品質で拡散モデルに及びません。また、生成画像の多様性が低いという弱点もあります。
今回の解説論文の目標は拡散モデルの優れたサンプル品質とGANの高速サンプリングを組み合わせることです。これを実現するために、2つのトレーニング目標の組み合わせを導入します。
- 敵対的損失
- スコア蒸留サンプリング(SDS)に対応する蒸留損失
敵対的損失は、実在画像と生成画像を識別器を通じて比較することで、他の蒸留手法でよく生じるぼやけやその他のアーティファクトを回避します。蒸留損失は、別の事前トレーニングされた(および固定された)拡散モデルを教師として使用し、事前トレーニングされた拡散モデルの幅広い知識を効果的に活用します。
提案手法はわずか1~4サンプリングステップで忠実度の高いリアルタイム画像生成ができて、拡散モデルのSOTAであるSDXLを上回りました。
提案手法
トレーニング手順
トレーニング手順は図1に示すように、メインモデルであるADD-studentは、重みθを持つ事前にトレーニングされた拡散モデル(UNet-DM)、学習可能な重み$ϕ$を持つ識別器、および凍結された重み$ψ$を持つDM-Teacher(拡散モデル)という3つのモデルを利用します。
敵対的損失に関して、生成されたサンプル $(\hat{x}_\theta) $と実際の画像$( x_0) $は、それらを区別するための識別器に渡されます。 識別器の設計と敵対的な損失についての詳細は、次のセクションに述べます。 DM教師から知識を蒸留するために、ADD-studentのサンプル$( \hat{x}_\theta) $を教師(DM-Teacher)の前向きプロセスに拡散させて $(\hat{x}_{\theta,t}) $にし、蒸留損失$( L_{\text{distill}})$ の再構成ターゲットとして教師のノイズ除去予測 $( \hat{x}_\psi(\hat{x}_{\theta,t}, t)) $を使用します。詳しくは次のセクションに述べます。
全体の損失関数は次の式となります。
敵対的損失と識別器
識別器に関してはStylegan-t(Sauer et al, 2023)の構造と設定を利用します。固定された事前学習済みの特徴ネットワークFと一連のトレーニング可能な軽量の識別器ヘッド \( D_{(ϕ、k)} \)を使用します。特徴ネットワークFについては、Sauerらがビジョントランスフォーマー(ViT)がうまく機能することを見出したので、次のセクションでViTとモデルサイズの異なる選択肢を検証します。トレーニング可能な識別器ヘッドは、特徴ネットワークの異なる層の特徴Fkに適用されます。
識別器の損失 \( L_{adv}^D \)とメインモデルの \( L_{adv}^G \)は次のようになります。
ここで、\( R1 \) は \( R1 \) 勾配ペナルティを示しています。画素値に対する勾配ペナルティを計算する代わりに、各識別器ヘッド\( D_{(ϕ、k)} \)の出力でそれを計算します。出力解像度が \(128 \times 128 \) ピクセルより大きい場合、R1ペナルティが特に有効です。
スコア蒸留損失
スコア蒸留損失は次の式となります。
\( sg \)は、stop-gradient操作を示します。スコア蒸留損失は、ADD-studentによって生成されたサンプル \( x_\theta \)とDM-teacherの出力との相違を計算する距離メトリック \( d \) を使用します。適切な \( d \)を見つけるために、実験で、 多くの関数を検証しましたが、平均二乗誤差 (MSE)が一番有効でした。
実験
生成モデルのSOTAとの定量的比較
本実験では、よく使われている自動計算の評価指標ではなく、ユーザー嗜好調査でより客観的に提案手法の有効性を検証します。プロンプトの遵守さ(入力プロントは出力画像に正しく反映されるかどうか)と画質という2つの評価指標で、ユーザがより良いものを選びます。結果は図2と3にまとめられています。わずかなサンプリングステップ(1~4)で,提案手法は生成モデルの代表的なモデルを上回り、特に4ステップの場合はSOTAの結果を達成できました。
定性的な結果と比較
SDXLと提案手法の定性的な比較は図4に示されています。提案手法はわずか4ステップで、SDXLと同等以上の画質を生成できていると確認できます。また、入力プロントも生成結果に正しく反映されていると確認できます。特に、図4の左下の画像に示したように、ノイズやアーティファクトに関して、SDXLの生成結果よりも提案手法の方が少ないと確認できます。定量の実験結果も含めて、提案手法はより少ないサンプリングステップで、品質とプロンプトの整合性の両方で、拡散モデルのSOTAであるSDXLを上回っているとわかります。
まとめ
今回の記事では、事前に訓練された拡散モデルを、高速で少ないステップの画像生成モデルに蒸留するためのAdversarial Diffusion Distillation (ADD)を紹介しました。提案手法は、Stable DiffusionとSDXLなどの学習済モデルを蒸留するために、敵対的蒸留とスコア蒸留損失を組み合わせ、識別器による実データと拡散教師による構造理解の両方を活用しました。提案手法の1または2ステップの超高速サンプリングで特に優れた性能を発揮し、実験結果により、多くの場合、先行研究を上回っていることを示しています。一方、ステップ回数をさらに上げると、もっと良い結果を得られ、SDXL、IF、OpenMUSEなどのよく使われる複数ステップの拡散モデルを上回りました。しかし、1サンプリングステップでの生成には、画質やプロントとの整合性に関する改善余地がまだあります。より改善できれれば、提案手法はリアルタイムで利用可能な初めての拡散モデルとなるかもしれません。
この記事に関するカテゴリー