Skip to content

Commit

Permalink
MSD balanced dataset and plots
Browse files Browse the repository at this point in the history
  • Loading branch information
drkostas committed May 17, 2022
1 parent e5959a1 commit 1a15596
Show file tree
Hide file tree
Showing 9 changed files with 185 additions and 37 deletions.
2 changes: 1 addition & 1 deletion SegFormer/demo/class_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def voc_palette():

def msd_palette():
"""MSD palette for external use."""
return [[128, 128, 128], [0, 64, 0]]
return [[75, 0, 130], [255, 255, 0]]


dataset_aliases = {
Expand Down
50 changes: 44 additions & 6 deletions SegFormer/demo/image_demo.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,50 @@
from argparse import ArgumentParser

from mmseg.apis import inference_segmentor, init_segmentor, show_result_pyplot
import matplotlib.pyplot as plt
from glob import glob
import mmcv
from mmseg.apis import inference_segmentor, init_segmentor
from class_names import get_palette
from imageio import imread
import matplotlib.cm as cm
my_cmap = cm.Reds
my_cmap.set_under('k', alpha=0)


def show_result_pyplot(model, img, result, palette=None, fig_size=(15, 10), train_or_test='train'):
"""Visualize the segmentation results on the image.
Args:
model (nn.Module): The loaded segmentor.
img (str or np.ndarray): Image filename or loaded image.
result (list): The segmentation result.
palette (list[list[int]]] | None): The palette of segmentation
map. If None is given, random palette will be generated.
Default: None
fig_size (tuple): Figure size of the pyplot figure.
train_or_test (str): 'train' or 'test'.
"""

if hasattr(model, 'module'):
model = model.module
img = model.show_result(img, result, palette=palette, show=False)
plt.figure(figsize=fig_size)
fig, ax = plt.subplots(1, 1, figsize=fig_size)
ax.imshow(mmcv.bgr2rgb(img), alpha=1.0)
if train_or_test == 'train':
img_annot = img.replace("images", "annotations")
img_annot = imread(img_annot)
ax.imshow(img_annot, cmap=my_cmap, interpolation='none',
clim=[0.9, 1], alpha=.4)
plt.show()


def main():
parser = ArgumentParser()
parser.add_argument('img', help='Image file')
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument('--num', help='Number of images to show', type=int)
parser.add_argument('--train_or_test', help='Number of images to show', default='train', type=str)
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
Expand All @@ -19,10 +55,12 @@ def main():

# build the model from a config file and a checkpoint file
model = init_segmentor(args.config, args.checkpoint, device=args.device)
# test a single image
result = inference_segmentor(model, args.img)
# show the results
show_result_pyplot(model, args.img, result, get_palette(args.palette))
for img in glob(f'{args.img}/*.png')[:args.num]:
print("Plotting: ", img)
# test a single image
result = inference_segmentor(model, img)
# show the results
show_result_pyplot(model, img, result, get_palette(args.palette), args.train_or_test)


if __name__ == '__main__':
Expand Down
67 changes: 67 additions & 0 deletions SegFormer/local_configs/_base_/datasets/msd_balanced.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# dataset settings
dataset_type = 'MSDBalancedDataset'
data_root = 'data/MSD/Task09_Spleen_RGB_2D_512_Balanced'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (512, 512)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(512, 512), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(512, 512),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
data_root='../../data/MSD/Task09_Spleen_RGB_2D_512_Balanced',
img_dir='images/training',
ann_dir='annotations/training',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_root='../../data/MSD/Task09_Spleen_RGB_2D_512_Balanced',
img_dir='images/training',
ann_dir='annotations/training',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_root='../../data/MSD/Task09_Spleen_RGB_2D_512_Balanced',
img_dir='images/training',
ann_dir='annotations/training',
pipeline=test_pipeline)
# val=dict(
# type=dataset_type,
# data_root='../../data/MSD/Task09_Spleen_RGB_2D_512_Balanced',
# img_dir='images/validation',
# ann_dir='annotations/validation',
# pipeline=test_pipeline),
# test=dict(
# type=dataset_type,
# data_root='../../data/MSD/Task09_Spleen_RGB_2D_512_Balanced',
# img_dir='images/validation',
# ann_dir='annotations/validation',
# pipeline=test_pipeline)
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
_base_ = [
'../../_base_/models/segformer.py',
'../../_base_/datasets/msd_balanced.py',
'../../_base_/default_runtime.py',
'../../_base_/schedules/schedule_20k.py'
]

# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
find_unused_parameters = True
model = dict(
type='EncoderDecoder',
pretrained='../../pretrained/ImageNet-1K/mit_b0.pth',
backbone=dict(
type='mit_b0',
style='pytorch'),
decode_head=dict(
type='SegFormerHead',
in_channels=[32, 64, 160, 256],
in_index=[0, 1, 2, 3],
feature_strides=[4, 8, 16, 32],
channels=128,
dropout_ratio=0.1,
num_classes=150,
norm_cfg=norm_cfg,
align_corners=False,
decoder_params=dict(embed_dim=256),
loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))

# optimizer
optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01,
paramwise_cfg=dict(custom_keys={'pos_block': dict(decay_mult=0.),
'norm': dict(decay_mult=0.),
'head': dict(lr_mult=10.)
}))

lr_config = dict(_delete_=True, policy='poly',
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6,
power=1.0, min_lr=0.0, by_epoch=False)

data = dict(samples_per_gpu=2)
evaluation = dict(interval=16000, metric='mIoU')
3 changes: 2 additions & 1 deletion SegFormer/mmseg/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .ade import ADE20KDataset
from .msd import MSDDataset
from .msd_marked import MSDMarkedDataset
from .msd_balanced import MSDBalancedDataset
from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset
from .chase_db1 import ChaseDB1Dataset
from .cityscapes import CityscapesDataset
Expand All @@ -19,5 +20,5 @@
'DATASETS', 'build_dataset', 'PIPELINES', 'CityscapesDataset',
'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset',
'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset', 'STAREDataset', 'MapillaryDataset', 'CocoStuff',
'MSDDataset', 'MSDMarkedDataset'
'MSDDataset', 'MSDMarkedDataset', 'MSDBalancedDataset'
]
24 changes: 24 additions & 0 deletions SegFormer/mmseg/datasets/msd_balanced.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from .builder import DATASETS
from .custom import CustomDataset


@DATASETS.register_module()
class MSDBalancedDataset(CustomDataset):
"""ADE20K dataset.
In segmentation map annotation for ADE20K, 0 stands for background, which
is not included in 150 categories. ``reduce_zero_label`` is fixed to True.
The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to
'.png'.
"""
CLASSES = (
"no ailment", "ailment")

PALETTE = [[120, 120, 120], [92, 0, 255]]

def __init__(self, **kwargs):
super(MSDBalancedDataset, self).__init__(
img_suffix='.png',
seg_map_suffix='.png',
reduce_zero_label=True,
**kwargs)
22 changes: 0 additions & 22 deletions data/MSD/Task09_Spleen/dataset.json

This file was deleted.

1 change: 0 additions & 1 deletion models/20220515_165920.log.json

This file was deleted.

6 changes: 0 additions & 6 deletions models/20220515_170407.log.json

This file was deleted.

0 comments on commit 1a15596

Please sign in to comment.