Skip to content

Commit

Permalink
Add local train demo
Browse files Browse the repository at this point in the history
  • Loading branch information
mzio committed Oct 14, 2024
1 parent 80f8bd7 commit 2a1fc6d
Show file tree
Hide file tree
Showing 2 changed files with 374 additions and 1 deletion.
30 changes: 29 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,35 @@ python distill_llama.py --model_config distill_llama3_1_8b_lk_t2r \

### Demoing linear attention 7B+ models

The above scripts will save two checkpoints: (1) for the learned attention feature maps (denoted by a `_distill` suffix), (2) for the LoRA finetuning weights (denoted by a `_ft` suffix). We uploaded a couple starter checkpoints in `./checkpoints/`, where for any linearized LLM we only need to save these layers (~0.2% of a 7B LLM's parameters). We have provided [sample checkpoints on HuggingFace](https://huggingface.co/collections/hazyresearch/lolcats-670ca4341699355b61238c37).

#### Trained from `distill_llama.py`
The above scripts will save two checkpoints: (1) for the learned attention feature maps (denoted by a `_distill` suffix), (2) for the LoRA finetuning weights (denoted by a `_ft` suffix).

For example (what the filepaths might look like):

1. Trained linear attention feature maps:
```
./checkpoints/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_1_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-se=0-re=0-lzi=1_distill.pt
```

2. Trained attention LoRA weights:
```
./checkpoints/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_1_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-se=0-re=0-lzi=1-bs=1-gas=8-nte=2-ms=-1-se=0-re=0_ft.pt
```

To chat with these models, you can run a demo script like so (albeit with slower PyTorch implementations):

```bash
python -Wignore demo_lolcats_llm.py \
--attn_mlp_checkpoint_path './checkpoints/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_1_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-se=0-re=0-lzi=1_distill.pt' \
--finetune_checkpoint_path './checkpoints/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_1_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-se=0-re=0-lzi=1-bs=1-gas=8-nte=2-ms=-1-se=0-re=0_ft.pt' \
--num_generations 1 --benchmark
```


#### Hugging Face checkpoints

We also provide some [sample checkpoints on HuggingFace](https://huggingface.co/collections/hazyresearch/lolcats-670ca4341699355b61238c37).

Use the commands provided at `demos/demo_8b.sh` to run inference with our LoLCATS - Llama 3.1 8B checkpoint, which will be downloaded from HuggingFace. The downloaded checkpoints require under <1GB, and are inserted into your local Meta Llama 3.1 model in 16-bit precision -- please ensure you have downloaded the base model and specify your path to it in the configs in demo_8b.sh. To run the demo:
```bash
Expand Down
345 changes: 345 additions & 0 deletions demo_lolcats_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,345 @@
"""
Quick demo of linearized LLM generations
"""
from typing import Optional, List
from os.path import join
import time
import argparse
import torch

from omegaconf import OmegaConf

from transformers import TextStreamer, TextIteratorStreamer, AutoTokenizer

from src.utils.setup import seed_everything
from src.utils.logging import print_header
from src.model.pretrained import get_pretrained_loader
from src.model.load_model import load_and_convert_attns, load_and_convert_finetune


system_prompt = """Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{prompt}
### Response:
"""


def get_args():
parser = argparse.ArgumentParser()
# Model load + setup
parser.add_argument("--attn_mlp_checkpoint_path", type=str, default=None)
parser.add_argument("--finetune_checkpoint_path", type=str, default=None)
parser.add_argument("--config_dir", type=str, default='configs')
parser.add_argument("--seed", type=int, default=42)

# Generation
parser.add_argument("--num_generations", type=int, default=1)
parser.add_argument("--top_k", type=int, default=50)
parser.add_argument("--top_p", type=float, default=0.95)
parser.add_argument("--max_new_tokens", type=int, default=1024)

# Miscellaneous
parser.add_argument("--benchmark", action='store_true', default=False)
parser.add_argument("--print_model", action='store_true', default=False)
parser.add_argument("--debug", action='store_true', default=False)
parser.add_argument("--huggingface_token", type=str, default=None)

# Alt
parser.add_argument("--attn_checkpoint_path", type=str, default=None)
parser.add_argument("--peft_checkpoint_path", type=str, default=None)

args = parser.parse_args()
if args.attn_mlp_checkpoint_path is None and args.attn_checkpoint_path is not None:
args.attn_mlp_checkpoint_path = args.attn_checkpoint_path
if args.finetune_checkpoint_path is None and args.peft_checkpoint_path is not None:
args.finetune_checkpoint_path = args.peft_checkpoint_path
return args


def get_lm_eval_lolcats_model(model_kwargs: dict, lolcats_model: bool = True):
lm_kwargs = copy.deepcopy(model_kwargs)
lm_kwargs['pretrained'] = lm_kwargs['pretrained_model_name_or_path']
lm_kwargs['dtype'] = str(lm_kwargs['torch_dtype']).split('.')[-1]
del lm_kwargs['torch_dtype']

if 'Llama' in lm_kwargs['pretrained_model_name_or_path']: # and lolcats_model:
lm_kwargs['device_map'] = None
from lm_eval_harness.models import ShardedLolcatsLlamaForCausalLM
lm = ShardedLolcatsLlamaForCausalLM.create_from_arg_string(
'', lm_kwargs,
)
else:
sys.path.append(LM_EVALUATION_HARNESS_PATH)
from lm_eval.models import get_model

lm = get_model('hf-causal-experimental').create_from_arg_string(
'', lm_kwargs,
)
# model = lm.model
return lm


class BatchTextIteratorStreamer(TextIteratorStreamer):
"""
Copied from https://discuss.huggingface.co/t/textiteratorstreamer-compatibility-with-batch-processing/46763/2
"""
def __init__(self,
tokenizer: AutoTokenizer,
batch_size: int,
skip_prompt: bool = False,
timeout: Optional[float] = None,
**decode_kwargs: any):
super().__init__(tokenizer, skip_prompt, timeout, **decode_kwargs)
self.batch_size = batch_size
self.token_cache = [[] for _ in range(batch_size)]
self.print_len = [0 for _ in range(batch_size)]
self.generate_exception = None
self.go_up = 0 + batch_size
self.stop_signal = tokenizer.eos_token

def put(self, value):
if len(value.shape) != 2:
value = torch.reshape(value, (self.batch_size, value.shape[0] // self.batch_size))

if self.skip_prompt and self.next_tokens_are_prompt:
self.next_tokens_are_prompt = False
return

printable_texts = list()
for idx in range(self.batch_size):
self.token_cache[idx].extend(value[idx].tolist())
text = self.tokenizer.decode(self.token_cache[idx], **self.decode_kwargs)

if text.endswith("\n"):
printable_text = text[self.print_len[idx] :]
self.token_cache[idx] = []
self.print_len[idx] = 0
self.go_up += 1
# If the last token is a CJK character, we print the characters.
elif len(text) > 0 and self._is_chinese_char(ord(text[-1])):
printable_text = text[self.print_len[idx] :]
self.print_len[idx] += len(printable_text)
else:
printable_text = text[self.print_len[idx] : text.rfind(" ") + 1]
# printable_text = text[self.print_len[idx] : self.print_len[idx] + 1]
# if printable_text == '':
# printable_text = self.stop_signal
self.print_len[idx] += len(printable_text)
printable_texts.append(printable_text)

self.on_finalized_text(printable_texts)

def end(self):
printable_texts = list()
for idx in range(self.batch_size):
if len(self.token_cache[idx]) > 0:
text = self.tokenizer.decode(self.token_cache[idx], **self.decode_kwargs)
printable_text = text[self.print_len[idx] :]
self.token_cache[idx] = []
self.print_len[idx] = 0
else:
printable_text = ""
# printable_text = self.stop_signal
printable_texts.append(printable_text)

self.next_tokens_are_prompt = True
self.on_finalized_text(printable_texts, stream_end=True)

def on_finalized_text(self, texts: List[str], stream_end: bool = False):
self.text_queue.put(texts, timeout=self.timeout)
if stream_end:
self.text_queue.put(self.stop_signal, timeout=self.timeout)

try:
text = [
''.join([x[i] if i < len(x) else self.stop_signal
for x in self.text_queue.queue ])
for i in range(len(self.text_queue.queue[0]))
]
# text = '\n\n'.join(self.text_queue.queue[0])
text = '\n------------\n'.join(text)
go_up = "\033[F" * self.go_up # len(text) # Goes up this many lines
# go_down = "\n" * self.go_up # len(text) # Goes up this many lines
print(f'{text}', flush=True, end="" if not stream_end else None)
# print(f'{go_up}{text}', end="" if not stream_end else None)
except Exception as e:
print(self.stop_signal)

def count_params(module) -> int:
return sum(p.numel() for p in module.parameters())


def setup_fsdp_config(config, args, checkpoint_name: str = 'finetune'):
"""
Hacky arguments for llama-recipes training function
"""
config.seed = args.seed
config.enable_fsdp = args.enable_fsdp
config.low_cpu_fsdp = args.low_cpu_fsdp
config.dist_checkpoint_root_folder = args.checkpoint_dir
config.dist_checkpoint_folder = checkpoint_name

config.model_name = args.run_name
config.use_peft = False # We have custom logic for saving PEFT modules
config.save_model = True
config.run_validation = True
config.use_fp16 = False
config.save_model = True
config.save_optimizer = False
config.output_dir = args.checkpoint_dir
config.save_metrics = not args.no_wandb
config.gradient_clipping = False
config.gradient_clipping_threshold = 1.0
config.num_epochs = getattr(config.trainer, 'num_train_epochs', None)
config.num_train_steps = getattr(args, 'num_train_steps', None) # exit training loop early for debugging
config.eval_steps = getattr(config.trainer, 'eval_steps', None) # how many gradient updates before evaluating
return config


def load_model_from_checkpoint(attn_mlp_checkpoint_path: str,
finetune_checkpoint_path: str,
config_dir: str = 'configs',
print_model: bool = False,
debug: bool = False,
huggingface_token: str = None):
rank = 0
# Get configs from checkpoint paths
try:
model_config = attn_mlp_checkpoint_path.split('-m=')[-1].split('-f=')[0]
distill_config = attn_mlp_checkpoint_path.split('-d=')[-1].split('-m=')[0]
except Exception as e:
model_config = finetune_checkpoint_path.split('-m=')[-1].split('-f=')[0]
distill_config = None

model_config = join(config_dir, 'model', f'{model_config}.yaml')
model_config = OmegaConf.load(model_config)

if distill_config is not None:
distill_config = join(config_dir, 'experiment', f'{distill_config}.yaml')
distill_config = OmegaConf.load(distill_config)
else:
distill_config = {}

finetune_config = finetune_checkpoint_path.split('-f=')[-1].split('-')[0]
finetune_config = join(config_dir, 'experiment', f'{finetune_config}.yaml')
finetune_config = OmegaConf.load(finetune_config)

# Load initial model
model_loader = get_pretrained_loader(**model_config.model,
huggingface_token=huggingface_token)
tokenizer = model_loader.load_tokenizer()
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = 'left'

model = model_loader.load(model_config['attention']['attention_type'])
if debug:
print_header('Pretrained Model')
print(model)

# Add subquadratic attentions
model, distill_peft_config = load_and_convert_attns(model, model_config,
attention_type=None, # in model_config
checkpoint_path=attn_mlp_checkpoint_path,
print_model=debug,
merge_loras=False,
peft_gradient_checkpointing=False,
train_attention=False)

# Add PEFT parameters
model, ft_peft_config = load_and_convert_finetune(model, finetune_config,
checkpoint_path=finetune_checkpoint_path,
print_model=debug,
merge_loras=False,
peft_gradient_checkpointing=False)
if print_model:
print_header('*** Model after checkpoint load ***')
print(model)

return model, model_config, tokenizer


def get_model_name(attn_mlp_checkpoint_path: str, finetune_checkpoint_path: str,
model_config: str = None):
model_name = '😺 ' if attn_mlp_checkpoint_path is not None else ''
if 'llama3_8b_' in finetune_checkpoint_path:
model_name += f'Llama-3-8B'
elif 'llama3_1_8b_' in finetune_checkpoint_path:
model_name += f'Llama-3.1-8B'
elif 'llama2_7b_' in finetune_checkpoint_path:
model_name += f'Llama-2-7B'
elif 'mistral_7b_' in finetune_checkpoint_path:
model_name += f'Mistral-7B'

if attn_mlp_checkpoint_path is not None:
model_name += f'-LoLCATs'

if 'alpaca_clean' in finetune_checkpoint_path:
model_name += f'-Alpaca'

elif model_config is not None:
if 'llama3_8b_' in model_config:
model_name += f'Llama-3-8B'
elif 'llama2_7b_' in model_config:
model_name += f'Llama-2-7B'
elif 'mistral_7b_' in model_config:
model_name += f'Mistral-7B'

return model_name


def main():
args = get_args()
seed_everything(args.seed)
model, model_config, tokenizer = load_model_from_checkpoint(
args.attn_mlp_checkpoint_path, args.finetune_checkpoint_path,
config_dir=args.config_dir, print_model = args.print_model, debug = args.debug,
)
model.eval()
input_len = len(tokenizer(system_prompt)['input_ids'])

model_name = get_model_name(args.attn_mlp_checkpoint_path,
args.finetune_checkpoint_path,
model_config)
while True:
print(f'\n>> Generating {args.num_generations} responses in parallel')
prompt = input(f'>> Message {model_name} (or cmd-c to quit)... ')
all_prompts = [system_prompt.format(prompt=prompt)] * args.num_generations


if args.num_generations == 1:
streamer = TextStreamer(tokenizer, skip_prompt=True,
decode_kwargs={'skip_special_tokens': True})
else:
streamer = BatchTextIteratorStreamer(tokenizer=tokenizer,
batch_size=args.num_generations,
skip_prompt=True,)

with torch.no_grad():
model_input = tokenizer(all_prompts, return_tensors="pt").to(model.device)

if args.benchmark:
torch.cuda.synchronize()
start_time = time.time()
model_output = model.generate(**model_input, use_cache=True,
max_new_tokens=args.max_new_tokens,
do_sample=True,
top_k=args.top_k,
top_p=args.top_p,
num_return_sequences=1,
pad_token_id=tokenizer.eos_token_id,
streamer=streamer)
if args.benchmark:
torch.cuda.synchronize()
elapsed = time.time() - start_time
total_tokens = (model_output != tokenizer.eos_token_id).sum().item()
print_header('(Coarse) stats for nerds')
print(f'├── Model data type: {model.dtype}')
print(f'├── Time of longest response: {elapsed:.3f} sec')
print(f'├── Total tokens processed + generated: {total_tokens}')
print(f'├── Throughput (lagged by last response): {total_tokens / elapsed:.3f} tokens/sec')


if __name__ == '__main__':
main()

0 comments on commit 2a1fc6d

Please sign in to comment.