Repository for "Improving evidential deep learning via multi-task learning," published in AAAI2022

Overview

Improving evidential deep learning via multi task learning

It is a repository of AAAI2022 paper, “Improving evidential deep learning via multi-task learning”, by Dongpin Oh and Bonggun Shin.

This repository contains the code to reproduce the Multi-task evidential neural network (MT-ENet), which uses the Lipschitz MSE loss function as the additional loss function of the evidential regression network (ENet). The Lipschitz MSE loss function can improve the accuracy of the ENet while preserving its uncertainty estimation capability, by avoiding gradient conflict with the NLL loss function—the original loss function of the ENet.

drawing

Setup

Please refer to "requirements.txt" for requring packages of this repo.

pip install -r requirements.txt

Training the ENet with the Lipschitz-MSE loss: example

from mtevi.mtevi import EvidentialMarginalLikelihood, EvidenceRegularizer, modified_mse
...
net = EvidentialNetwork() ## Evidential regression network
nll_loss = EvidentialMarginalLikelihood() ## original loss, NLL loss
reg = EvidenceRegularizer() ## evidential regularizer
mmse_loss = modified_mse ## lipschitz MSE loss
...
for inputs, labels in dataloader:
	gamma, nu, alpha, beta = net(inputs)
	loss = nll_loss(gamma, nu, alpha, beta, labels)
	loss += reg(gamma, nu, alpha, beta, labels)
	loss += mmse_loss(gamma, nu, alpha, beta, labels)
	loss.backward()	

Quick start

  • Synthetic data experiment.
python synthetic_exp.py
  • UCI regression benchmark experiments.
python uci_exp_norm -p energy
  • Drug target affinity (DTA) regression task on KIBA and Davis datasets.
python train_evinet.py -o test --type davis -f 0 --evi # ENet
python train_evinet.py -o test --type davis -f 0  # MT-ENet
  • Gradient conflict experiment on the DTA benchmarks
python check_conflict.py --type davis -f 0 # Conflict between the Lipschitz MSE (proposed) and NLL loss. 
python check_conflict.py --type davis -f 0 --abl # Conflict between the simple MSE loss and NLL loss.

Characteristic of the Lipschitz MSE loss

drawing

  • The Lipschitz MSE loss function can support training the ENet to more accurately predicts target values.
  • It regularizes its gradient to prevent gradient conflict with the NLL loss--the original loss function--if the NLL loss increases predictive uncertainty of the ENet.
  • Please check our paper for details.
Owner
deargen
deargen
Heart Arrhythmia Classification

This program takes and input of an ECG in European Data Format (EDF) and outputs the classification for heartbeats into normal vs different types of arrhythmia . It uses a deep learning model for cla

4 Nov 02, 2022
ktrain is a Python library that makes deep learning and AI more accessible and easier to apply

Overview | Tutorials | Examples | Installation | FAQ | How to Cite Welcome to ktrain News and Announcements 2020-11-08: ktrain v0.25.x is released and

Arun S. Maiya 1.1k Jan 02, 2023
Source code of CIKM2021 Long Paper "PSSL: Self-supervised Learning for Personalized Search with Contrastive Sampling".

PSSL Source code of CIKM2021 Long Paper "PSSL: Self-supervised Learning for Personalized Search with Contrastive Sampling". It consists of the pre-tra

2 Dec 21, 2021
GyroSPD: Vector-valued Distance and Gyrocalculus on the Space of Symmetric Positive Definite Matrices

GyroSPD Code for the paper "Vector-valued Distance and Gyrocalculus on the Space of Symmetric Positive Definite Matrices" accepted at NeurIPS 2021. Re

Federico Lopez 12 Dec 12, 2022
Official PyTorch implementation of "Edge Rewiring Goes Neural: Boosting Network Resilience via Policy Gradient".

Edge Rewiring Goes Neural: Boosting Network Resilience via Policy Gradient This repository is the official PyTorch implementation of "Edge Rewiring Go

Shanchao Yang 4 Dec 12, 2022
Conversion between units used in magnetism

convmag Conversion between various units used in magnetism The conversions between base units available are: T - G : 1e4

0 Jul 15, 2021
Simulation-based inference for the Galactic Center Excess

Simulation-based inference for the Galactic Center Excess Siddharth Mishra-Sharma and Kyle Cranmer Abstract The nature of the Fermi gamma-ray Galactic

Siddharth Mishra-Sharma 3 Jan 21, 2022
Structure Information is the Key: Self-Attention RoI Feature Extractor in 3D Object Detection

Structure Information is the Key: Self-Attention RoI Feature Extractor in 3D Object Detection abstract:Unlike 2D object detection where all RoI featur

DK. Zhang 2 Oct 07, 2022
Tello Drone Trajectory Tracking

With this library you can track the trajectory of your tello drone or swarm of drones in real time.

Kamran Asgarov 2 Oct 12, 2022
InDuDoNet+: A Model-Driven Interpretable Dual Domain Network for Metal Artifact Reduction in CT Images

InDuDoNet+: A Model-Driven Interpretable Dual Domain Network for Metal Artifact Reduction in CT Images Hong Wang, Yuexiang Li, Haimiao Zhang, Deyu Men

Hong Wang 4 Dec 27, 2022
A python bot to move your mouse every few seconds to appear active on Skype, Teams or Zoom as you go AFK. 🐭 🤖

PyMouseBot If you're from GT and annoyed with SGVPN idle timeouts while working on development laptop, You might find this useful. A python cli bot to

Oaker Min 6 Oct 24, 2022
(CVPR 2022) Energy-based Latent Aligner for Incremental Learning

Energy-based Latent Aligner for Incremental Learning Accepted to CVPR 2022 We illustrate an Incremental Learning model trained on a continuum of tasks

Joseph K J 37 Jan 03, 2023
Simultaneous Demand Prediction and Planning

Simultaneous Demand Prediction and Planning Dependencies Python packages: Pytorch, scikit-learn, Pandas, Numpy, PyYAML Data POI: data/poi Road network

Yizong Wang 1 Sep 01, 2022
Numerical Methods with Python, Numpy and Matplotlib

Numerical Bric-a-Brac Collections of numerical techniques with Python and standard computational packages (Numpy, SciPy, Numba, Matplotlib ...). Diffe

Vincent Bonnet 10 Dec 20, 2021
Implementation of popular SOTA self-supervised learning algorithms as Fastai Callbacks.

Self Supervised Learning with Fastai Implementation of popular SOTA self-supervised learning algorithms as Fastai Callbacks. Install pip install self-

Kerem Turgutlu 276 Dec 23, 2022
Tensors and Dynamic neural networks in Python with strong GPU acceleration

PyTorch is a Python package that provides two high-level features: Tensor computation (like NumPy) with strong GPU acceleration Deep neural networks b

61.4k Jan 04, 2023
This code is part of the reproducibility package for the SANER 2022 paper "Generating Clarifying Questions for Query Refinement in Source Code Search".

Clarifying Questions for Query Refinement in Source Code Search This code is part of the reproducibility package for the SANER 2022 paper "Generating

Zachary Eberhart 0 Dec 04, 2021
Fast Neural Style for Image Style Transform by Pytorch

FastNeuralStyle by Pytorch Fast Neural Style for Image Style Transform by Pytorch This is famous Fast Neural Style of Paper Perceptual Losses for Real

Bengxy 81 Sep 03, 2022
Source code of generalized shuffled linear regression

Generalized-Shuffled-Linear-Regression Code for the ICCV 2021 paper: Generalized Shuffled Linear Regression. Authors: Feiran Li, Kent Fujiwara, Fumio

FEI 7 Oct 26, 2022
DUE: End-to-End Document Understanding Benchmark

This is the repository that provide tools to download data, reproduce the baseline results and evaluation. What can you achieve with this guide Based

21 Dec 29, 2022