-
Notifications
You must be signed in to change notification settings - Fork 2
/
demo.py
109 lines (90 loc) · 4.68 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from __future__ import print_function
import tensorflow as tf
import numpy as np
import os, sys, cv2
import glob
import shutil
sys.path.append(os.getcwd())
from ctpn.lib.networks.factory import get_network
from ctpn.lib.fast_rcnn.config import cfg,cfg_from_file
from ctpn.lib.fast_rcnn.test import test_ctpn
from ctpn.lib.fast_rcnn.nms_wrapper import nms
from ctpn.lib.utils.timer import Timer
from ctpn.lib.text_connector.detectors import TextDetector
from ctpn.lib.text_connector.text_connect_cfg import Config as TextLineCfg
class CTPN(object):
#load the model
def __init__(self):
cfg_from_file(os.getcwd() +os.sep+ 'ctpn/text.yml')
# init session
self.config = tf.ConfigProto(device_count={'GPU': 1, 'CPU': 4},log_device_placement=False,allow_soft_placement=True)
g_ctpn = tf.Graph()
self.sess = tf.Session(graph=g_ctpn,config=self.config)
with g_ctpn.as_default():
# load network
self.net = get_network("VGGnet_test")
# load model
print(('Loading network {:s}... '.format("VGGnet_test")), end=' ')
saver = tf.train.Saver()
try:
ckpt = tf.train.get_checkpoint_state(cfg.TEST.checkpoints_path)
print('Restoring from {}...'.format(ckpt.model_checkpoint_path), end=' ')
saver.restore(self.sess, ckpt.model_checkpoint_path)
print('done')
except:
raise 'Check your pretrained {:s}'.format(ckpt.model_checkpoint_path)
im = 128 * np.ones((300, 300, 3), dtype=np.uint8)
for i in range(2):
_, _ = test_ctpn(self.sess, self.net, im)
def resize_im(self, im, scale, max_scale=None):
f=float(scale)/min(im.shape[0], im.shape[1])
if max_scale!=None and f*max(im.shape[0], im.shape[1])>max_scale:
f=float(max_scale)/max(im.shape[0], im.shape[1])
return cv2.resize(im, None,None, fx=f, fy=f,interpolation=cv2.INTER_LINEAR), f
def draw_boxes_ex(self,img,image_name,boxes,scale):
base_name = image_name.split(os.sep)[-1]
with open('data/results/' + 'res_{}.txt'.format(base_name.split('.')[0]), 'w') as f:
for box in boxes:
if np.linalg.norm(box[0] - box[1]) < 5 or np.linalg.norm(box[3] - box[0]) < 5:
continue
if box[8] >= 0.9:
color = (0, 255, 0)
elif box[8] >= 0.8:
color = (255, 0, 0)
cv2.line(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), color, 2)
cv2.line(img, (int(box[0]), int(box[1])), (int(box[4]), int(box[5])), color, 2)
cv2.line(img, (int(box[6]), int(box[7])), (int(box[2]), int(box[3])), color, 2)
cv2.line(img, (int(box[4]), int(box[5])), (int(box[6]), int(box[7])), color, 2)
min_x = min(int(box[0]/scale),int(box[2]/scale),int(box[4]/scale),int(box[6]/scale))
min_y = min(int(box[1]/scale),int(box[3]/scale),int(box[5]/scale),int(box[7]/scale))
max_x = max(int(box[0]/scale),int(box[2]/scale),int(box[4]/scale),int(box[6]/scale))
max_y = max(int(box[1]/scale),int(box[3]/scale),int(box[5]/scale),int(box[7]/scale))
line = ','.join([str(min_x),str(min_y),str(max_x),str(max_y)])+'\r\n'
f.write(line)
img=cv2.resize(img, None, None, fx=1.0/scale, fy=1.0/scale, interpolation=cv2.INTER_LINEAR)
cv2.imwrite(os.path.join("data/results", base_name), img)
def get_text_box(self,img,image_name):
timer = Timer()
timer.tic()
im, scale = self.resize_im(img, scale=TextLineCfg.SCALE, max_scale=TextLineCfg.MAX_SCALE)
scores, boxes = test_ctpn(self.sess, self.net, im)
textdetector = TextDetector()
boxes = textdetector.detect(boxes, scores[:, np.newaxis], im.shape[:2])
self.draw_boxes_ex(im, image_name, boxes, scale)
timer.toc()
print(('Detection took {:.3f}s for '
'{:d} object proposals').format(timer.total_time, boxes.shape[0]))
tmp = im.copy()
return tmp,boxes
if __name__ == '__main__':
if os.path.exists("data/results/"):
shutil.rmtree("data/results/")
os.makedirs("data/results/")
ctpnObj = CTPN()
im_names = glob.glob(os.path.join(cfg.DATA_DIR, 'demo', '*.png')) + \
glob.glob(os.path.join(cfg.DATA_DIR, 'demo', '*.jpg'))
for im_name in im_names:
print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
print(('Demo for {:s}'.format(im_name)))
im = cv2.imread(im_name)
ctpnObj.get_text_box(im,im_name)