DiffStride: Learning strides in convolutional neural networks

Overview

DiffStride: Learning strides in convolutional neural networks

Overview

DiffStride is a pooling layer with learnable strides. Unlike strided convolutions, average pooling or max-pooling that require cross-validating stride values at each layer, DiffStride can be initialized with an arbitrary value at each layer (e.g. (2, 2) and during training its strides will be optimized for the task at hand.

We describe DiffStride in our ICLR 2022 paper Learning Strides in Convolutional Neural Network. Compared to the experiments described in the paper, this implementation uses a Pre-Act Resnet and uses Mixup in training.

Installation

To install the diffstride library, run the following pip git clone this repo:

git clone https://github.com/google-research/diffstride.git

The cd into the root and run the command:

pip install -e .

Example training

To run an example training on CIFAR10 and save the result in TensorBoard:

python3 -m diffstride.examples.main \
  --gin_config=cifar10.gin \
  --gin_bindings="train.workdir = '/tmp/exp/diffstride/resnet18/'"

Using custom parameters

This implementation uses Gin to parametrize the model, data processing and training loop. To use custom parameters, one should edit examples/cifar10.gin.

For example, to train with SpectralPooling on cifar100:

data.load_datasets:
  name = 'cifar100'

resnet.Resnet:
  pooling_cls = @pooling.FixedSpectralPooling

Or to train with strided convolutions and without Mixup:

data.load_datasets:
  mixup_alpha = 0.0

resnet.Resnet:
  pooling_cls = None

Results

This current implementation gives the following accuracy on CIFAR-10 and CIFAR-100, averaged over three runs. To show the robustness of DiffStride to stride initialization, we run both with the standard strides of ResNet (resnet.resnet18.strides = '1, 1, 2, 2, 2') and with a 'poor' choice of strides (resnet.resnet18.strides = '1, 1, 3, 2, 3'). Unlike Strided Convolutions and fixed Spectral Pooling, DiffStride is not affected by the stride initialization.

CIFAR-10

Pooling Test Accuracy (%) w/ strides = (1, 1, 2, 2, 2) Test Accuracy (%) w/ strides = (1, 1, 3, 2, 3)
Strided Convolution (Baseline) 91.06 ± 0.04 89.21 ± 0.27
Spectral Pooling 93.49 ± 0.05 92.00 ± 0.08
DiffStride 94.20 ± 0.06 94.19 ± 0.15

CIFAR-100

Pooling Test Accuracy (%) w/ strides = (1, 1, 2, 2, 2) Test Accuracy (%) w/ strides = (1, 1, 3, 2, 3)
Strided Convolution (Baseline) 65.75 ± 0.39 60.82 ± 0.42
Spectral Pooling 72.86 ± 0.23 67.74 ± 0.43
DiffStride 76.08 ± 0.23 76.09 ± 0.06

CPU/GPU Warning

We rely on the tensorflow FFT implementation which requires the input data to be in the channels_first format. This is usually not the regular data format of most datasets (including CIFAR) and running with channels_first also prevents from using of convolutions on CPU. Therefore even if we do support channels_last data format for CPU compatibility , we do encourage the user to run with channels_first data format on GPU.

Reference

If you use this repository, please consider citing:

@article{riad2022diffstride,
  title={Learning Strides in Convolutional Neural Networks},
  author={Riad, Rachid and Teboul, Olivier and Grangier, David and Zeghidour, Neil},
  journal={ICLR},
  year={2022}
}

Disclainer

This is not an official Google product.

Owner
Google Research
Google Research
This project provides the code and datasets for 'CapSal: Leveraging Captioning to Boost Semantics for Salient Object Detection', CVPR 2019.

Code-and-Dataset-for-CapSal This project provides the code and datasets for 'CapSal: Leveraging Captioning to Boost Semantics for Salient Object Detec

lu zhang 48 Aug 19, 2022
A PyTorch Implementation of Neural IMage Assessment

NIMA: Neural IMage Assessment This is a PyTorch implementation of the paper NIMA: Neural IMage Assessment (accepted at IEEE Transactions on Image Proc

yunxiaos 418 Dec 29, 2022
A PyTorch implementation of the baseline method in Panoptic Narrative Grounding (ICCV 2021 Oral)

A PyTorch implementation of the baseline method in Panoptic Narrative Grounding (ICCV 2021 Oral)

Biomedical Computer Vision @ Uniandes 52 Dec 19, 2022
Mixed Neural Likelihood Estimation for models of decision-making

Mixed neural likelihood estimation for models of decision-making Mixed neural likelihood estimation (MNLE) enables Bayesian parameter inference for mo

mackelab 9 Dec 22, 2022
MODNet: Trimap-Free Portrait Matting in Real Time

MODNet is a model for real-time portrait matting with only RGB image input.

Zhanghan Ke 2.8k Dec 30, 2022
Incomplete easy-to-use math solver and PDF generator.

Math Expert Let me do your work Preview preview.mp4 Introduction Math Expert is our (@salastro, @younis-tarek, @marawn-mogeb) math high school graduat

SalahDin Ahmed 22 Jul 11, 2022
A collection of pre-trained StyleGAN2 models trained on different datasets at different resolution.

Awesome Pretrained StyleGAN2 A collection of pre-trained StyleGAN2 models trained on different datasets at different resolution. Note the readme is a

Justin 1.1k Dec 24, 2022
[ICLR'19] Trellis Networks for Sequence Modeling

TrellisNet for Sequence Modeling This repository contains the experiments done in paper Trellis Networks for Sequence Modeling by Shaojie Bai, J. Zico

CMU Locus Lab 460 Oct 13, 2022
Learning to Map Large-scale Sparse Graphs on Memristive Crossbar

Release of AutoGMap:Learning to Map Large-scale Sparse Graphs on Memristive Crossbar For reproduction of our searched model, the Ubuntu OS is recommen

2 Aug 23, 2022
Implementation of gaze tracking and demo

Predicting Customer Demand by Using Gaze Detecting and Object Tracking This project is the integration of gaze detecting and object tracking. Predict

2 Oct 20, 2022
MPViT:Multi-Path Vision Transformer for Dense Prediction

MPViT : Multi-Path Vision Transformer for Dense Prediction This repository inlcu

Youngwan Lee 272 Dec 20, 2022
Corgis are the cutest creatures; have 30K of them!

corgi-net This is a dataset of corgi images scraped from the corgi subreddit. After filtering using an ImageNet classifier, the training set consists

Alex Nichol 6 Dec 24, 2022
A curated (most recent) list of resources for Learning with Noisy Labels

A curated (most recent) list of resources for Learning with Noisy Labels

Jiaheng Wei 321 Jan 09, 2023
LBK 20 Dec 02, 2022
Dynamic wallpaper generator.

Wiki • About • Installation About This project is a dynamic wallpaper changer. It waits untill you turn on the music, downloads album cover if it's po

3 Sep 18, 2021
2021-MICCAI-Progressively Normalized Self-Attention Network for Video Polyp Segmentation

2021-MICCAI-Progressively Normalized Self-Attention Network for Video Polyp Segmentation Authors: Ge-Peng Ji*, Yu-Cheng Chou*, Deng-Ping Fan, Geng Che

Ge-Peng Ji (Daniel) 85 Dec 30, 2022
Deep Face Recognition in PyTorch

Face Recognition in PyTorch By Alexey Gruzdev and Vladislav Sovrasov Introduction A repository for different experimental Face Recognition models such

Alexey Gruzdev 141 Sep 11, 2022
SciKit-Learn Laboratory (SKLL) makes it easy to run machine learning experiments.

SciKit-Learn Laboratory This Python package provides command-line utilities to make it easier to run machine learning experiments with scikit-learn. O

ETS 528 Nov 25, 2022
Code for A Volumetric Transformer for Accurate 3D Tumor Segmentation

VT-UNet This repo contains the supported pytorch code and configuration files to reproduce 3D medical image segmentaion results of VT-UNet. Environmen

Himashi Amanda Peiris 114 Dec 20, 2022
DAN: Unfolding the Alternating Optimization for Blind Super Resolution

DAN-Basd-on-Openmmlab DAN: Unfolding the Alternating Optimization for Blind Super Resolution We reproduce DAN via mmediting based on open-sourced code

AlexZou 72 Dec 13, 2022