From a5ac0561b857cabf60903536ceb147e2abb5bc15 Mon Sep 17 00:00:00 2001 From: Nils Reimers Date: Wed, 23 Dec 2020 22:19:15 +0100 Subject: [PATCH] update batch hard example --- .../multilingual/make_multilingual.py | 4 +- ...raining.py => training_batch_hard_trec.py} | 60 ++--- sentence_transformers/__init__.py | 2 +- .../datasets/SentenceLabelDataset.py | 241 ++++++------------ sentence_transformers/datasets/__init__.py | 4 +- .../datasets/sampler/LabelSampler.py | 76 ------ .../datasets/sampler/__init__.py | 1 - tests/test_compute_embeddings.py | 49 ++++ 8 files changed, 151 insertions(+), 286 deletions(-) rename examples/training/other/{training_batch_hard_trec_continue_training.py => training_batch_hard_trec.py} (83%) delete mode 100644 sentence_transformers/datasets/sampler/LabelSampler.py delete mode 100644 sentence_transformers/datasets/sampler/__init__.py create mode 100644 tests/test_compute_embeddings.py diff --git a/examples/training/multilingual/make_multilingual.py b/examples/training/multilingual/make_multilingual.py index 0e0fdfe74..53be0e3a5 100644 --- a/examples/training/multilingual/make_multilingual.py +++ b/examples/training/multilingual/make_multilingual.py @@ -77,8 +77,8 @@ def download_corpora(filepaths): # Here we define train train and dev corpora -train_corpus = "../datasets/ted2020.tsv.gz" # Transcripts of TED talks, crawled 2020 -sts_corpus = "../datasets/STS2017-extended.zip" # Extended STS2017 dataset for more languages +train_corpus = "datasets/ted2020.tsv.gz" # Transcripts of TED talks, crawled 2020 +sts_corpus = "datasets/STS2017-extended.zip" # Extended STS2017 dataset for more languages parallel_sentences_folder = "parallel-sentences/" # Check if the file exists. If not, they are downloaded diff --git a/examples/training/other/training_batch_hard_trec_continue_training.py b/examples/training/other/training_batch_hard_trec.py similarity index 83% rename from examples/training/other/training_batch_hard_trec_continue_training.py rename to examples/training/other/training_batch_hard_trec.py index d1b2a3fec..bd5c2ffb5 100644 --- a/examples/training/other/training_batch_hard_trec_continue_training.py +++ b/examples/training/other/training_batch_hard_trec.py @@ -9,21 +9,15 @@ too easy and the network fails to learn good representations. Batch hard triplet loss (https://arxiv.org/abs/1703.07737) creates triplets on the fly. It requires that the -data is labeled (e.g. labels A, B, C) and we assume that samples with the same label are similar: -A sent1; A sent2; B sent3; B sent4 -... +data is labeled (e.g. labels 1, 2, 3) and we assume that samples with the same label are similar: -In a batch, it checks for sent1 with label A what is the other sentence with label A that is the furthest (hard positive) +In a batch, it checks for sent1 with label 1 what is the other sentence with label 1 that is the furthest (hard positive) which sentence with another label is the closest (hard negative example). It then tries to optimize this, i.e. all sentences with the same label should be close and sentences for different labels should be clearly seperated. """ -from sentence_transformers import ( - SentenceTransformer, - SentenceLabelDataset, - LoggingHandler, - losses, -) +from sentence_transformers import SentenceTransformer, LoggingHandler, losses, util +from sentence_transformers.datasets import SentenceLabelDataset from torch.utils.data import DataLoader from sentence_transformers.readers import InputExample from sentence_transformers.evaluation import TripletEvaluator @@ -32,10 +26,16 @@ import logging import os -import urllib.request import random from collections import defaultdict +logging.basicConfig( + format="%(asctime)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO, + handlers=[LoggingHandler()], +) + # Inspired from torchnlp def trec_dataset( directory="datasets/trec/", @@ -43,17 +43,17 @@ def trec_dataset( test_filename="TREC_10.label", validation_dataset_nb=500, urls=[ - "http://cogcomp.org/Data/QA/QC/train_5500.label", - "http://cogcomp.org/Data/QA/QC/TREC_10.label", + "https://cogcomp.seas.upenn.edu/Data/QA/QC/train_5500.label", + "https://cogcomp.seas.upenn.edu/Data/QA/QC/TREC_10.label", ], ): - os.makedirs(directory, exist_ok=True) ret = [] for url, filename in zip(urls, [train_filename, test_filename]): full_path = os.path.join(directory, filename) - urllib.request.urlretrieve(url, filename=full_path) + if not os.path.exists(full_path): + util.http_get(url, full_path) examples = [] label_map = {} @@ -61,9 +61,6 @@ def trec_dataset( for line in open(full_path, "rb"): # there is one non-ASCII byte: sisterBADBYTEcity; replaced with space label, _, text = line.replace(b"\xf0", b" ").strip().decode().partition(" ") - - # We extract the upper category (e.g. DESC from DESC:def) - label, _, _ = label.partition(":") if label not in label_map: label_map[label] = len(label_map) @@ -117,12 +114,6 @@ def triplets_from_labeled_dataset(input_examples): return triplets -logging.basicConfig( - format="%(asctime)s - %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - level=logging.INFO, - handlers=[LoggingHandler()], -) # You can specify any huggingface/transformers pre-trained model here, for example, bert-base-uncased, roberta-base, xlm-roberta-base model_name = 'paraphrase-distilroberta-base-v1' @@ -140,19 +131,16 @@ def triplets_from_labeled_dataset(input_examples): logging.info("Loading TREC dataset") train_set, dev_set, test_set = trec_dataset() +# We create a special dataset "SentenceLabelDataset" to wrap out train_set +# It will yield batches that contain at least two samples with the same label +train_data_sampler = SentenceLabelDataset(train_set) +train_dataloader = DataLoader(train_data_sampler, batch_size=32, drop_last=True) # Load pretrained model +logging.info("Load model") model = SentenceTransformer(model_name) -logging.info("Read TREC train dataset") -train_dataset = SentenceLabelDataset( - examples=train_set, - model=model, - provide_positive=False, #For BatchHardTripletLoss, we must set provide_positive and provide_negative to False - provide_negative=False, -) -train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size) ### Triplet losses #################### ### There are 4 triplet loss variants: @@ -169,14 +157,12 @@ def triplets_from_labeled_dataset(input_examples): logging.info("Read TREC val dataset") -dev_evaluator = TripletEvaluator.from_input_examples(dev_set, name='dev') +dev_evaluator = TripletEvaluator.from_input_examples(dev_set, name='trec-dev') logging.info("Performance before fine-tuning:") dev_evaluator(model) -warmup_steps = int( - len(train_dataset) * num_epochs / train_batch_size * 0.1 -) # 10% of train data +warmup_steps = int(len(train_dataloader) * num_epochs * 0.1) # 10% of train data # Train the model model.fit( @@ -195,5 +181,5 @@ def triplets_from_labeled_dataset(input_examples): ############################################################################## logging.info("Evaluating model on test set") -test_evaluator = TripletEvaluator.from_input_examples(test_set, name='test') +test_evaluator = TripletEvaluator.from_input_examples(test_set, name='trec-test') model.evaluate(test_evaluator) diff --git a/sentence_transformers/__init__.py b/sentence_transformers/__init__.py index 0ca87027b..0555bb850 100644 --- a/sentence_transformers/__init__.py +++ b/sentence_transformers/__init__.py @@ -1,6 +1,6 @@ __version__ = "0.4.0" __DOWNLOAD_SERVER__ = 'https://sbert.net/models/' -from .datasets import SentencesDataset, SentenceLabelDataset, ParallelSentencesDataset +from .datasets import SentencesDataset, ParallelSentencesDataset from .LoggingHandler import LoggingHandler from .SentenceTransformer import SentenceTransformer from .readers import InputExample diff --git a/sentence_transformers/datasets/SentenceLabelDataset.py b/sentence_transformers/datasets/SentenceLabelDataset.py index f72c6c148..4b0385f1d 100644 --- a/sentence_transformers/datasets/SentenceLabelDataset.py +++ b/sentence_transformers/datasets/SentenceLabelDataset.py @@ -1,185 +1,94 @@ -from torch.utils.data import Dataset +""" + +""" +from torch.utils.data import IterableDataset +import numpy as np from typing import List -import bisect -import torch +from ..readers import InputExample import logging -import numpy as np -from tqdm import tqdm -from .. import SentenceTransformer -from ..readers.InputExample import InputExample -from multiprocessing import Pool, cpu_count -import multiprocessing -class SentenceLabelDataset(Dataset): +class SentenceLabelDataset(IterableDataset): """ - Dataset for training with triplet loss. - This dataset takes a list of sentences grouped by their label and uses this grouping to dynamically select a - positive example from the same group and a negative example from the other sentences for a selected anchor sentence. + This dataset can be used for some specific Triplet Losses like BATCH_HARD_TRIPLET_LOSS which requires + multiple examples with the same label in a batch. - This dataset should be used in combination with dataset_reader.LabelSentenceReader + It draws n consecutive, random and unique samples from one label at a time. This is repeated for each label. - One iteration over this dataset selects every sentence as anchor once. + Labels with fewer than n unique samples are ignored. + This also applied to drawing without replacement, once less than n samples remain for a label, it is skipped. - This also uses smart batching like SentenceDataset. + This *DOES NOT* check if there are more labels than the batch is large or if the batch size is divisible + by the samples drawn per label. """ - - def __init__(self, examples: List[InputExample], model: SentenceTransformer, provide_positive: bool = True, - provide_negative: bool = True, - parallel_tokenization: bool = True, - max_processes: int = 4, - chunk_size: int = 5000): + def __init__(self, examples: List[InputExample], samples_per_label: int = 2, with_replacement: bool = False): """ - Converts input examples to a SentenceLabelDataset usable to train the model with - SentenceTransformer.smart_batching_collate as the collate_fn for the DataLoader - - Assumes only one sentence per InputExample and labels as integers from 0 to max_num_labels - and should be used in combination with dataset_reader.LabelSentenceReader. - - Labels with only one example are ignored. - - smart_batching_collate as collate_fn is required because it transforms the tokenized texts to the tensors. + Creates a LabelSampler for a SentenceLabelDataset. :param examples: - the input examples for the training - :param model - the Sentence BERT model for the conversion - :param provide_positive: - set this to False, if you don't need a positive example (e.g. for BATCH_HARD_TRIPLET_LOSS). - :param provide_negative: - set this to False, if you don't need a negative example (e.g. for BATCH_HARD_TRIPLET_LOSS - or MULTIPLE_NEGATIVES_RANKING_LOSS). - :param parallel_tokenization - If true, multiple processes will be started for the tokenization - :param max_processes - Maximum number of processes started for tokenization. Cannot be larger can cpu_count() - :param chunk_size - #chunk_size number of examples are send to each process. Larger values increase overall tokenization speed + a list with InputExamples + :param samples_per_label: + the number of consecutive, random and unique samples drawn per label + :param with_replacement: + if this is True, then each sample is drawn at most once (depending on the total number of samples per label). + if this is False, then one sample can be drawn in multiple draws, but still not multiple times in the same + drawing. """ - self.model = model - self.groups_right_border = [] - self.grouped_inputs = [] - self.grouped_labels = [] - self.num_labels = 0 - self.max_processes = min(max_processes, cpu_count()) - self.chunk_size = chunk_size - self.parallel_tokenization = parallel_tokenization - - if self.parallel_tokenization: - if multiprocessing.get_start_method() != 'fork': - logging.info("Parallel tokenization is only available on Unix systems which allow to fork processes. Fall back to sequential tokenization") - self.parallel_tokenization = False - - self.convert_input_examples(examples, model) - - self.idxs = np.arange(len(self.grouped_inputs)) - - self.provide_positive = provide_positive - self.provide_negative = provide_negative - - - def convert_input_examples(self, examples: List[InputExample], model: SentenceTransformer): - """ - Converts input examples to a SentenceLabelDataset. - - Assumes only one sentence per InputExample and labels as integers from 0 to max_num_labels - and should be used in combination with dataset_reader.LabelSentenceReader. + super().__init__() - Labels with only one example are ignored. + self.samples_per_label = samples_per_label - :param examples: - the input examples for the training - :param model - the Sentence Transformer model for the conversion - :param is_pretokenized - If set to true, no tokenization will be applied. It is expected that the input is tokenized via model.tokenize - """ + #Group examples by label + label2ex = {} + for example in examples: + if example.label not in label2ex: + label2ex[example.label] = [] + label2ex[example.label].append(example) - inputs = [] - labels = [] - - label_sent_mapping = {} - too_long = 0 - label_type = None - - logging.info("Start tokenization") - if not self.parallel_tokenization or self.max_processes == 1 or len(examples) <= self.chunk_size: - tokenized_texts = [self.tokenize_example(example) for example in examples] - else: - logging.info("Use multi-process tokenization with {} processes".format(self.max_processes)) - self.model.to('cpu') - with Pool(self.max_processes) as p: - tokenized_texts = list(p.imap(self.tokenize_example, examples, chunksize=self.chunk_size)) - - # Group examples and labels - # Add examples with the same label to the same dict - for ex_index, example in enumerate(tqdm(examples, desc="Convert dataset")): - if label_type is None: - if isinstance(example.label, int): - label_type = torch.long - elif isinstance(example.label, float): - label_type = torch.float - tokenized_text = tokenized_texts[ex_index][0] - - if hasattr(model, 'max_seq_length') and model.max_seq_length is not None and model.max_seq_length > 0 and len(tokenized_text) > model.max_seq_length: - too_long += 1 - - if example.label in label_sent_mapping: - label_sent_mapping[example.label].append(ex_index) + #Include only labels with at least 2 examples + self.grouped_inputs = [] + self.groups_right_border = [] + num_labels = 0 + + for label, label_examples in label2ex.items(): + if len(label_examples) >= self.samples_per_label: + self.grouped_inputs.extend(label_examples) + self.groups_right_border.append(len(self.grouped_inputs)) # At which position does this label group / bucket end? + num_labels += 1 + + self.label_range = np.arange(num_labels) + self.with_replacement = with_replacement + np.random.shuffle(self.label_range) + + logging.info("SentenceLabelDataset: {} examples, from which {} examples could be used (those labels appeared at least {} times). {} different labels found.".format(len(examples), len(self.grouped_inputs), self.samples_per_label, num_labels )) + + def __iter__(self): + label_idx = 0 + count = 0 + already_seen = {} + while count < len(self.grouped_inputs): + label = self.label_range[label_idx] + if label not in already_seen: + already_seen[label] = set() + + left_border = 0 if label == 0 else self.groups_right_border[label-1] + right_border = self.groups_right_border[label] + + if self.with_replacement: + selection = np.arange(left_border, right_border) else: - label_sent_mapping[example.label] = [ex_index] - - inputs.append(tokenized_text) - labels.append(example.label) - - # Group sentences, such that sentences with the same label - # are besides each other. Only take labels with at least 2 examples - distinct_labels = list(label_sent_mapping.keys()) - for i in range(len(distinct_labels)): - label = distinct_labels[i] - if len(label_sent_mapping[label]) >= 2: - self.grouped_inputs.extend([inputs[j] for j in label_sent_mapping[label]]) - self.grouped_labels.extend([labels[j] for j in label_sent_mapping[label]]) - self.groups_right_border.append(len(self.grouped_inputs)) #At which position does this label group / bucket end? - self.num_labels += 1 - - self.grouped_labels = torch.tensor(self.grouped_labels, dtype=label_type) - logging.info("Num sentences: %d" % (len(self.grouped_inputs))) - logging.info("Sentences longer than max_seqence_length: {}".format(too_long)) - logging.info("Number of labels with >1 examples: {}".format(len(distinct_labels))) - - - def tokenize_example(self, example): - if example.texts_tokenized is not None: - return example.texts_tokenized - - return [self.model.tokenize(text) for text in example.texts] - - def __getitem__(self, item): - if not self.provide_positive and not self.provide_negative: - return [self.grouped_inputs[item]], self.grouped_labels[item] - - # Anchor element - anchor = self.grouped_inputs[item] - - # Check start and end position for this label in our list of grouped sentences - group_idx = bisect.bisect_right(self.groups_right_border, item) - left_border = 0 if group_idx == 0 else self.groups_right_border[group_idx - 1] - right_border = self.groups_right_border[group_idx] - - if self.provide_positive: - positive_item_idx = np.random.choice(np.concatenate([self.idxs[left_border:item], self.idxs[item + 1:right_border]])) - positive = self.grouped_inputs[positive_item_idx] - else: - positive = [] - - if self.provide_negative: - negative_item_idx = np.random.choice(np.concatenate([self.idxs[0:left_border], self.idxs[right_border:]])) - negative = self.grouped_inputs[negative_item_idx] - else: - negative = [] - - return [anchor, positive, negative], self.grouped_labels[item] - + selection = [i for i in np.arange(left_border, right_border) if i not in already_seen[label]] + + if len(selection) >= self.samples_per_label: + for element_idx in np.random.choice(selection, self.samples_per_label, replace=False): + count += 1 + already_seen[label].add(element_idx) + yield self.grouped_inputs[element_idx] + + label_idx += 1 + if label_idx >= len(self.label_range): + label_idx = 0 + already_seen = {} + np.random.shuffle(self.label_range) def __len__(self): return len(self.grouped_inputs) \ No newline at end of file diff --git a/sentence_transformers/datasets/__init__.py b/sentence_transformers/datasets/__init__.py index f67b35d1b..3ed73dcc4 100644 --- a/sentence_transformers/datasets/__init__.py +++ b/sentence_transformers/datasets/__init__.py @@ -1,5 +1,3 @@ -from .sampler import * from .ParallelSentencesDataset import ParallelSentencesDataset -from .SentenceLabelDataset import SentenceLabelDataset from .SentencesDataset import SentencesDataset - +from .SentenceLabelDataset import SentenceLabelDataset diff --git a/sentence_transformers/datasets/sampler/LabelSampler.py b/sentence_transformers/datasets/sampler/LabelSampler.py deleted file mode 100644 index 394bec5b4..000000000 --- a/sentence_transformers/datasets/sampler/LabelSampler.py +++ /dev/null @@ -1,76 +0,0 @@ -""" -This file contains sampler functions, that can be used to sample mini-batches with specific properties. -""" -from torch.utils.data import Sampler -import numpy as np -from ...datasets import SentenceLabelDataset - - -class LabelSampler(Sampler): - """ - This sampler is used for some specific Triplet Losses like BATCH_HARD_TRIPLET_LOSS - or MULTIPLE_NEGATIVES_RANKING_LOSS which require multiple or only one sample from one label per batch. - - It draws n consecutive, random and unique samples from one label at a time. This is repeated for each label. - - Labels with fewer than n unique samples are ignored. - This also applied to drawing without replacement, once less than n samples remain for a label, it is skipped. - - This *DOES NOT* check if there are more labels than the batch is large or if the batch size is divisible - by the samples drawn per label. - - - """ - def __init__(self, data_source: SentenceLabelDataset, samples_per_label: int = 5, - with_replacement: bool = False): - """ - Creates a LabelSampler for a SentenceLabelDataset. - - :param data_source: - the dataset from which samples are drawn - :param samples_per_label: - the number of consecutive, random and unique samples drawn per label - :param with_replacement: - if this is True, then each sample is drawn at most once (depending on the total number of samples per label). - if this is False, then one sample can be drawn in multiple draws, but still not multiple times in the same - drawing. - """ - super().__init__(data_source) - self.data_source = data_source - self.samples_per_label = samples_per_label - self.label_range = np.arange(data_source.num_labels) - self.borders = data_source.groups_right_border - self.with_replacement = with_replacement - np.random.shuffle(self.label_range) - - def __iter__(self): - label_idx = 0 - count = 0 - already_seen = {} - while count < len(self.data_source): - label = self.label_range[label_idx] - if label not in already_seen: - already_seen[label] = set() - - left_border = 0 if label == 0 else self.borders[label-1] - right_border = self.borders[label] - - if self.with_replacement: - selection = np.arange(left_border, right_border) - else: - selection = [i for i in np.arange(left_border, right_border) if i not in already_seen[label]] - - if len(selection) >= self.samples_per_label: - for element_idx in np.random.choice(selection, self.samples_per_label, replace=False): - count += 1 - already_seen[label].add(element_idx) - yield element_idx - - label_idx += 1 - if label_idx >= len(self.label_range): - label_idx = 0 - already_seen = {} - np.random.shuffle(self.label_range) - - def __len__(self): - return len(self.data_source) \ No newline at end of file diff --git a/sentence_transformers/datasets/sampler/__init__.py b/sentence_transformers/datasets/sampler/__init__.py deleted file mode 100644 index 42f696134..000000000 --- a/sentence_transformers/datasets/sampler/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .LabelSampler import * \ No newline at end of file diff --git a/tests/test_compute_embeddings.py b/tests/test_compute_embeddings.py new file mode 100644 index 000000000..00ee71cb1 --- /dev/null +++ b/tests/test_compute_embeddings.py @@ -0,0 +1,49 @@ +""" +Computes embeddings +""" + +import csv +import gzip +import os +import unittest + +from torch.utils.data import DataLoader + +from sentence_transformers import SentenceTransformer, SentencesDataset, losses, models, util +from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator +from sentence_transformers.readers import InputExample +import numpy as np + +class ComputeEmbeddingsTest(unittest.TestCase): + def setUp(self): + self.model = SentenceTransformer('paraphrase-distilroberta-base-v1') + + def test_encode_single_sentences(self): + #Single sentence + emb = self.model.encode("Hello Word, a test sentence") + assert emb.shape == (768,) + assert abs(np.sum(emb) - 7.9811716) < 0.001 + + # Single sentence as list + emb = self.model.encode(["Hello Word, a test sentence"]) + assert emb.shape == (1, 768) + assert abs(np.sum(emb) - 7.9811716) < 0.001 + + # Sentence list + emb = self.model.encode(["Hello Word, a test sentence", "Here comes another sentence", "My final sentence"]) + assert emb.shape == (3, 768) + print(np.sum(emb)) + assert abs(np.sum(emb) - 22.968266) < 0.001 + + def test_encode_tuple_sentences(self): + # Input a sentence tuple + emb = self.model.encode([("Hello Word, a test sentence", "Second input for model")]) + assert emb.shape == (1, 768) + assert abs(np.sum(emb) - 9.503508) < 0.001 + + # List of sentence tuples + emb = self.model.encode([("Hello Word, a test sentence", "Second input for model"), ("My second tuple", "With two inputs"), ("Final tuple", "final test")]) + assert emb.shape == (3, 768) + assert abs(np.sum(emb) - 32.14627) < 0.001 + +