【Diffusion Transformer】OpenAIのSoraにも使われた最新技術
3つの要点
✔️ 拡散モデルとTranformerを合わせたモデル
✔️ 従来の拡散モデルやGANを超える画質と多様性を実現
✔️ 従来のU-Netモデルを上回る性能を示す
Scalable Diffusion Models with Transformers
written by William Peebles, Saining Xie
(Submitted on 2 Mar 2023)
Comments: Code, project page and videos available at this https URL
Subjects: Computer Vision and Pattern Recognition (cs.CV); Machine Learning (cs.LG)
code:
本記事で使用している画像は論文中のもの、紹介スライドのもの、またはそれを参考に作成したものを使用しております。
はじめに
機械学習は、トランスフォーマーによって牽引されるルネッサンスを経験しています。過去5年間、自然言語処理、画像解析などの分野で、ニューラルアーキテクチャの多くがトランスフォーマーによって大きく取って代わられました。しかし、画像生成モデルの多くはまだこのトレンドに追いついていません。たとえば、拡散モデルは画像生成モデルの最近の進歩の中心にありますが、すべてのモデルはデフォルトのバックボーンとして畳み込みU-Netアーキテクチャを採用しています。
今回の解説論文では、拡散モデルにおけるアーキテクチャの選択の重要性を解明し、将来の生成モデリング研究のための実証的な基準を提供することを目指しています。今回の論文は、U-Netの帰納的なバイアスが拡散モデルの性能に必須ではないことを示し、これらはトランスフォーマーなどの標準的な設計に簡単に置き換えることができることを示しています。その結果、拡散モデルは、スケーラビリティ、ロバスト性、効率性といった有利な特性を保持するだけでなく、他のドメインからのベストプラクティスやトレーニングレシピを継承することにより、アーキテクチャの統一化という最近のトレンドから恩恵を受けることができる。標準化されたアーキテクチャは、領域横断的な研究に新たな可能性をもたらすだろう。
今回の論文では、トランスフォーマーに基づく新しいクラスの拡散モデルに焦点を当てます。これらをDiffusion Transformers、または略してDiTsと呼びます。 DiTsは、従来の畳み込みネットワークよりも視覚認識に効果的にスケーリングできることが示されているVision Transformers(ViTs)のベストプラクティスに従います。
実験結果より、DiTsは高いスケーラビリティをもって、クラス条件付きの256×256 ImageNet生成ベンチマークで2.27 FIDの最先端の結果を達成できました。
提案手法
Patchify
DiTの入力は空間表現zです(256×256×3の画像の場合、zの形状は32×32×4です)。DiTの最初の層は「Patchify」であり、これにより空間入力が次元dの各パッチを線形に埋め込んだT個のトークンのシーケンスに変換されます。Patchifyに続いて、すべての入力トークンに標準のViT周波数ベースの位置埋め込み(サイン・コサインバージョン)を適用します。パッチ化によって生成されるトークンの数Tは、パッチサイズハイパーパラメータpによって決まります。図2に示されているように、pを半分にするとTが4倍になり、したがって合計のトランスフォーマーGflopsも少なくとも4倍になります。pを変更しても下流のパラメータ数に影響はありません。
DiTブロックの設計
パッチ化に続いて、入力トークンは一連のトランスフォーマーブロックで処理されます。入力ノイズ画像に加えて、拡散モデルは時刻 t、クラスラベルc、テキストなどの追加の条件付き情報を処理することがあります。この点に基づいて、今回の論文は、次の4つDiTのブロックを検討しました。これらの設計は、標準のViTブロック設計に対して小さいながらも重要な変更です。すべてのブロックの設計は図1に示されています。
・In-context conditioning
tとcのベクトル埋め込みを入力シーケンスに2つの追加トークンとして追加し、これらを画像トークンと同じように扱います。これはViTsのclsトークンと類似しており、修正なしに標準のViTブロックを使用できるようにします。最終ブロックの後、シーケンスから条件付きトークンを削除します。このアプローチにより生じるGflopsは非常に小さいので、無視できます。
・Cross-attentionブロック
tとcの埋め込みを、画像トークンのシーケンスとは別の長さ2のシーケンスに連結します。トランスフォーマーブロックは、Vaswaniらの元の設計と同様に、また、クラスラベルに対する条件付けに使用されるLDMに似た、マルチヘッドのセルフアテンションブロックの後に追加のマルチヘッドのクロスアテンションレイヤーを含むように修正されます。クロスアテンションは、モデルに最も多くのGflopsを追加し、おおよそ15%のオーバーヘッドを生じます。
・適応的正規化レイヤー (adaLN)
GANやUNetバックボーンを持つ拡散モデルでの適応的正規化レイヤーの広範な使用に続き、今回の論文はトランスフォーマーブロック内の標準的なレイヤー正規化層を適応的レイヤー正規化(adaLN)に置き換えることを検討しました。このadaLNでは、次元ごとのスケールおよびシフトパラメータγとβを直接学習する代わりに、これらを時刻tとクラスラベルcの埋め込みベクトルの合計から回帰します。今回の論文が検討したした3つのブロック設計の中で、adaLNはGflopsを最も追加せず、したがって最も計算効率が良いです。また、すべてのトークンに同じ関数を適用することに制限される唯一の調整メカニズムです。
・adaLN-Zeroブロック
ResNetの先行研究では、各残差ブロックを恒等関数として初期化することが有益であることがわかっています。たとえば、Goyalらは、各ブロックでの最終バッチ正規化スケールファクターγをゼロで初期化することで、教師あり学習の設定で大規模なトレーニングを加速させることができると検証しました。Diffusion U-Netモデルでは、各ブロックの残差接続の前に最終の畳み込み層をゼロで初期化するという類似した初期化戦略が使用されています。今回の解説論文では、これと同じ操作を行うadaLN DiTブロックを検討しました。γとβを回帰させるだけでなく、DiTブロック内の残差接続の直前に適用される次元ごとのスケーリングパラメーターαも回帰します。
モデルのサイズ
DiTsは、隠れ層のサイズdで動作するN個のDiTブロックのシーケンスを適用します。ViTに続いて、N、d、およびattentionの数を共にスケーリングする標準のトランスフォーマー設定を使用します。具体的には、DiT-S、DiT-B、DiT-L、およびDiT-XLの4つの設定を使用します。これらは、0.3から118.6 Gflopsまでの広範なモデルサイズとflop割り当てをカバーしており、スケーリングのパフォーマンスを評価することができます。表1には、設定の詳細が記載されています。
Transformer decoder
最終的なDiTブロックの後、画像トークンのシーケンスを出力ノイズ予測と出力対角共分散予測にデコードする必要があります。これらの出力の形状は、元の空間入力と同じです。これを行うために、標準の線形デコーダーを使用します。最終的なレイヤーノーム(adaLNを使用する場合は適応的)を適用し、各トークンをp×p×2Cテンソルに線形にデコードします。ここで、CはDiTの空間入力のチャンネル数です。最後に、デコードされたトークンを元の空間レイアウトに再配置して、予測されたノイズと共分散を取得します。今回の論文が検討するDiT設計空間は、パッチサイズ、トランスフォーマーブロックアーキテクチャ、およびモデルサイズです。
実験
DiTブロックの設計
最も高いGflopのDiT-XL/2モデルを4つ訓練しました。それぞれ、異なるブロック設計を使用しています。それらは、In-context conditioning(119.4 Gflops)、Cross-attention (137.6 Gflops)、適応的レイヤー正規化(adaLN、118.6 Gflops)、またはadaLN-zero(118.6 Gflops)です。トレーニングの間にFIDを測定しました。FIDは生成画質を表し、FIDが低いほど画質が高いです。図3はその結果を示しています。adaLN-Zeroブロックは、最も計算効率が良いにもかかわらず、クロスアテンションとインコンテキストの両方よりも低いFIDを提供します。トレーニングイテレーションが400Kの時点で、adaLN-Zeroモデルで達成されるFIDは、インコンテキストモデルのほぼ半分であり、条件付けメカニズムがモデルの品質に重大な影響を与えることを示しています。初期化も重要で、各DiTブロックを同一関数として初期化するadaLNZeroは、バニラadaLNを大幅に上回りました。以降、すべてのモデルでadaLN-Zero DiTブロックを使用します。
スケーラビリティの検証
私たちは、モデル構成(S、B、L、XL)とパッチサイズ(8、4、2)を使って、12のDiTモデルを訓練しました。DiT-LとDiT-XLは、他の構成に比べて、相対Gflopsの観点から明らかに互いに近いことに注意してください。図4(左)は、各モデルのGflopsと400KのトレーニングイテレーションでのFIDの概要を示しています。すべての場合で、モデルサイズを増やし、パッチサイズを減らすことで、拡散モデルをかなり改善できることがわかりました。
図5(上)は、モデルサイズが増加し、パッチサイズが一定のままである場合のFIDの変化を示しています。すべての構成で、トランスフォーマーをより深く、広くすることで、トレーニングのすべての段階でFIDが著しく改善されることがわかります。同様に、図5(下)は、パッチサイズを減少させ、モデルサイズを一定に保った場合のFIDを示しています。DiTによって処理されるトークンの数を単純にスケーリングすることで、トレーニング中にかなりのFIDの改善が見られることが再び観察されます。
画質の検証
本実験では、DiT-XL/2をImageNet 256×256とImageNet 512×512のデータセットで訓練して、従来の拡散モデルとGANのSOTAとの比較を行います。両方のデータセットにおいて、提案手法は従来の拡散モデルを大きく上回ったと確認できます。GANのSOTAであるStyleGANと比較するとき、同等以上の画質(FID)を達成し、多様性(Recall)でGANを大きく上回りました。
生成の例は図6に示します。どの場合でも非常に鮮明な画像で、人間でも真偽を区別できない程度です。
まとめ
今回の解説論文は、Diffusion Transformers(DiTs)を提案しました。これは、従来のCNNのU-Netモデルを上回り、トランスフォーマーモデルクラスの優れたスケーリング特性を継承した、拡散モデルのシンプルなトランスフォーマーベースのバックボーンです。この論文で示された有望なスケーリング結果を考慮すると、将来の研究では、DiTsをより大きなモデルやトークン数にスケールさせることが継続されるべきです。また、DiTは、DALL·E 2やStable Diffusionのような有名なテキストから画像へのモデルのバックボーンとしても検討される可能性があります。
この記事に関するカテゴリー