import json
import re
import shutil
import sys
import time
from pathlib import Path
from string import Template
import hello.fiftyone.dataset as hod
tmpl_readme = """\
# README
- `$date`
---
[TOC]
## Metrics
$aggregate_metrics
```json
$report
```
"""
tmpl_readme = Template(tmpl_readme)
[docs]
def make_dataset(dataset_dir, info_py="info.py", data_path="data", preds_path="predictions/", labels_path="labels/"):
dataset_dir = Path(dataset_dir or ".")
with open(dataset_dir / info_py, "r") as f:
codestr = f.read()
info = eval(re.split(r"info\s*=\s*", codestr)[1])
dataset_name = dataset_dir.name
dataset_type = "segmentation"
version = "001"
classes = info.get("classes", [])
mask_targets = info.get("mask_targets", {})
hod.delete_datasets([dataset_name])
dataset = hod.create_dataset(dataset_name, dataset_type, classes, mask_targets)
dataset.info["version"] = version
hod.add_images_dir(dataset, dataset_dir / data_path, None)
hod.add_segmentation_labels(dataset, "predictions", dataset_dir / preds_path, mask_targets, mode="png")
hod.add_segmentation_labels(dataset, "ground_truth", dataset_dir / labels_path, mask_targets, mode="png")
return dataset
[docs]
def save_plot(plot, html_file):
if hasattr(plot, "_widget"):
plot = plot._widget
if hasattr(plot, "write_html"):
plot.write_html(html_file)
elif hasattr(plot, "save"):
plot.save(html_file)
[docs]
def func(dataset_dir, info_py="info.py", data_path="data", preds_path="predictions/", labels_path="labels/", output_dir=None, **kwargs):
dataset = make_dataset(dataset_dir, info_py, data_path, preds_path, labels_path)
params = dict(
gt_field="ground_truth",
eval_key="eval",
mask_targets=dataset.default_mask_targets,
method="simple",
bandwidth=None,
average="micro",
)
params.update(**kwargs)
results = dataset.evaluate_segmentations("predictions", **params)
results.print_report()
if output_dir is not None:
output_dir = Path(output_dir)
shutil.rmtree(output_dir, ignore_errors=True)
(output_dir).mkdir(parents=True, exist_ok=False)
tmpl_mapping = {
"date": time.strftime(r"%Y-%m-%d %H:%M"),
"aggregate_metrics": "\n".join([format_kv(k, v) for k, v in results.metrics().items()]),
"report": json.dumps(results.report(), indent=4),
}
readme_str = tmpl_readme.safe_substitute(tmpl_mapping)
with open(output_dir / "README.md", "w") as f:
f.write(readme_str)
html_file = str(output_dir / "plot_confusion_matrix.html")
plot = results.plot_confusion_matrix()
save_plot(plot, html_file)
return "\n[END]"
[docs]
def parse_args(args=None):
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument("dataset_dir", type=str,
help="base dir")
parser.add_argument("--info", dest="info_py", type=str, default="info.py",
help="which the info.py")
parser.add_argument("--data", dest="data_path", type=str, default="data",
help="which the images")
parser.add_argument("--preds", dest="preds_path", type=str, default="predictions/",
help="which the predictions file or dir")
parser.add_argument("--labels", dest="labels_path", type=str, default="labels/",
help="which the ground_truth file or dir")
parser.add_argument("--out", dest="output_dir", type=str, default=None,
help="save results to output dir")
parser.add_argument("--bandwidth", dest="bandwidth", type=int, default=None,
help="evaluate only along the contours of the ground truth masks")
parser.add_argument("--average", dest="average", type=str, default="micro",
choices=["micro", "macro", "weighted", "samples"],
help="https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_fscore_support.html")
args = parser.parse_args(args=args)
return vars(args)
[docs]
def main(args=None):
kwargs = parse_args(args)
print(f"{__file__}: {kwargs}")
print(func(**kwargs))
return 0
if __name__ == "__main__":
sys.exit(main())