Skip to content

Commit

Permalink
Update for latest package versions
Browse files Browse the repository at this point in the history
  • Loading branch information
albertfgu committed Dec 16, 2020
1 parent 6f2758f commit af46c6d
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 37 deletions.
1 change: 1 addition & 0 deletions cfg/runner/pl.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
runner:
name: pl
ntrials: 1
# Currently we don't set seed when run with pytorch-lightning
seed:
7 changes: 5 additions & 2 deletions pl_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,29 @@ def pl_train(cfg, pl_model_class):
torch.cuda.manual_seed(cfg.seed)
model = pl_model_class(cfg.model, cfg.dataset, cfg.train)
if 'pl' in cfg and 'profile' in cfg.pl and cfg.pl.profile:
# profiler=pl.profiler.AdvancedProfiler(output_filename=cfg.train.profiler),
profiler_args = { 'profiler': pl.profiler.AdvancedProfiler(), }
else:
profiler_args = {}
if 'pl' in cfg and 'wandb' in cfg.pl and cfg.pl.wandb:
# kwargs['logger'] = WandbLogger(name=config['pl_wandb'], project='ops-memory-pl')
logger = WandbLogger(project='ops-memory-pl')
logger.log_hyperparams(cfg.model)
logger.log_hyperparams(cfg.dataset)
logger.log_hyperparams(cfg.train)
profiler_args['logger'] = logger
print("profiler args", profiler_args)
trainer = pl.Trainer(
# gpus=1 if config['gpu'] else None,
gpus=1,
gradient_clip_val=cfg.train.gradient_clip_val,
max_epochs=1 if cfg.smoke_test else cfg.train.epochs,
early_stop_callback=False, progress_bar_refresh_rate=1,
progress_bar_refresh_rate=1,
limit_train_batches=cfg.train.limit_train_batches,
track_grad_norm=2,
**profiler_args,
)

trainer.fit(model)
trainer.test(model)
# trainer.test(model)
return trainer, model
59 changes: 24 additions & 35 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,58 +26,44 @@ def __init__(self, model_args, dataset_cfg, train_args):
self.dataset = DatasetBase.registry[dataset_cfg.name](dataset_cfg)
self.train_args = train_args
self.model_args = model_args
# self.model_args.cell_args.max_length = self.dataset.N # TODO fix datasets
# cell_args = model_args.cell_args
# other_args = {k: v for k, v in model_args.items() if k not in ['cell', 'cell_args', 'dropout']}
self.model = Model(
self.dataset.input_size,
self.dataset.output_size,
# model_args.cell,
# cell_args=cell_args,
output_len=self.dataset.output_len,
# dropout=model_args.dropout,
# max_length=self.dataset.N,
**model_args,
)

def forward(self, input):
self.model.forward(input)

def training_step(self, batch, batch_idx):
def _shared_step(self, batch, batch_idx, prefix='train'):
batch_x, batch_y, *len_batch = batch
# Either fixed length sequence or variable length
len_batch = None if not len_batch else len_batch[0]
out = self.model(batch_x, len_batch)
loss = self.dataset.loss(out, batch_y, len_batch)
metrics = self.dataset.metrics(out, batch_y)
return {'loss': loss, 'size': batch_x.shape[0], 'out': out, 'target': batch_y,
'progress_bar': metrics, 'log': metrics}

def training_epoch_end(self, outputs, prefix='train'):
losses = torch.stack([output['loss'] for output in outputs])
sizes = torch.tensor([output['size'] for output in outputs], device=losses.device)
loss_mean = (losses * sizes).sum() / sizes.sum()
outs = [output['out'] for output in outputs]
targets = [output['target'] for output in outputs]
metrics = self.dataset.metrics_epoch(outs, targets)
metrics = {f'{prefix}_{k}': v for k, v in metrics.items()}
results = {f'{prefix}_loss': loss_mean, **metrics}
results_scalar = {k: to_scalar(v) for k, v in results.items()} # PL prefers torch.Tensor while we prefer float
setattr(self, f'_{prefix}_results', results_scalar)
if getattr(self.train_args, 'verbose', False):
print(f'{prefix} set results:', results_scalar)
return {f'{prefix}_loss': loss_mean, 'log': results}

def validation_step(self, batch, batch_idx):
batch_x, batch_y, *len_batch = batch
# Either fixed length sequence or variable length
len_batch = None if not len_batch else len_batch[0]
out = self.model(batch_x, len_batch)
loss = self.dataset.loss(out, batch_y, len_batch)
metrics = self.dataset.metrics(out, batch_y)
return {'size': batch_x.shape[0], 'loss': loss, 'out': out, 'target': batch_y, **metrics}
self.log(f'{prefix}_loss', loss, on_epoch=True, prog_bar=False)
self.log_dict(metrics, on_epoch=True, prog_bar=True)
return loss

def validation_epoch_end(self, outputs):
return self.training_epoch_end(outputs, prefix='val')
def training_step(self, batch, batch_idx):
return self._shared_step(batch, batch_idx, prefix='train')

def test_step(self, batch, batch_idx):
return self.validation_step(batch, batch_idx)
def validation_step(self, batch, batch_idx, dataloader_idx=0):
return (self._shared_step(batch, batch_idx, prefix='val') if dataloader_idx == 0 else
self._shared_step(batch, batch_idx, prefix='test'))

def test_epoch_end(self, outputs):
return self.training_epoch_end(outputs, prefix='test')
def test_step(self, batch, batch_idx):
return self._shared_step(batch, batch_idx, prefix='test')

def configure_optimizers(self):
name_to_opt = {'adam': torch.optim.Adam, 'rmsprop': torch.optim.RMSprop}
Expand All @@ -86,6 +72,7 @@ def configure_optimizers(self):
non_orth_params, log_orth_params = get_parameters(self.model)
return optimizer([
{'params': non_orth_params, 'lr': self.train_args.lr, 'weight_decay': self.train_args.wd},
# {'params': log_orth_params, 'lr': self.train_args.lr_orth},
{'params': log_orth_params, 'lr': self.train_args.lr/10.0},
])
else:
Expand All @@ -100,15 +87,17 @@ def train_dataloader(self):
return self.dataset.train_loader

def val_dataloader(self):
return self.dataset.val_loader
return [self.dataset.val_loader, self.dataset.test_loader]

def test_dataloader(self):
return self.dataset.test_loader


@hydra.main(config_path="cfg/config.yaml", strict=False)
@hydra.main(config_path="cfg", config_name="config.yaml")
def main(cfg: OmegaConf):
print(cfg.pretty())
# We want to add fields to cfg so need to call OmegaConf.set_struct
OmegaConf.set_struct(cfg, False)
print(OmegaConf.to_yaml(cfg))
if cfg.runner.name == 'pl':
from pl_runner import pl_train
trainer, model = pl_train(cfg, RNNTraining)
Expand Down
6 changes: 6 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
from munch import Munch


def remove_postfix(text, postfix):
if text.endswith(postfix):
return text[:-len(postfix)]
return text


# pytorch-lightning returns pytorch 0-dim tensor instead of python scalar
def to_scalar(x):
return x.item() if isinstance(x, torch.Tensor) else x
Expand Down

0 comments on commit af46c6d

Please sign in to comment.