AI-Master-Book
  • about AI-Master-Book
  • AI Master Book
    • 이상치 탐지 with Python
    • 베이지안 뉴럴네트워크 (BNN) with Python
    • 그래프 뉴럴네트워크 (GNN) with Python
    • 데이터 마케팅 분석 with Python
  • LLM MASTER BOOK
    • OpenAI API 쿡북 with Python
    • 기초부터 심화까지 RAG 쿡북 with Python
    • MCP 에이전트 쿡북 with Python
  • LLMs
    • OpenAI API
      • 1️⃣ChatCompletion
      • 2️⃣DALL-E
      • 3️⃣Text to Speech
      • 4️⃣Video to Transcripts
      • 5️⃣Assistants API
      • 6️⃣Prompt Engineering
      • 7️⃣OpenAI New GPT-4o
    • LangChain
      • LangChain Basic
        • 1️⃣Basic Modules
        • 2️⃣Model I/O
        • 3️⃣Prompts
        • 4️⃣Chains
        • 5️⃣Agents
        • 6️⃣Tools
        • 7️⃣Memory
      • LangChain Intermediate
        • 1️⃣OpenAI LLM
        • 2️⃣Prompt Template
        • 3️⃣Retrieval
        • 4️⃣RAG ChatBot
        • 5️⃣RAG with Gemini
        • 6️⃣New Huggingface-LangChain
        • 7️⃣Huggingface Hub
        • 8️⃣SQL Agent & Chain
        • 9️⃣Expression Language(LCEL)
        • 🔟Llama3-8B with LangChain
      • LangChain Advanced
        • 1️⃣LLM Evaluation
        • 2️⃣RAG Evaluation with RAGAS
        • 3️⃣LangChain with RAGAS
        • 4️⃣RAG Paradigms
        • 5️⃣LangChain: Advance Techniques
        • 6️⃣LangChain with NeMo-Guardrails
        • 7️⃣LangChain vs. LlamaIndex
        • 8️⃣LangChain LCEL vs. LangGraph
    • LlamaIndex
      • LlamaIndex Basic
        • 1️⃣Introduction
        • 2️⃣Customization
        • 3️⃣Data Connectors
        • 4️⃣Documents & Nodes
        • 5️⃣Naive RAG
        • 6️⃣Advanced RAG
        • 7️⃣Llama3-8B with LlamaIndex
        • 8️⃣LlmaPack
      • LlamaIndex Intermediate
        • 1️⃣QueryEngine
        • 2️⃣Agent
        • 3️⃣Evaluation
        • 4️⃣Evaluation-Driven Development
        • 5️⃣Fine-tuning
        • 6️⃣Prompt Compression with LLMLingua
      • LlamaIndex Advanced
        • 1️⃣Agentic RAG: Router Engine
        • 2️⃣Agentic RAG: Tool Calling
        • 3️⃣Building Agent Reasoning Loop
        • 4️⃣Building Multi-document Agent
    • Hugging Face
      • Huggingface Basic
        • 1️⃣Datasets
        • 2️⃣Tokenizer
        • 3️⃣Sentence Embeddings
        • 4️⃣Transformers
        • 5️⃣Sentence Transformers
        • 6️⃣Evaluate
        • 7️⃣Diffusers
      • Huggingface Tasks
        • NLP
          • 1️⃣Sentiment Analysis
          • 2️⃣Zero-shot Classification
          • 3️⃣Aspect-Based Sentiment Analysis
          • 4️⃣Feature Extraction
          • 5️⃣Intent Classification
          • 6️⃣Topic Modeling: BERTopic
          • 7️⃣NER: Token Classification
          • 8️⃣Summarization
          • 9️⃣Translation
          • 🔟Text Generation
        • Audio & Tabular
          • 1️⃣Text-to-Speech: TTS
          • 2️⃣Speech Recognition: Whisper
          • 3️⃣Audio Classification
          • 4️⃣Tabular Qustaion & Answering
        • Vision & Multimodal
          • 1️⃣Image-to-Text
          • 2️⃣Text to Image
          • 3️⃣Image to Image
          • 4️⃣Text or Image-to-Video
          • 5️⃣Depth Estimation
          • 6️⃣Image Classification
          • 7️⃣Object Detection
          • 8️⃣Segmentatio
      • Huggingface Optimization
        • 1️⃣Accelerator
        • 2️⃣Bitsandbytes
        • 3️⃣Flash Attention
        • 4️⃣Quantization
        • 5️⃣Safetensors
        • 6️⃣Optimum-ONNX
        • 7️⃣Optimum-NVIDIA
        • 8️⃣Optimum-Intel
      • Huggingface Fine-tuning
        • 1️⃣Transformer Fine-tuning
        • 2️⃣PEFT Fine-tuning
        • 3️⃣PEFT: Fine-tuning with QLoRA
        • 4️⃣PEFT: Fine-tuning Phi-2 with QLoRA
        • 5️⃣Axoltl Fine-tuning with QLoRA
        • 6️⃣TRL: RLHF Alignment Fine-tuning
        • 7️⃣TRL: DPO Fine-tuning with Phi-3-4k-instruct
        • 8️⃣TRL: ORPO Fine-tuning with Llama3-8B
        • 9️⃣Convert GGUF gemma-2b with llama.cpp
        • 🔟Apple Silicon Fine-tuning Gemma-2B with MLX
        • 🔢LLM Mergekit
    • Agentic LLM
      • Agentic LLM
        • 1️⃣Basic Agentic LLM
        • 2️⃣Multi-agent with CrewAI
        • 3️⃣LangGraph: Multi-agent Basic
        • 4️⃣LangGraph: Agentic RAG with LangChain
        • 5️⃣LangGraph: Agentic RAG with Llama3-8B by Groq
      • Autonomous Agent
        • 1️⃣LLM Autonomous Agent?
        • 2️⃣AutoGPT: Worldcup Winner Search with LangChain
        • 3️⃣BabyAGI: Weather Report with LangChain
        • 4️⃣AutoGen: Writing Blog Post with LangChain
        • 5️⃣LangChain: Autonomous-agent Debates with Tools
        • 6️⃣CAMEL Role-playing Autonomous Cooperative Agents
        • 7️⃣LangChain: Two-player Harry Potter D&D based CAMEL
        • 8️⃣LangChain: Multi-agent Bid for K-Pop Debate
        • 9️⃣LangChain: Multi-agent Authoritarian Speaker Selection
        • 🔟LangChain: Multi-Agent Simulated Environment with PettingZoo
    • Multimodal
      • 1️⃣PaliGemma: Open Vision LLM
      • 2️⃣FLUX.1: Generative Image
    • Building LLM
      • 1️⃣DSPy
      • 2️⃣DSPy RAG
      • 3️⃣DSPy with LangChain
      • 4️⃣Mamba
      • 5️⃣Mamba RAG with LangChain
      • 7️⃣PostgreSQL VectorDB with pgvorco.rs
Powered by GitBook
On this page
  • About Flash Attention-2
  • Flash Attention
  • Memory Test
  • 1. Standard Inference
  • 2. Inference with Bitsandbytes
  • 3. Inference with Bitsandbytes, FlashAttention2
  • Conclusion
  1. LLMs
  2. Hugging Face
  3. Huggingface Optimization

Flash Attention

PreviousBitsandbytesNextQuantization

Last updated 1 year ago

About Flash Attention-2

GPU는 메모리 대역폭과 병렬 처리에 최적화되어 있기 때문에 CPU와 달리 머신 러닝을 위한 표준 하드웨어로 선택됩니다. 최신 모델의 더 큰 크기를 따라잡거나 기존 및 구형 하드웨어에서 이러한 대형 모델을 실행하기 위해 GPU 추론 속도를 높이는 데 사용할 수 있는 몇 가지 최적화 방법이 있습니다.

Flash Attention(현재 FlashAttention-2)는 메모리 효율이 높은 주의 메커니즘으로 BetterTransformer(파이토치 기본 빠른 경로 실행)와 Bitsandbytes를 사용하여 모델을 더 낮은 정밀도로 정량화합니다.

FlashAttention-2는 standard attention 메커니즘을 더 빠르고 효율적으로 구현한 것으로, 추론 속도를 크게 높일 수 있습니다:

  • 시퀀스 길이에 대한 주의 계산을 추가로 병렬화합니다.

  • GPU 스레드 간 작업을 분할하여 스레드 간 통신 및 공유 메모리 읽기/쓰기를 감소시킵니다.

논문에서 40GB GPU의 타일형 플래시어텐션 계산 패턴과 메모리 계층구조를 보여줍니다. 오른쪽 차트는 어텐션 메커니즘의 여러 구성 요소를 융합하고 재정렬하여 얻을 수 있는 상대적인 속도 향상을 보여줍니다.

Flash Attention

설치 방법은 아래와 같습니다. 설치를 하면 Transformer 라이브러리와 함께 사용할 수 있습니다.

%pip install flash-attn --no-build-isolation
import torch
import bitsandbytes as bnb
from transformers import AutoModelForCausalLM, BitsAndBytesConfig

Memory Test

Flash Attention-2를 사용했을 때 실제 GPU에서 메모리가 어떻게 최적화되고 소모하는지 결과를 비교하겠습니다. 아래는 테스트 조건 입니다:

  • Model: Mistral-7B-v0.3

  • Batchsize: 2

  • Sequence Lenght: 3072

테스트는 3가지 조건으로 실행할 예정입니다.

  1. AutoModelForCausalML

  2. AutoModelForCausalML + BitsAndBytes

  3. AutoModelForCausalML + BitsAndBytes + flash_attention_2

1. Standard Inference

model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-7B-v0.3",
    device_map="auto",
    use_flash_attention_2=False,
)

unit_gib = 1024 ** 3

curr_mem = torch.cuda.memory_reserved() / unit_gib
print(f"Initial: {curr_mem:.4f} GiB")

batch_size = 2
seq_len = 3072
inn = torch.LongTensor([[0] * seq_len] * batch_size)

try:
    with torch.no_grad():
        out = model(inn)

    infer_mem = torch.cuda.memory_reserved() / unit_gib
    print(f"Inference: {infer_mem:.4f} GiB")
except:
    print(f"Inference: OOM")
Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Initial: 7.5684 GiB
Inference: 8.7402 GiB

2. Inference with Bitsandbytes

bnb_config = BitsAndBytesConfig(  
    load_in_4bit= True,
    bnb_4bit_quant_type= "nf4",
    bnb_4bit_compute_dtype= torch.bfloat16,
    bnb_4bit_use_double_quant= False,
)

model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-7B-v0.3",
    device_map="auto",
    use_flash_attention_2=False,
)

unit_gib = 1024 ** 3

curr_mem = torch.cuda.memory_reserved() / unit_gib
print(f"Initial: {curr_mem:.4f} GiB")

batch_size = 2
seq_len = 3072
inn = torch.LongTensor([[0] * seq_len] * batch_size)

try:
    with torch.no_grad():
        out = model(inn)

    infer_mem = torch.cuda.memory_reserved() / unit_gib
    print(f"Inference: {infer_mem:.4f} GiB")
except:
    print(f"Inference: OOM")
Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Initial: 7.5527 GiB
Inference: 8.7246 GiB

3. Inference with Bitsandbytes, FlashAttention2

bnb_config = BitsAndBytesConfig(  
    load_in_4bit= True,
    bnb_4bit_quant_type= "nf4",
    bnb_4bit_compute_dtype= torch.bfloat16,
    bnb_4bit_use_double_quant= False,
)

model: ModelCls = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-7B-v0.3",
    device_map="auto",
    quantization_config=bnb_config,
    use_flash_attention_2=True,
)

unit_gib = 1024 ** 3

curr_mem = torch.cuda.memory_reserved() / unit_gib
print(f"Initial: {curr_mem:.4f} GiB")

batch_size = 2
seq_len = 3072
inn = torch.LongTensor([[0] * seq_len] * batch_size)

try:
    with torch.no_grad():
        out = model(inn)

    infer_mem = torch.cuda.memory_reserved() / unit_gib
    print(f"Inference: {infer_mem:.4f} GiB")
except:
    print(f"Inference: OOM")
Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Initial: 9.3535 GiB
Inference: OOM

Conclusion

첫번째 일반적인 AutoModelForCausalML로 모델을 불러와 Inference 했을 때 8.742G의 GPU 메모리를 사용했습니다. 두번째 의 조합을 사용했을 때 8.724G의 메모리를 사용했습니다. 중요한 사실은 Initial Memory로 초기에 GPU 메모리 확보 입니다. 두 조건 모두 차이가 없습니다.

이제 마지막으로 flash_attention_2=True 로 했을 때 입니다. Initial Memory는 9.3535G로 일반적인 조건보다 더 많이 확보했고, Inference 후에는 0으로 GPU 공간을 잡아먹지 않았습니다.

Train이나 Inference 시 Flash Attention을 사용하면 메모리 최적화를 이룰 수 있습니다.

3️⃣
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness