Skip to content

Commit

Permalink
Support auto_round integration 2.x (#1806)
Browse files Browse the repository at this point in the history
Signed-off-by: Kaihui-intel <[email protected]>
  • Loading branch information
Kaihui-intel authored May 21, 2024
1 parent 24508d0 commit 4728fdc
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 74 deletions.
2 changes: 1 addition & 1 deletion .azure-pipelines/scripts/ut/env_setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ elif [[ $(echo "${test_case}" | grep -c "tf pruning") != 0 ]]; then
fi

if [[ $(echo "${test_case}" | grep -c "api") != 0 ]] || [[ $(echo "${test_case}" | grep -c "adaptor") != 0 ]]; then
pip install auto-round
pip install git+https://github.com/intel/auto-round.git@ecca5349981044e1278773a251b3fc5c0a11fe7b
fi

# test deps
Expand Down
26 changes: 15 additions & 11 deletions neural_compressor/adaptor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4916,12 +4916,12 @@ def autoround_quantize(self, model, tune_cfg, dataloader):
weight_config[op_name]["sym"] = config["weight"]["scheme"] == "sym"

# auto round recipes

enable_full_range = self.recipes["autoround_args"].get("enable_full_range", False)
batch_size = self.recipes["autoround_args"].get("batch_size", 8)
lr_scheduler = self.recipes["autoround_args"].get("lr_scheduler", None)
dataset_name = self.recipes["autoround_args"].get("dataset_name", "NeelNanda/pile-10k")
dataset_split = self.recipes["autoround_args"].get("dataset_split", "train")
use_quant_input = self.recipes["autoround_args"].get("use_quant_input", True)
dataset = self.recipes["autoround_args"].get("dataset", "NeelNanda/pile-10k")
enable_quanted_input = self.recipes["autoround_args"].get("enable_quanted_input", True)
enable_minmax_tuning = self.recipes["autoround_args"].get("enable_minmax_tuning", True)
lr = self.recipes["autoround_args"].get("lr", None)
minmax_lr = self.recipes["autoround_args"].get("minmax_lr", None)
Expand All @@ -4938,22 +4938,26 @@ def autoround_quantize(self, model, tune_cfg, dataloader):
data_type = self.recipes["autoround_args"].get("data_type", "int") ##only support data_type
scale_dtype = self.recipes["autoround_args"].get("scale_dtype", "fp16")
amp = self.recipes["autoround_args"].get("amp", True)
device = self.recipes["autoround_args"].get("device", None)
bits = self.recipes["autoround_args"].get("bits", 4)
group_size = self.recipes["autoround_args"].get("group_size", 128)
sym = self.recipes["autoround_args"].get("scheme", "asym") == "sym"

if dataloader is not None:
dataset = dataloader
model, autoround_config = autoround_quantize(
model=model,
tokenizer=None,
bits=4,
group_size=128,
sym=False,
bits=bits,
group_size=group_size,
sym=sym,
weight_config=weight_config,
enable_full_range=enable_full_range,
batch_size=batch_size,
amp=amp,
device=device,
lr_scheduler=lr_scheduler,
dataloader=dataloader,
dataset_name=dataset_name,
dataset_split=dataset_split,
use_quant_input=use_quant_input,
dataset=dataset,
enable_quanted_input=enable_quanted_input,
enable_minmax_tuning=enable_minmax_tuning,
lr=lr,
minmax_lr=minmax_lr,
Expand Down
25 changes: 19 additions & 6 deletions neural_compressor/adaptor/torch_utils/auto_round.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from auto_round.calib_dataset import CALIB_DATASETS # pylint: disable=E0401

def get_dataloader(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42, bs=8, n_samples=512):
"""Generate a DataLoader for calibration using specified parameters.
Args:
tokenizer (Tokenizer): The tokenizer to use for tokenization.
seqlen (int): The exact sequence length. samples < seqlen will be dropped,
samples longer than seqlen will be truncated
dataset_name (str, optional): The name of the dataset or datasets separated by commas.
Defaults to "NeelNanda/pile-10k".
split (str, optional): The data split to use. Defaults to None.
seed (int, optional): The random seed for reproducibility. Defaults to 42.
bs (int, optional): The batch size. Defaults to 4.
n_samples (int, optional): The total number of samples to include. Defaults to 512.
Returns:
DataLoader: The DataLoader for the calibrated dataset.
"""
from auto_round.calib_dataset import get_dataloader # pylint: disable=E0401

def get_dataloader(
tokenizer, seqlen=2048, seed=42, train_bs=8, dataset_split="train", dataset_name="NeelNanda/pile-10k"
):
get_dataloader = CALIB_DATASETS.get(dataset_name, CALIB_DATASETS["NeelNanda/pile-10k"])
dataloader = get_dataloader(
tokenizer, seqlen=seqlen, seed=seed, bs=train_bs, split=dataset_split, dataset_name=dataset_name
tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=seed, bs=bs, n_samples=n_samples
)
return dataloader
100 changes: 48 additions & 52 deletions neural_compressor/adaptor/torch_utils/weight_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import math
from copy import deepcopy
from typing import OrderedDict
from typing import Optional, OrderedDict, Union

from ...utils import logger
from ...utils.utility import LazyImport
Expand Down Expand Up @@ -679,7 +679,7 @@ def quant_weight_w_scale(weight, scale, zp, group_size=-1):

def autoround_quantize(
model,
tokenizer,
tokenizer=None,
bits: int = 4,
group_size: int = 128,
sym: bool = False,
Expand All @@ -689,10 +689,8 @@ def autoround_quantize(
amp: bool = True,
device=None,
lr_scheduler=None,
dataloader=None, ## to support later
dataset_name: str = "NeelNanda/pile-10k",
dataset_split: str = "train",
use_quant_input: bool = True,
dataset: Union[str, list, tuple, torch.utils.data.DataLoader] = "NeelNanda/pile-10k",
enable_quanted_input: bool = True,
enable_minmax_tuning: bool = True,
lr: float = None,
minmax_lr: float = None,
Expand All @@ -706,52 +704,52 @@ def autoround_quantize(
gradient_accumulate_steps: int = 1,
not_use_best_mse: bool = False,
dynamic_max_gap: int = -1,
data_type: str = "int", ##only support data_type
scale_dtype="fp16",
data_type: str = "int", ##only support int for now
scale_dtype: str = "fp16",
**kwargs,
):
"""Run autoround weight-only quantization.
Args:
model: The PyTorch model to be quantized.
tokenizer: Tokenizer for processing input data. Temporarily set as a mandatory parameter.
bits (int): Number of bits for quantization (default is 4).
group_size (int): Size of the quantization group (default is 128).
sym (bool): Whether the symmetric quantization is to be used.
weight_config (dict): Configuration for weight quantization (default is an empty dictionary).
weight_config={
'layer1':##layer_name
{
'data_type': 'int',
'bits': 4,
'group_size': 32,
'scheme': "asym", ## or sym
}
...
}
enable_full_range (bool): Whether to enable full range quantization (default is False).
bs (int): Batch size for training (default is 8).
amp (bool): Whether to use automatic mixed precision (default is True). Automatically detect and set.
device: The device to be used for tuning (default is None). Automatically detect and set.
lr_scheduler: The learning rate scheduler to be used.
dataloader: The dataloader for input data (to be supported in future).
dataset_name (str): The default dataset name (default is "NeelNanda/pile-10k").
dataset_split (str): The split of the dataset to be used (default is "train").
use_quant_input (bool): Whether to use quantized input data (default is True).
enable_minmax_tuning (bool): Whether to enable min-max tuning (default is True).
lr (float): The learning rate (default is 0.005).
minmax_lr (float): The learning rate for min-max tuning (default is None).
low_gpu_mem_usage (bool): Whether to use low GPU memory (default is True).
iters (int): Number of iterations (default is 200).
seqlen (int): Length of the sequence.
n_samples (int): Number of samples (default is 512).
sampler (str): The sampling method (default is "rand").
seed (int): The random seed (default is 42).
n_blocks (int): Number of blocks (default is 1).
gradient_accumulate_steps (int): Number of gradient accumulation steps (default is 1).
not_use_best_mse (bool): Whether to use mean squared error (default is False).
dynamic_max_gap (int): The dynamic maximum gap (default is -1).
data_type (str): The data type to be used (default is "int").
**kwargs: Additional keyword arguments.
model: The PyTorch model to be quantized.
tokenizer: An optional tokenizer for processing input data. If none is provided, a dataloader must be supplied.
bits (int): Number of bits for quantization (default is 4).
group_size (int): Size of the quantization group (default is 128).
sym (bool): Whether symmetric quantization is to be used (default is False).
weight_config (dict): Configuration for weight quantization (default is an empty dictionary).
weight_config={
'layer1':##layer_name
{
'data_type': 'int',
'bits': 4,
'group_size': 32,
'sym': False
}
...
}
enable_full_range (bool): Whether to enable full range quantization (default is False).
batch_size (int): Batch size for training (default is 8).
amp (bool): Whether to use automatic mixed precision (default is True).
device: The device to be used for tuning (default is "auto").
lr_scheduler: The learning rate scheduler to be used.
dataset (str): The default dataset name (default is "NeelNanda/pile-10k").
enable_quanted_input (bool): Whether to use the output of the previous quantized block as
the input for the current block (default is True).
enable_minmax_tuning (bool): Whether to enable weight min-max tuning (default is True).
lr (float): The learning rate (default is None, will be set to 1.0/iters).
minmax_lr (float): The learning rate for min-max tuning (default is None, it will be set to lr automatically).
low_gpu_mem_usage (bool): Whether to use low GPU memory (default is True).
iters (int): Number of iterations (default is 200).
seqlen (int): Data length of the sequence for tuning (default is 2048).
n_samples (int): Number of samples (default is 512).
sampler (str): The sampling method (default is "rand").
seed (int): The random seed (default is 42).
n_blocks (int): Number of blocks (default is 1).
gradient_accumulate_steps (int): Number of gradient accumulation steps (default is 1).
not_use_best_mse (bool): Whether to use mean squared error (default is False).
dynamic_max_gap (int): The dynamic maximum gap (default is -1).
data_type (str): The data type to be used (default is "int").
scale_dtype (str): The data type of quantization scale to be used (default is "float32"), different kernels
have different choices.
Returns:
The quantized model.
Expand All @@ -770,10 +768,8 @@ def autoround_quantize(
amp=amp,
device=device,
lr_scheduler=lr_scheduler,
dataloader=dataloader, ## to support later
dataset_name=dataset_name,
dataset_split=dataset_split,
use_quant_input=use_quant_input,
dataset=dataset,
enable_quanted_input=enable_quanted_input,
enable_minmax_tuning=enable_minmax_tuning,
lr=lr,
minmax_lr=minmax_lr,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -752,9 +752,7 @@ def test_AutoRound_quant(self):
tokenizer = transformers.AutoTokenizer.from_pretrained(
"hf-internal-testing/tiny-random-GPTJForCausalLM", trust_remote_code=True
)
dataloader = get_dataloader(
tokenizer, seqlen=10, seed=42, train_bs=8, dataset_split="train", dataset_name="NeelNanda/pile-10k"
)
dataloader = get_dataloader(tokenizer, 32, dataset_name="NeelNanda/pile-10k", seed=42, bs=8, n_samples=20)
fp32_model = copy.deepcopy(self.gptj)
conf = PostTrainingQuantConfig(
approach="weight_only",
Expand All @@ -777,7 +775,7 @@ def test_AutoRound_quant(self):
recipes={
"autoround_args": {
"n_samples": 20,
"seq_len": 10,
"seqlen": 32,
"iters": 10,
"scale_dtype": "fp32",
"amp": False,
Expand Down

0 comments on commit 4728fdc

Please sign in to comment.