Skip to content

Commit

Permalink
Feat: rotation-based equalization
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Nov 6, 2024
1 parent 52e0059 commit 5850118
Show file tree
Hide file tree
Showing 8 changed files with 96,809 additions and 97 deletions.
481 changes: 401 additions & 80 deletions src/brevitas/graph/equalize.py

Large diffs are not rendered by default.

96,198 changes: 96,198 additions & 0 deletions src/brevitas/graph/hadamard.py

Large diffs are not rendered by default.

50 changes: 50 additions & 0 deletions src/brevitas/nn/equalized_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,16 @@

import torch

from brevitas.graph.hadamard import get_hadK
from brevitas.graph.hadamard import matmul_hadU
from brevitas.graph.hadamard import matmul_hadU_cuda
from brevitas.nn.quant_mha import QuantMultiheadAttention

try:
import fast_hadamard_transform
except:
fast_hadamard_transform = None

INPUT_NAMES = ['input', 'inp', 'query', 'x', 'hidden_states']


Expand Down Expand Up @@ -41,3 +49,45 @@ def forward(self, *args, **kwargs):
# We convert everything to args so that hooks can work correctly
out = self.layer(*kwargs.values())
return out


class RotatedModule(torch.nn.Module):

def __init__(self, layer, had_mat=None, k=None) -> None:
super().__init__()
if had_mat is not None:
self.had_mat = torch.nn.Parameter(had_mat).cpu()
else:
self.had_mat = None
self.layer = layer
self.k = k

def forward(self, inp, **kwargs):
is_cuda = 'cuda' in str(inp.device) and torch.version.cuda is not None
if is_cuda and fast_hadamard_transform is not None:
if self.had_mat is None or self.k is None:
had_K, K = get_hadK(inp.shape[-1])
else:
had_K = self.had_mat
K = self.k
inp = matmul_hadU_cuda(inp, had_K, K)
else:
inp = matmul_hadU(inp)
o = self.layer(inp)

return o


def functional_rotate_input(inp, transpose=False):
is_cuda = 'cuda' in str(inp.device) and torch.version.cuda is not None
if transpose:
inp = inp.t()
if is_cuda and fast_hadamard_transform is not None:
had_K, K = get_hadK(inp.shape[-1])
inp = matmul_hadU_cuda(inp, had_K, K)
else:
inp = matmul_hadU(inp)

if transpose:
inp = inp.t()
return inp
13 changes: 12 additions & 1 deletion src/brevitas_examples/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Set the env variable BREVITAS_JIT=1 to speed up the quantization process. Curren
```bash
usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES]
[--seqlen SEQLEN] [--eval] [--dataset {wikitext2,c4}]
[--gpxq-block-name GPXQ_BLOCK_NAME]
[--weight-bit-width WEIGHT_BIT_WIDTH]
[--weight-param-method {stats,mse,hqo}]
[--weight-scale-precision {float_scale,po2_scale}]
Expand Down Expand Up @@ -53,7 +54,10 @@ options:
--seqlen SEQLEN Sequence length. Default: 2048.
--eval Eval model PPL on the chosen Dataset.
--dataset {wikitext2,c4}
Dataset to use for quantization (default: wikitext2)
Dataset to use for quantization (default: c4)
--gpxq-block-name GPXQ_BLOCK_NAME
Block name for faster GPxQ optimization. It works only
if FX is not needed (default: None)
--weight-bit-width WEIGHT_BIT_WIDTH
Weight bit width. Default: 8.
--weight-param-method {stats,mse,hqo}
Expand Down Expand Up @@ -121,6 +125,7 @@ options:
--act-calibration Apply activation calibration.
--bias-corr Apply bias correction.
--ln-affine-merge Merge LN affine params.
--replace-rmsnorm Replace HF RMSNorms with Torch one.
--no-quantize Disable quantization.
--no-float16 Disable float16 as base datatype and switch to
float32.
Expand All @@ -129,6 +134,12 @@ options:
--weight-equalization
Apply weight equalization. Relevant to ReLU based
models (e.g. OPT).
--graph-rotation Apply graph rotation equalization
--graph-rotation-mode {had,ort}
If GraphRotation is enabled, decide how to compute the
random rotation matrix that is fully fused. Online or
partial rotation will always be Hadamard
--layerwise-rotation Apply layerwise rotation equalization
--act-equalization {None,layerwise,fx}
Apply activation equalization (SmoothQuant). Layerwise
introduces standalone mul nodes,while fx merges them
Expand Down
28 changes: 20 additions & 8 deletions src/brevitas_examples/llm/llm_quant/ln_affine_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,23 @@
import torch
from torch import nn

from brevitas.graph.equalize import _is_reshaping_op
from brevitas.graph.base import ModuleToModuleByClass
from brevitas.graph.equalize import _is_scale_invariant_module
from brevitas.graph.equalize import MergeLnAffine
from brevitas.graph.utils import get_module
from brevitas_examples.llm.llm_quant.run_utils import cast_to_float32


def replace_rmsnorm_with_torch(model, config):
set_of_layers = set(type(x) for x in model.modules() if 'RMS' in type(x).__name__)
rewriters = [
ModuleToModuleByClass(
rms_cls, torch.nn.RMSNorm, normalized_shape=config.hidden_size, eps=config.rms_norm_eps)
for rms_cls in set_of_layers]
dtype = next(iter(model.parameters())).dtype
for r in rewriters:
model = r.apply(model)
model = model.to(dtype)
return model


def replace_bias(next_module, new_bias):
Expand Down Expand Up @@ -49,7 +62,7 @@ def merge_layernorm_affine_params(graph_model):
module = get_module(graph_model, node.target)
if isinstance(module, nn.LayerNorm):
for next in node.users:
while (_is_reshaping_op(next) or _is_scale_invariant_module(graph_model, next)):
while (_is_scale_invariant_module(graph_model, next)):
next = node.next
if next.op == 'call_module':
next_module = get_module(graph_model, next.target)
Expand Down Expand Up @@ -83,8 +96,7 @@ def merge_layernorm_affine_params(graph_model):


@torch.no_grad()
def apply_layernorm_affine_merge(graph_model, dtype):
# We can't do fp16 tracing on CPU as many kernels are not implemented
# So we have to cast to fp32 first, trace, apply merging, and then cast back
with cast_to_float32(graph_model, dtype):
merge_layernorm_affine_params(graph_model)
def apply_layernorm_affine_merge(graph_model):
eq = MergeLnAffine()
graph_model = eq.apply(graph_model)
return graph_model
36 changes: 34 additions & 2 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from brevitas.export import export_torch_qcdq
from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager
from brevitas.graph.equalize import GraphRotationEqualization
from brevitas.graph.equalize import LayerwiseActivationRotation
from brevitas.graph.quantize import layerwise_quantize
from brevitas_examples.common.accelerate_utils.accelerate import offload_model
from brevitas_examples.common.accelerate_utils.accelerate import remove_hooks
Expand All @@ -30,6 +32,7 @@
from brevitas_examples.llm.llm_quant.gpxq import apply_gpfq
from brevitas_examples.llm.llm_quant.gpxq import apply_gptq
from brevitas_examples.llm.llm_quant.ln_affine_merge import apply_layernorm_affine_merge
from brevitas_examples.llm.llm_quant.ln_affine_merge import replace_rmsnorm_with_torch
from brevitas_examples.llm.llm_quant.prepare_for_quantize import add_zero_bias_to_linear
from brevitas_examples.llm.llm_quant.prepare_for_quantize import replace_mha_with_quantizable_layers
from brevitas_examples.llm.llm_quant.run_utils import CastFloat16ToFloat32
Expand Down Expand Up @@ -196,18 +199,34 @@ def main(args):
remove_hooks(model)
print(f"Float perplexity ({args.dataset}): {float_ppl:.3f}")

if args.replace_rmsnorm:
model = replace_rmsnorm_with_torch(model, model.config)

if require_fx:
model = get_fx(model)
with torch.no_grad():
model, guards = torch._dynamo.export(model)(**calibration_loader[0])
# Blockwise optimization does not work with FX at the moment
args.gpxq_block_name = None

# Apply LN affine merging before inserting MHA layers
# since currently there is support only for merging into Linear
if args.ln_affine_merge:
print("Apply LN affine merge...")
apply_layernorm_affine_merge(model, dtype)
apply_layernorm_affine_merge(model)
print("LN affine merge applied.")

if args.graph_rotation:
assert args.ln_affine_merge
assert args.replace_rmsnorm
model = offload_model(model)
eq = GraphRotationEqualization(
orphan_sink=True, full_rotation_method=args.graph_rotation_mode)
model = eq.apply(model)
remove_hooks(model)
elif args.layerwise_rotation:
eq = LayerwiseActivationRotation()
model = eq.apply(model)

# Insert standard MHA layers when performing fx based weight/act equalization to avoid dealing
# with all the variability in HF implementations
if args.replace_mha:
Expand Down Expand Up @@ -497,6 +516,8 @@ def parse_args(args):
'--act-calibration', action='store_true', help='Apply activation calibration.')
parser.add_argument('--bias-corr', action='store_true', help='Apply bias correction.')
parser.add_argument('--ln-affine-merge', action='store_true', help='Merge LN affine params.')
parser.add_argument(
'--replace-rmsnorm', action='store_true', help='Replace HF RMSNorms with Torch one.')
parser.add_argument('--no-quantize', action='store_true', help='Disable quantization.')
parser.add_argument(
'--no-float16',
Expand All @@ -510,6 +531,17 @@ def parse_args(args):
'--weight-equalization',
action='store_true',
help='Apply weight equalization. Relevant to ReLU based models (e.g. OPT).')
parser.add_argument(
'--graph-rotation', action='store_true', help='Apply graph rotation equalization')
parser.add_argument(
'--graph-rotation-mode',
default='had',
choices=['had', 'ort'],
help=
'If GraphRotation is enabled, decide how to compute the random rotation matrix that is fully fused. Online or partial rotation will always be Hadamard'
)
parser.add_argument(
'--layerwise-rotation', action='store_true', help='Apply layerwise rotation equalization')
parser.add_argument(
'--act-equalization',
default=None,
Expand Down
37 changes: 37 additions & 0 deletions tests/brevitas/graph/equalization_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,3 +491,40 @@ def forward(self, x):

toy_quant_model = fixture_union(
'toy_quant_model', list_of_quant_fixtures, ids=list_of_quant_fixtures)

## List of Rotation fixtures


@pytest_cases.fixture
def linear_rms():

class LinearRMSModel(nn.Module):

def __init__(self) -> None:
super().__init__()
self.linear = nn.Linear(3, 4, bias=True)
self.linear.weight.data.fill_(2.)
self.linear.bias.data.fill_(1.)
self.rms = nn.RMSNorm(4)
self.rms.weight.data = torch.randn_like(
self.rms.weight.data) # Change learned parameters
self.linear_1 = nn.Linear(4, 8, bias=False)
self.linear_1.weight.data.fill_(2.)
self.linear_2 = nn.Linear(8, 8, bias=False)

def forward(self, x):
x = self.linear(x)
x = self.rms(x)
x = self.linear_1(x)
x = self.linear_2(x) * x
x = torch.matmul(x.flatten(1), x.flatten(1).t())

return x

return LinearRMSModel


list_of_rotation_mixtures = ['linear_rms']

rotation_fixtures = fixture_union(
'rotation_fixtures', list_of_rotation_mixtures, ids=list_of_rotation_mixtures)
63 changes: 57 additions & 6 deletions tests/brevitas/graph/test_equalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,14 @@
from brevitas.graph.equalize import _batch_norm
from brevitas.graph.equalize import _extract_regions
from brevitas.graph.equalize import _is_supported_module
from brevitas.graph.equalize import _supported_layers
from brevitas.graph.equalize import activation_equalization_mode
from brevitas.graph.equalize import GraphRotationEqualization
from brevitas.graph.equalize import MergeLnAffine
from brevitas.graph.standardize import DuplicateSharedStatelessModule
from brevitas.graph.standardize import TorchFunctionalToModule
from brevitas.graph.utils import get_module
from tests.marker import requires_pt_ge

from .equalization_fixtures import *

Expand All @@ -28,14 +32,14 @@ def test_resnet18_equalization():
expected_out = model(inp)

model_orig = copy.deepcopy(model)
regions = _extract_regions(model)
supported_sinks = list(_supported_layers)
supported_sinks = tuple([
x for x in _supported_layers if x not in (torch.nn.LayerNorm, *_batch_norm)])
regions = _extract_regions(model, state_impl_kwargs={'supported_sinks': supported_sinks})
_ = equalize_test(
regions, merge_bias=True, bias_shrinkage='vaiq', scale_computation_type='maxabs')
out = model(inp)

# Check that equalization is not introducing FP variations
assert torch.allclose(expected_out, out, atol=ATOL)

regions = sorted(regions, key=lambda region: sorted([r for r in region.srcs_names]))
resnet_18_regions = sorted(RESNET_18_REGIONS, key=lambda region: region[0][0])
equalized_layers = set()
Expand All @@ -58,6 +62,9 @@ def test_resnet18_equalization():
orig_module = get_module(model_orig, layer)
assert not torch.allclose(eq_module.weight, orig_module.weight)

# Check that equalization is not introducing FP variations
assert torch.allclose(expected_out, out, atol=ATOL)


@pytest_cases.parametrize("merge_bias", [True, False])
def test_equalization_torchvision_models(model_coverage: tuple, merge_bias: bool):
Expand All @@ -73,7 +80,10 @@ def test_equalization_torchvision_models(model_coverage: tuple, merge_bias: bool

expected_out = model(inp)

regions = _extract_regions(model)
supported_sinks = list(_supported_layers)
supported_sinks = tuple([
x for x in _supported_layers if x not in (torch.nn.LayerNorm, *_batch_norm)])
regions = _extract_regions(model, state_impl_kwargs={'supported_sinks': supported_sinks})
scale_factor_regions = equalize_test(
regions, merge_bias=merge_bias, bias_shrinkage='vaiq', scale_computation_type='maxabs')
shape_scale_regions = [scale.shape for scale in scale_factor_regions]
Expand Down Expand Up @@ -126,7 +136,10 @@ def test_models(toy_model, merge_bias, request):
expected_out = model(inp)

model = symbolic_trace(model)
regions = _extract_regions(model)
supported_sinks = list(_supported_layers)
supported_sinks = tuple([
x for x in _supported_layers if x not in (torch.nn.LayerNorm, *_batch_norm)])
regions = _extract_regions(model, state_impl_kwargs={'supported_sinks': supported_sinks})
scale_factor_regions = equalize_test(
regions, merge_bias=merge_bias, bias_shrinkage='vaiq', scale_computation_type='maxabs')
shape_scale_regions = [scale.shape for scale in scale_factor_regions]
Expand Down Expand Up @@ -225,3 +238,41 @@ def test_act_equalization_torchvision_models(model_dict: dict, layerwise: bool):
# Check that at least one region performs "true" equalization
# If all shapes are scalar, no equalization has been performed
assert any([shape != () for shape in shape_scale_regions])


@requires_pt_ge('2.4')
@pytest_cases.parametrize('partial_had', [True, False])
def test_models(rotation_fixtures, partial_had):

in_shape = IN_SIZE_LINEAR

model_class = rotation_fixtures
model = model_class()
inp = torch.ones(in_shape)

model.eval()
penultimate_weight = model.linear_1.weight.data
last_weight = model.linear_2.weight.data
with torch.no_grad():
expected_out = model(inp)

model = symbolic_trace(model)
merge = MergeLnAffine()
model = merge.apply(model)
eq = GraphRotationEqualization(orphan_sink=partial_had)
model = eq.apply(model)

with torch.no_grad():
out = model(inp)

penultimate_weight_new = model.linear_1.weight.data

# Invariance of the output
assert torch.allclose(out, expected_out, atol=ATOL)
# Rotate weights must be different
assert not torch.allclose(penultimate_weight, penultimate_weight_new)
# Merging affine parameters of RMS
assert torch.allclose(model.rms.weight.data, torch.ones_like(model.rms.weight.data))
if partial_had:
last_weight_new = model.linear_2.layer.weight.data
assert not torch.allclose(last_weight, last_weight_new)

0 comments on commit 5850118

Please sign in to comment.