import json
import shutil
import sys
from collections import defaultdict
from pathlib import Path
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
[docs]
def simplify(data):
res = dict()
for k, v in data.items():
if k == "iter" or k == "epoch":
res[k] = np.max(v)
else:
res[k] = np.mean(v)
return res
[docs]
def load_json_log(json_log, schedules=["iter", "lr", "loss_cls", "loss_bbox"], metrics=["bbox_mAP", "bbox_mAP_50"]):
with open(json_log, "r") as f:
lines = [l.strip() for l in f.read().splitlines()]
lines = [l for l in lines if l and not l.startswith("#")]
x_labels = set(["iter", "epoch"])
log_dict = defaultdict(list)
cache = defaultdict(list)
for row in lines:
log = json.loads(row)
mode = log.pop("mode", None)
if mode == "train":
for k, v in log.items():
if k in schedules or k in x_labels:
cache[k].append(v)
elif mode == "val":
for k, v in log.items():
if k in metrics:
log_dict[k].append(v)
for k, v in simplify(cache).items():
log_dict[k].append(v)
cache = defaultdict(list)
val1, val2 = sorted(log_dict["iter"])[:2]
if val1 == val2:
log_dict["iter"] = [a * b for a, b in zip(log_dict["iter"], log_dict["epoch"])]
return log_dict
[docs]
def plotting_log_dicts(log_dicts, out_dir, schedules, metrics, format):
cache = defaultdict(list)
for exp_name, log_dict in log_dicts:
for k, v in log_dict.items():
cache[k].extend(v)
cache["exp_name"].extend([exp_name] * len(v))
cache = pd.DataFrame(cache)
columns = set(cache.columns)
assert "iter" in columns and "epoch" in columns
x_label, schedules = schedules[0], schedules[1:]
y_labels = schedules + metrics
for y_label in y_labels:
out_file = str(out_dir / f"images/{y_label}{format}")
plotting_metrics(cache, x_label, y_label, out_file)
for exp_name, data in cache.groupby(by="exp_name"):
out_file = str(out_dir / f"images/{exp_name}{format}")
plotting_schedules(data, x_label, y_labels, out_file)
[docs]
def plotting_metrics(cache, x_label, y_label, out_file):
fig = go.Figure()
exp_names = []
for exp_name, data in cache.groupby(by="exp_name"):
fig.add_trace(
go.Scatter(
x=data[x_label],
y=data[y_label],
showlegend=True,
mode="lines",
name=exp_name.split("_", maxsplit=1)[0]
)
)
exp_names.append(exp_name)
fig.update_yaxes(title_text=y_label)
fig.update_xaxes(title_text=x_label)
title_text = "<br>".join([
"Analyze MMDetection Training Json Log",
f"{exp_names[0]},..."
])
fig.update_layout(title_text=title_text)
if Path(out_file).suffix == ".html":
fig.write_html(out_file)
else:
fig.write_image(out_file)
return fig
[docs]
def plotting_schedules(cache, x_label, y_labels, out_file):
n_plot = len(y_labels)
fig = make_subplots(
rows=n_plot, cols=1,
shared_xaxes=True,
vertical_spacing=0.03,
specs=[[{"type": "scatter"}] for _ in range(n_plot)]
)
for i, y_label in enumerate(y_labels, 1):
fig.add_trace(
go.Scatter(
x=cache[x_label],
y=cache[y_label],
showlegend=False,
mode="lines",
name=y_label
),
row=i, col=1
)
fig.update_yaxes(title_text=y_label, row=i, col=1)
fig.update_xaxes(title_text=x_label, row=n_plot, col=1)
title_text = "<br>".join([
"Analyze MMDetection Training Json Log",
str(y_labels)
])
fig.update_layout(height=300*n_plot, title_text=title_text)
if Path(out_file).suffix == ".html":
fig.write_html(out_file)
else:
fig.write_image(out_file)
return fig
[docs]
def func(json_logs, out_dir, schedules, metrics, format):
out_dir = Path(out_dir)
shutil.rmtree(out_dir, ignore_errors=True)
(out_dir / "images").mkdir(parents=True, exist_ok=False)
if Path(json_logs[0]).is_dir():
json_logs = [str(f) for f in Path(json_logs[0]).glob("*/*.log.json")]
log_dicts = []
for json_log in json_logs:
log_dict = load_json_log(json_log, schedules, metrics)
log_dicts.append([Path(json_log).parent.name, log_dict])
plotting_log_dicts(log_dicts, out_dir, schedules, metrics, format)
return str(out_dir)
[docs]
def parse_args(args=None):
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument("json_logs", type=str, nargs='+',
help="path of train log in json format or runs dir")
parser.add_argument("-o", dest="out_dir", type=str,
help="save plotting curves to the dir")
parser.add_argument("-s", dest="schedules", type=str, nargs='+',
default=["iter", "lr", "loss_cls", "loss_bbox"],
help="the schedule that you want to plot")
parser.add_argument("-m", dest="metrics", type=str, nargs='+',
default=["bbox_mAP", "bbox_mAP_50"],
help="the metric that you want to plot")
parser.add_argument("-f", dest="format", type=str, default=".html",
choices=[".png", ".svg", ".pdf", ".html"],
help="image save format")
args = parser.parse_args(args=args)
return vars(args)
[docs]
def main(args=None):
kwargs = parse_args(args)
print(f"{__file__}: {kwargs}")
print(func(**kwargs))
return 0
if __name__ == "__main__":
sys.exit(main())