From c1d0c271de7f45b3328e1fb00626296707eb5e4f Mon Sep 17 00:00:00 2001 From: mzio Date: Wed, 21 Aug 2024 21:34:58 -0700 Subject: [PATCH] Update 7B readme and add running 70B w llama-recipes --- README.md | 173 ++++++++++++++---- ...ill_alpaca_clean_xent1_mse1000_lr1e-2.yaml | 52 ++++++ ...l_llama3_1_405b_lk_smd_wtk64_fd64_w01.yaml | 39 ++++ distill_llama.py | 2 +- environment.yaml | 2 +- llama_recipes/configs/fsdp.py | 2 +- llama_recipes/dev_scripts.md | 35 ++++ llama_recipes/distill_llama.py | 6 +- .../distill_checkpoint_handler.py | 10 +- llama_recipes/trainer_attention.py | 16 +- llama_recipes/trainer_finetune.py | 39 ++-- 11 files changed, 313 insertions(+), 63 deletions(-) create mode 100644 configs/experiment/distill_alpaca_clean_xent1_mse1000_lr1e-2.yaml create mode 100644 configs/model/distill_llama3_1_405b_lk_smd_wtk64_fd64_w01.yaml create mode 100644 llama_recipes/dev_scripts.md diff --git a/README.md b/README.md index 4d619bf..17a1191 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ In this README: - Getting started with dependencies, installation, and experiment configs -- Sample commands (Mistral-7B-v0.1, Llama-2-7B, Llama-3-8B, Mixtral-8x7B, Llama-2-70B, Llama-3-70B) +- Sample commands (Mistral-7B-v0.1, Llama-3-8B, Llama-3.1-8B, Llama-3.1-70B) --- @@ -19,7 +19,7 @@ Please see `environment.yaml` for dependencies. We can set them up with conda: ``` conda env create -f environment.yaml -conda activate hedgehog +conda activate lolcats ``` --- @@ -45,7 +45,7 @@ pretrained_config: low_cpu_mem_usage: true torch_dtype: bfloat16 rope_theta: 10000.0 - attn_implementation: eager # so we can supervise with attention weights + attn_implementation: eager # if supervising with attention weights ``` --- @@ -72,7 +72,13 @@ For now, we implement the causal linear attention with the CUDA kernel from [htt To build the kernel (`causal_dot_product`), first activate the conda environment (`conda activate hedgehog`). Then navigate to `./csrc/` and run `python setup.py install` within `./csrc/`. It's worth checking the arguments in `./csrc/setup.py` to match your GPU setup and C++ versions. -TODO: we're very excited to integrate additional developments like Songlin and friends' `flash-linear-attention` [repo](https://github.com/sustcsonglin/flash-linear-attention), as well as [ThunderKittens](https://github.com/HazyResearch/ThunderKittens). Please let us know if you're interested in scaling up these efficient linear attention implementations into 7B to 70B models. +### ThunderKittens linear attention + sliding window kernel + +TODO. [ThunderKittens](https://github.com/HazyResearch/ThunderKittens) + +### More + +We're very excited to integrate additional developments like Songlin and friends' [flash-linear-attention](https://github.com/sustcsonglin/flash-linear-attention) ### Flash Attention 2 install @@ -92,6 +98,8 @@ For any of these commands, you may need to provide a Hugging Face token to downl ### Demoing linear attention 7B models +**_Note: Stale_** + We upload a couple checkpoints in `./checkpoints/`, where for any linearized 7B model we only need to save the linear attention layers and the LoRA weights (in two separate `.pt` checkpoints). To chat with these models, you can run: ``` @@ -101,36 +109,70 @@ python -Wignore demo_hedgehog_llm.py \ --num_generations 1 --benchmark ``` -### Distilling + finetuning 7B models +--- + +### Linearizing 7B models + +

+ +

-Any of the below commands will convert a 7B Mistral or Llama LLM into a linear attention instruction-following variant. Despite only using LoRA and training on these 50K instruction-tuning samples, we're able to ``unlock'' a good amount of the base model performance when measured on LM Eval tasks. +Any of the below commands will convert a 7B Mistral or Llama LLM into a subquadratic attention instruction-following variant. Despite only using LoRA and training on these 50K instruction-tuning samples, we're able to ``unlock'' a good amount of the base model performance when measured on LM Eval tasks. + +See `configs/model/` for model configs used in the below commands, and `configs/experiments/` for attention transfer and finetuning configs. #### Mistral-7B-v0.1, Hedgehog Feature Map, using Alpaca-Clean ``` -python distill_llama.py --model_config distill_mistral_7b_lk_smd_zi \ +python distill_llama.py --model_config distill_mistral_7b_lk_smd_fd64 \ +--distill_config distill_alpaca_clean_lr1e-2 \ +--finetune_config finetune_lora_qkvo_alpaca_clean \ +--eval_config eval_alpaca_clean \ +--lk_zero_init \ +--verbose --seed 0 --replicate 0 \ +--huggingface_token hf_ +``` + +#### Mistral-7B-v0.1, Hedgehog + ThunderKittens Sliding Window, using Alpaca-Clean + +``` +python distill_llama.py --model_config distill_mistral_7b_lk_smd_wtk64_fd64_w01 \ +--distill_config distill_alpaca_clean_lr1e-2 \ +--finetune_config finetune_lora_qkvo_alpaca_clean \ +--eval_config eval_alpaca_clean \ +--lk_zero_init \ +--verbose --seed 0 --replicate 0 \ +--huggingface_token hf_ +``` + +#### Llama 3 8B, Hedgehog Feature Map, using Alpaca-Clean + +``` +python distill_llama.py --model_config distill_llama3_8b_lk_smd_fd64 \ --distill_config distill_alpaca_clean_lr1e-2 \ --finetune_config finetune_lora_qkvo_alpaca_clean \ --eval_config eval_alpaca_clean \ --lk_skip_connection --lk_zero_init \ ---verbose --seed 0 --replicate 0 +--verbose --seed 0 --replicate 0 \ +--huggingface_token hf_ ``` -#### Mistral-7B-v0.1, Hedgehog + Sliding Window, using Alpaca-Clean +#### Llama 3 8B, Hedgehog + ThunderKittens Sliding Window, using Alpaca-Clean ``` -python distill_llama.py --model_config distill_mistral_7b_lk_smd_zi_swa16_hh \ +python distill_llama.py --model_config distill_llama3_8b_lk_smd_wtk64_fd64_w01 \ --distill_config distill_alpaca_clean_lr1e-2 \ --finetune_config finetune_lora_qkvo_alpaca_clean \ --eval_config eval_alpaca_clean \ --lk_skip_connection --lk_zero_init \ ---verbose --seed 0 --replicate 0 +--verbose --seed 0 --replicate 0 \ +--huggingface_token hf_ ``` -#### Llama-3-8B, Hedgehog Feature Map, using Alpaca-Clean +#### Llama 3.1 8B, Hedgehog Feature Map, using Alpaca-Clean ``` -python distill_llama.py --model_config distill_llama3_8b_lk_smd_zi \ +python distill_llama.py --model_config distill_llama3_1_8b_lk_smd_fd64 \ --distill_config distill_alpaca_clean_lr1e-2 \ --finetune_config finetune_lora_qkvo_alpaca_clean \ --eval_config eval_alpaca_clean \ @@ -139,10 +181,10 @@ python distill_llama.py --model_config distill_llama3_8b_lk_smd_zi \ --huggingface_token hf_ ``` -#### Llama-2-7B, Hedgehog Feature Map, using Alpaca-Clean +#### Llama 3.1 8B, Hedgehog + ThunderKittens Sliding Window, using Alpaca-Clean ``` -python distill_llama.py --model_config distill_llama2_7b_lk_smd_zi \ +python distill_llama.py --model_config distill_llama3_1_8b_lk_smd_wtk64_fd64_w01 \ --distill_config distill_alpaca_clean_lr1e-2 \ --finetune_config finetune_lora_qkvo_alpaca_clean \ --eval_config eval_alpaca_clean \ @@ -151,45 +193,110 @@ python distill_llama.py --model_config distill_llama2_7b_lk_smd_zi \ --huggingface_token hf_ ``` +--- + ### Evaluation The above scripts will save two checkpoints: (1) for the learned attention layer weights (denoted by a `_distill` suffix), (2) for the LoRA finetuning weights (denoted by a `_ft` suffix). To evaluate linearized models from these checkpoints, we can add the `--load_distill_checkpoint` and `--load_finetune_checkpoint` args. For example: ``` -python distill_llama.py --model_config distill_mistral_7b_lk_smd_zi_swa16_hh \ +python distill_llama.py --model_config distill_llama3_8b_lk_smd_wtk64_fd64_w01 \ --distill_config distill_alpaca_clean_lr1e-2 \ --finetune_config finetune_lora_qkvo_alpaca_clean \ --eval_config eval_alpaca_clean \ --lk_skip_connection --lk_zero_init \ --verbose --seed 0 --replicate 0 \ ---load_distill_checkpoint ./checkpoints/distill_mistral_7b_lk_smd_zi_swa16_hh/dl-d=distill_alpaca_clean_kld_mse_mistral_lr1e-3-m=distill_mistral_7b_lk_smd_swa_hh-f=finetune_lora_qkvo_alpaca_clean_mistral-s=0-se=0-re=0-lk=untied_head_einsum-lsc=1_distill.pt \ ---load_finetune_checkpoint ./checkpoints/distill_mistral_7b_lk_smd_zi_swa16_hh/dl-d=distill_alpaca_clean_kld_mse_mistral_lr1e-3-m=distill_mistral_7b_lk_smd_swa_hh-f=finetune_lora_qkvo_alpaca_clean_mistral-s=0-se=0-re=0-lk=untied_head_einsum-lsc=1-se=0-re=0_ft.pt +--load_distill_checkpoint \ +--load_finetune_checkpoint ``` -#### LM Eval +#### LM Evaluation Harness For sample LM Eval scripts, please see `./lm_eval_harness/README.md`. In particular, this will involve cloning the Language Model Evaluation Harness from [here](https://github.com/EleutherAI/lm-evaluation-harness/tree/b281b0921b636bc36ad05c0b0b0763bd6dd43463). Note we use the `b281b09` branch following Hugging Face's [Open LLM Leaderboard](https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard). --- -## Updated example commands +### Linearizing 70B models and up [WIP] + +

+ +

+ +We also support linearizing larger LLMs (Llama 3.1 70B, Llama 3.1 405B) using the great [llama-recipes](https://github.com/meta-llama/llama-recipes/tree/main/src/llama_recipes) repository. + +See `llama_recipes/README.md` for more details. At a high-level, we borrow the Fully Sharded Data Parallel (FSDP) pipeline, and split the two stages of LoLCATs linearizing into two scripts: + +1. `distill_llama.py`: where we first train subquadratic attentions to mimic the softmax attentions (saving the learned attention feature map checkpoints) +2. `distill_llama_finetune.py`: where we swap in the learned attentions and finetune the rest of the model with LoRA (saving the LoRA checkpoints) + +By passing in the same configurations files and arguments to both scripts, `distill_llama_finetune.py` should automatically load the saved checkpoints and pick up from where `distill_llama.py` left off. + +#### Sample Commands + +**_Script 1: Attention Transfer_** ```bash -python distill_llama.py \ ---model_config distill_long_llama3_8b_lk_smd_wtk64_fd64_w01 \ ---distill_config distill_long_alpaca_8k_xent0_mse1000_lr1e-2_bs1 \ ---finetune_config finetune_long_lora_qkvo_alpaca_clean_8192 \ ---eval_config eval_alpaca_clean \ ---lk_zero_init --verbose --seed 0 --replicate 614 --state_chunk_len 1024 \ ---num_train_epochs 2 +torchrun --nnodes 1 --nproc_per_node 8 \ +llama_recipes/distill_llama.py \ +--model_config distill_llama3_1_70b_lk_smd_wtk64_fd64_w01 \ +--distill_config distill_alpaca_clean_xent0_mse1000_lr1e-2 \ +--finetune_config finetune_lora_qkvo_alpaca_clean_llama3_1_70b \ +--eval_config eval_alpaca_clean \ +--verbose --replicate 0 --seed 0 \ +--enable_fsdp --low_cpu_fsdp ``` +**_Script 2: Low-rank Adaptation_** + ```bash -python distill_llama.py \ ---model_config distill_llama3_8b_lk_smd_wtk64_fd64_w01 \ +torchrun --nnodes 1 --nproc_per_node 9 \ +llama_recipes/distill_llama_finetune.py \ +--model_config distill_llama3_1_70b_lk_smd_wtk64_fd64_w01 \ --distill_config distill_alpaca_clean_xent0_mse1000_lr1e-2 \ ---finetune_config finetune_lora_qkvo_alpaca_clean \ ---eval_config eval_alpaca_clean \ ---lk_zero_init --verbose --seed 0 --replicate 614 \ ---num_train_epochs 2 +--finetune_config finetune_lora_qkvo_alpaca_clean_llama3_1_70b \ +--eval_config eval_alpaca_clean \ +--verbose --replicate 0 --seed 0 \ +--enable_fsdp --low_cpu_fsdp +``` + +#### GPU Memory Training Requirements + +See https://huggingface.co/blog/llama31#training-memory-requirements + +--- + +## Setup Debugging + +### Huggingface datasets errors + +If you come across an error like the following: + +``` + File "/root/miniconda3/envs/hedgehog/lib/python3.12/site-packages/fsspec/spec.py", line 606, in glob + pattern = glob_translate(path + ("/" if ends_with_sep else "")) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/root/miniconda3/envs/hedgehog/lib/python3.12/site-packages/fsspec/utils.py", line 734, in glob_translate + raise ValueError( +ValueError: Invalid pattern: '**' can only be an entire path component +``` + +Try reinstalling the Hugging Face `datasets` package with the version specified, e.g., via `pip install datasets==2.15.0`. + +Sometimes setting up the virtual environment from `environment.yaml` results in `datasets==2.11.0` being installed instead. + +Similarly, you may need to run the following installs: + +```bash +pip install nltk +pip install rouge-score +``` + +### Miniconda installation + +```bash +mkdir -p ~/miniconda3 +wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh +bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3 +rm -rf ~/miniconda3/miniconda.sh +~/miniconda3/bin/conda init bash ``` diff --git a/configs/experiment/distill_alpaca_clean_xent1_mse1000_lr1e-2.yaml b/configs/experiment/distill_alpaca_clean_xent1_mse1000_lr1e-2.yaml new file mode 100644 index 0000000..530ba68 --- /dev/null +++ b/configs/experiment/distill_alpaca_clean_xent1_mse1000_lr1e-2.yaml @@ -0,0 +1,52 @@ +dataset: + name: alpaca_clean + dataset_config: + name: alpaca + path: yahma/alpaca-cleaned + chunk_size: 1024 # sequence length for distilling + concat_data: true + cache_dir: 'data/alpaca' # Change this to where you want to save + pretrained_model_config: # will be updated based on model_config + pretrained_model_name_or_path: 'meta-llama/Meta-Llama-3-8B' + cache_dir: '/scr-ssd/mzhang/models/llama3' + preprocess_config: null + +dataloader: + batch_size: 1 + num_workers: 2 + drop_last: false + pin_memory: true + +optimizer: + optim: adamw_torch_fused + lr: 0.01 + weight_decay: 0.0 + +lr_scheduler: + lr_scheduler_type: reduce_lr_on_plateau + mode: min + factor: 0.1 + patience: 10 + min_lr: 0.00001 + +trainer: # HuggingFace Trainer-like arguments + name: distill_attention_xent_mse + reverse_kl: false + mse_factor: 1000 + xent_factor: 1 + + bf16: true + train_split: train + val_split: validation + num_train_epochs: 2 + gradient_accumulation_steps: 8 + seed: 42 + batch_size: 1 + load_best_model_at_end: true + greater_is_better: false + metric_for_best_model: distill/eval/loss + logging_steps: 100 + evaluation_strategy: steps + max_steps: -1 + eval_steps: 100 + max_eval_batches: null diff --git a/configs/model/distill_llama3_1_405b_lk_smd_wtk64_fd64_w01.yaml b/configs/model/distill_llama3_1_405b_lk_smd_wtk64_fd64_w01.yaml new file mode 100644 index 0000000..dd9f87f --- /dev/null +++ b/configs/model/distill_llama3_1_405b_lk_smd_wtk64_fd64_w01.yaml @@ -0,0 +1,39 @@ +name: llama +model: + pretrained_model_name_or_path: "meta-llama/Meta-Llama-3.1-405B" + cache_dir: "/scr-ssd/mzhang/models/llama-3_1-405b" # Set this to where you want to save checkpoint weights + return_dict: true + load_in_8bit: false + load_in_4bit: false + device_map: auto + low_cpu_mem_usage: true + torch_dtype: bfloat16 + attn_implementation: flash_attention_2 + rope_theta: 500000.0 + rope_scaling: + factor: 8.0 + low_freq_factor: 1.0 + high_freq_factor: 4.0 + original_max_position_embeddings: 8192 + rope_type: llama3 + +attention: + attention_type: lolcats_llama_window_tk + state_chunk_len: 1024 + window_size: 64 + affine_attention_factors: false + init_window_factor: -2.1972245773362196 + feature_map: softmax_dim + feature_map_kwargs: + eps: 1e-12 + # mlp: null # to set + fullspace: true + layer_idx: null # to set + learned_kernel: untied_head_einsum + learned_kernel_kwargs: + feature_dim: 64 + skip_connection: false + bias: false + zero_init: false + tie_qk_kernels: false + train_qk: false diff --git a/distill_llama.py b/distill_llama.py index 22c543a..eb0b5bb 100644 --- a/distill_llama.py +++ b/distill_llama.py @@ -147,7 +147,7 @@ def main(): wandb.config.update(_flattened) # Get pretrained model - model_loader = get_pretrained_loader(**model_config.model, + model_loader = get_pretrained_loader(**model_config.model, huggingface_token=args.huggingface_token) tokenizer = model_loader.load_tokenizer() tokenizer.pad_token_id = tokenizer.eos_token_id diff --git a/environment.yaml b/environment.yaml index 1c16ae6..71bd267 100644 --- a/environment.yaml +++ b/environment.yaml @@ -1,4 +1,4 @@ -name: hedgehog +name: lolcats channels: - conda-forge - pytorch diff --git a/llama_recipes/configs/fsdp.py b/llama_recipes/configs/fsdp.py index aec168c..4d754c0 100644 --- a/llama_recipes/configs/fsdp.py +++ b/llama_recipes/configs/fsdp.py @@ -15,7 +15,7 @@ class fsdp_config: sharding_group_size : int=0 # requires hsdp to be set. This specifies the sharding group size, number of GPUs that you model can fit into to form a replica of a model. replica_group_size: int=0 #requires hsdp to be set. This specifies the replica group size, which is world_size/sharding_group_size. checkpoint_type: StateDictType = StateDictType.SHARDED_STATE_DICT # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size. - fsdp_activation_checkpointing: bool=True + fsdp_activation_checkpointing: bool=False # True fsdp_cpu_offload: bool=False pure_bf16: bool = False optimizer: str= "AdamW" diff --git a/llama_recipes/dev_scripts.md b/llama_recipes/dev_scripts.md new file mode 100644 index 0000000..289020f --- /dev/null +++ b/llama_recipes/dev_scripts.md @@ -0,0 +1,35 @@ +### Various development scripts + +For your viewing pleasure. + +**_Initial testing of the distillation recipe_** + +May need to prepend with `NCCL_CUMEM_ENABLE=0` + +```bash +torchrun --nnodes 1 --nproc_per_node 7 \ +llama_recipes/distill_llama.py \ +--model_config distill_llama3_1_70b_lk_smd_wtk64_fd64_w01 \ +--distill_config distill_alpaca_clean_xent0_mse1000_lr1e-2 \ +--finetune_config finetune_lora_qkvo_alpaca_clean_llama3_1_70b \ +--eval_config eval_alpaca_clean \ +--verbose --replicate 0 --seed 0 \ +--enable_fsdp --low_cpu_fsdp \ +--dataset_chunk_size 1024 \ +--attention_type lolcats_llama_window_tk_bf16 \ +--eval_steps 1 --dataset_chunk_size 256 +``` + +```bash +torchrun --nnodes 1 --nproc_per_node 7 \ +llama_recipes/distill_llama_finetune.py \ +--model_config distill_llama3_1_70b_lk_smd_wtk64_fd64_w01 \ +--distill_config distill_alpaca_clean_xent0_mse1000_lr1e-2 \ +--finetune_config finetune_lora_qkvo_alpaca_clean_llama3_1_70b \ +--eval_config eval_alpaca_clean \ +--verbose --replicate 0 --seed 0 \ +--enable_fsdp --low_cpu_fsdp \ +--dataset_chunk_size 1024 \ +--attention_type lolcats_llama_window_tk_bf16 \ +--eval_steps 1 --dataset_chunk_size 256 +``` diff --git a/llama_recipes/distill_llama.py b/llama_recipes/distill_llama.py index 9426017..5c3a0fd 100644 --- a/llama_recipes/distill_llama.py +++ b/llama_recipes/distill_llama.py @@ -190,9 +190,9 @@ def get_args(): args.run_name += f'-npgc={args.no_peft_grad_ckpt}' if args.fsdp_activation_checkpointing is not None: args.run_name += f'-fac={args.fsdp_activation_checkpointing}' - if args.dataset_chunk_size is not None: - args.run_name += f'-dcs={args.dataset_chunk_size}' - args.run_name += f'-s={args.seed}' + # if args.dataset_chunk_size is not None: + # args.run_name += f'-dcs={args.dataset_chunk_size}' + # args.run_name += f'-s={args.seed}' if args.debug: args.run_name += '-debug' diff --git a/llama_recipes/model_checkpointing/distill_checkpoint_handler.py b/llama_recipes/model_checkpointing/distill_checkpoint_handler.py index cc028d2..09aa677 100644 --- a/llama_recipes/model_checkpointing/distill_checkpoint_handler.py +++ b/llama_recipes/model_checkpointing/distill_checkpoint_handler.py @@ -10,8 +10,8 @@ FullyShardedDataParallel as FSDP, StateDictType, FullStateDictConfig, # general model non-sharded, non-flattened params - # LocalStateDictConfig, # flattened params, usable only by FSDP - # ShardedStateDictConfig, # un-flattened param but shards, usable by other parallel schemes. + LocalStateDictConfig, # flattened params, usable only by FSDP + ShardedStateDictConfig, # un-flattened param but shards, usable by other parallel schemes. ) from torch.distributed._shard.checkpoint import ( @@ -186,9 +186,13 @@ def save_model_and_optimizer_sharded(model, rank, cfg,optim=None): ) t0 = time.perf_counter() - with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): + with FSDP.state_dict_type(model, + StateDictType.SHARDED_STATE_DICT, + ShardedStateDictConfig(offload_to_cpu=True), + ): # state_dict = {"model": model.state_dict()} state_dict = model.state_dict() + # state_dict = model.state_dict(state_dict_device='cpu') ignore_params = [ _rename_sharded(n) # n.replace('_fsdp_wrapped_module.','').replace('._checkpoint_wrapped_module', '').replace('.mlp._flat_param', '.mlp.layer').replace('._flat_param', '.weight') diff --git a/llama_recipes/trainer_attention.py b/llama_recipes/trainer_attention.py index 4ad06c4..7bc8bf0 100644 --- a/llama_recipes/trainer_attention.py +++ b/llama_recipes/trainer_attention.py @@ -301,11 +301,11 @@ def train(model, train_dataloader, eval_dataloader, tokenizer, return results, best_checkpoint_path -def eval_loop(model, evaluate_func, optimizer, lr_scheduler, - train_config, fsdp_config, rank, eval_dataloader, +def eval_loop(model, evaluate_func, optimizer, lr_scheduler, + train_config, fsdp_config, rank, eval_dataloader, local_rank, tokenizer, wandb_run, - val_step_loss, val_loss, best_val_loss, checkpoint_times, - epoch, step): # extra globals + val_step_loss, val_loss, best_val_loss, + checkpoint_times, epoch, step): # extra globals """ Evaluate model and save checkpoints - see `evaluate_func` for evaluation logic @@ -389,7 +389,7 @@ def evaluate_attn(model, train_config, eval_dataloader, val_step_loss = [] eval_loss = 0.0 # Initialize evaluation loss - _epoch = epoch if epoch is not None else '' + _epoch = f' {epoch}' if epoch is not None else '' pbar = tqdm(eval_dataloader,colour="green", desc=f"Evaluating epoch{_epoch}", dynamic_ncols=True) for step, batch in enumerate(pbar): for key in batch.keys(): @@ -409,6 +409,8 @@ def evaluate_attn(model, train_config, eval_dataloader, eval_loss += loss.detach().float() + pbar.set_description(f"Evaluating epoch{_epoch} | step_loss: {loss.item():.5f} | avg_loss: {eval_loss.item()/(step+1):.5f}") + # If there's more than one CUDA device, reduce evaluation loss across all devices if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp): dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM) @@ -429,11 +431,11 @@ def evaluate_attn(model, train_config, eval_dataloader, if local_rank == 0 or not train_config.enable_fsdp: print(f" {eval_epoch_loss=}") - if wandb_run: + if wandb_run: wandb_run.log({'eval/loss': eval_epoch_loss,}, commit=False) del loss; del eval_loss; del batch - torch.cuda.empty_cache() + clear_gpu_cache() return eval_epoch_loss, val_step_loss diff --git a/llama_recipes/trainer_finetune.py b/llama_recipes/trainer_finetune.py index 97ad052..04c90c7 100644 --- a/llama_recipes/trainer_finetune.py +++ b/llama_recipes/trainer_finetune.py @@ -113,6 +113,7 @@ def train(model, train_dataloader, eval_dataloader, tokenizer, results = {} best_val_loss = float("inf") best_checkpoint_path = None + total_step = 0 for epoch in range(train_config.num_epochs): epoch_start_time = time.perf_counter() @@ -149,7 +150,7 @@ def train(model, train_dataloader, eval_dataloader, tokenizer, if train_config.use_fp16: # if fp16 is enabled, use gradient scaler to handle gradient update scaler.scale(loss).backward() - if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: + if (total_step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0: scaler.unscale_(optimizer) if train_config.enable_fsdp: @@ -163,7 +164,7 @@ def train(model, train_dataloader, eval_dataloader, tokenizer, else: # regular backpropagation when fp16 is not used loss.backward() - if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: + if (total_step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0: if train_config.enable_fsdp: model.clip_grad_norm_(train_config.gradient_clipping_threshold) @@ -177,7 +178,7 @@ def train(model, train_dataloader, eval_dataloader, tokenizer, if not train_config.enable_fsdp or rank==0: wandb_run.log({ 'train/epoch': epoch + 1, - 'train/step': epoch * len(train_dataloader) + step, + 'train/step': total_step, # epoch * len(train_dataloader) + step, 'train/loss': train_step_loss[-1], 'train/ppl': train_step_perplexity[-1] }) @@ -189,21 +190,24 @@ def train(model, train_dataloader, eval_dataloader, tokenizer, if train_config.save_metrics: save_to_json(metrics_filename, train_step_loss, train_loss, val_step_loss, val_loss) - if step == getattr(train_config, 'num_train_steps', -1): + if total_step == getattr(train_config, 'num_train_steps', -1): break # Early exit for debugging later logic if (train_config.run_validation and ( - (step + 1) % (train_config.eval_steps * gradient_accumulation_steps) == 0)): # or step == len(train_dataloader) - 1)): + (total_step + 1) % (train_config.eval_steps * gradient_accumulation_steps) == 0)): # or step == len(train_dataloader) - 1)): dist.barrier() - eval_outputs = eval_loop(model, optimizer, lr_scheduler, train_config, fsdp_config, rank, - eval_dataloader, local_rank, tokenizer, wandb_run, - val_step_loss, val_loss, best_val_loss, checkpoint_times, epoch, step) + eval_outputs = eval_loop(model, evaluate_lm, optimizer, lr_scheduler, + train_config, fsdp_config, rank, eval_dataloader, + local_rank, tokenizer, wandb_run, + val_step_loss, val_loss, best_val_loss, + checkpoint_times, epoch, total_step) dist.barrier() save_path, val_step_loss, val_loss, best_val_loss, checkpoint_times = eval_outputs if save_path is not None: best_checkpoint_path = save_path model.train() print(f'-> Model is training on rank {rank}') + total_step += 1 pbar.close() epoch_end_time = time.perf_counter()-epoch_start_time @@ -226,9 +230,11 @@ def train(model, train_dataloader, eval_dataloader, tokenizer, # Update the learning rate as needed # lr_scheduler.step() dist.barrier() - eval_outputs = eval_loop(model, evaluate_lm, optimizer, lr_scheduler, train_config, fsdp_config, rank, - eval_dataloader, local_rank, tokenizer, wandb_run, - val_step_loss, val_loss, best_val_loss, checkpoint_times, epoch, step) + eval_outputs = eval_loop(model, evaluate_lm, optimizer, lr_scheduler, + train_config, fsdp_config, rank, eval_dataloader, + local_rank, tokenizer, wandb_run, + val_step_loss, val_loss, best_val_loss, + checkpoint_times, epoch, total_step) dist.barrier() save_path, val_step_loss, val_loss, best_val_loss, checkpoint_times = eval_outputs if save_path is not None: @@ -242,7 +248,8 @@ def train(model, train_dataloader, eval_dataloader, tokenizer, return results, best_checkpoint_path -def evaluate_lm(model, train_config, eval_dataloader, local_rank, tokenizer, wandb_run): +def evaluate_lm(model, train_config, eval_dataloader, + local_rank, tokenizer, wandb_run, epoch: int = None): """ Evaluates the model on the given dataloader @@ -261,7 +268,9 @@ def evaluate_lm(model, train_config, eval_dataloader, local_rank, tokenizer, wan val_step_loss = [] eval_loss = 0.0 # Initialize evaluation loss - for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)): + _epoch = f' {epoch}' if epoch is not None else '' + pbar = tqdm(eval_dataloader,colour="green", desc=f"Evaluating epoch{_epoch}", dynamic_ncols=True) + for step, batch in enumerate(pbar): for key in batch.keys(): if train_config.enable_fsdp: batch[key] = batch[key].to(local_rank) @@ -278,6 +287,8 @@ def evaluate_lm(model, train_config, eval_dataloader, local_rank, tokenizer, wan val_step_loss.append(loss.detach().cpu().float().item()) eval_loss += loss.detach().float() + _ppl = torch.exp(eval_loss/(step+1)).item() + pbar.set_description(f"Evaluating epoch{_epoch} | step_loss: {loss.item():.5f} | avg_loss: {eval_loss.item()/(step+1):.5f} | avg_ppl: {_ppl:.5f}") # If there's more than one CUDA device, reduce evaluation loss across all devices if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp): @@ -301,7 +312,7 @@ def evaluate_lm(model, train_config, eval_dataloader, local_rank, tokenizer, wan if local_rank == 0 or not train_config.enable_fsdp: print(f" eval_epoch_loss={eval_epoch_loss}, eval_epoch_ppl={eval_epoch_ppl}") - if wandb_run: + if wandb_run: wandb_run.log({'eval/loss': eval_epoch_loss, 'eval/ppl': eval_epoch_ppl}, commit=False) return eval_epoch_loss, val_step_loss \ No newline at end of file