Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (equalize): enable parametrized rotations #1148

Merged
merged 20 commits into from
Jan 14, 2025

Conversation

pablomlago
Copy link
Collaborator

Reason for this PR

Changes Made in this PR

Testing Summary

Risk Highlight

  • This PR includes code from another work (please detail).
  • This PR contains API-breaking changes.
  • This PR depends on work in another PR (please provide links/details).
  • This PR introduces new dependencies (please detail).
  • There are coverage gaps not covered by tests.
  • Documentation updates required in subsequent PR.

Checklist

  • Code comments added to any hard-to-understand areas, if applicable.
  • Changes generate no new warnings.
  • Updated any relevant tests, if applicable.
  • No conflicts with destination dev branch.
  • I reviewed my own code changes.
  • Initial CI/CD passing.
  • 1+ reviews given, and any review issues addressed and approved.
  • Post-review full CI/CD passing.

@pablomlago pablomlago requested a review from Giuseppe5 January 9, 2025 20:06
@pablomlago pablomlago changed the title [DO NOT MERGE] Enable parametrized rotations Enable parametrized rotations Jan 10, 2025
@pablomlago pablomlago force-pushed the feat-unfused-rotations branch from 7d27050 to 3888df4 Compare January 11, 2025 10:39
return model


class RotationWeightParametrization(torch.nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this should live here, maybe pytorch utils?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me something very specific to live in pytorch utils. Maybe worth having a graph/rotation_utils.py?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do utils/rotation_utils

Maybe at some point we will also have a utils/equalize_utils and we'll leave only a few core classes in graph/equalize.py

if old_module is self.old_module_instance:
# register the parametrization in the old_module
parametrize.register_parametrization(
old_module, self.tensor_name, self.parametrization_module, unsafe=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the unsafe flag for?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It checks that the dtype and shape of the parametrized tensor is the same as the original one. As it computes parametrization(self.weight) and then the checks are done on the result, the parametrizations take some time to register. I had disabled it for faster experimentation, but probably it's worth doing those checks

def apply(self, model: GraphModule) -> GraphModule:
for old_module in model.modules():
if old_module is self.old_module_instance:
if hasattr(old_module, 'allocate_params'):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment, maybe TODO, whether this should live here or outside the apply function. I'm not sure about it

return weight


class ModuleInstanceFuseRotationWeights(Transform):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems a very specific functions, add some comments

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this equivalent to the old behaviour?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the point about adding this function is to avoid duplicating the code between the parametrization module and the in-place transformation. Also, by having an specific rewriter for the in-place fusing of the rotation, we can avoid doing any modifications at all in the FX model, thus preventing potential inconsistencies in fused_no_fx

def __init__(
self, old_module_instance: Module, tensor_name: str,
parametrization_module: Module) -> None:
self.old_module_instance = old_module_instance
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe names are not correctly representing what variables do

self.axis = axis
self.K = K

def forward(self, weight: torch.Tensor) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change weight to tensor

raise RuntimeError("Not supported yet")
module.weight.data = weight
if fuse_rotations:
rewriter = ModuleInstanceFuseRotationWeights(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is equivalent to the old behavior, it seems a bit over-complicated. Do we need it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did it mainly for preventing duplication. Before, we had the same logic for sinks/sources, and when we do not fuse the rotations, I had to duplicate that logic again in the parametrization module



@pytest_cases.parametrize('N', [1, 2, 3], ids=lambda x: f"N={x}")
def test_composition_unfused_rotations(N):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the tests!

@pablomlago pablomlago force-pushed the feat-unfused-rotations branch from 20cffcc to e6fb34f Compare January 13, 2025 14:37
tied_param_name_split = tied_param_name.split(".")
# Check if the tied parameter is the original parameter in the module
if len(tied_param_name_split) >= 3 and tied_param_name_split[
-3] == "parametrizations" and tied_param_name_split[-1] == "original":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why [-3]?
Seems pretty arbitrary. What if the hierarchy is smaller than 3?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nevermind

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add a comment to explain so it's clearer

@pablomlago pablomlago force-pushed the feat-unfused-rotations branch from f9ae9ae to 82287b0 Compare January 13, 2025 22:21
@Giuseppe5 Giuseppe5 self-requested a review January 13, 2025 22:46
@pablomlago pablomlago force-pushed the feat-unfused-rotations branch from beb2cbc to b1b59a2 Compare January 14, 2025 11:22
@Giuseppe5 Giuseppe5 requested review from Giuseppe5 and removed request for Giuseppe5 January 14, 2025 12:23
@Giuseppe5 Giuseppe5 requested review from Giuseppe5 and removed request for Giuseppe5 January 14, 2025 14:05
@Giuseppe5 Giuseppe5 merged commit 52cfffd into Xilinx:dev Jan 14, 2025
393 of 396 checks passed
@pablomlago pablomlago changed the title Enable parametrized rotations Feat (equalize): enable parametrized rotations Jan 15, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants