diff --git a/src/main/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTree.java b/src/main/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTree.java index eb115c16..3548bfd8 100644 --- a/src/main/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTree.java +++ b/src/main/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTree.java @@ -94,12 +94,16 @@ public static class Split implements Node { private final Node right; private final int feature; private final float threshold; + private final int leftNodeId; + private final int missingNodeId; - public Split(Node left, Node right, int feature, float threshold) { + public Split(Node left, Node right, int feature, float threshold, int leftNodeId, int missingNodeId) { this.left = Objects.requireNonNull(left); this.right = Objects.requireNonNull(right); this.feature = feature; this.threshold = threshold; + this.leftNodeId = leftNodeId; + this.missingNodeId = missingNodeId; } @Override @@ -113,9 +117,18 @@ public float eval(float[] scores) { while (!n.isLeaf()) { assert n instanceof Split; Split s = (Split) n; - if (s.threshold > scores[s.feature]) { + if (Float.isNaN(scores[s.feature])) { + if (s.missingNodeId == s.leftNodeId) { + n = s.left; + } + else { + n = s.right; + } + } + else if (s.threshold > scores[s.feature]) { n = s.left; - } else { + } + else { n = s.right; } } diff --git a/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParser.java b/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParser.java index 7b53fa7f..41fc8197 100644 --- a/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParser.java +++ b/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParser.java @@ -249,7 +249,7 @@ boolean isSplit() { Node toNode(FeatureSet set) { if (isSplit()) { return new NaiveAdditiveDecisionTree.Split(children.get(0).toNode(set), children.get(1).toNode(set), - set.featureOrdinal(split), threshold); + set.featureOrdinal(split), threshold, leftNodeId, missingNodeId); } else { return new NaiveAdditiveDecisionTree.Leaf(leaf); } diff --git a/src/test/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java b/src/test/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java index 9606a6e2..a4f17fed 100644 --- a/src/test/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java +++ b/src/test/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java @@ -71,6 +71,17 @@ public void testScore() throws IOException { assertEquals(expected, ranker.score(vector), Math.ulp(expected)); } + public void testScoreMissing() throws IOException { + NaiveAdditiveDecisionTree ranker = parseTreeModel("simple_tree.txt", Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME)); + LtrRanker.FeatureVector vector = ranker.newFeatureVector(null); + vector.setFeatureScore(0, Float.NaN); + vector.setFeatureScore(1, Float.NaN); + vector.setFeatureScore(2, Float.NaN); + + float expected = 17.0F*3.4F + 23.0F*2.8F; + assertEquals(expected, ranker.score(vector), Math.ulp(expected)); + } + public void testSigmoidScore() throws IOException { NaiveAdditiveDecisionTree ranker = parseTreeModel("simple_tree.txt", Normalizers.get(Normalizers.SIGMOID_NORMALIZER_NAME)); LtrRanker.FeatureVector vector = ranker.newFeatureVector(null); @@ -113,7 +124,7 @@ public void testRamSize() { 100, 1000, 5, 50, counts); long actualSize = ranker.ramBytesUsed(); - long expectedApprox = counts.splits.get() * (NUM_BYTES_OBJECT_HEADER + Float.BYTES + NUM_BYTES_OBJECT_REF * 2); + long expectedApprox = counts.splits.get() * (NUM_BYTES_OBJECT_HEADER + Float.BYTES + NUM_BYTES_OBJECT_REF * 2 + Integer.BYTES * 3); expectedApprox += counts.leaves.get() * (NUM_BYTES_ARRAY_HEADER + NUM_BYTES_OBJECT_HEADER + Float.BYTES); expectedApprox += ranker.size() * Float.BYTES + NUM_BYTES_ARRAY_HEADER; assertThat(actualSize, allOf( @@ -207,17 +218,20 @@ NaiveAdditiveDecisionTree.Node parseTree() { if (line.contains("- output")) { return new NaiveAdditiveDecisionTree.Leaf(extractLastFloat(line)); } else if(line.contains("- split")) { - String featName = line.split(":")[1]; + String[] values = line.split(":"); + String featName = values[1]; + float threshold = Float.parseFloat(values[2]); + int leftNodeId = Integer.parseInt(values[3]); + int missingNodeId = Integer.parseInt(values[5]); int ord = set.featureOrdinal(featName); if (ord < 0 || ord > set.size()) { throw new IllegalArgumentException("Unknown feature " + featName); } - float threshold = extractLastFloat(line); NaiveAdditiveDecisionTree.Node right = parseTree(); NaiveAdditiveDecisionTree.Node left = parseTree(); return new NaiveAdditiveDecisionTree.Split(left, right, - ord, threshold); + ord, threshold, leftNodeId, missingNodeId); } else { throw new IllegalArgumentException("Invalid tree"); } @@ -282,7 +296,7 @@ private NaiveAdditiveDecisionTree.Node newSplit(int depth) { int feature = featureGen.get(); float thresh = thresholdGenerator.apply(feature); statsCollector.newSplit(depth, feature, thresh); - return new NaiveAdditiveDecisionTree.Split(newNode(depth), newNode(depth), feature, thresh); + return new NaiveAdditiveDecisionTree.Split(newNode(depth), newNode(depth), feature, thresh, 1, 1); } private NaiveAdditiveDecisionTree.Node newLeaf(int depth) { diff --git a/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserTests.java b/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserTests.java index 2ef7db20..0324033e 100644 --- a/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserTests.java +++ b/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserTests.java @@ -58,7 +58,7 @@ public void testReadSimpleSplit() throws IOException { "\"split_condition\":0.123," + "\"yes\":1," + "\"no\": 2," + - "\"missing\":2,"+ + "\"missing\":1,"+ "\"children\": [" + " {\"nodeid\": 1, \"depth\": 1, \"leaf\": 0.5}," + " {\"nodeid\": 2, \"depth\": 1, \"leaf\": 0.2}" + @@ -73,6 +73,8 @@ public void testReadSimpleSplit() throws IOException { assertEquals(0.5F, tree.score(v), Math.ulp(0.5F)); v.setFeatureScore(0, 0.123F); assertEquals(0.2F, tree.score(v), Math.ulp(0.2F)); + v.setFeatureScore(0, Float.NaN); + assertEquals(0.5F, tree.score(v), Math.ulp(0.2F)); } public void testReadSimpleSplitInObject() throws IOException { diff --git a/src/test/resources/com/o19s/es/ltr/ranker/dectree/simple_tree.txt b/src/test/resources/com/o19s/es/ltr/ranker/dectree/simple_tree.txt index 109b1cb3..ebab5540 100644 --- a/src/test/resources/com/o19s/es/ltr/ranker/dectree/simple_tree.txt +++ b/src/test/resources/com/o19s/es/ltr/ranker/dectree/simple_tree.txt @@ -1,21 +1,22 @@ # first line after split is right +# on a split line, the last 3 integers correspond to leftNodeId, rightNodeId, missingNodeId # data point: feature1:1, feature2:2, feature3:3 - tree:3.4 - - split:feature1:2.3 + - split:feature1:2.3:1:2:1 - output:3.2 # right wins - - split:feature2:2.2 - - split:feature3:3.2 + - split:feature2:2.2:3:4:4 + - split:feature3:3.2:5:6:5 - output:11 - output:17 # left wins => output 1.2*3.4 - output:1.2 - tree:2.8 - - split:feature1:0.1 + - split:feature1:0.1:1:2:1 # right wins - - split:feature2:1.8 + - split:feature2:1.8:3:4:4 # right wins - - split:feature3:3.2 + - split:feature3:3.2:5:6:6 - output:10 # left wins => output 3.2*2.8 - output:3.2