diff --git a/README.md b/README.md
index 4d619bf..17a1191 100644
--- a/README.md
+++ b/README.md
@@ -7,7 +7,7 @@
In this README:
- Getting started with dependencies, installation, and experiment configs
-- Sample commands (Mistral-7B-v0.1, Llama-2-7B, Llama-3-8B, Mixtral-8x7B, Llama-2-70B, Llama-3-70B)
+- Sample commands (Mistral-7B-v0.1, Llama-3-8B, Llama-3.1-8B, Llama-3.1-70B)
---
@@ -19,7 +19,7 @@ Please see `environment.yaml` for dependencies. We can set them up with conda:
```
conda env create -f environment.yaml
-conda activate hedgehog
+conda activate lolcats
```
---
@@ -45,7 +45,7 @@ pretrained_config:
low_cpu_mem_usage: true
torch_dtype: bfloat16
rope_theta: 10000.0
- attn_implementation: eager # so we can supervise with attention weights
+ attn_implementation: eager # if supervising with attention weights
```
---
@@ -72,7 +72,13 @@ For now, we implement the causal linear attention with the CUDA kernel from [htt
To build the kernel (`causal_dot_product`), first activate the conda environment (`conda activate hedgehog`). Then navigate to `./csrc/` and run `python setup.py install` within `./csrc/`. It's worth checking the arguments in `./csrc/setup.py` to match your GPU setup and C++ versions.
-TODO: we're very excited to integrate additional developments like Songlin and friends' `flash-linear-attention` [repo](https://github.com/sustcsonglin/flash-linear-attention), as well as [ThunderKittens](https://github.com/HazyResearch/ThunderKittens). Please let us know if you're interested in scaling up these efficient linear attention implementations into 7B to 70B models.
+### ThunderKittens linear attention + sliding window kernel
+
+TODO. [ThunderKittens](https://github.com/HazyResearch/ThunderKittens)
+
+### More
+
+We're very excited to integrate additional developments like Songlin and friends' [flash-linear-attention](https://github.com/sustcsonglin/flash-linear-attention)
### Flash Attention 2 install
@@ -92,6 +98,8 @@ For any of these commands, you may need to provide a Hugging Face token to downl
### Demoing linear attention 7B models
+**_Note: Stale_**
+
We upload a couple checkpoints in `./checkpoints/`, where for any linearized 7B model we only need to save the linear attention layers and the LoRA weights (in two separate `.pt` checkpoints). To chat with these models, you can run:
```
@@ -101,36 +109,70 @@ python -Wignore demo_hedgehog_llm.py \
--num_generations 1 --benchmark
```
-### Distilling + finetuning 7B models
+---
+
+### Linearizing 7B models
+
+
+
+
-Any of the below commands will convert a 7B Mistral or Llama LLM into a linear attention instruction-following variant. Despite only using LoRA and training on these 50K instruction-tuning samples, we're able to ``unlock'' a good amount of the base model performance when measured on LM Eval tasks.
+Any of the below commands will convert a 7B Mistral or Llama LLM into a subquadratic attention instruction-following variant. Despite only using LoRA and training on these 50K instruction-tuning samples, we're able to ``unlock'' a good amount of the base model performance when measured on LM Eval tasks.
+
+See `configs/model/` for model configs used in the below commands, and `configs/experiments/` for attention transfer and finetuning configs.
#### Mistral-7B-v0.1, Hedgehog Feature Map, using Alpaca-Clean
```
-python distill_llama.py --model_config distill_mistral_7b_lk_smd_zi \
+python distill_llama.py --model_config distill_mistral_7b_lk_smd_fd64 \
+--distill_config distill_alpaca_clean_lr1e-2 \
+--finetune_config finetune_lora_qkvo_alpaca_clean \
+--eval_config eval_alpaca_clean \
+--lk_zero_init \
+--verbose --seed 0 --replicate 0 \
+--huggingface_token hf_
+```
+
+#### Mistral-7B-v0.1, Hedgehog + ThunderKittens Sliding Window, using Alpaca-Clean
+
+```
+python distill_llama.py --model_config distill_mistral_7b_lk_smd_wtk64_fd64_w01 \
+--distill_config distill_alpaca_clean_lr1e-2 \
+--finetune_config finetune_lora_qkvo_alpaca_clean \
+--eval_config eval_alpaca_clean \
+--lk_zero_init \
+--verbose --seed 0 --replicate 0 \
+--huggingface_token hf_
+```
+
+#### Llama 3 8B, Hedgehog Feature Map, using Alpaca-Clean
+
+```
+python distill_llama.py --model_config distill_llama3_8b_lk_smd_fd64 \
--distill_config distill_alpaca_clean_lr1e-2 \
--finetune_config finetune_lora_qkvo_alpaca_clean \
--eval_config eval_alpaca_clean \
--lk_skip_connection --lk_zero_init \
---verbose --seed 0 --replicate 0
+--verbose --seed 0 --replicate 0 \
+--huggingface_token hf_
```
-#### Mistral-7B-v0.1, Hedgehog + Sliding Window, using Alpaca-Clean
+#### Llama 3 8B, Hedgehog + ThunderKittens Sliding Window, using Alpaca-Clean
```
-python distill_llama.py --model_config distill_mistral_7b_lk_smd_zi_swa16_hh \
+python distill_llama.py --model_config distill_llama3_8b_lk_smd_wtk64_fd64_w01 \
--distill_config distill_alpaca_clean_lr1e-2 \
--finetune_config finetune_lora_qkvo_alpaca_clean \
--eval_config eval_alpaca_clean \
--lk_skip_connection --lk_zero_init \
---verbose --seed 0 --replicate 0
+--verbose --seed 0 --replicate 0 \
+--huggingface_token hf_
```
-#### Llama-3-8B, Hedgehog Feature Map, using Alpaca-Clean
+#### Llama 3.1 8B, Hedgehog Feature Map, using Alpaca-Clean
```
-python distill_llama.py --model_config distill_llama3_8b_lk_smd_zi \
+python distill_llama.py --model_config distill_llama3_1_8b_lk_smd_fd64 \
--distill_config distill_alpaca_clean_lr1e-2 \
--finetune_config finetune_lora_qkvo_alpaca_clean \
--eval_config eval_alpaca_clean \
@@ -139,10 +181,10 @@ python distill_llama.py --model_config distill_llama3_8b_lk_smd_zi \
--huggingface_token hf_
```
-#### Llama-2-7B, Hedgehog Feature Map, using Alpaca-Clean
+#### Llama 3.1 8B, Hedgehog + ThunderKittens Sliding Window, using Alpaca-Clean
```
-python distill_llama.py --model_config distill_llama2_7b_lk_smd_zi \
+python distill_llama.py --model_config distill_llama3_1_8b_lk_smd_wtk64_fd64_w01 \
--distill_config distill_alpaca_clean_lr1e-2 \
--finetune_config finetune_lora_qkvo_alpaca_clean \
--eval_config eval_alpaca_clean \
@@ -151,45 +193,110 @@ python distill_llama.py --model_config distill_llama2_7b_lk_smd_zi \
--huggingface_token hf_
```
+---
+
### Evaluation
The above scripts will save two checkpoints: (1) for the learned attention layer weights (denoted by a `_distill` suffix), (2) for the LoRA finetuning weights (denoted by a `_ft` suffix). To evaluate linearized models from these checkpoints, we can add the `--load_distill_checkpoint` and `--load_finetune_checkpoint` args. For example:
```
-python distill_llama.py --model_config distill_mistral_7b_lk_smd_zi_swa16_hh \
+python distill_llama.py --model_config distill_llama3_8b_lk_smd_wtk64_fd64_w01 \
--distill_config distill_alpaca_clean_lr1e-2 \
--finetune_config finetune_lora_qkvo_alpaca_clean \
--eval_config eval_alpaca_clean \
--lk_skip_connection --lk_zero_init \
--verbose --seed 0 --replicate 0 \
---load_distill_checkpoint ./checkpoints/distill_mistral_7b_lk_smd_zi_swa16_hh/dl-d=distill_alpaca_clean_kld_mse_mistral_lr1e-3-m=distill_mistral_7b_lk_smd_swa_hh-f=finetune_lora_qkvo_alpaca_clean_mistral-s=0-se=0-re=0-lk=untied_head_einsum-lsc=1_distill.pt \
---load_finetune_checkpoint ./checkpoints/distill_mistral_7b_lk_smd_zi_swa16_hh/dl-d=distill_alpaca_clean_kld_mse_mistral_lr1e-3-m=distill_mistral_7b_lk_smd_swa_hh-f=finetune_lora_qkvo_alpaca_clean_mistral-s=0-se=0-re=0-lk=untied_head_einsum-lsc=1-se=0-re=0_ft.pt
+--load_distill_checkpoint \
+--load_finetune_checkpoint
```
-#### LM Eval
+#### LM Evaluation Harness
For sample LM Eval scripts, please see `./lm_eval_harness/README.md`. In particular, this will involve cloning the Language Model Evaluation Harness from [here](https://github.com/EleutherAI/lm-evaluation-harness/tree/b281b0921b636bc36ad05c0b0b0763bd6dd43463). Note we use the `b281b09` branch following Hugging Face's [Open LLM Leaderboard](https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard).
---
-## Updated example commands
+### Linearizing 70B models and up [WIP]
+
+
+
+
+
+We also support linearizing larger LLMs (Llama 3.1 70B, Llama 3.1 405B) using the great [llama-recipes](https://github.com/meta-llama/llama-recipes/tree/main/src/llama_recipes) repository.
+
+See `llama_recipes/README.md` for more details. At a high-level, we borrow the Fully Sharded Data Parallel (FSDP) pipeline, and split the two stages of LoLCATs linearizing into two scripts:
+
+1. `distill_llama.py`: where we first train subquadratic attentions to mimic the softmax attentions (saving the learned attention feature map checkpoints)
+2. `distill_llama_finetune.py`: where we swap in the learned attentions and finetune the rest of the model with LoRA (saving the LoRA checkpoints)
+
+By passing in the same configurations files and arguments to both scripts, `distill_llama_finetune.py` should automatically load the saved checkpoints and pick up from where `distill_llama.py` left off.
+
+#### Sample Commands
+
+**_Script 1: Attention Transfer_**
```bash
-python distill_llama.py \
---model_config distill_long_llama3_8b_lk_smd_wtk64_fd64_w01 \
---distill_config distill_long_alpaca_8k_xent0_mse1000_lr1e-2_bs1 \
---finetune_config finetune_long_lora_qkvo_alpaca_clean_8192 \
---eval_config eval_alpaca_clean \
---lk_zero_init --verbose --seed 0 --replicate 614 --state_chunk_len 1024 \
---num_train_epochs 2
+torchrun --nnodes 1 --nproc_per_node 8 \
+llama_recipes/distill_llama.py \
+--model_config distill_llama3_1_70b_lk_smd_wtk64_fd64_w01 \
+--distill_config distill_alpaca_clean_xent0_mse1000_lr1e-2 \
+--finetune_config finetune_lora_qkvo_alpaca_clean_llama3_1_70b \
+--eval_config eval_alpaca_clean \
+--verbose --replicate 0 --seed 0 \
+--enable_fsdp --low_cpu_fsdp
```
+**_Script 2: Low-rank Adaptation_**
+
```bash
-python distill_llama.py \
---model_config distill_llama3_8b_lk_smd_wtk64_fd64_w01 \
+torchrun --nnodes 1 --nproc_per_node 9 \
+llama_recipes/distill_llama_finetune.py \
+--model_config distill_llama3_1_70b_lk_smd_wtk64_fd64_w01 \
--distill_config distill_alpaca_clean_xent0_mse1000_lr1e-2 \
---finetune_config finetune_lora_qkvo_alpaca_clean \
---eval_config eval_alpaca_clean \
---lk_zero_init --verbose --seed 0 --replicate 614 \
---num_train_epochs 2
+--finetune_config finetune_lora_qkvo_alpaca_clean_llama3_1_70b \
+--eval_config eval_alpaca_clean \
+--verbose --replicate 0 --seed 0 \
+--enable_fsdp --low_cpu_fsdp
+```
+
+#### GPU Memory Training Requirements
+
+See https://huggingface.co/blog/llama31#training-memory-requirements
+
+---
+
+## Setup Debugging
+
+### Huggingface datasets errors
+
+If you come across an error like the following:
+
+```
+ File "/root/miniconda3/envs/hedgehog/lib/python3.12/site-packages/fsspec/spec.py", line 606, in glob
+ pattern = glob_translate(path + ("/" if ends_with_sep else ""))
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ File "/root/miniconda3/envs/hedgehog/lib/python3.12/site-packages/fsspec/utils.py", line 734, in glob_translate
+ raise ValueError(
+ValueError: Invalid pattern: '**' can only be an entire path component
+```
+
+Try reinstalling the Hugging Face `datasets` package with the version specified, e.g., via `pip install datasets==2.15.0`.
+
+Sometimes setting up the virtual environment from `environment.yaml` results in `datasets==2.11.0` being installed instead.
+
+Similarly, you may need to run the following installs:
+
+```bash
+pip install nltk
+pip install rouge-score
+```
+
+### Miniconda installation
+
+```bash
+mkdir -p ~/miniconda3
+wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
+bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
+rm -rf ~/miniconda3/miniconda.sh
+~/miniconda3/bin/conda init bash
```
diff --git a/configs/experiment/distill_alpaca_clean_xent1_mse1000_lr1e-2.yaml b/configs/experiment/distill_alpaca_clean_xent1_mse1000_lr1e-2.yaml
new file mode 100644
index 0000000..530ba68
--- /dev/null
+++ b/configs/experiment/distill_alpaca_clean_xent1_mse1000_lr1e-2.yaml
@@ -0,0 +1,52 @@
+dataset:
+ name: alpaca_clean
+ dataset_config:
+ name: alpaca
+ path: yahma/alpaca-cleaned
+ chunk_size: 1024 # sequence length for distilling
+ concat_data: true
+ cache_dir: 'data/alpaca' # Change this to where you want to save
+ pretrained_model_config: # will be updated based on model_config
+ pretrained_model_name_or_path: 'meta-llama/Meta-Llama-3-8B'
+ cache_dir: '/scr-ssd/mzhang/models/llama3'
+ preprocess_config: null
+
+dataloader:
+ batch_size: 1
+ num_workers: 2
+ drop_last: false
+ pin_memory: true
+
+optimizer:
+ optim: adamw_torch_fused
+ lr: 0.01
+ weight_decay: 0.0
+
+lr_scheduler:
+ lr_scheduler_type: reduce_lr_on_plateau
+ mode: min
+ factor: 0.1
+ patience: 10
+ min_lr: 0.00001
+
+trainer: # HuggingFace Trainer-like arguments
+ name: distill_attention_xent_mse
+ reverse_kl: false
+ mse_factor: 1000
+ xent_factor: 1
+
+ bf16: true
+ train_split: train
+ val_split: validation
+ num_train_epochs: 2
+ gradient_accumulation_steps: 8
+ seed: 42
+ batch_size: 1
+ load_best_model_at_end: true
+ greater_is_better: false
+ metric_for_best_model: distill/eval/loss
+ logging_steps: 100
+ evaluation_strategy: steps
+ max_steps: -1
+ eval_steps: 100
+ max_eval_batches: null
diff --git a/configs/model/distill_llama3_1_405b_lk_smd_wtk64_fd64_w01.yaml b/configs/model/distill_llama3_1_405b_lk_smd_wtk64_fd64_w01.yaml
new file mode 100644
index 0000000..dd9f87f
--- /dev/null
+++ b/configs/model/distill_llama3_1_405b_lk_smd_wtk64_fd64_w01.yaml
@@ -0,0 +1,39 @@
+name: llama
+model:
+ pretrained_model_name_or_path: "meta-llama/Meta-Llama-3.1-405B"
+ cache_dir: "/scr-ssd/mzhang/models/llama-3_1-405b" # Set this to where you want to save checkpoint weights
+ return_dict: true
+ load_in_8bit: false
+ load_in_4bit: false
+ device_map: auto
+ low_cpu_mem_usage: true
+ torch_dtype: bfloat16
+ attn_implementation: flash_attention_2
+ rope_theta: 500000.0
+ rope_scaling:
+ factor: 8.0
+ low_freq_factor: 1.0
+ high_freq_factor: 4.0
+ original_max_position_embeddings: 8192
+ rope_type: llama3
+
+attention:
+ attention_type: lolcats_llama_window_tk
+ state_chunk_len: 1024
+ window_size: 64
+ affine_attention_factors: false
+ init_window_factor: -2.1972245773362196
+ feature_map: softmax_dim
+ feature_map_kwargs:
+ eps: 1e-12
+ # mlp: null # to set
+ fullspace: true
+ layer_idx: null # to set
+ learned_kernel: untied_head_einsum
+ learned_kernel_kwargs:
+ feature_dim: 64
+ skip_connection: false
+ bias: false
+ zero_init: false
+ tie_qk_kernels: false
+ train_qk: false
diff --git a/distill_llama.py b/distill_llama.py
index 22c543a..eb0b5bb 100644
--- a/distill_llama.py
+++ b/distill_llama.py
@@ -147,7 +147,7 @@ def main():
wandb.config.update(_flattened)
# Get pretrained model
- model_loader = get_pretrained_loader(**model_config.model,
+ model_loader = get_pretrained_loader(**model_config.model,
huggingface_token=args.huggingface_token)
tokenizer = model_loader.load_tokenizer()
tokenizer.pad_token_id = tokenizer.eos_token_id
diff --git a/environment.yaml b/environment.yaml
index 1c16ae6..71bd267 100644
--- a/environment.yaml
+++ b/environment.yaml
@@ -1,4 +1,4 @@
-name: hedgehog
+name: lolcats
channels:
- conda-forge
- pytorch
diff --git a/llama_recipes/configs/fsdp.py b/llama_recipes/configs/fsdp.py
index aec168c..4d754c0 100644
--- a/llama_recipes/configs/fsdp.py
+++ b/llama_recipes/configs/fsdp.py
@@ -15,7 +15,7 @@ class fsdp_config:
sharding_group_size : int=0 # requires hsdp to be set. This specifies the sharding group size, number of GPUs that you model can fit into to form a replica of a model.
replica_group_size: int=0 #requires hsdp to be set. This specifies the replica group size, which is world_size/sharding_group_size.
checkpoint_type: StateDictType = StateDictType.SHARDED_STATE_DICT # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
- fsdp_activation_checkpointing: bool=True
+ fsdp_activation_checkpointing: bool=False # True
fsdp_cpu_offload: bool=False
pure_bf16: bool = False
optimizer: str= "AdamW"
diff --git a/llama_recipes/dev_scripts.md b/llama_recipes/dev_scripts.md
new file mode 100644
index 0000000..289020f
--- /dev/null
+++ b/llama_recipes/dev_scripts.md
@@ -0,0 +1,35 @@
+### Various development scripts
+
+For your viewing pleasure.
+
+**_Initial testing of the distillation recipe_**
+
+May need to prepend with `NCCL_CUMEM_ENABLE=0`
+
+```bash
+torchrun --nnodes 1 --nproc_per_node 7 \
+llama_recipes/distill_llama.py \
+--model_config distill_llama3_1_70b_lk_smd_wtk64_fd64_w01 \
+--distill_config distill_alpaca_clean_xent0_mse1000_lr1e-2 \
+--finetune_config finetune_lora_qkvo_alpaca_clean_llama3_1_70b \
+--eval_config eval_alpaca_clean \
+--verbose --replicate 0 --seed 0 \
+--enable_fsdp --low_cpu_fsdp \
+--dataset_chunk_size 1024 \
+--attention_type lolcats_llama_window_tk_bf16 \
+--eval_steps 1 --dataset_chunk_size 256
+```
+
+```bash
+torchrun --nnodes 1 --nproc_per_node 7 \
+llama_recipes/distill_llama_finetune.py \
+--model_config distill_llama3_1_70b_lk_smd_wtk64_fd64_w01 \
+--distill_config distill_alpaca_clean_xent0_mse1000_lr1e-2 \
+--finetune_config finetune_lora_qkvo_alpaca_clean_llama3_1_70b \
+--eval_config eval_alpaca_clean \
+--verbose --replicate 0 --seed 0 \
+--enable_fsdp --low_cpu_fsdp \
+--dataset_chunk_size 1024 \
+--attention_type lolcats_llama_window_tk_bf16 \
+--eval_steps 1 --dataset_chunk_size 256
+```
diff --git a/llama_recipes/distill_llama.py b/llama_recipes/distill_llama.py
index 9426017..5c3a0fd 100644
--- a/llama_recipes/distill_llama.py
+++ b/llama_recipes/distill_llama.py
@@ -190,9 +190,9 @@ def get_args():
args.run_name += f'-npgc={args.no_peft_grad_ckpt}'
if args.fsdp_activation_checkpointing is not None:
args.run_name += f'-fac={args.fsdp_activation_checkpointing}'
- if args.dataset_chunk_size is not None:
- args.run_name += f'-dcs={args.dataset_chunk_size}'
- args.run_name += f'-s={args.seed}'
+ # if args.dataset_chunk_size is not None:
+ # args.run_name += f'-dcs={args.dataset_chunk_size}'
+ # args.run_name += f'-s={args.seed}'
if args.debug:
args.run_name += '-debug'
diff --git a/llama_recipes/model_checkpointing/distill_checkpoint_handler.py b/llama_recipes/model_checkpointing/distill_checkpoint_handler.py
index cc028d2..09aa677 100644
--- a/llama_recipes/model_checkpointing/distill_checkpoint_handler.py
+++ b/llama_recipes/model_checkpointing/distill_checkpoint_handler.py
@@ -10,8 +10,8 @@
FullyShardedDataParallel as FSDP,
StateDictType,
FullStateDictConfig, # general model non-sharded, non-flattened params
- # LocalStateDictConfig, # flattened params, usable only by FSDP
- # ShardedStateDictConfig, # un-flattened param but shards, usable by other parallel schemes.
+ LocalStateDictConfig, # flattened params, usable only by FSDP
+ ShardedStateDictConfig, # un-flattened param but shards, usable by other parallel schemes.
)
from torch.distributed._shard.checkpoint import (
@@ -186,9 +186,13 @@ def save_model_and_optimizer_sharded(model, rank, cfg,optim=None):
)
t0 = time.perf_counter()
- with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
+ with FSDP.state_dict_type(model,
+ StateDictType.SHARDED_STATE_DICT,
+ ShardedStateDictConfig(offload_to_cpu=True),
+ ):
# state_dict = {"model": model.state_dict()}
state_dict = model.state_dict()
+ # state_dict = model.state_dict(state_dict_device='cpu')
ignore_params = [
_rename_sharded(n)
# n.replace('_fsdp_wrapped_module.','').replace('._checkpoint_wrapped_module', '').replace('.mlp._flat_param', '.mlp.layer').replace('._flat_param', '.weight')
diff --git a/llama_recipes/trainer_attention.py b/llama_recipes/trainer_attention.py
index 4ad06c4..7bc8bf0 100644
--- a/llama_recipes/trainer_attention.py
+++ b/llama_recipes/trainer_attention.py
@@ -301,11 +301,11 @@ def train(model, train_dataloader, eval_dataloader, tokenizer,
return results, best_checkpoint_path
-def eval_loop(model, evaluate_func, optimizer, lr_scheduler,
- train_config, fsdp_config, rank, eval_dataloader,
+def eval_loop(model, evaluate_func, optimizer, lr_scheduler,
+ train_config, fsdp_config, rank, eval_dataloader,
local_rank, tokenizer, wandb_run,
- val_step_loss, val_loss, best_val_loss, checkpoint_times,
- epoch, step): # extra globals
+ val_step_loss, val_loss, best_val_loss,
+ checkpoint_times, epoch, step): # extra globals
"""
Evaluate model and save checkpoints
- see `evaluate_func` for evaluation logic
@@ -389,7 +389,7 @@ def evaluate_attn(model, train_config, eval_dataloader,
val_step_loss = []
eval_loss = 0.0 # Initialize evaluation loss
- _epoch = epoch if epoch is not None else ''
+ _epoch = f' {epoch}' if epoch is not None else ''
pbar = tqdm(eval_dataloader,colour="green", desc=f"Evaluating epoch{_epoch}", dynamic_ncols=True)
for step, batch in enumerate(pbar):
for key in batch.keys():
@@ -409,6 +409,8 @@ def evaluate_attn(model, train_config, eval_dataloader,
eval_loss += loss.detach().float()
+ pbar.set_description(f"Evaluating epoch{_epoch} | step_loss: {loss.item():.5f} | avg_loss: {eval_loss.item()/(step+1):.5f}")
+
# If there's more than one CUDA device, reduce evaluation loss across all devices
if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
@@ -429,11 +431,11 @@ def evaluate_attn(model, train_config, eval_dataloader,
if local_rank == 0 or not train_config.enable_fsdp:
print(f" {eval_epoch_loss=}")
- if wandb_run:
+ if wandb_run:
wandb_run.log({'eval/loss': eval_epoch_loss,}, commit=False)
del loss; del eval_loss; del batch
- torch.cuda.empty_cache()
+ clear_gpu_cache()
return eval_epoch_loss, val_step_loss
diff --git a/llama_recipes/trainer_finetune.py b/llama_recipes/trainer_finetune.py
index 97ad052..04c90c7 100644
--- a/llama_recipes/trainer_finetune.py
+++ b/llama_recipes/trainer_finetune.py
@@ -113,6 +113,7 @@ def train(model, train_dataloader, eval_dataloader, tokenizer,
results = {}
best_val_loss = float("inf")
best_checkpoint_path = None
+ total_step = 0
for epoch in range(train_config.num_epochs):
epoch_start_time = time.perf_counter()
@@ -149,7 +150,7 @@ def train(model, train_dataloader, eval_dataloader, tokenizer,
if train_config.use_fp16:
# if fp16 is enabled, use gradient scaler to handle gradient update
scaler.scale(loss).backward()
- if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
+ if (total_step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
scaler.unscale_(optimizer)
if train_config.enable_fsdp:
@@ -163,7 +164,7 @@ def train(model, train_dataloader, eval_dataloader, tokenizer,
else:
# regular backpropagation when fp16 is not used
loss.backward()
- if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
+ if (total_step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
if train_config.enable_fsdp:
model.clip_grad_norm_(train_config.gradient_clipping_threshold)
@@ -177,7 +178,7 @@ def train(model, train_dataloader, eval_dataloader, tokenizer,
if not train_config.enable_fsdp or rank==0:
wandb_run.log({
'train/epoch': epoch + 1,
- 'train/step': epoch * len(train_dataloader) + step,
+ 'train/step': total_step, # epoch * len(train_dataloader) + step,
'train/loss': train_step_loss[-1],
'train/ppl': train_step_perplexity[-1]
})
@@ -189,21 +190,24 @@ def train(model, train_dataloader, eval_dataloader, tokenizer,
if train_config.save_metrics:
save_to_json(metrics_filename, train_step_loss, train_loss, val_step_loss, val_loss)
- if step == getattr(train_config, 'num_train_steps', -1):
+ if total_step == getattr(train_config, 'num_train_steps', -1):
break # Early exit for debugging later logic
if (train_config.run_validation and (
- (step + 1) % (train_config.eval_steps * gradient_accumulation_steps) == 0)): # or step == len(train_dataloader) - 1)):
+ (total_step + 1) % (train_config.eval_steps * gradient_accumulation_steps) == 0)): # or step == len(train_dataloader) - 1)):
dist.barrier()
- eval_outputs = eval_loop(model, optimizer, lr_scheduler, train_config, fsdp_config, rank,
- eval_dataloader, local_rank, tokenizer, wandb_run,
- val_step_loss, val_loss, best_val_loss, checkpoint_times, epoch, step)
+ eval_outputs = eval_loop(model, evaluate_lm, optimizer, lr_scheduler,
+ train_config, fsdp_config, rank, eval_dataloader,
+ local_rank, tokenizer, wandb_run,
+ val_step_loss, val_loss, best_val_loss,
+ checkpoint_times, epoch, total_step)
dist.barrier()
save_path, val_step_loss, val_loss, best_val_loss, checkpoint_times = eval_outputs
if save_path is not None:
best_checkpoint_path = save_path
model.train()
print(f'-> Model is training on rank {rank}')
+ total_step += 1
pbar.close()
epoch_end_time = time.perf_counter()-epoch_start_time
@@ -226,9 +230,11 @@ def train(model, train_dataloader, eval_dataloader, tokenizer,
# Update the learning rate as needed
# lr_scheduler.step()
dist.barrier()
- eval_outputs = eval_loop(model, evaluate_lm, optimizer, lr_scheduler, train_config, fsdp_config, rank,
- eval_dataloader, local_rank, tokenizer, wandb_run,
- val_step_loss, val_loss, best_val_loss, checkpoint_times, epoch, step)
+ eval_outputs = eval_loop(model, evaluate_lm, optimizer, lr_scheduler,
+ train_config, fsdp_config, rank, eval_dataloader,
+ local_rank, tokenizer, wandb_run,
+ val_step_loss, val_loss, best_val_loss,
+ checkpoint_times, epoch, total_step)
dist.barrier()
save_path, val_step_loss, val_loss, best_val_loss, checkpoint_times = eval_outputs
if save_path is not None:
@@ -242,7 +248,8 @@ def train(model, train_dataloader, eval_dataloader, tokenizer,
return results, best_checkpoint_path
-def evaluate_lm(model, train_config, eval_dataloader, local_rank, tokenizer, wandb_run):
+def evaluate_lm(model, train_config, eval_dataloader,
+ local_rank, tokenizer, wandb_run, epoch: int = None):
"""
Evaluates the model on the given dataloader
@@ -261,7 +268,9 @@ def evaluate_lm(model, train_config, eval_dataloader, local_rank, tokenizer, wan
val_step_loss = []
eval_loss = 0.0 # Initialize evaluation loss
- for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
+ _epoch = f' {epoch}' if epoch is not None else ''
+ pbar = tqdm(eval_dataloader,colour="green", desc=f"Evaluating epoch{_epoch}", dynamic_ncols=True)
+ for step, batch in enumerate(pbar):
for key in batch.keys():
if train_config.enable_fsdp:
batch[key] = batch[key].to(local_rank)
@@ -278,6 +287,8 @@ def evaluate_lm(model, train_config, eval_dataloader, local_rank, tokenizer, wan
val_step_loss.append(loss.detach().cpu().float().item())
eval_loss += loss.detach().float()
+ _ppl = torch.exp(eval_loss/(step+1)).item()
+ pbar.set_description(f"Evaluating epoch{_epoch} | step_loss: {loss.item():.5f} | avg_loss: {eval_loss.item()/(step+1):.5f} | avg_ppl: {_ppl:.5f}")
# If there's more than one CUDA device, reduce evaluation loss across all devices
if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
@@ -301,7 +312,7 @@ def evaluate_lm(model, train_config, eval_dataloader, local_rank, tokenizer, wan
if local_rank == 0 or not train_config.enable_fsdp:
print(f" eval_epoch_loss={eval_epoch_loss}, eval_epoch_ppl={eval_epoch_ppl}")
- if wandb_run:
+ if wandb_run:
wandb_run.log({'eval/loss': eval_epoch_loss, 'eval/ppl': eval_epoch_ppl}, commit=False)
return eval_epoch_loss, val_step_loss
\ No newline at end of file