Skip to content

Commit

Permalink
Add new flag to entrypoint
Browse files Browse the repository at this point in the history
  • Loading branch information
pablomlago committed Jan 21, 2025
1 parent c6614a0 commit 1b90234
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 19 deletions.
13 changes: 8 additions & 5 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1531,7 +1531,7 @@ def __init__(
orphan_sink: bool = False,
sdpa_regions: bool = False,
rotate_matmul: bool = False,
fuse_rotations: bool = True,
use_parametrized_rotations: bool = False,
full_rotation_method: str = 'had',
return_rewriters: bool = False) -> None:
super(GraphRotationEqualization, self).__init__()
Expand All @@ -1549,8 +1549,8 @@ def __init__(
self.full_rotation_method = full_rotation_method
self.return_rewriters = return_rewriters
self.sdpa_regions = sdpa_regions
if not fuse_rotations:
# NOTE: When fuse_rotations=False, parametrized rotations are applied. This changes the attribute __class__
if use_parametrized_rotations:
# NOTE: When use_parametrized_rotations=False, parametrized rotations are applied. This changes the attribute __class__
# of the parametrized module, e.g. to"<class 'torch.nn.utils.parametrize.ParametrizedLinear'>".
# Therefore, algorithms that do type checking might need to use type_before_parametrizations(module),
# instead of only type(module) (see layerwise_layer_handler). Algorithms that rely on in-place modifications
Expand All @@ -1559,7 +1559,7 @@ def __init__(
warnings.warn(
"Using parametrized results might break type-checking, which could lead to unexpected behaviour."
)
self.fuse_rotations = fuse_rotations
self.use_parametrized_rotations = use_parametrized_rotations

def rotate_matmuls(self, graph_module):
matmul_nodes = list(graph_module.graph.nodes)
Expand Down Expand Up @@ -1650,7 +1650,10 @@ def apply(self,
self.rotate_matmuls(graph_model)
if len(regions) > 0:
rewriters = _apply_rotate(
graph_model, regions, self.full_rotation_method, fuse_rotations=self.fuse_rotations)
graph_model,
regions,
self.full_rotation_method,
fuse_rotations=not self.use_parametrized_rotations)
if self.return_rewriters:
return graph_model, rewriters
else:
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas_examples/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

Set the env variable `BREVITAS_JIT=1` to speed up the quantization process. Currently unsupported whenever export is also toggled or with MSE based scales/zero-points.

When using `--rotation fused_no_fx_optimize`, the rotation training procedure relies on the Trainer class (https://huggingface.co/docs/transformers/en/main_classes/trainer). Therefore, training can be further configured by passing arguments accepted by the dataclass TrainingArguments (https://huggingface.co/docs/transformers/en/main_classes/trainer#transformers.TrainingArguments), e.g. `--learning_rate`, `--weight_decay`, `per_device_train_batch_size`.
When using `--optimize-rotations`, the rotation training procedure relies on the Trainer class (https://huggingface.co/docs/transformers/en/main_classes/trainer). Therefore, training can be further configured by passing arguments accepted by the dataclass TrainingArguments (https://huggingface.co/docs/transformers/en/main_classes/trainer#transformers.TrainingArguments), e.g. `--learning_rate`, `--weight_decay`, `per_device_train_batch_size`.

```bash
usage: main.py [-h] [--config CONFIG] [--model MODEL] [--seed SEED]
Expand Down
25 changes: 16 additions & 9 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def set_seed(seed):
torch.random.manual_seed(seed)


def fused_rotation_no_fx(model, calibration_loader, args, fuse_rotations: bool = True):
def fused_rotation_no_fx(model, calibration_loader, args):
with torch.no_grad():
new_model, guards = torch._dynamo.export(model)(**calibration_loader[0])
if hasattr(model, str(torch.nn.functional.scaled_dot_product_attention)):
Expand All @@ -105,7 +105,7 @@ def fused_rotation_no_fx(model, calibration_loader, args, fuse_rotations: bool =
full_rotation_method=args.rotation_mode,
return_rewriters=True,
sdpa_regions=args.rotation_sdpa_regions,
fuse_rotations=fuse_rotations)
use_parametrized_rotations=args.optimize_rotations)
new_model, rewriters = eq.apply(new_model)
rewriters = fix_rewriter(rewriters, model, 'weight')
for r in rewriters:
Expand Down Expand Up @@ -149,7 +149,9 @@ def model_export(model, ref_input, args):


def validate(args, extra_args: Optional[List[str]] = None):
if args.rotation != "fused_no_fx_optimize":
if args.optimize_rotations:
assert args.rotation in ['fx', 'fused_no_fx'], f"Rotations can only be optimized if --rotation=fx or --rotation=fused_no_fx"
else:
assert extra_args is None or len(extra_args) == 0, f"The following unknown arguments were passed: {[extra_arg for extra_arg in extra_args if extra_arg.startswith('--')]}"
if args.functional_sdpa_quant:
assert args.input_scale_type == 'dynamic' or args.input_bit_width is None, "Functional SDPA Quant requires dynamic activation quantization"
Expand Down Expand Up @@ -279,7 +281,7 @@ def quantize_llm(args, extra_args=None):
device=None,
fuse_sequences=args.fuse_sequences)

if args.rotation in ["fused_no_fx_optimize"]:
if args.optimize_rotations:
# Extra arguments should be used as training arguments for rotation optimization
rot_optimization_args = parse_rotation_optimization_args(extra_args=extra_args)
# Load the data for rotation optimization
Expand Down Expand Up @@ -353,16 +355,15 @@ def quantize_llm(args, extra_args=None):
eq = GraphRotationEqualization(
orphan_sink=args.rotation_orphan_sink,
full_rotation_method=args.rotation_mode,
sdpa_regions=args.rotation_sdpa_regions)
sdpa_regions=args.rotation_sdpa_regions,
use_parametrized_rotations=args.optimize_rotations)
model = eq.apply(model)
remove_hooks(model)
elif args.rotation == 'layerwise':
eq = LayerwiseActivationRotation()
model = eq.apply(model)
elif args.rotation == 'fused_no_fx':
fused_rotation_no_fx(model, calibration_loader, args)
elif args.rotation == 'fused_no_fx_optimize':
fused_rotation_no_fx(model, calibration_loader, args, fuse_rotations=False)

if args.weight_equalization:
print("Apply weight equalization...")
Expand Down Expand Up @@ -484,7 +485,7 @@ def quantize_llm(args, extra_args=None):
for k, v in dict_hooks.items():
k._hf_hook.post_forward = v

if args.rotation in ['fused_no_fx_optimize']:
if args.optimize_rotations:
apply_rotation_optimization(
model=model,
tokenizer=tokenizer,
Expand Down Expand Up @@ -855,8 +856,14 @@ def parse_args(args, override_defaults={}):
'--rotation',
type=str,
default=None,
choices=['fx', 'layerwise', 'fused_no_fx', 'fused_no_fx_optimize'],
choices=['fx', 'layerwise', 'fused_no_fx'],
help='Apply graph rotation equalization')
parser.add_argument(
"--optimize-rotations",
action="store_true",
default=False,
help="Whether to optimize the rotations (default: %(default)s).",
)
parser.add_argument(
'--rotation-mode',
default='had',
Expand Down
12 changes: 8 additions & 4 deletions tests/brevitas_examples/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,8 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl):
"weight_bit_width": 4,
"input_bit_width": None,
"replace_rmsnorm": True,
"rotation": "fused_no_fx_optimize",
"rotation": "fused_no_fx",
"optimize_rotations": True,
"rotation_orphan_sink": True,
"rotation_mode": "ort",
"nsamples_rot_calibration": 2,
Expand All @@ -865,7 +866,8 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl):
"weight_bit_width": 4,
"input_bit_width": None,
"replace_rmsnorm": True,
"rotation": "fused_no_fx_optimize",
"rotation": "fused_no_fx",
"optimize_rotations": True,
"rotation_orphan_sink": False,
"rotation_mode": "ort",
"nsamples_rot_calibration": 2,
Expand All @@ -892,7 +894,8 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl):
"weight_bit_width": 4,
"input_bit_width": None,
"replace_rmsnorm": True,
"rotation": "fused_no_fx_optimize",
"rotation": "fused_no_fx",
"optimize_rotations": True,
"rotation_orphan_sink": True,
"rotation_mode": "had",
"nsamples_rot_calibration": 2,
Expand All @@ -919,7 +922,8 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl):
"weight_bit_width": 4,
"input_bit_width": None,
"replace_rmsnorm": True,
"rotation": "fused_no_fx_optimize",
"rotation": "fused_no_fx",
"optimize_rotations": True,
"rotation_orphan_sink": False,
"rotation_mode": "had",
"nsamples_rot_calibration": 2,
Expand Down

0 comments on commit 1b90234

Please sign in to comment.