Mini-hmc-jax - A simple implementation of Hamiltonian Monte Carlo in JAX

Overview

mini-hmc-jax

This is a simple implementation of Hamiltonian Monte Carlo in JAX that is vectorized and supports pytree parameters (i.e. tree-like structures).

Here's a minimal example to sample from a distribution:

import jax
import jax.numpy as jnp
from hmc import hmc_sampler

# define target distribution
def target_log_pdf(params):
    return jax.scipy.stats.t.logpdf(params, df=1).sum()

# run HMC
params_init = jnp.zeros(10)
key = jax.random.PRNGKey(0)
chain = hmc_sampler(params_init, target_log_pdf, n_steps=10, n_leapfrog_steps=100, step_size=0.1, key=key)

Inspired by the paper What Are Bayesian Neural Network Posteriors Really Like?

@article{izmailov2021bayesian,
  title={What Are Bayesian Neural Network Posteriors Really Like?},
  author={Izmailov, Pavel and Vikram, Sharad and Hoffman, Matthew D and Wilson, Andrew Gordon},
  journal={arXiv preprint arXiv:2104.14421},
  year={2021}
}
Owner
Martin Marek
Martin Marek
University of Rochester 2021 Summer REU focusing on music sentiment transfer using CycleGAN

Music-Sentiment-Transfer University of Rochester 2021 Summer REU focusing on music sentiment transfer using CycleGAN Poster: Music Sentiment Transfer

Miles Sigel 2 Jan 24, 2022
Official page of Patchwork (RA-L'21 w/ IROS'21)

Patchwork Official page of "Patchwork: Concentric Zone-based Region-wise Ground Segmentation with Ground Likelihood Estimation Using a 3D LiDAR Sensor

Hyungtae Lim 254 Jan 05, 2023
ECAENet (TensorFlow and Keras)

ECAENet: EfficientNet with Efficient Channel Attention for Plant Species Recognition (SCI:Q3) (Journal of Intelligent & Fuzzy Systems)

4 Dec 22, 2022
This project aims to explore the deployment of Swin-Transformer based on TensorRT, including the test results of FP16 and INT8.

Swin Transformer This project aims to explore the deployment of SwinTransformer based on TensorRT, including the test results of FP16 and INT8. Introd

maggiez 87 Dec 21, 2022
Codes for our paper "SentiLARE: Sentiment-Aware Language Representation Learning with Linguistic Knowledge" (EMNLP 2020)

SentiLARE: Sentiment-Aware Language Representation Learning with Linguistic Knowledge Introduction SentiLARE is a sentiment-aware pre-trained language

74 Dec 30, 2022
A curated list of awesome open source libraries to deploy, monitor, version and scale your machine learning

Awesome production machine learning This repository contains a curated list of awesome open source libraries that will help you deploy, monitor, versi

The Institute for Ethical Machine Learning 12.9k Jan 04, 2023
SmartSim Infrastructure Library.

Home Install Documentation Slack Invite Cray Labs SmartSim SmartSim makes it easier to use common Machine Learning (ML) libraries like PyTorch and Ten

Cray Labs 139 Jan 01, 2023
Automatic Image Background Subtraction

Automatic Image Background Subtraction This repo contains set of scripts for automatic one-shot image background subtraction task using the following

Oleg Sémery 6 Dec 05, 2022
PArallel Distributed Deep LEarning: Machine Learning Framework from Industrial Practice (『飞桨』核心框架,深度学习&机器学习高性能单机、分布式训练和跨平台部署)

English | 简体中文 Welcome to the PaddlePaddle GitHub. PaddlePaddle, as the only independent R&D deep learning platform in China, has been officially open

19.4k Jan 04, 2023
Neural network pruning for finding a sparse computational model for controlling a biological motor task.

MothPruning Scientific Overview Originally inspired by biological nervous systems, deep neural networks (DNNs) are powerful computational tools for mo

Olivia Thomas 0 Dec 14, 2022
Official PyTorch implementation of "Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets" (ICLR 2021)

Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets This is the official PyTorch implementation for the paper Rapid Neural A

48 Dec 26, 2022
Official PyTorch implementation of paper: Standardized Max Logits: A Simple yet Effective Approach for Identifying Unexpected Road Obstacles in Urban-Scene Segmentation (ICCV 2021 Oral Presentation)

SML (ICCV 2021, Oral) : Official Pytorch Implementation This repository provides the official PyTorch implementation of the following paper: Standardi

SangHun 61 Dec 27, 2022
A python library for highly configurable transformers - easing model architecture search and experimentation.

A python library for highly configurable transformers - easing model architecture search and experimentation.

Anthony Fuller 51 Nov 20, 2022
A Rao-Blackwellized Particle Filter for 6D Object Pose Tracking

PoseRBPF: A Rao-Blackwellized Particle Filter for 6D Object Pose Tracking PoseRBPF Paper Self-supervision Paper Pose Estimation Video Robot Manipulati

NVIDIA Research Projects 107 Dec 25, 2022
Multimodal Co-Attention Transformer (MCAT) for Survival Prediction in Gigapixel Whole Slide Images

Multimodal Co-Attention Transformer (MCAT) for Survival Prediction in Gigapixel Whole Slide Images [ICCV 2021] © Mahmood Lab - This code is made avail

Mahmood Lab @ Harvard/BWH 63 Dec 01, 2022
2D Human Pose estimation using transformers. Implementation in Pytorch

PE-former: Pose Estimation Transformer Vision transformer architectures perform very well for image classification tasks. Efforts to solve more challe

Panteleris Paschalis 23 Oct 17, 2022
A Temporal Extension Library for PyTorch Geometric

Documentation | External Resources | Datasets PyTorch Geometric Temporal is a temporal (dynamic) extension library for PyTorch Geometric. The library

Benedek Rozemberczki 1.9k Jan 07, 2023
Codes for AAAI22 paper "Learning to Solve Travelling Salesman Problem with Hardness-Adaptive Curriculum"

Paper For more details, please see our paper Learning to Solve Travelling Salesman Problem with Hardness-Adaptive Curriculum which has been accepted a

14 Sep 30, 2022
Train an RL agent to execute natural language instructions in a 3D Environment (PyTorch)

Gated-Attention Architectures for Task-Oriented Language Grounding This is a PyTorch implementation of the AAAI-18 paper: Gated-Attention Architecture

Devendra Chaplot 234 Nov 05, 2022
Local Attention - Flax module for Jax

Local Attention - Flax Autoregressive Local Attention - Flax module for Jax Install $ pip install local-attention-flax Usage from jax import random fr

Phil Wang 16 Jun 16, 2022