ACGANの再来~ReACGAN~
3つの要点
✔️ ACGANの訓練が不安定になってしまう原因が判別器の勾配爆発にあることを発見
✔️ データ間の関係性も考慮できる新しい損失関数D2D-CEを提案
✔️ ACGANの改良版ReACGANを開発しBigGANに匹敵する画像生成能を達成
Rebooting ACGAN: Auxiliary Classifier GANs with Stable Training
written by Minguk Kang, Woohyeon Shim, Minsu Cho, Jaesik Park
(Submitted on 1 Nov 2021)
Comments: NeurIPS 2021
Subjects: Computer Vision and Pattern Recognition (cs.CV); Artificial Intelligence (cs.AI); Machine Learning (cs.LG)
code:
本記事で使用している画像は論文中のもの、紹介スライドのもの、またはそれを参考に作成したものを使用しております。
はじめに
これまでの敵対的生成ネットワーク(GAN)における研究は生成サンプルの多様性を損なうモード崩壊や訓練の難しさといった問題に対処する方向性で行われてきました。具体的には勾配消失を起こさない目的関数の定義や訓練を安定化する正則化手法の開発、データセットにおけるラベルデータの有効活用などが挙げられます。
データセットにおいてラベルデータを利用するGANは特に条件付GAN(Conditional GAN)と呼ばれており、判別器をクラスラベルにより条件付ける方法によって判別器ベースGAN(classifier-based GAN)と投影ベースGAN(projection-based GAN)に大別されます。
判別器ベースGANでは判別器がデータが本物かどうかだけでなくそのクラスラベルも同時に予測します。クラス数が増加すると訓練初期でモード崩壊を起こしてしまう点や画像生成においてクラス内での生成サンプルの多様性が少ない点などの問題点が知られています。
一方で投影ベースGANでは判別器においてクラスラベルから投影した条件ベクトルとデータの特徴ベクトルの内積を計算するアーキテクチャを採用しています。データとラベルの間における一対一の関係性しか考慮しないためデータセット内のデータ同士の関係性を考慮する余地が残されています。
本稿で紹介する論文ではACGANの改良版であるReACGANを開発し、入力を超球に射影することで勾配爆発の問題を抑制するとともに、データ間の関係性を考慮した新しい損失関数Data-to-Data Cross-Entropy(D2D-CE)を提案しています。
ACGANの学習における不安定性
従来のGANでは本物データと生成データを見分けるニューラルネットワークを判別器(Discriminator, D)と呼びます。ACGANにおいてはこの判別器が画像に写っているオブジェクトのクラスラベルも同時に分類します。
このクラス分類に対する損失関数として交差エントロピーが用いられているのですが、ソフトマックス層を伴うニューラルネットワークにおいてこの損失関数を計算した場合、最終線形層の重みに関する偏微分は以下の式のように計算されます。
$1_{y_i}=k$はクラスラベル$y$が$k$である場合には1、それ以外の場合には0を取る関数を表しており、$p_{i,k}$はサンプル$i$がクラス$k$に属する確率のことを表しています。
訓練初期にこの$p_{i,k}$が小さい値を取っていることにより勾配ノルムが大きくなってしまうことでモード崩壊が引き起こされ、ACGANの学習が不安定になることがわかりました。
元論文では単純に判別器における最終中間層の出力$F(x)$を正規化する(上式のノルムを1にする)だけで学習安定化の効果を得られることが確認されています。
Data-to-Data Cross-Entropy Loss(D2D-CE)
前述したように投影ベースのGANにおいてはデータとクラスラベル間の関係性のみしか考慮できていなかった点を受け、ReACGANではデータ間の関係性も考慮できる損失関数D2D-CEを提案しています。
D2D-CEでは従来の交差エントロピーがクラスラベルに対応する特徴ベクトル(最終線形層における重みベクトル)とデータから抽出される特徴ベクトルの間の内積を用いて計算されていることを踏まえ、異なるクラスに属するサンプル同士でもこの内積を計算することでデータ間の関係性を反映しています。
具体的にD2D-CEは以下の式で表すことができます。
式中の$f$は画像を特徴抽出器($F$)に入力し、さらに投影層(Projection layer)に通して得られた正規化済みの埋め込み表現、$v$はクラスに対応する正規化済みの埋め込み表現、$\tau$は温度パラメータです。$N(i)$は異なるクラスに属するサンプルの集合を表しています。ここで正規化とは前のセクションで述べた勾配ノルムを1にするための正規化のことです。
このようにデータ間の関係性を考慮した損失関数を導入することでどのような効果を期待しているのでしょうか。下の図は各モデルの訓練においてサンプルを特徴量空間にどう配置するよう学習が進むのか示したものです。図中の青はイヌ以外のクラス、赤はイヌのクラスを表しており、★はクラス分類の線形層重み$w$、矢印は学習の方向を表しています。
判別器ベースGANではサンプルが異なるクラスからは離れるように、また、所属するクラスには近づくように重みの更新が行われます。投影ベースGANでは所属するクラスに近づくように重みの更新が行われます。
一方で、D2D-CEを用いたReACGANの訓練ではマージン項の導入により簡単に異なるクラスに分類できるサンプルは無視され、判別が難しい負例は正例から遠ざけ、所属するクラスには近づくように重みの更新が行われます。これによりクラス間の分離能を保ったままクラス内のサンプルのばらつきを確保した学習が期待できます。
また、実験的にD2D-CEが訓練の安定化にも寄与することが確かめられています。Tiny-ImageNetのデータセットを用いて特徴マップのノルムと分類器における勾配ノルムを各イテレーションで算出したところ、D2D-CEの導入により各値が低い値で推移することがわかりました。D2D-CEが勾配爆発を抑制し、訓練の安定化に寄与することが示唆されています。
ReACGANの全体像
ReACGANの全体像を示すと以下の図のようになります。ACGANからの変化として大きなところは新たな損失関数を導入したことでクラス分類というよりもサンプルが同一のクラスに属すか否か判定するようになっている点です。
従来のGANにおける敵対的訓練はそのままにして、追加のタスクとしてD2D-CEによる訓練を行う仕組みとなっています。
D2D-CEでは異なるクラスに属するサンプル間で特徴量ベクトルの内積を計算するため、サンプル間でクラスが同じか否かを表す対称行列(False negatives mask)が用いられている点が特徴的です。
ReACGANの性能
ReACGANにより生成された画像を5つのベンチマークデータセットについて示したものが下の図です。
以前のACGANとは比べ物にならないほど生成能力が向上しています。しかし、BigGANのアーキテクチャを採用することでネットワークパラメータのサイズが大幅に大きくなっている点には注意が必要です。
ImageNetにおいてInception Score(IS)やFrechet Inception Distance(FID)による評価を行った結果は以下のようになっています。
ReACGANはバッチサイズが256と比較的小さいときに生成能力が高く、2048と大きいバッチサイズになるとBigGANに負けてしまっています。
学習曲線を見るとReACGANではISの立ち上がりが早く、FIDの下がり方も早い傾向にあります。これは学習初期における訓練安定化の効果ではないかと考えられます。
メモリが少ない環境など、バッチサイズをとにかく大きくして勾配の信頼性を上げることができない場合にはReACGANによる訓練を検討してみてはどうでしょうか。
まとめ
いかがだったでしょうか。ACGANの再来というタイトルにしていますが、損失関数の設計などの観点からACGANよりもContraGANに近いのではないかと感じました。
今回提案された損失関数D2D-CEを他の損失関数と比較した結果や、異なるアーキテクチャに応用した結果が元論文には載っていますので興味のある方は参照してみてください。
この記事に関するカテゴリー