diff --git a/torchao/_models/llama/bench_autoround.sh b/torchao/_models/llama/bench_autoround.sh new file mode 100644 index 0000000000..9a1665d60e --- /dev/null +++ b/torchao/_models/llama/bench_autoround.sh @@ -0,0 +1,2 @@ +# for auto-round, it requires 200 iters for bs 1 to get good output +# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround &> auto_round_logs/quant_lm_head_iters200-bs1 diff --git a/torchao/_models/llama/benchmarks.sh b/torchao/_models/llama/benchmarks.sh index 6dd9c10d94..bdac0f6bb8 100644 --- a/torchao/_models/llama/benchmarks.sh +++ b/torchao/_models/llama/benchmarks.sh @@ -5,6 +5,7 @@ export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoround --write_result benchmark_results.txt # in readme python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index bf1d870b52..501da70f7f 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -30,7 +30,7 @@ def device_sync(device): wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) -from torchao._models.llama.model import Transformer, prepare_inputs_for_model +from torchao._models.llama.model import Transformer, prepare_inputs_for_model, TransformerBlock from torchao._models.llama.tokenizer import get_tokenizer def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization @@ -220,6 +220,30 @@ def main( groupsize=int(quantization.split("-")[-1]) assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" quantize_(model, int4_weight_only(group_size=groupsize)) + if "autoround" == quantization: + from torchao.prototype.autoround.autoround_demo import quantize_model_with_autoround + from torchao.prototype.autoround.core import auto_round_config + import torchao.prototype.autoround.utils as ar_utils + from transformers import AutoTokenizer + from torchao.prototype.autoround.multi_tensor import multi_tensor_config + multi_tensor_config.offload_device = "cpu" + # TODO(Yi): Load the tokenizer withouth the HF tokenizer + _tokenizer = AutoTokenizer.from_pretrained(checkpoint_path.parent) + model = model.to(multi_tensor_config.offload_device) + # TODO: Enable other `train_bs` + auto_round_config.train_bs = 1 + auto_round_config.iters = 200 + auto_round_config.nsamples = 128 + auto_round_config.quant_lm_head = True + print(auto_round_config) + model.set_caches_for_calib(max_seq_length=auto_round_config.seqlen, max_batch_size=auto_round_config.train_bs) + is_decoder = ( + lambda mod, fqn: isinstance(mod, TransformerBlock) or "output" in fqn + ) + quantize_model_with_autoround(model, tokenizer=_tokenizer, decoder_cls=TransformerBlock, auto_round_config=auto_round_config, device="cuda", gen_text=False, is_decoder=is_decoder) + model.clean_caches_for_calib() + model = model.to(device) + if "autoquant" == quantization: model = autoquant(model, manual=True) @@ -322,7 +346,7 @@ def callback(x): tokens_generated = y.size(0) - prompt_length tokens_sec = tokens_generated / t aggregate_metrics['tokens_per_sec'].append(tokens_sec) - print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec") + print(f"Time for inference {i + 1}: {tokens_generated} tokens, {t:.02f} sec total, {tokens_sec:.02f} tokens/sec") print(f"Bandwidth achieved: {model_size * tokens_sec:.02f} GB/s") print("==========") diff --git a/torchao/_models/llama/model.py b/torchao/_models/llama/model.py index 58a1709642..2111dd453a 100644 --- a/torchao/_models/llama/model.py +++ b/torchao/_models/llama/model.py @@ -169,9 +169,35 @@ def setup_caches(self, max_batch_size, max_seq_length): self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base, dtype) self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)) + def set_caches_for_calib(self, max_batch_size: int, max_seq_length: int): + # Compare with `setup_caches` method, it ignore the `KVCache` initialization + if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size: + return + head_dim = self.config.dim // self.config.n_head + max_seq_length = find_multiple(max_seq_length, 8) + self.max_seq_length = max_seq_length + self.max_batch_size = max_batch_size + dtype = self.output.weight.dtype + # For quantized layers, dtype is encoded in scales + if hasattr(self.output, "scales"): + dtype = self.output.scales.dtype + elif hasattr(self.output, "scales_and_zeros"): + dtype = self.output.scales_and_zeros.dtype + + self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base, dtype) + self.freqs_cis = self.freqs_cis.to("cuda") + self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)) + self.causal_mask = self.causal_mask.to("cuda") + + def clean_caches_for_calib(self): + self.max_batch_size = -1 + self.max_seq_length = -1 + self.freqs_cis: Optional[Tensor] = None + self.mask_cache: Optional[Tensor] = None + def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: assert self.freqs_cis is not None, "Caches must be initialized first" - mask = self.causal_mask[None, None, input_pos] + mask = self.causal_mask[None, input_pos] freqs_cis = self.freqs_cis[input_pos] x = self.tok_embeddings(idx) @@ -244,6 +270,7 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) @@ -290,8 +317,10 @@ def precompute_freqs_cis( def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: + bsz, seqlen, num_heads, head_dim = x.shape xshaped = x.float().reshape(*x.shape[:-1], -1, 2) - freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) + + freqs_cis = freqs_cis.view(bsz, xshaped.size(1), 1, xshaped.size(3), 2) x_out2 = torch.stack( [ xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], diff --git a/torchao/prototype/autoround/autoround_demo.py b/torchao/prototype/autoround/autoround_demo.py index eb36b74a3f..f36b3ce342 100644 --- a/torchao/prototype/autoround/autoround_demo.py +++ b/torchao/prototype/autoround/autoround_demo.py @@ -13,7 +13,7 @@ def quantize_model_with_autoround( - model, tokenizer, decoder_cls, auto_round_config=auto_round_config, device="cuda" + model, tokenizer, decoder_cls, auto_round_config=auto_round_config, device="cuda", gen_text=True, is_decoder=None ): with torch.no_grad(): # 0. Get the model, tokenizer, and decoder_cls @@ -23,12 +23,13 @@ def quantize_model_with_autoround( # User should provide the `is_decoder` function for identifying the decoder block # It can be extended to other modules, such as `lm_head`, the function like: # is_target_module = lambda mod, fqn: isinstance(mod, decoder_cls) or "lm_head" in fqn - if auto_round_config.quant_lm_head: - is_decoder = ( - lambda mod, fqn: isinstance(mod, decoder_cls) or "lm_head" in fqn - ) - else: - is_decoder = lambda mod, fqn: isinstance(mod, decoder_cls) + if is_decoder is None: + if auto_round_config.quant_lm_head: + is_decoder = ( + lambda mod, fqn: isinstance(mod, decoder_cls) or "lm_head" in fqn + ) + else: + is_decoder = lambda mod, fqn: isinstance(mod, decoder_cls) prepare_model_for_applying_auto_round_(model, is_decoder) @@ -63,11 +64,13 @@ def quantize_model_with_autoround( model, torchao.dtypes.AffineQuantizedTensor ) print(f"Number of quantized weight: {num_quantized_weight}") + # 4(Optional). Generate text using the optimized model - ar_utils.gen_text( - model, tokenizer, "Quantized model", device="cuda", max_length=50 - ) + if gen_text: + ar_utils.gen_text( + model, tokenizer, "Quantized model", device="cuda", max_length=50 + ) return model @@ -114,7 +117,7 @@ def main(args): "--iters", default=200, type=int, help="Number of iterations for optimization" ) parser.add_argument( - "--bits", default=3, type=int, help="Number of bits for quantization" + "--bits", default=4, type=int, help="Number of bits for quantization" ) parser.add_argument( "--train_bs", default=4, type=int, help="Batch size for training" diff --git a/torchao/prototype/autoround/core.py b/torchao/prototype/autoround/core.py index 37f84dc0d5..4cde94951d 100644 --- a/torchao/prototype/autoround/core.py +++ b/torchao/prototype/autoround/core.py @@ -11,7 +11,7 @@ from torchao.prototype.autoround.multi_tensor import multi_tensor_config, MultiTensor from torchao.quantization.quant_primitives import ZeroPointDomain from torchao.utils import find_multiple - +import logging # TODO: remove it before merge ar_utils.freeze_random() @@ -156,6 +156,7 @@ def _is_observed_linear(mod: torch.nn.Module, fqn: str): def apply_auto_round(block, grouped_args, spec, block_outputs): # Call the auto-round to execute the optimization process import auto_round + block = block.to(multi_tensor_config.accelerator_device) global layer_idx diff --git a/torchao/prototype/autoround/utils.py b/torchao/prototype/autoround/utils.py index 4fbf754ec2..bf714f58f4 100644 --- a/torchao/prototype/autoround/utils.py +++ b/torchao/prototype/autoround/utils.py @@ -4,7 +4,7 @@ import random import auto_round - +import logging import numpy as np import torch @@ -85,6 +85,12 @@ def gen_example_inputs(tokenizer, device, max_length=20): return (input_ids,) +def auto_detect_decoder_cls(model): + for name, module in model.named_modules(): + if isinstance(module, torch.nn.ModuleList): + first_module = module[0] + return type(first_module) + def get_float_model_info(model_name_or_path, torch_dtype=torch.float32): import transformers @@ -97,7 +103,10 @@ def get_float_model_info(model_name_or_path, torch_dtype=torch.float32): elif "opt" in model_name_or_path: decoder_cls = transformers.models.opt.modeling_opt.OPTDecoderLayer else: - raise ValueError(f"Unsupported model: {model_name_or_path}") + decoder_cls = auto_detect_decoder_cls(model) + logging.warning(f"auto detect decoder_cls: {decoder_cls}") + if decoder_cls is None: + raise ValueError(f"Unsupported model: {model_name_or_path}") return model, tokenizer, decoder_cls