最新AI論文をキャッチアップ

分布シフトの入念な分析!

分布シフトの入念な分析!

Domain Shift

3つの要点
✔️ 分布シフトについてのフレームワークを提案
✔️ 重要な三つの分布シフトについて定義
✔️ 様々な手法を包括的に比較評価

A Fine-Grained Analysis on Distribution Shift
written by Olivia WilesSven GowalFlorian StimbergSylvestre Alvise-RebuffiIra KtenaKrishnamurthy DvijothamTaylan Cemgil
(Submitted on 21 Oct 2021 (v1), last revised 25 Nov 2021 (this version, v2))
Comments: ICLR2022.

Subjects: Machine Learning (cs.LG); Computer Vision and Pattern Recognition (cs.CV)

code:  

本記事で使用している画像は論文中のもの、紹介スライドのもの、またはそれを参考に作成したものを使用しております。  

はじめに

機械学習モデルをアプリケーションとして幅広く利用するには、分布の変化に対しロバストであることが重要になります。例えば、ある病院群の画像で訓練されたモデルは、別の病院の画像に対しては適切に機能しないことがあるかもしれません。そのため、分布シフトに対するロバスト性を高めること、あるいは特定のモデルが分布シフトに対しどれだけロバストであるかを把握することは非常に重要な課題であり、このような問題に取り組むDomain Generalizationは活発に研究されています。

しかしながら、実際に発生しうる分布シフトを定義すること、複数の異なる分布シフトに対するアルゴリズムのロバスト性を評価するといった研究はほとんど行われていませんでした。

この記事で紹介する論文では、この重要な問題に取り組むため、分布の変化を細かく分析できるフレームワークを導入し、実世界に影響を与えうる三つの分布シフト(spurious correlation, lowdata drift, unseen data shift)を定義しました。さらに、二つの追加条件(ラベルノイズ、データセットサイズ)を導入し、実データ・合成データの両方について、19の既存手法の評価を行いました。(これらの貢献が評価され、この論文はICLR2022にAccept(Oral)されています。)

提案手法

Generalization評価のためのフレームワーク

はじめに、入力を$x$とし、これに対応する属性を$y^1,y^2,...,y^K$($y^{1:K}$)とします。ここで、属性のうち一つはラベルであり、これは$y^l$と表記します。例えば医療画像であれば、$y^l$は良性・悪性、$y^i(i \neq l)$は画像が撮影された病院の情報等が考えられます。また、$x$と$y^{1:K}$の合同分布を$p$とします。このとき、モデルの学習目的はリスク$R(f)=E_{(x,y^l)~p}[L(y^l,f(x))]$を最小化する分類器$f$を構築することです(Lは損失関数)。実際には、入力と属性のサイズは有限の数$n$なため、代わりに経験的リスク$\hat{R}(f;p)=\frac{1}{n} \sum_{\{(y^l_i,x_i)~p\}^n_{i=1}}L(y^l_i,f(x_i))$を最小化します。

分布シフトが起こりうる条件下では、モデルのtrain・test時でデータの分布$p_{train},p_{test}$は異なると考えられます。例えば、$p_{train}$と$p_{test}$は、異なる病院で撮影された画像であったり、撮影に使われた機器が異なるかもしれません。このとき、モデルは$\hat{R}(f;p_{train})$を最小化するように学習を行いますが、実際にはテスト時の経験的リスク$\hat{R}(f;p_{test})$を小さくすることが望ましいです。

ここで注目すべきなのは、$p_{train},p_{test}$は異なる分布であっても、真の分布$p$($x$と$y^{1:K}$の合同分布)に関係しているということです。そこで、この関係を表現するために、潜在因子$z$を利用して因数分解を行います。ここで、$z$について以下の関係が成り立つとします。

 

このとき、真の分布$p(y^{1:K},x)$について、以下のように因数分解を行うことができます。

 

つまり、真の分布は属性$y^{1:K}$の周辺分布$p(y^{1:K})$と、条件付き生成モデル$p(x|y^{1:K})$の積として表すことができます。

これを踏まえて、一つ重要な仮定をします。すなわち、分布シフトは属性の周辺分布が変化することによって生じると考えます。

つまり、分布シフトは$p(y^{1:K}) \neq p_{train}(y^{1:K}) \neq p_{test}(y^{1:K})$のような場合に生じ、一方で条件付き生成モデル$p(x|y^{1:K})$は、すべての分布について変化せず共有されると考えます。

すなわち、$p_{test}(y^{1:K}, x) = p_{test}(y^{1:K}) \int p(x|z)p(z|y^{1:K})dz$や$p_{train}(y^{1:K}, x) = p_{train}(y^{1:K}) \int p(x|z)p(z|y^{1:K})dz$が成り立ちます。

分布シフトについて

上述のフレームワークに従い、実世界で起こりうる代表的な三種類の分布シフトについて考えます。実際の例として、dSpritesデータセットの場合の分布シフトの事例は以下の図のようになります。

 

このとき、属性$y^1$は色(赤、緑、青)、$y^2$は形(ハート、楕円、四角形)にあたります。

テスト分布$p_test$について

テスト分布$p_test$では、属性$y^{1:K}$が一様に分布していると仮定します。つまり、$p_{test}(y^{1:K})=\frac{1}{\prod_i}|A^i|$とします。

これは上述の図(d)のように、全ての属性について一様に偏りなくデータが分布している状態にあたります。

擬似相関(Spurious correlation)

はじめに、属性が$p_train$では相関しているものの、$p_test$では相関していない場合について考えます。この疑似相関は、上述のフレームワークによれば、二つの属性$y^a,y^b$がtrain時に相関している(独立でない)場合に生じます。

具体的には、$p_{train}(y^a|y^1,...,y^b,...,y^K) > p_{train}(y^a|y^1,...,y^{b-1},y^{b+1},...,y^K)$が成り立つ場合にあたります。この疑似相関は、相関がある二つの属性の一方がラベルの場合に特に問題となります。

先程の図(a)の例で言えば、もしラベルが形状($y^2$)であった場合、$y^1=赤$ならば$y^2=ハート$、$y^1=緑$ならば$y^2=楕円$といったように、モデルは色に基づいて形状の予測を行ってしまうかもしれません。この場合、属性間に相関がない$p_test$ではGeneralizationに失敗してしまいます。

 

低データドリフト(Low-data drift)

低データドリフトは、$p_{train}$では属性値が偏っているものの、$p_{test}$では偏っていない場合にあたります(先程の図(b))。この分布シフトは、データセットの収集が属性値によって偏りがある場合に生じます。上述のフレームワークによれば、$p_{train}(y^a=v) << p_{test}(y^a=v)$の場合にあたります。

未見データシフト(Unseen data shift)

低データドリフトの特殊なケースとして、特定の属性値のデータがtrain時に欠落している場合が考えられます。上述のフレームワークによれば、これは以下の式で表されます。

 

より複雑な分布シフトについて

ラベルの周辺分布$p(y^l)$は、特定の属性値の確率$p(y^a)$と条件付き確率$p(y^l|y^a)$からなる二つの項に分解することができ、これは$p(y^l)=\sum_{y^a}p(y^l|y^a)p(y^a)$と表されます。

このとき、擬似相関は$p(y^l|y^a)$を制御し、低データドリフト・未見データシフトは$p(y^a)$を制御します。そのため、より複雑な分布シフトは、これら三つの分布シフトを構成要素として記述することができます。 

追加条件

これらの分布シフトに加え、実際の環境では以下の二つの追加条件が生じる可能性があります。

ラベルノイズ

ラベルノイズは、アノテーターの意見の相違やエラーがある場合に生じます。これは、観測された属性(例えばラベル)がノイズによって破壊されるものとしてモデル化されます。これは、破壊されたラベルを$\hat{y}^i$、真のラベルを$y^i$とすると、$\hat{y}^i~c(y^i)$と表されます。

データセットサイズ

trainデータセットのサイズの制約によって、モデルの性能が変化することが考えられます。

元論文では実験の際、これらの条件を導入してモデルの評価を行っています。

ロバスト性を高めるための手法

モデル学習時に$p_train$にのアクセスできる条件下で、真の分布$p$やテスト分布$p_{test}$におけるリスクを小さくすることが、分布シフトに対するロバスト性を高めることにあたります。この目的を実現するためには、以下のような方法があります。

1.重み付けリサンプリング(Weighted resampling)

trainセットについて、重要度の重み$W(y^{1:K}) = p(y^{1:K})/p_{train}(y^{1:K})$を用いてリサンプリングを行います。

このとき、$i$番目のデータポイント$(y^{1:K}_i,x_i)$は、$1/n$ではなく$W(y^{1:K}_i)/\sum^n_{i'=1} W(y^{1:K}_{i'})$の確率で選択されます。

実際には真の分布$p(y^{1:K})$にアクセスできるとは限らないため、属性の全ての組み合わせが一様にランダムに発生すると仮定することが多いです。

2.ヒューリスティックなデータ補強(Heuristic Data Augmentation)

重み付けリサンプリングでは、同じサンプルが何度も再使用される可能性があるため、ヒューリスティックなデータ増強を行うことで、オーバーフィットを軽減します。

3.学習されたデータ増強(Learned Data Augmentation)

真の分布が属性$y^{1:K}$の周辺分布$p(y^{1:K})$と、条件付き生成モデル$p(x|y^{1:K})$の積として表すことができることを踏まえて、trainデータから条件付き生成モデル$\hat{p}(x|y^{1:K})$を学習し、新たな合成データのサンプリングを行います。

このとき、増強されたデータ分布$p_{aug}=(1-\alpha)p_{train}+alpha \hat{p}(x|y^{1:K})p(y^{1:K})$から得られたデータセットで、教師付き分類器の学習を行います。

4.表現学習(Representation Learning)

別の因数分解として、$p_{train}(y^{1:K}, x) = \int p(z|x)p_{train}(y^{1:K}|z)dz$が考えられます。

これを元に、trainデータから教師なしで$p(z|x)$の表現学習を行い、潜在変数$z$を元に予測を行う分類器ヘッド$p{train}(y^l|z)$を学習することができます。適切に表現学習を行うことができれば、特定の属性分布の影響を受けることなく、$p_test,p$へのGeneralizationが実現できます。

実験設定

実験では、分布シフトに対するモデルのロバスト性を向上させるための手法について、19個のアルゴリズムを評価します。

アーキテクチャ

実験に用いるモデルのアーキテクチャは以下のとおりです。 

  • ResNet18,ResNet50,ResNet101
  • ViT
  • MLP

また、学習時には重み付きリサンプリングを行い、$p_{train}$の低い確率の部分からオーバーサンプリングします。

ヒューリスティックなデータ補強

ロバスト性向上のため、以下のデータ増強手法について分析を行います。

  • 標準的なImageNetの増強
  • JSDなしのAugMix
  • RandAugment
  • AutoAugment

学習されたデータ増強

条件付き生成モデル$p(x|y^{1:K})$を近似し、生成された画像をデータ増強として利用します。

近似にはCycleGANを利用します。

Domain generalization

Domain generalization手法では、属性に依存しない表現$z$を回復することを目的としており、以下の手法について実験を行います。

  • IRM
  • DeepCORAL
  • domain MixUp
  • DANN
  • SagNet

Adaptive approaches

Adaptive approaches(適応的手法)として、以下の手法について実験を行います。

  • JTT
  • BN-Adapt

表現学習

表現学習手法として、以下の手法について実験を行います。

  • β-VAE
  • ImageNetでの事前学習($D_{train}$に追加データを使用)

データセットとモデル選択

実験には、6つの画像分類データセットを利用します。

  • DSPRITES
  • MPI3D
  • SMALLNORB
  • SHAPES3D
  • CAMELYON17
  • IWILDCAM

ここで、単純な合成データセット(DSPRITES・MPI3D・SHAPES3D・SMALLNORB)ではResNet18を、複雑な実世界データセット(CAMELYON17・IWILDCAM)ではResNet50を使用します。実験では5つのシード値について実行しています。

実験結果

偽相関(Spurious Correlation)、低データドリフト(Low-data drift)、未見データシフト(Unseen data shift)における結果は以下の通りです。

また、ラベルノイズ・データサイズ制約が存在する場合の結果は以下の通りです。

 

全体として得られた結果は以下のようにまとめられます。

  • 常に最高の性能を示す手法は存在しない。
  • 事前学習は様々なデータシフト・データセットにわたり強力なツールとなる。
  • ヒューリスティックなデータ増強は必ずしも結果を改善するとは限らない。
  • 学習されたデータ増強は様々な条件と分布シフトにわたり効果的である。
  • Domain generalizationによる性能向上は限定的だった。
  • 最適なアルゴリズムは、詳細な条件によっては異なるかもしれない。
  • 考慮する属性は結果に直接影響を与える。

実験結果のより詳細な内容については元論文を参照ください。

また、実用的な情報として、論文では以下のヒントが勧められています。

  • ヒューリスティックなデータ増強が不変性を促進するなら、それを利用する。
  • ヒューリスティックなデータ増強が役立たない場合、学習されたデータ増強を利用する。
  • 事前学習を用いる。
  • 複雑な手法による改善は限られている。

最後に、論文では実験の結果から、以下のように考察を行っています。

  • データセットだけから、最適な手法を事前に決めることはできない。
  • 分布シフトについての知識がある場合には、それに焦点を当てるべきである。
  • 様々な条件で手法を評価することが極めて重要である。

まとめ

この記事では、分布シフトについての包括的なフレームワークの提案、並びに様々な手法についての実験による詳細な分析を行った論文について紹介しました。このフレームワークとベンチマークは、分布シフトに関連する手法の評価のための有用なツールとなりうるでしょう。

 

 

記事の内容等について改善箇所などございましたら、
お問い合わせフォームよりAI-SCHOLAR編集部の方にご連絡を頂けますと幸いです。
どうぞよろしくお願いします。

お問い合わせする