Skip to content

Commit

Permalink
epam#1826 - Pagination and sorting for ElasticStream
Browse files Browse the repository at this point in the history
  • Loading branch information
uladkaminski committed Mar 15, 2024
1 parent fb2f17e commit 55acb51
Show file tree
Hide file tree
Showing 6 changed files with 182 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
import com.epam.indigo.model.Helpers;
import com.epam.indigo.model.IndigoRecord;
import com.epam.indigo.model.NamingConstants;
import com.epam.indigo.predicate.*;
import com.epam.indigo.predicate.BaseMatch;
import com.epam.indigo.predicate.ExactMatch;
import com.epam.indigo.predicate.FilterPredicate;
import com.epam.indigo.predicate.IndigoPredicate;
import com.epam.indigo.predicate.SubstructureMatch;
import com.epam.indigo.sort.IndigoComparator;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.RequestOptions;
Expand All @@ -15,13 +20,29 @@
import org.elasticsearch.script.Script;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.sort.FieldSortBuilder;
import org.elasticsearch.search.sort.SortOrder;

import java.io.IOException;
import java.util.*;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Spliterator;
import java.util.function.*;
import java.util.stream.*;
import java.util.stream.Collector;
import java.util.stream.DoubleStream;
import java.util.stream.IntStream;
import java.util.stream.LongStream;
import java.util.stream.Stream;

import static com.epam.indigo.model.NamingConstants.*;
import static com.epam.indigo.model.NamingConstants.SIM_FINGERPRINT;
import static com.epam.indigo.model.NamingConstants.SIM_FINGERPRINT_LEN;
import static com.epam.indigo.model.NamingConstants.SUB_FINGERPRINT;
import static com.epam.indigo.model.NamingConstants.SUB_FINGERPRINT_LEN;

/**
* Implementation of JDK Stream API
Expand All @@ -32,8 +53,10 @@ public class ElasticStream<T extends IndigoRecord> implements Stream<T> {
private final RestHighLevelClient elasticClient;
private final List<IndigoPredicate<? super T>> predicates = new ArrayList<>();
private final String indexName;
private int size = 10;
private final int MAX_ALLOWED_SIZE = 1000;
private int limit = Integer.MAX_VALUE;
private final List<IndigoComparator<? super T>> comparators = new ArrayList<>();

private static final int BATCH_SIZE = 10000;

public ElasticStream(RestHighLevelClient elasticClient, String indexName) {
this.elasticClient = elasticClient;
Expand All @@ -51,9 +74,9 @@ public Stream<T> filter(Predicate<? super T> predicate) {

@Override
public Stream<T> limit(long maxSize) {
if (maxSize > MAX_ALLOWED_SIZE)
throw new IllegalArgumentException(String.format("Bingo Elastic max page size should be less than or equal to %1", MAX_ALLOWED_SIZE));
this.size = (int) maxSize;
if (maxSize > Integer.MAX_VALUE)
throw new IllegalArgumentException(String.format("Bingo Elastic max page size should be less than or equal to %1$d", Integer.MAX_VALUE));
this.limit = (int) maxSize;
return this;
}

Expand All @@ -65,26 +88,51 @@ public boolean isParallel() {
@Override
public <R, A> R collect(Collector<? super T, A, R> collector) {
A container = collector.supplier().get();
SearchRequest searchRequest = compileRequest();
SearchHit[] hits;
try {
SearchResponse searchResponse = this.elasticClient.search(searchRequest, RequestOptions.DEFAULT);
hits = searchResponse.getHits().getHits();
if (NamingConstants.BINGO_REACTIONS.equals(this.indexName)) {
for (SearchHit hit : hits) {
collector.accumulator().accept(container, (T) Helpers.reactionFromElastic(hit.getId(), hit.getSourceAsMap(), hit.getScore()));
}
} else if (NamingConstants.BINGO_MOLECULES.equals(this.indexName)) {
for (SearchHit hit : hits) {
collector.accumulator().accept(container, (T) Helpers.moleculeFromElastic(hit.getId(), hit.getSourceAsMap(), hit.getScore()));
Object[] searchAfterParameters = null;

long processedRecords = 0;
boolean continueSearch = true;

while (continueSearch) {
int currentBatchSize = (int) Math.min(BATCH_SIZE, limit - processedRecords);
SearchRequest searchRequest = compileRequest(searchAfterParameters, currentBatchSize);
SearchResponse searchResponse;
try {
searchResponse = elasticClient.search(searchRequest, RequestOptions.DEFAULT);
} catch (IOException e) {
throw new BingoElasticException("Couldn't complete search in Elasticsearch", e);
}

SearchHit[] hits = searchResponse.getHits().getHits();
if (hits.length == 0) {
break;
}

for (SearchHit hit : hits) {
if (processedRecords >= limit) {
break;
}
} else {
throw new BingoElasticException("Unsupported index " + this.indexName);
T record = convertHitToRecord(hit);
collector.accumulator().accept(container, record);
processedRecords++;
}
} catch (IOException e) {
throw new BingoElasticException("Couldn't complete search in Elasticsearch", e.getCause());

searchAfterParameters = hits[hits.length - 1].getSortValues();
continueSearch = !this.comparators.isEmpty() && hits.length == BATCH_SIZE;
}

return collector.finisher().apply(container);
}


private T convertHitToRecord(SearchHit hit) {
if (NamingConstants.BINGO_REACTIONS.equals(this.indexName)) {
return (T) Helpers.reactionFromElastic(hit.getId(), hit.getSourceAsMap(), hit.getScore());
} else if (NamingConstants.BINGO_MOLECULES.equals(this.indexName)) {
return (T) Helpers.moleculeFromElastic(hit.getId(), hit.getSourceAsMap(), hit.getScore());
} else {
throw new BingoElasticException("Unsupported index " + this.indexName);
}
return (R) container;
}

private QueryBuilder[] generateClauses(List<Integer> fingerprint, String field) {
Expand All @@ -95,15 +143,26 @@ private QueryBuilder[] generateClauses(List<Integer> fingerprint, String field)
return bits;
}

private SearchRequest compileRequest() {
private SearchRequest compileRequest(Object[] searchAfterParameters, int batchSize) {
SearchRequest searchRequest = new SearchRequest(this.indexName);
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();

boolean similarityRequested = false;
boolean isEmptyFingerprint = false;

searchSourceBuilder.size(batchSize);

if (!comparators.isEmpty()) {
comparators.stream().map(IndigoComparator::toSortBuilder).forEach(searchSourceBuilder::sort);
searchSourceBuilder.sort(new FieldSortBuilder("_doc").order(SortOrder.ASC));
}


if (this.predicates.isEmpty()) {
searchSourceBuilder.query(QueryBuilders.matchAllQuery());
} else {
BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery();

Script script = null;
float threshold = 0.0f;
for (IndigoPredicate<? super T> predicate : this.predicates) {
Expand Down Expand Up @@ -141,13 +200,26 @@ private SearchRequest compileRequest() {
}
searchSourceBuilder.fetchSource(new String[]{"*"}, new String[]{SIM_FINGERPRINT, SIM_FINGERPRINT_LEN, SUB_FINGERPRINT_LEN, SUB_FINGERPRINT});
searchSourceBuilder.minScore(threshold);
searchSourceBuilder.size(this.size);
searchSourceBuilder.query(QueryBuilders.scriptScoreQuery(boolQueryBuilder, script));
}

if (searchAfterParameters != null) {
searchSourceBuilder.searchAfter(searchAfterParameters);
}

searchRequest.source(searchSourceBuilder);
return searchRequest;
}

@Override
public ElasticStream<T> sorted(Comparator<? super T> comparator) {
if (!(comparator instanceof IndigoComparator)) {
throw new IllegalArgumentException("Comparator used isn't an IndigoComparator");
}
comparators.add((IndigoComparator) comparator);
return this;
}

private Script generateIdentityScore() {
Map<String, Object> map = new HashMap<>();
map.put("source", "_score");
Expand Down Expand Up @@ -255,11 +327,6 @@ public Stream<T> sorted() {
throw new BingoElasticException("sorted() operation on this stream isn't implemented");
}

@Override
public Stream<T> sorted(Comparator<? super T> comparator) {
throw new BingoElasticException("sorted() operation on this stream isn't implemented");
}

@Override
public Stream<T> peek(Consumer<? super T> action) {
throw new BingoElasticException("peek() operation on this stream isn't implemented");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package com.epam.indigo.sort;

import com.epam.indigo.model.IndigoRecord;
import org.elasticsearch.search.sort.FieldSortBuilder;
import org.elasticsearch.search.sort.SortBuilder;
import org.elasticsearch.search.sort.SortOrder;

public class FieldComparator<T extends IndigoRecord> extends IndigoComparator<T> {

protected String fieldName;

public FieldComparator(final String fieldName, final SortOrder sortOrder) {
super(sortOrder);
this.fieldName = fieldName;
}

@Override
public int compare(final T o1, final T o2) {
// does not expect to be called
return 0;
}

@Override
public SortBuilder<FieldSortBuilder> toSortBuilder() {
return new FieldSortBuilder(this.fieldName).order(this.sortOrder);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package com.epam.indigo.sort;

import com.epam.indigo.model.IndigoRecord;
import org.elasticsearch.search.sort.SortBuilder;
import org.elasticsearch.search.sort.SortOrder;

import java.util.Comparator;


public abstract class IndigoComparator<T extends IndigoRecord> implements Comparator<T> {
protected SortOrder sortOrder;

public IndigoComparator(SortOrder sortOrder) {
this.sortOrder = sortOrder;
}

public abstract SortBuilder toSortBuilder();

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package com.epam.indigo.sort;

import com.epam.indigo.model.IndigoRecord;
import org.elasticsearch.search.sort.ScoreSortBuilder;
import org.elasticsearch.search.sort.SortBuilder;
import org.elasticsearch.search.sort.SortOrder;

public class ScoreComparator<T extends IndigoRecord> extends IndigoComparator<T> {

public ScoreComparator() {
super(SortOrder.DESC);
}

public ScoreComparator(SortOrder sortOrder) {
super(sortOrder);
}

@Override
public SortBuilder<ScoreSortBuilder> toSortBuilder() {
return new ScoreSortBuilder().order(sortOrder);
}

@Override
public int compare(final T o1, final T o2) {
// does not expect to be called
return 0;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ public void rangeQueryWithTanimoto() {
.filter(new RangeQuery<>(fieldName, 10, 100))
.collect(Collectors.toList());

assertEquals(Math.min(10, cnt), similarRecords.size());
assertEquals(cnt, similarRecords.size());
} catch (Exception exception) {
Assertions.fail("Exception happened during test " + exception.getMessage());
}
Expand Down Expand Up @@ -272,11 +272,11 @@ public void reactionTanimoto() {
}

@Test
@DisplayName("Page size of 2000 should throw exception")
@DisplayName("Page size of Integer.MAX_VALUE should throw exception")
public void pageSizeOverLimit() {
assertThrows(IllegalArgumentException.class, () -> repository.stream()
.filter(new KeywordQuery<>("test", "test"))
.limit(2000)
.limit((long) Integer.MAX_VALUE + 1)
.collect(Collectors.toList()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,10 @@ public void saveFromSdfFile() {
Helpers.iterateSdf("src/test/resources/rand_queries_small.sdf").forEach(indigoRecordList::add);
repository.indexRecords(indigoRecordList, indigoRecordList.size());
TimeUnit.SECONDS.sleep(5);
List<IndigoRecordMolecule> collect = repository.stream().collect(Collectors.toList());
assertEquals(10, collect.size());
List<IndigoRecordMolecule> fullCollection = repository.stream().collect(Collectors.toList());
List<IndigoRecordMolecule> limitCollection = repository.stream().limit(20).collect(Collectors.toList());
assertEquals(20, limitCollection.size());
assertEquals(371, fullCollection.size());
} catch (Exception exception) {
Assertions.fail(exception);
}
Expand All @@ -93,7 +95,7 @@ public void saveFromCmlFile() {
repository.indexRecords(indigoRecordList, indigoRecordList.size());
TimeUnit.SECONDS.sleep(5);
List<IndigoRecord> collect = repository.stream().collect(Collectors.toList());
assertEquals(10, collect.size());
assertEquals(163, collect.size());
} catch (Exception exception) {
Assertions.fail(exception);
}
Expand Down

0 comments on commit 55acb51

Please sign in to comment.