-
Notifications
You must be signed in to change notification settings - Fork 205
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #7 from Xilinx/mobilenetv1_example
Mobilenetv1 example
- Loading branch information
Showing
6 changed files
with
427 additions
and
64 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
import argparse | ||
import os | ||
import random | ||
|
||
import torch | ||
import torch.nn.parallel | ||
import torch.backends.cudnn as cudnn | ||
import torch.optim | ||
import torch.utils.data | ||
import torch.utils.data.distributed | ||
import torchvision.transforms as transforms | ||
import torchvision.datasets as datasets | ||
|
||
from models import quant_mobilenet_v1 | ||
|
||
SEED = 123456 | ||
|
||
models = {'quant_mobilenet_v1': quant_mobilenet_v1} | ||
|
||
|
||
parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation') | ||
parser.add_argument('--imagenet-dir', help='path to folder containing Imagenet val folder') | ||
parser.add_argument('--resume', type=str, help='Path to pretrained model') | ||
parser.add_argument('--arch', choices=models.keys(), default='quant_mobilenet_v1', | ||
help='model architecture: ' + ' | '.join(models.keys())) | ||
parser.add_argument('--workers', default=4, type=int, help='number of data loading workers') | ||
parser.add_argument('--batch-size', default=256, type=int, help='Minibatch size') | ||
parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.') | ||
parser.add_argument('--bit-width', default=4, type=int, help='Model bit-width') | ||
|
||
|
||
def main(): | ||
args = parser.parse_args() | ||
random.seed(SEED) | ||
torch.manual_seed(SEED) | ||
|
||
model = models[args.arch](bit_width=args.bit_width) | ||
|
||
if args.gpu is not None: | ||
torch.cuda.set_device(args.gpu) | ||
model = model.cuda(args.gpu) | ||
cudnn.benchmark = True | ||
|
||
# optionally resume from a checkpoint | ||
if args.resume: | ||
if os.path.isfile(args.resume): | ||
print("=> loading checkpoint '{}'".format(args.resume)) | ||
if args.gpu is None: | ||
checkpoint = torch.load(args.resume) | ||
else: | ||
# Map model to be loaded to specified single gpu. | ||
loc = 'cuda:{}'.format(args.gpu) | ||
checkpoint = torch.load(args.resume, map_location=loc) | ||
model.load_state_dict(checkpoint['state_dict'], strict=False) | ||
|
||
valdir = os.path.join(args.imagenet_dir, 'val') | ||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | ||
val_loader = torch.utils.data.DataLoader( | ||
datasets.ImageFolder(valdir, transforms.Compose([ | ||
transforms.Resize(256), | ||
transforms.CenterCrop(224), | ||
transforms.ToTensor(), | ||
normalize, | ||
])), | ||
batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) | ||
|
||
validate(val_loader, model, args) | ||
return | ||
|
||
|
||
def validate(val_loader, model, args): | ||
top1 = AverageMeter('Acc@1', ':6.2f') | ||
top5 = AverageMeter('Acc@5', ':6.2f') | ||
|
||
def print_accuracy(top1, top5, prefix=''): | ||
print('{}Avg acc@1 {top1.avg:.3f} Avg acc@5 {top5.avg:.3f}' | ||
.format(prefix, top1=top1, top5=top5)) | ||
|
||
model.eval() | ||
with torch.no_grad(): | ||
num_batches = len(val_loader) | ||
for i, (images, target) in enumerate(val_loader): | ||
if args.gpu is not None: | ||
images = images.cuda(args.gpu, non_blocking=True) | ||
target = target.cuda(args.gpu, non_blocking=True) | ||
output = model(images) | ||
# measure accuracy | ||
acc1, acc5 = accuracy(output, target, topk=(1, 5)) | ||
top1.update(acc1[0], images.size(0)) | ||
top5.update(acc5[0], images.size(0)) | ||
print_accuracy(top1, top5, '{}/{}: '.format(i, num_batches)) | ||
print_accuracy(top1, top5, 'Total:') | ||
return | ||
|
||
|
||
class AverageMeter(object): | ||
"""Computes and stores the average and current value""" | ||
def __init__(self, name, fmt=':f'): | ||
self.name = name | ||
self.fmt = fmt | ||
self.reset() | ||
|
||
def reset(self): | ||
self.val = 0 | ||
self.avg = 0 | ||
self.sum = 0 | ||
self.count = 0 | ||
|
||
def update(self, val, n=1): | ||
self.val = val | ||
self.sum += val * n | ||
self.count += n | ||
self.avg = self.sum / self.count | ||
|
||
def __str__(self): | ||
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' | ||
return fmtstr.format(**self.__dict__) | ||
|
||
|
||
def accuracy(output, target, topk=(1,)): | ||
"""Computes the accuracy over the k top predictions for the specified values of k""" | ||
with torch.no_grad(): | ||
maxk = max(topk) | ||
batch_size = target.size(0) | ||
|
||
_, pred = output.topk(maxk, 1, True, True) | ||
pred = pred.t() | ||
correct = pred.eq(target.view(1, -1).expand_as(pred)) | ||
|
||
res = [] | ||
for k in topk: | ||
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) | ||
res.append(correct_k.mul(100.0 / batch_size)) | ||
return res | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .mobilenetv1 import * | ||
from .vgg import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import brevitas.nn as qnn | ||
from brevitas.core.quant import QuantType | ||
from brevitas.core.restrict_val import RestrictValueType | ||
from brevitas.core.scaling import ScalingImplType | ||
from brevitas.core.stats import StatsOp | ||
|
||
QUANT_TYPE = QuantType.INT | ||
SCALING_MIN_VAL = 2e-16 | ||
|
||
ACT_SCALING_IMPL_TYPE = ScalingImplType.PARAMETER | ||
ACT_SCALING_PER_CHANNEL = False | ||
ACT_SCALING_RESTRICT_SCALING_TYPE = RestrictValueType.LOG_FP | ||
ACT_MAX_VAL = 6.0 | ||
ACT_RETURN_QUANT_TENSOR = False | ||
ACT_PER_CHANNEL_BROADCASTABLE_SHAPE = None | ||
|
||
WEIGHT_SCALING_IMPL_TYPE = ScalingImplType.STATS | ||
WEIGHT_SCALING_PER_OUTPUT_CHANNEL = True | ||
WEIGHT_SCALING_STATS_OP = StatsOp.MAX | ||
WEIGHT_RESTRICT_SCALING_TYPE = RestrictValueType.LOG_FP | ||
WEIGHT_NARROW_RANGE = True | ||
|
||
|
||
def make_quant_conv2d(in_channels, | ||
out_channels, | ||
kernel_size, | ||
stride, | ||
padding, | ||
groups, | ||
bias, | ||
bit_width, | ||
weight_quant_type=QUANT_TYPE, | ||
weight_scaling_impl_type=WEIGHT_SCALING_IMPL_TYPE, | ||
weight_scaling_stats_op=WEIGHT_SCALING_STATS_OP, | ||
weight_scaling_per_output_channel=WEIGHT_SCALING_PER_OUTPUT_CHANNEL, | ||
weight_restrict_scaling_type=WEIGHT_RESTRICT_SCALING_TYPE, | ||
weight_narrow_range=WEIGHT_NARROW_RANGE, | ||
weight_scaling_min_val=SCALING_MIN_VAL): | ||
return qnn.QuantConv2d(in_channels, | ||
out_channels, | ||
groups=groups, | ||
kernel_size=kernel_size, | ||
padding=padding, | ||
stride=stride, | ||
bias=bias, | ||
weight_bit_width=bit_width, | ||
weight_quant_type=weight_quant_type, | ||
weight_scaling_impl_type=weight_scaling_impl_type, | ||
weight_scaling_stats_op=weight_scaling_stats_op, | ||
weight_scaling_per_output_channel=weight_scaling_per_output_channel, | ||
weight_restrict_scaling_type=weight_restrict_scaling_type, | ||
weight_narrow_range=weight_narrow_range, | ||
weight_scaling_min_val=weight_scaling_min_val) | ||
|
||
|
||
def make_quant_linear(in_channels, | ||
out_channels, | ||
bias, | ||
bit_width, | ||
weight_quant_type=QUANT_TYPE, | ||
weight_scaling_impl_type=WEIGHT_SCALING_IMPL_TYPE, | ||
weight_scaling_stats_op=WEIGHT_SCALING_STATS_OP, | ||
weight_scaling_per_output_channel=WEIGHT_SCALING_PER_OUTPUT_CHANNEL, | ||
weight_restrict_scaling_type=WEIGHT_RESTRICT_SCALING_TYPE, | ||
weight_narrow_range=WEIGHT_NARROW_RANGE, | ||
weight_scaling_min_val=SCALING_MIN_VAL): | ||
return qnn.QuantLinear(in_channels, out_channels, | ||
bias=bias, | ||
weight_bit_width=bit_width, | ||
weight_quant_type=weight_quant_type, | ||
weight_scaling_impl_type=weight_scaling_impl_type, | ||
weight_scaling_stats_op=weight_scaling_stats_op, | ||
weight_scaling_per_output_channel=weight_scaling_per_output_channel, | ||
weight_restrict_scaling_type=weight_restrict_scaling_type, | ||
weight_narrow_range=weight_narrow_range, | ||
weight_scaling_min_val=weight_scaling_min_val) | ||
|
||
|
||
def make_quant_relu(bit_width, | ||
quant_type=QUANT_TYPE, | ||
scaling_impl_type=ACT_SCALING_IMPL_TYPE, | ||
scaling_per_channel=ACT_SCALING_PER_CHANNEL, | ||
restrict_scaling_type=ACT_SCALING_RESTRICT_SCALING_TYPE, | ||
scaling_min_val=SCALING_MIN_VAL, | ||
max_val=ACT_MAX_VAL, | ||
return_quant_tensor=ACT_RETURN_QUANT_TENSOR, | ||
per_channel_broadcastable_shape=ACT_PER_CHANNEL_BROADCASTABLE_SHAPE): | ||
return qnn.QuantReLU(bit_width=bit_width, | ||
quant_type=quant_type, | ||
scaling_impl_type=scaling_impl_type, | ||
scaling_per_channel=scaling_per_channel, | ||
restrict_scaling_type=restrict_scaling_type, | ||
scaling_min_val=scaling_min_val, | ||
max_val=max_val, | ||
return_quant_tensor=return_quant_tensor, | ||
per_channel_broadcastable_shape=per_channel_broadcastable_shape) | ||
|
||
def make_quant_avg_pool(bit_width, | ||
kernel_size, | ||
stride, | ||
signed, | ||
quant_type=QUANT_TYPE): | ||
return qnn.QuantAvgPool2d(kernel_size=kernel_size, | ||
quant_type=quant_type, | ||
signed=signed, | ||
stride=stride, | ||
min_overall_bit_width=bit_width, | ||
max_overall_bit_width=bit_width) | ||
|
Oops, something went wrong.