-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvisualizer.py
32 lines (26 loc) · 938 Bytes
/
visualizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import matplotlib.pyplot as plt
import numpy as np
def plot_mnist_sample(mnist_sample):
mnist_sample = np.squeeze(mnist_sample)
plt.figure(figsize=(7, 7))
plt.imshow(mnist_sample, cmap='gray', interpolation='none')
plt.title('MNIST sample', fontsize=20)
plt.axis('off')
plt.show()
def plot_mnist_grid(image_batch, function=None):
fig = plt.figure()
if function is not None:
image_result = function([image_batch[:9]])
else:
image_result = np.expand_dims(image_batch[:9], 0)
plt.clf()
for image_arg in range(9):
plt.subplot(3, 3, image_arg + 1)
image = np.squeeze(image_result[0][image_arg])
plt.imshow(image, cmap='gray')
plt.axis('off')
fig.canvas.draw()
plt.show()
def print_evaluation(epoch_arg, val_score, test_score):
message = 'Epoch: {0} | Val: {1} | Test: {2}'
print(message.format(epoch_arg, val_score, test_score))