Skip to content

Commit

Permalink
finish training
Browse files Browse the repository at this point in the history
  • Loading branch information
namdvt committed Jul 22, 2020
1 parent c2f274c commit 53d5bcf
Show file tree
Hide file tree
Showing 10 changed files with 164 additions and 36 deletions.
Empty file modified .gitignore
100644 → 100755
Empty file.
39 changes: 39 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# CRNN for captcha recognition
## General
This is a simple PyTorch implementation of OCR system using CNN + RNN + CTC loss for captcha recognition.
## Dataset
I used CAPTCHA Images dataset which was downloaded from https://www.kaggle.com/fournierp/captcha-version-2-images
## Files

```
.
├── data
│   └── CAPTCHA Images
│   ├── test
│   ├── train
│   └── val
├── dataset.py
├── model.py
├── output
│   ├── log.txt
│   ├── loss.png
│   └── weight.pth
├── predict.py
├── README.md
├── split_train_val_test.py
├── train.py
└── utils.py
```
### Training
```
python train.py
```
Training and validation loss:

![Image description](output/loss.png)
### Testing
```
python predict.py
```
accuracy = 0.897
Empty file modified data/CAPTCHA Images/.gitignore
100644 → 100755
Empty file.
13 changes: 2 additions & 11 deletions dataset.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
import os
import matplotlib.pyplot as plt

from utils import get_dict


def split_image(image):
output = torch.Tensor([])
Expand All @@ -25,7 +23,6 @@ def __init__(self, root, augment=False):
super(CaptchaImagesDataset, self).__init__()
self.root = root
self.augment = augment
_, self.char2int = get_dict()

self.image_list = []
for ext in ('*.png', '*.jpg'):
Expand All @@ -38,16 +35,10 @@ def __getitem__(self, index):
image = self.image_list[index]
text = image.split('/')[-1].split('.')[0]

image = Image.open(image).convert('L')
image = Image.open(image).convert('RGB')
image = F.to_tensor(image)
image = split_image(image)

label = []
for c in text.lower():
label.append(self.char2int.get(c))
label = torch.tensor(label)

return image, label
return image, text


def get_loader(root, batch_size):
Expand Down
19 changes: 17 additions & 2 deletions model.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torch
import torch.nn as nn
from torchvision.models import resnet18
import torch.nn.functional as F


class Conv2d(nn.Module):
Expand All @@ -18,9 +20,22 @@ def forward(self, x):


class CRNN(nn.Module):
def __init__(self):
def __init__(self, vocab_size):
super(CRNN, self).__init__()
self.conv = Conv2d(3,3,3)
resnet = resnet18(pretrained=True)
modules = list(resnet.children())[:-3]
self.resnet = nn.Sequential(*modules)
self.fc1 = nn.Linear(1024, 256)
self.fc2 = nn.Linear(256, vocab_size)
self.gru1 = nn.GRU(input_size=256, hidden_size=256)

def forward(self, x):
x = self.resnet(x)
x = x.permute(0, 3, 1, 2).contiguous()
x = x.view(x.shape[0], x.shape[1], -1)
x = F.dropout(self.fc1(x), p=0.5)
output, _ = self.gru1(x)
x = self.fc2(output)
x = x.permute(1, 0, 2)

return x
Empty file modified output/.gitignore
100644 → 100755
Empty file.
40 changes: 40 additions & 0 deletions predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import string

import numpy as np
from PIL import Image
import torch
import torchvision.transforms.functional as F
import matplotlib.pyplot as plt
from model import CRNN
import os
from tqdm import tqdm
import glob
from dataset import CaptchaImagesDataset
from utils import LabelConverter
from tqdm import tqdm


if __name__ == '__main__':
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
label_converter = LabelConverter(char_set=string.ascii_lowercase + string.digits)
vocab_size = label_converter.get_vocab_size()

model = CRNN(vocab_size=vocab_size).to(device)
model.load_state_dict(torch.load('output/weight.pth', map_location=device))
model.eval()

correct = 0.0
image_list = glob.glob('data/CAPTCHA Images/test/*')
for image in tqdm(image_list):
ground_truth = image.split('/')[-1].split('.')[0]
image = Image.open(image).convert('RGB')
image = F.to_tensor(image).unsqueeze(0).to(device)

output = model(image)
encoded_text = output.squeeze().argmax(1)
decoded_text = label_converter.decode(encoded_text)

if ground_truth == decoded_text:
correct += 1

print('accuracy =', correct/len(image_list))
Empty file modified split_train_val_test.py
100644 → 100755
Empty file.
42 changes: 28 additions & 14 deletions train.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,14 +1,27 @@
import torch
import torch.optim as optim
from utils import write_log, write_figure
from utils import write_log, write_figure, LabelConverter
import numpy as np
from dataset import get_loader
from tqdm import tqdm
from model import CRNN
import math
import string
import torch.nn as nn


def fit(epoch, model, optimizer, criterion, device, data_loader, phase='training'):
def calculate_loss(inputs, texts, label_converter, device):
criterion = nn.CTCLoss(blank=0)

inputs = inputs.log_softmax(2)
input_size, batch_size, _ = inputs.size()
input_size = torch.full(size=(batch_size,), fill_value=input_size, dtype=torch.int32)

encoded_texts, text_lens = label_converter.encode(texts)
loss = criterion(inputs, encoded_texts.to(device), input_size.to(device), text_lens.to(device))
return loss


def fit(epoch, model, optimizer, label_converter, device, data_loader, phase='training'):
if phase == 'training':
model.train()
else:
Expand All @@ -26,7 +39,7 @@ def fit(epoch, model, optimizer, criterion, device, data_loader, phase='training
with torch.no_grad():
outputs = model(images)

loss = criterion(outputs, labels)
loss = calculate_loss(outputs, labels, label_converter, device)
running_loss += loss.item()

if phase == 'training':
Expand All @@ -40,39 +53,40 @@ def fit(epoch, model, optimizer, criterion, device, data_loader, phase='training

def train():
print('start training ...........')
batch_size = 2
num_epochs = 200
learning_rate = 0.01
batch_size = 16
num_epochs = 50
learning_rate = 0.1

label_converter = LabelConverter(char_set=string.ascii_lowercase + string.digits)
vocab_size = label_converter.get_vocab_size()

device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
model = CRNN()
model = CRNN(vocab_size=vocab_size).to(device)
# model.load_state_dict(torch.load('output/weight.pth', map_location=device))

train_loader, val_loader = get_loader('data/CAPTCHA Images/', batch_size=batch_size)

optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, nesterov=True)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
criterion = None
# scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 10, 2)

train_losses, val_losses = [], []
for epoch in range(num_epochs):
train_epoch_loss = fit(epoch, model, optimizer, criterion, device, train_loader, phase='training')
val_epoch_loss = fit(epoch, model, optimizer, criterion, device, val_loader, phase='validation')
train_epoch_loss = fit(epoch, model, optimizer, label_converter, device, train_loader, phase='training')
val_epoch_loss = fit(epoch, model, optimizer, label_converter, device, val_loader, phase='validation')
print('-----------------------------------------')

if epoch == 0 or val_epoch_loss <= np.min(val_losses):
torch.save(model.state_dict(), 'output/weight.pth')

# if epoch == 0 or train_epoch_loss <= np.min(train_losses):
# torch.save(model.state_dict(), 'output/weight.pth')

train_losses.append(train_epoch_loss)
val_losses.append(val_epoch_loss)

write_figure('output', train_losses, val_losses)
write_log('output', epoch, train_epoch_loss, val_epoch_loss)

scheduler.step(val_epoch_loss)
# scheduler.step(epoch)


if __name__ == "__main__":
Expand Down
47 changes: 38 additions & 9 deletions utils.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,18 +1,47 @@
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import torch
import collections


def get_dict():
NUMBER = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
ALPHABET = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u',
'v', 'w', 'x', 'y', 'z']
ALL_CHAR_SET = NUMBER + ALPHABET
class LabelConverter:
def __init__(self, char_set):
char = ['-'] + sorted(set(''.join(char_set)))
self.vocab_size = len(char)
self.int2char = dict(enumerate(char))
self.char2int = {char: ind for ind, char in self.int2char.items()}

char = sorted(set(''.join(ALL_CHAR_SET)))
int2char = dict(enumerate(char))
char2int = {char: ind for ind, char in int2char.items()}
return int2char, char2int
def get_vocab_size(self):
return self.vocab_size

def encode(self, texts):
text_length = []
for t in texts:
text_length.append(len(t))

encoded_texts = []
for t in texts:
for c in t.lower():
encoded_texts.append(self.char2int.get(c))

return torch.tensor(encoded_texts), torch.tensor(text_length)

def decode(self, encoded_text):
# decode
text = []
for i in encoded_text:
text.append(self.int2char.get(i.item()))

# remove duplicate
decoded_text = ''
for i, t in enumerate(text):
if t == '-':
continue
if i > 0 and t == text[i-1]:
continue
decoded_text = decoded_text + t

return decoded_text


def write_figure(location, train_losses, val_losses):
Expand Down

0 comments on commit 53d5bcf

Please sign in to comment.