-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
132 lines (99 loc) · 4.81 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import logging
import os
import random
import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from data.dataset import LPRDataSet
from data.ccpd import load_dataset
from model.lprnet import LPRNet, CHARS
from utils.general import decode, sparse_tuple_for_ctc, set_logging
logger = logging.getLogger(__name__)
set_logging()
def test(model, data_loader, dataset, device, ctc_loss, lpr_max_len, float_test=False):
correct_count = 0
process_count = 0
half = not float_test and (device.type != 'cpu')
if half:
model.half()
pbar = tqdm(enumerate(data_loader), total=len(data_loader), desc='Test')
mloss = 0.0
for i, (imgs, labels, lengths) in pbar:
imgs, labels = imgs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
imgs = imgs.half() if half else imgs.float()
labels = labels.half() if half else labels.float()
# 准备 loss 计算的参数
input_lengths, target_lengths = sparse_tuple_for_ctc(lpr_max_len, lengths)
with torch.no_grad():
x = model(imgs)
y = x.permute(2, 0, 1) # [batch_size, chars, width] -> [width, batch_size, chars]
y = y.log_softmax(2).requires_grad_()
loss = ctc_loss(y.float(), labels.float(), input_lengths=input_lengths, target_lengths=target_lengths)
x = x.cpu().detach().numpy()
_, pred_labels = decode(x)
start = 0
for j, length in enumerate(lengths):
label = labels[start:start + length]
start += length
if np.array_equal(np.array(pred_labels[j]), label.cpu().numpy()):
correct_count += 1
# Print
mloss = (mloss * i + loss.item()) / (i + 1) # update mean losses
process_count += len(lengths)
acc = float(correct_count) / float(process_count)
pbar.set_description('Test mloss: %.5f, macc: %.5f' % (mloss, acc))
acc = float(correct_count) / float(len(dataset))
model.float()
return mloss, acc
def main(opts):
# 选择设备
device = torch.device("cuda:0" if (not opts.cpu and torch.cuda.is_available()) else "cpu")
cuda = device.type != 'cpu'
logger.info('Use device %s.' % device)
# 定义网络
model = LPRNet(class_num=len(CHARS), dropout_rate=opts.dropout_rate).to(device)
logger.info("Build network is successful.")
# 损失函数
ctc_loss = torch.nn.CTCLoss(blank=len(CHARS) - 1, reduction='mean') # reduction: 'none' | 'mean' | 'sum'
# Load weights
ckpt = torch.load(opts.weights, map_location=device)
# 加载网络
model.load_state_dict(ckpt["model"])
# 释放内存
del ckpt
# Print
logger.info('Load weights completed.')
# 加载数据
_, test_dataset = load_dataset(args.source_dir, args.cache_dir, opts.img_size)
test_loader = DataLoader(test_dataset, batch_size=opts.batch_size, shuffle=False, num_workers=opts.workers,
pin_memory=cuda, collate_fn=test_dataset.collate_fn)
logger.info('Image sizes %d test' % (len(test_dataset)))
logger.info('Using %d dataloader workers' % opts.workers)
model.eval()
test(model, test_loader, test_dataset, device, ctc_loss, opts.lpr_max_len, opts.float_test)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='STNet & LPRNet Testing')
parser.add_argument('--source-dir', type=str, default="/home/hejingyi/dataset/licence_plate/", help='train images source dir.')
parser.add_argument('--weights', type=str, default="runs/exp6/weights/best.pt", help='initial weights path.')
parser.add_argument('--img-size', default=(96, 48), help='the image size')
parser.add_argument('--cpu', action='store_true', help='force use cpu.')
parser.add_argument('--batch-size', type=int, default=1024, help='train batch size.')
parser.add_argument('--dropout_rate', default=0.5, help='dropout rate.')
parser.add_argument('--lpr-max-len', default=18, help='license plate number max length.')
parser.add_argument('--float-test', action='store_true', help='use float model run test.')
parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers.')
parser.add_argument('--worker-dir', type=str, default='runs', help='worker dir.')
args = parser.parse_args()
del parser
# 打印参数
logger.info("args: %s" % args)
# 自动调整的参数(不打印)
args.cache_dir = os.path.join(args.worker_dir, 'cache')
# 参数处理后的初始化工作
os.makedirs(args.cache_dir, exist_ok=True)
# 开始训练
main(args)