A tf.keras implementation of Facebook AI's MadGrad optimization algorithm

Overview

MADGRAD Optimization Algorithm For Tensorflow

This package implements the MadGrad Algorithm proposed in Adaptivity without Compromise: A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic Optimization (Aaron Defazio and Samy Jelassi, 2021).

MIT License version-shield release-shield python-shield code-style

Table of Contents
  1. About The Project
  2. Getting Started
  3. Usage
  4. Contributing
  5. License
  6. Contact
  7. Citations

About The Project

The MadGrad algorithm of optimization uses Dual averaging of gradients along with momentum based adaptivity to attain results that match or outperform Adam or SGD + momentum based algorithms. This project offers a Tensorflow implementation of the algorithm along with a few usage examples and tests.



Prerequisites

Prerequisites can be installed separately through the requirements.txt file as below

pip install -r requirements.txt

Installation

This project is built with Python 3 and can be pip installed directly

pip install tf-madgrad

Usage

Open In Colab

To use the optimizer in any tf.keras model, you just need to import and instantiate the MadGrad optimizer from the tf_madgrad package.

from madgrad import MadGrad

# Create the architecture
inp = tf.keras.layers.Input(shape=shape)
...
op = tf.keras.layers.Dense(classes, activation=activation)

# Instantiate the model
model = tf.keras.models.Model(inp, op)

# Pass the MadGrad optimizer to the compile function
model.compile(optimizer=MadGrad(lr=0.01), loss=loss)

# Fit the keras model as normal
model.fit(...)

This implementation is also supported for distributed training using tf.strategy

See a MNIST example here

Contributing

Any and all contributions are welcome. Please raise an issue if the optimizer gives incorrect results or crashes unexpectedly during training.

License

Distributed under the MIT License. See LICENSE for more information.

Contact

Feel free to reach out for any issues or requests related to this implementation

Darshan Deshpande - Email | LinkedIn

Citations

@misc{defazio2021adaptivity,
      title={Adaptivity without Compromise: A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic Optimization}, 
      author={Aaron Defazio and Samy Jelassi},
      year={2021},
      eprint={2101.11075},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}
Owner
Helping Machines Learn Better 💻😃
FinRL­-Meta: A Universe for Data­-Driven Financial Reinforcement Learning. 🔥

FinRL-Meta: A Universe of Market Environments. FinRL-Meta is a universe of market environments for data-driven financial reinforcement learning. Users

AI4Finance Foundation 543 Jan 08, 2023
Investigating Attention Mechanism in 3D Point Cloud Object Detection (arXiv 2021)

Investigating Attention Mechanism in 3D Point Cloud Object Detection (arXiv 2021) This repository is for the following paper: "Investigating Attention

52 Nov 19, 2022
Bi-level feature alignment for versatile image translation and manipulation (Under submission of TPAMI)

Bi-level feature alignment for versatile image translation and manipulation (Under submission of TPAMI) Preparation Clone the Synchronized-BatchNorm-P

Fangneng Zhan 12 Aug 10, 2022
Python script that takes an Impulse response .wav and a input .wav to demonstrate audio convolution.

convolver Python script that takes an Impulse response .wav and a input .wav to demonstrate audio convolution. Created by Sean Higley

Sean Higley 1 Feb 23, 2022
Pytorch implementation of NEGEV method. Paper: "Negative Evidence Matters in Interpretable Histology Image Classification".

Pytorch 1.10.0 code for: Negative Evidence Matters in Interpretable Histology Image Classification (https://arxiv. org/abs/xxxx.xxxxx) Citation: @arti

Soufiane Belharbi 4 Dec 01, 2022
GLANet - The code for Global and Local Alignment Networks for Unpaired Image-to-Image Translation arxiv

GLANet The code for Global and Local Alignment Networks for Unpaired Image-to-Image Translation arxiv Framework: visualization results: Getting Starte

stanley 29 Dec 14, 2022
Replication Package for AequeVox:Automated Fariness Testing for Speech Recognition Systems

AequeVox Replication Package for AequeVox:Automated Fariness Testing for Speech Recognition Systems README under development. Python Packages Required

Sai Sathiesh 2 Aug 28, 2022
Implementation for our ICCV 2021 paper: Dual-Camera Super-Resolution with Aligned Attention Modules

DCSR: Dual Camera Super-Resolution Implementation for our ICCV 2021 oral paper: Dual-Camera Super-Resolution with Aligned Attention Modules paper | pr

Tengfei Wang 110 Dec 20, 2022
Yolov5 + Deep Sort with PyTorch

딥소트 수정중 Yolov5 + Deep Sort with PyTorch Introduction This repository contains a two-stage-tracker. The detections generated by YOLOv5, a family of obj

1 Nov 26, 2021
Mixed Transformer UNet for Medical Image Segmentation

MT-UNet Update 2021/11/19 Thank you for your interest in our work. We have uploaded the code of our MTUNet to help peers conduct further research on i

dotman 92 Dec 25, 2022
Supervised forecasting of sequential data in Python.

Supervised forecasting of sequential data in Python. Intro Supervised forecasting is the machine learning task of making predictions for sequential da

The Alan Turing Institute 54 Nov 15, 2022
A mini lib that implements several useful functions binding to PyTorch in C++.

Torch-gather A mini library that implements several useful functions binding to PyTorch in C++. What does gather do? Why do we need it? When dealing w

maxwellzh 8 Sep 07, 2022
[CVPR'22] COAP: Learning Compositional Occupancy of People

COAP: Compositional Articulated Occupancy of People Paper | Video | Project Page This is the official implementation of the CVPR 2022 paper COAP: Lear

Marko Mihajlovic 111 Dec 11, 2022
Numerical Methods with Python, Numpy and Matplotlib

Numerical Bric-a-Brac Collections of numerical techniques with Python and standard computational packages (Numpy, SciPy, Numba, Matplotlib ...). Diffe

Vincent Bonnet 10 Dec 20, 2021
PRIME: A Few Primitives Can Boost Robustness to Common Corruptions

PRIME: A Few Primitives Can Boost Robustness to Common Corruptions This is the official repository of PRIME, the data agumentation method introduced i

Apostolos Modas 34 Oct 30, 2022
State of the Art Neural Networks for Deep Learning

pyradox This python library helps you with implementing various state of the art neural networks in a totally customizable fashion using Tensorflow 2

Ritvik Rastogi 60 May 29, 2022
Degree-Quant: Quantization-Aware Training for Graph Neural Networks.

Degree-Quant This repo provides a clean re-implementation of the code associated with the paper Degree-Quant: Quantization-Aware Training for Graph Ne

35 Oct 07, 2022
Pytorch Implementation of Residual Vision Transformers(ResViT)

ResViT Official Pytorch Implementation of Residual Vision Transformers(ResViT) which is described in the following paper: Onat Dalmaz and Mahmut Yurt

ICON Lab 41 Dec 08, 2022
Distilling Motion Planner Augmented Policies into Visual Control Policies for Robot Manipulation (CoRL 2021)

Distilling Motion Planner Augmented Policies into Visual Control Policies for Robot Manipulation [Project website] [Paper] This project is a PyTorch i

Cognitive Learning for Vision and Robotics (CLVR) lab @ USC 6 Feb 28, 2022
PN-Net a neural field-based framework for depth estimation from single-view RGB images.

PN-Net We present a neural field-based framework for depth estimation from single-view RGB images. Rather than representing a 2D depth map as a single

1 Oct 02, 2021