PyTorch code for SENTRY: Selective Entropy Optimization via Committee Consistency for Unsupervised DA

Related tags

Deep LearningSENTRY
Overview

PyTorch Code for SENTRY: Selective Entropy Optimization via Committee Consistency for Unsupervised Domain Adaptation

Viraj Prabhu, Shivam Khare, Deeksha Kartik, Judy Hoffman

Many existing approaches for unsupervised domain adaptation (UDA) focus on adapting under only data distribution shift and offer limited success under additional cross-domain label distribution shift. Recent work based on self-training using target pseudolabels has shown promise, but on challenging shifts pseudolabels may be highly unreliable and using them for self-training may cause error accumulation and domain misalignment. We propose Selective Entropy Optimization via Committee Consistency (SENTRY), a UDA algorithm that judges the reliability of a target instance based on its predictive consistency under a committee of random image transformations. Our algorithm then selectively minimizes predictive entropy to increase confidence on highly consistent target instances, while maximizing predictive entropy to reduce confidence on highly inconsistent ones. In combination with pseudolabel-based approximate target class balancing, our approach leads to significant improvements over the state-of-the-art on 27/31 domain shifts from standard UDA benchmarks as well as benchmarks designed to stress-test adaptation under label distribution shift.

method

Table of Contents

Setup and Dependencies

  1. Create an anaconda environment with Python 3.6: conda create -n sentry python=3.6.8 and activate: conda activate sentry
  2. Navigate to the code directory: cd code/
  3. Install dependencies: pip install -r requirements.txt

And you're all set up!

Usage

Download data

Data for SVHN->MNIST is downloaded automatically via PyTorch. Data for other benchmarks can be downloaded from the following links. The splits used for our experiments are already included in the data/ folder):

  1. DomainNet
  2. OfficeHome
  3. VisDA2017 (only train and validation needed)

Pretrained checkpoints

To reproduce numbers reported in the paper, we include a a few pretrained checkpoints. We include checkpoints (source and adapted) for SVHN to MNIST (DIGITS) in the checkpoints directory. Source and adapted checkpoints for Clipart to Sketch adaptation (from DomainNet) and Real_World to Product adaptation (from OfficeHome RS-UT) can be downloaded from this link, and should be saved to the checkpoints/source and checkpoints/SENTRY directory as appropriate.

Train and adapt model

  • Natural label distribution shift: Adapt a model from to for a given (where benchmark may be DomainNet, OfficeHome, VisDA, or DIGITS), as follows:
python train.py --id <experiment_id> \
                --source <source> \
                --target <target> \
                --img_dir <image_directory> \
                --LDS_type <LDS_type> \
                --load_from_cfg True \
                --cfg_file 'config/<benchmark>/<cfg_file>.yml' \
                --use_cuda True

SENTRY hyperparameters are provided via a sentry.yml config file in the corresponding config/<benchmark> folder (On DIGITS, we also provide a config for baseline adaptation via DANN). The list of valid source/target domains per-benchmark are:

  • DomainNet: real, clipart, sketch, painting
  • OfficeHome_RS_UT: Real_World, Clipart, Product
  • OfficeHome: Real_World, Clipart, Product, Art
  • VisDA2017: visda_train, visda_test
  • DIGITS: Only svhn (source) to mnist (target) adaptation is currently supported.

Pass in the path to the parent folder containing dataset images via the --img_dir <name_of_directory> flag (eg. --img_dir '~/data/DomainNet'). Pass in the label distribution shift type via the --LDS_type flag: For DomainNet, OfficeHome (standard), and VisDA2017, pass in --LDS_type 'natural' (default). For OfficeHome RS-UT, pass in --LDS_type 'RS_UT'. For DIGITS, pass in --LDS_type as one of IF1, IF20, IF50, or IF100, to load a manually long-tailed target training split with a given imbalance factor (IF), as described in Table 4 of the paper.

To load a pretrained DA checkpoint instead of training your own, additionally pass --load_da True and --id <benchmark_name> to the script above. Finally, the training script will log performance metrics to the console (average and aggregate accuracy), and additionally plot and save some per-class performance statistics to the results/ folder.

Note: By default this code runs on GPU. To run on CPU pass: --use_cuda False

Reference

If you found this code useful, please consider citing:

@article{prabhu2020sentry
   author = {Prabhu, Viraj and Khare, Shivam and Kartik, Deeksha and Hoffman, Judy},
   title = {SENTRY: Selective Entropy Optimization via Committee Consistency for Unsupervised Domain Adaptation},
   year = {2020},
   journal = {arXiv preprint: 2012.11460},
}

Acknowledgements

We would like to thank the developers of PyTorch for building an excellent framework, in addition to the numerous contributors to all the open-source packages we use.

License

MIT

This is a collection of our NAS and Vision Transformer work.

AutoML - Neural Architecture Search This is a collection of our AutoML-NAS work iRPE (NEW): Rethinking and Improving Relative Position Encoding for Vi

Microsoft 828 Dec 28, 2022
Streaming Anomaly Detection Framework in Python (Outlier Detection for Streaming Data)

Python Streaming Anomaly Detection (PySAD) PySAD is an open-source python framework for anomaly detection on streaming multivariate data. Documentatio

Selim Firat Yilmaz 181 Dec 18, 2022
Code for the RA-L (ICRA) 2021 paper "SeqNet: Learning Descriptors for Sequence-Based Hierarchical Place Recognition"

SeqNet: Learning Descriptors for Sequence-Based Hierarchical Place Recognition [ArXiv+Supplementary] [IEEE Xplore RA-L 2021] [ICRA 2021 YouTube Video]

Sourav Garg 63 Dec 12, 2022
Data Engineering ZoomCamp

Data Engineering ZoomCamp I'm partaking in a Data Engineering Bootcamp / Zoomcamp and will be tracking my progress here. I can't promise these notes w

Aaron 61 Jan 06, 2023
Unofficial PyTorch Implementation of UnivNet: A Neural Vocoder with Multi-Resolution Spectrogram Discriminators for High-Fidelity Waveform Generation

UnivNet UnivNet: A Neural Vocoder with Multi-Resolution Spectrogram Discriminators for High-Fidelity Waveform Generation This is an unofficial PyTorch

MINDs Lab 170 Jan 04, 2023
Code for paper: "Spinning Language Models for Propaganda-As-A-Service"

Spinning Language Models for Propaganda-As-A-Service This is the source code for the Arxiv version of the paper. You can use this Google Colab to expl

Eugene Bagdasaryan 16 Jan 03, 2023
PEPit is a package enabling computer-assisted worst-case analyses of first-order optimization methods.

PEPit: Performance Estimation in Python This open source Python library provides a generic way to use PEP framework in Python. Performance estimation

Baptiste 53 Nov 16, 2022
Code artifacts for the submission "Mind the Gap! A Study on the Transferability of Virtual vs Physical-world Testing of Autonomous Driving Systems"

Code Artifacts Code artifacts for the submission "Mind the Gap! A Study on the Transferability of Virtual vs Physical-world Testing of Autonomous Driv

Andrea Stocco 2 Aug 24, 2022
StyleGAN2 with adaptive discriminator augmentation (ADA) - Official TensorFlow implementation

StyleGAN2 with adaptive discriminator augmentation (ADA) β€” Official TensorFlow implementation Training Generative Adversarial Networks with Limited Da

NVIDIA Research Projects 1.7k Dec 29, 2022
python debugger and anti-vm that checks if you're in a virtual machine or if someones trying to debug your file

Anti-Debug was made by Love ❌ code βœ… πŸŽ‰ ・What it checks for ・ Kills tools that can be used to debug your file ・ Exits if ran in vm (supports different

Rdimo 31 Aug 09, 2022
Implement some metaheuristics and cost functions

Metaheuristics This repot implement some metaheuristics and cost functions. Metaheuristics JAYA Implement Jaya optimizer without constraints. Cost fun

Adri1G 1 Mar 23, 2022
Class-Balanced Loss Based on Effective Number of Samples. CVPR 2019

Class-Balanced Loss Based on Effective Number of Samples Tensorflow code for the paper: Class-Balanced Loss Based on Effective Number of Samples Yin C

Yin Cui 546 Jan 08, 2023
Softlearning is a reinforcement learning framework for training maximum entropy policies in continuous domains. Includes the official implementation of the Soft Actor-Critic algorithm.

Softlearning Softlearning is a deep reinforcement learning toolbox for training maximum entropy policies in continuous domains. The implementation is

Robotic AI & Learning Lab Berkeley 997 Dec 30, 2022
TSIT: A Simple and Versatile Framework for Image-to-Image Translation

TSIT: A Simple and Versatile Framework for Image-to-Image Translation This repository provides the official PyTorch implementation for the following p

Liming Jiang 255 Nov 23, 2022
alfred-py: A deep learning utility library for **human**

Alfred Alfred is command line tool for deep-learning usage. if you want split an video into image frames or combine frames into a single video, then a

JinTian 800 Jan 03, 2023
An air quality monitoring service with a Raspberry Pi and a SDS011 sensor.

Raspberry Pi Air Quality Monitor A simple air quality monitoring service for the Raspberry Pi. Installation Clone the repository and run the following

rydercalmdown 24 Dec 09, 2022
Official code repository for A Simple Long-Tailed Rocognition Baseline via Vision-Language Model.

BALLAD This is the official code repository for A Simple Long-Tailed Rocognition Baseline via Vision-Language Model. Requirements Python3 Pytorch(1.7.

peng gao 42 Nov 26, 2022
Code for NeurIPS2021 submission "A Surrogate Objective Framework for Prediction+Programming with Soft Constraints"

This repository is the code for NeurIPS 2021 submission "A Surrogate Objective Framework for Prediction+Programming with Soft Constraints". Edit 2021/

10 Dec 20, 2022
Jittor is a high-performance deep learning framework based on JIT compiling and meta-operators.

Jittor: a Just-in-time(JIT) deep learning framework Quickstart | Install | Tutorial | Chinese Jittor is a high-performance deep learning framework bas

2.7k Jan 03, 2023
Implementation of TimeSformer, a pure attention-based solution for video classification

TimeSformer - Pytorch Implementation of TimeSformer, a pure and simple attention-based solution for reaching SOTA on video classification.

Phil Wang 602 Jan 03, 2023