From cd5ed69ae3a3cfd379ab2bdebb07a8296b12f8e8 Mon Sep 17 00:00:00 2001 From: Patrick Le Date: Tue, 3 Jan 2023 11:05:04 -0600 Subject: [PATCH 1/6] debug --- .../dectree/NaiveAdditiveDecisionTree.java | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/src/main/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTree.java b/src/main/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTree.java index eb115c16..f99192b0 100644 --- a/src/main/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTree.java +++ b/src/main/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTree.java @@ -94,12 +94,18 @@ public static class Split implements Node { private final Node right; private final int feature; private final float threshold; + private final int leftNodeId; + private final int rightNodeId; + private final int missingNodeId; - public Split(Node left, Node right, int feature, float threshold) { + public Split(Node left, Node right, int feature, float threshold, int leftNodeId, int rightNodeId, int missingNodeId) { this.left = Objects.requireNonNull(left); this.right = Objects.requireNonNull(right); this.feature = feature; this.threshold = threshold; + this.leftNodeId = leftNodeId; + this.rightNodeId = rightNodeId; + this.missingNodeId = missingNodeId; } @Override @@ -113,7 +119,15 @@ public float eval(float[] scores) { while (!n.isLeaf()) { assert n instanceof Split; Split s = (Split) n; - if (s.threshold > scores[s.feature]) { + if (Float.isNaN(scores[s.feature])) { + if (s.missingNodeId == s.leftNodeId) { + n = s.left; + } + else { + n = s.right; + } + } + else if (s.threshold > scores[s.feature]) { n = s.left; } else { n = s.right; From 66a1b96c5f27d524cdd10328feef6ac450e0252d Mon Sep 17 00:00:00 2001 From: Patrick Le Date: Tue, 3 Jan 2023 11:06:03 -0600 Subject: [PATCH 2/6] debug --- .../ltr/ranker/parser/XGBoostJsonParser.java | 2 +- .../NaiveAdditiveDecisionTreeTests.java | 37 +++++++++++-------- .../ranker/parser/XGBoostJsonParserTests.java | 4 +- .../es/ltr/ranker/dectree/simple_tree.txt | 14 +++---- 4 files changed, 32 insertions(+), 25 deletions(-) diff --git a/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParser.java b/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParser.java index f02cc949..39dbe780 100644 --- a/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParser.java +++ b/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParser.java @@ -247,7 +247,7 @@ boolean isSplit() { Node toNode(FeatureSet set) { if (isSplit()) { return new NaiveAdditiveDecisionTree.Split(children.get(0).toNode(set), children.get(1).toNode(set), - set.featureOrdinal(split), threshold); + set.featureOrdinal(split), threshold, leftNodeId, rightNodeId, missingNodeId); } else { return new NaiveAdditiveDecisionTree.Leaf(leaf); } diff --git a/src/test/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java b/src/test/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java index 9606a6e2..c483791b 100644 --- a/src/test/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java +++ b/src/test/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java @@ -68,6 +68,7 @@ public void testScore() throws IOException { vector.setFeatureScore(2, 3); float expected = 1.2F*3.4F + 3.2F*2.8F; + float actual = ranker.score(vector); assertEquals(expected, ranker.score(vector), Math.ulp(expected)); } @@ -107,19 +108,19 @@ public void testPerfAndRobustness() { counts.nodes.get(), counts.splits.get(), counts.leaves.get()); } - public void testRamSize() { - SimpleCountRandomTreeGeneratorStatsCollector counts = new SimpleCountRandomTreeGeneratorStatsCollector(); - NaiveAdditiveDecisionTree ranker = generateRandomDecTree(100, 1000, - 100, 1000, - 5, 50, counts); - long actualSize = ranker.ramBytesUsed(); - long expectedApprox = counts.splits.get() * (NUM_BYTES_OBJECT_HEADER + Float.BYTES + NUM_BYTES_OBJECT_REF * 2); - expectedApprox += counts.leaves.get() * (NUM_BYTES_ARRAY_HEADER + NUM_BYTES_OBJECT_HEADER + Float.BYTES); - expectedApprox += ranker.size() * Float.BYTES + NUM_BYTES_ARRAY_HEADER; - assertThat(actualSize, allOf( - greaterThan((long) (expectedApprox*0.66F)), - lessThan((long) (expectedApprox*1.33F)))); - } +// public void testRamSize() { +// SimpleCountRandomTreeGeneratorStatsCollector counts = new SimpleCountRandomTreeGeneratorStatsCollector(); +// NaiveAdditiveDecisionTree ranker = generateRandomDecTree(100, 1000, +// 100, 1000, +// 5, 50, counts); +// long actualSize = ranker.ramBytesUsed(); +// long expectedApprox = counts.splits.get() * (NUM_BYTES_OBJECT_HEADER + Float.BYTES + NUM_BYTES_OBJECT_REF * 2); +// expectedApprox += counts.leaves.get() * (NUM_BYTES_ARRAY_HEADER + NUM_BYTES_OBJECT_HEADER + Float.BYTES); +// expectedApprox += ranker.size() * Float.BYTES + NUM_BYTES_ARRAY_HEADER; +// assertThat(actualSize, allOf( +// greaterThan((long) (expectedApprox*0.66F)), +// lessThan((long) (expectedApprox*1.33F)))); +// } public static NaiveAdditiveDecisionTree generateRandomDecTree(int minFeatures, int maxFeatures, int minTrees, int maxTrees, int minDepth, int maxDepth, @@ -207,7 +208,11 @@ NaiveAdditiveDecisionTree.Node parseTree() { if (line.contains("- output")) { return new NaiveAdditiveDecisionTree.Leaf(extractLastFloat(line)); } else if(line.contains("- split")) { - String featName = line.split(":")[1]; + String[] values = line.split(":"); + String featName = values[1]; + int leftNodeId = Integer.parseInt(values[3]); + int rightNodeId = Integer.parseInt(values[4]); + int missingNodeId = Integer.parseInt(values[5]); int ord = set.featureOrdinal(featName); if (ord < 0 || ord > set.size()) { throw new IllegalArgumentException("Unknown feature " + featName); @@ -217,7 +222,7 @@ NaiveAdditiveDecisionTree.Node parseTree() { NaiveAdditiveDecisionTree.Node left = parseTree(); return new NaiveAdditiveDecisionTree.Split(left, right, - ord, threshold); + ord, threshold, leftNodeId, rightNodeId, missingNodeId); } else { throw new IllegalArgumentException("Invalid tree"); } @@ -282,7 +287,7 @@ private NaiveAdditiveDecisionTree.Node newSplit(int depth) { int feature = featureGen.get(); float thresh = thresholdGenerator.apply(feature); statsCollector.newSplit(depth, feature, thresh); - return new NaiveAdditiveDecisionTree.Split(newNode(depth), newNode(depth), feature, thresh); + return new NaiveAdditiveDecisionTree.Split(newNode(depth), newNode(depth), feature, thresh, 1, 2, 1); } private NaiveAdditiveDecisionTree.Node newLeaf(int depth) { diff --git a/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserTests.java b/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserTests.java index 2ef7db20..0324033e 100644 --- a/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserTests.java +++ b/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserTests.java @@ -58,7 +58,7 @@ public void testReadSimpleSplit() throws IOException { "\"split_condition\":0.123," + "\"yes\":1," + "\"no\": 2," + - "\"missing\":2,"+ + "\"missing\":1,"+ "\"children\": [" + " {\"nodeid\": 1, \"depth\": 1, \"leaf\": 0.5}," + " {\"nodeid\": 2, \"depth\": 1, \"leaf\": 0.2}" + @@ -73,6 +73,8 @@ public void testReadSimpleSplit() throws IOException { assertEquals(0.5F, tree.score(v), Math.ulp(0.5F)); v.setFeatureScore(0, 0.123F); assertEquals(0.2F, tree.score(v), Math.ulp(0.2F)); + v.setFeatureScore(0, Float.NaN); + assertEquals(0.5F, tree.score(v), Math.ulp(0.2F)); } public void testReadSimpleSplitInObject() throws IOException { diff --git a/src/test/resources/com/o19s/es/ltr/ranker/dectree/simple_tree.txt b/src/test/resources/com/o19s/es/ltr/ranker/dectree/simple_tree.txt index 109b1cb3..66abb2ee 100644 --- a/src/test/resources/com/o19s/es/ltr/ranker/dectree/simple_tree.txt +++ b/src/test/resources/com/o19s/es/ltr/ranker/dectree/simple_tree.txt @@ -1,23 +1,23 @@ # first line after split is right # data point: feature1:1, feature2:2, feature3:3 - tree:3.4 - - split:feature1:2.3 + - split:feature1:2.3:1:2:2 - output:3.2 # right wins - - split:feature2:2.2 - - split:feature3:3.2 + - split:feature2:2.2:3:4:4 + - split:feature3:3.2:5:6:6 - output:11 - output:17 # left wins => output 1.2*3.4 - output:1.2 - tree:2.8 - - split:feature1:0.1 + - split:feature1:0.1:1:2:2 # right wins - - split:feature2:1.8 + - split:feature2:1.8:3:4:4 # right wins - - split:feature3:3.2 + - split:feature3:3.2:5:6:6 - output:10 -# left wins => output 3.2*2.8 +# left wins => output 3.2*2.8:7:8:8 - output:3.2 - output:15 - output:23 From c6b758a4c191ccd980c181b6911203feff75e184 Mon Sep 17 00:00:00 2001 From: Patrick Le Date: Tue, 3 Jan 2023 11:18:54 -0600 Subject: [PATCH 3/6] fix simple_tree tests --- .../dectree/NaiveAdditiveDecisionTreeTests.java | 14 ++++++++++++-- .../com/o19s/es/ltr/ranker/dectree/simple_tree.txt | 8 ++++---- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/src/test/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java b/src/test/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java index c483791b..93ab9d3e 100644 --- a/src/test/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java +++ b/src/test/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java @@ -68,7 +68,17 @@ public void testScore() throws IOException { vector.setFeatureScore(2, 3); float expected = 1.2F*3.4F + 3.2F*2.8F; - float actual = ranker.score(vector); + assertEquals(expected, ranker.score(vector), Math.ulp(expected)); + } + + public void testScoreMissing() throws IOException { + NaiveAdditiveDecisionTree ranker = parseTreeModel("simple_tree.txt", Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME)); + LtrRanker.FeatureVector vector = ranker.newFeatureVector(null); + vector.setFeatureScore(0, Float.NaN); + vector.setFeatureScore(1, Float.NaN); + vector.setFeatureScore(2, Float.NaN); + + float expected = 17.0F*3.4F + 23.0F*2.8F; assertEquals(expected, ranker.score(vector), Math.ulp(expected)); } @@ -210,6 +220,7 @@ NaiveAdditiveDecisionTree.Node parseTree() { } else if(line.contains("- split")) { String[] values = line.split(":"); String featName = values[1]; + float threshold = Float.parseFloat(values[2]); int leftNodeId = Integer.parseInt(values[3]); int rightNodeId = Integer.parseInt(values[4]); int missingNodeId = Integer.parseInt(values[5]); @@ -217,7 +228,6 @@ NaiveAdditiveDecisionTree.Node parseTree() { if (ord < 0 || ord > set.size()) { throw new IllegalArgumentException("Unknown feature " + featName); } - float threshold = extractLastFloat(line); NaiveAdditiveDecisionTree.Node right = parseTree(); NaiveAdditiveDecisionTree.Node left = parseTree(); diff --git a/src/test/resources/com/o19s/es/ltr/ranker/dectree/simple_tree.txt b/src/test/resources/com/o19s/es/ltr/ranker/dectree/simple_tree.txt index 66abb2ee..bb7495ca 100644 --- a/src/test/resources/com/o19s/es/ltr/ranker/dectree/simple_tree.txt +++ b/src/test/resources/com/o19s/es/ltr/ranker/dectree/simple_tree.txt @@ -1,23 +1,23 @@ # first line after split is right # data point: feature1:1, feature2:2, feature3:3 - tree:3.4 - - split:feature1:2.3:1:2:2 + - split:feature1:2.3:1:2:1 - output:3.2 # right wins - split:feature2:2.2:3:4:4 - - split:feature3:3.2:5:6:6 + - split:feature3:3.2:5:6:5 - output:11 - output:17 # left wins => output 1.2*3.4 - output:1.2 - tree:2.8 - - split:feature1:0.1:1:2:2 + - split:feature1:0.1:1:2:1 # right wins - split:feature2:1.8:3:4:4 # right wins - split:feature3:3.2:5:6:6 - output:10 -# left wins => output 3.2*2.8:7:8:8 +# left wins => output 3.2*2.8:7:8:1 - output:3.2 - output:15 - output:23 From 50f32b6278d51cba4f509ba7048c33e0a697d9ac Mon Sep 17 00:00:00 2001 From: Patrick Le Date: Tue, 3 Jan 2023 11:32:06 -0600 Subject: [PATCH 4/6] fix testRamSize --- .../NaiveAdditiveDecisionTreeTests.java | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/src/test/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java b/src/test/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java index 93ab9d3e..687272c9 100644 --- a/src/test/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java +++ b/src/test/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java @@ -118,19 +118,20 @@ public void testPerfAndRobustness() { counts.nodes.get(), counts.splits.get(), counts.leaves.get()); } -// public void testRamSize() { -// SimpleCountRandomTreeGeneratorStatsCollector counts = new SimpleCountRandomTreeGeneratorStatsCollector(); -// NaiveAdditiveDecisionTree ranker = generateRandomDecTree(100, 1000, -// 100, 1000, -// 5, 50, counts); -// long actualSize = ranker.ramBytesUsed(); -// long expectedApprox = counts.splits.get() * (NUM_BYTES_OBJECT_HEADER + Float.BYTES + NUM_BYTES_OBJECT_REF * 2); -// expectedApprox += counts.leaves.get() * (NUM_BYTES_ARRAY_HEADER + NUM_BYTES_OBJECT_HEADER + Float.BYTES); -// expectedApprox += ranker.size() * Float.BYTES + NUM_BYTES_ARRAY_HEADER; -// assertThat(actualSize, allOf( -// greaterThan((long) (expectedApprox*0.66F)), -// lessThan((long) (expectedApprox*1.33F)))); -// } + public void testRamSize() { + SimpleCountRandomTreeGeneratorStatsCollector counts = new SimpleCountRandomTreeGeneratorStatsCollector(); + NaiveAdditiveDecisionTree ranker = generateRandomDecTree(100, 1000, + 100, 1000, + 5, 50, counts); + long actualSize = ranker.ramBytesUsed(); + long expectedApprox = counts.splits.get() * (NUM_BYTES_OBJECT_HEADER + Float.BYTES + NUM_BYTES_OBJECT_REF * 2 + Integer.BYTES * 3); + int num_splits = counts.splits.get(); + expectedApprox += counts.leaves.get() * (NUM_BYTES_ARRAY_HEADER + NUM_BYTES_OBJECT_HEADER + Float.BYTES); + expectedApprox += ranker.size() * Float.BYTES + NUM_BYTES_ARRAY_HEADER; + assertThat(actualSize, allOf( + greaterThan((long) (expectedApprox*0.66F)), + lessThan((long) (expectedApprox*1.33F)))); + } public static NaiveAdditiveDecisionTree generateRandomDecTree(int minFeatures, int maxFeatures, int minTrees, int maxTrees, int minDepth, int maxDepth, From d0eaad81ac730cd72f2b1b9f70ae2f4f14121555 Mon Sep 17 00:00:00 2001 From: Patrick Le Date: Mon, 9 Jan 2023 13:24:45 -0600 Subject: [PATCH 5/6] nits --- .../o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTree.java | 3 ++- .../es/ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java | 1 - .../resources/com/o19s/es/ltr/ranker/dectree/simple_tree.txt | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/main/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTree.java b/src/main/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTree.java index f99192b0..4ab19e2d 100644 --- a/src/main/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTree.java +++ b/src/main/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTree.java @@ -129,7 +129,8 @@ public float eval(float[] scores) { } else if (s.threshold > scores[s.feature]) { n = s.left; - } else { + } + else { n = s.right; } } diff --git a/src/test/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java b/src/test/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java index 687272c9..22a14643 100644 --- a/src/test/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java +++ b/src/test/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java @@ -125,7 +125,6 @@ public void testRamSize() { 5, 50, counts); long actualSize = ranker.ramBytesUsed(); long expectedApprox = counts.splits.get() * (NUM_BYTES_OBJECT_HEADER + Float.BYTES + NUM_BYTES_OBJECT_REF * 2 + Integer.BYTES * 3); - int num_splits = counts.splits.get(); expectedApprox += counts.leaves.get() * (NUM_BYTES_ARRAY_HEADER + NUM_BYTES_OBJECT_HEADER + Float.BYTES); expectedApprox += ranker.size() * Float.BYTES + NUM_BYTES_ARRAY_HEADER; assertThat(actualSize, allOf( diff --git a/src/test/resources/com/o19s/es/ltr/ranker/dectree/simple_tree.txt b/src/test/resources/com/o19s/es/ltr/ranker/dectree/simple_tree.txt index bb7495ca..ebab5540 100644 --- a/src/test/resources/com/o19s/es/ltr/ranker/dectree/simple_tree.txt +++ b/src/test/resources/com/o19s/es/ltr/ranker/dectree/simple_tree.txt @@ -1,4 +1,5 @@ # first line after split is right +# on a split line, the last 3 integers correspond to leftNodeId, rightNodeId, missingNodeId # data point: feature1:1, feature2:2, feature3:3 - tree:3.4 - split:feature1:2.3:1:2:1 @@ -17,7 +18,7 @@ # right wins - split:feature3:3.2:5:6:6 - output:10 -# left wins => output 3.2*2.8:7:8:1 +# left wins => output 3.2*2.8 - output:3.2 - output:15 - output:23 From 181f718257f8798743cde031385fe613a4a003c7 Mon Sep 17 00:00:00 2001 From: wrigleyDan Date: Wed, 13 Dec 2023 15:47:45 +0100 Subject: [PATCH 6/6] Remove `rightNodeId` as per PR comment https://github.com/o19s/elasticsearch-learning-to-rank/pull/452#issuecomment-1561157753 --- .../es/ltr/ranker/dectree/NaiveAdditiveDecisionTree.java | 4 +--- .../com/o19s/es/ltr/ranker/parser/XGBoostJsonParser.java | 2 +- .../ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java | 5 ++--- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/main/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTree.java b/src/main/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTree.java index 4ab19e2d..3548bfd8 100644 --- a/src/main/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTree.java +++ b/src/main/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTree.java @@ -95,16 +95,14 @@ public static class Split implements Node { private final int feature; private final float threshold; private final int leftNodeId; - private final int rightNodeId; private final int missingNodeId; - public Split(Node left, Node right, int feature, float threshold, int leftNodeId, int rightNodeId, int missingNodeId) { + public Split(Node left, Node right, int feature, float threshold, int leftNodeId, int missingNodeId) { this.left = Objects.requireNonNull(left); this.right = Objects.requireNonNull(right); this.feature = feature; this.threshold = threshold; this.leftNodeId = leftNodeId; - this.rightNodeId = rightNodeId; this.missingNodeId = missingNodeId; } diff --git a/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParser.java b/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParser.java index 39dbe780..281eaa67 100644 --- a/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParser.java +++ b/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParser.java @@ -247,7 +247,7 @@ boolean isSplit() { Node toNode(FeatureSet set) { if (isSplit()) { return new NaiveAdditiveDecisionTree.Split(children.get(0).toNode(set), children.get(1).toNode(set), - set.featureOrdinal(split), threshold, leftNodeId, rightNodeId, missingNodeId); + set.featureOrdinal(split), threshold, leftNodeId, missingNodeId); } else { return new NaiveAdditiveDecisionTree.Leaf(leaf); } diff --git a/src/test/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java b/src/test/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java index 22a14643..a4f17fed 100644 --- a/src/test/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java +++ b/src/test/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java @@ -222,7 +222,6 @@ NaiveAdditiveDecisionTree.Node parseTree() { String featName = values[1]; float threshold = Float.parseFloat(values[2]); int leftNodeId = Integer.parseInt(values[3]); - int rightNodeId = Integer.parseInt(values[4]); int missingNodeId = Integer.parseInt(values[5]); int ord = set.featureOrdinal(featName); if (ord < 0 || ord > set.size()) { @@ -232,7 +231,7 @@ NaiveAdditiveDecisionTree.Node parseTree() { NaiveAdditiveDecisionTree.Node left = parseTree(); return new NaiveAdditiveDecisionTree.Split(left, right, - ord, threshold, leftNodeId, rightNodeId, missingNodeId); + ord, threshold, leftNodeId, missingNodeId); } else { throw new IllegalArgumentException("Invalid tree"); } @@ -297,7 +296,7 @@ private NaiveAdditiveDecisionTree.Node newSplit(int depth) { int feature = featureGen.get(); float thresh = thresholdGenerator.apply(feature); statsCollector.newSplit(depth, feature, thresh); - return new NaiveAdditiveDecisionTree.Split(newNode(depth), newNode(depth), feature, thresh, 1, 2, 1); + return new NaiveAdditiveDecisionTree.Split(newNode(depth), newNode(depth), feature, thresh, 1, 1); } private NaiveAdditiveDecisionTree.Node newLeaf(int depth) {