forked from yeyupiaoling/Tensorflow-FaceRecognition
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfcn_detector.py
44 lines (39 loc) · 1.98 KB
/
fcn_detector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import tensorflow as tf
import sys
sys.path.append("../")
class FcnDetector(object):
#net_factory: which net
#model_path: where the params'file is
def __init__(self, net_factory, model_path):
#create a graph
graph = tf.Graph()
with graph.as_default():
#define tensor and op in graph(-1,1)
self.image_op = tf.placeholder(tf.float32, name='input_image')
self.width_op = tf.placeholder(tf.int32, name='image_width')
self.height_op = tf.placeholder(tf.int32, name='image_height')
image_reshape = tf.reshape(self.image_op, [1, self.height_op, self.width_op, 3])
#self.cls_prob batch*2
#self.bbox_pred batch*4
#construct model here
#self.cls_prob, self.bbox_pred = net_factory(image_reshape, training=False)
#contains landmark
self.cls_prob, self.bbox_pred, _ = net_factory(image_reshape, training=False)
#allow
self.sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, gpu_options=tf.GPUOptions(allow_growth=True)))
saver = tf.train.Saver()
#check whether the dictionary is valid
model_dict = '/'.join(model_path.split('/')[:-1])
ckpt = tf.train.get_checkpoint_state(model_dict)
print(model_path)
readstate = ckpt and ckpt.model_checkpoint_path
assert readstate, "the params dictionary is not valid"
print("restore models' param")
saver.restore(self.sess, model_path)
def predict(self, databatch):
height, width, _ = databatch.shape
# print(height, width)
cls_prob, bbox_pred = self.sess.run([self.cls_prob, self.bbox_pred],
feed_dict={self.image_op: databatch, self.width_op: width,
self.height_op: height})
return cls_prob, bbox_pred