Skip to content

Commit

Permalink
Fix rescore nested hits (#367)
Browse files Browse the repository at this point in the history
* quick-and-dirty fix for feature score logging against rescore query without named query

* Workaround nested hits

when nested hits are returned make sure that we do not try to log them

closes #357

Co-authored-by: Tomohiro Manabe <[email protected]>
  • Loading branch information
nomoa and tmanabe authored Apr 30, 2021
1 parent 5430a60 commit b03c45e
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 21 deletions.
14 changes: 14 additions & 0 deletions src/javaRestTest/java/com/o19s/es/ltr/logging/LoggingIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,20 @@ public void testLog() throws Exception {
.addRescoreLogging("second_log", 0, true)));
SearchResponse resp3 = client().prepareSearch("test_index").setTypes("test").setSource(sourceBuilder).get();
assertSearchHits(docs, resp3);

query = QueryBuilders.boolQuery().filter(QueryBuilders.idsQuery("test").addIds(ids));
sourceBuilder = new SearchSourceBuilder().query(query)
.fetchSource(false)
.size(10)
.addRescorer(new QueryRescorerBuilder(new WrapperQueryBuilder(sbuilder.toString())))
.addRescorer(new QueryRescorerBuilder(new WrapperQueryBuilder(sbuilder_rescore.toString())))
.ext(Collections.singletonList(
new LoggingSearchExtBuilder()
.addRescoreLogging("first_log", 0, false)
.addRescoreLogging("second_log", 1, true)));

SearchResponse resp4 = client().prepareSearch("test_index").setTypes("test").setSource(sourceBuilder).get();
assertSearchHits(docs, resp4);
}

public void testLogExtraLogging() throws Exception {
Expand Down
46 changes: 28 additions & 18 deletions src/main/java/com/o19s/es/ltr/logging/LoggingFetchSubPhase.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.o19s.es.ltr.feature.FeatureSet;
import com.o19s.es.ltr.query.RankerQuery;
import com.o19s.es.ltr.ranker.LogLtrRanker;
import com.o19s.es.ltr.utils.Suppliers;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
Expand All @@ -27,6 +28,7 @@
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.elasticsearch.common.CheckedSupplier;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.document.DocumentField;
import org.elasticsearch.search.SearchHit;
Expand All @@ -52,30 +54,29 @@ public FetchSubPhaseProcessor getProcessor(FetchContext context) throws IOExcept
return null;
}

BooleanQuery.Builder builder = new BooleanQuery.Builder();
List<HitLogConsumer> loggers = new ArrayList<>();
Map<String, Query> namedQueries = context.parsedQuery().namedFilters();


if (namedQueries.size() > 0) {
// NOTE: we do not support logging on nested hits but sadly at this point we cannot know
// if we are going to run on top level hits or nested hits.
// Delegate creation of the loggers until we know the hits checking for SearchHit#getNestedIdentity
CheckedSupplier<Tuple<Weight, List<HitLogConsumer>>, IOException> weigthtAndLogSpecsSupplier = () -> {
List<HitLogConsumer> loggers = new ArrayList<>();
Map<String, Query> namedQueries = context.parsedQuery().namedFilters();
BooleanQuery.Builder builder = new BooleanQuery.Builder();
ext.logSpecsStream().filter((l) -> l.getNamedQuery() != null).forEach((l) -> {
Tuple<RankerQuery, HitLogConsumer> query = extractQuery(l, namedQueries);
builder.add(new BooleanClause(query.v1(), BooleanClause.Occur.MUST));
loggers.add(query.v2());
});

ext.logSpecsStream().filter((l) -> l.getRescoreIndex() != null).forEach((l) -> {
Tuple<RankerQuery, HitLogConsumer> query = extractRescore(l, context.rescore());
builder.add(new BooleanClause(query.v1(), BooleanClause.Occur.MUST));
loggers.add(query.v2());
});
}
Weight w = context.searcher().rewrite(builder.build()).createWeight(context.searcher(), ScoreMode.COMPLETE, 1.0F);
return new Tuple<>(w, loggers);
};



Weight w = context.searcher().rewrite(builder.build()).createWeight(context.searcher(), ScoreMode.COMPLETE, 1.0F);

return new LoggingFetchSubPhaseProcessor(w, loggers);
return new LoggingFetchSubPhaseProcessor(Suppliers.memoizeCheckedSupplier(weigthtAndLogSpecsSupplier));
}

private Tuple<RankerQuery, HitLogConsumer> extractQuery(LoggingSearchExtBuilder.LogSpec
Expand Down Expand Up @@ -127,23 +128,32 @@ private Tuple<RankerQuery, HitLogConsumer> toLogger(LoggingSearchExtBuilder.LogS
return new Tuple<>(query, consumer);
}
static class LoggingFetchSubPhaseProcessor implements FetchSubPhaseProcessor {
private final Weight weight;
private final List<HitLogConsumer> loggers;
private final CheckedSupplier<Tuple<Weight, List<HitLogConsumer>>, IOException> loggersSupplier;
private Scorer scorer;
private LeafReaderContext currentContext;

LoggingFetchSubPhaseProcessor(Weight weight, List<HitLogConsumer> loggers) {
this.weight = weight;
this.loggers = loggers;
LoggingFetchSubPhaseProcessor(CheckedSupplier<Tuple<Weight, List<HitLogConsumer>>, IOException> loggersSupplier) {
this.loggersSupplier = loggersSupplier;
}


@Override
public void setNextReader(LeafReaderContext readerContext) throws IOException {
scorer = weight.scorer(readerContext);
currentContext = readerContext;
scorer = null;
}

@Override
public void process(HitContext hitContext) throws IOException {
if (hitContext.hit().getNestedIdentity() != null) {
// we do not support logging nested docs
return;
}
Tuple<Weight, List<HitLogConsumer>> weightAndLoggers = loggersSupplier.get();
if (scorer == null) {
scorer = weightAndLoggers.v1().scorer(currentContext);
}
List<HitLogConsumer> loggers = weightAndLoggers.v2();
if (scorer != null && scorer.iterator().advance(hitContext.docId()) == hitContext.docId()) {
loggers.forEach((l) -> l.nextDoc(hitContext.hit()));
// Scoring will trigger log collection
Expand Down
25 changes: 25 additions & 0 deletions src/main/java/com/o19s/es/ltr/utils/Suppliers.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import com.o19s.es.ltr.ranker.LtrRanker;
import org.elasticsearch.Assertions;
import org.elasticsearch.common.CheckedSupplier;

import java.util.Objects;
import java.util.function.Supplier;
Expand Down Expand Up @@ -95,4 +96,28 @@ public void set(LtrRanker.FeatureVector obj) {
super.set(obj);
}
}

/**
* memoize the return value of the checked supplier (thread unsafe)
*/
public static <R, E extends Exception> CheckedSupplier<R, E> memoizeCheckedSupplier(CheckedSupplier<R, E> supplier) {
return new CheckedMemoizeSupplier<R, E>(supplier);
}

private static class CheckedMemoizeSupplier<R, E extends Exception> implements CheckedSupplier<R, E> {
private final CheckedSupplier<R, E> supplier;
private R value;

private CheckedMemoizeSupplier(CheckedSupplier<R, E> supplier) {
this.supplier = supplier;
}

@Override
public R get() throws E {
if (value == null) {
value = supplier.get();
}
return value;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util.TestUtil;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.lucene.search.function.CombineFunction;
import org.elasticsearch.common.lucene.search.function.FieldValueFactorFunction;
import org.elasticsearch.common.lucene.search.function.FunctionScoreQuery;
Expand Down Expand Up @@ -114,10 +115,9 @@ public void testLogging() throws IOException {
.add(new BooleanClause(query1, BooleanClause.Occur.MUST))
.add(new BooleanClause(query2, BooleanClause.Occur.MUST))
.build();
LoggingFetchSubPhase subPhase = new LoggingFetchSubPhase();
Weight weight = searcher.createWeight(query, ScoreMode.COMPLETE, 1.0F);
List<LoggingFetchSubPhase.HitLogConsumer> loggers = Arrays.asList(logger1, logger2);
LoggingFetchSubPhaseProcessor processor = new LoggingFetchSubPhaseProcessor(weight, loggers);
LoggingFetchSubPhaseProcessor processor = new LoggingFetchSubPhaseProcessor(() -> new Tuple<>(weight, loggers));

SearchHit[] hits = preprocessRandomHits(processor);
for (SearchHit hit : hits) {
Expand Down Expand Up @@ -195,7 +195,7 @@ public void collect(int doc) throws IOException {
return hits.toArray(new SearchHit[0]);
}

public static Document buildDoc(String text, float value) throws IOException {
public static Document buildDoc(String text, float value) {
String id = UUID.randomUUID().toString();
Document d = new Document();
d.add(newStringField("id", id, Field.Store.YES));
Expand Down

0 comments on commit b03c45e

Please sign in to comment.