import gradio as gr from transformers import BeitImageProcessor, BeitForImageClassification def inference(img): pretrained_model_path = "beit-base-patch16-224-pt22k-ft22k" processor = BeitImageProcessor.from_pretrained(pretrained_model_path) model = BeitForImageClassification.from_pretrained(pretrained_model_path) inputs = processor(images=img, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits # model predicts one of the 21,841 ImageNet-22k classes predicted_class_idx = logits.argmax(-1).item() # print("Predicted class:", model.config.id2label[predicted_class_idx]) return model.config.id2label[predicted_class_idx] title = "Image classification:beit-base-patch16-224-pt22k-ft22k" description = "这是beit-base-patch16-224-pt22k-ft22k的Gradio Demo。用于图像分类。上传你想要的图像或者点击下面的示例来加载它。" article = "

Github Repo Pytorch

visitor badge

" examples = [['example_cat.jpg'], ['Masahiro.png']] demo = gr.Interface( fn=inference, inputs=[gr.inputs.Image(type="pil")], outputs=gr.outputs.Textbox(), title=title, description=description, article=article, examples=examples) demo.launch(server_name="0.0.0.0")