Skip to content

Commit

Permalink
Monte Carlo dropout benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmedmalaa committed Feb 21, 2020
1 parent 52503dc commit 9d1f9fd
Show file tree
Hide file tree
Showing 10 changed files with 284 additions and 0 deletions.
Binary file added data/__pycache__/make_data.cpython-36.pyc
Binary file not shown.
77 changes: 77 additions & 0 deletions models/DNN_uncertainty.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@

# Copyright (c) 2020, Ahmed M. Alaa
# Licensed under the BSD 3-clause license (see LICENSE.txt)

# ---------------------------------------------------------
# Base classes for feedforward, convolutional and recurrent
# neural network (DNN, CNN, RNN) models in pytorch
# ---------------------------------------------------------

from __future__ import absolute_import, division, print_function

import numpy as np
import pandas as pd
import sys

if not sys.warnoptions:
import warnings
warnings.simplefilter("ignore")

import torch
from torch.autograd import Variable
import torch.nn.functional as nnf
from torch.utils.data import random_split
from torch.optim import SGD
from torch.distributions import constraints
import torchvision as torchv
import torchvision.transforms as torchvt
from torch import nn
from torch.autograd import grad
import torch.nn.functional as F
import scipy.stats as st

from sklearn.preprocessing import StandardScaler
from copy import deepcopy
import time

from models.base_models import DNN

torch.manual_seed(1)


class MCDP_DNN(DNN):

def __init__(self,
dropout_prob=0.5,
dropout_active=True,
n_dim=1,
num_layers=2,
num_hidden=200,
output_size=1,
activation="ReLU",
mode="Regression"):

super(MCDP_DNN, self).__init__()

self.dropout_prob = dropout_prob
self.dropout = nn.Dropout(p=dropout_prob)
self.dropout_active = True


def forward(self, X):

_out= self.dropout(self.model(X))

return _out


def predict(self, X, alpha=0.1, MC_samples=100):

z_c = st.norm.ppf(1-alpha/2)
X = torch.tensor(X.reshape((-1, self.n_dim))).float()
samples_ = [self.forward(X).detach().numpy() for u in range(MC_samples)]
pred_sample = np.concatenate(samples_, axis=1)
pred_mean = np.mean(pred_sample, axis=1)
pred_std = z_c * np.std(pred_sample, axis=1)

return pred_mean, pred_std
Binary file added models/__pycache__/DNN_uncertainty.cpython-36.pyc
Binary file not shown.
Binary file added models/__pycache__/__init__.cpython-36.pyc
Binary file not shown.
Binary file added models/__pycache__/base_models.cpython-36.pyc
Binary file not shown.
133 changes: 133 additions & 0 deletions models/base_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@

# Copyright (c) 2020, Ahmed M. Alaa
# Licensed under the BSD 3-clause license (see LICENSE.txt)

# ---------------------------------------------------------
# Base classes for feedforward, convolutional and recurrent
# neural network (DNN, CNN, RNN) models in pytorch
# ---------------------------------------------------------

# -------------------------------------
# | TO DO: |
# | ------ |
# | Loss functions file |
# | ADD EPOCHS |
# | argument explanation for the DNN |
# | Exception handling |
# | Multiple architectures in RNN |
# | cmd arguments |
# | logger, misc and config files |
# -------------------------------------

from __future__ import absolute_import, division, print_function

import numpy as np
import pandas as pd
import sys

if not sys.warnoptions:
import warnings
warnings.simplefilter("ignore")

import torch
from torch.autograd import Variable
import torch.nn.functional as nnf
from torch.utils.data import random_split
from torch.optim import SGD
from torch.distributions import constraints
import torchvision as torchv
import torchvision.transforms as torchvt
from torch import nn
import torchvision.transforms as transforms
from torch.autograd import grad
import scipy.stats as st

from sklearn.preprocessing import StandardScaler
from copy import deepcopy
import time

from utils.parameters import *

torch.manual_seed(1)


class DNN(nn.Module):

def __init__(self,
n_dim=1,
dropout_prob=0.0,
dropout_active=False,
num_layers=2,
num_hidden=200,
output_size=1,
activation="Tanh",
mode="Regression"
):

super(DNN, self).__init__()

self.n_dim = n_dim
self.num_layers = num_layers
self.num_hidden = num_hidden
self.mode = mode
self.activation = activation
self.device = torch.device('cpu') # Make this an option
self.output_size = output_size
self.dropout_prob = dropout_prob
self.dropout_active = dropout_active
self.model = build_architecture(self)


def fit(self, X, y, learning_rate=1e-3, loss_type="MSE", batch_size=100, num_iter=500, verbosity=False):

self.X = torch.tensor(X.reshape((-1, self.n_dim))).float()
self.y = torch.tensor(y).float()

loss_dict = {"MSE": torch.nn.MSELoss}

self.loss_fn = loss_dict[loss_type](reduction='mean')
self.loss_trace = []

batch_size = np.min((batch_size, X.shape[0]))

optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)

for _ in range(num_iter):

batch_idx = np.random.choice(list(range(X.shape[0])), batch_size )

y_pred = self.model(self.X[batch_idx, :])

self.loss = self.loss_fn(y_pred.reshape((batch_size, self.n_dim)), self.y[batch_idx].reshape((batch_size, self.n_dim)))

self.loss_trace.append(self.loss.detach().numpy())

if verbosity:

print("--- Iteration: %d \t--- Loss: %.3f" % (_, self.loss.item()))

self.model.zero_grad()

optimizer.zero_grad() # clear gradients for this training step
self.loss.backward() # backpropagation, compute gradients
optimizer.step()



def predict(self, X, numpy_output=True):

X = torch.tensor(X.reshape((-1, self.n_dim))).float()

if numpy_output:

prediction = self.model(X).detach().numpy()

else:

prediction = self.model(X)


return prediction



Empty file added utils/__init__.py
Empty file.
Binary file added utils/__pycache__/__init__.cpython-36.pyc
Binary file not shown.
Binary file added utils/__pycache__/parameters.cpython-36.pyc
Binary file not shown.
74 changes: 74 additions & 0 deletions utils/parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@

# Copyright (c) 2020, Ahmed M. Alaa
# Licensed under the BSD 3-clause license (see LICENSE.txt)

# ---------------------------------------------------------
# Helper functions and utilities for deep learning models
# ---------------------------------------------------------


from __future__ import absolute_import, division, print_function

import numpy as np
import pandas as pd
import sys

if not sys.warnoptions:
import warnings
warnings.simplefilter("ignore")

import torch
from torch.autograd import Variable
import torch.nn.functional as nnf
from torch.utils.data import random_split
from torch.optim import SGD
from torch.distributions import constraints
import torchvision as torchv
import torchvision.transforms as torchvt
from torch import nn
import torchvision.transforms as transforms
from torch.autograd import grad
import scipy.stats as st

from sklearn.preprocessing import StandardScaler
from copy import deepcopy
import time

torch.manual_seed(1)


ACTIVATION_DICT = {"ReLU": torch.nn.ReLU(), "Hardtanh": torch.nn.Hardtanh(),
"ReLU6": torch.nn.ReLU6(), "Sigmoid": torch.nn.Sigmoid(),
"Tanh": torch.nn.Tanh(), "ELU": torch.nn.ELU(),
"CELU": torch.nn.CELU(), "SELU": torch.nn.SELU(),
"GLU": torch.nn.GLU(), "LeakyReLU": torch.nn.LeakyReLU(),
"LogSigmoid": torch.nn.LogSigmoid(), "Softplus": torch.nn.Softplus()}


def build_architecture(base_model):

modules = []

if base_model.dropout_active:

modules.append(torch.nn.Dropout(p=base_model.dropout_prob))

modules.append(torch.nn.Linear(base_model.n_dim, base_model.num_hidden))
modules.append(ACTIVATION_DICT[base_model.activation])

for u in range(base_model.num_layers - 1):

if base_model.dropout_active:

modules.append(torch.nn.Dropout(p=base_model.dropout_prob))

modules.append(torch.nn.Linear(base_model.num_hidden, base_model.num_hidden))
modules.append(ACTIVATION_DICT[base_model.activation])

modules.append(torch.nn.Linear(base_model.num_hidden, base_model.output_size))

_architecture = nn.Sequential(*modules)

return _architecture


0 comments on commit 9d1f9fd

Please sign in to comment.