diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 05d9bcb47..955f2a749 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -382,19 +382,23 @@ def main(args): dict_hooks = dict() - def update_params_post_init(module): - update_internal_dict(module) - + # When offloading to CPU + GPU, the CPU scale factors must be updated + # before we move them back to the meta device. + # If we don't, we lose the new value but the internal flag "init_done" is True, thus we will use the wrong scale. + # To do this, we attach a "hook" to the post_forward function, called before the post_forward + # The function will update the dict with the initialized scales for m in model.modules(): if hasattr(m, '_hf_hook'): if m._hf_hook.weights_map is not None: + # We store the original function to be restored later dict_hooks[m] = m._hf_hook.post_forward - new_funct = functools.partial(update_params_post_init, m) + new_funct = functools.partial(update_internal_dict, m) m._hf_hook.post_forward = hooked_on_a_function(m._hf_hook.post_forward, new_funct) with torch.no_grad(): model(**calibration_loader[0]) + # We restore the original behaviour of the post-forward. for k, v in dict_hooks.items(): k._hf_hook.post_forward = v