Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Overfeat layers #13

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions beacon8/layers/Dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams
_srng = RandomStreams()


class Dropout(Module):
def __init__(self, dropout):
Module.__init__(self)
Expand Down
14 changes: 14 additions & 0 deletions beacon8/layers/DuringTesting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from .Module import Module


class DuringTesting(Module):
def __init__(self, module):
Module.__init__(self)

self.module = module

def symb_forward(self, symb_input):
if self.training_mode:
return symb_input
else:
return self.module.symb_forward(symb_input)
14 changes: 14 additions & 0 deletions beacon8/layers/DuringTraining.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from .Module import Module


class DuringTraining(Module):
def __init__(self, module):
Module.__init__(self)

self.module = module

def symb_forward(self, symb_input):
if self.training_mode:
return self.module.symb_forward(symb_input)
else:
return symb_input
246 changes: 246 additions & 0 deletions beacon8/layers/Overfeat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
from . import Module


import theano as _th
from theano.sandbox.cuda import CudaNdarrayType, GpuOp
from theano.sandbox.cuda.basic_ops import (as_cuda_ndarray_variable, gpu_contiguous)


class PyCudaOp(GpuOp):
def __eq__(self, other):
return type(self) == type(other)

def __hash__(self):
return hash(type(self))

def __str__(self):
return self.__class__.__name__

def make_node(self, inp):
inp = as_cuda_ndarray_variable(inp)
return _th.Apply(self, [inp], [inp.type()])


class RollOpBase(PyCudaOp):
def c_support_code(self):
c_support_code = """
__global__ void maxpool_roll(float *input, float *output, int batch_size, int feature_size, int height_size, int width_size)
{
int x = blockIdx.x * blockDim.x + threadIdx.x;
int batch = blockIdx.y * blockDim.y + threadIdx.y;
int map_size = height_size * width_size;
int feature = x / map_size;
int height = (x % map_size) / width_size;
int width = x % width_size;
int height_out = height / 2;
int width_out = width / 2;
int batch_out = batch * 4;
if (height % 2 == 0 && width % 2 == 1)
{
batch_out += 1;
}
else if (height % 2 == 1 && width % 2 == 0)
{
batch_out += 2;
}
else if (height % 2 == 1 && width % 2 == 1)
{
batch_out += 3;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could write the three above as batch_out += 2*(height % 2) + (width % 2), which would avoid conditionals and thus, in theory, speed up this kernel quite a bit.

if (batch < batch_size && feature < feature_size && height_out * 2 < height_size && width_out * 2 < width_size)
{
output[batch_out * (feature_size * ((height_size + 1) / 2) * ((width_size + 1) / 2)) +
feature * (((height_size + 1) / 2) * ((width_size + 1) / 2)) +
height_out * ((width_size + 1) / 2) +
width_out] = input[batch * (feature_size * height_size * width_size) +
feature * (height_size * width_size) +
height * width_size +
width];
}
}
"""
return c_support_code

def c_code(self, node, name, inputs, outputs, sub):
fail = sub['fail']

inp, = inputs
out, = outputs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we put assert len(inputs) == 1, str(type(self)) + " only takes one input." and analogously for outputs here in order to prevent silent mistakes?


c_code = """
{
int batch_size = CudaNdarray_HOST_DIMS(%(inp)s)[0];
int n_features = CudaNdarray_HOST_DIMS(%(inp)s)[1];
int height = CudaNdarray_HOST_DIMS(%(inp)s)[2];
int width = CudaNdarray_HOST_DIMS(%(inp)s)[3];

int out_shape[] = {batch_size * 4, n_features, (height + 1) / 2, (width + 1) / 2};
if (NULL == %(out)s || CudaNdarray_NDIM(%(inp)s) != CudaNdarray_NDIM(%(out)s) ||
!(CudaNdarray_HOST_DIMS(%(out)s)[0] == out_shape[0] &&
CudaNdarray_HOST_DIMS(%(out)s)[1] == out_shape[1] &&
CudaNdarray_HOST_DIMS(%(out)s)[2] == out_shape[2] &&
CudaNdarray_HOST_DIMS(%(out)s)[3] == out_shape[3]))
{
Py_XDECREF(%(out)s);
%(out)s = (CudaNdarray*)CudaNdarray_ZEROS(CudaNdarray_NDIM(%(inp)s), out_shape);
}

if (!%(out)s)
{
PyErr_SetString(PyExc_MemoryError, "failed to alloc output");
%(fail)s;
}

dim3 block(16, 16, 1);
dim3 grid((int)(ceil(((float)n_features * height * width) / block.x)),
(int)(ceil(((float)batch_size) / block.y)),
1);

maxpool_roll<<<grid, block>>>(CudaNdarray_DEV_DATA(%(inp)s),
CudaNdarray_DEV_DATA(%(out)s),
batch_size, n_features, height, width);

CNDA_THREAD_SYNC;
cudaError_t sts = cudaGetLastError();
if (cudaSuccess != sts)
{
PyErr_Format(PyExc_RuntimeError, cudaGetErrorString(sts));
%(fail)s;
}
}
"""
return c_code % locals()


class RollOp(RollOpBase):
def grad(self, inp, grads):
top, = grads
top = gpu_contiguous(top)
return [RollOpGrad()(top)]


class UnRollOpBase(PyCudaOp):
def c_support_code(self):
c_support_code = """
__global__ void maxpool_unroll(float *input, float *output, int batch_size, int feature_size, int height_size, int width_size)
{
int x = blockIdx.x * blockDim.x + threadIdx.x;
int batch = blockIdx.y * blockDim.y + threadIdx.y;
int map_size = height_size * width_size;
int feature = x / map_size;
int height = (x % map_size) / width_size;
int width = x % width_size;
int height_out = height * 2;
int width_out = width * 2;
int batch_out = batch / 4;
if (batch % 4 == 1)
{
width_out += 1;
}
else if (batch % 4 == 2)
{
height_out += 1;
}
else if (batch % 4 == 3)
{
height_out += 1;
width_out += 1;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, with width_out += batch % 2 and height_out += (batch/2) % 2?

if (batch < batch_size && feature < feature_size)
{
output[batch_out * (feature_size * height_size * 2 * width_size * 2) +
feature * (height_size * 2 * width_size * 2) +
height_out * width_size * 2 +
width_out] = input[batch * (feature_size * height_size * width_size) +
feature * (height_size * width_size) +
height * width_size +
width];
}
}
"""

return c_support_code

def c_code(self, node, name, inputs, outputs, sub):
fail = sub['fail']

inp, = inputs
out, = outputs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same assert as above.


c_code = """
{
int batch_size = CudaNdarray_HOST_DIMS(%(inp)s)[0];
int n_features = CudaNdarray_HOST_DIMS(%(inp)s)[1];
int height = CudaNdarray_HOST_DIMS(%(inp)s)[2];
int width = CudaNdarray_HOST_DIMS(%(inp)s)[3];

int out_shape[] = {batch_size / 4, n_features, height * 2, width * 2};
if (NULL == %(out)s || CudaNdarray_NDIM(%(inp)s) != CudaNdarray_NDIM(%(out)s) ||
!(CudaNdarray_HOST_DIMS(%(out)s)[0] == out_shape[0] &&
CudaNdarray_HOST_DIMS(%(out)s)[1] == out_shape[1] &&
CudaNdarray_HOST_DIMS(%(out)s)[2] == out_shape[2] &&
CudaNdarray_HOST_DIMS(%(out)s)[3] == out_shape[3]))
{
Py_XDECREF(%(out)s);
%(out)s = (CudaNdarray*)CudaNdarray_NewDims(CudaNdarray_NDIM(%(inp)s), out_shape);
}

if (!%(out)s)
{
PyErr_SetString(PyExc_MemoryError, "failed to alloc output");
%(fail)s;
}

dim3 block(16, 16, 1);
dim3 grid((int)(ceil(((float)n_features * height * width) / block.x)),
(int)(ceil(((float)batch_size) / block.y)),
1);

maxpool_unroll<<<grid, block>>>(CudaNdarray_DEV_DATA(%(inp)s),
CudaNdarray_DEV_DATA(%(out)s),
batch_size, n_features, height, width);

CNDA_THREAD_SYNC;
cudaError_t sts = cudaGetLastError();
if (cudaSuccess != sts)
{
PyErr_Format(PyExc_RuntimeError, cudaGetErrorString(sts));
%(fail)s;
}
}
"""
return c_code % locals()


class UnRollOp(UnRollOpBase):
def grad(self, inp, grads):
top, = grads
top = gpu_contiguous(top)
return [UnRollOpGrad()(top)]


class RollOpGrad(UnRollOpBase):
pass


class UnRollOpGrad(RollOpBase):
pass

unroll = UnRollOp()
roll = RollOp()


class OverfeatRoll(Module):
def __init__(self):
Module.__init__(self)

def symb_forward(self, symb_input):
return roll(symb_input)


class OverfeatUnroll(Module):
def __init__(self):
Module.__init__(self)

def symb_forward(self, symb_input):
return unroll(symb_input)
4 changes: 2 additions & 2 deletions beacon8/layers/SpatialConvolutionCUDNN.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from .Module import Module

import theano as _th
import numpy as _np
import theano.sandbox.cuda.dnn as _dnn

from .Module import Module


class SpatialConvolutionCUDNN(Module):
def __init__(self, n_input_plane, n_output_plane, k_w, k_h, d_w=1, d_h=1, pad_w=0, pad_h=0, with_bias=True):
Expand Down
4 changes: 2 additions & 2 deletions beacon8/layers/SpatialMaxPoolingCUDNN.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import theano.sandbox.cuda.dnn as _dnn

from .Module import Module

import theano.sandbox.cuda.dnn as _dnn


class SpatialMaxPoolingCUDNN(Module):
def __init__(self, k_w, k_h, d_w=None, d_h=None, pad_w=0, pad_h=0):
Expand Down
17 changes: 17 additions & 0 deletions beacon8/layers/SpatialSoftMax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from .Module import Module

import theano.sandbox.cuda.dnn as dnn
from theano.sandbox.cuda.basic_ops import gpu_contiguous


def spatial_softmax(img):
img = gpu_contiguous(img)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This little bastard 😄

return dnn.GpuDnnSoftmax(tensor_format='bc01', algo='accurate', mode='channel')(img)


class SpatialSoftMax(Module):
def __init__(self):
Module.__init__(self)

def symb_forward(self, symb_input):
return spatial_softmax(symb_input)
14 changes: 14 additions & 0 deletions beacon8/layers/SpatialSubSampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from .Module import Module


class SpatialSubSampling(Module):
"""
note that it behaves very differently from Torch!
"""
def __init__(self, scale):
self.scale = scale

def symb_forward(self, symb_input):
if symb_input.ndim != 4:
raise NotImplementedError
return symb_input[:, :, ::self.scale, ::self.scale]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As soon as Theano/Theano#2487 lands, we can make this generic.

If you want, using this trick scale could be either an integer it is now, or a tuple for different scales for the dimensions, and then use it as symb_input[:, :, ::self.scale[0], ::self.scale[1]].

5 changes: 5 additions & 0 deletions beacon8/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,8 @@
from .SpatialMaxPooling import *
from .SpatialConvolutionCUDNN import *
from .SpatialMaxPoolingCUDNN import *
from .Overfeat import OverfeatRoll, OverfeatUnroll
from .DuringTraining import *
from .DuringTesting import *
from .SpatialSoftMax import SpatialSoftMax
from .SpatialSubSampling import *
1 change: 1 addition & 0 deletions examples/MNIST/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from progress_bar import *
import theano as _th


def validate(dataset_x, dataset_y, model, epoch, batch_size):
progress = make_progressbar('Testing', epoch, len(dataset_x))
progress.start()
Expand Down
38 changes: 38 additions & 0 deletions examples/Segmentation/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import glob
import os
import numpy as np
import scipy as sp
import tarfile

# Python 2/3 compatibility.
try:
from urllib.request import urlretrieve
except ImportError:
from urllib import urlretrieve



def load_data():
data_folder = os.path.join(os.path.dirname(__file__), 'iccv09Data')
if not os.path.isdir(data_folder):
tar_file = os.path.join(os.path.dirname(__file__), 'data.tar.gz')
origin = ('http://dags.stanford.edu/data/iccv09Data.tar.gz')
print('Downloading data from {}'.format(origin))
urlretrieve(origin, tar_file)
tar = tarfile.open(tar_file)
tar.extractall()
tar.close()
os.remove(tar_file)

image_files = glob.glob(os.path.join(data_folder, 'images', '*.jpg'))

set_x = list()
set_y = list()

for image_file in image_files:
file_id = os.path.splitext(os.path.split(image_file)[1])[0]
labels = np.loadtxt(os.path.join(data_folder, 'labels', file_id + '.regions.txt'))
set_x.append(sp.misc.imread(image_file).transpose(2, 0, 1))
set_y.append(labels)

return set_x, set_y
Loading