TF2 implementation of knowledge distillation using the "function matching" hypothesis from the paper Knowledge distillation: A good teacher is patient and consistent by Beyer et al.

Overview

FunMatch-Distillation

TF2 implementation of knowledge distillation using the "function matching" hypothesis from the paper Knowledge distillation: A good teacher is patient and consistent by Beyer et al.

The techniques have been demonstrated using three datasets:

This repository provides Kaggle Kernel notebooks so that we can leverage the free TPu v3-8 to run the long training schedules. Please refer to this section.

Importance

The importance of knowledge distillation lies in its practical usefulness. With the recipes from "function matching", we can now perform knowledge distillation using a principled approach yielding student models that can actually match the performance of their teacher models. This essentially allows us to compress bigger models into (much) smaller ones thereby reducing storage costs and improving inference speed.

Key ingredients

  • No use of ground-truth labels during distillation.
  • Teacher and student should see same images during distillation as opposed to differently augmented views of same images.
  • Aggressive form of MixUp as the key augmentation recipe. MixUp is paired with "Inception-style" cropping (implemented in this script).
  • A LONG training schedule for distillation. At least 1000 epochs to get good results without overfitting. The importance of a long training schedule is paramount as studied in the paper.

Results

The table below summarizes the results of my experiments. In all cases, teacher is a BiT-ResNet101x3 model and student is a BiT-ResNet50x1. For fun, you can also try to distill into other model families. BiT stands for "Big Transfer" and it was proposed in this paper.

Dataset Teacher/Student Top-1 Acc on Test Location
Flowers102 Teacher 98.18% Link
Flowers102 Student (1000 epochs) 81.02% Link
Pet37 Teacher 90.92% Link
Pet37 Student (300 epochs) 81.3% Link
Pet37 Student (1000 epochs) 86% Link
Food101 Teacher 85.52% Link
Food101 Student (100 epochs) 76.06% Link

(Location denotes the trained model location.)

These results are consistent with Table 4 of the original paper.

It should be noted that none of the above student training regimes showed signs of overfitting. Further improvements can be done by training for longer. The authors also showed that Shampoo can get to similar performance much quicker than Adam during distillation. So, it may very well be possible to get this performance with fewer epochs with Shampoo.

A few differences from the original implementation:

  • The authors use BiT-ResNet152x2 as a teacher.
  • The mixup() variant I used will produce a pair of duplicate images if the number of images is even. Now, for 8 workers it will become 8 pairs. This may have led to the reduced performance. We can overcome this by using tf.roll(images, 1, axis=0) instead of tf.reverse in the mixup() function. Thanks to Lucas Beyer for pointing this out.

About the notebooks

All the notebooks are fully runnable on Kaggle Kernel. The only requirement is that you'd need a billing enabled GCP account to use GCS Buckets to store data.

Notebook Description Kaggle Kernel
train_bit.ipynb Shows how to train the teacher model. Link
train_bit_keras_tuner.ipynb Shows how to run hyperparameter tuning using
Keras Tuner for the teacher model.
Link
funmatch_distillation.ipynb Shows an implementation of the recipes
from "function matching".
Link

These are only demonstrated on the Pet37 dataset but will work out-of-the-box for the other datasets too.

TFRecords

For convenience, TFRecords of different datasets are provided:

Dataset TFRecords
Flowers102 Link
Pet37 Link
Food101 Link

Paper citation

@misc{beyer2021knowledge,
      title={Knowledge distillation: A good teacher is patient and consistent}, 
      author={Lucas Beyer and Xiaohua Zhai and Amélie Royer and Larisa Markeeva and Rohan Anil and Alexander Kolesnikov},
      year={2021},
      eprint={2106.05237},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Acknowledgements

Huge thanks to Lucas Beyer (first author of the paper) for providing suggestions on the initial version of the implementation.

Thanks to the ML-GDE program for providing GCP credits.

Thanks to TRC for providing Cloud TPU access.

You might also like...
Implementation of momentum^2 teacher

Momentum^2 Teacher: Momentum Teacher with Momentum Statistics for Self-Supervised Learning Requirements All experiments are done with python3.6, torch

Code implementation of Data Efficient Stagewise Knowledge Distillation paper.
Code implementation of Data Efficient Stagewise Knowledge Distillation paper.

Data Efficient Stagewise Knowledge Distillation Table of Contents Data Efficient Stagewise Knowledge Distillation Table of Contents Requirements Image

The official implementation of CVPR 2021 Paper: Improving Weakly Supervised Visual Grounding by Contrastive Knowledge Distillation.

Improving Weakly Supervised Visual Grounding by Contrastive Knowledge Distillation This repository is the official implementation of CVPR 2021 paper:

PyTorch implementation of paper A Fast Knowledge Distillation Framework for Visual Recognition.
PyTorch implementation of paper A Fast Knowledge Distillation Framework for Visual Recognition.

FKD: A Fast Knowledge Distillation Framework for Visual Recognition Official PyTorch implementation of paper A Fast Knowledge Distillation Framework f

Official implementation of the paper
Official implementation of the paper "Lightweight Deep CNN for Natural Image Matting via Similarity Preserving Knowledge Distillation"

Lightweight-Deep-CNN-for-Natural-Image-Matting-via-Similarity-Preserving-Knowledge-Distillation Introduction Accepted at IEEE Signal Processing Letter

Pcos-prediction - Predicts the likelihood of Polycystic Ovary Syndrome based on patient attributes and symptoms
Pcos-prediction - Predicts the likelihood of Polycystic Ovary Syndrome based on patient attributes and symptoms

PCOS Prediction 🥼 Predicts the likelihood of Polycystic Ovary Syndrome based on

[ICLR 2021 Spotlight Oral] "Undistillable: Making A Nasty Teacher That CANNOT teach students", Haoyu Ma, Tianlong Chen, Ting-Kuei Hu, Chenyu You, Xiaohui Xie, Zhangyang Wang

Undistillable: Making A Nasty Teacher That CANNOT teach students "Undistillable: Making A Nasty Teacher That CANNOT teach students" Haoyu Ma, Tianlong

Unet network with mean teacher for altrasound image segmentation

Unet network with mean teacher for altrasound image segmentation

Details about the wide minima density hypothesis and metrics to compute width of a minima

wide-minima-density-hypothesis Details about the wide minima density hypothesis and metrics to compute width of a minima This repo presents the wide m

Releases(v4.0.0)
Owner
Sayak Paul
Trying to learn how machines learn.
Sayak Paul
Source codes for the paper "Local Additivity Based Data Augmentation for Semi-supervised NER"

LADA This repo contains codes for the following paper: Jiaao Chen*, Zhenghui Wang*, Ran Tian, Zichao Yang, Diyi Yang: Local Additivity Based Data Augm

GT-SALT 36 Dec 02, 2022
Equivariant layers for RC-complement symmetry in DNA sequence data

Equi-RC Equivariant layers for RC-complement symmetry in DNA sequence data This is a repository that implements the layers as described in "Reverse-Co

7 May 19, 2022
The Generic Manipulation Driver Package - Implements a ROS Interface over the robotics toolbox for Python

Armer Driver Armer aims to provide an interface layer between the hardware drivers of a robotic arm giving the user control in several ways: Joint vel

QUT Centre for Robotics (QCR) 13 Nov 26, 2022
BossNAS: Exploring Hybrid CNN-transformers with Block-wisely Self-supervised Neural Architecture Search

BossNAS This repository contains PyTorch evaluation code, retraining code and pretrained models of our paper: BossNAS: Exploring Hybrid CNN-transforme

Changlin Li 127 Dec 26, 2022
Implementation of Nyström Self-attention, from the paper Nyströmformer

Nyström Attention Implementation of Nyström Self-attention, from the paper Nyströmformer. Yannic Kilcher video Install $ pip install nystrom-attention

Phil Wang 95 Jan 02, 2023
Easily Process a Batch of Cox Models

ezcox: Easily Process a Batch of Cox Models The goal of ezcox is to operate a batch of univariate or multivariate Cox models and return tidy result. ⏬

Shixiang Wang 15 May 23, 2022
Testing and Estimation of structural breaks in Stata

xtbreak estimating and testing for many known and unknown structural breaks in time series and panel data. For an overview of xtbreak test see xtbreak

Jan Ditzen 13 Jun 19, 2022
Deep metric learning methods implemented in Chainer

Deep Metric Learning Implementation of several methods for deep metric learning in Chainer v4.2.0. Proxy-NCA: No Fuss Distance Metric Learning using P

ronekko 156 Nov 28, 2022
Config files for my GitHub profile.

Canalyst Candas Data Science Library Name Canalyst Candas Description Built by a former PM / analyst to give anyone with a little bit of Python knowle

Canalyst Candas 13 Jun 24, 2022
Deep Learning GPU Training System

DIGITS DIGITS (the Deep Learning GPU Training System) is a webapp for training deep learning models. The currently supported frameworks are: Caffe, To

NVIDIA Corporation 4.1k Jan 03, 2023
A static analysis library for computing graph representations of Python programs suitable for use with graph neural networks.

python_graphs This package is for computing graph representations of Python programs for machine learning applications. It includes the following modu

Google Research 258 Dec 29, 2022
Use unsupervised and supervised learning to predict stocks

AIAlpha: Multilayer neural network architecture for stock return prediction This project is meant to be an advanced implementation of stacked neural n

Vivek Palaniappan 1.5k Dec 26, 2022
On the Analysis of French Phonetic Idiosyncrasies for Accent Recognition

On the Analysis of French Phonetic Idiosyncrasies for Accent Recognition With the spirit of reproducible research, this repository contains codes requ

0 Feb 24, 2022
Cooperative multi-agent reinforcement learning for high-dimensional nonequilibrium control

Cooperative multi-agent reinforcement learning for high-dimensional nonequilibrium control Official implementation of: Cooperative multi-agent reinfor

0 Nov 16, 2021
Official PyTorch implementation of RIO

Image-Level or Object-Level? A Tale of Two Resampling Strategies for Long-Tailed Detection Figure 1: Our proposed Resampling at image-level and obect-

NVIDIA Research Projects 17 May 20, 2022
An architecture that makes any doodle realistic, in any specified style, using VQGAN, CLIP and some basic embedding arithmetics.

Sketch Simulator An architecture that makes any doodle realistic, in any specified style, using VQGAN, CLIP and some basic embedding arithmetics. See

12 Dec 18, 2022
Microsoft Cognitive Toolkit (CNTK), an open source deep-learning toolkit

CNTK Chat Windows build status Linux build status The Microsoft Cognitive Toolkit (https://cntk.ai) is a unified deep learning toolkit that describes

Microsoft 17.3k Dec 29, 2022
HHP-Net: A light Heteroscedastic neural network for Head Pose estimation with uncertainty

HHP-Net: A light Heteroscedastic neural network for Head Pose estimation with uncertainty Giorgio Cantarini, Francesca Odone, Nicoletta Noceti, Federi

18 Aug 02, 2022
LLVIP: A Visible-infrared Paired Dataset for Low-light Vision

LLVIP: A Visible-infrared Paired Dataset for Low-light Vision Project | Arxiv | Abstract It is very challenging for various visual tasks such as image

CVSM Group - email: <a href=[email protected]"> 377 Jan 07, 2023
3D-Reconstruction 基于深度学习方法的单目多视图三维重建

基于深度学习方法的单目多视图三维重建 Part I 三维重建 代码:Part1 技术文档:[Markdown] [PDF] 原始图像:Original Images 点云结果:Point Cloud Results-1

HMT_Curo 19 Dec 26, 2022