Source code for hello.onnx.infer
import logging
from pathlib import Path
import onnxruntime
import torch
import torch.nn as nn
from hello.utils import importer
[docs]
def set_logging(name=None, verbose=True):
# Sets level and returns logger
level = logging.INFO if verbose else logging.ERROR
log = logging.getLogger(name)
log.setLevel(level)
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("%(message)s"))
handler.setLevel(level)
log.addHandler(handler)
set_logging() # run before defining LOGGER
LOGGER = logging.getLogger("onnx-inference") # define globally
[docs]
class DetectBackend(nn.Module):
def __init__(self, root, f="best.onnx", libpath="libprocess.py"):
root = Path(root)
f = (root / f).as_posix()
libpath = (root / libpath).as_posix()
super().__init__()
LOGGER.info(f"Loading {f} for ONNX Runtime inference...")
cuda = torch.cuda.is_available()
providers = ["CUDAExecutionProvider",
"CPUExecutionProvider"] if cuda else ["CPUExecutionProvider"]
self.session = onnxruntime.InferenceSession(f, providers=providers)
self.output_names = [self.session.get_outputs()[0].name]
self.input_name = self.session.get_inputs()[0].name
self.libprocess = importer.load_from_file("libprocess", libpath)
self.mask_targets = self.libprocess.default_mask_targets()
self.classes = self.libprocess.default_classes()
[docs]
def forward(self, filepath, **kwargs):
x = self.libprocess.pre_process(filepath, **kwargs)
y = self.session.run(self.output_names, {self.input_name: x})[0]
z = self.libprocess.post_process(y, **kwargs)
return z