Contrastive Language-Image Pretraining

Related tags

Deep LearningCLIP
Overview

CLIP

[Blog] [Paper] [Model Card] [Colab]

CLIP (Contrastive Language-Image Pre-Training) is a neural network trained on a variety of (image, text) pairs. It can be instructed in natural language to predict the most relevant text snippet, given an image, without directly optimizing for the task, similarly to the zero-shot capabilities of GPT-2 and 3. We found CLIP matches the performance of the original ResNet50 on ImageNet “zero-shot” without using any of the original 1.28M labeled examples, overcoming several major challenges in computer vision.

Approach

CLIP

Usage

First, install PyTorch 1.7.1 and torchvision, as well as small additional dependencies, and then install this repo as a Python package. On a CUDA GPU machine, the following will do the trick:

$ conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=11.0
$ pip install ftfy regex tqdm
$ pip install git+https://github.com/openai/CLIP.git

Replace cudatoolkit=11.0 above with the appropriate CUDA version on your machine or cpuonly when installing on a machine without a GPU.

import torch
import clip
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)

with torch.no_grad():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    
    logits_per_image, logits_per_text = model(image, text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

print("Label probs:", probs)  # prints: [[0.9927937  0.00421068 0.00299572]]

API

The CLIP module clip provides the following methods:

clip.available_models()

Returns the names of the available CLIP models.

clip.load(name, device=..., jit=False)

Returns the model and the TorchVision transform needed by the model, specified by the model name returned by clip.available_models(). It will download the model as necessary. The name argument can also be a path to a local checkpoint.

The device to run the model can be optionally specified, and the default is to use the first CUDA device if there is any, otherwise the CPU. When jit is False, a non-JIT version of the model will be loaded.

clip.tokenize(text: Union[str, List[str]], context_length=77)

Returns a LongTensor containing tokenized sequences of given text input(s). This can be used as the input to the model


The model returned by clip.load() supports the following methods:

model.encode_image(image: Tensor)

Given a batch of images, returns the image features encoded by the vision portion of the CLIP model.

model.encode_text(text: Tensor)

Given a batch of text tokens, returns the text features encoded by the language portion of the CLIP model.

model(image: Tensor, text: Tensor)

Given a batch of images and a batch of text tokens, returns two Tensors, containing the logit scores corresponding to each image and text input. The values are cosine similarities between the corresponding image and text features, times 100.

More Examples

Zero-Shot Prediction

The code below performs zero-shot prediction using CLIP, as shown in Appendix B in the paper. This example takes an image from the CIFAR-100 dataset, and predicts the most likely labels among the 100 textual labels from the dataset.

import os
import clip
import torch
from torchvision.datasets import CIFAR100

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)

# Download the dataset
cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)

# Prepare the inputs
image, class_id = cifar100[3637]
image_input = preprocess(image).unsqueeze(0).to(device)
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)

# Calculate features
with torch.no_grad():
    image_features = model.encode_image(image_input)
    text_features = model.encode_text(text_inputs)

# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)

# Print the result
print("\nTop predictions:\n")
for value, index in zip(values, indices):
    print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%")

The output will look like the following (the exact numbers may be slightly different depending on the compute device):

Top predictions:

           snake: 65.31%
          turtle: 12.29%
    sweet_pepper: 3.83%
          lizard: 1.88%
       crocodile: 1.75%

Note that this example uses the encode_image() and encode_text() methods that return the encoded features of given inputs.

Linear-probe evaluation

The example below uses scikit-learn to perform logistic regression on image features.

import os
import clip
import torch

import numpy as np
from sklearn.linear_model import LogisticRegression
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100
from tqdm import tqdm

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)

# Load the dataset
root = os.path.expanduser("~/.cache")
train = CIFAR100(root, download=True, train=True, transform=preprocess)
test = CIFAR100(root, download=True, train=False, transform=preprocess)


def get_features(dataset):
    all_features = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(DataLoader(dataset, batch_size=100)):
            features = model.encode_image(images.to(device))

            all_features.append(features)
            all_labels.append(labels)

    return torch.cat(all_features).cpu().numpy(), torch.cat(all_labels).cpu().numpy()

# Calculate the image features
train_features, train_labels = get_features(train)
test_features, test_labels = get_features(test)

# Perform logistic regression
classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1)
classifier.fit(train_features, train_labels)

# Evaluate using the logistic regression classifier
predictions = classifier.predict(test_features)
accuracy = np.mean((test_labels == predictions).astype(np.float)) * 100.
print(f"Accuracy = {accuracy:.3f}")

Note that the C value should be determined via a hyperparameter sweep using a validation split.

Owner
OpenAI
OpenAI
3D Pose Estimation for Vehicles

3D Pose Estimation for Vehicles Introduction This work generates 4 key-points and 2 key-edges from vertices and edges of vehicles as ground truth. The

Jingyi Wang 1 Nov 01, 2021
Based on the given clinical dataset, Predict whether the patient having Heart Disease or Not having Heart Disease

Heart_Disease_Classification Based on the given clinical dataset, Predict whether the patient having Heart Disease or Not having Heart Disease Dataset

Ashish 1 Jan 30, 2022
NU-Wave: A Diffusion Probabilistic Model for Neural Audio Upsampling

NU-Wave: A Diffusion Probabilistic Model for Neural Audio Upsampling For Official repo of NU-Wave: A Diffusion Probabilistic Model for Neural Audio Up

Rishikesh (ऋषिकेश) 38 Oct 11, 2022
Transformer part of 12th place solution in Riiid! Answer Correctness Prediction

kaggle_riiid Transformer part of 12th place solution in Riiid! Answer Correctness Prediction. Please see here for more information. Execution You need

Sakami Kosuke 2 Apr 23, 2022
A library for uncertainty quantification based on PyTorch

Torchuq [logo here] TorchUQ is an extensive library for uncertainty quantification (UQ) based on pytorch. TorchUQ currently supports 10 representation

TorchUQ 96 Dec 12, 2022
House_prices_kaggle - Predict sales prices and practice feature engineering, RFs, and gradient boosting

House Prices - Advanced Regression Techniques Predicting House Prices with Machine Learning This project is build to enhance my knowledge about machin

Gurpreet Singh 1 Jan 01, 2022
Unsupervised Feature Loss (UFLoss) for High Fidelity Deep learning (DL)-based reconstruction

Unsupervised Feature Loss (UFLoss) for High Fidelity Deep learning (DL)-based reconstruction Official github repository for the paper High Fidelity De

28 Dec 16, 2022
PyTorch implementation of SimSiam: Exploring Simple Siamese Representation Learning

SimSiam: Exploring Simple Siamese Representation Learning This is a PyTorch implementation of the SimSiam paper: @Article{chen2020simsiam, author =

Facebook Research 834 Dec 30, 2022
(SIGIR2020) “Asymmetric Tri-training for Debiasing Missing-Not-At-Random Explicit Feedback’’

Asymmetric Tri-training for Debiasing Missing-Not-At-Random Explicit Feedback About This repository accompanies the real-world experiments conducted i

yuta-saito 19 Dec 01, 2022
A style-based Quantum Generative Adversarial Network

Style-qGAN A style based Quantum Generative Adversarial Network (style-qGAN) model for Monte Carlo event generation. Tutorial We have prepared a noteb

9 Nov 24, 2022
Discord-Protect is a simple discord bot allowing you to have some security on your discord server by ordering a captcha to the user who joins your server.

Discord-Protect Discord-Protect is a simple discord bot allowing you to have some security on your discord server by ordering a captcha to the user wh

Tir Omar 2 Oct 28, 2021
The code for 'Deep Residual Fourier Transformation for Single Image Deblurring'

Deep Residual Fourier Transformation for Single Image Deblurring Xintian Mao, Yiming Liu, Wei Shen, Qingli Li and Yan Wang code will be released soon

145 Dec 13, 2022
Simple (but Strong) Baselines for POMDPs

Recurrent Model-Free RL is a Strong Baseline for Many POMDPs Welcome to the POMDP world! This repo provides some simple baselines for POMDPs, specific

Tianwei V. Ni 172 Dec 29, 2022
Imagededup - 😎 Finding duplicate images made easy

imagededup is a python package that simplifies the task of finding exact and near duplicates in an image collection.

idealo 4.3k Jan 07, 2023
Machine learning notebooks in different subjects optimized to run in google collaboratory

Notebooks Name Description Category Link Training pix2pix This notebook shows a simple pipeline for training pix2pix on a simple dataset. Most of the

Zaid Alyafeai 363 Dec 06, 2022
🧑‍🔬 verify your TEAL program by experiment and observation

Graviton - Testing TEAL with Dry Runs Tutorial Local Installation The following instructions assume that you have make available in your local environ

Algorand 18 Jan 03, 2023
🛠️ Tools for Transformers compression using Lightning ⚡

Bert-squeeze is a repository aiming to provide code to reduce the size of Transformer-based models or decrease their latency at inference time.

Jules Belveze 66 Dec 11, 2022
A reimplementation of DCGAN in PyTorch

DCGAN in PyTorch A reimplementation of DCGAN in PyTorch. Although there is an abundant source of code and examples found online (as well as an officia

Diego Porres 6 Jan 08, 2022