Skip to content

Commit

Permalink
update some demo stuff from other branch
Browse files Browse the repository at this point in the history
  • Loading branch information
simran-arora committed Oct 14, 2024
1 parent d0f0e67 commit e15a204
Show file tree
Hide file tree
Showing 19 changed files with 3,849 additions and 115 deletions.
27 changes: 8 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ In this README:
- Getting started with dependencies, installation, and experiment configs
- Sample commands for 7B+ LLMs (e.g., Mistral-7B-v0.1, Llama-3-8B, Llama-3.1-8B; anything you can run on a single GPU)

In the [lolcats-scaled branch](https://github.com/HazyResearch/lolcats/tree/lolcats-scaled), we provide details for larger 70B and 405B LLMs.

---

## Getting started
Expand Down Expand Up @@ -107,11 +109,11 @@ python setup.py install

### ThunderKittens linear attention + sliding window kernel

We also implemented a fused linear attention + sliding window kernel with [ThunderKittens](https://github.com/HazyResearch/ThunderKittens).
We also implemented a fused linear attention + sliding window kernel with the [ThunderKittens CUDA framework](https://github.com/HazyResearch/ThunderKittens).

For the linearizng layer, see [`./src/model/linear_attention/linear_window_attention_tk_gen.py`](https://github.com/HazyResearch/lolcats/blob/main/src/model/linear_attention/linear_window_attention_tk_gen.py)

But full repository support coming soon! (requires an import)
You can install the kernel and benchmark 8B models (LoLCATS linearized and Llama Transformer) with and without our ThunderKittens CUDA kernel using the details [in this README.md](). Our 8B model will auto-download from our [HuggingFace checkpoint](https://huggingface.co/hazyresearch/lolcats-llama-3.1-8b-distill).

### More!

Expand Down Expand Up @@ -247,27 +249,14 @@ python distill_llama.py --model_config distill_llama3_1_8b_lk_t2r \

### Demoing linear attention 7B+ models

The above scripts will save two checkpoints: (1) for the learned attention feature maps (denoted by a `_distill` suffix), (2) for the LoRA finetuning weights (denoted by a `_ft` suffix). We uploaded a couple starter checkpoints in `./checkpoints/`, where for any linearized LLM we only need to save these layers (~0.2% of a 7B LLM's parameters). (We also provide additional checkpoints on HuggingFace).

To chat with these models (albeit in an unoptimized PyTorch implementation), you can run:

**Llama 3.1 8B**
```bash
python -Wignore demo_lolcats_llm.py \
--attn_mlp_checkpoint_path './checkpoints/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_1_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-se=0-re=420-lzi=1_distill.pt' \
--finetune_checkpoint_path './checkpoints/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_1_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-se=0-re=420-lzi=1-bs=1-gas=8-nte=2-ms=-1-se=0-re=420_ft.pt' \
--num_generations 1 --benchmark
```
The above scripts will save two checkpoints: (1) for the learned attention feature maps (denoted by a `_distill` suffix), (2) for the LoRA finetuning weights (denoted by a `_ft` suffix). We uploaded a couple starter checkpoints in `./checkpoints/`, where for any linearized LLM we only need to save these layers (~0.2% of a 7B LLM's parameters). We have provided [sample checkpoints on HuggingFace](https://huggingface.co/collections/hazyresearch/lolcats-670ca4341699355b61238c37).

**Llama 3 8B**
Use the commands provided at `demos/demo_8b.sh` to run inference with our LoLCATS - Llama 3.1 8B checkpoint, which will be downloaded from HuggingFace. The downloaded checkpoints require under <1GB, and are inserted into your local Meta Llama 3.1 model in 16-bit precision -- please ensure you have downloaded the base model and specify your path to it in the configs in demo_8b.sh. To run the demo:
```bash
python -Wignore demo_lolcats_llm.py \
--attn_mlp_checkpoint_path './checkpoints/distill_llama3_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-se=0-re=12-lzi=1_distill.pt' \
--finetune_checkpoint_path './checkpoints/distill_llama3_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-se=0-re=12-lzi=1-bs=1-gas=8-nte=2-se=0-re=12_ft.pt' \
--num_generations 1 --benchmark
cd lolcats/
bash demos/demo_8b.sh
```


---

### LM Evaluation Harness Evaluation
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
58 changes: 58 additions & 0 deletions demos/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@

## Demos

We describe how to use LoLCATS checkpoints. We include:
1. Demo script to talk to our models using Hugging Face checkpoints
2. Demo script to benchmark the pretrained 8B linearized versus base softmax attention models
3. Code to reproduce the MMLU numbers at 70B and 405B numbers using our uploaded HuggingFace checkpoints
4. Coming soon: VLLM integration with custom LoLCATS CUDA kernels!

### Talk to pre-trained LoLCATS LLMs

Use the commands provided at `demo_8b.sh` to run inference with our LoLCATS - Llama 3.1 8B checkpoint, which will be downloaded from Hugging Face. The downloaded checkpoints require under <1GB, and are inserted into your local Meta Llama 3.1 model in 16-bit precision -- please ensure you have downloaded the base model and specify your path to it in the configs in `demo_8b.sh`. To run the demo:
```bash
bash demo_8b.sh
```

### Fast inference with custom CUDA kernels

We provide a custom CUDA prefill kernel written in the [ThunderKittens framework](https://github.com/HazyResearch/ThunderKittens).

To install the kernel:
```bash
# Clone the repo
git clone https://github.com/HazyResearch/ThunderKittens
cd ThunderKittens
# In config.py, select 'hedgehog', then run:
source env.src
python setup.py install
```

As a quick end-to-end compare the prefill speed of the linearized LoLCATS 8B vs. the base Llama 8B model, we provide a script at:
```bash
bash benchmark_8b.sh
```

The code will print out the inference tokens per second per method.

### 5-shot MMLU Eval

First get the 5-shot MMLU data. We directly saved the tokenized examples produced by the [lm-eval-harness](https://github.com/EleutherAI/lm-evaluation-harness) codebase to a pickle file
```
cd lolcats/inference/
unzip mmlu.pkl.zip
```

We provide scripts to eval our 70B and 405B LoLCATS linearized checkpoints on HuggingFace on MMLU
```bash
cd lolcats/
bash demos/llm_mmlu_eval/demo_70b.sh # runs on 1 8x80GB H100 node
sbatch demos/llm_mmlu_eval/demo_405b.sh # set to use 2 8x80GB H100 nodes
```

These call to the `demos/llm_mmlu_eval/eval_mmlu.py` file, which just loops through mmlu.pkl and uses the last-token model logits to get the predictions.


### VLLM Integration

Coming Soon!
41 changes: 41 additions & 0 deletions demos/benchmark_8b.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@


CONFIG_DIR='/home/bfs/simran/clean4/lolcats/configs/' # update to your path

""" Benchmarking the 8b model on the LOLCATS dataset """

# Run the linearized model with the ThunderKittens kernel
CUDA_VISIBLE_DEVICES=0 python -Wignore demo_lolcats_hf.py \
--model_config_path ${CONFIG_DIR}/model/llama3_1_8b/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml \
--distill_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml \
--finetune_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/finetune_qkvo_alpaca_clean.yaml \
--attn_mlp_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-distill' \
--finetune_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-ft-lora' \
--num_generations 1 \
--use_cuda_kernels 1 \
--benchmark \
--max_new_tokens 1

# Run the linearized model *without* the ThunderKittens kernel
CUDA_VISIBLE_DEVICES=0 python -Wignore demo_lolcats_hf.py \
--model_config_path ${CONFIG_DIR}/model/llama3_1_8b/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml \
--distill_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml \
--finetune_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/finetune_qkvo_alpaca_clean.yaml \
--attn_mlp_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-distill' \
--finetune_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-ft-lora' \
--num_generations 1 \
--use_cuda_kernels 0 \
--benchmark \
--max_new_tokens 1

# Run the base Llama model with Transformers SDPA attention
CUDA_VISIBLE_DEVICES=0 python -Wignore demo_lolcats_hf.py \
--model_config_path ${CONFIG_DIR}/model/llama3_1_8b/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml \
--distill_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml \
--finetune_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/finetune_qkvo_alpaca_clean.yaml \
--attn_mlp_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-distill' \
--finetune_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-ft-lora' \
--num_generations 1 \
--use_attention \
--benchmark \
--max_new_tokens 1
24 changes: 24 additions & 0 deletions demos/demo_8b.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@

CONFIG_DIR='/home/bfs/simran/clean4/lolcats/configs/' # update to your path

# using huggingface checkpoints
CUDA_VISIBLE_DEVICES=0 python -Wignore demo_lolcats_hf.py \
--model_config_path ${CONFIG_DIR}/model/llama3_1_8b/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml \
--distill_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml \
--finetune_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/finetune_qkvo_alpaca_clean.yaml \
--attn_mlp_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-distill' \
--finetune_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-ft-lora' \
--num_generations 1

# if you train your own LoLCATS weights, you can use the following command to run inference:
# CHECKPOINT_DIR='/home/mzhang/projects/lolcats/checkpoints/'
# CUDA_VISIBLE_DEVICES=0 python -Wignore demo_lolcats_hf.py \
# --model_config_path ${CONFIG_DIR}/model/llama3_1_8b/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml \
# --distill_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml \
# --finetune_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/finetune_qkvo_alpaca_clean.yaml \
# --attn_mlp_checkpoint_path ${CHECKPOINT_DIR}/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_1_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-ms=2500-se=0-re=100-lzi=1_distill.pt \
# --finetune_checkpoint_path ${CHECKPOINT_DIR}/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_1_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-ms=2500-se=0-re=100-lzi=1-bs=1-gas=8-nte=2-ms=2500-se=0-re=100_ft.pt \
# --num_generations 1



Loading

0 comments on commit e15a204

Please sign in to comment.