Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Added ONNX export to BNN-PYNQ example #916

Open
wants to merge 3 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/brevitas_examples/bnn_pynq/bnn_pynq_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def parse_args(args):
"--state_dict_to_pth",
action='store_true',
help="Saves a model state_dict into a pth and then exits")
parser.add_argument("--export_qonnx", action='store_true', help="Export QONNX Model")
parser.add_argument("--export_qcdq_onnx", action='store_true', help="Export QCDQ ONNX Model")
return parser.parse_args(args)


Expand Down Expand Up @@ -110,7 +112,7 @@ def launch(cmd_args):

# Avoid creating new folders etc.
if args.evaluate:
args.dry_run = True
args.dry_run = True # Comment out to export ONNX models from pre-trained

# Init trainer
trainer = Trainer(args)
Expand Down
6 changes: 3 additions & 3 deletions src/brevitas_examples/bnn_pynq/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import brevitas.nn as qnn
from brevitas.quant import Int8WeightPerChannelFloat
from brevitas.quant import Int8WeightPerTensorFloat
from brevitas.quant import Int32Bias
from brevitas.quant import Int24Bias
from brevitas.quant import TruncTo8bit
from brevitas.quant_tensor import QuantTensor

Expand Down Expand Up @@ -120,8 +120,8 @@ def __init__(
num_classes=10,
act_bit_width=8,
weight_bit_width=8,
round_average_pool=False,
last_layer_bias_quant=Int32Bias,
round_average_pool=True,
last_layer_bias_quant=Int24Bias,
weight_quant=Int8WeightPerChannelFloat,
first_layer_weight_quant=Int8WeightPerChannelFloat,
last_layer_weight_quant=Int8WeightPerTensorFloat):
Expand Down
26 changes: 26 additions & 0 deletions src/brevitas_examples/bnn_pynq/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from torchvision.datasets import CIFAR10
from torchvision.datasets import MNIST

from brevitas.export import export_onnx_qcdq
from brevitas.export import export_qonnx

from .logger import EvalEpochMeters
from .logger import Logger
from .logger import TrainingEpochMeters
Expand Down Expand Up @@ -149,6 +152,29 @@ def __init__(self, args):
self.logger.info("Saving checkpoint model to {}".format(new_path))
exit(0)

if args.export_qonnx:
name = args.network.lower()
path = os.path.join(self.checkpoints_dir_path, name)
export_qonnx(model, self.train_loader.dataset[0][0].unsqueeze(0), path)
with open(path, "rb") as f:
bytes = f.read()
readable_hash = sha256(bytes).hexdigest()[:8]
new_path = os.path.join(
self.checkpoints_dir_path, "{}-qonnx-{}.onnx".format(name, readable_hash))
os.rename(path, new_path)
self.logger.info("Exporting QONNX to {}".format(new_path))
if args.export_qcdq_onnx:
name = args.network.lower()
path = os.path.join(self.checkpoints_dir_path, name)
export_onnx_qcdq(model, self.train_loader.dataset[0][0].unsqueeze(0), path)
with open(path, "rb") as f:
bytes = f.read()
readable_hash = sha256(bytes).hexdigest()[:8]
new_path = os.path.join(
self.checkpoints_dir_path, "{}-qcdq-{}.onnx".format(name, readable_hash))
os.rename(path, new_path)
self.logger.info("Exporting QCDQ ONNX to {}".format(new_path))

if args.gpus is not None and len(args.gpus) == 1:
model = model.to(device=self.device)
if args.gpus is not None and len(args.gpus) > 1:
Expand Down
Loading