44 lines
1.7 KiB
Python
44 lines
1.7 KiB
Python
import gradio as gr
|
|
import torch
|
|
from gradio.themes.utils import sizes
|
|
|
|
|
|
theme = gr.themes.Default(radius_size=sizes.radius_none).set(
|
|
block_label_text_color = '#4D63FF',
|
|
block_title_text_color = '#4D63FF',
|
|
button_primary_text_color = '#4D63FF',
|
|
button_primary_background_fill='#FFFFFF',
|
|
button_primary_border_color='#4D63FF',
|
|
button_primary_background_fill_hover='#EDEFFF',
|
|
)
|
|
|
|
model2 = torch.hub.load("AK391/animegan2-pytorch:main", "generator", pretrained=True, progress=False)
|
|
model1 = torch.hub.load("AK391/animegan2-pytorch:main", "generator", pretrained="face_paint_512_v1")
|
|
face2paint = torch.hub.load('AK391/animegan2-pytorch:main', 'face2paint', size=512,side_by_side=False)
|
|
|
|
def inference(img, ver):
|
|
if ver == 'version 2 (🔺 robustness,🔻 stylization)':
|
|
out = face2paint(model2, img)
|
|
else:
|
|
out = face2paint(model1, img)
|
|
return out
|
|
|
|
|
|
with gr.Blocks(theme=theme, css="footer {visibility: hidden}") as demo:
|
|
gr.Markdown("""
|
|
<div align='center' ><font size='60'>动漫风格迁移</font></div>
|
|
""")
|
|
with gr.Row():
|
|
with gr.Column():
|
|
image = gr.Image(label="图片", type='pil')
|
|
radio = gr.inputs.Radio(choices=['version 2 (🔺 robustness,🔻 stylization)', 'version 1 (🔺 stylization, 🔻 robustness)'],label='版本')
|
|
with gr.Row():
|
|
button = gr.Button("提交", variant="primary")
|
|
box2 = gr.Image(label="图片")
|
|
|
|
button.click(fn=inference, inputs=[image, radio], outputs=box2)
|
|
examples = gr.Examples(examples=[['groot.jpeg','version 2 (🔺 robustness,🔻 stylization)'],['gongyoo.jpeg','version 1 (🔺 stylization, 🔻 robustness)']], inputs=[image, radio], label="例子")
|
|
|
|
|
|
demo.launch(server_name = "0.0.0.0")
|