-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_ddp.py
113 lines (99 loc) · 4.72 KB
/
train_ddp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import argparse
import collections
import random
import numpy as np
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from torchinfo import summary
import data_loader.data_loaders as module_data
import model as module_arch
import model.loss as module_loss
import model.metric as module_metric
from parse_config import ConfigParser
from trainer import Trainer
from utils import get_logger
def main(config):
logger = get_logger(name=__name__, log_dir=config.log_dir,
verbosity=config['trainer']['verbosity'])
if config['n_gpu'] == -1:
config.config['n_gpu'] = torch.cuda.device_count()
if config['dist_backend'] is None:
config.config['dist_backend'] = 'nccl'
if config['dist_url'] is None:
config.config['dist_url'] = 'tcp://127.0.0.1:34567'
torch.backends.cudnn.benchmark = True
if config['seed'] is not None:
torch.manual_seed(config['seed'])
torch.backends.cudnn.deterministic = True
np.random.seed(config['seed'])
random.seed(config['seed'])
logger.warning('You seeded the training. '
'This turns on the CUDNN deterministic setting, '
'which can slow down your training '
'You may see unexpected behavior when restarting '
'from checkpoints.')
mp.spawn(main_worker, nprocs=config['n_gpu'], args=(config,))
def main_worker(gpu, config):
logger = get_logger('Worker{}'.format(
gpu), log_dir=config.log_dir, verbosity=config['trainer']['verbosity'])
logger.info('Using GPU: {} for training'.format(gpu))
# Rank here is the process rank amoung all processes on all nodes, needs modification in case of multi node
dist.init_process_group(
backend=config['dist_backend'], init_method=config['dist_url'], world_size=config['n_gpu'], rank=gpu)
# setup data_loader instances
# Needs modification to support multinode
config.config['data_loader']['args']['batch_size'] //= config['n_gpu']
config.config['data_loader']['args']['num_workers'] //= config['n_gpu']
data_loader_obj = config.init_obj('data_loader', module_data)
data_loader = data_loader_obj.get_train_loader()
valid_data_loader = data_loader_obj.get_valid_loader()
# build model architecture, then print to console
model = config.init_obj('arch', module_arch)
# prepare for (multi-device) GPU training
torch.cuda.device(gpu)
model.cuda(gpu)
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[gpu], find_unused_parameters=True)
if gpu == 0:
logger.info(config['name'])
trainable_params = filter(
lambda p: p.requires_grad, model.parameters())
logger.info(summary(model, input_size=[
config['data_loader']['args']['batch_size']]+config['input_size'], verbose=0))
logger.info('Trainable parameters: {}'.format(
sum([p.numel() for p in trainable_params])))
# get function handles of loss and metrics
criterion = getattr(module_loss, config['loss'])
metrics = [getattr(module_metric, met) for met in config['metrics']]
# build optimizer, learning rate scheduler.
optimizer = config.init_obj('optimizer', torch.optim, model.parameters())
lr_scheduler = config.init_obj(
'lr_scheduler', torch.optim.lr_scheduler, optimizer)
trainer = Trainer(model, criterion, metrics, optimizer,
config=config,
device=gpu,
data_loader=data_loader,
valid_data_loader=valid_data_loader,
lr_scheduler=lr_scheduler,
train_sampler=data_loader_obj.train_sampler)
trainer.train()
if __name__ == '__main__':
args = argparse.ArgumentParser(description='PyTorch Template')
args.add_argument('-c', '--config', default=None, type=str,
help='config file path (default: None)')
args.add_argument('-r', '--resume', default=None, type=str,
help='path to latest checkpoint (default: None)')
args.add_argument('-d', '--device', default=None, type=str,
help='indices of GPUs to enable (default: all)')
# custom cli options to modify configuration from default values given in json file.
CustomArgs = collections.namedtuple('CustomArgs', 'flags type target help')
options = [
CustomArgs(['--lr', '--learning_rate'], type=float,
target='optimizer;args;lr', help=""),
CustomArgs(['--bs', '--batch_size'], type=int,
target='data_loader;args;batch_size', help="")
]
config = ConfigParser.from_args(args, options)
main(config)