Skip to content

Commit

Permalink
added inference scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
LIN Yun committed Jun 30, 2020
1 parent 80974bc commit 715221e
Show file tree
Hide file tree
Showing 8 changed files with 6,664 additions and 34 deletions.
10 changes: 6 additions & 4 deletions configs/cascade_mask_rcnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ MODEL:
DATASETS:
TRAIN: ("benign_train",)
TEST: ("benign_test",)
DATALOADER:
NUM_WORKERS: 0
SOLVER:
IMS_PER_BATCH: 16 # Batch size; Default 16
BASE_LR: 0.0015
STEPS: (60000, 80000) # The iteration number to decrease learning rate by GAMMA.
MAX_ITER: 90000 # Number of training steps
IMS_PER_BATCH: 10 # Batch size; Default 16
BASE_LR: 0.0001
STEPS: (17000, 22000) # The iteration number to decrease learning rate by GAMMA.
MAX_ITER: 25000 # Number of training steps
CHECKPOINT_PERIOD: 2500 # Saves checkpoint every number of steps
INPUT:
MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) # Image input sizes
Expand Down
68 changes: 68 additions & 0 deletions inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""Inference on a single image.
"""

from argparse import ArgumentParser

import cv2
from detectron2.config import get_cfg
from detectron2.data import MetadataCatalog
from detectron2.engine import DefaultPredictor
from detectron2.utils.visualizer import Visualizer
from PIL import Image

import detectron2_1


def main(args):
# Configure weights and confidence threshold
cfg = get_cfg()
cfg.merge_from_file(args.config_path)
cfg.MODEL.WEIGHTS = args.weights_path
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = args.conf_threshold

# Initialize model
predictor = DefaultPredictor(cfg)

# Load image as numpy array
im = cv2.imread(args.img_path)

# Perform inference
outputs = predictor(im)

# Set dataset categories
# FIXME: Specifc to this task
MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).thing_classes = ["box", "logo"]

# Draw instance predictions
v = Visualizer(im[:, :, ::-1], MetadataCatalog.get(cfg.DATASETS.TRAIN[0]))
out = v.draw_instance_predictions(outputs["instances"].to("cpu"))

# Image with instance predictions as numpy array
pred = out.get_image()

# Save image with instance predictions
Image.fromarray(pred).save(args.output_path)


def get_args():
parser = ArgumentParser()

parser.add_argument("--img-path", help="Path to image to perform inference on")
parser.add_argument("--config-path", help="Path to config file of model")
parser.add_argument("--weights-path", help="Path to model weights")
parser.add_argument(
"--output-path", help="Path to save image with instance predictions"
)
parser.add_argument(
"--conf-threshold",
type=float,
default="0.05",
help="Confidence threshold of predictions, default 0.05",
)

return parser.parse_args()


if __name__ == "__main__":
args = get_args()
main(args)
238 changes: 238 additions & 0 deletions notebooks/inference.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit 715221e

Please sign in to comment.