Skip to content

Commit

Permalink
Merge pull request #4 from ariG23498/aritra/gradio
Browse files Browse the repository at this point in the history
[Add] Gradio Demo
  • Loading branch information
simran-arora authored Dec 24, 2024
2 parents d2d47d1 + db565bd commit 375df84
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 0 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ In the [lolcats-scaled branch](https://github.com/HazyResearch/lolcats/tree/lolc

## Getting started

### Use the model in Hugging Face Space

A demo is hosted as a Hugging Face Space. You can try the model out [here](https://huggingface.co/spaces/ariG23498/lolcats).

### Setup dependencies

Please see `environment.yaml` for dependencies and adjust PyTorch CUDA version if needed. We can set them up with conda:
Expand Down
171 changes: 171 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import sys
sys.path.append("../")

import torch
import gradio as gr
from omegaconf import OmegaConf
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download

from src.utils.setup import seed_everything
from src.utils.logging import print_header
from src.model.pretrained import get_pretrained_loader
from src.model.load_model import load_and_convert_attns, load_and_convert_finetune

def load_model_from_checkpoint(
attn_mlp_checkpoint_path: str = None,
finetune_checkpoint_path: str = None,
model_config_path: str = None,
distill_config_path: str = None,
finetune_config_path: str = None,
config_dir: str = 'configs',
print_model: bool = False,
debug: bool = False,
huggingface_token: str = None,
use_cuda_kernels: bool = False,
use_attention: bool = False
):

is_local = attn_mlp_checkpoint_path.endswith(".pt")

model_config = OmegaConf.load(model_config_path)
distill_config = OmegaConf.load(distill_config_path)
finetune_config = OmegaConf.load(finetune_config_path)

model_loader = get_pretrained_loader(**model_config.model,
huggingface_token=huggingface_token)
tokenizer = model_loader.load_tokenizer()
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = 'left'
if use_attention:
model = model_loader.load('softmax')
return model, model_config, tokenizer

model = model_loader.load(model_config['attention']['attention_type'])
if use_cuda_kernels:
print('*** Using TK CUDA kernels **')
model_config['attention']['attention_type'] = 'lolcats_llama_window_tk_gen'

if is_local:
checkpoint_path = attn_mlp_checkpoint_path
else:
checkpoint_path = None
model, distill_peft_config = load_and_convert_attns(
model, model_config,
attention_type=None,
checkpoint_path=checkpoint_path,
print_model=debug,
merge_loras=False,
peft_gradient_checkpointing=False,
train_attention=False)

if is_local:
checkpoint_path = attn_mlp_checkpoint_path
else:
checkpoint_path = None
model, ft_peft_config = load_and_convert_finetune(
model, finetune_config,
checkpoint_path=checkpoint_path,
print_model=debug,
merge_loras=False,
peft_gradient_checkpointing=False)

if not is_local:
model = load_hf_weights(
model,
attn_mlp_checkpoint_path, finetune_checkpoint_path,
filename="model.pt"
)
if use_cuda_kernels:
print('*** Using TK CUDA kernels ***')

if print_model:
print('*** Model after checkpoint load ***')
print(model)

return model, model_config, tokenizer

def load_hf_weights(model, distill_repo_id, ft_repo_id, filename="model.pt"):
for repo_id in [distill_repo_id, ft_repo_id]:
if repo_id is None: continue

print(f"Loading weights from {repo_id}")

local_file_path = hf_hub_download(repo_id=repo_id, filename=filename)
state_dict = torch.load(local_file_path)
if 'model_state_dict' in state_dict:
state_dict = state_dict['model_state_dict']
else:
pass
_keys = model.load_state_dict(state_dict, strict=False)
if len(_keys.unexpected_keys) > 0:
new_state_dict = {k.replace('model.', 'model.model.'): v for k, v in state_dict.items()}
_keys = model.load_state_dict(new_state_dict, strict=False)
if len(_keys.unexpected_keys) > 0:
new_state_dict = {k.replace('model.', 'base_model.model.model.'): v for k, v in state_dict.items()}
_keys = model.load_state_dict(new_state_dict, strict=False)

try:
assert len(_keys.unexpected_keys) == 0
print('*** All expected keys matched successfully ***')
except Exception as e:
print(e)
print('*** Error: unexpected keys in checkpoint - please fix ***')
print('Unexpected keys:')
for k in _keys.unexpected_keys:
print(k)
exit()

return model

def load_model_and_tokenizer():
CONFIG_DIR = 'configs' # Update to your path

model_config_path = f"{CONFIG_DIR}/model/distill_llama3_1_8b_lk_smd_wtk64_fd64_w01.yaml"
distill_config_path = f"{CONFIG_DIR}/experiment/distill_alpaca_clean_xent0_mse1000_lr1e-2.yaml"
finetune_config_path = f"{CONFIG_DIR}/experiment/finetune_lora_qkvo_alpaca_clean.yaml"
attn_mlp_checkpoint_path = 'hazyresearch/lolcats-llama-3.1-8b-distill'
finetune_checkpoint_path = 'hazyresearch/lolcats-llama-3.1-8b-ft-lora'

model, model_config, tokenizer = load_model_from_checkpoint(
attn_mlp_checkpoint_path=attn_mlp_checkpoint_path,
finetune_checkpoint_path=finetune_checkpoint_path,
model_config_path=model_config_path,
distill_config_path=distill_config_path,
finetune_config_path=finetune_config_path,
config_dir=CONFIG_DIR,
print_model=False,
debug=False,
huggingface_token=None,
use_cuda_kernels=False,
use_attention=False
)
model = model.to('cuda')
model.eval()
return model, tokenizer

model, tokenizer = load_model_and_tokenizer()

def generate_response(prompt):
all_prompts = [prompt]

with torch.no_grad():
model_input = tokenizer(all_prompts, return_tensors="pt").to(model.device)
model_output = model.generate(
**model_input, use_cache=True,
max_new_tokens=50,
do_sample=False,
top_k=1,
top_p=1.0,
num_return_sequences=1,
pad_token_id=tokenizer.eos_token_id)
generated_tokens = model_output[0]
input_len = model_input['input_ids'].shape[1]
generated_tokens = generated_tokens[input_len:]
generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)

return generated_text

iface = gr.Interface(fn=generate_response, inputs="text", outputs="text", title="LOLcats Model Demo")

iface.launch()

0 comments on commit 375df84

Please sign in to comment.