Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle char level tokenization #11

Merged
merged 4 commits into from
Feb 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions src/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch

from src.sample import sample
from src.tokenizer import CharLevelTokenizer
from src.utils import print_rank_0


Expand All @@ -17,7 +18,7 @@
self.top_k = top_k
self.top_p = top_p
self.temperature = temperature
self.untils = ["\n\n"]

Check failure on line 21 in src/generation.py

View workflow job for this annotation

GitHub Actions / codespell

untils ==> until, utils

def generate(
self,
Expand All @@ -32,11 +33,20 @@
stop_at_eos=True,
max_seqlen=None,
):
eos_token_ids = self.tokenizer.tokenize(self.tokenizer.eos).to(device)
if isinstance(self.tokenizer.eos, int):
eos_token_ids = torch.LongTensor([self.tokenizer.eos]).to(device)
else:
# is a tensor
eos_token_ids = self.tokenizer.tokenize(self.tokenizer.eos).to(device)

if input_ids is None:
input = self.tokenizer.tokenize(input_string)
input = input.unsqueeze(0).to(device)
if isinstance(input, list):
input = torch.LongTensor(input).unsqueeze(0).to(device)
# is a tensor
else:
input = input.unsqueeze(0).to(device)

else:
input = input_ids
x = input
Expand Down Expand Up @@ -96,10 +106,12 @@
inference_params_dict_out["mha"].seqlen_offset += 1
inference_params_dict_out["hyena"].seqlen_offset += 1

logits, inference_params_dict_out = self.model(
x,
inference_params_dict=inference_params_dict_out,
)
# do forward pass with no gradient
with torch.no_grad():
logits, inference_params_dict_out = self.model(
x,
inference_params_dict=inference_params_dict_out,
)

last_logits = logits[:, -1]

Expand All @@ -113,7 +125,7 @@
if stop_at_eos and (generation[0, -2:] == eos_token_ids).all():
print_rank_0("Stopping generation at EOS")

if print_generation and verbose:
if print_generation and verbose and batch_size == 1:
print_rank_0(
f"{self.tokenizer.detokenize([new_idx.item()])}",
end=" ",
Expand All @@ -128,12 +140,12 @@
x = torch.cat([x, new_idx[:, None]], dim=-1)

if verbose:
y = self.tokenizer.detokenize_batch(
generation[:, : i + 1],
skip_special_tokens=skip_special_tokens,
)
kwargs = {}
if not isinstance(self.tokenizer, CharLevelTokenizer):
kwargs["skip_special_tokens"] = skip_special_tokens
y = self.tokenizer.detokenize_batch(generation[:, : i + 1], **kwargs)

for until in self.untils:

Check failure on line 148 in src/generation.py

View workflow job for this annotation

GitHub Actions / codespell

untils ==> until, utils
if until in y:
y = y.split(until)[0]
break
Expand Down
115 changes: 115 additions & 0 deletions src/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# based on https://github.com/EleutherAI/gpt-neox/blob/main/megatron/tokenizer/tokenizer.py
import json
import pathlib
from abc import ABC, abstractmethod
from typing import List, Union

import numpy as np
import torch
import tqdm
from tokenizers import Tokenizer
Expand Down Expand Up @@ -67,3 +70,115 @@ def eod(self):
@property
def vocab_size(self):
return 32000


class AbstractTokenizer(ABC):
"""Abstract class for tokenizer."""

def __init__(self, name):
self.name = name
super().__init__()

@property
@abstractmethod
def vocab_size(self):
pass

@property
@abstractmethod
def vocab(self):
"""Dictionary from vocab text token to id token."""
pass

@property
@abstractmethod
def inv_vocab(self):
"""Dictionary from vocab id token to text token."""
pass

@abstractmethod
def tokenize(self, text):
pass

def detokenize(self, token_ids):
raise NotImplementedError("detokenizer is not implemented for {} " "tokenizer".format(self.name))

@property
def cls(self):
raise NotImplementedError("CLS is not provided for {} " "tokenizer".format(self.name))

@property
def sep(self):
raise NotImplementedError("SEP is not provided for {} " "tokenizer".format(self.name))

@property
def pad(self):
raise NotImplementedError("PAD is not provided for {} " "tokenizer".format(self.name))

@property
def eod(self):
raise NotImplementedError("EOD is not provided for {} " "tokenizer".format(self.name))

@property
def mask(self):
raise NotImplementedError("MASK is not provided for {} " "tokenizer".format(self.name))


class CharLevelTokenizer(AbstractTokenizer):
"""Character Level Tokenizer"""

def __init__(self, vocab_size):
name = "CharLevelTokenizer"
super().__init__(name)
self._vocab_size = vocab_size
self.eod_id = 0
self.eos_id = 0
self.pad_id = 1

def clamp(self, n):
return max(32, min(n, self.vocab_size))

@property
def vocab_size(self):
return self._vocab_size

@property
def vocab(self):
raise NotImplementedError

@property
def inv_vocab(self):
raise NotImplementedError

def decode_token(self, token: int):
return str(chr(self.clamp(token)))

def tokenize(self, text: str):
return list(np.fromstring(text, dtype=np.uint8))

def tokenize_batch(self, text_batch: Union[List[str], str]):
if isinstance(text_batch, list):
return [self.tokenize(s) for s in text_batch]
else:
return self.tokenize(text_batch)

def detokenize(self, token_ids):
return "".join(list(map(self.decode_token, token_ids)))

def detokenize_batch(self, token_ids: Union[List[str], str]):
if isinstance(token_ids, list):
return [self.detokenize(s) for s in token_ids]
# elif if tensor, convert to list first
elif isinstance(token_ids, torch.Tensor):
return [self.detokenize(s) for s in token_ids.tolist()]
else:
return self.detokenize(token_ids)

@property
def eod(self):
return self.eod_id

# duplicate to suppose both names, eos and eod
@property
def eos(self):
return self.eod_id
Loading