Cross-Layer AttentionによってTransformerのメモリを大幅に削減
3つの要点
✔️ Cross-Layer Attention(CLA)によるKVキャッシュのメモリ削減
✔️ CLAを用いた1Bおよび3Bパラメータモデルでの精度維持
✔️ CLAとMulti-Query Attention(MQA)やGrouped-Query Attention(GQA)との組み合わせによる効果的なメモリ効率向上
Reducing Transformer Key-Value Cache Size with Cross-Layer Attention
written by William Brandon, Mayank Mishra, Aniruddha Nrusimha, Rameswar Panda, Jonathan Ragan Kelly
(Submitted on 21 May 2024)
Comments: Published on arxiv.
Subjects: Machine Learning (cs.LG); Computation and Language (cs.CL)
code:
本記事で使用している画像は論文中のもの、紹介スライドのもの、またはそれを参考に作成したものを使用しております。
序論
近年、トランスフォーマーモデルは自然言語処理の分野で飛躍的な進歩を遂げ、様々な応用分野で優れた成果を上げています。しかし、大規模な言語モデルの性能を最大限に引き出すためには、高いメモリ要件を持つキー・バリュー(KV)キャッシュが不可欠です。特に長いシーケンスや大規模なバッチサイズを扱う場合、そのメモリ消費は非常に高くなり、実用上の課題となっています。
この課題を解決するために、多くの研究者がKVキャッシュのメモリ効率を改善する方法を模索してきました。中でも、Multi-Query Attention(MQA)やGrouped-Query Attention(GQA)は、複数のクエリヘッドが単一のキー/バリューヘッドを共有することで、KVキャッシュのサイズを削減する有効な手段として広く採用されています。しかし、さらなるメモリ効率の向上が求められています。
このような背景から、MITとMIT-IBM Watson AI Labの研究者たちは、新たなアプローチとして「Cross-Layer Attention(CLA)」を提案しました。CLAは、隣接するレイヤー間でキーとバリューのヘッドを共有することで、KVキャッシュのサイズをさらに削減しつつ、モデルの精度をほぼ維持することを目指しています。
関連研究
トランスフォーマーモデルの性能と効率を最大限に引き出すため、多くの研究者が様々なアプローチを模索しています。この論文では、特にKVキャッシュのメモリ効率を改善するための関連研究に焦点を当てています。以下に、論文内で紹介されている主要な関連研究をまとめます。
Multi-Query Attention(MQA)とGrouped-Query Attention(GQA)
最も関連性の高い研究として、トランスフォーマーモデルの注意メカニズムを改良するMulti-Query Attention(MQA)とGrouped-Query Attention(GQA)が挙げられます。Shazeer(2019)が提案したMQAは、複数のクエリヘッドが単一のキー/バリューヘッドを共有することで、KVキャッシュのサイズを大幅に削減します。Ainslieら(2023)がこれを一般化し、Grouped-Query Attention(GQA)として、クエリヘッドをグループ化し、各グループが単一のキー/バリューヘッドを共有するアーキテクチャを提案しました。これにより、メモリ効率が向上しつつ、精度の低下を最小限に抑えることが可能となります。
KVキャッシュ圧縮
KVキャッシュのサイズを削減する別のアプローチとして、KVキャッシュ圧縮が挙げられます。Hooperら(2024)は、キーとバリューの量子化を行い、低精度でのストレージを実現するKVQuantを提案しました。また、Zhangら(2024)は、Coupled Quantizationという手法を用いて、キーとバリューの非均一なエンコーディングを行い、KVキャッシュを1ビットまたは2ビットに圧縮する方法を示しています。
不要なKVキャッシュエントリの削除
他のアプローチとして、不要なKVキャッシュエントリを削除する方法があります。Zhangら(2023)は、重要でないKVキャッシュエントリを削除するH2Oを提案しました。Liuら(2023)は、Scissorhandsという手法を用いて、生成中の重要なトークンのみを保存する方法を示しました。これらの手法により、KVキャッシュのメモリ使用量を効果的に削減することが可能となります。
アーキテクチャの変更によるKVキャッシュサイズの削減
本論文で提案されているCross-Layer Attention(CLA)は、アーキテクチャの変更によってKVキャッシュサイズを削減するアプローチです。これは、従来のGQAやMQAが単一のレイヤー内でキー/バリューの共有を行っていたのに対し、CLAは隣接するレイヤー間でキー/バリューの共有を行う点で独自性があります。これにより、KVキャッシュのメモリフットプリントをさらに削減しつつ、モデルの精度を維持することが可能となります。
トレーニングメモリの効率化
トレーニング時のメモリ効率を向上させる研究も多く行われています。Shoeybiら(2020)は、巨大なニューラルネットワークのトレーニングを効率化するためのモデル並列化技術であるMegatron-LMを提案しました。Huangら(2019)は、パイプライン並列化を用いてトレーニングメモリの使用を最適化するGPipeを紹介しました。CLAはこれらの技術とも互換性があり、さらなるメモリ効率の向上が期待されます。
提案手法(Cross-Layer Attention)
トランスフォーマーモデルにおけるKVキャッシュのメモリ問題を解決するために、研究者たちは新たな手法である「Cross・Layer Attention(CLA)」を提案しました。CLAは、隣接するレイヤー間でキーとバリューのヘッドを共有することで、KVキャッシュのサイズを削減しつつ、モデルの精度を維持することを目指しています。このセクションでは、CLAの設計とその具体的な動作について詳しく説明します。
CLAの基本概念
従来のトランスフォーマーアーキテクチャでは、各レイヤーが独自のキーとバリューを計算し、それらをKVキャッシュに保存します。この方法では、長いシーケンスや大規模なバッチサイズに対応するために大量のメモリが必要となります。これに対し、CLAは一部のレイヤーで計算されたキーとバリューを隣接するレイヤーでも共有することで、メモリの使用量を削減します。
具体的には、CLAは次のように動作します。
・キー/バリューの計算と共有: 一部のレイヤーは独自にキーとバリューを計算し、これをKVキャッシュに保存します。その後、隣接するレイヤーはこの計算済みのキーとバリューを再利用します(図1参照)。
・シェアリングファクター: キーとバリューが共有されるレイヤーの数を「シェアリングファクター」と呼びます。例えば、シェアリングファクターが2の場合、各ペアのレイヤーが同じキーとバリューを使用します(図2参照)。
この手法により、KVキャッシュのメモリ使用量がシェアリングファクターの分だけ削減されます。
図1: Cross・Layer Attention(CLA)の概念図
CLAのアーキテクチャ
CLAの設計は、従来のMulti・Query Attention(MQA)やGrouped・Query Attention(GQA)と組み合わせることが可能です。従来のMQAやGQAが同一レイヤー内でキーとバリューを共有するのに対し、CLAは複数のレイヤー間で共有を行います。これにより、さらなるメモリ削減が可能となります。
CLAの具体的な構造は以下の通りです。
・ キー/バリューの投影: いくつかのレイヤーでは独自にキーとバリューの投影を行い、その結果をKVキャッシュに保存します。他のレイヤーでは、これらの投影結果を再利用します。
・ 組み合わせの柔軟性: CLAは、MQAやGQAと組み合わせることができ、それぞれの利点を組み合わせることで最適なメモリ効率を実現します。
図2: シェアリングファクターの異なるCLAの構成
実験
この研究では、提案されたCross・Layer Attention(CLA)手法の有効性を検証するため、1Bおよび3Bパラメータのモデルを使用して一連の実験が行われました。
すべての実験において、モデルはSlimPajamaデータセットを用いて訓練されました。モデルのトークナイザーにはGPT・NeoXトークナイザーが使用され、Byte・Pair Encoding(BPE)でトークン化が行われました。また、Llamaのアーキテクチャに基づき、前正規化、SwiGLU活性化関数、回転位置エンベディングが採用されました。トレーニングはNVIDIA H100 GPUを使用し、PyTorchフレームワークで実施されました。
1Bパラメータモデルの実験結果
1Bパラメータモデルでは、様々なCLA構成がテストされました。特に、MQA・CLA2構成が優れた性能を示しました(図3参照)。
図3: 1Bパラメータモデルの実験結果
・ MQA・CLA2モデル: 頭部次元を64から512まで変化させたMQA・CLA2モデルは、KVキャッシュメモリを削減しつつ、従来のMQAモデルに比べて精度を向上させました。特に、頭部次元が128のモデルでは、従来のMQAモデルと比較してメモリ使用量が半減し、精度がほぼ同等であることが確認されました。
・ GQA・CLA2モデル: GQAとCLA2を組み合わせたモデルもテストされましたが、最も効果的だったのはGQA2・CLA2構成であり、他の構成に比べて優れた精度を示しました。
3Bパラメータモデルの実験結果
3BパラメータモデルでもCLAの効果を検証するための実験が行われました。ここでも、MQA・CLA2構成が最も効果的であることが確認されました。
・ 頭部次元128のMQAモデル: 学習率を調整した結果、MQA・CLA2モデルは頭部次元128の従来のMQAモデルと比較して優れた精度を示しました。特に、Wikitextデータセットでのパープレキシティ(perplexity)で大きな差が見られました(表5参照)。
表5: 3Bパラメータモデルの実験結果
考察
1. メモリ効率の向上: CLAは、特にシェアリングファクターが2の場合において、KVキャッシュのメモリ使用量を効果的に削減しながら精度をほぼ維持できることが確認されました。これにより、従来のアーキテクチャに比べて大幅なメモリ効率の向上が実現されました。
2. 精度の維持: CLAを使用することで、精度の低下を最小限に抑えつつ、メモリ使用量を削減できるため、特に長いシーケンスや大規模なバッチサイズを扱うシナリオにおいて有用です。
3. 学習率の重要性: 学習率の調整がモデルの性能に与える影響が大きく、CLAモデルでは特に高い学習率が効果的であることが示されました。これは、CLAがメモリ効率を向上させるだけでなく、トレーニングプロセス自体の効率も向上させる可能性を示しています。
これらの結果から、CLAはトランスフォーマーモデルの設計における新たな標準となりうる手法であり、実用性と効率性の両面で大きな利点を提供することが明らかになりました。
結論
本稿では、トランスフォーマーモデルにおけるKVキャッシュのメモリ使用量を削減する新たな手法として、Cross-Layer Attention(CLA)を提案しました。CLAは、隣接するレイヤー間でキーとバリューを共有することで、KVキャッシュのサイズを半減させつつ、精度をほぼ維持できることを示しました。特に1Bおよび3Bパラメータのモデルでの実験結果は、CLAがメモリ効率と精度の両方で優れた性能を発揮することを示しています。
今後の展望としては、CLAのさらなる最適化と拡張が考えられます。例えば、異なるモデルアーキテクチャやより大規模なモデルへの適用、また長期的なシーケンスに対するCLAの効果の検証などが挙げられます。また、実際のアプリケーションでのCLAの効果を評価することで、その実用性と有効性をさらに確認することが重要です。CLAは、トランスフォーマーモデルの進化に貢献する重要なステップとなるでしょう。
この記事に関するカテゴリー