forked from nimashoghi/smart-quantization
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathimdb.py
102 lines (88 loc) · 3.22 KB
/
imdb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from argparse import ArgumentParser
from typing import Optional
import torch
from argparse_utils.mapping import mapping_action
from datasets import load_dataset
from pytorch_lightning import LightningDataModule
from torch.utils.data.dataloader import DataLoader
from transformers import BertTokenizer
from transformers.tokenization_utils_base import (
PaddingStrategy,
TensorType,
TruncationStrategy,
)
class IMDBDataModule(LightningDataModule):
@staticmethod
def add_argparse_args(parent_parser: ArgumentParser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument("--max_input_length", default=512, type=int)
parser.add_argument("--batch_size", default=8, type=int, help="batch size")
parser.add_argument("--val_batch_size", type=int, help="validation batch size")
parser.add_argument(
"--tokenizer_cls",
action=mapping_action(dict(bert=BertTokenizer)),
default="bert",
)
return parser
def __init__(self, hparams):
super().__init__()
self.hparams = hparams
if self.hparams.val_batch_size is None:
self.hparams.val_batch_size = max(self.hparams.batch_size // 4, 1)
self.tokenizer = self.hparams.tokenizer_cls.from_pretrained(
self.hparams.pretrained_model_name
)
def batch_collate(self, batch):
input = self.tokenizer(
[value["text"] for value in batch],
max_length=self.hparams.max_input_length,
padding=PaddingStrategy.MAX_LENGTH,
truncation=TruncationStrategy.LONGEST_FIRST,
return_tensors=TensorType.PYTORCH,
)
output = torch.tensor([value["label"] for value in batch], dtype=torch.long)
return (
dict(
input_ids=input["input_ids"],
attention_mask=input["attention_mask"],
),
output,
batch,
)
def setup(self, stage: Optional[str]):
self.dataset = load_dataset("imdb")
if stage == "fit" or stage is None:
self.imdb_train = self.dataset["train"]
self.imdb_val = self.dataset["test"]
if stage == "test" or stage is None:
self.imdb_test = self.dataset["test"]
def prepare_data(self):
load_dataset("imdb")
self.hparams.tokenizer_cls.from_pretrained(self.hparams.pretrained_model_name)
def train_dataloader(self):
return DataLoader(
self.imdb_train,
batch_size=self.hparams.batch_size,
shuffle=True,
num_workers=8,
pin_memory=True,
collate_fn=self.batch_collate,
)
def val_dataloader(self):
return DataLoader(
self.imdb_val,
batch_size=self.hparams.val_batch_size,
shuffle=False,
num_workers=8,
pin_memory=True,
collate_fn=self.batch_collate,
)
def test_dataloader(self):
return DataLoader(
self.imdb_test,
batch_size=self.hparams.val_batch_size,
shuffle=False,
num_workers=8,
pin_memory=True,
collate_fn=self.batch_collate,
)