forked from XifengGuo/CapsNet-Keras
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
468 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import numpy as np | ||
from keras import callbacks | ||
|
||
import capsulenet | ||
|
||
class args: | ||
save_dir = "weights/" | ||
debug = True | ||
|
||
# model | ||
routings = 1 | ||
|
||
# hp | ||
batch_size = 32 | ||
lr = 0.001 | ||
lr_decay = 1.0 | ||
lam_recon = 0.392 | ||
|
||
# training | ||
epochs = 3 | ||
shift_fraction = 0.1 | ||
digit = 5 | ||
|
||
|
||
(x_train, y_train), (x_test, y_test) = capsulenet.load_mnist() | ||
|
||
model, eval_model, manipulate_model = capsulenet.CapsNet(input_shape=x_train.shape[1:], | ||
n_class=len(np.unique(np.argmax(y_train, 1))), | ||
routings=args.routings) | ||
|
||
capsulenet.train(model=model, data=((x_train, y_train), (x_test, y_test)), args=args) | ||
|
||
capsulenet.test(eval_model, data=(x_test, y_test), args=args) | ||
|
||
model.save_weights("weights.h5") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,225 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"Using TensorFlow backend.\n", | ||
"/home/watiz/.virtualenvs/siamese_reid/lib/python3.5/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n", | ||
" from ._conv import register_converters as _register_converters\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import numpy as np\n", | ||
"from keras import layers, models, optimizers\n", | ||
"from keras import backend as K\n", | ||
"from keras.utils import to_categorical\n", | ||
"from keras.preprocessing.image import ImageDataGenerator\n", | ||
"from keras import callbacks\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"from utils import combine_images\n", | ||
"from PIL import Image\n", | ||
"from capsulelayers import CapsuleLayer, PrimaryCap, Length, Mask\n", | ||
"\n", | ||
"%matplotlib inline" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"class args:\n", | ||
" save_dir = \"weights/\"\n", | ||
" debug = True\n", | ||
"\n", | ||
" # model\n", | ||
" routings = 1\n", | ||
"\n", | ||
" # hp\n", | ||
" batch_size = 32\n", | ||
" lr = 0.001\n", | ||
" lr_decay = 1.0\n", | ||
" lam_recon = 0.392\n", | ||
"\n", | ||
" # training\n", | ||
" epochs = 3\n", | ||
" shift_fraction = 0.1\n", | ||
" digit = 5" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"\n", | ||
"def CapsNet(input_shape, n_class, routings):\n", | ||
" x = layers.Input(shape=input_shape)\n", | ||
"\n", | ||
" # Layer 1: Just a conventional Conv2D layer\n", | ||
" conv1 = layers.Conv2D(filters=256, kernel_size=9, strides=1, padding='valid', activation='relu', name='conv1')(x)\n", | ||
"\n", | ||
" # Layer 2: Conv2D layer with `squash` activation, then reshape to [None, num_capsule, dim_capsule]\n", | ||
" primarycaps = PrimaryCap(conv1, dim_capsule=8, n_channels=32, kernel_size=9, strides=2, padding='valid')\n", | ||
"\n", | ||
" # Layer 3: Capsule layer. Routing algorithm works here.\n", | ||
" digitcaps = CapsuleLayer(num_capsule=n_class, dim_capsule=16, routings=routings,\n", | ||
" name='digitcaps')(primarycaps)\n", | ||
"\n", | ||
" # Layer 4: This is an auxiliary layer to replace each capsule with its length. Just to match the true label's shape.\n", | ||
" # If using tensorflow, this will not be necessary. :)\n", | ||
" out_caps = Length(name='capsnet')(digitcaps)\n", | ||
" \n", | ||
" model = models.Model(x, out_caps)\n", | ||
" \n", | ||
" return model\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# the data, shuffled and split between train and test sets\n", | ||
"from keras.datasets import mnist\n", | ||
"(x_train, y_train), (x_test, y_test) = mnist.load_data()\n", | ||
"\n", | ||
"x_train = x_train.reshape(-1, 28, 28, 1).astype('float32') / 255.\n", | ||
"x_test = x_test.reshape(-1, 28, 28, 1).astype('float32') / 255.\n", | ||
"y_train = to_categorical(y_train.astype('float32'))\n", | ||
"y_test = to_categorical(y_test.astype('float32'))\n", | ||
"\n", | ||
"x_train = x_train[:1000]\n", | ||
"y_train = y_train[:1000]\n", | ||
"x_test = x_test[:1000]\n", | ||
"y_test = y_test[:1000]\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def margin_loss(y_true, y_pred):\n", | ||
" L = y_true * K.square(K.maximum(0., 0.9 - y_pred)) + \\\n", | ||
" 0.5 * (1 - y_true) * K.square(K.maximum(0., y_pred - 0.1))\n", | ||
"\n", | ||
" return K.mean(K.sum(L, 1))\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"model = CapsNet(input_shape=x_train.shape[1:], n_class=len(np.unique(np.argmax(y_train, 1))), routings=args.routings)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"model.compile(optimizer=optimizers.Adam(lr=args.lr), loss=[margin_loss])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Train on 1000 samples, validate on 1000 samples\n", | ||
"Epoch 1/3\n", | ||
"1000/1000 [==============================] - 20s - loss: 0.4622 - val_loss: 0.2460\n", | ||
"Epoch 2/3\n", | ||
"1000/1000 [==============================] - 20s - loss: 0.1756 - val_loss: 0.1639\n", | ||
"Epoch 3/3\n", | ||
"1000/1000 [==============================] - 26s - loss: 0.1070 - val_loss: 0.1050\n" | ||
] | ||
}, | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"<keras.callbacks.History at 0x7f57a75547f0>" | ||
] | ||
}, | ||
"execution_count": 8, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"model.fit(x_train, y_train, validation_data=[x_test, y_test], epochs=3)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 9, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAWQAAABbCAYAAABEQP/sAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvAOZPmwAADpxJREFUeJzt3XmQ3VP6x/H3Q2JIJGhLyBRimRAUU0KV8kPsxNhjH2tRIkEsyc8ee1HWQg0RpGzzKww/SxA/+86PoEjEElMIhiDEZEEEZ/64/eTb9+be7tvdt+85376fV1Wqt+/tfvLtvuc+3+d7znMshICIiMS3ROwARESkQAOyiEgiNCCLiCRCA7KISCI0IIuIJEIDsohIIjQgi4gkIskB2cwGmdkzZvZvM/unme0TO6YUmFmTmT1gZvPNbIaZHRI7ptjMbICZTTKz2WY208z+ZmY9YscVi5n9wcwmNP99zDWzt81saOy4UpCHcSW5Abn5yfQQ8AjQBBwL/N3MBkYNLA3XA78A/YC/AuPMbMO4IUV3A/ANsBrwZ2AIMDJqRHH1AD6ncB6WA84B/mFmAyLGFF1exhVLbaWemW0E/D/QJzQHZ2ZPAK+FEMZGDS4iM+sNzAY2CiFMb/7cncC/QghnRA0uIjN7HxgdQpjU/PEVQN8QwvC4kaXDzKYAF4QQ/jd2LLHkZVxJLkOuwICNYgcR2UDgVx+Mm70DNHqGfA1wkJn1MrM/AkOB/4scUzLMrB+Fv51psWNJUHLjSooD8ocULkH/28x6mtnOFC6/esUNK7plgTkln/s30CdCLCl5gcKL0hzgC+AN4MGoESXCzHoC/wPcHkL4IHY8keViXEluQA4hLAT2Bv4CzARGA/+g8GRrZPOAviWf6wvMjRBLEsxsCQrZ8P1Ab2AlYAXgsphxpaD53NxJ4Z7DCZHDiS4v40pyAzJACGFKCGFICGHFEMIuwNrA67Hjimw60MPM/tTic5vQ2JeiTcAawN9CCAtCCN8BtwK7xQ0rLjMzYAKFm7/DmgejhpeHcSXJAdnMNjazpZvrgmMo3EG/LXJYUYUQ5lPIBC80s95m9l/AXhSyoIYUQpgFfAKMMLMeZrY8cAQwJW5k0Y0DBgF7hBB+ih1MKvIwriQ5IAOHAV9RqPnsAOwUQlgQN6QkjASWoXBe7gJGhBAaOUMG2BfYFfgW+CewEDglakQRmdmawHAKUwBnmtm85n9/jRxaCpIfV5Kb9iYi0qhSzZBFRBqOBmQRkURoQBYRSYQGZBGRRGhAFhFJRLvaFJpZQ0zJCCFYtcc2yjkBZoUQVq7mQJ2T8hrlvOj5U1ZVfyvKkKVaM2IHkCCdE6lWVX8rGpBFRBKhAVlEJBEakEVEEtGwe4+JdEeDBw8G4MADDwRgxIgRAFx00UUA3HzzzQDMnj07QnTSFmXIIiKJaFdzoUaZoqJpO2W9GULYrJoDdU7K68rzst122wFwwQUXALDlllsCMHXqVABWX311AH76qdCNc6+99gLgrbfeqnksev6UVdXfijJkEZFEJFdDLmx2ACussAIAJ598MpC94v/yyy8AjB8/HoCJEycC0J3biG677bZFb88777yyx3l29NxzzxW97Y6WWWYZAE4//XQAevfuDcCuu+4KwAYbbFD2cRdeeCGQnau8W3vttQG45557AGhqagLgyiuvBOCMMwobknsG7cc99NBDAGy88caAasqpUIYsIpKIZGrI/ko9duxYAIYNG1bV46655hoARo8eDdQmU06lBnb++ecDlTPitnhWVKNMOWoN2TPinXbaCYAxY8YA2ZWTX1n573/evHlAlvl5DXXatMIGK5tsskktwopeQ/ZM2K8kH3nkEQD2228/AH799dei4w8++GAA7ryzsPPXmWeeCcAVV1xRs5hSef4kRjVkEZE8iV5D3n///QG46aabAOjbt7DT/eTJk4Eso/Ga8p577glkGdGJJ54IwCWXXALArFmz6hF2l3r22WeBrGbcUZ5Zd4dasmdyZ511VlXHX3fddQCMGzcOgCOPPBKAjz76qPbBRbTyysX9an744Qdg8czYvfnmm0D2PPEr0lpmyKk44YQTADj77LMBWHXVVYHsKuqpp54Cspkofl9hzpw5dY2zJWXIIiKJ0IAsIpKIaCWLU04p7NR+7rnnAtmEdb80vfHGG8s+7vjjjweyy3G/5OoOpQq/idfZUkV3ss022wAwcuTIdj1u1KhRQLbw4dJLL61tYIm4/PLLATjggAMAeP7551s9fvr06UA2/c0v6/0mqd8kzKPVVlsNgNtuuw3IbmovueSSQFaq8Lc77LBD0dt9990XyBbNTJkypQ5RF1OGLCKSiLpPe1t++eUBeP3114Gs0O7T3j799NOqvs8qq6xS9PE333zT2dAWqfe0Hc+I/WZeJf6K39YCEec38/xxnRRl2tvcuXOBbNpbKz8TWHzao195eQb52GOP1So0SGDam/Msz68I2lro4QtK3njjDSC7qe4LSTqj3s+f9ddfH4DXXnsNgGWXXbbo6y+++CKQTQn0qwR3xx13ANCnTx8ga8TkV6w1omlvIiJ5UpcMeYklsnH/vvvuA2DvvfcGsnaAw4cP78i3ZrnllgPgoIMOArL60YIFCzr0/aD+r/CVfgfVZrhtTZOr0QKRumbIQ4cOBWDSpEkA/P7772WP+/HHH4s+7tWrV9njbrjhBiCbJlkjyWTIHfXll18Wfdy/f/9Of896P388M95ss8Kv4osvvgDg2GOPBeCZZ54BYOHChWUf71fXK664IgCvvPIKAFtvvXVnQ2tJGbKISJ7UZZaF3ymHLDP2WREdvfu97rrrAvDEE08AMGDAACB7tXvggQc69H3rpa16MVRf+/XjKmXK/nmvs6bKF/8A3H777UCWGZdeRXiNdJ999gGyJdS+QKiUN2r3jLvGtWSps913333R+5tvvjmweGvRt99+u9XvcfTRRwPZYjT36KOP1izO9lKGLCKSiLpkyOXu3N5yyy1A9bMq3NJLLw1kGbBnxu67775rf4B1VNpKs5yO1nrbypRb3jWu8R3kmvDfLRTfdyjn+uuvB+Cll14CsqX2O+64IwDbb7992cf5vPdGz5A33HBDIKu5X3311THDqZrHe8QRRyz6nF/5+bZVbWXG/n/3+1elvv32207H2VHKkEVEEhFtpV57M1nPhH3OoL/KlYqxuqY9qsmMOztv2Juk5G3F31dffbXofd+IoBKfreN8Vo3fUa+UIa+55pqdCbHb2HTTTYFs7m3pbJVUedOktdZaa9Hn/P7CZ599VnRsjx6F4c3XPviV+qGHHlr0uFKVZvTUgzJkEZFE1CVDLtfOzmdXlGY2pbznxWGHHQbAUkstVfY4f/XzlV2pGjJkSMWvtdWHoFqlWzh5ptzaz84LP0cffPBB2a97bfjiiy8u+3Xf6sm3eHrvvfdqHWLSPCP22SmeJcasm7aHXzl9+OGHiz7n2b433fe/kUGDBgHZ1VKl1ZwpUYYsIpKIuqzUa9l34tVXXwWKa0DV8IzonXfeAbI7qs4b1/t69c7oypVGrZ3vWs8Tbm0LqA78rC5fqTd48OBF73t2730JvK735JNPAtlmppU8/vjjQDbrwvk8de8M1knRVuoNHDgQWHyWka+8e/fdd8s+bueddway+dgzZswAYJ111qlVaHVZqddyRo7/n33VbiVeJ/f5x08//TSQzbzxznfHHHMMALfeemtHQqtEK/VERPKkLjXklp3YvI45YcIEINu00nlm5PWfjz/+GMg2Z6zUJ9nXs+dRd9hiqRZaZq09e/YEFl+p19YV3RprrAFkNeLS41OuH1bD5996H2Tvv+B8Cyefp+8rWX27otNOO63oeL+yzJuff/550ft77LEHAFdddRWQrdxz3t3Nt4srvXrwFX4pUIYsIpKIus9D9k5M/mq1xRZbAFlNyFdeef9bv/vrd1e996nzvsrff/99V4bdpWo1uyLvWtb//ffZr1+/omO8zly6KtF5P4wa1YiT413wVlppJSDru/D+++8DWQZ91FFHAdnmrl579xk3PrvEa8h59vLLLwPZWOL3DWbOnAlUrqc77xDpVw9trRLtSsqQRUQSEW2lns9N9hpXR91///0A/Pbbb52OqbvJ87xjz+BKM+SmpiYA7r33XgDWW289IFv56bXj7sp/p57FeYY8fvx4IFvJ6jVj35Fnl112AbI+Dz5LZd68efUIu678/95efn9BK/VERCRehtxevsKo5fxDKO6hmwelq+eg9plsNR3lUuc9bcut8oSsP4HXDT1TbGt+ct75jCWf23/SSScB8PDDDwNZ/wafWeDHefbnK1rzsjIvhpj9TpQhi4gkIjcZ8lZbbQVkNTFXuidY6nxGRcvs1d/3lXWd7VVcaTdq7wKXB/PnzwfghRdeAIp3nWlp4sSJQDY/3Tt5VeI7CueV92u47LLLgGzFnu+sXCm784z4k08+6eoQc+/rr7+O9rOVIYuIJCI3GbLvAJuHjk2t8Rpyy7qxZ8iVMttKGXPp49qqGedxReDYsWMBePDBB4GsdlzquOOOAyr/XXjvhnHjxtU6xLryvQZ9NdqwYcOAxTNjn6Xiu2KMGjUKyPo2HH744V0fbE61XAVYb8qQRUQSoQFZRCQRuSlZuLyWKly5skFpqaG0dFGplFGtWrf1rCdfSu/NdMaMGQNkC0Ta4ttC+aV63vnNuUMOOQRYvLmQ81aTvvDDH7fbbrsBjdugvzX+PNHSaRERyV+G3F20zJS9UU5bN/fa+73zNM2tLZ4hf/7550A2/asSz4x9gUlbW8PnjbcKaNnatjV333130VtZnF99V9pAuR6UIYuIJCL3GXKeG9O70k1JSxeIlC6trtSus7MLSvLAm0l5a8VzzjkHyOp/kydPBuCuu+4Cul9mLF2vtMF9PSlDFhFJRF02Oc2bemzSmENdvslpDkXb5DRleXv++KYX06ZNA+Daa68F4NRTT63lj9EmpyIieZL7GrKISGf4xrDeVGjq1KnRYlGGLCKSCGXIItLQfMZO//79I0eiDFlEJBntzZBnAfnfN7x17d2/pRHOCbTvvOiclNcI50XnpLyqzku7pr2JiEjXUclCRCQRGpBFRBKhAVlEJBEakEVEEqEBWUQkERqQRUQSoQFZRCQRGpBFRBKhAVlEJBH/AbPwhv4YTiclAAAAAElFTkSuQmCC\n", | ||
"text/plain": [ | ||
"<matplotlib.figure.Figure at 0x7f57a743f860>" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
} | ||
], | ||
"source": [ | ||
"n_images = 5\n", | ||
"ids = np.random.choice(x_test.shape[0], n_images, replace=False)\n", | ||
"images = x_test[ids]\n", | ||
"id_class = model.predict(images)\n", | ||
"\n", | ||
"for index, image in enumerate(images):\n", | ||
" ax = plt.subplot(1, n_images, index+1)\n", | ||
" plt.imshow(image.reshape(28, 28), cmap=\"gray\")\n", | ||
" ax.get_xaxis().set_visible(False)\n", | ||
" ax.get_yaxis().set_visible(False)\n", | ||
" \n", | ||
" ax.set_title(\"{0}\".format(np.argmax(id_class[index])))" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.5.2" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
|
||
# coding: utf-8 | ||
|
||
# In[1]: | ||
|
||
|
||
import numpy as np | ||
from keras import layers, models, optimizers | ||
from keras import backend as K | ||
from keras.utils import to_categorical | ||
from keras.preprocessing.image import ImageDataGenerator | ||
from keras import callbacks | ||
import matplotlib.pyplot as plt | ||
from utils import combine_images | ||
from PIL import Image | ||
from capsulelayers import CapsuleLayer, PrimaryCap, Length, Mask | ||
|
||
get_ipython().magic(u'matplotlib inline') | ||
|
||
|
||
# In[2]: | ||
|
||
|
||
class args: | ||
save_dir = "weights/" | ||
debug = True | ||
|
||
# model | ||
routings = 1 | ||
|
||
# hp | ||
batch_size = 32 | ||
lr = 0.001 | ||
lr_decay = 1.0 | ||
lam_recon = 0.392 | ||
|
||
# training | ||
epochs = 3 | ||
shift_fraction = 0.1 | ||
digit = 5 | ||
|
||
|
||
# In[3]: | ||
|
||
|
||
|
||
def CapsNet(input_shape, n_class, routings): | ||
x = layers.Input(shape=input_shape) | ||
|
||
# Layer 1: Just a conventional Conv2D layer | ||
conv1 = layers.Conv2D(filters=256, kernel_size=9, strides=1, padding='valid', activation='relu', name='conv1')(x) | ||
|
||
# Layer 2: Conv2D layer with `squash` activation, then reshape to [None, num_capsule, dim_capsule] | ||
primarycaps = PrimaryCap(conv1, dim_capsule=8, n_channels=32, kernel_size=9, strides=2, padding='valid') | ||
|
||
# Layer 3: Capsule layer. Routing algorithm works here. | ||
digitcaps = CapsuleLayer(num_capsule=n_class, dim_capsule=16, routings=routings, | ||
name='digitcaps')(primarycaps) | ||
|
||
# Layer 4: This is an auxiliary layer to replace each capsule with its length. Just to match the true label's shape. | ||
# If using tensorflow, this will not be necessary. :) | ||
out_caps = Length(name='capsnet')(digitcaps) | ||
|
||
model = models.Model(x, out_caps) | ||
|
||
return model | ||
|
||
|
||
# In[4]: | ||
|
||
|
||
# the data, shuffled and split between train and test sets | ||
from keras.datasets import mnist | ||
(x_train, y_train), (x_test, y_test) = mnist.load_data() | ||
|
||
x_train = x_train.reshape(-1, 28, 28, 1).astype('float32') / 255. | ||
x_test = x_test.reshape(-1, 28, 28, 1).astype('float32') / 255. | ||
y_train = to_categorical(y_train.astype('float32')) | ||
y_test = to_categorical(y_test.astype('float32')) | ||
|
||
x_train = x_train[:1000] | ||
y_train = y_train[:1000] | ||
x_test = x_test[:1000] | ||
y_test = y_test[:1000] | ||
|
||
|
||
# In[5]: | ||
|
||
|
||
def margin_loss(y_true, y_pred): | ||
L = y_true * K.square(K.maximum(0., 0.9 - y_pred)) + 0.5 * (1 - y_true) * K.square(K.maximum(0., y_pred - 0.1)) | ||
|
||
return K.mean(K.sum(L, 1)) | ||
|
||
|
||
# In[6]: | ||
|
||
|
||
model = CapsNet(input_shape=x_train.shape[1:], n_class=len(np.unique(np.argmax(y_train, 1))), routings=args.routings) | ||
|
||
|
||
# In[7]: | ||
|
||
|
||
model.compile(optimizer=optimizers.Adam(lr=args.lr), loss=[margin_loss]) | ||
|
||
|
||
# In[8]: | ||
|
||
|
||
model.fit(x_train, y_train, validation_data=[x_test, y_test], epochs=3) | ||
|
||
|
||
# In[9]: | ||
|
||
|
||
n_images = 5 | ||
ids = np.random.choice(x_test.shape[0], n_images, replace=False) | ||
images = x_test[ids] | ||
id_class = model.predict(images) | ||
|
||
for index, image in enumerate(images): | ||
ax = plt.subplot(1, n_images, index+1) | ||
plt.imshow(image.reshape(28, 28), cmap="gray") | ||
ax.get_xaxis().set_visible(False) | ||
ax.get_yaxis().set_visible(False) | ||
|
||
ax.set_title("{0}".format(np.argmax(id_class[index]))) | ||
|