diff --git a/h2o-algos/src/main/java/hex/api/RegisterAlgos.java b/h2o-algos/src/main/java/hex/api/RegisterAlgos.java index 5135a8911347..54a22dd24a70 100644 --- a/h2o-algos/src/main/java/hex/api/RegisterAlgos.java +++ b/h2o-algos/src/main/java/hex/api/RegisterAlgos.java @@ -40,6 +40,7 @@ public void registerEndPoints(RestApiContext context) { new hex.tree.dt .DT (true), new hex.hglm .HGLM (true), new hex.adaboost. AdaBoost (true) + //new hex.knn .KNN (true) will be implement in different PR }; // "Word2Vec", "Example", "Grep" diff --git a/h2o-algos/src/main/java/hex/generic/GenericModel.java b/h2o-algos/src/main/java/hex/generic/GenericModel.java index 06f60b8d6f9f..07e072946a48 100644 --- a/h2o-algos/src/main/java/hex/generic/GenericModel.java +++ b/h2o-algos/src/main/java/hex/generic/GenericModel.java @@ -241,6 +241,7 @@ private void predict(EasyPredictModelWrapper wrapper, AdaptFrameParameters adapt final String weightsColumn = adaptFrameParameters.getWeightsColumn(); final String responseColumn = adaptFrameParameters.getResponseColumn(); final String treatmentColumn = adaptFrameParameters.getTreatmentColumn(); + final String idColumn = adaptFrameParameters.getIdColumn(); final boolean isClassifier = wrapper.getModel().isClassifier(); final boolean isUplift = treatmentColumn != null; final float[] yact; @@ -355,6 +356,9 @@ public String getResponseColumn() { } @Override public String getTreatmentColumn() {return descriptor != null ? descriptor.treatmentColumn() : null;} + @Override + public String getIdColumn() { return descriptor != null ? descriptor.idColumn() : null;} + @Override public double missingColumnsType() { return Double.NaN; diff --git a/h2o-algos/src/main/java/hex/knn/CosineDistance.java b/h2o-algos/src/main/java/hex/knn/CosineDistance.java new file mode 100644 index 000000000000..92684b332e9f --- /dev/null +++ b/h2o-algos/src/main/java/hex/knn/CosineDistance.java @@ -0,0 +1,25 @@ +package hex.knn; + +public class CosineDistance extends KNNDistance { + + public CosineDistance(){ + super.valuesLength = 3; + } + + @Override + public double nom(double v1, double v2) { + return v1*v2; + } + + @Override + public void calculateValues(double v1, double v2) { + this.values[0] += nom(v1, v2); + this.values[1] += nom(v1, v1); + this.values[2] += nom(v2, v2); + } + + @Override + public double result() { + return 1 - (this.values[0] / (Math.sqrt(this.values[1]) * Math.sqrt(this.values[2]))); + } +} diff --git a/h2o-algos/src/main/java/hex/knn/EuclideanDistance.java b/h2o-algos/src/main/java/hex/knn/EuclideanDistance.java new file mode 100644 index 000000000000..f3786b64e45f --- /dev/null +++ b/h2o-algos/src/main/java/hex/knn/EuclideanDistance.java @@ -0,0 +1,21 @@ +package hex.knn; + +public class EuclideanDistance extends KNNDistance { + + @Override + public double nom(double v1, double v2) { + return (v1-v2)*(v1-v2); + } + + @Override + public void calculateValues(double v1, double v2) { + this.values[0] += nom(v1, v2); + } + + @Override + public double result() { + return Math.sqrt(this.values[0]); + } + + +} diff --git a/h2o-algos/src/main/java/hex/knn/KNN.java b/h2o-algos/src/main/java/hex/knn/KNN.java new file mode 100644 index 000000000000..ae2ede5def64 --- /dev/null +++ b/h2o-algos/src/main/java/hex/knn/KNN.java @@ -0,0 +1,101 @@ +package hex.knn; + +import hex.*; +import water.DKV; +import water.Key; +import water.Scope; +import water.fvec.Chunk; +import water.fvec.Frame; + +public class KNN extends ModelBuilder { + + public KNN(KNNModel.KNNParameters parms) { + super(parms); + init(false); + } + + public KNN(boolean startup_once) { + super(new KNNModel.KNNParameters(), startup_once); + } + + @Override + protected KNNDriver trainModelImpl() { + return new KNNDriver(); + } + + @Override + public ModelCategory[] can_build() { + return new ModelCategory[]{ModelCategory.Binomial, ModelCategory.Multinomial}; + } + + @Override + public boolean isSupervised() { + return true; + } + + @Override public void init(boolean expensive) { + super.init(expensive); + if( null == _parms._id_column) { + error("_id_column", "ID column parameter not set."); + } + if( null == _parms._distance) { + error("_distance", "Distance parameter not set."); + } + } + + class KNNDriver extends Driver { + + @Override + public void computeImpl() { + KNNModel model = null; + Frame result = new Frame(Key.make("KNN_distances")); + Frame tmpResult = null; + try { + init(true); // Initialize parameters + if (error_count() > 0) { + throw new IllegalArgumentException("Found validation errors: " + validationErrors()); + } + model = new KNNModel(dest(), _parms, new KNNModel.KNNOutput(KNN.this)); + model.delete_and_lock(_job); + Frame train = _parms.train(); + String idColumn = _parms._id_column; + int idColumnIndex = train.find(idColumn); + byte idType = train.vec(idColumnIndex).get_type(); + String responseColumn = _parms._response_column; + int responseColumnIndex = train.find(responseColumn); + int nChunks = train.anyVec().nChunks(); + int nCols = train.numCols(); + // split data into chunks to calculate distances in parallel task + for (int i = 0; i < nChunks; i++) { + Chunk[] query = new Chunk[nCols]; + for (int j = 0; j < nCols; j++) { + query[j] = train.vec(j).chunkForChunkIdx(i).deepCopy(); + } + KNNDistanceTask task = new KNNDistanceTask(_parms._k, query, _parms._distance, idColumnIndex, idColumn, idType, responseColumnIndex, responseColumn); + tmpResult = task.doAll(train).outputFrame(); + // merge result from a chunk + result = result.add(tmpResult); + } + DKV.put(result._key, result); + model._output.setDistancesKey(result._key); + Scope.untrack(result); + + model.update(_job); + + model.score(_parms.train()).delete(); + model._output._training_metrics = ModelMetrics.getFromDKV(model, _parms.train()); + } catch (Exception e) { + throw new RuntimeException(e); + } finally { + if (model != null) { + model.unlock(_job); + } + if (tmpResult != null) { + tmpResult.remove(); + } + } + } + } +} + + diff --git a/h2o-algos/src/main/java/hex/knn/KNNDistance.java b/h2o-algos/src/main/java/hex/knn/KNNDistance.java new file mode 100644 index 000000000000..d006bded8689 --- /dev/null +++ b/h2o-algos/src/main/java/hex/knn/KNNDistance.java @@ -0,0 +1,45 @@ +package hex.knn; + +import water.Iced; + +/** + * Template for various distance calculation. + */ +public abstract class KNNDistance extends Iced { + + // Lenght of values for calculation of the distance + // For example for calculation euclidean and manhattan distance we need only one value, + // for calculation cosine distance we need tree values. + public int valuesLength = 1; + + // Array to cumulate partial calculations of distance + public double[] values; + + /** + * Method to calculate the distance between two points from two vectors. + * @param v1 value of an item in the first vector + * @param v2 value of an item in the second vector + */ + public abstract double nom(double v1, double v2); + + /** + * Initialize values array to store partial calculation of distance. + */ + public void initializeValues(){ + this.values = new double[valuesLength]; + } + + /** + * Method to cumulate partial calculations of distance between two vectors and save it to values array. + * @param v1 value of an item in the first vector + * @param v2 value of an item in the second vector + */ + public abstract void calculateValues(double v1, double v2); + + /** + * Calculate the result from cumulated values. + * @return Final distance calculation. + */ + public abstract double result(); + +} diff --git a/h2o-algos/src/main/java/hex/knn/KNNDistanceTask.java b/h2o-algos/src/main/java/hex/knn/KNNDistanceTask.java new file mode 100644 index 000000000000..31bca772ede3 --- /dev/null +++ b/h2o-algos/src/main/java/hex/knn/KNNDistanceTask.java @@ -0,0 +1,133 @@ +package hex.knn; + +import water.Key; +import water.MRTask; +import water.fvec.Chunk; +import water.fvec.Frame; +import water.fvec.Vec; + +import java.util.Iterator; + +public class KNNDistanceTask extends MRTask { + + public int _k; + public Chunk[] _queryData; + public KNNDistance _distance; + public KNNHashMap> _topNNeighboursMaps; + public String _idColumn; + public String _responseColumn; + public int _idIndex; + public int _responseIndex; + public byte _idColumnType; + + /** + * Calculate distances dor a particular chunk + */ + public KNNDistanceTask(int k, Chunk[] query, KNNDistance distance, int idIndex, String idColumn, byte idType, int responseIndex, String responseColumn){ + this._k = k; + this._queryData = query; + this._distance = distance; + this._topNNeighboursMaps = new KNNHashMap<>(); + this._idColumn = idColumn; + this._responseColumn = responseColumn; + this._idIndex = idIndex; + this._responseIndex = responseIndex; + this._idColumnType = idType; + } + + @Override + public void map(Chunk[] cs) { + int queryColNum = _queryData.length; + long queryRowNum = _queryData[0]._len; + int inputColNum = cs.length; + int inputRowNum = cs[0]._len; + assert queryColNum == inputColNum: "Query data frame and input data frame should have the same columns number."; + for (int i = 0; i < queryRowNum; i++) { // go over all query data rows + TopNTreeMap distancesMap = new TopNTreeMap<>(_k); + String queryDataId = _idColumnType == Vec.T_STR ? _queryData[_idIndex].stringAt(i) : String.valueOf(_queryData[_idIndex].at8(i)); + for (int j = 0; j < inputRowNum; j++) { // go over all input data rows + String inputDataId = _idColumnType == Vec.T_STR ? cs[_idIndex].stringAt(j) : String.valueOf(cs[_idIndex].at8(j)); + long inputDataCategory = cs[_responseIndex].at8(j); + // if(queryDataId.equals(inputDataId)) continue; // the same id included or not? + _distance.initializeValues(); + for (int k = 0; k < inputColNum; k++) { // go over all columns + if (k == _idIndex || k == _responseIndex) continue; + double queryColData = _queryData[k].atd(i); + double inputColData = cs[k].atd(j); + _distance.calculateValues(queryColData, inputColData); + } + double dist = _distance.result(); + + distancesMap.put(new KNNKey(inputDataId, dist), inputDataCategory); + } + _topNNeighboursMaps.put(queryDataId, distancesMap); + } + } + + @Override + public void reduce(KNNDistanceTask mrt) { + KNNHashMap> inputMap = mrt._topNNeighboursMaps; + this._topNNeighboursMaps.reduce(inputMap); + } + + /** + * Get data from maps to Frame + * @param vecs + * @return filled array of vecs with calculated data + */ + public Vec[] fillVecs(Vec[] vecs){ + for (int i = 0; i < vecs[0].length(); i++) { + String id = _idColumnType == Vec.T_STR ? vecs[0].stringAt(i) : String.valueOf(vecs[0].at8(i)); + TopNTreeMap topNMap = _topNNeighboursMaps.get(id); + Iterator distances = topNMap.keySet().stream().iterator(); + Iterator responses = topNMap.values().iterator(); + for (int j = 1; j < _k+1; j++) { + KNNKey key = distances.next(); + String keyString = key.key.toString(); + vecs[j].set(i, key.value); + if(_idColumnType == Vec.T_STR){ + vecs[_k + j].set(i, keyString); + } else { + vecs[_k + j].set(i, Integer.parseInt(keyString)); + } + vecs[2 * _k + j].set(i, (long) responses.next()); + } + } + return vecs; + } + + /** + * Generate output frame with calculated distances. + * @return + */ + public Frame outputFrame() { + int newVecsSize = _k*3+1; + Vec[] vecs = new Vec[newVecsSize]; + String[] names = new String[newVecsSize]; + boolean isStringId = _idColumnType == Vec.T_STR; + Vec id = Vec.makeCon(0, _queryData[0].len(), false); + for (int i = 0; i < _queryData[_idIndex].len(); i++) { + if(isStringId) { + id.set(i, _queryData[_idIndex].stringAt(i)); + } else { + id.set(i, _queryData[_idIndex].atd(i)); + } + } + vecs[0] = id; + names[0] = _idColumn; + for (int i = 1; i < _k+1; i++) { + // names of columns + names[i] = "dist_"+i; // this could be customized + names[_k+i] = _idColumn+"_"+i; // this could be customized + names[2*_k+i] = _responseColumn+"_"+i; // this could be customized + vecs[i] = id.makeZero(); + vecs[i] = vecs[i].toNumericVec(); + vecs[_k+i] = id.makeZero(); + if (isStringId) vecs[_k+i].toStringVec(); + vecs[2*_k+i] = id.makeZero(); + vecs[2*_k+i] = vecs[2*_k+i].toNumericVec(); + } + vecs = fillVecs(vecs); + return new Frame(Key.make("KNN_distances_tmp"), names, vecs); + } +} diff --git a/h2o-algos/src/main/java/hex/knn/KNNHashMap.java b/h2o-algos/src/main/java/hex/knn/KNNHashMap.java new file mode 100644 index 000000000000..cd7227175297 --- /dev/null +++ b/h2o-algos/src/main/java/hex/knn/KNNHashMap.java @@ -0,0 +1,19 @@ +package hex.knn; + +import java.util.HashMap; +import java.util.Map; + +public class KNNHashMap> extends HashMap { + + public void reduce(KNNHashMap map){ + for (Map.Entry entry: map.entrySet()) { + K key = entry.getKey(); + V valueMap = entry.getValue(); + if (this.containsKey(key)){ + V currentKeyMap = this.get(key); + currentKeyMap.putAll(valueMap); + this.put(key, currentKeyMap); + } + } + } +} diff --git a/h2o-algos/src/main/java/hex/knn/KNNKey.java b/h2o-algos/src/main/java/hex/knn/KNNKey.java new file mode 100644 index 000000000000..8a2a5869ab44 --- /dev/null +++ b/h2o-algos/src/main/java/hex/knn/KNNKey.java @@ -0,0 +1,45 @@ +package hex.knn; + +import water.Iced; + +import java.util.Objects; + +/** + * Class to save id and distance value for KNN calculation. + * The key can be String or Integer. The value should be Double or class extends Double. + * @param String of Integer + * @param Double or class extends Double + */ +public class KNNKey, V extends Double> extends Iced implements Comparable> { + + K key; + V value; + + KNNKey(K key, V value){ + this.key = key; + this.value = value; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + KNNKey knnKey = (KNNKey) o; + return Objects.equals(key, knnKey.key) && Objects.equals(value, knnKey.value); + } + + @Override + public int hashCode() { + return Objects.hash(key, value); + } + + @Override + public int compareTo(KNNKey o) { + if (o == null) return 1; + int sameValue = this.value.compareTo(o.value); + if (sameValue == 0){ + return this.key.compareTo(o.key); + } + return sameValue; + } +} diff --git a/h2o-algos/src/main/java/hex/knn/KNNModel.java b/h2o-algos/src/main/java/hex/knn/KNNModel.java new file mode 100644 index 000000000000..e2d4c194ce4d --- /dev/null +++ b/h2o-algos/src/main/java/hex/knn/KNNModel.java @@ -0,0 +1,84 @@ +package hex.knn; + +import hex.*; +import water.DKV; +import water.H2O; +import water.Key; +import water.Scope; +import water.fvec.Frame; + +public class KNNModel extends Model { + + public static class KNNParameters extends Model.Parameters { + public String algoName() { + return "KNN"; + } + public String fullName() { + return "K-nearest neighbors"; + } + public String javaName() { + return KNNModel.class.getName(); + } + + public int _k = 3; + public KNNDistance _distance; + public boolean _compute_metrics; + + @Override + public long progressUnits() { + return 0; + } + } + + public static class KNNOutput extends Model.Output { + + public KNNOutput(KNN b) { + super(b); + } + Key _distances_key; + + @Override + public ModelCategory getModelCategory() { + if (nclasses() > 2) { + return ModelCategory.Multinomial; + } else { + return ModelCategory.Binomial; + } + } + + public void setDistancesKey(Key _distances_key) { + this._distances_key = _distances_key; + } + + public Frame getDistances(){ + return DKV.get(_distances_key).get(); + } + } + + public KNNModel(Key selfKey, KNNModel.KNNParameters parms, KNNModel.KNNOutput output) { + super(selfKey, parms, output); + } + + @Override + public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) { + switch(_output.getModelCategory()) { + case Binomial: + return new ModelMetricsBinomial.MetricBuilderBinomial(domain); + case Multinomial: + return new ModelMetricsMultinomial.MetricBuilderMultinomial(_output.nclasses(), domain, _parms._auc_type); + default: throw H2O.unimpl("Invalid ModelCategory " + _output.getModelCategory()); + } + } + + @Override + protected double[] score0(double[] data, double[] preds) { + Frame train = _parms._train.get(); + int idIndex = train.find(_parms._id_column); + int responseIndex = train.find(_parms._response_column); + byte idType = train.types()[idIndex]; + preds = new KNNScoringTask(data, _parms._k, _output.nclasses(), _parms._distance, idIndex, idType, + responseIndex).doAll(train).score(); + Scope.untrack(train); + return preds; + } +} diff --git a/h2o-algos/src/main/java/hex/knn/KNNScoringTask.java b/h2o-algos/src/main/java/hex/knn/KNNScoringTask.java new file mode 100644 index 000000000000..8838e95c288c --- /dev/null +++ b/h2o-algos/src/main/java/hex/knn/KNNScoringTask.java @@ -0,0 +1,75 @@ +package hex.knn; + + +import water.MRTask; +import water.fvec.Chunk; +import water.fvec.Vec; +import water.util.ArrayUtils; + +public class KNNScoringTask extends MRTask { + + public int _k; + public double[] _queryData; + public KNNDistance _distance; + public TopNTreeMap _distancesMap; + public int _idIndex; + public int _responseIndex; + public byte _idColumnType; + public int _domainSize; + + /** + * Go through the whole input frame to find the k near distances and score based on them. + */ + public KNNScoringTask(double[] query, int k, int domainSize, KNNDistance distance, int idIndex, byte idType, int responseIndex){ + this._k = k; + this._queryData = query; + this._distance = distance; + this._responseIndex = responseIndex; + this._idIndex = idIndex; + this._idColumnType = idType; + this._distancesMap = new TopNTreeMap<>(_k); + this._domainSize = domainSize; + } + + @Override + public void map(Chunk[] cs) { + int inputColNum = cs.length; + int inputRowNum = cs[0]._len; + for (int i = 0; i < inputRowNum; i++) { // go over all input data rows + String inputDataId = _idColumnType == Vec.T_STR ? cs[_idIndex].stringAt(i) : String.valueOf(cs[_idIndex].at8(i)); + int inputDataCategory = (int) cs[_responseIndex].at8(i); + _distance.initializeValues(); + int j = 0; + for (int k = 0; k < inputColNum; k++) { // go over all columns + if(k == _idIndex || k == _responseIndex) continue; + double queryColData = _queryData[j++]; + double inputColData = cs[k].atd(i); + _distance.calculateValues(queryColData, inputColData); + } + double dist = _distance.result(); + _distancesMap.put(new KNNKey(inputDataId, dist), inputDataCategory); + } + } + + @Override + public void reduce(KNNScoringTask mrt) { + this._distancesMap.putAll(mrt._distancesMap); + } + + public double[] score(){ + double[] scores = new double[_domainSize+1]; + assert _distancesMap.size() <= _k: "Distances map size should be <= _k"; + for (int value: _distancesMap.values()){ + scores[value+1]++; + } + // normalize the result score by _k + for (int i = 1; i < _domainSize+1; i++) { + if(scores[i] != 0) { + scores[i] = scores[i]/_k; + } + } + // decide the class by the max score + scores[0] = ArrayUtils.maxIndex(scores)-1; + return scores; + } +} diff --git a/h2o-algos/src/main/java/hex/knn/ManhattanDistance.java b/h2o-algos/src/main/java/hex/knn/ManhattanDistance.java new file mode 100644 index 000000000000..748bf9c47d87 --- /dev/null +++ b/h2o-algos/src/main/java/hex/knn/ManhattanDistance.java @@ -0,0 +1,20 @@ +package hex.knn; + +public class ManhattanDistance extends KNNDistance { + + @Override + public double nom(double v1, double v2) { + return Math.abs(v1-v2); + } + + @Override + public void calculateValues(double v1, double v2) { + this.values[0] += nom(v1, v2); + } + + @Override + public double result() { + return this.values[0]; + } + +} diff --git a/h2o-algos/src/main/java/hex/knn/TopNTreeMap.java b/h2o-algos/src/main/java/hex/knn/TopNTreeMap.java new file mode 100644 index 000000000000..658080fc69a1 --- /dev/null +++ b/h2o-algos/src/main/java/hex/knn/TopNTreeMap.java @@ -0,0 +1,48 @@ +package hex.knn; + +import java.util.Collection; +import java.util.Comparator; +import java.util.Map; +import java.util.TreeMap; + +/** + * The map for saving distances. + * @param Key is composed of data id and distance value. + * @param Value is the class of the data point. + */ +public class TopNTreeMap extends TreeMap { + + public int n; + + TopNTreeMap(int n){ + this.n = n; + } + + @Override + public V put(K key, V value) { + if(size() < n) { + return super.put(key, value); + } + K lastKey = lastEntry().getKey(); + int compare = comparator().compare(lastKey, key); + if(compare > 0 ) { + V returnValue = super.put(key, value); + if (size() > n){ + remove(lastKey); + } + return returnValue; + } else { + return null; + } + } + + @Override + public Comparator comparator() { + return new Comparator() { + @Override + public int compare(K o1, K o2) { + return o1.compareTo(o2); + } + }; + } +} diff --git a/h2o-algos/src/test/java/hex/knn/KNNDistanceTaskTest.java b/h2o-algos/src/test/java/hex/knn/KNNDistanceTaskTest.java new file mode 100644 index 000000000000..749f05b8687f --- /dev/null +++ b/h2o-algos/src/test/java/hex/knn/KNNDistanceTaskTest.java @@ -0,0 +1,97 @@ +package hex.knn; + +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; +import water.Scope; +import water.TestUtil; +import water.fvec.Chunk; +import water.fvec.Frame; +import water.fvec.Vec; +import water.util.TwoDimTable; + +public class KNNDistanceTaskTest extends TestUtil { + + @BeforeClass() public static void setup() { stall_till_cloudsize(1); } + + @Test + public void testIrisEuclidean(){ + try { + Scope.enter(); + Frame fr = parseTestFile("smalldata/iris/iris_wheader.csv"); + Scope.track(fr); + String idColumn = "id"; + String response = "class"; + fr.add(idColumn, createIdVec(fr.numRows(), Vec.T_NUM)); + Scope.track(fr); + int k = 3; + int nCols = fr.numCols(); + Chunk[] query = new Chunk[nCols]; + for (int j = 0; j < nCols; j++) { + query[j] = fr.vec(j).chunkForChunkIdx(0); + } + KNNDistanceTask mrt = new KNNDistanceTask(k, query, new EuclideanDistance(), fr.find(idColumn), idColumn, fr.vec(idColumn).get_type(), fr.find(response), response); + mrt.doAll(fr); + Frame result = mrt.outputFrame(); + Scope.track(result); + Assert.assertNotNull(result); + } + finally { + Scope.exit(); + } + } + + @Test + public void testIrisManhattan(){ + try { + Scope.enter(); + Frame fr = parseTestFile("smalldata/iris/iris_wheader.csv"); + Scope.track(fr); + String idColumn = "id"; + String response = "class"; + fr.add(idColumn, createIdVec(fr.numRows(), Vec.T_NUM)); + Scope.track(fr); + int k = 3; + int nCols = fr.numCols(); + Chunk[] query = new Chunk[nCols]; + for (int j = 0; j < nCols; j++) { + query[j] = fr.vec(j).chunkForChunkIdx(0); + } + KNNDistanceTask mrt = new KNNDistanceTask(k, query, new ManhattanDistance(), fr.find(idColumn), idColumn, fr.vec(idColumn).get_type(), fr.find(response), response); + mrt.doAll(fr); + Frame result = mrt.outputFrame(); + Scope.track(result); + Assert.assertNotNull(result); + } + finally { + Scope.exit(); + } + } + + @Test + public void testIrisCosine(){ + try { + Scope.enter(); + Frame fr = parseTestFile("smalldata/iris/iris_wheader.csv"); + Scope.track(fr); + String idColumn = "id"; + String response = "class"; + fr.add(idColumn, createIdVec(fr.numRows(), Vec.T_NUM)); + Scope.track(fr); + int k = 3; + int nCols = fr.numCols(); + Chunk[] query = new Chunk[nCols]; + for (int j = 0; j < nCols; j++) { + query[j] = fr.vec(j).chunkForChunkIdx(0); + } + KNNDistanceTask mrt = new KNNDistanceTask(k, query, new CosineDistance(), fr.find(idColumn), idColumn, fr.vec(idColumn).get_type(), fr.find(response), response); + mrt.doAll(fr); + Frame result = mrt.outputFrame(); + Scope.track(result); + Assert.assertNotNull(result); + } + finally { + Scope.exit(); + } + } +} diff --git a/h2o-algos/src/test/java/hex/knn/KNNTest.java b/h2o-algos/src/test/java/hex/knn/KNNTest.java new file mode 100644 index 000000000000..421f4f47a4ab --- /dev/null +++ b/h2o-algos/src/test/java/hex/knn/KNNTest.java @@ -0,0 +1,380 @@ +package hex.knn; + +import hex.*; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import water.DKV; +import water.TestUtil; +import water.exceptions.H2OModelBuilderIllegalArgumentException; +import water.fvec.Frame; +import water.fvec.TestFrameBuilder; +import water.fvec.Vec; +import water.runner.CloudSize; +import water.runner.H2ORunner; + +@CloudSize(1) +@RunWith(H2ORunner.class) + +public class KNNTest extends TestUtil { + + @Test + public void testIris() { + KNNModel knn = null; + Frame fr = null; + Frame preds = null; + Frame distances = null; + try { + fr = parseTestFile("smalldata/iris/iris_wheader.csv"); + + String idColumn = "id"; + String response = "class"; + + fr.add(idColumn, createIdVec(fr.numRows(), Vec.T_NUM)); + DKV.put(fr); + KNNModel.KNNParameters parms = new KNNModel.KNNParameters(); + parms._train = fr._key; + parms._k = 3; + parms._distance = new EuclideanDistance(); + parms._response_column = response; + parms._id_column = idColumn; + parms._auc_type = MultinomialAucType.MACRO_OVR; + + parms._seed = 42; + KNN job = new KNN(parms); + knn = job.trainModel().get(); + Assert.assertNotNull(knn); + + distances = knn._output.getDistances(); + Assert.assertNotNull(distances); + + preds = knn.score(fr); + Assert.assertNotNull(preds); + + ModelMetricsMultinomial mm = (ModelMetricsMultinomial) ModelMetrics.getFromDKV(knn, parms.train()); + ModelMetricsMultinomial mm1 = (ModelMetricsMultinomial) knn._output._training_metrics; + Assert.assertEquals(mm.auc(), mm1.auc(), 0); + + // test after KNN API will be ready + //knn.testJavaScoring(fr, preds, 0); + + } finally { + if (knn != null){ + knn.delete(); + } + if (distances != null){ + distances.delete(); + } + if(fr != null) { + fr.delete(); + } + if(preds != null){ + preds.delete(); + } + } + } + + @Test + public void testSimpleFrameEuclidean() { + KNNModel knn = null; + Frame fr = null; + Frame preds = null; + Frame distances = null; + try { + fr = generateSimpleFrame(); + + String idColumn = "id"; + String response = "class"; + + DKV.put(fr); + KNNModel.KNNParameters parms = new KNNModel.KNNParameters(); + parms._train = fr._key; + parms._k = 2; + parms._distance = new EuclideanDistance(); + parms._response_column = response; + parms._id_column = idColumn; + parms._auc_type = MultinomialAucType.MACRO_OVR; + + parms._seed = 42; + KNN job = new KNN(parms); + knn = job.trainModel().get(); + Assert.assertNotNull(knn); + + distances = knn._output.getDistances(); + Assert.assertNotNull(distances); + + Assert.assertEquals(distances.vec(0).at8(0), 1); + Assert.assertEquals(distances.vec(1).at(0), 0.0, 0); + Assert.assertEquals(distances.vec(2).at(0), 1.414, 10e-3); + Assert.assertEquals(distances.vec(3).at8(0), 1); + Assert.assertEquals(distances.vec(4).at8(0), 2); + Assert.assertEquals(distances.vec(5).at8(0), 1); + Assert.assertEquals(distances.vec(6).at8(0),1); + + Assert.assertEquals(distances.vec(0).at8(1), 2); + Assert.assertEquals(distances.vec(1).at(1), 0.0, 0); + Assert.assertEquals(distances.vec(2).at(1), 1.414, 10e-3); + Assert.assertEquals(distances.vec(3).at8(1), 2); + Assert.assertEquals(distances.vec(4).at8(1), 1); + Assert.assertEquals(distances.vec(5).at8(1), 1); + Assert.assertEquals(distances.vec(6).at8(1), 1); + + preds = knn.score(fr); + Assert.assertNotNull(preds); + + Assert.assertEquals(preds.vec(0).at8(0), 1); + Assert.assertEquals(preds.vec(1).at(0), 0.0, 0); + Assert.assertEquals(preds.vec(2).at(0), 1.0, 0); + + Assert.assertEquals(preds.vec(0).at8(3), 0); + Assert.assertEquals(preds.vec(1).at(3), 1.0, 0); + Assert.assertEquals(preds.vec(2).at(3), 0.0, 0); + + ModelMetricsBinomial mm = (ModelMetricsBinomial) ModelMetrics.getFromDKV(knn, parms.train()); + Assert.assertNotNull(mm); + Assert.assertEquals(mm.auc(), 1.0, 0); + } finally { + if (knn != null){ + knn.delete(); + } + if (distances != null){ + distances.delete(); + } + if(fr != null) { + fr.delete(); + } + if(preds != null){ + preds.delete(); + } + } + } + + @Test + public void testSimpleFrameManhattan() { + KNNModel knn = null; + Frame fr = null; + Frame preds = null; + Frame distances = null; + try { + fr = generateSimpleFrame(); + + String idColumn = "id"; + String response = "class"; + + DKV.put(fr); + KNNModel.KNNParameters parms = new KNNModel.KNNParameters(); + parms._train = fr._key; + parms._k = 2; + parms._distance = new ManhattanDistance(); + parms._response_column = response; + parms._id_column = idColumn; + parms._auc_type = MultinomialAucType.MACRO_OVR; + + parms._seed = 42; + KNN job = new KNN(parms); + knn = job.trainModel().get(); + Assert.assertNotNull(knn); + + distances = knn._output.getDistances(); + Assert.assertNotNull(distances); + + Assert.assertEquals(distances.vec(0).at8(0), 1); + Assert.assertEquals(distances.vec(1).at(0), 0.0, 0); + Assert.assertEquals(distances.vec(2).at(0), 2.0, 0); + Assert.assertEquals(distances.vec(3).at8(0), 1); + Assert.assertEquals(distances.vec(4).at8(0), 2); + Assert.assertEquals(distances.vec(5).at8(0), 1); + Assert.assertEquals(distances.vec(6).at8(0),1); + + Assert.assertEquals(distances.vec(0).at8(1), 2); + Assert.assertEquals(distances.vec(1).at(1), 0.0, 0); + Assert.assertEquals(distances.vec(2).at(1), 2.0, 0); + Assert.assertEquals(distances.vec(3).at8(1), 2); + Assert.assertEquals(distances.vec(4).at8(1), 1); + Assert.assertEquals(distances.vec(5).at8(1), 1); + Assert.assertEquals(distances.vec(6).at8(1), 1); + + preds = knn.score(fr); + Assert.assertNotNull(preds); + + Assert.assertEquals(preds.vec(0).at8(0), 1); + Assert.assertEquals(preds.vec(1).at(0), 0.0, 0); + Assert.assertEquals(preds.vec(2).at(0), 1.0, 0); + + Assert.assertEquals(preds.vec(0).at8(3), 0); + Assert.assertEquals(preds.vec(1).at(3), 0.5, 0); + Assert.assertEquals(preds.vec(2).at(3), 0.5, 0); + + ModelMetricsBinomial mm = (ModelMetricsBinomial) ModelMetrics.getFromDKV(knn, parms.train()); + Assert.assertNotNull(mm); + Assert.assertEquals(mm.auc(), 1.0, 0); + } finally { + if (knn != null){ + knn.delete(); + } + if (distances != null){ + distances.delete(); + } + if(fr != null) { + fr.delete(); + } + if(preds != null){ + preds.delete(); + } + } + } + + @Test + public void testSimpleFrameCosine() { + KNNModel knn = null; + Frame fr = null; + Frame preds = null; + Frame distances = null; + try { + fr = generateSimpleFrameForCosine(); + + String idColumn = "id"; + String response = "class"; + + DKV.put(fr); + KNNModel.KNNParameters parms = new KNNModel.KNNParameters(); + parms._train = fr._key; + parms._k = 2; + parms._distance = new CosineDistance(); + parms._response_column = response; + parms._id_column = idColumn; + parms._auc_type = MultinomialAucType.MACRO_OVR; + + parms._seed = 42; + KNN job = new KNN(parms); + knn = job.trainModel().get(); + Assert.assertNotNull(knn); + + distances = knn._output.getDistances(); + Assert.assertNotNull(distances); + + Assert.assertEquals(distances.vec(0).at8(0), 1); + Assert.assertEquals(distances.vec(1).at(0), 0.0, 10e-5); + Assert.assertEquals(distances.vec(2).at(0), 1.0, 10e-5); + Assert.assertEquals(distances.vec(3).at8(0), 1); + Assert.assertEquals(distances.vec(4).at8(0), 3); + Assert.assertEquals(distances.vec(5).at8(0), 1); + Assert.assertEquals(distances.vec(6).at8(0),0); + + Assert.assertEquals(distances.vec(0).at8(1), 2); + Assert.assertEquals(distances.vec(1).at(1), 0.0, 10e-5); + Assert.assertEquals(distances.vec(2).at(1), 0.105573, 10e-5); + Assert.assertEquals(distances.vec(3).at8(1), 2); + Assert.assertEquals(distances.vec(4).at8(1), 4); + Assert.assertEquals(distances.vec(5).at8(1), 1); + Assert.assertEquals(distances.vec(6).at8(1), 0); + + preds = knn.score(fr); + Assert.assertNotNull(preds); + + Assert.assertEquals(preds.vec(0).at8(0), 1); + Assert.assertEquals(preds.vec(1).at(0), 0.5, 0); + Assert.assertEquals(preds.vec(2).at(0), 0.5, 0); + + Assert.assertEquals(preds.vec(0).at8(3), 0); + Assert.assertEquals(preds.vec(1).at(3), 1.0, 0); + Assert.assertEquals(preds.vec(2).at(3), 0.0, 0); + + ModelMetricsBinomial mm = (ModelMetricsBinomial) ModelMetrics.getFromDKV(knn, parms.train()); + Assert.assertNotNull(mm); + Assert.assertEquals(mm.auc(), 1.0, 0); + } finally { + if (knn != null){ + knn.delete(); + } + if (distances != null){ + distances.delete(); + } + if(fr != null) { + fr.delete(); + } + if(preds != null){ + preds.delete(); + } + } + } + + private Frame generateSimpleFrame(){ + return new TestFrameBuilder() + .withColNames("id", "C0", "C1", "class") + .withVecTypes(Vec.T_NUM, Vec.T_NUM, Vec.T_NUM, Vec.T_CAT) + .withDataForCol(0, ari(1, 2, 3, 4)) + .withDataForCol(1, ard(0.0, 1.0, 2.0, 3.0)) + .withDataForCol(2, ard(0.0, 1.0, 0.0, 1.0)) + .withDataForCol(3, ar("1", "1", "0", "0")) + .build(); + } + + private Frame generateSimpleFrameForCosine(){ + return new TestFrameBuilder() + .withColNames("id", "C0", "C1", "class") + .withVecTypes(Vec.T_NUM, Vec.T_NUM, Vec.T_NUM, Vec.T_CAT) + .withDataForCol(0, ari(1, 2, 3, 4)) + .withDataForCol(1, ard(0.0, 1.0, 2.0, 3.0)) + .withDataForCol(2, ard(-1.0, 1.0, 0.0, 1.0)) + .withDataForCol(3, ar("1", "1", "0", "0")) + .build(); + } + + @Test(expected = H2OModelBuilderIllegalArgumentException.class) + public void testIdColumnIsNotDefined() { + KNNModel knn = null; + Frame fr = null; + try { + fr = generateSimpleFrame(); + DKV.put(fr); + + KNNModel.KNNParameters parms = new KNNModel.KNNParameters(); + parms._train = fr._key; + parms._k = 2; + parms._distance = new EuclideanDistance(); + parms._response_column = "class"; + parms._id_column = null; + + parms._seed = 42; + KNN job = new KNN(parms); + knn = job.trainModel().get(); + + } finally { + if (knn != null){ + knn.delete(); + } + if (fr != null) { + fr.delete(); + } + } + } + + @Test(expected = H2OModelBuilderIllegalArgumentException.class) + public void testDistanceIsNotDefined() { + KNNModel knn = null; + Frame fr = null; + try { + fr = generateSimpleFrame(); + DKV.put(fr); + + KNNModel.KNNParameters parms = new KNNModel.KNNParameters(); + parms._train = fr._key; + parms._k = 2; + parms._response_column = "class"; + parms._id_column = "id"; + + parms._seed = 42; + KNN job = new KNN(parms); + knn = job.trainModel().get(); + + } finally { + if (knn != null){ + knn.delete(); + } + if (fr != null) { + fr.delete(); + } + } + } +} diff --git a/h2o-algos/src/test/java/hex/knn/TopNTreeMapTest.java b/h2o-algos/src/test/java/hex/knn/TopNTreeMapTest.java new file mode 100644 index 000000000000..74770c115b92 --- /dev/null +++ b/h2o-algos/src/test/java/hex/knn/TopNTreeMapTest.java @@ -0,0 +1,80 @@ +package hex.knn; + +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; +import water.TestUtil; + +public class TopNTreeMapTest extends TestUtil { + + @BeforeClass() public static void setup() { stall_till_cloudsize(1); } + + @Test + public void testPut(){ + int k = 3; + TopNTreeMap, Double> map = new TopNTreeMap<>(k); + + // test all items was added and correctly sorted (1.0, 2.0, 3.0) + map.put(new KNNKey("a", 3.0), 1.0); + map.put(new KNNKey("b", 2.0), 0.0); + map.put(new KNNKey("c", 1.0), 1.0); + Assert.assertEquals(k, map.size()); + Assert.assertEquals(map.lastKey().value, 3.0, 0); + + // test the new item 4. 0 is not added + map.put(new KNNKey("d", 4.0), 0.0); + Assert.assertEquals(k, map.size()); + Assert.assertEquals(map.lastKey().value, 3.0, 0); + + // test the new item 0.0 should be added and be the first item of the map and the last item 3.0 is removed + map.put(new KNNKey("e", 0.0), 0.0); + Assert.assertEquals(k, map.size()); + Assert.assertEquals(map.lastKey().value, 2.0, 0); + + // test the new item with the key("e", 0.0) and value "0.0" is not added + // the item is not added, put returns value which is associated with this key + Double value = map.put(new KNNKey("e", 0.0), 0.0); + Assert.assertEquals(value, 0.0, 0); + Assert.assertEquals(map.lastKey().value, 2.0, 0); + + // test the new item with the key("ee", 0.0) and value "0.0" is added + // the item is added, put returns null + value = map.put(new KNNKey("ee", 0.0), 0.0); + Assert.assertNull(value); + Assert.assertEquals(map.firstKey().value, 0.0, 0); + //Assert.assertEquals("", "e", map.firstKey().key); + + // test put new item with the key ("ee", "1.0") and value "0.0" is added + // the item is added, put return null + value = map.put(new KNNKey("ee", 0.5), 0.0); + Assert.assertEquals(map.lastKey().value, 0.5, 0); + + } + + @Test + public void testKNNKeyCompareTo(){ + KNNKey k1 = new KNNKey<>(1, 1.0); + KNNKey k2 = new KNNKey<>(2, 1.0); + KNNKey k3 = new KNNKey<>(2, 1.0); + KNNKey k4 = new KNNKey<>(2, 2.0); + + // different key same value -> the first is less + Assert.assertEquals(k1.compareTo(k2), -1); + + // same key same value -> both object are the same + Assert.assertEquals(k2.compareTo(k3), 0); + + // different key same value -> depends on key comparator + Assert.assertEquals(k2.compareTo(k1), 1); + + // same key different value -> the item with less value is less + Assert.assertEquals(k3.compareTo(k4), -1); + + // different key different value -> the item with less value is less + Assert.assertEquals(k1.compareTo(k4), -1); + + // different key different value -> the item with less value is less + Assert.assertEquals(k4.compareTo(k1), 1); + } + +} diff --git a/h2o-core/src/main/java/hex/Model.java b/h2o-core/src/main/java/hex/Model.java index c246e5675f05..f0f8aa96bd2b 100755 --- a/h2o-core/src/main/java/hex/Model.java +++ b/h2o-core/src/main/java/hex/Model.java @@ -449,6 +449,7 @@ public long getOrMakeRealSeed(){ public String _offset_column; public String _fold_column; public String _treatment_column; + public String _id_column; // Check for constant response public boolean _check_constant_response = true; @@ -573,7 +574,7 @@ public long getOrMakeRealSeed(){ public final Frame valid() { return _valid==null ? null : _valid.get(); } public String[] getNonPredictors() { - return Arrays.stream(new String[]{_weights_column, _offset_column, _fold_column, _response_column, _treatment_column}) + return Arrays.stream(new String[]{_weights_column, _offset_column, _fold_column, _response_column, _treatment_column, _id_column}) .filter(Objects::nonNull) .toArray(String[]::new); } @@ -800,6 +801,11 @@ public final String getTreatmentColumn(){ return _treatment_column; } + @Override + public String getIdColumn() { + return _id_column; + } + @Override public final int getMaxCategoricalLevels() { return _max_categorical_levels; @@ -1078,6 +1084,7 @@ protected Output(ModelBuilder b, Frame train) { _hasWeights = b.hasWeightCol(); _hasFold = b.hasFoldCol(); _hasTreatment = b.hasTreatmentCol(); + _hasId = b.hasIdCol(); _distribution = b._distribution; _priorClassDist = b._priorClassDist; _reproducibility_information_table = createReproducibilityInformationTable(b); @@ -1087,7 +1094,7 @@ protected Output(ModelBuilder b, Frame train) { /** Returns number of input features (OK for most supervised methods, need to override for unsupervised!) */ public int nfeatures() { - return _names.length - (_hasOffset?1:0) - (_hasWeights?1:0) - (_hasFold?1:0) - (_hasTreatment ?1:0) - (isSupervised()?1:0); + return _names.length - (_hasOffset?1:0) - (_hasWeights?1:0) - (_hasFold?1:0) - (_hasTreatment ?1:0) - (_hasId?1:0) - (isSupervised()?1:0); } /** Returns features used by the model */ public String[] features() { @@ -1160,16 +1167,20 @@ public String[] features() { protected boolean _hasWeights;// only need to know if we have them protected boolean _hasFold;// only need to know if we have them protected boolean _hasTreatment; + protected boolean _hasId; public boolean hasOffset () { return _hasOffset;} public boolean hasWeights () { return _hasWeights;} public boolean hasFold () { return _hasFold;} public boolean hasTreatment() { return _hasTreatment;} public boolean hasResponse() { return isSupervised(); } + public boolean hasId() {return _hasId;} public String responseName() { return isSupervised()?_names[responseIdx()]:null;} public String weightsName () { return _hasWeights ?_names[weightsIdx()]:null;} public String offsetName () { return _hasOffset ?_names[offsetIdx()]:null;} public String foldName () { return _hasFold ?_names[foldIdx()]:null;} public String treatmentName() { return _hasTreatment ? _names[treatmentIdx()]: null;} + public String idName() {return _hasId ? _names[idIdx()] : null;} + public InteractionBuilder interactionBuilder() { return null; } // Vec layout is [c1,c2,...,cn, w?, o?, f?, u?, r] // cn are predictor cols, r is response, w is weights, o is offset, f is fold and t is treatment - these are optional @@ -1198,6 +1209,11 @@ public int treatmentIdx() { return _names.length - (isSupervised()?1:0) - 1; } + public int idIdx() { + if(!_hasId) return -1; + return _names.length - (isSupervised()?1:0) - 1; + } + /** Names of levels for a categorical response column. */ public String[] classNames() { if (_domains == null || _domains.length == 0 || !isSupervised()) return null; @@ -1667,6 +1683,7 @@ public interface AdaptFrameParameters { String getFoldColumn(); String getResponseColumn(); String getTreatmentColumn(); + String getIdColumn(); double missingColumnsType(); int getMaxCategoricalLevels(); default String[] getNonPredictors() { @@ -1711,6 +1728,7 @@ public static String[] adaptTestForTrain(final Frame test, final String[] origNa final String fold = parms.getFoldColumn(); final String response = parms.getResponseColumn(); final String treatment = parms.getTreatmentColumn(); + final String id = parms.getIdColumn(); // whether we need to be careful with categorical encoding - the test frame could be either in original state or in encoded state @@ -1730,7 +1748,7 @@ public static String[] adaptTestForTrain(final Frame test, final String[] origNa // As soon as the test frame contains at least one original pre-encoding predictor, // then we consider the frame as valid for predictions, and we'll later fill missing columns with NA Set required = new HashSet<>(Arrays.asList(origNames)); - required.removeAll(Arrays.asList(response, weights, fold, treatment)); + required.removeAll(Arrays.asList(response, weights, fold, treatment, id)); for (String name : test.names()) { if (required.contains(name)) { match = true; @@ -3481,6 +3499,9 @@ protected class H2OModelDescriptor implements ModelDescriptor { @Override public String foldColumn() { return _output.foldName(); } @Override + public String idColumn() { return _output.idName();} + + @Override public ModelCategory getModelCategory() { return _output.getModelCategory(); } @Override public boolean isSupervised() { return _output.isSupervised(); } diff --git a/h2o-core/src/main/java/hex/ModelBuilder.java b/h2o-core/src/main/java/hex/ModelBuilder.java index cbf1e301542d..79ee0c9a3731 100644 --- a/h2o-core/src/main/java/hex/ModelBuilder.java +++ b/h2o-core/src/main/java/hex/ModelBuilder.java @@ -1094,6 +1094,7 @@ public boolean isResponseOptional() { protected transient Vec _weights; // observation weight column protected transient Vec _fold; // fold id column protected transient Vec _treatment; + protected transient Vec _id; protected transient String[] _origNames; // only set if ModelBuilder.encodeFrameCategoricals() changes the training frame protected transient String[][] _origDomains; // only set if ModelBuilder.encodeFrameCategoricals() changes the training frame protected transient double[] _orig_projection_array; // only set if ModelBuilder.encodeFrameCategoricals() changes the training frame @@ -1102,7 +1103,8 @@ public boolean isResponseOptional() { public boolean hasWeightCol(){ return _parms._weights_column != null;} // don't look at transient Vec public boolean hasFoldCol() { return _parms._fold_column != null;} // don't look at transient Vec public boolean hasTreatmentCol() { return _parms._treatment_column != null;} - public int numSpecialCols() { return (hasOffsetCol() ? 1 : 0) + (hasWeightCol() ? 1 : 0) + (hasFoldCol() ? 1 : 0) + (hasTreatmentCol() ? 1 : 0); } + public boolean hasIdCol() { return _parms._id_column != null; } + public int numSpecialCols() { return (hasOffsetCol() ? 1 : 0) + (hasWeightCol() ? 1 : 0) + (hasFoldCol() ? 1 : 0) + (hasTreatmentCol() ? 1 : 0) + (hasIdCol() ? 1 : 0); } public boolean havePojo() { return false; } public boolean haveMojo() { return false; } @@ -1220,6 +1222,23 @@ public int separateFeatureVecs() { _treatment = null; assert(!hasTreatmentCol()); } + if(_parms._id_column!= null) { + Vec id = _train.remove(_parms._id_column); + if (id == null) + error("_id_column","Id column '" + _parms._id_column + "' not found in the training frame"); + else { + if(id.naCnt() > 0) + error("_id_column","Id column cannot have missing values."); + if(id.isCategorical()) + error("_id_column","Id column cannot be categorical."); + _id = id; + _train.add(_parms._id_column, id); + ++res; + } + } else { + _id = null; + assert(!hasIdCol()); + } if(isSupervised() && _parms._response_column != null) { _response = _train.remove(_parms._response_column); if (_response == null) { diff --git a/h2o-genmodel/src/main/java/hex/genmodel/MojoPipelineWriter.java b/h2o-genmodel/src/main/java/hex/genmodel/MojoPipelineWriter.java index 94cae181cca5..6104e4f34b33 100644 --- a/h2o-genmodel/src/main/java/hex/genmodel/MojoPipelineWriter.java +++ b/h2o-genmodel/src/main/java/hex/genmodel/MojoPipelineWriter.java @@ -148,6 +148,9 @@ public String foldColumn() { @Override public String treatmentColumn() { return null; } + @Override + public String idColumn() { return null; } + @Override public ModelCategory getModelCategory() { return _finalModel._category; diff --git a/h2o-genmodel/src/main/java/hex/genmodel/descriptor/ModelDescriptor.java b/h2o-genmodel/src/main/java/hex/genmodel/descriptor/ModelDescriptor.java index 2c4d09948cad..435e9894f7dc 100644 --- a/h2o-genmodel/src/main/java/hex/genmodel/descriptor/ModelDescriptor.java +++ b/h2o-genmodel/src/main/java/hex/genmodel/descriptor/ModelDescriptor.java @@ -53,6 +53,11 @@ public interface ModelDescriptor { */ String treatmentColumn(); + /** + * @return A {@link String} with the name of the id column used. Null of there was no id used during training. + */ + String idColumn(); + /** * Model's category. * diff --git a/h2o-genmodel/src/main/java/hex/genmodel/descriptor/ModelDescriptorBuilder.java b/h2o-genmodel/src/main/java/hex/genmodel/descriptor/ModelDescriptorBuilder.java index 0ee4d5855e06..89f7a2e8e5e0 100644 --- a/h2o-genmodel/src/main/java/hex/genmodel/descriptor/ModelDescriptorBuilder.java +++ b/h2o-genmodel/src/main/java/hex/genmodel/descriptor/ModelDescriptorBuilder.java @@ -46,6 +46,7 @@ public static class MojoModelDescriptor implements ModelDescriptor, Serializable private final String _foldColumn; private final String _weightsColumn; private final String _treatmentColumn; + private final String _idColumn; private final String[][] _domains; private final String[][] _origDomains; private final String[] _names; @@ -75,14 +76,15 @@ private MojoModelDescriptor(final MojoModel mojoModel, final String fullAlgorith _fullAlgoName = fullAlgorithmName; if (modelAttributes != null) { ColumnSpecifier weightsColSpec = (ColumnSpecifier) modelAttributes.getParameterValueByName("weights_column"); - _weightsColumn = weightsColSpec != null ? weightsColSpec.getColumnName() : null; + _weightsColumn = weightsColSpec != null ? weightsColSpec.getColumnName() : null; + // the treatment column should be ColumnSpecifier not String - this should be fixed in different PR + _treatmentColumn = (String) modelAttributes.getParameterValueByName("treatment_column"); + ColumnSpecifier idColSpec = (ColumnSpecifier) modelAttributes.getParameterValueByName("id_column"); + _idColumn = idColSpec != null ? idColSpec.getColumnName() : null; } else { _weightsColumn = null; - } - if (modelAttributes != null) { - _treatmentColumn = (String) modelAttributes.getParameterValueByName("treatment_column");; - } else { _treatmentColumn = null; + _idColumn = null; } } @@ -126,6 +128,9 @@ public String treatmentColumn() { return _treatmentColumn; } + @Override + public String idColumn() { return _idColumn;} + @Override public ModelCategory getModelCategory() { return _category; @@ -264,6 +269,9 @@ public String foldColumn() { @Override public String treatmentColumn() { return null; } + @Override + public String idColumn() { return null; } + @Override public ModelCategory getModelCategory() { return _category; diff --git a/h2o-test-support/src/main/java/water/TestUtil.java b/h2o-test-support/src/main/java/water/TestUtil.java index 348ca6186c25..b896482f0bfb 100644 --- a/h2o-test-support/src/main/java/water/TestUtil.java +++ b/h2o-test-support/src/main/java/water/TestUtil.java @@ -2050,6 +2050,28 @@ public static Vec createRandomCategoricalVec(final long len, final long randomSe return vec; } + + /** + * @param len Length of the resulting vector + * @param type Type of column. Possible options are Vec.T_NUM, Vec.T_STR, Vec.T_UUID + * @return id column vec + */ + public static Vec createIdVec(final long len, byte type) { + assert type == Vec.T_UUID || type == Vec.T_STR || type == Vec.T_NUM: "Unsupported type for id vec creation: "+type; + final Vec vec = Vec.makeZero(len, type); + for (int i = 0; i < vec.length(); i++) { + switch (type) { + case Vec.T_STR: + vec.set(i, String.valueOf(i)); + case Vec.T_UUID: + vec.set(i, UUID.fromString(String.valueOf(i))); + default: + vec.set(i, i); + } + } + return vec; + } + @SuppressWarnings("rawtypes") public static GenModel toMojo(Model model, String testName, boolean readModelMetaData) { final String filename = testName + ".zip";