-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
initial commit: inference pipeline done
- Loading branch information
0 parents
commit 930a343
Showing
10 changed files
with
497 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
__pycache__ | ||
runs | ||
wandb | ||
_clones | ||
logs | ||
ke-t5-base-finetuned-en-to-ko | ||
ke-t5-base-finetuned-ko-to-en | ||
results | ||
*.wet | ||
*.txt | ||
*.ipynb |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
Implementation of the paper ["Extracting Training Data from Large Language Models"(Carlini et al, 2020)](https://arxiv.org/abs/2012.07805) | ||
|
||
### How to Run | ||
|
||
### References | ||
|
||
- [Official Implementation](https://github.com/ftramer/LM_Memorization) | ||
- [Implementation with Sampling Method](https://github.com/shreyansh26/Extracting-Training-Data-from-Large-Langauge-Models) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
CFG: | ||
# gpt2, gpt2-medium, gpt2-large, gpt2-xl | ||
# t5-small, t5-base, t5-large, t5-3b | ||
# t5-v1_1-small, t5-v1_1-base, t5-v1_1-large, t5-v1_1-xl | ||
# t5-small-lm-adapt, t5-base-lm-adapt, t5-large-lm-adapt, t5-xl-lm-adapt | ||
model_type: gpt2 # gpt2, t5 flax-community/t5-base-openwebtext | ||
|
||
# configuration reconstructed according to the paper | ||
data_path: snoop2head/common_crawl # one of snoop2head/common_crawl, c4, openwebtext | ||
inference_batch_size: 32 | ||
min_prefix_length: 5 # prefix ranges 5 ~ 10 according to the paper | ||
max_prefix_length: 10 | ||
generate_token_length: 256 # +256 token to the given prefix | ||
num_return_sequences: 5 # n = 40 (or k = 40) | ||
num_inference_samples: 1000 # 200K sampling(or generation) per each scheme | ||
|
||
|
||
# configuration for the model's generation | ||
device_ids: # if -1 use cpu(not recommended), elif 0 use single gpu, else use multiple gpus | ||
- 0 | ||
fp16: false | ||
num_beams: 5 | ||
repetition_penalty: 1.3 | ||
no_repeat_ngram_size: 3 | ||
num_return_sequences: 1 | ||
|
||
# other configs | ||
seed: 42 | ||
inference_result_path: ./results/result.csv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import pandas as pd | ||
from transformers import PreTrainedTokenizerFast, DataCollatorForSeq2Seq, | ||
from datasets import load_dataset, Dataset, DatasetDict | ||
from utils import load_config | ||
|
||
|
||
def parse_commoncrawl(wet_file): | ||
""" | ||
Prefix candidates for the GPT-2 model generations. | ||
Parses of a WET file and port to huggingface dataset | ||
Tested for the May 2021 crawl. | ||
@shreyansh26 | ||
""" | ||
dset_list = [] | ||
with open(wet_file) as f: | ||
lines = f.readlines() | ||
|
||
start_idxs = [i for i in range(len(lines)) if "WARC/1.0" in lines[i]] | ||
|
||
count_eng = 0 | ||
for i in range(len(start_idxs) - 1): | ||
start = start_idxs[i] | ||
end = start_idxs[i + 1] | ||
if "WARC-Identified-Content-Language: eng" in lines[start + 7]: | ||
count_eng += 1 | ||
for j in range(start + 10, end): | ||
dset_list.append(lines[j]) | ||
|
||
return dset_list | ||
|
||
|
||
def remove_line_break(input_list: list): | ||
""" | ||
removes \n from all the items in a list. | ||
""" | ||
return [item.replace("\n", "") for item in input_list] | ||
|
||
|
||
def remove_duplicates(input_list: list): | ||
""" | ||
Delete duplicates from a list. | ||
""" | ||
return list(set(input_list)) | ||
|
||
|
||
def remove_blank_items(input_list: list): | ||
""" | ||
Delete blank items from a list. | ||
""" | ||
return [item for item in input_list if item != ""] | ||
|
||
|
||
def remove_short_items(input_list: list, min_length: int = 5): | ||
CFG = load_config() | ||
tokenizer = PreTrainedTokenizerFast.from_pretrained(CFG.model_name) | ||
return [item for item in input_list if len(tokenizer.tokenize(item)) >= min_length] | ||
|
||
|
||
def upload_huggingface_hub(dset: list): | ||
# package to huggingface dataset and push to hub | ||
df = pd.DataFrame(dset, columns=["text"]) | ||
dataset = Dataset.from_pandas(df) | ||
dataset_dict = DatasetDict({"train": dataset}) | ||
dataset_dict.push_to_hub("snoop2head/common_crawl") | ||
pass | ||
|
||
|
||
def package_openwebtext(): | ||
""" GPT train dataset: https://huggingface.co/datasets/openwebtext """ | ||
pass | ||
|
||
|
||
def package_c4(): | ||
""" T5 train dataset: https://huggingface.co/datasets/c4 """ | ||
pass | ||
|
||
|
||
def __main__(): | ||
|
||
# read and parse | ||
dset = parse_commoncrawl("./commoncrawl.warc.wet") | ||
print(len(dset)) | ||
|
||
# preprocess | ||
dset = remove_duplicates(dset) | ||
dset = remove_line_break(dset) | ||
dset = remove_blank_items(dset) | ||
dset = remove_short_items(dset) | ||
print(len(dset)) | ||
|
||
# upload to huggingface hub | ||
upload_huggingface_hub(dset) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
# 3rd-party libraries | ||
import multiprocessing | ||
import pandas as pd | ||
import numpy as np | ||
from tqdm import tqdm | ||
from easydict import EasyDict | ||
import torch | ||
from torch.utils.data import DataLoader | ||
from datasets import load_dataset, dataset_dict | ||
|
||
|
||
# custom modules | ||
from models import load_tokenizer, load_generation_model, tokenize_fn | ||
from utils import ( | ||
load_config, | ||
load_devices, | ||
remove_lengthy_texts, | ||
restrict_token_length_fn, | ||
get_token_sequence_length, | ||
collate_fn, | ||
seed_everything, | ||
) | ||
from metric import ( | ||
calculate_batch_perplexity, | ||
calculate_batch_window, | ||
calculate_batch_zlib, | ||
Summary, | ||
AverageMeter, | ||
) | ||
|
||
# load config | ||
CFG = load_config() | ||
seed_everything(CFG.seed) | ||
CPU_COUNT = multiprocessing.cpu_count() // 2 | ||
|
||
# load models to the designated device(s) | ||
devices = load_devices() | ||
tokenizer = load_tokenizer() | ||
baseline_model = load_generation_model("baseline").to(devices[0]) # largest model | ||
middle_model = load_generation_model("middle").to(devices[0]) | ||
small_model = load_generation_model("small").to(devices[0]) | ||
|
||
# load and tokenize dataset | ||
internet_data = load_dataset(CFG.data_path, split="train") | ||
internet_data = internet_data.filter(remove_lengthy_texts, num_proc=CPU_COUNT) | ||
random_numbers_train = np.random.randint( | ||
0, len(internet_data["text"]), int(CFG.num_inference_samples) | ||
) | ||
internet_data = internet_data.select(random_numbers_train) | ||
tokenized_datasets = internet_data.map( | ||
tokenize_fn, batched=True, num_proc=CPU_COUNT, remove_columns=["text"] | ||
) | ||
tokenized_datasets = tokenized_datasets.filter(restrict_token_length_fn, num_proc=CPU_COUNT) | ||
print("text data tokenization done") | ||
|
||
# make dataloaders with uniform lengths batch | ||
list_prefix_loaders = [] | ||
tokenized_datasets = tokenized_datasets.map(get_token_sequence_length, num_proc=CPU_COUNT) | ||
min_len = min(tokenized_datasets["sequence_length"]) | ||
max_len = max(tokenized_datasets["sequence_length"]) | ||
for prefix_len in range(min_len, max_len + 1): | ||
prefix_uniform_len = tokenized_datasets.filter( | ||
lambda tokenized_datasets: tokenized_datasets["sequence_length"] == prefix_len | ||
) # group prefixes with uniform lengths, due to absent of padding tokens in GPT2 | ||
if len(prefix_uniform_len) == 0: | ||
continue | ||
prefix_loader = DataLoader( | ||
prefix_uniform_len, collate_fn=collate_fn, batch_size=CFG.inference_batch_size, shuffle=True | ||
) # batching with collation | ||
list_prefix_loaders.append(prefix_loader) | ||
print("dataloader created with token length of", prefix_len) | ||
|
||
# inferencing per dataloader | ||
print("inference start") | ||
list_prefix_texts = [] | ||
list_generated_texts = [] | ||
|
||
for prefix_loader in list_prefix_loaders: | ||
for idx, (prefix_batch, attention_mask) in enumerate(tqdm(prefix_loader)): | ||
if idx == 0: | ||
prefix_length = len(prefix_batch[0]) | ||
print("inferencing per dataloader with prefix length of:", prefix_length) | ||
prefix_batch = prefix_batch.to(devices[0]) # load on the designated device | ||
attention_mask = attention_mask.to(devices[0]) # load on the designated device | ||
with torch.no_grad(): | ||
generated = baseline_model.generate( | ||
input_ids=prefix_batch, | ||
attention_mask=attention_mask, | ||
max_length=CFG.max_prefix_length + CFG.generate_token_length, | ||
num_return_sequences=CFG.num_return_sequences, | ||
repetition_penalty=CFG.repetition_penalty, | ||
no_repeat_ngram_size=CFG.no_repeat_ngram_size, | ||
) | ||
del attention_mask | ||
|
||
prefix_texts = tokenizer.batch_decode( | ||
prefix_batch.cpu().detach().numpy(), skip_special_tokens=True | ||
) | ||
del prefix_batch | ||
|
||
generated_texts = tokenizer.batch_decode( | ||
generated.cpu().detach().numpy(), skip_special_tokens=True | ||
) | ||
del generated | ||
torch.cuda.empty_cache() | ||
|
||
list_prefix_texts.extend(prefix_texts) | ||
list_generated_texts.extend(generated_texts) | ||
print( | ||
f"generation/sampling completed | {prefix_length} prefix length | {idx+1 * CFG.inference_batch_size} samples" | ||
) | ||
|
||
df = pd.DataFrame({"prefix": list_prefix_texts, "generated": list_generated_texts}) | ||
df.to_csv(CFG.inference_result_path, index=False) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
import zlib | ||
import torch | ||
import numpy as np | ||
from enum import Enum | ||
|
||
|
||
def calculate_batch_perplexity(input_ids, model): | ||
""" perplexity defined as the exponential of the model's loss """ | ||
model.eval() | ||
with torch.no_grad(): | ||
output = model(input_ids, labels=input_ids) | ||
perplexity = torch.exp(output.loss) | ||
del output, input_ids | ||
return perplexity | ||
|
||
|
||
def calculate_batch_window(input_ids, model, window_size=50): | ||
""" | ||
Sometimes a model is not confident when the sample | ||
contains one memorized substring surrounded by a | ||
block of non-memorized (and high perplexity) text. | ||
To handle this, we use the minimum perplexity when | ||
averaged over a sliding window of 50 tokens. | ||
""" | ||
model.eval() | ||
|
||
# if input_ids is nested sequence, lower the dimension | ||
if len(input_ids.size()) != 1: | ||
input_ids = input_ids.squeeze() | ||
else: | ||
pass | ||
|
||
# if not sliding window unavailable, then return mere perplexity | ||
if input_ids.size() < 50: | ||
return calculate_batch_perplexity(input_ids, model) | ||
else: | ||
pass | ||
|
||
# make tensors for the sliding window | ||
sliding_windows = input_ids.unfold(0, window_size, 1) | ||
min_perplexity = np.inf | ||
|
||
# yield the lowest perplexity score out of given sliding window | ||
with torch.no_grad(): | ||
for tensor in sliding_windows: | ||
perplexity = calculate_batch_perplexity(tensor, model) | ||
del tensor | ||
min_perplexity = np.min(min_perplexity, perplexity) | ||
|
||
del input_ids | ||
return min_perplexity | ||
|
||
|
||
def calculate_batch_zlib(text): | ||
""" | ||
As a simple baseline method, we compute the zlib entropy of the text: | ||
the number of bits of entropy when the sequence is compressed with zlib compression. | ||
Although text compressors are simple, they can identify many of the | ||
examples of trivial memorization and repeated patterns described above | ||
(e.g., they are excellent at modeling repeated substrings). | ||
""" | ||
|
||
return zlib.compress(bytes(text, "utf-8")) | ||
|
||
|
||
class Summary(Enum): | ||
NONE = 0 | ||
AVERAGE = 1 | ||
SUM = 2 | ||
COUNT = 3 | ||
|
||
|
||
class AverageMeter(object): | ||
"""Computes and stores the average accross the given batches""" | ||
|
||
def __init__(self, name, fmt=":f", summary_type=Summary.AVERAGE): | ||
self.name = name | ||
self.fmt = fmt | ||
self.summary_type = summary_type | ||
self.reset() | ||
|
||
def reset(self): | ||
self.val = 0 | ||
self.avg = 0 | ||
self.sum = 0 | ||
self.count = 0 | ||
|
||
def update(self, val, n=1): | ||
self.val = val | ||
self.sum += val * n | ||
self.count += n | ||
self.avg = self.sum / self.count | ||
|
||
def __str__(self): | ||
fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" | ||
return fmtstr.format(**self.__dict__) | ||
|
||
def summary(self): | ||
fmtstr = "" | ||
if self.summary_type is Summary.NONE: | ||
fmtstr = "" | ||
elif self.summary_type is Summary.AVERAGE: | ||
fmtstr = "{name} {avg:.3f}" | ||
elif self.summary_type is Summary.SUM: | ||
fmtstr = "{name} {sum:.3f}" | ||
elif self.summary_type is Summary.COUNT: | ||
fmtstr = "{name} {count:.3f}" | ||
else: | ||
raise ValueError("invalid summary type %r" % self.summary_type) | ||
|
||
return fmtstr.format(**self.__dict__) | ||
|
Oops, something went wrong.