Skip to content

Commit

Permalink
Pickling
Browse files Browse the repository at this point in the history
  • Loading branch information
ajtulloch committed May 15, 2015
1 parent d0b74a8 commit 6c4038c
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 43 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@
build/
dist/
MANIFEST
TARGETS
*.pyc
*.so
7 changes: 2 additions & 5 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,7 @@ install:
- if [ "${COVERAGE}" == "--with-coverage" ]; then sudo pip install coverage; fi
- if [ "${COVERAGE}" == "--with-coverage" ]; then sudo pip install coveralls; fi
script:
- if [ "${COVERAGE}" == "--with-coverage" ]; then
- make test-coverage;
- else
- make test;
- fi
- if [ "${COVERAGE}" == "--with-coverage" ]; then make test-coverage; fi
- if [ "${COVERAGE}" != "--with-coverage" ]; then make test; fi
after_success:
- if [ "${COVERAGE}" == "--with-coverage" ]; then coveralls; fi
25 changes: 17 additions & 8 deletions benchmarks/bench_compiled_tree.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from __future__ import print_function

from collections import namedtuple
from datetime import datetime
from functools import partial
from sklearn import ensemble, datasets
import compiledtrees
from compiledtrees.compiled import CompiledRegressionPredictor
from sklearn.tree.tree import DTYPE
from sklearn.utils import array2d
from sklearn.utils.bench import total_seconds
import argparse
import gc
Expand Down Expand Up @@ -46,7 +48,7 @@ def uniform_dataset(args):

ENSEMBLE_REGRESSORS = [
("GB-D1", with_depth(ensemble.GradientBoostingRegressor, 1)),
("GB-D3", with_depth(ensemble.GradientBoostingRegressor, 3)),
("GB-D3", with_depth(ensemble.GradientBoostingRegressor, 3)),
("GB-B10", with_best_first(ensemble.GradientBoostingRegressor, 10)),
("RF-D1", with_depth(ensemble.RandomForestRegressor, 1)),
("RF-D3", with_depth(ensemble.RandomForestRegressor, 3)),
Expand Down Expand Up @@ -74,7 +76,7 @@ def run_ensemble(args, name, cls, X, y):
def run(n_estimators):
clf = cls(n_estimators=n_estimators)
clf.fit(X, y)
compiled = compiledtrees.CompiledRegressionPredictor(clf)
compiled = CompiledRegressionPredictor(clf)
relative_timing = RelativeTiming(
compiled=timing(compiled), normal=timing(clf))
print(n_estimators, relative_timing)
Expand All @@ -95,8 +97,14 @@ def plot(args, timings):
for name, cls_timings in timings:
xs, relative_timings = zip(*cls_timings)
ys = [r.normal / r.compiled for r in relative_timings]
plt.plot(xs, ys, '-o', label=name)
plt.hlines(1.0, np.min(xs), np.max(xs), 'k')
for x, y in zip(xs, ys):
print(xs, ys)
if args.show_plot:
plt.plot(xs, ys, '-o', label=name)
plt.hlines(1.0, np.min(xs), np.max(xs), 'k')

if not args.show_plot:
return

plt.xlabel('Number of weak learners')
plt.ylabel('Relative speedup')
Expand All @@ -107,14 +115,15 @@ def plot(args, timings):
plt.title(title)
plt.suptitle(suptitle, fontsize=3)
filename = "timings{0}.png".format(hash(str(args)))
plt.savefig(filename, dpi=300)
plt.savefig(filename, dpi=72)


def run_simulation(args):
X, y = DATASETS[args.dataset](args)
X = array2d(X, dtype=DTYPE)
X = X.astype(dtype=DTYPE)
timings = [(name, run_ensemble(args, name, cls, X, y))
for name, cls in ENSEMBLE_REGRESSORS]
print(timings)
plot(args, timings)


Expand Down
17 changes: 10 additions & 7 deletions compiledtrees/code_gen.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from __future__ import print_function

from distutils import sysconfig
Expand Down Expand Up @@ -61,10 +64,10 @@ def code_gen_tree(tree, evaluate_fn=EVALUATE_FN_NAME, gen=None):
def recur(node):
if tree.children_left[node] == -1:
assert tree.value[node].size == 1
gen.write("return {0};".format(tree.value[node].item()))
gen.write("return {0}f;".format(tree.value[node].item()))
return

branch = "if (f[{feature}] <= {threshold}) {{".format(
branch = "if (f[{feature}] <= {threshold}f) {{".format(
feature=tree.feature[node],
threshold=tree.threshold[node])
with gen.bracketed(branch, "}"):
Expand Down Expand Up @@ -142,9 +145,9 @@ def code_gen_ensemble(trees, individual_learner_weight, initial_value,
with gen.bracketed('extern "C" {', "}"):
fn_decl = "float {name}(float* f) {{".format(name=EVALUATE_FN_NAME)
with gen.bracketed(fn_decl, "}"):
gen.write("float result = {0};".format(initial_value))
gen.write("float result = {0}f;".format(initial_value))
for i, _ in enumerate(trees):
increment = "result += {name}_{index}(f) * {weight};".format(
increment = "result += {name}_{index}(f) * {weight}f;".format(
name=EVALUATE_FN_NAME,
index=i,
weight=individual_learner_weight)
Expand All @@ -169,7 +172,7 @@ def call(args):
with open(cpp_f, 'w') as f:
f.write(code)

call([CXX_COMPILER, cpp_f, "-c", "-o", o_f, "-O3"])
call([CXX_COMPILER, "-shared", o_f, "-dynamiclib",
"-fpic", "-flto", "-o", so_f, "-O3"])
call([CXX_COMPILER, cpp_f, "-c", "-fPIC", "-o", o_f, "-O3"])
call([CXX_COMPILER, "-shared", o_f,
"-fPIC", "-flto", "-o", so_f, "-O3"])
return so_f
34 changes: 27 additions & 7 deletions compiledtrees/compiled.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
from __future__ import print_function

from sklearn.utils import array2d
from sklearn.tree.tree import DecisionTreeRegressor, DTYPE
from sklearn.ensemble.gradient_boosting import GradientBoostingRegressor
from sklearn.ensemble.forest import ForestRegressor
Expand All @@ -27,7 +29,20 @@ class CompiledRegressionPredictor(object):
http://crsouza.blogspot.com/2012/01/decision-trees-in-c.html
"""
def __init__(self, clf):
self._n_features, self._evaluator = self._build(clf)
self._n_features, self._evaluator, self._so_f = self._build(clf)

def __getstate__(self):
return dict(n_features=self._n_features, so_f=open(self._so_f).read())

def __setstate__(self, state):
import tempfile
with tempfile.NamedTemporaryFile(delete=False) as tf:
tf.write(state["so_f"])
self._n_features = state["n_features"]
self._so_f = tf.name
self._evaluator = _compiled.CompiledPredictor(
tf.name.encode("ascii"),
cg.EVALUATE_FN_NAME.encode("ascii"))

@classmethod
def _build(cls, clf):
Expand Down Expand Up @@ -64,8 +79,10 @@ def _build(cls, clf):
assert lines is not None

so_f = cg.compile_code_to_object("\n".join(lines))
return n_features, _compiled.CompiledPredictor(
so_f.encode("ascii"), cg.EVALUATE_FN_NAME.encode("ascii"))
evaluator = _compiled.CompiledPredictor(
so_f.encode("ascii"),
cg.EVALUATE_FN_NAME.encode("ascii"))
return n_features, evaluator, so_f

@classmethod
def compilable(cls, clf):
Expand Down Expand Up @@ -110,10 +127,13 @@ def predict(self, X):
y: array of shape = [n_samples]
The predicted values.
"""
if getattr(X, "dtype", None) != DTYPE or X.ndim != 2:
X = array2d(X, dtype=DTYPE)
if X.dtype != DTYPE:
raise ValueError("X.dtype is {}, not {}".format(X.dtype, DTYPE))
if X.ndim != 2:
raise ValueError(
"Input must be 2-dimensional (n_samples, n_features), "
"not {}".format(X.shape))

# TODO - validate n_features is correct?
n_samples, n_features = X.shape
if self._n_features != n_features:
raise ValueError("Number of features of the model must "
Expand Down
82 changes: 66 additions & 16 deletions compiledtrees/tests/test_compiled.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from sklearn import ensemble, tree
from compiledtrees.compiled import CompiledRegressionPredictor
from sklearn.utils.testing import \
assert_array_almost_equal, assert_raises, assert_equal
import numpy as np
import unittest
import tempfile
import pickle
import cPickle

REGRESSORS = {
ensemble.GradientBoostingRegressor,
Expand All @@ -17,29 +25,71 @@
}


def pairwise(iterable):
import itertools
"s -> (s0,s1), (s1,s2), (s2, s3), ..."
a, b = itertools.tee(iterable)
next(b, None)
return itertools.izip(a, b)


def assert_equal_predictions(cls, X, y):
clf = cls()
clf.fit(X, y)
compiled = CompiledRegressionPredictor(clf)
assert_array_almost_equal(clf.predict(X), compiled.predict(X))

with tempfile.NamedTemporaryFile(delete=False) as tf:
pickle.dump(compiled, tf)
depickled = pickle.load(open(tf.name))

with tempfile.NamedTemporaryFile(delete=False) as tf:
pickle.dump(depickled, tf)
dedepickled = pickle.load(open(tf.name))

with tempfile.NamedTemporaryFile(delete=False) as tf:
cPickle.dump(compiled, tf)
decpickled = cPickle.load(open(tf.name))

predictors = [clf, compiled, depickled, decpickled, dedepickled]
predictions = [p.predict(X) for p in predictors]
for (p1, p2) in pairwise(predictions):
assert_array_almost_equal(p1, p2)


def test_rejects_unfitted_regressors_as_compilable():
for cls in REGRESSORS:
assert_equal(CompiledRegressionPredictor.compilable(cls()), False)
assert_raises(ValueError, CompiledRegressionPredictor, cls())
class TestCompiledTrees(unittest.TestCase):
def test_rejects_unfitted_regressors_as_compilable(self):
for cls in REGRESSORS:
assert_equal(CompiledRegressionPredictor.compilable(cls()), False)
assert_raises(ValueError, CompiledRegressionPredictor, cls())

def test_rejects_classifiers_as_compilable(self):
for cls in CLASSIFIERS:
assert_equal(CompiledRegressionPredictor.compilable(cls()), False)
assert_raises(ValueError, CompiledRegressionPredictor, cls())

def test_rejects_classifiers_as_compilable():
for cls in CLASSIFIERS:
assert_equal(CompiledRegressionPredictor.compilable(cls()), False)
assert_raises(ValueError, CompiledRegressionPredictor, cls())
def test_correct_predictions(self):
num_features = 20
num_examples = 1000
X = np.random.normal(size=(num_examples, num_features))
X = X.astype(np.float32)
y = np.random.normal(size=num_examples)
for cls in REGRESSORS:
assert_equal_predictions(cls, X, y)
y = np.random.choice([-1, 1], size=num_examples)
for cls in REGRESSORS:
assert_equal_predictions(cls, X, y)

def test_predictions_with_invalid_input(self):
num_features = 100
num_examples = 100
X = np.random.normal(size=(num_examples, num_features))
X = X.astype(np.float32)
y = np.random.choice([-1, 1], size=num_examples)

def test_correct_predictions():
num_features = 100
num_examples = 100
X = np.random.normal(size=(num_examples, num_features))
y = np.random.choice([-1, 1], size=num_examples)
for cls in REGRESSORS:
assert_equal_predictions(cls, X, y)
for cls in REGRESSORS:
clf = cls()
clf.fit(X, y)
compiled = CompiledRegressionPredictor(clf)
assert_raises(ValueError, compiled.predict, X.astype(np.float64))
assert_raises(ValueError, compiled.predict,
np.resize(X, (1, num_features, num_features)))

0 comments on commit 6c4038c

Please sign in to comment.