Skip to content

Commit

Permalink
add a bathed mode to RerankingEvaluator.py
Browse files Browse the repository at this point in the history
  • Loading branch information
nreimers committed Dec 15, 2021
1 parent d2f0293 commit d2f2e99
Showing 1 changed file with 73 additions and 1 deletion.
74 changes: 73 additions & 1 deletion sentence_transformers/evaluation/RerankingEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@ class RerankingEvaluator(SentenceEvaluator):
:param samples: Must be a list and each element is of the form: {'query': '', 'positive': [], 'negative': []}. Query is the search query,
positive is a list of positive (relevant) documents, negative is a list of negative (irrelevant) documents.
"""
def __init__(self, samples, mrr_at_k: int = 10, name: str = '', write_csv: bool = True, similarity_fct=cos_sim, batch_size: int = 64, show_progress_bar: bool = False):
def __init__(self, samples, mrr_at_k: int = 10, name: str = '', write_csv: bool = True, similarity_fct=cos_sim, batch_size: int = 64, show_progress_bar: bool = False, use_batched_encoding: bool = True):
self.samples = samples
self.name = name
self.mrr_at_k = mrr_at_k
self.similarity_fct = similarity_fct
self.batch_size = batch_size
self.show_progress_bar = show_progress_bar
self.use_batched_encoding = use_batched_encoding

if isinstance(self.samples, dict):
self.samples = list(self.samples.values())
Expand Down Expand Up @@ -79,6 +80,77 @@ def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int =
return mean_ap

def compute_metrices(self, model):
return self.compute_metrices_batched(model) if self.use_batched_encoding else self.compute_metrices_individual(model)

def compute_metrices_batched(self, model):
"""
Computes the metrices in a batched way, by batching all queries and
all documents together
"""
all_mrr_scores = []
all_ap_scores = []

all_query_embs = model.encode([sample['query'] for sample in self.samples],
convert_to_tensor=True,
batch_size=self.batch_size,
show_progress_bar=True) #self.show_progress_bar)

all_docs = []

for sample in self.samples:
all_docs.extend(sample['positive'])
all_docs.extend(sample['negative'])

all_docs_embs = model.encode(all_docs,
convert_to_tensor=True,
batch_size=self.batch_size,
show_progress_bar=self.show_progress_bar)

#Compute scores
query_idx, docs_idx = 0,0
for instance in self.samples:
query_emb = all_query_embs[query_idx]
query_idx += 1

num_pos = len(instance['positive'])
num_neg = len(instance['negative'])
docs_emb = all_docs_embs[docs_idx:docs_idx+num_pos+num_neg]
docs_idx += num_pos+num_neg

if num_pos == 0 or num_neg == 0:
continue

pred_scores = self.similarity_fct(query_emb, docs_emb)
if len(pred_scores.shape) > 1:
pred_scores = pred_scores[0]

pred_scores_argsort = torch.argsort(-pred_scores) #Sort in decreasing order

#Compute MRR score
is_relevant = [True]*num_pos + [False]*num_neg
mrr_score = 0
for rank, index in enumerate(pred_scores_argsort[0:self.mrr_at_k]):
if is_relevant[index]:
mrr_score = 1 / (rank+1)
break
all_mrr_scores.append(mrr_score)

# Compute AP
all_ap_scores.append(average_precision_score(is_relevant, pred_scores.cpu().tolist()))

mean_ap = np.mean(all_ap_scores)
mean_mrr = np.mean(all_mrr_scores)

return {'map': mean_ap, 'mrr': mean_mrr}


def compute_metrices_individual(self, model):
"""
Embeds every (query, positive, negative) tuple individually.
Is slower than the batched version, but saves memory as only the
embeddings for one tuple are needed. Useful when you have
a really large test set
"""
all_mrr_scores = []
all_ap_scores = []

Expand Down

0 comments on commit d2f2e99

Please sign in to comment.