Skip to content

Commit

Permalink
GH-16319 Implement KNN backend [nocheck] (#16405)
Browse files Browse the repository at this point in the history
* Implement KNN model
* Implement KNN distance calculation
* Implement KNN distance task and KNN driver
* Implement scoring
* Add JUnit tests
* Add docs
  • Loading branch information
maurever authored Dec 18, 2024
1 parent cca2428 commit 8ff06da
Show file tree
Hide file tree
Showing 22 changed files with 1,265 additions and 9 deletions.
1 change: 1 addition & 0 deletions h2o-algos/src/main/java/hex/api/RegisterAlgos.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ public void registerEndPoints(RestApiContext context) {
new hex.tree.dt .DT (true),
new hex.hglm .HGLM (true),
new hex.adaboost. AdaBoost (true)
//new hex.knn .KNN (true) will be implement in different PR
};

// "Word2Vec", "Example", "Grep"
Expand Down
4 changes: 4 additions & 0 deletions h2o-algos/src/main/java/hex/generic/GenericModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ private void predict(EasyPredictModelWrapper wrapper, AdaptFrameParameters adapt
final String weightsColumn = adaptFrameParameters.getWeightsColumn();
final String responseColumn = adaptFrameParameters.getResponseColumn();
final String treatmentColumn = adaptFrameParameters.getTreatmentColumn();
final String idColumn = adaptFrameParameters.getIdColumn();
final boolean isClassifier = wrapper.getModel().isClassifier();
final boolean isUplift = treatmentColumn != null;
final float[] yact;
Expand Down Expand Up @@ -355,6 +356,9 @@ public String getResponseColumn() {
}
@Override
public String getTreatmentColumn() {return descriptor != null ? descriptor.treatmentColumn() : null;}
@Override
public String getIdColumn() { return descriptor != null ? descriptor.idColumn() : null;}

@Override
public double missingColumnsType() {
return Double.NaN;
Expand Down
25 changes: 25 additions & 0 deletions h2o-algos/src/main/java/hex/knn/CosineDistance.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package hex.knn;

public class CosineDistance extends KNNDistance {

public CosineDistance(){
super.valuesLength = 3;
}

@Override
public double nom(double v1, double v2) {
return v1*v2;
}

@Override
public void calculateValues(double v1, double v2) {
this.values[0] += nom(v1, v2);
this.values[1] += nom(v1, v1);
this.values[2] += nom(v2, v2);
}

@Override
public double result() {
return 1 - (this.values[0] / (Math.sqrt(this.values[1]) * Math.sqrt(this.values[2])));
}
}
21 changes: 21 additions & 0 deletions h2o-algos/src/main/java/hex/knn/EuclideanDistance.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package hex.knn;

public class EuclideanDistance extends KNNDistance {

@Override
public double nom(double v1, double v2) {
return (v1-v2)*(v1-v2);
}

@Override
public void calculateValues(double v1, double v2) {
this.values[0] += nom(v1, v2);
}

@Override
public double result() {
return Math.sqrt(this.values[0]);
}


}
101 changes: 101 additions & 0 deletions h2o-algos/src/main/java/hex/knn/KNN.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package hex.knn;

import hex.*;
import water.DKV;
import water.Key;
import water.Scope;
import water.fvec.Chunk;
import water.fvec.Frame;

public class KNN extends ModelBuilder<KNNModel,KNNModel.KNNParameters,KNNModel.KNNOutput> {

public KNN(KNNModel.KNNParameters parms) {
super(parms);
init(false);
}

public KNN(boolean startup_once) {
super(new KNNModel.KNNParameters(), startup_once);
}

@Override
protected KNNDriver trainModelImpl() {
return new KNNDriver();
}

@Override
public ModelCategory[] can_build() {
return new ModelCategory[]{ModelCategory.Binomial, ModelCategory.Multinomial};
}

@Override
public boolean isSupervised() {
return true;
}

@Override public void init(boolean expensive) {
super.init(expensive);
if( null == _parms._id_column) {
error("_id_column", "ID column parameter not set.");
}
if( null == _parms._distance) {
error("_distance", "Distance parameter not set.");
}
}

class KNNDriver extends Driver {

@Override
public void computeImpl() {
KNNModel model = null;
Frame result = new Frame(Key.make("KNN_distances"));
Frame tmpResult = null;
try {
init(true); // Initialize parameters
if (error_count() > 0) {
throw new IllegalArgumentException("Found validation errors: " + validationErrors());
}
model = new KNNModel(dest(), _parms, new KNNModel.KNNOutput(KNN.this));
model.delete_and_lock(_job);
Frame train = _parms.train();
String idColumn = _parms._id_column;
int idColumnIndex = train.find(idColumn);
byte idType = train.vec(idColumnIndex).get_type();
String responseColumn = _parms._response_column;
int responseColumnIndex = train.find(responseColumn);
int nChunks = train.anyVec().nChunks();
int nCols = train.numCols();
// split data into chunks to calculate distances in parallel task
for (int i = 0; i < nChunks; i++) {
Chunk[] query = new Chunk[nCols];
for (int j = 0; j < nCols; j++) {
query[j] = train.vec(j).chunkForChunkIdx(i).deepCopy();
}
KNNDistanceTask task = new KNNDistanceTask(_parms._k, query, _parms._distance, idColumnIndex, idColumn, idType, responseColumnIndex, responseColumn);
tmpResult = task.doAll(train).outputFrame();
// merge result from a chunk
result = result.add(tmpResult);
}
DKV.put(result._key, result);
model._output.setDistancesKey(result._key);
Scope.untrack(result);

model.update(_job);

model.score(_parms.train()).delete();
model._output._training_metrics = ModelMetrics.getFromDKV(model, _parms.train());
} catch (Exception e) {
throw new RuntimeException(e);
} finally {
if (model != null) {
model.unlock(_job);
}
if (tmpResult != null) {
tmpResult.remove();
}
}
}
}
}


45 changes: 45 additions & 0 deletions h2o-algos/src/main/java/hex/knn/KNNDistance.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package hex.knn;

import water.Iced;

/**
* Template for various distance calculation.
*/
public abstract class KNNDistance extends Iced<KNNDistance> {

// Lenght of values for calculation of the distance
// For example for calculation euclidean and manhattan distance we need only one value,
// for calculation cosine distance we need tree values.
public int valuesLength = 1;

// Array to cumulate partial calculations of distance
public double[] values;

/**
* Method to calculate the distance between two points from two vectors.
* @param v1 value of an item in the first vector
* @param v2 value of an item in the second vector
*/
public abstract double nom(double v1, double v2);

/**
* Initialize values array to store partial calculation of distance.
*/
public void initializeValues(){
this.values = new double[valuesLength];
}

/**
* Method to cumulate partial calculations of distance between two vectors and save it to values array.
* @param v1 value of an item in the first vector
* @param v2 value of an item in the second vector
*/
public abstract void calculateValues(double v1, double v2);

/**
* Calculate the result from cumulated values.
* @return Final distance calculation.
*/
public abstract double result();

}
133 changes: 133 additions & 0 deletions h2o-algos/src/main/java/hex/knn/KNNDistanceTask.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package hex.knn;

import water.Key;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;

import java.util.Iterator;

public class KNNDistanceTask extends MRTask<KNNDistanceTask> {

public int _k;
public Chunk[] _queryData;
public KNNDistance _distance;
public KNNHashMap<String, TopNTreeMap<KNNKey, Object>> _topNNeighboursMaps;
public String _idColumn;
public String _responseColumn;
public int _idIndex;
public int _responseIndex;
public byte _idColumnType;

/**
* Calculate distances dor a particular chunk
*/
public KNNDistanceTask(int k, Chunk[] query, KNNDistance distance, int idIndex, String idColumn, byte idType, int responseIndex, String responseColumn){
this._k = k;
this._queryData = query;
this._distance = distance;
this._topNNeighboursMaps = new KNNHashMap<>();
this._idColumn = idColumn;
this._responseColumn = responseColumn;
this._idIndex = idIndex;
this._responseIndex = responseIndex;
this._idColumnType = idType;
}

@Override
public void map(Chunk[] cs) {
int queryColNum = _queryData.length;
long queryRowNum = _queryData[0]._len;
int inputColNum = cs.length;
int inputRowNum = cs[0]._len;
assert queryColNum == inputColNum: "Query data frame and input data frame should have the same columns number.";
for (int i = 0; i < queryRowNum; i++) { // go over all query data rows
TopNTreeMap<KNNKey, Object> distancesMap = new TopNTreeMap<>(_k);
String queryDataId = _idColumnType == Vec.T_STR ? _queryData[_idIndex].stringAt(i) : String.valueOf(_queryData[_idIndex].at8(i));
for (int j = 0; j < inputRowNum; j++) { // go over all input data rows
String inputDataId = _idColumnType == Vec.T_STR ? cs[_idIndex].stringAt(j) : String.valueOf(cs[_idIndex].at8(j));
long inputDataCategory = cs[_responseIndex].at8(j);
// if(queryDataId.equals(inputDataId)) continue; // the same id included or not?
_distance.initializeValues();
for (int k = 0; k < inputColNum; k++) { // go over all columns
if (k == _idIndex || k == _responseIndex) continue;
double queryColData = _queryData[k].atd(i);
double inputColData = cs[k].atd(j);
_distance.calculateValues(queryColData, inputColData);
}
double dist = _distance.result();

distancesMap.put(new KNNKey(inputDataId, dist), inputDataCategory);
}
_topNNeighboursMaps.put(queryDataId, distancesMap);
}
}

@Override
public void reduce(KNNDistanceTask mrt) {
KNNHashMap<String, TopNTreeMap<KNNKey, Object>> inputMap = mrt._topNNeighboursMaps;
this._topNNeighboursMaps.reduce(inputMap);
}

/**
* Get data from maps to Frame
* @param vecs
* @return filled array of vecs with calculated data
*/
public Vec[] fillVecs(Vec[] vecs){
for (int i = 0; i < vecs[0].length(); i++) {
String id = _idColumnType == Vec.T_STR ? vecs[0].stringAt(i) : String.valueOf(vecs[0].at8(i));
TopNTreeMap<KNNKey, Object> topNMap = _topNNeighboursMaps.get(id);
Iterator<KNNKey> distances = topNMap.keySet().stream().iterator();
Iterator<Object> responses = topNMap.values().iterator();
for (int j = 1; j < _k+1; j++) {
KNNKey key = distances.next();
String keyString = key.key.toString();
vecs[j].set(i, key.value);
if(_idColumnType == Vec.T_STR){
vecs[_k + j].set(i, keyString);
} else {
vecs[_k + j].set(i, Integer.parseInt(keyString));
}
vecs[2 * _k + j].set(i, (long) responses.next());
}
}
return vecs;
}

/**
* Generate output frame with calculated distances.
* @return
*/
public Frame outputFrame() {
int newVecsSize = _k*3+1;
Vec[] vecs = new Vec[newVecsSize];
String[] names = new String[newVecsSize];
boolean isStringId = _idColumnType == Vec.T_STR;
Vec id = Vec.makeCon(0, _queryData[0].len(), false);
for (int i = 0; i < _queryData[_idIndex].len(); i++) {
if(isStringId) {
id.set(i, _queryData[_idIndex].stringAt(i));
} else {
id.set(i, _queryData[_idIndex].atd(i));
}
}
vecs[0] = id;
names[0] = _idColumn;
for (int i = 1; i < _k+1; i++) {
// names of columns
names[i] = "dist_"+i; // this could be customized
names[_k+i] = _idColumn+"_"+i; // this could be customized
names[2*_k+i] = _responseColumn+"_"+i; // this could be customized
vecs[i] = id.makeZero();
vecs[i] = vecs[i].toNumericVec();
vecs[_k+i] = id.makeZero();
if (isStringId) vecs[_k+i].toStringVec();
vecs[2*_k+i] = id.makeZero();
vecs[2*_k+i] = vecs[2*_k+i].toNumericVec();
}
vecs = fillVecs(vecs);
return new Frame(Key.make("KNN_distances_tmp"), names, vecs);
}
}
19 changes: 19 additions & 0 deletions h2o-algos/src/main/java/hex/knn/KNNHashMap.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package hex.knn;

import java.util.HashMap;
import java.util.Map;

public class KNNHashMap<K, V extends TopNTreeMap<KNNKey, Object>> extends HashMap<K, V> {

public void reduce(KNNHashMap<K, V> map){
for (Map.Entry<K, V> entry: map.entrySet()) {
K key = entry.getKey();
V valueMap = entry.getValue();
if (this.containsKey(key)){
V currentKeyMap = this.get(key);
currentKeyMap.putAll(valueMap);
this.put(key, currentKeyMap);
}
}
}
}
Loading

0 comments on commit 8ff06da

Please sign in to comment.