Skip to content

Commit

Permalink
Prevent assignments of FX modules to model
Browse files Browse the repository at this point in the history
  • Loading branch information
pablomlago committed Dec 30, 2024
1 parent 39ce837 commit 1dde589
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
26 changes: 26 additions & 0 deletions src/brevitas/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from abc import abstractmethod
import inspect
from inspect import getcallargs
from typing import Any, Dict, Type

import torch
from torch.nn import Module
Expand Down Expand Up @@ -189,6 +190,31 @@ def apply(self, model: GraphModule) -> GraphModule:
return model


class ModuleInstanceWrapModule(Transform):

def __init__(
self,
old_module_instance: Module,
wrapper_class: Type[Module],
module_attribute: str,
kwargs_wrapper: Dict[str, Any]):
self.old_module_instance = old_module_instance
self.wrapper_class = wrapper_class
self.module_attribute = module_attribute
self.kwargs_wrapper = kwargs_wrapper

def apply(self, model: GraphModule) -> GraphModule:
for old_module in model.modules():
if old_module is self.old_module_instance:
kwargs = {self.module_attribute: self.old_module_instance}
kwargs.update(self.kwargs_wrapper)
new_module_instance = self.wrapper_class(**kwargs)
# init the new module based on the old one
replace_module(model, old_module, new_module_instance)
break
return model


class ModuleToModuleByName(ModuleToModule):

def __init__(self, old_module_name, new_module_class, **kwargs):
Expand Down
6 changes: 4 additions & 2 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from brevitas.graph.base import GraphTransform
from brevitas.graph.base import InsertModuleCallAfter
from brevitas.graph.base import ModuleInstanceToModuleInstance
from brevitas.graph.base import ModuleInstanceWrapModule
from brevitas.graph.base import Transform
from brevitas.graph.hadamard import get_hadK
from brevitas.graph.hadamard import matmul_hadU
Expand Down Expand Up @@ -1366,8 +1367,9 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method=
module.offload_params(module)

if insert_rotation_module and len(region.srcs) == 0:
rewriter = ModuleInstanceToModuleInstance(
module, RotatedModule(had_mat=rot_mat, k=K, layer=module))
rewriter = ModuleInstanceWrapModule(
module, RotatedModule, "layer", {
"had_mat": rot_mat, "k": K})
rewriters.append(rewriter)
for r in rewriters:
model = r.apply(model)
Expand Down

0 comments on commit 1dde589

Please sign in to comment.