-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
56 lines (45 loc) · 1.73 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import chainlit as cl
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from openai import OpenAI
client = OpenAI()
tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
tokenizer_2 = AutoTokenizer.from_pretrained(
"facebook/nllb-200-distilled-600M", src_lang="aka_Latn"
)
model_2 = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
def generate_response(input):
completion = client.chat.completions.create(
model="gpt-4",
messages=[
{"role": "user", "content": input}
]
)
return completion.choices[0].message.content
def eng_to_target(article):
article = article
inputs = tokenizer(article, return_tensors="pt")
translated_tokens = model.generate(
**inputs, forced_bos_token_id=tokenizer.lang_code_to_id["aka_Latn"], max_length=3000
)
return tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
def target_to_eng(article):
article = article
inputs = tokenizer_2(article, return_tensors="pt")
translated_tokens = model_2.generate(
**inputs, forced_bos_token_id=tokenizer_2.lang_code_to_id["eng_Latn"], max_length=3000
)
return tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
@cl.on_message
async def main(message: cl.Message):
# Your custom logic goes here...
# Send a response back to the user
message_in = target_to_eng(message.content)
print(message_in)
chat_response = generate_response(message_in)
print(chat_response)
message_out = eng_to_target(chat_response)
print(message_out)
await cl.Message(
content=message_out
).send()