Skip to content

Commit

Permalink
attention ocr
Browse files Browse the repository at this point in the history
  • Loading branch information
zhang0jhon committed Nov 6, 2019
0 parents commit 31e81c0
Show file tree
Hide file tree
Showing 26 changed files with 7,720 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
/checkpoint/
/pretrain/inception_v4.ckpt
/icdar_datasets.npy
82 changes: 82 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# AttentionOCR for Arbitrary-Shaped Scene Text Recognition

## Introduction

This is the **ranked No.1** tensorflow based scene text spotting algorithm on [__ICDAR2019 Robust Reading Challenge on Arbitrary-Shaped Text__](https://rrc.cvc.uab.es/?ch=14) (Latin Only, Latin and Chinese), futhermore, the algorithm is also adopted in [__ICDAR2019 Robust Reading Challenge on Large-scale Street View Text with Partial Labeling__](https://rrc.cvc.uab.es/?ch=16) and [__ICDAR2019 Robust Reading Challenge on Reading Chinese Text on Signboard__](https://rrc.cvc.uab.es/?ch=12).
Scene text detection algorithm is modified from [__Tensorpack FasterRCNN__](https://github.com/tensorpack/tensorpack/tree/master/examples/FasterRCNN), and we only open source code in this repository for scene text recognition.

Note that our text recognition algorithm not only recognize Latin and Non-Latin characters, but also support horizontal and vertical text recognition in one model. It is convenient for multi-lingual arbitrary-shaped text recognition.

## Dependencies

```
python 3
tensorflow-gpu 1.14
tensorpack 0.9.8
```

## Usage

<!-- It is recommended to get familiar the relevant papers listed below:
+ [Neural Machine Translation by Jointly Learning to Align and Translate](https://arxiv.org/abs/1409.0473)
+ [Show, Attend and Tell: Neural Image Caption Generation with Visual Attention](https://arxiv.org/abs/1502.03044) -->

First download and extract multiple text datasets in base text dir, please refer to dataset.py for dataset preprocess and multiple datasets.

### Multiple Datasets

```
$(base_dir)/lsvt
$(base_dir)/art
$(base_dir)/rects
$(base_dir)/icdar2017rctw
```

### Train

You can modify your gpu lists in config.py for specified gpus and then run:
```
python train.py
```
Use ICDAR2019-LSVT, ICDAR2019-ArT, ICDAR2019-ReCTS for default training, you can change it with your own training data.

### Evaluation

```
python eval.py --checkpoint_path=$(Your model path)
```

Use ICDAR2017RCTW for default evaluation with Normalized Edit Distance metric(1-N.E.D specifically), you can change it with your own evaluation data.

### Export

Export checkpoint to tensorflow pb model for inference.

```
python export.py --pb_path=$(Your tensorflow pb model save path) --checkpoint_path=$(Your trained model path)
```

### Test

Load tensorflow pb model for text recognition.
```
python test.py --pb_path=$(Your tensorflow pb model save path) --img_folder=$(Your test img folder)
```
Default use ICDAR2019-ArT for test, you can change it with your own test data.

## Visualization

Scene text detection and recognition result:

![](imgs/viz.png)

Scene text recognition attention maps:

![](imgs/attention_maps_gt_1.jpg)
![](imgs/attention_maps_gt_8454.jpg)
![](imgs/attention_maps_gt_8459.jpg)
![](imgs/attention_maps_gt_8473.jpg)
![](imgs/attention_maps_gt_8601.jpg)
![](imgs/attention_maps_gt_8622.jpg)
![](imgs/attention_maps_gt_918.jpg)
![](imgs/attention_maps_gt_94.jpg)
164 changes: 164 additions & 0 deletions common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# -*- coding: utf-8 -*-
# File: common.py

import numpy as np
import cv2

from tensorpack.dataflow import RNGDataFlow
from tensorpack.dataflow.imgaug import ImageAugmentor, ResizeTransform


class DataFromListOfDict(RNGDataFlow):
def __init__(self, lst, keys, shuffle=False):
self._lst = lst
self._keys = keys
self._shuffle = shuffle
self._size = len(lst)

def __len__(self):
return self._size

def __iter__(self):
if self._shuffle:
self.rng.shuffle(self._lst)
for dic in self._lst:
dp = [dic[k] for k in self._keys]
yield dp


class CustomResize(ImageAugmentor):
"""
Try resizing the shortest edge to a certain number
while avoiding the longest edge to exceed max_size.
"""

def __init__(self, short_edge_length, max_size, interp=cv2.INTER_LINEAR):
"""
Args:
short_edge_length ([int, int]): a [min, max] interval from which to sample the
shortest edge length.
max_size (int): maximum allowed longest edge length.
"""
super(CustomResize, self).__init__()
if isinstance(short_edge_length, int):
short_edge_length = (short_edge_length, short_edge_length)
self._init(locals())

def get_transform(self, img):
h, w = img.shape[:2]
size = self.rng.randint(
self.short_edge_length[0], self.short_edge_length[1] + 1)
scale = size * 1.0 / min(h, w)
if h < w:
newh, neww = size, scale * w
else:
newh, neww = scale * h, size
if max(newh, neww) > self.max_size:
scale = self.max_size * 1.0 / max(newh, neww)
newh = newh * scale
neww = neww * scale
neww = int(neww + 0.5)
newh = int(newh + 0.5)
return ResizeTransform(h, w, newh, neww, self.interp)


def box_to_point8(boxes):
"""
Args:
boxes: nx4
Returns:
(nx4)x2
"""
b = boxes[:, [0, 1, 2, 3, 0, 3, 2, 1]]
b = b.reshape((-1, 2))
return b


def point8_to_box(points):
"""
Args:
points: (nx4)x2
Returns:
nx4 boxes (x1y1x2y2)
"""
p = points.reshape((-1, 4, 2))
minxy = p.min(axis=1) # nx2
maxxy = p.max(axis=1) # nx2
return np.concatenate((minxy, maxxy), axis=1)


def polygons_to_mask(polys, height, width):
"""
Convert polygons to binary masks.
Args:
polys: a list of nx2 float array. Each array contains many (x, y) coordinates.
Returns:
a binary matrix of (height, width)
"""
polys = [p.flatten().tolist() for p in polys]
assert len(polys) > 0, "Polygons are empty!"

import pycocotools.mask as cocomask
rles = cocomask.frPyObjects(polys, height, width)
rle = cocomask.merge(rles)
return cocomask.decode(rle)


def clip_boxes(boxes, shape):
"""
Args:
boxes: (...)x4, float
shape: h, w
"""
orig_shape = boxes.shape
boxes = boxes.reshape([-1, 4])
h, w = shape
boxes[:, [0, 1]] = np.maximum(boxes[:, [0, 1]], 0)
boxes[:, 2] = np.minimum(boxes[:, 2], w)
boxes[:, 3] = np.minimum(boxes[:, 3], h)
return boxes.reshape(orig_shape)


def filter_boxes_inside_shape(boxes, shape):
"""
Args:
boxes: (nx4), float
shape: (h, w)
Returns:
indices: (k, )
selection: (kx4)
"""
assert boxes.ndim == 2, boxes.shape
assert len(shape) == 2, shape
h, w = shape
indices = np.where(
(boxes[:, 0] >= 0) &
(boxes[:, 1] >= 0) &
(boxes[:, 2] <= w) &
(boxes[:, 3] <= h))[0]
return indices, boxes[indices, :]


try:
import pycocotools.mask as cocomask

# Much faster than utils/np_box_ops
def np_iou(A, B):
def to_xywh(box):
box = box.copy()
box[:, 2] -= box[:, 0]
box[:, 3] -= box[:, 1]
return box

ret = cocomask.iou(
to_xywh(A), to_xywh(B),
np.zeros((len(B),), dtype=np.bool))
# can accelerate even more, if using float32
return ret.astype('float32')

except ImportError:
from utils.np_box_ops import iou as np_iou # noqa
91 changes: 91 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# -*- coding: utf-8 -*-

import os
from parse_dict import get_dict

# base dir for multiple text datasets
base_dir = '/opt/data/nfs/zhangjinjin/data/text/'

# font path for visualization
font_path = './fonts/cn/SourceHanSans-Normal.ttf'

# 'ocr' for inception model with padding image.
# 'ocr_with_normalized_bbox' for inception model with cropped text region for attention lstm.
model_name = 'ocr' # 'ocr_with_normalized_bbox'

# path for tensorboard summary and checkpoint path
summary_path = './checkpoint'

# tensorflow model name scope
name_scope = 'InceptionV4'

# path for numpy dict with processed image paths and labels used in dataset.py
dataset_name = 'icdar_datasets.npy'
# pb_path = './checkpoint/text_recognition_5435.pb'

# restore training parameters
restore_path = ''
starting_epoch = 1

# checkpoint_path = './checkpoint/model-10000'
# imagenet pretrain model path
pretrain_path = './pretrain/inception_v4.ckpt'

# label dict for text recognition
label_dict = get_dict()
reverse_label_dict = dict((v,k) for k,v in label_dict.items())

# gpu lists
gpus = [6, 7, 8, 9]

num_gpus = len(gpus)
num_classes = len(label_dict)

# max sequence length without EOS
seq_len = 32

# embedding size
wemb_size = 256

# lstm size
lstm_size = 512

# minimum cropped image size for data augment
crop_min_size = 224

# input image size
image_size = 256

# max random image offset for data augment
offset = 16

# CNN endpoint stride
stride = 8

# resize parameters for data augment
TRAIN_SHORT_EDGE_SIZE = 8
MAX_SIZE = image_size - 32

# training batch size
batch_size = 10 #12

# steps per training epoch in tensorpack
steps_per_epoch = 500

# max epochs
num_epochs = 1000

# model weight decay factor
weight_decay = 1e-5

# base learning rate
learning_rate = 1e-4 * num_gpus

# minimun learning rate for cosine decay learning rate
min_lr = learning_rate / 100

# warm up steps
warmup_steps = 1000

# thread for multi-thread data loading
num_threads = 16
Loading

0 comments on commit 31e81c0

Please sign in to comment.