-
Notifications
You must be signed in to change notification settings - Fork 29
/
mnist.py
115 lines (93 loc) · 4.42 KB
/
mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch
import argparse
import numpy as np
from DCN import DCN
from torchvision import datasets, transforms
from sklearn.metrics import adjusted_rand_score
from sklearn.metrics import normalized_mutual_info_score
def evaluate(model, test_loader):
y_test = []
y_pred = []
for data, target in test_loader:
batch_size = data.size()[0]
data = data.view(batch_size, -1).to(model.device)
latent_X = model.autoencoder(data, latent=True)
latent_X = latent_X.detach().cpu().numpy()
y_test.append(target.view(-1, 1).numpy())
y_pred.append(model.kmeans.update_assign(latent_X).reshape(-1, 1))
y_test = np.vstack(y_test).reshape(-1)
y_pred = np.vstack(y_pred).reshape(-1)
return (normalized_mutual_info_score(y_test, y_pred),
adjusted_rand_score(y_test, y_pred))
def solver(args, model, train_loader, test_loader):
rec_loss_list = model.pretrain(train_loader, args.pre_epoch)
nmi_list = []
ari_list = []
for e in range(args.epoch):
model.train()
model.fit(e, train_loader)
model.eval()
NMI, ARI = evaluate(model, test_loader) # evaluation on test_loader
nmi_list.append(NMI)
ari_list.append(ARI)
print('\nEpoch: {:02d} | NMI: {:.3f} | ARI: {:.3f}\n'.format(
e, NMI, ARI))
return rec_loss_list, nmi_list, ari_list
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Deep Clustering Network')
# Dataset parameters
parser.add_argument('--dir', default='../Dataset/mnist',
help='dataset directory')
parser.add_argument('--input-dim', type=int, default=28*28,
help='input dimension')
parser.add_argument('--n-classes', type=int, default=10,
help='output dimension')
# Training parameters
parser.add_argument('--lr', type=float, default=1e-4,
help='learning rate (default: 1e-4)')
parser.add_argument('--wd', type=float, default=5e-4,
help='weight decay (default: 5e-4)')
parser.add_argument('--batch-size', type=int, default=256,
help='input batch size for training')
parser.add_argument('--epoch', type=int, default=100,
help='number of epochs to train')
parser.add_argument('--pre-epoch', type=int, default=50,
help='number of pre-train epochs')
parser.add_argument('--pretrain', type=bool, default=True,
help='whether use pre-training')
# Model parameters
parser.add_argument('--lamda', type=float, default=1,
help='coefficient of the reconstruction loss')
parser.add_argument('--beta', type=float, default=1,
help=('coefficient of the regularization term on '
'clustering'))
parser.add_argument('--hidden-dims', default=[500, 500, 2000],
help='learning rate (default: 1e-4)')
parser.add_argument('--latent_dim', type=int, default=10,
help='latent space dimension')
parser.add_argument('--n-clusters', type=int, default=10,
help='number of clusters in the latent space')
# Utility parameters
parser.add_argument('--n-jobs', type=int, default=1,
help='number of jobs to run in parallel')
parser.add_argument('--cuda', type=bool, default=True,
help='whether to use GPU')
parser.add_argument('--log-interval', type=int, default=100,
help=('how many batches to wait before logging the '
'training status'))
args = parser.parse_args()
# Load data
transformer = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,),
(0.3081,))])
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(args.dir, train=True, download=True,
transform=transformer),
batch_size=args.batch_size, shuffle=False)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(args.dir, train=False, transform=transformer),
batch_size=args.batch_size, shuffle=True)
# Main body
model = DCN(args)
rec_loss_list, nmi_list, ari_list = solver(
args, model, train_loader, test_loader)