Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (llm/learned_round): fast block update #1110

Merged
merged 12 commits into from
Dec 5, 2024
Merged
101 changes: 82 additions & 19 deletions src/brevitas_examples/common/learned_round/learned_round_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,8 @@ def apply_learned_round(
get_blocks_fn: Callable,
model_prepare_fn: Optional[Callable] = None,
model_finish_fn: Optional[Callable] = None,
keep_gpu: bool = True) -> None:
keep_gpu: bool = True,
fast_update: bool = False) -> None:

# Perform any needed preprocessing before rounding optimisation, e.g. disabling caching in LLMs
model_dict = None if model_prepare_fn is None else model_prepare_fn(model)
Expand All @@ -602,26 +603,28 @@ def apply_learned_round(

# Initialize cache to store partial inputs and outputs for each block
cache.initialize_cache()

floating_point_datasets = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

floating_point_datasets is no longer used after the changes, right?

# Iterate over blocks and optimise the rounding parameters within each of them
for block_idx, block in enumerate(blocks):
# Distribute the model across devices to run a forward pass to capture
# inputs/outputs to the given block
model = offload_model(model)
# Cache needs to be cleared before populating it with the inputs and outputs
# to the block under optimization.
self._populate_cache(
cache,
model,
model_forward,
block,
data_loader,
keep_gpu=keep_gpu,
capture_quant_input=True,
capture_quant_output=False,
)
# Remove hooks needed to offload the model blocks to cpu
remove_hooks(model)
if block_idx == 0 or not fast_update:
cache.clear_cache()
model = offload_model(model)
# Cache needs to be cleared before populating it with the inputs and outputs
# to the block under optimization.
self._populate_cache(
cache,
model,
model_forward,
block,
data_loader,
keep_gpu=keep_gpu,
capture_quant_input=True,
capture_quant_output=False,
)
# Remove hooks needed to offload the model blocks to cpu
remove_hooks(model)

# Retrieve scales
scale_params = return_scale_parameters(block)
Expand Down Expand Up @@ -678,9 +681,69 @@ def apply_learned_round(
# Move the block back to CPU
block.cpu()

# Reset cache after optimisation
cache.clear_cache()
if block_idx + 1 < len(blocks) and fast_update:
cache, floating_point_datasets = self.skip_full_execution(block, blocks[block_idx+1], floating_point_datasets, block_forward, cache)

# The original configuration of the model is restored after finishing the optimization
if model_finish_fn is not None:
model_finish_fn(model, model_dict)

def skip_full_execution(self, block, next_block, floating_point_datasets, block_forward, cache):

# We need to compute two inputs, one is a floating point one to compute float out
# The second is a quantized one to create the quantized input of the next blocks

# If we don't have a floating_point_dataset, we retrieve it from the cache
# The idea is that the cache contains the input to the very first block, and there is nothing
# quantized before that. This is a moderately strong assumption
if len(floating_point_datasets) <= 0:
for i in range(len(cache)):
(args, kwargs), _ = cache.sample_batch([i])
floating_point_datasets.append((args, kwargs))

# We use the cache output to generate a new temporary dataloder for the next block
# and to update our floating_point_dataset
new_data_loader = []
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
for i in range(len(cache)):
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
(args, kwargs), output = cache.sample_batch([i])

new_data_loader.append(((output,), kwargs))
floating_point_datasets[i] = ((output,), kwargs)

# Temporary cache
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
tmp_cache = type(cache)()

# We compute the floating point output of the upcoming block
if torch.cuda.is_available():
next_block.cuda()
save_inputs_output(
next_block,
block_forward,
next_block,
new_data_loader,
tmp_cache,
store_inputs=False,
store_output=True,
keep_gpu=False,
disable_quant=True,
)
next_block.cpu()

cache['output'] = tmp_cache['output']

# Finally (!), we compute the quantized input of the next block
block.eval()
block.cuda()
next_quant_input = []
pbar = tqdm(range(len(cache)), desc='', leave=False)
with torch.no_grad():
for i in pbar:
(args, kwargs), _ = cache.sample_batch([i])
out = block_forward(block, (args, kwargs))
out = send_to_device(out, 'cpu')
next_quant_input.append((out,))
cache['args'] = next_quant_input
block.cpu()
pbar.close()

return cache, floating_point_datasets
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class CacheVision(Cache, dict):
def __init__(self) -> None:
super().__init__()
self.batch_dim = 0
self.initialize_cache()

def store_inputs(self, args, kwargs) -> None:
input_batch = args[0]
Expand Down
4 changes: 4 additions & 0 deletions src/brevitas_examples/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES]
[--export-prefix EXPORT_PREFIX]
[--checkpoint-name CHECKPOINT_NAME] [--fuse-sequences]
[--learned-round {None,linear_round}]
[--learned-round-fast-update]

options:
-h, --help show this help message and exit
Expand Down Expand Up @@ -196,5 +197,8 @@ options:
--learned-round {None,linear_round}
Whether to use learned round. If `None`, RTN is used
(default: None)
--learned-round-fast-update
Whether to use fast update with learned round.
Prototype (default: False)

```
41 changes: 21 additions & 20 deletions src/brevitas_examples/llm/llm_quant/learned_round_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class CacheLLM(Cache, dict):

def __init__(self) -> None:
super().__init__()
self.initialize_cache()

def store_inputs(self, args, kwargs) -> None:
self["args"].append(args)
Expand Down Expand Up @@ -107,25 +108,25 @@ def get_blocks(model: nn.Module, block_name_attribute: str) -> List[nn.Module]:


def apply_learned_round(
model: nn.Module,
calibration_loader: DataLoader,
iters: int = 200,
learned_round: str = "linear_round",
learned_round_loss: str = "mse",
block_name_attribute: str = "layers",
optimizer: str = "sign_sgd",
batch_size: int = 8,
learn_scale: bool = False,
use_best_model: bool = True,
amp_dtype: torch.dtype = torch.float16,
loss_scaling_factor: float = 1000,
lr_scheduler: Optional[str] = "linear",
optimizer_kwargs: Optional[Dict] = None,
lr_scheduler_kwargs: Optional[Dict] = None,
learned_round_loss_kwargs: Optional[Dict] = None,
scale_optimizer_class: Optional[str] = None,
scale_optimizer_kwargs: Optional[Dict] = None,
) -> None:
model: nn.Module,
calibration_loader: DataLoader,
iters: int = 200,
learned_round: str = "linear_round",
learned_round_loss: str = "mse",
block_name_attribute: str = "layers",
optimizer: str = "sign_sgd",
batch_size: int = 8,
learn_scale: bool = False,
use_best_model: bool = True,
amp_dtype: torch.dtype = torch.float16,
loss_scaling_factor: float = 1000,
lr_scheduler: Optional[str] = "linear",
optimizer_kwargs: Optional[Dict] = None,
lr_scheduler_kwargs: Optional[Dict] = None,
learned_round_loss_kwargs: Optional[Dict] = None,
scale_optimizer_class: Optional[str] = None,
scale_optimizer_kwargs: Optional[Dict] = None,
fast_update: bool = False) -> None:
# Parse strings to obtain the arguments for the optimizer
learned_round = parse_learned_round(learned_round)
learned_round_loss_class = parse_learned_round_loss_class(learned_round_loss)
Expand Down Expand Up @@ -166,4 +167,4 @@ def apply_learned_round(
model_prepare_fn=llm_learned_round_prepare_fn,
model_finish_fn=llm_learned_round_finish_fn,
keep_gpu=False,
)
fast_update=fast_update)
8 changes: 7 additions & 1 deletion src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,8 @@ def main(args):
scale_optimizer_class='sgd',
optimizer_kwargs={'lr': args.learned_round_lr},
scale_optimizer_kwargs={
'lr': args.learned_round_scale_lr, 'momentum': args.learned_round_scale_momentum})
'lr': args.learned_round_scale_lr, 'momentum': args.learned_round_scale_momentum},
fast_update=args.learned_round_fast_update)
print("Learned round applied.")

model = offload_model(model)
Expand Down Expand Up @@ -705,6 +706,11 @@ def parse_args(args):
default=None,
choices=[None, 'linear_round'],
help='Whether to use learned round. If `None`, RTN is used (default: %(default)s)')
parser.add_argument(
'--learned-round-fast-update',
default=False,
action="store_true",
help='Whether to use fast update with learned round. Prototype (default: %(default)s)')
return parser.parse_args(args)


Expand Down
Loading