diff --git a/src/brevitas_examples/bnn_pynq/bnn_pynq_train.py b/src/brevitas_examples/bnn_pynq/bnn_pynq_train.py index 5aa316aea..505617745 100644 --- a/src/brevitas_examples/bnn_pynq/bnn_pynq_train.py +++ b/src/brevitas_examples/bnn_pynq/bnn_pynq_train.py @@ -73,6 +73,8 @@ def parse_args(args): "--state_dict_to_pth", action='store_true', help="Saves a model state_dict into a pth and then exits") + parser.add_argument("--export_qonnx", action='store_true', help="Export QONNX Model") + parser.add_argument("--export_qcdq_onnx", action='store_true', help="Export QCDQ ONNX Model") return parser.parse_args(args) @@ -110,7 +112,7 @@ def launch(cmd_args): # Avoid creating new folders etc. if args.evaluate: - args.dry_run = True + args.dry_run = True # Comment out to export ONNX models from pre-trained # Init trainer trainer = Trainer(args) diff --git a/src/brevitas_examples/bnn_pynq/models/resnet.py b/src/brevitas_examples/bnn_pynq/models/resnet.py index 14efdf498..366cc4de9 100644 --- a/src/brevitas_examples/bnn_pynq/models/resnet.py +++ b/src/brevitas_examples/bnn_pynq/models/resnet.py @@ -9,7 +9,7 @@ import brevitas.nn as qnn from brevitas.quant import Int8WeightPerChannelFloat from brevitas.quant import Int8WeightPerTensorFloat -from brevitas.quant import Int32Bias +from brevitas.quant import Int24Bias from brevitas.quant import TruncTo8bit from brevitas.quant_tensor import QuantTensor @@ -120,8 +120,8 @@ def __init__( num_classes=10, act_bit_width=8, weight_bit_width=8, - round_average_pool=False, - last_layer_bias_quant=Int32Bias, + round_average_pool=True, + last_layer_bias_quant=Int24Bias, weight_quant=Int8WeightPerChannelFloat, first_layer_weight_quant=Int8WeightPerChannelFloat, last_layer_weight_quant=Int8WeightPerTensorFloat): diff --git a/src/brevitas_examples/bnn_pynq/trainer.py b/src/brevitas_examples/bnn_pynq/trainer.py index 78c1db97e..90ff99f5c 100644 --- a/src/brevitas_examples/bnn_pynq/trainer.py +++ b/src/brevitas_examples/bnn_pynq/trainer.py @@ -18,6 +18,9 @@ from torchvision.datasets import CIFAR10 from torchvision.datasets import MNIST +from brevitas.export import export_onnx_qcdq +from brevitas.export import export_qonnx + from .logger import EvalEpochMeters from .logger import Logger from .logger import TrainingEpochMeters @@ -149,6 +152,29 @@ def __init__(self, args): self.logger.info("Saving checkpoint model to {}".format(new_path)) exit(0) + if args.export_qonnx: + name = args.network.lower() + path = os.path.join(self.checkpoints_dir_path, name) + export_qonnx(model, self.train_loader.dataset[0][0].unsqueeze(0), path) + with open(path, "rb") as f: + bytes = f.read() + readable_hash = sha256(bytes).hexdigest()[:8] + new_path = os.path.join( + self.checkpoints_dir_path, "{}-qonnx-{}.onnx".format(name, readable_hash)) + os.rename(path, new_path) + self.logger.info("Exporting QONNX to {}".format(new_path)) + if args.export_qcdq_onnx: + name = args.network.lower() + path = os.path.join(self.checkpoints_dir_path, name) + export_onnx_qcdq(model, self.train_loader.dataset[0][0].unsqueeze(0), path) + with open(path, "rb") as f: + bytes = f.read() + readable_hash = sha256(bytes).hexdigest()[:8] + new_path = os.path.join( + self.checkpoints_dir_path, "{}-qcdq-{}.onnx".format(name, readable_hash)) + os.rename(path, new_path) + self.logger.info("Exporting QCDQ ONNX to {}".format(new_path)) + if args.gpus is not None and len(args.gpus) == 1: model = model.to(device=self.device) if args.gpus is not None and len(args.gpus) > 1: