diff --git a/README.md b/README.md index 5fda3e1..2bde72c 100644 --- a/README.md +++ b/README.md @@ -31,24 +31,21 @@ fine-tuned versions on a task that interests you. ### How to use -Here is how to use this model to classify an image of CIFAR-100 into one of the 1,000 ImageNet classes: +Here is how to use this model to classify an image of the COCO 2017 dataset into one of the 1,000 ImageNet classes: ```python from transformers import ViTFeatureExtractor, ViTForImageClassification -from datasets import load_dataset -import numpy as np - +from PIL import Image +import requests +url = 'http://images.cocodataset.org/val2017/000000039769.jpg' +image = Image.open(requests.get(url, stream=True).raw) feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224') -model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224") - -dataset = load_dataset("cifar100", split='test') -image = np.asarray(dataset[2]['img'], dtype=np.uint8) -image = np.moveaxis(image, source=-1, destination=0) # change from (H, W, C) to (C, H, W) - -pixel_values = feature_extractor(image) -outputs = model(pixel_values) +model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224') +inputs = feature_extractor(images=image) +outputs = model(**inputs) logits = outputs.logits -predicted_class = logits.argmax(-1) +# model predicts one of the 1000 ImageNet classes +predicted_class = logits.argmax(-1).item() ``` Currently, both the feature extractor and model support PyTorch. Tensorflow and JAX/FLAX are coming soon, and the API of ViTFeatureExtractor might change.