Tensorflow implementation and notebooks for Implicit Maximum Likelihood Estimation

Related tags

Deep Learningtf-imle
Overview

tf-imle

Tensorflow 2 and PyTorch implementation and Jupyter notebooks for Implicit Maximum Likelihood Estimation (I-MLE) proposed in the NeurIPS 2021 paper Implicit MLE: Backpropagating Through Discrete Exponential Family Distributions.

I-MLE is also available as a PyTorch library: https://github.com/uclnlp/torch-imle

Introduction

Implicit MLE (I-MLE) makes it possible to include discrete combinatorial optimization algorithms, such as Dijkstra's algorithm or integer linear programming (ILP) solvers, as well as complex discrete probability distributions in standard deep learning architectures. The figure below illustrates the setting I-MLE was developed for. is a standard neural network, mapping some input to the input parameters of a discrete combinatorial optimization algorithm or a discrete probability distribution, depicted as the black box. In the forward pass, the discrete component is executed and its discrete output fed into a downstream neural network . Now, with I-MLE it is possible to estimate gradients of with respect to a loss function, which are used during backpropagation to update the parameters of the upstream neural network.

Illustration of the problem addressed by I-MLE

The core idea of I-MLE is that it defines an implicit maximum likelihood objective whose gradients are used to update upstream parameters of the model. Every instance of I-MLE requires two ingredients:

  1. A method to approximately sample from a complex and possibly intractable distribution. For this we use Perturb-and-MAP (aka the Gumbel-max trick) and propose a novel family of noise perturbations tailored to the problem at hand.
  2. A method to compute a surrogate empirical distribution: Vanilla MLE reduces the KL divergence between the current distribution and the empirical distribution. Since in our setting, we do not have access to such an empirical distribution, we have to design surrogate empirical distributions which we term target distributions. Here we propose two families of target distributions which are widely applicable and work well in practice.

Requirements:

TensorFlow 2 implementation:

  • tensorflow==2.3.0 or tensorflow-gpu==2.3.0
  • numpy==1.18.5
  • matplotlib==3.1.1
  • scikit-learn==0.24.1
  • tensorflow-probability==0.7.0

PyTorch implementation:

Example: I-MLE as a Layer

The following is an instance of I-MLE implemented as a layer. This is a class where the optimization problem is computing the k-subset configuration, the target distribution is based on perturbation-based implicit differentiation, and the perturb-and-MAP noise perturbations are drawn from the sum-of-gamma distribution.

class IMLESubsetkLayer(tf.keras.layers.Layer):
    
    def __init__(self, k, _tau=10.0, _lambda=10.0):
        super(IMLESubsetkLayer, self).__init__()
        # average number of 1s in a solution to the optimization problem
        self.k = k
        # the temperature at which we want to sample
        self._tau = _tau
        # the perturbation strength (here we use a target distribution based on perturbation-based implicit differentiation
        self._lambda = _lambda  
        # the samples we store for the backward pass
        self.samples = None 
        
    @tf.function
    def sample_sum_of_gamma(self, shape):
        
        s = tf.map_fn(fn=lambda t: tf.random.gamma(shape, 1.0/self.k, self.k/t), 
                  elems=tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]))   
        # now add the samples
        s = tf.reduce_sum(s, 0)
        # the log(m) term
        s = s - tf.math.log(10.0)
        # divide by k --> each s[c] has k samples whose sum is distributed as Gumbel(0, 1)
        s = self._tau * (s / self.k)

        return s
    
    @tf.function
    def sample_discrete_forward(self, logits): 
        self.samples = self.sample_sum_of_gamma(tf.shape(logits))
        gamma_perturbed_logits = logits + self.samples
        # gamma_perturbed_logits is the input to the combinatorial opt algorithm
        # the next two lines can be replaced by a custom black-box algorithm call
        threshold = tf.expand_dims(tf.nn.top_k(gamma_perturbed_logits, self.k, sorted=True)[0][:,-1], -1)
        y = tf.cast(tf.greater_equal(gamma_perturbed_logits, threshold), tf.float32)
        
        return y
    
    @tf.function
    def sample_discrete_backward(self, logits):     
        gamma_perturbed_logits = logits + self.samples
        # gamma_perturbed_logits is the input to the combinatorial opt algorithm
        # the next two lines can be replaced by a custom black-box algorithm call
        threshold = tf.expand_dims(tf.nn.top_k(gamma_perturbed_logits, self.k, sorted=True)[0][:,-1], -1)
        y = tf.cast(tf.greater_equal(gamma_perturbed_logits, threshold), tf.float32)
        return y
    
    @tf.custom_gradient
    def subset_k(self, logits, k):

        # sample discretely with perturb and map
        z_train = self.sample_discrete_forward(logits)
        # compute the top-k discrete values
        threshold = tf.expand_dims(tf.nn.top_k(logits, self.k, sorted=True)[0][:,-1], -1)
        z_test = tf.cast(tf.greater_equal(logits, threshold), tf.float32)
        # at training time we sample, at test time we take the argmax
        z_output = K.in_train_phase(z_train, z_test)
        
        def custom_grad(dy):

            # we perturb (implicit diff) and then resuse sample for perturb and MAP
            map_dy = self.sample_discrete_backward(logits - (self._lambda*dy))
            # we now compute the gradients as the difference (I-MLE gradients)
            grad = tf.math.subtract(z_train, map_dy)
            # return the gradient            
            return grad, k

        return z_output, custom_grad

Reference

@inproceedings{niepert21imle,
  author    = {Mathias Niepert and
               Pasquale Minervini and
               Luca Franceschi},
  title     = {Implicit {MLE:} Backpropagating Through Discrete Exponential Family
               Distributions},
  booktitle = {NeurIPS},
  series    = {Proceedings of Machine Learning Research},
  publisher = {{PMLR}},
  year      = {2021}
}
Owner
NEC Laboratories Europe
Research software developed at NEC Laboratories Europe
NEC Laboratories Europe
Lowest memory consumption and second shortest runtime in NTIRE 2022 challenge on Efficient Super-Resolution

FMEN Lowest memory consumption and second shortest runtime in NTIRE 2022 on Efficient Super-Resolution. Our paper: Fast and Memory-Efficient Network T

33 Dec 01, 2022
code for paper "Does Unsupervised Architecture Representation Learning Help Neural Architecture Search?"

Does Unsupervised Architecture Representation Learning Help Neural Architecture Search? Code for paper: Does Unsupervised Architecture Representation

39 Dec 17, 2022
Collective Multi-type Entity Alignment Between Knowledge Graphs (WWW'20)

CG-MuAlign A reference implementation for "Collective Multi-type Entity Alignment Between Knowledge Graphs", published in WWW 2020. If you find our pa

Bran Zhu 28 Dec 11, 2022
Implementation of "Learning to Match Features with Seeded Graph Matching Network" ICCV2021

SGMNet Implementation PyTorch implementation of SGMNet for ICCV'21 paper "Learning to Match Features with Seeded Graph Matching Network", by Hongkai C

87 Dec 11, 2022
PyG (PyTorch Geometric) - A library built upon PyTorch to easily write and train Graph Neural Networks (GNNs)

PyG (PyTorch Geometric) is a library built upon PyTorch to easily write and train Graph Neural Networks (GNNs) for a wide range of applications related to structured data.

PyG 16.5k Jan 08, 2023
Densely Connected Search Space for More Flexible Neural Architecture Search (CVPR2020)

DenseNAS The code of the CVPR2020 paper Densely Connected Search Space for More Flexible Neural Architecture Search. Neural architecture search (NAS)

Jamin Fong 291 Nov 18, 2022
A PyTorch-Based Framework for Deep Learning in Computer Vision

TorchCV: A PyTorch-Based Framework for Deep Learning in Computer Vision @misc{you2019torchcv, author = {Ansheng You and Xiangtai Li and Zhen Zhu a

Donny You 2.2k Jan 09, 2023
Official implementation of Self-supervised Graph Attention Networks (SuperGAT), ICLR 2021.

SuperGAT Official implementation of Self-supervised Graph Attention Networks (SuperGAT). This model is presented at How to Find Your Friendly Neighbor

Dongkwan Kim 127 Dec 28, 2022
nanodet_plus,yolov5_v6.0

OAK_Detection OAK设备上适配nanodet_plus,yolov5_v6.0 Environment pytorch = 1.7.0

炼丹去了 1 Feb 18, 2022
PyTorch implementation of "Contrast to Divide: self-supervised pre-training for learning with noisy labels"

Contrast to Divide: self-supervised pre-training for learning with noisy labels This is an official implementation of "Contrast to Divide: self-superv

55 Nov 23, 2022
Small little script to scrape, parse and check for active tor nodes. Can be used as proxies.

TorScrape TorScrape is a small but useful script made in python that scrapes a website for active tor nodes, parse the html and then save the nodes in

5 Dec 04, 2022
A Pytorch implement of paper "Anomaly detection in dynamic graphs via transformer" (TADDY).

TADDY: Anomaly detection in dynamic graphs via transformer This repo covers an reference implementation for the paper "Anomaly detection in dynamic gr

Yue Tan 21 Nov 24, 2022
Self-Supervised Monocular DepthEstimation with Internal Feature Fusion(arXiv), BMVC2021

DIFFNet This repo is for Self-Supervised Monocular DepthEstimation with Internal Feature Fusion(arXiv), BMVC2021 A new backbone for self-supervised de

Hang 94 Dec 25, 2022
Implementation of ICCV 2021 oral paper -- A Novel Self-Supervised Learning for Gaussian Mixture Model

SS-GMM Implementation of ICCV 2021 oral paper -- Self-Supervised Image Prior Learning with GMM from a Single Noisy Image with supplementary material R

HUST-The Tan Lab 4 Dec 05, 2022
An SE(3)-invariant autoencoder for generating the periodic structure of materials

Crystal Diffusion Variational AutoEncoder This software implementes Crystal Diffusion Variational AutoEncoder (CDVAE), which generates the periodic st

Tian Xie 94 Dec 10, 2022
SpiroMask: Measuring Lung Function Using Consumer-Grade Masks

SpiroMask: Measuring Lung Function Using Consumer-Grade Masks Anonymised repository for paper submitted for peer review at ACM HEALTH (October 2021).

0 May 10, 2022
Differentiable scientific computing library

xitorch: differentiable scientific computing library xitorch is a PyTorch-based library of differentiable functions and functionals that can be widely

98 Dec 26, 2022
KUIELAB-MDX-Net got the 2nd place on the Leaderboard A and the 3rd place on the Leaderboard B in the MDX-Challenge ISMIR 2021

KUIELAB-MDX-Net got the 2nd place on the Leaderboard A and the 3rd place on the Leaderboard B in the MDX-Challenge ISMIR 2021

IELab@ Korea University 74 Dec 28, 2022
codes for paper Combining Dynamic Local Context Focus and Dependency Cluster Attention for Aspect-level sentiment classification

DLCF-DCA codes for paper Combining Dynamic Local Context Focus and Dependency Cluster Attention for Aspect-level sentiment classification. submitted t

15 Aug 30, 2022
NeurIPS'21 Tractable Density Estimation on Learned Manifolds with Conformal Embedding Flows

NeurIPS'21 Tractable Density Estimation on Learned Manifolds with Conformal Embedding Flows This repo contains the code for the paper Tractable Densit

Layer6 Labs 4 Dec 12, 2022