import gradio as gr from transformers import BeitImageProcessor, BeitForImageClassification def inference(img): pretrained_model_path = "beit-base-patch16-224-pt22k-ft22k" processor = BeitImageProcessor.from_pretrained(pretrained_model_path) model = BeitForImageClassification.from_pretrained(pretrained_model_path) inputs = processor(images=img, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits # model predicts one of the 21,841 ImageNet-22k classes predicted_class_idx = logits.argmax(-1).item() # print("Predicted class:", model.config.id2label[predicted_class_idx]) return model.config.id2label[predicted_class_idx] title = "Image classification:beit-base-patch16-224-pt22k-ft22k" description = "这是beit-base-patch16-224-pt22k-ft22k的Gradio Demo。用于图像分类。上传你想要的图像或者点击下面的示例来加载它。" article = "