Skip to content

Commit

Permalink
Update ConstrativeLoss examples
Browse files Browse the repository at this point in the history
  • Loading branch information
nreimers committed Dec 13, 2021
1 parent 10e1599 commit d2f0293
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 12 deletions.
15 changes: 9 additions & 6 deletions sentence_transformers/losses/ContrastiveLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,19 @@ class ContrastiveLoss(nn.Module):
Example::
from sentence_transformers import SentenceTransformer, SentencesDataset, LoggingHandler, losses
from sentence_transformers.readers import InputExample
from sentence_transformers import SentenceTransformer, LoggingHandler, losses, InputExample
from torch.utils.data import DataLoader
model = SentenceTransformer('distilbert-base-nli-mean-tokens')
train_examples = [InputExample(texts=['This is a positive pair', 'Where the distance will be minimized'], label=1),
model = SentenceTransformer('all-MiniLM-L6-v2')
train_examples = [
InputExample(texts=['This is a positive pair', 'Where the distance will be minimized'], label=1),
InputExample(texts=['This is a negative pair', 'Their distance will be increased'], label=0)]
train_dataset = SentencesDataset(train_examples, model)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size)
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=2)
train_loss = losses.ContrastiveLoss(model=model)
model.fit([(train_dataloader, train_loss)], show_progress_bar=True)
"""

def __init__(self, model: SentenceTransformer, distance_metric=SiameseDistanceMetric.COSINE_DISTANCE, margin: float = 0.5, size_average:bool = True):
Expand Down
15 changes: 9 additions & 6 deletions sentence_transformers/losses/OnlineContrastiveLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,18 @@ class OnlineContrastiveLoss(nn.Module):
Example::
from sentence_transformers import SentenceTransformer, SentencesDataset, LoggingHandler, losses
from sentence_transformers.readers import InputExample
from sentence_transformers import SentenceTransformer, LoggingHandler, losses, InputExample
from torch.utils.data import DataLoader
model = SentenceTransformer('distilbert-base-nli-mean-tokens')
train_examples = [InputExample(texts=['This is a positive pair', 'Where the distance will be minimized'], label=1),
model = SentenceTransformer('all-MiniLM-L6-v2')
train_examples = [
InputExample(texts=['This is a positive pair', 'Where the distance will be minimized'], label=1),
InputExample(texts=['This is a negative pair', 'Their distance will be increased'], label=0)]
train_dataset = SentencesDataset(train_examples, model)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size)
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=2)
train_loss = losses.OnlineContrastiveLoss(model=model)
model.fit([(train_dataloader, train_loss)], show_progress_bar=True)
"""

def __init__(self, model: SentenceTransformer, distance_metric=SiameseDistanceMetric.COSINE_DISTANCE, margin: float = 0.5):
Expand Down

0 comments on commit d2f0293

Please sign in to comment.