Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable ao/_models #7

Draft
wants to merge 5 commits into
base: auto_round_support-3
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions torchao/_models/llama/bench_autoround.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# for auto-round, it requires 200 iters for bs 1 to get good output
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround &> auto_round_logs/quant_lm_head_iters200-bs1
1 change: 1 addition & 0 deletions torchao/_models/llama/benchmarks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround --write_result benchmark_results.txt
# in readme
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt
Expand Down
28 changes: 26 additions & 2 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def device_sync(device):
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from torchao._models.llama.model import Transformer, prepare_inputs_for_model
from torchao._models.llama.model import Transformer, prepare_inputs_for_model, TransformerBlock
from torchao._models.llama.tokenizer import get_tokenizer

def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization
Expand Down Expand Up @@ -220,6 +220,30 @@ def main(
groupsize=int(quantization.split("-")[-1])
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
quantize_(model, int4_weight_only(group_size=groupsize))
if "autoround" == quantization:
from torchao.prototype.autoround.autoround_demo import quantize_model_with_autoround
from torchao.prototype.autoround.core import auto_round_config
import torchao.prototype.autoround.utils as ar_utils
from transformers import AutoTokenizer
from torchao.prototype.autoround.multi_tensor import multi_tensor_config
multi_tensor_config.offload_device = "cpu"
# TODO(Yi): Load the tokenizer withouth the HF tokenizer
_tokenizer = AutoTokenizer.from_pretrained(checkpoint_path.parent)
model = model.to(multi_tensor_config.offload_device)
# TODO: Enable other `train_bs`
auto_round_config.train_bs = 1
auto_round_config.iters = 200
auto_round_config.nsamples = 128
auto_round_config.quant_lm_head = True
print(auto_round_config)
model.set_caches_for_calib(max_seq_length=auto_round_config.seqlen, max_batch_size=auto_round_config.train_bs)
is_decoder = (
lambda mod, fqn: isinstance(mod, TransformerBlock) or "output" in fqn
)
quantize_model_with_autoround(model, tokenizer=_tokenizer, decoder_cls=TransformerBlock, auto_round_config=auto_round_config, device="cuda", gen_text=False, is_decoder=is_decoder)
model.clean_caches_for_calib()
model = model.to(device)

if "autoquant" == quantization:
model = autoquant(model, manual=True)

Expand Down Expand Up @@ -322,7 +346,7 @@ def callback(x):
tokens_generated = y.size(0) - prompt_length
tokens_sec = tokens_generated / t
aggregate_metrics['tokens_per_sec'].append(tokens_sec)
print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec")
print(f"Time for inference {i + 1}: {tokens_generated} tokens, {t:.02f} sec total, {tokens_sec:.02f} tokens/sec")
print(f"Bandwidth achieved: {model_size * tokens_sec:.02f} GB/s")
print("==========")

Expand Down
33 changes: 31 additions & 2 deletions torchao/_models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,35 @@ def setup_caches(self, max_batch_size, max_seq_length):
self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base, dtype)
self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))

def set_caches_for_calib(self, max_batch_size: int, max_seq_length: int):
# Compare with `setup_caches` method, it ignore the `KVCache` initialization
if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
return
head_dim = self.config.dim // self.config.n_head
max_seq_length = find_multiple(max_seq_length, 8)
self.max_seq_length = max_seq_length
self.max_batch_size = max_batch_size
dtype = self.output.weight.dtype
# For quantized layers, dtype is encoded in scales
if hasattr(self.output, "scales"):
dtype = self.output.scales.dtype
elif hasattr(self.output, "scales_and_zeros"):
dtype = self.output.scales_and_zeros.dtype

self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base, dtype)
self.freqs_cis = self.freqs_cis.to("cuda")
self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))
self.causal_mask = self.causal_mask.to("cuda")

def clean_caches_for_calib(self):
self.max_batch_size = -1
self.max_seq_length = -1
self.freqs_cis: Optional[Tensor] = None
self.mask_cache: Optional[Tensor] = None

def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
assert self.freqs_cis is not None, "Caches must be initialized first"
mask = self.causal_mask[None, None, input_pos]
mask = self.causal_mask[None, input_pos]
freqs_cis = self.freqs_cis[input_pos]
x = self.tok_embeddings(idx)

Expand Down Expand Up @@ -244,6 +270,7 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona

k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)

y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)

y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
Expand Down Expand Up @@ -290,8 +317,10 @@ def precompute_freqs_cis(


def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
bsz, seqlen, num_heads, head_dim = x.shape
xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)

freqs_cis = freqs_cis.view(bsz, xshaped.size(1), 1, xshaped.size(3), 2)
x_out2 = torch.stack(
[
xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
Expand Down
25 changes: 14 additions & 11 deletions torchao/prototype/autoround/autoround_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


def quantize_model_with_autoround(
model, tokenizer, decoder_cls, auto_round_config=auto_round_config, device="cuda"
model, tokenizer, decoder_cls, auto_round_config=auto_round_config, device="cuda", gen_text=True, is_decoder=None
):
with torch.no_grad():
# 0. Get the model, tokenizer, and decoder_cls
Expand All @@ -23,12 +23,13 @@ def quantize_model_with_autoround(
# User should provide the `is_decoder` function for identifying the decoder block
# It can be extended to other modules, such as `lm_head`, the function like:
# is_target_module = lambda mod, fqn: isinstance(mod, decoder_cls) or "lm_head" in fqn
if auto_round_config.quant_lm_head:
is_decoder = (
lambda mod, fqn: isinstance(mod, decoder_cls) or "lm_head" in fqn
)
else:
is_decoder = lambda mod, fqn: isinstance(mod, decoder_cls)
if is_decoder is None:
if auto_round_config.quant_lm_head:
is_decoder = (
lambda mod, fqn: isinstance(mod, decoder_cls) or "lm_head" in fqn
)
else:
is_decoder = lambda mod, fqn: isinstance(mod, decoder_cls)

prepare_model_for_applying_auto_round_(model, is_decoder)

Expand Down Expand Up @@ -63,11 +64,13 @@ def quantize_model_with_autoround(
model, torchao.dtypes.AffineQuantizedTensor
)
print(f"Number of quantized weight: {num_quantized_weight}")


# 4(Optional). Generate text using the optimized model
ar_utils.gen_text(
model, tokenizer, "Quantized model", device="cuda", max_length=50
)
if gen_text:
ar_utils.gen_text(
model, tokenizer, "Quantized model", device="cuda", max_length=50
)
return model


Expand Down Expand Up @@ -114,7 +117,7 @@ def main(args):
"--iters", default=200, type=int, help="Number of iterations for optimization"
)
parser.add_argument(
"--bits", default=3, type=int, help="Number of bits for quantization"
"--bits", default=4, type=int, help="Number of bits for quantization"
)
parser.add_argument(
"--train_bs", default=4, type=int, help="Batch size for training"
Expand Down
3 changes: 2 additions & 1 deletion torchao/prototype/autoround/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torchao.prototype.autoround.multi_tensor import multi_tensor_config, MultiTensor
from torchao.quantization.quant_primitives import ZeroPointDomain
from torchao.utils import find_multiple

import logging
# TODO: remove it before merge
ar_utils.freeze_random()

Expand Down Expand Up @@ -156,6 +156,7 @@ def _is_observed_linear(mod: torch.nn.Module, fqn: str):
def apply_auto_round(block, grouped_args, spec, block_outputs):
# Call the auto-round to execute the optimization process
import auto_round


block = block.to(multi_tensor_config.accelerator_device)
global layer_idx
Expand Down
13 changes: 11 additions & 2 deletions torchao/prototype/autoround/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import random

import auto_round

import logging
import numpy as np
import torch

Expand Down Expand Up @@ -85,6 +85,12 @@ def gen_example_inputs(tokenizer, device, max_length=20):
return (input_ids,)


def auto_detect_decoder_cls(model):
for name, module in model.named_modules():
if isinstance(module, torch.nn.ModuleList):
first_module = module[0]
return type(first_module)

def get_float_model_info(model_name_or_path, torch_dtype=torch.float32):
import transformers

Expand All @@ -97,7 +103,10 @@ def get_float_model_info(model_name_or_path, torch_dtype=torch.float32):
elif "opt" in model_name_or_path:
decoder_cls = transformers.models.opt.modeling_opt.OPTDecoderLayer
else:
raise ValueError(f"Unsupported model: {model_name_or_path}")
decoder_cls = auto_detect_decoder_cls(model)
logging.warning(f"auto detect decoder_cls: {decoder_cls}")
if decoder_cls is None:
raise ValueError(f"Unsupported model: {model_name_or_path}")
return model, tokenizer, decoder_cls


Expand Down