926 lines
3.4 MiB
Plaintext
926 lines
3.4 MiB
Plaintext
|
{
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 0,
|
|||
|
"metadata": {
|
|||
|
"colab": {
|
|||
|
"name": "Interacting with CLIP.ipynb",
|
|||
|
"provenance": [],
|
|||
|
"collapsed_sections": []
|
|||
|
},
|
|||
|
"kernelspec": {
|
|||
|
"name": "python3",
|
|||
|
"display_name": "Python 3"
|
|||
|
},
|
|||
|
"accelerator": "GPU",
|
|||
|
"widgets": {
|
|||
|
"application/vnd.jupyter.widget-state+json": {
|
|||
|
"1369964d45004b5e95a058910b2a33e6": {
|
|||
|
"model_module": "@jupyter-widgets/controls",
|
|||
|
"model_name": "HBoxModel",
|
|||
|
"state": {
|
|||
|
"_view_name": "HBoxView",
|
|||
|
"_dom_classes": [],
|
|||
|
"_model_name": "HBoxModel",
|
|||
|
"_view_module": "@jupyter-widgets/controls",
|
|||
|
"_model_module_version": "1.5.0",
|
|||
|
"_view_count": null,
|
|||
|
"_view_module_version": "1.5.0",
|
|||
|
"box_style": "",
|
|||
|
"layout": "IPY_MODEL_12e23e2819094ee0a079d4eb77cfc4f9",
|
|||
|
"_model_module": "@jupyter-widgets/controls",
|
|||
|
"children": [
|
|||
|
"IPY_MODEL_7a5f52e56ede4ac3abe37a3ece007dc9",
|
|||
|
"IPY_MODEL_ce8b0faa1a1340b5a504d7b3546b3ccb"
|
|||
|
]
|
|||
|
}
|
|||
|
},
|
|||
|
"12e23e2819094ee0a079d4eb77cfc4f9": {
|
|||
|
"model_module": "@jupyter-widgets/base",
|
|||
|
"model_name": "LayoutModel",
|
|||
|
"state": {
|
|||
|
"_view_name": "LayoutView",
|
|||
|
"grid_template_rows": null,
|
|||
|
"right": null,
|
|||
|
"justify_content": null,
|
|||
|
"_view_module": "@jupyter-widgets/base",
|
|||
|
"overflow": null,
|
|||
|
"_model_module_version": "1.2.0",
|
|||
|
"_view_count": null,
|
|||
|
"flex_flow": null,
|
|||
|
"width": null,
|
|||
|
"min_width": null,
|
|||
|
"border": null,
|
|||
|
"align_items": null,
|
|||
|
"bottom": null,
|
|||
|
"_model_module": "@jupyter-widgets/base",
|
|||
|
"top": null,
|
|||
|
"grid_column": null,
|
|||
|
"overflow_y": null,
|
|||
|
"overflow_x": null,
|
|||
|
"grid_auto_flow": null,
|
|||
|
"grid_area": null,
|
|||
|
"grid_template_columns": null,
|
|||
|
"flex": null,
|
|||
|
"_model_name": "LayoutModel",
|
|||
|
"justify_items": null,
|
|||
|
"grid_row": null,
|
|||
|
"max_height": null,
|
|||
|
"align_content": null,
|
|||
|
"visibility": null,
|
|||
|
"align_self": null,
|
|||
|
"height": null,
|
|||
|
"min_height": null,
|
|||
|
"padding": null,
|
|||
|
"grid_auto_rows": null,
|
|||
|
"grid_gap": null,
|
|||
|
"max_width": null,
|
|||
|
"order": null,
|
|||
|
"_view_module_version": "1.2.0",
|
|||
|
"grid_template_areas": null,
|
|||
|
"object_position": null,
|
|||
|
"object_fit": null,
|
|||
|
"grid_auto_columns": null,
|
|||
|
"margin": null,
|
|||
|
"display": null,
|
|||
|
"left": null
|
|||
|
}
|
|||
|
},
|
|||
|
"7a5f52e56ede4ac3abe37a3ece007dc9": {
|
|||
|
"model_module": "@jupyter-widgets/controls",
|
|||
|
"model_name": "FloatProgressModel",
|
|||
|
"state": {
|
|||
|
"_view_name": "ProgressView",
|
|||
|
"style": "IPY_MODEL_5e6adc4592124a4581b85f4c1f3bab4d",
|
|||
|
"_dom_classes": [],
|
|||
|
"description": "",
|
|||
|
"_model_name": "FloatProgressModel",
|
|||
|
"bar_style": "success",
|
|||
|
"max": 169001437,
|
|||
|
"_view_module": "@jupyter-widgets/controls",
|
|||
|
"_model_module_version": "1.5.0",
|
|||
|
"value": 169001437,
|
|||
|
"_view_count": null,
|
|||
|
"_view_module_version": "1.5.0",
|
|||
|
"orientation": "horizontal",
|
|||
|
"min": 0,
|
|||
|
"description_tooltip": null,
|
|||
|
"_model_module": "@jupyter-widgets/controls",
|
|||
|
"layout": "IPY_MODEL_4a61c10fc00c4f04bb00b82e942da210"
|
|||
|
}
|
|||
|
},
|
|||
|
"ce8b0faa1a1340b5a504d7b3546b3ccb": {
|
|||
|
"model_module": "@jupyter-widgets/controls",
|
|||
|
"model_name": "HTMLModel",
|
|||
|
"state": {
|
|||
|
"_view_name": "HTMLView",
|
|||
|
"style": "IPY_MODEL_b597cd6f6cd443aba4bf4491ac7f957e",
|
|||
|
"_dom_classes": [],
|
|||
|
"description": "",
|
|||
|
"_model_name": "HTMLModel",
|
|||
|
"placeholder": "",
|
|||
|
"_view_module": "@jupyter-widgets/controls",
|
|||
|
"_model_module_version": "1.5.0",
|
|||
|
"value": " 169001984/? [00:06<00:00, 25734958.25it/s]",
|
|||
|
"_view_count": null,
|
|||
|
"_view_module_version": "1.5.0",
|
|||
|
"description_tooltip": null,
|
|||
|
"_model_module": "@jupyter-widgets/controls",
|
|||
|
"layout": "IPY_MODEL_161969cae25a49f38aacd1568d3cac6c"
|
|||
|
}
|
|||
|
},
|
|||
|
"5e6adc4592124a4581b85f4c1f3bab4d": {
|
|||
|
"model_module": "@jupyter-widgets/controls",
|
|||
|
"model_name": "ProgressStyleModel",
|
|||
|
"state": {
|
|||
|
"_view_name": "StyleView",
|
|||
|
"_model_name": "ProgressStyleModel",
|
|||
|
"description_width": "initial",
|
|||
|
"_view_module": "@jupyter-widgets/base",
|
|||
|
"_model_module_version": "1.5.0",
|
|||
|
"_view_count": null,
|
|||
|
"_view_module_version": "1.2.0",
|
|||
|
"bar_color": null,
|
|||
|
"_model_module": "@jupyter-widgets/controls"
|
|||
|
}
|
|||
|
},
|
|||
|
"4a61c10fc00c4f04bb00b82e942da210": {
|
|||
|
"model_module": "@jupyter-widgets/base",
|
|||
|
"model_name": "LayoutModel",
|
|||
|
"state": {
|
|||
|
"_view_name": "LayoutView",
|
|||
|
"grid_template_rows": null,
|
|||
|
"right": null,
|
|||
|
"justify_content": null,
|
|||
|
"_view_module": "@jupyter-widgets/base",
|
|||
|
"overflow": null,
|
|||
|
"_model_module_version": "1.2.0",
|
|||
|
"_view_count": null,
|
|||
|
"flex_flow": null,
|
|||
|
"width": null,
|
|||
|
"min_width": null,
|
|||
|
"border": null,
|
|||
|
"align_items": null,
|
|||
|
"bottom": null,
|
|||
|
"_model_module": "@jupyter-widgets/base",
|
|||
|
"top": null,
|
|||
|
"grid_column": null,
|
|||
|
"overflow_y": null,
|
|||
|
"overflow_x": null,
|
|||
|
"grid_auto_flow": null,
|
|||
|
"grid_area": null,
|
|||
|
"grid_template_columns": null,
|
|||
|
"flex": null,
|
|||
|
"_model_name": "LayoutModel",
|
|||
|
"justify_items": null,
|
|||
|
"grid_row": null,
|
|||
|
"max_height": null,
|
|||
|
"align_content": null,
|
|||
|
"visibility": null,
|
|||
|
"align_self": null,
|
|||
|
"height": null,
|
|||
|
"min_height": null,
|
|||
|
"padding": null,
|
|||
|
"grid_auto_rows": null,
|
|||
|
"grid_gap": null,
|
|||
|
"max_width": null,
|
|||
|
"order": null,
|
|||
|
"_view_module_version": "1.2.0",
|
|||
|
"grid_template_areas": null,
|
|||
|
"object_position": null,
|
|||
|
"object_fit": null,
|
|||
|
"grid_auto_columns": null,
|
|||
|
"margin": null,
|
|||
|
"display": null,
|
|||
|
"left": null
|
|||
|
}
|
|||
|
},
|
|||
|
"b597cd6f6cd443aba4bf4491ac7f957e": {
|
|||
|
"model_module": "@jupyter-widgets/controls",
|
|||
|
"model_name": "DescriptionStyleModel",
|
|||
|
"state": {
|
|||
|
"_view_name": "StyleView",
|
|||
|
"_model_name": "DescriptionStyleModel",
|
|||
|
"description_width": "",
|
|||
|
"_view_module": "@jupyter-widgets/base",
|
|||
|
"_model_module_version": "1.5.0",
|
|||
|
"_view_count": null,
|
|||
|
"_view_module_version": "1.2.0",
|
|||
|
"_model_module": "@jupyter-widgets/controls"
|
|||
|
}
|
|||
|
},
|
|||
|
"161969cae25a49f38aacd1568d3cac6c": {
|
|||
|
"model_module": "@jupyter-widgets/base",
|
|||
|
"model_name": "LayoutModel",
|
|||
|
"state": {
|
|||
|
"_view_name": "LayoutView",
|
|||
|
"grid_template_rows": null,
|
|||
|
"right": null,
|
|||
|
"justify_content": null,
|
|||
|
"_view_module": "@jupyter-widgets/base",
|
|||
|
"overflow": null,
|
|||
|
"_model_module_version": "1.2.0",
|
|||
|
"_view_count": null,
|
|||
|
"flex_flow": null,
|
|||
|
"width": null,
|
|||
|
"min_width": null,
|
|||
|
"border": null,
|
|||
|
"align_items": null,
|
|||
|
"bottom": null,
|
|||
|
"_model_module": "@jupyter-widgets/base",
|
|||
|
"top": null,
|
|||
|
"grid_column": null,
|
|||
|
"overflow_y": null,
|
|||
|
"overflow_x": null,
|
|||
|
"grid_auto_flow": null,
|
|||
|
"grid_area": null,
|
|||
|
"grid_template_columns": null,
|
|||
|
"flex": null,
|
|||
|
"_model_name": "LayoutModel",
|
|||
|
"justify_items": null,
|
|||
|
"grid_row": null,
|
|||
|
"max_height": null,
|
|||
|
"align_content": null,
|
|||
|
"visibility": null,
|
|||
|
"align_self": null,
|
|||
|
"height": null,
|
|||
|
"min_height": null,
|
|||
|
"padding": null,
|
|||
|
"grid_auto_rows": null,
|
|||
|
"grid_gap": null,
|
|||
|
"max_width": null,
|
|||
|
"order": null,
|
|||
|
"_view_module_version": "1.2.0",
|
|||
|
"grid_template_areas": null,
|
|||
|
"object_position": null,
|
|||
|
"object_fit": null,
|
|||
|
"grid_auto_columns": null,
|
|||
|
"margin": null,
|
|||
|
"display": null,
|
|||
|
"left": null
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
},
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {
|
|||
|
"id": "YPHN7PJgKOzb"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"# Interacting with CLIP\n",
|
|||
|
"\n",
|
|||
|
"This is a self-contained notebook that shows how to download and run CLIP models, calculate the similarity between arbitrary image and text inputs, and perform zero-shot image classifications."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {
|
|||
|
"id": "53N4k0pj_9qL"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"# Preparation for Colab\n",
|
|||
|
"\n",
|
|||
|
"Make sure you're running a GPU runtime; if not, select \"GPU\" as the hardware accelerator in Runtime > Change Runtime Type in the menu. The next cells will install the `clip` package and its dependencies, and check if PyTorch 1.7.1 or later is installed."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"metadata": {
|
|||
|
"colab": {
|
|||
|
"base_uri": "https://localhost:8080/"
|
|||
|
},
|
|||
|
"id": "0BpdJkdBssk9",
|
|||
|
"outputId": "4d9b51f8-d255-4868-97f6-be0a67dadfae"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"! pip install ftfy regex tqdm\n",
|
|||
|
"! pip install git+https://github.com/openai/CLIP.git"
|
|||
|
],
|
|||
|
"execution_count": 1,
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Collecting ftfy\n",
|
|||
|
" Downloading ftfy-6.0.3.tar.gz (64 kB)\n",
|
|||
|
"\u001b[?25l\r\u001b[K |█████ | 10 kB 34.6 MB/s eta 0:00:01\r\u001b[K |██████████▏ | 20 kB 20.7 MB/s eta 0:00:01\r\u001b[K |███████████████▎ | 30 kB 16.4 MB/s eta 0:00:01\r\u001b[K |████████████████████▍ | 40 kB 15.2 MB/s eta 0:00:01\r\u001b[K |█████████████████████████▌ | 51 kB 7.0 MB/s eta 0:00:01\r\u001b[K |██████████████████████████████▋ | 61 kB 8.2 MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 64 kB 2.6 MB/s \n",
|
|||
|
"\u001b[?25hRequirement already satisfied: regex in /usr/local/lib/python3.7/dist-packages (2019.12.20)\n",
|
|||
|
"Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (4.41.1)\n",
|
|||
|
"Requirement already satisfied: wcwidth in /usr/local/lib/python3.7/dist-packages (from ftfy) (0.2.5)\n",
|
|||
|
"Building wheels for collected packages: ftfy\n",
|
|||
|
" Building wheel for ftfy (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
|
|||
|
" Created wheel for ftfy: filename=ftfy-6.0.3-py3-none-any.whl size=41934 sha256=ce00f21233e5e1a5c3e84204d651ae0c17403484418c330c69ebae7297ca7003\n",
|
|||
|
" Stored in directory: /root/.cache/pip/wheels/19/f5/38/273eb3b5e76dfd850619312f693716ac4518b498f5ffb6f56d\n",
|
|||
|
"Successfully built ftfy\n",
|
|||
|
"Installing collected packages: ftfy\n",
|
|||
|
"Successfully installed ftfy-6.0.3\n",
|
|||
|
"Collecting git+https://github.com/openai/CLIP.git\n",
|
|||
|
" Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-tbmjxrgj\n",
|
|||
|
" Running command git clone -q https://github.com/openai/CLIP.git /tmp/pip-req-build-tbmjxrgj\n",
|
|||
|
"Requirement already satisfied: ftfy in /usr/local/lib/python3.7/dist-packages (from clip==1.0) (6.0.3)\n",
|
|||
|
"Requirement already satisfied: regex in /usr/local/lib/python3.7/dist-packages (from clip==1.0) (2019.12.20)\n",
|
|||
|
"Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from clip==1.0) (4.41.1)\n",
|
|||
|
"Requirement already satisfied: torch in /usr/local/lib/python3.7/dist-packages (from clip==1.0) (1.9.0+cu102)\n",
|
|||
|
"Requirement already satisfied: torchvision in /usr/local/lib/python3.7/dist-packages (from clip==1.0) (0.10.0+cu102)\n",
|
|||
|
"Requirement already satisfied: wcwidth in /usr/local/lib/python3.7/dist-packages (from ftfy->clip==1.0) (0.2.5)\n",
|
|||
|
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch->clip==1.0) (3.7.4.3)\n",
|
|||
|
"Requirement already satisfied: pillow>=5.3.0 in /usr/local/lib/python3.7/dist-packages (from torchvision->clip==1.0) (7.1.2)\n",
|
|||
|
"Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from torchvision->clip==1.0) (1.19.5)\n",
|
|||
|
"Building wheels for collected packages: clip\n",
|
|||
|
" Building wheel for clip (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
|
|||
|
" Created wheel for clip: filename=clip-1.0-py3-none-any.whl size=1369080 sha256=9059060f3a406b8268cfcbbf03647ab9b59e660d03f86c6120a9dd06f2b82cec\n",
|
|||
|
" Stored in directory: /tmp/pip-ephem-wheel-cache-7v3mrcvj/wheels/fd/b9/c3/5b4470e35ed76e174bff77c92f91da82098d5e35fd5bc8cdac\n",
|
|||
|
"Successfully built clip\n",
|
|||
|
"Installing collected packages: clip\n",
|
|||
|
"Successfully installed clip-1.0\n"
|
|||
|
],
|
|||
|
"name": "stdout"
|
|||
|
}
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"metadata": {
|
|||
|
"id": "C1hkDT38hSaP",
|
|||
|
"colab": {
|
|||
|
"base_uri": "https://localhost:8080/"
|
|||
|
},
|
|||
|
"outputId": "70a44964-883d-4fd0-b95a-2c7f2b19aca9"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"import numpy as np\n",
|
|||
|
"import torch\n",
|
|||
|
"from pkg_resources import packaging\n",
|
|||
|
"\n",
|
|||
|
"print(\"Torch version:\", torch.__version__)\n"
|
|||
|
],
|
|||
|
"execution_count": 2,
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Torch version: 1.9.0+cu102\n"
|
|||
|
],
|
|||
|
"name": "stdout"
|
|||
|
}
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {
|
|||
|
"id": "eFxgLV5HAEEw"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"# Loading the model\n",
|
|||
|
"\n",
|
|||
|
"`clip.available_models()` will list the names of available CLIP models."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"metadata": {
|
|||
|
"id": "uLFS29hnhlY4",
|
|||
|
"colab": {
|
|||
|
"base_uri": "https://localhost:8080/"
|
|||
|
},
|
|||
|
"outputId": "11779e1e-8bdd-4167-c18e-d26bdd6b67db"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"import clip\n",
|
|||
|
"\n",
|
|||
|
"clip.available_models()"
|
|||
|
],
|
|||
|
"execution_count": 3,
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"output_type": "execute_result",
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"['RN50', 'RN101', 'RN50x4', 'RN50x16', 'ViT-B/32', 'ViT-B/16']"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {
|
|||
|
"tags": []
|
|||
|
},
|
|||
|
"execution_count": 3
|
|||
|
}
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"metadata": {
|
|||
|
"colab": {
|
|||
|
"base_uri": "https://localhost:8080/"
|
|||
|
},
|
|||
|
"id": "IBRVTY9lbGm8",
|
|||
|
"outputId": "f06fd2fd-6126-475b-87d0-b10aa3b7da49"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"model, preprocess = clip.load(\"ViT-B/32\")\n",
|
|||
|
"model.cuda().eval()\n",
|
|||
|
"input_resolution = model.visual.input_resolution\n",
|
|||
|
"context_length = model.context_length\n",
|
|||
|
"vocab_size = model.vocab_size\n",
|
|||
|
"\n",
|
|||
|
"print(\"Model parameters:\", f\"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}\")\n",
|
|||
|
"print(\"Input resolution:\", input_resolution)\n",
|
|||
|
"print(\"Context length:\", context_length)\n",
|
|||
|
"print(\"Vocab size:\", vocab_size)"
|
|||
|
],
|
|||
|
"execution_count": 4,
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"100%|███████████████████████████████████████| 338M/338M [00:05<00:00, 63.0MiB/s]\n"
|
|||
|
],
|
|||
|
"name": "stderr"
|
|||
|
},
|
|||
|
{
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Model parameters: 151,277,313\n",
|
|||
|
"Input resolution: 224\n",
|
|||
|
"Context length: 77\n",
|
|||
|
"Vocab size: 49408\n"
|
|||
|
],
|
|||
|
"name": "stdout"
|
|||
|
}
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {
|
|||
|
"id": "21slhZGCqANb"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"# Image Preprocessing\n",
|
|||
|
"\n",
|
|||
|
"We resize the input images and center-crop them to conform with the image resolution that the model expects. Before doing so, we will normalize the pixel intensity using the dataset mean and standard deviation.\n",
|
|||
|
"\n",
|
|||
|
"The second return value from `clip.load()` contains a torchvision `Transform` that performs this preprocessing.\n",
|
|||
|
"\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"metadata": {
|
|||
|
"id": "d6cpiIFHp9N6",
|
|||
|
"colab": {
|
|||
|
"base_uri": "https://localhost:8080/"
|
|||
|
},
|
|||
|
"outputId": "880cb98e-1e5e-430e-8b59-4bf35fa554f9"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"preprocess"
|
|||
|
],
|
|||
|
"execution_count": 5,
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"output_type": "execute_result",
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"Compose(\n",
|
|||
|
" Resize(size=224, interpolation=bicubic, max_size=None, antialias=None)\n",
|
|||
|
" CenterCrop(size=(224, 224))\n",
|
|||
|
" <function _transform.<locals>.<lambda> at 0x7f3a24ffb440>\n",
|
|||
|
" ToTensor()\n",
|
|||
|
" Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))\n",
|
|||
|
")"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {
|
|||
|
"tags": []
|
|||
|
},
|
|||
|
"execution_count": 5
|
|||
|
}
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {
|
|||
|
"id": "xwSB5jZki3Cj"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"# Text Preprocessing\n",
|
|||
|
"\n",
|
|||
|
"We use a case-insensitive tokenizer, which can be invoked using `clip.tokenize()`. By default, the outputs are padded to become 77 tokens long, which is what the CLIP models expects."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"metadata": {
|
|||
|
"colab": {
|
|||
|
"base_uri": "https://localhost:8080/"
|
|||
|
},
|
|||
|
"id": "qGom156-i2kL",
|
|||
|
"outputId": "050b0ce1-caba-47e1-f4ac-dba994599718"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"clip.tokenize(\"Hello World!\")"
|
|||
|
],
|
|||
|
"execution_count": 6,
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"output_type": "execute_result",
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"tensor([[49406, 3306, 1002, 256, 49407, 0, 0, 0, 0, 0,\n",
|
|||
|
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
|||
|
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
|||
|
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
|||
|
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
|||
|
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
|||
|
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
|||
|
" 0, 0, 0, 0, 0, 0, 0]])"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {
|
|||
|
"tags": []
|
|||
|
},
|
|||
|
"execution_count": 6
|
|||
|
}
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {
|
|||
|
"id": "4W8ARJVqBJXs"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"# Setting up input images and texts\n",
|
|||
|
"\n",
|
|||
|
"We are going to feed 8 example images and their textual descriptions to the model, and compare the similarity between the corresponding features.\n",
|
|||
|
"\n",
|
|||
|
"The tokenizer is case-insensitive, and we can freely give any suitable textual descriptions."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"metadata": {
|
|||
|
"id": "tMc1AXzBlhzm"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"import os\n",
|
|||
|
"import skimage\n",
|
|||
|
"import IPython.display\n",
|
|||
|
"import matplotlib.pyplot as plt\n",
|
|||
|
"from PIL import Image\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"\n",
|
|||
|
"from collections import OrderedDict\n",
|
|||
|
"import torch\n",
|
|||
|
"\n",
|
|||
|
"%matplotlib inline\n",
|
|||
|
"%config InlineBackend.figure_format = 'retina'\n",
|
|||
|
"\n",
|
|||
|
"# images in skimage to use and their textual descriptions\n",
|
|||
|
"descriptions = {\n",
|
|||
|
" \"page\": \"a page of text about segmentation\",\n",
|
|||
|
" \"chelsea\": \"a facial photo of a tabby cat\",\n",
|
|||
|
" \"astronaut\": \"a portrait of an astronaut with the American flag\",\n",
|
|||
|
" \"rocket\": \"a rocket standing on a launchpad\",\n",
|
|||
|
" \"motorcycle_right\": \"a red motorcycle standing in a garage\",\n",
|
|||
|
" \"camera\": \"a person looking at a camera on a tripod\",\n",
|
|||
|
" \"horse\": \"a black-and-white silhouette of a horse\", \n",
|
|||
|
" \"coffee\": \"a cup of coffee on a saucer\"\n",
|
|||
|
"}"
|
|||
|
],
|
|||
|
"execution_count": 7,
|
|||
|
"outputs": []
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"metadata": {
|
|||
|
"id": "NSSrLY185jSf",
|
|||
|
"colab": {
|
|||
|
"base_uri": "https://localhost:8080/",
|
|||
|
"height": 368
|
|||
|
},
|
|||
|
"outputId": "06451963-5ecb-4ddc-d0a8-24e9b110af7d"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"original_images = []\n",
|
|||
|
"images = []\n",
|
|||
|
"texts = []\n",
|
|||
|
"plt.figure(figsize=(16, 5))\n",
|
|||
|
"\n",
|
|||
|
"for filename in [filename for filename in os.listdir(skimage.data_dir) if filename.endswith(\".png\") or filename.endswith(\".jpg\")]:\n",
|
|||
|
" name = os.path.splitext(filename)[0]\n",
|
|||
|
" if name not in descriptions:\n",
|
|||
|
" continue\n",
|
|||
|
"\n",
|
|||
|
" image = Image.open(os.path.join(skimage.data_dir, filename)).convert(\"RGB\")\n",
|
|||
|
" \n",
|
|||
|
" plt.subplot(2, 4, len(images) + 1)\n",
|
|||
|
" plt.imshow(image)\n",
|
|||
|
" plt.title(f\"{filename}\\n{descriptions[name]}\")\n",
|
|||
|
" plt.xticks([])\n",
|
|||
|
" plt.yticks([])\n",
|
|||
|
"\n",
|
|||
|
" original_images.append(image)\n",
|
|||
|
" images.append(preprocess(image))\n",
|
|||
|
" texts.append(descriptions[name])\n",
|
|||
|
"\n",
|
|||
|
"plt.tight_layout()\n"
|
|||
|
],
|
|||
|
"execution_count": 8,
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"output_type": "display_data",
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAACLgAAAK+CAYAAACrNjW2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAWJQAAFiUBSVIk8AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOydebgcRdX/P18SCEsIIezKEsQFBBUBEVmjAqLoC7IoKkoURF9eFxRBQdAoKqAiijtrXJBFRFBREJGwqCjyY1cQhAiykz0hC0nO749Tnen07Z470zP3zr3kfJ6nn57p2k5XV1edqj5VJTMjCIIgCIIgCIIgCIIgCIIgCIIgCIIgCIIgCIYqK/VagCAIgiAIgiAIgiAIgiAIgiAIgiAIgiAIgiBoRhi4BEEQBEEQBEEQBEEQBEEQBEEQBEEQBEEQBEOaMHAJgiAIgiAIgiAIgiAIgiAIgiAIgiAIgiAIhjRh4BIEQRAEQRAEQRAEQRAEQRAEQRAEQRAEQRAMacLAJQiCIAiCIAiCIAiCIAiCIAiCIAiCIAiCIBjShIFLEARBEARBEARBEARBEARBEARBEARBEARBMKQJA5cgCIIgCIIgCIIgCIIgCIIgCIIgCIIgCIJgSBMGLkEQBEEQBEEQBEEQBEEQBEEQBEEQBEEQBMGQJgxcgiAIgiAIgiAIgiAIgiAIgiAIgiAIgiAIgiFNGLgEQRAEQRAEQRAEQRAEQRAEQRAEQRAEQRAEQ5owcAmCIAiCIAiCIAiCIAiCIAiCIAiCIAiCIAiGNGHgEgRBEARBEARBEARBEARBEARBEARBEARBEAxpwsAlCIIgCIIgCIIgCIIgCIIgCIIgCIIgCIIgGNKEgUsQBEEQBEEQBEEQBEEQBMEQQNIkSSZpcq9laZXhKHMQBEEQBA0kTU5t+aQVId0gCIY3I3stQBAEQRAEQRAEvUXSBGACcLuZXd5baYYXko4GxgKTzWxqj8UJgiAIgiBoC0kTgfHA5WZ2e2+lCYIgCIIgCIIgaE6s4BIEQRAEQRAEwQTg88D+PZZjOHI0nnfjeyxHEARBEARBHSbiusy2HcTxDHAf8Hg3BAqCIAiCIAiCIKgiVnAJgiAIgiAIgiAIgiAIgiAIamFm3wG+02s5giAIgiAIgiB4/hMruARBEARBEARBEARBEARBEARBEARBEARBEARDmjBwCYIgCIIgCIIhiqR1JR0l6QpJ90qaI2mepH9I+oakF1SEW1/S1yTdnfwvkPSIpD9L+qKkzZK/8ZIMX5Ye4DBJVjjG5/0m/0jaSdKlkh6XtETSNwsyvF7SZZKekLQonX8p6Q1N7ndZmpI2lXS2pP9KWijpIUlflzSmy3k1IaU5tYlcE5OfKblrk1JebJYuXVfItyllcVXEPyWFmShpbUlnSHowPbf/SjpL0kYVYSelsJPT/8Mk/TXd/2xJ10naq5/0Xy7pYklPSZqf8u8LklYtxh8EQRAEzzckTU1t3QRJL5T0vdQOL5R0e87fFpJ+mGujZ0i6QdIRkkb0k8Ymkk5PutmcdPxD0rmSXt+mvMcneRdI2q/gtp6kUyTdJWlu0oXulvRlSeMKficmXWaPdOn8gi4ztQ2ZKvWFgn63jaSLkl64IOkcJ0ka1U/8mX4zT9L0pN+8Nbkte36tyhsEQRAEKwqStpL0A0n/kvSspJlJTzhT0vYVYUZIOlrSHSnMdEm/kbRDP2mNlnSCpFskzUpt/f0prU1qyN7S2FZJuJb1oVyYWmNKLdzDcuNJSae5WT5eM0vStZL2qQhbHIer0qNWaZL+aklPuy+FeTzFsU0x/iAYTsQWRUEQBEEQBEEwdPkMcEz6vRiYDawFbJWOQyXtaWZ3ZgFSB/8vQGYQsSSFeyGwMfA64DHgB8ntSWA0sAawAJhVkGFJUShJ7wR+ivcnZhX9SPoS8Nn015Kf9YH9gf0lnWpmxze571cB5wHjgDm4Yf74lBd7SNrZzJ4rhGk7rzpkLp536yX5ZgCLcu7Ta8S5DnALsAUwH7+PFwIfxPNtDzP7Z1VgSecAh+PPYx4wBpgA7C7pHWb2i5IwewK/BlZNl2YDmwOfA/YGptS4jyAIgiAYjrwU+DmwLvAssEzXSMYUP6fRXs7Cdafd0vFOSfub2bxipJIOBH4CrJYuLcDb+S1xHeWNuJ7TL5JOA47D2/n9zOzanNuuwBW4/gSulywFtk7HeyXtZWb3Jff5uC4zDlgZ1wHm55J7uhWZ2mBn4Cw832YDAl4GfBF4S5JtbjGQpLOBI9Lfpem+9gAmSDq6yzIGQRAEwfMGSR8FzgAyQ9x5+BjNNul4JT5mkGckcCXwJlwXWgisDewLvFHSG8zsLyVpbQX8jsYkoMUp7IuBj+JjMm8zsz+1KHs7Y1v5cO3qQxkDPqYk6Qzg6CRPFv8bgDdIOtbMvt4k7N7A5bg+OQvX3TI9ant8vK0YZi3g2uQOnherA+8E3gocWfdegqDXxAouQRAEQRAEQTB0eRg4AR90WM3M1gFGATsAV+PGFT+TpFyYz+MDAA8AuwOrmNk4vBP8CuBLwBMAZvaImW0IZJ3oi81sw8LxSIlc5+ADBpub2Vi8g/xNAEmH0DBu+Q6wvpmtnWT9drr+GUmHNrnvycDtwCvMbAxugHM4PjiyA27w0Y28qo2ZfT3lXZY/BxTy7YAa0Z4ErAm8DRhtZqPxwaaHcPl/LmnlirD7Ae8B/hcYY2ZrAS8CbsD7fd+WtNwEB0nrAhfhH+v+huf3Wnh+vwcf8PpwjfsIgiAIguHI6cDjwC5mtkZqhw+StAWN9vJ6YMuk/6wJfAjXT/YEvlWMUNLOKexqwHXAjsDqSTdbC3g78Mf+BJO0kqQf4MYtM4G9CsYtm+EGq+OA7wMvSWmuget/vwc2AS5TWm3GzC5OusyfUzQfL+gyr2kx31rle8A/gFcmfWNN4P24Uc1OwDdK7vv9NIxbTgHGJb1yQ+Bc4Gu4jhQEQRAEQQ5JBwNn4sYtlwIvN7PRqR1dBzgUuLUk6P8Br8GNIEab2Zr4JKS7cV2oTN9ZC/gtbtzy8+R/1aRLbQH8DDeS+YWksS3eQstjWzk52taHcgz0mNKrceOW02joMy8ELkjuX03GOVVcnO4tG4cbAxyPGyztJ+ktJWHOxI1b5gHvxZ/nWvhYz13Ad2veSxD0nDBwCYIgCIIgCIIhipmdaWanmNldZrY4XVtiZrfiBg3/wGeg7J4LtlM6n2hmN5rZ0hRuoZndbWYnmdnlHYp2B/AOM5ua4l5sZlNTR//k5OciM/uomT2T/Ewzs48BFyb3kyVV9UceBd5iZnfnZD8PODu5H1QMUDOvhhpjgAPN7De553Y98GZ8ps3W+CBTGWOBI8zsB2b2bAr7EPCuFHYjfOZ0no/iA1tPAW/K5fdzZvYz3Kio1cGnIAiCIBjuLMYNRzKDD8zsAfxjxxrAv3H95L7kttDMzgI+lrx/QNKLC3Gegc+EvgFva28xM0vh55jZ5Wb2gWZCJePWC3BjmqeACSUzp7+Mt9mnmtlRZvaAmS1Nx9248eydwMtxo5pesBDYx8zuAjCzRWY2GTgquR8uadPMc9IrP5f+nm1mJ5jZrBT2KTM7ArgGN7QOgiAIgiCRdIcz0t8Lzezg/GqwZjbdzC4ws2NKgo/FV4m7xMwWJf93AhOT+2vy7XXiWHw1ugvN7B1mdqeZLUlhHzSz9wBXARvQMFztjzpjW7X1oUEYUxoDnGNmn8npM4/jhifX4SvbTWoS/hbgkNw43DwzOxVfbQcK42SSXpTiBjjSzH5qaSVkM7sH2IflV+4LgmFFGLgEQRAEQRAEwTDEzBbig/oAu+ScZqfzRgwcp2eDCwW2xZefBZ9NU8YX0nk8Pou5jG+k+yuSDV5s04qQGU3yaqhxo5ndVLyYPqRdmv72Me5JPIzPiiqGfQxfnQX65lu2ysxZZjazJOwlwIMtyB0EQRAEzwd+bGZP5i8kI4sD098zMiPSAufgxrki105L2pKGrnOc9d1esV8krQb8EjgEXzVuNzO7o+BndeBgfLn7PquggBuT0NAl9mpXj
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 1152x360 with 8 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {
|
|||
|
"tags": [],
|
|||
|
"image/png": {
|
|||
|
"width": 1116,
|
|||
|
"height": 351
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {
|
|||
|
"id": "WEVKsji6WOIX"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"## Building features\n",
|
|||
|
"\n",
|
|||
|
"We normalize the images, tokenize each text input, and run the forward pass of the model to get the image and text features."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"metadata": {
|
|||
|
"id": "HBgCanxi8JKw"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"image_input = torch.tensor(np.stack(images)).cuda()\n",
|
|||
|
"text_tokens = clip.tokenize([\"This is \" + desc for desc in texts]).cuda()"
|
|||
|
],
|
|||
|
"execution_count": 9,
|
|||
|
"outputs": []
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"metadata": {
|
|||
|
"id": "ZN9I0nIBZ_vW"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"with torch.no_grad():\n",
|
|||
|
" image_features = model.encode_image(image_input).float()\n",
|
|||
|
" text_features = model.encode_text(text_tokens).float()"
|
|||
|
],
|
|||
|
"execution_count": 10,
|
|||
|
"outputs": []
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {
|
|||
|
"id": "cuxm2Gt4Wvzt"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"## Calculating cosine similarity\n",
|
|||
|
"\n",
|
|||
|
"We normalize the features and calculate the dot product of each pair."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"metadata": {
|
|||
|
"id": "yKAxkQR7bf3A"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"image_features /= image_features.norm(dim=-1, keepdim=True)\n",
|
|||
|
"text_features /= text_features.norm(dim=-1, keepdim=True)\n",
|
|||
|
"similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T"
|
|||
|
],
|
|||
|
"execution_count": 11,
|
|||
|
"outputs": []
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"metadata": {
|
|||
|
"id": "C5zvMxh8cU6m",
|
|||
|
"colab": {
|
|||
|
"base_uri": "https://localhost:8080/",
|
|||
|
"height": 831
|
|||
|
},
|
|||
|
"outputId": "22bca748-ab42-4888-c9da-8f22c21c6185"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"count = len(descriptions)\n",
|
|||
|
"\n",
|
|||
|
"plt.figure(figsize=(20, 14))\n",
|
|||
|
"plt.imshow(similarity, vmin=0.1, vmax=0.3)\n",
|
|||
|
"# plt.colorbar()\n",
|
|||
|
"plt.yticks(range(count), texts, fontsize=18)\n",
|
|||
|
"plt.xticks([])\n",
|
|||
|
"for i, image in enumerate(original_images):\n",
|
|||
|
" plt.imshow(image, extent=(i - 0.5, i + 0.5, -1.6, -0.6), origin=\"lower\")\n",
|
|||
|
"for x in range(similarity.shape[1]):\n",
|
|||
|
" for y in range(similarity.shape[0]):\n",
|
|||
|
" plt.text(x, y, f\"{similarity[y, x]:.2f}\", ha=\"center\", va=\"center\", size=12)\n",
|
|||
|
"\n",
|
|||
|
"for side in [\"left\", \"top\", \"right\", \"bottom\"]:\n",
|
|||
|
" plt.gca().spines[side].set_visible(False)\n",
|
|||
|
"\n",
|
|||
|
"plt.xlim([-0.5, count - 0.5])\n",
|
|||
|
"plt.ylim([count + 0.5, -2])\n",
|
|||
|
"\n",
|
|||
|
"plt.title(\"Cosine similarity between text and image features\", size=20)"
|
|||
|
],
|
|||
|
"execution_count": 12,
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"output_type": "execute_result",
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"Text(0.5, 1.0, 'Cosine similarity between text and image features')"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {
|
|||
|
"tags": []
|
|||
|
},
|
|||
|
"execution_count": 12
|
|||
|
},
|
|||
|
{
|
|||
|
"output_type": "display_data",
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAACB0AAAY5CAYAAAATvfRQAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAWJQAAFiUBSVIk8AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzdd5gkV3no/+/b3TOzSbvKCJRWgAgi2iLDBUnYgA02YAzYJlgkG2OM7wUM92ewr8ABA7KNL2BsTBBOmGRsw8WIKBBBiCBAZBASiiitwqaZ6XB+f5zq3eqank4Td/X9PE89011zqupU6FPddd5zTqSUkCRJkiRJkiRJkiRJGldtrTMgSZIkSZIkSZIkSZIOTAYdSJIkSZIkSZIkSZKkiRh0IEmSJEmSJEmSJEmSJmLQgSRJkiRJkiRJkiRJmohBB5IkSZIkSZIkSZIkaSIGHUiSJEmSJEmSJEmSpIkYdCBJkiRJkiRJkiRJkiZi0IEkSZIkSZIkSZIkSZqIQQeSJEmSJEmSJEmSJGkiBh1IkiRJkiRJkiRJkqSJGHQgSZIkSZIkSZIkSZImYtCBJEmSJEmSJEmSJEmaiEEHkiRJkiRJkiRJkiRpIgYdSJIkSZK0QiJia0ScFhHPjog/iIhXRsSLIuIZEfGQiNi41nlcaRGRStM5a52f24riuisf+zPXOk+TGGc/ImJ7Je1Zq5fTlXFb/vxExGWlfT9vrfMjSQe6iDinfF9ZgfVbbmtkETFV/Cb6YET8JCJ2Vb73vGGt8yhpPI21zoAkSZIkSQeTiNgMnAk8DXgggwP+WxHxFeCfgXenlHasfA4lSZIkaW1ExInAfwH3Xuu8SFo+9nQgSZIkSdIyiYhnA5cBbwIezPDf3Q3gQUX6KyPizyLikBXNpKSDki1M16eVblmsA1tEnFe6Pi5b6/yspog4q9Kqefta50nSyouIKeA/uI0FHBwsvZBJg9jTgSRJkiRJS1QMk/CPwK/2+XcH+DbwU+AGYCtwe+BkoBxgsBH4Q+B+wKNXMr+SJEmStAaeBNy39P5i4M+Kv3tK829dzUxJWjqDDiRJkiRJWoKImAE+ApxW+df3gNcAH0kp3dBnuWngdOApwDPZ/xt9ZsUyK0mSJElr51dKr+eAn08pXbtWmZG0fAw6kCRJkiRpac6mN+AgkXssODul1FpsoZTSPHAucG5EvBZ4HfD4FcznmkgpxVrn4bYopXQecMAf+4NlPybl50eSdKBIKW1f6zzogHBq6fUXDTiQDh7DxpaUJEmSJEmLiIjHAy8szUrAs1JKfzEo4KAqpfSDlNITgJcCIy8nSZIkSQeQo0uvr16zXEhadvZ0IEmSJEnSBCKiBvxVZfabU0rvmnSdKaW/jIgPLi1nkiRJkrQubSm9bq5ZLiQtO4MOJEmSJEmazBOBO5beXw3876WuNKX043HSR8R9gHuSWw1NA9cBlwJfKIZwmEhEHAn8LHAnYBtQB3aX1v/NlNLuSdc/YZ7uB9wNOBbYC1wFnJdSunEZ1n0o8FDgDsCR5H39Kbnb1yuWuv4x83IScF/gOOAQoFPk5yrgEuDb4/SksQz52Qw8HDiefGx2ABeklL4+ZLkNwMOAU8j7cQPwzWLZtKKZHlNENIC7F9PtyQ/E95D39VvA11NK7WXe5rHA/YvtHQ7cCLw7pXTLcm7nYBURDwDuQv7M7gYuBz6dUtq1DOueBh4MbCeXrR1y2ffNlNI3lrr+g0FE1MnX78nAUeTnzNcB3wMuTCl1lmEbq1Yur+T9ZbVFRAD3IZe9RwMbyOfmEvJ3Ayv5WLNy//bAg8j30w3A9eTPy7eXuN4Z8lBfJwGHAtcAPyaf72Xdh9VS+kzegdwL2HfJn8m9Q5a7F7lsuh0wC1wGfCKltHPCfARwV/J10v1eNke+Tn4AfHkp37dL27kd+bvWseTy9Ery972Ll7ruPtu6K/l75tHk6/4G4CfA+cOO73pTfEd9KPncHEU+N9cBX0kp/WCJ674juRw9EdhKvg53kH8HXZBS2rOU9R+IirLmYeQy7Bhgnvy5/NqQ5Vbkfr4ef6/e5qWUnJycnJycnJycnJycnJycxpyA88jDKXSnV63itjcCfwhcUclDedoJvBM4bsx1Pxz4GNAesO5EfvB2ITnQojFgfeVlzhmy7b5pgWcA31kkH23g3cDxEx7LRxXnsjVgXy8CfnmFz2kAzwa+MeS4J/LDtI8CTxmwvtMqy5w5blryQ8G3ALcuko8LgZ9d5Pr8U+CWRZb7MfCLIx6XcfZjeyXtWUPWvQ14FvChAfvYnW4B/hq4wxjn9JzyOkrzHwp8kv6fsfuO8/kBzhrheuk3bS+W/7PK/F+Y4Np9UWUdz1imz8RlpXWeV8yrAb9Drujpt197gLcCh024zbsC/0ouPxc7dlcCLwamx7huR53OLK3jotL8L42Q9+q53DMoj8UyT6ws86gRtnMs8HfkIJnF9uN6chlwyITnYdnL5cU+S6zg/WVAXs6c8Po4bch6DwNeSw6CXGwdtwJvBo4asq53VZZ76xj795rKsh8Covjf9gn3/axlOvZrVe7fFfhPcqvuftv6LvDYCfZnY3HOb15kvVeSvy9OD8rfMl7bl5XWf964acnfhZ5bHI9++7MD+L1F1vc44OJFlpstrssNYxzXpwDvJVfID7pO9gDvAE6e8JjdFfgIi5d3FwG/Mskx7rNP/x/5O9hi+7IX+BeK7wjLeF2cOeQY9pvOGbLO+wMfLs7tYuv4QbHt2oj5nAIeSy7/rhqSv3ngA/T5HjzkWh91WnBul3Dut1fWfdaAtKdV0p5ZzD8U+Fvgpj55fcOA9a3I7yyW+feq0/JNa54BJycnJycnJycnJycnJ6cDbQI2sfDB8fZV2vYp5JYboz602gM8bcR1/+kED8UScOiAdY7zALEnLbnnhn8cMQ9XA6eMcRwPIVcAjLOf7wVmVuh6OneC4/71Aes8rZL2zHHSklv0DQpq6U67gTNK6zqG0QInOsBvjnBsxtmP7ZW0Zw1Z92cnOOY3lvd3yPrPKS9bzPsDBj94Xe2gg+30PrD9wATXb/l83wRsXKbPxWWl9Z5HLg/+Y8T9uwa4zxjbCuDVLF4h2G/6FotURve5bkedziyt4+zS/Bawbcg+XNBnfY8YssybSmnngE1D0j+ffE8ZdX+uAO41xnlYsXK5ssw5rOD9ZYS8nDnh9XHagHU+gf6VQYtNNwOnD1jfZnKvFeVlnjrCvj2KXL53l7kSOKL0/+0T7vtZy3Ts16Lc/1UGBzKVp/81xr6cwOIBWNXp8+RKwwX5W86JJQQdkFv3/9uI+/OWyrpeO+Jy/82QYKxifaOWDeVpD/DrYx6vX2dwpXl5OnvcY1zazkMYXolenmbH3Zch2z9zguN5ziLrmiIHF46zrvMY8FultO5XT5DPFvCSMa71kfO8lM9XZbntlXWfNSDtaZW0Z5J7xRi0DwuCDljZ+/my/151Wr7J4RUkSZIkSRrfg+gdsvAnKaXLVnqjxVAKnyJ3xV52KbkCbJY85MPPkivRILds+qeI2JxSeuuAdT8XeEVl9hzwdXLF0Sy5C9bbAfcgdzO60t5MboUKuTLwy+QKjCngXsCdS2lvD7wvIn4mDenmtuiK8+Pkh2hlNwJfI7dq21xs46TS/58MbIuIX0jL0HV4yVvJFTVlN5Erc68lP9DcSu7e9W7kyrKVdAS5J4XjivfXAF8lV5qcADyQ/df/JuC9RVe9e8nBE/cu7cOF5JaJR5Nb+G8o/hfAWyPii2mJ3d8uQa3y/lpyi+cd5Ov9UHKQT/kaOBz4SEQ8MI3ZzX5EPBV4XWnWJcX29pC7m33AWLlfBimlyyLi48Cji1m/FBFHpZSuH2X5YoiDe5dm/Wtaua6Z3ww8vnidyJ/VHwMzRR62l9IeA3w8Ih6aUvrhoJUWXWe/i/1lTdfeYhtXF+/vTC4zumXrPYAvRMT9U0o/nWSHhvgU8JLidR14BPBf/RJGxFbgfn3+9UjgMwO28cjS64FdRUfEn7LwHtEkt1C8glxOnVjko1s+HAecHxEPSyl9a
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 1440x1008 with 1 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {
|
|||
|
"tags": [],
|
|||
|
"image/png": {
|
|||
|
"width": 1038,
|
|||
|
"height": 796
|
|||
|
},
|
|||
|
"needs_background": "light"
|
|||
|
}
|
|||
|
}
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {
|
|||
|
"id": "alePijoXy6AH"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"# Zero-Shot Image Classification\n",
|
|||
|
"\n",
|
|||
|
"You can classify images using the cosine similarity (times 100) as the logits to the softmax operation."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"metadata": {
|
|||
|
"id": "Nqu4GlfPfr-p",
|
|||
|
"colab": {
|
|||
|
"base_uri": "https://localhost:8080/",
|
|||
|
"height": 102,
|
|||
|
"referenced_widgets": [
|
|||
|
"1369964d45004b5e95a058910b2a33e6",
|
|||
|
"12e23e2819094ee0a079d4eb77cfc4f9",
|
|||
|
"7a5f52e56ede4ac3abe37a3ece007dc9",
|
|||
|
"ce8b0faa1a1340b5a504d7b3546b3ccb",
|
|||
|
"5e6adc4592124a4581b85f4c1f3bab4d",
|
|||
|
"4a61c10fc00c4f04bb00b82e942da210",
|
|||
|
"b597cd6f6cd443aba4bf4491ac7f957e",
|
|||
|
"161969cae25a49f38aacd1568d3cac6c"
|
|||
|
]
|
|||
|
},
|
|||
|
"outputId": "ca7a0e3c-e267-4e6e-8a1b-bbab3c0a2462"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"from torchvision.datasets import CIFAR100\n",
|
|||
|
"\n",
|
|||
|
"cifar100 = CIFAR100(os.path.expanduser(\"~/.cache\"), transform=preprocess, download=True)"
|
|||
|
],
|
|||
|
"execution_count": 13,
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to /root/.cache/cifar-100-python.tar.gz\n"
|
|||
|
],
|
|||
|
"name": "stdout"
|
|||
|
},
|
|||
|
{
|
|||
|
"output_type": "display_data",
|
|||
|
"data": {
|
|||
|
"application/vnd.jupyter.widget-view+json": {
|
|||
|
"model_id": "1369964d45004b5e95a058910b2a33e6",
|
|||
|
"version_minor": 0,
|
|||
|
"version_major": 2
|
|||
|
},
|
|||
|
"text/plain": [
|
|||
|
"HBox(children=(FloatProgress(value=0.0, max=169001437.0), HTML(value='')))"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {
|
|||
|
"tags": []
|
|||
|
}
|
|||
|
},
|
|||
|
{
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"\n",
|
|||
|
"Extracting /root/.cache/cifar-100-python.tar.gz to /root/.cache\n"
|
|||
|
],
|
|||
|
"name": "stdout"
|
|||
|
}
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"metadata": {
|
|||
|
"id": "C4S__zCGy2MT"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"text_descriptions = [f\"This is a photo of a {label}\" for label in cifar100.classes]\n",
|
|||
|
"text_tokens = clip.tokenize(text_descriptions).cuda()"
|
|||
|
],
|
|||
|
"execution_count": 14,
|
|||
|
"outputs": []
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"metadata": {
|
|||
|
"id": "c4z1fm9vCpSR"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"with torch.no_grad():\n",
|
|||
|
" text_features = model.encode_text(text_tokens).float()\n",
|
|||
|
" text_features /= text_features.norm(dim=-1, keepdim=True)\n",
|
|||
|
"\n",
|
|||
|
"text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)\n",
|
|||
|
"top_probs, top_labels = text_probs.cpu().topk(5, dim=-1)"
|
|||
|
],
|
|||
|
"execution_count": 15,
|
|||
|
"outputs": []
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"metadata": {
|
|||
|
"colab": {
|
|||
|
"base_uri": "https://localhost:8080/",
|
|||
|
"height": 931
|
|||
|
},
|
|||
|
"id": "T6Ju_6IBE2Iz",
|
|||
|
"outputId": "e1a155dc-474d-409c-e03d-d41b804648c3"
|
|||
|
},
|
|||
|
"source": [
|
|||
|
"plt.figure(figsize=(16, 16))\n",
|
|||
|
"\n",
|
|||
|
"for i, image in enumerate(original_images):\n",
|
|||
|
" plt.subplot(4, 4, 2 * i + 1)\n",
|
|||
|
" plt.imshow(image)\n",
|
|||
|
" plt.axis(\"off\")\n",
|
|||
|
"\n",
|
|||
|
" plt.subplot(4, 4, 2 * i + 2)\n",
|
|||
|
" y = np.arange(top_probs.shape[-1])\n",
|
|||
|
" plt.grid()\n",
|
|||
|
" plt.barh(y, top_probs[i])\n",
|
|||
|
" plt.gca().invert_yaxis()\n",
|
|||
|
" plt.gca().set_axisbelow(True)\n",
|
|||
|
" plt.yticks(y, [cifar100.classes[index] for index in top_labels[i].numpy()])\n",
|
|||
|
" plt.xlabel(\"probability\")\n",
|
|||
|
"\n",
|
|||
|
"plt.subplots_adjust(wspace=0.5)\n",
|
|||
|
"plt.show()"
|
|||
|
],
|
|||
|
"execution_count": 16,
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"output_type": "display_data",
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAByoAAAckCAYAAAATVYMmAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAWJQAAFiUBSVIk8AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzda6xleXrX9+/z/6+9z61O3au7p7unp2c8Ho+xHMAWAZFAhkuAKEoIEREhgEJ4ByFKeIGSSESQ+E0SIaIALxIiRQJCgoOSkEgIAoGMiY0ReAZjG8Z0M93T012X7rqe676t///Ji+dZ+9TcxPTYXT0e/z5Sq6pOnbPP2mutXZozv/17HnN3RERERERERERERERERESepfJhH4CIiIiIiIiIiIiIiIiI/OKjoFJEREREREREREREREREnjkFlSIiIiIiIiIiIiIiIiLyzCmoFBEREREREREREREREZFnTkGliIiIiIiIiIiIiIiIiDxzCipFRERERERERERERERE5JlTUCkiIiIiIiIiIiIiIiIiz5yCShERERERERERERERERF55hRUioiIiIiIiIiIiIiIiMgzp6BSRERERERERERERERERJ45BZUiIiIiIiIiIiIiIiIi8swpqBQRERERERERERERERGRZ274sA9AREREfu7M7E3gMvClD/lQRES+HbwKHLv7xz/sAxERERGRbw/6uVlE5Cu8yrfJz80KKkVERL4zXN7Z2bn+fd/3fdc/7AOR9+/k5ASAw8PDD/lI5P3Stfv29IUvfIHFYvFhH4aIiIiIfHvRz83vk37eeX90vt4fna/37+fznH07/dysoFJEROQ7w5deeeWV65/73Oc+7OOQb8FnP/tZAD7zmc98qMch75+u3benH/zBH+Tzn//8lz7s4xARERGRbyv6ufl90s8774/O1/uj8/X+/Xyes2+nn5u1o1JEREREREREREREREREnjkFlSIiIiIiIiIiIiIiIiLyzCmoFBEREREREREREREREZFnTkGliIiIiIiIiIiIiIiIiDxzCipFRERERERERERERERE5JlTUCkiIiIiIiIiIiIiIiIiz5yCShERERERERERERERERF55hRUioiIiIiIiIiIiIiIiMgzp6BSRERERERERERERERERJ45BZUiIiIiIiIiIiIiIiIi8swpqBQRERERERERERERERGRZ05BpYiIiIiIiIiIiIiIiIg8cwoqRUREREREREREREREROSZU1ApIiIiIiIiIiIiIiIiIs+cgkoREREREREREREREREReeaGD/sARERE5OfHW8edV/+Tv/JhH4b8XPw1Xb9fsHTtPhBf+i//1Q/7EERERETkO4h+bv4W6eed90fn6/3R+Xr/njpn3wk/N6tRKSIiIiIiIiIiIiIiIiLPnIJKEREREREREREREREREXnmFFSKiIiIiIiIiIiIiIiIyDOnoFJEREREREREREREREREnjkFlSIiIiIiIiIiIiIiIiLyzCmoFBEREREREREREREREZFnTkGliIiIiIiIiIiIiIiIiDxzCipFRERERERERERERERE5JlTUCkiIiIiIiIiIiIiIiIiz9zwYR+AiIiIiIiIyHeiv/nX/qrfvvsuq/Wa524csrc35/zkhC/84zf463/98/zUT73J6eo9Rj8GOgOH7NYbzOoehQGn43TMCvsHc37J97/AR57f4+TshH/4U7d5+N6Kl165xa/+Vd/Dpz71ca5fvU6ZVUotWKmYFdw8fu0OON0do8THcKwYOJQK7lDKgJUC7pQ6xOc4uHdKtfh4qbg7WAHvWC2AAU6xSinx8WIGdDDDKPF7N9w7eKf7iPdG6yP0EbzS6RiOA+ZOqRXzghnx+KVShoFCfKyUgWFWKVYxDLee3yuPKL4QM8O84NYwL2Dg3emMdHe8d7wDDt17nIuhMAwDtQ4MpVLzvGLx4KUULJ83bhhG90afHqhZfA9rbN8n7nlgeFxfN2iN5h4f7x7n3PL6ecc9ryFgxDUtNmyPZagDpRilFEqp1FJw6xQMs0LvG8BwgBbXA3eg4DS8N/KE4Pl83OPTzCqlQI+Dzq/zuE49jsjyPLj3OP89j9XA3ek06EYpRu8tHgLobU0ckrP9hnlG47lZvg4caNQ6BzOckWoDZo7lcy6FuN+KUfJ+K+Wp9+bnsXkD78bYNnQa3pzenN5H2mbM719wb7Q20vomrkkDJ/6eDj3vM/OCtw3NndY2uHd6N1ofOTt5xHv33uD+nX/A8eOfwfsas0LD6M1ZLTcslp3zMzg9c87Xhau3XuH7v//X8Inv+l6u3rjGMJvFvwHUeC1jdHq+lg1aXAu2r+9Gsfj+5Mtves05Pc9FXPfWNtv7vY+NtTfoztgbbdPoGL2PuMfrKM6B57WOe7j3EbzQS17v3uM+c8d7XLkOFIzeNvzQH/9T9i39YyoiIiLf0RRUioiIiIiIiHwAvuvVj/HC87fYrBvYyGq95NGDx9y994Q775zizajMaNSMJBudhlmNB3DoNAYGujt1BjduHHDl2pyjJxvOjt7j+PiEo+MFq9UGL2wDK8wy7OkRjtUKvUWYRc3AJ4ILKwPQiC8zzCLU6m3MEBKKFby3bXhWSo0AoswyEeyAbUM4K5ViU6DRcW95XCVCFavQNjhOsUK3Am6UjDEKZftwU9hqZthgFCv5X6UO84zgLAPRCvQ4hzZFlhl8mlOoNGsZcg0UBvBOc8dqofcW4Vh3ijvm+RjmGf7a9nyYlTy+CBDxjjWDPoV9nZ5BpRUy4IlAsngGxThePEK/bljJYM49g0OP5+RQSgUr1BIhrdUBI695LdtjcxrmRLjlDjaAj3EfAMUiEM4ECvcMJqlE2NUxd9yMzoh7hJFMYSXO9sGMCKXcaT5Sa4USQWjvbTrzed+M+T3jCmPTvdXwvJ6tj5hF+Bh/l19LPj8DY8CsxDWxiltGYZbHuE1J8xAtjr93w62BdawYpWcwbj2/Zx6ej2BO5JyFRqNScY/v6aVRyzwCPzeaN8w6xQa8NQDW65HF+TGnR+9wdvYG+HrK8qF3vDdKGahlpNZGmUFZOg/u3+Wtt1/jytUb7B7sczAMNOsUj/u7+wajYjZgPpJvO4AKrTfMM5ifQuhiU1IYoWrvGXQSYbcZeINaqX1Fp1N9Ri8GvVNtoGUwSrV808J0D8R56+Z4b3SfXreZ5NOh93j9W9FMNxEREfmGFFSKiIiIiIiIfAAOLh0w352zXo2cL0549Og9br/zLq+9fp/z04ZhGTgMGA1nxH3MMKBFKMU8IswOY+/M9nY42Jnz8ivXuPvuOQ/uP+Ho+ITlYgnuWDXqUHEKdRjobYNboXunDDN6iyZfrRV322Y6ZlmpNCjFI3csGTiY0/qGoe5GyGMlW34ZqGYz07JhCWxbkd0bhQgXox3YMcumVgZf7o1aKt0MulNqPP4UiJUMBM2I4LCwDSvdW0Yi0b7Doik4RSURXDVKmUVzsESYMrVJjRp/9ggEnYbhEbjRqLVQ6kDN5+ZYBobT864Ui7ZoG6NxV0qht5YNw3gafRyzqVriEH0KDFu0Mc1xi3PmvUXoaoZ5hVIo2WKLILdFWJfhajQaS5yrjK4czwAvGrFYhMBTY9KmEDmvvfcWAR2Wn9OzRRv3RWudahbPx0o0YHv+GbJFavQMpnDiPmmbbUhpJYNGiyZetQF34rrTwTpDqdGCzXC62AClRwBn2ey1ur2/4gHAS6c3o5fOYANOBmc2YETwWqgZ/DbcoG1vkrzPZjP6JoLSni3GuJMKZh2ngjU6Q3ysZ/jMgHXHGBk9zt5ms+D09AGnJ+/Q1sf5poA4D917tnidQoSxg8EwczbrNbdvv86tWy9x6fAq8/mcOhTi5TnGdbZ4zURjesxsNl4n0bqMcLbYLELoQvYgoZdtHI1ZPEsb4/NrhvudRrWCD5FhFhz3uLO6A30TrzO3KFU/9W+B53OzPCaKZUe1RGtYRERE5OtQUCkiIiIiIiLygfAIY2zJarXi8aNj3njzPnfvrKh1Nz/Do/mUI087I2NbUeqQjcQcS
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 1152x1152 with 16 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {
|
|||
|
"tags": [],
|
|||
|
"image/png": {
|
|||
|
"width": 917,
|
|||
|
"height": 914
|
|||
|
},
|
|||
|
"needs_background": "light"
|
|||
|
}
|
|||
|
}
|
|||
|
]
|
|||
|
}
|
|||
|
]
|
|||
|
}
|