Proximal Backpropagation - a neural network training algorithm that takes implicit instead of explicit gradient steps

Related tags

Deep Learningproxprop
Overview

Proximal Backpropagation

Proximal Backpropagation (ProxProp) is a neural network training algorithm that takes implicit instead of explicit gradient steps to update the network parameters. We have analyzed this algorithm in our ICLR 2018 paper:

Proximal Backpropagation (Thomas Frerix, Thomas Möllenhoff, Michael Moeller, Daniel Cremers; ICLR 2018) [https://arxiv.org/abs/1706.04638]

tl;dr

  • We provide a PyTorch implementation of ProxProp for Python 3 and PyTorch 1.0.1.
  • The results of our paper can be reproduced by executing the script paper_experiments.sh.
  • ProxProp is implemented as a torch.nn.Module (a 'layer') and can be combined with any other layer and first-order optimizer. While a ProxPropConv2d and a ProxPropLinear layer already exist, you can generate a ProxProp layer for your favorite linear layer with one line of code.

Installation

  1. Make sure you have a running Python 3 (tested with Python 3.7) ecosytem. We recommend that you use a conda install, as this is also the recommended option to get the latest PyTorch running. For this README and for the scripts, we assume that you have conda running with Python 3.7.
  2. Clone this repository and switch to the directory.
  3. Install the dependencies via conda install --file conda_requirements.txt and pip install -r pip_requirements.txt.
  4. Install PyTorch with magma support. We have tested our code with PyTorch 1.0.1 and CUDA 10.0. You can install this setup via
    conda install -c pytorch magma-cuda100
    conda install pytorch torchvision cudatoolkit=10.0 -c pytorch
    
  5. (optional, but necessary to reproduce paper experiments) Download the CIFAR-10 dataset by executing get_data.sh

Training neural networks with ProxProp

ProxProp is implemented as a custom linear layer (torch.nn.Module) with its own backward pass to take implicit gradient steps on the network parameters. With this design choice it can be combined with any other layer, for which one takes explicit gradient steps. Furthermore, the resulting update direction can be used with any first-order optimizer that expects a suitable update direction in parameter space. In our paper we prove that ProxProp generates a descent direction and show experiments with Nesterov SGD and Adam.

You can use our pre-defined layers ProxPropConv2d and ProxPropLinear, corresponding to nn.Conv2d and nn.Linear, by importing

from ProxProp import ProxPropConv2d, ProxPropLinear

Besides the usual layer parameters, as detailed in the PyTorch docs, you can provide:

  • tau_prox: step size for a proximal step; default is tau_prox=1
  • optimization_mode: can be one of 'prox_exact', 'prox_cg{N}', 'gradient' for an exact proximal step, an approximate proximal step with N conjugate gradient steps and an explicit gradient step, respectively; default is optimization_mode='prox_cg1'. The 'gradient' mode is for a fair comparison with SGD, as it incurs the same overhead as the other methods in exploiting a generic implementation with the provided PyTorch API.

If you want to use ProxProp to optimize your favorite linear layer, you can generate the respective module with one line of code. As an example for the the Conv3d layer:

from ProxProp import proxprop_module_generator
ProxPropConv3d = proxprop_module_generator(torch.nn.Conv3d)

This gives you a default implementation for the approximate conjugate gradient solver, which treats all parameters as a stacked vector. If you want to use the exact solver or want to use the conjugate gradient solver more efficiently, you have to provide the respective reshaping methods to proxprop_module_generator, as this requires specific knowledge of the layer's structure and cannot be implemented generically. As a template, take a look at the ProxProp.py file, where we have done this for the ProxPropLinear layer.

By reusing the forward/backward implementations of existing PyTorch modules, ProxProp becomes readily accessible. However, we pay an overhead associated with generically constructing the backward pass using the PyTorch API. We have intentionally sided with genericity over speed.

Reproduce paper experiments

To reproduce the paper experiments execute the script paper_experiments.sh. This will run our paper's experiments, store the results in the directory paper_experiments/ and subsequently compile the results into the file paper_plots.pdf. We use an NVIDIA Titan X GPU; executing the script takes roughly 3 hours.

Acknowledgement

We want to thank Soumith Chintala for helping us track down a mysterious bug and the whole PyTorch dev team for their continued development effort and great support to the community.

Publication

If you use ProxProp, please acknowledge our paper by citing

@article{Frerix-et-al-18,
    title = {Proximal Backpropagation},
    author={Thomas Frerix, Thomas Möllenhoff, Michael Moeller, Daniel Cremers},
    journal={International Conference on Learning Representations},
    year={2018},
    url = {https://arxiv.org/abs/1706.04638}
}
Owner
Thomas Frerix
Thomas Frerix
Towards Open-World Feature Extrapolation: An Inductive Graph Learning Approach

This repository holds the implementation for paper Towards Open-World Feature Extrapolation: An Inductive Graph Learning Approach Download our preproc

Qitian Wu 42 Dec 27, 2022
A TensorFlow implementation of Neural Program Synthesis from Diverse Demonstration Videos

ViZDoom http://vizdoom.cs.put.edu.pl ViZDoom allows developing AI bots that play Doom using only the visual information (the screen buffer). It is pri

Hyeonwoo Noh 1 Aug 19, 2020
Python script that takes an Impulse response .wav and a input .wav to demonstrate audio convolution.

convolver Python script that takes an Impulse response .wav and a input .wav to demonstrate audio convolution. Created by Sean Higley

Sean Higley 1 Feb 23, 2022
OpenCV, MediaPipe Pose Estimation, Affine Transform for Icon Overlay

Yoga Pose Identification and Icon Matching Project Goal Detect yoga poses performed by a user and overlay a corresponding icon image. Running the main

Anna Garverick 1 Dec 03, 2021
Tensor-based approaches for fMRI classification

tensor-fmri Using tensor-based approaches to classify fMRI data from StarPLUS. Citation If you use any code in this repository, please cite the follow

4 Sep 07, 2022
Reference PyTorch implementation of "End-to-end optimized image compression with competition of prior distributions"

PyTorch reference implementation of "End-to-end optimized image compression with competition of prior distributions" by Benoit Brummer and Christophe

Benoit Brummer 6 Jun 16, 2022
Implementation of "StrengthNet: Deep Learning-based Emotion Strength Assessment for Emotional Speech Synthesis"

StrengthNet Implementation of "StrengthNet: Deep Learning-based Emotion Strength Assessment for Emotional Speech Synthesis" https://arxiv.org/abs/2110

RuiLiu 65 Dec 20, 2022
The official implementation of NeurIPS 2021 paper: Finding Optimal Tangent Points for Reducing Distortions of Hard-label Attacks

Introduction This repository includes the source code for "Finding Optimal Tangent Points for Reducing Distortions of Hard-label Attacks", which is pu

machen 11 Nov 27, 2022
ICSS - Interactive Continual Semantic Segmentation

Presentation This repository contains the code of our paper: Weakly-supervised c

Alteia 9 Jul 23, 2022
Weighted QMIX: Expanding Monotonic Value Function Factorisation

This repo contains the cleaned-up code that was used in "Weighted QMIX: Expanding Monotonic Value Function Factorisation"

whirl 82 Dec 29, 2022
Generate pixel-style avatars with python.

face2pixel Generate pixel-style avatars with python. Run: Clone the project: git clone https://github.com/theodorecooper/face2pixel install requiremen

Theodore Cooper 2 May 11, 2022
Retrieval.pytorch - The code we used in [2020 DIGIX]

Retrieval.pytorch - The code we used in [2020 DIGIX]

Guo-Hua Wang 2 Feb 07, 2022
Deep Reinforcement Learning with pytorch & visdom

Deep Reinforcement Learning with pytorch & visdom Sample testings of trained agents (DQN on Breakout, A3C on Pong, DoubleDQN on CartPole, continuous A

Jingwei Zhang 783 Jan 04, 2023
Continual World is a benchmark for continual reinforcement learning

Continual World Continual World is a benchmark for continual reinforcement learning. It contains realistic robotic tasks which come from MetaWorld. Th

41 Dec 24, 2022
Code for the ICASSP-2021 paper: Continuous Speech Separation with Conformer.

Continuous Speech Separation with Conformer Introduction We examine the use of the Conformer architecture for continuous speech separation. Conformer

Sanyuan Chen (陈三元) 81 Nov 28, 2022
Rainbow DQN implementation that outperforms the paper's results on 40% of games using 20x less data 🌈

Rainbow 🌈 An implementation of Rainbow DQN which outperforms the paper's (Hessel et al. 2017) results on 40% of tested games while using 20x less dat

Dominik Schmidt 31 Dec 21, 2022
PyTorch code accompanying the paper "Landmark-Guided Subgoal Generation in Hierarchical Reinforcement Learning" (NeurIPS 2021).

HIGL This is a PyTorch implementation for our paper: Landmark-Guided Subgoal Generation in Hierarchical Reinforcement Learning (NeurIPS 2021). Our cod

Junsu Kim 20 Dec 14, 2022
Out-of-Domain Human Mesh Reconstruction via Dynamic Bilevel Online Adaptation

DynaBOA Code repositoty for the paper: Out-of-Domain Human Mesh Reconstruction via Dynamic Bilevel Online Adaptation Shanyan Guan, Jingwei Xu, Michell

198 Dec 29, 2022
Official code of CVPR 2021's PLOP: Learning without Forgetting for Continual Semantic Segmentation

PLOP: Learning without Forgetting for Continual Semantic Segmentation This repository contains all of our code. It is a modified version of Cermelli e

Arthur Douillard 116 Dec 14, 2022
The end-to-end platform for building voice products at scale

Picovoice Made in Vancouver, Canada by Picovoice Picovoice is the end-to-end platform for building voice products on your terms. Unlike Alexa and Goog

Picovoice 318 Jan 07, 2023