Skip to content

Commit

Permalink
Fix (llm): small fixes to LLM (#1035)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Oct 8, 2024
1 parent 686beb7 commit db6c560
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 31 deletions.
5 changes: 5 additions & 0 deletions src/brevitas/core/function_wrapper/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,11 @@ def __init__(self, expanded_groupwise_shape, group_size, group_dim) -> None:

@brevitas.jit.script_method
def forward(self, x: torch.Tensor):
# This one is a bit tricky but we could end up here:
# - If we quantize the zero point, which will already have expanded shape matching the scale (although no padding, but we don't need the padding)
# - Groupwise HQO quantization, where weight will already have been padded and expanded
if len(x.shape) == len(self.expanded_groupwise_shape):
return x
y = torch.nn.functional.pad(
x, padding(x, self.group_size, self.group_dim), mode='constant', value=0.)
y = y.view(self.expanded_groupwise_shape)
Expand Down
6 changes: 5 additions & 1 deletion src/brevitas/core/stats/stats_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,14 +692,18 @@ def parameter_search(self, xl, x):
self.set_local_loss_mode(False)
qt_value = self.input_view_shape_impl(quant_tensor.value)
qt_scale = self.input_view_shape_impl(quant_tensor.scale)
qt_int = self.input_view_shape_impl(quant_tensor.int())
qt_zp = self.input_view_shape_impl(quant_tensor.zero_point)
qt_int = qt_value / qt_scale + qt_zp
loss = torch.abs(qt_value - x).mean()
best_candidate = torch.where(loss < best_loss, candidate, best_candidate)
if loss >= best_loss:
break
best_loss = torch.min(loss, best_loss)
W_e = shrink_lp_op(x - qt_value, self.beta, self.lp_norm)

# Compared to the original formulation, the value we're looking for is:
# - scaled by qt_scale
# - opposite sign
val = self.input_view_shape_impl((x - W_e) - qt_int * qt_scale)

if self.stats_reduce_dim is None:
Expand Down
53 changes: 36 additions & 17 deletions src/brevitas/quant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,41 +429,31 @@ class AccumulatorAwareZeroCenterWeightQuant(AccumulatorAwareWeightQuant):
pre_zero_point_stats_input_view_shape_impl = this.scaling_stats_input_view_shape_impl


class MSESubInjectorBase(ExtendedInjector):

@value
def inner_stats_input_view_shape_impl(scaling_per_output):
if scaling_per_output == ScalingPerOutputType.CHANNEL:
return StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS
elif scaling_per_output == ScalingPerOutputType.TENSOR:
return StatsInputViewShapeImpl.OVER_TENSOR
elif scaling_per_output == ScalingPerOutputType.GROUP:
raise RuntimeError("Not implemented yet")

permute_dims = (this << 1).permute_dims


class MSESymmetricScaleSubInjector(MSESubInjectorBase):
class MSESymmetricScaleSubInjector(ExtendedInjector):
scaling_per_output = (this << 1).scaling_per_output
proxy_module = (this << 1).proxy_module
mse_init_op = AbsMax
stats_impl = MSE
stats_reduce_dim = (this << 1).stats_reduce_dim
device = (this << 1).device
type = (this << 1).type
permute_dims = (this << 1).permute_dims
inner_stats_input_view_shape_impl = (this << 1).inner_stats_input_view_shape_impl


class MSEAsymmetricScaleSubInjector(MSESubInjectorBase):
class MSEAsymmetricScaleSubInjector(ExtendedInjector):
scaling_per_output = (this << 1).scaling_per_output
proxy_module = (this << 1).proxy_module
mse_init_op = AbsMinMax
stats_impl = MSE
stats_reduce_dim = (this << 1).stats_reduce_dim
device = (this << 1).device
dtype = (this << 1).dtype
permute_dims = (this << 1).permute_dims
inner_stats_input_view_shape_impl = (this << 1).inner_stats_input_view_shape_impl


class MSEZeroPointSubInjector(MSESubInjectorBase):
class MSEZeroPointSubInjector(ExtendedInjector):
# zp is per channel when scaling is per channel
scaling_per_output = (this << 1).scaling_per_output
proxy_module = (this << 1).proxy_module
Expand All @@ -473,6 +463,8 @@ class MSEZeroPointSubInjector(MSESubInjectorBase):
stats_reduce_dim = (this << 1).stats_reduce_dim
device = (this << 1).device
dtype = (this << 1).dtype
permute_dims = (this << 1).permute_dims
inner_stats_input_view_shape_impl = (this << 1).inner_stats_input_view_shape_impl


class MSEAsymmetricScale(ExtendedInjector):
Expand All @@ -484,6 +476,15 @@ class MSEAsymmetricScale(ExtendedInjector):
scaling_impl_type = ScalingImplType.PARAMETER_FROM_STATS
scaling_stats_input_view_shape_impl = nn.Identity()

@value
def inner_stats_input_view_shape_impl(scaling_per_output):
if scaling_per_output == ScalingPerOutputType.CHANNEL:
return StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS
elif scaling_per_output == ScalingPerOutputType.TENSOR:
return StatsInputViewShapeImpl.OVER_TENSOR
elif scaling_per_output == ScalingPerOutputType.GROUP:
return StatsInputViewShapeImpl.OVER_SUBCHANNEL_BLOCK

@value
def scaling_stats_impl():
return this.mse_scale.stats_impl
Expand All @@ -498,6 +499,15 @@ class MSESymmetricScale(ExtendedInjector):
scaling_impl_type = ScalingImplType.PARAMETER_FROM_STATS
scaling_stats_input_view_shape_impl = nn.Identity()

@value
def inner_stats_input_view_shape_impl(scaling_per_output):
if scaling_per_output == ScalingPerOutputType.CHANNEL:
return StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS
elif scaling_per_output == ScalingPerOutputType.TENSOR:
return StatsInputViewShapeImpl.OVER_TENSOR
elif scaling_per_output == ScalingPerOutputType.GROUP:
return StatsInputViewShapeImpl.OVER_SUBCHANNEL_BLOCK

@value
def scaling_stats_impl():
return this.mse_scale.stats_impl
Expand All @@ -511,6 +521,15 @@ class MSEZeroPoint(ExtendedInjector):
mse_zero_point = MSEZeroPointSubInjector
zero_point_stats_input_view_shape_impl = nn.Identity()

@value
def inner_stats_input_view_shape_impl(scaling_per_output):
if scaling_per_output == ScalingPerOutputType.CHANNEL:
return StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS
elif scaling_per_output == ScalingPerOutputType.TENSOR:
return StatsInputViewShapeImpl.OVER_TENSOR
elif scaling_per_output == ScalingPerOutputType.GROUP:
return StatsInputViewShapeImpl.OVER_SUBCHANNEL_BLOCK

@value
def zero_point_stats_impl():
return this.mse_zero_point.stats_impl
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas/quant/experimental/mx_quant_ocp.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,14 @@ class MXInt8Act(MXActMixin, GroupwiseActProxyMixin, IntQuant, MaxStatsScaling, A
bit_width = 8


class MXInt8WeightMSE(MXInt8Weight, MSESymmetricScale):
class MXInt8WeightMSE(MSESymmetricScale, MXInt8Weight):
"""
MX Int signed weight quantizer with per-channel MSE-based scaling.
"""
pass


class ShiftedMXUInt8WeightMSE(ShiftedMXUInt8Weight, MSEAsymmetricScale):
class ShiftedMXUInt8WeightMSE(MSEAsymmetricScale, ShiftedMXUInt8Weight):
"""
MX Int signed weight quantizer with per-channel MSE-based scaling.
"""
Expand Down
6 changes: 3 additions & 3 deletions src/brevitas/quant/shifted_scaled_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from brevitas.proxy.groupwise_int_parameter_quant import GroupwiseWeightQuantProxyFromInjector
from brevitas.quant.base import *
from brevitas.quant.base import HQOActZeroPoint
from brevitas.quant.base import HQOZeroPoint
from brevitas.quant.base import HQOWeightZeroPoint
from brevitas.quant.solver.act import ActQuantSolver
from brevitas.quant.solver.weight import WeightQuantSolver

Expand Down Expand Up @@ -145,7 +145,7 @@ class ShiftedUint8WeightPerChannelFloatMSE(MSEAsymmetricScale,
pass


class ShiftedUint8WeightPerTensorFloatHQO(HQOZeroPoint, ShiftedUint8WeightPerTensorFloat):
class ShiftedUint8WeightPerTensorFloatHQO(HQOWeightZeroPoint, ShiftedUint8WeightPerTensorFloat):
"""
8-bit per-tensor unsigned int weight quantizer with floating-point per-channel scale factor and integer
zero point. Zero-point is initialized from HQO local loss.
Expand All @@ -157,7 +157,7 @@ class ShiftedUint8WeightPerTensorFloatHQO(HQOZeroPoint, ShiftedUint8WeightPerTen
quantize_zero_point = False


class ShiftedUint8WeightPerChannelFloatHQO(HQOZeroPoint, ShiftedUint8WeightPerChannelFloat):
class ShiftedUint8WeightPerChannelFloatHQO(HQOWeightZeroPoint, ShiftedUint8WeightPerChannelFloat):
"""
8-bit per-tensor unsigned int weight quantizer with floating-point per-channel scale factor and integer
zero point. Zero-point is initialized from HQO local loss.
Expand Down
16 changes: 10 additions & 6 deletions src/brevitas_examples/common/generative/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@
'per_tensor': {
'sym': Int8WeightPerTensorFixedPointMSE},
'per_channel': {
'sym': Int8WeightPerChannelFixedPointMSE}},
'per_group': {
'sym': MXInt8WeightMSE, 'asym': ShiftedMXUInt8WeightMSE}}},
'sym': Int8WeightPerChannelFixedPointMSE},
'per_group': {
'sym': MXInt8WeightMSE, 'asym': ShiftedMXUInt8WeightMSE}}}},
'float': {
'float_scale': {
'stats': {
Expand Down Expand Up @@ -210,6 +210,7 @@ def generate_quantizers(
weight_group_size,
quantize_weight_zero_point,
weight_quant_format='int',
weight_group_dim=None,
input_bit_width=None,
input_quant_format='',
input_scale_precision=None,
Expand Down Expand Up @@ -276,6 +277,10 @@ def generate_quantizers(
'narrow_range': False,
'quantize_zero_point': quantize_weight_zero_point},
**weight_float_format)

if weight_group_dim is not None:
weight_quant = weight_quant.let(**{'group_dim': weight_group_dim})

if dtype == torch.float16:
weight_quant = weight_quant.let(**{'scaling_min_val': 1e-4})
if weight_kwargs is not None:
Expand All @@ -285,9 +290,8 @@ def generate_quantizers(
if weight_quant_granularity == 'per_group':
weight_quant = weight_quant.let(**{'group_size': weight_group_size})
# weight scale is converted to a standalone parameter
# This is done already by default in the per_group quantizer
if weight_quant_granularity != 'per_group':
weight_quant = weight_quant.let(scaling_impl_type='parameter_from_stats')

weight_quant = weight_quant.let(scaling_impl_type='parameter_from_stats')
# weight zero-point is converted to a standalone parameter
# This is done already by default in the per_group quantizer
if weight_quant_type == 'asym' and weight_quant_granularity != 'per_group':
Expand Down
13 changes: 11 additions & 2 deletions src/brevitas_examples/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ Set the env variable BREVITAS_JIT=1 to speed up the quantization process. Curren
usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES]
[--seqlen SEQLEN] [--eval] [--dataset {wikitext2,c4}]
[--weight-bit-width WEIGHT_BIT_WIDTH]
[--weight-param-method {stats,mse}]
[--weight-param-method {stats,mse,hqo}]
[--weight-scale-precision {float_scale,po2_scale}]
[--weight-quant-type {sym,asym}]
[--weight-quant-format WEIGHT_QUANT_FORMAT]
[--weight-quant-granularity {per_channel,per_tensor,per_group}]
[--weight-group-dim {1,0}]
[--weight-group-size WEIGHT_GROUP_SIZE]
[--quantize-weight-zero-point]
[--input-bit-width INPUT_BIT_WIDTH]
Expand All @@ -38,6 +39,7 @@ usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES]
[--weight-equalization]
[--act-equalization {None,layerwise,fx}] [--load-awq LOAD_AWQ]
[--export-target {None,onnx_qcdq,torch_qcdq,sharded_torchmlir_group_weight,sharded_packed_torchmlir_group_weight}]
[--export-prefix EXPORT_PREFIX]
[--checkpoint-name CHECKPOINT_NAME]

options:
Expand All @@ -51,7 +53,7 @@ options:
Dataset to use for quantization (default: wikitext2)
--weight-bit-width WEIGHT_BIT_WIDTH
Weight bit width. Default: 8.
--weight-param-method {stats,mse}
--weight-param-method {stats,mse,hqo}
How scales/zero-point are determined. Default: stats.
--weight-scale-precision {float_scale,po2_scale}
Whether scale is a float value or a po2. Default: po2.
Expand All @@ -65,6 +67,9 @@ options:
--weight-quant-granularity {per_channel,per_tensor,per_group}
Granularity for scales/zero-point of weights. Default:
per_group.
--weight-group-dim {1,0}
Override default group_dim for groupsize quantization.
Default: layer-dependant
--weight-group-size WEIGHT_GROUP_SIZE
Group size for per_group weight quantization. Default:
128.
Expand Down Expand Up @@ -119,6 +124,10 @@ options:
--load-awq LOAD_AWQ Load the awq search results.
--export-target {None,onnx_qcdq,torch_qcdq,sharded_torchmlir_group_weight,sharded_packed_torchmlir_group_weight}
Model export.
--export-prefix EXPORT_PREFIX
Path prefix to use for the various export flows. If
None, a path will be derived from the model name
(default: None)
--checkpoint-name CHECKPOINT_NAME
Filename to save checkpoint. If `None`, no checkpoint
is saved (default: None)
Expand Down
7 changes: 7 additions & 0 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def main(args):
weight_quant_type=args.weight_quant_type,
weight_quant_granularity=args.weight_quant_granularity,
weight_group_size=args.weight_group_size,
weight_group_dim=args.weight_group_dim,
quantize_weight_zero_point=args.quantize_weight_zero_point,
weight_quant_format=args.weight_quant_format,
input_bit_width=args.input_bit_width,
Expand Down Expand Up @@ -358,6 +359,12 @@ def parse_args(args):
default='per_group',
choices=['per_channel', 'per_tensor', 'per_group'],
help='Granularity for scales/zero-point of weights. Default: per_group.')
parser.add_argument(
'--weight-group-dim',
type=int,
default=None,
choices=[1, 0],
help='Override default group_dim for groupsize quantization. Default: layer-dependant')
parser.add_argument(
'--weight-group-size',
type=int,
Expand Down

0 comments on commit db6c560

Please sign in to comment.