Source code for hello.fiftyone.miou

import re
import sys
from pathlib import Path

import cv2 as cv
import numpy as np
import torch
from prettytable import PrettyTable


[docs] class ConfusionMatrix: """For segmentation metrics. Args: class_names (list[str]): the list of class label strings reduce_zero_label (bool, optional): defaults to True """ def __init__(self, class_names, reduce_zero_label=True): self.class_names = class_names self.num_classes = len(class_names) self.reduce_zero_label = reduce_zero_label self.mat = None
[docs] def update(self, target, output): if isinstance(target, np.ndarray): target = torch.from_numpy(target) output = torch.from_numpy(output) if self.reduce_zero_label: target[target == 0] = 255 target = target - 1 target[target == 254] = 255 n = self.num_classes if self.mat is None: self.mat = torch.zeros((n, n), dtype=torch.int64, device=target.device) with torch.inference_mode(): k = (target >= 0) & (target < n) inds = n * target[k].to(torch.int64) + output[k] self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)
[docs] def reset(self): self.mat.zero_()
@property def confusion_matrix(self): mat = self.mat.float() mat = mat / mat.sum(dim=1, keepdim=True) mat = mat - torch.diag_embed(mat.diag()) top_values, top_indices = mat.topk(2, dim=1) n = self.num_classes table_data = PrettyTable() table_data.add_column("class_name", self.class_names) table_data.add_column("index", [f"{v:02d}" for v in range(n)]) for i in range(top_values.size(1)): table_data.add_column(f"top{i+1}_ratio", [f"{v:.2%}" for v in top_values[:, i]], align="r") table_data.add_column(f"top{i+1}_class", [self.class_names[v.item()] for v in top_indices[:, i]], align="r") return table_data @property def metrics(self): h = self.mat.float() precision = torch.diag(h) / h.sum(0) * 100 recall = torch.diag(h) / h.sum(1) * 100 # diag / sum(gt_label) iou = torch.diag(h) / (h.sum(0) + h.sum(1) - torch.diag(h)) * 100 support = h.sum(1) n = self.num_classes table_data = PrettyTable() table_data.add_column("class_name", self.class_names) table_data.add_column("index", [f"{v:02d}" for v in range(n)]) table_data.add_column("precision", [f"{v:.2f}" for v in precision.tolist()], align="r") table_data.add_column("recall", [f"{v:.2f}" for v in recall.tolist()], align="r") table_data.add_column("iou", [f"{v:.2f}" for v in iou.tolist()], align="r") table_data.add_column("support", [int(v) for v in support.tolist()], align="r") table_data.add_column("support_ratio", [f"{v:.2%}" for v in (support / support.sum()).tolist()], align="r") table_data.add_row(["macro avg", "-"] + [f"{v.nanmean().item():.2f}" for v in [precision, recall, iou]] + ["-", "-"]) w = support / support.sum() precision, recall, iou = precision * w, recall * w, iou * w table_data.add_row(["weighted avg", "-"] + [f"{v.nansum().item():.2f}" for v in [precision, recall, iou]] + ["-", "-"]) return table_data
[docs] def func(true_dir, pred_dir, num_classes, class_names, reduce_zero_label=True): true_files = {f.stem: str(f) for f in Path(true_dir).glob("*.png")} pred_files = {f.stem: str(f) for f in Path(pred_dir).glob("*.png")} common_stems = sorted(set(true_files.keys()) & set(pred_files.keys())) print(f"[INFO] number of true files: {len(true_files)}") print(f"[INFO] number of pred files: {len(pred_files)}") print(f"[INFO] number of common files: {len(common_stems)}") assert (num_classes is not None) or (class_names is not None) if class_names is None: info_py = Path(true_dir).with_name("info.py") if info_py.is_file(): with open(info_py, "r") as f: codestr = f.read() info = eval(re.split(r"info\s*=\s*", codestr)[1]) class_names = info["classes"][:num_classes] if class_names is None: class_names = [f"c{i:03d}" for i in range(num_classes)] confmat = ConfusionMatrix(class_names, reduce_zero_label=reduce_zero_label) for stem in common_stems: target = cv.imread(true_files[stem], 0) output = cv.imread(pred_files[stem], 0) confmat.update(target, output) print(confmat.metrics.get_string()) print(confmat.confusion_matrix.get_string())
[docs] def parse_args(args=None): from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) parser.add_argument("true_dir", type=str, help="the ground truth dir") parser.add_argument("pred_dir", type=str, help="the prediction dir") parser.add_argument("-n", dest="num_classes", type=int, help="[0, num_classes) to calculate metric") parser.add_argument("-c", dest="class_names", type=str, nargs='+', help="if is None, will be generated [c0 ... cn-1]") parser.add_argument("-y", dest="reduce_zero_label", action="store_true", help="whether to target zero as ignored") args = parser.parse_args(args=args) return vars(args)
[docs] def main(args=None): kwargs = parse_args(args) print(f"{__file__}: {kwargs}") print(func(**kwargs)) return 0
if __name__ == "__main__": sys.exit(main())