Skip to content

Commit

Permalink
Lower PyTorch version to ensure test is run
Browse files Browse the repository at this point in the history
  • Loading branch information
pablomlago committed Jan 11, 2025
1 parent de80207 commit 1a077ca
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion tests/brevitas/graph/test_equalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch.nn.utils.parametrize as parametrize
from torchvision import models

from brevitas import torch_version
from brevitas.fx import symbolic_trace
from brevitas.graph.base import ModuleInstanceRegisterParametrization
from brevitas.graph.base import RotationWeightParametrization
Expand Down Expand Up @@ -405,7 +406,7 @@ def compare_model_weights(model_fused, model_unfused, classes_to_compare=(nn.Lin
assert torch.allclose(getattr(module_fused, tensor_name), getattr(module_unfused, tensor_name), atol=0.0, rtol=0.0), f"Tensor {tensor_name} does not match for module {name_module_fused}"


@requires_pt_ge('2.4')
@requires_pt_ge('2.3.1')
@pytest_cases.parametrize(
'mask',
itertools.product([False, True], repeat=3),
Expand Down

0 comments on commit 1a077ca

Please sign in to comment.