Skip to content

Commit

Permalink
Merge pull request #11 from brianhie/char_tokenizer1
Browse files Browse the repository at this point in the history
feat: handle byte-level tokenization
  • Loading branch information
Zymrael authored Feb 18, 2024
2 parents 0661f6e + 047503a commit 37caf4d
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 11 deletions.
34 changes: 23 additions & 11 deletions src/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch

from src.sample import sample
from src.tokenizer import CharLevelTokenizer
from src.utils import print_rank_0


Expand All @@ -32,11 +33,20 @@ def generate(
stop_at_eos=True,
max_seqlen=None,
):
eos_token_ids = self.tokenizer.tokenize(self.tokenizer.eos).to(device)
if isinstance(self.tokenizer.eos, int):
eos_token_ids = torch.LongTensor([self.tokenizer.eos]).to(device)
else:
# is a tensor
eos_token_ids = self.tokenizer.tokenize(self.tokenizer.eos).to(device)

if input_ids is None:
input = self.tokenizer.tokenize(input_string)
input = input.unsqueeze(0).to(device)
if isinstance(input, list):
input = torch.LongTensor(input).unsqueeze(0).to(device)
# is a tensor
else:
input = input.unsqueeze(0).to(device)

else:
input = input_ids
x = input
Expand Down Expand Up @@ -96,10 +106,12 @@ def generate(
inference_params_dict_out["mha"].seqlen_offset += 1
inference_params_dict_out["hyena"].seqlen_offset += 1

logits, inference_params_dict_out = self.model(
x,
inference_params_dict=inference_params_dict_out,
)
# do forward pass with no gradient
with torch.no_grad():
logits, inference_params_dict_out = self.model(
x,
inference_params_dict=inference_params_dict_out,
)

last_logits = logits[:, -1]

Expand All @@ -113,7 +125,7 @@ def generate(
if stop_at_eos and (generation[0, -2:] == eos_token_ids).all():
print_rank_0("Stopping generation at EOS")

if print_generation and verbose:
if print_generation and verbose and batch_size == 1:
print_rank_0(
f"{self.tokenizer.detokenize([new_idx.item()])}",
end=" ",
Expand All @@ -128,10 +140,10 @@ def generate(
x = torch.cat([x, new_idx[:, None]], dim=-1)

if verbose:
y = self.tokenizer.detokenize_batch(
generation[:, : i + 1],
skip_special_tokens=skip_special_tokens,
)
kwargs = {}
if not isinstance(self.tokenizer, CharLevelTokenizer):
kwargs["skip_special_tokens"] = skip_special_tokens
y = self.tokenizer.detokenize_batch(generation[:, : i + 1], **kwargs)

for until in self.untils:

Check failure on line 148 in src/generation.py

View workflow job for this annotation

GitHub Actions / codespell

untils ==> until, utils
if until in y:
Expand Down
115 changes: 115 additions & 0 deletions src/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# based on https://github.com/EleutherAI/gpt-neox/blob/main/megatron/tokenizer/tokenizer.py
import json
import pathlib
from abc import ABC, abstractmethod
from typing import List, Union

import numpy as np
import torch
import tqdm
from tokenizers import Tokenizer
Expand Down Expand Up @@ -67,3 +70,115 @@ def eod(self):
@property
def vocab_size(self):
return 32000


class AbstractTokenizer(ABC):
"""Abstract class for tokenizer."""

def __init__(self, name):
self.name = name
super().__init__()

@property
@abstractmethod
def vocab_size(self):
pass

@property
@abstractmethod
def vocab(self):
"""Dictionary from vocab text token to id token."""
pass

@property
@abstractmethod
def inv_vocab(self):
"""Dictionary from vocab id token to text token."""
pass

@abstractmethod
def tokenize(self, text):
pass

def detokenize(self, token_ids):
raise NotImplementedError("detokenizer is not implemented for {} " "tokenizer".format(self.name))

@property
def cls(self):
raise NotImplementedError("CLS is not provided for {} " "tokenizer".format(self.name))

@property
def sep(self):
raise NotImplementedError("SEP is not provided for {} " "tokenizer".format(self.name))

@property
def pad(self):
raise NotImplementedError("PAD is not provided for {} " "tokenizer".format(self.name))

@property
def eod(self):
raise NotImplementedError("EOD is not provided for {} " "tokenizer".format(self.name))

@property
def mask(self):
raise NotImplementedError("MASK is not provided for {} " "tokenizer".format(self.name))


class CharLevelTokenizer(AbstractTokenizer):
"""Character Level Tokenizer"""

def __init__(self, vocab_size):
name = "CharLevelTokenizer"
super().__init__(name)
self._vocab_size = vocab_size
self.eod_id = 0
self.eos_id = 0
self.pad_id = 1

def clamp(self, n):
return max(32, min(n, self.vocab_size))

@property
def vocab_size(self):
return self._vocab_size

@property
def vocab(self):
raise NotImplementedError

@property
def inv_vocab(self):
raise NotImplementedError

def decode_token(self, token: int):
return str(chr(self.clamp(token)))

def tokenize(self, text: str):
return list(np.fromstring(text, dtype=np.uint8))

def tokenize_batch(self, text_batch: Union[List[str], str]):
if isinstance(text_batch, list):
return [self.tokenize(s) for s in text_batch]
else:
return self.tokenize(text_batch)

def detokenize(self, token_ids):
return "".join(list(map(self.decode_token, token_ids)))

def detokenize_batch(self, token_ids: Union[List[str], str]):
if isinstance(token_ids, list):
return [self.detokenize(s) for s in token_ids]
# elif if tensor, convert to list first
elif isinstance(token_ids, torch.Tensor):
return [self.detokenize(s) for s in token_ids.tolist()]
else:
return self.detokenize(token_ids)

@property
def eod(self):
return self.eod_id

# duplicate to suppose both names, eos and eod
@property
def eos(self):
return self.eod_id

0 comments on commit 37caf4d

Please sign in to comment.