Skip to content

Commit

Permalink
update demos
Browse files Browse the repository at this point in the history
  • Loading branch information
simran-arora committed Oct 14, 2024
1 parent e15a204 commit 6580604
Show file tree
Hide file tree
Showing 18 changed files with 56 additions and 2,977 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ We also implemented a fused linear attention + sliding window kernel with the [T

For the linearizng layer, see [`./src/model/linear_attention/linear_window_attention_tk_gen.py`](https://github.com/HazyResearch/lolcats/blob/main/src/model/linear_attention/linear_window_attention_tk_gen.py)

You can install the kernel and benchmark 8B models (LoLCATS linearized and Llama Transformer) with and without our ThunderKittens CUDA kernel using the details [in this README.md](). Our 8B model will auto-download from our [HuggingFace checkpoint](https://huggingface.co/hazyresearch/lolcats-llama-3.1-8b-distill).
You can install the kernel and benchmark 8B models (LoLCATS linearized and Llama Transformer) with and without our ThunderKittens CUDA kernel using the details [in this README.md](). Our 8B model will auto-download from our [HuggingFace checkpoint](https://huggingface.co/hazyresearch/lolcats-llama-3.1-8b-distill).

### More!

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ dataset:
cache_dir: 'data/alpaca' # Change this to where you want to save
pretrained_model_config: # will be updated based on model_config
pretrained_model_name_or_path: 'meta-llama/Meta-Llama-3-8B'
cache_dir: '/scr-ssd/mzhang/models/llama3'
cache_dir: '/scratch/'
preprocess_config: null

dataloader:
Expand Down
2 changes: 1 addition & 1 deletion configs/experiment/eval_alpaca_clean.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ dataset:
cache_dir: 'data/alpaca' # Change this to where you want to save
pretrained_model_config:
pretrained_model_name_or_path: 'mistralai/Mistral-7B-v0.1' # will be updated based on model_config
cache_dir: '/scr-ssd/mzhang/models/mistral-v0.1'
cache_dir: '/scratch/'
preprocess_config: null

dataloader:
Expand Down
2 changes: 1 addition & 1 deletion configs/experiment/finetune_lora_qkvo_alpaca_clean.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ dataset:
cache_dir: "data/alpaca"
pretrained_model_config:
pretrained_model_name_or_path: "mistralai/Mistral-7B-v0.1" # will be updated based on model_config
cache_dir: "/data_persistent2/sim_data/"
cache_dir: "/scratch/"
preprocess_config: null

dataloader:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name: llama
model:
pretrained_model_name_or_path: "meta-llama/Meta-Llama-3.1-8B"
cache_dir: "/scr-ssd/mzhang/models/llama-3_1-8b" # Set this to where you want to save checkpoint weights
cache_dir: "/scratch/" # Set this to where you want to save checkpoint weights
return_dict: true
load_in_8bit: false
load_in_4bit: false
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name: llama
model:
pretrained_model_name_or_path: "meta-llama/Meta-Llama-3.1-8B"
cache_dir: "/scr-ssd/mzhang/models/llama-3_1-8b" # Set this to where you want to save checkpoint weights
cache_dir: "/scratch/" # Set this to where you want to save checkpoint weights
return_dict: true
load_in_8bit: false
load_in_4bit: false
Expand Down
1 change: 1 addition & 0 deletions demos/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ As a quick end-to-end compare the prefill speed of the linearized LoLCATS 8B vs.
```bash
bash benchmark_8b.sh
```
Our benchmarking implementation is currently restricted to prefill lengths that are multiples of 64.

The code will print out the inference tokens per second per method.

Expand Down
22 changes: 11 additions & 11 deletions demos/benchmark_8b.sh
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@


CONFIG_DIR='/home/bfs/simran/clean4/lolcats/configs/' # update to your path
CONFIG_DIR='/home/bfs/simran/attention/lolcats/configs/' # update to your path

""" Benchmarking the 8b model on the LOLCATS dataset """
# """ Benchmarking the 8b model on the LOLCATS dataset """

# Run the linearized model with the ThunderKittens kernel
CUDA_VISIBLE_DEVICES=0 python -Wignore demo_lolcats_hf.py \
--model_config_path ${CONFIG_DIR}/model/llama3_1_8b/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml \
--distill_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml \
--finetune_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/finetune_qkvo_alpaca_clean.yaml \
--model_config_path ${CONFIG_DIR}/model/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml \
--distill_config_path ${CONFIG_DIR}/experiment/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml \
--finetune_config_path ${CONFIG_DIR}/experiment/finetune_lora_qkvo_alpaca_clean.yaml \
--attn_mlp_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-distill' \
--finetune_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-ft-lora' \
--num_generations 1 \
Expand All @@ -18,9 +18,9 @@ CUDA_VISIBLE_DEVICES=0 python -Wignore demo_lolcats_hf.py \

# Run the linearized model *without* the ThunderKittens kernel
CUDA_VISIBLE_DEVICES=0 python -Wignore demo_lolcats_hf.py \
--model_config_path ${CONFIG_DIR}/model/llama3_1_8b/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml \
--distill_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml \
--finetune_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/finetune_qkvo_alpaca_clean.yaml \
--model_config_path ${CONFIG_DIR}/model/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml \
--distill_config_path ${CONFIG_DIR}/experiment/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml \
--finetune_config_path ${CONFIG_DIR}/experiment/finetune_lora_qkvo_alpaca_clean.yaml \
--attn_mlp_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-distill' \
--finetune_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-ft-lora' \
--num_generations 1 \
Expand All @@ -30,9 +30,9 @@ CUDA_VISIBLE_DEVICES=0 python -Wignore demo_lolcats_hf.py \

# Run the base Llama model with Transformers SDPA attention
CUDA_VISIBLE_DEVICES=0 python -Wignore demo_lolcats_hf.py \
--model_config_path ${CONFIG_DIR}/model/llama3_1_8b/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml \
--distill_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml \
--finetune_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/finetune_qkvo_alpaca_clean.yaml \
--model_config_path ${CONFIG_DIR}/model/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml \
--distill_config_path ${CONFIG_DIR}/experiment/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml \
--finetune_config_path ${CONFIG_DIR}/experiment/finetune_lora_qkvo_alpaca_clean.yaml \
--attn_mlp_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-distill' \
--finetune_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-ft-lora' \
--num_generations 1 \
Expand Down
17 changes: 10 additions & 7 deletions demos/demo_8b.sh
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@

CONFIG_DIR='/home/bfs/simran/clean4/lolcats/configs/' # update to your path
CONFIG_DIR='/home/bfs/simran/attention/lolcats/configs/' # update to your path

# using huggingface checkpoints
# Using huggingface checkpoints
CUDA_VISIBLE_DEVICES=0 python -Wignore demo_lolcats_hf.py \
--model_config_path ${CONFIG_DIR}/model/llama3_1_8b/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml \
--distill_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml \
--finetune_config_path ${CONFIG_DIR}/experiment/llama3_1_8b/finetune_qkvo_alpaca_clean.yaml \
--model_config_path ${CONFIG_DIR}/model/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml \
--distill_config_path ${CONFIG_DIR}/experiment/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml \
--finetune_config_path ${CONFIG_DIR}/experiment/finetune_lora_qkvo_alpaca_clean.yaml \
--attn_mlp_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-distill' \
--finetune_checkpoint_path 'hazyresearch/lolcats-llama-3.1-8b-ft-lora' \
--num_generations 1
--num_generations 1 \
--max_new_tokens 50

# if you train your own LoLCATS weights, you can use the following command to run inference:

# Reference script:
# if you train your own LoLCATS weights, you can use the following command to run inference with your local checkpoints:
# CHECKPOINT_DIR='/home/mzhang/projects/lolcats/checkpoints/'
# CUDA_VISIBLE_DEVICES=0 python -Wignore demo_lolcats_hf.py \
# --model_config_path ${CONFIG_DIR}/model/llama3_1_8b/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml \
Expand Down
82 changes: 0 additions & 82 deletions demos/vllm_integration/README.md

This file was deleted.

Loading

0 comments on commit 6580604

Please sign in to comment.