-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathapi.py
204 lines (168 loc) · 9.07 KB
/
api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import os as os
import numpy as np
import json
import pickle
import gzip
import matplotlib.pyplot as plt
class Benchmark():
"""API for TabularBench."""
def __init__(self, data_dir, cache=False, cache_dir="cached/"):
"""Initialize dataset (will take a few seconds-minutes).
Keyword arguments:
bench_data -- str, the raw benchmark data directory
"""
if not os.path.isfile(data_dir) or not data_dir.endswith(".json"):
raise ValueError("Please specify path to the bench json file.")
self.data_dir = data_dir
self.cache_dir = cache_dir
self.cache = cache
print("==> Loading data...")
self.data = self._read_data(data_dir)
self.dataset_names = list(self.data.keys())
print("==> Done.")
def query(self, dataset_name, tag, config_id):
"""Query a run.
Keyword arguments:
dataset_name -- str, the name of the dataset in the benchmark
tag -- str, the tag you want to query
config_id -- int, an identifier for which run you want to query, if too large will query the last run
"""
config_id = str(config_id)
if dataset_name not in self.dataset_names:
raise ValueError("Dataset name not found.")
if config_id not in self.data[dataset_name].keys():
raise ValueError("Config nr %s not found for dataset %s." % (config_id, dataset_name))
if tag in self.data[dataset_name][config_id]["log"].keys():
return self.data[dataset_name][config_id]["log"][tag]
if tag in self.data[dataset_name][config_id]["results"].keys():
return self.data[dataset_name][config_id]["results"][tag]
if tag in self.data[dataset_name][config_id]["config"].keys():
return self.data[dataset_name][config_id]["config"][tag]
if tag == "config":
return self.data[dataset_name][config_id]["config"]
raise ValueError("Tag %s not found for config %s for dataset %s" % (tag, config_id, dataset_name))
def query_best(self, dataset_name, tag, criterion, position=0):
"""Query the n-th best run. "Best" here means achieving the largest value at any epoch/step,
Keyword arguments:
dataset_name -- str, the name of the dataset in the benchmark
tag -- str, the tag you want to query
criterion -- str, the tag you want to use for the ranking
position -- int, an identifier for which position in the ranking you want to query
"""
performances = []
for config_id in self.data[dataset_name].keys():
performances.append((config_id, max(self.query(dataset_name, criterion, config_id))))
performances.sort(key=lambda x: x[1]*1000, reverse=True)
desired_position = performances[position][0]
return self.query(dataset_name, tag, desired_position)
def get_queriable_tags(self, dataset_name=None, config_id=None):
"""Returns a list of all queriable tags"""
if dataset_name is None or config_id is None:
dataset_name = list(self.data.keys())[0]
config_id = list(self.data[dataset_name].keys())[0]
else:
config_id = str(config_id)
log_tags = list(self.data[dataset_name][config_id]["log"].keys())
result_tags = list(self.data[dataset_name][config_id]["results"].keys())
config_tags = list(self.data[dataset_name][config_id]["config"].keys())
additional = ["config"]
return log_tags+result_tags+config_tags+additional
def get_dataset_names(self):
"""Returns a list of all availabe dataset names like defined on openml"""
return self.dataset_names
def get_openml_task_ids(self):
"""Returns a list of openml task ids"""
task_ids = []
for dataset_name in self.dataset_names:
task_ids.append(self.query(dataset_name, "OpenML_task_id", 1))
return task_ids
def get_number_of_configs(self, dataset_name):
"""Returns the number of configurations for a dataset"""
if dataset_name not in self.dataset_names:
raise ValueError("Dataset name not found.")
return len(self.data[dataset_name].keys())
def get_config(self, dataset_name, config_id):
"""Returns the configuration of a run specified by dataset name and config id"""
if dataset_name not in self.dataset_names:
raise ValueError("Dataset name not found.")
return self.data[dataset_name][config_id]["config"]
def plot_by_name(self, dataset_names, x_col, y_col, n_configs=10, show_best=False, xscale='linear', yscale='linear', criterion=None):
"""Plot multiple datasets and multiple runs.
Keyword arguments:
dataset_names -- list
x_col -- str, tag to plot on x-axis
y_col -- str, tag to plot on y-axis
n_configs -- int, number of configs to plot for each dataset
show_best -- bool, weather to show the n_configs best (according to query_best())
xscale -- str, set xscale, options as in matplotlib: "linear", "log", "symlog", "logit", ...
yscale -- str, set yscale, options as in matplotlib: "linear", "log", "symlog", "logit", ...
criterion -- str, tag used as criterion for query_best()
"""
if isinstance(dataset_names, str):
dataset_names = [dataset_names]
if not isinstance(dataset_names, (list, np.ndarray)):
raise ValueError("Please specify a dataset name or a list list of dataset names.")
n_rows = len(dataset_names)
fig, axes = plt.subplots(n_rows, 1, sharex=False, sharey=False, figsize=(10,7*n_rows))
if criterion is None:
criterion = y_col
loop_arg = enumerate(axes.flatten()) if len(dataset_names)>1 else [(0,axes)]
for ind_ax, ax in loop_arg:
for ind in range(n_configs):
try:
if ind==0:
instances = int(self.query(dataset_names[ind_ax], "instances", 0))
classes = int(self.query(dataset_names[ind_ax], "classes", 0))
features = int(self.query(dataset_names[ind_ax], "features", 0))
if show_best:
x = self.query_best(dataset_names[ind_ax], x_col, criterion, ind)
y = self.query_best(dataset_names[ind_ax], y_col, criterion, ind)
else:
x = self.query(dataset_names[ind_ax], x_col, ind+1)
y = self.query(dataset_names[ind_ax], y_col, ind+1)
ax.plot(x, y, 'p-')
ax.set_xscale(xscale)
ax.set_yscale(yscale)
ax.set(xlabel="step", ylabel=y_col)
title_str = ", ".join([dataset_names[ind_ax],
"features: " + str(features),
"classes: " + str(classes),
"instances: " + str(instances)])
ax.title.set_text(title_str)
except ValueError:
print("Run %i not found for dataset %s" %(ind, dataset_names[ind_ax]))
except Exception as e:
raise e
def _cache_data(self, data, cache_file):
os.makedirs(self.cache_dir, exist_ok=True)
with gzip.open(cache_file, 'wb') as f:
pickle.dump(data, f)
def _read_cached_data(self, cache_file):
with gzip.open(cache_file, 'rb') as f:
data = pickle.load(f)
return data
def _read_file_string(self, path):
"""Reads a large json string from path. Python file handler has issues with large files so it has to be chunked."""
# Shoutout to https://stackoverflow.com/questions/48122798/oserror-errno-22-invalid-argument-when-reading-a-huge-file
file_str = ''
with open(path, 'r') as f:
while True:
block = f.read(64 * (1 << 20)) # Read 64 MB at a time
if not block: # Reached EOF
break
file_str += block
return file_str
def _read_data(self, path):
"""Reads cached data if available. If not, reads json and caches the data as .pkl.gz"""
cache_file = os.path.join(self.cache_dir, os.path.basename(self.data_dir).replace(".json", ".pkl.gz"))
if os.path.exists(cache_file) and self.cache:
print("==> Found cached data, loading...")
data = self._read_cached_data(cache_file)
else:
print("==> No cached data found or cache set to False.")
print("==> Reading json data...")
data = json.loads(self._read_file_string(path))
if self.cache:
print("==> Caching data...")
self._cache_data(data, cache_file)
return data