自己教師あり学習が自己教師あり学習を改善する!!
3つの要点
✔️ 階層的に自己教師あり学習を行うHierarchical PreTraining (HPT)を提案
✔️ 16種類もの多様なデータセットを使用した検証実験を実施
✔️ HPTによって80倍の学習高速化、ロバスト性の改善を実現
Self-Supervised Pretraining Improves Self-Supervised Pretraining
written by Colorado J. Reed, Xiangyu Yue, Ani Nrusimha, Sayna Ebrahimi, Vivek Vijaykumar, Richard Mao, Bo Li, Shanghang Zhang, Devin Guillory, Sean Metzger, Kurt Keutzer, Trevor Darrell
(Submitted on 23 Mar 2021 (v1), last revised 25 Mar 2021 (this version, v2))
Comments: WACV 2022
Subjects: Computer Vision and Pattern Recognition (cs.CV)
code:
本記事で使用している画像は論文中のもの、紹介スライドのもの、またはそれを参考に作成したものを使用しております。
はじめに
自己教師あり学習 (SSL)は、様々な画像認識タスクで有効であることが知られています。しかし、その能力を十分に発揮するためには、豊富なデータと計算量が必要とされています。
したがって、多くの場合は下の図に示すようにImageNetで事前に自己教師あり学習を済ませたモデルを使用することになります。
しかし、ImageNetで自己教師あり学習済みのモデルを新たな画像認識のタスクに転用する際に、そのタスクで使用する画像 (医用画像や航空写真など)がImageNetの画像と特徴が異なる場合、性能が低下してしまうことが知られています。
本論文では、この問題を解決するべく下の図に示す階層的に自己教師あり学習を行うHierarchical PreTraining (HPT)を提案しています。HPTは、BaseデータでのSSL→SourceデータでのSSL→TargetデータでのSSLというように徐々に目的のタスクに近づけるように繰り返しSSLを行うというものです。
ここで、Baseデータとは大規模なデータセット (ImageNet)、Sourceデータとは比較的規模が大きくかつTargetデータに似た特徴をもつデータセット、Targetデータとは目的とするタスクでのデータセットのことを表しています。本論文では、HPTの有効性を、16個もの多様なデータセットを使用した検証実験を行うことで確認しました。HPTの有効性を検証する実験及びその結果について解説します。
実験設定
データセット
実験では5個のドメインを含む合計16個のデータセットを使用しています。
比較手法
HPTを含むSSLのプロセスが異なる4つの手法の比較を行っています。また、SSLとしてはMoCo-v2を使用しています。
- Base:ImageNetを使ったSSL (ただし、Batch Normalization層のみTargetデータを使って更新)
- Target:Targetデータを使ったSSL
- HPT (提案手法):ImageNetを使ったSSL→(Sourceデータを使ったSSL)→Targetデータを使ったSSL
- HPT-BN:ImageNetを使ったSSL→(Sourceデータを使ったSSL)→Targetデータを使ったSSLでBatch Normalization層のみ更新
実験結果
Separability analysis
SSLで学習した特徴量を用いて、線形識別器による識別を行った結果を検証します。
SSLで学習した特徴量を入力とした線形識別器でラベルを用いて学習を行います。線形識別器自体の性能はあまり強力ではありません。したがって、SSLで特徴量をうまく抽出できていればいるほど、線形識別器でも良い性能を出すことができます。
上図に実験結果を示します。各グラフの上にTargetデータが表示されていて、横軸が線形識別器の更新回数、縦軸はその性能 (AccuracyまたはAUROC)を表しています。
また、この実験では、HPT、HPT-BNにおいてSourceデータによるSSLは行わず、ImageNetによるSSL→TargetデータによるSSLとしています。実験結果から確認できたことは以下のとおりです。
- HPTは16種類中15種類のデータセットでBase、Targetと同等かそれ以上の性能に収束した。
- HPTはBase、Targetと比較して80倍速く学習が収束した。(HPTは5k steps、Base・Targetは400k stepsで収束)
- DomainNet quickdrawでは、HPTの性能はTargetの性能に劣っていて、原因としては、ImageNetとDomainNet quickdrawでの特徴の差が大きいことが考えられた。
Semi-supervised transferability
半教師あり学習を行ったときの各手法の性能を検証します。
自己教師あり学習を行った後、Targetデータからランダムに選択した1000個のラベル付きデータを使用してファインチューニングを行います。ただし、各クラスのデータが1個は含まれるように選択します。
上の図が実験結果を示しています。BはBase、TがTargetを表しています。
また、この実験でもHPT、HPT-BNにおいてSourceデータによるSSLは行わず、ImageNetによるSSL→TargetデータによるSSLとしています。
実験結果から確認できたことは以下のとおりです。
- HPTは16種類中15種類 (DomainNet quickdrawを除く)のデータセットでBase、Target以上の性能に収束した。
- HPT-BNがHPTの性能を超えることはなかった。
Sequential pretraining transferability
転移学習を行ったときの各手法の性能を検証します。
上の図が実験結果を示しています。BはBaseデータ、SがSourceデータ、TはTargetデータを使用したSSLを表しています。例えば、B+SはImageNetでのSSL→SourceデータでのSSLを、B+S+TがImageNetでのSSL→SourceデータでのSSL→TargetデータでのSSLを表しています。
さらに、各グラフ上部には使用したSourceデータとTargetデータが表示されています。例えば、左のグラフはBaseデータとしてImageNet、SourceデータとしてChexpert、TargetデータとしてChet-X-ray-kidsを使用していることを表しています。実験結果から確認できたことは以下のとおりです。
- B+S+T (つまり、HPT)でSSLを行った場合の性能が最も良かった。
Augmentation robustness
SSLを行うときのデータ拡張に対するロバスト性について検証します。使用するデータ拡張の種類を減らしてSSLを行い、学習後に得られた特徴量を用いて、線形識別器による識別を行います。
使用するデータ拡張は、RandomResizedCrop、ColorJitter、Grayscale、GaussianBlur、 RandomHorizontalFlipの5種類です。
また、この実験でもHPTにおいてSourceデータによるSSLは行わず、ImageNetによるSSL→TargetデータによるSSLとしています。
上の図が実験結果を示しています。各グラフでは、右に行くほど使用するデータ拡張の種類が減少しています。実験結果から確認できたことは以下のとおりです。
- HPTは、Targetと比較して、使用するデータ拡張を減らしても高い性能を維持した。
- HPTは、Chexpertデータを使用した場合(右図)、使用するデータ拡張を減らすと性能が低下したが、Targetの性能を下回ることはなかった。
Pretraining data robustness
自己教師あり学習で使用するTargetデータのデータ数に対するロバスト性について検証します。
上の図が実験結果を示しています。各グラフでは、右に行くほどSSLで使用するTargetデータのデータ数が増加しています。実験結果から確認できたことは以下のとおりです。
- HPTは、使用できるデータ数が少ないほど、他の手法より優れていた。
- HPT-BNは、使用できるデータ数が5k以下の場合、他の手法より優れていた。
まとめ
今回は階層的なSSLを行うHPTを紹介しました。検証実験を通して、HPTはシンプルでありながら強力な手法であるということがわかりました。HPTは、実装が容易でデータ量・計算量ともに節約できる実践的な手法であるため、今後の発展に期待したいです。
この記事に関するカテゴリー