Skip to content

Commit

Permalink
Display correct pooling mode in model card comment
Browse files Browse the repository at this point in the history
  • Loading branch information
aphedges committed Aug 14, 2021
1 parent 61e1d82 commit 25a4ed9
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def _create_model_card(self, path: str, model_name: Optional[str] = None):
pooling_mode = pooling_module.get_pooling_mode_str()
model_card = model_card.replace("{USAGE_TRANSFORMERS_SECTION}", ModelCardTemplate.__USAGE_TRANSFORMERS__)
pooling_fct_name, pooling_fct = ModelCardTemplate.model_card_get_pooling_function(pooling_mode)
model_card = model_card.replace("{POOLING_FUNCTION}", pooling_fct).replace("{POOLING_FUNCTION_NAME}", pooling_fct_name)
model_card = model_card.replace("{POOLING_FUNCTION}", pooling_fct).replace("{POOLING_FUNCTION_NAME}", pooling_fct_name).replace("{POOLING_MODE}", pooling_mode)
tags.append('transformers')

# Print full model
Expand Down
2 changes: 1 addition & 1 deletion sentence_transformers/model_card_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class ModelCardTemplate:
with torch.no_grad():
model_output = model(**encoded_input)
# Perform pooling. In this case, max pooling.
# Perform pooling. In this case, {POOLING_MODE} pooling.
sentence_embeddings = {POOLING_FUNCTION_NAME}(model_output, encoded_input['attention_mask'])
print("Sentence embeddings:")
Expand Down

0 comments on commit 25a4ed9

Please sign in to comment.