Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle char level tokenization #11

Merged
merged 4 commits into from
Feb 18, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 31 additions & 12 deletions src/generation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) Together

Check failure on line 1 in src/generation.py

View workflow job for this annotation

GitHub Actions / isort

Imports are incorrectly sorted and/or formatted.
# This software is distributed under the terms of the Apache License, Version 2.0
# Author: Michael Poli

Expand All @@ -8,6 +8,7 @@

from src.sample import sample
from src.utils import print_rank_0
from src.tokenizer import CharLevelTokenizer


class Generator:
Expand All @@ -17,7 +18,7 @@
self.top_k = top_k
self.top_p = top_p
self.temperature = temperature
self.untils = ["\n\n"]

Check failure on line 21 in src/generation.py

View workflow job for this annotation

GitHub Actions / codespell

untils ==> until, utils

def generate(
self,
Expand All @@ -32,11 +33,20 @@
stop_at_eos=True,
max_seqlen=None,
):
eos_token_ids = self.tokenizer.tokenize(self.tokenizer.eos).to(device)

if isinstance(self.tokenizer.eos, int):
eos_token_ids = torch.LongTensor([self.tokenizer.eos]).to(device)
else:
# is a tensor
eos_token_ids = self.tokenizer.tokenize(self.tokenizer.eos).to(device)

if input_ids is None:
input = self.tokenizer.tokenize(input_string)
input = input.unsqueeze(0).to(device)
if isinstance(input, list):
input = torch.LongTensor(input).unsqueeze(0).to(device)
# is a tensor
else:
input = input.unsqueeze(0).to(device)

else:
input = input_ids
x = input
Expand Down Expand Up @@ -76,6 +86,7 @@
mem_after_tok = torch.cuda.memory_allocated(device=x.device) / 1e9
print_rank_0(f"Memory after tokenization: {mem_after_tok} GB")
print_rank_0("Starting generation...")
torch.cuda.memory._record_memory_history(enabled=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be on by default

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

if input_string is not None:
print_rank_0("Prompt: " + input_string)
else:
Expand All @@ -96,10 +107,12 @@
inference_params_dict_out["mha"].seqlen_offset += 1
inference_params_dict_out["hyena"].seqlen_offset += 1

logits, inference_params_dict_out = self.model(
x,
inference_params_dict=inference_params_dict_out,
)
# do forward pass with no gradient
with torch.no_grad():
logits, inference_params_dict_out = self.model(
x,
inference_params_dict=inference_params_dict_out,
)

last_logits = logits[:, -1]

Expand All @@ -113,7 +126,7 @@
if stop_at_eos and (generation[0, -2:] == eos_token_ids).all():
print_rank_0("Stopping generation at EOS")

if print_generation and verbose:
if print_generation and verbose and batch_size == 1:
print_rank_0(
f"{self.tokenizer.detokenize([new_idx.item()])}",
end=" ",
Expand All @@ -128,12 +141,18 @@
x = torch.cat([x, new_idx[:, None]], dim=-1)

if verbose:
y = self.tokenizer.detokenize_batch(
generation[:, : i + 1],
skip_special_tokens=skip_special_tokens,
)
if isinstance(self.tokenizer, CharLevelTokenizer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe could slightly reformat this to have different args depending on the Tokenizer class

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reformatted

y = self.tokenizer.detokenize_batch(
generation[:, : i + 1],
# skip_special_tokens=skip_special_tokens, # this isn't supported in the Char level tokenizer
)
else:
y = self.tokenizer.detokenize_batch(
generation[:, : i + 1],
skip_special_tokens=skip_special_tokens,
)

for until in self.untils:

Check failure on line 155 in src/generation.py

View workflow job for this annotation

GitHub Actions / codespell

untils ==> until, utils
if until in y:
y = y.split(until)[0]
break
Expand Down
127 changes: 127 additions & 0 deletions src/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# based on https://github.com/EleutherAI/gpt-neox/blob/main/megatron/tokenizer/tokenizer.py

Check failure on line 1 in src/tokenizer.py

View workflow job for this annotation

GitHub Actions / isort

Imports are incorrectly sorted and/or formatted.
from abc import ABC
import json
import pathlib

import torch
import tqdm
from tokenizers import Tokenizer
from abc import abstractmethod
from typing import List, Union
import numpy as np


class HFAutoTokenizer:
Expand Down Expand Up @@ -67,3 +71,126 @@
@property
def vocab_size(self):
return 32000

class AbstractTokenizer(ABC):
"""Abstract class for tokenizer."""

def __init__(self, name):
self.name = name
super().__init__()

@property
@abstractmethod
def vocab_size(self):
pass

@property
@abstractmethod
def vocab(self):
"""Dictionary from vocab text token to id token."""
pass

@property
@abstractmethod
def inv_vocab(self):
"""Dictionary from vocab id token to text token."""
pass

@abstractmethod
def tokenize(self, text):
pass

def detokenize(self, token_ids):
raise NotImplementedError(
"detokenizer is not implemented for {} " "tokenizer".format(self.name)
)

@property
def cls(self):
raise NotImplementedError(
"CLS is not provided for {} " "tokenizer".format(self.name)
)

@property
def sep(self):
raise NotImplementedError(
"SEP is not provided for {} " "tokenizer".format(self.name)
)

@property
def pad(self):
raise NotImplementedError(
"PAD is not provided for {} " "tokenizer".format(self.name)
)

@property
def eod(self):
raise NotImplementedError(
"EOD is not provided for {} " "tokenizer".format(self.name)
)

@property
def mask(self):
raise NotImplementedError(
"MASK is not provided for {} " "tokenizer".format(self.name)
)


class CharLevelTokenizer(AbstractTokenizer):
"""Character Level Tokenizer"""

def __init__(self, vocab_size):
name = "CharLevelTokenizer"
super().__init__(name)
self._vocab_size = vocab_size
self.eod_id = 0
self.eos_id = 0
self.pad_id = 1

def clamp(self, n):
return max(32, min(n, self.vocab_size))

@property
def vocab_size(self):
return self._vocab_size

@property
def vocab(self):
raise NotImplementedError

@property
def inv_vocab(self):
raise NotImplementedError

def decode_token(self, token: int):
return str(chr(self.clamp(token)))

def tokenize(self, text: str):
return list(np.fromstring(text, dtype=np.uint8))

def tokenize_batch(self, text_batch: Union[List[str], str]):
if isinstance(text_batch, list):
return [self.tokenize(s) for s in text_batch]
else:
return self.tokenize(text_batch)

def detokenize(self, token_ids):
return "".join(list(map(self.decode_token, token_ids)))

def detokenize_batch(self, token_ids: Union[List[str], str]):
if isinstance(token_ids, list):
return [self.detokenize(s) for s in token_ids]
# elif if tensor, convert to list first
elif isinstance(token_ids, torch.Tensor):
return [self.detokenize(s) for s in token_ids.tolist()]
else:
return self.detokenize(token_ids)

@property
def eod(self):
return self.eod_id

# duplicate to suppose both names, eos and eod
@property
def eos(self):
return self.eod_id
Loading