forked from doomdagadiggiedahdah/SimpleStories
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathgenerate_stories.py
286 lines (249 loc) · 14.4 KB
/
generate_stories.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
import random
import textwrap
import itertools
import json
import os
from openai import OpenAI
import hashlib
import time
import textstat
import anthropic
import concurrent.futures
from tqdm import tqdm
from datetime import datetime
from datasets import load_dataset
from perturbations import perturbations
MAX_STORIES_PER_COMPLETION = 40
END_STRING = "THE END."
def load_personas():
try:
dataset = load_dataset("proj-persona/PersonaHub", "persona")
personas = dataset['train']
return personas
except Exception as e:
print(f"Error loading personas: {e}")
return []
class RateLimitException(Exception):
pass
personas = load_personas()
themes = {"en": ["Friendship", "Courage", "Coming of age", "Kindness", "Adventure", "Imagination", "Family", "Perseverance", "Curiosity", "Honesty", "Romance", "Teamwork", "Responsibility", "Strategy", "Magic", "Discovery", "Betrayal", "Deception", "Generosity", "Creativity", "Self-Acceptance", "Helping Others", "Hardship", "Agency", "Power", "Revenge", "Independence", "Problem-Solving", "Resourcefulness", "Long-Term Thinking", "Optimism", "Humor", "Love", "The Five Senses", "Tradition", "Innovation", "Hope", "Dreams", "Belonging", "Travel", "Overcoming", "Trust", "Morality", "Happiness", "Consciousness", "Failure", "Conflict", "Cooperation", "Growth", "Loss", "Celebration", "Transformation", "Scheming", "Challenge", "Planning", "Wonder", "Surprises", "Conscience", "Intelligence", "Logic", "Resilience"]}["en"]
topics = {"en": ["talking animals", "fantasy worlds", "time travel", "space exploration", "mystical creatures", "underwater adventures", "dinosaurs", "pirates", "superheroes", "fairy tales", "outer space", "hidden treasures", "magical lands", "enchanted forests", "secret societies", "robots and technology", "sports", "school life", "holidays", "cultural traditions", "magical objects", "lost civilizations", "subterranean worlds", "bygone eras", "invisibility", "giant creatures", "miniature worlds", "alien encounters", "haunted places", "shape-shifting", "island adventures", "unusual vehicles", "undercover missions", "dream worlds", "virtual worlds", "riddles", "sibling rivalry", "treasure hunts", "snowy adventures", "seasonal changes", "mysterious maps", "royal kingdoms", "living objects", "gardens", "lost cities", "the arts", "the sky"]}["en"]
styles = {"en": ["whimsical", "playful", "epic", "fairy tale-like", "modern", "classic", "lyric", "mythological", "lighthearted", "adventurous", "heartwarming", "humorous", "mystical", "action-packed", "fable-like", "surreal", "philosophical", "melancholic", "noir", "romantic", "tragic", "minimalist", "suspenseful"]}["en"]
features = {"en": ["dialogue", "a moral lesson", "a twist ending", "foreshadowing", "irony", "inner monologue", "symbolism", "a MacGuffin", "a non-linear timeline", "a flashback", "a nested structure", "a story within a story", "multiple perspectives", "Checkhov's gun", "the fourth wall", "a cliffhanger", "an anti-hero", "juxtaposition", "climactic structure"]}["en"]
grammars = {"en": ["present tense", "past tense", "future tense", "progressive aspect", "perfect aspect", "passive voice", "conditional mood", "imperative mood", "indicative mood", "relative clauses", "prepositional phrases", "indirect speech", "exclamative sentences", "comparative forms", "superlative forms", "subordinate clauses", "ellipsis", "anaphora", "cataphora", "wh-questions", "yes-no questions", "gerunds", "participle phrases", "inverted sentences", "non-finite clauses", "determiners", "quantifiers", "adjective order", "parallel structure", "discourse markers", "appositive phrases"]}["en"]
def get_random_params(seed=42):
grammar = random.choice(grammars)
perturbation = random.choice(perturbations)
if random.random() < 0.5:
grammar = ""
if random.random() < 0.5:
perturbation = ""
return {
"theme": random.choice(themes),
"topic": random.choice(topics),
"style": random.choice(styles),
"feature": random.choice(features),
"personas": random.choice(personas)['persona'],
"grammar": grammar,
"perturbation": perturbation,
"num_paragraphs": random.randint(1, 9),
}
def iterate_params(seed=42):
random.seed(seed)
# Define independent iterators that cycle through each parameter list
theme_iter = itertools.cycle(themes)
topic_iter = itertools.cycle(topics)
style_iter = itertools.cycle(styles)
feature_iter = itertools.cycle(features)
persona_iter = itertools.cycle(personas)
grammar_iter = itertools.cycle(grammars)
k = 0
while True:
theme = next(theme_iter)
topic = next(topic_iter)
style = next(style_iter)
feature = next(feature_iter)
persona = next(persona_iter)
grammar = next(grammar_iter) if k % 2 != 0 else ""
yield {
"theme": theme,
"topic": topic,
"style": style,
"feature": feature,
"persona": persona['persona'],
"grammar": grammar,
"num_paragraphs": 1 + (k % 9),
}
k += 1
## this is used in the openai_batch script
# A slightly hacky way to yield all combinations of parameters, but having empty "grammar" value half of the time.
# Assumes the lengths of (themes * topics * styles * features), grammars, the range of num_paragraphs and 2 are coprime.
# random.seed(seed)
# This stores all combinations in memory at the moment, inelegant but not a big problem at the moment. Can be easily refactored if all parameter list lengths are coprime.
# combinations = list(itertools.product(themes, topics, styles, features, perturbations))
# random.shuffle(combinations)
# for k, combination in enumerate(combinations):
# theme, topic, style, feature = combination
# grammar = grammars[k % len(grammars)]
# if k % 2 == 0:
# grammar = ""
# yield {
# "theme": theme,
# "topic": topic,
# "style": style,
# "feature": feature,
# "grammar": grammar,
# "num_paragraphs": 1+(k%9),
# }
def create_simple_story_prompt(params):
num_stories_per_completion = MAX_STORIES_PER_COMPLETION // max(3, params['num_paragraphs'])
num_stories_per_completion = 3
system_prompt = textwrap.dedent(f"""\
You are a children's storyteller specializing in creating engaging stories for young children aged 3-4 years old.
You will be given elements to build your story around. Use these to the best of your ability to generate a story using experiences, details, and vocabulary suitable for 3-4 year olds.
But, if they are not suitable for children, creatively reinterpret them to be suitable to the story.
If you need to use proper names, make them from space-separated common words.
Either don't give characters a name, or select from Mia, Alex, Jean, Samuel, Lily, Leo, Jose, Kim, Alice, Lena, Rita, Emmanuel, Anne, Peter, Maria or Luis.
Complex story structure is great, but please remember to only use simple words.
Similarly, try not to start the story with "Once upon a time,". Other beginnings are preferred.
Please do your best, even if you have to remove some details (if you have to remove details, don't mention it).
Only return the story.
""")
singular = params['num_paragraphs'] == 1
template_singular = f"Write a short story ({params['num_paragraphs']} paragraphs, with ~85 words per paragraph) using very basic words that a preschool child could understand. \nThe story "
template_plural = textwrap.dedent(f"""\
Write {num_stories_per_completion} short stories ({params['num_paragraphs']} paragraph{'' if singular else 's'} each) using very basic words that a young child could understand.
Do not number each story or write a headline. Make the stories diverse by fully exploring the theme, but each story should be self-contained. Separate the stories by putting {END_STRING} in between.
Each story
""")
template = textwrap.dedent("""\
should be about {theme}, include {topic}, be {style} in its writing style and ideally feature {feature},
and include characters with the following personalities: {personas}.
Make sure to use character details to steer the story; again, if they are not suitable for children, creatively reinterpret them to be suitable to the story.
{perturbation} {grammar}
""")
if singular:
template = template_singular + template
else:
template = template_plural + template
params = params.copy()
if params['grammar']:
params['grammar'] = f" The most important thing is to write an engaging easy story, but where it makes sense, demonstrate the use of {params['grammar']}."
prompt = template.format(**params)
return system_prompt, prompt, num_stories_per_completion
def generate_content(gen_model, system_prompt, prompt):
assert "gpt" in gen_model or "claude" in gen_model, "Invalid model name"
if "gpt" in gen_model: # OpenAI
client = OpenAI(api_key=os.environ["OPENAI_API_KEY_SIMPLESTORIES"])
completion = client.chat.completions.create(
model=gen_model,
top_p=0.7,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt}
]
)
completion = completion.choices[0].message.content
elif "claude" in gen_model: # Anthropic
client = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY_SIMPLESTORIES"])
completion = client.messages.create(
model=gen_model,
max_tokens=min(1024*MAX_STORIES_PER_COMPLETION, 8192),
top_p=0.7,
system=system_prompt,
messages=[
{"role": "user", "content": prompt}
],
)
completion = completion.content[0].text
return completion
def process_completion(gen_model, completion, params, expected_num_stories=None,
spache_score=None, flesch_kincaid_score=None, flesch_reading_ease=None):
id = hashlib.md5(completion.encode()).hexdigest()
stories = [x.strip() for x in completion.split(END_STRING) if len(x.strip()) > 1]
table = str.maketrans({
"\u201d": "\"",
"\u201c": "\"",
"\u2019": "'",
"\u2018": "'",
"\u2014": "-",
"\u2026": "..."
})
stories = [x.translate(table) for x in stories]
if (len(stories) != expected_num_stories and expected_num_stories):
print(f"Completion did not include expected number of stories, actual={len(stories)} != expected={expected_num_stories}\nend of completion: {completion[-100:]}")
return [{
'generation_id': id + "-" + str(k),
'story': story,
'model': gen_model,
'num_stories_in_completion': len(stories),
"expected_num_stories_in_completion": expected_num_stories,
'spache_score': spache_score, # Add Spache score
'flesch_kincaid_score': flesch_kincaid_score, # Add Flesch-Kincaid score
'flesch_reading_ease': flesch_reading_ease, # Add Flesch Reading Ease score
**params
} for k, story in enumerate(stories)]
def evaluate_story(story):
spache_score = textstat.spache_readability(story)
flesch_kincaid_score = textstat.flesch_kincaid_grade(story)
flesch_reading_ease = textstat.flesch_reading_ease(story)
return spache_score, flesch_kincaid_score, flesch_reading_ease
def generate_simple_story(gen_model, params: dict):
system_prompt, prompt, expected_num_stories = create_simple_story_prompt(params.copy())
try:
completion = generate_content(gen_model, system_prompt, prompt)
spache_score, flesch_kincaid_score, flesch_reading_ease = evaluate_story(completion)
return process_completion(gen_model, completion, params, expected_num_stories,
spache_score, flesch_kincaid_score, flesch_reading_ease)
except Exception as e:
# TODO Implement Rate Limit Logic for different APIs
raise RateLimitException(e)
def generate_and_log_simple_stories(gen_model: str, params: dict, formatted_time: str):
json_struct = generate_simple_story(gen_model, params)
lines = [json.dumps(item) for item in json_struct if 'story' in item]
filename = f'data/stories-{gen_model}-{formatted_time}.jsonl'
with open(filename, "a") as f:
f.write("\n".join(lines) + "\n")
for line in json_struct:
print("##############################")
print(line['story'])
print(params)
# print(line['personas'])
# print(line['perturbation'])
spache_score, fk_score, fl_reading = evaluate_story(line['story'])
print(f"{spache_score , fk_score, fl_reading = }")
print()
def worker_thread(gen_model: str, params: dict, formatted_time: str):
while True:
try:
return generate_and_log_simple_stories(gen_model, params, formatted_time)
except RateLimitException as e:
print(f"Rate limit hit: {e}, backing off for 5 seconds...")
time.sleep(5)
continue
def main(num_completions: int, num_threads: int = 20, model = "gpt-4o-mini"):
if not os.path.exists("data"):
os.makedirs("data")
now = datetime.now()
formatted_time = now.strftime('%Y-%m-%d-%H-%M-%S')
random.seed(42) ## moving here to initialize before generating random params
# params_gen = iterate_params()
# for i, param in enumerate(params_gen):
# if i >= num_completions:
# break
# print(param)
# Generate regular simple stories
with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
future_to_story = {
executor.submit(worker_thread, model, get_random_params(), formatted_time): i for i in range(num_completions)
}
for future in tqdm(concurrent.futures.as_completed(future_to_story), total=num_completions, desc="Generating stories"):
try:
data = future.result()
except Exception as e:
print(f"Story generation failed with exception: {e}")
# Reference models: ["gpt-4o", "gpt-4o-mini", "claude-sonnet-3.5-20240620"]
if __name__ == '__main__':
NUM_COMPLETIONS = 3
main(NUM_COMPLETIONS, num_threads=10, model="gpt-4o-mini")
# main(NUM_COMPLETIONS, num_threads=10, model="claude-3-5-sonnet-20240620")