Catch up on the latest AI articles

[RetrievalAttention] Improved Efficiency Of LLM For Processing Long Contexts!

[RetrievalAttention] Improved Efficiency Of LLM For Processing Long Contexts!

Large Language Models

3 main points
✔️ Proposed "RetrievalAttention" method to improve inference speed of large-scale language models (LLMs) for long context
✔️ Enabled faster inference while maintaining high accuracy and reducing huge memory usage and computational costs

✔️ Significantly improved inference efficiency on long-text tasks

RetrievalAttention: Accelerating Long-Context LLM Inference via Vector Retrieval
written by Di LiuMeng ChenBaotong LuHuiqiang JiangZhenhua HanQianxi ZhangQi ChenChengruidong ZhangBailu DingKai ZhangChen ChenFan YangYuqing YangLili Qiu
(Submitted on 16 Sep 2024)
Comments: 16 pages

Subjects:  Machine Learning (cs.LG); Computation and Language (cs.CL)

code:  

 

The images used in this article are from the paper, the introductory slides, or were created based on them.

Summary

The research background of this paper is primarily aimed at solving the problem of "large language models (LLMs) with long contexts".

Transformer-based LLMs are widely used in diverse fields, but the computational cost is very high when dealing with long contexts. In particular, the computation of "attention" increases processing time and memory usage as the context gets longer, which becomes a bottleneck. Many techniques have been developed to solve this problem, but none of them have reached a complete solution.

Our method is superior to existing attention optimization techniques in that it allows very efficient inference for long contexts while maintaining nearly equivalent accuracy. In particular, it can efficiently run large models (8B-parameter models) on relatively low-spec GPUs.

In short, RetrievalAttention is a technology that dramatically improves memory and time efficiency in LLM inference for long contexts, and is a major step toward practical application.

Research Background

Large-scale language models can process very long text data and perform very well in a variety of natural language processing tasks. For example, they can read huge amounts of text and generate responses or summarize based on the content. However, the "attention mechanism," which is at the core of these models, has significant challenges.

The attention mechanism is a technique for determining which parts of the input text are important and predicting the next word based on that determination. However, since attention computation is done by comparing two sets of vectors, the "query" and "key-value" vectors, the computational complexity increases dramatically as the text gets longer. This leads to slow inference speed and huge memory usage. A major bottleneck is that GPU memory quickly reaches its limit, especially when the context is very long.

The traditional solution has been to use a technique called "KV caching". This technique eliminates unnecessary calculations by retaining and reusing past "key" and "value" states necessary for calculations. However, even this method consumes a large amount of memory when processing long contexts. For example, trying to process a large number of tokens on a single GPU may require more than 500 GB of memory. This makes it difficult to use in a realistic system, so a more efficient method was needed.

Therefore, this paper focuses on the feature of "dynamic sparseness" in the attention mechanism. In practice, not all tokens are important for predicting the next word, but only certain subset of tokens play an important role. In other words, the idea is that it is not necessary to include all tokens in the calculation, and that if we focus only on the important tokens, we can greatly reduce the cost of the calculation.

Based on this background, a new method, RetrievalAttention, has been proposed to achieve efficient attention calculation.

Proposed Method

In the ransformer model, the attention mechanism determines which parts of the input text are important and predicts the next token based on that. However, the longer the context, the more computationally intensive it becomes to calculate the attention for all tokens.

A key feature of RetrievalAttention is that it addresses the difference in distribution between query and key vectors (OOD problem). In ordinary approximate optimal search, it is assumed that query and key belong to the same distribution, but in attention computation, query and key-value vectors often have different distributions, which leads to poor performance. To solve this problem, RetrievalAttention uses a new search algorithm that adapts to attention-specific distributions. This approach makes it possible to obtain highly accurate attention results even when scanning only 1-3% of the data for a query.

RetrievalAttention also makes good use of GPU and CPU memory. Specifically, important "key value" vectors are kept on the GPU, while the remaining data is offloaded to the CPU, thereby reducing GPU memory consumption while maintaining computational efficiency.

RetrievalAttention" uses two big ideas to streamline this attention calculation

Leveraging dynamic sparsity

In the attention calculation, not all tokens are equally important, in fact only some tokens play an important role in predicting the next token. This is called "dynamic sparsity. RetrievalAttention" takes advantage of this property, focusing only on the important tokens and omitting others.

Optimization by vector search

Next, a technique called "approximate optimal search (ANNS)" is used to select the most important tokens in an approximate manner, rather than targeting all tokens. This technique searches for important data from a huge amount of data at high speed, and significantly reduces the amount of computation compared to the usual attention calculation.

Experiment

The RetrievalAttention experiments proposed in this paper test the effectiveness of a method for improving the inferential efficiency of large-scale language models (LLMs) that deal with long contexts. The experiments use several large-scale models and benchmarks to evaluate in detail how well the proposed method performs.

First, the experimental environment used an NVIDIA RTX 4090 GPU (24 GB memory) and several LLMs such as Llama-3-8B and Yi-6B. Each of these models has the ability to process long contexts of up to 128,000 tokens. The goal of the experiment is to see how much faster RetrievalAttention can be compared to other methods while maintaining inference accuracy.

Experiments have evaluated the proposed method in terms of both accuracy and speed. First, in terms of accuracy, RetrievalAttention performed almost as well as FullAttention. This means that by efficiently extracting only important tokens, the computational cost is reduced without affecting the inference results of the model. This result is confirmed by the results of the benchmark task ∞-Bench.

On the other hand, there was also a significant improvement in inference speed. In particular, when processing a long context of 128,000 tokens, RetrievalAttention was found to be nearly five times faster at inference than conventional full attention. This speedup is achieved by significantly reducing access to unnecessary tokens; in the Needle-in-a-haystack task, RetrievalAttention's efficiency in extracting specific information from the vast amount of data is particularly striking, demonstrating RetrievalAttention's superior retrieval performance. RetrievalAttention's superior retrieval performance is demonstrated.

RetrievalAttention also excels in terms of GPU memory usage. Normally, processing long contexts requires a huge amount of memory, but the proposed method is designed to process 128,000 tokens with 16 GB of GPU memory, which allows for efficient inference while keeping hardware costs low.

Thus, RetrievalAttention is a method that significantly improves inference speed and memory efficiency while maintaining accuracy, and has demonstrated practical performance, especially in tasks involving long contexts.

Summary

The conclusion of this paper is that the proposed "RetrievalAttention" method is very effective in streamlining the inference of large-scale language models (LLMs) that deal with long sentences.

Normal attention calculations tend to be time- and memory-intensive due to the huge amount of tokens involved. In particular, the longer the context, the more exponentially the burden of the attention calculation increases. RetrievalAttention, however, solves this problem by focusing on the "dynamic sparsity" in attention and effectively processing only the necessary portions.

The method dynamically selects important tokens and achieves comparable accuracy with less computation and memory usage than conventional methods. Experimental results also show that using RetrievalAttention significantly improves inference speed, especially in long-text tasks, up to 5 times faster.

In addition, RetrievalAttention can efficiently process long contexts even with limited GPU memory, thus reducing memory usage. This feature opens up the possibility of accomplishing tasks that previously required very expensive hardware, even in more affordable environments.

In conclusion, RetrievalAttention is a powerful technique for maintaining accuracy while significantly reducing inference costs when dealing with long contexts, and is very important for the future development of LLM.

  • メルマガ登録(ver
  • ライター
  • エンジニア_大募集!!

If you have any suggestions for improvement of the content of the article,
please contact the AI-SCHOLAR editorial team through the contact form.

Contact Us