知識蒸留で効果的な教師の条件とは?
3つの要点
✔️ 知識蒸留を成功させる効果的な方法について検討
✔️ 一貫した(consistent)・忍耐強い(patient)教師が重要であることを特定
✔️ ResNet-50モデルでImageNet 82.8%のTop-1精度を達成
Knowledge distillation: A good teacher is patient and consistent
written by Lucas Beyer, Xiaohua Zhai, Amélie Royer, Larisa Markeeva, Rohan Anil, Alexander Kolesnikov
(Submitted on 9 Jun 2021 (v1), last revised 21 Jun 2022 (this version, v2))
Comments: CVPR2022.
Subjects: Computer Vision and Pattern Recognition (cs.CV); Artificial Intelligence (cs.AI); Machine Learning (cs.LG)
code:
本記事で使用している画像は論文中のもの、紹介スライドのもの、またはそれを参考に作成したものを使用しております。
はじめに
画像分類、物体検出、セマンティックセグメンテーションなどのコンピュータビジョンタスクでは、非常に大規模なモデルが最先端の性能を示しています。
しかし、その計算コストの高さゆえ、高性能な大規模モデルよりも、ResNet-50やMobileNetなどの小さなモデルの方が一般的に使用されています。
本記事で紹介する論文では、この問題に対処するため、知識蒸留(Knowledge Distillation)により大規模モデルを圧縮し優れた性能を発揮するための、より効果的な方法を特定することに取り組みました。
その結果、教師モデルと生徒モデルの入力を同一にすること、積極的な増強を行い学習時間を長くする事により、様々なビジョンデータセットで優れた結果が示され、特にImageNetにおけるResNet-50モデルで82.8%のTop-1精度を達成しました。
実験設定
実験では、特定のタスクで高い精度を示す大規模な視覚モデル(教師モデル)を、性能低下を抑えながらより小さなモデル(生徒モデル)に圧縮します。
実験で用いるデータセットは以下の五つです。
- flower102
- pets
- food101
- sun397
- ILSVRC-2012(ImageNet)
これらは多様な画像分類タスクであり、クラス数は37~1000、画像数は1020~1281167にわたります。
評価指標には分類精度を用いています。
教師・生徒モデル
教師モデルとして、実験ではBiT(ILSVRC-2012とImageNet-21kで事前学習されたResNetモデルの大規模なコレクション)のモデルを使用します。標準的なResNetとの大きな違いとして、Batch Normalizationの代わりにGroup Normalization層とWeight Standardizationが使用されています。
生徒モデルにはBiT-ResNet-40の変種を使用しています(以後簡略化しResNet-50と呼びます)。
蒸留損失
蒸留損失には教師モデル・生徒モデルの予測クラス確率ベクトル$p_t,p_s$間のKL-divergenceを使用します。
$C$はクラス集合です。また、温度パラメータ$T$も使用しています($p_s \propto exp(log p_s/T), p_t \propto exp(log p_t/T)$)。
実験結果
論文では、知識蒸留は教師・生徒モデルの関数をマッチングさせるタスクであると解釈しています。
この解釈に基づき、モデル圧縮のための知識蒸留における二つの原則として、以下の二つを結論として示しています。
- 教師・生徒モデルは一貫した同一の(全く同じクロップ・Augmentationがなされた)入力を処理すること("consistent teacher")。
- 汎化を良くするため、積極的にAugmentation処理を行い、多くのエポック数で学習を行うこと("patient teacher")。
一貫した教師(consistent teacher)の重要性
はじめに、"consistent teacher"仮説について検証するため、以下の四つの知識蒸留におけるオプションについて検討します。
- Fixed teacher:教師モデルの予測を固定します。
- fix/rs:教師・生徒モデルはどちらも224x224にリサイズされた画像を入力とします。
- fix/cc:教師モデルは中心をクロップし、生徒モデルではランダムなクロップを行います。
- fix/ic_ens:教師モデルの予測として、1k通りのクロップに対する予測の平均を使用します(inception crop)。生徒モデルは画像の一部をランダムにクロップします。
- Independent noise:教師・生徒モデルで異なる入力が与えられます。
- ind/rc:教師・生徒モデルでそれぞれ独立したランダムなクロップを適用します。
- ind/ic:教師・生徒モデルでそれぞれ独立したinception cropを適用します。
- Consistent teaching:教師・生徒モデルで同一の入力が与えられます。
- same/rc:教師・生徒モデルで同一のランダムなクロップを適用します。
- same/ic:教師・生徒モデルで同一のinception cropを適用します。
- Function matching:教師・生徒モデルで同一の入力(+Augmentation)が与えられます。
- same/ir,rc, mix:Consistent teachingの拡張で、画像に対してmixup処理を行い多様性を高めた上で、教師・生徒モデルに同一の入力を与えます。
これらの設定について、Flowers102上で10,000エポックの学習を行った場合の学習曲線は以下の通りです。
図の通り、一貫した教師(same/rc、same/ic)設定がより優れた結果を示しており、教師・生徒モデルで一貫した入力が与えられることの重要性が示されました。またtrain/val曲線を比較すると、教師の予測を固定する場合(fix、黒線)にはオーバーフィッティングが発生していることがわかります。
忍耐強い教師(patient teacher)の重要性
通常の教師あり学習の場合、積極的な画像増強(augmentation)は、画像ラベルに対して実際の画像が大きく歪んでしまうリスクが生じます。
しかし、知識蒸留を教師・生徒モデル関数のマッチング処理であると解釈し、教師・生徒モデルに一貫した同一の入力を与えるならば、入力が大きく歪んでいても関数マッチングには有効であるため、積極的に画像増強を行うことができます。
この考えに基づき、積極的に画像増強を行うことでオーバーフィッティングを回避しつつ、非常に長い時間の最適化を行う(patient teacher)場合について検証します。結果は以下の通りです。
この図では、各データセットに対して異なるエポック数で学習を行った場合のテスト精度の推移が示されています。
図の通り、非常に多くのエポック数で学習を行った結果、最終的に生徒モデルが教師モデルの性能(赤線)に到達していることがわかります。また、1Mエポックもの学習を経てもオーバーフィッティングが起きていないことも注目に値します。そして、ゼロからの学習または転移学習を行った場合と比較した場合、小さなエポック数では劣りますが、最終的には上回る結果が得られました。
ImageNetへのスケールアップ
上記の実験は比較的小規模なデータセットに対するものであったため、より大規模なImageNetに対して同様の実験を行います。このときの結果は以下の通りです。
上記の実験と同様、consistent teaching設定はオーバーフィッティングすることなく、学習時間の増加に伴い性能が向上していることがわかります。
また、積極的に増強処理を行うfunction matching設定では、エポック数が少ない場合はアンダーフィッティングが起きていますが、学習時間が長くなるとより優れた性能を示しています。
最終的に、ResNet-50生徒モデルは、ImageNetで82.31%のTop-1精度を達成しました。
異なる解像度における知識蒸留
これまでの実験では、教師・生徒が同一の解像度(224x224)の入力を受け取っています。しかし、生徒モデルでの解像度を低下させることで、より高速な処理が実現できるかもしれません。
そこで、生徒モデルにおける入力画像解像度を教師モデルより小さくした場合のエポック数とTop-1精度は以下の通りです。
表の通り、生徒モデルと教師モデルの解像度が異なる場合でも知識蒸留が有効に機能することがわかります。
また、より解像度が高く高精度な教師モデルから知識蒸留を行うことで(S384→S224)、生徒モデルが同じ224x224の解像度でもより優れた性能を発揮できることが示されました。
second order preconditionerによる学習効率向上
積極的に増強処理を行うfunction matching設定は、最終的な性能は高いものの、より多くの学習時間が必要になります。
ここで、より強力なOptimizerを使用する(Adam→Shampoo)ことで、学習時間の増加を抑えることができるかについて検証します。結果は以下の通りです。
図の通り、Adamの代わりにShampooを使用することで、学習速度を4倍向上させることに成功しました。
事前学習済みモデルの使用について
転移学習の成功を踏まえて、学生モデルを事前学習済みモデルで初期化した場合の結果は以下の通りです。
学習時間が短い場合、事前学習済みモデルによる初期化は良い結果を示しています。しかし、学習時間が長くなると、最終的にはゼロからの学習の方がより優れた性能を達成しました。
異なるモデル系列における知識蒸留
異なる解像度でも知識蒸留が成功することを踏まえて、異なるモデル系列間での知識蒸留についても検証します。
まず、生徒モデルをMobileNet v3(Large)に変更した場合では、300エポックで74.60、1200エポックで76.31のTop-1精度を達成しました。また、生徒モデルはResNet50で、教師モデルをアンサンブル設定にした場合(224x224のデフォルト+384x384のロジット平均)には、9600エポック後に82.82ものTop-1精度を達成しました。
総じて、教師・生徒モデルが異なるモデルアーキテクチャであったり、教師モデルがアンサンブル設定であっても知識蒸留が成功することがわかりました。
既存手法との比較
これらの実験の中で最良の結果と、既存のResNet-50モデルとの比較結果は以下の通りです。
総じて、論文で提示された知識蒸留設定は既存の最先端の結果を上回る性能を示しました。
"out-of-domain"データにおける知識蒸留
知識蒸留を関数マッチングと捉える場合、任意の画像入力に対して知識蒸留が有効であると予想されます。
この仮説を検討するため、petsとsun397データセットで実験を行います。具体的には、food101とImageNetの画像(out-of-domain)で知識蒸留を行い、ipetsとsun397の画像(in-domain)で知識蒸留を行った場合と比較します。結果は以下の通りです。
総じて、in-domainの知識蒸留が最も優れた性能を示しましたが、out-of-domainの画像でもある程度は知識蒸留が機能することが示されました。
また、ドメインに関連性があるか重複がある場合(petsとImageNet、sun397とImageNet)では、長い学習時間が必要になるものの、in-domainに近い性能を達成しうることがわかりました。
蒸留損失なしとの比較
最後に、これらの実験結果が特有の学習設定(積極的なmixup増強と長い学習時間)によるものではないことを確認するため、蒸留損失を消去して通常の教師あり学習を行った場合と比較した結果は以下の通りです。
図の通り、知識蒸留を行わずに積極的なmixup増強と長期の学習を行った場合、性能が低下しオーバーフィッティングが生じました。このように、積極的なmixup増強と長期の学習は、知識蒸留と組み合わせることで有効に機能することが示されました。
まとめ
モデル圧縮の新たな手法を提案するのではなく、既存の知識蒸留プロセスについて再考し、より効果的な学習プロセスについて提案した論文について紹介しました。この結果は、知識蒸留を「教師・生徒モデルの関数マッチング」として捉えるという新たな解釈に基づいています。
そして、(1)教師と生徒の入力を同一にすること、(2)積極的に増強を適用して入力の多様性を高めること、(3)学習時間を長くすることで、知識蒸留の性能を高められることを示しました。
これらの知見に基づき、大規模なモデルをResNet-50に圧縮することで、既存の最先端の性能を上回り、今後の研究における強力なベースラインとなりうる重要な結果を達成しました。
この記事に関するカテゴリー