Skip to content

Commit

Permalink
minimal docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
IamShubhamGupto committed Apr 14, 2023
1 parent 3671661 commit 418502f
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 46 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<br />
<p align="center">

<h1 align="center">Augmentations and Activations: Pushing The Performance Of Isotropic ConvNets</h1>
<h1 align="center">Activations And Augmentations: Pushing The Performance Of Isotropic ConvNets</h1>
<h4 align="center"><a href="https://github.com/datacrisis">Keifer Lee</a>, <a href="https://github.com/iamshubhamgupto">Shubham Gupta</a>, <a href="">Karan Sharma</a></h4>

</p>
Expand Down
34 changes: 18 additions & 16 deletions networks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch.nn as nn

# residual skip connection from ResNets
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
Expand All @@ -8,7 +8,7 @@ def __init__(self, fn):
def forward(self, x):
return self.fn(x) + x


# Vanilla ConvMixer architecture
def ConvMixer(dim, depth, kernel_size=5, patch_size=2, n_classes=10, activation='GELU'):

#Det activation func
Expand All @@ -26,21 +26,22 @@ def ConvMixer(dim, depth, kernel_size=5, patch_size=2, n_classes=10, activation=
act_fx,
nn.BatchNorm2d(dim),
*[nn.Sequential(
Residual(nn.Sequential(
nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"),
act_fx,
nn.BatchNorm2d(dim)
)),
nn.Conv2d(dim, dim, kernel_size=1),
act_fx,
nn.BatchNorm2d(dim)
) for i in range(depth)],
Residual(nn.Sequential(
nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"), #depthwise convolution
act_fx,
nn.BatchNorm2d(dim)
)
),
nn.Conv2d(dim, dim, kernel_size=1), #pointwise convolution
act_fx,
nn.BatchNorm2d(dim)
) for i in range(depth)],
nn.AdaptiveAvgPool2d((1,1)),
nn.Flatten(),
nn.Linear(dim, n_classes)
nn.Linear(dim, n_classes) #fully connected layer
)


# modified ConvMixer architecture, with inter block skip connections and deeper layers.
def ConvMixerXL(dim, depth, kernel_size=5, patch_size=2, n_classes=10, skip_period=3, activation='GeLU'):

#Det activation func
Expand All @@ -58,15 +59,16 @@ def ConvMixerXL(dim, depth, kernel_size=5, patch_size=2, n_classes=10, skip_peri
act_fx,
nn.BatchNorm2d(dim),
*[nn.Sequential(
Residual(nn.Sequential(*[nn.Sequential(
Residual(nn.Sequential(*[nn.Sequential( #inter block skip connections
Residual(
nn.Sequential(
# depthwise convolution
nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"),
act_fx,
nn.BatchNorm2d(dim)
)
),
nn.Conv2d(dim, dim, kernel_size=1),
nn.Conv2d(dim, dim, kernel_size=1), #pointwise convolution
act_fx,
nn.BatchNorm2d(dim)
) for i in range(depth//skip_period)]
Expand All @@ -76,5 +78,5 @@ def ConvMixerXL(dim, depth, kernel_size=5, patch_size=2, n_classes=10, skip_peri
],
nn.AdaptiveAvgPool2d((1,1)),
nn.Flatten(),
nn.Linear(dim, n_classes)
nn.Linear(dim, n_classes) #fully connected layer
)
43 changes: 24 additions & 19 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,30 +30,30 @@
parser.add_argument('--skip_period', default=3,
help='Denominator in extra skip connection periodicity computation; only used in ConvMixer-XL',
type=int)
parser.add_argument('--activation', default='GELU', choices=['GELU','ReLU','SiLU'])
parser.add_argument('--batch-size', default=64, type=int)
parser.add_argument('--scale', default=0.75, type=float)
parser.add_argument('--reprob', default=0.2, type=float)
parser.add_argument('--ra-m', default=12, type=int)
parser.add_argument('--ra-n', default=2, type=int)
parser.add_argument('--jitter', default=0.2, type=float)
parser.add_argument('--activation', default='GELU', choices=['GELU','ReLU','SiLU'], help='Activation function')
parser.add_argument('--batch-size', default=64, type=int, help='Batch size')
parser.add_argument('--scale', default=0.75, type=float, help='Scale factor resizing images')
parser.add_argument('--reprob', default=0.2, type=float, help='Random erase probability')
parser.add_argument('--ra-m', default=12, type=int, help='Magnitude of random augmentation'')
parser.add_argument('--ra-n', default=2, type=int, help='Number of random augmentations')
parser.add_argument('--jitter', default=0.2, type=float, help='Jittering factor')
parser.add_argument('--no_aug',action='store_true',help="Enable flag to remove augmentations")

parser.add_argument('--use_cutmix',action='store_true',help="Enable CutMix regularizer")
parser.add_argument('--cutmix_alpha', type=float, default=1.0)
parser.add_argument('--cutmix_alpha', type=float, default=1.0, help="CutMix alpha parameter")
parser.add_argument('--use_mixup',action='store_true',help="Enable MixUp regularizer")
parser.add_argument('--mixup_alpha', type=float, default=1.0)
parser.add_argument('--mixup_alpha', type=float, default=1.0, help="MixUp alpha parameter")

parser.add_argument('--hdim', default=256, type=int)
parser.add_argument('--depth', default=8, type=int)
parser.add_argument('--psize', default=2, type=int)
parser.add_argument('--conv-ks', default=5, type=int)
parser.add_argument('--hdim', default=256, type=int, help='Hidden dimension')
parser.add_argument('--depth', default=8, type=int, help='Depth of network')
parser.add_argument('--psize', default=2, type=int, help='Patch size')
parser.add_argument('--conv-ks', default=5, type=int, help='Kernel size of convolutions')

parser.add_argument('--wd', default=0.01, type=float)
parser.add_argument('--clip-norm', action='store_true')
parser.add_argument('--epochs', default=100, type=int)
parser.add_argument('--lr-max', default=0.005, type=float)
parser.add_argument('--workers', default=8, type=int)
parser.add_argument('--wd', default=0.01, type=float, help='Weight decay')
parser.add_argument('--clip-norm', action='store_true', help='Enable gradient clipping')
parser.add_argument('--epochs', default=100, type=int, help='Number of epochs')
parser.add_argument('--lr-max', default=0.005, type=float, help='Max learning rate')
parser.add_argument('--workers', default=8, type=int, help='Number of workers for dataloader')

parser.add_argument('--save_dir',default='./',help='Directory to save outputs to')

Expand Down Expand Up @@ -90,6 +90,7 @@
transforms.ToTensor(),
transforms.Normalize(cifar10_mean, cifar10_std)])

# No augmentations for test set
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(cifar10_mean, cifar10_std)
Expand Down Expand Up @@ -136,13 +137,17 @@
model = ConvMixerXL(args.hdim, args.depth, patch_size=args.psize, kernel_size=args.conv_ks, n_classes=10, skip_period=args.skip_period,
activation=args.activation)

# load to GPU for faster speed
model = nn.DataParallel(model).cuda()

# triangular learning rate scheduler, increases then decreases
lr_schedule = lambda t: np.interp([t], [0, args.epochs*2//5, args.epochs*4//5, args.epochs],
[0, args.lr_max, args.lr_max/20.0, 0])[0]

# AdamW optimizer for isotropic architecture
opt = optim.AdamW(model.parameters(), lr=args.lr_max, weight_decay=args.wd) #optimizer

# disabled cutmix and mixup for implmentation issues
# if args.use_cutmix:
# train_criterion = CutMixCriterion(reduction='mean')
# else:
Expand Down Expand Up @@ -253,7 +258,7 @@
test_loss_ls = [{'Value':round(test_loss/m,5)}]
test_acc_ls = [{'Value':round(test_acc/m,5)}]


# log to terminal
print(f'[{args.name}-{args.model}] | Test Acc: {test_acc/m:.4f}, Time: {time.time() - start:.1f}')


Expand Down
31 changes: 21 additions & 10 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
Implementation of CutMix data augmentation technique and MixUp data augmentation technique.
https://github.com/hysts/pytorch_cutmix/blob/master/cutmix.py
https://github.com/hysts/pytorch_mixup/blob/master/utils.py
We tried to use the code below but had several issues while training.
Due to timing constraints, we decided to disable these function calls for now.
The original ConvMixer is trained using the timm library which contains these augmentations by default
'''
import numpy as np
import torch
Expand All @@ -12,27 +16,30 @@
def cutmix(batch, alpha):
data, targets = batch

# generate mixed sample
indices = torch.randperm(data.size(0))
shuffled_data = data[indices]
shuffled_targets = targets[indices]

lam = np.random.beta(alpha, alpha)

image_h, image_w = data.shape[2:]
cx = np.random.uniform(0, image_w)
cy = np.random.uniform(0, image_h)
w = image_w * np.sqrt(1 - lam)
h = image_h * np.sqrt(1 - lam)
# calculate bbox
x0 = int(np.round(max(cx - w / 2, 0)))
x1 = int(np.round(min(cx + w / 2, image_w)))
y0 = int(np.round(max(cy - h / 2, 0)))
y1 = int(np.round(min(cy + h / 2, image_h)))

# paste shuffled image on top of original image
data[:, :, y0:y1, x0:x1] = shuffled_data[:, :, y0:y1, x0:x1]
targets = (targets, shuffled_targets, lam)

return data, targets

# one hot encoding function
def onehot(label, n_classes):
return torch.zeros(label.size(0), n_classes).scatter_(
1, label.view(-1, 1), 1)
Expand All @@ -47,10 +54,13 @@ def mixup(data, targets, alpha, n_classes):
targets2 = onehot(targets2, n_classes)

lam = torch.FloatTensor([np.random.beta(alpha, alpha)])
# create liear combination of two random samples
data = data * lam + data2 * (1 - lam)
targets = (targets, targets2, lam)
return data, targets

# randomly apply either cutmix or mixup
# requires both cutmix and mixup flags to be enabled
class CustomCollator:
def __init__(self, cutmix_alpha, mixup_alpha, num_classes):
self.cutmix_alpha = cutmix_alpha
Expand All @@ -64,18 +74,19 @@ def __call__(self, batch):
else:
batch = mixup(*batch, self.mixup_alpha, self.num_classes)
return batch

# deprecated

# class CutMixCriterion:
# def __init__(self, reduction):
# self.criterion = nn.CrossEntropyLoss(reduction=reduction)

class CutMixCriterion:
def __init__(self, reduction):
self.criterion = nn.CrossEntropyLoss(reduction=reduction)

def __call__(self, preds, targets):
targets1, targets2, lam = targets
return lam * self.criterion(
preds, targets1) + (1 - lam) * self.criterion(preds, targets2)
# def __call__(self, preds, targets):
# targets1, targets2, lam = targets
# return lam * self.criterion(
# preds, targets1) + (1 - lam) * self.criterion(preds, targets2)

# deprecated

# def cross_entropy_loss(input, target, size_average=True):
# input = F.log_softmax(input, dim=1)
# loss = -torch.sum(input * target)
Expand Down

0 comments on commit 418502f

Please sign in to comment.