From 65806044fa39d3e1a728845a15ae9ce8d0f7bd2c Mon Sep 17 00:00:00 2001 From: Simran Arora Date: Mon, 14 Oct 2024 05:53:29 -0700 Subject: [PATCH] update demos --- README.md | 2 +- ...ill_alpaca_clean_xent0_mse1000_lr1e-2.yaml | 2 +- configs/experiment/eval_alpaca_clean.yaml | 2 +- .../finetune_lora_qkvo_alpaca_clean.yaml | 2 +- ...ill_llama3_1_8b_lk_smd_wsw64_fd64_w01.yaml | 2 +- ...ill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml | 2 +- demos/README.md | 1 + demos/benchmark_8b.sh | 22 +- demos/demo_8b.sh | 17 +- demos/vllm_integration/README.md | 82 -- demos/vllm_integration/vllm_files/__init__.py | 205 ---- demos/vllm_integration/vllm_files/lolcats.py | 792 --------------- .../vllm_files/lolcats_inference.py | 870 ----------------- .../vllm_files/lolcats_inference_paged.py | 912 ------------------ .../vllm_files/test_vllm_aw.py | 82 -- src/model/convert_model.py | 6 +- src/model/linear_attention/__init__.py | 4 + .../linear_window_attention_tk_gen.py | 28 +- 18 files changed, 56 insertions(+), 2977 deletions(-) delete mode 100644 demos/vllm_integration/README.md delete mode 100644 demos/vllm_integration/vllm_files/__init__.py delete mode 100644 demos/vllm_integration/vllm_files/lolcats.py delete mode 100644 demos/vllm_integration/vllm_files/lolcats_inference.py delete mode 100644 demos/vllm_integration/vllm_files/lolcats_inference_paged.py delete mode 100644 demos/vllm_integration/vllm_files/test_vllm_aw.py diff --git a/README.md b/README.md index 83c97c5..1025af1 100644 --- a/README.md +++ b/README.md @@ -113,7 +113,7 @@ We also implemented a fused linear attention + sliding window kernel with the [T For the linearizng layer, see [`./src/model/linear_attention/linear_window_attention_tk_gen.py`](https://github.com/HazyResearch/lolcats/blob/main/src/model/linear_attention/linear_window_attention_tk_gen.py) -You can install the kernel and benchmark 8B models (LoLCATS linearized and Llama Transformer) with and without our ThunderKittens CUDA kernel using the details [in this README.md](). Our 8B model will auto-download from our [HuggingFace checkpoint](https://huggingface.co/hazyresearch/lolcats-llama-3.1-8b-distill). +You can install the kernel and benchmark 8B models (LoLCATS linearized and Llama Transformer) with and without our ThunderKittens CUDA kernel using the details [in this README.md](). Our 8B model will auto-download from our [HuggingFace checkpoint](https://huggingface.co/hazyresearch/lolcats-llama-3.1-8b-distill). ### More! diff --git a/configs/experiment/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml b/configs/experiment/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml index 29e1db6..e71170d 100644 --- a/configs/experiment/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml +++ b/configs/experiment/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml @@ -8,7 +8,7 @@ dataset: cache_dir: 'data/alpaca' # Change this to where you want to save pretrained_model_config: # will be updated based on model_config pretrained_model_name_or_path: 'meta-llama/Meta-Llama-3-8B' - cache_dir: '/scr-ssd/mzhang/models/llama3' + cache_dir: '/scratch/' preprocess_config: null dataloader: diff --git a/configs/experiment/eval_alpaca_clean.yaml b/configs/experiment/eval_alpaca_clean.yaml index 5836813..af1005a 100644 --- a/configs/experiment/eval_alpaca_clean.yaml +++ b/configs/experiment/eval_alpaca_clean.yaml @@ -8,7 +8,7 @@ dataset: cache_dir: 'data/alpaca' # Change this to where you want to save pretrained_model_config: pretrained_model_name_or_path: 'mistralai/Mistral-7B-v0.1' # will be updated based on model_config - cache_dir: '/scr-ssd/mzhang/models/mistral-v0.1' + cache_dir: '/scratch/' preprocess_config: null dataloader: diff --git a/configs/experiment/finetune_lora_qkvo_alpaca_clean.yaml b/configs/experiment/finetune_lora_qkvo_alpaca_clean.yaml index 8c53d34..10323a0 100644 --- a/configs/experiment/finetune_lora_qkvo_alpaca_clean.yaml +++ b/configs/experiment/finetune_lora_qkvo_alpaca_clean.yaml @@ -8,7 +8,7 @@ dataset: cache_dir: "data/alpaca" pretrained_model_config: pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1" # will be updated based on model_config - cache_dir: "/data_persistent2/sim_data/" + cache_dir: "/scratch/" preprocess_config: null dataloader: diff --git a/configs/model/distill_llama3_1_8b_lk_smd_wsw64_fd64_w01.yaml b/configs/model/distill_llama3_1_8b_lk_smd_wsw64_fd64_w01.yaml index 8d7d44c..0313f3a 100644 --- a/configs/model/distill_llama3_1_8b_lk_smd_wsw64_fd64_w01.yaml +++ b/configs/model/distill_llama3_1_8b_lk_smd_wsw64_fd64_w01.yaml @@ -1,7 +1,7 @@ name: llama model: pretrained_model_name_or_path: "meta-llama/Meta-Llama-3.1-8B" - cache_dir: "/scr-ssd/mzhang/models/llama-3_1-8b" # Set this to where you want to save checkpoint weights + cache_dir: "/scratch/" # Set this to where you want to save checkpoint weights return_dict: true load_in_8bit: false load_in_4bit: false diff --git a/configs/model/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml b/configs/model/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml index f3d04b5..7bc9a4d 100644 --- a/configs/model/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml +++ b/configs/model/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml @@ -1,7 +1,7 @@ name: llama model: pretrained_model_name_or_path: "meta-llama/Meta-Llama-3.1-8B" - cache_dir: "/scr-ssd/mzhang/models/llama-3_1-8b" # Set this to where you want to save checkpoint weights + cache_dir: "/scratch/" # Set this to where you want to save checkpoint weights return_dict: true load_in_8bit: false load_in_4bit: false diff --git a/demos/README.md b/demos/README.md index c88d4a3..d64226a 100644 --- a/demos/README.md +++ b/demos/README.md @@ -32,6 +32,7 @@ As a quick end-to-end compare the prefill speed of the linearized LoLCATS 8B vs. ```bash bash benchmark_8b.sh ``` +Our benchmarking implementation is currently restricted to prefill lengths that are multiples of 64. The code will print out the inference tokens per second per method. diff --git a/demos/benchmark_8b.sh b/demos/benchmark_8b.sh index a62369d..4bb4104 100644 --- a/demos/benchmark_8b.sh +++ b/demos/benchmark_8b.sh @@ -1,14 +1,14 @@ -CONFIG_DIR='/home/bfs/simran/clean4/lolcats/configs/' # update to your path +CONFIG_DIR='/home/bfs/simran/attention/lolcats/configs/' # update to your path -""" Benchmarking the 8b model on the LOLCATS dataset """ +# """ Benchmarking the 8b model on the LOLCATS dataset """ # Run the linearized model with the ThunderKittens kernel CUDA_VISIBLE_DEVICES=0 python -Wignore demo_lolcats_hf.py \ - --model_config_path ${CONFIG_DIR}/model/llama3_1_8b/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml \ - --distill_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml \ - --finetune_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/finetune_qkvo_alpaca_clean.yaml \ + --model_config_path ${CONFIG_DIR}/model/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml \ + --distill_config_path ${CONFIG_DIR}/experiment/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml \ + --finetune_config_path ${CONFIG_DIR}/experiment/finetune_lora_qkvo_alpaca_clean.yaml \ --attn_mlp_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-distill' \ --finetune_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-ft-lora' \ --num_generations 1 \ @@ -18,9 +18,9 @@ CUDA_VISIBLE_DEVICES=0 python -Wignore demo_lolcats_hf.py \ # Run the linearized model *without* the ThunderKittens kernel CUDA_VISIBLE_DEVICES=0 python -Wignore demo_lolcats_hf.py \ - --model_config_path ${CONFIG_DIR}/model/llama3_1_8b/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml \ - --distill_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml \ - --finetune_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/finetune_qkvo_alpaca_clean.yaml \ + --model_config_path ${CONFIG_DIR}/model/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml \ + --distill_config_path ${CONFIG_DIR}/experiment/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml \ + --finetune_config_path ${CONFIG_DIR}/experiment/finetune_lora_qkvo_alpaca_clean.yaml \ --attn_mlp_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-distill' \ --finetune_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-ft-lora' \ --num_generations 1 \ @@ -30,9 +30,9 @@ CUDA_VISIBLE_DEVICES=0 python -Wignore demo_lolcats_hf.py \ # Run the base Llama model with Transformers SDPA attention CUDA_VISIBLE_DEVICES=0 python -Wignore demo_lolcats_hf.py \ - --model_config_path ${CONFIG_DIR}/model/llama3_1_8b/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml \ - --distill_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml \ - --finetune_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/finetune_qkvo_alpaca_clean.yaml \ + --model_config_path ${CONFIG_DIR}/model/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml \ + --distill_config_path ${CONFIG_DIR}/experiment/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml \ + --finetune_config_path ${CONFIG_DIR}/experiment/finetune_lora_qkvo_alpaca_clean.yaml \ --attn_mlp_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-distill' \ --finetune_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-ft-lora' \ --num_generations 1 \ diff --git a/demos/demo_8b.sh b/demos/demo_8b.sh index a1e1b9a..c278519 100644 --- a/demos/demo_8b.sh +++ b/demos/demo_8b.sh @@ -1,16 +1,19 @@ -CONFIG_DIR='/home/bfs/simran/clean4/lolcats/configs/' # update to your path +CONFIG_DIR='/home/bfs/simran/attention/lolcats/configs/' # update to your path -# using huggingface checkpoints +# Using huggingface checkpoints CUDA_VISIBLE_DEVICES=0 python -Wignore demo_lolcats_hf.py \ - --model_config_path ${CONFIG_DIR}/model/llama3_1_8b/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml \ - --distill_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml \ - --finetune_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/finetune_qkvo_alpaca_clean.yaml \ + --model_config_path ${CONFIG_DIR}/model/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml \ + --distill_config_path ${CONFIG_DIR}/experiment/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml \ + --finetune_config_path ${CONFIG_DIR}/experiment/finetune_lora_qkvo_alpaca_clean.yaml \ --attn_mlp_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-distill' \ --finetune_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-ft-lora' \ - --num_generations 1 + --num_generations 1 \ + --max_new_tokens 50 -# if you train your own LoLCATS weights, you can use the following command to run inference: + +# Reference script: +# if you train your own LoLCATS weights, you can use the following command to run inference with your local checkpoints: # CHECKPOINT_DIR='/home/mzhang/projects/lolcats/checkpoints/' # CUDA_VISIBLE_DEVICES=0 python -Wignore demo_lolcats_hf.py \ # --model_config_path ${CONFIG_DIR}/model/llama3_1_8b/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml \ diff --git a/demos/vllm_integration/README.md b/demos/vllm_integration/README.md deleted file mode 100644 index 56afa27..0000000 --- a/demos/vllm_integration/README.md +++ /dev/null @@ -1,82 +0,0 @@ -## Coming Soon! VLLM Integration - - -#### 1. Clone VLLM -Also run VLLM installations. -```bash -git clone https://github.com/vllm-project/vllm -``` - -#### 2. Copy the following LoLCATS specific files into vllm. - -``` -bash -cp lolcats/inference/vllm_files/lolcats.py vllm/model_executor/models/lolcats.py -``` - -And add the new LoLCATS models from: -```bash -lolcats/inference/vllm_files/__init__.py -> vllm/model_executor/models/__init__.py -``` - -#### 3. Set the model checkpoint paths. - -Given your local download of the 405B weights, go to the ```Meta-Llama-3.1-405B/config.py``` file and modify the architecture list from ```LlamaForCausalLM``` to ```LlamaLolcatsForCausalLM```. - -In ```vllm/model_executor/models/lolcats_inference_paged.py``` set the ```PATH=....pt``` to the name of your copy of the linearized weights (feature maps and LoRA). - -#### 4. Run VLLM. - -These instructions assume you have 2 nodes of $8 \times 80$GB to fit the FP16 405B model. You are okay with 1 node for 70B parameters. -```bash - -# Step 1. Follow the VLLM installation quick start to install it in your environment. - -# Step 2. Set up a 2 node ray cluster. On the respective nodes, run: -ray start --head # on node 1 -ray start --address='ip from above' # on node 2 - -# Step 3. Load the model on the 2 nodes, creating an OpenAI endpoint. Remember to hard code the ckpt paths in lolcats.py PATH (cant use env variable on multinode). Set tensor-parallel-size to 8 if using 1 node. Run this on the head node (node 1). -vllm serve /path/to/hf/model/Meta-Llama-3.1-405B --tensor-parallel-size 16 --enforce-eager # on node 1 -``` - -#### 5. Clone LM-Eval harness and run inference evaluations: -```bash -git clone https://github.com/EleutherAI/lm-evaluation-harness -git checkout b281b092 -pip install -e .[api] -``` - -Note that if ```datasets.load_datasets``` gives an issue, it helps to ```pip install -U datasets```. - -Launch the evaluation commands on node 1 (the head node of the ray cluster). -```bash -lm_eval --model local-completions --tasks piqa,hellaswag,winogrande,arc_challenge,arc_easy --model_args model='/path/to/hf/model/Meta-Llama-3.1-405B',base_url=http://localhost:8000/v1/completions,num_concurrent=1,max_retries=3,tokenized_requests=False --batch_size 1 --output save/ - -lm_eval --model local-completions --tasks mmlu --num_fewshot 5 --model_args model='/path/to/hf/model/Meta-Llama-3.1-405B',base_url=http://localhost:8000/v1/completions,num_concurrent=1,max_retries=3,tokenized_requests=False --batch_size 1 --output save/ -``` - -#### References -Please cite the following if you use their work: - -``` -@misc{eval-harness, - author = {Gao, Leo and Tow, Jonathan and Abbasi, Baber and Biderman, Stella and Black, Sid and DiPofi, Anthony and Foster, Charles and Golding, Laurence and Hsu, Jeffrey and Le Noac'h, Alain and Li, Haonan and McDonell, Kyle and Muennighoff, Niklas and Ociepa, Chris and Phang, Jason and Reynolds, Laria and Schoelkopf, Hailey and Skowron, Aviya and Sutawika, Lintang and Tang, Eric and Thite, Anish and Wang, Ben and Wang, Kevin and Zou, Andy}, - title = {A framework for few-shot language model evaluation}, - month = 07, - year = 2024, - publisher = {Zenodo}, - version = {v0.4.3}, - doi = {10.5281/zenodo.12608602}, - url = {https://zenodo.org/records/12608602} -} -``` - -``` -@inproceedings{kwon2023efficient, - title={Efficient Memory Management for Large Language Model Serving with PagedAttention}, - author={Woosuk Kwon and Zhuohan Li and Siyuan Zhuang and Ying Sheng and Lianmin Zheng and Cody Hao Yu and Joseph E. Gonzalez and Hao Zhang and Ion Stoica}, - booktitle={Proceedings of the ACM SIGOPS 29th Symposium on Operating Systems Principles}, - year={2023} -} -``` diff --git a/demos/vllm_integration/vllm_files/__init__.py b/demos/vllm_integration/vllm_files/__init__.py deleted file mode 100644 index 6953dab..0000000 --- a/demos/vllm_integration/vllm_files/__init__.py +++ /dev/null @@ -1,205 +0,0 @@ -import functools -import importlib -from typing import Dict, List, Optional, Tuple, Type - -import torch.nn as nn - -from vllm.logger import init_logger -from vllm.utils import is_hip - -logger = init_logger(__name__) - -_GENERATION_MODELS = { - "AquilaModel": ("llama", "LlamaForCausalLM"), - "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2 - "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b - "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b - "BloomForCausalLM": ("bloom", "BloomForCausalLM"), - "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), - "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"), - "CohereForCausalLM": ("commandr", "CohereForCausalLM"), - "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"), - "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), - "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"), - "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"), - "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"), - "FalconForCausalLM": ("falcon", "FalconForCausalLM"), - "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"), - "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"), - "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"), - "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"), - "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"), - "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"), - "InternLMForCausalLM": ("llama", "LlamaForCausalLM"), - "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"), - "JAISLMHeadModel": ("jais", "JAISLMHeadModel"), - "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), - # For decapoda-research/llama-* - "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), - "MistralForCausalLM": ("llama", "LlamaForCausalLM"), - "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), - "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"), - # transformers's mpt class has lower case - "MptForCausalLM": ("mpt", "MPTForCausalLM"), - "MPTForCausalLM": ("mpt", "MPTForCausalLM"), - "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"), - "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"), - "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"), - "OPTForCausalLM": ("opt", "OPTForCausalLM"), - "OrionForCausalLM": ("orion", "OrionForCausalLM"), - "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"), - "PhiForCausalLM": ("phi", "PhiForCausalLM"), - "Phi3ForCausalLM": ("llama", "LlamaForCausalLM"), - "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), - "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), - "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"), - "RWForCausalLM": ("falcon", "FalconForCausalLM"), - "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), - "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), - "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), - "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"), - "XverseForCausalLM": ("xverse", "XverseForCausalLM"), - "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), - "MedusaModel": ("medusa", "Medusa"), - "EAGLEModel": ("eagle", "EAGLE"), - "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), - "JambaForCausalLM": ("jamba", "JambaForCausalLM"), - "GraniteForCausalLM": ("granite", "GraniteForCausalLM"), - "LlamaLolcatsForCausalLM": ("lolcats", "LlamaLolcatsForCausalLM") -} - -_EMBEDDING_MODELS = { - "MistralModel": ("llama_embedding", "LlamaEmbeddingModel"), -} - -_MULTIMODAL_MODELS = { - "Blip2ForConditionalGeneration": - ("blip2", "Blip2ForConditionalGeneration"), - "ChameleonForConditionalGeneration": - ("chameleon", "ChameleonForConditionalGeneration"), - "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), - "InternVLChatModel": ("internvl", "InternVLChatModel"), - "LlavaForConditionalGeneration": - ("llava", "LlavaForConditionalGeneration"), - "LlavaNextForConditionalGeneration": - ("llava_next", "LlavaNextForConditionalGeneration"), - "MiniCPMV": ("minicpmv", "MiniCPMV"), - "PaliGemmaForConditionalGeneration": ("paligemma", - "PaliGemmaForConditionalGeneration"), - "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), - "UltravoxModel": ("ultravox", "UltravoxModel"), - "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), -} -_CONDITIONAL_GENERATION_MODELS = { - "BartModel": ("bart", "BartForConditionalGeneration"), - "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"), -} - -_MODELS = { - **_GENERATION_MODELS, - **_EMBEDDING_MODELS, - **_MULTIMODAL_MODELS, - **_CONDITIONAL_GENERATION_MODELS, -} - -# Architecture -> type. -# out of tree models -_OOT_MODELS: Dict[str, Type[nn.Module]] = {} - -# Models not supported by ROCm. -_ROCM_UNSUPPORTED_MODELS: List[str] = [] - -# Models partially supported by ROCm. -# Architecture -> Reason. -_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in " - "Triton flash attention. For half-precision SWA support, " - "please use CK flash attention by setting " - "`VLLM_USE_TRITON_FLASH_ATTN=0`") -_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = { - "Qwen2ForCausalLM": - _ROCM_SWA_REASON, - "MistralForCausalLM": - _ROCM_SWA_REASON, - "MixtralForCausalLM": - _ROCM_SWA_REASON, - "PaliGemmaForConditionalGeneration": - ("ROCm flash attention does not yet " - "fully support 32-bit precision on PaliGemma"), - "Phi3VForCausalLM": - ("ROCm Triton flash attention may run into compilation errors due to " - "excessive use of shared memory. If this happens, disable Triton FA " - "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`") -} - - -class ModelRegistry: - - @staticmethod - @functools.lru_cache(maxsize=128) - def _get_model(model_arch: str): - module_name, model_cls_name = _MODELS[model_arch] - module = importlib.import_module( - f"vllm.model_executor.models.{module_name}") - return getattr(module, model_cls_name, None) - - @staticmethod - def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: - if model_arch in _OOT_MODELS: - return _OOT_MODELS[model_arch] - if model_arch not in _MODELS: - return None - if is_hip(): - if model_arch in _ROCM_UNSUPPORTED_MODELS: - raise ValueError( - f"Model architecture {model_arch} is not supported by " - "ROCm for now.") - if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS: - logger.warning( - "Model architecture %s is partially supported by ROCm: %s", - model_arch, _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]) - - return ModelRegistry._get_model(model_arch) - - @staticmethod - def resolve_model_cls( - architectures: List[str]) -> Tuple[Type[nn.Module], str]: - for arch in architectures: - model_cls = ModelRegistry._try_load_model_cls(arch) - if model_cls is not None: - return (model_cls, arch) - - raise ValueError( - f"Model architectures {architectures} are not supported for now. " - f"Supported architectures: {ModelRegistry.get_supported_archs()}") - - @staticmethod - def get_supported_archs() -> List[str]: - return list(_MODELS.keys()) + list(_OOT_MODELS.keys()) - - @staticmethod - def register_model(model_arch: str, model_cls: Type[nn.Module]): - if model_arch in _MODELS: - logger.warning( - "Model architecture %s is already registered, and will be " - "overwritten by the new model class %s.", model_arch, - model_cls.__name__) - global _OOT_MODELS - _OOT_MODELS[model_arch] = model_cls - - @staticmethod - def is_embedding_model(model_arch: str) -> bool: - return model_arch in _EMBEDDING_MODELS - - @staticmethod - def is_multimodal_model(model_arch: str) -> bool: - - # TODO: find a way to avoid initializing CUDA prematurely to - # use `supports_multimodal` to determine if a model is multimodal - # model_cls = ModelRegistry._try_load_model_cls(model_arch) - # from vllm.model_executor.models.interfaces import supports_multimodal - return model_arch in _MULTIMODAL_MODELS - - -__all__ = [ - "ModelRegistry", -] diff --git a/demos/vllm_integration/vllm_files/lolcats.py b/demos/vllm_integration/vllm_files/lolcats.py deleted file mode 100644 index 1a9612c..0000000 --- a/demos/vllm_integration/vllm_files/lolcats.py +++ /dev/null @@ -1,792 +0,0 @@ -# coding=utf-8 -# Adapted from -# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py -# Copyright 2023 The vLLM team. -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Inference-only LLaMA model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union - -import math -import os -import torch -from collections import OrderedDict -from torch import nn -from torch.nn.parameter import Parameter -from transformers import LlamaConfig - -from vllm.attention import Attention, AttentionMetadata, AttentionType -from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) -from vllm.distributed import (divide, - # split_tensor_along_last_dim, - # tensor_model_parallel_all_gather, - # tensor_model_parallel_all_reduce - ) -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) -from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( - get_compressed_tensors_cache_scale) -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors -from vllm.utils import is_hip - - - -from vllm.model_executor.models.interfaces import SupportsLoRA -from vllm.model_executor.models.utils import PPMissingLayer, is_pp_missing_parameter, make_layers -# from .interfaces import SupportsLoRA -# from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers -from vllm.model_executor.models.llama import LlamaForCausalLM, LlamaAttention, LlamaMLP - -from vllm.logger import init_logger -from vllm.model_executor.layers.linear import ReplicatedLinear - -logger = init_logger(__name__) - - -### OURS for Linear attention implementation -# from peft import get_peft_model, LoraConfig, TaskType - -# PEFT_KWARGS = { -# 'r': 8, -# 'lora_alpha': 16, # 32 -# 'lora_dropout': 0.05, -# 'target_modules': ["q_proj", "v_proj", "k_proj", "o_proj"] -# } - -### Hybrid Attention - - -from vllm.attention import Attention, AttentionMetadata - -class LlamaLoraAttention(LlamaAttention): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - _device = self.qkv_proj.weight.device - _dtype = self.qkv_proj.weight.dtype - print("Hello from Llama Lora Attention") - - def merge_lora_to_qkv_parallel(self, # param: Parameter, - loaded_delta: torch.Tensor, - loaded_shard_id: str = 'q', - total_num_heads: int = 32, - total_num_kv_heads: int = 4, - head_size: int = 128): - """ - Merge computed delta_AB into QKV parallel weights - - Based off of vLLM linear layer: https://github.com/vllm-project/vllm/blob/bc6e42a9b19364e07da9f279edd81796541d147d/vllm/model_executor/layers/linear.py#L762 - then Rahul, then Claude 3.5 Sonnet - - model.layers.1.self_attn.qkv_proj.weight torch.Size([1280, 8192]) - --> output_dim 0 - model.layers.1.self_attn.o_proj.weight torch.Size([8192, 1024]) - --> output_dim 0 - - apply this three times for q, k, and v LoRA deltas to the same layer - """ - - param = self.qkv_proj.weight - param_data = param.data - output_dim = getattr(param, "output_dim", None) - - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() - # num_heads = divide(total_num_heads, tp_size) - # if tp_size >= total_num_kv_heads: - # num_kv_heads = 1 - # # num_kv_head_replicas = divide(tp_size, total_num_kv_heads) - # else: - # num_kv_heads = divide(total_num_kv_heads, tp_size) - # # num_kv_head_replicas = 1 - num_heads = total_num_heads - num_kv_heads = total_num_kv_heads - - num_original_kv_heads = 8 # all Llama 3.1 models have 8 kv heads in total - - num_kv_head_replicas = tp_size // num_original_kv_heads - - if loaded_shard_id == "q": - shard_offset = 0 - shard_size = num_heads * head_size - elif loaded_shard_id == "k": - shard_offset = num_heads * head_size - shard_size = num_kv_heads * head_size - elif loaded_shard_id == "v": - shard_offset = (num_heads + num_kv_heads) * head_size - shard_size = num_kv_heads * head_size - - # print(f"{tp_rank=}, {tp_size=}") - if loaded_shard_id == "q": - start_idx = tp_rank * shard_size - else: - start_idx = (tp_rank // num_kv_head_replicas) * shard_size - - device = param_data.device - - param_data = param_data.narrow(output_dim, shard_offset, shard_size) - # print(f'{loaded_shard_id=}') - # print(f'{shard_offset=}, {shard_size=}, {shard_offset+shard_size=}') - # print(f'{output_dim=}, {start_idx=}, {param_data.shape=}') - # print('-' * 10) - - # self.qkv_proj.weight.data[shard_offset:shard_offset+shard_size, :] += ( - # loaded_delta.narrow(output_dim, start_idx, shard_size).to(device) - # ) - # print(f'Loading {loaded_shard_id} {start_idx}-{start_idx+shard_size} into {param_data.shape}, which is slice({shard_offset}, {shard_offset+shard_size})') - try: - param_data.copy_(param_data + loaded_delta.narrow(output_dim, start_idx, shard_size).to(device)) - # print(f"Loaded {loaded_shard_id} into {param_data.shape}") - except Exception as e: - print(f"Error: {e}") - print(f"{loaded_shard_id=}") - print(f"{output_dim=}") - print(f"{start_idx=}") - print(f"{shard_size=}") - print(f"{param_data.shape=}") - print(f"{loaded_delta.shape=}") - print(f"{tp_rank=}") - print(f"{tp_size=}") - - def merge_lora_to_o_parallel(self, - loaded_delta: torch.Tensor): - """ - Merge computed delta_AB into output projection (RowParallel linear) - """ - param = self.o_proj.weight - param_data = param.data - input_dim = getattr(param, "input_dim", None) - device = param_data.device - - # print('o_proj {input_dim=}') - - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() - - if input_dim is not None: - shard_size = param_data.shape[input_dim] - start_idx = tp_rank * shard_size - loaded_delta = loaded_delta.narrow(input_dim, start_idx, shard_size) - - # Special case for loading scales off disk, which often do not - # have a shape (such as in the case of AutoFP8). - if len(loaded_delta.shape) == 0: - loaded_delta = loaded_delta.reshape(1) - - # print('{param_data.shape=} | {loaded_delta.shape=}') - # assert param_data.shape == loaded_delta.shape - param_data.copy_(param_data + loaded_delta.to(device)) - - -### VLLM Llama Model - - -class FeatureMap(nn.Module): - """ - Learnable MLP in feature map. - - Full feature map is like f(xW + b) - -> This is the `W` and (optional) `b` part - """ - def __init__(self, - num_heads: int, - head_dim: int, - feature_dim: int, - dtype: torch.dtype, - device: torch.device, - eps: float = 1e-12, - **kwargs): - super().__init__() - self.num_heads = num_heads - self.head_dim = head_dim - self.feature_dim = feature_dim - self.dtype = dtype - self.device = device - self.eps = eps - self.init_weights_() - - def activation(self, x: torch.Tensor): - return torch.cat([ - torch.softmax(x, dim=-1), torch.softmax(-x, dim=-1) - ], dim=-1).clamp(min=self.eps) - - def init_weights_(self): - self.layer = nn.Parameter(torch.zeros( - (self.num_heads, self.head_dim, self.feature_dim), - dtype=self.dtype, device=self.device, - )) - - def forward(self, x: torch.Tensor): - return self.activation( - torch.einsum('hdf,bhld->bhlf', self.layer, x)) - - -class LlamaLolcatsAttentionActual(nn.Module): - """Attention layer. - - This class takes query, key, and value tensors as input. The input tensors - can either contain prompt tokens or generation tokens. - The class does the following: - - 1. Store the input key and value tensors in the KV cache. - 2. Perform (multi-head/multi-query/grouped-query) attention. - 3. Return the output tensor. - """ - - def __init__( - self, - num_heads: int, - head_size: int, - num_kv_heads: int, - ) -> None: - super().__init__() - self.num_heads = num_heads - self.head_size = head_size - self.num_kv_heads = num_kv_heads - - assert self.num_heads % self.num_kv_heads == 0 - self.num_queries_per_kv = num_heads // num_kv_heads - - max_seq_len = 2048 - window_size = 64 - - self.register_buffer('mask_window', self._create_mask(max_seq_len, window_size, True)) - self.register_buffer('mask_linear', self._create_mask(max_seq_len, window_size, False)) - - - def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - fmap_q: FeatureMap, - fmap_k: FeatureMap, - window_factors: torch.Tensor, - ) -> torch.Tensor: - # num_tokens, hidden_size = query.shape - # Reshape the query, key, and value tensors. - query = query.view(-1, self.num_heads, self.head_size) - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) - - if self.num_kv_heads != self.num_heads: - key = key.repeat_interleave(self.num_queries_per_kv, dim=1) - value = value.repeat_interleave(self.num_queries_per_kv, dim=1) - - query = query.movedim(0, query.dim() - 2) - key = key.movedim(0, key.dim() - 2) - value = value.movedim(0, value.dim() - 2) - - if query.dim() == 3: - query = query.unsqueeze(0) - key = key.unsqueeze(0) - value = value.unsqueeze(0) - - f_q = fmap_q(query) - f_k = fmap_k(key) - - window_size = 64 - window_factors = torch.nn.functional.sigmoid(window_factors) - linear_factors = 1 - # linear_factors = 1 - window_factors - - return self.superlinear_attention(query, key, f_q, f_k, - value, - window_factors, - linear_factors, - window_size) - - def _create_mask(self, max_seq_len: int, window_size: int, is_window: bool) -> torch.Tensor: - l = window_size - m = math.ceil(max_seq_len / window_size) - mask = torch.block_diag(*[torch.ones((l, l))] * m) - mask += torch.roll(mask, -l, -1) - mask = mask[:max_seq_len, :max_seq_len] - mask = mask[None, None, ...] # b, h, q_len, k_len - mask = torch.tril(mask if is_window else 1 - mask).to(dtype=torch.bool) - return mask - - def get_masks(self, window_size: int, q_len: int, k_len: int, device: torch.device) -> tuple[torch.Tensor]: - return self.mask_window[:, :, :q_len, :k_len], self.mask_linear[:, :, :q_len, :k_len] - - def superlinear_attention(self, q: torch.Tensor, k: torch.Tensor, - f_q: torch.Tensor, f_k: torch.Tensor, - v: torch.Tensor, - window_factor: torch.Tensor, - linear_factor: torch.Tensor, - window_size: int, - kv_state: torch.Tensor = None, - k_state: torch.Tensor = None, - eps: float = 1e-12, - mask_value: float=-1e8): - """ - Hybrid attention combining sliding window and linear attentions - """ - - mask_window, mask_linear = self.get_masks( - window_size, q.shape[-2], k.shape[-2], q.device) - - # 1. Sliding window (softmax attention) - # a_sm = torch.einsum( - # 'bhmd,bhnd->bhmn', q.float(), k.float()) * (k.shape[-1] ** -0.5) - a_sm = torch.einsum( - 'bhmd,bhnd->bhmn', q, k) * (k.shape[-1] ** -0.5) - a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value) - # torch.softmax(a_sm, dim=-1), but we account for the max when combining - a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True) - a_sm = window_factor * torch.exp(a_sm - a_sm_max) - sum_sm = a_sm.sum(dim=-1, keepdim=True) - - # 2. Under window (linear attention) - # a_ln = torch.einsum('bhmd,bhnd->bhmn', f_q.float(), f_k.float()) - a_ln = torch.einsum('bhmd,bhnd->bhmn', f_q, f_k) - a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0) - sum_ln = a_ln.sum(dim=-1, keepdim=True) - - # 3. Combine - # Allow outputs to also depend on prior kv_state and k_state - # y = torch.einsum('bhmn,bhnd->bhmd', a_sm + a_ln, v.float()) - # y = (y / (sum_sm + sum_ln)).to(q.dtype) - y = torch.einsum('bhmn,bhnd->bhmd', a_sm + a_ln, v) - y = (y / (sum_sm + sum_ln)) - # # logger.info(f"splattn {y.shape=}") - return y # attention weights only for the last chunk - - -class LlamaLolcatsAttention(LlamaAttention): - def __init__(self, *args, **kwargs): - - super().__init__(*args, **kwargs) - self.attn = LlamaLolcatsAttentionActual(self.num_heads, - self.head_dim, - self.num_kv_heads) - - _device = self.qkv_proj.weight.device - _dtype = self.qkv_proj.weight.dtype - - _feature_dim = 64 - - _feature_map_kwargs = { - "num_heads": self.num_heads, - "head_dim": self.head_dim, - "feature_dim": _feature_dim, - "dtype": _dtype, - "device": _device, - } - - self.feature_map_q = FeatureMap(**_feature_map_kwargs) - self.feature_map_k = FeatureMap(**_feature_map_kwargs) - self.window_factors = nn.Parameter( - torch.ones(1, self.num_heads, 1, 1, device=_device, dtype=_dtype)) - - def load_window_factors(self, loaded_weight: torch.Tensor): - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() - - if tp_size > 1: - - num_heads_per_rank = self.num_heads - start_idx = tp_rank * num_heads_per_rank - end_idx = start_idx + num_heads_per_rank - - sharded_weight = loaded_weight[:, start_idx:end_idx, :, :] - - else: - - sharded_weight = loaded_weight - - assert self.window_factors.shape == sharded_weight.shape, \ - f"Shape mismatch: {self.window_factors.shape} vs {sharded_weight.shape}" - - with torch.no_grad(): - self.window_factors.copy_(sharded_weight) - - def load_feature_map_q(self, loaded_weight: torch.Tensor): - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() - - # print(f"{tp_size}") - - if tp_size > 1: - - num_heads_per_rank = self.feature_map_q.num_heads - start_idx = tp_rank * num_heads_per_rank - end_idx = start_idx + num_heads_per_rank - - sharded_weight = loaded_weight[start_idx:end_idx, :, :] - - if sharded_weight.shape[-1] != self.feature_map_q.layer.shape[-1]: - sharded_weight = sharded_weight[:, :, :self.feature_map_q.layer.shape[-1]] - - else: - - sharded_weight = loaded_weight - - assert self.feature_map_q.layer.shape == sharded_weight.shape, \ - f"Shape mismatch: {self.feature_map_q.layer.shape} vs {sharded_weight.shape}" - - with torch.no_grad(): - self.feature_map_q.layer.copy_(sharded_weight) - - def load_feature_map_k(self, loaded_weight: torch.Tensor): - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() - - if tp_size > 1: - - num_heads_per_rank = self.feature_map_k.num_heads - start_idx = tp_rank * num_heads_per_rank - end_idx = start_idx + num_heads_per_rank - - sharded_weight = loaded_weight[start_idx:end_idx, :, :] - - if sharded_weight.shape[-1] != self.feature_map_k.layer.shape[-1]: - sharded_weight = sharded_weight[:, :, :self.feature_map_k.layer.shape[-1]] - - else: - - sharded_weight = loaded_weight - - assert self.feature_map_k.layer.shape == sharded_weight.shape, \ - f"Shape mismatch: {self.feature_map_k.layer.shape} vs {sharded_weight.shape}" - - with torch.no_grad(): - self.feature_map_k.layer.copy_(sharded_weight) - # self.feature_map_k.layer.normal_(std=1) - - def merge_lora_to_qkv_parallel(self, # param: Parameter, - loaded_delta: torch.Tensor, - loaded_shard_id: str = 'q', - total_num_heads: int = 32, - total_num_kv_heads: int = 4, - head_size: int = 128): - """ - Merge computed delta_AB into QKV parallel weights - - Based off of vLLM linear layer: https://github.com/vllm-project/vllm/blob/bc6e42a9b19364e07da9f279edd81796541d147d/vllm/model_executor/layers/linear.py#L762 - then Rahul, then Claude 3.5 Sonnet - - model.layers.1.self_attn.qkv_proj.weight torch.Size([1280, 8192]) - --> output_dim 0 - model.layers.1.self_attn.o_proj.weight torch.Size([8192, 1024]) - --> output_dim 0 - - apply this three times for q, k, and v LoRA deltas to the same layer - """ - - param = self.qkv_proj.weight - param_data = param.data - output_dim = getattr(param, "output_dim", None) - - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() - # num_heads = divide(total_num_heads, tp_size) - # if tp_size >= total_num_kv_heads: - # num_kv_heads = 1 - # # num_kv_head_replicas = divide(tp_size, total_num_kv_heads) - # else: - # num_kv_heads = divide(total_num_kv_heads, tp_size) - # # num_kv_head_replicas = 1 - num_heads = total_num_heads - num_kv_heads = total_num_kv_heads - - num_original_kv_heads = 8 # all Llama 3.1 models have 8 kv heads in total - - num_kv_head_replicas = tp_size // num_original_kv_heads - - if loaded_shard_id == "q": - shard_offset = 0 - shard_size = num_heads * head_size - elif loaded_shard_id == "k": - shard_offset = num_heads * head_size - shard_size = num_kv_heads * head_size - elif loaded_shard_id == "v": - shard_offset = (num_heads + num_kv_heads) * head_size - shard_size = num_kv_heads * head_size - - # print(f"{tp_rank=}, {tp_size=}") - if loaded_shard_id == "q": - start_idx = tp_rank * shard_size - else: - start_idx = (tp_rank // num_kv_head_replicas) * shard_size - - device = param_data.device - - param_data = param_data.narrow(output_dim, shard_offset, shard_size) - # print(f'{loaded_shard_id=}') - # print(f'{shard_offset=}, {shard_size=}, {shard_offset+shard_size=}') - # print(f'{output_dim=}, {start_idx=}, {param_data.shape=}') - # print('-' * 10) - - # self.qkv_proj.weight.data[shard_offset:shard_offset+shard_size, :] += ( - # loaded_delta.narrow(output_dim, start_idx, shard_size).to(device) - # ) - # print(f'Loading {loaded_shard_id} {start_idx}-{start_idx+shard_size} into {param_data.shape}, which is slice({shard_offset}, {shard_offset+shard_size})') - try: - param_data.copy_(param_data + loaded_delta.narrow(output_dim, start_idx, shard_size).to(device)) - # print(f"Loaded {loaded_shard_id} into {param_data.shape}") - except Exception as e: - print(f"Error: {e}") - print(f"{loaded_shard_id=}") - print(f"{output_dim=}") - print(f"{start_idx=}") - print(f"{shard_size=}") - print(f"{param_data.shape=}") - print(f"{loaded_delta.shape=}") - print(f"{tp_rank=}") - print(f"{tp_size=}") - - def merge_lora_to_o_parallel(self, - loaded_delta: torch.Tensor): - """ - Merge computed delta_AB into output projection (RowParallel linear) - """ - param = self.o_proj.weight - param_data = param.data - input_dim = getattr(param, "input_dim", None) - device = param_data.device - - # print('o_proj {input_dim=}') - - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() - - if input_dim is not None: - shard_size = param_data.shape[input_dim] - start_idx = tp_rank * shard_size - loaded_delta = loaded_delta.narrow(input_dim, start_idx, shard_size) - - # Special case for loading scales off disk, which often do not - # have a shape (such as in the case of AutoFP8). - if len(loaded_delta.shape) == 0: - loaded_delta = loaded_delta.reshape(1) - - # print('{param_data.shape=} | {loaded_delta.shape=}') - # assert param_data.shape == loaded_delta.shape - param_data.copy_(param_data + loaded_delta.to(device)) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - ndim = hidden_states.dim() - qkv, _ = self.qkv_proj(hidden_states) - seq_len = hidden_states.shape[-2] - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, - fmap_q=self.feature_map_q, - fmap_k=self.feature_map_k, - window_factors=self.window_factors) - attn_output = attn_output.transpose(1, 2).contiguous().view(-1, seq_len, self.num_heads * self.head_dim) - output, _ = self.o_proj(attn_output) - if output.dim() > ndim: - output = output.squeeze(0) - return output - - -class LlamaLolcatsForCausalLM(LlamaForCausalLM): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - print(f"LOLCATS!!!: Loading model with config: {self.config}") - - softmax_attentions = getattr(self.config, 'softmax_attentions', []) - - for i in range(len(self.model.layers)): - if i in softmax_attentions: - print(f"Using Lora Llama Attention at Layer {i}") - self.model.layers[i].self_attn = LlamaLoraAttention( - config=self.config, - hidden_size=self.config.hidden_size, - num_heads=self.config.num_attention_heads, - num_kv_heads=getattr(self.config, "num_key_value_heads", - self.config.num_attention_heads), - rope_theta=self.config.rope_theta, - rope_scaling=self.config.rope_scaling, - ) - else: - self.model.layers[i].self_attn = LlamaLolcatsAttention( - config=self.config, - hidden_size=self.config.hidden_size, - num_heads=self.config.num_attention_heads, - num_kv_heads=getattr(self.config, "num_key_value_heads", - self.config.num_attention_heads), - rope_theta=self.config.rope_theta, - rope_scaling=self.config.rope_scaling, - ) - - def get_device(self): - device = next(self.parameters()).device - return str(device) - - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - super().load_weights(weights) - - # model_size = 8 - # model_size = 70 - model_size = 405 - - # PATH = f'/data/rahul/checkpoints/{model_size}b.pt' - # PATH = f'/home/rahul/code/lolcats/ckpt_lora-dl-d=distill_redpajama_xent1_mse1000_lr1e-2-m=distill_llama3_1_70b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_redpajama-dcs=512-se=0-re=4-lzi=1-dcs=512-se=0-re=4.pt' - - # Trenchcoats v1 - # PATH = f'/home/rahul/code/lolcats/ckpt_lora-dl-d=distill_llama_405b_xent1_mse1000_lr1e-2-m=distill_llama3_1_405b_lk_smd_wtk64_fd64_w01-f=finetune_layer_mini_xent1_mse1000-s=0-se=0-re=0-ef=finetune_llama_405b_qkvo-ft_lora=0-se=0-re=0-ef=finetune_llama_405b_qkvo-ft_lora=0.pt' - - # No distill - # PATH = f'/home/rahul/code/lolcats/ckpt_lora-dl-d=no_distill_alpaca_clean-m=distill_llama3_1_405b_lk_smd_wtk64_fd64_w01-f=finetune_layer_mini_xent1_mse1000-s=0-se=0-re=0-ef=no_distill_finetune_405b-ft_lora=0-se=0-re=0-ef=no_distill_finetune_405b-ft_lora=0-no_distill.pt' - - # Hybrid (last cria attention) - PATH = f'/home/rahul/code/lolcats/ckpt_lora-dl-d=distill_llama_405b_xent1_mse1000_lr1e-2-m=distill_llama3_1_405b_lk_smd_wtk64_fd64_w01_h117-f=finetune_layer_mini_xent1_mse1000-s=0-se=0-re=0-ef=finetune_llama_405b_qkvo_cos-ft_lora=0-se=0-re=0-ef=finetune_llama_405b_qkvo_cos-ft_lora=0.pt' - - print(f"PATH: {PATH}!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") - - adapter_weights_path = os.getenv("LOLCATS_ADAPTER_PATH", PATH) - - adapter_weights = torch.load(adapter_weights_path, weights_only=True) - - adapter_weights_copy = OrderedDict({}) - - for key, value in adapter_weights.items(): - key_suffix = key[key.rindex("model.")+6:] - adapter_weights_copy[key_suffix] = value - - adapter_weights = adapter_weights_copy - updated_keys = [] - - print("\n") - for layer_idx, layer in enumerate(self.model.layers): - # if layer_idx == 0: - # print(f'Weight factors before checkpoint load, {layer.self_attn.window_factors.shape}, {layer.self_attn.window_factors.flatten()}') - - window_factors_key = f'layers.{layer_idx}.self_attn.window_factors' - if window_factors_key in adapter_weights: - layer.self_attn.load_window_factors(adapter_weights[window_factors_key]) - updated_keys.append(window_factors_key) - - # if layer_idx == 0: - # print(f'Weight factors after checkpoint load, {layer.self_attn.window_factors.shape}, {layer.self_attn.window_factors.flatten()}') - if layer_idx == 0: - print("\n") - print(f'FMAP Q before checkpoint load, {layer.self_attn.feature_map_q.layer.shape}, {layer.self_attn.feature_map_q.layer[0,0,:4]}') - - fm_q_key = f'layers.{layer_idx}.self_attn.feature_map_q.mlp.layer' - if fm_q_key in adapter_weights: - layer.self_attn.load_feature_map_q(adapter_weights[fm_q_key]) - updated_keys.append(fm_q_key) - - if layer_idx == 0: - print(f'FMAP Q after checkpoint load; {layer.self_attn.feature_map_q.layer.shape},{layer.self_attn.feature_map_q.layer[0,0,:4]}') - - fm_k_key = f'layers.{layer_idx}.self_attn.feature_map_k.mlp.layer' - if fm_k_key in adapter_weights: - layer.self_attn.load_feature_map_k(adapter_weights[fm_k_key]) - updated_keys.append(fm_k_key) - - weight_name = 'layers.{layer_idx}.self_attn.{proj}.lora_{a_or_b}.default.weight' - target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"] - # target_modules = ["q_proj", "k_proj", "v_proj"] - # target_modules = ["k_proj", "v_proj"] - # target_modules = ["q_proj", "k_proj"] - - r = 8 - lora_alpha = 16 - lora_dropout = 0 - - for proj in target_modules: - lora_A_key = weight_name.format(layer_idx=layer_idx, proj=proj, a_or_b='A') - lora_B_key = weight_name.format(layer_idx=layer_idx, proj=proj, a_or_b='B') - if lora_A_key in adapter_weights: - weight_A = adapter_weights[lora_A_key] - weight_B = adapter_weights[lora_B_key] - delta_AB = get_delta_weight(weight_A, weight_B, r=r, lora_alpha=lora_alpha, - lora_dropout=lora_dropout) - - # if layer_idx == 0: - # print(f'layer {layer_idx} weight_A.shape: {weight_A.shape} | weight_B.shape: {weight_B.shape} | delta_AB.shape: {delta_AB.shape}') - # print(f'layer {layer_idx} proj {proj} delta_AB', delta_AB.shape) - - if proj == 'o_proj': - if layer_idx == 0: - print("\n") - print(f'Layer {layer_idx} {proj} weight before checkpoint load, {layer.self_attn.o_proj.weight.shape}, {layer.self_attn.o_proj.weight[0,:4]}') - - layer.self_attn.merge_lora_to_o_parallel(delta_AB) - - if layer_idx == 0: - print(f'Layer {layer_idx} {proj} weight after checkpoint load, {layer.self_attn.o_proj.weight.shape}, {layer.self_attn.o_proj.weight[0,:4]}') - else: - # if layer_idx == 0 and proj in ['q_proj']: - # print(f'Layer {layer_idx} {proj} weight before checkpoint load, {layer.self_attn.qkv_proj.weight.shape}, {layer.self_attn.qkv_proj.weight[0,:4]}') - - layer.self_attn.merge_lora_to_qkv_parallel(delta_AB, - loaded_shard_id=proj.split('_')[0], - total_num_heads=layer.self_attn.num_heads, - total_num_kv_heads=layer.self_attn.num_kv_heads, - head_size=layer.self_attn.head_dim) - # if layer_idx == 0 and proj in ['q_proj']: - # print(f'Layer {layer_idx} {proj} weight after checkpoint load, {layer.self_attn.qkv_proj.weight.shape}, {layer.self_attn.qkv_proj.weight[0,:4]}') - updated_keys.append(lora_A_key) - updated_keys.append(lora_B_key) - - assert len(set(adapter_weights_copy.keys()) - set(updated_keys)) == 0, \ - f"UNUPDATED KEYS: {set(adapter_weights_copy.keys()) - set(updated_keys)}" - - -def transpose(weight, fan_in_fan_out): - if not fan_in_fan_out: - return weight - - if isinstance(weight, torch.nn.Parameter): - return torch.nn.Parameter(weight.T) - return weight.T - - -def get_delta_weight(weight_A: torch.Tensor, weight_B: torch.Tensor, - r: int = 8, lora_alpha: float = 16, lora_dropout: float = 0, - fan_in_fan_out: bool = False,): - - device = weight_B.device - dtype = weight_B.dtype - # From https://github.com/huggingface/peft/blob/850eeb5c3a5cf692f5612c7c733b13fde184e05d/src/peft/tuners/lora/layer.py#L512 - cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16) - if cast_to_fp32: - weight_A = weight_A.float() - weight_B = weight_B.float() - scaling = lora_alpha / r - output_tensor = transpose(weight_B @ weight_A, fan_in_fan_out) * scaling - if cast_to_fp32: - output_tensor = output_tensor.to(dtype=dtype) - return output_tensor diff --git a/demos/vllm_integration/vllm_files/lolcats_inference.py b/demos/vllm_integration/vllm_files/lolcats_inference.py deleted file mode 100644 index 9865052..0000000 --- a/demos/vllm_integration/vllm_files/lolcats_inference.py +++ /dev/null @@ -1,870 +0,0 @@ -# coding=utf-8 -# Adapted from -# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py -# Copyright 2023 The vLLM team. -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Inference-only LLaMA model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union - -import math -import os -import torch -import time -from collections import OrderedDict -from torch import nn -from torch.nn.parameter import Parameter -from transformers import LlamaConfig - -from vllm.attention import Attention, AttentionMetadata, AttentionType -from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) -from vllm.distributed import (divide, - # split_tensor_along_last_dim, - # tensor_model_parallel_all_gather, - # tensor_model_parallel_all_reduce - ) -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) -from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( - get_compressed_tensors_cache_scale) -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors -from vllm.utils import is_hip - - - -from vllm.model_executor.models.interfaces import SupportsLoRA -from vllm.model_executor.models.utils import PPMissingLayer, is_pp_missing_parameter, make_layers -# from .interfaces import SupportsLoRA -# from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers -from vllm.model_executor.models.llama import LlamaForCausalLM, LlamaAttention, LlamaMLP - -from vllm.logger import init_logger -from vllm.model_executor.layers.linear import ReplicatedLinear - -logger = init_logger(__name__) - - -### OURS for Linear attention implementation -# from peft import get_peft_model, LoraConfig, TaskType - -# PEFT_KWARGS = { -# 'r': 8, -# 'lora_alpha': 16, # 32 -# 'lora_dropout': 0.05, -# 'target_modules': ["q_proj", "v_proj", "k_proj", "o_proj"] -# } - -### Hybrid Attention - - -from vllm.attention import Attention, AttentionMetadata - - -### VLLM Llama Model - - -class FeatureMap(nn.Module): - """ - Learnable MLP in feature map. - - Full feature map is like f(xW + b) - -> This is the `W` and (optional) `b` part - """ - def __init__(self, - num_heads: int, - head_dim: int, - feature_dim: int, - dtype: torch.dtype, - device: torch.device, - eps: float = 1e-12, - **kwargs): - super().__init__() - self.num_heads = num_heads - self.head_dim = head_dim - self.feature_dim = feature_dim - self.dtype = dtype - self.device = device - self.eps = eps - self.init_weights_() - - def activation(self, x: torch.Tensor): - return torch.cat([ - torch.softmax(x, dim=-1), torch.softmax(-x, dim=-1) - ], dim=-1).clamp(min=self.eps) - - def init_weights_(self): - self.layer = nn.Parameter(torch.zeros( - (self.num_heads, self.head_dim, self.feature_dim), - dtype=self.dtype, device=self.device, - )) - - def forward(self, x: torch.Tensor): - return self.activation( - torch.einsum('hdf,bhld->bhlf', self.layer, x.to(self.dtype))) - - -from dataclasses import dataclass -@dataclass -class LoLCacheParams: - is_prompt: bool = False - kv_state: torch.Tensor = torch.Tensor() - k_state: torch.Tensor = torch.Tensor() - kv_cache: torch.Tensor = torch.Tensor() - - -class LlamaLolcatsAttentionActual(nn.Module): - """Attention layer. - - This class takes query, key, and value tensors as input. The input tensors - can either contain prompt tokens or generation tokens. - The class does the following: - - 1. Store the input key and value tensors in the KV cache. - 2. Perform (multi-head/multi-query/grouped-query) attention. - 3. Return the output tensor. - """ - - def __init__( - self, - num_heads: int, - head_size: int, - num_kv_heads: int, - layer_idx: int, - ) -> None: - super().__init__() - self.num_heads = num_heads - self.head_size = head_size - self.num_kv_heads = num_kv_heads - - assert self.num_heads % self.num_kv_heads == 0 - self.num_queries_per_kv = num_heads // num_kv_heads - - max_seq_len = 2048 - window_size = 64 - self.window_size = window_size - - self.register_buffer('mask_window', self._create_mask(max_seq_len, window_size, True)) - self.register_buffer('mask_linear', self._create_mask(max_seq_len, window_size, False)) - - # SA: inference cache - self.lolcats_cache = None - tp_rank = get_tensor_model_parallel_rank() - - print(f"{layer_idx=}") - self.layer_idx = layer_idx - self.tp_rank = tp_rank - - - def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - fmap_q: FeatureMap, - fmap_k: FeatureMap, - window_factors: torch.Tensor, - state=None - ) -> torch.Tensor: - - if self.lolcats_cache is None: - self._prepare_lolcats_cache() - - # num_tokens, hidden_size = query.shape - # Reshape the query, key, and value tensors. - query = query.view(-1, self.num_heads, self.head_size) - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) - - if self.num_kv_heads != self.num_heads: - key = key.repeat_interleave(self.num_queries_per_kv, dim=1) - value = value.repeat_interleave(self.num_queries_per_kv, dim=1) - - query = query.movedim(0, query.dim() - 2) - key = key.movedim(0, key.dim() - 2) - value = value.movedim(0, value.dim() - 2) - - if query.dim() == 3: - query = query.unsqueeze(0) - key = key.unsqueeze(0) - value = value.unsqueeze(0) - - f_q = fmap_q(query) - f_k = fmap_k(key) - - window_size = 64 - window_factors = torch.nn.functional.sigmoid(window_factors) - linear_factors = 1 - - # if self.tp_rank == 0 and self.layer_idx == 0: print(f"{query.shape=}") - seqlen = query.shape[2] - # if self.tp_rank == 0 and self.layer_idx == 0: print(f"{seqlen=}") - if seqlen == 1: - return self.recurrent_attention( - query, key, f_q, f_k, - value, window_factors, - linear_factors, - window_size, - fmap_q, fmap_k - ) - else: - return self.superlinear_attention( - query, key, f_q, f_k, - value, - window_factors, linear_factors, - window_size - ) - - - def _create_mask(self, max_seq_len: int, window_size: int, is_window: bool) -> torch.Tensor: - l = window_size - m = math.ceil(max_seq_len / window_size) - mask = torch.block_diag(*[torch.ones((l, l))] * m) - mask += torch.roll(mask, -l, -1) - mask = mask[:max_seq_len, :max_seq_len] - mask = mask[None, None, ...] # b, h, q_len, k_len - mask = torch.tril(mask if is_window else 1 - mask).to(dtype=torch.bool) - return mask - - - def get_masks(self, window_size: int, q_len: int, k_len: int, device: torch.device) -> tuple[torch.Tensor]: - return self.mask_window[:, :, :q_len, :k_len], self.mask_linear[:, :, :q_len, :k_len] - - - def _prepare_lolcats_cache(self): - if self.tp_rank == 0 and self.layer_idx == 0: print("heyo 5 -- hello prepare kv cache") - dtype = torch.bfloat16 - bs = 1 - self.lolcats_cache = LoLCacheParams() - if self.tp_rank == 0 and self.layer_idx == 0: print("heyo 6 -- bye prepare kv cache") - - - def _init_kv_cache(self, keys, values, f_k): - if self.tp_rank == 0 and self.layer_idx == 0: print("heyo 3 -- hello init kv cache") - dtype = keys.dtype - if self.tp_rank == 0 and self.layer_idx == 0: print(f"{f_k.shape=}") - - # decoding KV state (KV terms up to last window_size) - decode_kv_state = torch.einsum('bhlf,bhld->bhfd', - f_k[:, :, :-self.window_size], - values[:, :, :-self.window_size] - ) - - if self.tp_rank == 0 and self.layer_idx == 0: - print(decode_kv_state[0][0]) - if self.tp_rank == 0 and self.layer_idx == 0: print(f"{decode_kv_state.shape=}") - if self.tp_rank == 0 and self.layer_idx == 0: print(f"{f_k.shape=}") - # shape is b, h, 1, f; note the 1 - decode_k_state = f_k[:, :, :-self.window_size].sum(dim=-2,keepdim=True) - self.lolcats_cache.kv_state = decode_kv_state - self.lolcats_cache.k_state = decode_k_state - if self.tp_rank == 0 and self.layer_idx == 0: print(f"{decode_k_state.shape=}") - - # update the cache - kv_cache = torch.stack([ - keys[:, :, -self.window_size:, :].float(), - values[:, :, -self.window_size:, :].float() - ], dim=1) - if self.tp_rank == 0 and self.layer_idx == 0: print(f"{kv_cache.shape=}") - self.lolcats_cache.kv_cache = kv_cache - if self.tp_rank == 0 and self.layer_idx == 0: print("heyo 4 -- bye init kv cache") - - - def superlinear_attention( - self, q: torch.Tensor, k: torch.Tensor, - f_q: torch.Tensor, f_k: torch.Tensor, - v: torch.Tensor, - window_factor: torch.Tensor, linear_factor: torch.Tensor, - window_size: int, - kv_state: torch.Tensor = None, k_state: torch.Tensor = None, - eps: float = 1e-12, - mask_value: float=-1e8 - ): - """ - Hybrid attention combining sliding window and linear attentions - """ - - mask_window, mask_linear = self.get_masks( - window_size, q.shape[-2], k.shape[-2], q.device) - - # 1. Sliding window (softmax attention) - # a_sm = torch.einsum( - # 'bhmd,bhnd->bhmn', q.float(), k.float()) * (k.shape[-1] ** -0.5) - a_sm = torch.einsum( - 'bhmd,bhnd->bhmn', q, k) * (k.shape[-1] ** -0.5) - a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value) - # torch.softmax(a_sm, dim=-1), but we account for the max when combining - a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True) - a_sm = window_factor * torch.exp(a_sm - a_sm_max) - sum_sm = a_sm.sum(dim=-1, keepdim=True) - - # 2. Under window (linear attention) - # a_ln = torch.einsum('bhmd,bhnd->bhmn', f_q.float(), f_k.float()) - a_ln = torch.einsum('bhmd,bhnd->bhmn', f_q, f_k) - a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0) - sum_ln = a_ln.sum(dim=-1, keepdim=True) - - # 3. Combine - # Allow outputs to also depend on prior kv_state and k_state - # y = torch.einsum('bhmn,bhnd->bhmd', a_sm + a_ln, v.float()) - # y = (y / (sum_sm + sum_ln)).to(q.dtype) - y = torch.einsum('bhmn,bhnd->bhmd', a_sm + a_ln, v) - y = (y / (sum_sm + sum_ln)) - # # logger.info(f"splattn {y.shape=}") - - self._init_kv_cache(k, v, f_k) - return y # attention weights only for the last chunk - - - def _update_kv_cache(self, keys, values, fmap_k): - # if self.tp_rank == 0 and self.layer_idx == 0: print("heyo 1 - hello update kv cache") - # get state from before - kv_state = self.lolcats_cache.kv_state - k_state = self.lolcats_cache.k_state - kv_cache_swa = self.lolcats_cache.kv_cache - k_cache = kv_cache_swa[:, 0] - v_cache = kv_cache_swa[:, 1] - - dtype = kv_state.dtype - - # update the linear attention states - # since we ignore the diag blocks, just grab last tokens of kv cache - cur_seq_len = k_cache.shape[-2] - if self.tp_rank == 0 and self.layer_idx == 0: print(f"{cur_seq_len=}") - if cur_seq_len >= self.window_size: - if self.tp_rank == 0 and self.layer_idx == 0: print(f"Updating the kv_state and k_state...") - # if self.tp_rank == 0 and self.layer_idx == 0: - # print(f"{fmap_k.layer=}") - # print(f"{k_cache[0, 0, 0, 0:8]=}") - # print(f"{k_cache[:, :, :1, :]=}") - # print(f"{fmap_k(k_cache[:, :, :1, :])=}") - k_state = fmap_k(k_cache[:, :, :1, :]) - v_state = v_cache[:, :, :1, :] - kv_state = torch.einsum('bhlf,bhld->bhfd', k_state.float(), v_state.float()).to(dtype) # b, h, f, d - self.lolcats_cache.kv_state += kv_state.to(kv_state.dtype) - self.lolcats_cache.k_state += k_state - - # update swa states - if cur_seq_len < self.window_size: - # only add to cache - k_cache = torch.cat([k_cache, keys], dim=-2) - v_cache = torch.cat([v_cache, values], dim=-2) - else: - # remove oldest key and value and append - k_cache = torch.cat([k_cache[:, :, 1:, :], keys], dim=-2) - v_cache = torch.cat([v_cache[:, :, 1:, :], values], dim=-2) - kv_cache_swa = torch.stack([k_cache, v_cache], dim=1) - self.lolcats_cache.kv_cache = kv_cache_swa - - # if self.tp_rank == 0 and self.layer_idx == 0: print("heyo 2 - bye update kv cache") - return self.lolcats_cache.kv_state, self.lolcats_cache.k_state, k_cache, v_cache - - - def recurrent_attention( - self, q: torch.Tensor, k: torch.Tensor, - f_q: torch.Tensor, f_k: torch.Tensor, - v: torch.Tensor, - window_factor: torch.Tensor, - linear_factor: torch.Tensor, - window_size: int, - fmap_q, fmap_k, - kv_state: torch.Tensor = None, - k_state: torch.Tensor = None, - eps: float = 1e-12, mask_value: float=-1e8 - ): - dtype = torch.float32 - # if self.tp_rank == 0 and self.layer_idx == 0: print(f"hello recurrent step!") - kv_state, k_state, k_cache, v_cache = self._update_kv_cache(k, v, fmap_k) - - # Softmax attention terms - a_sm = torch.einsum('bhmd,bhnd->bhmn', q.float(), k_cache.float()) * (k.shape[-1] ** -0.5) - a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True) - a_sm = window_factor * torch.exp(a_sm - a_sm_max) - sum_sm = a_sm.sum(dim=-1, keepdim=True) - y_sm = torch.einsum('bhmn,bhnd->bhmd', a_sm.float(), v_cache.float()) - - # Combine with linear attention terms - f_q = fmap_q(q) - y_ln = linear_factor * torch.einsum('bhlf,bhfd->bhld', f_q.float(), kv_state.float()) - sum_ln = linear_factor * torch.einsum('bhld,bhnd->bhl', f_q.float(), k_state.float())[..., None] - - # if self.tp_rank == 0 and self.layer_idx == 0: - # print(f"{y_ln[0][0][0][:4]=}") - # print(f"{sum_ln[0][0]=}") - - y = y_sm + y_ln - attn_output = (y / (sum_sm + sum_ln)).to(q.dtype) - - # if self.tp_rank == 0 and self.layer_idx == 0: print(f"bye recurrent step!") - return attn_output - - -class LlamaLolcatsAttention(LlamaAttention): - def __init__(self, layer_idx, *args, **kwargs): - - super().__init__(*args, **kwargs) - print(f"{layer_idx=}") - self.attn = LlamaLolcatsAttentionActual(self.num_heads, - self.head_dim, - self.num_kv_heads, - layer_idx) - self.head_size = self.head_dim - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - - _device = self.qkv_proj.weight.device - _dtype = self.qkv_proj.weight.dtype - - _feature_dim = 64 - _feature_map_kwargs = { - "num_heads": self.num_heads, - "head_dim": self.head_dim, - "feature_dim": _feature_dim, - "dtype": _dtype, - "device": _device, - } - self.feature_dim = _feature_dim - self.window_size = 64 - - tp_rank = get_tensor_model_parallel_rank() - self.tp_rank = tp_rank - self.layer_idx = layer_idx - - self.feature_map_q = FeatureMap(**_feature_map_kwargs) - self.feature_map_k = FeatureMap(**_feature_map_kwargs) - self.window_factors = nn.Parameter( - torch.ones(1, self.num_heads, 1, 1, device=_device, dtype=_dtype)) - - def load_window_factors(self, loaded_weight: torch.Tensor): - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() - - # print(f"{tp_size=}") - # assert 0, "ahhhh window factors" - - if tp_size > 1: - - num_heads_per_rank = self.num_heads - start_idx = tp_rank * num_heads_per_rank - end_idx = start_idx + num_heads_per_rank - - if self.layer_idx == 0 and tp_rank == 0: - print(loaded_weight) - - if self.layer_idx < 2: - print(f"{num_heads_per_rank=}") - print(f"{tp_rank=}; {loaded_weight.shape=}; {start_idx=}; {end_idx=}") - - sharded_weight = loaded_weight[:, start_idx:end_idx, :, :] - - else: - - sharded_weight = loaded_weight - - assert self.window_factors.shape == sharded_weight.shape, \ - f"Shape mismatch: {self.window_factors.shape} vs {sharded_weight.shape}" - - with torch.no_grad(): - self.window_factors.copy_(sharded_weight) - - def load_feature_map_q(self, loaded_weight: torch.Tensor): - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() - - # print(f"{tp_size=}") - # assert 0, "ahhhh feature map q" - - if tp_size > 1: - - num_heads_per_rank = self.feature_map_q.num_heads - start_idx = tp_rank * num_heads_per_rank - end_idx = start_idx + num_heads_per_rank - - sharded_weight = loaded_weight[start_idx:end_idx, :, :] - - if sharded_weight.shape[-1] != self.feature_map_q.layer.shape[-1]: - sharded_weight = sharded_weight[:, :, :self.feature_map_q.layer.shape[-1]] - - else: - - sharded_weight = loaded_weight - - assert self.feature_map_q.layer.shape == sharded_weight.shape, \ - f"Shape mismatch: {self.feature_map_q.layer.shape} vs {sharded_weight.shape}" - - with torch.no_grad(): - self.feature_map_q.layer.copy_(sharded_weight) - - def load_feature_map_k(self, loaded_weight: torch.Tensor): - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() - - # print(f"{tp_size=}") - # assert 0, "ahhhh" - - if tp_size > 1: - - num_heads_per_rank = self.feature_map_k.num_heads - start_idx = tp_rank * num_heads_per_rank - end_idx = start_idx + num_heads_per_rank - - sharded_weight = loaded_weight[start_idx:end_idx, :, :] - - if sharded_weight.shape[-1] != self.feature_map_k.layer.shape[-1]: - sharded_weight = sharded_weight[:, :, :self.feature_map_k.layer.shape[-1]] - - else: - - sharded_weight = loaded_weight - - assert self.feature_map_k.layer.shape == sharded_weight.shape, \ - f"Shape mismatch: {self.feature_map_k.layer.shape} vs {sharded_weight.shape}" - - with torch.no_grad(): - self.feature_map_k.layer.copy_(sharded_weight) - # self.feature_map_k.layer.normal_(std=1) - - def merge_lora_to_qkv_parallel(self, # param: Parameter, - loaded_delta: torch.Tensor, - loaded_shard_id: str = 'q', - total_num_heads: int = 32, - total_num_kv_heads: int = 4, - head_size: int = 128): - """ - Merge computed delta_AB into QKV parallel weights - - Based off of vLLM linear layer: https://github.com/vllm-project/vllm/blob/bc6e42a9b19364e07da9f279edd81796541d147d/vllm/model_executor/layers/linear.py#L762 - then Rahul, then Claude 3.5 Sonnet - - model.layers.1.self_attn.qkv_proj.weight torch.Size([1280, 8192]) - --> output_dim 0 - model.layers.1.self_attn.o_proj.weight torch.Size([8192, 1024]) - --> output_dim 0 - - apply this three times for q, k, and v LoRA deltas to the same layer - """ - - param = self.qkv_proj.weight - param_data = param.data - output_dim = getattr(param, "output_dim", None) - - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() - # num_heads = divide(total_num_heads, tp_size) - # if tp_size >= total_num_kv_heads: - # num_kv_heads = 1 - # # num_kv_head_replicas = divide(tp_size, total_num_kv_heads) - # else: - # num_kv_heads = divide(total_num_kv_heads, tp_size) - # # num_kv_head_replicas = 1 - num_heads = total_num_heads - num_kv_heads = total_num_kv_heads - - num_original_kv_heads = 8 # all Llama 3.1 models have 8 kv heads in total - - num_kv_head_replicas = tp_size // num_original_kv_heads - - if loaded_shard_id == "q": - shard_offset = 0 - shard_size = num_heads * head_size - elif loaded_shard_id == "k": - shard_offset = num_heads * head_size - shard_size = num_kv_heads * head_size - elif loaded_shard_id == "v": - shard_offset = (num_heads + num_kv_heads) * head_size - shard_size = num_kv_heads * head_size - - # print(f"{tp_rank=}, {tp_size=}") - if loaded_shard_id == "q": - start_idx = tp_rank * shard_size - else: - start_idx = (tp_rank // num_kv_head_replicas) * shard_size - - device = param_data.device - - param_data = param_data.narrow(output_dim, shard_offset, shard_size) - # print(f'{loaded_shard_id=}') - # print(f'{shard_offset=}, {shard_size=}, {shard_offset+shard_size=}') - # print(f'{output_dim=}, {start_idx=}, {param_data.shape=}') - # print('-' * 10) - - # self.qkv_proj.weight.data[shard_offset:shard_offset+shard_size, :] += ( - # loaded_delta.narrow(output_dim, start_idx, shard_size).to(device) - # ) - # print(f'Loading {loaded_shard_id} {start_idx}-{start_idx+shard_size} into {param_data.shape}, which is slice({shard_offset}, {shard_offset+shard_size})') - try: - param_data.copy_(param_data + loaded_delta.narrow(output_dim, start_idx, shard_size).to(device)) - # print(f"Loaded {loaded_shard_id} into {param_data.shape}") - except Exception as e: - print(f"Error: {e}") - print(f"{loaded_shard_id=}") - print(f"{output_dim=}") - print(f"{start_idx=}") - print(f"{shard_size=}") - print(f"{param_data.shape=}") - print(f"{loaded_delta.shape=}") - print(f"{tp_rank=}") - print(f"{tp_size=}") - - def merge_lora_to_o_parallel(self, - loaded_delta: torch.Tensor): - """ - Merge computed delta_AB into output projection (RowParallel linear) - """ - param = self.o_proj.weight - param_data = param.data - input_dim = getattr(param, "input_dim", None) - device = param_data.device - - # print('o_proj {input_dim=}') - - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() - - if input_dim is not None: - shard_size = param_data.shape[input_dim] - start_idx = tp_rank * shard_size - loaded_delta = loaded_delta.narrow(input_dim, start_idx, shard_size) - - # Special case for loading scales off disk, which often do not - # have a shape (such as in the case of AutoFP8). - if len(loaded_delta.shape) == 0: - loaded_delta = loaded_delta.reshape(1) - - # print('{param_data.shape=} | {loaded_delta.shape=}') - # assert param_data.shape == loaded_delta.shape - param_data.copy_(param_data + loaded_delta.to(device)) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - ndim = hidden_states.dim() - qkv, _ = self.qkv_proj(hidden_states) - seq_len = hidden_states.shape[-2] - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self.rotary_emb(positions, q, k) - - attn_output = self.attn( - q, k, v, - fmap_q=self.feature_map_q, - fmap_k=self.feature_map_k, - window_factors=self.window_factors, - state=None - ) - - # outputs - attn_output = attn_output.transpose(1, 2).contiguous().view(-1, seq_len, self.num_heads * self.head_dim) - output, _ = self.o_proj(attn_output) - if output.dim() > ndim: - output = output.squeeze(0) - return output - - -class LlamaLolcatsForCausalLM(LlamaForCausalLM): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - print(f"LOLCATS!!!: Loading model with config: {self.config}") - - softmax_attentions = getattr(self.config, 'softmax_attentions', []) - print(f"{softmax_attentions=}") - - tp_rank = get_tensor_model_parallel_rank() - self.tp_rank = tp_rank - - for i in range(len(self.model.layers)): - if i in softmax_attentions: - print(f"Using Lora Llama Attention at Layer {i}") - self.model.layers[i].self_attn = LlamaSdpaAttention( - config=self.config, - hidden_size=self.config.hidden_size, - num_heads=self.config.num_attention_heads, - num_kv_heads=getattr(self.config, "num_key_value_heads", - self.config.num_attention_heads), - rope_theta=self.config.rope_theta, - rope_scaling=self.config.rope_scaling, - ) - else: - self.model.layers[i].self_attn = LlamaLolcatsAttention( - i, - config=self.config, - hidden_size=self.config.hidden_size, - num_heads=self.config.num_attention_heads, - num_kv_heads=getattr(self.config, "num_key_value_heads", - self.config.num_attention_heads), - rope_theta=self.config.rope_theta, - rope_scaling=self.config.rope_scaling, - ) - print(self.model) - - def get_device(self): - device = next(self.parameters()).device - return str(device) - - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - super().load_weights(weights) - - # model_size = 8 - # FINETUNE_PATH = '/home/rahul/code/lolcats/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_1_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-ms=2500-se=0-re=100-lzi=1-bs=1-gas=8-nte=2-ms=2500-se=0-re=100_ft.pt' - # MLP_PATH = '/home/rahul/code/lolcats/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_1_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-ms=2500-se=0-re=100-lzi=1_distill.pt' - # # merge the MLP and FINETUNE weights as adapter weights - # adapter_weights = torch.load(FINETUNE_PATH, weights_only=True) - # adapter_weights.update(torch.load(MLP_PATH, weights_only=True)) - # print(adapter_weights.keys()) - # # only keep any weight with 'feature' or 'window' or 'lora' in the key - # adapter_weights = {k: v for k, v in adapter_weights.items() if 'feature' in k or 'window' in k or 'lora' in k} - - model_size = 70 - PATH = f'/data/rahul/checkpoints/{model_size}b.pt' - PATH = f'/home/rahul/code/lolcats/ckpt_lora-dl-d=distill_redpajama_xent1_mse1000_lr1e-2-m=distill_llama3_1_70b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_redpajama-dcs=512-se=0-re=4-lzi=1-dcs=512-se=0-re=4.pt' - print(f"PATH INFERENCE: {PATH}!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") - adapter_weights_path = os.getenv("LOLCATS_ADAPTER_PATH", PATH) - adapter_weights = torch.load(adapter_weights_path, weights_only=True) - - adapter_weights_copy = OrderedDict({}) - - for key, value in adapter_weights.items(): - key_suffix = key[key.rindex("model.")+6:] - adapter_weights_copy[key_suffix] = value - - adapter_weights = adapter_weights_copy - updated_keys = [] - - print("\n") - num_layers = len(self.model.layers) - for layer_idx, layer in enumerate(self.model.layers): - if layer_idx == 0: - print(f'Weight factors before checkpoint load {self.tp_rank=}, {layer.self_attn.window_factors.shape}, {layer.self_attn.window_factors.flatten()}') - - window_factors_key = f'layers.{layer_idx}.self_attn.window_factors' - if window_factors_key in adapter_weights: - layer.self_attn.load_window_factors(adapter_weights[window_factors_key]) - updated_keys.append(window_factors_key) - - if layer_idx == 0: - print(f'Weight factors after checkpoint load {self.tp_rank=}, {layer.self_attn.window_factors.shape}, {layer.self_attn.window_factors.flatten()}') - - fm_q_key = f'layers.{layer_idx}.self_attn.feature_map_q.mlp.layer' - if fm_q_key in adapter_weights: - # if layer_idx in [0, num_layers-1]: - # # print("\n") - # # print(f'FMAP Q before checkpoint load {self.tp_rank=}, {layer.self_attn.feature_map_q.layer.shape}, {layer.self_attn.feature_map_q.layer[0,0,:4]}') - - layer.self_attn.load_feature_map_q(adapter_weights[fm_q_key]) - updated_keys.append(fm_q_key) - - # if layer_idx in [0, num_layers-1]: - # print(f'FMAP Q after checkpoint load; {layer.self_attn.feature_map_q.layer.shape},{layer.self_attn.feature_map_q.layer[0,0,:4]}') - - fm_k_key = f'layers.{layer_idx}.self_attn.feature_map_k.mlp.layer' - if fm_k_key in adapter_weights: - layer.self_attn.load_feature_map_k(adapter_weights[fm_k_key]) - updated_keys.append(fm_k_key) - - weight_name = 'layers.{layer_idx}.self_attn.{proj}.lora_{a_or_b}.default.weight' - target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"] - # target_modules = ["q_proj", "k_proj", "v_proj"] - # target_modules = ["k_proj", "v_proj"] - # target_modules = ["q_proj", "k_proj"] - - r = 8 - lora_alpha = 16 - lora_dropout = 0 - - for proj in target_modules: - lora_A_key = weight_name.format(layer_idx=layer_idx, proj=proj, a_or_b='A') - lora_B_key = weight_name.format(layer_idx=layer_idx, proj=proj, a_or_b='B') - if lora_A_key in adapter_weights: - weight_A = adapter_weights[lora_A_key] - weight_B = adapter_weights[lora_B_key] - delta_AB = get_delta_weight(weight_A, weight_B, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout) - - # if layer_idx in [0, num_layers-1]: - # print(f'layer {layer_idx} weight_A.shape: {weight_A.shape} | weight_B.shape: {weight_B.shape} | delta_AB.shape: {delta_AB.shape}') - # print(f'layer {layer_idx} proj {proj} delta_AB', delta_AB.shape) - - if proj == 'o_proj': - # if layer_idx in [0, num_layers-1]: - # print("\n") - # print(f'Layer {layer_idx} {proj} weight before checkpoint load, {layer.self_attn.o_proj.weight.shape}, {layer.self_attn.o_proj.weight[0,:4]}') - - layer.self_attn.merge_lora_to_o_parallel(delta_AB) - - # if layer_idx in [0, num_layers-1]: - # print(f'Layer {layer_idx} {proj} weight after checkpoint load, {layer.self_attn.o_proj.weight.shape}, {layer.self_attn.o_proj.weight[0,:4]}') - else: - # if layer_idx in [0, num_layers-1] and proj in ['q_proj']: - # print(f'Layer {layer_idx} {proj} weight before checkpoint load, {layer.self_attn.qkv_proj.weight.shape}, {layer.self_attn.qkv_proj.weight[0,:4]}') - - layer.self_attn.merge_lora_to_qkv_parallel( - delta_AB, - loaded_shard_id=proj.split('_')[0], - total_num_heads=layer.self_attn.num_heads, - total_num_kv_heads=layer.self_attn.num_kv_heads,head_size=layer.self_attn.head_dim) - - # if layer_idx in [0, num_layers-1] and proj in ['q_proj']: - # print(f'Layer {layer_idx} {proj} weight after checkpoint load, {layer.self_attn.qkv_proj.weight.shape}, {layer.self_attn.qkv_proj.weight[0,:4]}') - updated_keys.append(lora_A_key) - updated_keys.append(lora_B_key) - - assert len(set(adapter_weights_copy.keys()) - set(updated_keys)) == 0, \ - f"UNUPDATED KEYS: {set(adapter_weights_copy.keys()) - set(updated_keys)}" - - -def transpose(weight, fan_in_fan_out): - if not fan_in_fan_out: - return weight - - if isinstance(weight, torch.nn.Parameter): - return torch.nn.Parameter(weight.T) - return weight.T - - -def get_delta_weight(weight_A: torch.Tensor, weight_B: torch.Tensor, - r: int = 8, lora_alpha: float = 16, lora_dropout: float = 0, - fan_in_fan_out: bool = False,): - - device = weight_B.device - dtype = weight_B.dtype - # From https://github.com/huggingface/peft/blob/850eeb5c3a5cf692f5612c7c733b13fde184e05d/src/peft/tuners/lora/layer.py#L512 - cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16) - if cast_to_fp32: - weight_A = weight_A.float() - weight_B = weight_B.float() - scaling = lora_alpha / r - output_tensor = transpose(weight_B @ weight_A, fan_in_fan_out) * scaling - if cast_to_fp32: - output_tensor = output_tensor.to(dtype=dtype) - return output_tensor diff --git a/demos/vllm_integration/vllm_files/lolcats_inference_paged.py b/demos/vllm_integration/vllm_files/lolcats_inference_paged.py deleted file mode 100644 index 5c92304..0000000 --- a/demos/vllm_integration/vllm_files/lolcats_inference_paged.py +++ /dev/null @@ -1,912 +0,0 @@ -# coding=utf-8 -# Adapted from -# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py -# Copyright 2023 The vLLM team. -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Inference-only LLaMA model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union - -import math -import os -import torch -import time -from collections import OrderedDict -from torch import nn -from torch.nn.parameter import Parameter -from transformers import LlamaConfig - -from vllm.attention import Attention, AttentionMetadata, AttentionType -from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) -from vllm.distributed import (divide, - # split_tensor_along_last_dim, - # tensor_model_parallel_all_gather, - # tensor_model_parallel_all_reduce - ) -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) -from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( - get_compressed_tensors_cache_scale) -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors -from vllm.utils import is_hip - - - -from vllm.model_executor.models.interfaces import SupportsLoRA -from vllm.model_executor.models.utils import PPMissingLayer, is_pp_missing_parameter, make_layers -# from .interfaces import SupportsLoRA -# from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers -from vllm.model_executor.models.llama import LlamaForCausalLM, LlamaAttention, LlamaMLP - -from vllm.logger import init_logger -from vllm.model_executor.layers.linear import ReplicatedLinear - -logger = init_logger(__name__) - - -### OURS for Linear attention implementation -# from peft import get_peft_model, LoraConfig, TaskType - -# PEFT_KWARGS = { -# 'r': 8, -# 'lora_alpha': 16, # 32 -# 'lora_dropout': 0.05, -# 'target_modules': ["q_proj", "v_proj", "k_proj", "o_proj"] -# } - -### Hybrid Attention -from vllm.attention import Attention, AttentionMetadata - - -### VLLM Llama Model -class FeatureMap(nn.Module): - """ - Learnable MLP in feature map. - - Full feature map is like f(xW + b) - -> This is the `W` and (optional) `b` part - """ - def __init__(self, - num_heads: int, - head_dim: int, - feature_dim: int, - dtype: torch.dtype, - device: torch.device, - eps: float = 1e-12, - **kwargs): - super().__init__() - self.num_heads = num_heads - self.head_dim = head_dim - self.feature_dim = feature_dim - self.dtype = dtype - self.device = device - self.eps = eps - self.init_weights_() - - def activation(self, x: torch.Tensor): - return torch.cat([ - torch.softmax(x, dim=-1), torch.softmax(-x, dim=-1) - ], dim=-1).clamp(min=self.eps) - - def init_weights_(self): - self.layer = nn.Parameter(torch.zeros( - (self.num_heads, self.head_dim, self.feature_dim), - dtype=self.dtype, device=self.device, - )) - - def forward(self, x: torch.Tensor): - return self.activation( - torch.einsum('hdf,bhld->bhlf', self.layer, x.to(self.dtype))) - - -from dataclasses import dataclass -@dataclass -class LoLCacheParams: - is_prompt: bool = False - kv_state: torch.Tensor = torch.Tensor() - k_state: torch.Tensor = torch.Tensor() - kv_cache: torch.Tensor = torch.Tensor() - -@dataclass -class PageCache: - kv_cache: torch.Tensor = None - q_cache: torch.Tensor = None - - -class LlamaLolcatsAttentionActual(nn.Module): - """Attention layer. - - This class takes query, key, and value tensors as input. The input tensors - can either contain prompt tokens or generation tokens. - The class does the following: - - 1. Store the input key and value tensors in the KV cache. - 2. Perform (multi-head/multi-query/grouped-query) attention. - 3. Return the output tensor. - """ - - def __init__( - self, - num_heads: int, - head_size: int, - num_kv_heads: int, - layer_idx: int, - ) -> None: - super().__init__() - self.num_heads = num_heads - self.head_size = head_size - self.num_kv_heads = num_kv_heads - - assert self.num_heads % self.num_kv_heads == 0 - self.num_queries_per_kv = num_heads // num_kv_heads - self.window_size = 64 - - # SA: inference cache - self.lolcats_cache = None - self.layer_idx = layer_idx - self.tp_rank = get_tensor_model_parallel_rank() - - - def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - fmap_q: FeatureMap, - fmap_k: FeatureMap, - window_factors: torch.Tensor, - state=None, - attn_metadata: AttentionMetadata = None - ) -> torch.Tensor: - # if self.layer_idx == 0: - # print(f"Initially: {query.shape=}, {key.shape=}, {value.shape=}") - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens - positions = attn_metadata.seq_start_loc.tolist() - start, end = positions[0], positions[1] - - if self.lolcats_cache is None or end == num_prefill_tokens: - # reset cache - self._prepare_lolcats_cache() - if self.layer_idx == 0 and self.tp_rank == 0: - print("Resetting cache") - print(f"-- {num_prefill_tokens=}, {num_decode_tokens=}, {start=}, {end=}") - # print(self.page_cache.kv_cache) - - # num_tokens, hidden_size = query.shape - # Reshape the query, key, and value tensors. - query = query.view(-1, self.num_heads, self.head_size) - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) - - if self.num_kv_heads != self.num_heads: - key = key.repeat_interleave(self.num_queries_per_kv, dim=1) - value = value.repeat_interleave(self.num_queries_per_kv, dim=1) - - query = query.movedim(0, query.dim() - 2) - key = key.movedim(0, key.dim() - 2) - value = value.movedim(0, value.dim() - 2) - - if query.dim() == 3: - query = query.unsqueeze(0) - key = key.unsqueeze(0) - value = value.unsqueeze(0) - - query, key, value = self.get_full_key_value(query, key, value) - if self.layer_idx == 0 and self.tp_rank == 0: - print(f"-- after update: {query.shape=}, {key.shape=}, {value.shape=}") - - f_q = fmap_q(query) - f_k = fmap_k(key) - - window_size = 64 - window_factors = torch.nn.functional.sigmoid(window_factors) - linear_factors = 1 - - seq_len = query.shape[-2] - if num_decode_tokens >= 1 or seq_len == 1: - return self.recurrent_attention( - query, key, f_q, f_k, - value, window_factors, - linear_factors, - window_size, - fmap_q, fmap_k - ) - else: - out = self.superlinear_attention( - query, key, f_q, f_k, - value, - window_factors, linear_factors, - window_size - ) - return out - - - def get_full_key_value(self, query, key, value): - # add the current key and value to the cache - if self.page_cache.kv_cache is not None: - key = torch.cat([self.page_cache.kv_cache[:, 0], key], dim=-2) - value = torch.cat([self.page_cache.kv_cache[:, 1], value], dim=-2) - query = torch.cat([self.page_cache.q_cache, query], dim=-2) - else: - key = key - value = value - query = query - - # update the cache - self.page_cache.kv_cache = torch.stack([key, value], dim=1) - self.page_cache.q_cache = query - return query, key, value - - - def _create_mask(self, max_seq_len: int, window_size: int, is_window: bool) -> torch.Tensor: - l = window_size - m = math.ceil(max_seq_len / window_size) - mask = torch.block_diag(*[torch.ones((l, l))] * m) - mask += torch.roll(mask, -l, -1) - mask = mask[:max_seq_len, :max_seq_len] - mask = mask[None, None, ...] # b, h, q_len, k_len - mask = torch.tril(mask if is_window else 1 - mask).to(dtype=torch.bool) - return mask - - - def get_masks(self, window_size: int, q_len: int, k_len: int, device: torch.device) -> tuple[torch.Tensor]: - mask_window = self._create_mask(q_len, window_size, True).to(device) - mask_linear = self._create_mask(q_len, window_size, False).to(device) - return mask_window[:, :, :q_len, :k_len], mask_linear[:, :, :q_len, :k_len] - - - def _prepare_lolcats_cache(self): - self.lolcats_cache = LoLCacheParams() - self.page_cache = PageCache() - - - def _init_kv_cache(self, keys, values, f_k): - dtype = keys.dtype - - # decoding KV state (KV terms up to last window_size) - decode_kv_state = torch.einsum('bhlf,bhld->bhfd', - f_k[:, :, :-self.window_size], - values[:, :, :-self.window_size] - ) - - # shape is b, h, 1, f; note the 1 - decode_k_state = f_k[:, :, :-self.window_size].sum(dim=-2,keepdim=True) - self.lolcats_cache.kv_state = decode_kv_state - self.lolcats_cache.k_state = decode_k_state - - # update the cache - kv_cache = torch.stack([ - keys[:, :, -self.window_size:, :].float(), - values[:, :, -self.window_size:, :].float() - ], dim=1) - self.lolcats_cache.kv_cache = kv_cache - - - def superlinear_attention( - self, q: torch.Tensor, k: torch.Tensor, - f_q: torch.Tensor, f_k: torch.Tensor, - v: torch.Tensor, - window_factor: torch.Tensor, linear_factor: torch.Tensor, - window_size: int, - kv_state: torch.Tensor = None, k_state: torch.Tensor = None, - eps: float = 1e-12, - mask_value: float=-1e8 - ): - """ - Hybrid attention combining sliding window and linear attentions - """ - - mask_window, mask_linear = self.get_masks( - window_size, q.shape[-2], k.shape[-2], q.device) - - # 1. Sliding window (softmax attention) - a_sm = torch.einsum( - 'bhmd,bhnd->bhmn', q.float(), k.float()) * (k.shape[-1] ** -0.5) - a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value) - # torch.softmax(a_sm, dim=-1), but we account for the max when combining - a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True) - a_sm = window_factor * torch.exp(a_sm - a_sm_max) - sum_sm = a_sm.sum(dim=-1, keepdim=True) - - # 2. Under window (linear attention) - a_ln = torch.einsum('bhmd,bhnd->bhmn', f_q.float(), f_k.float()) - a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0).to(q.dtype) - sum_ln = a_ln.sum(dim=-1, keepdim=True) - - # 3. Combine - # Allow outputs to also depend on prior kv_state and k_state - y = torch.einsum('bhmn,bhnd->bhmd', a_sm + a_ln, v) - y = (y / (sum_sm + sum_ln)) - - self._init_kv_cache(k, v, f_k) - return y.to(q.dtype) # attention weights only for the last chunk - - - def _update_kv_cache(self, keys, values, fmap_k): - # if self.tp_rank == 0 and self.layer_idx == 0: print("heyo 1 - hello update kv cache") - # get state from before - kv_state = self.lolcats_cache.kv_state - k_state = self.lolcats_cache.k_state - kv_cache_swa = self.lolcats_cache.kv_cache - k_cache = kv_cache_swa[:, 0] - v_cache = kv_cache_swa[:, 1] - - dtype = kv_state.dtype - - # update the linear attention states - # since we ignore the diag blocks, just grab last tokens of kv cache - cur_seq_len = k_cache.shape[-2] - # if self.tp_rank == 0 and self.layer_idx == 0: print(f"{cur_seq_len=}") - if cur_seq_len >= self.window_size: - # if self.tp_rank == 0 and self.layer_idx == 0: print(f"Updating the kv_state and k_state...") - k_state = fmap_k(k_cache[:, :, :1, :]) - v_state = v_cache[:, :, :1, :] - kv_state = torch.einsum('bhlf,bhld->bhfd', k_state.float(), v_state.float()).to(dtype) # b, h, f, d - self.lolcats_cache.kv_state += kv_state.to(kv_state.dtype) - self.lolcats_cache.k_state += k_state - - # update swa states - if cur_seq_len < self.window_size: - # only add to cache - k_cache = torch.cat([k_cache, keys], dim=-2) - v_cache = torch.cat([v_cache, values], dim=-2) - else: - # remove oldest key and value and append - k_cache = torch.cat([k_cache[:, :, 1:, :], keys], dim=-2) - v_cache = torch.cat([v_cache[:, :, 1:, :], values], dim=-2) - kv_cache_swa = torch.stack([k_cache, v_cache], dim=1) - self.lolcats_cache.kv_cache = kv_cache_swa - - # if self.tp_rank == 0 and self.layer_idx == 0: print("heyo 2 - bye update kv cache") - return self.lolcats_cache.kv_state, self.lolcats_cache.k_state, k_cache, v_cache - - - def recurrent_attention( - self, q: torch.Tensor, k: torch.Tensor, - f_q: torch.Tensor, f_k: torch.Tensor, - v: torch.Tensor, - window_factor: torch.Tensor, - linear_factor: torch.Tensor, - window_size: int, - fmap_q, fmap_k, - kv_state: torch.Tensor = None, - k_state: torch.Tensor = None, - eps: float = 1e-12, mask_value: float=-1e8 - ): - dtype = torch.float32 - kv_state, k_state, k_cache, v_cache = self._update_kv_cache(k, v, fmap_k) - - # Softmax attention terms - a_sm = torch.einsum('bhmd,bhnd->bhmn', q.float(), k_cache.float()) * (k.shape[-1] ** -0.5) - a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True) - a_sm = window_factor * torch.exp(a_sm - a_sm_max) - sum_sm = a_sm.sum(dim=-1, keepdim=True) - y_sm = torch.einsum('bhmn,bhnd->bhmd', a_sm.float(), v_cache.float()) - - # Combine with linear attention terms - f_q = fmap_q(q) - y_ln = linear_factor * torch.einsum('bhlf,bhfd->bhld', f_q.float(), kv_state.float()) - sum_ln = linear_factor * torch.einsum('bhld,bhnd->bhl', f_q.float(), k_state.float())[..., None] - - y = y_sm + y_ln - attn_output = (y / (sum_sm + sum_ln)).to(q.dtype) - return attn_output - - -class LlamaLolcatsAttention(LlamaAttention): - def __init__(self, layer_idx, use_base_attn, *args, **kwargs): - - super().__init__(*args, **kwargs) - self.use_base_attn = use_base_attn - if self.use_base_attn: - # coppy the original self.attn into self.base_attn before we override - # use deepcopy to avoid any shared references - import copy - self.base_attn = copy.deepcopy(self.attn) - - self.attn = LlamaLolcatsAttentionActual(self.num_heads, - self.head_dim, - self.num_kv_heads, - layer_idx) - self.head_size = self.head_dim - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - - _device = self.qkv_proj.weight.device - _dtype = self.qkv_proj.weight.dtype - - _feature_dim = 64 - _feature_map_kwargs = { - "num_heads": self.num_heads, - "head_dim": self.head_dim, - "feature_dim": _feature_dim, - "dtype": _dtype, - "device": _device, - } - self.feature_dim = _feature_dim - self.window_size = 64 - - tp_rank = get_tensor_model_parallel_rank() - self.tp_rank = tp_rank - self.layer_idx = layer_idx - - self.feature_map_q = FeatureMap(**_feature_map_kwargs) - self.feature_map_k = FeatureMap(**_feature_map_kwargs) - self.window_factors = nn.Parameter( - torch.ones(1, self.num_heads, 1, 1, device=_device, dtype=_dtype)) - - - def load_window_factors(self, loaded_weight: torch.Tensor): - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() - if tp_size > 1: - - num_heads_per_rank = self.num_heads - start_idx = tp_rank * num_heads_per_rank - end_idx = start_idx + num_heads_per_rank - - if self.layer_idx == 0 and tp_rank == 0: - print(loaded_weight) - - if self.layer_idx < 2: - print(f"{num_heads_per_rank=}") - print(f"{tp_rank=}; {loaded_weight.shape=}; {start_idx=}; {end_idx=}") - - sharded_weight = loaded_weight[:, start_idx:end_idx, :, :] - - else: - - sharded_weight = loaded_weight - - assert self.window_factors.shape == sharded_weight.shape, \ - f"Shape mismatch: {self.window_factors.shape} vs {sharded_weight.shape}" - - with torch.no_grad(): - self.window_factors.copy_(sharded_weight) - - def load_feature_map_q(self, loaded_weight: torch.Tensor): - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() - - if tp_size > 1: - - num_heads_per_rank = self.feature_map_q.num_heads - start_idx = tp_rank * num_heads_per_rank - end_idx = start_idx + num_heads_per_rank - - sharded_weight = loaded_weight[start_idx:end_idx, :, :] - - if sharded_weight.shape[-1] != self.feature_map_q.layer.shape[-1]: - sharded_weight = sharded_weight[:, :, :self.feature_map_q.layer.shape[-1]] - - else: - - sharded_weight = loaded_weight - - assert self.feature_map_q.layer.shape == sharded_weight.shape, \ - f"Shape mismatch: {self.feature_map_q.layer.shape} vs {sharded_weight.shape}" - - with torch.no_grad(): - self.feature_map_q.layer.copy_(sharded_weight) - - def load_feature_map_k(self, loaded_weight: torch.Tensor): - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() - - if tp_size > 1: - - num_heads_per_rank = self.feature_map_k.num_heads - start_idx = tp_rank * num_heads_per_rank - end_idx = start_idx + num_heads_per_rank - - sharded_weight = loaded_weight[start_idx:end_idx, :, :] - - if sharded_weight.shape[-1] != self.feature_map_k.layer.shape[-1]: - sharded_weight = sharded_weight[:, :, :self.feature_map_k.layer.shape[-1]] - - else: - - sharded_weight = loaded_weight - - assert self.feature_map_k.layer.shape == sharded_weight.shape, \ - f"Shape mismatch: {self.feature_map_k.layer.shape} vs {sharded_weight.shape}" - - with torch.no_grad(): - self.feature_map_k.layer.copy_(sharded_weight) - # self.feature_map_k.layer.normal_(std=1) - - def merge_lora_to_qkv_parallel(self, # param: Parameter, - loaded_delta: torch.Tensor, - loaded_shard_id: str = 'q', - total_num_heads: int = 32, - total_num_kv_heads: int = 4, - head_size: int = 128): - """ - Merge computed delta_AB into QKV parallel weights - - Based off of vLLM linear layer: https://github.com/vllm-project/vllm/blob/bc6e42a9b19364e07da9f279edd81796541d147d/vllm/model_executor/layers/linear.py#L762 - then Rahul, then Claude 3.5 Sonnet - - model.layers.1.self_attn.qkv_proj.weight torch.Size([1280, 8192]) - --> output_dim 0 - model.layers.1.self_attn.o_proj.weight torch.Size([8192, 1024]) - --> output_dim 0 - - apply this three times for q, k, and v LoRA deltas to the same layer - """ - - param = self.qkv_proj.weight - param_data = param.data - output_dim = getattr(param, "output_dim", None) - - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() - # num_heads = divide(total_num_heads, tp_size) - # if tp_size >= total_num_kv_heads: - # num_kv_heads = 1 - # # num_kv_head_replicas = divide(tp_size, total_num_kv_heads) - # else: - # num_kv_heads = divide(total_num_kv_heads, tp_size) - # # num_kv_head_replicas = 1 - num_heads = total_num_heads - num_kv_heads = total_num_kv_heads - - num_original_kv_heads = 8 # all Llama 3.1 models have 8 kv heads in total - - num_kv_head_replicas = tp_size // num_original_kv_heads - - if loaded_shard_id == "q": - shard_offset = 0 - shard_size = num_heads * head_size - elif loaded_shard_id == "k": - shard_offset = num_heads * head_size - shard_size = num_kv_heads * head_size - elif loaded_shard_id == "v": - shard_offset = (num_heads + num_kv_heads) * head_size - shard_size = num_kv_heads * head_size - - # print(f"{tp_rank=}, {tp_size=}") - if loaded_shard_id == "q": - start_idx = tp_rank * shard_size - else: - start_idx = (tp_rank // num_kv_head_replicas) * shard_size - - device = param_data.device - - param_data = param_data.narrow(output_dim, shard_offset, shard_size) - # print(f'{loaded_shard_id=}') - # print(f'{shard_offset=}, {shard_size=}, {shard_offset+shard_size=}') - # print(f'{output_dim=}, {start_idx=}, {param_data.shape=}') - # print('-' * 10) - - # self.qkv_proj.weight.data[shard_offset:shard_offset+shard_size, :] += ( - # loaded_delta.narrow(output_dim, start_idx, shard_size).to(device) - # ) - # print(f'Loading {loaded_shard_id} {start_idx}-{start_idx+shard_size} into {param_data.shape}, which is slice({shard_offset}, {shard_offset+shard_size})') - try: - param_data.copy_(param_data + loaded_delta.narrow(output_dim, start_idx, shard_size).to(device)) - # print(f"Loaded {loaded_shard_id} into {param_data.shape}") - except Exception as e: - print(f"Error: {e}") - print(f"{loaded_shard_id=}") - print(f"{output_dim=}") - print(f"{start_idx=}") - print(f"{shard_size=}") - print(f"{param_data.shape=}") - print(f"{loaded_delta.shape=}") - print(f"{tp_rank=}") - print(f"{tp_size=}") - - def merge_lora_to_o_parallel(self, - loaded_delta: torch.Tensor): - """ - Merge computed delta_AB into output projection (RowParallel linear) - """ - param = self.o_proj.weight - param_data = param.data - input_dim = getattr(param, "input_dim", None) - device = param_data.device - - # print('o_proj {input_dim=}') - - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() - - if input_dim is not None: - shard_size = param_data.shape[input_dim] - start_idx = tp_rank * shard_size - loaded_delta = loaded_delta.narrow(input_dim, start_idx, shard_size) - - # Special case for loading scales off disk, which often do not - # have a shape (such as in the case of AutoFP8). - if len(loaded_delta.shape) == 0: - loaded_delta = loaded_delta.reshape(1) - - # print('{param_data.shape=} | {loaded_delta.shape=}') - # assert param_data.shape == loaded_delta.shape - param_data.copy_(param_data + loaded_delta.to(device)) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, - problem_idx: int - ) -> torch.Tensor: - ndim = hidden_states.dim() - qkv, _ = self.qkv_proj(hidden_states) - seq_len = hidden_states.shape[-2] - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self.rotary_emb(positions, q, k) - - attn_output = self.attn( - q, k, v, - fmap_q=self.feature_map_q, - fmap_k=self.feature_map_k, - window_factors=self.window_factors, - state=None, - attn_metadata=attn_metadata - ) - - ref_output = None - expt_tag = '_cria_alpaca_final' - if self.use_base_attn and self.layer_idx % 9 == 0: - ref_output = self.base_attn( - q, k, v, - attn_metadata=attn_metadata, - kv_cache=kv_cache, - ) - - dir_path = f"/data/simran/mmlu_hybrid_outputs_{expt_tag}/" - if not os.path.exists(dir_path): os.makedirs(dir_path, exist_ok=True) - fpath = f"{dir_path}/our_attn_output_problem{problem_idx}_rank{self.tp_rank}_layer{self.layer_idx}.pt" - torch.save(attn_output, fpath) - fpath = f"{dir_path}/ref_attn_output_problem{problem_idx}_rank{self.tp_rank}_layer{self.layer_idx}.pt" - torch.save(ref_output, fpath) - # print(f"Saved!") - # end save stuff - - # outputs - full_seq_len = attn_output.shape[-2] # in case we updated the length - attn_output = attn_output.transpose(1, 2).contiguous().view( - -1, full_seq_len, self.num_heads * self.head_dim - ) - output, _ = self.o_proj(attn_output) - if output.dim() > ndim: - output = output.squeeze(0) - output = output[-seq_len:, ...] # put back the original seq_len - - if self.use_base_attn and self.layer_idx % 9 == 0: - ref_y, _ = self.o_proj(ref_output) - dir_path = f"/data/simran/mmlu_hybrid_y_outs_{expt_tag}/" - if not os.path.exists(dir_path): os.makedirs(dir_path, exist_ok=True) - fpath = f"{dir_path}/our_y_out_problem{problem_idx}_rank{self.tp_rank}_layer{self.layer_idx}.pt" - torch.save(output, fpath) - fpath = f"{dir_path}/ref_y_out_problem{problem_idx}_rank{self.tp_rank}_layer{self.layer_idx}.pt" - torch.save(ref_y, fpath) - - return output - - -class LlamaLolcatsForCausalLM(LlamaForCausalLM): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - print(f"LOLCATS!!!: Loading model with config: {self.config}") - - tp_rank = get_tensor_model_parallel_rank() - self.tp_rank = tp_rank - - softmax_attentions = getattr(self.config, 'softmax_attentions', []) - print(f"{softmax_attentions=}") - - use_base_attn = getattr(self.config, 'use_base_attn', False) - - for i in range(len(self.model.layers)): - if i in softmax_attentions: - pass - else: - self.model.layers[i].self_attn = LlamaLolcatsAttention( - i, - use_base_attn=use_base_attn, - config=self.config, - hidden_size=self.config.hidden_size, - num_heads=self.config.num_attention_heads, - num_kv_heads=getattr(self.config, "num_key_value_heads", - self.config.num_attention_heads), - rope_theta=self.config.rope_theta, - rope_scaling=self.config.rope_scaling, - ) - print(self.model) - - - def get_device(self): - device = next(self.parameters()).device - return str(device) - - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - super().load_weights(weights) - - # r = 8 - # lora_alpha = 16 - # lora_dropout = 0 - - # model_size = 8 - # FINETUNE_PATH = '/home/rahul/code/lolcats/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_1_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-ms=2500-se=0-re=100-lzi=1-bs=1-gas=8-nte=2-ms=2500-se=0-re=100_ft.pt' - # MLP_PATH = '/home/rahul/code/lolcats/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_1_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-ms=2500-se=0-re=100-lzi=1_distill.pt' - # # merge the MLP and FINETUNE weights as adapter weights - # adapter_weights = torch.load(FINETUNE_PATH, weights_only=True) - # adapter_weights.update(torch.load(MLP_PATH, weights_only=True)) - # print(adapter_weights.keys()) - # # only keep any weight with 'feature' or 'window' or 'lora' in the key - # adapter_weights = {k: v for k, v in adapter_weights.items() if 'feature' in k or 'window' in k or 'lora' in k} - - # model_size = 70 - # # PATH = f'/data/rahul/checkpoints/{model_size}b.pt' - # PATH = f'/home/rahul/code/lolcats/ckpt_lora-dl-d=distill_redpajama_xent1_mse1000_lr1e-2-m=distill_llama3_1_70b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_redpajama-dcs=512-se=0-re=4-lzi=1-dcs=512-se=0-re=4.pt' - - ########### 405 B ############ - - # PATH = '/home/rahul/code/lolcats/ckpts/seqlen768.pt' # 405B at 768 seqlen - - # 1. Alpaca Cria QV Rank 4 -- with hybridization - PATH = '/home/rahul/code/lolcats/ckpt_lora-dl-d=distill_llama_405b_xent1_mse1000_lr1e-2-m=distill_llama3_1_405b_lk_smd_wtk64_fd64_w01_h72_80_117_125-f=finetune_layer_mini_xent1_mse1000-s=0-se=0-re=0-ef=finetune_llama_405b_qkvo_e2_h72_80_117_125-ft_lora=0-se=0-re=0-alpaca.pt' - - # 2. Alpaca Cria QV Rank 4 -- pure - # PATH = '/home/rahul/code/lolcats/ckpt_lora-dl-d=distill_llama_405b_xent1_mse1000_lr1e-2-m=distill_llama3_1_405b_lk_smd_wtk64_fd64_w01-f=finetune_layer_mini_xent1_mse1000-s=0-se=0-re=0-ef=finetune_llama_405b_qkvo_e2-ft_lora=0-se=0-re=0-ef=finetune_llama_405b_qkvo_e2-ft_lora=0_epoch2.pt' - - # 3. RP Cria QV Rank 4 -- pure - # PATH = '/home/rahul/code/lolcats/ckpts/cria_rp.pt' # 780.pt step - # PATH = '/home/rahul/code/lolcats/ckpt_lora-dl-d=rp_distill_llama_405b_xent1_mse1000_lr1e-2-m=distill_llama3_1_405b_lk_smd_wtk64_fd64_w01-f=rp_finetune_llama_40b_qv_hparams-s=0-se=0-re=0-ef=finetune_llama_405b_qkvo_e2_rp-ft_lora=0-se=0-re=0-s=1670.pt' - - print(f"PATH INFERENCE: {PATH}!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") - adapter_weights_path = os.getenv("LOLCATS_ADAPTER_PATH", PATH) - adapter_weights = torch.load(adapter_weights_path, weights_only=True) - - adapter_weights_copy = OrderedDict({}) - - for key, value in adapter_weights.items(): - key_suffix = key[key.rindex("model.")+6:] - adapter_weights_copy[key_suffix] = value - - adapter_weights = adapter_weights_copy - updated_keys = [] - - print("\n") - num_layers = len(self.model.layers) - for layer_idx, layer in enumerate(self.model.layers): - if layer_idx == 0: - print(f'Weight factors before checkpoint load {self.tp_rank=}, {layer.self_attn.window_factors.shape}, {layer.self_attn.window_factors.flatten()}') - - window_factors_key = f'layers.{layer_idx}.self_attn.window_factors' - if window_factors_key in adapter_weights: - layer.self_attn.load_window_factors(adapter_weights[window_factors_key]) - updated_keys.append(window_factors_key) - - if layer_idx == 0: - print(f'Weight factors after checkpoint load {self.tp_rank=}, {layer.self_attn.window_factors.shape}, {layer.self_attn.window_factors.flatten()}') - - fm_q_key = f'layers.{layer_idx}.self_attn.feature_map_q.mlp.layer' - if fm_q_key in adapter_weights: - # if layer_idx in [0, num_layers-1]: - # # print("\n") - # # print(f'FMAP Q before checkpoint load {self.tp_rank=}, {layer.self_attn.feature_map_q.layer.shape}, {layer.self_attn.feature_map_q.layer[0,0,:4]}') - - layer.self_attn.load_feature_map_q(adapter_weights[fm_q_key]) - updated_keys.append(fm_q_key) - - # if layer_idx in [0, num_layers-1]: - # print(f'FMAP Q after checkpoint load; {layer.self_attn.feature_map_q.layer.shape},{layer.self_attn.feature_map_q.layer[0,0,:4]}') - - fm_k_key = f'layers.{layer_idx}.self_attn.feature_map_k.mlp.layer' - if fm_k_key in adapter_weights: - layer.self_attn.load_feature_map_k(adapter_weights[fm_k_key]) - updated_keys.append(fm_k_key) - - weight_name = 'layers.{layer_idx}.self_attn.{proj}.lora_{a_or_b}.default.weight' - target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"] - # target_modules = ["q_proj", "k_proj", "v_proj"] - # target_modules = ["k_proj", "v_proj"] - # target_modules = ["q_proj", "k_proj"] - - r = 8 - lora_alpha = 16 - lora_dropout = 0 - - for proj in target_modules: - lora_A_key = weight_name.format(layer_idx=layer_idx, proj=proj, a_or_b='A') - lora_B_key = weight_name.format(layer_idx=layer_idx, proj=proj, a_or_b='B') - if lora_A_key in adapter_weights: - weight_A = adapter_weights[lora_A_key] - weight_B = adapter_weights[lora_B_key] - delta_AB = get_delta_weight(weight_A, weight_B, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout) - - # if layer_idx in [0, num_layers-1]: - # print(f'layer {layer_idx} weight_A.shape: {weight_A.shape} | weight_B.shape: {weight_B.shape} | delta_AB.shape: {delta_AB.shape}') - # print(f'layer {layer_idx} proj {proj} delta_AB', delta_AB.shape) - - if proj == 'o_proj': - # if layer_idx in [0, num_layers-1]: - # print("\n") - # print(f'Layer {layer_idx} {proj} weight before checkpoint load, {layer.self_attn.o_proj.weight.shape}, {layer.self_attn.o_proj.weight[0,:4]}') - - layer.self_attn.merge_lora_to_o_parallel(delta_AB) - - # if layer_idx in [0, num_layers-1]: - # print(f'Layer {layer_idx} {proj} weight after checkpoint load, {layer.self_attn.o_proj.weight.shape}, {layer.self_attn.o_proj.weight[0,:4]}') - else: - # if layer_idx in [0, num_layers-1] and proj in ['q_proj']: - # print(f'Layer {layer_idx} {proj} weight before checkpoint load, {layer.self_attn.qkv_proj.weight.shape}, {layer.self_attn.qkv_proj.weight[0,:4]}') - - layer.self_attn.merge_lora_to_qkv_parallel( - delta_AB, - loaded_shard_id=proj.split('_')[0], - total_num_heads=layer.self_attn.num_heads, - total_num_kv_heads=layer.self_attn.num_kv_heads,head_size=layer.self_attn.head_dim) - - # if layer_idx in [0, num_layers-1] and proj in ['q_proj']: - # print(f'Layer {layer_idx} {proj} weight after checkpoint load, {layer.self_attn.qkv_proj.weight.shape}, {layer.self_attn.qkv_proj.weight[0,:4]}') - updated_keys.append(lora_A_key) - updated_keys.append(lora_B_key) - - assert len(set(adapter_weights_copy.keys()) - set(updated_keys)) == 0, \ - f"UNUPDATED KEYS: {set(adapter_weights_copy.keys()) - set(updated_keys)}" - - -def transpose(weight, fan_in_fan_out): - if not fan_in_fan_out: - return weight - - if isinstance(weight, torch.nn.Parameter): - return torch.nn.Parameter(weight.T) - return weight.T - - -def get_delta_weight(weight_A: torch.Tensor, weight_B: torch.Tensor, - r: int = 8, lora_alpha: float = 16, lora_dropout: float = 0, - fan_in_fan_out: bool = False,): - - device = weight_B.device - dtype = weight_B.dtype - # From https://github.com/huggingface/peft/blob/850eeb5c3a5cf692f5612c7c733b13fde184e05d/src/peft/tuners/lora/layer.py#L512 - cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16) - if cast_to_fp32: - weight_A = weight_A.float() - weight_B = weight_B.float() - scaling = lora_alpha / r - output_tensor = transpose(weight_B @ weight_A, fan_in_fan_out) * scaling - if cast_to_fp32: - output_tensor = output_tensor.to(dtype=dtype) - return output_tensor - diff --git a/demos/vllm_integration/vllm_files/test_vllm_aw.py b/demos/vllm_integration/vllm_files/test_vllm_aw.py deleted file mode 100644 index 9738367..0000000 --- a/demos/vllm_integration/vllm_files/test_vllm_aw.py +++ /dev/null @@ -1,82 +0,0 @@ -import os -import math -from openai import OpenAI - -def calculate_perplexity(logprobs): - total_log_prob = 0 - token_count = 0 - - for token_logprobs in logprobs[1:]: - if token_logprobs: - total_log_prob += list(token_logprobs.values())[0].logprob - token_count += 1 - - if token_count == 0: - return float('inf') - - print(token_count) - perplexity = math.exp(-total_log_prob / token_count) - return perplexity - -def calc_perplexity_serve(logprobs, trim=1): - logprobs = logprobs[:-trim] - logprobs = [x for x in logprobs if x is not None] - print(f"{len(logprobs)=}") - return math.exp(-sum(logprobs) / len(logprobs)) - -if __name__ == '__main__': - use_served_model = True - model_size = 70 # [8, 70] - PATH = f"/data/rahul/models/Meta-Llama-3.1-{model_size}B/" - CKPT_PATH = f'/data/rahul/checkpoints/{model_size}b.pt' - openai_api_base = "http://0.0.0.0:8000/v1" - - os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7" - os.environ["LOLCATS_ADAPTER_PATH"] = CKPT_PATH - - prompts = [ - "I'm Michael:- 3rd-year Computer Science PhD student advised by Chris Ré.- Labmate at HazyResearch, Stanford AI Lab, Stanford Machine Learning Group. I currently work on deep learning architectures for expressive + efficient long sequence modeling, and using these advances to enable learning from new tasks and data types and I also care about deep learning robustness. I received my A.B. in", - # Statistics and Computer Science at Harvard in 2020. I'm grateful to have" - # "The 2024 Summer Paralympics (French: Jeux paralympiques d'été de 2024), also known as the Paris 2024 Paralympic Games, and branded as Paris 2024, is the 17th Summer Paralympic Games, an international multi-sport parasports event governed by the International Paralympic Committee, being held in Paris, France, from 28 August to 8 September 2024. These games mark the first time Paris is hosting the Summer Paralympics and the second time that France is hosting the new ", - # "Manchester United Football Club, commonly referred to as Man United (often Man United (often stylised as Man Utd), or simply United, is a Man United (often stylised as Man Utd), or simply United, is a professional football club based in Old Trafford, Greater Manchester, England. They compete in the Premier League, the top tier of English football. Nicknamed the Red Devils, they were founded as Newton Heath LYR Football Club in 1878, but changed their name to Manchester United in 1902. After a spell playing in Clayton, Manchester, the club moved to their current stadium, Old Trafford, in 1910. " - ] - - if use_served_model: - client = OpenAI(base_url=openai_api_base, api_key="EMPTY") - models = client.models.list() - model = models.data[0].id - tokens = 3 - outputs = client.completions.create( - model=model, - prompt=prompts, - temperature=0, - logprobs=1, - max_tokens=tokens, - seed=0, - echo=True, - ) - for prompt, choice in zip(prompts, outputs.choices): - logprobs = choice.logprobs.token_logprobs - print(f"Prompt: {len(prompt.split())}\n{prompt}") - print(f"Completion: {choice.text.replace(prompt, '')}") - print(f'Perplexity: {calc_perplexity_serve(logprobs, trim=tokens)}') - print("\n") - else: - - os.environ["VLLM_LOGGING_LEVEL"] = "DEBUG" - - from vllm import ModelRegistry, LLM, SamplingParams - - from src.model.modeling_llama_vllm import LlamaLolcatsForCausalLM - ModelRegistry.register_model("LlamaLolcatsForCausalLM", LlamaLolcatsForCausalLM) - sampling_params = SamplingParams(temperature=0.0, logprobs=1, prompt_logprobs=1, min_tokens=1, max_tokens=1) - llm = LLM(model=PATH, tensor_parallel_size=8, enforce_eager=True) - outputs = llm.generate( - prompts, - sampling_params, - ) - logprobs = output.prompt_logprobs - for output in outputs: - print(f"Perplexity: {calculate_perplexity(output.prompt_logprobs):.4f}") - - # Print the outputs. diff --git a/src/model/convert_model.py b/src/model/convert_model.py index f4b334a..41dd1dc 100644 --- a/src/model/convert_model.py +++ b/src/model/convert_model.py @@ -128,7 +128,7 @@ def get_attention(attention_type: str, **kwargs: any): ## TK generation build (requires Thunderkittens) elif attention_type == 'lolcats_llama_window_tk_gen': - from .linear_attention.linear_window_attention_tk_gen import LolcatsWindowAttentionTKGen + from .linear_attention import LolcatsWindowAttentionTKGen return partial(LolcatsWindowAttentionTKGen, **kwargs) else: @@ -144,6 +144,10 @@ def get_attention_cache(attention_type: str, past_key_values: any = None): return past_key_values # print(f'Returning attention cache based on attention_type == {attention_type}') + elif 'lolcats_llama_window_tk_gen' in attention_type: + from .linear_attention import LinearAttentionTKWindowGenerationCache + return LinearAttentionTKWindowGenerationCache() + elif 'llama_window_tk' in attention_type: from .linear_attention import LinearAttentionTKWindowCache return LinearAttentionTKWindowCache() diff --git a/src/model/linear_attention/__init__.py b/src/model/linear_attention/__init__.py index dd3e49f..2482d28 100644 --- a/src/model/linear_attention/__init__.py +++ b/src/model/linear_attention/__init__.py @@ -17,3 +17,7 @@ from .linear_window_attention_sw_long import ( LolcatsSlidingWindowLongAttention, ) +from .linear_window_attention_tk_gen import ( + LolcatsWindowAttentionTKGen, + LinearAttentionTKWindowGenerationCache +) diff --git a/src/model/linear_attention/linear_window_attention_tk_gen.py b/src/model/linear_attention/linear_window_attention_tk_gen.py index 9dcbb27..ffc2abf 100644 --- a/src/model/linear_attention/linear_window_attention_tk_gen.py +++ b/src/model/linear_attention/linear_window_attention_tk_gen.py @@ -5,12 +5,15 @@ import torch import torch.nn.functional as F -from thunderkittens import hedgehog as tk_window_hedgehog_attention +try: + from thunderkittens import hedgehog as tk_window_hedgehog_attention + print(f"Successfully imported ThunderKittens for TK window attention") +except: + print(f"Failed to import ThunderKittens for TK window attention") from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention from .linear_attention import LinearAttentionState - class LolcatsWindowAttentionTKGen(LolcatsTKWindowLongAttention): def __init__(self, *args, window_size: int = 64, **kwargs): super().__init__(*args, **kwargs) @@ -18,6 +21,11 @@ def __init__(self, *args, window_size: int = 64, **kwargs): self.base_inference = False self.window_size = 64 # hard-coded support for TK kernel self.decode_window_size = 64 + + b, h, l, d = 1, 32, 8192, 128 + self.y_true = torch.zeros(b, h, l, d, dtype=torch.bfloat16, device='cuda') + self.kv_state = torch.zeros(b, h, d, d, dtype=torch.float32, device='cuda') + self.k_state = torch.zeros(b, h, d, dtype=torch.float32, device='cuda') def forward(self, hidden_states: torch.Tensor, @@ -61,16 +69,16 @@ def forward(self, + linear_factors * torch.einsum('bhld,bhdf->bhlf', f_q.float(), kv_state.float())) sum_ln = linear_factors * torch.einsum( 'bhld,bhnd->bhl', f_q.float(), k_state.float())[..., None] - y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype) + self.y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype) else: # Process prefill # Use TK-implemented linear + terrace window attention b, h, l, d = q.shape device = q.device # tk.hedgehog arguments - y_pred = torch.zeros(b, h, l, d, dtype=torch.bfloat16, device=device) - kv_state = torch.zeros(b, h, d, d, dtype=torch.float32, device=device) - k_state = torch.zeros(b, h, d, dtype=torch.float32, device=device) + # y_true = torch.zeros(b, h, l, d, dtype=torch.bfloat16, device=device) + # kv_state = torch.zeros(b, h, d, d, dtype=torch.float32, device=device) + # k_state = torch.zeros(b, h, d, dtype=torch.float32, device=device) betas = F.sigmoid(self.window_factors[0, :, 0, 0].to(dtype=torch.float32)) alphas = (1 - betas if self.affine_attention_factors else torch.ones(betas.shape, dtype=torch.float32, device=device)) @@ -83,13 +91,15 @@ def forward(self, # f_k[:, :, :-self.window_size], # v[:, :, :-self.window_size]) # b, h, f, d # 4. k_state = f_k[:, :, :-self.window_size].sum(dim=-2) # b, h, d + tk_window_hedgehog_attention(q.contiguous(), k.contiguous(), v.contiguous(), - y_pred, k_state, kv_state, + self.y_true, self.k_state, self.kv_state, q_map, k_map, alphas, betas) - past_key_value.update_with_kv(kv_state, k_state.unsqueeze(-2), k, v, self.layer_idx) + + past_key_value.update_with_kv(self.kv_state, self.k_state.unsqueeze(-2), k, v, self.layer_idx) # Concatenate heads and apply output projection - y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size) + y_true = self.y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size) y_true = self.o_proj(y_true) return y_true, None, past_key_value