Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update - xgboost to handle missing values #480

Merged
merged 10 commits into from
Dec 14, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,16 @@ public static class Split implements Node {
private final Node right;
private final int feature;
private final float threshold;
private final int leftNodeId;
private final int missingNodeId;

public Split(Node left, Node right, int feature, float threshold) {
public Split(Node left, Node right, int feature, float threshold, int leftNodeId, int missingNodeId) {
this.left = Objects.requireNonNull(left);
this.right = Objects.requireNonNull(right);
this.feature = feature;
this.threshold = threshold;
this.leftNodeId = leftNodeId;
this.missingNodeId = missingNodeId;
}

@Override
Expand All @@ -113,9 +117,18 @@ public float eval(float[] scores) {
while (!n.isLeaf()) {
assert n instanceof Split;
Split s = (Split) n;
if (s.threshold > scores[s.feature]) {
if (Float.isNaN(scores[s.feature])) {
if (s.missingNodeId == s.leftNodeId) {
n = s.left;
}
else {
n = s.right;
}
}
else if (s.threshold > scores[s.feature]) {
n = s.left;
} else {
}
else {
n = s.right;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ boolean isSplit() {
Node toNode(FeatureSet set) {
if (isSplit()) {
return new NaiveAdditiveDecisionTree.Split(children.get(0).toNode(set), children.get(1).toNode(set),
set.featureOrdinal(split), threshold);
set.featureOrdinal(split), threshold, leftNodeId, missingNodeId);
} else {
return new NaiveAdditiveDecisionTree.Leaf(leaf);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,17 @@ public void testScore() throws IOException {
assertEquals(expected, ranker.score(vector), Math.ulp(expected));
}

public void testScoreMissing() throws IOException {
NaiveAdditiveDecisionTree ranker = parseTreeModel("simple_tree.txt", Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME));
LtrRanker.FeatureVector vector = ranker.newFeatureVector(null);
vector.setFeatureScore(0, Float.NaN);
vector.setFeatureScore(1, Float.NaN);
vector.setFeatureScore(2, Float.NaN);

float expected = 17.0F*3.4F + 23.0F*2.8F;
assertEquals(expected, ranker.score(vector), Math.ulp(expected));
}

public void testSigmoidScore() throws IOException {
NaiveAdditiveDecisionTree ranker = parseTreeModel("simple_tree.txt", Normalizers.get(Normalizers.SIGMOID_NORMALIZER_NAME));
LtrRanker.FeatureVector vector = ranker.newFeatureVector(null);
Expand Down Expand Up @@ -113,7 +124,7 @@ public void testRamSize() {
100, 1000,
5, 50, counts);
long actualSize = ranker.ramBytesUsed();
long expectedApprox = counts.splits.get() * (NUM_BYTES_OBJECT_HEADER + Float.BYTES + NUM_BYTES_OBJECT_REF * 2);
long expectedApprox = counts.splits.get() * (NUM_BYTES_OBJECT_HEADER + Float.BYTES + NUM_BYTES_OBJECT_REF * 2 + Integer.BYTES * 3);
expectedApprox += counts.leaves.get() * (NUM_BYTES_ARRAY_HEADER + NUM_BYTES_OBJECT_HEADER + Float.BYTES);
expectedApprox += ranker.size() * Float.BYTES + NUM_BYTES_ARRAY_HEADER;
assertThat(actualSize, allOf(
Expand Down Expand Up @@ -207,17 +218,20 @@ NaiveAdditiveDecisionTree.Node parseTree() {
if (line.contains("- output")) {
return new NaiveAdditiveDecisionTree.Leaf(extractLastFloat(line));
} else if(line.contains("- split")) {
String featName = line.split(":")[1];
String[] values = line.split(":");
String featName = values[1];
float threshold = Float.parseFloat(values[2]);
int leftNodeId = Integer.parseInt(values[3]);
int missingNodeId = Integer.parseInt(values[5]);
int ord = set.featureOrdinal(featName);
if (ord < 0 || ord > set.size()) {
throw new IllegalArgumentException("Unknown feature " + featName);
}
float threshold = extractLastFloat(line);
NaiveAdditiveDecisionTree.Node right = parseTree();
NaiveAdditiveDecisionTree.Node left = parseTree();

return new NaiveAdditiveDecisionTree.Split(left, right,
ord, threshold);
ord, threshold, leftNodeId, missingNodeId);
} else {
throw new IllegalArgumentException("Invalid tree");
}
Expand Down Expand Up @@ -282,7 +296,7 @@ private NaiveAdditiveDecisionTree.Node newSplit(int depth) {
int feature = featureGen.get();
float thresh = thresholdGenerator.apply(feature);
statsCollector.newSplit(depth, feature, thresh);
return new NaiveAdditiveDecisionTree.Split(newNode(depth), newNode(depth), feature, thresh);
return new NaiveAdditiveDecisionTree.Split(newNode(depth), newNode(depth), feature, thresh, 1, 1);
}

private NaiveAdditiveDecisionTree.Node newLeaf(int depth) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public void testReadSimpleSplit() throws IOException {
"\"split_condition\":0.123," +
"\"yes\":1," +
"\"no\": 2," +
"\"missing\":2,"+
"\"missing\":1,"+
"\"children\": [" +
" {\"nodeid\": 1, \"depth\": 1, \"leaf\": 0.5}," +
" {\"nodeid\": 2, \"depth\": 1, \"leaf\": 0.2}" +
Expand All @@ -73,6 +73,8 @@ public void testReadSimpleSplit() throws IOException {
assertEquals(0.5F, tree.score(v), Math.ulp(0.5F));
v.setFeatureScore(0, 0.123F);
assertEquals(0.2F, tree.score(v), Math.ulp(0.2F));
v.setFeatureScore(0, Float.NaN);
assertEquals(0.5F, tree.score(v), Math.ulp(0.2F));
}

public void testReadSimpleSplitInObject() throws IOException {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
# first line after split is right
# on a split line, the last 3 integers correspond to leftNodeId, rightNodeId, missingNodeId
# data point: feature1:1, feature2:2, feature3:3
- tree:3.4
- split:feature1:2.3
- split:feature1:2.3:1:2:1
- output:3.2
# right wins
- split:feature2:2.2
- split:feature3:3.2
- split:feature2:2.2:3:4:4
- split:feature3:3.2:5:6:5
- output:11
- output:17
# left wins => output 1.2*3.4
- output:1.2
- tree:2.8
- split:feature1:0.1
- split:feature1:0.1:1:2:1
# right wins
- split:feature2:1.8
- split:feature2:1.8:3:4:4
# right wins
- split:feature3:3.2
- split:feature3:3.2:5:6:6
- output:10
# left wins => output 3.2*2.8
- output:3.2
Expand Down
Loading