Skip to content

Commit

Permalink
Update 7B readme and add running 70B w llama-recipes
Browse files Browse the repository at this point in the history
  • Loading branch information
mzio committed Aug 22, 2024
1 parent 4b6eec1 commit c1d0c27
Show file tree
Hide file tree
Showing 11 changed files with 313 additions and 63 deletions.
173 changes: 140 additions & 33 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
In this README:

- Getting started with dependencies, installation, and experiment configs
- Sample commands (Mistral-7B-v0.1, Llama-2-7B, Llama-3-8B, Mixtral-8x7B, Llama-2-70B, Llama-3-70B)
- Sample commands (Mistral-7B-v0.1, Llama-3-8B, Llama-3.1-8B, Llama-3.1-70B)

---

Expand All @@ -19,7 +19,7 @@ Please see `environment.yaml` for dependencies. We can set them up with conda:

```
conda env create -f environment.yaml
conda activate hedgehog
conda activate lolcats
```

---
Expand All @@ -45,7 +45,7 @@ pretrained_config:
low_cpu_mem_usage: true
torch_dtype: bfloat16
rope_theta: 10000.0
attn_implementation: eager # so we can supervise with attention weights
attn_implementation: eager # if supervising with attention weights
```
---
Expand All @@ -72,7 +72,13 @@ For now, we implement the causal linear attention with the CUDA kernel from [htt

To build the kernel (`causal_dot_product`), first activate the conda environment (`conda activate hedgehog`). Then navigate to `./csrc/` and run `python setup.py install` within `./csrc/`. It's worth checking the arguments in `./csrc/setup.py` to match your GPU setup and C++ versions.

TODO: we're very excited to integrate additional developments like Songlin and friends' `flash-linear-attention` [repo](https://github.com/sustcsonglin/flash-linear-attention), as well as [ThunderKittens](https://github.com/HazyResearch/ThunderKittens). Please let us know if you're interested in scaling up these efficient linear attention implementations into 7B to 70B models.
### ThunderKittens linear attention + sliding window kernel

TODO. [ThunderKittens](https://github.com/HazyResearch/ThunderKittens)

### More

We're very excited to integrate additional developments like Songlin and friends' [flash-linear-attention](https://github.com/sustcsonglin/flash-linear-attention)

### Flash Attention 2 install

Expand All @@ -92,6 +98,8 @@ For any of these commands, you may need to provide a Hugging Face token to downl

### Demoing linear attention 7B models

**_Note: Stale_**

We upload a couple checkpoints in `./checkpoints/`, where for any linearized 7B model we only need to save the linear attention layers and the LoRA weights (in two separate `.pt` checkpoints). To chat with these models, you can run:

```
Expand All @@ -101,36 +109,70 @@ python -Wignore demo_hedgehog_llm.py \
--num_generations 1 --benchmark
```

### Distilling + finetuning 7B models
---

### Linearizing 7B models

<p align="center">
<img src="assets/hedgehog_llamas.png" align='center' width=80% height=80%>
</p>

Any of the below commands will convert a 7B Mistral or Llama LLM into a linear attention instruction-following variant. Despite only using LoRA and training on these 50K instruction-tuning samples, we're able to ``unlock'' a good amount of the base model performance when measured on LM Eval tasks.
Any of the below commands will convert a 7B Mistral or Llama LLM into a subquadratic attention instruction-following variant. Despite only using LoRA and training on these 50K instruction-tuning samples, we're able to ``unlock'' a good amount of the base model performance when measured on LM Eval tasks.

See `configs/model/` for model configs used in the below commands, and `configs/experiments/` for attention transfer and finetuning configs.

#### Mistral-7B-v0.1, Hedgehog Feature Map, using Alpaca-Clean

```
python distill_llama.py --model_config distill_mistral_7b_lk_smd_zi \
python distill_llama.py --model_config distill_mistral_7b_lk_smd_fd64 \
--distill_config distill_alpaca_clean_lr1e-2 \
--finetune_config finetune_lora_qkvo_alpaca_clean \
--eval_config eval_alpaca_clean \
--lk_zero_init \
--verbose --seed 0 --replicate 0 \
--huggingface_token hf_<insert your token here>
```

#### Mistral-7B-v0.1, Hedgehog + ThunderKittens Sliding Window, using Alpaca-Clean

```
python distill_llama.py --model_config distill_mistral_7b_lk_smd_wtk64_fd64_w01 \
--distill_config distill_alpaca_clean_lr1e-2 \
--finetune_config finetune_lora_qkvo_alpaca_clean \
--eval_config eval_alpaca_clean \
--lk_zero_init \
--verbose --seed 0 --replicate 0 \
--huggingface_token hf_<insert your token here>
```

#### Llama 3 8B, Hedgehog Feature Map, using Alpaca-Clean

```
python distill_llama.py --model_config distill_llama3_8b_lk_smd_fd64 \
--distill_config distill_alpaca_clean_lr1e-2 \
--finetune_config finetune_lora_qkvo_alpaca_clean \
--eval_config eval_alpaca_clean \
--lk_skip_connection --lk_zero_init \
--verbose --seed 0 --replicate 0
--verbose --seed 0 --replicate 0 \
--huggingface_token hf_<insert your token here>
```

#### Mistral-7B-v0.1, Hedgehog + Sliding Window, using Alpaca-Clean
#### Llama 3 8B, Hedgehog + ThunderKittens Sliding Window, using Alpaca-Clean

```
python distill_llama.py --model_config distill_mistral_7b_lk_smd_zi_swa16_hh \
python distill_llama.py --model_config distill_llama3_8b_lk_smd_wtk64_fd64_w01 \
--distill_config distill_alpaca_clean_lr1e-2 \
--finetune_config finetune_lora_qkvo_alpaca_clean \
--eval_config eval_alpaca_clean \
--lk_skip_connection --lk_zero_init \
--verbose --seed 0 --replicate 0
--verbose --seed 0 --replicate 0 \
--huggingface_token hf_<insert your token here>
```

#### Llama-3-8B, Hedgehog Feature Map, using Alpaca-Clean
#### Llama 3.1 8B, Hedgehog Feature Map, using Alpaca-Clean

```
python distill_llama.py --model_config distill_llama3_8b_lk_smd_zi \
python distill_llama.py --model_config distill_llama3_1_8b_lk_smd_fd64 \
--distill_config distill_alpaca_clean_lr1e-2 \
--finetune_config finetune_lora_qkvo_alpaca_clean \
--eval_config eval_alpaca_clean \
Expand All @@ -139,10 +181,10 @@ python distill_llama.py --model_config distill_llama3_8b_lk_smd_zi \
--huggingface_token hf_<insert your token here>
```

#### Llama-2-7B, Hedgehog Feature Map, using Alpaca-Clean
#### Llama 3.1 8B, Hedgehog + ThunderKittens Sliding Window, using Alpaca-Clean

```
python distill_llama.py --model_config distill_llama2_7b_lk_smd_zi \
python distill_llama.py --model_config distill_llama3_1_8b_lk_smd_wtk64_fd64_w01 \
--distill_config distill_alpaca_clean_lr1e-2 \
--finetune_config finetune_lora_qkvo_alpaca_clean \
--eval_config eval_alpaca_clean \
Expand All @@ -151,45 +193,110 @@ python distill_llama.py --model_config distill_llama2_7b_lk_smd_zi \
--huggingface_token hf_<insert your token here>
```

---

### Evaluation

The above scripts will save two checkpoints: (1) for the learned attention layer weights (denoted by a `_distill` suffix), (2) for the LoRA finetuning weights (denoted by a `_ft` suffix). To evaluate linearized models from these checkpoints, we can add the `--load_distill_checkpoint` and `--load_finetune_checkpoint` args. For example:

```
python distill_llama.py --model_config distill_mistral_7b_lk_smd_zi_swa16_hh \
python distill_llama.py --model_config distill_llama3_8b_lk_smd_wtk64_fd64_w01 \
--distill_config distill_alpaca_clean_lr1e-2 \
--finetune_config finetune_lora_qkvo_alpaca_clean \
--eval_config eval_alpaca_clean \
--lk_skip_connection --lk_zero_init \
--verbose --seed 0 --replicate 0 \
--load_distill_checkpoint ./checkpoints/distill_mistral_7b_lk_smd_zi_swa16_hh/dl-d=distill_alpaca_clean_kld_mse_mistral_lr1e-3-m=distill_mistral_7b_lk_smd_swa_hh-f=finetune_lora_qkvo_alpaca_clean_mistral-s=0-se=0-re=0-lk=untied_head_einsum-lsc=1_distill.pt \
--load_finetune_checkpoint ./checkpoints/distill_mistral_7b_lk_smd_zi_swa16_hh/dl-d=distill_alpaca_clean_kld_mse_mistral_lr1e-3-m=distill_mistral_7b_lk_smd_swa_hh-f=finetune_lora_qkvo_alpaca_clean_mistral-s=0-se=0-re=0-lk=untied_head_einsum-lsc=1-se=0-re=0_ft.pt
--load_distill_checkpoint <path-to-distill-checkpoint> \
--load_finetune_checkpoint <path-to-finetune-checkpoint>
```

#### LM Eval
#### LM Evaluation Harness

For sample LM Eval scripts, please see `./lm_eval_harness/README.md`. In particular, this will involve cloning the Language Model Evaluation Harness from [here](https://github.com/EleutherAI/lm-evaluation-harness/tree/b281b0921b636bc36ad05c0b0b0763bd6dd43463). Note we use the `b281b09` branch following Hugging Face's [Open LLM Leaderboard](https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard).

---

## Updated example commands
### Linearizing 70B models and up [WIP]

<p align="center">
<img src="assets/hedgehog_llamas_big.png" align='center' width=80% height=80%>
</p>

We also support linearizing larger LLMs (Llama 3.1 70B, Llama 3.1 405B) using the great [llama-recipes](https://github.com/meta-llama/llama-recipes/tree/main/src/llama_recipes) repository.

See `llama_recipes/README.md` for more details. At a high-level, we borrow the Fully Sharded Data Parallel (FSDP) pipeline, and split the two stages of LoLCATs linearizing into two scripts:

1. `distill_llama.py`: where we first train subquadratic attentions to mimic the softmax attentions (saving the learned attention feature map checkpoints)
2. `distill_llama_finetune.py`: where we swap in the learned attentions and finetune the rest of the model with LoRA (saving the LoRA checkpoints)

By passing in the same configurations files and arguments to both scripts, `distill_llama_finetune.py` should automatically load the saved checkpoints and pick up from where `distill_llama.py` left off.

#### Sample Commands

**_Script 1: Attention Transfer_**

```bash
python distill_llama.py \
--model_config distill_long_llama3_8b_lk_smd_wtk64_fd64_w01 \
--distill_config distill_long_alpaca_8k_xent0_mse1000_lr1e-2_bs1 \
--finetune_config finetune_long_lora_qkvo_alpaca_clean_8192 \
--eval_config eval_alpaca_clean \
--lk_zero_init --verbose --seed 0 --replicate 614 --state_chunk_len 1024 \
--num_train_epochs 2
torchrun --nnodes 1 --nproc_per_node 8 \
llama_recipes/distill_llama.py \
--model_config distill_llama3_1_70b_lk_smd_wtk64_fd64_w01 \
--distill_config distill_alpaca_clean_xent0_mse1000_lr1e-2 \
--finetune_config finetune_lora_qkvo_alpaca_clean_llama3_1_70b \
--eval_config eval_alpaca_clean \
--verbose --replicate 0 --seed 0 \
--enable_fsdp --low_cpu_fsdp
```

**_Script 2: Low-rank Adaptation_**

```bash
python distill_llama.py \
--model_config distill_llama3_8b_lk_smd_wtk64_fd64_w01 \
torchrun --nnodes 1 --nproc_per_node 9 \
llama_recipes/distill_llama_finetune.py \
--model_config distill_llama3_1_70b_lk_smd_wtk64_fd64_w01 \
--distill_config distill_alpaca_clean_xent0_mse1000_lr1e-2 \
--finetune_config finetune_lora_qkvo_alpaca_clean \
--eval_config eval_alpaca_clean \
--lk_zero_init --verbose --seed 0 --replicate 614 \
--num_train_epochs 2
--finetune_config finetune_lora_qkvo_alpaca_clean_llama3_1_70b \
--eval_config eval_alpaca_clean \
--verbose --replicate 0 --seed 0 \
--enable_fsdp --low_cpu_fsdp
```

#### GPU Memory Training Requirements

See https://huggingface.co/blog/llama31#training-memory-requirements

---

## Setup Debugging

### Huggingface datasets errors

If you come across an error like the following:

```
File "/root/miniconda3/envs/hedgehog/lib/python3.12/site-packages/fsspec/spec.py", line 606, in glob
pattern = glob_translate(path + ("/" if ends_with_sep else ""))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/hedgehog/lib/python3.12/site-packages/fsspec/utils.py", line 734, in glob_translate
raise ValueError(
ValueError: Invalid pattern: '**' can only be an entire path component
```

Try reinstalling the Hugging Face `datasets` package with the version specified, e.g., via `pip install datasets==2.15.0`.

Sometimes setting up the virtual environment from `environment.yaml` results in `datasets==2.11.0` being installed instead.

Similarly, you may need to run the following installs:

```bash
pip install nltk
pip install rouge-score
```

### Miniconda installation

```bash
mkdir -p ~/miniconda3
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
rm -rf ~/miniconda3/miniconda.sh
~/miniconda3/bin/conda init bash
```
52 changes: 52 additions & 0 deletions configs/experiment/distill_alpaca_clean_xent1_mse1000_lr1e-2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
dataset:
name: alpaca_clean
dataset_config:
name: alpaca
path: yahma/alpaca-cleaned
chunk_size: 1024 # sequence length for distilling
concat_data: true
cache_dir: 'data/alpaca' # Change this to where you want to save
pretrained_model_config: # will be updated based on model_config
pretrained_model_name_or_path: 'meta-llama/Meta-Llama-3-8B'
cache_dir: '/scr-ssd/mzhang/models/llama3'
preprocess_config: null

dataloader:
batch_size: 1
num_workers: 2
drop_last: false
pin_memory: true

optimizer:
optim: adamw_torch_fused
lr: 0.01
weight_decay: 0.0

lr_scheduler:
lr_scheduler_type: reduce_lr_on_plateau
mode: min
factor: 0.1
patience: 10
min_lr: 0.00001

trainer: # HuggingFace Trainer-like arguments
name: distill_attention_xent_mse
reverse_kl: false
mse_factor: 1000
xent_factor: 1

bf16: true
train_split: train
val_split: validation
num_train_epochs: 2
gradient_accumulation_steps: 8
seed: 42
batch_size: 1
load_best_model_at_end: true
greater_is_better: false
metric_for_best_model: distill/eval/loss
logging_steps: 100
evaluation_strategy: steps
max_steps: -1
eval_steps: 100
max_eval_batches: null
39 changes: 39 additions & 0 deletions configs/model/distill_llama3_1_405b_lk_smd_wtk64_fd64_w01.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
name: llama
model:
pretrained_model_name_or_path: "meta-llama/Meta-Llama-3.1-405B"
cache_dir: "/scr-ssd/mzhang/models/llama-3_1-405b" # Set this to where you want to save checkpoint weights
return_dict: true
load_in_8bit: false
load_in_4bit: false
device_map: auto
low_cpu_mem_usage: true
torch_dtype: bfloat16
attn_implementation: flash_attention_2
rope_theta: 500000.0
rope_scaling:
factor: 8.0
low_freq_factor: 1.0
high_freq_factor: 4.0
original_max_position_embeddings: 8192
rope_type: llama3

attention:
attention_type: lolcats_llama_window_tk
state_chunk_len: 1024
window_size: 64
affine_attention_factors: false
init_window_factor: -2.1972245773362196
feature_map: softmax_dim
feature_map_kwargs:
eps: 1e-12
# mlp: null # to set
fullspace: true
layer_idx: null # to set
learned_kernel: untied_head_einsum
learned_kernel_kwargs:
feature_dim: 64
skip_connection: false
bias: false
zero_init: false
tie_qk_kernels: false
train_qk: false
2 changes: 1 addition & 1 deletion distill_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def main():
wandb.config.update(_flattened)

# Get pretrained model
model_loader = get_pretrained_loader(**model_config.model,
model_loader = get_pretrained_loader(**model_config.model,
huggingface_token=args.huggingface_token)
tokenizer = model_loader.load_tokenizer()
tokenizer.pad_token_id = tokenizer.eos_token_id
Expand Down
2 changes: 1 addition & 1 deletion environment.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: hedgehog
name: lolcats
channels:
- conda-forge
- pytorch
Expand Down
2 changes: 1 addition & 1 deletion llama_recipes/configs/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class fsdp_config:
sharding_group_size : int=0 # requires hsdp to be set. This specifies the sharding group size, number of GPUs that you model can fit into to form a replica of a model.
replica_group_size: int=0 #requires hsdp to be set. This specifies the replica group size, which is world_size/sharding_group_size.
checkpoint_type: StateDictType = StateDictType.SHARDED_STATE_DICT # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
fsdp_activation_checkpointing: bool=True
fsdp_activation_checkpointing: bool=False # True
fsdp_cpu_offload: bool=False
pure_bf16: bool = False
optimizer: str= "AdamW"
Expand Down
Loading

0 comments on commit c1d0c27

Please sign in to comment.