import gradio as gr from transformers import BeitImageProcessor, BeitForImageClassification def inference(img): pretrained_model_path = "beit-base-patch16-224-pt22k-ft22k" processor = BeitImageProcessor.from_pretrained(pretrained_model_path) model = BeitForImageClassification.from_pretrained(pretrained_model_path) inputs = processor(images=img, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits # model predicts one of the 21,841 ImageNet-22k classes predicted_class_idx = logits.argmax(-1).item() # print("Predicted class:", model.config.id2label[predicted_class_idx]) return model.config.id2label[predicted_class_idx] title = "Image classification:beit-base-patch16-224-pt22k-ft22k" description = "Gradio Demo for beit-base-patch16-224-pt22k-ft22k. To use it, simply upload your image, or click one of the examples to load them." article = "

Github Repo Pytorch

visitor badge

" examples = [['example_cat.jpg'], ['Masahiro.png']] demo = gr.Interface( fn=inference, inputs=[gr.inputs.Image(type="pil")], outputs=gr.outputs.Textbox(), title=title, description=description, article=article, examples=examples) demo.launch()