import gradio as gr from PIL import Image from transformers import BeitImageProcessor, BeitForImageClassification from PIL import Image def inference(img): pretrained_model_path = "beit-base-patch16-224-pt22k-ft22k" processor = BeitImageProcessor.from_pretrained(pretrained_model_path) model = BeitForImageClassification.from_pretrained(pretrained_model_path) inputs = processor(images=img, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits # model predicts one of the 21,841 ImageNet-22k classes predicted_class_idx = logits.argmax(-1).item() # print("Predicted class:", model.config.id2label[predicted_class_idx]) return model.config.id2label[predicted_class_idx] title = "beit-base-patch16-224-pt22k-ft22k" description = "Gradio Demo for beit-base-patch16-224-pt22k-ft22k. To use it, simply upload your image, or click one of the examples to load them." article = "

Github Repo Pytorch

visitor badge

" examples=[['example_cat.jpg'],['Masahiro.png']] demo = gr.Interface( fn=inference, inputs=[gr.inputs.Image(type="pil")], outputs=gr.outputs.Textbox(), title=title, description=description, article=article, examples=examples) demo.launch() ## # model_dir = "hub/animegan2-pytorch-main" # model_dir_weight = "hub/checkpoints/face_paint_512_v1.pt" # # model2 = torch.hub.load( # model_dir, # "generator", # pretrained=True, # progress=False, # source="local" # ) # model1 = torch.load(model_dir_weight) # face2paint = torch.hub.load( # model_dir, 'face2paint', # size=512,side_by_side=False, # source="local" # ) # # def inference(img, ver): # if ver == 'version 2 (🔺 robustness,🔻 stylization)': # out = face2paint(model2, img) # else: # out = face2paint(model1, img) # return out #