因果論の独立メカニズム仮設を深層学習で実現してみた
3つの要点
✔️ 環境の変化に関連する部分のみが反応する「独立メカニズム」を実現したRIMsを提案
✔️ AttentionとLSTMを組み合わせた競争を誘発する機構を組み入れる
✔️ 幅広い実験を通じてRIMsによる汎化性能の向上を検証した
Recurrent Independent Mechanisms
written by Anirudh Goyal, Alex Lamb, Jordan Hoffmann, Shagun Sodhani, Sergey Levine, Yoshua Bengio, Bernhard Schölkopf
(Submitted on 24 Sep 2019 (v1), last revised 17 Nov 2020 (this version, v6))
Comments: Published on arxiv.
Subjects: Machine Learning (cs.LG); Artificial Intelligence (cs.AI); Machine Learning (stat.ML)
code:
本記事で使用している画像は論文中のもの、またはそれを参考に作成したものを使用しております。
はじめに:世の中に存在する独立構造
物理プロセスは単純なモジュール構造を持つサブシステムの組み合わせによって生み出されると考えられます。例えば、二つのボールは重力によって相互に連動しているが、たまに衝突して強く相互作用する以外には、ほとんど独立したメカニズムとしてモデル化できます。
紹介する論文の提案手法は、回帰的独立メカニズム(Recurrent Independent Mechanisms, RIMs)と名付けて、人間が自然界を認知する際に用いる独立メカニズムを持つモジュール構造を、深層学習のアプローチで実現しようとしています。
RIMsをデザインする上で、独立するメカニズムであることと疎なコミュニケーションを行うことが重要です。モジュール間の独立性は、因果推論における局所的な干渉ができることの前提条件となっており、非常に重要な性質と考えられています。一方で、モジュール間で疎なコミュニケーションを行うことは密な交換を避けることであり、入力に関連するモジュールのみが反応する独立性に繋がります。具体的に、RIMsのアーキテクチャがどうデザインすべきなのか?RIMsの汎化性能をどの切り口から評価できるのかについて、詳しく説明していきます。
提案手法:Recurrent Independent Mechanisms(RIMs)
RIMsの全体像とアーキテクチャデザインを概観した後、RIMsの詳細を4ステップに分けて具体的に説明します。
全体のモデルをk分割した個々のサブシステムは、反復して観測した系列情報の変化を捉えられるようにしています。これらのサブシステムを循環独立メカニズム(Recurrent Independent Mechanisms,RIMs)と呼び、各RIMはそれぞれ異なった機能を持つようにデータから学習する仕組みになっています。
# k個目RIMは時刻tにおいて価値ベクトルh_(t,k)とパラメータθ_kを持ちます。ただし、t=1,...,Tである。
Figure 1にRIMsの全体図を示しました。著者たちは、各RIMsが分化して独自のダイナミックスを持つながら、たまに他のRIMや選択した入力の埋め込みとインタラクションするようにデザインしました。特に注意機構(Attention Mechanisms)を用いることで、RIMsは少数のキー/バリュー変数のみを操作することが可能であり、パラメータの総数が小さく抑えています。この分化とモジュール化は計算や統計上にメリットだけでなく、個々のRIMが計算を支配することを防ぎ、計算の再結合や再利用しやすいように単純な要素に分解することを促します。
RIMsアーキテクチャによって、1つの大きな均質なシステムを学習するよりロバスト性の高いシステムを学習できることに期待しています。その際にRIMsには他のRIMが変化しても独自の機能を維持すべきといったさまざまな性質が必要と考えられており、論文の付録Aに詳細が記しています。
独立変数の処理を行うキー/バリュー注意機構
まず入力を受け取る段階において、各RIMは自身と関係する入力があった時のみ、励起と更新を行うことが望ましいです。各時刻においてk個のRIMsがそれぞれ入力との関連度をAttention機構を用いて計算し、リソースを競争する仕組みを取り入れます。これでもし、はじめににて紹介したデータの裏に独立した物理現象によって支配されているという仮説が正しければ、RIMsは自然と独立メカニズムを学習できると言われています(Parascandolo et al. 2018)。
多くの分野で有効性が示されているSoft-Attentionを使います。RIMsからQuery(Q)を、入力情報からKey(K)とValue(V)を生成し、各RIMが入力情報との関係度は下記のSoft-Attentionで算出できます。
# 各RIMsの入力と出力が複数のオブジェクトとなった時に、RIMs自身のQuery(Q)を用いて入力のKey(K)とValue(V)にSoft-Attention計算を通じて動的に入力するオブジェクトを選択できるようになります。
Top-Down形式でRIMsの活性化を選択
提案手法は、現時刻のRIMsと入力との相互作用の結果によって活性化するRIMsが決まることで、入力に関係するRIMsを動的に選ぶことを学習します。各時刻において2.1節で述べたAttention機構で得られたScoreの上位kつのRIMsを選びます。つまり、各時刻において入力から読み込みできるのは他のRIMsより高いAttention Scoreを得たkつのRIMsになり、また選ばれたRIMsしか更新できません。
このRIMsから入力にアクセスするTop-Down形式で活性化するRIMsを選択する過程はFigure 1の右に示されています。式(2)のようにSoft-Attentionの計算におけるKeyとValueは、入力XをそれぞれマトリックスWで線形変換して得られます。またQueryは個々のRIMsが独自の変換マトリックスWを用いて算出します。これらのマトリックスWはRIMsのパラメータとなり、Queryを算出するWは各RIMが異なります。
また、系列情報の入力Xは時刻tで簡単に処理できるだけでなく、画像といった空間構造を持つ入力でも、CNNのような埋め込みネットワークの出力をXとして扱えば、同様な手順で活性化するRIMsの選択ができます。
独立したRIMのダイナミックス
ここではRIMs間で情報を流さない独立したダイナミックスを考えます。形式はいくつか考えられますが、著者たちはGRU(LSTM)アーキテクチャを採用します。次式で示すように、時刻tにk個目のRIMの状態を潜在状態h_(t,k)とし、2.2節で説明したAttention計算で得られたA_kを入力とします。また、Stは活性化したRIMsの集合であるため、活性化したRIMのみが独自のGRU(LSTM)を通して更新します。
RIMs間のコミュニケーション
基本的にRIMsは独自のパラメータを持って学習を進めていますが、Attention機構を用いて活性化したRIMsが他のRIMから情報を取得するチャンスを用意されています。この理由は、活性化されていないRIMsは入力と直接に関係しなくても、活性化されたRIMにとって役立つ情報を含んでいることが考えられるからです。
上の式から分かるように、2.1節で紹介したAttention機構をベースに、勾配消失を防ぐために残差接続(residual connection)を導入しました(Soft-Attentionの最後に追加した潜在状態h_(t,k))。また、RIMs間のスパースな情報交換を実現するには、同様にtop-k Attentionを用いました。
関連研究
- Neural Turing Machine(NTM) and Relational Memory Core(RMC)
NTMは独立したメモリを注意機構で用いて読み書きを行っています。RIMsの入力情報はメモリの一部にしか影響を与えないところは、NTMと同じアイデアを持っています。RMCはマルチヘッド注意機構を用いて複数のメモリ間で情報を流すようにしているが、RIMsはできるだけ独立したメモリを保つようにします。RIMsはそれぞれのダイナミックスを持つが、RMCはお互いに影響を及ぼします。 - Separate Recurrent Models
EntNetと、IndRNNは独立した循環モデルとして考えられるが、RIMsはスパースにコミュニケーションを行うことを注意機構を使って実現しています。 - Modularity and Neural Networks
ニューラルネットを複数のモジュールから構成されていると考えられますが、RIMsは複数のモジュールを活性化でき、お高いに情報を交換するようにしています。 - Computation on demand
他にもたくさんのアーキテクチャがRNNの潜在表現(h)を一時的に休止させるものがありますが、RIMsは入力情報を選択していることがそれらとの異なります。
実験
ここでは変化が生じる環境またはモジュール化したタスクにおいて、RIMsが汎化性能に貢献することを示した上で、その原因を調べていきます。時系列への汎化とオブジェクトベースの汎化をはじめ、それら両方を求める複雑な環境における検証を行います。
時系列パターンの汎化性能
まずは、時系列データの異なるパターンに対して異なるRIMが活性化されることを可視化します。Figure 2の中間部で時系列情報が入力されない時にRIMsが決まったパターンで活性化されることが分かります。次にコピータスクと系列MNIST分類タスクについて説明します。
短い文字列の次に長い空白を入力した後、文字列の再現を評価するコピータスクでは、RIMsは学習時に空白が50ステップでテスト時に200ステップの空白に汎化できるが、比較手法のLSTM、NTMとRMCは上手くいきません(Table 1左)。これを実現するのに2章で紹介したRIMsの要素が必要であることを追加実験で示しています(付録D.1)。
MNISTの画像をピクセルの[0,1]系列入力で分類するタスクにおいて、14x14の画像で学習した後、16x16, 19x19, 24x24と異なる解像度で評価実験を行います。コピータスクよりも情報の長期記憶が求められている中で、RIMsはTransformersをはじめ有力な手法よりも優れており、関連した部分にのみ反応するRIMsの仕組みが解像度の変化で生じる系列長の変化にロバストであることが言えます(Table 1右)。
オブジェクトベースの汎化性能
次に、異なるサイズや速度のバンディングボールの動きを予測タスクを用いてRIMsのオブジェクトベースの汎化性能を評価します。
Figure 3左は、15フレームの動画を与えた後50フレームまでのボールの動きを予測するタスクの結果を示しています。訓練時のボール数がテスト時に変化したり、入力の一部分が見えなかったりしてもRIMsはベースラインより再構成誤差が低い結果だとなりました。また、三つのボールで学習し、テスト時にボール数が1~6に変化してもRIMsは全ての比較手法より優れており、オベジェクトベースの汎化性能が高いことがわかります(Figure 3右)。
また、Figure 4よりグリッド環境のPickupタスクにおいてRIMsはBaseline LSTMより既知や未知のいずれの障害物に対してもロバスト性が高いことがわかります。
複雑環境における汎化性能
時空間的な汎化性能は4.1節と4.2節にて確認しました。これらを組み合わせたより複雑な環境における汎化性能を、Atari gamesで検証します。
BaselineはLSTMを用いて変化する環境を捉えた上で、PPO Agentでポリシーを出力する手法とします。学習済みPPO Agentを用いるので、単にLSTMをRIMsに置き換えた手法で評価します。Baselineの精度より優れた場合は0以上となる評価指標を用いたFigure 5のような結果となり、RIMs-PPOはLSTM-PPOより多くのタスクにおいて精度が高いことがわかります。特にDemon Attackといった環境の変化が現在の状態に対して大きく関連するタスクにおいて、RIMsの環境に対して適切なRIMが選べる仕組みが大きく精度に貢献できます。
Discussion & Ablation
提案手法RIMsが入力の変化を上手に捉えられることが実験でわかりました。著者たちは追加実験(付録参照)を通じて、1.スパースな活性化 2.入力を処理するAttentionが必要であることや 3. RIMs間のコミュニケーション(2.4節)が精度に貢献することが示しました。
まとめ
世の中に存在する多くのシステムは独立したメカニズムのもとで動作しながら、たまに交互作用するようになっています。一方で多くの機械学習の手法は逆に全てが入力に関連するようなバイアスを暗黙に導入していると著者たちが主張しています。今回紹介したRecurrent Independent Mechanisms(RIMs)はあくまで独立メカニズムを深層学習で実装した一例に過ぎないが、大量な実験でその有効性を示しています。今後、読者が環境変化時(特にOut-Of-Distribution)の汎化性能で悩む際に、独立メカニズムのアイデアを思い出していただくだけでも、本論文を紹介した甲斐があると考えています。
この記事に関するカテゴリー