diff --git a/src/javaRestTest/java/com/o19s/es/ltr/ShardStatsIT.java b/src/javaRestTest/java/com/o19s/es/ltr/ShardStatsIT.java new file mode 100644 index 00000000..304aca04 --- /dev/null +++ b/src/javaRestTest/java/com/o19s/es/ltr/ShardStatsIT.java @@ -0,0 +1,123 @@ +package com.o19s.es.ltr; + +import com.o19s.es.explore.ExplorerQueryBuilder; +import com.o19s.es.termstat.TermStatQueryBuilder; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.search.SearchType; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.TermQueryBuilder; +import org.elasticsearch.search.SearchHits; +import org.elasticsearch.test.ESIntegTestCase; + +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertSearchResponse; +import static org.hamcrest.Matchers.equalTo; + +/* + These tests mostly verify that shard vs collection stat counting is working as expected. + */ +public class ShardStatsIT extends ESIntegTestCase { + @Override + protected int numberOfShards() { + return 2; + } + + protected void createIdx() { + prepareCreate("idx") + .addMapping("type", "s", "type=text"); + + for (int i = 0; i < 4; i++) { + indexDoc(i); + } + refreshIndex(); + } + + protected void indexDoc(int id) { + client().prepareIndex("idx", "type", Integer.toString(id)) + .setRouting( ((id % 2) == 0 ) ? "a" : "b" ) + .setSource("s", "zzz").get(); + } + + protected void refreshIndex() { + client().admin().indices().prepareRefresh("idx").get(); + } + + public void testDfsExplorer() throws Exception { + createIdx(); + + QueryBuilder q = new TermQueryBuilder("s", "zzz"); + + ExplorerQueryBuilder eq = new ExplorerQueryBuilder() + .query(q) + .statsType("min_raw_df"); + + final SearchResponse r = client().prepareSearch("idx") + .setSearchType(SearchType.DFS_QUERY_THEN_FETCH) + .setQuery(eq).get(); + + assertSearchResponse(r); + + SearchHits hits = r.getHits(); + assertThat(hits.getAt(0).getScore(), equalTo(4.0f)); + } + + public void testNonDfsExplorer() throws Exception { + createIdx(); + + QueryBuilder q = new TermQueryBuilder("s", "zzz"); + + ExplorerQueryBuilder eq = new ExplorerQueryBuilder() + .query(q) + .statsType("min_raw_df"); + + final SearchResponse r = client().prepareSearch("idx") + .setSearchType(SearchType.QUERY_THEN_FETCH) + .setQuery(eq).get(); + + assertSearchResponse(r); + + SearchHits hits = r.getHits(); + assertThat(hits.getAt(0).getScore(), equalTo(2.0f)); + } + + public void testDfsTSQ() throws Exception { + createIdx(); + + TermStatQueryBuilder tsq = new TermStatQueryBuilder() + .expr("df") + .aggr("min") + .posAggr("min") + .terms(new String[]{"zzz"}) + .fields(new String[]{"s"}); + + final SearchResponse r = client().prepareSearch("idx") + .setSearchType(SearchType.DFS_QUERY_THEN_FETCH) + .setQuery(tsq) + .get(); + + assertSearchResponse(r); + + SearchHits hits = r.getHits(); + assertThat(hits.getAt(0).getScore(), equalTo(4.0f)); + } + + public void testNonDfsTSQ() throws Exception { + createIdx(); + + TermStatQueryBuilder tsq = new TermStatQueryBuilder() + .expr("df") + .aggr("min") + .posAggr("min") + .terms(new String[]{"zzz"}) + .fields(new String[]{"s"}); + + final SearchResponse r = client().prepareSearch("idx") + .setSearchType(SearchType.QUERY_THEN_FETCH) + .setQuery(tsq) + .get(); + + assertSearchResponse(r); + + SearchHits hits = r.getHits(); + assertThat(hits.getAt(0).getScore(), equalTo(2.0f)); + } +} diff --git a/src/main/java/com/o19s/es/explore/ExplorerQuery.java b/src/main/java/com/o19s/es/explore/ExplorerQuery.java index a64fb3e5..d0ef4324 100644 --- a/src/main/java/com/o19s/es/explore/ExplorerQuery.java +++ b/src/main/java/com/o19s/es/explore/ExplorerQuery.java @@ -24,6 +24,7 @@ import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Weight; +import org.apache.lucene.search.TermStatistics; import org.apache.lucene.search.ConstantScoreWeight; import org.apache.lucene.search.Explanation; import org.apache.lucene.search.Scorer; @@ -104,11 +105,15 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo for (Term term : terms) { TermStates ctx = TermStates.build(searcher.getTopReaderContext(), term, scoreMode.needsScores()); - - if(ctx != null){ - df_stats.add(ctx.docFreq()); - idf_stats.add(sim.idf(ctx.docFreq(), searcher.collectionStatistics(term.field()).docCount())); - ttf_stats.add(ctx.totalTermFreq()); + if(ctx != null && ctx.docFreq() > 0){ + TermStatistics tStats = searcher.termStatistics(term, ctx.docFreq(), ctx.totalTermFreq()); + df_stats.add(tStats.docFreq()); + idf_stats.add(sim.idf(tStats.docFreq(), searcher.collectionStatistics(term.field()).docCount())); + ttf_stats.add(tStats.totalTermFreq()); + } else { + df_stats.add(0.0f); + idf_stats.add(0.0f); + ttf_stats.add(0.0f); } } diff --git a/src/main/java/com/o19s/es/ltr/feature/store/ScriptFeature.java b/src/main/java/com/o19s/es/ltr/feature/store/ScriptFeature.java index ba995d38..77eb2823 100644 --- a/src/main/java/com/o19s/es/ltr/feature/store/ScriptFeature.java +++ b/src/main/java/com/o19s/es/ltr/feature/store/ScriptFeature.java @@ -12,6 +12,7 @@ import org.apache.lucene.analysis.tokenattributes.TermToBytesRefAttribute; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.Term; +import org.apache.lucene.index.TermStates; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Explanation; import org.apache.lucene.search.IndexSearcher; @@ -286,18 +287,31 @@ static class LtrScriptWeight extends Weight { private final ScriptScoreFunction function; private final TermStatSupplier termStatSupplier; private final Set terms; + private final HashMap termContexts; LtrScriptWeight(Query query, ScriptScoreFunction function, TermStatSupplier termStatSupplier, Set terms, IndexSearcher searcher, - ScoreMode scoreMode) { + ScoreMode scoreMode) throws IOException { super(query); this.function = function; this.termStatSupplier = termStatSupplier; this.terms = terms; this.searcher = searcher; this.scoreMode = scoreMode; + this.termContexts = new HashMap<>(); + + if (scoreMode.needsScores()) { + for (Term t : terms) { + TermStates ctx = TermStates.build(searcher.getTopReaderContext(), t, true); + if (ctx != null && ctx.docFreq() > 0) { + searcher.collectionStatistics(t.field()); + searcher.termStatistics(t, ctx.docFreq(), ctx.totalTermFreq()); + } + termContexts.put(t, ctx); + } + } } @Override @@ -319,7 +333,7 @@ public int docID() { public float score() throws IOException { // Do the terms magic if the user asked for it if (terms.size() > 0) { - termStatSupplier.bump(searcher, context, docID(), terms, scoreMode); + termStatSupplier.bump(searcher, context, docID(), terms, scoreMode, termContexts); } return (float) leafScoreFunction.score(iterator.docID(), 0F); diff --git a/src/main/java/com/o19s/es/termstat/TermStatQuery.java b/src/main/java/com/o19s/es/termstat/TermStatQuery.java index 926826dd..07c854d2 100644 --- a/src/main/java/com/o19s/es/termstat/TermStatQuery.java +++ b/src/main/java/com/o19s/es/termstat/TermStatQuery.java @@ -6,6 +6,7 @@ import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.Term; +import org.apache.lucene.index.TermStates; import org.apache.lucene.search.Explanation; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Query; @@ -14,6 +15,8 @@ import org.apache.lucene.search.Weight; import java.io.IOException; +import java.util.HashMap; +import java.util.Map; import java.util.Objects; import java.util.Set; @@ -81,8 +84,14 @@ static class TermStatWeight extends Weight { private final AggrType aggr; private final AggrType posAggr; private final Set terms; - - TermStatWeight(IndexSearcher searcher, TermStatQuery tsq, Set terms, ScoreMode scoreMode, AggrType aggr, AggrType posAggr) { + private final Map termContexts; + + TermStatWeight(IndexSearcher searcher, + TermStatQuery tsq, + Set terms, + ScoreMode scoreMode, + AggrType aggr, + AggrType posAggr) throws IOException { super(tsq); this.searcher = searcher; this.expression = tsq.expr; @@ -90,6 +99,21 @@ static class TermStatWeight extends Weight { this.scoreMode = scoreMode; this.aggr = aggr; this.posAggr = posAggr; + this.termContexts = new HashMap<>(); + + // This is needed for proper DFS_QUERY_THEN_FETCH support + if (scoreMode.needsScores()) { + for (Term t : terms) { + TermStates ctx = TermStates.build(searcher.getTopReaderContext(), t, true); + + if (ctx != null && ctx.docFreq() > 0) { + searcher.collectionStatistics(t.field()); + searcher.termStatistics(t, ctx.docFreq(), ctx.totalTermFreq()); + } + + termContexts.put(t, ctx); + } + } } @Override @@ -110,7 +134,7 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio @Override public Scorer scorer(LeafReaderContext context) throws IOException { - return new TermStatScorer(this, searcher, context, expression, terms, scoreMode, aggr, posAggr); + return new TermStatScorer(this, searcher, context, expression, terms, scoreMode, aggr, posAggr, termContexts); } @Override diff --git a/src/main/java/com/o19s/es/termstat/TermStatScorer.java b/src/main/java/com/o19s/es/termstat/TermStatScorer.java index 3669e0b5..c8ddd3fc 100644 --- a/src/main/java/com/o19s/es/termstat/TermStatScorer.java +++ b/src/main/java/com/o19s/es/termstat/TermStatScorer.java @@ -7,6 +7,7 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.Term; +import org.apache.lucene.index.TermStates; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.DoubleValues; import org.apache.lucene.search.DoubleValuesSource; @@ -16,6 +17,7 @@ import java.io.IOException; import java.util.HashMap; +import java.util.Map; import java.util.Set; public class TermStatScorer extends Scorer { @@ -29,6 +31,7 @@ public class TermStatScorer extends Scorer { private final IndexSearcher searcher; private final Set terms; private final ScoreMode scoreMode; + private final Map termContexts; public TermStatScorer(TermStatQuery.TermStatWeight weight, IndexSearcher searcher, @@ -37,7 +40,8 @@ public TermStatScorer(TermStatQuery.TermStatWeight weight, Set terms, ScoreMode scoreMode, AggrType aggr, - AggrType posAggr) { + AggrType posAggr, + Map termContexts) { super(weight); this.context = context; this.compiledExpression = compiledExpression; @@ -46,6 +50,7 @@ public TermStatScorer(TermStatQuery.TermStatWeight weight, this.scoreMode = scoreMode; this.aggr = aggr; this.posAggr = posAggr; + this.termContexts = termContexts; this.iter = DocIdSetIterator.all(context.reader().maxDoc()); } @@ -65,7 +70,7 @@ public float score() throws IOException { // Refresh the term stats tsq.setPosAggr(posAggr); - tsq.bump(searcher, context, docID(), terms, scoreMode); + tsq.bump(searcher, context, docID(), terms, scoreMode, termContexts); // Prepare computed statistics StatisticsHelper computed = new StatisticsHelper(); diff --git a/src/main/java/com/o19s/es/termstat/TermStatSupplier.java b/src/main/java/com/o19s/es/termstat/TermStatSupplier.java index 04de9342..50d87bac 100644 --- a/src/main/java/com/o19s/es/termstat/TermStatSupplier.java +++ b/src/main/java/com/o19s/es/termstat/TermStatSupplier.java @@ -13,6 +13,7 @@ import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.TermStatistics; import org.apache.lucene.search.similarities.ClassicSimilarity; import java.io.IOException; @@ -22,6 +23,7 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; +import java.util.Map; import java.util.Set; @@ -47,7 +49,7 @@ public TermStatSupplier() { public void bump (IndexSearcher searcher, LeafReaderContext context, int docID, Set terms, - ScoreMode scoreMode) throws IOException { + ScoreMode scoreMode, Map termContexts) throws IOException { df_stats.getData().clear(); idf_stats.getData().clear(); tf_stats.getData().clear(); @@ -61,22 +63,24 @@ public void bump (IndexSearcher searcher, LeafReaderContext context, break; } - TermStates termStates = TermStates.build(searcher.getTopReaderContext(), term, scoreMode.needsScores()); + TermStates termStates = termContexts.get(term); assert termStates != null && termStates .wasBuiltFor(ReaderUtil.getTopLevelContext(context)); TermState state = termStates.get(context); - if (state == null) { + if (state == null || termStates.docFreq() == 0) { insertZeroes(); // Zero out stats for terms we don't know about in the index continue; } + TermStatistics indexStats = searcher.termStatistics(term, termStates.docFreq(), termStates.totalTermFreq()); + // Collection Statistics - df_stats.add(termStates.docFreq()); - idf_stats.add(sim.idf(termStates.docFreq(), searcher.collectionStatistics(term.field()).docCount())); - ttf_stats.add(termStates.totalTermFreq()); + df_stats.add(indexStats.docFreq()); + idf_stats.add(sim.idf(indexStats.docFreq(), searcher.collectionStatistics(term.field()).docCount())); + ttf_stats.add(indexStats.totalTermFreq()); // Doc specifics TermsEnum termsEnum = context.reader().terms(term.field()).iterator();