diff --git a/trainer.py b/trainer.py index 49689d7..a3db0e5 100644 --- a/trainer.py +++ b/trainer.py @@ -50,8 +50,7 @@ ds, batch_size=args.batch_size, shuffle=True, - num_workers=args.num_workers, - pin_memory=True + num_workers=args.num_workers ) feature_extractor = transform(**kwargs).to(device) @@ -92,7 +91,7 @@ with torch.no_grad(): x = x.to(device) x = feature_extractor(x) + 1 - x = x.log() + x = x.log().detach() target = target.to(device) # forward pass @@ -134,4 +133,4 @@ if args.save_checkpoint: save_checkpoint(args.save_path, model, criterion, optimizer, epoch) -save_checkpoint(model, criterion, optimizer, epoch) +save_checkpoint(args.save_path, model, criterion, optimizer, epoch)