41 lines
1.2 KiB
Python
41 lines
1.2 KiB
Python
from typing import Dict, List, Tuple
|
|
|
|
import numpy as np
|
|
import onnxruntime as ort
|
|
from PIL import Image
|
|
from PIL.Image import Image as PILImage
|
|
|
|
|
|
class BaseSession:
|
|
def __init__(self, model_name: str, inner_session: ort.InferenceSession):
|
|
self.model_name = model_name
|
|
self.inner_session = inner_session
|
|
|
|
def normalize(
|
|
self,
|
|
img: PILImage,
|
|
mean: Tuple[float, float, float],
|
|
std: Tuple[float, float, float],
|
|
size: Tuple[int, int],
|
|
) -> Dict[str, np.ndarray]:
|
|
im = img.convert("RGB").resize(size, Image.LANCZOS)
|
|
|
|
im_ary = np.array(im)
|
|
im_ary = im_ary / np.max(im_ary)
|
|
|
|
tmpImg = np.zeros((im_ary.shape[0], im_ary.shape[1], 3))
|
|
tmpImg[:, :, 0] = (im_ary[:, :, 0] - mean[0]) / std[0]
|
|
tmpImg[:, :, 1] = (im_ary[:, :, 1] - mean[1]) / std[1]
|
|
tmpImg[:, :, 2] = (im_ary[:, :, 2] - mean[2]) / std[2]
|
|
|
|
tmpImg = tmpImg.transpose((2, 0, 1))
|
|
|
|
return {
|
|
self.inner_session.get_inputs()[0]
|
|
.name: np.expand_dims(tmpImg, 0)
|
|
.astype(np.float32)
|
|
}
|
|
|
|
def predict(self, img: PILImage) -> List[PILImage]:
|
|
raise NotImplementedError
|