Skip to content

Commit

Permalink
Feat (equalize): better equalization across SDPA (#1159)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Jan 16, 2025
1 parent 0d30ab1 commit 41ace8a
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 22 deletions.
5 changes: 5 additions & 0 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from brevitas.graph.hadamard import random_hadamard_matrix
from brevitas.graph.utils import get_module
from brevitas.graph.utils import get_node
from brevitas.nn import ScaledDotProductAttention
from brevitas.nn.equalized_layer import EqualizedModule
from brevitas.nn.equalized_layer import functional_rotate_input
from brevitas.nn.equalized_layer import INPUT_NAMES
Expand Down Expand Up @@ -1584,6 +1585,10 @@ def find_sink(node):
name_to_module={
'src0': src_module, 'sink0': sink_module})
regions.append(region)
for m in graph_module.modules():
if isinstance(m, ScaledDotProductAttention):
m.pre_process_q = functional_rotate_input
m.pre_process_k = functional_rotate_input
return regions

def apply(self,
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/graph/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self, model: torch.nn.Module, quant_map: Dict, enabled: bool = True
self.enabled = enabled
for stateless_function, stateless_module in quant_map.items():
if not hasattr(model, str(stateless_function)):
setattr(model, str(stateless_function), stateless_module())
model.add_module(str(stateless_function), stateless_module())

def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas/nn/equalized_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,13 @@ def forward(self, inp, **kwargs):
def functional_rotate_input(inp, transpose=False):
is_cuda = 'cuda' in str(inp.device) and torch.version.cuda is not None
if transpose:
inp = inp.t()
inp = inp.transpose(-2, -1)
if is_cuda and fast_hadamard_transform is not None:
had_K, K = get_hadK(inp.shape[-1])
inp = matmul_hadU_cuda(inp, had_K, K)
else:
inp = matmul_hadU(inp)

if transpose:
inp = inp.t()
inp = inp.transpose(-2, -1)
return inp
23 changes: 20 additions & 3 deletions src/brevitas/nn/quant_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
from torch.nn import Parameter
import torch.nn.functional as F

from brevitas.core.function_wrapper.misc import Identity
from brevitas.function import identity
from brevitas.quant.scaled_int import Int8ActPerTensorFloat
from brevitas.quant.scaled_int import Uint8ActPerTensorFloat

Expand All @@ -57,6 +59,12 @@

class ScaledDotProductAttention(Module):

def __init__(self, pre_process_q=identity, pre_process_k=identity, pre_process_v=identity):
super().__init__()
self.pre_process_q = pre_process_q
self.pre_process_k = pre_process_k
self.pre_process_v = pre_process_v

def forward(
self,
query: Tensor,
Expand Down Expand Up @@ -103,9 +111,9 @@ def forward(
if enable_gqa:
kwargs["enable_gqa"] = enable_gqa
return F.scaled_dot_product_attention(
query=query,
key=key,
value=value,
query=self.pre_process_q(query),
key=self.pre_process_k(key),
value=self.pre_process_v(value),
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
Expand All @@ -116,6 +124,9 @@ class QuantScaledDotProductAttention(Module):

def __init__(
self,
pre_process_q=identity,
pre_process_k=identity,
pre_process_v=identity,
softmax_input_quant=None,
attn_output_weights_quant=Uint8ActPerTensorFloat,
q_scaled_quant=Int8ActPerTensorFloat,
Expand All @@ -125,6 +136,11 @@ def __init__(
**kwargs) -> None:
super(QuantScaledDotProductAttention, self).__init__()

self.pre_process_q = pre_process_q
self.pre_process_k = pre_process_k
self.pre_process_v = pre_process_v
print(self.pre_process_q)

def filter_kwargs(prefix):
return {k[len(prefix):]: v for k, v in kwargs.items() if k.startswith(prefix)}

Expand Down Expand Up @@ -196,6 +212,7 @@ def forward(
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
query, key, value = self.pre_process_q(query), self.pre_process_k(key), self.pre_process_v(value)
q_scaled = self.q_scaled_quant(query * scale_factor)
k_transpose = self.k_transposed_quant(key.transpose(-2, -1))
attn_weight = q_scaled @ k_transpose
Expand Down
8 changes: 6 additions & 2 deletions src/brevitas_examples/common/generative/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@
"""
import re

from dependencies import this
import torch
from torch import nn

from brevitas import nn as qnn
from brevitas.core.function_wrapper import CeilSte
from brevitas.core.function_wrapper import FloorSte
from brevitas.core.restrict_val import RoundSte
from brevitas.core.stats import NegativeMinOrZero
from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint
from brevitas.graph.quantize import layerwise_quantize
from brevitas.quant.base import ParameterFromRuntimeZeroPoint
from brevitas.quant.experimental.float import Fp8e4m3Act
from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat
from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat
Expand Down Expand Up @@ -68,6 +71,7 @@
from brevitas_examples.common.generative.quantizers import Int8DynamicActPerRowFloat
from brevitas_examples.common.generative.quantizers import Int8DynamicActPerTensorFloat
from brevitas_examples.common.generative.quantizers import IntWeightSymmetricGroupQuant
from brevitas_examples.common.generative.quantizers import RuntimeDynamicStatsZeroPoint
from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerRowFloat
from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerTensorFloat

Expand Down Expand Up @@ -388,10 +392,10 @@ def generate_quantizers(
elif input_quant_granularity == 'per_group':
q_scaled_quant = sym_input_quant.let(
**{
'group_dim': 2, 'group_size': input_group_size})
'group_dim': -1, 'group_size': input_group_size})
k_transposed_quant = sym_input_quant.let(
**{
'group_dim': 1, 'group_size': input_group_size})
'group_dim': -1, 'group_size': input_group_size})
v_quant = q_scaled_quant
attn_output_weights_quant = q_scaled_quant
else:
Expand Down
33 changes: 19 additions & 14 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ def set_seed(seed):
def fused_rotation_no_fx(model, calibration_loader, args):
with torch.no_grad():
new_model, guards = torch._dynamo.export(model)(**calibration_loader[0])
if hasattr(model, str(torch.nn.functional.scaled_dot_product_attention)):
m_to_add = getattr(model, str(torch.nn.functional.scaled_dot_product_attention))
new_model.add_module(str(torch.nn.functional.scaled_dot_product_attention), m_to_add)

apply_layernorm_affine_merge(new_model)
new_model, rewriters = apply_layernorm_to_rmsnorm(new_model, return_rewriters=True)
rewriters = fix_rewriter(rewriters, model, 'weight')
Expand Down Expand Up @@ -303,20 +307,6 @@ def quantize_llm(args):
apply_layernorm_to_rmsnorm(model)
print("Layernorm To RMSNorm applied.")

if args.rotation == 'fx':
model = offload_model(model)
eq = GraphRotationEqualization(
orphan_sink=args.rotation_orphan_sink,
full_rotation_method=args.rotation_mode,
sdpa_regions=args.rotation_sdpa_regions)
model = eq.apply(model)
remove_hooks(model)
elif args.rotation == 'layerwise':
eq = LayerwiseActivationRotation()
model = eq.apply(model)
elif args.rotation == 'fused_no_fx':
fused_rotation_no_fx(model, calibration_loader, args)

# Insert standard MHA layers when performing fx based weight/act equalization to avoid dealing
# with all the variability in HF implementations
if args.replace_mha:
Expand All @@ -333,6 +323,21 @@ def quantize_llm(args):
with torch.no_grad(), functional_quantization_mode(model, {torch.nn.functional.scaled_dot_product_attention: ScaledDotProductAttention}):
model(**calibration_loader[0])
remove_hooks(model)

if args.rotation == 'fx':
model = offload_model(model)
eq = GraphRotationEqualization(
orphan_sink=args.rotation_orphan_sink,
full_rotation_method=args.rotation_mode,
sdpa_regions=args.rotation_sdpa_regions)
model = eq.apply(model)
remove_hooks(model)
elif args.rotation == 'layerwise':
eq = LayerwiseActivationRotation()
model = eq.apply(model)
elif args.rotation == 'fused_no_fx':
fused_rotation_no_fx(model, calibration_loader, args)

if args.weight_equalization:
print("Apply weight equalization...")
# In case of float16 model, we need to offload to account for missing ops
Expand Down

0 comments on commit 41ace8a

Please sign in to comment.