Robust fine-tuning of zero-shot models

Related tags

Deep Learningwise-ft
Overview

Robust fine-tuning of zero-shot models

This repository contains code for the paper Robust fine-tuning of zero-shot models by Mitchell Wortsman*, Gabriel Ilharco*, Jong Wook Kim, Mike Li, Simon Kornblith, Rebecca Roelofs, Raphael Gontijo-Lopes, Hannaneh Hajishirzi, Ali Farhadi, Hongseok Namkoong, Ludwig Schmidt.

Abstract

Large pre-trained models such as CLIP offer consistent accuracy across a range of data distributions when performing zero-shot inference (i.e., without fine-tuning on a specific dataset). Although existing fine-tuning approaches substantially improve accuracy in-distribution, they also reduce out-of-distribution robustness. We address this tension by introducing a simple and effective method for improving robustness: ensembling the weights of the zero-shot and fine-tuned models. Compared to standard fine-tuning, the resulting weight-space ensembles provide large accuracy improvements out-of-distribution, while matching or improving in-distribution accuracy. On ImageNet and five derived distribution shifts, weight-space ensembles improve out-of-distribution accuracy by 2 to 10 percentage points while increasing in-distribution accuracy by nearly 1 percentage point relative to standard fine-tuning. These improvements come at no additional computational cost during fine-tuning or inference.

Summary figure

figure1

Compared to standard fine-tuning, weight-space ensembles for fine-tuning (WiSE-FT) improve out-of-distribution (OOD) accuracy without decreasing in-distribution (ID) performance. Top left: Zero-shot CLIP models exhibit high effective robustness and moderate in-distribution accuracy, while standard fine-tuning (end-to-end or with a linear classifier) attains higher ID accuracy and less effective robustness. Top right: Our method linearly interpolates between the zero-shot and fine-tuned models with a mixing coefficient alpha in [0,1]. Bottom: On five distribution shifts derived from ImageNet (ImageNetV2, ImageNet-R, ImageNet Sketch, ObjectNet, and ImageNet-A), WiSE-FT improves average OOD accuracy by 8.7 percentage points (pp) when fine-tuning end-to-end (+2.1 pp when fine-tuning a linear classifier) while maintaining ID accuracy.

Code

Overview

WiSE-FT can be implemented in a few lines of code in addition to standard fine-tuning, as shown below. See src/wise_ft.py for more details.

# Load models
zeroshot = ImageClassifier.load(zeroshot_checkpoint)
finetuned = ImageClassifier.load(finetuned_checkpoint)
theta_0 = zeroshot.state_dict()
theta_1 = finetuned.state_dict()

# make sure checkpoints are compatible
assert set(theta_0.keys()) == set(theta_1.keys())

# interpolate between checkpoints with mixing coefficient alpha
theta = {
    key: (1-alpha) * theta_0[key] + alpha * theta_1[key]
    for key in theta_0.keys()
}

# update the model acccording to the new weights
finetuned.load_state_dict(theta)

# evaluate
evaluate(finetuned, args)

Install dependencies

conda env create
conda activate wiseft

Add directory to PYTHONPATH:

cd wise-ft
export PYTHONPATH="$PYTHONPATH:$PWD"

Download data

When necessary, please refer to datasets.md for instructions on how to download datasets.

Run WiSE-FT

Sample command when zeroshot and fine-tuned models are available:

python src/wise_ft.py   \
    --eval-datasets=ImageNet,ImageNetV2,ImageNetR,ImageNetA,ImageNetSketch  \
    --load=models/zeroshot.pt,models/finetuned.pt  \
    --results-db=results.jsonl  \
    --save=models/wiseft  \
    --data-location=~/data \
    --alpha 0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0

Sample command for running WiSE-FT from scratch using ViT-B/32:

python src/wise_ft.py   \
    --train-dataset=ImageNet  \
    --epochs=10  \
    --lr=0.00003  \
    --batch-size=512  \
    --cache-dir=cache  \
    --model=ViT-B/32  \
    --eval-datasets=ImageNet,ImageNetV2,ImageNetR,ImageNetA,ImageNetSketch  \
    --template=openai_imagenet_template  \
    --results-db=results.jsonl  \
    --save=models/wiseft/ViTB32  \
    --data-location=~/data \
    --alpha 0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0

Note: the flag --freeze-encoder controls whether only a linear classifier is fine-tuned, or if all weights are fine-tuned (end-to-end).

Plotting results

Sample command for generating a scatter plot:

python src/scatter_plot.py  \
    --eval-datasets=ImageNetV2,ImageNetR,ImageNetA,ImageNetSketch  \
    --results-db=results.jsonl  \
    --save plots

We show samples of expected behavior below when running the commands above using ViT-B/16 (models can be downloaded here):

ImageNet-Sketch         ImageNet-A

ImageNet-R         ImageNetV2

ObjectNet

Citing

If you found this repository useful, please consider citing:

@article{wortsman2021robust,
  title={Robust fine-tuning of zero-shot models},
  author={Wortsman, Mitchell and Ilharco, Gabriel and Kim, Jong Wook and Li, Mike and Kornblith, Simon and Roelofs, Rebecca and Gontijo-Lopes, Raphael and Hajishirzi, Hannaneh and Farhadi, Ali and Namkoong, Hongseok and Schmidt, Ludwig},
  journal={arXiv preprint arXiv:2109.01903},
  note={\url{https://arxiv.org/abs/2109.01903}},
  year={2021}
}
Understanding the Properties of Minimum Bayes Risk Decoding in Neural Machine Translation.

Understanding Minimum Bayes Risk Decoding This repo provides code and documentation for the following paper: Müller and Sennrich (2021): Understanding

ZurichNLP 13 May 01, 2022
This is an official implementation for "Video Swin Transformers".

Video Swin Transformer By Ze Liu*, Jia Ning*, Yue Cao, Yixuan Wei, Zheng Zhang, Stephen Lin and Han Hu. This repo is the official implementation of "V

Swin Transformer 981 Jan 03, 2023
alfred-py: A deep learning utility library for **human**

Alfred Alfred is command line tool for deep-learning usage. if you want split an video into image frames or combine frames into a single video, then a

JinTian 800 Jan 03, 2023
Compositional and Parameter-Efficient Representations for Large Knowledge Graphs

NodePiece - Compositional and Parameter-Efficient Representations for Large Knowledge Graphs NodePiece is a "tokenizer" for reducing entity vocabulary

Michael Galkin 107 Jan 04, 2023
Unsupervised Image-to-Image Translation

UNIT: UNsupervised Image-to-image Translation Networks Imaginaire Repository We have a reimplementation of the UNIT method that is more performant. It

Ming-Yu Liu 劉洺堉 1.9k Dec 26, 2022
Disentangled Cycle Consistency for Highly-realistic Virtual Try-On, CVPR 2021

Disentangled Cycle Consistency for Highly-realistic Virtual Try-On, CVPR 2021 [WIP] The code for CVPR 2021 paper 'Disentangled Cycle Consistency for H

ChongjianGE 94 Dec 11, 2022
PyTorch implementaton of our CVPR 2021 paper "Bridging the Visual Gap: Wide-Range Image Blending"

Bridging the Visual Gap: Wide-Range Image Blending PyTorch implementaton of our CVPR 2021 paper "Bridging the Visual Gap: Wide-Range Image Blending".

Chia-Ni Lu 69 Dec 20, 2022
nanodet_plus,yolov5_v6.0

OAK_Detection OAK设备上适配nanodet_plus,yolov5_v6.0 Environment pytorch = 1.7.0

炼丹去了 1 Feb 18, 2022
Re-implementation of the vector capsule with dynamic routing

VectorCapsule Re-implementation of the vector capsule with dynamic routing We implement the vector capsule and dynamic routing via graph neural networ

ZhenchaoTang 10 Feb 10, 2022
Bio-Computing Platform Featuring Large-Scale Representation Learning and Multi-Task Deep Learning “螺旋桨”生物计算工具集

English | 简体中文 Latest News 2021.10.25 Paper "Docking-based Virtual Screening with Multi-Task Learning" is accepted by BIBM 2021. 2021.07.29 PaddleHeli

633 Jan 04, 2023
TorchMD-Net provides state-of-the-art graph neural networks and equivariant transformer neural networks potentials for learning molecular potentials

TorchMD-net TorchMD-Net provides state-of-the-art graph neural networks and equivariant transformer neural networks potentials for learning molecular

TorchMD 104 Jan 03, 2023
Data loaders and abstractions for text and NLP

torchtext This repository consists of: torchtext.datasets: The raw text iterators for common NLP datasets torchtext.data: Some basic NLP building bloc

3.2k Jan 08, 2023
MMdnn is a set of tools to help users inter-operate among different deep learning frameworks. E.g. model conversion and visualization. Convert models between Caffe, Keras, MXNet, Tensorflow, CNTK, PyTorch Onnx and CoreML.

MMdnn MMdnn is a comprehensive and cross-framework tool to convert, visualize and diagnose deep learning (DL) models. The "MM" stands for model manage

Microsoft 5.7k Jan 09, 2023
一个免费开源一键搭建的通用验证码识别平台,大部分常见的中英数验证码识别都没啥问题。

captcha_server 一个免费开源一键搭建的通用验证码识别平台,大部分常见的中英数验证码识别都没啥问题。 使用方法 python = 3.8 以上环境 pip install -r requirements.txt -i https://pypi.douban.com/simple gun

Sml2h3 189 Dec 02, 2022
Preparation material for Dropbox interviews

Dropbox-Onsite-Interviews A guide for the Dropbox onsite interview! The Dropbox interview question bank is very small. The bank has been in a Chinese

386 Dec 31, 2022
Useful materials and tutorials for 110-1 NTU DBME5028 (Application of Deep Learning in Medical Imaging)

Useful materials and tutorials for 110-1 NTU DBME5028 (Application of Deep Learning in Medical Imaging)

7 Jun 22, 2022
ICS 4u HD project, start before-wards. A curtain shooting game using python.

Touhou-Star-Salvation HDCH ICS 4u HD project, start before-wards. A curtain shooting game using python and pygame. By Jason Li For arts and gameplay,

15 Dec 22, 2022
We utilize deep reinforcement learning to obtain favorable trajectories for visual-inertial system calibration.

Unified Data Collection for Visual-Inertial Calibration via Deep Reinforcement Learning Update: The lastest code will be updated in this branch. Pleas

ETHZ ASL 27 Dec 29, 2022
TensorFlow CNN for fast style transfer

Fast Style Transfer in TensorFlow Add styles from famous paintings to any photo in a fraction of a second! It takes 100ms on a 2015 Titan X to style t

1 Dec 14, 2021
A transformer model to predict pathogenic mutations

MutFormer MutFormer is an application of the BERT (Bidirectional Encoder Representations from Transformers) NLP (Natural Language Processing) model wi

Wang Genomics Lab 2 Nov 29, 2022