Skip to content

Commit

Permalink
add video demo for segmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
czczup committed Dec 16, 2022
1 parent e8a5f21 commit ef394ed
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 0 deletions.
16 changes: 16 additions & 0 deletions segmentation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,19 @@ CUDA_VISIBLE_DEVICES=0 python image_demo.py \

The result will be saved at `demo/ADE_val_00000591.jpg`.
![image](https://s3.bmp.ovh/imgs/2022/06/05/3c7d0cb18e9f45eb.jpg)

## Video Demo
To inference a single video like this:

```
CUDA_VISIBLE_DEVICES=0 python video_demo.py demo.mp4 \
configs/ade20k/mask2former_beit_adapter_large_896_80k_ade20k_ss.py \
released/mask2former_beit_adapter_large_896_80k_ade20k.pth.tar \
--output-file results.mp4 \
--palette ade20k
```


https://user-images.githubusercontent.com/23737120/208140362-f2029060-eb16-4280-b85f-074006547a12.mp4


120 changes: 120 additions & 0 deletions segmentation/video_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Copyright (c) OpenMMLab. All rights reserved.
from argparse import ArgumentParser

import cv2
import mmcv_custom # noqa: F401,F403
import mmseg_custom # noqa: F401,F403
from mmseg.apis import inference_segmentor, init_segmentor
from mmseg.core.evaluation import get_palette
from mmcv.runner import load_checkpoint
from mmseg.core import get_classes


def main():
parser = ArgumentParser()
parser.add_argument('video', help='Video file or webcam id')
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--palette',
default='cityscapes',
help='Color palette used for segmentation map')
parser.add_argument(
'--show', action='store_true', help='Whether to show draw result')
parser.add_argument(
'--show-wait-time', default=1, type=int, help='Wait time after imshow')
parser.add_argument(
'--output-file', default=None, type=str, help='Output video file path')
parser.add_argument(
'--output-fourcc',
default='MJPG',
type=str,
help='Fourcc of the output video')
parser.add_argument(
'--output-fps', default=-1, type=int, help='FPS of the output video')
parser.add_argument(
'--output-height',
default=-1,
type=int,
help='Frame height of the output video')
parser.add_argument(
'--output-width',
default=-1,
type=int,
help='Frame width of the output video')
parser.add_argument(
'--opacity',
type=float,
default=0.5,
help='Opacity of painted segmentation map. In (0, 1] range.')
args = parser.parse_args()

assert args.show or args.output_file, \
'At least one output should be enabled.'

# build the model from a config file and a checkpoint file
model = init_segmentor(args.config, None, device=args.device)
checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
if 'CLASSES' in checkpoint.get('meta', {}):
model.CLASSES = checkpoint['meta']['CLASSES']
else:
model.CLASSES = get_classes(args.palette)

# build input video
cap = cv2.VideoCapture(args.video)
assert (cap.isOpened())
input_height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
input_width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
input_fps = cap.get(cv2.CAP_PROP_FPS)

# init output video
writer = None
output_height = None
output_width = None
if args.output_file is not None:
fourcc = cv2.VideoWriter_fourcc(*args.output_fourcc)
output_fps = args.output_fps if args.output_fps > 0 else input_fps
output_height = args.output_height if args.output_height > 0 else int(
input_height)
output_width = args.output_width if args.output_width > 0 else int(
input_width)
writer = cv2.VideoWriter(args.output_file, fourcc, output_fps,
(output_width, output_height), True)

# start looping
try:
while True:
flag, frame = cap.read()
if not flag:
break

# test a single image
result = inference_segmentor(model, frame)

# blend raw image and prediction
draw_img = model.show_result(
frame,
result,
palette=get_palette(args.palette),
show=False,
opacity=args.opacity)

if args.show:
cv2.imshow('video_demo', draw_img)
cv2.waitKey(args.show_wait_time)
if writer:
if draw_img.shape[0] != output_height or draw_img.shape[
1] != output_width:
draw_img = cv2.resize(draw_img,
(output_width, output_height))
writer.write(draw_img)
finally:
if writer:
writer.release()
cap.release()


if __name__ == '__main__':
main()

0 comments on commit ef394ed

Please sign in to comment.