-
Notifications
You must be signed in to change notification settings - Fork 202
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat (equalize): enable parametrized rotations #1148
Conversation
7d27050
to
3888df4
Compare
src/brevitas/graph/base.py
Outdated
return model | ||
|
||
|
||
class RotationWeightParametrization(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if this should live here, maybe pytorch utils?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems to me something very specific to live in pytorch utils. Maybe worth having a graph/rotation_utils.py?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's do utils/rotation_utils
Maybe at some point we will also have a utils/equalize_utils and we'll leave only a few core classes in graph/equalize.py
src/brevitas/graph/base.py
Outdated
if old_module is self.old_module_instance: | ||
# register the parametrization in the old_module | ||
parametrize.register_parametrization( | ||
old_module, self.tensor_name, self.parametrization_module, unsafe=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the unsafe flag for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It checks that the dtype and shape of the parametrized tensor is the same as the original one. As it computes parametrization(self.weight) and then the checks are done on the result, the parametrizations take some time to register. I had disabled it for faster experimentation, but probably it's worth doing those checks
src/brevitas/graph/base.py
Outdated
def apply(self, model: GraphModule) -> GraphModule: | ||
for old_module in model.modules(): | ||
if old_module is self.old_module_instance: | ||
if hasattr(old_module, 'allocate_params'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a comment, maybe TODO, whether this should live here or outside the apply function. I'm not sure about it
src/brevitas/graph/base.py
Outdated
return weight | ||
|
||
|
||
class ModuleInstanceFuseRotationWeights(Transform): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems a very specific functions, add some comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this equivalent to the old behaviour?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the point about adding this function is to avoid duplicating the code between the parametrization module and the in-place transformation. Also, by having an specific rewriter for the in-place fusing of the rotation, we can avoid doing any modifications at all in the FX model, thus preventing potential inconsistencies in fused_no_fx
src/brevitas/graph/base.py
Outdated
def __init__( | ||
self, old_module_instance: Module, tensor_name: str, | ||
parametrization_module: Module) -> None: | ||
self.old_module_instance = old_module_instance |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe names are not correctly representing what variables do
src/brevitas/graph/base.py
Outdated
self.axis = axis | ||
self.K = K | ||
|
||
def forward(self, weight: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change weight
to tensor
src/brevitas/graph/equalize.py
Outdated
raise RuntimeError("Not supported yet") | ||
module.weight.data = weight | ||
if fuse_rotations: | ||
rewriter = ModuleInstanceFuseRotationWeights( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is equivalent to the old behavior, it seems a bit over-complicated. Do we need it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did it mainly for preventing duplication. Before, we had the same logic for sinks/sources, and when we do not fuse the rotations, I had to duplicate that logic again in the parametrization module
|
||
|
||
@pytest_cases.parametrize('N', [1, 2, 3], ids=lambda x: f"N={x}") | ||
def test_composition_unfused_rotations(N): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the tests!
20cffcc
to
e6fb34f
Compare
tied_param_name_split = tied_param_name.split(".") | ||
# Check if the tied parameter is the original parameter in the module | ||
if len(tied_param_name_split) >= 3 and tied_param_name_split[ | ||
-3] == "parametrizations" and tied_param_name_split[-1] == "original": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why [-3]
?
Seems pretty arbitrary. What if the hierarchy is smaller than 3?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nevermind
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll add a comment to explain so it's clearer
f9ae9ae
to
82287b0
Compare
beb2cbc
to
b1b59a2
Compare
Reason for this PR
Changes Made in this PR
Testing Summary
Risk Highlight
Checklist
dev
branch.