diff --git a/src/brevitas_examples/llm/llm_quant/mha_layers.py b/src/brevitas_examples/llm/llm_quant/mha_layers.py index b73622c01..9b56c0e63 100644 --- a/src/brevitas_examples/llm/llm_quant/mha_layers.py +++ b/src/brevitas_examples/llm/llm_quant/mha_layers.py @@ -296,11 +296,5 @@ def set_weight(value): del state_dict[name] state_dict[prefix + 'mha.out_proj.weight'] = torch.eye(self.mha.out_proj.weight.shape[0]) state_dict[prefix + 'mha.out_proj.bias'] = torch.zeros(self.mha.out_proj.bias.shape) - # elif prefix + 'self.output.dense.weight' in name: - # state_dict[prefix + 'mha.out_proj.weight'] = value - # del state_dict[name] - # elif prefix + 'self.output.dense.bias' in name: - # state_dict[prefix + 'mha.out_proj.bias'] = value - # del state_dict[name] return super()._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)