Skip to content

Commit

Permalink
Fix po2 for float quant
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Sep 25, 2024
1 parent 6398a6b commit ddb09a5
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 28 deletions.
5 changes: 3 additions & 2 deletions src/brevitas/core/quant/float.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,13 @@ def __init__(

@brevitas.jit.script_method
def quantize(self, x: torch.Tensor):
scale = self.scaling_impl(x)

if self.float_scaling_impl is not None:
float_scaling_impl_value = self.float_scaling_impl(
self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias())
scale = scale / float_scaling_impl_value
else:
float_scaling_impl_value = torch.tensor(1.).type_as(x)
scale = self.scaling_impl(x, float_scaling_impl_value)
x = self.input_view_impl(x)
scaled_x = x / scale
internal_scale = float_internal_scale(
Expand Down
3 changes: 1 addition & 2 deletions src/brevitas/core/quant/int.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,8 @@ def __init__(
@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
bit_width = self.msb_clamp_bit_width_impl()
threshold = self.scaling_impl(x)
int_threshold = self.int_scaling_impl(bit_width)
scale = threshold / int_threshold
scale = self.scaling_impl(x, int_threshold)
zero_point = self.zero_point_impl(x, scale, bit_width)
y = self.int_quant(scale, zero_point, bit_width, x)
return y, scale, zero_point, bit_width
Expand Down
16 changes: 8 additions & 8 deletions src/brevitas/core/scaling/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ def __init__(
device)

@brevitas.jit.script_method
def forward(self, ignored: torch.Tensor) -> torch.Tensor:
def forward(self, ignored: torch.Tensor, threshold: torch.Tensor) -> torch.Tensor:
stats = self.parameter_list_stats()
return self.stats_scaling_impl(stats)
return self.stats_scaling_impl(stats, threshold)


class _StatsScaling(brevitas.jit.ScriptModule):
Expand All @@ -80,8 +80,8 @@ def __init__(
self.restrict_scaling_pre = restrict_scaling_impl.restrict_init_module()

@brevitas.jit.script_method
def forward(self, stats: torch.Tensor) -> torch.Tensor:
stats = self.restrict_scaling_pre(stats)
def forward(self, stats: torch.Tensor, threshold: torch.Tensor) -> torch.Tensor:
stats = self.restrict_scaling_pre(stats / threshold)
stats = self.affine_rescaling(stats)
stats = self.restrict_clamp_scaling(stats)
return stats
Expand Down Expand Up @@ -120,9 +120,9 @@ def __init__(
device)

@brevitas.jit.script_method
def forward(self, x: torch.Tensor):
def forward(self, x: torch.Tensor, threshold: torch.Tensor) -> torch.Tensor:
stats = self.runtime_stats(x)
return self.stats_scaling_impl(stats)
return self.stats_scaling_impl(stats, threshold)


class _AffineRescaling(brevitas.jit.ScriptModule):
Expand Down Expand Up @@ -179,9 +179,9 @@ def __init__(
self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl)

@brevitas.jit.script_method
def forward(self, stats_input) -> torch.Tensor:
def forward(self, stats_input: torch.Tensor, threshold: torch.Tensor) -> torch.Tensor:
stats_input_reshaped = self.input_view_impl(stats_input)
out = self.scaling_stats_impl(stats_input_reshaped)
out = self.scaling_stats_impl(stats_input_reshaped) / threshold
# Scaling min val
out = self.restrict_clamp_scaling(out)
return out
36 changes: 20 additions & 16 deletions src/brevitas/core/scaling/standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def __init__(
self.value = StatelessBuffer(torch.tensor(scaling_init, dtype=dtype, device=device))

@brevitas.jit.script_method
def forward(self, placeholder: Tensor) -> Tensor:
value = self.value()
def forward(self, placeholder: Tensor, threshold: torch.Tensor) -> Tensor:
value = self.value() / threshold
restricted_value = self.restrict_clamp_scaling(value)
return restricted_value

Expand Down Expand Up @@ -149,8 +149,8 @@ def __init__(
self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl)

@brevitas.jit.script_method
def forward(self, placeholder: Tensor) -> Tensor:
value = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value))
def forward(self, placeholder: Tensor, threshold: torch.Tensor) -> Tensor:
value = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value / threshold))
return value

def _load_from_state_dict(
Expand Down Expand Up @@ -201,19 +201,21 @@ def __init__(
self.value = Parameter(torch.full(scaling_shape, 1.0, dtype=dtype, device=device))

@brevitas.jit.script_method
def forward(self, ignored: torch.Tensor) -> torch.Tensor:
def forward(self, ignored: torch.Tensor, threshold: torch.Tensor) -> torch.Tensor:
if self.init_done:
value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(self.value))
value = abs_binary_sign_grad(
self.stats_scaling_impl.restrict_clamp_scaling(self.value / threshold))
return value
else:
stats = self.parameter_list_stats()
# workaround to avoid find_ununsed_parameter=True in DDP
stats = stats + 0. * self.value
if self.local_loss_mode:
return self.stats_scaling_impl(stats)
return self.stats_scaling_impl(stats, threshold)
stats = self.restrict_inplace_preprocess(stats)
inplace_tensor_mul(self.value.detach(), stats)
value = abs_binary_sign_grad(self.stats_scaling_impl.restrict_clamp_scaling(self.value))
value = abs_binary_sign_grad(
self.stats_scaling_impl.restrict_clamp_scaling(self.value / threshold))
self.init_done = True
return value

Expand Down Expand Up @@ -317,7 +319,7 @@ def __init__(
self.restrict_preprocess = Identity()

@brevitas.jit.script_method
def training_forward(self, stats_input: Tensor) -> Tensor:
def training_forward(self, stats_input: Tensor, threshold: torch.Tensor) -> Tensor:
if self.counter < self.collect_stats_steps:
stats_input = self.stats_input_view_shape_impl(stats_input)
stats = self.stats(stats_input)
Expand All @@ -334,25 +336,27 @@ def training_forward(self, stats_input: Tensor) -> Tensor:
inplace_momentum_update(
self.buffer, clamped_stats.detach(), self.momentum, self.counter, new_counter)
self.counter = new_counter
return abs_binary_sign_grad(clamped_stats)
return abs_binary_sign_grad(clamped_stats) / threshold
elif self.counter == self.collect_stats_steps:
self.restrict_inplace_preprocess(self.buffer)
inplace_tensor_mul(self.value.detach(), self.buffer)
self.counter = self.counter + 1
return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(self.value)))
return abs_binary_sign_grad(
self.clamp_scaling(self.restrict_scaling(self.value / threshold)))
else:
return abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(self.value)))
return abs_binary_sign_grad(
self.clamp_scaling(self.restrict_scaling(self.value / threshold)))

@brevitas.jit.script_method
def forward(self, stats_input: Tensor) -> Tensor:
def forward(self, stats_input: Tensor, threshold: torch.Tensor) -> Tensor:
if self.training:
return self.training_forward(stats_input)
return self.training_forward(stats_input, threshold)
else:
if self.counter <= self.collect_stats_steps:
out = self.buffer
out = self.buffer / threshold
out = self.restrict_preprocess(out)
else:
out = self.value
out = self.value / threshold
out = abs_binary_sign_grad(self.clamp_scaling(self.restrict_scaling(out)))
return out

Expand Down

0 comments on commit ddb09a5

Please sign in to comment.