72 lines
2.2 KiB
Python
72 lines
2.2 KiB
Python
import hashlib
|
|
import os
|
|
import sys
|
|
from contextlib import redirect_stdout
|
|
from pathlib import Path
|
|
from typing import Type
|
|
|
|
import onnxruntime as ort
|
|
import pooch
|
|
|
|
from .session_base import BaseSession
|
|
from .session_cloth import ClothSession
|
|
from .session_simple import SimpleSession
|
|
|
|
|
|
def new_session(model_name: str = "u2net") -> BaseSession:
|
|
session_class: Type[BaseSession]
|
|
md5 = "60024c5c889badc19c04ad937298a77b"
|
|
url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx"
|
|
session_class = SimpleSession
|
|
|
|
if model_name == "u2netp":
|
|
md5 = "8e83ca70e441ab06c318d82300c84806"
|
|
url = (
|
|
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx"
|
|
)
|
|
session_class = SimpleSession
|
|
elif model_name == "u2net_human_seg":
|
|
md5 = "c09ddc2e0104f800e3e1bb4652583d1f"
|
|
url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx"
|
|
session_class = SimpleSession
|
|
elif model_name == "u2net_cloth_seg":
|
|
md5 = "2434d1f3cb744e0e49386c906e5a08bb"
|
|
url = "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx"
|
|
session_class = ClothSession
|
|
elif model_name == "silueta":
|
|
md5 = "55e59e0d8062d2f5d013f4725ee84782"
|
|
url = (
|
|
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/silueta.onnx"
|
|
)
|
|
session_class = SimpleSession
|
|
|
|
u2net_home = os.getenv(
|
|
"U2NET_HOME", os.path.join(os.getenv("XDG_DATA_HOME", "~"), ".u2net")
|
|
)
|
|
|
|
fname = f"{model_name}.onnx"
|
|
path = Path(u2net_home).expanduser()
|
|
full_path = Path(u2net_home).expanduser() / fname
|
|
|
|
pooch.retrieve(
|
|
url,
|
|
f"md5:{md5}",
|
|
fname=fname,
|
|
path=Path(u2net_home).expanduser(),
|
|
progressbar=True,
|
|
)
|
|
|
|
sess_opts = ort.SessionOptions()
|
|
|
|
if "OMP_NUM_THREADS" in os.environ:
|
|
sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
|
|
|
|
return session_class(
|
|
model_name,
|
|
ort.InferenceSession(
|
|
str(full_path),
|
|
providers=ort.get_available_providers(),
|
|
sess_options=sess_opts,
|
|
),
|
|
)
|