Skip to content

Commit

Permalink
Merge pull request #7 from Xilinx/mobilenetv1_example
Browse files Browse the repository at this point in the history
Mobilenetv1 example
  • Loading branch information
volcacius authored Oct 15, 2019
2 parents 5db353c + 8435430 commit a94daa8
Show file tree
Hide file tree
Showing 6 changed files with 427 additions and 64 deletions.
60 changes: 0 additions & 60 deletions examples/common.py

This file was deleted.

138 changes: 138 additions & 0 deletions examples/imagenet_val.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import argparse
import os
import random

import torch
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets

from models import quant_mobilenet_v1

SEED = 123456

models = {'quant_mobilenet_v1': quant_mobilenet_v1}


parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
parser.add_argument('--imagenet-dir', help='path to folder containing Imagenet val folder')
parser.add_argument('--resume', type=str, help='Path to pretrained model')
parser.add_argument('--arch', choices=models.keys(), default='quant_mobilenet_v1',
help='model architecture: ' + ' | '.join(models.keys()))
parser.add_argument('--workers', default=4, type=int, help='number of data loading workers')
parser.add_argument('--batch-size', default=256, type=int, help='Minibatch size')
parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.')
parser.add_argument('--bit-width', default=4, type=int, help='Model bit-width')


def main():
args = parser.parse_args()
random.seed(SEED)
torch.manual_seed(SEED)

model = models[args.arch](bit_width=args.bit_width)

if args.gpu is not None:
torch.cuda.set_device(args.gpu)
model = model.cuda(args.gpu)
cudnn.benchmark = True

# optionally resume from a checkpoint
if args.resume:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
if args.gpu is None:
checkpoint = torch.load(args.resume)
else:
# Map model to be loaded to specified single gpu.
loc = 'cuda:{}'.format(args.gpu)
checkpoint = torch.load(args.resume, map_location=loc)
model.load_state_dict(checkpoint['state_dict'], strict=False)

valdir = os.path.join(args.imagenet_dir, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
val_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(valdir, transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])),
batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)

validate(val_loader, model, args)
return


def validate(val_loader, model, args):
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')

def print_accuracy(top1, top5, prefix=''):
print('{}Avg acc@1 {top1.avg:.3f} Avg acc@5 {top5.avg:.3f}'
.format(prefix, top1=top1, top5=top5))

model.eval()
with torch.no_grad():
num_batches = len(val_loader)
for i, (images, target) in enumerate(val_loader):
if args.gpu is not None:
images = images.cuda(args.gpu, non_blocking=True)
target = target.cuda(args.gpu, non_blocking=True)
output = model(images)
# measure accuracy
acc1, acc5 = accuracy(output, target, topk=(1, 5))
top1.update(acc1[0], images.size(0))
top5.update(acc5[0], images.size(0))
print_accuracy(top1, top5, '{}/{}: '.format(i, num_batches))
print_accuracy(top1, top5, 'Total:')
return


class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f'):
self.name = name
self.fmt = fmt
self.reset()

def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0

def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count

def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)


def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)

_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))

res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul(100.0 / batch_size))
return res


if __name__ == '__main__':
main()
2 changes: 2 additions & 0 deletions examples/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .mobilenetv1 import *
from .vgg import *
109 changes: 109 additions & 0 deletions examples/models/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import brevitas.nn as qnn
from brevitas.core.quant import QuantType
from brevitas.core.restrict_val import RestrictValueType
from brevitas.core.scaling import ScalingImplType
from brevitas.core.stats import StatsOp

QUANT_TYPE = QuantType.INT
SCALING_MIN_VAL = 2e-16

ACT_SCALING_IMPL_TYPE = ScalingImplType.PARAMETER
ACT_SCALING_PER_CHANNEL = False
ACT_SCALING_RESTRICT_SCALING_TYPE = RestrictValueType.LOG_FP
ACT_MAX_VAL = 6.0
ACT_RETURN_QUANT_TENSOR = False
ACT_PER_CHANNEL_BROADCASTABLE_SHAPE = None

WEIGHT_SCALING_IMPL_TYPE = ScalingImplType.STATS
WEIGHT_SCALING_PER_OUTPUT_CHANNEL = True
WEIGHT_SCALING_STATS_OP = StatsOp.MAX
WEIGHT_RESTRICT_SCALING_TYPE = RestrictValueType.LOG_FP
WEIGHT_NARROW_RANGE = True


def make_quant_conv2d(in_channels,
out_channels,
kernel_size,
stride,
padding,
groups,
bias,
bit_width,
weight_quant_type=QUANT_TYPE,
weight_scaling_impl_type=WEIGHT_SCALING_IMPL_TYPE,
weight_scaling_stats_op=WEIGHT_SCALING_STATS_OP,
weight_scaling_per_output_channel=WEIGHT_SCALING_PER_OUTPUT_CHANNEL,
weight_restrict_scaling_type=WEIGHT_RESTRICT_SCALING_TYPE,
weight_narrow_range=WEIGHT_NARROW_RANGE,
weight_scaling_min_val=SCALING_MIN_VAL):
return qnn.QuantConv2d(in_channels,
out_channels,
groups=groups,
kernel_size=kernel_size,
padding=padding,
stride=stride,
bias=bias,
weight_bit_width=bit_width,
weight_quant_type=weight_quant_type,
weight_scaling_impl_type=weight_scaling_impl_type,
weight_scaling_stats_op=weight_scaling_stats_op,
weight_scaling_per_output_channel=weight_scaling_per_output_channel,
weight_restrict_scaling_type=weight_restrict_scaling_type,
weight_narrow_range=weight_narrow_range,
weight_scaling_min_val=weight_scaling_min_val)


def make_quant_linear(in_channels,
out_channels,
bias,
bit_width,
weight_quant_type=QUANT_TYPE,
weight_scaling_impl_type=WEIGHT_SCALING_IMPL_TYPE,
weight_scaling_stats_op=WEIGHT_SCALING_STATS_OP,
weight_scaling_per_output_channel=WEIGHT_SCALING_PER_OUTPUT_CHANNEL,
weight_restrict_scaling_type=WEIGHT_RESTRICT_SCALING_TYPE,
weight_narrow_range=WEIGHT_NARROW_RANGE,
weight_scaling_min_val=SCALING_MIN_VAL):
return qnn.QuantLinear(in_channels, out_channels,
bias=bias,
weight_bit_width=bit_width,
weight_quant_type=weight_quant_type,
weight_scaling_impl_type=weight_scaling_impl_type,
weight_scaling_stats_op=weight_scaling_stats_op,
weight_scaling_per_output_channel=weight_scaling_per_output_channel,
weight_restrict_scaling_type=weight_restrict_scaling_type,
weight_narrow_range=weight_narrow_range,
weight_scaling_min_val=weight_scaling_min_val)


def make_quant_relu(bit_width,
quant_type=QUANT_TYPE,
scaling_impl_type=ACT_SCALING_IMPL_TYPE,
scaling_per_channel=ACT_SCALING_PER_CHANNEL,
restrict_scaling_type=ACT_SCALING_RESTRICT_SCALING_TYPE,
scaling_min_val=SCALING_MIN_VAL,
max_val=ACT_MAX_VAL,
return_quant_tensor=ACT_RETURN_QUANT_TENSOR,
per_channel_broadcastable_shape=ACT_PER_CHANNEL_BROADCASTABLE_SHAPE):
return qnn.QuantReLU(bit_width=bit_width,
quant_type=quant_type,
scaling_impl_type=scaling_impl_type,
scaling_per_channel=scaling_per_channel,
restrict_scaling_type=restrict_scaling_type,
scaling_min_val=scaling_min_val,
max_val=max_val,
return_quant_tensor=return_quant_tensor,
per_channel_broadcastable_shape=per_channel_broadcastable_shape)

def make_quant_avg_pool(bit_width,
kernel_size,
stride,
signed,
quant_type=QUANT_TYPE):
return qnn.QuantAvgPool2d(kernel_size=kernel_size,
quant_type=quant_type,
signed=signed,
stride=stride,
min_overall_bit_width=bit_width,
max_overall_bit_width=bit_width)

Loading

0 comments on commit a94daa8

Please sign in to comment.