Contrastive Language-Image Pretraining

Related tags

Deep LearningCLIP
Overview

CLIP

[Blog] [Paper] [Model Card] [Colab]

CLIP (Contrastive Language-Image Pre-Training) is a neural network trained on a variety of (image, text) pairs. It can be instructed in natural language to predict the most relevant text snippet, given an image, without directly optimizing for the task, similarly to the zero-shot capabilities of GPT-2 and 3. We found CLIP matches the performance of the original ResNet50 on ImageNet “zero-shot” without using any of the original 1.28M labeled examples, overcoming several major challenges in computer vision.

Approach

CLIP

Usage

First, install PyTorch 1.7.1 and torchvision, as well as small additional dependencies, and then install this repo as a Python package. On a CUDA GPU machine, the following will do the trick:

$ conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=11.0
$ pip install ftfy regex tqdm
$ pip install git+https://github.com/openai/CLIP.git

Replace cudatoolkit=11.0 above with the appropriate CUDA version on your machine or cpuonly when installing on a machine without a GPU.

import torch
import clip
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)

with torch.no_grad():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    
    logits_per_image, logits_per_text = model(image, text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

print("Label probs:", probs)  # prints: [[0.9927937  0.00421068 0.00299572]]

API

The CLIP module clip provides the following methods:

clip.available_models()

Returns the names of the available CLIP models.

clip.load(name, device=..., jit=False)

Returns the model and the TorchVision transform needed by the model, specified by the model name returned by clip.available_models(). It will download the model as necessary. The name argument can also be a path to a local checkpoint.

The device to run the model can be optionally specified, and the default is to use the first CUDA device if there is any, otherwise the CPU. When jit is False, a non-JIT version of the model will be loaded.

clip.tokenize(text: Union[str, List[str]], context_length=77)

Returns a LongTensor containing tokenized sequences of given text input(s). This can be used as the input to the model


The model returned by clip.load() supports the following methods:

model.encode_image(image: Tensor)

Given a batch of images, returns the image features encoded by the vision portion of the CLIP model.

model.encode_text(text: Tensor)

Given a batch of text tokens, returns the text features encoded by the language portion of the CLIP model.

model(image: Tensor, text: Tensor)

Given a batch of images and a batch of text tokens, returns two Tensors, containing the logit scores corresponding to each image and text input. The values are cosine similarities between the corresponding image and text features, times 100.

More Examples

Zero-Shot Prediction

The code below performs zero-shot prediction using CLIP, as shown in Appendix B in the paper. This example takes an image from the CIFAR-100 dataset, and predicts the most likely labels among the 100 textual labels from the dataset.

import os
import clip
import torch
from torchvision.datasets import CIFAR100

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)

# Download the dataset
cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)

# Prepare the inputs
image, class_id = cifar100[3637]
image_input = preprocess(image).unsqueeze(0).to(device)
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)

# Calculate features
with torch.no_grad():
    image_features = model.encode_image(image_input)
    text_features = model.encode_text(text_inputs)

# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)

# Print the result
print("\nTop predictions:\n")
for value, index in zip(values, indices):
    print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%")

The output will look like the following (the exact numbers may be slightly different depending on the compute device):

Top predictions:

           snake: 65.31%
          turtle: 12.29%
    sweet_pepper: 3.83%
          lizard: 1.88%
       crocodile: 1.75%

Note that this example uses the encode_image() and encode_text() methods that return the encoded features of given inputs.

Linear-probe evaluation

The example below uses scikit-learn to perform logistic regression on image features.

import os
import clip
import torch

import numpy as np
from sklearn.linear_model import LogisticRegression
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100
from tqdm import tqdm

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)

# Load the dataset
root = os.path.expanduser("~/.cache")
train = CIFAR100(root, download=True, train=True, transform=preprocess)
test = CIFAR100(root, download=True, train=False, transform=preprocess)


def get_features(dataset):
    all_features = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(DataLoader(dataset, batch_size=100)):
            features = model.encode_image(images.to(device))

            all_features.append(features)
            all_labels.append(labels)

    return torch.cat(all_features).cpu().numpy(), torch.cat(all_labels).cpu().numpy()

# Calculate the image features
train_features, train_labels = get_features(train)
test_features, test_labels = get_features(test)

# Perform logistic regression
classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1)
classifier.fit(train_features, train_labels)

# Evaluate using the logistic regression classifier
predictions = classifier.predict(test_features)
accuracy = np.mean((test_labels == predictions).astype(np.float)) * 100.
print(f"Accuracy = {accuracy:.3f}")

Note that the C value should be determined via a hyperparameter sweep using a validation split.

Owner
OpenAI
OpenAI
This is an (re-)implementation of DeepLab-ResNet in TensorFlow for semantic image segmentation on the PASCAL VOC dataset.

DeepLab-ResNet-TensorFlow This is an (re-)implementation of DeepLab-ResNet in TensorFlow for semantic image segmentation on the PASCAL VOC dataset. Up

19 Jan 16, 2022
The `rtdl` library + The official implementation of the paper

The `rtdl` library + The official implementation of the paper "Revisiting Deep Learning Models for Tabular Data"

Yandex Research 510 Dec 30, 2022
“袋鼯麻麻——智能购物平台”能够精准地定位识别每一个商品

“袋鼯麻麻——智能购物平台”能够精准地定位识别每一个商品,并且能够返回完整地购物清单及顾客应付的实际商品总价格,极大地降低零售行业实际运营过程中巨大的人力成本,提升零售行业无人化、自动化、智能化水平。

thomas-yanxin 192 Jan 05, 2023
OpenMMLab's Next Generation Video Understanding Toolbox and Benchmark

Introduction English | 简体中文 MMAction2 is an open-source toolbox for video understanding based on PyTorch. It is a part of the OpenMMLab project. The m

OpenMMLab 2.7k Jan 07, 2023
Research code for the paper "How Good is Your Tokenizer? On the Monolingual Performance of Multilingual Language Models"

Introduction This repository contains research code for the ACL 2021 paper "How Good is Your Tokenizer? On the Monolingual Performance of Multilingual

AdapterHub 20 Aug 04, 2022
Temporally Efficient Vision Transformer for Video Instance Segmentation, CVPR 2022, Oral

Temporally Efficient Vision Transformer for Video Instance Segmentation Temporally Efficient Vision Transformer for Video Instance Segmentation (CVPR

Hust Visual Learning Team 203 Dec 31, 2022
CityLearn Challenge Multi-Agent Reinforcement Learning for Intelligent Energy Management, 2020, PikaPika team

Citylearn Challenge This is the PyTorch implementation for PikaPika team, CityLearn Challenge Multi-Agent Reinforcement Learning for Intelligent Energ

bigAIdream projects 10 Oct 10, 2022
CrossNorm and SelfNorm for Generalization under Distribution Shifts (ICCV 2021)

CrossNorm (CN) and SelfNorm (SN) (Accepted at ICCV 2021) This is the official PyTorch implementation of our CNSN paper, in which we propose CrossNorm

100 Dec 28, 2022
Convex optimization for fun and profit.

CFMM Optimal Routing This repository contains the code needed to generate the figures used in the paper Optimal Routing for Constant Function Market M

Guillermo Angeris 183 Dec 29, 2022
Latent Execution for Neural Program Synthesis

Latent Execution for Neural Program Synthesis This repo provides the code to replicate the experiments in the paper Xinyun Chen, Dawn Song, Yuandong T

Xinyun Chen 16 Oct 02, 2022
Code for Quantifying Ignorance in Individual-Level Causal-Effect Estimates under Hidden Confounding

🍐 quince Code for Quantifying Ignorance in Individual-Level Causal-Effect Estimates under Hidden Confounding 🍐 Installation $ git clone

Andrew Jesson 19 Jun 23, 2022
Evidential Softmax for Sparse Multimodal Distributions in Deep Generative Models

Evidential Softmax for Sparse Multimodal Distributions in Deep Generative Models Abstract Many applications of generative models rely on the marginali

Stanford Intelligent Systems Laboratory 9 Jun 06, 2022
This repo is about implementing different approaches of pose estimation and also is a sub-task of the smart hospital bed project :smile:

Pose-Estimation This repo is a sub-task of the smart hospital bed project which is about implementing the task of pose estimation 😄 Many thanks to th

Max 11 Oct 17, 2022
A simple rest api that classifies pneumonia infection weather it is Normal, Pneumonia Virus or Pneumonia Bacteria from a chest-x-ray image.

This is a simple rest api that classifies pneumonia infection weather it is Normal, Pneumonia Virus or Pneumonia Bacteria from a chest-x-ray image.

crispengari 3 Jan 08, 2022
ObsPy: A Python Toolbox for seismology/seismological observatories.

ObsPy is an open-source project dedicated to provide a Python framework for processing seismological data. It provides parsers for common file formats

ObsPy 979 Jan 07, 2023
GPU-Accelerated Deep Learning Library in Python

Hebel GPU-Accelerated Deep Learning Library in Python Hebel is a library for deep learning with neural networks in Python using GPU acceleration with

Hannes Bretschneider 1.2k Dec 21, 2022
A simple AI that will give you si ple task and this is made with python

Crystal-AI A simple AI that will give you si ple task and this is made with python Prerequsites: Python3.6.2 pyttsx3 pip install pyttsx3 pyaudio pip i

CrystalAnd 1 Dec 25, 2021
CausaLM: Causal Model Explanation Through Counterfactual Language Models

CausaLM: Causal Model Explanation Through Counterfactual Language Models Authors: Amir Feder, Nadav Oved, Uri Shalit, Roi Reichart Abstract: Understan

Amir Feder 39 Jul 10, 2022
deep learning model with only python and numpy with test accuracy 99 % on mnist dataset and different optimization choices

deep_nn_model_with_only_python_100%_test_accuracy deep learning model with only python and numpy with test accuracy 99 % on mnist dataset and differen

0 Aug 28, 2022
Implementation for Panoptic-PolarNet (CVPR 2021)

Panoptic-PolarNet This is the official implementation of Panoptic-PolarNet. [ArXiv paper] Introduction Panoptic-PolarNet is a fast and robust LiDAR po

Zixiang Zhou 126 Jan 01, 2023