diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5746c2b4a..e58bf1c93 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,4 +1,4 @@ -exclude: ^detection/configs +exclude: ^detection/configs, ^segmentation/configs repos: - repo: https://gitlab.com/pycqa/flake8.git rev: 3.8.3 diff --git a/segmentation/configs/_base_/datasets/ade20k.py b/segmentation/configs/_base_/datasets/ade20k.py new file mode 100644 index 000000000..efc8b4bb2 --- /dev/null +++ b/segmentation/configs/_base_/datasets/ade20k.py @@ -0,0 +1,54 @@ +# dataset settings +dataset_type = 'ADE20KDataset' +data_root = 'data/ade/ADEChallengeData2016' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +crop_size = (512, 512) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', reduce_zero_label=True), + dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(2048, 512), + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/training', + ann_dir='annotations/training', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/validation', + ann_dir='annotations/validation', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/validation', + ann_dir='annotations/validation', + pipeline=test_pipeline)) diff --git a/segmentation/configs/_base_/datasets/chase_db1.py b/segmentation/configs/_base_/datasets/chase_db1.py new file mode 100644 index 000000000..298594ea9 --- /dev/null +++ b/segmentation/configs/_base_/datasets/chase_db1.py @@ -0,0 +1,59 @@ +# dataset settings +dataset_type = 'ChaseDB1Dataset' +data_root = 'data/CHASE_DB1' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +img_scale = (960, 999) +crop_size = (128, 128) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=img_scale, + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) +] + +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type='RepeatDataset', + times=40000, + dataset=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/training', + ann_dir='annotations/training', + pipeline=train_pipeline)), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/validation', + ann_dir='annotations/validation', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/validation', + ann_dir='annotations/validation', + pipeline=test_pipeline)) diff --git a/segmentation/configs/_base_/datasets/cityscapes.py b/segmentation/configs/_base_/datasets/cityscapes.py new file mode 100644 index 000000000..f21867c63 --- /dev/null +++ b/segmentation/configs/_base_/datasets/cityscapes.py @@ -0,0 +1,54 @@ +# dataset settings +dataset_type = 'CityscapesDataset' +data_root = 'data/cityscapes/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +crop_size = (512, 1024) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(2048, 1024), + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=2, + workers_per_gpu=2, + train=dict( + type=dataset_type, + data_root=data_root, + img_dir='leftImg8bit/train', + ann_dir='gtFine/train', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='leftImg8bit/val', + ann_dir='gtFine/val', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='leftImg8bit/val', + ann_dir='gtFine/val', + pipeline=test_pipeline)) diff --git a/segmentation/configs/_base_/datasets/cityscapes_1024x1024.py b/segmentation/configs/_base_/datasets/cityscapes_1024x1024.py new file mode 100644 index 000000000..f98d92972 --- /dev/null +++ b/segmentation/configs/_base_/datasets/cityscapes_1024x1024.py @@ -0,0 +1,35 @@ +_base_ = './cityscapes.py' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +crop_size = (1024, 1024) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(2048, 1024), + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + train=dict(pipeline=train_pipeline), + val=dict(pipeline=test_pipeline), + test=dict(pipeline=test_pipeline)) diff --git a/segmentation/configs/_base_/datasets/cityscapes_768x768.py b/segmentation/configs/_base_/datasets/cityscapes_768x768.py new file mode 100644 index 000000000..fde9d7c7d --- /dev/null +++ b/segmentation/configs/_base_/datasets/cityscapes_768x768.py @@ -0,0 +1,35 @@ +_base_ = './cityscapes.py' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +crop_size = (768, 768) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=(2049, 1025), ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(2049, 1025), + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + train=dict(pipeline=train_pipeline), + val=dict(pipeline=test_pipeline), + test=dict(pipeline=test_pipeline)) diff --git a/segmentation/configs/_base_/datasets/cityscapes_769x769.py b/segmentation/configs/_base_/datasets/cityscapes_769x769.py new file mode 100644 index 000000000..336c7b254 --- /dev/null +++ b/segmentation/configs/_base_/datasets/cityscapes_769x769.py @@ -0,0 +1,35 @@ +_base_ = './cityscapes.py' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +crop_size = (769, 769) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=(2049, 1025), ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(2049, 1025), + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + train=dict(pipeline=train_pipeline), + val=dict(pipeline=test_pipeline), + test=dict(pipeline=test_pipeline)) diff --git a/segmentation/configs/_base_/datasets/cityscapes_832x832.py b/segmentation/configs/_base_/datasets/cityscapes_832x832.py new file mode 100644 index 000000000..b9325cc00 --- /dev/null +++ b/segmentation/configs/_base_/datasets/cityscapes_832x832.py @@ -0,0 +1,35 @@ +_base_ = './cityscapes.py' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +crop_size = (832, 832) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(2048, 1024), + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + train=dict(pipeline=train_pipeline), + val=dict(pipeline=test_pipeline), + test=dict(pipeline=test_pipeline)) diff --git a/segmentation/configs/_base_/datasets/cityscapes_896x896.py b/segmentation/configs/_base_/datasets/cityscapes_896x896.py new file mode 100644 index 000000000..608e05b7c --- /dev/null +++ b/segmentation/configs/_base_/datasets/cityscapes_896x896.py @@ -0,0 +1,35 @@ +_base_ = './cityscapes.py' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +crop_size = (896, 896) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(2048, 1024), + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + train=dict(pipeline=train_pipeline), + val=dict(pipeline=test_pipeline), + test=dict(pipeline=test_pipeline)) diff --git a/segmentation/configs/_base_/datasets/coco-stuff10k.py b/segmentation/configs/_base_/datasets/coco-stuff10k.py new file mode 100644 index 000000000..ec0496928 --- /dev/null +++ b/segmentation/configs/_base_/datasets/coco-stuff10k.py @@ -0,0 +1,57 @@ +# dataset settings +dataset_type = 'COCOStuffDataset' +data_root = 'data/coco_stuff10k' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +crop_size = (512, 512) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', reduce_zero_label=True), + dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(2048, 512), + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type=dataset_type, + data_root=data_root, + reduce_zero_label=True, + img_dir='images/train2014', + ann_dir='annotations/train2014', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + data_root=data_root, + reduce_zero_label=True, + img_dir='images/test2014', + ann_dir='annotations/test2014', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + reduce_zero_label=True, + img_dir='images/test2014', + ann_dir='annotations/test2014', + pipeline=test_pipeline)) diff --git a/segmentation/configs/_base_/datasets/coco-stuff164k.py b/segmentation/configs/_base_/datasets/coco-stuff164k.py new file mode 100644 index 000000000..a6a38f2ac --- /dev/null +++ b/segmentation/configs/_base_/datasets/coco-stuff164k.py @@ -0,0 +1,54 @@ +# dataset settings +dataset_type = 'COCOStuffDataset' +data_root = 'data/coco_stuff164k' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +crop_size = (512, 512) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(2048, 512), + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/train2017', + ann_dir='annotations/train2017', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/val2017', + ann_dir='annotations/val2017', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/val2017', + ann_dir='annotations/val2017', + pipeline=test_pipeline)) diff --git a/segmentation/configs/_base_/datasets/drive.py b/segmentation/configs/_base_/datasets/drive.py new file mode 100644 index 000000000..06e8ff606 --- /dev/null +++ b/segmentation/configs/_base_/datasets/drive.py @@ -0,0 +1,59 @@ +# dataset settings +dataset_type = 'DRIVEDataset' +data_root = 'data/DRIVE' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +img_scale = (584, 565) +crop_size = (64, 64) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=img_scale, + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) +] + +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type='RepeatDataset', + times=40000, + dataset=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/training', + ann_dir='annotations/training', + pipeline=train_pipeline)), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/validation', + ann_dir='annotations/validation', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/validation', + ann_dir='annotations/validation', + pipeline=test_pipeline)) diff --git a/segmentation/configs/_base_/datasets/hrf.py b/segmentation/configs/_base_/datasets/hrf.py new file mode 100644 index 000000000..242d790eb --- /dev/null +++ b/segmentation/configs/_base_/datasets/hrf.py @@ -0,0 +1,59 @@ +# dataset settings +dataset_type = 'HRFDataset' +data_root = 'data/HRF' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +img_scale = (2336, 3504) +crop_size = (256, 256) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=img_scale, + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) +] + +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type='RepeatDataset', + times=40000, + dataset=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/training', + ann_dir='annotations/training', + pipeline=train_pipeline)), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/validation', + ann_dir='annotations/validation', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/validation', + ann_dir='annotations/validation', + pipeline=test_pipeline)) diff --git a/segmentation/configs/_base_/datasets/loveda.py b/segmentation/configs/_base_/datasets/loveda.py new file mode 100644 index 000000000..e55335695 --- /dev/null +++ b/segmentation/configs/_base_/datasets/loveda.py @@ -0,0 +1,54 @@ +# dataset settings +dataset_type = 'LoveDADataset' +data_root = 'data/loveDA' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +crop_size = (512, 512) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', reduce_zero_label=True), + dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(1024, 1024), + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type=dataset_type, + data_root=data_root, + img_dir='img_dir/train', + ann_dir='ann_dir/train', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='img_dir/val', + ann_dir='ann_dir/val', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='img_dir/val', + ann_dir='ann_dir/val', + pipeline=test_pipeline)) diff --git a/segmentation/configs/_base_/datasets/mapillary_896x896.py b/segmentation/configs/_base_/datasets/mapillary_896x896.py new file mode 100644 index 000000000..48657319f --- /dev/null +++ b/segmentation/configs/_base_/datasets/mapillary_896x896.py @@ -0,0 +1,55 @@ +# dataset settings +dataset_type = 'MapillaryDataset' +data_root = 'data/Mapillary/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +crop_size = (896, 896) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='MapillaryHack'), + dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 1.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(2048, 1024), + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=2, + workers_per_gpu=2, + train=dict( + type=dataset_type, + data_root='data/Mapillary/', + img_dir=['training/images', 'validation/images'], + ann_dir=['training/labels', 'validation/labels'], + pipeline=train_pipeline), + val=dict( + type='CityscapesDataset', + data_root='data/cityscapes/', + img_dir='leftImg8bit/val', + ann_dir='gtFine/val', + pipeline=test_pipeline), + test=dict( + type='CityscapesDataset', + data_root='data/cityscapes/', + img_dir='leftImg8bit/val', + ann_dir='gtFine/val', + pipeline=test_pipeline)) diff --git a/segmentation/configs/_base_/datasets/nyu_depth_v2.py b/segmentation/configs/_base_/datasets/nyu_depth_v2.py new file mode 100644 index 000000000..bac9443c2 --- /dev/null +++ b/segmentation/configs/_base_/datasets/nyu_depth_v2.py @@ -0,0 +1,59 @@ +# dataset settings +dataset_type = 'NYUDepthV2Dataset' +data_root = 'data/nyu_depth_v2/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) + +crop_size = (480, 480) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', reduce_zero_label=True), + dict(type='Resize', img_scale=(640, 480), ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(640, 480), + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type=dataset_type, + data_root=data_root, + img_dir='image', + ann_dir='label40', + split='train.txt', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='image', + ann_dir='label40', + split='test.txt', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='image', + ann_dir='label40', + split='test.txt', + pipeline=test_pipeline)) diff --git a/segmentation/configs/_base_/datasets/pascal_context.py b/segmentation/configs/_base_/datasets/pascal_context.py new file mode 100644 index 000000000..ff65bad1b --- /dev/null +++ b/segmentation/configs/_base_/datasets/pascal_context.py @@ -0,0 +1,60 @@ +# dataset settings +dataset_type = 'PascalContextDataset' +data_root = 'data/VOCdevkit/VOC2010/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) + +img_scale = (520, 520) +crop_size = (480, 480) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=img_scale, + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type=dataset_type, + data_root=data_root, + img_dir='JPEGImages', + ann_dir='SegmentationClassContext', + split='ImageSets/SegmentationContext/train.txt', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='JPEGImages', + ann_dir='SegmentationClassContext', + split='ImageSets/SegmentationContext/val.txt', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='JPEGImages', + ann_dir='SegmentationClassContext', + split='ImageSets/SegmentationContext/val.txt', + pipeline=test_pipeline)) diff --git a/segmentation/configs/_base_/datasets/pascal_context_59.py b/segmentation/configs/_base_/datasets/pascal_context_59.py new file mode 100644 index 000000000..37585abab --- /dev/null +++ b/segmentation/configs/_base_/datasets/pascal_context_59.py @@ -0,0 +1,60 @@ +# dataset settings +dataset_type = 'PascalContextDataset59' +data_root = 'data/VOCdevkit/VOC2010/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) + +img_scale = (520, 520) +crop_size = (480, 480) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', reduce_zero_label=True), + dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=img_scale, + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type=dataset_type, + data_root=data_root, + img_dir='JPEGImages', + ann_dir='SegmentationClassContext', + split='ImageSets/SegmentationContext/train.txt', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='JPEGImages', + ann_dir='SegmentationClassContext', + split='ImageSets/SegmentationContext/val.txt', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='JPEGImages', + ann_dir='SegmentationClassContext', + split='ImageSets/SegmentationContext/val.txt', + pipeline=test_pipeline)) diff --git a/segmentation/configs/_base_/datasets/pascal_voc12.py b/segmentation/configs/_base_/datasets/pascal_voc12.py new file mode 100644 index 000000000..ba1d42d0c --- /dev/null +++ b/segmentation/configs/_base_/datasets/pascal_voc12.py @@ -0,0 +1,57 @@ +# dataset settings +dataset_type = 'PascalVOCDataset' +data_root = 'data/VOCdevkit/VOC2012' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +crop_size = (512, 512) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(2048, 512), + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type=dataset_type, + data_root=data_root, + img_dir='JPEGImages', + ann_dir='SegmentationClass', + split='ImageSets/Segmentation/train.txt', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='JPEGImages', + ann_dir='SegmentationClass', + split='ImageSets/Segmentation/val.txt', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='JPEGImages', + ann_dir='SegmentationClass', + split='ImageSets/Segmentation/val.txt', + pipeline=test_pipeline)) diff --git a/segmentation/configs/_base_/datasets/pascal_voc12_aug.py b/segmentation/configs/_base_/datasets/pascal_voc12_aug.py new file mode 100644 index 000000000..3f23b6717 --- /dev/null +++ b/segmentation/configs/_base_/datasets/pascal_voc12_aug.py @@ -0,0 +1,9 @@ +_base_ = './pascal_voc12.py' +# dataset settings +data = dict( + train=dict( + ann_dir=['SegmentationClass', 'SegmentationClassAug'], + split=[ + 'ImageSets/Segmentation/train.txt', + 'ImageSets/Segmentation/aug.txt' + ])) diff --git a/segmentation/configs/_base_/datasets/potsdam.py b/segmentation/configs/_base_/datasets/potsdam.py new file mode 100644 index 000000000..f74c4a56c --- /dev/null +++ b/segmentation/configs/_base_/datasets/potsdam.py @@ -0,0 +1,54 @@ +# dataset settings +dataset_type = 'PotsdamDataset' +data_root = 'data/potsdam' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +crop_size = (512, 512) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', reduce_zero_label=True), + dict(type='Resize', img_scale=(512, 512), ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(512, 512), + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type=dataset_type, + data_root=data_root, + img_dir='img_dir/train', + ann_dir='ann_dir/train', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='img_dir/val', + ann_dir='ann_dir/val', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='img_dir/val', + ann_dir='ann_dir/val', + pipeline=test_pipeline)) diff --git a/segmentation/configs/_base_/datasets/stare.py b/segmentation/configs/_base_/datasets/stare.py new file mode 100644 index 000000000..3f71b2548 --- /dev/null +++ b/segmentation/configs/_base_/datasets/stare.py @@ -0,0 +1,59 @@ +# dataset settings +dataset_type = 'STAREDataset' +data_root = 'data/STARE' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +img_scale = (605, 700) +crop_size = (128, 128) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=img_scale, + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) +] + +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type='RepeatDataset', + times=40000, + dataset=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/training', + ann_dir='annotations/training', + pipeline=train_pipeline)), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/validation', + ann_dir='annotations/validation', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/validation', + ann_dir='annotations/validation', + pipeline=test_pipeline)) diff --git a/segmentation/configs/_base_/default_runtime.py b/segmentation/configs/_base_/default_runtime.py new file mode 100644 index 000000000..b564cc4e7 --- /dev/null +++ b/segmentation/configs/_base_/default_runtime.py @@ -0,0 +1,14 @@ +# yapf:disable +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook', by_epoch=False), + # dict(type='TensorboardLoggerHook') + ]) +# yapf:enable +dist_params = dict(backend='nccl') +log_level = 'INFO' +load_from = None +resume_from = None +workflow = [('train', 1)] +cudnn_benchmark = True diff --git a/segmentation/configs/_base_/models/ann_r50-d8.py b/segmentation/configs/_base_/models/ann_r50-d8.py new file mode 100644 index 000000000..a2cb65382 --- /dev/null +++ b/segmentation/configs/_base_/models/ann_r50-d8.py @@ -0,0 +1,46 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='ANNHead', + in_channels=[1024, 2048], + in_index=[2, 3], + channels=512, + project_channels=256, + query_scales=(1, ), + key_pool_scales=(1, 3, 6, 8), + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/segmentation/configs/_base_/models/apcnet_r50-d8.py b/segmentation/configs/_base_/models/apcnet_r50-d8.py new file mode 100644 index 000000000..c8f5316cb --- /dev/null +++ b/segmentation/configs/_base_/models/apcnet_r50-d8.py @@ -0,0 +1,44 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='APCHead', + in_channels=2048, + in_index=3, + channels=512, + pool_scales=(1, 2, 3, 6), + dropout_ratio=0.1, + num_classes=19, + norm_cfg=dict(type='SyncBN', requires_grad=True), + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/segmentation/configs/_base_/models/bisenetv1_r18-d32.py b/segmentation/configs/_base_/models/bisenetv1_r18-d32.py new file mode 100644 index 000000000..40698644b --- /dev/null +++ b/segmentation/configs/_base_/models/bisenetv1_r18-d32.py @@ -0,0 +1,68 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + backbone=dict( + type='BiSeNetV1', + in_channels=3, + context_channels=(128, 256, 512), + spatial_channels=(64, 64, 64, 128), + out_indices=(0, 1, 2), + out_channels=256, + backbone_cfg=dict( + type='ResNet', + in_channels=3, + depth=18, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 1, 1), + strides=(1, 2, 2, 2), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + norm_cfg=norm_cfg, + align_corners=False, + init_cfg=None), + decode_head=dict( + type='FCNHead', + in_channels=256, + in_index=0, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=[ + dict( + type='FCNHead', + in_channels=128, + channels=64, + num_convs=1, + num_classes=19, + in_index=1, + norm_cfg=norm_cfg, + concat_input=False, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + dict( + type='FCNHead', + in_channels=128, + channels=64, + num_convs=1, + num_classes=19, + in_index=2, + norm_cfg=norm_cfg, + concat_input=False, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + ], + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/segmentation/configs/_base_/models/bisenetv2.py b/segmentation/configs/_base_/models/bisenetv2.py new file mode 100644 index 000000000..f8fffeeca --- /dev/null +++ b/segmentation/configs/_base_/models/bisenetv2.py @@ -0,0 +1,80 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained=None, + backbone=dict( + type='BiSeNetV2', + detail_channels=(64, 64, 128), + semantic_channels=(16, 32, 64, 128), + semantic_expansion_ratio=6, + bga_channels=128, + out_indices=(0, 1, 2, 3, 4), + init_cfg=None, + align_corners=False), + decode_head=dict( + type='FCNHead', + in_channels=128, + in_index=0, + channels=1024, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=[ + dict( + type='FCNHead', + in_channels=16, + channels=16, + num_convs=2, + num_classes=19, + in_index=1, + norm_cfg=norm_cfg, + concat_input=False, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + dict( + type='FCNHead', + in_channels=32, + channels=64, + num_convs=2, + num_classes=19, + in_index=2, + norm_cfg=norm_cfg, + concat_input=False, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + dict( + type='FCNHead', + in_channels=64, + channels=256, + num_convs=2, + num_classes=19, + in_index=3, + norm_cfg=norm_cfg, + concat_input=False, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + dict( + type='FCNHead', + in_channels=128, + channels=1024, + num_convs=2, + num_classes=19, + in_index=4, + norm_cfg=norm_cfg, + concat_input=False, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + ], + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/segmentation/configs/_base_/models/ccnet_r50-d8.py b/segmentation/configs/_base_/models/ccnet_r50-d8.py new file mode 100644 index 000000000..794148f57 --- /dev/null +++ b/segmentation/configs/_base_/models/ccnet_r50-d8.py @@ -0,0 +1,44 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='CCHead', + in_channels=2048, + in_index=3, + channels=512, + recurrence=2, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/segmentation/configs/_base_/models/cgnet.py b/segmentation/configs/_base_/models/cgnet.py new file mode 100644 index 000000000..eff8d9458 --- /dev/null +++ b/segmentation/configs/_base_/models/cgnet.py @@ -0,0 +1,35 @@ +# model settings +norm_cfg = dict(type='SyncBN', eps=1e-03, requires_grad=True) +model = dict( + type='EncoderDecoder', + backbone=dict( + type='CGNet', + norm_cfg=norm_cfg, + in_channels=3, + num_channels=(32, 64, 128), + num_blocks=(3, 21), + dilations=(2, 4), + reductions=(8, 16)), + decode_head=dict( + type='FCNHead', + in_channels=256, + in_index=2, + channels=256, + num_convs=0, + concat_input=False, + dropout_ratio=0, + num_classes=19, + norm_cfg=norm_cfg, + loss_decode=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0, + class_weight=[ + 2.5959933, 6.7415504, 3.5354059, 9.8663225, 9.690899, 9.369352, + 10.289121, 9.953208, 4.3097677, 9.490387, 7.674431, 9.396905, + 10.347791, 6.3927646, 10.226669, 10.241062, 10.280587, + 10.396974, 10.055647 + ])), + # model training and testing settings + train_cfg=dict(sampler=None), + test_cfg=dict(mode='whole')) diff --git a/segmentation/configs/_base_/models/danet_r50-d8.py b/segmentation/configs/_base_/models/danet_r50-d8.py new file mode 100644 index 000000000..2c934939f --- /dev/null +++ b/segmentation/configs/_base_/models/danet_r50-d8.py @@ -0,0 +1,44 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='DAHead', + in_channels=2048, + in_index=3, + channels=512, + pam_channels=64, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/segmentation/configs/_base_/models/deeplabv3_r50-d8.py b/segmentation/configs/_base_/models/deeplabv3_r50-d8.py new file mode 100644 index 000000000..d7a43bee0 --- /dev/null +++ b/segmentation/configs/_base_/models/deeplabv3_r50-d8.py @@ -0,0 +1,44 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='ASPPHead', + in_channels=2048, + in_index=3, + channels=512, + dilations=(1, 12, 24, 36), + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/segmentation/configs/_base_/models/deeplabv3_unet_s5-d16.py b/segmentation/configs/_base_/models/deeplabv3_unet_s5-d16.py new file mode 100644 index 000000000..0cd262999 --- /dev/null +++ b/segmentation/configs/_base_/models/deeplabv3_unet_s5-d16.py @@ -0,0 +1,50 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained=None, + backbone=dict( + type='UNet', + in_channels=3, + base_channels=64, + num_stages=5, + strides=(1, 1, 1, 1, 1), + enc_num_convs=(2, 2, 2, 2, 2), + dec_num_convs=(2, 2, 2, 2), + downsamples=(True, True, True, True), + enc_dilations=(1, 1, 1, 1, 1), + dec_dilations=(1, 1, 1, 1), + with_cp=False, + conv_cfg=None, + norm_cfg=norm_cfg, + act_cfg=dict(type='ReLU'), + upsample_cfg=dict(type='InterpConv'), + norm_eval=False), + decode_head=dict( + type='ASPPHead', + in_channels=64, + in_index=4, + channels=16, + dilations=(1, 12, 24, 36), + dropout_ratio=0.1, + num_classes=2, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=128, + in_index=3, + channels=64, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=2, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='slide', crop_size=256, stride=170)) diff --git a/segmentation/configs/_base_/models/deeplabv3plus_r50-d8.py b/segmentation/configs/_base_/models/deeplabv3plus_r50-d8.py new file mode 100644 index 000000000..050e39e09 --- /dev/null +++ b/segmentation/configs/_base_/models/deeplabv3plus_r50-d8.py @@ -0,0 +1,46 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='DepthwiseSeparableASPPHead', + in_channels=2048, + in_index=3, + channels=512, + dilations=(1, 12, 24, 36), + c1_in_channels=256, + c1_channels=48, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/segmentation/configs/_base_/models/dmnet_r50-d8.py b/segmentation/configs/_base_/models/dmnet_r50-d8.py new file mode 100644 index 000000000..d22ba5264 --- /dev/null +++ b/segmentation/configs/_base_/models/dmnet_r50-d8.py @@ -0,0 +1,44 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='DMHead', + in_channels=2048, + in_index=3, + channels=512, + filter_sizes=(1, 3, 5, 7), + dropout_ratio=0.1, + num_classes=19, + norm_cfg=dict(type='SyncBN', requires_grad=True), + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/segmentation/configs/_base_/models/dnl_r50-d8.py b/segmentation/configs/_base_/models/dnl_r50-d8.py new file mode 100644 index 000000000..edb4c174c --- /dev/null +++ b/segmentation/configs/_base_/models/dnl_r50-d8.py @@ -0,0 +1,46 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='DNLHead', + in_channels=2048, + in_index=3, + channels=512, + dropout_ratio=0.1, + reduction=2, + use_scale=True, + mode='embedded_gaussian', + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/segmentation/configs/_base_/models/dpt_vit-b16.py b/segmentation/configs/_base_/models/dpt_vit-b16.py new file mode 100644 index 000000000..dfd48a95f --- /dev/null +++ b/segmentation/configs/_base_/models/dpt_vit-b16.py @@ -0,0 +1,31 @@ +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='pretrain/vit-b16_p16_224-80ecf9dd.pth', # noqa + backbone=dict( + type='VisionTransformer', + img_size=224, + embed_dims=768, + num_layers=12, + num_heads=12, + out_indices=(2, 5, 8, 11), + final_norm=False, + with_cls_token=True, + output_cls_token=True), + decode_head=dict( + type='DPTHead', + in_channels=(768, 768, 768, 768), + channels=256, + embed_dims=768, + post_process_channels=[96, 192, 384, 768], + num_classes=150, + readout_type='project', + input_transform='multiple_select', + in_index=(0, 1, 2, 3), + norm_cfg=norm_cfg, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=None, + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) # yapf: disable diff --git a/segmentation/configs/_base_/models/emanet_r50-d8.py b/segmentation/configs/_base_/models/emanet_r50-d8.py new file mode 100644 index 000000000..26adcd430 --- /dev/null +++ b/segmentation/configs/_base_/models/emanet_r50-d8.py @@ -0,0 +1,47 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='EMAHead', + in_channels=2048, + in_index=3, + channels=256, + ema_channels=512, + num_bases=64, + num_stages=3, + momentum=0.1, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/segmentation/configs/_base_/models/encnet_r50-d8.py b/segmentation/configs/_base_/models/encnet_r50-d8.py new file mode 100644 index 000000000..be777123a --- /dev/null +++ b/segmentation/configs/_base_/models/encnet_r50-d8.py @@ -0,0 +1,48 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='EncHead', + in_channels=[512, 1024, 2048], + in_index=(1, 2, 3), + channels=512, + num_codes=32, + use_se_loss=True, + add_lateral=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_se_decode=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.2)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/segmentation/configs/_base_/models/erfnet_fcn.py b/segmentation/configs/_base_/models/erfnet_fcn.py new file mode 100644 index 000000000..7f2e9bff8 --- /dev/null +++ b/segmentation/configs/_base_/models/erfnet_fcn.py @@ -0,0 +1,32 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained=None, + backbone=dict( + type='ERFNet', + in_channels=3, + enc_downsample_channels=(16, 64, 128), + enc_stage_non_bottlenecks=(5, 8), + enc_non_bottleneck_dilations=(2, 4, 8, 16), + enc_non_bottleneck_channels=(64, 128), + dec_upsample_channels=(64, 16), + dec_stages_non_bottleneck=(2, 2), + dec_non_bottleneck_channels=(64, 16), + dropout_ratio=0.1, + init_cfg=None), + decode_head=dict( + type='FCNHead', + in_channels=16, + channels=128, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/segmentation/configs/_base_/models/fast_scnn.py b/segmentation/configs/_base_/models/fast_scnn.py new file mode 100644 index 000000000..8e89d911d --- /dev/null +++ b/segmentation/configs/_base_/models/fast_scnn.py @@ -0,0 +1,57 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True, momentum=0.01) +model = dict( + type='EncoderDecoder', + backbone=dict( + type='FastSCNN', + downsample_dw_channels=(32, 48), + global_in_channels=64, + global_block_channels=(64, 96, 128), + global_block_strides=(2, 2, 1), + global_out_channels=128, + higher_in_channels=64, + lower_in_channels=128, + fusion_out_channels=128, + out_indices=(0, 1, 2), + norm_cfg=norm_cfg, + align_corners=False), + decode_head=dict( + type='DepthwiseSeparableFCNHead', + in_channels=128, + channels=128, + concat_input=False, + num_classes=19, + in_index=-1, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1)), + auxiliary_head=[ + dict( + type='FCNHead', + in_channels=128, + channels=32, + num_convs=1, + num_classes=19, + in_index=-2, + norm_cfg=norm_cfg, + concat_input=False, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)), + dict( + type='FCNHead', + in_channels=64, + channels=32, + num_convs=1, + num_classes=19, + in_index=-3, + norm_cfg=norm_cfg, + concat_input=False, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)), + ], + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/segmentation/configs/_base_/models/fastfcn_r50-d32_jpu_psp.py b/segmentation/configs/_base_/models/fastfcn_r50-d32_jpu_psp.py new file mode 100644 index 000000000..9dc8609ae --- /dev/null +++ b/segmentation/configs/_base_/models/fastfcn_r50-d32_jpu_psp.py @@ -0,0 +1,53 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + dilations=(1, 1, 2, 4), + strides=(1, 2, 2, 2), + out_indices=(1, 2, 3), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + neck=dict( + type='JPU', + in_channels=(512, 1024, 2048), + mid_channels=512, + start_level=0, + end_level=-1, + dilations=(1, 2, 4, 8), + align_corners=False, + norm_cfg=norm_cfg), + decode_head=dict( + type='PSPHead', + in_channels=2048, + in_index=2, + channels=512, + pool_scales=(1, 2, 3, 6), + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=1, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/segmentation/configs/_base_/models/fcn_hr18.py b/segmentation/configs/_base_/models/fcn_hr18.py new file mode 100644 index 000000000..c3e299bc8 --- /dev/null +++ b/segmentation/configs/_base_/models/fcn_hr18.py @@ -0,0 +1,52 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://msra/hrnetv2_w18', + backbone=dict( + type='HRNet', + norm_cfg=norm_cfg, + norm_eval=False, + extra=dict( + stage1=dict( + num_modules=1, + num_branches=1, + block='BOTTLENECK', + num_blocks=(4, ), + num_channels=(64, )), + stage2=dict( + num_modules=1, + num_branches=2, + block='BASIC', + num_blocks=(4, 4), + num_channels=(18, 36)), + stage3=dict( + num_modules=4, + num_branches=3, + block='BASIC', + num_blocks=(4, 4, 4), + num_channels=(18, 36, 72)), + stage4=dict( + num_modules=3, + num_branches=4, + block='BASIC', + num_blocks=(4, 4, 4, 4), + num_channels=(18, 36, 72, 144)))), + decode_head=dict( + type='FCNHead', + in_channels=[18, 36, 72, 144], + in_index=(0, 1, 2, 3), + channels=sum([18, 36, 72, 144]), + input_transform='resize_concat', + kernel_size=1, + num_convs=1, + concat_input=False, + dropout_ratio=-1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/segmentation/configs/_base_/models/fcn_r50-d8.py b/segmentation/configs/_base_/models/fcn_r50-d8.py new file mode 100644 index 000000000..5e98f6cc9 --- /dev/null +++ b/segmentation/configs/_base_/models/fcn_r50-d8.py @@ -0,0 +1,45 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='FCNHead', + in_channels=2048, + in_index=3, + channels=512, + num_convs=2, + concat_input=True, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/segmentation/configs/_base_/models/fcn_unet_s5-d16.py b/segmentation/configs/_base_/models/fcn_unet_s5-d16.py new file mode 100644 index 000000000..a33e79728 --- /dev/null +++ b/segmentation/configs/_base_/models/fcn_unet_s5-d16.py @@ -0,0 +1,51 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained=None, + backbone=dict( + type='UNet', + in_channels=3, + base_channels=64, + num_stages=5, + strides=(1, 1, 1, 1, 1), + enc_num_convs=(2, 2, 2, 2, 2), + dec_num_convs=(2, 2, 2, 2), + downsamples=(True, True, True, True), + enc_dilations=(1, 1, 1, 1, 1), + dec_dilations=(1, 1, 1, 1), + with_cp=False, + conv_cfg=None, + norm_cfg=norm_cfg, + act_cfg=dict(type='ReLU'), + upsample_cfg=dict(type='InterpConv'), + norm_eval=False), + decode_head=dict( + type='FCNHead', + in_channels=64, + in_index=4, + channels=64, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=2, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=128, + in_index=3, + channels=64, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=2, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='slide', crop_size=256, stride=170)) diff --git a/segmentation/configs/_base_/models/fpn_r50.py b/segmentation/configs/_base_/models/fpn_r50.py new file mode 100644 index 000000000..86ab327db --- /dev/null +++ b/segmentation/configs/_base_/models/fpn_r50.py @@ -0,0 +1,36 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 1, 1), + strides=(1, 2, 2, 2), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=4), + decode_head=dict( + type='FPNHead', + in_channels=[256, 256, 256, 256], + in_index=[0, 1, 2, 3], + feature_strides=[4, 8, 16, 32], + channels=128, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/segmentation/configs/_base_/models/gcnet_r50-d8.py b/segmentation/configs/_base_/models/gcnet_r50-d8.py new file mode 100644 index 000000000..3d2ad69f5 --- /dev/null +++ b/segmentation/configs/_base_/models/gcnet_r50-d8.py @@ -0,0 +1,46 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='GCHead', + in_channels=2048, + in_index=3, + channels=512, + ratio=1 / 4., + pooling_type='att', + fusion_types=('channel_add', ), + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/segmentation/configs/_base_/models/icnet_r50-d8.py b/segmentation/configs/_base_/models/icnet_r50-d8.py new file mode 100644 index 000000000..d7273cd28 --- /dev/null +++ b/segmentation/configs/_base_/models/icnet_r50-d8.py @@ -0,0 +1,74 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + backbone=dict( + type='ICNet', + backbone_cfg=dict( + type='ResNetV1c', + in_channels=3, + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + in_channels=3, + layer_channels=(512, 2048), + light_branch_middle_channels=32, + psp_out_channels=512, + out_channels=(64, 256, 256), + norm_cfg=norm_cfg, + align_corners=False, + ), + neck=dict( + type='ICNeck', + in_channels=(64, 256, 256), + out_channels=128, + norm_cfg=norm_cfg, + align_corners=False), + decode_head=dict( + type='FCNHead', + in_channels=128, + channels=128, + num_convs=1, + in_index=2, + dropout_ratio=0, + num_classes=19, + norm_cfg=norm_cfg, + concat_input=False, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=[ + dict( + type='FCNHead', + in_channels=128, + channels=128, + num_convs=1, + num_classes=19, + in_index=0, + norm_cfg=norm_cfg, + concat_input=False, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + dict( + type='FCNHead', + in_channels=128, + channels=128, + num_convs=1, + num_classes=19, + in_index=1, + norm_cfg=norm_cfg, + concat_input=False, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + ], + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/segmentation/configs/_base_/models/isanet_r50-d8.py b/segmentation/configs/_base_/models/isanet_r50-d8.py new file mode 100644 index 000000000..c0221a371 --- /dev/null +++ b/segmentation/configs/_base_/models/isanet_r50-d8.py @@ -0,0 +1,45 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='ISAHead', + in_channels=2048, + in_index=3, + channels=512, + isa_channels=256, + down_factor=(8, 8), + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/segmentation/configs/_base_/models/lraspp_m-v3-d8.py b/segmentation/configs/_base_/models/lraspp_m-v3-d8.py new file mode 100644 index 000000000..93258242a --- /dev/null +++ b/segmentation/configs/_base_/models/lraspp_m-v3-d8.py @@ -0,0 +1,25 @@ +# model settings +norm_cfg = dict(type='SyncBN', eps=0.001, requires_grad=True) +model = dict( + type='EncoderDecoder', + backbone=dict( + type='MobileNetV3', + arch='large', + out_indices=(1, 3, 16), + norm_cfg=norm_cfg), + decode_head=dict( + type='LRASPPHead', + in_channels=(16, 24, 960), + in_index=(0, 1, 2), + channels=128, + input_transform='multiple_select', + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + act_cfg=dict(type='ReLU'), + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/segmentation/configs/_base_/models/mask2former_beit.py b/segmentation/configs/_base_/models/mask2former_beit.py new file mode 100644 index 000000000..4990be100 --- /dev/null +++ b/segmentation/configs/_base_/models/mask2former_beit.py @@ -0,0 +1,138 @@ +# model_cfg +num_things_classes = 100 +num_stuff_classes = 50 +num_classes = num_things_classes + num_stuff_classes +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoderMask2Former', + pretrained=None, + backbone=dict( + type='XCiT', + patch_size=16, + embed_dim=384, + depth=12, + num_heads=8, + mlp_ratio=4, + qkv_bias=True, + use_abs_pos_emb=True, + use_rel_pos_bias=False, + ), + decode_head=dict( + type='Mask2FormerHead', + in_channels=[256, 512, 1024, 2048], # pass to pixel_decoder inside + # strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + in_index=[0, 1, 2, 3], + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=100, + num_transformer_feat_level=3, + pixel_decoder=dict( + type='MSDeformAttnPixelDecoder', + num_outs=3, + norm_cfg=dict(type='GN', num_groups=32), + act_cfg=dict(type='ReLU'), + encoder=dict( + type='DetrTransformerEncoder', + num_layers=6, + transformerlayers=dict( + type='BaseTransformerLayer', + attn_cfgs=dict( + type='MultiScaleDeformableAttention', + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + im2col_step=64, + dropout=0.0, + batch_first=False, + norm_cfg=None, + init_cfg=None), + ffn_cfgs=dict( + type='FFN', + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True)), + operation_order=('self_attn', 'norm', 'ffn', 'norm')), + init_cfg=None), + positional_encoding=dict( + type='SinePositionalEncoding', num_feats=128, normalize=True), + init_cfg=None), + enforce_decoder_input_project=False, + positional_encoding=dict( + type='SinePositionalEncoding', num_feats=128, normalize=True), + transformer_decoder=dict( + type='DetrTransformerDecoder', + return_intermediate=True, + num_layers=9, + transformerlayers=dict( + type='DetrTransformerDecoderLayer', + attn_cfgs=dict( + type='MultiheadAttention', + embed_dims=256, + num_heads=8, + attn_drop=0.0, + proj_drop=0.0, + dropout_layer=None, + batch_first=False), + ffn_cfgs=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + act_cfg=dict(type='ReLU', inplace=True), + ffn_drop=0.0, + dropout_layer=None, + add_identity=True), + feedforward_channels=2048, + operation_order=('cross_attn', 'norm', 'self_attn', 'norm', + 'ffn', 'norm')), + init_cfg=None), + loss_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * num_classes + [0.1]), + loss_mask=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type='DiceLoss', + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0)), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type='MaskHungarianAssigner', + cls_cost=dict(type='ClassificationCost', weight=2.0), + mask_cost=dict( + type='CrossEntropyLossCost', weight=5.0, use_sigmoid=True), + dice_cost=dict( + type='DiceCost', weight=5.0, pred_act=True, eps=1.0)), + sampler=dict(type='MaskPseudoSampler')), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=None) + +find_unused_parameters = True \ No newline at end of file diff --git a/segmentation/configs/_base_/models/mask2former_beit_cityscapes.py b/segmentation/configs/_base_/models/mask2former_beit_cityscapes.py new file mode 100644 index 000000000..94cf90ac1 --- /dev/null +++ b/segmentation/configs/_base_/models/mask2former_beit_cityscapes.py @@ -0,0 +1,138 @@ +# model_cfg +num_things_classes = 8 +num_stuff_classes = 11 +num_classes = num_things_classes + num_stuff_classes +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoderMask2Former', + pretrained=None, + backbone=dict( + type='XCiT', + patch_size=16, + embed_dim=384, + depth=12, + num_heads=8, + mlp_ratio=4, + qkv_bias=True, + use_abs_pos_emb=True, + use_rel_pos_bias=False, + ), + decode_head=dict( + type='Mask2FormerHead', + in_channels=[256, 512, 1024, 2048], # pass to pixel_decoder inside + # strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + in_index=[0, 1, 2, 3], + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=100, + num_transformer_feat_level=3, + pixel_decoder=dict( + type='MSDeformAttnPixelDecoder', + num_outs=3, + norm_cfg=dict(type='GN', num_groups=32), + act_cfg=dict(type='ReLU'), + encoder=dict( + type='DetrTransformerEncoder', + num_layers=6, + transformerlayers=dict( + type='BaseTransformerLayer', + attn_cfgs=dict( + type='MultiScaleDeformableAttention', + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + im2col_step=64, + dropout=0.0, + batch_first=False, + norm_cfg=None, + init_cfg=None), + ffn_cfgs=dict( + type='FFN', + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True)), + operation_order=('self_attn', 'norm', 'ffn', 'norm')), + init_cfg=None), + positional_encoding=dict( + type='SinePositionalEncoding', num_feats=128, normalize=True), + init_cfg=None), + enforce_decoder_input_project=False, + positional_encoding=dict( + type='SinePositionalEncoding', num_feats=128, normalize=True), + transformer_decoder=dict( + type='DetrTransformerDecoder', + return_intermediate=True, + num_layers=9, + transformerlayers=dict( + type='DetrTransformerDecoderLayer', + attn_cfgs=dict( + type='MultiheadAttention', + embed_dims=256, + num_heads=8, + attn_drop=0.0, + proj_drop=0.0, + dropout_layer=None, + batch_first=False), + ffn_cfgs=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + act_cfg=dict(type='ReLU', inplace=True), + ffn_drop=0.0, + dropout_layer=None, + add_identity=True), + feedforward_channels=2048, + operation_order=('cross_attn', 'norm', 'self_attn', 'norm', + 'ffn', 'norm')), + init_cfg=None), + loss_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * num_classes + [0.1]), + loss_mask=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type='DiceLoss', + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0)), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type='MaskHungarianAssigner', + cls_cost=dict(type='ClassificationCost', weight=2.0), + mask_cost=dict( + type='CrossEntropyLossCost', weight=5.0, use_sigmoid=True), + dice_cost=dict( + type='DiceCost', weight=5.0, pred_act=True, eps=1.0)), + sampler=dict(type='MaskPseudoSampler')), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=None) + +find_unused_parameters = True \ No newline at end of file diff --git a/segmentation/configs/_base_/models/mask2former_beit_cocostuff.py b/segmentation/configs/_base_/models/mask2former_beit_cocostuff.py new file mode 100644 index 000000000..540b5a2ff --- /dev/null +++ b/segmentation/configs/_base_/models/mask2former_beit_cocostuff.py @@ -0,0 +1,138 @@ +# model_cfg +num_things_classes = 80 +num_stuff_classes = 91 +num_classes = num_things_classes + num_stuff_classes +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoderMask2Former', + pretrained=None, + backbone=dict( + type='XCiT', + patch_size=16, + embed_dim=384, + depth=12, + num_heads=8, + mlp_ratio=4, + qkv_bias=True, + use_abs_pos_emb=True, + use_rel_pos_bias=False, + ), + decode_head=dict( + type='Mask2FormerHead', + in_channels=[256, 512, 1024, 2048], # pass to pixel_decoder inside + # strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + in_index=[0, 1, 2, 3], + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=100, + num_transformer_feat_level=3, + pixel_decoder=dict( + type='MSDeformAttnPixelDecoder', + num_outs=3, + norm_cfg=dict(type='GN', num_groups=32), + act_cfg=dict(type='ReLU'), + encoder=dict( + type='DetrTransformerEncoder', + num_layers=6, + transformerlayers=dict( + type='BaseTransformerLayer', + attn_cfgs=dict( + type='MultiScaleDeformableAttention', + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + im2col_step=64, + dropout=0.0, + batch_first=False, + norm_cfg=None, + init_cfg=None), + ffn_cfgs=dict( + type='FFN', + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True)), + operation_order=('self_attn', 'norm', 'ffn', 'norm')), + init_cfg=None), + positional_encoding=dict( + type='SinePositionalEncoding', num_feats=128, normalize=True), + init_cfg=None), + enforce_decoder_input_project=False, + positional_encoding=dict( + type='SinePositionalEncoding', num_feats=128, normalize=True), + transformer_decoder=dict( + type='DetrTransformerDecoder', + return_intermediate=True, + num_layers=9, + transformerlayers=dict( + type='DetrTransformerDecoderLayer', + attn_cfgs=dict( + type='MultiheadAttention', + embed_dims=256, + num_heads=8, + attn_drop=0.0, + proj_drop=0.0, + dropout_layer=None, + batch_first=False), + ffn_cfgs=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + act_cfg=dict(type='ReLU', inplace=True), + ffn_drop=0.0, + dropout_layer=None, + add_identity=True), + feedforward_channels=2048, + operation_order=('cross_attn', 'norm', 'self_attn', 'norm', + 'ffn', 'norm')), + init_cfg=None), + loss_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * num_classes + [0.1]), + loss_mask=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type='DiceLoss', + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0)), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type='MaskHungarianAssigner', + cls_cost=dict(type='ClassificationCost', weight=2.0), + mask_cost=dict( + type='CrossEntropyLossCost', weight=5.0, use_sigmoid=True), + dice_cost=dict( + type='DiceCost', weight=5.0, pred_act=True, eps=1.0)), + sampler=dict(type='MaskPseudoSampler')), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=None) + +find_unused_parameters = True \ No newline at end of file diff --git a/segmentation/configs/_base_/models/mask2former_beit_pascal.py b/segmentation/configs/_base_/models/mask2former_beit_pascal.py new file mode 100644 index 000000000..c9a1e87d9 --- /dev/null +++ b/segmentation/configs/_base_/models/mask2former_beit_pascal.py @@ -0,0 +1,138 @@ +# model_cfg +num_things_classes = 29 +num_stuff_classes = 30 +num_classes = num_things_classes + num_stuff_classes +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoderMask2Former', + pretrained=None, + backbone=dict( + type='XCiT', + patch_size=16, + embed_dim=384, + depth=12, + num_heads=8, + mlp_ratio=4, + qkv_bias=True, + use_abs_pos_emb=True, + use_rel_pos_bias=False, + ), + decode_head=dict( + type='Mask2FormerHead', + in_channels=[256, 512, 1024, 2048], # pass to pixel_decoder inside + # strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + in_index=[0, 1, 2, 3], + num_things_classes=num_things_classes, + num_stuff_classes=num_stuff_classes, + num_queries=100, + num_transformer_feat_level=3, + pixel_decoder=dict( + type='MSDeformAttnPixelDecoder', + num_outs=3, + norm_cfg=dict(type='GN', num_groups=32), + act_cfg=dict(type='ReLU'), + encoder=dict( + type='DetrTransformerEncoder', + num_layers=6, + transformerlayers=dict( + type='BaseTransformerLayer', + attn_cfgs=dict( + type='MultiScaleDeformableAttention', + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + im2col_step=64, + dropout=0.0, + batch_first=False, + norm_cfg=None, + init_cfg=None), + ffn_cfgs=dict( + type='FFN', + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True)), + operation_order=('self_attn', 'norm', 'ffn', 'norm')), + init_cfg=None), + positional_encoding=dict( + type='SinePositionalEncoding', num_feats=128, normalize=True), + init_cfg=None), + enforce_decoder_input_project=False, + positional_encoding=dict( + type='SinePositionalEncoding', num_feats=128, normalize=True), + transformer_decoder=dict( + type='DetrTransformerDecoder', + return_intermediate=True, + num_layers=9, + transformerlayers=dict( + type='DetrTransformerDecoderLayer', + attn_cfgs=dict( + type='MultiheadAttention', + embed_dims=256, + num_heads=8, + attn_drop=0.0, + proj_drop=0.0, + dropout_layer=None, + batch_first=False), + ffn_cfgs=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + act_cfg=dict(type='ReLU', inplace=True), + ffn_drop=0.0, + dropout_layer=None, + add_identity=True), + feedforward_channels=2048, + operation_order=('cross_attn', 'norm', 'self_attn', 'norm', + 'ffn', 'norm')), + init_cfg=None), + loss_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=2.0, + reduction='mean', + class_weight=[1.0] * num_classes + [0.1]), + loss_mask=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + reduction='mean', + loss_weight=5.0), + loss_dice=dict( + type='DiceLoss', + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=5.0)), + train_cfg=dict( + num_points=12544, + oversample_ratio=3.0, + importance_sample_ratio=0.75, + assigner=dict( + type='MaskHungarianAssigner', + cls_cost=dict(type='ClassificationCost', weight=2.0), + mask_cost=dict( + type='CrossEntropyLossCost', weight=5.0, use_sigmoid=True), + dice_cost=dict( + type='DiceCost', weight=5.0, pred_act=True, eps=1.0)), + sampler=dict(type='MaskPseudoSampler')), + test_cfg=dict( + panoptic_on=True, + # For now, the dataset does not support + # evaluating semantic segmentation metric. + semantic_on=False, + instance_on=True, + # max_per_image is for instance segmentation. + max_per_image=100, + iou_thr=0.8, + # In Mask2Former's panoptic postprocessing, + # it will filter mask area where score is less than 0.5 . + filter_low_score=True), + init_cfg=None) + +find_unused_parameters = True \ No newline at end of file diff --git a/segmentation/configs/_base_/models/maskformer_beit.py b/segmentation/configs/_base_/models/maskformer_beit.py new file mode 100644 index 000000000..20a4e5c89 --- /dev/null +++ b/segmentation/configs/_base_/models/maskformer_beit.py @@ -0,0 +1,92 @@ +# model_cfg +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained=None, + backbone=dict( + type='XCiT', + patch_size=16, + embed_dim=384, + depth=12, + num_heads=8, + mlp_ratio=4, + qkv_bias=True, + use_abs_pos_emb=True, + use_rel_pos_bias=False, + ), + decode_head=dict( + type='MaskFormerHead', + in_channels=[384, 384, 384, 384], # pass to pixel_decoder inside + channels=256, + in_index=[0, 1, 2, 3], + num_classes=150, + out_channels=256, + num_queries=100, + pixel_decoder=dict( + type='PixelDecoder', + norm_cfg=dict(type='GN', num_groups=32), + act_cfg=dict(type='ReLU')), + enforce_decoder_input_project=False, + positional_encoding=dict( + type='SinePositionalEncoding', num_feats=128, normalize=True), + transformer_decoder=dict( + type='DetrTransformerDecoder', + return_intermediate=True, + num_layers=6, + transformerlayers=dict( + type='DetrTransformerDecoderLayer', + attn_cfgs=dict( + type='MultiheadAttention', + embed_dims=256, + num_heads=8, + attn_drop=0.1, + proj_drop=0.1, + dropout_layer=None, + batch_first=False), + ffn_cfgs=dict( + embed_dims=256, + feedforward_channels=2048, + num_fcs=2, + act_cfg=dict(type='ReLU', inplace=True), + ffn_drop=0.1, + dropout_layer=None, + add_identity=True), + # the following parameter was not used, + # just make current api happy + feedforward_channels=2048, + operation_order=('self_attn', 'norm', 'cross_attn', 'norm', + 'ffn', 'norm')), + init_cfg=None), + loss_cls=dict( + type='CrossEntropyLoss', + bg_cls_weight=0.1, + use_sigmoid=False, + loss_weight=1.0, + reduction='mean', + class_weight=1.0), + loss_mask=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + reduction='mean', + loss_weight=20.0), + loss_dice=dict( + type='DiceLoss', + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=True, + eps=1.0, + loss_weight=1.0), + assigner=dict( + type='MaskHungarianAssigner', + cls_cost=dict(type='ClassificationCost', weight=1.0), + mask_cost=dict(type='MaskFocalLossCost', weight=20.0), + dice_cost=dict( + type='DiceCost', weight=1.0, pred_act=True, eps=1.0))), + # training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='slide', crop_size=(512, 512), stride=(341, 341)), +) +find_unused_parameters = True \ No newline at end of file diff --git a/segmentation/configs/_base_/models/nonlocal_r50-d8.py b/segmentation/configs/_base_/models/nonlocal_r50-d8.py new file mode 100644 index 000000000..5674a3985 --- /dev/null +++ b/segmentation/configs/_base_/models/nonlocal_r50-d8.py @@ -0,0 +1,46 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='NLHead', + in_channels=2048, + in_index=3, + channels=512, + dropout_ratio=0.1, + reduction=2, + use_scale=True, + mode='embedded_gaussian', + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/segmentation/configs/_base_/models/ocrnet_hr18.py b/segmentation/configs/_base_/models/ocrnet_hr18.py new file mode 100644 index 000000000..c60f62a7c --- /dev/null +++ b/segmentation/configs/_base_/models/ocrnet_hr18.py @@ -0,0 +1,68 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='CascadeEncoderDecoder', + num_stages=2, + pretrained='open-mmlab://msra/hrnetv2_w18', + backbone=dict( + type='HRNet', + norm_cfg=norm_cfg, + norm_eval=False, + extra=dict( + stage1=dict( + num_modules=1, + num_branches=1, + block='BOTTLENECK', + num_blocks=(4, ), + num_channels=(64, )), + stage2=dict( + num_modules=1, + num_branches=2, + block='BASIC', + num_blocks=(4, 4), + num_channels=(18, 36)), + stage3=dict( + num_modules=4, + num_branches=3, + block='BASIC', + num_blocks=(4, 4, 4), + num_channels=(18, 36, 72)), + stage4=dict( + num_modules=3, + num_branches=4, + block='BASIC', + num_blocks=(4, 4, 4, 4), + num_channels=(18, 36, 72, 144)))), + decode_head=[ + dict( + type='FCNHead', + in_channels=[18, 36, 72, 144], + channels=sum([18, 36, 72, 144]), + in_index=(0, 1, 2, 3), + input_transform='resize_concat', + kernel_size=1, + num_convs=1, + concat_input=False, + dropout_ratio=-1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + dict( + type='OCRHead', + in_channels=[18, 36, 72, 144], + in_index=(0, 1, 2, 3), + input_transform='resize_concat', + channels=512, + ocr_channels=256, + dropout_ratio=-1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + ], + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/segmentation/configs/_base_/models/ocrnet_r50-d8.py b/segmentation/configs/_base_/models/ocrnet_r50-d8.py new file mode 100644 index 000000000..615aa3ff7 --- /dev/null +++ b/segmentation/configs/_base_/models/ocrnet_r50-d8.py @@ -0,0 +1,47 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='CascadeEncoderDecoder', + num_stages=2, + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=[ + dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + dict( + type='OCRHead', + in_channels=2048, + in_index=3, + channels=512, + ocr_channels=256, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)) + ], + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/segmentation/configs/_base_/models/pointrend_r50.py b/segmentation/configs/_base_/models/pointrend_r50.py new file mode 100644 index 000000000..9d323dbf9 --- /dev/null +++ b/segmentation/configs/_base_/models/pointrend_r50.py @@ -0,0 +1,56 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='CascadeEncoderDecoder', + num_stages=2, + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 1, 1), + strides=(1, 2, 2, 2), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=4), + decode_head=[ + dict( + type='FPNHead', + in_channels=[256, 256, 256, 256], + in_index=[0, 1, 2, 3], + feature_strides=[4, 8, 16, 32], + channels=128, + dropout_ratio=-1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + dict( + type='PointHead', + in_channels=[256], + in_index=[0], + channels=256, + num_fcs=3, + coarse_pred_each_layer=True, + dropout_ratio=-1, + num_classes=19, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)) + ], + # model training and testing settings + train_cfg=dict( + num_points=2048, oversample_ratio=3, importance_sample_ratio=0.75), + test_cfg=dict( + mode='whole', + subdivision_steps=2, + subdivision_num_points=8196, + scale_factor=2)) diff --git a/segmentation/configs/_base_/models/psanet_r50-d8.py b/segmentation/configs/_base_/models/psanet_r50-d8.py new file mode 100644 index 000000000..689513fa9 --- /dev/null +++ b/segmentation/configs/_base_/models/psanet_r50-d8.py @@ -0,0 +1,49 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='PSAHead', + in_channels=2048, + in_index=3, + channels=512, + mask_size=(97, 97), + psa_type='bi-direction', + compact=False, + shrink_factor=2, + normalization_factor=1.0, + psa_softmax=True, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/segmentation/configs/_base_/models/pspnet_r50-d8.py b/segmentation/configs/_base_/models/pspnet_r50-d8.py new file mode 100644 index 000000000..f451e08ad --- /dev/null +++ b/segmentation/configs/_base_/models/pspnet_r50-d8.py @@ -0,0 +1,44 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='PSPHead', + in_channels=2048, + in_index=3, + channels=512, + pool_scales=(1, 2, 3, 6), + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/segmentation/configs/_base_/models/pspnet_unet_s5-d16.py b/segmentation/configs/_base_/models/pspnet_unet_s5-d16.py new file mode 100644 index 000000000..fcff9ec4f --- /dev/null +++ b/segmentation/configs/_base_/models/pspnet_unet_s5-d16.py @@ -0,0 +1,50 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained=None, + backbone=dict( + type='UNet', + in_channels=3, + base_channels=64, + num_stages=5, + strides=(1, 1, 1, 1, 1), + enc_num_convs=(2, 2, 2, 2, 2), + dec_num_convs=(2, 2, 2, 2), + downsamples=(True, True, True, True), + enc_dilations=(1, 1, 1, 1, 1), + dec_dilations=(1, 1, 1, 1), + with_cp=False, + conv_cfg=None, + norm_cfg=norm_cfg, + act_cfg=dict(type='ReLU'), + upsample_cfg=dict(type='InterpConv'), + norm_eval=False), + decode_head=dict( + type='PSPHead', + in_channels=64, + in_index=4, + channels=16, + pool_scales=(1, 2, 3, 6), + dropout_ratio=0.1, + num_classes=2, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=128, + in_index=3, + channels=64, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=2, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='slide', crop_size=256, stride=170)) diff --git a/segmentation/configs/_base_/models/segformer_mit-b0.py b/segmentation/configs/_base_/models/segformer_mit-b0.py new file mode 100644 index 000000000..5b3e07331 --- /dev/null +++ b/segmentation/configs/_base_/models/segformer_mit-b0.py @@ -0,0 +1,34 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained=None, + backbone=dict( + type='MixVisionTransformer', + in_channels=3, + embed_dims=32, + num_stages=4, + num_layers=[2, 2, 2, 2], + num_heads=[1, 2, 5, 8], + patch_sizes=[7, 3, 3, 3], + sr_ratios=[8, 4, 2, 1], + out_indices=(0, 1, 2, 3), + mlp_ratio=4, + qkv_bias=True, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.1), + decode_head=dict( + type='SegformerHead', + in_channels=[32, 64, 160, 256], + in_index=[0, 1, 2, 3], + channels=256, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/segmentation/configs/_base_/models/setr_mla.py b/segmentation/configs/_base_/models/setr_mla.py new file mode 100644 index 000000000..af4ba2492 --- /dev/null +++ b/segmentation/configs/_base_/models/setr_mla.py @@ -0,0 +1,95 @@ +# model settings +backbone_norm_cfg = dict(type='LN', eps=1e-6, requires_grad=True) +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='pretrain/jx_vit_large_p16_384-b3be5167.pth', + backbone=dict( + type='VisionTransformer', + img_size=(768, 768), + patch_size=16, + in_channels=3, + embed_dims=1024, + num_layers=24, + num_heads=16, + out_indices=(5, 11, 17, 23), + drop_rate=0.1, + norm_cfg=backbone_norm_cfg, + with_cls_token=False, + interpolate_mode='bilinear', + ), + neck=dict( + type='MLANeck', + in_channels=[1024, 1024, 1024, 1024], + out_channels=256, + norm_cfg=norm_cfg, + act_cfg=dict(type='ReLU'), + ), + decode_head=dict( + type='SETRMLAHead', + in_channels=(256, 256, 256, 256), + channels=512, + in_index=(0, 1, 2, 3), + dropout_ratio=0, + mla_channels=128, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=[ + dict( + type='FCNHead', + in_channels=256, + channels=256, + in_index=0, + dropout_ratio=0, + num_convs=0, + kernel_size=1, + concat_input=False, + num_classes=19, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + dict( + type='FCNHead', + in_channels=256, + channels=256, + in_index=1, + dropout_ratio=0, + num_convs=0, + kernel_size=1, + concat_input=False, + num_classes=19, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + dict( + type='FCNHead', + in_channels=256, + channels=256, + in_index=2, + dropout_ratio=0, + num_convs=0, + kernel_size=1, + concat_input=False, + num_classes=19, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + dict( + type='FCNHead', + in_channels=256, + channels=256, + in_index=3, + dropout_ratio=0, + num_convs=0, + kernel_size=1, + concat_input=False, + num_classes=19, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + ], + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/segmentation/configs/_base_/models/setr_naive.py b/segmentation/configs/_base_/models/setr_naive.py new file mode 100644 index 000000000..0c330ea2d --- /dev/null +++ b/segmentation/configs/_base_/models/setr_naive.py @@ -0,0 +1,80 @@ +# model settings +backbone_norm_cfg = dict(type='LN', eps=1e-6, requires_grad=True) +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='pretrain/jx_vit_large_p16_384-b3be5167.pth', + backbone=dict( + type='VisionTransformer', + img_size=(768, 768), + patch_size=16, + in_channels=3, + embed_dims=1024, + num_layers=24, + num_heads=16, + out_indices=(9, 14, 19, 23), + drop_rate=0.1, + norm_cfg=backbone_norm_cfg, + with_cls_token=True, + interpolate_mode='bilinear', + ), + decode_head=dict( + type='SETRUPHead', + in_channels=1024, + channels=256, + in_index=3, + num_classes=19, + dropout_ratio=0, + norm_cfg=norm_cfg, + num_convs=1, + up_scale=4, + kernel_size=1, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=[ + dict( + type='SETRUPHead', + in_channels=1024, + channels=256, + in_index=0, + num_classes=19, + dropout_ratio=0, + norm_cfg=norm_cfg, + num_convs=1, + up_scale=4, + kernel_size=1, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + dict( + type='SETRUPHead', + in_channels=1024, + channels=256, + in_index=1, + num_classes=19, + dropout_ratio=0, + norm_cfg=norm_cfg, + num_convs=1, + up_scale=4, + kernel_size=1, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + dict( + type='SETRUPHead', + in_channels=1024, + channels=256, + in_index=2, + num_classes=19, + dropout_ratio=0, + norm_cfg=norm_cfg, + num_convs=1, + up_scale=4, + kernel_size=1, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)) + ], + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/segmentation/configs/_base_/models/setr_pup.py b/segmentation/configs/_base_/models/setr_pup.py new file mode 100644 index 000000000..8e5f23b9c --- /dev/null +++ b/segmentation/configs/_base_/models/setr_pup.py @@ -0,0 +1,80 @@ +# model settings +backbone_norm_cfg = dict(type='LN', eps=1e-6, requires_grad=True) +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='pretrain/jx_vit_large_p16_384-b3be5167.pth', + backbone=dict( + type='VisionTransformer', + img_size=(768, 768), + patch_size=16, + in_channels=3, + embed_dims=1024, + num_layers=24, + num_heads=16, + out_indices=(9, 14, 19, 23), + drop_rate=0.1, + norm_cfg=backbone_norm_cfg, + with_cls_token=True, + interpolate_mode='bilinear', + ), + decode_head=dict( + type='SETRUPHead', + in_channels=1024, + channels=256, + in_index=3, + num_classes=19, + dropout_ratio=0, + norm_cfg=norm_cfg, + num_convs=4, + up_scale=2, + kernel_size=3, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=[ + dict( + type='SETRUPHead', + in_channels=1024, + channels=256, + in_index=0, + num_classes=19, + dropout_ratio=0, + norm_cfg=norm_cfg, + num_convs=1, + up_scale=4, + kernel_size=3, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + dict( + type='SETRUPHead', + in_channels=1024, + channels=256, + in_index=1, + num_classes=19, + dropout_ratio=0, + norm_cfg=norm_cfg, + num_convs=1, + up_scale=4, + kernel_size=3, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + dict( + type='SETRUPHead', + in_channels=1024, + channels=256, + in_index=2, + num_classes=19, + dropout_ratio=0, + norm_cfg=norm_cfg, + num_convs=1, + up_scale=4, + kernel_size=3, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + ], + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/segmentation/configs/_base_/models/stdc.py b/segmentation/configs/_base_/models/stdc.py new file mode 100644 index 000000000..e313f0443 --- /dev/null +++ b/segmentation/configs/_base_/models/stdc.py @@ -0,0 +1,83 @@ +norm_cfg = dict(type='BN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained=None, + backbone=dict( + type='STDCContextPathNet', + backbone_cfg=dict( + type='STDCNet', + stdc_type='STDCNet1', + in_channels=3, + channels=(32, 64, 256, 512, 1024), + bottleneck_type='cat', + num_convs=4, + norm_cfg=norm_cfg, + act_cfg=dict(type='ReLU'), + with_final_conv=False), + last_in_channels=(1024, 512), + out_channels=128, + ffm_cfg=dict(in_channels=384, out_channels=256, scale_factor=4)), + decode_head=dict( + type='FCNHead', + in_channels=256, + channels=256, + num_convs=1, + num_classes=19, + in_index=3, + concat_input=False, + dropout_ratio=0.1, + norm_cfg=norm_cfg, + align_corners=True, + sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000), + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=[ + dict( + type='FCNHead', + in_channels=128, + channels=64, + num_convs=1, + num_classes=19, + in_index=2, + norm_cfg=norm_cfg, + concat_input=False, + align_corners=False, + sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000), + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + dict( + type='FCNHead', + in_channels=128, + channels=64, + num_convs=1, + num_classes=19, + in_index=1, + norm_cfg=norm_cfg, + concat_input=False, + align_corners=False, + sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000), + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + dict( + type='STDCHead', + in_channels=256, + channels=64, + num_convs=1, + num_classes=2, + boundary_threshold=0.1, + in_index=0, + norm_cfg=norm_cfg, + concat_input=False, + align_corners=False, + loss_decode=[ + dict( + type='CrossEntropyLoss', + loss_name='loss_ce', + use_sigmoid=True, + loss_weight=1.0), + dict(type='DiceLoss', loss_name='loss_dice', loss_weight=1.0) + ]), + ], + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/segmentation/configs/_base_/models/twins_pcpvt-s_fpn.py b/segmentation/configs/_base_/models/twins_pcpvt-s_fpn.py new file mode 100644 index 000000000..e7722759b --- /dev/null +++ b/segmentation/configs/_base_/models/twins_pcpvt-s_fpn.py @@ -0,0 +1,44 @@ +# model settings +backbone_norm_cfg = dict(type='LN') +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + backbone=dict( + type='PCPVT', + init_cfg=dict( + type='Pretrained', checkpoint='pretrained/pcpvt_small.pth'), + in_channels=3, + embed_dims=[64, 128, 320, 512], + num_heads=[1, 2, 5, 8], + patch_sizes=[4, 2, 2, 2], + strides=[4, 2, 2, 2], + mlp_ratios=[8, 8, 4, 4], + out_indices=(0, 1, 2, 3), + qkv_bias=True, + norm_cfg=backbone_norm_cfg, + depths=[3, 4, 6, 3], + sr_ratios=[8, 4, 2, 1], + norm_after_stage=False, + drop_rate=0.0, + attn_drop_rate=0., + drop_path_rate=0.2), + neck=dict( + type='FPN', + in_channels=[64, 128, 320, 512], + out_channels=256, + num_outs=4), + decode_head=dict( + type='FPNHead', + in_channels=[256, 256, 256, 256], + in_index=[0, 1, 2, 3], + feature_strides=[4, 8, 16, 32], + channels=128, + dropout_ratio=0.1, + num_classes=150, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/segmentation/configs/_base_/models/twins_pcpvt-s_upernet.py b/segmentation/configs/_base_/models/twins_pcpvt-s_upernet.py new file mode 100644 index 000000000..a48e1a953 --- /dev/null +++ b/segmentation/configs/_base_/models/twins_pcpvt-s_upernet.py @@ -0,0 +1,52 @@ +# model settings +backbone_norm_cfg = dict(type='LN') +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + backbone=dict( + type='PCPVT', + init_cfg=dict( + type='Pretrained', checkpoint='pretrained/pcpvt_small.pth'), + in_channels=3, + embed_dims=[64, 128, 320, 512], + num_heads=[1, 2, 5, 8], + patch_sizes=[4, 2, 2, 2], + strides=[4, 2, 2, 2], + mlp_ratios=[8, 8, 4, 4], + out_indices=(0, 1, 2, 3), + qkv_bias=True, + norm_cfg=backbone_norm_cfg, + depths=[3, 4, 6, 3], + sr_ratios=[8, 4, 2, 1], + norm_after_stage=False, + drop_rate=0.0, + attn_drop_rate=0., + drop_path_rate=0.2), + decode_head=dict( + type='UPerHead', + in_channels=[64, 128, 320, 512], + in_index=[0, 1, 2, 3], + pool_scales=(1, 2, 3, 6), + channels=512, + dropout_ratio=0.1, + num_classes=150, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=320, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=150, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/segmentation/configs/_base_/models/upernet_beit.py b/segmentation/configs/_base_/models/upernet_beit.py new file mode 100644 index 000000000..aa7b25ec4 --- /dev/null +++ b/segmentation/configs/_base_/models/upernet_beit.py @@ -0,0 +1,55 @@ +# -------------------------------------------------------- +# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) +# Github source: https://github.com/microsoft/unilm/tree/master/beit +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# By Hangbo Bao +# Based on timm, mmseg, setr, xcit and swin code bases +# https://github.com/rwightman/pytorch-image-models/tree/master/timm +# https://github.com/fudan-zvg/SETR +# https://github.com/facebookresearch/xcit/ +# https://github.com/microsoft/Swin-Transformer +# --------------------------------------------------------' +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained=None, + backbone=dict( + type='XCiT', + patch_size=16, + embed_dim=384, + depth=12, + num_heads=8, + mlp_ratio=4, + qkv_bias=True, + use_abs_pos_emb=True, + use_rel_pos_bias=False, + ), + decode_head=dict( + type='UPerHead', + in_channels=[384, 384, 384, 384], + in_index=[0, 1, 2, 3], + pool_scales=(1, 2, 3, 6), + channels=512, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=384, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) \ No newline at end of file diff --git a/segmentation/configs/_base_/models/upernet_r50.py b/segmentation/configs/_base_/models/upernet_r50.py new file mode 100644 index 000000000..10974962f --- /dev/null +++ b/segmentation/configs/_base_/models/upernet_r50.py @@ -0,0 +1,44 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 1, 1), + strides=(1, 2, 2, 2), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='UPerHead', + in_channels=[256, 512, 1024, 2048], + in_index=[0, 1, 2, 3], + pool_scales=(1, 2, 3, 6), + channels=512, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/segmentation/configs/_base_/models/upernet_swin.py b/segmentation/configs/_base_/models/upernet_swin.py new file mode 100644 index 000000000..71b51629e --- /dev/null +++ b/segmentation/configs/_base_/models/upernet_swin.py @@ -0,0 +1,54 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +backbone_norm_cfg = dict(type='LN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained=None, + backbone=dict( + type='SwinTransformer', + pretrain_img_size=224, + embed_dims=96, + patch_size=4, + window_size=7, + mlp_ratio=4, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + strides=(4, 2, 2, 2), + out_indices=(0, 1, 2, 3), + qkv_bias=True, + qk_scale=None, + patch_norm=True, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.3, + use_abs_pos_embed=False, + act_cfg=dict(type='GELU'), + norm_cfg=backbone_norm_cfg), + decode_head=dict( + type='UPerHead', + in_channels=[96, 192, 384, 768], + in_index=[0, 1, 2, 3], + pool_scales=(1, 2, 3, 6), + channels=512, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=384, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/segmentation/configs/_base_/models/upernet_vit-b16_ln_mln.py b/segmentation/configs/_base_/models/upernet_vit-b16_ln_mln.py new file mode 100644 index 000000000..cd6587dfe --- /dev/null +++ b/segmentation/configs/_base_/models/upernet_vit-b16_ln_mln.py @@ -0,0 +1,57 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='pretrain/jx_vit_base_p16_224-80ecf9dd.pth', + backbone=dict( + type='VisionTransformer', + img_size=(512, 512), + patch_size=16, + in_channels=3, + embed_dims=768, + num_layers=12, + num_heads=12, + mlp_ratio=4, + out_indices=(2, 5, 8, 11), + qkv_bias=True, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + with_cls_token=True, + norm_cfg=dict(type='LN', eps=1e-6), + act_cfg=dict(type='GELU'), + norm_eval=False, + interpolate_mode='bicubic'), + neck=dict( + type='MultiLevelNeck', + in_channels=[768, 768, 768, 768], + out_channels=768, + scales=[4, 2, 1, 0.5]), + decode_head=dict( + type='UPerHead', + in_channels=[768, 768, 768, 768], + in_index=[0, 1, 2, 3], + pool_scales=(1, 2, 3, 6), + channels=512, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=768, + in_index=3, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) # yapf: disable diff --git a/segmentation/configs/_base_/schedules/schedule_160k.py b/segmentation/configs/_base_/schedules/schedule_160k.py new file mode 100644 index 000000000..39630f215 --- /dev/null +++ b/segmentation/configs/_base_/schedules/schedule_160k.py @@ -0,0 +1,9 @@ +# optimizer +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) +optimizer_config = dict() +# learning policy +lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) +# runtime settings +runner = dict(type='IterBasedRunner', max_iters=160000) +checkpoint_config = dict(by_epoch=False, interval=16000) +evaluation = dict(interval=16000, metric='mIoU', pre_eval=True) diff --git a/segmentation/configs/_base_/schedules/schedule_20k.py b/segmentation/configs/_base_/schedules/schedule_20k.py new file mode 100644 index 000000000..73c702197 --- /dev/null +++ b/segmentation/configs/_base_/schedules/schedule_20k.py @@ -0,0 +1,9 @@ +# optimizer +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) +optimizer_config = dict() +# learning policy +lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) +# runtime settings +runner = dict(type='IterBasedRunner', max_iters=20000) +checkpoint_config = dict(by_epoch=False, interval=2000) +evaluation = dict(interval=2000, metric='mIoU', pre_eval=True) diff --git a/segmentation/configs/_base_/schedules/schedule_320k.py b/segmentation/configs/_base_/schedules/schedule_320k.py new file mode 100644 index 000000000..a0b230626 --- /dev/null +++ b/segmentation/configs/_base_/schedules/schedule_320k.py @@ -0,0 +1,9 @@ +# optimizer +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) +optimizer_config = dict() +# learning policy +lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) +# runtime settings +runner = dict(type='IterBasedRunner', max_iters=320000) +checkpoint_config = dict(by_epoch=False, interval=32000) +evaluation = dict(interval=32000, metric='mIoU') diff --git a/segmentation/configs/_base_/schedules/schedule_40k.py b/segmentation/configs/_base_/schedules/schedule_40k.py new file mode 100644 index 000000000..d2c502325 --- /dev/null +++ b/segmentation/configs/_base_/schedules/schedule_40k.py @@ -0,0 +1,9 @@ +# optimizer +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) +optimizer_config = dict() +# learning policy +lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) +# runtime settings +runner = dict(type='IterBasedRunner', max_iters=40000) +checkpoint_config = dict(by_epoch=False, interval=4000) +evaluation = dict(interval=4000, metric='mIoU', pre_eval=True) diff --git a/segmentation/configs/_base_/schedules/schedule_80k.py b/segmentation/configs/_base_/schedules/schedule_80k.py new file mode 100644 index 000000000..8365a878e --- /dev/null +++ b/segmentation/configs/_base_/schedules/schedule_80k.py @@ -0,0 +1,9 @@ +# optimizer +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) +optimizer_config = dict() +# learning policy +lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) +# runtime settings +runner = dict(type='IterBasedRunner', max_iters=80000) +checkpoint_config = dict(by_epoch=False, interval=8000) +evaluation = dict(interval=8000, metric='mIoU', pre_eval=True) diff --git a/segmentation/configs/ade20k/README.md b/segmentation/configs/ade20k/README.md new file mode 100644 index 000000000..7a80dda66 --- /dev/null +++ b/segmentation/configs/ade20k/README.md @@ -0,0 +1,15 @@ +# ADE20K + + + +## Introduction + +The ADE20K semantic segmentation dataset contains more than 20K scene-centric images exhaustively annotated with pixel-level objects and object parts labels. There are totally 150 semantic categories, which include stuffs like sky, road, grass, and discrete objects like person, car, bed. + +## Results and Models + +| Method | Backbone | Pre-train | Batch Size | Lr schd | Crop Size | mIoU (SS) | mIoU (MS) | #Param | Config | Download | +|:-----------:|:-------------:|:---------------------------------------------------------------------------------------------------------------------:|:----------:|:-------:|:---------:|:------------------------------------------------------------------------------------------:|:------------------------------------------------------------------------------------------:|:------:|:----------------------------------------------------------------:|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:| +| UperNet | ViT-Adapter-L | [BEiT-L](https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth) | 8x2 | 160k | 640 | [58.0](https://drive.google.com/file/d/1KsV4QPfoRi5cj2hjCzy8VfWih8xCTrE3/view?usp=sharing) | [58.4](https://drive.google.com/file/d/1haeTUvQhKCM7hunVdK60yxULbRH7YYBK/view?usp=sharing) | 451M | [config](./upernet_beit_adapter_large_640_160k_ade20k_ss.py) | [model](https://github.com/czczup/ViT-Adapter/releases/download/v0.2.1/upernet_beit_adapter_large_640_160k_ade20k.pth.tar) \| [log](https://github.com/czczup/ViT-Adapter/releases/download/v0.2.1/20220313_233147.log) | +| Mask2Former | ViT-Adapter-L | [BEiT-L](https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth) | 8x2 | 160k | 640 | [58.3](https://drive.google.com/file/d/1jj56lSbc2s4ZNc-Hi-w6o-OSS99oi-_g/view?usp=sharing) | [59.0](https://drive.google.com/file/d/1hgpZB5gsyd7LTS7Aay2CbHmlY10nafCw/view?usp=sharing) | 568M | [config](./mask2former_beit_adapter_large_640_160k_ade20k_ss.py) | [model](https://github.com/czczup/ViT-Adapter/releases/download/v0.2.2/mask2former_beit_adapter_large_640_160k_ade20k.zip) \| [log](https://github.com/czczup/ViT-Adapter/releases/download/v0.2.2/20220426_003454.log) | +| Mask2Former | ViT-Adapter-L | COCO-Stuff-164K | 16x1 | 80k | 896 | [59.4](https://drive.google.com/file/d/1B_1XSwdnLhjJeUmn1g_nxfvGJpYmYWHa/view?usp=sharing) | [60.5](https://drive.google.com/file/d/1UtjmgcYKR-2h116oQXklUYOVcTw15woM/view?usp=sharing) | 571M | [config](./mask2former_beit_adapter_large_896_80k_ade20k_ss.py) | [model](https://github.com/czczup/ViT-Adapter/releases/download/v0.2.0/mask2former_beit_adapter_large_896_80k_ade20k.zip) \| [log](https://github.com/czczup/ViT-Adapter/releases/download/v0.2.0/20220430_154104.log) | diff --git a/segmentation/configs/ade20k/mask2former_beit_adapter_large_640_160k_ade20k_ms.py b/segmentation/configs/ade20k/mask2former_beit_adapter_large_640_160k_ade20k_ms.py new file mode 100644 index 000000000..1ea239e9a --- /dev/null +++ b/segmentation/configs/ade20k/mask2former_beit_adapter_large_640_160k_ade20k_ms.py @@ -0,0 +1,149 @@ +# Copyright (c) Shanghai AI Lab. All rights reserved. +_base_ = [ + '../_base_/models/mask2former_beit.py', + '../_base_/datasets/ade20k.py', + '../_base_/default_runtime.py', + '../_base_/schedules/schedule_160k.py' +] +crop_size = (640, 640) +# pretrained = 'https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth' +pretrained = 'pretrained/beit_large_patch16_224_pt22k_ft22k.pth' +model = dict( + pretrained=pretrained, + backbone=dict( + type='BEiTAdapter', + img_size=640, + patch_size=16, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + use_abs_pos_emb=False, + use_rel_pos_bias=True, + init_values=1e-6, + drop_path_rate=0.3, + conv_inplane=64, + n_points=4, + deform_num_heads=16, + cffn_ratio=0.25, + deform_ratio=0.5, + interaction_indexes=[[0, 5], [6, 11], [12, 17], [18, 23]], + ), + decode_head=dict( + in_channels=[1024, 1024, 1024, 1024], + feat_channels=1024, + out_channels=1024, + num_queries=100, + pixel_decoder=dict( + type='MSDeformAttnPixelDecoder', + num_outs=3, + norm_cfg=dict(type='GN', num_groups=32), + act_cfg=dict(type='ReLU'), + encoder=dict( + type='DetrTransformerEncoder', + num_layers=6, + transformerlayers=dict( + type='BaseTransformerLayer', + attn_cfgs=dict( + type='MultiScaleDeformableAttention', + embed_dims=1024, + num_heads=32, + num_levels=3, + num_points=4, + im2col_step=64, + dropout=0.0, + batch_first=False, + norm_cfg=None, + init_cfg=None), + ffn_cfgs=dict( + type='FFN', + embed_dims=1024, + feedforward_channels=4096, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True)), + operation_order=('self_attn', 'norm', 'ffn', 'norm')), + init_cfg=None), + positional_encoding=dict( + type='SinePositionalEncoding', num_feats=512, normalize=True), + init_cfg=None), + positional_encoding=dict( + type='SinePositionalEncoding', num_feats=512, normalize=True), + transformer_decoder=dict( + type='DetrTransformerDecoder', + return_intermediate=True, + num_layers=9, + transformerlayers=dict( + type='DetrTransformerDecoderLayer', + attn_cfgs=dict( + type='MultiheadAttention', + embed_dims=1024, + num_heads=32, + attn_drop=0.0, + proj_drop=0.0, + dropout_layer=None, + batch_first=False), + ffn_cfgs=dict( + embed_dims=1024, + feedforward_channels=4096, + num_fcs=2, + act_cfg=dict(type='ReLU', inplace=True), + ffn_drop=0.0, + dropout_layer=None, + add_identity=True), + feedforward_channels=4096, + operation_order=('cross_attn', 'norm', 'self_attn', 'norm', + 'ffn', 'norm')), + init_cfg=None) + ), + test_cfg=dict(mode='slide', crop_size=crop_size, stride=(426, 426)) +) +# dataset settings +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', reduce_zero_label=True), + dict(type='Resize', img_scale=(2048, 640), ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='ToMask'), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg', 'gt_masks', 'gt_labels']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(2048, 640), + img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=True, + transforms=[ + dict(type='SETR_Resize', keep_ratio=True, + crop_size=crop_size, setr_multi_scale=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +optimizer = dict(_delete_=True, type='AdamW', lr=2e-5, betas=(0.9, 0.999), weight_decay=0.05, + constructor='LayerDecayOptimizerConstructor', + paramwise_cfg=dict(num_layers=24, layer_decay_rate=0.90)) +lr_config = dict(_delete_=True, + policy='poly', + warmup='linear', + warmup_iters=1500, + warmup_ratio=1e-6, + power=1.0, min_lr=0.0, by_epoch=False) +data = dict(samples_per_gpu=2, + train=dict(pipeline=train_pipeline), + val=dict(pipeline=test_pipeline), + test=dict(pipeline=test_pipeline)) +runner = dict(type='IterBasedRunner') +checkpoint_config = dict(by_epoch=False, interval=1000, max_keep_ckpts=1) +evaluation = dict(interval=16000, metric='mIoU', save_best='mIoU') \ No newline at end of file diff --git a/segmentation/configs/ade20k/mask2former_beit_adapter_large_640_160k_ade20k_ss.py b/segmentation/configs/ade20k/mask2former_beit_adapter_large_640_160k_ade20k_ss.py new file mode 100644 index 000000000..036b91a5a --- /dev/null +++ b/segmentation/configs/ade20k/mask2former_beit_adapter_large_640_160k_ade20k_ss.py @@ -0,0 +1,149 @@ +# Copyright (c) Shanghai AI Lab. All rights reserved. +_base_ = [ + '../_base_/models/mask2former_beit.py', + '../_base_/datasets/ade20k.py', + '../_base_/default_runtime.py', + '../_base_/schedules/schedule_160k.py' +] +crop_size = (640, 640) +# pretrained = 'https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth' +pretrained = 'pretrained/beit_large_patch16_224_pt22k_ft22k.pth' +model = dict( + pretrained=pretrained, + backbone=dict( + type='BEiTAdapter', + img_size=640, + patch_size=16, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + use_abs_pos_emb=False, + use_rel_pos_bias=True, + init_values=1e-6, + drop_path_rate=0.3, + conv_inplane=64, + n_points=4, + deform_num_heads=16, + cffn_ratio=0.25, + deform_ratio=0.5, + interaction_indexes=[[0, 5], [6, 11], [12, 17], [18, 23]], + ), + decode_head=dict( + in_channels=[1024, 1024, 1024, 1024], + feat_channels=1024, + out_channels=1024, + num_queries=100, + pixel_decoder=dict( + type='MSDeformAttnPixelDecoder', + num_outs=3, + norm_cfg=dict(type='GN', num_groups=32), + act_cfg=dict(type='ReLU'), + encoder=dict( + type='DetrTransformerEncoder', + num_layers=6, + transformerlayers=dict( + type='BaseTransformerLayer', + attn_cfgs=dict( + type='MultiScaleDeformableAttention', + embed_dims=1024, + num_heads=32, + num_levels=3, + num_points=4, + im2col_step=64, + dropout=0.0, + batch_first=False, + norm_cfg=None, + init_cfg=None), + ffn_cfgs=dict( + type='FFN', + embed_dims=1024, + feedforward_channels=4096, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True)), + operation_order=('self_attn', 'norm', 'ffn', 'norm')), + init_cfg=None), + positional_encoding=dict( + type='SinePositionalEncoding', num_feats=512, normalize=True), + init_cfg=None), + positional_encoding=dict( + type='SinePositionalEncoding', num_feats=512, normalize=True), + transformer_decoder=dict( + type='DetrTransformerDecoder', + return_intermediate=True, + num_layers=9, + transformerlayers=dict( + type='DetrTransformerDecoderLayer', + attn_cfgs=dict( + type='MultiheadAttention', + embed_dims=1024, + num_heads=32, + attn_drop=0.0, + proj_drop=0.0, + dropout_layer=None, + batch_first=False), + ffn_cfgs=dict( + embed_dims=1024, + feedforward_channels=4096, + num_fcs=2, + act_cfg=dict(type='ReLU', inplace=True), + ffn_drop=0.0, + dropout_layer=None, + add_identity=True), + feedforward_channels=4096, + operation_order=('cross_attn', 'norm', 'self_attn', 'norm', + 'ffn', 'norm')), + init_cfg=None) + ), + test_cfg=dict(mode='slide', crop_size=crop_size, stride=(426, 426)) +) +# dataset settings +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', reduce_zero_label=True), + dict(type='Resize', img_scale=(2048, 640), ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='ToMask'), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg', 'gt_masks', 'gt_labels']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(2048, 640), + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='ResizeToMultiple', size_divisor=32), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +optimizer = dict(_delete_=True, type='AdamW', lr=2e-5, betas=(0.9, 0.999), weight_decay=0.05, + constructor='LayerDecayOptimizerConstructor', + paramwise_cfg=dict(num_layers=24, layer_decay_rate=0.90)) +lr_config = dict(_delete_=True, + policy='poly', + warmup='linear', + warmup_iters=1500, + warmup_ratio=1e-6, + power=1.0, min_lr=0.0, by_epoch=False) +data = dict(samples_per_gpu=2, + train=dict(pipeline=train_pipeline), + val=dict(pipeline=test_pipeline), + test=dict(pipeline=test_pipeline)) +runner = dict(type='IterBasedRunner') +checkpoint_config = dict(by_epoch=False, interval=1000, max_keep_ckpts=1) +evaluation = dict(interval=16000, metric='mIoU', save_best='mIoU') diff --git a/segmentation/configs/ade20k/mask2former_beit_adapter_large_896_80k_ade20k_ms.py b/segmentation/configs/ade20k/mask2former_beit_adapter_large_896_80k_ade20k_ms.py new file mode 100644 index 000000000..a72a9e185 --- /dev/null +++ b/segmentation/configs/ade20k/mask2former_beit_adapter_large_896_80k_ade20k_ms.py @@ -0,0 +1,151 @@ +# Copyright (c) Shanghai AI Lab. All rights reserved. +_base_ = [ + '../_base_/models/mask2former_beit.py', + '../_base_/datasets/ade20k.py', + '../_base_/default_runtime.py', + '../_base_/schedules/schedule_80k.py' +] +crop_size = (896, 896) +# pretrained = 'https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth' +pretrained = 'pretrained/beit_large_patch16_224_pt22k_ft22k.pth' +model = dict( + type='EncoderDecoderMask2FormerAug', + pretrained=pretrained, + backbone=dict( + type='BEiTAdapter', + img_size=896, + patch_size=16, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + use_abs_pos_emb=False, + use_rel_pos_bias=True, + init_values=1e-6, + drop_path_rate=0.3, + conv_inplane=64, + n_points=4, + deform_num_heads=16, + cffn_ratio=0.25, + deform_ratio=0.5, + interaction_indexes=[[0, 5], [6, 11], [12, 17], [18, 23]], + ), + decode_head=dict( + in_channels=[1024, 1024, 1024, 1024], + feat_channels=1024, + out_channels=1024, + num_queries=200, + pixel_decoder=dict( + type='MSDeformAttnPixelDecoder', + num_outs=3, + norm_cfg=dict(type='GN', num_groups=32), + act_cfg=dict(type='ReLU'), + encoder=dict( + type='DetrTransformerEncoder', + num_layers=6, + transformerlayers=dict( + type='BaseTransformerLayer', + attn_cfgs=dict( + type='MultiScaleDeformableAttention', + embed_dims=1024, + num_heads=32, + num_levels=3, + num_points=4, + im2col_step=64, + dropout=0.0, + batch_first=False, + norm_cfg=None, + init_cfg=None), + ffn_cfgs=dict( + type='FFN', + embed_dims=1024, + feedforward_channels=4096, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True)), + operation_order=('self_attn', 'norm', 'ffn', 'norm')), + init_cfg=None), + positional_encoding=dict( + type='SinePositionalEncoding', num_feats=512, normalize=True), + init_cfg=None), + positional_encoding=dict( + type='SinePositionalEncoding', num_feats=512, normalize=True), + transformer_decoder=dict( + type='DetrTransformerDecoder', + return_intermediate=True, + num_layers=9, + transformerlayers=dict( + type='DetrTransformerDecoderLayer', + attn_cfgs=dict( + type='MultiheadAttention', + embed_dims=1024, + num_heads=32, + attn_drop=0.0, + proj_drop=0.0, + dropout_layer=None, + batch_first=False), + ffn_cfgs=dict( + embed_dims=1024, + feedforward_channels=4096, + num_fcs=2, + act_cfg=dict(type='ReLU', inplace=True), + ffn_drop=0.0, + dropout_layer=None, + add_identity=True), + feedforward_channels=4096, + operation_order=('cross_attn', 'norm', 'self_attn', 'norm', + 'ffn', 'norm')), + init_cfg=None) + ), + test_cfg=dict(mode='slide', crop_size=crop_size, stride=(512, 512)) +) +# dataset settings +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', reduce_zero_label=True), + dict(type='Resize', img_scale=(3584, 896), ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='ToMask'), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg', 'gt_masks', 'gt_labels']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(3584, 896), + img_ratios=[640./896., 768./896., 1.0], + flip=True, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='ResizeToMultiple', size_divisor=32), + dict(type='RandomFlip'), + dict(type='PadShortSide', size=896, pad_val=0, seg_pad_val=255), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +optimizer = dict(_delete_=True, type='AdamW', lr=2e-5, betas=(0.9, 0.999), weight_decay=0.05, + constructor='LayerDecayOptimizerConstructor', + paramwise_cfg=dict(num_layers=24, layer_decay_rate=0.90)) +lr_config = dict(_delete_=True, + policy='poly', + warmup='linear', + warmup_iters=1500, + warmup_ratio=1e-6, + power=1.0, min_lr=0.0, by_epoch=False) +data = dict(samples_per_gpu=1, + train=dict(pipeline=train_pipeline), + val=dict(pipeline=test_pipeline), + test=dict(pipeline=test_pipeline)) +runner = dict(type='IterBasedRunner') +checkpoint_config = dict(by_epoch=False, interval=1000, max_keep_ckpts=1) +evaluation = dict(interval=8000, metric='mIoU', save_best='mIoU') diff --git a/segmentation/configs/ade20k/mask2former_beit_adapter_large_896_80k_ade20k_ss.py b/segmentation/configs/ade20k/mask2former_beit_adapter_large_896_80k_ade20k_ss.py new file mode 100644 index 000000000..27e271a4f --- /dev/null +++ b/segmentation/configs/ade20k/mask2former_beit_adapter_large_896_80k_ade20k_ss.py @@ -0,0 +1,150 @@ +# Copyright (c) Shanghai AI Lab. All rights reserved. +_base_ = [ + '../_base_/models/mask2former_beit.py', + '../_base_/datasets/ade20k.py', + '../_base_/default_runtime.py', + '../_base_/schedules/schedule_80k.py' +] +crop_size = (896, 896) +# pretrained = 'https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth' +pretrained = 'pretrained/beit_large_patch16_224_pt22k_ft22k.pth' +model = dict( + type='EncoderDecoderMask2Former', + pretrained=pretrained, + backbone=dict( + type='BEiTAdapter', + img_size=896, + patch_size=16, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + use_abs_pos_emb=False, + use_rel_pos_bias=True, + init_values=1e-6, + drop_path_rate=0.3, + conv_inplane=64, + n_points=4, + deform_num_heads=16, + cffn_ratio=0.25, + deform_ratio=0.5, + interaction_indexes=[[0, 5], [6, 11], [12, 17], [18, 23]], + ), + decode_head=dict( + in_channels=[1024, 1024, 1024, 1024], + feat_channels=1024, + out_channels=1024, + num_queries=200, + pixel_decoder=dict( + type='MSDeformAttnPixelDecoder', + num_outs=3, + norm_cfg=dict(type='GN', num_groups=32), + act_cfg=dict(type='ReLU'), + encoder=dict( + type='DetrTransformerEncoder', + num_layers=6, + transformerlayers=dict( + type='BaseTransformerLayer', + attn_cfgs=dict( + type='MultiScaleDeformableAttention', + embed_dims=1024, + num_heads=32, + num_levels=3, + num_points=4, + im2col_step=64, + dropout=0.0, + batch_first=False, + norm_cfg=None, + init_cfg=None), + ffn_cfgs=dict( + type='FFN', + embed_dims=1024, + feedforward_channels=4096, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True)), + operation_order=('self_attn', 'norm', 'ffn', 'norm')), + init_cfg=None), + positional_encoding=dict( + type='SinePositionalEncoding', num_feats=512, normalize=True), + init_cfg=None), + positional_encoding=dict( + type='SinePositionalEncoding', num_feats=512, normalize=True), + transformer_decoder=dict( + type='DetrTransformerDecoder', + return_intermediate=True, + num_layers=9, + transformerlayers=dict( + type='DetrTransformerDecoderLayer', + attn_cfgs=dict( + type='MultiheadAttention', + embed_dims=1024, + num_heads=32, + attn_drop=0.0, + proj_drop=0.0, + dropout_layer=None, + batch_first=False), + ffn_cfgs=dict( + embed_dims=1024, + feedforward_channels=4096, + num_fcs=2, + act_cfg=dict(type='ReLU', inplace=True), + ffn_drop=0.0, + dropout_layer=None, + add_identity=True), + feedforward_channels=4096, + operation_order=('cross_attn', 'norm', 'self_attn', 'norm', + 'ffn', 'norm')), + init_cfg=None) + ), + test_cfg=dict(mode='slide', crop_size=crop_size, stride=(512, 512)) +) +# dataset settings +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', reduce_zero_label=True), + dict(type='Resize', img_scale=(3584, 896), ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='ToMask'), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg', 'gt_masks', 'gt_labels']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(3584, 896), + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='ResizeToMultiple', size_divisor=32), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +optimizer = dict(_delete_=True, type='AdamW', lr=2e-5, betas=(0.9, 0.999), weight_decay=0.05, + constructor='LayerDecayOptimizerConstructor', + paramwise_cfg=dict(num_layers=24, layer_decay_rate=0.90)) +lr_config = dict(_delete_=True, + policy='poly', + warmup='linear', + warmup_iters=1500, + warmup_ratio=1e-6, + power=1.0, min_lr=0.0, by_epoch=False) +data = dict(samples_per_gpu=1, + train=dict(pipeline=train_pipeline), + val=dict(pipeline=test_pipeline), + test=dict(pipeline=test_pipeline)) +runner = dict(type='IterBasedRunner') +checkpoint_config = dict(by_epoch=False, interval=1000, max_keep_ckpts=1) +evaluation = dict(interval=8000, metric='mIoU', save_best='mIoU') \ No newline at end of file diff --git a/segmentation/configs/ade20k/upernet_beit_adapter_large_640_160k_ade20k_ms.py b/segmentation/configs/ade20k/upernet_beit_adapter_large_640_160k_ade20k_ms.py new file mode 100644 index 000000000..e8ad4efbc --- /dev/null +++ b/segmentation/configs/ade20k/upernet_beit_adapter_large_640_160k_ade20k_ms.py @@ -0,0 +1,88 @@ +# Copyright (c) Shanghai AI Lab. All rights reserved. +_base_ = [ + '../_base_/models/upernet_beit.py', + '../_base_/datasets/ade20k.py', + '../_base_/default_runtime.py', + '../_base_/schedules/schedule_160k.py' +] +crop_size = (640, 640) +# pretrained = 'https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth' +pretrained = 'pretrained/beit_large_patch16_224_pt22k_ft22k.pth' +model = dict( + pretrained=pretrained, + backbone=dict( + type='BEiTAdapter', + img_size=640, + patch_size=16, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + use_abs_pos_emb=False, + use_rel_pos_bias=True, + init_values=1e-6, + drop_path_rate=0.3, + conv_inplane=64, + n_points=4, + deform_num_heads=16, + cffn_ratio=0.25, + deform_ratio=0.5, + interaction_indexes=[[0, 5], [6, 11], [12, 17], [18, 23]], + ), + decode_head=dict( + in_channels=[1024, 1024, 1024, 1024], + num_classes=150, + channels=1024, + ), + auxiliary_head=dict( + in_channels=1024, + num_classes=150 + ), + test_cfg = dict(mode='slide', crop_size=crop_size, stride=(426, 426)) +) +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', reduce_zero_label=True), + dict(type='Resize', img_scale=(2048, 640), ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(2048, 640), + img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=True, + transforms=[ + dict(type='SETR_Resize', keep_ratio=True, + crop_size=crop_size, setr_multi_scale=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +optimizer = dict(_delete_=True, type='AdamW', lr=2e-5, betas=(0.9, 0.999), weight_decay=0.05, + constructor='LayerDecayOptimizerConstructor', + paramwise_cfg=dict(num_layers=24, layer_decay_rate=0.90)) +lr_config = dict(_delete_=True, policy='poly', + warmup='linear', + warmup_iters=1500, + warmup_ratio=1e-6, + power=1.0, min_lr=0.0, by_epoch=False) +data=dict(samples_per_gpu=2, + train=dict(pipeline=train_pipeline), + val=dict(pipeline=test_pipeline), + test=dict(pipeline=test_pipeline)) +runner = dict(type='IterBasedRunner') +checkpoint_config = dict(by_epoch=False, interval=1000, max_keep_ckpts=1) +evaluation = dict(interval=16000, metric='mIoU', save_best='mIoU') \ No newline at end of file diff --git a/segmentation/configs/ade20k/upernet_beit_adapter_large_640_160k_ade20k_ss.py b/segmentation/configs/ade20k/upernet_beit_adapter_large_640_160k_ade20k_ss.py new file mode 100644 index 000000000..a79dd5829 --- /dev/null +++ b/segmentation/configs/ade20k/upernet_beit_adapter_large_640_160k_ade20k_ss.py @@ -0,0 +1,88 @@ +# Copyright (c) Shanghai AI Lab. All rights reserved. +_base_ = [ + '../_base_/models/upernet_beit.py', + '../_base_/datasets/ade20k.py', + '../_base_/default_runtime.py', + '../_base_/schedules/schedule_160k.py' +] +crop_size = (640, 640) +# pretrained = 'https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth' +pretrained = 'pretrained/beit_large_patch16_224_pt22k_ft22k.pth' +model = dict( + pretrained=pretrained, + backbone=dict( + type='BEiTAdapter', + img_size=640, + patch_size=16, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + use_abs_pos_emb=False, + use_rel_pos_bias=True, + init_values=1e-6, + drop_path_rate=0.3, + conv_inplane=64, + n_points=4, + deform_num_heads=16, + cffn_ratio=0.25, + deform_ratio=0.5, + interaction_indexes=[[0, 5], [6, 11], [12, 17], [18, 23]], + ), + decode_head=dict( + in_channels=[1024, 1024, 1024, 1024], + num_classes=150, + channels=1024, + ), + auxiliary_head=dict( + in_channels=1024, + num_classes=150 + ), + test_cfg = dict(mode='slide', crop_size=crop_size, stride=(426, 426)) +) +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', reduce_zero_label=True), + dict(type='Resize', img_scale=(2048, 640), ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(2048, 640), + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='ResizeToMultiple', size_divisor=32), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +optimizer = dict(_delete_=True, type='AdamW', lr=2e-5, betas=(0.9, 0.999), weight_decay=0.05, + constructor='LayerDecayOptimizerConstructor', + paramwise_cfg=dict(num_layers=24, layer_decay_rate=0.90)) +lr_config = dict(_delete_=True, policy='poly', + warmup='linear', + warmup_iters=1500, + warmup_ratio=1e-6, + power=1.0, min_lr=0.0, by_epoch=False) +data=dict(samples_per_gpu=2, + train=dict(pipeline=train_pipeline), + val=dict(pipeline=test_pipeline), + test=dict(pipeline=test_pipeline)) +runner = dict(type='IterBasedRunner') +checkpoint_config = dict(by_epoch=False, interval=1000, max_keep_ckpts=1) +evaluation = dict(interval=16000, metric='mIoU', save_best='mIoU') \ No newline at end of file diff --git a/segmentation/configs/cityscapes/README.md b/segmentation/configs/cityscapes/README.md new file mode 100644 index 000000000..cc1c8ac0b --- /dev/null +++ b/segmentation/configs/cityscapes/README.md @@ -0,0 +1,21 @@ +# Cityscapes + + + +## Introduction + +Cityscapes is a large-scale database which focuses on semantic understanding of urban street scenes. It provides semantic, instance-wise, and dense pixel annotations for 30 classes grouped into 8 categories (flat surfaces, humans, vehicles, constructions, objects, nature, sky, and void). The dataset consists of around 5000 fine annotated images and 20000 coarse annotated ones. Data was captured in 50 cities during several months, daytimes, and good weather conditions. It was originally recorded as video so the frames were manually selected to have the following features: large number of dynamic objects, varying scene layout, and varying background. + +## Results and Models + +Cityscapes val set + +| Method | Backbone | Pre-train | Batch Size | Lr schd | Crop Size | mIoU (SS) | mIoU (MS) | #Param | Config | Download | +|:-----------:|:-------------:|:--------------------------------------------------------------------------------------------------------------------------------:|:----------:|:-------:|:---------:|:------------------------------------------------------------------------------------------:|:------------------------------------------------------------------------------------------:|:------:|:-------------------------------------------------------------------:|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:| +| Mask2Former | ViT-Adapter-L | [Mapillary](https://github.com/czczup/ViT-Adapter/releases/download/v0.2.3/mask2former_beit_adapter_large_896_80k_mapillary.zip) | 16x1 | 80k | 896 | [84.9](https://drive.google.com/file/d/1LKy0zz-brCBbKGmUWquadILaBHdDLR6s/view?usp=sharing) | [85.8](https://drive.google.com/file/d/1LSJvK1BPSbzm9eWpKL8Xo7RmYBrd2xux/view?usp=sharing) | 571M | [config](./mask2former_beit_adapter_large_896_80k_cityscapes_ss.py) | [model](https://github.com/czczup/ViT-Adapter/releases/download/v0.2.3/mask2former_beit_adapter_large_896_80k_cityscapes.zip) \| [log](https://github.com/czczup/ViT-Adapter/releases/download/v0.2.3/log.txt) | + +Cityscapes test set + +| Method | Backbone | Pre-train | Batch Size | Lr schd | Crop Size | mIoU (SS) | mIoU (MS) | #Param | Config | Download | +|:-----------:|:-------------:|:--------------------------------------------------------------------------------------------------------------------------------:|:----------:|:-------:|:---------:|:---------:|:---------------------------------------------------------------------------------------------------------------------------------:|:------:|:-------------------------------------------------------------------:|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:| +| Mask2Former | ViT-Adapter-L | [Mapillary](https://github.com/czczup/ViT-Adapter/releases/download/v0.2.3/mask2former_beit_adapter_large_896_80k_mapillary.zip) | 16x1 | 80k | 896 | - | [85.2](https://www.cityscapes-dataset.com/anonymous-results/?id=0ca6821dc3183ff970bd5266f812df2eaa4519ecb1973ca1308d65a3b546bf27) | 571M | [config](./mask2former_beit_adapter_large_896_80k_cityscapes_ss.py) | [model](https://github.com/czczup/ViT-Adapter/releases/download/v0.2.3/mask2former_beit_adapter_large_896_80k_cityscapes.zip) \| [log](https://github.com/czczup/ViT-Adapter/releases/download/v0.2.3/log.txt) | diff --git a/segmentation/configs/cityscapes/mask2former_beit_adapter_large_896_80k_cityscapes_ms.py b/segmentation/configs/cityscapes/mask2former_beit_adapter_large_896_80k_cityscapes_ms.py new file mode 100644 index 000000000..8fad8cbc8 --- /dev/null +++ b/segmentation/configs/cityscapes/mask2former_beit_adapter_large_896_80k_cityscapes_ms.py @@ -0,0 +1,150 @@ +# Copyright (c) Shanghai AI Lab. All rights reserved. +_base_ = [ + '../_base_/models/mask2former_beit_cityscapes.py', + '../_base_/datasets/cityscapes_896x896.py', + '../_base_/default_runtime.py', + '../_base_/schedules/schedule_80k.py' +] +crop_size = (896, 896) +# pretrained = 'https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth' +pretrained = 'pretrained/beit_large_patch16_224_pt22k_ft22k.pth' +model = dict( + type='EncoderDecoderMask2FormerAug', + pretrained=pretrained, + backbone=dict( + type='BEiTAdapter', + img_size=896, + patch_size=16, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + use_abs_pos_emb=False, + use_rel_pos_bias=True, + init_values=1e-6, + drop_path_rate=0.3, + conv_inplane=64, + n_points=4, + deform_num_heads=16, + cffn_ratio=0.25, + deform_ratio=0.5, + interaction_indexes=[[0, 5], [6, 11], [12, 17], [18, 23]], + ), + decode_head=dict( + in_channels=[1024, 1024, 1024, 1024], + feat_channels=1024, + out_channels=1024, + num_queries=100, + pixel_decoder=dict( + type='MSDeformAttnPixelDecoder', + num_outs=3, + norm_cfg=dict(type='GN', num_groups=32), + act_cfg=dict(type='ReLU'), + encoder=dict( + type='DetrTransformerEncoder', + num_layers=6, + transformerlayers=dict( + type='BaseTransformerLayer', + attn_cfgs=dict( + type='MultiScaleDeformableAttention', + embed_dims=1024, + num_heads=32, + num_levels=3, + num_points=4, + im2col_step=64, + dropout=0.0, + batch_first=False, + norm_cfg=None, + init_cfg=None), + ffn_cfgs=dict( + type='FFN', + embed_dims=1024, + feedforward_channels=4096, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True)), + operation_order=('self_attn', 'norm', 'ffn', 'norm')), + init_cfg=None), + positional_encoding=dict( + type='SinePositionalEncoding', num_feats=512, normalize=True), + init_cfg=None), + positional_encoding=dict( + type='SinePositionalEncoding', num_feats=512, normalize=True), + transformer_decoder=dict( + type='DetrTransformerDecoder', + return_intermediate=True, + num_layers=9, + transformerlayers=dict( + type='DetrTransformerDecoderLayer', + attn_cfgs=dict( + type='MultiheadAttention', + embed_dims=1024, + num_heads=32, + attn_drop=0.0, + proj_drop=0.0, + dropout_layer=None, + batch_first=False), + ffn_cfgs=dict( + embed_dims=1024, + feedforward_channels=4096, + num_fcs=2, + act_cfg=dict(type='ReLU', inplace=True), + ffn_drop=0.0, + dropout_layer=None, + add_identity=True), + feedforward_channels=4096, + operation_order=('cross_attn', 'norm', 'self_attn', 'norm', + 'ffn', 'norm')), + init_cfg=None) + ), + test_cfg=dict(mode='slide', crop_size=crop_size, stride=(512, 512)) +) +# dataset settings +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='ToMask'), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg', 'gt_masks', 'gt_labels']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(2048, 1024), + img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0], + flip=True, + transforms=[ + dict(type='SETR_Resize', keep_ratio=True, + crop_size=crop_size, setr_multi_scale=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +optimizer = dict(_delete_=True, type='AdamW', lr=2e-5, betas=(0.9, 0.999), weight_decay=0.05, + constructor='LayerDecayOptimizerConstructor', + paramwise_cfg=dict(num_layers=24, layer_decay_rate=0.90)) +lr_config = dict(_delete_=True, + policy='poly', + warmup='linear', + warmup_iters=1500, + warmup_ratio=1e-6, + power=1.0, min_lr=0.0, by_epoch=False) +data = dict(samples_per_gpu=1, + train=dict(pipeline=train_pipeline), + test=dict(pipeline=test_pipeline), + val=dict(pipeline=test_pipeline)) +runner = dict(type='IterBasedRunner') +checkpoint_config = dict(by_epoch=False, interval=1000, max_keep_ckpts=1) +evaluation = dict(interval=1000, metric='mIoU', save_best='mIoU') \ No newline at end of file diff --git a/segmentation/configs/cityscapes/mask2former_beit_adapter_large_896_80k_cityscapes_ss.py b/segmentation/configs/cityscapes/mask2former_beit_adapter_large_896_80k_cityscapes_ss.py new file mode 100644 index 000000000..d65e47e7a --- /dev/null +++ b/segmentation/configs/cityscapes/mask2former_beit_adapter_large_896_80k_cityscapes_ss.py @@ -0,0 +1,131 @@ +# Copyright (c) Shanghai AI Lab. All rights reserved. +_base_ = [ + '../_base_/models/mask2former_beit_cityscapes.py', + '../_base_/datasets/cityscapes_896x896.py', + '../_base_/default_runtime.py', + '../_base_/schedules/schedule_80k.py' +] +crop_size = (896, 896) +# pretrained = 'https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth' +pretrained = 'pretrained/beit_large_patch16_224_pt22k_ft22k.pth' +model = dict( + pretrained=pretrained, + backbone=dict( + type='BEiTAdapter', + img_size=896, + patch_size=16, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + use_abs_pos_emb=False, + use_rel_pos_bias=True, + init_values=1e-6, + drop_path_rate=0.3, + conv_inplane=64, + n_points=4, + deform_num_heads=16, + cffn_ratio=0.25, + deform_ratio=0.5, + interaction_indexes=[[0, 5], [6, 11], [12, 17], [18, 23]], + ), + decode_head=dict( + in_channels=[1024, 1024, 1024, 1024], + feat_channels=1024, + out_channels=1024, + num_queries=100, + pixel_decoder=dict( + type='MSDeformAttnPixelDecoder', + num_outs=3, + norm_cfg=dict(type='GN', num_groups=32), + act_cfg=dict(type='ReLU'), + encoder=dict( + type='DetrTransformerEncoder', + num_layers=6, + transformerlayers=dict( + type='BaseTransformerLayer', + attn_cfgs=dict( + type='MultiScaleDeformableAttention', + embed_dims=1024, + num_heads=32, + num_levels=3, + num_points=4, + im2col_step=64, + dropout=0.0, + batch_first=False, + norm_cfg=None, + init_cfg=None), + ffn_cfgs=dict( + type='FFN', + embed_dims=1024, + feedforward_channels=4096, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True)), + operation_order=('self_attn', 'norm', 'ffn', 'norm')), + init_cfg=None), + positional_encoding=dict( + type='SinePositionalEncoding', num_feats=512, normalize=True), + init_cfg=None), + positional_encoding=dict( + type='SinePositionalEncoding', num_feats=512, normalize=True), + transformer_decoder=dict( + type='DetrTransformerDecoder', + return_intermediate=True, + num_layers=9, + transformerlayers=dict( + type='DetrTransformerDecoderLayer', + attn_cfgs=dict( + type='MultiheadAttention', + embed_dims=1024, + num_heads=32, + attn_drop=0.0, + proj_drop=0.0, + dropout_layer=None, + batch_first=False), + ffn_cfgs=dict( + embed_dims=1024, + feedforward_channels=4096, + num_fcs=2, + act_cfg=dict(type='ReLU', inplace=True), + ffn_drop=0.0, + dropout_layer=None, + add_identity=True), + feedforward_channels=4096, + operation_order=('cross_attn', 'norm', 'self_attn', 'norm', + 'ffn', 'norm')), + init_cfg=None) + ), + test_cfg=dict(mode='slide', crop_size=crop_size, stride=(512, 512)) +) +# dataset settings +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='ToMask'), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg', 'gt_masks', 'gt_labels']) +] +optimizer = dict(_delete_=True, type='AdamW', lr=2e-5, betas=(0.9, 0.999), weight_decay=0.05, + constructor='LayerDecayOptimizerConstructor', + paramwise_cfg=dict(num_layers=24, layer_decay_rate=0.90)) +lr_config = dict(_delete_=True, + policy='poly', + warmup='linear', + warmup_iters=1500, + warmup_ratio=1e-6, + power=1.0, min_lr=0.0, by_epoch=False) +data = dict(samples_per_gpu=1, + train=dict(pipeline=train_pipeline)) +runner = dict(type='IterBasedRunner') +checkpoint_config = dict(by_epoch=False, interval=1000, max_keep_ckpts=1) +evaluation = dict(interval=1000, metric='mIoU', save_best='mIoU') \ No newline at end of file diff --git a/segmentation/configs/cityscapes/mask2former_beit_adapter_large_896_80k_mapillary_ss.py b/segmentation/configs/cityscapes/mask2former_beit_adapter_large_896_80k_mapillary_ss.py new file mode 100644 index 000000000..e620a1b22 --- /dev/null +++ b/segmentation/configs/cityscapes/mask2former_beit_adapter_large_896_80k_mapillary_ss.py @@ -0,0 +1,152 @@ +# Copyright (c) Shanghai AI Lab. All rights reserved. +_base_ = [ + '../_base_/models/mask2former_beit_cityscapes.py', + '../_base_/datasets/mapillary_896x896.py', + '../_base_/default_runtime.py', + '../_base_/schedules/schedule_80k.py' +] +crop_size = (896, 896) +# pretrained = 'https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth' +pretrained = 'pretrained/beit_large_patch16_224_pt22k_ft22k.pth' +model = dict( + pretrained=pretrained, + backbone=dict( + type='BEiTAdapter', + img_size=896, + patch_size=16, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + qkv_bias=True, + use_abs_pos_emb=False, + use_rel_pos_bias=True, + init_values=1e-6, + drop_path_rate=0.3, + conv_inplane=64, + n_points=4, + deform_num_heads=16, + cffn_ratio=0.25, + deform_ratio=0.5, + interaction_indexes=[[0, 5], [6, 11], [12, 17], [18, 23]], + ), + decode_head=dict( + in_channels=[1024, 1024, 1024, 1024], + feat_channels=1024, + out_channels=1024, + num_queries=100, + pixel_decoder=dict( + type='MSDeformAttnPixelDecoder', + num_outs=3, + norm_cfg=dict(type='GN', num_groups=32), + act_cfg=dict(type='ReLU'), + encoder=dict( + type='DetrTransformerEncoder', + num_layers=6, + transformerlayers=dict( + type='BaseTransformerLayer', + attn_cfgs=dict( + type='MultiScaleDeformableAttention', + embed_dims=1024, + num_heads=32, + num_levels=3, + num_points=4, + im2col_step=64, + dropout=0.0, + batch_first=False, + norm_cfg=None, + init_cfg=None), + ffn_cfgs=dict( + type='FFN', + embed_dims=1024, + feedforward_channels=4096, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True)), + operation_order=('self_attn', 'norm', 'ffn', 'norm')), + init_cfg=None), + positional_encoding=dict( + type='SinePositionalEncoding', num_feats=512, normalize=True), + init_cfg=None), + positional_encoding=dict( + type='SinePositionalEncoding', num_feats=512, normalize=True), + transformer_decoder=dict( + type='DetrTransformerDecoder', + return_intermediate=True, + num_layers=9, + transformerlayers=dict( + type='DetrTransformerDecoderLayer', + attn_cfgs=dict( + type='MultiheadAttention', + embed_dims=1024, + num_heads=32, + attn_drop=0.0, + proj_drop=0.0, + dropout_layer=None, + batch_first=False), + ffn_cfgs=dict( + embed_dims=1024, + feedforward_channels=4096, + num_fcs=2, + act_cfg=dict(type='ReLU', inplace=True), + ffn_drop=0.0, + dropout_layer=None, + add_identity=True), + feedforward_channels=4096, + operation_order=('cross_attn', 'norm', 'self_attn', 'norm', + 'ffn', 'norm')), + init_cfg=None) + ), + test_cfg=dict(mode='slide', crop_size=crop_size, stride=(512, 512)) +) +# dataset settings +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='MapillaryHack'), + dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='ToMask'), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg', 'gt_masks', 'gt_labels']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(2048, 1024), + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='ResizeToMultiple', size_divisor=32), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +optimizer = dict(_delete_=True, type='AdamW', lr=2e-5, betas=(0.9, 0.999), weight_decay=0.05, + constructor='LayerDecayOptimizerConstructor', + paramwise_cfg=dict(num_layers=24, layer_decay_rate=0.90)) + +lr_config = dict(_delete_=True, + policy='poly', + warmup='linear', + warmup_iters=1500, + warmup_ratio=1e-6, + power=1.0, min_lr=0.0, by_epoch=False) + +data = dict(samples_per_gpu=1, + train=dict(pipeline=train_pipeline), + val=dict(pipeline=test_pipeline), + test=dict(pipeline=test_pipeline)) +runner = dict(type='IterBasedRunner') +checkpoint_config = dict(by_epoch=False, interval=1000, max_keep_ckpts=1) +evaluation = dict(interval=1000, metric='mIoU', save_best='mIoU') \ No newline at end of file diff --git a/segmentation/dist_test.sh b/segmentation/dist_test.sh new file mode 100755 index 000000000..a84ed9baf --- /dev/null +++ b/segmentation/dist_test.sh @@ -0,0 +1,9 @@ +#!/usr/bin/env bash + +CONFIG=$1 +CHECKPOINT=$2 +GPUS=$3 +PORT=${PORT:-29510} +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ +python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ + $(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4} diff --git a/segmentation/dist_train.sh b/segmentation/dist_train.sh new file mode 100755 index 000000000..c0bd5575d --- /dev/null +++ b/segmentation/dist_train.sh @@ -0,0 +1,9 @@ +#!/usr/bin/env bash + +CONFIG=$1 +GPUS=$2 +PORT=${PORT:-29300} + +#PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ +python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ + $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3} diff --git a/segmentation/mmcv_custom/__init__.py b/segmentation/mmcv_custom/__init__.py new file mode 100644 index 000000000..7187bdab6 --- /dev/null +++ b/segmentation/mmcv_custom/__init__.py @@ -0,0 +1,8 @@ +from .checkpoint import load_checkpoint +from .customized_text import CustomizedTextLoggerHook +from .layer_decay_optimizer_constructor import LayerDecayOptimizerConstructor + +__all__ = [ + 'LayerDecayOptimizerConstructor', 'CustomizedTextLoggerHook', + 'load_checkpoint' +] diff --git a/segmentation/mmcv_custom/checkpoint.py b/segmentation/mmcv_custom/checkpoint.py new file mode 100644 index 000000000..c7d13cbba --- /dev/null +++ b/segmentation/mmcv_custom/checkpoint.py @@ -0,0 +1,654 @@ +# Copyright (c) Open-MMLab. All rights reserved. +import io +import math +import os +import os.path as osp +import pkgutil +import time +import warnings +from collections import OrderedDict +from importlib import import_module +from tempfile import TemporaryDirectory + +import mmcv +import numpy as np +import torch +import torchvision +from mmcv.fileio import FileClient +from mmcv.fileio import load as load_file +from mmcv.parallel import is_module_wrapper +from mmcv.runner import get_dist_info +from mmcv.utils import mkdir_or_exist +from scipy import interpolate +from torch.nn import functional as F +from torch.optim import Optimizer +from torch.utils import model_zoo + +ENV_MMCV_HOME = 'MMCV_HOME' +ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME' +DEFAULT_CACHE_DIR = '~/.cache' + + +def _get_mmcv_home(): + mmcv_home = os.path.expanduser( + os.getenv( + ENV_MMCV_HOME, + os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), + 'mmcv'))) + + mkdir_or_exist(mmcv_home) + return mmcv_home + + +def load_state_dict(module, state_dict, strict=False, logger=None): + """Load state_dict to a module. + + This method is modified from :meth:`torch.nn.Module.load_state_dict`. + Default value for ``strict`` is set to ``False`` and the message for + param mismatch will be shown even if strict is False. + Args: + module (Module): Module that receives the state_dict. + state_dict (OrderedDict): Weights. + strict (bool): whether to strictly enforce that the keys + in :attr:`state_dict` match the keys returned by this module's + :meth:`~torch.nn.Module.state_dict` function. Default: ``False``. + logger (:obj:`logging.Logger`, optional): Logger to log the error + message. If not specified, print function will be used. + """ + unexpected_keys = [] + all_missing_keys = [] + err_msg = [] + + metadata = getattr(state_dict, '_metadata', None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + # use _load_from_state_dict to enable checkpoint version control + def load(module, prefix=''): + # recursively check parallel module in case that the model has a + # complicated structure, e.g., nn.Module(nn.Module(DDP)) + if is_module_wrapper(module): + module = module.module + local_metadata = {} if metadata is None else metadata.get( + prefix[:-1], {}) + module._load_from_state_dict(state_dict, prefix, local_metadata, True, + all_missing_keys, unexpected_keys, + err_msg) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + '.') + + load(module) + load = None # break load->load reference cycle + + # ignore "num_batches_tracked" of BN layers + missing_keys = [ + key for key in all_missing_keys if 'num_batches_tracked' not in key + ] + + if unexpected_keys: + err_msg.append('unexpected key in source ' + f'state_dict: {", ".join(unexpected_keys)}\n') + if missing_keys: + err_msg.append( + f'missing keys in source state_dict: {", ".join(missing_keys)}\n') + + rank, _ = get_dist_info() + if len(err_msg) > 0 and rank == 0: + err_msg.insert( + 0, 'The model and loaded state dict do not match exactly\n') + err_msg = '\n'.join(err_msg) + if strict: + raise RuntimeError(err_msg) + elif logger is not None: + logger.warning(err_msg) + else: + print(err_msg) + + +def load_url_dist(url, model_dir=None, map_location='cpu'): + """In distributed setting, this function only download checkpoint at local + rank 0.""" + rank, world_size = get_dist_info() + rank = int(os.environ.get('LOCAL_RANK', rank)) + if rank == 0: + checkpoint = model_zoo.load_url(url, + model_dir=model_dir, + map_location=map_location) + if world_size > 1: + torch.distributed.barrier() + if rank > 0: + checkpoint = model_zoo.load_url(url, + model_dir=model_dir, + map_location=map_location) + return checkpoint + + +def load_pavimodel_dist(model_path, map_location=None): + """In distributed setting, this function only download checkpoint at local + rank 0.""" + try: + from pavi import modelcloud + except ImportError: + raise ImportError( + 'Please install pavi to load checkpoint from modelcloud.') + rank, world_size = get_dist_info() + rank = int(os.environ.get('LOCAL_RANK', rank)) + if rank == 0: + model = modelcloud.get(model_path) + with TemporaryDirectory() as tmp_dir: + downloaded_file = osp.join(tmp_dir, model.name) + model.download(downloaded_file) + checkpoint = torch.load(downloaded_file, map_location=map_location) + if world_size > 1: + torch.distributed.barrier() + if rank > 0: + model = modelcloud.get(model_path) + with TemporaryDirectory() as tmp_dir: + downloaded_file = osp.join(tmp_dir, model.name) + model.download(downloaded_file) + checkpoint = torch.load(downloaded_file, + map_location=map_location) + return checkpoint + + +def load_fileclient_dist(filename, backend, map_location): + """In distributed setting, this function only download checkpoint at local + rank 0.""" + rank, world_size = get_dist_info() + rank = int(os.environ.get('LOCAL_RANK', rank)) + allowed_backends = ['ceph'] + if backend not in allowed_backends: + raise ValueError(f'Load from Backend {backend} is not supported.') + if rank == 0: + fileclient = FileClient(backend=backend) + buffer = io.BytesIO(fileclient.get(filename)) + checkpoint = torch.load(buffer, map_location=map_location) + if world_size > 1: + torch.distributed.barrier() + if rank > 0: + fileclient = FileClient(backend=backend) + buffer = io.BytesIO(fileclient.get(filename)) + checkpoint = torch.load(buffer, map_location=map_location) + return checkpoint + + +def get_torchvision_models(): + model_urls = dict() + for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__): + if ispkg: + continue + _zoo = import_module(f'torchvision.models.{name}') + if hasattr(_zoo, 'model_urls'): + _urls = getattr(_zoo, 'model_urls') + model_urls.update(_urls) + return model_urls + + +def get_external_models(): + mmcv_home = _get_mmcv_home() + default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json') + default_urls = load_file(default_json_path) + assert isinstance(default_urls, dict) + external_json_path = osp.join(mmcv_home, 'open_mmlab.json') + if osp.exists(external_json_path): + external_urls = load_file(external_json_path) + assert isinstance(external_urls, dict) + default_urls.update(external_urls) + + return default_urls + + +def get_mmcls_models(): + mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json') + mmcls_urls = load_file(mmcls_json_path) + + return mmcls_urls + + +def get_deprecated_model_names(): + deprecate_json_path = osp.join(mmcv.__path__[0], + 'model_zoo/deprecated.json') + deprecate_urls = load_file(deprecate_json_path) + assert isinstance(deprecate_urls, dict) + + return deprecate_urls + + +def _process_mmcls_checkpoint(checkpoint): + state_dict = checkpoint['state_dict'] + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + if k.startswith('backbone.'): + new_state_dict[k[9:]] = v + new_checkpoint = dict(state_dict=new_state_dict) + + return new_checkpoint + + +def _load_checkpoint(filename, map_location=None): + """Load checkpoint from somewhere (modelzoo, file, url). + + Args: + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for + details. + map_location (str | None): Same as :func:`torch.load`. Default: None. + Returns: + dict | OrderedDict: The loaded checkpoint. It can be either an + OrderedDict storing model weights or a dict containing other + information, which depends on the checkpoint. + """ + if filename.startswith('modelzoo://'): + warnings.warn('The URL scheme of "modelzoo://" is deprecated, please ' + 'use "torchvision://" instead') + model_urls = get_torchvision_models() + model_name = filename[11:] + checkpoint = load_url_dist(model_urls[model_name]) + elif filename.startswith('torchvision://'): + model_urls = get_torchvision_models() + model_name = filename[14:] + checkpoint = load_url_dist(model_urls[model_name]) + elif filename.startswith('open-mmlab://'): + model_urls = get_external_models() + model_name = filename[13:] + deprecated_urls = get_deprecated_model_names() + if model_name in deprecated_urls: + warnings.warn(f'open-mmlab://{model_name} is deprecated in favor ' + f'of open-mmlab://{deprecated_urls[model_name]}') + model_name = deprecated_urls[model_name] + model_url = model_urls[model_name] + # check if is url + if model_url.startswith(('http://', 'https://')): + checkpoint = load_url_dist(model_url) + else: + filename = osp.join(_get_mmcv_home(), model_url) + if not osp.isfile(filename): + raise IOError(f'{filename} is not a checkpoint file') + checkpoint = torch.load(filename, map_location=map_location) + elif filename.startswith('mmcls://'): + model_urls = get_mmcls_models() + model_name = filename[8:] + checkpoint = load_url_dist(model_urls[model_name]) + checkpoint = _process_mmcls_checkpoint(checkpoint) + elif filename.startswith(('http://', 'https://')): + checkpoint = load_url_dist(filename) + elif filename.startswith('pavi://'): + model_path = filename[7:] + checkpoint = load_pavimodel_dist(model_path, map_location=map_location) + elif filename.startswith('s3://'): + checkpoint = load_fileclient_dist(filename, + backend='ceph', + map_location=map_location) + else: + if not osp.isfile(filename): + raise IOError(f'{filename} is not a checkpoint file') + checkpoint = torch.load(filename, map_location=map_location) + return checkpoint + + +def cosine_scheduler(base_value, + final_value, + epochs, + niter_per_ep, + warmup_epochs=0, + start_warmup_value=0, + warmup_steps=-1): + warmup_schedule = np.array([]) + warmup_iters = warmup_epochs * niter_per_ep + if warmup_steps > 0: + warmup_iters = warmup_steps + print('Set warmup steps = %d' % warmup_iters) + if warmup_epochs > 0: + warmup_schedule = np.linspace(start_warmup_value, base_value, + warmup_iters) + + iters = np.arange(epochs * niter_per_ep - warmup_iters) + schedule = np.array([ + final_value + 0.5 * (base_value - final_value) * + (1 + math.cos(math.pi * i / (len(iters)))) for i in iters + ]) + + schedule = np.concatenate((warmup_schedule, schedule)) + + assert len(schedule) == epochs * niter_per_ep + return schedule + + +def load_checkpoint(model, + filename, + map_location='cpu', + strict=False, + logger=None): + """Load checkpoint from a file or URI. + + Args: + model (Module): Module to load checkpoint. + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for + details. + map_location (str): Same as :func:`torch.load`. + strict (bool): Whether to allow different params for the model and + checkpoint. + logger (:mod:`logging.Logger` or None): The logger for error message. + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + checkpoint = _load_checkpoint(filename, map_location) + # OrderedDict is a subclass of dict + if not isinstance(checkpoint, dict): + raise RuntimeError( + f'No state_dict found in checkpoint file {filename}') + # get state_dict from checkpoint + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + elif 'model' in checkpoint: + state_dict = checkpoint['model'] + elif 'module' in checkpoint: + state_dict = checkpoint['module'] + else: + state_dict = checkpoint + # strip prefix of state_dict + if list(state_dict.keys())[0].startswith('module.'): + state_dict = {k[7:]: v for k, v in state_dict.items()} + + # for MoBY, load model of online branch + if sorted(list(state_dict.keys()))[0].startswith('encoder'): + state_dict = { + k.replace('encoder.', ''): v + for k, v in state_dict.items() if k.startswith('encoder.') + } + + # reshape absolute position embedding for Swin + if state_dict.get('absolute_pos_embed') is not None: + absolute_pos_embed = state_dict['absolute_pos_embed'] + N1, L, C1 = absolute_pos_embed.size() + N2, C2, H, W = model.absolute_pos_embed.size() + if N1 != N2 or C1 != C2 or L != H * W: + logger.warning('Error in loading absolute_pos_embed, pass') + else: + state_dict['absolute_pos_embed'] = absolute_pos_embed.view( + N2, H, W, C2).permute(0, 3, 1, 2) + + rank, _ = get_dist_info() + if 'rel_pos_bias.relative_position_bias_table' in state_dict: + if rank == 0: + print( + 'Expand the shared relative position embedding to each layers. ' + ) + num_layers = model.get_num_layers() + rel_pos_bias = state_dict[ + 'rel_pos_bias.relative_position_bias_table'] + for i in range(num_layers): + state_dict['blocks.%d.attn.relative_position_bias_table' % + i] = rel_pos_bias.clone() + + state_dict.pop('rel_pos_bias.relative_position_bias_table') + + all_keys = list(state_dict.keys()) + for key in all_keys: + if 'relative_position_index' in key: + state_dict.pop(key) + + if 'relative_position_bias_table' in key: + rel_pos_bias = state_dict[key] + src_num_pos, num_attn_heads = rel_pos_bias.size() + dst_num_pos, _ = model.state_dict()[key].size() + dst_patch_shape = model.patch_embed.patch_shape + if dst_patch_shape[0] != dst_patch_shape[1]: + raise NotImplementedError() + num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - + 1) * (dst_patch_shape[1] * 2 - 1) + src_size = int((src_num_pos - num_extra_tokens)**0.5) + dst_size = int((dst_num_pos - num_extra_tokens)**0.5) + if src_size != dst_size: + if rank == 0: + print('Position interpolate for %s from %dx%d to %dx%d' % + (key, src_size, src_size, dst_size, dst_size)) + extra_tokens = rel_pos_bias[-num_extra_tokens:, :] + rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] + + def geometric_progression(a, r, n): + return a * (1.0 - r**n) / (1.0 - r) + + left, right = 1.01, 1.5 + while right - left > 1e-6: + q = (left + right) / 2.0 + gp = geometric_progression(1, q, src_size // 2) + if gp > dst_size // 2: + right = q + else: + left = q + + # if q > 1.13492: + # q = 1.13492 + + dis = [] + cur = 1 + for i in range(src_size // 2): + dis.append(cur) + cur += q**(i + 1) + + r_ids = [-_ for _ in reversed(dis)] + + x = r_ids + [0] + dis + y = r_ids + [0] + dis + + t = dst_size // 2.0 + dx = np.arange(-t, t + 0.1, 1.0) + dy = np.arange(-t, t + 0.1, 1.0) + if rank == 0: + print('x = {}'.format(x)) + print('dx = {}'.format(dx)) + + all_rel_pos_bias = [] + + for i in range(num_attn_heads): + z = rel_pos_bias[:, i].view(src_size, + src_size).float().numpy() + f = interpolate.interp2d(x, y, z, kind='cubic') + all_rel_pos_bias.append( + torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to( + rel_pos_bias.device)) + + rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) + new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), + dim=0) + state_dict[key] = new_rel_pos_bias + + if 'pos_embed' in state_dict: + pos_embed_checkpoint = state_dict['pos_embed'] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int( + (pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5) + # height (== width) for the new position embedding + new_size = int(num_patches**0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + if rank == 0: + print('Position interpolate from %dx%d to %dx%d' % + (orig_size, orig_size, new_size, new_size)) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, + embedding_size).permute( + 0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate(pos_tokens, + size=(new_size, + new_size), + mode='bicubic', + align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + state_dict['pos_embed'] = new_pos_embed + + # interpolate position bias table if needed + relative_position_bias_table_keys = [ + k for k in state_dict.keys() if 'relative_position_bias_table' in k + ] + for table_key in relative_position_bias_table_keys: + table_pretrained = state_dict[table_key] + table_current = model.state_dict()[table_key] + L1, nH1 = table_pretrained.size() + L2, nH2 = table_current.size() + if nH1 != nH2: + logger.warning(f'Error in loading {table_key}, pass') + else: + if L1 != L2: + S1 = int(L1**0.5) + S2 = int(L2**0.5) + table_pretrained_resized = F.interpolate( + table_pretrained.permute(1, 0).view(1, nH1, S1, S1), + size=(S2, S2), + mode='bicubic') + state_dict[table_key] = table_pretrained_resized.view( + nH2, L2).permute(1, 0) + + # load state_dict + load_state_dict(model, state_dict, strict, logger) + return checkpoint + + +def weights_to_cpu(state_dict): + """Copy a model state_dict to cpu. + + Args: + state_dict (OrderedDict): Model weights on GPU. + Returns: + OrderedDict: Model weights on GPU. + """ + state_dict_cpu = OrderedDict() + for key, val in state_dict.items(): + state_dict_cpu[key] = val.cpu() + return state_dict_cpu + + +def _save_to_state_dict(module, destination, prefix, keep_vars): + """Saves module state to `destination` dictionary. + + This method is modified from :meth:`torch.nn.Module._save_to_state_dict`. + Args: + module (nn.Module): The module to generate state_dict. + destination (dict): A dict where state will be stored. + prefix (str): The prefix for parameters and buffers used in this + module. + """ + for name, param in module._parameters.items(): + if param is not None: + destination[prefix + name] = param if keep_vars else param.detach() + for name, buf in module._buffers.items(): + # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d + if buf is not None: + destination[prefix + name] = buf if keep_vars else buf.detach() + + +def get_state_dict(module, destination=None, prefix='', keep_vars=False): + """Returns a dictionary containing a whole state of the module. + + Both parameters and persistent buffers (e.g. running averages) are + included. Keys are corresponding parameter and buffer names. + This method is modified from :meth:`torch.nn.Module.state_dict` to + recursively check parallel module in case that the model has a complicated + structure, e.g., nn.Module(nn.Module(DDP)). + Args: + module (nn.Module): The module to generate state_dict. + destination (OrderedDict): Returned dict for the state of the + module. + prefix (str): Prefix of the key. + keep_vars (bool): Whether to keep the variable property of the + parameters. Default: False. + Returns: + dict: A dictionary containing a whole state of the module. + """ + # recursively check parallel module in case that the model has a + # complicated structure, e.g., nn.Module(nn.Module(DDP)) + if is_module_wrapper(module): + module = module.module + + # below is the same as torch.nn.Module.state_dict() + if destination is None: + destination = OrderedDict() + destination._metadata = OrderedDict() + destination._metadata[prefix[:-1]] = local_metadata = dict( + version=module._version) + _save_to_state_dict(module, destination, prefix, keep_vars) + for name, child in module._modules.items(): + if child is not None: + get_state_dict(child, + destination, + prefix + name + '.', + keep_vars=keep_vars) + for hook in module._state_dict_hooks.values(): + hook_result = hook(module, destination, prefix, local_metadata) + if hook_result is not None: + destination = hook_result + return destination + + +def save_checkpoint(model, filename, optimizer=None, meta=None): + """Save checkpoint to file. + + The checkpoint will have 3 fields: ``meta``, ``state_dict`` and + ``optimizer``. By default ``meta`` will contain version and time info. + Args: + model (Module): Module whose params are to be saved. + filename (str): Checkpoint filename. + optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. + meta (dict, optional): Metadata to be saved in checkpoint. + """ + if meta is None: + meta = {} + elif not isinstance(meta, dict): + raise TypeError(f'meta must be a dict or None, but got {type(meta)}') + meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) + + if is_module_wrapper(model): + model = model.module + + if hasattr(model, 'CLASSES') and model.CLASSES is not None: + # save class name to the meta + meta.update(CLASSES=model.CLASSES) + + checkpoint = { + 'meta': meta, + 'state_dict': weights_to_cpu(get_state_dict(model)) + } + # save optimizer state dict in the checkpoint + if isinstance(optimizer, Optimizer): + checkpoint['optimizer'] = optimizer.state_dict() + elif isinstance(optimizer, dict): + checkpoint['optimizer'] = {} + for name, optim in optimizer.items(): + checkpoint['optimizer'][name] = optim.state_dict() + + if filename.startswith('pavi://'): + try: + from pavi import modelcloud + from pavi.exception import NodeNotFoundError + except ImportError: + raise ImportError( + 'Please install pavi to load checkpoint from modelcloud.') + model_path = filename[7:] + root = modelcloud.Folder() + model_dir, model_name = osp.split(model_path) + try: + model = modelcloud.get(model_dir) + except NodeNotFoundError: + model = root.create_training_model(model_dir) + with TemporaryDirectory() as tmp_dir: + checkpoint_file = osp.join(tmp_dir, model_name) + with open(checkpoint_file, 'wb') as f: + torch.save(checkpoint, f) + f.flush() + model.create_file(checkpoint_file, name=model_name) + else: + mmcv.mkdir_or_exist(osp.dirname(filename)) + # immediately flush buffer + with open(filename, 'wb') as f: + torch.save(checkpoint, f) + f.flush() diff --git a/segmentation/mmcv_custom/customized_text.py b/segmentation/mmcv_custom/customized_text.py new file mode 100644 index 000000000..168984be7 --- /dev/null +++ b/segmentation/mmcv_custom/customized_text.py @@ -0,0 +1,123 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import datetime +from collections import OrderedDict + +import mmcv +import torch +from mmcv.runner import HOOKS, TextLoggerHook + + +@HOOKS.register_module() +class CustomizedTextLoggerHook(TextLoggerHook): + """Customized Text Logger hook. + + This logger prints out both lr and layer_0_lr. + """ + def _log_info(self, log_dict, runner): + # print exp name for users to distinguish experiments + # at every ``interval_exp_name`` iterations and the end of each epoch + if runner.meta is not None and 'exp_name' in runner.meta: + if (self.every_n_iters(runner, self.interval_exp_name)) or ( + self.by_epoch and self.end_of_epoch(runner)): + exp_info = f'Exp name: {runner.meta["exp_name"]}' + runner.logger.info(exp_info) + + if log_dict['mode'] == 'train': + lr_str = {} + for lr_type in ['lr', 'layer_0_lr']: + if isinstance(log_dict[lr_type], dict): + lr_str[lr_type] = [] + for k, val in log_dict[lr_type].items(): + lr_str.append(f'{lr_type}_{k}: {val:.3e}') + lr_str[lr_type] = ' '.join(lr_str) + else: + lr_str[lr_type] = f'{lr_type}: {log_dict[lr_type]:.3e}' + + # by epoch: Epoch [4][100/1000] + # by iter: Iter [100/100000] + if self.by_epoch: + log_str = f'Epoch [{log_dict["epoch"]}]' \ + f'[{log_dict["iter"]}/{len(runner.data_loader)}]\t' + else: + log_str = f'Iter [{log_dict["iter"]}/{runner.max_iters}]\t' + log_str += f'{lr_str["lr"]}, {lr_str["layer_0_lr"]}, ' + + if 'time' in log_dict.keys(): + self.time_sec_tot += (log_dict['time'] * self.interval) + time_sec_avg = self.time_sec_tot / (runner.iter - + self.start_iter + 1) + eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1) + eta_str = str(datetime.timedelta(seconds=int(eta_sec))) + log_str += f'eta: {eta_str}, ' + log_str += f'time: {log_dict["time"]:.3f}, ' \ + f'data_time: {log_dict["data_time"]:.3f}, ' + # statistic memory + if torch.cuda.is_available(): + log_str += f'memory: {log_dict["memory"]}, ' + else: + # val/test time + # here 1000 is the length of the val dataloader + # by epoch: Epoch[val] [4][1000] + # by iter: Iter[val] [1000] + if self.by_epoch: + log_str = f'Epoch({log_dict["mode"]}) ' \ + f'[{log_dict["epoch"]}][{log_dict["iter"]}]\t' + else: + log_str = f'Iter({log_dict["mode"]}) [{log_dict["iter"]}]\t' + + log_items = [] + for name, val in log_dict.items(): + # TODO: resolve this hack + # these items have been in log_str + if name in [ + 'mode', 'Epoch', 'iter', 'lr', 'layer_0_lr', 'time', + 'data_time', 'memory', 'epoch' + ]: + continue + if isinstance(val, float): + val = f'{val:.4f}' + log_items.append(f'{name}: {val}') + log_str += ', '.join(log_items) + + runner.logger.info(log_str) + + def log(self, runner): + if 'eval_iter_num' in runner.log_buffer.output: + # this doesn't modify runner.iter and is regardless of by_epoch + cur_iter = runner.log_buffer.output.pop('eval_iter_num') + else: + cur_iter = self.get_iter(runner, inner_iter=True) + + log_dict = OrderedDict(mode=self.get_mode(runner), + epoch=self.get_epoch(runner), + iter=cur_iter) + + # record lr and layer_0_lr + cur_lr = runner.current_lr() + if isinstance(cur_lr, list): + log_dict['layer_0_lr'] = min(cur_lr) + log_dict['lr'] = max(cur_lr) + else: + assert isinstance(cur_lr, dict) + log_dict['lr'], log_dict['layer_0_lr'] = {}, {} + for k, lr_ in cur_lr.items(): + assert isinstance(lr_, list) + log_dict['layer_0_lr'].update({k: min(lr_)}) + log_dict['lr'].update({k: max(lr_)}) + + if 'time' in runner.log_buffer.output: + # statistic memory + if torch.cuda.is_available(): + log_dict['memory'] = self._get_max_memory(runner) + + log_dict = dict(log_dict, **runner.log_buffer.output) + + self._log_info(log_dict, runner) + self._dump_log(log_dict, runner) + return log_dict diff --git a/segmentation/mmcv_custom/layer_decay_optimizer_constructor.py b/segmentation/mmcv_custom/layer_decay_optimizer_constructor.py new file mode 100644 index 000000000..1b3a7861a --- /dev/null +++ b/segmentation/mmcv_custom/layer_decay_optimizer_constructor.py @@ -0,0 +1,111 @@ +# Copyright (c) ByteDance, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +"""Mostly copy-paste from BEiT library: + +https://github.com/microsoft/unilm/blob/master/beit/semantic_segmentation/mmcv_custom/layer_decay_optimizer_constructor.py +""" + +import json + +from mmcv.runner import (OPTIMIZER_BUILDERS, DefaultOptimizerConstructor, + get_dist_info) + + +def get_num_layer_for_vit(var_name, num_max_layer): + if var_name in ('backbone.cls_token', 'backbone.mask_token', + 'backbone.pos_embed'): + return 0 + elif var_name.startswith('backbone.patch_embed'): + return 0 + elif var_name.startswith('decode_head.mask_embed'): + return 0 + elif var_name.startswith('decode_head.cls_embed'): + return 0 + elif var_name.startswith('decode_head.level_embed'): + return 0 + elif var_name.startswith('decode_head.query_embed'): + return 0 + elif var_name.startswith('decode_head.query_feat'): + return 0 + elif var_name.startswith('backbone.blocks'): + layer_id = int(var_name.split('.')[2]) + return layer_id + 1 + + else: + return num_max_layer - 1 + + +@OPTIMIZER_BUILDERS.register_module() +class LayerDecayOptimizerConstructor(DefaultOptimizerConstructor): + def add_params(self, params, module, prefix='', is_dcn_module=None): + """Add all parameters of module to the params list. + + The parameters of the given module will be added to the list of param + groups, with specific rules defined by paramwise_cfg. + Args: + params (list[dict]): A list of param groups, it will be modified + in place. + module (nn.Module): The module to be added. + prefix (str): The prefix of the module + is_dcn_module (int|float|None): If the current module is a + submodule of DCN, `is_dcn_module` will be passed to + control conv_offset layer's learning rate. Defaults to None. + """ + parameter_groups = {} + print(self.paramwise_cfg) + num_layers = self.paramwise_cfg.get('num_layers') + 2 + layer_decay_rate = self.paramwise_cfg.get('layer_decay_rate') + print('Build LayerDecayOptimizerConstructor %f - %d' % + (layer_decay_rate, num_layers)) + weight_decay = self.base_wd + + for name, param in module.named_parameters(): + if not param.requires_grad: + continue # frozen weights + if len(param.shape) == 1 or name.endswith('.bias') \ + or name in ('pos_embed', 'cls_token'): + # or "relative_position_bias_table" in name: + group_name = 'no_decay' + this_weight_decay = 0. + else: + group_name = 'decay' + this_weight_decay = weight_decay + + layer_id = get_num_layer_for_vit(name, num_layers) + group_name = 'layer_%d_%s' % (layer_id, group_name) + + if group_name not in parameter_groups: + scale = layer_decay_rate**(num_layers - layer_id - 1) + + parameter_groups[group_name] = { + 'weight_decay': this_weight_decay, + 'params': [], + 'param_names': [], + 'lr_scale': scale, + 'group_name': group_name, + 'lr': scale * self.base_lr, + } + + parameter_groups[group_name]['params'].append(param) + parameter_groups[group_name]['param_names'].append(name) + rank, _ = get_dist_info() + if rank == 0: + to_display = {} + for key in parameter_groups: + to_display[key] = { + 'param_names': parameter_groups[key]['param_names'], + 'lr_scale': parameter_groups[key]['lr_scale'], + 'lr': parameter_groups[key]['lr'], + 'weight_decay': parameter_groups[key]['weight_decay'], + } + print('Param groups = %s' % json.dumps(to_display, indent=2)) + + # state_dict = module.state_dict() + # for group_name in parameter_groups: + # group = parameter_groups[group_name] + # for name in group["param_names"]: + # group["params"].append(state_dict[name]) + params.extend(parameter_groups.values()) diff --git a/segmentation/mmseg_custom/__init__.py b/segmentation/mmseg_custom/__init__.py new file mode 100644 index 000000000..e93445c76 --- /dev/null +++ b/segmentation/mmseg_custom/__init__.py @@ -0,0 +1,3 @@ +from .core import * # noqa: F401,F403 +from .datasets import * # noqa: F401,F403 +from .models import * # noqa: F401,F403 diff --git a/segmentation/mmseg_custom/core/__init__.py b/segmentation/mmseg_custom/core/__init__.py new file mode 100644 index 000000000..ea53aaf3b --- /dev/null +++ b/segmentation/mmseg_custom/core/__init__.py @@ -0,0 +1,8 @@ +from mmseg.core.evaluation import * # noqa: F401, F403 +from mmseg.core.seg import * # noqa: F401, F403 + +from .anchor import * # noqa: F401,F403 +from .box import * # noqa: F401,F403 +from .evaluation import * # noqa: F401,F403 +from .mask import * # noqa: F401,F403 +from .utils import * # noqa: F401, F403 diff --git a/segmentation/mmseg_custom/core/anchor/__init__.py b/segmentation/mmseg_custom/core/anchor/__init__.py new file mode 100644 index 000000000..9c8035f3c --- /dev/null +++ b/segmentation/mmseg_custom/core/anchor/__init__.py @@ -0,0 +1 @@ +from .point_generator import MlvlPointGenerator # noqa: F401,F403 diff --git a/segmentation/mmseg_custom/core/anchor/builder.py b/segmentation/mmseg_custom/core/anchor/builder.py new file mode 100644 index 000000000..ddb25ad37 --- /dev/null +++ b/segmentation/mmseg_custom/core/anchor/builder.py @@ -0,0 +1,19 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +from mmcv.utils import Registry, build_from_cfg + +PRIOR_GENERATORS = Registry('Generator for anchors and points') + +ANCHOR_GENERATORS = PRIOR_GENERATORS + + +def build_prior_generator(cfg, default_args=None): + return build_from_cfg(cfg, PRIOR_GENERATORS, default_args) + + +def build_anchor_generator(cfg, default_args=None): + warnings.warn( + '``build_anchor_generator`` would be deprecated soon, please use ' + '``build_prior_generator`` ') + return build_prior_generator(cfg, default_args=default_args) diff --git a/segmentation/mmseg_custom/core/anchor/point_generator.py b/segmentation/mmseg_custom/core/anchor/point_generator.py new file mode 100644 index 000000000..34dd51a95 --- /dev/null +++ b/segmentation/mmseg_custom/core/anchor/point_generator.py @@ -0,0 +1,260 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch +from torch.nn.modules.utils import _pair + +from .builder import PRIOR_GENERATORS + + +@PRIOR_GENERATORS.register_module() +class PointGenerator: + def _meshgrid(self, x, y, row_major=True): + xx = x.repeat(len(y)) + yy = y.view(-1, 1).repeat(1, len(x)).view(-1) + if row_major: + return xx, yy + else: + return yy, xx + + def grid_points(self, featmap_size, stride=16, device='cuda'): + feat_h, feat_w = featmap_size + shift_x = torch.arange(0., feat_w, device=device) * stride + shift_y = torch.arange(0., feat_h, device=device) * stride + shift_xx, shift_yy = self._meshgrid(shift_x, shift_y) + stride = shift_x.new_full((shift_xx.shape[0], ), stride) + shifts = torch.stack([shift_xx, shift_yy, stride], dim=-1) + all_points = shifts.to(device) + return all_points + + def valid_flags(self, featmap_size, valid_size, device='cuda'): + feat_h, feat_w = featmap_size + valid_h, valid_w = valid_size + assert valid_h <= feat_h and valid_w <= feat_w + valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device) + valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device) + valid_x[:valid_w] = 1 + valid_y[:valid_h] = 1 + valid_xx, valid_yy = self._meshgrid(valid_x, valid_y) + valid = valid_xx & valid_yy + return valid + + +@PRIOR_GENERATORS.register_module() +class MlvlPointGenerator: + """Standard points generator for multi-level (Mlvl) feature maps in 2D + points-based detectors. + + Args: + strides (list[int] | list[tuple[int, int]]): Strides of anchors + in multiple feature levels in order (w, h). + offset (float): The offset of points, the value is normalized with + corresponding stride. Defaults to 0.5. + """ + def __init__(self, strides, offset=0.5): + self.strides = [_pair(stride) for stride in strides] + self.offset = offset + + @property + def num_levels(self): + """int: number of feature levels that the generator will be applied""" + return len(self.strides) + + @property + def num_base_priors(self): + """list[int]: The number of priors (points) at a point + on the feature grid""" + return [1 for _ in range(len(self.strides))] + + def _meshgrid(self, x, y, row_major=True): + yy, xx = torch.meshgrid(y, x) + if row_major: + # warning .flatten() would cause error in ONNX exporting + # have to use reshape here + return xx.reshape(-1), yy.reshape(-1) + + else: + return yy.reshape(-1), xx.reshape(-1) + + def grid_priors(self, + featmap_sizes, + dtype=torch.float32, + device='cuda', + with_stride=False): + """Generate grid points of multiple feature levels. + + Args: + featmap_sizes (list[tuple]): List of feature map sizes in + multiple feature levels, each size arrange as + as (h, w). + dtype (:obj:`dtype`): Dtype of priors. Default: torch.float32. + device (str): The device where the anchors will be put on. + with_stride (bool): Whether to concatenate the stride to + the last dimension of points. + + Return: + list[torch.Tensor]: Points of multiple feature levels. + The sizes of each tensor should be (N, 2) when with stride is + ``False``, where N = width * height, width and height + are the sizes of the corresponding feature level, + and the last dimension 2 represent (coord_x, coord_y), + otherwise the shape should be (N, 4), + and the last dimension 4 represent + (coord_x, coord_y, stride_w, stride_h). + """ + + assert self.num_levels == len(featmap_sizes) + multi_level_priors = [] + for i in range(self.num_levels): + priors = self.single_level_grid_priors(featmap_sizes[i], + level_idx=i, + dtype=dtype, + device=device, + with_stride=with_stride) + multi_level_priors.append(priors) + return multi_level_priors + + def single_level_grid_priors(self, + featmap_size, + level_idx, + dtype=torch.float32, + device='cuda', + with_stride=False): + """Generate grid Points of a single level. + + Note: + This function is usually called by method ``self.grid_priors``. + + Args: + featmap_size (tuple[int]): Size of the feature maps, arrange as + (h, w). + level_idx (int): The index of corresponding feature map level. + dtype (:obj:`dtype`): Dtype of priors. Default: torch.float32. + device (str, optional): The device the tensor will be put on. + Defaults to 'cuda'. + with_stride (bool): Concatenate the stride to the last dimension + of points. + + Return: + Tensor: Points of single feature levels. + The shape of tensor should be (N, 2) when with stride is + ``False``, where N = width * height, width and height + are the sizes of the corresponding feature level, + and the last dimension 2 represent (coord_x, coord_y), + otherwise the shape should be (N, 4), + and the last dimension 4 represent + (coord_x, coord_y, stride_w, stride_h). + """ + feat_h, feat_w = featmap_size + stride_w, stride_h = self.strides[level_idx] + shift_x = (torch.arange(0, feat_w, device=device) + + self.offset) * stride_w + # keep featmap_size as Tensor instead of int, so that we + # can convert to ONNX correctly + shift_x = shift_x.to(dtype) + + shift_y = (torch.arange(0, feat_h, device=device) + + self.offset) * stride_h + # keep featmap_size as Tensor instead of int, so that we + # can convert to ONNX correctly + shift_y = shift_y.to(dtype) + shift_xx, shift_yy = self._meshgrid(shift_x, shift_y) + if not with_stride: + shifts = torch.stack([shift_xx, shift_yy], dim=-1) + else: + # use `shape[0]` instead of `len(shift_xx)` for ONNX export + stride_w = shift_xx.new_full((shift_xx.shape[0], ), + stride_w).to(dtype) + stride_h = shift_xx.new_full((shift_yy.shape[0], ), + stride_h).to(dtype) + shifts = torch.stack([shift_xx, shift_yy, stride_w, stride_h], + dim=-1) + all_points = shifts.to(device) + return all_points + + def valid_flags(self, featmap_sizes, pad_shape, device='cuda'): + """Generate valid flags of points of multiple feature levels. + + Args: + featmap_sizes (list(tuple)): List of feature map sizes in + multiple feature levels, each size arrange as + as (h, w). + pad_shape (tuple(int)): The padded shape of the image, + arrange as (h, w). + device (str): The device where the anchors will be put on. + + Return: + list(torch.Tensor): Valid flags of points of multiple levels. + """ + assert self.num_levels == len(featmap_sizes) + multi_level_flags = [] + for i in range(self.num_levels): + point_stride = self.strides[i] + feat_h, feat_w = featmap_sizes[i] + h, w = pad_shape[:2] + valid_feat_h = min(int(np.ceil(h / point_stride[1])), feat_h) + valid_feat_w = min(int(np.ceil(w / point_stride[0])), feat_w) + flags = self.single_level_valid_flags((feat_h, feat_w), + (valid_feat_h, valid_feat_w), + device=device) + multi_level_flags.append(flags) + return multi_level_flags + + def single_level_valid_flags(self, + featmap_size, + valid_size, + device='cuda'): + """Generate the valid flags of points of a single feature map. + + Args: + featmap_size (tuple[int]): The size of feature maps, arrange as + as (h, w). + valid_size (tuple[int]): The valid size of the feature maps. + The size arrange as as (h, w). + device (str, optional): The device where the flags will be put on. + Defaults to 'cuda'. + + Returns: + torch.Tensor: The valid flags of each points in a single level \ + feature map. + """ + feat_h, feat_w = featmap_size + valid_h, valid_w = valid_size + assert valid_h <= feat_h and valid_w <= feat_w + valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device) + valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device) + valid_x[:valid_w] = 1 + valid_y[:valid_h] = 1 + valid_xx, valid_yy = self._meshgrid(valid_x, valid_y) + valid = valid_xx & valid_yy + return valid + + def sparse_priors(self, + prior_idxs, + featmap_size, + level_idx, + dtype=torch.float32, + device='cuda'): + """Generate sparse points according to the ``prior_idxs``. + + Args: + prior_idxs (Tensor): The index of corresponding anchors + in the feature map. + featmap_size (tuple[int]): feature map size arrange as (w, h). + level_idx (int): The level index of corresponding feature + map. + dtype (obj:`torch.dtype`): Date type of points. Defaults to + ``torch.float32``. + device (obj:`torch.device`): The device where the points is + located. + Returns: + Tensor: Anchor with shape (N, 2), N should be equal to + the length of ``prior_idxs``. And last dimension + 2 represent (coord_x, coord_y). + """ + height, width = featmap_size + x = (prior_idxs % width + self.offset) * self.strides[level_idx][0] + y = ((prior_idxs // width) % height + + self.offset) * self.strides[level_idx][1] + prioris = torch.stack([x, y], 1).to(dtype) + prioris = prioris.to(device) + return prioris diff --git a/segmentation/mmseg_custom/core/box/__init__.py b/segmentation/mmseg_custom/core/box/__init__.py new file mode 100644 index 000000000..032b9e144 --- /dev/null +++ b/segmentation/mmseg_custom/core/box/__init__.py @@ -0,0 +1,2 @@ +from .builder import * # noqa: F401,F403 +from .samplers import MaskPseudoSampler # noqa: F401,F403 diff --git a/segmentation/mmseg_custom/core/box/builder.py b/segmentation/mmseg_custom/core/box/builder.py new file mode 100644 index 000000000..af4b8a835 --- /dev/null +++ b/segmentation/mmseg_custom/core/box/builder.py @@ -0,0 +1,15 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.utils import Registry, build_from_cfg + +BBOX_SAMPLERS = Registry('bbox_sampler') +BBOX_CODERS = Registry('bbox_coder') + + +def build_sampler(cfg, **default_args): + """Builder of box sampler.""" + return build_from_cfg(cfg, BBOX_SAMPLERS, default_args) + + +def build_bbox_coder(cfg, **default_args): + """Builder of box coder.""" + return build_from_cfg(cfg, BBOX_CODERS, default_args) diff --git a/segmentation/mmseg_custom/core/box/samplers/__init__.py b/segmentation/mmseg_custom/core/box/samplers/__init__.py new file mode 100644 index 000000000..0542ac820 --- /dev/null +++ b/segmentation/mmseg_custom/core/box/samplers/__init__.py @@ -0,0 +1 @@ +from .mask_pseudo_sampler import MaskPseudoSampler # noqa: F401,F403 diff --git a/segmentation/mmseg_custom/core/box/samplers/base_sampler.py b/segmentation/mmseg_custom/core/box/samplers/base_sampler.py new file mode 100644 index 000000000..dee649739 --- /dev/null +++ b/segmentation/mmseg_custom/core/box/samplers/base_sampler.py @@ -0,0 +1,105 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod + +import torch + +from .sampling_result import SamplingResult + + +class BaseSampler(metaclass=ABCMeta): + """Base class of samplers.""" + def __init__(self, + num, + pos_fraction, + neg_pos_ub=-1, + add_gt_as_proposals=True, + **kwargs): + self.num = num + self.pos_fraction = pos_fraction + self.neg_pos_ub = neg_pos_ub + self.add_gt_as_proposals = add_gt_as_proposals + self.pos_sampler = self + self.neg_sampler = self + + @abstractmethod + def _sample_pos(self, assign_result, num_expected, **kwargs): + """Sample positive samples.""" + pass + + @abstractmethod + def _sample_neg(self, assign_result, num_expected, **kwargs): + """Sample negative samples.""" + pass + + def sample(self, + assign_result, + bboxes, + gt_bboxes, + gt_labels=None, + **kwargs): + """Sample positive and negative bboxes. + + This is a simple implementation of bbox sampling given candidates, + assigning results and ground truth bboxes. + + Args: + assign_result (:obj:`AssignResult`): Bbox assigning results. + bboxes (Tensor): Boxes to be sampled from. + gt_bboxes (Tensor): Ground truth bboxes. + gt_labels (Tensor, optional): Class labels of ground truth bboxes. + + Returns: + :obj:`SamplingResult`: Sampling result. + + Example: + >>> from mmdet.core.bbox import RandomSampler + >>> from mmdet.core.bbox import AssignResult + >>> from mmdet.core.bbox.demodata import ensure_rng, random_boxes + >>> rng = ensure_rng(None) + >>> assign_result = AssignResult.random(rng=rng) + >>> bboxes = random_boxes(assign_result.num_preds, rng=rng) + >>> gt_bboxes = random_boxes(assign_result.num_gts, rng=rng) + >>> gt_labels = None + >>> self = RandomSampler(num=32, pos_fraction=0.5, neg_pos_ub=-1, + >>> add_gt_as_proposals=False) + >>> self = self.sample(assign_result, bboxes, gt_bboxes, gt_labels) + """ + if len(bboxes.shape) < 2: + bboxes = bboxes[None, :] + + bboxes = bboxes[:, :4] + + gt_flags = bboxes.new_zeros((bboxes.shape[0], ), dtype=torch.uint8) + if self.add_gt_as_proposals and len(gt_bboxes) > 0: + if gt_labels is None: + raise ValueError( + 'gt_labels must be given when add_gt_as_proposals is True') + bboxes = torch.cat([gt_bboxes, bboxes], dim=0) + assign_result.add_gt_(gt_labels) + gt_ones = bboxes.new_ones(gt_bboxes.shape[0], dtype=torch.uint8) + gt_flags = torch.cat([gt_ones, gt_flags]) + + num_expected_pos = int(self.num * self.pos_fraction) + pos_inds = self.pos_sampler._sample_pos(assign_result, + num_expected_pos, + bboxes=bboxes, + **kwargs) + # We found that sampled indices have duplicated items occasionally. + # (may be a bug of PyTorch) + pos_inds = pos_inds.unique() + num_sampled_pos = pos_inds.numel() + num_expected_neg = self.num - num_sampled_pos + if self.neg_pos_ub >= 0: + _pos = max(1, num_sampled_pos) + neg_upper_bound = int(self.neg_pos_ub * _pos) + if num_expected_neg > neg_upper_bound: + num_expected_neg = neg_upper_bound + neg_inds = self.neg_sampler._sample_neg(assign_result, + num_expected_neg, + bboxes=bboxes, + **kwargs) + neg_inds = neg_inds.unique() + + sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes, + assign_result, gt_flags) + return sampling_result diff --git a/segmentation/mmseg_custom/core/box/samplers/mask_pseudo_sampler.py b/segmentation/mmseg_custom/core/box/samplers/mask_pseudo_sampler.py new file mode 100644 index 000000000..501a2010d --- /dev/null +++ b/segmentation/mmseg_custom/core/box/samplers/mask_pseudo_sampler.py @@ -0,0 +1,43 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""copy from +https://github.com/ZwwWayne/K-Net/blob/main/knet/det/mask_pseudo_sampler.py.""" + +import torch + +from ..builder import BBOX_SAMPLERS +from .base_sampler import BaseSampler +from .mask_sampling_result import MaskSamplingResult + + +@BBOX_SAMPLERS.register_module() +class MaskPseudoSampler(BaseSampler): + """A pseudo sampler that does not do sampling actually.""" + def __init__(self, **kwargs): + pass + + def _sample_pos(self, **kwargs): + """Sample positive samples.""" + raise NotImplementedError + + def _sample_neg(self, **kwargs): + """Sample negative samples.""" + raise NotImplementedError + + def sample(self, assign_result, masks, gt_masks, **kwargs): + """Directly returns the positive and negative indices of samples. + + Args: + assign_result (:obj:`AssignResult`): Assigned results + masks (torch.Tensor): Bounding boxes + gt_masks (torch.Tensor): Ground truth boxes + Returns: + :obj:`SamplingResult`: sampler results + """ + pos_inds = torch.nonzero(assign_result.gt_inds > 0, + as_tuple=False).squeeze(-1).unique() + neg_inds = torch.nonzero(assign_result.gt_inds == 0, + as_tuple=False).squeeze(-1).unique() + gt_flags = masks.new_zeros(masks.shape[0], dtype=torch.uint8) + sampling_result = MaskSamplingResult(pos_inds, neg_inds, masks, + gt_masks, assign_result, gt_flags) + return sampling_result diff --git a/segmentation/mmseg_custom/core/box/samplers/mask_sampling_result.py b/segmentation/mmseg_custom/core/box/samplers/mask_sampling_result.py new file mode 100644 index 000000000..f6c500b5b --- /dev/null +++ b/segmentation/mmseg_custom/core/box/samplers/mask_sampling_result.py @@ -0,0 +1,59 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""copy from +https://github.com/ZwwWayne/K-Net/blob/main/knet/det/mask_pseudo_sampler.py.""" + +import torch + +from .sampling_result import SamplingResult + + +class MaskSamplingResult(SamplingResult): + """Mask sampling result.""" + def __init__(self, pos_inds, neg_inds, masks, gt_masks, assign_result, + gt_flags): + self.pos_inds = pos_inds + self.neg_inds = neg_inds + self.pos_masks = masks[pos_inds] + self.neg_masks = masks[neg_inds] + self.pos_is_gt = gt_flags[pos_inds] + + self.num_gts = gt_masks.shape[0] + self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1 + + if gt_masks.numel() == 0: + # hack for index error case + assert self.pos_assigned_gt_inds.numel() == 0 + self.pos_gt_masks = torch.empty_like(gt_masks) + else: + self.pos_gt_masks = gt_masks[self.pos_assigned_gt_inds, :] + + if assign_result.labels is not None: + self.pos_gt_labels = assign_result.labels[pos_inds] + else: + self.pos_gt_labels = None + + @property + def masks(self): + """torch.Tensor: concatenated positive and negative boxes""" + return torch.cat([self.pos_masks, self.neg_masks]) + + def __nice__(self): + data = self.info.copy() + data['pos_masks'] = data.pop('pos_masks').shape + data['neg_masks'] = data.pop('neg_masks').shape + parts = [f"'{k}': {v!r}" for k, v in sorted(data.items())] + body = ' ' + ',\n '.join(parts) + return '{\n' + body + '\n}' + + @property + def info(self): + """Returns a dictionary of info about the object.""" + return { + 'pos_inds': self.pos_inds, + 'neg_inds': self.neg_inds, + 'pos_masks': self.pos_masks, + 'neg_masks': self.neg_masks, + 'pos_is_gt': self.pos_is_gt, + 'num_gts': self.num_gts, + 'pos_assigned_gt_inds': self.pos_assigned_gt_inds, + } diff --git a/segmentation/mmseg_custom/core/box/samplers/sampling_result.py b/segmentation/mmseg_custom/core/box/samplers/sampling_result.py new file mode 100644 index 000000000..d1ac5785b --- /dev/null +++ b/segmentation/mmseg_custom/core/box/samplers/sampling_result.py @@ -0,0 +1,150 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmdet.utils import util_mixins + + +class SamplingResult(util_mixins.NiceRepr): + """Bbox sampling result. + + Example: + >>> # xdoctest: +IGNORE_WANT + >>> from mmdet.core.bbox.samplers.sampling_result import * # NOQA + >>> self = SamplingResult.random(rng=10) + >>> print(f'self = {self}') + self = + """ + def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result, + gt_flags): + self.pos_inds = pos_inds + self.neg_inds = neg_inds + self.pos_bboxes = bboxes[pos_inds] + self.neg_bboxes = bboxes[neg_inds] + self.pos_is_gt = gt_flags[pos_inds] + + self.num_gts = gt_bboxes.shape[0] + self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1 + + if gt_bboxes.numel() == 0: + # hack for index error case + assert self.pos_assigned_gt_inds.numel() == 0 + self.pos_gt_bboxes = torch.empty_like(gt_bboxes).view(-1, 4) + else: + if len(gt_bboxes.shape) < 2: + gt_bboxes = gt_bboxes.view(-1, 4) + + self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds.long(), :] + + if assign_result.labels is not None: + self.pos_gt_labels = assign_result.labels[pos_inds] + else: + self.pos_gt_labels = None + + @property + def bboxes(self): + """torch.Tensor: concatenated positive and negative boxes""" + return torch.cat([self.pos_bboxes, self.neg_bboxes]) + + def to(self, device): + """Change the device of the data inplace. + + Example: + >>> self = SamplingResult.random() + >>> print(f'self = {self.to(None)}') + >>> # xdoctest: +REQUIRES(--gpu) + >>> print(f'self = {self.to(0)}') + """ + _dict = self.__dict__ + for key, value in _dict.items(): + if isinstance(value, torch.Tensor): + _dict[key] = value.to(device) + return self + + def __nice__(self): + data = self.info.copy() + data['pos_bboxes'] = data.pop('pos_bboxes').shape + data['neg_bboxes'] = data.pop('neg_bboxes').shape + parts = [f"'{k}': {v!r}" for k, v in sorted(data.items())] + body = ' ' + ',\n '.join(parts) + return '{\n' + body + '\n}' + + @property + def info(self): + """Returns a dictionary of info about the object.""" + return { + 'pos_inds': self.pos_inds, + 'neg_inds': self.neg_inds, + 'pos_bboxes': self.pos_bboxes, + 'neg_bboxes': self.neg_bboxes, + 'pos_is_gt': self.pos_is_gt, + 'num_gts': self.num_gts, + 'pos_assigned_gt_inds': self.pos_assigned_gt_inds, + } + + @classmethod + def random(cls, rng=None, **kwargs): + """ + Args: + rng (None | int | numpy.random.RandomState): seed or state. + kwargs (keyword arguments): + - num_preds: number of predicted boxes + - num_gts: number of true boxes + - p_ignore (float): probability of a predicted box assigned to \ + an ignored truth. + - p_assigned (float): probability of a predicted box not being \ + assigned. + - p_use_label (float | bool): with labels or not. + + Returns: + :obj:`SamplingResult`: Randomly generated sampling result. + + Example: + >>> from mmdet.core.bbox.samplers.sampling_result import * # NOQA + >>> self = SamplingResult.random() + >>> print(self.__dict__) + """ + from mmdet.core.bbox import demodata + from mmdet.core.bbox.assigners.assign_result import AssignResult + from mmdet.core.bbox.samplers.random_sampler import RandomSampler + rng = demodata.ensure_rng(rng) + + # make probabalistic? + num = 32 + pos_fraction = 0.5 + neg_pos_ub = -1 + + assign_result = AssignResult.random(rng=rng, **kwargs) + + # Note we could just compute an assignment + bboxes = demodata.random_boxes(assign_result.num_preds, rng=rng) + gt_bboxes = demodata.random_boxes(assign_result.num_gts, rng=rng) + + if rng.rand() > 0.2: + # sometimes algorithms squeeze their data, be robust to that + gt_bboxes = gt_bboxes.squeeze() + bboxes = bboxes.squeeze() + + if assign_result.labels is None: + gt_labels = None + else: + gt_labels = None # todo + + if gt_labels is None: + add_gt_as_proposals = False + else: + add_gt_as_proposals = True # make probabalistic? + + sampler = RandomSampler(num, + pos_fraction, + neg_pos_ub=neg_pos_ub, + add_gt_as_proposals=add_gt_as_proposals, + rng=rng) + self = sampler.sample(assign_result, bboxes, gt_bboxes, gt_labels) + return self diff --git a/segmentation/mmseg_custom/core/evaluation/__init__.py b/segmentation/mmseg_custom/core/evaluation/__init__.py new file mode 100644 index 000000000..3193f018d --- /dev/null +++ b/segmentation/mmseg_custom/core/evaluation/__init__.py @@ -0,0 +1 @@ +from .panoptic_utils import INSTANCE_OFFSET # noqa: F401,F403 diff --git a/segmentation/mmseg_custom/core/evaluation/panoptic_utils.py b/segmentation/mmseg_custom/core/evaluation/panoptic_utils.py new file mode 100644 index 000000000..10c9ad934 --- /dev/null +++ b/segmentation/mmseg_custom/core/evaluation/panoptic_utils.py @@ -0,0 +1,6 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# A custom value to distinguish instance ID and category ID; need to +# be greater than the number of categories. +# For a pixel in the panoptic result map: +# pan_id = ins_id * INSTANCE_OFFSET + cat_id +INSTANCE_OFFSET = 1000 diff --git a/segmentation/mmseg_custom/core/mask/__init__.py b/segmentation/mmseg_custom/core/mask/__init__.py new file mode 100644 index 000000000..46425d4ab --- /dev/null +++ b/segmentation/mmseg_custom/core/mask/__init__.py @@ -0,0 +1 @@ +from .utils import mask2bbox # noqa: F401,F403 diff --git a/segmentation/mmseg_custom/core/mask/utils.py b/segmentation/mmseg_custom/core/mask/utils.py new file mode 100644 index 000000000..90544b34f --- /dev/null +++ b/segmentation/mmseg_custom/core/mask/utils.py @@ -0,0 +1,89 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmcv +import numpy as np +import pycocotools.mask as mask_util +import torch + + +def split_combined_polys(polys, poly_lens, polys_per_mask): + """Split the combined 1-D polys into masks. + + A mask is represented as a list of polys, and a poly is represented as + a 1-D array. In dataset, all masks are concatenated into a single 1-D + tensor. Here we need to split the tensor into original representations. + + Args: + polys (list): a list (length = image num) of 1-D tensors + poly_lens (list): a list (length = image num) of poly length + polys_per_mask (list): a list (length = image num) of poly number + of each mask + + Returns: + list: a list (length = image num) of list (length = mask num) of \ + list (length = poly num) of numpy array. + """ + mask_polys_list = [] + for img_id in range(len(polys)): + polys_single = polys[img_id] + polys_lens_single = poly_lens[img_id].tolist() + polys_per_mask_single = polys_per_mask[img_id].tolist() + + split_polys = mmcv.slice_list(polys_single, polys_lens_single) + mask_polys = mmcv.slice_list(split_polys, polys_per_mask_single) + mask_polys_list.append(mask_polys) + return mask_polys_list + + +# TODO: move this function to more proper place +def encode_mask_results(mask_results): + """Encode bitmap mask to RLE code. + + Args: + mask_results (list | tuple[list]): bitmap mask results. + In mask scoring rcnn, mask_results is a tuple of (segm_results, + segm_cls_score). + + Returns: + list | tuple: RLE encoded mask. + """ + if isinstance(mask_results, tuple): # mask scoring + cls_segms, cls_mask_scores = mask_results + else: + cls_segms = mask_results + num_classes = len(cls_segms) + encoded_mask_results = [[] for _ in range(num_classes)] + for i in range(len(cls_segms)): + for cls_segm in cls_segms[i]: + encoded_mask_results[i].append( + mask_util.encode( + np.array( + cls_segm[:, :, np.newaxis], order='F', + dtype='uint8'))[0]) # encoded with RLE + if isinstance(mask_results, tuple): + return encoded_mask_results, cls_mask_scores + else: + return encoded_mask_results + + +def mask2bbox(masks): + """Obtain tight bounding boxes of binary masks. + + Args: + masks (Tensor): Binary mask of shape (n, h, w). + + Returns: + Tensor: Bboxe with shape (n, 4) of \ + positive region in binary mask. + """ + N = masks.shape[0] + bboxes = masks.new_zeros((N, 4), dtype=torch.float32) + x_any = torch.any(masks, dim=1) + y_any = torch.any(masks, dim=2) + for i in range(N): + x = torch.where(x_any[i, :])[0] + y = torch.where(y_any[i, :])[0] + if len(x) > 0 and len(y) > 0: + bboxes[i, :] = bboxes.new_tensor( + [x[0], y[0], x[-1] + 1, y[-1] + 1]) + + return bboxes diff --git a/segmentation/mmseg_custom/core/utils/__init__.py b/segmentation/mmseg_custom/core/utils/__init__.py new file mode 100644 index 000000000..26ff24d35 --- /dev/null +++ b/segmentation/mmseg_custom/core/utils/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .dist_utils import (DistOptimizerHook, all_reduce_dict, allreduce_grads, + reduce_mean) +from .misc import add_prefix, multi_apply + +__all__ = [ + 'add_prefix', 'multi_apply', 'DistOptimizerHook', 'allreduce_grads', + 'all_reduce_dict', 'reduce_mean' +] diff --git a/segmentation/mmseg_custom/core/utils/dist_utils.py b/segmentation/mmseg_custom/core/utils/dist_utils.py new file mode 100644 index 000000000..88e519f72 --- /dev/null +++ b/segmentation/mmseg_custom/core/utils/dist_utils.py @@ -0,0 +1,148 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import functools +import pickle +import warnings +from collections import OrderedDict + +import torch +import torch.distributed as dist +from mmcv.runner import OptimizerHook, get_dist_info +from torch._utils import (_flatten_dense_tensors, _take_tensors, + _unflatten_dense_tensors) + + +def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1): + if bucket_size_mb > 0: + bucket_size_bytes = bucket_size_mb * 1024 * 1024 + buckets = _take_tensors(tensors, bucket_size_bytes) + else: + buckets = OrderedDict() + for tensor in tensors: + tp = tensor.type() + if tp not in buckets: + buckets[tp] = [] + buckets[tp].append(tensor) + buckets = buckets.values() + + for bucket in buckets: + flat_tensors = _flatten_dense_tensors(bucket) + dist.all_reduce(flat_tensors) + flat_tensors.div_(world_size) + for tensor, synced in zip( + bucket, _unflatten_dense_tensors(flat_tensors, bucket)): + tensor.copy_(synced) + + +def allreduce_grads(params, coalesce=True, bucket_size_mb=-1): + """Allreduce gradients. + + Args: + params (list[torch.Parameters]): List of parameters of a model + coalesce (bool, optional): Whether allreduce parameters as a whole. + Defaults to True. + bucket_size_mb (int, optional): Size of bucket, the unit is MB. + Defaults to -1. + """ + grads = [ + param.grad.data for param in params + if param.requires_grad and param.grad is not None + ] + world_size = dist.get_world_size() + if coalesce: + _allreduce_coalesced(grads, world_size, bucket_size_mb) + else: + for tensor in grads: + dist.all_reduce(tensor.div_(world_size)) + + +class DistOptimizerHook(OptimizerHook): + """Deprecated optimizer hook for distributed training.""" + def __init__(self, *args, **kwargs): + warnings.warn('"DistOptimizerHook" is deprecated, please switch to' + '"mmcv.runner.OptimizerHook".') + super().__init__(*args, **kwargs) + + +def reduce_mean(tensor): + """"Obtain the mean of tensor on different GPUs.""" + if not (dist.is_available() and dist.is_initialized()): + return tensor + tensor = tensor.clone() + dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM) + return tensor + + +def obj2tensor(pyobj, device='cuda'): + """Serialize picklable python object to tensor.""" + storage = torch.ByteStorage.from_buffer(pickle.dumps(pyobj)) + return torch.ByteTensor(storage).to(device=device) + + +def tensor2obj(tensor): + """Deserialize tensor to picklable python object.""" + return pickle.loads(tensor.cpu().numpy().tobytes()) + + +@functools.lru_cache() +def _get_global_gloo_group(): + """Return a process group based on gloo backend, containing all the ranks + The result is cached.""" + if dist.get_backend() == 'nccl': + return dist.new_group(backend='gloo') + else: + return dist.group.WORLD + + +def all_reduce_dict(py_dict, op='sum', group=None, to_float=True): + """Apply all reduce function for python dict object. + + The code is modified from https://github.com/Megvii- + BaseDetection/YOLOX/blob/main/yolox/utils/allreduce_norm.py. + + NOTE: make sure that py_dict in different ranks has the same keys and + the values should be in the same shape. + + Args: + py_dict (dict): Dict to be applied all reduce op. + op (str): Operator, could be 'sum' or 'mean'. Default: 'sum' + group (:obj:`torch.distributed.group`, optional): Distributed group, + Default: None. + to_float (bool): Whether to convert all values of dict to float. + Default: True. + + Returns: + OrderedDict: reduced python dict object. + """ + _, world_size = get_dist_info() + if world_size == 1: + return py_dict + if group is None: + # TODO: May try not to use gloo in the future + group = _get_global_gloo_group() + if dist.get_world_size(group) == 1: + return py_dict + + # all reduce logic across different devices. + py_key = list(py_dict.keys()) + py_key_tensor = obj2tensor(py_key) + dist.broadcast(py_key_tensor, src=0) + py_key = tensor2obj(py_key_tensor) + + tensor_shapes = [py_dict[k].shape for k in py_key] + tensor_numels = [py_dict[k].numel() for k in py_key] + + if to_float: + flatten_tensor = torch.cat( + [py_dict[k].flatten().float() for k in py_key]) + else: + flatten_tensor = torch.cat([py_dict[k].flatten() for k in py_key]) + + dist.all_reduce(flatten_tensor, op=dist.ReduceOp.SUM) + if op == 'mean': + flatten_tensor /= world_size + + split_tensors = [ + x.reshape(shape) for x, shape in zip( + torch.split(flatten_tensor, tensor_numels), tensor_shapes) + ] + return OrderedDict({k: v for k, v in zip(py_key, split_tensors)}) diff --git a/segmentation/mmseg_custom/core/utils/misc.py b/segmentation/mmseg_custom/core/utils/misc.py new file mode 100644 index 000000000..9e161fba7 --- /dev/null +++ b/segmentation/mmseg_custom/core/utils/misc.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +def multi_apply(func, *args, **kwargs): + """Apply function to a list of arguments. + + Note: + This function applies the ``func`` to multiple inputs and + map the multiple outputs of the ``func`` into different + list. Each list contains the same type of outputs corresponding + to different inputs. + + Args: + func (Function): A function that will be applied to a list of + arguments + + Returns: + tuple(list): A tuple containing multiple list, each list contains \ + a kind of returned results by the function + """ + pfunc = partial(func, **kwargs) if kwargs else func + map_results = map(pfunc, *args) + return tuple(map(list, zip(*map_results))) + + +def add_prefix(inputs, prefix): + """Add prefix for dict. + + Args: + inputs (dict): The input dict with str keys. + prefix (str): The prefix to add. + + Returns: + + dict: The dict with keys updated with ``prefix``. + """ + + outputs = dict() + for name, value in inputs.items(): + outputs[f'{prefix}.{name}'] = value + + return outputs diff --git a/segmentation/mmseg_custom/datasets/__init__.py b/segmentation/mmseg_custom/datasets/__init__.py new file mode 100644 index 000000000..1471c0cca --- /dev/null +++ b/segmentation/mmseg_custom/datasets/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .mapillary import MapillaryDataset # noqa: F401,F403 +from .pipelines import * # noqa: F401,F403 diff --git a/segmentation/mmseg_custom/datasets/mapillary.py b/segmentation/mmseg_custom/datasets/mapillary.py new file mode 100644 index 000000000..086be5ebb --- /dev/null +++ b/segmentation/mmseg_custom/datasets/mapillary.py @@ -0,0 +1,46 @@ +from mmseg.datasets.builder import DATASETS +from mmseg.datasets.custom import CustomDataset + + +@DATASETS.register_module() +class MapillaryDataset(CustomDataset): + """Mapillary dataset. + """ + CLASSES = ('Bird', 'Ground Animal', 'Curb', 'Fence', 'Guard Rail', 'Barrier', + 'Wall', 'Bike Lane', 'Crosswalk - Plain', 'Curb Cut', 'Parking', 'Pedestrian Area', + 'Rail Track', 'Road', 'Service Lane', 'Sidewalk', 'Bridge', 'Building', 'Tunnel', + 'Person', 'Bicyclist', 'Motorcyclist', 'Other Rider', 'Lane Marking - Crosswalk', + 'Lane Marking - General', 'Mountain', 'Sand', 'Sky', 'Snow', 'Terrain', 'Vegetation', + 'Water', 'Banner', 'Bench', 'Bike Rack', 'Billboard', 'Catch Basin', 'CCTV Camera', + 'Fire Hydrant', 'Junction Box', 'Mailbox', 'Manhole', 'Phone Booth', 'Pothole', + 'Street Light', 'Pole', 'Traffic Sign Frame', 'Utility Pole', 'Traffic Light', + 'Traffic Sign (Back)', 'Traffic Sign (Front)', 'Trash Can', 'Bicycle', 'Boat', + 'Bus', 'Car', 'Caravan', 'Motorcycle', 'On Rails', 'Other Vehicle', 'Trailer', + 'Truck', 'Wheeled Slow', 'Car Mount', 'Ego Vehicle', 'Unlabeled') + + PALETTE = [[165, 42, 42], [0, 192, 0], [196, 196, 196], [190, 153, 153], + [180, 165, 180], [90, 120, 150], [ + 102, 102, 156], [128, 64, 255], + [140, 140, 200], [170, 170, 170], [250, 170, 160], [96, 96, 96], + [230, 150, 140], [128, 64, 128], [ + 110, 110, 110], [244, 35, 232], + [150, 100, 100], [70, 70, 70], [150, 120, 90], [220, 20, 60], + [255, 0, 0], [255, 0, 100], [255, 0, 200], [200, 128, 128], + [255, 255, 255], [64, 170, 64], [230, 160, 50], [70, 130, 180], + [190, 255, 255], [152, 251, 152], [107, 142, 35], [0, 170, 30], + [255, 255, 128], [250, 0, 30], [100, 140, 180], [220, 220, 220], + [220, 128, 128], [222, 40, 40], [100, 170, 30], [40, 40, 40], + [33, 33, 33], [100, 128, 160], [142, 0, 0], [70, 100, 150], + [210, 170, 100], [153, 153, 153], [128, 128, 128], [0, 0, 80], + [250, 170, 30], [192, 192, 192], [220, 220, 0], [140, 140, 20], + [119, 11, 32], [150, 0, 255], [ + 0, 60, 100], [0, 0, 142], [0, 0, 90], + [0, 0, 230], [0, 80, 100], [128, 64, 64], [0, 0, 110], [0, 0, 70], + [0, 0, 192], [32, 32, 32], [120, 10, 10], [0, 0, 0]] + + def __init__(self, **kwargs): + super(MapillaryDataset, self).__init__( + img_suffix='.jpg', + seg_map_suffix='.png', + reduce_zero_label=False, + **kwargs) \ No newline at end of file diff --git a/segmentation/mmseg_custom/datasets/pipelines/__init__.py b/segmentation/mmseg_custom/datasets/pipelines/__init__.py new file mode 100644 index 000000000..f61008719 --- /dev/null +++ b/segmentation/mmseg_custom/datasets/pipelines/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .formatting import DefaultFormatBundle, ToMask +from .transform import MapillaryHack, PadShortSide, SETR_Resize + +__all__ = [ + 'DefaultFormatBundle', 'ToMask', 'SETR_Resize', 'PadShortSide', + 'MapillaryHack' +] diff --git a/segmentation/mmseg_custom/datasets/pipelines/formatting.py b/segmentation/mmseg_custom/datasets/pipelines/formatting.py new file mode 100644 index 000000000..d1a41ad63 --- /dev/null +++ b/segmentation/mmseg_custom/datasets/pipelines/formatting.py @@ -0,0 +1,82 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +from mmcv.parallel import DataContainer as DC +from mmseg.datasets.builder import PIPELINES +from mmseg.datasets.pipelines.formatting import to_tensor + + +@PIPELINES.register_module(force=True) +class DefaultFormatBundle(object): + """Default formatting bundle. + + It simplifies the pipeline of formatting common fields, including "img" + and "gt_semantic_seg". These fields are formatted as follows. + + - img: (1)transpose, (2)to tensor, (3)to DataContainer (stack=True) + - gt_semantic_seg: (1)unsqueeze dim-0 (2)to tensor, + (3)to DataContainer (stack=True) + """ + def __call__(self, results): + """Call function to transform and format common fields in results. + + Args: + results (dict): Result dict contains the data to convert. + + Returns: + dict: The result dict contains the data that is formatted with + default bundle. + """ + + if 'img' in results: + img = results['img'] + if len(img.shape) < 3: + img = np.expand_dims(img, -1) + img = np.ascontiguousarray(img.transpose(2, 0, 1)) + results['img'] = DC(to_tensor(img), stack=True) + if 'gt_semantic_seg' in results: + # convert to long + results['gt_semantic_seg'] = DC(to_tensor( + results['gt_semantic_seg'][None, ...].astype(np.int64)), + stack=True) + if 'gt_masks' in results: + results['gt_masks'] = DC(to_tensor(results['gt_masks'])) + if 'gt_labels' in results: + results['gt_labels'] = DC(to_tensor(results['gt_labels'])) + + return results + + def __repr__(self): + return self.__class__.__name__ + + +@PIPELINES.register_module() +class ToMask(object): + """Transfer gt_semantic_seg to binary mask and generate gt_labels.""" + def __init__(self, ignore_index=255): + self.ignore_index = ignore_index + + def __call__(self, results): + gt_semantic_seg = results['gt_semantic_seg'] + gt_labels = np.unique(gt_semantic_seg) + # remove ignored region + gt_labels = gt_labels[gt_labels != self.ignore_index] + + gt_masks = [] + for class_id in gt_labels: + gt_masks.append(gt_semantic_seg == class_id) + + if len(gt_masks) == 0: + # Some image does not have annotation (all ignored) + gt_masks = np.empty((0, ) + results['pad_shape'][:-1], dtype=np.int64) + gt_labels = np.empty((0, ), dtype=np.int64) + else: + gt_masks = np.asarray(gt_masks, dtype=np.int64) + gt_labels = np.asarray(gt_labels, dtype=np.int64) + + results['gt_labels'] = gt_labels + results['gt_masks'] = gt_masks + return results + + def __repr__(self): + return self.__class__.__name__ + \ + f'(ignore_index={self.ignore_index})' diff --git a/segmentation/mmseg_custom/datasets/pipelines/transform.py b/segmentation/mmseg_custom/datasets/pipelines/transform.py new file mode 100644 index 000000000..ce570dd38 --- /dev/null +++ b/segmentation/mmseg_custom/datasets/pipelines/transform.py @@ -0,0 +1,350 @@ +import mmcv +import numpy as np +import torch +from mmseg.datasets.builder import PIPELINES + + +@PIPELINES.register_module() +class SETR_Resize(object): + """Resize images & seg. + + This transform resizes the input image to some scale. If the input dict + contains the key "scale", then the scale in the input dict is used, + otherwise the specified scale in the init method is used. + + ``img_scale`` can either be a tuple (single-scale) or a list of tuple + (multi-scale). There are 3 multiscale modes: + + - ``ratio_range is not None``: randomly sample a ratio from the ratio range + and multiply it with the image scale. + + - ``ratio_range is None and multiscale_mode == "range"``: randomly sample a + scale from the a range. + + - ``ratio_range is None and multiscale_mode == "value"``: randomly sample a + scale from multiple scales. + + Args: + img_scale (tuple or list[tuple]): Images scales for resizing. + multiscale_mode (str): Either "range" or "value". + ratio_range (tuple[float]): (min_ratio, max_ratio) + keep_ratio (bool): Whether to keep the aspect ratio when resizing the + image. + """ + def __init__(self, + img_scale=None, + multiscale_mode='range', + ratio_range=None, + keep_ratio=True, + crop_size=None, + setr_multi_scale=False): + + if img_scale is None: + self.img_scale = None + else: + if isinstance(img_scale, list): + self.img_scale = img_scale + else: + self.img_scale = [img_scale] + # assert mmcv.is_list_of(self.img_scale, tuple) + + if ratio_range is not None: + # mode 1: given a scale and a range of image ratio + assert len(self.img_scale) == 1 + else: + # mode 2: given multiple scales or a range of scales + assert multiscale_mode in ['value', 'range'] + + self.multiscale_mode = multiscale_mode + self.ratio_range = ratio_range + self.keep_ratio = keep_ratio + self.crop_size = crop_size + self.setr_multi_scale = setr_multi_scale + + @staticmethod + def random_select(img_scales): + """Randomly select an img_scale from given candidates. + + Args: + img_scales (list[tuple]): Images scales for selection. + + Returns: + (tuple, int): Returns a tuple ``(img_scale, scale_dix)``, + where ``img_scale`` is the selected image scale and + ``scale_idx`` is the selected index in the given candidates. + """ + + assert mmcv.is_list_of(img_scales, tuple) + scale_idx = np.random.randint(len(img_scales)) + img_scale = img_scales[scale_idx] + return img_scale, scale_idx + + @staticmethod + def random_sample(img_scales): + """Randomly sample an img_scale when ``multiscale_mode=='range'``. + + Args: + img_scales (list[tuple]): Images scale range for sampling. + There must be two tuples in img_scales, which specify the lower + and uper bound of image scales. + + Returns: + (tuple, None): Returns a tuple ``(img_scale, None)``, where + ``img_scale`` is sampled scale and None is just a placeholder + to be consistent with :func:`random_select`. + """ + + assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2 + img_scale_long = [max(s) for s in img_scales] + img_scale_short = [min(s) for s in img_scales] + long_edge = np.random.randint( + min(img_scale_long), + max(img_scale_long) + 1) + short_edge = np.random.randint( + min(img_scale_short), + max(img_scale_short) + 1) + img_scale = (long_edge, short_edge) + return img_scale, None + + @staticmethod + def random_sample_ratio(img_scale, ratio_range): + """Randomly sample an img_scale when ``ratio_range`` is specified. + + A ratio will be randomly sampled from the range specified by + ``ratio_range``. Then it would be multiplied with ``img_scale`` to + generate sampled scale. + + Args: + img_scale (tuple): Images scale base to multiply with ratio. + ratio_range (tuple[float]): The minimum and maximum ratio to scale + the ``img_scale``. + + Returns: + (tuple, None): Returns a tuple ``(scale, None)``, where + ``scale`` is sampled ratio multiplied with ``img_scale`` and + None is just a placeholder to be consistent with + :func:`random_select`. + """ + + assert isinstance(img_scale, tuple) and len(img_scale) == 2 + min_ratio, max_ratio = ratio_range + assert min_ratio <= max_ratio + ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio + scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio) + return scale, None + + def _random_scale(self, results): + """Randomly sample an img_scale according to ``ratio_range`` and + ``multiscale_mode``. + + If ``ratio_range`` is specified, a ratio will be sampled and be + multiplied with ``img_scale``. + If multiple scales are specified by ``img_scale``, a scale will be + sampled according to ``multiscale_mode``. + Otherwise, single scale will be used. + + Args: + results (dict): Result dict from :obj:`dataset`. + + Returns: + dict: Two new keys 'scale` and 'scale_idx` are added into + ``results``, which would be used by subsequent pipelines. + """ + + if self.ratio_range is not None: + scale, scale_idx = self.random_sample_ratio( + self.img_scale[0], self.ratio_range) + elif len(self.img_scale) == 1: + scale, scale_idx = self.img_scale[0], 0 + elif self.multiscale_mode == 'range': + scale, scale_idx = self.random_sample(self.img_scale) + elif self.multiscale_mode == 'value': + scale, scale_idx = self.random_select(self.img_scale) + else: + raise NotImplementedError + + results['scale'] = scale + results['scale_idx'] = scale_idx + + def _resize_img(self, results): + """Resize images with ``results['scale']``.""" + + if self.keep_ratio: + if self.setr_multi_scale: + if min(results['scale']) < self.crop_size[0]: + new_short = self.crop_size[0] + else: + new_short = min(results['scale']) + + h, w = results['img'].shape[:2] + if h > w: + new_h, new_w = new_short * h / w, new_short + else: + new_h, new_w = new_short, new_short * w / h + results['scale'] = (new_h, new_w) + + img, scale_factor = mmcv.imrescale(results['img'], + results['scale'], + return_scale=True) + # the w_scale and h_scale has minor difference + # a real fix should be done in the mmcv.imrescale in the future + new_h, new_w = img.shape[:2] + h, w = results['img'].shape[:2] + w_scale = new_w / w + h_scale = new_h / h + else: + img, w_scale, h_scale = mmcv.imresize(results['img'], + results['scale'], + return_scale=True) + scale_factor = np.array([w_scale, h_scale, w_scale, h_scale], + dtype=np.float32) + results['img'] = img + results['img_shape'] = img.shape + results['pad_shape'] = img.shape # in case that there is no padding + results['scale_factor'] = scale_factor + results['keep_ratio'] = self.keep_ratio + + def _resize_seg(self, results): + """Resize semantic segmentation map with ``results['scale']``.""" + for key in results.get('seg_fields', []): + if self.keep_ratio: + gt_seg = mmcv.imrescale(results[key], + results['scale'], + interpolation='nearest') + else: + gt_seg = mmcv.imresize(results[key], + results['scale'], + interpolation='nearest') + results['gt_semantic_seg'] = gt_seg + + def __call__(self, results): + """Call function to resize images, bounding boxes, masks, semantic + segmentation map. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor', + 'keep_ratio' keys are added into result dict. + """ + + if 'scale' not in results: + self._random_scale(results) + self._resize_img(results) + self._resize_seg(results) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += (f'(img_scale={self.img_scale}, ' + f'multiscale_mode={self.multiscale_mode}, ' + f'ratio_range={self.ratio_range}, ' + f'keep_ratio={self.keep_ratio})') + return repr_str + + +@PIPELINES.register_module() +class PadShortSide(object): + """Pad the image & mask. + + Pad to the minimum size that is equal or larger than a number. + Added keys are "pad_shape", "pad_fixed_size", + + Args: + size (int, optional): Fixed padding size. + pad_val (float, optional): Padding value. Default: 0. + seg_pad_val (float, optional): Padding value of segmentation map. + Default: 255. + """ + def __init__(self, size=None, pad_val=0, seg_pad_val=255): + self.size = size + self.pad_val = pad_val + self.seg_pad_val = seg_pad_val + # only one of size and size_divisor should be valid + assert size is not None + + def _pad_img(self, results): + """Pad images according to ``self.size``.""" + h, w = results['img'].shape[:2] + new_h = max(h, self.size) + new_w = max(w, self.size) + padded_img = mmcv.impad(results['img'], + shape=(new_h, new_w), + pad_val=self.pad_val) + + results['img'] = padded_img + results['pad_shape'] = padded_img.shape + # results['unpad_shape'] = (h, w) + + def _pad_seg(self, results): + """Pad masks according to ``results['pad_shape']``.""" + for key in results.get('seg_fields', []): + results[key] = mmcv.impad(results[key], + shape=results['pad_shape'][:2], + pad_val=self.seg_pad_val) + + def __call__(self, results): + """Call function to pad images, masks, semantic segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Updated result dict. + """ + h, w = results['img'].shape[:2] + if h >= self.size and w >= self.size: # 短边比窗口大,跳过 + pass + else: + self._pad_img(results) + self._pad_seg(results) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(size={self.size}, pad_val={self.pad_val})' + return repr_str + + +@PIPELINES.register_module() +class MapillaryHack(object): + """map MV 65 class to 19 class like Cityscapes.""" + def __init__(self): + self.map = [[13, 24, 41], [2, 15], [17], [6], [3], + [45, 47], [48], [50], [30], [29], [27], [19], [20, 21, 22], + [55], [61], [54], [58], [57], [52]] + + self.others = [i for i in range(66)] + for i in self.map: + for j in i: + if j in self.others: + self.others.remove(j) + + def __call__(self, results): + """Call function to process the image with gamma correction. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Processed results. + """ + gt_map = results['gt_semantic_seg'] + # others -> 255 + new_gt_map = np.zeros_like(gt_map) + + for value in self.others: + new_gt_map[gt_map == value] = 255 + + for index, map in enumerate(self.map): + for value in map: + new_gt_map[gt_map == value] = index + + results['gt_semantic_seg'] = new_gt_map + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + return repr_str diff --git a/segmentation/mmseg_custom/models/__init__.py b/segmentation/mmseg_custom/models/__init__.py new file mode 100644 index 000000000..fa63eaff9 --- /dev/null +++ b/segmentation/mmseg_custom/models/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .backbones import * +from .builder import (MASK_ASSIGNERS, MATCH_COST, TRANSFORMER, build_assigner, + build_match_cost) +from .decode_heads import * # noqa: F401,F403 +from .losses import * # noqa: F401,F403 +from .plugins import * +from .segmentors import * # noqa: F401,F403 + +__all__ = [ + 'MASK_ASSIGNERS', 'MATCH_COST', 'TRANSFORMER', 'build_assigner', + 'build_match_cost' +] diff --git a/segmentation/mmseg_custom/models/backbones/__init__.py b/segmentation/mmseg_custom/models/backbones/__init__.py new file mode 100644 index 000000000..9664f38cd --- /dev/null +++ b/segmentation/mmseg_custom/models/backbones/__init__.py @@ -0,0 +1,6 @@ +from .beit_adapter import BEiTAdapter +from .beit_baseline import BEiTBaseline +from .vit_adapter import ViTAdapter +from .vit_baseline import ViTBaseline + +__all__ = ['ViTBaseline', 'ViTAdapter', 'BEiTAdapter', 'BEiTBaseline'] diff --git a/segmentation/mmseg_custom/models/backbones/base/beit.py b/segmentation/mmseg_custom/models/backbones/base/beit.py new file mode 100644 index 000000000..afbecd56c --- /dev/null +++ b/segmentation/mmseg_custom/models/backbones/base/beit.py @@ -0,0 +1,391 @@ +# -------------------------------------------------------- +# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) +# Github source: https://github.com/microsoft/unilm/tree/master/beit +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# By Hangbo Bao +# Based on timm, mmseg, setr, xcit and swin code bases +# https://github.com/rwightman/pytorch-image-models/tree/master/timm +# https://github.com/fudan-zvg/SETR +# https://github.com/facebookresearch/xcit/ +# https://github.com/microsoft/Swin-Transformer +# --------------------------------------------------------' +import math +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from mmcv_custom import load_checkpoint +from mmseg.models.builder import BACKBONES +from mmseg.utils import get_root_logger +from timm.models.layers import drop_path, to_2tuple, trunc_normal_ + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of + residual blocks).""" + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return 'p={}'.format(self.drop_prob) + + +class Mlp(nn.Module): + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + # x = self.drop(x) + # commit this for the original BERT implement + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0., + proj_drop=0., + window_size=None, + attn_head_dim=None): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * self.num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) + self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) + else: + self.q_bias = None + self.v_bias = None + + if window_size: + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = \ + torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + self.register_buffer("relative_position_index", relative_position_index) + + # trunc_normal_(self.relative_position_bias_table, std=.0) + else: + self.window_size = None + self.relative_position_bias_table = None + self.relative_position_index = None + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(all_head_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, rel_pos_bias=None): + B, N, C = x.shape + qkv_bias = None + if self.q_bias is not None: + qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) + # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + if self.relative_position_bias_table is not None: + relative_position_bias = \ + self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + # relative_position_bias = relative_position_bias[:, 1:, 1:] + attn = attn + relative_position_bias.unsqueeze(0) + + if rel_pos_bias is not None: + attn = attn + rel_pos_bias + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, + window_size=None, attn_head_dim=None, with_cp=False): + super().__init__() + self.with_cp = with_cp + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if init_values is not None: + self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) + self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) + else: + self.gamma_1, self.gamma_2 = None, None + + def forward(self, x, H, W, rel_pos_bias=None): + def _inner_forward(x): + if self.gamma_1 is None: + x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + else: + x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) + x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + return x + + if self.with_cp and x.requires_grad: + x = cp.checkpoint(_inner_forward, x) + else: + x = _inner_forward(x) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x, **kwargs): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + # assert H == self.img_size[0] and W == self.img_size[1], \ + # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x) + Hp, Wp = x.shape[2], x.shape[3] + + x = x.flatten(2).transpose(1, 2) + return x, Hp, Wp + + +class HybridEmbed(nn.Module): + """ CNN Feature Map Embedding + Extract feature map from CNN, flatten, project to embedding dim. + """ + + def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): + super().__init__() + assert isinstance(backbone, nn.Module) + img_size = to_2tuple(img_size) + self.img_size = img_size + self.backbone = backbone + if feature_size is None: + with torch.no_grad(): + # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature + # map for all networks, the feature metadata has reliable channel and stride info, but using + # stride to calc feature dim requires info about padding of each stage that isn't captured. + training = backbone.training + if training: + backbone.eval() + o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] + feature_size = o.shape[-2:] + feature_dim = o.shape[1] + backbone.train(training) + else: + feature_size = to_2tuple(feature_size) + feature_dim = self.backbone.feature_info.channels()[-1] + self.num_patches = feature_size[0] * feature_size[1] + self.proj = nn.Linear(feature_dim, embed_dim) + + def forward(self, x): + x = self.backbone(x)[-1] + x = x.flatten(2).transpose(1, 2) + x = self.proj(x) + return x + + +class RelativePositionBias(nn.Module): + + def __init__(self, window_size, num_heads): + super().__init__() + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = \ + torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer("relative_position_index", relative_position_index) + + # trunc_normal_(self.relative_position_bias_table, std=.02) + + def forward(self): + relative_position_bias = \ + self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + + +@BACKBONES.register_module() +class BEiT(nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + + def __init__(self, img_size=512, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, + depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None, + init_values=None, use_checkpoint=False, use_abs_pos_emb=False, use_rel_pos_bias=True, + use_shared_rel_pos_bias=False, pretrained=None, with_cp=False): + super().__init__() + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + self.norm_layer = norm_layer + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.drop_path_rate = drop_path_rate + if hybrid_backbone is not None: + self.patch_embed = HybridEmbed( + hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) + else: + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + if use_abs_pos_emb: + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + else: + self.pos_embed = None + self.pos_drop = nn.Dropout(p=drop_rate) + + if use_shared_rel_pos_bias: + self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads) + else: + self.rel_pos_bias = None + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.use_rel_pos_bias = use_rel_pos_bias + self.use_checkpoint = use_checkpoint + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, with_cp=with_cp, + init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None) + for i in range(depth)]) + + # if self.pos_embed is not None: + # trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + + self.apply(self._init_weights) + + self.init_weights(pretrained) + + # self.fix_init_weight() + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + # pretrained = 'pretrained/beit_large_patch16_512_pt22k_ft22kto1k.pth' + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, strict=False, logger=logger) + + def fix_init_weight(self): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def get_num_layers(self): + return len(self.blocks) diff --git a/segmentation/mmseg_custom/models/backbones/base/vit.py b/segmentation/mmseg_custom/models/backbones/base/vit.py new file mode 100644 index 000000000..dee3bd299 --- /dev/null +++ b/segmentation/mmseg_custom/models/backbones/base/vit.py @@ -0,0 +1,347 @@ +"""Vision Transformer (ViT) in PyTorch. + +A PyTorch implement of Vision Transformers as described in: + +'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' + - https://arxiv.org/abs/2010.11929 + +`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` + - https://arxiv.org/abs/2106.10270 + +The official jax code is released and available at https://github.com/google-research/vision_transformer + +DeiT model defs and weights from https://github.com/facebookresearch/deit, +paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877 + +Acknowledgments: +* The paper authors for releasing code and weights, thanks! +* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out +for some einops/einsum fun +* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT +* Bert reference code checks against Huggingface Transformers and Tensorflow Bert + +Hacked together by / Copyright 2021 Ross Wightman +""" +import logging +import math +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.runner import BaseModule, load_checkpoint +from mmseg.utils import get_root_logger +from timm.models.layers import DropPath, Mlp, to_2tuple + + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding.""" + def __init__(self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + norm_layer=None, + flatten=True): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], + img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + + self.proj = nn.Conv2d(in_chans, + embed_dim, + kernel_size=patch_size, + stride=patch_size) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + x = self.proj(x) + _, _, H, W = x.shape + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x, H, W + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, H, W): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowedAttention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., window_size=14, + pad_mode="constant"): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.window_size = window_size + self.pad_mode = pad_mode + + def forward(self, x, H, W): + B, N, C = x.shape + N_ = self.window_size * self.window_size + H_ = math.ceil(H / self.window_size) * self.window_size + W_ = math.ceil(W / self.window_size) * self.window_size + + qkv = self.qkv(x) # [B, N, C] + qkv = qkv.transpose(1, 2).reshape(B, C * 3, H, W) # [B, C, H, W] + qkv = F.pad(qkv, [0, W_ - W, 0, H_ - H], mode=self.pad_mode) + + qkv = F.unfold(qkv, kernel_size=(self.window_size, self.window_size), + stride=(self.window_size, self.window_size)) + B, C_kw_kw, L = qkv.shape # L - the num of windows + qkv = qkv.reshape(B, C * 3, N_, L).permute(0, 3, 2, 1) # [B, L, N_, C] + qkv = qkv.reshape(B, L, N_, 3, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + # q,k,v [B, L, num_head, N_, C/num_head] + attn = (q @ k.transpose(-2, -1)) * self.scale # [B, L, num_head, N_, N_] + # if self.mask: + # attn = attn * mask + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) # [B, L, num_head, N_, N_] + # attn @ v = [B, L, num_head, N_, C/num_head] + x = (attn @ v).permute(0, 2, 4, 3, 1).reshape(B, C_kw_kw // 3, L) + + x = F.fold(x, output_size=(H_, W_), kernel_size=(self.window_size, self.window_size), + stride=(self.window_size, self.window_size)) # [B, C, H_, W_] + x = x[:, :, :H, :W].reshape(B, C, N).transpose(-1, -2) + x = self.proj(x) + x = self.proj_drop(x) + return x + +# class WindowedAttention(nn.Module): +# def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., window_size=14, pad_mode="constant"): +# super().__init__() +# self.num_heads = num_heads +# head_dim = dim // num_heads +# self.scale = head_dim ** -0.5 +# +# self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) +# self.attn_drop = nn.Dropout(attn_drop) +# self.proj = nn.Linear(dim, dim) +# self.proj_drop = nn.Dropout(proj_drop) +# self.window_size = window_size +# self.pad_mode = pad_mode +# +# def forward(self, x, H, W): +# B, N, C = x.shape +# +# N_ = self.window_size * self.window_size +# H_ = math.ceil(H / self.window_size) * self.window_size +# W_ = math.ceil(W / self.window_size) * self.window_size +# x = x.view(B, H, W, C) +# x = F.pad(x, [0, 0, 0, W_ - W, 0, H_- H], mode=self.pad_mode) +# +# x = window_partition(x, window_size=self.window_size)# nW*B, window_size, window_size, C +# x = x.view(-1, N_, C) +# +# qkv = self.qkv(x).view(-1, N_, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) +# q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) +# attn = (q @ k.transpose(-2, -1)) * self.scale # [B, L, num_head, N_, N_] +# attn = attn.softmax(dim=-1) +# attn = self.attn_drop(attn) # [B, L, num_head, N_, N_] +# x = (attn @ v).transpose(1, 2).reshape(-1, self.window_size, self.window_size, C) +# +# x = window_reverse(x, self.window_size, H_, W_) +# x = x[:, :H, :W, :].reshape(B, N, C).contiguous() +# x = self.proj(x) +# x = self.proj_drop(x) +# return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, windowed=False, + window_size=14, pad_mode="constant", layer_scale=False): + super().__init__() + self.norm1 = norm_layer(dim) + if windowed: + self.attn = WindowedAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, + proj_drop=drop, window_size=window_size, pad_mode=pad_mode) + else: + self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.layer_scale = layer_scale + if layer_scale: + self.gamma1 = nn.Parameter(torch.ones((dim)), requires_grad=True) + self.gamma2 = nn.Parameter(torch.ones((dim)), requires_grad=True) + + def forward(self, x, H, W): + if self.layer_scale: + x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x), H, W)) + x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x))) + else: + x = x + self.drop_path(self.attn(self.norm1(x), H, W)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class TIMMVisionTransformer(BaseModule): + """Vision Transformer. + + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` + - https://arxiv.org/abs/2010.11929 + + Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` + - https://arxiv.org/abs/2012.12877 + """ + def __init__(self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4., + qkv_bias=True, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + layer_scale=False, + embed_layer=PatchEmbed, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU, + window_attn=False, + window_size=14, + pretrained=None): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + embed_layer (nn.Module): patch embedding layer + norm_layer: (nn.Module): normalization layer + pretrained: (str): pretrained path + """ + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + act_layer = act_layer or nn.GELU + self.norm_layer = norm_layer + self.act_layer = act_layer + self.pretrain_size = img_size + self.drop_path_rate = drop_path_rate + self.drop_rate = drop_rate + + window_attn = [window_attn] * depth if not isinstance(window_attn, list) else window_attn + window_size = [window_size] * depth if not isinstance(window_size, list) else window_size + logging.info("window attention:", window_attn) + logging.info("window size:", window_size) + logging.info("layer scale:", layer_scale) + + self.patch_embed = embed_layer( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.Sequential(*[ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, + attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, + windowed=window_attn[i], window_size=window_size[i], layer_scale=layer_scale) + for i in range(depth)]) + + self.init_weights(pretrained) + + def init_weights(self, pretrained=None): + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger) + + def forward_features(self, x): + x, H, W = self.patch_embed(x) + cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_token, x), dim=1) + x = self.pos_drop(x + self.pos_embed) + for blk in self.blocks: + x = blk(x, H, W) + x = self.norm(x) + return x + + def forward(self, x): + x = self.forward_features(x) + return x diff --git a/segmentation/mmseg_custom/models/backbones/beit_adapter.py b/segmentation/mmseg_custom/models/backbones/beit_adapter.py new file mode 100644 index 000000000..0d8b4902a --- /dev/null +++ b/segmentation/mmseg_custom/models/backbones/beit_adapter.py @@ -0,0 +1,504 @@ +# Copyright (c) Shanghai AI Lab. All rights reserved. +import logging +import math +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmseg.models.builder import BACKBONES +from ops.modules import MSDeformAttn +from timm.models.layers import DropPath, trunc_normal_ +from torch.nn.init import normal_ + +from .base.beit import BEiT + +_logger = logging.getLogger(__name__) + + +class ConvFFN(nn.Module): + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.dwconv = DWConv(hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x, H, W): + x = self.fc1(x) + x = self.dwconv(x, H, W) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DWConv(nn.Module): + def __init__(self, dim=768): + super().__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, x, H, W): + B, N, C = x.shape + n = N // 21 + x1 = x[:, 0:16 * n, :].transpose(1, 2).view(B, C, H * 2, W * 2) + x2 = x[:, 16 * n:20 * n, :].transpose(1, 2).view(B, C, H, W) + x3 = x[:, 20 * n:, :].transpose(1, 2).view(B, C, H // 2, W // 2) + x1 = self.dwconv(x1).flatten(2).transpose(1, 2) + x2 = self.dwconv(x2).flatten(2).transpose(1, 2) + x3 = self.dwconv(x3).flatten(2).transpose(1, 2) + x = torch.cat([x1, x2, x3], dim=1) + return x + + +class Extractor(nn.Module): + def __init__(self, + dim, + num_heads=6, + n_points=4, + n_levels=1, + deform_ratio=1.0, + with_cffn=True, + cffn_ratio=0.25, + drop=0., + drop_path=0., + norm_layer=partial(nn.LayerNorm, eps=1e-6)): + super().__init__() + self.query_norm = norm_layer(dim) + self.feat_norm = norm_layer(dim) + self.attn = MSDeformAttn(d_model=dim, + n_levels=n_levels, + n_heads=num_heads, + n_points=n_points, + ratio=deform_ratio) + self.with_cffn = with_cffn + if with_cffn: + self.ffn = ConvFFN(in_features=dim, + hidden_features=int(dim * cffn_ratio), + drop=drop) + self.ffn_norm = norm_layer(dim) + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, query, reference_points, feat, spatial_shapes, + level_start_index, H, W): + attn = self.attn(self.query_norm(query), reference_points, + self.feat_norm(feat), spatial_shapes, + level_start_index, None) + query = query + attn + + if self.with_cffn: + query = query + self.drop_path(self.ffn(self.ffn_norm(query), H, W)) + return query + + +class Injector(nn.Module): + def __init__(self, + dim, + num_heads=6, + n_points=4, + n_levels=1, + deform_ratio=1.0, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + init_values=0.): + super().__init__() + self.query_norm = norm_layer(dim) + self.feat_norm = norm_layer(dim) + self.attn = MSDeformAttn(d_model=dim, + n_levels=n_levels, + n_heads=num_heads, + n_points=n_points, + ratio=deform_ratio) + self.gamma = nn.Parameter(init_values * torch.ones((dim)), + requires_grad=True) + + def forward(self, query, reference_points, feat, spatial_shapes, + level_start_index): + attn = self.attn(self.query_norm(query), reference_points, + self.feat_norm(feat), spatial_shapes, + level_start_index, None) + return query + self.gamma * attn + + +class InteractionBlock(nn.Module): + def __init__(self, + dim, + num_heads=6, + n_points=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + drop=0., + drop_path=0., + with_cffn=True, + cffn_ratio=0.25, + init_values=0., + deform_ratio=1.0, + extra_extractor=False): + super().__init__() + + self.injector = Injector(dim=dim, + n_levels=3, + num_heads=num_heads, + init_values=init_values, + n_points=n_points, + norm_layer=norm_layer, + deform_ratio=deform_ratio) + self.extractor = Extractor(dim=dim, + n_levels=1, + num_heads=num_heads, + n_points=n_points, + norm_layer=norm_layer, + deform_ratio=deform_ratio, + with_cffn=with_cffn, + cffn_ratio=cffn_ratio, + drop=drop, + drop_path=drop_path) + if extra_extractor: + self.extra_extractors = nn.Sequential(*[ + Extractor(dim=dim, + num_heads=num_heads, + n_points=n_points, + norm_layer=norm_layer, + with_cffn=with_cffn, + cffn_ratio=cffn_ratio, + deform_ratio=deform_ratio) for _ in range(2) + ]) + else: + self.extra_extractors = None + + def forward(self, x, c, cls, blocks, deform_inputs1, deform_inputs2, H, W): + x = self.injector(query=x, + reference_points=deform_inputs1[0], + feat=c, + spatial_shapes=deform_inputs1[1], + level_start_index=deform_inputs1[2]) + x = torch.cat((cls, x), dim=1) + for idx, blk in enumerate(blocks): + x = blk(x, H, W) + cls, x = x[:, :1, ], x[:, 1:, ] + c = self.extractor(query=c, + reference_points=deform_inputs2[0], + feat=x, + spatial_shapes=deform_inputs2[1], + level_start_index=deform_inputs2[2], + H=H, + W=W) + if self.extra_extractors is not None: + for extractor in self.extra_extractors: + c = extractor(query=c, + reference_points=deform_inputs2[0], + feat=x, + spatial_shapes=deform_inputs2[1], + level_start_index=deform_inputs2[2], + H=H, + W=W) + return x, c, cls + + +class SpatialPriorModule(nn.Module): + def __init__(self, inplanes=64, embed_dim=384): + super().__init__() + + self.stem = nn.Sequential(*[ + nn.Conv2d( + 3, inplanes, kernel_size=3, stride=2, padding=1, bias=False), + nn.SyncBatchNorm(inplanes), + nn.ReLU(inplace=True), + nn.Conv2d(inplanes, + inplanes, + kernel_size=3, + stride=1, + padding=1, + bias=False), + nn.SyncBatchNorm(inplanes), + nn.ReLU(inplace=True), + nn.Conv2d(inplanes, + inplanes, + kernel_size=3, + stride=1, + padding=1, + bias=False), + nn.SyncBatchNorm(inplanes), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + ]) + self.conv2 = nn.Sequential(*[ + nn.Conv2d(inplanes, + 2 * inplanes, + kernel_size=3, + stride=2, + padding=1, + bias=False), + nn.SyncBatchNorm(2 * inplanes), + nn.ReLU(inplace=True) + ]) + self.conv3 = nn.Sequential(*[ + nn.Conv2d(2 * inplanes, + 4 * inplanes, + kernel_size=3, + stride=2, + padding=1, + bias=False), + nn.SyncBatchNorm(4 * inplanes), + nn.ReLU(inplace=True) + ]) + self.conv4 = nn.Sequential(*[ + nn.Conv2d(4 * inplanes, + 4 * inplanes, + kernel_size=3, + stride=2, + padding=1, + bias=False), + nn.SyncBatchNorm(4 * inplanes), + nn.ReLU(inplace=True) + ]) + self.fc1 = nn.Conv2d(inplanes, + embed_dim, + kernel_size=1, + stride=1, + padding=0, + bias=True) + self.fc2 = nn.Conv2d(2 * inplanes, + embed_dim, + kernel_size=1, + stride=1, + padding=0, + bias=True) + self.fc3 = nn.Conv2d(4 * inplanes, + embed_dim, + kernel_size=1, + stride=1, + padding=0, + bias=True) + self.fc4 = nn.Conv2d(4 * inplanes, + embed_dim, + kernel_size=1, + stride=1, + padding=0, + bias=True) + + def forward(self, x): + c1 = self.stem(x) + c2 = self.conv2(c1) + c3 = self.conv3(c2) + c4 = self.conv4(c3) + c1 = self.fc1(c1) + c2 = self.fc2(c2) + c3 = self.fc3(c3) + c4 = self.fc4(c4) + + bs, dim, _, _ = c1.shape + # c1 = c1.view(bs, dim, -1).transpose(1, 2) # 4s + c2 = c2.view(bs, dim, -1).transpose(1, 2) # 8s + c3 = c3.view(bs, dim, -1).transpose(1, 2) # 16s + c4 = c4.view(bs, dim, -1).transpose(1, 2) # 32s + + return c1, c2, c3, c4 + + +@BACKBONES.register_module() +class BEiTAdapter(BEiT): + def __init__(self, + pretrain_size=224, + conv_inplane=64, + n_points=4, + deform_num_heads=6, + init_values=0., + cffn_ratio=0.25, + deform_ratio=1.0, + with_cffn=True, + interaction_indexes=None, + add_vit_feature=True, + *args, + **kwargs): + + super().__init__(init_values=init_values, *args, **kwargs) + + # self.num_classes = 80 + # self.cls_token = None + self.num_block = len(self.blocks) + self.pretrain_size = (pretrain_size, pretrain_size) + self.flags = [ + i for i in range(-1, self.num_block, self.num_block // 4) + ][1:] + self.interaction_indexes = interaction_indexes + self.add_vit_feature = add_vit_feature + embed_dim = self.embed_dim + + self.level_embed = nn.Parameter(torch.zeros(3, embed_dim)) + self.spm = SpatialPriorModule(inplanes=conv_inplane, + embed_dim=embed_dim) + self.interactions = nn.Sequential(*[ + InteractionBlock(dim=embed_dim, + num_heads=deform_num_heads, + n_points=n_points, + init_values=init_values, + drop_path=self.drop_path_rate, + norm_layer=self.norm_layer, + with_cffn=with_cffn, + cffn_ratio=cffn_ratio, + deform_ratio=deform_ratio, + extra_extractor=True if i == + len(interaction_indexes) - 1 else False) + for i in range(len(interaction_indexes)) + ]) + + self.up = nn.ConvTranspose2d(embed_dim, embed_dim, 2, 2) + self.norm1 = nn.SyncBatchNorm(embed_dim) + self.norm2 = nn.SyncBatchNorm(embed_dim) + self.norm3 = nn.SyncBatchNorm(embed_dim) + self.norm4 = nn.SyncBatchNorm(embed_dim) + + self.up.apply(self._init_weights) + self.spm.apply(self._init_weights) + self.interactions.apply(self._init_weights) + self.apply(self._init_deform_weights) + normal_(self.level_embed) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm) or isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def _get_pos_embed(self, pos_embed, H, W): + pos_embed = pos_embed.reshape(1, self.pretrain_size[0] // 16, + self.pretrain_size[1] // 16, + -1).permute(0, 3, 1, 2) + pos_embed = F.interpolate( + pos_embed, size=(H, W), mode='bicubic', align_corners=False).\ + reshape(1, -1, H * W).permute(0, 2, 1) + return pos_embed + + def _init_deform_weights(self, m): + if isinstance(m, MSDeformAttn): + m._reset_parameters() + + def _get_reference_points(self, spatial_shapes, device): + reference_points_list = [] + for lvl, (H_, W_) in enumerate(spatial_shapes): + ref_y, ref_x = torch.meshgrid( + torch.linspace(0.5, + H_ - 0.5, + H_, + dtype=torch.float32, + device=device), + torch.linspace(0.5, + W_ - 0.5, + W_, + dtype=torch.float32, + device=device)) + ref_y = ref_y.reshape(-1)[None] / H_ + ref_x = ref_x.reshape(-1)[None] / W_ + ref = torch.stack((ref_x, ref_y), -1) + reference_points_list.append(ref) + reference_points = torch.cat(reference_points_list, 1) + reference_points = reference_points[:, :, None] + return reference_points + + def _deform_inputs(self, x): + bs, c, h, w = x.shape + spatial_shapes = torch.as_tensor([(h // 8, w // 8), (h // 16, w // 16), + (h // 32, w // 32)], + dtype=torch.long, + device=x.device) + level_start_index = torch.cat((spatial_shapes.new_zeros( + (1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) + reference_points = self._get_reference_points([(h // 16, w // 16)], + x.device) + deform_inputs1 = [reference_points, spatial_shapes, level_start_index] + + spatial_shapes = torch.as_tensor([(h // 16, w // 16)], + dtype=torch.long, + device=x.device) + level_start_index = torch.cat((spatial_shapes.new_zeros( + (1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) + reference_points = self._get_reference_points([(h // 8, w // 8), + (h // 16, w // 16), + (h // 32, w // 32)], + x.device) + deform_inputs2 = [reference_points, spatial_shapes, level_start_index] + + return deform_inputs1, deform_inputs2 + + def _add_level_embed(self, c2, c3, c4): + c2 = c2 + self.level_embed[0] + c3 = c3 + self.level_embed[1] + c4 = c4 + self.level_embed[2] + return c2, c3, c4 + + def forward(self, x): + deform_inputs1, deform_inputs2 = self._deform_inputs(x) + + # SPM forward + c1, c2, c3, c4 = self.spm(x) + c2, c3, c4 = self._add_level_embed(c2, c3, c4) + c = torch.cat([c2, c3, c4], dim=1) + + # Patch Embedding forward + x, H, W = self.patch_embed(x) + bs, n, dim = x.shape + cls = self.cls_token.expand(bs, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + + if self.pos_embed is not None: + pos_embed = self._get_pos_embed(self.pos_embed, H, W) + x = x + pos_embed + x = self.pos_drop(x) + + # Interaction + outs = list() + for i, layer in enumerate(self.interactions): + indexes = self.interaction_indexes[i] + x, c, cls = layer(x, c, cls, self.blocks[indexes[0]:indexes[-1] + 1], + deform_inputs1, deform_inputs2, H, W) + outs.append(x.transpose(1, 2).view(bs, dim, H, W).contiguous()) + + # Split & Reshape + c2 = c[:, 0:c2.size(1), :] + c3 = c[:, c2.size(1):c2.size(1) + c3.size(1), :] + c4 = c[:, c2.size(1) + c3.size(1):, :] + + c2 = c2.transpose(1, 2).view(bs, dim, H * 2, W * 2).contiguous() + c3 = c3.transpose(1, 2).view(bs, dim, H, W).contiguous() + c4 = c4.transpose(1, 2).view(bs, dim, H // 2, W // 2).contiguous() + c1 = self.up(c2) + c1 + + if self.add_vit_feature: + x1, x2, x3, x4 = outs + x1 = F.interpolate(x1, + scale_factor=4, + mode='bilinear', + align_corners=False) + x2 = F.interpolate(x2, + scale_factor=2, + mode='bilinear', + align_corners=False) + x4 = F.interpolate(x4, + scale_factor=0.5, + mode='bilinear', + align_corners=False) + c1, c2, c3, c4 = c1 + x1, c2 + x2, c3 + x3, c4 + x4 + + # Final Norm + f1 = self.norm1(c1) + f2 = self.norm2(c2) + f3 = self.norm3(c3) + f4 = self.norm4(c4) + return [f1, f2, f3, f4] diff --git a/segmentation/mmseg_custom/models/backbones/beit_baseline.py b/segmentation/mmseg_custom/models/backbones/beit_baseline.py new file mode 100644 index 000000000..88df44df0 --- /dev/null +++ b/segmentation/mmseg_custom/models/backbones/beit_baseline.py @@ -0,0 +1,460 @@ +# -------------------------------------------------------- +# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) +# Github source: https://github.com/microsoft/unilm/tree/master/beit +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# By Hangbo Bao +# Based on timm, mmseg, setr, xcit and swin code bases +# https://github.com/rwightman/pytorch-image-models/tree/master/timm +# https://github.com/fudan-zvg/SETR +# https://github.com/facebookresearch/xcit/ +# https://github.com/microsoft/Swin-Transformer +# --------------------------------------------------------' +import math +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from mmcv_custom import load_checkpoint +from mmseg.models.builder import BACKBONES +from mmseg.utils import get_root_logger +from timm.models.layers import drop_path, to_2tuple, trunc_normal_ + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of + residual blocks).""" + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return 'p={}'.format(self.drop_prob) + + +class Mlp(nn.Module): + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + # x = self.drop(x) + # commit this for the original BERT implement + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0., + proj_drop=0., + window_size=None, + attn_head_dim=None): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * self.num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) + self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) + else: + self.q_bias = None + self.v_bias = None + + if window_size: + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = \ + torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer("relative_position_index", relative_position_index) + + # trunc_normal_(self.relative_position_bias_table, std=.0) + else: + self.window_size = None + self.relative_position_bias_table = None + self.relative_position_index = None + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(all_head_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, rel_pos_bias=None): + B, N, C = x.shape + qkv_bias = None + if self.q_bias is not None: + qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) + # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + if self.relative_position_bias_table is not None: + relative_position_bias = \ + self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if rel_pos_bias is not None: + attn = attn + rel_pos_bias + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, + window_size=None, attn_head_dim=None): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if init_values is not None: + self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) + self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) + else: + self.gamma_1, self.gamma_2 = None, None + + def forward(self, x, rel_pos_bias=None): + if self.gamma_1 is None: + x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + else: + x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) + x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x, **kwargs): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + # assert H == self.img_size[0] and W == self.img_size[1], \ + # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x) + Hp, Wp = x.shape[2], x.shape[3] + + x = x.flatten(2).transpose(1, 2) + return x, (Hp, Wp) + + +class HybridEmbed(nn.Module): + """ CNN Feature Map Embedding + Extract feature map from CNN, flatten, project to embedding dim. + """ + + def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): + super().__init__() + assert isinstance(backbone, nn.Module) + img_size = to_2tuple(img_size) + self.img_size = img_size + self.backbone = backbone + if feature_size is None: + with torch.no_grad(): + # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature + # map for all networks, the feature metadata has reliable channel and stride info, but using + # stride to calc feature dim requires info about padding of each stage that isn't captured. + training = backbone.training + if training: + backbone.eval() + o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] + feature_size = o.shape[-2:] + feature_dim = o.shape[1] + backbone.train(training) + else: + feature_size = to_2tuple(feature_size) + feature_dim = self.backbone.feature_info.channels()[-1] + self.num_patches = feature_size[0] * feature_size[1] + self.proj = nn.Linear(feature_dim, embed_dim) + + def forward(self, x): + x = self.backbone(x)[-1] + x = x.flatten(2).transpose(1, 2) + x = self.proj(x) + return x + + +class RelativePositionBias(nn.Module): + + def __init__(self, window_size, num_heads): + super().__init__() + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = \ + torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer("relative_position_index", relative_position_index) + + # trunc_normal_(self.relative_position_bias_table, std=.02) + + def forward(self): + relative_position_bias = \ + self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH + return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + + +@BACKBONES.register_module() +class BEiTBaseline(nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., hybrid_backbone=None, norm_layer=None, init_values=None, use_checkpoint=False, + use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, + out_indices=[3, 5, 7, 11]): + super().__init__() + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + + if hybrid_backbone is not None: + self.patch_embed = HybridEmbed( + hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) + else: + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + self.out_indices = out_indices + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + if use_abs_pos_emb: + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + else: + self.pos_embed = None + self.pos_drop = nn.Dropout(p=drop_rate) + + if use_shared_rel_pos_bias: + self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads) + else: + self.rel_pos_bias = None + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.use_rel_pos_bias = use_rel_pos_bias + self.use_checkpoint = use_checkpoint + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None) + for i in range(depth)]) + + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + # trunc_normal_(self.mask_token, std=.02) + self.out_indices = out_indices + + if patch_size == 16: + self.fpn1 = nn.Sequential( + nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), + nn.SyncBatchNorm(embed_dim), + nn.GELU(), + nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), + ) + + self.fpn2 = nn.Sequential( + nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), + ) + + self.fpn3 = nn.Identity() + + self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2) + elif patch_size == 8: + self.fpn1 = nn.Sequential( + nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), + ) + + self.fpn2 = nn.Identity() + + self.fpn3 = nn.Sequential( + nn.MaxPool2d(kernel_size=2, stride=2), + ) + + self.fpn4 = nn.Sequential( + nn.MaxPool2d(kernel_size=4, stride=4), + ) + self.apply(self._init_weights) + self.fix_init_weight() + + def fix_init_weight(self): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + if isinstance(pretrained, str): + self.apply(_init_weights) + logger = get_root_logger() + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + self.apply(_init_weights) + else: + raise TypeError('pretrained must be a str or None') + + def get_num_layers(self): + return len(self.blocks) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def forward_features(self, x): + B, C, H, W = x.shape + x, (Hp, Wp) = self.patch_embed(x) + batch_size, seq_len, _ = x.size() + + cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + if self.pos_embed is not None: + x = x + self.pos_embed + x = self.pos_drop(x) + + rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None + features = [] + for i, blk in enumerate(self.blocks): + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, rel_pos_bias) + else: + x = blk(x, rel_pos_bias) + if i in self.out_indices: + xp = x[:, 1:, :].permute(0, 2, 1).reshape(B, -1, Hp, Wp) + features.append(xp.contiguous()) + + ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4] + for i in range(len(features)): + features[i] = ops[i](features[i]) + + return tuple(features) + + def forward(self, x): + x = self.forward_features(x) + return x diff --git a/segmentation/mmseg_custom/models/backbones/vit_adapter.py b/segmentation/mmseg_custom/models/backbones/vit_adapter.py new file mode 100644 index 000000000..09e98ea2c --- /dev/null +++ b/segmentation/mmseg_custom/models/backbones/vit_adapter.py @@ -0,0 +1,397 @@ +import logging +from functools import partial +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +from timm.models.layers import trunc_normal_, DropPath +from torch.nn.init import normal_ +from .base.vit import TIMMVisionTransformer +from mmseg.models.builder import BACKBONES +from ops.modules import MSDeformAttn + +_logger = logging.getLogger(__name__) + + +class ConvFFN(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.dwconv = DWConv(hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x, H, W): + x = self.fc1(x) + x = self.dwconv(x, H, W) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class DWConv(nn.Module): + def __init__(self, dim=768): + super(DWConv, self).__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, x, H, W): + B, N, C = x.shape + n = N // 21 + x1 = x[:, 0:16 * n, :].transpose(1, 2).view(B, C, H * 2, W * 2) + x2 = x[:, 16 * n:20 * n, :].transpose(1, 2).view(B, C, H, W) + x3 = x[:, 20 * n:, :].transpose(1, 2).view(B, C, H // 2, W // 2) + x1 = self.dwconv(x1).flatten(2).transpose(1, 2) + x2 = self.dwconv(x2).flatten(2).transpose(1, 2) + x3 = self.dwconv(x3).flatten(2).transpose(1, 2) + x = torch.cat([x1, x2, x3], dim=1) + return x + + +class ExtractLayer(nn.Module): + + def __init__(self, + dim, + num_heads=6, + n_points=4, + n_levels=1, + ratio=1.0, + norm_layer=partial(nn.LayerNorm, eps=1e-6)): + super().__init__() + self.query_norm = norm_layer(dim) + self.feat_norm = norm_layer(dim) + self.attn = MSDeformAttn(d_model=dim, n_levels=n_levels, n_heads=num_heads, n_points=n_points, ratio=ratio) + + def forward(self, query, reference_points, feat, spatial_shapes, level_start_index): + attn = self.attn(self.query_norm(query), reference_points, self.feat_norm(feat), + spatial_shapes, level_start_index, None) + query = query + attn + return query + + +class InsertLayer(nn.Module): + + def __init__(self, + dim, + num_heads=6, + n_points=4, + n_levels=1, + ratio=1.0, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + init_values=0.): + super().__init__() + self.query_norm = norm_layer(dim) + self.feat_norm = norm_layer(dim) + self.attn = MSDeformAttn(d_model=dim, n_levels=n_levels, n_heads=num_heads, n_points=n_points, ratio=ratio) + self.gamma = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) + + def forward(self, query, reference_points, feat, spatial_shapes, level_start_index): + attn = self.attn(self.query_norm(query), reference_points, self.feat_norm(feat), + spatial_shapes, level_start_index, None) + return query + self.gamma * attn + + +class InteractBlock(nn.Module): + + def __init__(self, dim, + num_heads=6, + n_points=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + drop=0., + drop_path=0., + with_ffn=True, + ffn_ratio=0.25, + init_values=0., + extract_deform_ratio=1.0, + insert_deform_ratio=1.0): + super().__init__() + self.extract = ExtractLayer(dim=dim, n_levels=1, num_heads=num_heads, + n_points=n_points, norm_layer=norm_layer, ratio=extract_deform_ratio) + self.insert = InsertLayer(dim=dim, n_levels=3, num_heads=num_heads, init_values=init_values, + n_points=n_points, norm_layer=norm_layer, ratio=insert_deform_ratio) + + self.with_ffn = with_ffn + if with_ffn: + self.ffn = ConvFFN(in_features=dim, hidden_features=int(dim * ffn_ratio), drop=drop) + self.ffn_norm = norm_layer(dim) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x, c, blocks, deform_pkg1, deform_pkg2, H, W): + x = self.insert(x, deform_pkg1[0], c, deform_pkg1[1], deform_pkg1[2]) + for idx, blk in enumerate(blocks): + x = blk(x, H, W) + c = self.extract(c, deform_pkg2[0], x, deform_pkg2[1], deform_pkg2[2]) + if self.with_ffn: + c = c + self.drop_path(self.ffn(self.ffn_norm(c), H, W)) + return x, c + + +class ExtractBlock(nn.Module): + + def __init__(self, dim, + num_heads=6, + n_points=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + drop=0., + drop_path=0., + with_ffn=True, + ffn_ratio=0.25, + deform_ratio=1.0): + super().__init__() + self.extract = ExtractLayer(dim=dim, n_levels=1, num_heads=num_heads, + n_points=n_points, norm_layer=norm_layer, ratio=deform_ratio) + + self.with_ffn = with_ffn + if with_ffn: + self.ffn = ConvFFN(in_features=dim, hidden_features=int(dim * ffn_ratio), drop=drop) + self.ffn_norm = norm_layer(dim) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, c, x, deform_pkg, H, W): + + c = self.extract(c, deform_pkg[0], x, deform_pkg[1], deform_pkg[2]) + if self.with_ffn: + c = c + self.drop_path(self.ffn(self.ffn_norm(c), H, W)) + return c + + +class ConvBranch(nn.Module): + def __init__(self, inplanes=64, embed_dim=384): + super(ConvBranch, self).__init__() + + self.stem = nn.Sequential(*[ + nn.Conv2d(3, inplanes, kernel_size=3, stride=2, padding=1, bias=False), + nn.SyncBatchNorm(inplanes), + nn.ReLU(inplace=True), + nn.Conv2d(inplanes, inplanes, kernel_size=3, stride=1, padding=1, bias=False), + nn.SyncBatchNorm(inplanes), + nn.ReLU(inplace=True), + nn.Conv2d(inplanes, inplanes, kernel_size=3, stride=1, padding=1, bias=False), + nn.SyncBatchNorm(inplanes), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + ]) + self.conv2 = nn.Sequential(*[ + nn.Conv2d(inplanes, 2 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), + nn.SyncBatchNorm(2 * inplanes), + nn.ReLU(inplace=True) + ]) + self.conv3 = nn.Sequential(*[ + nn.Conv2d(2 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), + nn.SyncBatchNorm(4 * inplanes), + nn.ReLU(inplace=True) + ]) + self.conv4 = nn.Sequential(*[ + nn.Conv2d(4 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False), + nn.SyncBatchNorm(4 * inplanes), + nn.ReLU(inplace=True) + ]) + self.fc1 = nn.Conv2d(inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) + self.fc2 = nn.Conv2d(2 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) + self.fc3 = nn.Conv2d(4 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) + self.fc4 = nn.Conv2d(4 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) + + def forward(self, x): + c1 = self.stem(x) + c2 = self.conv2(c1) + c3 = self.conv3(c2) + c4 = self.conv4(c3) + c1 = self.fc1(c1) + c2 = self.fc2(c2) + c3 = self.fc3(c3) + c4 = self.fc4(c4) + + bs, dim, _, _ = c1.shape + # c1 = c1.view(bs, dim, -1).transpose(1, 2) # 4s + c2 = c2.view(bs, dim, -1).transpose(1, 2) # 8s + c3 = c3.view(bs, dim, -1).transpose(1, 2) # 16s + c4 = c4.view(bs, dim, -1).transpose(1, 2) # 32s + + return c1, c2, c3, c4 + + +@BACKBONES.register_module() +class ViTAdapter(TIMMVisionTransformer): + + def __init__(self, + pretrain_size=224, + num_heads=12, + conv_inplane=64, + n_points=4, + deform_num_heads=6, + init_values=0., + interaction_indexes=None, + interact_with_ffn=False, + cffn_ratio=0.25, + num_extract_block=3, + extract_with_ffn=True, + extract_ffn_ratio=0.25, + add_vit_feature=True, + extract_deform_ratio=1.0, + deform_ratio=1.0, + pretrained=None, + *args, + **kwargs): + + super().__init__(num_heads=num_heads, pretrained=pretrained, + *args, **kwargs) + + self.num_classes = 80 + self.cls_token = None + self.num_block = len(self.blocks) + self.pretrain_size = (pretrain_size, pretrain_size) + self.flags = [i for i in range(-1, self.num_block, self.num_block // 4)][1:] + self.interaction_indexes = interaction_indexes + self.add_vit_feature = add_vit_feature + embed_dim = self.embed_dim + + self.level_embed = nn.Parameter(torch.zeros(3, embed_dim)) + self.conv_branch = ConvBranch(inplanes=conv_inplane, embed_dim=embed_dim) + self.interact_blocks = nn.Sequential(*[ + InteractBlock(dim=embed_dim, + num_heads=deform_num_heads, + n_points=n_points, + init_values=init_values, + drop_path=self.drop_path_rate, + norm_layer=self.norm_layer, + with_ffn=interact_with_ffn, + ffn_ratio=cffn_ratio, + extract_deform_ratio=deform_ratio, + insert_deform_ratio=deform_ratio + ) for _ in range(len(interaction_indexes))]) + self.extract_blocks = nn.Sequential(*[ + ExtractBlock(dim=embed_dim, + num_heads=deform_num_heads, + n_points=n_points, + norm_layer=self.norm_layer, + with_ffn=extract_with_ffn, + ffn_ratio=extract_ffn_ratio, + deform_ratio=extract_deform_ratio + ) for _ in range(num_extract_block) + ]) + self.up = nn.ConvTranspose2d(embed_dim, embed_dim, 2, 2) + self.norm1 = nn.SyncBatchNorm(embed_dim) + self.norm2 = nn.SyncBatchNorm(embed_dim) + self.norm3 = nn.SyncBatchNorm(embed_dim) + self.norm4 = nn.SyncBatchNorm(embed_dim) + + self.up.apply(self._init_weights) + self.conv_branch.apply(self._init_weights) + self.interact_blocks.apply(self._init_weights) + self.extract_blocks.apply(self._init_weights) + self.apply(self._init_deform_weights) + normal_(self.level_embed) + + self.init_weights(pretrained) + + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm) or isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def _get_pos_embed(self, pos_embed, H, W): + pos_embed = pos_embed.reshape(1, self.pretrain_size[0] // 16, + self.pretrain_size[1] // 16, -1).permute(0, 3, 1, 2) + pos_embed = F.interpolate(pos_embed, size=(H, W), mode="bicubic", + align_corners=False).reshape(1, -1, H * W).permute(0, 2, 1) + return pos_embed + + def _init_deform_weights(self, m): + if isinstance(m, MSDeformAttn): + m._reset_parameters() + + def _get_reference_points(self, spatial_shapes, device): + reference_points_list = [] + for lvl, (H_, W_) in enumerate(spatial_shapes): + ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), + torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device)) + ref_y = ref_y.reshape(-1)[None] / H_ + ref_x = ref_x.reshape(-1)[None] / W_ + ref = torch.stack((ref_x, ref_y), -1) + reference_points_list.append(ref) + reference_points = torch.cat(reference_points_list, 1) + reference_points = reference_points[:, :, None] + return reference_points + + def forward_deform_pkgs(self, x): + bs, c, h, w = x.shape + spatial_shapes = torch.as_tensor([(h // 8, w // 8), + (h // 16, w // 16), + (h // 32, w // 32)], dtype=torch.long).cuda() + level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) + reference_points = self._get_reference_points([(h // 16, w // 16)], "cuda").cuda() + deform_pkg1 = [reference_points, spatial_shapes, level_start_index] + + spatial_shapes = torch.as_tensor([(h // 16, w // 16)], dtype=torch.long).cuda() + level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) + reference_points = self._get_reference_points([(h // 8, w // 8), + (h // 16, w // 16), + (h // 32, w // 32)], "cuda").cuda() + deform_pkg2 = [reference_points, spatial_shapes, level_start_index] + + return deform_pkg1, deform_pkg2 + + def add_level_embed(self, c2, c3, c4): + c2 = c2 + self.level_embed[0] + c3 = c3 + self.level_embed[1] + c4 = c4 + self.level_embed[2] + return c2, c3, c4 + + def forward(self, x): + deform_pkg1, deform_pkg2 = self.forward_deform_pkgs(x) + + c1, c2, c3, c4 = self.conv_branch(x) + c2, c3, c4 = self.add_level_embed(c2, c3, c4) + c = torch.cat([c2, c3, c4], dim=1) + + x, H, W = self.patch_embed(x) + bs, n, dim = x.shape + pos_embed = self._get_pos_embed(self.pos_embed[:, 1:], H, W) + x = self.pos_drop(x + pos_embed) + + outs = [] + for i, layer in enumerate(self.interact_blocks): + indexes = self.interaction_indexes[i] + x, c = layer(x, c, self.blocks[indexes[0]:indexes[-1] + 1], + deform_pkg1, deform_pkg2, H, W) + outs.append(x.transpose(1, 2).view(bs, dim, H, W).contiguous()) + + for extract_block in self.extract_blocks: + c = extract_block(c, x, deform_pkg2, H, W) + + c2 = c[:, 0: c2.size(1), :] + c3 = c[:, c2.size(1): c2.size(1) + c3.size(1), :] + c4 = c[:, c2.size(1) + c3.size(1):, :] + + c2 = c2.transpose(1, 2).view(bs, dim, H * 2, W * 2).contiguous() + c3 = c3.transpose(1, 2).view(bs, dim, H, W).contiguous() + c4 = c4.transpose(1, 2).view(bs, dim, H // 2, W // 2).contiguous() + c1 = self.up(c2) + c1 + + if self.add_vit_feature: + x1, x2, x3, x4 = outs + x1 = F.interpolate(x1, scale_factor=4, mode='bilinear', align_corners=False) + x2 = F.interpolate(x2, scale_factor=2, mode='bilinear', align_corners=False) + x4 = F.interpolate(x4, scale_factor=0.5, mode='bilinear', align_corners=False) + c1, c2, c3, c4 = c1 + x1, c2 + x2, c3 + x3, c4 + x4 + + f1 = self.norm1(c1) + f2 = self.norm2(c2) + f3 = self.norm3(c3) + f4 = self.norm4(c4) + return [f1, f2, f3, f4] diff --git a/segmentation/mmseg_custom/models/backbones/vit_baseline.py b/segmentation/mmseg_custom/models/backbones/vit_baseline.py new file mode 100644 index 000000000..8cc38b298 --- /dev/null +++ b/segmentation/mmseg_custom/models/backbones/vit_baseline.py @@ -0,0 +1,97 @@ +import logging +import torch.nn as nn +import torch.nn.functional as F +import math +from timm.models.layers import trunc_normal_ +from .base.vit import TIMMVisionTransformer +from mmseg.utils import get_root_logger +from mmcv.runner import load_checkpoint +from mmseg.models.builder import BACKBONES + +_logger = logging.getLogger(__name__) + +@BACKBONES.register_module() +class ViTBaseline(TIMMVisionTransformer): + + def __init__(self, pretrain_size=224, *args, **kwargs): + + super().__init__(*args, **kwargs) + self.num_classes = 80 + self.cls_token = None + self.num_block = len(self.blocks) + self.pretrain_size = (pretrain_size, pretrain_size) + self.flags = [i for i in range(-1, self.num_block, self.num_block // 4)][1:] + + embed_dim = self.embed_dim + self.norm1 = self.norm_layer(embed_dim) + self.norm2 = self.norm_layer(embed_dim) + self.norm3 = self.norm_layer(embed_dim) + self.norm4 = self.norm_layer(embed_dim) + + self.up1 = nn.Sequential(*[ + nn.ConvTranspose2d(embed_dim, embed_dim, 2, 2), + nn.GroupNorm(32, embed_dim), + nn.GELU(), + nn.ConvTranspose2d(embed_dim, embed_dim, 2, 2) + ]) + self.up2 = nn.ConvTranspose2d(embed_dim, embed_dim, 2, 2) + self.up3 = nn.Identity() + self.up4 = nn.MaxPool2d(kernel_size=2, stride=2) + + self.up1.apply(self._init_weights) + self.up2.apply(self._init_weights) + self.up3.apply(self._init_weights) + self.up4.apply(self._init_weights) + + def init_weights(self, pretrained=None): + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm) or isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + def _get_pos_embed(self, pos_embed, H, W): + pos_embed = pos_embed.reshape(1, self.pretrain_size[0] // 16, + self.pretrain_size[1] // 16, -1).permute(0, 3, 1, 2) + pos_embed = F.interpolate(pos_embed, size=(H, W), mode="bicubic", align_corners=False).reshape(1, -1, H * W).permute(0, 2, 1) + return pos_embed + + def forward_features(self, x): + outs = [] + x, H, W = self.patch_embed(x) + pos_embed = self._get_pos_embed(self.pos_embed[:, 1:], H, W) + x = self.pos_drop(x + pos_embed) + for index, blk in enumerate(self.blocks): + x = blk(x, H, W) + if index in self.flags: + outs.append(x) + return outs, H, W + + def forward(self, x): + outs, H, W = self.forward_features(x) + f1, f2, f3, f4 = outs + bs, n, dim = f1.shape + f1 = self.norm1(f1).transpose(1, 2).reshape(bs, dim, H, W) + f2 = self.norm2(f2).transpose(1, 2).reshape(bs, dim, H, W) + f3 = self.norm3(f3).transpose(1, 2).reshape(bs, dim, H, W) + f4 = self.norm4(f4).transpose(1, 2).reshape(bs, dim, H, W) + + f1 = self.up1(f1).contiguous() + f2 = self.up2(f2).contiguous() + f3 = self.up3(f3).contiguous() + f4 = self.up4(f4).contiguous() + + return [f1, f2, f3, f4] \ No newline at end of file diff --git a/segmentation/mmseg_custom/models/builder.py b/segmentation/mmseg_custom/models/builder.py new file mode 100644 index 000000000..07eb5eca6 --- /dev/null +++ b/segmentation/mmseg_custom/models/builder.py @@ -0,0 +1,23 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +from mmcv.utils import Registry + +TRANSFORMER = Registry('Transformer') +MASK_ASSIGNERS = Registry('mask_assigner') +MATCH_COST = Registry('match_cost') + + +def build_match_cost(cfg): + """Build Match Cost.""" + return MATCH_COST.build(cfg) + + +def build_assigner(cfg): + """Build Assigner.""" + return MASK_ASSIGNERS.build(cfg) + + +def build_transformer(cfg): + """Build Transformer.""" + return TRANSFORMER.build(cfg) diff --git a/segmentation/mmseg_custom/models/decode_heads/__init__.py b/segmentation/mmseg_custom/models/decode_heads/__init__.py new file mode 100644 index 000000000..52a7761c0 --- /dev/null +++ b/segmentation/mmseg_custom/models/decode_heads/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .mask2former_head import Mask2FormerHead +from .maskformer_head import MaskFormerHead + +__all__ = [ + 'MaskFormerHead', + 'Mask2FormerHead', +] diff --git a/segmentation/mmseg_custom/models/decode_heads/mask2former_head.py b/segmentation/mmseg_custom/models/decode_heads/mask2former_head.py new file mode 100644 index 000000000..6e78dfb06 --- /dev/null +++ b/segmentation/mmseg_custom/models/decode_heads/mask2former_head.py @@ -0,0 +1,578 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import Conv2d, build_plugin_layer, caffe2_xavier_init +from mmcv.cnn.bricks.transformer import (build_positional_encoding, + build_transformer_layer_sequence) +from mmcv.ops import point_sample +from mmcv.runner import ModuleList, force_fp32 +from mmseg.models.builder import HEADS, build_loss +from mmseg.models.decode_heads.decode_head import BaseDecodeHead + +from ...core import build_sampler, multi_apply, reduce_mean +from ..builder import build_assigner +from ..utils import get_uncertain_point_coords_with_randomness + + +@HEADS.register_module() +class Mask2FormerHead(BaseDecodeHead): + """Implements the Mask2Former head. + + See `Masked-attention Mask Transformer for Universal Image + Segmentation `_ for details. + + Args: + in_channels (list[int]): Number of channels in the input feature map. + feat_channels (int): Number of channels for features. + out_channels (int): Number of channels for output. + num_things_classes (int): Number of things. + num_stuff_classes (int): Number of stuff. + num_queries (int): Number of query in Transformer decoder. + pixel_decoder (:obj:`mmcv.ConfigDict` | dict): Config for pixel + decoder. Defaults to None. + enforce_decoder_input_project (bool, optional): Whether to add + a layer to change the embed_dim of tranformer encoder in + pixel decoder to the embed_dim of transformer decoder. + Defaults to False. + transformer_decoder (:obj:`mmcv.ConfigDict` | dict): Config for + transformer decoder. Defaults to None. + positional_encoding (:obj:`mmcv.ConfigDict` | dict): Config for + transformer decoder position encoding. Defaults to None. + loss_cls (:obj:`mmcv.ConfigDict` | dict): Config of the classification + loss. Defaults to None. + loss_mask (:obj:`mmcv.ConfigDict` | dict): Config of the mask loss. + Defaults to None. + loss_dice (:obj:`mmcv.ConfigDict` | dict): Config of the dice loss. + Defaults to None. + train_cfg (:obj:`mmcv.ConfigDict` | dict): Training config of + Mask2Former head. + test_cfg (:obj:`mmcv.ConfigDict` | dict): Testing config of + Mask2Former head. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + """ + def __init__(self, + in_channels, + feat_channels, + out_channels, + num_things_classes=80, + num_stuff_classes=53, + num_queries=100, + num_transformer_feat_level=3, + pixel_decoder=None, + enforce_decoder_input_project=False, + transformer_decoder=None, + positional_encoding=None, + loss_cls=None, + loss_mask=None, + loss_dice=None, + train_cfg=None, + test_cfg=None, + init_cfg=None, + **kwargs): + super(Mask2FormerHead, self).__init__( + in_channels=in_channels, + channels=feat_channels, + num_classes=(num_things_classes + num_stuff_classes), + init_cfg=init_cfg, + input_transform='multiple_select', + **kwargs) + self.num_things_classes = num_things_classes + self.num_stuff_classes = num_stuff_classes + self.num_classes = self.num_things_classes + self.num_stuff_classes + self.num_queries = num_queries + self.num_transformer_feat_level = num_transformer_feat_level + self.num_heads = transformer_decoder.transformerlayers. \ + attn_cfgs.num_heads + self.num_transformer_decoder_layers = transformer_decoder.num_layers + assert pixel_decoder.encoder.transformerlayers. \ + attn_cfgs.num_levels == num_transformer_feat_level + pixel_decoder_ = copy.deepcopy(pixel_decoder) + pixel_decoder_.update( + in_channels=in_channels, + feat_channels=feat_channels, + out_channels=out_channels) + self.pixel_decoder = build_plugin_layer(pixel_decoder_)[1] + self.transformer_decoder = build_transformer_layer_sequence( + transformer_decoder) + self.decoder_embed_dims = self.transformer_decoder.embed_dims + + self.decoder_input_projs = ModuleList() + # from low resolution to high resolution + for _ in range(num_transformer_feat_level): + if (self.decoder_embed_dims != feat_channels + or enforce_decoder_input_project): + self.decoder_input_projs.append( + Conv2d( + feat_channels, self.decoder_embed_dims, kernel_size=1)) + else: + self.decoder_input_projs.append(nn.Identity()) + self.decoder_positional_encoding = build_positional_encoding( + positional_encoding) + self.query_embed = nn.Embedding(self.num_queries, feat_channels) + self.query_feat = nn.Embedding(self.num_queries, feat_channels) + # from low resolution to high resolution + self.level_embed = nn.Embedding(self.num_transformer_feat_level, + feat_channels) + + self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1) + self.mask_embed = nn.Sequential( + nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), + nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), + nn.Linear(feat_channels, out_channels)) + + self.test_cfg = test_cfg + self.train_cfg = train_cfg + if train_cfg: + self.assigner = build_assigner(self.train_cfg.assigner) + self.sampler = build_sampler(self.train_cfg.sampler, context=self) + self.num_points = self.train_cfg.get('num_points', 12544) + self.oversample_ratio = self.train_cfg.get('oversample_ratio', 3.0) + self.importance_sample_ratio = self.train_cfg.get( + 'importance_sample_ratio', 0.75) + + self.class_weight = loss_cls.class_weight + self.loss_cls = build_loss(loss_cls) + self.loss_mask = build_loss(loss_mask) + self.loss_dice = build_loss(loss_dice) + + def init_weights(self): + for m in self.decoder_input_projs: + if isinstance(m, Conv2d): + caffe2_xavier_init(m, bias=0) + + self.pixel_decoder.init_weights() + + for p in self.transformer_decoder.parameters(): + if p.dim() > 1: + nn.init.xavier_normal_(p) + + def get_targets(self, cls_scores_list, mask_preds_list, gt_labels_list, + gt_masks_list, img_metas): + """Compute classification and mask targets for all images for a decoder + layer. + + Args: + cls_scores_list (list[Tensor]): Mask score logits from a single + decoder layer for all images. Each with shape [num_queries, + cls_out_channels]. + mask_preds_list (list[Tensor]): Mask logits from a single decoder + layer for all images. Each with shape [num_queries, h, w]. + gt_labels_list (list[Tensor]): Ground truth class indices for all + images. Each with shape (n, ), n is the sum of number of stuff + type and number of instance in a image. + gt_masks_list (list[Tensor]): Ground truth mask for each image, + each with shape (n, h, w). + img_metas (list[dict]): List of image meta information. + + Returns: + tuple[list[Tensor]]: a tuple containing the following targets. + + - labels_list (list[Tensor]): Labels of all images. + Each with shape [num_queries, ]. + - label_weights_list (list[Tensor]): Label weights of all + images.Each with shape [num_queries, ]. + - mask_targets_list (list[Tensor]): Mask targets of all images. + Each with shape [num_queries, h, w]. + - mask_weights_list (list[Tensor]): Mask weights of all images. + Each with shape [num_queries, ]. + - num_total_pos (int): Number of positive samples in all + images. + - num_total_neg (int): Number of negative samples in all + images. + """ + (labels_list, label_weights_list, mask_targets_list, mask_weights_list, + pos_inds_list, + neg_inds_list) = multi_apply(self._get_target_single, cls_scores_list, + mask_preds_list, gt_labels_list, + gt_masks_list, img_metas) + + num_total_pos = sum((inds.numel() for inds in pos_inds_list)) + num_total_neg = sum((inds.numel() for inds in neg_inds_list)) + return (labels_list, label_weights_list, mask_targets_list, + mask_weights_list, num_total_pos, num_total_neg) + + def _get_target_single(self, cls_score, mask_pred, gt_labels, gt_masks, + img_metas): + """Compute classification and mask targets for one image. + + Args: + cls_score (Tensor): Mask score logits from a single decoder layer + for one image. Shape (num_queries, cls_out_channels). + mask_pred (Tensor): Mask logits for a single decoder layer for one + image. Shape (num_queries, h, w). + gt_labels (Tensor): Ground truth class indices for one image with + shape (num_gts, ). + gt_masks (Tensor): Ground truth mask for each image, each with + shape (num_gts, h, w). + img_metas (dict): Image informtation. + + Returns: + tuple[Tensor]: A tuple containing the following for one image. + + - labels (Tensor): Labels of each image. \ + shape (num_queries, ). + - label_weights (Tensor): Label weights of each image. \ + shape (num_queries, ). + - mask_targets (Tensor): Mask targets of each image. \ + shape (num_queries, h, w). + - mask_weights (Tensor): Mask weights of each image. \ + shape (num_queries, ). + - pos_inds (Tensor): Sampled positive indices for each \ + image. + - neg_inds (Tensor): Sampled negative indices for each \ + image. + """ + # sample points + num_queries = cls_score.shape[0] + num_gts = gt_labels.shape[0] + + point_coords = torch.rand((1, self.num_points, 2), + device=cls_score.device) + # shape (num_queries, num_points) + mask_points_pred = point_sample( + mask_pred.unsqueeze(1), point_coords.repeat(num_queries, 1, + 1)).squeeze(1) + # shape (num_gts, num_points) + gt_points_masks = point_sample( + gt_masks.unsqueeze(1).float(), point_coords.repeat(num_gts, 1, + 1)).squeeze(1) + + # assign and sample + assign_result = self.assigner.assign(cls_score, mask_points_pred, + gt_labels, gt_points_masks, + img_metas) + sampling_result = self.sampler.sample(assign_result, mask_pred, + gt_masks) + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + + # label target + labels = gt_labels.new_full((self.num_queries, ), + self.num_classes, + dtype=torch.long) + labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] + label_weights = gt_labels.new_ones((self.num_queries, )) + + # mask target + mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds] + mask_weights = mask_pred.new_zeros((self.num_queries, )) + mask_weights[pos_inds] = 1.0 + + return (labels, label_weights, mask_targets, mask_weights, pos_inds, + neg_inds) + + def loss_single(self, cls_scores, mask_preds, gt_labels_list, + gt_masks_list, img_metas): + """Loss function for outputs from a single decoder layer. + + Args: + cls_scores (Tensor): Mask score logits from a single decoder layer + for all images. Shape (batch_size, num_queries, + cls_out_channels). Note `cls_out_channels` should includes + background. + mask_preds (Tensor): Mask logits for a pixel decoder for all + images. Shape (batch_size, num_queries, h, w). + gt_labels_list (list[Tensor]): Ground truth class indices for each + image, each with shape (num_gts, ). + gt_masks_list (list[Tensor]): Ground truth mask for each image, + each with shape (num_gts, h, w). + img_metas (list[dict]): List of image meta information. + + Returns: + tuple[Tensor]: Loss components for outputs from a single \ + decoder layer. + """ + num_imgs = cls_scores.size(0) + cls_scores_list = [cls_scores[i] for i in range(num_imgs)] + mask_preds_list = [mask_preds[i] for i in range(num_imgs)] + (labels_list, label_weights_list, mask_targets_list, mask_weights_list, + num_total_pos, + num_total_neg) = self.get_targets(cls_scores_list, mask_preds_list, + gt_labels_list, gt_masks_list, + img_metas) + # shape (batch_size, num_queries) + labels = torch.stack(labels_list, dim=0) + # shape (batch_size, num_queries) + label_weights = torch.stack(label_weights_list, dim=0) + # shape (num_total_gts, h, w) + mask_targets = torch.cat(mask_targets_list, dim=0) + # shape (batch_size, num_queries) + mask_weights = torch.stack(mask_weights_list, dim=0) + + # classfication loss + # shape (batch_size * num_queries, ) + cls_scores = cls_scores.flatten(0, 1) + labels = labels.flatten(0, 1) + label_weights = label_weights.flatten(0, 1) + + class_weight = cls_scores.new_tensor(self.class_weight) + loss_cls = self.loss_cls( + cls_scores, + labels, + label_weights, + avg_factor=class_weight[labels].sum()) + + num_total_masks = reduce_mean(cls_scores.new_tensor([num_total_pos])) + num_total_masks = max(num_total_masks, 1) + + # extract positive ones + # shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w) + mask_preds = mask_preds[mask_weights > 0] + + if mask_targets.shape[0] == 0: + # zero match + loss_dice = mask_preds.sum() + loss_mask = mask_preds.sum() + return loss_cls, loss_mask, loss_dice + + with torch.no_grad(): + points_coords = get_uncertain_point_coords_with_randomness( + mask_preds.unsqueeze(1), None, self.num_points, + self.oversample_ratio, self.importance_sample_ratio) + # shape (num_total_gts, h, w) -> (num_total_gts, num_points) + mask_point_targets = point_sample( + mask_targets.unsqueeze(1).float(), points_coords).squeeze(1) + # shape (num_queries, h, w) -> (num_queries, num_points) + mask_point_preds = point_sample( + mask_preds.unsqueeze(1), points_coords).squeeze(1) + + # dice loss + loss_dice = self.loss_dice( + mask_point_preds, mask_point_targets, avg_factor=num_total_masks) + + # mask loss + # shape (num_queries, num_points) -> (num_queries * num_points, ) + mask_point_preds = mask_point_preds.reshape(-1,1) + # shape (num_total_gts, num_points) -> (num_total_gts * num_points, ) + mask_point_targets = mask_point_targets.reshape(-1) + loss_mask = self.loss_mask( + mask_point_preds, + mask_point_targets, + avg_factor=num_total_masks * self.num_points) + + return loss_cls, loss_mask, loss_dice + + @force_fp32(apply_to=('all_cls_scores', 'all_mask_preds')) + def loss(self, all_cls_scores, all_mask_preds, gt_labels_list, + gt_masks_list, img_metas): + """Loss function. + + Args: + all_cls_scores (Tensor): Classification scores for all decoder + layers with shape [num_decoder, batch_size, num_queries, + cls_out_channels]. + all_mask_preds (Tensor): Mask scores for all decoder layers with + shape [num_decoder, batch_size, num_queries, h, w]. + gt_labels_list (list[Tensor]): Ground truth class indices for each + image with shape (n, ). n is the sum of number of stuff type + and number of instance in a image. + gt_masks_list (list[Tensor]): Ground truth mask for each image with + shape (n, h, w). + img_metas (list[dict]): List of image meta information. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + num_dec_layers = len(all_cls_scores) + all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)] + all_gt_masks_list = [gt_masks_list for _ in range(num_dec_layers)] + img_metas_list = [img_metas for _ in range(num_dec_layers)] + losses_cls, losses_mask, losses_dice = multi_apply( + self.loss_single, all_cls_scores, all_mask_preds, + all_gt_labels_list, all_gt_masks_list, img_metas_list) + + loss_dict = dict() + # loss from the last decoder layer + loss_dict['loss_cls'] = losses_cls[-1] + loss_dict['loss_mask'] = losses_mask[-1] + loss_dict['loss_dice'] = losses_dice[-1] + # loss from other decoder layers + num_dec_layer = 0 + for loss_cls_i, loss_mask_i, loss_dice_i in zip( + losses_cls[:-1], losses_mask[:-1], losses_dice[:-1]): + loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i + loss_dict[f'd{num_dec_layer}.loss_mask'] = loss_mask_i + loss_dict[f'd{num_dec_layer}.loss_dice'] = loss_dice_i + num_dec_layer += 1 + return loss_dict + + def forward_head(self, decoder_out, mask_feature, attn_mask_target_size): + """Forward for head part which is called after every decoder layer. + + Args: + decoder_out (Tensor): in shape (num_queries, batch_size, c). + mask_feature (Tensor): in shape (batch_size, c, h, w). + attn_mask_target_size (tuple[int, int]): target attention + mask size. + + Returns: + tuple: A tuple contain three elements. + + - cls_pred (Tensor): Classification scores in shape \ + (batch_size, num_queries, cls_out_channels). \ + Note `cls_out_channels` should includes background. + - mask_pred (Tensor): Mask scores in shape \ + (batch_size, num_queries,h, w). + - attn_mask (Tensor): Attention mask in shape \ + (batch_size * num_heads, num_queries, h, w). + """ + decoder_out = self.transformer_decoder.post_norm(decoder_out) + decoder_out = decoder_out.transpose(0, 1) + # shape (num_queries, batch_size, c) + cls_pred = self.cls_embed(decoder_out) + # shape (num_queries, batch_size, c) + mask_embed = self.mask_embed(decoder_out) + # shape (num_queries, batch_size, h, w) + mask_pred = torch.einsum('bqc,bchw->bqhw', mask_embed, mask_feature) + attn_mask = F.interpolate( + mask_pred, + attn_mask_target_size, + mode='bilinear', + align_corners=False) + # shape (num_queries, batch_size, h, w) -> + # (batch_size * num_head, num_queries, h, w) + attn_mask = attn_mask.flatten(2).unsqueeze(1).repeat( + (1, self.num_heads, 1, 1)).flatten(0, 1) + attn_mask = attn_mask.sigmoid() < 0.5 + attn_mask = attn_mask.detach() + + return cls_pred, mask_pred, attn_mask + + def forward(self, feats, img_metas): + """Forward function. + + Args: + feats (list[Tensor]): Multi scale Features from the + upstream network, each is a 4D-tensor. + img_metas (list[dict]): List of image information. + + Returns: + tuple: A tuple contains two elements. + + - cls_pred_list (list[Tensor)]: Classification logits \ + for each decoder layer. Each is a 3D-tensor with shape \ + (batch_size, num_queries, cls_out_channels). \ + Note `cls_out_channels` should includes background. + - mask_pred_list (list[Tensor]): Mask logits for each \ + decoder layer. Each with shape (batch_size, num_queries, \ + h, w). + """ + batch_size = len(img_metas) + mask_features, multi_scale_memorys = self.pixel_decoder(feats) + # multi_scale_memorys (from low resolution to high resolution) + decoder_inputs = [] + decoder_positional_encodings = [] + for i in range(self.num_transformer_feat_level): + decoder_input = self.decoder_input_projs[i](multi_scale_memorys[i]) + # shape (batch_size, c, h, w) -> (h*w, batch_size, c) + decoder_input = decoder_input.flatten(2).permute(2, 0, 1) + level_embed = self.level_embed.weight[i].view(1, 1, -1) + decoder_input = decoder_input + level_embed + # shape (batch_size, c, h, w) -> (h*w, batch_size, c) + mask = decoder_input.new_zeros( + (batch_size, ) + multi_scale_memorys[i].shape[-2:], + dtype=torch.bool) + decoder_positional_encoding = self.decoder_positional_encoding( + mask) + decoder_positional_encoding = decoder_positional_encoding.flatten( + 2).permute(2, 0, 1) + decoder_inputs.append(decoder_input) + decoder_positional_encodings.append(decoder_positional_encoding) + # shape (num_queries, c) -> (num_queries, batch_size, c) + query_feat = self.query_feat.weight.unsqueeze(1).repeat( + (1, batch_size, 1)) + query_embed = self.query_embed.weight.unsqueeze(1).repeat( + (1, batch_size, 1)) + + cls_pred_list = [] + mask_pred_list = [] + cls_pred, mask_pred, attn_mask = self.forward_head( + query_feat, mask_features, multi_scale_memorys[0].shape[-2:]) + cls_pred_list.append(cls_pred) + mask_pred_list.append(mask_pred) + + for i in range(self.num_transformer_decoder_layers): + level_idx = i % self.num_transformer_feat_level + # if a mask is all True(all background), then set it all False. + attn_mask[torch.where( + attn_mask.sum(-1) == attn_mask.shape[-1])] = False + + # cross_attn + self_attn + layer = self.transformer_decoder.layers[i] + attn_masks = [attn_mask, None] + query_feat = layer( + query=query_feat, + key=decoder_inputs[level_idx], + value=decoder_inputs[level_idx], + query_pos=query_embed, + key_pos=decoder_positional_encodings[level_idx], + attn_masks=attn_masks, + query_key_padding_mask=None, + # here we do not apply masking on padded region + key_padding_mask=None) + cls_pred, mask_pred, attn_mask = self.forward_head( + query_feat, mask_features, multi_scale_memorys[ + (i + 1) % self.num_transformer_feat_level].shape[-2:]) + + cls_pred_list.append(cls_pred) + mask_pred_list.append(mask_pred) + + return cls_pred_list, mask_pred_list + + def forward_train(self, x, img_metas, gt_semantic_seg, gt_labels, + gt_masks): + """Forward function for training mode. + + Args: + x (list[Tensor]): Multi-level features from the upstream network, + each is a 4D-tensor. + img_metas (list[Dict]): List of image information. + gt_semantic_seg (list[tensor]):Each element is the ground truth + of semantic segmentation with the shape (N, H, W). + train_cfg (dict): The training config, which not been used in + maskformer. + gt_labels (list[Tensor]): Each element is ground truth labels of + each box, shape (num_gts,). + gt_masks (list[BitmapMasks]): Each element is masks of instances + of a image, shape (num_gts, h, w). + + Returns: + losses (dict[str, Tensor]): a dictionary of loss components + """ + + # forward + all_cls_scores, all_mask_preds = self(x, img_metas) + + # loss + losses = self.loss(all_cls_scores, all_mask_preds, gt_labels, gt_masks, + img_metas) + + return losses + + def forward_test(self, inputs, img_metas, test_cfg): + """Test segment without test-time aumengtation. + + Only the output of last decoder layers was used. + + Args: + inputs (list[Tensor]): Multi-level features from the + upstream network, each is a 4D-tensor. + img_metas (list[dict]): List of image information. + test_cfg (dict): Testing config. + + Returns: + seg_mask (Tensor): Predicted semantic segmentation logits. + """ + all_cls_scores, all_mask_preds = self(inputs, img_metas) + cls_score, mask_pred = all_cls_scores[-1], all_mask_preds[-1] + ori_h, ori_w, _ = img_metas[0]['ori_shape'] + + # semantic inference + cls_score = F.softmax(cls_score, dim=-1)[..., :-1] + mask_pred = mask_pred.sigmoid() + seg_mask = torch.einsum('bqc,bqhw->bchw', cls_score, mask_pred) + return seg_mask diff --git a/segmentation/mmseg_custom/models/decode_heads/maskformer_head.py b/segmentation/mmseg_custom/models/decode_heads/maskformer_head.py new file mode 100644 index 000000000..aa8509a6d --- /dev/null +++ b/segmentation/mmseg_custom/models/decode_heads/maskformer_head.py @@ -0,0 +1,519 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import Conv2d, build_plugin_layer, kaiming_init +from mmcv.cnn.bricks.transformer import (build_positional_encoding, + build_transformer_layer_sequence) +from mmcv.runner import force_fp32 +from mmseg.models.builder import HEADS, build_loss +from mmseg.models.decode_heads.decode_head import BaseDecodeHead + +from ...core import multi_apply, reduce_mean +from ..builder import build_assigner, build_transformer + + +@HEADS.register_module() +class MaskFormerHead(BaseDecodeHead): + """Implements the MaskFormer head. + + See `paper: Per-Pixel Classification is Not All You Need + for Semantic Segmentation` + for details. + + Args: + in_channels (list[int]): Number of channels in the input feature map. + feat_channels (int): Number channels for feature. + out_channels (int): Number channels for output. + num_things_classes (int): Number of things. + num_stuff_classes (int): Number of stuff. + num_queries (int): Number of query in Transformer. + pixel_decoder (obj:`mmcv.ConfigDict`|dict): Config for pixel decoder. + Defaults to None. + enforce_decoder_input_project (bool, optional): Whether to add a layer + to change the embed_dim of tranformer encoder in pixel decoder to + the embed_dim of transformer decoder. Defaults to False. + transformer_decoder (obj:`mmcv.ConfigDict`|dict): Config for + transformer decoder. Defaults to None. + positional_encoding (obj:`mmcv.ConfigDict`|dict): Config for + transformer decoder position encoding. Defaults to None. + loss_cls (obj:`mmcv.ConfigDict`|dict): Config of the classification + loss. Defaults to `CrossEntropyLoss`. + loss_mask (obj:`mmcv.ConfigDict`|dict): Config of the mask loss. + Defaults to `FocalLoss`. + loss_dice (obj:`mmcv.ConfigDict`|dict): Config of the dice loss. + Defaults to `DiceLoss`. + train_cfg (obj:`mmcv.ConfigDict`|dict): Training config of Maskformer + head. + test_cfg (obj:`mmcv.ConfigDict`|dict): Testing config of Maskformer + head. + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + """ + def __init__(self, + out_channels, + num_queries=100, + pixel_decoder=None, + enforce_decoder_input_project=False, + transformer_decoder=None, + positional_encoding=None, + loss_cls=dict( + type='CrossEntropyLoss', + bg_cls_weight=0.1, + use_sigmoid=False, + loss_weight=1.0, + class_weight=1.0), + loss_mask=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=20.0), + loss_dice=dict( + type='DiceLoss', + use_sigmoid=True, + activate=True, + naive_dice=True, + loss_weight=1.0), + assigner=dict( + type='MaskHungarianAssigner', + cls_cost=dict(type='ClassificationCost', weight=1.), + dice_cost=dict(type='DiceCost', weight=1.0, pred_act=True, + eps=1.0), + mask_cost=dict(type='MaskFocalLossCost', weight=20.0)), + **kwargs): + super(MaskFormerHead, self).__init__(input_transform='multiple_select', + **kwargs) + self.num_queries = num_queries + + pixel_decoder.update( + in_channels=self.in_channels, + feat_channels=self.channels, + out_channels=out_channels) + self.pixel_decoder = build_plugin_layer(pixel_decoder)[1] + self.transformer_decoder = build_transformer_layer_sequence( + transformer_decoder) + self.decoder_embed_dims = self.transformer_decoder.embed_dims + pixel_decoder_type = pixel_decoder.get('type') + if pixel_decoder_type == 'PixelDecoder' and ( + self.decoder_embed_dims != self.in_channels[-1] + or enforce_decoder_input_project): + self.decoder_input_proj = Conv2d( + self.in_channels[-1], self.decoder_embed_dims, kernel_size=1) + else: + self.decoder_input_proj = nn.Identity() + self.decoder_pe = build_positional_encoding(positional_encoding) + self.query_embed = nn.Embedding(self.num_queries, out_channels) + + self.cls_embed = nn.Linear(self.channels, self.num_classes + 1) + self.mask_embed = nn.Sequential( + nn.Linear(self.channels, self.channels), nn.ReLU(inplace=True), + nn.Linear(self.channels, self.channels), nn.ReLU(inplace=True), + nn.Linear(self.channels, out_channels)) + + self.assigner = build_assigner(assigner) + + self.bg_cls_weight = 0 + class_weight = loss_cls.get('class_weight', None) + if class_weight is not None and (self.__class__ is MaskFormerHead): + assert isinstance(class_weight, float), 'Expected ' \ + 'class_weight to have type float. Found ' \ + f'{type(class_weight)}.' + # NOTE following the official MaskFormerHead repo, bg_cls_weight + # means relative classification weight of the VOID class. + bg_cls_weight = loss_cls.get('bg_cls_weight', class_weight) + assert isinstance(bg_cls_weight, float), 'Expected ' \ + 'bg_cls_weight to have type float. Found ' \ + f'{type(bg_cls_weight)}.' + class_weight = (self.num_classes + 1) * [class_weight] + # set VOID class as the last indice + class_weight[self.num_classes] = bg_cls_weight + loss_cls.update({'class_weight': class_weight}) + if 'bg_cls_weight' in loss_cls: + loss_cls.pop('bg_cls_weight') + self.bg_cls_weight = bg_cls_weight + + assert loss_cls['loss_weight'] == assigner['cls_cost']['weight'], \ + 'The classification weight for loss and matcher should be' \ + 'exactly the same.' + assert loss_dice['loss_weight'] == assigner['dice_cost']['weight'], \ + f'The dice weight for loss and matcher' \ + f'should be exactly the same.' + assert loss_mask['loss_weight'] == assigner['mask_cost']['weight'], \ + 'The focal weight for loss and matcher should be' \ + 'exactly the same.' + self.loss_cls = build_loss(loss_cls) + self.loss_mask = build_loss(loss_mask) + self.loss_dice = build_loss(loss_dice) + + self.init_weights() + + def init_weights(self): + kaiming_init(self.decoder_input_proj, a=1) + + def get_targets(self, cls_scores_list, mask_preds_list, gt_labels_list, + gt_masks_list, img_metas): + """Compute classification and mask targets for all images for a decoder + layer. + + Args: + cls_scores_list (list[Tensor]): Mask score logits from a single + decoder layer for all images. Each with shape [num_queries, + cls_out_channels]. + mask_preds_list (list[Tensor]): Mask logits from a single decoder + layer for all images. Each with shape [num_queries, h, w]. + gt_labels_list (list[Tensor]): Ground truth class indices for all + images. Each with shape (n, ), n is the sum of number of stuff + type and number of instance in a image. + gt_masks_list (list[Tensor]): Ground truth mask for each image, + each with shape (n, h, w). + img_metas (list[dict]): List of image meta information. + + Returns: + tuple[list[Tensor]]: a tuple containing the following targets. + + - labels_list (list[Tensor]): Labels of all images. + Each with shape [num_queries, ]. + - label_weights_list (list[Tensor]): Label weights of all + images.Each with shape [num_queries, ]. + - mask_targets_list (list[Tensor]): Mask targets of all images. + Each with shape [num_queries, h, w]. + - mask_weights_list (list[Tensor]): Mask weights of all images. + Each with shape [num_queries, ]. + - num_total_pos (int): Number of positive samples in all + images. + - num_total_neg (int): Number of negative samples in all + images. + """ + (labels_list, label_weights_list, mask_targets_list, mask_weights_list, + pos_inds_list, + neg_inds_list) = multi_apply(self._get_target_single, cls_scores_list, + mask_preds_list, gt_labels_list, + gt_masks_list, img_metas) + + num_total_pos = sum((inds.numel() for inds in pos_inds_list)) + num_total_neg = sum((inds.numel() for inds in neg_inds_list)) + return (labels_list, label_weights_list, mask_targets_list, + mask_weights_list, num_total_pos, num_total_neg) + + def _get_target_single(self, cls_score, mask_pred, gt_labels, gt_masks, + img_metas): + """Compute classification and mask targets for one image. + + Args: + cls_score (Tensor): Mask score logits from a single decoder layer + for one image. Shape [num_queries, cls_out_channels]. + mask_pred (Tensor): Mask logits for a single decoder layer for one + image. Shape [num_queries, h, w]. + gt_labels (Tensor): Ground truth class indices for one image with + shape (n, ). n is the sum of number of stuff type and number + of instance in a image. + gt_masks (Tensor): Ground truth mask for each image, each with + shape (n, h, w). + img_metas (dict): Image informtation. + + Returns: + tuple[Tensor]: a tuple containing the following for one image. + + - labels (Tensor): Labels of each image. + shape [num_queries, ]. + - label_weights (Tensor): Label weights of each image. + shape [num_queries, ]. + - mask_targets (Tensor): Mask targets of each image. + shape [num_queries, h, w]. + - mask_weights (Tensor): Mask weights of each image. + shape [num_queries, ]. + - pos_inds (Tensor): Sampled positive indices for each image. + - neg_inds (Tensor): Sampled negative indices for each image. + """ + target_shape = mask_pred.shape[-2:] + gt_masks_downsampled = F.interpolate( + gt_masks.unsqueeze(1).float(), target_shape, + mode='nearest').squeeze(1).long() + # assign and sample + assign_result = self.assigner.assign(cls_score, mask_pred, gt_labels, + gt_masks_downsampled, img_metas) + # pos_ind: range from 1 to (self.num_classes) + # which represents the positive index + pos_inds = torch.nonzero(assign_result.gt_inds > 0, + as_tuple=False).squeeze(-1).unique() + neg_inds = torch.nonzero(assign_result.gt_inds == 0, + as_tuple=False).squeeze(-1).unique() + pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1 + + # label target + labels = gt_labels.new_full((self.num_queries, ), + self.num_classes, + dtype=torch.long) + labels[pos_inds] = gt_labels[pos_assigned_gt_inds] + label_weights = gt_labels.new_ones(self.num_queries) + + # mask target + mask_targets = gt_masks[pos_assigned_gt_inds, :] + mask_weights = mask_pred.new_zeros((self.num_queries, )) + mask_weights[pos_inds] = 1.0 + + return (labels, label_weights, mask_targets, mask_weights, pos_inds, + neg_inds) + + @force_fp32(apply_to=('all_cls_scores', 'all_mask_preds')) + def loss(self, all_cls_scores, all_mask_preds, gt_labels_list, + gt_masks_list, img_metas): + """Loss function. + + Args: + all_cls_scores (Tensor): Classification scores for all decoder + layers with shape [num_decoder, batch_size, num_queries, + cls_out_channels]. + all_mask_preds (Tensor): Mask scores for all decoder layers with + shape [num_decoder, batch_size, num_queries, h, w]. + gt_labels_list (list[Tensor]): Ground truth class indices for each + image with shape (n, ). n is the sum of number of stuff type + and number of instance in a image. + gt_masks_list (list[Tensor]): Ground truth mask for each image with + shape (n, h, w). + img_metas (list[dict]): List of image meta information. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + num_dec_layers = len(all_cls_scores) + all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)] + all_gt_masks_list = [gt_masks_list for _ in range(num_dec_layers)] + img_metas_list = [img_metas for _ in range(num_dec_layers)] + losses_cls, losses_mask, losses_dice = multi_apply( + self.loss_single, all_cls_scores, all_mask_preds, + all_gt_labels_list, all_gt_masks_list, img_metas_list) + + loss_dict = dict() + # loss from the last decoder layer + loss_dict['loss_cls'] = losses_cls[-1] + loss_dict['loss_mask'] = losses_mask[-1] + loss_dict['loss_dice'] = losses_dice[-1] + # loss from other decoder layers + num_dec_layer = 0 + for loss_cls_i, loss_mask_i, loss_dice_i in zip( + losses_cls[:-1], losses_mask[:-1], losses_dice[:-1]): + loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i + loss_dict[f'd{num_dec_layer}.loss_mask'] = loss_mask_i + loss_dict[f'd{num_dec_layer}.loss_dice'] = loss_dice_i + num_dec_layer += 1 + return loss_dict + + def loss_single(self, cls_scores, mask_preds, gt_labels_list, + gt_masks_list, img_metas): + """Loss function for outputs from a single decoder layer. + + Args: + cls_scores (Tensor): Mask score logits from a single decoder layer + for all images. Shape [batch_size, num_queries, + cls_out_channels]. + mask_preds (Tensor): Mask logits for a pixel decoder for all + images. Shape [batch_size, num_queries, h, w]. + gt_labels_list (list[Tensor]): Ground truth class indices for each + image, each with shape (n, ). n is the sum of number of stuff + types and number of instances in a image. + gt_masks_list (list[Tensor]): Ground truth mask for each image, + each with shape (n, h, w). + img_metas (list[dict]): List of image meta information. + + Returns: + tuple[Tensor]:Loss components for outputs from a single decoder + layer. + """ + num_imgs = cls_scores.size(0) + cls_scores_list = [cls_scores[i] for i in range(num_imgs)] + mask_preds_list = [mask_preds[i] for i in range(num_imgs)] + + (labels_list, label_weights_list, mask_targets_list, mask_weights_list, + num_total_pos, + num_total_neg) = self.get_targets(cls_scores_list, mask_preds_list, + gt_labels_list, gt_masks_list, + img_metas) + # shape [batch_size, num_queries] + labels = torch.stack(labels_list, dim=0) + # shape [batch_size, num_queries] + label_weights = torch.stack(label_weights_list, dim=0) + # shape [num_gts, h, w] + mask_targets = torch.cat(mask_targets_list, dim=0) + # shape [batch_size, num_queries] + mask_weights = torch.stack(mask_weights_list, dim=0) + + # classfication loss + # shape [batch_size * num_queries, ] + cls_scores = cls_scores.flatten(0, 1) + # shape [batch_size * num_queries, ] + labels = labels.flatten(0, 1) + # shape [batch_size* num_queries, ] + label_weights = label_weights.flatten(0, 1) + + class_weight = cls_scores.new_ones(self.num_classes + 1) + class_weight[-1] = self.bg_cls_weight + + loss_cls = self.loss_cls( + cls_scores, + labels, + label_weights, + avg_factor=class_weight[labels].sum()) + + num_total_masks = reduce_mean(cls_scores.new_tensor([num_total_pos])) + num_total_masks = max(num_total_masks, 1) + + # extract positive ones + mask_preds = mask_preds[mask_weights > 0] + target_shape = mask_targets.shape[-2:] + + if mask_targets.shape[0] == 0: + # zero match + loss_dice = mask_preds.sum() + loss_mask = mask_preds.sum() + return loss_cls, loss_mask, loss_dice + + # upsample to shape of target + # shape [num_gts, h, w] + mask_preds = F.interpolate( + mask_preds.unsqueeze(1), + target_shape, + mode='bilinear', + align_corners=False).squeeze(1) + + # dice loss + loss_dice = self.loss_dice( + mask_preds, mask_targets, avg_factor=num_total_masks) + + # mask loss + # FocalLoss support input of shape [n, num_class] + h, w = mask_preds.shape[-2:] + # shape [num_gts, h, w] -> [num_gts * h * w, 1] + mask_preds = mask_preds.reshape(-1, 1) + # shape [num_gts, h, w] -> [num_gts * h * w] + mask_targets = mask_targets.reshape(-1) + # target is (1 - mask_targets) !!! + print("mask_pred:", mask_preds.shape) + print("mask_targets:", mask_targets.shape) + loss_mask = self.loss_mask( + mask_preds, 1 - mask_targets, avg_factor=num_total_masks * h * w) + + return loss_cls, loss_mask, loss_dice + + def forward(self, feats, img_metas): + """Forward function. + + Args: + feats (list[Tensor]): Features from the upstream network, each + is a 4D-tensor. + img_metas (list[dict]): List of image information. + + Returns: + all_cls_scores (Tensor): Classification scores for each + scale level. Each is a 4D-tensor with shape + [num_decoder, batch_size, num_queries, cls_out_channels]. + Note `cls_out_channels` should includes background. + all_mask_preds (Tensor): Mask scores for each decoder + layer. Each with shape [num_decoder, batch_size, + num_queries, h, w]. + """ + batch_size = len(img_metas) + input_img_h, input_img_w = img_metas[0]['pad_shape'][:-1] + # input_img_h, input_img_w = img_metas[0]['batch_input_shape'] + padding_mask = feats[-1].new_ones( + (batch_size, input_img_h, input_img_w), dtype=torch.float32) + for i in range(batch_size): + img_h, img_w, _ = img_metas[i]['img_shape'] + padding_mask[i, :img_h, :img_w] = 0 + padding_mask = F.interpolate( + padding_mask.unsqueeze(1), + size=feats[-1].shape[-2:], + mode='nearest').to(torch.bool).squeeze(1) + # when backbone is swin, memory is output of last stage of swin. + # when backbone is r50, memory is output of tranformer encoder. + mask_features, memory = self.pixel_decoder(feats, img_metas) + pos_embed = self.decoder_pe(padding_mask) + memory = self.decoder_input_proj(memory) + # shape [batch_size, c, h, w] -> [h*w, batch_size, c] + memory = memory.flatten(2).permute(2, 0, 1) + pos_embed = pos_embed.flatten(2).permute(2, 0, 1) + # shape [batch_size, h * w] + padding_mask = padding_mask.flatten(1) + # shape = [num_queries, embed_dims] + query_embed = self.query_embed.weight + # shape = [num_queries, batch_size, embed_dims] + query_embed = query_embed.unsqueeze(1).repeat(1, batch_size, 1) + target = torch.zeros_like(query_embed) + # shape [num_decoder, num_queries, batch_size, embed_dims] + out_dec = self.transformer_decoder( + query=target, + key=memory, + value=memory, + key_pos=pos_embed, + query_pos=query_embed, + key_padding_mask=padding_mask) + # shape [num_decoder, batch_size, num_queries, embed_dims] + out_dec = out_dec.transpose(1, 2) + + # cls_scores + all_cls_scores = self.cls_embed(out_dec) + + # mask_preds + mask_embed = self.mask_embed(out_dec) + all_mask_preds = torch.einsum('lbqc,bchw->lbqhw', mask_embed, + mask_features) + + return all_cls_scores, all_mask_preds + + def forward_train(self, + x, + img_metas, + gt_semantic_seg, + gt_labels, + gt_masks): + """Forward function for training mode. + + Args: + x (list[Tensor]): Multi-level features from the upstream network, + each is a 4D-tensor. + img_metas (list[Dict]): List of image information. + gt_semantic_seg (list[tensor]):Each element is the ground truth + of semantic segmentation with the shape (N, H, W). + train_cfg (dict): The training config, which not been used in + maskformer. + gt_labels (list[Tensor]): Each element is ground truth labels of + each box, shape (num_gts,). + gt_masks (list[BitmapMasks]): Each element is masks of instances + of a image, shape (num_gts, h, w). + + Returns: + losses (dict[str, Tensor]): a dictionary of loss components + """ + + # forward + all_cls_scores, all_mask_preds = self(x, img_metas) + + # loss + losses = self.loss(all_cls_scores, all_mask_preds, gt_labels, gt_masks, + img_metas) + + return losses + + def forward_test(self, inputs, img_metas, test_cfg): + """Test segment without test-time aumengtation. + + Only the output of last decoder layers was used. + + Args: + inputs (list[Tensor]): Multi-level features from the + upstream network, each is a 4D-tensor. + img_metas (list[dict]): List of image information. + test_cfg (dict): Testing config. + + Returns: + seg_mask (Tensor): Predicted semantic segmentation logits. + """ + all_cls_scores, all_mask_preds = self(inputs, img_metas) + cls_score, mask_pred = all_cls_scores[-1], all_mask_preds[-1] + ori_h, ori_w, _ = img_metas[0]['ori_shape'] + + # semantic inference + cls_score = F.softmax(cls_score, dim=-1)[..., :-1] + mask_pred = mask_pred.sigmoid() + seg_mask = torch.einsum('bqc,bqhw->bchw', cls_score, mask_pred) + return seg_mask diff --git a/segmentation/mmseg_custom/models/losses/__init__.py b/segmentation/mmseg_custom/models/losses/__init__.py new file mode 100644 index 000000000..50ed88a44 --- /dev/null +++ b/segmentation/mmseg_custom/models/losses/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy, + cross_entropy, mask_cross_entropy) +from .dice_loss import DiceLoss +from .focal_loss import FocalLoss +from .match_costs import (ClassificationCost, CrossEntropyLossCost, DiceCost, + MaskFocalLossCost) + +__all__ = [ + 'cross_entropy', 'binary_cross_entropy', 'mask_cross_entropy', + 'CrossEntropyLoss', 'DiceLoss', 'FocalLoss', 'ClassificationCost', + 'MaskFocalLossCost', 'DiceCost', 'CrossEntropyLossCost' +] diff --git a/segmentation/mmseg_custom/models/losses/cross_entropy_loss.py b/segmentation/mmseg_custom/models/losses/cross_entropy_loss.py new file mode 100644 index 000000000..5766ea6e5 --- /dev/null +++ b/segmentation/mmseg_custom/models/losses/cross_entropy_loss.py @@ -0,0 +1,291 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmseg.models.builder import LOSSES +from mmseg.models.losses.utils import get_class_weight, weight_reduce_loss + + +def cross_entropy(pred, + label, + weight=None, + class_weight=None, + reduction='mean', + avg_factor=None, + ignore_index=-100, + avg_non_ignore=False): + """cross_entropy. The wrapper function for :func:`F.cross_entropy` + + Args: + pred (torch.Tensor): The prediction with shape (N, 1). + label (torch.Tensor): The learning label of the prediction. + weight (torch.Tensor, optional): Sample-wise loss weight. + Default: None. + class_weight (list[float], optional): The weight for each class. + Default: None. + reduction (str, optional): The method used to reduce the loss. + Options are 'none', 'mean' and 'sum'. Default: 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. Default: None. + ignore_index (int): Specifies a target value that is ignored and + does not contribute to the input gradients. When + ``avg_non_ignore `` is ``True``, and the ``reduction`` is + ``''mean''``, the loss is averaged over non-ignored targets. + Defaults: -100. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. + `New in version 0.23.0.` + """ + + # class_weight is a manual rescaling weight given to each class. + # If given, has to be a Tensor of size C element-wise losses + loss = F.cross_entropy( + pred, + label, + weight=class_weight, + reduction='none', + ignore_index=ignore_index) + + # apply weights and do the reduction + # average loss over non-ignored elements + # pytorch's official cross_entropy average loss over non-ignored elements + # refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa + if (avg_factor is None) and avg_non_ignore and reduction == 'mean': + avg_factor = label.numel() - (label == ignore_index).sum().item() + if weight is not None: + weight = weight.float() + loss = weight_reduce_loss( + loss, weight=weight, reduction=reduction, avg_factor=avg_factor) + + return loss + + +def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index): + """Expand onehot labels to match the size of prediction.""" + bin_labels = labels.new_zeros(target_shape) + valid_mask = (labels >= 0) & (labels != ignore_index) + inds = torch.nonzero(valid_mask, as_tuple=True) + + if inds[0].numel() > 0: + if labels.dim() == 3: + bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1 + else: + bin_labels[inds[0], labels[valid_mask]] = 1 + + valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float() + + if label_weights is None: + bin_label_weights = valid_mask + else: + bin_label_weights = label_weights.unsqueeze(1).expand(target_shape) + bin_label_weights = bin_label_weights * valid_mask + + return bin_labels, bin_label_weights, valid_mask + + +def binary_cross_entropy(pred, + label, + weight=None, + reduction='mean', + avg_factor=None, + class_weight=None, + ignore_index=-100, + avg_non_ignore=False, + **kwargs): + """Calculate the binary CrossEntropy loss. + + Args: + pred (torch.Tensor): The prediction with shape (N, 1). + label (torch.Tensor): The learning label of the prediction. + Note: In bce loss, label < 0 is invalid. + weight (torch.Tensor, optional): Sample-wise loss weight. + reduction (str, optional): The method used to reduce the loss. + Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + class_weight (list[float], optional): The weight for each class. + ignore_index (int): The label index to be ignored. Default: -100. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. + `New in version 0.23.0.` + + Returns: + torch.Tensor: The calculated loss + """ + if pred.size(1) == 1: + # For binary class segmentation, the shape of pred is + # [N, 1, H, W] and that of label is [N, H, W]. + assert label.max() <= 1, \ + 'For pred with shape [N, 1, H, W], its label must have at ' \ + 'most 2 classes' + pred = pred.squeeze() + if pred.dim() != label.dim(): + assert (pred.dim() == 2 and label.dim() == 1) or ( + pred.dim() == 4 and label.dim() == 3), \ + 'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \ + 'H, W], label shape [N, H, W] are supported' + # `weight` returned from `_expand_onehot_labels` + # has been treated for valid (non-ignore) pixels + label, weight, valid_mask = _expand_onehot_labels( + label, weight, pred.shape, ignore_index) + else: + # should mask out the ignored elements + valid_mask = ((label >= 0) & (label != ignore_index)).float() + if weight is not None: + weight = weight * valid_mask + else: + weight = valid_mask + # average loss over non-ignored and valid elements + if reduction == 'mean' and avg_factor is None and avg_non_ignore: + avg_factor = valid_mask.sum().item() + + loss = F.binary_cross_entropy_with_logits( + pred, label.float(), pos_weight=class_weight, reduction='none') + # do the reduction for the weighted loss + loss = weight_reduce_loss( + loss, weight, reduction=reduction, avg_factor=avg_factor) + + return loss + + +def mask_cross_entropy(pred, + target, + label, + reduction='mean', + avg_factor=None, + class_weight=None, + ignore_index=None, + **kwargs): + """Calculate the CrossEntropy loss for masks. + + Args: + pred (torch.Tensor): The prediction with shape (N, C), C is the number + of classes. + target (torch.Tensor): The learning label of the prediction. + label (torch.Tensor): ``label`` indicates the class label of the mask' + corresponding object. This will be used to select the mask in the + of the class which the object belongs to when the mask prediction + if not class-agnostic. + reduction (str, optional): The method used to reduce the loss. + Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + class_weight (list[float], optional): The weight for each class. + ignore_index (None): Placeholder, to be consistent with other loss. + Default: None. + + Returns: + torch.Tensor: The calculated loss + """ + assert ignore_index is None, 'BCE loss does not support ignore_index' + # TODO: handle these two reserved arguments + assert reduction == 'mean' and avg_factor is None + num_rois = pred.size()[0] + inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device) + pred_slice = pred[inds, label].squeeze(1) + return F.binary_cross_entropy_with_logits( + pred_slice, target, weight=class_weight, reduction='mean')[None] + + +@LOSSES.register_module(force=True) +class CrossEntropyLoss(nn.Module): + """CrossEntropyLoss. + + Args: + use_sigmoid (bool, optional): Whether the prediction uses sigmoid + of softmax. Defaults to False. + use_mask (bool, optional): Whether to use mask cross entropy loss. + Defaults to False. + reduction (str, optional): . Defaults to 'mean'. + Options are "none", "mean" and "sum". + class_weight (list[float] | str, optional): Weight of each class. If in + str format, read them from a file. Defaults to None. + loss_weight (float, optional): Weight of the loss. Defaults to 1.0. + loss_name (str, optional): Name of the loss item. If you want this loss + item to be included into the backward graph, `loss_` must be the + prefix of the name. Defaults to 'loss_ce'. + avg_non_ignore (bool): The flag decides to whether the loss is + only averaged over non-ignored targets. Default: False. + `New in version 0.23.0.` + """ + def __init__(self, + use_sigmoid=False, + use_mask=False, + reduction='mean', + class_weight=None, + loss_weight=1.0, + loss_name='loss_ce', + avg_non_ignore=False): + super(CrossEntropyLoss, self).__init__() + assert (use_sigmoid is False) or (use_mask is False) + self.use_sigmoid = use_sigmoid + self.use_mask = use_mask + self.reduction = reduction + self.loss_weight = loss_weight + self.class_weight = get_class_weight(class_weight) + self.avg_non_ignore = avg_non_ignore + if not self.avg_non_ignore and self.reduction == 'mean': + warnings.warn( + 'Default ``avg_non_ignore`` is False, if you would like to ' + 'ignore the certain label and average loss over non-ignore ' + 'labels, which is the same with PyTorch official ' + 'cross_entropy, set ``avg_non_ignore=True``.') + + if self.use_sigmoid: + self.cls_criterion = binary_cross_entropy + elif self.use_mask: + self.cls_criterion = mask_cross_entropy + else: + self.cls_criterion = cross_entropy + self._loss_name = loss_name + + def extra_repr(self): + """Extra repr.""" + s = f'avg_non_ignore={self.avg_non_ignore}' + return s + + def forward(self, + cls_score, + label, + weight=None, + avg_factor=None, + reduction_override=None, + ignore_index=-100, + **kwargs): + """Forward function.""" + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = (reduction_override + if reduction_override else self.reduction) + if self.class_weight is not None: + class_weight = cls_score.new_tensor(self.class_weight) + else: + class_weight = None + # Note: for BCE loss, label < 0 is invalid. + loss_cls = self.loss_weight * self.cls_criterion( + cls_score, + label, + weight, + class_weight=class_weight, + reduction=reduction, + avg_factor=avg_factor, + avg_non_ignore=self.avg_non_ignore, + ignore_index=ignore_index, + **kwargs) + return loss_cls + + @property + def loss_name(self): + """Loss Name. + + This function must be implemented and will return the name of this + loss function. This name will be used to combine different loss items + by simple sum operation. In addition, if you want this loss item to be + included into the backward graph, `loss_` must be the prefix of the + name. + + Returns: + str: The name of this loss item. + """ + return self._loss_name diff --git a/segmentation/mmseg_custom/models/losses/dice_loss.py b/segmentation/mmseg_custom/models/losses/dice_loss.py new file mode 100644 index 000000000..e84b07e6f --- /dev/null +++ b/segmentation/mmseg_custom/models/losses/dice_loss.py @@ -0,0 +1,179 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmseg.models.builder import LOSSES +from mmseg.models.losses.utils import weight_reduce_loss + + +def dice_loss(pred, + target, + weight=None, + eps=1e-3, + reduction='mean', + avg_factor=None): + """Calculate dice loss, which is proposed in + `V-Net: Fully Convolutional Neural Networks for Volumetric + Medical Image Segmentation `_. + + Args: + pred (torch.Tensor): The prediction, has a shape (n, *) + target (torch.Tensor): The learning label of the prediction, + shape (n, *), same shape of pred. + weight (torch.Tensor, optional): The weight of loss for each + prediction, has a shape (n,). Defaults to None. + eps (float): Avoid dividing by zero. Default: 1e-3. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. + Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + """ + + input = pred.flatten(1) + target = target.flatten(1).float() + + a = torch.sum(input * target, 1) + b = torch.sum(input * input, 1) + eps + c = torch.sum(target * target, 1) + eps + d = (2 * a) / (b + c) + loss = 1 - d + if weight is not None: + assert weight.ndim == loss.ndim + assert len(weight) == len(pred) + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + +def naive_dice_loss(pred, + target, + weight=None, + eps=1e-3, + reduction='mean', + avg_factor=None): + """Calculate naive dice loss, the coefficient in the denominator is the + first power instead of the second power. + + Args: + pred (torch.Tensor): The prediction, has a shape (n, *) + target (torch.Tensor): The learning label of the prediction, + shape (n, *), same shape of pred. + weight (torch.Tensor, optional): The weight of loss for each + prediction, has a shape (n,). Defaults to None. + eps (float): Avoid dividing by zero. Default: 1e-3. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. + Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + """ + input = pred.flatten(1) + target = target.flatten(1).float() + + a = torch.sum(input * target, 1) + b = torch.sum(input, 1) + c = torch.sum(target, 1) + d = (2 * a + eps) / (b + c + eps) + loss = 1 - d + if weight is not None: + assert weight.ndim == loss.ndim + assert len(weight) == len(pred) + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + +@LOSSES.register_module(force=True) +class DiceLoss(nn.Module): + def __init__(self, + use_sigmoid=True, + activate=True, + reduction='mean', + naive_dice=False, + loss_weight=1.0, + eps=1e-3): + """Dice Loss, there are two forms of dice loss is supported: + + - the one proposed in `V-Net: Fully Convolutional Neural + Networks for Volumetric Medical Image Segmentation + `_. + - the dice loss in which the power of the number in the + denominator is the first power instead of the second + power. + + Args: + use_sigmoid (bool, optional): Whether to the prediction is + used for sigmoid or softmax. Defaults to True. + activate (bool): Whether to activate the predictions inside, + this will disable the inside sigmoid operation. + Defaults to True. + reduction (str, optional): The method used + to reduce the loss. Options are "none", + "mean" and "sum". Defaults to 'mean'. + naive_dice (bool, optional): If false, use the dice + loss defined in the V-Net paper, otherwise, use the + naive dice loss in which the power of the number in the + denominator is the first power instead of the second + power.Defaults to False. + loss_weight (float, optional): Weight of loss. Defaults to 1.0. + eps (float): Avoid dividing by zero. Defaults to 1e-3. + """ + + super(DiceLoss, self).__init__() + self.use_sigmoid = use_sigmoid + self.reduction = reduction + self.naive_dice = naive_dice + self.loss_weight = loss_weight + self.eps = eps + self.activate = activate + + def forward(self, + pred, + target, + weight=None, + reduction_override=None, + avg_factor=None): + """Forward function. + + Args: + pred (torch.Tensor): The prediction, has a shape (n, *). + target (torch.Tensor): The label of the prediction, + shape (n, *), same shape of pred. + weight (torch.Tensor, optional): The weight of loss for each + prediction, has a shape (n,). Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): The reduction method used to + override the original reduction method of the loss. + Options are "none", "mean" and "sum". + + Returns: + torch.Tensor: The calculated loss + """ + + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = (reduction_override + if reduction_override else self.reduction) + + if self.activate: + if self.use_sigmoid: + pred = pred.sigmoid() + else: + raise NotImplementedError + + if self.naive_dice: + loss = self.loss_weight * naive_dice_loss( + pred, + target, + weight, + eps=self.eps, + reduction=reduction, + avg_factor=avg_factor) + else: + loss = self.loss_weight * dice_loss( + pred, + target, + weight, + eps=self.eps, + reduction=reduction, + avg_factor=avg_factor) + + return loss \ No newline at end of file diff --git a/segmentation/mmseg_custom/models/losses/focal_loss.py b/segmentation/mmseg_custom/models/losses/focal_loss.py new file mode 100644 index 000000000..3d48a2bb6 --- /dev/null +++ b/segmentation/mmseg_custom/models/losses/focal_loss.py @@ -0,0 +1,180 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.ops import sigmoid_focal_loss as _sigmoid_focal_loss +from mmseg.models.builder import LOSSES +from mmseg.models.losses.utils import weight_reduce_loss + + +# This method is only for debugging +def py_sigmoid_focal_loss(pred, + target, + weight=None, + gamma=2.0, + alpha=0.25, + reduction='mean', + avg_factor=None): + """PyTorch version of `Focal Loss `_. + + Args: + pred (torch.Tensor): The prediction with shape (N, C), C is the + number of classes + target (torch.Tensor): The learning label of the prediction. + weight (torch.Tensor, optional): Sample-wise loss weight. + gamma (float, optional): The gamma for calculating the modulating + factor. Defaults to 2.0. + alpha (float, optional): A balanced form for Focal Loss. + Defaults to 0.25. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + """ + pred_sigmoid = pred.sigmoid() + target = target.type_as(pred) + pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target) + focal_weight = (alpha * target + (1 - alpha) * + (1 - target)) * pt.pow(gamma) + loss = F.binary_cross_entropy_with_logits( + pred, target, reduction='none') * focal_weight + if weight is not None: + if weight.shape != loss.shape: + if weight.size(0) == loss.size(0): + # For most cases, weight is of shape (num_priors, ), + # which means it does not have the second axis num_class + weight = weight.view(-1, 1) + else: + # Sometimes, weight per anchor per class is also needed. e.g. + # in FSAF. But it may be flattened of shape + # (num_priors x num_class, ), while loss is still of shape + # (num_priors, num_class). + assert weight.numel() == loss.numel() + weight = weight.view(loss.size(0), -1) + assert weight.ndim == loss.ndim + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + +def sigmoid_focal_loss(pred, + target, + weight=None, + gamma=2.0, + alpha=0.25, + reduction='mean', + avg_factor=None): + r"""A warpper of cuda version `Focal Loss + `_. + + Args: + pred (torch.Tensor): The prediction with shape (N, C), C is the number + of classes. + target (torch.Tensor): The learning label of the prediction. + weight (torch.Tensor, optional): Sample-wise loss weight. + gamma (float, optional): The gamma for calculating the modulating + factor. Defaults to 2.0. + alpha (float, optional): A balanced form for Focal Loss. + Defaults to 0.25. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + """ + # Function.apply does not accept keyword arguments, so the decorator + # "weighted_loss" is not applicable + loss = _sigmoid_focal_loss(pred.contiguous(), target.contiguous(), gamma, + alpha, None, 'none') + if weight is not None: + if weight.shape != loss.shape: + if weight.size(0) == loss.size(0): + # For most cases, weight is of shape (num_priors, ), + # which means it does not have the second axis num_class + weight = weight.view(-1, 1) + else: + # Sometimes, weight per anchor per class is also needed. e.g. + # in FSAF. But it may be flattened of shape + # (num_priors x num_class, ), while loss is still of shape + # (num_priors, num_class). + assert weight.numel() == loss.numel() + weight = weight.view(loss.size(0), -1) + assert weight.ndim == loss.ndim + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + +@LOSSES.register_module(force=True) +class FocalLoss(nn.Module): + def __init__(self, + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + reduction='mean', + loss_weight=1.0): + """`Focal Loss `_ + + Args: + use_sigmoid (bool, optional): Whether to the prediction is + used for sigmoid or softmax. Defaults to True. + gamma (float, optional): The gamma for calculating the modulating + factor. Defaults to 2.0. + alpha (float, optional): A balanced form for Focal Loss. + Defaults to 0.25. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. Options are "none", "mean" and + "sum". + loss_weight (float, optional): Weight of loss. Defaults to 1.0. + """ + super(FocalLoss, self).__init__() + assert use_sigmoid is True, 'Only sigmoid focal loss supported now.' + self.use_sigmoid = use_sigmoid + self.gamma = gamma + self.alpha = alpha + self.reduction = reduction + self.loss_weight = loss_weight + + def forward(self, + pred, + target, + weight=None, + avg_factor=None, + reduction_override=None): + """Forward function. + + Args: + pred (torch.Tensor): The prediction. + target (torch.Tensor): The learning label of the prediction. + weight (torch.Tensor, optional): The weight of loss for each + prediction. Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): The reduction method used to + override the original reduction method of the loss. + Options are "none", "mean" and "sum". + + Returns: + torch.Tensor: The calculated loss + """ + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if self.use_sigmoid: + if torch.cuda.is_available() and pred.is_cuda: + calculate_loss_func = sigmoid_focal_loss + else: + num_classes = pred.size(1) + target = F.one_hot(target, num_classes=num_classes + 1) + target = target[:, :num_classes] + calculate_loss_func = py_sigmoid_focal_loss + + loss_cls = self.loss_weight * calculate_loss_func( + pred, + target, + weight, + gamma=self.gamma, + alpha=self.alpha, + reduction=reduction, + avg_factor=avg_factor) + + else: + raise NotImplementedError + return loss_cls diff --git a/segmentation/mmseg_custom/models/losses/match_costs.py b/segmentation/mmseg_custom/models/losses/match_costs.py new file mode 100644 index 000000000..9c04862d4 --- /dev/null +++ b/segmentation/mmseg_custom/models/losses/match_costs.py @@ -0,0 +1,233 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..builder import MATCH_COST + + +@MATCH_COST.register_module() +class FocalLossCost: + """FocalLossCost. + + Args: + weight (int | float, optional): loss_weight + alpha (int | float, optional): focal_loss alpha + gamma (int | float, optional): focal_loss gamma + eps (float, optional): default 1e-12 + + Examples: + >>> from mmdet.core.bbox.match_costs.match_cost import FocalLossCost + >>> import torch + >>> self = FocalLossCost() + >>> cls_pred = torch.rand(4, 3) + >>> gt_labels = torch.tensor([0, 1, 2]) + >>> factor = torch.tensor([10, 8, 10, 8]) + >>> self(cls_pred, gt_labels) + tensor([[-0.3236, -0.3364, -0.2699], + [-0.3439, -0.3209, -0.4807], + [-0.4099, -0.3795, -0.2929], + [-0.1950, -0.1207, -0.2626]]) + """ + def __init__(self, weight=1., alpha=0.25, gamma=2, eps=1e-12): + self.weight = weight + self.alpha = alpha + self.gamma = gamma + self.eps = eps + + def __call__(self, cls_pred, gt_labels): + """ + Args: + cls_pred (Tensor): Predicted classification logits, shape + [num_query, num_class]. + gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,). + + Returns: + torch.Tensor: cls_cost value with weight + """ + cls_pred = cls_pred.sigmoid() + neg_cost = -(1 - cls_pred + self.eps).log() * ( + 1 - self.alpha) * cls_pred.pow(self.gamma) + pos_cost = -(cls_pred + self.eps).log() * self.alpha * ( + 1 - cls_pred).pow(self.gamma) + cls_cost = pos_cost[:, gt_labels] - neg_cost[:, gt_labels] + return cls_cost * self.weight + + +@MATCH_COST.register_module() +class MaskFocalLossCost(FocalLossCost): + """Cost of mask assignments based on focal losses. + + Args: + weight (int | float, optional): loss_weight. + alpha (int | float, optional): focal_loss alpha. + gamma (int | float, optional): focal_loss gamma. + eps (float, optional): default 1e-12. + """ + def __call__(self, cls_pred, gt_labels): + """ + Args: + cls_pred (Tensor): Predicted classfication logits + in shape (N1, H, W), dtype=torch.float32. + gt_labels (Tensor): Ground truth in shape (N2, H, W), + dtype=torch.long. + + Returns: + Tensor: classification cost matrix in shape (N1, N2). + """ + cls_pred = cls_pred.reshape((cls_pred.shape[0], -1)) + gt_labels = gt_labels.reshape((gt_labels.shape[0], -1)).float() + hw = cls_pred.shape[1] + cls_pred = cls_pred.sigmoid() + neg_cost = -(1 - cls_pred + self.eps).log() * ( + 1 - self.alpha) * cls_pred.pow(self.gamma) + pos_cost = -(cls_pred + self.eps).log() * self.alpha * ( + 1 - cls_pred).pow(self.gamma) + + cls_cost = torch.einsum('nc,mc->nm', pos_cost, gt_labels) + \ + torch.einsum('nc,mc->nm', neg_cost, (1 - gt_labels)) + return cls_cost / hw * self.weight + + +@MATCH_COST.register_module() +class ClassificationCost: + """ClsSoftmaxCost.Borrow from + mmdet.core.bbox.match_costs.match_cost.ClassificationCost. + + Args: + weight (int | float, optional): loss_weight + + Examples: + >>> import torch + >>> self = ClassificationCost() + >>> cls_pred = torch.rand(4, 3) + >>> gt_labels = torch.tensor([0, 1, 2]) + >>> factor = torch.tensor([10, 8, 10, 8]) + >>> self(cls_pred, gt_labels) + tensor([[-0.3430, -0.3525, -0.3045], + [-0.3077, -0.2931, -0.3992], + [-0.3664, -0.3455, -0.2881], + [-0.3343, -0.2701, -0.3956]]) + """ + def __init__(self, weight=1.): + self.weight = weight + + def __call__(self, cls_pred, gt_labels): + """ + Args: + cls_pred (Tensor): Predicted classification logits, shape + [num_query, num_class]. + gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,). + + Returns: + torch.Tensor: cls_cost value with weight + """ + # Following the official DETR repo, contrary to the loss that + # NLL is used, we approximate it in 1 - cls_score[gt_label]. + # The 1 is a constant that doesn't change the matching, + # so it can be omitted. + cls_score = cls_pred.softmax(-1) + cls_cost = -cls_score[:, gt_labels] + return cls_cost * self.weight + + +@MATCH_COST.register_module() +class DiceCost: + """Cost of mask assignments based on dice losses. + + Args: + weight (int | float, optional): loss_weight. Defaults to 1. + pred_act (bool, optional): Whether to apply sigmoid to mask_pred. + Defaults to False. + eps (float, optional): default 1e-12. + """ + def __init__(self, weight=1., pred_act=False, eps=1e-3): + self.weight = weight + self.pred_act = pred_act + self.eps = eps + + def binary_mask_dice_loss(self, mask_preds, gt_masks): + """ + Args: + mask_preds (Tensor): Mask prediction in shape (N1, H, W). + gt_masks (Tensor): Ground truth in shape (N2, H, W) + store 0 or 1, 0 for negative class and 1 for + positive class. + + Returns: + Tensor: Dice cost matrix in shape (N1, N2). + """ + mask_preds = mask_preds.reshape((mask_preds.shape[0], -1)) + gt_masks = gt_masks.reshape((gt_masks.shape[0], -1)).float() + numerator = 2 * torch.einsum('nc,mc->nm', mask_preds, gt_masks) + denominator = mask_preds.sum(-1)[:, None] + gt_masks.sum(-1)[None, :] + loss = 1 - (numerator + self.eps) / (denominator + self.eps) + return loss + + def __call__(self, mask_preds, gt_masks): + """ + Args: + mask_preds (Tensor): Mask prediction logits in shape (N1, H, W). + gt_masks (Tensor): Ground truth in shape (N2, H, W). + + Returns: + Tensor: Dice cost matrix in shape (N1, N2). + """ + if self.pred_act: + mask_preds = mask_preds.sigmoid() + dice_cost = self.binary_mask_dice_loss(mask_preds, gt_masks) + return dice_cost * self.weight + + +@MATCH_COST.register_module() +class CrossEntropyLossCost: + """CrossEntropyLossCost. + + Args: + weight (int | float, optional): loss weight. Defaults to 1. + use_sigmoid (bool, optional): Whether the prediction uses sigmoid + of softmax. Defaults to True. + """ + def __init__(self, weight=1., use_sigmoid=True): + assert use_sigmoid, 'use_sigmoid = False is not supported yet.' + self.weight = weight + self.use_sigmoid = use_sigmoid + + def _binary_cross_entropy(self, cls_pred, gt_labels): + """ + Args: + cls_pred (Tensor): The prediction with shape (num_query, 1, *) or + (num_query, *). + gt_labels (Tensor): The learning label of prediction with + shape (num_gt, *). + Returns: + Tensor: Cross entropy cost matrix in shape (num_query, num_gt). + """ + cls_pred = cls_pred.flatten(1).float() + gt_labels = gt_labels.flatten(1).float() + n = cls_pred.shape[1] + pos = F.binary_cross_entropy_with_logits( + cls_pred, torch.ones_like(cls_pred), reduction='none') + neg = F.binary_cross_entropy_with_logits( + cls_pred, torch.zeros_like(cls_pred), reduction='none') + cls_cost = torch.einsum('nc,mc->nm', pos, gt_labels) + \ + torch.einsum('nc,mc->nm', neg, 1 - gt_labels) + cls_cost = cls_cost / n + + return cls_cost + + def __call__(self, cls_pred, gt_labels): + """ + Args: + cls_pred (Tensor): Predicted classification logits. + gt_labels (Tensor): Labels. + Returns: + Tensor: Cross entropy cost matrix with weight in + shape (num_query, num_gt). + """ + if self.use_sigmoid: + cls_cost = self._binary_cross_entropy(cls_pred, gt_labels) + else: + raise NotImplementedError + + return cls_cost * self.weight diff --git a/segmentation/mmseg_custom/models/losses/match_loss.py b/segmentation/mmseg_custom/models/losses/match_loss.py new file mode 100644 index 000000000..3c53839ac --- /dev/null +++ b/segmentation/mmseg_custom/models/losses/match_loss.py @@ -0,0 +1,179 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..builder import MATCH_COST + + +@MATCH_COST.register_module() +class FocalLossCost: + """FocalLossCost. + + Args: + weight (int | float, optional): loss_weight + alpha (int | float, optional): focal_loss alpha + gamma (int | float, optional): focal_loss gamma + eps (float, optional): default 1e-12 + + Examples: + >>> from mmdet.core.bbox.match_costs.match_cost import FocalLossCost + >>> import torch + >>> self = FocalLossCost() + >>> cls_pred = torch.rand(4, 3) + >>> gt_labels = torch.tensor([0, 1, 2]) + >>> factor = torch.tensor([10, 8, 10, 8]) + >>> self(cls_pred, gt_labels) + tensor([[-0.3236, -0.3364, -0.2699], + [-0.3439, -0.3209, -0.4807], + [-0.4099, -0.3795, -0.2929], + [-0.1950, -0.1207, -0.2626]]) + """ + def __init__(self, weight=1., alpha=0.25, gamma=2, eps=1e-12): + self.weight = weight + self.alpha = alpha + self.gamma = gamma + self.eps = eps + + def __call__(self, cls_pred, gt_labels): + """ + Args: + cls_pred (Tensor): Predicted classification logits, shape + [num_query, num_class]. + gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,). + + Returns: + torch.Tensor: cls_cost value with weight + """ + cls_pred = cls_pred.sigmoid() + neg_cost = -(1 - cls_pred + self.eps).log() * ( + 1 - self.alpha) * cls_pred.pow(self.gamma) + pos_cost = -(cls_pred + self.eps).log() * self.alpha * ( + 1 - cls_pred).pow(self.gamma) + cls_cost = pos_cost[:, gt_labels] - neg_cost[:, gt_labels] + return cls_cost * self.weight + + +@MATCH_COST.register_module() +class MaskFocalLossCost(FocalLossCost): + """Cost of mask assignments based on focal losses. + + Args: + weight (int | float, optional): loss_weight. + alpha (int | float, optional): focal_loss alpha. + gamma (int | float, optional): focal_loss gamma. + eps (float, optional): default 1e-12. + """ + def __call__(self, cls_pred, gt_labels): + """ + Args: + cls_pred (Tensor): Predicted classfication logits + in shape (N1, H, W), dtype=torch.float32. + gt_labels (Tensor): Ground truth in shape (N2, H, W), + dtype=torch.long. + + Returns: + Tensor: classification cost matrix in shape (N1, N2). + """ + cls_pred = cls_pred.reshape((cls_pred.shape[0], -1)) + gt_labels = gt_labels.reshape((gt_labels.shape[0], -1)).float() + hw = cls_pred.shape[1] + cls_pred = cls_pred.sigmoid() + neg_cost = -(1 - cls_pred + self.eps).log() * ( + 1 - self.alpha) * cls_pred.pow(self.gamma) + pos_cost = -(cls_pred + self.eps).log() * self.alpha * ( + 1 - cls_pred).pow(self.gamma) + + cls_cost = torch.einsum('nc,mc->nm', pos_cost, gt_labels) + \ + torch.einsum('nc,mc->nm', neg_cost, (1 - gt_labels)) + return cls_cost / hw * self.weight + + +@MATCH_COST.register_module() +class ClassificationCost: + """ClsSoftmaxCost.Borrow from + mmdet.core.bbox.match_costs.match_cost.ClassificationCost. + + Args: + weight (int | float, optional): loss_weight + + Examples: + >>> import torch + >>> self = ClassificationCost() + >>> cls_pred = torch.rand(4, 3) + >>> gt_labels = torch.tensor([0, 1, 2]) + >>> factor = torch.tensor([10, 8, 10, 8]) + >>> self(cls_pred, gt_labels) + tensor([[-0.3430, -0.3525, -0.3045], + [-0.3077, -0.2931, -0.3992], + [-0.3664, -0.3455, -0.2881], + [-0.3343, -0.2701, -0.3956]]) + """ + def __init__(self, weight=1.): + self.weight = weight + + def __call__(self, cls_pred, gt_labels): + """ + Args: + cls_pred (Tensor): Predicted classification logits, shape + [num_query, num_class]. + gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,). + + Returns: + torch.Tensor: cls_cost value with weight + """ + # Following the official DETR repo, contrary to the loss that + # NLL is used, we approximate it in 1 - cls_score[gt_label]. + # The 1 is a constant that doesn't change the matching, + # so it can be omitted. + cls_score = cls_pred.softmax(-1) + cls_cost = -cls_score[:, gt_labels] + return cls_cost * self.weight + + +@MATCH_COST.register_module() +class DiceCost: + """Cost of mask assignments based on dice losses. + + Args: + weight (int | float, optional): loss_weight. Defaults to 1. + pred_act (bool, optional): Whether to apply sigmoid to mask_pred. + Defaults to False. + eps (float, optional): default 1e-12. + """ + def __init__(self, weight=1., pred_act=False, eps=1e-3): + self.weight = weight + self.pred_act = pred_act + self.eps = eps + + def binary_mask_dice_loss(self, mask_preds, gt_masks): + """ + Args: + mask_preds (Tensor): Mask prediction in shape (N1, H, W). + gt_masks (Tensor): Ground truth in shape (N2, H, W) + store 0 or 1, 0 for negative class and 1 for + positive class. + + Returns: + Tensor: Dice cost matrix in shape (N1, N2). + """ + mask_preds = mask_preds.reshape((mask_preds.shape[0], -1)) + gt_masks = gt_masks.reshape((gt_masks.shape[0], -1)).float() + numerator = 2 * torch.einsum('nc,mc->nm', mask_preds, gt_masks) + denominator = mask_preds.sum(-1)[:, None] + gt_masks.sum(-1)[None, :] + loss = 1 - (numerator + self.eps) / (denominator + self.eps) + return loss + + def __call__(self, mask_preds, gt_masks): + """ + Args: + mask_preds (Tensor): Mask prediction logits in shape (N1, H, W). + gt_masks (Tensor): Ground truth in shape (N2, H, W). + + Returns: + Tensor: Dice cost matrix in shape (N1, N2). + """ + if self.pred_act: + mask_preds = mask_preds.sigmoid() + dice_cost = self.binary_mask_dice_loss(mask_preds, gt_masks) + return dice_cost * self.weight diff --git a/segmentation/mmseg_custom/models/plugins/__init__.py b/segmentation/mmseg_custom/models/plugins/__init__.py new file mode 100644 index 000000000..c1e91ae63 --- /dev/null +++ b/segmentation/mmseg_custom/models/plugins/__init__.py @@ -0,0 +1,7 @@ +from .msdeformattn_pixel_decoder import MSDeformAttnPixelDecoder +from .pixel_decoder import PixelDecoder, TransformerEncoderPixelDecoder + +__all__ = [ + 'PixelDecoder', 'TransformerEncoderPixelDecoder', + 'MSDeformAttnPixelDecoder' +] diff --git a/segmentation/mmseg_custom/models/plugins/msdeformattn_pixel_decoder.py b/segmentation/mmseg_custom/models/plugins/msdeformattn_pixel_decoder.py new file mode 100644 index 000000000..3a4ea63ca --- /dev/null +++ b/segmentation/mmseg_custom/models/plugins/msdeformattn_pixel_decoder.py @@ -0,0 +1,268 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import (PLUGIN_LAYERS, Conv2d, ConvModule, caffe2_xavier_init, + normal_init, xavier_init) +from mmcv.cnn.bricks.transformer import (build_positional_encoding, + build_transformer_layer_sequence) +from mmcv.runner import BaseModule, ModuleList + +from ...core.anchor import MlvlPointGenerator +from ..utils.transformer import MultiScaleDeformableAttention + + +@PLUGIN_LAYERS.register_module() +class MSDeformAttnPixelDecoder(BaseModule): + """Pixel decoder with multi-scale deformable attention. + + Args: + in_channels (list[int] | tuple[int]): Number of channels in the + input feature maps. + strides (list[int] | tuple[int]): Output strides of feature from + backbone. + feat_channels (int): Number of channels for feature. + out_channels (int): Number of channels for output. + num_outs (int): Number of output scales. + norm_cfg (:obj:`mmcv.ConfigDict` | dict): Config for normalization. + Defaults to dict(type='GN', num_groups=32). + act_cfg (:obj:`mmcv.ConfigDict` | dict): Config for activation. + Defaults to dict(type='ReLU'). + encoder (:obj:`mmcv.ConfigDict` | dict): Config for transformer + encoder. Defaults to `DetrTransformerEncoder`. + positional_encoding (:obj:`mmcv.ConfigDict` | dict): Config for + transformer encoder position encoding. Defaults to + dict(type='SinePositionalEncoding', num_feats=128, + normalize=True). + init_cfg (:obj:`mmcv.ConfigDict` | dict): Initialization config dict. + """ + def __init__(self, + in_channels=[256, 512, 1024, 2048], + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_outs=3, + norm_cfg=dict(type='GN', num_groups=32), + act_cfg=dict(type='ReLU'), + encoder=dict( + type='DetrTransformerEncoder', + num_layers=6, + transformerlayers=dict( + type='BaseTransformerLayer', + attn_cfgs=dict( + type='MultiScaleDeformableAttention', + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + im2col_step=64, + dropout=0.0, + batch_first=False, + norm_cfg=None, + init_cfg=None), + feedforward_channels=1024, + ffn_dropout=0.0, + operation_order=('self_attn', 'norm', 'ffn', 'norm')), + init_cfg=None), + positional_encoding=dict( + type='SinePositionalEncoding', + num_feats=128, + normalize=True), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.strides = strides + self.num_input_levels = len(in_channels) + self.num_encoder_levels = \ + encoder.transformerlayers.attn_cfgs.num_levels + assert self.num_encoder_levels >= 1, \ + 'num_levels in attn_cfgs must be at least one' + input_conv_list = [] + # from top to down (low to high resolution) + for i in range(self.num_input_levels - 1, + self.num_input_levels - self.num_encoder_levels - 1, + -1): + input_conv = ConvModule( + in_channels[i], + feat_channels, + kernel_size=1, + norm_cfg=norm_cfg, + act_cfg=None, + bias=True) + input_conv_list.append(input_conv) + self.input_convs = ModuleList(input_conv_list) + + self.encoder = build_transformer_layer_sequence(encoder) + self.postional_encoding = build_positional_encoding( + positional_encoding) + # high resolution to low resolution + self.level_encoding = nn.Embedding(self.num_encoder_levels, + feat_channels) + + # fpn-like structure + self.lateral_convs = ModuleList() + self.output_convs = ModuleList() + self.use_bias = norm_cfg is None + # from top to down (low to high resolution) + # fpn for the rest features that didn't pass in encoder + for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1, + -1): + lateral_conv = ConvModule( + in_channels[i], + feat_channels, + kernel_size=1, + bias=self.use_bias, + norm_cfg=norm_cfg, + act_cfg=None) + output_conv = ConvModule( + feat_channels, + feat_channels, + kernel_size=3, + stride=1, + padding=1, + bias=self.use_bias, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.lateral_convs.append(lateral_conv) + self.output_convs.append(output_conv) + + self.mask_feature = Conv2d( + feat_channels, out_channels, kernel_size=1, stride=1, padding=0) + + self.num_outs = num_outs + self.point_generator = MlvlPointGenerator(strides) + + def init_weights(self): + """Initialize weights.""" + for i in range(0, self.num_encoder_levels): + xavier_init( + self.input_convs[i].conv, + gain=1, + bias=0, + distribution='uniform') + + for i in range(0, self.num_input_levels - self.num_encoder_levels): + caffe2_xavier_init(self.lateral_convs[i].conv, bias=0) + caffe2_xavier_init(self.output_convs[i].conv, bias=0) + + caffe2_xavier_init(self.mask_feature, bias=0) + + normal_init(self.level_encoding, mean=0, std=1) + for p in self.encoder.parameters(): + if p.dim() > 1: + nn.init.xavier_normal_(p) + + # init_weights defined in MultiScaleDeformableAttention + for layer in self.encoder.layers: + for attn in layer.attentions: + if isinstance(attn, MultiScaleDeformableAttention): + attn.init_weights() + + def forward(self, feats): + """ + Args: + feats (list[Tensor]): Feature maps of each level. Each has + shape of (batch_size, c, h, w). + + Returns: + tuple: A tuple containing the following: + + - mask_feature (Tensor): shape (batch_size, c, h, w). + - multi_scale_features (list[Tensor]): Multi scale \ + features, each in shape (batch_size, c, h, w). + """ + # generate padding mask for each level, for each image + batch_size = feats[0].shape[0] + encoder_input_list = [] + padding_mask_list = [] + level_positional_encoding_list = [] + spatial_shapes = [] + reference_points_list = [] + for i in range(self.num_encoder_levels): + level_idx = self.num_input_levels - i - 1 + feat = feats[level_idx] + feat_projected = self.input_convs[i](feat) + h, w = feat.shape[-2:] + + # no padding + padding_mask_resized = feat.new_zeros( + (batch_size, ) + feat.shape[-2:], dtype=torch.bool) + pos_embed = self.postional_encoding(padding_mask_resized) + level_embed = self.level_encoding.weight[i] + level_pos_embed = level_embed.view(1, -1, 1, 1) + pos_embed + # (h_i * w_i, 2) + reference_points = self.point_generator.single_level_grid_priors( + feat.shape[-2:], level_idx, device=feat.device) + # normalize + factor = feat.new_tensor([[w, h]]) * self.strides[level_idx] + reference_points = reference_points / factor + + # shape (batch_size, c, h_i, w_i) -> (h_i * w_i, batch_size, c) + feat_projected = feat_projected.flatten(2).permute(2, 0, 1) + level_pos_embed = level_pos_embed.flatten(2).permute(2, 0, 1) + padding_mask_resized = padding_mask_resized.flatten(1) + + encoder_input_list.append(feat_projected) + padding_mask_list.append(padding_mask_resized) + level_positional_encoding_list.append(level_pos_embed) + spatial_shapes.append(feat.shape[-2:]) + reference_points_list.append(reference_points) + # shape (batch_size, total_num_query), + # total_num_query=sum([., h_i * w_i,.]) + padding_masks = torch.cat(padding_mask_list, dim=1) + # shape (total_num_query, batch_size, c) + encoder_inputs = torch.cat(encoder_input_list, dim=0) + level_positional_encodings = torch.cat( + level_positional_encoding_list, dim=0) + device = encoder_inputs.device + # shape (num_encoder_levels, 2), from low + # resolution to high resolution + spatial_shapes = torch.as_tensor( + spatial_shapes, dtype=torch.long, device=device) + # shape (0, h_0*w_0, h_0*w_0+h_1*w_1, ...) + level_start_index = torch.cat((spatial_shapes.new_zeros( + (1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) + reference_points = torch.cat(reference_points_list, dim=0) + reference_points = reference_points[None, :, None].repeat( + batch_size, 1, self.num_encoder_levels, 1) + valid_radios = reference_points.new_ones( + (batch_size, self.num_encoder_levels, 2)) + # shape (num_total_query, batch_size, c) + memory = self.encoder( + query=encoder_inputs, + key=None, + value=None, + query_pos=level_positional_encodings, + key_pos=None, + attn_masks=None, + key_padding_mask=None, + query_key_padding_mask=padding_masks, + spatial_shapes=spatial_shapes, + reference_points=reference_points, + level_start_index=level_start_index, + valid_radios=valid_radios) + # (num_total_query, batch_size, c) -> (batch_size, c, num_total_query) + memory = memory.permute(1, 2, 0) + + # from low resolution to high resolution + num_query_per_level = [e[0] * e[1] for e in spatial_shapes] + outs = torch.split(memory, num_query_per_level, dim=-1) + outs = [ + x.reshape(batch_size, -1, spatial_shapes[i][0], + spatial_shapes[i][1]) for i, x in enumerate(outs) + ] + + for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1, + -1): + x = feats[i] + cur_feat = self.lateral_convs[i](x) + y = cur_feat + F.interpolate( + outs[-1], + size=cur_feat.shape[-2:], + mode='bilinear', + align_corners=False) + y = self.output_convs[i](y) + outs.append(y) + multi_scale_features = outs[:self.num_outs] + + mask_feature = self.mask_feature(outs[-1]) + return mask_feature, multi_scale_features \ No newline at end of file diff --git a/segmentation/mmseg_custom/models/plugins/pixel_decoder.py b/segmentation/mmseg_custom/models/plugins/pixel_decoder.py new file mode 100644 index 000000000..62e488f68 --- /dev/null +++ b/segmentation/mmseg_custom/models/plugins/pixel_decoder.py @@ -0,0 +1,237 @@ +import torch +import torch.nn.functional as F +from mmcv.cnn import PLUGIN_LAYERS, Conv2d, ConvModule, kaiming_init +from mmcv.cnn.bricks.transformer import (build_positional_encoding, + build_transformer_layer_sequence) +from mmcv.runner import BaseModule, ModuleList + + +@PLUGIN_LAYERS.register_module() +class PixelDecoder(BaseModule): + """Pixel decoder with a structure like fpn. + + Args: + in_channels (list[int] | tuple[int]): Number of channels in the + input feature maps. + feat_channels (int): Number channels for feature. + out_channels (int): Number channels for output. + norm_cfg (obj:`mmcv.ConfigDict`|dict): Config for normalization. + Defaults to dict(type='GN', num_groups=32). + act_cfg (obj:`mmcv.ConfigDict`|dict): Config for activation. + Defaults to dict(type='ReLU'). + encoder (obj:`mmcv.ConfigDict`|dict): Config for transorformer + encoder.Defaults to None. + positional_encoding (obj:`mmcv.ConfigDict`|dict): Config for + transformer encoder position encoding. Defaults to + dict(type='SinePositionalEncoding', num_feats=128, + normalize=True). + init_cfg (obj:`mmcv.ConfigDict`|dict): Initialization config dict. + Default: None + """ + def __init__(self, + in_channels, + feat_channels, + out_channels, + norm_cfg=dict(type='GN', num_groups=32), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.num_inputs = len(in_channels) + self.lateral_convs = ModuleList() + self.output_convs = ModuleList() + self.use_bias = norm_cfg is None + for i in range(0, self.num_inputs - 1): + l_conv = ConvModule( + in_channels[i], + feat_channels, + kernel_size=1, + bias=self.use_bias, + norm_cfg=norm_cfg, + act_cfg=None) + o_conv = ConvModule( + feat_channels, + feat_channels, + kernel_size=3, + stride=1, + padding=1, + bias=self.use_bias, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.lateral_convs.append(l_conv) + self.output_convs.append(o_conv) + + self.last_feat_conv = ConvModule( + in_channels[-1], + feat_channels, + kernel_size=3, + padding=1, + stride=1, + bias=self.use_bias, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.mask_feature = Conv2d( + feat_channels, out_channels, kernel_size=3, stride=1, padding=1) + + def init_weights(self): + """Initialize weights.""" + for i in range(0, self.num_inputs - 2): + kaiming_init(self.lateral_convs[i].conv, a=1) + kaiming_init(self.output_convs[i].conv, a=1) + + kaiming_init(self.mask_feature, a=1) + kaiming_init(self.last_feat_conv, a=1) + + def forward(self, feats, img_metas): + """ + Args: + feats (list[Tensor]): Feature maps of each level. Each has + shape of [bs, c, h, w]. + img_metas (list[dict]): List of image information. Pass in + for creating more accurate padding mask. #! not used here. + + Returns: + tuple: a tuple containing the following: + + - mask_feature (Tensor): Shape [bs, c, h, w]. + - memory (Tensor): Output of last stage of backbone. + Shape [bs, c, h, w]. + """ + y = self.last_feat_conv(feats[-1]) + for i in range(self.num_inputs - 2, -1, -1): + x = feats[i] + cur_fpn = self.lateral_convs[i](x) + y = cur_fpn + \ + F.interpolate(y, size=cur_fpn.shape[-2:], mode='nearest') + y = self.output_convs[i](y) + + mask_feature = self.mask_feature(y) + memory = feats[-1] + return mask_feature, memory + + +@PLUGIN_LAYERS.register_module() +class TransformerEncoderPixelDecoder(PixelDecoder): + """Pixel decoder with transormer encoder inside. + + Args: + in_channels (list[int] | tuple[int]): Number of channels in the + input feature maps. + feat_channels (int): Number channels for feature. + out_channels (int): Number channels for output. + norm_cfg (obj:`mmcv.ConfigDict`|dict): Config for normalization. + Defaults to dict(type='GN', num_groups=32). + act_cfg (obj:`mmcv.ConfigDict`|dict): Config for activation. + Defaults to dict(type='ReLU'). + encoder (obj:`mmcv.ConfigDict`|dict): Config for transorformer + encoder.Defaults to None. + positional_encoding (obj:`mmcv.ConfigDict`|dict): Config for + transformer encoder position encoding. Defaults to + dict(type='SinePositionalEncoding', num_feats=128, + normalize=True). + init_cfg (obj:`mmcv.ConfigDict`|dict): Initialization config dict. + Default: None + """ + def __init__(self, + in_channels, + feat_channels, + out_channels, + norm_cfg=dict(type='GN', num_groups=32), + act_cfg=dict(type='ReLU'), + encoder=None, + positional_encoding=dict( + type='SinePositionalEncoding', + num_feats=128, + normalize=True), + init_cfg=None): + super(TransformerEncoderPixelDecoder, self).__init__( + in_channels, + feat_channels, + out_channels, + norm_cfg, + act_cfg, + init_cfg=init_cfg) + self.last_feat_conv = None + + self.encoder = build_transformer_layer_sequence(encoder) + self.encoder_embed_dims = self.encoder.embed_dims + assert self.encoder_embed_dims == feat_channels, 'embed_dims({}) of ' \ + 'tranformer encoder must equal to feat_channels({})'.format( + feat_channels, self.encoder_embed_dims) + self.positional_encoding = build_positional_encoding( + positional_encoding) + self.encoder_in_proj = Conv2d( + in_channels[-1], feat_channels, kernel_size=1) + self.encoder_out_proj = ConvModule( + feat_channels, + feat_channels, + kernel_size=3, + stride=1, + padding=1, + bias=self.use_bias, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def init_weights(self): + """Initialize weights.""" + for i in range(0, self.num_inputs - 2): + kaiming_init(self.lateral_convs[i].conv, a=1) + kaiming_init(self.output_convs[i].conv, a=1) + + kaiming_init(self.mask_feature, a=1) + kaiming_init(self.encoder_in_proj, a=1) + kaiming_init(self.encoder_out_proj.conv, a=1) + + def forward(self, feats, img_metas): + """ + Args: + feats (list[Tensor]): Feature maps of each level. Each has + shape of [bs, c, h, w]. + img_metas (list[dict]): List of image information. Pass in + for creating more accurate padding mask. + + Returns: + tuple: a tuple containing the following: + + - mask_feature (Tensor): shape [bs, c, h, w]. + - memory (Tensor): shape [bs, c, h, w]. + """ + feat_last = feats[-1] + bs, c, h, w = feat_last.shape + input_img_h, input_img_w = img_metas[0]['pad_shape'][:-1] + # input_img_h, input_img_w = img_metas[0]['batch_input_shape'] + padding_mask = feat_last.new_ones((bs, input_img_h, input_img_w), + dtype=torch.float32) + for i in range(bs): + img_h, img_w, _ = img_metas[i]['img_shape'] + padding_mask[i, :img_h, :img_w] = 0 + padding_mask = F.interpolate( + padding_mask.unsqueeze(1), + size=feat_last.shape[-2:], + mode='nearest').to(torch.bool).squeeze(1) + + pos_embed = self.positional_encoding(padding_mask) + feat_last = self.encoder_in_proj(feat_last) + # [bs, c, h, w] -> [nq, bs, dim] + feat_last = feat_last.flatten(2).permute(2, 0, 1) + pos_embed = pos_embed.flatten(2).permute(2, 0, 1) + padding_mask = padding_mask.flatten(1) # [bs, h, w] -> [bs, h*w] + memory = self.encoder( + query=feat_last, + key=None, + value=None, + query_pos=pos_embed, + query_key_padding_mask=padding_mask) + # [nq, bs, em] -> [bs, c, h, w] + memory = memory.permute(1, 2, 0).view(bs, self.encoder_embed_dims, h, + w) + y = self.encoder_out_proj(memory) + for i in range(self.num_inputs - 2, -1, -1): + x = feats[i] + cur_fpn = self.lateral_convs[i](x) + y = cur_fpn + \ + F.interpolate(y, size=cur_fpn.shape[-2:], mode='nearest') + y = self.output_convs[i](y) + + mask_feature = self.mask_feature(y) + return mask_feature, memory diff --git a/segmentation/mmseg_custom/models/segmentors/__init__.py b/segmentation/mmseg_custom/models/segmentors/__init__.py new file mode 100644 index 000000000..f380c9ef7 --- /dev/null +++ b/segmentation/mmseg_custom/models/segmentors/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .encoder_decoder_mask2former import EncoderDecoderMask2Former +from .encoder_decoder_mask2former_aug import EncoderDecoderMask2FormerAug + +__all__ = ['EncoderDecoderMask2Former', 'EncoderDecoderMask2FormerAug'] diff --git a/segmentation/mmseg_custom/models/segmentors/encoder_decoder_mask2former.py b/segmentation/mmseg_custom/models/segmentors/encoder_decoder_mask2former.py new file mode 100644 index 000000000..190635844 --- /dev/null +++ b/segmentation/mmseg_custom/models/segmentors/encoder_decoder_mask2former.py @@ -0,0 +1,285 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmseg.core import add_prefix +from mmseg.models import builder +from mmseg.models.builder import SEGMENTORS +from mmseg.models.segmentors.base import BaseSegmentor +from mmseg.ops import resize + + +@SEGMENTORS.register_module() +class EncoderDecoderMask2Former(BaseSegmentor): + """Encoder Decoder segmentors. + + EncoderDecoder typically consists of backbone, decode_head, auxiliary_head. + Note that auxiliary_head is only used for deep supervision during training, + which could be dumped during inference. + """ + def __init__(self, + backbone, + decode_head, + neck=None, + auxiliary_head=None, + train_cfg=None, + test_cfg=None, + pretrained=None, + init_cfg=None): + super(EncoderDecoderMask2Former, self).__init__(init_cfg) + if pretrained is not None: + assert backbone.get('pretrained') is None, \ + 'both backbone and segmentor set pretrained weight' + backbone.pretrained = pretrained + self.backbone = builder.build_backbone(backbone) + if neck is not None: + self.neck = builder.build_neck(neck) + decode_head.update(train_cfg=train_cfg) + decode_head.update(test_cfg=test_cfg) + self._init_decode_head(decode_head) + self._init_auxiliary_head(auxiliary_head) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + assert self.with_decode_head + + def _init_decode_head(self, decode_head): + """Initialize ``decode_head``""" + self.decode_head = builder.build_head(decode_head) + self.align_corners = self.decode_head.align_corners + self.num_classes = self.decode_head.num_classes + + def _init_auxiliary_head(self, auxiliary_head): + """Initialize ``auxiliary_head``""" + if auxiliary_head is not None: + if isinstance(auxiliary_head, list): + self.auxiliary_head = nn.ModuleList() + for head_cfg in auxiliary_head: + self.auxiliary_head.append(builder.build_head(head_cfg)) + else: + self.auxiliary_head = builder.build_head(auxiliary_head) + + def extract_feat(self, img): + """Extract features from images.""" + x = self.backbone(img) + if self.with_neck: + x = self.neck(x) + return x + + def encode_decode(self, img, img_metas): + """Encode images with backbone and decode into a semantic segmentation + map of the same size as input.""" + x = self.extract_feat(img) + out = self._decode_head_forward_test(x, img_metas) + out = resize( + input=out, + size=img.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + return out + + def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg, + **kwargs): + """Run forward function and calculate loss for decode head in + training.""" + losses = dict() + loss_decode = self.decode_head.forward_train(x, img_metas, + gt_semantic_seg, **kwargs) + + losses.update(add_prefix(loss_decode, 'decode')) + return losses + + def _decode_head_forward_test(self, x, img_metas): + """Run forward function and calculate loss for decode head in + inference.""" + seg_logits = self.decode_head.forward_test(x, img_metas, self.test_cfg) + return seg_logits + + def _auxiliary_head_forward_train(self, x, img_metas, gt_semantic_seg): + """Run forward function and calculate loss for auxiliary head in + training.""" + losses = dict() + if isinstance(self.auxiliary_head, nn.ModuleList): + for idx, aux_head in enumerate(self.auxiliary_head): + loss_aux = aux_head.forward_train(x, img_metas, + gt_semantic_seg, + self.train_cfg) + losses.update(add_prefix(loss_aux, f'aux_{idx}')) + else: + loss_aux = self.auxiliary_head.forward_train( + x, img_metas, gt_semantic_seg, self.train_cfg) + losses.update(add_prefix(loss_aux, 'aux')) + + return losses + + def forward_dummy(self, img): + """Dummy forward function.""" + seg_logit = self.encode_decode(img, None) + + return seg_logit + + def forward_train(self, img, img_metas, gt_semantic_seg, **kwargs): + """Forward function for training. + + Args: + img (Tensor): Input images. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:Collect`. + gt_semantic_seg (Tensor): Semantic segmentation masks + used if the architecture supports semantic segmentation task. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + + x = self.extract_feat(img) + + losses = dict() + + loss_decode = self._decode_head_forward_train(x, img_metas, + gt_semantic_seg, + **kwargs) + losses.update(loss_decode) + + if self.with_auxiliary_head: + loss_aux = self._auxiliary_head_forward_train( + x, img_metas, gt_semantic_seg) + losses.update(loss_aux) + + return losses + + # TODO refactor + def slide_inference(self, img, img_meta, rescale): + """Inference by sliding-window with overlap. + + If h_crop > h_img or w_crop > w_img, the small patch will be used to + decode without padding. + """ + + h_stride, w_stride = self.test_cfg.stride + h_crop, w_crop = self.test_cfg.crop_size + batch_size, _, h_img, w_img = img.size() + num_classes = self.num_classes + h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 + w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 + preds = img.new_zeros((batch_size, num_classes, h_img, w_img)) + count_mat = img.new_zeros((batch_size, 1, h_img, w_img)) + for h_idx in range(h_grids): + for w_idx in range(w_grids): + y1 = h_idx * h_stride + x1 = w_idx * w_stride + y2 = min(y1 + h_crop, h_img) + x2 = min(x1 + w_crop, w_img) + y1 = max(y2 - h_crop, 0) + x1 = max(x2 - w_crop, 0) + crop_img = img[:, :, y1:y2, x1:x2] + crop_seg_logit = self.encode_decode(crop_img, img_meta) + preds += F.pad(crop_seg_logit, + (int(x1), int(preds.shape[3] - x2), int(y1), + int(preds.shape[2] - y2))) + + count_mat[:, :, y1:y2, x1:x2] += 1 + assert (count_mat == 0).sum() == 0 + if torch.onnx.is_in_onnx_export(): + # cast count_mat to constant while exporting to ONNX + count_mat = torch.from_numpy( + count_mat.cpu().detach().numpy()).to(device=img.device) + preds = preds / count_mat + if rescale: + preds = resize( + preds, + size=img_meta[0]['ori_shape'][:2], + mode='bilinear', + align_corners=self.align_corners, + warning=False) + return preds + + def whole_inference(self, img, img_meta, rescale): + """Inference with full image.""" + + seg_logit = self.encode_decode(img, img_meta) + if rescale: + # support dynamic shape for onnx + if torch.onnx.is_in_onnx_export(): + size = img.shape[2:] + else: + size = img_meta[0]['ori_shape'][:2] + seg_logit = resize( + seg_logit, + size=size, + mode='bilinear', + align_corners=self.align_corners, + warning=False) + + return seg_logit + + def inference(self, img, img_meta, rescale): + """Inference with slide/whole style. + + Args: + img (Tensor): The input image of shape (N, 3, H, W). + img_meta (dict): Image info dict where each dict has: 'img_shape', + 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:Collect`. + rescale (bool): Whether rescale back to original shape. + + Returns: + Tensor: The output segmentation map. + """ + + assert self.test_cfg.mode in ['slide', 'whole'] + ori_shape = img_meta[0]['ori_shape'] + assert all(_['ori_shape'] == ori_shape for _ in img_meta) + if self.test_cfg.mode == 'slide': + seg_logit = self.slide_inference(img, img_meta, rescale) + else: + seg_logit = self.whole_inference(img, img_meta, rescale) + output = F.softmax(seg_logit, dim=1) + flip = img_meta[0]['flip'] + if flip: + flip_direction = img_meta[0]['flip_direction'] + assert flip_direction in ['horizontal', 'vertical'] + if flip_direction == 'horizontal': + output = output.flip(dims=(3,)) + elif flip_direction == 'vertical': + output = output.flip(dims=(2,)) + + return output + + def simple_test(self, img, img_meta, rescale=True): + """Simple test with single image.""" + seg_logit = self.inference(img, img_meta, rescale) + seg_pred = seg_logit.argmax(dim=1) + if torch.onnx.is_in_onnx_export(): + # our inference backend only support 4D output + seg_pred = seg_pred.unsqueeze(0) + return seg_pred + seg_pred = seg_pred.cpu().numpy() + # unravel batch dim + seg_pred = list(seg_pred) + return seg_pred + + def aug_test(self, imgs, img_metas, rescale=True): + """Test with augmentations. + + Only rescale=True is supported. + """ + # aug_test rescale all imgs back to ori_shape for now + assert rescale + # to save memory, we get augmented seg logit inplace + seg_logit = self.inference(imgs[0], img_metas[0], rescale) + for i in range(1, len(imgs)): + cur_seg_logit = self.inference(imgs[i], img_metas[i], rescale) + seg_logit += cur_seg_logit + seg_logit /= len(imgs) + seg_pred = seg_logit.argmax(dim=1) + seg_pred = seg_pred.cpu().numpy() + # unravel batch dim + seg_pred = list(seg_pred) + return seg_pred diff --git a/segmentation/mmseg_custom/models/segmentors/encoder_decoder_mask2former_aug.py b/segmentation/mmseg_custom/models/segmentors/encoder_decoder_mask2former_aug.py new file mode 100644 index 000000000..7dc54cf7e --- /dev/null +++ b/segmentation/mmseg_custom/models/segmentors/encoder_decoder_mask2former_aug.py @@ -0,0 +1,289 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmseg.core import add_prefix +from mmseg.models import builder +from mmseg.models.builder import SEGMENTORS +from mmseg.models.segmentors.base import BaseSegmentor +from mmseg.ops import resize + + +@SEGMENTORS.register_module() +class EncoderDecoderMask2FormerAug(BaseSegmentor): + """Encoder Decoder segmentors. + + EncoderDecoder typically consists of backbone, decode_head, auxiliary_head. + Note that auxiliary_head is only used for deep supervision during training, + which could be dumped during inference. + """ + def __init__(self, + backbone, + decode_head, + neck=None, + auxiliary_head=None, + train_cfg=None, + test_cfg=None, + pretrained=None, + init_cfg=None): + super(EncoderDecoderMask2FormerAug, self).__init__(init_cfg) + if pretrained is not None: + assert backbone.get('pretrained') is None, \ + 'both backbone and segmentor set pretrained weight' + backbone.pretrained = pretrained + self.backbone = builder.build_backbone(backbone) + if neck is not None: + self.neck = builder.build_neck(neck) + decode_head.update(train_cfg=train_cfg) + decode_head.update(test_cfg=test_cfg) + self._init_decode_head(decode_head) + self._init_auxiliary_head(auxiliary_head) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + assert self.with_decode_head + + def _init_decode_head(self, decode_head): + """Initialize ``decode_head``""" + self.decode_head = builder.build_head(decode_head) + self.align_corners = self.decode_head.align_corners + self.num_classes = self.decode_head.num_classes + + def _init_auxiliary_head(self, auxiliary_head): + """Initialize ``auxiliary_head``""" + if auxiliary_head is not None: + if isinstance(auxiliary_head, list): + self.auxiliary_head = nn.ModuleList() + for head_cfg in auxiliary_head: + self.auxiliary_head.append(builder.build_head(head_cfg)) + else: + self.auxiliary_head = builder.build_head(auxiliary_head) + + def extract_feat(self, img): + """Extract features from images.""" + x = self.backbone(img) + if self.with_neck: + x = self.neck(x) + return x + + def encode_decode(self, img, img_metas): + """Encode images with backbone and decode into a semantic segmentation + map of the same size as input.""" + x = self.extract_feat(img) + out = self._decode_head_forward_test(x, img_metas) + out = resize( + input=out, + size=img.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + return out + + def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg, + **kwargs): + """Run forward function and calculate loss for decode head in + training.""" + losses = dict() + loss_decode = self.decode_head.forward_train(x, img_metas, + gt_semantic_seg, **kwargs) + + losses.update(add_prefix(loss_decode, 'decode')) + return losses + + def _decode_head_forward_test(self, x, img_metas): + """Run forward function and calculate loss for decode head in + inference.""" + seg_logits = self.decode_head.forward_test(x, img_metas, self.test_cfg) + return seg_logits + + def _auxiliary_head_forward_train(self, x, img_metas, gt_semantic_seg): + """Run forward function and calculate loss for auxiliary head in + training.""" + losses = dict() + if isinstance(self.auxiliary_head, nn.ModuleList): + for idx, aux_head in enumerate(self.auxiliary_head): + loss_aux = aux_head.forward_train(x, img_metas, + gt_semantic_seg, + self.train_cfg) + losses.update(add_prefix(loss_aux, f'aux_{idx}')) + else: + loss_aux = self.auxiliary_head.forward_train( + x, img_metas, gt_semantic_seg, self.train_cfg) + losses.update(add_prefix(loss_aux, 'aux')) + + return losses + + def forward_dummy(self, img): + """Dummy forward function.""" + seg_logit = self.encode_decode(img, None) + + return seg_logit + + def forward_train(self, img, img_metas, gt_semantic_seg, **kwargs): + """Forward function for training. + + Args: + img (Tensor): Input images. + img_metas (list[dict]): List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:Collect`. + gt_semantic_seg (Tensor): Semantic segmentation masks + used if the architecture supports semantic segmentation task. + + Returns: + dict[str, Tensor]: a dictionary of loss components + """ + + x = self.extract_feat(img) + + losses = dict() + + loss_decode = self._decode_head_forward_train(x, img_metas, + gt_semantic_seg, + **kwargs) + losses.update(loss_decode) + + if self.with_auxiliary_head: + loss_aux = self._auxiliary_head_forward_train( + x, img_metas, gt_semantic_seg) + losses.update(loss_aux) + + return losses + + # TODO refactor + def slide_inference(self, img, img_meta, rescale, unpad=True): + """Inference by sliding-window with overlap. + + If h_crop > h_img or w_crop > w_img, the small patch will be used to + decode without padding. + """ + + h_stride, w_stride = self.test_cfg.stride + h_crop, w_crop = self.test_cfg.crop_size + batch_size, _, h_img, w_img = img.size() + num_classes = self.num_classes + h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 + w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 + preds = img.new_zeros((batch_size, num_classes, h_img, w_img)) + count_mat = img.new_zeros((batch_size, 1, h_img, w_img)) + for h_idx in range(h_grids): + for w_idx in range(w_grids): + y1 = h_idx * h_stride + x1 = w_idx * w_stride + y2 = min(y1 + h_crop, h_img) + x2 = min(x1 + w_crop, w_img) + y1 = max(y2 - h_crop, 0) + x1 = max(x2 - w_crop, 0) + crop_img = img[:, :, y1:y2, x1:x2] + crop_seg_logit = self.encode_decode(crop_img, img_meta) + preds += F.pad(crop_seg_logit, + (int(x1), int(preds.shape[3] - x2), int(y1), + int(preds.shape[2] - y2))) + + count_mat[:, :, y1:y2, x1:x2] += 1 + assert (count_mat == 0).sum() == 0 + if torch.onnx.is_in_onnx_export(): + # cast count_mat to constant while exporting to ONNX + count_mat = torch.from_numpy( + count_mat.cpu().detach().numpy()).to(device=img.device) + preds = preds / count_mat + + if unpad: + unpad_h, unpad_w = img_meta[0]['img_shape'][:2] + # logging.info(preds.shape, img_meta[0]) + preds = preds[:, :, :unpad_h, :unpad_w] + if rescale: + preds = resize(preds, + size=img_meta[0]['ori_shape'][:2], + mode='bilinear', + align_corners=self.align_corners, + warning=False) + return preds + + def whole_inference(self, img, img_meta, rescale): + """Inference with full image.""" + + seg_logit = self.encode_decode(img, img_meta) + if rescale: + # support dynamic shape for onnx + if torch.onnx.is_in_onnx_export(): + size = img.shape[2:] + else: + size = img_meta[0]['ori_shape'][:2] + seg_logit = resize( + seg_logit, + size=size, + mode='bilinear', + align_corners=self.align_corners, + warning=False) + + return seg_logit + + def inference(self, img, img_meta, rescale): + """Inference with slide/whole style. + + Args: + img (Tensor): The input image of shape (N, 3, H, W). + img_meta (dict): Image info dict where each dict has: 'img_shape', + 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:Collect`. + rescale (bool): Whether rescale back to original shape. + + Returns: + Tensor: The output segmentation map. + """ + + assert self.test_cfg.mode in ['slide', 'whole'] + ori_shape = img_meta[0]['ori_shape'] + assert all(_['ori_shape'] == ori_shape for _ in img_meta) + if self.test_cfg.mode == 'slide': + seg_logit = self.slide_inference(img, img_meta, rescale) + else: + seg_logit = self.whole_inference(img, img_meta, rescale) + output = F.softmax(seg_logit, dim=1) + flip = img_meta[0]['flip'] + if flip: + flip_direction = img_meta[0]['flip_direction'] + assert flip_direction in ['horizontal', 'vertical'] + if flip_direction == 'horizontal': + output = output.flip(dims=(3, )) + elif flip_direction == 'vertical': + output = output.flip(dims=(2, )) + + return output + + def simple_test(self, img, img_meta, rescale=True): + """Simple test with single image.""" + seg_logit = self.inference(img, img_meta, rescale) + seg_pred = seg_logit.argmax(dim=1) + if torch.onnx.is_in_onnx_export(): + # our inference backend only support 4D output + seg_pred = seg_pred.unsqueeze(0) + return seg_pred + seg_pred = seg_pred.cpu().numpy() + # unravel batch dim + seg_pred = list(seg_pred) + return seg_pred + + def aug_test(self, imgs, img_metas, rescale=True): + """Test with augmentations. + + Only rescale=True is supported. + """ + # aug_test rescale all imgs back to ori_shape for now + assert rescale + # to save memory, we get augmented seg logit inplace + seg_logit = self.inference(imgs[0], img_metas[0], rescale) + for i in range(1, len(imgs)): + cur_seg_logit = self.inference(imgs[i], img_metas[i], rescale) + seg_logit += cur_seg_logit + seg_logit /= len(imgs) + seg_pred = seg_logit.argmax(dim=1) + seg_pred = seg_pred.cpu().numpy() + # unravel batch dim + seg_pred = list(seg_pred) + return seg_pred diff --git a/segmentation/mmseg_custom/models/utils/__init__.py b/segmentation/mmseg_custom/models/utils/__init__.py new file mode 100644 index 000000000..bef9124b4 --- /dev/null +++ b/segmentation/mmseg_custom/models/utils/__init__.py @@ -0,0 +1,12 @@ +from .assigner import MaskHungarianAssigner +from .point_sample import get_uncertain_point_coords_with_randomness +from .positional_encoding import (LearnedPositionalEncoding, + SinePositionalEncoding) +from .transformer import (DetrTransformerDecoder, DetrTransformerDecoderLayer, + DynamicConv, Transformer) + +__all__ = [ + 'DetrTransformerDecoderLayer', 'DetrTransformerDecoder', 'DynamicConv', + 'Transformer', 'LearnedPositionalEncoding', 'SinePositionalEncoding', + 'MaskHungarianAssigner', 'get_uncertain_point_coords_with_randomness' +] diff --git a/segmentation/mmseg_custom/models/utils/assigner.py b/segmentation/mmseg_custom/models/utils/assigner.py new file mode 100644 index 000000000..1d6028940 --- /dev/null +++ b/segmentation/mmseg_custom/models/utils/assigner.py @@ -0,0 +1,165 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod + +import torch +import torch.nn.functional as F + +from ..builder import MASK_ASSIGNERS, build_match_cost + +try: + from scipy.optimize import linear_sum_assignment +except ImportError: + linear_sum_assignment = None + + +class AssignResult(metaclass=ABCMeta): + """Collection of assign results.""" + def __init__(self, num_gts, gt_inds, labels): + self.num_gts = num_gts + self.gt_inds = gt_inds + self.labels = labels + + @property + def info(self): + info = { + 'num_gts': self.num_gts, + 'gt_inds': self.gt_inds, + 'labels': self.labels, + } + return info + + +class BaseAssigner(metaclass=ABCMeta): + """Base assigner that assigns boxes to ground truth boxes.""" + @abstractmethod + def assign(self, masks, gt_masks, gt_masks_ignore=None, gt_labels=None): + """Assign boxes to either a ground truth boxes or a negative boxes.""" + pass + + +@MASK_ASSIGNERS.register_module() +class MaskHungarianAssigner(BaseAssigner): + """Computes one-to-one matching between predictions and ground truth for + mask. + + This class computes an assignment between the targets and the predictions + based on the costs. The costs are weighted sum of three components: + classification cost, regression L1 cost and regression iou cost. The + targets don't include the no_object, so generally there are more + predictions than targets. After the one-to-one matching, the un-matched + are treated as backgrounds. Thus each query prediction will be assigned + with `0` or a positive integer indicating the ground truth index: + + - 0: negative sample, no assigned gt + - positive integer: positive sample, index (1-based) of assigned gt + + Args: + cls_cost (obj:`mmcv.ConfigDict`|dict): Classification cost config. + mask_cost (obj:`mmcv.ConfigDict`|dict): Mask cost config. + dice_cost (obj:`mmcv.ConfigDict`|dict): Dice cost config. + """ + def __init__(self, + cls_cost=dict(type='ClassificationCost', weight=1.0), + dice_cost=dict(type='DiceCost', weight=1.0), + mask_cost=dict(type='MaskFocalCost', weight=1.0)): + self.cls_cost = build_match_cost(cls_cost) + self.dice_cost = build_match_cost(dice_cost) + self.mask_cost = build_match_cost(mask_cost) + + def assign(self, + cls_pred, + mask_pred, + gt_labels, + gt_masks, + img_meta, + gt_masks_ignore=None, + eps=1e-7): + """Computes one-to-one matching based on the weighted costs. + + This method assign each query prediction to a ground truth or + background. The `assigned_gt_inds` with -1 means don't care, + 0 means negative sample, and positive number is the index (1-based) + of assigned gt. + The assignment is done in the following steps, the order matters. + + 1. assign every prediction to -1 + 2. compute the weighted costs + 3. do Hungarian matching on CPU based on the costs + 4. assign all to 0 (background) first, then for each matched pair + between predictions and gts, treat this prediction as foreground + and assign the corresponding gt index (plus 1) to it. + + Args: + mask_pred (Tensor): Predicted mask, shape [num_query, h, w] + cls_pred (Tensor): Predicted classification logits, shape + [num_query, num_class]. + gt_masks (Tensor): Ground truth mask, shape [num_gt, h, w]. + gt_labels (Tensor): Label of `gt_masks`, shape (num_gt,). + img_meta (dict): Meta information for current image. + gt_masks_ignore (Tensor, optional): Ground truth masks that are + labelled as `ignored`. Default None. + eps (int | float, optional): A value added to the denominator for + numerical stability. Default 1e-7. + + Returns: + :obj:`AssignResult`: The assigned result. + """ + assert gt_masks_ignore is None, \ + 'Only case when gt_masks_ignore is None is supported.' + num_gts, num_queries = gt_labels.shape[0], cls_pred.shape[0] + + # 1. assign -1 by default + assigned_gt_inds = cls_pred.new_full((num_queries, ), + -1, + dtype=torch.long) + assigned_labels = cls_pred.new_full((num_queries, ), + -1, + dtype=torch.long) + if num_gts == 0 or num_queries == 0: + # No ground truth or boxes, return empty assignment + if num_gts == 0: + # No ground truth, assign all to background + assigned_gt_inds[:] = 0 + return AssignResult( + num_gts, assigned_gt_inds, labels=assigned_labels) + + # 2. compute the weighted costs + # classification and maskcost. + if self.cls_cost.weight != 0 and cls_pred is not None: + cls_cost = self.cls_cost(cls_pred, gt_labels) + else: + cls_cost = 0 + + if self.mask_cost.weight != 0: + # mask_pred shape = [nq, h, w] + # gt_mask shape = [ng, h, w] + # mask_cost shape = [nq, ng] + mask_cost = self.mask_cost(mask_pred, gt_masks) + else: + mask_cost = 0 + + if self.dice_cost.weight != 0: + dice_cost = self.dice_cost(mask_pred, gt_masks) + else: + dice_cost = 0 + cost = cls_cost + mask_cost + dice_cost + + # 3. do Hungarian matching on CPU using linear_sum_assignment + cost = cost.detach().cpu() + if linear_sum_assignment is None: + raise ImportError('Please run "pip install scipy" ' + 'to install scipy first.') + + matched_row_inds, matched_col_inds = linear_sum_assignment(cost) + matched_row_inds = torch.from_numpy(matched_row_inds).to( + cls_pred.device) + matched_col_inds = torch.from_numpy(matched_col_inds).to( + cls_pred.device) + + # 4. assign backgrounds and foregrounds + # assign all indices to backgrounds first + assigned_gt_inds[:] = 0 + # assign foregrounds based on matching results + assigned_gt_inds[matched_row_inds] = matched_col_inds + 1 + assigned_labels[matched_row_inds] = gt_labels[matched_col_inds] + return AssignResult(num_gts, assigned_gt_inds, labels=assigned_labels) diff --git a/segmentation/mmseg_custom/models/utils/point_sample.py b/segmentation/mmseg_custom/models/utils/point_sample.py new file mode 100644 index 000000000..ac4b2daf7 --- /dev/null +++ b/segmentation/mmseg_custom/models/utils/point_sample.py @@ -0,0 +1,87 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmcv.ops import point_sample + + +def get_uncertainty(mask_pred, labels): + """Estimate uncertainty based on pred logits. + + We estimate uncertainty as L1 distance between 0.0 and the logits + prediction in 'mask_pred' for the foreground class in `classes`. + + Args: + mask_pred (Tensor): mask predication logits, shape (num_rois, + num_classes, mask_height, mask_width). + + labels (list[Tensor]): Either predicted or ground truth label for + each predicted mask, of length num_rois. + + Returns: + scores (Tensor): Uncertainty scores with the most uncertain + locations having the highest uncertainty score, + shape (num_rois, 1, mask_height, mask_width) + """ + if mask_pred.shape[1] == 1: + gt_class_logits = mask_pred.clone() + else: + inds = torch.arange(mask_pred.shape[0], device=mask_pred.device) + gt_class_logits = mask_pred[inds, labels].unsqueeze(1) + return -torch.abs(gt_class_logits) + + +def get_uncertain_point_coords_with_randomness(mask_pred, labels, num_points, + oversample_ratio, + importance_sample_ratio): + """Get ``num_points`` most uncertain points with random points during + train. + + Sample points in [0, 1] x [0, 1] coordinate space based on their + uncertainty. The uncertainties are calculated for each point using + 'get_uncertainty()' function that takes point's logit prediction as + input. + + Args: + mask_pred (Tensor): A tensor of shape (num_rois, num_classes, + mask_height, mask_width) for class-specific or class-agnostic + prediction. + labels (list): The ground truth class for each instance. + num_points (int): The number of points to sample. + oversample_ratio (int): Oversampling parameter. + importance_sample_ratio (float): Ratio of points that are sampled + via importnace sampling. + + Returns: + point_coords (Tensor): A tensor of shape (num_rois, num_points, 2) + that contains the coordinates sampled points. + """ + assert oversample_ratio >= 1 + assert 0 <= importance_sample_ratio <= 1 + batch_size = mask_pred.shape[0] + num_sampled = int(num_points * oversample_ratio) + point_coords = torch.rand( + batch_size, num_sampled, 2, device=mask_pred.device) + point_logits = point_sample(mask_pred, point_coords) + # It is crucial to calculate uncertainty based on the sampled + # prediction value for the points. Calculating uncertainties of the + # coarse predictions first and sampling them for points leads to + # incorrect results. To illustrate this: assume uncertainty func( + # logits)=-abs(logits), a sampled point between two coarse + # predictions with -1 and 1 logits has 0 logits, and therefore 0 + # uncertainty value. However, if we calculate uncertainties for the + # coarse predictions first, both will have -1 uncertainty, + # and sampled point will get -1 uncertainty. + point_uncertainties = get_uncertainty(point_logits, labels) + num_uncertain_points = int(importance_sample_ratio * num_points) + num_random_points = num_points - num_uncertain_points + idx = torch.topk( + point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] + shift = num_sampled * torch.arange( + batch_size, dtype=torch.long, device=mask_pred.device) + idx += shift[:, None] + point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view( + batch_size, num_uncertain_points, 2) + if num_random_points > 0: + rand_roi_coords = torch.rand( + batch_size, num_random_points, 2, device=mask_pred.device) + point_coords = torch.cat((point_coords, rand_roi_coords), dim=1) + return point_coords \ No newline at end of file diff --git a/segmentation/mmseg_custom/models/utils/positional_encoding.py b/segmentation/mmseg_custom/models/utils/positional_encoding.py new file mode 100644 index 000000000..426e05836 --- /dev/null +++ b/segmentation/mmseg_custom/models/utils/positional_encoding.py @@ -0,0 +1,161 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.nn as nn +from mmcv.cnn.bricks.transformer import POSITIONAL_ENCODING +from mmcv.runner import BaseModule + + +@POSITIONAL_ENCODING.register_module() +class SinePositionalEncoding(BaseModule): + """Position encoding with sine and cosine functions. + + See `End-to-End Object Detection with Transformers + `_ for details. + + Args: + num_feats (int): The feature dimension for each position + along x-axis or y-axis. Note the final returned dimension + for each position is 2 times of this value. + temperature (int, optional): The temperature used for scaling + the position embedding. Defaults to 10000. + normalize (bool, optional): Whether to normalize the position + embedding. Defaults to False. + scale (float, optional): A scale factor that scales the position + embedding. The scale will be used only when `normalize` is True. + Defaults to 2*pi. + eps (float, optional): A value added to the denominator for + numerical stability. Defaults to 1e-6. + offset (float): offset add to embed when do the normalization. + Defaults to 0. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + def __init__(self, + num_feats, + temperature=10000, + normalize=False, + scale=2 * math.pi, + eps=1e-6, + offset=0., + init_cfg=None): + super(SinePositionalEncoding, self).__init__(init_cfg) + if normalize: + assert isinstance(scale, (float, int)), 'when normalize is set,' \ + 'scale should be provided and in float or int type, ' \ + f'found {type(scale)}' + self.num_feats = num_feats + self.temperature = temperature + self.normalize = normalize + self.scale = scale + self.eps = eps + self.offset = offset + + def forward(self, mask): + """Forward function for `SinePositionalEncoding`. + + Args: + mask (Tensor): ByteTensor mask. Non-zero values representing + ignored positions, while zero values means valid positions + for this image. Shape [bs, h, w]. + + Returns: + pos (Tensor): Returned position embedding with shape + [bs, num_feats*2, h, w]. + """ + # For convenience of exporting to ONNX, it's required to convert + # `masks` from bool to int. + mask = mask.to(torch.int) + not_mask = 1 - mask # logical_not + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + y_embed = (y_embed + self.offset) / \ + (y_embed[:, -1:, :] + self.eps) * self.scale + x_embed = (x_embed + self.offset) / \ + (x_embed[:, :, -1:] + self.eps) * self.scale + dim_t = torch.arange( + self.num_feats, dtype=torch.float32, device=mask.device) + dim_t = self.temperature**(2 * (dim_t // 2) / self.num_feats) + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + # use `view` instead of `flatten` for dynamically exporting to ONNX + B, H, W = mask.size() + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), + dim=4).view(B, H, W, -1) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), + dim=4).view(B, H, W, -1) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + def __repr__(self): + """str: a string that describes the module""" + repr_str = self.__class__.__name__ + repr_str += f'(num_feats={self.num_feats}, ' + repr_str += f'temperature={self.temperature}, ' + repr_str += f'normalize={self.normalize}, ' + repr_str += f'scale={self.scale}, ' + repr_str += f'eps={self.eps})' + return repr_str + + +@POSITIONAL_ENCODING.register_module() +class LearnedPositionalEncoding(BaseModule): + """Position embedding with learnable embedding weights. + + Args: + num_feats (int): The feature dimension for each position + along x-axis or y-axis. The final returned dimension for + each position is 2 times of this value. + row_num_embed (int, optional): The dictionary size of row embeddings. + Default 50. + col_num_embed (int, optional): The dictionary size of col embeddings. + Default 50. + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + def __init__(self, + num_feats, + row_num_embed=50, + col_num_embed=50, + init_cfg=dict(type='Uniform', layer='Embedding')): + super(LearnedPositionalEncoding, self).__init__(init_cfg) + self.row_embed = nn.Embedding(row_num_embed, num_feats) + self.col_embed = nn.Embedding(col_num_embed, num_feats) + self.num_feats = num_feats + self.row_num_embed = row_num_embed + self.col_num_embed = col_num_embed + + def forward(self, mask): + """Forward function for `LearnedPositionalEncoding`. + + Args: + mask (Tensor): ByteTensor mask. Non-zero values representing + ignored positions, while zero values means valid positions + for this image. Shape [bs, h, w]. + + Returns: + pos (Tensor): Returned position embedding with shape + [bs, num_feats*2, h, w]. + """ + h, w = mask.shape[-2:] + x = torch.arange(w, device=mask.device) + y = torch.arange(h, device=mask.device) + x_embed = self.col_embed(x) + y_embed = self.row_embed(y) + pos = torch.cat( + (x_embed.unsqueeze(0).repeat(h, 1, 1), y_embed.unsqueeze(1).repeat( + 1, w, 1)), + dim=-1).permute(2, 0, + 1).unsqueeze(0).repeat(mask.shape[0], 1, 1, 1) + return pos + + def __repr__(self): + """str: a string that describes the module""" + repr_str = self.__class__.__name__ + repr_str += f'(num_feats={self.num_feats}, ' + repr_str += f'row_num_embed={self.row_num_embed}, ' + repr_str += f'col_num_embed={self.col_num_embed})' + return repr_str diff --git a/segmentation/mmseg_custom/models/utils/transformer.py b/segmentation/mmseg_custom/models/utils/transformer.py new file mode 100644 index 000000000..80d8a31f5 --- /dev/null +++ b/segmentation/mmseg_custom/models/utils/transformer.py @@ -0,0 +1,997 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +import warnings +from typing import Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import (build_activation_layer, build_conv_layer, + build_norm_layer, xavier_init) +from mmcv.cnn.bricks.registry import (TRANSFORMER_LAYER, + TRANSFORMER_LAYER_SEQUENCE) +from mmcv.cnn.bricks.transformer import (BaseTransformerLayer, + TransformerLayerSequence, + build_transformer_layer_sequence) +from mmcv.runner.base_module import BaseModule +from mmcv.utils import to_2tuple +from torch.nn.init import normal_ + +from ..builder import TRANSFORMER + +try: + from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention + +except ImportError: + warnings.warn( + '`MultiScaleDeformableAttention` in MMCV has been moved to ' + '`mmcv.ops.multi_scale_deform_attn`, please update your MMCV') + from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention + + +class AdaptivePadding(nn.Module): + """Applies padding to input (if needed) so that input can get fully covered + by filter you specified. It support two modes "same" and "corner". The + "same" mode is same with "SAME" padding mode in TensorFlow, pad zero around + input. The "corner" mode would pad zero to bottom right. + + Args: + kernel_size (int | tuple): Size of the kernel: + stride (int | tuple): Stride of the filter. Default: 1: + dilation (int | tuple): Spacing between kernel elements. + Default: 1 + padding (str): Support "same" and "corner", "corner" mode + would pad zero to bottom right, and "same" mode would + pad zero around input. Default: "corner". + Example: + >>> kernel_size = 16 + >>> stride = 16 + >>> dilation = 1 + >>> input = torch.rand(1, 1, 15, 17) + >>> adap_pad = AdaptivePadding( + >>> kernel_size=kernel_size, + >>> stride=stride, + >>> dilation=dilation, + >>> padding="corner") + >>> out = adap_pad(input) + >>> assert (out.shape[2], out.shape[3]) == (16, 32) + >>> input = torch.rand(1, 1, 16, 17) + >>> out = adap_pad(input) + >>> assert (out.shape[2], out.shape[3]) == (16, 32) + """ + def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'): + + super(AdaptivePadding, self).__init__() + + assert padding in ('same', 'corner') + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + padding = to_2tuple(padding) + dilation = to_2tuple(dilation) + + self.padding = padding + self.kernel_size = kernel_size + self.stride = stride + self.dilation = dilation + + def get_pad_shape(self, input_shape): + input_h, input_w = input_shape + kernel_h, kernel_w = self.kernel_size + stride_h, stride_w = self.stride + output_h = math.ceil(input_h / stride_h) + output_w = math.ceil(input_w / stride_w) + pad_h = max((output_h - 1) * stride_h + + (kernel_h - 1) * self.dilation[0] + 1 - input_h, 0) + pad_w = max((output_w - 1) * stride_w + + (kernel_w - 1) * self.dilation[1] + 1 - input_w, 0) + return pad_h, pad_w + + def forward(self, x): + pad_h, pad_w = self.get_pad_shape(x.size()[-2:]) + if pad_h > 0 or pad_w > 0: + if self.padding == 'corner': + x = F.pad(x, [0, pad_w, 0, pad_h]) + elif self.padding == 'same': + x = F.pad(x, [ + pad_w // 2, pad_w - pad_w // 2, pad_h // 2, + pad_h - pad_h // 2 + ]) + return x + + +class PatchMerging(BaseModule): + """Merge patch feature map. + + This layer groups feature map by kernel_size, and applies norm and linear + layers to the grouped feature map. Our implementation uses `nn.Unfold` to + merge patch, which is about 25% faster than original implementation. + Instead, we need to modify pretrained models for compatibility. + + Args: + in_channels (int): The num of input channels. + to gets fully covered by filter and stride you specified.. + Default: True. + out_channels (int): The num of output channels. + kernel_size (int | tuple, optional): the kernel size in the unfold + layer. Defaults to 2. + stride (int | tuple, optional): the stride of the sliding blocks in the + unfold layer. Default: None. (Would be set as `kernel_size`) + padding (int | tuple | string ): The padding length of + embedding conv. When it is a string, it means the mode + of adaptive padding, support "same" and "corner" now. + Default: "corner". + dilation (int | tuple, optional): dilation parameter in the unfold + layer. Default: 1. + bias (bool, optional): Whether to add bias in linear layer or not. + Defaults: False. + norm_cfg (dict, optional): Config dict for normalization layer. + Default: dict(type='LN'). + init_cfg (dict, optional): The extra config for initialization. + Default: None. + """ + def __init__(self, + in_channels, + out_channels, + kernel_size=2, + stride=None, + padding='corner', + dilation=1, + bias=False, + norm_cfg=dict(type='LN'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.out_channels = out_channels + if stride: + stride = stride + else: + stride = kernel_size + + kernel_size = to_2tuple(kernel_size) + stride = to_2tuple(stride) + dilation = to_2tuple(dilation) + + if isinstance(padding, str): + self.adap_padding = AdaptivePadding( + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + padding=padding) + # disable the padding of unfold + padding = 0 + else: + self.adap_padding = None + + padding = to_2tuple(padding) + self.sampler = nn.Unfold( + kernel_size=kernel_size, + dilation=dilation, + padding=padding, + stride=stride) + + sample_dim = kernel_size[0] * kernel_size[1] * in_channels + + if norm_cfg is not None: + self.norm = build_norm_layer(norm_cfg, sample_dim)[1] + else: + self.norm = None + + self.reduction = nn.Linear(sample_dim, out_channels, bias=bias) + + def forward(self, x, input_size): + """ + Args: + x (Tensor): Has shape (B, H*W, C_in). + input_size (tuple[int]): The spatial shape of x, arrange as (H, W). + Default: None. + + Returns: + tuple: Contains merged results and its spatial shape. + + - x (Tensor): Has shape (B, Merged_H * Merged_W, C_out) + - out_size (tuple[int]): Spatial shape of x, arrange as + (Merged_H, Merged_W). + """ + B, L, C = x.shape + assert isinstance(input_size, Sequence), f'Expect ' \ + f'input_size is ' \ + f'`Sequence` ' \ + f'but get {input_size}' + + H, W = input_size + assert L == H * W, 'input feature has wrong size' + + x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W + # Use nn.Unfold to merge patch. About 25% faster than original method, + # but need to modify pretrained model for compatibility + + if self.adap_padding: + x = self.adap_padding(x) + H, W = x.shape[-2:] + + x = self.sampler(x) + # if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2) + + out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] * + (self.sampler.kernel_size[0] - 1) - + 1) // self.sampler.stride[0] + 1 + out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] * + (self.sampler.kernel_size[1] - 1) - + 1) // self.sampler.stride[1] + 1 + + output_size = (out_h, out_w) + x = x.transpose(1, 2) # B, H/2*W/2, 4*C + x = self.norm(x) if self.norm else x + x = self.reduction(x) + return x, output_size + + +def inverse_sigmoid(x, eps=1e-5): + """Inverse function of sigmoid. + + Args: + x (Tensor): The tensor to do the + inverse. + eps (float): EPS avoid numerical + overflow. Defaults 1e-5. + Returns: + Tensor: The x has passed the inverse + function of sigmoid, has same + shape with input. + """ + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1 / x2) + + +@TRANSFORMER_LAYER.register_module() +class DetrTransformerDecoderLayer(BaseTransformerLayer): + """Implements decoder layer in DETR transformer. + + Args: + attn_cfgs (list[`mmcv.ConfigDict`] | list[dict] | dict )): + Configs for self_attention or cross_attention, the order + should be consistent with it in `operation_order`. If it is + a dict, it would be expand to the number of attention in + `operation_order`. + feedforward_channels (int): The hidden dimension for FFNs. + ffn_dropout (float): Probability of an element to be zeroed + in ffn. Default 0.0. + operation_order (tuple[str]): The execution order of operation + in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm'). + Default:None + act_cfg (dict): The activation config for FFNs. Default: `LN` + norm_cfg (dict): Config dict for normalization layer. + Default: `LN`. + ffn_num_fcs (int): The number of fully-connected layers in FFNs. + Default:2. + """ + def __init__(self, + attn_cfgs, + feedforward_channels, + ffn_dropout=0.0, + operation_order=None, + act_cfg=dict(type='ReLU', inplace=True), + norm_cfg=dict(type='LN'), + ffn_num_fcs=2, + **kwargs): + super(DetrTransformerDecoderLayer, self).__init__( + attn_cfgs=attn_cfgs, + feedforward_channels=feedforward_channels, + ffn_dropout=ffn_dropout, + operation_order=operation_order, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + ffn_num_fcs=ffn_num_fcs, + **kwargs) + assert len(operation_order) == 6 + assert set(operation_order) == set( + ['self_attn', 'norm', 'cross_attn', 'ffn']) + + +@TRANSFORMER_LAYER_SEQUENCE.register_module() +class DetrTransformerEncoder(TransformerLayerSequence): + """TransformerEncoder of DETR. + + Args: + post_norm_cfg (dict): Config of last normalization layer. Default: + `LN`. Only used when `self.pre_norm` is `True` + """ + def __init__(self, *args, post_norm_cfg=dict(type='LN'), **kwargs): + super(DetrTransformerEncoder, self).__init__(*args, **kwargs) + if post_norm_cfg is not None: + self.post_norm = build_norm_layer( + post_norm_cfg, self.embed_dims)[1] if self.pre_norm else None + else: + assert not self.pre_norm, f'Use prenorm in ' \ + f'{self.__class__.__name__},' \ + f'Please specify post_norm_cfg' + self.post_norm = None + + def forward(self, *args, **kwargs): + """Forward function for `TransformerCoder`. + + Returns: + Tensor: forwarded results with shape [num_query, bs, embed_dims]. + """ + x = super(DetrTransformerEncoder, self).forward(*args, **kwargs) + if self.post_norm is not None: + x = self.post_norm(x) + return x + + +@TRANSFORMER_LAYER_SEQUENCE.register_module() +class DetrTransformerDecoder(TransformerLayerSequence): + """Implements the decoder in DETR transformer. + + Args: + return_intermediate (bool): Whether to return intermediate outputs. + post_norm_cfg (dict): Config of last normalization layer. Default: + `LN`. + """ + def __init__(self, + *args, + post_norm_cfg=dict(type='LN'), + return_intermediate=False, + **kwargs): + + super(DetrTransformerDecoder, self).__init__(*args, **kwargs) + self.return_intermediate = return_intermediate + if post_norm_cfg is not None: + self.post_norm = build_norm_layer(post_norm_cfg, + self.embed_dims)[1] + else: + self.post_norm = None + + def forward(self, query, *args, **kwargs): + """Forward function for `TransformerDecoder`. + + Args: + query (Tensor): Input query with shape + `(num_query, bs, embed_dims)`. + + Returns: + Tensor: Results with shape [1, num_query, bs, embed_dims] when + return_intermediate is `False`, otherwise it has shape + [num_layers, num_query, bs, embed_dims]. + """ + if not self.return_intermediate: + x = super().forward(query, *args, **kwargs) + if self.post_norm: + x = self.post_norm(x)[None] + return x + + intermediate = [] + for layer in self.layers: + query = layer(query, *args, **kwargs) + if self.return_intermediate: + if self.post_norm is not None: + intermediate.append(self.post_norm(query)) + else: + intermediate.append(query) + return torch.stack(intermediate) + + +@TRANSFORMER.register_module() +class Transformer(BaseModule): + """Implements the DETR transformer. + + Following the official DETR implementation, this module copy-paste + from torch.nn.Transformer with modifications: + + * positional encodings are passed in MultiheadAttention + * extra LN at the end of encoder is removed + * decoder returns a stack of activations from all decoding layers + + See `paper: End-to-End Object Detection with Transformers + `_ for details. + + Args: + encoder (`mmcv.ConfigDict` | Dict): Config of + TransformerEncoder. Defaults to None. + decoder ((`mmcv.ConfigDict` | Dict)): Config of + TransformerDecoder. Defaults to None + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Defaults to None. + """ + def __init__(self, encoder=None, decoder=None, init_cfg=None): + super(Transformer, self).__init__(init_cfg=init_cfg) + self.encoder = build_transformer_layer_sequence(encoder) + self.decoder = build_transformer_layer_sequence(decoder) + self.embed_dims = self.encoder.embed_dims + + def init_weights(self): + # follow the official DETR to init parameters + for m in self.modules(): + if hasattr(m, 'weight') and m.weight.dim() > 1: + xavier_init(m, distribution='uniform') + self._is_init = True + + def forward(self, x, mask, query_embed, pos_embed): + """Forward function for `Transformer`. + + Args: + x (Tensor): Input query with shape [bs, c, h, w] where + c = embed_dims. + mask (Tensor): The key_padding_mask used for encoder and decoder, + with shape [bs, h, w]. + query_embed (Tensor): The query embedding for decoder, with shape + [num_query, c]. + pos_embed (Tensor): The positional encoding for encoder and + decoder, with the same shape as `x`. + + Returns: + tuple[Tensor]: results of decoder containing the following tensor. + + - out_dec: Output from decoder. If return_intermediate_dec \ + is True output has shape [num_dec_layers, bs, + num_query, embed_dims], else has shape [1, bs, \ + num_query, embed_dims]. + - memory: Output results from encoder, with shape \ + [bs, embed_dims, h, w]. + """ + bs, c, h, w = x.shape + # use `view` instead of `flatten` for dynamically exporting to ONNX + x = x.view(bs, c, -1).permute(2, 0, 1) # [bs, c, h, w] -> [h*w, bs, c] + pos_embed = pos_embed.view(bs, c, -1).permute(2, 0, 1) + query_embed = query_embed.unsqueeze(1).repeat( + 1, bs, 1) # [num_query, dim] -> [num_query, bs, dim] + mask = mask.view(bs, -1) # [bs, h, w] -> [bs, h*w] + memory = self.encoder( + query=x, + key=None, + value=None, + query_pos=pos_embed, + query_key_padding_mask=mask) + target = torch.zeros_like(query_embed) + # out_dec: [num_layers, num_query, bs, dim] + out_dec = self.decoder( + query=target, + key=memory, + value=memory, + key_pos=pos_embed, + query_pos=query_embed, + key_padding_mask=mask) + out_dec = out_dec.transpose(1, 2) + memory = memory.permute(1, 2, 0).reshape(bs, c, h, w) + return out_dec, memory + + +@TRANSFORMER_LAYER_SEQUENCE.register_module() +class DeformableDetrTransformerDecoder(TransformerLayerSequence): + """Implements the decoder in DETR transformer. + + Args: + return_intermediate (bool): Whether to return intermediate outputs. + coder_norm_cfg (dict): Config of last normalization layer. Default: + `LN`. + """ + def __init__(self, *args, return_intermediate=False, **kwargs): + + super(DeformableDetrTransformerDecoder, self).__init__(*args, **kwargs) + self.return_intermediate = return_intermediate + + def forward(self, + query, + *args, + reference_points=None, + valid_ratios=None, + reg_branches=None, + **kwargs): + """Forward function for `TransformerDecoder`. + + Args: + query (Tensor): Input query with shape + `(num_query, bs, embed_dims)`. + reference_points (Tensor): The reference + points of offset. has shape + (bs, num_query, 4) when as_two_stage, + otherwise has shape ((bs, num_query, 2). + valid_ratios (Tensor): The radios of valid + points on the feature map, has shape + (bs, num_levels, 2) + reg_branch: (obj:`nn.ModuleList`): Used for + refining the regression results. Only would + be passed when with_box_refine is True, + otherwise would be passed a `None`. + + Returns: + Tensor: Results with shape [1, num_query, bs, embed_dims] when + return_intermediate is `False`, otherwise it has shape + [num_layers, num_query, bs, embed_dims]. + """ + output = query + intermediate = [] + intermediate_reference_points = [] + for lid, layer in enumerate(self.layers): + if reference_points.shape[-1] == 4: + reference_points_input = reference_points[:, :, None] * \ + torch.cat([valid_ratios, valid_ratios], + -1)[:, None] + else: + assert reference_points.shape[-1] == 2 + reference_points_input = reference_points[:, :, None] * \ + valid_ratios[:, None] + output = layer( + output, + *args, + reference_points=reference_points_input, + **kwargs) + output = output.permute(1, 0, 2) + + if reg_branches is not None: + tmp = reg_branches[lid](output) + if reference_points.shape[-1] == 4: + new_reference_points = tmp + inverse_sigmoid( + reference_points) + new_reference_points = new_reference_points.sigmoid() + else: + assert reference_points.shape[-1] == 2 + new_reference_points = tmp + new_reference_points[..., :2] = tmp[ + ..., :2] + inverse_sigmoid( + reference_points) + new_reference_points = new_reference_points.sigmoid() + reference_points = new_reference_points.detach() + + output = output.permute(1, 0, 2) + if self.return_intermediate: + intermediate.append(output) + intermediate_reference_points.append(reference_points) + + if self.return_intermediate: + return torch.stack(intermediate), torch.stack( + intermediate_reference_points) + + return output, reference_points + + +@TRANSFORMER.register_module() +class DeformableDetrTransformer(Transformer): + """Implements the DeformableDETR transformer. + + Args: + as_two_stage (bool): Generate query from encoder features. + Default: False. + num_feature_levels (int): Number of feature maps from FPN: + Default: 4. + two_stage_num_proposals (int): Number of proposals when set + `as_two_stage` as True. Default: 300. + """ + + def __init__(self, + as_two_stage=False, + num_feature_levels=4, + two_stage_num_proposals=300, + **kwargs): + super(DeformableDetrTransformer, self).__init__(**kwargs) + self.as_two_stage = as_two_stage + self.num_feature_levels = num_feature_levels + self.two_stage_num_proposals = two_stage_num_proposals + self.embed_dims = self.encoder.embed_dims + self.init_layers() + + def init_layers(self): + """Initialize layers of the DeformableDetrTransformer.""" + self.level_embeds = nn.Parameter( + torch.Tensor(self.num_feature_levels, self.embed_dims)) + + if self.as_two_stage: + self.enc_output = nn.Linear(self.embed_dims, self.embed_dims) + self.enc_output_norm = nn.LayerNorm(self.embed_dims) + self.pos_trans = nn.Linear(self.embed_dims * 2, + self.embed_dims * 2) + self.pos_trans_norm = nn.LayerNorm(self.embed_dims * 2) + else: + self.reference_points = nn.Linear(self.embed_dims, 2) + + def init_weights(self): + """Initialize the transformer weights.""" + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + for m in self.modules(): + if isinstance(m, MultiScaleDeformableAttention): + m.init_weights() + if not self.as_two_stage: + xavier_init(self.reference_points, distribution='uniform', bias=0.) + normal_(self.level_embeds) + + def gen_encoder_output_proposals(self, memory, memory_padding_mask, + spatial_shapes): + """Generate proposals from encoded memory. + + Args: + memory (Tensor) : The output of encoder, + has shape (bs, num_key, embed_dim). num_key is + equal the number of points on feature map from + all level. + memory_padding_mask (Tensor): Padding mask for memory. + has shape (bs, num_key). + spatial_shapes (Tensor): The shape of all feature maps. + has shape (num_level, 2). + + Returns: + tuple: A tuple of feature map and bbox prediction. + + - output_memory (Tensor): The input of decoder, \ + has shape (bs, num_key, embed_dim). num_key is \ + equal the number of points on feature map from \ + all levels. + - output_proposals (Tensor): The normalized proposal \ + after a inverse sigmoid, has shape \ + (bs, num_keys, 4). + """ + + N, S, C = memory.shape + proposals = [] + _cur = 0 + for lvl, (H, W) in enumerate(spatial_shapes): + mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H * W)].view( + N, H, W, 1) + valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1) + valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1) + + grid_y, grid_x = torch.meshgrid( + torch.linspace( + 0, H - 1, H, dtype=torch.float32, device=memory.device), + torch.linspace( + 0, W - 1, W, dtype=torch.float32, device=memory.device)) + grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) + + scale = torch.cat([valid_W.unsqueeze(-1), + valid_H.unsqueeze(-1)], 1).view(N, 1, 1, 2) + grid = (grid.unsqueeze(0).expand(N, -1, -1, -1) + 0.5) / scale + wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl) + proposal = torch.cat((grid, wh), -1).view(N, -1, 4) + proposals.append(proposal) + _cur += (H * W) + output_proposals = torch.cat(proposals, 1) + output_proposals_valid = ((output_proposals > 0.01) & + (output_proposals < 0.99)).all( + -1, keepdim=True) + output_proposals = torch.log(output_proposals / (1 - output_proposals)) + output_proposals = output_proposals.masked_fill( + memory_padding_mask.unsqueeze(-1), float('inf')) + output_proposals = output_proposals.masked_fill( + ~output_proposals_valid, float('inf')) + + output_memory = memory + output_memory = output_memory.masked_fill( + memory_padding_mask.unsqueeze(-1), float(0)) + output_memory = output_memory.masked_fill(~output_proposals_valid, + float(0)) + output_memory = self.enc_output_norm(self.enc_output(output_memory)) + return output_memory, output_proposals + + @staticmethod + def get_reference_points(spatial_shapes, valid_ratios, device): + """Get the reference points used in decoder. + + Args: + spatial_shapes (Tensor): The shape of all + feature maps, has shape (num_level, 2). + valid_ratios (Tensor): The radios of valid + points on the feature map, has shape + (bs, num_levels, 2) + device (obj:`device`): The device where + reference_points should be. + + Returns: + Tensor: reference points used in decoder, has \ + shape (bs, num_keys, num_levels, 2). + """ + reference_points_list = [] + for lvl, (H, W) in enumerate(spatial_shapes): + # TODO check this 0.5 + ref_y, ref_x = torch.meshgrid( + torch.linspace( + 0.5, H - 0.5, H, dtype=torch.float32, device=device), + torch.linspace( + 0.5, W - 0.5, W, dtype=torch.float32, device=device)) + ref_y = ref_y.reshape(-1)[None] / ( + valid_ratios[:, None, lvl, 1] * H) + ref_x = ref_x.reshape(-1)[None] / ( + valid_ratios[:, None, lvl, 0] * W) + ref = torch.stack((ref_x, ref_y), -1) + reference_points_list.append(ref) + reference_points = torch.cat(reference_points_list, 1) + reference_points = reference_points[:, :, None] * valid_ratios[:, None] + return reference_points + + def get_valid_ratio(self, mask): + """Get the valid radios of feature maps of all level.""" + _, H, W = mask.shape + valid_H = torch.sum(~mask[:, :, 0], 1) + valid_W = torch.sum(~mask[:, 0, :], 1) + valid_ratio_h = valid_H.float() / H + valid_ratio_w = valid_W.float() / W + valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) + return valid_ratio + + def get_proposal_pos_embed(self, + proposals, + num_pos_feats=128, + temperature=10000): + """Get the position embedding of proposal.""" + scale = 2 * math.pi + dim_t = torch.arange( + num_pos_feats, dtype=torch.float32, device=proposals.device) + dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats) + # N, L, 4 + proposals = proposals.sigmoid() * scale + # N, L, 4, 128 + pos = proposals[:, :, :, None] / dim_t + # N, L, 4, 64, 2 + pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), + dim=4).flatten(2) + return pos + + def forward(self, + mlvl_feats, + mlvl_masks, + query_embed, + mlvl_pos_embeds, + reg_branches=None, + cls_branches=None, + **kwargs): + """Forward function for `Transformer`. + + Args: + mlvl_feats (list(Tensor)): Input queries from + different level. Each element has shape + [bs, embed_dims, h, w]. + mlvl_masks (list(Tensor)): The key_padding_mask from + different level used for encoder and decoder, + each element has shape [bs, h, w]. + query_embed (Tensor): The query embedding for decoder, + with shape [num_query, c]. + mlvl_pos_embeds (list(Tensor)): The positional encoding + of feats from different level, has the shape + [bs, embed_dims, h, w]. + reg_branches (obj:`nn.ModuleList`): Regression heads for + feature maps from each decoder layer. Only would + be passed when + `with_box_refine` is True. Default to None. + cls_branches (obj:`nn.ModuleList`): Classification heads + for feature maps from each decoder layer. Only would + be passed when `as_two_stage` + is True. Default to None. + + + Returns: + tuple[Tensor]: results of decoder containing the following tensor. + + - inter_states: Outputs from decoder. If + return_intermediate_dec is True output has shape \ + (num_dec_layers, bs, num_query, embed_dims), else has \ + shape (1, bs, num_query, embed_dims). + - init_reference_out: The initial value of reference \ + points, has shape (bs, num_queries, 4). + - inter_references_out: The internal value of reference \ + points in decoder, has shape \ + (num_dec_layers, bs,num_query, embed_dims) + - enc_outputs_class: The classification score of \ + proposals generated from \ + encoder's feature maps, has shape \ + (batch, h*w, num_classes). \ + Only would be returned when `as_two_stage` is True, \ + otherwise None. + - enc_outputs_coord_unact: The regression results \ + generated from encoder's feature maps., has shape \ + (batch, h*w, 4). Only would \ + be returned when `as_two_stage` is True, \ + otherwise None. + """ + assert self.as_two_stage or query_embed is not None + + feat_flatten = [] + mask_flatten = [] + lvl_pos_embed_flatten = [] + spatial_shapes = [] + for lvl, (feat, mask, pos_embed) in enumerate( + zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)): + bs, c, h, w = feat.shape + spatial_shape = (h, w) + spatial_shapes.append(spatial_shape) + feat = feat.flatten(2).transpose(1, 2) + mask = mask.flatten(1) + pos_embed = pos_embed.flatten(2).transpose(1, 2) + lvl_pos_embed = pos_embed + self.level_embeds[lvl].view(1, 1, -1) + lvl_pos_embed_flatten.append(lvl_pos_embed) + feat_flatten.append(feat) + mask_flatten.append(mask) + feat_flatten = torch.cat(feat_flatten, 1) + mask_flatten = torch.cat(mask_flatten, 1) + lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) + spatial_shapes = torch.as_tensor( + spatial_shapes, dtype=torch.long, device=feat_flatten.device) + level_start_index = torch.cat((spatial_shapes.new_zeros( + (1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) + valid_ratios = torch.stack( + [self.get_valid_ratio(m) for m in mlvl_masks], 1) + + reference_points = \ + self.get_reference_points(spatial_shapes, + valid_ratios, + device=feat.device) + + feat_flatten = feat_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims) + lvl_pos_embed_flatten = lvl_pos_embed_flatten.permute( + 1, 0, 2) # (H*W, bs, embed_dims) + memory = self.encoder( + query=feat_flatten, + key=None, + value=None, + query_pos=lvl_pos_embed_flatten, + query_key_padding_mask=mask_flatten, + spatial_shapes=spatial_shapes, + reference_points=reference_points, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + **kwargs) + + memory = memory.permute(1, 0, 2) + bs, _, c = memory.shape + if self.as_two_stage: + output_memory, output_proposals = \ + self.gen_encoder_output_proposals( + memory, mask_flatten, spatial_shapes) + enc_outputs_class = cls_branches[self.decoder.num_layers]( + output_memory) + enc_outputs_coord_unact = \ + reg_branches[ + self.decoder.num_layers](output_memory) + output_proposals + + topk = self.two_stage_num_proposals + topk_proposals = torch.topk( + enc_outputs_class[..., 0], topk, dim=1)[1] + topk_coords_unact = torch.gather( + enc_outputs_coord_unact, 1, + topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) + topk_coords_unact = topk_coords_unact.detach() + reference_points = topk_coords_unact.sigmoid() + init_reference_out = reference_points + pos_trans_out = self.pos_trans_norm( + self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact))) + query_pos, query = torch.split(pos_trans_out, c, dim=2) + else: + query_pos, query = torch.split(query_embed, c, dim=1) + query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1) + query = query.unsqueeze(0).expand(bs, -1, -1) + reference_points = self.reference_points(query_pos).sigmoid() + init_reference_out = reference_points + + # decoder + query = query.permute(1, 0, 2) + memory = memory.permute(1, 0, 2) + query_pos = query_pos.permute(1, 0, 2) + inter_states, inter_references = self.decoder( + query=query, + key=None, + value=memory, + query_pos=query_pos, + key_padding_mask=mask_flatten, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + reg_branches=reg_branches, + **kwargs) + + inter_references_out = inter_references + if self.as_two_stage: + return inter_states, init_reference_out, \ + inter_references_out, enc_outputs_class, \ + enc_outputs_coord_unact + return inter_states, init_reference_out, \ + inter_references_out, None, None + + +@TRANSFORMER.register_module() +class DynamicConv(BaseModule): + """Implements Dynamic Convolution. + + This module generate parameters for each sample and + use bmm to implement 1*1 convolution. Code is modified + from the `official github repo `_ . + + Args: + in_channels (int): The input feature channel. + Defaults to 256. + feat_channels (int): The inner feature channel. + Defaults to 64. + out_channels (int, optional): The output feature channel. + When not specified, it will be set to `in_channels` + by default + input_feat_shape (int): The shape of input feature. + Defaults to 7. + with_proj (bool): Project two-dimentional feature to + one-dimentional feature. Default to True. + act_cfg (dict): The activation config for DynamicConv. + norm_cfg (dict): Config dict for normalization layer. Default + layer normalization. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + def __init__(self, + in_channels=256, + feat_channels=64, + out_channels=None, + input_feat_shape=7, + with_proj=True, + act_cfg=dict(type='ReLU', inplace=True), + norm_cfg=dict(type='LN'), + init_cfg=None): + super(DynamicConv, self).__init__(init_cfg) + self.in_channels = in_channels + self.feat_channels = feat_channels + self.out_channels_raw = out_channels + self.input_feat_shape = input_feat_shape + self.with_proj = with_proj + self.act_cfg = act_cfg + self.norm_cfg = norm_cfg + self.out_channels = out_channels if out_channels else in_channels + + self.num_params_in = self.in_channels * self.feat_channels + self.num_params_out = self.out_channels * self.feat_channels + self.dynamic_layer = nn.Linear( + self.in_channels, self.num_params_in + self.num_params_out) + + self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1] + self.norm_out = build_norm_layer(norm_cfg, self.out_channels)[1] + + self.activation = build_activation_layer(act_cfg) + + num_output = self.out_channels * input_feat_shape ** 2 + if self.with_proj: + self.fc_layer = nn.Linear(num_output, self.out_channels) + self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1] + + def forward(self, param_feature, input_feature): + """Forward function for `DynamicConv`. + + Args: + param_feature (Tensor): The feature can be used + to generate the parameter, has shape + (num_all_proposals, in_channels). + input_feature (Tensor): Feature that + interact with parameters, has shape + (num_all_proposals, in_channels, H, W). + + Returns: + Tensor: The output feature has shape + (num_all_proposals, out_channels). + """ + input_feature = input_feature.flatten(2).permute(2, 0, 1) + + input_feature = input_feature.permute(1, 0, 2) + parameters = self.dynamic_layer(param_feature) + + param_in = parameters[:, :self.num_params_in].view( + -1, self.in_channels, self.feat_channels) + param_out = parameters[:, -self.num_params_out:].view( + -1, self.feat_channels, self.out_channels) + + # input_feature has shape (num_all_proposals, H*W, in_channels) + # param_in has shape (num_all_proposals, in_channels, feat_channels) + # feature has shape (num_all_proposals, H*W, feat_channels) + features = torch.bmm(input_feature, param_in) + features = self.norm_in(features) + features = self.activation(features) + + # param_out has shape (batch_size, feat_channels, out_channels) + features = torch.bmm(features, param_out) + features = self.norm_out(features) + features = self.activation(features) + + if self.with_proj: + features = features.flatten(1) + features = self.fc_layer(features) + features = self.fc_norm(features) + features = self.activation(features) + + return features diff --git a/segmentation/slurm_test.sh b/segmentation/slurm_test.sh new file mode 100644 index 000000000..24de10aa3 --- /dev/null +++ b/segmentation/slurm_test.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash + +set -x + +PARTITION=$1 +JOB_NAME=$2 +CONFIG=$3 +CHECKPOINT=$4 +GPUS=${GPUS:-8} +GPUS_PER_NODE=${GPUS_PER_NODE:-8} +CPUS_PER_TASK=${CPUS_PER_TASK:-5} +PY_ARGS=${@:5} +SRUN_ARGS=${SRUN_ARGS:-""} + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ +srun -p ${PARTITION} \ + --job-name=${JOB_NAME} \ + --gres=gpu:${GPUS_PER_NODE} \ + --ntasks=${GPUS} \ + --ntasks-per-node=${GPUS_PER_NODE} \ + --cpus-per-task=${CPUS_PER_TASK} \ + --kill-on-bad-exit=1 \ + ${SRUN_ARGS} \ + python -u test.py ${CONFIG} ${CHECKPOINT} --launcher="slurm" ${PY_ARGS} diff --git a/segmentation/slurm_train.sh b/segmentation/slurm_train.sh new file mode 100644 index 000000000..12c10aae1 --- /dev/null +++ b/segmentation/slurm_train.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash + +set -x + +PARTITION=$1 +JOB_NAME=$2 +CONFIG=$3 +GPUS=${GPUS:-8} +GPUS_PER_NODE=${GPUS_PER_NODE:-8} +CPUS_PER_TASK=${CPUS_PER_TASK:-5} +SRUN_ARGS=${SRUN_ARGS:-""} +PY_ARGS=${@:4} + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ +srun -p ${PARTITION} \ + --job-name=${JOB_NAME} \ + --gres=gpu:${GPUS_PER_NODE} \ + --ntasks=${GPUS} \ + --ntasks-per-node=${GPUS_PER_NODE} \ + --cpus-per-task=${CPUS_PER_TASK} \ + --kill-on-bad-exit=1 \ + ${SRUN_ARGS} \ + python -u train.py ${CONFIG} --launcher="slurm" ${PY_ARGS} diff --git a/segmentation/test.py b/segmentation/test.py new file mode 100644 index 000000000..371779c54 --- /dev/null +++ b/segmentation/test.py @@ -0,0 +1,274 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os +import os.path as osp +import shutil +import time +import warnings + +import mmcv +import mmcv_custom +import mmseg_custom +import torch +from mmcv.parallel import MMDataParallel, MMDistributedDataParallel +from mmcv.runner import (get_dist_info, init_dist, load_checkpoint, + wrap_fp16_model) +from mmcv.utils import DictAction +from mmseg.apis import multi_gpu_test, single_gpu_test +from mmseg.datasets import build_dataloader, build_dataset +from mmseg.models import build_segmentor + + +def parse_args(): + parser = argparse.ArgumentParser( + description='mmseg test (and eval) a model') + parser.add_argument('config', help='test config file path') + parser.add_argument('checkpoint', help='checkpoint file') + parser.add_argument( + '--work-dir', + help=('if specified, the evaluation metric results will be dumped' + 'into the directory as json')) + parser.add_argument( + '--aug-test', action='store_true', help='Use Flip and Multi scale aug') + parser.add_argument('--out', help='output result file in pickle format') + parser.add_argument( + '--format-only', + action='store_true', + help='Format the output results without perform evaluation. It is' + 'useful when you want to format the result to a specific format and ' + 'submit it to the test server') + parser.add_argument( + '--eval', + type=str, + nargs='+', + help='evaluation metrics, which depends on the dataset, e.g., "mIoU"' + ' for generic datasets, and "cityscapes" for Cityscapes') + parser.add_argument('--show', action='store_true', help='show results') + parser.add_argument( + '--show-dir', help='directory where painted images will be saved') + parser.add_argument( + '--gpu-collect', + action='store_true', + help='whether to use gpu to collect results.') + parser.add_argument( + '--tmpdir', + help='tmp directory used for collecting results from multiple ' + 'workers, available when gpu_collect is not specified') + parser.add_argument( + '--options', + nargs='+', + action=DictAction, + help="--options is deprecated in favor of --cfg_options' and it will " + 'not be supported in version v0.22.0. Override some settings in the ' + 'used config, the key-value pair in xxx=yyy format will be merged ' + 'into config file. If the value to be overwritten is a list, it ' + 'should be like key="[a,b]" or key=a,b It also allows nested ' + 'list/tuple values, e.g. key="[(a,b),(c,d)]" Note that the quotation ' + 'marks are necessary and that no white space is allowed.') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + parser.add_argument( + '--eval-options', + nargs='+', + action=DictAction, + help='custom options for evaluation') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') + parser.add_argument( + '--opacity', + type=float, + default=0.5, + help='Opacity of painted segmentation map. In (0, 1] range.') + parser.add_argument('--local_rank', type=int, default=0) + args = parser.parse_args() + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = str(args.local_rank) + + if args.options and args.cfg_options: + raise ValueError( + '--options and --cfg-options cannot be both ' + 'specified, --options is deprecated in favor of --cfg-options. ' + '--options will not be supported in version v0.22.0.') + if args.options: + warnings.warn('--options is deprecated in favor of --cfg-options. ' + '--options will not be supported in version v0.22.0.') + args.cfg_options = args.options + + return args + + +def main(): + args = parse_args() + assert args.out or args.eval or args.format_only or args.show \ + or args.show_dir, \ + ('Please specify at least one operation (save/eval/format/show the ' + 'results / save the results) with the argument "--out", "--eval"' + ', "--format-only", "--show" or "--show-dir"') + + if args.eval and args.format_only: + raise ValueError('--eval and --format_only cannot be both specified') + + if args.out is not None and not args.out.endswith(('.pkl', '.pickle')): + raise ValueError('The output file must be a pkl file.') + + cfg = mmcv.Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + # set cudnn_benchmark + if cfg.get('cudnn_benchmark', False): + torch.backends.cudnn.benchmark = True + if args.aug_test: + # hard code index + cfg.data.test.pipeline[1].img_ratios = [ + 0.5, 0.75, 1.0, 1.25, 1.5, 1.75 + ] + cfg.data.test.pipeline[1].flip = True + cfg.model.pretrained = None + cfg.data.test.test_mode = True + + # init distributed env first, since logger depends on the dist info. + if args.launcher == 'none': + distributed = False + else: + distributed = True + init_dist(args.launcher, **cfg.dist_params) + + rank, _ = get_dist_info() + # allows not to create + if args.work_dir is not None and rank == 0: + mmcv.mkdir_or_exist(osp.abspath(args.work_dir)) + timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) + if args.aug_test: + json_file = osp.join(args.work_dir, + f'eval_multi_scale_{timestamp}.json') + else: + json_file = osp.join(args.work_dir, + f'eval_single_scale_{timestamp}.json') + elif rank == 0: + work_dir = osp.join('./work_dirs', + osp.splitext(osp.basename(args.config))[0]) + mmcv.mkdir_or_exist(osp.abspath(work_dir)) + timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) + if args.aug_test: + json_file = osp.join(work_dir, + f'eval_multi_scale_{timestamp}.json') + else: + json_file = osp.join(work_dir, + f'eval_single_scale_{timestamp}.json') + + # build the dataloader + # TODO: support multiple images per gpu (only minor changes are needed) + dataset = build_dataset(cfg.data.test) + data_loader = build_dataloader( + dataset, + samples_per_gpu=1, + workers_per_gpu=cfg.data.workers_per_gpu, + dist=distributed, + shuffle=False) + + # build the model and load checkpoint + cfg.model.train_cfg = None + model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg')) + fp16_cfg = cfg.get('fp16', None) + if fp16_cfg is not None: + wrap_fp16_model(model) + checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu') + if 'CLASSES' in checkpoint.get('meta', {}): + model.CLASSES = checkpoint['meta']['CLASSES'] + else: + print('"CLASSES" not found in meta, use dataset.CLASSES instead') + model.CLASSES = dataset.CLASSES + if 'PALETTE' in checkpoint.get('meta', {}): + model.PALETTE = checkpoint['meta']['PALETTE'] + else: + print('"PALETTE" not found in meta, use dataset.PALETTE instead') + model.PALETTE = dataset.PALETTE + + # clean gpu memory when starting a new evaluation. + torch.cuda.empty_cache() + eval_kwargs = {} if args.eval_options is None else args.eval_options + + # Deprecated + efficient_test = eval_kwargs.get('efficient_test', False) + if efficient_test: + warnings.warn( + '``efficient_test=True`` does not have effect in tools/test.py, ' + 'the evaluation and format results are CPU memory efficient by ' + 'default') + + eval_on_format_results = ( + args.eval is not None and 'cityscapes' in args.eval) + if eval_on_format_results: + assert len(args.eval) == 1, 'eval on format results is not ' \ + 'applicable for metrics other than ' \ + 'cityscapes' + if args.format_only or eval_on_format_results: + if 'imgfile_prefix' in eval_kwargs: + tmpdir = eval_kwargs['imgfile_prefix'] + else: + tmpdir = '.format_cityscapes' + eval_kwargs.setdefault('imgfile_prefix', tmpdir) + mmcv.mkdir_or_exist(tmpdir) + else: + tmpdir = None + + if not distributed: + model = MMDataParallel(model, device_ids=[0]) + results = single_gpu_test( + model, + data_loader, + args.show, + args.show_dir, + False, + args.opacity, + pre_eval=args.eval is not None and not eval_on_format_results, + format_only=args.format_only or eval_on_format_results, + format_args=eval_kwargs) + else: + model = MMDistributedDataParallel( + model.cuda(), + device_ids=[torch.cuda.current_device()], + broadcast_buffers=False) + results = multi_gpu_test( + model, + data_loader, + args.tmpdir, + args.gpu_collect, + False, + pre_eval=args.eval is not None and not eval_on_format_results, + format_only=args.format_only or eval_on_format_results, + format_args=eval_kwargs) + + rank, _ = get_dist_info() + if rank == 0: + if args.out: + warnings.warn( + 'The behavior of ``args.out`` has been changed since MMSeg ' + 'v0.16, the pickled outputs could be seg map as type of ' + 'np.array, pre-eval results or file paths for ' + '``dataset.format_results()``.') + print(f'\nwriting results to {args.out}') + mmcv.dump(results, args.out) + if args.eval: + eval_kwargs.update(metric=args.eval) + metric = dataset.evaluate(results, **eval_kwargs) + metric_dict = dict(config=args.config, metric=metric) + mmcv.dump(metric_dict, json_file, indent=4) + if tmpdir is not None and eval_on_format_results: + # remove tmp dir when cityscapes evaluation + shutil.rmtree(tmpdir) + + +if __name__ == '__main__': + main() diff --git a/segmentation/train.py b/segmentation/train.py new file mode 100644 index 000000000..86a5e443d --- /dev/null +++ b/segmentation/train.py @@ -0,0 +1,215 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import copy +import os +import os.path as osp +import time +import warnings + +import mmcv +import mmcv_custom +import mmseg_custom +import torch +from mmcv.cnn.utils import revert_sync_batchnorm +from mmcv.runner import get_dist_info, init_dist +from mmcv.utils import Config, DictAction, get_git_hash +from mmseg import __version__ +from mmseg.apis import init_random_seed, set_random_seed, train_segmentor +from mmseg.datasets import build_dataset +from mmseg.models import build_segmentor +from mmseg.utils import collect_env, get_root_logger + + +def parse_args(): + parser = argparse.ArgumentParser(description='Train a segmentor') + parser.add_argument('config', help='train config file path') + parser.add_argument('--work-dir', help='the dir to save logs and models') + parser.add_argument( + '--load-from', help='the checkpoint file to load weights from') + parser.add_argument( + '--resume-from', help='the checkpoint file to resume from') + parser.add_argument( + '--no-validate', + action='store_true', + help='whether not to evaluate the checkpoint during training') + group_gpus = parser.add_mutually_exclusive_group() + group_gpus.add_argument( + '--gpus', + type=int, + help='number of gpus to use ' + '(only applicable to non-distributed training)') + group_gpus.add_argument( + '--gpu-ids', + type=int, + nargs='+', + help='ids of gpus to use ' + '(only applicable to non-distributed training)') + parser.add_argument('--seed', type=int, default=None, help='random seed') + parser.add_argument( + '--deterministic', + action='store_true', + help='whether to set deterministic options for CUDNN backend.') + parser.add_argument( + '--options', + nargs='+', + action=DictAction, + help="--options is deprecated in favor of --cfg_options' and it will " + 'not be supported in version v0.22.0. Override some settings in the ' + 'used config, the key-value pair in xxx=yyy format will be merged ' + 'into config file. If the value to be overwritten is a list, it ' + 'should be like key="[a,b]" or key=a,b It also allows nested ' + 'list/tuple values, e.g. key="[(a,b),(c,d)]" Note that the quotation ' + 'marks are necessary and that no white space is allowed.') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') + parser.add_argument('--local_rank', type=int, default=0) + parser.add_argument( + '--auto-resume', + action='store_true', + help='resume from the latest checkpoint automatically.') + args = parser.parse_args() + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = str(args.local_rank) + + if args.options and args.cfg_options: + raise ValueError( + '--options and --cfg-options cannot be both ' + 'specified, --options is deprecated in favor of --cfg-options. ' + '--options will not be supported in version v0.22.0.') + if args.options: + warnings.warn('--options is deprecated in favor of --cfg-options. ' + '--options will not be supported in version v0.22.0.') + args.cfg_options = args.options + + return args + + +def main(): + args = parse_args() + + cfg = Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + # set cudnn_benchmark + if cfg.get('cudnn_benchmark', False): + torch.backends.cudnn.benchmark = True + + # work_dir is determined in this priority: CLI > segment in file > filename + if args.work_dir is not None: + # update configs according to CLI args if args.work_dir is not None + cfg.work_dir = args.work_dir + elif cfg.get('work_dir', None) is None: + # use config filename as default work_dir if cfg.work_dir is None + cfg.work_dir = osp.join('./work_dirs', + osp.splitext(osp.basename(args.config))[0]) + if args.load_from is not None: + cfg.load_from = args.load_from + if args.resume_from is not None: + cfg.resume_from = args.resume_from + if args.gpu_ids is not None: + cfg.gpu_ids = args.gpu_ids + else: + cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus) + cfg.auto_resume = args.auto_resume + + # init distributed env first, since logger depends on the dist info. + if args.launcher == 'none': + distributed = False + else: + distributed = True + init_dist(args.launcher, **cfg.dist_params) + # gpu_ids is used to calculate iter when resuming checkpoint + _, world_size = get_dist_info() + cfg.gpu_ids = range(world_size) + + # create work_dir + mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) + # dump config + cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) + # init the logger before other steps + timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) + log_file = osp.join(cfg.work_dir, f'{timestamp}.log') + logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) + + # init the meta dict to record some important information such as + # environment info and seed, which will be logged + meta = dict() + # log env info + env_info_dict = collect_env() + env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()]) + dash_line = '-' * 60 + '\n' + logger.info('Environment info:\n' + dash_line + env_info + '\n' + + dash_line) + meta['env_info'] = env_info + + # log some basic info + logger.info(f'Distributed training: {distributed}') + logger.info(f'Config:\n{cfg.pretty_text}') + + # set random seeds + seed = init_random_seed(args.seed) + logger.info(f'Set random seed to {seed}, ' + f'deterministic: {args.deterministic}') + set_random_seed(seed, deterministic=args.deterministic) + cfg.seed = seed + meta['seed'] = seed + meta['exp_name'] = osp.basename(args.config) + + model = build_segmentor( + cfg.model, + train_cfg=cfg.get('train_cfg'), + test_cfg=cfg.get('test_cfg')) + model.init_weights() + + # SyncBN is not support for DP + if not distributed: + warnings.warn( + 'SyncBN is only supported with DDP. To be compatible with DP, ' + 'we convert SyncBN to BN. Please use dist_train.sh which can ' + 'avoid this error.') + model = revert_sync_batchnorm(model) + + logger.info(model) + + datasets = [build_dataset(cfg.data.train)] + if len(cfg.workflow) == 2: + val_dataset = copy.deepcopy(cfg.data.val) + val_dataset.pipeline = cfg.data.train.pipeline + datasets.append(build_dataset(val_dataset)) + if cfg.checkpoint_config is not None: + # save mmseg version, config file content and class names in + # checkpoints as meta data + cfg.checkpoint_config.meta = dict( + mmseg_version=f'{__version__}+{get_git_hash()[:7]}', + config=cfg.pretty_text, + CLASSES=datasets[0].CLASSES, + PALETTE=datasets[0].PALETTE) + # add an attribute for visualization convenience + model.CLASSES = datasets[0].CLASSES + # passing checkpoint meta for saving best checkpoint + meta.update(cfg.checkpoint_config.meta) + train_segmentor( + model, + datasets, + cfg, + distributed=distributed, + validate=(not args.no_validate), + timestamp=timestamp, + meta=meta) + + +if __name__ == '__main__': + main()