Skip to content

Commit

Permalink
update batch hard example
Browse files Browse the repository at this point in the history
  • Loading branch information
nreimers committed Dec 23, 2020
1 parent 5c3d34f commit a5ac056
Show file tree
Hide file tree
Showing 8 changed files with 151 additions and 286 deletions.
4 changes: 2 additions & 2 deletions examples/training/multilingual/make_multilingual.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def download_corpora(filepaths):


# Here we define train train and dev corpora
train_corpus = "../datasets/ted2020.tsv.gz" # Transcripts of TED talks, crawled 2020
sts_corpus = "../datasets/STS2017-extended.zip" # Extended STS2017 dataset for more languages
train_corpus = "datasets/ted2020.tsv.gz" # Transcripts of TED talks, crawled 2020
sts_corpus = "datasets/STS2017-extended.zip" # Extended STS2017 dataset for more languages
parallel_sentences_folder = "parallel-sentences/"

# Check if the file exists. If not, they are downloaded
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,15 @@
too easy and the network fails to learn good representations.
Batch hard triplet loss (https://arxiv.org/abs/1703.07737) creates triplets on the fly. It requires that the
data is labeled (e.g. labels A, B, C) and we assume that samples with the same label are similar:
A sent1; A sent2; B sent3; B sent4
...
data is labeled (e.g. labels 1, 2, 3) and we assume that samples with the same label are similar:
In a batch, it checks for sent1 with label A what is the other sentence with label A that is the furthest (hard positive)
In a batch, it checks for sent1 with label 1 what is the other sentence with label 1 that is the furthest (hard positive)
which sentence with another label is the closest (hard negative example). It then tries to optimize this, i.e.
all sentences with the same label should be close and sentences for different labels should be clearly seperated.
"""

from sentence_transformers import (
SentenceTransformer,
SentenceLabelDataset,
LoggingHandler,
losses,
)
from sentence_transformers import SentenceTransformer, LoggingHandler, losses, util
from sentence_transformers.datasets import SentenceLabelDataset
from torch.utils.data import DataLoader
from sentence_transformers.readers import InputExample
from sentence_transformers.evaluation import TripletEvaluator
Expand All @@ -32,38 +26,41 @@

import logging
import os
import urllib.request
import random
from collections import defaultdict

logging.basicConfig(
format="%(asctime)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=logging.INFO,
handlers=[LoggingHandler()],
)

# Inspired from torchnlp
def trec_dataset(
directory="datasets/trec/",
train_filename="train_5500.label",
test_filename="TREC_10.label",
validation_dataset_nb=500,
urls=[
"http://cogcomp.org/Data/QA/QC/train_5500.label",
"http://cogcomp.org/Data/QA/QC/TREC_10.label",
"https://cogcomp.seas.upenn.edu/Data/QA/QC/train_5500.label",
"https://cogcomp.seas.upenn.edu/Data/QA/QC/TREC_10.label",
],
):

os.makedirs(directory, exist_ok=True)

ret = []
for url, filename in zip(urls, [train_filename, test_filename]):
full_path = os.path.join(directory, filename)
urllib.request.urlretrieve(url, filename=full_path)
if not os.path.exists(full_path):
util.http_get(url, full_path)

examples = []
label_map = {}
guid=1
for line in open(full_path, "rb"):
# there is one non-ASCII byte: sisterBADBYTEcity; replaced with space
label, _, text = line.replace(b"\xf0", b" ").strip().decode().partition(" ")

# We extract the upper category (e.g. DESC from DESC:def)
label, _, _ = label.partition(":")

if label not in label_map:
label_map[label] = len(label_map)
Expand Down Expand Up @@ -117,12 +114,6 @@ def triplets_from_labeled_dataset(input_examples):
return triplets


logging.basicConfig(
format="%(asctime)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=logging.INFO,
handlers=[LoggingHandler()],
)

# You can specify any huggingface/transformers pre-trained model here, for example, bert-base-uncased, roberta-base, xlm-roberta-base
model_name = 'paraphrase-distilroberta-base-v1'
Expand All @@ -140,19 +131,16 @@ def triplets_from_labeled_dataset(input_examples):
logging.info("Loading TREC dataset")
train_set, dev_set, test_set = trec_dataset()

# We create a special dataset "SentenceLabelDataset" to wrap out train_set
# It will yield batches that contain at least two samples with the same label
train_data_sampler = SentenceLabelDataset(train_set)
train_dataloader = DataLoader(train_data_sampler, batch_size=32, drop_last=True)


# Load pretrained model
logging.info("Load model")
model = SentenceTransformer(model_name)

logging.info("Read TREC train dataset")
train_dataset = SentenceLabelDataset(
examples=train_set,
model=model,
provide_positive=False, #For BatchHardTripletLoss, we must set provide_positive and provide_negative to False
provide_negative=False,
)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size)

### Triplet losses ####################
### There are 4 triplet loss variants:
Expand All @@ -169,14 +157,12 @@ def triplets_from_labeled_dataset(input_examples):


logging.info("Read TREC val dataset")
dev_evaluator = TripletEvaluator.from_input_examples(dev_set, name='dev')
dev_evaluator = TripletEvaluator.from_input_examples(dev_set, name='trec-dev')

logging.info("Performance before fine-tuning:")
dev_evaluator(model)

warmup_steps = int(
len(train_dataset) * num_epochs / train_batch_size * 0.1
) # 10% of train data
warmup_steps = int(len(train_dataloader) * num_epochs * 0.1) # 10% of train data

# Train the model
model.fit(
Expand All @@ -195,5 +181,5 @@ def triplets_from_labeled_dataset(input_examples):
##############################################################################

logging.info("Evaluating model on test set")
test_evaluator = TripletEvaluator.from_input_examples(test_set, name='test')
test_evaluator = TripletEvaluator.from_input_examples(test_set, name='trec-test')
model.evaluate(test_evaluator)
2 changes: 1 addition & 1 deletion sentence_transformers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__version__ = "0.4.0"
__DOWNLOAD_SERVER__ = 'https://sbert.net/models/'
from .datasets import SentencesDataset, SentenceLabelDataset, ParallelSentencesDataset
from .datasets import SentencesDataset, ParallelSentencesDataset
from .LoggingHandler import LoggingHandler
from .SentenceTransformer import SentenceTransformer
from .readers import InputExample
Expand Down
Loading

0 comments on commit a5ac056

Please sign in to comment.