diff --git a/sentence_transformers/losses/ContrastiveLoss.py b/sentence_transformers/losses/ContrastiveLoss.py index 87fe3a44d..f8d034059 100644 --- a/sentence_transformers/losses/ContrastiveLoss.py +++ b/sentence_transformers/losses/ContrastiveLoss.py @@ -28,16 +28,19 @@ class ContrastiveLoss(nn.Module): Example:: - from sentence_transformers import SentenceTransformer, SentencesDataset, LoggingHandler, losses - from sentence_transformers.readers import InputExample + from sentence_transformers import SentenceTransformer, LoggingHandler, losses, InputExample + from torch.utils.data import DataLoader - model = SentenceTransformer('distilbert-base-nli-mean-tokens') - train_examples = [InputExample(texts=['This is a positive pair', 'Where the distance will be minimized'], label=1), + model = SentenceTransformer('all-MiniLM-L6-v2') + train_examples = [ + InputExample(texts=['This is a positive pair', 'Where the distance will be minimized'], label=1), InputExample(texts=['This is a negative pair', 'Their distance will be increased'], label=0)] - train_dataset = SentencesDataset(train_examples, model) - train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size) + + train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=2) train_loss = losses.ContrastiveLoss(model=model) + model.fit([(train_dataloader, train_loss)], show_progress_bar=True) + """ def __init__(self, model: SentenceTransformer, distance_metric=SiameseDistanceMetric.COSINE_DISTANCE, margin: float = 0.5, size_average:bool = True): diff --git a/sentence_transformers/losses/OnlineContrastiveLoss.py b/sentence_transformers/losses/OnlineContrastiveLoss.py index f03483e0e..285e63219 100644 --- a/sentence_transformers/losses/OnlineContrastiveLoss.py +++ b/sentence_transformers/losses/OnlineContrastiveLoss.py @@ -18,15 +18,18 @@ class OnlineContrastiveLoss(nn.Module): Example:: - from sentence_transformers import SentenceTransformer, SentencesDataset, LoggingHandler, losses - from sentence_transformers.readers import InputExample + from sentence_transformers import SentenceTransformer, LoggingHandler, losses, InputExample + from torch.utils.data import DataLoader - model = SentenceTransformer('distilbert-base-nli-mean-tokens') - train_examples = [InputExample(texts=['This is a positive pair', 'Where the distance will be minimized'], label=1), + model = SentenceTransformer('all-MiniLM-L6-v2') + train_examples = [ + InputExample(texts=['This is a positive pair', 'Where the distance will be minimized'], label=1), InputExample(texts=['This is a negative pair', 'Their distance will be increased'], label=0)] - train_dataset = SentencesDataset(train_examples, model) - train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size) + + train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=2) train_loss = losses.OnlineContrastiveLoss(model=model) + + model.fit([(train_dataloader, train_loss)], show_progress_bar=True) """ def __init__(self, model: SentenceTransformer, distance_metric=SiameseDistanceMetric.COSINE_DISTANCE, margin: float = 0.5):