Skip to content

Commit

Permalink
added a bunch of scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
maxjcohen committed Jan 20, 2018
1 parent 7b83957 commit 549fbec
Show file tree
Hide file tree
Showing 4 changed files with 468 additions and 26 deletions.
105 changes: 79 additions & 26 deletions run.ipynb

Large diffs are not rendered by default.

35 changes: 35 additions & 0 deletions run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import numpy as np
from keras import callbacks

import capsulenet

class args:
save_dir = "weights/"
debug = True

# model
routings = 1

# hp
batch_size = 32
lr = 0.001
lr_decay = 1.0
lam_recon = 0.392

# training
epochs = 3
shift_fraction = 0.1
digit = 5


(x_train, y_train), (x_test, y_test) = capsulenet.load_mnist()

model, eval_model, manipulate_model = capsulenet.CapsNet(input_shape=x_train.shape[1:],
n_class=len(np.unique(np.argmax(y_train, 1))),
routings=args.routings)

capsulenet.train(model=model, data=((x_train, y_train), (x_test, y_test)), args=args)

capsulenet.test(eval_model, data=(x_test, y_test), args=args)

model.save_weights("weights.h5")
225 changes: 225 additions & 0 deletions simple.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n",
"/home/watiz/.virtualenvs/siamese_reid/lib/python3.5/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
" from ._conv import register_converters as _register_converters\n"
]
}
],
"source": [
"import numpy as np\n",
"from keras import layers, models, optimizers\n",
"from keras import backend as K\n",
"from keras.utils import to_categorical\n",
"from keras.preprocessing.image import ImageDataGenerator\n",
"from keras import callbacks\n",
"import matplotlib.pyplot as plt\n",
"from utils import combine_images\n",
"from PIL import Image\n",
"from capsulelayers import CapsuleLayer, PrimaryCap, Length, Mask\n",
"\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"class args:\n",
" save_dir = \"weights/\"\n",
" debug = True\n",
"\n",
" # model\n",
" routings = 1\n",
"\n",
" # hp\n",
" batch_size = 32\n",
" lr = 0.001\n",
" lr_decay = 1.0\n",
" lam_recon = 0.392\n",
"\n",
" # training\n",
" epochs = 3\n",
" shift_fraction = 0.1\n",
" digit = 5"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"\n",
"def CapsNet(input_shape, n_class, routings):\n",
" x = layers.Input(shape=input_shape)\n",
"\n",
" # Layer 1: Just a conventional Conv2D layer\n",
" conv1 = layers.Conv2D(filters=256, kernel_size=9, strides=1, padding='valid', activation='relu', name='conv1')(x)\n",
"\n",
" # Layer 2: Conv2D layer with `squash` activation, then reshape to [None, num_capsule, dim_capsule]\n",
" primarycaps = PrimaryCap(conv1, dim_capsule=8, n_channels=32, kernel_size=9, strides=2, padding='valid')\n",
"\n",
" # Layer 3: Capsule layer. Routing algorithm works here.\n",
" digitcaps = CapsuleLayer(num_capsule=n_class, dim_capsule=16, routings=routings,\n",
" name='digitcaps')(primarycaps)\n",
"\n",
" # Layer 4: This is an auxiliary layer to replace each capsule with its length. Just to match the true label's shape.\n",
" # If using tensorflow, this will not be necessary. :)\n",
" out_caps = Length(name='capsnet')(digitcaps)\n",
" \n",
" model = models.Model(x, out_caps)\n",
" \n",
" return model\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# the data, shuffled and split between train and test sets\n",
"from keras.datasets import mnist\n",
"(x_train, y_train), (x_test, y_test) = mnist.load_data()\n",
"\n",
"x_train = x_train.reshape(-1, 28, 28, 1).astype('float32') / 255.\n",
"x_test = x_test.reshape(-1, 28, 28, 1).astype('float32') / 255.\n",
"y_train = to_categorical(y_train.astype('float32'))\n",
"y_test = to_categorical(y_test.astype('float32'))\n",
"\n",
"x_train = x_train[:1000]\n",
"y_train = y_train[:1000]\n",
"x_test = x_test[:1000]\n",
"y_test = y_test[:1000]\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def margin_loss(y_true, y_pred):\n",
" L = y_true * K.square(K.maximum(0., 0.9 - y_pred)) + \\\n",
" 0.5 * (1 - y_true) * K.square(K.maximum(0., y_pred - 0.1))\n",
"\n",
" return K.mean(K.sum(L, 1))\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"model = CapsNet(input_shape=x_train.shape[1:], n_class=len(np.unique(np.argmax(y_train, 1))), routings=args.routings)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"model.compile(optimizer=optimizers.Adam(lr=args.lr), loss=[margin_loss])"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train on 1000 samples, validate on 1000 samples\n",
"Epoch 1/3\n",
"1000/1000 [==============================] - 20s - loss: 0.4622 - val_loss: 0.2460\n",
"Epoch 2/3\n",
"1000/1000 [==============================] - 20s - loss: 0.1756 - val_loss: 0.1639\n",
"Epoch 3/3\n",
"1000/1000 [==============================] - 26s - loss: 0.1070 - val_loss: 0.1050\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x7f57a75547f0>"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.fit(x_train, y_train, validation_data=[x_test, y_test], epochs=3)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAWQAAABbCAYAAABEQP/sAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvAOZPmwAADpxJREFUeJzt3XmQ3VP6x/H3Q2JIJGhLyBRimRAUU0KV8kPsxNhjH2tRIkEsyc8ee1HWQg0RpGzzKww/SxA/+86PoEjEElMIhiDEZEEEZ/64/eTb9+be7tvdt+85376fV1Wqt+/tfvLtvuc+3+d7znMshICIiMS3ROwARESkQAOyiEgiNCCLiCRCA7KISCI0IIuIJEIDsohIIjQgi4gkIskB2cwGmdkzZvZvM/unme0TO6YUmFmTmT1gZvPNbIaZHRI7ptjMbICZTTKz2WY208z+ZmY9YscVi5n9wcwmNP99zDWzt81saOy4UpCHcSW5Abn5yfQQ8AjQBBwL/N3MBkYNLA3XA78A/YC/AuPMbMO4IUV3A/ANsBrwZ2AIMDJqRHH1AD6ncB6WA84B/mFmAyLGFF1exhVLbaWemW0E/D/QJzQHZ2ZPAK+FEMZGDS4iM+sNzAY2CiFMb/7cncC/QghnRA0uIjN7HxgdQpjU/PEVQN8QwvC4kaXDzKYAF4QQ/jd2LLHkZVxJLkOuwICNYgcR2UDgVx+Mm70DNHqGfA1wkJn1MrM/AkOB/4scUzLMrB+Fv51psWNJUHLjSooD8ocULkH/28x6mtnOFC6/esUNK7plgTkln/s30CdCLCl5gcKL0hzgC+AN4MGoESXCzHoC/wPcHkL4IHY8keViXEluQA4hLAT2Bv4CzARGA/+g8GRrZPOAviWf6wvMjRBLEsxsCQrZ8P1Ab2AlYAXgsphxpaD53NxJ4Z7DCZHDiS4v40pyAzJACGFKCGFICGHFEMIuwNrA67Hjimw60MPM/tTic5vQ2JeiTcAawN9CCAtCCN8BtwK7xQ0rLjMzYAKFm7/DmgejhpeHcSXJAdnMNjazpZvrgmMo3EG/LXJYUYUQ5lPIBC80s95m9l/AXhSyoIYUQpgFfAKMMLMeZrY8cAQwJW5k0Y0DBgF7hBB+ih1MKvIwriQ5IAOHAV9RqPnsAOwUQlgQN6QkjASWoXBe7gJGhBAaOUMG2BfYFfgW+CewEDglakQRmdmawHAKUwBnmtm85n9/jRxaCpIfV5Kb9iYi0qhSzZBFRBqOBmQRkURoQBYRSYQGZBGRRGhAFhFJRLvaFJpZQ0zJCCFYtcc2yjkBZoUQVq7mQJ2T8hrlvOj5U1ZVfyvKkKVaM2IHkCCdE6lWVX8rGpBFRBKhAVlEJBEakEVEEtGwe4+JdEeDBw8G4MADDwRgxIgRAFx00UUA3HzzzQDMnj07QnTSFmXIIiKJaFdzoUaZoqJpO2W9GULYrJoDdU7K68rzst122wFwwQUXALDlllsCMHXqVABWX311AH76qdCNc6+99gLgrbfeqnksev6UVdXfijJkEZFEJFdDLmx2ACussAIAJ598MpC94v/yyy8AjB8/HoCJEycC0J3biG677bZFb88777yyx3l29NxzzxW97Y6WWWYZAE4//XQAevfuDcCuu+4KwAYbbFD2cRdeeCGQnau8W3vttQG45557AGhqagLgyiuvBOCMMwobknsG7cc99NBDAGy88caAasqpUIYsIpKIZGrI/ko9duxYAIYNG1bV46655hoARo8eDdQmU06lBnb++ecDlTPitnhWVKNMOWoN2TPinXbaCYAxY8YA2ZWTX1n573/evHlAlvl5DXXatMIGK5tsskktwopeQ/ZM2K8kH3nkEQD2228/AH799dei4w8++GAA7ryzsPPXmWeeCcAVV1xRs5hSef4kRjVkEZE8iV5D3n///QG46aabAOjbt7DT/eTJk4Eso/Ga8p577glkGdGJJ54IwCWXXALArFmz6hF2l3r22WeBrGbcUZ5Zd4dasmdyZ511VlXHX3fddQCMGzcOgCOPPBKAjz76qPbBRbTyysX9an744Qdg8czYvfnmm0D2PPEr0lpmyKk44YQTADj77LMBWHXVVYHsKuqpp54Cspkofl9hzpw5dY2zJWXIIiKJ0IAsIpKIaCWLU04p7NR+7rnnAtmEdb80vfHGG8s+7vjjjweyy3G/5OoOpQq/idfZUkV3ss022wAwcuTIdj1u1KhRQLbw4dJLL61tYIm4/PLLATjggAMAeP7551s9fvr06UA2/c0v6/0mqd8kzKPVVlsNgNtuuw3IbmovueSSQFaq8Lc77LBD0dt9990XyBbNTJkypQ5RF1OGLCKSiLpPe1t++eUBeP3114Gs0O7T3j799NOqvs8qq6xS9PE333zT2dAWqfe0Hc+I/WZeJf6K39YCEec38/xxnRRl2tvcuXOBbNpbKz8TWHzao195eQb52GOP1So0SGDam/Msz68I2lro4QtK3njjDSC7qe4LSTqj3s+f9ddfH4DXXnsNgGWXXbbo6y+++CKQTQn0qwR3xx13ANCnTx8ga8TkV6w1omlvIiJ5UpcMeYklsnH/vvvuA2DvvfcGsnaAw4cP78i3ZrnllgPgoIMOArL60YIFCzr0/aD+r/CVfgfVZrhtTZOr0QKRumbIQ4cOBWDSpEkA/P7772WP+/HHH4s+7tWrV9njbrjhBiCbJlkjyWTIHfXll18Wfdy/f/9Of896P388M95ss8Kv4osvvgDg2GOPBeCZZ54BYOHChWUf71fXK664IgCvvPIKAFtvvXVnQ2tJGbKISJ7UZZaF3ymHLDP2WREdvfu97rrrAvDEE08AMGDAACB7tXvggQc69H3rpa16MVRf+/XjKmXK/nmvs6bKF/8A3H777UCWGZdeRXiNdJ999gGyJdS+QKiUN2r3jLvGtWSps913333R+5tvvjmweGvRt99+u9XvcfTRRwPZYjT36KOP1izO9lKGLCKSiLpkyOXu3N5yyy1A9bMq3NJLLw1kGbBnxu67775rf4B1VNpKs5yO1nrbypRb3jWu8R3kmvDfLRTfdyjn+uuvB+Cll14CsqX2O+64IwDbb7992cf5vPdGz5A33HBDIKu5X3311THDqZrHe8QRRyz6nF/5+bZVbWXG/n/3+1elvv32207H2VHKkEVEEhFtpV57M1nPhH3OoL/KlYqxuqY9qsmMOztv2Juk5G3F31dffbXofd+IoBKfreN8Vo3fUa+UIa+55pqdCbHb2HTTTYFs7m3pbJVUedOktdZaa9Hn/P7CZ599VnRsjx6F4c3XPviV+qGHHlr0uFKVZvTUgzJkEZFE1CVDLtfOzmdXlGY2pbznxWGHHQbAUkstVfY4f/XzlV2pGjJkSMWvtdWHoFqlWzh5ptzaz84LP0cffPBB2a97bfjiiy8u+3Xf6sm3eHrvvfdqHWLSPCP22SmeJcasm7aHXzl9+OGHiz7n2b433fe/kUGDBgHZ1VKl1ZwpUYYsIpKIuqzUa9l34tVXXwWKa0DV8IzonXfeAbI7qs4b1/t69c7oypVGrZ3vWs8Tbm0LqA78rC5fqTd48OBF73t2730JvK735JNPAtlmppU8/vjjQDbrwvk8de8M1knRVuoNHDgQWHyWka+8e/fdd8s+bueddway+dgzZswAYJ111qlVaHVZqddyRo7/n33VbiVeJ/f5x08//TSQzbzxznfHHHMMALfeemtHQqtEK/VERPKkLjXklp3YvI45YcIEINu00nlm5PWfjz/+GMg2Z6zUJ9nXs+dRd9hiqRZaZq09e/YEFl+p19YV3RprrAFkNeLS41OuH1bD5996H2Tvv+B8Cyefp+8rWX27otNOO63oeL+yzJuff/550ft77LEHAFdddRWQrdxz3t3Nt4srvXrwFX4pUIYsIpKIus9D9k5M/mq1xRZbAFlNyFdeef9bv/vrd1e996nzvsrff/99V4bdpWo1uyLvWtb//ffZr1+/omO8zly6KtF5P4wa1YiT413wVlppJSDru/D+++8DWQZ91FFHAdnmrl579xk3PrvEa8h59vLLLwPZWOL3DWbOnAlUrqc77xDpVw9trRLtSsqQRUQSEW2lns9N9hpXR91///0A/Pbbb52OqbvJ87xjz+BKM+SmpiYA7r33XgDWW289IFv56bXj7sp/p57FeYY8fvx4IFvJ6jVj35Fnl112AbI+Dz5LZd68efUIu678/95efn9BK/VERCRehtxevsKo5fxDKO6hmwelq+eg9plsNR3lUuc9bcut8oSsP4HXDT1TbGt+ct75jCWf23/SSScB8PDDDwNZ/wafWeDHefbnK1rzsjIvhpj9TpQhi4gkIjcZ8lZbbQVkNTFXuidY6nxGRcvs1d/3lXWd7VVcaTdq7wKXB/PnzwfghRdeAIp3nWlp4sSJQDY/3Tt5VeI7CueV92u47LLLgGzFnu+sXCm784z4k08+6eoQc+/rr7+O9rOVIYuIJCI3GbLvAJuHjk2t8Rpyy7qxZ8iVMttKGXPp49qqGedxReDYsWMBePDBB4GsdlzquOOOAyr/XXjvhnHjxtU6xLryvQZ9NdqwYcOAxTNjn6Xiu2KMGjUKyPo2HH744V0fbE61XAVYb8qQRUQSoQFZRCQRuSlZuLyWKly5skFpqaG0dFGplFGtWrf1rCdfSu/NdMaMGQNkC0Ta4ttC+aV63vnNuUMOOQRYvLmQ81aTvvDDH7fbbrsBjdugvzX+PNHSaRERyV+G3F20zJS9UU5bN/fa+73zNM2tLZ4hf/7550A2/asSz4x9gUlbW8PnjbcKaNnatjV333130VtZnF99V9pAuR6UIYuIJCL3GXKeG9O70k1JSxeIlC6trtSus7MLSvLAm0l5a8VzzjkHyOp/kydPBuCuu+4Cul9mLF2vtMF9PSlDFhFJRF02Oc2bemzSmENdvslpDkXb5DRleXv++KYX06ZNA+Daa68F4NRTT63lj9EmpyIieZL7GrKISGf4xrDeVGjq1KnRYlGGLCKSCGXIItLQfMZO//79I0eiDFlEJBntzZBnAfnfN7x17d2/pRHOCbTvvOiclNcI50XnpLyqzku7pr2JiEjXUclCRCQRGpBFRBKhAVlEJBEakEVEEqEBWUQkERqQRUQSoQFZRCQRGpBFRBKhAVlEJBH/AbPwhv4YTiclAAAAAElFTkSuQmCC\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x7f57a743f860>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"n_images = 5\n",
"ids = np.random.choice(x_test.shape[0], n_images, replace=False)\n",
"images = x_test[ids]\n",
"id_class = model.predict(images)\n",
"\n",
"for index, image in enumerate(images):\n",
" ax = plt.subplot(1, n_images, index+1)\n",
" plt.imshow(image.reshape(28, 28), cmap=\"gray\")\n",
" ax.get_xaxis().set_visible(False)\n",
" ax.get_yaxis().set_visible(False)\n",
" \n",
" ax.set_title(\"{0}\".format(np.argmax(id_class[index])))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
129 changes: 129 additions & 0 deletions simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@

# coding: utf-8

# In[1]:


import numpy as np
from keras import layers, models, optimizers
from keras import backend as K
from keras.utils import to_categorical
from keras.preprocessing.image import ImageDataGenerator
from keras import callbacks
import matplotlib.pyplot as plt
from utils import combine_images
from PIL import Image
from capsulelayers import CapsuleLayer, PrimaryCap, Length, Mask

get_ipython().magic(u'matplotlib inline')


# In[2]:


class args:
save_dir = "weights/"
debug = True

# model
routings = 1

# hp
batch_size = 32
lr = 0.001
lr_decay = 1.0
lam_recon = 0.392

# training
epochs = 3
shift_fraction = 0.1
digit = 5


# In[3]:



def CapsNet(input_shape, n_class, routings):
x = layers.Input(shape=input_shape)

# Layer 1: Just a conventional Conv2D layer
conv1 = layers.Conv2D(filters=256, kernel_size=9, strides=1, padding='valid', activation='relu', name='conv1')(x)

# Layer 2: Conv2D layer with `squash` activation, then reshape to [None, num_capsule, dim_capsule]
primarycaps = PrimaryCap(conv1, dim_capsule=8, n_channels=32, kernel_size=9, strides=2, padding='valid')

# Layer 3: Capsule layer. Routing algorithm works here.
digitcaps = CapsuleLayer(num_capsule=n_class, dim_capsule=16, routings=routings,
name='digitcaps')(primarycaps)

# Layer 4: This is an auxiliary layer to replace each capsule with its length. Just to match the true label's shape.
# If using tensorflow, this will not be necessary. :)
out_caps = Length(name='capsnet')(digitcaps)

model = models.Model(x, out_caps)

return model


# In[4]:


# the data, shuffled and split between train and test sets
from keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.reshape(-1, 28, 28, 1).astype('float32') / 255.
x_test = x_test.reshape(-1, 28, 28, 1).astype('float32') / 255.
y_train = to_categorical(y_train.astype('float32'))
y_test = to_categorical(y_test.astype('float32'))

x_train = x_train[:1000]
y_train = y_train[:1000]
x_test = x_test[:1000]
y_test = y_test[:1000]


# In[5]:


def margin_loss(y_true, y_pred):
L = y_true * K.square(K.maximum(0., 0.9 - y_pred)) + 0.5 * (1 - y_true) * K.square(K.maximum(0., y_pred - 0.1))

return K.mean(K.sum(L, 1))


# In[6]:


model = CapsNet(input_shape=x_train.shape[1:], n_class=len(np.unique(np.argmax(y_train, 1))), routings=args.routings)


# In[7]:


model.compile(optimizer=optimizers.Adam(lr=args.lr), loss=[margin_loss])


# In[8]:


model.fit(x_train, y_train, validation_data=[x_test, y_test], epochs=3)


# In[9]:


n_images = 5
ids = np.random.choice(x_test.shape[0], n_images, replace=False)
images = x_test[ids]
id_class = model.predict(images)

for index, image in enumerate(images):
ax = plt.subplot(1, n_images, index+1)
plt.imshow(image.reshape(28, 28), cmap="gray")
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)

ax.set_title("{0}".format(np.argmax(id_class[index])))

0 comments on commit 549fbec

Please sign in to comment.