MADE (Masked Autoencoder Density Estimation) implementation in PyTorch

Overview

pytorch-made

This code is an implementation of "Masked AutoEncoder for Density Estimation" by Germain et al., 2015. The core idea is that you can turn an auto-encoder into an autoregressive density model just by appropriately masking the connections in the MLP, ordering the input dimensions in some way and making sure that all outputs only depend on inputs earlier in the list. Like other autoregressive models (char-rnn, pixel cnns, etc), evaluating the likelihood is very cheap (a single forward pass), but sampling is linear in the number of dimensions.

figure 1

The authors of the paper also published code here, but it's a bit wordy, sprawling and in Theano. Hence my own shot at it with only ~150 lines of code and PyTorch <3.

examples

First we download the binarized mnist dataset. Then we can reproduce the first point on the plot of Figure 2 by training a 1-layer MLP of 500 units with only a single mask, and using a single fixed (but random) ordering as so:

python run.py --data-path binarized_mnist.npz -q 500

which converges at binary cross entropy loss of 94.5, as shown in the paper. We can then simultaneously train a larger model ensemble (with weight sharing in the one MLP) and average over all of the models at test time. For instance, we can use 10 orderings (-n 10) and also average over the 10 at inference time (-s 10):

python run.py --data-path binarized_mnist.npz -q 500 -n 10 -s 10

which gives a much better test loss of 79.3, but at the cost of multiple forward passes. I was not able to reproduce single-forward-pass gains that the paper alludes to when training with multiple masks, might be doing something wrong.

usage

The core class is MADE, found in made.py. It inherits from PyTorch nn.Module so you can "slot it into" larger architectures quite easily. To instantiate MADE on 1D inputs of MNIST digits for example (which have 28*28 pixels), using one hidden layer of 500 neurons, and using a single but random ordering we would do:

model = MADE(28*28, [500], 28*28, num_masks=1, natural_ordering=False)

The reason we plug the size of the output (3rd argument) into MADE is that one might want to use relatively complicated output distributions, for example a gaussian distribution would normally be parameterized by a mean and a standard deviation for each dimension, or you could bin the output range into buckets and output logprobs for a softmax, or mixture parameters, etc. In the simplest example in this code we use binary predictions, where are only parameterized by one number, hence the number of the input dimensions happens to equal the number of outputs.

License

MIT

Owner
Andrej
I like to train Deep Neural Nets on large datasets.
Andrej
Implementation for paper "Towards the Generalization of Contrastive Self-Supervised Learning"

Contrastive Self-Supervised Learning on CIFAR-10 Paper "Towards the Generalization of Contrastive Self-Supervised Learning", Weiran Huang, Mingyang Yi

Weiran Huang 13 Nov 30, 2022
KIDA: Knowledge Inheritance in Data Aggregation

KIDA: Knowledge Inheritance in Data Aggregation This project releases our 1st place solution on NeurIPS2021 ML4CO Dual Task. Slide and model weights a

24 Sep 08, 2022
Official PyTorch implementation of "Adversarial Reciprocal Points Learning for Open Set Recognition"

Adversarial Reciprocal Points Learning for Open Set Recognition Official PyTorch implementation of "Adversarial Reciprocal Points Learning for Open Se

Guangyao Chen 78 Dec 28, 2022
GPT-Code-Clippy (GPT-CC) is an open source version of GitHub Copilot

GPT-Code-Clippy (GPT-CC) is an open source version of GitHub Copilot, a language model -- based on GPT-3, called GPT-Codex -- that is fine-tuned on publicly available code from GitHub.

2.3k Jan 09, 2023
MASS (Mueen's Algorithm for Similarity Search) - a python 2 and 3 compatible library used for searching time series sub-sequences under z-normalized Euclidean distance for similarity.

Introduction MASS allows you to search a time series for a subquery resulting in an array of distances. These array of distances enable you to identif

Matrix Profile Foundation 79 Dec 31, 2022
Implementation of EMNLP 2017 Paper "Natural Language Does Not Emerge 'Naturally' in Multi-Agent Dialog" using PyTorch and ParlAI

Language Emergence in Multi Agent Dialog Code for the Paper Natural Language Does Not Emerge 'Naturally' in Multi-Agent Dialog Satwik Kottur, José M.

Karan Desai 105 Nov 25, 2022
Author's PyTorch implementation of TD3 for OpenAI gym tasks

Addressing Function Approximation Error in Actor-Critic Methods PyTorch implementation of Twin Delayed Deep Deterministic Policy Gradients (TD3). If y

Scott Fujimoto 1.3k Dec 25, 2022
PyoMyo - Python Opensource Myo library

PyoMyo Python module for the Thalmic Labs Myo armband. Cross platform and multithreaded and works without the Myo SDK. pip install pyomyo Documentati

PerlinWarp 81 Jan 08, 2023
MVFNet: Multi-View Fusion Network for Efficient Video Recognition (AAAI 2021)

MVFNet: Multi-View Fusion Network for Efficient Video Recognition (AAAI 2021) Overview We release the code of the MVFNet (Multi-View Fusion Network).

2 Jan 29, 2022
CDTrans: Cross-domain Transformer for Unsupervised Domain Adaptation

CDTrans: Cross-domain Transformer for Unsupervised Domain Adaptation [arxiv] This is the official repository for CDTrans: Cross-domain Transformer for

238 Dec 22, 2022
This is a repository for a semantic segmentation inference API using the OpenVINO toolkit

BMW-IntelOpenVINO-Segmentation-Inference-API This is a repository for a semantic segmentation inference API using the OpenVINO toolkit. It's supported

BMW TechOffice MUNICH 34 Nov 24, 2022
DR-GAN: Automatic Radial Distortion Rectification Using Conditional GAN in Real-Time

DR-GAN: Automatic Radial Distortion Rectification Using Conditional GAN in Real-Time Introduction This is official implementation for DR-GAN (IEEE TCS

Kang Liao 18 Dec 23, 2022
PyTorch implementation of Progressive Growing of GANs for Improved Quality, Stability, and Variation.

PyTorch implementation of Progressive Growing of GANs for Improved Quality, Stability, and Variation. Warning: the master branch might collapse. To ob

559 Dec 14, 2022
Machine Learning Model deployment for Container (TensorFlow Serving)

try_tf_serving ├───dataset │ ├───testing │ │ ├───paper │ │ ├───rock │ │ └───scissors │ └───training │ ├───paper │ ├───rock

Azhar Rizki Zulma 5 Jan 07, 2022
🔥🔥High-Performance Face Recognition Library on PaddlePaddle & PyTorch🔥🔥

face.evoLVe: High-Performance Face Recognition Library based on PaddlePaddle & PyTorch Evolve to be more comprehensive, effective and efficient for fa

Zhao Jian 3.1k Jan 02, 2023
Change is Everywhere: Single-Temporal Supervised Object Change Detection in Remote Sensing Imagery (ICCV 2021)

Change is Everywhere Single-Temporal Supervised Object Change Detection in Remote Sensing Imagery by Zhuo Zheng, Ailong Ma, Liangpei Zhang and Yanfei

Zhuo Zheng 125 Dec 13, 2022
The Official Repository for "Generalized OOD Detection: A Survey"

Generalized Out-of-Distribution Detection: A Survey 1. Overview This repository is with our survey paper: Title: Generalized Out-of-Distribution Detec

Jingkang Yang 338 Jan 03, 2023
PyTorch Implementation of Unsupervised Depth Completion with Calibrated Backprojection Layers (ORAL, ICCV 2021)

Unsupervised Depth Completion with Calibrated Backprojection Layers PyTorch implementation of Unsupervised Depth Completion with Calibrated Backprojec

80 Dec 13, 2022
Official repository for ABC-GAN

ABC-GAN The work represented in this repository is the result of a 14 week semesterthesis on photo-realistic image generation using generative adversa

IgorSusmelj 10 Jun 23, 2022