Skip to content

Commit

Permalink
Add latent code manipulation
Browse files Browse the repository at this point in the history
  • Loading branch information
XifengGuo committed Dec 4, 2017
1 parent 8bb1b93 commit a90bfb4
Show file tree
Hide file tree
Showing 14 changed files with 84 additions and 22 deletions.
30 changes: 26 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,33 @@ python capsulenet.py --is_training 0 --weights result/trained_model.h5
Digits at top 5 rows are real images from MNIST and
digits at bottom are corresponding reconstructed images.

![](real_and_recon.png)
![](result/real_and_recon.png)

**Manipulate latent code:**

```
python capsulenet.py -t --digit 5 -w result/trained_model.h5
```
For each digit, the *i*th row corresponds to the *i*th dimension of the capsule, and columns from left to
right correspond to adding `[-0.25, -0.2, -0.15, -0.1, -0.05, 0, 0.05, 0.1, 0.15, 0.2, 0.25]` to
the value of one dimension of the capsule.

As we can see, each dimension has caught some characteristics of a digit. The same dimension of
different digit capsules may represent different characteristics. This is because that different
digits are reconstructed from different feature vectors (digit capsules). These vectors are mutually
independent during reconstruction.

![](result/manipulate-0.png)
![](result/manipulate-1.png)
![](result/manipulate-2.png)
![](result/manipulate-3.png)
![](result/manipulate-4.png)
![](result/manipulate-5.png)
![](result/manipulate-6.png)
![](result/manipulate-7.png)
![](result/manipulate-8.png)
![](result/manipulate-9.png)

**The model structure:**

![](result/model.png)

## Other Implementations

Expand Down
64 changes: 49 additions & 15 deletions capsulenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from keras import layers, models, optimizers
from keras import backend as K
from keras.utils import to_categorical
import matplotlib.pyplot as plt
from utils import combine_images
from PIL import Image
from capsulelayers import CapsuleLayer, PrimaryCap, Length, Mask

K.set_image_data_format('channels_last')
Expand Down Expand Up @@ -65,7 +68,13 @@ def CapsNet(input_shape, n_class, num_routing):
# Models for training and evaluation (prediction)
train_model = models.Model([x, y], [out_caps, decoder(masked_by_y)])
eval_model = models.Model(x, [out_caps, decoder(masked)])
return train_model, eval_model

# manipulate model
noise = layers.Input(shape=(n_class, 16))
noised_digitcaps = layers.Add()([digitcaps, noise])
masked_noised_y = Mask()([noised_digitcaps, y])
manipulate_model = models.Model([x, y, noise], decoder(masked_noised_y))
return train_model, eval_model, manipulate_model


def margin_loss(y_true, y_pred):
Expand Down Expand Up @@ -138,26 +147,47 @@ def train_generator(x, y, batch_size, shift_fraction=0.):
return model


def test(model, data):
def test(model, data, args):
x_test, y_test = data
y_pred, x_recon = model.predict(x_test, batch_size=100)
print('-'*50)
print('-'*30 + 'Begin: test' + '-'*30)
print('Test acc:', np.sum(np.argmax(y_pred, 1) == np.argmax(y_test, 1))/y_test.shape[0])

import matplotlib.pyplot as plt
from utils import combine_images
from PIL import Image

img = combine_images(np.concatenate([x_test[:50],x_recon[:50]]))
image = img * 255
Image.fromarray(image.astype(np.uint8)).save("real_and_recon.png")
Image.fromarray(image.astype(np.uint8)).save(args.save_dir + "/real_and_recon.png")
print()
print('Reconstructed images are saved to ./real_and_recon.png')
print('-'*50)
plt.imshow(plt.imread("real_and_recon.png", ))
print('Reconstructed images are saved to %s/real_and_recon.png' % args.save_dir)
print('-' * 30 + 'End: test' + '-' * 30)
plt.imshow(plt.imread(args.save_dir + "/real_and_recon.png"))
plt.show()


def manipulate_latent(model, data, args):
print('-'*30 + 'Begin: manipulate' + '-'*30)
x_test, y_test = data
index = np.argmax(y_test, 1) == args.digit
number = np.random.randint(low=0, high=sum(index) - 1)
x, y = x_test[index][number], y_test[index][number]
x, y = np.expand_dims(x, 0), np.expand_dims(y, 0)
noise = np.zeros([1, 10, 16])
x_recons = []
for dim in range(16):
for r in [-0.25, -0.2, -0.15, -0.1, -0.05, 0, 0.05, 0.1, 0.15, 0.2, 0.25]:
tmp = np.copy(noise)
tmp[:,:,dim] = r
x_recon = model.predict([x, y, tmp])
x_recons.append(x_recon)

x_recons = np.concatenate(x_recons)

img = combine_images(x_recons, height=16)
image = img*255
Image.fromarray(image.astype(np.uint8)).save(args.save_dir + '/manipulate-%d.png' % args.digit)
print('manipulated result saved to %s/manipulate-%d.png' % (args.save_dir, args.digit))
print('-' * 30 + 'End: manipulate' + '-' * 30)


def load_mnist():
# the data, shuffled and split between train and test sets
from keras.datasets import mnist
Expand Down Expand Up @@ -195,6 +225,8 @@ def load_mnist():
parser.add_argument('--save_dir', default='./result')
parser.add_argument('-t', '--testing', action='store_true',
help="Test the trained model on testing dataset")
parser.add_argument('--digit', default=5, type=int,
help="Digit to manipulate")
parser.add_argument('-w', '--weights', default=None,
help="The path of the saved weights. Should be specified when testing")
args = parser.parse_args()
Expand All @@ -207,11 +239,12 @@ def load_mnist():
(x_train, y_train), (x_test, y_test) = load_mnist()

# define model
model, eval_model = CapsNet(input_shape=x_train.shape[1:],
n_class=len(np.unique(np.argmax(y_train, 1))),
num_routing=args.routings)
model, eval_model, manipulate_model = CapsNet(input_shape=x_train.shape[1:],
n_class=len(np.unique(np.argmax(y_train, 1))),
num_routing=args.routings)
model.summary()


# train or test
if args.weights is not None: # init the model weights with provided one
model.load_weights(args.weights)
Expand All @@ -220,4 +253,5 @@ def load_mnist():
else: # as long as weights are given, will run testing
if args.weights is None:
print('No weights are provided. Will test using random initialized weights.')
test(model=eval_model, data=(x_test, y_test))
manipulate_latent(manipulate_model, (x_test, y_test), args)
test(model=eval_model, data=(x_test, y_test), args=args)
Binary file added result/manipulate-0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added result/manipulate-1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added result/manipulate-2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added result/manipulate-3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added result/manipulate-4.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added result/manipulate-5.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added result/manipulate-6.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added result/manipulate-7.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added result/manipulate-8.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added result/manipulate-9.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
File renamed without changes
12 changes: 9 additions & 3 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,16 @@ def plot_log(filename, show=True):
plt.show()


def combine_images(generated_images):
def combine_images(generated_images, height=None, width=None):
num = generated_images.shape[0]
width = int(math.sqrt(num))
height = int(math.ceil(float(num)/width))
if width is None and height is None:
width = int(math.sqrt(num))
height = int(math.ceil(float(num)/width))
elif width is not None and height is None: # height not given
height = int(math.ceil(float(num)/width))
elif height is not None and width is None: # width not given
width = int(math.ceil(float(num)/height))

shape = generated_images.shape[1:3]
image = np.zeros((height*shape[0], width*shape[1]),
dtype=generated_images.dtype)
Expand Down

0 comments on commit a90bfb4

Please sign in to comment.