PyTorch implementation of Algorithm 1 of "On the Anatomy of MCMC-Based Maximum Likelihood Learning of Energy-Based Models"

Overview

Code for On the Anatomy of MCMC-Based Maximum Likelihood Learning of Energy-Based Models

This repository will reproduce the main results from our paper:

On the Anatomy of MCMC-Based Maximum Likelihood Learning of Energy-Based Models
Erik Nijkamp*, Mitch Hill*, Tian Han, Song-Chun Zhu, and Ying Nian Wu (*equal contributions)
https://arxiv.org/abs/1903.12370
AAAI 2020.

The files train_data.py and train_toy.py are PyTorch-based implementations of Algorithm 1 for image datasets and toy 2D distributions respectively. Both files will measure and plot the diagnostic values $d_{s_t}$ and $r_t$ described in Section 3 during training. The file eval.py will sample from a saved checkpoint using either unadjusted Langevin dynamics or Metropolis-Hastings adjusted Langevin dynamics. We provide an appendix ebm-anatomy-appendix.pdf that contains further practical considerations and empirical observations.

Config Files

The folder config_locker has several JSON files that reproduce different convergent and non-convergent learning outcomes for image datasets and toy distributions. Config files for evaluation of pre-trained networks are also included. The files data_config.json, toy_config.json, and eval_config.json fully explain the parameters for train_data.py, train_toy.py, and eval.py respectively.

Executable Files

To run an experiment with train_data.py, train_toy.py, or eval.py, just specify a name for the experiment folder and the location of the JSON config file:

# directory for experiment results
EXP_DIR = './name_of/new_folder/'
# json file with experiment config
CONFIG_FILE = './path_to/config.json'

before execution.

Other Files

Network structures are located in nets.py. A download function for Oxford Flowers 102 data, plotting functions, and a toy dataset class can be found in utils.py.

Diagnostics

Energy Difference and Langevin Gradient Magnitude: Both image and toy experiments will plot $d_{s_t}$ and $r_t$ (see Section 3) over training along with correlation plots as in Figure 4 (with ACF rather than PACF).

Landscape Plots: Toy experiments will plot the density and log-density (negative energy) for ground-truth, learned energy, and short-run models. Kernel density estimation is used to obtain the short-run density.

Short-Run MCMC Samples: Image data experiments will periodically visualize the short-run MCMC samples. A batch of persistent MCMC samples will also be saved for implementations that use persistent initialization for short-run sampling.

Long-Run MCMC Samples: Image data experiments have the option to obtain long-run MCMC samples during training. When log_longrun is set to true in a data config file, the training implementation will generate long-run MCMC samples at a frequency determined by log_longrun_freq. The appearance of long-run MCMC samples indicates whether the energy function assigns probability mass in realistic regions of the image space.

Pre-trained Networks

A convergent pre-trained network and non-convergent pre-trained network for the Oxford Flowers 102 dataset are available in the Releases section of the repository. The config files eval_flowers_convergent.json and eval_flowers_convergent_mh.json are set up to evaluate flowers_convergent_net.pth. The config file eval_flowers_nonconvergent.json is set up to evaluate flowers_nonconvergent_net.pth.

Contact

Please contact Mitch Hill ([email protected]) or Erik Nijkamp ([email protected]) for any questions.

You might also like...
Re-implementation of the Noise Contrastive Estimation algorithm for pyTorch, following "Noise-contrastive estimation: A new estimation principle for unnormalized statistical models." (Gutmann and Hyvarinen, AISTATS 2010)

Noise Contrastive Estimation for pyTorch Overview This repository contains a re-implementation of the Noise Contrastive Estimation algorithm, implemen

ppo_pytorch_cpp - an implementation of the proximal policy optimization algorithm for the C++ API of Pytorch
ppo_pytorch_cpp - an implementation of the proximal policy optimization algorithm for the C++ API of Pytorch

PPO Pytorch C++ This is an implementation of the proximal policy optimization algorithm for the C++ API of Pytorch. It uses a simple TestEnvironment t

PyTorch implementation of DreamerV2 model-based RL algorithm

PyDreamer Reimplementation of DreamerV2 model-based RL algorithm in PyTorch. The official DreamerV2 implementation can be found here. Features ... Run

PyTorch implementation of the implicit Q-learning algorithm (IQL)
PyTorch implementation of the implicit Q-learning algorithm (IQL)

Implicit-Q-Learning (IQL) PyTorch implementation of the implicit Q-learning algorithm IQL (Paper) Currently only implemented for online learning. Offl

PyTorch Implementation of the SuRP algorithm by the authors of the AISTATS 2022 paper "An Information-Theoretic Justification for Model Pruning"

PyTorch Implementation of the SuRP algorithm by the authors of the AISTATS 2022 paper "An Information-Theoretic Justification for Model Pruning".

A pytorch reprelication of the model-based reinforcement learning algorithm MBPO
A pytorch reprelication of the model-based reinforcement learning algorithm MBPO

Overview This is a re-implementation of the model-based RL algorithm MBPO in pytorch as described in the following paper: When to Trust Your Model: Mo

An algorithm that handles large-scale aerial photo co-registration, based on SURF, RANSAC and PyTorch autograd.
An algorithm that handles large-scale aerial photo co-registration, based on SURF, RANSAC and PyTorch autograd.

An algorithm that handles large-scale aerial photo co-registration, based on SURF, RANSAC and PyTorch autograd.

Implements pytorch code for the Accelerated SGD algorithm.

AccSGD This is the code associated with Accelerated SGD algorithm used in the paper On the insufficiency of existing momentum schemes for Stochastic O

PyGAD, a Python 3 library for building the genetic algorithm and training machine learning algorithms (Keras & PyTorch).
PyGAD, a Python 3 library for building the genetic algorithm and training machine learning algorithms (Keras & PyTorch).

PyGAD: Genetic Algorithm in Python PyGAD is an open-source easy-to-use Python 3 library for building the genetic algorithm and optimizing machine lear

Comments
  • Step size in Langevin Dynamics

    Step size in Langevin Dynamics

    Hi, in your code, when you do the langevin dynamics, you run x_s_t.data += - f_prime + config['epsilon'] * t.randn_like(x_s_t) However, does this mean that the step size for the gradient f_prim is 1? Should we run x_s_t.data += - 0.5*config['epsilon']**2*f_prime + config['epsilon'] * t.randn_like(x_s_t) instead?

    opened by XavierXiao 1
Releases(v1.0)
Owner
Mitch Hill
Assistant Professor of Statistics and Data Science at UCF
Mitch Hill
Codebase for Image Classification Research, written in PyTorch.

pycls pycls is an image classification codebase, written in PyTorch. It was originally developed for the On Network Design Spaces for Visual Recogniti

Facebook Research 2k Jan 01, 2023
Training a Resilient Q-Network against Observational Interference, Causal Inference Q-Networks

Obs-Causal-Q-Network AAAI 2022 - Training a Resilient Q-Network against Observational Interference Preprint | Slides | Colab Demo | Environment Setup

23 Nov 21, 2022
Machine Learning University: Accelerated Computer Vision Class

Machine Learning University: Accelerated Computer Vision Class This repository contains slides, notebooks, and datasets for the Machine Learning Unive

AWS Samples 1.3k Dec 28, 2022
Official code for 'Pixel-wise Energy-biased Abstention Learning for Anomaly Segmentationon Complex Urban Driving Scenes'

PEBAL This repo contains the Pytorch implementation of our paper: Pixel-wise Energy-biased Abstention Learning for Anomaly Segmentationon Complex Urba

Yu Tian 115 Dec 29, 2022
To prepare an image processing model to classify the type of disaster based on the image dataset

Disaster Classificiation using CNNs bunnysaini/Disaster-Classificiation Goal To prepare an image processing model to classify the type of disaster bas

Bunny Saini 1 Jan 24, 2022
Self-supervised Product Quantization for Deep Unsupervised Image Retrieval - ICCV2021

Self-supervised Product Quantization for Deep Unsupervised Image Retrieval Pytorch implementation of SPQ Accepted to ICCV 2021 - paper Young Kyun Jang

Young Kyun Jang 71 Dec 27, 2022
Code for the paper "Adapting Monolingual Models: Data can be Scarce when Language Similarity is High"

Wietse de Vries • Martijn Bartelds • Malvina Nissim • Martijn Wieling Adapting Monolingual Models: Data can be Scarce when Language Similarity is High

Wietse de Vries 5 Aug 02, 2021
Real-time Object Detection for Streaming Perception, CVPR 2022

StreamYOLO Real-time Object Detection for Streaming Perception Jinrong Yang, Songtao Liu, Zeming Li, Xiaoping Li, Sun Jian Real-time Object Detection

Jinrong Yang 237 Dec 27, 2022
A GOOD REPRESENTATION DETECTS NOISY LABELS

A GOOD REPRESENTATION DETECTS NOISY LABELS This code is a PyTorch implementation of the paper: Prerequisites Python 3.6.9 PyTorch 1.7.1 Torchvision 0.

<a href=[email protected]"> 64 Jan 04, 2023
Lux AI environment interface for RLlib multi-agents

Lux AI interface to RLlib MultiAgentsEnv For Lux AI Season 1 Kaggle competition. LuxAI repo RLlib-multiagents docs Kaggle environments repo Please let

Jaime 12 Nov 07, 2022
Object detection using yolo-tiny model and opencv used as backend

Object detection Algorithm used : Yolo algorithm Backend : opencv Library required: opencv = 4.5.4-dev' Quick Overview about structure 1) main.py Load

2 Jul 06, 2022
MatchGAN: A Self-supervised Semi-supervised Conditional Generative Adversarial Network

MatchGAN: A Self-supervised Semi-supervised Conditional Generative Adversarial Network This repository is the official implementation of MatchGAN: A S

Justin Sun 12 Dec 27, 2022
A Nim frontend for pytorch, aiming to be mostly auto-generated and internally using ATen.

Master Release Pytorch - Py + Nim A Nim frontend for pytorch, aiming to be mostly auto-generated and internally using ATen. Because Nim compiles to C+

Giovanni Petrantoni 425 Dec 22, 2022
Real-time face detection and emotion/gender classification using fer2013/imdb datasets with a keras CNN model and openCV.

Real-time face detection and emotion/gender classification using fer2013/imdb datasets with a keras CNN model and openCV.

Octavio Arriaga 5.3k Dec 30, 2022
Training BERT with Compute/Time (Academic) Budget

Training BERT with Compute/Time (Academic) Budget This repository contains scripts for pre-training and finetuning BERT-like models with limited time

Intel Labs 263 Jan 07, 2023
An official PyTorch implementation of the TKDE paper "Self-Supervised Graph Representation Learning via Topology Transformations".

Self-Supervised Graph Representation Learning via Topology Transformations This repository is the official PyTorch implementation of the following pap

Hsiang Gao 2 Oct 31, 2022
EMNLP 2021: Single-dataset Experts for Multi-dataset Question-Answering

MADE (Multi-Adapter Dataset Experts) This repository contains the implementation of MADE (Multi-adapter dataset experts), which is described in the pa

Princeton Natural Language Processing 68 Jul 18, 2022
T-LOAM: Truncated Least Squares Lidar-only Odometry and Mapping in Real-Time

T-LOAM: Truncated Least Squares Lidar-only Odometry and Mapping in Real-Time The first Lidar-only odometry framework with high performance based on tr

Pengwei Zhou 183 Dec 01, 2022
Bio-Computing Platform Featuring Large-Scale Representation Learning and Multi-Task Deep Learning “螺旋桨”生物计算工具集

English | 简体中文 Latest News 2021.10.25 Paper "Docking-based Virtual Screening with Multi-Task Learning" is accepted by BIBM 2021. 2021.07.29 PaddleHeli

633 Jan 04, 2023
a practicable framework used in Deep Learning. So far UDL only provide DCFNet implementation for the ICCV paper (Dynamic Cross Feature Fusion for Remote Sensing Pansharpening)

UDL UDL is a practicable framework used in Deep Learning (computer vision). Benchmark codes, results and models are available in UDL, please contact @

Xiao Wu 11 Sep 30, 2022