Self-supervised learning optimally robust representations for domain generalization.

Overview

OptDom: Learning Optimal Representations for Domain Generalization

This repository contains the official implementation for Optimal Representations for Covariate Shift️. Our paper theoretically characterizes the minimal sufficient representations for optimal domain generalization (DG) under covariate shift and derives practical self-supervised learning (SSL) objectives for learning such representations.

We provide code for reproducing our main results with contribution highlights:

  • Finetuning pretrained SSL models (CLIP) to be superior robust DG models ️[minimal example]
  • A novel contrastive adversarial domain bottleneck for learning domain-invariant representations ️[implementation]

Setup

  1. Install PyTorch 1.7.1 and CLIP following the instructions.
  2. Install other packages: pip install -r requirements.txt.

Finetune & Evaluate CLIP on DomainBed

Our paper derives SSL objectives for learning optimally robust representations and gives insights into the superior robustness of CLIP (Sec 4). Here we include the code for finetuning CLIP with our proposed objectives and evaluating on the DomainBed benchmark, which reproduces our experiments in Sec 6.2.

The implementation is included in DomainBed directory which is highly based on the DomainBed repo. The CLIP based models are implemented in domainbed/clip_algorithms.py, and the domain bottlenecks are in domainbed/bottlenecks.py. The training script for finetuning CLIP with bottlenecks is domainbed/scripts/train_clip.py.

Preparation

Move to the DomainBed directory and download the datasets:

python -m domainbed.scripts.download --data_dir ./datasets/

By default, we download the datasets: PACS, VLCS, OfficeHome, TerraIncognita, DomainNet.

Launch a single run

If you want to launch a single run for debugging, run with command:

bash run_debug.sh

The key arguments include:

  • --dataset: dataset for finetuning and evaluation.
  • --algorithm: algorithms implemented with CLIP, see domainbed/clip_algorithms.py.
  • --test_envs: list of left-out environments for testing, others used for training/finetuning.
  • --hparams: JSON-serialized hyperprameter dict, see domainbed/hparams_registry.py for list of all hyperprameters.

Note that the result of a single run could be very sensitive to hyperparameters and random seed, we recommend to launch a sweep over hyperparameters and random seeds as in DomainBed.

Launch a sweep with tuning

To launch a sweep, run with command:

bash run_sweep_clip.sh

A sweep over 10 hyperparameters and 5 random seeds is launched for each dataset and algorithm. By default, the CLIP-RN50 model is used, and you can also run with other models by changing the clip_model argument, e.g., ViT-B/32 for CLIP-ViT-B/32. Also to launch a sweep, you need to select or implement a command launcher in domainbed/command_launchers.py by setting the launcher argument. If you are using slurm, we already implement a slurm launcher that you can adapt from.

After the sweep is finished, you can collect result with the notebook collect_clip_results.ipynb. Note that the results may be slightly different from the paper due to code cleaning.

(Optional) Run CAD in DomainBed setup

You can also evaluate our proposed (conditional) CAD bottleneck in the DomainBed setup where a ResNet-50 is end-to-end trained on source domains and evaluated on a left-out target domain. We include the implementation in domainbed/algorithms.py, which you can run with command:

bash run_sweep_e2e_dombed.sh

Also you can collect result with the notebook collect_e2e_results.ipynb. Note that as the claim of our paper, the algorithms in this setup lack access to the information of the target domain, so we don't expect our bottlenecks and other algorithms to necessarily outperform ERM. However, our CAD bottleneck does lead to consistent improvement surprisingly.

Finetune CLIP on LAION-400M

Coming soon!

Minimal Code for Custom Finetuning

If you want to finetune CLIP on your dataset with our bottlenecks, we provide the minimal code example:

import torch
from torch.utils.data import DataLoader, TensorDataset
import clip
from tqdm import tqdm

from domainbed import hparams_registry
from domainbed import algorithms


# 1. Determine whether you do supervised or contrastive finetuning:
#       - True: use a cross-entropy loss with a supervised dataset
#       - False: use a contrastive loss with a text-image-pair dataset
supervised_funetuning = True

if supervised_funetuning:
    loss_name = "Sup"
    dataset_name = "my suervised dataset"
else:
    loss_name = "Contrast"
    dataset_name = "my text-image pair dataset"


# 2. Determine the bottleneck you want to use with different properties
bottleneck_name = "CondCAD"  # Ent, CAD, CondCAD
algorithm_name = loss_name + "CLIPBottleneck" + bottleneck_name


# 3. Set hyperparameters, you can also change the hyperparameter dict and default values
hparams = hparams_registry.default_hparams(algorithm_name, dataset_name)


# 4. Load pretrained CLIP models
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

pretrained, preprocess = clip.load(hparams['clip_model'], device, jit=False)


# 5. Load your dataset, you  dataset should have the form:
#       - (image, label) for supervised finetuning
#       - (image, text) for contrastive finetuning
#    Remember to use the CLIP preprocessing function for image transformation,
#       and your dataset should be a list of sub-datasets from different domains (singleton for a single domain)
dataset = load_your_dataset(dataset_name, preprocess)
num_envs = len(dataset)
num_classes = dataset.num_classes  # dummy for text-image-pair dataset


# 6. Featurize your dataset with CLIP models

def get_clip_feature(clip_model, x, y):
    """Compute CLIP features"""
    with torch.no_grad():
        z = clip_model.encode_image(x).float()
        if not supervised_funetuning:  # `y` is a batch of texts that should be tokenized
            y = clip_model.encode_text(clip.tokenize(y)).float()
    return z, y

def clip_featurize_data(clip_model, dataset, device):
    """Featurize a dataset"""
    Z, Y = [], []
    for x, y in tqdm(DataLoader(dataset, batch_size=512, num_workers=4)):
        z, y = get_clip_feature(clip_model, x.to(device), y.to(device))
        Z += [z.cpu()]
        Y += [y.cpu()]
    return TensorDataset(torch.cat(Z), torch.cat(Y))

def clip_precompute_splits(clip_model, splits, device):
    _splits = []
    for ds in splits:
        _splits.append(clip_featurize_data(clip_model, ds, device))
    return _splits


dataset = clip_precompute_splits(pretrained, dataset, device)
train_loaders = [DataLoader(
    dataset=env,
    batch_size=hparams['batch_size'],
    num_workers=4)
    for i, env in enumerate(dataset)]
train_minibatches_iterator = zip(*train_loaders)
steps_per_epoch = int(min([len(env) / hparams['batch_size'] for env in dataset]))
n_steps = hparams['max_step']


# 7. Initialize the model:
algorithm_class = algorithms.get_algorithm_class(algorithm_name)
algorithm = algorithm_class(pretrained.visual.output_dim, num_classes, num_envs, hparams, pretrained, None)
algorithm.to(device)
algorithm.train()


# 8. Finetune the model:
for step in range(n_steps):
    minibatches_device = [(x.to(device), y.to(device)) for x, y in next(train_minibatches_iterator)]
    algorithm.adjust_lr(step, n_steps, steps_per_epoch)
    step_vals = algorithm.update(minibatches_device, None)

Cite

If you find this work relevant to your work, please cite our paper:

@article{ruan2021optdom,
  title={Optimal Representations for Covariate Shift},
  author={Ruan, Yangjun and  Dubois, Yann and Maddison, Chris J},
  journal={arXiv preprint arXiv:2201.00057},
  year={2022},
}

Acknowledgement

Our code is based on:

Owner
Yangjun Ruan
Ph.D. student @ UofT & Vector Previously undergrad @ ZJU
Yangjun Ruan
Explaining in Style: Training a GAN to explain a classifier in StyleSpace

Explaining in Style: Official TensorFlow Colab Explaining in Style: Training a GAN to explain a classifier in StyleSpace Oran Lang, Yossi Gandelsman,

Google 197 Nov 08, 2022
A simple program for training and testing vit

Vit This is a simple program for training and testing vit. Key requirements: torch, torchvision and timm. Dataset I put 5 categories of the cub classi

xiezhenyu 2 Oct 11, 2022
This is a Python wrapper for TA-LIB based on Cython instead of SWIG.

TA-Lib This is a Python wrapper for TA-LIB based on Cython instead of SWIG. From the homepage: TA-Lib is widely used by trading software developers re

John Benediktsson 7.3k Jan 03, 2023
Implementation of "Unsupervised Domain Adaptive 3D Detection with Multi-Level Consistency"

Unsupervised Domain Adaptive 3D Detection with Multi-Level Consistency (ICCV2021) Paper Link: https://arxiv.org/abs/2107.11355 This implementation bui

32 Nov 17, 2022
[v1 (ISBI'21) + v2] MedMNIST: A Large-Scale Lightweight Benchmark for 2D and 3D Biomedical Image Classification

MedMNIST Project (Website) | Dataset (Zenodo) | Paper (arXiv) | MedMNIST v1 (ISBI'21) Jiancheng Yang, Rui Shi, Donglai Wei, Zequan Liu, Lin Zhao, Bili

683 Dec 28, 2022
Semi-Supervised Semantic Segmentation with Cross-Consistency Training (CCT)

Semi-Supervised Semantic Segmentation with Cross-Consistency Training (CCT) Paper, Project Page This repo contains the official implementation of CVPR

Yassine 344 Dec 29, 2022
Deep Learning for Time Series Forecasting.

nixtlats:Deep Learning for Time Series Forecasting [nikstla] (noun, nahuatl) Period of time. State-of-the-art time series forecasting for pytorch. Nix

Nixtla 5 Dec 06, 2022
Udacity's CS101: Intro to Computer Science - Building a Search Engine

Udacity's CS101: Intro to Computer Science - Building a Search Engine All soluti

Phillip 0 Feb 26, 2022
Space Ship Simulator using python

FlyOver Basic space-ship simulator using python How to run? Just double click run.py What modules do i need? All modules that i currently using is bui

0 Oct 09, 2022
TensorFlow Implementation of Unsupervised Cross-Domain Image Generation

Domain Transfer Network (DTN) TensorFlow implementation of Unsupervised Cross-Domain Image Generation. Requirements Python 2.7 TensorFlow 0.12 Pickle

Yunjey Choi 864 Dec 30, 2022
Cours d'Algorithmique Appliquée avec Python pour BTS SIO SISR

Course: Introduction to Applied Algorithms with Python (in French) This is the source code of the website for the Applied Algorithms with Python cours

Loic Yvonnet 0 Jan 27, 2022
Audio Visual Emotion Recognition using TDA

Audio Visual Emotion Recognition using TDA RAVDESS database with two datasets analyzed: Video and Audio dataset: Audio-Dataset: https://www.kaggle.com

Combinatorial Image Analysis research group 3 May 11, 2022
GeoMol: Torsional Geometric Generation of Molecular 3D Conformer Ensembles

GeoMol: Torsional Geometric Generation of Molecular 3D Conformer Ensembles This repository contains a method to generate 3D conformer ensembles direct

127 Dec 20, 2022
This is an official implementation for "Video Swin Transformers".

Video Swin Transformer By Ze Liu*, Jia Ning*, Yue Cao, Yixuan Wei, Zheng Zhang, Stephen Lin and Han Hu. This repo is the official implementation of "V

Swin Transformer 981 Jan 03, 2023
Look Closer: Bridging Egocentric and Third-Person Views with Transformers for Robotic Manipulation

Look Closer: Bridging Egocentric and Third-Person Views with Transformers for Robotic Manipulation Official PyTorch implementation for the paper Look

Rishabh Jangir 20 Nov 24, 2022
A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch

Introduction This is a Python package available on PyPI for NVIDIA-maintained utilities to streamline mixed precision and distributed training in Pyto

Artit 'Art' Wangperawong 5 Sep 29, 2021
MOOSE (Multi-organ objective segmentation) a data-centric AI solution that generates multilabel organ segmentations to facilitate systemic TB whole-person research

MOOSE (Multi-organ objective segmentation) a data-centric AI solution that generates multilabel organ segmentations to facilitate systemic TB whole-person research.The pipeline is based on nn-UNet an

QIMP team 30 Jan 01, 2023
Open source repository for the code accompanying the paper 'PatchNets: Patch-Based Generalizable Deep Implicit 3D Shape Representations'.

PatchNets This is the official repository for the project "PatchNets: Patch-Based Generalizable Deep Implicit 3D Shape Representations". For details,

16 May 22, 2022
NeuroLKH: Combining Deep Learning Model with Lin-Kernighan-Helsgaun Heuristic for Solving the Traveling Salesman Problem

NeuroLKH: Combining Deep Learning Model with Lin-Kernighan-Helsgaun Heuristic for Solving the Traveling Salesman Problem Liang Xin, Wen Song, Zhiguang

xinliangedu 33 Dec 27, 2022
A fast poisson image editing implementation that can utilize multi-core CPU or GPU to handle a high-resolution image input.

Poisson Image Editing - A Parallel Implementation Jiayi Weng (jiayiwen), Zixu Chen (zixuc) Poisson Image Editing is a technique that can fuse two imag

Jiayi Weng 110 Dec 27, 2022