-
Notifications
You must be signed in to change notification settings - Fork 69
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
11 changed files
with
459 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
[settings] | ||
known_third_party = PIL,cog,cv2,detectron2,grade,graphviz,h5py,matplotlib,mmcv,mmdet,numpy,packaging,panopticapi,pycocotools,setuptools,six,terminaltables,torch,torchvision,tqdm,xmltodict | ||
known_third_party = PIL,cog,cv2,dataset,detectron2,evaluator,grade,graphviz,h5py,matplotlib,mmcv,mmdet,numpy,packaging,panopticapi,pycocotools,setuptools,six,terminaltables,torch,torchvision,tqdm,trainer,xmltodict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# CE7454 Assignment 1: PSG Classification | ||
|
||
## Get Started | ||
|
||
We specially build the tiny codebase here (in this directory) to help our students to quickly get started. | ||
|
||
First off, let's clone or fork the codebase and enter in the `ce7454` directory. Don't forget to star the repo if you find the assignment is interesting and instructive. | ||
Then we download the data [here](https://entuedu-my.sharepoint.com/:f:/g/personal/jingkang001_e_ntu_edu_sg/El0TuNPJWyVJqw6agAzd7l0BU_Tr9gJEgUzLbZggaAKyHg?e=cFPqNF) and unzip the data at the correct place. Eventually, your `ce7454` folder should looks like this: | ||
``` | ||
ce7454 | ||
├── checkpoints | ||
├── data | ||
│ ├── coco | ||
│ │ ├── panoptic_train2017 | ||
│ │ ├── panoptic_val2017 | ||
│ │ ├── train2017 | ||
│ │ └── val2017 | ||
│ └── psg | ||
│ ├── psg_cls_basic.json | ||
│ └── psg_val_advanced.json | ||
├── results | ||
├── dataset.py | ||
├── evaluator.py | ||
├── ... | ||
``` | ||
|
||
|
||
We provide 4500 training data, 500 validation data, and 500 test data. | ||
Notice that there might not be exactly 4500 training images (so are val/test images) as some images are annotated twice, and we consider one annotation as one sample. | ||
|
||
Then, we need to setup the environment. We use `conda` to manage our dependencies. The code is heavily dependent on PyTorch. | ||
|
||
```bash | ||
conda install python=3.7 pytorch=1.7.0 torchvision=0.8.0 torchaudio==0.7.0 cudatoolkit=10.1 | ||
pip install tqdm | ||
``` | ||
|
||
Finally, make sure your working directory is `ce7454`, and let's train the model! | ||
```bash | ||
python main.py | ||
``` | ||
|
||
You can explore the project by reading from the `main.py` and dive in. Good Luck!! |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import io | ||
import json | ||
import logging | ||
import os | ||
|
||
import torch | ||
import torchvision.transforms as trn | ||
from PIL import Image, ImageFile | ||
from torch.utils.data import Dataset | ||
|
||
# to fix "OSError: image file is truncated" | ||
ImageFile.LOAD_TRUNCATED_IMAGES = True | ||
|
||
|
||
class Convert: | ||
def __init__(self, mode='RGB'): | ||
self.mode = mode | ||
|
||
def __call__(self, image): | ||
return image.convert(self.mode) | ||
|
||
|
||
def get_transforms(stage: str): | ||
mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] | ||
if stage == 'train': | ||
return trn.Compose([ | ||
Convert('RGB'), | ||
trn.Resize((1333, 800)), | ||
trn.RandomHorizontalFlip(), | ||
trn.RandomCrop((1333, 800), padding=4), | ||
trn.ToTensor(), | ||
trn.Normalize(mean, std), | ||
]) | ||
|
||
elif stage in ['val', 'test']: | ||
return trn.Compose([ | ||
Convert('RGB'), | ||
trn.Resize((1333, 800)), | ||
trn.ToTensor(), | ||
trn.Normalize(mean, std), | ||
]) | ||
|
||
|
||
class PSGClsDataset(Dataset): | ||
def __init__( | ||
self, | ||
stage, | ||
root='./data/coco/', | ||
num_classes=56, | ||
): | ||
super(PSGClsDataset, self).__init__() | ||
with open('./data/psg/psg_cls_basic.json') as f: | ||
dataset = json.load(f) | ||
self.imglist = [ | ||
d for d in dataset['data'] | ||
if d['image_id'] in dataset[f'{stage}_image_ids'] | ||
] | ||
self.root = root | ||
self.transform_image = get_transforms(stage) | ||
self.num_classes = num_classes | ||
|
||
def __len__(self): | ||
return len(self.imglist) | ||
|
||
def __getitem__(self, index): | ||
sample = self.imglist[index] | ||
path = os.path.join(self.root, sample['file_name']) | ||
try: | ||
with open(path, 'rb') as f: | ||
content = f.read() | ||
filebytes = content | ||
buff = io.BytesIO(filebytes) | ||
image = Image.open(buff).convert('RGB') | ||
sample['data'] = self.transform_image(image) | ||
except Exception as e: | ||
logging.error('Error, cannot read [{}]'.format(path)) | ||
raise e | ||
# Generate Soft Label | ||
soft_label = torch.Tensor(self.num_classes) | ||
soft_label.fill_(0) | ||
soft_label[sample['relations']] = 1 | ||
sample['soft_label'] = soft_label | ||
del sample['relations'] | ||
return sample |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.utils.data import DataLoader | ||
|
||
|
||
class Evaluator: | ||
def __init__( | ||
self, | ||
net: nn.Module, | ||
k: int, | ||
): | ||
self.net = net | ||
self.k = k | ||
|
||
def eval_recall( | ||
self, | ||
data_loader: DataLoader, | ||
): | ||
self.net.eval() | ||
loss_avg = 0.0 | ||
pred_list, gt_list = [], [] | ||
with torch.no_grad(): | ||
for batch in data_loader: | ||
data = batch['data'].cuda() | ||
logits = self.net(data) | ||
prob = torch.sigmoid(logits) | ||
target = batch['soft_label'].cuda() | ||
loss = F.binary_cross_entropy(prob, target, reduction='sum') | ||
loss_avg += float(loss.data) | ||
# gather prediction and gt | ||
pred = torch.topk(prob.data, self.k)[1] | ||
pred = pred.cpu().detach().tolist() | ||
pred_list.extend(pred) | ||
for soft_label in batch['soft_label']: | ||
gt_label = (soft_label == 1).nonzero(as_tuple=True)[0]\ | ||
.cpu().detach().tolist() | ||
gt_list.append(gt_label) | ||
|
||
# compute mean recall | ||
score_list = np.zeros([56, 2], dtype=int) | ||
for gt, pred in zip(gt_list, pred_list): | ||
for gt_id in gt: | ||
# pos 0 for counting all existing relations | ||
score_list[gt_id][0] += 1 | ||
if gt_id in pred: | ||
# pos 1 for counting relations that is recalled | ||
score_list[gt_id][1] += 1 | ||
score_list = score_list[6:] | ||
# to avoid nan | ||
score_list[:, 0][score_list[:, 0] == 0] = 1 | ||
meanrecall = np.mean(score_list[:, 1] / score_list[:, 0]) | ||
|
||
metrics = {} | ||
metrics['test_loss'] = loss_avg / len(data_loader) | ||
metrics['mean_recall'] = meanrecall | ||
|
||
return metrics | ||
|
||
def submit( | ||
self, | ||
data_loader: DataLoader, | ||
): | ||
self.net.eval() | ||
|
||
pred_list = [] | ||
with torch.no_grad(): | ||
for batch in data_loader: | ||
data = batch['data'].cuda() | ||
logits = self.net(data) | ||
prob = torch.sigmoid(logits) | ||
pred = torch.topk(prob.data, self.k)[1] | ||
pred = pred.cpu().detach().tolist() | ||
pred_list.extend(pred) | ||
return pred_list |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import argparse | ||
import os | ||
|
||
import numpy as np | ||
|
||
|
||
def compute_recall(gt_list, pred_list): | ||
score_list = np.zeros([56, 2], dtype=int) | ||
for gt, pred in zip(gt_list, pred_list): | ||
for gt_id in gt: | ||
# pos 0 for counting all existing relations | ||
score_list[gt_id][0] += 1 | ||
if gt_id in pred: | ||
# pos 1 for counting relations that is recalled | ||
score_list[gt_id][1] += 1 | ||
score_list = score_list[6:] | ||
# to avoid nan, but test set does not have empty predict | ||
# score_list[:,0][score_list[:,0] == 0] = 1 | ||
meanrecall = np.mean(score_list[:, 1] / score_list[:, 0]) | ||
|
||
return meanrecall | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser(description='MMDet eval a model') | ||
parser.add_argument('input_path', help='input file path') | ||
parser.add_argument('output_path', help='output file path') | ||
|
||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def main(): | ||
args = parse_args() | ||
submit_dir = os.path.join(args.input_path, 'res') | ||
groundtruth_dir = os.path.join(args.input_path, 'ref') | ||
|
||
gt_list = [] | ||
with open(os.path.join(groundtruth_dir, 'psg_cls_gt.txt'), 'r') as reader: | ||
for line in reader.readlines(): | ||
gt_list.append( | ||
[int(label) for label in line.strip('/n').split(' ')]) | ||
|
||
pred_list = [] | ||
with open(os.path.join(submit_dir, 'result.txt'), 'r') as reader: | ||
for line in reader.readlines(): | ||
pred_list.append( | ||
[int(label) for label in line.strip('/n').split(' ')]) | ||
|
||
assert np.array(pred_list).shape == ( | ||
500, 3), 'make sure the submitted file is 500 x 3' | ||
result = compute_recall(gt_list, pred_list) | ||
output_filename = os.path.join(args.output_path, 'scores.txt') | ||
|
||
with open(output_filename, 'w') as f3: | ||
f3.write('score: {}\n'.format(result)) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Oops, something went wrong.