forked from XifengGuo/CapsNet-Keras
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun.py
35 lines (24 loc) · 817 Bytes
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import numpy as np
from keras import callbacks
import capsulenet
class args:
save_dir = "weights/"
debug = True
# model
routings = 1
# hp
batch_size = 32
lr = 0.001
lr_decay = 1.0
lam_recon = 0.392
# training
epochs = 3
shift_fraction = 0.1
digit = 5
(x_train, y_train), (x_test, y_test) = capsulenet.load_mnist()
model, eval_model, manipulate_model = capsulenet.CapsNet(input_shape=x_train.shape[1:],
n_class=len(np.unique(np.argmax(y_train, 1))),
routings=args.routings)
capsulenet.train(model=model, data=((x_train, y_train), (x_test, y_test)), args=args)
capsulenet.test(eval_model, data=(x_test, y_test), args=args)
model.save_weights("weights.h5")