image_bind/app.py

63 lines
2.1 KiB
Python

import data
import torch
import gradio as gr
from models import imagebind_model
from models.imagebind_model import ModalityType
from gradio.themes.utils import sizes
theme = gr.themes.Default(radius_size=sizes.radius_none).set(
block_label_text_color = '#4D63FF',
block_title_text_color = '#4D63FF',
button_primary_text_color = '#4D63FF',
button_primary_background_fill='#FFFFFF',
button_primary_border_color='#4D63FF',
button_primary_background_fill_hover='#EDEFFF',
)
css = "footer {visibility: hidden}"
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = imagebind_model.imagebind_huge(pretrained=True)
model.eval()
model.to(device)
def audio_text(audio, text_list):
audio_paths = [audio]
labels = [label.strip(" ") for label in text_list.strip(" ").split("|")]
inputs = {
ModalityType.TEXT: data.load_and_transform_text(labels, device),
ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device),
}
with torch.no_grad():
embeddings = model(inputs)
scores = torch.softmax(embeddings[ModalityType.AUDIO] @ embeddings[ModalityType.TEXT].T, dim=-1).squeeze(0).tolist()
score_dict = {label:score for label, score in zip(labels, scores)}
print(score_dict)
return score_dict
with gr.Blocks(theme=theme, css=css) as demo:
gr.Markdown("""
<div align='center' ><font size='60'>音频分类</font></div>
""")
with gr.Row():
with gr.Column():
audio = gr.inputs.Audio(type='filepath',label="音频输入")
text = gr.inputs.Textbox(lines=1,label="类别")
with gr.Row():
button = gr.Button("提交", variant="primary")
outputs = gr.Label(label="类别")
button.click(fn=audio_text, inputs=[audio, text], outputs=outputs)
examples = gr.Examples(examples=[[".assets/dog_audio.wav", "A dog|A car|A bird"],[".assets/car_audio.wav", "A dog|A car|A bird"], [".assets/bird_audio.wav", "A dog|A car|A bird"]],inputs=[audio, text], label="例子")
if __name__ == "__main__":
demo.queue().launch(server_name = "0.0.0.0")