From c6e925d5ed32531c24cabea63ccef8dacbaaf96b Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Thu, 28 Nov 2024 15:17:35 +0000 Subject: [PATCH 01/12] Cailey SGD --- src/brevitas/optim/sgdg.py | 197 ++++++++++++++++++++++++ tests/brevitas/optim/test_cailey_sgd.py | 128 +++++++++++++++ 2 files changed, 325 insertions(+) create mode 100644 src/brevitas/optim/sgdg.py create mode 100644 tests/brevitas/optim/test_cailey_sgd.py diff --git a/src/brevitas/optim/sgdg.py b/src/brevitas/optim/sgdg.py new file mode 100644 index 000000000..d1e89e39b --- /dev/null +++ b/src/brevitas/optim/sgdg.py @@ -0,0 +1,197 @@ +# coding=utf-8 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# This code is originally from: https://github.com/JunLi-Galios/Optimization-on-Stiefel-Manifold-via-Cayley-Transform/blob/master/stiefel_optimizer.py + +import random + +import torch +from torch.optim.optimizer import Optimizer + + +def unit(v, dim: int = 1, eps: float = 1e-8): + vnorm = norm(v, dim) + return v / vnorm.add(eps), vnorm + + +def norm(v, dim: int = 1): + assert len(v.size()) == 2 + return v.norm(p=2, dim=dim, keepdim=True) + + +def matrix_norm_one(W): + out = torch.abs(W) + out = torch.sum(out, dim=0) + out = torch.max(out) + return out + + +def Cayley_loop(X, W, tan_vec, t): # + [n, p] = X.size() + Y = X + t * tan_vec + for i in range(5): + Y = X + t * torch.matmul(W, 0.5 * (X + Y)) + + return Y.t() + + +def qr_retraction(tan_vec): # tan_vec, p-by-n, p <= n + [p, n] = tan_vec.size() + tan_vec.t_() + q, r = torch.linalg.qr(tan_vec) + d = torch.diag(r, 0) + ph = d.sign() + q *= ph.expand_as(q) + q.t_() + + return q + + +episilon = 1e-8 + + +class SGDG(Optimizer): + r"""This optimizer updates variables with two different routines + based on the boolean variable 'stiefel'. + + If stiefel is True, the variables will be updated by SGD-G proposed + as decorrelated weight matrix. + + If stiefel is False, the variables will be updated by SGD. + This routine was taken from https://github.com/pytorch/pytorch/blob/master/torch/optim/sgd.py. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + + -- common parameters + lr (float): learning rate + momentum (float, optional): momentum factor (default: 0) + stiefel (bool, optional): whether to use SGD-G (default: False) + + -- parameters in case stiefel is False + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + dampening (float, optional): dampening for momentum (default: 0) + nesterov (bool, optional): enables Nesterov momentum (default: False) + + -- parameters in case stiefel is True + omega (float, optional): orthogonality regularization factor (default: 0) + grad_clip (float, optional): threshold for gradient norm clipping (default: None) + """ + + def __init__( + self, + params, + lr: float = 1e-3, + momentum: int = 0, + dampening: int = 0, + weight_decay: int = 0, + nesterov: bool = False, + stiefel: bool = False, + omega: int = 0, + grad_clip=None, + ) -> None: + defaults = dict( + lr=lr, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + nesterov=nesterov, + stiefel=stiefel, + omega=0, + grad_clip=grad_clip, + ) + if nesterov and (momentum <= 0 or dampening != 0): + raise ValueError("Nesterov momentum requires a momentum and zero dampening") + super(SGDG, self).__init__(params, defaults) + + def __setstate__(self, state) -> None: + super(SGDG, self).__setstate__(state) + for group in self.param_groups: + group.setdefault("nesterov", False) + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + momentum = group["momentum"] + stiefel = group["stiefel"] + + for p in group["params"]: + if p.grad is None: + continue + + unity, _ = unit(p.data.view(p.size()[0], -1)) + if stiefel and unity.size()[0] <= unity.size()[1]: + weight_decay = group["weight_decay"] + dampening = group["dampening"] + nesterov = group["nesterov"] + + rand_num = random.randint(1, 101) + if rand_num == 1: + unity = qr_retraction(unity) + + g = p.grad.data.view(p.size()[0], -1) + + lr = group["lr"] + + param_state = self.state[p] + if "momentum_buffer" not in param_state: + param_state["momentum_buffer"] = torch.zeros(g.t().size()) + if p.is_cuda: + param_state["momentum_buffer"] = param_state["momentum_buffer"].cuda() + + V = param_state["momentum_buffer"] + V = momentum * V - g.t() + MX = torch.mm(V, unity) + XMX = torch.mm(unity, MX) + XXMX = torch.mm(unity.t(), XMX) + W_hat = MX - 0.5 * XXMX + W = W_hat - W_hat.t() + t = 0.5 * 2 / (matrix_norm_one(W) + episilon) + alpha = min(t, lr) + + p_new = Cayley_loop(unity.t(), W, V, alpha) + V_new = torch.mm(W, unity.t()) # n-by-p + # check_identity(p_new.t()) + p.data.copy_(p_new.view(p.size())) + V.copy_(V_new) + + else: + d_p = p.grad.data + # defined. + try: + if weight_decay != 0: + # defined. + d_p.add_(weight_decay, p.data) + except: + pass + if momentum != 0: + param_state = self.state[p] + if "momentum_buffer" not in param_state: + buf = param_state["momentum_buffer"] = d_p.clone() + else: + buf = param_state["momentum_buffer"] + # always defined. + buf.mul_(momentum).add_(1 - dampening, d_p) + # defined. + if nesterov: + d_p = d_p.add(momentum, buf) + else: + d_p = buf + + p.data.add_(-group["lr"], d_p) + + return loss diff --git a/tests/brevitas/optim/test_cailey_sgd.py b/tests/brevitas/optim/test_cailey_sgd.py new file mode 100644 index 000000000..4774ec830 --- /dev/null +++ b/tests/brevitas/optim/test_cailey_sgd.py @@ -0,0 +1,128 @@ +""" +Copyright (C) 2024, Advanced Micro Devices, Inc. +Copyright (c) 2016- Facebook, Inc (Adam Paszke) +Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +Copyright (c) 2011-2013 NYU (Clement Farabet) +Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +3. Neither the names of AMD, Facebook, Deepmind Technologies, NYU, + NEC Laboratories America and IDIAP Research Institute nor the names + of its contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. +""" + +from copy import deepcopy +from itertools import product +import math +import sys +from typing import List, Union +import unittest + +from hypothesis import given +import numpy as np +import pytest +import pytest_cases +from pytest_cases import fixture +from scipy.stats import ortho_group +import torch +from torch.nn import Parameter +import torch.nn as nn +from torch.optim.lr_scheduler import LinearLR + +from brevitas.optim.sgdg import SGDG +from tests.conftest import SEED + +torch.manual_seed(SEED) + +from torch.testing._internal.common_optimizers import OptimizerInput + +OPTIMIZER_KWARGS = [{ + "stiefel": True}, { + "stiefel": True, "lr": 1e-2}, { + "stiefel": True, "lr": torch.tensor(0.001)}] +LR_SCHEDULER_ARGS = [ + None, + (LinearLR, { + "start_factor": 1.0, "end_factor": 0.0, "total_iters": 20}),] +DEVICES = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] +DTYPES = [torch.float32] + +device_dtype_parametrize = pytest_cases.parametrize("device, dtype", list(product(DEVICES, DTYPES))) + + +class TestCaileySGD: + + @device_dtype_parametrize + @pytest_cases.parametrize("optimizer_kwargs", OPTIMIZER_KWARGS) + @pytest_cases.parametrize("lr_scheduler_args", LR_SCHEDULER_ARGS) + def test_forloop_goes_right_direction(self, device, dtype, optimizer_kwargs, lr_scheduler_args): + optim_cls = SGDG + # Generate a random orthogonal matrix of size NxN. Columns represent orthonormal vector in R^{N} + N = 5 + P = 3 + weight_orthogonal = ortho_group(dim=N, seed=SEED).rvs() + weight_orthonormal = weight_orthogonal / np.linalg.norm(weight_orthogonal, ord=2, axis=0) + # Verify that the matrix is orthonormal + assert np.allclose(np.matmul(weight_orthonormal.T, weight_orthonormal), np.eye(N)) + # Initialize weights, the Cailey SGD optimizer expects a matrix of size PxN, given the + # condition unity.size()[0] <= unity.size()[1] + weight = Parameter( + torch.from_numpy(weight_orthonormal[:, :P].T).to(device=device, dtype=dtype)) + + optimizer = optim_cls([weight], **deepcopy(optimizer_kwargs)) + scheduler = None if lr_scheduler_args is None else lr_scheduler_args[0]( + optimizer, **lr_scheduler_args[1]) + + def closure(): + optimizer.zero_grad() + loss = (weight - torch.eye(N, P, device=device, dtype=dtype).t()).pow(2).sum() + loss.backward() + return loss + + initial_value = closure().item() + for _ in range(20): + closure() + optimizer.step() + if scheduler is not None: + scheduler.step() + + # Verify that iterates stay within the Stiefel manifold + assert torch.allclose( + weight.detach().cpu() @ weight.detach().cpu().t(), + torch.eye(P, P, device=device, dtype=dtype).detach().cpu(), + atol=1e-5, + rtol=1e-6) + + if optimizer_kwargs.get("maximize", False): + assert closure().item() > initial_value + else: + assert closure().item() < initial_value From 67087d1e1a75a1a1c7e9dcb6149c03cd31d9f54e Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Fri, 6 Dec 2024 14:01:11 +0000 Subject: [PATCH 02/12] Draft implementation unfused rotation --- src/brevitas/graph/equalize.py | 95 ++++++++- src/brevitas/nn/equalized_layer.py | 95 +++++++++ tests/brevitas/graph/equalization_fixtures.py | 27 +++ tests/brevitas/graph/test_equalization.py | 196 ++++++++++++++++++ 4 files changed, 409 insertions(+), 4 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 4e5c1a162..aad00ca99 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -3,6 +3,7 @@ from abc import ABC from abc import abstractmethod +from collections import defaultdict from dataclasses import dataclass from dataclasses import field from functools import partial @@ -34,6 +35,7 @@ from brevitas.nn.equalized_layer import functional_rotate_input from brevitas.nn.equalized_layer import INPUT_NAMES from brevitas.nn.equalized_layer import RotatedModule +from brevitas.nn.equalized_layer import UnfusedRotatedModule from brevitas.nn.quant_scale_bias import ScaleBias from brevitas.utils.torch_utils import KwargsForwardHook @@ -339,6 +341,8 @@ def _get_input_axis(module: nn.Module) -> Optional[int]: return 0 else: return None + elif isinstance(module, (UnfusedRotatedModule)): + return _get_input_axis(module.module) else: return None @@ -367,6 +371,8 @@ def _get_output_axis(module: nn.Module) -> Optional[int]: return 0 else: return None + elif isinstance(module, (UnfusedRotatedModule)): + return _get_output_axis(module.module) else: return None @@ -1307,7 +1313,7 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method= if not insert_rotation_module and not region.is_valid: continue hidden_dim = region.max_shape_sinks - if not insert_rotation_module and full_rotation_method == 'ort': + if full_rotation_method == 'ort': rot_mat = random_orthogonal_matrix(hidden_dim) K = None rot_func = _apply_ort_device @@ -1374,6 +1380,82 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method= return rewriters +@dataclass +class UnfusedRotation: + rot_mat: torch.Tensor + is_sink: bool + is_source: bool + is_orphan: bool + + +def _apply_unfused_rotate(model: nn.Module, regions: List[Region], full_rotation_method='ort'): + rewriters = [] + fused_rotated_modules = defaultdict(list) + rot_func = _apply_ort_device + + for region in regions: + insert_rotation_module = len(region.srcs) == 0 + + if not insert_rotation_module and not region.is_valid: + continue + hidden_dim = region.max_shape_sinks + + rot_mat = random_orthogonal_matrix(hidden_dim) + + for name, indexes in region.srcs.items(): + module = region.get_module_from_name(name) + + fused_rotated_modules[module].append( + UnfusedRotation( + rot_mat=rot_mat, + is_sink=False, + is_source=True, + is_orphan=False, + )) + + for name, indexes in region.sinks.items(): + module = region.get_module_from_name(name) + + if insert_rotation_module and len(region.srcs) == 0: + fused_rotated_modules[module].append( + UnfusedRotation( + rot_mat=rot_mat, + is_sink=False, + is_source=False, + is_orphan=True, + )) + else: + fused_rotated_modules[module].append( + UnfusedRotation( + rot_mat=rot_mat, + is_sink=True, + is_source=False, + is_orphan=False, + )) + + for module, rotation_modules in fused_rotated_modules.items(): + rotation_module = module + for rotation_module_dataclass in rotation_modules: + rotation_module = UnfusedRotatedModule( + module=rotation_module, + rot_func=rot_func, + rot_mat=rotation_module_dataclass.rot_mat, + _get_input_axis=_get_input_axis, + _get_output_axis=_get_output_axis, + is_source=rotation_module_dataclass.is_source, + is_sink=rotation_module_dataclass.is_sink, + is_orphan=rotation_module_dataclass.is_orphan, + ) + rewriter = ModuleInstanceToModuleInstance( + module, + rotation_module, + ) + rewriters.append(rewriter) + for r in rewriters: + model = r.apply(model) + return rewriters + + def _replace_bias(next_module, new_bias): new_bias = new_bias.view(-1) if next_module.bias is not None: @@ -1463,8 +1545,10 @@ def rotate_matmuls(self, graph_module): graph_module.recompile() graph_module.graph.lint() - def apply(self, - graph_model: GraphModule) -> Union[Tuple[GraphModule, List[Transform]], GraphModule]: + def apply( + self, + graph_model: GraphModule, + fuse_rotations: bool = True) -> Union[Tuple[GraphModule, List[Transform]], GraphModule]: rewriters = [] regions = _extract_regions( graph_model, @@ -1488,7 +1572,10 @@ def apply(self, if self.rotate_matmul: self.rotate_matmuls(graph_model) if len(regions) > 0: - rewriters = _apply_rotate(graph_model, regions, self.full_rotation_method) + if fuse_rotations: + rewriters = _apply_rotate(graph_model, regions, self.full_rotation_method) + else: + rewriters = _apply_unfused_rotate(graph_model, regions, self.full_rotation_method) if self.return_rewriters: return graph_model, rewriters else: diff --git a/src/brevitas/nn/equalized_layer.py b/src/brevitas/nn/equalized_layer.py index 8413a8208..09c6dd70e 100644 --- a/src/brevitas/nn/equalized_layer.py +++ b/src/brevitas/nn/equalized_layer.py @@ -1,4 +1,5 @@ from inspect import signature +from typing import Callable, Optional import torch @@ -78,6 +79,100 @@ def forward(self, inp, **kwargs): return o +class UnfusedRotatedModule(torch.nn.Module): + + def __init__( + self, + module: torch.nn.Module, + rot_func: Callable, + rot_mat: torch.Tensor, + _get_input_axis: Callable, + _get_output_axis: Callable, + is_source: bool = False, + is_sink: bool = False, + is_orphan: bool = False, + ) -> None: + super().__init__() + self.module = module + self.rot_func = rot_func + self.rot_mat = torch.nn.Parameter(rot_mat).cpu() + + # TODO: This were included to prevent circular imports. + self._get_input_axis = _get_input_axis + self._get_output_axis = _get_output_axis + + self.is_source = is_source + self.is_sink = is_sink + self.is_orphan = is_orphan + + # These properties enable propagating the fusing to the module weights + @property + def weight(self) -> Optional[torch.Tensor]: + weight = getattr(self.module, 'weight', None) + # Add rotation and let these being propagated till the parent + # unfused rotated module + if self.is_sink or self.is_orphan: + axis = self._get_input_axis(self.module) + if axis == 1: + weight = self.rot_func(weight, self.rot_mat) + elif axis == 0: + weight = self.rot_func(weight.t(), self.rot_mat).t() + else: + raise RuntimeError("Not supported yet") + + if self.is_source: + axis = self._get_output_axis(self.module) + if axis == 0: + weight = self.rot_func(weight.t(), self.rot_mat).t() + elif axis == 1: + weight = self.rot_func(weight, self.rot_mat) + else: + raise RuntimeError("Not supported yet") + + return weight + + @property + def bias(self) -> Optional[torch.Tensor]: + bias = getattr(self.module, 'bias', None) + # Propagate bias adding the rotations incrementally + if self.is_source: + if bias is not None: + bias = self.rot_func(bias, self.rot_mat) + + return bias + + def forward(self, inp, **kwargs): + # Rotated matrices + weight = self.weight.data + bias = self.bias.data if self.bias is not None else None + + # Propagate calls till getting to the original module being rotated + child_module = self.module + # Iterate until the original module is reached, keeping the rotations that need to be performed on the input + while isinstance(child_module, UnfusedRotatedModule): + child_module = child_module.module + # child_module contains the original module in the network. Before applying its forward method, we need to + # rotate the inpute appropiately + if self.is_orphan: + # Rotate the input for an orphan sink + inp = self.rot_func(inp, self.rot_mat) + # Modify the weights, and run the original model forward. After that, restore the previous values. + if weight is not None: + orig_weight = child_module.weight.data + child_module.weight.data = weight + if bias is not None: + orig_bias = child_module.bias.data + child_module.bias.data = bias + # Call forward of the original module + o = child_module(inp) + # Restore un-rotated weights + child_module.weight.data = orig_weight + if bias is not None: + child_module.bias.data = orig_bias + # Return rotated output + return o + + def functional_rotate_input(inp, transpose=False): is_cuda = 'cuda' in str(inp.device) and torch.version.cuda is not None if transpose: diff --git a/tests/brevitas/graph/equalization_fixtures.py b/tests/brevitas/graph/equalization_fixtures.py index 035cdaadd..53b68cf1e 100644 --- a/tests/brevitas/graph/equalization_fixtures.py +++ b/tests/brevitas/graph/equalization_fixtures.py @@ -35,6 +35,8 @@ IN_SIZE_LINEAR = (1, 224, 3) IN_SIZE_CONV_SMALL = (1, 3, 32, 32) +IN_FEATURES_LINEAR = 5 + def equalize_test(regions, merge_bias, bias_shrinkage, scale_computation_type): scale_factors_regions = [] @@ -352,6 +354,24 @@ def forward(self, x): return ConvTransposeModel +@pytest_cases.fixture +def linear_model(): + + class LinearModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.linear_0 = nn.Linear(in_features=5, out_features=5) + self.linear_1 = nn.Linear(in_features=5, out_features=5) + + def forward(self, x): + x = self.linear_0(x) + x = self.linear_1(x) + return x + + return LinearModel + + list_of_fixtures = [ 'residual_model', 'srcsinkconflict_model', @@ -528,3 +548,10 @@ def forward(self, x): rotation_fixtures = fixture_union( 'rotation_fixtures', list_of_rotation_mixtures, ids=list_of_rotation_mixtures) + +list_of_rotation_unfused_mixtures = ['linear_model'] + +rotation_unfused_fixtures = fixture_union( + 'rotation_unfused_fixtures', + list_of_rotation_unfused_mixtures, + ids=list_of_rotation_unfused_mixtures) diff --git a/tests/brevitas/graph/test_equalization.py b/tests/brevitas/graph/test_equalization.py index afb8636e4..2c4b4ecee 100644 --- a/tests/brevitas/graph/test_equalization.py +++ b/tests/brevitas/graph/test_equalization.py @@ -2,21 +2,33 @@ # SPDX-License-Identifier: BSD-3-Clause import copy +from functools import partial +import itertools +from typing import List, Tuple +from unittest.mock import patch +import pytest import torch from torchvision import models from brevitas.fx import symbolic_trace +# TODO: Refactor to prevent circular import +from brevitas.graph.equalize import _apply_ort_device from brevitas.graph.equalize import _batch_norm from brevitas.graph.equalize import _extract_regions +from brevitas.graph.equalize import _get_input_axis +from brevitas.graph.equalize import _get_output_axis from brevitas.graph.equalize import _is_supported_module from brevitas.graph.equalize import _supported_layers from brevitas.graph.equalize import activation_equalization_mode from brevitas.graph.equalize import GraphRotationEqualization from brevitas.graph.equalize import MergeLnAffine +from brevitas.graph.equalize import random_orthogonal_matrix from brevitas.graph.standardize import DuplicateSharedStatelessModule from brevitas.graph.standardize import TorchFunctionalToModule from brevitas.graph.utils import get_module +from brevitas.nn.equalized_layer import RotatedModule +from brevitas.nn.equalized_layer import UnfusedRotatedModule from tests.marker import requires_pt_ge from .equalization_fixtures import * @@ -276,3 +288,187 @@ def test_models(rotation_fixtures, partial_had): if partial_had: last_weight_new = model.linear_2.layer.weight.data assert not torch.allclose(last_weight, last_weight_new) + + +def _rotate_input_output(is_source: bool, is_sink: bool, is_orphan: bool) -> Tuple[bool, bool]: + # Verify that only one flag is enabled at the same time + assert sum([is_source, is_sink, is_orphan]) <= 1, "Only one flag can be enabled." + + rotate_input, rotate_output = False, False + if is_source: + rotate_output = True + if is_sink: + rotate_input = True + + return rotate_input, rotate_output + + +def _compute_rotated_ouptut_from_matrices( + module: nn.Module, inp: torch.Tensor, rot_mat_input: torch.Tensor, + rot_mat_output: torch.Tensor): + # If the node is a sink, the input is multiplied by the inverse of the rotation matrix x <- xQ^{-1} + inp = inp @ rot_mat_input.t() + # If the node is a source, the output is multiplied by the rotation matrix o <- oQ + out = module(inp) @ rot_mat_output + # Return rotated output + return out + + +# NOTE: The assumption is that only one flag can be true simultaneously +# NOTE: Orphans need to be taken care of. A module can only be orphan once. +def _generate_rotation_flags(N: int) -> List[bool]: + return [ + rotation_flags for rotation_flags in itertools.product([False, True], repeat=3 * N) if ( + all([sum(rotation_flags[i * 3:(i + 1) * 3]) <= 1 for i in range(N)]) and + # Only outermost rotation can be orphan + all([not rotation_flags[i * 3 + 2] for i in range(N - 1)]))] + + +@requires_pt_ge('2.4') +@pytest_cases.parametrize('N', [1, 2, 3], ids=lambda x: f"N={x}") +def test_composition_unfused_rotation_layer(N): + torch.manual_seed(SEED) + + for rotation_flags in _generate_rotation_flags(N): + + in_features = IN_FEATURES_LINEAR + module = nn.Linear(in_features=in_features, out_features=in_features) + + # Sample input to pass through the block + sample_input = torch.rand((1, in_features),) + + # Compose rotation modules + rotated_module = module + + # Composite rotation matrices + rot_mat_input = torch.eye(in_features) + rot_mat_output = torch.eye(in_features) + + for i in range(N): + module_rotation_flags = rotation_flags[i * 3:(i + 1) * 3] + is_source, is_sink, is_orphan = module_rotation_flags + rotate_input, rotate_output = _rotate_input_output(is_source, is_sink, is_orphan) + + # Generate a random matrix + rot_mat = random_orthogonal_matrix(in_features).to(dtype=torch.float32) + + # Aggregate rotation matrices + if rotate_input: + rot_mat_input = rot_mat_input @ rot_mat + if rotate_output: + rot_mat_output = rot_mat_output @ rot_mat + + # Compose rotation modules + rotated_module = UnfusedRotatedModule( + module=rotated_module, + rot_func=_apply_ort_device, + _get_input_axis=_get_input_axis, + _get_output_axis=_get_output_axis, + rot_mat=rot_mat, + is_source=is_source, + is_sink=is_sink, + is_orphan=is_orphan, + ) + + # Compute outputs to compare + gt_output = _compute_rotated_ouptut_from_matrices( + module, sample_input, rot_mat_input, rot_mat_output) + rot_output = rotated_module(sample_input) + + # Verify that the rotation operations were computed correctly + assert torch.allclose(gt_output, rot_output, atol=ATOL) + + +# Adapted from https://github.com/facebookresearch/SpinQuant/blob/main/eval_utils/rotation_utils.py#L26 +def _random_orthogonal_matrix(size, generator): + """ + Generate a random orthogonal matrix of the specified size. + First, we generate a random matrix with entries from a standard distribution. + Then, we use QR decomposition to obtain an orthogonal matrix. + Finally, we multiply by a diagonal matrix with diag r to adjust the signs. + + Args: + size (int): The size of the matrix (size x size). + + Returns: + torch.Tensor: An orthogonal matrix of the specified size. + """ + torch.cuda.empty_cache() + random_matrix = torch.randn(size, size, dtype=torch.float64, generator=generator) + q, r = torch.linalg.qr(random_matrix) + q *= torch.sign(torch.diag(r)).unsqueeze(0).float() + return q + + +# This test verifies that the weights returned by the unfused rotate modules +# match those when fusing +@requires_pt_ge('2.4') +@pytest_cases.parametrize('partial_had', [False, True]) +def test_models_unfused_rotations(rotation_fixtures, partial_had): + + in_shape = IN_SIZE_LINEAR + + model_class = rotation_fixtures + model = model_class() + + model.eval() + inp = torch.rand(in_shape) + with torch.no_grad(): + expected_out = model(inp) + + model = symbolic_trace(model) + merge = MergeLnAffine() + model = merge.apply(model) + eq = GraphRotationEqualization(orphan_sink=partial_had, full_rotation_method='ort') + + # Save a copy to apply graph rotation equalization on + model_copy = copy.deepcopy(model) + + # We need to make sure that the same random matrices are being generated + generator = torch.Generator() + generator.manual_seed(SEED) + # Clone generator to make sure we can use the same rotation matrices + generator_clone = generator.clone_state() + + # We pass the generator to make sure that we can reproduce the random orthogonal matrices that are generated + with patch('brevitas.graph.equalize.random_orthogonal_matrix', + partial(_random_orthogonal_matrix, generator=generator)) as mock_ort_generator: + # Apply rotation equalization while controlling the random matrices that are generated + model = eq.apply(model) + + # Now rotate but without fusing the rotation matrices + with patch('brevitas.graph.equalize.random_orthogonal_matrix', + partial(_random_orthogonal_matrix, generator=generator_clone)) as mock_ort_generator: + # Apply rotation equalization while controlling the random matrices that are generated + model_copy = eq.apply(model_copy, fuse_rotations=False) + + with torch.no_grad(): + out = model_copy(inp) + + # Verify that the output of the model does not change after incorporating the rotations + assert torch.allclose(expected_out, out) + + # Verify that weight matrices + for model_node, model_copy_node in zip(model.graph.nodes, model_copy.graph.nodes): + if model_node.op == 'call_module': + module = get_module(model, model_node.target) + module_copy = get_module(model_copy, model_copy_node.target) + if isinstance(module, (nn.Linear, RotatedModule)): + weight = module.weight if isinstance(module, nn.Linear) else module.layer.weight + bias = module.bias if isinstance(module, nn.Linear) else module.layer.bias + weight_copy = module_copy.weight + bias_copy = module_copy.bias + assert torch.allclose(weight, weight_copy, atol=ATOL) + if bias is not None: + assert torch.allclose(bias, bias_copy, atol=ATOL) + # For a RotatedModule, corresponding to an orphan node, additional checks need to be done + if isinstance(module, RotatedModule): + # The outermost should be an orphan + rotated_module = module_copy + assert rotated_module.is_orphan, "Unfused rotated module needs to be an orphan." + # Check that the inner UnfusedRotatedModules are not orphans + while isinstance(rotated_module.module, UnfusedRotatedModule): + assert not rotated_module.module.is_orphan, "Inner unfused rotated modules cannot be orphans." + rotated_module = rotated_module.module + # Verify that the rotation matrices match + assert torch.allclose(module.had_mat, module_copy.rot_mat) From ff0d1332af4d0986523257a4fa021944ebc71afd Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Fri, 6 Dec 2024 17:55:22 +0000 Subject: [PATCH 03/12] Enable rotation matrix fusing --- src/brevitas/graph/equalize.py | 34 ++++++++ src/brevitas/nn/equalized_layer.py | 30 +++++-- .../llm/llm_quant/rotation_optimization.py | 81 +++++++++++++++++++ src/brevitas_examples/llm/main.py | 17 ++-- tests/brevitas/graph/test_equalization.py | 79 +++++++++++++++++- 5 files changed, 225 insertions(+), 16 deletions(-) create mode 100644 src/brevitas_examples/llm/llm_quant/rotation_optimization.py diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index aad00ca99..b0e02cc78 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -1380,6 +1380,40 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method= return rewriters +def _fuse_rotations(model: nn.Module): + rewriters = [] + + def _fuse_rotations_aux(module: nn.Module): + if isinstance(module, UnfusedRotatedModule): + unrotated_module = module.unrotated_module + rot_weight = module.weight.data + + # Fuse rotations with weights + unrotated_module.weight.data = rot_weight + # Fuse rotations with bias if existent + if module.bias is not None: + rot_bias = module.bias.data + unrotated_module.bias.data = rot_bias + + # Use rotated module if orphan + if module.is_orphan: + rewriter = ModuleInstanceToModuleInstance( + module, RotatedModule(had_mat=module.rot_mat, k=None, layer=unrotated_module)) + else: + rewriter = ModuleInstanceToModuleInstance(module, unrotated_module) + # Save rewriter + rewriters.append(rewriter) + else: + for child_module in module.children(): + _fuse_rotations_aux(child_module) + + # Populate rewriters + _fuse_rotations_aux(model) + # Apply rewriter to fuse the weights + for r in rewriters: + model = r.apply(model) + + @dataclass class UnfusedRotation: rot_mat: torch.Tensor diff --git a/src/brevitas/nn/equalized_layer.py b/src/brevitas/nn/equalized_layer.py index 09c6dd70e..899612796 100644 --- a/src/brevitas/nn/equalized_layer.py +++ b/src/brevitas/nn/equalized_layer.py @@ -52,6 +52,11 @@ def forward(self, *args, **kwargs): return out +def _apply_ort_device(tensor, ort, *args): + ort = ort.type_as(tensor) + return torch.matmul(tensor, ort) + + class RotatedModule(torch.nn.Module): def __init__(self, layer, had_mat=None, k=None) -> None: @@ -65,15 +70,19 @@ def __init__(self, layer, had_mat=None, k=None) -> None: def forward(self, inp, **kwargs): is_cuda = 'cuda' in str(inp.device) and torch.version.cuda is not None - if is_cuda and fast_hadamard_transform is not None: - if self.had_mat is None or self.k is None: - had_K, K = get_hadK(inp.shape[-1]) - else: - had_K = self.had_mat - K = self.k - inp = matmul_hadU_cuda(inp, had_K, K) + # If k is None, we assume that an orthogonal matrix is used + if self.k is None: + inp = _apply_ort_device(inp, self.had_mat) else: - inp = matmul_hadU(inp) + if is_cuda and fast_hadamard_transform is not None: + if self.had_mat is None or self.k is None: + had_K, K = get_hadK(inp.shape[-1]) + else: + had_K = self.had_mat + K = self.k + inp = matmul_hadU_cuda(inp, had_K, K) + else: + inp = matmul_hadU(inp) o = self.layer(inp) return o @@ -141,6 +150,11 @@ def bias(self) -> Optional[torch.Tensor]: return bias + @property + def unrotated_module(self) -> torch.nn.Module: + return self.module.unrotated_module if isinstance( + self.module, UnfusedRotatedModule) else self.module + def forward(self, inp, **kwargs): # Rotated matrices weight = self.weight.data diff --git a/src/brevitas_examples/llm/llm_quant/rotation_optimization.py b/src/brevitas_examples/llm/llm_quant/rotation_optimization.py new file mode 100644 index 000000000..809d07bbe --- /dev/null +++ b/src/brevitas_examples/llm/llm_quant/rotation_optimization.py @@ -0,0 +1,81 @@ +""" +Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +""" + +from dataclasses import dataclass +from dataclasses import field +from typing import Optional, Tuple + +import torch +from torch.utils.data import Dataset +from tqdm import tqdm +import transformers +from transformers import default_data_collator +from transformers import Trainer +from transformers.tokenization_utils import PreTrainedTokenizerBase + +from brevitas.nn.equalized_layer import UnfusedRotatedModule +from brevitas.optim.sgdg import SGDG + + +@dataclass +class ModelArguments: + input_model: Optional[str] = field( + default="hf-internal-testing/tiny-random-LlamaForCausalLM", + metadata={"help": "Input model"}) + output_rotation_path: Optional[str] = field( + default="test-output", metadata={"help": "Output rotation checkpoint path"}) + optimized_rotation_path: Optional[str] = field( + default=None, metadata={"help": "Optimized rotation checkpoint path"}) + access_token: Optional[str] = field( + default="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", + metadata={"help": "Huggingface access token to access gated repo like Llama"}, + ) + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + cache_dir: Optional[str] = field(default=None) + output_dir: Optional[str] = field(default="/tmp/output/") + model_max_length: Optional[int] = field( + default=2048, + metadata={ + "help": + "Maximum sequence length. Sequences will be right padded (and possibly truncated)"}, + ) + + +def parse_optimization_rotation_args(unknown_args=None) -> None: + parser = transformers.HfArgumentParser(( + ModelArguments, + TrainingArguments, + )) + _, training_args = parser.parse_args_into_dataclasses(args=unknown_args) + return training_args + + +def apply_rotation_optimization( + graph_model: torch.fx.GraphModule, + tokenizer: PreTrainedTokenizerBase, + train_dataset: Dataset, + unknown_args=None) -> None: + # Get training arguments + training_args = parse_optimization_rotation_args(unknown_args) + # Collect trainable matrices + trainable_parameters = [] + for module in graph_model.modules(): + if isinstance(module, UnfusedRotatedModule): + if not module.is_sink: + trainable_parameters.append(module.rot_mat) + # Initialize optimizer + optimizer = SGDG(trainable_parameters, lr=training_args.learning_rate, stiefel=True) + trainer = Trainer( + model=graph_model, + tokenizer=tokenizer, + args=training_args, + train_dataset=train_dataset, + eval_dataset=None, + data_collator=default_data_collator, + optimizers=(optimizer, None)) + trainer.train() diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index ed2ebc2c8..bf5b9685e 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -47,6 +47,7 @@ from brevitas_examples.llm.llm_quant.ln_affine_merge import replace_rmsnorm_with_torch from brevitas_examples.llm.llm_quant.prepare_for_quantize import add_zero_bias_to_linear from brevitas_examples.llm.llm_quant.prepare_for_quantize import replace_mha_with_quantizable_layers +from brevitas_examples.llm.llm_quant.rotation_optimization import apply_rotation_optimization from brevitas_examples.llm.llm_quant.prepare_for_quantize import \ replace_sdpa_with_quantizable_layers from brevitas_examples.llm.llm_quant.run_utils import CastFloat16ToFloat32 @@ -195,7 +196,7 @@ def validate(args): "or decreasing the sequence length (seqlen)") -def quantize_llm(args): +def quantize_llm(args, unknown_args=None): validate(args) set_seed(args.seed) if args.export_prefix is None: @@ -297,7 +298,7 @@ def quantize_llm(args): model = offload_model(model) eq = GraphRotationEqualization( orphan_sink=args.rotation_orphan_sink, full_rotation_method=args.rotation_mode) - model = eq.apply(model) + model = eq.apply(model, fuse_rotations=not args.rotation_optimize) remove_hooks(model) elif args.rotation == 'layerwise': eq = LayerwiseActivationRotation() @@ -372,6 +373,7 @@ def quantize_llm(args): quantize_embedding=False) if not args.quantize_last_layer: if require_fx: + # TODO: Fix when using UnfusedRotation, layer_map[type(last_module)][1] crashes last_node = [node for node in model.graph.nodes if node.op == 'call_module'][-1] last_module = get_module(model, last_node.target) last_layer_kwargs = layer_map[type(last_module)][1] @@ -766,6 +768,11 @@ def parse_args(args, override_defaults={}): help= 'If GraphRotation is enabled, decide how to compute the random rotation matrix that is fully fused. Online or partial rotation will always be Hadamard' ) + # TODO: Make sure in argument validator that + parser.add_argument( + '--rotation-optimize', + action='store_true', + help='Whether to optimize the rotation matrices.') parser.add_argument( '--rotation-orphan-sink', action="store_true", @@ -847,13 +854,13 @@ def parse_args(args, override_defaults={}): help='A list of tasks for zero_shot evaluation. Default: %(default)s') parser.set_defaults(**override_defaults) - return parser.parse_args(args) + return parser.parse_known_args(args) def main(): overrides = override_defaults(sys.argv[1:]) - args = parse_args(sys.argv[1:], override_defaults=overrides) - quantize_llm(args) + args, unknown_args = parse_args(sys.argv[1:], override_defaults=overrides) + quantize_llm(args, unknown_args) if __name__ == '__main__': diff --git a/tests/brevitas/graph/test_equalization.py b/tests/brevitas/graph/test_equalization.py index 2c4b4ecee..d8ea0bcb2 100644 --- a/tests/brevitas/graph/test_equalization.py +++ b/tests/brevitas/graph/test_equalization.py @@ -16,6 +16,7 @@ from brevitas.graph.equalize import _apply_ort_device from brevitas.graph.equalize import _batch_norm from brevitas.graph.equalize import _extract_regions +from brevitas.graph.equalize import _fuse_rotations from brevitas.graph.equalize import _get_input_axis from brevitas.graph.equalize import _get_output_axis from brevitas.graph.equalize import _is_supported_module @@ -413,8 +414,6 @@ def test_models_unfused_rotations(rotation_fixtures, partial_had): model.eval() inp = torch.rand(in_shape) - with torch.no_grad(): - expected_out = model(inp) model = symbolic_trace(model) merge = MergeLnAffine() @@ -446,7 +445,8 @@ def test_models_unfused_rotations(rotation_fixtures, partial_had): out = model_copy(inp) # Verify that the output of the model does not change after incorporating the rotations - assert torch.allclose(expected_out, out) + with torch.no_grad(): + expected_out = model(inp) # Verify that weight matrices for model_node, model_copy_node in zip(model.graph.nodes, model_copy.graph.nodes): @@ -472,3 +472,76 @@ def test_models_unfused_rotations(rotation_fixtures, partial_had): rotated_module = rotated_module.module # Verify that the rotation matrices match assert torch.allclose(module.had_mat, module_copy.rot_mat) + + +# This test verifies that the weights returned by the unfused rotate modules +# match those when fusing +@requires_pt_ge('2.4') +@pytest_cases.parametrize('partial_had', [False, True]) +def test_models_fused_rotations(rotation_fixtures, partial_had): + + in_shape = IN_SIZE_LINEAR + + model_class = rotation_fixtures + model = model_class() + + model.eval() + inp = torch.rand(in_shape) + + model = symbolic_trace(model) + merge = MergeLnAffine() + model = merge.apply(model) + eq = GraphRotationEqualization(orphan_sink=partial_had, full_rotation_method='ort') + + # Save a copy to apply graph rotation equalization on + model_copy = copy.deepcopy(model) + + # We need to make sure that the same random matrices are being generated + generator = torch.Generator() + generator.manual_seed(SEED) + # Clone generator to make sure we can use the same rotation matrices + generator_clone = generator.clone_state() + + # We pass the generator to make sure that we can reproduce the random orthogonal matrices that are generated + with patch('brevitas.graph.equalize.random_orthogonal_matrix', + partial(_random_orthogonal_matrix, generator=generator)) as mock_ort_generator: + # Apply rotation equalization while controlling the random matrices that are generated + model = eq.apply(model) + + with torch.no_grad(): + expected_out = model(inp) + + # Now rotate but without fusing the rotation matrices + with patch('brevitas.graph.equalize.random_orthogonal_matrix', + partial(_random_orthogonal_matrix, generator=generator_clone)) as mock_ort_generator: + # Apply rotation equalization while controlling the random matrices that are generated + model_copy = eq.apply(model_copy, fuse_rotations=False) + + # Fuse the rotations and make sure the behaviour is the same + _fuse_rotations(model_copy) + + with torch.no_grad(): + out = model_copy(inp) + + # Verify that the output of the model does not change after incorporating the rotations + assert torch.allclose(expected_out, out) + + # Verify that weight matrices + for model_node, model_copy_node in zip(model.graph.nodes, model_copy.graph.nodes): + if model_node.op == 'call_module': + module = get_module(model, model_node.target) + module_copy = get_module(model_copy, model_copy_node.target) + if isinstance(module, (nn.Linear, RotatedModule)): + weight = module.weight if isinstance(module, nn.Linear) else module.layer.weight + bias = module.bias if isinstance(module, nn.Linear) else module.layer.bias + weight_copy = module_copy.weight if isinstance( + module_copy, nn.Linear) else module_copy.layer.weight + bias_copy = module_copy.bias if isinstance( + module_copy, nn.Linear) else module_copy.layer.bias + assert torch.allclose(weight, weight_copy, atol=ATOL) + if bias is not None: + assert torch.allclose(bias, bias_copy, atol=ATOL) + # For a RotatedModule, corresponding to an orphan node, additional checks need to be done + if isinstance(module, RotatedModule): + # Verify that the rotation matrices match + assert torch.allclose(module.had_mat, module_copy.had_mat) From e6bb2a76aa3711dd773523bb129626aaa0f4dd19 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Fri, 6 Dec 2024 17:56:51 +0000 Subject: [PATCH 04/12] Remove default --- src/brevitas_examples/llm/llm_quant/rotation_optimization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas_examples/llm/llm_quant/rotation_optimization.py b/src/brevitas_examples/llm/llm_quant/rotation_optimization.py index 809d07bbe..8baf07596 100644 --- a/src/brevitas_examples/llm/llm_quant/rotation_optimization.py +++ b/src/brevitas_examples/llm/llm_quant/rotation_optimization.py @@ -29,7 +29,7 @@ class ModelArguments: optimized_rotation_path: Optional[str] = field( default=None, metadata={"help": "Optimized rotation checkpoint path"}) access_token: Optional[str] = field( - default="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", + default="", metadata={"help": "Huggingface access token to access gated repo like Llama"}, ) From 694f80be487db3225bc6dbfdceb037b77c951d46 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Thu, 12 Dec 2024 14:24:33 +0000 Subject: [PATCH 05/12] New tests --- src/brevitas/graph/equalize.py | 266 +++++++++++- src/brevitas/graph/hadamard.py | 100 +++++ src/brevitas/nn/equalized_layer.py | 123 +++++- .../llm/llm_quant/rotation_optimization.py | 32 +- src/brevitas_examples/llm/main.py | 68 ++- tests/brevitas/graph/test_equalization.py | 390 ++++++++++++++---- tests/brevitas_examples/test_llm.py | 2 +- 7 files changed, 880 insertions(+), 101 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index b0e02cc78..0cd8ea47c 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -29,6 +29,7 @@ from brevitas.graph.hadamard import get_hadK from brevitas.graph.hadamard import matmul_hadU from brevitas.graph.hadamard import matmul_hadU_cuda +from brevitas.graph.hadamard import random_hadamard_matrix from brevitas.graph.utils import get_module from brevitas.graph.utils import get_node from brevitas.nn.equalized_layer import EqualizedModule @@ -341,8 +342,11 @@ def _get_input_axis(module: nn.Module) -> Optional[int]: return 0 else: return None + # TODO: Remove with parametrizations elif isinstance(module, (UnfusedRotatedModule)): return _get_input_axis(module.module) + elif isinstance(module, (RotatedModule,)): + return _get_input_axis(module.layer) else: return None @@ -371,8 +375,11 @@ def _get_output_axis(module: nn.Module) -> Optional[int]: return 0 else: return None + # TODO: Remove with parametrizations elif isinstance(module, (UnfusedRotatedModule)): return _get_output_axis(module.module) + elif isinstance(module, (RotatedModule,)): + return _get_output_axis(module.layer) else: return None @@ -1281,6 +1288,10 @@ def _apply_had_device(tensor, had_K, K): def _apply_ort_device(tensor, ort, *args): ort = ort.type_as(tensor) + if tensor.shape[-1] != ort.shape[0]: + tensor_shape = tensor.shape + return torch.matmul(tensor.view(-1, tensor_shape[-1] // ort.shape[0], ort.shape[0]), + ort).view(tensor_shape) return torch.matmul(tensor, ort) @@ -1313,6 +1324,7 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method= if not insert_rotation_module and not region.is_valid: continue hidden_dim = region.max_shape_sinks + # TODO: Include again not insert_rotation_module if full_rotation_method == 'ort': rot_mat = random_orthogonal_matrix(hidden_dim) K = None @@ -1606,10 +1618,8 @@ def apply( if self.rotate_matmul: self.rotate_matmuls(graph_model) if len(regions) > 0: - if fuse_rotations: - rewriters = _apply_rotate(graph_model, regions, self.full_rotation_method) - else: - rewriters = _apply_unfused_rotate(graph_model, regions, self.full_rotation_method) + _apply_rotate_fn = _apply_rotate if fuse_rotations else _apply_unfused_rotate + rewriters = _apply_rotate_fn(graph_model, regions, self.full_rotation_method) if self.return_rewriters: return graph_model, rewriters else: @@ -1702,9 +1712,253 @@ def __init__(self, blacklist_layer=None): self.supported_sinks = (nn.Linear) self.blacklist_layers = blacklist_layer - def apply(self, model: nn.Module) -> nn.Module: + def apply(self, model: nn.Module, fuse_rotations: bool = True) -> nn.Module: regions: List[Region] = [] self.find_module(model, regions) if len(regions) > 0: - _apply_rotate(model, regions) + _apply_rotate_fn = _apply_rotate if fuse_rotations else _apply_unfused_rotate + _apply_rotate_fn(model, regions) return model + + +def find_missing_rotation_regions(graph_model: GraphModule, + head_dim: int, + state_impl_kwargs=None) -> List[Region]: + import re + + regions = [] + # Add R2 regions, this should be innermost + for src_name, src_module in graph_model.named_modules(): + if "attn_v_proj" in src_name: + if state_impl_kwargs is not None: + state = WalkRegionState(**state_impl_kwargs) + else: + state = WalkRegionState() + + block_number_matches_src = re.findall(r'\d+', src_name) + assert len(block_number_matches_src) == 2, "Could not identify block" + block_number_src = int(block_number_matches_src[1]) + + eq_indexes = EqualizationIndexes(0, head_dim, 0) + state.add_srcs(src_name, src_module, eq_indexes) + + # Now the corresponding sink + for sink_name, sink_module in graph_model.named_modules(): + if "attn_o_proj" in sink_name: + block_number_matches_sink = re.findall(r'\d+', sink_name) + assert len(block_number_matches_sink) == 2, "Could not identify block" + block_number_sink = int(block_number_matches_sink[1]) + # If the blocks match, we identified the region + if block_number_src == block_number_sink: + eq_indexes = EqualizationIndexes(0, head_dim, state.offset) + state.add_sinks(sink_name, sink_module, eq_indexes) + # Instantiate region and add to list + region = Region( + srcs=dict(sorted(state.srcs.items())), + sinks=dict(sorted(state.sinks.items())), + name_to_module=state.name_to_module, + ) + if region not in regions: + regions.append(region) + + return regions + + +def _apply_rotate_fused_rotations( + model: nn.Module, + regions: List[Region], + full_rotation_method: str = 'had', + fuse_rotations: bool = True): + rewriters = [] + # Dictionary to append the unfused rotated modules for optimization + unfused_rotated_modules = defaultdict(list) + # Dictionary to keep track of the modules that are assigned to a RotatedModule + fused_rotated_modules = {} + # List to keep track of the rotation matrices added to the + rotation_matrices = [] + + for region in regions: + insert_rotation_module = len(region.srcs) == 0 + + if not insert_rotation_module and not region.is_valid: + continue + hidden_dim = region.max_shape_sinks + if not insert_rotation_module and full_rotation_method == 'ort': + rot_mat = random_orthogonal_matrix( + hidden_dim) if fuse_rotations else torch.nn.Parameter( + random_orthogonal_matrix(hidden_dim)) + K = None + rot_func = _apply_ort_device + # Store rotation matrix for optimization + rotation_matrices.append(rot_mat) + elif not insert_rotation_module and not fuse_rotations: + # TODO: Make it more general + rot_mat = torch.nn.Parameter(random_hadamard_matrix(hidden_dim, torch.device('cpu'))) + K = None + rot_func = _apply_ort_device + # Store rotation matrix for optimization + rotation_matrices.append(rot_mat) + else: + try: + # Build hadamard rotation matrix + rot_mat, K = get_hadK(hidden_dim) + rot_func = _apply_had_device + except AssertionError as e: + print(f"Incomptible shapes {hidden_dim}") + if not insert_rotation_module: + print("Falling back to orthogonal matrices") + rot_mat = random_orthogonal_matrix(hidden_dim) + K = None + rot_func = _apply_ort_device + print("Skipping layers") + continue + + for name, indexes in region.srcs.items(): + module = region.get_module_from_name(name) + + if not insert_rotation_module and fuse_rotations: + if hasattr(module, 'allocate_params'): + module.allocate_params(module) + + axis = _get_output_axis(module) + weight = module.weight.data + + if axis == 0: + weight = rot_func(weight.t(), rot_mat, K).t() + elif axis == 1: + weight = rot_func(weight, rot_mat, K) + else: + raise RuntimeError("Not supported yet") + module.weight.data = weight + + if getattr(module, 'bias', None) is not None: + bias = module.bias.data + bias = rot_func(bias, rot_mat, K) + module.bias.data = bias + + if hasattr(module, 'offload_params'): + module.offload_params(module) + else: + unfused_rotated_modules[module].append( + UnfusedRotation( + rot_mat=rot_mat, + is_sink=False, + is_source=True, + is_orphan=False, + )) + + for name, indexes in region.sinks.items(): + module = region.get_module_from_name(name) + + if not insert_rotation_module and not fuse_rotations: + unfused_rotated_modules[module].append( + UnfusedRotation( + rot_mat=rot_mat, + is_sink=True, + is_source=False, + is_orphan=False, + )) + else: + if hasattr(module, 'allocate_params'): + module.allocate_params(module) + + axis = _get_input_axis(module) + weight = module.weight.data + + if axis == 1: + _update_weights(module, rot_func(weight, rot_mat, K), 'weight') + elif axis == 0: + _update_weights(module, rot_func(weight.t(), rot_mat, K).t(), 'weight') + else: + raise RuntimeError("Not supported yet") + + if hasattr(module, 'offload_params'): + module.offload_params(module) + + if insert_rotation_module: + if module not in fused_rotated_modules: + fused_rotated_modules[module] = RotatedModule( + had_mat=rot_mat, k=K, layer=module) + else: + raise RuntimeError( + "Only one RotatedModule at most can be assigned to a module.") + # For this to work, we need to have the following hierarchy UnfusedRotatedModule -> (RotatedModule) -> Linear + for module, rotation_modules in unfused_rotated_modules.items(): + # Verify that at most one RotatedModule is available + rotation_module = module if module not in fused_rotated_modules else fused_rotated_modules[ + module] + for rotation_module_dataclass in rotation_modules: + rotation_module = UnfusedRotatedModule( + module=rotation_module, + rot_func=rot_func, + rot_mat=rotation_module_dataclass.rot_mat, + _get_input_axis=_get_input_axis, + _get_output_axis=_get_output_axis, + is_source=rotation_module_dataclass.is_source, + is_sink=rotation_module_dataclass.is_sink, + is_orphan=rotation_module_dataclass.is_orphan, + ) + # Instantiate rewriters + rewriter = ModuleInstanceToModuleInstance(module, rotation_module) + rewriters.append(rewriter) + # Add missing RotatedModules, in case there are any + for module, rotation_module in fused_rotated_modules.items(): + if module not in unfused_rotated_modules: + rewriter = ModuleInstanceToModuleInstance(module, rotation_module) + rewriters.append(rewriter) + for r in rewriters: + model = r.apply(model) + return rewriters, rotation_matrices + + +class GraphRotationEqualizationOptimization(GraphRotationEqualization): + + def __init__( + self, + blacklist_layers: Optional[List[str]] = None, + orphan_sink: bool = False, + rotate_matmul: bool = False, + full_rotation_method: str = 'had', + ) -> None: + super(GraphRotationEqualizationOptimization, self).__init__( + blacklist_layers=blacklist_layers, + orphan_sink=orphan_sink, + rotate_matmul=rotate_matmul, + full_rotation_method=full_rotation_method, + return_rewriters=True, + ) + + def apply( + self, + graph_model: GraphModule, + fuse_rotations: bool = True, + additional_regions: Optional[List] = None + ) -> Union[Tuple[GraphModule, List[Transform]], GraphModule]: + rewriters = [] + regions = _extract_regions( + graph_model, + state_impl_kwargs={ + 'supported_srcs': self.supported_srcs, + 'supported_sinks': self.supported_sinks, + 'scale_invariant_layers': self.scale_invariant_layers, + 'scale_invariant_function': self.scale_invariant_function}) + if additional_regions is not None: + regions.extend(additional_regions) + eq_layers = set() + orphan_regions = [] + self.find_module(graph_model, orphan_regions) + for r in regions: + id_list = [id(r.name_to_module[sink_name]) for sink_name in r.sinks_names] + eq_layers.update(id_list) + if self.orphan_sink: + for o_r in orphan_regions: + # Layerwise have only a single sink named 'sinks0' + id_sink = id(o_r.get_module_from_name('sinks0')) + if id_sink not in eq_layers: + # TODO: Change data structure to insert in the beginning with O(1) + regions = [o_r] + regions + if self.rotate_matmul: + self.rotate_matmuls(graph_model) + if len(regions) > 0: + rewriters, rotation_matrices = _apply_rotate_fused_rotations(graph_model, regions, self.full_rotation_method, fuse_rotations) + return graph_model, rewriters, rotation_matrices diff --git a/src/brevitas/graph/hadamard.py b/src/brevitas/graph/hadamard.py index 27bf1b4ae..d3773c759 100644 --- a/src/brevitas/graph/hadamard.py +++ b/src/brevitas/graph/hadamard.py @@ -66,6 +66,23 @@ def get_hadK(n, transpose=False): assert (is_pow2(n // 12)) K = 12 hadK = tensors['get_had12'].T if transpose else tensors['get_had12'] + # TODO: Add this matrix along with the others + elif n % 64 == 0: + assert (is_pow2(n // 64)) + K = 64 + hadK = torch.tensor([[1 if char == '+' else -1 + for char in line] + for line in hadamard_string_64.strip().split('\n')], + dtype=torch.float32, + requires_grad=False) + elif n % 16 == 0: + assert (is_pow2(n // 16)) + K = 16 + hadK = torch.tensor([[1 if char == '+' else -1 + for char in line] + for line in hadamard_string_16.strip().split('\n')], + dtype=torch.float32, + requires_grad=False) else: assert (is_pow2(n)) K = 1 @@ -166,3 +183,86 @@ def apply_exact_had_to_linear(module, had_dim=-1, output=False): def is_pow2(n): return (n & (n - 1) == 0) and (n > 0) + + +hadamard_string_16 = """++++++++++++++++ ++-+-+-+-+-+-+-+- +++--++--++--++-- ++--++--++--++--+ +++++----++++---- ++-+--+-++-+--+-+ +++----++++----++ ++--+-++-+--+-++- +++++++++-------- ++-+-+-+--+-+-+-+ +++--++----++--++ ++--++--+-++--++- +++++--------++++ ++-+--+-+-+-++-+- +++----++--++++-- ++--+-++--++-+--+""" + +hadamard_string_64 = """++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+- +++--++--++--++--++--++--++--++--++--++--++--++--++--++--++--++-- ++--++--++--++--++--++--++--++--++--++--++--++--++--++--++--++--+ +++++----++++----++++----++++----++++----++++----++++----++++---- ++-+--+-++-+--+-++-+--+-++-+--+-++-+--+-++-+--+-++-+--+-++-+--+-+ +++----++++----++++----++++----++++----++++----++++----++++----++ ++--+-++-+--+-++-+--+-++-+--+-++-+--+-++-+--+-++-+--+-++-+--+-++- +++++++++--------++++++++--------++++++++--------++++++++-------- ++-+-+-+--+-+-+-++-+-+-+--+-+-+-++-+-+-+--+-+-+-++-+-+-+--+-+-+-+ +++--++----++--++++--++----++--++++--++----++--++++--++----++--++ ++--++--+-++--++-+--++--+-++--++-+--++--+-++--++-+--++--+-++--++- +++++--------++++++++--------++++++++--------++++++++--------++++ ++-+--+-+-+-++-+-+-+--+-+-+-++-+-+-+--+-+-+-++-+-+-+--+-+-+-++-+- +++----++--++++--++----++--++++--++----++--++++--++----++--++++-- ++--+-++--++-+--++--+-++--++-+--++--+-++--++-+--++--+-++--++-+--+ +++++++++++++++++----------------++++++++++++++++---------------- ++-+-+-+-+-+-+-+--+-+-+-+-+-+-+-++-+-+-+-+-+-+-+--+-+-+-+-+-+-+-+ +++--++--++--++----++--++--++--++++--++--++--++----++--++--++--++ ++--++--++--++--+-++--++--++--++-+--++--++--++--+-++--++--++--++- +++++----++++--------++++----++++++++----++++--------++++----++++ ++-+--+-++-+--+-+-+-++-+--+-++-+-+-+--+-++-+--+-+-+-++-+--+-++-+- +++----++++----++--++++----++++--++----++++----++--++++----++++-- ++--+-++-+--+-++--++-+--+-++-+--++--+-++-+--+-++--++-+--+-++-+--+ +++++++++----------------++++++++++++++++----------------++++++++ ++-+-+-+--+-+-+-+-+-+-+-++-+-+-+-+-+-+-+--+-+-+-+-+-+-+-++-+-+-+- +++--++----++--++--++--++++--++--++--++----++--++--++--++++--++-- ++--++--+-++--++--++--++-+--++--++--++--+-++--++--++--++-+--++--+ +++++--------++++----++++++++----++++--------++++----++++++++---- ++-+--+-+-+-++-+--+-++-+-+-+--+-++-+--+-+-+-++-+--+-++-+-+-+--+-+ +++----++--++++----++++--++----++++----++--++++----++++--++----++ ++--+-++--++-+--+-++-+--++--+-++-+--+-++--++-+--+-++-+--++--+-++- +++++++++++++++++++++++++++++++++-------------------------------- ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+--+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +++--++--++--++--++--++--++--++----++--++--++--++--++--++--++--++ ++--++--++--++--++--++--++--++--+-++--++--++--++--++--++--++--++- +++++----++++----++++----++++--------++++----++++----++++----++++ ++-+--+-++-+--+-++-+--+-++-+--+-+-+-++-+--+-++-+--+-++-+--+-++-+- +++----++++----++++----++++----++--++++----++++----++++----++++-- ++--+-++-+--+-++-+--+-++-+--+-++--++-+--+-++-+--+-++-+--+-++-+--+ +++++++++--------++++++++----------------++++++++--------++++++++ ++-+-+-+--+-+-+-++-+-+-+--+-+-+-+-+-+-+-++-+-+-+--+-+-+-++-+-+-+- +++--++----++--++++--++----++--++--++--++++--++----++--++++--++-- ++--++--+-++--++-+--++--+-++--++--++--++-+--++--+-++--++-+--++--+ +++++--------++++++++--------++++----++++++++--------++++++++---- ++-+--+-+-+-++-+-+-+--+-+-+-++-+--+-++-+-+-+--+-+-+-++-+-+-+--+-+ +++----++--++++--++----++--++++----++++--++----++--++++--++----++ ++--+-++--++-+--++--+-++--++-+--+-++-+--++--+-++--++-+--++--+-++- +++++++++++++++++--------------------------------++++++++++++++++ ++-+-+-+-+-+-+-+--+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-++-+-+-+-+-+-+-+- +++--++--++--++----++--++--++--++--++--++--++--++++--++--++--++-- ++--++--++--++--+-++--++--++--++--++--++--++--++-+--++--++--++--+ +++++----++++--------++++----++++----++++----++++++++----++++---- ++-+--+-++-+--+-+-+-++-+--+-++-+--+-++-+--+-++-+-+-+--+-++-+--+-+ +++----++++----++--++++----++++----++++----++++--++----++++----++ ++--+-++-+--+-++--++-+--+-++-+--+-++-+--+-++-+--++--+-++-+--+-++- +++++++++----------------++++++++--------++++++++++++++++-------- ++-+-+-+--+-+-+-+-+-+-+-++-+-+-+--+-+-+-++-+-+-+-+-+-+-+--+-+-+-+ +++--++----++--++--++--++++--++----++--++++--++--++--++----++--++ ++--++--+-++--++--++--++-+--++--+-++--++-+--++--++--++--+-++--++- +++++--------++++----++++++++--------++++++++----++++--------++++ ++-+--+-+-+-++-+--+-++-+-+-+--+-+-+-++-+-+-+--+-++-+--+-+-+-++-+- +++----++--++++----++++--++----++--++++--++----++++----++--++++-- ++--+-++--++-+--+-++-+--++--+-++--++-+--++--+-++-+--+-++--++-+--+""" diff --git a/src/brevitas/nn/equalized_layer.py b/src/brevitas/nn/equalized_layer.py index 899612796..fbe045124 100644 --- a/src/brevitas/nn/equalized_layer.py +++ b/src/brevitas/nn/equalized_layer.py @@ -1,3 +1,4 @@ +import functools from inspect import signature from typing import Callable, Optional @@ -68,6 +69,14 @@ def __init__(self, layer, had_mat=None, k=None) -> None: self.layer = layer self.k = k + @property + def weight(self) -> Optional[torch.Tensor]: + return getattr(self.layer, 'weight', None) + + @property + def bias(self) -> Optional[torch.Tensor]: + return getattr(self.layer, 'bias', None) + def forward(self, inp, **kwargs): is_cuda = 'cuda' in str(inp.device) and torch.version.cuda is not None # If k is None, we assume that an orthogonal matrix is used @@ -88,13 +97,91 @@ def forward(self, inp, **kwargs): return o +def rot_func_wrapper(weight: torch.Tensor, rot_mat: torch.Tensor, rotation_function: Callable): + weight_shape = weight.shape + rot_mat_dim = rot_mat.shape[0] + return rotation_function(weight.view(-1, weight_shape.shape[1] // rot_mat_dim, + rot_mat_dim)).view(weight_shape) + + +class RotationWeightParametrization(torch.nn.Module): + + def __init__( + self, + rot_mat: torch.nn.Parameter, + rot_func: Callable, + input_axis: Optional[int] = None, + output_axis: Optional[int] = None, + is_source: bool = False, + is_sink: bool = False, + is_orphan: bool = False, + ) -> None: + super().__init__() + self.rot_mat = rot_mat + self.rot_func = rot_func + self.input_axis = input_axis + self.output_axis = output_axis + self.is_source = is_source + self.is_sink = is_sink + self.is_orphan = is_orphan + self.K = None + + def forward(self, weight: torch.Tensor) -> torch.Tensor: + if self.is_sink or self.is_orphan: + if self.input_axis == 1: + weight = self.rot_func(weight, self.rot_mat, self.K) + elif self.input_axis == 0: + weight = self.rot_func(weight.t(), self.rot_mat, self.K).t() + else: + raise RuntimeError("Not supported yet") + + if self.is_source: + if self.output_axis == 0: + weight = self.rot_func(weight.t(), self.rot_mat, self.K).t() + elif self.output_axis == 1: + weight = self.rot_func(weight, self.rot_mat, self.K) + else: + raise RuntimeError("Not supported yet") + + return weight + + +class RotationBiasParametrization(torch.nn.Module): + + def __init__( + self, + rot_mat: torch.nn.Parameter, + rot_func: Callable, + input_axis: Optional[int] = None, + output_axis: Optional[int] = None, + is_source: bool = False, + is_sink: bool = False, + is_orphan: bool = False, + ) -> None: + super().__init__() + self.rot_mat = rot_mat + self.rot_func = rot_func + self.input_axis = input_axis + self.output_axis = output_axis + self.is_source = is_source + self.is_sink = is_sink + self.is_orphan = is_orphan + self.K = None + + def forward(self, bias: torch.Tensor) -> torch.Tensor: + if self.is_source: + bias = self.rot_func(bias, self.rot_mat, self.K) + + return bias + + class UnfusedRotatedModule(torch.nn.Module): def __init__( self, module: torch.nn.Module, rot_func: Callable, - rot_mat: torch.Tensor, + rot_mat: torch.nn.Parameter, _get_input_axis: Callable, _get_output_axis: Callable, is_source: bool = False, @@ -104,7 +191,8 @@ def __init__( super().__init__() self.module = module self.rot_func = rot_func - self.rot_mat = torch.nn.Parameter(rot_mat).cpu() + self.rot_mat = rot_mat + self.K = None # TODO: This were included to prevent circular imports. self._get_input_axis = _get_input_axis @@ -114,6 +202,25 @@ def __init__( self.is_sink = is_sink self.is_orphan = is_orphan + # TODO: Does it make sense the extra complexity just to prevent the view operation? + # Probably if no reshaping needs to be done, no change is required + def _wrap_rot(self) -> bool: + weight_shape = self.module.weight.shape + rot_dim = self.rot_mat.shape[0] + if self.is_sink or self.is_orphan: + weight_shape_dim = weight_shape[self._get_input_axis(self.module)] + elif self.is_source: + weight_shape_dim = weight_shape[self._get_output_axis(self.module)] + else: + weight_shape_dim = None + + if weight_shape_dim is not None: + if rot_dim != weight_shape_dim: + assert weight_shape_dim % rot_dim == 0, "Sizes need to be divisibile" + return True + # No need to incorporate additional view operations + return False + # These properties enable propagating the fusing to the module weights @property def weight(self) -> Optional[torch.Tensor]: @@ -123,18 +230,18 @@ def weight(self) -> Optional[torch.Tensor]: if self.is_sink or self.is_orphan: axis = self._get_input_axis(self.module) if axis == 1: - weight = self.rot_func(weight, self.rot_mat) + weight = self.rot_func(weight, self.rot_mat, self.K) elif axis == 0: - weight = self.rot_func(weight.t(), self.rot_mat).t() + weight = self.rot_func(weight.t(), self.rot_mat, self.K).t() else: raise RuntimeError("Not supported yet") if self.is_source: axis = self._get_output_axis(self.module) if axis == 0: - weight = self.rot_func(weight.t(), self.rot_mat).t() + weight = self.rot_func(weight.t(), self.rot_mat, self.K).t() elif axis == 1: - weight = self.rot_func(weight, self.rot_mat) + weight = self.rot_func(weight, self.rot_mat, self.K) else: raise RuntimeError("Not supported yet") @@ -146,7 +253,7 @@ def bias(self) -> Optional[torch.Tensor]: # Propagate bias adding the rotations incrementally if self.is_source: if bias is not None: - bias = self.rot_func(bias, self.rot_mat) + bias = self.rot_func(bias, self.rot_mat, self.K) return bias @@ -169,7 +276,7 @@ def forward(self, inp, **kwargs): # rotate the inpute appropiately if self.is_orphan: # Rotate the input for an orphan sink - inp = self.rot_func(inp, self.rot_mat) + inp = self.rot_func(inp, self.rot_mat, self.K) # Modify the weights, and run the original model forward. After that, restore the previous values. if weight is not None: orig_weight = child_module.weight.data diff --git a/src/brevitas_examples/llm/llm_quant/rotation_optimization.py b/src/brevitas_examples/llm/llm_quant/rotation_optimization.py index 8baf07596..7b2baecae 100644 --- a/src/brevitas_examples/llm/llm_quant/rotation_optimization.py +++ b/src/brevitas_examples/llm/llm_quant/rotation_optimization.py @@ -38,6 +38,7 @@ class ModelArguments: class TrainingArguments(transformers.TrainingArguments): cache_dir: Optional[str] = field(default=None) output_dir: Optional[str] = field(default="/tmp/output/") + use_cpu: Optional[bool] = field(default="False") model_max_length: Optional[int] = field( default=2048, metadata={ @@ -55,6 +56,25 @@ def parse_optimization_rotation_args(unknown_args=None) -> None: return training_args +def collate_fn(kwargs_list, return_tensors="pt"): + # Keyword arguments + kwargs = {} + for curr_dict in kwargs_list: + for key, value in curr_dict.items(): + if isinstance(value, torch.Tensor): + if key not in kwargs: + kwargs[key] = [] + kwargs[key].append(value) + else: + if key not in kwargs: + kwargs[key] = value + for key, value in kwargs.items(): + if isinstance(value, list) and len(value) > 0: + kwargs[key] = torch.cat(kwargs[key], dim=0) + # FP outputs + return kwargs + + def apply_rotation_optimization( graph_model: torch.fx.GraphModule, tokenizer: PreTrainedTokenizerBase, @@ -62,12 +82,20 @@ def apply_rotation_optimization( unknown_args=None) -> None: # Get training arguments training_args = parse_optimization_rotation_args(unknown_args) + # Set to False the model parameters + for param in graph_model.parameters(): + param.requires_grad = False # Collect trainable matrices trainable_parameters = [] + ids_rot = set() for module in graph_model.modules(): if isinstance(module, UnfusedRotatedModule): - if not module.is_sink: + if id(module.rot_mat) not in ids_rot: + ids_rot.add(id(module.rot_mat)) trainable_parameters.append(module.rot_mat) + # Collect parameters for the rotation matrices + for rot_mat in trainable_parameters: + rot_mat.requires_grad = True # Initialize optimizer optimizer = SGDG(trainable_parameters, lr=training_args.learning_rate, stiefel=True) trainer = Trainer( @@ -76,6 +104,6 @@ def apply_rotation_optimization( args=training_args, train_dataset=train_dataset, eval_dataset=None, - data_collator=default_data_collator, + data_collator=collate_fn, optimizers=(optimizer, None)) trainer.train() diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index bf5b9685e..2c79101ef 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -20,7 +20,9 @@ from brevitas.export import export_torch_qcdq from brevitas.export.inference.manager import quant_inference_mode from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager +from brevitas.graph.equalize import find_missing_rotation_regions from brevitas.graph.equalize import GraphRotationEqualization +from brevitas.graph.equalize import GraphRotationEqualizationOptimization from brevitas.graph.equalize import LayerwiseActivationRotation from brevitas.graph.quantize import layerwise_quantize from brevitas.graph.utils import get_module @@ -77,7 +79,7 @@ def set_seed(seed): torch.random.manual_seed(seed) -def fused_rotation_no_fx(model, calibration_loader, args): +def fused_rotation_no_fx(model, calibration_loader, args, fuse_rotations: bool = False): with torch.no_grad(): new_model, guards = torch._dynamo.export(model)(**calibration_loader[0]) apply_layernorm_affine_merge(new_model) @@ -91,7 +93,7 @@ def fused_rotation_no_fx(model, calibration_loader, args): orphan_sink=args.rotation_orphan_sink, full_rotation_method=args.rotation_mode, return_rewriters=True) - new_model, rewriters = eq.apply(new_model) + new_model, rewriters = eq.apply(new_model, fuse_rotations=fuse_rotations) rewriters = fix_rewriter(rewriters, model, 'weight') for r in rewriters: @@ -99,6 +101,36 @@ def fused_rotation_no_fx(model, calibration_loader, args): remove_hooks(new_model) +def fused_optimized_rotation_no_fx( + model, + calibration_loader, + args, + fuse_rotations: bool = False, + add_additional_regions: bool = False): + with torch.no_grad(): + new_model, guards = torch._dynamo.export(model)(**calibration_loader[0]) + apply_layernorm_affine_merge(new_model) + new_model, rewriters = apply_layernorm_to_rmsnorm(new_model, return_rewriters=True) + rewriters = fix_rewriter(rewriters, model, 'weight') + + for r in rewriters: + r.apply(model) + #new_model = offload_model(new_model) + additional_regions = find_missing_rotation_regions( + new_model, model.config.hidden_size // + model.config.num_attention_heads) if add_additional_regions else None + eq = GraphRotationEqualizationOptimization( + orphan_sink=args.rotation_orphan_sink, + full_rotation_method=args.rotation_mode, + ) + new_model, rewriters, rotation_matrices = eq.apply(new_model, fuse_rotations=fuse_rotations, additional_regions=additional_regions) + rewriters = fix_rewriter(rewriters, model, 'weight') + + for r in rewriters: + r.apply(model) + #remove_hooks(new_model) + + def set_seed(seed): np.random.seed(seed) torch.random.manual_seed(seed) @@ -273,6 +305,15 @@ def quantize_llm(args, unknown_args=None): if args.replace_rmsnorm: model = replace_rmsnorm_with_torch(model, model.config) + # TODO: Refactor + if args.rotation == 'fused_no_fx_optimize': + for i in range(len(calibration_loader)): + del calibration_loader[i]["attention_mask"] + calibration_loader[i]["labels"] = calibration_loader[i]["input_ids"] + + model.config.use_cache = False + model.config.loss_type = "ForCausalLM" + if require_fx: if model.__class__.__name__ in _SUPPORTED_MODELS and not args.replace_rmsnorm: model = get_fx(model, is_export=args.export_target is not None) @@ -298,13 +339,16 @@ def quantize_llm(args, unknown_args=None): model = offload_model(model) eq = GraphRotationEqualization( orphan_sink=args.rotation_orphan_sink, full_rotation_method=args.rotation_mode) - model = eq.apply(model, fuse_rotations=not args.rotation_optimize) + model = eq.apply(model) remove_hooks(model) elif args.rotation == 'layerwise': eq = LayerwiseActivationRotation() model = eq.apply(model) elif args.rotation == 'fused_no_fx': fused_rotation_no_fx(model, calibration_loader, args) + elif args.rotation == 'fused_no_fx_optimize': + fused_optimized_rotation_no_fx( + model, calibration_loader, args, fuse_rotations=False, add_additional_regions=True) # Insert standard MHA layers when performing fx based weight/act equalization to avoid dealing # with all the variability in HF implementations @@ -425,6 +469,17 @@ def quantize_llm(args, unknown_args=None): for k, v in dict_hooks.items(): k._hf_hook.post_forward = v + # TODO: Refactor + remove_hooks(model) + + if args.rotation == 'fused_no_fx_optimize': + apply_rotation_optimization( + graph_model=model, + tokenizer=tokenizer, + train_dataset=calibration_loader, + unknown_args=unknown_args, + ) + if args.act_calibration: print("Apply act calibration...") apply_calibration(model, calibration_loader) @@ -759,7 +814,7 @@ def parse_args(args, override_defaults={}): '--rotation', type=str, default=None, - choices=['fx', 'layerwise', 'fused_no_fx'], + choices=['fx', 'layerwise', 'fused_no_fx', 'fused_no_fx_optimize'], help='Apply graph rotation equalization') parser.add_argument( '--rotation-mode', @@ -768,11 +823,6 @@ def parse_args(args, override_defaults={}): help= 'If GraphRotation is enabled, decide how to compute the random rotation matrix that is fully fused. Online or partial rotation will always be Hadamard' ) - # TODO: Make sure in argument validator that - parser.add_argument( - '--rotation-optimize', - action='store_true', - help='Whether to optimize the rotation matrices.') parser.add_argument( '--rotation-orphan-sink', action="store_true", diff --git a/tests/brevitas/graph/test_equalization.py b/tests/brevitas/graph/test_equalization.py index d8ea0bcb2..0f65db168 100644 --- a/tests/brevitas/graph/test_equalization.py +++ b/tests/brevitas/graph/test_equalization.py @@ -25,6 +25,7 @@ from brevitas.graph.equalize import GraphRotationEqualization from brevitas.graph.equalize import MergeLnAffine from brevitas.graph.equalize import random_orthogonal_matrix +from brevitas.graph.hadamard import matmul_hadU from brevitas.graph.standardize import DuplicateSharedStatelessModule from brevitas.graph.standardize import TorchFunctionalToModule from brevitas.graph.utils import get_module @@ -401,11 +402,47 @@ def _random_orthogonal_matrix(size, generator): return q +def _random_hadamard_matrix(size, device, generator): + # See https://github.com/Cornell-RelaxML/quip-sharp , Section "Randomized Hadamard Transformation" + Q = torch.randint(low=0, high=2, size=(size,), generator=generator).to(torch.float64) + Q = Q * 2 - 1 + Q = torch.diag(Q) + return matmul_hadU(Q).to(device) + + +def _compare_module_weights_fused_unfused(gt_module, rot_module, fused_rotations=False): + gt_weight = gt_module.weight if isinstance(gt_module, nn.Linear) else gt_module.layer.weight + gt_bias = gt_module.bias if isinstance(gt_module, nn.Linear) else gt_module.layer.bias + if fused_rotations: + rot_weight = rot_module.weight if isinstance( + rot_module, nn.Linear) else rot_module.layer.weight + rot_bias = rot_module.bias if isinstance(rot_module, nn.Linear) else rot_module.layer.bias + else: + rot_weight = rot_module.weight + rot_bias = rot_module.bias + assert torch.allclose(gt_weight, rot_weight, rtol=0.0, atol=0.0) + if gt_bias is not None: + assert torch.allclose(gt_bias, rot_bias, rtol=0.0, atol=0.0) + # For a RotatedModule, corresponding to an orphan node, additional checks need to be done + if isinstance(gt_module, RotatedModule): + if not fused_rotations: + # The outermost should be an orphan + child_rot_module = rot_module + assert child_rot_module.is_orphan, "Unfused rotated module needs to be an orphan." + # Check that the inner UnfusedRotatedModules are not orphans + while isinstance(child_rot_module.module, UnfusedRotatedModule): + assert not child_rot_module.module.is_orphan, "Inner unfused rotated modules cannot be orphans." + child_rot_module = child_rot_module.module + # Verify that the rotation matrices match + assert torch.allclose(gt_module.had_mat, rot_module.rot_mat) + + # This test verifies that the weights returned by the unfused rotate modules # match those when fusing @requires_pt_ge('2.4') @pytest_cases.parametrize('partial_had', [False, True]) -def test_models_unfused_rotations(rotation_fixtures, partial_had): +@pytest_cases.parametrize('fused_rotations', [False, True]) +def test_models_rotations(rotation_fixtures, partial_had, fused_rotations): in_shape = IN_SIZE_LINEAR @@ -431,22 +468,28 @@ def test_models_unfused_rotations(rotation_fixtures, partial_had): # We pass the generator to make sure that we can reproduce the random orthogonal matrices that are generated with patch('brevitas.graph.equalize.random_orthogonal_matrix', - partial(_random_orthogonal_matrix, generator=generator)) as mock_ort_generator: + partial(_random_orthogonal_matrix, generator=generator)): # Apply rotation equalization while controlling the random matrices that are generated model = eq.apply(model) + with torch.no_grad(): + expected_out = model(inp) + # Now rotate but without fusing the rotation matrices with patch('brevitas.graph.equalize.random_orthogonal_matrix', - partial(_random_orthogonal_matrix, generator=generator_clone)) as mock_ort_generator: + partial(_random_orthogonal_matrix, generator=generator_clone)): # Apply rotation equalization while controlling the random matrices that are generated model_copy = eq.apply(model_copy, fuse_rotations=False) + # Fuse the rotations and make sure the behaviour is the same + if fused_rotations: + _fuse_rotations(model_copy) + with torch.no_grad(): out = model_copy(inp) # Verify that the output of the model does not change after incorporating the rotations - with torch.no_grad(): - expected_out = model(inp) + assert torch.allclose(expected_out, out, rtol=0.0, atol=0.0) # Verify that weight matrices for model_node, model_copy_node in zip(model.graph.nodes, model_copy.graph.nodes): @@ -454,47 +497,85 @@ def test_models_unfused_rotations(rotation_fixtures, partial_had): module = get_module(model, model_node.target) module_copy = get_module(model_copy, model_copy_node.target) if isinstance(module, (nn.Linear, RotatedModule)): - weight = module.weight if isinstance(module, nn.Linear) else module.layer.weight - bias = module.bias if isinstance(module, nn.Linear) else module.layer.bias - weight_copy = module_copy.weight - bias_copy = module_copy.bias - assert torch.allclose(weight, weight_copy, atol=ATOL) - if bias is not None: - assert torch.allclose(bias, bias_copy, atol=ATOL) - # For a RotatedModule, corresponding to an orphan node, additional checks need to be done - if isinstance(module, RotatedModule): - # The outermost should be an orphan - rotated_module = module_copy - assert rotated_module.is_orphan, "Unfused rotated module needs to be an orphan." - # Check that the inner UnfusedRotatedModules are not orphans - while isinstance(rotated_module.module, UnfusedRotatedModule): - assert not rotated_module.module.is_orphan, "Inner unfused rotated modules cannot be orphans." - rotated_module = rotated_module.module - # Verify that the rotation matrices match - assert torch.allclose(module.had_mat, module_copy.rot_mat) - - -# This test verifies that the weights returned by the unfused rotate modules -# match those when fusing + _compare_module_weights_fused_unfused(module, module_copy, fused_rotations) + + +def _compare_module_weights(module, module_copy): + weight = module.weight if isinstance(module, nn.Linear) else module.layer.weight + bias = module.bias if isinstance(module, nn.Linear) else module.layer.bias + weight_copy = module_copy.weight + bias_copy = module_copy.bias + assert torch.allclose(weight, weight_copy, rtol=0.0, atol=0.0) + if bias is not None: + assert torch.allclose(bias, bias_copy, rtol=0.0, atol=0.0) + + +import logging + +from transformers import AutoModelForCausalLM +from transformers import AutoTokenizer + +from brevitas.graph.equalize import find_missing_rotation_regions +from brevitas_examples.common.accelerate_utils.accelerate import offload_model +from brevitas_examples.common.accelerate_utils.accelerate import remove_hooks +from brevitas_examples.llm.llm_quant.data_utils import get_dataset_for_model +from brevitas_examples.llm.llm_quant.ln_affine_merge import apply_layernorm_affine_merge +from brevitas_examples.llm.llm_quant.ln_affine_merge import apply_layernorm_to_rmsnorm +from brevitas_examples.llm.llm_quant.ln_affine_merge import replace_rmsnorm_with_torch +from brevitas_examples.llm.llm_quant.run_utils import fix_rewriter +from brevitas_examples.llm.main import fused_optimized_rotation_no_fx +from brevitas_examples.llm.main import fused_rotation_no_fx +from tests.brevitas_examples.test_llm import default_run_args + + +@pytest_cases.fixture( + ids=[ + "llama",], + params=[ + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "input_bit_width": None, + "fuse_sequences": False, + "act_calibration": False,},]) +def equalize_args(default_run_args, request): + args = default_run_args + export_dict = request.param + args.update(**export_dict) + yield args + + +@pytest.mark.llm @requires_pt_ge('2.4') @pytest_cases.parametrize('partial_had', [False, True]) -def test_models_fused_rotations(rotation_fixtures, partial_had): - - in_shape = IN_SIZE_LINEAR - - model_class = rotation_fixtures - model = model_class() - +@pytest_cases.parametrize('fused_rotations', [False, True]) +def test_small_models_equalize_legacy_rotation_orthogonal( + caplog, partial_had, fused_rotations, equalize_args): + import os + os.environ["HF_HUB_CACHE"] = "/scratch/hf_models/" + caplog.set_level(logging.INFO) + args = equalize_args + args.rotation_orphan_sink = partial_had + args.rotation_mode = 'ort' + + kwargs = {"torch_dtype": torch.float16} + model = AutoModelForCausalLM.from_pretrained(args.model, **kwargs) + model = replace_rmsnorm_with_torch(model, model.config) + model.config.use_cache = False + print("Model loaded.") model.eval() - inp = torch.rand(in_shape) - - model = symbolic_trace(model) - merge = MergeLnAffine() - model = merge.apply(model) - eq = GraphRotationEqualization(orphan_sink=partial_had, full_rotation_method='ort') - - # Save a copy to apply graph rotation equalization on - model_copy = copy.deepcopy(model) + tokenizer = AutoTokenizer.from_pretrained(args.model) + # Load the data for calibration and evaluation. + calibration_loader = get_dataset_for_model( + args.model, + dataset_name=args.dataset, + tokenizer=tokenizer, + nsamples=args.nsamples, + seqlen=args.seqlen, + split="train", + seed=args.seed, + require_fx=False, + device=None, + fuse_sequences=args.fuse_sequences) # We need to make sure that the same random matrices are being generated generator = torch.Generator() @@ -502,46 +583,205 @@ def test_models_fused_rotations(rotation_fixtures, partial_had): # Clone generator to make sure we can use the same rotation matrices generator_clone = generator.clone_state() + # Save a copy to apply graph rotation equalization on + model_copy = copy.deepcopy(model) + # We pass the generator to make sure that we can reproduce the random orthogonal matrices that are generated with patch('brevitas.graph.equalize.random_orthogonal_matrix', - partial(_random_orthogonal_matrix, generator=generator)) as mock_ort_generator: - # Apply rotation equalization while controlling the random matrices that are generated - model = eq.apply(model) + partial(_random_orthogonal_matrix, generator=generator)): + with patch('brevitas.graph.hadamard.random_hadamard_matrix', + partial(_random_hadamard_matrix, generator=generator)): + fused_rotation_no_fx(model, calibration_loader, args, fuse_rotations=True) + # Run model and save outputs with torch.no_grad(): - expected_out = model(inp) + expected_logits = model(**calibration_loader[0]).logits - # Now rotate but without fusing the rotation matrices + # We pass the generator to make sure that we can reproduce the random orthogonal matrices that are generated with patch('brevitas.graph.equalize.random_orthogonal_matrix', - partial(_random_orthogonal_matrix, generator=generator_clone)) as mock_ort_generator: - # Apply rotation equalization while controlling the random matrices that are generated - model_copy = eq.apply(model_copy, fuse_rotations=False) + partial(_random_orthogonal_matrix, generator=generator_clone)): + with patch('brevitas.graph.hadamard.random_hadamard_matrix', + partial(_random_hadamard_matrix, generator=generator_clone)): + fused_rotation_no_fx(model_copy, calibration_loader, args, fuse_rotations=False) - # Fuse the rotations and make sure the behaviour is the same - _fuse_rotations(model_copy) + if fused_rotations: + _fuse_rotations(model_copy) + # Run model and save outputs with torch.no_grad(): - out = model_copy(inp) + logits = model_copy(**calibration_loader[0]).logits + + # Verify that the output is the same + assert torch.allclose(expected_logits, logits) + + # Verify that the weights after fusing match + for name_fused_module, fused_module in model.named_modules(): + # For linear modules verify that the weights match + if isinstance(fused_module, (nn.Linear, RotatedModule)): + for name_unfused_Module, unfused_module in model_copy.named_modules(): + if name_fused_module == name_unfused_Module: + _compare_module_weights(fused_module, unfused_module) + # For a RotatedModule, corresponding to an orphan node, additional checks need to be done + if isinstance(fused_module, RotatedModule): + # Verify that the outer module is an orphan + if fused_rotations: + assert isinstance(unfused_module, RotatedModule) + assert torch.allclose(fused_module.had_mat, unfused_module.had_mat) + else: + assert unfused_module.is_orphan + # Verify that the rotation matrices match + assert torch.allclose(fused_module.had_mat, unfused_module.rot_mat) + + +from itertools import product + +from brevitas.graph.equalize import _apply_had_device +from brevitas.graph.hadamard import get_hadK + + +# NOTE: This test works because in R2 we patch the rotation method, so the appropiate matrix is not effectively used. This is because when the fast_hadamard_transform is not avai +@pytest.mark.llm +@requires_pt_ge('2.4') +@pytest_cases.parametrize( + 'partial_had, fused_rotations, add_additional_regions', + list(product([False, True], repeat=3)), + ids=[("fused-R1" if fused_rotations else "R1") + ("-R2" if add_additional_regions else "") + + ("-R3" if partial_had else "") for partial_had, + fused_rotations, + add_additional_regions in list(product([False, True], repeat=3))], +) +@pytest_cases.parametrize('rotation_mode', ['ort', 'had']) +def test_small_models_equalize_mixed_fused_unfused( + caplog, partial_had, fused_rotations, add_additional_regions, rotation_mode, equalize_args): + import os + os.environ["HF_HUB_CACHE"] = "/scratch/hf_models/" + caplog.set_level(logging.INFO) + args = equalize_args + args.rotation_orphan_sink = partial_had + args.rotation_mode = rotation_mode + + kwargs = {"torch_dtype": torch.float16} + model = AutoModelForCausalLM.from_pretrained(args.model, **kwargs) + model = replace_rmsnorm_with_torch(model, model.config) + model.config.use_cache = False + print("Model loaded.") + model.eval() + tokenizer = AutoTokenizer.from_pretrained(args.model) + # Load the data for calibration and evaluation. + calibration_loader = get_dataset_for_model( + args.model, + dataset_name=args.dataset, + tokenizer=tokenizer, + nsamples=args.nsamples, + seqlen=args.seqlen, + split="train", + seed=args.seed, + require_fx=False, + device=None, + fuse_sequences=args.fuse_sequences) - # Verify that the output of the model does not change after incorporating the rotations - assert torch.allclose(expected_out, out) + # We need to make sure that the same random matrices are being generated + generator = torch.Generator() + generator.manual_seed(SEED) + # Clone generator to make sure we can use the same rotation matrices + generator_clone = generator.clone_state() - # Verify that weight matrices - for model_node, model_copy_node in zip(model.graph.nodes, model_copy.graph.nodes): - if model_node.op == 'call_module': - module = get_module(model, model_node.target) - module_copy = get_module(model_copy, model_copy_node.target) - if isinstance(module, (nn.Linear, RotatedModule)): - weight = module.weight if isinstance(module, nn.Linear) else module.layer.weight - bias = module.bias if isinstance(module, nn.Linear) else module.layer.bias - weight_copy = module_copy.weight if isinstance( - module_copy, nn.Linear) else module_copy.layer.weight - bias_copy = module_copy.bias if isinstance( - module_copy, nn.Linear) else module_copy.layer.bias - assert torch.allclose(weight, weight_copy, atol=ATOL) - if bias is not None: - assert torch.allclose(bias, bias_copy, atol=ATOL) - # For a RotatedModule, corresponding to an orphan node, additional checks need to be done - if isinstance(module, RotatedModule): - # Verify that the rotation matrices match - assert torch.allclose(module.had_mat, module_copy.had_mat) + # Run model and save outputs + with torch.no_grad(): + original_logits = model(**calibration_loader[0]).logits + + # Save a copy to apply graph rotation equalization on + model_copy = copy.deepcopy(model) + + with patch('brevitas.graph.equalize.random_orthogonal_matrix', + partial(_random_orthogonal_matrix, generator=generator)): + fused_optimized_rotation_no_fx( + model, + calibration_loader, + args, + fuse_rotations=True, + add_additional_regions=add_additional_regions) + + # Run model and save outputs + with torch.no_grad(): + expected_logits = model(**calibration_loader[0]).logits + + # Instead of random orthogonal matrices, we want to use the same ones as when the activations are not fused. + if rotation_mode == 'had': + with patch('brevitas.graph.equalize._apply_ort_device', _apply_had_device): + fused_optimized_rotation_no_fx( + model_copy, + calibration_loader, + args, + fuse_rotations=False, + add_additional_regions=add_additional_regions) + else: + with patch('brevitas.graph.equalize.random_orthogonal_matrix', + partial(_random_orthogonal_matrix, generator=generator_clone)): + fused_optimized_rotation_no_fx( + model_copy, + calibration_loader, + args, + fuse_rotations=False, + add_additional_regions=add_additional_regions) + + # Fuse matrices with module weights + if fused_rotations: + _fuse_rotations(model_copy) + + ids_rot = set() + num_rotation_matrices = 0 + # Count the number of unique rotation matrices + for module in model_copy.modules(): + if isinstance(module, UnfusedRotatedModule): + if id(module.rot_mat) not in ids_rot: + num_rotation_matrices += 1 + ids_rot.add(id(module.rot_mat)) + + num_rotated_modules = 0 + # Count the number of RotatedModules + for module in model_copy.modules(): + if isinstance(module, RotatedModule): + num_rotated_modules += 1 + + # Run model and save outputs + with torch.no_grad(): + logits = model_copy(**calibration_loader[0]).logits + + # Verify that the number of learnable rotation matrices is the expected (R1 + one R2 per block) + expected_number_rotation_matrices = 0 if fused_rotations else ( + 1 + (model.config.num_hidden_layers if add_additional_regions else 0)) + assert num_rotation_matrices == expected_number_rotation_matrices, f"Expected {expected_number_rotation_matrices} learnable rotations, found {num_rotation_matrices}." + + # Verify that the number of rotated modules is correct + expected_number_rotated_modules = 0 if not partial_had else ( + model.config.num_hidden_layers if add_additional_regions else 2 * + model.config.num_hidden_layers) + assert num_rotated_modules == expected_number_rotated_modules, f"Expected {expected_number_rotated_modules} learnable rotations, found {num_rotated_modules}." + + # Verify that the rotated module output is similar to the original FP + assert torch.allclose(original_logits, logits, atol=ATOL) + # Verify that the output is the same + assert torch.allclose(expected_logits, logits, atol=0.0, rtol=0.0) + + # Verify that the weights after fusing match + for name_fused_module, fused_module in model.named_modules(): + # For linear modules verify that the weights match + if isinstance(fused_module, (nn.Linear, RotatedModule)): + for name_unfused_Module, unfused_module in model_copy.named_modules(): + if name_fused_module == name_unfused_Module: + _compare_module_weights(fused_module, unfused_module) + # In case a RotatedModule is found, additional checks need to be done. + if isinstance(fused_module, RotatedModule): + if fused_rotations: + assert isinstance(unfused_module, RotatedModule) + assert torch.allclose(fused_module.had_mat, unfused_module.had_mat, rtol=0.0, atol=0.0), "The rotation matrices do not match." + else: + # Iterate over child nodes until finding the innermost RotatedModule + child_module = unfused_module + while isinstance(child_module, UnfusedRotatedModule): + assert not child_module.is_orphan, "UnfusedRotatedModule should not be an orphan." + child_module = child_module.module + # After finding the inner Rotated Module, they need to be compared + assert isinstance(child_module, RotatedModule), "Inner module should be RotatedModule." + assert torch.allclose(fused_module.had_mat, child_module.had_mat, rtol=0.0, atol=0.0), "The rotation matrices do not match." diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index c02a3e320..6a42c0895 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -131,7 +131,7 @@ def small_models_with_ppl(request): @pytest_cases.fixture() def default_run_args(request): - args = UpdatableNamespace(**vars(parse_args([]))) + args = UpdatableNamespace(**vars(parse_args([])[0])) args.nsamples = 2 args.seqlen = 2 args.model = "hf-internal-testing/tiny-random-MistralForCausalLM" From 06a77bc5d1b8af1eabd5a16c0d44325ce204fbe9 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Fri, 13 Dec 2024 11:10:35 +0000 Subject: [PATCH 06/12] Switch to reparametrizations and refactor --- src/brevitas/graph/equalize.py | 296 +++-------- src/brevitas/nn/equalized_layer.py | 119 ----- .../llm/llm_quant/rotation_optimization.py | 17 +- .../llm/llm_quant/rotation_utils.py | 92 ++++ .../llm/llm_quant/run_utils.py | 15 +- src/brevitas_examples/llm/main.py | 14 +- tests/brevitas/graph/test_equalization.py | 475 ++---------------- tests/brevitas_examples/test_llm.py | 207 +++++++- 8 files changed, 424 insertions(+), 811 deletions(-) create mode 100644 src/brevitas_examples/llm/llm_quant/rotation_utils.py diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 0cd8ea47c..7056f18cd 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -16,6 +16,7 @@ import torch from torch.fx import GraphModule as TorchGraphModule import torch.nn as nn +import torch.nn.utils.parametrize as parametrize from brevitas import torch_version from brevitas.fx import GraphModule @@ -36,7 +37,8 @@ from brevitas.nn.equalized_layer import functional_rotate_input from brevitas.nn.equalized_layer import INPUT_NAMES from brevitas.nn.equalized_layer import RotatedModule -from brevitas.nn.equalized_layer import UnfusedRotatedModule +from brevitas.nn.equalized_layer import RotationBiasParametrization +from brevitas.nn.equalized_layer import RotationWeightParametrization from brevitas.nn.quant_scale_bias import ScaleBias from brevitas.utils.torch_utils import KwargsForwardHook @@ -342,9 +344,6 @@ def _get_input_axis(module: nn.Module) -> Optional[int]: return 0 else: return None - # TODO: Remove with parametrizations - elif isinstance(module, (UnfusedRotatedModule)): - return _get_input_axis(module.module) elif isinstance(module, (RotatedModule,)): return _get_input_axis(module.layer) else: @@ -375,9 +374,6 @@ def _get_output_axis(module: nn.Module) -> Optional[int]: return 0 else: return None - # TODO: Remove with parametrizations - elif isinstance(module, (UnfusedRotatedModule)): - return _get_output_axis(module.module) elif isinstance(module, (RotatedModule,)): return _get_output_axis(module.layer) else: @@ -1324,8 +1320,7 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method= if not insert_rotation_module and not region.is_valid: continue hidden_dim = region.max_shape_sinks - # TODO: Include again not insert_rotation_module - if full_rotation_method == 'ort': + if not insert_rotation_module and full_rotation_method == 'ort': rot_mat = random_orthogonal_matrix(hidden_dim) K = None rot_func = _apply_ort_device @@ -1392,116 +1387,6 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method= return rewriters -def _fuse_rotations(model: nn.Module): - rewriters = [] - - def _fuse_rotations_aux(module: nn.Module): - if isinstance(module, UnfusedRotatedModule): - unrotated_module = module.unrotated_module - rot_weight = module.weight.data - - # Fuse rotations with weights - unrotated_module.weight.data = rot_weight - # Fuse rotations with bias if existent - if module.bias is not None: - rot_bias = module.bias.data - unrotated_module.bias.data = rot_bias - - # Use rotated module if orphan - if module.is_orphan: - rewriter = ModuleInstanceToModuleInstance( - module, RotatedModule(had_mat=module.rot_mat, k=None, layer=unrotated_module)) - else: - rewriter = ModuleInstanceToModuleInstance(module, unrotated_module) - # Save rewriter - rewriters.append(rewriter) - else: - for child_module in module.children(): - _fuse_rotations_aux(child_module) - - # Populate rewriters - _fuse_rotations_aux(model) - # Apply rewriter to fuse the weights - for r in rewriters: - model = r.apply(model) - - -@dataclass -class UnfusedRotation: - rot_mat: torch.Tensor - is_sink: bool - is_source: bool - is_orphan: bool - - -def _apply_unfused_rotate(model: nn.Module, regions: List[Region], full_rotation_method='ort'): - rewriters = [] - fused_rotated_modules = defaultdict(list) - rot_func = _apply_ort_device - - for region in regions: - insert_rotation_module = len(region.srcs) == 0 - - if not insert_rotation_module and not region.is_valid: - continue - hidden_dim = region.max_shape_sinks - - rot_mat = random_orthogonal_matrix(hidden_dim) - - for name, indexes in region.srcs.items(): - module = region.get_module_from_name(name) - - fused_rotated_modules[module].append( - UnfusedRotation( - rot_mat=rot_mat, - is_sink=False, - is_source=True, - is_orphan=False, - )) - - for name, indexes in region.sinks.items(): - module = region.get_module_from_name(name) - - if insert_rotation_module and len(region.srcs) == 0: - fused_rotated_modules[module].append( - UnfusedRotation( - rot_mat=rot_mat, - is_sink=False, - is_source=False, - is_orphan=True, - )) - else: - fused_rotated_modules[module].append( - UnfusedRotation( - rot_mat=rot_mat, - is_sink=True, - is_source=False, - is_orphan=False, - )) - - for module, rotation_modules in fused_rotated_modules.items(): - rotation_module = module - for rotation_module_dataclass in rotation_modules: - rotation_module = UnfusedRotatedModule( - module=rotation_module, - rot_func=rot_func, - rot_mat=rotation_module_dataclass.rot_mat, - _get_input_axis=_get_input_axis, - _get_output_axis=_get_output_axis, - is_source=rotation_module_dataclass.is_source, - is_sink=rotation_module_dataclass.is_sink, - is_orphan=rotation_module_dataclass.is_orphan, - ) - rewriter = ModuleInstanceToModuleInstance( - module, - rotation_module, - ) - rewriters.append(rewriter) - for r in rewriters: - model = r.apply(model) - return rewriters - - def _replace_bias(next_module, new_bias): new_bias = new_bias.view(-1) if next_module.bias is not None: @@ -1591,10 +1476,8 @@ def rotate_matmuls(self, graph_module): graph_module.recompile() graph_module.graph.lint() - def apply( - self, - graph_model: GraphModule, - fuse_rotations: bool = True) -> Union[Tuple[GraphModule, List[Transform]], GraphModule]: + def apply(self, + graph_model: GraphModule) -> Union[Tuple[GraphModule, List[Transform]], GraphModule]: rewriters = [] regions = _extract_regions( graph_model, @@ -1618,8 +1501,7 @@ def apply( if self.rotate_matmul: self.rotate_matmuls(graph_model) if len(regions) > 0: - _apply_rotate_fn = _apply_rotate if fuse_rotations else _apply_unfused_rotate - rewriters = _apply_rotate_fn(graph_model, regions, self.full_rotation_method) + rewriters = _apply_rotate(graph_model, regions, self.full_rotation_method) if self.return_rewriters: return graph_model, rewriters else: @@ -1716,67 +1598,16 @@ def apply(self, model: nn.Module, fuse_rotations: bool = True) -> nn.Module: regions: List[Region] = [] self.find_module(model, regions) if len(regions) > 0: - _apply_rotate_fn = _apply_rotate if fuse_rotations else _apply_unfused_rotate - _apply_rotate_fn(model, regions) + _apply_rotate(model, regions) return model -def find_missing_rotation_regions(graph_model: GraphModule, - head_dim: int, - state_impl_kwargs=None) -> List[Region]: - import re - - regions = [] - # Add R2 regions, this should be innermost - for src_name, src_module in graph_model.named_modules(): - if "attn_v_proj" in src_name: - if state_impl_kwargs is not None: - state = WalkRegionState(**state_impl_kwargs) - else: - state = WalkRegionState() - - block_number_matches_src = re.findall(r'\d+', src_name) - assert len(block_number_matches_src) == 2, "Could not identify block" - block_number_src = int(block_number_matches_src[1]) - - eq_indexes = EqualizationIndexes(0, head_dim, 0) - state.add_srcs(src_name, src_module, eq_indexes) - - # Now the corresponding sink - for sink_name, sink_module in graph_model.named_modules(): - if "attn_o_proj" in sink_name: - block_number_matches_sink = re.findall(r'\d+', sink_name) - assert len(block_number_matches_sink) == 2, "Could not identify block" - block_number_sink = int(block_number_matches_sink[1]) - # If the blocks match, we identified the region - if block_number_src == block_number_sink: - eq_indexes = EqualizationIndexes(0, head_dim, state.offset) - state.add_sinks(sink_name, sink_module, eq_indexes) - # Instantiate region and add to list - region = Region( - srcs=dict(sorted(state.srcs.items())), - sinks=dict(sorted(state.sinks.items())), - name_to_module=state.name_to_module, - ) - if region not in regions: - regions.append(region) - - return regions - - def _apply_rotate_fused_rotations( model: nn.Module, regions: List[Region], - full_rotation_method: str = 'had', + full_rotation_method='had', fuse_rotations: bool = True): rewriters = [] - # Dictionary to append the unfused rotated modules for optimization - unfused_rotated_modules = defaultdict(list) - # Dictionary to keep track of the modules that are assigned to a RotatedModule - fused_rotated_modules = {} - # List to keep track of the rotation matrices added to the - rotation_matrices = [] - for region in regions: insert_rotation_module = len(region.srcs) == 0 @@ -1784,20 +1615,18 @@ def _apply_rotate_fused_rotations( continue hidden_dim = region.max_shape_sinks if not insert_rotation_module and full_rotation_method == 'ort': - rot_mat = random_orthogonal_matrix( - hidden_dim) if fuse_rotations else torch.nn.Parameter( - random_orthogonal_matrix(hidden_dim)) + rot_mat = random_orthogonal_matrix(hidden_dim) + # If the rotations are not fused, redefine as parameter + if not fuse_rotations: + rot_mat = torch.nn.Parameter(rot_mat) K = None rot_func = _apply_ort_device - # Store rotation matrix for optimization - rotation_matrices.append(rot_mat) elif not insert_rotation_module and not fuse_rotations: - # TODO: Make it more general - rot_mat = torch.nn.Parameter(random_hadamard_matrix(hidden_dim, torch.device('cpu'))) + # TODO: Generalize + device = next(model.parameters()).device + rot_mat = torch.nn.Parameter(random_hadamard_matrix(hidden_dim, device)) K = None rot_func = _apply_ort_device - # Store rotation matrix for optimization - rotation_matrices.append(rot_mat) else: try: # Build hadamard rotation matrix @@ -1815,12 +1644,17 @@ def _apply_rotate_fused_rotations( for name, indexes in region.srcs.items(): module = region.get_module_from_name(name) + axis = _get_output_axis(module) + + assert not insert_rotation_module, "Orphan regions must not have sources." if not insert_rotation_module and fuse_rotations: + # Verify that there are no parametrizations, as otherwise the underlying data will not be updated + assert not hasattr(module, "parametrizations"), "Fused rotations need to be incorporated before the parametrized rotations." + if hasattr(module, 'allocate_params'): module.allocate_params(module) - axis = _get_output_axis(module) weight = module.weight.data if axis == 0: @@ -1838,31 +1672,48 @@ def _apply_rotate_fused_rotations( if hasattr(module, 'offload_params'): module.offload_params(module) - else: - unfused_rotated_modules[module].append( - UnfusedRotation( + elif not insert_rotation_module and not fuse_rotations: + # Parametrize weights and possibly bias with unfused rotations + parametrize.register_parametrization( + module, + "weight", + RotationWeightParametrization( rot_mat=rot_mat, - is_sink=False, + rot_func=rot_func, + output_axis=axis, is_source=True, - is_orphan=False, )) + if getattr(module, 'bias', None) is not None: + parametrize.register_parametrization( + module, + "bias", + RotationBiasParametrization( + rot_mat=rot_mat, + rot_func=rot_func, + output_axis=axis, + is_source=True, + )) for name, indexes in region.sinks.items(): module = region.get_module_from_name(name) + axis = _get_input_axis(module) if not insert_rotation_module and not fuse_rotations: - unfused_rotated_modules[module].append( - UnfusedRotation( + parametrize.register_parametrization( + module, + "weight", + RotationWeightParametrization( rot_mat=rot_mat, + rot_func=rot_func, + input_axis=axis, is_sink=True, - is_source=False, - is_orphan=False, )) else: + # Verify that there are no parametrizations, as otherwise the underlying data will not be updated + assert not hasattr(module, "parametrizations"), "Fused rotations need to be incorporated before the parametrized rotations." + if hasattr(module, 'allocate_params'): module.allocate_params(module) - - axis = _get_input_axis(module) weight = module.weight.data if axis == 1: @@ -1875,42 +1726,16 @@ def _apply_rotate_fused_rotations( if hasattr(module, 'offload_params'): module.offload_params(module) - if insert_rotation_module: - if module not in fused_rotated_modules: - fused_rotated_modules[module] = RotatedModule( - had_mat=rot_mat, k=K, layer=module) - else: - raise RuntimeError( - "Only one RotatedModule at most can be assigned to a module.") - # For this to work, we need to have the following hierarchy UnfusedRotatedModule -> (RotatedModule) -> Linear - for module, rotation_modules in unfused_rotated_modules.items(): - # Verify that at most one RotatedModule is available - rotation_module = module if module not in fused_rotated_modules else fused_rotated_modules[ - module] - for rotation_module_dataclass in rotation_modules: - rotation_module = UnfusedRotatedModule( - module=rotation_module, - rot_func=rot_func, - rot_mat=rotation_module_dataclass.rot_mat, - _get_input_axis=_get_input_axis, - _get_output_axis=_get_output_axis, - is_source=rotation_module_dataclass.is_source, - is_sink=rotation_module_dataclass.is_sink, - is_orphan=rotation_module_dataclass.is_orphan, - ) - # Instantiate rewriters - rewriter = ModuleInstanceToModuleInstance(module, rotation_module) - rewriters.append(rewriter) - # Add missing RotatedModules, in case there are any - for module, rotation_module in fused_rotated_modules.items(): - if module not in unfused_rotated_modules: - rewriter = ModuleInstanceToModuleInstance(module, rotation_module) - rewriters.append(rewriter) + if insert_rotation_module and len(region.srcs) == 0: + rewriter = ModuleInstanceToModuleInstance( + module, RotatedModule(had_mat=rot_mat, k=K, layer=module)) + rewriters.append(rewriter) for r in rewriters: model = r.apply(model) - return rewriters, rotation_matrices + return rewriters +# TODO: Consolidate with GraphRotationEqualization class GraphRotationEqualizationOptimization(GraphRotationEqualization): def __init__( @@ -1955,10 +1780,13 @@ def apply( # Layerwise have only a single sink named 'sinks0' id_sink = id(o_r.get_module_from_name('sinks0')) if id_sink not in eq_layers: - # TODO: Change data structure to insert in the beginning with O(1) regions = [o_r] + regions if self.rotate_matmul: self.rotate_matmuls(graph_model) if len(regions) > 0: - rewriters, rotation_matrices = _apply_rotate_fused_rotations(graph_model, regions, self.full_rotation_method, fuse_rotations) - return graph_model, rewriters, rotation_matrices + rewriters = _apply_rotate_fused_rotations( + graph_model, regions, self.full_rotation_method, fuse_rotations) + if self.return_rewriters: + return graph_model, rewriters + else: + return graph_model diff --git a/src/brevitas/nn/equalized_layer.py b/src/brevitas/nn/equalized_layer.py index fbe045124..ccd812713 100644 --- a/src/brevitas/nn/equalized_layer.py +++ b/src/brevitas/nn/equalized_layer.py @@ -175,125 +175,6 @@ def forward(self, bias: torch.Tensor) -> torch.Tensor: return bias -class UnfusedRotatedModule(torch.nn.Module): - - def __init__( - self, - module: torch.nn.Module, - rot_func: Callable, - rot_mat: torch.nn.Parameter, - _get_input_axis: Callable, - _get_output_axis: Callable, - is_source: bool = False, - is_sink: bool = False, - is_orphan: bool = False, - ) -> None: - super().__init__() - self.module = module - self.rot_func = rot_func - self.rot_mat = rot_mat - self.K = None - - # TODO: This were included to prevent circular imports. - self._get_input_axis = _get_input_axis - self._get_output_axis = _get_output_axis - - self.is_source = is_source - self.is_sink = is_sink - self.is_orphan = is_orphan - - # TODO: Does it make sense the extra complexity just to prevent the view operation? - # Probably if no reshaping needs to be done, no change is required - def _wrap_rot(self) -> bool: - weight_shape = self.module.weight.shape - rot_dim = self.rot_mat.shape[0] - if self.is_sink or self.is_orphan: - weight_shape_dim = weight_shape[self._get_input_axis(self.module)] - elif self.is_source: - weight_shape_dim = weight_shape[self._get_output_axis(self.module)] - else: - weight_shape_dim = None - - if weight_shape_dim is not None: - if rot_dim != weight_shape_dim: - assert weight_shape_dim % rot_dim == 0, "Sizes need to be divisibile" - return True - # No need to incorporate additional view operations - return False - - # These properties enable propagating the fusing to the module weights - @property - def weight(self) -> Optional[torch.Tensor]: - weight = getattr(self.module, 'weight', None) - # Add rotation and let these being propagated till the parent - # unfused rotated module - if self.is_sink or self.is_orphan: - axis = self._get_input_axis(self.module) - if axis == 1: - weight = self.rot_func(weight, self.rot_mat, self.K) - elif axis == 0: - weight = self.rot_func(weight.t(), self.rot_mat, self.K).t() - else: - raise RuntimeError("Not supported yet") - - if self.is_source: - axis = self._get_output_axis(self.module) - if axis == 0: - weight = self.rot_func(weight.t(), self.rot_mat, self.K).t() - elif axis == 1: - weight = self.rot_func(weight, self.rot_mat, self.K) - else: - raise RuntimeError("Not supported yet") - - return weight - - @property - def bias(self) -> Optional[torch.Tensor]: - bias = getattr(self.module, 'bias', None) - # Propagate bias adding the rotations incrementally - if self.is_source: - if bias is not None: - bias = self.rot_func(bias, self.rot_mat, self.K) - - return bias - - @property - def unrotated_module(self) -> torch.nn.Module: - return self.module.unrotated_module if isinstance( - self.module, UnfusedRotatedModule) else self.module - - def forward(self, inp, **kwargs): - # Rotated matrices - weight = self.weight.data - bias = self.bias.data if self.bias is not None else None - - # Propagate calls till getting to the original module being rotated - child_module = self.module - # Iterate until the original module is reached, keeping the rotations that need to be performed on the input - while isinstance(child_module, UnfusedRotatedModule): - child_module = child_module.module - # child_module contains the original module in the network. Before applying its forward method, we need to - # rotate the inpute appropiately - if self.is_orphan: - # Rotate the input for an orphan sink - inp = self.rot_func(inp, self.rot_mat, self.K) - # Modify the weights, and run the original model forward. After that, restore the previous values. - if weight is not None: - orig_weight = child_module.weight.data - child_module.weight.data = weight - if bias is not None: - orig_bias = child_module.bias.data - child_module.bias.data = bias - # Call forward of the original module - o = child_module(inp) - # Restore un-rotated weights - child_module.weight.data = orig_weight - if bias is not None: - child_module.bias.data = orig_bias - # Return rotated output - return o - - def functional_rotate_input(inp, transpose=False): is_cuda = 'cuda' in str(inp.device) and torch.version.cuda is not None if transpose: diff --git a/src/brevitas_examples/llm/llm_quant/rotation_optimization.py b/src/brevitas_examples/llm/llm_quant/rotation_optimization.py index 7b2baecae..93b7edf7e 100644 --- a/src/brevitas_examples/llm/llm_quant/rotation_optimization.py +++ b/src/brevitas_examples/llm/llm_quant/rotation_optimization.py @@ -9,14 +9,12 @@ import torch from torch.utils.data import Dataset -from tqdm import tqdm import transformers -from transformers import default_data_collator from transformers import Trainer from transformers.tokenization_utils import PreTrainedTokenizerBase -from brevitas.nn.equalized_layer import UnfusedRotatedModule from brevitas.optim.sgdg import SGDG +from brevitas_examples.llm.llm_quant.rotation_utils import extract_trainable_rotation_matrices @dataclass @@ -86,18 +84,11 @@ def apply_rotation_optimization( for param in graph_model.parameters(): param.requires_grad = False # Collect trainable matrices - trainable_parameters = [] - ids_rot = set() - for module in graph_model.modules(): - if isinstance(module, UnfusedRotatedModule): - if id(module.rot_mat) not in ids_rot: - ids_rot.add(id(module.rot_mat)) - trainable_parameters.append(module.rot_mat) - # Collect parameters for the rotation matrices - for rot_mat in trainable_parameters: + trainable_rotations = extract_trainable_rotation_matrices(graph_model) + for rot_mat in trainable_rotations: rot_mat.requires_grad = True # Initialize optimizer - optimizer = SGDG(trainable_parameters, lr=training_args.learning_rate, stiefel=True) + optimizer = SGDG(trainable_rotations, lr=training_args.learning_rate, stiefel=True) trainer = Trainer( model=graph_model, tokenizer=tokenizer, diff --git a/src/brevitas_examples/llm/llm_quant/rotation_utils.py b/src/brevitas_examples/llm/llm_quant/rotation_utils.py new file mode 100644 index 000000000..037de166d --- /dev/null +++ b/src/brevitas_examples/llm/llm_quant/rotation_utils.py @@ -0,0 +1,92 @@ +# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +import re +from typing import List + +from torch import nn +from torch.fx import GraphModule +import torch.nn.utils.parametrize as parametrize + +from brevitas.graph.base import ModuleInstanceToModuleInstance +from brevitas.graph.base import Transform +from brevitas.graph.equalize import EqualizationIndexes +from brevitas.graph.equalize import Region +from brevitas.graph.equalize import WalkRegionState +from brevitas.nn.equalized_layer import RotationWeightParametrization + + +def find_self_attention_rotation_regions( + graph_model: GraphModule, head_dim: int, state_impl_kwargs=None) -> List[Region]: + regions = [] + # See R2 rotation matrices in https://arxiv.org/pdf/2405.16406. + for src_name, src_module in graph_model.named_modules(): + if "attn_v_proj" in src_name: + if state_impl_kwargs is not None: + state = WalkRegionState(**state_impl_kwargs) + else: + state = WalkRegionState() + + block_number_matches_src = re.findall(r'\d+', src_name) + assert len(block_number_matches_src) == 2, "Could not identify block" + block_number_src = int(block_number_matches_src[1]) + + eq_indexes = EqualizationIndexes(0, head_dim, 0) + state.add_srcs(src_name, src_module, eq_indexes) + + # Now the corresponding sink + for sink_name, sink_module in graph_model.named_modules(): + if "attn_o_proj" in sink_name: + block_number_matches_sink = re.findall(r'\d+', sink_name) + assert len(block_number_matches_sink) == 2, "Could not identify block" + block_number_sink = int(block_number_matches_sink[1]) + # If the blocks match, the region was identified + if block_number_src == block_number_sink: + eq_indexes = EqualizationIndexes(0, head_dim, state.offset) + state.add_sinks(sink_name, sink_module, eq_indexes) + region = Region( + srcs=dict(sorted(state.srcs.items())), + sinks=dict(sorted(state.sinks.items())), + name_to_module=state.name_to_module, + ) + if region not in regions: + regions.append(region) + + return regions + + +def fuse_rotations(model: nn.Module) -> None: + for module in model.modules(): + # Check if the module has any parametrizations + if hasattr(module, "parametrizations"): + # Remove weight parametrizations + parametrize.remove_parametrizations(module, "weight", leave_parametrized=True) + # We need to check again, in case the weight parametrizations were the only ones + if hasattr(module, "parametrizations") and hasattr(module.parametrizations, "bias"): + parametrize.remove_parametrizations(module, "bias", leave_parametrized=True) + + +def extract_rewriters_unfused_rotations(model: nn.Module, + rewriters: List[Transform]) -> List[Transform]: + extra_rewriters = [] + for module in model.modules(): + if hasattr(module, "parametrizations"): + # Verify that the current module does not have already associated a RotatedModule + if len([r for r in rewriters if r.old_module_instance is module]) == 0: + # Identity rewriter, only useful externaly + rewriter = ModuleInstanceToModuleInstance(module, module) + extra_rewriters.append(rewriter) + return extra_rewriters + + +def extract_trainable_rotation_matrices(model: nn.Module) -> List[nn.Parameter]: + trainable_rotations = [] + # We need to keep track of the IDs of the rotation matrices, as several modules + # can share the same parametrized rotation. + ids_rot = set() + for module in model.modules(): + if isinstance(module, RotationWeightParametrization): + if id(module.rot_mat) not in ids_rot: + ids_rot.add(id(module.rot_mat)) + trainable_rotations.append(module.rot_mat) + return trainable_rotations diff --git a/src/brevitas_examples/llm/llm_quant/run_utils.py b/src/brevitas_examples/llm/llm_quant/run_utils.py index 44ba711a5..4acbd87ed 100644 --- a/src/brevitas_examples/llm/llm_quant/run_utils.py +++ b/src/brevitas_examples/llm/llm_quant/run_utils.py @@ -110,15 +110,24 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): return out +def _get_tensor_weight_id(module, tensor_name): + if hasattr(module, "parametrizations") and tensor_name in module.parametrizations: + return id(module.parametrizations[tensor_name].original) + elif hasattr(module, tensor_name): + return id(getattr(module, tensor_name)) + return None + + # This functions remap rewriters so match modules in a potentially different model that shares the same underlying tensors # We rely on the fact that two versions of the same model (eager vs FX) might have different modules id (id(fx_module) != id (eager_module)) # However, the underlying tensors are still shared, so we can recostruct the mapping between the two # modules. def fix_rewriter(rewriters, old_model_ref, tensor_name): + # We need to account for reparametrizations, to make sure the underlying tensors are accessed for r in rewriters: - tensor_id = id(r.old_module_instance.weight) + tensor_id = _get_tensor_weight_id(r.old_module_instance, tensor_name) module = [ - m for m in old_model_ref.modules() - if hasattr(m, tensor_name) and id(m.weight) == tensor_id] + m for m in old_model_ref.modules() if _get_tensor_weight_id(m, tensor_name) == tensor_id + ] r.old_module_instance = module[0] return rewriters diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 2c79101ef..4244055f6 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -20,7 +20,6 @@ from brevitas.export import export_torch_qcdq from brevitas.export.inference.manager import quant_inference_mode from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager -from brevitas.graph.equalize import find_missing_rotation_regions from brevitas.graph.equalize import GraphRotationEqualization from brevitas.graph.equalize import GraphRotationEqualizationOptimization from brevitas.graph.equalize import LayerwiseActivationRotation @@ -50,6 +49,8 @@ from brevitas_examples.llm.llm_quant.prepare_for_quantize import add_zero_bias_to_linear from brevitas_examples.llm.llm_quant.prepare_for_quantize import replace_mha_with_quantizable_layers from brevitas_examples.llm.llm_quant.rotation_optimization import apply_rotation_optimization +from brevitas_examples.llm.llm_quant.rotation_utils import extract_rewriters_unfused_rotations +from brevitas_examples.llm.llm_quant.rotation_utils import find_self_attention_rotation_regions from brevitas_examples.llm.llm_quant.prepare_for_quantize import \ replace_sdpa_with_quantizable_layers from brevitas_examples.llm.llm_quant.run_utils import CastFloat16ToFloat32 @@ -116,14 +117,21 @@ def fused_optimized_rotation_no_fx( for r in rewriters: r.apply(model) #new_model = offload_model(new_model) - additional_regions = find_missing_rotation_regions( + + # Regions with source o_proj and sink down_proj + self_attention_regions = find_self_attention_rotation_regions( new_model, model.config.hidden_size // model.config.num_attention_heads) if add_additional_regions else None eq = GraphRotationEqualizationOptimization( orphan_sink=args.rotation_orphan_sink, full_rotation_method=args.rotation_mode, ) - new_model, rewriters, rotation_matrices = eq.apply(new_model, fuse_rotations=fuse_rotations, additional_regions=additional_regions) + new_model, rewriters = eq.apply(new_model, fuse_rotations=fuse_rotations, additional_regions=self_attention_regions) + + # Retrieve additional rewriters for unfused rotations + rewriters_unfused_rotations = extract_rewriters_unfused_rotations(new_model, rewriters) + rewriters.extend(rewriters_unfused_rotations) + rewriters = fix_rewriter(rewriters, model, 'weight') for r in rewriters: diff --git a/tests/brevitas/graph/test_equalization.py b/tests/brevitas/graph/test_equalization.py index 0f65db168..1e913845c 100644 --- a/tests/brevitas/graph/test_equalization.py +++ b/tests/brevitas/graph/test_equalization.py @@ -2,21 +2,19 @@ # SPDX-License-Identifier: BSD-3-Clause import copy -from functools import partial import itertools from typing import List, Tuple from unittest.mock import patch import pytest import torch +import torch.nn.utils.parametrize as parametrize from torchvision import models from brevitas.fx import symbolic_trace -# TODO: Refactor to prevent circular import from brevitas.graph.equalize import _apply_ort_device from brevitas.graph.equalize import _batch_norm from brevitas.graph.equalize import _extract_regions -from brevitas.graph.equalize import _fuse_rotations from brevitas.graph.equalize import _get_input_axis from brevitas.graph.equalize import _get_output_axis from brevitas.graph.equalize import _is_supported_module @@ -25,12 +23,11 @@ from brevitas.graph.equalize import GraphRotationEqualization from brevitas.graph.equalize import MergeLnAffine from brevitas.graph.equalize import random_orthogonal_matrix -from brevitas.graph.hadamard import matmul_hadU from brevitas.graph.standardize import DuplicateSharedStatelessModule from brevitas.graph.standardize import TorchFunctionalToModule from brevitas.graph.utils import get_module -from brevitas.nn.equalized_layer import RotatedModule -from brevitas.nn.equalized_layer import UnfusedRotatedModule +from brevitas.nn.equalized_layer import RotationBiasParametrization +from brevitas.nn.equalized_layer import RotationWeightParametrization from tests.marker import requires_pt_ge from .equalization_fixtures import * @@ -293,13 +290,13 @@ def test_models(rotation_fixtures, partial_had): def _rotate_input_output(is_source: bool, is_sink: bool, is_orphan: bool) -> Tuple[bool, bool]: - # Verify that only one flag is enabled at the same time + # Verify that only one flag is enabled simultaneously assert sum([is_source, is_sink, is_orphan]) <= 1, "Only one flag can be enabled." rotate_input, rotate_output = False, False if is_source: rotate_output = True - if is_sink: + if is_sink or is_orphan: rotate_input = True return rotate_input, rotate_output @@ -316,32 +313,29 @@ def _compute_rotated_ouptut_from_matrices( return out -# NOTE: The assumption is that only one flag can be true simultaneously -# NOTE: Orphans need to be taken care of. A module can only be orphan once. +# RotationParametrizations can only have one type flag enabled simultaneously (is_source, is_sink, is_orphan). +# Moreover, orphan rotations need to be the outermost rotation, as this cancels out when rotating the input. def _generate_rotation_flags(N: int) -> List[bool]: return [ rotation_flags for rotation_flags in itertools.product([False, True], repeat=3 * N) if ( - all([sum(rotation_flags[i * 3:(i + 1) * 3]) <= 1 for i in range(N)]) and - # Only outermost rotation can be orphan - all([not rotation_flags[i * 3 + 2] for i in range(N - 1)]))] + all([sum(rotation_flags[i * 3:(i + 1) * 3]) <= 1 + for i in range(N)]) and all([not rotation_flags[i * 3 + 2] for i in range(N - 1)])) + ] @requires_pt_ge('2.4') @pytest_cases.parametrize('N', [1, 2, 3], ids=lambda x: f"N={x}") -def test_composition_unfused_rotation_layer(N): +def test_composition_unfused_rotations(N): torch.manual_seed(SEED) for rotation_flags in _generate_rotation_flags(N): in_features = IN_FEATURES_LINEAR module = nn.Linear(in_features=in_features, out_features=in_features) + rot_module = copy.deepcopy(module) # Sample input to pass through the block sample_input = torch.rand((1, in_features),) - - # Compose rotation modules - rotated_module = module - # Composite rotation matrices rot_mat_input = torch.eye(in_features) rot_mat_output = torch.eye(in_features) @@ -361,427 +355,34 @@ def test_composition_unfused_rotation_layer(N): rot_mat_output = rot_mat_output @ rot_mat # Compose rotation modules - rotated_module = UnfusedRotatedModule( - module=rotated_module, - rot_func=_apply_ort_device, - _get_input_axis=_get_input_axis, - _get_output_axis=_get_output_axis, - rot_mat=rot_mat, - is_source=is_source, - is_sink=is_sink, - is_orphan=is_orphan, - ) - - # Compute outputs to compare + parametrize.register_parametrization( + rot_module, + "weight", + RotationWeightParametrization( + rot_mat=rot_mat, + rot_func=_apply_ort_device, + input_axis=_get_input_axis(rot_module), + output_axis=_get_output_axis(rot_module), + is_source=is_source, + is_sink=is_sink, + is_orphan=is_orphan, + )) + parametrize.register_parametrization( + rot_module, + "bias", + RotationBiasParametrization( + rot_mat=rot_mat, + rot_func=_apply_ort_device, + input_axis=_get_input_axis(rot_module), + output_axis=_get_output_axis(rot_module), + is_source=is_source, + is_sink=is_sink, + is_orphan=is_orphan, + )) + gt_output = _compute_rotated_ouptut_from_matrices( module, sample_input, rot_mat_input, rot_mat_output) - rot_output = rotated_module(sample_input) + rot_output = rot_module(sample_input) # Verify that the rotation operations were computed correctly assert torch.allclose(gt_output, rot_output, atol=ATOL) - - -# Adapted from https://github.com/facebookresearch/SpinQuant/blob/main/eval_utils/rotation_utils.py#L26 -def _random_orthogonal_matrix(size, generator): - """ - Generate a random orthogonal matrix of the specified size. - First, we generate a random matrix with entries from a standard distribution. - Then, we use QR decomposition to obtain an orthogonal matrix. - Finally, we multiply by a diagonal matrix with diag r to adjust the signs. - - Args: - size (int): The size of the matrix (size x size). - - Returns: - torch.Tensor: An orthogonal matrix of the specified size. - """ - torch.cuda.empty_cache() - random_matrix = torch.randn(size, size, dtype=torch.float64, generator=generator) - q, r = torch.linalg.qr(random_matrix) - q *= torch.sign(torch.diag(r)).unsqueeze(0).float() - return q - - -def _random_hadamard_matrix(size, device, generator): - # See https://github.com/Cornell-RelaxML/quip-sharp , Section "Randomized Hadamard Transformation" - Q = torch.randint(low=0, high=2, size=(size,), generator=generator).to(torch.float64) - Q = Q * 2 - 1 - Q = torch.diag(Q) - return matmul_hadU(Q).to(device) - - -def _compare_module_weights_fused_unfused(gt_module, rot_module, fused_rotations=False): - gt_weight = gt_module.weight if isinstance(gt_module, nn.Linear) else gt_module.layer.weight - gt_bias = gt_module.bias if isinstance(gt_module, nn.Linear) else gt_module.layer.bias - if fused_rotations: - rot_weight = rot_module.weight if isinstance( - rot_module, nn.Linear) else rot_module.layer.weight - rot_bias = rot_module.bias if isinstance(rot_module, nn.Linear) else rot_module.layer.bias - else: - rot_weight = rot_module.weight - rot_bias = rot_module.bias - assert torch.allclose(gt_weight, rot_weight, rtol=0.0, atol=0.0) - if gt_bias is not None: - assert torch.allclose(gt_bias, rot_bias, rtol=0.0, atol=0.0) - # For a RotatedModule, corresponding to an orphan node, additional checks need to be done - if isinstance(gt_module, RotatedModule): - if not fused_rotations: - # The outermost should be an orphan - child_rot_module = rot_module - assert child_rot_module.is_orphan, "Unfused rotated module needs to be an orphan." - # Check that the inner UnfusedRotatedModules are not orphans - while isinstance(child_rot_module.module, UnfusedRotatedModule): - assert not child_rot_module.module.is_orphan, "Inner unfused rotated modules cannot be orphans." - child_rot_module = child_rot_module.module - # Verify that the rotation matrices match - assert torch.allclose(gt_module.had_mat, rot_module.rot_mat) - - -# This test verifies that the weights returned by the unfused rotate modules -# match those when fusing -@requires_pt_ge('2.4') -@pytest_cases.parametrize('partial_had', [False, True]) -@pytest_cases.parametrize('fused_rotations', [False, True]) -def test_models_rotations(rotation_fixtures, partial_had, fused_rotations): - - in_shape = IN_SIZE_LINEAR - - model_class = rotation_fixtures - model = model_class() - - model.eval() - inp = torch.rand(in_shape) - - model = symbolic_trace(model) - merge = MergeLnAffine() - model = merge.apply(model) - eq = GraphRotationEqualization(orphan_sink=partial_had, full_rotation_method='ort') - - # Save a copy to apply graph rotation equalization on - model_copy = copy.deepcopy(model) - - # We need to make sure that the same random matrices are being generated - generator = torch.Generator() - generator.manual_seed(SEED) - # Clone generator to make sure we can use the same rotation matrices - generator_clone = generator.clone_state() - - # We pass the generator to make sure that we can reproduce the random orthogonal matrices that are generated - with patch('brevitas.graph.equalize.random_orthogonal_matrix', - partial(_random_orthogonal_matrix, generator=generator)): - # Apply rotation equalization while controlling the random matrices that are generated - model = eq.apply(model) - - with torch.no_grad(): - expected_out = model(inp) - - # Now rotate but without fusing the rotation matrices - with patch('brevitas.graph.equalize.random_orthogonal_matrix', - partial(_random_orthogonal_matrix, generator=generator_clone)): - # Apply rotation equalization while controlling the random matrices that are generated - model_copy = eq.apply(model_copy, fuse_rotations=False) - - # Fuse the rotations and make sure the behaviour is the same - if fused_rotations: - _fuse_rotations(model_copy) - - with torch.no_grad(): - out = model_copy(inp) - - # Verify that the output of the model does not change after incorporating the rotations - assert torch.allclose(expected_out, out, rtol=0.0, atol=0.0) - - # Verify that weight matrices - for model_node, model_copy_node in zip(model.graph.nodes, model_copy.graph.nodes): - if model_node.op == 'call_module': - module = get_module(model, model_node.target) - module_copy = get_module(model_copy, model_copy_node.target) - if isinstance(module, (nn.Linear, RotatedModule)): - _compare_module_weights_fused_unfused(module, module_copy, fused_rotations) - - -def _compare_module_weights(module, module_copy): - weight = module.weight if isinstance(module, nn.Linear) else module.layer.weight - bias = module.bias if isinstance(module, nn.Linear) else module.layer.bias - weight_copy = module_copy.weight - bias_copy = module_copy.bias - assert torch.allclose(weight, weight_copy, rtol=0.0, atol=0.0) - if bias is not None: - assert torch.allclose(bias, bias_copy, rtol=0.0, atol=0.0) - - -import logging - -from transformers import AutoModelForCausalLM -from transformers import AutoTokenizer - -from brevitas.graph.equalize import find_missing_rotation_regions -from brevitas_examples.common.accelerate_utils.accelerate import offload_model -from brevitas_examples.common.accelerate_utils.accelerate import remove_hooks -from brevitas_examples.llm.llm_quant.data_utils import get_dataset_for_model -from brevitas_examples.llm.llm_quant.ln_affine_merge import apply_layernorm_affine_merge -from brevitas_examples.llm.llm_quant.ln_affine_merge import apply_layernorm_to_rmsnorm -from brevitas_examples.llm.llm_quant.ln_affine_merge import replace_rmsnorm_with_torch -from brevitas_examples.llm.llm_quant.run_utils import fix_rewriter -from brevitas_examples.llm.main import fused_optimized_rotation_no_fx -from brevitas_examples.llm.main import fused_rotation_no_fx -from tests.brevitas_examples.test_llm import default_run_args - - -@pytest_cases.fixture( - ids=[ - "llama",], - params=[ - { - "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", - "input_bit_width": None, - "fuse_sequences": False, - "act_calibration": False,},]) -def equalize_args(default_run_args, request): - args = default_run_args - export_dict = request.param - args.update(**export_dict) - yield args - - -@pytest.mark.llm -@requires_pt_ge('2.4') -@pytest_cases.parametrize('partial_had', [False, True]) -@pytest_cases.parametrize('fused_rotations', [False, True]) -def test_small_models_equalize_legacy_rotation_orthogonal( - caplog, partial_had, fused_rotations, equalize_args): - import os - os.environ["HF_HUB_CACHE"] = "/scratch/hf_models/" - caplog.set_level(logging.INFO) - args = equalize_args - args.rotation_orphan_sink = partial_had - args.rotation_mode = 'ort' - - kwargs = {"torch_dtype": torch.float16} - model = AutoModelForCausalLM.from_pretrained(args.model, **kwargs) - model = replace_rmsnorm_with_torch(model, model.config) - model.config.use_cache = False - print("Model loaded.") - model.eval() - tokenizer = AutoTokenizer.from_pretrained(args.model) - # Load the data for calibration and evaluation. - calibration_loader = get_dataset_for_model( - args.model, - dataset_name=args.dataset, - tokenizer=tokenizer, - nsamples=args.nsamples, - seqlen=args.seqlen, - split="train", - seed=args.seed, - require_fx=False, - device=None, - fuse_sequences=args.fuse_sequences) - - # We need to make sure that the same random matrices are being generated - generator = torch.Generator() - generator.manual_seed(SEED) - # Clone generator to make sure we can use the same rotation matrices - generator_clone = generator.clone_state() - - # Save a copy to apply graph rotation equalization on - model_copy = copy.deepcopy(model) - - # We pass the generator to make sure that we can reproduce the random orthogonal matrices that are generated - with patch('brevitas.graph.equalize.random_orthogonal_matrix', - partial(_random_orthogonal_matrix, generator=generator)): - with patch('brevitas.graph.hadamard.random_hadamard_matrix', - partial(_random_hadamard_matrix, generator=generator)): - fused_rotation_no_fx(model, calibration_loader, args, fuse_rotations=True) - - # Run model and save outputs - with torch.no_grad(): - expected_logits = model(**calibration_loader[0]).logits - - # We pass the generator to make sure that we can reproduce the random orthogonal matrices that are generated - with patch('brevitas.graph.equalize.random_orthogonal_matrix', - partial(_random_orthogonal_matrix, generator=generator_clone)): - with patch('brevitas.graph.hadamard.random_hadamard_matrix', - partial(_random_hadamard_matrix, generator=generator_clone)): - fused_rotation_no_fx(model_copy, calibration_loader, args, fuse_rotations=False) - - if fused_rotations: - _fuse_rotations(model_copy) - - # Run model and save outputs - with torch.no_grad(): - logits = model_copy(**calibration_loader[0]).logits - - # Verify that the output is the same - assert torch.allclose(expected_logits, logits) - - # Verify that the weights after fusing match - for name_fused_module, fused_module in model.named_modules(): - # For linear modules verify that the weights match - if isinstance(fused_module, (nn.Linear, RotatedModule)): - for name_unfused_Module, unfused_module in model_copy.named_modules(): - if name_fused_module == name_unfused_Module: - _compare_module_weights(fused_module, unfused_module) - # For a RotatedModule, corresponding to an orphan node, additional checks need to be done - if isinstance(fused_module, RotatedModule): - # Verify that the outer module is an orphan - if fused_rotations: - assert isinstance(unfused_module, RotatedModule) - assert torch.allclose(fused_module.had_mat, unfused_module.had_mat) - else: - assert unfused_module.is_orphan - # Verify that the rotation matrices match - assert torch.allclose(fused_module.had_mat, unfused_module.rot_mat) - - -from itertools import product - -from brevitas.graph.equalize import _apply_had_device -from brevitas.graph.hadamard import get_hadK - - -# NOTE: This test works because in R2 we patch the rotation method, so the appropiate matrix is not effectively used. This is because when the fast_hadamard_transform is not avai -@pytest.mark.llm -@requires_pt_ge('2.4') -@pytest_cases.parametrize( - 'partial_had, fused_rotations, add_additional_regions', - list(product([False, True], repeat=3)), - ids=[("fused-R1" if fused_rotations else "R1") + ("-R2" if add_additional_regions else "") + - ("-R3" if partial_had else "") for partial_had, - fused_rotations, - add_additional_regions in list(product([False, True], repeat=3))], -) -@pytest_cases.parametrize('rotation_mode', ['ort', 'had']) -def test_small_models_equalize_mixed_fused_unfused( - caplog, partial_had, fused_rotations, add_additional_regions, rotation_mode, equalize_args): - import os - os.environ["HF_HUB_CACHE"] = "/scratch/hf_models/" - caplog.set_level(logging.INFO) - args = equalize_args - args.rotation_orphan_sink = partial_had - args.rotation_mode = rotation_mode - - kwargs = {"torch_dtype": torch.float16} - model = AutoModelForCausalLM.from_pretrained(args.model, **kwargs) - model = replace_rmsnorm_with_torch(model, model.config) - model.config.use_cache = False - print("Model loaded.") - model.eval() - tokenizer = AutoTokenizer.from_pretrained(args.model) - # Load the data for calibration and evaluation. - calibration_loader = get_dataset_for_model( - args.model, - dataset_name=args.dataset, - tokenizer=tokenizer, - nsamples=args.nsamples, - seqlen=args.seqlen, - split="train", - seed=args.seed, - require_fx=False, - device=None, - fuse_sequences=args.fuse_sequences) - - # We need to make sure that the same random matrices are being generated - generator = torch.Generator() - generator.manual_seed(SEED) - # Clone generator to make sure we can use the same rotation matrices - generator_clone = generator.clone_state() - - # Run model and save outputs - with torch.no_grad(): - original_logits = model(**calibration_loader[0]).logits - - # Save a copy to apply graph rotation equalization on - model_copy = copy.deepcopy(model) - - with patch('brevitas.graph.equalize.random_orthogonal_matrix', - partial(_random_orthogonal_matrix, generator=generator)): - fused_optimized_rotation_no_fx( - model, - calibration_loader, - args, - fuse_rotations=True, - add_additional_regions=add_additional_regions) - - # Run model and save outputs - with torch.no_grad(): - expected_logits = model(**calibration_loader[0]).logits - - # Instead of random orthogonal matrices, we want to use the same ones as when the activations are not fused. - if rotation_mode == 'had': - with patch('brevitas.graph.equalize._apply_ort_device', _apply_had_device): - fused_optimized_rotation_no_fx( - model_copy, - calibration_loader, - args, - fuse_rotations=False, - add_additional_regions=add_additional_regions) - else: - with patch('brevitas.graph.equalize.random_orthogonal_matrix', - partial(_random_orthogonal_matrix, generator=generator_clone)): - fused_optimized_rotation_no_fx( - model_copy, - calibration_loader, - args, - fuse_rotations=False, - add_additional_regions=add_additional_regions) - - # Fuse matrices with module weights - if fused_rotations: - _fuse_rotations(model_copy) - - ids_rot = set() - num_rotation_matrices = 0 - # Count the number of unique rotation matrices - for module in model_copy.modules(): - if isinstance(module, UnfusedRotatedModule): - if id(module.rot_mat) not in ids_rot: - num_rotation_matrices += 1 - ids_rot.add(id(module.rot_mat)) - - num_rotated_modules = 0 - # Count the number of RotatedModules - for module in model_copy.modules(): - if isinstance(module, RotatedModule): - num_rotated_modules += 1 - - # Run model and save outputs - with torch.no_grad(): - logits = model_copy(**calibration_loader[0]).logits - - # Verify that the number of learnable rotation matrices is the expected (R1 + one R2 per block) - expected_number_rotation_matrices = 0 if fused_rotations else ( - 1 + (model.config.num_hidden_layers if add_additional_regions else 0)) - assert num_rotation_matrices == expected_number_rotation_matrices, f"Expected {expected_number_rotation_matrices} learnable rotations, found {num_rotation_matrices}." - - # Verify that the number of rotated modules is correct - expected_number_rotated_modules = 0 if not partial_had else ( - model.config.num_hidden_layers if add_additional_regions else 2 * - model.config.num_hidden_layers) - assert num_rotated_modules == expected_number_rotated_modules, f"Expected {expected_number_rotated_modules} learnable rotations, found {num_rotated_modules}." - - # Verify that the rotated module output is similar to the original FP - assert torch.allclose(original_logits, logits, atol=ATOL) - # Verify that the output is the same - assert torch.allclose(expected_logits, logits, atol=0.0, rtol=0.0) - - # Verify that the weights after fusing match - for name_fused_module, fused_module in model.named_modules(): - # For linear modules verify that the weights match - if isinstance(fused_module, (nn.Linear, RotatedModule)): - for name_unfused_Module, unfused_module in model_copy.named_modules(): - if name_fused_module == name_unfused_Module: - _compare_module_weights(fused_module, unfused_module) - # In case a RotatedModule is found, additional checks need to be done. - if isinstance(fused_module, RotatedModule): - if fused_rotations: - assert isinstance(unfused_module, RotatedModule) - assert torch.allclose(fused_module.had_mat, unfused_module.had_mat, rtol=0.0, atol=0.0), "The rotation matrices do not match." - else: - # Iterate over child nodes until finding the innermost RotatedModule - child_module = unfused_module - while isinstance(child_module, UnfusedRotatedModule): - assert not child_module.is_orphan, "UnfusedRotatedModule should not be an orphan." - child_module = child_module.module - # After finding the inner Rotated Module, they need to be compared - assert isinstance(child_module, RotatedModule), "Inner module should be RotatedModule." - assert torch.allclose(fused_module.had_mat, child_module.had_mat, rtol=0.0, atol=0.0), "The rotation matrices do not match." diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index 6a42c0895..18b791639 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -2,11 +2,15 @@ # SPDX-License-Identifier: BSD-3-Clause from argparse import Namespace +import copy from dataclasses import dataclass +from functools import partial +from itertools import product import logging import os import platform import shutil +from unittest.mock import patch import numpy as np import onnx @@ -14,15 +18,30 @@ import pytest import pytest_cases import torch -import transformers +from torch import nn +from transformers import AutoModelForCausalLM +from transformers import AutoTokenizer from brevitas import config from brevitas import torch_version -from brevitas_examples.llm.main import parse_args +# LLM example depends on optimum-amd, which requires PyTorch>=2.2 from brevitas_examples.llm.main import quantize_llm +from brevitas_examples.llm.main import parse_args + +from brevitas.graph.equalize import _apply_had_device +from brevitas.nn.equalized_layer import RotatedModule +from brevitas_examples.llm.llm_quant.data_utils import get_dataset_for_model +from brevitas_examples.llm.llm_quant.ln_affine_merge import replace_rmsnorm_with_torch +from brevitas_examples.llm.llm_quant.rotation_utils import extract_trainable_rotation_matrices +from brevitas_examples.llm.llm_quant.rotation_utils import fuse_rotations +from brevitas_examples.llm.main import fused_optimized_rotation_no_fx + +from tests.conftest import SEED from tests.marker import jit_disabled_for_export from tests.marker import requires_pt_ge +ATOL = 1e-3 + def ptid2pathname(string): return string.replace("/", "-").replace(":", "-") @@ -723,3 +742,187 @@ def test_small_models_learned_round_ppl(caplog, learned_round_ppl_args_and_ppl): quant_ppl = quant_ppl.detach().cpu().numpy() assert allveryclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}" assert allveryclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}" + + +# Adapted from https://github.com/facebookresearch/SpinQuant/blob/main/eval_utils/rotation_utils.py#L26 +# This functions needs to be patches to enable passing the generator and ensuring that the orthogonal +# matrices generated are the same. +def _random_orthogonal_matrix(size, generator): + """ + Generate a random orthogonal matrix of the specified size. + First, we generate a random matrix with entries from a standard distribution. + Then, we use QR decomposition to obtain an orthogonal matrix. + Finally, we multiply by a diagonal matrix with diag r to adjust the signs. + + Args: + size (int): The size of the matrix (size x size). + + Returns: + torch.Tensor: An orthogonal matrix of the specified size. + """ + torch.cuda.empty_cache() + random_matrix = torch.randn(size, size, dtype=torch.float64, generator=generator) + q, r = torch.linalg.qr(random_matrix) + q *= torch.sign(torch.diag(r)).unsqueeze(0).float() + return q + + +@pytest_cases.fixture( + ids=[ + "llama",], + params=[ + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "input_bit_width": None, + "fuse_sequences": False, + "act_calibration": False,},]) +def equalize_args(default_run_args, request): + args = default_run_args + export_dict = request.param + args.update(**export_dict) + yield args + + +# Auxiliar method to compare the weights in rotated modules. +def _compare_fused_unfused_rotation_modules(module_name, fused_rot_module, unfused_rot_module): + fused_weight = fused_rot_module.weight if isinstance( + fused_rot_module, nn.Linear) else fused_rot_module.layer.weight + fused_bias = fused_rot_module.bias if isinstance( + fused_rot_module, nn.Linear) else fused_rot_module.layer.bias + unfused_weight = unfused_rot_module.weight if isinstance( + unfused_rot_module, nn.Linear) else unfused_rot_module.layer.weight + unfused_bias = unfused_rot_module.bias if isinstance( + unfused_rot_module, nn.Linear) else unfused_rot_module.layer.bias + assert torch.allclose(fused_weight, unfused_weight, rtol=0.0, atol=0.0), f"The weights after rotation do not match for module {module_name}." + if fused_bias is not None: + assert torch.allclose(fused_bias, unfused_bias, rtol=0.0, atol=0.0), f"The bias after rotation do not match for module {module_name}." + # In case a RotatedModule is found, additional checks need to be done. + if isinstance(fused_rot_module, RotatedModule): + assert isinstance(unfused_rot_module, RotatedModule), f"Expected an instance of RotatedModule for module {module_name}." + assert torch.allclose(fused_rot_module.had_mat, unfused_rot_module.had_mat, rtol=0.0, atol=0.0), f"The rotation matrices of RotatedModule {module_name} do not match." + + +@pytest.mark.llm +@requires_pt_ge('2.4') +@pytest_cases.parametrize( + 'partial_had, fused_rotations, add_additional_regions', + list(product([False, True], repeat=3)), + ids=[("fused-R1" if fused_rotations else "R1") + ("-R2" if add_additional_regions else "") + + ("-R3" if partial_had else "") for partial_had, + fused_rotations, + add_additional_regions in list(product([False, True], repeat=3))], +) +@pytest_cases.parametrize('rotation_mode', ['ort', 'had']) +def test_small_models_rotations( + caplog, partial_had, fused_rotations, add_additional_regions, rotation_mode, equalize_args): + caplog.set_level(logging.INFO) + args = equalize_args + args.rotation_orphan_sink = partial_had + args.rotation_mode = rotation_mode + + kwargs = {"torch_dtype": torch.float16} + model = AutoModelForCausalLM.from_pretrained(args.model, **kwargs) + model = replace_rmsnorm_with_torch(model, model.config) + model.config.use_cache = False + print("Model loaded.") + model.eval() + tokenizer = AutoTokenizer.from_pretrained(args.model) + # Load the data for calibration and evaluation. + calibration_loader = get_dataset_for_model( + args.model, + dataset_name=args.dataset, + tokenizer=tokenizer, + nsamples=args.nsamples, + seqlen=args.seqlen, + split="train", + seed=args.seed, + require_fx=False, + device=None, + fuse_sequences=args.fuse_sequences) + + # We need to make sure that the same random matrices are being generated + generator = torch.Generator() + generator.manual_seed(SEED) + # Clone generator to make sure we can use the same rotation matrices + generator_clone = generator.clone_state() + + # Run model and save outputs + with torch.no_grad(): + original_logits = model(**calibration_loader[0]).logits + + # Save a copy to apply graph rotation equalization on + model_copy = copy.deepcopy(model) + + with patch('brevitas.graph.equalize.random_orthogonal_matrix', + partial(_random_orthogonal_matrix, generator=generator)): + fused_optimized_rotation_no_fx( + model, + calibration_loader, + args, + fuse_rotations=True, + add_additional_regions=add_additional_regions) + + # Run model and save outputs + with torch.no_grad(): + expected_logits = model(**calibration_loader[0]).logits + + # Instead of random orthogonal matrices, we want to use the same ones as when the activations are not fused. + if rotation_mode == 'had': + with patch('brevitas.graph.equalize._apply_ort_device', _apply_had_device): + fused_optimized_rotation_no_fx( + model_copy, + calibration_loader, + args, + fuse_rotations=False, + add_additional_regions=add_additional_regions) + else: + with patch('brevitas.graph.equalize.random_orthogonal_matrix', + partial(_random_orthogonal_matrix, generator=generator_clone)): + fused_optimized_rotation_no_fx( + model_copy, + calibration_loader, + args, + fuse_rotations=False, + add_additional_regions=add_additional_regions) + + # Fuse matrices with module weights + if fused_rotations: + fuse_rotations(model_copy) + + # Run model and save outputs + with torch.no_grad(): + logits = model_copy(**calibration_loader[0]).logits + + # Verify that the rotated module output is similar to the original FP + assert torch.allclose(original_logits, logits, atol=ATOL), "Output of rotated network does not approximately match that of the original network." + # Verify that the output is the same + assert torch.allclose(expected_logits, logits, atol=0.0, rtol=0.0), "Outputs of fused/unfused rotated networks do not match exactly." + + num_rotation_matrices = len(extract_trainable_rotation_matrices(model_copy)) + + num_rotated_modules = 0 + # Count the number of RotatedModules + for module in model_copy.modules(): + if isinstance(module, RotatedModule): + num_rotated_modules += 1 + + # Verify that the number of learnable rotation matrices is the expected (R1 + one R2 per block) + expected_number_rotation_matrices = 0 if fused_rotations else ( + 1 + (model.config.num_hidden_layers if add_additional_regions else 0)) + assert num_rotation_matrices == expected_number_rotation_matrices, f"Expected {expected_number_rotation_matrices} learnable rotations, found {num_rotation_matrices}." + + # Verify that the number of rotated modules is correct + expected_number_rotated_modules = 0 if not partial_had else ( + model.config.num_hidden_layers if add_additional_regions else 2 * + model.config.num_hidden_layers) + assert num_rotated_modules == expected_number_rotated_modules, f"Expected {expected_number_rotated_modules} RotatedModules found {num_rotated_modules}." + + # Verify that the weights after fusing match + for name_fused_module, fused_module in model.named_modules(): + # For linear modules verify that the weights match + if isinstance(fused_module, (nn.Linear, RotatedModule)): + for name_unfused_Module, unfused_module in model_copy.named_modules(): + if name_fused_module == name_unfused_Module: + # Verify that everything matches between the fused and unfused rotation modules + _compare_fused_unfused_rotation_modules( + name_fused_module, fused_module, unfused_module) From 9366d11e8e486d38b6cd4c93cb4627929ad78b4a Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Fri, 13 Dec 2024 11:12:04 +0000 Subject: [PATCH 07/12] Unsaved changes in llm main --- src/brevitas_examples/llm/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 4244055f6..fa0e358c8 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -118,7 +118,7 @@ def fused_optimized_rotation_no_fx( r.apply(model) #new_model = offload_model(new_model) - # Regions with source o_proj and sink down_proj + # Regions with source v_proj and sink o_proj self_attention_regions = find_self_attention_rotation_regions( new_model, model.config.hidden_size // model.config.num_attention_heads) if add_additional_regions else None From 8d2b416bc2f7374ce7be531d23fc8e6a44bd26c5 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Fri, 13 Dec 2024 17:20:51 +0000 Subject: [PATCH 08/12] Consolidate fused/unfused rotations --- src/brevitas/graph/equalize.py | 317 ++++++------------ src/brevitas/graph/hadamard.py | 2 +- .../llm/llm_quant/rotation_optimization.py | 1 - src/brevitas_examples/llm/main.py | 71 ++-- tests/brevitas_examples/test_llm.py | 56 ++-- 5 files changed, 155 insertions(+), 292 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 7056f18cd..faae03a62 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -1312,7 +1312,11 @@ def random_orthogonal_matrix(size): return q -def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method='had'): +def _apply_rotate( + model: nn.Module, + regions: List[Region], + full_rotation_method: str = 'had', + fuse_rotations: bool = True): rewriters = [] for region in regions: insert_rotation_module = len(region.srcs) == 0 @@ -1324,6 +1328,13 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method= rot_mat = random_orthogonal_matrix(hidden_dim) K = None rot_func = _apply_ort_device + elif not insert_rotation_module and not fuse_rotations: + # TODO: This might be problematic if the parameters are distributed + # across devices. Generalize this logic for safety. + device = next(model.parameters()).device + rot_mat = random_hadamard_matrix(hidden_dim, device) + K = None + rot_func = _apply_ort_device else: try: # Build hadamard rotation matrix @@ -1339,44 +1350,85 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method= print("Skipping layers") continue + # If the rotation is not fused, redefine as a Parameter, to enable its optimization + if not fuse_rotations: + rot_mat = torch.nn.Parameter(rot_mat) + for name, indexes in region.srcs.items(): module = region.get_module_from_name(name) - if hasattr(module, 'allocate_params'): - module.allocate_params(module) axis = _get_output_axis(module) - weight = module.weight.data - if axis == 0: - weight = rot_func(weight.t(), rot_mat, K).t() - elif axis == 1: - weight = rot_func(weight, rot_mat, K) - else: - raise RuntimeError("Not supported yet") - module.weight.data = weight + if fuse_rotations: + if hasattr(module, 'allocate_params'): + module.allocate_params(module) + weight = module.weight.data - if getattr(module, 'bias', None) is not None: - bias = module.bias.data - bias = rot_func(bias, rot_mat, K) - module.bias.data = bias - if hasattr(module, 'offload_params'): - module.offload_params(module) + if axis == 0: + weight = rot_func(weight.t(), rot_mat, K).t() + elif axis == 1: + weight = rot_func(weight, rot_mat, K) + else: + raise RuntimeError("Not supported yet") + module.weight.data = weight + + if getattr(module, 'bias', None) is not None: + bias = module.bias.data + bias = rot_func(bias, rot_mat, K) + module.bias.data = bias + if hasattr(module, 'offload_params'): + module.offload_params(module) + else: + parametrize.register_parametrization( + module, + "weight", + RotationWeightParametrization( + rot_mat=rot_mat, + rot_func=rot_func, + output_axis=axis, + is_source=True, + )) + if getattr(module, 'bias', None) is not None: + parametrize.register_parametrization( + module, + "bias", + RotationBiasParametrization( + rot_mat=rot_mat, + rot_func=rot_func, + output_axis=axis, + is_source=True, + )) for name, indexes in region.sinks.items(): module = region.get_module_from_name(name) - if hasattr(module, 'allocate_params'): - module.allocate_params(module) axis = _get_input_axis(module) - weight = module.weight.data - if axis == 1: - _update_weights(module, rot_func(weight, rot_mat, K), 'weight') - elif axis == 0: - _update_weights(module, rot_func(weight.t(), rot_mat, K).t(), 'weight') + if not insert_rotation_module and not fuse_rotations: + parametrize.register_parametrization( + module, + "weight", + RotationWeightParametrization( + rot_mat=rot_mat, + rot_func=rot_func, + input_axis=axis, + is_sink=True, + )) else: - raise RuntimeError("Not supported yet") + # Verify that there are no parametrizations, as otherwise the underlying weights will not be updated + assert not hasattr(module, "parametrizations"), "Fused rotations need to be incorporated before the parametrized rotations." + + if hasattr(module, 'allocate_params'): + module.allocate_params(module) + weight = module.weight.data + + if axis == 1: + _update_weights(module, rot_func(weight, rot_mat, K), 'weight') + elif axis == 0: + _update_weights(module, rot_func(weight.t(), rot_mat, K).t(), 'weight') + else: + raise RuntimeError("Not supported yet") - if hasattr(module, 'offload_params'): - module.offload_params(module) + if hasattr(module, 'offload_params'): + module.offload_params(module) if insert_rotation_module and len(region.srcs) == 0: rewriter = ModuleInstanceToModuleInstance( @@ -1476,8 +1528,12 @@ def rotate_matmuls(self, graph_module): graph_module.recompile() graph_module.graph.lint() - def apply(self, - graph_model: GraphModule) -> Union[Tuple[GraphModule, List[Transform]], GraphModule]: + def apply( + self, + graph_model: GraphModule, + fuse_rotations: bool = True, + additional_regions: Optional[List[Region]] = None + ) -> Union[Tuple[GraphModule, List[Transform]], GraphModule]: rewriters = [] regions = _extract_regions( graph_model, @@ -1486,6 +1542,8 @@ def apply(self, 'supported_sinks': self.supported_sinks, 'scale_invariant_layers': self.scale_invariant_layers, 'scale_invariant_function': self.scale_invariant_function}) + if additional_regions is not None: + regions.extend(additional_regions) eq_layers = set() orphan_regions = [] self.find_module(graph_model, orphan_regions) @@ -1497,11 +1555,18 @@ def apply(self, # Layerwise have only a single sink named 'sinks0' id_sink = id(o_r.get_module_from_name('sinks0')) if id_sink not in eq_layers: - regions.append(o_r) + # Orphan regions result in an in-place update of the weights, so these are applied before + # the rest of the rotations, to simplify the logic when fuse_rotations = False, as, + # otherwise, additional checks need to be incorporated to verify if the module weights + # have any parametrizations already, since in that case, the in-place update needs to + # be performed in module.parametrizations.weight.original. + # TODO: Use deque to perform this operation in O(1) + regions = [o_r] + regions if self.rotate_matmul: self.rotate_matmuls(graph_model) if len(regions) > 0: - rewriters = _apply_rotate(graph_model, regions, self.full_rotation_method) + rewriters = _apply_rotate( + graph_model, regions, self.full_rotation_method, fuse_rotations) if self.return_rewriters: return graph_model, rewriters else: @@ -1600,193 +1665,3 @@ def apply(self, model: nn.Module, fuse_rotations: bool = True) -> nn.Module: if len(regions) > 0: _apply_rotate(model, regions) return model - - -def _apply_rotate_fused_rotations( - model: nn.Module, - regions: List[Region], - full_rotation_method='had', - fuse_rotations: bool = True): - rewriters = [] - for region in regions: - insert_rotation_module = len(region.srcs) == 0 - - if not insert_rotation_module and not region.is_valid: - continue - hidden_dim = region.max_shape_sinks - if not insert_rotation_module and full_rotation_method == 'ort': - rot_mat = random_orthogonal_matrix(hidden_dim) - # If the rotations are not fused, redefine as parameter - if not fuse_rotations: - rot_mat = torch.nn.Parameter(rot_mat) - K = None - rot_func = _apply_ort_device - elif not insert_rotation_module and not fuse_rotations: - # TODO: Generalize - device = next(model.parameters()).device - rot_mat = torch.nn.Parameter(random_hadamard_matrix(hidden_dim, device)) - K = None - rot_func = _apply_ort_device - else: - try: - # Build hadamard rotation matrix - rot_mat, K = get_hadK(hidden_dim) - rot_func = _apply_had_device - except AssertionError as e: - print(f"Incomptible shapes {hidden_dim}") - if not insert_rotation_module: - print("Falling back to orthogonal matrices") - rot_mat = random_orthogonal_matrix(hidden_dim) - K = None - rot_func = _apply_ort_device - print("Skipping layers") - continue - - for name, indexes in region.srcs.items(): - module = region.get_module_from_name(name) - axis = _get_output_axis(module) - - assert not insert_rotation_module, "Orphan regions must not have sources." - - if not insert_rotation_module and fuse_rotations: - # Verify that there are no parametrizations, as otherwise the underlying data will not be updated - assert not hasattr(module, "parametrizations"), "Fused rotations need to be incorporated before the parametrized rotations." - - if hasattr(module, 'allocate_params'): - module.allocate_params(module) - - weight = module.weight.data - - if axis == 0: - weight = rot_func(weight.t(), rot_mat, K).t() - elif axis == 1: - weight = rot_func(weight, rot_mat, K) - else: - raise RuntimeError("Not supported yet") - module.weight.data = weight - - if getattr(module, 'bias', None) is not None: - bias = module.bias.data - bias = rot_func(bias, rot_mat, K) - module.bias.data = bias - - if hasattr(module, 'offload_params'): - module.offload_params(module) - elif not insert_rotation_module and not fuse_rotations: - # Parametrize weights and possibly bias with unfused rotations - parametrize.register_parametrization( - module, - "weight", - RotationWeightParametrization( - rot_mat=rot_mat, - rot_func=rot_func, - output_axis=axis, - is_source=True, - )) - if getattr(module, 'bias', None) is not None: - parametrize.register_parametrization( - module, - "bias", - RotationBiasParametrization( - rot_mat=rot_mat, - rot_func=rot_func, - output_axis=axis, - is_source=True, - )) - - for name, indexes in region.sinks.items(): - module = region.get_module_from_name(name) - axis = _get_input_axis(module) - - if not insert_rotation_module and not fuse_rotations: - parametrize.register_parametrization( - module, - "weight", - RotationWeightParametrization( - rot_mat=rot_mat, - rot_func=rot_func, - input_axis=axis, - is_sink=True, - )) - else: - # Verify that there are no parametrizations, as otherwise the underlying data will not be updated - assert not hasattr(module, "parametrizations"), "Fused rotations need to be incorporated before the parametrized rotations." - - if hasattr(module, 'allocate_params'): - module.allocate_params(module) - weight = module.weight.data - - if axis == 1: - _update_weights(module, rot_func(weight, rot_mat, K), 'weight') - elif axis == 0: - _update_weights(module, rot_func(weight.t(), rot_mat, K).t(), 'weight') - else: - raise RuntimeError("Not supported yet") - - if hasattr(module, 'offload_params'): - module.offload_params(module) - - if insert_rotation_module and len(region.srcs) == 0: - rewriter = ModuleInstanceToModuleInstance( - module, RotatedModule(had_mat=rot_mat, k=K, layer=module)) - rewriters.append(rewriter) - for r in rewriters: - model = r.apply(model) - return rewriters - - -# TODO: Consolidate with GraphRotationEqualization -class GraphRotationEqualizationOptimization(GraphRotationEqualization): - - def __init__( - self, - blacklist_layers: Optional[List[str]] = None, - orphan_sink: bool = False, - rotate_matmul: bool = False, - full_rotation_method: str = 'had', - ) -> None: - super(GraphRotationEqualizationOptimization, self).__init__( - blacklist_layers=blacklist_layers, - orphan_sink=orphan_sink, - rotate_matmul=rotate_matmul, - full_rotation_method=full_rotation_method, - return_rewriters=True, - ) - - def apply( - self, - graph_model: GraphModule, - fuse_rotations: bool = True, - additional_regions: Optional[List] = None - ) -> Union[Tuple[GraphModule, List[Transform]], GraphModule]: - rewriters = [] - regions = _extract_regions( - graph_model, - state_impl_kwargs={ - 'supported_srcs': self.supported_srcs, - 'supported_sinks': self.supported_sinks, - 'scale_invariant_layers': self.scale_invariant_layers, - 'scale_invariant_function': self.scale_invariant_function}) - if additional_regions is not None: - regions.extend(additional_regions) - eq_layers = set() - orphan_regions = [] - self.find_module(graph_model, orphan_regions) - for r in regions: - id_list = [id(r.name_to_module[sink_name]) for sink_name in r.sinks_names] - eq_layers.update(id_list) - if self.orphan_sink: - for o_r in orphan_regions: - # Layerwise have only a single sink named 'sinks0' - id_sink = id(o_r.get_module_from_name('sinks0')) - if id_sink not in eq_layers: - regions = [o_r] + regions - if self.rotate_matmul: - self.rotate_matmuls(graph_model) - if len(regions) > 0: - rewriters = _apply_rotate_fused_rotations( - graph_model, regions, self.full_rotation_method, fuse_rotations) - if self.return_rewriters: - return graph_model, rewriters - else: - return graph_model diff --git a/src/brevitas/graph/hadamard.py b/src/brevitas/graph/hadamard.py index d3773c759..aa05546b1 100644 --- a/src/brevitas/graph/hadamard.py +++ b/src/brevitas/graph/hadamard.py @@ -123,7 +123,7 @@ def random_hadamard_matrix(size, device): Q = torch.randint(low=0, high=2, size=(size,)).to(torch.float64) Q = Q * 2 - 1 Q = torch.diag(Q) - return matmul_hadU(Q).to(device) + return matmul_hadU(Q).to(device).float() def matmul_hadU_cuda(X, hadK, K): diff --git a/src/brevitas_examples/llm/llm_quant/rotation_optimization.py b/src/brevitas_examples/llm/llm_quant/rotation_optimization.py index 93b7edf7e..31cf00051 100644 --- a/src/brevitas_examples/llm/llm_quant/rotation_optimization.py +++ b/src/brevitas_examples/llm/llm_quant/rotation_optimization.py @@ -87,7 +87,6 @@ def apply_rotation_optimization( trainable_rotations = extract_trainable_rotation_matrices(graph_model) for rot_mat in trainable_rotations: rot_mat.requires_grad = True - # Initialize optimizer optimizer = SGDG(trainable_rotations, lr=training_args.learning_rate, stiefel=True) trainer = Trainer( model=graph_model, diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index fa0e358c8..45a2124d1 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -21,7 +21,6 @@ from brevitas.export.inference.manager import quant_inference_mode from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager from brevitas.graph.equalize import GraphRotationEqualization -from brevitas.graph.equalize import GraphRotationEqualizationOptimization from brevitas.graph.equalize import LayerwiseActivationRotation from brevitas.graph.quantize import layerwise_quantize from brevitas.graph.utils import get_module @@ -80,7 +79,12 @@ def set_seed(seed): torch.random.manual_seed(seed) -def fused_rotation_no_fx(model, calibration_loader, args, fuse_rotations: bool = False): +def fused_rotation_no_fx( + model, + calibration_loader, + args, + fuse_rotations: bool = True, + add_self_attention_regions: bool = False): with torch.no_grad(): new_model, guards = torch._dynamo.export(model)(**calibration_loader[0]) apply_layernorm_affine_merge(new_model) @@ -94,49 +98,22 @@ def fused_rotation_no_fx(model, calibration_loader, args, fuse_rotations: bool = orphan_sink=args.rotation_orphan_sink, full_rotation_method=args.rotation_mode, return_rewriters=True) - new_model, rewriters = eq.apply(new_model, fuse_rotations=fuse_rotations) - rewriters = fix_rewriter(rewriters, model, 'weight') - - for r in rewriters: - r.apply(model) - remove_hooks(new_model) - - -def fused_optimized_rotation_no_fx( - model, - calibration_loader, - args, - fuse_rotations: bool = False, - add_additional_regions: bool = False): - with torch.no_grad(): - new_model, guards = torch._dynamo.export(model)(**calibration_loader[0]) - apply_layernorm_affine_merge(new_model) - new_model, rewriters = apply_layernorm_to_rmsnorm(new_model, return_rewriters=True) - rewriters = fix_rewriter(rewriters, model, 'weight') - - for r in rewriters: - r.apply(model) - #new_model = offload_model(new_model) - # Regions with source v_proj and sink o_proj - self_attention_regions = find_self_attention_rotation_regions( - new_model, model.config.hidden_size // - model.config.num_attention_heads) if add_additional_regions else None - eq = GraphRotationEqualizationOptimization( - orphan_sink=args.rotation_orphan_sink, - full_rotation_method=args.rotation_mode, - ) + self_attention_regions = ( + find_self_attention_rotation_regions( + new_model, model.config.hidden_size // + model.config.num_attention_heads) if add_self_attention_regions else None) new_model, rewriters = eq.apply(new_model, fuse_rotations=fuse_rotations, additional_regions=self_attention_regions) - - # Retrieve additional rewriters for unfused rotations - rewriters_unfused_rotations = extract_rewriters_unfused_rotations(new_model, rewriters) - rewriters.extend(rewriters_unfused_rotations) + # Additional rewriters need to be added if rotations are not fused + if not fuse_rotations: + rewriters_unfused_rotations = extract_rewriters_unfused_rotations(new_model, rewriters) + rewriters.extend(rewriters_unfused_rotations) rewriters = fix_rewriter(rewriters, model, 'weight') for r in rewriters: r.apply(model) - #remove_hooks(new_model) + remove_hooks(new_model) def set_seed(seed): @@ -314,7 +291,7 @@ def quantize_llm(args, unknown_args=None): model = replace_rmsnorm_with_torch(model, model.config) # TODO: Refactor - if args.rotation == 'fused_no_fx_optimize': + if args.rotation in ['fused_no_fx_optimize', 'fused_no_fx_optimize_self_attn_region']: for i in range(len(calibration_loader)): del calibration_loader[i]["attention_mask"] calibration_loader[i]["labels"] = calibration_loader[i]["input_ids"] @@ -355,8 +332,11 @@ def quantize_llm(args, unknown_args=None): elif args.rotation == 'fused_no_fx': fused_rotation_no_fx(model, calibration_loader, args) elif args.rotation == 'fused_no_fx_optimize': - fused_optimized_rotation_no_fx( - model, calibration_loader, args, fuse_rotations=False, add_additional_regions=True) + fused_rotation_no_fx( + model, calibration_loader, args, fuse_rotations=False, add_self_attention_regions=False) + elif args.rotation == 'fused_no_fx_optimize_self_attn_region': + fused_rotation_no_fx( + model, calibration_loader, args, fuse_rotations=False, add_self_attention_regions=True) # Insert standard MHA layers when performing fx based weight/act equalization to avoid dealing # with all the variability in HF implementations @@ -480,7 +460,7 @@ def quantize_llm(args, unknown_args=None): # TODO: Refactor remove_hooks(model) - if args.rotation == 'fused_no_fx_optimize': + if args.rotation in ['fused_no_fx_optimize', 'fused_no_fx_optimize_self_attn_region']: apply_rotation_optimization( graph_model=model, tokenizer=tokenizer, @@ -822,7 +802,12 @@ def parse_args(args, override_defaults={}): '--rotation', type=str, default=None, - choices=['fx', 'layerwise', 'fused_no_fx', 'fused_no_fx_optimize'], + choices=[ + 'fx', + 'layerwise', + 'fused_no_fx', + 'fused_no_fx_optimize', + 'fused_no_fx_optimize_self_attn_region'], help='Apply graph rotation equalization') parser.add_argument( '--rotation-mode', diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index 18b791639..84d329430 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -34,7 +34,7 @@ from brevitas_examples.llm.llm_quant.ln_affine_merge import replace_rmsnorm_with_torch from brevitas_examples.llm.llm_quant.rotation_utils import extract_trainable_rotation_matrices from brevitas_examples.llm.llm_quant.rotation_utils import fuse_rotations -from brevitas_examples.llm.main import fused_optimized_rotation_no_fx +from brevitas_examples.llm.main import fused_rotation_no_fx from tests.conftest import SEED from tests.marker import jit_disabled_for_export @@ -853,37 +853,41 @@ def test_small_models_rotations( # Save a copy to apply graph rotation equalization on model_copy = copy.deepcopy(model) - with patch('brevitas.graph.equalize.random_orthogonal_matrix', - partial(_random_orthogonal_matrix, generator=generator)): - fused_optimized_rotation_no_fx( - model, - calibration_loader, - args, - fuse_rotations=True, - add_additional_regions=add_additional_regions) + # offload_model is patched to behave as an identity, thus making sure that the operations + # are deterministic, enabling to test that the tensors match exactly. + with patch('brevitas_examples.llm.main.offload_model', lambda m: m): + with patch('brevitas.graph.equalize.random_orthogonal_matrix', + partial(_random_orthogonal_matrix, generator=generator)): + fused_rotation_no_fx( + model, + calibration_loader, + args, + fuse_rotations=True, + add_self_attention_regions=add_additional_regions) # Run model and save outputs with torch.no_grad(): expected_logits = model(**calibration_loader[0]).logits # Instead of random orthogonal matrices, we want to use the same ones as when the activations are not fused. - if rotation_mode == 'had': - with patch('brevitas.graph.equalize._apply_ort_device', _apply_had_device): - fused_optimized_rotation_no_fx( - model_copy, - calibration_loader, - args, - fuse_rotations=False, - add_additional_regions=add_additional_regions) - else: - with patch('brevitas.graph.equalize.random_orthogonal_matrix', - partial(_random_orthogonal_matrix, generator=generator_clone)): - fused_optimized_rotation_no_fx( - model_copy, - calibration_loader, - args, - fuse_rotations=False, - add_additional_regions=add_additional_regions) + with patch('brevitas_examples.llm.main.offload_model', lambda m: m): + if rotation_mode == 'had': + with patch('brevitas.graph.equalize._apply_ort_device', _apply_had_device): + fused_rotation_no_fx( + model_copy, + calibration_loader, + args, + fuse_rotations=False, + add_self_attention_regions=add_additional_regions) + else: + with patch('brevitas.graph.equalize.random_orthogonal_matrix', + partial(_random_orthogonal_matrix, generator=generator_clone)): + fused_rotation_no_fx( + model_copy, + calibration_loader, + args, + fuse_rotations=False, + add_self_attention_regions=add_additional_regions) # Fuse matrices with module weights if fused_rotations: From 59efe9b4c0defd96e56837b5144243837fdaf33e Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Mon, 16 Dec 2024 11:17:32 +0000 Subject: [PATCH 09/12] Enable quantization of parametrized modules --- src/brevitas/graph/base.py | 15 ++++++++++++++- src/brevitas/graph/quantize_impl.py | 8 +++++--- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/brevitas/graph/base.py b/src/brevitas/graph/base.py index def3f7070..dae5160ce 100644 --- a/src/brevitas/graph/base.py +++ b/src/brevitas/graph/base.py @@ -3,11 +3,13 @@ from abc import ABC from abc import abstractmethod +from collections import OrderedDict import inspect from inspect import getcallargs import torch from torch.nn import Module +import torch.nn.utils.parametrize as parametrize from torch.overrides import get_testing_overrides from brevitas.fx import GraphModule @@ -154,7 +156,18 @@ def _init_new_module(self, old_module: Module, name=None): def _replace_old_module(self, model, old_module, new_module, load_state_dict=True): replace_module(model, old_module, new_module) if load_state_dict: - new_module.load_state_dict(old_module.state_dict()) + # The dictionary entries relative to parametrizations need to be ignored, as these are passed + # when invoking transfer_parametrizations_and_params. + old_module_state_dict = OrderedDict({ + k: v for k, + v in old_module.state_dict().items() if not k.startswith("parametrizations")}) + # If the old module is parametrized, these need to be transferred to the new module. Strict needs to be set to False, + # as there will be missing keys for those parameters which have any parametrizations attached. + if parametrize.is_parametrized(old_module): + new_module.load_state_dict(old_module_state_dict, strict=False) + parametrize.transfer_parametrizations_and_params(old_module, new_module) + else: + new_module.load_state_dict(old_module_state_dict) class InsertModuleCallAfter(GraphTransform): diff --git a/src/brevitas/graph/quantize_impl.py b/src/brevitas/graph/quantize_impl.py index 535f9a8f9..a4d348ab5 100644 --- a/src/brevitas/graph/quantize_impl.py +++ b/src/brevitas/graph/quantize_impl.py @@ -6,6 +6,7 @@ import torch import torch.nn as nn +import torch.nn.utils.parametrize as parametrize import brevitas from brevitas.graph.base import InsertModuleCallAfter @@ -511,7 +512,7 @@ def find_module( Specifically, it allows to map nn.MultiheadAttetion to its quantized counterpart and not its Linear submodules. """ - if _module_class_name(type(model)) in layer_map.keys(): + if _module_class_name(parametrize.type_before_parametrizations(model)) in layer_map.keys(): module_to_replace.append(model) else: for name, module in model.named_children(): @@ -532,8 +533,9 @@ def layerwise_layer_handler( find_module(model, layer_map, module_to_replace, name_blacklist) rewriters = [] for module in module_to_replace: - if layer_map[_module_class_name(type(module))] is not None: - quant_module_class, quant_module_kwargs = layer_map[_module_class_name(type(module))] + if layer_map[_module_class_name( + parametrize.type_before_parametrizations(module))] is not None: + quant_module_class, quant_module_kwargs = layer_map[_module_class_name(parametrize.type_before_parametrizations(module))] rewriter = ModuleToModuleByInstance(module, quant_module_class, **quant_module_kwargs) rewriters.append(rewriter) for rewriter in rewriters: From 3274aa8b7e7399c3d0538299a3018314581b0519 Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Tue, 31 Dec 2024 11:04:44 +0000 Subject: [PATCH 10/12] Fix logic registering parametrizations --- src/brevitas/graph/base.py | 98 +++++++++++++++++ src/brevitas/graph/equalize.py | 104 +++++++++++------- src/brevitas/graph/quantize_impl.py | 3 +- src/brevitas/nn/equalized_layer.py | 22 +--- .../llm/llm_quant/rotation_optimization.py | 3 +- .../llm/llm_quant/rotation_utils.py | 4 +- src/brevitas_examples/llm/main.py | 64 +++++++++-- tests/brevitas/graph/test_equalization.py | 3 +- 8 files changed, 226 insertions(+), 75 deletions(-) diff --git a/src/brevitas/graph/base.py b/src/brevitas/graph/base.py index dae5160ce..1546ecb67 100644 --- a/src/brevitas/graph/base.py +++ b/src/brevitas/graph/base.py @@ -6,9 +6,12 @@ from collections import OrderedDict import inspect from inspect import getcallargs +from typing import Any, Callable, Dict, Type, Union import torch +from torch import Tensor from torch.nn import Module +from torch.nn import Parameter import torch.nn.utils.parametrize as parametrize from torch.overrides import get_testing_overrides @@ -187,6 +190,76 @@ def apply(self, graph_model: GraphModule) -> GraphModule: return graph_model +class ModuleInstanceRegisterParametrization(Transform): + + def __init__( + self, old_module_instance: Module, tensor_name: str, + parametrization_module: Module) -> None: + self.old_module_instance = old_module_instance + self.tensor_name = tensor_name + self.parametrization_module = parametrization_module + + def apply(self, model: GraphModule) -> GraphModule: + for old_module in model.modules(): + if old_module is self.old_module_instance: + # register the parametrization in the old_module + parametrize.register_parametrization( + old_module, self.tensor_name, self.parametrization_module) + break + return model + + +class ModuleInstanceFuseRotationWeights(Transform): + + def __init__( + self, + old_module_instance: Module, + rot_mat: Union[Parameter, Tensor], + rot_func: Callable, + K: int, + tensor_name: str, + axis: int, + is_source: bool, + ): + self.old_module_instance = old_module_instance + self.rot_mat = rot_mat + self.rot_func = rot_func + self.K = K + self.tensor_name = tensor_name + self.axis = axis + self.is_source = is_source + + def apply(self, model: GraphModule) -> GraphModule: + for old_module in model.modules(): + if old_module is self.old_module_instance: + if hasattr(old_module, 'allocate_params'): + old_module.allocate_params(old_module) + weight = getattr(old_module, self.tensor_name).data + + if self.is_source: + if self.axis == 0: + weight = self.rot_func(weight.t(), self.rot_mat, self.K).t() + elif self.axis == 1: + weight = self.rot_func(weight, self.rot_mat, self.K) + else: + raise RuntimeError("Not supported yet") + # If not a source, the module is either a sink or an orphan + else: + if self.axis == 1: + weight = self.rot_func(weight, self.rot_mat, self.K) + elif self.axis == 0: + weight = self.rot_func(weight.t(), self.rot_mat, self.K).t() + else: + raise RuntimeError("Not supported yet") + # Modify the weights in-place + getattr(old_module, self.tensor_name).data = weight + + if hasattr(old_module, 'offload_params'): + old_module.offload_params(old_module) + break + return model + + class ModuleInstanceToModuleInstance(Transform): def __init__(self, old_module_instance, new_module_instance): @@ -202,6 +275,31 @@ def apply(self, model: GraphModule) -> GraphModule: return model +class ModuleInstanceWrapModule(Transform): + + def __init__( + self, + old_module_instance: Module, + wrapper_class: Type[Module], + module_attribute: str, + kwargs_wrapper: Dict[str, Any]): + self.old_module_instance = old_module_instance + self.wrapper_class = wrapper_class + self.module_attribute = module_attribute + self.kwargs_wrapper = kwargs_wrapper + + def apply(self, model: GraphModule) -> GraphModule: + for old_module in model.modules(): + if old_module is self.old_module_instance: + kwargs = {self.module_attribute: self.old_module_instance} + kwargs.update(self.kwargs_wrapper) + new_module_instance = self.wrapper_class(**kwargs) + # init the new module based on the old one + replace_module(model, old_module, new_module_instance) + break + return model + + class ModuleToModuleByName(ModuleToModule): def __init__(self, old_module_name, new_module_class, **kwargs): diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index faae03a62..0da49233c 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -21,11 +21,13 @@ from brevitas import torch_version from brevitas.fx import GraphModule from brevitas.fx import Node -from brevitas.graph import ModuleToModuleByClass from brevitas.graph import ModuleToModuleByInstance from brevitas.graph.base import GraphTransform from brevitas.graph.base import InsertModuleCallAfter +from brevitas.graph.base import ModuleInstanceFuseRotationWeights +from brevitas.graph.base import ModuleInstanceRegisterParametrization from brevitas.graph.base import ModuleInstanceToModuleInstance +from brevitas.graph.base import ModuleInstanceWrapModule from brevitas.graph.base import Transform from brevitas.graph.hadamard import get_hadK from brevitas.graph.hadamard import matmul_hadU @@ -1316,7 +1318,8 @@ def _apply_rotate( model: nn.Module, regions: List[Region], full_rotation_method: str = 'had', - fuse_rotations: bool = True): + fuse_rotations: bool = True, + apply_inplace_rotations: bool = True): rewriters = [] for region in regions: insert_rotation_module = len(region.srcs) == 0 @@ -1351,7 +1354,7 @@ def _apply_rotate( continue # If the rotation is not fused, redefine as a Parameter, to enable its optimization - if not fuse_rotations: + if not insert_rotation_module and not fuse_rotations: rot_mat = torch.nn.Parameter(rot_mat) for name, indexes in region.srcs.items(): @@ -1359,36 +1362,44 @@ def _apply_rotate( axis = _get_output_axis(module) if fuse_rotations: - if hasattr(module, 'allocate_params'): - module.allocate_params(module) - weight = module.weight.data - - if axis == 0: - weight = rot_func(weight.t(), rot_mat, K).t() - elif axis == 1: - weight = rot_func(weight, rot_mat, K) - else: - raise RuntimeError("Not supported yet") - module.weight.data = weight + rewriter = ModuleInstanceFuseRotationWeights( + old_module_instance=module, + rot_mat=rot_mat, + rot_func=rot_func, + K=K, + tensor_name="weight", + axis=axis, + is_source=True, + ) + rewriters.append(rewriter) if getattr(module, 'bias', None) is not None: - bias = module.bias.data - bias = rot_func(bias, rot_mat, K) - module.bias.data = bias - if hasattr(module, 'offload_params'): - module.offload_params(module) + rewriter = ModuleInstanceFuseRotationWeights( + old_module_instance=module, + rot_mat=rot_mat, + rot_func=rot_func, + K=K, + tensor_name="bias", + axis=1, + is_source=True, + ) + rewriters.append(rewriter) else: - parametrize.register_parametrization( + rewriter = ModuleInstanceRegisterParametrization( module, "weight", RotationWeightParametrization( rot_mat=rot_mat, rot_func=rot_func, - output_axis=axis, + axis=axis, is_source=True, )) + rewriters.append(rewriter) if getattr(module, 'bias', None) is not None: - parametrize.register_parametrization( + # TODO: Consolidate RotationBiasParametrization into a single + # class, by setting output_axis = 1. Also, could use a single + # axis, as input_axis and output_axis are not used simultaneously + rewriter = ModuleInstanceRegisterParametrization( module, "bias", RotationBiasParametrization( @@ -1397,45 +1408,49 @@ def _apply_rotate( output_axis=axis, is_source=True, )) + rewriters.append(rewriter) for name, indexes in region.sinks.items(): module = region.get_module_from_name(name) axis = _get_input_axis(module) if not insert_rotation_module and not fuse_rotations: - parametrize.register_parametrization( + rewriter = ModuleInstanceRegisterParametrization( module, "weight", RotationWeightParametrization( rot_mat=rot_mat, rot_func=rot_func, - input_axis=axis, + axis=axis, is_sink=True, )) + rewriters.append(rewriter) else: # Verify that there are no parametrizations, as otherwise the underlying weights will not be updated assert not hasattr(module, "parametrizations"), "Fused rotations need to be incorporated before the parametrized rotations." - if hasattr(module, 'allocate_params'): - module.allocate_params(module) - weight = module.weight.data - - if axis == 1: - _update_weights(module, rot_func(weight, rot_mat, K), 'weight') - elif axis == 0: - _update_weights(module, rot_func(weight.t(), rot_mat, K).t(), 'weight') - else: - raise RuntimeError("Not supported yet") - - if hasattr(module, 'offload_params'): - module.offload_params(module) + rewriter = ModuleInstanceFuseRotationWeights( + old_module_instance=module, + rot_mat=rot_mat, + rot_func=rot_func, + K=K, + tensor_name="weight", + axis=axis, + is_source=False, + ) + rewriters.append(rewriter) if insert_rotation_module and len(region.srcs) == 0: - rewriter = ModuleInstanceToModuleInstance( - module, RotatedModule(had_mat=rot_mat, k=K, layer=module)) + rewriter = ModuleInstanceWrapModule( + module, RotatedModule, "layer", { + "had_mat": rot_mat, "k": K}) rewriters.append(rewriter) for r in rewriters: - model = r.apply(model) + # The parametrizations need to be registered after the potential HF hooks have been + # removed, as otherwise the device maps will not match the structure of the + # model's state_dict after the registration of the parametrizations. + if apply_inplace_rotations and not isinstance(r, ModuleInstanceRegisterParametrization): + model = r.apply(model) return rewriters @@ -1532,7 +1547,8 @@ def apply( self, graph_model: GraphModule, fuse_rotations: bool = True, - additional_regions: Optional[List[Region]] = None + additional_regions: Optional[List[Region]] = None, + apply_inplace_rotations: bool = True, ) -> Union[Tuple[GraphModule, List[Transform]], GraphModule]: rewriters = [] regions = _extract_regions( @@ -1566,7 +1582,11 @@ def apply( self.rotate_matmuls(graph_model) if len(regions) > 0: rewriters = _apply_rotate( - graph_model, regions, self.full_rotation_method, fuse_rotations) + graph_model, + regions, + self.full_rotation_method, + fuse_rotations, + apply_inplace_rotations) if self.return_rewriters: return graph_model, rewriters else: diff --git a/src/brevitas/graph/quantize_impl.py b/src/brevitas/graph/quantize_impl.py index a4d348ab5..538ce5717 100644 --- a/src/brevitas/graph/quantize_impl.py +++ b/src/brevitas/graph/quantize_impl.py @@ -7,6 +7,7 @@ import torch import torch.nn as nn import torch.nn.utils.parametrize as parametrize +from tqdm import tqdm import brevitas from brevitas.graph.base import InsertModuleCallAfter @@ -538,6 +539,6 @@ def layerwise_layer_handler( quant_module_class, quant_module_kwargs = layer_map[_module_class_name(parametrize.type_before_parametrizations(module))] rewriter = ModuleToModuleByInstance(module, quant_module_class, **quant_module_kwargs) rewriters.append(rewriter) - for rewriter in rewriters: + for rewriter in tqdm(rewriters, leave=False): model = rewriter.apply(model) return model diff --git a/src/brevitas/nn/equalized_layer.py b/src/brevitas/nn/equalized_layer.py index ccd812713..2c48f9da3 100644 --- a/src/brevitas/nn/equalized_layer.py +++ b/src/brevitas/nn/equalized_layer.py @@ -69,14 +69,6 @@ def __init__(self, layer, had_mat=None, k=None) -> None: self.layer = layer self.k = k - @property - def weight(self) -> Optional[torch.Tensor]: - return getattr(self.layer, 'weight', None) - - @property - def bias(self) -> Optional[torch.Tensor]: - return getattr(self.layer, 'bias', None) - def forward(self, inp, **kwargs): is_cuda = 'cuda' in str(inp.device) and torch.version.cuda is not None # If k is None, we assume that an orthogonal matrix is used @@ -110,8 +102,7 @@ def __init__( self, rot_mat: torch.nn.Parameter, rot_func: Callable, - input_axis: Optional[int] = None, - output_axis: Optional[int] = None, + axis: int, is_source: bool = False, is_sink: bool = False, is_orphan: bool = False, @@ -119,8 +110,7 @@ def __init__( super().__init__() self.rot_mat = rot_mat self.rot_func = rot_func - self.input_axis = input_axis - self.output_axis = output_axis + self.axis = axis self.is_source = is_source self.is_sink = is_sink self.is_orphan = is_orphan @@ -128,17 +118,17 @@ def __init__( def forward(self, weight: torch.Tensor) -> torch.Tensor: if self.is_sink or self.is_orphan: - if self.input_axis == 1: + if self.axis == 1: weight = self.rot_func(weight, self.rot_mat, self.K) - elif self.input_axis == 0: + elif self.axis == 0: weight = self.rot_func(weight.t(), self.rot_mat, self.K).t() else: raise RuntimeError("Not supported yet") if self.is_source: - if self.output_axis == 0: + if self.axis == 0: weight = self.rot_func(weight.t(), self.rot_mat, self.K).t() - elif self.output_axis == 1: + elif self.axis == 1: weight = self.rot_func(weight, self.rot_mat, self.K) else: raise RuntimeError("Not supported yet") diff --git a/src/brevitas_examples/llm/llm_quant/rotation_optimization.py b/src/brevitas_examples/llm/llm_quant/rotation_optimization.py index 31cf00051..618763498 100644 --- a/src/brevitas_examples/llm/llm_quant/rotation_optimization.py +++ b/src/brevitas_examples/llm/llm_quant/rotation_optimization.py @@ -20,8 +20,7 @@ @dataclass class ModelArguments: input_model: Optional[str] = field( - default="hf-internal-testing/tiny-random-LlamaForCausalLM", - metadata={"help": "Input model"}) + default="meta-llama/Llama-3.2-1B", metadata={"help": "Input model"}) output_rotation_path: Optional[str] = field( default="test-output", metadata={"help": "Output rotation checkpoint path"}) optimized_rotation_path: Optional[str] = field( diff --git a/src/brevitas_examples/llm/llm_quant/rotation_utils.py b/src/brevitas_examples/llm/llm_quant/rotation_utils.py index 037de166d..9c84aeff7 100644 --- a/src/brevitas_examples/llm/llm_quant/rotation_utils.py +++ b/src/brevitas_examples/llm/llm_quant/rotation_utils.py @@ -66,13 +66,15 @@ def fuse_rotations(model: nn.Module) -> None: parametrize.remove_parametrizations(module, "bias", leave_parametrized=True) +# TODO: Remove? We rely on ModuleInstanceRegisterParametrization def extract_rewriters_unfused_rotations(model: nn.Module, rewriters: List[Transform]) -> List[Transform]: extra_rewriters = [] for module in model.modules(): if hasattr(module, "parametrizations"): # Verify that the current module does not have already associated a RotatedModule - if len([r for r in rewriters if r.old_module_instance is module]) == 0: + if len([r for r in rewriters if r.old_module_instance is module and + isinstance(r, ModuleInstanceToModuleInstance)]) == 0: # Identity rewriter, only useful externaly rewriter = ModuleInstanceToModuleInstance(module, module) extra_rewriters.append(rewriter) diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 45a2124d1..20cdae543 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -4,7 +4,9 @@ import argparse from copy import deepcopy import functools +import os import sys +from typing import Callable, List from warnings import warn from lm_eval import evaluator @@ -20,6 +22,7 @@ from brevitas.export import export_torch_qcdq from brevitas.export.inference.manager import quant_inference_mode from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager +from brevitas.graph.base import ModuleInstanceFuseRotationWeights from brevitas.graph.equalize import GraphRotationEqualization from brevitas.graph.equalize import LayerwiseActivationRotation from brevitas.graph.quantize import layerwise_quantize @@ -79,6 +82,37 @@ def set_seed(seed): torch.random.manual_seed(seed) +def on_process(process_index: int): + + def decorator(func: Callable): + + @functools.wraps(func) + def _wrapper(model, *args, **kwargs): + curr_process_index = int(os.environ.get('LOCAL_RANK', -1)) + + if curr_process_index == -1 or (process_index == curr_process_index): + print(f"Applying {func.__name__} on process index {curr_process_index}") + return func(model, *args, **kwargs) + else: + print(f"Skipping function {func.__name__} on process index {curr_process_index}") + return model + + return _wrapper + + return decorator + + +def apply_fused_rotations(model: torch.nn.Module, rewriters: List) -> torch.nn.Module: + model = offload_model(model) + for r in rewriters: + if isinstance(r, ModuleInstanceFuseRotationWeights): + model = r.apply(model) + remove_hooks(model) + return model + + +# TODO: Use no_grad? The result of fusing the rotations would yield tensor with requires_grad set to False, +# which might no be a problem, as that flag is set in the appropiate QAT/PTQ algorithms. def fused_rotation_no_fx( model, calibration_loader, @@ -93,7 +127,6 @@ def fused_rotation_no_fx( for r in rewriters: r.apply(model) - new_model = offload_model(new_model) eq = GraphRotationEqualization( orphan_sink=args.rotation_orphan_sink, full_rotation_method=args.rotation_mode, @@ -103,17 +136,17 @@ def fused_rotation_no_fx( find_self_attention_rotation_regions( new_model, model.config.hidden_size // model.config.num_attention_heads) if add_self_attention_regions else None) - new_model, rewriters = eq.apply(new_model, fuse_rotations=fuse_rotations, additional_regions=self_attention_regions) - # Additional rewriters need to be added if rotations are not fused - if not fuse_rotations: - rewriters_unfused_rotations = extract_rewriters_unfused_rotations(new_model, rewriters) - rewriters.extend(rewriters_unfused_rotations) - + new_model, rewriters = eq.apply(new_model, fuse_rotations=fuse_rotations, additional_regions=self_attention_regions, apply_inplace_rotations=False) + # Rewriters need to be fixed to point to the module instances of the original model rewriters = fix_rewriter(rewriters, model, 'weight') - + # The weights of the FX model and the original model are tied, so the rotation fusing has already been applied. + # Note that the parametrization registration cannot be done in a model that has been offloaded using + # offload_model, as the change in the state dictionary when registering the parametrization causes the removal + # of the hooks to crash. This is due to the fact that the device_map in the AlignDevicesHook is no longer valid. + model = apply_fused_rotations(model, rewriters) for r in rewriters: - r.apply(model) - remove_hooks(new_model) + if not isinstance(r, ModuleInstanceFuseRotationWeights): + model = r.apply(model) def set_seed(seed): @@ -296,6 +329,14 @@ def quantize_llm(args, unknown_args=None): del calibration_loader[i]["attention_mask"] calibration_loader[i]["labels"] = calibration_loader[i]["input_ids"] + def mock_save_pretrained_fn(*args, **kwargs): + pass + + # For a PretrainedModel, the Trainer in accelerate calls save_pretrained after + # finishing the optimization. However, this method no longer works after + # registering parametrizations/quantizing, so this method is mocked to prevent + # a crash. + model.save_pretrained = mock_save_pretrained_fn model.config.use_cache = False model.config.loss_type = "ForCausalLM" @@ -405,7 +446,6 @@ def quantize_llm(args, unknown_args=None): quantize_embedding=False) if not args.quantize_last_layer: if require_fx: - # TODO: Fix when using UnfusedRotation, layer_map[type(last_module)][1] crashes last_node = [node for node in model.graph.nodes if node.op == 'call_module'][-1] last_module = get_module(model, last_node.target) last_layer_kwargs = layer_map[type(last_module)][1] @@ -468,6 +508,8 @@ def quantize_llm(args, unknown_args=None): unknown_args=unknown_args, ) + remove_hooks(model) + if args.act_calibration: print("Apply act calibration...") apply_calibration(model, calibration_loader) diff --git a/tests/brevitas/graph/test_equalization.py b/tests/brevitas/graph/test_equalization.py index 1e913845c..2acf8287b 100644 --- a/tests/brevitas/graph/test_equalization.py +++ b/tests/brevitas/graph/test_equalization.py @@ -361,8 +361,7 @@ def test_composition_unfused_rotations(N): RotationWeightParametrization( rot_mat=rot_mat, rot_func=_apply_ort_device, - input_axis=_get_input_axis(rot_module), - output_axis=_get_output_axis(rot_module), + axis=_get_output_axis(rot_module) if is_source else _get_input_axis(rot_module), is_source=is_source, is_sink=is_sink, is_orphan=is_orphan, From e3d853f583dfa5cbb29892a8cba93539b576296c Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Thu, 2 Jan 2025 12:39:40 +0000 Subject: [PATCH 11/12] Fix multi-GPU setup --- src/brevitas_examples/llm/main.py | 110 ++++++++++++++++++------------ 1 file changed, 66 insertions(+), 44 deletions(-) diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 20cdae543..8abee55e9 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -23,6 +23,7 @@ from brevitas.export.inference.manager import quant_inference_mode from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager from brevitas.graph.base import ModuleInstanceFuseRotationWeights +from brevitas.graph.base import Transform from brevitas.graph.equalize import GraphRotationEqualization from brevitas.graph.equalize import LayerwiseActivationRotation from brevitas.graph.quantize import layerwise_quantize @@ -82,27 +83,30 @@ def set_seed(seed): torch.random.manual_seed(seed) -def on_process(process_index: int): +def is_main_process(): + return int(os.environ.get('LOCAL_RANK', -1)) in [-1, 0] - def decorator(func: Callable): - @functools.wraps(func) - def _wrapper(model, *args, **kwargs): - curr_process_index = int(os.environ.get('LOCAL_RANK', -1)) - if curr_process_index == -1 or (process_index == curr_process_index): - print(f"Applying {func.__name__} on process index {curr_process_index}") - return func(model, *args, **kwargs) - else: - print(f"Skipping function {func.__name__} on process index {curr_process_index}") - return model +def on_process(func: Callable, process_index: int): + + @functools.wraps(func) + def _wrapper(model, *args, **kwargs): + curr_process_index = int(os.environ.get('LOCAL_RANK', -1)) - return _wrapper + if curr_process_index == -1 or (process_index == curr_process_index): + print(f"Applying {func.__name__} on process index {curr_process_index}") + return func(model, *args, **kwargs) + else: + print(f"Skipping function {func.__name__} on process index {curr_process_index}") + return model - return decorator + return _wrapper +on_main_process = functools.partial(on_process, process_index=0) -def apply_fused_rotations(model: torch.nn.Module, rewriters: List) -> torch.nn.Module: +@on_main_process +def apply_fused_rotations(model: torch.nn.Module, rewriters: List[Transform]) -> torch.nn.Module: model = offload_model(model) for r in rewriters: if isinstance(r, ModuleInstanceFuseRotationWeights): @@ -111,6 +115,15 @@ def apply_fused_rotations(model: torch.nn.Module, rewriters: List) -> torch.nn.M return model +@on_main_process +def evaluate_model(model: torch.nn.Module, validation_loader, args, tokenizer): + model = offload_model(model) + quant_ppl = compute_perplexity( + model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer) + print(f"Perplexity ({args.dataset}): {quant_ppl:.3f}") + remove_hooks(model) + + # TODO: Use no_grad? The result of fusing the rotations would yield tensor with requires_grad set to False, # which might no be a problem, as that flag is set in the appropiate QAT/PTQ algorithms. def fused_rotation_no_fx( @@ -313,12 +326,9 @@ def quantize_llm(args, unknown_args=None): if args.eval: assert args.export_target != 'torch_qcdq', "TorchScript QCDQ export and Evaluation simultaneously" - print("Float model eval...") - model = offload_model(model) - float_ppl = compute_perplexity( - model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer) - remove_hooks(model) - print(f"Float perplexity ({args.dataset}): {float_ppl:.3f}") + print("Evaluating float model...") + evaluate_model(model, validation_loader, args, tokenizer) + print("Float evaluation done.") if args.replace_rmsnorm: model = replace_rmsnorm_with_torch(model, model.config) @@ -473,32 +483,43 @@ def mock_save_pretrained_fn(*args, **kwargs): if args.bias_corr: model = add_zero_bias_to_linear(model) - model = offload_model(model) - - dict_hooks = dict() - - # When offloading to CPU + GPU, the CPU scale factors must be updated - # before we move them back to the meta device. - # If we don't, we lose the new value but the internal flag "init_done" is True, thus we will use the wrong scale. - # To do this, we attach a "hook" to the post_forward function, called before the post_forward - # The function will update the dict with the initialized scales - for m in model.modules(): - if hasattr(m, '_hf_hook'): - if m._hf_hook.weights_map is not None: - # We store the original function to be restored later - dict_hooks[m] = m._hf_hook.post_forward - new_funct = functools.partial(update_internal_dict, m) - m._hf_hook.post_forward = hooked_on_a_function(m._hf_hook.post_forward, new_funct) + # We need to run a calibration forward pass to initialize quantization-related parameters, + # e.g. scales. In DDP, as parameters are synchronized across replicas before optimization, + # it is not needed to run this pass for every process, as the parameters of the main + # process will be broadcasted to each replica. + if is_main_process(): + model = offload_model(model) - with torch.no_grad(): - model(**calibration_loader[0]) + dict_hooks = dict() + + # When offloading to CPU + GPU, the CPU scale factors must be updated + # before we move them back to the meta device. + # If we don't, we lose the new value but the internal flag "init_done" is True, thus we will use the wrong scale. + # To do this, we attach a "hook" to the post_forward function, called before the post_forward + # The function will update the dict with the initialized scales + for m in model.modules(): + if hasattr(m, '_hf_hook'): + if m._hf_hook.weights_map is not None: + # We store the original function to be restored later + dict_hooks[m] = m._hf_hook.post_forward + new_funct = functools.partial(update_internal_dict, m) + m._hf_hook.post_forward = hooked_on_a_function(m._hf_hook.post_forward, new_funct) + + with torch.no_grad(): + model(**calibration_loader[0]) - # We restore the original behaviour of the post-forward. - for k, v in dict_hooks.items(): - k._hf_hook.post_forward = v + # We restore the original behaviour of the post-forward. + for k, v in dict_hooks.items(): + k._hf_hook.post_forward = v - # TODO: Refactor - remove_hooks(model) + # TODO: Refactor + remove_hooks(model) + else: + # TODO: Generalize this logic. Currently, only ParameterFromStatsFromParameterZeroPoint + # and ParameterFromStatsFromParameterScaling have the attribute init_done + for module in model.modules(): + if hasattr(module, "init_done"): + module.init_done = True if args.rotation in ['fused_no_fx_optimize', 'fused_no_fx_optimize_self_attn_region']: apply_rotation_optimization( @@ -509,6 +530,7 @@ def mock_save_pretrained_fn(*args, **kwargs): ) remove_hooks(model) + torch.cuda.empty_cache() if args.act_calibration: print("Apply act calibration...") @@ -562,7 +584,7 @@ def mock_save_pretrained_fn(*args, **kwargs): apply_bias_correction(model, calibration_loader) print("Bias correction applied.") - if args.eval and not args.no_quantize: + if args.eval and not args.no_quantize and is_main_process(): print("Model eval...") with torch.no_grad(), quant_inference_mode(model): model(**calibration_loader[0]) From 11a6bb32f67a3aa550dc87b72f2b239183507ecd Mon Sep 17 00:00:00 2001 From: Pablo Monteagudo Lago Date: Tue, 7 Jan 2025 10:46:56 +0000 Subject: [PATCH 12/12] Partial fix to integration scales/parametrizations --- src/brevitas/graph/base.py | 46 +- src/brevitas/graph/equalize.py | 2 +- .../common/accelerate_utils/modeling.py | 449 ++++++++++++++++++ .../llm/llm_quant/rotation_optimization.py | 31 +- src/brevitas_examples/llm/main.py | 26 +- tests/brevitas_examples/test_modeling.py | 56 +++ 6 files changed, 603 insertions(+), 7 deletions(-) create mode 100644 src/brevitas_examples/common/accelerate_utils/modeling.py create mode 100644 tests/brevitas_examples/test_modeling.py diff --git a/src/brevitas/graph/base.py b/src/brevitas/graph/base.py index 1546ecb67..744eccfba 100644 --- a/src/brevitas/graph/base.py +++ b/src/brevitas/graph/base.py @@ -157,6 +157,50 @@ def _init_new_module(self, old_module: Module, name=None): return new_module def _replace_old_module(self, model, old_module, new_module, load_state_dict=True): + replace_module(model, old_module, new_module) + if load_state_dict: + # The dictionary entries relative to parametrizations need to be ignored, as these are passed + # when invoking transfer_parametrizations_and_params. + old_module_state_dict = old_module.state_dict() + + # If the model is parametrized filter the state_dict appropiately + if parametrize.is_parametrized(old_module): + # Map the keys "parametrizations.tensor_name.original" to "tensor_name" + keys_to_remove = [] + keys_value_to_add = [] + for key, value in old_module_state_dict.items(): + split_key = key.split(".") + if len(split_key) >= 3 and split_key[-3] == "parametrizations" and split_key[ + -1] == "original": + tensor_name = split_key[-2] + keys_value_to_add.append((".".join(split_key[:-3] + [tensor_name]), value)) + # We need to remove all the keys corresponding to the parametrizations added to the model + # to make sure the dictionary can be loaded with no missing/unused keys + # NOTE: For safety, an additional check could be added as this would not work if a model + # without parametrizations has any key containing "parametrizations" + if "parametrizations" in split_key: + keys_to_remove.append(key) + # The modifications need to be reflected in old_module_state_dict + for key in keys_to_remove: + del old_module_state_dict[key] + for key, value in keys_value_to_add: + old_module_state_dict[key] = value + + # Note that strict is set to True, as all the adaptations to the state dict were performed + new_module.load_state_dict(old_module_state_dict) + # If the old module is parametrized, these need to be transferred to the new module + # We do not rely on the method transfer_parametrizations_and_params as using it can result + # in parameter ties being broken + # Note that unsafe is set to True for efficiency, as the checks should have been done + # when first registering the parametrization to old_module + if parametrize.is_parametrized(old_module): + for tensor_name in old_module.parametrizations: + for param_func in old_module.parametrizations[tensor_name]: + parametrize.register_parametrization( + new_module, tensor_name, param_func, unsafe=True) + + # TODO: Remove after debugging + def _replace_old_module_legacy(self, model, old_module, new_module, load_state_dict=True): replace_module(model, old_module, new_module) if load_state_dict: # The dictionary entries relative to parametrizations need to be ignored, as these are passed @@ -204,7 +248,7 @@ def apply(self, model: GraphModule) -> GraphModule: if old_module is self.old_module_instance: # register the parametrization in the old_module parametrize.register_parametrization( - old_module, self.tensor_name, self.parametrization_module) + old_module, self.tensor_name, self.parametrization_module, unsafe=True) break return model diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 0da49233c..4b9afdafc 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -1443,7 +1443,7 @@ def _apply_rotate( if insert_rotation_module and len(region.srcs) == 0: rewriter = ModuleInstanceWrapModule( module, RotatedModule, "layer", { - "had_mat": rot_mat, "k": K}) + "had_mat": None, "k": K}) rewriters.append(rewriter) for r in rewriters: # The parametrizations need to be registered after the potential HF hooks have been diff --git a/src/brevitas_examples/common/accelerate_utils/modeling.py b/src/brevitas_examples/common/accelerate_utils/modeling.py new file mode 100644 index 000000000..11e1d0ffb --- /dev/null +++ b/src/brevitas_examples/common/accelerate_utils/modeling.py @@ -0,0 +1,449 @@ +# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from collections import defaultdict +from collections import OrderedDict +import logging +from typing import Dict, List, Optional, Union +import warnings + +from accelerate.utils.modeling import _get_proper_dtype +from accelerate.utils.modeling import check_tied_parameters_in_config +from accelerate.utils.modeling import clean_device_map +from accelerate.utils.modeling import compute_module_total_buffer_size +from accelerate.utils.modeling import dtype_byte_size +from accelerate.utils.modeling import find_tied_parameters +from accelerate.utils.modeling import get_max_layer_size +from accelerate.utils.modeling import get_max_memory +from accelerate.utils.modeling import get_non_persistent_buffers +import torch +from torch import nn + +logger = logging.getLogger(__name__) + + +def named_module_tensors( + module: nn.Module, + include_buffers: bool = True, + recurse: bool = False, + remove_non_persistent: bool = False, + remove_duplicate: bool = False, +): + """ + A helper function that gathers all the tensors (parameters + buffers) of a given module. If `include_buffers=True` + it's the same as doing `module.named_parameters(recurse=recurse) + module.named_buffers(recurse=recurse)`. + + Args: + module (`torch.nn.Module`): + The module we want the tensors on. + include_buffer (`bool`, *optional*, defaults to `True`): + Whether or not to include the buffers in the result. + recurse (`bool`, *optional`, defaults to `False`): + Whether or not to go look in every submodule or just return the direct parameters and buffers. + remove_non_persistent (`bool`, *optional*, defaults to `False`): + Whether or not to remove the non persistent buffer from the buffers. Useful only when include_buffers = + True + """ + yield from module.named_parameters(recurse=recurse, remove_duplicate=remove_duplicate) + + if include_buffers: + non_persistent_buffers = set() + if remove_non_persistent: + non_persistent_buffers = get_non_persistent_buffers(module, recurse=recurse) + for named_buffer in module.named_buffers(recurse=recurse): + name, _ = named_buffer + if name not in non_persistent_buffers: + yield named_buffer + + +def compute_module_sizes( + model: nn.Module, + dtype: Optional[Union[str, torch.device]] = None, + special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None, + buffers_only: bool = False, +): + """ + Compute the size of each submodule of a given model. + """ + if dtype is not None: + dtype = _get_proper_dtype(dtype) + dtype_size = dtype_byte_size(dtype) + if special_dtypes is not None: + special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()} + special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()} + module_sizes = defaultdict(int) + + module_list = [] + + if not buffers_only: + module_list = named_module_tensors(model, recurse=True) + else: + module_list = model.named_buffers(recurse=True) + + for name, tensor in module_list: + if special_dtypes is not None and name in special_dtypes: + size = tensor.numel() * special_dtypes_size[name] + elif dtype is None: + size = tensor.numel() * dtype_byte_size(tensor.dtype) + elif str(tensor.dtype).startswith(("torch.uint", "torch.int", "torch.bool")): + # According to the code in set_module_tensor_to_device, these types won't be converted + # so use their original size here + size = tensor.numel() * dtype_byte_size(tensor.dtype) + else: + size = tensor.numel() * min(dtype_size, dtype_byte_size(tensor.dtype)) + name_parts = name.split(".") + for idx in range(len(name_parts) + 1): + module_sizes[".".join(name_parts[:idx])] += size + + return module_sizes + + +def infer_auto_device_map( + model: nn.Module, + max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None, + no_split_module_classes: Optional[List[str]] = None, + dtype: Optional[Union[str, torch.dtype]] = None, + special_dtypes: Optional[Dict[str, Union[str, torch.dtype]]] = None, + verbose: bool = False, + clean_result: bool = True, + offload_buffers: bool = False, +): + """ + Compute a device map for a given model giving priority to GPUs, then offload on CPU and finally offload to disk, + such that: + - we don't exceed the memory available of any of the GPU. + - if offload to the CPU is needed, there is always room left on GPU 0 to put back the layer offloaded on CPU that + has the largest size. + - if offload to the CPU is needed,we don't exceed the RAM available on the CPU. + - if offload to the disk is needed, there is always room left on the CPU to put back the layer offloaded on disk + that has the largest size. + + + + All computation is done analyzing sizes and dtypes of the model parameters. As a result, the model can be on the + meta device (as it would if initialized within the `init_empty_weights` context manager). + + + + Args: + model (`torch.nn.Module`): + The model to analyze. + max_memory (`Dict`, *optional*): + A dictionary device identifier to maximum memory. Will default to the maximum memory available if unset. + Example: `max_memory={0: "1GB"}`. + no_split_module_classes (`List[str]`, *optional*): + A list of layer class names that should never be split across device (for instance any layer that has a + residual connection). + dtype (`str` or `torch.dtype`, *optional*): + If provided, the weights will be converted to that type when loaded. + special_dtypes (`Dict[str, Union[str, torch.device]]`, *optional*): + If provided, special dtypes to consider for some specific weights (will override dtype used as default for + all weights). + verbose (`bool`, *optional*, defaults to `False`): + Whether or not to provide debugging statements as the function builds the device_map. + clean_result (`bool`, *optional*, defaults to `True`): + Clean the resulting device_map by grouping all submodules that go on the same device together. + offload_buffers (`bool`, *optional*, defaults to `False`): + In the layers that are offloaded on the CPU or the hard drive, whether or not to offload the buffers as + well as the parameters. + """ + # Get default / clean up max_memory + max_memory = get_max_memory(max_memory) + print(max_memory) + if no_split_module_classes is None: + no_split_module_classes = [] + elif not isinstance(no_split_module_classes, (list, tuple)): + no_split_module_classes = [no_split_module_classes] + + devices = list(max_memory.keys()) + if "disk" not in devices: + devices.append("disk") + gpus = [device for device in devices if device not in ["cpu", "disk"]] + + # Devices that need to keep space for a potential offloaded layer. + if "mps" in gpus: + main_devices = ["mps"] + elif len(gpus) > 0: + main_devices = [gpus[0], "cpu"] + else: + main_devices = ["cpu"] + + module_sizes = compute_module_sizes(model, dtype=dtype, special_dtypes=special_dtypes) + print({k: v for k, v in module_sizes.items() if len(k.split(".")) <= 3}) + tied_parameters = find_tied_parameters(model) + + if check_tied_parameters_in_config(model) and len(tied_parameters) == 0: + logger.warn( + "The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function." + ) + + device_map = OrderedDict() + current_device = 0 + current_memory_used = 0 + device_memory_used = {} + device_buffer_sizes = {} + + # Direct submodules and parameters + modules_to_treat = ( + list(model.named_parameters(recurse=False)) + list(model.named_children()) + + list(model.named_buffers(recurse=False))) + # Initialize maximum largest layer, to know which space to keep in memory + max_layer_size, max_layer_names = get_max_layer_size(modules_to_treat, module_sizes, no_split_module_classes) + + # Ready ? This is going to be a bit messy. + while len(modules_to_treat) > 0: + name, module = modules_to_treat.pop(0) + if verbose: + print(f"\nTreating module {name}.") + # Max size in the remaining layers may have changed since we took one, so we maybe update it. + max_layer_names = [n for n in max_layer_names if n != name and not n.startswith(name + ".")] + if len(max_layer_names) == 0: + max_layer_size, max_layer_names = get_max_layer_size( + [(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)], + module_sizes, + no_split_module_classes, + ) + # Assess size needed + module_size = module_sizes[name] + + # We keep relevant tied parameters only: one of the tied parameters in the group is inside the current module + # and the other is not. + # Note: If we are currently processing the name `compute.weight`, an other parameter named e.g. `compute.weight_submodule.parameter` + # needs to be considered outside the current module, hence the check with additional dots. + tied_param_goups = [ + tied_group for tied_group in tied_parameters + if any(name + "." in k + "." + for k in tied_group) and not all(name + "." in k + "." for k in tied_group)] + + if verbose and len(tied_param_goups) > 0: + print(f" Found the relevant tied param groups {tied_param_goups}") + + # Then we keep track of all the parameters that are tied to the current module, but not in the current module + tied_params = sum([[ + p + for p in tied_group + if name + "." not in p + "." and any( + name_tied + "." in p + "." + for name_tied, _ in modules_to_treat)] + for tied_group in tied_param_goups], []) + + if verbose and len(tied_params) > 0: + print(f" So those parameters need to be taken into account {tied_params}") + + device = devices[current_device] + current_max_size = max_memory[device] if device != "disk" else None + current_memory_reserved = 0 + # Reduce max size available by the largest layer. + if devices[current_device] in main_devices: + current_max_size = current_max_size - max_layer_size + current_memory_reserved = max_layer_size + # Case 1 -> We're too big! + if current_max_size is not None and current_memory_used + module_size > current_max_size: + # Split or not split? + modules_children = ([] if isinstance(module, nn.Parameter) or + isinstance(module, torch.Tensor) else list(module.named_children())) + if verbose: + print( + f"Not enough space on {devices[current_device]} to put {name} (space available " + f"{current_max_size - current_memory_used}, module size {module_size}).") + if len(modules_children) == 0 or module.__class__.__name__ in no_split_module_classes: + # -> no split, we go to the next device + if verbose: + print("This module cannot be split, going to the next device.") + + device_memory_used[device] = current_memory_used + current_memory_reserved + current_device += 1 + modules_to_treat = [(name, module)] + modules_to_treat + current_memory_used = 0 + else: + # -> split, we replace the module studied by its children + parameters + if verbose: + print(f"Splitting {name}.") + modules_children = list(module.named_parameters(recurse=False)) + modules_children + modules_to_treat = [ + (f"{name}.{n}", v) for n, v in modules_children] + modules_to_treat + # Update the max layer size. + max_layer_size, max_layer_names = get_max_layer_size( + [(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)], + module_sizes, + no_split_module_classes, + ) + + # Case 2, it fits! We're not entirely out of the wood though, because we may have some tied parameters. + elif len(tied_params) > 0: + # First locate all tied modules + tied_module_names = [] + tied_modules = [] + for tied_param in tied_params: + tied_module_index = [ + i for i, (n, _) in enumerate(modules_to_treat) if (n + ".") in tied_param][0] + tied_module_names.append(modules_to_treat[tied_module_index][0]) + tied_modules.append(modules_to_treat[tied_module_index][1]) + if verbose: + print( + f" It looks like {name} is going to fit on {devices[current_device]} but we have tied " + f"parameters to account for.\n - Names {tied_params}\n - Module names {tied_module_names}" + ) + + # Let's see if it all fits first + module_size_with_ties = module_size + for tied_param, tied_module_name in zip(tied_params, tied_module_names): + module_size_with_ties += module_sizes[tied_module_name] - module_sizes[tied_param] + + if current_max_size is None or current_memory_used + module_size_with_ties <= current_max_size: + # We really really fit! + if verbose: + print(f"Putting {name} and {tied_module_names} on {devices[current_device]}.") + current_memory_used += module_size_with_ties + device_map[name] = devices[current_device] + for tied_module_name in tied_module_names: + if tied_module_name in [m[0] for m in modules_to_treat]: + # The module may have been removed by a previous iteration of this loop. + tied_module_index = [ + i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name + ][0] + modules_to_treat.pop(tied_module_index) + device_map[tied_module_name] = devices[current_device] + + if not offload_buffers and isinstance(module, nn.Module): + current_buffer_size = compute_module_total_buffer_size( + module, dtype=dtype, special_dtypes=special_dtypes) + device_buffer_sizes[device] = device_buffer_sizes.get( + device, 0) + current_buffer_size + + else: + # We don't fit with the tied modules. Next question is: can we split one of the tied modules to make it + # smaller or do we need to go on the next device? + if verbose: + print( + f"Not enough space on {devices[current_device]} to put {name} and {tied_module_names} (space " + f"available {current_max_size - current_memory_used}, needed size {module_size_with_ties})." + ) + split_happened = False + for tied_module_name, tied_module in zip(tied_module_names, tied_modules): + tied_module_children = list(tied_module.named_children()) + if len(tied_module_children + ) == 0 or tied_module.__class__.__name__ in no_split_module_classes: + # can't break this one. + continue + + if verbose: + print(f"Splitting {tied_module_name}.") + tied_module_children = list( + tied_module.named_parameters(recurse=False)) + tied_module_children + tied_module_children = [ + (f"{tied_module_name}.{n}", v) for n, v in tied_module_children] + tied_module_index = [ + i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name][0] + + modules_to_treat = ([(name, module)] + modules_to_treat[:tied_module_index] + + tied_module_children + + modules_to_treat[tied_module_index + 1:]) + # Update the max layer size. + max_layer_size, max_layer_names = get_max_layer_size( + [(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)], + module_sizes, + no_split_module_classes, + ) + split_happened = True + break + + # Before going to the next device, we check if we can allocate a subset of the modules + # with tied parameters to the current device + split_tied = False + # This is the simplest heuristic for partitioning the graph and being able to fill the + # device, while separating the modules that contain tied parameters. However, it is + # optimal in a fully connected topology where all the modules take the same memory and + # all of them share the same tied parameters. + for i in reversed(range(len(tied_params))): + subset_tied_params = tied_params[:i] + subset_tied_tied_module_names = tied_module_names[:i] + + module_size_with_ties = module_size + for tied_param, tied_module_name in zip(subset_tied_params, subset_tied_tied_module_names): + module_size_with_ties += module_sizes[tied_module_name] - module_sizes[ + tied_param] + + # Does this subset of modules fit? + if current_max_size is None or current_memory_used + module_size_with_ties <= current_max_size: + # We really really fit! + if verbose: + print( + f"Putting {name} and {subset_tied_tied_module_names} on {devices[current_device]}." + ) + current_memory_used += module_size_with_ties + device_map[name] = devices[current_device] + for tied_module_name in subset_tied_tied_module_names: + if tied_module_name in [m[0] for m in modules_to_treat]: + # The module may have been removed by a previous iteration of this loop. + tied_module_index = [ + i for i, (n, _) in enumerate(modules_to_treat) + if n == tied_module_name][0] + modules_to_treat.pop(tied_module_index) + device_map[tied_module_name] = devices[current_device] + + if not offload_buffers and isinstance(module, nn.Module): + current_buffer_size = compute_module_total_buffer_size( + module, dtype=dtype, special_dtypes=special_dtypes) + device_buffer_sizes[device] = device_buffer_sizes.get( + device, 0) + current_buffer_size + + # A maximal subset of modules was identified and set to the appropiate devices + split_tied = True + break + + # Continue the processing of modules + if split_tied: + continue + + if not split_happened: + # If the tied module is not split, we go to the next device + if verbose: + print("None of the tied module can be split, going to the next device.") + + device_memory_used[device] = current_memory_used + current_memory_reserved + current_device += 1 + modules_to_treat = [(name, module)] + modules_to_treat + current_memory_used = 0 + + else: + if verbose: + if current_max_size is None: + print(f"Putting {name} (size={module_size}) on {devices[current_device]}.") + else: + print( + f"Putting {name} (size={module_size}) on {devices[current_device]} " + f"(available={current_max_size - current_memory_used}).") + current_memory_used += module_size + device_memory_used[device] = current_memory_used + current_memory_reserved + device_map[name] = devices[current_device] + + if not offload_buffers and isinstance(module, nn.Module): + current_buffer_size = compute_module_total_buffer_size( + module, dtype=dtype, special_dtypes=special_dtypes) + device_buffer_sizes[device] = device_buffer_sizes.get( + device, 0) + current_buffer_size + + if clean_result: + device_map = clean_device_map(device_map) + + non_gpu_buffer_size = device_buffer_sizes.get("cpu", 0) + device_buffer_sizes.get("disk", 0) + if non_gpu_buffer_size > 0 and not offload_buffers: + is_buffer_fit_any_gpu = False + for gpu_device, gpu_max_memory in max_memory.items(): + if gpu_device == "cpu" or gpu_device == "disk": + continue + + if not is_buffer_fit_any_gpu: + gpu_memory_used = device_memory_used.get(gpu_device, 0) + + if gpu_max_memory >= non_gpu_buffer_size + gpu_memory_used: + is_buffer_fit_any_gpu = True + + if len(gpus) > 0 and not is_buffer_fit_any_gpu: + warnings.warn( + f"Current model requires {non_gpu_buffer_size} bytes of buffer for offloaded layers, which seems does " + f"not fit any GPU's remaining memory. If you are experiencing a OOM later, please consider using " + f"offload_buffers=True.") + + return device_map diff --git a/src/brevitas_examples/llm/llm_quant/rotation_optimization.py b/src/brevitas_examples/llm/llm_quant/rotation_optimization.py index 618763498..1d1cb8c46 100644 --- a/src/brevitas_examples/llm/llm_quant/rotation_optimization.py +++ b/src/brevitas_examples/llm/llm_quant/rotation_optimization.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from dataclasses import field -from typing import Optional, Tuple +from typing import Optional, override, Tuple import torch from torch.utils.data import Dataset @@ -20,7 +20,7 @@ @dataclass class ModelArguments: input_model: Optional[str] = field( - default="meta-llama/Llama-3.2-1B", metadata={"help": "Input model"}) + default="meta-llama/Llama-3.2-3B", metadata={"help": "Input model"}) output_rotation_path: Optional[str] = field( default="test-output", metadata={"help": "Output rotation checkpoint path"}) optimized_rotation_path: Optional[str] = field( @@ -72,6 +72,30 @@ def collate_fn(kwargs_list, return_tensors="pt"): return kwargs +class FSDPTrainer(Trainer): + + def __init__(self, optimizers, **kwargs): + super().__init__(**kwargs) + self.optimizer, self.lr_scheduler = optimizers + + @override + def create_optimizer_and_scheduler(self, num_training_steps: int): + """ + Setup the optimizer and the learning rate scheduler. + + We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the + Trainer's init through `optimizers`, or subclass and override this method (or `create_optimizer` and/or + `create_scheduler`) in a subclass. + """ + + # Overwrite optimizer creation because optimizer is already created + optimizer = self.optimizer + self.create_scheduler( + num_training_steps=num_training_steps, + optimizer=optimizer, + ) + + def apply_rotation_optimization( graph_model: torch.fx.GraphModule, tokenizer: PreTrainedTokenizerBase, @@ -87,6 +111,9 @@ def apply_rotation_optimization( for rot_mat in trainable_rotations: rot_mat.requires_grad = True optimizer = SGDG(trainable_rotations, lr=training_args.learning_rate, stiefel=True) + # TODO: Enable for multiprocess + # torch.distributed.barrier() + torch.cuda.empty_cache() trainer = Trainer( model=graph_model, tokenizer=tokenizer, diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 8abee55e9..3a0371c28 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -14,6 +14,7 @@ import numpy as np from optimum.exporters.onnx import onnx_export_from_model import torch +import torch.distributed as dist from transformers import AutoModelForCausalLM from transformers import AutoTokenizer from transformers.utils.fx import _SUPPORTED_MODELS @@ -56,6 +57,7 @@ from brevitas_examples.llm.llm_quant.rotation_utils import find_self_attention_rotation_regions from brevitas_examples.llm.llm_quant.prepare_for_quantize import \ replace_sdpa_with_quantizable_layers +from brevitas_examples.llm.llm_quant.rotation_utils import fuse_rotations from brevitas_examples.llm.llm_quant.run_utils import CastFloat16ToFloat32 from brevitas_examples.llm.llm_quant.run_utils import fix_rewriter from brevitas_examples.llm.llm_quant.run_utils import get_fx @@ -87,6 +89,9 @@ def is_main_process(): return int(os.environ.get('LOCAL_RANK', -1)) in [-1, 0] +def is_multi_process(): + return int(os.environ.get('LOCAL_RANK', -1)) != -1 + def on_process(func: Callable, process_index: int): @@ -135,6 +140,7 @@ def fused_rotation_no_fx( with torch.no_grad(): new_model, guards = torch._dynamo.export(model)(**calibration_loader[0]) apply_layernorm_affine_merge(new_model) + # NOTE: This call breaks ties between the the lm_head and the embedding layer new_model, rewriters = apply_layernorm_to_rmsnorm(new_model, return_rewriters=True) rewriters = fix_rewriter(rewriters, model, 'weight') @@ -262,6 +268,10 @@ def validate(args): def quantize_llm(args, unknown_args=None): validate(args) set_seed(args.seed) + # TODO: Validate + if is_multi_process(): + dist.init_process_group(backend="nccl") + if args.export_prefix is None: args.export_prefix = f"{args.model.replace('/', '--')}" @@ -389,6 +399,9 @@ def mock_save_pretrained_fn(*args, **kwargs): fused_rotation_no_fx( model, calibration_loader, args, fuse_rotations=False, add_self_attention_regions=True) + # TODO: Validate + if is_multi_process(): + torch.distributed.barrier() # Insert standard MHA layers when performing fx based weight/act equalization to avoid dealing # with all the variability in HF implementations if args.replace_mha: @@ -528,9 +541,16 @@ def mock_save_pretrained_fn(*args, **kwargs): train_dataset=calibration_loader, unknown_args=unknown_args, ) - - remove_hooks(model) - torch.cuda.empty_cache() + remove_hooks(model) + # In the main process, rotations can be fused for evaluation + # TODO: Validate + if is_main_process(): + # Offload model before fusing the rotations + model = offload_model(model) + # TODO: Make sure that ties are kept + # Fuse rotations with weights + fuse_rotations(model) + remove_hooks(model) if args.act_calibration: print("Apply act calibration...") diff --git a/tests/brevitas_examples/test_modeling.py b/tests/brevitas_examples/test_modeling.py new file mode 100644 index 000000000..ba7cd96b7 --- /dev/null +++ b/tests/brevitas_examples/test_modeling.py @@ -0,0 +1,56 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict +from typing import Dict, Optional +import unittest + +import torch +import torch.nn as nn + +from brevitas_examples.common.accelerate_utils.modeling import infer_auto_device_map + + +class ModelForTest(nn.Module): + + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(3, 4) + self.batchnorm = nn.BatchNorm1d(4) + self.linear2 = nn.Linear(4, 5) + + def forward(self, x): + return self.linear2(self.batchnorm(self.linear1(x))) + + +class ModelingUtilsTester(unittest.TestCase): + + def test_infer_auto_device_map_tied_weights_split(self): + model = nn.Sequential(OrderedDict([("layer1", ModelForTest()), ("layer2", ModelForTest())])) + expected_sizes = {"": 236, "linear1": 64, "linear1.weight": 48, "linear1.bias": 16} + expected_sizes.update({"linear2": 100, "linear2.weight": 80, "linear2.bias": 20}) + expected_sizes.update({"batchnorm": 72, "batchnorm.weight": 16, "batchnorm.bias": 16}) + expected_sizes.update({ + "batchnorm.running_mean": 16, + "batchnorm.running_var": 16, + "batchnorm.num_batches_tracked": 8}) + # model has size 236: linear1 64, batchnorm 72, linear2 100 + model.layer1.linear1.weight = model.layer2.linear1.weight + device_map = infer_auto_device_map( + model, + max_memory={ + 0: 236, 1: 236, 2: 236}, + verbose=True, + no_split_module_classes=[ModelForTest.__name__]) + assert device_map == {"layer1": 1, "layer2": 2}