diff --git a/.gitignore b/.gitignore
index 6ef135b..ff3359d 100644
--- a/.gitignore
+++ b/.gitignore
@@ -148,3 +148,8 @@ dmypy.json
tools/custom_tools
scripts/_depreciated
# configs/detseg
+
+ce7454.zip
+ce7454_tools/
+ce7454/checkpoints
+ce7454/results
diff --git a/.isort.cfg b/.isort.cfg
index c027277..fb68330 100644
--- a/.isort.cfg
+++ b/.isort.cfg
@@ -1,2 +1,2 @@
[settings]
-known_third_party = PIL,cog,cv2,detectron2,grade,graphviz,h5py,matplotlib,mmcv,mmdet,numpy,packaging,panopticapi,pycocotools,setuptools,six,terminaltables,torch,torchvision,tqdm,xmltodict
+known_third_party = PIL,cog,cv2,dataset,detectron2,evaluator,grade,graphviz,h5py,matplotlib,mmcv,mmdet,numpy,packaging,panopticapi,pycocotools,setuptools,six,terminaltables,torch,torchvision,tqdm,trainer,xmltodict
diff --git a/README.md b/README.md
index 8adbe9c..9b22b60 100644
--- a/README.md
+++ b/README.md
@@ -30,7 +30,7 @@
-
+
@@ -59,11 +59,12 @@
---
## Updates
+- **Sep 4, 2022**: We introduce the PSG Classification Task for NTU CE7454 Coursework, as described [here](https://github.com/Jingkang50/OpenPSG/blob/main/ce7454).
- **Aug 21, 2022**: We provide guidance on PSG challenge registration [here](https://github.com/Jingkang50/OpenPSG/blob/main/psg_challenge.md).
- **Aug 12, 2022**: Replicate demo and Cloud API is added, try it [here](https://replicate.com/cjwbw/openpsg)!
- **Aug 10, 2022**: We launched [Hugging Face demo 🤗](https://huggingface.co/spaces/mmlab-ntu/OpenPSG). Try it with your scene!
- **Aug 5, 2022**: The PSG Challenge will be available on [International Algorithm Case Competition ](https://iacc.pazhoulab-huangpu.com/)! All the data will be available there then! Stay tuned!
-- **July 25, 2022**: :boom: We are preparing a PSG competition with [ECCV'22 SenseHuman Workshop](https://sense-human.github.io) and [International Algorithm Case Competition](https://iacc.pazhoulab-huangpu.com/), starting from Aug 6, with a prize pool of :money_mouth_face: **US$150K** :money_mouth_face:. Join us on our [Slack](https://join.slack.com/t/psgdataset/shared_invite/zt-1d14sdkw3-59pJdrp6gLAHuBObPL91qw) to stay updated!
+- **July 25, 2022**: :boom: We are preparing a PSG competition with [ECCV'22 SenseHuman Workshop](https://sense-human.github.io) and [International Algorithm Case Competition](https://iacc.pazhoulab-huangpu.com/), starting from Aug 6, with a prize pool of :money_mouth_face: **US$150K** :money_mouth_face:. Join us on our [Slack](https://join.slack.com/t/psgdataset/shared_invite/zt-1f8wkjfky-~uikum1YA1giLGZphFZdAQ) to stay updated!
- **July 25, 2022**: PSG paper is available on [arXiv](https://arxiv.org/abs/2207.11247).
- **July 3, 2022**: PSG is accepted by ECCV'22.
## What is PSG Task?
@@ -249,14 +250,14 @@ python tools/test.py \
## Model Zoo
-Method | Backbone | #Epoch | R/mR@20 | R/mR@50 | R/mR@100 | ckpt
---- | --- | --- | --- | --- |--- |--- |
-IMP | ResNet-50 | 12 | 16.5 / 6.52 | 18.2 / 7.05 | 18.6 / 7.23 | [link](https://entuedu-my.sharepoint.com/:f:/g/personal/jingkang001_e_ntu_edu_sg/EiTgJ9q2h3hDpyXSdu6BtlQBHAZNwNaYmcO7SElxhkIFXw?e=8fytHc) |
-MOTIFS | ResNet-50 | 12 | 20.0 / 9.10 | 21.7 / 9.57 | 22.0 / 9.69 | [link](https://entuedu-my.sharepoint.com/:f:/g/personal/jingkang001_e_ntu_edu_sg/Eh4hvXIspUFKpNa_75qwDoEBJTCIozTLzm49Ste6HaoPow?e=ZdAs6z) |
-VCTree | ResNet-50 | 12 | 20.6 / 9.70 | 22.1 / 10.2 | 22.5 / 10.2 | [link](https://entuedu-my.sharepoint.com/:f:/g/personal/jingkang001_e_ntu_edu_sg/EhKfi9kqAd9CnSoHztQIChABeBjBD3hF7DflrNCjlHfh9A?e=lWa1bd) |
-GPSNet | ResNet-50 | 12 | 17.8 / 7.03 | 19.6 / 7.49 | 20.1 / 7.67 | [link](https://entuedu-my.sharepoint.com/:f:/g/personal/jingkang001_e_ntu_edu_sg/EipIhZgVgx1LuK2RUmjRg2sB8JqxMIS5GnPDHeaYy5GF6A?e=5j53VF) |
-PSGTR | ResNet-50 | 60 | 28.4 / 16.6 | 34.4 / 20.8 | 36.3 / 22.1 | [link](https://entuedu-my.sharepoint.com/:f:/g/personal/jingkang001_e_ntu_edu_sg/Eonc-KwOxg9EmdtGDX6ss-gB35QpKDnN_1KSWOj6U8sZwQ?e=zdqwqP) |
-PSGFormer | ResNet-50 | 60 | 18.0 / 14.8 | 19.6 / 17.0 | 20.1 / 17.6 | [link](https://entuedu-my.sharepoint.com/:f:/g/personal/jingkang001_e_ntu_edu_sg/EnaJchJzJPtGrkl4k09evPIB5JUkkDZ2tSS9F-Hd-1KYzA?e=9QA8Nc) |
+Method | Backbone | #Epoch | R/mR@20 | R/mR@50 | R/mR@100 | ckpt | SHA256
+--- | --- | --- | --- | --- |--- |--- |--- |
+IMP | ResNet-50 | 12 | 16.5 / 6.52 | 18.2 / 7.05 | 18.6 / 7.23 | [link](https://entuedu-my.sharepoint.com/:f:/g/personal/jingkang001_e_ntu_edu_sg/EiTgJ9q2h3hDpyXSdu6BtlQBHAZNwNaYmcO7SElxhkIFXw?e=8fytHc) |7be2842b6664e2b9ef6c7c05d27fde521e2401ffe67dbb936438c69e98f9783c |
+MOTIFS | ResNet-50 | 12 | 20.0 / 9.10 | 21.7 / 9.57 | 22.0 / 9.69 | [link](https://entuedu-my.sharepoint.com/:f:/g/personal/jingkang001_e_ntu_edu_sg/Eh4hvXIspUFKpNa_75qwDoEBJTCIozTLzm49Ste6HaoPow?e=ZdAs6z) | 956471959ca89acae45c9533fb9f9a6544e650b8ea18fe62cdead495b38751b8 |
+VCTree | ResNet-50 | 12 | 20.6 / 9.70 | 22.1 / 10.2 | 22.5 / 10.2 | [link](https://entuedu-my.sharepoint.com/:f:/g/personal/jingkang001_e_ntu_edu_sg/EhKfi9kqAd9CnSoHztQIChABeBjBD3hF7DflrNCjlHfh9A?e=lWa1bd) |e5fdac7e6cc8d9af7ae7027f6d0948bf414a4a605ed5db4d82c5d72de55c9b58 |
+GPSNet | ResNet-50 | 12 | 17.8 / 7.03 | 19.6 / 7.49 | 20.1 / 7.67 | [link](https://entuedu-my.sharepoint.com/:f:/g/personal/jingkang001_e_ntu_edu_sg/EipIhZgVgx1LuK2RUmjRg2sB8JqxMIS5GnPDHeaYy5GF6A?e=5j53VF) | 98cd7450925eb88fa311a20fce74c96f712e45b7f29857c5cdf9b9dd57f59c51 |
+PSGTR | ResNet-50 | 60 | 28.4 / 16.6 | 34.4 / 20.8 | 36.3 / 22.1 | [link](https://entuedu-my.sharepoint.com/:f:/g/personal/jingkang001_e_ntu_edu_sg/Eonc-KwOxg9EmdtGDX6ss-gB35QpKDnN_1KSWOj6U8sZwQ?e=zdqwqP) | 1c4ddcbda74686568b7e6b8145f7f33030407e27e390c37c23206f95c51829ed |
+PSGFormer | ResNet-50 | 60 | 18.0 / 14.8 | 19.6 / 17.0 | 20.1 / 17.6 | [link](https://entuedu-my.sharepoint.com/:f:/g/personal/jingkang001_e_ntu_edu_sg/EnaJchJzJPtGrkl4k09evPIB5JUkkDZ2tSS9F-Hd-1KYzA?e=9QA8Nc) | 2f0015ce67040fa00b65986f6ce457c4f8cc34720f7e47a656b462b696a013b7 |
---
## Contributing
diff --git a/ce7454/README.md b/ce7454/README.md
new file mode 100644
index 0000000..e8915de
--- /dev/null
+++ b/ce7454/README.md
@@ -0,0 +1,43 @@
+# CE7454 Assignment 1: PSG Classification
+
+## Get Started
+
+We specially build the tiny codebase here (in this directory) to help our students to quickly get started.
+
+First off, let's clone or fork the codebase and enter in the `ce7454` directory. Don't forget to star the repo if you find the assignment is interesting and instructive.
+Then we download the data [here](https://entuedu-my.sharepoint.com/:f:/g/personal/jingkang001_e_ntu_edu_sg/El0TuNPJWyVJqw6agAzd7l0BU_Tr9gJEgUzLbZggaAKyHg?e=cFPqNF) and unzip the data at the correct place. Eventually, your `ce7454` folder should looks like this:
+```
+ce7454
+├── checkpoints
+├── data
+│ ├── coco
+│ │ ├── panoptic_train2017
+│ │ ├── panoptic_val2017
+│ │ ├── train2017
+│ │ └── val2017
+│ └── psg
+│ ├── psg_cls_basic.json
+│ └── psg_val_advanced.json
+├── results
+├── dataset.py
+├── evaluator.py
+├── ...
+```
+
+
+We provide 4500 training data, 500 validation data, and 500 test data.
+Notice that there might not be exactly 4500 training images (so are val/test images) as some images are annotated twice, and we consider one annotation as one sample.
+
+Then, we need to setup the environment. We use `conda` to manage our dependencies. The code is heavily dependent on PyTorch.
+
+```bash
+conda install python=3.7 pytorch=1.7.0 torchvision=0.8.0 torchaudio==0.7.0 cudatoolkit=10.1
+pip install tqdm
+```
+
+Finally, make sure your working directory is `ce7454`, and let's train the model!
+```bash
+python main.py
+```
+
+You can explore the project by reading from the `main.py` and dive in. Good Luck!!
diff --git a/ce7454/dataset.py b/ce7454/dataset.py
new file mode 100644
index 0000000..6800b44
--- /dev/null
+++ b/ce7454/dataset.py
@@ -0,0 +1,84 @@
+import io
+import json
+import logging
+import os
+
+import torch
+import torchvision.transforms as trn
+from PIL import Image, ImageFile
+from torch.utils.data import Dataset
+
+# to fix "OSError: image file is truncated"
+ImageFile.LOAD_TRUNCATED_IMAGES = True
+
+
+class Convert:
+ def __init__(self, mode='RGB'):
+ self.mode = mode
+
+ def __call__(self, image):
+ return image.convert(self.mode)
+
+
+def get_transforms(stage: str):
+ mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
+ if stage == 'train':
+ return trn.Compose([
+ Convert('RGB'),
+ trn.Resize((1333, 800)),
+ trn.RandomHorizontalFlip(),
+ trn.RandomCrop((1333, 800), padding=4),
+ trn.ToTensor(),
+ trn.Normalize(mean, std),
+ ])
+
+ elif stage in ['val', 'test']:
+ return trn.Compose([
+ Convert('RGB'),
+ trn.Resize((1333, 800)),
+ trn.ToTensor(),
+ trn.Normalize(mean, std),
+ ])
+
+
+class PSGClsDataset(Dataset):
+ def __init__(
+ self,
+ stage,
+ root='./data/coco/',
+ num_classes=56,
+ ):
+ super(PSGClsDataset, self).__init__()
+ with open('./data/psg/psg_cls_basic.json') as f:
+ dataset = json.load(f)
+ self.imglist = [
+ d for d in dataset['data']
+ if d['image_id'] in dataset[f'{stage}_image_ids']
+ ]
+ self.root = root
+ self.transform_image = get_transforms(stage)
+ self.num_classes = num_classes
+
+ def __len__(self):
+ return len(self.imglist)
+
+ def __getitem__(self, index):
+ sample = self.imglist[index]
+ path = os.path.join(self.root, sample['file_name'])
+ try:
+ with open(path, 'rb') as f:
+ content = f.read()
+ filebytes = content
+ buff = io.BytesIO(filebytes)
+ image = Image.open(buff).convert('RGB')
+ sample['data'] = self.transform_image(image)
+ except Exception as e:
+ logging.error('Error, cannot read [{}]'.format(path))
+ raise e
+ # Generate Soft Label
+ soft_label = torch.Tensor(self.num_classes)
+ soft_label.fill_(0)
+ soft_label[sample['relations']] = 1
+ sample['soft_label'] = soft_label
+ del sample['relations']
+ return sample
diff --git a/ce7454/evaluator.py b/ce7454/evaluator.py
new file mode 100644
index 0000000..f4fa37d
--- /dev/null
+++ b/ce7454/evaluator.py
@@ -0,0 +1,76 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+
+
+class Evaluator:
+ def __init__(
+ self,
+ net: nn.Module,
+ k: int,
+ ):
+ self.net = net
+ self.k = k
+
+ def eval_recall(
+ self,
+ data_loader: DataLoader,
+ ):
+ self.net.eval()
+ loss_avg = 0.0
+ pred_list, gt_list = [], []
+ with torch.no_grad():
+ for batch in data_loader:
+ data = batch['data'].cuda()
+ logits = self.net(data)
+ prob = torch.sigmoid(logits)
+ target = batch['soft_label'].cuda()
+ loss = F.binary_cross_entropy(prob, target, reduction='sum')
+ loss_avg += float(loss.data)
+ # gather prediction and gt
+ pred = torch.topk(prob.data, self.k)[1]
+ pred = pred.cpu().detach().tolist()
+ pred_list.extend(pred)
+ for soft_label in batch['soft_label']:
+ gt_label = (soft_label == 1).nonzero(as_tuple=True)[0]\
+ .cpu().detach().tolist()
+ gt_list.append(gt_label)
+
+ # compute mean recall
+ score_list = np.zeros([56, 2], dtype=int)
+ for gt, pred in zip(gt_list, pred_list):
+ for gt_id in gt:
+ # pos 0 for counting all existing relations
+ score_list[gt_id][0] += 1
+ if gt_id in pred:
+ # pos 1 for counting relations that is recalled
+ score_list[gt_id][1] += 1
+ score_list = score_list[6:]
+ # to avoid nan
+ score_list[:, 0][score_list[:, 0] == 0] = 1
+ meanrecall = np.mean(score_list[:, 1] / score_list[:, 0])
+
+ metrics = {}
+ metrics['test_loss'] = loss_avg / len(data_loader)
+ metrics['mean_recall'] = meanrecall
+
+ return metrics
+
+ def submit(
+ self,
+ data_loader: DataLoader,
+ ):
+ self.net.eval()
+
+ pred_list = []
+ with torch.no_grad():
+ for batch in data_loader:
+ data = batch['data'].cuda()
+ logits = self.net(data)
+ prob = torch.sigmoid(logits)
+ pred = torch.topk(prob.data, self.k)[1]
+ pred = pred.cpu().detach().tolist()
+ pred_list.extend(pred)
+ return pred_list
diff --git a/ce7454/grade.py b/ce7454/grade.py
new file mode 100644
index 0000000..316ac34
--- /dev/null
+++ b/ce7454/grade.py
@@ -0,0 +1,60 @@
+import argparse
+import os
+
+import numpy as np
+
+
+def compute_recall(gt_list, pred_list):
+ score_list = np.zeros([56, 2], dtype=int)
+ for gt, pred in zip(gt_list, pred_list):
+ for gt_id in gt:
+ # pos 0 for counting all existing relations
+ score_list[gt_id][0] += 1
+ if gt_id in pred:
+ # pos 1 for counting relations that is recalled
+ score_list[gt_id][1] += 1
+ score_list = score_list[6:]
+ # to avoid nan, but test set does not have empty predict
+ # score_list[:,0][score_list[:,0] == 0] = 1
+ meanrecall = np.mean(score_list[:, 1] / score_list[:, 0])
+
+ return meanrecall
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='MMDet eval a model')
+ parser.add_argument('input_path', help='input file path')
+ parser.add_argument('output_path', help='output file path')
+
+ args = parser.parse_args()
+ return args
+
+
+def main():
+ args = parse_args()
+ submit_dir = os.path.join(args.input_path, 'res')
+ groundtruth_dir = os.path.join(args.input_path, 'ref')
+
+ gt_list = []
+ with open(os.path.join(groundtruth_dir, 'psg_cls_gt.txt'), 'r') as reader:
+ for line in reader.readlines():
+ gt_list.append(
+ [int(label) for label in line.strip('/n').split(' ')])
+
+ pred_list = []
+ with open(os.path.join(submit_dir, 'result.txt'), 'r') as reader:
+ for line in reader.readlines():
+ pred_list.append(
+ [int(label) for label in line.strip('/n').split(' ')])
+
+ assert np.array(pred_list).shape == (
+ 500, 3), 'make sure the submitted file is 500 x 3'
+ result = compute_recall(gt_list, pred_list)
+ output_filename = os.path.join(args.output_path, 'scores.txt')
+
+ with open(output_filename, 'w') as f3:
+ f3.write('score: {}\n'.format(result))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/ce7454/main.py b/ce7454/main.py
new file mode 100644
index 0000000..b2b3447
--- /dev/null
+++ b/ce7454/main.py
@@ -0,0 +1,104 @@
+import argparse
+import os
+import time
+
+import torch
+from dataset import PSGClsDataset
+from evaluator import Evaluator
+from torch.utils.data import DataLoader
+from torchvision.models import resnet50
+from trainer import BaseTrainer
+
+parser = argparse.ArgumentParser()
+parser.add_argument('--model_name', type=str, default='res50')
+parser.add_argument('--epoch', type=int, default=36)
+parser.add_argument('--lr', type=float, default=0.001)
+parser.add_argument('--batch_size', type=int, default=16)
+parser.add_argument('--momentum', type=float, default=0.9)
+parser.add_argument('--weight_decay', type=float, default=0.0005)
+
+args = parser.parse_args()
+
+savename = f'{args.model_name}_e{args.epoch}_lr{args.lr}_bs{args.batch_size}_m{args.momentum}_wd{args.weight_decay}'
+os.makedirs('./checkpoints', exist_ok=True)
+os.makedirs('./results', exist_ok=True)
+
+# loading dataset
+train_dataset = PSGClsDataset(stage='train')
+train_dataloader = DataLoader(train_dataset,
+ batch_size=args.batch_size,
+ shuffle=True,
+ num_workers=8)
+
+val_dataset = PSGClsDataset(stage='val')
+val_dataloader = DataLoader(val_dataset,
+ batch_size=32,
+ shuffle=False,
+ num_workers=8)
+
+test_dataset = PSGClsDataset(stage='test')
+test_dataloader = DataLoader(test_dataset,
+ batch_size=32,
+ shuffle=False,
+ num_workers=8)
+print('Data Loaded...', flush=True)
+
+# loading model
+model = resnet50(pretrained=True)
+model.fc = torch.nn.Linear(2048, 56)
+model.cuda()
+print('Model Loaded...', flush=True)
+
+# loading trainer
+trainer = BaseTrainer(model,
+ train_dataloader,
+ learning_rate=args.lr,
+ momentum=args.momentum,
+ weight_decay=args.weight_decay,
+ epochs=args.epoch)
+evaluator = Evaluator(model, k=3)
+
+# train!
+print('Start Training...', flush=True)
+begin_epoch = time.time()
+best_val_recall = 0.0
+for epoch in range(0, args.epoch):
+ train_metrics = trainer.train_epoch()
+ val_metrics = evaluator.eval_recall(val_dataloader)
+
+ # show log
+ print(
+ '{} | Epoch {:3d} | Time {:5d}s | Train Loss {:.4f} | Test Loss {:.3f} | mR {:.2f}'
+ .format(savename, (epoch + 1), int(time.time() - begin_epoch),
+ train_metrics['train_loss'], val_metrics['test_loss'],
+ 100.0 * val_metrics['mean_recall']),
+ flush=True)
+
+ # save model
+ if val_metrics['mean_recall'] >= best_val_recall:
+ torch.save(model.state_dict(), f'./checkpoints/{savename}_best.ckpt')
+ best_val_recall = val_metrics['mean_recall']
+
+print('Training Completed...', flush=True)
+
+# saving result!
+print('Loading Best Ckpt...', flush=True)
+checkpoint = torch.load(f'checkpoints/{savename}_best.ckpt')
+model.load_state_dict(checkpoint)
+test_evaluator = Evaluator(model, k=3)
+check_metrics = test_evaluator.eval_recall(val_dataloader)
+if best_val_recall == check_metrics['mean_recall']:
+ print('Successfully load best checkpoint with acc {:.2f}'.format(
+ 100 * best_val_recall),
+ flush=True)
+else:
+ print('Fail to load best checkpoint')
+result = test_evaluator.submit(test_dataloader)
+
+# save into the file
+with open(f'results/{savename}_{best_val_recall}.txt', 'w') as writer:
+ for label_list in result:
+ a = [str(x) for x in label_list]
+ save_str = ' '.join(a)
+ writer.writelines(save_str + '\n')
+print('Result Saved!', flush=True)
diff --git a/ce7454/trainer.py b/ce7454/trainer.py
new file mode 100644
index 0000000..8f68bae
--- /dev/null
+++ b/ce7454/trainer.py
@@ -0,0 +1,72 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+
+
+def cosine_annealing(step, total_steps, lr_max, lr_min):
+ return lr_min + (lr_max -
+ lr_min) * 0.5 * (1 + np.cos(step / total_steps * np.pi))
+
+
+class BaseTrainer:
+ def __init__(self,
+ net: nn.Module,
+ train_loader: DataLoader,
+ learning_rate: float = 0.1,
+ momentum: float = 0.9,
+ weight_decay: float = 0.0005,
+ epochs: int = 100) -> None:
+ self.net = net
+ self.train_loader = train_loader
+
+ self.optimizer = torch.optim.SGD(
+ net.parameters(),
+ learning_rate,
+ momentum=momentum,
+ weight_decay=weight_decay,
+ nesterov=True,
+ )
+
+ self.scheduler = torch.optim.lr_scheduler.LambdaLR(
+ self.optimizer,
+ lr_lambda=lambda step: cosine_annealing(
+ step,
+ epochs * len(train_loader),
+ 1, # since lr_lambda computes multiplicative factor
+ 1e-6 / learning_rate,
+ ),
+ )
+
+ def train_epoch(self):
+ self.net.train() # enter train mode
+
+ loss_avg = 0.0
+ train_dataiter = iter(self.train_loader)
+
+ for train_step in tqdm(range(1, len(train_dataiter) + 1)):
+ # for train_step in tqdm(range(1, 5)):
+ batch = next(train_dataiter)
+ data = batch['data'].cuda()
+ target = batch['soft_label'].cuda()
+ # forward
+ logits = self.net(data)
+ loss = F.binary_cross_entropy_with_logits(logits,
+ target,
+ reduction='sum')
+ # backward
+ self.optimizer.zero_grad()
+ loss.backward()
+ self.optimizer.step()
+ self.scheduler.step()
+
+ # exponential moving average, show smooth values
+ with torch.no_grad():
+ loss_avg = loss_avg * 0.8 + float(loss) * 0.2
+
+ metrics = {}
+ metrics['train_loss'] = loss_avg
+
+ return metrics
diff --git a/configs/_base_/datasets/psg.py b/configs/_base_/datasets/psg.py
index d8b31db..ff77bd0 100644
--- a/configs/_base_/datasets/psg.py
+++ b/configs/_base_/datasets/psg.py
@@ -1,7 +1,7 @@
# dataset settings
dataset_type = 'PanopticSceneGraphDataset'
# ann_file = './data/psg/psg.json' # full data, available after PSG challenge
-ann_file = './data/psg/psg_train_val.json' # for PSG challenge development
+ann_file = './data/psg/psg.json' # './data/psg/psg_train_val.json' for PSG challenge development
# ann_file = './data/psg/psg_val_test.json' # for PSG challenge submission
coco_root = './data/coco'
diff --git a/tools/Visualize_Dataset.ipynb b/tools/Visualize_Dataset.ipynb
index e28d5e4..354b70b 100644
--- a/tools/Visualize_Dataset.ipynb
+++ b/tools/Visualize_Dataset.ipynb
@@ -19,6 +19,7 @@
"%autoreload 2\n",
"\n",
"import os\n",
+ "os.chdir('..')\n",
"\n",
"from pathlib import Path\n",
"import matplotlib.pyplot as plt\n",
@@ -72,7 +73,6 @@
],
"source": [
"# set working path as home dir to easy access data\n",
- "os.chdir('..')\n",
"psg_dataset_file = load_json(Path(\"data/psg/psg.json\"))\n",
"print('keys: ', list(psg_dataset_file.keys()))"
]
@@ -384,7 +384,7 @@
}
],
"source": [
- "from vis_tools.detectron_viz import Visualizer\n",
+ "from openpsg.utils.vis_tools.detectron_viz import Visualizer\n",
"viz = Visualizer(img)\n",
"viz.overlay_instances(\n",
" labels=labels_coco,\n",