Skip to content

Commit

Permalink
Issue 481 - implement support for missing values with XGBoost (#482)
Browse files Browse the repository at this point in the history
* sparse feature vector, naive decision tree

* add integration test

* turn back on testLog tests

* minor edits

* add default value to explanation

* lint

* add tests
  • Loading branch information
patrick-le-shopify authored Jan 31, 2024
1 parent 05ad18c commit d2bbe8b
Show file tree
Hide file tree
Showing 15 changed files with 331 additions and 47 deletions.
6 changes: 3 additions & 3 deletions src/javaRestTest/java/com/o19s/es/ltr/NodeSettingsIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,14 @@ public void testCacheSettings() throws IOException, InterruptedException {

public static class DummyModel extends CompiledLtrModel {
public DummyModel(String name, long size) throws IOException {
super(name, LtrTestUtils.randomFeatureSet(1), new DummryRanker(size));
super(name, LtrTestUtils.randomFeatureSet(1), new DummyRanker(size));
}
}

public static class DummryRanker implements LtrRanker, Accountable {
public static class DummyRanker implements LtrRanker, Accountable {
private final long ramUsed;

public DummryRanker(long ramUsed) {
public DummyRanker(long ramUsed) {
this.ramUsed = ramUsed;
}

Expand Down
78 changes: 74 additions & 4 deletions src/javaRestTest/java/com/o19s/es/ltr/query/StoredLtrQueryIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,32 @@
import com.o19s.es.ltr.action.CreateModelFromSetAction.CreateModelFromSetRequestBuilder;
import com.o19s.es.ltr.feature.store.ScriptFeature;
import com.o19s.es.ltr.feature.store.StoredFeature;
import com.o19s.es.ltr.feature.store.StoredFeatureSet;
import com.o19s.es.ltr.feature.store.StoredLtrModel;
import com.o19s.es.ltr.feature.store.index.IndexFeatureStore;
import com.o19s.es.ltr.logging.LoggingSearchExtBuilder;
import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.WrapperQueryBuilder;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.rescore.QueryRescoreMode;
import org.elasticsearch.search.rescore.QueryRescorerBuilder;
import org.hamcrest.CoreMatchers;
import org.hamcrest.Matchers;

import java.util.concurrent.ExecutionException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;

import static org.hamcrest.CoreMatchers.containsString;

/**
* Created by doug on 12/29/16.
Expand All @@ -63,6 +72,69 @@ public class StoredLtrQueryIT extends BaseIntegrationTest {
"\"feature6\": 1" +
"}";

private static final String SIMPLE_MODEL_XGB = "[{" +
"\"nodeid\": 0," +
"\"split\":\"text_feature1\"," +
"\"depth\":0," +
"\"split_condition\":100.0," +
"\"yes\":1," +
"\"no\":2," +
"\"missing\":2," +
"\"children\": [" +
" {\"nodeid\": 1, \"depth\": 1, \"leaf\": 0.5}," +
" {\"nodeid\": 2, \"depth\": 1, \"leaf\": 0.2}" +
"]}]";


public void testScriptFeatureUseCaseMissingFeatureNaiveAdditiveDecisionTree() throws Exception {
List<StoredFeature> features = new ArrayList<>(1);
features.add(new StoredFeature("text_feature1", Collections.singletonList("query"), "mustache",
QueryBuilders.matchQuery("field1", "{{query}}").toString()));

StoredFeatureSet set = new StoredFeatureSet("my_set", features);
addElement(set);
StoredLtrModel model = new StoredLtrModel("my_model", set,
new StoredLtrModel.LtrModelDefinition("model/xgboost+json",
SIMPLE_MODEL_XGB, true));
addElement(model);

buildIndex();

Map<String, Object> params = new HashMap<>();
params.put("query", "bonjour");
StoredLtrQueryBuilder sbuilder = new StoredLtrQueryBuilder(LtrTestUtils.nullLoader())
.featureSetName("my_set")
.modelName("my_model")
.params(params)
.queryName("test")
.boost(1);

QueryBuilder query = QueryBuilders.boolQuery().must(new WrapperQueryBuilder(sbuilder.toString()));
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().query(query)
.explain(true)
.fetchSource(true)
.size(10)
.ext(Collections.singletonList(
new LoggingSearchExtBuilder()
.addQueryLogging("log", "test", false)));

SearchResponse resp = client().prepareSearch("test_index").setSource(sourceBuilder).get();
SearchHit hit = resp.getHits().getAt(0);
assertTrue(hit.getFields().containsKey("_ltrlog"));
Map<String, List<Map<String, Object>>> logs = hit.getFields().get("_ltrlog").getValue();
assertTrue(logs.containsKey("log"));
List<Map<String, Object>> log = logs.get("log");

// verify that text_feature1 has a missing value, and that the reported score results from the model taking the
// corresponding branch, along with the explanation
String explanation = hit.getExplanation().getDetails()[0].getDescription();
assertThat(explanation, containsString("default value of NaN used"));

assertEquals("text_feature1", log.get(0).get("name"));
assertEquals(null, log.get(0).get("value"));

assertEquals(0.2F, hit.getScore(), Math.ulp(0.2F));
}

public void testScriptFeatureUseCase() throws Exception {
addElement(new StoredFeature("feature1", Collections.singletonList("query"), "mustache",
Expand Down Expand Up @@ -260,7 +332,7 @@ public void testFullUsecase() throws Exception {

StoredLtrModel model = getElement(StoredLtrModel.class, StoredLtrModel.TYPE, "my_model");
CachesStatsNodesResponse stats = client().execute(CachesStatsAction.INSTANCE,
new CachesStatsAction.CachesStatsNodesRequest()).get();
new CachesStatsAction.CachesStatsNodesRequest()).get();
assertEquals(1, stats.getAll().getTotal().getCount());
assertEquals(model.compile(parserFactory()).ramBytesUsed(), stats.getAll().getTotal().getRam());
assertEquals(1, stats.getAll().getModels().getCount());
Expand Down Expand Up @@ -301,6 +373,4 @@ public void buildIndex() {
.setSource("field1", "hello world", "field2", "bonjour world")
.get();
}


}
3 changes: 2 additions & 1 deletion src/main/java/com/o19s/es/ltr/query/RankerQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,8 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio
}
featureString += ":";
if (!explain.isMatch()) {
subs.add(Explanation.noMatch(featureString + " [no match, default value 0.0 used]"));
subs.add(Explanation.noMatch(featureString +
String.format(" [no match, default value of %.2f used]", d.getDefaultScore())));
} else {
subs.add(Explanation.match(explain.getValue(), featureString, explain));
d.setFeatureScore(ordinal, explain.getValue().floatValue());
Expand Down
48 changes: 48 additions & 0 deletions src/main/java/com/o19s/es/ltr/ranker/ArrayFeatureVector.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright [2017] Wikimedia Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.o19s.es.ltr.ranker;

import java.util.Arrays;

public class ArrayFeatureVector implements LtrRanker.FeatureVector {
public final float[] scores;
public final float defaultScore;

public ArrayFeatureVector(int size, float value) {
scores = new float[size];
defaultScore = value;
}

@Override
public void setFeatureScore(int featureIdx, float score) {
scores[featureIdx] = score;
}

@Override
public float getFeatureScore(int featureIdx) {
return scores[featureIdx];
}

public void reset() {
Arrays.fill(scores, defaultScore);
}

@Override
public float getDefaultScore() {
return defaultScore;
}
}
29 changes: 2 additions & 27 deletions src/main/java/com/o19s/es/ltr/ranker/DenseFeatureVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,34 +16,9 @@

package com.o19s.es.ltr.ranker;

import java.util.Arrays;
public class DenseFeatureVector extends ArrayFeatureVector {

/**
* Simple array-backed feature vector
*/
public class DenseFeatureVector implements LtrRanker.FeatureVector {
public final float[] scores;

/**
* New simple array-backed datapoint
*
* @param size size of the internal array
*/
public DenseFeatureVector(int size) {
this.scores = new float[size];
}

@Override
public void setFeatureScore(int featureIdx, float score) {
scores[featureIdx] = score;
}

@Override
public float getFeatureScore(int featureIdx) {
return scores[featureIdx];
}

public void reset() {
Arrays.fill(scores, 0F);
super(size, 0F);
}
}
4 changes: 4 additions & 0 deletions src/main/java/com/o19s/es/ltr/ranker/LogLtrRanker.java
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ void reset(LtrRanker ranker) {
this.inner = ranker.newFeatureVector(inner);
logger.reset();
}

public float getDefaultScore() {
return inner.getDefaultScore();
}
}

public LogConsumer getLogConsumer() {
Expand Down
8 changes: 7 additions & 1 deletion src/main/java/com/o19s/es/ltr/ranker/LtrRanker.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public interface LtrRanker {
/**
* Score the data point.
* At this point all feature scores are set.
* features that did not match are set with a score to 0
* features that did not match are set with a default score
*
* @param point the feature vector point to compute the score for
* @return the score computed for the given point
Expand All @@ -73,5 +73,11 @@ interface FeatureVector {
*/
float getFeatureScore(int featureId);

/**
* Retrieve the default score
* @return the score computed for the given feature
*/
float getDefaultScore();

}
}
25 changes: 25 additions & 0 deletions src/main/java/com/o19s/es/ltr/ranker/SparseFeatureVector.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright [2017] Wikimedia Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.o19s.es.ltr.ranker;

public class SparseFeatureVector extends ArrayFeatureVector {

public SparseFeatureVector(int size) {
super(size, Float.NaN);
reset();
}
}
47 changes: 47 additions & 0 deletions src/main/java/com/o19s/es/ltr/ranker/SparseLtrRanker.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright [2017] Wikimedia Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.o19s.es.ltr.ranker;

/**
* A dense ranker base class to work with {@link SparseFeatureVector}
* where missing feature scores are set to 0.
*/
public abstract class SparseLtrRanker implements LtrRanker {
@Override
public SparseFeatureVector newFeatureVector(FeatureVector reuse) {
if (reuse != null) {
assert reuse instanceof SparseFeatureVector;
SparseFeatureVector vector = (SparseFeatureVector) reuse;
vector.reset();
return vector;
}
return new SparseFeatureVector(size());
}

@Override
public float score(FeatureVector vector) {
assert vector instanceof SparseFeatureVector;
return this.score((SparseFeatureVector) vector);
}

protected abstract float score(SparseFeatureVector vector);

/**
* @return the number of features supported by this ranker
*/
protected abstract int size();
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

package com.o19s.es.ltr.ranker.dectree;

import com.o19s.es.ltr.ranker.DenseFeatureVector;
import com.o19s.es.ltr.ranker.DenseLtrRanker;
import com.o19s.es.ltr.ranker.SparseFeatureVector;
import com.o19s.es.ltr.ranker.SparseLtrRanker;
import com.o19s.es.ltr.ranker.normalizer.Normalizer;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.RamUsageEstimator;
Expand All @@ -28,7 +28,7 @@
* Naive implementation of additive decision tree.
* May be slow when the number of trees and tree complexity if high comparatively to the number of features.
*/
public class NaiveAdditiveDecisionTree extends DenseLtrRanker implements Accountable {
public class NaiveAdditiveDecisionTree extends SparseLtrRanker implements Accountable {
private static final long BASE_RAM_USED = RamUsageEstimator.shallowSizeOfInstance(Split.class);

private final Node[] trees;
Expand Down Expand Up @@ -60,7 +60,7 @@ public String name() {
}

@Override
protected float score(DenseFeatureVector vector) {
protected float score(SparseFeatureVector vector) {
float sum = 0;
float[] scores = vector.scores;
for (int i = 0; i < trees.length; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import java.util.Arrays;

/**
* Implements FeatureVector but without needing to pass in a stirng
* Implements FeatureVector but without needing to pass in a string
* to be parsed
*/
public class DenseProgramaticDataPoint extends DataPoint implements LtrRanker.FeatureVector {
Expand Down Expand Up @@ -71,4 +71,8 @@ public float getFeatureScore(int featureIdx) {
public void reset() {
Arrays.fill(fVals, 0F);
}

public float getDefaultScore() {
return 0.0F;
}
}
Loading

0 comments on commit d2bbe8b

Please sign in to comment.