-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
update some demo stuff from other branch
- Loading branch information
1 parent
d0f0e67
commit e15a204
Showing
19 changed files
with
3,849 additions
and
115 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file removed
BIN
-13.2 MB
...une_lora_qkvo_alpaca_clean-s=0-se=0-re=420-lzi=1-bs=1-gas=8-nte=2-ms=-1-se=0-re=420_ft.pt
Binary file not shown.
Binary file removed
BIN
-32.1 MB
..._lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-se=0-re=420-lzi=1_distill.pt
Binary file not shown.
Binary file removed
BIN
-13.2 MB
...-f=finetune_lora_qkvo_alpaca_clean-s=0-se=0-re=12-lzi=1-bs=1-gas=8-nte=2-se=0-re=12_ft.pt
Binary file not shown.
Binary file removed
BIN
-32.1 MB
...b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-se=0-re=12-lzi=1_distill.pt
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
|
||
## Demos | ||
|
||
We describe how to use LoLCATS checkpoints. We include: | ||
1. Demo script to talk to our models using Hugging Face checkpoints | ||
2. Demo script to benchmark the pretrained 8B linearized versus base softmax attention models | ||
3. Code to reproduce the MMLU numbers at 70B and 405B numbers using our uploaded HuggingFace checkpoints | ||
4. Coming soon: VLLM integration with custom LoLCATS CUDA kernels! | ||
|
||
### Talk to pre-trained LoLCATS LLMs | ||
|
||
Use the commands provided at `demo_8b.sh` to run inference with our LoLCATS - Llama 3.1 8B checkpoint, which will be downloaded from Hugging Face. The downloaded checkpoints require under <1GB, and are inserted into your local Meta Llama 3.1 model in 16-bit precision -- please ensure you have downloaded the base model and specify your path to it in the configs in `demo_8b.sh`. To run the demo: | ||
```bash | ||
bash demo_8b.sh | ||
``` | ||
|
||
### Fast inference with custom CUDA kernels | ||
|
||
We provide a custom CUDA prefill kernel written in the [ThunderKittens framework](https://github.com/HazyResearch/ThunderKittens). | ||
|
||
To install the kernel: | ||
```bash | ||
# Clone the repo | ||
git clone https://github.com/HazyResearch/ThunderKittens | ||
cd ThunderKittens | ||
# In config.py, select 'hedgehog', then run: | ||
source env.src | ||
python setup.py install | ||
``` | ||
|
||
As a quick end-to-end compare the prefill speed of the linearized LoLCATS 8B vs. the base Llama 8B model, we provide a script at: | ||
```bash | ||
bash benchmark_8b.sh | ||
``` | ||
|
||
The code will print out the inference tokens per second per method. | ||
|
||
### 5-shot MMLU Eval | ||
|
||
First get the 5-shot MMLU data. We directly saved the tokenized examples produced by the [lm-eval-harness](https://github.com/EleutherAI/lm-evaluation-harness) codebase to a pickle file | ||
``` | ||
cd lolcats/inference/ | ||
unzip mmlu.pkl.zip | ||
``` | ||
|
||
We provide scripts to eval our 70B and 405B LoLCATS linearized checkpoints on HuggingFace on MMLU | ||
```bash | ||
cd lolcats/ | ||
bash demos/llm_mmlu_eval/demo_70b.sh # runs on 1 8x80GB H100 node | ||
sbatch demos/llm_mmlu_eval/demo_405b.sh # set to use 2 8x80GB H100 nodes | ||
``` | ||
|
||
These call to the `demos/llm_mmlu_eval/eval_mmlu.py` file, which just loops through mmlu.pkl and uses the last-token model logits to get the predictions. | ||
|
||
|
||
### VLLM Integration | ||
|
||
Coming Soon! |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
|
||
|
||
CONFIG_DIR='/home/bfs/simran/clean4/lolcats/configs/' # update to your path | ||
|
||
""" Benchmarking the 8b model on the LOLCATS dataset """ | ||
|
||
# Run the linearized model with the ThunderKittens kernel | ||
CUDA_VISIBLE_DEVICES=0 python -Wignore demo_lolcats_hf.py \ | ||
--model_config_path ${CONFIG_DIR}/model/llama3_1_8b/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml \ | ||
--distill_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml \ | ||
--finetune_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/finetune_qkvo_alpaca_clean.yaml \ | ||
--attn_mlp_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-distill' \ | ||
--finetune_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-ft-lora' \ | ||
--num_generations 1 \ | ||
--use_cuda_kernels 1 \ | ||
--benchmark \ | ||
--max_new_tokens 1 | ||
|
||
# Run the linearized model *without* the ThunderKittens kernel | ||
CUDA_VISIBLE_DEVICES=0 python -Wignore demo_lolcats_hf.py \ | ||
--model_config_path ${CONFIG_DIR}/model/llama3_1_8b/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml \ | ||
--distill_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml \ | ||
--finetune_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/finetune_qkvo_alpaca_clean.yaml \ | ||
--attn_mlp_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-distill' \ | ||
--finetune_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-ft-lora' \ | ||
--num_generations 1 \ | ||
--use_cuda_kernels 0 \ | ||
--benchmark \ | ||
--max_new_tokens 1 | ||
|
||
# Run the base Llama model with Transformers SDPA attention | ||
CUDA_VISIBLE_DEVICES=0 python -Wignore demo_lolcats_hf.py \ | ||
--model_config_path ${CONFIG_DIR}/model/llama3_1_8b/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml \ | ||
--distill_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml \ | ||
--finetune_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/finetune_qkvo_alpaca_clean.yaml \ | ||
--attn_mlp_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-distill' \ | ||
--finetune_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-ft-lora' \ | ||
--num_generations 1 \ | ||
--use_attention \ | ||
--benchmark \ | ||
--max_new_tokens 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
|
||
CONFIG_DIR='/home/bfs/simran/clean4/lolcats/configs/' # update to your path | ||
|
||
# using huggingface checkpoints | ||
CUDA_VISIBLE_DEVICES=0 python -Wignore demo_lolcats_hf.py \ | ||
--model_config_path ${CONFIG_DIR}/model/llama3_1_8b/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml \ | ||
--distill_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml \ | ||
--finetune_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/finetune_qkvo_alpaca_clean.yaml \ | ||
--attn_mlp_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-distill' \ | ||
--finetune_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-ft-lora' \ | ||
--num_generations 1 | ||
|
||
# if you train your own LoLCATS weights, you can use the following command to run inference: | ||
# CHECKPOINT_DIR='/home/mzhang/projects/lolcats/checkpoints/' | ||
# CUDA_VISIBLE_DEVICES=0 python -Wignore demo_lolcats_hf.py \ | ||
# --model_config_path ${CONFIG_DIR}/model/llama3_1_8b/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml \ | ||
# --distill_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml \ | ||
# --finetune_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/finetune_qkvo_alpaca_clean.yaml \ | ||
# --attn_mlp_checkpoint_path ${CHECKPOINT_DIR}/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_1_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-ms=2500-se=0-re=100-lzi=1_distill.pt \ | ||
# --finetune_checkpoint_path ${CHECKPOINT_DIR}/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_1_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-ms=2500-se=0-re=100-lzi=1-bs=1-gas=8-nte=2-ms=2500-se=0-re=100_ft.pt \ | ||
# --num_generations 1 | ||
|
||
|
||
|
Oops, something went wrong.