diff --git a/vit-base-patch16-224/Dockerfile b/vit-base-patch16-224/Dockerfile new file mode 100644 index 0000000..0fb6f30 --- /dev/null +++ b/vit-base-patch16-224/Dockerfile @@ -0,0 +1,13 @@ +FROM python:3.7.4-slim + +WORKDIR /app + +COPY requirements.txt /app + +RUN pip config set global.index-url https://pypi.mirrors.ustc.edu.cn/simple/ + +RUN pip3 install --trusted-host pypi.python.org -r requirements.txt + +COPY . /app + +CMD ["python", "vit.py"] diff --git a/vit-base-patch16-224/cat.jpeg b/vit-base-patch16-224/cat.jpeg new file mode 100644 index 0000000..a260caf Binary files /dev/null and b/vit-base-patch16-224/cat.jpeg differ diff --git a/vit-base-patch16-224/dog.jpeg b/vit-base-patch16-224/dog.jpeg new file mode 100644 index 0000000..f83318c Binary files /dev/null and b/vit-base-patch16-224/dog.jpeg differ diff --git a/vit-base-patch16-224/requirements.txt b/vit-base-patch16-224/requirements.txt new file mode 100644 index 0000000..4f0806d --- /dev/null +++ b/vit-base-patch16-224/requirements.txt @@ -0,0 +1,4 @@ +gradio +huggingface +torch +transformers diff --git a/vit-base-patch16-224/vit.py b/vit-base-patch16-224/vit.py new file mode 100644 index 0000000..0eb80ed --- /dev/null +++ b/vit-base-patch16-224/vit.py @@ -0,0 +1,24 @@ +#图像分类 +import gradio as gr +from transformers import ViTFeatureExtractor, ViTForImageClassification + +feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224') +model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224') + +def image_classification(image): + inputs = feature_extractor(images=image, return_tensors="pt") + outputs = model(**inputs) + logits = outputs.logits + predicted_class_idx = logits.argmax(-1).item() + return model.config.id2label[predicted_class_idx] + +demo = gr.Interface(fn=image_classification, + inputs=gr.Image(), + outputs=gr.Label(num_top_classes=1), + title = "图像分类", + allow_flagging="never", + examples = ['cat.jpeg', 'dog.jpeg', 'zebra.jpeg']) + + +if __name__ == "__main__": + demo.queue(concurrency_count=3).launch(server_name = "0.0.0.0", server_port = 7000, max_threads=40) diff --git a/vit-base-patch16-224/zebra.jpeg b/vit-base-patch16-224/zebra.jpeg new file mode 100644 index 0000000..82efa7d Binary files /dev/null and b/vit-base-patch16-224/zebra.jpeg differ