diff --git a/cfg/runner/pl.yaml b/cfg/runner/pl.yaml index fd2aadc..15120ad 100644 --- a/cfg/runner/pl.yaml +++ b/cfg/runner/pl.yaml @@ -1,4 +1,5 @@ runner: name: pl + ntrials: 1 # Currently we don't set seed when run with pytorch-lightning seed: diff --git a/pl_runner.py b/pl_runner.py index 3f0ed45..58898a1 100644 --- a/pl_runner.py +++ b/pl_runner.py @@ -9,10 +9,12 @@ def pl_train(cfg, pl_model_class): torch.cuda.manual_seed(cfg.seed) model = pl_model_class(cfg.model, cfg.dataset, cfg.train) if 'pl' in cfg and 'profile' in cfg.pl and cfg.pl.profile: + # profiler=pl.profiler.AdvancedProfiler(output_filename=cfg.train.profiler), profiler_args = { 'profiler': pl.profiler.AdvancedProfiler(), } else: profiler_args = {} if 'pl' in cfg and 'wandb' in cfg.pl and cfg.pl.wandb: + # kwargs['logger'] = WandbLogger(name=config['pl_wandb'], project='ops-memory-pl') logger = WandbLogger(project='ops-memory-pl') logger.log_hyperparams(cfg.model) logger.log_hyperparams(cfg.dataset) @@ -20,15 +22,16 @@ def pl_train(cfg, pl_model_class): profiler_args['logger'] = logger print("profiler args", profiler_args) trainer = pl.Trainer( + # gpus=1 if config['gpu'] else None, gpus=1, gradient_clip_val=cfg.train.gradient_clip_val, max_epochs=1 if cfg.smoke_test else cfg.train.epochs, - early_stop_callback=False, progress_bar_refresh_rate=1, + progress_bar_refresh_rate=1, limit_train_batches=cfg.train.limit_train_batches, track_grad_norm=2, **profiler_args, ) trainer.fit(model) - trainer.test(model) + # trainer.test(model) return trainer, model diff --git a/train.py b/train.py index 58c8ad2..16735de 100644 --- a/train.py +++ b/train.py @@ -26,58 +26,44 @@ def __init__(self, model_args, dataset_cfg, train_args): self.dataset = DatasetBase.registry[dataset_cfg.name](dataset_cfg) self.train_args = train_args self.model_args = model_args + # self.model_args.cell_args.max_length = self.dataset.N # TODO fix datasets + # cell_args = model_args.cell_args + # other_args = {k: v for k, v in model_args.items() if k not in ['cell', 'cell_args', 'dropout']} self.model = Model( self.dataset.input_size, self.dataset.output_size, + # model_args.cell, + # cell_args=cell_args, output_len=self.dataset.output_len, + # dropout=model_args.dropout, + # max_length=self.dataset.N, **model_args, ) def forward(self, input): self.model.forward(input) - def training_step(self, batch, batch_idx): + def _shared_step(self, batch, batch_idx, prefix='train'): batch_x, batch_y, *len_batch = batch # Either fixed length sequence or variable length len_batch = None if not len_batch else len_batch[0] out = self.model(batch_x, len_batch) loss = self.dataset.loss(out, batch_y, len_batch) metrics = self.dataset.metrics(out, batch_y) - return {'loss': loss, 'size': batch_x.shape[0], 'out': out, 'target': batch_y, - 'progress_bar': metrics, 'log': metrics} - - def training_epoch_end(self, outputs, prefix='train'): - losses = torch.stack([output['loss'] for output in outputs]) - sizes = torch.tensor([output['size'] for output in outputs], device=losses.device) - loss_mean = (losses * sizes).sum() / sizes.sum() - outs = [output['out'] for output in outputs] - targets = [output['target'] for output in outputs] - metrics = self.dataset.metrics_epoch(outs, targets) metrics = {f'{prefix}_{k}': v for k, v in metrics.items()} - results = {f'{prefix}_loss': loss_mean, **metrics} - results_scalar = {k: to_scalar(v) for k, v in results.items()} # PL prefers torch.Tensor while we prefer float - setattr(self, f'_{prefix}_results', results_scalar) - if getattr(self.train_args, 'verbose', False): - print(f'{prefix} set results:', results_scalar) - return {f'{prefix}_loss': loss_mean, 'log': results} - - def validation_step(self, batch, batch_idx): - batch_x, batch_y, *len_batch = batch - # Either fixed length sequence or variable length - len_batch = None if not len_batch else len_batch[0] - out = self.model(batch_x, len_batch) - loss = self.dataset.loss(out, batch_y, len_batch) - metrics = self.dataset.metrics(out, batch_y) - return {'size': batch_x.shape[0], 'loss': loss, 'out': out, 'target': batch_y, **metrics} + self.log(f'{prefix}_loss', loss, on_epoch=True, prog_bar=False) + self.log_dict(metrics, on_epoch=True, prog_bar=True) + return loss - def validation_epoch_end(self, outputs): - return self.training_epoch_end(outputs, prefix='val') + def training_step(self, batch, batch_idx): + return self._shared_step(batch, batch_idx, prefix='train') - def test_step(self, batch, batch_idx): - return self.validation_step(batch, batch_idx) + def validation_step(self, batch, batch_idx, dataloader_idx=0): + return (self._shared_step(batch, batch_idx, prefix='val') if dataloader_idx == 0 else + self._shared_step(batch, batch_idx, prefix='test')) - def test_epoch_end(self, outputs): - return self.training_epoch_end(outputs, prefix='test') + def test_step(self, batch, batch_idx): + return self._shared_step(batch, batch_idx, prefix='test') def configure_optimizers(self): name_to_opt = {'adam': torch.optim.Adam, 'rmsprop': torch.optim.RMSprop} @@ -86,6 +72,7 @@ def configure_optimizers(self): non_orth_params, log_orth_params = get_parameters(self.model) return optimizer([ {'params': non_orth_params, 'lr': self.train_args.lr, 'weight_decay': self.train_args.wd}, + # {'params': log_orth_params, 'lr': self.train_args.lr_orth}, {'params': log_orth_params, 'lr': self.train_args.lr/10.0}, ]) else: @@ -100,15 +87,17 @@ def train_dataloader(self): return self.dataset.train_loader def val_dataloader(self): - return self.dataset.val_loader + return [self.dataset.val_loader, self.dataset.test_loader] def test_dataloader(self): return self.dataset.test_loader -@hydra.main(config_path="cfg/config.yaml", strict=False) +@hydra.main(config_path="cfg", config_name="config.yaml") def main(cfg: OmegaConf): - print(cfg.pretty()) + # We want to add fields to cfg so need to call OmegaConf.set_struct + OmegaConf.set_struct(cfg, False) + print(OmegaConf.to_yaml(cfg)) if cfg.runner.name == 'pl': from pl_runner import pl_train trainer, model = pl_train(cfg, RNNTraining) diff --git a/utils.py b/utils.py index 93a0298..0059d8b 100644 --- a/utils.py +++ b/utils.py @@ -4,6 +4,12 @@ from munch import Munch +def remove_postfix(text, postfix): + if text.endswith(postfix): + return text[:-len(postfix)] + return text + + # pytorch-lightning returns pytorch 0-dim tensor instead of python scalar def to_scalar(x): return x.item() if isinstance(x, torch.Tensor) else x