diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 744130664..31b2d4f72 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -34,6 +34,7 @@ from brevitas.graph.hadamard import random_hadamard_matrix from brevitas.graph.utils import get_module from brevitas.graph.utils import get_node +from brevitas.nn import ScaledDotProductAttention from brevitas.nn.equalized_layer import EqualizedModule from brevitas.nn.equalized_layer import functional_rotate_input from brevitas.nn.equalized_layer import INPUT_NAMES @@ -1584,6 +1585,10 @@ def find_sink(node): name_to_module={ 'src0': src_module, 'sink0': sink_module}) regions.append(region) + for m in graph_module.modules(): + if isinstance(m, ScaledDotProductAttention): + m.pre_process_q = functional_rotate_input + m.pre_process_k = functional_rotate_input return regions def apply(self, diff --git a/src/brevitas/graph/quantize.py b/src/brevitas/graph/quantize.py index 7724e8f9d..6dbdae4cb 100644 --- a/src/brevitas/graph/quantize.py +++ b/src/brevitas/graph/quantize.py @@ -50,7 +50,7 @@ def __init__(self, model: torch.nn.Module, quant_map: Dict, enabled: bool = True self.enabled = enabled for stateless_function, stateless_module in quant_map.items(): if not hasattr(model, str(stateless_function)): - setattr(model, str(stateless_function), stateless_module()) + model.add_module(str(stateless_function), stateless_module()) def __torch_function__(self, func, types, args=(), kwargs=None): if kwargs is None: diff --git a/src/brevitas/nn/equalized_layer.py b/src/brevitas/nn/equalized_layer.py index 8413a8208..e3d930a50 100644 --- a/src/brevitas/nn/equalized_layer.py +++ b/src/brevitas/nn/equalized_layer.py @@ -81,7 +81,7 @@ def forward(self, inp, **kwargs): def functional_rotate_input(inp, transpose=False): is_cuda = 'cuda' in str(inp.device) and torch.version.cuda is not None if transpose: - inp = inp.t() + inp = inp.transpose(-2, -1) if is_cuda and fast_hadamard_transform is not None: had_K, K = get_hadK(inp.shape[-1]) inp = matmul_hadU_cuda(inp, had_K, K) @@ -89,5 +89,5 @@ def functional_rotate_input(inp, transpose=False): inp = matmul_hadU(inp) if transpose: - inp = inp.t() + inp = inp.transpose(-2, -1) return inp diff --git a/src/brevitas/nn/quant_sdpa.py b/src/brevitas/nn/quant_sdpa.py index 43f99e827..9f82ccc0c 100644 --- a/src/brevitas/nn/quant_sdpa.py +++ b/src/brevitas/nn/quant_sdpa.py @@ -49,6 +49,8 @@ from torch.nn import Parameter import torch.nn.functional as F +from brevitas.core.function_wrapper.misc import Identity +from brevitas.function import identity from brevitas.quant.scaled_int import Int8ActPerTensorFloat from brevitas.quant.scaled_int import Uint8ActPerTensorFloat @@ -57,6 +59,12 @@ class ScaledDotProductAttention(Module): + def __init__(self, pre_process_q=identity, pre_process_k=identity, pre_process_v=identity): + super().__init__() + self.pre_process_q = pre_process_q + self.pre_process_k = pre_process_k + self.pre_process_v = pre_process_v + def forward( self, query: Tensor, @@ -103,9 +111,9 @@ def forward( if enable_gqa: kwargs["enable_gqa"] = enable_gqa return F.scaled_dot_product_attention( - query=query, - key=key, - value=value, + query=self.pre_process_q(query), + key=self.pre_process_k(key), + value=self.pre_process_v(value), attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, @@ -116,6 +124,9 @@ class QuantScaledDotProductAttention(Module): def __init__( self, + pre_process_q=identity, + pre_process_k=identity, + pre_process_v=identity, softmax_input_quant=None, attn_output_weights_quant=Uint8ActPerTensorFloat, q_scaled_quant=Int8ActPerTensorFloat, @@ -125,6 +136,11 @@ def __init__( **kwargs) -> None: super(QuantScaledDotProductAttention, self).__init__() + self.pre_process_q = pre_process_q + self.pre_process_k = pre_process_k + self.pre_process_v = pre_process_v + print(self.pre_process_q) + def filter_kwargs(prefix): return {k[len(prefix):]: v for k, v in kwargs.items() if k.startswith(prefix)} @@ -196,6 +212,7 @@ def forward( attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) else: attn_bias += attn_mask + query, key, value = self.pre_process_q(query), self.pre_process_k(key), self.pre_process_v(value) q_scaled = self.q_scaled_quant(query * scale_factor) k_transpose = self.k_transposed_quant(key.transpose(-2, -1)) attn_weight = q_scaled @ k_transpose diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index 778955285..d845f58a6 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -4,6 +4,7 @@ """ import re +from dependencies import this import torch from torch import nn @@ -11,8 +12,10 @@ from brevitas.core.function_wrapper import CeilSte from brevitas.core.function_wrapper import FloorSte from brevitas.core.restrict_val import RoundSte +from brevitas.core.stats import NegativeMinOrZero from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint from brevitas.graph.quantize import layerwise_quantize +from brevitas.quant.base import ParameterFromRuntimeZeroPoint from brevitas.quant.experimental.float import Fp8e4m3Act from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat @@ -68,6 +71,7 @@ from brevitas_examples.common.generative.quantizers import Int8DynamicActPerRowFloat from brevitas_examples.common.generative.quantizers import Int8DynamicActPerTensorFloat from brevitas_examples.common.generative.quantizers import IntWeightSymmetricGroupQuant +from brevitas_examples.common.generative.quantizers import RuntimeDynamicStatsZeroPoint from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerRowFloat from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerTensorFloat @@ -388,10 +392,10 @@ def generate_quantizers( elif input_quant_granularity == 'per_group': q_scaled_quant = sym_input_quant.let( **{ - 'group_dim': 2, 'group_size': input_group_size}) + 'group_dim': -1, 'group_size': input_group_size}) k_transposed_quant = sym_input_quant.let( **{ - 'group_dim': 1, 'group_size': input_group_size}) + 'group_dim': -1, 'group_size': input_group_size}) v_quant = q_scaled_quant attn_output_weights_quant = q_scaled_quant else: diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 98c37027c..8c4fe1968 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -84,6 +84,10 @@ def set_seed(seed): def fused_rotation_no_fx(model, calibration_loader, args): with torch.no_grad(): new_model, guards = torch._dynamo.export(model)(**calibration_loader[0]) + if hasattr(model, str(torch.nn.functional.scaled_dot_product_attention)): + m_to_add = getattr(model, str(torch.nn.functional.scaled_dot_product_attention)) + new_model.add_module(str(torch.nn.functional.scaled_dot_product_attention), m_to_add) + apply_layernorm_affine_merge(new_model) new_model, rewriters = apply_layernorm_to_rmsnorm(new_model, return_rewriters=True) rewriters = fix_rewriter(rewriters, model, 'weight') @@ -303,20 +307,6 @@ def quantize_llm(args): apply_layernorm_to_rmsnorm(model) print("Layernorm To RMSNorm applied.") - if args.rotation == 'fx': - model = offload_model(model) - eq = GraphRotationEqualization( - orphan_sink=args.rotation_orphan_sink, - full_rotation_method=args.rotation_mode, - sdpa_regions=args.rotation_sdpa_regions) - model = eq.apply(model) - remove_hooks(model) - elif args.rotation == 'layerwise': - eq = LayerwiseActivationRotation() - model = eq.apply(model) - elif args.rotation == 'fused_no_fx': - fused_rotation_no_fx(model, calibration_loader, args) - # Insert standard MHA layers when performing fx based weight/act equalization to avoid dealing # with all the variability in HF implementations if args.replace_mha: @@ -333,6 +323,21 @@ def quantize_llm(args): with torch.no_grad(), functional_quantization_mode(model, {torch.nn.functional.scaled_dot_product_attention: ScaledDotProductAttention}): model(**calibration_loader[0]) remove_hooks(model) + + if args.rotation == 'fx': + model = offload_model(model) + eq = GraphRotationEqualization( + orphan_sink=args.rotation_orphan_sink, + full_rotation_method=args.rotation_mode, + sdpa_regions=args.rotation_sdpa_regions) + model = eq.apply(model) + remove_hooks(model) + elif args.rotation == 'layerwise': + eq = LayerwiseActivationRotation() + model = eq.apply(model) + elif args.rotation == 'fused_no_fx': + fused_rotation_no_fx(model, calibration_loader, args) + if args.weight_equalization: print("Apply weight equalization...") # In case of float16 model, we need to offload to account for missing ops