# mmseg1.x
# pip install openmim
# mim install mmengine
# mim install mmcv==2.0.0
# mim install mmseg==1.0.0
import shutil
import sys
import warnings
from pathlib import Path
import cv2 as cv
import numpy as np
import torch
from tqdm import tqdm
from mmseg.apis import inference_model, init_model
suffix_set = set(".avi,.mp4,.MOV,.mkv".split(","))
# ignore warnings when segmentors inference
warnings.filterwarnings("ignore")
[docs]
def tensor2ndarray(value):
if isinstance(value, torch.Tensor):
value = value.detach().cpu().numpy()
return value
[docs]
def draw_sem_seg(sem_seg, classes, palette):
num_classes = len(classes)
ids = np.unique(sem_seg)[::-1]
legal_indices = ids < num_classes
ids = ids[legal_indices]
labels = np.array(ids, dtype=np.int64)
colors = [palette[label] for label in labels]
mask = np.zeros(sem_seg.shape + (3,), dtype="uint8")
for label, color in zip(labels, colors):
mask[sem_seg == label] = color
return mask
[docs]
def test_image(model, classes, palette, img, out_name, out_dir, add_zero_label=False):
"""Inference image.
Args:
model (nn.Module): The loaded segmentor.
img (str/ndarray): Image file or loaded image.
out_name (str): The file name to save.
out_dir (pathlib.Path): The directory to save.
add_zero_label (bool, optional): Defaults to False.
"""
if isinstance(img, str):
img = cv.imread(img, flags=cv.IMREAD_COLOR)
assert isinstance(img, np.ndarray), "a loaded image"
result = inference_model(model, img)
pred_sem_seg = result.pred_sem_seg.cpu().data # 1xHxW
seg_logits = result.seg_logits.cpu().data # CxHxW
pred_sem_seg = tensor2ndarray(pred_sem_seg)
seg_logits = tensor2ndarray(seg_logits)
rgb_mask = draw_sem_seg(pred_sem_seg[0], classes, palette)
mixed = cv.addWeighted(img, 0.5, rgb_mask, 0.5, 0)
if add_zero_label:
pred_sem_seg = pred_sem_seg + 1
out_file = str(out_dir / "data" / f"{out_name}.jpg")
cv.imwrite(out_file, img)
out_file = str(out_dir / "results" / f"{out_name}.jpg")
cv.imwrite(out_file, np.concatenate((img, mixed, rgb_mask), axis=0))
out_file = str(out_dir / "predictions" / f"{out_name}.png")
cv.imwrite(out_file, pred_sem_seg[0].clip(min=0, max=255).astype("uint8"))
[docs]
def test_images(model, image_paths, out_dir, add_zero_label=False):
"""Inference images.
Args:
model (nn.Module): The loaded segmentor.
image_paths (list[str]): Image files to inference.
out_dir (pathlib.Path): The directory to save.
add_zero_label (bool, optional): Defaults to False.
"""
if hasattr(model, "module"):
classes = model.module.dataset_meta["classes"]
palette = model.module.dataset_meta["palette"]
else:
classes = model.dataset_meta["classes"]
palette = model.dataset_meta["palette"]
for image_path in tqdm(image_paths):
img, out_name = cv.imread(image_path, 1), Path(image_path).stem
test_image(model, classes, palette, img, out_name, out_dir, add_zero_label)
[docs]
def test_videos(model, video_paths, out_dir, add_zero_label=False):
"""Inference videos.
Args:
model (nn.Module): The loaded segmentor.
video_paths (list[str]): Video files to inference.
out_dir (pathlib.Path): The directory to save.
add_zero_label (bool, optional): Defaults to False.
"""
if hasattr(model, "module"):
classes = model.module.dataset_meta["classes"]
palette = model.module.dataset_meta["palette"]
else:
classes = model.dataset_meta["classes"]
palette = model.dataset_meta["palette"]
print(f"[W] in development ..")
for video_path in tqdm(video_paths):
pass
[docs]
def func(root, config_file, checkpoint_file, cfg_options, testdata, out_dir, add_zero_label=False):
"""Inference test data.
Args:
root (_type_): _description_
config_file (_type_): _description_
checkpoint_file (_type_): _description_
cfg_options (_type_): _description_
testdata (_type_): _description_
out_dir (_type_): _description_
add_zero_label (bool, optional): Defaults to False.
"""
root = Path(root)
config_file = str(root / config_file)
checkpoint_file = str(root / checkpoint_file)
model = init_model(config_file, checkpoint_file, cfg_options=cfg_options)
testdata = Path(testdata)
assert testdata.is_file() or testdata.is_dir()
out_dir = Path(out_dir)
shutil.rmtree(out_dir, ignore_errors=True)
(out_dir / "data").mkdir(parents=True, exist_ok=False)
(out_dir / "results").mkdir(parents=True, exist_ok=False)
(out_dir / "predictions").mkdir(parents=True, exist_ok=False)
if testdata.is_file():
testdata = [testdata]
else:
testdata = sorted(testdata.glob("**/*"))
image_paths = [str(f) for f in testdata if f.suffix == ".jpg"]
video_paths = [str(f) for f in testdata if f.suffix in suffix_set]
if image_paths:
test_images(model, image_paths, out_dir, add_zero_label)
if video_paths:
test_videos(model, video_paths, out_dir, add_zero_label)
[docs]
def parse_args(args=None):
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument("root", type=str,
help="the base dir")
parser.add_argument("config_file", type=str,
help="config file path")
parser.add_argument("checkpoint_file", type=str,
help="checkpoint file path")
parser.add_argument("testdata", type=str,
help="image/video(s) path or dir")
parser.add_argument("-o", dest="out_dir", type=str,
help="save results")
parser.add_argument("-y", dest="add_zero_label", action="store_true",
help="add zero label as background")
parser.add_argument("-e", dest="cfg_options", type=str, default=None,
help="to override some settings, string of a python dict")
args = parser.parse_args(args=args)
return vars(args)
[docs]
def main(args=None):
kwargs = parse_args(args)
print(f"{__file__}: {kwargs}")
cfg_options = kwargs["cfg_options"]
if cfg_options is not None:
kwargs["cfg_options"] = eval(cfg_options)
print(func(**kwargs))
return 0
if __name__ == "__main__":
sys.exit(main())