2.86% and 15.85% on CIFAR-10 and CIFAR-100

Overview

Shake-Shake regularization

This repository contains the code for the paper Shake-Shake regularization. This arxiv paper is an extension of Shake-Shake regularization of 3-branch residual networks which was accepted as a workshop contribution at ICLR 2017.

The code is based on fb.resnet.torch.

Table of Contents

  1. Introduction
  2. Results
  3. Usage
  4. Contact

Introduction

The method introduced in this paper aims at helping deep learning practitioners faced with an overfit problem. The idea is to replace, in a multi-branch network, the standard summation of parallel branches with a stochastic affine combination. Applied to 3-branch residual networks, shake-shake regularization improves on the best single shot published results on CIFAR-10 and CIFAR-100 by reaching test errors of 2.86% and 15.85%.

shake-shake

Figure 1: Left: Forward training pass. Center: Backward training pass. Right: At test time.

Bibtex:

@article{Gastaldi17ShakeShake,
   title = {Shake-Shake regularization},
   author = {Xavier Gastaldi},
   journal = {arXiv preprint arXiv:1705.07485},
   year = 2017,
}

Results on CIFAR-10

The base network is a 26 2x32d ResNet (i.e. the network has a depth of 26, 2 residual branches and the first residual block has a width of 32). "Shake" means that all scaling coefficients are overwritten with new random numbers before the pass. "Even" means that all scaling coefficients are set to 0.5 before the pass. "Keep" means that we keep, for the backward pass, the scaling coefficients used during the forward pass. "Batch" means that, for each residual block, we apply the same scaling coefficient for all the images in the mini-batch. "Image" means that, for each residual block, we apply a different scaling coefficient for each image in the mini-batch. The numbers in the Table below represent the average of 3 runs except for the 96d models which were run 5 times.

Forward Backward Level 26 2x32d 26 2x64d 26 2x96d 26 2x112d
Even Even n\a 4.27 3.76 3.58 -
Even Shake Batch 4.44 - -
Shake Keep Batch 4.11 - - -
Shake Even Batch 3.47 3.30 - -
Shake Shake Batch 3.67 3.07 - -
Even Shake Image 4.11 - - -
Shake Keep Image 4.09 - - -
Shake Even Image 3.47 3.20 - -
Shake Shake Image 3.55 2.98 2.86 2.821

Table 1: Error rates (%) on CIFAR-10 (Top 1 of the last epoch)

Other results

Cifar-100:
29 2x4x64d: 15.85%

Reduced CIFAR-10:
26 2x96d: 17.05%1

SVHN:
26 2x96d: 1.4%1

Reduced SVHN:
26 2x96d: 12.32%1

Usage

  1. Install fb.resnet.torch, optnet and lua-stdlib.
  2. Download Shake-Shake
git clone https://github.com/xgastaldi/shake-shake.git
  1. Copy the elements in the shake-shake folder and paste them in the fb.resnet.torch folder. This will overwrite 5 files (main.lua, train.lua, opts.lua, checkpoints.lua and models/init.lua) and add 4 new files (models/shakeshake.lua, models/shakeshakeblock.lua, models/mulconstantslices.lua and models/shakeshaketable.lua).
  2. To reproduce CIFAR-10 results (e.g. 26 2x32d "Shake-Shake-Image" ResNet) on 2 GPUs:
CUDA_VISIBLE_DEVICES=0,1 th main.lua -dataset cifar10 -nGPU 2 -batchSize 128 -depth 26 -shareGradInput false -optnet true -nEpochs 1800 -netType shakeshake -lrShape cosine -baseWidth 32 -LR 0.2 -forwardShake true -backwardShake true -shakeImage true

To get comparable results using 1 GPU, please change the batch size and the corresponding learning rate:

CUDA_VISIBLE_DEVICES=0 th main.lua -dataset cifar10 -nGPU 1 -batchSize 64 -depth 26 -shareGradInput false -optnet true -nEpochs 1800 -netType shakeshake -lrShape cosine -baseWidth 32 -LR 0.1 -forwardShake true -backwardShake true -shakeImage true

A 26 2x96d "Shake-Shake-Image" ResNet can be trained on 2 GPUs using:

CUDA_VISIBLE_DEVICES=0,1 th main.lua -dataset cifar10 -nGPU 2 -batchSize 128 -depth 26 -shareGradInput false -optnet true -nEpochs 1800 -netType shakeshake -lrShape cosine -baseWidth 96 -LR 0.2 -forwardShake true -backwardShake true -shakeImage true
  1. To reproduce CIFAR-100 results (e.g. 29 2x4x64d "Shake-Even-Image" ResNeXt) on 2 GPUs:
CUDA_VISIBLE_DEVICES=0,1 th main.lua -dataset cifar100 -depth 29 -baseWidth 64 -groups 4 -weightDecay 5e-4 -batchSize 32 -netType shakeshake -nGPU 2 -LR 0.025 -nThreads 8 -shareGradInput true -nEpochs 1800 -lrShape cosine -forwardShake true -backwardShake false -shakeImage true

Note

Changes made to fb.resnet.torch files:

main.lua
Ln 17, 54-59, 81-100: Adds a log

train.lua
Ln 36-38 58-60 206-213: Adds the cosine learning rate function
Ln 88-89: Adds the learning rate to the elements printed on screen

opts.lua
Ln 21-64: Adds Shake-Shake options

checkpoints.lua
Ln 15-16: Adds require 'models/shakeshakeblock', 'models/shakeshaketable' and require 'std'
Ln 60-61: Avoids using the fb.resnet.torch deepcopy (it doesn't seem to be compatible with the BN in shakeshakeblock) and replaces it with the deepcopy from stdlib
Ln 67-86: Saves only the last model

models/init.lua
Ln 91-92: Adds require 'models/mulconstantslices', require 'models/shakeshakeblock' and require 'models/shakeshaketable'

The main model is in shakeshake.lua. The residual block model is in shakeshakeblock.lua. mulconstantslices.lua is just an extension of nn.mulconstant that multiplies elements of a vector with image slices of a mini-batch tensor. shakeshaketable.lua contains the method used for CIFAR-100 since the ResNeXt code uses a table implementation instead of a module version.

Reimplementations

Pytorch
https://github.com/hysts/pytorch_shake_shake

Tensorflow
https://github.com/tensorflow/models/blob/master/research/autoaugment/
https://github.com/tensorflow/tensor2tensor

Contact

xgastaldi.mba2011 at london.edu
Any discussions, suggestions and questions are welcome!

References

(1) Ekin D. Cubuk, Barret Zoph, Dandelion Mane, Vijay Vasudevan, and Quoc V. Le. AutoAugment: Learning Augmentation Policies from Data. In arXiv:1805.09501, May 2018.

Code repository for "Stable View Synthesis".

Stable View Synthesis Code repository for "Stable View Synthesis". Setup Install the following Python packages in your Python environment - numpy (1.1

Intelligent Systems Lab Org 195 Dec 24, 2022
Official implementation of the method ContIG, for self-supervised learning from medical imaging with genomics

ContIG: Self-supervised Multimodal Contrastive Learning for Medical Imaging with Genetics This is the code implementation of the paper "ContIG: Self-s

Digital Health & Machine Learning 22 Dec 13, 2022
[CVPR 2021] Unsupervised 3D Shape Completion through GAN Inversion

ShapeInversion Paper Junzhe Zhang, Xinyi Chen, Zhongang Cai, Liang Pan, Haiyu Zhao, Shuai Yi, Chai Kiat Yeo, Bo Dai, Chen Change Loy "Unsupervised 3D

100 Dec 22, 2022
Adversarial Framework for (non-) Parametric Image Stylisation Mosaics

Fully Adversarial Mosaics (FAMOS) Pytorch implementation of the paper "Copy the Old or Paint Anew? An Adversarial Framework for (non-) Parametric Imag

Zalando Research 120 Dec 24, 2022
3D dataset of humans Manipulating Objects in-the-Wild (MOW)

MOW dataset [Website] This repository maintains our 3D dataset of humans Manipulating Objects in-the-Wild (MOW). The dataset contains 512 images in th

Zhe Cao 28 Nov 06, 2022
Pytorch implementation for RelTransformer

RelTransformer Our Architecture This is a Pytorch implementation for RelTransformer The implementation for Evaluating on VG200 can be found here Requi

Vision CAIR Research Group, KAUST 21 Nov 22, 2022
List of papers, code and experiments using deep learning for time series forecasting

Deep Learning Time Series Forecasting List of state of the art papers focus on deep learning and resources, code and experiments using deep learning f

Alexander Robles 2k Jan 06, 2023
Publication describing 3 ML examples at NSLS-II and interfacing into Bluesky

Machine learning enabling high-throughput and remote operations at large-scale user facilities. Overview This repository contains the source code and

BNL 4 Sep 24, 2022
Survival analysis (SA) is a well-known statistical technique for the study of temporal events.

DAGSurv Survival analysis (SA) is a well-known statistical technique for the study of temporal events. In SA, time-to-an-event data is modeled using a

Rahul Kukreja 1 Sep 05, 2022
The Habitat-Matterport 3D Research Dataset - the largest-ever dataset of 3D indoor spaces.

Habitat-Matterport 3D Dataset (HM3D) The Habitat-Matterport 3D Research Dataset is the largest-ever dataset of 3D indoor spaces. It consists of 1,000

Meta Research 62 Dec 27, 2022
On Size-Oriented Long-Tailed Graph Classification of Graph Neural Networks

On Size-Oriented Long-Tailed Graph Classification of Graph Neural Networks We provide the code (in PyTorch) and datasets for our paper "On Size-Orient

Zemin Liu 4 Jun 18, 2022
clustering moroccan stocks time series data using k-means with dtw (dynamic time warping)

Moroccan Stocks Clustering Context Hey! we don't always have to forecast time series am I right ? We use k-means to cluster about 70 moroccan stock pr

Ayman Lafaz 7 Oct 18, 2022
Repository features UNet inspired architecture used for segmenting lungs on chest X-Ray images

Lung Segmentation (2D) Repository features UNet inspired architecture used for segmenting lungs on chest X-Ray images. Demo See the application of the

163 Sep 21, 2022
PyTorch code for Vision Transformers training with the Self-Supervised learning method DINO

Self-Supervised Vision Transformers with DINO PyTorch implementation and pretrained models for DINO. For details, see Emerging Properties in Self-Supe

Facebook Research 4.2k Jan 03, 2023
Social Fabric: Tubelet Compositions for Video Relation Detection

Social-Fabric Social Fabric: Tubelet Compositions for Video Relation Detection This repository contains the code and results for the following paper:

Shuo Chen 7 Aug 09, 2022
A novel method to tune language models. Codes and datasets for paper ``GPT understands, too''.

P-tuning A novel method to tune language models. Codes and datasets for paper ``GPT understands, too''. How to use our code We have released the code

THUDM 562 Dec 27, 2022
A collection of awesome resources image-to-image translation.

awesome image-to-image translation A collection of resources on image-to-image translation. Contributing If you think I have missed out on something (

876 Dec 28, 2022
A lane detection integrated Real-time Instance Segmentation based on YOLACT (You Only Look At CoefficienTs)

Real-time Instance Segmentation and Lane Detection This is a lane detection integrated Real-time Instance Segmentation based on YOLACT (You Only Look

Jin 4 Dec 30, 2022
JupyterNotebook - C/C++, Javascript, HTML, LaTex, Shell scripts in Jupyter Notebook Also run them on remote computer

JupyterNotebook Read, write and execute C, C++, Javascript, Shell scripts, HTML, LaTex in jupyter notebook, And also execute them on remote computer R

1 Jan 09, 2022
K-FACE Analysis Project on Pytorch

Installation Setup with Conda # create a new environment conda create --name insightKface python=3.7 # or over conda activate insightKface #install t

Jung Jun Uk 7 Nov 10, 2022