Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update model_chatglm.py #2766

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions fastchat/model/model_chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,35 @@ def process_response(response):
return response


def apply_stopping_string(reply, stop_strings):
if isinstance(stop_strings, str):
stop_strings = [stop_strings]

stop_found = False

for string in stop_strings[:4]:
if isinstance(string, str):
idx = reply.find(string)
if idx != -1:
reply = reply[:idx]
stop_found = True

if not stop_found:
# If something like "\nYo" is generated just before "\nYou: is completed, trim it
for string in stop_strings[:4]:
if isinstance(string, str):
for j in range(len(string) - 1, 0, -1):
if reply[-j:] == string[:j]:
reply = reply[:-j]
break
else:
continue

break

return stop_found, reply


@torch.inference_mode()
def generate_stream_chatglm(
model,
Expand All @@ -53,6 +82,7 @@ def generate_stream_chatglm(
top_p = float(params.get("top_p", 1.0))
max_new_tokens = int(params.get("max_new_tokens", 256))
echo = params.get("echo", True)
stop = params.get("stop", [])

inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
input_echo_len = len(inputs["input_ids"][0])
Expand All @@ -78,6 +108,10 @@ def generate_stream_chatglm(
response = tokenizer.decode(output_ids)
response = process_response(response)

stop_found, response = (
apply_stopping_string(response, stop) if response else (False, response)
)

yield {
"text": response,
"usage": {
Expand All @@ -88,6 +122,9 @@ def generate_stream_chatglm(
"finish_reason": None,
}

if stop_found:
break

# TODO: ChatGLM stop when it reach max length
# Only last stream result contains finish_reason, we set finish_reason as stop
ret = {
Expand Down