Skip to content

Commit

Permalink
Doc freq issue (#380)
Browse files Browse the repository at this point in the history
* workaround for issue #375

* fixes issue #381

* Restore nonVocab test, provide 0 scores for nonVocab terms

* Patch up the TermSupplier

* Cleanup whitespace change

* Fix params

* Add some tests to verify new logic, add shard param for explorer

* Fix failing IT

* Cleanup unneeded param

* More cleanup

* Move TermStates building to createWeight to satisfy DFS logic

Co-authored-by: = <[email protected]>
  • Loading branch information
ndkmath1 and worleydl authored Jul 6, 2021
1 parent 7f1be48 commit 74cde49
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 18 deletions.
123 changes: 123 additions & 0 deletions src/javaRestTest/java/com/o19s/es/ltr/ShardStatsIT.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package com.o19s.es.ltr;

import com.o19s.es.explore.ExplorerQueryBuilder;
import com.o19s.es.termstat.TermStatQueryBuilder;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.SearchType;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.TermQueryBuilder;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.test.ESIntegTestCase;

import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertSearchResponse;
import static org.hamcrest.Matchers.equalTo;

/*
These tests mostly verify that shard vs collection stat counting is working as expected.
*/
public class ShardStatsIT extends ESIntegTestCase {
@Override
protected int numberOfShards() {
return 2;
}

protected void createIdx() {
prepareCreate("idx")
.addMapping("type", "s", "type=text");

for (int i = 0; i < 4; i++) {
indexDoc(i);
}
refreshIndex();
}

protected void indexDoc(int id) {
client().prepareIndex("idx", "type", Integer.toString(id))
.setRouting( ((id % 2) == 0 ) ? "a" : "b" )
.setSource("s", "zzz").get();
}

protected void refreshIndex() {
client().admin().indices().prepareRefresh("idx").get();
}

public void testDfsExplorer() throws Exception {
createIdx();

QueryBuilder q = new TermQueryBuilder("s", "zzz");

ExplorerQueryBuilder eq = new ExplorerQueryBuilder()
.query(q)
.statsType("min_raw_df");

final SearchResponse r = client().prepareSearch("idx")
.setSearchType(SearchType.DFS_QUERY_THEN_FETCH)
.setQuery(eq).get();

assertSearchResponse(r);

SearchHits hits = r.getHits();
assertThat(hits.getAt(0).getScore(), equalTo(4.0f));
}

public void testNonDfsExplorer() throws Exception {
createIdx();

QueryBuilder q = new TermQueryBuilder("s", "zzz");

ExplorerQueryBuilder eq = new ExplorerQueryBuilder()
.query(q)
.statsType("min_raw_df");

final SearchResponse r = client().prepareSearch("idx")
.setSearchType(SearchType.QUERY_THEN_FETCH)
.setQuery(eq).get();

assertSearchResponse(r);

SearchHits hits = r.getHits();
assertThat(hits.getAt(0).getScore(), equalTo(2.0f));
}

public void testDfsTSQ() throws Exception {
createIdx();

TermStatQueryBuilder tsq = new TermStatQueryBuilder()
.expr("df")
.aggr("min")
.posAggr("min")
.terms(new String[]{"zzz"})
.fields(new String[]{"s"});

final SearchResponse r = client().prepareSearch("idx")
.setSearchType(SearchType.DFS_QUERY_THEN_FETCH)
.setQuery(tsq)
.get();

assertSearchResponse(r);

SearchHits hits = r.getHits();
assertThat(hits.getAt(0).getScore(), equalTo(4.0f));
}

public void testNonDfsTSQ() throws Exception {
createIdx();

TermStatQueryBuilder tsq = new TermStatQueryBuilder()
.expr("df")
.aggr("min")
.posAggr("min")
.terms(new String[]{"zzz"})
.fields(new String[]{"s"});

final SearchResponse r = client().prepareSearch("idx")
.setSearchType(SearchType.QUERY_THEN_FETCH)
.setQuery(tsq)
.get();

assertSearchResponse(r);

SearchHits hits = r.getHits();
assertThat(hits.getAt(0).getScore(), equalTo(2.0f));
}
}
15 changes: 10 additions & 5 deletions src/main/java/com/o19s/es/explore/ExplorerQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Weight;
import org.apache.lucene.search.TermStatistics;
import org.apache.lucene.search.ConstantScoreWeight;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.Scorer;
Expand Down Expand Up @@ -104,11 +105,15 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo

for (Term term : terms) {
TermStates ctx = TermStates.build(searcher.getTopReaderContext(), term, scoreMode.needsScores());

if(ctx != null){
df_stats.add(ctx.docFreq());
idf_stats.add(sim.idf(ctx.docFreq(), searcher.collectionStatistics(term.field()).docCount()));
ttf_stats.add(ctx.totalTermFreq());
if(ctx != null && ctx.docFreq() > 0){
TermStatistics tStats = searcher.termStatistics(term, ctx.docFreq(), ctx.totalTermFreq());
df_stats.add(tStats.docFreq());
idf_stats.add(sim.idf(tStats.docFreq(), searcher.collectionStatistics(term.field()).docCount()));
ttf_stats.add(tStats.totalTermFreq());
} else {
df_stats.add(0.0f);
idf_stats.add(0.0f);
ttf_stats.add(0.0f);
}
}

Expand Down
18 changes: 16 additions & 2 deletions src/main/java/com/o19s/es/ltr/feature/store/ScriptFeature.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.apache.lucene.analysis.tokenattributes.TermToBytesRefAttribute;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermStates;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
Expand Down Expand Up @@ -286,18 +287,31 @@ static class LtrScriptWeight extends Weight {
private final ScriptScoreFunction function;
private final TermStatSupplier termStatSupplier;
private final Set<Term> terms;
private final HashMap<Term, TermStates> termContexts;

LtrScriptWeight(Query query, ScriptScoreFunction function,
TermStatSupplier termStatSupplier,
Set<Term> terms,
IndexSearcher searcher,
ScoreMode scoreMode) {
ScoreMode scoreMode) throws IOException {
super(query);
this.function = function;
this.termStatSupplier = termStatSupplier;
this.terms = terms;
this.searcher = searcher;
this.scoreMode = scoreMode;
this.termContexts = new HashMap<>();

if (scoreMode.needsScores()) {
for (Term t : terms) {
TermStates ctx = TermStates.build(searcher.getTopReaderContext(), t, true);
if (ctx != null && ctx.docFreq() > 0) {
searcher.collectionStatistics(t.field());
searcher.termStatistics(t, ctx.docFreq(), ctx.totalTermFreq());
}
termContexts.put(t, ctx);
}
}
}

@Override
Expand All @@ -319,7 +333,7 @@ public int docID() {
public float score() throws IOException {
// Do the terms magic if the user asked for it
if (terms.size() > 0) {
termStatSupplier.bump(searcher, context, docID(), terms, scoreMode);
termStatSupplier.bump(searcher, context, docID(), terms, scoreMode, termContexts);
}

return (float) leafScoreFunction.score(iterator.docID(), 0F);
Expand Down
30 changes: 27 additions & 3 deletions src/main/java/com/o19s/es/termstat/TermStatQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermStates;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
Expand All @@ -14,6 +15,8 @@
import org.apache.lucene.search.Weight;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Set;

Expand Down Expand Up @@ -81,15 +84,36 @@ static class TermStatWeight extends Weight {
private final AggrType aggr;
private final AggrType posAggr;
private final Set<Term> terms;

TermStatWeight(IndexSearcher searcher, TermStatQuery tsq, Set<Term> terms, ScoreMode scoreMode, AggrType aggr, AggrType posAggr) {
private final Map<Term, TermStates> termContexts;

TermStatWeight(IndexSearcher searcher,
TermStatQuery tsq,
Set<Term> terms,
ScoreMode scoreMode,
AggrType aggr,
AggrType posAggr) throws IOException {
super(tsq);
this.searcher = searcher;
this.expression = tsq.expr;
this.terms = terms;
this.scoreMode = scoreMode;
this.aggr = aggr;
this.posAggr = posAggr;
this.termContexts = new HashMap<>();

// This is needed for proper DFS_QUERY_THEN_FETCH support
if (scoreMode.needsScores()) {
for (Term t : terms) {
TermStates ctx = TermStates.build(searcher.getTopReaderContext(), t, true);

if (ctx != null && ctx.docFreq() > 0) {
searcher.collectionStatistics(t.field());
searcher.termStatistics(t, ctx.docFreq(), ctx.totalTermFreq());
}

termContexts.put(t, ctx);
}
}
}

@Override
Expand All @@ -110,7 +134,7 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio

@Override
public Scorer scorer(LeafReaderContext context) throws IOException {
return new TermStatScorer(this, searcher, context, expression, terms, scoreMode, aggr, posAggr);
return new TermStatScorer(this, searcher, context, expression, terms, scoreMode, aggr, posAggr, termContexts);
}

@Override
Expand Down
9 changes: 7 additions & 2 deletions src/main/java/com/o19s/es/termstat/TermStatScorer.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;

import org.apache.lucene.index.TermStates;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.DoubleValues;
import org.apache.lucene.search.DoubleValuesSource;
Expand All @@ -16,6 +17,7 @@

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;

public class TermStatScorer extends Scorer {
Expand All @@ -29,6 +31,7 @@ public class TermStatScorer extends Scorer {
private final IndexSearcher searcher;
private final Set<Term> terms;
private final ScoreMode scoreMode;
private final Map<Term, TermStates> termContexts;

public TermStatScorer(TermStatQuery.TermStatWeight weight,
IndexSearcher searcher,
Expand All @@ -37,7 +40,8 @@ public TermStatScorer(TermStatQuery.TermStatWeight weight,
Set<Term> terms,
ScoreMode scoreMode,
AggrType aggr,
AggrType posAggr) {
AggrType posAggr,
Map<Term, TermStates> termContexts) {
super(weight);
this.context = context;
this.compiledExpression = compiledExpression;
Expand All @@ -46,6 +50,7 @@ public TermStatScorer(TermStatQuery.TermStatWeight weight,
this.scoreMode = scoreMode;
this.aggr = aggr;
this.posAggr = posAggr;
this.termContexts = termContexts;

this.iter = DocIdSetIterator.all(context.reader().maxDoc());
}
Expand All @@ -65,7 +70,7 @@ public float score() throws IOException {

// Refresh the term stats
tsq.setPosAggr(posAggr);
tsq.bump(searcher, context, docID(), terms, scoreMode);
tsq.bump(searcher, context, docID(), terms, scoreMode, termContexts);

// Prepare computed statistics
StatisticsHelper computed = new StatisticsHelper();
Expand Down
16 changes: 10 additions & 6 deletions src/main/java/com/o19s/es/termstat/TermStatSupplier.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.TermStatistics;
import org.apache.lucene.search.similarities.ClassicSimilarity;

import java.io.IOException;
Expand All @@ -22,6 +23,7 @@
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;


Expand All @@ -47,7 +49,7 @@ public TermStatSupplier() {

public void bump (IndexSearcher searcher, LeafReaderContext context,
int docID, Set<Term> terms,
ScoreMode scoreMode) throws IOException {
ScoreMode scoreMode, Map<Term, TermStates> termContexts) throws IOException {
df_stats.getData().clear();
idf_stats.getData().clear();
tf_stats.getData().clear();
Expand All @@ -61,22 +63,24 @@ public void bump (IndexSearcher searcher, LeafReaderContext context,
break;
}

TermStates termStates = TermStates.build(searcher.getTopReaderContext(), term, scoreMode.needsScores());
TermStates termStates = termContexts.get(term);

assert termStates != null && termStates
.wasBuiltFor(ReaderUtil.getTopLevelContext(context));

TermState state = termStates.get(context);

if (state == null) {
if (state == null || termStates.docFreq() == 0) {
insertZeroes(); // Zero out stats for terms we don't know about in the index
continue;
}

TermStatistics indexStats = searcher.termStatistics(term, termStates.docFreq(), termStates.totalTermFreq());

// Collection Statistics
df_stats.add(termStates.docFreq());
idf_stats.add(sim.idf(termStates.docFreq(), searcher.collectionStatistics(term.field()).docCount()));
ttf_stats.add(termStates.totalTermFreq());
df_stats.add(indexStats.docFreq());
idf_stats.add(sim.idf(indexStats.docFreq(), searcher.collectionStatistics(term.field()).docCount()));
ttf_stats.add(indexStats.totalTermFreq());

// Doc specifics
TermsEnum termsEnum = context.reader().terms(term.field()).iterator();
Expand Down

0 comments on commit 74cde49

Please sign in to comment.