Source code for hello.mmseg.infer

"""pip install mmsegmentation"""
import shutil
import sys
import warnings
from pathlib import Path

import cv2 as cv
import numpy as np
from tqdm import tqdm

from mmseg.apis import inference_segmentor, init_segmentor

suffix_set = set(".avi,.mp4,.MOV,.mkv".split(","))
# ignore warnings when segmentors inference
warnings.filterwarnings('ignore')


[docs] def test_image(model, img, palette, out_dir, out_name, add_zero_label=False): """Inference image. Args: model (nn.Module): The loaded segmentor. img (str/ndarray): Image file or loaded image. palette (list[list[int]]): The palette of segmentation map. out_dir (pathlib.Path): The directory to save. out_name (str): The file name to save. add_zero_label (bool, optional): Defaults to False. """ if isinstance(img, str): img = cv.imread(img, flags=cv.IMREAD_COLOR) result = inference_segmentor(model, img) mask = model.show_result(img, result, palette=palette, show=False, out_file=None, opacity=0.5) mask_pure = model.show_result(img, result, palette=palette, show=False, out_file=None, opacity=1.0) out_file = str(out_dir / "data" / f"{out_name}.jpg") cv.imwrite(out_file, img) out_file = str(out_dir / "results" / f"{out_name}.jpg") cv.imwrite(out_file, np.concatenate((img, mask, mask_pure), axis=0)) out_file = str(out_dir / "predictions" / f"{out_name}.png") if add_zero_label: result = [x + 1 for x in result] cv.imwrite(out_file, result[0].clip(min=0, max=255).astype("uint8"))
[docs] def test_images(model, image_paths, palette, out_dir, add_zero_label=False): """Inference images. Args: model (nn.Module): The loaded segmentor. image_paths (list[str]): Image files to inference. palette (list[list[int]]): The palette of segmentation map. out_dir (pathlib.Path): The directory to save. add_zero_label (bool, optional): Defaults to False. """ for img in tqdm(image_paths): out_name = f"{Path(img).stem}" test_image(model, img, palette, out_dir, out_name, add_zero_label)
[docs] def test_video(model, video_path, palette, out_dir, add_zero_label=False): """Inference video. Args: model (nn.Module): The loaded segmentor. video_path (str): Video file to inference. palette (list[list[int]]): The palette of segmentation map. out_dir (pathlib.Path): The directory to save. add_zero_label (bool, optional): Defaults to False. """ cap = cv.VideoCapture(video_path) cap_fps = int(cap.get(cv.CAP_PROP_FPS)) frame_count = int(cap.get(cv.CAP_PROP_FRAME_COUNT)) prefix = Path(video_path).stem for index in tqdm(range(frame_count)): ret, img = cap.read() if not ret: print("Can't receive frame (stream end?). Exiting ...") break if (index % cap_fps) != 0: continue out_name = f"{prefix}_{index:06d}" test_image(model, img, palette, out_dir, out_name, add_zero_label) cap.release()
[docs] def test_videos(model, video_paths, palette, out_dir, add_zero_label=False): """Inference videos. Args: model (nn.Module): The loaded segmentor. video_paths (list[str]): Video files to inference. palette (list[list[int]]): The palette of segmentation map. out_dir (pathlib.Path): The directory to save. add_zero_label (bool, optional): Defaults to False. """ for video_path in video_paths: test_video(model, video_path, palette, out_dir, add_zero_label)
[docs] def func(root, config_file, checkpoint_file, testdata, out_dir, add_zero_label=False): """Inference test data. Args: root (_type_): _description_ config_file (_type_): _description_ checkpoint_file (_type_): _description_ testdata (_type_): _description_ out_dir (_type_): _description_ add_zero_label (bool, optional): Defaults to False. """ root = Path(root) config_file = str(root / config_file) checkpoint_file = str(root / checkpoint_file) model = init_segmentor(config_file, checkpoint_file, device="cuda:0") testdata = Path(testdata) assert testdata.is_file() or testdata.is_dir() out_dir = Path(out_dir) shutil.rmtree(out_dir, ignore_errors=True) (out_dir / "data").mkdir(parents=True, exist_ok=False) (out_dir / "results").mkdir(parents=True, exist_ok=False) (out_dir / "predictions").mkdir(parents=True, exist_ok=False) if testdata.is_file(): testdata = [testdata] else: testdata = list(testdata.glob("**/*")) image_paths = [str(f) for f in testdata if f.suffix == ".jpg"] video_paths = [str(f) for f in testdata if f.suffix in suffix_set] if image_paths: test_images(model, image_paths, model.PALETTE, out_dir, add_zero_label) if video_paths: test_videos(model, video_paths, model.PALETTE, out_dir, add_zero_label)
[docs] def parse_args(args=None): from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) parser.add_argument("root", type=str, help="the base dir") parser.add_argument("config_file", type=str, help="config file path") parser.add_argument("checkpoint_file", type=str, help="checkpoint file path") parser.add_argument("testdata", type=str, help="image/video(s) path or dir") parser.add_argument("-o", dest="out_dir", type=str, help="save results") parser.add_argument("-y", dest="add_zero_label", action="store_true", help="add zero label as background") args = parser.parse_args(args=args) return vars(args)
[docs] def main(args=None): kwargs = parse_args(args) print(f"{__file__}: {kwargs}") print(func(**kwargs)) return 0
if __name__ == "__main__": sys.exit(main())