【BitNet b1.58】モデルパラメータを3値で表現しLlama以上の精度を達成!?
3つの要点
✔️ 大規模言語モデルは、計算量、メモリ使用量、消費電力が膨大
✔️ モデルパラメータ数×モデルパラメータ精度分だけ計算量、メモリ使用量、消費電力が増大するのが問題点
✔️ 問題解決のため、モデルパラメータ精度を16bit(約7万値)から1.58bit(3値)に落としても、LLaMAと同等な回答精度を達成する言語モデルを提案
The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits
written by Shuming Ma, Hongyu Wang, Lingxiao Ma, Lei Wang, Wenhui Wang, Shaohan Huang, Li Dong, Ruiping Wang, Jilong Xue, Furu Wei
(Submitted on 27 Feb 2024)
Comments: Work in progress
Subjects: Computation and Language (cs.CL); Machine Learning (cs.LG)
code:
本記事で使用している画像は論文中のもの、紹介スライドのもの、またはそれを参考に作成したものを使用しております。
はじめに
大規模言語モデル(Large Language Models, LLM)に対して、小規模言語モデル(Small Language Models, SLM)が脚光を浴びつつあります。
大規模なモデルパラメータ数、大規模なデータセットで学習を行う大規模言語モデルは、その回答能力の高さで人工知能に対する一般の人々の期待値を底上げしました。しかし、大規模言語モデルの学習、推論を実行するには、超ハイスペックな計算機が必要です。そのため、オンプレミス(自前の建物内に設置した計算機)やエッジ(スマホなどの端末内の計算機)ではなく、クラウドサーバを介したサービスとしてLLMを利用する人が大半です。
企業にとって、クラウドサーバを介したサービスを使う場合、そういったサービスを利用するアカウント管理、予算申請、情報流出などのセキュリティリスクへの対策が必要となります。特に、独自の大規模データを活用したい企業にとっては、大きな足かせです。
そういった足かせを外し、誰もが気軽にAIの恩恵を受けられるようにする上で注目されるのが、小規模言語モデルです。大規模言語モデルに対して、小規模な言語モデルを採用することで、ハードウエア要件として必要な計算量、メモリ使用量、消費電力の基準を下げることができます。
したがって、小規模言語モデルは、AIの恩恵を受けるためのハードウエア要件を緩和することができ、オンプレミスやエッジでの利用を促進する期待があります。これは、AIの民主化を加速する一つの流れになると考えられます。
小規模言語モデルの中でも、チャレンジングなものが、今回紹介する1bit LLMです(正確には1.58bit LLMです。1bit LLMの後続として1.58bit LLMが提案されています)。
大規模言語モデルで、何がハードウエア要件を引き上げているかというと、モデルパラメータ数の多さです。さらに細かいことを言うと、モデルパラメータ数×モデルパラメータ精度が問題となってきます。
モデルパラメータ精度とは、数値を何段階の値で表すかです。たとえば、円周率を3桁で表現すると3.14ですが、これを1桁で表現すると3です。3桁で表現する場合は9.99~0.00の1000段階で値を区別しているわけです。1桁で表現すると、0~9の10段階で値を区別するわけですから、1/100ぐらいの精度で値を区別していると見なせます。
ここで、1桁の場合と3桁の場合の計算量、メモリ使用量、消費電力を比べると、どうでしょうか?桁数が少ない方、つまり、円周率を3と考えた方が計算は楽ですし、覚えるのも楽だと思います。計算するときのカロリーも少ないと感覚的にわかると思います。
計算機の世界では、10になったら桁の繰り上げが起こる10進数ではなく2になったら桁の繰り上げが起こる2進数で考えることが基本のため、2進数における桁数であるbitを用いて、モデルパラメータ精度を表現します。
今回の論文は、-1,0,1の3値でモデルパラメータを表現します。今までの一般的なモデルパラメータ精度が16bitであったのに対して、たったの1.58bitです(たとえば、10進数で1000は10の3乗で、この値が大きいほど桁数が大きいと考えることができます。同じように、3は2の何乗か?を計算すると、3値は2進数で何桁相当か換算でき、その答えが1.58です)。
このようなモデルパラメータ精度で、LLMの回答精度がガタ落ちしないのか心配になりますが、驚くべきことにLLMと同等、モデルパラメータ数によっては、むしろ3値にした方が回答精度が良いという結果が得られています。
このように、LLMsの回答精度をなるべく維持しつつ、モデルパラメータの精度を下げるような技術は量子化と呼ばれ、研究されています。
それでは、本論文の提案手法であるBitNet b1.58と評価結果について、具体的に説明してゆきます。
BitNet b1.58のメリット
従来のLLMsに対する提案のBitNet b1.58のコストパフォーマンス比較を図1に示します。
BitNet b1.58の特徴は、モデルパラメータ、いわゆるニューラルネットワークの重みが図1の左のWのように、-1,0,1の3値のどれかになるという点です。
従来のLLMsの重みは、16bitの浮動小数点数で表現されます。浮動小数点数は2.961×10の-1のように、(仮数部)×(指数部)乗の形式で値を表現するものです。仮数部、指数部、符号にそれぞれ16bitを配分するもので、図1の右のTransformer LLMsのWの通り、小数を表現できます。
基本的に計算機は、bit単位で計算する演算器が備わっており、bitごとに演算するので、bitが多いほど計算コストはかかりますし、bitの値を保持するメモリのコストも増えます。演算器を多数用意し並列計算できるようにすれば計算時間は短くなりますが、その分エネルギー消費は大きくなります。
モデルパラメータ数が多いと、実はモデルパラメータ値を使って推論をする際に、大量のモデルパラメータ値情報をメモリから転送する時間自体も推論時間(LLMsが入力に対して回答するのにかかる時間)増大の要因になってしまいます。
したがって、BitNet b1.58は、モデルパラメータ精度を3値まで下げることで、BitNet b1.58の方がTransformer LLMsよりコストが低くできると図1の横軸で主張されています。
性能(Performance)に関しては、BitNet b1.58は従来と同等だと主張しています。
このように、性能とコストの2軸で従来のLLMsと比べた場合に、どの軸でも従来に劣らず、少なくとも一つの軸、ここではCostの観点では優れていること(Pareto Improvement)を主張しています。
BitNet b1.58で必要な演算を図2に示します。
従来はモデルパラメータと入力の掛け算(Multiplication)と足し算(Addition)が必要でした。
対して、BitNet b1.58は足し算だけになります。つまり、より単純な計算で済むようになるということです。
従来のGPUは掛け算と足し算、いわゆる積和が多数出る行列演算を高速化するものでしたが、今回のBitNet b1.58は、足し算(と、0の時は入力を0に、-1の時は符号を負に、1の時は符号を正にするような演算)を高速化できれば良いです。
BitNet b1.58の技術ポイント
BitNet b1.58は、BitNetをベースにしています。
量子化では、学習後のモデルパラメータの精度を削減してしまうものと、モデルパラメータの精度を下げることを意識して学習する方式があります。
前者は、後処理としてモデルパラメータの精度を削減するので、既存のものに対して適用しやすい利便性はありますが、モデルの性能は下がりやすいという欠点があります。
後者は、モデルパラメータの精度を下げることを意識して学習した場合はその分計算コストがかかるデメリットがありますがモデルの性能低下を抑えられるとされています。
BitNetは、後者の量子化を意識して学習をするイメージです。
このように学習中に量子化処理が挟まる場合、連続値を離散値に丸めるような処理が通常入ってくるので、非連続的な変換になり、微分不可能になります。
このとき、ニューラルネットワークのモデルパラメータ値を効率よく計算するための誤差逆伝搬法が使えなくなってしまい、問題となります。BitNetでは、既存のストレートスルー方式と呼ばれる、微分不可能な関数はスルーして、そのまま結果を伝搬し微分を計算できるところで計算して、誤差逆伝搬するという経験的な方式を採用しています。
これらはBitNet b1.58でも引き継いでいる処理になります。
BitNetとBitNet b1.58(今回)の違いは、モデルパラメータ値を-1,1の2値で表すか、-1,0,1の3値で表すかです。
BitNet b1.58は、BitNetのメリットはそのままに、BitNetに追加のメリットがあります。
-1,1だけでなく0を追加することで当たり前ですが、モデルパラメータの値の表現精度が高くなるので、モデルの表現力が上がります。
また、0を含めることの意味として、特徴フィルタリングの効果が期待できます。
機械学習では一般的に不要な特徴量が含まれていると、モデルの予測性能に大きな悪影響を与えます。0を含めると、直接的に不要な特徴を削減できます。
従来のBitNetの量子化方法
従来の-1,1の2値をとるBitNetの場合、モデルパラメータ値に関しては、0以上なら1で、0より小さいなら-1に変換します。ただし、現状のモデルパラメータ値の中心というのは、いわばモデルパラメータ値の平均となりますので、中心が0からずれていると変換に偏りが生じ、誤差が大きくなってしまいます。
そこで、モデルパラメータ値の平均を差し引いた上で(ゼロポイント調整した上で)、0以上なら1で、0より小さいなら-1に変換します。
活性化関数の量子化では、入力行列の要素の絶対値の最大値で割り、値の範囲を[-1,1]にし、Q(2のn-1乗、何ビットに量子化するかに依る)をかけることで、[-Q, Q]とします。活性化関数が非対称な場合、たとえばRelu関数を想定していた場合は、0が閾値となるので、最小値を引いてから、同様な処理をすることで[0,Q]の範囲とします。
ニューラルネットワークでは、入力とモデルパラメータの積和を計算し、その出力に活性化関数を適用します。活性化関数は、極端にいえば、ネットワークのあるニューロンを発火するか、しないかの0,1を決めるもので、閾値を超えるものを1、そうでないものを0と判断します。
したがって、適切に処理しないと、極端に0ばかり出力されるなどの問題が生じる可能性があるので、閾値を気にして範囲調整していると考えられます。
提案のBitNetの量子化方法
提案では、モデルパラメータ値をモデルパラメータ値の絶対値平均で割り(スケール処理)、値を整数に丸め(丸め処理)、-1より小さければ-1に、1より大きければ1に変換(クリップ処理)します。
たとえば、従来の活性化関数の量子化同様に、最大の絶対値で割るという処理も考えられると思いますが、一つの値だけ極端に絶対値が大きい場合に、それを元にスケールしてしまうと、1つの値だけ、絶対値が大きく、他の値は大きな値で割られた結果、(-1,0,1の3値あるのにそのうちの一つの)0付近に固まってしまうということが考えられます。
そうなると0に偏った変換になるので、平均値以上ならば、+1 or -1、平均値より小さいならば、0となるように、平均値でスケールし、-1,1をはみ出す値は-1 or 1の近い方の値に変換する方式をBitNet b1.58では採用していると想像されます。
活性化関数の量子化は従来と同様です。ただし、従来では、ReLuの場合、最小値を差し引いて[0,Q]に値の範囲を変換していましたが、よりシンプルな処理にするために、最小値を差し引く処理は入れず、常に値の範囲を[-Q, Q]に変換します。なぜ、今回はシンプルにするために最小値を差し引くようなゼロポイント調整処理を省いてもよいと考えられるのかは不明なところです。おそらく試してみたら、そんなに問題なかったということなのでしょう。
評価結果
メモリ使用量と応答速度と予測精度
メタが開発したLLMであるLLaMA LLMとBitNet b1.58のメモリ使用量と応答速度と予測精度を図3に示します。
モデルパラメータ数(Size)が7億(700M)、13億(1.3B)、30億(3B)、39億(3.9B)の場合のGPUメモリサイズ(Memory)、応答速度(Latency)、予測誤差(PPL)が示されています。表の矢印の通り、どの指標も小さいほどよいです。
BitNetはLLaMA LLMに比べ、GPUメモリサイズは2.6倍~3.6倍小さく、応答速度は、1.23倍~2.7倍、予測誤差はほとんど変わらない結果となっています。
モデルパラメータ数が多いほど、メモリ使用サイズ、応答速度、予測誤差は小さくなる傾向で、モデルパラメータ数が3B、3.9Bのときに、GPUメモリサイズ、応答速度、予測誤差は、LLaMA LLMにすべて優れているという結果になっています。
エネルギー消費量
512トークン入力時のLLaMAとBitNet b1.58のエネルギーコストの比較を図4に示します。
BitNet b1.58は、LLaMAに比べ、19倍~41倍エネルギーコストが小さく、モデルパラメータ数(Model Size)が多いほど、エネルギー消費(Energy)が小さいです。
終わりに
本記事では、BitNet b1.58について説明しました。
従来のLLMはモデルパラメータ数が増えると、予測性能は向上してゆきますが、メモリ使用量、応答時間、エネルギー消費も大幅に増えてしまうという問題がありました。
従来のLLMは1つのモデルパラメータ値を16bitの浮動小数点数で表現するため、モデルパラメータ数×16bitの情報量を保存する必要があり、モデルパラメータ数を増やすほど、メモリ使用量が増え、その情報を転送する時間がかかり応答時間も伸び、計算量が増えるのでエネルギーの消費も大きくなります。
この問題を緩和するために、なるべくLLMの回答精度を維持しつつモデルパラメータ値の精度を16bitの浮動小数点数から1.58bit(-1,0,1の3値)に下げる技術(量子化)を紹介しました。-1,1の2値に削減するBitNetがすでに提案されていましたが、-1,1に0を加えることで、ほとんどの計算を足し算だけで計算できるというメリットを維持しつつ、特徴を絞り込むフィルタ効果を得ることができ、メモリ使用量、応答時間、エネルギー消費量の削減効果を少し犠牲にするだけで、LLMの予測精度を向上することができます。モデルパラメータ数によっては、予測精度の劣化を微小に抑えるどころか、逆に向上する結果も得られています。
LLaMAと比較した結果、メモリ使用量、応答時間、エネルギー消費量を削減しつつ、予測性能を従来以上にできることが示されました。
本論文では、提案手法では、ほとんどの計算を足し算だけできることから、GPUとは異なる新しいハードウエハを作ることで更なる高速処理やエネルギー消費の削減が期待できるとされています。
現在、NVIDIA株が異常な注目を浴びていますが、GPUとは異なるハードウエアでよいとなれば、NVIDIA一強を覆すことになりえます。とはいっても、BitNet b1.58は、論文ではGPUを用いて評価していますし、活性化関数は後述するように8bitですので、全ての処理過程で3値というわけではないと考えられます。すべての過程で、足し算だけとなれば抜本的にハードウエアを変えることができそうですが、掛け算が残ってしまうのであれば、大きく変わりにくいと思われます。
また、メーカーとしてはより最適な新しいハードウエアがあると言われた方が、NVIDIAに付け入るスキがあるとうれしくなるかもしれないですが、一般ユーザからすれば、新しいハードウエアを買わされるよりかは、汎用のCPUもしくはより廉価なハードウエアで高速計算できると言われた方がうれしいような気がします。
この記事に関するカテゴリー