forked from 767342809/image_quality_assessment
-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
executable file
·83 lines (59 loc) · 2.7 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
import csv
import glob
import json
import argparse
from src.utils.utils import calc_mean_score, save_json
from src.handlers.model_builder import Nima
from src.handlers.data_generator import TestDataGenerator
def image_file_to_json(img_path):
img_dir = os.path.dirname(img_path)
img_id = os.path.basename(img_path).split('.')[0]
return img_dir, [{'image_id': img_id}]
def image_dir_to_json(img_dir, img_type='jpg'):
img_paths = glob.glob(os.path.join(img_dir, '*.'+img_type))
samples = []
for img_path in img_paths:
img_id = os.path.basename(img_path).split('.')[0]
samples.append({'image_id': img_id})
return samples
def predict(model, data_generator):
return model.predict_generator(data_generator, workers=8, use_multiprocessing=True, verbose=1)
def main(base_model_name, weights_file, image_source, predictions_file, img_format='jpg'):
# load samples
if os.path.isfile(image_source):
image_dir, samples = image_file_to_json(image_source)
else:
image_dir = image_source
samples = image_dir_to_json(image_dir, img_type='jpg')
# build model and load weights
nima = Nima(base_model_name, weights=None)
nima.build()
nima.nima_model.load_weights(weights_file)
# initialize data generator
data_generator = TestDataGenerator(samples, image_dir, 64, 10, nima.preprocessing_function(),
img_format=img_format)
# get predictions
predictions = predict(nima.nima_model, data_generator)
# calc mean scores and add to samples
for i, sample in enumerate(samples):
sample['mean_score_prediction'] = calc_mean_score(predictions[i])
print(samples)
new_sample = sorted(samples, key=lambda x: x["mean_score_prediction"], reverse=True)
print("new_sample: ", new_sample)
# print(json.dumps(new_sample, indent=2))
with open("img_quality.csv", "w", encoding="utf-8") as f:
csv_writer = csv.writer(f)
csv_writer.writerow(['id', 'mean'])
for d in new_sample:
csv_writer.writerow([d["image_id"], d["mean_score_prediction"]])
if predictions_file is not None:
save_json(new_sample, predictions_file)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-b', '--base-model-name', help='CNN base model name', required=True)
parser.add_argument('-w', '--weights-file', help='path of weights file', required=True)
parser.add_argument('-is', '--image-source', help='image directory or file', required=True)
parser.add_argument('-pf', '--predictions-file', help='file with predictions', required=False, default=None)
args = parser.parse_args()
main(**args.__dict__)