Skip to content

Commit

Permalink
Add Asymmetric Model
Browse files Browse the repository at this point in the history
  • Loading branch information
nreimers committed Dec 26, 2020
1 parent 555fa6d commit fdc8bfa
Show file tree
Hide file tree
Showing 7 changed files with 143 additions and 83 deletions.
68 changes: 4 additions & 64 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,81 +391,21 @@ def smart_batching_collate(self, batch):
sentence_features = []
for idx in range(num_texts):
tokenized = self.tokenize(texts[idx])
for name in tokenized:
tokenized[name] = tokenized[name].to(self._target_device)
batch_to_device(tokenized, self._target_device)
sentence_features.append(tokenized)

return sentence_features, labels

"""
num_texts = len(batch[0][0])
labels = []
paired_texts = [[] for _ in range(num_texts)]
max_seq_len = [0] * num_texts
for tokens, label in batch:
labels.append(label)
for i in range(num_texts):
paired_texts[i].append(tokens[i])
max_seq_len[i] = max(max_seq_len[i], self._text_length(tokens[i]))
features = []
for idx in range(num_texts):
max_len = max_seq_len[idx]
feature_lists = {}
for text in paired_texts[idx]:
sentence_features = self.get_sentence_features(text, max_len)
for feature_name in sentence_features:
if feature_name not in feature_lists:
feature_lists[feature_name] = []
feature_lists[feature_name].append(sentence_features[feature_name])
for feature_name in feature_lists:
feature_lists[feature_name] = torch.cat(feature_lists[feature_name])
features.append(feature_lists)
return {'features': features, 'labels': torch.stack(labels)}
"""

def smart_batching_collate_text_only(self, batch):
"""
Transforms a batch from a SmartBatchingDataset to a batch of tensors for the model.
Here, batch is a list of texts
:param batch:
a batch from a SmartBatchingDataset
:return:
a batch of tensors for the model
"""

max_seq_len = max([self._text_length(text) for text in batch])
feature_lists = {}

for text in batch:
sentence_features = self.get_sentence_features(text, max_seq_len)
for feature_name in sentence_features:
if feature_name not in feature_lists:
feature_lists[feature_name] = []

feature_lists[feature_name].append(sentence_features[feature_name])

for feature_name in feature_lists:
feature_lists[feature_name] = torch.cat(feature_lists[feature_name])

return feature_lists

def _text_length(self, text: Union[List[int], List[List[int]]]):
"""
Help function to get the length for the input text. Text can be either
a list of ints (which means a single text as input), or a tuple of list of ints
(representing several text inputs to the model).
"""
if len(text) == 0 or isinstance(text[0], int):
if isinstance(text, dict):
return len(next(iter(text.values())))
elif len(text) == 0 or isinstance(text[0], int):
return len(text)
else:
return sum([len(t) for t in text])
Expand Down
107 changes: 107 additions & 0 deletions sentence_transformers/models/Asym.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import torch
from torch import Tensor
from torch import nn
from torch import functional as F
from typing import Union, Tuple, List, Iterable, Dict
import os
import json
from ..util import fullname, import_from_string
from collections import OrderedDict

class Asym(nn.Sequential):
def __init__(self, sub_modules: Dict[str, List[nn.Module]], allow_empty_key: bool = True):
"""
This model allows to create asymmetric SentenceTransformer models, that apply different models depending on the specified input key.
In the below example, we create two different Dense models for 'query' and 'doc'. Text that is passed as {'query': 'My query'} will
be passed along along the first Dense model, and text that will be passed as {'doc': 'My document'} will use the other Dense model.
Note, that when you call encode(), that only inputs of the same type can be encoded. Mixed-Types cannot be encoded.
Example::
word_embedding_model = models.Transformer(model_name)
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
asym_model = models.Asym({'query': [models.Dense(word_embedding_model.get_word_embedding_dimension(), 128)], 'doc': [models.Dense(word_embedding_model.get_word_embedding_dimension(), 128)]})
model = SentenceTransformer(modules=[word_embedding_model, pooling_model, asym_model])
model.encode([{'query': 'Q1'}, {'query': 'Q2'}]
model.encode([{'doc': 'Doc1'}, {'doc': 'Doc2'}]
#You can train it with InputExample like this. Note, that the order must always be the same:
train_example = InputExample(texts=[{'query': 'Train query', 'doc': 'Doc query'}], label=1)
:param sub_modules: Dict in the format str -> List[models]. The models in the specified list will be applied for input marked with the respective key.
:param allow_empty_key: If true, inputs without a key can be processed. If false, an exception will be thrown if no key is specified.
"""
self.sub_modules = sub_modules
self.allow_empty_key = allow_empty_key

ordered_dict = OrderedDict()
for name, models in sub_modules.items():
if not isinstance(models, List):
models = [models]

for idx, model in enumerate(models):
ordered_dict[name+"-"+str(idx)] = model
super(Asym, self).__init__(ordered_dict)


def forward(self, features: Dict[str, Tensor]):
if 'text_keys' in features and len(features['text_keys']) > 0:
text_key = features['text_keys'][0]
for model in self.sub_modules[text_key]:
features = model(features)
elif not self.allow_empty_key:
raise ValueError('Input did not specify any keys and allow_empty_key is False')

return features

def get_sentence_embedding_dimension(self) -> int:
raise NotImplementedError()

def save(self, output_path):
model_lookup = {}
model_types = {}
model_structure = {}

for name, models in self.sub_modules.items():
model_structure[name] = []
for model in models:
model_id = str(id(model))+'_'+type(model).__name__
model_lookup[model_id] = model
model_types[model_id] = type(model).__module__
model_structure[name].append(model_id)

for model_id, model in model_lookup.items():
model_path = os.path.join(output_path, str(model_id))
os.makedirs(model_path, exist_ok=True)
model.save(model_path)

with open(os.path.join(output_path, 'config.json'), 'w', encoding='utf8') as fOut:
json.dump({'types': model_types, 'structure': model_structure,
'parameters': {'allow_empty_key': self.allow_empty_key}},
fOut, indent=2)


@staticmethod
def load(input_path):
with open(os.path.join(input_path, 'config.json')) as fIn:
config = json.load(fIn)

modules = {}
for model_id, model_type in config['types'].items():
module_class = import_from_string(model_type)
module = module_class.load(os.path.join(input_path, model_id))
modules[model_id] = module

model_structure = {}
for key_name, models_list in config['structure'].items():
model_structure[key_name] = []
for model_id in models_list:
model_structure[key_name].append(modules[model_id])

model = Asym(model_structure, **config['parameters'])
return model
10 changes: 9 additions & 1 deletion sentence_transformers/models/Dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,23 @@ class Dense(nn.Module):
:param out_features: Output size
:param bias: Add a bias vector
:param activation_function: Pytorch activation function applied on output
:param init_weight: Initial value for the matrix of the linear layer
:param init_bias: Initial value for the bias of the linear layer
"""
def __init__(self, in_features: int, out_features: int, bias: bool = True, activation_function=nn.Tanh()):
def __init__(self, in_features: int, out_features: int, bias: bool = True, activation_function=nn.Tanh(), init_weight: Tensor = None, init_bias: Tensor = None):
super(Dense, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.bias = bias
self.activation_function = activation_function
self.linear = nn.Linear(in_features, out_features, bias=bias)

if init_weight is not None:
self.linear.weight = nn.Parameter(init_weight)

if init_bias is not None:
self.linear.bias = nn.Parameter(init_bias)

def forward(self, features: Dict[str, Tensor]):
features.update({'sentence_embedding': self.activation_function(self.linear(features['sentence_embedding']))})
return features
Expand Down
22 changes: 18 additions & 4 deletions sentence_transformers/models/Transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ def __init__(self, model_name_or_path: str, max_seq_length: int = 128,

def forward(self, features):
"""Returns token_embeddings, cls_token"""
output_states = self.auto_model(**features, return_dict=False)
trans_features = {'input_ids': features['input_ids'], 'attention_mask': features['attention_mask']}
if 'token_type_ids' in features:
trans_features['token_type_ids'] = features['token_type_ids']

output_states = self.auto_model(**trans_features, return_dict=False)
output_tokens = output_states[0]

cls_tokens = output_tokens[:, 0, :] # CLS token is first token
Expand All @@ -56,16 +60,26 @@ def tokenize(self, texts: Union[List[str], List[Tuple[str, str]]]):
"""
Tokenizes a text and maps tokens to token-ids
"""
output = {}
if isinstance(texts[0], str):
texts = [texts]
to_tokenize = [texts]
elif isinstance(texts[0], dict):
to_tokenize = []
output['text_keys'] = []
for lookup in texts:
text_key, text = next(iter(lookup.items()))
to_tokenize.append(text)
output['text_keys'].append(text_key)
to_tokenize = [to_tokenize]
else:
batch1, batch2 = [], []
for text_tuple in texts:
batch1.append(text_tuple[0])
batch2.append(text_tuple[1])
texts = [batch1, batch2]
to_tokenize = [batch1, batch2]

return self.tokenizer(*texts, padding=True, truncation='longest_first', return_tensors="pt", max_length=self.max_seq_length)
output.update(self.tokenizer(*to_tokenize, padding=True, truncation='longest_first', return_tensors="pt", max_length=self.max_seq_length))
return output


def get_config_dict(self):
Expand Down
1 change: 1 addition & 0 deletions sentence_transformers/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .Transformer import Transformer
from .ALBERT import ALBERT
from .Asym import Asym
from .BERT import BERT
from .BoW import BoW
from .CNN import CNN
Expand Down
7 changes: 2 additions & 5 deletions sentence_transformers/readers/InputExample.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ class InputExample:
"""
Structure for one input example with texts, the label and a unique id
"""
def __init__(self, guid: str = '', texts: List[str] = None, texts_tokenized: List[List[int]] = None, label: Union[int, float] = 0):
def __init__(self, guid: str = '', texts: List[str] = None, label: Union[int, float] = 0):
"""
Creates one InputExample with the given texts, guid and label
Expand All @@ -14,14 +14,11 @@ def __init__(self, guid: str = '', texts: List[str] = None, texts_tokenized: Lis
id for the example
:param texts
the texts for the example. Note, str.strip() is called on the texts
:param texts_tokenized
Optional: Texts that are already tokenized. If texts_tokenized is passed, texts must not be passed.
:param label
the label for the example
"""
self.guid = guid
self.texts = [text.strip() for text in texts] if texts is not None else texts
self.texts_tokenized = texts_tokenized
self.texts = texts
self.label = label

def __str__(self):
Expand Down
11 changes: 2 additions & 9 deletions sentence_transformers/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,18 +216,11 @@ def batch_to_device(batch, target_device: device):
send a pytorch batch to a device (CPU/GPU)
"""
for key in batch:
batch[key] = batch[key].to(target_device)
if isinstance(batch[key], Tensor):
batch[key] = batch[key].to(target_device)
return batch

"""
features = batch['features']
for paired_sentence_idx in range(len(features)):
for feature_name in features[paired_sentence_idx]:
features[paired_sentence_idx][feature_name] = features[paired_sentence_idx][feature_name].to(target_device)

labels = batch['labels'].to(target_device)
return features, labels
"""

def fullname(o):
"""
Expand Down

0 comments on commit fdc8bfa

Please sign in to comment.