diff --git a/src/generation.py b/src/generation.py index 106ddb2..49d7290 100644 --- a/src/generation.py +++ b/src/generation.py @@ -8,7 +8,7 @@ from src.sample import sample from src.utils import print_rank_0 -from src.tokenizer import CharLevelTokenizer # need to add a check for this type of tokenizer +from src.tokenizer import CharLevelTokenizer class Generator: @@ -33,7 +33,6 @@ def generate( stop_at_eos=True, max_seqlen=None, ): - # check dtype if self.tokenizer.eos is int if isinstance(self.tokenizer.eos, int): eos_token_ids = torch.LongTensor([self.tokenizer.eos]).to(device) else: