From b2c729eb15b87559981bb08f2928f5c9839ef699 Mon Sep 17 00:00:00 2001 From: Scott Lundberg Date: Thu, 23 Nov 2023 02:26:11 +0000 Subject: [PATCH] Fix tokenization for llamacpp, address torch to numpy bug --- guidance/_grammar.py | 5 +++- guidance/_parser.py | 14 ++++++++---- guidance/models/_llama_cpp.py | 4 ++++ guidance/models/_local.py | 43 +++++++++++++++++++++++++++++++++-- 4 files changed, 59 insertions(+), 7 deletions(-) diff --git a/guidance/_grammar.py b/guidance/_grammar.py index a6538faf3..03dc90734 100644 --- a/guidance/_grammar.py +++ b/guidance/_grammar.py @@ -532,8 +532,11 @@ def token_limit(value, max_tokens): def _rec_token_limit(grammar, max_tokens): if grammar.max_tokens > max_tokens and not isinstance(grammar, Terminal): - if getattr(grammar, "recursive", False): # only restrict recursive selects, otherwise we would block all way to complete the grammar + if getattr(grammar, "recursive", False): # only restrict recursive selects, otherwise we would block all ways to complete the grammar grammar.max_tokens = max_tokens + for value in getattr(grammar, "values", []): # restrict recursive selects recursive nodes + if not isinstance(value, Terminal): + value.max_tokens = max_tokens if hasattr(grammar, "values"): for g in grammar.values: _rec_token_limit(g, max_tokens) diff --git a/guidance/_parser.py b/guidance/_parser.py index 7e91ade9c..05da69cb3 100644 --- a/guidance/_parser.py +++ b/guidance/_parser.py @@ -306,7 +306,10 @@ def valid_next_bytes(self): valid_items = set() next_state_set = self.state_sets[self.state_set_pos + 1] for item in next_state_set: - if item.pos > 0 and isinstance(item.values[item.pos - 1], Terminal): + token_span = self.token_counts[-1] - self.token_counts[item.start] + if item.node.max_tokens <= token_span: + continue + elif item.pos > 0 and isinstance(item.values[item.pos - 1], Terminal): v = item.values[item.pos - 1] if v not in valid_items: valid_items.add(v) @@ -367,9 +370,12 @@ def __repr__(self, state_sets=None) -> str: rs = "" if state.pos == 0: rs += "•" - rs += state.values[0].name - if state.pos == 1: - rs += "•" + if len(state.values) == 0: + rs += "NO_VALUES!" + else: + rs += state.values[0].name + if state.pos == 1: + rs += "•" else: assert False s += f"{rs:40} ({state.start}) {'nullable' if state.node.nullable else ''}\n" diff --git a/guidance/models/_llama_cpp.py b/guidance/models/_llama_cpp.py index b002b489e..1059aee4c 100644 --- a/guidance/models/_llama_cpp.py +++ b/guidance/models/_llama_cpp.py @@ -69,6 +69,10 @@ def __init__(self, model=None, tokenizer=None, echo=True, caching=True, temperat self._cache_state["cache_token_ids"] = [] + def _joint_tokenize(self, token_ids): + byte_string = b"".join([self.tokens[t] for t in token_ids]) + return self.model_obj.tokenize(byte_string, add_bos=False, special=True) + def _get_logits(self, token_ids, forced_bytes): '''Computes the logits for the given token state. diff --git a/guidance/models/_local.py b/guidance/models/_local.py index dea7a8a0d..16a0f8b52 100644 --- a/guidance/models/_local.py +++ b/guidance/models/_local.py @@ -30,6 +30,10 @@ def _get_logits(self, token_ids, forced_bytes): # pretend to extend the KV cache and update the log probs return np.randn(len(self.tokens)) + + def _joint_tokenize(self, token_ids): + # an abstract method. Should return what a full joint tokenizer would give for a given byte string + return token_ids def _tokenize_prefix(self, byte_string): '''This is used to speed up the tokenization of long prompts without using the parser.''' @@ -70,6 +74,34 @@ def _tokenize_prefix(self, byte_string): pos = valid_pos return token_ids,token_byte_positions + + def _cleanup_tokens(self, token_ids, token_byte_positions): + + # compute a joint tokenization + joint_token_ids = self._joint_tokenize(token_ids) + + # see if we need to redo the tokenization + redo = False + if len(joint_token_ids) != len(token_ids): + redo = True + else: + for i,id in enumerate(joint_token_ids): + if token_ids[i] != id: + redo = True + break + + if redo: + token_ids = joint_token_ids + last_pos = token_byte_positions[-1] + token_byte_positions = [] + pos = 0 + for i,id in enumerate(joint_token_ids): + pos += len(self.tokens[id]) + token_byte_positions.append(pos) + assert token_byte_positions[-1] == last_pos + + return token_ids, token_byte_positions + def __call__(self, grammar, max_tokens=100, n=1, top_p=1, temperature=0.0, ensure_bos_token=True, log_probs=False): assert n == 1, "Still need to add support for n > 1!" @@ -84,6 +116,7 @@ def __call__(self, grammar, max_tokens=100, n=1, top_p=1, temperature=0.0, ensur # run a simple tokenizer (that does not use a grammar) on the prefix for better performance token_ids,token_byte_positions = self._tokenize_prefix(prompt) + token_ids,token_byte_positions = self._cleanup_tokens(token_ids,token_byte_positions) if len(token_byte_positions) > 0: pre_parser_bytes = token_byte_positions[-1] prompt = prompt[token_byte_positions[-1]:] @@ -99,6 +132,7 @@ def __call__(self, grammar, max_tokens=100, n=1, top_p=1, temperature=0.0, ensur sampled_token_ind = None token_count = 0 last_token_count = 0 + was_forced = False while True: # each iteration generates one more token (and some of the associated bytes) # enforce the token limit @@ -181,7 +215,6 @@ def __call__(self, grammar, max_tokens=100, n=1, top_p=1, temperature=0.0, ensur break trie = trie.children[next_byte] - forced_pos = parser.pos # record how far the bytes are forced @@ -202,6 +235,7 @@ def __call__(self, grammar, max_tokens=100, n=1, top_p=1, temperature=0.0, ensur sampled_token_ind = trie.value sampled_token = self.tokens[sampled_token_ind] new_bytes_log_prob = 0.0 + was_forced = True # we are at the end of the grammar elif next_byte_mask_sum == 0: @@ -215,6 +249,11 @@ def __call__(self, grammar, max_tokens=100, n=1, top_p=1, temperature=0.0, ensur # otherwise we need to compute the logits and sample a valid token else: + + # if we were forced we might need to clean up the greedy tokenization to match the global tokenization behavior as seen in training + if was_forced: + token_ids,token_byte_positions = self._cleanup_tokens(token_ids, token_byte_positions) + was_forced = False logits = self._get_logits(token_ids, parser.bytes[start_pos:forced_pos]) # if requested we compute the log probabilities so we can track the probabilities of each node @@ -230,7 +269,7 @@ def __call__(self, grammar, max_tokens=100, n=1, top_p=1, temperature=0.0, ensur else: assert top_p == 1, "Still need to add support for top_p!" probs = scipy.special.softmax(logits / current_temp, axis=-1) - sampling_order = np.random.choice(len(probs), size=len(probs), p=probs) + sampling_order = np.random.choice(len(probs), size=len(probs), p=probs+1e-10, replace=False) # the 1e-10 is ensure we have no zero probs, which numpy does not like # loop over the tokens looking for a valid one for i,sampled_token_ind in enumerate(sampling_order):