Skip to content

Commit

Permalink
Update - xgboost to handle missing values (#480)
Browse files Browse the repository at this point in the history
With this change XGBoost can handle missing feature values in NaiveAdditiveDecisionTree.

* debug

* debug

* fix simple_tree tests

* fix testRamSize

* nits

* Remove `rightNodeId` as per PR comment #452 (comment)

---------

Co-authored-by: Patrick Le <[email protected]>
Co-authored-by: lechipatrick <[email protected]>
  • Loading branch information
3 people authored Dec 14, 2023
1 parent b03f570 commit 39024f6
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,16 @@ public static class Split implements Node {
private final Node right;
private final int feature;
private final float threshold;
private final int leftNodeId;
private final int missingNodeId;

public Split(Node left, Node right, int feature, float threshold) {
public Split(Node left, Node right, int feature, float threshold, int leftNodeId, int missingNodeId) {
this.left = Objects.requireNonNull(left);
this.right = Objects.requireNonNull(right);
this.feature = feature;
this.threshold = threshold;
this.leftNodeId = leftNodeId;
this.missingNodeId = missingNodeId;
}

@Override
Expand All @@ -113,9 +117,18 @@ public float eval(float[] scores) {
while (!n.isLeaf()) {
assert n instanceof Split;
Split s = (Split) n;
if (s.threshold > scores[s.feature]) {
if (Float.isNaN(scores[s.feature])) {
if (s.missingNodeId == s.leftNodeId) {
n = s.left;
}
else {
n = s.right;
}
}
else if (s.threshold > scores[s.feature]) {
n = s.left;
} else {
}
else {
n = s.right;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ boolean isSplit() {
Node toNode(FeatureSet set) {
if (isSplit()) {
return new NaiveAdditiveDecisionTree.Split(children.get(0).toNode(set), children.get(1).toNode(set),
set.featureOrdinal(split), threshold);
set.featureOrdinal(split), threshold, leftNodeId, missingNodeId);
} else {
return new NaiveAdditiveDecisionTree.Leaf(leaf);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,17 @@ public void testScore() throws IOException {
assertEquals(expected, ranker.score(vector), Math.ulp(expected));
}

public void testScoreMissing() throws IOException {
NaiveAdditiveDecisionTree ranker = parseTreeModel("simple_tree.txt", Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME));
LtrRanker.FeatureVector vector = ranker.newFeatureVector(null);
vector.setFeatureScore(0, Float.NaN);
vector.setFeatureScore(1, Float.NaN);
vector.setFeatureScore(2, Float.NaN);

float expected = 17.0F*3.4F + 23.0F*2.8F;
assertEquals(expected, ranker.score(vector), Math.ulp(expected));
}

public void testSigmoidScore() throws IOException {
NaiveAdditiveDecisionTree ranker = parseTreeModel("simple_tree.txt", Normalizers.get(Normalizers.SIGMOID_NORMALIZER_NAME));
LtrRanker.FeatureVector vector = ranker.newFeatureVector(null);
Expand Down Expand Up @@ -113,7 +124,7 @@ public void testRamSize() {
100, 1000,
5, 50, counts);
long actualSize = ranker.ramBytesUsed();
long expectedApprox = counts.splits.get() * (NUM_BYTES_OBJECT_HEADER + Float.BYTES + NUM_BYTES_OBJECT_REF * 2);
long expectedApprox = counts.splits.get() * (NUM_BYTES_OBJECT_HEADER + Float.BYTES + NUM_BYTES_OBJECT_REF * 2 + Integer.BYTES * 3);
expectedApprox += counts.leaves.get() * (NUM_BYTES_ARRAY_HEADER + NUM_BYTES_OBJECT_HEADER + Float.BYTES);
expectedApprox += ranker.size() * Float.BYTES + NUM_BYTES_ARRAY_HEADER;
assertThat(actualSize, allOf(
Expand Down Expand Up @@ -207,17 +218,20 @@ NaiveAdditiveDecisionTree.Node parseTree() {
if (line.contains("- output")) {
return new NaiveAdditiveDecisionTree.Leaf(extractLastFloat(line));
} else if(line.contains("- split")) {
String featName = line.split(":")[1];
String[] values = line.split(":");
String featName = values[1];
float threshold = Float.parseFloat(values[2]);
int leftNodeId = Integer.parseInt(values[3]);
int missingNodeId = Integer.parseInt(values[5]);
int ord = set.featureOrdinal(featName);
if (ord < 0 || ord > set.size()) {
throw new IllegalArgumentException("Unknown feature " + featName);
}
float threshold = extractLastFloat(line);
NaiveAdditiveDecisionTree.Node right = parseTree();
NaiveAdditiveDecisionTree.Node left = parseTree();

return new NaiveAdditiveDecisionTree.Split(left, right,
ord, threshold);
ord, threshold, leftNodeId, missingNodeId);
} else {
throw new IllegalArgumentException("Invalid tree");
}
Expand Down Expand Up @@ -282,7 +296,7 @@ private NaiveAdditiveDecisionTree.Node newSplit(int depth) {
int feature = featureGen.get();
float thresh = thresholdGenerator.apply(feature);
statsCollector.newSplit(depth, feature, thresh);
return new NaiveAdditiveDecisionTree.Split(newNode(depth), newNode(depth), feature, thresh);
return new NaiveAdditiveDecisionTree.Split(newNode(depth), newNode(depth), feature, thresh, 1, 1);
}

private NaiveAdditiveDecisionTree.Node newLeaf(int depth) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public void testReadSimpleSplit() throws IOException {
"\"split_condition\":0.123," +
"\"yes\":1," +
"\"no\": 2," +
"\"missing\":2,"+
"\"missing\":1,"+
"\"children\": [" +
" {\"nodeid\": 1, \"depth\": 1, \"leaf\": 0.5}," +
" {\"nodeid\": 2, \"depth\": 1, \"leaf\": 0.2}" +
Expand All @@ -73,6 +73,8 @@ public void testReadSimpleSplit() throws IOException {
assertEquals(0.5F, tree.score(v), Math.ulp(0.5F));
v.setFeatureScore(0, 0.123F);
assertEquals(0.2F, tree.score(v), Math.ulp(0.2F));
v.setFeatureScore(0, Float.NaN);
assertEquals(0.5F, tree.score(v), Math.ulp(0.2F));
}

public void testReadSimpleSplitInObject() throws IOException {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
# first line after split is right
# on a split line, the last 3 integers correspond to leftNodeId, rightNodeId, missingNodeId
# data point: feature1:1, feature2:2, feature3:3
- tree:3.4
- split:feature1:2.3
- split:feature1:2.3:1:2:1
- output:3.2
# right wins
- split:feature2:2.2
- split:feature3:3.2
- split:feature2:2.2:3:4:4
- split:feature3:3.2:5:6:5
- output:11
- output:17
# left wins => output 1.2*3.4
- output:1.2
- tree:2.8
- split:feature1:0.1
- split:feature1:0.1:1:2:1
# right wins
- split:feature2:1.8
- split:feature2:1.8:3:4:4
# right wins
- split:feature3:3.2
- split:feature3:3.2:5:6:6
- output:10
# left wins => output 3.2*2.8
- output:3.2
Expand Down

0 comments on commit 39024f6

Please sign in to comment.