From 4e39eb83c22ebfd894e4e41bd424d69f4239ab4c Mon Sep 17 00:00:00 2001 From: Platon Bibik Date: Mon, 4 Nov 2024 09:15:59 +0100 Subject: [PATCH] Fix irreproducible inference due to the imprecise floating point values in xgboost's get_dump (#500) Add new parser `model/xgboost+json+raw` to support xgboost's JSON output when using `get_dump` fixes #497 --- .../com/o19s/es/ltr/LtrQueryParserPlugin.java | 2 + .../ranker/parser/XGBoostRawJsonParser.java | 425 +++++++++++++ .../parser/XGBoostRawJsonParserTests.java | 578 ++++++++++++++++++ 3 files changed, 1005 insertions(+) create mode 100644 src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParser.java create mode 100644 src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParserTests.java diff --git a/src/main/java/com/o19s/es/ltr/LtrQueryParserPlugin.java b/src/main/java/com/o19s/es/ltr/LtrQueryParserPlugin.java index dc86730b..772b37d7 100644 --- a/src/main/java/com/o19s/es/ltr/LtrQueryParserPlugin.java +++ b/src/main/java/com/o19s/es/ltr/LtrQueryParserPlugin.java @@ -47,6 +47,7 @@ import com.o19s.es.ltr.ranker.parser.LinearRankerParser; import com.o19s.es.ltr.ranker.parser.LtrRankerParserFactory; import com.o19s.es.ltr.ranker.parser.XGBoostJsonParser; +import com.o19s.es.ltr.ranker.parser.XGBoostRawJsonParser; import com.o19s.es.ltr.ranker.ranklib.RankLibScriptEngine; import com.o19s.es.ltr.ranker.ranklib.RanklibModelParser; import com.o19s.es.ltr.rest.RestCreateModelFromSet; @@ -129,6 +130,7 @@ public LtrQueryParserPlugin(Settings settings) { .register(RanklibModelParser.TYPE, () -> new RanklibModelParser(ranklib.get())) .register(LinearRankerParser.TYPE, LinearRankerParser::new) .register(XGBoostJsonParser.TYPE, XGBoostJsonParser::new) + .register(XGBoostRawJsonParser.TYPE, XGBoostRawJsonParser::new) .build(); } diff --git a/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParser.java b/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParser.java new file mode 100644 index 00000000..658e7e76 --- /dev/null +++ b/src/main/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParser.java @@ -0,0 +1,425 @@ +package com.o19s.es.ltr.ranker.parser; + +import com.o19s.es.ltr.feature.FeatureSet; +import com.o19s.es.ltr.ranker.dectree.NaiveAdditiveDecisionTree; +import com.o19s.es.ltr.ranker.normalizer.Normalizer; +import com.o19s.es.ltr.ranker.normalizer.Normalizers; +import org.elasticsearch.common.ParsingException; +import org.elasticsearch.xcontent.ObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentParseException; +import org.elasticsearch.xcontent.json.JsonXContent; + +import java.io.IOException; +import java.util.Arrays; +import java.util.ArrayList; +import java.util.List; +import java.util.ListIterator; +import java.util.Optional; + +public class XGBoostRawJsonParser implements LtrRankerParser { + + public static final String TYPE = "model/xgboost+json+raw"; + + private static final Integer MISSING_NODE_ID = Integer.MAX_VALUE; + + @Override + public NaiveAdditiveDecisionTree parse(FeatureSet set, String model) { + XGBoostRawJsonParser.XGBoostDefinition modelDefinition; + try (XContentParser parser = JsonXContent.jsonXContent.createParser(XContentParserConfiguration.EMPTY, + model) + ) { + modelDefinition = XGBoostRawJsonParser.XGBoostDefinition.parse(parser, set); + } catch (IOException e) { + throw new IllegalArgumentException("Cannot parse model", e); + } + + NaiveAdditiveDecisionTree.Node[] trees = modelDefinition.getLearner().getTrees(set); + float[] weights = new float[trees.length]; + Arrays.fill(weights, 1F); + return new NaiveAdditiveDecisionTree(trees, weights, set.size(), modelDefinition.getLearner().getObjective().getNormalizer()); + } + + private static class XGBoostDefinition { + private static final ObjectParser PARSER; + + static { + PARSER = new ObjectParser<>("xgboost_definition", true, XGBoostRawJsonParser.XGBoostDefinition::new); + PARSER.declareObject( + XGBoostRawJsonParser.XGBoostDefinition::setLearner, + XGBoostRawJsonParser.XGBoostLearner::parse, + new ParseField("learner") + ); + PARSER.declareIntArray(XGBoostRawJsonParser.XGBoostDefinition::setVersion, new ParseField("version")); + } + + public static XGBoostRawJsonParser.XGBoostDefinition parse(XContentParser parser, FeatureSet set) throws IOException { + XGBoostRawJsonParser.XGBoostDefinition definition; + XContentParser.Token startToken = parser.nextToken(); + + if (startToken == XContentParser.Token.START_OBJECT) { + try { + definition = PARSER.apply(parser, set); + } catch (XContentParseException e) { + throw new ParsingException(parser.getTokenLocation(), "Unable to parse XGBoost object", e); + } + if (definition.learner == null) { + throw new ParsingException(parser.getTokenLocation(), "XGBoost model missing required field [learner]"); + } + List unknownFeatures = new ArrayList<>(); + for (String modelFeatureName : definition.learner.featureNames) { + if (!set.hasFeature(modelFeatureName)) { + unknownFeatures.add(modelFeatureName); + } + } + if (!unknownFeatures.isEmpty()) { + throw new ParsingException(parser.getTokenLocation(), "Unknown features in model: [" + + String.join(", ", unknownFeatures) + "]"); + } + if (definition.learner.featureNames.size() != definition.learner.featureTypes.size()) { + throw new ParsingException(parser.getTokenLocation(), + "Feature names list and feature types list must have the same length"); + } + Optional firstUnsupportedType = definition.learner.featureTypes.stream() + .filter(typeStr -> !typeStr.equals("float")) + .findFirst(); + if (firstUnsupportedType.isPresent()) { + throw new ParsingException(parser.getTokenLocation(), + "The LTR plugin only supports float feature types " + + "because Elasticsearch scores are always float32. " + + "Found feature type [" + firstUnsupportedType.get() + "] in model" + ); + } + } else { + throw new ParsingException(parser.getTokenLocation(), "Expected [START_OBJECT] but got [" + startToken + "]"); + } + return definition; + } + + private XGBoostLearner learner; + + public XGBoostLearner getLearner() { + return learner; + } + + public void setLearner(XGBoostLearner learner) { + this.learner = learner; + } + + private List version; + + public List getVersion() { + return version; + } + + public void setVersion(List version) { + this.version = version; + } + } + + static class XGBoostLearner { + + private List featureNames; + private List featureTypes; + private XGBoostGradientBooster gradientBooster; + private XGBoostObjective objective; + + private static final ObjectParser PARSER; + + static { + PARSER = new ObjectParser<>("xgboost_learner", true, XGBoostRawJsonParser.XGBoostLearner::new); + PARSER.declareObject( + XGBoostRawJsonParser.XGBoostLearner::setObjective, + XGBoostRawJsonParser.XGBoostObjective::parse, + new ParseField("objective") + ); + PARSER.declareObject( + XGBoostRawJsonParser.XGBoostLearner::setGradientBooster, + XGBoostRawJsonParser.XGBoostGradientBooster::parse, + new ParseField("gradient_booster") + ); + PARSER.declareStringArray(XGBoostRawJsonParser.XGBoostLearner::setFeatureNames, new ParseField("feature_names")); + PARSER.declareStringArray(XGBoostRawJsonParser.XGBoostLearner::setFeatureTypes, new ParseField("feature_types")); + } + + private void setFeatureTypes(List featureTypes) { + this.featureTypes = featureTypes; + } + + private void setFeatureNames(List featureNames) { + this.featureNames = featureNames; + } + + public static XGBoostRawJsonParser.XGBoostLearner parse(XContentParser parser, FeatureSet set) throws IOException { + return PARSER.apply(parser, set); + } + + XGBoostLearner() { + } + + NaiveAdditiveDecisionTree.Node[] getTrees(FeatureSet set) { + return this.getGradientBooster().getModel().getTrees(); + } + + public XGBoostObjective getObjective() { + return objective; + } + + public void setObjective(XGBoostObjective objective) { + this.objective = objective; + } + + public XGBoostGradientBooster getGradientBooster() { + return gradientBooster; + } + + public void setGradientBooster(XGBoostGradientBooster gradientBooster) { + this.gradientBooster = gradientBooster; + } + } + + static class XGBoostGradientBooster { + private XGBoostModel model; + + private static final ObjectParser PARSER; + + static { + PARSER = new ObjectParser<>("xgboost_gradient_booster", true, XGBoostRawJsonParser.XGBoostGradientBooster::new); + PARSER.declareObject( + XGBoostRawJsonParser.XGBoostGradientBooster::setModel, + XGBoostRawJsonParser.XGBoostModel::parse, + new ParseField("model") + ); + } + + static XGBoostRawJsonParser.XGBoostGradientBooster parse(XContentParser parser, FeatureSet set) throws IOException { + return PARSER.apply(parser, set); + } + + XGBoostGradientBooster() { + } + + public XGBoostModel getModel() { + return model; + } + + public void setModel(XGBoostModel model) { + this.model = model; + } + } + + static class XGBoostModel { + private NaiveAdditiveDecisionTree.Node[] trees; + private List treeInfo; + + private static final ObjectParser PARSER; + + static { + PARSER = new ObjectParser<>("xgboost_model", true, XGBoostRawJsonParser.XGBoostModel::new); + PARSER.declareObjectArray( + XGBoostRawJsonParser.XGBoostModel::setTrees, + XGBoostRawJsonParser.XGBoostTree::parse, + new ParseField("trees") + ); + PARSER.declareIntArray(XGBoostRawJsonParser.XGBoostModel::setTreeInfo, new ParseField("tree_info")); + } + + public List getTreeInfo() { + return treeInfo; + } + + public void setTreeInfo(List treeInfo) { + this.treeInfo = treeInfo; + } + + public static XGBoostRawJsonParser.XGBoostModel parse(XContentParser parser, FeatureSet set) throws IOException { + try { + return PARSER.apply(parser, set); + } catch (IllegalArgumentException e) { + throw new ParsingException(parser.getTokenLocation(), e.getMessage(), e); + } + } + + XGBoostModel() { + } + + public NaiveAdditiveDecisionTree.Node[] getTrees() { + return trees; + } + + public void setTrees(List parsedTrees) { + NaiveAdditiveDecisionTree.Node[] trees = new NaiveAdditiveDecisionTree.Node[parsedTrees.size()]; + ListIterator it = parsedTrees.listIterator(); + while (it.hasNext()) { + trees[it.nextIndex()] = it.next().getRootNode(); + } + this.trees = trees; + } + } + + static class XGBoostObjective { + private Normalizer normalizer; + + private static final ObjectParser PARSER; + + static { + PARSER = new ObjectParser<>("xgboost_objective", true, XGBoostRawJsonParser.XGBoostObjective::new); + PARSER.declareString(XGBoostRawJsonParser.XGBoostObjective::setName, new ParseField("name")); + } + + public static XGBoostRawJsonParser.XGBoostObjective parse(XContentParser parser, FeatureSet set) throws IOException { + return PARSER.apply(parser, set); + } + + XGBoostObjective() { + } + + public void setName(String name) { + switch (name) { + case "binary:logitraw", "rank:ndcg", "rank:map", "rank:pairwise", "reg:linear" -> + this.normalizer = Normalizers.get(Normalizers.NOOP_NORMALIZER_NAME); + case "binary:logistic", "reg:logistic" -> + this.normalizer = Normalizers.get(Normalizers.SIGMOID_NORMALIZER_NAME); + default -> + throw new IllegalArgumentException("Objective [" + name + "] is not a valid XGBoost objective"); + } + } + + Normalizer getNormalizer() { + return this.normalizer; + } + } + + static class XGBoostTree { + private Integer treeId; + private List leftChildren; + private List rightChildren; + private List parents; + private List splitConditions; + private List splitIndices; + private List defaultLeft; + private List splitTypes; + private List baseWeights; + + private NaiveAdditiveDecisionTree.Node rootNode; + + private static final ObjectParser PARSER; + + static { + PARSER = new ObjectParser<>("xgboost_tree", true, XGBoostRawJsonParser.XGBoostTree::new); + PARSER.declareInt(XGBoostRawJsonParser.XGBoostTree::setTreeId, new ParseField("id")); + PARSER.declareIntArray(XGBoostRawJsonParser.XGBoostTree::setLeftChildren, new ParseField("left_children")); + PARSER.declareIntArray(XGBoostRawJsonParser.XGBoostTree::setRightChildren, new ParseField("right_children")); + PARSER.declareIntArray(XGBoostRawJsonParser.XGBoostTree::setParents, new ParseField("parents")); + PARSER.declareFloatArray(XGBoostRawJsonParser.XGBoostTree::setSplitConditions, new ParseField("split_conditions")); + PARSER.declareIntArray(XGBoostRawJsonParser.XGBoostTree::setSplitIndices, new ParseField("split_indices")); + PARSER.declareIntArray(XGBoostRawJsonParser.XGBoostTree::setDefaultLeft, new ParseField("default_left")); + PARSER.declareIntArray(XGBoostRawJsonParser.XGBoostTree::setSplitTypes, new ParseField("split_type")); + PARSER.declareFloatArray(XGBoostRawJsonParser.XGBoostTree::setBaseWeights, new ParseField("base_weights")); + } + + public static XGBoostRawJsonParser.XGBoostTree parse(XContentParser parser, FeatureSet set) throws IOException { + XGBoostRawJsonParser.XGBoostTree tree = PARSER.apply(parser, set); + tree.rootNode = tree.asLibTree(0); + return tree; + } + + public Integer getTreeId() { + return treeId; + } + + public void setTreeId(Integer treeId) { + this.treeId = treeId; + } + + public List getLeftChildren() { + return leftChildren; + } + + public void setLeftChildren(List leftChildren) { + this.leftChildren = leftChildren; + } + + public List getRightChildren() { + return rightChildren; + } + + public void setRightChildren(List rightChildren) { + this.rightChildren = rightChildren; + } + + public List getParents() { + return parents; + } + + public void setParents(List parents) { + this.parents = parents; + } + + public List getSplitConditions() { + return splitConditions; + } + + public void setSplitConditions(List splitConditions) { + this.splitConditions = splitConditions; + } + + public List getSplitIndices() { + return splitIndices; + } + + public void setSplitIndices(List splitIndices) { + this.splitIndices = splitIndices; + } + + public List getDefaultLeft() { + return defaultLeft; + } + + public void setDefaultLeft(List defaultLeft) { + this.defaultLeft = defaultLeft; + } + + public List getSplitTypes() { + return splitTypes; + } + + public void setSplitTypes(List splitTypes) { + this.splitTypes = splitTypes; + } + + private boolean isSplit(Integer nodeId) { + return leftChildren.get(nodeId) != -1 && rightChildren.get(nodeId) != -1; + } + + private NaiveAdditiveDecisionTree.Node asLibTree(Integer nodeId) { + if (nodeId >= leftChildren.size()) { + throw new IllegalArgumentException("Child node reference ID [" + nodeId + "] is invalid"); + } + if (nodeId >= rightChildren.size()) { + throw new IllegalArgumentException("Child node reference ID [" + nodeId + "] is invalid"); + } + + if (isSplit(nodeId)) { + return new NaiveAdditiveDecisionTree.Split(asLibTree(leftChildren.get(nodeId)), asLibTree(rightChildren.get(nodeId)), + splitIndices.get(nodeId), splitConditions.get(nodeId), splitIndices.get(nodeId), MISSING_NODE_ID); + } else { + return new NaiveAdditiveDecisionTree.Leaf(baseWeights.get(nodeId)); + } + } + + public List getBaseWeights() { + return baseWeights; + } + + public void setBaseWeights(List baseWeights) { + this.baseWeights = baseWeights; + } + + public NaiveAdditiveDecisionTree.Node getRootNode() { + return rootNode; + } + } +} diff --git a/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParserTests.java b/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParserTests.java new file mode 100644 index 00000000..857bee46 --- /dev/null +++ b/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostRawJsonParserTests.java @@ -0,0 +1,578 @@ +package com.o19s.es.ltr.ranker.parser; + +import com.o19s.es.ltr.feature.FeatureSet; +import com.o19s.es.ltr.feature.store.StoredFeatureSet; +import com.o19s.es.ltr.ranker.LtrRanker.FeatureVector; +import com.o19s.es.ltr.ranker.SparseFeatureVector; +import com.o19s.es.ltr.ranker.dectree.NaiveAdditiveDecisionTree; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.elasticsearch.common.ParsingException; +import org.hamcrest.CoreMatchers; +import org.junit.Rule; +import org.junit.rules.ExpectedException; + +import java.io.IOException; +import java.util.List; + +import static com.o19s.es.ltr.LtrTestUtils.randomFeature; +import static java.util.Collections.singletonList; + +public class XGBoostRawJsonParserTests extends LuceneTestCase { + private final XGBoostRawJsonParser parser = new XGBoostRawJsonParser(); + + @Rule + public ExpectedException expectedException = ExpectedException.none(); + + public void testSimpleSplit() throws IOException { + String model = + "{" + + " \"learner\":{" + + " \"attributes\":{}," + + " \"feature_names\":[\"feat1\"]," + + " \"feature_types\":[\"float\"]," + + " \"gradient_booster\":{" + + " \"model\":{" + + " \"gbtree_model_param\":{" + + " \"num_parallel_tree\":\"1\"," + + " \"num_trees\":\"1\"}," + + " \"iteration_indptr\":[0,1]," + + " \"tree_info\":[0]," + + " \"trees\":[{" + + " \"base_weights\":[1E0, 10E0, 0E0]," + + " \"categories\":[]," + + " \"categories_nodes\":[]," + + " \"categories_segments\":[]," + + " \"categories_sizes\":[]," + + " \"default_left\":[0, 0, 0]," + + " \"id\":0," + + " \"left_children\":[2, -1, -1]," + + " \"loss_changes\":[0E0, 0E0, 0E0]," + + " \"parents\":[2147483647, 0, 0]," + + " \"right_children\":[1, -1, -1]," + + " \"split_conditions\":[3E0, -1E0, -1E0]," + + " \"split_indices\":[0, 0, 0]," + + " \"split_type\":[0, 0, 0]," + + " \"sum_hessian\":[1E0, 1E0, 1E0]," + + " \"tree_param\":{" + + " \"num_deleted\":\"0\"," + + " \"num_feature\":\"1\"," + + " \"num_nodes\":\"3\"," + + " \"size_leaf_vector\":\"1\"}" + + " }" + + " ]}," + + " \"name\":\"gbtree\"" + + " }," + + " \"learner_model_param\":{" + + " \"base_score\":\"5E-1\"," + + " \"boost_from_average\":\"1\"," + + " \"num_class\":\"0\"," + + " \"num_feature\":\"1\"," + + " \"num_target\":\"1\"" + + " }," + + " \"objective\":{" + + " \"name\":\"reg:linear\"," + + " \"reg_loss_param\":{\"scale_pos_weight\":\"1\"}" + + " }" + + " }," + + " \"version\":[2,1,0]" + + "}"; + + FeatureSet set = new StoredFeatureSet("set", singletonList(randomFeature("feat1"))); + NaiveAdditiveDecisionTree tree = parser.parse(set, model); + FeatureVector featureVector = new SparseFeatureVector(1); + featureVector.setFeatureScore(0, 2); + assertEquals(0.0, tree.score(featureVector), Math.ulp(0.1F)); + + featureVector.setFeatureScore(0, 4); + assertEquals(10.0, tree.score(featureVector), Math.ulp(0.1F)); + } + + public void testReadWithLogisticObjective() throws IOException { + String model = + "{" + + " \"learner\":{" + + " \"attributes\":{}," + + " \"feature_names\":[\"feat1\"]," + + " \"feature_types\":[\"float\"]," + + " \"gradient_booster\":{" + + " \"model\":{" + + " \"gbtree_model_param\":{" + + " \"num_parallel_tree\":\"1\"," + + " \"num_trees\":\"1\"}," + + " \"iteration_indptr\":[0,1]," + + " \"tree_info\":[0]," + + " \"trees\":[{" + + " \"base_weights\":[1E0, -2E-1, 5E-1]," + + " \"categories\":[]," + + " \"categories_nodes\":[]," + + " \"categories_segments\":[]," + + " \"categories_sizes\":[]," + + " \"default_left\":[0, 0, 0]," + + " \"id\":0," + + " \"left_children\":[2, -1, -1]," + + " \"loss_changes\":[0E0, 0E0, 0E0]," + + " \"parents\":[2147483647, 0, 0]," + + " \"right_children\":[1, -1, -1]," + + " \"split_conditions\":[3E0, -1E0, -1E0]," + + " \"split_indices\":[0, 0, 0]," + + " \"split_type\":[0, 0, 0]," + + " \"sum_hessian\":[1E0, 1E0, 1E0]," + + " \"tree_param\":{" + + " \"num_deleted\":\"0\"," + + " \"num_feature\":\"1\"," + + " \"num_nodes\":\"3\"," + + " \"size_leaf_vector\":\"1\"}" + + " }" + + " ]}," + + " \"name\":\"gbtree\"" + + " }," + + " \"learner_model_param\":{" + + " \"base_score\":\"5E-1\"," + + " \"boost_from_average\":\"1\"," + + " \"num_class\":\"0\"," + + " \"num_feature\":\"1\"," + + " \"num_target\":\"1\"" + + " }," + + " \"objective\":{" + + " \"name\":\"reg:logistic\"," + + " \"reg_loss_param\":{\"scale_pos_weight\":\"1\"}" + + " }" + + " }," + + " \"version\":[2,1,0]" + + "}"; + + FeatureSet set = new StoredFeatureSet("set", singletonList(randomFeature("feat1"))); + NaiveAdditiveDecisionTree tree = parser.parse(set, model); + FeatureVector v = tree.newFeatureVector(null); + v.setFeatureScore(0, 2); + assertEquals(0.62245935F, tree.score(v), Math.ulp(0.62245935F)); + v.setFeatureScore(0, 4); + assertEquals(0.45016602F, tree.score(v), Math.ulp(0.45016602F)); + } + + public void testBadObjectiveParam() throws IOException { + String model = + "{" + + " \"learner\":{" + + " \"attributes\":{}," + + " \"feature_names\":[\"feat1\", \"feat2\"]," + + " \"feature_types\":[\"float\", \"float\"]," + + " \"gradient_booster\":{" + + " \"model\":{" + + " \"gbtree_model_param\":{" + + " \"num_parallel_tree\":\"1\"," + + " \"num_trees\":\"1\"}," + + " \"iteration_indptr\":[0,1]," + + " \"tree_info\":[0]," + + " \"trees\":[{" + + " \"base_weights\":[1E0, 10E0, 0E0]," + + " \"categories\":[]," + + " \"categories_nodes\":[]," + + " \"categories_segments\":[]," + + " \"categories_sizes\":[]," + + " \"default_left\":[0, 0, 0]," + + " \"id\":0," + + " \"left_children\":[2, -1, -1]," + + " \"loss_changes\":[0E0, 0E0, 0E0]," + + " \"parents\":[2147483647, 0, 0]," + + " \"right_children\":[1, -1, -1]," + + " \"split_conditions\":[3E0, -1E0, -1E0]," + + " \"split_indices\":[0, 0, 0]," + + " \"split_type\":[0, 0, 0]," + + " \"sum_hessian\":[1E0, 1E0, 1E0]," + + " \"tree_param\":{" + + " \"num_deleted\":\"0\"," + + " \"num_feature\":\"2\"," + + " \"num_nodes\":\"3\"," + + " \"size_leaf_vector\":\"1\"}" + + " }" + + " ]}," + + " \"name\":\"gbtree\"" + + " }," + + " \"learner_model_param\":{" + + " \"base_score\":\"5E-1\"," + + " \"boost_from_average\":\"1\"," + + " \"num_class\":\"0\"," + + " \"num_feature\":\"2\"," + + " \"num_target\":\"1\"" + + " }," + + " \"objective\":{" + + " \"name\":\"reg:invalid\"," + + " \"reg_loss_param\":{\"scale_pos_weight\":\"1\"}" + + " }" + + " }," + + " \"version\":[2,1,0]" + + "}"; + + FeatureSet set = new StoredFeatureSet("set", singletonList(randomFeature("feat1"))); + assertThat(expectThrows(ParsingException.class, () -> parser.parse(set, model)).getMessage(), + CoreMatchers.containsString("Unable to parse XGBoost object")); + } + + public void testBadFeatureTypeParam() throws IOException { + String model = + "{" + + " \"learner\":{" + + " \"attributes\":{}," + + " \"feature_names\":[\"feat1\"]," + + " \"feature_types\":[\"int\"]," + + " \"gradient_booster\":{" + + " \"model\":{" + + " \"gbtree_model_param\":{" + + " \"num_parallel_tree\":\"1\"," + + " \"num_trees\":\"1\"}," + + " \"iteration_indptr\":[0,1]," + + " \"tree_info\":[0]," + + " \"trees\":[{" + + " \"base_weights\":[1E0, 10E0, 0E0]," + + " \"categories\":[]," + + " \"categories_nodes\":[]," + + " \"categories_segments\":[]," + + " \"categories_sizes\":[]," + + " \"default_left\":[0, 0, 0]," + + " \"id\":0," + + " \"left_children\":[2, -1, -1]," + + " \"loss_changes\":[0E0, 0E0, 0E0]," + + " \"parents\":[2147483647, 0, 0]," + + " \"right_children\":[1, -1, -1]," + + " \"split_conditions\":[3E0, -1E0, -1E0]," + + " \"split_indices\":[0, 0, 0]," + + " \"split_type\":[0, 0, 0]," + + " \"sum_hessian\":[1E0, 1E0, 1E0]," + + " \"tree_param\":{" + + " \"num_deleted\":\"0\"," + + " \"num_feature\":\"1\"," + + " \"num_nodes\":\"3\"," + + " \"size_leaf_vector\":\"1\"}" + + " }" + + " ]}," + + " \"name\":\"gbtree\"" + + " }," + + " \"learner_model_param\":{" + + " \"base_score\":\"5E-1\"," + + " \"boost_from_average\":\"1\"," + + " \"num_class\":\"0\"," + + " \"num_feature\":\"1\"," + + " \"num_target\":\"1\"" + + " }," + + " \"objective\":{" + + " \"name\":\"reg:linear\"," + + " \"reg_loss_param\":{\"scale_pos_weight\":\"1\"}" + + " }" + + " }," + + " \"version\":[2,1,0]" + + "}"; + + FeatureSet set = new StoredFeatureSet("set", singletonList(randomFeature("feat1"))); + assertThat(expectThrows(ParsingException.class, () -> parser.parse(set, model)).getMessage(), + CoreMatchers.containsString("The LTR plugin only supports float feature types because " + + "Elasticsearch scores are always float32. Found feature type [int] in model")); + } + + public void testMismatchingFeatureList() throws IOException { + String model = + "{" + + " \"learner\":{" + + " \"attributes\":{}," + + " \"feature_names\":[\"feat1\", \"feat2\"]," + + " \"feature_types\":[\"float\"]," + + " \"gradient_booster\":{" + + " \"model\":{" + + " \"gbtree_model_param\":{" + + " \"num_parallel_tree\":\"1\"," + + " \"num_trees\":\"1\"}," + + " \"iteration_indptr\":[0,1]," + + " \"tree_info\":[0]," + + " \"trees\":[{" + + " \"base_weights\":[1E0, 10E0, 0E0]," + + " \"categories\":[]," + + " \"categories_nodes\":[]," + + " \"categories_segments\":[]," + + " \"categories_sizes\":[]," + + " \"default_left\":[0, 0, 0]," + + " \"id\":0," + + " \"left_children\":[2, -1, -1]," + + " \"loss_changes\":[0E0, 0E0, 0E0]," + + " \"parents\":[2147483647, 0, 0]," + + " \"right_children\":[1, -1, -1]," + + " \"split_conditions\":[3E0, -1E0, -1E0]," + + " \"split_indices\":[0, 0, 0]," + + " \"split_type\":[0, 0, 0]," + + " \"sum_hessian\":[1E0, 1E0, 1E0]," + + " \"tree_param\":{" + + " \"num_deleted\":\"0\"," + + " \"num_feature\":\"2\"," + + " \"num_nodes\":\"3\"," + + " \"size_leaf_vector\":\"1\"}" + + " }" + + " ]}," + + " \"name\":\"gbtree\"" + + " }," + + " \"learner_model_param\":{" + + " \"base_score\":\"5E-1\"," + + " \"boost_from_average\":\"1\"," + + " \"num_class\":\"0\"," + + " \"num_feature\":\"2\"," + + " \"num_target\":\"1\"" + + " }," + + " \"objective\":{" + + " \"name\":\"reg:logistic\"," + + " \"reg_loss_param\":{\"scale_pos_weight\":\"1\"}" + + " }" + + " }," + + " \"version\":[2,1,0]" + + "}"; + + FeatureSet set = new StoredFeatureSet("set", List.of(randomFeature("feat1"), randomFeature("feat2"))); + assertThat(expectThrows(ParsingException.class, () -> parser.parse(set, model)).getMessage(), + CoreMatchers.containsString("Feature names list and feature types list must have the same length")); + } + + public void testSplitMissingLeftChild() throws IOException { + String model = + "{" + + " \"learner\":{" + + " \"attributes\":{}," + + " \"feature_names\":[\"feat1\"]," + + " \"feature_types\":[\"float\"]," + + " \"gradient_booster\":{" + + " \"model\":{" + + " \"gbtree_model_param\":{" + + " \"num_parallel_tree\":\"1\"," + + " \"num_trees\":\"1\"}," + + " \"iteration_indptr\":[0,1]," + + " \"tree_info\":[0]," + + " \"trees\":[{" + + " \"base_weights\":[1E0, 10E0, 0E0]," + + " \"categories\":[]," + + " \"categories_nodes\":[]," + + " \"categories_segments\":[]," + + " \"categories_sizes\":[]," + + " \"default_left\":[0, 0, 0]," + + " \"id\":0," + + " \"left_children\":[100, -1, -1]," + + " \"loss_changes\":[0E0, 0E0, 0E0]," + + " \"parents\":[2147483647, 0, 0]," + + " \"right_children\":[1, -1, -1]," + + " \"split_conditions\":[3E0, -1E0, -1E0]," + + " \"split_indices\":[0, 0, 0]," + + " \"split_type\":[0, 0, 0]," + + " \"sum_hessian\":[1E0, 1E0, 1E0]," + + " \"tree_param\":{" + + " \"num_deleted\":\"0\"," + + " \"num_feature\":\"1\"," + + " \"num_nodes\":\"3\"," + + " \"size_leaf_vector\":\"1\"}" + + " }" + + " ]}," + + " \"name\":\"gbtree\"" + + " }," + + " \"learner_model_param\":{" + + " \"base_score\":\"5E-1\"," + + " \"boost_from_average\":\"1\"," + + " \"num_class\":\"0\"," + + " \"num_feature\":\"1\"," + + " \"num_target\":\"1\"" + + " }," + + " \"objective\":{" + + " \"name\":\"reg:linear\"," + + " \"reg_loss_param\":{\"scale_pos_weight\":\"1\"}" + + " }" + + " }," + + " \"version\":[2,1,0]" + + "}"; + + try { + FeatureSet set = new StoredFeatureSet("set", singletonList(randomFeature("feat1"))); + parser.parse(set, model); + fail("Expected an exception"); + } catch (ParsingException e) { + assertThat(e.getMessage(), CoreMatchers.containsString("Unable to parse XGBoost object")); + Throwable rootCause = e.getCause().getCause().getCause().getCause().getCause().getCause(); + assertThat(rootCause, CoreMatchers.instanceOf(IllegalArgumentException.class)); + assertThat(rootCause.getMessage(), CoreMatchers.containsString("Child node reference ID [100] is invalid")); + } + } + + public void testSplitMissingRightChild() throws IOException { + String model = + "{" + + " \"learner\":{" + + " \"attributes\":{}," + + " \"feature_names\":[\"feat1\"]," + + " \"feature_types\":[\"float\"]," + + " \"gradient_booster\":{" + + " \"model\":{" + + " \"gbtree_model_param\":{" + + " \"num_parallel_tree\":\"1\"," + + " \"num_trees\":\"1\"}," + + " \"iteration_indptr\":[0,1]," + + " \"tree_info\":[0]," + + " \"trees\":[{" + + " \"base_weights\":[1E0, 10E0, 0E0]," + + " \"categories\":[]," + + " \"categories_nodes\":[]," + + " \"categories_segments\":[]," + + " \"categories_sizes\":[]," + + " \"default_left\":[0, 0, 0]," + + " \"id\":0," + + " \"left_children\":[1, -1, -1]," + + " \"loss_changes\":[0E0, 0E0, 0E0]," + + " \"parents\":[2147483647, 0, 0]," + + " \"right_children\":[100, -1, -1]," + + " \"split_conditions\":[3E0, -1E0, -1E0]," + + " \"split_indices\":[0, 0, 0]," + + " \"split_type\":[0, 0, 0]," + + " \"sum_hessian\":[1E0, 1E0, 1E0]," + + " \"tree_param\":{" + + " \"num_deleted\":\"0\"," + + " \"num_feature\":\"1\"," + + " \"num_nodes\":\"3\"," + + " \"size_leaf_vector\":\"1\"}" + + " }" + + " ]}," + + " \"name\":\"gbtree\"" + + " }," + + " \"learner_model_param\":{" + + " \"base_score\":\"5E-1\"," + + " \"boost_from_average\":\"1\"," + + " \"num_class\":\"0\"," + + " \"num_feature\":\"1\"," + + " \"num_target\":\"1\"" + + " }," + + " \"objective\":{" + + " \"name\":\"reg:linear\"," + + " \"reg_loss_param\":{\"scale_pos_weight\":\"1\"}" + + " }" + + " }," + + " \"version\":[2,1,0]" + + "}"; + + try { + FeatureSet set = new StoredFeatureSet("set", singletonList(randomFeature("feat1"))); + parser.parse(set, model); + fail("Expected an exception"); + } catch (ParsingException e) { + assertThat(e.getMessage(), CoreMatchers.containsString("Unable to parse XGBoost object")); + Throwable rootCause = e.getCause().getCause().getCause().getCause().getCause().getCause(); + assertThat(rootCause, CoreMatchers.instanceOf(IllegalArgumentException.class)); + assertThat(rootCause.getMessage(), CoreMatchers.containsString("Child node reference ID [100] is invalid")); + } + } + + public void testBadStruct() throws IOException { + String model = + "[{" + + " \"learner\":{" + + " \"attributes\":{}," + + " \"feature_names\":[\"feat1\", \"feat2\"]," + + " \"feature_types\":[\"float\", \"float\"]," + + " \"gradient_booster\":{" + + " \"model\":{" + + " \"gbtree_model_param\":{" + + " \"num_parallel_tree\":\"1\"," + + " \"num_trees\":\"1\"}," + + " \"iteration_indptr\":[0,1]," + + " \"tree_info\":[0]," + + " \"trees\":[{" + + " \"base_weights\":[1E0, 10E0, 0E0]," + + " \"categories\":[]," + + " \"categories_nodes\":[]," + + " \"categories_segments\":[]," + + " \"categories_sizes\":[]," + + " \"default_left\":[0, 0, 0]," + + " \"id\":0," + + " \"left_children\":[2, -1, -1]," + + " \"loss_changes\":[0E0, 0E0, 0E0]," + + " \"parents\":[2147483647, 0, 0]," + + " \"right_children\":[1, -1, -1]," + + " \"split_conditions\":[3E0, -1E0, -1E0]," + + " \"split_indices\":[0, 0, 0]," + + " \"split_type\":[0, 0, 0]," + + " \"sum_hessian\":[1E0, 1E0, 1E0]," + + " \"tree_param\":{" + + " \"num_deleted\":\"0\"," + + " \"num_feature\":\"2\"," + + " \"num_nodes\":\"3\"," + + " \"size_leaf_vector\":\"1\"}" + + " }" + + " ]}," + + " \"name\":\"gbtree\"" + + " }," + + " \"learner_model_param\":{" + + " \"base_score\":\"5E-1\"," + + " \"boost_from_average\":\"1\"," + + " \"num_class\":\"0\"," + + " \"num_feature\":\"2\"," + + " \"num_target\":\"1\"" + + " }," + + " \"objective\":{" + + " \"name\":\"reg:linear\"," + + " \"reg_loss_param\":{\"scale_pos_weight\":\"1\"}" + + " }" + + " }," + + " \"version\":[2,1,0]" + + "}]"; + FeatureSet set = new StoredFeatureSet("set", singletonList(randomFeature("feat1"))); + assertThat(expectThrows(ParsingException.class, () -> parser.parse(set, model)).getMessage(), + CoreMatchers.containsString("Expected [START_OBJECT] but got")); + } + + public void testMissingFeat() throws IOException { + String model = + "{" + + " \"learner\":{" + + " \"attributes\":{}," + + " \"feature_names\":[\"feat1\", \"feat2\"]," + + " \"feature_types\":[\"float\",\"float\"]," + + " \"gradient_booster\":{" + + " \"model\":{" + + " \"gbtree_model_param\":{" + + " \"num_parallel_tree\":\"1\"," + + " \"num_trees\":\"1\"}," + + " \"iteration_indptr\":[0,1]," + + " \"tree_info\":[0]," + + " \"trees\":[{" + + " \"base_weights\":[1E0, 10E0, 0E0]," + + " \"categories\":[]," + + " \"categories_nodes\":[]," + + " \"categories_segments\":[]," + + " \"categories_sizes\":[]," + + " \"default_left\":[0, 0, 0]," + + " \"id\":0," + + " \"left_children\":[2, -1, -1]," + + " \"loss_changes\":[0E0, 0E0, 0E0]," + + " \"parents\":[2147483647, 0, 0]," + + " \"right_children\":[1, -1, -1]," + + " \"split_conditions\":[3E0, -1E0, -1E0]," + + " \"split_indices\":[0, 0, 100]," + + " \"split_type\":[0, 0, 0]," + + " \"sum_hessian\":[1E0, 1E0, 1E0]," + + " \"tree_param\":{" + + " \"num_deleted\":\"0\"," + + " \"num_feature\":\"2\"," + + " \"num_nodes\":\"3\"," + + " \"size_leaf_vector\":\"1\"}" + + " }" + + " ]}," + + " \"name\":\"gbtree\"" + + " }," + + " \"learner_model_param\":{" + + " \"base_score\":\"5E-1\"," + + " \"boost_from_average\":\"1\"," + + " \"num_class\":\"0\"," + + " \"num_feature\":\"2\"," + + " \"num_target\":\"1\"" + + " }," + + " \"objective\":{" + + " \"name\":\"reg:linear\"," + + " \"reg_loss_param\":{\"scale_pos_weight\":\"1\"}" + + " }" + + " }," + + " \"version\":[2,1,0]" + + "}"; + FeatureSet set = new StoredFeatureSet("set", singletonList(randomFeature("feat1234"))); + assertThat(expectThrows(ParsingException.class, () -> parser.parse(set, model)).getMessage(), + CoreMatchers.containsString("Unknown features in model: [feat1, feat2]")); + } +}