From 25a4ed9306284fdb4f4fe5def2da1639542ea032 Mon Sep 17 00:00:00 2001 From: Alex Hedges Date: Fri, 13 Aug 2021 20:49:57 -0400 Subject: [PATCH] Display correct pooling mode in model card comment --- sentence_transformers/SentenceTransformer.py | 2 +- sentence_transformers/model_card_templates.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index 07fc2a6ca..c12a67d63 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -386,7 +386,7 @@ def _create_model_card(self, path: str, model_name: Optional[str] = None): pooling_mode = pooling_module.get_pooling_mode_str() model_card = model_card.replace("{USAGE_TRANSFORMERS_SECTION}", ModelCardTemplate.__USAGE_TRANSFORMERS__) pooling_fct_name, pooling_fct = ModelCardTemplate.model_card_get_pooling_function(pooling_mode) - model_card = model_card.replace("{POOLING_FUNCTION}", pooling_fct).replace("{POOLING_FUNCTION_NAME}", pooling_fct_name) + model_card = model_card.replace("{POOLING_FUNCTION}", pooling_fct).replace("{POOLING_FUNCTION_NAME}", pooling_fct_name).replace("{POOLING_MODE}", pooling_mode) tags.append('transformers') # Print full model diff --git a/sentence_transformers/model_card_templates.py b/sentence_transformers/model_card_templates.py index ac96c22b7..c8a15d28a 100644 --- a/sentence_transformers/model_card_templates.py +++ b/sentence_transformers/model_card_templates.py @@ -105,7 +105,7 @@ class ModelCardTemplate: with torch.no_grad(): model_output = model(**encoded_input) -# Perform pooling. In this case, max pooling. +# Perform pooling. In this case, {POOLING_MODE} pooling. sentence_embeddings = {POOLING_FUNCTION_NAME}(model_output, encoded_input['attention_mask']) print("Sentence embeddings:")