Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Example for adaptive bit width of quantizer #1156

Closed
karli262 opened this issue Jan 15, 2025 · 4 comments
Closed

Example for adaptive bit width of quantizer #1156

karli262 opened this issue Jan 15, 2025 · 4 comments
Labels
enhancement New feature or request

Comments

@karli262
Copy link

Hello,

I am new to the topic but I already went through the documentation provided. It shows how you can adjust the bit width so it is a learnable parameter. But how do you actually update it? I have a simple network with one quantized layer and I add the bit width of this layer to the loss function to minimize it. However, it always stays the same. Can you provide a working example or correct mine:

The quantizer is used like in the documentation:

class LearnedIntWeightPerChannelFloat(Int8WeightPerChannelFloat):
    scaling_impl_type = ScalingImplType.PARAMETER_FROM_STATS
    restrict_scaling_type = RestrictValueType.LOG_FP
    weight_bit_width_impl_type = BitWidthImplType.PARAMETER

My layer:

self.lin = qnn.QuantLinear(32, 6, weight_quant=LearnedIntWeightPerChannelFloat)

and the loss:

loss = task_loss + model.lin.quant_weight().bit_width

Thank you in advance!

@karli262 karli262 added the enhancement New feature or request label Jan 15, 2025
@karli262
Copy link
Author

Printing the model parameter gives me this:

+-------------------------------------------------------------------------+------------+
|                                 Modules                                 | Parameters |
+-------------------------------------------------------------------------+------------+
|                                lin.weight                               |    192     |
|                                 lin.bias                                |     6      |
|             lin.weight_quant.tensor_quant.scaling_impl.value            |     6      |
| lin.weight_quant.tensor_quant.msb_clamp_bit_width_impl.bit_width_offset |     1      |
+-------------------------------------------------------------------------+------------+
Total Trainable Params: 205

@nickfraser
Copy link
Collaborator

I'll give you a quick response and I'll take a deeper look if that doesn't help you. My initial glance is that you don't want to use the bitwidth for your cost function directly*. You'll like want to scale it**, and you might benefit from using our helpers.

* That shouldn't be the cause of your problem though.

**

loss = task_loss + some_small_value*model.lin.quant_weight().bit_width

However, it always stays the same.

Can you share your training loop? Effectively, lin.weight_quant.tensor_quant.msb_clamp_bit_width_impl.bit_width_offset is a weight that should have gradients associated with it. Running loss.backward() & optimizer.step() should update it so long as it has been set up correctly.

@Giuseppe5
Copy link
Collaborator

Another quick sanity check is to see whether the grad attribute of the bitwidth parameter is not None after backward.
If it's None, then no gradients are being propagated; if it's anything else than None (even all zeros), then it might be a scaling issue as Nick suggested.

Let us know what you find out and we can take a closer look at this if it is still an issue!

@karli262
Copy link
Author

Thank you for your quick reply and your tips.
I tried it all out and I just got it working. For my application I had to choose a small step size for the optimizer. I tried to increase it and now it updates the bit width. Sorry for the inconvenience.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants