-
Notifications
You must be signed in to change notification settings - Fork 2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
* Implement KNN model * Implement KNN distance calculation * Implement KNN distance task and KNN driver * Implement scoring * Add JUnit tests * Add docs
- Loading branch information
Showing
22 changed files
with
1,265 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
package hex.knn; | ||
|
||
public class CosineDistance extends KNNDistance { | ||
|
||
public CosineDistance(){ | ||
super.valuesLength = 3; | ||
} | ||
|
||
@Override | ||
public double nom(double v1, double v2) { | ||
return v1*v2; | ||
} | ||
|
||
@Override | ||
public void calculateValues(double v1, double v2) { | ||
this.values[0] += nom(v1, v2); | ||
this.values[1] += nom(v1, v1); | ||
this.values[2] += nom(v2, v2); | ||
} | ||
|
||
@Override | ||
public double result() { | ||
return 1 - (this.values[0] / (Math.sqrt(this.values[1]) * Math.sqrt(this.values[2]))); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
package hex.knn; | ||
|
||
public class EuclideanDistance extends KNNDistance { | ||
|
||
@Override | ||
public double nom(double v1, double v2) { | ||
return (v1-v2)*(v1-v2); | ||
} | ||
|
||
@Override | ||
public void calculateValues(double v1, double v2) { | ||
this.values[0] += nom(v1, v2); | ||
} | ||
|
||
@Override | ||
public double result() { | ||
return Math.sqrt(this.values[0]); | ||
} | ||
|
||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
package hex.knn; | ||
|
||
import hex.*; | ||
import water.DKV; | ||
import water.Key; | ||
import water.Scope; | ||
import water.fvec.Chunk; | ||
import water.fvec.Frame; | ||
|
||
public class KNN extends ModelBuilder<KNNModel,KNNModel.KNNParameters,KNNModel.KNNOutput> { | ||
|
||
public KNN(KNNModel.KNNParameters parms) { | ||
super(parms); | ||
init(false); | ||
} | ||
|
||
public KNN(boolean startup_once) { | ||
super(new KNNModel.KNNParameters(), startup_once); | ||
} | ||
|
||
@Override | ||
protected KNNDriver trainModelImpl() { | ||
return new KNNDriver(); | ||
} | ||
|
||
@Override | ||
public ModelCategory[] can_build() { | ||
return new ModelCategory[]{ModelCategory.Binomial, ModelCategory.Multinomial}; | ||
} | ||
|
||
@Override | ||
public boolean isSupervised() { | ||
return true; | ||
} | ||
|
||
@Override public void init(boolean expensive) { | ||
super.init(expensive); | ||
if( null == _parms._id_column) { | ||
error("_id_column", "ID column parameter not set."); | ||
} | ||
if( null == _parms._distance) { | ||
error("_distance", "Distance parameter not set."); | ||
} | ||
} | ||
|
||
class KNNDriver extends Driver { | ||
|
||
@Override | ||
public void computeImpl() { | ||
KNNModel model = null; | ||
Frame result = new Frame(Key.make("KNN_distances")); | ||
Frame tmpResult = null; | ||
try { | ||
init(true); // Initialize parameters | ||
if (error_count() > 0) { | ||
throw new IllegalArgumentException("Found validation errors: " + validationErrors()); | ||
} | ||
model = new KNNModel(dest(), _parms, new KNNModel.KNNOutput(KNN.this)); | ||
model.delete_and_lock(_job); | ||
Frame train = _parms.train(); | ||
String idColumn = _parms._id_column; | ||
int idColumnIndex = train.find(idColumn); | ||
byte idType = train.vec(idColumnIndex).get_type(); | ||
String responseColumn = _parms._response_column; | ||
int responseColumnIndex = train.find(responseColumn); | ||
int nChunks = train.anyVec().nChunks(); | ||
int nCols = train.numCols(); | ||
// split data into chunks to calculate distances in parallel task | ||
for (int i = 0; i < nChunks; i++) { | ||
Chunk[] query = new Chunk[nCols]; | ||
for (int j = 0; j < nCols; j++) { | ||
query[j] = train.vec(j).chunkForChunkIdx(i).deepCopy(); | ||
} | ||
KNNDistanceTask task = new KNNDistanceTask(_parms._k, query, _parms._distance, idColumnIndex, idColumn, idType, responseColumnIndex, responseColumn); | ||
tmpResult = task.doAll(train).outputFrame(); | ||
// merge result from a chunk | ||
result = result.add(tmpResult); | ||
} | ||
DKV.put(result._key, result); | ||
model._output.setDistancesKey(result._key); | ||
Scope.untrack(result); | ||
|
||
model.update(_job); | ||
|
||
model.score(_parms.train()).delete(); | ||
model._output._training_metrics = ModelMetrics.getFromDKV(model, _parms.train()); | ||
} catch (Exception e) { | ||
throw new RuntimeException(e); | ||
} finally { | ||
if (model != null) { | ||
model.unlock(_job); | ||
} | ||
if (tmpResult != null) { | ||
tmpResult.remove(); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
package hex.knn; | ||
|
||
import water.Iced; | ||
|
||
/** | ||
* Template for various distance calculation. | ||
*/ | ||
public abstract class KNNDistance extends Iced<KNNDistance> { | ||
|
||
// Lenght of values for calculation of the distance | ||
// For example for calculation euclidean and manhattan distance we need only one value, | ||
// for calculation cosine distance we need tree values. | ||
public int valuesLength = 1; | ||
|
||
// Array to cumulate partial calculations of distance | ||
public double[] values; | ||
|
||
/** | ||
* Method to calculate the distance between two points from two vectors. | ||
* @param v1 value of an item in the first vector | ||
* @param v2 value of an item in the second vector | ||
*/ | ||
public abstract double nom(double v1, double v2); | ||
|
||
/** | ||
* Initialize values array to store partial calculation of distance. | ||
*/ | ||
public void initializeValues(){ | ||
this.values = new double[valuesLength]; | ||
} | ||
|
||
/** | ||
* Method to cumulate partial calculations of distance between two vectors and save it to values array. | ||
* @param v1 value of an item in the first vector | ||
* @param v2 value of an item in the second vector | ||
*/ | ||
public abstract void calculateValues(double v1, double v2); | ||
|
||
/** | ||
* Calculate the result from cumulated values. | ||
* @return Final distance calculation. | ||
*/ | ||
public abstract double result(); | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
package hex.knn; | ||
|
||
import water.Key; | ||
import water.MRTask; | ||
import water.fvec.Chunk; | ||
import water.fvec.Frame; | ||
import water.fvec.Vec; | ||
|
||
import java.util.Iterator; | ||
|
||
public class KNNDistanceTask extends MRTask<KNNDistanceTask> { | ||
|
||
public int _k; | ||
public Chunk[] _queryData; | ||
public KNNDistance _distance; | ||
public KNNHashMap<String, TopNTreeMap<KNNKey, Object>> _topNNeighboursMaps; | ||
public String _idColumn; | ||
public String _responseColumn; | ||
public int _idIndex; | ||
public int _responseIndex; | ||
public byte _idColumnType; | ||
|
||
/** | ||
* Calculate distances dor a particular chunk | ||
*/ | ||
public KNNDistanceTask(int k, Chunk[] query, KNNDistance distance, int idIndex, String idColumn, byte idType, int responseIndex, String responseColumn){ | ||
this._k = k; | ||
this._queryData = query; | ||
this._distance = distance; | ||
this._topNNeighboursMaps = new KNNHashMap<>(); | ||
this._idColumn = idColumn; | ||
this._responseColumn = responseColumn; | ||
this._idIndex = idIndex; | ||
this._responseIndex = responseIndex; | ||
this._idColumnType = idType; | ||
} | ||
|
||
@Override | ||
public void map(Chunk[] cs) { | ||
int queryColNum = _queryData.length; | ||
long queryRowNum = _queryData[0]._len; | ||
int inputColNum = cs.length; | ||
int inputRowNum = cs[0]._len; | ||
assert queryColNum == inputColNum: "Query data frame and input data frame should have the same columns number."; | ||
for (int i = 0; i < queryRowNum; i++) { // go over all query data rows | ||
TopNTreeMap<KNNKey, Object> distancesMap = new TopNTreeMap<>(_k); | ||
String queryDataId = _idColumnType == Vec.T_STR ? _queryData[_idIndex].stringAt(i) : String.valueOf(_queryData[_idIndex].at8(i)); | ||
for (int j = 0; j < inputRowNum; j++) { // go over all input data rows | ||
String inputDataId = _idColumnType == Vec.T_STR ? cs[_idIndex].stringAt(j) : String.valueOf(cs[_idIndex].at8(j)); | ||
long inputDataCategory = cs[_responseIndex].at8(j); | ||
// if(queryDataId.equals(inputDataId)) continue; // the same id included or not? | ||
_distance.initializeValues(); | ||
for (int k = 0; k < inputColNum; k++) { // go over all columns | ||
if (k == _idIndex || k == _responseIndex) continue; | ||
double queryColData = _queryData[k].atd(i); | ||
double inputColData = cs[k].atd(j); | ||
_distance.calculateValues(queryColData, inputColData); | ||
} | ||
double dist = _distance.result(); | ||
|
||
distancesMap.put(new KNNKey(inputDataId, dist), inputDataCategory); | ||
} | ||
_topNNeighboursMaps.put(queryDataId, distancesMap); | ||
} | ||
} | ||
|
||
@Override | ||
public void reduce(KNNDistanceTask mrt) { | ||
KNNHashMap<String, TopNTreeMap<KNNKey, Object>> inputMap = mrt._topNNeighboursMaps; | ||
this._topNNeighboursMaps.reduce(inputMap); | ||
} | ||
|
||
/** | ||
* Get data from maps to Frame | ||
* @param vecs | ||
* @return filled array of vecs with calculated data | ||
*/ | ||
public Vec[] fillVecs(Vec[] vecs){ | ||
for (int i = 0; i < vecs[0].length(); i++) { | ||
String id = _idColumnType == Vec.T_STR ? vecs[0].stringAt(i) : String.valueOf(vecs[0].at8(i)); | ||
TopNTreeMap<KNNKey, Object> topNMap = _topNNeighboursMaps.get(id); | ||
Iterator<KNNKey> distances = topNMap.keySet().stream().iterator(); | ||
Iterator<Object> responses = topNMap.values().iterator(); | ||
for (int j = 1; j < _k+1; j++) { | ||
KNNKey key = distances.next(); | ||
String keyString = key.key.toString(); | ||
vecs[j].set(i, key.value); | ||
if(_idColumnType == Vec.T_STR){ | ||
vecs[_k + j].set(i, keyString); | ||
} else { | ||
vecs[_k + j].set(i, Integer.parseInt(keyString)); | ||
} | ||
vecs[2 * _k + j].set(i, (long) responses.next()); | ||
} | ||
} | ||
return vecs; | ||
} | ||
|
||
/** | ||
* Generate output frame with calculated distances. | ||
* @return | ||
*/ | ||
public Frame outputFrame() { | ||
int newVecsSize = _k*3+1; | ||
Vec[] vecs = new Vec[newVecsSize]; | ||
String[] names = new String[newVecsSize]; | ||
boolean isStringId = _idColumnType == Vec.T_STR; | ||
Vec id = Vec.makeCon(0, _queryData[0].len(), false); | ||
for (int i = 0; i < _queryData[_idIndex].len(); i++) { | ||
if(isStringId) { | ||
id.set(i, _queryData[_idIndex].stringAt(i)); | ||
} else { | ||
id.set(i, _queryData[_idIndex].atd(i)); | ||
} | ||
} | ||
vecs[0] = id; | ||
names[0] = _idColumn; | ||
for (int i = 1; i < _k+1; i++) { | ||
// names of columns | ||
names[i] = "dist_"+i; // this could be customized | ||
names[_k+i] = _idColumn+"_"+i; // this could be customized | ||
names[2*_k+i] = _responseColumn+"_"+i; // this could be customized | ||
vecs[i] = id.makeZero(); | ||
vecs[i] = vecs[i].toNumericVec(); | ||
vecs[_k+i] = id.makeZero(); | ||
if (isStringId) vecs[_k+i].toStringVec(); | ||
vecs[2*_k+i] = id.makeZero(); | ||
vecs[2*_k+i] = vecs[2*_k+i].toNumericVec(); | ||
} | ||
vecs = fillVecs(vecs); | ||
return new Frame(Key.make("KNN_distances_tmp"), names, vecs); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
package hex.knn; | ||
|
||
import java.util.HashMap; | ||
import java.util.Map; | ||
|
||
public class KNNHashMap<K, V extends TopNTreeMap<KNNKey, Object>> extends HashMap<K, V> { | ||
|
||
public void reduce(KNNHashMap<K, V> map){ | ||
for (Map.Entry<K, V> entry: map.entrySet()) { | ||
K key = entry.getKey(); | ||
V valueMap = entry.getValue(); | ||
if (this.containsKey(key)){ | ||
V currentKeyMap = this.get(key); | ||
currentKeyMap.putAll(valueMap); | ||
this.put(key, currentKeyMap); | ||
} | ||
} | ||
} | ||
} |
Oops, something went wrong.