Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (brevitas_examples/llm): correct scale init with CPU offloading #1124

Merged
merged 3 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/brevitas/utils/python_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,13 @@ def _getattr(obj, attr):
return getattr(obj, attr)

return functools.reduce(_getattr, [obj] + attr.split("."))


def hooked_on_a_function(function, prefunction):

@functools.wraps(function)
def run(*args, **kwargs):
prefunction()
return function(*args, **kwargs)

return run
22 changes: 22 additions & 0 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import argparse
from copy import deepcopy
import functools
import sys
from warnings import warn

Expand All @@ -20,8 +21,10 @@
from brevitas.graph.equalize import LayerwiseActivationRotation
from brevitas.graph.quantize import layerwise_quantize
from brevitas.graph.utils import get_module
from brevitas.utils.python_utils import hooked_on_a_function
from brevitas_examples.common.accelerate_utils.accelerate import offload_model
from brevitas_examples.common.accelerate_utils.accelerate import remove_hooks
from brevitas_examples.common.accelerate_utils.accelerate import update_internal_dict
from brevitas_examples.common.generative.quantize import generate_quant_maps
from brevitas_examples.common.generative.quantize import generate_quantizers
from brevitas_examples.common.parse_utils import quant_format_validator
Expand Down Expand Up @@ -378,9 +381,28 @@ def main(args):

model = offload_model(model)

dict_hooks = dict()

# When offloading to CPU + GPU, the CPU scale factors must be updated
# before we move them back to the meta device.
# If we don't, we lose the new value but the internal flag "init_done" is True, thus we will use the wrong scale.
# To do this, we attach a "hook" to the post_forward function, called before the post_forward
# The function will update the dict with the initialized scales
for m in model.modules():
if hasattr(m, '_hf_hook'):
if m._hf_hook.weights_map is not None:
# We store the original function to be restored later
dict_hooks[m] = m._hf_hook.post_forward
new_funct = functools.partial(update_internal_dict, m)
m._hf_hook.post_forward = hooked_on_a_function(m._hf_hook.post_forward, new_funct)

with torch.no_grad():
model(**calibration_loader[0])

# We restore the original behaviour of the post-forward.
for k, v in dict_hooks.items():
k._hf_hook.post_forward = v

if args.act_calibration:
print("Apply act calibration...")
apply_calibration(model, calibration_loader)
Expand Down
Loading