Skip to content

Commit

Permalink
add char level tokenizer, modify generation code to handle batched in…
Browse files Browse the repository at this point in the history
…ference with char tok
  • Loading branch information
brianhie committed Feb 18, 2024
1 parent 53c3b23 commit 4a1d35f
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 12 deletions.
44 changes: 32 additions & 12 deletions src/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from src.sample import sample
from src.utils import print_rank_0
from src.tokenizer import CharLevelTokenizer # need to add a check for this type of tokenizer


class Generator:
Expand All @@ -32,11 +33,21 @@ def generate(
stop_at_eos=True,
max_seqlen=None,
):
eos_token_ids = self.tokenizer.tokenize(self.tokenizer.eos).to(device)

# check dtype if self.tokenizer.eos is int
if isinstance(self.tokenizer.eos, int):
eos_token_ids = torch.LongTensor([self.tokenizer.eos]).to(device)
else:
# is a tensor
eos_token_ids = self.tokenizer.tokenize(self.tokenizer.eos).to(device)

if input_ids is None:
input = self.tokenizer.tokenize(input_string)
input = input.unsqueeze(0).to(device)
if isinstance(input, list):
input = torch.LongTensor(input).unsqueeze(0).to(device)
# is a tensor
else:
input = input.unsqueeze(0).to(device)

else:
input = input_ids
x = input
Expand Down Expand Up @@ -76,6 +87,7 @@ def generate(
mem_after_tok = torch.cuda.memory_allocated(device=x.device) / 1e9
print_rank_0(f"Memory after tokenization: {mem_after_tok} GB")
print_rank_0("Starting generation...")
torch.cuda.memory._record_memory_history(enabled=True)
if input_string is not None:
print_rank_0("Prompt: " + input_string)
else:
Expand All @@ -96,10 +108,12 @@ def generate(
inference_params_dict_out["mha"].seqlen_offset += 1
inference_params_dict_out["hyena"].seqlen_offset += 1

logits, inference_params_dict_out = self.model(
x,
inference_params_dict=inference_params_dict_out,
)
# do forward pass with no gradient
with torch.no_grad():
logits, inference_params_dict_out = self.model(
x,
inference_params_dict=inference_params_dict_out,
)

last_logits = logits[:, -1]

Expand All @@ -113,7 +127,7 @@ def generate(
if stop_at_eos and (generation[0, -2:] == eos_token_ids).all():
print_rank_0("Stopping generation at EOS")

if print_generation and verbose:
if print_generation and verbose and batch_size == 1:
print_rank_0(
f"{self.tokenizer.detokenize([new_idx.item()])}",
end=" ",
Expand All @@ -128,10 +142,16 @@ def generate(
x = torch.cat([x, new_idx[:, None]], dim=-1)

if verbose:
y = self.tokenizer.detokenize_batch(
generation[:, : i + 1],
skip_special_tokens=skip_special_tokens,
)
if isinstance(self.tokenizer, CharLevelTokenizer):
y = self.tokenizer.detokenize_batch(
generation[:, : i + 1],
# skip_special_tokens=skip_special_tokens, # this isn't supported in the Char level tokenizer
)
else:
y = self.tokenizer.detokenize_batch(
generation[:, : i + 1],
skip_special_tokens=skip_special_tokens,
)

for until in self.untils:
if until in y:
Expand Down
127 changes: 127 additions & 0 deletions src/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# based on https://github.com/EleutherAI/gpt-neox/blob/main/megatron/tokenizer/tokenizer.py
from abc import ABC
import json
import pathlib

import torch
import tqdm
from tokenizers import Tokenizer
from abc import abstractmethod
from typing import List, Union
import numpy as np


class HFAutoTokenizer:
Expand Down Expand Up @@ -67,3 +71,126 @@ def eod(self):
@property
def vocab_size(self):
return 32000

class AbstractTokenizer(ABC):
"""Abstract class for tokenizer."""

def __init__(self, name):
self.name = name
super().__init__()

@property
@abstractmethod
def vocab_size(self):
pass

@property
@abstractmethod
def vocab(self):
"""Dictionary from vocab text token to id token."""
pass

@property
@abstractmethod
def inv_vocab(self):
"""Dictionary from vocab id token to text token."""
pass

@abstractmethod
def tokenize(self, text):
pass

def detokenize(self, token_ids):
raise NotImplementedError(
"detokenizer is not implemented for {} " "tokenizer".format(self.name)
)

@property
def cls(self):
raise NotImplementedError(
"CLS is not provided for {} " "tokenizer".format(self.name)
)

@property
def sep(self):
raise NotImplementedError(
"SEP is not provided for {} " "tokenizer".format(self.name)
)

@property
def pad(self):
raise NotImplementedError(
"PAD is not provided for {} " "tokenizer".format(self.name)
)

@property
def eod(self):
raise NotImplementedError(
"EOD is not provided for {} " "tokenizer".format(self.name)
)

@property
def mask(self):
raise NotImplementedError(
"MASK is not provided for {} " "tokenizer".format(self.name)
)


class CharLevelTokenizer(AbstractTokenizer):
"""Character Level Tokenizer"""

def __init__(self, vocab_size):
name = "CharLevelTokenizer"
super().__init__(name)
self._vocab_size = vocab_size
self.eod_id = 0
self.eos_id = 0
self.pad_id = 1

def clamp(self, n):
return max(32, min(n, self.vocab_size))

@property
def vocab_size(self):
return self._vocab_size

@property
def vocab(self):
raise NotImplementedError

@property
def inv_vocab(self):
raise NotImplementedError

def decode_token(self, token: int):
return str(chr(self.clamp(token)))

def tokenize(self, text: str):
return list(np.fromstring(text, dtype=np.uint8))

def tokenize_batch(self, text_batch: Union[List[str], str]):
if isinstance(text_batch, list):
return [self.tokenize(s) for s in text_batch]
else:
return self.tokenize(text_batch)

def detokenize(self, token_ids):
return "".join(list(map(self.decode_token, token_ids)))

def detokenize_batch(self, token_ids: Union[List[str], str]):
if isinstance(token_ids, list):
return [self.detokenize(s) for s in token_ids]
# elif if tensor, convert to list first
elif isinstance(token_ids, torch.Tensor):
return [self.detokenize(s) for s in token_ids.tolist()]
else:
return self.detokenize(token_ids)

@property
def eod(self):
return self.eod_id

# duplicate to suppose both names, eos and eod
@property
def eos(self):
return self.eod_id

0 comments on commit 4a1d35f

Please sign in to comment.