diff --git a/README.md b/README.md index 089bd55..83c97c5 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,8 @@ In this README: - Getting started with dependencies, installation, and experiment configs - Sample commands for 7B+ LLMs (e.g., Mistral-7B-v0.1, Llama-3-8B, Llama-3.1-8B; anything you can run on a single GPU) +In the [lolcats-scaled branch](https://github.com/HazyResearch/lolcats/tree/lolcats-scaled), we provide details for larger 70B and 405B LLMs. + --- ## Getting started @@ -107,11 +109,11 @@ python setup.py install ### ThunderKittens linear attention + sliding window kernel -We also implemented a fused linear attention + sliding window kernel with [ThunderKittens](https://github.com/HazyResearch/ThunderKittens). +We also implemented a fused linear attention + sliding window kernel with the [ThunderKittens CUDA framework](https://github.com/HazyResearch/ThunderKittens). For the linearizng layer, see [`./src/model/linear_attention/linear_window_attention_tk_gen.py`](https://github.com/HazyResearch/lolcats/blob/main/src/model/linear_attention/linear_window_attention_tk_gen.py) -But full repository support coming soon! (requires an import) +You can install the kernel and benchmark 8B models (LoLCATS linearized and Llama Transformer) with and without our ThunderKittens CUDA kernel using the details [in this README.md](). Our 8B model will auto-download from our [HuggingFace checkpoint](https://huggingface.co/hazyresearch/lolcats-llama-3.1-8b-distill). ### More! @@ -247,27 +249,14 @@ python distill_llama.py --model_config distill_llama3_1_8b_lk_t2r \ ### Demoing linear attention 7B+ models -The above scripts will save two checkpoints: (1) for the learned attention feature maps (denoted by a `_distill` suffix), (2) for the LoRA finetuning weights (denoted by a `_ft` suffix). We uploaded a couple starter checkpoints in `./checkpoints/`, where for any linearized LLM we only need to save these layers (~0.2% of a 7B LLM's parameters). (We also provide additional checkpoints on HuggingFace). - -To chat with these models (albeit in an unoptimized PyTorch implementation), you can run: - -**Llama 3.1 8B** -```bash -python -Wignore demo_lolcats_llm.py \ ---attn_mlp_checkpoint_path './checkpoints/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_1_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-se=0-re=420-lzi=1_distill.pt' \ ---finetune_checkpoint_path './checkpoints/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_1_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-se=0-re=420-lzi=1-bs=1-gas=8-nte=2-ms=-1-se=0-re=420_ft.pt' \ ---num_generations 1 --benchmark -``` +The above scripts will save two checkpoints: (1) for the learned attention feature maps (denoted by a `_distill` suffix), (2) for the LoRA finetuning weights (denoted by a `_ft` suffix). We uploaded a couple starter checkpoints in `./checkpoints/`, where for any linearized LLM we only need to save these layers (~0.2% of a 7B LLM's parameters). We have provided [sample checkpoints on HuggingFace](https://huggingface.co/collections/hazyresearch/lolcats-670ca4341699355b61238c37). -**Llama 3 8B** +Use the commands provided at `demos/demo_8b.sh` to run inference with our LoLCATS - Llama 3.1 8B checkpoint, which will be downloaded from HuggingFace. The downloaded checkpoints require under <1GB, and are inserted into your local Meta Llama 3.1 model in 16-bit precision -- please ensure you have downloaded the base model and specify your path to it in the configs in demo_8b.sh. To run the demo: ```bash -python -Wignore demo_lolcats_llm.py \ ---attn_mlp_checkpoint_path './checkpoints/distill_llama3_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-se=0-re=12-lzi=1_distill.pt' \ ---finetune_checkpoint_path './checkpoints/distill_llama3_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-se=0-re=12-lzi=1-bs=1-gas=8-nte=2-se=0-re=12_ft.pt' \ ---num_generations 1 --benchmark +cd lolcats/ +bash demos/demo_8b.sh ``` - --- ### LM Evaluation Harness Evaluation diff --git a/checkpoints/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_1_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-se=0-re=420-lzi=1-bs=1-gas=8-nte=2-ms=-1-se=0-re=420_ft.pt b/checkpoints/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_1_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-se=0-re=420-lzi=1-bs=1-gas=8-nte=2-ms=-1-se=0-re=420_ft.pt deleted file mode 100644 index 306f97f..0000000 Binary files a/checkpoints/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_1_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-se=0-re=420-lzi=1-bs=1-gas=8-nte=2-ms=-1-se=0-re=420_ft.pt and /dev/null differ diff --git a/checkpoints/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_1_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-se=0-re=420-lzi=1_distill.pt b/checkpoints/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_1_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-se=0-re=420-lzi=1_distill.pt deleted file mode 100644 index 26282c7..0000000 Binary files a/checkpoints/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_1_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-se=0-re=420-lzi=1_distill.pt and /dev/null differ diff --git a/checkpoints/distill_llama3_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-se=0-re=12-lzi=1-bs=1-gas=8-nte=2-se=0-re=12_ft.pt b/checkpoints/distill_llama3_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-se=0-re=12-lzi=1-bs=1-gas=8-nte=2-se=0-re=12_ft.pt deleted file mode 100644 index 55f27b1..0000000 Binary files a/checkpoints/distill_llama3_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-se=0-re=12-lzi=1-bs=1-gas=8-nte=2-se=0-re=12_ft.pt and /dev/null differ diff --git a/checkpoints/distill_llama3_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-se=0-re=12-lzi=1_distill.pt b/checkpoints/distill_llama3_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-se=0-re=12-lzi=1_distill.pt deleted file mode 100644 index 39a5533..0000000 Binary files a/checkpoints/distill_llama3_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-se=0-re=12-lzi=1_distill.pt and /dev/null differ diff --git a/demos/README.md b/demos/README.md new file mode 100644 index 0000000..c88d4a3 --- /dev/null +++ b/demos/README.md @@ -0,0 +1,58 @@ + +## Demos + +We describe how to use LoLCATS checkpoints. We include: +1. Demo script to talk to our models using Hugging Face checkpoints +2. Demo script to benchmark the pretrained 8B linearized versus base softmax attention models +3. Code to reproduce the MMLU numbers at 70B and 405B numbers using our uploaded HuggingFace checkpoints +4. Coming soon: VLLM integration with custom LoLCATS CUDA kernels! + +### Talk to pre-trained LoLCATS LLMs + +Use the commands provided at `demo_8b.sh` to run inference with our LoLCATS - Llama 3.1 8B checkpoint, which will be downloaded from Hugging Face. The downloaded checkpoints require under <1GB, and are inserted into your local Meta Llama 3.1 model in 16-bit precision -- please ensure you have downloaded the base model and specify your path to it in the configs in `demo_8b.sh`. To run the demo: +```bash +bash demo_8b.sh +``` + +### Fast inference with custom CUDA kernels + +We provide a custom CUDA prefill kernel written in the [ThunderKittens framework](https://github.com/HazyResearch/ThunderKittens). + +To install the kernel: +```bash +# Clone the repo +git clone https://github.com/HazyResearch/ThunderKittens +cd ThunderKittens +# In config.py, select 'hedgehog', then run: +source env.src +python setup.py install +``` + +As a quick end-to-end compare the prefill speed of the linearized LoLCATS 8B vs. the base Llama 8B model, we provide a script at: +```bash +bash benchmark_8b.sh +``` + +The code will print out the inference tokens per second per method. + +### 5-shot MMLU Eval + +First get the 5-shot MMLU data. We directly saved the tokenized examples produced by the [lm-eval-harness](https://github.com/EleutherAI/lm-evaluation-harness) codebase to a pickle file +``` +cd lolcats/inference/ +unzip mmlu.pkl.zip +``` + +We provide scripts to eval our 70B and 405B LoLCATS linearized checkpoints on HuggingFace on MMLU +```bash +cd lolcats/ +bash demos/llm_mmlu_eval/demo_70b.sh # runs on 1 8x80GB H100 node +sbatch demos/llm_mmlu_eval/demo_405b.sh # set to use 2 8x80GB H100 nodes +``` + +These call to the `demos/llm_mmlu_eval/eval_mmlu.py` file, which just loops through mmlu.pkl and uses the last-token model logits to get the predictions. + + +### VLLM Integration + +Coming Soon! diff --git a/demos/benchmark_8b.sh b/demos/benchmark_8b.sh new file mode 100644 index 0000000..a62369d --- /dev/null +++ b/demos/benchmark_8b.sh @@ -0,0 +1,41 @@ + + +CONFIG_DIR='/home/bfs/simran/clean4/lolcats/configs/' # update to your path + +""" Benchmarking the 8b model on the LOLCATS dataset """ + +# Run the linearized model with the ThunderKittens kernel +CUDA_VISIBLE_DEVICES=0 python -Wignore demo_lolcats_hf.py \ + --model_config_path ${CONFIG_DIR}/model/llama3_1_8b/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml \ + --distill_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml \ + --finetune_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/finetune_qkvo_alpaca_clean.yaml \ + --attn_mlp_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-distill' \ + --finetune_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-ft-lora' \ + --num_generations 1 \ + --use_cuda_kernels 1 \ + --benchmark \ + --max_new_tokens 1 + +# Run the linearized model *without* the ThunderKittens kernel +CUDA_VISIBLE_DEVICES=0 python -Wignore demo_lolcats_hf.py \ + --model_config_path ${CONFIG_DIR}/model/llama3_1_8b/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml \ + --distill_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml \ + --finetune_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/finetune_qkvo_alpaca_clean.yaml \ + --attn_mlp_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-distill' \ + --finetune_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-ft-lora' \ + --num_generations 1 \ + --use_cuda_kernels 0 \ + --benchmark \ + --max_new_tokens 1 + +# Run the base Llama model with Transformers SDPA attention +CUDA_VISIBLE_DEVICES=0 python -Wignore demo_lolcats_hf.py \ + --model_config_path ${CONFIG_DIR}/model/llama3_1_8b/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml \ + --distill_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml \ + --finetune_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/finetune_qkvo_alpaca_clean.yaml \ + --attn_mlp_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-distill' \ + --finetune_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-ft-lora' \ + --num_generations 1 \ + --use_attention \ + --benchmark \ + --max_new_tokens 1 diff --git a/demos/demo_8b.sh b/demos/demo_8b.sh new file mode 100644 index 0000000..a1e1b9a --- /dev/null +++ b/demos/demo_8b.sh @@ -0,0 +1,24 @@ + +CONFIG_DIR='/home/bfs/simran/clean4/lolcats/configs/' # update to your path + +# using huggingface checkpoints +CUDA_VISIBLE_DEVICES=0 python -Wignore demo_lolcats_hf.py \ + --model_config_path ${CONFIG_DIR}/model/llama3_1_8b/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml \ + --distill_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml \ + --finetune_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/finetune_qkvo_alpaca_clean.yaml \ + --attn_mlp_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-distill' \ + --finetune_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-ft-lora' \ + --num_generations 1 + +# if you train your own LoLCATS weights, you can use the following command to run inference: +# CHECKPOINT_DIR='/home/mzhang/projects/lolcats/checkpoints/' +# CUDA_VISIBLE_DEVICES=0 python -Wignore demo_lolcats_hf.py \ +# --model_config_path ${CONFIG_DIR}/model/llama3_1_8b/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml \ +# --distill_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml \ +# --finetune_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/finetune_qkvo_alpaca_clean.yaml \ +# --attn_mlp_checkpoint_path ${CHECKPOINT_DIR}/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_1_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-ms=2500-se=0-re=100-lzi=1_distill.pt \ +# --finetune_checkpoint_path ${CHECKPOINT_DIR}/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_1_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-ms=2500-se=0-re=100-lzi=1-bs=1-gas=8-nte=2-ms=2500-se=0-re=100_ft.pt \ +# --num_generations 1 + + + diff --git a/demo_lolcats_llm.py b/demos/demo_lolcats_hf.py similarity index 64% rename from demo_lolcats_llm.py rename to demos/demo_lolcats_hf.py index c0900a9..abeca52 100644 --- a/demo_lolcats_llm.py +++ b/demos/demo_lolcats_hf.py @@ -11,34 +11,43 @@ from transformers import TextStreamer, TextIteratorStreamer, AutoTokenizer +import sys +sys.path.append("../") + from src.utils.setup import seed_everything from src.utils.logging import print_header from src.model.pretrained import get_pretrained_loader from src.model.load_model import load_and_convert_attns, load_and_convert_finetune +try: + from huggingface_hub import hf_hub_download +except ImportError: + print("Please pip install huggingface-hub") -system_prompt = """Below is an instruction that describes a task. Write a response that appropriately completes the request. - -### Instruction: -{prompt} -### Response: -""" +system_prompt = """{prompt}""" def get_args(): parser = argparse.ArgumentParser() # Model load + setup + parser.add_argument("--model_config_path", type=str) + parser.add_argument("--finetune_config_path", type=str) + parser.add_argument("--distill_config_path", type=str) parser.add_argument("--attn_mlp_checkpoint_path", type=str, default=None) parser.add_argument("--finetune_checkpoint_path", type=str, default=None) parser.add_argument("--config_dir", type=str, default='configs') parser.add_argument("--seed", type=int, default=42) + # Inference + parser.add_argument("--use_cuda_kernels", type=int, default=0) + parser.add_argument("--use_attention", action='store_true', default=False) + # Generation parser.add_argument("--num_generations", type=int, default=1) - parser.add_argument("--top_k", type=int, default=50) - parser.add_argument("--top_p", type=float, default=0.95) - parser.add_argument("--max_new_tokens", type=int, default=1024) + parser.add_argument("--top_k", type=int, default=1.0) + parser.add_argument("--top_p", type=float, default=1.0) + parser.add_argument("--max_new_tokens", type=int, default=2) # Miscellaneous parser.add_argument("--benchmark", action='store_true', default=False) @@ -46,15 +55,7 @@ def get_args(): parser.add_argument("--debug", action='store_true', default=False) parser.add_argument("--huggingface_token", type=str, default=None) - # Alt - parser.add_argument("--attn_checkpoint_path", type=str, default=None) - parser.add_argument("--peft_checkpoint_path", type=str, default=None) - args = parser.parse_args() - if args.attn_mlp_checkpoint_path is None and args.attn_checkpoint_path is not None: - args.attn_mlp_checkpoint_path = args.attn_checkpoint_path - if args.finetune_checkpoint_path is None and args.peft_checkpoint_path is not None: - args.finetune_checkpoint_path = args.peft_checkpoint_path return args @@ -77,7 +78,6 @@ def get_lm_eval_lolcats_model(model_kwargs: dict, lolcats_model: bool = True): lm = get_model('hf-causal-experimental').create_from_arg_string( '', lm_kwargs, ) - # model = lm.model return lm @@ -123,9 +123,6 @@ def put(self, value): self.print_len[idx] += len(printable_text) else: printable_text = text[self.print_len[idx] : text.rfind(" ") + 1] - # printable_text = text[self.print_len[idx] : self.print_len[idx] + 1] - # if printable_text == '': - # printable_text = self.stop_signal self.print_len[idx] += len(printable_text) printable_texts.append(printable_text) @@ -141,7 +138,6 @@ def end(self): self.print_len[idx] = 0 else: printable_text = "" - # printable_text = self.stop_signal printable_texts.append(printable_text) self.next_tokens_are_prompt = True @@ -158,19 +154,12 @@ def on_finalized_text(self, texts: List[str], stream_end: bool = False): for x in self.text_queue.queue ]) for i in range(len(self.text_queue.queue[0])) ] - # text = '\n\n'.join(self.text_queue.queue[0]) text = '\n------------\n'.join(text) go_up = "\033[F" * self.go_up # len(text) # Goes up this many lines - # go_down = "\n" * self.go_up # len(text) # Goes up this many lines print(f'{text}', flush=True, end="" if not stream_end else None) - # print(f'{go_up}{text}', end="" if not stream_end else None) - # self.go_up = self.batch_size except Exception as e: print(self.stop_signal) - # print(text) - # print(e) - # return def count_params(module) -> int: return sum(p.numel() for p in module.parameters()) @@ -203,61 +192,108 @@ def setup_fsdp_config(config, args, checkpoint_name: str = 'finetune'): return config -def load_model_from_checkpoint(attn_mlp_checkpoint_path: str, - finetune_checkpoint_path: str, +def load_hf_weights(model, distill_repo_id, ft_repo_id, filename="model.pt"): + for repo_id in [distill_repo_id, ft_repo_id]: + if repo_id is None: continue + + print(f"Loading weights from {repo_id}") + + local_file_path = hf_hub_download(repo_id=repo_id, filename=filename) + state_dict = torch.load(local_file_path) + if 'model_state_dict' in state_dict: + state_dict = state_dict['model_state_dict'] + else: + pass + _keys = model.load_state_dict(state_dict, strict=False) + if len(_keys.unexpected_keys) > 0: + new_state_dict = {k.replace('model.', 'model.model.'): v for k, v in state_dict.items()} + _keys = model.load_state_dict(new_state_dict, strict=False) + if len(_keys.unexpected_keys) > 0: + new_state_dict = {k.replace('model.', 'base_model.model.model.'): v for k, v in state_dict.items()} + _keys = model.load_state_dict(new_state_dict, strict=False) + + try: + assert len(_keys.unexpected_keys) == 0 + print_header('*** All expected keys matched successfully ***') + except Exception as e: + print(e) + print_header('*** Error: unexpected keys in checkpoint - please fix ***') + print('Unexpected keys:') + for k in _keys.unexpected_keys: + print(k) + exit() + + return model + + +def load_model_from_checkpoint(attn_mlp_checkpoint_path: str = None, + finetune_checkpoint_path: str = None, + model_config_path: str = None, + distill_config_path: str = None, + finetune_config_path: str = None, config_dir: str = 'configs', print_model: bool = False, debug: bool = False, - huggingface_token: str = None): - rank = 0 - # Get configs from checkpoint paths - try: - model_config = attn_mlp_checkpoint_path.split('-m=')[-1].split('-f=')[0] - distill_config = attn_mlp_checkpoint_path.split('-d=')[-1].split('-m=')[0] - except Exception as e: - model_config = finetune_checkpoint_path.split('-m=')[-1].split('-f=')[0] - distill_config = None - - model_config = join(config_dir, 'model', f'{model_config}.yaml') - model_config = OmegaConf.load(model_config) - - if distill_config is not None: - distill_config = join(config_dir, 'experiment', f'{distill_config}.yaml') - distill_config = OmegaConf.load(distill_config) - else: - distill_config = {} + huggingface_token: str = None, + use_cuda_kernels: bool = False, + use_attention: bool = False): - finetune_config = finetune_checkpoint_path.split('-f=')[-1].split('-')[0] - finetune_config = join(config_dir, 'experiment', f'{finetune_config}.yaml') - finetune_config = OmegaConf.load(finetune_config) + is_local = attn_mlp_checkpoint_path.endswith(".pt") + + rank = 0 + model_config = OmegaConf.load(model_config_path) + distill_config = OmegaConf.load(distill_config_path) + finetune_config = OmegaConf.load(finetune_config_path) - # Load initial model + # Load initial transformer model model_loader = get_pretrained_loader(**model_config.model, huggingface_token=huggingface_token) tokenizer = model_loader.load_tokenizer() tokenizer.pad_token_id = tokenizer.eos_token_id tokenizer.padding_side = 'left' + if use_attention: + model = model_loader.load('softmax') + return model, model_config, tokenizer model = model_loader.load(model_config['attention']['attention_type']) - if debug: - print_header('Pretrained Model') - print(model) - - # Add subquadratic attentions + if use_cuda_kernels: + print('*** Using TK CUDA kernels **') + model_config['attention']['attention_type'] = 'lolcats_llama_window_tk_gen' + + # Swap the softmax to linear attention + if is_local: + checkpoint_path = attn_mlp_checkpoint_path + else: + checkpoint_path = None model, distill_peft_config = load_and_convert_attns(model, model_config, - attention_type=None, # in model_config - checkpoint_path=attn_mlp_checkpoint_path, + attention_type=None, + checkpoint_path=checkpoint_path, print_model=debug, merge_loras=False, peft_gradient_checkpointing=False, train_attention=False) # Add PEFT parameters + if is_local: + checkpoint_path = attn_mlp_checkpoint_path + else: + checkpoint_path = None model, ft_peft_config = load_and_convert_finetune(model, finetune_config, - checkpoint_path=finetune_checkpoint_path, + checkpoint_path=checkpoint_path, print_model=debug, merge_loras=False, peft_gradient_checkpointing=False) + + # Load from huggingface checkpoints and insert into the model dict + if not is_local: + model = load_hf_weights( + model, + attn_mlp_checkpoint_path, finetune_checkpoint_path, + filename="model.pt" + ) + if use_cuda_kernels: + print('*** Using TK CUDA kernels ***') + if print_model: print_header('*** Model after checkpoint load ***') print(model) @@ -265,51 +301,28 @@ def load_model_from_checkpoint(attn_mlp_checkpoint_path: str, return model, model_config, tokenizer -def get_model_name(attn_mlp_checkpoint_path: str, finetune_checkpoint_path: str, - model_config: str = None): - model_name = 'πŸ¦” ' if attn_mlp_checkpoint_path is not None else '' - if 'llama3_8b_' in finetune_checkpoint_path: - model_name += f'Llama-3-8B' - elif 'llama3_1_8b_' in finetune_checkpoint_path: - model_name += f'Llama-3.1-8B' - elif 'llama2_7b_' in finetune_checkpoint_path: - model_name += f'Llama-2-7B' - elif 'mistral_7b_' in finetune_checkpoint_path: - model_name += f'Mistral-7B' - - if attn_mlp_checkpoint_path is not None: - model_name += f'-Hedgehog' - - if 'alpaca_clean' in finetune_checkpoint_path: - model_name += f'-Alpaca' - - elif model_config is not None: - if 'llama3_8b_' in model_config: - model_name += f'Llama-3-8B' - elif 'llama2_7b_' in model_config: - model_name += f'Llama-2-7B' - elif 'mistral_7b_' in model_config: - model_name += f'Mistral-7B' - - return model_name - - def main(): args = get_args() seed_everything(args.seed) model, model_config, tokenizer = load_model_from_checkpoint( args.attn_mlp_checkpoint_path, args.finetune_checkpoint_path, + args.model_config_path, args.distill_config_path, args.finetune_config_path, config_dir=args.config_dir, print_model = args.print_model, debug = args.debug, + use_cuda_kernels = args.use_cuda_kernels, + use_attention = args.use_attention, ) + model = model.to('cuda') model.eval() input_len = len(tokenizer(system_prompt)['input_ids']) - model_name = get_model_name(args.attn_mlp_checkpoint_path, - args.finetune_checkpoint_path, - model_config) while True: print(f'\n>> Generating {args.num_generations} responses in parallel') - prompt = input(f'>> Message {model_name} (or cmd-c to quit)... ') + if args.use_cuda_kernels or args.benchmark: + # 101424: need the prompt to be a multiple of 64 for our current kernel + prompt = "Create a summary of the following passage: London is the capital city of England and the United Kingdom. It is a leading global city with strengths in the arts, commerce, education, entertainment, fashion, finance, healthcare, media, professional services, research and development, tourism, and transport all contributing to its prominence. It is one of the most populous cities in the world, with an estimated population of 8.9 million in 2019. At its centre stand the imposing Houses of Parliament, the iconic β€˜Big Ben’ clock tower and Westminster Abbey, site of British monarch coronations. Across the Thames River, the London Eye observation wheel provides panoramic views of the South Bank cultural complex, and the entire city. London exerts a strong influence on world art, entertainment, fashion, commerce, finance, education, healthcare, media, science, technology, tourism, transport, and communications. London's cultures encompass over 300 languages. The 2023 population of Greater London is just under 10 million people. The Greater London Built-up Area is the fourth-most populous in current Europe. Until 1889, the name 'London' applied officially only to the City of London, but since then it has also referred to the County of London and to the Greater London area" + prompt = " ".join([prompt]*32)[:-5] #+ "area where we are all living. Now, let's see what the" + else: + prompt = input(f'>> Your prompt: (or cmd-c to quit)... ') all_prompts = [system_prompt.format(prompt=prompt)] * args.num_generations @@ -323,18 +336,20 @@ def main(): with torch.no_grad(): model_input = tokenizer(all_prompts, return_tensors="pt").to(model.device) + print(model_input['input_ids'].shape) if args.benchmark: torch.cuda.synchronize() start_time = time.time() model_output = model.generate(**model_input, use_cache=True, max_new_tokens=args.max_new_tokens, - do_sample=True, + do_sample=False, top_k=args.top_k, top_p=args.top_p, num_return_sequences=1, pad_token_id=tokenizer.eos_token_id, streamer=streamer) + if args.benchmark: torch.cuda.synchronize() elapsed = time.time() - start_time @@ -345,6 +360,10 @@ def main(): print(f'β”œβ”€β”€ Total tokens processed + generated: {total_tokens}') print(f'β”œβ”€β”€ Throughput (lagged by last response): {total_tokens / elapsed:.3f} tokens/sec') + if 1: #args.use_cuda_kernels: + break if __name__ == '__main__': - main() \ No newline at end of file + main() + + diff --git a/demos/llm_mmlu_eval/demo_405b.sh b/demos/llm_mmlu_eval/demo_405b.sh new file mode 100644 index 0000000..d7783d6 --- /dev/null +++ b/demos/llm_mmlu_eval/demo_405b.sh @@ -0,0 +1,50 @@ +#!/bin/bash +#SBATCH --job-name=llama-405b +#SBATCH --partition=sixhour +#SBATCH --nodes=2 +#SBATCH --nodelist=mk-xii-05,mk-xii-06 # TODO: set to your nodenames +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=19 +#SBATCH --time=5:59:00 +#SBATCH --output=/home/simarora/utils/slurm_logs/slurm-%j.out # TODO: make your own directory +#SBATCH --error=/home/simarora/utils/slurm_logs/slurm-%j.err +#SBATCH --ntasks=2 # Add this line +#SBATCH --ntasks-per-node=1 # Add this line + +# Initialize HPC-X toolkit for high-performance computing +. /opt/hpcx/hpcx-init.sh +hpcx_load + +export NCCL_IGNORE_CPU_AFFINITY=1 # Ignore CPU affinity settings +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 # Enable asynchronous error handling for PyTorch NCCL +export CUDA_DEVICE_ORDER=PCI_BUS_ID # Set CUDA device order to PCI bus ID +export NCCL_IB_DISABLE=0 # Enable InfiniBand if available +export NCCL_NET_GDR_LEVEL=5 # Enable GPUDirect RDMA for faster GPU-to-GPU communication +export NCCL_P2P_DISABLE=0 # Enable peer-to-peer communication between GPUs +export NCCL_BUFFSIZE=2097152 # Set 2MB buffer size for NCCL operations +export NCCL_IB_HCA=mlx5 # Specify the InfiniBand Host Channel Adapter to use + +export MASTER_HOSTNAME="mk-xii-05" # # TODO change to your nodenames +export MASTER_ADDR=$(host $MASTER_HOSTNAME | awk '/has address/ { print $4 }') +export MASTER_PORT=29500 + +export PYTHONPATH=/home/simarora/code/lolcats/ # TODO change to your folder + +# Save the model outputs +srun torchrun --nnodes 2 --node_rank $SLURM_NODEID --rdzv_id $RANDOM --rdzv_backend c10d --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT --nproc_per_node 8 \ + demos/llm_mmlu_eval/eval_mmlu.py \ + --model_config llama3_1_405b/distill_llama3_1_405b_lk_smd_wtk64_fd64_w01 \ + --distill_config llama3_1_405b/rp_distill_llama_405b_xent1_mse1000_lr1e-2 \ + --finetune_config llama3_1_405b/finetune_rp_llama_405b_qkvo_e2 \ + --verbose --replicate 0 --seed 0 --lk_zero_init \ + --eval_steps 100 --dataset_chunk_size 1024 \ + --enable_fsdp --low_cpu_fsdp --fsdp_activation_checkpointing \ + --tag hf_405b_mmlu \ + --finetune_checkpoint_path hazyresearch/lolcats-llama-3.1-405b + + +# Alternatively, you can run with your own locally trained paths by passing in the the checkpoint_path like follows: +# --finetune_checkpoint_path /home/simarora/code/lolcats/checkpoints/ckpt_lora-dl-d=rp_distill_llama_405b_xent1_mse1000_lr1e-2-m=distill_llama3_1_405b_lk_smd_wtk64_fd64_w01-f=rp_finetune_llama_40b_qv_hparams-s=0-se=0-re=0-ef=finetune_llama_405b_qkvo_e2_rp-ft_lora=0-se=0-re=0-s=3550.pt + + + diff --git a/demos/llm_mmlu_eval/demo_70b.sh b/demos/llm_mmlu_eval/demo_70b.sh new file mode 100644 index 0000000..43ddb51 --- /dev/null +++ b/demos/llm_mmlu_eval/demo_70b.sh @@ -0,0 +1,30 @@ +export PYTHONPATH=/home/simarora/code/lolcats/ + +# Use HF checkpoint paths (can also prob get away with 2 GPUs - longer contexts may not fit tho) +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nnodes 1 --nproc_per_nodes 8 \ + demos/llm_mmlu_eval/eval_mmlu.py \ + --model_config llama3_1_70b/distill_llama3_1_70b_lk_smd_wtk64_fd64_w01 \ + --distill_config llama3_1_70b/distill_rp_llama_70b_xent0_mse1000_lr1e-2 \ + --finetune_config llama3_1_70b/finetune_rp_llama_70b_qkvo \ + --eval_config eval_alpaca_clean \ + --verbose --replicate 0 --seed 0 \ + --lk_zero_init \ + --eval_steps 100 --dataset_chunk_size 1024 \ + --enable_fsdp --low_cpu_fsdp --fsdp_activation_checkpointing \ + --experiment_tag lolcats_hf_70b \ + --finetune_checkpoint_path 'hazyresearch/lolcats-llama-3.1-70b' + +# Example using local paths, in case you train your own model. +# CUDA_VISIBLE_DEVICES=6,7 torchrun --nnodes 1 --nproc_per_node 2 \ +# demos/llm_mmlu_eval/eval_mmlu.py \ +# --model_config llama3_1_70b/distill_llama3_1_70b_lk_smd_wtk64_fd64_w01 \ +# --distill_config llama3_1_70b/distill_rp_llama_70b_xent0_mse1000_lr1e-2 \ +# --finetune_config llama3_1_70b/finetune_rp_llama_70b_qkvo \ +# --eval_config eval_alpaca_clean \ +# --verbose --replicate 0 --seed 0 \ +# --lk_zero_init \ +# --eval_steps 100 --dataset_chunk_size 1024 \ +# --enable_fsdp --low_cpu_fsdp --fsdp_activation_checkpointing \ +# --experiment_tag my_lolcats_70b \ +# --finetune_checkpoint_path ckpt_lora-dl-d=distill_rp_llama_70b_xent0_mse1000_lr1e-2-m=distill_llama3_1_70b_lk_smd_wtk64_fd64_w01-f=finetune_rp_llama_70b_qkvo-fac=1-se=0-re=0-se=0-re=0.pt + diff --git a/demos/llm_mmlu_eval/eval_mmlu.py b/demos/llm_mmlu_eval/eval_mmlu.py new file mode 100644 index 0000000..442ee7b --- /dev/null +++ b/demos/llm_mmlu_eval/eval_mmlu.py @@ -0,0 +1,580 @@ + +import os +from os.path import join +import dataclasses +import random +from tqdm import tqdm +import pickle + +from accelerate.utils import is_xpu_available +from torch.utils.data import DataLoader, Dataset +from collections import defaultdict +from omegaconf import OmegaConf + +import sys +sys.path.append('./llama_recipes/') +sys.path.append('./src') +sys.path.append('./../src') + +import torch +from torch import nn +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload + +from llama_recipes.configs import fsdp_config as FSDP_CONFIG +from llama_recipes.utils.fsdp_utils import fsdp_auto_wrap_policy +from llama_recipes.utils.config_utils import update_config +from llama_recipes.trainer_finetune import ( + setup, + setup_environ_flags, + clear_gpu_cache, + print_model_size, + get_policies, +) +from llama_recipes.model_checkpointing.distill_checkpoint_handler import load_sharded_model_single_gpu +from llama_recipes.distill_llama import setup_wandb, setup_fsdp_config + +from src.utils.setup import update_config_from_args, update_model_config_from_args +from src.utils.logging import print_header, print_config +from src.finetune import prepare_finetune_configs +from src.model.pretrained import get_pretrained_loader +from src.model.load_model import load_and_convert_attns, load_and_convert_finetune +import argparse + +try: + from huggingface_hub import hf_hub_download +except ImportError: + print("Please pip install huggingface-hub") + + +def get_args(): + """Get attention transfer args""" + parser = argparse.ArgumentParser() + parser.add_argument("--project_name", type=str, default='lolcats') + parser.add_argument("--model_config", type=str, default=None) + parser.add_argument("--distill_config", type=str, default=None) + parser.add_argument("--finetune_config", type=str, default=None) + parser.add_argument("--eval_config", type=str, default=None) + + parser.add_argument("--layers_per_model", type=int, default=None) + parser.add_argument("--layers_limit", type=int, default=None) + parser.add_argument("--layers_min_limit", type=int, default=None) + + parser.add_argument("--pretrained_model_name_or_path", type=str, default=None) + parser.add_argument("--load_distill_checkpoint", type=str, default=None) + parser.add_argument("--load_finetune_checkpoint", type=str, default=None) + parser.add_argument("--finetune_checkpoint_path", type=str, default=None) + parser.add_argument("--resume_distill", action='store_true', default=None) + parser.add_argument("--resume_finetune", action='store_true', default=None) + + # Override default configs + # Feature map / model + parser.add_argument("--attention_type", type=str, default=None) + parser.add_argument("--learned_kernel", type=str, default=None) + parser.add_argument("--lk_skip_connection", action='store_true', default=None) + parser.add_argument("--lk_zero_init", action='store_true', default=None) + parser.add_argument("--tie_qk_kernels", action='store_true', default=None) + parser.add_argument("--train_qk", action='store_true', default=None) + parser.add_argument("--state_chunk_len", type=int, default=None) + + # Training + ## Distributed training / Llama recipes + parser.add_argument("--enable_fsdp", action='store_true', default=None) + parser.add_argument("--low_cpu_fsdp", action='store_true', default=None) + parser.add_argument("--pure_bf16", action='store_true', default=None) + parser.add_argument("--fsdp_activation_checkpointing", action='store_true', default=None) + parser.add_argument("--fsdp_cpu_offload", action='store_true', default=None) + + ## Hyperparameters + parser.add_argument("--lr", type=float, default=None) + parser.add_argument("--weight_decay", type=float, default=None) + parser.add_argument("--optim", type=str, default=None) + parser.add_argument("--scheduler", type=str, default=None) + parser.add_argument("--gradient_accumulation_steps", type=int, default=None) + parser.add_argument("--num_train_epochs", type=int, default=None) + parser.add_argument("--max_steps", type=int, default=None) + parser.add_argument("--no_peft_grad_ckpt", action='store_true', default=None) + + # Finetuning + parser.add_argument("--finetune_lr", type=float, default=None) + parser.add_argument("--finetune_attn_mlps", action='store_true', default=None) + + # Dataloading + parser.add_argument("--dataset_chunk_size", type=int, default=None) + parser.add_argument("--batch_size", type=int, default=None) + parser.add_argument("--num_workers", type=int, default=None) + + # Evaluation + parser.add_argument("--no_init_eval", action='store_true', default=False) + parser.add_argument("--eval_steps", type=int, default=None) + + # Experimental tag for saving mmlu preds + parser.add_argument("--experiment_tag", type=str, default=None) + + # Miscellaneous + parser.add_argument("--huggingface_token", type=str, default=None) + parser.add_argument("--checkpoint_dir", type=str, default='./checkpoints') + parser.add_argument("--replicate", type=int, default=0) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--verbose", action='store_true', default=None) + parser.add_argument("--no_cuda", action='store_true', default=None) + parser.add_argument("--no_wandb", action='store_true', default=None) + parser.add_argument("--wandb_entity", type=str, default='hazy-research') + parser.add_argument("--debug", action='store_true', default=None) + parser.add_argument("--num_train_steps", type=int, default=-1) + + # DEMO + ## Generation + parser.add_argument("--num_generations", type=int, default=1) + parser.add_argument("--top_k", type=int, default=50) + parser.add_argument("--top_p", type=float, default=0.95) + parser.add_argument("--max_new_tokens", type=int, default=1024) + + ## Miscellaneous + parser.add_argument("--benchmark", action='store_true', default=False) + parser.add_argument("--print_model", action='store_true', default=False) + + args = parser.parse_args() + + distill_name = args.distill_config + finetune_name = args.finetune_config + + args.run_name = f'dl-d={distill_name}-m={args.model_config}-f={finetune_name}' + if args.no_peft_grad_ckpt is not None: + args.run_name += f'-npgc={args.no_peft_grad_ckpt}' + if args.fsdp_activation_checkpointing is not None: + args.run_name += f'-fac={args.fsdp_activation_checkpointing}' + + if args.debug: + args.run_name += '-debug' + + args.run_name = args.run_name.replace('True', '1').replace('False', '0') # concise hacks + return args + + +class InputDataset(Dataset): + def __init__(self, data): + self.samples = data + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + x = self.samples[idx] + entry = x[-1] + input_tokens = torch.tensor(entry['token_ids']) + ans_tokens = torch.tensor(entry['answer_token_ids']) + text = entry['text'] + text, answer_choice = text + + if 'A' in answer_choice: + answer_choice = 0 + elif 'B' in answer_choice: + answer_choice = 1 + elif 'C' in answer_choice: + answer_choice = 2 + elif 'D' in answer_choice: + answer_choice = 3 + else: + raise ValueError(f"Invalid answer choice: {answer_choice}") + + gold_answer = x[2]['gold'] + task_name = x[1] + query = x[2]['query'] + doc_id = f"{task_name}-{query}" + + return { + 'input_ids': input_tokens, + 'target_ids': ans_tokens, + 'doc_id': doc_id, + 'answer_choice': x[0], + 'answer': gold_answer + } + + +def evaluate_lm(model, train_config, eval_dataloader, + local_rank, rank: int = 0, TAG=None): + + if TAG is None: + print(f"TAG is None, Will not save the predictions out.") + + if rank == 0: + for n, p in model.named_parameters(): + if ('layers.0.' in n and 'base_attn' not in n and + '.0.mlp.' not in n and '.block_sparse_moe' not in n): + print(f'-> {n}:\n', p) + + if train_config.enable_fsdp: + world_size = int(os.environ["WORLD_SIZE"]) + model.eval() + + log_probabilities = defaultdict(dict) + answers = {} + pbar = tqdm(eval_dataloader,colour="green", desc=f"Rank {rank}", dynamic_ncols=True) + correctness = {} + skipped = {} + criterion = torch.nn.CrossEntropyLoss(reduction='mean') + for step, batch in enumerate(pbar): + for key in batch.keys(): + if (type(batch[key]) == torch.Tensor): + if train_config.enable_fsdp: + batch[key] = batch[key].to(local_rank) + else: + if is_xpu_available(): + batch[key] = batch[key].to('xpu:0') + else: + batch[key] = batch[key].to('cuda:0') + + # Ensure no gradients are computed for this scope to save memory + with torch.no_grad(): + input_keys = {'input_ids'} + inputs = {k: v.to(model.device) for k, v in batch.items() if k in input_keys} + bs = inputs['input_ids'].shape[0] + seq_len = inputs['input_ids'].shape[1] + + # model call + outputs = model(**inputs, output_attentions=False, use_cache=False) + outputs = outputs.get('logits')[..., -1, :].contiguous() + target_ids = batch['target_ids'].to(model.device) + + # find the answer with the highest probability + losses = [] + for choice_idx in range(outputs.shape[0]): + output = outputs[choice_idx].unsqueeze(0) + target = target_ids[choice_idx].view(-1) + losses.append(criterion(output, target)) + losses = torch.stack(losses).cpu() # b, 1 + pred = torch.argmin(losses, dim=0) + answer = batch['answer'] + if type(pred) == torch.Tensor: # Flagging this logic. + pred = pred.item() + if rank == 0: + print(f"--> {step=}: {answer=}, {pred=}") + correct = (answer[0].cpu() == pred) + correctness[step] = correct + skipped[step] = False + + # free up memory + del outputs; del inputs; del target_ids; del losses; del pred; del correct + + if step % 100 == 0 and rank == 0: + total_correct = sum(correctness.values()) + print(f"--> at step {step}, Total Correct: {total_correct}/{len(correctness)}") + if TAG is not None: + with open(f"{TAG}_{step}_mmlu_predictions_{rank}.pkl", 'wb') as f: + pickle.dump(correctness, f) + with open(f"{TAG}_{step}_mmlu_skipped_{rank}.pkl", 'wb') as f: + pickle.dump(skipped, f) + + if rank == 0: + total_correct = sum(correctness.values()) + print(f"--> at step {step}, Total Correct: {total_correct}/{len(correctness)}") + if TAG is not None: + with open(f"{TAG}_{step}_mmlu_predictions_{rank}.pkl", 'wb') as f: + pickle.dump(correctness, f) + with open(f"{TAG}_{step}_mmlu_skipped_{rank}.pkl", 'wb') as f: + pickle.dump(skipped, f) + + del log_probabilities; del batch + clear_gpu_cache() + + total_correct = sum(correctness.values()) + mmlu_score = total_correct / len(correctness) + return mmlu_score + + +def load_hf_weights(model, distill_repo_id, ft_repo_id, filename="model.pt"): + for repo_id in [distill_repo_id, ft_repo_id]: + if repo_id is None: continue + print(f"Loading weights from {repo_id}") + local_file_path = hf_hub_download(repo_id=repo_id, filename=filename) + state_dict = torch.load(local_file_path) + if 'model_state_dict' in state_dict: + state_dict = state_dict['model_state_dict'] + else: + pass + _keys = model.load_state_dict(state_dict, strict=False) + if len(_keys.unexpected_keys) > 0: + new_state_dict = {k.replace('model.', 'model.model.'): v for k, v in state_dict.items()} + _keys = model.load_state_dict(new_state_dict, strict=False) + if len(_keys.unexpected_keys) > 0: + new_state_dict = {k.replace('model.', 'base_model.model.model.'): v for k, v in state_dict.items()} + _keys = model.load_state_dict(new_state_dict, strict=False) + try: + assert len(_keys.unexpected_keys) == 0 + print_header('*** All expected keys matched successfully ***') + except Exception as e: + print(e) + print_header('*** Error: unexpected keys in checkpoint - please fix ***') + print('Unexpected keys:') + for k in _keys.unexpected_keys: + print(k) + exit() + return model + + +def main(): + # --------- + # 1. SET UP + # --------- + args = get_args() + args.checkpoint_dir = join(args.checkpoint_dir, args.model_config) + if not os.path.isdir(args.checkpoint_dir): + os.makedirs(args.checkpoint_dir) + kwargs = vars(args) + + # Load distillation + attention configs + distill_config_path = join('./configs/experiment', f'{args.distill_config}.yaml') + distill_config = OmegaConf.load(distill_config_path) + distill_config = update_config_from_args(distill_config, args) + + model_config_path = join('./configs/model', f'{args.model_config}.yaml') + model_config = OmegaConf.load(model_config_path) + model_config = update_model_config_from_args(model_config, args) + if args.enable_fsdp: + if getattr(model_config.model, 'load_in_4bit', False): + model_config.model.device_map = 'auto' + elif getattr(model_config.model, 'load_in_8bit', False): + model_config.model.device_map = 'auto' + else: + model_config.model.device_map = None # FSDP will complain about device placement o.w. + + # Update dataset pretrained model config + for k in distill_config.dataset.pretrained_model_config: + print(f"{k=}") + distill_config.dataset.pretrained_model_config[k] = getattr(model_config.model, k) + + args.run_name = args.run_name.replace('True', '1').replace('False', '0') # concise hacks + + # Update the configuration for the training and sharding process + distill_config = setup_fsdp_config(distill_config, args, 'distill') # patch llama-recipes args + fsdp_config = FSDP_CONFIG() + update_config((fsdp_config), **vars(args)) + # Set the seeds for reproducibility + if is_xpu_available(): + torch.xpu.manual_seed(args.seed) + torch.manual_seed(args.seed) + random.seed(args.seed) + + if args.enable_fsdp: + setup() + # torchrun specific + local_rank = int(os.environ["LOCAL_RANK"]) + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + if rank == 0: + print(f"{distill_config.dataset.pretrained_model_config=}") + print("*****"*10) + print(f"{model_config.model=}") + + if rank == 0 or not args.enable_fsdp: + print_header('Distillation Config') + print_config(distill_config) + print_header('Model Config') + print_config(model_config) + print_header('FSDP Config') + print_config(dataclasses.asdict(fsdp_config)) + + if torch.distributed.is_initialized(): + if is_xpu_available(): + torch.xpu.set_device(local_rank) + elif torch.cuda.is_available(): + torch.cuda.set_device(local_rank) + clear_gpu_cache(local_rank) + setup_environ_flags(rank) + + wandb_run = None + if not args.no_wandb: + if not args.enable_fsdp or rank==0: + wandb_run = setup_wandb(distill_config, fsdp_config, **kwargs) + + + finetune_checkpoint_path = args.finetune_checkpoint_path + if finetune_checkpoint_path == "None": + finetune_checkpoint_path = None + + # ------------------------ + # 2. LOAD PRETRAINED MODEL + # ------------------------ + # Load the pre-trained model and setup its configuration + # Initialize tokenizer and model loader + model_loader = get_pretrained_loader(**model_config.model) + tokenizer = model_loader.load_tokenizer() + tokenizer.pad_token_id = tokenizer.eos_token_id + tokenizer.padding_side = 'left' + + model_type = "softmax" + if 'lama' in model_config.model.pretrained_model_name_or_path: + from transformers.models.llama.modeling_llama import LlamaDecoderLayer as DecoderLayer + from src.model.modeling_llama import LolcatsLlamaForCausalLM as ModelClass + if finetune_checkpoint_path is None: + model_type = 'softmax' + else: + model_type = 'llama' + if rank == 0: + print(f"{model_type=}") + + # Convert model + if finetune_checkpoint_path is not None: + try: + args.attention_type = model_config['attention']['attention_type'] + except AttributeError: + args.attention_type = 'lolcats_llama' + else: + args.attention_type = "softmax" + + model = model_loader.load(args.attention_type) + model.state_chunk_len = model_config['attention']['state_chunk_len'] + model_config.model_name = model_config.model.pretrained_model_name_or_path + print_model_size(model, model_config, rank if args.enable_fsdp else 0) + if args.enable_fsdp and fsdp_config.pure_bf16: + model.to(torch.bfloat16) + + # ------------------------------- + # 3. CONVERT DISTILLED ATTENTIONS + # ------------------------------- + print(f"Before convert attns") + if finetune_checkpoint_path is not None: + model, distill_peft_config = load_and_convert_attns(model, model_config, + attention_type=args.attention_type, + checkpoint_path=None, + print_model=args.verbose, + merge_loras=False, + peft_gradient_checkpointing=not args.no_peft_grad_ckpt, + train_attention=False, + rank=rank) + else: + distill_peft_config = None + + if rank == 0: + print(model) + if rank == 0: + print_header('** Sanity check model weights **') + for n, p in model.named_parameters(): + if ('layers.0.' in n and ('feature_map' in n or 'lora' in n)): + print(f'-> {n}:\n', p) + + if wandb_run and distill_peft_config is not None: + wandb_run.config.update(distill_peft_config) + + # ---------------------------- + # 4. ADD FINETUNING PARAMETERS + # ---------------------------- + finetune_config, args = prepare_finetune_configs(args, model_config, args.finetune_config) + finetune_config = setup_fsdp_config(finetune_config, args, 'finetune') + if args.finetune_lr is not None: + finetune_config.model_name += f'=flr={args.finetune_lr}' + + if rank == 0: + print(f"{args.load_finetune_checkpoint=}") + if finetune_checkpoint_path is not None: + model, _ = load_and_convert_finetune(model, finetune_config, + # checkpoint_path=args.load_finetune_checkpoint, + print_model=args.verbose, + merge_loras=False, + peft_gradient_checkpointing=not args.no_peft_grad_ckpt, + rank=rank) + + # Load in our trained weights, assumes path has weights from both attention transfer and LoRA + if finetune_checkpoint_path is not None: + if '.pt' in args.finetune_checkpoint_path: + with torch.no_grad(): + _keys = model.load_state_dict(torch.load(args.finetune_checkpoint_path), strict=False) + if rank == 0: + print(f"Found {len(_keys.unexpected_keys)} unexpected keys.") + elif 'hazyresearch' in args.finetune_checkpoint_path: + print(f"Loading from huggingface.") + model = load_hf_weights(model, args.finetune_checkpoint_path, None, filename="model.pt") + else: + model = load_sharded_model_single_gpu(model, model_path=args.finetune_checkpoint_path, + cfg=finetune_config, rank=rank) + + if rank == 0: + print(f"{args.enable_fsdp=}") + if rank == 0 or not args.enable_fsdp: # debugging + print_header('** Sanity check model weights **') + for n, p in model.named_parameters(): + if ('layers.0.' in n and 'base_attn' not in n and + '.0.mlp.' not in n and '.block_sparse_moe' not in n): + print(f'-> {n}:\n', p) + + # ------------------------------------------------------ + # 5. SETUP FSDP AND LOAD DISTILLED ATTENTION CHECKPOINTS + # ------------------------------------------------------ + if args.enable_fsdp: + + mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank, model=model_type) + my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, DecoderLayer) + + device_id = 0 + if is_xpu_available(): + device_id = torch.xpu.current_device() + elif torch.cuda.is_available(): + device_id = torch.cuda.current_device() + print('-> device_id:', device_id) + print(f"Model") + if rank == 0: + print(model) + + model = FSDP( + model, + auto_wrap_policy=my_auto_wrapping_policy, + cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None, + mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None, + sharding_strategy=fsdp_config.sharding_strategy, + device_id=device_id, + limit_all_gathers=True, + sync_module_states=args.low_cpu_fsdp, + param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False) + if args.low_cpu_fsdp and rank != 0 else None, + ) + + # Load distilled checkpoints + if args.verbose and rank == 0: + print_header('*** FSDP Model ***') + print(model) + print('Loading checkpoints from:', distill_config.model_name) + + if rank == 0 or not args.enable_fsdp: # debugging + print_header('** Sanity check model weights **') + for n, p in model.named_parameters(): + if ('layers.0.' in n and 'base_attn' not in n and + '.0.mlp.' not in n and '.block_sparse_moe' not in n): + print(f'-> {n}:\n', p) + + if args.verbose and (rank == 0 or not args.enable_fsdp): + print_header('*** FSDP MODEL ***') + print(model) + print_header('*** Trainable Parameters ***') + for n, p in model.named_parameters(): + if p.requires_grad: + print(f'β”œβ”€β”€ {n} (dtype = {p.dtype})') + + if rank == 0: + print(f"Getting the MMLU dataset:") + if not os.path.exists("demos/llm_mmlu_eval/mmlu.pkl"): + print(f"Please make sure the paths are set corretly for mmlu.pkl") + with open("demos/llm_mmlu_eval/mmlu.pkl", 'rb') as f: + data = pickle.load(f) + dataset = InputDataset(data) + dataloader = DataLoader(dataset, shuffle=False, batch_size=4) + if not args.enable_fsdp or rank == 0: + print(f"--> Validation Set Length = {len(dataloader.dataset)}") + + if rank == 0: + print(f"Running evaluation:") + + TAG = args.experiment_tag + mmlu_score = evaluate_lm(model, finetune_config, + dataloader, + local_rank if args.enable_fsdp else None, + rank = rank if args.enable_fsdp else None, + TAG=TAG + ) + if rank == 0: + print(f"--> Final MMLU Score: {mmlu_score}") + +if __name__ == "__main__": + main() + diff --git a/demos/llm_mmlu_eval/mmlu.pkl.zip b/demos/llm_mmlu_eval/mmlu.pkl.zip new file mode 100644 index 0000000..c83c81a Binary files /dev/null and b/demos/llm_mmlu_eval/mmlu.pkl.zip differ diff --git a/demos/vllm_integration/README.md b/demos/vllm_integration/README.md new file mode 100644 index 0000000..56afa27 --- /dev/null +++ b/demos/vllm_integration/README.md @@ -0,0 +1,82 @@ +## Coming Soon! VLLM Integration + + +#### 1. Clone VLLM +Also run VLLM installations. +```bash +git clone https://github.com/vllm-project/vllm +``` + +#### 2. Copy the following LoLCATS specific files into vllm. + +``` +bash +cp lolcats/inference/vllm_files/lolcats.py vllm/model_executor/models/lolcats.py +``` + +And add the new LoLCATS models from: +```bash +lolcats/inference/vllm_files/__init__.py -> vllm/model_executor/models/__init__.py +``` + +#### 3. Set the model checkpoint paths. + +Given your local download of the 405B weights, go to the ```Meta-Llama-3.1-405B/config.py``` file and modify the architecture list from ```LlamaForCausalLM``` to ```LlamaLolcatsForCausalLM```. + +In ```vllm/model_executor/models/lolcats_inference_paged.py``` set the ```PATH=....pt``` to the name of your copy of the linearized weights (feature maps and LoRA). + +#### 4. Run VLLM. + +These instructions assume you have 2 nodes of $8 \times 80$GB to fit the FP16 405B model. You are okay with 1 node for 70B parameters. +```bash + +# Step 1. Follow the VLLM installation quick start to install it in your environment. + +# Step 2. Set up a 2 node ray cluster. On the respective nodes, run: +ray start --head # on node 1 +ray start --address='ip from above' # on node 2 + +# Step 3. Load the model on the 2 nodes, creating an OpenAI endpoint. Remember to hard code the ckpt paths in lolcats.py PATH (cant use env variable on multinode). Set tensor-parallel-size to 8 if using 1 node. Run this on the head node (node 1). +vllm serve /path/to/hf/model/Meta-Llama-3.1-405B --tensor-parallel-size 16 --enforce-eager # on node 1 +``` + +#### 5. Clone LM-Eval harness and run inference evaluations: +```bash +git clone https://github.com/EleutherAI/lm-evaluation-harness +git checkout b281b092 +pip install -e .[api] +``` + +Note that if ```datasets.load_datasets``` gives an issue, it helps to ```pip install -U datasets```. + +Launch the evaluation commands on node 1 (the head node of the ray cluster). +```bash +lm_eval --model local-completions --tasks piqa,hellaswag,winogrande,arc_challenge,arc_easy --model_args model='/path/to/hf/model/Meta-Llama-3.1-405B',base_url=http://localhost:8000/v1/completions,num_concurrent=1,max_retries=3,tokenized_requests=False --batch_size 1 --output save/ + +lm_eval --model local-completions --tasks mmlu --num_fewshot 5 --model_args model='/path/to/hf/model/Meta-Llama-3.1-405B',base_url=http://localhost:8000/v1/completions,num_concurrent=1,max_retries=3,tokenized_requests=False --batch_size 1 --output save/ +``` + +#### References +Please cite the following if you use their work: + +``` +@misc{eval-harness, + author = {Gao, Leo and Tow, Jonathan and Abbasi, Baber and Biderman, Stella and Black, Sid and DiPofi, Anthony and Foster, Charles and Golding, Laurence and Hsu, Jeffrey and Le Noac'h, Alain and Li, Haonan and McDonell, Kyle and Muennighoff, Niklas and Ociepa, Chris and Phang, Jason and Reynolds, Laria and Schoelkopf, Hailey and Skowron, Aviya and Sutawika, Lintang and Tang, Eric and Thite, Anish and Wang, Ben and Wang, Kevin and Zou, Andy}, + title = {A framework for few-shot language model evaluation}, + month = 07, + year = 2024, + publisher = {Zenodo}, + version = {v0.4.3}, + doi = {10.5281/zenodo.12608602}, + url = {https://zenodo.org/records/12608602} +} +``` + +``` +@inproceedings{kwon2023efficient, + title={Efficient Memory Management for Large Language Model Serving with PagedAttention}, + author={Woosuk Kwon and Zhuohan Li and Siyuan Zhuang and Ying Sheng and Lianmin Zheng and Cody Hao Yu and Joseph E. Gonzalez and Hao Zhang and Ion Stoica}, + booktitle={Proceedings of the ACM SIGOPS 29th Symposium on Operating Systems Principles}, + year={2023} +} +``` diff --git a/demos/vllm_integration/vllm_files/__init__.py b/demos/vllm_integration/vllm_files/__init__.py new file mode 100644 index 0000000..6953dab --- /dev/null +++ b/demos/vllm_integration/vllm_files/__init__.py @@ -0,0 +1,205 @@ +import functools +import importlib +from typing import Dict, List, Optional, Tuple, Type + +import torch.nn as nn + +from vllm.logger import init_logger +from vllm.utils import is_hip + +logger = init_logger(__name__) + +_GENERATION_MODELS = { + "AquilaModel": ("llama", "LlamaForCausalLM"), + "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2 + "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b + "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b + "BloomForCausalLM": ("bloom", "BloomForCausalLM"), + "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), + "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"), + "CohereForCausalLM": ("commandr", "CohereForCausalLM"), + "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"), + "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), + "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"), + "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"), + "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"), + "FalconForCausalLM": ("falcon", "FalconForCausalLM"), + "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"), + "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"), + "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"), + "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"), + "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"), + "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"), + "InternLMForCausalLM": ("llama", "LlamaForCausalLM"), + "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"), + "JAISLMHeadModel": ("jais", "JAISLMHeadModel"), + "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), + # For decapoda-research/llama-* + "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), + "MistralForCausalLM": ("llama", "LlamaForCausalLM"), + "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), + "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"), + # transformers's mpt class has lower case + "MptForCausalLM": ("mpt", "MPTForCausalLM"), + "MPTForCausalLM": ("mpt", "MPTForCausalLM"), + "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"), + "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"), + "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"), + "OPTForCausalLM": ("opt", "OPTForCausalLM"), + "OrionForCausalLM": ("orion", "OrionForCausalLM"), + "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"), + "PhiForCausalLM": ("phi", "PhiForCausalLM"), + "Phi3ForCausalLM": ("llama", "LlamaForCausalLM"), + "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), + "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), + "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"), + "RWForCausalLM": ("falcon", "FalconForCausalLM"), + "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), + "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), + "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), + "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"), + "XverseForCausalLM": ("xverse", "XverseForCausalLM"), + "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), + "MedusaModel": ("medusa", "Medusa"), + "EAGLEModel": ("eagle", "EAGLE"), + "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), + "JambaForCausalLM": ("jamba", "JambaForCausalLM"), + "GraniteForCausalLM": ("granite", "GraniteForCausalLM"), + "LlamaLolcatsForCausalLM": ("lolcats", "LlamaLolcatsForCausalLM") +} + +_EMBEDDING_MODELS = { + "MistralModel": ("llama_embedding", "LlamaEmbeddingModel"), +} + +_MULTIMODAL_MODELS = { + "Blip2ForConditionalGeneration": + ("blip2", "Blip2ForConditionalGeneration"), + "ChameleonForConditionalGeneration": + ("chameleon", "ChameleonForConditionalGeneration"), + "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), + "InternVLChatModel": ("internvl", "InternVLChatModel"), + "LlavaForConditionalGeneration": + ("llava", "LlavaForConditionalGeneration"), + "LlavaNextForConditionalGeneration": + ("llava_next", "LlavaNextForConditionalGeneration"), + "MiniCPMV": ("minicpmv", "MiniCPMV"), + "PaliGemmaForConditionalGeneration": ("paligemma", + "PaliGemmaForConditionalGeneration"), + "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), + "UltravoxModel": ("ultravox", "UltravoxModel"), + "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), +} +_CONDITIONAL_GENERATION_MODELS = { + "BartModel": ("bart", "BartForConditionalGeneration"), + "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"), +} + +_MODELS = { + **_GENERATION_MODELS, + **_EMBEDDING_MODELS, + **_MULTIMODAL_MODELS, + **_CONDITIONAL_GENERATION_MODELS, +} + +# Architecture -> type. +# out of tree models +_OOT_MODELS: Dict[str, Type[nn.Module]] = {} + +# Models not supported by ROCm. +_ROCM_UNSUPPORTED_MODELS: List[str] = [] + +# Models partially supported by ROCm. +# Architecture -> Reason. +_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in " + "Triton flash attention. For half-precision SWA support, " + "please use CK flash attention by setting " + "`VLLM_USE_TRITON_FLASH_ATTN=0`") +_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = { + "Qwen2ForCausalLM": + _ROCM_SWA_REASON, + "MistralForCausalLM": + _ROCM_SWA_REASON, + "MixtralForCausalLM": + _ROCM_SWA_REASON, + "PaliGemmaForConditionalGeneration": + ("ROCm flash attention does not yet " + "fully support 32-bit precision on PaliGemma"), + "Phi3VForCausalLM": + ("ROCm Triton flash attention may run into compilation errors due to " + "excessive use of shared memory. If this happens, disable Triton FA " + "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`") +} + + +class ModelRegistry: + + @staticmethod + @functools.lru_cache(maxsize=128) + def _get_model(model_arch: str): + module_name, model_cls_name = _MODELS[model_arch] + module = importlib.import_module( + f"vllm.model_executor.models.{module_name}") + return getattr(module, model_cls_name, None) + + @staticmethod + def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: + if model_arch in _OOT_MODELS: + return _OOT_MODELS[model_arch] + if model_arch not in _MODELS: + return None + if is_hip(): + if model_arch in _ROCM_UNSUPPORTED_MODELS: + raise ValueError( + f"Model architecture {model_arch} is not supported by " + "ROCm for now.") + if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS: + logger.warning( + "Model architecture %s is partially supported by ROCm: %s", + model_arch, _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]) + + return ModelRegistry._get_model(model_arch) + + @staticmethod + def resolve_model_cls( + architectures: List[str]) -> Tuple[Type[nn.Module], str]: + for arch in architectures: + model_cls = ModelRegistry._try_load_model_cls(arch) + if model_cls is not None: + return (model_cls, arch) + + raise ValueError( + f"Model architectures {architectures} are not supported for now. " + f"Supported architectures: {ModelRegistry.get_supported_archs()}") + + @staticmethod + def get_supported_archs() -> List[str]: + return list(_MODELS.keys()) + list(_OOT_MODELS.keys()) + + @staticmethod + def register_model(model_arch: str, model_cls: Type[nn.Module]): + if model_arch in _MODELS: + logger.warning( + "Model architecture %s is already registered, and will be " + "overwritten by the new model class %s.", model_arch, + model_cls.__name__) + global _OOT_MODELS + _OOT_MODELS[model_arch] = model_cls + + @staticmethod + def is_embedding_model(model_arch: str) -> bool: + return model_arch in _EMBEDDING_MODELS + + @staticmethod + def is_multimodal_model(model_arch: str) -> bool: + + # TODO: find a way to avoid initializing CUDA prematurely to + # use `supports_multimodal` to determine if a model is multimodal + # model_cls = ModelRegistry._try_load_model_cls(model_arch) + # from vllm.model_executor.models.interfaces import supports_multimodal + return model_arch in _MULTIMODAL_MODELS + + +__all__ = [ + "ModelRegistry", +] diff --git a/demos/vllm_integration/vllm_files/lolcats.py b/demos/vllm_integration/vllm_files/lolcats.py new file mode 100644 index 0000000..1a9612c --- /dev/null +++ b/demos/vllm_integration/vllm_files/lolcats.py @@ -0,0 +1,792 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only LLaMA model compatible with HuggingFace weights.""" +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +import math +import os +import torch +from collections import OrderedDict +from torch import nn +from torch.nn.parameter import Parameter +from transformers import LlamaConfig + +from vllm.attention import Attention, AttentionMetadata, AttentionType +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.distributed import (divide, + # split_tensor_along_last_dim, + # tensor_model_parallel_all_gather, + # tensor_model_parallel_all_reduce + ) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + get_compressed_tensors_cache_scale) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.utils import is_hip + + + +from vllm.model_executor.models.interfaces import SupportsLoRA +from vllm.model_executor.models.utils import PPMissingLayer, is_pp_missing_parameter, make_layers +# from .interfaces import SupportsLoRA +# from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers +from vllm.model_executor.models.llama import LlamaForCausalLM, LlamaAttention, LlamaMLP + +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import ReplicatedLinear + +logger = init_logger(__name__) + + +### OURS for Linear attention implementation +# from peft import get_peft_model, LoraConfig, TaskType + +# PEFT_KWARGS = { +# 'r': 8, +# 'lora_alpha': 16, # 32 +# 'lora_dropout': 0.05, +# 'target_modules': ["q_proj", "v_proj", "k_proj", "o_proj"] +# } + +### Hybrid Attention + + +from vllm.attention import Attention, AttentionMetadata + +class LlamaLoraAttention(LlamaAttention): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + _device = self.qkv_proj.weight.device + _dtype = self.qkv_proj.weight.dtype + print("Hello from Llama Lora Attention") + + def merge_lora_to_qkv_parallel(self, # param: Parameter, + loaded_delta: torch.Tensor, + loaded_shard_id: str = 'q', + total_num_heads: int = 32, + total_num_kv_heads: int = 4, + head_size: int = 128): + """ + Merge computed delta_AB into QKV parallel weights + + Based off of vLLM linear layer: https://github.com/vllm-project/vllm/blob/bc6e42a9b19364e07da9f279edd81796541d147d/vllm/model_executor/layers/linear.py#L762 + then Rahul, then Claude 3.5 Sonnet + + model.layers.1.self_attn.qkv_proj.weight torch.Size([1280, 8192]) + --> output_dim 0 + model.layers.1.self_attn.o_proj.weight torch.Size([8192, 1024]) + --> output_dim 0 + + apply this three times for q, k, and v LoRA deltas to the same layer + """ + + param = self.qkv_proj.weight + param_data = param.data + output_dim = getattr(param, "output_dim", None) + + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + # num_heads = divide(total_num_heads, tp_size) + # if tp_size >= total_num_kv_heads: + # num_kv_heads = 1 + # # num_kv_head_replicas = divide(tp_size, total_num_kv_heads) + # else: + # num_kv_heads = divide(total_num_kv_heads, tp_size) + # # num_kv_head_replicas = 1 + num_heads = total_num_heads + num_kv_heads = total_num_kv_heads + + num_original_kv_heads = 8 # all Llama 3.1 models have 8 kv heads in total + + num_kv_head_replicas = tp_size // num_original_kv_heads + + if loaded_shard_id == "q": + shard_offset = 0 + shard_size = num_heads * head_size + elif loaded_shard_id == "k": + shard_offset = num_heads * head_size + shard_size = num_kv_heads * head_size + elif loaded_shard_id == "v": + shard_offset = (num_heads + num_kv_heads) * head_size + shard_size = num_kv_heads * head_size + + # print(f"{tp_rank=}, {tp_size=}") + if loaded_shard_id == "q": + start_idx = tp_rank * shard_size + else: + start_idx = (tp_rank // num_kv_head_replicas) * shard_size + + device = param_data.device + + param_data = param_data.narrow(output_dim, shard_offset, shard_size) + # print(f'{loaded_shard_id=}') + # print(f'{shard_offset=}, {shard_size=}, {shard_offset+shard_size=}') + # print(f'{output_dim=}, {start_idx=}, {param_data.shape=}') + # print('-' * 10) + + # self.qkv_proj.weight.data[shard_offset:shard_offset+shard_size, :] += ( + # loaded_delta.narrow(output_dim, start_idx, shard_size).to(device) + # ) + # print(f'Loading {loaded_shard_id} {start_idx}-{start_idx+shard_size} into {param_data.shape}, which is slice({shard_offset}, {shard_offset+shard_size})') + try: + param_data.copy_(param_data + loaded_delta.narrow(output_dim, start_idx, shard_size).to(device)) + # print(f"Loaded {loaded_shard_id} into {param_data.shape}") + except Exception as e: + print(f"Error: {e}") + print(f"{loaded_shard_id=}") + print(f"{output_dim=}") + print(f"{start_idx=}") + print(f"{shard_size=}") + print(f"{param_data.shape=}") + print(f"{loaded_delta.shape=}") + print(f"{tp_rank=}") + print(f"{tp_size=}") + + def merge_lora_to_o_parallel(self, + loaded_delta: torch.Tensor): + """ + Merge computed delta_AB into output projection (RowParallel linear) + """ + param = self.o_proj.weight + param_data = param.data + input_dim = getattr(param, "input_dim", None) + device = param_data.device + + # print('o_proj {input_dim=}') + + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + + if input_dim is not None: + shard_size = param_data.shape[input_dim] + start_idx = tp_rank * shard_size + loaded_delta = loaded_delta.narrow(input_dim, start_idx, shard_size) + + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_delta.shape) == 0: + loaded_delta = loaded_delta.reshape(1) + + # print('{param_data.shape=} | {loaded_delta.shape=}') + # assert param_data.shape == loaded_delta.shape + param_data.copy_(param_data + loaded_delta.to(device)) + + +### VLLM Llama Model + + +class FeatureMap(nn.Module): + """ + Learnable MLP in feature map. + + Full feature map is like f(xW + b) + -> This is the `W` and (optional) `b` part + """ + def __init__(self, + num_heads: int, + head_dim: int, + feature_dim: int, + dtype: torch.dtype, + device: torch.device, + eps: float = 1e-12, + **kwargs): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.feature_dim = feature_dim + self.dtype = dtype + self.device = device + self.eps = eps + self.init_weights_() + + def activation(self, x: torch.Tensor): + return torch.cat([ + torch.softmax(x, dim=-1), torch.softmax(-x, dim=-1) + ], dim=-1).clamp(min=self.eps) + + def init_weights_(self): + self.layer = nn.Parameter(torch.zeros( + (self.num_heads, self.head_dim, self.feature_dim), + dtype=self.dtype, device=self.device, + )) + + def forward(self, x: torch.Tensor): + return self.activation( + torch.einsum('hdf,bhld->bhlf', self.layer, x)) + + +class LlamaLolcatsAttentionActual(nn.Module): + """Attention layer. + + This class takes query, key, and value tensors as input. The input tensors + can either contain prompt tokens or generation tokens. + The class does the following: + + 1. Store the input key and value tensors in the KV cache. + 2. Perform (multi-head/multi-query/grouped-query) attention. + 3. Return the output tensor. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + num_kv_heads: int, + ) -> None: + super().__init__() + self.num_heads = num_heads + self.head_size = head_size + self.num_kv_heads = num_kv_heads + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = num_heads // num_kv_heads + + max_seq_len = 2048 + window_size = 64 + + self.register_buffer('mask_window', self._create_mask(max_seq_len, window_size, True)) + self.register_buffer('mask_linear', self._create_mask(max_seq_len, window_size, False)) + + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + fmap_q: FeatureMap, + fmap_k: FeatureMap, + window_factors: torch.Tensor, + ) -> torch.Tensor: + # num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=1) + value = value.repeat_interleave(self.num_queries_per_kv, dim=1) + + query = query.movedim(0, query.dim() - 2) + key = key.movedim(0, key.dim() - 2) + value = value.movedim(0, value.dim() - 2) + + if query.dim() == 3: + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + + f_q = fmap_q(query) + f_k = fmap_k(key) + + window_size = 64 + window_factors = torch.nn.functional.sigmoid(window_factors) + linear_factors = 1 + # linear_factors = 1 - window_factors + + return self.superlinear_attention(query, key, f_q, f_k, + value, + window_factors, + linear_factors, + window_size) + + def _create_mask(self, max_seq_len: int, window_size: int, is_window: bool) -> torch.Tensor: + l = window_size + m = math.ceil(max_seq_len / window_size) + mask = torch.block_diag(*[torch.ones((l, l))] * m) + mask += torch.roll(mask, -l, -1) + mask = mask[:max_seq_len, :max_seq_len] + mask = mask[None, None, ...] # b, h, q_len, k_len + mask = torch.tril(mask if is_window else 1 - mask).to(dtype=torch.bool) + return mask + + def get_masks(self, window_size: int, q_len: int, k_len: int, device: torch.device) -> tuple[torch.Tensor]: + return self.mask_window[:, :, :q_len, :k_len], self.mask_linear[:, :, :q_len, :k_len] + + def superlinear_attention(self, q: torch.Tensor, k: torch.Tensor, + f_q: torch.Tensor, f_k: torch.Tensor, + v: torch.Tensor, + window_factor: torch.Tensor, + linear_factor: torch.Tensor, + window_size: int, + kv_state: torch.Tensor = None, + k_state: torch.Tensor = None, + eps: float = 1e-12, + mask_value: float=-1e8): + """ + Hybrid attention combining sliding window and linear attentions + """ + + mask_window, mask_linear = self.get_masks( + window_size, q.shape[-2], k.shape[-2], q.device) + + # 1. Sliding window (softmax attention) + # a_sm = torch.einsum( + # 'bhmd,bhnd->bhmn', q.float(), k.float()) * (k.shape[-1] ** -0.5) + a_sm = torch.einsum( + 'bhmd,bhnd->bhmn', q, k) * (k.shape[-1] ** -0.5) + a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value) + # torch.softmax(a_sm, dim=-1), but we account for the max when combining + a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True) + a_sm = window_factor * torch.exp(a_sm - a_sm_max) + sum_sm = a_sm.sum(dim=-1, keepdim=True) + + # 2. Under window (linear attention) + # a_ln = torch.einsum('bhmd,bhnd->bhmn', f_q.float(), f_k.float()) + a_ln = torch.einsum('bhmd,bhnd->bhmn', f_q, f_k) + a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0) + sum_ln = a_ln.sum(dim=-1, keepdim=True) + + # 3. Combine + # Allow outputs to also depend on prior kv_state and k_state + # y = torch.einsum('bhmn,bhnd->bhmd', a_sm + a_ln, v.float()) + # y = (y / (sum_sm + sum_ln)).to(q.dtype) + y = torch.einsum('bhmn,bhnd->bhmd', a_sm + a_ln, v) + y = (y / (sum_sm + sum_ln)) + # # logger.info(f"splattn {y.shape=}") + return y # attention weights only for the last chunk + + +class LlamaLolcatsAttention(LlamaAttention): + def __init__(self, *args, **kwargs): + + super().__init__(*args, **kwargs) + self.attn = LlamaLolcatsAttentionActual(self.num_heads, + self.head_dim, + self.num_kv_heads) + + _device = self.qkv_proj.weight.device + _dtype = self.qkv_proj.weight.dtype + + _feature_dim = 64 + + _feature_map_kwargs = { + "num_heads": self.num_heads, + "head_dim": self.head_dim, + "feature_dim": _feature_dim, + "dtype": _dtype, + "device": _device, + } + + self.feature_map_q = FeatureMap(**_feature_map_kwargs) + self.feature_map_k = FeatureMap(**_feature_map_kwargs) + self.window_factors = nn.Parameter( + torch.ones(1, self.num_heads, 1, 1, device=_device, dtype=_dtype)) + + def load_window_factors(self, loaded_weight: torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + + if tp_size > 1: + + num_heads_per_rank = self.num_heads + start_idx = tp_rank * num_heads_per_rank + end_idx = start_idx + num_heads_per_rank + + sharded_weight = loaded_weight[:, start_idx:end_idx, :, :] + + else: + + sharded_weight = loaded_weight + + assert self.window_factors.shape == sharded_weight.shape, \ + f"Shape mismatch: {self.window_factors.shape} vs {sharded_weight.shape}" + + with torch.no_grad(): + self.window_factors.copy_(sharded_weight) + + def load_feature_map_q(self, loaded_weight: torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + + # print(f"{tp_size}") + + if tp_size > 1: + + num_heads_per_rank = self.feature_map_q.num_heads + start_idx = tp_rank * num_heads_per_rank + end_idx = start_idx + num_heads_per_rank + + sharded_weight = loaded_weight[start_idx:end_idx, :, :] + + if sharded_weight.shape[-1] != self.feature_map_q.layer.shape[-1]: + sharded_weight = sharded_weight[:, :, :self.feature_map_q.layer.shape[-1]] + + else: + + sharded_weight = loaded_weight + + assert self.feature_map_q.layer.shape == sharded_weight.shape, \ + f"Shape mismatch: {self.feature_map_q.layer.shape} vs {sharded_weight.shape}" + + with torch.no_grad(): + self.feature_map_q.layer.copy_(sharded_weight) + + def load_feature_map_k(self, loaded_weight: torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + + if tp_size > 1: + + num_heads_per_rank = self.feature_map_k.num_heads + start_idx = tp_rank * num_heads_per_rank + end_idx = start_idx + num_heads_per_rank + + sharded_weight = loaded_weight[start_idx:end_idx, :, :] + + if sharded_weight.shape[-1] != self.feature_map_k.layer.shape[-1]: + sharded_weight = sharded_weight[:, :, :self.feature_map_k.layer.shape[-1]] + + else: + + sharded_weight = loaded_weight + + assert self.feature_map_k.layer.shape == sharded_weight.shape, \ + f"Shape mismatch: {self.feature_map_k.layer.shape} vs {sharded_weight.shape}" + + with torch.no_grad(): + self.feature_map_k.layer.copy_(sharded_weight) + # self.feature_map_k.layer.normal_(std=1) + + def merge_lora_to_qkv_parallel(self, # param: Parameter, + loaded_delta: torch.Tensor, + loaded_shard_id: str = 'q', + total_num_heads: int = 32, + total_num_kv_heads: int = 4, + head_size: int = 128): + """ + Merge computed delta_AB into QKV parallel weights + + Based off of vLLM linear layer: https://github.com/vllm-project/vllm/blob/bc6e42a9b19364e07da9f279edd81796541d147d/vllm/model_executor/layers/linear.py#L762 + then Rahul, then Claude 3.5 Sonnet + + model.layers.1.self_attn.qkv_proj.weight torch.Size([1280, 8192]) + --> output_dim 0 + model.layers.1.self_attn.o_proj.weight torch.Size([8192, 1024]) + --> output_dim 0 + + apply this three times for q, k, and v LoRA deltas to the same layer + """ + + param = self.qkv_proj.weight + param_data = param.data + output_dim = getattr(param, "output_dim", None) + + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + # num_heads = divide(total_num_heads, tp_size) + # if tp_size >= total_num_kv_heads: + # num_kv_heads = 1 + # # num_kv_head_replicas = divide(tp_size, total_num_kv_heads) + # else: + # num_kv_heads = divide(total_num_kv_heads, tp_size) + # # num_kv_head_replicas = 1 + num_heads = total_num_heads + num_kv_heads = total_num_kv_heads + + num_original_kv_heads = 8 # all Llama 3.1 models have 8 kv heads in total + + num_kv_head_replicas = tp_size // num_original_kv_heads + + if loaded_shard_id == "q": + shard_offset = 0 + shard_size = num_heads * head_size + elif loaded_shard_id == "k": + shard_offset = num_heads * head_size + shard_size = num_kv_heads * head_size + elif loaded_shard_id == "v": + shard_offset = (num_heads + num_kv_heads) * head_size + shard_size = num_kv_heads * head_size + + # print(f"{tp_rank=}, {tp_size=}") + if loaded_shard_id == "q": + start_idx = tp_rank * shard_size + else: + start_idx = (tp_rank // num_kv_head_replicas) * shard_size + + device = param_data.device + + param_data = param_data.narrow(output_dim, shard_offset, shard_size) + # print(f'{loaded_shard_id=}') + # print(f'{shard_offset=}, {shard_size=}, {shard_offset+shard_size=}') + # print(f'{output_dim=}, {start_idx=}, {param_data.shape=}') + # print('-' * 10) + + # self.qkv_proj.weight.data[shard_offset:shard_offset+shard_size, :] += ( + # loaded_delta.narrow(output_dim, start_idx, shard_size).to(device) + # ) + # print(f'Loading {loaded_shard_id} {start_idx}-{start_idx+shard_size} into {param_data.shape}, which is slice({shard_offset}, {shard_offset+shard_size})') + try: + param_data.copy_(param_data + loaded_delta.narrow(output_dim, start_idx, shard_size).to(device)) + # print(f"Loaded {loaded_shard_id} into {param_data.shape}") + except Exception as e: + print(f"Error: {e}") + print(f"{loaded_shard_id=}") + print(f"{output_dim=}") + print(f"{start_idx=}") + print(f"{shard_size=}") + print(f"{param_data.shape=}") + print(f"{loaded_delta.shape=}") + print(f"{tp_rank=}") + print(f"{tp_size=}") + + def merge_lora_to_o_parallel(self, + loaded_delta: torch.Tensor): + """ + Merge computed delta_AB into output projection (RowParallel linear) + """ + param = self.o_proj.weight + param_data = param.data + input_dim = getattr(param, "input_dim", None) + device = param_data.device + + # print('o_proj {input_dim=}') + + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + + if input_dim is not None: + shard_size = param_data.shape[input_dim] + start_idx = tp_rank * shard_size + loaded_delta = loaded_delta.narrow(input_dim, start_idx, shard_size) + + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_delta.shape) == 0: + loaded_delta = loaded_delta.reshape(1) + + # print('{param_data.shape=} | {loaded_delta.shape=}') + # assert param_data.shape == loaded_delta.shape + param_data.copy_(param_data + loaded_delta.to(device)) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + ndim = hidden_states.dim() + qkv, _ = self.qkv_proj(hidden_states) + seq_len = hidden_states.shape[-2] + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, + fmap_q=self.feature_map_q, + fmap_k=self.feature_map_k, + window_factors=self.window_factors) + attn_output = attn_output.transpose(1, 2).contiguous().view(-1, seq_len, self.num_heads * self.head_dim) + output, _ = self.o_proj(attn_output) + if output.dim() > ndim: + output = output.squeeze(0) + return output + + +class LlamaLolcatsForCausalLM(LlamaForCausalLM): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + print(f"LOLCATS!!!: Loading model with config: {self.config}") + + softmax_attentions = getattr(self.config, 'softmax_attentions', []) + + for i in range(len(self.model.layers)): + if i in softmax_attentions: + print(f"Using Lora Llama Attention at Layer {i}") + self.model.layers[i].self_attn = LlamaLoraAttention( + config=self.config, + hidden_size=self.config.hidden_size, + num_heads=self.config.num_attention_heads, + num_kv_heads=getattr(self.config, "num_key_value_heads", + self.config.num_attention_heads), + rope_theta=self.config.rope_theta, + rope_scaling=self.config.rope_scaling, + ) + else: + self.model.layers[i].self_attn = LlamaLolcatsAttention( + config=self.config, + hidden_size=self.config.hidden_size, + num_heads=self.config.num_attention_heads, + num_kv_heads=getattr(self.config, "num_key_value_heads", + self.config.num_attention_heads), + rope_theta=self.config.rope_theta, + rope_scaling=self.config.rope_scaling, + ) + + def get_device(self): + device = next(self.parameters()).device + return str(device) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + super().load_weights(weights) + + # model_size = 8 + # model_size = 70 + model_size = 405 + + # PATH = f'/data/rahul/checkpoints/{model_size}b.pt' + # PATH = f'/home/rahul/code/lolcats/ckpt_lora-dl-d=distill_redpajama_xent1_mse1000_lr1e-2-m=distill_llama3_1_70b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_redpajama-dcs=512-se=0-re=4-lzi=1-dcs=512-se=0-re=4.pt' + + # Trenchcoats v1 + # PATH = f'/home/rahul/code/lolcats/ckpt_lora-dl-d=distill_llama_405b_xent1_mse1000_lr1e-2-m=distill_llama3_1_405b_lk_smd_wtk64_fd64_w01-f=finetune_layer_mini_xent1_mse1000-s=0-se=0-re=0-ef=finetune_llama_405b_qkvo-ft_lora=0-se=0-re=0-ef=finetune_llama_405b_qkvo-ft_lora=0.pt' + + # No distill + # PATH = f'/home/rahul/code/lolcats/ckpt_lora-dl-d=no_distill_alpaca_clean-m=distill_llama3_1_405b_lk_smd_wtk64_fd64_w01-f=finetune_layer_mini_xent1_mse1000-s=0-se=0-re=0-ef=no_distill_finetune_405b-ft_lora=0-se=0-re=0-ef=no_distill_finetune_405b-ft_lora=0-no_distill.pt' + + # Hybrid (last cria attention) + PATH = f'/home/rahul/code/lolcats/ckpt_lora-dl-d=distill_llama_405b_xent1_mse1000_lr1e-2-m=distill_llama3_1_405b_lk_smd_wtk64_fd64_w01_h117-f=finetune_layer_mini_xent1_mse1000-s=0-se=0-re=0-ef=finetune_llama_405b_qkvo_cos-ft_lora=0-se=0-re=0-ef=finetune_llama_405b_qkvo_cos-ft_lora=0.pt' + + print(f"PATH: {PATH}!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + + adapter_weights_path = os.getenv("LOLCATS_ADAPTER_PATH", PATH) + + adapter_weights = torch.load(adapter_weights_path, weights_only=True) + + adapter_weights_copy = OrderedDict({}) + + for key, value in adapter_weights.items(): + key_suffix = key[key.rindex("model.")+6:] + adapter_weights_copy[key_suffix] = value + + adapter_weights = adapter_weights_copy + updated_keys = [] + + print("\n") + for layer_idx, layer in enumerate(self.model.layers): + # if layer_idx == 0: + # print(f'Weight factors before checkpoint load, {layer.self_attn.window_factors.shape}, {layer.self_attn.window_factors.flatten()}') + + window_factors_key = f'layers.{layer_idx}.self_attn.window_factors' + if window_factors_key in adapter_weights: + layer.self_attn.load_window_factors(adapter_weights[window_factors_key]) + updated_keys.append(window_factors_key) + + # if layer_idx == 0: + # print(f'Weight factors after checkpoint load, {layer.self_attn.window_factors.shape}, {layer.self_attn.window_factors.flatten()}') + if layer_idx == 0: + print("\n") + print(f'FMAP Q before checkpoint load, {layer.self_attn.feature_map_q.layer.shape}, {layer.self_attn.feature_map_q.layer[0,0,:4]}') + + fm_q_key = f'layers.{layer_idx}.self_attn.feature_map_q.mlp.layer' + if fm_q_key in adapter_weights: + layer.self_attn.load_feature_map_q(adapter_weights[fm_q_key]) + updated_keys.append(fm_q_key) + + if layer_idx == 0: + print(f'FMAP Q after checkpoint load; {layer.self_attn.feature_map_q.layer.shape},{layer.self_attn.feature_map_q.layer[0,0,:4]}') + + fm_k_key = f'layers.{layer_idx}.self_attn.feature_map_k.mlp.layer' + if fm_k_key in adapter_weights: + layer.self_attn.load_feature_map_k(adapter_weights[fm_k_key]) + updated_keys.append(fm_k_key) + + weight_name = 'layers.{layer_idx}.self_attn.{proj}.lora_{a_or_b}.default.weight' + target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"] + # target_modules = ["q_proj", "k_proj", "v_proj"] + # target_modules = ["k_proj", "v_proj"] + # target_modules = ["q_proj", "k_proj"] + + r = 8 + lora_alpha = 16 + lora_dropout = 0 + + for proj in target_modules: + lora_A_key = weight_name.format(layer_idx=layer_idx, proj=proj, a_or_b='A') + lora_B_key = weight_name.format(layer_idx=layer_idx, proj=proj, a_or_b='B') + if lora_A_key in adapter_weights: + weight_A = adapter_weights[lora_A_key] + weight_B = adapter_weights[lora_B_key] + delta_AB = get_delta_weight(weight_A, weight_B, r=r, lora_alpha=lora_alpha, + lora_dropout=lora_dropout) + + # if layer_idx == 0: + # print(f'layer {layer_idx} weight_A.shape: {weight_A.shape} | weight_B.shape: {weight_B.shape} | delta_AB.shape: {delta_AB.shape}') + # print(f'layer {layer_idx} proj {proj} delta_AB', delta_AB.shape) + + if proj == 'o_proj': + if layer_idx == 0: + print("\n") + print(f'Layer {layer_idx} {proj} weight before checkpoint load, {layer.self_attn.o_proj.weight.shape}, {layer.self_attn.o_proj.weight[0,:4]}') + + layer.self_attn.merge_lora_to_o_parallel(delta_AB) + + if layer_idx == 0: + print(f'Layer {layer_idx} {proj} weight after checkpoint load, {layer.self_attn.o_proj.weight.shape}, {layer.self_attn.o_proj.weight[0,:4]}') + else: + # if layer_idx == 0 and proj in ['q_proj']: + # print(f'Layer {layer_idx} {proj} weight before checkpoint load, {layer.self_attn.qkv_proj.weight.shape}, {layer.self_attn.qkv_proj.weight[0,:4]}') + + layer.self_attn.merge_lora_to_qkv_parallel(delta_AB, + loaded_shard_id=proj.split('_')[0], + total_num_heads=layer.self_attn.num_heads, + total_num_kv_heads=layer.self_attn.num_kv_heads, + head_size=layer.self_attn.head_dim) + # if layer_idx == 0 and proj in ['q_proj']: + # print(f'Layer {layer_idx} {proj} weight after checkpoint load, {layer.self_attn.qkv_proj.weight.shape}, {layer.self_attn.qkv_proj.weight[0,:4]}') + updated_keys.append(lora_A_key) + updated_keys.append(lora_B_key) + + assert len(set(adapter_weights_copy.keys()) - set(updated_keys)) == 0, \ + f"UNUPDATED KEYS: {set(adapter_weights_copy.keys()) - set(updated_keys)}" + + +def transpose(weight, fan_in_fan_out): + if not fan_in_fan_out: + return weight + + if isinstance(weight, torch.nn.Parameter): + return torch.nn.Parameter(weight.T) + return weight.T + + +def get_delta_weight(weight_A: torch.Tensor, weight_B: torch.Tensor, + r: int = 8, lora_alpha: float = 16, lora_dropout: float = 0, + fan_in_fan_out: bool = False,): + + device = weight_B.device + dtype = weight_B.dtype + # From https://github.com/huggingface/peft/blob/850eeb5c3a5cf692f5612c7c733b13fde184e05d/src/peft/tuners/lora/layer.py#L512 + cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16) + if cast_to_fp32: + weight_A = weight_A.float() + weight_B = weight_B.float() + scaling = lora_alpha / r + output_tensor = transpose(weight_B @ weight_A, fan_in_fan_out) * scaling + if cast_to_fp32: + output_tensor = output_tensor.to(dtype=dtype) + return output_tensor diff --git a/demos/vllm_integration/vllm_files/lolcats_inference.py b/demos/vllm_integration/vllm_files/lolcats_inference.py new file mode 100644 index 0000000..9865052 --- /dev/null +++ b/demos/vllm_integration/vllm_files/lolcats_inference.py @@ -0,0 +1,870 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only LLaMA model compatible with HuggingFace weights.""" +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +import math +import os +import torch +import time +from collections import OrderedDict +from torch import nn +from torch.nn.parameter import Parameter +from transformers import LlamaConfig + +from vllm.attention import Attention, AttentionMetadata, AttentionType +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.distributed import (divide, + # split_tensor_along_last_dim, + # tensor_model_parallel_all_gather, + # tensor_model_parallel_all_reduce + ) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + get_compressed_tensors_cache_scale) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.utils import is_hip + + + +from vllm.model_executor.models.interfaces import SupportsLoRA +from vllm.model_executor.models.utils import PPMissingLayer, is_pp_missing_parameter, make_layers +# from .interfaces import SupportsLoRA +# from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers +from vllm.model_executor.models.llama import LlamaForCausalLM, LlamaAttention, LlamaMLP + +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import ReplicatedLinear + +logger = init_logger(__name__) + + +### OURS for Linear attention implementation +# from peft import get_peft_model, LoraConfig, TaskType + +# PEFT_KWARGS = { +# 'r': 8, +# 'lora_alpha': 16, # 32 +# 'lora_dropout': 0.05, +# 'target_modules': ["q_proj", "v_proj", "k_proj", "o_proj"] +# } + +### Hybrid Attention + + +from vllm.attention import Attention, AttentionMetadata + + +### VLLM Llama Model + + +class FeatureMap(nn.Module): + """ + Learnable MLP in feature map. + + Full feature map is like f(xW + b) + -> This is the `W` and (optional) `b` part + """ + def __init__(self, + num_heads: int, + head_dim: int, + feature_dim: int, + dtype: torch.dtype, + device: torch.device, + eps: float = 1e-12, + **kwargs): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.feature_dim = feature_dim + self.dtype = dtype + self.device = device + self.eps = eps + self.init_weights_() + + def activation(self, x: torch.Tensor): + return torch.cat([ + torch.softmax(x, dim=-1), torch.softmax(-x, dim=-1) + ], dim=-1).clamp(min=self.eps) + + def init_weights_(self): + self.layer = nn.Parameter(torch.zeros( + (self.num_heads, self.head_dim, self.feature_dim), + dtype=self.dtype, device=self.device, + )) + + def forward(self, x: torch.Tensor): + return self.activation( + torch.einsum('hdf,bhld->bhlf', self.layer, x.to(self.dtype))) + + +from dataclasses import dataclass +@dataclass +class LoLCacheParams: + is_prompt: bool = False + kv_state: torch.Tensor = torch.Tensor() + k_state: torch.Tensor = torch.Tensor() + kv_cache: torch.Tensor = torch.Tensor() + + +class LlamaLolcatsAttentionActual(nn.Module): + """Attention layer. + + This class takes query, key, and value tensors as input. The input tensors + can either contain prompt tokens or generation tokens. + The class does the following: + + 1. Store the input key and value tensors in the KV cache. + 2. Perform (multi-head/multi-query/grouped-query) attention. + 3. Return the output tensor. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + num_kv_heads: int, + layer_idx: int, + ) -> None: + super().__init__() + self.num_heads = num_heads + self.head_size = head_size + self.num_kv_heads = num_kv_heads + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = num_heads // num_kv_heads + + max_seq_len = 2048 + window_size = 64 + self.window_size = window_size + + self.register_buffer('mask_window', self._create_mask(max_seq_len, window_size, True)) + self.register_buffer('mask_linear', self._create_mask(max_seq_len, window_size, False)) + + # SA: inference cache + self.lolcats_cache = None + tp_rank = get_tensor_model_parallel_rank() + + print(f"{layer_idx=}") + self.layer_idx = layer_idx + self.tp_rank = tp_rank + + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + fmap_q: FeatureMap, + fmap_k: FeatureMap, + window_factors: torch.Tensor, + state=None + ) -> torch.Tensor: + + if self.lolcats_cache is None: + self._prepare_lolcats_cache() + + # num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=1) + value = value.repeat_interleave(self.num_queries_per_kv, dim=1) + + query = query.movedim(0, query.dim() - 2) + key = key.movedim(0, key.dim() - 2) + value = value.movedim(0, value.dim() - 2) + + if query.dim() == 3: + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + + f_q = fmap_q(query) + f_k = fmap_k(key) + + window_size = 64 + window_factors = torch.nn.functional.sigmoid(window_factors) + linear_factors = 1 + + # if self.tp_rank == 0 and self.layer_idx == 0: print(f"{query.shape=}") + seqlen = query.shape[2] + # if self.tp_rank == 0 and self.layer_idx == 0: print(f"{seqlen=}") + if seqlen == 1: + return self.recurrent_attention( + query, key, f_q, f_k, + value, window_factors, + linear_factors, + window_size, + fmap_q, fmap_k + ) + else: + return self.superlinear_attention( + query, key, f_q, f_k, + value, + window_factors, linear_factors, + window_size + ) + + + def _create_mask(self, max_seq_len: int, window_size: int, is_window: bool) -> torch.Tensor: + l = window_size + m = math.ceil(max_seq_len / window_size) + mask = torch.block_diag(*[torch.ones((l, l))] * m) + mask += torch.roll(mask, -l, -1) + mask = mask[:max_seq_len, :max_seq_len] + mask = mask[None, None, ...] # b, h, q_len, k_len + mask = torch.tril(mask if is_window else 1 - mask).to(dtype=torch.bool) + return mask + + + def get_masks(self, window_size: int, q_len: int, k_len: int, device: torch.device) -> tuple[torch.Tensor]: + return self.mask_window[:, :, :q_len, :k_len], self.mask_linear[:, :, :q_len, :k_len] + + + def _prepare_lolcats_cache(self): + if self.tp_rank == 0 and self.layer_idx == 0: print("heyo 5 -- hello prepare kv cache") + dtype = torch.bfloat16 + bs = 1 + self.lolcats_cache = LoLCacheParams() + if self.tp_rank == 0 and self.layer_idx == 0: print("heyo 6 -- bye prepare kv cache") + + + def _init_kv_cache(self, keys, values, f_k): + if self.tp_rank == 0 and self.layer_idx == 0: print("heyo 3 -- hello init kv cache") + dtype = keys.dtype + if self.tp_rank == 0 and self.layer_idx == 0: print(f"{f_k.shape=}") + + # decoding KV state (KV terms up to last window_size) + decode_kv_state = torch.einsum('bhlf,bhld->bhfd', + f_k[:, :, :-self.window_size], + values[:, :, :-self.window_size] + ) + + if self.tp_rank == 0 and self.layer_idx == 0: + print(decode_kv_state[0][0]) + if self.tp_rank == 0 and self.layer_idx == 0: print(f"{decode_kv_state.shape=}") + if self.tp_rank == 0 and self.layer_idx == 0: print(f"{f_k.shape=}") + # shape is b, h, 1, f; note the 1 + decode_k_state = f_k[:, :, :-self.window_size].sum(dim=-2,keepdim=True) + self.lolcats_cache.kv_state = decode_kv_state + self.lolcats_cache.k_state = decode_k_state + if self.tp_rank == 0 and self.layer_idx == 0: print(f"{decode_k_state.shape=}") + + # update the cache + kv_cache = torch.stack([ + keys[:, :, -self.window_size:, :].float(), + values[:, :, -self.window_size:, :].float() + ], dim=1) + if self.tp_rank == 0 and self.layer_idx == 0: print(f"{kv_cache.shape=}") + self.lolcats_cache.kv_cache = kv_cache + if self.tp_rank == 0 and self.layer_idx == 0: print("heyo 4 -- bye init kv cache") + + + def superlinear_attention( + self, q: torch.Tensor, k: torch.Tensor, + f_q: torch.Tensor, f_k: torch.Tensor, + v: torch.Tensor, + window_factor: torch.Tensor, linear_factor: torch.Tensor, + window_size: int, + kv_state: torch.Tensor = None, k_state: torch.Tensor = None, + eps: float = 1e-12, + mask_value: float=-1e8 + ): + """ + Hybrid attention combining sliding window and linear attentions + """ + + mask_window, mask_linear = self.get_masks( + window_size, q.shape[-2], k.shape[-2], q.device) + + # 1. Sliding window (softmax attention) + # a_sm = torch.einsum( + # 'bhmd,bhnd->bhmn', q.float(), k.float()) * (k.shape[-1] ** -0.5) + a_sm = torch.einsum( + 'bhmd,bhnd->bhmn', q, k) * (k.shape[-1] ** -0.5) + a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value) + # torch.softmax(a_sm, dim=-1), but we account for the max when combining + a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True) + a_sm = window_factor * torch.exp(a_sm - a_sm_max) + sum_sm = a_sm.sum(dim=-1, keepdim=True) + + # 2. Under window (linear attention) + # a_ln = torch.einsum('bhmd,bhnd->bhmn', f_q.float(), f_k.float()) + a_ln = torch.einsum('bhmd,bhnd->bhmn', f_q, f_k) + a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0) + sum_ln = a_ln.sum(dim=-1, keepdim=True) + + # 3. Combine + # Allow outputs to also depend on prior kv_state and k_state + # y = torch.einsum('bhmn,bhnd->bhmd', a_sm + a_ln, v.float()) + # y = (y / (sum_sm + sum_ln)).to(q.dtype) + y = torch.einsum('bhmn,bhnd->bhmd', a_sm + a_ln, v) + y = (y / (sum_sm + sum_ln)) + # # logger.info(f"splattn {y.shape=}") + + self._init_kv_cache(k, v, f_k) + return y # attention weights only for the last chunk + + + def _update_kv_cache(self, keys, values, fmap_k): + # if self.tp_rank == 0 and self.layer_idx == 0: print("heyo 1 - hello update kv cache") + # get state from before + kv_state = self.lolcats_cache.kv_state + k_state = self.lolcats_cache.k_state + kv_cache_swa = self.lolcats_cache.kv_cache + k_cache = kv_cache_swa[:, 0] + v_cache = kv_cache_swa[:, 1] + + dtype = kv_state.dtype + + # update the linear attention states + # since we ignore the diag blocks, just grab last tokens of kv cache + cur_seq_len = k_cache.shape[-2] + if self.tp_rank == 0 and self.layer_idx == 0: print(f"{cur_seq_len=}") + if cur_seq_len >= self.window_size: + if self.tp_rank == 0 and self.layer_idx == 0: print(f"Updating the kv_state and k_state...") + # if self.tp_rank == 0 and self.layer_idx == 0: + # print(f"{fmap_k.layer=}") + # print(f"{k_cache[0, 0, 0, 0:8]=}") + # print(f"{k_cache[:, :, :1, :]=}") + # print(f"{fmap_k(k_cache[:, :, :1, :])=}") + k_state = fmap_k(k_cache[:, :, :1, :]) + v_state = v_cache[:, :, :1, :] + kv_state = torch.einsum('bhlf,bhld->bhfd', k_state.float(), v_state.float()).to(dtype) # b, h, f, d + self.lolcats_cache.kv_state += kv_state.to(kv_state.dtype) + self.lolcats_cache.k_state += k_state + + # update swa states + if cur_seq_len < self.window_size: + # only add to cache + k_cache = torch.cat([k_cache, keys], dim=-2) + v_cache = torch.cat([v_cache, values], dim=-2) + else: + # remove oldest key and value and append + k_cache = torch.cat([k_cache[:, :, 1:, :], keys], dim=-2) + v_cache = torch.cat([v_cache[:, :, 1:, :], values], dim=-2) + kv_cache_swa = torch.stack([k_cache, v_cache], dim=1) + self.lolcats_cache.kv_cache = kv_cache_swa + + # if self.tp_rank == 0 and self.layer_idx == 0: print("heyo 2 - bye update kv cache") + return self.lolcats_cache.kv_state, self.lolcats_cache.k_state, k_cache, v_cache + + + def recurrent_attention( + self, q: torch.Tensor, k: torch.Tensor, + f_q: torch.Tensor, f_k: torch.Tensor, + v: torch.Tensor, + window_factor: torch.Tensor, + linear_factor: torch.Tensor, + window_size: int, + fmap_q, fmap_k, + kv_state: torch.Tensor = None, + k_state: torch.Tensor = None, + eps: float = 1e-12, mask_value: float=-1e8 + ): + dtype = torch.float32 + # if self.tp_rank == 0 and self.layer_idx == 0: print(f"hello recurrent step!") + kv_state, k_state, k_cache, v_cache = self._update_kv_cache(k, v, fmap_k) + + # Softmax attention terms + a_sm = torch.einsum('bhmd,bhnd->bhmn', q.float(), k_cache.float()) * (k.shape[-1] ** -0.5) + a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True) + a_sm = window_factor * torch.exp(a_sm - a_sm_max) + sum_sm = a_sm.sum(dim=-1, keepdim=True) + y_sm = torch.einsum('bhmn,bhnd->bhmd', a_sm.float(), v_cache.float()) + + # Combine with linear attention terms + f_q = fmap_q(q) + y_ln = linear_factor * torch.einsum('bhlf,bhfd->bhld', f_q.float(), kv_state.float()) + sum_ln = linear_factor * torch.einsum('bhld,bhnd->bhl', f_q.float(), k_state.float())[..., None] + + # if self.tp_rank == 0 and self.layer_idx == 0: + # print(f"{y_ln[0][0][0][:4]=}") + # print(f"{sum_ln[0][0]=}") + + y = y_sm + y_ln + attn_output = (y / (sum_sm + sum_ln)).to(q.dtype) + + # if self.tp_rank == 0 and self.layer_idx == 0: print(f"bye recurrent step!") + return attn_output + + +class LlamaLolcatsAttention(LlamaAttention): + def __init__(self, layer_idx, *args, **kwargs): + + super().__init__(*args, **kwargs) + print(f"{layer_idx=}") + self.attn = LlamaLolcatsAttentionActual(self.num_heads, + self.head_dim, + self.num_kv_heads, + layer_idx) + self.head_size = self.head_dim + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + _device = self.qkv_proj.weight.device + _dtype = self.qkv_proj.weight.dtype + + _feature_dim = 64 + _feature_map_kwargs = { + "num_heads": self.num_heads, + "head_dim": self.head_dim, + "feature_dim": _feature_dim, + "dtype": _dtype, + "device": _device, + } + self.feature_dim = _feature_dim + self.window_size = 64 + + tp_rank = get_tensor_model_parallel_rank() + self.tp_rank = tp_rank + self.layer_idx = layer_idx + + self.feature_map_q = FeatureMap(**_feature_map_kwargs) + self.feature_map_k = FeatureMap(**_feature_map_kwargs) + self.window_factors = nn.Parameter( + torch.ones(1, self.num_heads, 1, 1, device=_device, dtype=_dtype)) + + def load_window_factors(self, loaded_weight: torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + + # print(f"{tp_size=}") + # assert 0, "ahhhh window factors" + + if tp_size > 1: + + num_heads_per_rank = self.num_heads + start_idx = tp_rank * num_heads_per_rank + end_idx = start_idx + num_heads_per_rank + + if self.layer_idx == 0 and tp_rank == 0: + print(loaded_weight) + + if self.layer_idx < 2: + print(f"{num_heads_per_rank=}") + print(f"{tp_rank=}; {loaded_weight.shape=}; {start_idx=}; {end_idx=}") + + sharded_weight = loaded_weight[:, start_idx:end_idx, :, :] + + else: + + sharded_weight = loaded_weight + + assert self.window_factors.shape == sharded_weight.shape, \ + f"Shape mismatch: {self.window_factors.shape} vs {sharded_weight.shape}" + + with torch.no_grad(): + self.window_factors.copy_(sharded_weight) + + def load_feature_map_q(self, loaded_weight: torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + + # print(f"{tp_size=}") + # assert 0, "ahhhh feature map q" + + if tp_size > 1: + + num_heads_per_rank = self.feature_map_q.num_heads + start_idx = tp_rank * num_heads_per_rank + end_idx = start_idx + num_heads_per_rank + + sharded_weight = loaded_weight[start_idx:end_idx, :, :] + + if sharded_weight.shape[-1] != self.feature_map_q.layer.shape[-1]: + sharded_weight = sharded_weight[:, :, :self.feature_map_q.layer.shape[-1]] + + else: + + sharded_weight = loaded_weight + + assert self.feature_map_q.layer.shape == sharded_weight.shape, \ + f"Shape mismatch: {self.feature_map_q.layer.shape} vs {sharded_weight.shape}" + + with torch.no_grad(): + self.feature_map_q.layer.copy_(sharded_weight) + + def load_feature_map_k(self, loaded_weight: torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + + # print(f"{tp_size=}") + # assert 0, "ahhhh" + + if tp_size > 1: + + num_heads_per_rank = self.feature_map_k.num_heads + start_idx = tp_rank * num_heads_per_rank + end_idx = start_idx + num_heads_per_rank + + sharded_weight = loaded_weight[start_idx:end_idx, :, :] + + if sharded_weight.shape[-1] != self.feature_map_k.layer.shape[-1]: + sharded_weight = sharded_weight[:, :, :self.feature_map_k.layer.shape[-1]] + + else: + + sharded_weight = loaded_weight + + assert self.feature_map_k.layer.shape == sharded_weight.shape, \ + f"Shape mismatch: {self.feature_map_k.layer.shape} vs {sharded_weight.shape}" + + with torch.no_grad(): + self.feature_map_k.layer.copy_(sharded_weight) + # self.feature_map_k.layer.normal_(std=1) + + def merge_lora_to_qkv_parallel(self, # param: Parameter, + loaded_delta: torch.Tensor, + loaded_shard_id: str = 'q', + total_num_heads: int = 32, + total_num_kv_heads: int = 4, + head_size: int = 128): + """ + Merge computed delta_AB into QKV parallel weights + + Based off of vLLM linear layer: https://github.com/vllm-project/vllm/blob/bc6e42a9b19364e07da9f279edd81796541d147d/vllm/model_executor/layers/linear.py#L762 + then Rahul, then Claude 3.5 Sonnet + + model.layers.1.self_attn.qkv_proj.weight torch.Size([1280, 8192]) + --> output_dim 0 + model.layers.1.self_attn.o_proj.weight torch.Size([8192, 1024]) + --> output_dim 0 + + apply this three times for q, k, and v LoRA deltas to the same layer + """ + + param = self.qkv_proj.weight + param_data = param.data + output_dim = getattr(param, "output_dim", None) + + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + # num_heads = divide(total_num_heads, tp_size) + # if tp_size >= total_num_kv_heads: + # num_kv_heads = 1 + # # num_kv_head_replicas = divide(tp_size, total_num_kv_heads) + # else: + # num_kv_heads = divide(total_num_kv_heads, tp_size) + # # num_kv_head_replicas = 1 + num_heads = total_num_heads + num_kv_heads = total_num_kv_heads + + num_original_kv_heads = 8 # all Llama 3.1 models have 8 kv heads in total + + num_kv_head_replicas = tp_size // num_original_kv_heads + + if loaded_shard_id == "q": + shard_offset = 0 + shard_size = num_heads * head_size + elif loaded_shard_id == "k": + shard_offset = num_heads * head_size + shard_size = num_kv_heads * head_size + elif loaded_shard_id == "v": + shard_offset = (num_heads + num_kv_heads) * head_size + shard_size = num_kv_heads * head_size + + # print(f"{tp_rank=}, {tp_size=}") + if loaded_shard_id == "q": + start_idx = tp_rank * shard_size + else: + start_idx = (tp_rank // num_kv_head_replicas) * shard_size + + device = param_data.device + + param_data = param_data.narrow(output_dim, shard_offset, shard_size) + # print(f'{loaded_shard_id=}') + # print(f'{shard_offset=}, {shard_size=}, {shard_offset+shard_size=}') + # print(f'{output_dim=}, {start_idx=}, {param_data.shape=}') + # print('-' * 10) + + # self.qkv_proj.weight.data[shard_offset:shard_offset+shard_size, :] += ( + # loaded_delta.narrow(output_dim, start_idx, shard_size).to(device) + # ) + # print(f'Loading {loaded_shard_id} {start_idx}-{start_idx+shard_size} into {param_data.shape}, which is slice({shard_offset}, {shard_offset+shard_size})') + try: + param_data.copy_(param_data + loaded_delta.narrow(output_dim, start_idx, shard_size).to(device)) + # print(f"Loaded {loaded_shard_id} into {param_data.shape}") + except Exception as e: + print(f"Error: {e}") + print(f"{loaded_shard_id=}") + print(f"{output_dim=}") + print(f"{start_idx=}") + print(f"{shard_size=}") + print(f"{param_data.shape=}") + print(f"{loaded_delta.shape=}") + print(f"{tp_rank=}") + print(f"{tp_size=}") + + def merge_lora_to_o_parallel(self, + loaded_delta: torch.Tensor): + """ + Merge computed delta_AB into output projection (RowParallel linear) + """ + param = self.o_proj.weight + param_data = param.data + input_dim = getattr(param, "input_dim", None) + device = param_data.device + + # print('o_proj {input_dim=}') + + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + + if input_dim is not None: + shard_size = param_data.shape[input_dim] + start_idx = tp_rank * shard_size + loaded_delta = loaded_delta.narrow(input_dim, start_idx, shard_size) + + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_delta.shape) == 0: + loaded_delta = loaded_delta.reshape(1) + + # print('{param_data.shape=} | {loaded_delta.shape=}') + # assert param_data.shape == loaded_delta.shape + param_data.copy_(param_data + loaded_delta.to(device)) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + ndim = hidden_states.dim() + qkv, _ = self.qkv_proj(hidden_states) + seq_len = hidden_states.shape[-2] + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + + attn_output = self.attn( + q, k, v, + fmap_q=self.feature_map_q, + fmap_k=self.feature_map_k, + window_factors=self.window_factors, + state=None + ) + + # outputs + attn_output = attn_output.transpose(1, 2).contiguous().view(-1, seq_len, self.num_heads * self.head_dim) + output, _ = self.o_proj(attn_output) + if output.dim() > ndim: + output = output.squeeze(0) + return output + + +class LlamaLolcatsForCausalLM(LlamaForCausalLM): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + print(f"LOLCATS!!!: Loading model with config: {self.config}") + + softmax_attentions = getattr(self.config, 'softmax_attentions', []) + print(f"{softmax_attentions=}") + + tp_rank = get_tensor_model_parallel_rank() + self.tp_rank = tp_rank + + for i in range(len(self.model.layers)): + if i in softmax_attentions: + print(f"Using Lora Llama Attention at Layer {i}") + self.model.layers[i].self_attn = LlamaSdpaAttention( + config=self.config, + hidden_size=self.config.hidden_size, + num_heads=self.config.num_attention_heads, + num_kv_heads=getattr(self.config, "num_key_value_heads", + self.config.num_attention_heads), + rope_theta=self.config.rope_theta, + rope_scaling=self.config.rope_scaling, + ) + else: + self.model.layers[i].self_attn = LlamaLolcatsAttention( + i, + config=self.config, + hidden_size=self.config.hidden_size, + num_heads=self.config.num_attention_heads, + num_kv_heads=getattr(self.config, "num_key_value_heads", + self.config.num_attention_heads), + rope_theta=self.config.rope_theta, + rope_scaling=self.config.rope_scaling, + ) + print(self.model) + + def get_device(self): + device = next(self.parameters()).device + return str(device) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + super().load_weights(weights) + + # model_size = 8 + # FINETUNE_PATH = '/home/rahul/code/lolcats/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_1_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-ms=2500-se=0-re=100-lzi=1-bs=1-gas=8-nte=2-ms=2500-se=0-re=100_ft.pt' + # MLP_PATH = '/home/rahul/code/lolcats/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_1_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-ms=2500-se=0-re=100-lzi=1_distill.pt' + # # merge the MLP and FINETUNE weights as adapter weights + # adapter_weights = torch.load(FINETUNE_PATH, weights_only=True) + # adapter_weights.update(torch.load(MLP_PATH, weights_only=True)) + # print(adapter_weights.keys()) + # # only keep any weight with 'feature' or 'window' or 'lora' in the key + # adapter_weights = {k: v for k, v in adapter_weights.items() if 'feature' in k or 'window' in k or 'lora' in k} + + model_size = 70 + PATH = f'/data/rahul/checkpoints/{model_size}b.pt' + PATH = f'/home/rahul/code/lolcats/ckpt_lora-dl-d=distill_redpajama_xent1_mse1000_lr1e-2-m=distill_llama3_1_70b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_redpajama-dcs=512-se=0-re=4-lzi=1-dcs=512-se=0-re=4.pt' + print(f"PATH INFERENCE: {PATH}!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + adapter_weights_path = os.getenv("LOLCATS_ADAPTER_PATH", PATH) + adapter_weights = torch.load(adapter_weights_path, weights_only=True) + + adapter_weights_copy = OrderedDict({}) + + for key, value in adapter_weights.items(): + key_suffix = key[key.rindex("model.")+6:] + adapter_weights_copy[key_suffix] = value + + adapter_weights = adapter_weights_copy + updated_keys = [] + + print("\n") + num_layers = len(self.model.layers) + for layer_idx, layer in enumerate(self.model.layers): + if layer_idx == 0: + print(f'Weight factors before checkpoint load {self.tp_rank=}, {layer.self_attn.window_factors.shape}, {layer.self_attn.window_factors.flatten()}') + + window_factors_key = f'layers.{layer_idx}.self_attn.window_factors' + if window_factors_key in adapter_weights: + layer.self_attn.load_window_factors(adapter_weights[window_factors_key]) + updated_keys.append(window_factors_key) + + if layer_idx == 0: + print(f'Weight factors after checkpoint load {self.tp_rank=}, {layer.self_attn.window_factors.shape}, {layer.self_attn.window_factors.flatten()}') + + fm_q_key = f'layers.{layer_idx}.self_attn.feature_map_q.mlp.layer' + if fm_q_key in adapter_weights: + # if layer_idx in [0, num_layers-1]: + # # print("\n") + # # print(f'FMAP Q before checkpoint load {self.tp_rank=}, {layer.self_attn.feature_map_q.layer.shape}, {layer.self_attn.feature_map_q.layer[0,0,:4]}') + + layer.self_attn.load_feature_map_q(adapter_weights[fm_q_key]) + updated_keys.append(fm_q_key) + + # if layer_idx in [0, num_layers-1]: + # print(f'FMAP Q after checkpoint load; {layer.self_attn.feature_map_q.layer.shape},{layer.self_attn.feature_map_q.layer[0,0,:4]}') + + fm_k_key = f'layers.{layer_idx}.self_attn.feature_map_k.mlp.layer' + if fm_k_key in adapter_weights: + layer.self_attn.load_feature_map_k(adapter_weights[fm_k_key]) + updated_keys.append(fm_k_key) + + weight_name = 'layers.{layer_idx}.self_attn.{proj}.lora_{a_or_b}.default.weight' + target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"] + # target_modules = ["q_proj", "k_proj", "v_proj"] + # target_modules = ["k_proj", "v_proj"] + # target_modules = ["q_proj", "k_proj"] + + r = 8 + lora_alpha = 16 + lora_dropout = 0 + + for proj in target_modules: + lora_A_key = weight_name.format(layer_idx=layer_idx, proj=proj, a_or_b='A') + lora_B_key = weight_name.format(layer_idx=layer_idx, proj=proj, a_or_b='B') + if lora_A_key in adapter_weights: + weight_A = adapter_weights[lora_A_key] + weight_B = adapter_weights[lora_B_key] + delta_AB = get_delta_weight(weight_A, weight_B, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout) + + # if layer_idx in [0, num_layers-1]: + # print(f'layer {layer_idx} weight_A.shape: {weight_A.shape} | weight_B.shape: {weight_B.shape} | delta_AB.shape: {delta_AB.shape}') + # print(f'layer {layer_idx} proj {proj} delta_AB', delta_AB.shape) + + if proj == 'o_proj': + # if layer_idx in [0, num_layers-1]: + # print("\n") + # print(f'Layer {layer_idx} {proj} weight before checkpoint load, {layer.self_attn.o_proj.weight.shape}, {layer.self_attn.o_proj.weight[0,:4]}') + + layer.self_attn.merge_lora_to_o_parallel(delta_AB) + + # if layer_idx in [0, num_layers-1]: + # print(f'Layer {layer_idx} {proj} weight after checkpoint load, {layer.self_attn.o_proj.weight.shape}, {layer.self_attn.o_proj.weight[0,:4]}') + else: + # if layer_idx in [0, num_layers-1] and proj in ['q_proj']: + # print(f'Layer {layer_idx} {proj} weight before checkpoint load, {layer.self_attn.qkv_proj.weight.shape}, {layer.self_attn.qkv_proj.weight[0,:4]}') + + layer.self_attn.merge_lora_to_qkv_parallel( + delta_AB, + loaded_shard_id=proj.split('_')[0], + total_num_heads=layer.self_attn.num_heads, + total_num_kv_heads=layer.self_attn.num_kv_heads,head_size=layer.self_attn.head_dim) + + # if layer_idx in [0, num_layers-1] and proj in ['q_proj']: + # print(f'Layer {layer_idx} {proj} weight after checkpoint load, {layer.self_attn.qkv_proj.weight.shape}, {layer.self_attn.qkv_proj.weight[0,:4]}') + updated_keys.append(lora_A_key) + updated_keys.append(lora_B_key) + + assert len(set(adapter_weights_copy.keys()) - set(updated_keys)) == 0, \ + f"UNUPDATED KEYS: {set(adapter_weights_copy.keys()) - set(updated_keys)}" + + +def transpose(weight, fan_in_fan_out): + if not fan_in_fan_out: + return weight + + if isinstance(weight, torch.nn.Parameter): + return torch.nn.Parameter(weight.T) + return weight.T + + +def get_delta_weight(weight_A: torch.Tensor, weight_B: torch.Tensor, + r: int = 8, lora_alpha: float = 16, lora_dropout: float = 0, + fan_in_fan_out: bool = False,): + + device = weight_B.device + dtype = weight_B.dtype + # From https://github.com/huggingface/peft/blob/850eeb5c3a5cf692f5612c7c733b13fde184e05d/src/peft/tuners/lora/layer.py#L512 + cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16) + if cast_to_fp32: + weight_A = weight_A.float() + weight_B = weight_B.float() + scaling = lora_alpha / r + output_tensor = transpose(weight_B @ weight_A, fan_in_fan_out) * scaling + if cast_to_fp32: + output_tensor = output_tensor.to(dtype=dtype) + return output_tensor diff --git a/demos/vllm_integration/vllm_files/lolcats_inference_paged.py b/demos/vllm_integration/vllm_files/lolcats_inference_paged.py new file mode 100644 index 0000000..5c92304 --- /dev/null +++ b/demos/vllm_integration/vllm_files/lolcats_inference_paged.py @@ -0,0 +1,912 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only LLaMA model compatible with HuggingFace weights.""" +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +import math +import os +import torch +import time +from collections import OrderedDict +from torch import nn +from torch.nn.parameter import Parameter +from transformers import LlamaConfig + +from vllm.attention import Attention, AttentionMetadata, AttentionType +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.distributed import (divide, + # split_tensor_along_last_dim, + # tensor_model_parallel_all_gather, + # tensor_model_parallel_all_reduce + ) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + get_compressed_tensors_cache_scale) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.utils import is_hip + + + +from vllm.model_executor.models.interfaces import SupportsLoRA +from vllm.model_executor.models.utils import PPMissingLayer, is_pp_missing_parameter, make_layers +# from .interfaces import SupportsLoRA +# from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers +from vllm.model_executor.models.llama import LlamaForCausalLM, LlamaAttention, LlamaMLP + +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import ReplicatedLinear + +logger = init_logger(__name__) + + +### OURS for Linear attention implementation +# from peft import get_peft_model, LoraConfig, TaskType + +# PEFT_KWARGS = { +# 'r': 8, +# 'lora_alpha': 16, # 32 +# 'lora_dropout': 0.05, +# 'target_modules': ["q_proj", "v_proj", "k_proj", "o_proj"] +# } + +### Hybrid Attention +from vllm.attention import Attention, AttentionMetadata + + +### VLLM Llama Model +class FeatureMap(nn.Module): + """ + Learnable MLP in feature map. + + Full feature map is like f(xW + b) + -> This is the `W` and (optional) `b` part + """ + def __init__(self, + num_heads: int, + head_dim: int, + feature_dim: int, + dtype: torch.dtype, + device: torch.device, + eps: float = 1e-12, + **kwargs): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.feature_dim = feature_dim + self.dtype = dtype + self.device = device + self.eps = eps + self.init_weights_() + + def activation(self, x: torch.Tensor): + return torch.cat([ + torch.softmax(x, dim=-1), torch.softmax(-x, dim=-1) + ], dim=-1).clamp(min=self.eps) + + def init_weights_(self): + self.layer = nn.Parameter(torch.zeros( + (self.num_heads, self.head_dim, self.feature_dim), + dtype=self.dtype, device=self.device, + )) + + def forward(self, x: torch.Tensor): + return self.activation( + torch.einsum('hdf,bhld->bhlf', self.layer, x.to(self.dtype))) + + +from dataclasses import dataclass +@dataclass +class LoLCacheParams: + is_prompt: bool = False + kv_state: torch.Tensor = torch.Tensor() + k_state: torch.Tensor = torch.Tensor() + kv_cache: torch.Tensor = torch.Tensor() + +@dataclass +class PageCache: + kv_cache: torch.Tensor = None + q_cache: torch.Tensor = None + + +class LlamaLolcatsAttentionActual(nn.Module): + """Attention layer. + + This class takes query, key, and value tensors as input. The input tensors + can either contain prompt tokens or generation tokens. + The class does the following: + + 1. Store the input key and value tensors in the KV cache. + 2. Perform (multi-head/multi-query/grouped-query) attention. + 3. Return the output tensor. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + num_kv_heads: int, + layer_idx: int, + ) -> None: + super().__init__() + self.num_heads = num_heads + self.head_size = head_size + self.num_kv_heads = num_kv_heads + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = num_heads // num_kv_heads + self.window_size = 64 + + # SA: inference cache + self.lolcats_cache = None + self.layer_idx = layer_idx + self.tp_rank = get_tensor_model_parallel_rank() + + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + fmap_q: FeatureMap, + fmap_k: FeatureMap, + window_factors: torch.Tensor, + state=None, + attn_metadata: AttentionMetadata = None + ) -> torch.Tensor: + # if self.layer_idx == 0: + # print(f"Initially: {query.shape=}, {key.shape=}, {value.shape=}") + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + positions = attn_metadata.seq_start_loc.tolist() + start, end = positions[0], positions[1] + + if self.lolcats_cache is None or end == num_prefill_tokens: + # reset cache + self._prepare_lolcats_cache() + if self.layer_idx == 0 and self.tp_rank == 0: + print("Resetting cache") + print(f"-- {num_prefill_tokens=}, {num_decode_tokens=}, {start=}, {end=}") + # print(self.page_cache.kv_cache) + + # num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=1) + value = value.repeat_interleave(self.num_queries_per_kv, dim=1) + + query = query.movedim(0, query.dim() - 2) + key = key.movedim(0, key.dim() - 2) + value = value.movedim(0, value.dim() - 2) + + if query.dim() == 3: + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + + query, key, value = self.get_full_key_value(query, key, value) + if self.layer_idx == 0 and self.tp_rank == 0: + print(f"-- after update: {query.shape=}, {key.shape=}, {value.shape=}") + + f_q = fmap_q(query) + f_k = fmap_k(key) + + window_size = 64 + window_factors = torch.nn.functional.sigmoid(window_factors) + linear_factors = 1 + + seq_len = query.shape[-2] + if num_decode_tokens >= 1 or seq_len == 1: + return self.recurrent_attention( + query, key, f_q, f_k, + value, window_factors, + linear_factors, + window_size, + fmap_q, fmap_k + ) + else: + out = self.superlinear_attention( + query, key, f_q, f_k, + value, + window_factors, linear_factors, + window_size + ) + return out + + + def get_full_key_value(self, query, key, value): + # add the current key and value to the cache + if self.page_cache.kv_cache is not None: + key = torch.cat([self.page_cache.kv_cache[:, 0], key], dim=-2) + value = torch.cat([self.page_cache.kv_cache[:, 1], value], dim=-2) + query = torch.cat([self.page_cache.q_cache, query], dim=-2) + else: + key = key + value = value + query = query + + # update the cache + self.page_cache.kv_cache = torch.stack([key, value], dim=1) + self.page_cache.q_cache = query + return query, key, value + + + def _create_mask(self, max_seq_len: int, window_size: int, is_window: bool) -> torch.Tensor: + l = window_size + m = math.ceil(max_seq_len / window_size) + mask = torch.block_diag(*[torch.ones((l, l))] * m) + mask += torch.roll(mask, -l, -1) + mask = mask[:max_seq_len, :max_seq_len] + mask = mask[None, None, ...] # b, h, q_len, k_len + mask = torch.tril(mask if is_window else 1 - mask).to(dtype=torch.bool) + return mask + + + def get_masks(self, window_size: int, q_len: int, k_len: int, device: torch.device) -> tuple[torch.Tensor]: + mask_window = self._create_mask(q_len, window_size, True).to(device) + mask_linear = self._create_mask(q_len, window_size, False).to(device) + return mask_window[:, :, :q_len, :k_len], mask_linear[:, :, :q_len, :k_len] + + + def _prepare_lolcats_cache(self): + self.lolcats_cache = LoLCacheParams() + self.page_cache = PageCache() + + + def _init_kv_cache(self, keys, values, f_k): + dtype = keys.dtype + + # decoding KV state (KV terms up to last window_size) + decode_kv_state = torch.einsum('bhlf,bhld->bhfd', + f_k[:, :, :-self.window_size], + values[:, :, :-self.window_size] + ) + + # shape is b, h, 1, f; note the 1 + decode_k_state = f_k[:, :, :-self.window_size].sum(dim=-2,keepdim=True) + self.lolcats_cache.kv_state = decode_kv_state + self.lolcats_cache.k_state = decode_k_state + + # update the cache + kv_cache = torch.stack([ + keys[:, :, -self.window_size:, :].float(), + values[:, :, -self.window_size:, :].float() + ], dim=1) + self.lolcats_cache.kv_cache = kv_cache + + + def superlinear_attention( + self, q: torch.Tensor, k: torch.Tensor, + f_q: torch.Tensor, f_k: torch.Tensor, + v: torch.Tensor, + window_factor: torch.Tensor, linear_factor: torch.Tensor, + window_size: int, + kv_state: torch.Tensor = None, k_state: torch.Tensor = None, + eps: float = 1e-12, + mask_value: float=-1e8 + ): + """ + Hybrid attention combining sliding window and linear attentions + """ + + mask_window, mask_linear = self.get_masks( + window_size, q.shape[-2], k.shape[-2], q.device) + + # 1. Sliding window (softmax attention) + a_sm = torch.einsum( + 'bhmd,bhnd->bhmn', q.float(), k.float()) * (k.shape[-1] ** -0.5) + a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value) + # torch.softmax(a_sm, dim=-1), but we account for the max when combining + a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True) + a_sm = window_factor * torch.exp(a_sm - a_sm_max) + sum_sm = a_sm.sum(dim=-1, keepdim=True) + + # 2. Under window (linear attention) + a_ln = torch.einsum('bhmd,bhnd->bhmn', f_q.float(), f_k.float()) + a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0).to(q.dtype) + sum_ln = a_ln.sum(dim=-1, keepdim=True) + + # 3. Combine + # Allow outputs to also depend on prior kv_state and k_state + y = torch.einsum('bhmn,bhnd->bhmd', a_sm + a_ln, v) + y = (y / (sum_sm + sum_ln)) + + self._init_kv_cache(k, v, f_k) + return y.to(q.dtype) # attention weights only for the last chunk + + + def _update_kv_cache(self, keys, values, fmap_k): + # if self.tp_rank == 0 and self.layer_idx == 0: print("heyo 1 - hello update kv cache") + # get state from before + kv_state = self.lolcats_cache.kv_state + k_state = self.lolcats_cache.k_state + kv_cache_swa = self.lolcats_cache.kv_cache + k_cache = kv_cache_swa[:, 0] + v_cache = kv_cache_swa[:, 1] + + dtype = kv_state.dtype + + # update the linear attention states + # since we ignore the diag blocks, just grab last tokens of kv cache + cur_seq_len = k_cache.shape[-2] + # if self.tp_rank == 0 and self.layer_idx == 0: print(f"{cur_seq_len=}") + if cur_seq_len >= self.window_size: + # if self.tp_rank == 0 and self.layer_idx == 0: print(f"Updating the kv_state and k_state...") + k_state = fmap_k(k_cache[:, :, :1, :]) + v_state = v_cache[:, :, :1, :] + kv_state = torch.einsum('bhlf,bhld->bhfd', k_state.float(), v_state.float()).to(dtype) # b, h, f, d + self.lolcats_cache.kv_state += kv_state.to(kv_state.dtype) + self.lolcats_cache.k_state += k_state + + # update swa states + if cur_seq_len < self.window_size: + # only add to cache + k_cache = torch.cat([k_cache, keys], dim=-2) + v_cache = torch.cat([v_cache, values], dim=-2) + else: + # remove oldest key and value and append + k_cache = torch.cat([k_cache[:, :, 1:, :], keys], dim=-2) + v_cache = torch.cat([v_cache[:, :, 1:, :], values], dim=-2) + kv_cache_swa = torch.stack([k_cache, v_cache], dim=1) + self.lolcats_cache.kv_cache = kv_cache_swa + + # if self.tp_rank == 0 and self.layer_idx == 0: print("heyo 2 - bye update kv cache") + return self.lolcats_cache.kv_state, self.lolcats_cache.k_state, k_cache, v_cache + + + def recurrent_attention( + self, q: torch.Tensor, k: torch.Tensor, + f_q: torch.Tensor, f_k: torch.Tensor, + v: torch.Tensor, + window_factor: torch.Tensor, + linear_factor: torch.Tensor, + window_size: int, + fmap_q, fmap_k, + kv_state: torch.Tensor = None, + k_state: torch.Tensor = None, + eps: float = 1e-12, mask_value: float=-1e8 + ): + dtype = torch.float32 + kv_state, k_state, k_cache, v_cache = self._update_kv_cache(k, v, fmap_k) + + # Softmax attention terms + a_sm = torch.einsum('bhmd,bhnd->bhmn', q.float(), k_cache.float()) * (k.shape[-1] ** -0.5) + a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True) + a_sm = window_factor * torch.exp(a_sm - a_sm_max) + sum_sm = a_sm.sum(dim=-1, keepdim=True) + y_sm = torch.einsum('bhmn,bhnd->bhmd', a_sm.float(), v_cache.float()) + + # Combine with linear attention terms + f_q = fmap_q(q) + y_ln = linear_factor * torch.einsum('bhlf,bhfd->bhld', f_q.float(), kv_state.float()) + sum_ln = linear_factor * torch.einsum('bhld,bhnd->bhl', f_q.float(), k_state.float())[..., None] + + y = y_sm + y_ln + attn_output = (y / (sum_sm + sum_ln)).to(q.dtype) + return attn_output + + +class LlamaLolcatsAttention(LlamaAttention): + def __init__(self, layer_idx, use_base_attn, *args, **kwargs): + + super().__init__(*args, **kwargs) + self.use_base_attn = use_base_attn + if self.use_base_attn: + # coppy the original self.attn into self.base_attn before we override + # use deepcopy to avoid any shared references + import copy + self.base_attn = copy.deepcopy(self.attn) + + self.attn = LlamaLolcatsAttentionActual(self.num_heads, + self.head_dim, + self.num_kv_heads, + layer_idx) + self.head_size = self.head_dim + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + _device = self.qkv_proj.weight.device + _dtype = self.qkv_proj.weight.dtype + + _feature_dim = 64 + _feature_map_kwargs = { + "num_heads": self.num_heads, + "head_dim": self.head_dim, + "feature_dim": _feature_dim, + "dtype": _dtype, + "device": _device, + } + self.feature_dim = _feature_dim + self.window_size = 64 + + tp_rank = get_tensor_model_parallel_rank() + self.tp_rank = tp_rank + self.layer_idx = layer_idx + + self.feature_map_q = FeatureMap(**_feature_map_kwargs) + self.feature_map_k = FeatureMap(**_feature_map_kwargs) + self.window_factors = nn.Parameter( + torch.ones(1, self.num_heads, 1, 1, device=_device, dtype=_dtype)) + + + def load_window_factors(self, loaded_weight: torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + if tp_size > 1: + + num_heads_per_rank = self.num_heads + start_idx = tp_rank * num_heads_per_rank + end_idx = start_idx + num_heads_per_rank + + if self.layer_idx == 0 and tp_rank == 0: + print(loaded_weight) + + if self.layer_idx < 2: + print(f"{num_heads_per_rank=}") + print(f"{tp_rank=}; {loaded_weight.shape=}; {start_idx=}; {end_idx=}") + + sharded_weight = loaded_weight[:, start_idx:end_idx, :, :] + + else: + + sharded_weight = loaded_weight + + assert self.window_factors.shape == sharded_weight.shape, \ + f"Shape mismatch: {self.window_factors.shape} vs {sharded_weight.shape}" + + with torch.no_grad(): + self.window_factors.copy_(sharded_weight) + + def load_feature_map_q(self, loaded_weight: torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + + if tp_size > 1: + + num_heads_per_rank = self.feature_map_q.num_heads + start_idx = tp_rank * num_heads_per_rank + end_idx = start_idx + num_heads_per_rank + + sharded_weight = loaded_weight[start_idx:end_idx, :, :] + + if sharded_weight.shape[-1] != self.feature_map_q.layer.shape[-1]: + sharded_weight = sharded_weight[:, :, :self.feature_map_q.layer.shape[-1]] + + else: + + sharded_weight = loaded_weight + + assert self.feature_map_q.layer.shape == sharded_weight.shape, \ + f"Shape mismatch: {self.feature_map_q.layer.shape} vs {sharded_weight.shape}" + + with torch.no_grad(): + self.feature_map_q.layer.copy_(sharded_weight) + + def load_feature_map_k(self, loaded_weight: torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + + if tp_size > 1: + + num_heads_per_rank = self.feature_map_k.num_heads + start_idx = tp_rank * num_heads_per_rank + end_idx = start_idx + num_heads_per_rank + + sharded_weight = loaded_weight[start_idx:end_idx, :, :] + + if sharded_weight.shape[-1] != self.feature_map_k.layer.shape[-1]: + sharded_weight = sharded_weight[:, :, :self.feature_map_k.layer.shape[-1]] + + else: + + sharded_weight = loaded_weight + + assert self.feature_map_k.layer.shape == sharded_weight.shape, \ + f"Shape mismatch: {self.feature_map_k.layer.shape} vs {sharded_weight.shape}" + + with torch.no_grad(): + self.feature_map_k.layer.copy_(sharded_weight) + # self.feature_map_k.layer.normal_(std=1) + + def merge_lora_to_qkv_parallel(self, # param: Parameter, + loaded_delta: torch.Tensor, + loaded_shard_id: str = 'q', + total_num_heads: int = 32, + total_num_kv_heads: int = 4, + head_size: int = 128): + """ + Merge computed delta_AB into QKV parallel weights + + Based off of vLLM linear layer: https://github.com/vllm-project/vllm/blob/bc6e42a9b19364e07da9f279edd81796541d147d/vllm/model_executor/layers/linear.py#L762 + then Rahul, then Claude 3.5 Sonnet + + model.layers.1.self_attn.qkv_proj.weight torch.Size([1280, 8192]) + --> output_dim 0 + model.layers.1.self_attn.o_proj.weight torch.Size([8192, 1024]) + --> output_dim 0 + + apply this three times for q, k, and v LoRA deltas to the same layer + """ + + param = self.qkv_proj.weight + param_data = param.data + output_dim = getattr(param, "output_dim", None) + + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + # num_heads = divide(total_num_heads, tp_size) + # if tp_size >= total_num_kv_heads: + # num_kv_heads = 1 + # # num_kv_head_replicas = divide(tp_size, total_num_kv_heads) + # else: + # num_kv_heads = divide(total_num_kv_heads, tp_size) + # # num_kv_head_replicas = 1 + num_heads = total_num_heads + num_kv_heads = total_num_kv_heads + + num_original_kv_heads = 8 # all Llama 3.1 models have 8 kv heads in total + + num_kv_head_replicas = tp_size // num_original_kv_heads + + if loaded_shard_id == "q": + shard_offset = 0 + shard_size = num_heads * head_size + elif loaded_shard_id == "k": + shard_offset = num_heads * head_size + shard_size = num_kv_heads * head_size + elif loaded_shard_id == "v": + shard_offset = (num_heads + num_kv_heads) * head_size + shard_size = num_kv_heads * head_size + + # print(f"{tp_rank=}, {tp_size=}") + if loaded_shard_id == "q": + start_idx = tp_rank * shard_size + else: + start_idx = (tp_rank // num_kv_head_replicas) * shard_size + + device = param_data.device + + param_data = param_data.narrow(output_dim, shard_offset, shard_size) + # print(f'{loaded_shard_id=}') + # print(f'{shard_offset=}, {shard_size=}, {shard_offset+shard_size=}') + # print(f'{output_dim=}, {start_idx=}, {param_data.shape=}') + # print('-' * 10) + + # self.qkv_proj.weight.data[shard_offset:shard_offset+shard_size, :] += ( + # loaded_delta.narrow(output_dim, start_idx, shard_size).to(device) + # ) + # print(f'Loading {loaded_shard_id} {start_idx}-{start_idx+shard_size} into {param_data.shape}, which is slice({shard_offset}, {shard_offset+shard_size})') + try: + param_data.copy_(param_data + loaded_delta.narrow(output_dim, start_idx, shard_size).to(device)) + # print(f"Loaded {loaded_shard_id} into {param_data.shape}") + except Exception as e: + print(f"Error: {e}") + print(f"{loaded_shard_id=}") + print(f"{output_dim=}") + print(f"{start_idx=}") + print(f"{shard_size=}") + print(f"{param_data.shape=}") + print(f"{loaded_delta.shape=}") + print(f"{tp_rank=}") + print(f"{tp_size=}") + + def merge_lora_to_o_parallel(self, + loaded_delta: torch.Tensor): + """ + Merge computed delta_AB into output projection (RowParallel linear) + """ + param = self.o_proj.weight + param_data = param.data + input_dim = getattr(param, "input_dim", None) + device = param_data.device + + # print('o_proj {input_dim=}') + + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + + if input_dim is not None: + shard_size = param_data.shape[input_dim] + start_idx = tp_rank * shard_size + loaded_delta = loaded_delta.narrow(input_dim, start_idx, shard_size) + + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_delta.shape) == 0: + loaded_delta = loaded_delta.reshape(1) + + # print('{param_data.shape=} | {loaded_delta.shape=}') + # assert param_data.shape == loaded_delta.shape + param_data.copy_(param_data + loaded_delta.to(device)) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + problem_idx: int + ) -> torch.Tensor: + ndim = hidden_states.dim() + qkv, _ = self.qkv_proj(hidden_states) + seq_len = hidden_states.shape[-2] + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + + attn_output = self.attn( + q, k, v, + fmap_q=self.feature_map_q, + fmap_k=self.feature_map_k, + window_factors=self.window_factors, + state=None, + attn_metadata=attn_metadata + ) + + ref_output = None + expt_tag = '_cria_alpaca_final' + if self.use_base_attn and self.layer_idx % 9 == 0: + ref_output = self.base_attn( + q, k, v, + attn_metadata=attn_metadata, + kv_cache=kv_cache, + ) + + dir_path = f"/data/simran/mmlu_hybrid_outputs_{expt_tag}/" + if not os.path.exists(dir_path): os.makedirs(dir_path, exist_ok=True) + fpath = f"{dir_path}/our_attn_output_problem{problem_idx}_rank{self.tp_rank}_layer{self.layer_idx}.pt" + torch.save(attn_output, fpath) + fpath = f"{dir_path}/ref_attn_output_problem{problem_idx}_rank{self.tp_rank}_layer{self.layer_idx}.pt" + torch.save(ref_output, fpath) + # print(f"Saved!") + # end save stuff + + # outputs + full_seq_len = attn_output.shape[-2] # in case we updated the length + attn_output = attn_output.transpose(1, 2).contiguous().view( + -1, full_seq_len, self.num_heads * self.head_dim + ) + output, _ = self.o_proj(attn_output) + if output.dim() > ndim: + output = output.squeeze(0) + output = output[-seq_len:, ...] # put back the original seq_len + + if self.use_base_attn and self.layer_idx % 9 == 0: + ref_y, _ = self.o_proj(ref_output) + dir_path = f"/data/simran/mmlu_hybrid_y_outs_{expt_tag}/" + if not os.path.exists(dir_path): os.makedirs(dir_path, exist_ok=True) + fpath = f"{dir_path}/our_y_out_problem{problem_idx}_rank{self.tp_rank}_layer{self.layer_idx}.pt" + torch.save(output, fpath) + fpath = f"{dir_path}/ref_y_out_problem{problem_idx}_rank{self.tp_rank}_layer{self.layer_idx}.pt" + torch.save(ref_y, fpath) + + return output + + +class LlamaLolcatsForCausalLM(LlamaForCausalLM): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + print(f"LOLCATS!!!: Loading model with config: {self.config}") + + tp_rank = get_tensor_model_parallel_rank() + self.tp_rank = tp_rank + + softmax_attentions = getattr(self.config, 'softmax_attentions', []) + print(f"{softmax_attentions=}") + + use_base_attn = getattr(self.config, 'use_base_attn', False) + + for i in range(len(self.model.layers)): + if i in softmax_attentions: + pass + else: + self.model.layers[i].self_attn = LlamaLolcatsAttention( + i, + use_base_attn=use_base_attn, + config=self.config, + hidden_size=self.config.hidden_size, + num_heads=self.config.num_attention_heads, + num_kv_heads=getattr(self.config, "num_key_value_heads", + self.config.num_attention_heads), + rope_theta=self.config.rope_theta, + rope_scaling=self.config.rope_scaling, + ) + print(self.model) + + + def get_device(self): + device = next(self.parameters()).device + return str(device) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + super().load_weights(weights) + + # r = 8 + # lora_alpha = 16 + # lora_dropout = 0 + + # model_size = 8 + # FINETUNE_PATH = '/home/rahul/code/lolcats/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_1_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-ms=2500-se=0-re=100-lzi=1-bs=1-gas=8-nte=2-ms=2500-se=0-re=100_ft.pt' + # MLP_PATH = '/home/rahul/code/lolcats/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_1_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-ms=2500-se=0-re=100-lzi=1_distill.pt' + # # merge the MLP and FINETUNE weights as adapter weights + # adapter_weights = torch.load(FINETUNE_PATH, weights_only=True) + # adapter_weights.update(torch.load(MLP_PATH, weights_only=True)) + # print(adapter_weights.keys()) + # # only keep any weight with 'feature' or 'window' or 'lora' in the key + # adapter_weights = {k: v for k, v in adapter_weights.items() if 'feature' in k or 'window' in k or 'lora' in k} + + # model_size = 70 + # # PATH = f'/data/rahul/checkpoints/{model_size}b.pt' + # PATH = f'/home/rahul/code/lolcats/ckpt_lora-dl-d=distill_redpajama_xent1_mse1000_lr1e-2-m=distill_llama3_1_70b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_redpajama-dcs=512-se=0-re=4-lzi=1-dcs=512-se=0-re=4.pt' + + ########### 405 B ############ + + # PATH = '/home/rahul/code/lolcats/ckpts/seqlen768.pt' # 405B at 768 seqlen + + # 1. Alpaca Cria QV Rank 4 -- with hybridization + PATH = '/home/rahul/code/lolcats/ckpt_lora-dl-d=distill_llama_405b_xent1_mse1000_lr1e-2-m=distill_llama3_1_405b_lk_smd_wtk64_fd64_w01_h72_80_117_125-f=finetune_layer_mini_xent1_mse1000-s=0-se=0-re=0-ef=finetune_llama_405b_qkvo_e2_h72_80_117_125-ft_lora=0-se=0-re=0-alpaca.pt' + + # 2. Alpaca Cria QV Rank 4 -- pure + # PATH = '/home/rahul/code/lolcats/ckpt_lora-dl-d=distill_llama_405b_xent1_mse1000_lr1e-2-m=distill_llama3_1_405b_lk_smd_wtk64_fd64_w01-f=finetune_layer_mini_xent1_mse1000-s=0-se=0-re=0-ef=finetune_llama_405b_qkvo_e2-ft_lora=0-se=0-re=0-ef=finetune_llama_405b_qkvo_e2-ft_lora=0_epoch2.pt' + + # 3. RP Cria QV Rank 4 -- pure + # PATH = '/home/rahul/code/lolcats/ckpts/cria_rp.pt' # 780.pt step + # PATH = '/home/rahul/code/lolcats/ckpt_lora-dl-d=rp_distill_llama_405b_xent1_mse1000_lr1e-2-m=distill_llama3_1_405b_lk_smd_wtk64_fd64_w01-f=rp_finetune_llama_40b_qv_hparams-s=0-se=0-re=0-ef=finetune_llama_405b_qkvo_e2_rp-ft_lora=0-se=0-re=0-s=1670.pt' + + print(f"PATH INFERENCE: {PATH}!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + adapter_weights_path = os.getenv("LOLCATS_ADAPTER_PATH", PATH) + adapter_weights = torch.load(adapter_weights_path, weights_only=True) + + adapter_weights_copy = OrderedDict({}) + + for key, value in adapter_weights.items(): + key_suffix = key[key.rindex("model.")+6:] + adapter_weights_copy[key_suffix] = value + + adapter_weights = adapter_weights_copy + updated_keys = [] + + print("\n") + num_layers = len(self.model.layers) + for layer_idx, layer in enumerate(self.model.layers): + if layer_idx == 0: + print(f'Weight factors before checkpoint load {self.tp_rank=}, {layer.self_attn.window_factors.shape}, {layer.self_attn.window_factors.flatten()}') + + window_factors_key = f'layers.{layer_idx}.self_attn.window_factors' + if window_factors_key in adapter_weights: + layer.self_attn.load_window_factors(adapter_weights[window_factors_key]) + updated_keys.append(window_factors_key) + + if layer_idx == 0: + print(f'Weight factors after checkpoint load {self.tp_rank=}, {layer.self_attn.window_factors.shape}, {layer.self_attn.window_factors.flatten()}') + + fm_q_key = f'layers.{layer_idx}.self_attn.feature_map_q.mlp.layer' + if fm_q_key in adapter_weights: + # if layer_idx in [0, num_layers-1]: + # # print("\n") + # # print(f'FMAP Q before checkpoint load {self.tp_rank=}, {layer.self_attn.feature_map_q.layer.shape}, {layer.self_attn.feature_map_q.layer[0,0,:4]}') + + layer.self_attn.load_feature_map_q(adapter_weights[fm_q_key]) + updated_keys.append(fm_q_key) + + # if layer_idx in [0, num_layers-1]: + # print(f'FMAP Q after checkpoint load; {layer.self_attn.feature_map_q.layer.shape},{layer.self_attn.feature_map_q.layer[0,0,:4]}') + + fm_k_key = f'layers.{layer_idx}.self_attn.feature_map_k.mlp.layer' + if fm_k_key in adapter_weights: + layer.self_attn.load_feature_map_k(adapter_weights[fm_k_key]) + updated_keys.append(fm_k_key) + + weight_name = 'layers.{layer_idx}.self_attn.{proj}.lora_{a_or_b}.default.weight' + target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"] + # target_modules = ["q_proj", "k_proj", "v_proj"] + # target_modules = ["k_proj", "v_proj"] + # target_modules = ["q_proj", "k_proj"] + + r = 8 + lora_alpha = 16 + lora_dropout = 0 + + for proj in target_modules: + lora_A_key = weight_name.format(layer_idx=layer_idx, proj=proj, a_or_b='A') + lora_B_key = weight_name.format(layer_idx=layer_idx, proj=proj, a_or_b='B') + if lora_A_key in adapter_weights: + weight_A = adapter_weights[lora_A_key] + weight_B = adapter_weights[lora_B_key] + delta_AB = get_delta_weight(weight_A, weight_B, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout) + + # if layer_idx in [0, num_layers-1]: + # print(f'layer {layer_idx} weight_A.shape: {weight_A.shape} | weight_B.shape: {weight_B.shape} | delta_AB.shape: {delta_AB.shape}') + # print(f'layer {layer_idx} proj {proj} delta_AB', delta_AB.shape) + + if proj == 'o_proj': + # if layer_idx in [0, num_layers-1]: + # print("\n") + # print(f'Layer {layer_idx} {proj} weight before checkpoint load, {layer.self_attn.o_proj.weight.shape}, {layer.self_attn.o_proj.weight[0,:4]}') + + layer.self_attn.merge_lora_to_o_parallel(delta_AB) + + # if layer_idx in [0, num_layers-1]: + # print(f'Layer {layer_idx} {proj} weight after checkpoint load, {layer.self_attn.o_proj.weight.shape}, {layer.self_attn.o_proj.weight[0,:4]}') + else: + # if layer_idx in [0, num_layers-1] and proj in ['q_proj']: + # print(f'Layer {layer_idx} {proj} weight before checkpoint load, {layer.self_attn.qkv_proj.weight.shape}, {layer.self_attn.qkv_proj.weight[0,:4]}') + + layer.self_attn.merge_lora_to_qkv_parallel( + delta_AB, + loaded_shard_id=proj.split('_')[0], + total_num_heads=layer.self_attn.num_heads, + total_num_kv_heads=layer.self_attn.num_kv_heads,head_size=layer.self_attn.head_dim) + + # if layer_idx in [0, num_layers-1] and proj in ['q_proj']: + # print(f'Layer {layer_idx} {proj} weight after checkpoint load, {layer.self_attn.qkv_proj.weight.shape}, {layer.self_attn.qkv_proj.weight[0,:4]}') + updated_keys.append(lora_A_key) + updated_keys.append(lora_B_key) + + assert len(set(adapter_weights_copy.keys()) - set(updated_keys)) == 0, \ + f"UNUPDATED KEYS: {set(adapter_weights_copy.keys()) - set(updated_keys)}" + + +def transpose(weight, fan_in_fan_out): + if not fan_in_fan_out: + return weight + + if isinstance(weight, torch.nn.Parameter): + return torch.nn.Parameter(weight.T) + return weight.T + + +def get_delta_weight(weight_A: torch.Tensor, weight_B: torch.Tensor, + r: int = 8, lora_alpha: float = 16, lora_dropout: float = 0, + fan_in_fan_out: bool = False,): + + device = weight_B.device + dtype = weight_B.dtype + # From https://github.com/huggingface/peft/blob/850eeb5c3a5cf692f5612c7c733b13fde184e05d/src/peft/tuners/lora/layer.py#L512 + cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16) + if cast_to_fp32: + weight_A = weight_A.float() + weight_B = weight_B.float() + scaling = lora_alpha / r + output_tensor = transpose(weight_B @ weight_A, fan_in_fan_out) * scaling + if cast_to_fp32: + output_tensor = output_tensor.to(dtype=dtype) + return output_tensor + diff --git a/demos/vllm_integration/vllm_files/test_vllm_aw.py b/demos/vllm_integration/vllm_files/test_vllm_aw.py new file mode 100644 index 0000000..9738367 --- /dev/null +++ b/demos/vllm_integration/vllm_files/test_vllm_aw.py @@ -0,0 +1,82 @@ +import os +import math +from openai import OpenAI + +def calculate_perplexity(logprobs): + total_log_prob = 0 + token_count = 0 + + for token_logprobs in logprobs[1:]: + if token_logprobs: + total_log_prob += list(token_logprobs.values())[0].logprob + token_count += 1 + + if token_count == 0: + return float('inf') + + print(token_count) + perplexity = math.exp(-total_log_prob / token_count) + return perplexity + +def calc_perplexity_serve(logprobs, trim=1): + logprobs = logprobs[:-trim] + logprobs = [x for x in logprobs if x is not None] + print(f"{len(logprobs)=}") + return math.exp(-sum(logprobs) / len(logprobs)) + +if __name__ == '__main__': + use_served_model = True + model_size = 70 # [8, 70] + PATH = f"/data/rahul/models/Meta-Llama-3.1-{model_size}B/" + CKPT_PATH = f'/data/rahul/checkpoints/{model_size}b.pt' + openai_api_base = "http://0.0.0.0:8000/v1" + + os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7" + os.environ["LOLCATS_ADAPTER_PATH"] = CKPT_PATH + + prompts = [ + "I'm Michael:- 3rd-year Computer Science PhD student advised by Chris RΓ©.- Labmate at HazyResearch, Stanford AI Lab, Stanford Machine Learning Group. I currently work on deep learning architectures for expressive + efficient long sequence modeling, and using these advances to enable learning from new tasks and data types and I also care about deep learning robustness. I received my A.B. in", + # Statistics and Computer Science at Harvard in 2020. I'm grateful to have" + # "The 2024 Summer Paralympics (French: Jeux paralympiques d'Γ©tΓ© de 2024), also known as the Paris 2024 Paralympic Games, and branded as Paris 2024, is the 17th Summer Paralympic Games, an international multi-sport parasports event governed by the International Paralympic Committee, being held in Paris, France, from 28 August to 8 September 2024. These games mark the first time Paris is hosting the Summer Paralympics and the second time that France is hosting the new ", + # "Manchester United Football Club, commonly referred to as Man United (often Man United (often stylised as Man Utd), or simply United, is a Man United (often stylised as Man Utd), or simply United, is a professional football club based in Old Trafford, Greater Manchester, England. They compete in the Premier League, the top tier of English football. Nicknamed the Red Devils, they were founded as Newton Heath LYR Football Club in 1878, but changed their name to Manchester United in 1902. After a spell playing in Clayton, Manchester, the club moved to their current stadium, Old Trafford, in 1910. " + ] + + if use_served_model: + client = OpenAI(base_url=openai_api_base, api_key="EMPTY") + models = client.models.list() + model = models.data[0].id + tokens = 3 + outputs = client.completions.create( + model=model, + prompt=prompts, + temperature=0, + logprobs=1, + max_tokens=tokens, + seed=0, + echo=True, + ) + for prompt, choice in zip(prompts, outputs.choices): + logprobs = choice.logprobs.token_logprobs + print(f"Prompt: {len(prompt.split())}\n{prompt}") + print(f"Completion: {choice.text.replace(prompt, '')}") + print(f'Perplexity: {calc_perplexity_serve(logprobs, trim=tokens)}') + print("\n") + else: + + os.environ["VLLM_LOGGING_LEVEL"] = "DEBUG" + + from vllm import ModelRegistry, LLM, SamplingParams + + from src.model.modeling_llama_vllm import LlamaLolcatsForCausalLM + ModelRegistry.register_model("LlamaLolcatsForCausalLM", LlamaLolcatsForCausalLM) + sampling_params = SamplingParams(temperature=0.0, logprobs=1, prompt_logprobs=1, min_tokens=1, max_tokens=1) + llm = LLM(model=PATH, tensor_parallel_size=8, enforce_eager=True) + outputs = llm.generate( + prompts, + sampling_params, + ) + logprobs = output.prompt_logprobs + for output in outputs: + print(f"Perplexity: {calculate_perplexity(output.prompt_logprobs):.4f}") + + # Print the outputs.