From ae7da1835307b0ab37e8df7fad234cd23df843a3 Mon Sep 17 00:00:00 2001 From: Alessandro Pappalardo <1934033+volcacius@users.noreply.github.com> Date: Mon, 17 Jul 2023 18:22:26 +0200 Subject: [PATCH] Feat (core): add permute_dims to all reshape fns (#671) --- src/brevitas/core/function_wrapper/shape.py | 24 +++++++++++++++------ 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/src/brevitas/core/function_wrapper/shape.py b/src/brevitas/core/function_wrapper/shape.py index 4cda8da20..49f359a4a 100644 --- a/src/brevitas/core/function_wrapper/shape.py +++ b/src/brevitas/core/function_wrapper/shape.py @@ -60,7 +60,7 @@ class OverOutputChannelView(brevitas.jit.ScriptModule): torch.Size([16, 200]) """ - def __init__(self, permute_dims: Optional[Tuple[int, ...]]) -> None: + def __init__(self, permute_dims: Optional[Tuple[int, ...]] = None) -> None: super(OverOutputChannelView, self).__init__() if permute_dims is not None: self.permute_impl = PermuteDims(permute_dims) @@ -86,13 +86,18 @@ class OverBatchOverTensorView(brevitas.jit.ScriptModule): torch.Size([8, 250]) """ - def __init__(self) -> None: + def __init__(self, permute_dims: Optional[Tuple[int, ...]] = None) -> None: super(OverBatchOverTensorView, self).__init__() + if permute_dims is not None: + self.permute_impl = PermuteDims(permute_dims) + else: + self.permute_impl = Identity() @brevitas.jit.script_method def forward(self, x: torch.Tensor): - shape = over_batch_over_tensor(x) - return x.reshape(shape) + y = self.permute_impl(x) + shape = over_batch_over_tensor(y) + return y.reshape(shape) class OverBatchOverOutputChannelView(brevitas.jit.ScriptModule): @@ -107,13 +112,18 @@ class OverBatchOverOutputChannelView(brevitas.jit.ScriptModule): torch.Size([8, 10, 25]) """ - def __init__(self) -> None: + def __init__(self, permute_dims: Optional[Tuple[int, ...]] = None) -> None: super(OverBatchOverOutputChannelView, self).__init__() + if permute_dims is not None: + self.permute_impl = PermuteDims(permute_dims) + else: + self.permute_impl = Identity() @brevitas.jit.script_method def forward(self, x: torch.Tensor): - shape = over_batch_over_output_channels(x) - return x.reshape(shape) + y = self.permute_impl(x) + shape = over_batch_over_output_channels(y) + return y.reshape(shape) class StatsInputViewShapeImpl(object):