最新AI論文をキャッチアップ

【Lambda Networks】Attentionは要らない!? これが新世代のラムダネットワークだ!

【Lambda Networks】Attentionは要らない!? これが新世代のラムダネットワークだ!

機械学習

3つの要点
✔️ Attentionを代替するLambdaNetworksが提案され、ICLR2021に採択
✔️ Attentionを用いずに、広い範囲の相互作用を考慮
✔️ 計算効率と精度の両面でAttentionやConvolutionモデルを凌駕

LambdaNetworks: Modeling Long-Range Interactions Without Attention
written by Irwan Bello
(Submitted on 17 Feb 2021)
Comments: Accepted by the International Conference in Learning Representations 2021 (Spotlight)

Subjects: Computer Vision and Pattern Recognition (cs.CV); Machine Learning (cs.LG)

code:  

本記事で使用している画像は論文中のもの、またはそれを参考に作成したものを使用しております。

はじめに

データの長期にわたる依存関係をモデル化することは,機械学習の中心的な問題として研究されてきました.Self-Attentionは,そのための一般的なアプローチとして近年注目を集めていますが,非常にコンピューティングコストが高く,長いシーケンスや画像などの多次元データへの適用が進んできません.Linear Attentionメカニズムは,この高いメモリ要件に対するスケーラブルな解決策を提案していますが,ピクセル間の相対的な距離やグラフのノード間のエッジ関係などの内部データ構造をモデル化することができていません.

そこで,本研究で提案しているLambdaNetworksではこの2つの問題を解決しています.本研究ではQueryと構造化されたContext要素との長距離の相互作用を少ないメモリコストでモデル化するLambda Layerを提案しています.Lambda LayerContextLambdaと呼ばれる線形関数に変換し,それを対応するQueryに直接適用します.Self-AttentionQueryContext要素の間の類似性カーネルを定義するのに対し,Lambda LayerContext情報を固定サイズの線形関数(すなわち行列)に要約するため,メモリを大量に消費するAttention mapの必要性を回避することができます.

後述する実験結果からも分かる通り,LambdaNetworksConvolutionAttentionに基づいたモデルよりも精度の面で大幅に上回っていると同時に,計算効率が高く高速になっています.このLambda Layereinsum演算と畳み込みカーネルで簡単に実装でき,Github上でも公開されています.また,既存のResNetなどのモデルの一部をLambda Layerに置き換えるだけでも性能が向上することが分かっています.

Lambda Layer解説

全体イメージ

長距離にわたる相互作用を捉えるというのは,上図におけるGlobal Contextをいかに考慮するかという問題と同じ意味です.通常のConvolutionでは小さなカーネルサイズ部分しか考慮できず,Global Contextを考慮できないということが問題でした.そこで最近よく用いられるAttentionでは,各ピクセルとそのほかの全体のピクセルの重要度を算出したAttention Mapを導入することにより,Global Contextを考慮しています.ここで,上図ではLocal Contextが画像の一部のみを示していますが,Self-AttentionLambdaではしばしばLocal Contextは画像全体,すなわちGlobal Contextと同じサイズにまで広げて利用していることに注意して下さい.さて,このAttention MapGlobal Contextを考慮できるようになりましたが,同時に一つのピクセルごとに画像全体に対して異なるAttention Mapを算出する必要があるため,非常にコンピューティングコストが高いことが知られています.そこで,Lambdaでは上図からも分かる通り,Attention Mapよりも抽象的に情報を集約した後に,一度だけGlobal Contextに対するLambdaを求めています.詳細は後述しますが,大まかにAttentionLambdaの違いをイメージできましたでしょうか.

Attention

Lambdaについて説明するためには,Attentionとの違いを比較することが最も分かりやすいため,まずAttentionについて復習したいと思います.

上図がこれ以上説明する必要がないくらいにAttentionに関して簡潔にまとまっています.ここではmemorykeyvalueという形で二度使われています.Keyは探索用,valueは値そのものとして用いられ,この仕組みはLambdaでも同じように利用されています(memoryという言葉は,lambda中ではcontextと呼ばれています).Self-Attentionではinputmemoryが同一の内容ですが,Lambdaでも同様にinputcontextが同じ内容であることが多いです.Attentionでのポイントは,querykeyで内積を取ることで,ベクトルが似ていれば値が大きくなる性質を利用して関連度を算出し,その後softmaxで非線形化してAttention Mapを得,valueと積を取ってoutputとしていることです.これは,式にすると以下のように表すことができます.

以上のAttentionの流れを,さらにLambdaの説明のために書き直すと下図のようになります.ここで,contextとはmemoryと同じ意味合いであることに注意して下さい.

Lambda Layer

前置きが長くなりましたが,ここでようやくLambda Layerの説明に入ります.まずAttentionの場合と同様に図示したものが下図になります.

Attentionの図と比較すると違いが分かりやすいでしょう.違いは,QueryKeyの内積を取ったものにsoftmaxを取っているのではなく,Keyのみに対してsoftmaxを取っていることです.このKeysoftmaxで非線形化し,valueとの積を取ることが意味しているのは,著者によればcontentsummarizeしているということだそうです.前に説明したLambdaのイメージを思い出して下さい.Attentionに比べて抽象的に情報がまとまっているかと思います.実際,Contextからλに変換したことによって次元がmからkに減り,次元削減のような働きが起きていることが分かります.そして,最後にこの情報が集約されたλとQueryの積を取ってoutputとしています.Attentionを参考にしつつも,Attentionとは全く異なるアーキテクチャを提案していることが分かりますでしょうか.

position lambda

上記で説明したlambdaは,論文中ではcontent lambdaと呼ばれています.お分かりの通り,実はcontent lambdaはコンテキストをsummarizeすることができた一方で,画像中の位置関係を捉えることができていません.そこで,相対的な位置関係も考慮するためにcontent lambdaと合わせて導入されたのが,position lambdaとなります.

Eは,次元N×M×Kのインデックス付きテンソルになります.これらは学習されたパラメータの固定セットであり,位置埋め込みとしての役割を果たしています.ただし,実際には埋め込みが入力に直接は位置されるため,サイズはM×Kであり,これがN個の入力にたいして掛けられることになります.このEには埋め込みと呼ばれるベクトルが関連づけられていて,このN×MおよびK次元のベクトル行列がレイヤーごとに学習され,例ごとに計算内容が変更されないことがポイントになります.このため,N×M×Kはメモリ内で固定され,バッチサイズに応じて大きくなることがありません.

全体像

ここまで説明してきたcontent lambdaposition lambda,それに最後にQueryとλの積を取るところまでを図示したものが上図になります.そして,式に直すと以下のようになります.今までの説明を踏まえれば,特に問題はないと思います.

 

実装コード

Lambda LayerGithubhttps://github.com/lucidrains/lambda-networks)でPyTorch実装したものが公開されていますが,論文中でも簡単にコードに関する説明があります.

上記コードからも,非常に簡単にLambda Layerを定義できることが分かると思います.ここでeinsumとは,アインシュタインの縮約記法に基づいて,多次元線形代数配列演算を簡略形式で表せるpytorchに組み込みの関数で,今までtorch.matmultorch.transpose, torch.view, torch.squeezeなどで次元を合わせてきましたが,einsumを使うことで簡単にテンソルの積の演算をすることができます.

Lambda Layerは非常にシンプルな仕組みでデザインされているため,ResNetなどの既存モデルの一部だけをlambda layerに書き換えることも容易に行うことができます.論文中でも言及されているLambdaResNetも有志の方がPyTorch実装をGithubで公開しています(https://github.com/leaderj1001/LambdaNetworks).

実験結果

著者はLambda Layerを畳み込みネットワークとSelf-Attentionと比較しています.

上図からも分かる通り,Lambda Layerはパラメータ数が他の手法と比べて非常に少なく,15, 16Mしかありません.しかし,ImageNetの分類タスクにおいて他の手法を上回る精度を達成することができています.

さらに,下図ではメモリコストと精度が示されていますが,ここでもLambda Layerが,特にSelf-Attentionと比べて非常に少ないメモリで高い精度を達成していることが分かると思います. 

さらに下図では精度と訓練にかかる時間のトレードオフが図示されています.ここでも,LambdaResNetEfficientNetよりも少ない訓練時間で高い精度を達成していることが示されています.

最後に

本論文では,LambdaNetworksというAttentionを用いずに画像全体の情報も考慮できる新しいモデルが提案されました.Attentionの高いコンピューティングコストという問題を解決しつつも,高い精度を達成しており,近年のAttentionが非常に高い注目が集まっている流れを変え得る研究結果を示しました.実装自体は非常にシンプルなので,皆さんも一度試してもよいかもしれません.

記事の内容等について改善箇所などございましたら、
お問い合わせフォームよりAI-SCHOLAR編集部の方にご連絡を頂けますと幸いです。
どうぞよろしくお願いします。

お問い合わせする