50 lines
1.5 KiB
Python
50 lines
1.5 KiB
Python
import cv2
|
|
import paddlehub as hub
|
|
import gradio as gr
|
|
import torch
|
|
from gradio.themes.utils import sizes
|
|
|
|
|
|
theme = gr.themes.Default(radius_size=sizes.radius_none).set(
|
|
block_label_text_color = '#4D63FF',
|
|
block_title_text_color = '#4D63FF',
|
|
button_primary_text_color = '#4D63FF',
|
|
button_primary_background_fill='#FFFFFF',
|
|
button_primary_border_color='#4D63FF',
|
|
button_primary_background_fill_hover='#EDEFFF',
|
|
)
|
|
|
|
model = hub.Module(name='U2Net')
|
|
|
|
def infer(img):
|
|
result = model.Segmentation(images=[cv2.imread(img)],
|
|
paths=None,
|
|
batch_size=1,
|
|
input_size=320,
|
|
output_dir='output',
|
|
visualization=True)
|
|
return result[0]['front'][:,:,::-1], result[0]['mask']
|
|
|
|
|
|
with gr.Blocks(theme=theme, css="footer {visibility: hidden}") as demo:
|
|
gr.Markdown("""
|
|
<div align='center' ><font size='60'>显著性目标检测</font></div>
|
|
""")
|
|
with gr.Row():
|
|
with gr.Column():
|
|
input_image = gr.inputs.Image(label="Original Image", type="filepath")
|
|
with gr.Row():
|
|
button = gr.Button("提交", variant="primary")
|
|
with gr.Column():
|
|
output_image = [
|
|
gr.outputs.Image(type="numpy",label="Front"),
|
|
gr.outputs.Image(type="numpy",label="Mask")
|
|
]
|
|
|
|
button.click(fn=infer, inputs=input_image, outputs=output_image)
|
|
examples = gr.Examples(examples=[['fox.jpg'], ['parrot.jpg']], inputs=input_image, label="例子")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
demo.queue().launch(server_name = "0.0.0.0")
|