-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathload_save.py
36 lines (30 loc) · 1.1 KB
/
load_save.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import sys, os
import numpy as np
import tensorflow as tf
def path_name(dataset, alpha, num_samples, backward_pass, extra_string = None):
path = 'ckpts/' + dataset + '/'
if backward_pass == 'max':
folder_name = 'max_k%d' % num_samples
else:
folder_name = 'alpha%.2f_k%d' % (alpha, num_samples)
if extra_string is not None:
folder_name += '_' + extra_string
path = path + folder_name + '/'
return path
def save_checkpoint(sess, path, checkpoint=1, var_list = None):
if not os.path.exists(path):
os.makedirs(path)
# save model
fname = path + 'checkpoint%d.ckpt' % checkpoint
saver = tf.train.Saver(var_list)
save_path = saver.save(sess, fname)
print("Model saved in %s" % save_path)
def load_checkpoint(sess, path, checkpoint=1):
# load model
try:
fname = path + 'checkpoint%d.ckpt' % checkpoint
saver = tf.train.Saver()
saver.restore(sess, fname)
print("Model restored from %s" % fname)
except:
print "Failed to load model from %s" % fname