import data import torch import gradio as gr from models import imagebind_model from models.imagebind_model import ModalityType from gradio.themes.utils import sizes theme = gr.themes.Default(radius_size=sizes.radius_none).set( block_label_text_color = '#4D63FF', block_title_text_color = '#4D63FF', button_primary_text_color = '#4D63FF', button_primary_background_fill='#FFFFFF', button_primary_border_color='#4D63FF', button_primary_background_fill_hover='#EDEFFF', ) css = "footer {visibility: hidden}" device = "cuda:0" if torch.cuda.is_available() else "cpu" model = imagebind_model.imagebind_huge(pretrained=True) model.eval() model.to(device) def audio_text(audio, text_list): audio_paths = [audio] labels = [label.strip(" ") for label in text_list.strip(" ").split("|")] inputs = { ModalityType.TEXT: data.load_and_transform_text(labels, device), ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device), } with torch.no_grad(): embeddings = model(inputs) scores = torch.softmax(embeddings[ModalityType.AUDIO] @ embeddings[ModalityType.TEXT].T, dim=-1).squeeze(0).tolist() score_dict = {label:score for label, score in zip(labels, scores)} print(score_dict) return score_dict with gr.Blocks(theme=theme, css=css) as demo: gr.Markdown("""
音频分类
""") with gr.Row(): with gr.Column(): audio = gr.inputs.Audio(type='filepath',label="音频输入") text = gr.inputs.Textbox(lines=1,label="类别") with gr.Row(): button = gr.Button("提交", variant="primary") outputs = gr.Label(label="类别") button.click(fn=audio_text, inputs=[audio, text], outputs=outputs) examples = gr.Examples(examples=[[".assets/dog_audio.wav", "A dog|A car|A bird"],[".assets/car_audio.wav", "A dog|A car|A bird"], [".assets/bird_audio.wav", "A dog|A car|A bird"]],inputs=[audio, text], label="例子") if __name__ == "__main__": demo.queue().launch(server_name = "0.0.0.0")