forked from dustinstansbury/statistical-rethinking-2023
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
391 lines (313 loc) · 12 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
import os
import pandas as pd
import numpy as np
import pymc as pm
import arviz as az
import graphviz as gr
import networkx as nx
from matplotlib import pyplot as plt
from pathlib import Path
from typing import List, Union, Callable
HERE = Path(".")
def load_data(dataset, delimiter=";"):
fname = f"{dataset}.csv"
data_path = HERE / "data"
data_file = data_path / fname
return pd.read_csv(data_file, sep=delimiter)
def crosstab(x: np.array, y: np.array, labels: list[str] = None):
"""Simple cross tabulation of two discrete vectors x and y"""
ct = pd.crosstab(x, y)
if labels:
ct.index = labels
ct.columns = labels
return ct
def center(vals: np.ndarray) -> np.ndarray:
return vals - np.nanmean(vals)
def standardize(vals: np.ndarray) -> np.ndarray:
centered_vals = center(vals)
return centered_vals / np.nanstd(centered_vals)
def convert_to_categorical(vals):
return vals.astype("category").cat.codes.values
def logit(p: float) -> float:
return np.log(p / (1 - p))
def invlogit(x: float) -> float:
return 1 / (1 + np.exp(-x))
def draw_causal_graph(
edge_list, node_props=None, edge_props=None, graph_direction="UD"
):
"""Utility to draw a causal (directed) graph"""
g = gr.Digraph(graph_attr={"rankdir": graph_direction})
edge_props = {} if edge_props is None else edge_props
for e in edge_list:
props = edge_props[e] if e in edge_props else {}
g.edge(e[0], e[1], **props)
if node_props is not None:
for name, props in node_props.items():
g.node(name=name, **props)
return g
def plot_scatter(xs, ys, **scatter_kwargs):
"""Draw scatter plot with consistent style (e.g. unfilled points)"""
defaults = {"alpha": 0.6, "lw": 3, "s": 80, "color": "C0", "facecolors": "none"}
for k, v in defaults.items():
val = scatter_kwargs.get(k, v)
scatter_kwargs[k] = val
plt.scatter(xs, ys, **scatter_kwargs)
def plot_line(xs, ys, **plot_kwargs):
"""Plot line with consistent style (e.g. bordered lines)"""
linewidth = plot_kwargs.get("linewidth", 3)
plot_kwargs["linewidth"] = linewidth
# Copy settings for background
background_plot_kwargs = {k: v for k, v in plot_kwargs.items()}
background_plot_kwargs["linewidth"] = linewidth + 2
background_plot_kwargs["color"] = "white"
del background_plot_kwargs["label"] # no legend label for background
plt.plot(xs, ys, **background_plot_kwargs, zorder=30)
plt.plot(xs, ys, **plot_kwargs, zorder=31)
def plot_errorbar(
xs, ys, error_lower, error_upper, colors="C0", error_width=12, alpha=0.3
):
if isinstance(colors, str):
colors = [colors] * len(xs)
"""Draw thick error bars with consistent style"""
for ii, (x, y, err_l, err_u) in enumerate(zip(xs, ys, error_lower, error_upper)):
marker, _, bar = plt.errorbar(
x=x,
y=y,
yerr=np.array((err_l, err_u))[:, None],
ls="none",
color=colors[ii],
zorder=1,
)
plt.setp(bar[0], capstyle="round")
marker.set_fillstyle("none")
bar[0].set_alpha(alpha)
bar[0].set_linewidth(error_width)
def plot_x_errorbar(
xs, ys, error_lower, error_upper, colors="C0", error_width=12, alpha=0.3
):
if isinstance(colors, str):
colors = [colors] * len(xs)
"""Draw thick error bars with consistent style"""
for ii, (x, y, err_l, err_u) in enumerate(zip(xs, ys, error_lower, error_upper)):
marker, _, bar = plt.errorbar(
x=x,
y=y,
xerr=np.array((err_l, err_u))[:, None],
ls="none",
color=colors[ii],
zorder=1,
)
plt.setp(bar[0], capstyle="round")
marker.set_fillstyle("none")
bar[0].set_alpha(alpha)
bar[0].set_linewidth(error_width)
def plot_graph(graph, **graph_kwargs):
"""Draw a network graph.
graph: Union[networkx.DiGraph, np.ndarray]
if ndarray, assume `graph` is an adjacency matrix defining
a directed graph.
"""
# convert to networkx.DiGraph, if needed
G = (
nx.from_numpy_array(graph, create_using=nx.DiGraph)
if isinstance(graph, np.ndarray)
else graph
)
# Set default styling
np.random.seed(123) # for consistent spring-layout
if "layout" in graph_kwargs:
graph_kwargs["pos"] = graph_kwargs["layout"](G)
default_graph_kwargs = {
"node_color": "C0",
"node_size": 500,
"arrowsize": 30,
"width": 3,
"alpha": 0.7,
"connectionstyle": "arc3,rad=0.1",
"pos": nx.kamada_kawai_layout(G),
}
for k, v in default_graph_kwargs.items():
if k not in graph_kwargs:
graph_kwargs[k] = v
nx.draw(G, **graph_kwargs)
# return the node layout for consistent graphing
return graph_kwargs["pos"]
def plot_2d_function(xrange, yrange, func, ax=None, **countour_kwargs):
"""Evaluate the function `func` over the values of xrange and yrange and
plot the resulting value contour over that range.
Parameters
----------
xrange : np.ndarray
The horizontal values to evaluate/plot
yrange : p.ndarray
The horizontal values to evaluate/plot
func : Callable
function of two arguments, xs and ys. Should return a single value at
each point.
ax : matplotlib.Axis, optional
An optional axis to plot the function, by default None
Returns
-------
contour : matplotlib.contour.QuadContourSet
"""
resolution = len(xrange)
xs, ys = np.meshgrid(xrange, yrange)
xs = xs.ravel()
ys = ys.ravel()
value = func(xs, ys)
if ax is not None:
plt.sca(ax)
return plt.contour(
xs.reshape(resolution, resolution),
ys.reshape(resolution, resolution),
value.reshape(resolution, resolution),
**countour_kwargs,
)
def create_variables_dataframe(*variables: List[np.ndarray]) -> pd.DataFrame:
"""Converts a list of numpy arrays to a dataframe; infers column names from
variable names
"""
column_names = [get_variable_name(v) for v in variables]
return pd.DataFrame(np.vstack(variables).T, columns=column_names)
def plot_pymc_distribution(distribution: pm.Distribution, **distribution_params):
"""Plot a PyMC Distribution with specific distrubution parameters
Parameters
----------
distribution : pymc.Distribution
The class of distribution to
**distribution_params : dict
Distribution-specific parameters.
Returns
-------
ax : matplotlib.Axes
The axes object associated with the plot.
"""
with pm.Model() as _:
d = distribution(name=distribution.__name__, **distribution_params)
draws = pm.draw(d, draws=10_000)
return az.plot_dist(draws)
def savefig(filename):
"""Save a figure to the `./images` directory"""
image_path = HERE / "images"
if not image_path.exists():
print(f"creating image directory: {image_path}")
os.makedirs(image_path)
figure_path = image_path / filename
print(f"saving figure to {figure_path}")
plt.savefig(figure_path, dpi=300, bbox_inches="tight")
def display_image(filename, width=600):
"""Display an image saved to the `./images` directory"""
from IPython.display import Image, display
return display(Image(filename=f"images/{filename}", width=width))
def simulate_2_parameter_bayesian_learning_grid_approximation(
x_obs,
y_obs,
param_a_grid,
param_b_grid,
true_param_a,
true_param_b,
model_func,
posterior_func,
n_posterior_samples=3,
param_labels=None,
data_range_x=None,
data_range_y=None,
):
"""General function for simulating Bayesian learning in a 2-parameter model
using grid approximation.
Parameters
----------
x_obs : np.ndarray
The observed x values
y_obs : np.ndarray
The observed y values
param_a_grid: np.ndarray
The range of values the first model parameter in the model can take.
Note: should have same length as param_b_grid.
param_b_grid: np.ndarray
The range of values the second model parameter in the model can take.
Note: should have same length as param_a_grid.
true_param_a: float
The true value of the first model parameter, used for visualizing ground
truth
true_param_b: float
The true value of the second model parameter, used for visualizing ground
truth
model_func: Callable
A function `f` of the form `f(x, param_a, param_b)`. Evaluates the model
given at data points x, given the current state of parameters, `param_a`
and `param_b`. Returns a scalar output for the `y` associated with input
`x`.
posterior_func: Callable
A function `f` of the form `f(x_obs, y_obs, param_grid_a, param_grid_b)
that returns the posterior probability given the observed data and the
range of parameters defined by `param_grid_a` and `param_grid_b`.
n_posterior_samples: int
The number of model functions sampled from the 2D posterior
param_labels: Optional[list[str, str]]
For visualization, the names of `param_a` and `param_b`, respectively
data_range_x: Optional len-2 float sequence
For visualization, the upper and lower bounds of the domain used for model
evaluation
data_range_y: Optional len-2 float sequence
For visualization, the upper and lower bounds of the range used for model
evaluation.
"""
param_labels = param_labels if param_labels is not None else ["param_a", "param_b"]
data_range_x = (x_obs.min(), x_obs.max()) if data_range_x is None else data_range_x
data_range_y = (y_obs.min(), y_obs.max()) if data_range_y is None else data_range_y
# NOTE: assume square parameter grid
resolution = len(param_a_grid)
param_a_grid, param_b_grid = np.meshgrid(param_a_grid, param_b_grid)
param_a_grid = param_a_grid.ravel()
param_b_grid = param_b_grid.ravel()
posterior = posterior_func(x_obs, y_obs, param_a_grid, param_b_grid)
# Visualization
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
# Plot Posterior over intercept and slope params
plt.sca(axs[0])
plt.contour(
param_a_grid.reshape(resolution, resolution),
param_b_grid.reshape(resolution, resolution),
posterior.reshape(resolution, resolution),
cmap="gray_r",
)
# Sample locations in parameter space according to posterior
sample_idx = np.random.choice(
np.arange(len(posterior)),
p=posterior / posterior.sum(),
size=n_posterior_samples,
)
param_a_list = []
param_b_list = []
for ii, idx in enumerate(sample_idx):
param_a = param_a_grid[idx]
param_b = param_b_grid[idx]
param_a_list.append(param_a)
param_b_list.append(param_b)
# Add sampled parameters to posterior
plt.scatter(param_a, param_b, s=60, c=f"C{ii}", alpha=0.75, zorder=20)
# Add the true params to the plot for reference
plt.scatter(
true_param_a, true_param_b, color="k", marker="x", s=60, label="true parameters"
)
plt.xlabel(param_labels[0])
plt.ylabel(param_labels[1])
# Plot the current training data and model trends sampled from posterior
plt.sca(axs[1])
plt.scatter(x_obs, y_obs, s=60, c="k", alpha=0.5)
# Plot the resulting model functions sampled from posterior
xs = np.linspace(data_range_x[0], data_range_x[1], 100)
for ii, (param_a, param_b) in enumerate(zip(param_a_list, param_b_list)):
ys = model_func(xs, param_a, param_b)
plt.plot(xs, ys, color=f"C{ii}", linewidth=4, alpha=0.5)
groundtruth_ys = model_func(xs, true_param_a, true_param_b)
plt.plot(
xs, groundtruth_ys, color="k", linestyle="--", alpha=0.5, label="true trend"
)
plt.xlim([data_range_x[0], data_range_x[1]])
plt.xlabel("x value")
plt.ylim([data_range_y[0], data_range_y[1]])
plt.ylabel("y value")
plt.title(f"N={len(y_obs)}")
plt.legend(loc="upper left")