Skip to content

Commit

Permalink
Fix tokenization for llamacpp, address torch to numpy bug
Browse files Browse the repository at this point in the history
  • Loading branch information
slundberg committed Nov 23, 2023
1 parent 72e9658 commit b2c729e
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 7 deletions.
5 changes: 4 additions & 1 deletion guidance/_grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,8 +532,11 @@ def token_limit(value, max_tokens):

def _rec_token_limit(grammar, max_tokens):
if grammar.max_tokens > max_tokens and not isinstance(grammar, Terminal):
if getattr(grammar, "recursive", False): # only restrict recursive selects, otherwise we would block all way to complete the grammar
if getattr(grammar, "recursive", False): # only restrict recursive selects, otherwise we would block all ways to complete the grammar
grammar.max_tokens = max_tokens
for value in getattr(grammar, "values", []): # restrict recursive selects recursive nodes
if not isinstance(value, Terminal):
value.max_tokens = max_tokens
if hasattr(grammar, "values"):
for g in grammar.values:
_rec_token_limit(g, max_tokens)
Expand Down
14 changes: 10 additions & 4 deletions guidance/_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,10 @@ def valid_next_bytes(self):
valid_items = set()
next_state_set = self.state_sets[self.state_set_pos + 1]
for item in next_state_set:
if item.pos > 0 and isinstance(item.values[item.pos - 1], Terminal):
token_span = self.token_counts[-1] - self.token_counts[item.start]
if item.node.max_tokens <= token_span:
continue
elif item.pos > 0 and isinstance(item.values[item.pos - 1], Terminal):
v = item.values[item.pos - 1]
if v not in valid_items:
valid_items.add(v)
Expand Down Expand Up @@ -367,9 +370,12 @@ def __repr__(self, state_sets=None) -> str:
rs = ""
if state.pos == 0:
rs += "•"
rs += state.values[0].name
if state.pos == 1:
rs += "•"
if len(state.values) == 0:
rs += "NO_VALUES!"
else:
rs += state.values[0].name
if state.pos == 1:
rs += "•"
else:
assert False
s += f"{rs:40} ({state.start}) {'nullable' if state.node.nullable else ''}\n"
Expand Down
4 changes: 4 additions & 0 deletions guidance/models/_llama_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ def __init__(self, model=None, tokenizer=None, echo=True, caching=True, temperat

self._cache_state["cache_token_ids"] = []

def _joint_tokenize(self, token_ids):
byte_string = b"".join([self.tokens[t] for t in token_ids])
return self.model_obj.tokenize(byte_string, add_bos=False, special=True)

def _get_logits(self, token_ids, forced_bytes):
'''Computes the logits for the given token state.
Expand Down
43 changes: 41 additions & 2 deletions guidance/models/_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ def _get_logits(self, token_ids, forced_bytes):

# pretend to extend the KV cache and update the log probs
return np.randn(len(self.tokens))

def _joint_tokenize(self, token_ids):
# an abstract method. Should return what a full joint tokenizer would give for a given byte string
return token_ids

def _tokenize_prefix(self, byte_string):
'''This is used to speed up the tokenization of long prompts without using the parser.'''
Expand Down Expand Up @@ -70,6 +74,34 @@ def _tokenize_prefix(self, byte_string):
pos = valid_pos

return token_ids,token_byte_positions

def _cleanup_tokens(self, token_ids, token_byte_positions):

# compute a joint tokenization
joint_token_ids = self._joint_tokenize(token_ids)

# see if we need to redo the tokenization
redo = False
if len(joint_token_ids) != len(token_ids):
redo = True
else:
for i,id in enumerate(joint_token_ids):
if token_ids[i] != id:
redo = True
break

if redo:
token_ids = joint_token_ids
last_pos = token_byte_positions[-1]
token_byte_positions = []
pos = 0
for i,id in enumerate(joint_token_ids):
pos += len(self.tokens[id])
token_byte_positions.append(pos)
assert token_byte_positions[-1] == last_pos

return token_ids, token_byte_positions


def __call__(self, grammar, max_tokens=100, n=1, top_p=1, temperature=0.0, ensure_bos_token=True, log_probs=False):
assert n == 1, "Still need to add support for n > 1!"
Expand All @@ -84,6 +116,7 @@ def __call__(self, grammar, max_tokens=100, n=1, top_p=1, temperature=0.0, ensur

# run a simple tokenizer (that does not use a grammar) on the prefix for better performance
token_ids,token_byte_positions = self._tokenize_prefix(prompt)
token_ids,token_byte_positions = self._cleanup_tokens(token_ids,token_byte_positions)
if len(token_byte_positions) > 0:
pre_parser_bytes = token_byte_positions[-1]
prompt = prompt[token_byte_positions[-1]:]
Expand All @@ -99,6 +132,7 @@ def __call__(self, grammar, max_tokens=100, n=1, top_p=1, temperature=0.0, ensur
sampled_token_ind = None
token_count = 0
last_token_count = 0
was_forced = False
while True: # each iteration generates one more token (and some of the associated bytes)

# enforce the token limit
Expand Down Expand Up @@ -181,7 +215,6 @@ def __call__(self, grammar, max_tokens=100, n=1, top_p=1, temperature=0.0, ensur
break

trie = trie.children[next_byte]


forced_pos = parser.pos # record how far the bytes are forced

Expand All @@ -202,6 +235,7 @@ def __call__(self, grammar, max_tokens=100, n=1, top_p=1, temperature=0.0, ensur
sampled_token_ind = trie.value
sampled_token = self.tokens[sampled_token_ind]
new_bytes_log_prob = 0.0
was_forced = True

# we are at the end of the grammar
elif next_byte_mask_sum == 0:
Expand All @@ -215,6 +249,11 @@ def __call__(self, grammar, max_tokens=100, n=1, top_p=1, temperature=0.0, ensur

# otherwise we need to compute the logits and sample a valid token
else:

# if we were forced we might need to clean up the greedy tokenization to match the global tokenization behavior as seen in training
if was_forced:
token_ids,token_byte_positions = self._cleanup_tokens(token_ids, token_byte_positions)
was_forced = False
logits = self._get_logits(token_ids, parser.bytes[start_pos:forced_pos])

# if requested we compute the log probabilities so we can track the probabilities of each node
Expand All @@ -230,7 +269,7 @@ def __call__(self, grammar, max_tokens=100, n=1, top_p=1, temperature=0.0, ensur
else:
assert top_p == 1, "Still need to add support for top_p!"
probs = scipy.special.softmax(logits / current_temp, axis=-1)
sampling_order = np.random.choice(len(probs), size=len(probs), p=probs)
sampling_order = np.random.choice(len(probs), size=len(probs), p=probs+1e-10, replace=False) # the 1e-10 is ensure we have no zero probs, which numpy does not like

# loop over the tokens looking for a valid one
for i,sampled_token_ind in enumerate(sampling_order):
Expand Down

0 comments on commit b2c729e

Please sign in to comment.