[IJCAI-2021] A benchmark of data-free knowledge distillation from paper "Contrastive Model Inversion for Data-Free Knowledge Distillation"

Overview

DataFree

A benchmark of data-free knowledge distillation from paper "Contrastive Model Inversion for Data-Free Knowledge Distillation"

Authors: Gongfan Fang, Jie Song, Xinchao Wang, Chengchao Shen, Xingen Wang, Mingli Song

CMI (this work) DeepInv
ZSKT DFQ

Results

1. CIFAR-10

Method resnet-34
resnet-18
vgg-11
resnet-18
wrn-40-2
wrn-16-1
wrn-40-2
wrn-40-1
wrn-40-2
wrn-16-2
T. Scratch 95.70 92.25 94.87 94.87 94.87
S. Scratch 95.20 95.20 91.12 93.94 93.95
DAFL 92.22 81.10 65.71 81.33 81.55
ZSKT 93.32 89.46 83.74 86.07 89.66
DeepInv 93.26 90.36 83.04 86.85 89.72
DFQ 94.61 90.84 86.14 91.69 92.01
CMI 94.84 91.13 90.01 92.78 92.52

2. CIFAR-100

Method resnet-34
resnet-18
vgg-11
resnet-18
wrn-40-2
wrn-16-1
wrn-40-2
wrn-40-1
wrn-40-2
wrn-16-2
T. Scratch 78.05 71.32 75.83 75.83 75.83
S. Scratch 77.10 77.01 65.31 72.19 73.56
DAFL 74.47 57.29 22.50 34.66 40.00
ZSKT 67.74 34.72 30.15 29.73 28.44
DeepInv 61.32 54.13 53.77 61.33 61.34
DFQ 77.01 68.32 54.77 62.92 59.01
CMI 77.04 70.56 57.91 68.88 68.75

Quick Start

1. Visualize the inverted samples

Results will be saved as checkpoints/datafree-cmi/synthetic-cmi_for_vis.png

bash scripts/cmi/cmi_cifar10_for_vis.sh

2. Reproduce our results

Note: This repo was refactored from our experimental code and is still under development. I'm struggling to find the appropriate hyperparams for every methods (°ー°〃). So far, we only provide the hyperparameters to reproduce CIFAR-10 results for wrn-40-2 => wrn-16-1. You may need to tune the hyper-parameters for other models and datasets. More resources will be uploaded in the future update.

To reproduce our results, please download pre-trained teacher models from Dropbox-Models (266 MB) and extract them as checkpoints/pretrained. Also a pre-inverted data set with ~50k samples is available for wrn-40-2 teacher on CIFAR-10. You can download it from Dropbox-Data (133 MB) and extract them to run/cmi-preinverted-wrn402/.

  • Non-adversarial CMI: you can train a student model on inverted data directly. It should reach the accuracy of ~87.38% on CIFAR-10 as reported in Figure 3.

    bash scripts/cmi/nonadv_cmi_cifar10_wrn402_wrn161.sh
    
  • Adversarial CMI: or you can apply the adversarial distillation based on the pre-inverted data, where ~10k (256x40) new samples will be generated to improve the student. It should reach the accuracy of ~90.01% on CIFAR-10 as reported in Table 1.

    bash scripts/cmi/adv_cmi_cifar10_wrn402_wrn161.sh
    
  • Scratch CMI: It is OK to run the cmi algorithm wihout any pre-inverted data, but the student may overfit to early samples due to the limited data amount. It should reach the accuracy of ~88.82% on CIFAR-10, slightly worse than our reported results (90.01%).

    bash scripts/cmi/scratch_cmi_cifar10_wrn402_wrn161.sh
    

3. Scratch training

python train_scratch.py --model wrn40_2 --dataset cifar10 --batch-size 256 --lr 0.1 --epoch 200 --gpu 0

4. Vanilla KD

# KD with original training data (beta>0 to use hard targets)
python vanilla_kd.py --teacher wrn40_2 --student wrn16_1 --dataset cifar10 --transfer_set cifar10 --beta 0.1 --batch-size 128 --lr 0.1 --epoch 200 --gpu 0 

# KD with unlabeled data
python vanilla_kd.py --teacher wrn40_2 --student wrn16_1 --dataset cifar10 --transfer_set cifar100 --beta 0 --batch-size 128 --lr 0.1 --epoch 200 --gpu 0 

# KD with unlabeled data from a specified folder
python vanilla_kd.py --teacher wrn40_2 --student wrn16_1 --dataset cifar10 --transfer_set run/cmi --beta 0 --batch-size 128 --lr 0.1 --epoch 200 --gpu 0 

5. Data-free KD

bash scripts/xxx/xxx.sh # e.g. scripts/zskt/zskt_cifar10_wrn402_wrn161.sh

Hyper-parameters used by different methods:

Method adv bn oh balance act cr GAN Example
DAFL - - - scripts/dafl_cifar10.sh
ZSKT - - - - - scripts/zskt_cifar10.sh
DeepInv - - - - scripts/deepinv_cifar10.sh
DFQ - - scripts/dfq_cifar10.sh
CMI - - scripts/cmi_cifar10_scratch.sh

4. Use your models/datasets

You can register your models and datasets in registry.py by modifying NORMALIZE_DICT, MODEL_DICT and get_dataset. Then you can run the above commands to train your own models. As DAFL requires intermediate features from the penultimate layer, your model should accept an return_features=True parameter and return a (logits, features) tuple for DAFL.

5. Implement your algorithms

Your algorithms should inherent datafree.synthesis.BaseSynthesizer to implement two interfaces: 1) BaseSynthesizer.synthesize takes several steps to craft new samples and return an image dict for visualization; 2) BaseSynthesizer.sample fetches a batch of training data for KD.

Citation

If you found this work useful for your research, please cite our paper:

@misc{fang2021contrastive,
      title={Contrastive Model Inversion for Data-Free Knowledge Distillation}, 
      author={Gongfan Fang and Jie Song and Xinchao Wang and Chengchao Shen and Xingen Wang and Mingli Song},
      year={2021},
      eprint={2105.08584},
      archivePrefix={arXiv},
      primaryClass={cs.AI}
}

Reference

Owner
ZJU-VIPA
Laboratory of Visual Intelligence and Pattern Analysis
ZJU-VIPA
ICCV2021 Oral SA-ConvONet: Sign-Agnostic Optimization of Convolutional Occupancy Networks

Sign-Agnostic Convolutional Occupancy Networks Paper | Supplementary | Video | Teaser Video | Project Page This repository contains the implementation

63 Nov 18, 2022
Official code for "InfoGraph: Unsupervised and Semi-supervised Graph-Level Representation Learning via Mutual Information Maximization" (ICLR 2020, spotlight)

InfoGraph: Unsupervised and Semi-supervised Graph-Level Representation Learning via Mutual Information Maximization Authors: Fan-yun Sun, Jordan Hoffm

Fan-Yun Sun 232 Dec 28, 2022
ICRA 2021 - Robust Place Recognition using an Imaging Lidar

Robust Place Recognition using an Imaging Lidar A place recognition package using high-resolution imaging lidar. For best performance, a lidar equippe

Tixiao Shan 293 Dec 27, 2022
ML-Ensemble – high performance ensemble learning

A Python library for high performance ensemble learning ML-Ensemble combines a Scikit-learn high-level API with a low-level computational graph framew

Sebastian Flennerhag 764 Dec 31, 2022
This is a package for LiDARTag, described in paper: LiDARTag: A Real-Time Fiducial Tag System for Point Clouds

LiDARTag Overview This is a package for LiDARTag, described in paper: LiDARTag: A Real-Time Fiducial Tag System for Point Clouds (PDF)(arXiv). This wo

University of Michigan Dynamic Legged Locomotion Robotics Lab 159 Dec 21, 2022
kullanışlı ve işinizi kolaylaştıracak bir araç

Hey merhaba! işte çok sorulan sorularının cevabı ve sorunlarının çözümü; Soru= İçinde var denilen birçok şeyi göremiyorum bunun sebebi nedir? Cevap= B

Sexettin 16 Dec 17, 2022
ICNet and PSPNet-50 in Tensorflow for real-time semantic segmentation

Real-Time Semantic Segmentation in TensorFlow Perform pixel-wise semantic segmentation on high-resolution images in real-time with Image Cascade Netwo

Oles Andrienko 219 Nov 21, 2022
Official Repository for the paper "Improving Baselines in the Wild".

iWildCam and FMoW baselines (WILDS) This repository was originally forked from the official repository of WILDS datasets (commit 7e103ed) For general

Kazuki Irie 3 Nov 24, 2022
A repo for Causal Imitation Learning under Temporally Correlated Noise

CausIL A repo for Causal Imitation Learning under Temporally Correlated Noise. Running Experiments To re-train an expert, run: python experts/train_ex

Gokul Swamy 5 Nov 01, 2022
Official Python implementation of the FuzionCoin protocol

PyFuzc Official Python implementation of the FuzionCoin protocol WARNING: Under construction. Use at your own risk. Some functions may not work. Setup

FuzionCoin 3 Jul 07, 2022
Object detection evaluation metrics using Python.

Object detection evaluation metrics using Python.

Louis Facun 2 Sep 06, 2022
Complete-IoU (CIoU) Loss and Cluster-NMS for Object Detection and Instance Segmentation (YOLACT)

Complete-IoU Loss and Cluster-NMS for Improving Object Detection and Instance Segmentation. Our paper is accepted by IEEE Transactions on Cybernetics

290 Dec 25, 2022
Efficient Sharpness-aware Minimization for Improved Training of Neural Networks

Efficient Sharpness-aware Minimization for Improved Training of Neural Networks Code for “Efficient Sharpness-aware Minimization for Improved Training

Angusdu 32 Oct 18, 2022
[ECCVW2020] Robust Long-Term Object Tracking via Improved Discriminative Model Prediction (RLT-DiMP)

Feel free to visit my homepage Robust Long-Term Object Tracking via Improved Discriminative Model Prediction (RLT-DIMP) [ECCVW2020 paper] Presentation

Seokeon Choi 35 Oct 26, 2022
Keras implementations of Generative Adversarial Networks.

This repository has gone stale as I unfortunately do not have the time to maintain it anymore. If you would like to continue the development of it as

Erik Linder-Norén 8.9k Jan 04, 2023
OBBDetection: an oriented object detection toolbox modified from MMdetection

OBBDetection note: If you have questions or good suggestions, feel free to propose issues and contact me. introduction OBBDetection is an oriented obj

MIXIAOXIN_HO 3 Nov 11, 2022
Readings for "A Unified View of Relational Deep Learning for Polypharmacy Side Effect, Combination Therapy, and Drug-Drug Interaction Prediction."

Polypharmacy - DDI - Synergy Survey The Survey Paper This repository accompanies our survey paper A Unified View of Relational Deep Learning for Polyp

AstraZeneca 79 Jan 05, 2023
Code for the paper "JANUS: Parallel Tempered Genetic Algorithm Guided by Deep Neural Networks for Inverse Molecular Design"

JANUS: Parallel Tempered Genetic Algorithm Guided by Deep Neural Networks for Inverse Molecular Design This repository contains code for the paper: JA

Aspuru-Guzik group repo 55 Nov 29, 2022
Categorical Depth Distribution Network for Monocular 3D Object Detection

CaDDN CaDDN is a monocular-based 3D object detection method. This repository is based off of [OpenPCDet]. Categorical Depth Distribution Network for M

Toronto Robotics and AI Laboratory 289 Jan 05, 2023
StyleGAN2 - Official TensorFlow Implementation

StyleGAN2 - Official TensorFlow Implementation

NVIDIA Research Projects 10.1k Dec 28, 2022