forked from x4nth055/pythoncode-tutorials
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdialogpt.py
169 lines (163 loc) · 6.69 KB
/
dialogpt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
# -*- coding: utf-8 -*-
"""DialoGPT.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1KAg6X8RFHE0KSvFSZ__w7KGZrSqT4cZ3
"""
# !pip install transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# model_name = "microsoft/DialoGPT-large"
model_name = "microsoft/DialoGPT-medium"
# model_name = "microsoft/DialoGPT-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
print("====Greedy search chat====")
# chatting 5 times with greedy search
for step in range(5):
# take user input
text = input(">> You:")
# encode the input and add end of string token
input_ids = tokenizer.encode(text + tokenizer.eos_token, return_tensors="pt")
# concatenate new user input with chat history (if there is)
bot_input_ids = torch.cat([chat_history_ids, input_ids], dim=-1) if step > 0 else input_ids
# generate a bot response
chat_history_ids = model.generate(
bot_input_ids,
max_length=1000,
pad_token_id=tokenizer.eos_token_id,
)
#print the output
output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
print(f"DialoGPT: {output}")
print("====Beam search chat====")
# chatting 5 times with beam search
for step in range(5):
# take user input
text = input(">> You:")
# encode the input and add end of string token
input_ids = tokenizer.encode(text + tokenizer.eos_token, return_tensors="pt")
# concatenate new user input with chat history (if there is)
bot_input_ids = torch.cat([chat_history_ids, input_ids], dim=-1) if step > 0 else input_ids
# generate a bot response
chat_history_ids = model.generate(
bot_input_ids,
max_length=1000,
num_beams=3,
early_stopping=True,
pad_token_id=tokenizer.eos_token_id
)
#print the output
output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
print(f"DialoGPT: {output}")
print("====Sampling chat====")
# chatting 5 times with sampling
for step in range(5):
# take user input
text = input(">> You:")
# encode the input and add end of string token
input_ids = tokenizer.encode(text + tokenizer.eos_token, return_tensors="pt")
# concatenate new user input with chat history (if there is)
bot_input_ids = torch.cat([chat_history_ids, input_ids], dim=-1) if step > 0 else input_ids
# generate a bot response
chat_history_ids = model.generate(
bot_input_ids,
max_length=1000,
do_sample=True,
top_k=0,
pad_token_id=tokenizer.eos_token_id
)
#print the output
output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
print(f"DialoGPT: {output}")
print("====Sampling chat with tweaking temperature====")
# chatting 5 times with sampling & tweaking temperature
for step in range(5):
# take user input
text = input(">> You:")
# encode the input and add end of string token
input_ids = tokenizer.encode(text + tokenizer.eos_token, return_tensors="pt")
# concatenate new user input with chat history (if there is)
bot_input_ids = torch.cat([chat_history_ids, input_ids], dim=-1) if step > 0 else input_ids
# generate a bot response
chat_history_ids = model.generate(
bot_input_ids,
max_length=1000,
do_sample=True,
top_k=0,
temperature=0.75,
pad_token_id=tokenizer.eos_token_id
)
#print the output
output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
print(f"DialoGPT: {output}")
print("====Top-K sampling chat with tweaking temperature====")
# chatting 5 times with Top K sampling & tweaking temperature
for step in range(5):
# take user input
text = input(">> You:")
# encode the input and add end of string token
input_ids = tokenizer.encode(text + tokenizer.eos_token, return_tensors="pt")
# concatenate new user input with chat history (if there is)
bot_input_ids = torch.cat([chat_history_ids, input_ids], dim=-1) if step > 0 else input_ids
# generate a bot response
chat_history_ids = model.generate(
bot_input_ids,
max_length=1000,
do_sample=True,
top_k=100,
temperature=0.75,
pad_token_id=tokenizer.eos_token_id
)
#print the output
output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
print(f"DialoGPT: {output}")
print("====Nucleus sampling (top-p) chat with tweaking temperature====")
# chatting 5 times with nucleus sampling & tweaking temperature
for step in range(5):
# take user input
text = input(">> You:")
# encode the input and add end of string token
input_ids = tokenizer.encode(text + tokenizer.eos_token, return_tensors="pt")
# concatenate new user input with chat history (if there is)
bot_input_ids = torch.cat([chat_history_ids, input_ids], dim=-1) if step > 0 else input_ids
# generate a bot response
chat_history_ids = model.generate(
bot_input_ids,
max_length=1000,
do_sample=True,
top_p=0.95,
top_k=0,
temperature=0.75,
pad_token_id=tokenizer.eos_token_id
)
#print the output
output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
print(f"DialoGPT: {output}")
print("====chatting 5 times with nucleus & top-k sampling & tweaking temperature & multiple sentences====")
# chatting 5 times with nucleus & top-k sampling & tweaking temperature & multiple
# sentences
for step in range(5):
# take user input
text = input(">> You:")
# encode the input and add end of string token
input_ids = tokenizer.encode(text + tokenizer.eos_token, return_tensors="pt")
# concatenate new user input with chat history (if there is)
bot_input_ids = torch.cat([chat_history_ids, input_ids], dim=-1) if step > 0 else input_ids
# generate a bot response
chat_history_ids_list = model.generate(
bot_input_ids,
max_length=1000,
do_sample=True,
top_p=0.95,
top_k=50,
temperature=0.75,
num_return_sequences=5,
pad_token_id=tokenizer.eos_token_id
)
#print the outputs
for i in range(len(chat_history_ids_list)):
output = tokenizer.decode(chat_history_ids_list[i][bot_input_ids.shape[-1]:], skip_special_tokens=True)
print(f"DialoGPT {i}: {output}")
choice_index = int(input("Choose the response you want for the next input: "))
chat_history_ids = torch.unsqueeze(chat_history_ids_list[choice_index], dim=0)