91 lines
2.3 KiB
Python
91 lines
2.3 KiB
Python
import argparse
|
|
|
|
import torch
|
|
import cv2
|
|
import numpy as np
|
|
import os
|
|
|
|
from model import Generator
|
|
|
|
torch.backends.cudnn.enabled = False
|
|
torch.backends.cudnn.benchmark = False
|
|
torch.backends.cudnn.deterministic = True
|
|
|
|
def load_image(image_path, x32=False):
|
|
img = cv2.imread(image_path).astype(np.float32)
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
h, w = img.shape[:2]
|
|
|
|
if x32: # resize image to multiple of 32s
|
|
def to_32s(x):
|
|
return 256 if x < 256 else x - x%32
|
|
img = cv2.resize(img, (to_32s(w), to_32s(h)))
|
|
|
|
img = torch.from_numpy(img)
|
|
img = img/127.5 - 1.0
|
|
return img
|
|
|
|
|
|
def test(args):
|
|
device = args.device
|
|
|
|
net = Generator()
|
|
net.load_state_dict(torch.load(args.checkpoint, map_location="cpu"))
|
|
net.to(device).eval()
|
|
print(f"model loaded: {args.checkpoint}")
|
|
|
|
os.makedirs(args.output_dir, exist_ok=True)
|
|
|
|
for image_name in sorted(os.listdir(args.input_dir)):
|
|
if os.path.splitext(image_name)[-1].lower() not in [".jpg", ".png", ".bmp", ".tiff"]:
|
|
continue
|
|
|
|
image = load_image(os.path.join(args.input_dir, image_name), args.x32)
|
|
|
|
with torch.no_grad():
|
|
input = image.permute(2, 0, 1).unsqueeze(0).to(device)
|
|
out = net(input, args.upsample_align).squeeze(0).permute(1, 2, 0).cpu().numpy()
|
|
out = (out + 1)*127.5
|
|
out = np.clip(out, 0, 255).astype(np.uint8)
|
|
|
|
cv2.imwrite(os.path.join(args.output_dir, image_name), cv2.cvtColor(out, cv2.COLOR_BGR2RGB))
|
|
print(f"image saved: {image_name}")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
'--checkpoint',
|
|
type=str,
|
|
default='./pytorch_generator_Paprika.pt',
|
|
)
|
|
parser.add_argument(
|
|
'--input_dir',
|
|
type=str,
|
|
default='./samples/inputs',
|
|
)
|
|
parser.add_argument(
|
|
'--output_dir',
|
|
type=str,
|
|
default='./samples/results',
|
|
)
|
|
parser.add_argument(
|
|
'--device',
|
|
type=str,
|
|
default='cuda:0',
|
|
)
|
|
parser.add_argument(
|
|
'--upsample_align',
|
|
type=bool,
|
|
default=False,
|
|
)
|
|
parser.add_argument(
|
|
'--x32',
|
|
action="store_true",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
test(args)
|
|
|