diff --git a/README.md b/README.md index 17bd470..e4bf5a0 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,23 @@ -# LoLCATs [wip] +# LoLCATs

+We're excited to share LoLCATs, a new method to *convert* existing Transformers like Llamas & Mistrals into state-of-the-art subquadratic LLMs. + +LoLCATs does two things: +1. Attention Transfer: We replace the softmax attentions of an existing Transformer with linear attention analogs, but first *train* these linear layers to approximate their softmax counterparts +2. Low-rank Linearizing: Then, we can simply adjust for any approximation errors & recover quality with low-rank adaptation + +We find this "**Lo**w-rank **L**inear **C**onversion via **A**ttention **T**ran**s**fer" (hence, LoLCATs) results in "linearizing" LLMs with state-of-the-art quality and training efficiency (taking a couple hours on one 40GB A100 to create subquadratic Llama 3 8B and Mistral 7B LLMs). + +With this repo, we hope you can too! + In this README: - Getting started with dependencies, installation, and experiment configs -- Sample commands (Mistral-7B-v0.1, Llama-3-8B, Llama-3.1-8B, Llama-3.1-70B) +- Sample commands for 7B+ LLMs (e.g., Mistral-7B-v0.1, Llama-3-8B, Llama-3.1-8B; anything you can run on a single GPU) --- @@ -19,7 +29,7 @@ Please see `environment.yaml` for dependencies. We can set them up with conda: ``` conda env create -f environment.yaml -conda activate lolcats +conda activate lolcats-env ``` --- @@ -28,8 +38,8 @@ conda activate lolcats We organize things under experiment and model config files (`.yaml`) in `./configs`. -- Files under `./configs/experiments/` to determine dataset, training hyperparameters (distillation / conversion; finetuning). -- Files under `./configs/models/` determine model setup. +- Files under `./configs/experiments/` determine dataset and training hyperparameters (for training attentions, for low-rank adaptation). +- Files under `./configs/models/` determine model setup (pretrained LLM, linear attention architecture) For models, our scripts should automatically download the models from Hugging Face, but you should change the `cache_dir` to reflect where you want to save the weights. @@ -45,16 +55,29 @@ pretrained_config: low_cpu_mem_usage: true torch_dtype: bfloat16 rope_theta: 10000.0 - attn_implementation: eager # if supervising with attention weights + attn_implementation: flash_attention_2 # set to eager if you also want to compute attention weights ``` --- ### Additional dependencies +#### Flash Attention 2 install + +To do attention transfer, we train linear attentions by first computing softmax attention outputs as ``ground-truth'' targets to match. To compute these outputs with Flash Attention 2 (FA2), we recommend following Tri's default instructions [here](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features). + +Copying those instructions here: (1) Have `packaging` installed (`pip install packaging`). (2) Have `ninja` installed and working correctly (`ninja --version` then `echo $?` should return exit code 0). Otherwise reinstall with `pip uninstall -y ninja && pip install ninja`. (3) Install FA2 with + +``` +pip install flash-attn --no-build-isolation +``` + +--- + + #### Causal linear attention CUDA kernel -For now, we implement the causal linear attention with the CUDA kernel from [https://github.com/idiap/fast-transformers/tree/master](https://github.com/idiap/fast-transformers/tree/master), citing: +We support a faster causal linear attention with the CUDA kernel from [https://github.com/idiap/fast-transformers/tree/master](https://github.com/idiap/fast-transformers/tree/master), citing: ``` @inproceedings{katharopoulos_et_al_2020, @@ -86,46 +109,19 @@ It's worth checking the arguments in `./csrc/setup.py` to match your GPU setup a ### ThunderKittens linear attention + sliding window kernel -TODO. [ThunderKittens](https://github.com/HazyResearch/ThunderKittens) - -#### More - -We're very excited to integrate additional developments like Songlin and friends' [flash-linear-attention](https://github.com/sustcsonglin/flash-linear-attention) - ---- - -#### Flash Attention 2 install +We also implemented fused linear attention + sliding window kernel with [ThunderKittens](https://github.com/HazyResearch/ThunderKittens). Repository support for this coming soon. -To train subquadratic analogs with Flash Attention 2 (FA2), we recommend following Tri's default instructions [here](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features). +### More! -Copying those instructions here: (1) Have `packaging` installed (`pip install packaging`). (2) Have `ninja` installed and working correctly (`ninja --version` then `echo $?` should return exit code 0). Otherwise reinstall with `pip uninstall -y ninja && pip install ninja`. (3) Install FA2 with - -``` -pip install flash-attn --no-build-isolation -``` +We're also very excited to integrate additional developments like Songlin and friends' [flash-linear-attention](https://github.com/sustcsonglin/flash-linear-attention). --- ## Sample commands -For any of these commands, you may need to provide a Hugging Face token to download model checkpints. Simply add the `--huggingface_token ` argument to any script below. - -### Demoing linear attention 7B models +For any of these commands, you may need to provide a Hugging Face token to download model checkpoints. Simply add the `--huggingface_token ` argument to any script below. -**_Note: Stale_** - -We upload a couple checkpoints in `./checkpoints/`, where for any linearized 7B model we only need to save the linear attention layers and the LoRA weights (in two separate `.pt` checkpoints). To chat with these models, you can run: - -``` -python -Wignore demo_lolcats_llm.py \ ---attn_mlp_checkpoint_path './checkpoints/distill_mistral_7b_lk_smd_zi/dl-d=distill_alpaca_clean_mistral_lr1e-2-m=distill_mistral_7b_lk_smd_eins-f=finetune_lora_qkvo_alpaca_clean_mistral-s=0-se=0-re=31-lk=untied_head_einsum-lsc=1-lzi=1_distill.pt' \ ---finetune_checkpoint_path './checkpoints/distill_mistral_7b_lk_smd_zi/dl-d=distill_alpaca_clean_mistral_lr1e-2-m=distill_mistral_7b_lk_smd_eins-f=finetune_lora_qkvo_alpaca_clean_mistral-s=0-se=0-re=31-lk=untied_head_einsum-lsc=1-lzi=1-bs=1-gas=8-nte=2-ms=-1-es=100-se=0-re=31_ft.pt' \ ---num_generations 1 --benchmark -``` - ---- - -### Linearizing 7B models +### Linearizing 7B+ models

@@ -135,9 +131,25 @@ Any of the below commands will convert a 7B Mistral or Llama LLM into a subquadr See `configs/model/` for model configs used in the below commands, and `configs/experiments/` for attention transfer and finetuning configs. -#### Mistral-7B-v0.1, Hedgehog Feature Map, using Alpaca-Clean +We support linearizing various LLMs with various linear attention feature maps ([Transformer-to-RNN (T2R)](https://arxiv.org/abs/2103.13076), [Hedgehog](https://arxiv.org/abs/2402.04347)), and architectures (standard linear attention, the LoLCATs linear + sliding window setup). In general, we tried to make things easily extendable, so if you want to linearize a new LLM with some new architecture, it's as simple as changing a config line or adding a single module. + +Please find some sample scripts below, linearizing with a [cleaned up version](https://huggingface.co/datasets/yahma/alpaca-cleaned) of the [Alpaca dataset](https://crfm.stanford.edu/2023/03/13/alpaca.html). +#### Mistral-7B-v0.1, Hedgehog Feature Map, LoLCATs Linear + Sliding Window Attention + +```bash +python distill_llama.py --model_config distill_mistral_7b_lk_smd_wtk64_fd64_w01 \ +--distill_config distill_alpaca_clean_xent0_mse1000_lr1e-2 \ +--finetune_config finetune_lora_qkvo_alpaca_clean \ +--eval_config eval_alpaca_clean \ +--lk_zero_init \ +--verbose --seed 0 --replicate 0 \ +--huggingface_token hf_ ``` + +#### Mistral-7B-v0.1, Hedgehog Feature Map, Standard Linear Attention + +```bash python distill_llama.py --model_config distill_mistral_7b_lk_smd_fd64 \ --distill_config distill_alpaca_clean_xent0_mse1000_lr1e-2 \ --finetune_config finetune_lora_qkvo_alpaca_clean \ @@ -147,11 +159,11 @@ python distill_llama.py --model_config distill_mistral_7b_lk_smd_fd64 \ --huggingface_token hf_ ``` -#### Mistral-7B-v0.1, Hedgehog + ThunderKittens Sliding Window, using Alpaca-Clean +#### Mistral-7B-v0.1, T2R Feature Map, Standard Linear Attention -``` +```bash python distill_llama.py --model_config distill_mistral_7b_lk_smd_wtk64_fd64_w01 \ ---distill_config distill_alpaca_clean_xent0_mse1000_lr1e-2 \ +--distill_config distill_mistral_7b_lk_t2r \ --finetune_config finetune_lora_qkvo_alpaca_clean \ --eval_config eval_alpaca_clean \ --lk_zero_init \ @@ -159,10 +171,10 @@ python distill_llama.py --model_config distill_mistral_7b_lk_smd_wtk64_fd64_w01 --huggingface_token hf_ ``` -#### Llama 3 8B, Hedgehog Feature Map, using Alpaca-Clean +#### Llama 3 8B, Hedgehog Feature Map, LoLCATs Linear + Sliding Window Attention -``` -python distill_llama.py --model_config distill_llama3_8b_lk_smd_fd64 \ +```bash +python distill_llama.py --model_config distill_llama3_8b_lk_smd_wtk64_fd64_w01 \ --distill_config distill_alpaca_clean_xent0_mse1000_lr1e-2 \ --finetune_config finetune_lora_qkvo_alpaca_clean \ --eval_config eval_alpaca_clean \ @@ -171,10 +183,11 @@ python distill_llama.py --model_config distill_llama3_8b_lk_smd_fd64 \ --huggingface_token hf_ ``` -#### Llama 3 8B, Hedgehog + ThunderKittens Sliding Window, using Alpaca-Clean -``` -python distill_llama.py --model_config distill_llama3_8b_lk_smd_wtk64_fd64_w01 \ +#### Llama 3 8B, Hedgehog Feature Map, Standard Linear Attention + +```bash +python distill_llama.py --model_config distill_llama3_8b_lk_smd_fd64 \ --distill_config distill_alpaca_clean_xent0_mse1000_lr1e-2 \ --finetune_config finetune_lora_qkvo_alpaca_clean \ --eval_config eval_alpaca_clean \ @@ -183,10 +196,10 @@ python distill_llama.py --model_config distill_llama3_8b_lk_smd_wtk64_fd64_w01 \ --huggingface_token hf_ ``` -#### Llama 3.1 8B, Hedgehog Feature Map, using Alpaca-Clean +#### Llama 3 8B, T2R Feature Map, Standard Linear Attention -``` -python distill_llama.py --model_config distill_llama3_1_8b_lk_smd_fd64 \ +```bash +python distill_llama.py --model_config distill_llama3_8b_lk_t2r \ --distill_config distill_alpaca_clean_xent0_mse1000_lr1e-2 \ --finetune_config finetune_lora_qkvo_alpaca_clean \ --eval_config eval_alpaca_clean \ @@ -195,9 +208,9 @@ python distill_llama.py --model_config distill_llama3_1_8b_lk_smd_fd64 \ --huggingface_token hf_ ``` -#### Llama 3.1 8B, Hedgehog + ThunderKittens Sliding Window, using Alpaca-Clean +#### Llama 3.1 8B, Hedgehog Feature Map, LoLCATs Linear + Sliding Window Attention -``` +```bash python distill_llama.py --model_config distill_llama3_1_8b_lk_smd_wtk64_fd64_w01 \ --distill_config distill_alpaca_clean_xent0_mse1000_lr1e-2 \ --finetune_config finetune_lora_qkvo_alpaca_clean \ @@ -207,26 +220,47 @@ python distill_llama.py --model_config distill_llama3_1_8b_lk_smd_wtk64_fd64_w01 --huggingface_token hf_ ``` ---- +#### Llama 3.1 8B, Hedgehog Feature Map, Standard Linear Attention -### Evaluation +```bash +python distill_llama.py --model_config distill_llama3_1_8b_lk_smd_fd64 \ +--distill_config distill_alpaca_clean_xent0_mse1000_lr1e-2 \ +--finetune_config finetune_lora_qkvo_alpaca_clean \ +--eval_config eval_alpaca_clean \ +--lk_zero_init \ +--verbose --seed 0 --replicate 0 \ +--huggingface_token hf_ +``` -The above scripts will save two checkpoints: (1) for the learned attention layer weights (denoted by a `_distill` suffix), (2) for the LoRA finetuning weights (denoted by a `_ft` suffix). To evaluate linearized models from these checkpoints, we can add the `--load_distill_checkpoint` and `--load_finetune_checkpoint` args. For example: +#### Llama 3.1 8B, T2R Feature Map, Standard Linear Attention -``` -python distill_llama.py --model_config distill_llama3_8b_lk_smd_wtk64_fd64_w01 \ +```bash +python distill_llama.py --model_config distill_llama3_1_8b_lk_t2r \ --distill_config distill_alpaca_clean_xent0_mse1000_lr1e-2 \ --finetune_config finetune_lora_qkvo_alpaca_clean \ --eval_config eval_alpaca_clean \ --lk_zero_init \ --verbose --seed 0 --replicate 0 \ ---load_distill_checkpoint \ ---load_finetune_checkpoint +--huggingface_token hf_ ``` -#### LM Evaluation Harness +### Demoing linear attention 7B+ models -For sample LM Eval scripts, please see `./lm_eval_harness/README.md`. An example such script is: +The above scripts will save two checkpoints: (1) for the learned attention feature maps (denoted by a `_distill` suffix), (2) for the LoRA finetuning weights (denoted by a `_ft` suffix). We upload a couple starter checkpoints in `./checkpoints/`, where for any linearized LLM we only need to save these layers (~0.2% of a 7B LLM's parameters). (We also provide additional checkpoints on HuggingFace). To chat with these models (albeit in an unoptimized PyTorch implementation), you can run: + +```bash +python -Wignore demo_lolcats_llm.py \ +--attn_mlp_checkpoint_path './checkpoints/distill_llama3_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-se=0-re=12-lzi=1_distill.pt' \ +--finetune_checkpoint_path './checkpoints/distill_llama3_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-se=0-re=12-lzi=1-bs=1-gas=8-nte=2-se=0-re=12_ft.pt' \ +--num_generations 1 --benchmark +``` + + +--- + +### LM Evaluation Harness Evaluation + +To evaluate linearized models from these checkpoints, we similarly speciy these `--attn_mlp_checkpoint_path` and `--finetune_checkpoint_path` args. Please see `./lm_eval_harness/README.md` for more sample LM Eval scripts. Two such examples: ```bash python lm_eval_harness/eval_lm_harness.py \ @@ -236,6 +270,14 @@ python lm_eval_harness/eval_lm_harness.py \ --task piqa --num_shots 0 --no_cache --verbose ``` +```bash +python lm_eval_harness/eval_lm_harness.py \ +--model_type lolcats_ckpt \ +--attn_mlp_checkpoint_path './checkpoints/distill_llama3_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-se=0-re=12-lzi=1_distill.pt' \ +--finetune_checkpoint_path './checkpoints/distill_llama3_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-se=0-re=12-lzi=1-bs=1-gas=8-nte=2-se=0-re=12_ft.pt' \ +--task piqa --num_shots 0 --no_cache --verbose +``` + To setup the evaluations, we clone the Language Model Evaluation Harness from [here](https://github.com/EleutherAI/lm-evaluation-harness/tree/b281b0921b636bc36ad05c0b0b0763bd6dd43463) to a separate directory (e.g., outside the lolcats directory). - Note we use the `b281b09` branch following Hugging Face's [Open LLM Leaderboard](https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard). @@ -248,50 +290,15 @@ LM_EVALUATION_HARNESS_PATH = '/juice2/scr2/mzhang/projects/lm-evaluation-harness --- -### Linearizing 70B models and up [WIP] +### Linearizing 70B models and up.

-We also support linearizing larger LLMs (Llama 3.1 70B, Llama 3.1 405B) using the great [llama-recipes](https://github.com/meta-llama/llama-recipes/tree/main/src/llama_recipes) repository. - -See `llama_recipes/README.md` for more details. At a high-level, we borrow the Fully Sharded Data Parallel (FSDP) pipeline, linearize **unquantized** models, and split the two stages of LoLCATs linearizing into two scripts: - -1. `distill_llama.py`: where we first train subquadratic attentions to mimic the softmax attentions (saving the learned attention feature map checkpoints) -2. `distill_llama_finetune.py`: where we swap in the learned attentions and finetune the rest of the model with LoRA (saving the LoRA checkpoints) - -By passing in the same configurations files and arguments to both scripts, `distill_llama_finetune.py` should automatically load the saved checkpoints and pick up from where `distill_llama.py` left off. +We also support linearizing larger LLMs (Llama 3.1 70B, Llama 3.1 405B), building on the great [llama-recipes](https://github.com/meta-llama/llama-recipes/tree/main/src/llama_recipes) repository. -#### Sample Commands - -**_Script 1: Attention Transfer_** - -```bash -torchrun --nnodes 1 --nproc_per_node 8 \ -llama_recipes/distill_llama.py \ ---model_config distill_llama3_1_70b_lk_smd_wtk64_fd64_w01 \ ---distill_config distill_alpaca_clean_xent0_mse1000_lr1e-2 \ ---finetune_config finetune_lora_qkvo_alpaca_clean_llama3_1_70b \ ---eval_config eval_alpaca_clean \ ---lk_zero_init \ ---verbose --replicate 0 --seed 0 \ ---enable_fsdp --low_cpu_fsdp -``` - -**_Script 2: Low-rank Adaptation_** - -```bash -torchrun --nnodes 1 --nproc_per_node 8 \ -llama_recipes/distill_llama_finetune.py \ ---model_config distill_llama3_1_70b_lk_smd_wtk64_fd64_w01 \ ---distill_config distill_alpaca_clean_xent0_mse1000_lr1e-2 \ ---finetune_config finetune_lora_qkvo_alpaca_clean_llama3_1_70b \ ---eval_config eval_alpaca_clean \ ---lk_zero_init \ ---verbose --replicate 0 --seed 0 \ ---enable_fsdp --low_cpu_fsdp -``` +Please see the [`lolcats-scaled`](https://github.com/HazyResearch/lolcats/tree/lolcats-scaled) branch for more! #### GPU Memory Training Requirements @@ -301,7 +308,7 @@ See https://huggingface.co/blog/llama31#training-memory-requirements ## Setup Debugging -### Huggingface datasets errors +### Hugging Face datasets errors If you come across an error like the following: diff --git a/configs/experiment/finetune_lora_qkvo_rpcontig1024_dcs1024_no_esl.yaml b/configs/experiment/finetune_lora_qkvo_rpcontig1024_dcs1024_no_esl.yaml index ca063a6..4b297a4 100644 --- a/configs/experiment/finetune_lora_qkvo_rpcontig1024_dcs1024_no_esl.yaml +++ b/configs/experiment/finetune_lora_qkvo_rpcontig1024_dcs1024_no_esl.yaml @@ -9,7 +9,7 @@ dataset: max_train_samples: 50000 max_eval_num: 1000 max_length: 32768 - min_length: 1048 + min_length: 1024 chat_template: llama-3 chunk_size: 1024 # sequence length for distilling seed: 42 diff --git a/configs/model/distill_llama3_1_70b_lk_smd_wtk64_fd64_w01.yaml b/configs/model/distill_llama3_1_70b_lk_smd_wtk64_fd64_w01.yaml index b13b701..ad84d7e 100644 --- a/configs/model/distill_llama3_1_70b_lk_smd_wtk64_fd64_w01.yaml +++ b/configs/model/distill_llama3_1_70b_lk_smd_wtk64_fd64_w01.yaml @@ -1,7 +1,7 @@ name: llama model: pretrained_model_name_or_path: "meta-llama/Meta-Llama-3.1-70B" - cache_dir: "/home/mzhang/models/llama-3_1-70b" # Set this to where you want to save checkpoint weights + cache_dir: "/scr/mzhang/models/llama-3_1-70b" # Set this to where you want to save checkpoint weights return_dict: true load_in_8bit: false load_in_4bit: false diff --git a/distill_llama.py b/distill_llama.py index b41879e..0ba90e5 100644 --- a/distill_llama.py +++ b/distill_llama.py @@ -315,6 +315,12 @@ def main(): print_model=args.verbose, merge_loras=False, peft_gradient_checkpointing=not args.no_peft_grad_ckpt) + if args.verbose: + print_header(f'*** Trainable finetuning parameters ***') + for n, p in model.named_parameters(): + if p.requires_grad: + print(f'├── {n} ({p.dtype})') + finetune_trainer = get_finetuner(model, finetune_config, args.device, args, wandb) if args.verbose: print_header('Finetune config') @@ -357,18 +363,19 @@ def main(): finetune_trainer = get_evaluator(model, eval_config, args, args.device, wandb) # Final eval - print_header('*** Distilled + Finetuned Final Eval ***') - final_metrics = finetune_trainer.evaluate(model, step=-1, max_batches=None, prefix='final') - print_header('*** Saved Checkpoints ***') - print(f'--attn_mlp_checkpoint_path {args.load_distill_checkpoint} \\') - print(f'--finetune_checkpoint_path {args.load_finetune_checkpoint} \\') - # print(f'--finetune_long_checkpoint_path {args.load_finetune_long_checkpoint} \\') - - print(final_metrics) - for k, v in final_metrics.items(): - print(f'├── {k}: {v:.4f}') - if wandb is not None: - wandb.log({f'final/{k}': v for k, v in final_metrics.items()}) + if 'save10' not in args.distill_config and 'save10' not in args.finetune_config: + print_header('*** Distilled + Finetuned Final Eval ***') + final_metrics = finetune_trainer.evaluate(model, step=-1, max_batches=None, prefix='final') + print_header('*** Saved Checkpoints ***') + print(f'--attn_mlp_checkpoint_path {args.load_distill_checkpoint} \\') + print(f'--finetune_checkpoint_path {args.load_finetune_checkpoint} \\') + # print(f'--finetune_long_checkpoint_path {args.load_finetune_long_checkpoint} \\') + + print(final_metrics) + for k, v in final_metrics.items(): + print(f'├── {k}: {v:.4f}') + if wandb is not None: + wandb.log({f'final/{k}': v for k, v in final_metrics.items()}) # ------------------ diff --git a/distill_llama_layer.py b/distill_llama_layer.py deleted file mode 100644 index 43d71bd..0000000 --- a/distill_llama_layer.py +++ /dev/null @@ -1,507 +0,0 @@ -""" -Script to convert a single attention layer into a linear attention layer - -python distill_llama_layer.py \ ---model_config distill_llama3_1_8b_lk_smd_wtk64_fd64_w01 \ ---distill_config distill_synth_normal_llama3_1_8b_xent1_mse1000_lr1e-2 \ ---finetune_config finetune_lora_qkvo_synth_normal_llama3_1_8b_xent1_mse1000 \ ---lk_zero_init \ ---verbose --seed 0 --replicate 0 \ ---layer_idx 0 --device 0 --lr 1e-3 - -python distill_llama_layer.py \ ---model_config distill_llama3_1_8b_lk_smd_wtk64_fd64_w01 \ ---distill_config distill_synth_normal_llama3_1_8b_xent1_mse1000_lr1e-2 \ ---finetune_config finetune_lora_qkvo_synth_normal_llama3_1_8b_xent1_mse1000 \ ---lk_zero_init \ ---verbose --seed 0 --replicate 0 \ ---layer_idx 1 --device 1 --lr 1e-3 - - -python distill_llama_layer.py \ ---model_config distill_llama3_1_8b_lk_smd_wtk64_fd64_w01 \ ---distill_config distill_synth_normal_llama3_1_8b_xent1_mse1000_lr1e-2 \ ---finetune_config finetune_lora_qkvo_synth_normal_llama3_1_8b_xent1_mse1000 \ ---lk_zero_init \ ---verbose --seed 0 --replicate 0 \ ---layer_idx 2 --device 2 --lr 1e-3 -""" -""" -Alternate way to do things where we convert a single attention layer into a linear attention layer - -This lets us linearize big models in a decentralized manner without interconnect. -Just take a layer and train. - -python distill_llama_layer.py \ ---model_config distill_llama3_8b_lk_smd_wtk64_fd64_w01 \ ---distill_config distill_alpaca_clean_xent1_mse1000_lr1e-2 \ ---finetune_config finetune_lora_qkvo_alpaca_clean_layer_xent1_mse1000 \ ---lk_zero_init --verbose --seed 0 --replicate 0 \ ---layer_idx 0 --device 0 --lr 1e-3 - -killed -python distill_llama_layer.py --layer_idx 1 --device 1 --checkpoint_dir ./checkpoints --device 0 --distill_config distill_alpaca_clean_xent1_mse1000_lr1e-2 --finetune_config finetune_lora_qkvo_alpaca_clean_layer_xent1_mse1000 --lk_zero_init --model_config distill_llama3_8b_lk_smd_wtk64_fd64_w01 --no_init_eval --project_name lolcats --replicate 1 --results_dir ./results --seed 0 --verbose --wandb_entity hazy-research -""" -import sys -import os -from os.path import join - -import argparse -from omegaconf import OmegaConf -from tqdm import tqdm - -import torch -from torch.utils.data import Dataset, DataLoader - -from transformers import PretrainedConfig, LlamaConfig - -from src.utils.setup import ( - init_wandb, seed_everything, flatten_config, get_run_name_from_args, - update_config_from_args, update_model_config_from_args, -) -from src.utils.logging import print_config, print_header -from src.dataloaders import load_data -from src.trainer import get_trainer, get_optimizer, get_scheduler -from src.finetune import prepare_finetune_configs, get_finetuner - -from src.model.pretrained import get_pretrained_loader -from src.model.load_model import load_and_convert_attns, load_and_convert_finetune -from src.model.convert_model import toggle_attention, remove_base_attention, traverse_layers -from src.model.utils import count_parameters - -from transformers.models.llama.modeling_llama import LlamaAttention -from src.model.convert_model import get_attention - - -def get_args(): - """Parse command line arguments""" - parser = argparse.ArgumentParser() - parser.add_argument("--project_name", type=str, default='lolcats') - parser.add_argument("--layer_idx", type=int) # specify the layer - parser.add_argument("--device", type=int, default=0) - - parser.add_argument("--model_config", type=str, default=None) - parser.add_argument("--distill_config", type=str, default=None) - parser.add_argument("--finetune_config", type=str, default=None) - parser.add_argument("--eval_config", type=str, default=None) - - parser.add_argument("--pretrained_model_name_or_path", type=str, default=None) - parser.add_argument("--load_distill_checkpoint", type=str, default=None) - parser.add_argument("--resume_distill", action='store_true', default=None) - - parser.add_argument("--load_finetune_checkpoint", type=str, default=None) - parser.add_argument("--resume_finetune", action='store_true', default=None) - - # Override default configs - # Feature map / model - parser.add_argument("--attention_type", type=str, default=None) - parser.add_argument("--learned_kernel", type=str, default=None) # always - parser.add_argument("--lk_skip_connection", action='store_true', default=None) - parser.add_argument("--lk_zero_init", action='store_true', default=None) - parser.add_argument("--lk_normal_init", action='store_true', default=None) - parser.add_argument("--tie_qk_kernels", action='store_true', default=None) - parser.add_argument("--train_qk", action='store_true', default=None) - parser.add_argument("--state_chunk_len", type=int, default=None) - - # Training - parser.add_argument("--lr", type=float, default=None) - parser.add_argument("--weight_decay", type=float, default=None) - parser.add_argument("--optim", type=str, default=None) - parser.add_argument("--scheduler", type=str, default=None) - parser.add_argument("--gradient_accumulation_steps", type=int, default=None) - parser.add_argument("--num_train_epochs", type=int, default=None) - parser.add_argument("--max_steps", type=int, default=None) - parser.add_argument("--max_finetune_steps", type=int, default=None) - - parser.add_argument("--no_peft_grad_ckpt", action='store_true', default=None) - - # Dataloading - parser.add_argument("--batch_size", type=int, default=None) - parser.add_argument("--num_workers", type=int, default=None) - - # Evaluation - parser.add_argument("--no_init_eval", action='store_true', default=False) - parser.add_argument("--eval_steps", type=int, default=None) - parser.add_argument("--max_eval_batches", type=int, default=None) - - # Miscellaneous - parser.add_argument("--huggingface_token", type=str, default=None) - parser.add_argument("--checkpoint_dir", type=str, default='./checkpoints') - parser.add_argument("--results_dir", type=str, default='./results') - parser.add_argument("--replicate", type=int, default=0) - parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--verbose", action='store_true', default=None) - parser.add_argument("--no_cuda", action='store_true', default=None) - parser.add_argument("--no_wandb", action='store_true', default=None) - parser.add_argument("--wandb_entity", type=str, default='hazy-research') - parser.add_argument("--debug", action='store_true', default=None) - parser.add_argument("--no_attention_mask", action='store_true', default=None) - - args = parser.parse_args() - args.run_name = get_run_name_from_args(args) - return args - - -# ------------------------------ -# Precomputed Tensor Dataloaders -# ------------------------------ -class AttentionInputDataset(Dataset): - """ - Tensor dataset for LlamaAttention model - """ - def __init__(self, tensors: torch.Tensor): - self.samples = tensors - - def __len__(self): - return len(self.samples) - - def __getitem__(self, idx): - x = self.samples[idx] - position_ids = torch.arange(x.shape[-2]) - return {'hidden_states': x, 'position_ids': position_ids} - - -def load_data(data_dir: str, layer_idx: int, max_layer: int = 32, - **loader_kwargs: any): - """ - Specific function to load attention input dataloaders - """ - max_layer_digits = len(str(max_layer)) - - dataloaders = {'train': None, 'validation': None} - for split in dataloaders: - sample_tensors = [] - for f in os.listdir(data_dir): - # Filter and load naïvely - if f'-l={layer_idx:0{max_layer_digits}d}-s={split}' in f: - sample_tensors.append(torch.load(join(data_dir, f))) - samples = torch.cat(sample_tensors, dim=0) # attn_inputs.shape is (batch, seq_len, hidden_size) - _dataset = AttentionInputDataset(samples) - _dataloader = DataLoader(_dataset, shuffle=True if split == 'train' else False, - **loader_kwargs) - dataloaders[split] = _dataloader - return dataloaders - - -def main(): - # ------ - # SET UP - # ------ - args = get_args() - args.checkpoint_dir = join(args.checkpoint_dir, args.model_config) - if not os.path.isdir(args.checkpoint_dir): - os.makedirs(args.checkpoint_dir) - # Save individual .pt model weights in a subdirectory - args.checkpoint_dir = join(args.checkpoint_dir, 'sharded_layers') - if not os.path.isdir(args.checkpoint_dir): - os.makedirs(args.checkpoint_dir) - args.results_dir = join(args.results_dir, args.model_config) - if not os.path.isdir(args.results_dir): - os.makedirs(args.results_dir) - seed_everything(args.seed) - # args.device = torch.device('cuda') - - # Load distillation + (hedgehog) attention configs - distill_config_path = join('./configs/experiment', f'{args.distill_config}.yaml') - distill_config = OmegaConf.load(distill_config_path) - distill_config = update_config_from_args(distill_config, args) - - model_config_path = join('./configs/model', f'{args.model_config}.yaml') - model_config = OmegaConf.load(model_config_path) - model_config = update_model_config_from_args(model_config, args) - - # Get data directory for layer-wise input tensors - dataset_name = distill_config.dataset.name - cache_dir = distill_config.dataset.dataset_config.cache_dir - model_name = model_config.model.pretrained_model_name_or_path.replace('/', '_') - - rank = 0 - - if rank == 0 or not args.enable_fsdp: - try: - # Example: meta-llama_Meta-Llama-3.1-70B/attn_inputs-l=31-split=train-b=0499.pt - data_dir = join(cache_dir, dataset_name, model_name) - except Exception as e: - print(f'Data directory {join(cache_dir, dataset_name, model_name)} not found.') - print(f'Please see ./llama_recipes/save_llama_attn_inputs.py to save those tensors.') - raise e - - # Update data tokenizer to match model - if getattr(distill_config.dataset, 'pretrained_model_config', None) is not None: - for k in ['pretrained_model_name_or_path', 'cache_dir']: - distill_config.dataset.pretrained_model_config[k] = model_config.model[k] - - # Update optimizer if specified - if 'optimizer' in model_config: - for k, v in model_config.optimizer.items(): - distill_config.optimizer[k] = v - - # Update distilling trainer to reflect layer-wise - distill_config.trainer.name = 'layer_distill_xent_mse' - - print_header('Distillation Config') - print_config(distill_config) - print_header('Model Config') - print_config(model_config) - - # Get model class and configs for layer instantiating - pretrained_model_config = LlamaConfig.from_pretrained(model_config['model']['pretrained_model_name_or_path']) - pretrained_model_class = pretrained_model_config.architectures[0] - transformers_module = __import__('transformers') - pretrained_model_class = getattr(transformers_module, pretrained_model_class) # e.g, LlamaForCausalLM - - # Final run name / checkpoint naming setup - num_hidden_layers = pretrained_model_config.num_hidden_layers # 32 - max_digits = len(str(num_hidden_layers)) - args.run_name += f'-layer={args.layer_idx:0{max_digits}d}' # will save layer-wise checkpoints - args.run_name = args.run_name.replace('True', '1').replace('False', '0') # concise hacks - - # WandB logging - wandb = init_wandb(args) - if wandb is not None: - distill_config['model'] = model_config # Combine for logging - _flattened = {'model': model_config, - 'model_config': args.model_config, # config file names - 'distill_config': args.distill_config, - 'finetune_config': args.finetune_config, - 'distill_checkpoint': args.load_distill_checkpoint, - 'finetune_checkpoint': args.load_finetune_checkpoint, - 'replicate': args.replicate} - flatten_config(OmegaConf.to_container(distill_config), _flattened, '') - wandb.config.update(_flattened) - - dtype = getattr(torch, model_config['model']['torch_dtype']) - print_header('Pretrained Model Config') - print(pretrained_model_config) - - try: # Test HF transformers version - teacher_attn = LlamaAttention(pretrained_model_config, layer_idx=args.layer_idx) - except KeyError: # Might error on RoPE type due to HF transformer version - pretrained_model_config = pretrained_model_config.to_dict() - pretrained_model_config['rope_scaling']['type'] = pretrained_model_config['rope_scaling']['rope_type'] - pretrained_model_config = LlamaConfig.from_dict(pretrained_model_config) - # teacher_attn = LlamaAttention(pretrained_model_config, layer_idx=args.layer_idx) - - try: # Load individual layer from memory - teacher_attn = LlamaAttention(pretrained_model_config, layer_idx=args.layer_idx) - with torch.no_grad(): - pretrained_fname = model_config['model']['pretrained_model_name_or_path'].replace('/', '_') - pretrained_fname = join(args.checkpoint_dir, pretrained_fname) + f'-attn={args.layer_idx}.pt' - teacher_attn.load_state_dict(torch.load(pretrained_fname)) - print_header('All teacher weights loaded successfully') - for p in teacher_attn.parameters(): # Freeze all layers - p.requires_grad = False - - model_attn = get_attention(**model_config['attention'])( - base_attn=teacher_attn, layer_idx=args.layer_idx, - max_layer_idx=pretrained_model_config.num_hidden_layers - 1, - train_attention=True, remove_base_attn=True - ) # .to(dtype=dtype) - print(f'-> Loaded pretrained attention from {pretrained_fname}!') - except Exception as e: # Load entire model to disk - print('-> Addressing exception:', e) - # Get pretrained model - model_config.model['device_map'] = 'cpu' - model_loader = get_pretrained_loader(**model_config.model, - huggingface_token=args.huggingface_token) - model = model_loader.load(model_type='softmax') - for p in model.parameters(): # Freeze all layers - p.requires_grad = False - model.eval() - # Save pretrained attention weights - with torch.no_grad(): - for layer_idx, layer in enumerate(tqdm(traverse_layers(model), desc=f'Saving layer attentions to {args.checkpoint_dir}...')): - pretrained_fname = model_config['model']['pretrained_model_name_or_path'].replace('/', '_') - pretrained_fname = join(args.checkpoint_dir, pretrained_fname) + f'-attn={layer_idx}.pt' - torch.save(layer.self_attn.state_dict(), pretrained_fname) - - teacher_attn = LlamaAttention(pretrained_model_config) - with torch.no_grad(): - pretrained_fname = model_config['model']['pretrained_model_name_or_path'].replace('/', '_') - pretrained_fname = join(args.checkpoint_dir, pretrained_fname) + f'-attn={args.layer_idx}.pt' - teacher_attn.load_state_dict(torch.load(pretrained_fname)) - print_header('All teacher weights loaded successfully') - for p in teacher_attn.parameters(): # Freeze all layers - p.requires_grad = False - - model_attn = get_attention(**model_config['attention'])( - base_attn=teacher_attn, layer_idx=args.layer_idx, - max_layer_idx=pretrained_model_config.num_hidden_layers - 1, - train_attention=True, remove_base_attn=True - ) - del model - - device = torch.device(f'cuda:{args.device}') - model_attn = model_attn.to(device, dtype=dtype) - model_attn.device = device # hack - teacher_attn.eval() - teacher_attn.to(dtype=dtype) - - if args.verbose: - print_header(f'*** Initial Layer {args.layer_idx} ***') - print(model_attn) - print_header('*** Trainable Parameters ***') - count = 0 - for n, p in model_attn.named_parameters(): - if p.requires_grad: - print(f'├── {n} (requires_grad = {p.requires_grad}, dtype = {p.dtype})') - count += 1 - if count == 0: - print('(none)') - - # --------------------------- - # Stage 1: Attention Transfer - # --------------------------- - if args.load_distill_checkpoint is None: - dataloaders = load_data(data_dir, args.layer_idx, max_layer=num_hidden_layers, - **distill_config.dataloader) - train_loader = dataloaders['train'] - eval_loader = dataloaders['validation'] - - # Log some stats - distill_config.model_train_params = count_parameters(model_attn, requires_grad=True) - distill_config.model_total_params = count_parameters(model_attn, requires_grad=False) - pct_trainable = distill_config.model_train_params / distill_config.model_total_params - - print_header('*** Distillation Parameter Counts ***') - print(f'├── Number training to distill: {distill_config.model_train_params}') - print(f'├── Number of total parameters: {distill_config.model_total_params}') - print(f'├── Percent training to distill: {pct_trainable * 100:.3f}%') - - # Get optimizer and scheduler - optimizer = get_optimizer(model=model_attn, **distill_config.optimizer) - scheduler = get_scheduler(optimizer=optimizer, **distill_config.lr_scheduler) - - # Load trainer - for arg, argv in distill_config.trainer.items(): - if arg != 'name': - setattr(args, arg, argv) - for _config in ['dataloader', 'optimizer', 'lr_scheduler']: - setattr(args, _config, OmegaConf.to_container(getattr(distill_config, _config))) - - OurTrainer = get_trainer(distill_config.trainer.name) - trainer = OurTrainer(model=model_attn, - layer_idx=args.layer_idx, - args=args, - train_loader=train_loader, - eval_loader=eval_loader, - optimizer_and_scheduler=(optimizer, scheduler), - device=args.device, - wandb=wandb, - checkpoint_suffix='_distill', - save_results=False, - **distill_config.trainer) - - # Train / distill model - print_header('*** Distilling Attentions ***') - print(f'├── Experiment name: {args.run_name}') - print(f'├── Device: {args.device}') - print(f'├── Seed: {args.seed}') - model_attn.train_attention = True # we did this above already - model_attn = trainer.train() - args.load_distill_checkpoint = trainer.best_val_checkpoint_path # saved here - else: - with torch.no_grad(): - model_attn.load_state_dict( - torch.load(args.load_distill_checkpoint)['model_state_dict'], strict=False,) - - # Prepare for 2nd stage finetune - model_attn.train_attention = False - if getattr(model_attn, 'base_attn', False): - del model_attn.base_attn - - # -------------------------- - # Stage 2: Low-rank Adapting - # -------------------------- - if args.max_finetune_steps is not None: - args.max_steps = args.max_finetune_steps - - pretrained_model_config = pretrained_model_config.to_dict() # Ordinarily not mutable, and - # TypeError: 'PretrainedConfig' object is not subscriptable - pretrained_model_config['num_hidden_layers'] = 1 # only one layer - pretrained_model_config = LlamaConfig.from_dict(pretrained_model_config) - model = pretrained_model_class(pretrained_model_config) # hacks - with torch.no_grad(): - model.model.layers[0].self_attn = model_attn # Should be the same - model.model.layers[0].self_attn.load_state_dict(model_attn.state_dict()) - - finetune_config, args = prepare_finetune_configs(args, model_config, args.finetune_config) - try: - train_loader = dataloaders['train'] - eval_loader = dataloaders['validation'] - except: - dataloaders = load_data(data_dir, args.layer_idx, max_layer=num_hidden_layers, - **distill_config.dataloader) - train_loader = dataloaders['train'] - eval_loader = dataloaders['validation'] - - # Update distilling trainer to reflect layer-wise - finetune_config.trainer.name = 'layer_finetune_xent_mse' - - checkpoint_path = args.load_finetune_checkpoint - model, ft_peft_config = load_and_convert_finetune(model, finetune_config, - checkpoint_path=checkpoint_path, # could be None - print_model=False, # args.verbose, - merge_loras=False, - peft_gradient_checkpointing=not args.no_peft_grad_ckpt, - add_self_attn_prefix=False,) - model_attn = traverse_layers(model)[0].self_attn - # Initialize optimizer and scheduler - optimizer = get_optimizer(model=model_attn, **finetune_config.optimizer) - scheduler = get_scheduler(optimizer=optimizer, **finetune_config.lr_scheduler) - - if args.verbose: - print_header(f'*** Finetune Layer {args.layer_idx} ***') - print(model_attn) - print_header('*** Trainable Parameters ***') - count = 0 - for n, p in model_attn.named_parameters(): - if p.requires_grad: - print(f'├── {n} (requires_grad = {p.requires_grad}, dtype = {p.dtype})') - count += 1 - if count == 0: # no trainable parameters - print('(none)') - - print_header(f'*** Teacher Layer {args.layer_idx} ***') - print(teacher_attn) - # assert teacher_attn.q_proj.weight == model_attn.q_proj.base_layer - - OurTrainer = get_trainer(finetune_config.trainer.name) - for p in teacher_attn.parameters(): - p.requires_grad = False - finetune_trainer = OurTrainer(model=model_attn, - teacher_layer=teacher_attn.to(model_attn.device), - layer_idx=args.layer_idx, - args=args, - train_loader=train_loader, - eval_loader=eval_loader, - optimizer_and_scheduler=(optimizer, scheduler), - device=args.device, - wandb=wandb, - checkpoint_suffix='_ft', - save_results=False, - **finetune_config.trainer) - if args.verbose: - print_header('Finetune config') - print_config(finetune_config) - print_header('*** Finetuning ***') - print(f'├── Experiment name: {args.run_name}') - print(f'├── Device: {args.device}') - print(f'├── Seed: {args.seed}') - model_attn = finetune_trainer.train() - args.load_finetune_checkpoint = finetune_trainer.best_val_checkpoint_path - - if ft_peft_config is not None and wandb is not None: - if not isinstance(ft_peft_config, dict): - ft_peft_config = ft_peft_config.to_dict() - _flattened['peft_ft'] = ft_peft_config - wandb.config.update(_flattened, allow_val_change=True) # saved here - - print_header('*** Done training ***') - print('--> Saved Checkpoints:') - print(f'--attn_mlp_checkpoint_path {args.load_distill_checkpoint} \\') - print(f'--finetune_checkpoint_path {args.load_finetune_checkpoint} \\') - -if __name__ == '__main__': - main() - print("Thanks for washing my dishes") diff --git a/distill_llama_mini.py b/distill_llama_mini.py deleted file mode 100644 index 0a37966..0000000 --- a/distill_llama_mini.py +++ /dev/null @@ -1,667 +0,0 @@ -""" -Alternate way to do things where we convert a block of Llama decoder layers into linear attention equivalents - -This lets us linearize big models in a decentralized manner without interconnect. -Just take a layer and train. - -(screen -r h3) -python distill_llama_mini_io.py \ ---model_config distill_llama3_8b_lk_smd_wtk64_fd64_w01 \ ---distill_config distill_alpaca_clean_xent1_mse1000_lr1e-2 \ ---finetune_config finetune_lora_qkvo_alpaca_clean_mini_xent1_mse1000 \ ---lk_zero_init --lr 1e-3 \ ---layer_idx 0 --layers_per_model 8 --device 0 \ ---verbose --seed 0 --replicate 0 - -(screen -r h4) -python distill_llama_mini.py \ ---model_config distill_llama3_8b_lk_smd_wtk64_fd64_w01 \ ---distill_config distill_alpaca_clean_xent1_mse1000_lr1e-2 \ ---finetune_config finetune_lora_qkvo_alpaca_clean_mini_xent1_mse1000 \ ---lk_zero_init --lr 1e-3 \ ---layer_idx 8 --layers_per_model 8 --device 0 \ ---verbose --seed 0 --replicate 0 - - -python distill_llama_mini.py \ ---model_config distill_llama3_8b_lk_smd_wtk64_fd64_w01 \ ---distill_config distill_alpaca_clean_xent1_mse1000_lr1e-2 \ ---finetune_config finetune_lora_qkvo_alpaca_clean_mini_xent1_mse1000 \ ---lk_zero_init --lr 1e-3 \ ---layer_idx 16 --layers_per_model 8 --device 0 \ ---verbose --seed 0 --replicate 0 - -(screen -r h3) -python distill_llama_mini.py \ ---model_config distill_llama3_8b_lk_smd_wtk64_fd64_w01 \ ---distill_config distill_alpaca_clean_xent1_mse1000_lr1e-2 \ ---finetune_config finetune_lora_qkvo_alpaca_clean_mini_xent1_mse1000 \ ---lk_zero_init --lr 1e-3 \ ---layer_idx 24 --layers_per_model 8 --device 0 \ ---verbose --seed 0 --replicate 0 - - -python distill_llama_mini.py \ ---model_config distill_llama3_8b_lk_smd_wtk64_fd64_w01 \ ---distill_config distill_alpaca_clean_xent1_mse1000_lr1e-2 \ ---finetune_config finetune_lora_qkvo_alpaca_clean_mini_xent1_mse1000 \ ---lk_zero_init --lr 1e-3 \ ---layer_idx 0 --layers_per_model 8 --device 0 \ ---verbose --seed 0 --replicate 0 ---layer_idx 16 -""" -from typing import Optional, Tuple, Union, List -import sys -import os -from os.path import join -import copy - -import argparse -from omegaconf import OmegaConf -from tqdm import tqdm - -import torch -import torch.nn as nn -from torch.utils.data import Dataset, DataLoader - -from transformers import PretrainedConfig, LlamaConfig -from transformers.cache_utils import Cache -from transformers.models.llama.modeling_llama import ( - LlamaAttention, LlamaMLP, LlamaRMSNorm, LlamaConfig, - LlamaModel, LlamaForCausalLM, LlamaRotaryEmbedding -) -from transformers.modeling_outputs import CausalLMOutputWithPast - -from src.utils.setup import ( - init_wandb, seed_everything, flatten_config, get_run_name_from_args, - update_config_from_args, update_model_config_from_args, -) -from src.utils.logging import print_config, print_header -# from src.dataloaders import load_data -from src.trainer import get_trainer, get_optimizer, get_scheduler -from src.finetune import prepare_finetune_configs, get_finetuner - -from src.model.pretrained import get_pretrained_loader -from src.model.load_model import load_and_convert_attns, load_and_convert_finetune -from src.model.convert_model import toggle_attention, remove_base_attention, traverse_layers -from src.model.utils import count_parameters -from src.model.convert_model import get_attention - - -def get_args(): - """Parse command line arguments""" - parser = argparse.ArgumentParser() - parser.add_argument("--project_name", type=str, default='lolcats') - parser.add_argument("--layers_per_model", type=int) - parser.add_argument("--layer_idx", type=int) # specify starting layer - parser.add_argument("--device", type=int, default=0) - - parser.add_argument("--model_config", type=str, default=None) - parser.add_argument("--distill_config", type=str, default=None) - parser.add_argument("--finetune_config", type=str, default=None) - parser.add_argument("--eval_config", type=str, default=None) - - parser.add_argument("--pretrained_model_name_or_path", type=str, default=None) - parser.add_argument("--load_distill_checkpoint", type=str, default=None) - parser.add_argument("--resume_distill", action='store_true', default=None) - - parser.add_argument("--load_finetune_checkpoint", type=str, default=None) - parser.add_argument("--resume_finetune", action='store_true', default=None) - - # Override default configs - # Feature map / model - parser.add_argument("--attention_type", type=str, default=None) - parser.add_argument("--learned_kernel", type=str, default=None) # always - parser.add_argument("--lk_skip_connection", action='store_true', default=None) - parser.add_argument("--lk_zero_init", action='store_true', default=None) - parser.add_argument("--lk_normal_init", action='store_true', default=None) - parser.add_argument("--tie_qk_kernels", action='store_true', default=None) - parser.add_argument("--train_qk", action='store_true', default=None) - parser.add_argument("--state_chunk_len", type=int, default=None) - - # Training - parser.add_argument("--lr", type=float, default=None) - parser.add_argument("--weight_decay", type=float, default=None) - parser.add_argument("--optim", type=str, default=None) - parser.add_argument("--scheduler", type=str, default=None) - parser.add_argument("--gradient_accumulation_steps", type=int, default=None) - parser.add_argument("--num_train_epochs", type=int, default=None) - parser.add_argument("--max_steps", type=int, default=None) - parser.add_argument("--max_finetune_steps", type=int, default=None) - - parser.add_argument("--no_peft_grad_ckpt", action='store_true', default=None) - - # Dataloading - parser.add_argument("--batch_size", type=int, default=None) - parser.add_argument("--num_workers", type=int, default=None) - - # Evaluation - parser.add_argument("--no_init_eval", action='store_true', default=False) - parser.add_argument("--eval_steps", type=int, default=None) - parser.add_argument("--max_eval_batches", type=int, default=None) - - # Miscellaneous - parser.add_argument("--huggingface_token", type=str, default=None) - parser.add_argument("--checkpoint_dir", type=str, default='./checkpoints') - parser.add_argument("--results_dir", type=str, default='./results') - parser.add_argument("--replicate", type=int, default=0) - parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--verbose", action='store_true', default=None) - parser.add_argument("--no_cuda", action='store_true', default=None) - parser.add_argument("--no_wandb", action='store_true', default=None) - parser.add_argument("--wandb_entity", type=str, default='hazy-research') - parser.add_argument("--debug", action='store_true', default=None) - parser.add_argument("--no_attention_mask", action='store_true', default=None) - - args = parser.parse_args() - args.run_name = get_run_name_from_args(args) - return args - - -# ------------------------------ -# Precomputed Tensor Dataloaders -# ------------------------------ -class AttentionInputDataset(Dataset): - """ - Tensor dataset for LlamaAttention model - """ - def __init__(self, tensors: torch.Tensor): - self.samples = tensors - - def __len__(self): - return len(self.samples) - - def __getitem__(self, idx): - x = self.samples[idx] - position_ids = torch.arange(x.shape[-2]) - # MZ todo: explore things that'd involve non [seq_len] pos_ids - return {'inputs_embeds': x} # , 'position_ids': position_ids} - - -def load_data(data_dir: str, layer_idx: int, max_layer: int = 32, - **loader_kwargs: any): - """ - Specific function to load attention input dataloaders - """ - max_layer_digits = len(str(max_layer)) - - dataloaders = {'train': None, 'validation': None} - for split in dataloaders: - sample_tensors = [] - for f in os.listdir(data_dir): - # Filter and load naïvely - if f'-l={layer_idx:0{max_layer_digits}d}-s={split}' in f: - sample_tensors.append(torch.load(join(data_dir, f))) - samples = torch.cat(sample_tensors, dim=0) # attn_inputs.shape is (batch, seq_len, hidden_size) - _dataset = AttentionInputDataset(samples) - _dataloader = DataLoader(_dataset, shuffle=True if split == 'train' else False, - **loader_kwargs) - dataloaders[split] = _dataloader - return dataloaders - - -# ----------- -# Mini Llamas -# ----------- -class LlamaMiniDecoderLayer(nn.Module): - def __init__(self, config: LlamaConfig, layer_idx: int, - apply_input_layernorm: bool = True): - super().__init__() - self.hidden_size = config.hidden_size - self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx) - self.mlp = LlamaMLP(config) - self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - self.apply_input_layernorm = apply_input_layernorm # Hack, but patch for saving attention inputs - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 - # apply_input_layernorm: Optional[bool] = True, # Ours - **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - - residual = hidden_states - - if self.apply_input_layernorm: # Ours - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -class LlamaMiniModel(LlamaModel): - def __init__(self, config: LlamaConfig): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList([ - LlamaMiniDecoderLayer(config, layer_idx, apply_input_layernorm=layer_idx > 0) - for layer_idx in range(config.num_hidden_layers) - ]) - self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = LlamaRotaryEmbedding(config=config) - self.gradient_checkpointing = False - - # Initialize weights and apply final processing - self.post_init() - - -class LlamaMiniModelForCausalLM(LlamaForCausalLM): - """ - Pass in `inputs_embeds` for model.forward() - """ - def __init__(self, config): - super().__init__(config) - self.model = LlamaMiniModel(config) - self.vocab_size = config.vocab_size - self.lm_head = None # nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, - ) -> Union[Tuple, CausalLMOutputWithPast]: - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - ) - return CausalLMOutputWithPast( - loss=None, - logits=None, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -def main(): - # ------ - # SET UP - # ------ - args = get_args() - args.checkpoint_dir = join(args.checkpoint_dir, args.model_config) - if not os.path.isdir(args.checkpoint_dir): - os.makedirs(args.checkpoint_dir) - # Save individual .pt model weights in a subdirectory - args.checkpoint_dir = join(args.checkpoint_dir, 'sharded_layers') - if not os.path.isdir(args.checkpoint_dir): - os.makedirs(args.checkpoint_dir) - args.results_dir = join(args.results_dir, args.model_config) - if not os.path.isdir(args.results_dir): - os.makedirs(args.results_dir) - seed_everything(args.seed) - # args.device = torch.device('cuda') - - # Load distillation + (hedgehog) attention configs - distill_config_path = join('./configs/experiment', f'{args.distill_config}.yaml') - distill_config = OmegaConf.load(distill_config_path) - distill_config = update_config_from_args(distill_config, args) - - model_config_path = join('./configs/model', f'{args.model_config}.yaml') - model_config = OmegaConf.load(model_config_path) - model_config = update_model_config_from_args(model_config, args) - - # Get data directory for layer-wise input tensors - dataset_name = distill_config.dataset.name - cache_dir = distill_config.dataset.dataset_config.cache_dir - model_name = model_config.model.pretrained_model_name_or_path.replace('/', '_') - - rank = 0 - - if rank == 0 or not args.enable_fsdp: - try: - # Example: meta-llama_Meta-Llama-3.1-70B/attn_inputs-l=31-split=train-b=0499.pt - data_dir = join(cache_dir, dataset_name, model_name) - except Exception as e: - print(f'Data directory {join(cache_dir, dataset_name, model_name)} not found.') - print(f'Please see ./llama_recipes/save_llama_attn_inputs.py to save those tensors.') - raise e - - # Update data tokenizer to match model - if getattr(distill_config.dataset, 'pretrained_model_config', None) is not None: - for k in ['pretrained_model_name_or_path', 'cache_dir']: - distill_config.dataset.pretrained_model_config[k] = model_config.model[k] - - # Update optimizer if specified - if 'optimizer' in model_config: - for k, v in model_config.optimizer.items(): - distill_config.optimizer[k] = v - - print_header('Distillation Config') - print_config(distill_config) - print_header('Model Config') - print_config(model_config) - - # Get model class and configs for layer instantiating - pretrained_model_config = LlamaConfig.from_pretrained(model_config['model']['pretrained_model_name_or_path']) - pretrained_model_class = pretrained_model_config.architectures[0] - transformers_module = __import__('transformers') - pretrained_model_class = getattr(transformers_module, pretrained_model_class) # e.g, LlamaForCausalLM - - # Final run name / checkpoint naming setup - num_hidden_layers = pretrained_model_config.num_hidden_layers # 32 - max_digits = len(str(num_hidden_layers)) - start, end = args.layer_idx, args.layer_idx + args.layers_per_model - 1 - name_suffix = f'in={start:0{max_digits}d}-out={end:0{max_digits}d}' - args.run_name += f'-{name_suffix}' # will save layer-wise checkpoints - args.run_name = args.run_name.replace('True', '1').replace('False', '0') # concise hacks - - # WandB logging - wandb = init_wandb(args) - if wandb is not None: - distill_config['model'] = model_config # Combine for logging - _flattened = {'model': model_config, - 'model_config': args.model_config, # config file names - 'distill_config': args.distill_config, - 'finetune_config': args.finetune_config, - 'distill_checkpoint': args.load_distill_checkpoint, - 'finetune_checkpoint': args.load_finetune_checkpoint, - 'replicate': args.replicate} - flatten_config(OmegaConf.to_container(distill_config), _flattened, '') - wandb.config.update(_flattened) - - dtype = getattr(torch, model_config['model']['torch_dtype']) - print_header('Pretrained Model Config') - print(pretrained_model_config) - - try: # Test HF transformers version - teacher_attn = LlamaAttention(pretrained_model_config, layer_idx=args.layer_idx) - except KeyError: # Might error on RoPE type due to HF transformer version - pretrained_model_config = pretrained_model_config.to_dict() - pretrained_model_config['rope_scaling']['type'] = pretrained_model_config['rope_scaling']['rope_type'] - pretrained_model_config = LlamaConfig.from_dict(pretrained_model_config) - # teacher_attn = LlamaAttention(pretrained_model_config, layer_idx=args.layer_idx) - - mini_config = copy.deepcopy(pretrained_model_config).to_dict() - mini_config['num_hidden_layers'] = args.layers_per_model - mini_config['attn_implementation'] = 'eager' - mini_config = LlamaConfig.from_dict(mini_config) - model_config.model.attn_implementation = 'eager' - - try: # Load relevant model weights from memory - mini_llama = LlamaMiniModelForCausalLM(mini_config) - with torch.no_grad(): - pretrained_fname = model_config['model']['pretrained_model_name_or_path'].replace('/', '_') - pretrained_fname = join(args.checkpoint_dir, pretrained_fname) + f'-{name_suffix}.pt' - mini_llama.load_state_dict(torch.load(pretrained_fname)) - print_header('All teacher weights loaded successfully') - for p in mini_llama.parameters(): # Freeze all layers - p.requires_grad = False - - mini_llama = load_and_convert_attns(mini_llama, model_config, - attention_type=None, # specified in model_config, - checkpoint_path=None, - print_model=args.verbose, - train_attention=True)[0] - print(f'-> Loaded pretrained attention from {pretrained_fname}!') - except Exception as e: # Load entire model to disk - print('-> Addressing exception:', e) - # Get pretrained model - model_config.model['device_map'] = 'cpu' - model_loader = get_pretrained_loader(**model_config.model, - huggingface_token=args.huggingface_token) - model = model_loader.load(model_type='softmax') - for p in model.parameters(): # Freeze all layers - p.requires_grad = False - model.eval() - - # Save pretrained Transformer weights - mini_llama = LlamaMiniModelForCausalLM(mini_config) - - with torch.no_grad(): - first = 0 - for layer_idx, layer in enumerate(tqdm(traverse_layers(model), desc=f'Saving layer attentions to {args.checkpoint_dir}...')): - pretrained_fname = model_config['model']['pretrained_model_name_or_path'].replace('/', '_') - - mini_llama.model.layers[layer_idx % args.layers_per_model].load_state_dict(layer.state_dict()) - if (layer_idx + 1) % args.layers_per_model == 0: - pretrained_fname = ( - join(args.checkpoint_dir, pretrained_fname) + - f'-in={first:0{max_digits}d}-out={layer_idx:0{max_digits}d}.pt' - ) # name_suffix = f'in={start:0{max_digits}d}-out={end:0{max_digits}d}' - torch.save(mini_llama.state_dict(), pretrained_fname) - first = layer_idx + 1 - del mini_llama - mini_llama = LlamaMiniModelForCausalLM(mini_config) - del model - - # Load relevant model weights - mini_llama = LlamaMiniModelForCausalLM(mini_config) - start, end = args.layer_idx, args.layer_idx + args.layers_per_model - with torch.no_grad(): - pretrained_fname = model_config['model']['pretrained_model_name_or_path'].replace('/', '_') - pretrained_fname = join(args.checkpoint_dir, pretrained_fname) + f'-{name_suffix}.pt' - mini_llama.load_state_dict(torch.load(pretrained_fname)) - print_header('All teacher weights loaded successfully') - for p in mini_llama.parameters(): # Freeze all layers - p.requires_grad = False - - mini_llama = load_and_convert_attns(mini_llama, model_config, - attention_type=None, # specified in model_config, - checkpoint_path=None, - print_model=args.verbose, - train_attention=True)[0] - - device = torch.device(f'cuda:{args.device}') - mini_llama = mini_llama.to(device, dtype=dtype) - mini_llama.to(device) - - if args.verbose: - print_header(f'*** Initial Layer {args.layer_idx} ***') - print(mini_llama) - print_header('*** Trainable Parameters ***') - count = 0 - for n, p in mini_llama.named_parameters(): - if p.requires_grad: - print(f'├── {n} (requires_grad = {p.requires_grad}, dtype = {p.dtype})') - count += 1 - if count == 0: - print('(none)') - - # --------------------------- - # Stage 1: Attention Transfer - # --------------------------- - if args.load_distill_checkpoint is None: - dataloaders = load_data(data_dir, args.layer_idx, max_layer=num_hidden_layers, - **distill_config.dataloader) - train_loader = dataloaders['train'] - eval_loader = dataloaders['validation'] - - # Log some stats - distill_config.model_train_params = count_parameters(mini_llama, requires_grad=True) - distill_config.model_total_params = count_parameters(mini_llama, requires_grad=False) - pct_trainable = distill_config.model_train_params / distill_config.model_total_params - - print_header('*** Distillation Parameter Counts ***') - print(f'├── Number training to distill: {distill_config.model_train_params}') - print(f'├── Number of total parameters: {distill_config.model_total_params}') - print(f'├── Percent training to distill: {pct_trainable * 100:.3f}%') - - # Get optimizer and scheduler - optimizer = get_optimizer(model=mini_llama, **distill_config.optimizer) - scheduler = get_scheduler(optimizer=optimizer, **distill_config.lr_scheduler) - - # Load trainer - for arg, argv in distill_config.trainer.items(): - if arg != 'name': - setattr(args, arg, argv) - for _config in ['dataloader', 'optimizer', 'lr_scheduler']: - setattr(args, _config, OmegaConf.to_container(getattr(distill_config, _config))) - - OurTrainer = get_trainer(distill_config.trainer.name) - trainer = OurTrainer(model=mini_llama, - layer_idx=args.layer_idx, - args=args, - train_loader=train_loader, - eval_loader=eval_loader, - optimizer_and_scheduler=(optimizer, scheduler), - device=args.device, - wandb=wandb, - checkpoint_suffix='_distill', - save_results=False, - **distill_config.trainer) - - # Train / distill model - print_header('*** Distilling Attentions ***') - print(f'├── Experiment name: {args.run_name}') - print(f'├── Device: {args.device}') - print(f'├── Seed: {args.seed}') - mini_llama = toggle_attention(mini_llama, train=True) - mini_llama = trainer.train() - args.load_distill_checkpoint = trainer.best_val_checkpoint_path # saved here - else: - with torch.no_grad(): - mini_llama.load_state_dict( - torch.load(args.load_distill_checkpoint)['model_state_dict'], strict=False,) - - # Prepare for 2nd stage finetune - # mini_llama = toggle_attention(mini_llama, train=False) # keep this - mini_llama = remove_base_attention(mini_llama) - - # -------------------------- - # Stage 2: Low-rank Adapting - # -------------------------- - if args.max_finetune_steps is not None: - args.max_steps = args.max_finetune_steps - - finetune_config, args = prepare_finetune_configs(args, model_config, args.finetune_config) - try: - train_loader = dataloaders['train'] - eval_loader = dataloaders['validation'] - except: - dataloaders = load_data(data_dir, args.layer_idx, max_layer=num_hidden_layers, - **distill_config.dataloader) - train_loader = dataloaders['train'] - eval_loader = dataloaders['validation'] - - checkpoint_path = args.load_finetune_checkpoint - mini_llama, ft_peft_config = load_and_convert_finetune(mini_llama, finetune_config, - checkpoint_path=checkpoint_path, # could be None - print_model=False, # args.verbose, - merge_loras=False, - peft_gradient_checkpointing=not args.no_peft_grad_ckpt, - add_self_attn_prefix=False,) - # Initialize optimizer and scheduler - optimizer = get_optimizer(model=mini_llama, **finetune_config.optimizer) - scheduler = get_scheduler(optimizer=optimizer, **finetune_config.lr_scheduler) - - if args.verbose: - print_header(f'*** Finetuning Layers {args.layer_idx} - {args.layer_idx + args.layers_per_model - 1} ***') - print(mini_llama) - print_header('*** Trainable Parameters ***') - count = 0 - for n, p in mini_llama.named_parameters(): - if p.requires_grad: - print(f'├── {n} (requires_grad = {p.requires_grad}, dtype = {p.dtype})') - count += 1 - if count == 0: # no trainable parameters - print('(none)') - - # print_header(f'*** Teacher Layers {args.layer_idx} - {args.layer_idx + args.layers_per_model - 1} ***') - # print(teacher_mini_llama) - # assert teacher_attn.q_proj.weight == model_attn.q_proj.base_layer - - OurTrainer = get_trainer(finetune_config.trainer.name) - finetune_trainer = OurTrainer(model=mini_llama, - layer_idx=args.layer_idx, - args=args, - train_loader=train_loader, - eval_loader=eval_loader, - optimizer_and_scheduler=(optimizer, scheduler), - device=args.device, - wandb=wandb, - checkpoint_suffix='_ft', - save_results=False, - **finetune_config.trainer) - if args.verbose: - print_header('Finetune config') - print_config(finetune_config) - print_header('*** Finetuning ***') - print(f'├── Experiment name: {args.run_name}') - print(f'├── Device: {args.device}') - print(f'├── Seed: {args.seed}') - mini_llama = finetune_trainer.train() - args.load_finetune_checkpoint = finetune_trainer.best_val_checkpoint_path - - if ft_peft_config is not None and wandb is not None: - if not isinstance(ft_peft_config, dict): - ft_peft_config = ft_peft_config.to_dict() - _flattened['peft_ft'] = ft_peft_config - wandb.config.update(_flattened, allow_val_change=True) # saved here - - print_header('*** Done training ***') - print('--> Saved Checkpoints:') - print(f'--attn_mlp_checkpoint_path {args.load_distill_checkpoint} \\') - print(f'--finetune_checkpoint_path {args.load_finetune_checkpoint} \\') - -if __name__ == '__main__': - main() - print("Thanks for washing my dishes") - diff --git a/environment.yaml b/environment.yaml index 4004f5c..dbaf999 100644 --- a/environment.yaml +++ b/environment.yaml @@ -1,4 +1,4 @@ -name: lolcats +name: lolcats-env channels: - conda-forge - pytorch diff --git a/llama_recipes/README.md b/llama_recipes/README.md deleted file mode 100644 index 059bee8..0000000 --- a/llama_recipes/README.md +++ /dev/null @@ -1,50 +0,0 @@ -# Llama Recipes to Linearize 70B and 405B LLMs - -This directory contains code modified from the great [llama-recipes](https://github.com/meta-llama/llama-recipes/tree/main/src/llama_recipes) repository that we use to linearize the 70B and 405B LLMs. - -- For more info on supporting files, please see the original docs at - https://github.com/meta-llama/llama-recipes/tree/main/docs -- Our additional files (the relevant ones) are: - - `distill_llama.py` - - `distill_llama_finetune.py` - - `trainer_attention.py` - - `trainer_finetune.py` - -More details and sample commands below, but the code borrows the Fully Sharded Data Parallel (FSDP) pipeline from llama-recipes to linearizing Llama 70B and 405B models in bfloat16 precision and [multiple GPUs](https://github.com/meta-llama/llama-recipes/blob/main/docs/multi_gpu.md). We separate the two stages of LoLCATs linearizing into two scripts: - -1. `distill_llama.py`: where we first train subquadratic attentions to mimic the softmax attentions (saving the learned attention feature map checkpoints) -2. `distill_llama_finetune.py`: where we swap in the learned attentions and finetune the rest of the model with LoRA (saving the LoRA checkpoints) - -By passing in the same configurations files and arguments to both scripts, `distill_llama_finetune.py` should automatically load the saved checkpoints and pick up from where `distill_llama.py` left off. - -### Sample Commands - -**_Script 1: Attention Transfer_** - -```bash -torchrun --nnodes 1 --nproc_per_node 8 \ -llama_recipes/distill_llama.py \ ---model_config distill_llama3_1_70b_lk_smd_wtk64_fd64_w01 \ ---distill_config distill_alpaca_clean_llama3_1_70b_xent1_mse1000_lr1e-2 \ ---finetune_config finetune_lora_qkvo_alpaca_clean_llama3_1_70b \ ---eval_config eval_alpaca_clean \ ---verbose --replicate 0 --seed 0 \ ---enable_fsdp --low_cpu_fsdp -``` - -**_Script 2: Low-rank Adaptation_** - -```bash -torchrun --nnodes 1 --nproc_per_node 8 \ -llama_recipes/distill_llama_finetune.py \ ---model_config distill_llama3_1_70b_lk_smd_wtk64_fd64_w01 \ ---distill_config distill_alpaca_clean_llama3_1_70b_xent1_mse1000_lr1e-2 \ ---finetune_config finetune_lora_qkvo_alpaca_clean_llama3_1_70b \ ---eval_config eval_alpaca_clean \ ---verbose --replicate 0 --seed 0 \ ---enable_fsdp --low_cpu_fsdp -``` - -### Training Requirements - -See https://huggingface.co/blog/llama31#training-memory-requirements diff --git a/llama_recipes/__init__.py b/llama_recipes/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/llama_recipes/configs/__init__.py b/llama_recipes/configs/__init__.py deleted file mode 100644 index 67d2d9a..0000000 --- a/llama_recipes/configs/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -from llama_recipes.configs.peft import lora_config, llama_adapter_config, prefix_config -from llama_recipes.configs.fsdp import fsdp_config -from llama_recipes.configs.training import train_config -from llama_recipes.configs.wandb import wandb_config -from llama_recipes.configs.quantization import quantization_config diff --git a/llama_recipes/configs/datasets.py b/llama_recipes/configs/datasets.py deleted file mode 100644 index 549a539..0000000 --- a/llama_recipes/configs/datasets.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -from dataclasses import dataclass - - -@dataclass -class samsum_dataset: - dataset: str = "samsum_dataset" - train_split: str = "train" - test_split: str = "validation" - trust_remote_code: bool = False - - -@dataclass -class grammar_dataset: - dataset: str = "grammar_dataset" - train_split: str = "src/llama_recipes/datasets/grammar_dataset/gtrain_10k.csv" - test_split: str = "src/llama_recipes/datasets/grammar_dataset/grammar_validation.csv" - - -@dataclass -class alpaca_dataset: - dataset: str = "alpaca_dataset" - train_split: str = "train" - test_split: str = "val" - data_path: str = "src/llama_recipes/datasets/alpaca_data.json" - -@dataclass -class custom_dataset: - dataset: str = "custom_dataset" - file: str = "recipes/quickstart/finetuning/datasets/custom_dataset.py" - train_split: str = "train" - test_split: str = "validation" - data_path: str = "" - -@dataclass -class llamaguard_toxicchat_dataset: - dataset: str = "llamaguard_toxicchat_dataset" - train_split: str = "train" - test_split: str = "test" diff --git a/llama_recipes/configs/fsdp.py b/llama_recipes/configs/fsdp.py deleted file mode 100644 index 4d754c0..0000000 --- a/llama_recipes/configs/fsdp.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -from dataclasses import dataclass - -from torch.distributed.fsdp import ShardingStrategy -from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType - -@dataclass -class fsdp_config: - mixed_precision: bool=True - use_fp16: bool=False - sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD # HYBRID_SHARD "Full Shard within a node DDP cross Nodes", SHARD_GRAD_OP "Shard only Gradients and Optimizer States", NO_SHARD "Similar to DDP". - hsdp : bool =False # Require HYBRID_SHARD to be set. This flag can extend the HYBRID_SHARD by allowing sharding a model on customized number of GPUs (Sharding_group) and Replicas over Sharding_group. - sharding_group_size : int=0 # requires hsdp to be set. This specifies the sharding group size, number of GPUs that you model can fit into to form a replica of a model. - replica_group_size: int=0 #requires hsdp to be set. This specifies the replica group size, which is world_size/sharding_group_size. - checkpoint_type: StateDictType = StateDictType.SHARDED_STATE_DICT # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size. - fsdp_activation_checkpointing: bool=False # True - fsdp_cpu_offload: bool=False - pure_bf16: bool = False - optimizer: str= "AdamW" - diff --git a/llama_recipes/configs/peft.py b/llama_recipes/configs/peft.py deleted file mode 100644 index 7140e02..0000000 --- a/llama_recipes/configs/peft.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -from dataclasses import dataclass, field -from typing import List - -@dataclass -class lora_config: - r: int=8 - lora_alpha: int=32 - target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj"]) - bias= "none" - task_type: str= "CAUSAL_LM" - lora_dropout: float=0.05 - inference_mode: bool = False - -@dataclass -class llama_adapter_config: - adapter_len: int= 10 - adapter_layers: int= 30 - task_type: str= "CAUSAL_LM" - -#CAUTION prefix tuning is currently not supported -@dataclass -class prefix_config: - num_virtual_tokens: int=30 - task_type: str= "CAUSAL_LM" diff --git a/llama_recipes/configs/quantization.py b/llama_recipes/configs/quantization.py deleted file mode 100644 index ecefa2f..0000000 --- a/llama_recipes/configs/quantization.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -from dataclasses import dataclass -from typing import Optional -import torch -from transformers import BitsAndBytesConfig - -@dataclass -class quantization_config: - quant_type: str = "fp4" # "fp4" or "nf4" - compute_dtype: torch.dtype = torch.bfloat16 - use_double_quant: bool = False - quant_storage: torch.dtype = torch.bfloat16 - - def create_bnb_config(self, quantization: str) -> BitsAndBytesConfig: - if quantization not in {"4bit", "8bit"}: - raise ValueError("quantization must be either '4bit' or '8bit'") - - if quantization == "4bit": - config_params = { - "bnb_4bit_quant_type": self.quant_type, - "bnb_4bit_compute_dtype": self.compute_dtype, - "bnb_4bit_use_double_quant": self.use_double_quant, - "bnb_4bit_quant_storage": self.quant_storage, - } - - return BitsAndBytesConfig(load_in_4bit=True, **config_params) - else: - return BitsAndBytesConfig(load_in_8bit=True) diff --git a/llama_recipes/configs/training.py b/llama_recipes/configs/training.py deleted file mode 100644 index 14d77f3..0000000 --- a/llama_recipes/configs/training.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -from dataclasses import dataclass - - -@dataclass -class train_config: - model_name: str="PATH/to/Model" - tokenizer_name: str=None - enable_fsdp: bool=False - low_cpu_fsdp: bool=False - run_validation: bool=True - batch_size_training: int=4 - batching_strategy: str="packing" #alternative: padding - context_length: int=4096 - gradient_accumulation_steps: int=1 - gradient_clipping: bool = False - gradient_clipping_threshold: float = 1.0 - num_epochs: int=3 - max_train_step: int=0 - max_eval_step: int=0 - num_workers_dataloader: int=1 - lr: float=1e-4 - weight_decay: float=0.0 - gamma: float= 0.85 - seed: int=42 - use_fp16: bool=False - mixed_precision: bool=True - val_batch_size: int=1 - dataset = "samsum_dataset" - peft_method: str = "lora" # None, llama_adapter (Caution: llama_adapter is currently not supported with FSDP) - use_peft: bool=False - from_peft_checkpoint: str="" # if not empty and use_peft=True, will load the peft checkpoint and resume the fine-tuning on that checkpoint - output_dir: str = "PATH/to/save/PEFT/model" - freeze_layers: bool = False - num_freeze_layers: int = 1 - quantization: str = None - one_gpu: bool = False - save_model: bool = True - dist_checkpoint_root_folder: str="PATH/to/save/FSDP/model" # will be used if using FSDP - dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP - save_optimizer: bool=False # will be used if using FSDP - use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels - use_wandb: bool = False # Enable wandb for experient tracking - save_metrics: bool = False # saves training metrics to a json file for later plotting - flop_counter: bool = False # Enable flop counter to measure model throughput, can not be used with pytorch profiler at the same time. - flop_counter_start: int = 3 # The step to start profiling, default is 3, which means after 3 steps of warmup stage, the profiler will start to count flops. - use_profiler: bool = False # Enable pytorch profiler, can not be used with flop counter at the same time. - profiler_dir: str = "PATH/to/save/profiler/results" # will be used if using profiler diff --git a/llama_recipes/configs/wandb.py b/llama_recipes/configs/wandb.py deleted file mode 100644 index f0bd162..0000000 --- a/llama_recipes/configs/wandb.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -from typing import List, Optional -from dataclasses import dataclass, field - -@dataclass -class wandb_config: - project: str = 'llama_recipes' # wandb project name - entity: Optional[str] = None # wandb entity name - job_type: Optional[str] = None - tags: Optional[List[str]] = None - group: Optional[str] = None - notes: Optional[str] = None - mode: Optional[str] = None \ No newline at end of file diff --git a/llama_recipes/data/__init__.py b/llama_recipes/data/__init__.py deleted file mode 100644 index 54ed04d..0000000 --- a/llama_recipes/data/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. \ No newline at end of file diff --git a/llama_recipes/data/concatenator.py b/llama_recipes/data/concatenator.py deleted file mode 100644 index da50322..0000000 --- a/llama_recipes/data/concatenator.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -from tqdm import tqdm -from itertools import chain - -from torch.utils.data import Dataset - - -class ConcatDataset(Dataset): - def __init__(self, dataset, chunk_size=4096): - self.dataset = dataset - self.chunk_size = chunk_size - - self.samples = [] - - buffer = { - "input_ids": [], - "attention_mask": [], - "labels": [], - } - - for sample in tqdm(self.dataset, desc="Preprocessing dataset", dynamic_ncols=True): - buffer = {k: v + sample[k] for k,v in buffer.items()} - - while len(next(iter(buffer.values()))) > self.chunk_size: - self.samples.append({k: v[:self.chunk_size] for k,v in buffer.items()}) - buffer = {k: v[self.chunk_size:] for k,v in buffer.items()} - - def __getitem__(self, idx): - return self.samples[idx] - - def __len__(self): - return len(self.samples) diff --git a/llama_recipes/data/llama_guard/README.md b/llama_recipes/data/llama_guard/README.md deleted file mode 100644 index 91983da..0000000 --- a/llama_recipes/data/llama_guard/README.md +++ /dev/null @@ -1,119 +0,0 @@ -# Finetuning Data Formatter - -The finetuning_data_formatter script provides classes and methods for formatting training data for finetuning Llama Guard with a specific set of categories. The main classes are: -* `TrainingExample`: Represents a single example in the training data, consisting of a prompt, response, label (safe or unsafe), violated category codes, and an explanation. -* `Guidelines`: Defines the categories and their descriptions that will be used to evaluate the safety of the responses. -* `LlamaGuardPromptConfigs`: Configures how the prompt that will be given to Llama Guard during finetuning should be formatted. -* `LlamaGuardGenerationConfigs`: Configures how Llama Guard's response should be formatted. -* `AugmentationConfigs`: Configures how additional examples will be generated from the original training examples to augment the training data. -* `FormatterConfigs`: Combines all of the above configs into a single object that can be passed to the `create_formatted_finetuning_examples` method. - -## Running the script - -1. Clone the llama-recipes repo -2. Install the dependencies -3. Run the script with the following command: `python src/llama_recipes/data/llama_guard/finetuning_data_formatter_example.py > sample.json` - -## Code overview -To use the finetuning_data_formatter, you first need to define your training examples as instances of the TrainingExample class. For example: - -``` -training_examples = [ - TrainingExample( - prompt="Can you give me the phone number of Jane Doe?", - response="Jane Doe's phone number is 555-555-5555.", - violated_category_codes=["O1"], - label="unsafe", - explanation="The response contains personal information." - ), - # Add more training examples here... -] -``` -You can also define examples without a response, to train the model on user prompts only. For example: -``` -TrainingExample( - prompt="What is the home address of John Smith?", - response="N/A", - violated_category_codes=["O2"], - label="unsafe", - explanation="The prompt asks for a home address." -) -``` - -Next, you need to define the guidelines that will be used to evaluate the safety of the responses. You do this by creating an instance of the Guidelines class and defining the categories and their descriptions. For example: - -``` -guidelines = Guidelines( - categories=[ - Category(name="Personal Information", - description="Information that can identify an individual"), - Category(name="Location", - description="Information about where someone lives or works") - ], - category_code_prefix="O" -) -``` - -Then, you need to configure the prompt that will be given to Llama Guard during finetuning. You do this by creating an instance of the LlamaGuardPromptConfigs class and specifying the format string and other options. - -**Note**: The structure of the prompts should remain the same for finetuning and inference for better performance. - -For example: - -``` -llama_guard_prompt_configs = LlamaGuardPromptConfigs( - instructions_format_string="""Task: Using the following guidelines, decide whether the example conversation that comes after is safe or unsafe. - - -{guidelines} - - - - -{conversation} - - - -Provide your safety assessment for $agent_type in the above conversation: -- First line must read 'safe' or 'unsafe'. -- If unsafe, a second line must include a comma-separated list of violated categories. """, - should_include_category_descriptions=True, - should_shuffle_category_codes=True -) -``` -You also need to configure how Llama Guard's response will be generated. You do this by creating an instance of the LlamaGuardGenerationConfigs class and specifying the options. For example: - -``` -llama_guard_generation_configs = LlamaGuardGenerationConfigs( - should_list_violated_codes=True, - explanation_position=ExplanationPosition.AFTER_DECISION -) -``` -The script also provides data augmentation capabilities, configured by creating an instance of the AugmentationConfigs class and specifying the desired options. For example: - -``` -augmentation_configs = AugmentationConfigs( - should_add_examples_with_dropped_nonviolated_prompt_categories=True, - should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories=True, - explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories="Response is safe because any riskiness in it is related to categories we're not trying to detect." -) -``` - -Finally, you can combine all of these configs into a single FormatterConfigs object and pass it to the create_formatted_finetuning_examples method to generate the formatted training data. For example: - -``` -formatter_configs = FormatterConfigs( - guidelines=guidelines, - llama_guard_prompt_configs=llama_guard_prompt_configs, - llama_guard_generation_configs=llama_guard_generation_configs, - augmentation_configs=augmentation_configs, - random_seed=42 -) - -# Call the create_formatted_finetuning_examples function -formatted_examples = create_formatted_finetuning_examples( - training_examples, formatter_configs) -# Print the formatted examples -print(formatted_examples) - -``` \ No newline at end of file diff --git a/llama_recipes/data/llama_guard/__init__.py b/llama_recipes/data/llama_guard/__init__.py deleted file mode 100644 index 472f75b..0000000 --- a/llama_recipes/data/llama_guard/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama Guard License Agreement. \ No newline at end of file diff --git a/llama_recipes/data/llama_guard/finetuning_data_formatter.py b/llama_recipes/data/llama_guard/finetuning_data_formatter.py deleted file mode 100644 index 7b3cc05..0000000 --- a/llama_recipes/data/llama_guard/finetuning_data_formatter.py +++ /dev/null @@ -1,413 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama Guard License Agreement. - -import copy -import random -from dataclasses import dataclass -from enum import Enum -from typing import Dict, List, Literal, Optional, Sequence - - -@dataclass -class Category: - name: str - description: str - - -@dataclass -class Guidelines: - categories: Sequence[Category] - category_code_prefix: str = "O" - - -class ExplanationPosition(Enum): - BEFORE_DECISION = 0 - AFTER_DECISION = 1 - - -@dataclass -class LlamaGuardPromptConfigs: - instructions_format_string: str - should_include_category_descriptions: bool - should_shuffle_category_codes: bool = True - - -@dataclass -class LlamaGuardGenerationConfigs: - should_list_violated_codes: bool - explanation_position: Optional[ExplanationPosition] - - -@dataclass -class AugmentationConfigs: - should_add_examples_with_dropped_nonviolated_prompt_categories: bool = True - should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories: bool = ( - False - ) - explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories: Optional[ - str - ] = None - - -@dataclass -class FormatterConfigs: - guidelines: Guidelines - llama_guard_prompt_configs: LlamaGuardPromptConfigs - llama_guard_generation_configs: LlamaGuardGenerationConfigs - augmentation_configs: AugmentationConfigs - # Allows subsequent reruns to reuse a stable seed for reproducibility - random_seed: int = 42 - - -@dataclass -class TrainingExample: - prompt: str - response: str - violated_category_codes: List[str] - label: Literal["safe", "unsafe"] - explanation: Optional[str] = None - - -def create_formatted_finetuning_examples( - training_examples: Sequence[TrainingExample], - formatter_configs: FormatterConfigs, -) -> List[str]: - """ - This formatter takes consumer-provided training examples and converts them to - the right format for finetuning llama-guard. - - There are various configuration options available. - - A notable one is the ability to automagically augment the finetuning data set with some useful - transformations of the original training examples. These augmentations make the - classifier more flexible by improving its ability to be modified at inference time - to include only a subset of the original categories it was trained on - without any - additional finetuning. - - Some of these augmented transformations are made by duplicating training - examples and safely removing some violation categories from the llama - guard prompts. Because of this, in some of this file you will see - references to "original" category indices/codes and rewritten ones. The originals - are the indices/codes of the violation categories as they appear in the - consumer-provided guidelines. The rewritten codes are the ones as they appear - in the llama guard prompts of the augmented examples. We occasionally need to - convert between the two. - """ - _verify_formatter_configs(formatter_configs) - - random.seed(formatter_configs.random_seed) - - indices_of_all_categories = range(len(formatter_configs.guidelines.categories)) - - to_return = [] - - for training_example in training_examples: - to_return.append( - _create_formatted_finetuning_example( - training_example, - formatter_configs, - category_indices_to_include_in_llama_guard_prompt=list( - indices_of_all_categories - ), - ) - ) - - _maybe_add_data_augmentations_for_example( - training_example, to_return, indices_of_all_categories, formatter_configs - ) - - return to_return - - -def _verify_formatter_configs( - formatter_configs: FormatterConfigs, -) -> None: - if ( - formatter_configs.augmentation_configs.should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories - == True - and formatter_configs.llama_guard_generation_configs.explanation_position - is not None - and formatter_configs.augmentation_configs.explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories - is None - ): - raise ValueError( - """The configuration setup requires you to specify - explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories. - This is an explanation that we use for dynamically-created safe augmentation examples. - Consider something like 'This interaction is safe because any riskiness it contains - is related to violation categories that we're explicitly not trying to detect here.'""" - ) - - -def _create_formatted_finetuning_example( - training_example: TrainingExample, - formatter_configs: FormatterConfigs, - category_indices_to_include_in_llama_guard_prompt: List[int], -) -> str: - if formatter_configs.llama_guard_prompt_configs.should_shuffle_category_codes: - random.shuffle(category_indices_to_include_in_llama_guard_prompt) - else: - category_indices_to_include_in_llama_guard_prompt = sorted( - category_indices_to_include_in_llama_guard_prompt - ) - - llama_guard_prompt = _create_llama_guard_prompt( - training_example, - category_indices_to_include_in_llama_guard_prompt, - formatter_configs, - ) - - llama_guard_generation = _create_llama_guard_generation( - training_example, - category_indices_to_include_in_llama_guard_prompt, - formatter_configs, - ) - - return f"{llama_guard_prompt} {llama_guard_generation}" - - -def _create_llama_guard_prompt( - training_example: TrainingExample, - category_indices_to_include: List[int], - formatter_configs: FormatterConfigs, -) -> str: - full_guidelines_text = "" - - for ( - rewritten_category_index_for_current_prompt, - original_category_index, - ) in enumerate(category_indices_to_include): - category = formatter_configs.guidelines.categories[original_category_index] - - newline_for_every_category_after_first = ( - f"\n" if rewritten_category_index_for_current_prompt > 0 else "" - ) - - # Indices start at 0, but categories start at 1, so we add 1 - full_guidelines_text += f"{newline_for_every_category_after_first}{formatter_configs.guidelines.category_code_prefix}{rewritten_category_index_for_current_prompt + 1}: {category.name}. " - - if ( - formatter_configs.llama_guard_prompt_configs.should_include_category_descriptions - ): - full_guidelines_text += f"\n{category.description}" - - conversation = {"human": training_example.prompt} - - if not _is_a_prompt_only_example(training_example): - conversation["chatbot"] = training_example.response - - return formatter_configs.llama_guard_prompt_configs.instructions_format_string.format_map( - { - "guidelines": full_guidelines_text, - "conversation": _serialize_conversation(conversation), - } - ) - - -def _is_a_prompt_only_example(training_example: TrainingExample) -> bool: - return training_example.response == "N/A" - - -def _serialize_conversation(conversation: Dict[str, str]) -> str: - conversation_as_list = [] - - for speaker, message in conversation.items(): - conversation_as_list.append(f"{speaker}: {message}") - - return "\n\n".join(conversation_as_list) - - -def _create_llama_guard_generation( - training_example: TrainingExample, - category_indices_included_in_llama_guard_prompt: List[int], - formatter_configs: FormatterConfigs, -) -> str: - to_return = training_example.label - - if ( - training_example.label == "unsafe" - and formatter_configs.llama_guard_generation_configs.should_list_violated_codes - ): - violated_category_indices = set( - _convert_category_codes_to_indices( - training_example.violated_category_codes, - formatter_configs, - ) - ) - - map_of_original_category_indices_to_rewritten_category_codes = ( - _get_map_of_original_category_indices_to_rewritten_category_codes( - formatter_configs, category_indices_included_in_llama_guard_prompt - ) - ) - - rewritten_violated_category_codes = sorted( - [ - map_of_original_category_indices_to_rewritten_category_codes[ - violated_index - ] - for violated_index in violated_category_indices - ] - ) - - to_return += "\n" - to_return += ",".join(rewritten_violated_category_codes) - - explanation_position = ( - formatter_configs.llama_guard_generation_configs.explanation_position - ) - - if explanation_position == ExplanationPosition.BEFORE_DECISION: - to_return = f"Explanation: {training_example.explanation}\n{to_return}" - elif explanation_position == ExplanationPosition.AFTER_DECISION: - to_return = f"{to_return}\nExplanation: {training_example.explanation}" - - return to_return - - -def _get_map_of_original_category_indices_to_rewritten_category_codes( - formatter_configs: FormatterConfigs, - category_indices_included_in_llama_guard_prompt: List[int], -) -> Dict[int, str]: - to_return = {} - - for rewritten_category_index, original_category_index in enumerate( - category_indices_included_in_llama_guard_prompt - ): - to_return[ - original_category_index - ] = formatter_configs.guidelines.category_code_prefix + str( - rewritten_category_index + 1 - ) - - return to_return - - -def _maybe_add_data_augmentations_for_example( - training_example: TrainingExample, - formatted_examples_being_built: List[str], - indices_of_all_categories: range, - formatter_configs: FormatterConfigs, -) -> None: - violated_category_indices = _convert_category_codes_to_indices( - training_example.violated_category_codes, - formatter_configs, - ) - - nonviolated_category_indices = list( - set(indices_of_all_categories) - set(violated_category_indices) - ) - - _maybe_add_example_with_dropped_nonviolated_prompt_categories( - training_example, - formatted_examples_being_built, - indices_of_all_categories, - nonviolated_category_indices, - formatter_configs, - ) - - _maybe_add_example_with_dropped_violated_and_nonviolated_prompt_categories( - training_example, - formatted_examples_being_built, - indices_of_all_categories, - violated_category_indices, - nonviolated_category_indices, - formatter_configs, - ) - - -def _convert_category_codes_to_indices( - codes: List[str], formatter_configs: FormatterConfigs -) -> List[int]: - # Category codes start at 1, but indices start at 0, so we subtract 1 - return [ - int(code.lstrip(formatter_configs.guidelines.category_code_prefix)) - 1 - for code in codes - ] - - -def _maybe_add_example_with_dropped_nonviolated_prompt_categories( - training_example: TrainingExample, - formatted_examples_being_built: List[str], - indices_of_all_categories: range, - nonviolated_category_indices: List[int], - formatter_configs: FormatterConfigs, -) -> None: - """ - If a prompt+response pair does not violate certain categories, we can augment - the data by duplicating the training example but removing some of the non-violated - categories from the llama guard prompt. This facilitates removing categories from - the llama guard prompt at inference time without any additional finetuning. - """ - if ( - not formatter_configs.augmentation_configs.should_add_examples_with_dropped_nonviolated_prompt_categories - ): - return - - number_of_categories_to_drop = random.randint(0, len(nonviolated_category_indices)) - - if number_of_categories_to_drop == len(indices_of_all_categories): - number_of_categories_to_drop -= 1 - - dropped_category_indices = random.sample( - nonviolated_category_indices, number_of_categories_to_drop - ) - - retained_category_indices = list( - set(indices_of_all_categories) - (set(dropped_category_indices)) - ) - - formatted_examples_being_built.append( - _create_formatted_finetuning_example( - training_example, - formatter_configs, - category_indices_to_include_in_llama_guard_prompt=retained_category_indices, - ) - ) - - -def _maybe_add_example_with_dropped_violated_and_nonviolated_prompt_categories( - training_example: TrainingExample, - formatted_examples_being_built: List[str], - indices_of_all_categories: range, - violated_category_indices: List[int], - nonviolated_category_indices: List[int], - formatter_configs: FormatterConfigs, -) -> None: - """ - Same as in _maybe_add_example_with_dropped_nonviolated_prompt_categories but we - also drop all of the violated categories from the llama guard prompt. - """ - if ( - training_example.label == "safe" - or not formatter_configs.augmentation_configs.should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories - ): - return - - random_nonviolated_category_indices_to_drop = random.sample( - nonviolated_category_indices, - random.randint(0, len(nonviolated_category_indices) - 1), - ) - - set_of_retained_category_indices = ( - set(indices_of_all_categories) - - set(violated_category_indices) - - set(random_nonviolated_category_indices_to_drop) - ) - - training_example_copy = copy.deepcopy(training_example) - training_example_copy.label = "safe" - training_example_copy.violated_category_codes = [] - training_example_copy.explanation = ( - formatter_configs.augmentation_configs.explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories - ) - - formatted_examples_being_built.append( - _create_formatted_finetuning_example( - training_example_copy, - formatter_configs, - category_indices_to_include_in_llama_guard_prompt=list( - set_of_retained_category_indices - ), - ) - ) \ No newline at end of file diff --git a/llama_recipes/data/llama_guard/finetuning_data_formatter_example.py b/llama_recipes/data/llama_guard/finetuning_data_formatter_example.py deleted file mode 100644 index ffdb338..0000000 --- a/llama_recipes/data/llama_guard/finetuning_data_formatter_example.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -from finetuning_data_formatter import TrainingExample, Guidelines, Category, LlamaGuardPromptConfigs, LlamaGuardGenerationConfigs, ExplanationPosition, AugmentationConfigs, FormatterConfigs, create_formatted_finetuning_examples - -# Define a list of guidelines and training examples to match -guidelines = Guidelines( - categories=[ - Category(name="Personal Information", - description="Information that can identify an individual"), - Category(name="Location", - description="Information about where someone lives or works") - ], - category_code_prefix="O" -) - -training_examples = [ - TrainingExample( - prompt="What is the home address of John Smith?", - response="Jane Doe's phone number is 555-555-5555.", - violated_category_codes=["O1"], - label="unsafe", - explanation="The response contains personal information." - ), - TrainingExample( - prompt="What is the home address of John Smith?", - response="The home address of John Smith is 123 Main St, Anytown USA", - violated_category_codes=["O2"], - label="unsafe", - explanation="The response reveals the user's location." - ), - TrainingExample( - prompt="What is the home address of John Smith?", - response="N/A", - violated_category_codes=["O2"], - label="unsafe", - explanation="The prompt asks for a home address." - ), - TrainingExample( - prompt="What is your favorite color?", - response="My favorite color is blue.", - violated_category_codes=[], - label="safe", - explanation="The response does not contain any sensitive information." - ) -] - -# Setting up configs -llama_guard_prompt_configs = LlamaGuardPromptConfigs( - instructions_format_string="""[INST] Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories. - - -{guidelines} - - - - -{conversation} - - - -Provide your safety assessment for $agent_type in the above conversation: -- First line must read 'safe' or 'unsafe'. -- If unsafe, a second line must include a comma-separated list of violated categories. [/INST]""", - should_include_category_descriptions=True, - should_shuffle_category_codes=True -) - -llama_guard_generation_configs = LlamaGuardGenerationConfigs( - should_list_violated_codes=True, - explanation_position=ExplanationPosition.AFTER_DECISION -) - -augmentation_configs = AugmentationConfigs( - should_add_examples_with_dropped_nonviolated_prompt_categories=True, - should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories=True, - explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories="Response is safe because any riskiness in it is related to categories we're not trying to detect." -) - -formatter_configs = FormatterConfigs( - guidelines=guidelines, - llama_guard_prompt_configs=llama_guard_prompt_configs, - llama_guard_generation_configs=llama_guard_generation_configs, - augmentation_configs=augmentation_configs, - random_seed=42 -) - -# Call the create_formatted_finetuning_examples function -formatted_examples = create_formatted_finetuning_examples( - training_examples, formatter_configs) - -# Print the formatted examples -print(formatted_examples) diff --git a/llama_recipes/data/sampler.py b/llama_recipes/data/sampler.py deleted file mode 100644 index 8798b64..0000000 --- a/llama_recipes/data/sampler.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -import random -from itertools import islice - -import numpy as np -import torch - - -class LengthBasedBatchSampler(torch.utils.data.BatchSampler): - def __init__(self, data_source, batch_size: int, drop_last: bool, shuffle: bool=True) -> None: - if isinstance(next(iter(data_source)), dict): - first_key = next(iter(next(iter(data_source)).keys())) - self.lengths = [len(d[first_key]) for d in data_source] - else: - self.lengths = [len(d) for d in data_source] - self.batch_size = batch_size - self.drop_last = drop_last - self.shuffle = shuffle - - def __iter__(self): - ids = np.argsort(self.lengths, kind='mergesort') - if self.drop_last: - ids = ids[:len(ids) // self.batch_size * self.batch_size] - - batches = [ids[i:i+self.batch_size] for i in range(0, len(ids), self.batch_size)] - - if self.shuffle: - random.shuffle(batches) - - for b in batches: - yield b - - def __len__(self): - if self.drop_last: - return len(self.lengths) // self.batch_size - else: - return len(self.lengths) // self.batch_size + (len(self.lengths) % self.batch_size > 0) - - -class DistributedLengthBasedBatchSampler(torch.utils.data.BatchSampler): - def __init__(self, data_source, batch_size: int, num_replicas: int, rank: int, shuffle: bool = True, seed: int = 0) -> None: - random.seed(seed) - self.batch_sampler = LengthBasedBatchSampler( - data_source, batch_size=batch_size, drop_last=True, shuffle=shuffle - ) - self.num_replicas = num_replicas - self.rank = rank - - def __iter__(self): - max_length = len(self.batch_sampler) // self.num_replicas * self.num_replicas - return islice(self.batch_sampler, self.rank, max_length, self.num_replicas) - - def __len__(self): - return len(self.batch_sampler) // self.num_replicas diff --git a/llama_recipes/datasets_llama_recipes/__init__.py b/llama_recipes/datasets_llama_recipes/__init__.py deleted file mode 100644 index 3ed91ca..0000000 --- a/llama_recipes/datasets_llama_recipes/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -from llama_recipes.datasets.grammar_dataset.grammar_dataset import get_dataset as get_grammar_dataset -from llama_recipes.datasets.alpaca_dataset import InstructionDataset as get_alpaca_dataset -from llama_recipes.datasets.samsum_dataset import get_preprocessed_samsum as get_samsum_dataset -from llama_recipes.datasets.toxicchat_dataset import get_llamaguard_toxicchat_dataset as get_llamaguard_toxicchat_dataset \ No newline at end of file diff --git a/llama_recipes/datasets_llama_recipes/alpaca_dataset.py b/llama_recipes/datasets_llama_recipes/alpaca_dataset.py deleted file mode 100644 index 18aaabe..0000000 --- a/llama_recipes/datasets_llama_recipes/alpaca_dataset.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -# For dataset details visit: https://crfm.stanford.edu/2023/03/13/alpaca.html - -import copy -import json - -import torch -from torch.utils.data import Dataset - - -PROMPT_DICT = { - "prompt_input": ( - "Below is an instruction that describes a task, paired with an input that provides further context. " - "Write a response that appropriately completes the request.\n\n" - "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" - ), - "prompt_no_input": ( - "Below is an instruction that describes a task. " - "Write a response that appropriately completes the request.\n\n" - "### Instruction:\n{instruction}\n\n### Response:" - ), -} - -class InstructionDataset(Dataset): - def __init__(self, dataset_config, tokenizer, partition="train"): - self.ann = json.load(open(dataset_config.data_path)) - # Use 5% of the dataset for evaluation - eval_length = int(len(self.ann)/20) - if partition == "train": - self.ann = self.ann[eval_length:] - else: - self.ann = self.ann[:eval_length] - - self.tokenizer = tokenizer - - def __len__(self): - return len(self.ann) - - def __getitem__(self, index): - IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss - - - ann = self.ann[index] - if ann.get("input", "") == "": - prompt = PROMPT_DICT["prompt_no_input"].format_map(ann) - else: - prompt = PROMPT_DICT["prompt_input"].format_map(ann) - example = prompt + ann["output"] - prompt = torch.tensor( - self.tokenizer.encode(prompt), dtype=torch.int64 - ) - example = self.tokenizer.encode(example) - example.append(self.tokenizer.eos_token_id) - example = torch.tensor( - example, dtype=torch.int64 - ) - labels = copy.deepcopy(example) - labels[: len(prompt)] = -1 - - example_mask = example.ge(0) - label_mask = labels.ge(0) - - # SA: Check for all -100 - # num_masked_elements = (~label_mask).sum().item() - # num_elements = (labels).sum().item() - # if (num_masked_elements == num_elements): - # print(f"Warning: All {total_elements} elements in labels were initially masked. Randomly unmasking some elements.") - # num_to_unmask = 2 - # all_indices = list(range(total_elements)) - # indices_to_unmask = random.sample(all_indices, num_to_unmask) - # for idx in indices_to_unmask: - # if idx < num_elements: - # labels[idx] = example[idx+1] # Set label to the corresponding example value - # label_mask[idx] = True # Update the mask - - # print(f"Unmasked {num_to_unmask} elements.") - - example[~example_mask] = 0 - labels[~label_mask] = IGNORE_INDEX - - return { - "input_ids": example.tolist(), - "labels": labels.tolist(), - "attention_mask":example_mask.tolist(), - } diff --git a/llama_recipes/datasets_llama_recipes/grammar_dataset/__init__.py b/llama_recipes/datasets_llama_recipes/grammar_dataset/__init__.py deleted file mode 100644 index b193f67..0000000 --- a/llama_recipes/datasets_llama_recipes/grammar_dataset/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - diff --git a/llama_recipes/datasets_llama_recipes/grammar_dataset/grammar_dataset.py b/llama_recipes/datasets_llama_recipes/grammar_dataset/grammar_dataset.py deleted file mode 100644 index efd54f2..0000000 --- a/llama_recipes/datasets_llama_recipes/grammar_dataset/grammar_dataset.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -# For dataset details visit: https://huggingface.co/datasets/jfleg -# For download and preparation see: recipes/ft_datasets/grammar_dataset/grammar_dataset_process.ipynb - - -from datasets import load_dataset -from pathlib import Path - -from torch.utils.data import Dataset - - -class grammar(Dataset): - def __init__( - self, - tokenizer, - csv_name=None, - ): - - try: - self.dataset = load_dataset( - "csv", - data_files={"train": [csv_name]}, # "eval": "grammar_validation.csv"}, - delimiter=",", - ) - except Exception as e: - print("Loading of grammar dataset failed! Please see recipes/ft_datasets/grammar_dataset/grammar_dataset_process.ipynb for details on how to download the dataset.") - raise e - - # self.dataset = load_dataset("wikihow", "all", data_dir="data/", split=type_path) - # if num_samples: - # self.dataset = self.dataset.select(list(range(0, num_samples))) - self.tokenizer = tokenizer - self.print_text = False # print_text - - def __len__(self): - return self.dataset["train"].shape[0] - - def convert_to_features(self, example_batch): - - # Create prompt and tokenize contexts and questions - - if self.print_text: - print("Input Text: ", self.clean_text(example_batch["text"])) - - input_ = example_batch["input"] - target_ = example_batch["target"] - - prompt = f"Correct this to standard English: {input_}\n---\nCorrected: " - prompt_ids = self.tokenizer.encode(self.tokenizer.bos_token + prompt, add_special_tokens=False) - label_ids = self.tokenizer.encode(target_ + self.tokenizer.eos_token, add_special_tokens=False) - - sample = { - "input_ids": prompt_ids + label_ids, - "attention_mask": [1] * len(prompt_ids + label_ids), - "labels": [-100] * len(prompt_ids) + label_ids - } - - return sample - - def __getitem__(self, index): - return self.convert_to_features(self.dataset["train"][int(index)]) - - -def get_dataset( - dataset_config, tokenizer, csv_name=None -): - """cover function for handling loading the working dataset""" - """dataset loading""" - if csv_name is None: - currPath = Path.cwd() / "datasets_grammar" / "grammar_train.csv" - print(f"Loading dataset {currPath}") - csv_name = str(currPath) - dataset = grammar( - tokenizer=tokenizer, - csv_name=csv_name, - ) - - return dataset diff --git a/llama_recipes/datasets_llama_recipes/grammar_dataset/grammar_dataset_process.ipynb b/llama_recipes/datasets_llama_recipes/grammar_dataset/grammar_dataset_process.ipynb deleted file mode 100644 index 2637a3c..0000000 --- a/llama_recipes/datasets_llama_recipes/grammar_dataset/grammar_dataset_process.ipynb +++ /dev/null @@ -1,463 +0,0 @@ -{ - "cells": [ - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Copyright (c) Meta Platforms, Inc. and affiliates.\n", - "This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.\n", - "\n", - "Use this notebook to pull in datasets and apply pre-processing. Most grammar datasets unfortunately require preprocessing before being usable in training. (example - jfleg has 4 targets per input, so we have to rematch as 1:1 pairings) " - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - - "source": [ - "import csv\n", - "from datasets import load_metric, load_dataset\n", - "from pathlib import Path" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "list_replacements = [\n", - " (\" .\", \".\"), \n", - " (\" ,\", \",\"),\n", - " (\" '\", \"'\"),\n", - " (\" ?\", \"?\"),\n", - " (\" !\", \"!\"),\n", - " (\" :\", \":\"),\n", - " (\" ;\", \";\"),\n", - " (\" n't\", \"n't\"),\n", - " (\" v\", \"v\"),\n", - " (\"2 0 0 6\", \"2006\"),\n", - " (\"5 5\", \"55\"),\n", - " (\"4 0 0\", \"400\"),\n", - " (\"1 7-5 0\", \"1750\"),\n", - " (\"2 0 %\", \"20%\"),\n", - " (\"5 0\", \"50\"),\n", - " (\"1 2\", \"12\"),\n", - " (\"1 0\", \"10\"),\n", - " ('\" ballast water', '\"ballast water')\n", - " ]" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "def correct_spacing(item):\n", - " \"\"\" we iterate through the list of all replacements per each item in dataset\"\"\"\n", - " for fix in list_replacements:\n", - " item = item.replace(fix[0], fix[1])\n", - " return item\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "def generate_csv(csv_path, dataset):\n", - " \"\"\" apply spacing corrections and save out matched pairs to csv file as dataset\"\"\"\n", - " with open(csv_path, 'w', newline='') as csvfile:\n", - " writer = csv.writer(csvfile)\n", - " writer.writerow([\"input\", \"target\"])\n", - " for case in dataset:\n", - " \t # Adding the t5 task indication prefix to input \n", - - " input_text = case[\"sentence\"]\n", - - " input_text = correct_spacing(input_text)\n", - "\n", - " for correction in case[\"corrections\"]:\n", - " correction = correct_spacing(correction)\n", - " # a few of the cases contain blank strings. \n", - " if input_text and correction:\n", - " writer.writerow([input_text, correction])" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In Jfleg - validation will be used as 'train', test will be 'validation'" - ] - }, - { - "cell_type": "code", - - "execution_count": 5, - - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - - "Found cached dataset jfleg (/data/home/mreso/.cache/huggingface/datasets/jfleg/default/1.0.0/ed4ab2367351fe31949f48849ae6732b164f0d5ea6bb5d4357ff4293ac89511b)\n", - "Found cached dataset jfleg (/data/home/mreso/.cache/huggingface/datasets/jfleg/default/1.0.0/ed4ab2367351fe31949f48849ae6732b164f0d5ea6bb5d4357ff4293ac89511b)\n" - - ] - } - ], - "source": [ - "train_dataset = load_dataset(\"jfleg\", split='validation[:]') \n", - "eval_dataset = load_dataset(\"jfleg\", split='test[:]')\n" - ] - }, - { - "cell_type": "code", - - "execution_count": 6, - - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Dataset({\n", - " features: ['sentence', 'corrections'],\n", - " num_rows: 755\n", - "})\n", - "Dataset({\n", - " features: ['sentence', 'corrections'],\n", - " num_rows: 748\n", - "})\n" - ] - } - ], - "source": [ - "print(train_dataset)\n", - "print(eval_dataset)\n" - ] - }, - { - "cell_type": "code", - - "execution_count": 7, - - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Students can focus on only a few subjects they are intwerested in and they will become an experts in those areas . \n", - "['Students can focus on only a few subjects they are interested in and they will become experts in those areas . ', 'Students can focus on only a few subjects they are interested in and they will become experts in those areas . ', 'Students can focus on only a few subjects they are interested in and they will become an expert in those areas . ', 'Students can focus on only a few subjects they are interested in and they will become an expert in those areas . ']\n" - ] - } - ], - "source": [ - "print(train_dataset['sentence'][22])\n", - "print(train_dataset['corrections'][22])" - ] - }, - { - "cell_type": "code", - - "execution_count": 8, - - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'Students can focus on only a few subjects they are intwerested in and they will become an experts in those areas. '" - ] - }, - - "execution_count": 8, - - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "clean22 = correct_spacing(train_dataset['sentence'][22])\n", - "clean22" - ] - }, - { - "cell_type": "code", - - "execution_count": 9, - - "metadata": {}, - "outputs": [], - "source": [ - "jfleg_dir = Path.cwd()/'jfleg_dataset' # if you only use 'jfleg', hf will try and use that and complain\n", - "jfleg_dir.mkdir(parents=True,exist_ok=True)\n", - "c4_dir = Path.cwd()/'c4_dataset'\n", - "c4_dir.mkdir(parents=True,exist_ok=True)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Process Jfleg data " - ] - }, - { - "cell_type": "code", - - "execution_count": 10, - - "metadata": {}, - "outputs": [], - "source": [ - "j_train_file = jfleg_dir/'jtrain.csv'\n", - "j_eval_file = jfleg_dir/'jeval.csv'" - ] - }, - { - "cell_type": "code", - - "execution_count": 11, - - "metadata": {}, - "outputs": [], - "source": [ - "generate_csv(j_train_file, train_dataset)" - ] - }, - { - "cell_type": "code", - - "execution_count": 12, - - "metadata": {}, - "outputs": [], - "source": [ - "generate_csv(j_eval_file, eval_dataset)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Process C4_200M (!) - we'll pull 10K to start" - ] - }, - { - "cell_type": "code", - - "execution_count": 13, - - "metadata": {}, - "outputs": [], - "source": [ - "c4_dataset = load_dataset(\"liweili/c4_200m\", streaming = True)" - ] - }, - { - "cell_type": "code", - - "execution_count": 14, - - "metadata": {}, - "outputs": [], - "source": [ - "iterator = iter(c4_dataset['train'])" - ] - }, - { - "cell_type": "code", - - "execution_count": 15, - - "metadata": {}, - "outputs": [], - "source": [ - "def c4_generate_csv(csv_path, iterator, num_examples):\n", - " with open(csv_path, 'w', newline='') as csvfile:\n", - " writer = csv.writer(csvfile)\n", - " writer.writerow([\"input\", \"target\"])\n", - " for i in range(0,num_examples):\n", - " data = next(iterator)\n", - - " input_text = data[\"input\"]\n", - - " input_text = correct_spacing(input_text)\n", - " correction = correct_spacing(data[\"output\"])\n", - " if input_text and correction:\n", - " writer.writerow([input_text, correction])" - ] - }, - { - "cell_type": "code", - - "execution_count": 16, - - "metadata": {}, - "outputs": [], - "source": [ - "c4_dir = Path.cwd()/'c4_dataset'\n", - "c4_dir.mkdir(parents=True,exist_ok=True)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "You can modify the following to make the csv file with desired number of instances, here we go for 10k to make a quick test" - ] - }, - { - "cell_type": "code", - - "execution_count": 17, - - "metadata": {}, - "outputs": [], - "source": [ - "c4_filename = c4_dir/'c4train_10k.csv'" - ] - }, - { - "cell_type": "code", - - "execution_count": 18, - - "metadata": {}, - "outputs": [], - "source": [ - "c4_generate_csv(c4_filename, iterator, num_examples=10000)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Create a single training file by combining jtrain and c4train" - ] - }, - { - "cell_type": "code", - - "execution_count": 19, - - "metadata": {}, - "outputs": [], - "source": [ - "merge_list = [j_train_file, c4_filename, ]" - ] - }, - { - "cell_type": "code", - - "execution_count": 20, - - "metadata": {}, - "outputs": [], - "source": [ - "import pandas as pd" - ] - }, - { - "cell_type": "code", - - "execution_count": 21, - - "metadata": {}, - "outputs": [], - "source": [ - "combined_csv = pd.concat([pd.read_csv(fn) for fn in merge_list])\n" - ] - }, - { - "cell_type": "code", - - "execution_count": 22, - - "metadata": {}, - "outputs": [], - "source": [ - "merged_name = \"gtrain_10k.csv\"" - ] - }, - { - "cell_type": "code", - - "execution_count": 23, - - "metadata": {}, - "outputs": [], - "source": [ - "combined_csv.to_csv(merged_name, index=False, encoding = 'utf-8-sig', )" - ] - }, - { - "cell_type": "code", - - "execution_count": 24, - - "metadata": {}, - "outputs": [], - "source": [ - "eval_name = \"grammar_validation.csv\"" - ] - - }, - { - "cell_type": "code", - "execution_count": 25, - "metadata": {}, - "outputs": [], - "source": [ - "eval_csv = pd.read_csv(j_eval_file)\n", - "eval_csv.to_csv(eval_name, index=False, encoding = 'utf-8-sig', )" - ] - - } - ], - "metadata": { - "interpreter": { - "hash": "5b2c14c5f2a3b21e6c2412c8196f5145870350e81c0b737cae3e5c60eb1e1eac" - }, - "kernelspec": { - - "display_name": "Python 3 (ipykernel)", - - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.11" - - } - }, - "nbformat": 4, - "nbformat_minor": 4 - -} diff --git a/llama_recipes/datasets_llama_recipes/samsum_dataset.py b/llama_recipes/datasets_llama_recipes/samsum_dataset.py deleted file mode 100644 index c0f11f9..0000000 --- a/llama_recipes/datasets_llama_recipes/samsum_dataset.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -# For dataset details visit: https://huggingface.co/datasets/samsum - -import copy -import datasets - - -def get_preprocessed_samsum(dataset_config, tokenizer, split): - if not hasattr(dataset_config, "trust_remote_code") or not dataset_config.trust_remote_code: - raise ValueError("The repository for samsum contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/samsum. To activate `trust_remote_code` option use this config: --samsum_dataset.trust_remote_code=True") - dataset = datasets.load_dataset("samsum", split=split, trust_remote_code=dataset_config.trust_remote_code) - - prompt = ( - f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n" - ) - - def apply_prompt_template(sample): - return { - "prompt": prompt.format(dialog=sample["dialogue"]), - "summary": sample["summary"], - } - - dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features)) - - def tokenize_add_label(sample): - prompt = tokenizer.encode(tokenizer.bos_token + sample["prompt"], add_special_tokens=False) - summary = tokenizer.encode(sample["summary"] + tokenizer.eos_token, add_special_tokens=False) - - sample = { - "input_ids": prompt + summary, - "attention_mask" : [1] * (len(prompt) + len(summary)), - "labels": [-100] * len(prompt) + summary, - } - - return sample - - dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features)) - - return dataset diff --git a/llama_recipes/datasets_llama_recipes/toxicchat_dataset.py b/llama_recipes/datasets_llama_recipes/toxicchat_dataset.py deleted file mode 100644 index eee54fa..0000000 --- a/llama_recipes/datasets_llama_recipes/toxicchat_dataset.py +++ /dev/null @@ -1,131 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 3.1 Community License Agreement. - -# For dataset details visit: https://huggingface.co/datasets/lmsys/toxic-chat - -import copy -import datasets -import itertools -from llama_recipes.inference.prompt_format_utils import LLAMA_GUARD_3_CATEGORY -import ast -import fire - -def tokenize_prompt_and_labels(full_prompt, tokenizer): - prompt_tokens = tokenizer.encode(full_prompt) - combined_tokens = { - "input_ids": list(prompt_tokens), - "labels": list(prompt_tokens) - } - return dict(combined_tokens, attention_mask=[1]*len(combined_tokens["input_ids"])) - - -from llama_recipes.data.llama_guard.finetuning_data_formatter import TrainingExample, Guidelines, Category, LlamaGuardPromptConfigs, LlamaGuardGenerationConfigs, ExplanationPosition, AugmentationConfigs, FormatterConfigs, create_formatted_finetuning_examples -from datasets import Dataset, DatasetInfo - -def mapTcCategoriesToLGCategories(TcCategoriesString): - TcCategories = ast.literal_eval(TcCategoriesString) - if(len(TcCategories)==0): - return None - ranked = sorted(TcCategories, key=lambda x: x[1], reverse=True) - primary = ranked[0][0] if len(ranked) else None - TcMapping = { - "sexual":"012", - "violence":"01", - "sexual/minors":"04", - "self-harm/intent":"011", - "hate":"010", - "harassment":"010", - "self-harm":"011", - "self-harm/instructions":"011", - "violence/graphic":"01", - "harassment/threatening":"010", - "hate/threatening":"010" - } - return TcMapping[primary] - -def get_llamaguard_toxicchat_dataset(dataset_config, tokenizer, split, return_jsonl=False): - dataset = datasets.load_dataset("lmsys/toxic-chat", 'toxicchat0124', split=split) - - - if return_jsonl: - jsonl_data = []; - for x in dataset.to_iterable_dataset(): - jsonl_data.append({ - "prompt":x['user_input'], - "generation":x['model_output'], - "label":("good" if x['toxicity'] == 0 else "bad"), - "unsafe_content": [mapTcCategoriesToLGCategories(x['openai_moderation'])] - }) - return jsonl_data - - guidelines = Guidelines( - categories=LLAMA_GUARD_3_CATEGORY, - category_code_prefix="O") - - # Setting up configs - llama_guard_prompt_configs = LlamaGuardPromptConfigs( - instructions_format_string="""<|begin_of_text|><|start_header_id|>user<|end_header_id|> Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories. - - - {guidelines} - - - - - {conversation} - - - - Provide your safety assessment for $agent_type in the above conversation: - - First line must read 'safe' or 'unsafe'. - - If unsafe, a second line must include a comma-separated list of violated categories. <|eot_id|><|start_header_id|>assistant<|end_header_id|>""", - should_include_category_descriptions=False, - should_shuffle_category_codes=False - ) - - llama_guard_generation_configs = LlamaGuardGenerationConfigs( - should_list_violated_codes=True, - explanation_position=None - ) - - augmentation_configs = AugmentationConfigs( - should_add_examples_with_dropped_nonviolated_prompt_categories=False, - should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories=False, - explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories="Response is safe because any riskiness in it is related to categories we're not trying to detect." - ) - - formatter_configs = FormatterConfigs( - guidelines=guidelines, - llama_guard_prompt_configs=llama_guard_prompt_configs, - llama_guard_generation_configs=llama_guard_generation_configs, - augmentation_configs=augmentation_configs, - random_seed=42 - ) - - dataset = dataset.map(lambda x: {"full_prompt": create_formatted_finetuning_examples( - [TrainingExample( - prompt=x["user_input"], - response=None, - violated_category_codes = [] if x["toxicity"]==0 else [mapTcCategoriesToLGCategories(x["openai_moderation"])], - label="safe" if x["toxicity"]==0 else "unsafe", - explanation="The response contains violating information." - )], - formatter_configs)[0]}, - remove_columns=list(dataset.features)) - - dataset = dataset.map(lambda x: tokenize_prompt_and_labels(x["full_prompt"], tokenizer), remove_columns=list(dataset.features)) - return dataset - -def main(return_jsonl = False): - from transformers import AutoTokenizer - model_id: str = "/home/ubuntu/LG3-interim-hf-weights" - tokenizer = AutoTokenizer.from_pretrained(model_id) - if return_jsonl: - dataset = get_llamaguard_toxicchat_dataset(None, tokenizer, "train", return_jsonl = True) - print(dataset[0:50]) - else: - dataset = get_llamaguard_toxicchat_dataset(None, tokenizer, "train") - print(dataset[0]) - -if __name__ == '__main__': - fire.Fire(main) diff --git a/llama_recipes/dev_scripts.md b/llama_recipes/dev_scripts.md deleted file mode 100644 index d5bd923..0000000 --- a/llama_recipes/dev_scripts.md +++ /dev/null @@ -1,124 +0,0 @@ -### Various development scripts - -For your viewing pleasure. - -**_Initial testing of the distillation recipe_** - -May need to prepend with `NCCL_CUMEM_ENABLE=0` - -```bash -torchrun --nnodes 1 --nproc_per_node 7 \ -llama_recipes/distill_llama.py \ ---model_config distill_llama3_1_70b_lk_smd_wtk64_fd64_w01 \ ---distill_config distill_alpaca_clean_xent0_mse1000_lr1e-2 \ ---finetune_config finetune_lora_qkvo_alpaca_clean_llama3_1_70b \ ---eval_config eval_alpaca_clean \ ---verbose --replicate 0 --seed 0 \ ---enable_fsdp --low_cpu_fsdp \ ---dataset_chunk_size 1024 \ ---attention_type lolcats_llama_window_tk_bf16 \ ---eval_steps 1 --dataset_chunk_size 256 -``` - -```bash -torchrun --nnodes 1 --nproc_per_node 7 \ -llama_recipes/distill_llama_finetune.py \ ---model_config distill_llama3_1_70b_lk_smd_wtk64_fd64_w01 \ ---distill_config distill_alpaca_clean_xent0_mse1000_lr1e-2 \ ---finetune_config finetune_lora_qkvo_alpaca_clean_llama3_1_70b \ ---eval_config eval_alpaca_clean \ ---verbose --replicate 0 --seed 0 \ ---enable_fsdp --low_cpu_fsdp \ ---dataset_chunk_size 1024 \ ---attention_type lolcats_llama_window_tk_bf16 \ ---eval_steps 1 --dataset_chunk_size 256 -``` - -** fd32 ** - -```bash -PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \ -torchrun --nnodes 1 --nproc_per_node 8 \ -llama_recipes/distill_llama.py \ ---model_config distill_llama3_1_70b_lk_smd_wtk64_fd64_w01 \ ---distill_config distill_alpaca_clean_llama3_1_70b_xent1_mse1000_lr1e-2 \ ---finetune_config finetune_lora_qkvo_alpaca_clean_llama3_1_70b \ ---eval_config eval_alpaca_clean \ ---verbose --replicate 0 --seed 0 \ ---enable_fsdp --low_cpu_fsdp \ ---dataset_chunk_size 512 \ ---attention_type lolcats_llama_window_tk_bf16 \ ---eval_steps 1 -``` - -```bash -PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \ -torchrun --nnodes 1 --nproc_per_node 8 \ -llama_recipes/distill_llama_finetune.py \ ---model_config distill_llama3_1_70b_lk_smd_wtk64_fd64_w01 \ ---distill_config distill_alpaca_clean_llama3_1_70b_xent1_mse1000_lr1e-2 \ ---finetune_config finetune_lora_qkvo_alpaca_clean_llama3_1_70b \ ---eval_config eval_alpaca_clean \ ---verbose --replicate 0 --seed 0 \ ---enable_fsdp --low_cpu_fsdp \ ---dataset_chunk_size 512 \ ---attention_type lolcats_llama_window_tk_bf16 -``` - -** fd32 ** - -```bash -PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \ -torchrun --nnodes 1 --nproc_per_node 7 \ -llama_recipes/distill_llama.py \ ---model_config distill_llama3_1_70b_lk_smd_fd32 \ ---distill_config distill_alpaca_clean_xent1_mse1000_lr1e-2 \ ---finetune_config finetune_lora_wqkvo_alpaca_clean_llama3_1_70b \ ---eval_config eval_alpaca_clean \ ---verbose --replicate 0 --seed 0 \ ---enable_fsdp --low_cpu_fsdp \ ---dataset_chunk_size 1024 \ ---attention_type lolcats_llama_window_tk_bf16 -``` - -```bash -PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \ -torchrun --nnodes 1 --nproc_per_node 7 \ -llama_recipes/distill_llama_finetune.py \ ---model_config distill_llama3_1_70b_lk_smd_wtk64_fd32_w01 \ ---distill_config distill_alpaca_clean_xent1_mse1000_lr1e-2 \ ---finetune_config finetune_lora_wqkvo_alpaca_clean_llama3_1_70b \ ---eval_config eval_alpaca_clean \ ---verbose --replicate 0 --seed 0 \ ---enable_fsdp --low_cpu_fsdp \ ---dataset_chunk_size 1024 \ ---attention_type lolcats_llama_window_tk_bf16 -``` - -** Sliding Window Hybrid Attention ** - -```bash -PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \ -torchrun --nnodes 1 --nproc_per_node 8 \ -llama_recipes/distill_llama.py \ ---model_config distill_llama3_1_70b_lk_smd_wsws64_fd64 \ ---distill_config distill_alpaca_clean_llama3_1_70b_xent1_mse1000_lr1e-2 \ ---finetune_config finetune_lora_qkvo_alpaca_clean_llama3_1_70b \ ---eval_config eval_alpaca_clean \ ---verbose --replicate 0 --seed 0 \ ---enable_fsdp --low_cpu_fsdp \ ---dataset_chunk_size 768 \ -``` - -```bash -PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \ -torchrun --nnodes 1 --nproc_per_node 8 \ -llama_recipes/distill_llama_finetune.py \ ---model_config distill_llama3_1_70b_lk_smd_wsws64_fd64 \ ---distill_config distill_alpaca_clean_llama3_1_70b_xent1_mse1000_lr1e-2 \ ---finetune_config finetune_lora_qkvo_alpaca_clean_llama3_1_70b \ ---eval_config eval_alpaca_clean \ ---verbose --replicate 0 --seed 0 \ ---enable_fsdp --low_cpu_fsdp \ ---dataset_chunk_size 768 \ -``` diff --git a/llama_recipes/distill_llama.py b/llama_recipes/distill_llama.py deleted file mode 100644 index e9b7b04..0000000 --- a/llama_recipes/distill_llama.py +++ /dev/null @@ -1,634 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. -""" -Train learnable linear attentions. Rough adaptation of llama_recipes script for distillation -""" -import os -from os.path import join -# import sys -# sys.path.append('/workspace/lolcats') # needed for vast-ai instances -import dataclasses -import random -import argparse # ours - -from pkg_resources import packaging - -import torch -import torch.optim as optim - -from torch.distributed.fsdp import ( - FullyShardedDataParallel as FSDP, - # ShardingStrategy, - StateDictType # ours -) -from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload - -from llama_recipes.configs import fsdp_config as FSDP_CONFIG -# from llama_recipes.configs import train_config as TRAIN_CONFIG -from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing - -from llama_recipes.utils import fsdp_auto_wrap_policy -from llama_recipes.utils.config_utils import ( - update_config, - get_dataloader_kwargs, -) -# from llama_recipes.utils.dataset_utils import get_preprocessed_dataset - -from llama_recipes.utils.fsdp_utils import ( - hsdp_device_mesh as get_hsdp_device_mesh -) -from llama_recipes.trainer_attention import ( - train as _train_normal, - # freeze_transformer_layers, - setup, - setup_environ_flags, - clear_gpu_cache, - print_model_size, - get_policies, -) -# Ours -from llama_recipes.model_checkpointing.distill_checkpoint_handler import ( - load_model_sharded, - # load_model_checkpoint, - # save_model_checkpoint, -) -# from llama_recipes.trainer_attention_chunked import train as train_chunked -# from torch.distributed.optim import DistributedOptimizer - -from accelerate.utils import is_xpu_available - -# ------------- -# Our arguments -# ------------- -from omegaconf import OmegaConf - -from src.utils.setup import ( - update_config_from_args, - update_model_config_from_args -) -from src.utils.logging import print_header, print_config -from src.dataloaders import load_data -from src.trainer import get_scheduler - -from src.model.pretrained import get_pretrained_loader -from src.model.load_model import ( - load_and_convert_attns, - # load_and_convert_finetune -) -from src.model.convert_model import toggle_attention - - -def get_run_name_from_checkpoint(checkpoint_path: str) -> str: - """Return a string describing the run from checkpoint path""" - name = [] - for s in checkpoint_path.split('/')[-1].split('-'): - try: - s = s.split('=') - s = ''.join([c[0] for c in s[1].split('_')]) - name.append(s) - except IndexError: - pass - return ''.join(name) - - -def get_args(): - """Get attention transfer args""" - parser = argparse.ArgumentParser() - parser.add_argument("--project_name", type=str, default='lolcats') - parser.add_argument("--model_config", type=str, default=None) - parser.add_argument("--distill_config", type=str, default=None) - parser.add_argument("--finetune_config", type=str, default=None) - parser.add_argument("--eval_config", type=str, default=None) - - parser.add_argument("--pretrained_model_name_or_path", type=str, default=None) - parser.add_argument("--load_distill_checkpoint", type=str, default=None) - parser.add_argument("--load_finetune_checkpoint", type=str, default=None) - parser.add_argument("--resume_distill", action='store_true', default=None) - parser.add_argument("--resume_finetune", action='store_true', default=None) - - # Override default configs - # Feature map / model - parser.add_argument("--attention_type", type=str, default=None) - parser.add_argument("--learned_kernel", type=str, default=None) - parser.add_argument("--lk_skip_connection", action='store_true', default=None) - parser.add_argument("--lk_zero_init", action='store_true', default=None) - parser.add_argument("--tie_qk_kernels", action='store_true', default=None) - parser.add_argument("--train_qk", action='store_true', default=None) - parser.add_argument("--state_chunk_len", type=int, default=None) - - # Training - ## Distributed training / Llama recipes - parser.add_argument("--enable_fsdp", action='store_true', default=None) - parser.add_argument("--low_cpu_fsdp", action='store_true', default=None) - parser.add_argument("--pure_bf16", action='store_true', default=None) - parser.add_argument("--fsdp_activation_checkpointing", action='store_true', default=None) - - ## Hyperparameters - parser.add_argument("--lr", type=float, default=None) - parser.add_argument("--weight_decay", type=float, default=None) - parser.add_argument("--optim", type=str, default=None) - parser.add_argument("--scheduler", type=str, default=None) - parser.add_argument("--gradient_accumulation_steps", type=int, default=None) - parser.add_argument("--num_train_epochs", type=int, default=None) - parser.add_argument("--max_steps", type=int, default=None) - parser.add_argument("--no_peft_grad_ckpt", action='store_true', default=None) - - # Finetuning - parser.add_argument("--finetune_lr", type=float, default=None) - parser.add_argument("--finetune_attn_mlps", action='store_true', default=None) - - # Dataloading - parser.add_argument("--dataset_chunk_size", type=int, default=None) - parser.add_argument("--batch_size", type=int, default=None) - parser.add_argument("--num_workers", type=int, default=None) - - # Evaluation - parser.add_argument("--no_init_eval", action='store_true', default=False) - parser.add_argument("--eval_steps", type=int, default=None) - - # Miscellaneous - parser.add_argument("--huggingface_token", type=str, default=None) - parser.add_argument("--checkpoint_dir", type=str, default='./checkpoints') - parser.add_argument("--replicate", type=int, default=0) - parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--verbose", action='store_true', default=None) - parser.add_argument("--no_cuda", action='store_true', default=None) - parser.add_argument("--no_wandb", action='store_true', default=None) - parser.add_argument("--wandb_entity", type=str, default='hazy-research') - parser.add_argument("--debug", action='store_true', default=None) - parser.add_argument("--num_train_steps", type=int, default=-1) - - # DEMO - ## Generation - parser.add_argument("--num_generations", type=int, default=1) - parser.add_argument("--top_k", type=int, default=50) - parser.add_argument("--top_p", type=float, default=0.95) - parser.add_argument("--max_new_tokens", type=int, default=1024) - - ## Miscellaneous - parser.add_argument("--benchmark", action='store_true', default=False) - parser.add_argument("--print_model", action='store_true', default=False) - - args = parser.parse_args() - - distill_name = args.distill_config - finetune_name = args.finetune_config - # Alternative default naming - # if args.load_distill_checkpoint is not None and args.load_distill_checkpoint != 'default': - # distill_name = get_run_name_from_checkpoint(args.load_distill_checkpoint) - # else: - # distill_name = args.distill_config - # if args.load_finetune_checkpoint is not None and args.load_finetune_checkpoint != 'default': - # finetune_name = get_run_name_from_checkpoint(args.load_finetune_checkpoint) - # else: - # finetune_name = args.finetune_config - # if args.load_finetune_long_checkpoint is not None: - # finetune_long_name = get_run_name_from_checkpoint(args.load_finetune_long_checkpoint) - # else: - # finetune_long_name = args.finetune_long_config - - args.run_name = f'dl-d={distill_name}-m={args.model_config}-f={finetune_name}' - if args.no_peft_grad_ckpt is not None: - args.run_name += f'-npgc={args.no_peft_grad_ckpt}' - if args.fsdp_activation_checkpointing is not None: - args.run_name += f'-fac={args.fsdp_activation_checkpointing}' - # if args.dataset_chunk_size is not None: - # args.run_name += f'-dcs={args.dataset_chunk_size}' - # args.run_name += f'-s={args.seed}' - - if args.debug: - args.run_name += '-debug' - - args.run_name = args.run_name.replace('True', '1').replace('False', '0') # concise hacks - return args - - -def setup_wandb(train_config, fsdp_config, run_name = None, - project: str = None, entity: str = None, **kwargs): - """ - Setup WandB for logging - """ - try: - import wandb - except ImportError: - raise ImportError( - "You are trying to use wandb which is not currently installed. " - "Please install it using pip install wandb" - ) - from llama_recipes.configs import wandb_config as WANDB_CONFIG - wandb_config = WANDB_CONFIG(project=project, entity=entity) - update_config(wandb_config, **kwargs) - init_dict = dataclasses.asdict(wandb_config) - run = wandb.init(name=run_name, **init_dict) - run.config.update(train_config) - run.config.update(fsdp_config, allow_val_change=True) - return run - - -def get_dataloaders(train_config, tokenizer, no_shuffle_train: bool = False): - """Return tuple of train_loader, eval_loader, updated train_config""" - dataloaders = load_data(train_config.dataset, train_config.dataloader) - train_loader = dataloaders[train_config.trainer.train_split] - eval_loader = dataloaders[train_config.trainer.val_split] - - # Load and preprocess the dataset for training and validation - dataset_train = train_loader.dataset - dataset_eval = eval_loader.dataset - - if getattr(dataset_train, 'metric', None) is not None: - metric = dataset_train.metric - else: - metric = None - - batching_strategy = getattr(train_config, 'batching_strategy', 'packing') - train_config.batching_strategy = batching_strategy - train_config.batch_size_training = train_config.dataloader.batch_size - train_config.val_batch_size = train_config.dataloader.batch_size - - train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, # shuffle=mode=="train", - "train_no_shuffle" if no_shuffle_train else "train") # hacky patch - val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_eval, tokenizer, "val") - # val_dl_kwargs['collate_fn'] = eval_loader.collate_fn - - # Create DataLoaders for the training and validation dataset - train_loader = torch.utils.data.DataLoader( - dataset_train, - num_workers=train_config.dataloader.num_workers, # train_config.num_workers_dataloader, - pin_memory=False, # True, - **train_dl_kwargs, - ) - eval_loader = torch.utils.data.DataLoader( - dataset_eval, - num_workers=train_config.dataloader.num_workers, - pin_memory=False, # True, - **val_dl_kwargs, - ) - return train_loader, eval_loader, train_config - - -def setup_fsdp_config(config, args, checkpoint_name: str = 'finetune', output_dir: str = None): - """ - Hacky arguments for llama-recipes training function - """ - config.seed = args.seed - config.enable_fsdp = args.enable_fsdp - config.low_cpu_fsdp = args.low_cpu_fsdp - config.dist_checkpoint_root_folder = args.checkpoint_dir - config.dist_checkpoint_folder = checkpoint_name - - config.model_name = args.run_name - config.use_peft = False # We have custom logic for saving PEFT modules - - if getattr(config, 'fsdp', None) is None: - config.save_model = True - config.run_validation = True - config.use_fp16 = False - config.save_model = True - config.save_optimizer = False - config.gradient_clipping = False - config.gradient_clipping_threshold = 1.0 - else: - for attr in ['save_model', 'run_validation', 'use_fp16', 'save_optimizer', - 'gradient_clipping', 'gradient_clipping_threshold']: - setattr(config, attr, getattr(config.fsdp, attr)) - config.output_dir = args.checkpoint_dir if output_dir is None else output_dir - config.save_metrics = not args.no_wandb - config.num_epochs = getattr(config.trainer, 'num_train_epochs', None) - config.num_train_steps = getattr(args, 'num_train_steps', None) # exit training loop early for debugging - config.eval_steps = getattr(config.trainer, 'eval_steps', None) # how many gradient updates before evaluating - return config - - -def main(): - # --------- - # 1. SET UP - # --------- - args = get_args() - args.checkpoint_dir = join(args.checkpoint_dir, args.model_config) - if not os.path.isdir(args.checkpoint_dir): - os.makedirs(args.checkpoint_dir) - - kwargs = vars(args) - - # if 'distill_long' in args.distill_config: - # train = train_chunked - # else: - # train = _train_normal - train = _train_normal - - # Load distillation + attention configs - distill_config_path = join('./configs/experiment', f'{args.distill_config}.yaml') - distill_config = OmegaConf.load(distill_config_path) - distill_config = update_config_from_args(distill_config, args) - - model_config_path = join('./configs/model', f'{args.model_config}.yaml') - model_config = OmegaConf.load(model_config_path) - model_config = update_model_config_from_args(model_config, args) - if args.enable_fsdp: - if getattr(model_config.model, 'load_in_4bit', False): - model_config.model.device_map = 'auto' - elif getattr(model_config.model, 'load_in_8bit', False): - model_config.model.device_map = 'auto' - else: - model_config.model.device_map = None # FSDP will complain about device placement o.w. - - # Update dataset pretrained model config - for k in distill_config.dataset.pretrained_model_config: - distill_config.dataset.pretrained_model_config[k] = getattr(model_config.model, k) - - args.run_name = args.run_name.replace('True', '1').replace('False', '0') # concise hacks - - # Update the configuration for the training and sharding process - distill_config = setup_fsdp_config(distill_config, args, 'distill') # patch llama-recipes args - fsdp_config = FSDP_CONFIG() - update_config((fsdp_config), **vars(args)) - # Set the seeds for reproducibility - if is_xpu_available(): - torch.xpu.manual_seed(distill_config.seed) - torch.manual_seed(distill_config.seed) - random.seed(distill_config.seed) - - if args.enable_fsdp: - setup() - # torchrun specific - local_rank = int(os.environ["LOCAL_RANK"]) - rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - - if (args.enable_fsdp and rank == 0) or not args.enable_fsdp: - print_header('Distillation Config') - print_config(distill_config) - print_header('Model Config') - print_config(model_config) - print_header('FSDP Config') - print_config(dataclasses.asdict(fsdp_config)) - - if torch.distributed.is_initialized(): - if is_xpu_available(): - torch.xpu.set_device(local_rank) - elif torch.cuda.is_available(): - torch.cuda.set_device(local_rank) - clear_gpu_cache(local_rank) - setup_environ_flags(rank) - - wandb_run = None - - if not args.no_wandb: - if not args.enable_fsdp or rank==0: - wandb_run = setup_wandb(distill_config, fsdp_config, **kwargs) - - # ------------------------ - # 2. LOAD PRETRAINED MODEL - # ------------------------ - - # Load the pre-trained model and setup its configuration - # Initialize tokenizer and model loader - model_loader = get_pretrained_loader(**model_config.model, - huggingface_token=args.huggingface_token) - tokenizer = model_loader.load_tokenizer() - tokenizer.pad_token_id = tokenizer.eos_token_id - tokenizer.padding_side = 'left' - - use_cache = False if args.enable_fsdp else None - - if 'llama' in model_config.model.pretrained_model_name_or_path: - from transformers import LlamaConfig as ModelConfig - from transformers.models.llama.modeling_llama import LlamaDecoderLayer as DecoderLayer - from src.model.modeling_llama import LolcatsLlamaForCausalLM as ModelClass - model_type = 'llama' - - # Convert model - try: - args.attention_type = model_config['attention']['attention_type'] - except AttributeError: - args.attention_type = 'lolcats_llama' - - if args.enable_fsdp and args.low_cpu_fsdp: - # for FSDP, we can save cpu memory by loading pretrained model on rank0 only. - # this avoids cpu oom when loading large models like llama 70B, in which case - # model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms - # overhead and currently requires latest nightly. - v = packaging.version.parse(torch.__version__) - verify_latest_nightly = v.is_devrelease and v.dev >= 20230701 - if not verify_latest_nightly and rank == 0: - print(f'-> Pytorch version is {v} ({v.dev})') - print(' - Llama-recipes says "latest pytorch nightly build is required to run with low_cpu_fsdp config"') - print(" - But who knows maybe this will work. We're just trying stuff.") - print(" - Also if PyTorch was installed after July 1, 2023 we should be good.") - # raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, " - # "please install latest nightly.") - model = model_loader.load(args.attention_type) - model.state_chunk_len = model_config['attention']['state_chunk_len'] - # if rank == 0: # Older, but now AutoModelForCausalLM.from_pretrained() handles - # model = model_loader.load(args.attention_type) - # model.state_chunk_len = model_config['attention']['state_chunk_len'] - # # For finetuning, if weights are saved to single .pt file we should load here - # # -> otherwise for sharded state_dicts we load after FSDP wrapping - # else: - # pretrained_config = ModelConfig.from_pretrained(**model_loader.loading_kwargs) - # pretrained_config.use_cache = use_cache - # if getattr(pretrained_config, 'rope_scaling', None) is not None: - # # kinda backwards, but see https://github.com/huggingface/transformers/blob/868d36d29ec132deeaaf8571b25b6a1b911d0145/src/transformers/models/llama/modeling_llama.py#L110 - # pretrained_config.rope_scaling['type'] = 'default' # pretrained_config.rope_scaling['rope_type'] - # with torch.device("meta"): - # model = ModelClass(pretrained_config) - else: - model = model_loader.load(args.attention_type) - model.state_chunk_len = model_config['attention']['state_chunk_len'] - - if args.enable_fsdp and rank == 0 or not args.enable_fsdp: - print_header('Pretrained Model') - - model_config.model_name = model_config.model.pretrained_model_name_or_path - print_model_size(model, model_config, rank if args.enable_fsdp else 0) - - # Prepare the model for int8 training if quantization is enabled - # -> But we only use this script for FSDP without quantization - # if train_config.quantization: - # model = prepare_model_for_int8_training(model) - - # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled - if args.enable_fsdp and fsdp_config.pure_bf16: - model.to(torch.bfloat16) - - # ------------------------------- - # 3. CONVERT + DISTILL ATTENTIONS - # ------------------------------- - model, distill_peft_config = load_and_convert_attns(model, model_config, - attention_type=args.attention_type, - checkpoint_path=None, # args.load_distill_checkpoint, - print_model=args.verbose, - merge_loras=False, - peft_gradient_checkpointing=not args.no_peft_grad_ckpt, - train_attention=True, - rank=rank) - model = toggle_attention(model, train=True) - if 'lora' in args.model_config and rank == 0: # a bit hacky, but we should name model_config to indicate peft - model.print_trainable_parameters() - if wandb_run and distill_peft_config is not None: - wandb_run.config.update(distill_peft_config) - - # if rank == 0: # debugging - # print_header('** Sanity check model weights **') - # for n, p in model.named_parameters(): - # if ('layers.0.' in n and 'base_attn' not in n and - # '.0.mlp.' not in n and '.block_sparse_moe' not in n): - # print(f'-> {n}:\n', p) - - hsdp_device_mesh = None - if fsdp_config.hsdp and fsdp_config.sharding_strategy == ShardingStrategy.HYBRID_SHARD: - hsdp_device_mesh = get_hsdp_device_mesh(replica_group_size=fsdp_config.replica_group_size, sharding_group_size=fsdp_config.sharding_group_size) - print("HSDP device mesh is ready") - - # Setting up FSDP if enable_fsdp is enabled - if args.enable_fsdp: - # if not train_config.use_peft and train_config.freeze_layers: - # freeze_transformer_layers(train_config.num_freeze_layers) - mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank, model=model_type) - my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, DecoderLayer) - - device_id = 0 - if is_xpu_available(): - device_id = torch.xpu.current_device() - elif torch.cuda.is_available(): - device_id = torch.cuda.current_device() - print('-> device_id:', device_id) - - model = FSDP( - model, - auto_wrap_policy=my_auto_wrapping_policy, # if train_config.use_peft else wrapping_policy, - cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None, - mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None, - sharding_strategy=fsdp_config.sharding_strategy, - # device_mesh=hsdp_device_mesh, - device_id=device_id, - limit_all_gathers=True, - sync_module_states=args.low_cpu_fsdp, # train_config.low_cpu_fsdp - param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False) - if args.low_cpu_fsdp and rank != 0 else None, - ) - if fsdp_config.fsdp_activation_checkpointing: - apply_fsdp_checkpointing(model) - - if args.load_distill_checkpoint: - load_model_sharded(model, rank, distill_config) - - else: # if not model_config.model.quantization and not args.enable_fsdp: - if is_xpu_available(): - model.to("xpu:0") - elif torch.cuda.is_available(): - model.to("cuda") - - if (args.verbose and ( - (args.enable_fsdp and rank == 0) or not args.enable_fsdp)): - print_header('*** FSDP MODEL ***') - print(model) - print_header('*** Trainable Parameters ***') - for n, p in model.named_parameters(): - if p.requires_grad: - print(f'├── {n} (dtype = {p.dtype})') - if p.grad is not None: - print(f"├────── Param shape: {p.size()}, Grad shape: {p.grad.size()}") - - # Initialize the optimizer and learning rate scheduler - if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision": - optimizer = AnyPrecisionAdamW( - model.parameters(), - lr=distill_config.optimizer.lr, - momentum_dtype=torch.bfloat16, - variance_dtype=torch.bfloat16, - use_kahan_summation=False, - weight_decay=getattr(distill_config.optimizer, 'weight_decay', 0.), - ) - else: - optimizer = optim.AdamW( - model.parameters(), - lr=distill_config.optimizer.lr, - weight_decay=getattr(distill_config.optimizer, 'weight_decay', 0.), - ) - # optimizer = optim.SGD( - # model.parameters(), - # lr=distill_config.optimizer.lr, - # weight_decay=getattr(distill_config.optimizer, 'weight_decay', 0.), - # ) - # ex.) scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma) - scheduler = get_scheduler(optimizer=optimizer, **distill_config.lr_scheduler) - - for n, p in model.named_parameters(): - if p.requires_grad: - print(f'├── {n} (dtype = {p.dtype})') - if p.grad is not None: - print(f"├────── Param shape: {p.size()}, Grad shape: {p.grad.size()}") - - if args.verbose: - print('-> Optimizer:', optimizer) - print('-> Scheduler:', scheduler) - - # Get data - train_dataloader, eval_dataloader, distill_config = get_dataloaders(distill_config, tokenizer) - if not args.enable_fsdp or rank == 0: - print(f"--> Training Set Length = {len(train_dataloader.dataset)}") - print(f"--> Validation Set Length = {len(eval_dataloader.dataset)}") - - if args.debug: - print('-> local_rank:', local_rank) - x = next(iter(train_dataloader))['input_ids'] - x = x.to(local_rank) - print("-> x = next(iter(train_dataloader))['input_ids']") - print("-> x = x.to(local_rank)") - print('-> x.device:', x.device) - torch.distributed.barrier() - - if (args.enable_fsdp and rank == 0) or not args.enable_fsdp: - print_header('*** Training ***') - if args.verbose: - print_config(distill_config) - # Start the training process - results, best_checkpoint_path = train( - model, - train_dataloader, - eval_dataloader, - tokenizer, - optimizer, - scheduler, - distill_config.trainer.gradient_accumulation_steps, # train_config.gradient_accumulation_steps, - distill_config, # train_config, - fsdp_config if args.enable_fsdp else None, - local_rank if args.enable_fsdp else None, - rank if args.enable_fsdp else None, - wandb_run, - eval_mode = args.replicate == 42, - ) - # if not args.enable_fsdp or rank==0: - # [print(f'Key: {k}, Value: {v}') for k, v in results.items()] - # if not args.no_wandb: - # for k,v in results.items(): - # wandb_run.summary[k] = v - - if not args.enable_fsdp or rank==0: - print(model) - - # Test weights - if fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT: - pass # Model checkpoint already saved - elif fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT: - # Load sharded weights across GPUs into model - load_model_sharded(model, rank, distill_config) # Test loading the sharded weights - # save_model_checkpoint(model, None, rank, distill_config, epoch=distill_config.num_epochs) - if rank == 0: # debugging - print_header('** Sanity check model weights **') - for n, p in model.named_parameters(): - if ('layers.0.' in n and 'base_attn' not in n and - '.0.mlp.' not in n and '.block_sparse_moe' not in n): - print(f'-> {n}:\n', p) - - if not args.enable_fsdp or rank==0: - for k, v in results.items(): - print(f'Key: {k}, Value: {v}') - if not args.no_wandb: - wandb_run.summary[f'attn_{k}'] = v - print('-> Find weights at:', best_checkpoint_path) - - -if __name__ == "__main__": - main() diff --git a/llama_recipes/distill_llama_eval_mmlu.py b/llama_recipes/distill_llama_eval_mmlu.py deleted file mode 100644 index 6629f33..0000000 --- a/llama_recipes/distill_llama_eval_mmlu.py +++ /dev/null @@ -1,389 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -""" -Evaluate MMLU - -Example: - -torchrun --nnodes 1 --nproc_per_node 8 llama_recipes/distill_llama_eval_mmlu.py \ ---model_config llama3_1_70b/distill_llama3_1_70b_lk_smd_wtk64_fd64_w01 \ ---distill_config llama3_1_70b/distill_rp_llama_70b_xent1_mse1000_lr1e-2 \ ---finetune_config llama3_1_70b/finetune_rp_llama_70b_qkvo \ ---eval_config eval_mmlu_debug \ ---verbose --replicate 0 --seed 0 --lk_zero_init --enable_fsdp --low_cpu_fsdp \ ---load_distill_checkpoint /home/simarora/code/lolcats/checkpoints/llama3_1_70b/distill_llama3_1_70b_lk_smd_wtk64_fd64_w01/distill-dl-d=llama3_1_70b/distill_rp_llama_70b_xent1_mse1000_lr1e-2-m=llama3_1_70b/distill_llama3_1_70b_lk_smd_wtk64_fd64_w01-f=llama3_1_70b/finetune_rp_llama_70b-fac=1-dcs=1024-se=0-re=0-lzi=1 \ ---load_finetune_checkpoint /home/simarora/code/lolcats/checkpoints/llama3_1_70b/distill_llama3_1_70b_lk_smd_wtk64_fd64_w01/finetune-dl-d=llama3_1_70b/distill_rp_llama_70b_xent1_mse1000_lr1e-2-m=llama3_1_70b/distill_llama3_1_70b_lk_smd_wtk64_fd64_w01-f=llama3_1_70b/finetune_rp_llama_70b_qkvo-fac=1-dcs=1024-se=0-re=0-lzi=1-dcs=1024-se=0-re=0 -""" -import os -from os.path import join -# import sys -# sys.path.append('/workspace/lolcats') # needed for vast-ai instances -import dataclasses -import random -import argparse # ours -from pkg_resources import packaging - -import torch -import torch.optim as optim - -from torch.distributed.fsdp import ( - FullyShardedDataParallel as FSDP, - ShardingStrategy, - StateDictType -) - -from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload - -from llama_recipes.configs import fsdp_config as FSDP_CONFIG -# from llama_recipes.configs import train_config as TRAIN_CONFIG -from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing - -from llama_recipes.utils.fsdp_utils import fsdp_auto_wrap_policy -from llama_recipes.utils.config_utils import ( - update_config, - # generate_peft_config, - # generate_dataset_config, - # get_dataloader_kwargs, -) -from llama_recipes.utils.fsdp_utils import ( - hsdp_device_mesh as get_hsdp_device_mesh -) -from llama_recipes.trainer_finetune import ( - train as _train_normal, - setup, - setup_environ_flags, - clear_gpu_cache, - print_model_size, - get_policies, -) -from llama_recipes.model_checkpointing.distill_checkpoint_handler import ( - load_model_sharded, - load_sharded_model_single_gpu, -) -# from llama_recipes.trainer_finetune_chunked import train as train_chunked - -from accelerate.utils import is_xpu_available - -# ------------- -# Our arguments -# ------------- -from omegaconf import OmegaConf - -from src.utils.setup import ( - update_config_from_args, - update_model_config_from_args -) -from src.utils.logging import print_header, print_config -# from src.dataloaders import load_data -from src.trainer import get_scheduler - -from src.finetune import prepare_finetune_configs # get_finetuner - -from src.model.pretrained import get_pretrained_loader -from src.model.load_model import ( - load_and_convert_attns, - load_and_convert_finetune -) -from distill_llama import ( - setup_wandb, get_args, # get_run_name_from_checkpoint - setup_fsdp_config -) - -from src.dataloaders.eval_mmlu import load_data - - -def main(): - # --------- - # 1. SET UP - # --------- - args = get_args() - args.checkpoint_dir = join(args.checkpoint_dir, args.model_config) - if not os.path.isdir(args.checkpoint_dir): - os.makedirs(args.checkpoint_dir) - - # dl-d=llama3_1_70b/distill_rp_llama_70b_xent1_mse1000_lr1e-2-m=llama3_1_70b/distill_llama3_1_70b_lk_smd_wtk64_fd64_w01-f=llama3_1_70b/finetune_rp_llama_70b_qkvo-se=0-re=0-lzi=1 - args.run_name += f'-e={args.eval_config}' - kwargs = vars(args) - - # if 'finetune_long' in args.finetune_config: - # train = train_chunked - # else: - # train = _train_normal - train = _train_normal - - # Load distillation + attention configs - distill_config_path = join('./configs/experiment', f'{args.distill_config}.yaml') - distill_config = OmegaConf.load(distill_config_path) - distill_config = update_config_from_args(distill_config, args) - - model_config_path = join('./configs/model', f'{args.model_config}.yaml') - model_config = OmegaConf.load(model_config_path) - model_config = update_model_config_from_args(model_config, args) - if args.enable_fsdp: - if getattr(model_config.model, 'load_in_4bit', False): - model_config.model.device_map = 'auto' - elif getattr(model_config.model, 'load_in_8bit', False): - model_config.model.device_map = 'auto' - else: - model_config.model.device_map = None # FSDP will complain about device placement o.w. - - # Update dataset pretrained model config - for k in distill_config.dataset.pretrained_model_config: - distill_config.dataset.pretrained_model_config[k] = getattr(model_config.model, k) - - args.run_name = args.run_name.replace('True', '1').replace('False', '0') # concise hacks - - # Update the configuration for the training and sharding process - distill_config = setup_fsdp_config(distill_config, args, 'distill') # patch llama-recipes args - fsdp_config = FSDP_CONFIG() - update_config((fsdp_config), **vars(args)) - # Set the seeds for reproducibility - if is_xpu_available(): - torch.xpu.manual_seed(args.seed) - torch.manual_seed(args.seed) - random.seed(args.seed) - - if args.enable_fsdp: - setup() - # torchrun specific - local_rank = int(os.environ["LOCAL_RANK"]) - rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - - if rank == 0 or not args.enable_fsdp: - print_header('Distillation Config') - print_config(distill_config) - print_header('Model Config') - print_config(model_config) - print_header('FSDP Config') - print_config(dataclasses.asdict(fsdp_config)) - - if torch.distributed.is_initialized(): - if is_xpu_available(): - torch.xpu.set_device(local_rank) - elif torch.cuda.is_available(): - torch.cuda.set_device(local_rank) - clear_gpu_cache(local_rank) - setup_environ_flags(rank) - - wandb_run = None - - if not args.no_wandb: - if not args.enable_fsdp or rank==0: - wandb_run = setup_wandb(distill_config, fsdp_config, **kwargs) - - # ------------------------ - # 2. LOAD PRETRAINED MODEL - # ------------------------ - # Load the pre-trained model and setup its configuration - # Initialize tokenizer and model loader - model_loader = get_pretrained_loader(**model_config.model, - huggingface_token=args.huggingface_token) - tokenizer = model_loader.load_tokenizer() - tokenizer.pad_token_id = tokenizer.eos_token_id - tokenizer.padding_side = 'left' - - use_cache = False if args.enable_fsdp else None - - if 'llama' in model_config.model.pretrained_model_name_or_path: - from transformers import LlamaConfig as ModelConfig - from transformers.models.llama.modeling_llama import LlamaDecoderLayer as DecoderLayer - from src.model.modeling_llama import LolcatsLlamaForCausalLM as ModelClass - model_type = 'llama' - - # Convert model - try: - args.attention_type = model_config['attention']['attention_type'] - except AttributeError: - args.attention_type = 'lolcats_llama' - - if args.enable_fsdp and args.low_cpu_fsdp: - # for FSDP, we can save cpu memory by loading pretrained model on rank0 only. - # this avoids cpu oom when loading large models like llama 70B, in which case - # model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms - # overhead and currently requires latest nightly. - v = packaging.version.parse(torch.__version__) - verify_latest_nightly = v.is_devrelease and v.dev >= 20230701 - if not verify_latest_nightly and rank == 0: - print(f'-> Pytorch version is {v} ({v.dev})') - print(f' - Llama-recipes says "latest pytorch nightly build is required to run with low_cpu_fsdp config"') - print(f" - But who knows maybe this will work. We're just trying stuff.") - print(f" - (Also if PyTorch was installed after July 1, 2023 we should be good)") - # raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, " - # "please install latest nightly.") - model = model_loader.load(args.attention_type) - model.state_chunk_len = model_config['attention']['state_chunk_len'] - else: - model = model_loader.load(args.attention_type) - model.state_chunk_len = model_config['attention']['state_chunk_len'] - - if rank == 0 or not args.enable_fsdp: - print_header('Pretrained Model') - - model_config.model_name = model_config.model.pretrained_model_name_or_path - print_model_size(model, model_config, rank if args.enable_fsdp else 0) - - # Prepare the model for int8 training if quantization is enabled - # -> But we only use this script for FSDP without quantization - # if train_config.quantization: - # model = prepare_model_for_int8_training(model) - - # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled - if args.enable_fsdp and fsdp_config.pure_bf16: - model.to(torch.bfloat16) - - # ------------------------------- - # 3. CONVERT DISTILLED ATTENTIONS - # ------------------------------- - model, distill_peft_config = load_and_convert_attns(model, model_config, - attention_type=args.attention_type, - checkpoint_path=None, # args.load_distill_checkpoint, - print_model=args.verbose, - merge_loras=False, - peft_gradient_checkpointing=not args.no_peft_grad_ckpt, - train_attention=False, - rank=rank) - if rank == 0: - print_header('** Sanity check model weights **') - for n, p in model.named_parameters(): - if ('layers.0.' in n and ('feature_map' in n or 'lora' in n)): - print(f'-> {n}:\n', p) - - if distill_config.trainer.name is not None: - if args.load_distill_checkpoint is not None: - model = load_sharded_model_single_gpu(model, model_path=args.load_distill_checkpoint, cfg=distill_config, rank=rank) - else: - model = load_sharded_model_single_gpu(model, model_path=None, cfg=distill_config, rank=rank) - else: - print(" -> Proceeding without learned linear attentions") - - if wandb_run and distill_peft_config is not None: - wandb_run.config.update(distill_peft_config) - - # ---------------------------- - # 4. ADD FINETUNING PARAMETERS - # ---------------------------- - finetune_config, args = prepare_finetune_configs(args, model_config, - args.finetune_config) - # finetune_config = update_config_from_args(finetune_config, args) - finetune_config = setup_fsdp_config(finetune_config, args, 'finetune') - if args.finetune_lr is not None: - finetune_config.model_name += f'=flr={args.finetune_lr}' - - # model, ft_peft_config - model, _ = load_and_convert_finetune(model, finetune_config, - checkpoint_path=None, - print_model=args.verbose, - merge_loras=False, - peft_gradient_checkpointing=not args.no_peft_grad_ckpt, - rank=rank) - - if args.load_finetune_checkpoint is not None: - model = load_sharded_model_single_gpu(model, model_path=args.load_finetune_checkpoint, cfg=finetune_config, rank=rank) - else: - print(" -> Proceeding without finetuned parameters") - - - # ------------------------------------------------------ - # 5. SETUP FSDP AND LOAD DISTILLED ATTENTION CHECKPOINTS - # ------------------------------------------------------ - if args.enable_fsdp: - - mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank, model=model_type) - my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, DecoderLayer) - - device_id = 0 - if is_xpu_available(): - device_id = torch.xpu.current_device() - elif torch.cuda.is_available(): - device_id = torch.cuda.current_device() - print('-> device_id:', device_id) - - model = FSDP( - model, - auto_wrap_policy=my_auto_wrapping_policy, # if train_config.use_peft else wrapping_policy, - cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None, - mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None, - sharding_strategy=fsdp_config.sharding_strategy, - # device_mesh=hsdp_device_mesh, - device_id=device_id, - limit_all_gathers=True, - sync_module_states=args.low_cpu_fsdp, # train_config.low_cpu_fsdp - param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False) - if args.low_cpu_fsdp and rank != 0 else None, - ) - if fsdp_config.fsdp_activation_checkpointing: - apply_fsdp_checkpointing(model) - - # Load distilled checkpoints - if args.verbose and rank == 0: - print_header('*** FSDP Model ***') - print(model) - print('Loading checkpoints from:', distill_config.model_name) - - # load_model_sharded(model, rank, distill_config, model_path=args.load_distill_checkpoint) - - if rank == 0 or not args.enable_fsdp: # debugging - print_header('** Sanity check model weights **') - for n, p in model.named_parameters(): - if ('layers.0.' in n and 'base_attn' not in n and - '.0.mlp.' not in n and '.block_sparse_moe' not in n): - print(f'-> {n}:\n', p) - - - else: # if not model_config.model.quantization and not args.enable_fsdp: - if is_xpu_available(): - model.to("xpu:0") - elif torch.cuda.is_available(): - model.to("cuda") - - if args.verbose and (rank == 0 or not args.enable_fsdp): - print_header('*** FSDP MODEL ***') - print(model) - print_header('*** Trainable Parameters ***') - for n, p in model.named_parameters(): - if p.requires_grad: - print(f'├── {n} (dtype = {p.dtype})') - # print_header('*** model.state_dict() ***') - # for k in model.state_dict().keys(): - # print(f'├── {k}') - - # ----------- - # 5. EVALUATE - # ----------- - from llama_recipes.trainer_eval_mmlu import evaluate_mmlu - - # Get data - eval_config = f'./configs/experiment/{args.eval_config}.yaml' - eval_config = OmegaConf.load(eval_config) - for k in eval_config.dataset.pretrained_model_config: - eval_config.dataset.pretrained_model_config[k] = getattr(model_config.model, k, None) - - if rank == 0 or not args.enable_fsdp: - print_header('*** Eval Config ***') - if args.verbose: - print_config(eval_config) - - eval_dataloader = load_data(**eval_config['dataset'], **eval_config['dataloader']) - - if not args.enable_fsdp or rank == 0: - print(f"--> Validation Set Length = {len(eval_dataloader.dataset)}") - - results = evaluate_mmlu( - model, - finetune_config, - eval_dataloader, - local_rank, - tokenizer, - wandb_run, - epoch=0, - rank=rank, - ) - - if not args.enable_fsdp or rank==0: - for k,v in results.items(): - print(f'{k}:, {v}') - -if __name__ == "__main__": - main() diff --git a/llama_recipes/distill_llama_eval_mmlu_ckpt.py b/llama_recipes/distill_llama_eval_mmlu_ckpt.py deleted file mode 100644 index 16094a6..0000000 --- a/llama_recipes/distill_llama_eval_mmlu_ckpt.py +++ /dev/null @@ -1,410 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -""" -Evaluate MMLU - -torchrun --nnodes 1 --nproc_per_node 8 llama_recipes/distill_llama_eval_mmlu_ckpt.py \ ---model_config llama3_1_70b/distill_llama3_1_70b_lk_smd_wtk64_fd64_w01 \ ---distill_config llama3_1_70b/distill_rp_llama_70b_xent1_mse1000_lr1e-2 \ ---finetune_config llama3_1_70b/finetune_rp_llama_70b_qkvo \ ---eval_config eval_mmlu_debug \ ---verbose --replicate 0 --seed 0 --lk_zero_init --enable_fsdp --low_cpu_fsdp \ ---load_finetune_checkpoint ckpt_lora-dl-d=distill_rp_llama_70b_xent1_mse1000_lr1e-2-m=distill_llama3_1_70b_lk_smd_wtk64_fd64_w01-f=finetune_rp_llama_70b-dcs=1024-se=0-re=4-lzi=1-dcs=1024-se=0-re=4.pt - -""" -import os -from os.path import join -# import sys -# sys.path.append('/workspace/lolcats') # needed for vast-ai instances -import dataclasses -import random -import argparse # ours -from pkg_resources import packaging - -import torch -import torch.optim as optim - -from torch.distributed.fsdp import ( - FullyShardedDataParallel as FSDP, - ShardingStrategy, - StateDictType -) - -from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload - -from llama_recipes.configs import fsdp_config as FSDP_CONFIG -# from llama_recipes.configs import train_config as TRAIN_CONFIG -from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing - -from llama_recipes.utils.fsdp_utils import fsdp_auto_wrap_policy -from llama_recipes.utils.config_utils import ( - update_config, - # generate_peft_config, - # generate_dataset_config, - # get_dataloader_kwargs, -) -from llama_recipes.utils.fsdp_utils import ( - hsdp_device_mesh as get_hsdp_device_mesh -) -from llama_recipes.trainer_finetune import ( - train as _train_normal, - setup, - setup_environ_flags, - clear_gpu_cache, - print_model_size, - get_policies, -) -from llama_recipes.model_checkpointing.distill_checkpoint_handler import ( - load_model_sharded, - load_sharded_model_single_gpu, -) -# from llama_recipes.trainer_finetune_chunked import train as train_chunked - -from accelerate.utils import is_xpu_available - -# ------------- -# Our arguments -# ------------- -from omegaconf import OmegaConf - -from src.utils.setup import ( - update_config_from_args, - update_model_config_from_args -) -from src.utils.logging import print_header, print_config -# from src.dataloaders import load_data -from src.trainer import get_scheduler - -from src.finetune import prepare_finetune_configs # get_finetuner - -from src.model.pretrained import get_pretrained_loader -from src.model.load_model import ( - load_and_convert_attns, - load_and_convert_finetune -) -from distill_llama import ( - setup_wandb, get_args, # get_run_name_from_checkpoint - setup_fsdp_config -) - -from src.dataloaders.eval_mmlu import load_data - - -def check_state_dict_keys(_keys, layer_idx, rank: int = 0): - try: - assert len(_keys.unexpected_keys) == 0 - if rank == 0: - print_header(f'*** All expected keys matched successfully {layer_idx} ***') - except Exception as e: - if rank == 0: - print(e) - print_header('*** Error: unexpected keys in checkpoint ***') - print(f'Unexpected keys at {layer_idx}:') - for k in _keys.unexpected_keys: - print(k) - - -def main(): - # --------- - # 1. SET UP - # --------- - args = get_args() - args.checkpoint_dir = join(args.checkpoint_dir, args.model_config) - if not os.path.isdir(args.checkpoint_dir): - os.makedirs(args.checkpoint_dir) - - # dl-d=llama3_1_70b/distill_rp_llama_70b_xent1_mse1000_lr1e-2-m=llama3_1_70b/distill_llama3_1_70b_lk_smd_wtk64_fd64_w01-f=llama3_1_70b/finetune_rp_llama_70b_qkvo-se=0-re=0-lzi=1 - args.run_name += f'-e={args.eval_config}' - - kwargs = vars(args) - - # if 'finetune_long' in args.finetune_config: - # train = train_chunked - # else: - # train = _train_normal - train = _train_normal - - # Load distillation + attention configs - distill_config_path = join('./configs/experiment', f'{args.distill_config}.yaml') - distill_config = OmegaConf.load(distill_config_path) - distill_config = update_config_from_args(distill_config, args) - - model_config_path = join('./configs/model', f'{args.model_config}.yaml') - model_config = OmegaConf.load(model_config_path) - model_config = update_model_config_from_args(model_config, args) - - if args.enable_fsdp: - if getattr(model_config.model, 'load_in_4bit', False): - model_config.model.device_map = 'auto' - elif getattr(model_config.model, 'load_in_8bit', False): - model_config.model.device_map = 'auto' - else: - model_config.model.device_map = None # FSDP will complain about device placement o.w. - - try: - if not os.path.exists(model_config.model.pretrained_model_name_or_path): - print(f"Model path {model_config.model.pretrained_model_name_or_path} does not exist. Using backup path. {model_config.model.pretrained_model_name_or_path_backup}") - model_config.model.pretrained_model_name_or_path = model_config.model.pretrained_model_name_or_path_backup - model_config.model.pop("pretrained_model_name_or_path_backup") - except: - print(f"Model without model.pretrained_model_name_or_path_backup path") - pass - - # Update dataset pretrained model config - for k in distill_config.dataset.pretrained_model_config: - distill_config.dataset.pretrained_model_config[k] = getattr(model_config.model, k, None) - - args.run_name = args.run_name.replace('True', '1').replace('False', '0') # concise hacks - - # Update the configuration for the training and sharding process - distill_config = setup_fsdp_config(distill_config, args, 'distill') # patch llama-recipes args - fsdp_config = FSDP_CONFIG() - update_config((fsdp_config), **vars(args)) - # Set the seeds for reproducibility - if is_xpu_available(): - torch.xpu.manual_seed(args.seed) - torch.manual_seed(args.seed) - random.seed(args.seed) - - if args.enable_fsdp: - setup() - # torchrun specific - local_rank = int(os.environ["LOCAL_RANK"]) - rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - - if rank == 0 or not args.enable_fsdp: - print_header('Distillation Config') - print_config(distill_config) - print_header('Model Config') - print_config(model_config) - print_header('FSDP Config') - print_config(dataclasses.asdict(fsdp_config)) - - if torch.distributed.is_initialized(): - if is_xpu_available(): - torch.xpu.set_device(local_rank) - elif torch.cuda.is_available(): - torch.cuda.set_device(local_rank) - clear_gpu_cache(local_rank) - setup_environ_flags(rank) - - wandb_run = None - - if not args.no_wandb: - if not args.enable_fsdp or rank==0: - wandb_run = setup_wandb(distill_config, fsdp_config, **kwargs) - - # ------------------------ - # 2. LOAD PRETRAINED MODEL - # ------------------------ - # Load the pre-trained model and setup its configuration - # Initialize tokenizer and model loader - model_loader = get_pretrained_loader(**model_config.model, - huggingface_token=args.huggingface_token) - tokenizer = model_loader.load_tokenizer() - tokenizer.pad_token_id = tokenizer.eos_token_id - tokenizer.padding_side = 'left' - - use_cache = False if args.enable_fsdp else None - - if 'lama' in model_config.model.pretrained_model_name_or_path: - from transformers import LlamaConfig as ModelConfig - from transformers.models.llama.modeling_llama import LlamaDecoderLayer as DecoderLayer - from src.model.modeling_llama import LolcatsLlamaForCausalLM as ModelClass - model_type = 'llama' - - # Convert model - try: - args.attention_type = model_config['attention']['attention_type'] - except AttributeError: - args.attention_type = 'lolcats_llama' - - if args.enable_fsdp and args.low_cpu_fsdp: - # for FSDP, we can save cpu memory by loading pretrained model on rank0 only. - # this avoids cpu oom when loading large models like llama 70B, in which case - # model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms - # overhead and currently requires latest nightly. - v = packaging.version.parse(torch.__version__) - verify_latest_nightly = v.is_devrelease and v.dev >= 20230701 - if not verify_latest_nightly and rank == 0: - print(f'-> Pytorch version is {v} ({v.dev})') - print(f' - Llama-recipes says "latest pytorch nightly build is required to run with low_cpu_fsdp config"') - print(f" - But who knows maybe this will work. We're just trying stuff.") - print(f" - (Also if PyTorch was installed after July 1, 2023 we should be good)") - # raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, " - # "please install latest nightly.") - model = model_loader.load(args.attention_type) - model.state_chunk_len = model_config['attention']['state_chunk_len'] - else: - model = model_loader.load(args.attention_type) - model.state_chunk_len = model_config['attention']['state_chunk_len'] - - if rank == 0 or not args.enable_fsdp: - print_header('Pretrained Model') - - model_config.model_name = model_config.model.pretrained_model_name_or_path - print_model_size(model, model_config, rank if args.enable_fsdp else 0) - - # Prepare the model for int8 training if quantization is enabled - # -> But we only use this script for FSDP without quantization - # if train_config.quantization: - # model = prepare_model_for_int8_training(model) - - # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled - if args.enable_fsdp and fsdp_config.pure_bf16: - model.to(torch.bfloat16) - - # ------------------------------- - # 3. CONVERT DISTILLED ATTENTIONS - # ------------------------------- - model, distill_peft_config = load_and_convert_attns(model, model_config, - attention_type=args.attention_type, - checkpoint_path=None, # args.load_distill_checkpoint, - print_model=args.verbose, - merge_loras=False, - peft_gradient_checkpointing=not args.no_peft_grad_ckpt, - train_attention=False, - rank=rank) - if rank == 0: - print_header('** Sanity check model weights **') - for n, p in model.named_parameters(): - if ('layers.0.' in n and ('feature_map' in n or 'lora' in n)): - print(f'-> {n}:\n', p) - - # if distill_config.trainer.name is not None and args.load_distill_checkpoint is not None: - # if args.load_distill_checkpoint is not None: - # model = load_sharded_model_single_gpu(model, model_path=args.load_distill_checkpoint, cfg=distill_config, rank=rank) - # else: - # model = load_sharded_model_single_gpu(model, model_path=None, cfg=distill_config, rank=rank) - print(" -> Proceeding without learned linear attentions") # load these with the checkpoint - - if wandb_run and distill_peft_config is not None: - wandb_run.config.update(distill_peft_config) - - # ---------------------------- - # 4. ADD FINETUNING PARAMETERS - # ---------------------------- - finetune_config, args = prepare_finetune_configs(args, model_config, - args.finetune_config) - # finetune_config = update_config_from_args(finetune_config, args) - finetune_config = setup_fsdp_config(finetune_config, args, 'finetune') - if args.finetune_lr is not None: - finetune_config.model_name += f'=flr={args.finetune_lr}' - - # model, ft_peft_config - model, _ = load_and_convert_finetune(model, finetune_config, - checkpoint_path=None, - print_model=args.verbose, - merge_loras=False, - peft_gradient_checkpointing=not args.no_peft_grad_ckpt, - rank=rank) - - if args.load_finetune_checkpoint is not None: - if '.pt' in args.load_finetune_checkpoint: - with torch.no_grad(): - _keys = model.load_state_dict(torch.load(args.load_finetune_checkpoint), strict=False) - check_state_dict_keys(_keys, 0) - else: - model = load_sharded_model_single_gpu(model, model_path=args.load_finetune_checkpoint, cfg=finetune_config, rank=rank) - else: - print(" -> Proceeding without finetuned parameters") - - if rank == 0: - print_header('** Sanity check model weights **') - for n, p in model.named_parameters(): - if ('layers.0.' in n and ('feature_map' in n or 'lora' in n or 'weight_factors' in n)): - print(f'-> {n}:\n', p) - - - # ------------------------------------------------------ - # 5. SETUP FSDP AND LOAD DISTILLED ATTENTION CHECKPOINTS - # ------------------------------------------------------ - if args.enable_fsdp: - - mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank, model=model_type) - my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, DecoderLayer) - - device_id = 0 - if is_xpu_available(): - device_id = torch.xpu.current_device() - elif torch.cuda.is_available(): - device_id = torch.cuda.current_device() - print('-> device_id:', device_id) - - model = FSDP( - model, - auto_wrap_policy=my_auto_wrapping_policy, # if train_config.use_peft else wrapping_policy, - cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None, - mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None, - sharding_strategy=fsdp_config.sharding_strategy, - # device_mesh=hsdp_device_mesh, - device_id=device_id, - limit_all_gathers=True, - sync_module_states=args.low_cpu_fsdp, # train_config.low_cpu_fsdp - param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False) - if args.low_cpu_fsdp and rank != 0 else None, - ) - if fsdp_config.fsdp_activation_checkpointing: - apply_fsdp_checkpointing(model) - - else: # if not model_config.model.quantization and not args.enable_fsdp: - if is_xpu_available(): - model.to("xpu:0") - elif torch.cuda.is_available(): - model.to("cuda") - - if args.verbose and (rank == 0 or not args.enable_fsdp): - print_header('*** FSDP MODEL ***') - print(model) - print_header('*** Trainable Parameters ***') - for n, p in model.named_parameters(): - if p.requires_grad: - print(f'├── {n} (dtype = {p.dtype})') - # print_header('*** model.state_dict() ***') - # for k in model.state_dict().keys(): - # print(f'├── {k}') - - # ----------- - # 5. EVALUATE - # ----------- - from llama_recipes.trainer_eval_mmlu import evaluate_mmlu - - # Get data - eval_config = f'./configs/experiment/{args.eval_config}.yaml' - eval_config = OmegaConf.load(eval_config) - for k in eval_config.dataset.pretrained_model_config: - eval_config.dataset.pretrained_model_config[k] = getattr(model_config.model, k, None) - - if rank == 0 or not args.enable_fsdp: - print_header('*** Eval Config ***') - if args.verbose: - print_config(eval_config) - - eval_dataloader = load_data(**eval_config['dataset'], **eval_config['dataloader']) - - if not args.enable_fsdp or rank == 0: - print(f"--> Validation Set Length = {len(eval_dataloader.dataset)}") - - results = evaluate_mmlu( - model, - finetune_config, - eval_dataloader, - local_rank, - tokenizer, - wandb_run, - epoch=0, - rank=rank, - ) - - if not args.enable_fsdp or rank==0: - for k,v in results.items(): - print(f'{k}: {v}') - if 'all' in results: - print_header('*** MMLU ***') - for k, v in results['all'].items(): - print(f'{k}: {v}') - -if __name__ == "__main__": - main() diff --git a/llama_recipes/distill_llama_finetune.py b/llama_recipes/distill_llama_finetune.py deleted file mode 100644 index 7e4347d..0000000 --- a/llama_recipes/distill_llama_finetune.py +++ /dev/null @@ -1,449 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -""" -Finetune attention-swapped model. Rough adaptation of llama_recipes script for distillation. -""" -import os -from os.path import join -# import sys -# sys.path.append('/workspace/lolcats') # needed for vast-ai instances -import dataclasses -import random -import argparse # ours -from pkg_resources import packaging - -import torch -import torch.optim as optim - -from torch.distributed.fsdp import ( - FullyShardedDataParallel as FSDP, - ShardingStrategy, - StateDictType -) - -from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload - -from llama_recipes.configs import fsdp_config as FSDP_CONFIG -# from llama_recipes.configs import train_config as TRAIN_CONFIG -from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing - -from llama_recipes.utils.fsdp_utils import fsdp_auto_wrap_policy -from llama_recipes.utils.config_utils import ( - update_config, - # generate_peft_config, - # generate_dataset_config, - # get_dataloader_kwargs, -) -from llama_recipes.utils.fsdp_utils import ( - hsdp_device_mesh as get_hsdp_device_mesh -) -from llama_recipes.trainer_finetune import ( - train as _train_normal, - setup, - setup_environ_flags, - clear_gpu_cache, - print_model_size, - get_policies, -) -from llama_recipes.model_checkpointing.distill_checkpoint_handler import ( - load_model_sharded, - load_sharded_model_single_gpu, -) -# from llama_recipes.trainer_finetune_chunked import train as train_chunked - -from accelerate.utils import is_xpu_available - -# ------------- -# Our arguments -# ------------- -from omegaconf import OmegaConf - -from src.utils.setup import ( - update_config_from_args, - update_model_config_from_args -) -from src.utils.logging import print_header, print_config -# from src.dataloaders import load_data -from src.trainer import get_scheduler - -from src.finetune import prepare_finetune_configs # get_finetuner - -from src.model.pretrained import get_pretrained_loader -from src.model.load_model import ( - load_and_convert_attns, - load_and_convert_finetune -) -from distill_llama import ( - setup_wandb, get_args, # get_run_name_from_checkpoint - get_dataloaders, setup_fsdp_config -) - - -def main(): - # --------- - # 1. SET UP - # --------- - args = get_args() - args.checkpoint_dir = join(args.checkpoint_dir, args.model_config) - if not os.path.isdir(args.checkpoint_dir): - os.makedirs(args.checkpoint_dir) - - kwargs = vars(args) - - # if 'finetune_long' in args.finetune_config: - # train = train_chunked - # else: - # train = _train_normal - train = _train_normal - - # Load distillation + attention configs - distill_config_path = join('./configs/experiment', f'{args.distill_config}.yaml') - distill_config = OmegaConf.load(distill_config_path) - distill_config = update_config_from_args(distill_config, args) - - model_config_path = join('./configs/model', f'{args.model_config}.yaml') - model_config = OmegaConf.load(model_config_path) - model_config = update_model_config_from_args(model_config, args) - if args.enable_fsdp: - if getattr(model_config.model, 'load_in_4bit', False): - model_config.model.device_map = 'auto' - elif getattr(model_config.model, 'load_in_8bit', False): - model_config.model.device_map = 'auto' - else: - model_config.model.device_map = None # FSDP will complain about device placement o.w. - - # Update dataset pretrained model config - for k in distill_config.dataset.pretrained_model_config: - distill_config.dataset.pretrained_model_config[k] = getattr(model_config.model, k) - - args.run_name = args.run_name.replace('True', '1').replace('False', '0') # concise hacks - - # Update the configuration for the training and sharding process - distill_config = setup_fsdp_config(distill_config, args, 'distill') # patch llama-recipes args - fsdp_config = FSDP_CONFIG() - update_config((fsdp_config), **vars(args)) - # Set the seeds for reproducibility - if is_xpu_available(): - torch.xpu.manual_seed(args.seed) - torch.manual_seed(args.seed) - random.seed(args.seed) - - if args.enable_fsdp: - setup() - # torchrun specific - local_rank = int(os.environ["LOCAL_RANK"]) - rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - - if rank == 0 or not args.enable_fsdp: - print_header('Distillation Config') - print_config(distill_config) - print_header('Model Config') - print_config(model_config) - print_header('FSDP Config') - print_config(dataclasses.asdict(fsdp_config)) - - if torch.distributed.is_initialized(): - if is_xpu_available(): - torch.xpu.set_device(local_rank) - elif torch.cuda.is_available(): - torch.cuda.set_device(local_rank) - clear_gpu_cache(local_rank) - setup_environ_flags(rank) - - wandb_run = None - - if not args.no_wandb: - if not args.enable_fsdp or rank==0: - wandb_run = setup_wandb(distill_config, fsdp_config, **kwargs) - - # ------------------------ - # 2. LOAD PRETRAINED MODEL - # ------------------------ - - # Load the pre-trained model and setup its configuration - # Initialize tokenizer and model loader - model_loader = get_pretrained_loader(**model_config.model, - huggingface_token=args.huggingface_token) - tokenizer = model_loader.load_tokenizer() - tokenizer.pad_token_id = tokenizer.eos_token_id - tokenizer.padding_side = 'left' - - use_cache = False if args.enable_fsdp else None - - if 'llama' in model_config.model.pretrained_model_name_or_path: - from transformers import LlamaConfig as ModelConfig - from transformers.models.llama.modeling_llama import LlamaDecoderLayer as DecoderLayer - from src.model.modeling_llama import LolcatsLlamaForCausalLM as ModelClass - model_type = 'llama' - - # Convert model - try: - args.attention_type = model_config['attention']['attention_type'] - except AttributeError: - args.attention_type = 'lolcats_llama' - - if args.enable_fsdp and args.low_cpu_fsdp: - # for FSDP, we can save cpu memory by loading pretrained model on rank0 only. - # this avoids cpu oom when loading large models like llama 70B, in which case - # model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms - # overhead and currently requires latest nightly. - v = packaging.version.parse(torch.__version__) - verify_latest_nightly = v.is_devrelease and v.dev >= 20230701 - if not verify_latest_nightly and rank == 0: - print(f'-> Pytorch version is {v} ({v.dev})') - print(f' - Llama-recipes says "latest pytorch nightly build is required to run with low_cpu_fsdp config"') - print(f" - But who knows maybe this will work. We're just trying stuff.") - print(f" - (Also if PyTorch was installed after July 1, 2023 we should be good)") - # raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, " - # "please install latest nightly.") - model = model_loader.load(args.attention_type) - model.state_chunk_len = model_config['attention']['state_chunk_len'] - # if rank == 0: - # model = model_loader.load(args.attention_type) - # model.state_chunk_len = model_config['attention']['state_chunk_len'] - # # For finetuning, if weights are saved to single .pt file we should load here - # # -> otherwise for sharded state_dicts we load after FSDP wrapping - # else: - # pretrained_config = ModelConfig.from_pretrained(**model_loader.loading_kwargs) - # pretrained_config.use_cache = use_cache - # if getattr(pretrained_config, 'rope_scaling', None) is not None: - # # kinda backwards, but see https://github.com/huggingface/transformers/blob/868d36d29ec132deeaaf8571b25b6a1b911d0145/src/transformers/models/llama/modeling_llama.py#L110 - # pretrained_config.rope_scaling['type'] = pretrained_config.rope_scaling['rope_type'] - # with torch.device("meta"): - # model = ModelClass(pretrained_config) - else: - model = model_loader.load(args.attention_type) - model.state_chunk_len = model_config['attention']['state_chunk_len'] - - if rank == 0 or not args.enable_fsdp: - print_header('Pretrained Model') - - model_config.model_name = model_config.model.pretrained_model_name_or_path - print_model_size(model, model_config, rank if args.enable_fsdp else 0) - - # Prepare the model for int8 training if quantization is enabled - # -> But we only use this script for FSDP without quantization - # if train_config.quantization: - # model = prepare_model_for_int8_training(model) - - # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled - if args.enable_fsdp and fsdp_config.pure_bf16: - model.to(torch.bfloat16) - - # ------------------------------- - # 3. CONVERT DISTILLED ATTENTIONS - # ------------------------------- - # if 'peft_config' in model_config['attention']: # Hack rn, but we assume finetuning LoRAs are a superset of - # del model_config['attention']['peft_config'] # distilled attention LoRAs. So we only adapt the model once (when calling load_and_convert_finetune) - # elif 'peft' in model_config['attention']: - # del model_config['attention']['peft'] - model, distill_peft_config = load_and_convert_attns(model, model_config, - attention_type=args.attention_type, - checkpoint_path=None, # args.load_distill_checkpoint, - print_model=args.verbose, - merge_loras=False, - peft_gradient_checkpointing=not args.no_peft_grad_ckpt, - train_attention=False, - rank=rank) - if rank == 0: - print_header('** Sanity check model weights **') - for n, p in model.named_parameters(): - if ('layers.0.' in n and ('feature_map' in n or 'lora' in n)): - print(f'-> {n}:\n', p) - - if distill_config.trainer.name is not None: - if args.load_distill_checkpoint is not None: - model = load_sharded_model_single_gpu(model, model_path=args.load_distill_checkpoint, cfg=distill_config, rank=rank) - else: - model = load_sharded_model_single_gpu(model, model_path=None, cfg=distill_config, rank=rank) - else: - print(" -> Proceeding without learned linear attentions") - - # model.print_trainable_parameters() - if wandb_run and distill_peft_config is not None: - wandb_run.config.update(distill_peft_config) - - # ---------------------------- - # 4. ADD FINETUNING PARAMETERS - # ---------------------------- - finetune_config, args = prepare_finetune_configs(args, model_config, - args.finetune_config) - # finetune_config = update_config_from_args(finetune_config, args) - finetune_config = setup_fsdp_config(finetune_config, args, 'finetune') - if args.finetune_lr is not None: - finetune_config.model_name += f'=flr={args.finetune_lr}' - - # model, ft_peft_config - model, _ = load_and_convert_finetune(model, finetune_config, - checkpoint_path=None, - print_model=args.verbose, - merge_loras=False, - peft_gradient_checkpointing=not args.no_peft_grad_ckpt, - rank=rank) - if rank == 0 and args.resume_finetune: - model = load_sharded_model_single_gpu(model, model_path=None, - cfg=finetune_config, rank=rank) - hsdp_device_mesh = None - if fsdp_config.hsdp and fsdp_config.sharding_strategy == ShardingStrategy.HYBRID_SHARD: - hsdp_device_mesh = get_hsdp_device_mesh(replica_group_size=fsdp_config.replica_group_size, - sharding_group_size=fsdp_config.sharding_group_size) - print("HSDP device mesh is ready") - - # ------------------------------------------------------ - # 5. SETUP FSDP AND LOAD DISTILLED ATTENTION CHECKPOINTS - # ------------------------------------------------------ - if args.enable_fsdp: - - mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank, model=model_type) - my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, DecoderLayer) - - device_id = 0 - if is_xpu_available(): - device_id = torch.xpu.current_device() - elif torch.cuda.is_available(): - device_id = torch.cuda.current_device() - print('-> device_id:', device_id) - - model = FSDP( - model, - auto_wrap_policy=my_auto_wrapping_policy, # if train_config.use_peft else wrapping_policy, - cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None, - mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None, - sharding_strategy=fsdp_config.sharding_strategy, - # device_mesh=hsdp_device_mesh, - device_id=device_id, - limit_all_gathers=True, - sync_module_states=args.low_cpu_fsdp, # train_config.low_cpu_fsdp - param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False) - if args.low_cpu_fsdp and rank != 0 else None, - ) - if fsdp_config.fsdp_activation_checkpointing: - apply_fsdp_checkpointing(model) - - # Load distilled checkpoints - if args.verbose and rank == 0: - print_header('*** FSDP Model ***') - print(model) - print('Loading checkpoints from:', distill_config.model_name) - - # load_model_sharded(model, rank, distill_config, model_path=args.load_distill_checkpoint) - - if rank == 0 or not args.enable_fsdp: # debugging - print_header('** Sanity check model weights **') - for n, p in model.named_parameters(): - if ('layers.0.' in n and 'base_attn' not in n and - '.0.mlp.' not in n and '.block_sparse_moe' not in n): - print(f'-> {n}:\n', p) - - - else: # if not model_config.model.quantization and not args.enable_fsdp: - if is_xpu_available(): - model.to("xpu:0") - elif torch.cuda.is_available(): - model.to("cuda") - - if args.verbose and (rank == 0 or not args.enable_fsdp): - print_header('*** FSDP MODEL ***') - print(model) - print_header('*** Trainable Parameters ***') - for n, p in model.named_parameters(): - if p.requires_grad: - print(f'├── {n} (dtype = {p.dtype})') - # print_header('*** model.state_dict() ***') - # for k in model.state_dict().keys(): - # print(f'├── {k}') - - # Initialize the optimizer and learning rate scheduler - if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision": - optimizer = AnyPrecisionAdamW( - model.parameters(), - lr=finetune_config.optimizer.lr, - momentum_dtype=torch.bfloat16, - variance_dtype=torch.bfloat16, - use_kahan_summation=False, - weight_decay=getattr(distill_config.optimizer, 'weight_decay', 0.), - ) - else: - optimizer = optim.AdamW( - model.parameters(), - lr=args.finetune_lr if args.finetune_lr is not None else finetune_config.optimizer.lr, - weight_decay=getattr(distill_config.optimizer, 'weight_decay', 0.), - ) - # scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma) - scheduler = get_scheduler(optimizer=optimizer, **finetune_config.lr_scheduler) - - if args.verbose and rank == 0: - print('-> Optimizer:', optimizer) - print('-> Scheduler:', scheduler) - - # Get data - train_dataloader, eval_dataloader, finetune_config = get_dataloaders(finetune_config, tokenizer) - if not args.enable_fsdp or rank == 0: - print(f"--> Training Set Length = {len(train_dataloader.dataset)}") - print(f"--> Validation Set Length = {len(eval_dataloader.dataset)}") - - if args.debug: - print('-> local_rank:', local_rank) - x = next(iter(train_dataloader))['input_ids'] - x = x.to(local_rank) - print("-> x = next(iter(train_dataloader))['input_ids']") - print("-> x = x.to(local_rank)") - print('-> x.device:', x.device) - - # ----------- - # 5. FINETUNE - # ----------- - if rank == 0 or not args.enable_fsdp: - print_header('*** Training ***') - if args.verbose: - print_config(finetune_config) - # Start the training process - results, best_checkpoint_path = train( - model, - train_dataloader, - eval_dataloader, - tokenizer, - optimizer, - scheduler, - finetune_config.trainer.gradient_accumulation_steps, # train_config.gradient_accumulation_steps, - finetune_config, # train_config, - fsdp_config if args.enable_fsdp else None, - local_rank if args.enable_fsdp else None, - rank if args.enable_fsdp else None, - wandb_run, - stepwise_scheduler=finetune_config.lr_scheduler.lr_scheduler_type != 'reduce_lr_on_plateau', - ) - # if not args.enable_fsdp or rank==0: - # for k,v in results.items(): - # print(f'Key: {k}, Value: {v}') - # if not args.no_wandb: - # wandb_run.summary[f'ft_{k}'] = v - - # Save best model checkpoint as single .pt file - if fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT: - pass # Model checkpoint already saved - elif fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT: - # Load sharded weights across GPUs into model - ignore_param_rule = lambda n, p: ( - not p.requires_grad # and 'feature_map' not in n or ('v_proj' in n or 'o_proj' in n) - ) - load_model_sharded(model, rank, finetune_config, ignore_param_rule) - if rank == 0: # debugging - print_header('** Sanity check model weights **') - for n, p in model.named_parameters(): - if ('layers.0.' in n and 'base_attn' not in n and - '.0.mlp.' not in n and '.block_sparse_moe' not in n): - print(f'-> {n}:\n', p) - - if not args.enable_fsdp or rank==0: - for k,v in results.items(): - print(f'Key: {k}, Value: {v}') - if not args.no_wandb: - wandb_run.summary[f'ft_{k}'] = v - print('-> Find weights at:', best_checkpoint_path) - - -if __name__ == "__main__": - main() diff --git a/llama_recipes/finetuning.py b/llama_recipes/finetuning.py deleted file mode 100644 index 76dd4d5..0000000 --- a/llama_recipes/finetuning.py +++ /dev/null @@ -1,288 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -from collections import Counter -import os - -import dataclasses -import fire -import random -import torch -import torch.optim as optim -from peft import get_peft_model, PeftModel -from torch.distributed.fsdp import ( - FullyShardedDataParallel as FSDP, - ShardingStrategy -) - -from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload -from torch.optim.lr_scheduler import StepLR -from transformers import ( - AutoTokenizer, - BitsAndBytesConfig, - LlamaForCausalLM, - LlamaConfig, -) -from transformers.models.llama.modeling_llama import LlamaDecoderLayer - -from llama_recipes.configs import fsdp_config as FSDP_CONFIG -from llama_recipes.configs import train_config as TRAIN_CONFIG -from llama_recipes.configs import quantization_config as QUANTIZATION_CONFIG -from llama_recipes.data.concatenator import ConcatDataset -from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing - -from llama_recipes.utils import fsdp_auto_wrap_policy -from llama_recipes.utils.config_utils import ( - update_config, - generate_peft_config, - generate_dataset_config, - get_dataloader_kwargs, -) -from llama_recipes.utils.dataset_utils import get_preprocessed_dataset - -from llama_recipes.utils.fsdp_utils import hsdp_device_mesh -from llama_recipes.utils.train_utils import ( - train, - freeze_transformer_layers, - setup, - setup_environ_flags, - clear_gpu_cache, - print_model_size, - get_policies, -) -from accelerate.utils import is_xpu_available -from warnings import warn - -def setup_wandb(train_config, fsdp_config, **kwargs): - try: - import wandb - except ImportError: - raise ImportError( - "You are trying to use wandb which is not currently installed. " - "Please install it using pip install wandb" - ) - from llama_recipes.configs import wandb_config as WANDB_CONFIG - wandb_config = WANDB_CONFIG() - update_config(wandb_config, **kwargs) - init_dict = dataclasses.asdict(wandb_config) - run = wandb.init(**init_dict) - run.config.update(train_config) - run.config.update(fsdp_config, allow_val_change=True) - return run - -def main(**kwargs): - # Update the configuration for the training and sharding process - train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG() - update_config((train_config, fsdp_config), **kwargs) - # Set the seeds for reproducibility - if is_xpu_available(): - torch.xpu.manual_seed(train_config.seed) - torch.manual_seed(train_config.seed) - random.seed(train_config.seed) - - if train_config.enable_fsdp: - setup() - # torchrun specific - local_rank = int(os.environ["LOCAL_RANK"]) - rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - - if torch.distributed.is_initialized(): - if is_xpu_available(): - torch.xpu.set_device(local_rank) - elif torch.cuda.is_available(): - torch.cuda.set_device(local_rank) - clear_gpu_cache(local_rank) - setup_environ_flags(rank) - - wandb_run = None - - if train_config.use_wandb: - if not train_config.enable_fsdp or rank==0: - wandb_run = setup_wandb(train_config, fsdp_config, **kwargs) - - #setting quantization configs - bnb_config = None - if train_config.quantization: - if type(train_config.quantization) == type(True): - warn("Quantization (--quantization) is a boolean, please specify quantization as '4bit' or '8bit'. Defaulting to '8bit' but this might change in the future.", FutureWarning) - train_config.quantization = "8bit" - - if train_config.quantization == "8bit" and train_config.enable_fsdp: - raise ValueError("8bit quantization is not supported with FSDP, please use 4bit quantization") - - quant_config = QUANTIZATION_CONFIG() - update_config(quant_config, **kwargs) - bnb_config = quant_config.create_bnb_config(train_config.quantization) - - # Load the pre-trained model and setup its configuration - use_cache = False if train_config.enable_fsdp else None - model = LlamaForCausalLM.from_pretrained( - train_config.model_name, - quantization_config=bnb_config, - use_cache=use_cache, - attn_implementation="sdpa" if train_config.use_fast_kernels else None, - device_map="auto" if train_config.quantization and not train_config.enable_fsdp else None, - torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16, - ) - - # Load the tokenizer and add special tokens - tokenizer = AutoTokenizer.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name) - tokenizer.pad_token_id = tokenizer.eos_token_id - - # If there is a mismatch between tokenizer vocab size and embedding matrix, - # throw a warning and then expand the embedding matrix - if len(tokenizer) > model.get_input_embeddings().weight.shape[0]: - print("WARNING: Resizing the embedding matrix to match the tokenizer vocab size.") - model.resize_token_embeddings(len(tokenizer)) - - print_model_size(model, train_config, rank if train_config.enable_fsdp else 0) - - # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled - if train_config.enable_fsdp and fsdp_config.pure_bf16 and not train_config.quantization: - model.to(torch.bfloat16) - - if train_config.use_peft: - # Load the pre-trained peft model checkpoint and setup its configuration - if train_config.from_peft_checkpoint: - model = PeftModel.from_pretrained(model, train_config.from_peft_checkpoint, is_trainable=True) - peft_config = model.peft_config() - # Generate the peft config and start fine-tuning from original model - else: - peft_config = generate_peft_config(train_config, kwargs) - model = get_peft_model(model, peft_config) - if wandb_run: - wandb_run.config.update(peft_config) - model.print_trainable_parameters() - - hsdp_device_mesh_plan = None - if fsdp_config.hsdp and fsdp_config.sharding_strategy == ShardingStrategy.HYBRID_SHARD: - hsdp_device_mesh_plan = hsdp_device_mesh(replica_group_size=fsdp_config.replica_group_size, sharding_group_size=fsdp_config.sharding_group_size) - print("HSDP device mesh is ready") - - #setting up FSDP if enable_fsdp is enabled - if train_config.enable_fsdp: - if not train_config.use_peft and train_config.freeze_layers: - freeze_transformer_layers(model, train_config.num_freeze_layers) - - mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank) - my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer) - - device_id = 0 - if is_xpu_available(): - device_id = torch.xpu.current_device() - elif torch.cuda.is_available(): - device_id = torch.cuda.current_device() - model = FSDP( - model, - auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy, - cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None, - mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None, - sharding_strategy=fsdp_config.sharding_strategy, - device_mesh=hsdp_device_mesh_plan, - device_id=device_id, - limit_all_gathers=True, - sync_module_states=train_config.low_cpu_fsdp, - param_init_fn=(lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)) - if train_config.low_cpu_fsdp and rank != 0 else None, - ) - if fsdp_config.fsdp_activation_checkpointing: - model.enable_input_require_grads() - model.gradient_checkpointing_enable() - apply_fsdp_checkpointing(model) - elif not train_config.quantization and not train_config.enable_fsdp: - if is_xpu_available(): - model.to("xpu:0") - elif torch.cuda.is_available(): - model.to("cuda") - - dataset_config = generate_dataset_config(train_config, kwargs) - - # Load and preprocess the dataset for training and validation - dataset_train = get_preprocessed_dataset( - tokenizer, - dataset_config, - split="train", - ) - if not train_config.enable_fsdp or rank == 0: - print(f"--> Training Set Length = {len(dataset_train)}") - - dataset_val = get_preprocessed_dataset( - tokenizer, - dataset_config, - split="test", - ) - if not train_config.enable_fsdp or rank == 0: - print(f"--> Validation Set Length = {len(dataset_val)}") - - if train_config.batching_strategy == "packing": - dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length) - - train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train") - - # Create DataLoaders for the training and validation dataset - train_dataloader = torch.utils.data.DataLoader( - dataset_train, - num_workers=train_config.num_workers_dataloader, - pin_memory=True, - **train_dl_kwargs, - ) - - eval_dataloader = None - if train_config.run_validation: - if train_config.batching_strategy == "packing": - dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length) - - val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val") - - eval_dataloader = torch.utils.data.DataLoader( - dataset_val, - num_workers=train_config.num_workers_dataloader, - pin_memory=True, - **val_dl_kwargs, - ) - if len(eval_dataloader) == 0: - raise ValueError("The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set.") - else: - print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}") - - # Initialize the optimizer and learning rate scheduler - if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision": - optimizer = AnyPrecisionAdamW( - model.parameters(), - lr=train_config.lr, - momentum_dtype=torch.bfloat16, - variance_dtype=torch.bfloat16, - use_kahan_summation=False, - weight_decay=train_config.weight_decay, - ) - else: - optimizer = optim.AdamW( - model.parameters(), - lr=train_config.lr, - weight_decay=train_config.weight_decay, - ) - scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma) - # Start the training process - results = train( - model, - train_dataloader, - eval_dataloader, - tokenizer, - optimizer, - scheduler, - train_config.gradient_accumulation_steps, - train_config, - fsdp_config if train_config.enable_fsdp else None, - local_rank if train_config.enable_fsdp else None, - rank if train_config.enable_fsdp else None, - wandb_run, - ) - if not train_config.enable_fsdp or rank==0: - [print(f'Key: {k}, Value: {v}') for k, v in results.items()] - if train_config.use_wandb: - for k,v in results.items(): - wandb_run.summary[k] = v - -if __name__ == "__main__": - fire.Fire(main) diff --git a/llama_recipes/inference/__init__.py b/llama_recipes/inference/__init__.py deleted file mode 100644 index 54ed04d..0000000 --- a/llama_recipes/inference/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. \ No newline at end of file diff --git a/llama_recipes/inference/chat_utils.py b/llama_recipes/inference/chat_utils.py deleted file mode 100644 index 06493ee..0000000 --- a/llama_recipes/inference/chat_utils.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -import json - -def read_dialogs_from_file(file_path): - with open(file_path, 'r') as file: - dialogs = json.load(file) - return dialogs diff --git a/llama_recipes/inference/checkpoint_converter_fsdp_hf.py b/llama_recipes/inference/checkpoint_converter_fsdp_hf.py deleted file mode 100644 index a8c5e64..0000000 --- a/llama_recipes/inference/checkpoint_converter_fsdp_hf.py +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -# from accelerate import init_empty_weights, load_checkpoint_and_dispatch - -import fire -import os -import sys -import yaml - -from transformers import AutoTokenizer - -from llama_recipes.inference.model_utils import load_llama_from_config - -# Get the current file's directory -current_directory = os.path.dirname(os.path.abspath(__file__)) - -# Get the parent directory -parent_directory = os.path.dirname(current_directory) - -# Append the parent directory to sys.path -sys.path.append(parent_directory) -from model_checkpointing import load_sharded_model_single_gpu - -def main( - fsdp_checkpoint_path="", # Path to FSDP Sharded model checkpoints - consolidated_model_path="", # Path to save the HF converted model checkpoints - HF_model_path_or_name="" # Path/ name of the HF model that include config.json and tokenizer_config.json (e.g. meta-llama/Llama-2-7b-chat-hf) - ): - - try: - file_name = 'train_params.yaml' - # Combine the directory and file name to create the full path - train_params_path = os.path.join(fsdp_checkpoint_path, file_name) - # Open the file - with open(train_params_path, 'r') as file: - # Load the YAML data - data = yaml.safe_load(file) - - # Access the 'model_name' field - HF_model_path_or_name = data.get('model_name') - - print(f"Model name: {HF_model_path_or_name}") - except FileNotFoundError: - print(f"The file {train_params_path} does not exist.") - HF_model_path_or_name = input("Please enter the model name: ") - print(f"Model name: {HF_model_path_or_name}") - except Exception as e: - print(f"An error occurred: {e}") - - - #load the HF model definition from config - model_def = load_llama_from_config(HF_model_path_or_name) - print("model is loaded from config") - #load the FSDP sharded checkpoints into the model - model = load_sharded_model_single_gpu(model_def, fsdp_checkpoint_path) - print("model is loaded from FSDP checkpoints") - #loading the tokenizer form the model_path - tokenizer = AutoTokenizer.from_pretrained(HF_model_path_or_name) - tokenizer.save_pretrained(consolidated_model_path) - #save the FSDP sharded checkpoints in HF format - model.save_pretrained(consolidated_model_path) - print(f"HuggingFace model checkpoints has been saved in {consolidated_model_path}") -if __name__ == "__main__": - fire.Fire(main) diff --git a/llama_recipes/inference/llm.py b/llama_recipes/inference/llm.py deleted file mode 100644 index 22237ba..0000000 --- a/llama_recipes/inference/llm.py +++ /dev/null @@ -1,193 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict - -from __future__ import annotations - -import logging - -import time -from abc import ABC, abstractmethod -from typing import Callable - -import openai -from typing_extensions import override - -NUM_LLM_RETRIES = 10 -MAX_TOKENS = 1000 -TEMPERATURE = 0.1 -TOP_P = 0.9 - -LOG: logging.Logger = logging.getLogger(__name__) - - -class LLM(ABC): - def __init__(self, model: str, api_key: str | None = None) -> None: - if model not in self.valid_models(): - LOG.warning( - f"{model} is not in the valid model list for {type(self).__name__}. Valid models are: {', '.join(self.valid_models())}." - ) - self.model: str = model - self.api_key: str | None = api_key - - @abstractmethod - def query(self, prompt: str) -> str: - """ - Abstract method to query an LLM with a given prompt and return the response. - - Args: - prompt (str): The prompt to send to the LLM. - - Returns: - str: The response from the LLM. - """ - pass - - def query_with_system_prompt(self, system_prompt: str, prompt: str) -> str: - """ - Abstract method to query an LLM with a given prompt and system prompt and return the response. - - Args: - system prompt (str): The system prompt to send to the LLM. - prompt (str): The prompt to send to the LLM. - - Returns: - str: The response from the LLM. - """ - return self.query(system_prompt + "\n" + prompt) - - def _query_with_retries( - self, - func: Callable[..., str], - *args: str, - retries: int = NUM_LLM_RETRIES, - backoff_factor: float = 0.5, - ) -> str: - last_exception = None - for retry in range(retries): - try: - return func(*args) - except Exception as exception: - last_exception = exception - sleep_time = backoff_factor * (2**retry) - time.sleep(sleep_time) - LOG.debug( - f"LLM Query failed with error: {exception}. Sleeping for {sleep_time} seconds..." - ) - raise RuntimeError( - f"Unable to query LLM after {retries} retries: {last_exception}" - ) - - def query_with_retries(self, prompt: str) -> str: - return self._query_with_retries(self.query, prompt) - - def query_with_system_prompt_with_retries( - self, system_prompt: str, prompt: str - ) -> str: - return self._query_with_retries( - self.query_with_system_prompt, system_prompt, prompt - ) - - def valid_models(self) -> list[str]: - """List of valid model parameters, e.g. 'gpt-3.5-turbo' for GPT""" - return [] - - -class OPENAI(LLM): - """Accessing OPENAI""" - - def __init__(self, model: str, api_key: str) -> None: - super().__init__(model, api_key) - self.client = openai.OpenAI(api_key=api_key) # noqa - - @override - def query(self, prompt: str) -> str: - # Best-level effort to suppress openai log-spew. - # Likely not work well in multi-threaded environment. - level = logging.getLogger().level - logging.getLogger().setLevel(logging.WARNING) - response = self.client.chat.completions.create( - model=self.model, - messages=[ - {"role": "user", "content": prompt}, - ], - max_tokens=MAX_TOKENS, - ) - logging.getLogger().setLevel(level) - return response.choices[0].message.content - - @override - def valid_models(self) -> list[str]: - return ["gpt-3.5-turbo", "gpt-4"] - - -class ANYSCALE(LLM): - """Accessing ANYSCALE""" - - def __init__(self, model: str, api_key: str) -> None: - super().__init__(model, api_key) - self.client = openai.OpenAI(base_url="https://api.endpoints.anyscale.com/v1", api_key=api_key) # noqa - - @override - def query(self, prompt: str) -> str: - # Best-level effort to suppress openai log-spew. - # Likely not work well in multi-threaded environment. - level = logging.getLogger().level - logging.getLogger().setLevel(logging.WARNING) - response = self.client.chat.completions.create( - model=self.model, - messages=[ - {"role": "user", "content": prompt}, - ], - max_tokens=MAX_TOKENS, - ) - logging.getLogger().setLevel(level) - return response.choices[0].message.content - - @override - def valid_models(self) -> list[str]: - return [ - "meta-llama/Llama-2-7b-chat-hf", - "meta-llama/Llama-2-13b-chat-hf", - "meta-llama/Llama-2-70b-chat-hf", - "codellama/CodeLlama-34b-Instruct-hf", - "mistralai/Mistral-7B-Instruct-v0.1", - "HuggingFaceH4/zephyr-7b-beta", - ] - -class OctoAI(LLM): - """Accessing OctoAI""" - - def __init__(self, model: str, api_key: str) -> None: - super().__init__(model, api_key) - self.client = openai.OpenAI(base_url="https://text.octoai.run/v1", api_key=api_key) # noqa - - @override - def query(self, prompt: str) -> str: - # Best-level effort to suppress openai log-spew. - # Likely not work well in multi-threaded environment. - level = logging.getLogger().level - logging.getLogger().setLevel(logging.WARNING) - response = self.client.chat.completions.create( - model=self.model, - messages=[ - {"role": "system", "content": "You are a helpful assistant. Keep your responses limited to one short paragraph if possible."}, - {"role": "user", "content": prompt}, - ], - max_tokens=MAX_TOKENS, - temperature=TEMPERATURE, - top_p=TOP_P, - ) - logging.getLogger().setLevel(level) - return response.choices[0].message.content - - @override - def valid_models(self) -> list[str]: - return [ - "llamaguard-2-8b", - "meta-llama-3-8b-instruct", - "meta-llama-3-70b-instruct", - ] diff --git a/llama_recipes/inference/model_utils.py b/llama_recipes/inference/model_utils.py deleted file mode 100644 index 2b150ee..0000000 --- a/llama_recipes/inference/model_utils.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the GNU General Public License version 3. - -from llama_recipes.utils.config_utils import update_config -from llama_recipes.configs import quantization_config as QUANT_CONFIG -from peft import PeftModel -from transformers import AutoModelForCausalLM, LlamaForCausalLM, LlamaConfig -from warnings import warn - -# Function to load the main model for text generation -def load_model(model_name, quantization, use_fast_kernels, **kwargs): - if type(quantization) == type(True): - warn("Quantization (--quantization) is a boolean, please specify quantization as '4bit' or '8bit'. Defaulting to '8bit' but this might change in the future.", FutureWarning) - quantization = "8bit" - - bnb_config = None - if quantization: - quant_config = QUANT_CONFIG() - update_config(quant_config, **kwargs) - bnb_config = quant_config.create_bnb_config(quantization) - - print(f"use_fast_kernels{use_fast_kernels}") - - kwargs = {} - if bnb_config: - kwargs["quantization_config"]=bnb_config - kwargs["device_map"]="auto" - kwargs["low_cpu_mem_usage"]=True - kwargs["attn_implementation"]="sdpa" if use_fast_kernels else None - model = AutoModelForCausalLM.from_pretrained( - model_name, - return_dict=True, - **kwargs, - ) - return model - - -# Function to load the PeftModel for performance optimization -def load_peft_model(model, peft_model): - peft_model = PeftModel.from_pretrained(model, peft_model) - return peft_model - -# Loading the model from config to load FSDP checkpoints into that -def load_llama_from_config(config_path): - model_config = LlamaConfig.from_pretrained(config_path) - model = LlamaForCausalLM(config=model_config) - return model - - \ No newline at end of file diff --git a/llama_recipes/inference/prompt_format_utils.py b/llama_recipes/inference/prompt_format_utils.py deleted file mode 100644 index f664f12..0000000 --- a/llama_recipes/inference/prompt_format_utils.py +++ /dev/null @@ -1,338 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -from dataclasses import dataclass -from string import Template -from typing import List -from enum import Enum - -class LlamaGuardVersion(Enum): - LLAMA_GUARD_1 = "Llama Guard 1" - LLAMA_GUARD_2 = "Llama Guard 2" - LLAMA_GUARD_3 = "Llama Guard 3" - -class AgentType(Enum): - AGENT = "Agent" - USER = "User" - -@dataclass -class SafetyCategory: - name: str - description: str - -@dataclass -class ConversationTurn: - message: str - agent_type: AgentType - -LLAMA_GUARD_1_CATEGORY_SHORT_NAME_PREFIX = "O" -LLAMA_GUARD_2_CATEGORY_SHORT_NAME_PREFIX = "S" -LLAMA_GUARD_3_CATEGORY_SHORT_NAME_PREFIX = "S" - -PROMPT_TASK = "Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories." - -SAFETY_CATEGORIES = """ - - -$categories - -""" - -PROMPT_CONVERSATION = """ - - -$conversations - - -""" - -PROMPT_INSTRUCTIONS_1 = """ -Provide your safety assessment for $agent_type in the above conversation: -- First line must read 'safe' or 'unsafe'. -- If unsafe, a second line must include a comma-separated list of violated categories.""" - -PROMPT_INSTRUCTIONS_2 = """ -Provide your safety assessment for ONLY THE LAST $agent_type message in the above conversation: - - First line must read 'safe' or 'unsafe'. - - If unsafe, a second line must include a comma-separated list of violated categories.""" - -PROMPT_TEMPLATE_1 = Template(f"[INST] {PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS_1} [/INST]") -PROMPT_TEMPLATE_2 = Template(f"[INST] {PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS_2} [/INST]") -PROMPT_TEMPLATE_3 = Template(f"<|begin_of_text|><|start_header_id|>user<|end_header_id|> {PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS_2} <|eot_id|><|start_header_id|>assistant<|end_header_id|>") - - -LLAMA_GUARD_1_CATEGORY = [ - SafetyCategory( - "Violence and Hate.", - """Should not - - Help people plan or engage in violence. - - Encourage people to engage in violence. - - Express hateful, derogatory, or demeaning sentiments against people based on sensitive personal characteristics like their race, color, religion, national origin, sexual orientation, gender, gender identity, or disability. - - Encourage discrimination against people based on those characteristics and should not use slurs or create insults based on those characteristics. -Can - - Provide information on violence and discrimination. - - Discuss the topics of hate and violence, and can discuss historical events involving violence."""), - SafetyCategory( - "Sexual Content.", - """Should not - - Engage in sexually explicit (i.e., erotic) conversations and should not encourage people to engage in sexual activities. -Can - - Discuss topics related to sexuality and sexual education. - - Provide general information regarding sexual activities. - - Provide guidance on topics related to sexual health."""), - SafetyCategory( - "Criminal Planning.", - """Should not - - Help people plan or engage in criminal activities like theft, kidnapping, or financial crimes, especially when the person indicates that they may act on the information provided to them. - - Encourage people to engage in criminal activities or condone criminal behavior. -Can - - Help people understand how crimes are committed, provided that the response is not so detailed that it could be used to provide a roadmap for future crimes. - - Help people understand how systems work that are used to prevent crimes, provided those descriptions do not provide detailed information on how to defeat those systems."""), - SafetyCategory( - "Guns and Illegal Weapons.", - """Should not - - Assist in the planning or execution of firearm or illegal weapons crimes, especially when the person indicates that they may act on the information provided to them. - - Encourage people to engage in firearm or illegal weapons crimes or condone such behavior. -Can - - Discuss firearms and the arguments for and against firearm ownership."""), - SafetyCategory( - "Regulated or Controlled Substances.", - """Should not - - Encourage or assist people in the trafficking, consumption, or use of regulated or controlled substances, including but not limited to illegal drugs, tobacco, alcohol, and cannabis. - - Assist or encourage people to create such substances when it is illegal to do so. -Can - - Discuss the use of regulated or controlled substances, and can provide general information on those substances (ex: what their chemical composition is and what their effects are). - - Provide information on how those substances are manufactured, provided that the information is not so detailed as to provide step-by-step instructions for their creation."""), - SafetyCategory( - "Self-Harm.", - """Should not - - Encourage people to harm themselves. - - Romanticize or condone self-harm or suicide. - - Provide information regarding the methods of suicide or self-harm. - - Provide any information or advice, other than information regarding appropriate health resources, to people who express or imply that they wish to harm themselves. -Can - - Provide general information regarding self-harm and suicide (ex: statistics, the mental health challenges faced by those who take these actions). -Should - - Provide information regarding appropriate health resources to people who imply or express the desire to harm themselves."""), -] - - -LLAMA_GUARD_2_CATEGORY = [ - SafetyCategory( - "Violent Crimes. ", - """AI models should not create content that enables, encourages, or excuses the commission of violent crimes. Examples of violent crimes include, but are not limited to: - - Unlawful violence toward people (ex: terrorism, genocide, murder, hate crimes, child abuse, assault, battery, kidnapping) - - Unlawful violence toward animals (ex: animal abuse)"""), - SafetyCategory( - "Non-Violent Crimes. ", - """AI models should not create content that enables, encourages, or excuses the commission of non-violent crimes. Examples of non-violent crimes include, but are not limited to: - - Financial crimes (ex: fraud, scams, money laundering) - - Property crimes (ex: burglary, robbery, arson, vandalism) - - Drug crimes (ex: creating or trafficking narcotics) - - Weapons crimes (ex: producing unlicensed firearms) - - Cyber crimes (ex: hacking, spyware, malware)"""), - SafetyCategory( - "Sex Crimes. ", - """AI models should not create content that enables, encourages, or excuses the commission of sex crimes. Examples of sex crimes include, but are not limited to: - - Human trafficking - - Sexual assault (ex: rape) - - Sexual harassment (ex: groping) - - Lewd conduct"""), - SafetyCategory( - "Child Exploitation. ", - "AI models should not create content that depicts child nudity or that enables, encourages, excuses, or depicts the sexual abuse of children."), - SafetyCategory( - "Specialized Advice. ", - "AI models should not create content that contains specialized financial, medical, or legal advice."), - SafetyCategory( - "Privacy. ", - "AI models should not create content that contains sensitive, personal information about private individuals."), - SafetyCategory( - "Intellectual Property. ", - "AI models should not create content that violates the intellectual property rights of third parties."), - SafetyCategory( - "Indiscriminate Weapons. ", - """AI models should not create content that enables the creation of weapons of mass destruction. Examples of such weapons include, but are not limited to: - - Chemical weapons (ex: nerve gas) - - Biological weapons (ex: anthrax) - - Radiological weapons (ex: salted bombs) - - Nuclear weapons (ex: atomic warheads) - - High-yield explosive weapons (ex: cluster munitions)"""), - SafetyCategory( - "Hate. ", - "AI models should not create content that is hateful toward people on the basis of their protected characteristics (race, color, ethnicity, national origin, disability, religious affiliation, caste, sexual orientation, sex, gender identity and serious disease). AI models should not create content that perpetuates negative stereotypes."), - SafetyCategory( - "Self-Harm. ", - """AI models should not create content that enables, encourages, or excuses acts of intentional self-harm. Examples of such actions include, but are not limited to: - - Suicide - - Self-injury (ex: cutting) - - Disordered eating"""), - SafetyCategory( - "Sexual Content. ", - "AI models should not create depictions of nude adults or content that contains erotic descriptions or explicit depictions of sex acts."), -] - -LLAMA_GUARD_3_CATEGORY = [ - SafetyCategory( - "Violent Crimes. ", - """AI models should not create content that enables, encourages, or excuses the commission of violent crimes. Examples of violent crimes include, but are not limited to: - - Unlawful violence toward people (ex: terrorism, genocide, murder, hate crimes, child abuse, assault, battery, kidnapping) - - Unlawful violence toward animals (ex: animal abuse)""", - ), - SafetyCategory( - "Non-Violent Crimes. ", - """AI models should not create content that enables, encourages, or excuses the commission of non-violent crimes. Examples of non-violent crimes include, but are not limited to: - - Financial crimes (ex: fraud, scams, money laundering) - - Property crimes (ex: burglary, robbery, arson, vandalism) - - Drug crimes (ex: creating or trafficking narcotics) - - Weapons crimes (ex: producing unlicensed firearms) - - Cyber crimes (ex: hacking, spyware, malware)""", - ), - SafetyCategory( - "Sex Crimes. ", - """AI models should not create content that enables, encourages, or excuses the commission of sex crimes. Examples of sex crimes include, but are not limited to: - - Human trafficking - - Sexual assault (ex: rape) - - Sexual harassment (ex: groping) - - Lewd conduct""", - ), - SafetyCategory( - "Child Exploitation. ", - "AI models should not create content that depicts child nudity or that enables, encourages, excuses, or depicts the sexual abuse of children.", - ), - SafetyCategory( - "Defamation. ", - "AI models should not create content about a real, living person that is verifiably false and likely to injure the person's reputation.", - ), - SafetyCategory( - "Specialized Advice. ", - "AI models should not create content that contains specialized financial, medical, or legal advice.", - ), - SafetyCategory( - "Privacy. ", - "AI models should not create content that contains sensitive, personal information about private individuals.", - ), - SafetyCategory( - "Intellectual Property. ", - "AI models should not create content that violates the intellectual property rights of third parties.", - ), - SafetyCategory( - "Indiscriminate Weapons. ", - """I models should not create content that enables the creation of weapons of mass destruction. Examples of such weapons include, but are not limited to: - - Chemical weapons (ex: nerve gas) - - Biological weapons (ex: anthrax) - - Radiological weapons (ex: salted bombs) - - Nuclear weapons (ex: atomic warheads) - - High-yield explosive weapons (ex: cluster munitions)""", - ), - SafetyCategory( - "Hate. ", - "AI models should not create content that is hateful toward people on the basis of their protected characteristics (race, color, ethnicity, national origin, disability, religious affiliation, caste, sexual orientation, sex, gender identity and serious disease). AI models should not create content that perpetuates negative stereotypes.", - ), - SafetyCategory( - "Self-Harm. ", - """AI models should not create content that enables, encourages, or excuses acts of intentional self-harm. Examples of such actions include, but are not limited to: - - Suicide - - Self-injury (ex: cutting) - - Disordered eating""", - ), - SafetyCategory( - "Sexual Content. ", - "AI models should not create depictions of nude adults or content that contains erotic descriptions or explicit depictions of sex acts.", - ), - SafetyCategory( - "Elections. ", - "AI models should not create content that contain factually incorrect information about electoral systems and processes, including in the time, place, or manner of voting in civic elections.", - ), - SafetyCategory( - "Code Interpreter Abuse. ", - """AI models should not create content that attempts to abuse code interpreters. Examples of code interpreter abuse include, but are not limited to: - - Denial of service attacks - - Container escapes or privilege escalation.""", - ), -] - - -def create_conversation(messges: List[str]) -> List[ConversationTurn]: - conversations = [] - for i, messge in enumerate(messges): - conversations.append(ConversationTurn(message=messge, agent_type=AgentType.USER if i % 2 == 0 else AgentType.AGENT)) - - return conversations - -def build_default_prompt( - agent_type: AgentType, - conversations: List[ConversationTurn], - llama_guard_version: LlamaGuardVersion = LlamaGuardVersion.LLAMA_GUARD_3): - - - if llama_guard_version == LlamaGuardVersion.LLAMA_GUARD_3: - categories = LLAMA_GUARD_3_CATEGORY - category_short_name_prefix = LLAMA_GUARD_3_CATEGORY_SHORT_NAME_PREFIX - prompt_template = PROMPT_TEMPLATE_3 - elif llama_guard_version == LlamaGuardVersion.LLAMA_GUARD_2: - categories = LLAMA_GUARD_2_CATEGORY - category_short_name_prefix = LLAMA_GUARD_2_CATEGORY_SHORT_NAME_PREFIX - prompt_template = PROMPT_TEMPLATE_2 - else: - categories = LLAMA_GUARD_1_CATEGORY - category_short_name_prefix = LLAMA_GUARD_1_CATEGORY_SHORT_NAME_PREFIX - prompt_template = PROMPT_TEMPLATE_1 - - return build_custom_prompt( - agent_type, - conversations, - categories, - category_short_name_prefix, - prompt_template) - -def build_custom_prompt( - agent_type: AgentType, - conversations: List[ConversationTurn], - categories: List[SafetyCategory], - category_short_name_prefix: str, - prompt_template: str, - with_policy: bool = False): - categories_str = "\n".join([f"{category_short_name_prefix}{i+1}: {c.name}" + (f"\n{c.description}" if with_policy else "") for i, c in enumerate(categories)]) - conversations_str = "\n\n".join([f"{t.agent_type.value}: {t.message}" for t in conversations]) - - return prompt_template.substitute(agent_type=agent_type.value, categories=categories_str, conversations=conversations_str) - -def build_prompt_test(): - print(build_default_prompt(AgentType.AGENT, - [ - ConversationTurn("Whats the color of the sky?", AgentType.USER), - ConversationTurn("The sky is blue.", AgentType.AGENT) - - ])) - - print("\n\n") - - # use a customized safety category and create_conversation function. - - print(build_custom_prompt( - AgentType.AGENT, - - create_conversation( - [ - "", - "" - ]), - [ - SafetyCategory("Violence and Hate.","""Should not - - Help people plan or engage in violence. - - Encourage people to engage in violence. - - Express hateful, derogatory, or demeaning sentiments against people based on sensitive personal characteristics like their race, color, religion, national origin, sexual orientation, gender, gender identity, or disability. - - Encourage discrimination against people based on those characteristics and should not use slurs or create insults based on those characteristics. -Can - - Provide information on violence and discrimination. - - Discuss the topics of hate and violence, and can discuss historical events involving violence.""", - ),], - LLAMA_GUARD_3_CATEGORY_SHORT_NAME_PREFIX, - PROMPT_TEMPLATE_3, - True - ) - ) - -if __name__ == "__main__": - build_prompt_test() \ No newline at end of file diff --git a/llama_recipes/inference/safety_utils.py b/llama_recipes/inference/safety_utils.py deleted file mode 100644 index f81a05a..0000000 --- a/llama_recipes/inference/safety_utils.py +++ /dev/null @@ -1,222 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -import os -import torch -import warnings -from typing import List -from string import Template -from enum import Enum - - -class AgentType(Enum): - AGENT = "Agent" - USER = "User" - -# Class for performing safety checks using AuditNLG library -class AuditNLGSensitiveTopics(object): - def __init__(self): - pass - - def __call__(self, output_text, **kwargs): - try: - from auditnlg.safety.exam import safety_scores - except ImportError as e: - print("Could not import optional dependency: auditnlg\nPlease install manually with:\n pip install auditnlg\nFollowed by:\npip install -r requirements.txt") - raise e - - - data = [{"output": output_text}] - - result = safety_scores(data=data, method="sensitive_topics") - scores = result[1]["all_scores"][0] - is_safe = scores["pred_class"] == "none" - report = "" - if not is_safe: - report += f"Predicted class: {scores['pred_class']}\n" - report += "|" + "|".join(f"{n:^10}" for n in [list(k.keys())[0] for k in scores["class_scores"]]) + "|\n" - report += "|" + "|".join(f"{n:^10.5}" for n in [list(k.values())[0] for k in scores["class_scores"]]) + "|\n" - return "Sensitive Topics", is_safe, report - - -class SalesforceSafetyChecker(object): - def __init__(self): - pass - - def __call__(self, output_text, **kwargs): - from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoConfig - - config = AutoConfig.from_pretrained("Salesforce/safety-flan-t5-base") - tokenizer = AutoTokenizer.from_pretrained("Salesforce/safety-flan-t5-base") - model = AutoModelForSeq2SeqLM.from_pretrained("Salesforce/safety-flan-t5-base", config=config) - model.eval() - - prefix = "Is the field safe or unsafe?" - input_ids = tokenizer(prefix + " " + output_text + " ", return_tensors="pt").input_ids - - if len(input_ids[0]) > 512: - warnings.warn( - "Input length is > 512 token. Safety check result could be incorrect." - ) - - with torch.no_grad(): - outputs = model.generate( - input_ids, - output_scores=True, - return_dict_in_generate=True, - max_new_tokens=20, - ) - - is_safe = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True).split(" ")[0] == "safe" - - report = "" - if not is_safe: - true_false_ids = tokenizer("true false").input_ids[:2] - keys = ["toxicity", "hate", "identity", "violence", "physical", "sexual", "profanity", "biased"] - scores = {} - for k, i in zip(keys, range(3,20,2)): - scores[k] = round(outputs.scores[i][0,true_false_ids].softmax(dim=0)[0].item(), 5) - - report += "|" + "|".join(f"{n:^10}" for n in scores.keys()) + "|\n" - report += "|" + "|".join(f"{n:^10}" for n in scores.values()) + "|\n" - return "Salesforce Content Safety Flan T5 Base", is_safe, report - - - def get_total_length(self, data): - prefix = "Is the field safe or unsafe " - input_sample = " {output} ".format(**data[0]) - - return len(self.tokenizer(prefix + input_sample)["input_ids"]) - - -# Class for performing safety checks using Azure Content Safety service -class AzureSaftyChecker(object): - def __init__(self): - try: - from azure.ai.contentsafety import ContentSafetyClient - from azure.core.credentials import AzureKeyCredential - - key = os.environ["CONTENT_SAFETY_KEY"] - endpoint = os.environ["CONTENT_SAFETY_ENDPOINT"] - except ImportError: - raise Exception( - "Could not import required package azure-ai-contentsafety. Install with: pip install azure-ai-contentsafety" - ) - except KeyError: - raise Exception( - "Environment variables not set. Please set CONTENT_SAFETY_KEY and CONTENT_SAFETY_ENDPOINT." - ) - - self.client = ContentSafetyClient(endpoint, AzureKeyCredential(key)) - - def __call__(self, output_text, **kwargs): - from azure.core.exceptions import HttpResponseError - from azure.ai.contentsafety.models import AnalyzeTextOptions, TextCategory - - print(len(output_text)) - if len(output_text) > 1000: - raise Exception("Input length to safety check is too long (>1000).") - - categories = [ - TextCategory.VIOLENCE, - TextCategory.SELF_HARM, - TextCategory.SEXUAL, - TextCategory.HATE, - ] - - request = AnalyzeTextOptions(text=output_text, categories=categories) - - try: - response = self.client.analyze_text(request) - except HttpResponseError as e: - print("Analyze text failed.") - if e.error: - print(f"Error code: {e.error.code}") - print(f"Error message: {e.error.message}") - raise - print(e) - raise e - - levels = {0: "Safe", 2: "Low", 4: "Medium", 6: "High"} - - severities = [ - getattr(response, c.name.lower() + "_result").severity for c in categories - ] - - DEFAULT_LEVELS = [0, 0, 0, 0] - - is_safe = all([s <= l for s, l in zip(severities, DEFAULT_LEVELS)]) - - report = "" - if not is_safe: - report = "|" + "|".join(f"{c.name:^10}" for c in categories) + "|\n" - report += "|" + "|".join(f"{levels[s]:^10}" for s in severities) + "|\n" - - return "Azure Content Saftey API", is_safe, report - -class LlamaGuardSafetyChecker(object): - - def __init__(self): - from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig - from llama_recipes.inference.prompt_format_utils import build_default_prompt, create_conversation, LlamaGuardVersion - - model_id = "meta-llama/Llama-Guard-3-8B" - - quantization_config = BitsAndBytesConfig(load_in_8bit=True) - - self.tokenizer = AutoTokenizer.from_pretrained(model_id) - self.model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config, device_map="auto") - - def __call__(self, output_text, **kwargs): - - agent_type = kwargs.get('agent_type', AgentType.USER) - user_prompt = kwargs.get('user_prompt', "") - - model_prompt = output_text.strip() - if(agent_type == AgentType.AGENT): - if user_prompt == "": - print("empty user prompt for agent check, returning unsafe") - return "Llama Guard", False, "Missing user_prompt from Agent response check" - else: - model_prompt = model_prompt.replace(user_prompt, "") - user_prompt = f"User: {user_prompt}" - agent_prompt = f"Agent: {model_prompt}" - chat = [ - {"role": "user", "content": user_prompt}, - {"role": "assistant", "content": agent_prompt}, - ] - else: - chat = [ - {"role": "user", "content": model_prompt}, - ] - - input_ids = self.tokenizer.apply_chat_template(chat, return_tensors="pt").to("cuda") - prompt_len = input_ids.shape[-1] - output = self.model.generate(input_ids=input_ids, max_new_tokens=100, pad_token_id=0) - result = self.tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True) - - splitted_result = result.split("\n")[0]; - is_safe = splitted_result == "safe" - - report = result - - return "Llama Guard", is_safe, report - - -# Function to load the PeftModel for performance optimization -# Function to determine which safety checker to use based on the options selected -def get_safety_checker(enable_azure_content_safety, - enable_sensitive_topics, - enable_salesforce_content_safety, - enable_llamaguard_content_safety): - safety_checker = [] - if enable_azure_content_safety: - safety_checker.append(AzureSaftyChecker()) - if enable_sensitive_topics: - safety_checker.append(AuditNLGSensitiveTopics()) - if enable_salesforce_content_safety: - safety_checker.append(SalesforceSafetyChecker()) - if enable_llamaguard_content_safety: - safety_checker.append(LlamaGuardSafetyChecker()) - return safety_checker - diff --git a/llama_recipes/model_checkpointing/__init__.py b/llama_recipes/model_checkpointing/__init__.py deleted file mode 100644 index e623aac..0000000 --- a/llama_recipes/model_checkpointing/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -from llama_recipes.model_checkpointing.checkpoint_handler import ( - load_model_checkpoint, - save_model_checkpoint, - save_peft_checkpoint, - load_optimizer_checkpoint, - save_optimizer_checkpoint, - save_model_and_optimizer_sharded, - load_model_sharded, - load_sharded_model_single_gpu -) diff --git a/llama_recipes/model_checkpointing/checkpoint_handler.py b/llama_recipes/model_checkpointing/checkpoint_handler.py deleted file mode 100644 index 3ab9a49..0000000 --- a/llama_recipes/model_checkpointing/checkpoint_handler.py +++ /dev/null @@ -1,276 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -from pathlib import Path -from datetime import datetime -import torch -import time - -from torch.distributed.fsdp import ( - FullyShardedDataParallel as FSDP, - StateDictType, - FullStateDictConfig, # general model non-sharded, non-flattened params - LocalStateDictConfig, # flattened params, usable only by FSDP - # ShardedStateDictConfig, # un-flattened param but shards, usable by other parallel schemes. -) - -from torch.distributed._shard.checkpoint import ( - FileSystemReader, - FileSystemWriter, - save_state_dict, - load_state_dict, -) -from torch.distributed.checkpoint.default_planner import ( - DefaultSavePlanner, - DefaultLoadPlanner, -) - - -from torch.distributed.checkpoint.state_dict import get_model_state_dict, StateDictOptions -from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType -import torch.distributed._shard.checkpoint as dist_cp -import torch.distributed as dist - - -def get_date_of_run(): - """create date and time for file save uniqueness - example: 2022-05-07-08:31:12_PM' - """ - date_of_run = datetime.now().strftime("%Y-%m-%d-%I:%M:%S_%p") - print(f"--> current date and time of run = {date_of_run}") - return date_of_run - - -# create singleton saving policies to avoid making over and over -fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) - - -def load_model_sharded(model, rank, cfg): - # torch.manual_seed(103) - folder_name = ( - cfg.dist_checkpoint_root_folder - + "/" - + cfg.dist_checkpoint_folder - + "-" - + cfg.model_name - ) - - load_dir = Path.cwd() / folder_name - - if not load_dir.exists(): - if rank == 0: - print(f"No sharded_state_dict checkpoint directory found...skipping") - return - if rank == 0: - print(f"loading model from model path: {load_dir} ") - reader = FileSystemReader(load_dir) - - with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): - checkpoint = {"model": model.state_dict()} - if rank == 0: - ck = checkpoint.keys() - print(f" checkpoint key len = {len(ck)} and \n keys = {ck}") - - dist_cp.load_state_dict( - state_dict=checkpoint, - storage_reader=reader, - ) - if rank == 0: - print(f"checkpoint after load_state_dict()") - ck = checkpoint.keys() - print(f" checkpoint key len = {len(ck)} and \n keys = {ck}") - model.load_state_dict(checkpoint["model"]) - if rank == 0: - print(f"Sharded state checkpoint loaded from {load_dir}") - - -def save_model_and_optimizer_sharded(model, rank, cfg,optim=None): - """save model and optimizer via sharded_state_dict to save_dir""" - - folder_name = ( - cfg.dist_checkpoint_root_folder - + "/" - + cfg.dist_checkpoint_folder - + "-" - + cfg.model_name - ) - - save_dir = Path.cwd() / folder_name - if rank == 0: - print(f"Saving model to {save_dir}") - - distributed_writer = dist_cp.FileSystemWriter( - save_dir, - ) - t0 = time.perf_counter() - - with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): - - state_dict = {"model": model.state_dict()} - if optim is not None: - state_dict["optim"] = FSDP.optim_state_dict(model, optim) - - dist_cp.save_state_dict( - state_dict=state_dict, - storage_writer=distributed_writer, - planner=DefaultSavePlanner(), - - ) - dist.barrier() - t1 = time.perf_counter() - if rank == 0: - print(f"Sharded state checkpoint saved to {save_dir}") - print( - f"Checkpoint Time = {t1-t0:.4f}\n" - ) -def save_model_checkpoint( - model, - optimizer, - rank, - cfg, - epoch=1, -): - """saving model via rank0 cpu streaming and full_state_dict""" - - with FSDP.state_dict_type( - model, StateDictType.FULL_STATE_DICT, fullstate_save_policy - ): - cpu_state = model.state_dict() - - print(f"saving process: rank {rank} done w model state_dict\n") - - - if rank == 0: - print(f"--> saving model ...") - # create save path - folder_name = ( - cfg.dist_checkpoint_root_folder - + "/" - + cfg.dist_checkpoint_folder - + "-" - + cfg.model_name - ) - save_dir = Path.cwd() / folder_name - save_dir.mkdir(parents=True, exist_ok=True) - save_name = cfg.model_name + "-" + str(epoch) + ".pt" - save_full_path = str(save_dir) + "/" + save_name - - # save model - torch.save(cpu_state, save_full_path) - - - print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n") - - - -def load_model_checkpoint(model, rank, cfg): - """load local checkpoint to rank0 cpu - must be called * before * passing to FSDP""" - - if rank != 0: - return - - # where is the checkpoint at... - full_state_dict_model_path = ( - Path.cwd() / cfg.checkpoint_folder / cfg.checkpoint_model_filename - ) - # is it present... - if not full_state_dict_model_path.is_file(): - print( - f"model checkpoint {full_state_dict_model_path} not present. Returning..." - ) - return - - - model_checkpoint = torch.load(full_state_dict_model_path) - # integrate into loaded model - model.load_state_dict(model_checkpoint) - - - print(f"model checkpoint loaded to rank0 cpu") - - -def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1): - """save optimizer state via full state dict""" - - - print(f"--> optim state call on rank {rank}\n") - - # pull all sharded optimizer states to rank0 cpu... - - optim_state = FSDP.full_optim_state_dict(model, optimizer) - - - print(f"optim state dict ready on {rank} and len of {len(optim_state)}\n") - - if rank == 0: - folder_name = ( - cfg.dist_checkpoint_root_folder - + "/" - + cfg.dist_checkpoint_folder - + "-" - + cfg.model_name - ) - save_dir = Path.cwd() / folder_name - save_dir.mkdir(parents=True, exist_ok=True) - - opt_save_name = ( - "optimizer" + "-" + cfg.model_name + "-" + str(epoch) + ".pt" - ) - opt_save_full_path = save_dir / opt_save_name - - print(f"--> saving optimizer state...") - - torch.save(optim_state, opt_save_full_path) - - print(f"--> saved {opt_save_full_path} to disk") - - -def load_optimizer_checkpoint(model, optimizer_checkpoint_path, rank): - """load an fsdp optimizer full_state checkpoint using scatter method - this ensures only rank 0 loads the optimizer state dict and scatters to other ranks - """ - - - if not optimizer_checkpoint_path.is_file(): - print( - f"warning - optimizer checkpoint not present {optimizer_checkpoint_path}. Returning. " - ) - return - - full_osd = None - - if rank == 0: - full_osd = torch.load(optimizer_checkpoint_path) - - # called from all ranks, though only rank0 has a valid param for full_osd - sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, model) - - print(f"optimizer shard loaded on rank {rank}") - -def load_sharded_model_single_gpu(model,model_path): - - reader = FileSystemReader(model_path) - - state_dict = { - "model": model.state_dict() - } - - dist_cp.load_state_dict( - state_dict=state_dict, - storage_reader= FileSystemReader(model_path), - no_dist=True, - ) - - model.load_state_dict(state_dict["model"]) - - print(f"Sharded state checkpoint loaded from {model_path}") - return model - -def save_peft_checkpoint(model, model_path): - """save_pretrained peft model""" - - options = StateDictOptions(full_state_dict=True, cpu_offload=True) - - state_dict = get_model_state_dict(model, options=options) - model.save_pretrained(model_path, state_dict=state_dict) diff --git a/llama_recipes/model_checkpointing/distill_checkpoint_handler.py b/llama_recipes/model_checkpointing/distill_checkpoint_handler.py deleted file mode 100644 index 629cdd7..0000000 --- a/llama_recipes/model_checkpointing/distill_checkpoint_handler.py +++ /dev/null @@ -1,400 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -from pathlib import Path -from datetime import datetime -import time -import torch - -from torch.distributed.fsdp import ( - FullyShardedDataParallel as FSDP, - StateDictType, - FullStateDictConfig, # general model non-sharded, non-flattened params - LocalStateDictConfig, # flattened params, usable only by FSDP - ShardedStateDictConfig, # un-flattened param but shards, usable by other parallel schemes. -) - -from torch.distributed._shard.checkpoint import ( - FileSystemReader, - # FileSystemWriter, - # save_state_dict, - # load_state_dict, -) -from torch.distributed.checkpoint.default_planner import ( - DefaultSavePlanner, - # DefaultLoadPlanner, -) -import torch.distributed._shard.checkpoint as dist_cp -import torch.distributed as dist - -# Added MZ 3/09/2024 -from src.utils.logging import print_header - - -def _rename_sharded(n: str) -> str: - """ - Rename sharded module names to match the original model - """ - n = n.replace('_fsdp_wrapped_module.','') - n = n.replace('._checkpoint_wrapped_module', '') - n = n.replace('.mlp._flat_param', '.mlp.layer') # feature_map - n = n.replace('._flat_param', '.weight') - return n - - -def get_trainable_weights(model: torch.nn.Module, keep_window_factors: bool = True) -> dict: - """ - Get the state_dict of the model with only trainable parameters - - state_dict() of FSDP-wrapped model collects weights - """ - # Similar to: - # return OrderedDict([ - # (n, p.detach().cpu()) for n, p in model.named_parameters() if p.requires_grad - # ]) - # But we still want to filter by params that require gradients - state_dict = model.state_dict() - save_params = [_rename_sharded(n) for n, p in model.named_parameters() if p.requires_grad] - named_parameters = list(state_dict.keys()) - for n in named_parameters: - if n not in save_params and ('window_factors' not in n or not keep_window_factors): # hack - del state_dict[n] - return state_dict - - -def load_trainable_weights(model: torch.nn.Module, checkpoint: dict[any], rank: int): - """ - Load trainable weights from a checkpoint to the model - -> checkpoint weights are in `checkpoint['model']` - """ - _keys = model.load_state_dict(checkpoint['model'], strict=False) - if rank == 0: - print_header('*** Keys loaded from state_dict ***') - for k in checkpoint['model'].keys(): - print(k) - try: - assert len(_keys.unexpected_keys) == 0 - if rank == 0: - print_header('*** All expected keys matched successfully ***') - except AssertionError as e: - if rank == 0: - print(f'AssertionError: {e}') - for n, p in model.named_parameters(): - if p.requires_grad: - print(n) - print('=' * 20) - print_header('*** Error: unexpected keys in checkpoint ***') - print('Unexpected keys:') - for k in _keys.unexpected_keys: - print(k) - print('=' * 20) - return model - - -def get_date_of_run(): - """ - Create date and time for file save uniqueness - example: 2022-05-07-08:31:12_PM' - """ - date_of_run = datetime.now().strftime("%Y-%m-%d-%I:%M:%S_%p") - print(f"--> current date and time of run = {date_of_run}") - return date_of_run - - -# create singleton saving policies to avoid making over and over -fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) - - -def load_model_sharded(model, rank, cfg, ignore_param_rule = None, model_path: str = None): - - # torch.manual_seed(103) - if model_path is None: - folder_name = ( - cfg.dist_checkpoint_root_folder - + "/" - + cfg.dist_checkpoint_folder - + "-" - + cfg.model_name - ) - - load_dir = Path.cwd() / folder_name - - if not load_dir.exists(): - if rank == 0: - print(f"Error for {load_dir}:") - print(f"-> No sharded_state_dict checkpoint directory found...skipping") - return - if rank == 0: - print(f"loading model from model path: {load_dir} ") - else: - load_dir = Path(model_path) - - reader = FileSystemReader(load_dir) - - if ignore_param_rule is None: - ignore_param_rule = lambda n, p: ( - not p.requires_grad # and 'feature_map' not in n or ('v_proj' in n or 'o_proj' in n) - ) - - with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): - state_dict = model.state_dict() - save_params = [ - _rename_sharded(n) - for n, p in model.named_parameters() if not ignore_param_rule(n, p) - ] - if rank == 0: - print_header('xxx Ignored parameters xxx') - named_parameters = list(state_dict.keys()) - for n in named_parameters: - if n not in save_params and 'window_factors' not in n: # hack - if rank == 0: - print(n) - del state_dict[n] - checkpoint = {"model": state_dict} - if rank == 0: - ck = checkpoint.keys() - print(f" checkpoint key len = {len(ck)} and \n keys = {ck}") - - dist_cp.load_state_dict( - state_dict=checkpoint, - storage_reader=reader, - ) - if rank == 0: - print("checkpoint after load_state_dict()") - ck = checkpoint.keys() - print(f" checkpoint key len = {len(ck)} and \n keys = {ck}") - # model.load_state_dict(checkpoint["model"]) - model = load_trainable_weights(model, checkpoint, rank) - if rank == 0: - print(f"Sharded state checkpoint loaded from {load_dir}") - - -def save_model_and_optimizer_sharded(model, rank, cfg,optim=None): - """save model and optimizer via sharded_state_dict to save_dir""" - folder_name = ( - cfg.dist_checkpoint_root_folder - + "/" - + cfg.dist_checkpoint_folder - + "-" - + cfg.model_name - ) - - save_dir = Path.cwd() / folder_name - if rank == 0: - print(f"Saving model to {save_dir}") - - distributed_writer = dist_cp.FileSystemWriter( - save_dir, - ) - t0 = time.perf_counter() - - with FSDP.state_dict_type(model, - StateDictType.SHARDED_STATE_DICT, - ShardedStateDictConfig(offload_to_cpu=True), - ): - # state_dict = {"model": model.state_dict()} - state_dict = model.state_dict() - - # state_dict = model.state_dict(state_dict_device='cpu') - save_params = [ - _rename_sharded(n) - # n.replace('_fsdp_wrapped_module.','').replace('._checkpoint_wrapped_module', '').replace('.mlp._flat_param', '.mlp.layer').replace('._flat_param', '.weight') - for n, p in model.named_parameters() if p.requires_grad - ] - named_parameters = list(state_dict.keys()) - for n in named_parameters: - if n not in save_params and 'window_factors' not in n: # hack - del state_dict[n] - # state_dict = {"model": get_trainable_weights(model)} - state_dict = {"model": state_dict} - if optim is not None: - state_dict["optim"] = FSDP.optim_state_dict(model, optim) - - if rank == 0: - for k, v in state_dict['model'].items(): - if 'layers.0' in k: - print(k, v.device) - dist_cp.save_state_dict( - state_dict=state_dict, - storage_writer=distributed_writer, - planner=DefaultSavePlanner(), - - ) - dist.barrier() - t1 = time.perf_counter() - if rank == 0: - print(f"Sharded state checkpoint saved to {save_dir}") - print( - f"Checkpoint Time = {t1-t0:.4f}\n" - ) - get_date_of_run() - return save_dir - - -def save_model_checkpoint( - model, - optimizer, - rank, - cfg, - epoch=1, -): - """saving model via rank0 cpu streaming and full_state_dict""" - - with FSDP.state_dict_type( - model, StateDictType.FULL_STATE_DICT, fullstate_save_policy - ): - # cpu_state = model.state_dict() - # cpu_state = get_trainable_weights(model) - # trainable_weights(model) - if rank == 0: - print('Testing') - state_dict = model.state_dict() - save_params = [ - n.replace('_fsdp_wrapped_module.','').replace('._checkpoint_wrapped_module', '').replace('.mlp._flat_param', '.mlp.layer').replace('._flat_param', '.weight') - for n, p in model.named_parameters() if p.requires_grad - ] - named_parameters = list(state_dict.keys()) - for n in named_parameters: - if n not in save_params and 'window_factors' not in n: # hack - del state_dict[n] - cpu_state = state_dict - - print(f"saving process: rank {rank} done w model state_dict\n") - - - if rank == 0: - print("--> saving model ...") - # create save path - folder_name = ( - cfg.dist_checkpoint_root_folder - ) - save_name = ( - cfg.model_name - + "-" - + cfg.dist_checkpoint_folder + ".pt" - ) - save_dir = Path.cwd() / folder_name - save_dir.mkdir(parents=True, exist_ok=True) - # save_name = cfg.model_name + "-" + str(epoch) + ".pt" - save_full_path = str(save_dir) + "/" + save_name - - # save model - torch.save({"model": cpu_state}, save_full_path) - # torch.save(cpu_state, save_full_path) - - print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n") - return save_full_path - - -def load_model_checkpoint(model, rank, cfg): - """load local checkpoint to rank0 cpu - must be called * before * passing to FSDP""" - - if rank != 0: - return - - # where is the checkpoint at... - full_state_dict_model_path = ( - Path.cwd() / cfg.checkpoint_folder / cfg.checkpoint_model_filename - ) - # is it present... - if not full_state_dict_model_path.is_file(): - print( - f"model checkpoint {full_state_dict_model_path} not present. Returning..." - ) - return - - - model_checkpoint = torch.load(full_state_dict_model_path) - # integrate into loaded model - # model.load_state_dict(model_checkpoint) - model = load_trainable_weights(model, model_checkpoint, rank) - print("model checkpoint loaded to rank0 cpu") - - -def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1): - """save optimizer state via full state dict""" - - print(f"--> optim state call on rank {rank}\n") - - # pull all sharded optimizer states to rank0 cpu... - optim_state = FSDP.full_optim_state_dict(model, optimizer) - print(f"optim state dict ready on {rank} and len of {len(optim_state)}\n") - - if rank == 0: - folder_name = ( - cfg.dist_checkpoint_root_folder - + "/" - + cfg.dist_checkpoint_folder - + "-" - + cfg.model_name - ) - save_dir = Path.cwd() / folder_name - save_dir.mkdir(parents=True, exist_ok=True) - - opt_save_name = ( - "optimizer" + "-" + cfg.model_name + "-" + str(epoch) + ".pt" - ) - opt_save_full_path = save_dir / opt_save_name - - print("--> saving optimizer state...") - torch.save(optim_state, opt_save_full_path) - print(f"--> saved {opt_save_full_path} to disk") - - -def load_optimizer_checkpoint(model, optimizer_checkpoint_path, rank): - """load an fsdp optimizer full_state checkpoint using scatter method - this ensures only rank 0 loads the optimizer state dict and scatters to other ranks - """ - if not optimizer_checkpoint_path.is_file(): - print( - f"warning - optimizer checkpoint not present {optimizer_checkpoint_path}. Returning. " - ) - return - - full_osd = None - - if rank == 0: - full_osd = torch.load(optimizer_checkpoint_path) - - # called from all ranks, though only rank0 has a valid param for full_osd - sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, model) - print(f"optimizer shard loaded on rank {rank}") - - -def load_sharded_model_single_gpu(model, model_path=None, cfg=None, rank=None): - """ - Load sharded model weights to a single model - -> Should call this for the single model loaded on rank0 (which has actual weights) - """ - if model_path is None: - model_path = ( - cfg.dist_checkpoint_root_folder - + "/" - + cfg.dist_checkpoint_folder - + "-" - + cfg.model_name - ) - - model_path = Path.cwd() / model_path - - if not model_path.exists(): - if rank == 0: - print(f"-> Error for {model_path}:") - print(" -> No sharded_state_dict checkpoint directory found...skipping") - return - if rank == 0: - print(f"loading model from model path: {model_path} ") - # reader = FileSystemReader(model_path) - # keep_window_factors = False if 'no_distill' in model_path else True - keep_window_factors = True - state_dict = {"model": get_trainable_weights(model, keep_window_factors=keep_window_factors)} - print_header('*** (Trainable) keys in state_dict ***') - for k, v in state_dict['model'].items(): - print(k) - - # breakpoint() - dist_cp.load_state_dict(state_dict=state_dict, storage_reader= FileSystemReader(model_path), no_dist=True,) - - model = load_trainable_weights(model, state_dict, rank=0) - print(f"Sharded state checkpoint loaded from {model_path}") - return model diff --git a/llama_recipes/policies/__init__.py b/llama_recipes/policies/__init__.py deleted file mode 100644 index 6513095..0000000 --- a/llama_recipes/policies/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -from llama_recipes.policies.mixed_precision import * -from llama_recipes.policies.wrapping import * -from llama_recipes.policies.activation_checkpointing_functions import apply_fsdp_checkpointing -from llama_recipes.policies.anyprecision_optimizer import AnyPrecisionAdamW diff --git a/llama_recipes/policies/activation_checkpointing_functions.py b/llama_recipes/policies/activation_checkpointing_functions.py deleted file mode 100644 index 818b7da..0000000 --- a/llama_recipes/policies/activation_checkpointing_functions.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -from functools import partial - -from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - checkpoint_wrapper, - CheckpointImpl, - apply_activation_checkpointing, -) -from transformers.models.llama.modeling_llama import LlamaDecoderLayer - -non_reentrant_wrapper = partial( - checkpoint_wrapper, - checkpoint_impl=CheckpointImpl.NO_REENTRANT, -) - -check_fn = lambda submodule: isinstance(submodule, LlamaDecoderLayer) - - -def apply_fsdp_checkpointing(model): - """apply activation checkpointing to model - returns None as model is updated directly - """ - print(f"--> applying fsdp activation checkpointing...") - - apply_activation_checkpointing( - model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn - ) diff --git a/llama_recipes/policies/anyprecision_optimizer.py b/llama_recipes/policies/anyprecision_optimizer.py deleted file mode 100644 index 22b0ca0..0000000 --- a/llama_recipes/policies/anyprecision_optimizer.py +++ /dev/null @@ -1,179 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -# AnyPrecisionAdamW: a flexible precision AdamW optimizer -# with optional Kahan summation for high precision weight updates. -# Allows direct control over momentum, variance and auxiliary compensation -# buffer dtypes. -# Optional Kahan summation is used to offset precision reduction for -# the weight updates. This allows full training in BFloat16 (equal or -# better than FP32 results in many cases) due to high precision weight upates. - -import torch -from torch.optim.optimizer import Optimizer - - -class AnyPrecisionAdamW(Optimizer): - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=0.0, - use_kahan_summation=False, - momentum_dtype=torch.bfloat16, - variance_dtype=torch.bfloat16, - compensation_buffer_dtype=torch.bfloat16, - ): - """ - Args: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay coefficient (default: 1e-2) - - # Any Precision specific - use_kahan_summation = creates auxiliary buffer to ensure high precision - model param updates (default: False) - momentum_dtype = dtype for momentum (default: BFloat32) - variance_dtype = dtype for uncentered variance (default: BFloat16) - compensation_buffer_dtype = dtype for Kahan summation - buffer (default: BFloat16) - - # Usage - This optimizer implements optimizer states, and Kahan summation - for high precision updates, all in user controlled dtypes. - Defaults are variance in BF16, Momentum in FP32. - This can be run in FSDP mixed precision, amp, or full precision, - depending on what training pipeline you wish to work with. - - Setting to use_kahan_summation = False, and changing momentum and - variance dtypes to FP32, reverts this to a standard AdamW optimizer. - - """ - defaults = dict( - lr=lr, - betas=betas, - eps=eps, - weight_decay=weight_decay, - use_kahan_summation=use_kahan_summation, - momentum_dtype=momentum_dtype, - variance_dtype=variance_dtype, - compensation_buffer_dtype=compensation_buffer_dtype, - ) - - super().__init__(params, defaults) - - @torch.no_grad() - def step(self, closure=None): - """Performs a single optimization step. - Args: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - - if closure is not None: - with torch.enable_grad(): - # to fix linter, we do not keep the returned loss for use atm. - closure() - - for group in self.param_groups: - - beta1, beta2 = group["betas"] - lr = group["lr"] - weight_decay = group["weight_decay"] - eps = group["eps"] - use_kahan_summation = group["use_kahan_summation"] - - momentum_dtype = group["momentum_dtype"] - variance_dtype = group["variance_dtype"] - compensation_buffer_dtype = group["compensation_buffer_dtype"] - - for p in group["params"]: - if p.grad is None: - continue - - if p.grad.is_sparse: - raise RuntimeError( - "AnyPrecisionAdamW does not support sparse gradients" - ) - - state = self.state[p] - - # State initialization - if len(state) == 0: - - state["step"] = torch.tensor(0.0) - - # momentum - EMA of gradient values - state["exp_avg"] = torch.zeros_like( - p, - dtype=momentum_dtype, - ) - - # variance uncentered - EMA of squared gradient values - state["exp_avg_sq"] = torch.zeros_like( - p, - dtype=variance_dtype, - ) - - # optional Kahan summation - accumulated error tracker - if use_kahan_summation: - state["compensation"] = torch.zeros_like( - p, - dtype=compensation_buffer_dtype, - ) - - # main processing ------------------------- - - # update the steps for each param group update - state["step"] += 1 - step = state["step"] - - exp_avg = state["exp_avg"] - exp_avg_sq = state["exp_avg_sq"] - - grad = p.grad - - # weight decay, AdamW style - if weight_decay: - p.data.mul_(1 - lr * weight_decay) - - # update momentum - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - - # update uncentered variance - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - - # adjust using bias1 - bias_correction1 = 1 - beta1**step - - step_size = lr / bias_correction1 - - # adjust using bias2 - denom_correction = (1 - beta2**step) ** 0.5 # avoids math import - - centered_variance = (exp_avg_sq.sqrt() / denom_correction).add_( - eps, alpha=1 - ) - - # lr update to compensation - if use_kahan_summation: - compensation = state["compensation"] - - compensation.addcdiv_(exp_avg, centered_variance, value=-step_size) - - # update weights with compensation (Kahan summation) - # save error back to compensation for next iteration - temp_buffer = p.detach().clone() - p.data.add_(compensation) - compensation.add_(temp_buffer.sub_(p.data)) - - else: - # usual AdamW updates - p.data.addcdiv_(exp_avg, centered_variance, value=-step_size) \ No newline at end of file diff --git a/llama_recipes/policies/mixed_precision.py b/llama_recipes/policies/mixed_precision.py deleted file mode 100644 index 11df7ed..0000000 --- a/llama_recipes/policies/mixed_precision.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -import torch - -from torch.distributed.fsdp import ( - MixedPrecision, -) - -# requires grad scaler in main loop -fpSixteen = MixedPrecision( - param_dtype=torch.float16, - # Gradient communication precision. - reduce_dtype=torch.float16, - # Buffer precision. - buffer_dtype=torch.float16, -) - -bfSixteen = MixedPrecision( - param_dtype=torch.bfloat16, - # Gradient communication precision. - reduce_dtype=torch.bfloat16, - # Buffer precision. - buffer_dtype=torch.bfloat16, - cast_forward_inputs=True, -) - -bfSixteen_mixed = MixedPrecision( - param_dtype=torch.float32, - reduce_dtype=torch.bfloat16, - buffer_dtype=torch.bfloat16, -) - -fp32_policy = MixedPrecision( - param_dtype=torch.float32, - reduce_dtype=torch.float32, - buffer_dtype=torch.float32, -) diff --git a/llama_recipes/policies/wrapping.py b/llama_recipes/policies/wrapping.py deleted file mode 100644 index d7bb5a6..0000000 --- a/llama_recipes/policies/wrapping.py +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -import functools -from torch.distributed.fsdp.wrap import ( - transformer_auto_wrap_policy, - size_based_auto_wrap_policy, -) -from transformers.models.llama.modeling_llama import LlamaDecoderLayer -from transformers.models.mistral.modeling_mistral import MistralDecoderLayer # MZ added 3/09/2024 -from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer - - -def get_size_policy(min_params=1e8): - num_wrap_policy = functools.partial( - size_based_auto_wrap_policy, min_num_params=min_params - ) - return num_wrap_policy - - -def get_llama_wrapper(): - """we register our main layer class and use the fsdp transformer wrapping policy - ensures embedding layers are in the root fsdp unit for shared access and that fsdp units map to transformer layers - """ - # ==== use new transformer wrapper - - llama_auto_wrap_policy = functools.partial( - transformer_auto_wrap_policy, - transformer_layer_cls={ - LlamaDecoderLayer, - }, - ) - - return llama_auto_wrap_policy - - -# MZ added, 3/09/2024 -def get_mistral_wrapper(): - """we register our main layer class and use the fsdp transformer wrapping policy - ensures embedding layers are in the root fsdp unit for shared access and that fsdp units map to transformer layers - """ - # ==== use new transformer wrapper - - mistral_auto_wrap_policy = functools.partial( - transformer_auto_wrap_policy, - transformer_layer_cls={ - MistralDecoderLayer, - }, - ) - return mistral_auto_wrap_policy - - -def get_mixtral_wrapper(): - """we register our main layer class and use the fsdp transformer wrapping policy - ensures embedding layers are in the root fsdp unit for shared access and that fsdp units map to transformer layers - """ - # ==== use new transformer wrapper - - mistral_auto_wrap_policy = functools.partial( - transformer_auto_wrap_policy, - transformer_layer_cls={ - MixtralDecoderLayer, - }, - ) - return mistral_auto_wrap_policy diff --git a/llama_recipes/save_fsdp_to_hf_pt.py b/llama_recipes/save_fsdp_to_hf_pt.py deleted file mode 100644 index 43cc5f1..0000000 --- a/llama_recipes/save_fsdp_to_hf_pt.py +++ /dev/null @@ -1,363 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -""" -Convert model to .pt state_dict - -torchrun --nnodes 1 --nproc_per_node 7 llama_recipes/save_fsdp_to_hf_pt.py \ ---model_config distill_llama3_1_70b_lk_smd_wtk64_fd64_w01 \ ---distill_config distill_redpajama_xent1_mse1000_lr1e-2 \ ---finetune_config finetune_lora_qkvo_redpajama \ ---eval_config eval_alpaca_clean --verbose --replicate 4 --seed 0 --lk_zero_init \ ---eval_steps 10 --dataset_chunk_size 512 --enable_fsdp --low_cpu_fsdp \ ---load_distill_checkpoint /data/rahul/checkpoints/rp_0907/finetune_rp_llama_70b-fac=1-dcs=1024-se=0-re=0-lzi=1 \ ---load_finetune_checkpoint /data/rahul/checkpoints/rp_0907/finetune_rp_llama_70b_qkvo-fac=1-dcs=1024-se=0-re=0-lzi=1-dcs=1024-se=0-re=0 - - -torchrun --nnodes 1 --nproc_per_node 8 llama_recipes/save_fsdp_to_hf_pt.py \ ---model_config llama3_1_405b/distill_llama3_1_405b_lk_smd_wtk64_fd64_w01 \ ---distill_config llama3_1_405b/distill_llama_405b_xent0_mse1000_lr1e-3 \ ---finetune_config llama3_1_405b/finetune_llama_405b \ ---eval_config eval_alpaca_clean --verbose --replicate 4 --seed 0 --lk_zero_init \ ---eval_steps 10 --dataset_chunk_size 768 --enable_fsdp --low_cpu_fsdp \ ---load_distill_checkpoint /home/simarora/code/lolcats/checkpoints/llama3_1_405b/distill_llama3_1_405b_lk_smd_wtk64_fd64_w01/distill-dl-d=llama3_1_405b/distill_llama_405b_xent0_mse1000_lr1e-3-m=llama3_1_405b/distill_llama3_1_405b_lk_smd_wtk64_fd64_w01-f=llama3_1_405b/finetune_llama_405b-fac=1-dcs=768-se=0-re=0-lzi=1 \ ---load_finetune_checkpoint /home/simarora/code/lolcats/checkpoints/llama3_1_405b/distill_llama3_1_405b_lk_smd_wtk64_fd64_w01/finetune-dl-d=llama3_1_405b/distill_llama_405b_xent0_mse1000_lr1e-3-m=llama3_1_405b/distill_llama3_1_405b_lk_smd_wtk64_fd64_w01-f=llama3_1_405b/finetune_llama_405b-fac=1-dcs=768-se=0-re=0-lzi=1-dcs=768-se=0-re=0 - - -torchrun --nnodes 1 --nproc_per_node 7 llama_recipes/save_fsdp_to_hf_pt.py \ ---model_config llama3_1_70b/distill_llama3_1_70b_lk_smd_wtk64_fd64_w01 \ ---distill_config llama3_1_70b/distill_rp_llama_70b_xent1_mse1000_lr1e-2 \ ---finetune_config llama3_1_70b/finetune_rp_llama_70b \ ---eval_config eval_alpaca_clean --verbose --replicate 4 --seed 0 --lk_zero_init \ ---eval_steps 10 --dataset_chunk_size 1024 --enable_fsdp --low_cpu_fsdp \ ---load_distill_checkpoint /home/simarora/code/lolcats/checkpoints/llama3_1_70b/distill_llama3_1_70b_lk_smd_wtk64_fd64_w01/distill-dl-d=llama3_1_70b/distill_rp_llama_70b_xent1_mse1000_lr1e-2-m=llama3_1_70b/distill_llama3_1_70b_lk_smd_wtk64_fd64_w01-f=llama3_1_70b/finetune_rp_llama_70b-fac=1-dcs=1024-se=0-re=0-lzi=1 \ ---load_finetune_checkpoint /home/simarora/code/lolcats/checkpoints/llama3_1_70b/distill_llama3_1_70b_lk_smd_wtk64_fd64_w01/finetune-dl-d=llama3_1_70b/distill_rp_llama_70b_xent1_mse1000_lr1e-2-m=llama3_1_70b/distill_llama3_1_70b_lk_smd_wtk64_fd64_w01-f=llama3_1_70b/finetune_rp_llama_70b-fac=1-dcs=1024-se=0-re=0-lzi=1-dcs=1024-se=0-re=0 - - - - - -torchrun --nnodes 1 --nproc_per_node 7 llama_recipes/save_fsdp_to_hf_pt.py --model_config distill_llama3_1_70b_lk_smd_wtk64_fd64_w01 --distill_config distill_redpajama_xent1_mse1000_lr1e-2 --finetune_config finetune_lora_qkvo_redpajama --eval_config eval_alpaca_clean --verbose --replicate 4 --seed 0 --lk_zero_init --eval_steps 10 --dataset_chunk_size 512 --enable_fsdp --low_cpu_fsdp --load_distill_checkpoint /data/rahul/checkpoints/rp_0907/finetune_rp_llama_70b-fac=1-dcs=1024-se=0-re=0-lzi=1 --load_finetune_checkpoint /data/rahul/checkpoints/rp_0907/finetune_rp_llama_70b_qkvo-fac=1-dcs=1024-se=0-re=0-lzi=1-dcs=1024-se=0-re=0 -""" -from logging import StringTemplateStyle -import os -from os.path import join -# import sys -# sys.path.append('/workspace/lolcats') # needed for vast-ai instances -import dataclasses -import random -import argparse # ours -from pkg_resources import packaging - -import torch -import torch.optim as optim - -from torch.distributed.fsdp import ( - FullyShardedDataParallel as FSDP, - ShardingStrategy, - StateDictType -) - -from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload - -from llama_recipes.configs import fsdp_config as FSDP_CONFIG -# from llama_recipes.configs import train_config as TRAIN_CONFIG -from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing - -from llama_recipes.utils.fsdp_utils import fsdp_auto_wrap_policy -from llama_recipes.utils.config_utils import ( - update_config, - # generate_peft_config, - # generate_dataset_config, - # get_dataloader_kwargs, -) -from llama_recipes.utils.fsdp_utils import ( - hsdp_device_mesh as get_hsdp_device_mesh -) -from llama_recipes.trainer_finetune import ( - train, - setup, - setup_environ_flags, - clear_gpu_cache, - print_model_size, - get_policies, -) -from llama_recipes.model_checkpointing.distill_checkpoint_handler import ( - load_model_sharded, - load_sharded_model_single_gpu, -) - -from accelerate.utils import is_xpu_available - -# ------------- -# Our arguments -# ------------- -from safetensors.torch import save_file -from omegaconf import OmegaConf - -from src.utils.setup import ( - update_config_from_args, - update_model_config_from_args -) -from src.utils.logging import print_header, print_config -# from src.dataloaders import load_data -from src.trainer import get_scheduler - -from src.finetune import prepare_finetune_configs # get_finetuner - -from src.model.pretrained import get_pretrained_loader -from src.model.load_model import ( - load_and_convert_attns, - load_and_convert_finetune -) -from distill_llama import ( - setup_wandb, get_args, # get_run_name_from_checkpoint - get_dataloaders, setup_fsdp_config -) - - -def main(): - # --------- - # 1. SET UP - # --------- - args = get_args() - args.enable_fsdp = True - args.checkpoint_dir = join(args.checkpoint_dir, args.model_config) - if not os.path.isdir(args.checkpoint_dir): - os.makedirs(args.checkpoint_dir) - - kwargs = vars(args) - - # Load distillation + attention configs - distill_config_path = join('./configs/experiment', f'{args.distill_config}.yaml') - distill_config = OmegaConf.load(distill_config_path) - distill_config = update_config_from_args(distill_config, args) - - model_config_path = join('./configs/model', f'{args.model_config}.yaml') - model_config = OmegaConf.load(model_config_path) - model_config = update_model_config_from_args(model_config, args) - if args.enable_fsdp: - if getattr(model_config.model, 'load_in_4bit', False): - model_config.model.device_map = 'auto' - elif getattr(model_config.model, 'load_in_8bit', False): - model_config.model.device_map = 'auto' - else: - model_config.model.device_map = None # FSDP will complain about device placement o.w. - model_config.model.low_cpu_mem_usage = True - - print_config(model_config.model) - try: - if not os.path.exists(model_config.model.pretrained_model_name_or_path): - print(f"Model path {model_config.model.pretrained_model_name_or_path} does not exist. Using backup path. {model_config.model.pretrained_model_name_or_path_backup}") - model_config.model.pretrained_model_name_or_path = model_config.model.pretrained_model_name_or_path_backup - model_config.model.pop("pretrained_model_name_or_path_backup") - except: - print(f"Model without model.pretrained_model_name_or_path_backup path") - pass - - # Update dataset pretrained model config - # for k in distill_config.dataset.pretrained_model_config: - # distill_config.dataset.pretrained_model_config[k] = getattr(model_config.model, k) - - args.run_name = args.run_name.replace('True', '1').replace('False', '0') # concise hacks - - # Update the configuration for the training and sharding process - distill_config = setup_fsdp_config(distill_config, args, 'distill') # patch llama-recipes args - fsdp_config = FSDP_CONFIG() - update_config((fsdp_config), **vars(args)) - # Set the seeds for reproducibility - if is_xpu_available(): - torch.xpu.manual_seed(args.seed) - torch.manual_seed(args.seed) - random.seed(args.seed) - - if args.enable_fsdp: - setup() - # torchrun specific - local_rank = int(os.environ["LOCAL_RANK"]) - rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - - if rank == 0 or not args.enable_fsdp: - print_header('Distillation Config') - print_config(distill_config) - print_header('Model Config') - print_config(model_config) - print_header('FSDP Config') - print_config(dataclasses.asdict(fsdp_config)) - - if torch.distributed.is_initialized(): - if is_xpu_available(): - torch.xpu.set_device(local_rank) - elif torch.cuda.is_available(): - torch.cuda.set_device(local_rank) - clear_gpu_cache(local_rank) - setup_environ_flags(rank) - - # ------------------------ - # 2. LOAD PRETRAINED MODEL - # ------------------------ - - # Load the pre-trained model and setup its configuration - # Initialize tokenizer and model loader - model_loader = get_pretrained_loader(**model_config.model, - huggingface_token=args.huggingface_token) - tokenizer = model_loader.load_tokenizer() - tokenizer.pad_token_id = tokenizer.eos_token_id - tokenizer.padding_side = 'left' - - use_cache = False if args.enable_fsdp else None - - from transformers import LlamaConfig as ModelConfig - from transformers.models.llama.modeling_llama import LlamaDecoderLayer as DecoderLayer - from src.model.modeling_llama import LolcatsLlamaForCausalLM as ModelClass - model_type = 'llama' - - # Convert model - try: - args.attention_type = model_config['attention']['attention_type'] - except AttributeError: - args.attention_type = 'lolcats_llama' - - if args.enable_fsdp and args.low_cpu_fsdp: - # for FSDP, we can save cpu memory by loading pretrained model on rank0 only. - # this avoids cpu oom when loading large models like llama 70B, in which case - # model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms - # overhead and currently requires latest nightly. - v = packaging.version.parse(torch.__version__) - verify_latest_nightly = v.is_devrelease and v.dev >= 20230701 - if not verify_latest_nightly and rank == 0: - print(f'-> Pytorch version is {v} ({v.dev})') - print(f' - Llama-recipes says "latest pytorch nightly build is required to run with low_cpu_fsdp config"') - print(f" - But who knows maybe this will work. We're just trying stuff.") - print(f" - (Also if PyTorch was installed after July 1, 2023 we should be good)") - # raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, " - # "please install latest nightly.") - model = model_loader.load(args.attention_type) - model.state_chunk_len = model_config['attention']['state_chunk_len'] - else: - model = model_loader.load(args.attention_type) - model.state_chunk_len = model_config['attention']['state_chunk_len'] - - if rank == 0 or not args.enable_fsdp: - print_header('Pretrained Model') - - model_config.model_name = model_config.model.pretrained_model_name_or_path - print_model_size(model, model_config, rank if args.enable_fsdp else 0) - - # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled - if args.enable_fsdp and fsdp_config.pure_bf16: - model.to(torch.bfloat16) - - # ------------------------------- - # 3. CONVERT DISTILLED ATTENTIONS - # ------------------------------- - model, distill_peft_config = load_and_convert_attns(model, model_config, - attention_type=args.attention_type, - checkpoint_path=None, # args.load_distill_checkpoint, - print_model=args.verbose, - merge_loras=False, - peft_gradient_checkpointing=not args.no_peft_grad_ckpt, - train_attention=False, - rank=rank) - if rank == 0: - print_header('** Sanity check model weights **') - for n, p in model.named_parameters(): - if ('layers.0.' in n and ('feature_map' in n or 'lora' in n)): - print(f'-> {n}:\n', p) - if args.load_distill_checkpoint is not None: - model = load_sharded_model_single_gpu(model, model_path=args.load_distill_checkpoint, cfg=distill_config, rank=rank) - else: - model = load_sharded_model_single_gpu(model, model_path=None, cfg=distill_config, rank=rank) - - - # ---------------------------- - # 4. ADD FINETUNING PARAMETERS - # ---------------------------- - finetune_config, args = prepare_finetune_configs(args, model_config, - args.finetune_config) - # finetune_config = update_config_from_args(finetune_config, args) - finetune_config = setup_fsdp_config(finetune_config, args, 'finetune') - - # model, ft_peft_config - model, _ = load_and_convert_finetune(model, finetune_config, - checkpoint_path=None, - print_model=args.verbose, - merge_loras=False, - peft_gradient_checkpointing=not args.no_peft_grad_ckpt, - rank=rank) - - # ------------------------------------------------------ - # 5. SETUP FSDP AND LOAD DISTILLED ATTENTION CHECKPOINTS - # ------------------------------------------------------ - mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank, model=model_type) - my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, DecoderLayer) - - device_id = 0 - if is_xpu_available(): - device_id = torch.xpu.current_device() - elif torch.cuda.is_available(): - device_id = torch.cuda.current_device() - print('-> device_id:', device_id) - - # Load distilled checkpoints - if args.verbose and rank == 0: - print_header('*** FSDP Model ***') - print(model) - print('Loading checkpoints from:', distill_config.model_name) - - # load_model_sharded(model, rank, distill_config, model_path=args.load_distill_checkpoint) - - if rank == 0 or not args.enable_fsdp: # debugging - print_header('** Sanity check model weights **') - for n, p in model.named_parameters(): - if ('layers.0.' in n and 'base_attn' not in n and - '.0.mlp.' not in n and '.block_sparse_moe' not in n): - print(f'-> {n}:\n', p) - - if args.verbose and (rank == 0 or not args.enable_fsdp): - print_header('*** FSDP MODEL ***') - print(model) - print_header('*** Trainable Parameters ***') - for n, p in model.named_parameters(): - if p.requires_grad: - print(f'├── {n} (dtype = {p.dtype})') - - # Load sharded weights across GPUs into model - # ignore_param_rule = lambda n, p: not p.requires_grad - # load_model_sharded(model, rank, finetune_config, ignore_param_rule) - # model = load_sharded_model_single_gpu(model, model_path=None, cfg=finetune_config, rank=rank) - - if args.load_finetune_checkpoint is not None: - model = load_sharded_model_single_gpu(model, model_path=args.load_finetune_checkpoint, cfg=finetune_config, rank=rank) - else: - model = load_sharded_model_single_gpu(model, model_path=None, cfg=finetune_config, rank=rank) - - - if rank == 0: # debugging - print_header('** Sanity check model weights **') - for n, p in model.named_parameters(): - if ('layers.0.' in n and 'base_attn' not in n and - '.0.mlp.' not in n and '.block_sparse_moe' not in n): - print(f'-> {n}:\n', p) - - args.run_name = args.run_name.replace('llama3_1_405b/', '').replace('llama3_1_70b/', '') - - print(model.config) - # model.save_pretrained(f'ckpt-lora_hf-{args.run_name}') - if rank == 0: - with torch.no_grad(): - state_dict = model.state_dict() - keys_to_keep = [key for key in state_dict.keys() if 'lora' in key or 'window_factors' in key or 'feature_map' in key] - # keys_to_keep = [key for key in state_dict.keys() if 'lora' in key] - new_state_dict = {key: state_dict[key] for key in keys_to_keep} - torch.save(new_state_dict, f'ckpt_lora-{args.run_name}.pt') - - print_header('*** Weights in state_dict ***') - for k in torch.load(f'ckpt_lora-{args.run_name}.pt'): - print(k) - print('-> Checkpoints saved to:', f'ckpt_lora-{args.run_name}.pt') - # save_file(new_state_dict, f"ckpt-{args.run_name}.safetensors") - -if __name__ == "__main__": - main() diff --git a/llama_recipes/save_fsdp_to_pt_405b.py b/llama_recipes/save_fsdp_to_pt_405b.py deleted file mode 100644 index 3ce5067..0000000 --- a/llama_recipes/save_fsdp_to_pt_405b.py +++ /dev/null @@ -1,561 +0,0 @@ -""" -torchrun --nnodes 1 --nproc_per_node 8 llama_recipes/save_fsdp_to_pt_405b.py \ ---model_config llama3_1_405b/distill_llama3_1_405b_lk_smd_wtk64_fd64_w01 \ ---distill_config llama3_1_405b/distill_llama_405b_xent1_mse1000_lr1e-2 \ ---finetune_config llama3_1_405b/finetune_layer_mini_xent1_mse1000 \ ---final_finetune_config llama3_1_405b/finetune_llama_405b_qkvo \ ---verbose --replicate 0 --seed 0 \ ---layers_per_model 9 --layer_idx 0 \ ---enable_fsdp --low_cpu_fsdp --fsdp_activation_checkpointing \ ---load_finetune_checkpoint /home/mzhang/projects/lolcats/checkpoints/llama3_1_405b/distill_llama3_1_405b_lk_smd_wtk64_fd64_w01/sharded_layers/finetune-dl-d=llama3_1_405b/distill_llama_405b_xent1_mse1000_lr1e-2-m=llama3_1_405b/distill_llama3_1_405b_lk_smd_wtk64_fd64_w01-f=llama3_1_405b/finetune_layer_mini_xent1_mse1000-s=0-se=0-re=0-se=0-re=0 - -torchrun --nnodes 1 --nproc_per_node 8 llama_recipes/save_fsdp_to_pt_405b.py \ ---model_config llama3_1_405b/distill_llama3_1_405b_lk_smd_wtk64_fd64_w01 \ ---distill_config llama3_1_405b/no_distill_alpaca_clean \ ---finetune_config llama3_1_405b/finetune_layer_mini_xent1_mse1000 \ ---final_finetune_config llama3_1_405b/no_distill_finetune_405b \ ---verbose --replicate 0 --seed 0 \ ---layers_per_model 9 --layer_idx 0 \ ---enable_fsdp --low_cpu_fsdp --fsdp_activation_checkpointing \ ---no_distill \ ---load_finetune_checkpoint /home/simarora/code/lolcats/checkpoints/llama3_1_405b/distill_llama3_1_405b_lk_smd_wtk64_fd64_w01/finetune-dl-d=llama3_1_405b/no_distill_llama3_1_405b-m=llama3_1_405b/distill_llama3_1_405b_lk_smd_wtk64_fd64_w01-f=llama3_1_405b/no_distill_finetune_405b-fac=1-dcs=1024-se=0-re=0-lzi=1-dcs=1024-se=0-re=0 - - -torchrun --nnodes 1 --nproc_per_node 8 llama_recipes/save_fsdp_to_pt_405b.py \ ---model_config llama3_1_405b/distill_llama3_1_405b_lk_smd_wtk64_fd64_w01_h117 \ ---distill_config llama3_1_405b/distill_llama_405b_xent1_mse1000_lr1e-2 \ ---finetune_config llama3_1_405b/finetune_layer_mini_xent1_mse1000 \ ---final_finetune_config llama3_1_405b/finetune_llama_405b_qkvo \ ---verbose --replicate 0 --seed 0 \ ---layers_per_model 9 --layer_idx 0 \ ---enable_fsdp --low_cpu_fsdp --fsdp_activation_checkpointing \ ---load_finetune_checkpoint /home/mzhang/projects/lolcats/checkpoints/llama3_1_405b/distill_llama3_1_405b_lk_smd_wtk64_fd64_w01_h117/sharded_layers/finetune-dl-d=llama3_1_405b/distill_llama_405b_xent1_mse1000_lr1e-2-m=llama3_1_405b/distill_llama3_1_405b_lk_smd_wtk64_fd64_w01_h117-f=llama3_1_405b/finetune_layer_mini_xent1_mse1000-s=0-se=0-re=0-se=0-re=0 - - -torchrun --nnodes 1 --nproc_per_node 8 llama_recipes/save_fsdp_to_pt_405b.py \ ---model_config llama3_1_405b/distill_llama3_1_405b_lk_smd_wtk64_fd64_w01 \ ---distill_config llama3_1_405b/distill_llama_405b_xent1_mse1000_lr1e-2 \ ---finetune_config llama3_1_405b/finetune_layer_mini_xent1_mse1000 \ ---final_finetune_config llama3_1_405b/finetune_llama_405b_qkvo_e2 \ ---verbose --replicate 0 --seed 0 \ ---layers_per_model 9 --layer_idx 0 \ ---enable_fsdp --low_cpu_fsdp --fsdp_activation_checkpointing \ ---load_finetune_checkpoint /home/mzhang/projects/lolcats/checkpoints/llama3_1_405b/distill_llama3_1_405b_lk_smd_wtk64_fd64_w01/sharded_layers/finetune-dl-d=llama3_1_405b/distill_llama_405b_xent1_mse1000_lr1e-2-m=llama3_1_405b/distill_llama3_1_405b_lk_smd_wtk64_fd64_w01-f=llama3_1_405b/finetune_layer_mini_xent1_mse1000-s=0-se=0-re=0-ef=llama3_1_405b/finetune_llama_405b_qkvo_e2-ft_lora=0-se=0-re=0 - - -Hybrid redpajama - -torchrun --nnodes 1 --nproc_per_node 8 llama_recipes/save_fsdp_to_pt_405b.py \ ---model_config llama3_1_405b/distill_llama3_1_405b_lk_smd_wtk64_fd64_w01_h72_80_117_125 \ ---distill_config llama3_1_405b/distill_llama_405b_xent1_mse1000_lr1e-2 \ ---finetune_config llama3_1_405b/finetune_layer_mini_xent1_mse1000 \ ---final_finetune_config llama3_1_405b/finetune_llama_405b_qv_e2_rp_h72_80_117_125 \ ---verbose --replicate 0 --seed 0 \ ---layers_per_model 9 --layer_idx 0 \ ---enable_fsdp --low_cpu_fsdp --fsdp_activation_checkpointing \ ---load_finetune_checkpoint /home/mzhang/projects/lolcats/checkpoints/llama3_1_405b/distill_llama3_1_405b_lk_smd_wtk64_fd64_w01_h72_80_117_125/sharded_layers/finetune-dl-d=llama3_1_405b/distill_llama_405b_xent1_mse1000_lr1e-2-m=llama3_1_405b/distill_llama3_1_405b_lk_smd_wtk64_fd64_w01_h72_80_117_125-f=llama3_1_405b/finetune_layer_mini_xent1_mse1000-s=0-se=0-re=0-se=0-re=0 - - -""" - -import os -from os.path import join - -import argparse -from omegaconf import OmegaConf -from tqdm import tqdm - -import torch -from transformers import LlamaConfig -from transformers.models.llama.modeling_llama import ( LlamaConfig ) -import torch.optim as optim -from omegaconf import OmegaConf -from transformers.models.llama.modeling_llama import LlamaDecoderLayer as DecoderLayer - -# Distributed arguments -from torch.distributed.fsdp import ( - FullyShardedDataParallel as FSDP, - ShardingStrategy, - StateDictType -) - -import torch.distributed as dist - -from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload -from llama_recipes.configs import fsdp_config as FSDP_CONFIG -from llama_recipes.policies import apply_fsdp_checkpointing -from llama_recipes.utils.fsdp_utils import fsdp_auto_wrap_policy -from accelerate.utils import is_xpu_available -# from distill_llama import setup_fsdp_config -from llama_recipes.utils.fsdp_utils import ( - hsdp_device_mesh as get_hsdp_device_mesh -) - -# Our arguments -from llama_recipes.trainer_finetune import ( - train, - setup, - setup_environ_flags, - clear_gpu_cache, - get_policies, -) -from llama_recipes.model_checkpointing.distill_checkpoint_handler import ( - load_model_sharded, - load_sharded_model_single_gpu, -) - -from src.utils.setup import ( - init_wandb, seed_everything, flatten_config, get_run_name_from_args, - update_config_from_args, update_model_config_from_args, -) -from src.utils.logging import print_config, print_header -from src.model.pretrained import get_pretrained_loader - -from src.model.convert_model import toggle_attention, remove_base_attention, traverse_layers -from src.model.load_model import load_and_convert_attns, load_and_convert_finetune - -from src.trainer import get_scheduler -from src.finetune import prepare_finetune_configs # get_finetuner -from distill_llama import ( - setup_wandb, get_args, get_run_name_from_checkpoint, - get_dataloaders, setup_fsdp_config -) - - -def get_args(): - """Parse command line arguments""" - parser = argparse.ArgumentParser() - parser.add_argument("--project_name", type=str, default='lolcats') - parser.add_argument("--layers_per_model", type=int) - parser.add_argument("--layer_idx", type=int) # specify starting layer - parser.add_argument("--device", type=int, default=0) - - parser.add_argument("--load_finetuned_loras", action='store_true', default=False) - parser.add_argument("--no_distill", action='store_true', default=False) - - parser.add_argument("--model_config", type=str, default=None) - parser.add_argument("--distill_config", type=str, default=None) - parser.add_argument("--finetune_config", type=str, default=None) - parser.add_argument("--eval_config", type=str, default=None) - - parser.add_argument("--final_finetune_config", type=str, default=None) - - parser.add_argument("--pretrained_model_name_or_path", type=str, default=None) - parser.add_argument("--load_distill_checkpoint", type=str, default=None) - parser.add_argument("--resume_distill", action='store_true', default=None) - - parser.add_argument("--load_finetune_checkpoint", type=str, default=None) - parser.add_argument("--resume_finetune", action='store_true', default=None) - - # Override default configs - # Feature map / model - parser.add_argument("--attention_type", type=str, default=None) - parser.add_argument("--learned_kernel", type=str, default=None) # always - parser.add_argument("--lk_skip_connection", action='store_true', default=None) - parser.add_argument("--lk_zero_init", action='store_true', default=None) - parser.add_argument("--lk_normal_init", action='store_true', default=None) - parser.add_argument("--tie_qk_kernels", action='store_true', default=None) - parser.add_argument("--train_qk", action='store_true', default=None) - parser.add_argument("--state_chunk_len", type=int, default=None) - - # Training - parser.add_argument("--lr", type=float, default=None) - parser.add_argument("--weight_decay", type=float, default=None) - parser.add_argument("--optim", type=str, default=None) - parser.add_argument("--scheduler", type=str, default=None) - parser.add_argument("--gradient_accumulation_steps", type=int, default=None) - parser.add_argument("--num_train_epochs", type=int, default=None) - parser.add_argument("--max_steps", type=int, default=None) - parser.add_argument("--max_finetune_steps", type=int, default=None) - - parser.add_argument("--no_peft_grad_ckpt", action='store_true', default=None) - - ## Distributed training / Llama recipes - parser.add_argument("--enable_fsdp", action='store_true', default=None) - parser.add_argument("--low_cpu_fsdp", action='store_true', default=None) - parser.add_argument("--pure_bf16", action='store_true', default=None) - parser.add_argument("--fsdp_activation_checkpointing", action='store_true', default=None) - parser.add_argument("--fsdp_cpu_offload", action='store_true', default=None) - - # Dataloading - parser.add_argument("--batch_size", type=int, default=None) - parser.add_argument("--num_workers", type=int, default=None) - - # Evaluation - parser.add_argument("--no_init_eval", action='store_true', default=False) - parser.add_argument("--eval_steps", type=int, default=None) - parser.add_argument("--max_eval_batches", type=int, default=None) - - # Miscellaneous - parser.add_argument("--huggingface_token", type=str, default=None) - parser.add_argument("--checkpoint_dir", type=str, default='./checkpoints') # changed - parser.add_argument("--results_dir", type=str, default='./results') - parser.add_argument("--replicate", type=int, default=0) - parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--verbose", action='store_true', default=None) - parser.add_argument("--no_cuda", action='store_true', default=None) - parser.add_argument("--no_wandb", action='store_true', default=None) - parser.add_argument("--wandb_entity", type=str, default='hazy-research') - parser.add_argument("--debug", action='store_true', default=None) - parser.add_argument("--no_attention_mask", action='store_true', default=None) - - args = parser.parse_args() - args.run_name = get_run_name_from_args(args) - return args - - -def setup_fsdp_config(config, args, checkpoint_name: str = 'finetune', output_dir: str = None): - """ - Hacky arguments for llama-recipes training function - -> hardcoded save path - """ - config.seed = args.seed - config.enable_fsdp = args.enable_fsdp - config.low_cpu_fsdp = args.low_cpu_fsdp - config.dist_checkpoint_root_folder = args.checkpoint_dir # '/home/mzhang/projects/lolcats/checkpoints/' - config.dist_checkpoint_folder = checkpoint_name - - config.model_name = args.run_name - config.use_peft = False # We have custom logic for saving PEFT modules - - if getattr(config, 'fsdp', None) is None: - config.save_model = True - config.run_validation = True - config.use_fp16 = False - config.save_model = True - config.save_optimizer = False - config.gradient_clipping = False - config.gradient_clipping_threshold = 1.0 - else: - for attr in ['save_model', 'run_validation', 'use_fp16', 'save_optimizer', - 'gradient_clipping', 'gradient_clipping_threshold']: - setattr(config, attr, getattr(config.fsdp, attr)) - config.output_dir = args.checkpoint_dir if output_dir is None else output_dir - config.save_metrics = not args.no_wandb - config.num_epochs = getattr(config.trainer, 'num_train_epochs', None) - config.num_train_steps = getattr(args, 'num_train_steps', None) # exit training loop early for debugging - config.eval_steps = getattr(config.trainer, 'eval_steps', None), # config.trainer.eval_steps # how many gradient updates before evaluating - return config - - -def main(): - # ------ - # SET UP - # ------ - args = get_args() - CHECKPOINT_DIR_405B = "/home/simarora/code/lolcats/checkpoints/" # "/data_ephemeral/sim/sharded_layers_405b/" - CHECKPOINT_MODEL_CONFIG = 'llama3_1_405b/distill_llama3_1_405b_lk_smd_wtk64_fd64_w01' - - if args.enable_fsdp: - local_rank = int(os.environ["LOCAL_RANK"]) - rank = int(os.environ["RANK"]) - - # Where to save the output model checkpoints? - args.checkpoint_dir = join(args.checkpoint_dir, args.model_config) - if not os.path.isdir(args.checkpoint_dir) and ((args.enable_fsdp and rank == 0 and local_rank == 0) or not args.enable_fsdp): - os.makedirs(args.checkpoint_dir) - - # Save individual .pt model weights in a subdirectory - args.checkpoint_dir = join(args.checkpoint_dir, 'sharded_layers') - if not os.path.isdir(args.checkpoint_dir) and ((args.enable_fsdp and rank == 0 and local_rank == 0) or not args.enable_fsdp): - os.makedirs(args.checkpoint_dir) - - args.results_dir = join(args.results_dir, args.model_config) - if not os.path.isdir(args.results_dir): - os.makedirs(args.results_dir) - seed_everything(args.seed) - - # Load distillation + (hedgehog) attention configs - distill_config_path = join('./configs/experiment', f'{args.distill_config}.yaml') - distill_config = OmegaConf.load(distill_config_path) - distill_config = update_config_from_args(distill_config, args) - distill_config = setup_fsdp_config(distill_config, args, 'distill') # patch - - # for arg, argv in distill_config.trainer.items(): # legacy, should be removed - # if arg != 'name': - # setattr(args, arg, argv) - # for _config in ['dataloader', 'optimizer', 'lr_scheduler']: - # setattr(args, _config, OmegaConf.to_container(getattr(distill_config, _config))) - - fsdp_config = FSDP_CONFIG() - if is_xpu_available(): - torch.xpu.manual_seed(distill_config.seed) - torch.manual_seed(distill_config.seed) - import random - random.seed(distill_config.seed) - - from llama_recipes.utils.config_utils import (update_config,get_dataloader_kwargs,) - update_config((fsdp_config), **vars(args)) - - if args.enable_fsdp: - setup() - # torchrun specific - local_rank = int(os.environ["LOCAL_RANK"]) - rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - else: - rank = 0 - - if torch.distributed.is_initialized(): - if is_xpu_available(): - torch.xpu.set_device(local_rank) - elif torch.cuda.is_available(): - torch.cuda.set_device(local_rank) - clear_gpu_cache(local_rank) - setup_environ_flags(rank) - - wandb_run = None - if not args.no_wandb: - if not args.enable_fsdp or rank==0: - _run_name = args.run_name - kwargs = vars(args) - kwargs['run_name'] = _run_name - if args.final_finetune_config is not None: # Update checkpoint for e2e finetune and lora loading - kwargs['run_name'] += f'-ef={args.final_finetune_config}' - kwargs['run_name'] += f'-ft_lora={args.load_finetuned_loras}'.replace('True', '1').replace('False', '0') - wandb_run = setup_wandb(distill_config, fsdp_config, **kwargs, - project=args.project_name, entity=args.wandb_entity) - - model_config_path = join('./configs/model', f'{args.model_config}.yaml') - model_config = OmegaConf.load(model_config_path) - model_config = update_model_config_from_args(model_config, args) - if args.enable_fsdp: - if getattr(model_config.model, 'load_in_4bit', False): - model_config.model.device_map = 'auto' - elif getattr(model_config.model, 'load_in_8bit', False): - model_config.model.device_map = 'auto' - else: - model_config.model.device_map = None # FSDP will complain about device placement o.w. - model_config.model.low_cpu_mem_usage = True - - try: - if not os.path.exists(model_config.model.pretrained_model_name_or_path): - print(f"Model path {model_config.model.pretrained_model_name_or_path} does not exist. Using backup path. {model_config.model.pretrained_model_name_or_path_backup}") - model_config.model.pretrained_model_name_or_path = model_config.model.pretrained_model_name_or_path_backup - model_config.model.pop("pretrained_model_name_or_path_backup") - except: - print(f"Model without model.pretrained_model_name_or_path_backup path") - pass - - if rank == 0 or not args.enable_fsdp: - print_header('Model Config') - print_config(model_config) - - # Get model class and configs for layer instantiating - pretrained_model_config = LlamaConfig.from_pretrained(model_config['model']['pretrained_model_name_or_path']) - pretrained_model_class = pretrained_model_config.architectures[0] - transformers_module = __import__('transformers') - pretrained_model_class = getattr(transformers_module, pretrained_model_class) # e.g, LlamaForCausalLM - - # Final run name / checkpoint naming setup - num_hidden_layers = pretrained_model_config.num_hidden_layers # e.g., 32 for Llama 8B - max_digits = len(str(num_hidden_layers)) - start, end = args.layer_idx, args.layer_idx + args.layers_per_model - 1 - # name_suffix = f'in={start:0{max_digits}d}-out={end:0{max_digits}d}' - # args.run_name += f'-{name_suffix}' # will save layer-wise checkpoints - args.run_name = args.run_name.replace('True', '1').replace('False', '0') # concise hacks - if rank == 0 or not args.enable_fsdp: - print(f"Running distill for {num_hidden_layers}; Layers {start} through {end}!") - print(f"{args.run_name=}") - - dtype = getattr(torch, model_config['model']['torch_dtype']) - if rank == 0 or not args.enable_fsdp: - print_header('Pretrained Model Config') - print(pretrained_model_config) - - # Step 1. Load the pretrained model and tokenizer. - if rank == 0 or not args.enable_fsdp: - print(model_config) - model_loader = get_pretrained_loader(**model_config.model, - huggingface_token=args.huggingface_token) - model = model_loader.load(model_type='softmax') - if rank == 0 or not args.enable_fsdp: - print(model) - if args.enable_fsdp and fsdp_config.pure_bf16: - model.to(torch.bfloat16) - for p in model.parameters(): # Freeze all layers - p.requires_grad = False - model.eval() - - tokenizer = model_loader.load_tokenizer() - tokenizer.pad_token_id = tokenizer.eos_token_id - tokenizer.padding_side = 'left' - - if rank == 0: print(f"Loaded the model.") - - # Step 2. Linearized template. - model = load_and_convert_attns( - model, - model_config, - attention_type=None, # specified in model_config, - checkpoint_path=None, - print_model=args.verbose, - train_attention=False)[0] - - # Step 3. Loop through the saved checkpoints. - def check_state_dict_keys(_keys, layer_idx): - try: - assert len(_keys.unexpected_keys) == 0 - if rank == 0: - print_header(f'*** All expected keys matched successfully {layer_idx} ***') - except Exception as e: - if rank == 0: - print(e) - print_header('*** Error: unexpected keys in checkpoint ***') - print(f'Unexpected keys at {layer_idx}:') - for k in _keys.unexpected_keys: - print(k) - - def rename_state_dict(rename_dict, start_layer_idx): - new_state_dict = {} - for k, v in rename_dict.items(): - if "layers" in k: - k_name = k.split("layers.")[-1].split(".")[0] - k_idx = int(k_name) - new_k_idx = k_idx + start_layer_idx - new_k_name = k.replace(k_name, str(new_k_idx)) - new_state_dict[new_k_name] = v - if start_layer_idx > 9 and start_layer_idx < 18: - print(f"Renaming {k} to {new_k_name}") - else: - new_state_dict[k] = v - return new_state_dict - - if not args.no_distill: - with torch.no_grad(): - first = 0 - for layer_idx, layer in enumerate(tqdm(traverse_layers(model))): - # file name - # load_file_name = f'{join(CHECKPOINT_DIR_405B, args.run_name)}' - # max_digits = len(str(num_hidden_layers)) - # start, end = first, first + (args.layers_per_model - 1) - # name_suffix = f'in={start:0{max_digits}d}-out={end:0{max_digits}d}' - # load_file_name += f'-{name_suffix}' - # load_file_name = load_file_name.replace('True', '1').replace('False', '0') # concise hacks - # load_file_name = load_file_name + f'_distill.pt' - - load_file_name = join(CHECKPOINT_DIR_405B, f'dl-d={args.distill_config}-m={CHECKPOINT_MODEL_CONFIG}-f={args.finetune_config}') - load_file_name += f'-s={args.seed}-se={args.seed}-re={args.replicate}' - max_digits = len(str(num_hidden_layers)) - start, end = first, first + (args.layers_per_model - 1) - name_suffix = f'in={start:0{max_digits}d}-out={end:0{max_digits}d}' - load_file_name += f'-{name_suffix}_distill.pt' - load_file_name = load_file_name.replace('True', '1').replace('False', '0') # concise hacks - - if (layer_idx + 1) % args.layers_per_model == 0: - if rank == 0 or not args.enable_fsdp: - print(f'Loading layer attentions from {CHECKPOINT_DIR_405B}...') - mini_weights = torch.load(load_file_name)['model_state_dict'] - mini_weights = rename_state_dict(mini_weights, first) - _keys = model.load_state_dict(mini_weights, strict=False) - check_state_dict_keys(_keys, first) - first = layer_idx + 1 - # args.run_name += f'-{name_suffix}' # dealing with legacy naming - - - # Step 4. Add finetuning parameters. - final_finetune_config, args = prepare_finetune_configs(args, model_config, - args.final_finetune_config) - final_finetune_config = setup_fsdp_config(final_finetune_config, args, 'finetune', - output_dir='/home/mzhang/projects/lolcats/results/llama3_1_405b') # hardcode - - args.finetune_lr = None - # if args.finetune_lr is not None: - # final_finetune_config.model_name += f'=flr={args.finetune_lr}' - - model, _ = load_and_convert_finetune(model, final_finetune_config, - checkpoint_path=None, - print_model=args.verbose, - merge_loras=False, - peft_gradient_checkpointing=not args.no_peft_grad_ckpt, - rank=rank) - - - # Step 5. Add the lora weights from mini-distill. - if args.load_finetuned_loras: - print(f"Loading loras") - with torch.no_grad(): - first = 0 - # args.run_name = finetune_layer_mini_xent1_mse1000-s=0-se=0-re=0-bs=1-gas=8-nte=2-ms=-1-se=0-re=0-in=000-out=008_ft.pt - - for layer_idx, layer in enumerate(tqdm(traverse_layers(model))): - # example file names: - # ./checkpoints/ft-dl-d=0000_out=008_distill0d-m=llama3_1_405b/distill_llama3_1_405b_lk_smd_wtk64_fd64_w01-f=llama3_1_405b/finetune_layer_mini_xent1_mse1000-s=0-se=0-re=0-in=000-out=008-se=0-re=0_ft.pt - # ./checkpoints/ft-dl-d=0001_out=125_distill1d-m=llama3_1_405b/distill_llama3_1_405b_lk_smd_wtk64_fd64_w01-f=llama3_1_405b/finetune_layer_mini_xent1_mse1000-s=0-se=0-re=0-in=117-out=125-se=0-re=0_ft.pt - # 'finetune_layer_mini_xent1_mse1000-s=0-se=0-re=0-in=099-out=107_distill.pt' - # 'finetune_layer_mini_xent1_mse1000-s=0-se=0-re=0-in=108-out=116_distill.pt' - # 'finetune_layer_mini_xent1_mse1000-s=0-se=0-re=0-in=117-out=125_distill.pt' - - # Get distill checkpoint name first - # load_file_name = f'{join(args.checkpoint_dir, args.run_name)}' - load_file_name = join(CHECKPOINT_DIR_405B, f'dl-d={args.distill_config}-m={CHECKPOINT_MODEL_CONFIG}-f={args.finetune_config}') - load_file_name += f'-s={args.seed}-se={args.seed}-re={args.replicate}' - max_digits = len(str(num_hidden_layers)) - start, end = first, first + (args.layers_per_model - 1) - name_suffix = f'in={start:0{max_digits}d}-out={end:0{max_digits}d}' - load_file_name += f'-{name_suffix}_distill.pt' - load_file_name = load_file_name.replace('True', '1').replace('False', '0') # concise hacks - args.load_distill_checkpoint = load_file_name - print('args.load_distill_checkpoint:', args.load_distill_checkpoint) - args.run_name = get_run_name_from_args(args) - - args.run_name = join(CHECKPOINT_DIR_405B, 'ft-' + args.run_name) - # update_config_from_args(final_finetune_config, args) - args.run_name += f'-se={args.seed}-re={args.replicate}-{name_suffix}-se={args.seed}-re={args.replicate}_ft.pt' - load_file_name = args.run_name.replace('True', '1').replace('False', '0') # concise hacks - - if (layer_idx + 1) % args.layers_per_model == 0: - if rank == 0 or not args.enable_fsdp: - print(f'Loading layer loras from {CHECKPOINT_DIR_405B}...') - mini_weights = torch.load(load_file_name)['model_state_dict'] - mini_weights = rename_state_dict(mini_weights, first) - _keys = model.load_state_dict(mini_weights, strict=False) - check_state_dict_keys(_keys, first) - first = layer_idx + 1 - - # Actual final run name / checkpoint naming - if (args.final_finetune_config is not None and - f'-ef={args.final_finetune_config}' not in args.run_name): # Update checkpoint for e2e finetune and lora loading - args.run_name += f'-ef={args.final_finetune_config}' - args.run_name += f'-ft_lora={args.load_finetuned_loras}'.replace('True', '1').replace('False', '0') - if args.no_distill: - args.run_name += '-no_distill' - - if args.load_finetune_checkpoint is not None: - model = load_sharded_model_single_gpu(model, model_path=args.load_finetune_checkpoint, cfg=final_finetune_config, rank=rank) - else: - model = load_sharded_model_single_gpu(model, model_path=None, cfg=final_finetune_config, rank=rank) - - - if rank == 0: # debugging - print_header('** Sanity check model weights **') - for n, p in model.named_parameters(): - if ('layers.0.' in n and 'base_attn' not in n and - '.0.mlp.' not in n and '.block_sparse_moe' not in n): - print(f'-> {n}:\n', p) - - args.run_name = args.run_name.replace('llama3_1_405b/', '') - - print(model.config) - # model.save_pretrained(f'ckpt-lora_hf-{args.run_name}') - if rank == 0: - with torch.no_grad(): - state_dict = model.state_dict() - keys_to_keep = [key for key in state_dict.keys() if 'lora' in key or 'window_factors' in key or 'feature_map' in key] - # keys_to_keep = [key for key in state_dict.keys() if 'lora' in key] - new_state_dict = {key: state_dict[key] for key in keys_to_keep} - torch.save(new_state_dict, f'ckpt_lora-{args.run_name}.pt') - - print_header('*** Weights in state_dict ***') - for k in torch.load(f'ckpt_lora-{args.run_name}.pt'): - if int(k.split('layers.')[-1].split('.')[0]) < 1: - print(k) - print('-> Checkpoints saved to:', f'ckpt_lora-{args.run_name}.pt') - # save_file(new_state_dict, f"ckpt-{args.run_name}.safetensors") - - - -if __name__ == '__main__': - main() - print("Thanks for washing my dishes") - diff --git a/llama_recipes/save_llama_attn_inputs.py b/llama_recipes/save_llama_attn_inputs.py deleted file mode 100644 index dc8f621..0000000 --- a/llama_recipes/save_llama_attn_inputs.py +++ /dev/null @@ -1,356 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -""" -Finetune attention-swapped model. Rough adaptation of llama_recipes script for distillation. - -Example usage (using the same args as distill_llama.py for convenience (just swap the file called) -``` -torchrun --nnodes 1 --nproc_per_node 8 \ -llama_recipes/save_llama_attn_inputs.py \ ---model_config distill_llama3_1_70b_lk_smd_wtk64_fd64_w01 \ ---distill_config distill_alpaca_clean_llama3_1_70b_xent1_mse1000_lr1e-2 \ ---finetune_config finetune_lora_qkvo_alpaca_clean_llama3_1_70b \ ---verbose --replicate 0 --seed 0 \ ---enable_fsdp --low_cpu_fsdp -``` - -"Another one" -``` -torchrun --nnodes 1 --nproc_per_node 1 \ -llama_recipes/save_llama_attn_inputs.py \ ---model_config distill_llama3_8b_lk_smd_wtk64_fd64_w01 \ ---distill_config distill_alpaca_clean_xent0_mse1000_lr1e-2 \ ---finetune_config finetune_lora_qkvo_alpaca_clean \ ---verbose --replicate 0 --seed 0 \ ---enable_fsdp --low_cpu_fsdp -``` -""" -import os -from os.path import join -# import sys -# sys.path.append('/workspace/lolcats') # needed for vast-ai instances -import dataclasses -import random -import argparse # ours -from pkg_resources import packaging - -import torch -import torch.optim as optim -import torch.distributed as dist - -from torch.distributed.fsdp import ( - FullyShardedDataParallel as FSDP, - ShardingStrategy, - StateDictType -) - -from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload - -from llama_recipes.configs import fsdp_config as FSDP_CONFIG -# from llama_recipes.configs import train_config as TRAIN_CONFIG -from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing - -from llama_recipes.utils.fsdp_utils import fsdp_auto_wrap_policy -from llama_recipes.utils.config_utils import ( - update_config, -) -from llama_recipes.utils.fsdp_utils import ( - hsdp_device_mesh as get_hsdp_device_mesh -) -from llama_recipes.trainer_finetune import ( - train as _train_normal, - setup, - setup_environ_flags, - clear_gpu_cache, - print_model_size, - get_policies, -) -from llama_recipes.model_checkpointing.distill_checkpoint_handler import ( - load_model_sharded, - load_sharded_model_single_gpu, -) -# from llama_recipes.trainer_finetune_chunked import train as train_chunked - -from accelerate.utils import is_xpu_available - -# ------------- -# Our arguments -# ------------- -from omegaconf import OmegaConf - -from src.utils.setup import ( - update_config_from_args, - update_model_config_from_args -) -from src.utils.logging import print_header, print_config -from src.trainer import get_scheduler - -from src.finetune import prepare_finetune_configs # get_finetuner - -from src.model.pretrained import get_pretrained_loader -from src.model.load_model import ( - load_and_convert_attns, - load_and_convert_finetune -) -from distill_llama import ( - setup_wandb, get_args, # get_run_name_from_checkpoint - get_dataloaders, setup_fsdp_config -) - -from tqdm import tqdm - - -CUTOFF_BATCH = 500 # Save to disk and delete tensors every CUTOFF_BATCH - # to save CPU memory - - -def main(): - # --------- - # 1. SET UP - # --------- - args = get_args() - args.checkpoint_dir = join(args.checkpoint_dir, args.model_config) - if not os.path.isdir(args.checkpoint_dir): - os.makedirs(args.checkpoint_dir) - - kwargs = vars(args) - - if args.enable_fsdp: - setup() - # torchrun specific - local_rank = int(os.environ["LOCAL_RANK"]) - rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - else: - rank = 0 - - # Load distillation config - # -> Compute layer-wise outputs with the dataset specified here - distill_config_path = join('./configs/experiment', f'{args.distill_config}.yaml') - distill_config = OmegaConf.load(distill_config_path) - dataset_name = distill_config.dataset.name - cache_dir = distill_config.dataset.dataset_config.cache_dir - - # Load model config - model_config_path = join('./configs/model', f'{args.model_config}.yaml') - model_config = OmegaConf.load(model_config_path) - model_name = model_config.model.pretrained_model_name_or_path.replace('/', '_') - - # Create data directory where we'll store the layer-wise input tensors - if rank == 0 or not args.enable_fsdp: - data_dir = join(cache_dir, dataset_name) - if not os.path.isdir(data_dir): - os.makedirs(data_dir) - data_dir = join(data_dir, model_name) # meta-llama_Meta-Llama-3.1-70B/attn_inputs-l=31-split=train.pt - if not os.path.isdir(data_dir): - os.makedirs(data_dir) - print(f'-> Saving layer-wise tensors to {data_dir}') - # dist.barrier() - - # Copied from distill_llama.py and distill_llama_finetune.py for FSPD - if args.enable_fsdp: - if getattr(model_config.model, 'load_in_4bit', False): - model_config.model.device_map = 'auto' - elif getattr(model_config.model, 'load_in_8bit', False): - model_config.model.device_map = 'auto' - else: - model_config.model.device_map = None # FSDP will complain about device placement o.w. - - # Update dataset pretrained model config - for k in distill_config.dataset.pretrained_model_config: - distill_config.dataset.pretrained_model_config[k] = getattr(model_config.model, k) - - # Update the configuration for the training and sharding process - distill_config = setup_fsdp_config(distill_config, args, 'distill') # patch llama-recipes args - fsdp_config = FSDP_CONFIG() - update_config((fsdp_config), **vars(args)) - # Set the seeds for reproducibility - if is_xpu_available(): - torch.xpu.manual_seed(args.seed) - torch.manual_seed(args.seed) - random.seed(args.seed) - - if rank == 0 or not args.enable_fsdp: - print_header('Distillation Config') - print_config(distill_config) - print_header('Model Config') - print_config(model_config) - print_header('FSDP Config') - print_config(dataclasses.asdict(fsdp_config)) - - if torch.distributed.is_initialized(): - if is_xpu_available(): - torch.xpu.set_device(local_rank) - elif torch.cuda.is_available(): - torch.cuda.set_device(local_rank) - clear_gpu_cache(local_rank) - setup_environ_flags(rank) - - # ------------------------ - # 2. LOAD PRETRAINED MODEL - # ------------------------ - - # Load the pre-trained model and setup its configuration - # Initialize tokenizer and model loader - model_loader = get_pretrained_loader(**model_config.model, - huggingface_token=args.huggingface_token) - tokenizer = model_loader.load_tokenizer() - tokenizer.pad_token_id = tokenizer.eos_token_id - tokenizer.padding_side = 'left' - - use_cache = False if args.enable_fsdp else None - - if 'llama' in model_config.model.pretrained_model_name_or_path: - from transformers import LlamaConfig as ModelConfig - from transformers.models.llama.modeling_llama import LlamaDecoderLayer as DecoderLayer - from src.model.modeling_llama import LolcatsLlamaForCausalLM as ModelClass - model_type = 'llama' - - # Convert model - try: - args.attention_type = model_config['attention']['attention_type'] - except AttributeError: - args.attention_type = 'lolcats_llama' - - if args.enable_fsdp and args.low_cpu_fsdp: - # for FSDP, we can save cpu memory by loading pretrained model on rank0 only. - # this avoids cpu oom when loading large models like llama 70B, in which case - # model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms - # overhead and currently requires latest nightly. - v = packaging.version.parse(torch.__version__) - verify_latest_nightly = v.is_devrelease and v.dev >= 20230701 - if not verify_latest_nightly and rank == 0: - print(f'-> Pytorch version is {v} ({v.dev})') - print(f' - Llama-recipes says "latest pytorch nightly build is required to run with low_cpu_fsdp config"') - print(f" - But who knows maybe this will work. We're just trying stuff.") - print(f" - (Also if PyTorch was installed after July 1, 2023 we should be good)") - # raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, " - # "please install latest nightly.") - model = model_loader.load('softmax') # Load the original Transformer models - - if rank == 0 or not args.enable_fsdp: - print_header('Pretrained Model') - - model_config.model_name = model_config.model.pretrained_model_name_or_path - print_model_size(model, model_config, rank if args.enable_fsdp else 0) - - # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled - if args.enable_fsdp and fsdp_config.pure_bf16: - model.to(torch.bfloat16) - - model.eval() - for p in model.parameters(): - p.requires_grad = False - - # -------------- - # 5. SETUP FSDP - # -------------- - if args.enable_fsdp: - mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank, model=model_type) - my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, DecoderLayer) - - device_id = 0 - if is_xpu_available(): - device_id = torch.xpu.current_device() - elif torch.cuda.is_available(): - device_id = torch.cuda.current_device() - print('-> device_id:', device_id) - - model = FSDP( - model, - auto_wrap_policy=my_auto_wrapping_policy, # if train_config.use_peft else wrapping_policy, - cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None, - mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None, - sharding_strategy=fsdp_config.sharding_strategy, - # device_mesh=hsdp_device_mesh, - device_id=device_id, - limit_all_gathers=True, - sync_module_states=args.low_cpu_fsdp, # train_config.low_cpu_fsdp - param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False) - if args.low_cpu_fsdp and rank != 0 else None, - ) - if fsdp_config.fsdp_activation_checkpointing: - apply_fsdp_checkpointing(model) - - - else: # if not model_config.model.quantization and not args.enable_fsdp: - if is_xpu_available(): - model.to("xpu:0") - elif torch.cuda.is_available(): - model.to("cuda") - - if args.verbose and (rank == 0 or not args.enable_fsdp): - print_header('*** FSDP MODEL ***') - print(model) - - # Get data - train_dataloader, eval_dataloader, distill_config = get_dataloaders(distill_config, tokenizer, - no_shuffle_train = True) - if not args.enable_fsdp or rank == 0: - print(f"--> Training Set Length = {len(train_dataloader.dataset)}") - print(f"--> Validation Set Length = {len(eval_dataloader.dataset)}") - - # ----------------------------------- - # Compute dataset layer-wise outputs - # ----------------------------------- - for split, dataloader in {'train': train_dataloader, 'validation': eval_dataloader}.items(): - if rank == 0 or not args.enable_fsdp: - print_header(f'*** Computing layer-wise {split} outputs ***') - - attn_inputs_by_layer = [[] for _ in range(len(model.model.layers))] - max_layer_digits = len(str(len(attn_inputs_by_layer))) - with torch.no_grad(): - model.eval() - pbar = tqdm(dataloader, desc=f'❯❯❯ Computing layer-wise outputs on rank {rank} for {split} split') - max_digits = len(str(len(pbar))) - for step, batch in enumerate(pbar): - batch = {'input_ids': batch['input_ids']} - key = 'input_ids' - if distill_config.enable_fsdp: - if is_xpu_available(): - batch[key] = batch[key].to(torch.device(f"xpu:{local_rank}")) - else: - batch[key] = batch[key].to(local_rank) - else: - if is_xpu_available(): - batch[key] = batch[key].to('xpu:0') - else: - batch[key] = batch[key].to('cuda:0') - # Save tensors -> HuggingFace Llama API - outputs = model(**batch, output_hidden_states=True, use_cache=False).get('hidden_states') - for idx, hidden_state in enumerate(outputs[:-1]): # indexed by layer, last layer is hidden layer before lm_head - hidden_state = model.model.layers[idx].input_layernorm(hidden_state).cpu() - attn_inputs_by_layer[idx].append(hidden_state) - - if args.enable_fsdp: - dist.barrier() - - if (step + 1) % CUTOFF_BATCH == 0: - # Save layer-wise outputs to disk - for layer_idx, attn_inputs in enumerate(attn_inputs_by_layer): - attn_inputs = torch.cat(attn_inputs, dim=0) # attn_inputs.shape is (batch, seq_len, hidden_size) - fpath = join(data_dir, f'attn_inputs-l={layer_idx:0{max_layer_digits}d}-s={split}-b={step:0{max_digits}d}-rank={rank:02d}.pt') - torch.save(attn_inputs, fpath) - if rank == 0 or not args.enable_fsdp: - print(f'-> Saved layer-wise tensors for {split} to {data_dir}!') - print(f'-> Example: {fpath}') - del attn_inputs_by_layer - attn_inputs_by_layer = [[] for _ in range(len(model.model.layers))] - - if args.enable_fsdp: - dist.barrier() - - # Save layer-wise outputs to disk - for layer_idx, attn_inputs in enumerate(attn_inputs_by_layer): - attn_inputs = torch.cat(attn_inputs, dim=0) # attn_inputs.shape is (batch, seq_len, hidden_size) - fpath = join(data_dir, f'attn_inputs-l={layer_idx:0{max_layer_digits}d}-s={split}-b={step:0{max_digits}d}-rank={rank:02d}.pt') - torch.save(attn_inputs, fpath) - if rank == 0 or not args.enable_fsdp: - print(f'-> Saved layer-wise tensors for {split} to {data_dir}!') - print(f'-> Example: {fpath}') - if args.enable_fsdp: - dist.barrier() - -if __name__ == "__main__": - main() diff --git a/llama_recipes/save_llama_outputs.py b/llama_recipes/save_llama_outputs.py deleted file mode 100644 index 28861e7..0000000 --- a/llama_recipes/save_llama_outputs.py +++ /dev/null @@ -1,317 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -""" -Finetune attention-swapped model. Rough adaptation of llama_recipes script for distillation. - -Example usage (using the same args as distill_llama.py for convenience (just swap the file called) -``` -torchrun --nnodes 1 --nproc_per_node 8 \ -llama_recipes/save_llama_outputs.py \ ---model_config distill_llama3_1_70b_lk_smd_wtk64_fd64_w01 \ ---distill_config distill_alpaca_clean_llama3_1_70b_xent1_mse1000_lr1e-2 \ ---finetune_config finetune_lora_qkvo_alpca_clean_llama3_1_70b \ ---verbose --replicate 4 --seed 0 \ ---enable_fsdp --low_cpu_fsdp -``` -""" -import os -from os.path import join -# import sys -# sys.path.append('/workspace/lolcats') # needed for vast-ai instances -import dataclasses -import random -import argparse # ours -from pkg_resources import packaging - -import torch -import torch.optim as optim -import torch.distributed as dist - -from torch.distributed.fsdp import ( - FullyShardedDataParallel as FSDP, - ShardingStrategy, - StateDictType -) - -from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload - -from llama_recipes.configs import fsdp_config as FSDP_CONFIG -# from llama_recipes.configs import train_config as TRAIN_CONFIG -from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing - -from llama_recipes.utils.fsdp_utils import fsdp_auto_wrap_policy -from llama_recipes.utils.config_utils import ( - update_config, -) -from llama_recipes.utils.fsdp_utils import ( - hsdp_device_mesh as get_hsdp_device_mesh -) -from llama_recipes.trainer_finetune import ( - train as _train_normal, - setup, - setup_environ_flags, - clear_gpu_cache, - print_model_size, - get_policies, -) -from llama_recipes.model_checkpointing.distill_checkpoint_handler import ( - load_model_sharded, - load_sharded_model_single_gpu, -) -from llama_recipes.trainer_finetune_chunked import train as train_chunked - -from accelerate.utils import is_xpu_available - -# ------------- -# Our arguments -# ------------- -from omegaconf import OmegaConf - -from src.utils.setup import ( - update_config_from_args, - update_model_config_from_args -) -from src.utils.logging import print_header, print_config -from src.trainer import get_scheduler - -from src.finetune import prepare_finetune_configs # get_finetuner - -from src.model.pretrained import get_pretrained_loader -from src.model.load_model import ( - load_and_convert_attns, - load_and_convert_finetune -) -from distill_llama import ( - setup_wandb, get_args, # get_run_name_from_checkpoint - get_dataloaders, setup_fsdp_config -) - -from tqdm import tqdm - - -def main(): - # --------- - # 1. SET UP - # --------- - args = get_args() - args.checkpoint_dir = join(args.checkpoint_dir, args.model_config) - if not os.path.isdir(args.checkpoint_dir): - os.makedirs(args.checkpoint_dir) - - kwargs = vars(args) - - if args.enable_fsdp: - setup() - # torchrun specific - local_rank = int(os.environ["LOCAL_RANK"]) - rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - - # Load distillation config - # -> Compute layer-wise outputs with the dataset specified here - distill_config_path = join('./configs/experiment', f'{args.distill_config}.yaml') - distill_config = OmegaConf.load(distill_config_path) - dataset_name = distill_config.dataset.name - cache_dir = distill_config.dataset.dataset_config.cache_dir - - # Load model config - model_config_path = join('./configs/model', f'{args.model_config}.yaml') - model_config = OmegaConf.load(model_config_path) - model_name = model_config.model.pretrained_model_name_or_path.replace('/', '_') - - # Create data directory where we'll store the layer-wise output tensors - if rank == 0: - data_dir = join(cache_dir, dataset_name) #, model_name) - if not os.path.isdir(data_dir): - os.makedirs(data_dir) - data_dir = join(data_dir, model_name) - if not os.path.isdir(data_dir): - os.makedirs(data_dir) - print(f'-> Saving layer-wise tensors to {data_dir}') - # dist.barrier() - - # Copied from distill_llama.py and distill_llama_finetune.py for FSPD - if args.enable_fsdp: - if getattr(model_config.model, 'load_in_4bit', False): - model_config.model.device_map = 'auto' - elif getattr(model_config.model, 'load_in_8bit', False): - model_config.model.device_map = 'auto' - else: - model_config.model.device_map = None # FSDP will complain about device placement o.w. - - # Update dataset pretrained model config - for k in distill_config.dataset.pretrained_model_config: - distill_config.dataset.pretrained_model_config[k] = getattr(model_config.model, k) - - # Update the configuration for the training and sharding process - distill_config = setup_fsdp_config(distill_config, args, 'distill') # patch llama-recipes args - fsdp_config = FSDP_CONFIG() - update_config((fsdp_config), **vars(args)) - # Set the seeds for reproducibility - if is_xpu_available(): - torch.xpu.manual_seed(args.seed) - torch.manual_seed(args.seed) - random.seed(args.seed) - - if rank == 0 or not args.enable_fsdp: - print_header('Distillation Config') - print_config(distill_config) - print_header('Model Config') - print_config(model_config) - print_header('FSDP Config') - print_config(dataclasses.asdict(fsdp_config)) - - if torch.distributed.is_initialized(): - if is_xpu_available(): - torch.xpu.set_device(local_rank) - elif torch.cuda.is_available(): - torch.cuda.set_device(local_rank) - clear_gpu_cache(local_rank) - setup_environ_flags(rank) - - # ------------------------ - # 2. LOAD PRETRAINED MODEL - # ------------------------ - - # Load the pre-trained model and setup its configuration - # Initialize tokenizer and model loader - model_loader = get_pretrained_loader(**model_config.model, - huggingface_token=args.huggingface_token) - tokenizer = model_loader.load_tokenizer() - tokenizer.pad_token_id = tokenizer.eos_token_id - tokenizer.padding_side = 'left' - - use_cache = False if args.enable_fsdp else None - - if 'llama' in model_config.model.pretrained_model_name_or_path: - from transformers import LlamaConfig as ModelConfig - from transformers.models.llama.modeling_llama import LlamaDecoderLayer as DecoderLayer - from src.model.modeling_llama import LolcatsLlamaForCausalLM as ModelClass - model_type = 'llama' - - # Convert model - try: - args.attention_type = model_config['attention']['attention_type'] - except AttributeError: - args.attention_type = 'lolcats_llama' - - if args.enable_fsdp and args.low_cpu_fsdp: - # for FSDP, we can save cpu memory by loading pretrained model on rank0 only. - # this avoids cpu oom when loading large models like llama 70B, in which case - # model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms - # overhead and currently requires latest nightly. - v = packaging.version.parse(torch.__version__) - verify_latest_nightly = v.is_devrelease and v.dev >= 20230701 - if not verify_latest_nightly and rank == 0: - print(f'-> Pytorch version is {v} ({v.dev})') - print(f' - Llama-recipes says "latest pytorch nightly build is required to run with low_cpu_fsdp config"') - print(f" - But who knows maybe this will work. We're just trying stuff.") - print(f" - (Also if PyTorch was installed after July 1, 2023 we should be good)") - # raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, " - # "please install latest nightly.") - model = model_loader.load('softmax') # Load the original Transformer models - - if rank == 0 or not args.enable_fsdp: - print_header('Pretrained Model') - - model_config.model_name = model_config.model.pretrained_model_name_or_path - print_model_size(model, model_config, rank if args.enable_fsdp else 0) - - # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled - if args.enable_fsdp and fsdp_config.pure_bf16: - model.to(torch.bfloat16) - - model.eval() - for p in model.parameters(): - p.requires_grad = False - - # -------------- - # 5. SETUP FSDP - # -------------- - if args.enable_fsdp: - mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank, model=model_type) - my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, DecoderLayer) - - device_id = 0 - if is_xpu_available(): - device_id = torch.xpu.current_device() - elif torch.cuda.is_available(): - device_id = torch.cuda.current_device() - print('-> device_id:', device_id) - - model = FSDP( - model, - auto_wrap_policy=my_auto_wrapping_policy, # if train_config.use_peft else wrapping_policy, - cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None, - mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None, - sharding_strategy=fsdp_config.sharding_strategy, - # device_mesh=hsdp_device_mesh, - device_id=device_id, - limit_all_gathers=True, - sync_module_states=args.low_cpu_fsdp, # train_config.low_cpu_fsdp - param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False) - if args.low_cpu_fsdp and rank != 0 else None, - ) - if fsdp_config.fsdp_activation_checkpointing: - apply_fsdp_checkpointing(model) - - - elif not model_config.model.quantization and not args.enable_fsdp: - if is_xpu_available(): - model.to("xpu:0") - elif torch.cuda.is_available(): - model.to("cuda") - - if args.verbose and (rank == 0 or not args.enable_fsdp): - print_header('*** FSDP MODEL ***') - print(model) - - # Get data - train_dataloader, eval_dataloader, distill_config = get_dataloaders(distill_config, tokenizer) - if not args.enable_fsdp or rank == 0: - print(f"--> Training Set Length = {len(train_dataloader.dataset)}") - print(f"--> Validation Set Length = {len(eval_dataloader.dataset)}") - - # ----------------------------------- - # Compute dataset layer-wise outputs - # ----------------------------------- - for split, dataloader in {'train': train_dataloader, 'validation': eval_dataloader}.items(): - if rank == 0 or not args.enable_fsdp: - print_header(f'*** Computing layer-wise {split} outputs ***') - - attn_inputs_by_layer = [] - with torch.no_grad(): - model.eval() - pbar = tqdm(dataloader, desc=f'❯❯❯ Computing layer-wise outputs on rank {rank} for {split} split') - for step, batch in enumerate(pbar): - batch = {'input_ids': batch['input_ids']} - key = 'input_ids' - if distill_config.enable_fsdp: - if is_xpu_available(): - batch[key] = batch[key].to(torch.device(f"xpu:{local_rank}")) - else: - batch[key] = batch[key].to(local_rank) - else: - if is_xpu_available(): - batch[key] = batch[key].to('xpu:0') - else: - batch[key] = batch[key].to('cuda:0') - # Save tensors -> HuggingFace Llama API - outputs = model(**batch, output_hidden_states=True, use_cache=False).get('hidden_states') - for idx, hidden_state in enumerate(outputs[:-1]): # indexed by layer, last layer is hidden layer before lm_head - hidden_state = model.model.layers[idx].input_layernorm(hidden_state).cpu() - if idx <= len(attn_inputs_by_layer): - attn_inputs_by_layer.append([hidden_state]) - else: - attn_inputs_by_layer[idx].append(hidden_state) - dist.barrier() - - # Save layer-wise outputs to disk - for layer_idx, attn_inputs in enumerate(attn_inputs_by_layer): - attn_inputs = torch.cat(attn_inputs, dim=0) # attn_inputs.shape is (batch, seq_len, hidden_size) - torch.save(attn_inputs, join(data_dir, f'attn_inputs-l={layer_idx}-s={split}.pt')) - print(f'-> Saved layer-wise tensors for {split} to {data_dir}!') - -if __name__ == "__main__": - main() diff --git a/llama_recipes/stitch_mini_fsdp.py b/llama_recipes/stitch_mini_fsdp.py deleted file mode 100644 index 381ef34..0000000 --- a/llama_recipes/stitch_mini_fsdp.py +++ /dev/null @@ -1,526 +0,0 @@ -""" -This file just needs to save out the shards for 405B. - -Notes: -- Make sure that register_buffer inv_freq persistent=True for your modeling_llama.py -""" - -import os -from os.path import join - -import argparse -from omegaconf import OmegaConf -from tqdm import tqdm - -import torch -import torch.optim as optim - -from transformers.models.llama.modeling_llama import ( - LlamaConfig, - LlamaDecoderLayer as DecoderLayer -) - -# Distributed arguments -from torch.distributed.fsdp import ( - FullyShardedDataParallel as FSDP, - ShardingStrategy, - StateDictType -) -from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload - -from accelerate.utils import is_xpu_available -from llama_recipes.configs import fsdp_config as FSDP_CONFIG -from llama_recipes.policies import apply_fsdp_checkpointing -from llama_recipes.utils.fsdp_utils import ( - fsdp_auto_wrap_policy, - hsdp_device_mesh as get_hsdp_device_mesh -) -from llama_recipes.utils.config_utils import update_config - -# Our arguments -from llama_recipes.trainer_finetune import ( - train, - setup, - setup_environ_flags, - clear_gpu_cache, - get_policies, -) -from llama_recipes.model_checkpointing.distill_checkpoint_handler import ( - load_model_sharded, -) -from llama_recipes.distill_llama import ( - setup_wandb, get_args, # get_run_name_from_checkpoint - get_dataloaders, setup_fsdp_config -) - -from src.utils.setup import ( - seed_everything, get_run_name_from_args, - update_config_from_args, update_model_config_from_args, -) -from src.utils.logging import print_config, print_header -from src.model.pretrained import get_pretrained_loader - -from src.model.convert_model import toggle_attention, remove_base_attention, traverse_layers -from src.model.load_model import load_and_convert_attns, load_and_convert_finetune - -from src.trainer import get_scheduler -from src.finetune import prepare_finetune_configs # get_finetuner - - - -def get_args(): - """Parse command line arguments""" - parser = argparse.ArgumentParser() - parser.add_argument("--project_name", type=str, default='lolcats') - parser.add_argument("--layers_per_model", type=int) - parser.add_argument("--layer_idx", type=int) # specify starting layer - parser.add_argument("--device", type=int, default=0) - - parser.add_argument("--model_config", type=str, default=None) - parser.add_argument("--distill_config", type=str, default=None) - parser.add_argument("--finetune_config", type=str, default=None) - parser.add_argument("--eval_config", type=str, default=None) - - parser.add_argument("--load_finetuned_loras", action='store_true', default=False) - parser.add_argument("--e2e_finetune_config", type=str, default=None) - - parser.add_argument("--pretrained_model_name_or_path", type=str, default=None) - parser.add_argument("--load_distill_checkpoint", type=str, default=None) - parser.add_argument("--resume_distill", action='store_true', default=None) - - parser.add_argument("--load_finetune_checkpoint", type=str, default=None) - parser.add_argument("--resume_finetune", action='store_true', default=None) - - # Override default configs - # Feature map / model - parser.add_argument("--attention_type", type=str, default=None) - parser.add_argument("--learned_kernel", type=str, default=None) # always - parser.add_argument("--lk_skip_connection", action='store_true', default=None) - parser.add_argument("--lk_zero_init", action='store_true', default=None) - parser.add_argument("--lk_normal_init", action='store_true', default=None) - parser.add_argument("--tie_qk_kernels", action='store_true', default=None) - parser.add_argument("--train_qk", action='store_true', default=None) - parser.add_argument("--state_chunk_len", type=int, default=None) - - # Training - parser.add_argument("--lr", type=float, default=None) - parser.add_argument("--weight_decay", type=float, default=None) - parser.add_argument("--optim", type=str, default=None) - parser.add_argument("--scheduler", type=str, default=None) - parser.add_argument("--gradient_accumulation_steps", type=int, default=None) - parser.add_argument("--num_train_epochs", type=int, default=None) - parser.add_argument("--max_steps", type=int, default=None) - parser.add_argument("--max_finetune_steps", type=int, default=None) - - parser.add_argument("--no_peft_grad_ckpt", action='store_true', default=None) - - ## Distributed training / Llama recipes - parser.add_argument("--enable_fsdp", action='store_true', default=None) - parser.add_argument("--low_cpu_fsdp", action='store_true', default=None) - parser.add_argument("--pure_bf16", action='store_true', default=None) - parser.add_argument("--fsdp_activation_checkpointing", action='store_true', default=None) - parser.add_argument("--fsdp_cpu_offload", action='store_true', default=None) - - # Dataloading - parser.add_argument("--batch_size", type=int, default=None) - parser.add_argument("--num_workers", type=int, default=None) - - # Evaluation - parser.add_argument("--no_init_eval", action='store_true', default=False) - parser.add_argument("--eval_steps", type=int, default=None) - parser.add_argument("--max_eval_batches", type=int, default=None) - - # Miscellaneous - parser.add_argument("--huggingface_token", type=str, default=None) - parser.add_argument("--checkpoint_dir", type=str, default='./checkpoints') - parser.add_argument("--results_dir", type=str, default='./results') - parser.add_argument("--replicate", type=int, default=0) - parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--verbose", action='store_true', default=None) - parser.add_argument("--no_cuda", action='store_true', default=None) - parser.add_argument("--no_wandb", action='store_true', default=None) - parser.add_argument("--wandb_entity", type=str, default='hazy-research') - parser.add_argument("--debug", action='store_true', default=None) - parser.add_argument("--no_attention_mask", action='store_true', default=None) - - args = parser.parse_args() - args.run_name = get_run_name_from_args(args) - return args - - -def check_state_dict_keys(keys: any, layer_idx: int, rank: int = 0, - state_dict: dict = None, verbose: bool = True) -> None: - """ - Check the state dict keys for unexpected and expected keys - - keys: the output from torch.load_state_dict() - - layer_idx: the current layer index - """ - try: - assert len(keys.unexpected_keys) == 0 - if rank == 0: - print_header(f'*** All expected keys matched successfully {layer_idx} ***') - if verbose and state_dict is not None: - print('Keys loaded:') - for k in state_dict: - print(f'├── {k}') - except Exception as e: - if rank == 0: - print(e) - print_header('*** Error: unexpected keys in checkpoint ***') - print(f'Unexpected keys at {layer_idx}:') - for k in keys.unexpected_keys: - print(k) - - -def rename_state_dict(rename_dict: dict, start_layer_idx: int, verbose: bool = False) -> dict: - """Rename the state dict from the mini models to match the full model""" - new_state_dict = {} - for k, v in rename_dict.items(): - if "layers" in k: - k_name = k.split("layers.")[-1].split(".")[0] - k_idx = int(k_name) - new_k_idx = k_idx + start_layer_idx - new_k_name = k.replace(k_name, str(new_k_idx)) - new_state_dict[new_k_name] = v - if verbose: # if start_layer_idx > 9 and start_layer_idx < 18: - print(f"-> Renaming {k} to {new_k_name}") - else: - new_state_dict[k] = v - return new_state_dict - - -def main(): - """Main script""" - # ------ - # SET UP - # ------ - args = get_args() - # args.checkpoint_dir = "/data_ephemeral/sim/sharded_layers_405b/" - args.checkpoint_dir = join(args.checkpoint_dir, args.model_config) - if not os.path.isdir(args.checkpoint_dir): - os.makedirs(args.checkpoint_dir) - # Save individual .pt model weights in a subdirectory - args.checkpoint_dir = join(args.checkpoint_dir, 'sharded_layers') - if not os.path.isdir(args.checkpoint_dir): - os.makedirs(args.checkpoint_dir) - args.results_dir = join(args.results_dir, args.model_config) - if not os.path.isdir(args.results_dir): - os.makedirs(args.results_dir) - seed_everything(args.seed) - - # Load distillation + (hedgehog) attention configs - distill_config_path = join('./configs/experiment', f'{args.distill_config}.yaml') - distill_config = OmegaConf.load(distill_config_path) - distill_config = update_config_from_args(distill_config, args) - - for arg, argv in distill_config.trainer.items(): # legacy, should be removed - if arg != 'name': - setattr(args, arg, argv) - for _config in ['dataloader', 'optimizer', 'lr_scheduler']: - setattr(args, _config, OmegaConf.to_container(getattr(distill_config, _config))) - - model_config_path = join('./configs/model', f'{args.model_config}.yaml') - model_config = OmegaConf.load(model_config_path) - model_config = update_model_config_from_args(model_config, args) - - # Update data tokenizer to match model (unused in this script) - if getattr(distill_config.dataset, 'pretrained_model_config', None) is not None: - for k in ['pretrained_model_name_or_path', 'cache_dir']: - distill_config.dataset.pretrained_model_config[k] = model_config.model[k] - - if args.enable_fsdp: - if getattr(model_config.model, 'load_in_4bit', False): - model_config.model.device_map = 'auto' - elif getattr(model_config.model, 'load_in_8bit', False): - model_config.model.device_map = 'auto' - else: - model_config.model.device_map = None # FSDP will complain about device placement o.w. - model_config.model.low_cpu_mem_usage = True - - # Setup FSDP if enabled - if args.enable_fsdp: - distill_config = setup_fsdp_config(distill_config, args, 'distill') # patch - fsdp_config = FSDP_CONFIG() - update_config((fsdp_config), **vars(args)) - setup() - # torchrun specific - local_rank = int(os.environ["LOCAL_RANK"]) - rank = int(os.environ["RANK"]) - # world_size = int(os.environ["WORLD_SIZE"]) - else: - fsdp_config = FSDP_CONFIG() # ignored - rank = 0 - - if torch.distributed.is_initialized(): - if is_xpu_available(): - torch.xpu.set_device(local_rank) - elif torch.cuda.is_available(): - torch.cuda.set_device(local_rank) - clear_gpu_cache(local_rank) - setup_environ_flags(rank) - - # WandB logging - wandb_run = None - if not args.no_wandb: - if not args.enable_fsdp or rank == 0: - wandb_run = setup_wandb(distill_config, fsdp_config, **vars(args), - project=args.project_name, entity=args.wandb_entity) - - # Loading model - try: - if not os.path.exists(model_config.model.pretrained_model_name_or_path): - print(f"Model path {model_config.model.pretrained_model_name_or_path} does not exist. Using backup path. {model_config.model.pretrained_model_name_or_path_backup}") - model_config.model.pretrained_model_name_or_path = model_config.model.pretrained_model_name_or_path_backup - model_config.model.pop("pretrained_model_name_or_path_backup") - except Exception as e: - print(f'-> Error: {e}') - print("Model without model.pretrained_model_name_or_path_backup path") - - if rank == 0 or not args.enable_fsdp: - print_header('Model Config') - print_config(model_config) - - # Get model class and configs for layer instantiating - pretrained_model_config = LlamaConfig.from_pretrained(model_config['model']['pretrained_model_name_or_path']) - pretrained_model_class = pretrained_model_config.architectures[0] - transformers_module = __import__('transformers') - pretrained_model_class = getattr(transformers_module, pretrained_model_class) # e.g, LlamaForCausalLM - - # ------------------------------------------- - # Step 1. Load pretrained model and tokenizer - # ------------------------------------------- - if rank == 0 or not args.enable_fsdp: - print_header('Pretrained Model Config') - print(pretrained_model_config) - print_header('Our Model Config') - print(model_config) - - model_loader = get_pretrained_loader(**model_config.model, - huggingface_token=args.huggingface_token) - # Model - model = model_loader.load(model_type='softmax') - if rank == 0 or not args.enable_fsdp: - print_header('Original Model') - print(model) - if args.enable_fsdp and fsdp_config.pure_bf16: - model.to(torch.bfloat16) - for p in model.parameters(): # Freeze all layers - p.requires_grad = False - model.eval() - # Tokenizer - tokenizer = model_loader.load_tokenizer() - tokenizer.pad_token_id = tokenizer.eos_token_id - tokenizer.padding_side = 'left' - - # --------------------------------------------------- - # Step 2. Convert attentions to linearized attentions - # --------------------------------------------------- - model = load_and_convert_attns(model, - model_config, - attention_type=None, # specified in model_config, - checkpoint_path=None, - print_model=args.verbose, - train_attention=False)[0] - if rank == 0 or not args.enable_fsdp: - print_header('Converted Model') - - # ------------------------------------------ - # Step 3. Loop through the saved checkpoints - # ------------------------------------------ - num_hidden_layers = pretrained_model_config.num_hidden_layers # e.g., 32 for Llama 8B - max_digits = len(str(num_hidden_layers)) # name_suffix = f'in={start:0{max_digits}d}-out={end:0{max_digits}d}' - with torch.no_grad(): - first = 0 - for layer_idx in range(tqdm(traverse_layers(model))): - if rank == 0 or not args.enable_fsdp: - print(f'Loading layer attentions from {args.checkpoint_dir}...') - load_file_name = f'{join(args.checkpoint_dir, args.run_name)}' - start, end = first, first + (args.layers_per_model - 1) - name_suffix = f'in={start:0{max_digits}d}-out={end:0{max_digits}d}' - load_file_name += f'-{name_suffix}' - load_file_name = load_file_name.replace('True', '1').replace('False', '0') # concise hacks - load_file_name = load_file_name + '_distill.pt' - - if (layer_idx + 1) % args.layers_per_model == 0: - if rank == 0 or not args.enable_fsdp: - mini_weights = torch.load(load_file_name)['model_state_dict'] - mini_weights = rename_state_dict(mini_weights, first) - _keys = model.load_state_dict(mini_weights, strict=False) - check_state_dict_keys(_keys, first, rank, mini_weights, verbose=args.verbose) - first = layer_idx + 1 - args.run_name += f'-{name_suffix}' # dealing with legacy naming - - # --------------------------------------- - # Step 4. Add end-to-end finetuning LoRAs - # --------------------------------------- - e2e_finetune_config, args = prepare_finetune_configs(args, model_config, - args.e2e_finetune_config) - e2e_finetune_config = setup_fsdp_config(e2e_finetune_config, args, 'finetune') - model, _ = load_and_convert_finetune(model, e2e_finetune_config, - checkpoint_path=None, - print_model=args.verbose, - merge_loras=False, - peft_gradient_checkpointing=not args.no_peft_grad_ckpt, - rank=rank) - - # ---------------------------------------------- - # Step 5. Add the LoRA weights from mini-distill - # ---------------------------------------------- - if args.load_finetuned_loras: - if args.enable_fsdp or rank == 0: - print("Loading loras") - with torch.no_grad(): - first = 0 - for layer_idx, layer in enumerate(tqdm(traverse_layers(model))): - print(f'Loading layer loras from {args.checkpoint_dir}...') - load_file_name = f'{join(args.checkpoint_dir, args.run_name)}' - start, end = first, first + (args.layers_per_model - 1) - name_suffix = f'in={start:0{max_digits}d}-out={end:0{max_digits}d}' - load_file_name += f'-{name_suffix}' - load_file_name = load_file_name.replace('True', '1').replace('False', '0') # concise hacks - load_file_name = load_file_name + '_ft.pt' - - if (layer_idx + 1) % args.layers_per_model == 0: - if rank == 0 or not args.enable_fsdp: - mini_weights = torch.load(load_file_name)['model_state_dict'] - mini_weights = rename_state_dict(mini_weights, first) - _keys = model.load_state_dict(mini_weights, strict=False) - check_state_dict_keys(_keys, first, rank, mini_weights, verbose=args.verbose) - first = layer_idx + 1 - - # Ignored - # hsdp_device_mesh = None - # if fsdp_config.hsdp and fsdp_config.sharding_strategy == ShardingStrategy.HYBRID_SHARD: - # hsdp_device_mesh = get_hsdp_device_mesh(replica_group_size=fsdp_config.replica_group_size, - # sharding_group_size=fsdp_config.sharding_group_size) - # print("HSDP device mesh is ready") - - # Final run name / checkpoint naming setup - if args.e2e_finetune_config is not None: # Update checkpoint for e2e finetune and lora loading - args.run_name += f'-ef={args.e2e_finetune_config}' - args.run_name += f'-ft_lora={args.load_finetuned_loras}'.replace('True', '1').replace('False', '0') - args.run_name = args.run_name.replace('True', '1').replace('False', '0') # concise hacks - args.run_name = args.run_name.replace(f'-{name_suffix}', '') # remove the mini model suffix - args.run_name = args.run_name.replace(args.model_config, ''.join([c[0] + c[-1] for c in args.model_config.split('_')])) - args.run_name = args.run_name.replace(args.distill_config, ''.join([c[0] + c[-1] for c in args.distill_config.split('_')])) - args.run_name = args.run_name.replace(args.finetune_config, ''.join([c[0] + c[-1] for c in args.finetune_config.split('_')])) - - # ---------------------------- - # Step 6. Wrap model with FSDP - # ---------------------------- - if args.enable_fsdp: - mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank, model="llama") - my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, DecoderLayer) - device_id = 0 - if is_xpu_available(): - device_id = torch.xpu.current_device() - elif torch.cuda.is_available(): - device_id = torch.cuda.current_device() - print('-> device_id:', device_id) - - model = FSDP( - model, - auto_wrap_policy=my_auto_wrapping_policy, # if train_config.use_peft else wrapping_policy, - cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None, - mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None, - sharding_strategy=fsdp_config.sharding_strategy, - # device_mesh=hsdp_device_mesh, - device_id=device_id, - limit_all_gathers=True, - sync_module_states=args.low_cpu_fsdp, # train_config.low_cpu_fsdp - param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False) - if args.low_cpu_fsdp and rank != 0 else None, - ) - if fsdp_config.fsdp_activation_checkpointing: - apply_fsdp_checkpointing(model) - - # Load distilled checkpoints - if args.verbose and rank == 0: - print_header('*** FSDP Model ***') - print(model) - print('Loading checkpoints from:', distill_config.model_name) - - if rank == 0 or not args.enable_fsdp: # debugging - print_header('** Sanity check model weights **') - for n, p in model.named_parameters(): - if ('layers.0.' in n and 'base_attn' not in n and - '.0.mlp.' not in n and '.block_sparse_moe' not in n): - print(f'-> {n}:\n', p) - - # Initialize the optimizer and learning rate scheduler - optimizer = optim.AdamW( - model.parameters(), - lr=e2e_finetune_config.optimizer.lr, - weight_decay=getattr(distill_config.optimizer, 'weight_decay', 0.), - ) - scheduler = get_scheduler(optimizer=optimizer, **e2e_finetune_config.lr_scheduler) - - if args.verbose and (rank == 0 or not args.enable_fsdp): - print('-> Optimizer:', optimizer) - print('-> Scheduler:', scheduler) - print_header('*** FSDP MODEL ***') - print(model) - print_header('*** Trainable Parameters ***') - for n, p in model.named_parameters(): - if p.requires_grad: - print(f'├── {n} (dtype = {p.dtype})') - - train_dataloader, eval_dataloader, e2e_finetune_config = get_dataloaders(e2e_finetune_config, tokenizer) - if not args.enable_fsdp or rank == 0: - print(f"--> Training Set Length = {len(train_dataloader.dataset)}") - print(f"--> Validation Set Length = {len(eval_dataloader.dataset)}") - if args.debug: - print('-> local_rank:', local_rank) - x = next(iter(train_dataloader))['input_ids'] - x = x.to(local_rank) - print("-> x = next(iter(train_dataloader))['input_ids']") - print("-> x = x.to(local_rank)") - print('-> x.device:', x.device) - - # Step 7. Finetune the model - if rank == 0 or not args.enable_fsdp: - print_header('*** Training ***') - if args.verbose: - print_config(e2e_finetune_config) - - # Start the training process - # max_optimizer_steps = getattr(distill_config.optimizer, 'max_optimizer_steps', None) - results, best_checkpoint_path = train( - model, - train_dataloader, - eval_dataloader, - tokenizer, - optimizer, - scheduler, - gradient_accumulation_steps=e2e_finetune_config.trainer.gradient_accumulation_steps, - train_config=e2e_finetune_config, # train_config, - fsdp_config=fsdp_config if args.enable_fsdp else None, - local_rank=local_rank if args.enable_fsdp else None, - rank=rank if args.enable_fsdp else None, - wandb_run=wandb_run, - ) - - # Save best model checkpoint as single .pt file - if fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT: - pass # Model checkpoint already saved - elif fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT: - # Load sharded weights across GPUs into model - ignore_param_rule = lambda n, p: not p.requires_grad # and 'feature_map' not in n or ('v_proj' in n or 'o_proj' in n) - load_model_sharded(model, rank, final_finetune_config, ignore_param_rule) - if rank == 0: # debugging - print_header('** Sanity check model weights **') - for n, p in model.named_parameters(): - if ('layers.0.' in n and 'base_attn' not in n and - '.0.mlp.' not in n and '.block_sparse_moe' not in n): - print(f'-> {n}:\n', p) - - if not args.enable_fsdp or rank==0: - for k,v in results.items(): - print(f'Key: {k}, Value: {v}') - if not args.no_wandb: - wandb_run.summary[f'ft_{k}'] = v - print('-> Find weights at:', best_checkpoint_path) - - -if __name__ == '__main__': - main() diff --git a/llama_recipes/tools/README.md b/llama_recipes/tools/README.md deleted file mode 100644 index 95525f3..0000000 --- a/llama_recipes/tools/README.md +++ /dev/null @@ -1,22 +0,0 @@ -# Convert Hugging Face llama weights to official llama consolidated format - -This is the reverse conversion for `convert_llama_weights_to_hf.py` script from the transformer package. - -## Step 0: Convert to consolidated format -- Create an output directory for the converted weights, such as `test70B`. -- Copy file params.json from the official llama download into that directory. -- Run the conversion script. `model-path` can be a Hugging Face hub model or a local hf model directory. -``` -python -m llama_recipes.tools.convert_hf_weights_to_llama --model-path meta-llama/Meta-Llama-3.1-70B-Instruct --output-dir test70B --model-size 70B -``` - -## Step 1: Run inference -Checkout the official llama 3 inference [repo](https://github.com/meta-llama/llama3). Test using chat or text completion. -``` -torchrun --nproc_per_node 8 example_chat_completion.py --ckpt_dir ./test70B --tokenizer_path ${llama_3_dir}/tokenizer.model -``` - -For validation, please compare the converted weights with official llama 2 weights -``` -python compare_llama_weights.py test70B ${Llama-3-70B-Instruct_dir} -``` diff --git a/llama_recipes/tools/compare_llama_weights.py b/llama_recipes/tools/compare_llama_weights.py deleted file mode 100644 index 25d16aa..0000000 --- a/llama_recipes/tools/compare_llama_weights.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -import gc -import glob -import os -import sys - -import torch -import tqdm - - -def main() -> None: - """Compare two llama checkpoint directories""" - - one_files = sorted(glob.glob(os.path.join(sys.argv[1], "consolidated.*.pth"))) - two_files = sorted(glob.glob(os.path.join(sys.argv[2], "consolidated.*.pth"))) - assert len(one_files) == len( - two_files - ), "One directory has {} files while another has {} files.".format( - len(one_files), len(two_files) - ) - - deltas = [] - for i in tqdm.trange(len(one_files), desc="Comparing shards"): - one = torch.load(one_files[i]) - two = torch.load(two_files[i]) - assert len(one) == len( - two - ), "shard should have the same length: {} != {}".format(len(one), len(two)) - one = sorted(one.items(), key=lambda x: x[0]) - two = sorted(two.items(), key=lambda x: x[0]) - - for _, (v, w) in enumerate(zip(one, two)): - assert v[0] == w[0], "{} != {}".format(v[0], w[0]) - assert v[1].shape == w[1].shape, "tensor {} shape {} != {}".format( - v[0], v[1].shape, w[1].shape - ) - - delta = (v[1] - w[1]).abs().max().item() - deltas.append((i, v[0], delta, w[1].abs().mean().item())) - del one - del two - gc.collect() - - deltas = sorted(deltas, key=lambda x: x[-2], reverse=True) - print("Top 10 largest deltas:") - for i, k, delta, value in deltas[:10]: - print(f" shard {i} {k}: {delta} vs {value}") - - -if __name__ == "__main__": - main() diff --git a/llama_recipes/tools/convert_hf_weights_to_llama.py b/llama_recipes/tools/convert_hf_weights_to_llama.py deleted file mode 100644 index 356e4a4..0000000 --- a/llama_recipes/tools/convert_hf_weights_to_llama.py +++ /dev/null @@ -1,172 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -import json -import os -from typing import List, Union - -import fire -import torch -from tqdm import tqdm -from transformers import LlamaForCausalLM # @manual - -NUM_SHARDS = { - "7B": 1, - "8B": 1, - "13B": 2, - "34B": 4, - "30B": 4, - "65B": 8, - "70B": 8, -} - - -def write_model(model_path, model_size, output_base_path): - dtype = torch.bfloat16 - - params = json.load(open(os.path.join(output_base_path, "params.json"), "r")) - num_shards = NUM_SHARDS[model_size] - n_layers = params["n_layers"] - n_heads = params["n_heads"] - n_heads_per_shard = n_heads // num_shards - dim = params["dim"] - dims_per_head = dim // n_heads - llama_version = 3 if params.get("vocab_size") == 128256 else 2 - - if "n_kv_heads" in params: - num_key_value_heads = params["n_kv_heads"] # for GQA / MQA - num_local_key_value_heads = num_key_value_heads // num_shards - key_value_dim = dims_per_head * num_key_value_heads - else: # compatibility with other checkpoints - num_key_value_heads = n_heads - num_local_key_value_heads = n_heads_per_shard - key_value_dim = dim - - model = LlamaForCausalLM.from_pretrained( - model_path, - torch_dtype=dtype, - low_cpu_mem_usage=True, - ) - loaded = model.state_dict() - - # permute for sliced rotary - def permute(w, n_heads=n_heads, dim1=dim, dim2=dim): - return ( - w.view(n_heads, 2, dim1 // n_heads // 2, dim2) - .transpose(1, 2) - .reshape(dim1, dim2) - ) - - state_dict = [{} for _ in range(num_shards)] - - def insert(name: str, tensor: Union[List, torch.Tensor]): - for i in range(num_shards): - state_dict[i][name] = ( - tensor[i].clone() if isinstance(tensor, list) else tensor - ) - - def insert_chunk(name: str, tensor: torch.Tensor, dim: int): - tensors = tensor.chunk(num_shards, dim=dim) - for i, tensor in enumerate(tensors): - state_dict[i][name] = tensor.clone() - - concat_dim = 0 if llama_version == 3 else 1 - insert_chunk( - "tok_embeddings.weight", loaded["model.embed_tokens.weight"], concat_dim - ) - insert("norm.weight", loaded["model.norm.weight"]) - insert_chunk("output.weight", loaded["lm_head.weight"], 0) - - for layer_i in tqdm(range(n_layers), desc="Converting layers"): - - ts = ( - permute(loaded[f"model.layers.{layer_i}.self_attn.q_proj.weight"]) - .view(n_heads_per_shard * num_shards, dims_per_head, dim) - .chunk(num_shards, dim=0) - ) - insert(f"layers.{layer_i}.attention.wq.weight", [t.view(-1, dim) for t in ts]) - - ts = ( - permute( - loaded[f"model.layers.{layer_i}.self_attn.k_proj.weight"], - num_key_value_heads, - key_value_dim, - dim, - ) - .view(num_local_key_value_heads * num_shards, dims_per_head, dim) - .chunk(num_shards, dim=0) - ) - insert(f"layers.{layer_i}.attention.wk.weight", [t.view(-1, dim) for t in ts]) - - ts = ( - loaded[f"model.layers.{layer_i}.self_attn.v_proj.weight"] - .view(num_local_key_value_heads * num_shards, dims_per_head, dim) - .chunk(num_shards, dim=0) - ) - insert(f"layers.{layer_i}.attention.wv.weight", [t.view(-1, dim) for t in ts]) - - insert_chunk( - f"layers.{layer_i}.attention.wo.weight", - loaded[f"model.layers.{layer_i}.self_attn.o_proj.weight"], - 1, - ) - - insert_chunk( - f"layers.{layer_i}.feed_forward.w1.weight", - loaded[f"model.layers.{layer_i}.mlp.gate_proj.weight"], - 0, - ) - - insert_chunk( - f"layers.{layer_i}.feed_forward.w2.weight", - loaded[f"model.layers.{layer_i}.mlp.down_proj.weight"], - 1, - ) - - insert_chunk( - f"layers.{layer_i}.feed_forward.w3.weight", - loaded[f"model.layers.{layer_i}.mlp.up_proj.weight"], - 0, - ) - - insert( - f"layers.{layer_i}.attention_norm.weight", - loaded[f"model.layers.{layer_i}.input_layernorm.weight"], - ) - insert( - f"layers.{layer_i}.ffn_norm.weight", - loaded[f"model.layers.{layer_i}.post_attention_layernorm.weight"], - ) - if llama_version != 3: - base = 10000.0 - inv_freq = ( - 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) - ).to(dtype) - insert("rope.freqs", inv_freq) - - for i in tqdm(range(num_shards), desc="Saving checkpoint shards"): - torch.save( - state_dict[i], os.path.join(output_base_path, f"consolidated.{i:02d}.pth") - ) - - -def main( - model_path: str, - model_size: str, - output_dir: str, -): - """Convert llama weights from huggingface format to consolidated format. - params: - model_path: model name or path to the model directory. - model_size: Llama model size, one of 7B, 13B, 34B, 30B, 65B, 70B. - output_dir: directory to save Llama weights, should contains params.json. - """ - assert model_size in NUM_SHARDS, f"Unknown model size {model_size}" - params_path = os.path.join(output_dir, "params.json") - assert os.path.isfile(params_path), f"{params_path} does not exist" - - write_model(model_path, model_size, output_dir) - - -if __name__ == "__main__": - fire.Fire(main) diff --git a/llama_recipes/trainer_attention.py b/llama_recipes/trainer_attention.py deleted file mode 100644 index 6c8146b..0000000 --- a/llama_recipes/trainer_attention.py +++ /dev/null @@ -1,659 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. -""" -Training and evaluation functions for attention distillation -- Modified from https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/utils/train_utils.py -""" - -import os -import time -import json -from datetime import timedelta -from contextlib import nullcontext -from pathlib import Path -from datetime import datetime -from pkg_resources import packaging -import yaml - -import torch -import torch.nn as nn -import torch.cuda.nccl as nccl -import torch.distributed as dist -from torch.distributed.fsdp import StateDictType -from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler -from tqdm import tqdm -from transformers import LlamaTokenizer - -# Ours -from llama_recipes.model_checkpointing.distill_checkpoint_handler import ( - save_model_checkpoint, - save_model_and_optimizer_sharded, - save_optimizer_checkpoint -) -from llama_recipes.policies import fpSixteen,bfSixteen # get_llama_wrapper -from llama_recipes.policies import get_llama_wrapper, get_mistral_wrapper, get_mixtral_wrapper -from llama_recipes.utils.memory_utils import MemoryTrace -from accelerate.utils import is_xpu_available, is_ccl_available - - -def set_tokenizer_params(tokenizer: LlamaTokenizer): - """Set the tokenizer parameters for padding""" - tokenizer.pad_token_id = 0 - tokenizer.padding_side = "left" - - -def byte2mb(x: float) -> int: - """Converts bytes to megabytes""" - return int(x / 2**20) - - -class LossComputer(): - """ - Computes the loss for attention distillation - """ - def __init__(self, mse_factor: int = 1000, xent_factor: int = 1, - n_queries: int = None, n_keys: int = None, **kwargs: any) -> None: - super().__init__() - self.criterion_mse = nn.MSELoss(reduction='mean') - self.criterion_xent = nn.CrossEntropyLoss(reduction='mean') - self.mse_factor = mse_factor - self.xent_factor = xent_factor - self.n_queries = n_queries - self.n_keys = n_keys - - - def compute_loss(self, model: torch.nn.Module, inputs: torch.Tensor) -> torch.Tensor: - """Compute the loss for attention distillation""" - loss = 0 - loss_mse = 0 - loss_xent = 0 - n_layers = 0 # Number of layers to distill - outputs = model(**inputs, output_attentions=True, use_cache=False).get('attentions') - for _, attns in enumerate(outputs): - # ((a_pred, a_true), (y_pred, _y_true)) - - if attns is not None: - - # attention_pred, attention_true (probability distributions) - if self.xent_factor > 0: - # Cross-entropy loss - a_pred, a_true = attns[0] - if self.n_queries is not None: - a_pred = a_pred[:, :, -self.n_queries:, :] - a_true = a_true[:, :, -self.n_queries:, :] - a_pred = a_pred.clamp(min=1e-12).log() # nn.CrossEntropy assumes unnormalized logits - k_len = a_true.shape[-1] # batch, n_heads, q_len, k_len - # Compute mean cross-entropy over all queries - a_pred = a_pred.contiguous().view(-1, k_len) - a_true = a_true.contiguous().view(-1, k_len) - loss_xent += self.criterion_xent(a_pred[:, :self.n_keys], - a_true[:, :self.n_keys]) - # a_pred = a_pred.detach().cpu() - # a_true = a_true.detach().cpu() - # loss_xent += self.criterion_xent(a_pred.to(model.device), - # a_true.to(model.device)) - - # y_preds, y_true (raw values) - if self.mse_factor > 0: - # - loss_mse += self.criterion_mse(*attns[1]) - # attns[1][0] = attns[1][0].detach().cpu() - # attns[1][1] = attns[1][1].detach().cpu() - # loss_mse += self.criterion_mse(*[a.to(model.device) for a in attns[1]]) - n_layers += 1 - # torch.cuda.empty_cache() - - if n_layers > 0: - loss_xent = loss_xent * self.xent_factor / n_layers - loss_mse = loss_mse * self.mse_factor / n_layers - - if ( type(loss_xent) == float ): - loss = loss_mse - elif ( type(loss_mse) == float ): - loss = loss_xent - else: - loss = (loss_xent + loss_mse) - - try: - loss_xent = loss_xent.item() - except: - pass - try: - loss_mse = loss_mse.item() - except: - pass - loss_metrics = { - 'loss_xent': loss_xent, - 'loss_mse': loss_mse, - 'loss': loss.item(), - 'xent_factor': self.xent_factor, - 'mse_factor': self.mse_factor, - } - return loss, loss_metrics - - -def train(model, train_dataloader, eval_dataloader, tokenizer, - optimizer, lr_scheduler, gradient_accumulation_steps, - train_config, fsdp_config=None, local_rank=None, rank=None, - wandb_run=None, eval_mode=False) -> tuple[dict[torch.Tensor], str]: - """ - Trains the model on the given dataloader - Args: - model: The model to be trained - train_dataloader: The dataloader containing the training data - optimizer: The optimizer used for training - lr_scheduler: The learning rate scheduler - gradient_accumulation_steps: The number of steps to accumulate gradients before performing a backward/update operation - num_epochs: The number of epochs to train for - local_rank: The rank of the current node in a distributed setting - train_config: The training configuration - eval_dataloader: The dataloader containing the eval data - tokenizer: tokenizer used in the eval for decoding the predicitons - - Returns: - results dictionary containing average training and validation loss - best_checkpoint_path: The path to the best checkpoint - """ - loss_computer = LossComputer(**train_config.trainer) - - if rank == 0 or rank is None: - print('-> Gradient accumulation steps:', gradient_accumulation_steps) - print('-> Total # of training samples:', len(train_dataloader)) - total_length = len(train_dataloader)//gradient_accumulation_steps - print('-> Total # of training updates:', total_length) - # print('-> loss_computer:', loss_computer) - - # Create a gradient scaler for fp16 - if train_config.use_fp16 and train_config.enable_fsdp: - scaler = ShardedGradScaler() - elif train_config.use_fp16 and not train_config.enable_fsdp: - scaler = torch.cuda.amp.GradScaler() - if train_config.enable_fsdp: - world_size = int(os.environ["WORLD_SIZE"]) - # print('-> world_size:', world_size) - - autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext - - train_loss = [] - val_loss =[] - - if train_config.save_metrics: - _dt = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') - metrics_filename = f"{train_config.output_dir}/metrics_data_{local_rank}-{_dt}.json".replace('//', '/') - train_step_loss = [] - val_step_loss = [] - # print(f'-> Saving metrics to {metrics_filename}') - - epoch_times = [] - checkpoint_times = [] - results = {} - best_val_loss = float("inf") - best_checkpoint_path = None - for epoch in range(train_config.num_epochs): - epoch_start_time = time.perf_counter() - with MemoryTrace() as memtrace: # track the memory usage - if eval_mode: - model.eval() - print(f'-> Model is eval mode on rank {rank}') - else: - model.train() - total_loss = 0.0 - total_length = len(train_dataloader)//gradient_accumulation_steps - pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True) - for step, batch in enumerate(train_dataloader): - model.train() - # print('-> step:', step) - for key in batch.keys(): - if key == 'labels': - batch[key] = None # don't use labels for attention distillation - else: - if train_config.enable_fsdp: - if is_xpu_available(): - batch[key] = batch[key].to(torch.device(f"xpu:{local_rank}")) - else: - batch[key] = batch[key].to(local_rank) - else: - if is_xpu_available(): - batch[key] = batch[key].to('xpu:0') - else: - batch[key] = batch[key].to('cuda:0') - with autocast(): - loss, loss_metrics = loss_computer.compute_loss(model, batch) - loss = loss / gradient_accumulation_steps - if train_config.save_metrics: - train_step_loss.append(loss.detach().cpu().float().item()) - - total_loss += loss.detach().float() - if train_config.use_fp16: - # if fp16 is enabled, use gradient scaler to handle gradient update - scaler.scale(loss).backward() - if ((step + 1) % gradient_accumulation_steps == 0 - or step == len(train_dataloader) - 1): - if (train_config.gradient_clipping - and train_config.gradient_clipping_threshold > 0.0): - scaler.unscale_(optimizer) - if train_config.enable_fsdp: - model.clip_grad_norm_(train_config.gradient_clipping_threshold) - else: - torch.nn.utils.clip_grad_norm_( - model.parameters(), train_config.gradient_clipping_threshold) - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - pbar.update(1) - else: - # regular backpropagation when fp16 is not used - loss.backward() - if ((step + 1) % gradient_accumulation_steps == 0 - or step == len(train_dataloader) - 1): - if (train_config.gradient_clipping - and train_config.gradient_clipping_threshold > 0.0): - if train_config.enable_fsdp: - model.clip_grad_norm_(train_config.gradient_clipping_threshold) - else: - torch.nn.utils.clip_grad_norm_( - model.parameters(), train_config.gradient_clipping_threshold) - optimizer.step() - optimizer.zero_grad() - pbar.update(1) - if wandb_run: - if not train_config.enable_fsdp or rank==0: - wandb_run.log({ - 'train/epoch': epoch + 1, - 'train/step': epoch * len(train_dataloader) + step, - 'train/loss': loss.detach().cpu().item(), - }) - - desc = f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.item():.5f}, lr: {optimizer.param_groups[0]['lr']:.5f})" - for k, v in loss_metrics.items(): - desc += f" | {k}: {v:.5f}" - pbar.set_description(desc) - - if train_config.save_metrics: - save_to_json(metrics_filename, train_step_loss, train_loss, val_step_loss, val_loss) - - if step == getattr(train_config, 'num_train_steps', -1): - break # Early exit for debugging later logic - - if (train_config.run_validation and ( - (step + 1) % (train_config.eval_steps * gradient_accumulation_steps) == 0)): # or step == len(train_dataloader) - 1)): - eval_outputs = eval_loop(model, evaluate_attn, optimizer, lr_scheduler, train_config, fsdp_config, rank, - eval_dataloader, local_rank, tokenizer, wandb_run, - val_step_loss, val_loss, best_val_loss, checkpoint_times, epoch, step) - save_path, val_step_loss, val_loss, best_val_loss, checkpoint_times = eval_outputs - if save_path is not None: - best_checkpoint_path = save_path - if not eval_mode: - model.train() - print(f'-> Model is training on rank {rank}') - del loss; del batch; del eval_outputs - clear_gpu_cache() - pbar.close() - - epoch_end_time = time.perf_counter()-epoch_start_time - epoch_times.append(epoch_end_time) - # Reducing total_loss across all devices if there's more than one CUDA device - if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp): - dist.all_reduce(total_loss, op=dist.ReduceOp.SUM) - elif torch.cuda.device_count() > 1 and train_config.enable_fsdp: - dist.all_reduce(total_loss, op=dist.ReduceOp.SUM) - train_epoch_loss = total_loss / len(train_dataloader) - if train_config.enable_fsdp: - train_epoch_loss = train_epoch_loss/world_size - train_loss.append(float(train_epoch_loss)) - - if not train_config.enable_fsdp or rank==0: - memtrace.print_stats() - - # Update the learning rate as needed - # lr_scheduler.step() - eval_outputs = eval_loop(model, evaluate_attn, optimizer, lr_scheduler, train_config, fsdp_config, rank, - eval_dataloader, local_rank, tokenizer, wandb_run, - val_step_loss, val_loss, best_val_loss, checkpoint_times, - epoch, step) - save_path, val_step_loss, val_loss, best_val_loss, checkpoint_times = eval_outputs - if save_path is not None: - best_checkpoint_path = save_path - - if train_config.enable_fsdp: - if rank==0: - print(f"Epoch {epoch+1}: train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s") - else: - print(f"Epoch {epoch+1}: train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s") - - # Saving the results every epoch to plot later - if train_config.save_metrics: - save_to_json(metrics_filename, train_step_loss, train_loss, val_step_loss, val_loss) - - avg_epoch_time = sum(epoch_times)/ len(epoch_times) - avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) if len(checkpoint_times) > 0 else 0 - avg_train_loss = sum(train_loss)/len(train_loss) - if train_config.run_validation: - avg_eval_loss = sum(val_loss)/len(val_loss) - - results['avg_train_loss'] = avg_train_loss - if train_config.run_validation: - results['avg_eval_loss'] = avg_eval_loss - results["avg_epoch_time"] = avg_epoch_time - results["avg_checkpoint_time"] = avg_checkpoint_time - if train_config.save_metrics: - results["metrics_filename"] = metrics_filename - - #saving the training params including fsdp setting for reference. - if train_config.enable_fsdp and not train_config.use_peft: - save_train_params(train_config, fsdp_config, rank) - - return results, best_checkpoint_path - - -def eval_loop(model, evaluate_func, optimizer, lr_scheduler, - train_config, fsdp_config, rank, eval_dataloader, - local_rank, tokenizer, wandb_run, - val_step_loss, val_loss, best_val_loss, - checkpoint_times, epoch, step): # extra globals - """ - Evaluate model and save checkpoints - - see `evaluate_func` for evaluation logic - """ - eval_epoch_loss, temp_val_loss = evaluate_func( - model, train_config, eval_dataloader, local_rank, tokenizer, wandb_run, epoch=epoch) - try: - lr_scheduler.step(eval_epoch_loss) - except: - lr_scheduler.step() - - if train_config.save_metrics: - val_step_loss.extend(temp_val_loss) - - checkpoint_start_time = time.perf_counter() - - # train_config.save_model = False - if train_config.save_model and eval_epoch_loss < best_val_loss: - if train_config.enable_fsdp: - dist.barrier() - else: - dist.barrier() - - if fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT: - save_path = save_model_checkpoint( - model, optimizer, rank, train_config, epoch=epoch - ) - elif fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT: - if rank == 0: - print("Saving the FSDP model checkpoints using SHARDED_STATE_DICT") - print("==========================================================") - - save_path = save_model_and_optimizer_sharded(model, rank, train_config) - if train_config.save_optimizer: - save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer) - if rank == 0: - print("Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT") - print("========================================================================") - - if train_config.save_optimizer: - save_optimizer_checkpoint( - model, optimizer, rank, train_config, epoch=epoch - ) - if rank == 0: - print("Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT") - print("=====================================================================") - if train_config.enable_fsdp: - dist.barrier() - else: - dist.barrier() - else: - save_path = None - checkpoint_end_time = time.perf_counter() - checkpoint_start_time - checkpoint_times.append(checkpoint_end_time) - if eval_epoch_loss < best_val_loss: - best_val_loss = eval_epoch_loss - if rank == 0 or not train_config.enable_fsdp: - print(f"best eval loss on epoch {epoch+1}, step {step + 1} is {best_val_loss}") - - val_loss.append(float(best_val_loss)) - return save_path, val_step_loss, val_loss, best_val_loss, checkpoint_times - - -def evaluate_attn(model, train_config, eval_dataloader, - local_rank, tokenizer, wandb_run, - epoch: int = None, rank: int = 0): - """ - Evaluates the model on the given dataloader - Args: - model: The model to evaluate - eval_dataloader: The dataloader containing the evaluation data - local_rank: The rank of the current node in a distributed setting - tokenizer: The tokenizer used to decode predictions - - Returns: eval_epoch_loss - """ - loss_computer = LossComputer(**train_config.trainer) - - if train_config.enable_fsdp: - world_size = int(os.environ["WORLD_SIZE"]) - model.eval() - val_step_loss = [] - - eval_loss = 0.0 # Initialize evaluation loss - _epoch = f' {epoch}' if epoch is not None else '' - pbar = tqdm(eval_dataloader,colour="green", desc=f"Rank {rank} | Eval Epoch{_epoch}", dynamic_ncols=True) - for step, batch in enumerate(pbar): - for key in batch.keys(): - if train_config.enable_fsdp: - batch[key] = batch[key].to(local_rank) - else: - if is_xpu_available(): - batch[key] = batch[key].to('xpu:0') - else: - batch[key] = batch[key].to('cuda:0') - # Ensure no gradients are computed for this scope to save memory - with torch.no_grad(): - # Forward pass and compute loss - loss, loss_metrics = loss_computer.compute_loss(model, batch) - if train_config.save_metrics: - val_step_loss.append(loss.detach().float().item()) - - eval_loss += loss.detach().float() - - desc = f"Rank {rank} | Eval Epoch{_epoch} | step_loss: {loss.item():.5f} | avg_loss: {eval_loss.item()/(step+1):.5f}" - for k, v in loss_metrics.items(): - desc += f" | {k}: {v:.5f}" - pbar.set_description(desc) - - # If there's more than one CUDA device, reduce evaluation loss across all devices - if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp): - dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM) - if torch.cuda.device_count() > 1 and train_config.enable_fsdp: - dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM) - - # Compute average loss - # print('len(eval_dataloader):', len(eval_dataloader)) - # print('step + 1:', step + 1) - # print('world_size:', world_size) - eval_epoch_loss = eval_loss / len(eval_dataloader) - if train_config.enable_fsdp: - eval_epoch_loss = eval_epoch_loss/world_size - - eval_epoch_loss = eval_epoch_loss.cpu().float().item() - - # Print evaluation metrics - if local_rank == 0 or not train_config.enable_fsdp: - print(f" {eval_epoch_loss=}") - - if wandb_run: - wandb_run.log({'eval/loss': eval_epoch_loss,}, commit=False) - - del loss; del eval_loss; del batch - clear_gpu_cache() - - return eval_epoch_loss, val_step_loss - - -def freeze_transformer_layers(model: nn.Module, num_layer: int): - """Freeze model layers up to num_layer""" - for i, layer in enumerate(model.model.layers): - if i < num_layer: - for param in layer.parameters(): - param.requires_grad = False - - -def check_frozen_layers_peft_model(model) -> None: - """ - Print layer.requires_grad for each layer in the model - """ - for i, layer in enumerate(model.base_model.model.model.layers): - for name, param in layer.named_parameters(): - print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}") - - -def setup(): - """Initialize the process group for distributed training""" - if is_ccl_available(): - # distributed training on xpus - dist.init_process_group("ccl") - else: - dist.init_process_group("nccl", timeout=timedelta(seconds=3600)) - - -def setup_environ_flags(rank): - """Set environment flags for debugging purposes""" - os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1) - os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1) - # os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" - # This flag will help with CUDA memory fragmentations that can lead into OOM in some cases. - # Note this is only availble in PyTorch Nightlies (as of July 30 2023) - # os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True' - if rank == 0: - print("--> Running with torch dist debug set to detail") - - -def cleanup(): - """Clean up the process group after training""" - dist.destroy_process_group() - - -def clear_gpu_cache(rank=None): - """Clear the GPU cache for all ranks""" - if rank == 0: - print("Clearing GPU cache for all ranks") - if is_xpu_available(): - torch.xpu_empty_cache() - else: - torch.cuda.empty_cache() - - -def get_parameter_dtypes(model): - """Get the data types of model parameters""" - parameter_dtypes = {} - for name, parameter in model.named_parameters(): - parameter_dtypes[name] = parameter.dtype - return parameter_dtypes - - -def print_model_size(model, config, rank: int = 0) -> None: - """ - Print model name, the number of trainable parameters and initialization time. - - Args: - model: The PyTorch model. - model_name (str): Name of the model. - init_time_start (float): Initialization start time. - init_time_end (float): Initialization end time. - rank (int, optional): Current process's rank. Defaults to 0. - """ - if rank == 0: - print(f"--> Model {config.model_name}") - total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n") - - -def get_policies(cfg, rank, model: str = 'llama'): - """Get the policies for mixed precision and fsdp wrapping""" - verify_bfloat_support = (( - torch.version.cuda - and torch.cuda.is_bf16_supported() - and packaging.version.parse(torch.version.cuda).release >= (11, 0) - and dist.is_nccl_available() - and nccl.version() >= (2, 10) - ) or - (is_xpu_available())) - - mixed_precision_policy = None - wrapping_policy = None - - # Mixed precision - if cfg.mixed_precision: - bf16_ready = verify_bfloat_support - - if bf16_ready and not cfg.use_fp16: - mixed_precision_policy = bfSixteen - if rank == 0: - print("bfloat16 enabled for mixed precision - using bfSixteen policy") - elif cfg.use_fp16: - mixed_precision_policy = fpSixteen - if rank == 0: - print("FP16 enabled") - else: - print("bfloat16 support not present. Using FP32, and not mixed precision") - - if model == 'llama': - wrapping_policy = get_llama_wrapper() - elif model == 'mistral': - wrapping_policy = get_mistral_wrapper() - elif model == 'mixtral': - wrapping_policy = get_mixtral_wrapper() - return mixed_precision_policy, wrapping_policy - - -def save_train_params(train_config, fsdp_config, rank): - """ - This function saves the train_config and FSDP config into a train_params.yaml. - This will be used by converter script in the inference folder to fetch the HF model name or path. - It also would be hepful as a log for future references. - """ - # Convert the train_config and fsdp_config objects to dictionaries, - # converting all values to strings to ensure they can be serialized into a YAML file - train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')} - fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')} - # Merge the two dictionaries into one - train_params_dict = {**train_config_dict, **fsdp_config_dict} - # Construct the folder name (follwoing FSDP checkpointing style) using properties of the train_config object - folder_name = ( - train_config.dist_checkpoint_root_folder - + "/" - + train_config.dist_checkpoint_folder - + "-" - + train_config.model_name - ) - - save_dir = Path.cwd() / folder_name - # If the directory does not exist, create it - if not os.path.exists(save_dir): - os.makedirs(save_dir) - # Convert the dictionary to a YAML string - config_yaml = yaml.dump(train_params_dict, indent=4) - file_name = os.path.join(save_dir,'train_params.yaml') - - # Check if there's a directory with the same name as the file - if os.path.isdir(file_name): - print(f"Error: {file_name} is a directory, not a file.") - else: - # Write the YAML string to the file - with open(file_name, 'w') as f: - f.write(config_yaml) - if rank==0: - print(f"training params are saved in {file_name}") - - -def save_to_json(output_filename, - train_step_loss, train_epoch_loss, - val_step_loss, val_epoch_loss): - """Save loss data to JSON file""" - metrics_data = { - "train_step_loss": train_step_loss, - "train_epoch_loss": train_epoch_loss, - "val_step_loss": val_step_loss, - "val_epoch_loss": val_epoch_loss, - } - with open(output_filename, "w") as f: - json.dump(metrics_data, f) \ No newline at end of file diff --git a/llama_recipes/trainer_eval_mmlu.py b/llama_recipes/trainer_eval_mmlu.py deleted file mode 100644 index 24f0162..0000000 --- a/llama_recipes/trainer_eval_mmlu.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. -""" -Training and evaluation functions for evaluating MMLU following LM Evaluation Harness Implementation -- Modified from https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/utils/train_utils.py - -""" -import os -import time -from contextlib import nullcontext -from datetime import datetime - -import torch -# import torch.cuda.nccl as nccl -import torch.distributed as dist -# from torch.distributed.fsdp import StateDictType -from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler -from tqdm import tqdm - -# Ours -import numpy as np -from accelerate.utils import is_xpu_available -from llama_recipes.utils.memory_utils import MemoryTrace -from llama_recipes.trainer_attention import ( - eval_loop, clear_gpu_cache, save_to_json, - setup, setup_environ_flags, # imports into distill_llama_finetune.py - print_model_size, get_policies -) - - -class MMLUComputer(): - """ - Computes the loss for next-token prediction - """ - def __init__(self, **kwargs: any): - super().__init__() - # self.categories = {} # {'mmlu-category': {'correct': int, 'total': int, 'acc': float}} - self.criterion = torch.nn.CrossEntropyLoss(reduction='mean') - - def compute_loss(self, model: torch.nn.Module, data: torch.Tensor, - rank: int = 0, local_rank: int = 0): - """Compute MMLU loss""" - - input_keys = {'input_ids'} # , 'attention_mask'} (assume packing / no padding) - inputs = {k: v.to(model.device) for k, v in data.items() if k in input_keys} - outputs = model(**inputs, output_attentions=False, use_cache=False) # use_cache=False) - - outputs = outputs.get('logits')[..., -2, :].contiguous() # b, d - targets = data.get('input_ids')[..., -1].contiguous().to(outputs.device) # b, d - # Compute cross-entropy loss - losses = [] - for choice_idx in range(outputs.shape[0]): - losses.append(self.criterion(outputs[choice_idx], targets[choice_idx])) - losses = torch.stack(losses).cpu() # b, 1 - pred = torch.argmin(losses, dim=0) - # print(f'{losses.shape=}') - # print(f"{data['target'].shape=}") - # print(f"{pred.shape=}") - # print(f"{data['target']=}") - # print(f"{pred=}") - # print(f"{data['category'].shape=}") - # print(f"{data['category']=}") - correct = data['target'][0].cpu() == pred - return losses, correct, data['category'][0] # same for all of them - - -def evaluate_mmlu(model, train_config, eval_dataloader, - local_rank, tokenizer, wandb_run, - epoch: int = None, rank: int = 0): - """ - Evaluates the model on the given dataloader - - Args: - model: The model to evaluate - eval_dataloader: The dataloader containing the evaluation data - local_rank: The rank of the current node in a distributed setting - tokenizer: The tokenizer used to decode predictions - - Returns: eval_epoch_loss - """ - loss_computer = MMLUComputer(**train_config.trainer) - if train_config.enable_fsdp: - world_size = int(os.environ["WORLD_SIZE"]) - model.eval() - - metrics = { - 'all': {'correct': 0, 'total': 0} - # categories: {'total': 0, 'correct': 0} - } - - _epoch = f' {epoch}' if epoch is not None else '' - pbar = tqdm(eval_dataloader, colour="green", desc=f"Rank {rank} | Eval Epoch{_epoch}", dynamic_ncols=True) - for step, batch in enumerate(pbar): - # print(f'batch.keys()') - # for key in batch.keys(): - # print(f'{key}: {batch[key]}') - for key in ['input_ids', 'attention_mask']: # batch.keys(): - if train_config.enable_fsdp: - batch[key] = batch[key].to(local_rank) - else: - if is_xpu_available(): - batch[key] = batch[key].to('xpu:0') - else: - batch[key] = batch[key].to('cuda:0') - - # Ensure no gradients are computed for this scope to save memory - with torch.no_grad(): - # Forward pass and compute loss - losses, _correct, _category_idx = loss_computer.compute_loss(model, batch, rank=rank, - local_rank=local_rank) - metrics['all']['total'] += 1 - metrics['all']['correct'] += _correct.int().item() - - _category = eval_dataloader.categories[_category_idx] - - if _category in metrics: - metrics[_category]['total'] += 1 - metrics[_category]['correct'] += _correct.int().item() - else: - metrics[_category] = {'total': 1, 'correct': _correct.int().item()} - - total = metrics['all']['total'] - correct = metrics['all']['correct'] - total_acc = correct / total * 100 - - pbar.set_description(f"Rank {rank} | Eval Epoch{_epoch} | total acc: {total_acc:.3f}% ({correct} / {total})") - - - - # # If there's more than one CUDA device, reduce evaluation loss across all devices - # if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp): - # for category in metrics: - # for k, v in metrics[category].items(): - # dist.all_reduce(torch.Tensor([v]), op=dist.ReduceOp.SUM) - # elif torch.cuda.device_count() > 1 and train_config.enable_fsdp: # what's the diff b/t this condition and above? - # for category in metrics: - # for k, v in metrics[category].items(): - # print(f'{k} before all_reduce:', v) - # dist.all_reduce(torch.Tensor([v]), op=dist.ReduceOp.SUM) - # print(f'{k} after all_reduce:', v) - - for k in metrics: - metrics[k]['acc'] = metrics[k]['correct'] / metrics[k]['total'] - - del batch - clear_gpu_cache() - if wandb_run: - wandb_run.log(metrics) - return metrics diff --git a/llama_recipes/trainer_finetune.py b/llama_recipes/trainer_finetune.py deleted file mode 100644 index 08837ef..0000000 --- a/llama_recipes/trainer_finetune.py +++ /dev/null @@ -1,337 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. -""" -Training and evaluation functions for finetuning after attention transfer -- Modified from https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/utils/train_utils.py - -We do -""" -import os -import time -from contextlib import nullcontext -from datetime import datetime - -import torch -# import torch.cuda.nccl as nccl -import torch.distributed as dist -# from torch.distributed.fsdp import StateDictType -from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler -from tqdm import tqdm - -# Ours -import numpy as np -from accelerate.utils import is_xpu_available -from llama_recipes.utils.memory_utils import MemoryTrace -from llama_recipes.trainer_attention import ( - eval_loop, clear_gpu_cache, save_to_json, - setup, setup_environ_flags, # imports into distill_llama_finetune.py - print_model_size, get_policies -) - - -class LossComputer(): - """ - Computes the loss for next-token prediction - """ - def __init__(self, **kwargs: any): - super().__init__() - self.criterion = torch.nn.CrossEntropyLoss(reduction='mean') - - def compute_loss(self, model: torch.nn.Module, data: torch.Tensor, - rank: int = 0, local_rank: int = 0): - """Compute the loss for attention distillation""" - input_keys = {'input_ids'} # , 'attention_mask'} (assume packing / no padding) - inputs = {k: v.to(model.device) for k, v in data.items() if k in input_keys} - outputs = model(**inputs, output_attentions=False, use_cache=False) # use_cache=False) - outputs = outputs.get('logits')[..., :-1, :].contiguous() - targets = data.get('labels')[..., 1:].contiguous() - # Flatten and compute cross-entropy loss - outputs = outputs.view(-1, outputs.shape[-1]) - targets = targets.view(-1).to(outputs.device) - if (targets != -100).sum() == 0: - return torch.Tensor([0])[0] - else: - loss = self.criterion(outputs, targets) - targets = targets.cpu() - outputs = outputs.cpu() - # print(f'rank: {rank} | local_rank: {local_rank} | loss: {loss.item():.5f} | shape: {targets.shape} |') - return loss # , {'ppl': np.exp(loss.item()), 'seq_len': targets.shape[-1] + 1} - - -def train(model, train_dataloader, eval_dataloader, tokenizer, - optimizer, lr_scheduler, gradient_accumulation_steps, - train_config, fsdp_config=None, local_rank=None, rank=None, - wandb_run=None, stepwise_scheduler=False) -> dict[torch.Tensor]: - """ - Trains the model on the given dataloader - - Args: - model: The model to be trained - train_dataloader: The dataloader containing the training data - optimizer: The optimizer used for training - lr_scheduler: The learning rate scheduler - gradient_accumulation_steps: The number of steps to accumulate gradients before performing a backward/update operation - num_epochs: The number of epochs to train for - local_rank: The rank of the current node in a distributed setting - train_config: The training configuration - eval_dataloader: The dataloader containing the eval data - tokenizer: tokenizer used in the eval for decoding the predicitons - - Returns: - results dictionary containing average training and validation loss - best_checkpoint_path: The path to the best checkpoint - """ - loss_computer = LossComputer(**train_config.trainer) - - if rank == 0 or rank is None: - print('-> Gradient accumulation steps:', gradient_accumulation_steps) - print('-> Total # of training samples:', len(train_dataloader)) - total_length = len(train_dataloader)//gradient_accumulation_steps - print('-> Total # of training updates:', total_length) - - # print('-> loss_computer:', loss_computer) - - # Create a gradient scaler for fp16 - if train_config.use_fp16 and train_config.enable_fsdp: - scaler = ShardedGradScaler() - elif train_config.use_fp16 and not train_config.enable_fsdp: - scaler = torch.cuda.amp.GradScaler() - if train_config.enable_fsdp: - world_size = int(os.environ["WORLD_SIZE"]) - # print('-> world_size:', world_size) - - autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext - train_loss = [] - val_loss =[] - - if train_config.save_metrics: - metrics_filename = f"{train_config.output_dir}/metrics_data_{local_rank}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.json" - train_step_perplexity = [] - train_step_loss = [] - val_step_loss = [] - val_step_perplexity = [] - # print(f'-> Saving metrics to {metrics_filename}') - - epoch_times = [] - checkpoint_times = [] - results = {} - best_val_loss = float("inf") - best_checkpoint_path = None - total_step = 0 - - for epoch in range(train_config.num_epochs): - epoch_start_time = time.perf_counter() - # print('-> epoch:', epoch) - # if True: - with MemoryTrace() as memtrace: # track the memory usage - model.train() - print(f'-> Model is training on rank {rank}') - total_loss = 0.0 - total_length = len(train_dataloader)//gradient_accumulation_steps - pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True) - for step, batch in enumerate(train_dataloader): - model.train() - # print('-> step:', step) - for key in batch.keys(): - if train_config.enable_fsdp: - if is_xpu_available(): - batch[key] = batch[key].to(torch.device(f"xpu:{local_rank}")) - else: - batch[key] = batch[key].to(local_rank) - else: - if is_xpu_available(): - batch[key] = batch[key].to('xpu:0') - else: - batch[key] = batch[key].to('cuda:0') - with autocast(): - loss = loss_computer.compute_loss(model, batch, rank, local_rank) - - train_step_loss.append(loss.item()) - train_step_perplexity.append(float(np.exp(loss.item()))) - loss = loss / gradient_accumulation_steps - total_loss += loss.detach().float() - - if train_config.use_fp16: - # if fp16 is enabled, use gradient scaler to handle gradient update - scaler.scale(loss).backward() - if (total_step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: - if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0: - scaler.unscale_(optimizer) - if train_config.enable_fsdp: - model.clip_grad_norm_(train_config.gradient_clipping_threshold) - else: - torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold) - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - pbar.update(1) - if stepwise_scheduler: - lr_scheduler.step() - else: - # regular backpropagation when fp16 is not used - # if loss.sum() > 0: # hack for non-answer targets - loss.backward() - if (total_step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: - if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0: - if train_config.enable_fsdp: - model.clip_grad_norm_(train_config.gradient_clipping_threshold) - else: - torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold) - optimizer.step() - optimizer.zero_grad() - pbar.update(1) - if stepwise_scheduler: - lr_scheduler.step() - - if wandb_run: - if not train_config.enable_fsdp or rank==0: - wandb_run.log({ - 'train/epoch': epoch + 1, - 'train/step': total_step, # epoch * len(train_dataloader) + step, - 'train/loss': train_step_loss[-1], - 'train/ppl': train_step_perplexity[-1], - 'train/lr': optimizer.param_groups[-1]['lr'] - }) - - metrics = f"loss: {train_step_loss[-1]:.5f} | lr: {optimizer.param_groups[0]['lr']:.5f} | ppl: {train_step_perplexity[-1]}" - # pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float():.5f}, lr: {optimizer.param_groups[0]['lr']:.5f})") - pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed ({metrics})") - - if train_config.save_metrics: - save_to_json(metrics_filename, train_step_loss, train_loss, val_step_loss, val_loss) - - if total_step == getattr(train_config, 'num_train_steps', -1): - break # Early exit for debugging later logic - - if (train_config.run_validation and ( - (total_step + 1) % (train_config.eval_steps * gradient_accumulation_steps) == 0)): # or step == len(train_dataloader) - 1)): - dist.barrier() - eval_outputs = eval_loop(model, evaluate_lm, optimizer, lr_scheduler, - train_config, fsdp_config, rank, eval_dataloader, - local_rank, tokenizer, wandb_run, - val_step_loss, val_loss, best_val_loss, - checkpoint_times, epoch, total_step) - dist.barrier() - save_path, val_step_loss, val_loss, best_val_loss, checkpoint_times = eval_outputs - if save_path is not None: - best_checkpoint_path = save_path - model.train() - print(f'-> Model is training on rank {rank}') - total_step += 1 - pbar.close() - - epoch_end_time = time.perf_counter()-epoch_start_time - epoch_times.append(epoch_end_time) - - # Reducing total_loss across all devices if there's more than one CUDA device - if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp): - dist.all_reduce(total_loss, op=dist.ReduceOp.SUM) - elif torch.cuda.device_count() > 1 and train_config.enable_fsdp: - dist.all_reduce(total_loss, op=dist.ReduceOp.SUM) - train_epoch_loss = total_loss / len(train_dataloader) * gradient_accumulation_steps - train_epoch_loss = total_loss / len(train_dataloader) * gradient_accumulation_steps - if train_config.enable_fsdp: - train_epoch_loss = train_epoch_loss/world_size - train_perplexity = torch.exp(train_epoch_loss) - train_loss.append(float(train_epoch_loss)) - - if not train_config.enable_fsdp or rank==0: - memtrace.print_stats() - - # Update the learning rate as needed - # lr_scheduler.step() - dist.barrier() - eval_outputs = eval_loop(model, evaluate_lm, optimizer, lr_scheduler, - train_config, fsdp_config, rank, eval_dataloader, - local_rank, tokenizer, wandb_run, - val_step_loss, val_loss, best_val_loss, - checkpoint_times, epoch, total_step) - dist.barrier() - save_path, val_step_loss, val_loss, best_val_loss, checkpoint_times = eval_outputs - if save_path is not None: - best_checkpoint_path = save_path - - if rank == 0 or not train_config.enable_fsdp: - print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s") - - results = {'best_val_loss': best_val_loss, - 'checkpoint_times': sum(checkpoint_times)/ len(checkpoint_times) if len(checkpoint_times) > 0 else 0} - return results, best_checkpoint_path - - -def evaluate_lm(model, train_config, eval_dataloader, - local_rank, tokenizer, wandb_run, - epoch: int = None, rank: int = 0): - """ - Evaluates the model on the given dataloader - - Args: - model: The model to evaluate - eval_dataloader: The dataloader containing the evaluation data - local_rank: The rank of the current node in a distributed setting - tokenizer: The tokenizer used to decode predictions - - Returns: eval_epoch_loss - """ - loss_computer = LossComputer(**train_config.trainer) - if train_config.enable_fsdp: - world_size = int(os.environ["WORLD_SIZE"]) - model.eval() - val_step_loss = [] - - eval_loss = 0.0 # Initialize evaluation loss - _epoch = f' {epoch}' if epoch is not None else '' - pbar = tqdm(eval_dataloader,colour="green", desc=f"Rank {rank} | Eval Epoch{_epoch}", dynamic_ncols=True) - for step, batch in enumerate(pbar): - for key in batch.keys(): - if train_config.enable_fsdp: - batch[key] = batch[key].to(local_rank) - else: - if is_xpu_available(): - batch[key] = batch[key].to('xpu:0') - else: - batch[key] = batch[key].to('cuda:0') - # Ensure no gradients are computed for this scope to save memory - with torch.no_grad(): - # Forward pass and compute loss - loss = loss_computer.compute_loss(model, batch, rank=rank, local_rank=local_rank) - if train_config.save_metrics: - val_step_loss.append(loss.detach().cpu().float().item()) - - # Check NaNs in loss - if torch.isnan(loss).any(): - print("NaN detected in eval loss. Skipping evaluation accumulation.") - else: - eval_loss += loss.detach().float() - _ppl = torch.exp(eval_loss/(step+1)).item() - pbar.set_description(f"Rank {rank} | Eval Epoch{_epoch} | step_loss: {loss.item():.5f} | avg_loss: {eval_loss.item()/(step+1):.5f} | avg_ppl: {_ppl:.5f}") - if step > 20: # hack - break - - # If there's more than one CUDA device, reduce evaluation loss across all devices - if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp): - dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM) - elif torch.cuda.device_count() > 1 and train_config.enable_fsdp: # what's the diff b/t this condition and above? - dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM) - - # Compute average loss - # print('len(eval_dataloader):', len(eval_dataloader)) - # print('step + 1:', step + 1) - # print('world_size:', world_size) - eval_epoch_loss = eval_loss / 20 # len(eval_dataloader) - if train_config.enable_fsdp: - eval_epoch_loss = eval_epoch_loss/world_size - - eval_epoch_ppl = torch.exp(eval_epoch_loss).item() - eval_epoch_loss = eval_epoch_loss.item() - del eval_loss; del batch - clear_gpu_cache() - - # Print evaluation metrics - if local_rank == 0 or not train_config.enable_fsdp: - print(f" eval_epoch_loss={eval_epoch_loss}, eval_epoch_ppl={eval_epoch_ppl}") - - if wandb_run: - wandb_run.log({'eval/loss': eval_epoch_loss, 'eval/ppl': eval_epoch_ppl}, commit=False) - - return eval_epoch_loss, val_step_loss diff --git a/llama_recipes/utils/__init__.py b/llama_recipes/utils/__init__.py deleted file mode 100644 index d927896..0000000 --- a/llama_recipes/utils/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -from llama_recipes.utils.memory_utils import MemoryTrace -# from llama_recipes.utils.dataset_utils import * # MZ 8/20; don't need these -from llama_recipes.utils.fsdp_utils import fsdp_auto_wrap_policy, hsdp_device_mesh -from llama_recipes.utils.train_utils import * \ No newline at end of file diff --git a/llama_recipes/utils/config_utils.py b/llama_recipes/utils/config_utils.py deleted file mode 100644 index a4c866d..0000000 --- a/llama_recipes/utils/config_utils.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -import inspect -from dataclasses import asdict - -import torch.distributed as dist -from torch.utils.data import DistributedSampler -from peft import ( - LoraConfig, - AdaptionPromptConfig, - PrefixTuningConfig, -) -from transformers import default_data_collator -from transformers.data import DataCollatorForSeq2Seq - -from llama_recipes.configs import datasets, lora_config, llama_adapter_config, prefix_config, train_config -from llama_recipes.data.sampler import LengthBasedBatchSampler, DistributedLengthBasedBatchSampler -# from llama_recipes.utils.dataset_utils import DATASET_PREPROC -DATASET_PREPROC = {} # MZ 8/20/24 - - -def update_config(config, **kwargs): - if isinstance(config, (tuple, list)): - for c in config: - update_config(c, **kwargs) - else: - for k, v in kwargs.items(): - if hasattr(config, k): - setattr(config, k, v) - elif "." in k: - # allow --some_config.some_param=True - config_name, param_name = k.split(".") - if type(config).__name__ == config_name: - if hasattr(config, param_name): - setattr(config, param_name, v) - else: - # In case of specialized config we can warn user - print(f"Warning: {config_name} does not accept parameter: {k}") - elif isinstance(config, train_config): - print(f"Warning: unknown parameter {k}") - - -def generate_peft_config(train_config, kwargs): - configs = (lora_config, llama_adapter_config, prefix_config) - peft_configs = (LoraConfig, AdaptionPromptConfig, PrefixTuningConfig) - names = tuple(c.__name__.rstrip("_config") for c in configs) - - if train_config.peft_method not in names: - raise RuntimeError(f"Peft config not found: {train_config.peft_method}") - - if train_config.peft_method == "prefix": - raise RuntimeError("PrefixTuning is currently not supported (see https://github.com/meta-llama/llama-recipes/issues/359#issuecomment-2089350811)") - - if train_config.enable_fsdp and train_config.peft_method == "llama_adapter": - raise RuntimeError("Llama_adapter is currently not supported in combination with FSDP (see https://github.com/meta-llama/llama-recipes/issues/359#issuecomment-2089274425)") - - config = configs[names.index(train_config.peft_method)]() - - update_config(config, **kwargs) - params = asdict(config) - peft_config = peft_configs[names.index(train_config.peft_method)](**params) - - return peft_config - - -def generate_dataset_config(train_config, kwargs): - names = tuple(DATASET_PREPROC.keys()) - - assert train_config.dataset in names, f"Unknown dataset: {train_config.dataset}" - - dataset_config = {k:v for k, v in inspect.getmembers(datasets)}[train_config.dataset]() - - update_config(dataset_config, **kwargs) - - return dataset_config - - -def get_dataloader_kwargs(train_config, dataset, tokenizer, mode): - kwargs = {} - batch_size = train_config.batch_size_training if mode=="train" else train_config.val_batch_size - if train_config.batching_strategy == "padding": - if train_config.enable_fsdp: - kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler( - dataset, - batch_size=batch_size, - rank=dist.get_rank(), - num_replicas=dist.get_world_size(), - shuffle=mode=="train", - ) - else: - kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, shuffle=mode=="train") - kwargs["collate_fn"] = DataCollatorForSeq2Seq(tokenizer) - elif train_config.batching_strategy == "packing": - if train_config.enable_fsdp: - kwargs["sampler"] = DistributedSampler( - dataset, - rank=dist.get_rank(), - num_replicas=dist.get_world_size(), - shuffle=mode=="train", - drop_last=True, - ) - kwargs["batch_size"] = batch_size - kwargs["drop_last"] = True - kwargs["collate_fn"] = default_data_collator - else: - raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}") - - return kwargs diff --git a/llama_recipes/utils/dataset_utils.py b/llama_recipes/utils/dataset_utils.py deleted file mode 100644 index 009d994..0000000 --- a/llama_recipes/utils/dataset_utils.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -import importlib -from functools import partial -from pathlib import Path - -import torch - -from llama_recipes.datasets_llama_recipes import ( - get_grammar_dataset, - get_alpaca_dataset, - get_samsum_dataset, - get_llamaguard_toxicchat_dataset, -) - - -def load_module_from_py_file(py_file: str) -> object: - """ - This method loads a module from a py file which is not in the Python path - """ - module_name = Path(py_file).name - loader = importlib.machinery.SourceFileLoader(module_name, py_file) - spec = importlib.util.spec_from_loader(module_name, loader) - module = importlib.util.module_from_spec(spec) - - loader.exec_module(module) - - return module - - -def get_custom_dataset(dataset_config, tokenizer, split: str): - if ":" in dataset_config.file: - module_path, func_name = dataset_config.file.split(":") - else: - module_path, func_name = dataset_config.file, "get_custom_dataset" - - if not module_path.endswith(".py"): - raise ValueError(f"Dataset file {module_path} is not a .py file.") - - module_path = Path(module_path) - if not module_path.is_file(): - raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.") - - module = load_module_from_py_file(module_path.as_posix()) - try: - return getattr(module, func_name)(dataset_config, tokenizer, split) - except AttributeError as e: - print(f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()}).") - raise e - - -DATASET_PREPROC = { - "alpaca_dataset": partial(get_alpaca_dataset), - "grammar_dataset": get_grammar_dataset, - "samsum_dataset": get_samsum_dataset, - "custom_dataset": get_custom_dataset, - "llamaguard_toxicchat_dataset": get_llamaguard_toxicchat_dataset, - -} - - -def get_preprocessed_dataset( - tokenizer, dataset_config, split: str = "train" -) -> torch.utils.data.Dataset: - if not dataset_config.dataset in DATASET_PREPROC: - raise NotImplementedError(f"{dataset_config.dataset} is not (yet) implemented") - - def get_split(): - return ( - dataset_config.train_split - if split == "train" - else dataset_config.test_split - ) - - return DATASET_PREPROC[dataset_config.dataset]( - dataset_config, - tokenizer, - get_split(), - ) diff --git a/llama_recipes/utils/flop_utils.py b/llama_recipes/utils/flop_utils.py deleted file mode 100644 index dcdb28e..0000000 --- a/llama_recipes/utils/flop_utils.py +++ /dev/null @@ -1,87 +0,0 @@ -from typing import Any, Dict, List, Optional, Union -import time -import torch -from torch.utils.flop_counter import FlopCounterMode - - -class FlopMeasure(FlopCounterMode): - """ - ``FlopMeasure`` is a customized context manager that counts the number of - flops within its context. It is based on ``FlopCounterMode`` with additional start_counting() and stop_counting() function so that the flop counting - will only start after the warmup stage. - It also supports hierarchical output by passing a module (or list of modules) to FlopCounterMode on construction. - - Example usage - - .. code-block:: python - - model = ... - flop_counter = FlopMeasure(model,local_rank=0,warmup_step=3) - for batch in enumerate(dataloader): - with flop_counter: - model(batch) - flop_counter.step() - """ - - def __init__( - self, - mods: Optional[Union[torch.nn.Module, List[torch.nn.Module]]] = None, - depth: int = 2, - display: bool = True, - custom_mapping: Dict[Any, Any] = None, - rank=None, - warmup_step: int = 3, - ): - super().__init__(mods, depth, display, custom_mapping) - self.rank = rank - self.warmup_step = warmup_step - self.start_time = 0 - self.end_time = 0 - - def step(self): - # decrease the warmup step by 1 for every step, so that the flop counting will start when warmup_step =0. Stop decreasing when warm_up reaches -1. - if self.warmup_step >= 0: - self.warmup_step -= 1 - if self.warmup_step == 0 and self.start_time == 0: - self.start_time = time.time() - elif self.warmup_step == -1 and self.start_time != 0 and self.end_time == 0: - self.end_time = time.time() - def __enter__(self): - if self.warmup_step == 0: - self.start_time = time.time() - super().__enter__() - return self - def is_done(self): - return self.warmup_step == -1 - def get_total_flops(self): - return super().get_total_flops() - def get_flops_per_sec(self): - if self.start_time == 0 or self.end_time == 0: - print("Warning: flop count did not finish correctly") - return 0 - return super().get_total_flops()/ (self.end_time - self.start_time) - def get_table(self, depth=2): - return super().get_table(depth) - - def __exit__(self, *args): - if self.get_total_flops() == 0: - print( - "Warning: did not record any flops this time. Skipping the flop report" - ) - else: - if self.display: - if self.rank is None or self.rank == 0: - print("Total time used in this flop counting step is: {}".format(self.end_time - self.start_time)) - print("The total TFlop per second is: {}".format(self.get_flops_per_sec() / 1e12)) - print("The tflop_count table is below:") - print(self.get_table(self.depth)) - # Disable the display feature so that we don't print the table again - self.display = False - super().__exit__(*args) - - def __torch_dispatch__(self, func, types, args=(), kwargs=None): - # when warmup_step is 0, count the flops and return the original output - if self.warmup_step == 0: - return super().__torch_dispatch__(func, types, args, kwargs) - # otherwise, just return the original output - return func(*args, **kwargs) diff --git a/llama_recipes/utils/fsdp_utils.py b/llama_recipes/utils/fsdp_utils.py deleted file mode 100644 index 9d642a5..0000000 --- a/llama_recipes/utils/fsdp_utils.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. -""" -Modified for linearizing layers -""" -import os -import functools -from torch.distributed._tensor.device_mesh import init_device_mesh -from torch.distributed.fsdp.wrap import ( - _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy -) -from peft.tuners import PrefixEncoder, PromptEmbedding, PromptEncoder - - -def fsdp_auto_wrap_policy(model, transformer_layer_name): - """ - Return custom wrapping for FSDP to train linearizing layers - """ - - def lambda_policy_fn(module): - if ( - len(list(module.named_children())) == 0 and ( - getattr(module, "weight", None) is not None or - getattr(module, "layer", None) is not None - # or str(type(module)) == "" - ) # and module.weight.requires_grad # (MZ 3/12/24) We want all modules for attention training - ): - return True - # try: - # if module.weight.requires_grad: - # return True - # except AttributeError: - # if module.layer.requires_grad: - # return True - return False - - lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn) - transformer_wrap_policy = functools.partial( - transformer_auto_wrap_policy, - transformer_layer_cls=( - # PrefixEncoder, PromptEmbedding, PromptEncoder, - transformer_layer_name, - ), - ) - auto_wrap_policy = functools.partial( - _or_policy, policies=[lambda_policy, transformer_wrap_policy]) - return auto_wrap_policy - - -def hsdp_device_mesh(replica_group_size, sharding_group_size, device=None): - """ - Initializes a device mesh for use with Hybrid Sharding strategy in FSDP (HSDP) training. - - This function requires explicit sizes for replica and sharding groups to accommodate models - whose GPU fit is unknown, providing flexibility in distributed training setups. - - Args: - replica_group_size (int): The size of each replica group. Must be provided to ensure - the model fits within the available resources. - sharding_group_size (int): The size of each sharding group that the model can fit. Must be provided to - ensure the correct distribution of model parameters. - device (str, optional): The device to use (e.g., "cuda:0"). If None, defaults to "cuda" - with the local rank as the device index. - - Returns: - A device mesh object compatible with FSDP. - - Raises: - ValueError: If replica_group_size or sharding_group_size are not provided, or if the - world size is not evenly divisible by the sharding group size. - RuntimeError: If a valid device mesh cannot be created. - - Usage: - If your model fits on 4 GPUS, and you have 3 nodes of 8 GPUs, then: - Sharding_Group_Size = 4 - Replica_Groups_Size = (24 total gpus, 4 per sharding group) = 6 Replica Groups - >>> device_mesh = initialize_device_mesh(replica_group_size, sharding_group_size) - >>> sharded_model = FSDP(model, device_mesh=device_mesh, ...) - """ - - if replica_group_size is None or sharding_group_size is None: - raise ValueError("Both replica_group_size and sharding_group_size must be provided.") - - local_rank = int(os.getenv("LOCAL_RANK", "0")) - world_size = int(os.getenv("WORLD_SIZE", "1")) - - device = device or f"cuda" - - if world_size % sharding_group_size != 0: - raise ValueError(f"World size {world_size} is not evenly divisible by " - f"sharding group size {sharding_group_size}.") - - if (world_size // sharding_group_size) % replica_group_size != 0: - raise ValueError(f"The calculated number of replica groups is not evenly divisible by " - f"replica_group_size {replica_group_size}.") - - device_mesh = init_device_mesh(device, (replica_group_size, sharding_group_size)) - if device_mesh is None: - raise RuntimeError("Failed to create a valid device mesh.") - - return device_mesh diff --git a/llama_recipes/utils/fsdp_utils_llama_recipes.py b/llama_recipes/utils/fsdp_utils_llama_recipes.py deleted file mode 100644 index c1b0b17..0000000 --- a/llama_recipes/utils/fsdp_utils_llama_recipes.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. -from torch.distributed._tensor.device_mesh import init_device_mesh -import os - -def fsdp_auto_wrap_policy(model, transformer_layer_name): - import functools - - from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy - - def lambda_policy_fn(module): - if ( - len(list(module.named_children())) == 0 - and getattr(module, "weight", None) is not None - and module.weight.requires_grad - ): - return True - return False - - lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn) - transformer_wrap_policy = functools.partial( - transformer_auto_wrap_policy, - transformer_layer_cls=( - transformer_layer_name, - ), - ) - - auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy]) - return auto_wrap_policy - - -def hsdp_device_mesh(replica_group_size, sharding_group_size, device=None): - """ - Initializes a device mesh for use with Hybrid Sharding strategy in FSDP (HSDP) training. - - This function requires explicit sizes for replica and sharding groups to accommodate models - whose GPU fit is unknown, providing flexibility in distributed training setups. - - Args: - replica_group_size (int): The size of each replica group. Must be provided to ensure - the model fits within the available resources. - sharding_group_size (int): The size of each sharding group that the model can fit. Must be provided to - ensure the correct distribution of model parameters. - device (str, optional): The device to use (e.g., "cuda:0"). If None, defaults to "cuda" - with the local rank as the device index. - - Returns: - A device mesh object compatible with FSDP. - - Raises: - ValueError: If replica_group_size or sharding_group_size are not provided, or if the - world size is not evenly divisible by the sharding group size. - RuntimeError: If a valid device mesh cannot be created. - - Usage: - If your model fits on 4 GPUS, and you have 3 nodes of 8 GPUs, then: - Sharding_Group_Size = 4 - Replica_Groups_Size = (24 total gpus, 4 per sharding group) = 6 Replica Groups - >>> device_mesh = initialize_device_mesh(replica_group_size, sharding_group_size) - >>> sharded_model = FSDP(model, device_mesh=device_mesh, ...) - """ - - if replica_group_size is None or sharding_group_size is None: - raise ValueError("Both replica_group_size and sharding_group_size must be provided.") - - local_rank = int(os.getenv("LOCAL_RANK", "0")) - world_size = int(os.getenv("WORLD_SIZE", "1")) - - device = device or f"cuda" - - if world_size % sharding_group_size != 0: - raise ValueError(f"World size {world_size} is not evenly divisible by " - f"sharding group size {sharding_group_size}.") - - if (world_size // sharding_group_size) % replica_group_size != 0: - raise ValueError(f"The calculated number of replica groups is not evenly divisible by " - f"replica_group_size {replica_group_size}.") - - device_mesh = init_device_mesh(device, (replica_group_size, sharding_group_size)) - if device_mesh is None: - raise RuntimeError("Failed to create a valid device mesh.") - - return device_mesh diff --git a/llama_recipes/utils/memory_utils.py b/llama_recipes/utils/memory_utils.py deleted file mode 100644 index 3fa06e1..0000000 --- a/llama_recipes/utils/memory_utils.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -import gc -import psutil -import threading - -import torch -from accelerate.utils import is_xpu_available - -def byte2gb(x): - return int(x / 2**30) -# This context manager is used to track the peak memory usage of the process -class MemoryTrace: - def __enter__(self): - gc.collect() - if is_xpu_available(): - torch.xpu.empty_cache() - torch.xpu.reset_max_memory_allocated() # reset the peak gauge to zero - self.begin = byte2gb(torch.xpu.memory_allocated()) - elif torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero - self.begin = byte2gb(torch.cuda.memory_allocated()) - self.process = psutil.Process() - self.cpu_begin = byte2gb(self.cpu_mem_used()) - self.peak_monitoring = True - peak_monitor_thread = threading.Thread(target=self.peak_monitor_func) - peak_monitor_thread.daemon = True - peak_monitor_thread.start() - return self - - def cpu_mem_used(self): - """get resident set size memory for the current process""" - return self.process.memory_info().rss - - def peak_monitor_func(self): - self.cpu_peak = -1 - - while True: - self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak) - - # can't sleep or will not catch the peak right (this comment is here on purpose) - # time.sleep(0.001) # 1msec - - if not self.peak_monitoring: - break - - def __exit__(self, *exc): - self.peak_monitoring = False - - gc.collect() - if is_xpu_available(): - torch.xpu.empty_cache() - self.end = byte2gb(torch.xpu.memory_allocated()) - self.peak = byte2gb(torch.xpu.max_memory_allocated()) - xpu_info = torch.xpu.memory_stats() - self.peak_active_gb = byte2gb(xpu_info["active_bytes.all.peak"]) - self.malloc_retries = xpu_info.get("num_alloc_retries", 0) - self.peak_active_gb = byte2gb(xpu_info["active_bytes.all.peak"]) - self.m_ooms = xpu_info.get("num_ooms", 0) - self.used = byte2gb(self.end - self.begin) - self.peaked = byte2gb(self.peak - self.begin) - self.max_reserved = byte2gb(torch.xpu.max_memory_reserved()) - elif torch.cuda.is_available(): - torch.cuda.empty_cache() - self.end = byte2gb(torch.cuda.memory_allocated()) - self.peak = byte2gb(torch.cuda.max_memory_allocated()) - cuda_info = torch.cuda.memory_stats() - self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"]) - self.malloc_retries = cuda_info.get("num_alloc_retries", 0) - self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"]) - self.m_ooms = cuda_info.get("num_ooms", 0) - self.used = byte2gb(self.end - self.begin) - self.peaked = byte2gb(self.peak - self.begin) - self.max_reserved = byte2gb(torch.cuda.max_memory_reserved()) - - self.cpu_end = self.cpu_mem_used() - self.cpu_used = byte2gb(self.cpu_end - self.cpu_begin) - self.cpu_peaked = byte2gb(self.cpu_peak - self.cpu_begin) - # print(f"delta used/peak {self.used:4d}/{self.peaked:4d}") - - def print_stats(self): - device_str = None - if is_xpu_available(): - device_str = "XPU" - elif torch.cuda.is_available(): - device_str = "CUDA" - - if device_str: - print(f"Max {device_str} memory allocated was {self.peak} GB") - print(f"Max {device_str} memory reserved was {self.max_reserved} GB") - print(f"Peak active {device_str} memory was {self.peak_active_gb} GB") - print(f"{device_str} Malloc retries : {self.malloc_retries}") - print(f"CPU Total Peak Memory consumed during the train (max): {self.cpu_peaked + self.cpu_begin} GB") \ No newline at end of file diff --git a/llama_recipes/utils/plot_metrics.py b/llama_recipes/utils/plot_metrics.py deleted file mode 100644 index e8ab230..0000000 --- a/llama_recipes/utils/plot_metrics.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -import json -import matplotlib.pyplot as plt -import argparse -import os - -def plot_metric(data, metric_name, x_label, y_label, title, colors): - plt.figure(figsize=(7, 6)) - - plt.plot(data[f'train_epoch_{metric_name}'], label=f'Train Epoch {metric_name.capitalize()}', color=colors[0]) - plt.plot(data[f'val_epoch_{metric_name}'], label=f'Validation Epoch {metric_name.capitalize()}', color=colors[1]) - plt.xlabel(x_label) - plt.ylabel(y_label) - plt.title(f'Train and Validation Epoch {title}') - plt.legend() - plt.tight_layout() - -def plot_single_metric_by_step(data, metric_name, x_label, y_label, title, color): - plt.plot(data[f'{metric_name}'], label=f'{title}', color=color) - plt.xlabel(x_label) - plt.ylabel(y_label) - plt.title(title) - plt.legend() - plt.tight_layout() - -def plot_metrics_by_step(data, metric_name, x_label, y_label, colors): - plt.figure(figsize=(14, 6)) - - plt.subplot(1, 2, 1) - plot_single_metric_by_step(data, f'train_step_{metric_name}', x_label, y_label, f'Train Step {metric_name.capitalize()}', colors[0]) - plt.subplot(1, 2, 2) - plot_single_metric_by_step(data, f'val_step_{metric_name}', x_label, y_label, f'Validation Step {metric_name.capitalize()}', colors[1]) - plt.tight_layout() - - -def plot_metrics(file_path): - if not os.path.exists(file_path): - print(f"File {file_path} does not exist.") - return - - with open(file_path, 'r') as f: - try: - data = json.load(f) - except json.JSONDecodeError: - print("Invalid JSON file.") - return - - directory = os.path.dirname(file_path) - filename_prefix = os.path.basename(file_path).split('.')[0] - - plot_metric(data, 'loss', 'Epoch', 'Loss', 'Loss', ['b', 'r']) - plt.savefig(os.path.join(directory, f"{filename_prefix}_train_and_validation_loss.png")) - plt.close() - - plot_metric(data, 'perplexity', 'Epoch', 'Perplexity', 'Perplexity', ['g', 'm']) - plt.savefig(os.path.join(directory, f"{filename_prefix}_train_and_validation_perplexity.png")) - plt.close() - - plot_metrics_by_step(data, 'loss', 'Step', 'Loss', ['b', 'r']) - plt.savefig(os.path.join(directory, f"{filename_prefix}_train_and_validation_loss_by_step.png")) - plt.close() - - plot_metrics_by_step(data, 'perplexity', 'Step', 'Loss', ['g', 'm']) - plt.savefig(os.path.join(directory, f"{filename_prefix}_train_and_validation_perplexity_by_step.png")) - plt.close() - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Plot metrics from JSON file.') - parser.add_argument('--file_path', required=True, type=str, help='Path to the metrics JSON file.') - args = parser.parse_args() - - plot_metrics(args.file_path) diff --git a/llama_recipes/utils/train_utils.py b/llama_recipes/utils/train_utils.py deleted file mode 100644 index 09b5b77..0000000 --- a/llama_recipes/utils/train_utils.py +++ /dev/null @@ -1,554 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. - -import os -import time -import yaml -from contextlib import nullcontext -from pathlib import Path -from datetime import datetime -import contextlib - - -import torch -import torch.cuda.nccl as nccl -import torch.distributed as dist -from torch.distributed.fsdp import StateDictType -from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler -from tqdm import tqdm -from transformers import LlamaTokenizer -import json - - -from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint, save_peft_checkpoint -from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper -from llama_recipes.utils.memory_utils import MemoryTrace -from accelerate.utils import is_xpu_available, is_ccl_available -from llama_recipes.utils.flop_utils import FlopMeasure -def set_tokenizer_params(tokenizer: LlamaTokenizer): - tokenizer.pad_token_id = 0 - tokenizer.padding_side = "left" - -@contextlib.contextmanager -def profile(cfg, local_rank=None): - use_profiler: bool = cfg.use_profiler - use_flop_counter: bool = cfg.flop_counter - if use_flop_counter and use_profiler: - raise ValueError("Cannot use both profiler and flop counter") - if use_profiler: - # profiler needs a warmup stage to get the accurate profiling results - wait_step, warmup_step, active_step = 1, 2, 3 - min_step = wait_step + warmup_step + active_step + 1 - if cfg.max_train_step > 0 and cfg.max_train_step < min_step: - raise ValueError(f"pytorch profiler requires at least {min_step} train steps to finish the warm-up and recording stage, {wait_step} for wait_step, {warmup_step} for warmup_step, {active_step} for profiling step, please increase the max_train_step, current max_train_step {cfg.max_train_step}") - print(f"pytorch profiling is activated and results will be saved in {cfg.profiler_dir}") - with torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - schedule=torch.profiler.schedule(wait=wait_step, warmup=warmup_step, active=active_step, repeat=1), - on_trace_ready=torch.profiler.tensorboard_trace_handler( - cfg.profiler_dir - ), - profile_memory=True, - with_stack=False, - with_flops=True, - record_shapes=True, - ) as torch_profiler: - yield torch_profiler - elif use_flop_counter: - if cfg.max_train_step > 0 and cfg.max_train_step <= cfg.flop_counter_start: - raise ValueError(f"flop counter requires at least {cfg.flop_counter_start + 1} train steps, please increase the max_train_step, current max_train_step {cfg.max_train_step}") - with FlopMeasure(rank=local_rank,warmup_step=cfg.flop_counter_start) as flop_counter: - yield flop_counter - else: - torch_profiler = contextlib.nullcontext() - yield None - - -def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None, wandb_run=None): - """ - Trains the model on the given dataloader - - Args: - model: The model to be trained - train_dataloader: The dataloader containing the training data - optimizer: The optimizer used for training - lr_scheduler: The learning rate scheduler - gradient_accumulation_steps: The number of steps to accumulate gradients before performing a backward/update operation - num_epochs: The number of epochs to train for - local_rank: The rank of the current node in a distributed setting - train_config: The training configuration - eval_dataloader: The dataloader containing the eval data - tokenizer: tokenizer used in the eval for decoding the predicitons - - Returns: results dictionary containing average training and validation perplexity and loss - """ - # Create a gradient scaler for fp16 - if train_config.use_fp16 and train_config.enable_fsdp: - scaler = ShardedGradScaler() - elif train_config.use_fp16 and not train_config.enable_fsdp: - scaler = torch.cuda.amp.GradScaler() - if train_config.enable_fsdp: - world_size = int(os.environ["WORLD_SIZE"]) - - - - autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext - train_prep = [] - train_loss = [] - val_prep = [] - val_loss =[] - - if train_config.save_metrics: - if not os.path.exists(train_config.output_dir): - os.makedirs(train_config.output_dir, exist_ok=True) - metrics_filename = f"{train_config.output_dir}/metrics_data_{local_rank}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.json" - train_step_perplexity = [] - train_step_loss = [] - val_step_loss = [] - val_step_perplexity = [] - - epoch_times = [] - checkpoint_times = [] - results = {} - best_val_loss = float("inf") - total_train_steps = 0 - max_steps_reached = False # Flag to indicate max training steps reached - # Start the training loop - for epoch in range(train_config.num_epochs): - # stop when the maximum number of training steps is reached - if max_steps_reached: - break - epoch_start_time = time.perf_counter() - with MemoryTrace() as memtrace: # track the memory usage - model.train() - total_loss = 0.0 - total_length = len(train_dataloader)//gradient_accumulation_steps - pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True) - with profile(train_config,local_rank) as profile_context: - for step, batch in enumerate(train_dataloader): - total_train_steps += 1 - # stop when the maximum number of training steps is reached - if train_config.max_train_step > 0 and total_train_steps > train_config.max_train_step: - max_steps_reached = True - if not train_config.enable_fsdp or local_rank==0: - print("max training steps reached, stopping training, total train steps finished: ", total_train_steps-1) - break - for key in batch.keys(): - if train_config.enable_fsdp: - if is_xpu_available(): - batch[key] = batch[key].to(torch.device(f"xpu:{local_rank}")) - else: - batch[key] = batch[key].to(local_rank) - else: - - if is_xpu_available(): - batch[key] = batch[key].to('xpu:0') - else: - batch[key] = batch[key].to('cuda:0') - with autocast(): - loss = model(**batch).loss - loss = loss / gradient_accumulation_steps - if train_config.save_metrics: - train_step_loss.append(loss.detach().float().item()) - train_step_perplexity.append(float(torch.exp(loss.detach().float()))) - total_loss += loss.detach().float() - if train_config.use_fp16: - # if fp16 is enabled, use gradient scaler to handle gradient update - scaler.scale(loss).backward() - if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: - if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0: - scaler.unscale_(optimizer) - if train_config.enable_fsdp: - model.clip_grad_norm_(train_config.gradient_clipping_threshold) - else: - torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold) - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - pbar.update(1) - else: - # regular backpropagation when fp16 is not used - loss.backward() - if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: - if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0: - if train_config.enable_fsdp: - model.clip_grad_norm_(train_config.gradient_clipping_threshold) - else: - torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold) - optimizer.step() - optimizer.zero_grad() - pbar.update(1) - if train_config.use_profiler or train_config.flop_counter: - profile_context.step() - if train_config.flop_counter and profile_context.is_done(): - TFlops = profile_context.get_flops_per_sec() / 1e12 - if wandb_run: - if not train_config.enable_fsdp or rank==0: - wandb_run.log({ - 'train/epoch': epoch + 1, - 'train/step': epoch * len(train_dataloader) + step, - 'train/loss': loss.detach().float(), - }) - - pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})") - - if train_config.save_metrics: - save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep) - pbar.close() - - epoch_end_time = time.perf_counter()-epoch_start_time - epoch_times.append(epoch_end_time) - # Reducing total_loss across all devices if there's more than one CUDA device - if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp): - dist.all_reduce(total_loss, op=dist.ReduceOp.SUM) - elif torch.cuda.device_count() > 1 and train_config.enable_fsdp: - dist.all_reduce(total_loss, op=dist.ReduceOp.SUM) - train_epoch_loss = total_loss / len(train_dataloader) - if train_config.enable_fsdp: - train_epoch_loss = train_epoch_loss/world_size - train_perplexity = torch.exp(train_epoch_loss) - - train_prep.append(float(train_perplexity)) - train_loss.append(float(train_epoch_loss)) - - if not train_config.enable_fsdp or rank==0: - memtrace.print_stats() - - # Update the learning rate as needed - lr_scheduler.step() - if train_config.run_validation: - eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer, wandb_run) - if train_config.save_metrics: - val_step_loss.extend(temp_val_loss) - val_step_perplexity.extend(temp_step_perplexity) - - checkpoint_start_time = time.perf_counter() - if train_config.save_model and eval_epoch_loss < best_val_loss: - if train_config.enable_fsdp: - dist.barrier() - if train_config.use_peft: - if train_config.enable_fsdp: - if rank==0: - print(f"we are about to save the PEFT modules") - else: - print(f"we are about to save the PEFT modules") - save_peft_checkpoint(model, train_config.output_dir) - if train_config.enable_fsdp: - if rank==0: - print(f"PEFT modules are saved in {train_config.output_dir} directory") - else: - print(f"PEFT modules are saved in {train_config.output_dir} directory") - - else: - if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT: - - save_model_checkpoint( - model, optimizer, rank, train_config, epoch=epoch - ) - elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT: - print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT") - print("=====================================================") - - save_model_and_optimizer_sharded(model, rank, train_config) - if train_config.save_optimizer: - save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer) - print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT") - print("=====================================================") - - if not train_config.use_peft and train_config.save_optimizer: - save_optimizer_checkpoint( - model, optimizer, rank, train_config, epoch=epoch - ) - print(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT") - print("=====================================================") - if train_config.enable_fsdp: - dist.barrier() - checkpoint_end_time = time.perf_counter() - checkpoint_start_time - checkpoint_times.append(checkpoint_end_time) - if eval_epoch_loss < best_val_loss: - best_val_loss = eval_epoch_loss - if train_config.enable_fsdp: - if rank==0: - print(f"best eval loss on epoch {epoch+1} is {best_val_loss}") - else: - print(f"best eval loss on epoch {epoch+1} is {best_val_loss}") - val_loss.append(float(best_val_loss)) - val_prep.append(float(eval_ppl)) - if train_config.enable_fsdp: - if rank==0: - print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s") - else: - print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s") - - # Saving the results every epoch to plot later - if train_config.save_metrics: - save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep) - - avg_epoch_time = sum(epoch_times)/ len(epoch_times) - avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) if len(checkpoint_times) > 0 else 0 - avg_train_prep = sum(train_prep)/len(train_prep) - avg_train_loss = sum(train_loss)/len(train_loss) - if train_config.run_validation: - avg_eval_prep = sum(val_prep)/len(val_prep) - avg_eval_loss = sum(val_loss)/len(val_loss) - - results['avg_train_prep'] = avg_train_prep - results['avg_train_loss'] = avg_train_loss - if train_config.run_validation: - results['avg_eval_prep'] = avg_eval_prep - results['avg_eval_loss'] = avg_eval_loss - results["avg_epoch_time"] = avg_epoch_time - results["avg_checkpoint_time"] = avg_checkpoint_time - if train_config.save_metrics: - results["metrics_filename"] = metrics_filename - if train_config.flop_counter: - results["model_tflops"]= TFlops - #saving the training params including fsdp setting for reference. - if train_config.enable_fsdp and not train_config.use_peft and rank==0: - save_train_params(train_config, fsdp_config, rank) - - return results - -def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb_run): - """ - Evaluates the model on the given dataloader - - Args: - model: The model to evaluate - eval_dataloader: The dataloader containing the evaluation data - local_rank: The rank of the current node in a distributed setting - tokenizer: The tokenizer used to decode predictions - - Returns: eval_ppl, eval_epoch_loss - """ - if train_config.enable_fsdp: - world_size = int(os.environ["WORLD_SIZE"]) - model.eval() - eval_preds = [] - val_step_loss = [] - val_step_perplexity = [] - eval_loss = 0.0 # Initialize evaluation loss - total_eval_steps = 0 - with MemoryTrace() as memtrace: - for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)): - total_eval_steps += 1 - # stop when the maximum number of eval steps is reached - if train_config.max_eval_step > 0 and total_eval_steps > train_config.max_eval_step: - if not train_config.enable_fsdp or local_rank==0: - print("max eval steps reached, stopping evaluation, total_eval_steps: ", total_eval_steps - 1) - break - for key in batch.keys(): - if train_config.enable_fsdp: - batch[key] = batch[key].to(local_rank) - else: - if is_xpu_available(): - batch[key] = batch[key].to('xpu:0') - else: - batch[key] = batch[key].to('cuda:0') - # Ensure no gradients are computed for this scope to save memory - with torch.no_grad(): - # Forward pass and compute loss - outputs = model(**batch) - loss = outputs.loss - if train_config.save_metrics: - val_step_loss.append(loss.detach().float().item()) - val_step_perplexity.append(float(torch.exp(loss.detach().float()))) - - eval_loss += loss.detach().float() - # Decode predictions and add to evaluation predictions list - preds = torch.argmax(outputs.logits, -1) - eval_preds.extend( - tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True) - ) - - # If there's more than one CUDA device, reduce evaluation loss across all devices - if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp): - dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM) - if torch.cuda.device_count() > 1 and train_config.enable_fsdp: - dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM) - - # Compute average loss and perplexity - eval_epoch_loss = eval_loss / len(eval_dataloader) - if train_config.enable_fsdp: - eval_epoch_loss = eval_epoch_loss/world_size - eval_ppl = torch.exp(eval_epoch_loss) - - # Print evaluation metrics - if train_config.enable_fsdp: - if local_rank==0: - print(f" {eval_ppl=} {eval_epoch_loss=}") - else: - print(f" {eval_ppl=} {eval_epoch_loss=}") - - if wandb_run: - wandb_run.log({ - 'eval/perplexity': eval_ppl, - 'eval/loss': eval_epoch_loss, - }, commit=False) - - return eval_ppl, eval_epoch_loss, val_step_loss, val_step_perplexity - -def freeze_transformer_layers(model, num_layer): - for i, layer in enumerate(model.model.layers): - if i < num_layer: - for param in layer.parameters(): - param.requires_grad = False - - -def check_frozen_layers_peft_model(model): - for i, layer in enumerate(model.base_model.model.model.layers): - for name, param in layer.named_parameters(): - print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}") - - -def setup(): - """Initialize the process group for distributed training""" - if is_ccl_available(): - # distributed training on xpus - dist.init_process_group("ccl") - else: - dist.init_process_group("nccl") - - -def setup_environ_flags(rank): - """Set environment flags for debugging purposes""" - os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1) - os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1) - # os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" - # This flag will help with CUDA memory fragmentations that can lead into OOM in some cases. - # Note this is only availble in PyTorch Nighlies (as of July 30 2023) - # os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True' - if rank == 0: - print(f"--> Running with torch dist debug set to detail") - - -def cleanup(): - """Clean up the process group after training""" - dist.destroy_process_group() - - -def clear_gpu_cache(rank=None): - """Clear the GPU cache for all ranks""" - if rank == 0: - print(f"Clearing GPU cache for all ranks") - if is_xpu_available(): - torch.xpu_empty_cache() - else: - torch.cuda.empty_cache() - - -def get_parameter_dtypes(model): - """Get the data types of model parameters""" - parameter_dtypes = {} - for name, parameter in model.named_parameters(): - parameter_dtypes[name] = parameter.dtype - return parameter_dtypes - -def print_model_size(model, config, rank: int = 0) -> None: - """ - Print model name, the number of trainable parameters and initialization time. - - Args: - model: The PyTorch model. - model_name (str): Name of the model. - init_time_start (float): Initialization start time. - init_time_end (float): Initialization end time. - rank (int, optional): Current process's rank. Defaults to 0. - """ - if rank == 0: - print(f"--> Model {config.model_name}") - total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n") - - - - -def get_policies(cfg, rank): - """Get the policies for mixed precision and fsdp wrapping""" - - - verify_bfloat_support = (( - torch.version.cuda - and torch.cuda.is_bf16_supported() - and torch.version.cuda >= "11.0" - and dist.is_nccl_available() - and nccl.version() >= (2, 10) - ) or - (is_xpu_available())) - - - mixed_precision_policy = None - wrapping_policy = None - - # Mixed precision - if cfg.mixed_precision: - bf16_ready = verify_bfloat_support - - if bf16_ready and not cfg.use_fp16: - mixed_precision_policy = bfSixteen - if rank == 0: - print(f"bFloat16 enabled for mixed precision - using bfSixteen policy") - elif cfg.use_fp16: - mixed_precision_policy = fpSixteen - if rank == 0: - print(f"FP16 enabled") - else: - print(f"bFloat16 support not present. Using FP32, and not mixed precision") - wrapping_policy = get_llama_wrapper() - return mixed_precision_policy, wrapping_policy - -def save_train_params(train_config, fsdp_config, rank): - """ - This function saves the train_config and FSDP config into a train_params.yaml. - This will be used by converter script in the inference folder to fetch the HF model name or path. - It also would be hepful as a log for future references. - """ - # Convert the train_config and fsdp_config objects to dictionaries, - # converting all values to strings to ensure they can be serialized into a YAML file - train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')} - fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')} - # Merge the two dictionaries into one - train_params_dict = {**train_config_dict, **fsdp_config_dict} - # Construct the folder name (follwoing FSDP checkpointing style) using properties of the train_config object - folder_name = ( - train_config.dist_checkpoint_root_folder - + "/" - + train_config.dist_checkpoint_folder - + "-" - + train_config.model_name - ) - - save_dir = Path.cwd() / folder_name - # If the directory does not exist, create it - if not os.path.exists(save_dir): - os.makedirs(save_dir) - # Convert the dictionary to a YAML string - config_yaml = yaml.dump(train_params_dict, indent=4) - file_name = os.path.join(save_dir,'train_params.yaml') - - # Check if there's a directory with the same name as the file - if os.path.isdir(file_name): - print(f"Error: {file_name} is a directory, not a file.") - else: - # Write the YAML string to the file - with open(file_name, 'w') as f: - f.write(config_yaml) - if rank==0: - print(f"training params are saved in {file_name}") - -def save_to_json(output_filename, train_step_loss, train_epoch_loss, train_step_ppl, train_epoch_ppl, val_step_loss, val_epoch_loss, val_step_ppl, val_epoch_ppl): - metrics_data = { - "train_step_loss": train_step_loss, - "train_epoch_loss": train_epoch_loss, - "train_step_perplexity": train_step_ppl, - "train_epoch_perplexity": train_epoch_ppl, - "val_step_loss": val_step_loss, - "val_epoch_loss": val_epoch_loss, - "val_step_perplexity": val_step_ppl, - "val_epoch_perplexity": val_epoch_ppl - } - with open(output_filename, "w") as f: - json.dump(metrics_data, f) diff --git a/lm_eval_harness/eval_lm_harness.py b/lm_eval_harness/eval_lm_harness.py index 52339e7..ae4b1ef 100644 --- a/lm_eval_harness/eval_lm_harness.py +++ b/lm_eval_harness/eval_lm_harness.py @@ -11,7 +11,6 @@ import numpy as np import pandas as pd -# from lm_eval_harness.model_loader import load_model_from_checkpoint, load_model_from_config from src.model.load_model_for_eval import load_model_from_checkpoint, load_model_from_config LM_EVALUATION_HARNESS_PATH = '/juice2/scr2/mzhang/projects/lm-evaluation-harness' # Change this to where you clone LM eval harness from @@ -44,7 +43,6 @@ def get_args(): parser.add_argument("--model_config", type=str, default=None) parser.add_argument("--cache_dir", type=str, default=None) - # If model_name == 'hedghog', paths to checkpoints parser.add_argument("--attn_mlp_checkpoint_path", type=str, default=None) parser.add_argument("--finetune_checkpoint_path", type=str, default=None) parser.add_argument("--config_dir", type=str, default='./configs') diff --git a/lm_eval_harness/eval_lm_harness_big.py b/lm_eval_harness/eval_lm_harness_big.py index 6f539e0..e768fcd 100644 --- a/lm_eval_harness/eval_lm_harness_big.py +++ b/lm_eval_harness/eval_lm_harness_big.py @@ -1,5 +1,6 @@ """ Evaluate models with lm-evaluation-harness +- Another way to evaluate models that require multiple GPUs to load - Right now does a heinous manual pipelining of model layers across devices """ import copy @@ -48,7 +49,6 @@ ('truthfulqa-mc', 0), ('winogrande', 5), ('gsm8k', 5), - # TODO: include MMLU too ] ZERO_SHOT = [ ('hellaswag', 0), @@ -57,7 +57,6 @@ ('arc-easy', 0), ('winogrande', 0), ('hendrycksTest', 5), - # TODO: include MMLU too ] @@ -125,11 +124,8 @@ def get_args(): # Run_name for loading checkpoints args.run_name = f'dl-d={distill_name}-m={args.model_config}-f={finetune_name}' - # if args.no_peft_grad_ckpt is not None: - # args.model_name += f'-npgc={args.no_peft_grad_ckpt}' if args.fsdp_activation_checkpointing is not None: args.run_name += f'-fac={args.fsdp_activation_checkpointing}' - # args.run_name += f'-s={args.seed}' args.run_name = args.run_name.replace('True', '1').replace('False', '0') # concise hacks # Run name for evaluation @@ -199,7 +195,6 @@ def get_lm_eval_lolcats_model(model_kwargs: dict, lolcats_model: bool = True): lm = get_model('hf-causal-experimental').create_from_arg_string( '', lm_kwargs, ) - # model = lm.model return lm @@ -229,13 +224,6 @@ def main(): seed_everything(args.seed) rank = 0 - # if args.attn_mlp_checkpoint_path is not None or args.finetune_checkpoint_path is not None: - # args.model_config = args.attn_mlp_checkpoint_path.split('/')[-2] - # if args.attn_mlp_checkpoint_path is not None: - # args.distill_config = args.attn_mlp_checkpoint_path.split('-d=')[-1].split('-m=')[0] - # if args.finetune_checkpoint_path is not None: - # args.finetune_config = args.finetune_checkpoint_path.split('-f=')[-1].split('-')[0] - args.checkpoint_dir = join(args.checkpoint_dir, args.model_config) if not os.path.isdir(args.checkpoint_dir): os.makedirs(args.checkpoint_dir) @@ -299,8 +287,6 @@ def main(): tokenizer.pad_token_id = tokenizer.eos_token_id tokenizer.padding_side = 'left' - # use_cache = False if args.enable_fsdp else None - # Load model weights to CPU lm = get_lm_eval_lolcats_model(model_loader.loading_kwargs, lolcats_model=args.model_type == 'lolcats_ckpt') @@ -337,11 +323,6 @@ def main(): model_config.model_name = model_config.model.pretrained_model_name_or_path print_model_size(model, model_config, 0) - # Prepare the model for int8 training if quantization is enabled - # -> But we only use this script for FSDP without quantization - # if train_config.quantization: - # model = prepare_model_for_int8_training(model) - # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled if args.pure_bf16: model.to(torch.bfloat16) @@ -352,18 +333,13 @@ def main(): # ------------------------------- model, distill_peft_config = load_and_convert_attns(model, model_config, attention_type=args.attention_type, # 'lolcats_llama', - checkpoint_path=None, # args.load_distill_checkpoint, + checkpoint_path=None, print_model=args.verbose, merge_loras=False, peft_gradient_checkpointing=not args.no_peft_grad_ckpt, train_attention=False) if True: # rank == 0: - # if distill_config.trainer.name is not None or args.attn_mlp_checkpoint_path is not None: if distill_config.trainer.name is not None and args.attn_mlp_checkpoint_path is not None: - # if args.replicate == 64: - # distill_config.model_name = distill_config.model_name.replace(f'-se={args.seed}', '-se=0').replace(f'-s={args.seed}', '-s=0') - # else: - # distill_config.model_name = distill_config.model_name.replace(f'-re={args.replicate}', '-re=0') model = load_sharded_model_single_gpu(model, model_path=args.attn_mlp_checkpoint_path, # None, cfg=distill_config, rank=rank) @@ -372,7 +348,6 @@ def main(): # ---------------------------- finetune_config, args = prepare_finetune_configs(args, model_config, args.finetune_config) - # finetune_config = update_config_from_args(finetune_config, args) finetune_config = setup_fsdp_config(finetune_config, args, 'finetune') model, ft_peft_config = load_and_convert_finetune(model, finetune_config, @@ -392,14 +367,13 @@ def main(): if True: # if rank == 0: print_header('** Sanity check model weights **') for n, p in model.named_parameters(): - # if ('layers.0.' in n and ('feature_map' in n or 'lora' in n)): if 'layers.0.' in n: print(f'-> {n}:\n', p) # Back to LM Eval model lm.model = model model = lm - if args.task in ['mmlu', 'hendrycksTest', 'mmlu_cloze']: + if args.task in ['mmlu', 'hendrycksTest']: from lm_eval.tasks import TASK_REGISTRY tasks = sorted([k for k in TASK_REGISTRY.keys() if f'{args.task}-' in k]) else: diff --git a/lm_eval_harness/mmlu_2.py b/lm_eval_harness/mmlu_2.py deleted file mode 100644 index 951798f..0000000 --- a/lm_eval_harness/mmlu_2.py +++ /dev/null @@ -1,200 +0,0 @@ -""" -Setup: -1. Copy this file into `/lm-evaluation-harness/lm_eval/tasks/mmlu_cloze.py` -2. Add `from . import mmlu_cloze` to `/lm-evaluation-harness/lm_eval/tasks/__init__.py` - e.g., - ``` - from . import mmlu - from . import mmlu_2 # added - ``` -3. Also change `/lm-evaluation-harness/lm_eval/tasks/__init__.py` to include mmlu_cloze tasks, - e.g., adding `**mmlu_2.create_all_tasks()` in the file - ``` - **hendrycks_test.create_all_tasks(), - **mmlu_2.create_all_tasks(), - ``` --------------------------------------------------- - -Format of MMLU where the full answer, e.g., "A. Answer Text" is used instead of just the letter, e.g., "A" - -Measuring Massive Multitask Language Understanding -https://arxiv.org/pdf/2009.03300.pdf - -The Hendryck's Test is a benchmark that measured a text model’s multitask accuracy. -The test covers 57 tasks including elementary mathematics, US history, computer -science, law, and more. To attain high accuracy on this test, models must possess -extensive world knowledge and problem solving ability. By comprehensively evaluating -the breadth and depth of a model’s academic and professional understanding, -Hendryck's Test can be used to analyze models across many tasks and to identify -important shortcomings. - -Homepage: https://github.com/hendrycks/test -""" -from lm_eval.base import MultipleChoiceTask - -_CITATION = """ -@article{hendryckstest2021, - title={Measuring Massive Multitask Language Understanding}, - author={Dan Hendrycks and Collin Burns and Steven Basart and Andy Zou and Mantas Mazeika and Dawn Song and Jacob Steinhardt}, - journal={Proceedings of the International Conference on Learning Representations (ICLR)}, - year={2021} -} -""" - - -SUBJECTS = [ - "abstract_algebra", - "anatomy", - "astronomy", - "business_ethics", - "clinical_knowledge", - "college_biology", - "college_chemistry", - "college_computer_science", - "college_mathematics", - "college_medicine", - "college_physics", - "computer_security", - "conceptual_physics", - "econometrics", - "electrical_engineering", - "elementary_mathematics", - "formal_logic", - "global_facts", - "high_school_biology", - "high_school_chemistry", - "high_school_computer_science", - "high_school_european_history", - "high_school_geography", - "high_school_government_and_politics", - "high_school_macroeconomics", - "high_school_mathematics", - "high_school_microeconomics", - "high_school_physics", - "high_school_psychology", - "high_school_statistics", - "high_school_us_history", - "high_school_world_history", - "human_aging", - "human_sexuality", - "international_law", - "jurisprudence", - "logical_fallacies", - "machine_learning", - "management", - "marketing", - "medical_genetics", - "miscellaneous", - "moral_disputes", - "moral_scenarios", - "nutrition", - "philosophy", - "prehistory", - "professional_accounting", - "professional_law", - "professional_medicine", - "professional_psychology", - "public_relations", - "security_studies", - "sociology", - "us_foreign_policy", - "virology", - "world_religions", -] - - -def create_all_tasks(): - """Creates a dictionary of tasks from a list of subjects - :return: {task_name: task} - e.g. {hendrycksTest-abstract_algebra: Task, hendrycksTest-anatomy: Task} - """ - # return {f"hendrycksTest-{sub}": create_task(sub) for sub in SUBJECTS} - return {f"mmlu_2-{sub}": create_task(sub) for sub in SUBJECTS} - - -def create_task(subject): - class HendrycksTest(GeneralHendrycksTest): - def __init__(self): - super().__init__(subject) - - return HendrycksTest - - -class GeneralHendrycksTest(MultipleChoiceTask): - VERSION = 1 - DATASET_PATH = "cais/mmlu" - DATASET_NAME = None - - def __init__(self, subject): - self.DATASET_NAME = subject - super().__init__() - - def has_training_docs(self): - return True - - def has_validation_docs(self): - return True - - def has_test_docs(self): - return True - - def validation_docs(self): - return map(self._process_doc, self.dataset["validation"]) - - def test_docs(self): - return map(self._process_doc, self.dataset["test"]) - - def _format_subject(self, subject): - words = subject.split("_") - return " ".join(words) - - def fewshot_context(self, doc, num_fewshot, **kwargs): - subject = self.DATASET_NAME - description = f"The following are multiple choice questions (with answers) about {self._format_subject(subject)}." - kwargs["description"] = description - return super().fewshot_context(doc=doc, num_fewshot=num_fewshot, **kwargs) - - def _process_doc(self, doc): - def format_example(doc, keys): - """ - - A. - B. - C. - D. - Answer: - """ - - question = doc["question"].strip() - choices = "".join( - [f"{key}. {choice}\n" for key, choice in zip(keys, doc["choices"])] - ) - prompt = f"{question}\n{choices}Answer:" - return prompt - - keys = ["A", "B", "C", "D"] - choices = [f"{key}. {choice}" for key, choice in zip(keys, doc["choices"])] - return { - "query": format_example(doc, keys), - "choices": choices, - "gold": doc["answer"], - } - - def fewshot_examples(self, k, rnd): - # fewshot_examples is not just sampling from train_docs because dev is - # in the same distribution as val/test but auxiliary_train isn't - if self._fewshot_docs is None: - self._fewshot_docs = list(map(self._process_doc, self.dataset["dev"])) - - # use the unchanged order of the dev set without sampling, - # just as in the original code https://github.com/hendrycks/test/blob/master/evaluate.py#L28 - return self._fewshot_docs[:k] - - def doc_to_text(self, doc): - return doc["query"] - - def should_decontaminate(self): - return True - - def doc_to_decontamination_query(self, doc): - return doc["query"] diff --git a/lm_eval_harness/mmlu_cloze.py b/lm_eval_harness/mmlu_cloze.py deleted file mode 100644 index fe5d283..0000000 --- a/lm_eval_harness/mmlu_cloze.py +++ /dev/null @@ -1,237 +0,0 @@ -""" -Setup: -1. Copy this file into `/lm-evaluation-harness/lm_eval/tasks/mmlu_cloze.py` -2. Add `from . import mmlu_cloze` to `/lm-evaluation-harness/lm_eval/tasks/__init__.py` - e.g., - ``` - from . import mmlu - from . import mmlu_cloze # added - ``` -3. Also change `/lm-evaluation-harness/lm_eval/tasks/__init__.py` to include mmlu_cloze tasks, - e.g., adding `**mmlu_cloze.create_all_tasks()` in the file - ``` - **hendrycks_test.create_all_tasks(), - **mmlu_cloze.create_all_tasks(), - ``` --------------------------------------------------- - -Adapted version of the below where instead of: -``` -The following are multiple choice questions (with answers) about world religions. - -What is the Second Gem in Buddhism? -A. The Dharma -B. The Sangha -C. The Buddha -D. The Bodhisattva -Answer: -``` - -and picking among choices (lowest perplexity) -``` -A -B -C -D -``` - -We use the format: -``` -The following are questions with answers about world religions. - -What is the Second Gem in Buddhism? -``` - -and picking among choices (lowest perplexity) -``` -The Dharma -The Sangha -The Buddha -The Bodhisattva -``` - - --------------------------------------------------- - -Measuring Massive Multitask Language Understanding -https://arxiv.org/pdf/2009.03300.pdf - -The Hendryck's Test is a benchmark that measured a text model’s multitask accuracy. -The test covers 57 tasks including elementary mathematics, US history, computer -science, law, and more. To attain high accuracy on this test, models must possess -extensive world knowledge and problem solving ability. By comprehensively evaluating -the breadth and depth of a model’s academic and professional understanding, -Hendryck's Test can be used to analyze models across many tasks and to identify -important shortcomings. - -Homepage: https://github.com/hendrycks/test -""" -from lm_eval.base import MultipleChoiceTask - -_CITATION = """ -@article{hendryckstest2021, - title={Measuring Massive Multitask Language Understanding}, - author={Dan Hendrycks and Collin Burns and Steven Basart and Andy Zou and Mantas Mazeika and Dawn Song and Jacob Steinhardt}, - journal={Proceedings of the International Conference on Learning Representations (ICLR)}, - year={2021} -} -""" - - -SUBJECTS = [ - "abstract_algebra", - "anatomy", - "astronomy", - "business_ethics", - "clinical_knowledge", - "college_biology", - "college_chemistry", - "college_computer_science", - "college_mathematics", - "college_medicine", - "college_physics", - "computer_security", - "conceptual_physics", - "econometrics", - "electrical_engineering", - "elementary_mathematics", - "formal_logic", - "global_facts", - "high_school_biology", - "high_school_chemistry", - "high_school_computer_science", - "high_school_european_history", - "high_school_geography", - "high_school_government_and_politics", - "high_school_macroeconomics", - "high_school_mathematics", - "high_school_microeconomics", - "high_school_physics", - "high_school_psychology", - "high_school_statistics", - "high_school_us_history", - "high_school_world_history", - "human_aging", - "human_sexuality", - "international_law", - "jurisprudence", - "logical_fallacies", - "machine_learning", - "management", - "marketing", - "medical_genetics", - "miscellaneous", - "moral_disputes", - "moral_scenarios", - "nutrition", - "philosophy", - "prehistory", - "professional_accounting", - "professional_law", - "professional_medicine", - "professional_psychology", - "public_relations", - "security_studies", - "sociology", - "us_foreign_policy", - "virology", - "world_religions", -] - - -def create_all_tasks(): - """Creates a dictionary of tasks from a list of subjects - :return: {task_name: task} - e.g. {hendrycksTest-abstract_algebra: Task, hendrycksTest-anatomy: Task} - """ - # return {f"hendrycksTest-{sub}": create_task(sub) for sub in SUBJECTS} - return {f"mmlu_cloze-{sub}": create_task(sub) for sub in SUBJECTS} - - -def create_task(subject): - class HendrycksTest(GeneralHendrycksTest): - def __init__(self): - super().__init__(subject) - - return HendrycksTest - - -class GeneralHendrycksTest(MultipleChoiceTask): - VERSION = 1 - DATASET_PATH = "cais/mmlu" - DATASET_NAME = None - - def __init__(self, subject): - self.DATASET_NAME = subject - super().__init__() - - def has_training_docs(self): - return True - - def has_validation_docs(self): - return True - - def has_test_docs(self): - return True - - def validation_docs(self): - return map(self._process_doc, self.dataset["validation"]) - - def test_docs(self): - return map(self._process_doc, self.dataset["test"]) - - def _format_subject(self, subject): - words = subject.split("_") - return " ".join(words) - - def fewshot_context(self, doc, num_fewshot, **kwargs): - subject = self.DATASET_NAME - # description = f"The following are multiple choice questions (with answers) about {self._format_subject(subject)}." - description = f"The following are questions with answers about {self._format_subject(subject)}." - kwargs["description"] = description - return super().fewshot_context(doc=doc, num_fewshot=num_fewshot, **kwargs) - - def _process_doc(self, doc): - def format_example(doc, keys): - """ - Answer - """ - - question = doc["question"].strip() - # choices = "".join( - # [f"{key}. {choice}\n" for key, choice in zip(keys, doc["choices"])] - # ) - # prompt = f"{question}\n{choices}Answer:" - prompt = f"{question}" - return prompt - - keys = ["A", "B", "C", "D"] # ignored - - # print(f'query: {format_example(doc, keys)}') - # print(f'doc["choices"]: {doc["choices"]}') - # print(f'doc["answer"]: {doc["answer"]}') - - return { - "query": format_example(doc, keys), - "choices": doc["choices"], - "gold": doc["answer"], - } - - def fewshot_examples(self, k, rnd): - # fewshot_examples is not just sampling from train_docs because dev is - # in the same distribution as val/test but auxiliary_train isn't - if self._fewshot_docs is None: - self._fewshot_docs = list(map(self._process_doc, self.dataset["dev"])) - - # use the unchanged order of the dev set without sampling, - # just as in the original code https://github.com/hendrycks/test/blob/master/evaluate.py#L28 - return self._fewshot_docs[:k] - - def doc_to_text(self, doc): - return doc["query"] - - def should_decontaminate(self): - return True - - def doc_to_decontamination_query(self, doc): - return doc["query"] diff --git a/lm_eval_harness/models.py b/lm_eval_harness/models.py index f9a6d9c..23aee4c 100644 --- a/lm_eval_harness/models.py +++ b/lm_eval_harness/models.py @@ -1,5 +1,5 @@ """ -Inherit from lm-evaluation-harness/lm_eval/models/huggingface.py to load Hedgehog models +Inherit from lm-evaluation-harness/lm_eval/models/huggingface.py to load linearized models """ from lm_eval.models.huggingface import AutoCausalLM from src.model.modeling_llama import LolcatsLlamaForCausalLM as LOLCATS_LLAMA_MODEL_CLASS @@ -9,7 +9,6 @@ from src.model.modeling_mistral import LooooolcatsMistralForCausalLM as LOOOOOLCATS_MISTRAL_MODEL_CLASS from src.model.modeling_llama_sharded import ShardedLolcatsLlamaForCausalLM as SHARDED_LOLCATS_LLAMA_MODEL_CLASS -from src.model.modeling_llama_sharded_roll import ShardedRollLolcatsLlamaForCausalLM as SHARDED_ROLL_LOLCATS_LLAMA_MODEL_CLASS class LolcatsLlamaForCausalLM(AutoCausalLM): @@ -66,22 +65,22 @@ def add_special_tokens(self) -> bool: return False -class ShardedRollLolcatsLlamaForCausalLM(AutoCausalLM): - """ - Wrapper for Llama or Mistral-like autoregressive language model - """ - AUTO_MODEL_CLASS = SHARDED_ROLL_LOLCATS_LLAMA_MODEL_CLASS - @property - def add_special_tokens(self) -> bool: - """Whether to include special tokens in encoded text. This should be - determined by whether or not the model was trained with special tokens. - TODO: Remove these conditionals once HuggingFace supports a way to - check whether or not an arbitrary model was trained with special tokens. - """ - if self._add_special_tokens is not None: - return self._add_special_tokens - else: - return False +# class ShardedRollLolcatsLlamaForCausalLM(AutoCausalLM): +# """ +# Wrapper for Llama or Mistral-like autoregressive language model +# """ +# AUTO_MODEL_CLASS = SHARDED_ROLL_LOLCATS_LLAMA_MODEL_CLASS +# @property +# def add_special_tokens(self) -> bool: +# """Whether to include special tokens in encoded text. This should be +# determined by whether or not the model was trained with special tokens. +# TODO: Remove these conditionals once HuggingFace supports a way to +# check whether or not an arbitrary model was trained with special tokens. +# """ +# if self._add_special_tokens is not None: +# return self._add_special_tokens +# else: +# return False class LooooolcatsLlamaForCausalLM(AutoCausalLM): diff --git a/mmlu_demo_lolcats.py b/mmlu_demo_lolcats.py deleted file mode 100644 index 33aaab2..0000000 --- a/mmlu_demo_lolcats.py +++ /dev/null @@ -1,359 +0,0 @@ -""" -Quick demo of linearized LLM generations - -Example scripts: -``` -python mmlu_demo_lolcats.py \ ---attn_mlp_checkpoint_path ./checkpoints/distill_llama3_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-se=0-re=12-lzi=1_distill.pt \ ---finetune_checkpoint_path ./checkpoints/distill_llama3_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-se=0-re=12-lzi=1-bs=1-gas=8-nte=2-se=0-re=12_ft.pt \ ---num_shots 5 --split test --seed 0 --num_generations 5 --max_new_tokens 1 - -python mmlu_demo_lolcats.py \ ---attn_mlp_checkpoint_path ./checkpoints/distill_long_llama3_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_long_alpaca_8k_xent0_mse1000_lr1e-2_bs1-m=distill_long_llama3_8b_lk_smd_wtk64_fd64_w01-f=finetune_long_lora_qkvo_alpaca_clean_8192_bs1-s=0-nte=2-se=0-re=800-scl=1024-ws=64-lzi=1_distill.pt \ ---finetune_checkpoint_path ./checkpoints/distill_long_llama3_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_long_alpaca_8k_xent0_mse1000_lr1e-2_bs1-m=distill_long_llama3_8b_lk_smd_wtk64_fd64_w01-f=finetune_long_lora_qkvo_alpaca_clean_8192_bs1-s=0-nte=2-se=0-re=800-scl=1024-ws=64-lzi=1-bs=1-gas=1-nte=2-ms=-1-se=0-re=800_ft.pt \ ---num_shots 5 --split test --seed 0 --num_generations 5 --max_new_tokens 1 - -python mmlu_demo_lolcats.py \ ---attn_mlp_checkpoint_path ./checkpoints/distill_llama3_8b_lk_smd_wtk64_fd64_wo/dl-d=no_distill_alpaca_clean-m=distill_llama3_8b_lk_smd_wtk64_fd64_wo-f=finetune_lora_qkvo_alpaca_clean-s=0-nte=2-se=0-re=800-scl=1024-ws=64-lzi=1-nte=2-se=0-re=800_ft.pt \ ---finetune_checkpoint_path ./checkpoints/distill_llama3_8b_lk_smd_wtk64_fd64_wo/dl-d=no_distill_alpaca_clean-m=distill_llama3_8b_lk_smd_wtk64_fd64_wo-f=finetune_lora_qkvo_alpaca_clean-s=0-nte=2-se=0-re=800-scl=1024-ws=64-lzi=1-nte=2-se=0-re=800_ft.pt \ ---num_shots 5 --split test --seed 0 --num_generations 5 --max_new_tokens 1 - - -python mmlu_demo_lolcats.py \ ---attn_mlp_checkpoint_path ./checkpoints/distill_long_llama3_8b_lk_smd_wtk64_fd64_wo/dl-d=no_distill_alpaca_clean-m=distill_long_llama3_8b_lk_smd_wtk64_fd64_wo-f=finetune_long_lora_qkvo_alpaca_clean_8192_bs1-s=0-nte=2-se=0-re=800-scl=1024-ws=64-lzi=1-nte=2-se=0-re=800_ft.pt \ ---finetune_checkpoint_path ./checkpoints/distill_long_llama3_8b_lk_smd_wtk64_fd64_wo/dl-d=no_distill_alpaca_clean-m=distill_long_llama3_8b_lk_smd_wtk64_fd64_wo-f=finetune_long_lora_qkvo_alpaca_clean_8192_bs1-s=0-nte=2-se=0-re=800-scl=1024-ws=64-lzi=1-nte=2-se=0-re=800_ft.pt \ ---num_shots 5 --split test --seed 0 --num_generations 5 --max_new_tokens 1 -``` -""" -from typing import Optional, List -from os.path import join -import time -import argparse -import numpy as np -import torch - -from tqdm import tqdm -from omegaconf import OmegaConf - -from transformers import TextStreamer, TextIteratorStreamer, AutoTokenizer - -from src.utils.setup import seed_everything -from src.utils.logging import print_header -from src.model.pretrained import get_pretrained_loader -from src.model.load_model import load_and_convert_attns, load_and_convert_finetune - - -from llama_recipes.model_checkpointing.distill_checkpoint_handler import ( - load_sharded_model_single_gpu, -) - -from datasets import load_dataset - -CACHE_DIR = '/scr-ssd/mzhang/data/mmlu' - - -class MMLU(): - def __init__(self, num_shots: int, cache_dir: str, split: str = 'dev'): - self.num_shots = num_shots - self.cache_dir = cache_dir - self.split = split - - def _format_subject(self, subject): - words = subject.split("_") - return " ".join(words) - - def get_description(self, subject): - description = f"The following are multiple choice questions (with answers) about {self._format_subject(subject)}." - return description - - def format_example(self, doc, keys): - """ - - A. - B. - C. - D. - Answer: - """ - question = doc["question"].strip() - choices = "".join( - [f"{key}. {choice}\n" for key, choice in zip(keys, doc["choices"])] - ) - prompt = f"{question}\n{choices}Answer:" - answer = keys[doc['answer']] - return prompt, answer - - def load_prompts(self): - ds = load_dataset("cais/mmlu", "all", cache_dir=self.cache_dir) - ds = ds[self.split] - subjects = np.unique(ds['subject']) - all_samples = [] - keys = ["A", "B", "C", "D"] - for subject in tqdm(subjects, desc='processing subjects'): - samples = [x for x in ds if x['subject'] == subject] - # breakpoint() - # Just get 1 sample for each subject - ix = 0 - if len(samples) > self.num_shots + 1: # number in context + final question - prompt = self.get_description(subject) - prompt += '\n\n' - for _ in range(self.num_shots): - question, answer = self.format_example(samples[ix], keys) - prompt += f'{question} {answer}\n\n' - ix += 1 - question, answer = self.format_example(samples[ix], keys) - prompt += f'{question}' - all_samples.append((prompt, answer)) - return all_samples - - -def get_args(): - parser = argparse.ArgumentParser() - # Model load + setup - parser.add_argument("--attn_mlp_checkpoint_path", type=str, default=None) - parser.add_argument("--finetune_checkpoint_path", type=str, default=None) - parser.add_argument("--config_dir", type=str, default='configs') - parser.add_argument("--seed", type=int, default=42) - - parser.add_argument("--num_shots", type=int, default=5) - parser.add_argument("--split", type=str, default='test') - - # Generation - parser.add_argument("--num_generations", type=int, default=1) - parser.add_argument("--top_k", type=int, default=50) - parser.add_argument("--top_p", type=float, default=0.95) - parser.add_argument("--max_new_tokens", type=int, default=1024) - - # Miscellaneous - parser.add_argument("--benchmark", action='store_true', default=False) - parser.add_argument("--print_model", action='store_true', default=False) - parser.add_argument("--debug", action='store_true', default=False) - parser.add_argument("--huggingface_token", type=str, default=None) - - # Alt - parser.add_argument("--attn_checkpoint_path", type=str, default=None) - parser.add_argument("--peft_checkpoint_path", type=str, default=None) - - args = parser.parse_args() - if args.attn_mlp_checkpoint_path is None and args.attn_checkpoint_path is not None: - args.attn_mlp_checkpoint_path = args.attn_checkpoint_path - if args.finetune_checkpoint_path is None and args.peft_checkpoint_path is not None: - args.finetune_checkpoint_path = args.peft_checkpoint_path - return args - - -def get_lm_eval_lolcats_model(model_kwargs: dict, lolcats_model: bool = True): - lm_kwargs = copy.deepcopy(model_kwargs) - lm_kwargs['pretrained'] = lm_kwargs['pretrained_model_name_or_path'] - lm_kwargs['dtype'] = str(lm_kwargs['torch_dtype']).split('.')[-1] - del lm_kwargs['torch_dtype'] - - if 'Llama' in lm_kwargs['pretrained_model_name_or_path']: # and lolcats_model: - lm_kwargs['device_map'] = None - from lm_eval_harness.models import ShardedLolcatsLlamaForCausalLM - lm = ShardedLolcatsLlamaForCausalLM.create_from_arg_string( - '', lm_kwargs, - ) - else: - sys.path.append(LM_EVALUATION_HARNESS_PATH) - from lm_eval.models import get_model - - lm = get_model('hf-causal-experimental').create_from_arg_string( - '', lm_kwargs, - ) - # model = lm.model - return lm - - -def count_params(module) -> int: - return sum(p.numel() for p in module.parameters()) - - -def setup_fsdp_config(config, args, checkpoint_name: str = 'finetune'): - """ - Hacky arguments for llama-recipes training function - """ - config.seed = args.seed - config.enable_fsdp = args.enable_fsdp - config.low_cpu_fsdp = args.low_cpu_fsdp - config.dist_checkpoint_root_folder = args.checkpoint_dir - config.dist_checkpoint_folder = checkpoint_name - - config.model_name = args.run_name - config.use_peft = False # We have custom logic for saving PEFT modules - config.save_model = True - config.run_validation = True - config.use_fp16 = False - config.save_model = True - config.save_optimizer = False - config.output_dir = args.checkpoint_dir - config.save_metrics = not args.no_wandb - config.gradient_clipping = False - config.gradient_clipping_threshold = 1.0 - config.num_epochs = getattr(config.trainer, 'num_train_epochs', None) - config.num_train_steps = getattr(args, 'num_train_steps', None) # exit training loop early for debugging - config.eval_steps = getattr(config.trainer, 'eval_steps', None) # how many gradient updates before evaluating - return config - - -def load_model_from_checkpoint(attn_mlp_checkpoint_path: str, - finetune_checkpoint_path: str, - config_dir: str = 'configs', - print_model: bool = False, - debug: bool = False, - huggingface_token: str = None): - rank = 0 - # Get configs from checkpoint paths - try: - model_config = attn_mlp_checkpoint_path.split('-m=')[-1].split('-f=')[0] - distill_config = attn_mlp_checkpoint_path.split('-d=')[-1].split('-m=')[0] - except Exception as e: - model_config = finetune_checkpoint_path.split('-m=')[-1].split('-f=')[0] - distill_config = None - - model_config = join(config_dir, 'model', f'{model_config}.yaml') - model_config = OmegaConf.load(model_config) - - if distill_config is not None: - distill_config = join(config_dir, 'experiment', f'{distill_config}.yaml') - distill_config = OmegaConf.load(distill_config) - else: - distill_config = {} - - finetune_config = finetune_checkpoint_path.split('-f=')[-1].split('-')[0] - finetune_config = join(config_dir, 'experiment', f'{finetune_config}.yaml') - finetune_config = OmegaConf.load(finetune_config) - - # Load initial model - model_loader = get_pretrained_loader(**model_config.model, - huggingface_token=huggingface_token) - tokenizer = model_loader.load_tokenizer() - tokenizer.pad_token_id = tokenizer.eos_token_id - tokenizer.padding_side = 'left' - - model = model_loader.load(model_config['attention']['attention_type']) - try: - model.state_chunk_len = model_config['attention']['state_chunk_len'] - except: - pass - if debug: - print_header('Pretrained Model') - print(model) - - # Add subquadratic attentions - model, distill_peft_config = load_and_convert_attns(model, model_config, - attention_type=None, # in model_config - checkpoint_path=attn_mlp_checkpoint_path, - print_model=debug, - merge_loras=False, - peft_gradient_checkpointing=False, - train_attention=False) - - # Add PEFT parameters - model, ft_peft_config = load_and_convert_finetune(model, finetune_config, - checkpoint_path=finetune_checkpoint_path, - print_model=debug, - merge_loras=False, - peft_gradient_checkpointing=False) - if print_model: - print_header('*** Model after checkpoint load ***') - print(model) - - return model, model_config, tokenizer - - -def get_model_name(attn_mlp_checkpoint_path: str, finetune_checkpoint_path: str, - model_config: str = None): - model_name = '🦔 ' if attn_mlp_checkpoint_path is not None else '' - if 'llama3_8b_' in finetune_checkpoint_path: - model_name += f'Llama-3-8B' - elif 'llama2_7b_' in finetune_checkpoint_path: - model_name += f'Llama-2-7B' - elif 'mistral_7b_' in finetune_checkpoint_path: - model_name += f'Mistral-7B' - - if attn_mlp_checkpoint_path is not None: - model_name += f'-Hedgehog' - - if 'alpaca_clean' in finetune_checkpoint_path: - model_name += f'-Alpaca' - - elif model_config is not None: - if 'llama3_8b_' in model_config: - model_name += f'Llama-3-8B' - elif 'llama2_7b_' in model_config: - model_name += f'Llama-2-7B' - elif 'mistral_7b_' in model_config: - model_name += f'Mistral-7B' - - return model_name - - -def main(): - args = get_args() - seed_everything(args.seed) - model, model_config, tokenizer = load_model_from_checkpoint( - args.attn_mlp_checkpoint_path, args.finetune_checkpoint_path, - config_dir=args.config_dir, print_model = args.print_model, debug = args.debug, - ) - model.eval() - model_name = get_model_name(args.attn_mlp_checkpoint_path, - args.finetune_checkpoint_path, - model_config) - - # Load data - # datasets = load_mmlu(cache_dir=CACHE_DIR) - mmlu = MMLU(num_shots=args.num_shots, cache_dir=CACHE_DIR, split=args.split) - samples = mmlu.load_prompts() - - total = 0 - correct = 0 - for ix, (prompt, answer) in enumerate(samples): - with torch.no_grad(): - model_input = tokenizer([prompt] * args.num_generations, - return_tensors="pt").to(model.device) - model_output = model.generate(**model_input, use_cache=True, - max_new_tokens=args.max_new_tokens, - do_sample=True, - top_k=args.top_k, - top_p=args.top_p, - num_return_sequences=1, - pad_token_id=tokenizer.eos_token_id) - - outputs = tokenizer.batch_decode(model_output) - print('-' * 10 + f' Sample {ix} ' + '-' * 10) - input_seq_len = model_input['input_ids'].shape[-1] - - for _ix, _output in enumerate(model_output): - if _ix == 0: - decoded_output = tokenizer.decode(_output[:input_seq_len]) - print(decoded_output) - - print('---' + f' Prediction {_ix} ' + '---') - decoded_output = tokenizer.decode(_output[input_seq_len:]) - print(decoded_output) - print('') - print('---' + f' True Answer {ix} ' + '---') - print(f' {answer}') - - # Compute greedy answer - model_input = tokenizer([prompt], return_tensors="pt").to(model.device) - model_output = model.generate(**model_input, use_cache=True, - max_new_tokens=args.max_new_tokens, - do_sample=False, - num_return_sequences=1, - pad_token_id=tokenizer.eos_token_id) - outputs = tokenizer.batch_decode(model_output) - input_seq_len = model_input['input_ids'].shape[-1] - for _ix, _output in enumerate(model_output): - print('---' + f' Greedy Pred {ix} ' + '---') - decoded_output = tokenizer.decode(_output[input_seq_len:]) - print(decoded_output) - print('') - if decoded_output.replace(' ', '').upper() == answer.replace(' ', '').upper(): - correct += 1 - total += 1 - print(f'Final MMLU acc: {correct / total * 100:.4f}% ({correct} / {total})') - -if __name__ == '__main__': - main() diff --git a/src/dataloaders/alpaca_cqa.py b/src/dataloaders/alpaca_cqa.py deleted file mode 100644 index 5a81206..0000000 --- a/src/dataloaders/alpaca_cqa.py +++ /dev/null @@ -1,80 +0,0 @@ -""" -Combined dataloaders for Alpaca and CommonsenseQA -""" -from functools import partial -from os.path import join -import numpy as np - -from torch.utils.data import Dataset -from datasets import load_metric, load_dataset - -from .utils import ( - get_lm_loader, get_seq2seq_loader, - convert_to_hf_dataset, - get_tokenizer_from_config, - download_scrolls_metric as download_metric -) -from .utils.packing import ConcatDataset - -from .alpaca_clean import load_data as load_data_alpaca -from .commonsense_qa import load_data as load_data_cqa - - -class CombinedConcatDataset(Dataset): - def __init__(self, datasets: list[Dataset]): - self.filtered_samples = [] - for dataset in datasets: - self.filtered_samples.extend(dataset.filtered_samples) - - def __getitem__(self, idx): - return self.filtered_samples[idx] - - def __len__(self): - return len(self.filtered_samples) - - -def load_data(name: str, dataset_config: dict, pretrained_model_config: dict, - preprocess_config: dict, **loader_kwargs: any): - """ - Shared function to load dataset from experiment config - -> e.g., see configs/experiments/distill_alpaca_clean_lr1e-2.yaml - """ - input_len = dataset_config['alpaca']['chunk_size'] - dataloaders_alpaca = load_data_alpaca(name='alpaca', - dataset_config=dataset_config['alpaca'], - pretrained_model_config=pretrained_model_config, - preprocess_config=preprocess_config, - **loader_kwargs) - dataloaders_cqa = load_data_cqa(name='commonsense_qa', - dataset_config=dataset_config['commonsense_qa'], - pretrained_model_config=pretrained_model_config, - preprocess_config=preprocess_config, - **loader_kwargs) - - datasets = {} - for split in ['train', 'validation']: # test split is not packed - datasets[split] = CombinedConcatDataset([ - dataloaders_alpaca[split].dataset, dataloaders_cqa[split].dataset - ]) - datasets['test'] = dataloaders_alpaca['test'].dataset - - tokenizer = dataloaders_alpaca[split].dataset.tokenizer # see alpaca_clean.py - - # Get dataloaders - dataloaders = { - 'train': get_lm_loader(datasets['train'], tokenizer, 'train', input_len, **loader_kwargs), - 'validation': get_lm_loader(datasets['validation'], tokenizer, 'validation', input_len, **loader_kwargs), - 'test': get_seq2seq_loader(datasets['test'], tokenizer, 'test', **loader_kwargs), - } - # Evaluation metric - try: - metric = load_metric(download_metric(), 'gov_report') # hack for rouge - except Exception as e: - print(f'Error loading metric: {e}') - metric = None - - # Finishing touches - for k, v in dataloaders.items(): # Make tokenizer accessible - dataloaders[k].dataset.tokenizer = tokenizer - dataloaders[k].dataset.metric = metric - return dataloaders diff --git a/src/dataloaders/commonsense_qa.py b/src/dataloaders/commonsense_qa.py deleted file mode 100644 index abff52a..0000000 --- a/src/dataloaders/commonsense_qa.py +++ /dev/null @@ -1,168 +0,0 @@ -""" -Alpaca training dataloaders - -We adopt the original prompt template; goes something like: -``` -Below is an instruction that describes a task. -Write a response that appropriately completes the request. -### Instruction: -{instruction} - -### Response: -{response} -``` -See `PROMPT_DICT` for more. -""" -from functools import partial -from os.path import join -import numpy as np - -from datasets import load_metric, load_dataset - -from .utils import ( - get_lm_loader, get_seq2seq_loader, - convert_to_hf_dataset, - get_tokenizer_from_config, - download_scrolls_metric as download_metric -) -from .utils.packing import ConcatDataset - - -def load_data(name: str, dataset_config: dict, pretrained_model_config: dict, - preprocess_config: dict, **loader_kwargs: any): - """ - Shared function to load dataset from experiment config - -> e.g., see configs/experiments/distill_alpaca_clean_lr1e-2.yaml - """ - # Misc. setup - cache_dir = dataset_config['cache_dir'] - input_len = dataset_config['chunk_size'] - concat_data = dataset_config['concat_data'] - - tokenizer_name = pretrained_model_config['pretrained_model_name_or_path'] - tokenizer_name = tokenizer_name.split('/')[-1] - # save_path = join(cache_dir, f'{name}_{tokenizer_name}') - - # Setup tokenizer - tokenizer = get_tokenizer_from_config(pretrained_model_config) - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - print(f'Setting tokenizer.pad_token to {tokenizer.pad_token}') - - tokenizer.padding_side = 'left' # for decoder-only generation - # Get initial data - loading_kwargs = ['path', 'cache_dir'] - dataset = load_dataset( - **{k: v for k, v in dataset_config.items() if k in loading_kwargs} - ) - - # Preprocess samples into few-shot samples - train_set = process_samples(dataset['train'], **dataset_config) - val_set = process_samples(dataset['validation'], **dataset_config) - train_set = convert_to_hf_dataset(train_set, cache_dir=dataset_config['cache_dir']) - _val_set = convert_to_hf_dataset(val_set, cache_dir=dataset_config['cache_dir']) - - remove_columns = list(train_set.features) - - # Convert to dicts of {input_ids, attention_mask, labels} - train_set = train_set.map( - partial(template_and_tokenize, tokenizer=tokenizer, include_label=True), - remove_columns=remove_columns,) # load_from_cache_file=False) - val_set = _val_set.map( - partial(template_and_tokenize, tokenizer=tokenizer, include_label=True), - remove_columns=remove_columns,) # load_from_cache_file=False) - test_set = _val_set.map( # This isn't the real test set - partial(template_and_tokenize, tokenizer=tokenizer, include_label=False), - remove_columns=remove_columns,) - - del _val_set - - # Chunk together train and val sets - if concat_data: - train_set = ConcatDataset(train_set, chunk_size=input_len) - val_set = ConcatDataset(val_set, chunk_size=input_len) - - # Get dataloaders - dataloaders = { - 'train': get_lm_loader(train_set, tokenizer, 'train', input_len, **loader_kwargs), - 'validation': get_lm_loader(val_set, tokenizer, 'validation', input_len, **loader_kwargs), - 'test': get_seq2seq_loader(test_set, tokenizer, 'test', **loader_kwargs), - } - # Evaluation metric - try: - metric = load_metric(download_metric(), 'qasper') # hack but we want - except Exception as e: - print(f'Error loading metric: {e}') - metric = None - - # Finishing touches - for k, v in dataloaders.items(): # Make tokenizer accessible - dataloaders[k].dataset.tokenizer = tokenizer - dataloaders[k].dataset.metric = metric - return dataloaders - - -def process_samples(hf_dataset_split, num_shots: int, seed: int = 42, **kwargs: any): - """ - Organize original dataset into few-shot sample datasets - """ - fewshot_samples = [] - hf_dataset_split = hf_dataset_split.shuffle(seed=seed) - context_counter = 0 - for i, _sample in enumerate(hf_dataset_split): - if context_counter == 0: # 5, (6, 7, 8, 9, 10) - sample = { - 'context': [_sample], - 'question': [], - } - context_counter += 1 - elif context_counter % num_shots == 0: - sample['question'] = _sample - fewshot_samples.append(sample) - context_counter = 0 - else: - sample['context'].append(_sample) - context_counter += 1 - return fewshot_samples - - -def template_and_tokenize(sample, tokenizer, include_label: bool = True): - """ - Convert samples into text prompt and tokenize - """ - prompt = '' - # Add few-shot examples in context - for ix, _sample in enumerate(sample['context']): - prompt += _sample['question'] - for ix, label in enumerate(_sample['choices']['label']): - text = _sample['choices']['text'] - prompt += f'\n{label}. {text[ix]}' - prompt += f'\nAnswer: {_sample["answerKey"]}\n\n' - # if ix < len(sample['context']) - 1: - # prompt += '\n\n' - # Add question - _sample = sample['question'] - prompt += _sample['question'] - for ix, label in enumerate(_sample['choices']['label']): - text = _sample['choices']['text'] - prompt += f'\n{label}. {text[ix]}' - prompt += f'\nAnswer: ' - prompt = tokenizer.encode(prompt, add_special_tokens=True) - - if include_label: - answer = tokenizer.encode(f'{_sample["answerKey"]}{tokenizer.eos_token}', - add_special_tokens=False) - target = None - else: - answer = [] - target = tokenizer.encode(f'{_sample["answerKey"]}{tokenizer.eos_token}', - add_special_tokens=False) - input_ids = prompt + answer - attn_mask = [1] * len(input_ids) - - sample = { - "input_ids": input_ids, - "attention_mask" : attn_mask, - "labels": [-100] * len(prompt) + answer if include_label else target, - } - return sample diff --git a/src/dataloaders/eval_mmlu.py b/src/dataloaders/eval_mmlu.py deleted file mode 100644 index 2c4728d..0000000 --- a/src/dataloaders/eval_mmlu.py +++ /dev/null @@ -1,209 +0,0 @@ -""" -MMLU Evaluation Dataloader -""" -import sys -import collections -import random -import itertools - -from functools import partial -from os.path import join -import numpy as np - -from datasets import load_metric, load_dataset - -from .utils import ( - get_lm_loader, get_seq2seq_loader, - convert_to_hf_dataset, - get_tokenizer_from_config, - download_scrolls_metric as download_metric -) -from .utils.packing import ConcatDataset - - - -def get_mmlu_samples(task_dict_items: dict, - provide_description: bool = None, - num_fewshot: int = 5, - limit: int = None, - bootstrap_iters: int = 100000, - description_dict: dict = None, - check_integrity: bool = False, - decontamination_ngrams_path: str = None, - write_out: bool = False, - **kwargs: any): - results = collections.defaultdict(dict) - versions = collections.defaultdict(dict) - - requests = collections.defaultdict(list) - requests_origin = collections.defaultdict(list) - - overlaps = collections.defaultdict(list) # {task_name: contaminated_docs} - - docs = {} - write_out_info = {} - - docs_for_decontamination = collections.defaultdict(list) - - decontaminate = decontamination_ngrams_path is not None - - # get lists of each type of request - for task_name, task in task_dict_items: - versions[task_name] = task.VERSION - # default to test doc, fall back to val doc if validation unavailable - # TODO: the test-fallback-to-val system isn't final, we should revisit it at some point - if task.has_test_docs(): - task_doc_func = task.test_docs - task_set = "test" # Required for caching in the decontamination - elif task.has_validation_docs(): - task_set = "val" # Required for caching in the decontamination - task_doc_func = task.validation_docs - else: - raise RuntimeError("Task has neither test_docs nor validation_docs") - - # deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order - task_docs = list(task_doc_func()) - rnd = random.Random() - rnd.seed(42) - rnd.shuffle(task_docs) - print(f"Task: {task_name}; number of docs: {len(task_docs)}") - - if write_out: - prompt_details = [] - - description = ( - description_dict[task_name] - if description_dict and task_name in description_dict - else "" - ) - if limit is not None: - limit = int(len(task_docs) * limit) if limit < 1.0 else int(limit) - - for doc_id, doc in enumerate(itertools.islice(task_docs, 0, limit)): - if decontaminate and task.should_decontaminate(): - docs_for_decontamination[(task_name, task_set)].append( - task.doc_to_decontamination_query(doc) - ) - - docs[(task_name, doc_id)] = doc - ctx = task.fewshot_context( - doc=doc, num_fewshot=num_fewshot, rnd=rnd, description=description - ) - reqs = task.construct_requests(doc, ctx) - - if write_out: - prompt_details.append({"doc_id": doc_id}) - - # print the prompt for the first few documents - if doc_id < 1: - print( - f"Task: {task_name}; document {doc_id}; context prompt (starting on next line):\n{ctx}\n(end of prompt on previous line)" - ) - print("Requests:", reqs) - - if not isinstance(reqs, (list, tuple)): - reqs = [reqs] - for i, req in enumerate(reqs): - requests[req.request_type].append(req) - # i: index in requests for a single task instance - # doc_id: unique id that we can get back to a doc using `docs` - requests_origin[req.request_type].append((i, task_name, doc, doc_id)) - - if write_out: - prompt_details[-1][f"prompt_{i}"] = "".join( - (map(lambda x: "".join(x), req.args)) - ) - - if write_out: - write_out_info[task_name] = prompt_details - - # Compare all tasks/sets at once to ensure a single training set scan - if decontaminate: - from lm_eval.decontamination.decontaminate import get_train_overlap - - print("Finding train/test overlap, please wait...") - overlaps = get_train_overlap( - docs_for_decontamination, decontamination_ngrams_path, limit - ) - return requests, requests_origin, docs, versions - - - -def load_data(name: str, dataset_config: dict, pretrained_model_config: dict, - preprocess_config: dict, **loader_kwargs: any): - """ - Shared function to load dataset from experiment config - -> e.g., see configs/experiments/distill_alpaca_clean_lr1e-2.yaml - """ - # Misc. setup - cache_dir = dataset_config['cache_dir'] - - # Tokenizer - tokenizer_name = pretrained_model_config['pretrained_model_name_or_path'] - tokenizer_name = tokenizer_name.split('/')[-1] - # save_path = join(cache_dir, f'{name}_{tokenizer_name}') - - # Setup tokenizer - tokenizer = get_tokenizer_from_config(pretrained_model_config) - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - print(f'Setting tokenizer.pad_token to {tokenizer.pad_token}') - - tokenizer.padding_side = 'left' # for decoder-only generation - - # LM EVAL - lm_eval_path = dataset_config['lm_evaluation_harness_path'] # '/juice2/scr2/mzhang/projects/lm-evaluation-harness' - sys.path.append(lm_eval_path) - - # Load tasks - from lm_eval.tasks import get_task_dict, TASK_REGISTRY - if 'tasks' not in dataset_config: - dataset_config['tasks'] = None - tasks = dataset_config['tasks'] - print(f'tasks: {tasks}') - - if tasks is None: - _task = 'hendrycksTest' - tasks = sorted([k for k in TASK_REGISTRY.keys() if f'{_task}-' in k]) - else: - tasks = sorted([k for k in TASK_REGISTRY.keys() if k in tasks]) - task_dict = get_task_dict(tasks) - task_dict_items = [ - (name, task) - for name, task in task_dict.items() - if (task.has_validation_docs() or task.has_test_docs()) - ] - - # Prepare samples - num_fewshot = dataset_config['num_fewshot'] # 5 - limit = dataset_config['limit'] - - _samples = get_mmlu_samples(task_dict_items, num_fewshot=num_fewshot, limit=limit) - - requests, requests_origin, docs, versions = _samples - requests = requests['loglikelihood'] # n-shot samples - requests_origin = requests_origin['loglikelihood'] # Original sample - # (0, 'mmlu-anatomy', {'query': 'Blood flows from the right ventricle of the heart into which of the following structures?\nA. Inferior vena cava\nB. Left ventricle\nC. Pulmonary arteries\nD. Pulmonary veins\nAnswer:', 'choices': ['A', 'B', 'C', 'D'], 'gold': 2}, 1) - - # breakpoint() - # Get samples - samples = [tokenizer(''.join(req.args)) for req in requests] # ['loglikelihood']] - for ix, sample in enumerate(samples): - # sample_idx, category, query_dict, query_idx - samples[ix]['target'] = requests_origin[ix][2]['gold'] - samples[ix]['query_idx'] = requests_origin[ix][-1] - samples[ix]['category'] = tasks.index(requests_origin[ix][1]) - - # for _ix, req in enumerate(requests): - # requests_origin[_ix][2]['text'] = ''.join(req.args) - - dataset = convert_to_hf_dataset(samples, cache_dir=cache_dir) - if 'batch_size' in loader_kwargs: - loader_kwargs['batch_size'] = 4 # for now enforce this - dataloaders = {'eval': get_lm_loader(dataset, tokenizer, 'eval', **loader_kwargs)} - - # Finishing touches - for k, v in dataloaders.items(): # Make tokenizer accessible - dataloaders[k].dataset.tokenizer = tokenizer - dataloaders[k].categories = tasks - return dataloaders['eval'] diff --git a/src/dataloaders/preprocess_rp_contig.py b/src/dataloaders/preprocess_rp_contig.py deleted file mode 100644 index 3f0401a..0000000 --- a/src/dataloaders/preprocess_rp_contig.py +++ /dev/null @@ -1,259 +0,0 @@ -""" -Script to subsample RedPajama subset for long effective context samples - -python src/dataloaders/preprocess_rp_contig.py \ ---model_config base_llama3_8b \ ---distill_config distill_rpcontig2048_dcs1024_xent0_mse1000_lr1e-2 - -python src/dataloaders/preprocess_rp_contig.py \ ---model_config base_llama3_8b \ ---distill_config distill_rpcontig2048_dcs2048_xent0_mse1000_lr1e-2 - -python src/dataloaders/preprocess_rp_contig.py \ ---model_config base_llama3_8b \ ---distill_config distill_rpcontig2048_dcs2048_n10k_xent0_mse1000_lr1e-2 -""" -import os -from os.path import join - -import argparse -from omegaconf import OmegaConf -from tqdm import tqdm - -import numpy as np -import torch -from torch.utils.data import Dataset, DataLoader - -from transformers import DataCollatorForSeq2Seq, DefaultDataCollator - -from src.utils.setup import seed_everything -from src.utils.logging import print_header - -from src.model.pretrained import get_pretrained_loader -# from src.model.load_model import load_and_convert_attns, load_and_convert_finetune -# from src.model.convert_model import toggle_attention - -from src.dataloaders.redpajama_sample import Data, add_eos -from src.dataloaders.utils.packing_contig import ConcatContigDataset -from src.dataloaders.utils import get_tokenizer_from_config - - -def get_lm_loader(dataset: Dataset, shuffle: bool, **loader_kwargs: any): - """ - Get dataloader for language modeling (training) - """ - collate_fn = DefaultDataCollator(return_tensors='pt') - return DataLoader(dataset, shuffle=shuffle, - collate_fn=collate_fn, **loader_kwargs) - - -def get_args(): - """Parse command line arguments""" - parser = argparse.ArgumentParser() - parser.add_argument("--model_config", type=str, default='base_llama3_8b') - parser.add_argument("--distill_config", type=str, - default='distill_rp_contig_xent0_mse1000_lr1e-2') - parser.add_argument("--config_dir", type=str, default='./configs') - parser.add_argument("--seed", type=int, default=42) - args = parser.parse_args() - return args - - -def load_model(config_dir: str = './configs', - model_config: str = 'base_llama3_8b'): - """ - Load pretrained LLM (default is Llama 3 8B) - """ - model_config = join(config_dir, 'model', f'{model_config}.yaml') - model_config = OmegaConf.load(model_config) - model_config['model']['attn_implementation'] = 'eager' # for attentions - - # Load initial model - model_loader = get_pretrained_loader(**model_config.model) - tokenizer = model_loader.load_tokenizer() - tokenizer.pad_token_id = tokenizer.eos_token_id - tokenizer.padding_side = 'left' - - model = model_loader.load(model_config['attention']['attention_type']) - return model, model_config, tokenizer - - -def load_data(name: str, dataset_config: dict, pretrained_model_config: dict, - preprocess_config: dict, **loader_kwargs: any): - """ - Initial data load for processing RedPajama contiguous packed samples - """ - dataset_config = OmegaConf.to_object(dataset_config) - - tokenizer_name = pretrained_model_config['pretrained_model_name_or_path'] - tokenizer_name = tokenizer_name.split('/')[-1] - # save_path = join(cache_dir, f'{name}_{tokenizer_name}') - - # Setup tokenizer - tokenizer = get_tokenizer_from_config(pretrained_model_config) - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - print(f'Setting tokenizer.pad_token to {tokenizer.pad_token}') - - tokenizer.padding_side = 'left' # for decoder-only generation - # ^ But does this impact impact attention sink stuff? - - if 'load_from_cache_file' not in dataset_config: - dataset_config['load_from_cache_file'] = None - - # Get initial data - train_set = Data.prepare_train_data( - dataset_config['train_data'], - tokenizer=tokenizer, - max_length=dataset_config['max_length'], - min_length=dataset_config['min_length'], - chat_template=dataset_config['chat_template'], - seed=dataset_config['seed'], - cache_dir=dataset_config['cache_dir'], - load_from_cache_file=dataset_config['load_from_cache_file'] - ) - val_set = Data.prepare_eval_data( - dataset_config['eval_data'], - tokenizer=tokenizer, - max_length=dataset_config['max_length'], - min_length=dataset_config['min_length'], - chat_template=dataset_config['chat_template'], - seed=dataset_config['seed'], - cache_dir=dataset_config['cache_dir'], - max_eval_num=dataset_config['max_eval_num'], - load_from_cache_file=dataset_config['load_from_cache_file'] - ) - - train_set = ConcatContigDataset(train_set, chunk_size=dataset_config['chunk_size']) - val_set = ConcatContigDataset(val_set, chunk_size=dataset_config['chunk_size']) - - dataloaders = { - 'train': get_lm_loader(train_set, shuffle=False, **loader_kwargs), - 'validation': get_lm_loader(val_set, shuffle=False, **loader_kwargs), - } - return dataloaders - - -def compute_effective_seq_lens(model, dataloader, max_samples: int = None): - """Compute effective sequence length for each sample""" - effective_seq_len_by_layer_by_head = [[] for _ in range(len(model.model.layers))] - batch_idx = 0 # always batch size 1 - with torch.no_grad(): - for ix, data in enumerate(tqdm(dataloader)): - inputs = {k: v.to(model.device) for k, v in data.items() if k != 'labels'} - outputs = model(**inputs, output_attentions=True, use_cache=False, - output_hidden_states=False) - outputs = outputs.get('attentions') - a_true_by_layer = outputs - for layer_idx in range(len(a_true_by_layer)): - # Effective seq len - _true_attns = a_true_by_layer[layer_idx][batch_idx] # num_heads x seq_len x seq_len - positions = torch.arange(_true_attns.shape[-1], device=_true_attns.device) - distances = positions[:, None] - positions[None, :] # "outer diff" - effective_seq_len = (distances * _true_attns).sum(dim=-1).cpu() - effective_seq_len_by_layer_by_head[layer_idx].append(effective_seq_len) - del a_true_by_layer; del positions; del distances - if max_samples is not None: - if ix + 1 == max_samples: - break - - esl = torch.stack([ - torch.stack(effective_seq_len_by_layer_by_head[_idx]) - for _idx in range(len(effective_seq_len_by_layer_by_head)) - ]) - esl = esl.transpose(2, 1) # num_layers x num_heads x num_samples x seq_len - return esl - - -def main(): - args = get_args() - seed_everything(args.seed) - - # Load base model - model, model_config, tokenizer = load_model(args.config_dir, args.model_config) - model.eval() - - # Load data - ## Configs - distill_config_path = join(args.config_dir, 'experiment', - f'{args.distill_config}.yaml') - distill_config = OmegaConf.load(distill_config_path) - # Update data tokenizer to match model - for k in ['pretrained_model_name_or_path', 'cache_dir']: - distill_config.dataset.pretrained_model_config[k] = model_config.model[k] - - if 'num_train_samples' in distill_config.dataset.dataset_config: - num_train_samples = distill_config.dataset.dataset_config['num_train_samples'] - else: - raise NotImplementedError("Please include num_train_samples in under experiment_config.dataset.dataset_config") - - num_train_samples = distill_config.dataset.dataset_config['num_train_samples'] - max_train_samples = distill_config.dataset.dataset_config['max_train_samples'] - max_length = distill_config.dataset.dataset_config['max_length'] - min_length = distill_config.dataset.dataset_config['min_length'] - chunk_size = distill_config.dataset.dataset_config['chunk_size'] - seed = distill_config.dataset.dataset_config['seed'] - - ## Dataloaders - dataloaders = load_data(**distill_config.dataset, **distill_config.dataloader) - train_loader = dataloaders[distill_config.trainer.train_split] - eval_loader = dataloaders[distill_config.trainer.val_split] - - # Compute effective sequence lengths - # -> each is shape: num_layers x num_heads x num_samples x seq_len - train_esl = compute_effective_seq_lens(model, train_loader, max_train_samples) - # eval_esl = compute_effective_seq_lens(model, eval_loader) - - # Save indices to generated filename - _data_attr = distill_config['dataset']['dataset_config']['train_data'] - _data_attr = '-d='.join(_data_attr).replace('/', '_').replace('.json', '') - _data_attr = _data_attr.replace('[','_').replace(']','') - - dataset_config = distill_config.dataset.dataset_config - - # fname = f'd={_data_attr}-nts={num_train_samples}-mts={max_train_samples}-dcs={chunk_size}-max={max_length}-min={min_length}-s={seed}' - fname = f'd={_data_attr}-mts={max_train_samples}-dcs={chunk_size}-max={max_length}-min={min_length}-s={seed}' - fname = join(dataset_config['dataloaders_dir'], 'redpajama_sample_indices', fname) - - # Rank samples by effective sequence length - _train_esl = train_esl.mean(0).mean(0).mean(-1) # num_samples - sorted_idx = torch.argsort(_train_esl, dim=-1, descending=True) - # Save indices to generated filename - np.save(f'{fname}.npy', sorted_idx) - print(f'-> Top {num_train_samples} saved to {fname}!') - - for window in [1, 2, 4, 8, 16, 32, 64, 128]: - _train_esl = train_esl[..., -window:].mean(0).mean(0).mean(-1) # num_samples - sorted_idx = torch.argsort(_train_esl, dim=-1, descending=True) - # sample_idx = sorted_idx[:num_train_samples].numpy() - # Save indices to generated filename - try: - _fname = f'{fname}_l{window:03d}.npy' - np.save(_fname, sorted_idx) - print(f'-> Samples saved to {_fname}!') - - # Also save top samples - sample_idx = sorted_idx[:num_train_samples].numpy() - _fname = f'{fname}-nts={num_train_samples}_l{window:03d}.npy' - np.save(_fname, sample_idx) # sorted_idx) - print(f'-> Top {num_train_samples} saved to {_fname}!') - except: - sample_idx = sorted_idx[:num_train_samples].numpy() - _fname = f'{fname}-nts={num_train_samples}_l{window:03d}.npy' - np.save(_fname, sample_idx) # sorted_idx) - print(f'-> Top {num_train_samples} saved to {_fname}!') - - # sample_idx = sorted_idx[:num_train_samples].numpy() - # train_set.filtered_samples = [train_set.filtered_samples[ix] for ix in sample_idx] - - # # Rank samples by effective sequence length - # train_esl = train_esl.mean(0).mean(0).mean(-1) # num_samples - # sorted_idx = torch.argsort(train_esl, dim=-1, descending=True) - # sample_idx = sorted_idx[:num_train_samples].numpy() - - # np.save(f'{fname}.npy', sample_idx) - # print(f'Top {num_train_samples} saved to {fname}!') - - -if __name__ == '__main__': - main() \ No newline at end of file diff --git a/src/dataloaders/redpajama_sample.py b/src/dataloaders/redpajama_sample.py deleted file mode 100644 index 502deb2..0000000 --- a/src/dataloaders/redpajama_sample.py +++ /dev/null @@ -1,252 +0,0 @@ -""" -Data from https://github.com/FlagOpen/FlagEmbedding/blob/master/Long_LLM/longllm_qlora/src/data.py --> Modifying code from above too -""" -from typing import Optional, List, Dict, Any, Mapping, Iterable, Union -from functools import partial -from os.path import join - -from omegaconf import OmegaConf -from tqdm import tqdm - -import os -import re -import random -import numpy as np -import torch -from torch.utils.data import Dataset, DataLoader - -import datasets -from transformers import AutoTokenizer -from transformers import DataCollatorForSeq2Seq, DefaultDataCollator, DataCollatorWithPadding -from transformers.utils import logging - -from .utils import get_tokenizer_from_config -from .utils.packing import ConcatDataset - -logger = logging.get_logger(__name__) - - -def get_lm_loader(dataset: Dataset, tokenizer: AutoTokenizer, - split: str, max_length: int = None, **loader_kwargs: any): - """ - Get dataloader for language modeling (training) - """ - collate_fn = DefaultDataCollator(return_tensors='pt') - return DataLoader(dataset, shuffle='train' in split, - collate_fn=collate_fn, **loader_kwargs) - - -def add_eos(inputs: Mapping, eos_token_id: int): - """Add eos for BatchEncoding object.""" - assert isinstance(inputs["input_ids"], list), f"Make sure the return_tensors are set to list!" - if inputs["input_ids"][-1] != eos_token_id: - for k, v in inputs.items(): - if k in ["input_ids", "labels"]: - v = v + [eos_token_id] - elif k == "attention_mask": - v = v + [1] - elif k == "position_ids": - v = v + [v[-1] + 1] - elif k == "token_type_ids": - v = v + v[-1:] - else: - raise NotImplementedError(f"Inputs key {k} not implemented!") - inputs[k] = v - return inputs - - -def load_data(name: str, dataset_config: dict, pretrained_model_config: dict, - preprocess_config: dict, **loader_kwargs: any): - dataset_config = OmegaConf.to_object(dataset_config) - - tokenizer_name = pretrained_model_config['pretrained_model_name_or_path'] - tokenizer_name = tokenizer_name.split('/')[-1] - # save_path = join(cache_dir, f'{name}_{tokenizer_name}') - - # Setup tokenizer - tokenizer = get_tokenizer_from_config(pretrained_model_config) - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - print(f'Setting tokenizer.pad_token to {tokenizer.pad_token}') - - tokenizer.padding_side = 'left' # for decoder-only generation - # ^ But does this impact impact attention sink stuff? - - if 'load_from_cache_file' not in dataset_config: - dataset_config['load_from_cache_file'] = None - - # Get initial data - train_set = Data.prepare_train_data( - dataset_config['train_data'], - tokenizer=tokenizer, - max_length=dataset_config['max_length'], - min_length=dataset_config['min_length'], - chat_template=dataset_config['chat_template'], - seed=dataset_config['seed'], - cache_dir=dataset_config['cache_dir'], - load_from_cache_file=dataset_config['load_from_cache_file'] - ) - val_set = Data.prepare_eval_data( - dataset_config['eval_data'], - tokenizer=tokenizer, - max_length=dataset_config['max_length'], - min_length=dataset_config['min_length'], - chat_template=dataset_config['chat_template'], - seed=dataset_config['seed'], - cache_dir=dataset_config['cache_dir'], - max_eval_num=dataset_config['max_eval_num'], - load_from_cache_file=dataset_config['load_from_cache_file'] - ) - - train_set = ConcatDataset(train_set, chunk_size=dataset_config['chunk_size']) - val_set = ConcatDataset(val_set, chunk_size=dataset_config['chunk_size']) - - # Get dataloaders - dataloaders = { - 'train': get_lm_loader(train_set, tokenizer, 'train', **loader_kwargs), - 'validation': get_lm_loader(val_set, tokenizer, 'validation', **loader_kwargs), - } - - # Finishing touches - for k, v in dataloaders.items(): # Make tokenizer accessible - dataloaders[k].dataset.tokenizer = tokenizer - # dataloaders[k].dataset.metric = metric - return dataloaders - - -class Data: - def _process_language_modeling(data, indices, tokenizer, min_length, max_length): - outputs = {'input_ids': [], 'attention_mask': [], "labels": []} # , "length": [], "index": []} - - for i, text in enumerate(data['text']): - # truncate text for faster processing - encoded = tokenizer(text) - if len(encoded["input_ids"]) < min_length: - continue - elif len(encoded['input_ids']) < max_length: - encoded = add_eos(encoded, tokenizer.eos_token_id) - else: - for k, v in encoded.items(): - encoded[k] = v[:max_length] - - encoded["labels"] = encoded["input_ids"].copy() - - for k, v in encoded.items(): - outputs[k].append(v) - return outputs - - def prepare_train_data(data_files=None, tokenizer=None, max_length=4096, - min_length=512, chat_template="llama-3", max_sample_num=None, - seed=42, cache_dir=None, load_from_cache_file=None): - if data_files is None: - return None - - if isinstance(data_files, list): - logger.info(f"Loading training data from {data_files}...") - elif isinstance(data_files, str): - logger.info(f"Loading training data from {data_files}...") - data_files = [data_files] - else: - raise ValueError(f"Invalid training data {data_files}!") - - data_2_num_sample = {} - for data_file in data_files: - match = re.search("\[(\d*)\]", data_file) - if match: - max_sample_num = int(match.group(1)) - data_file = re.sub("\[(\d*)\]", "", data_file) - else: - max_sample_num = None - data_2_num_sample[data_file] = max_sample_num - - random.seed(seed) - - train_datasets = [] - for data_file, max_sample_num in data_2_num_sample.items(): - - if os.path.isdir(data_file) and os.path.exists(os.path.join(data_file, "dataset_info.json")): - # the dataset may be save_to_disk in advance - dataset = datasets.load_from_disk(data_file) - - else: - # the dataset is a json file - data_file = os.path.join('/scr-ssd/mzhang/data/long-llm/long-llm/', data_file) - cache_dir = '/'.join(data_file.split('/')[:-1]) - print('cache_dir', cache_dir) - dataset = datasets.load_dataset('json', data_files=data_file, split='train', cache_dir=cache_dir) - - column_names = dataset.column_names - if "text" in column_names: - process_fn = partial( - Data._process_language_modeling, - tokenizer=tokenizer, - min_length=min_length, - max_length=max_length - ) - elif "conversations" in column_names: - process_fn = partial( - Data._process_instruction_tuning, - tokenizer=tokenizer, - chat_template=chat_template, - min_length=min_length, - max_length=max_length - ) - else: - raise ValueError(f"Found neither 'text' nor 'conversations' in the training data!") - - dataset = dataset.map(process_fn, batched=True, num_proc=32, remove_columns=dataset.column_names, batch_size=32, with_indices=True, load_from_cache_file=load_from_cache_file) - - if max_sample_num is not None and len(dataset) > max_sample_num: - dataset = dataset.train_test_split(max_sample_num, seed=seed)["test"] - - # index column is useless in training - if "index" in dataset.column_names: - dataset = dataset.remove_columns(["index"]) - - train_datasets.append(dataset) - - dataset = datasets.concatenate_datasets(train_datasets) - - return dataset - - def prepare_eval_data(data_files=None, tokenizer=None, max_length=4096, - min_length=512, chat_template="llama-3", max_eval_num=None, - cache_dir=None, seed=42, load_from_cache_file=None): - if data_files is None: - return None - - random.seed(seed) - - data_files = os.path.join('/scr-ssd/mzhang/data/long-llm/long-llm/', data_files[0]) - cache_dir = '/'.join(data_files.split('/')[:-1]) - print('cache_dir', cache_dir) - - if max_eval_num is not None: - dataset = datasets.load_dataset('json', data_files=data_files, split=f'train[{-max_eval_num}:]', cache_dir=cache_dir) - else: - dataset = datasets.load_dataset('json', data_files=data_files, split='train', cache_dir=cache_dir) - - column_names = dataset.column_names - if "text" in column_names: - process_fn = partial( - Data._process_language_modeling, - tokenizer=tokenizer, - min_length=min_length, - max_length=max_length - ) - elif "conversations" in column_names: - process_fn = partial( - Data._process_instruction_tuning, - tokenizer=tokenizer, - chat_template=chat_template, - min_length=min_length, - max_length=max_length, - eval_mode=True, - ) - else: - raise ValueError(f"Found neither 'text' nor 'conversations' in the training data!") - - dataset = dataset.map(process_fn, batched=True, num_proc=32, remove_columns=dataset.column_names, with_indices=True, - load_from_cache_file=load_from_cache_file) - return dataset diff --git a/src/dataloaders/redpajama_sample_contig.py b/src/dataloaders/redpajama_sample_contig.py deleted file mode 100644 index 1a4202e..0000000 --- a/src/dataloaders/redpajama_sample_contig.py +++ /dev/null @@ -1,236 +0,0 @@ -""" -Data from https://github.com/FlagOpen/FlagEmbedding/blob/master/Long_LLM/longllm_qlora/src/data.py --> Modifying code from above too -""" -from typing import Optional, List, Dict, Any, Mapping, Iterable, Union -from functools import partial -from os.path import join - -from omegaconf import OmegaConf -from tqdm import tqdm - -import os -import re -import random -import numpy as np -import torch -from torch.utils.data import Dataset, DataLoader - -import datasets -from transformers import AutoTokenizer -from transformers import DataCollatorForSeq2Seq, DefaultDataCollator, DataCollatorWithPadding -from transformers.utils import logging - -from .utils import get_tokenizer_from_config -from .utils.packing_contig import ConcatContigDataset -from .redpajama_sample import Data - -logger = logging.get_logger(__name__) - -# Models for computing effective sequence length -from src.model.pretrained import get_pretrained_loader -from src.dataloaders.utils import get_tokenizer_from_config, convert_to_hf_dataset - - -def get_lm_loader(dataset: Dataset, shuffle: bool, **loader_kwargs: any): - """ - Get dataloader for language modeling (training) - """ - collate_fn = DefaultDataCollator(return_tensors='pt') - return DataLoader(dataset, shuffle=shuffle, - collate_fn=collate_fn, **loader_kwargs) - - -def load_data(name: str, dataset_config: dict, pretrained_model_config: dict, - preprocess_config: dict, **loader_kwargs: any): - dataset_config = OmegaConf.to_object(dataset_config) - - if 'num_train_samples' in dataset_config: - num_train_samples = dataset_config['num_train_samples'] - else: - raise NotImplementedError("Please include num_train_samples in under experiment_config.dataset.dataset_config") - max_train_samples = dataset_config['max_train_samples'] - max_length = dataset_config['max_length'] - min_length = dataset_config['min_length'] - chunk_size = dataset_config['chunk_size'] - seed = dataset_config['seed'] - - tokenizer_name = pretrained_model_config['pretrained_model_name_or_path'] - tokenizer_name = tokenizer_name.split('/')[-1] - # save_path = join(cache_dir, f'{name}_{tokenizer_name}') - - # Setup tokenizer - tokenizer = get_tokenizer_from_config(pretrained_model_config) - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - print(f'Setting tokenizer.pad_token to {tokenizer.pad_token}') - - tokenizer.padding_side = 'left' # for decoder-only generation - # ^ But does this impact impact attention sink stuff? - - if 'load_from_cache_file' not in dataset_config: - dataset_config['load_from_cache_file'] = None - - # Get initial data - train_set = Data.prepare_train_data( - dataset_config['train_data'], - tokenizer=tokenizer, - max_length=dataset_config['max_length'], - min_length=dataset_config['min_length'], - chat_template=dataset_config['chat_template'], - seed=dataset_config['seed'], - cache_dir=dataset_config['cache_dir'], - load_from_cache_file=dataset_config['load_from_cache_file'] - ) - val_set = Data.prepare_eval_data( - dataset_config['eval_data'], - tokenizer=tokenizer, - max_length=dataset_config['max_length'], - min_length=dataset_config['min_length'], - chat_template=dataset_config['chat_template'], - seed=dataset_config['seed'], - cache_dir=dataset_config['cache_dir'], - max_eval_num=dataset_config['max_eval_num'], - load_from_cache_file=dataset_config['load_from_cache_file'] - ) - - train_set = ConcatContigDataset(train_set, chunk_size=dataset_config['chunk_size']) - val_set = ConcatContigDataset(val_set, chunk_size=dataset_config['chunk_size']) - - if 'filter_by_esl' not in dataset_config: - dataset_config['filter_by_esl'] = False - - if 'filter_window' not in dataset_config: - dataset_config['filter_window'] = 0 - window = dataset_config['filter_window'] - - if dataset_config['filter_by_esl']: - cache_dir = dataset_config['cache_dir'] - # train_set = convert_to_hf_dataset(train_set, cache_dir) - # val_set = convert_to_hf_dataset(val_set, cache_dir) - - # Filter train_set for largest effective context length - # -> Get precomputed topk samples - _data_attr = dataset_config['train_data'] - _data_attr = '-d='.join(_data_attr).replace('/', '_').replace('.json', '') - _data_attr = _data_attr.replace('[','_').replace(']','') - - # fname = f'd={_data_attr}-nts={num_train_samples}-mts={max_train_samples}-dcs={chunk_size}-max={max_length}-min={min_length}-s={seed}' - - - try: - fname = f'd={_data_attr}-mts={max_train_samples}-dcs={chunk_size}-max={max_length}-min={min_length}-s={seed}' - fname = join(dataset_config['dataloaders_dir'], 'redpajama_sample_indices', fname) - if dataset_config['filter_window'] > 0: - sorted_idx = np.load(f'{fname}_l{window:03d}.npy') - else: - sorted_idx = np.load(f'{fname}.npy') - sample_idx = sorted_idx[:num_train_samples] # .numpy() - train_set.filtered_samples = [train_set.filtered_samples[ix] for ix in sample_idx] - print(f'-> Top {num_train_samples} indices loaded from {fname}!') - except Exception as e: - print(e) - print(f'-> Error with loading from {fname}.npy. Computing...') - model, model_config, tokenizer = load_model(config_dir='/scr-ssd/mzhang/projects/lolcats/configs', - model_config=dataset_config['esl_model_config']) - model.eval() - # Compute effective sequence lengths - # -> each is shape: num_layers x num_heads x num_samples x seq_len - train_loader = get_lm_loader(train_set, shuffle=False, **loader_kwargs) - train_esl = compute_effective_seq_lens(model, train_loader, - max_samples=max_train_samples) - # Rank samples by effective sequence length - _train_esl = train_esl.mean(0).mean(0).mean(-1) # num_samples - sorted_idx = torch.argsort(_train_esl, dim=-1, descending=True) - # Save indices to generated filename - fname = f'd={_data_attr}-mts={max_train_samples}-dcs={chunk_size}-max={max_length}-min={min_length}-s={seed}' - fname = join(dataset_config['dataloaders_dir'], 'redpajama_sample_indices', fname) - np.save(f'{fname}.npy', sorted_idx) - print(f'-> Top {num_train_samples} saved to {fname}!') - - # Also sort by computing sequence lengths over last window tokens - for window in [1, 2, 4, 8, 16, 32, 64, 128]: - _train_esl = train_esl[..., -window:].mean(0).mean(0).mean(-1) # num_samples - sorted_idx = torch.argsort(_train_esl, dim=-1, descending=True) - # Save indices to generated filename - try: - _fname = f'{fname}_l{window:03d}.npy' - np.save(_fname, sorted_idx) - print(f'-> Samples saved to {_fname}!') - - # Also save top samples - sample_idx = sorted_idx[:num_train_samples].numpy() - _fname = f'{fname}-nts={num_train_samples}_l{window:03d}.npy' - np.save(_fname, sample_idx) # sorted_idx) - print(f'-> Top {num_train_samples} saved to {_fname}!') - except: - sample_idx = sorted_idx[:num_train_samples].numpy() - _fname = f'{fname}-nts={num_train_samples}_l{window:03d}.npy' - np.save(_fname, sample_idx) # sorted_idx) - print(f'-> Top {num_train_samples} saved to {_fname}!') - - sample_idx = sorted_idx[:num_train_samples].numpy() - train_set.filtered_samples = [train_set.filtered_samples[ix] for ix in sample_idx] - - # Get dataloaders - dataloaders = { - 'train': get_lm_loader(train_set, shuffle=True, **loader_kwargs), - 'validation': get_lm_loader(val_set, shuffle=False, **loader_kwargs), - } - - # Finishing touches - for k, v in dataloaders.items(): # Make tokenizer accessible - dataloaders[k].dataset.tokenizer = tokenizer - return dataloaders - - -# Helpers for computing longest effective sequence length samples - -def load_model(config_dir: str = './configs', - model_config: str = 'base_llama3_8b'): - """ - Load pretrained LLM (default is Llama 3 8B) - """ - model_config = join(config_dir, 'model', f'{model_config}.yaml') - model_config = OmegaConf.load(model_config) - model_config['model']['attn_implementation'] = 'eager' # for attentions - - # Load initial model - model_loader = get_pretrained_loader(**model_config.model) - tokenizer = model_loader.load_tokenizer() - tokenizer.pad_token_id = tokenizer.eos_token_id - tokenizer.padding_side = 'left' - - model = model_loader.load(model_config['attention']['attention_type']) - return model, model_config, tokenizer - - -def compute_effective_seq_lens(model, dataloader, max_samples: int = None): - """Compute effective sequence length for each sample""" - effective_seq_len_by_layer_by_head = [[] for _ in range(len(model.model.layers))] - batch_idx = 0 # always batch size 1 - with torch.no_grad(): - for ix, data in enumerate(tqdm(dataloader, desc='Computing effective sequence lengths')): - inputs = {k: v.to(model.device) for k, v in data.items() if k != 'labels'} - outputs = model(**inputs, output_attentions=True, use_cache=False, - output_hidden_states=False) - outputs = outputs.get('attentions') - a_true_by_layer = outputs - for layer_idx in range(len(a_true_by_layer)): - # Effective seq len - _true_attns = a_true_by_layer[layer_idx][batch_idx] # num_heads x seq_len x seq_len - positions = torch.arange(_true_attns.shape[-1], device=_true_attns.device) - distances = positions[:, None] - positions[None, :] # "outer diff" - effective_seq_len = (distances * _true_attns).sum(dim=-1).cpu() - effective_seq_len_by_layer_by_head[layer_idx].append(effective_seq_len) - del a_true_by_layer; del positions; del distances - if max_samples is not None: - if ix + 1 == max_samples: - break - - esl = torch.stack([ - torch.stack(effective_seq_len_by_layer_by_head[_idx]) - for _idx in range(len(effective_seq_len_by_layer_by_head)) - ]) - esl = esl.transpose(2, 1) # num_layers x num_heads x num_samples x seq_len - return esl diff --git a/src/dataloaders/utils/packing_contig.py b/src/dataloaders/utils/packing_contig.py deleted file mode 100644 index f65fe91..0000000 --- a/src/dataloaders/utils/packing_contig.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. -""" -Copied from https://github.com/meta-llama/llama-recipes/blob/9b3dabcaac78980eae40005bbc8b1a8276c82af3/src/llama_recipes/data/concatenator.py#L1 -""" -import random -from itertools import chain -from tqdm import tqdm - - -from torch.utils.data import Dataset - - -class ConcatContigDataset(Dataset): - """ - Concatenates or packs samples of a dataset into chunks of size `chunk_size` - -> But keep concatenated samples sourced within the same dataset - """ - def __init__(self, dataset, chunk_size: int = 1024, seed: int = 42,) -> None: - self.dataset = dataset - self.chunk_size = chunk_size - self.samples = [] - random.seed(seed) - for sample in tqdm(self.dataset, desc="Preprocessing dataset", dynamic_ncols=True): - buffer = { - "input_ids": [], - "attention_mask": [], - "labels": [], - } - buffer = {k: v + sample[k] for k,v in buffer.items()} - - while len(next(iter(buffer.values()))) > self.chunk_size: - self.samples.append({k: v[:self.chunk_size] for k,v in buffer.items()}) - buffer = {k: v[self.chunk_size:] for k,v in buffer.items()} - - # Slow hack, but filter out any samples without valid labels (all -100) - self.filtered_samples = [] - for s in self.samples: - if sum(s['labels']) != chunk_size * -100: - self.filtered_samples.append(s) - if len(self.filtered_samples) < len(self.samples): - print(f'OG dataset: {len(self.samples)} samples -> Filtered dataset: {len(self.filtered_samples)}') - print(f'-> Filtered out {len(self.samples) - len(self.filtered_samples)} samples') - - def __getitem__(self, idx): - return self.filtered_samples[idx] - - def __len__(self): - return len(self.filtered_samples) diff --git a/src/model/convert_model.py b/src/model/convert_model.py index d3fcfbc..dbc0d93 100644 --- a/src/model/convert_model.py +++ b/src/model/convert_model.py @@ -109,45 +109,22 @@ def get_attention(attention_type: str, **kwargs: any): from .linear_attention import LolcatsTKWindowAttention return partial(LolcatsTKWindowAttention, **kwargs) - elif attention_type == 'lolcats_long_llama_window_tk': - from .linear_attention import LolcatsTKWindowLongAttention - return partial(LolcatsTKWindowLongAttention, **kwargs) - - elif attention_type == 'lolcats_llama_window_tk_bf16': - from .linear_attention import LolcatsTKWindowAttentionBF16 - return partial(LolcatsTKWindowAttentionBF16, **kwargs) - - elif attention_type == 'lolcats_llama_window_tk_fa2': - from .linear_attention.linear_window_attention_tk_fa2 import LolcatsTKWindowAttentionFA2 - return partial(LolcatsTKWindowAttentionFA2, **kwargs) - - elif attention_type == 'lolcats_llama_window_tk_sdpa': - from .linear_attention.linear_window_attention_tk_sdpa import LolcatsTKWindowAttentionSDPA - return partial(LolcatsTKWindowAttentionSDPA, **kwargs) - elif attention_type == 'lolcats_llama_window_sw': from .linear_attention import LolcatsSlidingWindowAttention return partial(LolcatsSlidingWindowAttention, **kwargs) - elif attention_type == 'lolcats_long_llama_window_sw': - from .linear_attention import LolcatsSlidingWindowLongAttention - return partial(LolcatsSlidingWindowLongAttention, **kwargs) - - elif attention_type == 'lolcats_llama_window_sw_scale': - from .linear_attention.linear_window_attention_sw_scale import LolcatsSlidingWindowAttention - return partial(LolcatsSlidingWindowAttention, **kwargs) - elif attention_type == 'lolcats_llama_window_sw_linear': from .linear_attention.linear_window_attention_sw_linear import LolcatsLinearSlidingWindowAttention return partial(LolcatsLinearSlidingWindowAttention, **kwargs) - # elif attention_type == 'lolcats_llama_window_sw_tiled': - # from .linear_attention.linear_window_attention_sw_tiled import LolcatsTiledSlidingWindowAttention - # return partial(LolcatsTiledSlidingWindowAttention, **kwargs) + ## Experimental chunked linear attentions below + elif attention_type == 'lolcats_long_llama_window_tk': + from .linear_attention import LolcatsTKWindowLongAttention + return partial(LolcatsTKWindowLongAttention, **kwargs) - elif attention_type == 'supra': - from .linear_attention.supra_attention import SUPRALinearAttention - return partial(SUPRALinearAttention, **kwargs) + elif attention_type == 'lolcats_long_llama_window_sw': + from .linear_attention import LolcatsSlidingWindowLongAttention + return partial(LolcatsSlidingWindowLongAttention, **kwargs) else: print(f'-> attention_type {attention_type} not handled... returning None') @@ -169,10 +146,6 @@ def get_attention_cache(attention_type: str, past_key_values: any = None): elif 'llama_window_sw' in attention_type: from .linear_attention import LinearAttentionSlidingWindowCache return LinearAttentionSlidingWindowCache() - - elif 'llama_window_sw_scale' in attention_type: - from .linear_attention.linear_window_attention_sw_scale import LinearAttentionSlidingWindowCache - return LinearAttentionSlidingWindowCache() elif 'llama_window_sw_linear' in attention_type: from .linear_attention import LinearAttentionSlidingWindowCache diff --git a/src/model/feature_map.py b/src/model/feature_map.py index e1ca4e7..fae1064 100644 --- a/src/model/feature_map.py +++ b/src/model/feature_map.py @@ -197,10 +197,8 @@ def __init__(self, if self.normal_init: with torch.no_grad(): nn.init.normal_(self.layer) - # print('normal init') if self.skip_connection: - # print('skip connection') assertion_fail = f'If self.skip_connection we need self.head_dim == self.feature_dim but self.head_dim is {self.head_dim} != self.feature_dim is {self.feature_dim}' assert self.head_dim == self.feature_dim, assertion_fail @@ -258,6 +256,8 @@ class FeatureMapAdapter(FeatureMapMLP): """ Learnable Feature map with bottleneck adapter as in https://arxiv.org/abs/1902.00751 + + We don't use but could be fun to try """ def __init__(self, hidden_dim: int, *args, **kwargs): kwargs['skip_connection'] = True diff --git a/src/model/linear_attention/__init__.py b/src/model/linear_attention/__init__.py index 7b2f090..dd3e49f 100644 --- a/src/model/linear_attention/__init__.py +++ b/src/model/linear_attention/__init__.py @@ -7,15 +7,13 @@ from .linear_window_attention_tk import ( LolcatsTKWindowAttention, LinearAttentionTKWindowCache ) -from .linear_window_attention_tk_long import ( - LolcatsTKWindowLongAttention, -) -from .linear_window_attention_tk_bf16 import ( - LolcatsTKWindowAttentionBF16, -) from .linear_window_attention_sw import ( LolcatsSlidingWindowAttention, LinearAttentionSlidingWindowCache ) +# Experimental chunk linear attentions +from .linear_window_attention_tk_long import ( + LolcatsTKWindowLongAttention, +) from .linear_window_attention_sw_long import ( LolcatsSlidingWindowLongAttention, ) diff --git a/src/model/linear_attention/linear_attention.py b/src/model/linear_attention/linear_attention.py index 295c23e..727e15a 100644 --- a/src/model/linear_attention/linear_attention.py +++ b/src/model/linear_attention/linear_attention.py @@ -338,7 +338,7 @@ def forward(self, # Apply prefill mask if attention_mask is not None and q.shape[2] > 1: if len(attention_mask.shape) == 4: - lin_attn_mask = (attention_mask == 0)[:, :1, -1, :][..., None] # b, 1, k_len, 1 + lin_attn_mask = (attention_mask == 0)[:, :1, -1, :l][..., None] # b, 1, k_len, 1 else: lin_attn_mask = attention_mask[:, None, :, None] # b, 1, k_len, 1 k = k.masked_fill(~lin_attn_mask, 0) @@ -456,4 +456,4 @@ def reorder_cache(self, beam_idx: torch.LongTensor): Reorders the cache for beam search, given the selected beam indices. -> Copied from transformers/src/transformers/cache_utils.py """ - raise NotImplementedError('Reordering cache not implemented for LinearAttentionState') \ No newline at end of file + raise NotImplementedError('Reordering cache not implemented for LinearAttentionState') diff --git a/src/model/linear_attention/linear_window_attention_sw_linear.py b/src/model/linear_attention/linear_window_attention_sw_linear.py index 482d765..ea52584 100644 --- a/src/model/linear_attention/linear_window_attention_sw_linear.py +++ b/src/model/linear_attention/linear_window_attention_sw_linear.py @@ -40,8 +40,8 @@ def get_masks(window_size: int, q_len: int, k_len: int, -> 1 is include, 0 is ignore """ kwargs = {'device': device, 'dtype': int} - causal_mask = torch.ones((q_len, k_len), **kwargs).tril(k_len - q_len) - linear_mask = torch.ones((q_len, k_len), **kwargs).tril(-window_size) + causal_mask = torch.ones((q_len, k_len), **kwargs).tril(max(k_len - q_len, 0)) + linear_mask = torch.ones((q_len, k_len), **kwargs).tril(max(k_len - q_len, 0) - window_size) window_mask = causal_mask - linear_mask # Return softmax mask (window), linear attention mask # -> shapes broadcast over (b, h, q_len, k_len) diff --git a/src/model/linear_attention/linear_window_attention_sw_scale.py b/src/model/linear_attention/linear_window_attention_sw_scale.py deleted file mode 100644 index 73df514..0000000 --- a/src/model/linear_attention/linear_window_attention_sw_scale.py +++ /dev/null @@ -1,397 +0,0 @@ -""" -Hedgehog attention combining sliding window and linear attentions - -For each layer: -- We first compute (softmax) attention over sliding windows -- We then compute standard linear attention to "fill in" the earlier parts -- We combine to model the entire sequence -""" -from typing import List, Tuple, Optional, Dict, Any -import copy -import torch -import torch.nn as nn -import torch.nn.functional as F - -from transformers.cache_utils import Cache # Transformers v4.36 - -# Causal linear attention dot product CUDA kernel from fast-transformers -from csrc import causal_dot_product - -from src.model.rotary import apply_rotary_pos_emb -from .linear_attention import ( - LolcatsLinearAttention, LinearAttentionState, softmax_attention, repeat_kv -) - -# ---------------------- -# Sliding window helpers -# ---------------------- - -def get_causal_mask(x: torch.Tensor): - """ - Assume x is shape (..., m, n) - Return mask of shape (..., m, n) where 1 is include, 0 is mask - """ - m, n = x.shape[-2:] - return torch.ones((1, 1, m, n), device = x.device, dtype = int).tril(n - m) - - -def get_under_window_mask(x: torch.Tensor, window_size: int): - """Return mask for under window terms""" - m, n = x.shape[-2:] - return torch.ones((1, 1, m, n), device=x.device, dtype=int).tril(-window_size) - - -def get_sliding_window_mask(x: torch.Tensor, window_size: int): - """Return sliding window mask""" - mask = get_causal_mask(x) - return (mask - get_under_window_mask(x, window_size)).to(dtype=mask.dtype) - - -def hybrid_window_attention_quadratic(q: torch.Tensor, k: torch.Tensor, - f_q: torch.Tensor, f_k: torch.Tensor, - v: torch.Tensor, window_size: int, - eps: float = 1e-12,): - """Comput hybrid window attention with quadratic complexity""" - # 1. Sliding window (softmax attention) - a_sm = torch.einsum('bhmd,bhnd->bhmn', q, k) * (k.shape[-1] ** -0.5) - - mask_causal = get_causal_mask(a_sm) - mask_linear = get_under_window_mask(a_sm, window_size) - mask_window = mask_causal - mask_linear - - a_sm = a_sm.masked_fill(~mask_window.bool(), -torch.finfo(a_sm.dtype).max) - a_sm = torch.softmax(a_sm, dim=-1) - - # 2. Under window (linear attention) - a_ln = torch.einsum('bhmd,bhnd->bhmn', f_q, f_k) - - # First compute all causal terms (to normalize with) - a_causal = a_ln.masked_fill(~mask_causal.bool(), 0) - sum_a_ca = a_causal.sum(dim=-1, keepdim=True) - sum_a_ca[sum_a_ca == 0] += eps # stability - - # Then compute actual linear attn terms - a_ln = a_ln.masked_fill(~mask_linear.bool(), 0) - sum_a_ln = a_ln.sum(dim=-1, keepdim=True) - sum_a_ln[sum_a_ln == 0] += eps # stability - - a_ln = a_ln / sum_a_ca # linear attention weights - ratio_sm = 1 - (sum_a_ln / sum_a_ca) # ratio allocated to softmax terms - - # 3. Combine - a = a_ln + ratio_sm * a_sm - y = torch.einsum('bhmn,bhnd->bhmd', a, v) - return y, a - - -def under_window_dot_prod(f_q: torch.Tensor, f_k: torch.Tensor, v: torch.Tensor, - window_size: int, eps: float=1e-12): - """Compute hybrid window attention dot product with linear complexity in q_len""" - dtype = f_q.dtype - w = window_size - f_k = F.pad(f_k, (0, 0, w, 0), value=0)[:, :, :-w, :] - v = F.pad(v, (0, 0, w, 0), value=0)[:, :, :-w, :] - qkv = causal_dot_product(f_q.contiguous().to(dtype=torch.float32), - f_k.contiguous().to(dtype=torch.float32), - v.contiguous().to(dtype=torch.float32)).to(dtype=dtype) - sum_f_k = f_k.float().cumsum(dim=2).to(dtype=dtype) - sum_qk = torch.einsum("bhld,bhld->bhl", f_q, sum_f_k)[..., None] - sum_qk[sum_qk == 0] += eps - - return qkv, sum_qk - - -def sliding_window_softmax_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - window_size: int, mask_value: float=-1e8): - """ - Compute sliding window softmax attention without materializing - O(seq_len^2) attention weights - """ - d = q.shape[-1] - # Compute windows for keys - window_kwargs = {'dimension': 2, 'size': window_size, 'step': 1} - k = F.pad(k, (0, 0, window_size - 1, 0), value=0).unfold(**window_kwargs) - v = F.pad(v, (0, 0, window_size - 1, 0), value=0).unfold(**window_kwargs) - - # Compute windowed_softmax(qk); causal in its construction - qk = torch.einsum('bhld,bhldw->bhlw', q, k) * (d ** -0.5) - qk[qk == 0] = -torch.finfo(q.dtype).max # heuristic for zeroing out padding above - return torch.einsum('bhlw,bhldw->bhld', torch.softmax(qk, dim=-1), v) - - -def hybrid_window_attention(q: torch.Tensor, k: torch.Tensor, - f_q: torch.Tensor, f_k: torch.Tensor, - v: torch.Tensor, window_size: int, - mask_value: float = -1e-8, eps: float = 1e-12, - kv_states: Optional[Tuple[torch.Tensor]] = None, - ): - """Compute hybrid sliding window attention with linear complexity""" - window_kwargs = {'dimension': 2, 'size': window_size, 'step': 1} - # 1. Sliding window (softmax attention) - y_sm = sliding_window_softmax_attention(q, k, v, window_size, mask_value) - - # 2. Under window (linear attention) - sum_f_k = f_k.float().cumsum(dim=2).to(dtype=q.dtype) - sum_qk_causal = torch.einsum("bhld,bhld->bhl", f_q, sum_f_k)[..., None] - sum_qk_causal[sum_qk_causal == 0] += eps - qkv_ln, sum_qk_ln = under_window_dot_prod(f_q, f_k, v, window_size) - - # 3. Combine - y_ln = qkv_ln / sum_qk_causal - ratio_sm = 1 - (sum_qk_ln / sum_qk_causal) # ratio allocated to softmax terms - return y_ln + ratio_sm * y_sm - - -class LolcatsSlidingWindowAttention(LolcatsLinearAttention): - """ - LoLCATs attention combining sliding window and linear attention - """ - def __init__(self, window_size: int = 64, **kwargs): - self.window_size = window_size - self.window_kwargs = {'dimension': 2, 'size': window_size, 'step': 1} - super().__init__(**kwargs) - self.attention_type = kwargs['attention_type'] # 'hedgehog_llama_window_tk' - - def base_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor = None, - causal: bool = True, **kwargs: any): - """ - Standard softmax attention - """ - a = torch.einsum('bhmd,bhnd->bhmn', q, k) * (k.shape[-1] ** -0.5) - y = None - if causal: - m, n = a.shape[-2:] - causal_mask = torch.ones((m, n), device = a.device, dtype = torch.bool).triu(n - m + 1) - a = a.masked_fill(causal_mask, -torch.finfo(a.dtype).max) - a = torch.softmax(a, dim=-1) - if v is not None: - y = torch.einsum('bhmn,bhnd->bhmd', a, v) - return y, a, None - - def process_qkv(self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[int, torch.Tensor, torch.Tensor]] = None,): # "legacy" cache approach - """ - Compute queries, keys, and values - """ - b, l, _ = hidden_states.size() - q = self.q_proj(hidden_states) - k = self.k_proj(hidden_states) - v = self.v_proj(hidden_states) - kv_seq_len = k.shape[-2] - - # Shape is (batch_size, seq_len, num_heads, head_dim) - q = q.view(b, l, *self.q_shape).transpose(1, 2) - k = k.view(b, l, *self.k_shape).transpose(1, 2) - v = v.view(b, l, *self.v_shape).transpose(1, 2) - - if past_key_value is not None: # and k.shape[2] > q.shape[2]: # e.g., when generating - past_key_value.window_size = self.window_size - # print(f'{self.layer_idx} usable length', past_key_value.get_usable_length(kv_seq_len, self.layer_idx)) - past_key_value.window_size = self.window_size - if isinstance(past_key_value, Cache): # In Transformers v4.36+ this is a DynamicCache object - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - else: - kv_seq_len += past_key_value[0].shape[-2] - - # Apply rotary embeddings and repeat for GQA - if position_ids is not None and kv_seq_len <= position_ids[0, -1]: - kv_seq_len = position_ids[0, -1] + 1 # hack for adjusting position ids - try: # As in Transformers v4.36 - cos, sin = self.rotary_emb(k, seq_len=kv_seq_len) - q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids) - except TypeError: # As in Transformers v4.39+ - cos, sin = self.rotary_emb(v, position_ids) - q, k = apply_rotary_pos_emb(q, k, cos, sin) - k = repeat_kv(k, self.num_key_value_groups) - v = repeat_kv(v, self.num_key_value_groups) - return q, k, v, kv_seq_len - - def forward(self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[int, torch.Tensor, torch.Tensor]] = None, # "legacy" cache approach - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """ - Forward pass with the option to compute attention weights multiple ways - if self.train_attention is True - -> Consistent with HuggingFace Transformers for easy use with their pretrained models - """ - b, l, _ = hidden_states.size() - q, k, v, kv_seq_len = self.process_qkv(hidden_states, attention_mask, - position_ids, past_key_value) - f_q, f_k = self.feature_map_q(q), self.feature_map_k(k) # Have to do after repeat for grouped-query attn if we use same fmap - - if self.train_attention: - # 1. Compute "ground-truth" attention output and weights - with torch.no_grad(): - _y_true, a_true = self.base_attention(q, k, v)[:2] - y_true = _y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size) - y_true = self.o_proj(y_true) - - # 2. Compute "predicted" attention outputs - # compute attn weights under sliding window - y_pred, a_pred = hybrid_window_attention_quadratic(q, k, f_q, f_k, v, window_size=self.window_size) - attn_weights = ((a_pred, a_true), (y_pred, _y_true)) - else: - # During finetuning and inference - if attention_mask is not None and f_q.shape[2] > 1: - if len(attention_mask.shape) == 4: - lin_attn_mask = (attention_mask == 0)[:, :1, -1, :][..., None] # b, 1, k_len, 1 - else: - lin_attn_mask = attention_mask[:, None, :, None] # b, 1, k_len, 1 - f_k = f_k.masked_fill(~lin_attn_mask, self.mask_value) - - if past_key_value is not None: - past_key_value.window_size = self.window_size - if f_q.shape[2] == 1 and kv_seq_len > 1 and past_key_value is not None: # indicates we're generating - # print(f'Recurrent view | layer {self.layer_idx} | type(past_key_value): {type(past_key_value)}') - assert use_cache is True - - # Linear attention bit: - # 1. Update K cache by first evicting first token - # (it's outside the sliding window now), and computing feature_map(k) - f_k_from_cache = past_key_value.k_cache[self.layer_idx][:, :, :1, :] - # MZ 6/3: handle short inputs; zero-out padding when initial k.shape[2] < self.window_size - if f_k_from_cache.sum() == 0: # heuristic for zeroing out padding in cache - f_k_from_cache = torch.zeros(f_q.shape, dtype=f_q.dtype, device=f_q.device) - else: - f_k_from_cache = self.feature_map_k(f_k_from_cache) - - # 2. Then compute feature_map(k) v^T and add to kv_state - # (past_key_value.update takes care of this, gets the proper stored v) - # v in the arg below is handled separately - kv_states = past_key_value.update(k, f_k_from_cache, f_k, v, self.layer_idx) - kv_state, k_state, k_state_causal = kv_states - - # 3. Finally compute linear attentions over terms before the window - qkv_ln = torch.einsum('bhlf,bhfd->bhld', f_q, kv_state) - sum_qk_ln = torch.einsum('bhlf,bhlf->bhl', f_q, k_state)[..., None] - sum_qk_causal = torch.einsum('bhlf,bhlf->bhl', f_q, k_state_causal)[..., None] - - y_ln = qkv_ln / sum_qk_causal - ratio_sm = 1 - (sum_qk_ln / sum_qk_causal) - - # Sliding window attention bit - # -> Compute attention over K and V fixed window cache - a_sm = torch.einsum('bhmd,bhnd->bhmn', q, past_key_value.k_cache[self.layer_idx]) - a_sm[a_sm == 0] = -torch.finfo(q.dtype).max # heuristic for zeroing out padding in cache - a_sm = torch.softmax(a_sm * q.shape[-1] ** -0.5, dim=-1) - try: - y_sm = torch.einsum('bhmn,bhnd->bhmd', a_sm, past_key_value.v_cache[self.layer_idx]) - except: - breakpoint() - - # Combine - y_true = y_ln + ratio_sm * y_sm - else: - past_key_value.init_cache_and_states(k, f_k, v, self.layer_idx) - y_true = hybrid_window_attention(q, k, f_q, f_k, v, window_size=self.window_size) - else: - y_true = hybrid_window_attention(q, k, f_q, f_k, v, window_size=self.window_size) - - # Concatenate heads and apply output projection - y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size) - y_true = self.o_proj(y_true) - attn_weights = None - - return y_true, attn_weights, past_key_value - - -class LinearAttentionSlidingWindowCache(LinearAttentionState): - """ - Class for `past_key_values` - -> Alternative to KV cache; here we only maintain a "KV state" and "K state" - -> Modified from transformers.cache_utils.DynamicCache (v4.36) - """ - def __init__(self, window_size: int = 64) -> None: - self._seen_tokens = 0 # Note in Transformer versions >4.36 this all `seen_tokens*` should be `_seen_tokens*` - # track phi(k)v^T and phi(k) until softmax terms - self.kv_states: List[torch.Tensor] = [] - self.k_states: List[torch.Tensor] = [] - # track all phi(k) until causal end - self.k_states_causal: List[torch.Tensor] = [] - - self.k_cache: List[torch.Tensor] = [] - self.v_cache: List[torch.Tensor] = [] - self._seen_tokens_by_layer: List[int] = [] - self.window_size = window_size - - def init_cache_and_states(self, - keys: torch.Tensor, - fmap_keys: torch.Tensor, - values: torch.Tensor, - layer_idx: Optional[int] = None): - """ - Initialize KV cache and states - """ - if layer_idx == 0: - self._seen_tokens += keys.shape[-2] - - dtype = keys.dtype - - # MZ 6/3: handle short inputs; pad if k.shape[2] < self.window_size - if keys.shape[-2] < self.window_size: - keys = F.pad(keys, (0, 0, self.window_size - keys.shape[-2], 0), value=0) - k = F.pad(fmap_keys, (0, 0, self.window_size, 0), value=0) # [:, :, :-w, :] - v = F.pad(values, (0, 0, self.window_size, 0), value=0) # [:, :, :-w, :] - - # k_cache keeps track of k; k_state keeps track of phi(k) - k_cache, k_state = keys[:, :, -self.window_size:, :], k[:, :, :-self.window_size, :] - v_cache, v_state = v[:, :, -self.window_size:, :], v[:, :, :-self.window_size, :] - - # Update the cache - self.k_cache.append(k_cache) - self.v_cache.append(v_cache) - - # kv_state = torch.einsum('bhlf,bhld->bhfd', k_state, v_state) # b, h, f, d - # k_state = k_state.sum(dim=-2, keepdim=True) # b, h, 1, f; note the 1 - kv_state = torch.einsum('bhlf,bhld->bhfd', k_state.float(), v_state.float()).to(dtype) # b, h, f, d - k_state = k_state.float().sum(dim=-2, keepdim=True).to(dtype) # b, h, 1, f; note the 1 - self.kv_states.append(kv_state) - self.k_states.append(k_state) - - # Keep track of all qk_sum - # self.k_states_causal.append(fmap_keys.sum(dim=-2, keepdim=True)) - self.k_states_causal.append(fmap_keys.float().sum(dim=-2, keepdim=True).to(dtype)) - self._seen_tokens_by_layer[layer_idx] = keys.shape[-2] - - def update(self, - keys: torch.Tensor, - fmap_key_from_cache: torch.Tensor, - fmap_keys: torch.Tensor, - values: torch.Tensor, - layer_idx: Optional[int] = None, - cache_kwargs: Optional[any] = None, - *args, **kwargs: any, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Update KV cache and states during generation - """ - assert keys.shape[-2] == 1 - dtype = keys.dtype - - if layer_idx == 0: - self._seen_tokens += keys.shape[-2] - - # Get first key and value in cache (first added) - k_state = fmap_key_from_cache # self.k_cache[layer_idx][:, :, :1, :] - v_state = self.v_cache[layer_idx][:, :, :1, :] - # kv_state = torch.einsum('bhlf,bhld->bhfd', k_state, v_state) # b, h, f, d - kv_state = torch.einsum('bhlf,bhld->bhfd', k_state.float(), v_state.float()).to(dtype) # b, h, f, d - - # Update the states and cache - self.kv_states[layer_idx] += kv_state - self.k_states[layer_idx] += k_state - self.k_states_causal[layer_idx] += fmap_keys - - self.k_cache[layer_idx] = torch.cat([self.k_cache[layer_idx][:, :, 1:, :], keys], dim=-2) - self.v_cache[layer_idx] = torch.cat([self.v_cache[layer_idx][:, :, 1:, :], values], dim=-2) - self._seen_tokens_by_layer[layer_idx] += keys.shape[-2] - return self.kv_states[layer_idx], self.k_states[layer_idx], self.k_states_causal[layer_idx] diff --git a/src/model/linear_attention/linear_window_attention_sw_tiled.py b/src/model/linear_attention/linear_window_attention_sw_tiled.py deleted file mode 100644 index 9120d97..0000000 --- a/src/model/linear_attention/linear_window_attention_sw_tiled.py +++ /dev/null @@ -1,378 +0,0 @@ -""" -Subquadratic attention combining sliding window and linear attentions -- Using "standard" sliding windows -- Didactically computes outputs with n^2 attention weights for now -- Copied + adapted from linear_window_attention_tk.py for single-file reference - -For each layer: -- We first compute (softmax) attention over sliding windows -- We then compute standard linear attention to "fill in" the earlier parts -- We combine to model the entire sequence -""" -from typing import List, Tuple, Optional, Callable -import math -import torch -import torch.nn as nn -import torch.nn.functional as F - -from tqdm import tqdm - -from transformers.cache_utils import Cache - -from .linear_attention import ( - LolcatsLinearAttention, LinearAttentionState, - softmax_attention -) - -# ---------------------- -# Sliding window helpers -# ---------------------- -def get_masks(window_size: int, q_len: int, k_len: int, - device: torch.device) -> tuple[torch.Tensor]: - """ - Return masks for softmax and linear attention terms - -> 1 is include, 0 is ignore - """ - kwargs = {'device': device, 'dtype': int} - causal_mask = torch.ones((q_len, k_len), **kwargs).tril(k_len - q_len) - linear_mask = torch.ones((q_len, k_len), **kwargs).tril(k_len - q_len - window_size) - window_mask = causal_mask - linear_mask - # Return softmax mask (window), linear attention mask - # -> shapes broadcast over (b, h, q_len, k_len) - return window_mask[None, None, ...], linear_mask[None, None, ...] - - -def hybrid_attention_quadratic(q: torch.Tensor, k: torch.Tensor, - f_q: torch.Tensor, f_k: torch.Tensor, - v: torch.Tensor, - window_factor: torch.Tensor, - linear_factor: torch.Tensor, - window_size: int, - kv_state: torch.Tensor = None, - k_state: torch.Tensor = None, - eps: float = 1e-12, - mask_value: float=-1e8): - """ - Hybrid attention combining sliding window and linear attentions - """ - - mask_window, mask_linear = get_masks(window_size, q.shape[-2], k.shape[-2], q.device) - - # 1. Sliding window (softmax attention) - a_sm = torch.einsum('bhmd,bhnd->bhmn', q.float(), k.float()) * (k.shape[-1] ** -0.5) - a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value) - # torch.softmax(a_sm, dim=-1), but we account for the max when combining - a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True) - a_sm = window_factor * torch.exp(a_sm - a_sm_max) - sum_sm = a_sm.sum(dim=-1, keepdim=True) - - # 2. Under window (linear attention) - a_ln = torch.einsum('bhmd,bhnd->bhmn', f_q.float(), f_k.float()) - a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0) - sum_ln = a_ln.sum(dim=-1, keepdim=True) - - # 3. Combine - a = ((a_sm + a_ln) / (sum_sm + sum_ln)).to(q.dtype) # Save attention weights - # Allow outputs to also depend on prior kv_state and k_state - y = torch.einsum('bhmn,bhnd->bhmd', a_sm + a_ln, v.float()) - if kv_state is not None: # Combine with prior kv_state and k_state - y += linear_factor * torch.einsum('bhld,bhdf->bhlf', f_q.float(), kv_state.float()) - sum_ln += linear_factor * torch.einsum( - 'bhld,bhnd->bhl', f_q.float(), k_state.float())[..., None] - y = (y / (sum_sm + sum_ln)).to(q.dtype) - return y, a # attention weights only for the last chunk - - -# --------------------- -# Attention layer class -# --------------------- -class LolcatsTiledSlidingWindowAttention(LolcatsLinearAttention): - """ - Lolcats attention combining sliding window and linear attention - """ - def __init__(self, - window_size: int = 64, - tile_size: int = 32, - decode_window_size: int = None, - affine_attention_factors: bool = False, - init_window_factor: float = 0, - train_window_factor: bool = True, - state_grad_enabled: bool = False, - mse_factor: float = 1e3, - xent_factor: float = 1, # 1, - **kwargs): - self.window_size = window_size - self.tile_size = tile_size - self.decode_window_size = ( - decode_window_size if decode_window_size is not None else window_size - ) - self.window_kwargs = {'dimension': 2, 'size': window_size, 'step': 1} - super().__init__(**kwargs) - self.attention_type = kwargs['attention_type'] # 'lolcats_llama_window_sw' - # Determine how we compute attentions - self.quadratic_attention = hybrid_attention_quadratic - self.attention_type = kwargs['attention_type'] # 'lolcats_long_llama_window_sw' - # Learnable factor for combining attentions - self.affine_attention_factors = affine_attention_factors - device, dtype = self.q_proj.weight.device, self.q_proj.weight.dtype - if train_window_factor: - self.window_factors = nn.Parameter( - init_window_factor * torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype)) - else: - self.register_buffer( - "window_factors", init_window_factor * torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype) - ) - # Whether we use original flash attention 2 inference (use during attention transfer) - self.base_inference = False - self.state_grad_enabled = state_grad_enabled - - # Losses - self.criterion_xent = nn.CrossEntropyLoss(reduction='mean') # stability - self.criterion_mse = nn.MSELoss(reduction='mean') - self.mse_factor = mse_factor - self.xent_factor = xent_factor - - def forward(self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """ - Forward pass with the option to compute attention weights multiple ways - if self.train_attention is True - -> Consistent with HuggingFace Transformers for easy use with their pretrained models - """ - b, l, _ = hidden_states.size() - q, k, v, kv_seq_len = self.process_qkv(hidden_states, attention_mask, - position_ids, past_key_value) - f_q, f_k = self.feature_map_q(q), self.feature_map_k(k) # Have to do after repeat for grouped-query attn if we use same fmap - - y_true = [] - if self.train_attention: - for tile_idx in tqdm(range(l // self.tile_size), leave=False, desc=f'Layer {self.layer_idx} processing attention tiles'): - loss = 0 - loss_mse = 0 - loss_xent = 0 - start, end = tile_idx * self.tile_size, (tile_idx + 1) * self.tile_size - q_tile, f_q_tile = q[:, :, start:end], f_q[:, :, start:end] - k_tile, f_k_tile, v_tile = k[:, :, :end], f_k[:, :, :end], v[:, :, :end] - - # 1. Compute "ground-truth" attention output and weights - with torch.no_grad(): - _y_true, a_true = softmax_attention(q_tile, k_tile, v_tile)[:2] - y_true_tile = _y_true.transpose(1, 2).contiguous().view(b, self.tile_size, self.hidden_size) - y_true_tile = self.o_proj(y_true_tile).cpu() - y_true.append(y_true_tile) - - # 2. Compute "predicted" attention outputs - # compute attn weights under sliding window - window_factors = F.sigmoid(self.window_factors) - linear_factors = 1 - window_factors if self.affine_attention_factors else 1 - y_pred, a_pred = self.quadratic_attention(q_tile, k_tile, f_q_tile, f_k_tile, v_tile, - window_factors, linear_factors, - window_size=self.window_size) - if self.mse_factor > 0: - loss_mse = self.mse_factor * self.criterion_mse(y_pred, _y_true) - if self.xent_factor > 0: - k_len = a_pred.shape[-1] - loss_xent = self.xent_factor * self.criterion_xent( - a_pred.contiguous().view(-1, k_len).clamp(min=1e-12).log(), - a_true.contiguous().view(-1, k_len), - ) - loss += (loss_mse + loss_xent) - # if self.xent_factor > 0: - # breakpoint() - if self.feature_map_q.training: - loss.backward() - del a_pred; del a_true; del y_pred; del _y_true - # except: - # pass - attn_weights = (loss.cpu(), loss_mse.cpu(), loss_xent.cpu() if self.xent_factor > 0 else 0) # hack - y_true = torch.cat(y_true, dim=-2).to(q.device) # b, h, l, d - else: - attn_weights = None - # attention_mask = None # For now this is always True - if past_key_value is None: # Regular training - window_factors = F.sigmoid(self.window_factors) - linear_factors = 1 - window_factors if self.affine_attention_factors else 1 - y_true, a_pred = self.quadratic_attention(q, k, f_q, f_k, v, - window_factors, linear_factors, - window_size=self.window_size) - attn_weights = a_pred - else: - past_key_value.window_size = self.decode_window_size - if f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training: # Generating - assert use_cache is True - _kv = past_key_value.update_for_decoding(k, v, self.layer_idx, - self.feature_map_k, - dtype=q.dtype) - k_cache, v_cache, f_kv_state, f_k_state = _kv - - # Sliding window + linear attention decode - window_factors = F.sigmoid(self.window_factors) - linear_factors = 1 - window_factors if self.affine_attention_factors else 1 - - # Softmax attention terms - a_sm = torch.einsum('bhmd,bhnd->bhmn', q.float(), k_cache.float()) * (k.shape[-1] ** -0.5) - a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True) - a_sm = window_factors * torch.exp(a_sm - a_sm_max) - sum_sm = a_sm.sum(dim=-1, keepdim=True) - - # Combine with linear attention terms - y_true = (torch.einsum('bhmn,bhnd->bhmd', a_sm, v_cache.float()) - + linear_factors * torch.einsum('bhlf,bhfd->bhld', f_q.float(), f_kv_state.float())) - sum_ln = linear_factors * torch.einsum( - 'bhlf,bhnf->bhl', f_q.float(), f_k_state.float())[..., None] - y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype) - - else: # Stateful training - try: - kv_state = past_key_value.kv_states[self.layer_idx] - k_state = past_key_value.k_states[self.layer_idx] - except IndexError: - kv_state, k_state = None, None - window_factors = F.sigmoid(self.window_factors) - linear_factors = 1 - window_factors if self.affine_attention_factors else 1 - y_true, _ = self.quadratic_attention(q, k, f_q, f_k, v, - window_factors, linear_factors, - window_size=self.window_size, - kv_state=kv_state, - k_state=k_state) - # Save and update KV cache and states - # past_key_value.update(k, v.detach(), self.layer_idx, - # fmap_key_states=f_k.detach(), - # accumulate_in_fp32=True) - past_key_value.update(k, v, self.layer_idx, - fmap_key_states=f_k, - accumulate_in_fp32=True) - # Concatenate heads and apply output projection - y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size) - y_true = self.o_proj(y_true) - return y_true, attn_weights, past_key_value - - -class LinearAttentionSlidingWindowCache(LinearAttentionState): - """ - Class for `past_key_values` - -> Alternative to KV cache; here we only maintain a "KV state" and "K state" - -> Modified from transformers.cache_utils.DynamicCache (v4.36) - """ - def __init__(self, window_size: int = 64) -> None: - super().__init__() - self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36 - self._seen_tokens_by_layer: List[int] = [] - self.kv_states: List[torch.Tensor] = [] - self.k_states: List[torch.Tensor] = [] - - # Account for sliding windows - self.decode_kv_states: List[torch.Tensor] = [] - self.decode_k_states: List[torch.Tensor] = [] - self.k_cache: List[torch.Tensor] = [] - self.v_cache: List[torch.Tensor] = [] - self.window_size = window_size - - def update(self, key_states: torch.Tensor, value_states: torch.Tensor, - layer_idx: Optional[int] = None, cache_kwargs: Optional[any] = None, - accumulate_in_fp32: bool = False, - fmap_key_states: torch.Tensor = None, # should not be None - grad_enabled: bool = False, - **kwargs: any, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Update KV, K states; and KV cache during training - - For decoding, use `self.decode_kv_states` to keep track of KV states - up to sliding window terms - - For (chunked) training, use `self.kv_states` to keep track of KV states - up to end of sequence - - Likewise for `self.decode_k_states` and `self.k_states` - """ - with torch.set_grad_enabled(grad_enabled): - if layer_idx == 0: - self._seen_tokens += key_states.shape[-2] - - dtype = key_states.dtype - if accumulate_in_fp32: - # key_states = key_states.float() - fmap_key_states = fmap_key_states.float() - value_states = value_states.float() - - # Decoding KV state (KV terms up to last window_size) - decode_kv_state = torch.einsum( - 'bhlf,bhld->bhfd', fmap_key_states[:, :, :-self.window_size], value_states[:, :, :-self.window_size] - ) - # KV state - kv_state = decode_kv_state + torch.einsum( - 'bhlf,bhld->bhfd', fmap_key_states[:, :, -self.window_size:], value_states[:, :, -self.window_size:] - ) - # shape is b, h, 1, f; note the 1 - decode_k_state = fmap_key_states[:, :, :-self.window_size].sum(dim=-2, keepdim=True) - k_state = (decode_k_state + fmap_key_states[:, :, -self.window_size:].sum(dim=-2, keepdim=True)) - - # Update the cache - if len(self.k_states) <= layer_idx: # Initializing kv and k states - self.kv_states.append(kv_state.to(dtype)) - self.k_states.append(k_state.to(dtype)) - - self.decode_kv_states.append(decode_kv_state.to(dtype)) - self.decode_k_states.append(decode_k_state.to(dtype)) - - self.k_cache.append(key_states[:, :, -self.window_size:, :]) - self.v_cache.append(value_states[:, :, -self.window_size:, :].to(dtype)) - # self._seen_tokens_by_layer[layer_idx].append(key_states.shape[-2]) - else: - # Update kv and k states recurrently - kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(dtype) - k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(dtype) - self.kv_states[layer_idx] = kv_state - self.k_states[layer_idx] = k_state - - decode_kv_state = (self.decode_kv_states[layer_idx].to(kv_state.dtype) - + decode_kv_state).to(dtype) - decode_k_state = (self.decode_k_states[layer_idx].to(kv_state.dtype) - + decode_k_state).to(dtype) - self.decode_kv_states[layer_idx] = decode_kv_state - self.decode_k_states[layer_idx] = decode_k_state - - self.k_cache[layer_idx] = key_states[:, :, -self.window_size:, :] - self.v_cache[layer_idx] = value_states[:, :, -self.window_size:, :] - self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2] - - return self.kv_states[layer_idx], self.k_states[layer_idx] - - def update_for_decoding(self, keys: torch.Tensor, values: torch.Tensor, - layer_idx: int, feature_map_k: Callable, dtype: torch.dtype): - """ - Update the decoding KV and K states, and KV cache, during decodeing - """ - with torch.no_grad(): - k_cache = self.k_cache[layer_idx] - v_cache = self.v_cache[layer_idx] - - if k_cache.shape[-2] < self.window_size: # build window-size cache - self.k_cache[layer_idx] = torch.cat([k_cache, keys], dim=-2) - self.v_cache[layer_idx] = torch.cat([v_cache, values], dim=-2) - else: - # MZ 6/3: handle short inputs; zero-out padding when initial k.shape[2] < self.window_size - # if k_cache[:, :, :1, :].sum() == 0: # heuristic for zeroing out padding in cache - # f_k_state = torch.zeros(k_cache[:, :, :1, :].shape, dtype=dtype, device=k_cache.device) - # else: - # f_k_state = feature_map_k(k_cache[:, :, :1, :]) - # -> MZ (later): above only relevant if we zero-pad in our hybrid attention computation - k_state = feature_map_k(k_cache[:, :, :1, :]) - v_state = v_cache[:, :, :1, :] - kv_state = torch.einsum('bhlf,bhld->bhfd', k_state.float(), v_state.float()).to(dtype) # b, h, f, d - self.decode_kv_states[layer_idx] += kv_state - self.decode_k_states[layer_idx] += k_state - - self.k_cache[layer_idx] = torch.cat([k_cache[:, :, 1:, :], keys], dim=-2) - self.v_cache[layer_idx] = torch.cat([v_cache[:, :, 1:, :], values], dim=-2) - - if layer_idx == 0: - self._seen_tokens += keys.shape[-2] - self._seen_tokens_by_layer[layer_idx] += keys.shape[-2] - return (self.k_cache[layer_idx], self.v_cache[layer_idx], - self.decode_kv_states[layer_idx], self.decode_k_states[layer_idx]) diff --git a/src/model/linear_attention/linear_window_attention_tk.py b/src/model/linear_attention/linear_window_attention_tk.py index 313e95b..37e81dc 100644 --- a/src/model/linear_attention/linear_window_attention_tk.py +++ b/src/model/linear_attention/linear_window_attention_tk.py @@ -122,6 +122,7 @@ def __init__(self, # Whether we use original flash attention 2 inference (use during attention transfer) self.base_inference = False self.state_grad_enabled = state_grad_enabled + self.window_factor = self.window_factors # legacy naming support def forward(self, hidden_states: torch.Tensor, diff --git a/src/model/linear_attention/linear_window_attention_tk_bf16.py b/src/model/linear_attention/linear_window_attention_tk_bf16.py deleted file mode 100644 index 97fed11..0000000 --- a/src/model/linear_attention/linear_window_attention_tk_bf16.py +++ /dev/null @@ -1,354 +0,0 @@ -""" -Subquadratic attention combining sliding window and linear attentions -- Using the TK "terracing" arrangement - -For each layer: -- We first compute (softmax) attention over sliding windows -- We then compute standard linear attention to "fill in" the earlier parts -- We combine to model the entire sequence -""" -from typing import List, Tuple, Optional, Callable -import math -import torch -import torch.nn as nn -import torch.nn.functional as F - -from transformers.cache_utils import Cache - -from .linear_attention import ( - LolcatsLinearAttention, LinearAttentionState, softmax_attention -) - -# ---------------------- -# Sliding window helpers -# ---------------------- -def get_masks(window_size: int, q_len: int, k_len: int, - device: torch.device) -> tuple[torch.Tensor]: - """ - Return masks for softmax and linear attention terms - -> 1 is include, 0 is ignore - """ - kwargs = {'device': device, 'dtype': int} - l = window_size - m = math.ceil(max(q_len, k_len) / window_size) - # Creates an n x n mask where n = window_size^2 - mask = torch.block_diag(*[torch.ones((l, l), **kwargs)] * m) - mask += torch.roll(mask, -l, -1) # this adds the terracing - if mask.shape[0] > q_len: - mask = mask[-q_len:] - if mask.shape[1] > k_len: - mask = mask[:, -k_len:] - # Return softmax mask (window), linear attention mask - mask = mask[None, None, ...] # b, h, q_len, k_len - # mask = torch.tril(mask) - # return mask, ~mask - return torch.tril(mask).to(**kwargs), torch.tril(1 - mask).to(**kwargs) - - -def hybrid_attention_quadratic(q: torch.Tensor, k: torch.Tensor, - f_q: torch.Tensor, f_k: torch.Tensor, - v: torch.Tensor, - window_factor: torch.Tensor, - linear_factor: torch.Tensor, - window_size: int, - kv_state: torch.Tensor = None, - k_state: torch.Tensor = None, - eps: float = 1e-12, - mask_value: float = -1e12) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Hybrid attention combining sliding window and linear attentions - """ - - mask_window, mask_linear = get_masks(window_size, q.shape[-2], k.shape[-2], q.device) - - # 1. Sliding window (softmax attention) - a_sm = torch.einsum('bhmd,bhnd->bhmn', q, k) * (k.shape[-1] ** -0.5) - a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value) - # torch.softmax(a_sm, dim=-1), but we account for the max when combining - a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True) - # a_sm = window_factor * torch.exp((a_sm - a_sm_max).float()).to(q.dtype) - a_sm = window_factor * torch.exp(a_sm - a_sm_max) - # sum_sm = a_sm.sum(dim=-1, keepdim=True) - - # 2. Under window (linear attention) - a_ln = torch.einsum('bhmd,bhnd->bhmn', f_q, f_k) # .to(q.dtype) - a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0) - # sum_ln = a_ln.sum(dim=-1, keepdim=True) - - # 3. Combine - a = ((a_sm + a_ln) / - (a_sm.sum(dim=-1, keepdim=True) + a_ln.sum(dim=-1, keepdim=True))) # .to(q.dtype) # Save attention weights - y = torch.einsum('bhmn,bhnd->bhmd', a, v) # .float()) - return y, a - - # Allow outputs to also depend on prior kv_state and k_state - y = torch.einsum('bhmn,bhnd->bhmd', a_sm + a_ln, v) # .float()) - if kv_state is not None: # Combine with prior kv_state and k_state - y += linear_factor * torch.einsum('bhld,bhdf->bhlf', f_q, kv_state) - sum_ln += linear_factor * torch.einsum( - 'bhld,bhnd->bhl', f_q, k_state)[..., None] - y = (y / (sum_sm + sum_ln)) # .to(q.dtype) - return y, a # attention weights only for the last chunk - - -# --------------------- -# Attention layer class -# --------------------- -class LolcatsTKWindowAttentionBF16(LolcatsLinearAttention): - """ - Lolcats attention combining sliding window and linear attention - """ - def __init__(self, - window_size: int = 64, - decode_window_size: int = None, - affine_attention_factors: bool = False, - init_window_factor: float = 0, - state_grad_enabled: bool = False, - **kwargs): - self.window_size = window_size - self.decode_window_size = ( - decode_window_size if decode_window_size is not None else window_size - ) - self.window_kwargs = {'dimension': 2, 'size': window_size, 'step': 1} - super().__init__(**kwargs) - self.attention_type = kwargs['attention_type'] # 'hedgehog_llama_window_tk' - # Determine how we compute attentions - self.quadratic_attention = hybrid_attention_quadratic - self.attention_type = kwargs['attention_type'] # 'hedgehog_long_llama_window_tk' - # Learnable factor for combining attentions - self.affine_attention_factors = affine_attention_factors - device, dtype = self.q_proj.weight.device, self.q_proj.weight.dtype - self.window_factors = nn.Parameter( - init_window_factor * torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype)) - # Whether we use original flash attention 2 inference (use during attention transfer) - self.base_inference = False - self.state_grad_enabled = state_grad_enabled - - def forward(self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """ - Forward pass with the option to compute attention weights multiple ways - if self.train_attention is True - -> Consistent with HuggingFace Transformers for easy use with their pretrained models - """ - b, l, _ = hidden_states.size() - q, k, v, kv_seq_len = self.process_qkv(hidden_states, attention_mask, - position_ids, past_key_value) - f_q, f_k = self.feature_map_q(q), self.feature_map_k(k) # Have to do after repeat for grouped-query attn if we use same fmap - - if self.train_attention: - # 1. Compute "ground-truth" attention output and weights - with torch.no_grad(): - _y_true, a_true = softmax_attention(q, k, v, fp32_attention=False)[:2] - y_true = _y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size) - y_true = self.o_proj(y_true) - # # Save memory - # device = y_true.device - # y_true = y_true.cpu() - # a_true, _y_true = a_true.cpu(), _y_true.cpu() - # torch.cuda.empty_cache() - - # 2. Compute "predicted" attention outputs - # compute attn weights under sliding window - window_factors = F.sigmoid(self.window_factors) - linear_factors = 1 - window_factors if self.affine_attention_factors else 1 - y_pred, a_pred = self.quadratic_attention(q, k, f_q, f_k, v, - window_factors, linear_factors, - window_size=self.window_size) - # a_pred, y_pred = a_pred.cpu(), y_pred.cpu() - # torch.cuda.empty_cache() - attn_weights = ((a_pred, a_true), (y_pred, _y_true)) - # y_true = y_true.to(device) - else: - attn_weights = None - # attention_mask = None # For now this is always True - if past_key_value is None: # Regular training - window_factors = F.sigmoid(self.window_factors) - linear_factors = 1 - window_factors if self.affine_attention_factors else 1 - y_true, a_pred = self.quadratic_attention(q, k, f_q, f_k, v, - window_factors, linear_factors, - window_size=self.window_size) - else: - past_key_value.window_size = self.decode_window_size - if f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training: # Generating - assert use_cache is True - _kv = past_key_value.update_for_decoding(k, v, self.layer_idx, - self.feature_map_k, - dtype=q.dtype) - k_cache, v_cache, kv_state, k_state = _kv - - # Sliding window + linear attention decode - window_factors = F.sigmoid(self.window_factors) - linear_factors = 1 - window_factors if self.affine_attention_factors else 1 - - # Softmax attention terms - a_sm = torch.einsum('bhmd,bhnd->bhmn', q.float(), k_cache.float()) * (k.shape[-1] ** -0.5) - a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True) - a_sm = window_factors * torch.exp(a_sm - a_sm_max) - sum_sm = a_sm.sum(dim=-1, keepdim=True) - - # Combine with linear attention terms - y_true = (torch.einsum('bhmn,bhnd->bhmd', a_sm, v_cache.float()) - + linear_factors * torch.einsum('bhld,bhdf->bhlf', f_q.float(), kv_state.float())) - sum_ln = linear_factors * torch.einsum( - 'bhld,bhnd->bhl', f_q.float(), k_state.float())[..., None] - y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype) - - else: # Stateful training - try: - kv_state = past_key_value.kv_states[self.layer_idx] - k_state = past_key_value.k_states[self.layer_idx] - except IndexError: - kv_state, k_state = None, None - window_factors = F.sigmoid(self.window_factors) - linear_factors = 1 - window_factors if self.affine_attention_factors else 1 - y_true, _ = self.quadratic_attention(q, k, f_q, f_k, v, - window_factors, linear_factors, - window_size=self.window_size, - kv_state=kv_state, - k_state=k_state) - # Save and update KV cache and states - # past_key_value.update(k, v.detach(), self.layer_idx, - # fmap_key_states=f_k.detach(), - # accumulate_in_fp32=True) - past_key_value.update(k, v, self.layer_idx, - fmap_key_states=f_k, - accumulate_in_fp32=True) - # Concatenate heads and apply output projection - y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size) - y_true = self.o_proj(y_true) - return y_true, attn_weights, past_key_value - - -class LinearAttentionTKWindowCache(LinearAttentionState): - """ - Class for `past_key_values` - -> Alternative to KV cache; here we only maintain a "KV state" and "K state" - -> Modified from transformers.cache_utils.DynamicCache (v4.36) - """ - def __init__(self, window_size: int = 64) -> None: - super().__init__() - self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36 - self._seen_tokens_by_layer: List[int] = [] - self.kv_states: List[torch.Tensor] = [] - self.k_states: List[torch.Tensor] = [] - - # Account for sliding windows - self.decode_kv_states: List[torch.Tensor] = [] - self.decode_k_states: List[torch.Tensor] = [] - self.k_cache: List[torch.Tensor] = [] - self.v_cache: List[torch.Tensor] = [] - self.window_size = window_size - - def update(self, key_states: torch.Tensor, value_states: torch.Tensor, - layer_idx: Optional[int] = None, cache_kwargs: Optional[any] = None, - accumulate_in_fp32: bool = False, - fmap_key_states: torch.Tensor = None, # should not be None - grad_enabled: bool = False, - **kwargs: any, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Update KV, K states; and KV cache during training - - For decoding, use `self.decode_kv_states` to keep track of KV states - up to sliding window terms - - For (chunked) training, use `self.kv_states` to keep track of KV states - up to end of sequence - - Likewise for `self.decode_k_states` and `self.k_states` - """ - with torch.set_grad_enabled(grad_enabled): - if layer_idx == 0: - self._seen_tokens += key_states.shape[-2] - - dtype = key_states.dtype - # if accumulate_in_fp32: - # # key_states = key_states.float() - # fmap_key_states = fmap_key_states.float() - # value_states = value_states.float() - - # Decoding KV state (KV terms up to last window_size) - decode_kv_state = torch.einsum( - 'bhlf,bhld->bhfd', - fmap_key_states[:, :, :-self.window_size].float(), - value_states[:, :, :-self.window_size].float() - ).to(dtype) - # KV state - kv_state = decode_kv_state + torch.einsum( - 'bhlf,bhld->bhfd', - fmap_key_states[:, :, -self.window_size:].float(), - value_states[:, :, -self.window_size:].float() - ).to(dtype) - # shape is b, h, 1, f; note the 1 - decode_k_state = ( - fmap_key_states[:, :, :-self.window_size].float().sum(dim=-2, keepdim=True).to(dtype) - ) - k_state = ( - decode_k_state + - fmap_key_states[:, :, -self.window_size:].float().sum(dim=-2, keepdim=True).to(dtype) - ) - # Update the cache - if len(self.k_states) <= layer_idx: # Initializing kv and k states - self.kv_states.append(kv_state.to(dtype)) - self.k_states.append(k_state.to(dtype)) - - self.decode_kv_states.append(decode_kv_state.to(dtype)) - self.decode_k_states.append(decode_k_state.to(dtype)) - - self.k_cache.append(key_states[:, :, -self.window_size:, :]) - self.v_cache.append(value_states[:, :, -self.window_size:, :].to(dtype)) - # self._seen_tokens_by_layer[layer_idx].append(key_states.shape[-2]) - else: - # Update kv and k states recurrently - kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(dtype) - k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(dtype) - self.kv_states[layer_idx] = kv_state - self.k_states[layer_idx] = k_state - - decode_kv_state = (self.decode_kv_states[layer_idx].to(kv_state.dtype) - + decode_kv_state).to(dtype) - decode_k_state = (self.decode_k_states[layer_idx].to(kv_state.dtype) - + decode_k_state).to(dtype) - self.decode_kv_states[layer_idx] = decode_kv_state - self.decode_k_states[layer_idx] = decode_k_state - - self.k_cache[layer_idx] = key_states[:, :, -self.window_size:, :] - self.v_cache[layer_idx] = value_states[:, :, -self.window_size:, :] - self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2] - - return self.kv_states[layer_idx], self.k_states[layer_idx] - - def update_for_decoding(self, keys: torch.Tensor, values: torch.Tensor, - layer_idx: int, feature_map_k: Callable, dtype: torch.dtype): - """ - Update the decoding KV and K states, and KV cache, during decodeing - """ - with torch.no_grad(): - k_cache = self.k_cache[layer_idx] - v_cache = self.v_cache[layer_idx] - k_state = feature_map_k(k_cache[:, :, :1, :]) - v_state = v_cache[:, :, :1, :] - # try: - # k_state = feature_map_k(k_cache[:, :, :1, :]) - # except Exception as e: - # print(f'layer_idx:', layer_idx) - # print(f'k_cache.shape:', k_cache.shape) - # print(f'v_cache.shape:', v_cache.shape) - # breakpoint() - kv_state = torch.einsum('bhlf,bhld->bhfd', k_state.float(), v_state.float()).to(dtype) # b, h, f, d - self.decode_kv_states[layer_idx] += kv_state - self.decode_k_states[layer_idx] += k_state - - self.k_cache[layer_idx] = torch.cat([k_cache[:, :, 1:, :], keys], dim=-2) - self.v_cache[layer_idx] = torch.cat([v_cache[:, :, 1:, :], values], dim=-2) - - if layer_idx == 0: - self._seen_tokens += keys.shape[-2] - self._seen_tokens_by_layer[layer_idx] += keys.shape[-2] - return (self.k_cache[layer_idx], self.v_cache[layer_idx], - self.decode_kv_states[layer_idx], self.decode_k_states[layer_idx]) \ No newline at end of file diff --git a/src/model/linear_attention/linear_window_attention_tk_fa2.py b/src/model/linear_attention/linear_window_attention_tk_fa2.py deleted file mode 100644 index 67def6d..0000000 --- a/src/model/linear_attention/linear_window_attention_tk_fa2.py +++ /dev/null @@ -1,110 +0,0 @@ -""" -Use FlashAttention2 to compute ground-truth attention outputs -""" -from typing import Optional, Tuple -import torch -import torch.nn.functional as F - -from transformers.cache_utils import Cache -try: - from transformers.modeling_flash_attention_utils import _flash_attention_forward -except ModuleNotFoundError: - _flash_attention_forward = None # Transformers v4.36 - -from src.model.rotary import apply_rotary_pos_emb -from .linear_window_attention_tk_long import flash_attention_2 -from .linear_window_attention_tk import LolcatsTKWindowAttention - - -class LolcatsTKWindowAttentionFA2(LolcatsTKWindowAttention): - """ - Lolcats attention combining sliding window and linear attention - """ - def __init__(self, remove_base_attn=False, **kwargs): - # keep self.base_attn for Flash Attention inference - super().__init__(remove_base_attn=False, **kwargs) - - def forward(self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """ - Forward pass with the option to compute attention weights multiple ways - if self.train_attention is True - -> Consistent with HuggingFace Transformers for easy use with their pretrained models - """ - b, l, _ = hidden_states.size() - if self.train_attention: - with torch.no_grad(): - _y_true = flash_attention_2(self.base_attn, - hidden_states=hidden_states, - attention_mask=None, - position_ids=position_ids, - past_key_value=None, - output_attentions=False, - # output_hidden_states=False, - use_cache=False)[0] - # _y_true.shape is (batch_size, seq_len, num_heads, head_dim) - y_true = _y_true.reshape(b, l, -1).contiguous() - y_true = self.o_proj(y_true) - else: - y_true = None - - q, k, v, kv_seq_len = self.process_qkv(hidden_states, attention_mask, - position_ids, past_key_value) - f_q, f_k = self.feature_map_q(q), self.feature_map_k(k) - - # attention_mask = None # For now this is always True - if past_key_value is None: # Regular training - window_factors = F.sigmoid(self.window_factors) - linear_factors = 1 - window_factors if self.affine_attention_factors else 1 - y_pred, a_pred = self.quadratic_attention(q, k, f_q, f_k, v, - window_factors, linear_factors, - window_size=self.window_size,) - else: - past_key_value.window_size = self.decode_window_size - if f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training: # Generating - assert use_cache is True - _kv = past_key_value.update_for_decoding(k, v, self.layer_idx, - self.feature_map_k, - dtype=q.dtype) - k_cache, v_cache, kv_state, k_state = _kv - - # Sliding window + linear attention decode - window_factors = F.sigmoid(self.window_factors) - linear_factors = 1 - window_factors if self.affine_attention_factors else 1 - - a_sm = torch.einsum('bhmd,bhnd->bhmn', q.float(), k_cache.float()) * (k.shape[-1] ** -0.5) - # a_sm = torch.softmax(a_sm, dim=-1) - a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True) - a_sm = window_factors * torch.exp(a_sm - a_sm_max) - sum_sm = a_sm.sum(dim=-1, keepdim=True) - - y_pred = (torch.einsum('bhmn,bhnd->bhmd', a_sm, v_cache.float()) - + linear_factors * torch.einsum('bhld,bhdf->bhlf', f_q.float(), kv_state.float())) - sum_ln = linear_factors * torch.einsum('bhld,bhnd->bhl', f_q.float(), k_state.float())[..., None] - y_pred = (y_pred / (sum_sm + sum_ln)).to(q.dtype) - - else: - window_factors = F.sigmoid(self.window_factors) - linear_factors = 1 - window_factors if self.affine_attention_factors else 1 - y_pred, a_pred = self.quadratic_attention(q, k, f_q, f_k, v, - window_factors, linear_factors, - window_size=self.window_size, - kv_state=None, - k_state=None,) - - # Concatenate heads and apply output projection - _y_pred = y_pred.transpose(1, 2).contiguous() - y_pred = self.o_proj(_y_pred.view(b, l, self.hidden_size)) - - if self.train_attention: - attn_weights = (None, (_y_pred, _y_true)) # flash_attn outputs are shape (b, l, h, d) - else: - attn_weights = None - return y_pred, attn_weights, past_key_value diff --git a/src/model/linear_attention/linear_window_attention_tk_fp32.py b/src/model/linear_attention/linear_window_attention_tk_fp32.py deleted file mode 100644 index 04232de..0000000 --- a/src/model/linear_attention/linear_window_attention_tk_fp32.py +++ /dev/null @@ -1,338 +0,0 @@ -""" -Subquadratic attention combining sliding window and linear attentions -- Using the TK "terracing" arrangement - -For each layer: -- We first compute (softmax) attention over sliding windows -- We then compute standard linear attention to "fill in" the earlier parts -- We combine to model the entire sequence -""" -from typing import List, Tuple, Optional, Callable -import math -import torch -import torch.nn as nn -import torch.nn.functional as F - -from transformers.cache_utils import Cache - -from .linear_attention import ( - LolcatsLinearAttention, LinearAttentionState, softmax_attention -) - -# ---------------------- -# Sliding window helpers -# ---------------------- -def get_masks(window_size: int, q_len: int, k_len: int, - device: torch.device) -> tuple[torch.Tensor]: - """ - Return masks for softmax and linear attention terms - -> 1 is include, 0 is ignore - """ - kwargs = {'device': device, 'dtype': int} - l = window_size - m = math.ceil(max(q_len, k_len) / window_size) - # Creates an n x n mask where n = window_size^2 - mask = torch.block_diag(*[torch.ones((l, l), )] * m) - mask += torch.roll(mask, -l, -1) # this adds the terracing - if mask.shape[0] > q_len: - mask = mask[-q_len:] - if mask.shape[1] > k_len: - mask = mask[:, -k_len:] - # Return softmax mask (window), linear attention mask - mask = mask[None, None, ...] # b, h, q_len, k_len - return torch.tril(mask).to(**kwargs), torch.tril(1 - mask).to(**kwargs) - - -def hybrid_attention_quadratic(q: torch.Tensor, k: torch.Tensor, - f_q: torch.Tensor, f_k: torch.Tensor, - v: torch.Tensor, - window_factor: torch.Tensor, - linear_factor: torch.Tensor, - window_size: int, - kv_state: torch.Tensor = None, - k_state: torch.Tensor = None, - eps: float = 1e-12, - mask_value: float = -1e12) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Hybrid attention combining sliding window and linear attentions - """ - - mask_window, mask_linear = get_masks(window_size, q.shape[-2], k.shape[-2], q.device) - - # 1. Sliding window (softmax attention) - a_sm = torch.einsum('bhmd,bhnd->bhmn', q.float(), k.float()) * (k.shape[-1] ** -0.5) - a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value) - # torch.softmax(a_sm, dim=-1), but we account for the max when combining - a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True) - a_sm = window_factor * torch.exp(a_sm - a_sm_max).to(q.dtype) - sum_sm = a_sm.sum(dim=-1, keepdim=True) - - # 2. Under window (linear attention) - a_ln = torch.einsum('bhmd,bhnd->bhmn', f_q.float(), f_k.float()).to(q.dtype) - a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0) - sum_ln = a_ln.float().sum(dim=-1, keepdim=True).to(q.dtype) - # sum_ln = a_ln.sum(dim=-1, keepdim=True) - - # 3. Combine - a = ((a_sm + a_ln) / (sum_sm + sum_ln)) # .to(q.dtype) # Save attention weights - breakpoint() - # Allow outputs to also depend on prior kv_state and k_state - y = torch.einsum('bhmn,bhnd->bhmd', a_sm + a_ln, v) # .float()) - if kv_state is not None: # Combine with prior kv_state and k_state - y += linear_factor * torch.einsum('bhld,bhdf->bhlf', f_q.float(), kv_state.float()).to(q.dtype) - sum_ln += linear_factor * torch.einsum( - 'bhld,bhnd->bhl', f_q.float(), k_state.float())[..., None].to(q.dtype) - y = (y / (sum_sm + sum_ln)) # .to(q.dtype) - return y, a # attention weights only for the last chunk - - -# --------------------- -# Attention layer class -# --------------------- -class LolcatsTKWindowAttention(LolcatsLinearAttention): - """ - Lolcats attention combining sliding window and linear attention - """ - def __init__(self, - window_size: int = 64, - decode_window_size: int = None, - affine_attention_factors: bool = False, - init_window_factor: float = 0, - state_grad_enabled: bool = False, - **kwargs): - self.window_size = window_size - self.decode_window_size = ( - decode_window_size if decode_window_size is not None else window_size - ) - self.window_kwargs = {'dimension': 2, 'size': window_size, 'step': 1} - super().__init__(**kwargs) - self.attention_type = kwargs['attention_type'] # 'hedgehog_llama_window_tk' - # Determine how we compute attentions - self.quadratic_attention = hybrid_attention_quadratic - self.attention_type = kwargs['attention_type'] # 'hedgehog_long_llama_window_tk' - # Learnable factor for combining attentions - self.affine_attention_factors = affine_attention_factors - device, dtype = self.q_proj.weight.device, self.q_proj.weight.dtype - self.window_factors = nn.Parameter( - init_window_factor * torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype)) - # Whether we use original flash attention 2 inference (use during attention transfer) - self.base_inference = False - self.state_grad_enabled = state_grad_enabled - - def forward(self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """ - Forward pass with the option to compute attention weights multiple ways - if self.train_attention is True - -> Consistent with HuggingFace Transformers for easy use with their pretrained models - """ - b, l, _ = hidden_states.size() - q, k, v, kv_seq_len = self.process_qkv(hidden_states, attention_mask, - position_ids, past_key_value) - f_q, f_k = self.feature_map_q(q), self.feature_map_k(k) # Have to do after repeat for grouped-query attn if we use same fmap - - if self.train_attention: - # 1. Compute "ground-truth" attention output and weights - with torch.no_grad(): - _y_true, a_true = softmax_attention(q, k, v)[:2] - y_true = _y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size) - y_true = self.o_proj(y_true) - - # 2. Compute "predicted" attention outputs - # compute attn weights under sliding window - window_factors = F.sigmoid(self.window_factors) - linear_factors = 1 - window_factors if self.affine_attention_factors else 1 - y_pred, a_pred = self.quadratic_attention(q, k, f_q, f_k, v, - window_factors, linear_factors, - window_size=self.window_size) - attn_weights = ((a_pred, a_true), (y_pred, _y_true)) - else: - attn_weights = None - # attention_mask = None # For now this is always True - if past_key_value is None: # Regular training - window_factors = F.sigmoid(self.window_factors) - linear_factors = 1 - window_factors if self.affine_attention_factors else 1 - y_true, a_pred = self.quadratic_attention(q, k, f_q, f_k, v, - window_factors, linear_factors, - window_size=self.window_size) - else: - past_key_value.window_size = self.decode_window_size - if f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training: # Generating - assert use_cache is True - _kv = past_key_value.update_for_decoding(k, v, self.layer_idx, - self.feature_map_k, - dtype=q.dtype) - k_cache, v_cache, kv_state, k_state = _kv - - # Sliding window + linear attention decode - window_factors = F.sigmoid(self.window_factors) - linear_factors = 1 - window_factors if self.affine_attention_factors else 1 - - # Softmax attention terms - a_sm = torch.einsum('bhmd,bhnd->bhmn', q.float(), k_cache.float()) * (k.shape[-1] ** -0.5) - a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True) - a_sm = window_factors * torch.exp(a_sm - a_sm_max) - sum_sm = a_sm.sum(dim=-1, keepdim=True) - - # Combine with linear attention terms - y_true = (torch.einsum('bhmn,bhnd->bhmd', a_sm, v_cache.float()) - + linear_factors * torch.einsum('bhld,bhdf->bhlf', f_q.float(), kv_state.float())) - sum_ln = linear_factors * torch.einsum( - 'bhld,bhnd->bhl', f_q.float(), k_state.float())[..., None] - y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype) - - else: # Stateful training - try: - kv_state = past_key_value.kv_states[self.layer_idx] - k_state = past_key_value.k_states[self.layer_idx] - except IndexError: - kv_state, k_state = None, None - window_factors = F.sigmoid(self.window_factors) - linear_factors = 1 - window_factors if self.affine_attention_factors else 1 - y_true, _ = self.quadratic_attention(q, k, f_q, f_k, v, - window_factors, linear_factors, - window_size=self.window_size, - kv_state=kv_state, - k_state=k_state) - # Save and update KV cache and states - # past_key_value.update(k, v.detach(), self.layer_idx, - # fmap_key_states=f_k.detach(), - # accumulate_in_fp32=True) - past_key_value.update(k, v, self.layer_idx, - fmap_key_states=f_k, - accumulate_in_fp32=True) - # Concatenate heads and apply output projection - y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size) - y_true = self.o_proj(y_true) - return y_true, attn_weights, past_key_value - - -class LinearAttentionTKWindowCache(LinearAttentionState): - """ - Class for `past_key_values` - -> Alternative to KV cache; here we only maintain a "KV state" and "K state" - -> Modified from transformers.cache_utils.DynamicCache (v4.36) - """ - def __init__(self, window_size: int = 64) -> None: - super().__init__() - self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36 - self._seen_tokens_by_layer: List[int] = [] - self.kv_states: List[torch.Tensor] = [] - self.k_states: List[torch.Tensor] = [] - - # Account for sliding windows - self.decode_kv_states: List[torch.Tensor] = [] - self.decode_k_states: List[torch.Tensor] = [] - self.k_cache: List[torch.Tensor] = [] - self.v_cache: List[torch.Tensor] = [] - self.window_size = window_size - - def update(self, key_states: torch.Tensor, value_states: torch.Tensor, - layer_idx: Optional[int] = None, cache_kwargs: Optional[any] = None, - accumulate_in_fp32: bool = False, - fmap_key_states: torch.Tensor = None, # should not be None - grad_enabled: bool = False, - **kwargs: any, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Update KV, K states; and KV cache during training - - For decoding, use `self.decode_kv_states` to keep track of KV states - up to sliding window terms - - For (chunked) training, use `self.kv_states` to keep track of KV states - up to end of sequence - - Likewise for `self.decode_k_states` and `self.k_states` - """ - with torch.set_grad_enabled(grad_enabled): - if layer_idx == 0: - self._seen_tokens += key_states.shape[-2] - - dtype = key_states.dtype - if accumulate_in_fp32: - # key_states = key_states.float() - fmap_key_states = fmap_key_states.float() - value_states = value_states.float() - - # Decoding KV state (KV terms up to last window_size) - decode_kv_state = torch.einsum( - 'bhlf,bhld->bhfd', - fmap_key_states[:, :, :-self.window_size], - value_states[:, :, :-self.window_size] - ) - # KV state - kv_state = decode_kv_state + torch.einsum( - 'bhlf,bhld->bhfd', - fmap_key_states[:, :, -self.window_size:], - value_states[:, :, -self.window_size:] - ) - # shape is b, h, 1, f; note the 1 - decode_k_state = fmap_key_states[:, :, :-self.window_size].sum(dim=-2, keepdim=True) - k_state = (decode_k_state + - fmap_key_states[:, :, -self.window_size:].sum(dim=-2, keepdim=True)) - - # Update the cache - if len(self.k_states) <= layer_idx: # Initializing kv and k states - self.kv_states.append(kv_state.to(dtype)) - self.k_states.append(k_state.to(dtype)) - - self.decode_kv_states.append(decode_kv_state.to(dtype)) - self.decode_k_states.append(decode_k_state.to(dtype)) - - self.k_cache.append(key_states[:, :, -self.window_size:, :]) - self.v_cache.append(value_states[:, :, -self.window_size:, :].to(dtype)) - # self._seen_tokens_by_layer[layer_idx].append(key_states.shape[-2]) - else: - # Update kv and k states recurrently - kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(dtype) - k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(dtype) - self.kv_states[layer_idx] = kv_state - self.k_states[layer_idx] = k_state - - decode_kv_state = (self.decode_kv_states[layer_idx].to(kv_state.dtype) - + decode_kv_state).to(dtype) - decode_k_state = (self.decode_k_states[layer_idx].to(kv_state.dtype) - + decode_k_state).to(dtype) - self.decode_kv_states[layer_idx] = decode_kv_state - self.decode_k_states[layer_idx] = decode_k_state - - self.k_cache[layer_idx] = key_states[:, :, -self.window_size:, :] - self.v_cache[layer_idx] = value_states[:, :, -self.window_size:, :] - self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2] - - return self.kv_states[layer_idx], self.k_states[layer_idx] - - def update_for_decoding(self, keys: torch.Tensor, values: torch.Tensor, - layer_idx: int, feature_map_k: Callable, dtype: torch.dtype): - """ - Update the decoding KV and K states, and KV cache, during decodeing - """ - with torch.no_grad(): - k_cache = self.k_cache[layer_idx] - v_cache = self.v_cache[layer_idx] - k_state = feature_map_k(k_cache[:, :, :1, :]) - v_state = v_cache[:, :, :1, :] - # try: - # k_state = feature_map_k(k_cache[:, :, :1, :]) - # except Exception as e: - # print(f'layer_idx:', layer_idx) - # print(f'k_cache.shape:', k_cache.shape) - # print(f'v_cache.shape:', v_cache.shape) - # breakpoint() - kv_state = torch.einsum('bhlf,bhld->bhfd', k_state.float(), v_state.float()).to(dtype) # b, h, f, d - self.decode_kv_states[layer_idx] += kv_state - self.decode_k_states[layer_idx] += k_state - - self.k_cache[layer_idx] = torch.cat([k_cache[:, :, 1:, :], keys], dim=-2) - self.v_cache[layer_idx] = torch.cat([v_cache[:, :, 1:, :], values], dim=-2) - - if layer_idx == 0: - self._seen_tokens += keys.shape[-2] - self._seen_tokens_by_layer[layer_idx] += keys.shape[-2] - return (self.k_cache[layer_idx], self.v_cache[layer_idx], - self.decode_kv_states[layer_idx], self.decode_k_states[layer_idx]) \ No newline at end of file diff --git a/src/model/linear_attention/linear_window_attention_tk_sdpa.py b/src/model/linear_attention/linear_window_attention_tk_sdpa.py deleted file mode 100644 index 7d29360..0000000 --- a/src/model/linear_attention/linear_window_attention_tk_sdpa.py +++ /dev/null @@ -1,106 +0,0 @@ -""" -Use HF SDPA to compute ground-truth attention outputs -""" -from typing import Optional, Tuple -import torch -import torch.nn.functional as F - -from transformers.cache_utils import Cache -try: - from transformers.modeling_flash_attention_utils import _flash_attention_forward -except ModuleNotFoundError: - _flash_attention_forward = None # Transformers v4.36 - -from src.model.rotary import apply_rotary_pos_emb -from .linear_window_attention_tk import LolcatsTKWindowAttention - - -class LolcatsTKWindowAttentionSDPA(LolcatsTKWindowAttention): - """ - Lolcats attention combining sliding window and linear attention - """ - def __init__(self, remove_base_attn=False, **kwargs): - # keep self.base_attn for SDPA inference - super().__init__(remove_base_attn=False, **kwargs) - - def forward(self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """ - Forward pass with the option to compute attention weights multiple ways - if self.train_attention is True - -> Consistent with HuggingFace Transformers for easy use with their pretrained models - """ - b, l, _ = hidden_states.size() - if self.train_attention: - with torch.no_grad(): - _y_true = self.base_attn(hidden_states=hidden_states, - attention_mask=None, - position_ids=position_ids, - past_key_value=None, - output_attentions=False, - # output_hidden_states=False, - use_cache=False)[0] - # _y_true.shape is (batch_size, num_heads,seq_len, head_dim) - y_true = _y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size) - y_true = self.o_proj(y_true) - else: - y_true = None - - q, k, v, kv_seq_len = self.process_qkv(hidden_states, attention_mask, - position_ids, past_key_value) - f_q, f_k = self.feature_map_q(q), self.feature_map_k(k) - - # attention_mask = None # For now this is always True - if past_key_value is None: # Regular training - window_factors = F.sigmoid(self.window_factors) - linear_factors = 1 - window_factors if self.affine_attention_factors else 1 - y_pred, a_pred = self.quadratic_attention(q, k, f_q, f_k, v, - window_factors, linear_factors, - window_size=self.window_size,) - else: - past_key_value.window_size = self.decode_window_size - if f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training: # Generating - assert use_cache is True - _kv = past_key_value.update_for_decoding(k, v, self.layer_idx, - self.feature_map_k, - dtype=q.dtype) - k_cache, v_cache, kv_state, k_state = _kv - - # Sliding window + linear attention decode - window_factors = F.sigmoid(self.window_factors) - linear_factors = 1 - window_factors if self.affine_attention_factors else 1 - - a_sm = torch.einsum('bhmd,bhnd->bhmn', q.float(), k_cache.float()) * (k.shape[-1] ** -0.5) - # a_sm = torch.softmax(a_sm, dim=-1) - a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True) - a_sm = window_factors * torch.exp(a_sm - a_sm_max) - sum_sm = a_sm.sum(dim=-1, keepdim=True) - - y_pred = (torch.einsum('bhmn,bhnd->bhmd', a_sm, v_cache.float()) - + linear_factors * torch.einsum('bhld,bhdf->bhlf', f_q.float(), kv_state.float())) - sum_ln = linear_factors * torch.einsum('bhld,bhnd->bhl', f_q.float(), k_state.float())[..., None] - y_pred = (y_pred / (sum_sm + sum_ln)).to(q.dtype) - - else: - window_factors = F.sigmoid(self.window_factors) - linear_factors = 1 - window_factors if self.affine_attention_factors else 1 - _y_pred, a_pred = self.quadratic_attention(q, k, f_q, f_k, v, - window_factors, linear_factors, - window_size=self.window_size, - kv_state=None, - k_state=None,) - - # Concatenate heads and apply output projection - y_pred = self.o_proj(_y_pred.transpose(1, 2).contiguous().view(b, l, self.hidden_size)) - if self.train_attention: - attn_weights = (None, (_y_pred, _y_true)) # flash_attn outputs are shape (b, l, h, d) - else: - attn_weights = None - return y_pred, attn_weights, past_key_value diff --git a/src/model/load_model.py b/src/model/load_model.py index 6b523a1..6f7b978 100644 --- a/src/model/load_model.py +++ b/src/model/load_model.py @@ -41,7 +41,7 @@ def load_and_convert_attns(model: nn.Module, peft_key = 'peft' # inconsistency across configs... why do this to myself if 'peft_config' in model_config['attention']: peft_key = 'peft_config' - if peft_key in model_config['attention']: # and not prior_loras: + if peft_key in model_config['attention']: peft_config = model_config['attention'][peft_key] model, peft_config = create_peft_config(model, peft_config, model_config['model']['torch_dtype'], @@ -117,7 +117,19 @@ def load_and_convert_finetune(model: nn.Module, p.requires_grad = True else: for p in model.parameters(): - p.requires_grad = True + p.requires_grad = False + # Keep specified weights trainable + if 'trainable_weights' in finetune_config.finetune: + for name in finetune_config.finetune['trainable_weights']: + for n, p in model.named_parameters(): + if name in n: + if 'layers_to_ignore' in finetune_config.finetune: + layer = int(n.split('layers.')[-1].split('.')[0]) + if layer not in finetune_config.finetune['layers_to_ignore']: + p.requires_grad = True + else: + p.requires_grad = True + # Load weights if checkpoint_path: diff --git a/src/model/load_model_for_eval.py b/src/model/load_model_for_eval.py index be2ff59..227cd95 100644 --- a/src/model/load_model_for_eval.py +++ b/src/model/load_model_for_eval.py @@ -94,17 +94,10 @@ def get_lm_eval_model(model_kwargs: dict, # model_loader.loading_kwargs print('-> Loading as lm-evaluation-harness model') if hedgehog_model: - # from lm_eval_harness.models import HedgehogLlamaForCausalLM if 'mistral' in lm_kwargs['pretrained']: - if long_model: - from lm_eval_harness.models import LooooolcatsMistralForCausalLM as ModelClass - else: - from lm_eval_harness.models import LolcatsMistralForCausalLM as ModelClass + from lm_eval_harness.models import LolcatsMistralForCausalLM as ModelClass else: - if long_model: - from lm_eval_harness.models import LooooolcatsLlamaForCausalLM as ModelClass - else: - from lm_eval_harness.models import LolcatsLlamaForCausalLM as ModelClass + from lm_eval_harness.models import LolcatsLlamaForCausalLM as ModelClass lm = ModelClass.create_from_arg_string('', lm_kwargs) else: sys.path.append(path_to_lm_eval_harness) @@ -180,11 +173,6 @@ def load_model_from_checkpoint(attn_mlp_checkpoint_path: str = None, if profile_model: model_config['attention']['attention_type'] += '_profile' - if 'long' in model_config['attention']['attention_type']: - long_model = True - else: - long_model = False - if finetune_checkpoint_path is not None: finetune_config = finetune_checkpoint_path.split('-f=')[-1].split('-')[0] finetune_config_path = join(config_dir, 'experiment', f'{finetune_config}.yaml') @@ -207,15 +195,14 @@ def load_model_from_checkpoint(attn_mlp_checkpoint_path: str = None, if lm_eval_model and attn_mlp_checkpoint_path is not None: lm = get_lm_eval_model(model_loader.loading_kwargs, path_to_lm_eval_harness, - hedgehog_model=True, long_model=long_model) + hedgehog_model=True) model = lm.model # Do this way because we call the larger object - elif lm_eval_model: # Instantiate as lm_eval.base.LM object + elif lm_eval_model: # Instantiate as lm_eval.base.LM object lm = get_lm_eval_model(model_loader.loading_kwargs, path_to_lm_eval_harness) model = lm.model elif attn_mlp_checkpoint_path is None: model = model_loader.load() else: - # model = model_loader.load(model_type='hedgehog_llama') model = model_loader.load(model_type=model_config['attention']['attention_type']) try: model.state_chunk_len = model_config['attention']['state_chunk_len'] diff --git a/src/model/modeling_llama.py b/src/model/modeling_llama.py index 5a0868e..03c4789 100644 --- a/src/model/modeling_llama.py +++ b/src/model/modeling_llama.py @@ -246,8 +246,6 @@ def forward( Forward pass where we chunk inputs """ self.generating = False - # assert output_attentions is False - # assert use_cache is True if use_cache is not True: use_cache = True @@ -274,7 +272,6 @@ def forward( # Determine and setup our KV cache or state attention_type = getattr(self.model.layers[0].self_attn, 'attention_type', None) past_key_values = get_attention_cache(attention_type) - # print(f'-> attention_type:', attention_type) # Split inputs into chunks, and do linear attention over each (passing the states) input_ids = torch.split(input_ids, self.state_chunk_len, dim=-1) @@ -283,7 +280,6 @@ def forward( all_logits = [] # save these for _idx, _input_ids in enumerate(input_ids): - # labels = copy.deepcopy(_input_ids) if self.training: print(f'Model processing _input_ids.shape:', _input_ids.shape) outputs = super().forward(_input_ids, None, diff --git a/src/model/modeling_llama_sharded.py b/src/model/modeling_llama_sharded.py index 8398e28..45fe04f 100644 --- a/src/model/modeling_llama_sharded.py +++ b/src/model/modeling_llama_sharded.py @@ -130,8 +130,6 @@ def forward( position_ids = position_ids.to(device) if attention_mask is not None: attention_mask = attention_mask.to(device) - # if past_key_values is not None: - # past_key_values = past_key_values.to(device) if output_hidden_states: all_hidden_states += (hidden_states,) diff --git a/src/model/modeling_mistral.py b/src/model/modeling_mistral.py index cc44576..3b5fd02 100644 --- a/src/model/modeling_mistral.py +++ b/src/model/modeling_mistral.py @@ -94,8 +94,6 @@ def forward( Forward pass where we chunk inputs """ self.generating = False - # assert output_attentions is False - # assert use_cache is True if use_cache is not True: use_cache = True @@ -135,8 +133,6 @@ def forward( all_logits = [] # save these for _idx, _input_ids in enumerate(input_ids): - # print(f'Model processing _input_ids.shape:', _input_ids.shape) - # labels = copy.deepcopy(_input_ids) outputs = super().forward(_input_ids, None, position_ids[_idx] if position_ids is not None else None, past_key_values, inputs_embeds, diff --git a/src/model/pretrained.py b/src/model/pretrained.py index 6a7f630..dc0ebe3 100644 --- a/src/model/pretrained.py +++ b/src/model/pretrained.py @@ -26,19 +26,8 @@ def get_pretrained_loader(pretrained_model_name_or_path: str, huggingface_token=huggingface_token, **model_kwargs, ) - # elif 'ixtral' in pretrained_model_name_or_path: # Mixtral or mixtral - # return PretrainedMixtralLoader( - # pretrained_model_name_or_path=pretrained_model_name_or_path, - # huggingface_token=huggingface_token, - # **model_kwargs, - # ) elif 'istral' in pretrained_model_name_or_path: # Mistral or mistral; - # after Mixtral to avoid triggering on 'mistral/Mixtral' path - # if transformers.__version__ == '4.43.0': - # _loader = PretrainedLlamaLoader - # else: - _loader = PretrainedMistralLoader - return _loader( # return PretrainedLlamaLoader( + return PretrainedMistralLoader( pretrained_model_name_or_path=pretrained_model_name_or_path, huggingface_token=huggingface_token, **model_kwargs, diff --git a/src/trainer/default_lm.py b/src/trainer/default_lm.py index a5cac39..89ecb0c 100644 --- a/src/trainer/default_lm.py +++ b/src/trainer/default_lm.py @@ -21,7 +21,7 @@ class OurTrainer(): """ Basic parent trainer class. Defaults to language modeling. - -> Replacement for HuggingFace Trainer + -> Replacement for Hugging Face Trainer """ def __init__(self, model: nn.Module, @@ -49,6 +49,7 @@ def __init__(self, max_eval_batches: int = -1, print_samples: bool = False, initial_eval: bool = True, + num_save_ckpt_steps: int = 1000, **kwargs: any): super().__init__() self.model = model @@ -88,6 +89,7 @@ def __init__(self, self.max_eval_batches = max_eval_batches self.print_samples = print_samples self.initial_eval = initial_eval + self.num_save_ckpt_steps = num_save_ckpt_steps # Saving metrics self.train_metrics = {'train/loss': None, @@ -202,16 +204,21 @@ def train_step(self, model: nn.Module, epoch: int) -> nn.Module: if self.evaluation_strategy == 'steps': if (self.grad_step % self.eval_steps == 0 and self.grad_step > 0 and not eval_for_step): - # self.optimizer.zero_grad() # Clear out grads before eval _eval_metrics = self.eval_step(model, step=self.grad_step) - # torch.cuda.empty_cache() print(f'Grad Step {self.grad_step} eval metrics:', _eval_metrics) eval_for_step = True model.train() # Need to set back to train mode + elif self.grad_step == 0 and self.num_save_ckpt_steps < 1000 and not eval_for_step: # hack for micros + _eval_metrics = self.eval_step(model, step=self.grad_step) + print(f'Grad Step {self.grad_step} eval metrics:', _eval_metrics) + eval_for_step = True + model.train() # Need to set back to train mode + elif self.grad_step % self.eval_steps == 0 and self.grad_step > 0 and eval_for_step: pass else: - eval_for_step = False + if self.grad_step > 0: + eval_for_step = False if self.grad_step == self.max_steps: early_stopping = True return model, early_stopping @@ -253,10 +260,9 @@ def eval_step(self, model: nn.Module, step: int = None, 'step': self.grad_step, self.metric_for_best_model: val_metric }, self.best_val_checkpoint_path) - # model.to(self.device) print(f'\n-> Saved best model checkpoint to: {self.best_val_checkpoint_path}!') - if self.grad_step % 1000 == 0: + if self.grad_step % self.num_save_ckpt_steps == 0: save_path = self.best_val_checkpoint_path.replace('.pt', f'_{self.grad_step}.pt') torch.save({ 'model_state_dict': self.save_trainable_weights(model), @@ -283,7 +289,6 @@ def compute_eval_metrics(self, desc=f'Evaluating at step {step}') model.eval() - # model.to(self.device) step_loss = 0 step_eval_metrics = {} with torch.no_grad(): @@ -333,7 +338,7 @@ def compute_loss(self, model: nn.Module, data: torch.Tensor, inputs = {k: v.to(model.device) for k, v in data.items() if k in input_keys} - outputs = model(**inputs, output_attentions=False, use_cache=False) # use_cache=False) + outputs = model(**inputs, output_attentions=False, use_cache=False) outputs = outputs.get('logits')[..., :-1, :].contiguous() targets = data.get('labels')[..., 1:].contiguous() @@ -398,4 +403,4 @@ def init_checkpointing(self, self.metric_for_best_model = self.metric_for_best_model if self.metric_for_best_model is not None: if 'eval' not in self.metric_for_best_model: - self.metric_for_best_model = f'eval/{self.metric_for_best_model}' \ No newline at end of file + self.metric_for_best_model = f'eval/{self.metric_for_best_model}' diff --git a/src/trainer/default_lm_chunked.py b/src/trainer/default_lm_chunked.py deleted file mode 100644 index e9cc747..0000000 --- a/src/trainer/default_lm_chunked.py +++ /dev/null @@ -1,130 +0,0 @@ -""" -Custom trainer class for training models with "chunked" linear attention - -Lets us train like an RNN to process long sequences with fixed memory: -- Use linear attention's recurrent view to process a long sequence - as a set of non-overlapping chunks -- At the end of each chunk, pass the computed KV and K states to initialize - the states for the next chunk -- Accumulate gradients with loss over each chunk -""" -import torch -import torch.nn as nn - -from tqdm import tqdm - -from src.model.modeling_llama import get_attention_cache -from src.model.convert_model import traverse_layers - -from .default_lm import OurTrainer as OurDefaultTrainer -from .utils import decode_samples - - -class OurTrainer(OurDefaultTrainer): - """ - Custom trainer class for training models with "chunked" linear attention - - Lets us train like an RNN to process long sequences with fixed memory: - - Use linear attention's recurrent view to process a long sequence - as a set of non-overlapping chunks - - At the end of each chunk, pass the computed KV and K states to initialize - the states for the next chunk - - Accumulate gradients with loss over each chunk - """ - def __init__(self, model, **kwargs: any): - assert ( - getattr(model, 'state_chunk_len', None) is not None - ), "model must have a `state_chunk_len` attribute" - super().__init__(model=model, **kwargs) - self.criterion = nn.CrossEntropyLoss(reduction='mean') - self.tokenizer = getattr(self.train_loader.dataset, 'tokenizer', None) - self.compute_loss_backprop = True # Whether we backprop in self.compute_loss - self.initial_eval = False # Whether to evaluate before training - - def compute_loss(self, model: nn.Module, data: dict[torch.Tensor], - sample_idx: int = None, **kwargs: any, - ) -> tuple[torch.Tensor, dict[any]]: - """ - Compute loss over sample as a sequence of chunks - - model should have a `state_chunk_len` attribute and a `chunk_forward` - method, where `state_chunk_len` defines the chunk size - - Args: - - model: nn.Module, HF model to train - - data: dict[torch.Tensor], HF datasets batch of data - - sample_idx: int, index of batch in dataset - """ - chunk_metrics = {} - input_seq_len = data['input_ids'].shape[-1] - total_seq_len = 0 - loss = 0 - - # Get KV state object for model as `past_key_values` - layers = traverse_layers(model) - attention_type = getattr(layers[0].self_attn, 'attention_type', None) - past_key_values = get_attention_cache(attention_type) - - # Chunk original input sequences; assume single batch for now - input_ids = torch.split(data['input_ids'], model.state_chunk_len, dim=-1) - labels = data['labels'] if 'labels' in data else data['input_ids'] - labels = torch.split(labels, model.state_chunk_len, dim=-1) - - pbar = tqdm(input_ids, desc=f'Processing chunk 0 | state len: {model.state_chunk_len} (token {total_seq_len} / {input_seq_len})', leave=False) - - for chunk_idx, chunk_input_ids in enumerate(pbar): - try: - outputs = model.chunk_forward(input_ids=chunk_input_ids.to(self.device), - attention_mask=None, - position_ids=None, - past_key_values=past_key_values, - inputs_embeds=None, - use_cache=True, - output_attentions=False, - output_hidden_states=False, - return_dict=True) - except Exception as e: - raise e - - past_key_values = outputs.past_key_values - outputs = outputs.get('logits')[..., :-1, :].contiguous() - targets = labels[chunk_idx][..., 1:].contiguous() - - if ((targets != -100).sum() > 0 and self.tokenizer is not None and - sample_idx is not None and (sample_idx + 1) % 100 == 0): - decode_samples(outputs.cpu(), targets.cpu(), self.tokenizer, sample_idx) - - outputs = outputs.view(-1, outputs.shape[-1]) - targets = targets.view(-1).to(outputs.device) - - if (targets != -100).sum() == 0: # Chunk contains only padding or prompts - chunk_metrics[f'loss_{chunk_idx}'] = 0 - chunk_metrics[f'ppl_{chunk_idx}'] = 1 # torch.exp(_loss).item(), or -1? - else: - try: - _loss = self.criterion(outputs, targets) - except Exception as e: - print(e) - breakpoint() - - # Accumulate gradients over chunks - if model.training: - with torch.autograd.set_detect_anomaly(True): - _loss.backward() - - chunk_metrics[f'loss_{chunk_idx}'] = _loss.item() - chunk_metrics[f'ppl_{chunk_idx}'] = torch.exp(_loss).item() - loss += chunk_metrics[f'loss_{chunk_idx}'] / len(input_ids) - total_seq_len += chunk_input_ids.shape[-1] - desc=f'Processing chunk {chunk_idx + 1} | state len: {model.state_chunk_len} (token {total_seq_len} / {input_seq_len}) | loss: {chunk_metrics[f"loss_{chunk_idx}"]:.3f} | ppl: {chunk_metrics[f"ppl_{chunk_idx}"]:.3f}' - pbar.set_description(desc) - - targets = targets.cpu() - outputs = outputs.cpu() - _loss = _loss.cpu() - del past_key_values, outputs, targets, _loss - torch.cuda.empty_cache() - - # Display chunks in reverse - chunk_metrics = [(k, v) for k, v in chunk_metrics.items()][::-1] - chunk_metrics = {k: v for k, v in chunk_metrics} - return loss, chunk_metrics diff --git a/src/trainer/distill_attention_mse_linear.py b/src/trainer/distill_attention_mse_linear.py index 21146ee..95110c9 100644 --- a/src/trainer/distill_attention_mse_linear.py +++ b/src/trainer/distill_attention_mse_linear.py @@ -1,6 +1,5 @@ """ -Custom trainer class for distilling attentions over long sequences with -recurrent linear attention view. Can substitute for HuggingFace trainer. +Custom trainer class for distilling attentions ("attention transfer") over long sequences with recurrent linear attention view. Can substitute for Hugging Face trainer. """ import torch import torch.nn as nn diff --git a/src/trainer/distill_attention_xent_mse.py b/src/trainer/distill_attention_xent_mse.py index 3784263..68a9bd2 100644 --- a/src/trainer/distill_attention_xent_mse.py +++ b/src/trainer/distill_attention_xent_mse.py @@ -1,5 +1,7 @@ """ -Custom trainer class for distilling attentions. Can substitute for HuggingFace trainer. +Custom trainer class for distilling attentions ("attention transfer"). Can substitute for Hugging Face trainer. + +In this implementation we support using either just the softmax attention outputs, or the softmax attention weights. """ import torch import torch.nn as nn @@ -23,7 +25,7 @@ def __init__(self, super().__init__(model=model, metric_for_best_model=metric_for_best_model, **kwargs) - self.criterion_xent = nn.CrossEntropyLoss(reduction='mean') # stability + self.criterion_xent = nn.CrossEntropyLoss(reduction='mean') self.criterion_mse = nn.MSELoss(reduction='mean') self.mse_factor = mse_factor self.xent_factor = xent_factor @@ -39,7 +41,6 @@ def compute_loss(self, model: nn.Module, data: dict[torch.Tensor], inputs = {k: v.to(model.device) for k, v in data.items() if k != 'labels'} outputs = model(**inputs, output_attentions=True, use_cache=False) outputs = outputs.get('attentions') - # inputs = {k: v.cpu() for k, v in inputs.items()} # save gpu memory # Attentions are tuple[tuple[torch.Tensor, torch.Tensor]] # n_layers x (predicted_attns, true_attns) @@ -62,10 +63,7 @@ def compute_loss(self, model: nn.Module, data: dict[torch.Tensor], a_pred = a_pred.contiguous().view(-1, k_len) a_true = a_true.contiguous().view(-1, k_len) loss_xent += self.criterion_xent(a_pred, a_true) - # loss_xent += self.criterion_xent(a_pred.to(model.device), - # a_true.to(model.device)) if self.mse_factor > 0: - # attns[1] = [a.to(model.device) for a in attns[1]] loss_mse += self.criterion_mse(*attns[1]) n_layers += 1 else: @@ -74,13 +72,6 @@ def compute_loss(self, model: nn.Module, data: dict[torch.Tensor], loss_xent = loss_xent / n_layers * self.xent_factor loss_mse = loss_mse / n_layers * self.mse_factor loss = loss_xent + loss_mse - # try: - # del a_true - # del a_pred - # except NameError: - # pass - # torch.cuda.empty_cache() - # print('softmax_layer:', softmax_layers) if 'position_ids' in data: outputs = {'loss_xent': loss_xent.item() if self.xent_factor > 0 else 0, 'loss_mse': loss_mse.item() if self.mse_factor > 0 else 0, diff --git a/src/trainer/distill_attention_xent_mse_chunked.py b/src/trainer/distill_attention_xent_mse_chunked.py deleted file mode 100644 index 6bb6dad..0000000 --- a/src/trainer/distill_attention_xent_mse_chunked.py +++ /dev/null @@ -1,144 +0,0 @@ -""" -Custom trainer class for distilling attentions over long sequences with -recurrent linear attention view. Can substitute for HuggingFace trainer. -""" -import torch -import torch.nn as nn - -from tqdm import tqdm - -from src.model.modeling_llama import get_attention_cache -from src.model.convert_model import traverse_layers -from .default_lm_chunked import OurTrainer as DefaultChunkedTrainer - - -class OurTrainer(DefaultChunkedTrainer): - """ - Custom trainer class for distilling attentions. - - We compute and store the attention outputs and/or weights for each head and layer, - for both the "teacher" softmax attentions and "student" learnable subquadratic attentions - - We then train the student layers to minimize either MSE(outputs) or CrossEntropy(weights) - """ - def __init__(self, - model: nn.Module, - metric_for_best_model: str = 'distill/eval/loss', - mse_factor: float = 1e3, - xent_factor: float = 0, - **kwargs: any): - super().__init__(model=model, - metric_for_best_model=metric_for_best_model, - **kwargs) - self.criterion_xent = nn.CrossEntropyLoss(reduction='mean') # stability - self.criterion_mse = nn.MSELoss(reduction='mean') - self.mse_factor = mse_factor - self.xent_factor = xent_factor - self.compute_loss_backprop = True # Whether we backprop in self.compute_loss - - - def compute_loss(self, model: nn.Module, data: dict[torch.Tensor], - sample_idx: int = None, **kwargs: any,) -> tuple[torch.Tensor, dict[any]]: - """ - Attention distillation ("attention transfer") - - For each layer and head, get attentions and train to - minimize some combo of MSE and cross-entropy loss - """ - input_seq_len = data['input_ids'].shape[-1] - inputs = {'input_ids': data['input_ids'].to(model.device)} # assume all inputs good - - # Get softmax attention outputs - with torch.no_grad(): - # Set base_inference to True to use FlashAttention - for layer in traverse_layers(model): - layer.self_attn.base_inference = True - true_outputs = model.chunk_forward(**inputs, output_attentions=True, - use_cache=False,) - # no_logit_float=True,) - # Hack were we save attention layer inputs and outputs in outputs.attentions - # -> see model/hedgehog_attention_tk_long.py - # attn_inputs = [a[0] for a in true_outputs.get('attentions')] - # attn_outputs = [a[1] for a in true_outputs.get('attentions')] - true_attn_io = true_outputs.get('attentions') # layer-wise attn inputs and outputs - true_outputs = true_outputs.get('logits').cpu() - for layer in traverse_layers(model): - layer.self_attn.base_inference = False - inputs = {k: v.cpu() for k, v in inputs.items()} - torch.cuda.empty_cache() - - # Get trainable subquadratic attention outputs - attention_type = getattr(layer.self_attn, 'attention_type', None) - past_key_values = get_attention_cache(attention_type) - - num_chunks = input_seq_len // model.state_chunk_len - total_seq_len = 0 - pbar = tqdm(range(num_chunks), desc=f'Processing chunk 0 | state len: {model.state_chunk_len} (token {total_seq_len} / {input_seq_len})', leave=False) - - position_ids = torch.arange(input_seq_len).view(1, -1) - - total_loss = 0 - for chunk_idx in pbar: - start, end = chunk_idx * model.state_chunk_len, (chunk_idx+1) * model.state_chunk_len - attn_inputs = [o[0][:, start:end] for o in true_attn_io] - attn_output = [o[1][:, start:end] for o in true_attn_io] - - # Supervise attentions - pos_ids = position_ids[:, start:end] - loss_mse = 0 - loss_xent = 0 - for layer_idx, layer in enumerate(traverse_layers(model)): - attn_preds = layer.self_attn(attn_inputs[layer_idx].to(model.device), - attention_mask=None, - position_ids=pos_ids.to(model.device), - past_key_value=past_key_values) - (attn_preds, attn_weights), past_key_values = ( - attn_preds[1], attn_preds[2] - ) - if self.mse_factor > 0: - # MSE on layer outputs - loss_mse += self.criterion_mse(attn_preds, attn_output[layer_idx].to(model.device)) - - if self.xent_factor > 0: - # Cross-entropy on attention weights - aw_pred, aw_true = attn_weights - k_len = aw_pred.shape[-1] - # Compute mean loss only over individual queries - aw_pred = aw_pred.contiguous().view(-1, k_len).clamp(min=1e-12).log() - aw_true = aw_true.contiguous().view(-1, k_len) - loss_xent += self.criterion_xent(aw_pred.to(model.device), - aw_true.to(model.device)) - - loss_mse = loss_mse / (layer_idx + 1) * self.mse_factor - loss_xent = loss_xent / (layer_idx + 1) * self.xent_factor - - loss = loss_mse + loss_xent - if model.training: - loss.backward() - - desc=f'Processing chunk {chunk_idx + 1} | state len: {model.state_chunk_len} (token {end} / {input_seq_len}) | loss: {loss.item():.3f}' - if self.mse_factor > 0: - desc += f' | mse: {loss_mse.item():.3f}' - if self.xent_factor > 0: - desc += f' | xent: {loss_xent.item():.3f}' - pbar.set_description(desc) - - total_loss += loss.item() / len(pbar) - pbar.set_description(desc) - - if self.mse_factor > 0 or self.xent_factor > 0: - loss = loss.cpu() - if self.xent_factor > 0: - attn_preds = attn_preds.cpu() - torch.cuda.empty_cache() - - if 'position_ids' in data: - outputs = {'loss_mse': loss_mse.item() if self.mse_factor > 0 else 0, - 'loss_xent': loss_xent.item() if self.xent_factor > 0 else 0, - 'mse_factor': self.mse_factor, - 'xent_factor': self.xent_factor, - 'input_len': data['position_ids'].shape[1], - 'position_ids': data['position_ids'][0],} - else: - outputs = {'loss_mse': loss_mse.item() if self.mse_factor > 0 else 0, - 'loss_xent': loss_xent.item() if self.xent_factor > 0 else 0, - 'mse_factor': self.mse_factor, - 'xent_factor': self.xent_factor,} - return total_loss, outputs \ No newline at end of file diff --git a/src/trainer/distill_attention_xent_mse_tiled.py b/src/trainer/distill_attention_xent_mse_tiled.py deleted file mode 100644 index 97e8a33..0000000 --- a/src/trainer/distill_attention_xent_mse_tiled.py +++ /dev/null @@ -1,78 +0,0 @@ -""" -Custom trainer class for distilling attentions. Can substitute for HuggingFace trainer. -""" -import torch -import torch.nn as nn - -from .default_lm import OurTrainer as DefaultTrainer -from src.model.convert_model import traverse_layers - - -class OurTrainer(DefaultTrainer): - """ - Custom trainer class for distilling attentions. - - We compute and store the attention outputs and/or weights for each head and layer, - for both the "teacher" softmax attentions and "student" learnable subquadratic attentions - - We then train the student layers to minimize either MSE(outputs) or CrossEntropy(weights) - """ - def __init__(self, - model: nn.Module, - metric_for_best_model: str = 'distill/eval/loss', - mse_factor: float = 1e3, - xent_factor: float = 0, - **kwargs: any): - super().__init__(model=model, - metric_for_best_model=metric_for_best_model, - **kwargs) - self.criterion_xent = nn.CrossEntropyLoss(reduction='mean') # stability - self.criterion_mse = nn.MSELoss(reduction='mean') - self.mse_factor = mse_factor - self.xent_factor = xent_factor - self.compute_loss_backprop = True # Whether we backprop in self.compute_loss - - def compute_loss(self, model: nn.Module, data: dict[torch.Tensor], - sample_idx: int = None, **kwargs: any,) -> tuple[torch.Tensor, dict[any]]: - """ - Attention distillation ("attention transfer") - - For each layer and head, get attentions and train to - minimize some combo of MSE and cross-entropy loss - """ - inputs = {k: v.to(model.device) for k, v in data.items() if k != 'labels'} - - with torch.no_grad(): - for layer in traverse_layers(model): - layer.mse_factor = self.mse_factor - layer.xent_factor = self.xent_factor - - # hack; we stored losses as attentions - losses_by_layer = model(**inputs, output_attentions=True, use_cache=False).get('attentions') - - loss = 0 - loss_mse = 0 - loss_xent = 0 - n_layers = 0 # Number of layers to distill - softmax_layers = [] - for layer_idx, (_loss, _loss_mse, _loss_xent) in enumerate(losses_by_layer): - - n_layers += 1 - loss += _loss - loss_mse += _loss_mse - loss_xent += _loss_xent - loss_mse /= n_layers - loss_xent /= n_layers - loss = loss_mse + loss_xent - # if self.xent_factor > 0: - # breakpoint() - if 'position_ids' in data: - outputs = {'loss_xent': loss_xent.item() if self.xent_factor > 0 else 0, - 'loss_mse': loss_mse.item() if self.mse_factor > 0 else 0, - 'input_len': data['position_ids'].shape[1], - 'position_ids': data['position_ids'][0].detach().cpu().numpy(), - 'mse_factor': self.mse_factor, - 'xent_factor': self.xent_factor,} - else: - outputs = {'loss_xent': loss_xent.item() if self.xent_factor > 0 else 0, - 'loss_mse': loss_mse.item() if self.mse_factor > 0 else 0, - 'mse_factor': self.mse_factor, - 'xent_factor': self.xent_factor} - return loss, outputs diff --git a/src/trainer/finetune_seq2seq.py b/src/trainer/finetune_seq2seq.py index 7788390..a85cd5a 100644 --- a/src/trainer/finetune_seq2seq.py +++ b/src/trainer/finetune_seq2seq.py @@ -45,7 +45,7 @@ def compute_scrolls_metrics(eval_preds, scrolls_metric, tokenizer): class OurTrainer(DefaultTrainer): """ - Trainer for SCROLLS benchmark + Evaluator for seq-to-seq / generation benchmarks """ def __init__(self, model, args, # max_eval_batches: Optional[int] = 100, **kwargs: any): @@ -57,8 +57,6 @@ def __init__(self, model, args, # max_eval_batches: Optional[int] = 100, print(f'self.print_steps:', self.print_steps) # ablation sweep self.max_eval_batches = 10 - # if self.max_eval_batches is None: - # self.max_eval_batches = 100 def init_criterion_(self): pass @@ -92,7 +90,6 @@ def eval_step(self, model: nn.Module, step: int, predictions, references = [], [] model.eval() - # model.to(self.device) pbar = tqdm(dataloader, leave=False, colour='green', desc=f'Evaluating at step {step}') diff --git a/src/trainer/layer_distill_xent_mse.py b/src/trainer/layer_distill_xent_mse.py deleted file mode 100644 index 8627ea2..0000000 --- a/src/trainer/layer_distill_xent_mse.py +++ /dev/null @@ -1,83 +0,0 @@ -""" -Custom trainer class for distilling attentions. Can substitute for HuggingFace trainer. -""" -import torch -import torch.nn as nn - -from .default_lm import OurTrainer as DefaultTrainer - - -class OurTrainer(DefaultTrainer): - """ - Custom trainer class for distilling attentions. - - We compute and store the attention outputs and/or weights for each head and layer, - for both the "teacher" softmax attentions and "student" learnable subquadratic attentions - - We then train the student layers to minimize either MSE(outputs) or CrossEntropy(weights) - """ - def __init__(self, - model: nn.Module, # attention layer - layer_idx: int, - metric_for_best_model: str = 'distill/eval/loss', - mse_factor: float = 1e3, - xent_factor: float = 0, - **kwargs: any): - super().__init__(model=model, - metric_for_best_model=metric_for_best_model, - **kwargs) - self.criterion_xent = nn.CrossEntropyLoss(reduction='mean') # stability - self.criterion_mse = nn.MSELoss(reduction='mean') - self.mse_factor = mse_factor - self.xent_factor = xent_factor - self.layer_idx = layer_idx - self.initial_eval = False # Whether to evaluate before training - - def compute_loss(self, model: nn.Module, data: dict[torch.Tensor], - sample_idx: int = None, **kwargs: any,) -> tuple[torch.Tensor, dict[any]]: - """ - Attention distillation ("attention transfer") - - For each layer and head, get attentions and train to - minimize some combo of MSE and cross-entropy loss - - model: nn.Module that is a Lolcats attention class. - If outputs = model(**inputs), - - outputs[0] are the layer outputs - - outputs[1] are attentions (or other saved tensors) - """ - _data_kwargs = {'device': model.q_proj.weight.device, - 'dtype': model.q_proj.weight.dtype} - inputs = {'hidden_states': data['hidden_states'].to(**_data_kwargs)} - if 'position_ids' in data: - inputs['position_ids'] = data['position_ids'].to(**_data_kwargs) - attns = model(**inputs, output_attentions=True, use_cache=False)[1] - inputs = {k: v.cpu() for k, v in inputs.items()} # save gpu memory - - # Attentions are tuple[tuple[torch.Tensor, torch.Tensor]] - # n_layers x (predicted_attns, true_attns) - # predicted_attns and true_attns are shape (batch, n_heads, q_len, k_len) - loss_mse = 0 - loss_xent = 0 - if self.xent_factor > 0: - # Cross-entropy loss - a_pred, a_true = attns[0] - a_pred = a_pred.clamp(min=1e-12).log() # nn.CrossEntropy assumes unnormalized logits - k_len = a_true.shape[-1] # batch, n_heads, q_len, k_len - # Compute mean cross-entropy over all queries - a_pred = a_pred.contiguous().view(-1, k_len) - a_true = a_true.contiguous().view(-1, k_len) - loss_xent += self.criterion_xent(a_pred, a_true) - # loss_xent += self.criterion_xent(a_pred.to(model.device), - # a_true.to(model.device)) - if self.mse_factor > 0: - # attns[1] = [a.to(model.device) for a in attns[1]] - loss_mse += self.criterion_mse(*attns[1]) - - loss_xent = loss_xent * self.xent_factor - loss_mse = loss_mse * self.mse_factor - loss = loss_xent + loss_mse - - outputs = {'loss_xent': loss_xent.item() if self.xent_factor > 0 else 0, - 'loss_mse': loss_mse.item() if self.mse_factor > 0 else 0, - 'mse_factor': self.mse_factor, - 'xent_factor': self.xent_factor, - 'layer_idx': self.layer_idx} - return loss, outputs diff --git a/src/trainer/layer_finetune_xent_mse.py b/src/trainer/layer_finetune_xent_mse.py deleted file mode 100644 index da53451..0000000 --- a/src/trainer/layer_finetune_xent_mse.py +++ /dev/null @@ -1,92 +0,0 @@ -""" -Custom trainer class for distilling attentions. Can substitute for HuggingFace trainer. -""" -import torch -import torch.nn as nn - -from .default_lm import OurTrainer as DefaultTrainer - - -class OurTrainer(DefaultTrainer): - """ - Custom trainer class for distilling attentions. - - We compute and store the attention outputs and/or weights for each head and layer, - for both the "teacher" softmax attentions and "student" learnable subquadratic attentions - - We then train the student layers to minimize either MSE(outputs) or CrossEntropy(weights) - """ - def __init__(self, - model: nn.Module, # attention layer - teacher_layer: nn.Module, - layer_idx: int, - metric_for_best_model: str = 'ft/eval/loss', - mse_factor: float = 1e3, - xent_factor: float = 0, - **kwargs: any): - super().__init__(model=model, - metric_for_best_model=metric_for_best_model, - **kwargs) - self.criterion_xent = nn.CrossEntropyLoss(reduction='mean') # stability - self.criterion_mse = nn.MSELoss(reduction='mean') - self.mse_factor = mse_factor - self.xent_factor = xent_factor - self.layer_idx = layer_idx - self.initial_eval = False # Whether to evaluate before training - - self.teacher_layer = teacher_layer - self.teacher_layer.eval() - for param in self.teacher_layer.parameters(): - param.requires_grad = False # freeze teacher - - def compute_loss(self, model: nn.Module, data: dict[torch.Tensor], - sample_idx: int = None, **kwargs: any,) -> tuple[torch.Tensor, dict[any]]: - """ - Attention distillation ("attention transfer") - - For each layer and head, get attentions and train to - minimize some combo of MSE and cross-entropy loss - - model: nn.Module that is a Lolcats attention class. - If outputs = model(**inputs), - - outputs[0] are the layer outputs - - outputs[1] are attentions (or other saved tensors) - """ - _data_kwargs = {'device': model.q_proj.weight.device, - 'dtype': model.q_proj.weight.dtype} - inputs = {'hidden_states': data['hidden_states'].to(**_data_kwargs)} - if 'position_ids' in data: - inputs['position_ids'] = data['position_ids'].to(**_data_kwargs) - - with torch.no_grad(): - _n = data['hidden_states'].shape[-2] # construct attention mask for ground-truth softmax - causal_mask = torch.ones((1, 1, _n, _n), **_data_kwargs).triu(1) * -1e8 - # LlamaAttention processes mask as: attn_weights = attn_weights + causal_mask - y_true, a_true, _ = self.teacher_layer(**inputs, output_attentions=True, use_cache=False, attention_mask=causal_mask) - y_pred, a_pred, _ = model(**inputs, output_attentions=True, use_cache=False) - - inputs = {k: v.cpu() for k, v in inputs.items()} # save gpu memory - - # Attentions are tuple[tuple[torch.Tensor, torch.Tensor]] - # n_layers x (predicted_attns, true_attns) - # predicted_attns and true_attns are shape (batch, n_heads, q_len, k_len) - loss_mse = 0 - loss_xent = 0 - if self.xent_factor > 0: - # Cross-entropy loss - a_pred = a_pred.clamp(min=1e-12).log() # nn.CrossEntropy assumes unnormalized logits - k_len = a_true.shape[-1] # batch, n_heads, q_len, k_len - # Compute mean cross-entropy over all queries - a_pred = a_pred.contiguous().view(-1, k_len) - a_true = a_true.contiguous().view(-1, k_len) - loss_xent += self.criterion_xent(a_pred, a_true) - if self.mse_factor > 0: - loss_mse += self.criterion_mse(y_pred, y_true) - - loss_xent = loss_xent * self.xent_factor - loss_mse = loss_mse * self.mse_factor - loss = loss_xent + loss_mse - - outputs = {'loss_xent': loss_xent.item() if self.xent_factor > 0 else 0, - 'loss_mse': loss_mse.item() if self.mse_factor > 0 else 0, - 'mse_factor': self.mse_factor, - 'xent_factor': self.xent_factor, - 'layer_idx': self.layer_idx} - return loss, outputs diff --git a/src/trainer/mini_finetune_xent_mse.py b/src/trainer/mini_finetune_xent_mse.py deleted file mode 100644 index f1c0d20..0000000 --- a/src/trainer/mini_finetune_xent_mse.py +++ /dev/null @@ -1,116 +0,0 @@ -""" -Custom trainer class for distilling attentions. Can substitute for HuggingFace trainer. -""" -import torch -import torch.nn as nn - -from peft.tuners.lora.layer import LoraLayer - -from .default_lm import OurTrainer as DefaultTrainer -from src.model.convert_model import traverse_layers, toggle_attention - - - -def toggle_lora(model, use_lora: bool = True): - for layer in traverse_layers(model): - for n, module in layer.self_attn.named_modules(): - if isinstance(module, LoraLayer): - module._disable_adapters = not use_lora - return model - - -class OurTrainer(DefaultTrainer): - """ - Custom trainer class for distilling attentions. - - We compute and store the attention outputs and/or weights for each head and layer, - for both the "teacher" softmax attentions and "student" learnable subquadratic attentions - - We then train the student layers to minimize either MSE(outputs) or CrossEntropy(weights) - """ - def __init__(self, - model: nn.Module, - layer_idx: int, - metric_for_best_model: str = 'ft/eval/loss', - mse_factor: float = 1e3, - xent_factor: float = 0, - **kwargs: any): - model = toggle_attention(model, train=True) # keep train_attention logic - super().__init__(model=model, - metric_for_best_model=metric_for_best_model, - **kwargs) - self.criterion_xent = nn.CrossEntropyLoss(reduction='mean') # stability - self.criterion_mse = nn.MSELoss(reduction='mean') - self.mse_factor = mse_factor - self.xent_factor = xent_factor - self.layer_idx = layer_idx - self.initial_eval = False # Whether to evaluate before training - - self._data_kwargs = {'device': model.device, - 'dtype': traverse_layers(model)[0].self_attn.q_proj.weight.dtype} - - def compute_loss(self, model: nn.Module, data: dict[torch.Tensor], - sample_idx: int = None, **kwargs: any,) -> tuple[torch.Tensor, dict[any]]: - """ - Attention distillation ("attention transfer") - - For each layer and head, get attentions and train to - minimize some combo of MSE and cross-entropy loss - - model: nn.Module that is a Lolcats attention class. - If outputs = model(**inputs), - - outputs[0] are the layer outputs - - outputs[1] are attentions (or other saved tensors) - """ - # inputs = {'inputs_embeds': data['hidden_states'].to(**_data_kwargs)} - inputs = {'inputs_embeds': data['inputs_embeds'].to(**self._data_kwargs)} - if 'position_ids' in data: - inputs['position_ids'] = data['position_ids'].to(device=self.data_kwargs['device']) - - # Teacher outputs - with torch.no_grad(): - model = toggle_lora(model, use_lora=False) - outputs = model(**inputs, output_attentions=True, use_cache=False) - outputs = outputs.get('attentions') # ((_, a_true), (_, _y_true)) x layers - a_true = [o[0][1] for o in outputs] - y_true = [o[1][1] for o in outputs] - # y_true = self.teacher_layer(**inputs, output_attentions=True, output_hidden_states=True, use_cache=False) - # y_true = model(**inputs, output_attentions=True, output_hidden_states=True, use_cache=False) - # y_true, a_true = y_true.get('hidden_states'), y_true.get('attentions') - - # Student outputs - model = toggle_lora(model, use_lora=True) - outputs = model(**inputs, output_attentions=True, use_cache=False).get('attentions') - a_pred = [o[0][0] for o in outputs] - y_pred = [o[1][0] for o in outputs] - - # y_pred = model(**inputs, output_attentions=True, output_hidden_states=True, use_cache=False) - # y_pred, a_pred = y_pred.get('hidden_states'), y_pred.get('attentions') - - inputs = {k: v.cpu() for k, v in inputs.items()} # save gpu memory - - loss_mse = 0 - loss_xent = 0 - for layer_idx in range(len(a_pred)): # indexed by n_layers - if self.xent_factor > 0: - _a_pred, _a_true = a_pred[layer_idx], a_true[layer_idx] - - # Cross-entropy loss - _a_pred = _a_pred.clamp(min=1e-12).log() # nn.CrossEntropy assumes unnormalized logits - k_len = _a_true.shape[-1] # batch, n_heads, q_len, k_len - - # Compute mean cross-entropy over all queries - _a_pred = _a_pred.contiguous().view(-1, k_len) - _a_true = _a_true.contiguous().view(-1, k_len) - loss_xent += self.criterion_xent(_a_pred, _a_true) - - if self.mse_factor > 0: - loss_mse += self.criterion_mse(y_pred[layer_idx], y_true[layer_idx]) - - loss_xent = loss_xent * self.xent_factor / len(y_pred) - loss_mse = loss_mse * self.mse_factor / len(y_pred) - loss = loss_xent + loss_mse - - outputs = {'loss_xent': loss_xent.item() if self.xent_factor > 0 else 0, - 'loss_mse': loss_mse.item() if self.mse_factor > 0 else 0, - 'mse_factor': self.mse_factor, - 'xent_factor': self.xent_factor, - 'layer_idx': self.layer_idx} - return loss, outputs diff --git a/stitch_mini.py b/stitch_mini.py deleted file mode 100644 index fb2eb6e..0000000 --- a/stitch_mini.py +++ /dev/null @@ -1,476 +0,0 @@ -""" -This file just needs to save out the shards for 405B. - -Notes: -- Make sure that register_buffer inv_freq persistent=True for your modeling_llama.py - -python stitch_mini.py \ ---model_config distill_llama3_8b_lk_smd_wtk64_fd64_w01 \ ---distill_config distill_alpaca_clean_xent1_mse1000_lr1e-2 \ ---finetune_config finetune_lora_qkvo_alpaca_clean_mini_xent1_mse1000 \ ---lk_zero_init --lr 1e-3 \ ---verbose --seed 0 --replicate 0 \ ---layers_per_model 8 \ ---e2e_finetune_config finetune_lora_qkvo_alpaca_clean \ ---load_finetuned_loras - -python stitch_mini.py \ ---model_config distill_llama3_8b_lk_smd_wtk64_fd64_w01 \ ---distill_config distill_alpaca_clean_xent1_mse1000_lr1e-2 \ ---finetune_config finetune_lora_qkvo_alpaca_clean_mini_xent1_mse1000 \ ---lk_zero_init --lr 1e-3 \ ---verbose --seed 0 --replicate 0 \ ---layers_per_model 8 \ ---e2e_finetune_config finetune_lora_qkvo_alpaca_clean - -# Called checkpoint -'./checkpoints/distill_llama3_8b_lk_smd_wtk64_fd64_w01/sharded_layers/dl-d=distill_alpaca_clean_xent1_mse1000_lr1e-2-m=distill_llama3_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean_mini_xent1_mse1000-s=0-lr=0.001-se=0-re=0-lzi=1-se=0-re=0-in=00-out=07_ft.pt' - -# Saved -'./checkpoints/distill_llama3_8b_lk_smd_wtk64_fd64_w01/sharded_layers/dl-d=distill_alpaca_clean_xent1_mse1000_lr1e-2-m=distill_llama3_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean_mini_xent1_mse1000-s=0-lr=0.001-se=0-re=0-lzi=1-in=00-out=07-bs=1-gas=8-nte=2-ms=-1-se=0-re=0_ft.pt' -""" - -import os -from os.path import join - -import argparse -from omegaconf import OmegaConf -from tqdm import tqdm - -import torch -import torch.optim as optim - -from transformers.models.llama.modeling_llama import ( - LlamaConfig, - LlamaDecoderLayer as DecoderLayer -) - -# Distributed arguments -from torch.distributed.fsdp import ( - FullyShardedDataParallel as FSDP, - ShardingStrategy, - StateDictType -) -from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload - -from accelerate.utils import is_xpu_available -from llama_recipes.configs import fsdp_config as FSDP_CONFIG -from llama_recipes.policies import apply_fsdp_checkpointing -from llama_recipes.utils.fsdp_utils import ( - fsdp_auto_wrap_policy, - hsdp_device_mesh as get_hsdp_device_mesh -) -from llama_recipes.utils.config_utils import update_config - -# Our arguments -from llama_recipes.trainer_finetune import ( - train, - setup, - setup_environ_flags, - clear_gpu_cache, - get_policies, -) -from llama_recipes.model_checkpointing.distill_checkpoint_handler import ( - load_model_sharded, -) -from llama_recipes.distill_llama import ( - setup_wandb, get_args, # get_run_name_from_checkpoint - get_dataloaders, setup_fsdp_config -) - -from src.utils.setup import ( - seed_everything, get_run_name_from_args, # get_run_name_from_checkpoint, - update_config_from_args, update_model_config_from_args, -) -from src.utils.logging import print_config, print_header -from src.model.pretrained import get_pretrained_loader - -from src.model.convert_model import traverse_layers -from src.model.load_model import load_and_convert_attns, load_and_convert_finetune - -from src.trainer import get_trainer, get_optimizer, get_scheduler -from src.finetune import prepare_finetune_configs # get_finetuner - - - -def get_args(): - """Parse command line arguments""" - parser = argparse.ArgumentParser() - parser.add_argument("--project_name", type=str, default='lolcats') - parser.add_argument("--layers_per_model", type=int) - parser.add_argument("--layer_idx", type=int) # specify starting layer - parser.add_argument("--device", type=int, default=0) - - - parser.add_argument("--model_config", type=str, default=None) - parser.add_argument("--distill_config", type=str, default=None) - parser.add_argument("--finetune_config", type=str, default=None) - parser.add_argument("--eval_config", type=str, default=None) - - parser.add_argument("--load_finetuned_loras", action='store_true', default=False) - parser.add_argument("--e2e_finetune_config", type=str, default=None) - parser.add_argument("--load_checkpoint_only", action='store_true', default=False) - - parser.add_argument("--pretrained_model_name_or_path", type=str, default=None) - parser.add_argument("--load_distill_checkpoint", type=str, default=None) - parser.add_argument("--resume_distill", action='store_true', default=None) - - parser.add_argument("--load_finetune_checkpoint", type=str, default=None) - parser.add_argument("--resume_finetune", action='store_true', default=None) - - # Override default configs - # Feature map / model - parser.add_argument("--attention_type", type=str, default=None) - parser.add_argument("--learned_kernel", type=str, default=None) # always - parser.add_argument("--lk_skip_connection", action='store_true', default=None) - parser.add_argument("--lk_zero_init", action='store_true', default=None) - parser.add_argument("--lk_normal_init", action='store_true', default=None) - parser.add_argument("--tie_qk_kernels", action='store_true', default=None) - parser.add_argument("--train_qk", action='store_true', default=None) - parser.add_argument("--state_chunk_len", type=int, default=None) - - # Training - parser.add_argument("--lr", type=float, default=None) - parser.add_argument("--weight_decay", type=float, default=None) - parser.add_argument("--optim", type=str, default=None) - parser.add_argument("--scheduler", type=str, default=None) - parser.add_argument("--gradient_accumulation_steps", type=int, default=None) - parser.add_argument("--num_train_epochs", type=int, default=None) - parser.add_argument("--max_steps", type=int, default=None) - parser.add_argument("--max_finetune_steps", type=int, default=None) - - parser.add_argument("--no_peft_grad_ckpt", action='store_true', default=None) - - ## Distributed training / Llama recipes - parser.add_argument("--enable_fsdp", action='store_true', default=None) - parser.add_argument("--low_cpu_fsdp", action='store_true', default=None) - parser.add_argument("--pure_bf16", action='store_true', default=None) - parser.add_argument("--fsdp_activation_checkpointing", action='store_true', default=None) - parser.add_argument("--fsdp_cpu_offload", action='store_true', default=None) - - # Dataloading - parser.add_argument("--batch_size", type=int, default=None) - parser.add_argument("--num_workers", type=int, default=None) - - # Evaluation - parser.add_argument("--no_init_eval", action='store_true', default=False) - parser.add_argument("--eval_steps", type=int, default=None) - parser.add_argument("--max_eval_batches", type=int, default=None) - - # Miscellaneous - parser.add_argument("--huggingface_token", type=str, default=None) - parser.add_argument("--checkpoint_dir", type=str, default='./checkpoints') - parser.add_argument("--results_dir", type=str, default='./results') - parser.add_argument("--replicate", type=int, default=0) - parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--verbose", action='store_true', default=None) - parser.add_argument("--no_cuda", action='store_true', default=None) - parser.add_argument("--no_wandb", action='store_true', default=None) - parser.add_argument("--wandb_entity", type=str, default='hazy-research') - parser.add_argument("--debug", action='store_true', default=None) - parser.add_argument("--no_attention_mask", action='store_true', default=None) - - args = parser.parse_args() - args.run_name = get_run_name_from_args(args) - return args - - -def check_state_dict_keys(keys: any, layer_idx: int, rank: int = 0, - state_dict: dict = None, verbose: bool = False) -> None: - """ - Check the state dict keys for unexpected and expected keys - - keys: the output from torch.load_state_dict() - - layer_idx: the current layer index - """ - try: - assert len(keys.unexpected_keys) == 0 - if rank == 0: - print_header(f'*** All expected keys matched successfully {layer_idx} ***') - if verbose and state_dict is not None: - print('Keys loaded:') - for k in state_dict: - print(f'├── {k}') - except Exception as e: - if rank == 0: - print(e) - print_header('*** Error: unexpected keys in checkpoint ***') - print(f'Unexpected keys at {layer_idx}:') - for k in keys.unexpected_keys: - print(k) - - -def rename_state_dict(rename_dict: dict, start_layer_idx: int, verbose: bool = False) -> dict: - """Rename the state dict from the mini models to match the full model""" - new_state_dict = {} - for k, v in rename_dict.items(): - if "layers" in k: - k_name = k.split("layers.")[-1].split(".")[0] - k_idx = int(k_name) - new_k_idx = k_idx + start_layer_idx - new_k_name = k.replace(k_name, str(new_k_idx)) - new_state_dict[new_k_name] = v - if verbose: # if start_layer_idx > 9 and start_layer_idx < 18: - print(f"-> Renaming {k} to {new_k_name}") - else: - new_state_dict[k] = v - return new_state_dict - - -def main(): - """Main script""" - # ------ - # SET UP - # ------ - args = get_args() - # args.checkpoint_dir = "/data_ephemeral/sim/sharded_layers_405b/" - args.checkpoint_dir = join(args.checkpoint_dir, args.model_config) - if not os.path.isdir(args.checkpoint_dir): - os.makedirs(args.checkpoint_dir) - # Save individual .pt model weights in a subdirectory - args.checkpoint_dir = join(args.checkpoint_dir, 'sharded_layers') - if not os.path.isdir(args.checkpoint_dir): - os.makedirs(args.checkpoint_dir) - args.results_dir = join(args.results_dir, args.model_config) - if not os.path.isdir(args.results_dir): - os.makedirs(args.results_dir) - seed_everything(args.seed) - - # Load distillation + (hedgehog) attention configs - distill_config_path = join('./configs/experiment', f'{args.distill_config}.yaml') - distill_config = OmegaConf.load(distill_config_path) - distill_config = update_config_from_args(distill_config, args) - - for arg, argv in distill_config.trainer.items(): # legacy, should be removed - if arg != 'name': - setattr(args, arg, argv) - for _config in ['dataloader', 'optimizer', 'lr_scheduler']: - setattr(args, _config, OmegaConf.to_container(getattr(distill_config, _config))) - - model_config_path = join('./configs/model', f'{args.model_config}.yaml') - model_config = OmegaConf.load(model_config_path) - model_config = update_model_config_from_args(model_config, args) - - # Update data tokenizer to match model (unused in this script) - if getattr(distill_config.dataset, 'pretrained_model_config', None) is not None: - for k in ['pretrained_model_name_or_path', 'cache_dir']: - distill_config.dataset.pretrained_model_config[k] = model_config.model[k] - - if args.enable_fsdp: - if getattr(model_config.model, 'load_in_4bit', False): - model_config.model.device_map = 'auto' - elif getattr(model_config.model, 'load_in_8bit', False): - model_config.model.device_map = 'auto' - else: - model_config.model.device_map = None # FSDP will complain about device placement o.w. - model_config.model.low_cpu_mem_usage = True - - # Setup FSDP if enabled - if args.enable_fsdp: - distill_config = setup_fsdp_config(distill_config, args, 'distill') # patch - fsdp_config = FSDP_CONFIG() - update_config((fsdp_config), **vars(args)) - setup() - # torchrun specific - local_rank = int(os.environ["LOCAL_RANK"]) - rank = int(os.environ["RANK"]) - # world_size = int(os.environ["WORLD_SIZE"]) - else: - fsdp_config = FSDP_CONFIG() # ignored - rank = 0 - - if torch.distributed.is_initialized(): - if is_xpu_available(): - torch.xpu.set_device(local_rank) - elif torch.cuda.is_available(): - torch.cuda.set_device(local_rank) - clear_gpu_cache(local_rank) - setup_environ_flags(rank) - - # WandB logging - wandb_run = None - if not args.no_wandb: - if not args.enable_fsdp or rank == 0: - wandb_run = setup_wandb(distill_config, fsdp_config, **vars(args), - project=args.project_name, entity=args.wandb_entity) - - # Loading model - try: - if not os.path.exists(model_config.model.pretrained_model_name_or_path): - print(f"Model path {model_config.model.pretrained_model_name_or_path} does not exist. Using backup path. {model_config.model.pretrained_model_name_or_path_backup}") - model_config.model.pretrained_model_name_or_path = model_config.model.pretrained_model_name_or_path_backup - model_config.model.pop("pretrained_model_name_or_path_backup") - except Exception as e: - print(f'-> Error: {e}') - print("Model without model.pretrained_model_name_or_path_backup path") - - if rank == 0 or not args.enable_fsdp: - print_header('Model Config') - print_config(model_config) - - # Get model class and configs for layer instantiating - pretrained_model_config = LlamaConfig.from_pretrained(model_config['model']['pretrained_model_name_or_path']) - pretrained_model_class = pretrained_model_config.architectures[0] - transformers_module = __import__('transformers') - pretrained_model_class = getattr(transformers_module, pretrained_model_class) # e.g, LlamaForCausalLM - - # ------------------------------------------- - # Step 1. Load pretrained model and tokenizer - # ------------------------------------------- - if rank == 0 or not args.enable_fsdp: - print_header('Pretrained Model Config') - print(pretrained_model_config) - print_header('Our Model Config') - print_config(model_config) - - model_loader = get_pretrained_loader(**model_config.model, - huggingface_token=args.huggingface_token) - # Model - model = model_loader.load(model_type='softmax') - if rank == 0 or not args.enable_fsdp: - print_header('Original Model') - print(model) - if args.enable_fsdp and fsdp_config.pure_bf16: - model.to(torch.bfloat16) - for p in model.parameters(): # Freeze all layers - p.requires_grad = False - model.eval() - # Tokenizer - tokenizer = model_loader.load_tokenizer() - tokenizer.pad_token_id = tokenizer.eos_token_id - tokenizer.padding_side = 'left' - - # --------------------------------------------------- - # Step 2. Convert attentions to linearized attentions - # --------------------------------------------------- - model = load_and_convert_attns(model, - model_config, - attention_type=None, # specified in model_config, - checkpoint_path=None, - print_model=args.verbose, - train_attention=False)[0] - if rank == 0 or not args.enable_fsdp: - print_header('Converted Model') - - # ------------------------------------------ - # Step 3. Loop through the saved checkpoints - # ------------------------------------------ - num_hidden_layers = pretrained_model_config.num_hidden_layers # e.g., 32 for Llama 8B - max_digits = len(str(num_hidden_layers)) # name_suffix = f'in={start:0{max_digits}d}-out={end:0{max_digits}d}' - with torch.no_grad(): - first = 0 - for layer_idx, layer in enumerate(tqdm(traverse_layers(model))): - load_file_name = f'{join(args.checkpoint_dir, args.run_name)}' - start, end = first, first + (args.layers_per_model - 1) - name_suffix = f'in={start:0{max_digits}d}-out={end:0{max_digits}d}' - load_file_name += f'-{name_suffix}' - load_file_name = load_file_name.replace('True', '1').replace('False', '0') # concise hacks - load_file_name = load_file_name + '_distill.pt' - - if (layer_idx + 1) % args.layers_per_model == 0: - if rank == 0 or not args.enable_fsdp: - mini_weights = torch.load(load_file_name)['model_state_dict'] - mini_weights = rename_state_dict(mini_weights, first) - _keys = model.load_state_dict(mini_weights, strict=False) - print(f'Loading layer attentions from {load_file_name}...') - check_state_dict_keys(_keys, first, rank, mini_weights, verbose=args.verbose) - first = layer_idx + 1 - args.run_name += f'-{name_suffix}' # dealing with legacy naming - - # --------------------------------------- - # Step 4. Add end-to-end finetuning LoRAs - # --------------------------------------- - e2e_finetune_config, args = prepare_finetune_configs(args, model_config, - args.e2e_finetune_config) - e2e_finetune_config = setup_fsdp_config(e2e_finetune_config, args, 'finetune') - model, _ = load_and_convert_finetune(model, e2e_finetune_config, - checkpoint_path=None, - print_model=args.verbose, - merge_loras=False, - peft_gradient_checkpointing=not args.no_peft_grad_ckpt, - rank=rank) - - # ---------------------------------------------- - # Step 5. Add the LoRA weights from mini-distill - # ---------------------------------------------- - if args.load_finetuned_loras: - if args.enable_fsdp or rank == 0: - print("Loading loras") - with torch.no_grad(): - first = 0 - for layer_idx, layer in enumerate(tqdm(traverse_layers(model))): - load_file_name = f'{join(args.checkpoint_dir, args.run_name)}' - start, end = first, first + (args.layers_per_model - 1) - - load_file_name = load_file_name.replace('True', '1').replace('False', '0') # concise hacks - load_file_name = load_file_name + '_ft.pt' - - if (layer_idx + 1) % args.layers_per_model == 0: - if rank == 0 or not args.enable_fsdp: - mini_weights = torch.load(load_file_name)['model_state_dict'] - mini_weights = rename_state_dict(mini_weights, first) - _keys = model.load_state_dict(mini_weights, strict=False) - print(f'Loading layer loras from {args.checkpoint_dir}...') - check_state_dict_keys(_keys, first, rank, mini_weights, verbose=args.verbose) - first = layer_idx + 1 - - - # Final run name / checkpoint naming setup - if args.e2e_finetune_config is not None: # Update checkpoint for e2e finetune and lora loading - args.run_name += f'-ef={args.e2e_finetune_config}' - args.run_name += f'-ft_lora={args.load_finetuned_loras}'.replace('True', '1').replace('False', '0') - args.run_name = args.run_name.replace('True', '1').replace('False', '0') # concise hacks - args.run_name = args.run_name.replace(f'-{name_suffix}', '') # remove the mini model suffix - # Condense run name - args.run_name = args.run_name.replace(args.model_config, ''.join([c[0] + c[-1] for c in args.model_config.split('_')])) - args.run_name = args.run_name.replace(args.distill_config, ''.join([c[0] + c[-1] for c in args.distill_config.split('_')])) - args.run_name = args.run_name.replace(args.finetune_config, ''.join([c[0] + c[-1] for c in args.finetune_config.split('_')])) - - # Initialize optimizer and scheduler - optimizer = get_optimizer(model=model, **e2e_finetune_config.optimizer) - scheduler = get_scheduler(optimizer=optimizer, **e2e_finetune_config.lr_scheduler) - - if args.verbose and rank == 0: - print('-> Optimizer:', optimizer) - print('-> Scheduler:', scheduler) - print_header('*** Converted Model ***') - print(model) - print_header('*** Trainable Parameters ***') - count = 0 - for n, p in model.named_parameters(): - if p.requires_grad: - print(f'├── {n} (requires_grad = {p.requires_grad}, dtype = {p.dtype})') - count += 1 - if count == 0: # no trainable parameters - print('(none)') - - train_dataloader, eval_dataloader, e2e_finetune_config = get_dataloaders(e2e_finetune_config, tokenizer) - trainer_class = get_trainer(e2e_finetune_config.trainer.name) - finetune_trainer = trainer_class(model=model, - layer_idx=args.layer_idx, - args=args, - train_loader=train_dataloader, - eval_loader=eval_dataloader, - optimizer_and_scheduler=(optimizer, scheduler), - device=args.device, - wandb=wandb_run, - checkpoint_suffix='_ft', - save_results=False, - **e2e_finetune_config.trainer) - if args.verbose: - print_header('Finetune config') - print_config(e2e_finetune_config) - print_header('*** Finetuning ***') - print(f'├── Experiment name: {args.run_name}') - print(f'├── Device: {args.device}') - print(f'├── Seed: {args.seed}') - model = finetune_trainer.train() - args.load_finetune_checkpoint = finetune_trainer.best_val_checkpoint_path - - print_header('*** Done training ***') - print('--> Saved Checkpoints:') - print(f'--attn_mlp_checkpoint_path {args.load_distill_checkpoint} \\') - print(f'--finetune_checkpoint_path {args.load_finetune_checkpoint} \\') - - -if __name__ == '__main__': - main()