超長いプロンプト文を圧縮してメモリを抑えるGoogleの高性能LLM
3つの要点
✔️ LLMは入力できるプロンプト長に限度があり、長い文章の要約ができない等の問題
✔️ プロンプトをパラメータに圧縮して記憶する部分を導入したLLMの注意機構を提案
✔️ 無限の長さのプロンプトを処理可能に。本の要約タスクで最高性能を達成
Leave No Context Behind: Efficient Infinite Context Transformers with Infini-attention
written by Tsendsuren Munkhdalai, Manaal Faruqui, Siddharth Gopal
(Submitted on 10 Apr 2024)
Comments: 9 pages, 4 figures, 4 tables
Subjects: Computation and Language (cs.CL); Artificial Intelligence (cs.AI); Machine Learning (cs.LG); Neural and Evolutionary Computing (cs.NE)
code:
本記事で使用している画像は論文中のもの、紹介スライドのもの、またはそれを参考に作成したものを使用しております。
はじめに
文章(テキスト)を理解するには全体の文脈を理解したうえで、一つ一つのトークン(文字の塊)を理解していく必要があります。
どのくらいの長さの文脈を理解できるかを大規模言語モデル (LLM)ではコンテキストウインドウサイズと呼んでいます。プロンプトを十分に理解するには、入力プロンプト長に対して、十分なコンテキストウインドウサイズがなければなりません。したがって、コンテキストウインドウサイズは、十分処理可能な入力プロンプトの長さを同時に意味します。
2024年5月、Open AIが最新のLLMとして、GPT-4oを発表しましたが、GPT-4oのコンテキストウィンドウサイズは12万8千トークンと発表されています。Open AIのブログによると、12万8千トークンはテキスト300ページ分ぐらいの量です。2022年11月にChatGPTが発表された当時のコンテキストウィンドウサイズは4千ですので、GPT-4oは初期のChatGPTに比べ32倍のコンテキストウィンドウサイズになります。
このように、当時に比べコンテキストウィンドウサイズが長くなっているのは、長い文脈を理解した処理が求められているからです。
たとえば、長い文脈を把握できないと、長い文章の要約がうまくいかなかったり、文脈内学習(in-context learning)においてタスクの長い説明を考慮しきれず思ったような回答をしてくれない、十分なバリエーションの回答の見本をLLMに教えることができない、RAG(Retrieval-Augmented Generation、検索した情報をプロンプトなどに埋め込む)によって取得した関連文書情報を十分に与えられないという問題が起きえます。
LLMに人間がやりたいことを明確に伝え、十分な情報を与え、ハルシネーションを抑えるには、コンテキストウインドウサイズを大きくできることが望ましいです。最近では、LLMの文脈で、メガプロンプト(非常に長いプロンプト)やメニーショット(文脈内学習でたくさんの教示例を与えること)という言葉が登場してきています。さらには、正しい事実に基づいて回答する能力については、ファインチューニングよりもRAGが向いているという報告も出てきています。コンテキストウィンドウを目一杯活用したLLMの高性能化が期待されているといえるでしょう。
よって、テキスト300ページ分のコンテキストウィンドウサイズがあるといっても、文脈内学習で解きたいタスクがより大規模に、より複雑になればなるほど、必要なコンテキストウィンドウサイズは増えていくと考えられます。
そのため、コンテキストウィンドウサイズ(処理可能なプロンプトの長さ)は、LLMにとって重要です。
このコンテキストウィンドウサイズが無限長のLLMがあったらどうでしょうか?
非常に長い文章を読まずに、自分の聞きたいところだけをLLMに問い合わせて、自分のペースで知りたいことを知れるでしょうし、人間側が情報を精査せずにありうる関連文献を突っ込んで最適な回答をもらうといったことも期待できるでしょう。LLMの事前学習のやり直しやファインチューニングをしなくても、手軽に文脈内に教師データを与え、文脈内学習でチューニングできる余地が広がるでしょう。
今回紹介する論文は、このコンテキストウィンドウサイズを無限大にするというGoogleのチャレンジングな技術に関するものです。LLMのベースとして使用されている機械学習モデルのTransformerではプロンプトの長さの二乗に比例し、計算量、メモリ使用量が増大するという問題がありましたが、提案手法のInfini-Transformersでは、プロンプト自体もパラメータに圧縮して記憶してしまうことで、メモリ使用量を抑制します。
では、モデルと評価結果を紹介してゆきます。
新しいアテンション機構Infini-Attentionを導入したTransformer:Infini-Transformers
課題:トランスフォーマのアテンション機構はキー、バリューを保持するためのメモリ使用量が多い
LLMで使われているトランスフォーマのアテンション構造では、入力プロンプト内の一つ一つのトークン(クエリ)を理解する際に、理解の対象となるトークン以外の前後のトークン(キー)との類似性を計算し、この類似性に応じてトークンの特徴量(バリュー)を更新することで、文脈(理解の対象自体を含む前後のトークン)を加味した理解をしてます。
このクエリとキーの類似度はクエリとキーの内積で計算されます。この行列サイズがクエリの長さ×クエリの長さになるので、プロンプトの長さに対し二次関数的に計算量、メモリ使用量が増加します。
たとえば、キーとバリューを保持するのに、ある5000億モデルパラメータのLLMのモデルでは、文脈の長さが2048(トークン)で、3テラバイトのメモリを使用したという報告があります。
計算機のメモリを超えるようなプロンプトの長さになると、物理的にデバイスの記憶可能容量を超え、受け付けることはできませんし、計算量が多すぎると、LLMから全然応答が返ってこないという問題が生じます。
以上の通り、従来のアテンション機構は、プロンプト(入力トークン列)が長くなるにつれて大きくなるキー、バリューの行列をつくるため、メモリ使用量がプロンプトの長さに応じて増大し続けるという課題があります。
解決策:圧縮メモリで前のキー、バリューを保持する機構を組み込んだアテンション機構 Infini-Attention
・解決アイデア
そこで、本論文では、入力プロンプトを分割して、前から順に処理してゆき、固定サイズのパラメータを持った圧縮メモリで、前のキーとバリューを保持するアテンション機構を提案しています。メモリ使用量は圧縮メモリのパラメータ(行列サイズ)に依存することになりますが、固定サイズのパラメータにすれば、プロンプトが長くなっても、メモリ使用量や計算量に歯止めがかかります。
・Infini-Attentionの全体構造
Infini-Attentionでは、入力トークン列をセグメント列に変換し、各セグメント内で内積を計算します。セグメント列は、入力トークン列を長さNのセグメントに分割したもので、各セグメントはインデックスSで区別します。
このようにローカルに注意処理(アテンション)を行う手法は既存にもありますが、通常のローカルアテンション手法では、処理した以前のセグメントのキー、バリューは捨ててしまいます。Infini-Attentionでは捨てずに圧縮メモリで保持します。
いわば、グローバルなアテンション(過去のセグメントを考慮したアテンション)とローカルのアテンション(現在の処理対象のセグメントに対するアテンション)を組み合わせた構造となっています。この構造を図1に示します。
・ローカルなアテンション
図1の紫のブロックは、ローカルなアテンションを実行するブロックです。対象の入力セグメントのクエリ $Q_s$ に対して通常のscaled dot product attentionを実行します。つまり、対象のセグメントの"セグメント長さN×キーの次元数の行列" $K_s$ と"キーの次元数×バリューの次元数の行列" $V_s$ から正規化した内積を計算し、”N×バリューの次元数の行列” $V_s$ との積を計算することで、アテンション文脈 $A_{dot}$ を得ます。
・グローバルなアテンション(圧縮メモリを用いたアテンション)
この紫のブロックだけですと、ローカルのアテンションだけになってしまいますので、グローバルなアテンションを計算する緑のブロックがあります。このブロックは過去のセグメントに関してのキー、バリューに基づくアテンション計算結果を保持しています。これは過去のキー、バリューをすべて圧縮して記憶するというより、1セグメントの処理が終わるたびに状態を更新していくイメージで、”前”の状態を常にアップデート(図1のUpdate)しておくイメージです。アップデータされた圧縮メモリを検索(図1のRetrieve)し、ローカルアテンションの出力と組み合わされます(図1のConcat)。
圧縮メモリの検索では、対象セグメントのセグメント長さN×キーの次元数のサイズのクエリの行列 $Q_{s}$ とキーの次元数×バリューの次元数の圧縮メモリの行列 $M_{s-1}$ 、正規化項 $z$ に基づき、圧縮メモリの中身(N×キーの次元数の行列) $A_{mem}$ を得ます。
圧縮メモリの更新では、圧縮メモリの行列と正規化項を対象のセグメントのキー、バリューに基づく計算結果で更新し、次のセグメントの処理で利用します。
更新後の圧縮メモリの行列=更新前の圧縮メモリの行列+活性化関数と対象のセグメントのキー $K$ 、バリュー $V$ に基づく行列積。
この加算した項は連想割付作用素(associative binding operator)と呼ばれています。
Inifini-Transformersでは、上記の圧縮メモリの処理を改善した既存手法であるDeltaルールという方式を採用しています。Deltaルールでは、新しいバリューから圧縮メモリの検索結果 $A_{mem}$ を差し引いた上で連想割付作用素を適用します。
・入力セグメントとトランスフォーマブロックと圧縮メモリの入出力関係
図2は入力セグメントとトランスフォーマブロックと圧縮メモリの入出力関係を図にしたもので、各セグメントが複数層のトランスフォーマで処理(灰色の矢印)されますが、各層のキー、バリューの計算結果によって圧縮メモリは更新(青の矢印)され、圧縮メモリとして次のセグメントの処理時に利用(紫の矢印)されています。これにより、各セグメントの処理時に有効な文脈は過去のセグメントも含んだものになります。
圧縮メモリ側のアテンションは、scaled dot product attentionではなく、計算量が線形オーダとされるlinear attention(Linear)が用いられています。
ローカルのアテンションの計算結果 $A_{dot}$ とグローバルのアテンションの計算結果 $A_{mem}$ は学習可能なトレードオフ調整パラメータに基づいて組み合わせます。
圧縮メモリは、 $M_{s-1}$ と $z$ を記憶するため、トランスフォーマーの各層で、キーの次元数×バリューの次元数( $M_{s-1}$ 分)+キーの次元数のサイズ( $z$ 分)のメモリを使用しますが、この固定のメモリ使用量で理論的には無限の入力セグメントを受け付けることができます。
・既存に比べたメモリ使用量削減効果
既存手法にも、LLMへの入力シーケンスをモデルパラメータとして記憶するトランスフォーマの提案はありますが、提案手法では、たとえば、既存手法のメモリ長さ6万5千のMemorizing Transformersに対し114倍の圧縮比を達成します。提案のInfini-Transformersは非常に圧縮比が高いことが分かります。従来に比べ非常に少ないメモリ使用量で長い文脈を処理できることが分かります。
評価結果
パスキータスク
長い文脈を理解できているかを評価するタスクとして、パスキータスクがあります。LLMに長いプロンプトを与え、その中にランダムな数字(パスキー)を忍ばせ、パスキー(合言葉)を見つけるよう指示を出し、パスキーを正確に抽出できるかを確かめます。
与えるプロンプトの長さは32K(3万2千), 128K(12万8千), 256K(25万6千), 512K(51万2千), 1M(100万)トークンの5パタンです。
パスキーを忍ばせる場所については、プロンプトのはじめ、真ん中、終わりの3パタンを試します。
Zero-shotで解かせる場合と5Kトークンの入力サイズで400ステップのファインチューニング(FT)をする場合、両方を実験しました。
比較手法は提案のInfini-Transformersにおいて、Linearは、Deltaルールを用いない場合、Linear+DeltaはDeltaルールを用いる場合です。
図3にパスキータスクの評価結果を示します。/で区切られた数字は、それぞれパスキーをプロンプトのはじめ、真ん中、終わりに置いた場合のパスキー抽出の成功率を示します。パスキーをプロンプトの終わりに置くと、正解率は高いですが、それ以外は成功率が低いです。
本論文の著者による考察はないですが、圧縮メモリによって、過去のコンテキストを保持するため、圧縮メモリの副作用として、遠い過去の情報が薄まってしまうということだと考えられます。
与えるプロンプトの長さに対する成功率の依存性は、緩やかで、長いほど、成功率は下がりますが、プロンプトの長さが2倍になったら成功率が1/2になるような大きな依存性はなく、緩やかに成功率は減少します。
FTをする場合、Zero-shotに比べ大きく成功率が向上しています。有効な文脈範囲がたとえ長くなったとしても、その文脈の情報を有効活用できるわけではなく、ファインチューニングによって情報をうまく使えるように訓練する必要があるのかもしれません。
Infini-Transformersにおいて、Deltaを用いない場合(Linear)とDeltaを用いる場合(Linear+Delta)の比較がありますが、基本性能に違いはあまりありません。図3のZero-shotの128Kの場合のみ、顕著な差が出ています。
本の要約
80億モデルパラメータのLLMを入力長8千で3万ステップの事前学習を行い、本の翻訳タスクにファインチューニングした。ファインチューニングは入力長3万2千で実施しました。評価は、入力長8万で実施しました。
評価指標はRougeで、機械要約と正解要約の一致度を測る指標であり、いくつかバリエーションがありますが、高いほど良いです。
評価結果を図4に示します。提案手法であるInfini-Transformers(Linear + Delta)は、既存の最高性能手法を上回る性能を達成しました。
おわりに
本記事では、入力プロンプトを分割したセグメントに対するアテンションであるローカルアテンションと圧縮メモリを用いて入力プロンプト全体の文脈を考慮したグローバルアテンションをTransfomerに導入することで無限の長さのプロンプトに対応可能にしたLLMについて説明しました。
従来のTransformerのメモリ使用量は、入力プロンプトの長さに依存し、その長さの二乗に比例するメモリ使用量を要します。一方で、Inifini-Transfomersのメモリ使用量は、入力プロンプトの長さに依りません。そのメモリ使用量は、キーの次元数×バリューの次元数のサイズの圧縮メモリの行列 $M_{s-1}$ とキーの次元数のサイズの正規化項 $z_{s-1}$ に依って決まります。そのため、無限の長さのプロンプトに対して、メモリ使用量が無限大にならずに済みます。
長い文脈の処理が必要な、本の内容全体の要約タスクにおいて、Infini-Transformersは既存を上回る性能を示しました。
今後、文脈内学習で様々なことをLLMにやらせようと思うと、文脈の長さは長大化するはずです。この文脈の長さが無限になっても対処できるというのは、非常にインパクトがあります。
ただし、提案のInfini-Transformersにより、入力プロンプトの長さにメモリ使用量が依存しなくなったといっても、入力プロンプトの最適なセグメントサイズをどう決めるべきで、それがどれほど精度に影響するのか気になります。
また、事前学習、ファインチューニングにおいて、学習時の入力プロンプトの長さについても、どのぐらいの長さに設定しておけば、どのくらいの長さのプロンプトを十分に理解できるのか気になります。
評価上は、3万2千トークンの入力プロンプトでファインチューニングすれば、50万トークンの文脈を処理できたということで、学習時の入力プロンプトの10倍ほどのトークンを推論時に処理できそうということは分かりますが、常にこのような比例関係があるのか、本の要約タスクに限定なのか?、毎回タスク別にファインチューニングしなくてはいけないのか?気になるところです。
ChatGPTなどでこういった入力プロンプト長を無限大にする技術が採用され、入力プロンプトの制限が大幅に緩和されると、メガプロンプトのトレンドが加速すると思います。一方で、ユーザからの入力プロンプトサイズが巨大になり、サーバ負荷が異常に高まるなど別の問題を引き起こしそうです。また、学習時に長い入力プロンプトを教示する必要がありますので、入力長の長い適切なデータセットの準備の方がボトルネックになるのかもしれません。
この記事に関するカテゴリー