From a8c4c193d8f83a0196e5dc6472d28d26e7704fb9 Mon Sep 17 00:00:00 2001 From: Hao Date: Fri, 1 Jun 2018 21:56:47 +1200 Subject: [PATCH] use tensorboardx to show graphs --- prune_alexnet.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/prune_alexnet.py b/prune_alexnet.py index 32687e6e..8eee8f07 100644 --- a/prune_alexnet.py +++ b/prune_alexnet.py @@ -6,6 +6,7 @@ import argparse import logging import sys +from tensorboardX import SummaryWriter from vision.prunning.prunner import ModelPrunner from vision.utils.misc import str2bool @@ -191,6 +192,7 @@ def make_prunner_loader(dataset): train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=len(val_dataset), shuffle=False, num_workers=1) + writer = SummaryWriter() if args.train: logging.info("Start training.") train(net, train_loader, val_loader, args.num_epochs, args.learning_rate) @@ -225,7 +227,11 @@ def make_prunner_loader(dataset): val_loss, val_accuracy = eval(prunner.model, val_loader) logging.info(f"Prune: {i}/{prune_num}, After Pruning Evaluation Accuracy:{val_accuracy:.4f}.") val_loss, val_accuracy = train_epoch(prunner.model, train_data_iter, args.num_recovery_batches, optimizer) + for name, param in net.named_parameters(): + writer.add_histogram(name, param.clone().cpu().data.numpy(), 10) if iteration % 10 == 0: + dummy_input = torch.rand(1, 3, 224, 224) + writer.add_graph(net, dummy_input) val_loss, val_accuracy = eval(prunner.model, val_loader) logging.info(f"Prune: {i}/{prune_num}, After Recovery Evaluation Accuracy:{val_accuracy:.4f}.") logging.info(f"Prune: {i}/{prune_num}, Iteration: {iteration}, Save model.") @@ -235,3 +241,5 @@ def make_prunner_loader(dataset): iteration += 1 else: logging.fatal("You should specify --prune_conv, --prune_linear or --train.") + + writer.close()