update dockerfile and animeganv2.py
47
Dockerfile
|
@ -1,30 +1,43 @@
|
||||||
# please visit https://github.com/xfyun/aiges/releases to get stable and suitable iamges.
|
# please visit https://github.com/xfyun/aiges/releases to get stable and suitable iamges.
|
||||||
|
|
||||||
FROM public.ecr.aws/iflytek-open/aiges-gpu:11.6-1.17-3.9.13-ubuntu1804-v2.0.0-rc6
|
FROM docker.io/library/python:3.8.9
|
||||||
|
|
||||||
RUN apt-get install -y unzip
|
RUN apt-get install -y unzip
|
||||||
|
|
||||||
RUN mkdir /app
|
|
||||||
RUN mkdir /app/hub/
|
|
||||||
RUN mkdir /app/hub/checkpoints
|
|
||||||
|
|
||||||
WORKDIR /app
|
RUN sed -i 's@//.*archive.ubuntu.com@//mirrors.ustc.edu.cn@g' /etc/apt/sources.list
|
||||||
|
RUN sed -i 's/security.ubuntu.com/mirrors.ustc.edu.cn/g' /etc/apt/sources.list
|
||||||
|
|
||||||
# do this if you are on the chinese server.
|
RUN --mount=target=/root/packages.txt,source=packages.txt && \
|
||||||
RUN pip3 config set global.index-url https://pypi.mirrors.ustc.edu.cn/simple/
|
apt-get update && xargs -r -a /root/packages.txt apt-get install -y && \
|
||||||
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Install packages
|
|
||||||
RUN pip install gradio torch torchvision Pillow gdown numpy scipy cmake onnxruntime-gpu opencv-python-headless
|
|
||||||
|
|
||||||
# Download images
|
WORKDIR /home/user/app
|
||||||
ADD https://github.com/gradio-app/gradio/raw/main/demo/animeganv2/gongyoo.jpeg /app
|
RUN useradd -m -u 1000 user
|
||||||
ADD https://github.com/gradio-app/gradio/raw/main/demo/animeganv2/groot.jpeg /app
|
RUN chown -R 1000.1000 /home/user
|
||||||
ADD https://github.com/AK391/animegan2-pytorch/archive/main.zip /app/hub
|
|
||||||
ADD https://github.com/bryandlee/animegan2-pytorch/raw/main/weights/face_paint_512_v2.pt /app/hub/checkpoints/
|
|
||||||
ADD https://github.com/bryandlee/animegan2-pytorch/raw/main/weights/face_paint_512_v1.pt /app/hub/checkpoints/
|
|
||||||
|
|
||||||
RUN unzip main.zip
|
RUN mkdir /home/user/app/hub/
|
||||||
|
RUN mkdir /home/user/app/hub/checkpoints
|
||||||
|
|
||||||
COPY animeganv2.py /app
|
|
||||||
|
RUN pip config set global.index-url https://pypi.mirrors.ustc.edu.cn/simple/
|
||||||
|
|
||||||
|
|
||||||
|
RUN pip install --no-cache-dir gradio==3.0.9
|
||||||
|
RUN pip install --no-cache-dir pip==22.3.1
|
||||||
|
RUN --mount=target=requirements.txt,source=requirements.txt pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
ADD gongyoo.jpeg /home/user/app
|
||||||
|
ADD groot.jpeg /home/user/app
|
||||||
|
ADD main.zip /home/user/app/hub
|
||||||
|
ADD hub/checkpoints/face_paint_512_v1.pt /home/user/app/hub/checkpoints/
|
||||||
|
ADD hub/checkpoints/face_paint_512_v2.pt /home/user/app/hub/checkpoints/
|
||||||
|
|
||||||
|
COPY --link --chown=1000 ./ /home/user/app
|
||||||
|
|
||||||
|
RUN unzip /home/user/app/hub/main.zip
|
||||||
|
|
||||||
|
COPY animeganv2.py /home/user/app
|
||||||
|
|
||||||
CMD ["python3", "animeganv2.py"]
|
CMD ["python3", "animeganv2.py"]
|
|
@ -2,8 +2,8 @@ import gradio as gr
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
model_dir="./app/hub/animegan2-pytorch-main"
|
model_dir = "hub/animegan2-pytorch-main"
|
||||||
model_dir_weight="./app/hub/checkpoints/face_paint_512_v1.pt"
|
model_dir_weight = "hub/checkpoints/face_paint_512_v1.pt"
|
||||||
|
|
||||||
model2 = torch.hub.load(
|
model2 = torch.hub.load(
|
||||||
model_dir,
|
model_dir,
|
||||||
|
@ -40,4 +40,4 @@ demo = gr.Interface(
|
||||||
article=article,
|
article=article,
|
||||||
examples=examples)
|
examples=examples)
|
||||||
|
|
||||||
demo.launch()
|
demo.launch(server_name="0.0.0.0")
|
|
@ -1,129 +0,0 @@
|
||||||
# Byte-compiled / optimized / DLL files
|
|
||||||
__pycache__/
|
|
||||||
*.py[cod]
|
|
||||||
*$py.class
|
|
||||||
|
|
||||||
# C extensions
|
|
||||||
*.so
|
|
||||||
|
|
||||||
# Distribution / packaging
|
|
||||||
.Python
|
|
||||||
build/
|
|
||||||
develop-eggs/
|
|
||||||
dist/
|
|
||||||
downloads/
|
|
||||||
eggs/
|
|
||||||
.eggs/
|
|
||||||
lib/
|
|
||||||
lib64/
|
|
||||||
parts/
|
|
||||||
sdist/
|
|
||||||
var/
|
|
||||||
wheels/
|
|
||||||
pip-wheel-metadata/
|
|
||||||
share/python-wheels/
|
|
||||||
*.egg-info/
|
|
||||||
.installed.cfg
|
|
||||||
*.egg
|
|
||||||
MANIFEST
|
|
||||||
|
|
||||||
# PyInstaller
|
|
||||||
# Usually these files are written by a python script from a template
|
|
||||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
|
||||||
*.manifest
|
|
||||||
*.spec
|
|
||||||
|
|
||||||
# Installer logs
|
|
||||||
pip-log.txt
|
|
||||||
pip-delete-this-directory.txt
|
|
||||||
|
|
||||||
# Unit test / coverage reports
|
|
||||||
htmlcov/
|
|
||||||
.tox/
|
|
||||||
.nox/
|
|
||||||
.coverage
|
|
||||||
.coverage.*
|
|
||||||
.cache
|
|
||||||
nosetests.xml
|
|
||||||
coverage.xml
|
|
||||||
*.cover
|
|
||||||
*.py,cover
|
|
||||||
.hypothesis/
|
|
||||||
.pytest_cache/
|
|
||||||
|
|
||||||
# Translations
|
|
||||||
*.mo
|
|
||||||
*.pot
|
|
||||||
|
|
||||||
# Django stuff:
|
|
||||||
*.log
|
|
||||||
local_settings.py
|
|
||||||
db.sqlite3
|
|
||||||
db.sqlite3-journal
|
|
||||||
|
|
||||||
# Flask stuff:
|
|
||||||
instance/
|
|
||||||
.webassets-cache
|
|
||||||
|
|
||||||
# Scrapy stuff:
|
|
||||||
.scrapy
|
|
||||||
|
|
||||||
# Sphinx documentation
|
|
||||||
docs/_build/
|
|
||||||
|
|
||||||
# PyBuilder
|
|
||||||
target/
|
|
||||||
|
|
||||||
# Jupyter Notebook
|
|
||||||
.ipynb_checkpoints
|
|
||||||
|
|
||||||
# IPython
|
|
||||||
profile_default/
|
|
||||||
ipython_config.py
|
|
||||||
|
|
||||||
# pyenv
|
|
||||||
.python-version
|
|
||||||
|
|
||||||
# pipenv
|
|
||||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
|
||||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
|
||||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
|
||||||
# install all needed dependencies.
|
|
||||||
#Pipfile.lock
|
|
||||||
|
|
||||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
|
||||||
__pypackages__/
|
|
||||||
|
|
||||||
# Celery stuff
|
|
||||||
celerybeat-schedule
|
|
||||||
celerybeat.pid
|
|
||||||
|
|
||||||
# SageMath parsed files
|
|
||||||
*.sage.py
|
|
||||||
|
|
||||||
# Environments
|
|
||||||
.env
|
|
||||||
.venv
|
|
||||||
env/
|
|
||||||
venv/
|
|
||||||
ENV/
|
|
||||||
env.bak/
|
|
||||||
venv.bak/
|
|
||||||
|
|
||||||
# Spyder project settings
|
|
||||||
.spyderproject
|
|
||||||
.spyproject
|
|
||||||
|
|
||||||
# Rope project settings
|
|
||||||
.ropeproject
|
|
||||||
|
|
||||||
# mkdocs documentation
|
|
||||||
/site
|
|
||||||
|
|
||||||
# mypy
|
|
||||||
.mypy_cache/
|
|
||||||
.dmypy.json
|
|
||||||
dmypy.json
|
|
||||||
|
|
||||||
# Pyre type checker
|
|
||||||
.pyre/
|
|
|
@ -1,21 +0,0 @@
|
||||||
MIT License
|
|
||||||
|
|
||||||
Copyright (c) 2021 Bryan Lee
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
|
||||||
in the Software without restriction, including without limitation the rights
|
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
|
||||||
furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
SOFTWARE.
|
|
|
@ -1,140 +0,0 @@
|
||||||
## PyTorch Implementation of [AnimeGANv2](https://github.com/TachibanaYoshino/AnimeGANv2)
|
|
||||||
|
|
||||||
|
|
||||||
**Updates**
|
|
||||||
|
|
||||||
* `2021-10-17` Add weights for [FacePortraitV2](#additional-model-weights)
|
|
||||||
* `2021-11-07` Thanks to [ak92501](https://twitter.com/ak92501), a web demo is integrated to [Huggingface Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio).
|
|
||||||
|
|
||||||
See demo: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/akhaliq/AnimeGANv2)
|
|
||||||
|
|
||||||
* `2021-11-07` Thanks to [xhlulu](https://github.com/xhlulu), the `torch.hub` model is now available. See [Torch Hub Usage](#torch-hub-usage).
|
|
||||||
* `2021-11-07` Add FacePortraitV2 style demo to a telegram bot. See [@face2stickerbot](https://t.me/face2stickerbot) by [sxela](https://github.com/sxela)
|
|
||||||
|
|
||||||
|
|
||||||
## Basic Usage
|
|
||||||
|
|
||||||
**Weight Conversion from the Original Repo (Requires TensorFlow 1.x)**
|
|
||||||
```
|
|
||||||
git clone https://github.com/TachibanaYoshino/AnimeGANv2
|
|
||||||
python convert_weights.py
|
|
||||||
```
|
|
||||||
|
|
||||||
**Inference**
|
|
||||||
```
|
|
||||||
python test.py --input_dir [image_folder_path] --device [cpu/cuda]
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
**Results from converted [[Paprika]](https://drive.google.com/file/d/1K_xN32uoQKI8XmNYNLTX5gDn1UnQVe5I/view?usp=sharing) style model**
|
|
||||||
|
|
||||||
(input image, original tensorflow result, pytorch result from left to right)
|
|
||||||
|
|
||||||
<img src="./samples/compare/1.jpg" width="960">
|
|
||||||
<img src="./samples/compare/2.jpg" width="960">
|
|
||||||
<img src="./samples/compare/3.jpg" width="960">
|
|
||||||
|
|
||||||
**Note:** Training code not included / Results from converted weights slightly different due to the [bilinear upsample issue](https://github.com/pytorch/pytorch/issues/10604)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Additional Model Weights
|
|
||||||
|
|
||||||
**Webtoon Face** [[ckpt]](https://drive.google.com/file/d/10T6F3-_RFOCJn6lMb-6mRmcISuYWJXGc)
|
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary>samples</summary>
|
|
||||||
|
|
||||||
Trained on <b>256x256</b> face images. Distilled from [webtoon face model](https://github.com/bryandlee/naver-webtoon-faces/blob/master/README.md#face2webtoon) with L2 + VGG + GAN Loss and CelebA-HQ images. See `test_faces.ipynb` for details.
|
|
||||||
|
|
||||||
<img src="./samples/face_results.jpg" width="512">
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
|
|
||||||
**Face Portrait v1** [[ckpt]](https://drive.google.com/file/d/1WK5Mdt6mwlcsqCZMHkCUSDJxN1UyFi0-)
|
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary>samples</summary>
|
|
||||||
|
|
||||||
Trained on <b>512x512</b> face images.
|
|
||||||
|
|
||||||
[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1jCqcKekdtKzW7cxiw_bjbbfLsPh-dEds?usp=sharing)
|
|
||||||
|
|
||||||
![samples](https://user-images.githubusercontent.com/26464535/127134790-93595da2-4f8b-4aca-a9d7-98699c5e6914.jpg)
|
|
||||||
|
|
||||||
[📺](https://youtu.be/CbMfI-HNCzw?t=317)
|
|
||||||
|
|
||||||
![sample](https://user-images.githubusercontent.com/26464535/129888683-98bb6283-7bb8-4d1a-a04a-e795f5858dcf.gif)
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
|
|
||||||
**Face Portrait v2** [[ckpt]](https://drive.google.com/uc?id=18H3iK09_d54qEDoWIc82SyWB2xun4gjU)
|
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary>samples</summary>
|
|
||||||
|
|
||||||
Trained on <b>512x512</b> face images. Compared to v1, `🔻beautify` `🔺robustness`
|
|
||||||
|
|
||||||
[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1jCqcKekdtKzW7cxiw_bjbbfLsPh-dEds?usp=sharing)
|
|
||||||
|
|
||||||
![face_portrait_v2_0](https://user-images.githubusercontent.com/26464535/137619176-59620b59-4e20-4d98-9559-a424f86b7f24.jpg)
|
|
||||||
|
|
||||||
![face_portrait_v2_1](https://user-images.githubusercontent.com/26464535/137619181-a45c9230-f5e7-4f3c-8002-7c266f89de45.jpg)
|
|
||||||
|
|
||||||
🦑 🎮 🔥
|
|
||||||
|
|
||||||
![face_portrait_v2_squid_game](https://user-images.githubusercontent.com/26464535/137619183-20e94f11-7a8e-4c3e-9b45-378ab63827ca.jpg)
|
|
||||||
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
|
|
||||||
## Torch Hub Usage
|
|
||||||
|
|
||||||
You can load Animegan v2 via `torch.hub`:
|
|
||||||
|
|
||||||
```python
|
|
||||||
import torch
|
|
||||||
model = torch.hub.load('bryandlee/animegan2-pytorch', 'generator').eval()
|
|
||||||
# convert your image into tensor here
|
|
||||||
out = model(img_tensor)
|
|
||||||
```
|
|
||||||
|
|
||||||
You can load with various configs (more details in [the torch docs](https://pytorch.org/docs/stable/hub.html)):
|
|
||||||
```python
|
|
||||||
model = torch.hub.load(
|
|
||||||
"bryandlee/animegan2-pytorch:main",
|
|
||||||
"generator",
|
|
||||||
pretrained=True, # or give URL to a pretrained model
|
|
||||||
device="cuda", # or "cpu" if you don't have a GPU
|
|
||||||
progress=True, # show progress
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
Currently, the following `pretrained` shorthands are available:
|
|
||||||
```python
|
|
||||||
model = torch.hub.load("bryandlee/animegan2-pytorch:main", "generator", pretrained="celeba_distill")
|
|
||||||
model = torch.hub.load("bryandlee/animegan2-pytorch:main", "generator", pretrained="face_paint_512_v1")
|
|
||||||
model = torch.hub.load("bryandlee/animegan2-pytorch:main", "generator", pretrained="face_paint_512_v2")
|
|
||||||
model = torch.hub.load("bryandlee/animegan2-pytorch:main", "generator", pretrained="paprika")
|
|
||||||
```
|
|
||||||
|
|
||||||
You can also load the `face2paint` util function. First, install dependencies:
|
|
||||||
|
|
||||||
```
|
|
||||||
pip install torchvision Pillow numpy
|
|
||||||
```
|
|
||||||
|
|
||||||
Then, import the function using `torch.hub`:
|
|
||||||
```python
|
|
||||||
face2paint = torch.hub.load(
|
|
||||||
'bryandlee/animegan2-pytorch:main', 'face2paint',
|
|
||||||
size=512, device="cpu"
|
|
||||||
)
|
|
||||||
|
|
||||||
img = Image.open(...).convert("RGB")
|
|
||||||
out = face2paint(model, img)
|
|
||||||
```
|
|
|
@ -1,140 +0,0 @@
|
||||||
import argparse
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import os
|
|
||||||
|
|
||||||
import tensorflow as tf
|
|
||||||
from AnimeGANv2.net import generator as tf_generator
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from model import Generator
|
|
||||||
|
|
||||||
|
|
||||||
def load_tf_weights(tf_path):
|
|
||||||
test_real = tf.placeholder(tf.float32, [1, None, None, 3], name='test')
|
|
||||||
with tf.variable_scope("generator", reuse=False):
|
|
||||||
test_generated = tf_generator.G_net(test_real).fake
|
|
||||||
|
|
||||||
saver = tf.train.Saver()
|
|
||||||
|
|
||||||
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, device_count = {'GPU': 0})) as sess:
|
|
||||||
ckpt = tf.train.get_checkpoint_state(tf_path)
|
|
||||||
|
|
||||||
assert ckpt is not None and ckpt.model_checkpoint_path is not None, f"Failed to load checkpoint {tf_path}"
|
|
||||||
|
|
||||||
saver.restore(sess, ckpt.model_checkpoint_path)
|
|
||||||
print(f"Tensorflow model checkpoint {ckpt.model_checkpoint_path} loaded")
|
|
||||||
|
|
||||||
tf_weights = {}
|
|
||||||
for v in tf.trainable_variables():
|
|
||||||
tf_weights[v.name] = v.eval()
|
|
||||||
|
|
||||||
return tf_weights
|
|
||||||
|
|
||||||
|
|
||||||
def convert_keys(k):
|
|
||||||
|
|
||||||
# 1. divide tf weight name in three parts [block_idx, layer_idx, weight/bias]
|
|
||||||
# 2. handle each part & merge into a pytorch model keys
|
|
||||||
|
|
||||||
k = k.replace("Conv/", "Conv_0/").replace("LayerNorm/", "LayerNorm_0/")
|
|
||||||
keys = k.split("/")[2:]
|
|
||||||
|
|
||||||
is_dconv = False
|
|
||||||
|
|
||||||
# handle C block..
|
|
||||||
if keys[0] == "C":
|
|
||||||
if keys[1] in ["Conv_1", "LayerNorm_1"]:
|
|
||||||
keys[1] = keys[1].replace("1", "5")
|
|
||||||
|
|
||||||
if len(keys) == 4:
|
|
||||||
assert "r" in keys[1]
|
|
||||||
|
|
||||||
if keys[1] == keys[2]:
|
|
||||||
is_dconv = True
|
|
||||||
keys[2] = "1.1"
|
|
||||||
|
|
||||||
block_c_maps = {
|
|
||||||
"1": "1.2",
|
|
||||||
"Conv_1": "2",
|
|
||||||
"2": "3",
|
|
||||||
}
|
|
||||||
if keys[2] in block_c_maps:
|
|
||||||
keys[2] = block_c_maps[keys[2]]
|
|
||||||
|
|
||||||
keys[1] = keys[1].replace("r", "") + ".layers." + keys[2]
|
|
||||||
keys[2] = keys[3]
|
|
||||||
keys.pop(-1)
|
|
||||||
assert len(keys) == 3
|
|
||||||
|
|
||||||
# handle output block
|
|
||||||
if "out" in keys[0]:
|
|
||||||
keys[1] = "0"
|
|
||||||
|
|
||||||
# first part
|
|
||||||
if keys[0] in ["A", "B", "C", "D", "E"]:
|
|
||||||
keys[0] = "block_" + keys[0].lower()
|
|
||||||
|
|
||||||
# second part
|
|
||||||
if "LayerNorm_" in keys[1]:
|
|
||||||
keys[1] = keys[1].replace("LayerNorm_", "") + ".2"
|
|
||||||
if "Conv_" in keys[1]:
|
|
||||||
keys[1] = keys[1].replace("Conv_", "") + ".1"
|
|
||||||
|
|
||||||
# third part
|
|
||||||
keys[2] = {
|
|
||||||
"weights:0": "weight",
|
|
||||||
"w:0": "weight",
|
|
||||||
"bias:0": "bias",
|
|
||||||
"gamma:0": "weight",
|
|
||||||
"beta:0": "bias",
|
|
||||||
}[keys[2]]
|
|
||||||
|
|
||||||
return ".".join(keys), is_dconv
|
|
||||||
|
|
||||||
|
|
||||||
def convert_and_save(tf_checkpoint_path, save_name):
|
|
||||||
|
|
||||||
tf_weights = load_tf_weights(tf_checkpoint_path)
|
|
||||||
|
|
||||||
torch_net = Generator()
|
|
||||||
torch_weights = torch_net.state_dict()
|
|
||||||
|
|
||||||
torch_converted_weights = {}
|
|
||||||
for k, v in tf_weights.items():
|
|
||||||
torch_k, is_dconv = convert_keys(k)
|
|
||||||
assert torch_k in torch_weights, f"weight name mismatch: {k}"
|
|
||||||
|
|
||||||
converted_weight = torch.from_numpy(v)
|
|
||||||
if len(converted_weight.shape) == 4:
|
|
||||||
if is_dconv:
|
|
||||||
converted_weight = converted_weight.permute(2, 3, 0, 1)
|
|
||||||
else:
|
|
||||||
converted_weight = converted_weight.permute(3, 2, 0, 1)
|
|
||||||
|
|
||||||
assert torch_weights[torch_k].shape == converted_weight.shape, f"shape mismatch: {k}"
|
|
||||||
|
|
||||||
torch_converted_weights[torch_k] = converted_weight
|
|
||||||
|
|
||||||
assert sorted(list(torch_converted_weights)) == sorted(list(torch_weights)), f"some weights are missing"
|
|
||||||
torch_net.load_state_dict(torch_converted_weights)
|
|
||||||
torch.save(torch_net.state_dict(), save_name)
|
|
||||||
print(f"PyTorch model saved at {save_name}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument(
|
|
||||||
'--tf_checkpoint_path',
|
|
||||||
type=str,
|
|
||||||
default='AnimeGANv2/checkpoint/generator_Paprika_weight',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--save_name',
|
|
||||||
type=str,
|
|
||||||
default='pytorch_generator_Paprika.pt',
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
convert_and_save(args.tf_checkpoint_path, args.save_name)
|
|
|
@ -1,63 +0,0 @@
|
||||||
import torch
|
|
||||||
|
|
||||||
def generator(pretrained=True, device="cpu", progress=True, check_hash=True):
|
|
||||||
from model import Generator
|
|
||||||
|
|
||||||
release_url = "https://github.com/bryandlee/animegan2-pytorch/raw/main/weights"
|
|
||||||
known = {
|
|
||||||
name: f"{release_url}/{name}.pt"
|
|
||||||
for name in [
|
|
||||||
'celeba_distill', 'face_paint_512_v1', 'face_paint_512_v2', 'paprika'
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
device = torch.device(device)
|
|
||||||
model = Generator().to(device)
|
|
||||||
|
|
||||||
if type(pretrained) == str:
|
|
||||||
# Look if a known name is passed, otherwise assume it's a URL
|
|
||||||
ckpt_url = known.get(pretrained, pretrained)
|
|
||||||
pretrained = True
|
|
||||||
else:
|
|
||||||
ckpt_url = known.get('face_paint_512_v2')
|
|
||||||
|
|
||||||
if pretrained is True:
|
|
||||||
state_dict = torch.hub.load_state_dict_from_url(
|
|
||||||
ckpt_url,
|
|
||||||
map_location=device,
|
|
||||||
progress=progress,
|
|
||||||
check_hash=check_hash,
|
|
||||||
)
|
|
||||||
model.load_state_dict(state_dict)
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def face2paint(device="cpu", size=512, side_by_side=False):
|
|
||||||
from PIL import Image
|
|
||||||
from torchvision.transforms.functional import to_tensor, to_pil_image
|
|
||||||
|
|
||||||
def face2paint(
|
|
||||||
model: torch.nn.Module,
|
|
||||||
img: Image.Image,
|
|
||||||
size: int = size,
|
|
||||||
side_by_side: bool = side_by_side,
|
|
||||||
device: str = device,
|
|
||||||
) -> Image.Image:
|
|
||||||
w, h = img.size
|
|
||||||
s = min(w, h)
|
|
||||||
img = img.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2))
|
|
||||||
img = img.resize((size, size), Image.LANCZOS)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
input = to_tensor(img).unsqueeze(0) * 2 - 1
|
|
||||||
output = model(input.to(device)).cpu()[0]
|
|
||||||
|
|
||||||
if side_by_side:
|
|
||||||
output = torch.cat([input[0], output], dim=2)
|
|
||||||
|
|
||||||
output = (output * 0.5 + 0.5).clip(0, 1)
|
|
||||||
|
|
||||||
return to_pil_image(output)
|
|
||||||
|
|
||||||
return face2paint
|
|
|
@ -1,110 +0,0 @@
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
class ConvNormLReLU(nn.Sequential):
|
|
||||||
def __init__(self, in_ch, out_ch, kernel_size=3, stride=1, padding=1, pad_mode="reflect", groups=1, bias=False):
|
|
||||||
|
|
||||||
pad_layer = {
|
|
||||||
"zero": nn.ZeroPad2d,
|
|
||||||
"same": nn.ReplicationPad2d,
|
|
||||||
"reflect": nn.ReflectionPad2d,
|
|
||||||
}
|
|
||||||
if pad_mode not in pad_layer:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
super(ConvNormLReLU, self).__init__(
|
|
||||||
pad_layer[pad_mode](padding),
|
|
||||||
nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=0, groups=groups, bias=bias),
|
|
||||||
nn.GroupNorm(num_groups=1, num_channels=out_ch, affine=True),
|
|
||||||
nn.LeakyReLU(0.2, inplace=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class InvertedResBlock(nn.Module):
|
|
||||||
def __init__(self, in_ch, out_ch, expansion_ratio=2):
|
|
||||||
super(InvertedResBlock, self).__init__()
|
|
||||||
|
|
||||||
self.use_res_connect = in_ch == out_ch
|
|
||||||
bottleneck = int(round(in_ch*expansion_ratio))
|
|
||||||
layers = []
|
|
||||||
if expansion_ratio != 1:
|
|
||||||
layers.append(ConvNormLReLU(in_ch, bottleneck, kernel_size=1, padding=0))
|
|
||||||
|
|
||||||
# dw
|
|
||||||
layers.append(ConvNormLReLU(bottleneck, bottleneck, groups=bottleneck, bias=True))
|
|
||||||
# pw
|
|
||||||
layers.append(nn.Conv2d(bottleneck, out_ch, kernel_size=1, padding=0, bias=False))
|
|
||||||
layers.append(nn.GroupNorm(num_groups=1, num_channels=out_ch, affine=True))
|
|
||||||
|
|
||||||
self.layers = nn.Sequential(*layers)
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
out = self.layers(input)
|
|
||||||
if self.use_res_connect:
|
|
||||||
out = input + out
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class Generator(nn.Module):
|
|
||||||
def __init__(self, ):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.block_a = nn.Sequential(
|
|
||||||
ConvNormLReLU(3, 32, kernel_size=7, padding=3),
|
|
||||||
ConvNormLReLU(32, 64, stride=2, padding=(0,1,0,1)),
|
|
||||||
ConvNormLReLU(64, 64)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.block_b = nn.Sequential(
|
|
||||||
ConvNormLReLU(64, 128, stride=2, padding=(0,1,0,1)),
|
|
||||||
ConvNormLReLU(128, 128)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.block_c = nn.Sequential(
|
|
||||||
ConvNormLReLU(128, 128),
|
|
||||||
InvertedResBlock(128, 256, 2),
|
|
||||||
InvertedResBlock(256, 256, 2),
|
|
||||||
InvertedResBlock(256, 256, 2),
|
|
||||||
InvertedResBlock(256, 256, 2),
|
|
||||||
ConvNormLReLU(256, 128),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.block_d = nn.Sequential(
|
|
||||||
ConvNormLReLU(128, 128),
|
|
||||||
ConvNormLReLU(128, 128)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.block_e = nn.Sequential(
|
|
||||||
ConvNormLReLU(128, 64),
|
|
||||||
ConvNormLReLU(64, 64),
|
|
||||||
ConvNormLReLU(64, 32, kernel_size=7, padding=3)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.out_layer = nn.Sequential(
|
|
||||||
nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0, bias=False),
|
|
||||||
nn.Tanh()
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, input, align_corners=True):
|
|
||||||
out = self.block_a(input)
|
|
||||||
half_size = out.size()[-2:]
|
|
||||||
out = self.block_b(out)
|
|
||||||
out = self.block_c(out)
|
|
||||||
|
|
||||||
if align_corners:
|
|
||||||
out = F.interpolate(out, half_size, mode="bilinear", align_corners=True)
|
|
||||||
else:
|
|
||||||
out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=False)
|
|
||||||
out = self.block_d(out)
|
|
||||||
|
|
||||||
if align_corners:
|
|
||||||
out = F.interpolate(out, input.size()[-2:], mode="bilinear", align_corners=True)
|
|
||||||
else:
|
|
||||||
out = F.interpolate(out, scale_factor=2, mode="bilinear", align_corners=False)
|
|
||||||
out = self.block_e(out)
|
|
||||||
|
|
||||||
out = self.out_layer(out)
|
|
||||||
return out
|
|
||||||
|
|
Before Width: | Height: | Size: 863 KiB |
Before Width: | Height: | Size: 1.3 MiB |
Before Width: | Height: | Size: 1.4 MiB |
Before Width: | Height: | Size: 843 KiB |
Before Width: | Height: | Size: 36 KiB |
Before Width: | Height: | Size: 26 KiB |
Before Width: | Height: | Size: 37 KiB |
Before Width: | Height: | Size: 30 KiB |
Before Width: | Height: | Size: 62 KiB |
Before Width: | Height: | Size: 37 KiB |
Before Width: | Height: | Size: 15 KiB |
Before Width: | Height: | Size: 39 KiB |
Before Width: | Height: | Size: 50 KiB |
Before Width: | Height: | Size: 22 KiB |
Before Width: | Height: | Size: 35 KiB |
Before Width: | Height: | Size: 131 KiB |
Before Width: | Height: | Size: 50 KiB |
Before Width: | Height: | Size: 37 KiB |
Before Width: | Height: | Size: 21 KiB |
Before Width: | Height: | Size: 35 KiB |
Before Width: | Height: | Size: 719 KiB |
Before Width: | Height: | Size: 536 KiB |
Before Width: | Height: | Size: 227 KiB |
|
@ -1,90 +0,0 @@
|
||||||
import argparse
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
import os
|
|
||||||
|
|
||||||
from model import Generator
|
|
||||||
|
|
||||||
torch.backends.cudnn.enabled = False
|
|
||||||
torch.backends.cudnn.benchmark = False
|
|
||||||
torch.backends.cudnn.deterministic = True
|
|
||||||
|
|
||||||
def load_image(image_path, x32=False):
|
|
||||||
img = cv2.imread(image_path).astype(np.float32)
|
|
||||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
||||||
h, w = img.shape[:2]
|
|
||||||
|
|
||||||
if x32: # resize image to multiple of 32s
|
|
||||||
def to_32s(x):
|
|
||||||
return 256 if x < 256 else x - x%32
|
|
||||||
img = cv2.resize(img, (to_32s(w), to_32s(h)))
|
|
||||||
|
|
||||||
img = torch.from_numpy(img)
|
|
||||||
img = img/127.5 - 1.0
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def test(args):
|
|
||||||
device = args.device
|
|
||||||
|
|
||||||
net = Generator()
|
|
||||||
net.load_state_dict(torch.load(args.checkpoint, map_location="cpu"))
|
|
||||||
net.to(device).eval()
|
|
||||||
print(f"model loaded: {args.checkpoint}")
|
|
||||||
|
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
|
||||||
|
|
||||||
for image_name in sorted(os.listdir(args.input_dir)):
|
|
||||||
if os.path.splitext(image_name)[-1].lower() not in [".jpg", ".png", ".bmp", ".tiff"]:
|
|
||||||
continue
|
|
||||||
|
|
||||||
image = load_image(os.path.join(args.input_dir, image_name), args.x32)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
input = image.permute(2, 0, 1).unsqueeze(0).to(device)
|
|
||||||
out = net(input, args.upsample_align).squeeze(0).permute(1, 2, 0).cpu().numpy()
|
|
||||||
out = (out + 1)*127.5
|
|
||||||
out = np.clip(out, 0, 255).astype(np.uint8)
|
|
||||||
|
|
||||||
cv2.imwrite(os.path.join(args.output_dir, image_name), cv2.cvtColor(out, cv2.COLOR_BGR2RGB))
|
|
||||||
print(f"image saved: {image_name}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument(
|
|
||||||
'--checkpoint',
|
|
||||||
type=str,
|
|
||||||
default='./pytorch_generator_Paprika.pt',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--input_dir',
|
|
||||||
type=str,
|
|
||||||
default='./samples/inputs',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--output_dir',
|
|
||||||
type=str,
|
|
||||||
default='./samples/results',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--device',
|
|
||||||
type=str,
|
|
||||||
default='cuda:0',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--upsample_align',
|
|
||||||
type=bool,
|
|
||||||
default=False,
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--x32',
|
|
||||||
action="store_true",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
test(args)
|
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
unzip
|
|
@ -1,6 +1,9 @@
|
||||||
# please keep the aiges for the latest
|
|
||||||
aiges
|
|
||||||
transformers
|
|
||||||
torch
|
torch
|
||||||
diffusers
|
torchvision
|
||||||
accelerate
|
Pillow
|
||||||
|
gdown
|
||||||
|
numpy
|
||||||
|
scipy
|
||||||
|
cmake
|
||||||
|
onnxruntime-gpu
|
||||||
|
opencv-python-headless
|