Skip to content

Commit

Permalink
support of T5 model in models.Transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
nreimers committed Jan 19, 2022
1 parent d5b0115 commit 06a38f6
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 94 deletions.
91 changes: 0 additions & 91 deletions sentence_transformers/models/T5.py

This file was deleted.

20 changes: 17 additions & 3 deletions sentence_transformers/models/Transformer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from torch import nn
from transformers import AutoModel, AutoTokenizer, AutoConfig
from transformers import AutoModel, AutoTokenizer, AutoConfig, T5Config
import json
from typing import List, Dict, Optional, Union, Tuple
import os
Expand All @@ -26,7 +26,8 @@ def __init__(self, model_name_or_path: str, max_seq_length: Optional[int] = None
self.do_lower_case = do_lower_case

config = AutoConfig.from_pretrained(model_name_or_path, **model_args, cache_dir=cache_dir)
self.auto_model = AutoModel.from_pretrained(model_name_or_path, config=config, cache_dir=cache_dir)
self._load_model(model_name_or_path, config, cache_dir)

self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path if tokenizer_name_or_path is not None else model_name_or_path, cache_dir=cache_dir, **tokenizer_args)

#No max_seq_length set. Try to infer from model
Expand All @@ -39,6 +40,20 @@ def __init__(self, model_name_or_path: str, max_seq_length: Optional[int] = None
if tokenizer_name_or_path is not None:
self.auto_model.config.tokenizer_class = self.tokenizer.__class__.__name__


def _load_model(self, model_name_or_path, config, cache_dir):
"""Loads the transformer model"""
if isinstance(config, T5Config):
self._load_t5_model(model_name_or_path, config, cache_dir)
else:
self.auto_model = AutoModel.from_pretrained(model_name_or_path, config=config, cache_dir=cache_dir)

def _load_t5_model(self, model_name_or_path, config, cache_dir):
"""Loads the encoder model from T5"""
from transformers import T5EncoderModel
T5EncoderModel._keys_to_ignore_on_load_unexpected = ["decoder.*"]
self.auto_model = T5EncoderModel.from_pretrained(model_name_or_path, config=config, cache_dir=cache_dir)

def __repr__(self):
return "Transformer({}) with Transformer model: {} ".format(self.get_config_dict(), self.auto_model.__class__.__name__)

Expand Down Expand Up @@ -95,7 +110,6 @@ def tokenize(self, texts: Union[List[str], List[Dict], List[Tuple[str, str]]]):
if self.do_lower_case:
to_tokenize = [[s.lower() for s in col] for col in to_tokenize]


output.update(self.tokenizer(*to_tokenize, padding=True, truncation='longest_first', return_tensors="pt", max_length=self.max_seq_length))
return output

Expand Down

0 comments on commit 06a38f6

Please sign in to comment.