-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathevaluate.py
190 lines (173 loc) · 7.61 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
from pathlib import Path
from typing import List, Dict, Any, Tuple, Union, Callable
import pickle
import os
from datetime import datetime
from typing import Dict, List
from functools import partial
import time
from USACOBench.utils import get_code_from_solution
Problem = Dict[Any, Any]
Solution = Dict[str, Union[str, None]]
SolutionSet = List[Solution]
SolutionDict = Dict[str, SolutionSet]
Result = Dict[str, str]
ResultSet = List[Result]
ResultDict = Dict[str, ResultSet]
Query = Dict[str, str]
def judge_fn_solve(responses: List[str],
queries: List[Query],
verbose=False) -> List[Result]:
'''
Result is the USACO judge outcome (ACCEPTED, WRONG_ANSWER, etc) based on final code.
Expects code in the Markdown format as defined in iml.utils.get_code_from_solution
'''
solution_sets = [[{
'problem_id': query['problem_id'],
'solution': response,
'solution_code': get_code_from_solution(response),
'language': 'Python3',
}] for query, response in zip(queries, responses)]
results = evaluate_ss(solution_sets, mode='eval_all')
return [result[0] for result in results] # flatten
def evaluate_model(model_fn: Callable,
prompt_fn: Callable,
queries: List[Query],
attempts: int,
problem_ids: List[str] = None,
verbose=False) -> Tuple[ResultDict, SolutionDict, List[ResultSet], List[SolutionSet]]:
'''
model_fn: takes in list of string prompts and outputs list of string responses, supports verbose bool
prompt_fn: returns a prompt string: takes in a query, which is a dict of strings
queries: list of queries to evaluate containing information to be inputted into the prompt function
problem_ids: we evaluate only on query-ground truth pairs with problem_ids in this list. If None, all problem_ids are valid.
attempts: number of times to run each query
'''
# queries, grond truths, prompts
if problem_ids is not None:
problem_ids = set(problem_ids)
valid_idxs = [idx for idx, query in enumerate(queries) if query['problem_id'] in problem_ids]
if verbose:
print('Evaluating on a subset of {} out of {} available query-ground_truth pairs...'.format(len(valid_idxs), len(queries)))
queries = [queries[idx] for idx in valid_idxs]
prompt_fns = [prompt_fn] * len(queries)
if verbose:
print('Evaluating on {} queries...'.format(len(queries)))
# model and judge
model_fn = partial(model_fn, verbose=verbose)
judge_fn = partial(judge_fn_solve, verbose=verbose)
prompts = [prompt_fn(query) for prompt_fn, query in zip(prompt_fns, queries)] * attempts
if verbose:
print('Generating...')
start_time = time.time()
responses = model_fn(prompts)
if verbose:
print('Finished generation, took {} seconds'.format(time.time() - start_time))
if verbose:
print('Judging...')
start_time = time.time()
results = judge_fn(responses, queries * attempts)
if verbose:
print('Finished judging, took {} seconds'.format(time.time() - start_time))
# nicer result formats
rdict = {}
for result in results:
problem_id = result['problem_id']
if problem_id not in rdict:
rdict[problem_id] = []
rdict[problem_id].append(result)
rs = list(rdict.values())
# nicer solution formats
# note: this sdict / ss includes the result for easier qualitative eval, so may be slightly bulkier
# no ground truth, e.g. code
sdict = {}
for solution, prompt, query in zip(responses, prompts, queries):
problem_id = query['problem_id']
matching_result = None
for result in results:
if result['problem_id'] == problem_id:
matching_result = result
break
if problem_id not in sdict:
sdict[problem_id] = []
sdict[problem_id].append({
'solution': solution,
'solution_code': get_code_from_solution(solution),
'result': matching_result,
'problem_id': problem_id,
'prompt': prompt,
})
ss = list(sdict.values())
return rdict, sdict, rs, ss
def evaluate_ss(ss, mode='eval_all') -> List[ResultSet]:
'''
Returns result sets for the given solution sets. For use inside Jupyter environments,
where directly calling evaluate_solution_sets crashes the environment. Uses os.system instead.
Returns a list of result sets.
'''
timestamp_str = datetime.now().strftime("%m_%d_%Y_%H_%M_%S_%f")
with open('judge_sandbox/solution_sets_{}.pickle'.format(timestamp_str), 'wb') as f:
pickle.dump(ss, f)
os.system('python evaluate_solution_sets.py -s judge_sandbox/solution_sets_{}.pickle -r judge_sandbox/result_sets_{}.pickle -m {}'.format(timestamp_str, timestamp_str, mode))
try:
with open('judge_sandbox/result_sets_{}.pickle'.format(timestamp_str), 'rb') as f:
rs = pickle.load(f)
except Exception as error:
print(error)
return None
return rs
def evaluate_code(problem_id, code) -> Result:
'''
Evaluates given code for problem problem_id on all test cases and returns results.
For use inside Jupyter environments.
Returns a single result.
'''
timestamp_str = datetime.now().strftime("%m_%d_%Y_%H_%M_%S_%f")
with open('judge_sandbox/code_{}.py'.format(timestamp_str), 'w') as f:
f.write(code)
os.system('python usaco_judge_one.py judge_sandbox/code_{}.py -i {} -r --result_path judge_sandbox/result_{}.pickle'.format(timestamp_str, problem_id, timestamp_str))
with open('judge_sandbox/result_{}.pickle'.format(timestamp_str), 'rb') as f:
result = pickle.load(f)
return result
def run_code_on_input(problem_id, code, input) -> str:
'''
Evaluates given code for problem problem_id on the given input and returns the printed output.
For use inside Jupyter environments.
Returns an output string.
'''
timestamp_str = datetime.now().strftime("%m_%d_%Y_%H_%M_%S_%f")
code_prefix = '''import sys;sys.stdout = open('output_{}.txt', 'w');sys.stderr = sys.stdout\n'''.format(timestamp_str)
code = code_prefix + code
with open('judge_sandbox/code_{}.py'.format(timestamp_str), 'w') as f:
f.write(code)
with open('judge_sandbox/input_{}.txt'.format(timestamp_str), 'w') as f:
f.write(input)
os.system('cat judge_sandbox/input_{}.txt | python judge_sandbox/code_{}.py'.format(timestamp_str, timestamp_str))
with open('output_{}.txt'.format(timestamp_str), 'r') as f:
output = f.read()
return output
def run_code_on_first_sample(problem, code, return_all=False, print_all=True):
'''
Evaluates given code for problem problem_id on the first sample, if available,
and returns the printed output. For use inside Jupyter environments.
Returns an output string.
return_all: returns not just output but (output, input, expected_output)
'''
problem_id = problem['problem_id']
assert 'samples' in problem and len(problem['samples']) > 0, 'No samples found'
input = problem['samples'][0]['input']
expected_output = problem['samples'][0]['output']
output = run_code_on_input(problem['problem_id'], code, input)
if print_all:
print('Input:')
print(input.strip())
print()
print('Output:')
print(output.strip())
print()
print('Expected output:')
print(expected_output.strip())
print()
if return_all:
return output.strip(), input.strip(), expected_output.strip()
return output