update dockerfile and animeganv2.py

This commit is contained in:
SOULOFCINDER 2023-03-28 16:24:55 +08:00
parent 0e41cf546e
commit f64a3c6ed4
40 changed files with 42 additions and 840 deletions

View File

@ -1,30 +1,43 @@
# please visit https://github.com/xfyun/aiges/releases to get stable and suitable iamges.
FROM public.ecr.aws/iflytek-open/aiges-gpu:11.6-1.17-3.9.13-ubuntu1804-v2.0.0-rc6
FROM docker.io/library/python:3.8.9
RUN apt-get install -y unzip
RUN mkdir /app
RUN mkdir /app/hub/
RUN mkdir /app/hub/checkpoints
WORKDIR /app
RUN sed -i 's@//.*archive.ubuntu.com@//mirrors.ustc.edu.cn@g' /etc/apt/sources.list
RUN sed -i 's/security.ubuntu.com/mirrors.ustc.edu.cn/g' /etc/apt/sources.list
# do this if you are on the chinese server.
RUN pip3 config set global.index-url https://pypi.mirrors.ustc.edu.cn/simple/
RUN --mount=target=/root/packages.txt,source=packages.txt && \
apt-get update && xargs -r -a /root/packages.txt apt-get install -y && \
rm -rf /var/lib/apt/lists/*
# Install packages
RUN pip install gradio torch torchvision Pillow gdown numpy scipy cmake onnxruntime-gpu opencv-python-headless
# Download images
ADD https://github.com/gradio-app/gradio/raw/main/demo/animeganv2/gongyoo.jpeg /app
ADD https://github.com/gradio-app/gradio/raw/main/demo/animeganv2/groot.jpeg /app
ADD https://github.com/AK391/animegan2-pytorch/archive/main.zip /app/hub
ADD https://github.com/bryandlee/animegan2-pytorch/raw/main/weights/face_paint_512_v2.pt /app/hub/checkpoints/
ADD https://github.com/bryandlee/animegan2-pytorch/raw/main/weights/face_paint_512_v1.pt /app/hub/checkpoints/
WORKDIR /home/user/app
RUN useradd -m -u 1000 user
RUN chown -R 1000.1000 /home/user
RUN unzip main.zip
RUN mkdir /home/user/app/hub/
RUN mkdir /home/user/app/hub/checkpoints
COPY animeganv2.py /app
RUN pip config set global.index-url https://pypi.mirrors.ustc.edu.cn/simple/
RUN pip install --no-cache-dir gradio==3.0.9
RUN pip install --no-cache-dir pip==22.3.1
RUN --mount=target=requirements.txt,source=requirements.txt pip install --no-cache-dir -r requirements.txt
ADD gongyoo.jpeg /home/user/app
ADD groot.jpeg /home/user/app
ADD main.zip /home/user/app/hub
ADD hub/checkpoints/face_paint_512_v1.pt /home/user/app/hub/checkpoints/
ADD hub/checkpoints/face_paint_512_v2.pt /home/user/app/hub/checkpoints/
COPY --link --chown=1000 ./ /home/user/app
RUN unzip /home/user/app/hub/main.zip
COPY animeganv2.py /home/user/app
CMD ["python3", "animeganv2.py"]

View File

@ -2,8 +2,8 @@ import gradio as gr
from PIL import Image
import torch
model_dir="./app/hub/animegan2-pytorch-main"
model_dir_weight="./app/hub/checkpoints/face_paint_512_v1.pt"
model_dir = "hub/animegan2-pytorch-main"
model_dir_weight = "hub/checkpoints/face_paint_512_v1.pt"
model2 = torch.hub.load(
model_dir,
@ -40,4 +40,4 @@ demo = gr.Interface(
article=article,
examples=examples)
demo.launch()
demo.launch(server_name="0.0.0.0")

View File

@ -1,129 +0,0 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/

View File

@ -1,21 +0,0 @@
MIT License
Copyright (c) 2021 Bryan Lee
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -1,140 +0,0 @@
## PyTorch Implementation of [AnimeGANv2](https://github.com/TachibanaYoshino/AnimeGANv2)
**Updates**
* `2021-10-17` Add weights for [FacePortraitV2](#additional-model-weights)
* `2021-11-07` Thanks to [ak92501](https://twitter.com/ak92501), a web demo is integrated to [Huggingface Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio).
See demo: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/akhaliq/AnimeGANv2)
* `2021-11-07` Thanks to [xhlulu](https://github.com/xhlulu), the `torch.hub` model is now available. See [Torch Hub Usage](#torch-hub-usage).
* `2021-11-07` Add FacePortraitV2 style demo to a telegram bot. See [@face2stickerbot](https://t.me/face2stickerbot) by [sxela](https://github.com/sxela)
## Basic Usage
**Weight Conversion from the Original Repo (Requires TensorFlow 1.x)**
```
git clone https://github.com/TachibanaYoshino/AnimeGANv2
python convert_weights.py
```
**Inference**
```
python test.py --input_dir [image_folder_path] --device [cpu/cuda]
```
**Results from converted [[Paprika]](https://drive.google.com/file/d/1K_xN32uoQKI8XmNYNLTX5gDn1UnQVe5I/view?usp=sharing) style model**
(input image, original tensorflow result, pytorch result from left to right)
<img src="./samples/compare/1.jpg" width="960"> &nbsp;
<img src="./samples/compare/2.jpg" width="960"> &nbsp;
<img src="./samples/compare/3.jpg" width="960"> &nbsp;
**Note:** Training code not included / Results from converted weights slightly different due to the [bilinear upsample issue](https://github.com/pytorch/pytorch/issues/10604)
## Additional Model Weights
**Webtoon Face** [[ckpt]](https://drive.google.com/file/d/10T6F3-_RFOCJn6lMb-6mRmcISuYWJXGc)
<details>
<summary>samples</summary>
Trained on <b>256x256</b> face images. Distilled from [webtoon face model](https://github.com/bryandlee/naver-webtoon-faces/blob/master/README.md#face2webtoon) with L2 + VGG + GAN Loss and CelebA-HQ images. See `test_faces.ipynb` for details.
<img src="./samples/face_results.jpg" width="512"> &nbsp;
</details>
**Face Portrait v1** [[ckpt]](https://drive.google.com/file/d/1WK5Mdt6mwlcsqCZMHkCUSDJxN1UyFi0-)
<details>
<summary>samples</summary>
Trained on <b>512x512</b> face images.
[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1jCqcKekdtKzW7cxiw_bjbbfLsPh-dEds?usp=sharing)
![samples](https://user-images.githubusercontent.com/26464535/127134790-93595da2-4f8b-4aca-a9d7-98699c5e6914.jpg)
[📺](https://youtu.be/CbMfI-HNCzw?t=317)
![sample](https://user-images.githubusercontent.com/26464535/129888683-98bb6283-7bb8-4d1a-a04a-e795f5858dcf.gif)
</details>
**Face Portrait v2** [[ckpt]](https://drive.google.com/uc?id=18H3iK09_d54qEDoWIc82SyWB2xun4gjU)
<details>
<summary>samples</summary>
Trained on <b>512x512</b> face images. Compared to v1, `🔻beautify` `🔺robustness`
[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1jCqcKekdtKzW7cxiw_bjbbfLsPh-dEds?usp=sharing)
![face_portrait_v2_0](https://user-images.githubusercontent.com/26464535/137619176-59620b59-4e20-4d98-9559-a424f86b7f24.jpg)
![face_portrait_v2_1](https://user-images.githubusercontent.com/26464535/137619181-a45c9230-f5e7-4f3c-8002-7c266f89de45.jpg)
🦑 🎮 🔥
![face_portrait_v2_squid_game](https://user-images.githubusercontent.com/26464535/137619183-20e94f11-7a8e-4c3e-9b45-378ab63827ca.jpg)
</details>
## Torch Hub Usage
You can load Animegan v2 via `torch.hub`:
```python
import torch
model = torch.hub.load('bryandlee/animegan2-pytorch', 'generator').eval()
# convert your image into tensor here
out = model(img_tensor)
```
You can load with various configs (more details in [the torch docs](https://pytorch.org/docs/stable/hub.html)):
```python
model = torch.hub.load(
"bryandlee/animegan2-pytorch:main",
"generator",
pretrained=True, # or give URL to a pretrained model
device="cuda", # or "cpu" if you don't have a GPU
progress=True, # show progress
)
```
Currently, the following `pretrained` shorthands are available:
```python
model = torch.hub.load("bryandlee/animegan2-pytorch:main", "generator", pretrained="celeba_distill")
model = torch.hub.load("bryandlee/animegan2-pytorch:main", "generator", pretrained="face_paint_512_v1")
model = torch.hub.load("bryandlee/animegan2-pytorch:main", "generator", pretrained="face_paint_512_v2")
model = torch.hub.load("bryandlee/animegan2-pytorch:main", "generator", pretrained="paprika")
```
You can also load the `face2paint` util function. First, install dependencies:
```
pip install torchvision Pillow numpy
```
Then, import the function using `torch.hub`:
```python
face2paint = torch.hub.load(
'bryandlee/animegan2-pytorch:main', 'face2paint',
size=512, device="cpu"
)
img = Image.open(...).convert("RGB")
out = face2paint(model, img)
```

View File

@ -1,140 +0,0 @@
import argparse
import numpy as np
import os
import tensorflow as tf
from AnimeGANv2.net import generator as tf_generator
import torch
from model import Generator
def load_tf_weights(tf_path):
test_real = tf.placeholder(tf.float32, [1, None, None, 3], name='test')
with tf.variable_scope("generator", reuse=False):
test_generated = tf_generator.G_net(test_real).fake
saver = tf.train.Saver()
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, device_count = {'GPU': 0})) as sess:
ckpt = tf.train.get_checkpoint_state(tf_path)
assert ckpt is not None and ckpt.model_checkpoint_path is not None, f"Failed to load checkpoint {tf_path}"
saver.restore(sess, ckpt.model_checkpoint_path)
print(f"Tensorflow model checkpoint {ckpt.model_checkpoint_path} loaded")
tf_weights = {}
for v in tf.trainable_variables():
tf_weights[v.name] = v.eval()
return tf_weights
def convert_keys(k):
# 1. divide tf weight name in three parts [block_idx, layer_idx, weight/bias]
# 2. handle each part & merge into a pytorch model keys
k = k.replace("Conv/", "Conv_0/").replace("LayerNorm/", "LayerNorm_0/")
keys = k.split("/")[2:]
is_dconv = False
# handle C block..
if keys[0] == "C":
if keys[1] in ["Conv_1", "LayerNorm_1"]:
keys[1] = keys[1].replace("1", "5")
if len(keys) == 4:
assert "r" in keys[1]
if keys[1] == keys[2]:
is_dconv = True
keys[2] = "1.1"
block_c_maps = {
"1": "1.2",
"Conv_1": "2",
"2": "3",
}
if keys[2] in block_c_maps:
keys[2] = block_c_maps[keys[2]]
keys[1] = keys[1].replace("r", "") + ".layers." + keys[2]
keys[2] = keys[3]
keys.pop(-1)
assert len(keys) == 3
# handle output block
if "out" in keys[0]:
keys[1] = "0"
# first part
if keys[0] in ["A", "B", "C", "D", "E"]:
keys[0] = "block_" + keys[0].lower()
# second part
if "LayerNorm_" in keys[1]:
keys[1] = keys[1].replace("LayerNorm_", "") + ".2"
if "Conv_" in keys[1]:
keys[1] = keys[1].replace("Conv_", "") + ".1"
# third part
keys[2] = {
"weights:0": "weight",
"w:0": "weight",
"bias:0": "bias",
"gamma:0": "weight",
"beta:0": "bias",
}[keys[2]]
return ".".join(keys), is_dconv
def convert_and_save(tf_checkpoint_path, save_name):
tf_weights = load_tf_weights(tf_checkpoint_path)
torch_net = Generator()
torch_weights = torch_net.state_dict()
torch_converted_weights = {}
for k, v in tf_weights.items():
torch_k, is_dconv = convert_keys(k)
assert torch_k in torch_weights, f"weight name mismatch: {k}"
converted_weight = torch.from_numpy(v)
if len(converted_weight.shape) == 4:
if is_dconv:
converted_weight = converted_weight.permute(2, 3, 0, 1)
else:
converted_weight = converted_weight.permute(3, 2, 0, 1)
assert torch_weights[torch_k].shape == converted_weight.shape, f"shape mismatch: {k}"
torch_converted_weights[torch_k] = converted_weight
assert sorted(list(torch_converted_weights)) == sorted(list(torch_weights)), f"some weights are missing"
torch_net.load_state_dict(torch_converted_weights)
torch.save(torch_net.state_dict(), save_name)
print(f"PyTorch model saved at {save_name}")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--tf_checkpoint_path',
type=str,
default='AnimeGANv2/checkpoint/generator_Paprika_weight',
)
parser.add_argument(
'--save_name',
type=str,
default='pytorch_generator_Paprika.pt',
)
args = parser.parse_args()
convert_and_save(args.tf_checkpoint_path, args.save_name)

View File

@ -1,63 +0,0 @@
import torch
def generator(pretrained=True, device="cpu", progress=True, check_hash=True):
from model import Generator
release_url = "https://github.com/bryandlee/animegan2-pytorch/raw/main/weights"
known = {
name: f"{release_url}/{name}.pt"
for name in [
'celeba_distill', 'face_paint_512_v1', 'face_paint_512_v2', 'paprika'
]
}
device = torch.device(device)
model = Generator().to(device)
if type(pretrained) == str:
# Look if a known name is passed, otherwise assume it's a URL
ckpt_url = known.get(pretrained, pretrained)
pretrained = True
else:
ckpt_url = known.get('face_paint_512_v2')
if pretrained is True:
state_dict = torch.hub.load_state_dict_from_url(
ckpt_url,
map_location=device,
progress=progress,
check_hash=check_hash,
)
model.load_state_dict(state_dict)
return model
def face2paint(device="cpu", size=512, side_by_side=False):
from PIL import Image
from torchvision.transforms.functional import to_tensor, to_pil_image
def face2paint(
model: torch.nn.Module,
img: Image.Image,
size: int = size,
side_by_side: bool = side_by_side,
device: str = device,
) -> Image.Image:
w, h = img.size
s = min(w, h)
img = img.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2))
img = img.resize((size, size), Image.LANCZOS)
with torch.no_grad():
input = to_tensor(img).unsqueeze(0) * 2 - 1
output = model(input.to(device)).cpu()[0]
if side_by_side:
output = torch.cat([input[0], output], dim=2)
output = (output * 0.5 + 0.5).clip(0, 1)
return to_pil_image(output)
return face2paint

View File

@ -1,110 +0,0 @@
import torch
from torch import nn
import torch.nn.functional as F
class ConvNormLReLU(nn.Sequential):
def __init__(self, in_ch, out_ch, kernel_size=3, stride=1, padding=1, pad_mode="reflect", groups=1, bias=False):
pad_layer = {
"zero": nn.ZeroPad2d,
"same": nn.ReplicationPad2d,
"reflect": nn.ReflectionPad2d,
}
if pad_mode not in pad_layer:
raise NotImplementedError
super(ConvNormLReLU, self).__init__(
pad_layer[pad_mode](padding),
nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=0, groups=groups, bias=bias),
nn.GroupNorm(num_groups=1, num_channels=out_ch, affine=True),
nn.LeakyReLU(0.2, inplace=True)
)
class InvertedResBlock(nn.Module):
def __init__(self, in_ch, out_ch, expansion_ratio=2):
super(InvertedResBlock, self).__init__()
self.use_res_connect = in_ch == out_ch
bottleneck = int(round(in_ch*expansion_ratio))
layers = []
if expansion_ratio != 1:
layers.append(ConvNormLReLU(in_ch, bottleneck, kernel_size=1, padding=0))
# dw
layers.append(ConvNormLReLU(bottleneck, bottleneck, groups=bottleneck, bias=True))
# pw
layers.append(nn.Conv2d(bottleneck, out_ch, kernel_size=1, padding=0, bias=False))
layers.append(nn.GroupNorm(num_groups=1, num_channels=out_ch, affine=True))
self.layers = nn.Sequential(*layers)
def forward(self, input):
out = self.layers(input)
if self.use_res_connect:
out = input + out
return out
class Generator(nn.Module):
def __init__(self, ):
super().__init__()
self.block_a = nn.Sequential(
ConvNormLReLU(3, 32, kernel_size=7, padding=3),
ConvNormLReLU(32, 64, stride=2, padding=(0,1,0,1)),
ConvNormLReLU(64, 64)
)
self.block_b = nn.Sequential(
ConvNormLReLU(64, 128, stride=2, padding=(0,1,0,1)),
ConvNormLReLU(128, 128)
)
self.block_c = nn.Sequential(
ConvNormLReLU(128, 128),
InvertedResBlock(128, 256, 2),
InvertedResBlock(256, 256, 2),
InvertedResBlock(256, 256, 2),
InvertedResBlock(256, 256, 2),
ConvNormLReLU(256, 128),
)
self.block_d = nn.Sequential(
ConvNormLReLU(128, 128),
ConvNormLReLU(128, 128)
)
self.block_e = nn.Sequential(
ConvNormLReLU(128, 64),
ConvNormLReLU(64, 64),
ConvNormLReLU(64, 32, kernel_size=7, padding=3)
)
self.out_layer = nn.Sequential(
nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0, bias=False),
nn.Tanh()
)
def forward(self, input, align_corners=True):
out = self.block_a(input)
half_size = out.size()[-2:]
out = self.block_b(out)
out = self.block_c(out)
if align_corners:
out = F.interpolate(out, half_size, mode="bilinear", align_corners=True)
else:
out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=False)
out = self.block_d(out)
if align_corners:
out = F.interpolate(out, input.size()[-2:], mode="bilinear", align_corners=True)
else:
out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=False)
out = self.block_e(out)
out = self.out_layer(out)
return out

Binary file not shown.

Before

Width:  |  Height:  |  Size: 863 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.3 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.4 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 843 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 36 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 26 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 37 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 30 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 62 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 37 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 15 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 39 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 50 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 22 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 35 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 131 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 50 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 37 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 21 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 35 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 719 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 536 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 227 KiB

View File

@ -1,90 +0,0 @@
import argparse
import torch
import cv2
import numpy as np
import os
from model import Generator
torch.backends.cudnn.enabled = False
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
def load_image(image_path, x32=False):
img = cv2.imread(image_path).astype(np.float32)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
h, w = img.shape[:2]
if x32: # resize image to multiple of 32s
def to_32s(x):
return 256 if x < 256 else x - x%32
img = cv2.resize(img, (to_32s(w), to_32s(h)))
img = torch.from_numpy(img)
img = img/127.5 - 1.0
return img
def test(args):
device = args.device
net = Generator()
net.load_state_dict(torch.load(args.checkpoint, map_location="cpu"))
net.to(device).eval()
print(f"model loaded: {args.checkpoint}")
os.makedirs(args.output_dir, exist_ok=True)
for image_name in sorted(os.listdir(args.input_dir)):
if os.path.splitext(image_name)[-1].lower() not in [".jpg", ".png", ".bmp", ".tiff"]:
continue
image = load_image(os.path.join(args.input_dir, image_name), args.x32)
with torch.no_grad():
input = image.permute(2, 0, 1).unsqueeze(0).to(device)
out = net(input, args.upsample_align).squeeze(0).permute(1, 2, 0).cpu().numpy()
out = (out + 1)*127.5
out = np.clip(out, 0, 255).astype(np.uint8)
cv2.imwrite(os.path.join(args.output_dir, image_name), cv2.cvtColor(out, cv2.COLOR_BGR2RGB))
print(f"image saved: {image_name}")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--checkpoint',
type=str,
default='./pytorch_generator_Paprika.pt',
)
parser.add_argument(
'--input_dir',
type=str,
default='./samples/inputs',
)
parser.add_argument(
'--output_dir',
type=str,
default='./samples/results',
)
parser.add_argument(
'--device',
type=str,
default='cuda:0',
)
parser.add_argument(
'--upsample_align',
type=bool,
default=False,
)
parser.add_argument(
'--x32',
action="store_true",
)
args = parser.parse_args()
test(args)

File diff suppressed because one or more lines are too long

BIN
hub/main.zip Normal file

Binary file not shown.

1
packages.txt Normal file
View File

@ -0,0 +1 @@
unzip

View File

@ -1,6 +1,9 @@
# please keep the aiges for the latest
aiges
transformers
torch
diffusers
accelerate
torchvision
Pillow
gdown
numpy
scipy
cmake
onnxruntime-gpu
opencv-python-headless