diff --git a/src/generation.py b/src/generation.py index 32e1bdc..a998c3f 100644 --- a/src/generation.py +++ b/src/generation.py @@ -7,6 +7,7 @@ import torch from src.sample import sample +from src.tokenizer import CharLevelTokenizer from src.utils import print_rank_0 @@ -32,11 +33,20 @@ def generate( stop_at_eos=True, max_seqlen=None, ): - eos_token_ids = self.tokenizer.tokenize(self.tokenizer.eos).to(device) + if isinstance(self.tokenizer.eos, int): + eos_token_ids = torch.LongTensor([self.tokenizer.eos]).to(device) + else: + # is a tensor + eos_token_ids = self.tokenizer.tokenize(self.tokenizer.eos).to(device) if input_ids is None: input = self.tokenizer.tokenize(input_string) - input = input.unsqueeze(0).to(device) + if isinstance(input, list): + input = torch.LongTensor(input).unsqueeze(0).to(device) + # is a tensor + else: + input = input.unsqueeze(0).to(device) + else: input = input_ids x = input @@ -96,10 +106,12 @@ def generate( inference_params_dict_out["mha"].seqlen_offset += 1 inference_params_dict_out["hyena"].seqlen_offset += 1 - logits, inference_params_dict_out = self.model( - x, - inference_params_dict=inference_params_dict_out, - ) + # do forward pass with no gradient + with torch.no_grad(): + logits, inference_params_dict_out = self.model( + x, + inference_params_dict=inference_params_dict_out, + ) last_logits = logits[:, -1] @@ -113,7 +125,7 @@ def generate( if stop_at_eos and (generation[0, -2:] == eos_token_ids).all(): print_rank_0("Stopping generation at EOS") - if print_generation and verbose: + if print_generation and verbose and batch_size == 1: print_rank_0( f"{self.tokenizer.detokenize([new_idx.item()])}", end=" ", @@ -128,10 +140,10 @@ def generate( x = torch.cat([x, new_idx[:, None]], dim=-1) if verbose: - y = self.tokenizer.detokenize_batch( - generation[:, : i + 1], - skip_special_tokens=skip_special_tokens, - ) + kwargs = {} + if not isinstance(self.tokenizer, CharLevelTokenizer): + kwargs["skip_special_tokens"] = skip_special_tokens + y = self.tokenizer.detokenize_batch(generation[:, : i + 1], **kwargs) for until in self.untils: if until in y: diff --git a/src/tokenizer.py b/src/tokenizer.py index 6792c94..ed25d72 100644 --- a/src/tokenizer.py +++ b/src/tokenizer.py @@ -1,7 +1,10 @@ # based on https://github.com/EleutherAI/gpt-neox/blob/main/megatron/tokenizer/tokenizer.py import json import pathlib +from abc import ABC, abstractmethod +from typing import List, Union +import numpy as np import torch import tqdm from tokenizers import Tokenizer @@ -67,3 +70,115 @@ def eod(self): @property def vocab_size(self): return 32000 + + +class AbstractTokenizer(ABC): + """Abstract class for tokenizer.""" + + def __init__(self, name): + self.name = name + super().__init__() + + @property + @abstractmethod + def vocab_size(self): + pass + + @property + @abstractmethod + def vocab(self): + """Dictionary from vocab text token to id token.""" + pass + + @property + @abstractmethod + def inv_vocab(self): + """Dictionary from vocab id token to text token.""" + pass + + @abstractmethod + def tokenize(self, text): + pass + + def detokenize(self, token_ids): + raise NotImplementedError("detokenizer is not implemented for {} " "tokenizer".format(self.name)) + + @property + def cls(self): + raise NotImplementedError("CLS is not provided for {} " "tokenizer".format(self.name)) + + @property + def sep(self): + raise NotImplementedError("SEP is not provided for {} " "tokenizer".format(self.name)) + + @property + def pad(self): + raise NotImplementedError("PAD is not provided for {} " "tokenizer".format(self.name)) + + @property + def eod(self): + raise NotImplementedError("EOD is not provided for {} " "tokenizer".format(self.name)) + + @property + def mask(self): + raise NotImplementedError("MASK is not provided for {} " "tokenizer".format(self.name)) + + +class CharLevelTokenizer(AbstractTokenizer): + """Character Level Tokenizer""" + + def __init__(self, vocab_size): + name = "CharLevelTokenizer" + super().__init__(name) + self._vocab_size = vocab_size + self.eod_id = 0 + self.eos_id = 0 + self.pad_id = 1 + + def clamp(self, n): + return max(32, min(n, self.vocab_size)) + + @property + def vocab_size(self): + return self._vocab_size + + @property + def vocab(self): + raise NotImplementedError + + @property + def inv_vocab(self): + raise NotImplementedError + + def decode_token(self, token: int): + return str(chr(self.clamp(token))) + + def tokenize(self, text: str): + return list(np.fromstring(text, dtype=np.uint8)) + + def tokenize_batch(self, text_batch: Union[List[str], str]): + if isinstance(text_batch, list): + return [self.tokenize(s) for s in text_batch] + else: + return self.tokenize(text_batch) + + def detokenize(self, token_ids): + return "".join(list(map(self.decode_token, token_ids))) + + def detokenize_batch(self, token_ids: Union[List[str], str]): + if isinstance(token_ids, list): + return [self.detokenize(s) for s in token_ids] + # elif if tensor, convert to list first + elif isinstance(token_ids, torch.Tensor): + return [self.detokenize(s) for s in token_ids.tolist()] + else: + return self.detokenize(token_ids) + + @property + def eod(self): + return self.eod_id + + # duplicate to suppose both names, eos and eod + @property + def eos(self): + return self.eod_id