diff --git a/src/javaRestTest/java/com/o19s/es/ltr/logging/LoggingIT.java b/src/javaRestTest/java/com/o19s/es/ltr/logging/LoggingIT.java index e213a997..5c8b62dd 100644 --- a/src/javaRestTest/java/com/o19s/es/ltr/logging/LoggingIT.java +++ b/src/javaRestTest/java/com/o19s/es/ltr/logging/LoggingIT.java @@ -261,6 +261,20 @@ public void testLog() throws Exception { .addRescoreLogging("second_log", 0, true))); SearchResponse resp3 = client().prepareSearch("test_index").setTypes("test").setSource(sourceBuilder).get(); assertSearchHits(docs, resp3); + + query = QueryBuilders.boolQuery().filter(QueryBuilders.idsQuery("test").addIds(ids)); + sourceBuilder = new SearchSourceBuilder().query(query) + .fetchSource(false) + .size(10) + .addRescorer(new QueryRescorerBuilder(new WrapperQueryBuilder(sbuilder.toString()))) + .addRescorer(new QueryRescorerBuilder(new WrapperQueryBuilder(sbuilder_rescore.toString()))) + .ext(Collections.singletonList( + new LoggingSearchExtBuilder() + .addRescoreLogging("first_log", 0, false) + .addRescoreLogging("second_log", 1, true))); + + SearchResponse resp4 = client().prepareSearch("test_index").setTypes("test").setSource(sourceBuilder).get(); + assertSearchHits(docs, resp4); } public void testLogExtraLogging() throws Exception { diff --git a/src/main/java/com/o19s/es/ltr/logging/LoggingFetchSubPhase.java b/src/main/java/com/o19s/es/ltr/logging/LoggingFetchSubPhase.java index 48d61a5e..c402dcd8 100644 --- a/src/main/java/com/o19s/es/ltr/logging/LoggingFetchSubPhase.java +++ b/src/main/java/com/o19s/es/ltr/logging/LoggingFetchSubPhase.java @@ -19,6 +19,7 @@ import com.o19s.es.ltr.feature.FeatureSet; import com.o19s.es.ltr.query.RankerQuery; import com.o19s.es.ltr.ranker.LogLtrRanker; +import com.o19s.es.ltr.utils.Suppliers; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; @@ -27,6 +28,7 @@ import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.Weight; +import org.elasticsearch.common.CheckedSupplier; import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.document.DocumentField; import org.elasticsearch.search.SearchHit; @@ -52,30 +54,29 @@ public FetchSubPhaseProcessor getProcessor(FetchContext context) throws IOExcept return null; } - BooleanQuery.Builder builder = new BooleanQuery.Builder(); - List loggers = new ArrayList<>(); - Map namedQueries = context.parsedQuery().namedFilters(); - - - if (namedQueries.size() > 0) { + // NOTE: we do not support logging on nested hits but sadly at this point we cannot know + // if we are going to run on top level hits or nested hits. + // Delegate creation of the loggers until we know the hits checking for SearchHit#getNestedIdentity + CheckedSupplier>, IOException> weigthtAndLogSpecsSupplier = () -> { + List loggers = new ArrayList<>(); + Map namedQueries = context.parsedQuery().namedFilters(); + BooleanQuery.Builder builder = new BooleanQuery.Builder(); ext.logSpecsStream().filter((l) -> l.getNamedQuery() != null).forEach((l) -> { Tuple query = extractQuery(l, namedQueries); builder.add(new BooleanClause(query.v1(), BooleanClause.Occur.MUST)); loggers.add(query.v2()); }); - ext.logSpecsStream().filter((l) -> l.getRescoreIndex() != null).forEach((l) -> { Tuple query = extractRescore(l, context.rescore()); builder.add(new BooleanClause(query.v1(), BooleanClause.Occur.MUST)); loggers.add(query.v2()); }); - } + Weight w = context.searcher().rewrite(builder.build()).createWeight(context.searcher(), ScoreMode.COMPLETE, 1.0F); + return new Tuple<>(w, loggers); + }; - - Weight w = context.searcher().rewrite(builder.build()).createWeight(context.searcher(), ScoreMode.COMPLETE, 1.0F); - - return new LoggingFetchSubPhaseProcessor(w, loggers); + return new LoggingFetchSubPhaseProcessor(Suppliers.memoizeCheckedSupplier(weigthtAndLogSpecsSupplier)); } private Tuple extractQuery(LoggingSearchExtBuilder.LogSpec @@ -127,23 +128,32 @@ private Tuple toLogger(LoggingSearchExtBuilder.LogS return new Tuple<>(query, consumer); } static class LoggingFetchSubPhaseProcessor implements FetchSubPhaseProcessor { - private final Weight weight; - private final List loggers; + private final CheckedSupplier>, IOException> loggersSupplier; private Scorer scorer; + private LeafReaderContext currentContext; - LoggingFetchSubPhaseProcessor(Weight weight, List loggers) { - this.weight = weight; - this.loggers = loggers; + LoggingFetchSubPhaseProcessor(CheckedSupplier>, IOException> loggersSupplier) { + this.loggersSupplier = loggersSupplier; } @Override public void setNextReader(LeafReaderContext readerContext) throws IOException { - scorer = weight.scorer(readerContext); + currentContext = readerContext; + scorer = null; } @Override public void process(HitContext hitContext) throws IOException { + if (hitContext.hit().getNestedIdentity() != null) { + // we do not support logging nested docs + return; + } + Tuple> weightAndLoggers = loggersSupplier.get(); + if (scorer == null) { + scorer = weightAndLoggers.v1().scorer(currentContext); + } + List loggers = weightAndLoggers.v2(); if (scorer != null && scorer.iterator().advance(hitContext.docId()) == hitContext.docId()) { loggers.forEach((l) -> l.nextDoc(hitContext.hit())); // Scoring will trigger log collection diff --git a/src/main/java/com/o19s/es/ltr/utils/Suppliers.java b/src/main/java/com/o19s/es/ltr/utils/Suppliers.java index 25fbdfa3..371680f9 100644 --- a/src/main/java/com/o19s/es/ltr/utils/Suppliers.java +++ b/src/main/java/com/o19s/es/ltr/utils/Suppliers.java @@ -18,6 +18,7 @@ import com.o19s.es.ltr.ranker.LtrRanker; import org.elasticsearch.Assertions; +import org.elasticsearch.common.CheckedSupplier; import java.util.Objects; import java.util.function.Supplier; @@ -95,4 +96,28 @@ public void set(LtrRanker.FeatureVector obj) { super.set(obj); } } + + /** + * memoize the return value of the checked supplier (thread unsafe) + */ + public static CheckedSupplier memoizeCheckedSupplier(CheckedSupplier supplier) { + return new CheckedMemoizeSupplier(supplier); + } + + private static class CheckedMemoizeSupplier implements CheckedSupplier { + private final CheckedSupplier supplier; + private R value; + + private CheckedMemoizeSupplier(CheckedSupplier supplier) { + this.supplier = supplier; + } + + @Override + public R get() throws E { + if (value == null) { + value = supplier.get(); + } + return value; + } + } } diff --git a/src/test/java/com/o19s/es/ltr/logging/LoggingFetchSubPhaseTests.java b/src/test/java/com/o19s/es/ltr/logging/LoggingFetchSubPhaseTests.java index 996a8b74..4e953160 100644 --- a/src/test/java/com/o19s/es/ltr/logging/LoggingFetchSubPhaseTests.java +++ b/src/test/java/com/o19s/es/ltr/logging/LoggingFetchSubPhaseTests.java @@ -44,6 +44,7 @@ import org.apache.lucene.store.Directory; import org.apache.lucene.util.LuceneTestCase; import org.apache.lucene.util.TestUtil; +import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.lucene.search.function.CombineFunction; import org.elasticsearch.common.lucene.search.function.FieldValueFactorFunction; import org.elasticsearch.common.lucene.search.function.FunctionScoreQuery; @@ -114,10 +115,9 @@ public void testLogging() throws IOException { .add(new BooleanClause(query1, BooleanClause.Occur.MUST)) .add(new BooleanClause(query2, BooleanClause.Occur.MUST)) .build(); - LoggingFetchSubPhase subPhase = new LoggingFetchSubPhase(); Weight weight = searcher.createWeight(query, ScoreMode.COMPLETE, 1.0F); List loggers = Arrays.asList(logger1, logger2); - LoggingFetchSubPhaseProcessor processor = new LoggingFetchSubPhaseProcessor(weight, loggers); + LoggingFetchSubPhaseProcessor processor = new LoggingFetchSubPhaseProcessor(() -> new Tuple<>(weight, loggers)); SearchHit[] hits = preprocessRandomHits(processor); for (SearchHit hit : hits) { @@ -195,7 +195,7 @@ public void collect(int doc) throws IOException { return hits.toArray(new SearchHit[0]); } - public static Document buildDoc(String text, float value) throws IOException { + public static Document buildDoc(String text, float value) { String id = UUID.randomUUID().toString(); Document d = new Document(); d.add(newStringField("id", id, Field.Store.YES));