Explainability for Vision Transformers (in PyTorch)

Overview

Explainability for Vision Transformers (in PyTorch)

This repository implements methods for explainability in Vision Transformers.

See also https://jacobgil.github.io/deeplearning/vision-transformer-explainability

Currently implemented:

  • Attention Rollout.

  • Gradient Attention Rollout for class specific explainability. This is our attempt to further build upon and improve Attention Rollout.

  • TBD Attention flow is work in progress.

Includes some tweaks and tricks to get it working:

  • Different Attention Head fusion methods,
  • Removing the lowest attentions.

Usage

  • From code
from vit_grad_rollout import VITAttentionGradRollout

model = torch.hub.load('facebookresearch/deit:main', 
'deit_tiny_patch16_224', pretrained=True)
grad_rollout = VITAttentionGradRollout(model, discard_ratio=0.9, head_fusion='max')
mask = grad_rollout(input_tensor, category_index=243)
  • From the command line:
python vit_explain.py --image_path  --head_fusion  --discard_ratio  --category_index 

If category_index isn't specified, Attention Rollout will be used, otherwise Gradient Attention Rollout will be used.

Notice that by default, this uses the 'Tiny' model from Training data-efficient image transformers & distillation through attention hosted on torch hub.

Where did the Transformer pay attention to in this image?

Image Vanilla Attention Rollout With discard_ratio+max fusion

Gradient Attention Rollout for class specific explainability

The Attention that flows in the transformer passes along information belonging to different classes. Gradient roll out lets us see what locations the network paid attention too, but it tells us nothing about if it ended up using those locations for the final classification.

We can multiply the attention with the gradient of the target class output, and take the average among the attention heads (while masking out negative attentions) to keep only attention that contributes to the target category (or categories).

Where does the Transformer see a Dog (category 243), and a Cat (category 282)?

Where does the Transformer see a Musket dog (category 161) and a Parrot (category 87):

Tricks and Tweaks to get this working

Filtering the lowest attentions in every layer

--discard_ratio

Removes noise by keeping the strongest attentions.

Results for dIfferent values:

Different Attention Head Fusions

The Attention Rollout method suggests taking the average attention accross the attention heads,

but emperically it looks like taking the Minimum value, Or the Maximum value combined with --discard_ratio, works better.

--head_fusion

Image Mean Fusion Min Fusion

References

Requirements

pip install timm

Owner
Jacob Gildenblat
Machine learning / Computer Vision developer.
Jacob Gildenblat
In the case of your data having only 1 channel while want to use timm models

timm_custom Description In the case of your data having only 1 channel while want to use timm models (with or without pretrained weights), run the fol

2 Nov 26, 2021
PiCIE: Unsupervised Semantic Segmentation using Invariance and Equivariance in clustering (CVPR2021)

PiCIE: Unsupervised Semantic Segmentation using Invariance and Equivariance in Clustering Jang Hyun Cho1, Utkarsh Mall2, Kavita Bala2, Bharath Harihar

Jang Hyun Cho 164 Dec 30, 2022
Determined: Deep Learning Training Platform

Determined: Deep Learning Training Platform Determined is an open-source deep learning training platform that makes building models fast and easy. Det

Determined AI 2k Dec 31, 2022
OMLT: Optimization and Machine Learning Toolkit

OMLT is a Python package for representing machine learning models (neural networks and gradient-boosted trees) within the Pyomo optimization environment.

C⚙G - Imperial College London 179 Jan 02, 2023
To propose and implement a multi-class classification approach to disaster assessment from the given data set of post-earthquake satellite imagery.

To propose and implement a multi-class classification approach to disaster assessment from the given data set of post-earthquake satellite imagery.

Kunal Wadhwa 2 Jan 05, 2022
Realtime YOLO Monster Detection With Non Maximum Supression

Realtime-YOLO-Monster-Detection-With-Non-Maximum-Supression Table of Contents In

5 Oct 07, 2022
Pytorch Implementation of LNSNet for Superpixel Segmentation

LNSNet Overview Official implementation of Learning the Superpixel in a Non-iterative and Lifelong Manner (CVPR'21) Learning Strategy The proposed LNS

42 Oct 11, 2022
Implementation of ConvMixer for "Patches Are All You Need? 🤷"

Patches Are All You Need? 🤷 This repository contains an implementation of ConvMixer for the ICLR 2022 submission "Patches Are All You Need?" by Asher

CMU Locus Lab 934 Jan 08, 2023
Code for the tech report Toward Training at ImageNet Scale with Differential Privacy

Differentially private Imagenet training Code for the tech report Toward Training at ImageNet Scale with Differential Privacy by Alexey Kurakin, Steve

Google Research 29 Nov 03, 2022
The official PyTorch code implementation of "Human Trajectory Prediction via Counterfactual Analysis" in ICCV 2021.

Human Trajectory Prediction via Counterfactual Analysis (CausalHTP) The official PyTorch code implementation of "Human Trajectory Prediction via Count

46 Dec 03, 2022
Must-read Papers on Physics-Informed Neural Networks.

PINNpapers Contributed by IDRL lab. Introduction Physics-Informed Neural Network (PINN) has achieved great success in scientific computing since 2017.

IDRL 330 Jan 07, 2023
🗺 General purpose U-Network implemented in Keras for image segmentation

TF-Unet General purpose U-Network implemented in Keras for image segmentation Getting started • Training • Evaluation Getting started Looking for Jupy

Or Fleisher 2 Aug 31, 2022
Official source code to CVPR'20 paper, "When2com: Multi-Agent Perception via Communication Graph Grouping"

When2com: Multi-Agent Perception via Communication Graph Grouping This is the PyTorch implementation of our paper: When2com: Multi-Agent Perception vi

34 Nov 09, 2022
CausalNLP is a practical toolkit for causal inference with text as treatment, outcome, or "controlled-for" variable.

CausalNLP CausalNLP is a practical toolkit for causal inference with text as treatment, outcome, or "controlled-for" variable. Install pip install -U

Arun S. Maiya 95 Jan 03, 2023
MT3: Multi-Task Multitrack Music Transcription

MT3: Multi-Task Multitrack Music Transcription MT3 is a multi-instrument automatic music transcription model that uses the T5X framework. This is not

Magenta 867 Dec 29, 2022
Open source repository for the code accompanying the paper 'PatchNets: Patch-Based Generalizable Deep Implicit 3D Shape Representations'.

PatchNets This is the official repository for the project "PatchNets: Patch-Based Generalizable Deep Implicit 3D Shape Representations". For details,

16 May 22, 2022
NeRViS: Neural Re-rendering for Full-frame Video Stabilization

Neural Re-rendering for Full-frame Video Stabilization

Yu-Lun Liu 9 Jun 17, 2022
A method that utilized Generative Adversarial Network (GAN) to interpret the black-box deep image classifier models by PyTorch.

A method that utilized Generative Adversarial Network (GAN) to interpret the black-box deep image classifier models by PyTorch.

Yunxia Zhao 3 Dec 29, 2022
PyTorch code for: Learning to Generate Grounded Visual Captions without Localization Supervision

Learning to Generate Grounded Visual Captions without Localization Supervision This is the PyTorch implementation of our paper: Learning to Generate G

Chih-Yao Ma 41 Nov 17, 2022
Code for paper ECCV 2020 paper: Who Left the Dogs Out? 3D Animal Reconstruction with Expectation Maximization in the Loop.

Who Left the Dogs Out? Evaluation and demo code for our ECCV 2020 paper: Who Left the Dogs Out? 3D Animal Reconstruction with Expectation Maximization

Benjamin Biggs 29 Dec 28, 2022