diff --git a/README.md b/README.md index 1025af1..b588f90 100644 --- a/README.md +++ b/README.md @@ -249,7 +249,35 @@ python distill_llama.py --model_config distill_llama3_1_8b_lk_t2r \ ### Demoing linear attention 7B+ models -The above scripts will save two checkpoints: (1) for the learned attention feature maps (denoted by a `_distill` suffix), (2) for the LoRA finetuning weights (denoted by a `_ft` suffix). We uploaded a couple starter checkpoints in `./checkpoints/`, where for any linearized LLM we only need to save these layers (~0.2% of a 7B LLM's parameters). We have provided [sample checkpoints on HuggingFace](https://huggingface.co/collections/hazyresearch/lolcats-670ca4341699355b61238c37). + +#### Trained from `distill_llama.py` +The above scripts will save two checkpoints: (1) for the learned attention feature maps (denoted by a `_distill` suffix), (2) for the LoRA finetuning weights (denoted by a `_ft` suffix). + +For example (what the filepaths might look like): + +1. Trained linear attention feature maps: +``` +./checkpoints/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_1_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-se=0-re=0-lzi=1_distill.pt +``` + +2. Trained attention LoRA weights: +``` +./checkpoints/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_1_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-se=0-re=0-lzi=1-bs=1-gas=8-nte=2-ms=-1-se=0-re=0_ft.pt +``` + +To chat with these models, you can run a demo script like so (albeit with slower PyTorch implementations): + +```bash +python -Wignore demo_lolcats_llm.py \ +--attn_mlp_checkpoint_path './checkpoints/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_1_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-se=0-re=0-lzi=1_distill.pt' \ +--finetune_checkpoint_path './checkpoints/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01/dl-d=distill_alpaca_clean_xent0_mse1000_lr1e-2-m=distill_llama3_1_8b_lk_smd_wtk64_fd64_w01-f=finetune_lora_qkvo_alpaca_clean-s=0-se=0-re=0-lzi=1-bs=1-gas=8-nte=2-ms=-1-se=0-re=0_ft.pt' \ +--num_generations 1 --benchmark +``` + + +#### Hugging Face checkpoints + +We also provide some [sample checkpoints on HuggingFace](https://huggingface.co/collections/hazyresearch/lolcats-670ca4341699355b61238c37). Use the commands provided at `demos/demo_8b.sh` to run inference with our LoLCATS - Llama 3.1 8B checkpoint, which will be downloaded from HuggingFace. The downloaded checkpoints require under <1GB, and are inserted into your local Meta Llama 3.1 model in 16-bit precision -- please ensure you have downloaded the base model and specify your path to it in the configs in demo_8b.sh. To run the demo: ```bash diff --git a/demo_lolcats_llm.py b/demo_lolcats_llm.py new file mode 100644 index 0000000..0db489e --- /dev/null +++ b/demo_lolcats_llm.py @@ -0,0 +1,345 @@ +""" +Quick demo of linearized LLM generations +""" +from typing import Optional, List +from os.path import join +import time +import argparse +import torch + +from omegaconf import OmegaConf + +from transformers import TextStreamer, TextIteratorStreamer, AutoTokenizer + +from src.utils.setup import seed_everything +from src.utils.logging import print_header +from src.model.pretrained import get_pretrained_loader +from src.model.load_model import load_and_convert_attns, load_and_convert_finetune + + +system_prompt = """Below is an instruction that describes a task. Write a response that appropriately completes the request. + +### Instruction: +{prompt} + +### Response: +""" + + +def get_args(): + parser = argparse.ArgumentParser() + # Model load + setup + parser.add_argument("--attn_mlp_checkpoint_path", type=str, default=None) + parser.add_argument("--finetune_checkpoint_path", type=str, default=None) + parser.add_argument("--config_dir", type=str, default='configs') + parser.add_argument("--seed", type=int, default=42) + + # Generation + parser.add_argument("--num_generations", type=int, default=1) + parser.add_argument("--top_k", type=int, default=50) + parser.add_argument("--top_p", type=float, default=0.95) + parser.add_argument("--max_new_tokens", type=int, default=1024) + + # Miscellaneous + parser.add_argument("--benchmark", action='store_true', default=False) + parser.add_argument("--print_model", action='store_true', default=False) + parser.add_argument("--debug", action='store_true', default=False) + parser.add_argument("--huggingface_token", type=str, default=None) + + # Alt + parser.add_argument("--attn_checkpoint_path", type=str, default=None) + parser.add_argument("--peft_checkpoint_path", type=str, default=None) + + args = parser.parse_args() + if args.attn_mlp_checkpoint_path is None and args.attn_checkpoint_path is not None: + args.attn_mlp_checkpoint_path = args.attn_checkpoint_path + if args.finetune_checkpoint_path is None and args.peft_checkpoint_path is not None: + args.finetune_checkpoint_path = args.peft_checkpoint_path + return args + + +def get_lm_eval_lolcats_model(model_kwargs: dict, lolcats_model: bool = True): + lm_kwargs = copy.deepcopy(model_kwargs) + lm_kwargs['pretrained'] = lm_kwargs['pretrained_model_name_or_path'] + lm_kwargs['dtype'] = str(lm_kwargs['torch_dtype']).split('.')[-1] + del lm_kwargs['torch_dtype'] + + if 'Llama' in lm_kwargs['pretrained_model_name_or_path']: # and lolcats_model: + lm_kwargs['device_map'] = None + from lm_eval_harness.models import ShardedLolcatsLlamaForCausalLM + lm = ShardedLolcatsLlamaForCausalLM.create_from_arg_string( + '', lm_kwargs, + ) + else: + sys.path.append(LM_EVALUATION_HARNESS_PATH) + from lm_eval.models import get_model + + lm = get_model('hf-causal-experimental').create_from_arg_string( + '', lm_kwargs, + ) + # model = lm.model + return lm + + +class BatchTextIteratorStreamer(TextIteratorStreamer): + """ + Copied from https://discuss.huggingface.co/t/textiteratorstreamer-compatibility-with-batch-processing/46763/2 + """ + def __init__(self, + tokenizer: AutoTokenizer, + batch_size: int, + skip_prompt: bool = False, + timeout: Optional[float] = None, + **decode_kwargs: any): + super().__init__(tokenizer, skip_prompt, timeout, **decode_kwargs) + self.batch_size = batch_size + self.token_cache = [[] for _ in range(batch_size)] + self.print_len = [0 for _ in range(batch_size)] + self.generate_exception = None + self.go_up = 0 + batch_size + self.stop_signal = tokenizer.eos_token + + def put(self, value): + if len(value.shape) != 2: + value = torch.reshape(value, (self.batch_size, value.shape[0] // self.batch_size)) + + if self.skip_prompt and self.next_tokens_are_prompt: + self.next_tokens_are_prompt = False + return + + printable_texts = list() + for idx in range(self.batch_size): + self.token_cache[idx].extend(value[idx].tolist()) + text = self.tokenizer.decode(self.token_cache[idx], **self.decode_kwargs) + + if text.endswith("\n"): + printable_text = text[self.print_len[idx] :] + self.token_cache[idx] = [] + self.print_len[idx] = 0 + self.go_up += 1 + # If the last token is a CJK character, we print the characters. + elif len(text) > 0 and self._is_chinese_char(ord(text[-1])): + printable_text = text[self.print_len[idx] :] + self.print_len[idx] += len(printable_text) + else: + printable_text = text[self.print_len[idx] : text.rfind(" ") + 1] + # printable_text = text[self.print_len[idx] : self.print_len[idx] + 1] + # if printable_text == '': + # printable_text = self.stop_signal + self.print_len[idx] += len(printable_text) + printable_texts.append(printable_text) + + self.on_finalized_text(printable_texts) + + def end(self): + printable_texts = list() + for idx in range(self.batch_size): + if len(self.token_cache[idx]) > 0: + text = self.tokenizer.decode(self.token_cache[idx], **self.decode_kwargs) + printable_text = text[self.print_len[idx] :] + self.token_cache[idx] = [] + self.print_len[idx] = 0 + else: + printable_text = "" + # printable_text = self.stop_signal + printable_texts.append(printable_text) + + self.next_tokens_are_prompt = True + self.on_finalized_text(printable_texts, stream_end=True) + + def on_finalized_text(self, texts: List[str], stream_end: bool = False): + self.text_queue.put(texts, timeout=self.timeout) + if stream_end: + self.text_queue.put(self.stop_signal, timeout=self.timeout) + + try: + text = [ + ''.join([x[i] if i < len(x) else self.stop_signal + for x in self.text_queue.queue ]) + for i in range(len(self.text_queue.queue[0])) + ] + # text = '\n\n'.join(self.text_queue.queue[0]) + text = '\n------------\n'.join(text) + go_up = "\033[F" * self.go_up # len(text) # Goes up this many lines + # go_down = "\n" * self.go_up # len(text) # Goes up this many lines + print(f'{text}', flush=True, end="" if not stream_end else None) + # print(f'{go_up}{text}', end="" if not stream_end else None) + except Exception as e: + print(self.stop_signal) + +def count_params(module) -> int: + return sum(p.numel() for p in module.parameters()) + + +def setup_fsdp_config(config, args, checkpoint_name: str = 'finetune'): + """ + Hacky arguments for llama-recipes training function + """ + config.seed = args.seed + config.enable_fsdp = args.enable_fsdp + config.low_cpu_fsdp = args.low_cpu_fsdp + config.dist_checkpoint_root_folder = args.checkpoint_dir + config.dist_checkpoint_folder = checkpoint_name + + config.model_name = args.run_name + config.use_peft = False # We have custom logic for saving PEFT modules + config.save_model = True + config.run_validation = True + config.use_fp16 = False + config.save_model = True + config.save_optimizer = False + config.output_dir = args.checkpoint_dir + config.save_metrics = not args.no_wandb + config.gradient_clipping = False + config.gradient_clipping_threshold = 1.0 + config.num_epochs = getattr(config.trainer, 'num_train_epochs', None) + config.num_train_steps = getattr(args, 'num_train_steps', None) # exit training loop early for debugging + config.eval_steps = getattr(config.trainer, 'eval_steps', None) # how many gradient updates before evaluating + return config + + +def load_model_from_checkpoint(attn_mlp_checkpoint_path: str, + finetune_checkpoint_path: str, + config_dir: str = 'configs', + print_model: bool = False, + debug: bool = False, + huggingface_token: str = None): + rank = 0 + # Get configs from checkpoint paths + try: + model_config = attn_mlp_checkpoint_path.split('-m=')[-1].split('-f=')[0] + distill_config = attn_mlp_checkpoint_path.split('-d=')[-1].split('-m=')[0] + except Exception as e: + model_config = finetune_checkpoint_path.split('-m=')[-1].split('-f=')[0] + distill_config = None + + model_config = join(config_dir, 'model', f'{model_config}.yaml') + model_config = OmegaConf.load(model_config) + + if distill_config is not None: + distill_config = join(config_dir, 'experiment', f'{distill_config}.yaml') + distill_config = OmegaConf.load(distill_config) + else: + distill_config = {} + + finetune_config = finetune_checkpoint_path.split('-f=')[-1].split('-')[0] + finetune_config = join(config_dir, 'experiment', f'{finetune_config}.yaml') + finetune_config = OmegaConf.load(finetune_config) + + # Load initial model + model_loader = get_pretrained_loader(**model_config.model, + huggingface_token=huggingface_token) + tokenizer = model_loader.load_tokenizer() + tokenizer.pad_token_id = tokenizer.eos_token_id + tokenizer.padding_side = 'left' + + model = model_loader.load(model_config['attention']['attention_type']) + if debug: + print_header('Pretrained Model') + print(model) + + # Add subquadratic attentions + model, distill_peft_config = load_and_convert_attns(model, model_config, + attention_type=None, # in model_config + checkpoint_path=attn_mlp_checkpoint_path, + print_model=debug, + merge_loras=False, + peft_gradient_checkpointing=False, + train_attention=False) + + # Add PEFT parameters + model, ft_peft_config = load_and_convert_finetune(model, finetune_config, + checkpoint_path=finetune_checkpoint_path, + print_model=debug, + merge_loras=False, + peft_gradient_checkpointing=False) + if print_model: + print_header('*** Model after checkpoint load ***') + print(model) + + return model, model_config, tokenizer + + +def get_model_name(attn_mlp_checkpoint_path: str, finetune_checkpoint_path: str, + model_config: str = None): + model_name = '😺 ' if attn_mlp_checkpoint_path is not None else '' + if 'llama3_8b_' in finetune_checkpoint_path: + model_name += f'Llama-3-8B' + elif 'llama3_1_8b_' in finetune_checkpoint_path: + model_name += f'Llama-3.1-8B' + elif 'llama2_7b_' in finetune_checkpoint_path: + model_name += f'Llama-2-7B' + elif 'mistral_7b_' in finetune_checkpoint_path: + model_name += f'Mistral-7B' + + if attn_mlp_checkpoint_path is not None: + model_name += f'-LoLCATs' + + if 'alpaca_clean' in finetune_checkpoint_path: + model_name += f'-Alpaca' + + elif model_config is not None: + if 'llama3_8b_' in model_config: + model_name += f'Llama-3-8B' + elif 'llama2_7b_' in model_config: + model_name += f'Llama-2-7B' + elif 'mistral_7b_' in model_config: + model_name += f'Mistral-7B' + + return model_name + + +def main(): + args = get_args() + seed_everything(args.seed) + model, model_config, tokenizer = load_model_from_checkpoint( + args.attn_mlp_checkpoint_path, args.finetune_checkpoint_path, + config_dir=args.config_dir, print_model = args.print_model, debug = args.debug, + ) + model.eval() + input_len = len(tokenizer(system_prompt)['input_ids']) + + model_name = get_model_name(args.attn_mlp_checkpoint_path, + args.finetune_checkpoint_path, + model_config) + while True: + print(f'\n>> Generating {args.num_generations} responses in parallel') + prompt = input(f'>> Message {model_name} (or cmd-c to quit)... ') + all_prompts = [system_prompt.format(prompt=prompt)] * args.num_generations + + + if args.num_generations == 1: + streamer = TextStreamer(tokenizer, skip_prompt=True, + decode_kwargs={'skip_special_tokens': True}) + else: + streamer = BatchTextIteratorStreamer(tokenizer=tokenizer, + batch_size=args.num_generations, + skip_prompt=True,) + + with torch.no_grad(): + model_input = tokenizer(all_prompts, return_tensors="pt").to(model.device) + + if args.benchmark: + torch.cuda.synchronize() + start_time = time.time() + model_output = model.generate(**model_input, use_cache=True, + max_new_tokens=args.max_new_tokens, + do_sample=True, + top_k=args.top_k, + top_p=args.top_p, + num_return_sequences=1, + pad_token_id=tokenizer.eos_token_id, + streamer=streamer) + if args.benchmark: + torch.cuda.synchronize() + elapsed = time.time() - start_time + total_tokens = (model_output != tokenizer.eos_token_id).sum().item() + print_header('(Coarse) stats for nerds') + print(f'├── Model data type: {model.dtype}') + print(f'├── Time of longest response: {elapsed:.3f} sec') + print(f'├── Total tokens processed + generated: {total_tokens}') + print(f'├── Throughput (lagged by last response): {total_tokens / elapsed:.3f} tokens/sec') + + +if __name__ == '__main__': + main() \ No newline at end of file