U-2-Net/app.py

50 lines
1.5 KiB
Python

import cv2
import paddlehub as hub
import gradio as gr
import torch
from gradio.themes.utils import sizes
theme = gr.themes.Default(radius_size=sizes.radius_none).set(
block_label_text_color = '#4D63FF',
block_title_text_color = '#4D63FF',
button_primary_text_color = '#4D63FF',
button_primary_background_fill='#FFFFFF',
button_primary_border_color='#4D63FF',
button_primary_background_fill_hover='#EDEFFF',
)
model = hub.Module(name='U2Net')
def infer(img):
result = model.Segmentation(images=[cv2.imread(img)],
paths=None,
batch_size=1,
input_size=320,
output_dir='output',
visualization=True)
return result[0]['front'][:,:,::-1], result[0]['mask']
with gr.Blocks(theme=theme, css="footer {visibility: hidden}") as demo:
gr.Markdown("""
<div align='center' ><font size='60'>显著性目标检测</font></div>
""")
with gr.Row():
with gr.Column():
input_image = gr.inputs.Image(label="Original Image", type="filepath")
with gr.Row():
button = gr.Button("提交", variant="primary")
with gr.Column():
output_image = [
gr.outputs.Image(type="numpy",label="Front"),
gr.outputs.Image(type="numpy",label="Mask")
]
button.click(fn=infer, inputs=input_image, outputs=output_image)
examples = gr.Examples(examples=[['fox.jpg'], ['parrot.jpg']], inputs=input_image, label="例子")
if __name__ == "__main__":
demo.queue().launch(server_name = "0.0.0.0")