Skip to content

Commit

Permalink
black and isort
Browse files Browse the repository at this point in the history
  • Loading branch information
brianhie committed Feb 18, 2024
1 parent 171fb71 commit 047503a
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 27 deletions.
8 changes: 4 additions & 4 deletions src/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import torch

from src.sample import sample
from src.utils import print_rank_0
from src.tokenizer import CharLevelTokenizer
from src.utils import print_rank_0


class Generator:
Expand Down Expand Up @@ -38,15 +38,15 @@ def generate(
else:
# is a tensor
eos_token_ids = self.tokenizer.tokenize(self.tokenizer.eos).to(device)

if input_ids is None:
input = self.tokenizer.tokenize(input_string)
if isinstance(input, list):
input = torch.LongTensor(input).unsqueeze(0).to(device)
# is a tensor
else:
input = input.unsqueeze(0).to(device)

else:
input = input_ids
x = input
Expand Down Expand Up @@ -143,7 +143,7 @@ def generate(
kwargs = {}
if not isinstance(self.tokenizer, CharLevelTokenizer):
kwargs["skip_special_tokens"] = skip_special_tokens
y = self.tokenizer.detokenize_batch(generation[:, :i+1], **kwargs)
y = self.tokenizer.detokenize_batch(generation[:, : i + 1], **kwargs)

for until in self.untils:

Check failure on line 148 in src/generation.py

View workflow job for this annotation

GitHub Actions / codespell

untils ==> until, utils
if until in y:
Expand Down
34 changes: 11 additions & 23 deletions src/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
# based on https://github.com/EleutherAI/gpt-neox/blob/main/megatron/tokenizer/tokenizer.py
from abc import ABC
import json
import pathlib
from abc import ABC, abstractmethod
from typing import List, Union

import numpy as np
import torch
import tqdm
from tokenizers import Tokenizer
from abc import abstractmethod
from typing import List, Union
import numpy as np


class HFAutoTokenizer:
Expand Down Expand Up @@ -72,6 +71,7 @@ def eod(self):
def vocab_size(self):
return 32000


class AbstractTokenizer(ABC):
"""Abstract class for tokenizer."""

Expand Down Expand Up @@ -101,39 +101,27 @@ def tokenize(self, text):
pass

def detokenize(self, token_ids):
raise NotImplementedError(
"detokenizer is not implemented for {} " "tokenizer".format(self.name)
)
raise NotImplementedError("detokenizer is not implemented for {} " "tokenizer".format(self.name))

@property
def cls(self):
raise NotImplementedError(
"CLS is not provided for {} " "tokenizer".format(self.name)
)
raise NotImplementedError("CLS is not provided for {} " "tokenizer".format(self.name))

@property
def sep(self):
raise NotImplementedError(
"SEP is not provided for {} " "tokenizer".format(self.name)
)
raise NotImplementedError("SEP is not provided for {} " "tokenizer".format(self.name))

@property
def pad(self):
raise NotImplementedError(
"PAD is not provided for {} " "tokenizer".format(self.name)
)
raise NotImplementedError("PAD is not provided for {} " "tokenizer".format(self.name))

@property
def eod(self):
raise NotImplementedError(
"EOD is not provided for {} " "tokenizer".format(self.name)
)
raise NotImplementedError("EOD is not provided for {} " "tokenizer".format(self.name))

@property
def mask(self):
raise NotImplementedError(
"MASK is not provided for {} " "tokenizer".format(self.name)
)
raise NotImplementedError("MASK is not provided for {} " "tokenizer".format(self.name))


class CharLevelTokenizer(AbstractTokenizer):
Expand Down Expand Up @@ -193,4 +181,4 @@ def eod(self):
# duplicate to suppose both names, eos and eod
@property
def eos(self):
return self.eod_id
return self.eod_id

0 comments on commit 047503a

Please sign in to comment.