From d2bbe8bd4e3d0b596af1f61e65643e3b40911ba8 Mon Sep 17 00:00:00 2001 From: patrick-le-shopify <156460083+patrick-le-shopify@users.noreply.github.com> Date: Wed, 31 Jan 2024 00:14:54 -0800 Subject: [PATCH] Issue 481 - implement support for missing values with XGBoost (#482) * sparse feature vector, naive decision tree * add integration test * turn back on testLog tests * minor edits * add default value to explanation * lint * add tests --- .../java/com/o19s/es/ltr/NodeSettingsIT.java | 6 +- .../o19s/es/ltr/query/StoredLtrQueryIT.java | 78 ++++++++++++++++++- .../com/o19s/es/ltr/query/RankerQuery.java | 3 +- .../es/ltr/ranker/ArrayFeatureVector.java | 48 ++++++++++++ .../es/ltr/ranker/DenseFeatureVector.java | 29 +------ .../com/o19s/es/ltr/ranker/LogLtrRanker.java | 4 + .../com/o19s/es/ltr/ranker/LtrRanker.java | 8 +- .../es/ltr/ranker/SparseFeatureVector.java | 25 ++++++ .../o19s/es/ltr/ranker/SparseLtrRanker.java | 47 +++++++++++ .../dectree/NaiveAdditiveDecisionTree.java | 8 +- .../ranklib/DenseProgramaticDataPoint.java | 6 +- .../ltr/ranker/DenseFeatureVectorTests.java | 49 ++++++++++++ .../ltr/ranker/SparseFeatureVectorTests.java | 49 ++++++++++++ .../NaiveAdditiveDecisionTreeTests.java | 14 +++- .../ranker/parser/XGBoostJsonParserTests.java | 4 +- 15 files changed, 331 insertions(+), 47 deletions(-) create mode 100644 src/main/java/com/o19s/es/ltr/ranker/ArrayFeatureVector.java create mode 100644 src/main/java/com/o19s/es/ltr/ranker/SparseFeatureVector.java create mode 100644 src/main/java/com/o19s/es/ltr/ranker/SparseLtrRanker.java create mode 100644 src/test/java/com/o19s/es/ltr/ranker/DenseFeatureVectorTests.java create mode 100644 src/test/java/com/o19s/es/ltr/ranker/SparseFeatureVectorTests.java diff --git a/src/javaRestTest/java/com/o19s/es/ltr/NodeSettingsIT.java b/src/javaRestTest/java/com/o19s/es/ltr/NodeSettingsIT.java index 291aed1c..f3a17570 100644 --- a/src/javaRestTest/java/com/o19s/es/ltr/NodeSettingsIT.java +++ b/src/javaRestTest/java/com/o19s/es/ltr/NodeSettingsIT.java @@ -81,14 +81,14 @@ public void testCacheSettings() throws IOException, InterruptedException { public static class DummyModel extends CompiledLtrModel { public DummyModel(String name, long size) throws IOException { - super(name, LtrTestUtils.randomFeatureSet(1), new DummryRanker(size)); + super(name, LtrTestUtils.randomFeatureSet(1), new DummyRanker(size)); } } - public static class DummryRanker implements LtrRanker, Accountable { + public static class DummyRanker implements LtrRanker, Accountable { private final long ramUsed; - public DummryRanker(long ramUsed) { + public DummyRanker(long ramUsed) { this.ramUsed = ramUsed; } diff --git a/src/javaRestTest/java/com/o19s/es/ltr/query/StoredLtrQueryIT.java b/src/javaRestTest/java/com/o19s/es/ltr/query/StoredLtrQueryIT.java index 2f095cea..21077057 100644 --- a/src/javaRestTest/java/com/o19s/es/ltr/query/StoredLtrQueryIT.java +++ b/src/javaRestTest/java/com/o19s/es/ltr/query/StoredLtrQueryIT.java @@ -26,23 +26,32 @@ import com.o19s.es.ltr.action.CreateModelFromSetAction.CreateModelFromSetRequestBuilder; import com.o19s.es.ltr.feature.store.ScriptFeature; import com.o19s.es.ltr.feature.store.StoredFeature; +import com.o19s.es.ltr.feature.store.StoredFeatureSet; import com.o19s.es.ltr.feature.store.StoredLtrModel; import com.o19s.es.ltr.feature.store.index.IndexFeatureStore; +import com.o19s.es.ltr.logging.LoggingSearchExtBuilder; import org.elasticsearch.action.search.SearchRequestBuilder; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.WrapperQueryBuilder; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.rescore.QueryRescoreMode; import org.elasticsearch.search.rescore.QueryRescorerBuilder; import org.hamcrest.CoreMatchers; import org.hamcrest.Matchers; +import java.util.concurrent.ExecutionException; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; -import java.util.concurrent.ExecutionException; + +import static org.hamcrest.CoreMatchers.containsString; /** * Created by doug on 12/29/16. @@ -63,6 +72,69 @@ public class StoredLtrQueryIT extends BaseIntegrationTest { "\"feature6\": 1" + "}"; + private static final String SIMPLE_MODEL_XGB = "[{" + + "\"nodeid\": 0," + + "\"split\":\"text_feature1\"," + + "\"depth\":0," + + "\"split_condition\":100.0," + + "\"yes\":1," + + "\"no\":2," + + "\"missing\":2," + + "\"children\": [" + + " {\"nodeid\": 1, \"depth\": 1, \"leaf\": 0.5}," + + " {\"nodeid\": 2, \"depth\": 1, \"leaf\": 0.2}" + + "]}]"; + + + public void testScriptFeatureUseCaseMissingFeatureNaiveAdditiveDecisionTree() throws Exception { + List features = new ArrayList<>(1); + features.add(new StoredFeature("text_feature1", Collections.singletonList("query"), "mustache", + QueryBuilders.matchQuery("field1", "{{query}}").toString())); + + StoredFeatureSet set = new StoredFeatureSet("my_set", features); + addElement(set); + StoredLtrModel model = new StoredLtrModel("my_model", set, + new StoredLtrModel.LtrModelDefinition("model/xgboost+json", + SIMPLE_MODEL_XGB, true)); + addElement(model); + + buildIndex(); + + Map params = new HashMap<>(); + params.put("query", "bonjour"); + StoredLtrQueryBuilder sbuilder = new StoredLtrQueryBuilder(LtrTestUtils.nullLoader()) + .featureSetName("my_set") + .modelName("my_model") + .params(params) + .queryName("test") + .boost(1); + + QueryBuilder query = QueryBuilders.boolQuery().must(new WrapperQueryBuilder(sbuilder.toString())); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().query(query) + .explain(true) + .fetchSource(true) + .size(10) + .ext(Collections.singletonList( + new LoggingSearchExtBuilder() + .addQueryLogging("log", "test", false))); + + SearchResponse resp = client().prepareSearch("test_index").setSource(sourceBuilder).get(); + SearchHit hit = resp.getHits().getAt(0); + assertTrue(hit.getFields().containsKey("_ltrlog")); + Map>> logs = hit.getFields().get("_ltrlog").getValue(); + assertTrue(logs.containsKey("log")); + List> log = logs.get("log"); + + // verify that text_feature1 has a missing value, and that the reported score results from the model taking the + // corresponding branch, along with the explanation + String explanation = hit.getExplanation().getDetails()[0].getDescription(); + assertThat(explanation, containsString("default value of NaN used")); + + assertEquals("text_feature1", log.get(0).get("name")); + assertEquals(null, log.get(0).get("value")); + + assertEquals(0.2F, hit.getScore(), Math.ulp(0.2F)); + } public void testScriptFeatureUseCase() throws Exception { addElement(new StoredFeature("feature1", Collections.singletonList("query"), "mustache", @@ -260,7 +332,7 @@ public void testFullUsecase() throws Exception { StoredLtrModel model = getElement(StoredLtrModel.class, StoredLtrModel.TYPE, "my_model"); CachesStatsNodesResponse stats = client().execute(CachesStatsAction.INSTANCE, - new CachesStatsAction.CachesStatsNodesRequest()).get(); + new CachesStatsAction.CachesStatsNodesRequest()).get(); assertEquals(1, stats.getAll().getTotal().getCount()); assertEquals(model.compile(parserFactory()).ramBytesUsed(), stats.getAll().getTotal().getRam()); assertEquals(1, stats.getAll().getModels().getCount()); @@ -301,6 +373,4 @@ public void buildIndex() { .setSource("field1", "hello world", "field2", "bonjour world") .get(); } - - } diff --git a/src/main/java/com/o19s/es/ltr/query/RankerQuery.java b/src/main/java/com/o19s/es/ltr/query/RankerQuery.java index f9938071..957a0d84 100644 --- a/src/main/java/com/o19s/es/ltr/query/RankerQuery.java +++ b/src/main/java/com/o19s/es/ltr/query/RankerQuery.java @@ -259,7 +259,8 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio } featureString += ":"; if (!explain.isMatch()) { - subs.add(Explanation.noMatch(featureString + " [no match, default value 0.0 used]")); + subs.add(Explanation.noMatch(featureString + + String.format(" [no match, default value of %.2f used]", d.getDefaultScore()))); } else { subs.add(Explanation.match(explain.getValue(), featureString, explain)); d.setFeatureScore(ordinal, explain.getValue().floatValue()); diff --git a/src/main/java/com/o19s/es/ltr/ranker/ArrayFeatureVector.java b/src/main/java/com/o19s/es/ltr/ranker/ArrayFeatureVector.java new file mode 100644 index 00000000..4477c647 --- /dev/null +++ b/src/main/java/com/o19s/es/ltr/ranker/ArrayFeatureVector.java @@ -0,0 +1,48 @@ +/* + * Copyright [2017] Wikimedia Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.o19s.es.ltr.ranker; + +import java.util.Arrays; + +public class ArrayFeatureVector implements LtrRanker.FeatureVector { + public final float[] scores; + public final float defaultScore; + + public ArrayFeatureVector(int size, float value) { + scores = new float[size]; + defaultScore = value; + } + + @Override + public void setFeatureScore(int featureIdx, float score) { + scores[featureIdx] = score; + } + + @Override + public float getFeatureScore(int featureIdx) { + return scores[featureIdx]; + } + + public void reset() { + Arrays.fill(scores, defaultScore); + } + + @Override + public float getDefaultScore() { + return defaultScore; + } +} diff --git a/src/main/java/com/o19s/es/ltr/ranker/DenseFeatureVector.java b/src/main/java/com/o19s/es/ltr/ranker/DenseFeatureVector.java index 153b2887..bb8e6b1f 100644 --- a/src/main/java/com/o19s/es/ltr/ranker/DenseFeatureVector.java +++ b/src/main/java/com/o19s/es/ltr/ranker/DenseFeatureVector.java @@ -16,34 +16,9 @@ package com.o19s.es.ltr.ranker; -import java.util.Arrays; +public class DenseFeatureVector extends ArrayFeatureVector { -/** - * Simple array-backed feature vector - */ -public class DenseFeatureVector implements LtrRanker.FeatureVector { - public final float[] scores; - - /** - * New simple array-backed datapoint - * - * @param size size of the internal array - */ public DenseFeatureVector(int size) { - this.scores = new float[size]; - } - - @Override - public void setFeatureScore(int featureIdx, float score) { - scores[featureIdx] = score; - } - - @Override - public float getFeatureScore(int featureIdx) { - return scores[featureIdx]; - } - - public void reset() { - Arrays.fill(scores, 0F); + super(size, 0F); } } diff --git a/src/main/java/com/o19s/es/ltr/ranker/LogLtrRanker.java b/src/main/java/com/o19s/es/ltr/ranker/LogLtrRanker.java index 1df48672..b4efe32a 100644 --- a/src/main/java/com/o19s/es/ltr/ranker/LogLtrRanker.java +++ b/src/main/java/com/o19s/es/ltr/ranker/LogLtrRanker.java @@ -80,6 +80,10 @@ void reset(LtrRanker ranker) { this.inner = ranker.newFeatureVector(inner); logger.reset(); } + + public float getDefaultScore() { + return inner.getDefaultScore(); + } } public LogConsumer getLogConsumer() { diff --git a/src/main/java/com/o19s/es/ltr/ranker/LtrRanker.java b/src/main/java/com/o19s/es/ltr/ranker/LtrRanker.java index 3c56c856..3fb3e6f8 100644 --- a/src/main/java/com/o19s/es/ltr/ranker/LtrRanker.java +++ b/src/main/java/com/o19s/es/ltr/ranker/LtrRanker.java @@ -46,7 +46,7 @@ public interface LtrRanker { /** * Score the data point. * At this point all feature scores are set. - * features that did not match are set with a score to 0 + * features that did not match are set with a default score * * @param point the feature vector point to compute the score for * @return the score computed for the given point @@ -73,5 +73,11 @@ interface FeatureVector { */ float getFeatureScore(int featureId); + /** + * Retrieve the default score + * @return the score computed for the given feature + */ + float getDefaultScore(); + } } diff --git a/src/main/java/com/o19s/es/ltr/ranker/SparseFeatureVector.java b/src/main/java/com/o19s/es/ltr/ranker/SparseFeatureVector.java new file mode 100644 index 00000000..72c86767 --- /dev/null +++ b/src/main/java/com/o19s/es/ltr/ranker/SparseFeatureVector.java @@ -0,0 +1,25 @@ +/* + * Copyright [2017] Wikimedia Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.o19s.es.ltr.ranker; + +public class SparseFeatureVector extends ArrayFeatureVector { + + public SparseFeatureVector(int size) { + super(size, Float.NaN); + reset(); + } +} diff --git a/src/main/java/com/o19s/es/ltr/ranker/SparseLtrRanker.java b/src/main/java/com/o19s/es/ltr/ranker/SparseLtrRanker.java new file mode 100644 index 00000000..8af06541 --- /dev/null +++ b/src/main/java/com/o19s/es/ltr/ranker/SparseLtrRanker.java @@ -0,0 +1,47 @@ +/* + * Copyright [2017] Wikimedia Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.o19s.es.ltr.ranker; + +/** + * A dense ranker base class to work with {@link SparseFeatureVector} + * where missing feature scores are set to 0. + */ +public abstract class SparseLtrRanker implements LtrRanker { + @Override + public SparseFeatureVector newFeatureVector(FeatureVector reuse) { + if (reuse != null) { + assert reuse instanceof SparseFeatureVector; + SparseFeatureVector vector = (SparseFeatureVector) reuse; + vector.reset(); + return vector; + } + return new SparseFeatureVector(size()); + } + + @Override + public float score(FeatureVector vector) { + assert vector instanceof SparseFeatureVector; + return this.score((SparseFeatureVector) vector); + } + + protected abstract float score(SparseFeatureVector vector); + + /** + * @return the number of features supported by this ranker + */ + protected abstract int size(); +} diff --git a/src/main/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTree.java b/src/main/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTree.java index 3548bfd8..0a7829ac 100644 --- a/src/main/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTree.java +++ b/src/main/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTree.java @@ -16,8 +16,8 @@ package com.o19s.es.ltr.ranker.dectree; -import com.o19s.es.ltr.ranker.DenseFeatureVector; -import com.o19s.es.ltr.ranker.DenseLtrRanker; +import com.o19s.es.ltr.ranker.SparseFeatureVector; +import com.o19s.es.ltr.ranker.SparseLtrRanker; import com.o19s.es.ltr.ranker.normalizer.Normalizer; import org.apache.lucene.util.Accountable; import org.apache.lucene.util.RamUsageEstimator; @@ -28,7 +28,7 @@ * Naive implementation of additive decision tree. * May be slow when the number of trees and tree complexity if high comparatively to the number of features. */ -public class NaiveAdditiveDecisionTree extends DenseLtrRanker implements Accountable { +public class NaiveAdditiveDecisionTree extends SparseLtrRanker implements Accountable { private static final long BASE_RAM_USED = RamUsageEstimator.shallowSizeOfInstance(Split.class); private final Node[] trees; @@ -60,7 +60,7 @@ public String name() { } @Override - protected float score(DenseFeatureVector vector) { + protected float score(SparseFeatureVector vector) { float sum = 0; float[] scores = vector.scores; for (int i = 0; i < trees.length; i++) { diff --git a/src/main/java/com/o19s/es/ltr/ranker/ranklib/DenseProgramaticDataPoint.java b/src/main/java/com/o19s/es/ltr/ranker/ranklib/DenseProgramaticDataPoint.java index 9dd5eb78..3c6b36bb 100644 --- a/src/main/java/com/o19s/es/ltr/ranker/ranklib/DenseProgramaticDataPoint.java +++ b/src/main/java/com/o19s/es/ltr/ranker/ranklib/DenseProgramaticDataPoint.java @@ -23,7 +23,7 @@ import java.util.Arrays; /** - * Implements FeatureVector but without needing to pass in a stirng + * Implements FeatureVector but without needing to pass in a string * to be parsed */ public class DenseProgramaticDataPoint extends DataPoint implements LtrRanker.FeatureVector { @@ -71,4 +71,8 @@ public float getFeatureScore(int featureIdx) { public void reset() { Arrays.fill(fVals, 0F); } + + public float getDefaultScore() { + return 0.0F; + } } diff --git a/src/test/java/com/o19s/es/ltr/ranker/DenseFeatureVectorTests.java b/src/test/java/com/o19s/es/ltr/ranker/DenseFeatureVectorTests.java new file mode 100644 index 00000000..4f8dbd45 --- /dev/null +++ b/src/test/java/com/o19s/es/ltr/ranker/DenseFeatureVectorTests.java @@ -0,0 +1,49 @@ +/* + * Copyright [2017] Wikimedia Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.o19s.es.ltr.ranker; + +import org.apache.lucene.tests.util.LuceneTestCase; + +public class DenseFeatureVectorTests extends LuceneTestCase { + public void testConstructor() { + int size = 10; + DenseFeatureVector featureVector = new DenseFeatureVector(size); + for (float score : featureVector.scores) { + assertEquals(0F, score, Math.ulp(0F)); + } + } + + public void testSetGetReset() { + int size = 10; + DenseFeatureVector featureVector = new DenseFeatureVector(size); + featureVector.setFeatureScore(5, 3.15F); + + assertEquals(3.15F, featureVector.getFeatureScore(5), Math.ulp(3.15F)); + assertEquals(0F, featureVector.getFeatureScore(0), Math.ulp(0F)); + + featureVector.reset(); + + for (int featureId = 0; featureId < size; featureId++) { + assertEquals(0F, featureVector.getFeatureScore(featureId), Math.ulp(0F)); + } + } + + public void testGetDefaultValue() { + assertEquals(0F, new DenseFeatureVector(10).getDefaultScore(), Math.ulp(0F)); + } + +} \ No newline at end of file diff --git a/src/test/java/com/o19s/es/ltr/ranker/SparseFeatureVectorTests.java b/src/test/java/com/o19s/es/ltr/ranker/SparseFeatureVectorTests.java new file mode 100644 index 00000000..7833abee --- /dev/null +++ b/src/test/java/com/o19s/es/ltr/ranker/SparseFeatureVectorTests.java @@ -0,0 +1,49 @@ +/* + * Copyright [2017] Wikimedia Foundation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.o19s.es.ltr.ranker; + +import org.apache.lucene.tests.util.LuceneTestCase; + +public class SparseFeatureVectorTests extends LuceneTestCase { + public void testConstructor() { + int size = 10; + SparseFeatureVector featureVector = new SparseFeatureVector(size); + for (float score : featureVector.scores) { + assertTrue(Float.isNaN(score)); + } + } + + public void testSetGetReset() { + int size = 10; + SparseFeatureVector featureVector = new SparseFeatureVector(size); + featureVector.setFeatureScore(5, 3.15F); + + assertEquals(3.15F, featureVector.getFeatureScore(5), Math.ulp(3.15F)); + assertTrue(Float.isNaN(featureVector.getFeatureScore(0))); + + featureVector.reset(); + + for (int featureId = 0; featureId < size; featureId++) { + assertTrue(Float.isNaN(featureVector.getFeatureScore(featureId))); + } + } + + public void testGetDefaultValue() { + assertTrue(Float.isNaN(new SparseFeatureVector(10).getDefaultScore())); + } + +} \ No newline at end of file diff --git a/src/test/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java b/src/test/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java index a4f17fed..d6479356 100644 --- a/src/test/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java +++ b/src/test/java/com/o19s/es/ltr/ranker/dectree/NaiveAdditiveDecisionTreeTests.java @@ -19,8 +19,8 @@ import com.o19s.es.ltr.feature.FeatureSet; import com.o19s.es.ltr.feature.PrebuiltFeature; import com.o19s.es.ltr.feature.PrebuiltFeatureSet; -import com.o19s.es.ltr.ranker.DenseFeatureVector; import com.o19s.es.ltr.ranker.LtrRanker; +import com.o19s.es.ltr.ranker.SparseFeatureVector; import com.o19s.es.ltr.ranker.linear.LinearRankerTests; import com.o19s.es.ltr.ranker.normalizer.Normalizer; import com.o19s.es.ltr.ranker.normalizer.Normalizers; @@ -100,15 +100,15 @@ public void testPerfAndRobustness() { 100, 1000, 5, 50, counts); - DenseFeatureVector vector = ranker.newFeatureVector(null); + SparseFeatureVector vector = ranker.newFeatureVector(null); int nPass = TestUtil.nextInt(random(), 10, 8916); - LinearRankerTests.fillRandomWeights(vector.scores); + fillRandomWeights(vector.scores); ranker.score(vector); // warmup long time = -System.currentTimeMillis(); for (int i = 0; i < nPass; i++) { vector = ranker.newFeatureVector(vector); - LinearRankerTests.fillRandomWeights(vector.scores); + fillRandomWeights(vector.scores); ranker.score(vector); } time += System.currentTimeMillis(); @@ -340,4 +340,10 @@ public void newTree() { trees.incrementAndGet(); } } + public static void fillRandomWeights(float[] weights) { + for (int i = 0; i < weights.length; i++) { + if (random().nextBoolean()) + weights[i] = (float) nextInt(random(),1, 100000) / (float) nextInt(random(), 1, 100000); + } + } } \ No newline at end of file diff --git a/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserTests.java b/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserTests.java index 0324033e..c5498387 100644 --- a/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserTests.java +++ b/src/test/java/com/o19s/es/ltr/ranker/parser/XGBoostJsonParserTests.java @@ -20,7 +20,7 @@ import com.o19s.es.ltr.feature.FeatureSet; import com.o19s.es.ltr.feature.store.StoredFeature; import com.o19s.es.ltr.feature.store.StoredFeatureSet; -import com.o19s.es.ltr.ranker.DenseFeatureVector; +import com.o19s.es.ltr.ranker.SparseFeatureVector; import com.o19s.es.ltr.ranker.LtrRanker.FeatureVector; import com.o19s.es.ltr.ranker.dectree.NaiveAdditiveDecisionTree; import com.o19s.es.ltr.ranker.linear.LinearRankerTests; @@ -268,7 +268,7 @@ public void testComplexModel() throws Exception { StoredFeatureSet set = new StoredFeatureSet("set", features); NaiveAdditiveDecisionTree tree = parser.parse(set, model); - DenseFeatureVector v = tree.newFeatureVector(null); + SparseFeatureVector v = tree.newFeatureVector(null); assertEquals(v.scores.length, features.size()); for (int i = random().nextInt(5000) + 1000; i > 0; i--) {