Skip to content

Commit

Permalink
Added cutmix / mixup augment
Browse files Browse the repository at this point in the history
  • Loading branch information
Keifer Lee committed Apr 12, 2023
1 parent 1b613d5 commit d4d4988
Show file tree
Hide file tree
Showing 5 changed files with 380 additions and 12 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
*.txt
*.pkl

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
57 changes: 45 additions & 12 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import json
import os

from utils import *
from networks import ConvMixer, ConvMixerXL

###========================================================================
Expand Down Expand Up @@ -38,6 +39,11 @@
parser.add_argument('--jitter', default=0.2, type=float)
parser.add_argument('--no_aug',action='store_true',help="Enable flag to remove augmentations")

parser.add_argument('--use_cutmix',action='store_true',help="Enable CutMix regularizer")
parser.add_argument('--cutmix_alpha', type=float, default=1.0)
parser.add_argument('--use_mixup',action='store_true',help="Enable MixUp regularizer")
parser.add_argument('--mixup_alpha', type=float, default=1.0)

parser.add_argument('--hdim', default=256, type=int)
parser.add_argument('--depth', default=8, type=int)
parser.add_argument('--psize', default=2, type=int)
Expand Down Expand Up @@ -95,33 +101,54 @@
testvalset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=test_transform)


if args.use_cutmix:
collator = CustomCollator(args.cutmix_alpha, args.mixup_alpha, 10)
trainloader = torch.utils.data.DataLoader(trainset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.workers,
collate_fn=collator)
else:
trainloader = torch.utils.data.DataLoader(trainset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.workers)

#Split test-val set
ln = len(testvalset)
valset,testset = torch.utils.data.random_split(testvalset,[ln//2,ln//2])

#Dataloaders
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers)
valloader = torch.utils.data.DataLoader(valset, batch_size=args.batch_size,
shuffle=False, num_workers=args.workers)
testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size,
shuffle=False, num_workers=args.workers)
valloader = torch.utils.data.DataLoader(valset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.workers)
testloader = torch.utils.data.DataLoader(testset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.workers)


###========================================================================
#Setup model, optim and scheduler
if args.model == 'CM':
model = ConvMixer(args.hdim, args.depth, patch_size=args.psize, kernel_size=args.conv_ks, n_classes=10, activation=args.activation)
elif args.model == 'CM-XL':
model = ConvMixerXL(args.hdim, args.depth, patch_size=args.psize, kernel_size=args.conv_ks, n_classes=10, skip_period=args.skip_period, activation=args.activation)
model = ConvMixerXL(args.hdim, args.depth, patch_size=args.psize, kernel_size=args.conv_ks, n_classes=10, skip_period=args.skip_period,
activation=args.activation)

model = nn.DataParallel(model).cuda()

lr_schedule = lambda t: np.interp([t], [0, args.epochs*2//5, args.epochs*4//5, args.epochs],
[0, args.lr_max, args.lr_max/20.0, 0])[0]

opt = optim.AdamW(model.parameters(), lr=args.lr_max, weight_decay=args.wd) #optimizer
criterion = nn.CrossEntropyLoss() #loss function
if args.use_cutmix:
train_criterion = CutMixCriterion(reduction='mean')
else:
train_criterion = nn.CrossEntropyLoss(reduction='mean')
test_criterion = nn.CrossEntropyLoss() #loss function
scaler = torch.cuda.amp.GradScaler() #grad scaler


Expand All @@ -144,6 +171,12 @@
#Go through training steps
for i, (X, y) in enumerate(trainloader):

print(type(X),print(type(y)))
print(X)
print(y)
print(X.shape,y.shape)


#Set train mode and port sample to cuda
model.train()
X, y = X.cuda(), y.cuda()
Expand All @@ -156,7 +189,7 @@
#FP and compute loss with amp
with torch.cuda.amp.autocast():
output = model(X)
loss = criterion(output, y)
loss = train_criterion(output, y)

#Scale gradient and clip norm
scaler.scale(loss).backward()
Expand All @@ -182,7 +215,7 @@
#FP and compute result and log
with torch.cuda.amp.autocast():
output = model(X)
loss = criterion(output,y)
loss = test_criterion(output,y)

val_loss += loss.item() * y.size(0)
val_acc += (output.max(1)[1] == y).sum().item()
Expand Down Expand Up @@ -215,7 +248,7 @@
#FP and compute result and log
with torch.cuda.amp.autocast():
output = model(X)
loss = criterion(output,y)
loss = test_criterion(output,y)

test_loss += loss.item() * y.size(0)
test_acc += (output.max(1)[1] == y).sum().item()
Expand Down Expand Up @@ -252,4 +285,4 @@
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': opt.state_dict(),
}, os.path.join(dst,'{}.pkl'.format(args.name)))
}, os.path.join(dst,'{}.pkl'.format(args.name)))
File renamed without changes.
255 changes: 255 additions & 0 deletions train_old.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import time
import random
import argparse
import json
import os

from networks import ConvMixer, ConvMixerXL

###========================================================================
#Setup seeds
seed = 1204
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

###========================================================================
#Setup args
parser = argparse.ArgumentParser()

parser.add_argument('--name', type=str, default="ConvMixer")

parser.add_argument('--model', default='CM', choices=['CM','CM-XL'])
parser.add_argument('--skip_period', default=3,
help='Denominator in extra skip connection periodicity computation; only used in ConvMixer-XL',
type=int)
parser.add_argument('--activation', default='GELU', choices=['GELU','ReLU','SiLU'])
parser.add_argument('--batch-size', default=64, type=int)
parser.add_argument('--scale', default=0.75, type=float)
parser.add_argument('--reprob', default=0.2, type=float)
parser.add_argument('--ra-m', default=12, type=int)
parser.add_argument('--ra-n', default=2, type=int)
parser.add_argument('--jitter', default=0.2, type=float)
parser.add_argument('--no_aug',action='store_true',help="Enable flag to remove augmentations")

parser.add_argument('--hdim', default=256, type=int)
parser.add_argument('--depth', default=8, type=int)
parser.add_argument('--psize', default=2, type=int)
parser.add_argument('--conv-ks', default=5, type=int)

parser.add_argument('--wd', default=0.01, type=float)
parser.add_argument('--clip-norm', action='store_true')
parser.add_argument('--epochs', default=100, type=int)
parser.add_argument('--lr-max', default=0.005, type=float)
parser.add_argument('--workers', default=8, type=int)

parser.add_argument('--save_dir',default='./',help='Directory to save outputs to')

args = parser.parse_args()

#Check dir exist; if not, create
if not os.path.isdir(args.save_dir):
os.makedirs(args.save_dir)
dst = os.path.join(os.getcwd(),args.save_dir)

#Save args
with open(os.path.join(dst,'args_{}.txt'.format(args.name)), 'w') as f:
json.dump(args.__dict__, f, indent=4)


###========================================================================
#Setup dataset
cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std = (0.2471, 0.2435, 0.2616)

#Transforms
if not args.no_aug:
train_transform = transforms.Compose([
transforms.RandomResizedCrop(32, scale=(args.scale, 1.0), ratio=(1.0, 1.0)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandAugment(num_ops=args.ra_n, magnitude=args.ra_m),
transforms.ColorJitter(args.jitter, args.jitter, args.jitter),
transforms.ToTensor(),
transforms.Normalize(cifar10_mean, cifar10_std),
transforms.RandomErasing(p=args.reprob)
])
else:
train_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(cifar10_mean, cifar10_std)])

test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(cifar10_mean, cifar10_std)
])

#Dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=train_transform)
testvalset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=test_transform)

#Split test-val set
ln = len(testvalset)
valset,testset = torch.utils.data.random_split(testvalset,[ln//2,ln//2])

#Dataloaders
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers)
valloader = torch.utils.data.DataLoader(valset, batch_size=args.batch_size,
shuffle=False, num_workers=args.workers)
testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size,
shuffle=False, num_workers=args.workers)


###========================================================================
#Setup model, optim and scheduler
if args.model == 'CM':
model = ConvMixer(args.hdim, args.depth, patch_size=args.psize, kernel_size=args.conv_ks, n_classes=10, activation=args.activation)
elif args.model == 'CM-XL':
model = ConvMixerXL(args.hdim, args.depth, patch_size=args.psize, kernel_size=args.conv_ks, n_classes=10, skip_period=args.skip_period, activation=args.activation)

model = nn.DataParallel(model).cuda()

lr_schedule = lambda t: np.interp([t], [0, args.epochs*2//5, args.epochs*4//5, args.epochs],
[0, args.lr_max, args.lr_max/20.0, 0])[0]

opt = optim.AdamW(model.parameters(), lr=args.lr_max, weight_decay=args.wd) #optimizer
criterion = nn.CrossEntropyLoss() #loss function
scaler = torch.cuda.amp.GradScaler() #grad scaler


###========================================================================
#Training and validation

#Setup vars
train_loss_ls = []
train_acc_ls = []

val_loss_ls = []
val_acc_ls = []


#Train loop
for epoch in range(args.epochs):
start = time.time()
train_loss, train_acc, n = 0, 0, 0

#Go through training steps
for i, (X, y) in enumerate(trainloader):

#Set train mode and port sample to cuda
model.train()
X, y = X.cuda(), y.cuda()

#Step lr scheduler and zero grad
lr = lr_schedule(epoch + (i + 1)/len(trainloader))
opt.param_groups[0].update(lr=lr)
opt.zero_grad()

#FP and compute loss with amp
with torch.cuda.amp.autocast():
output = model(X)
loss = criterion(output, y)

#Scale gradient and clip norm
scaler.scale(loss).backward()
if args.clip_norm:
scaler.unscale_(opt)
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(opt)
scaler.update()

#Compute and log loss and acc
train_loss += loss.item() * y.size(0)
train_acc += (output.max(1)[1] == y).sum().item()
n += y.size(0)

#Go through eval steps
model.eval() #to evaluation mode first
val_acc, val_loss, m = 0, 0, 0

with torch.no_grad(): #no grad needed
for i, (X, y) in enumerate(valloader):
X, y = X.cuda(), y.cuda() #port to cuda

#FP and compute result and log
with torch.cuda.amp.autocast():
output = model(X)
loss = criterion(output,y)

val_loss += loss.item() * y.size(0)
val_acc += (output.max(1)[1] == y).sum().item()
m += y.size(0)

#Log
train_loss_ls.append({'Epoch':epoch, 'Value': round(train_loss/n,5)})
val_loss_ls.append({'Epoch':epoch, 'Value': round(val_loss/m,5)})
train_acc_ls.append({'Epoch':epoch, 'Value': round(train_acc/n,5)})
val_acc_ls.append({'Epoch':epoch, 'Value': round(val_acc/m,5)})

print(f'[{args.name}-{args.model}] Epoch: {epoch} | Train Acc: {train_acc/n:.4f}, Test Acc: {val_acc/m:.4f}, Time: {time.time() - start:.1f}, lr: {lr:.6f}')


###========================================================================
#Final test
#Go through with final testing
print("==="*25)
print("[Training Complete. Evaluation with testset initiated.]")


#To evaluation mode first
model.eval()
test_acc, test_loss, m = 0, 0, 0

with torch.no_grad(): #no grad needed
for i, (X, y) in enumerate(testloader):
X, y = X.cuda(), y.cuda() #port to cuda

#FP and compute result and log
with torch.cuda.amp.autocast():
output = model(X)
loss = criterion(output,y)

test_loss += loss.item() * y.size(0)
test_acc += (output.max(1)[1] == y).sum().item()
m += y.size(0)

#Log
test_loss_ls = [{'Value':round(test_loss/m,5)}]
test_acc_ls = [{'Value':round(test_acc/m,5)}]


print(f'[{args.name}-{args.model}] | Test Acc: {test_acc/m:.4f}, Time: {time.time() - start:.1f}')


#Save everything
with open(os.path.join(dst,'train_loss_{}.txt'.format(args.name)), 'w') as f:
json.dump(train_loss_ls, f, indent=4)

with open(os.path.join(dst,'train_acc_{}.txt'.format(args.name)), 'w') as f:
json.dump(train_acc_ls, f, indent=4)

with open(os.path.join(dst,'val_loss_{}.txt'.format(args.name)), 'w') as f:
json.dump(val_loss_ls, f, indent=4)

with open(os.path.join(dst,'val_acc_{}.txt'.format(args.name)), 'w') as f:
json.dump(val_acc_ls, f, indent=4)

with open(os.path.join(dst,'test_loss_{}.txt'.format(args.name)), 'w') as f:
json.dump(test_loss_ls, f, indent=4)

with open(os.path.join(dst,'test_acc_{}.txt'.format(args.name)), 'w') as f:
json.dump(test_acc_ls, f, indent=4)

#Save model
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': opt.state_dict(),
}, os.path.join(dst,'{}.pkl'.format(args.name)))
Loading

0 comments on commit d4d4988

Please sign in to comment.