A library to inspect itermediate layers of PyTorch models.

Overview

A library to inspect itermediate layers of PyTorch models.

Why?

It's often the case that we want to inspect intermediate layers of a model without modifying the code e.g. visualize attention matrices of language models, get values from an intermediate layer to feed to another layer, or applying a loss function to intermediate layers.

Install

$ pip install surgeon-pytorch

PyPI - Python Version

Usage

Inspect

Given a PyTorch model we can display all layers using get_layers:

import torch
import torch.nn as nn

from surgeon_pytorch import Inspect, get_layers

class SomeModel(nn.Module):

    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(5, 3)
        self.layer2 = nn.Linear(3, 2)
        self.layer3 = nn.Linear(2, 1)

    def forward(self, x):
        x1 = self.layer1(x)
        x2 = self.layer2(x1)
        y = self.layer3(x2)
        return y


model = SomeModel()
print(get_layers(model)) # ['layer1', 'layer2', 'layer3']

Then we can wrap our model to be inspected using Inspect and in every forward call the new model we will also output the provided layer outputs (in second return value):

model_wrapped = Inspect(model, layer='layer2')
x = torch.rand(1, 5)
y, x2 = model_wrapped(x)
print(x2) # tensor([[-0.2726,  0.0910]], grad_fn=<AddmmBackward0>)

We can also provide a list of layers:

model_wrapped = Inspect(model, layer=['layer1', 'layer2'])
x = torch.rand(1, 5)
y, [x1, x2] = model_wrapped(x)
print(x1) # tensor([[ 0.1739,  0.3844, -0.4724]], grad_fn=<AddmmBackward0>)
print(x2) # tensor([[-0.2238,  0.0107]], grad_fn=<AddmmBackward0>)

Or a dictionary to get named outputs:

model_wrapped = Inspect(model, layer={'x1': 'layer1', 'x2': 'layer2'})
x = torch.rand(1, 5)
y, layers = model_wrapped(x)
print(layers)
"""
{
    'x1': tensor([[ 0.3707,  0.6584, -0.2970]], grad_fn=<AddmmBackward0>),
    'x2': tensor([[-0.1953, -0.3408]], grad_fn=<AddmmBackward0>)
}
"""

TODO

  • add extract function to get intermediate block
You might also like...
Ever felt tired after preprocessing the dataset, and not wanting to write any code further to train your model? Ever encountered a situation where you wanted to record the hyperparameters of the trained model and able to retrieve it afterward? Models Playground is here to help you do that. Models playground allows you to train your models right from the browser. pyhsmm - library for approximate unsupervised inference in Bayesian Hidden Markov Models (HMMs) and explicit-duration Hidden semi-Markov Models (HSMMs), focusing on the Bayesian Nonparametric extensions, the HDP-HMM and HDP-HSMM, mostly with weak-limit approximations. PyTorch implementation and pretrained models for XCiT models. See XCiT: Cross-Covariance Image Transformer
TorchGeo is a PyTorch domain library, similar to torchvision, that provides datasets, transforms, samplers, and pre-trained models specific to geospatial data.

TorchGeo is a PyTorch domain library, similar to torchvision, that provides datasets, transforms, samplers, and pre-trained models specific to geospatial data.

Pytorch library for end-to-end transformer models training and serving

Pytorch library for end-to-end transformer models training and serving

This repository provides an efficient PyTorch-based library for training deep models.

An Efficient Library for Training Deep Models This repository provides an efficient PyTorch-based library for training deep models. Installation Make

TorchMultimodal is a PyTorch library for training state-of-the-art multimodal multi-task models at scale.

TorchMultimodal (Alpha Release) Introduction TorchMultimodal is a PyTorch library for training state-of-the-art multimodal multi-task models at scale.

Objective of the repository is to learn and build machine learning models using Pytorch. 30DaysofML Using Pytorch
Objective of the repository is to learn and build machine learning models using Pytorch. 30DaysofML Using Pytorch

30 Days Of Machine Learning Using Pytorch Objective of the repository is to learn and build machine learning models using Pytorch. List of Algorithms

Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch
Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

Comments
  • Use one backbone with different heads

    Use one backbone with different heads

    Is it possible to save the results from the backbone and apply them on the heads of the all the other models. My goal was to try to save time by avoiding repeating the backbone part. Instead of running the 3 complete models (left), only run the backbone 1 time and switch only the heads for the 3 models (right), therefore not repeating executing the backbone every time in yolov5 model.

    Thank you for the help!

    question 
    opened by brunopatricio2012 4
  • Support for DataParallel?

    Support for DataParallel?

    Hi, I noticed that the current version does not support parallel models (at least those created using torch.nn.DataParallel) since the forward hook does not differentiate between the different copies of the model and a model wrapped with Inspect will just return the intermediate features of the last copy of the parallelized model to run.

    Are you planning on fixing this issue/supporting this use case?

    opened by zimmerrol 1
Releases(0.0.4)
Owner
archinet.ai
AI Research Group
archinet.ai
Unsupervised Learning of Probably Symmetric Deformable 3D Objects from Images in the Wild

Unsupervised Learning of Probably Symmetric Deformable 3D Objects from Images in the Wild

1.1k Jan 03, 2023
Implementation of fast algorithms for Maximum Spanning Tree (MST) parsing that includes fast ArcMax+Reweighting+Tarjan algorithm for single-root dependency parsing.

Fast MST Algorithm Implementation of fast algorithms for (Maximum Spanning Tree) MST parsing that includes fast ArcMax+Reweighting+Tarjan algorithm fo

Miloš Stanojević 11 Oct 14, 2022
(CVPR 2022) Pytorch implementation of "Self-supervised transformers for unsupervised object discovery using normalized cut"

(CVPR 2022) TokenCut Pytorch implementation of Tokencut: Self-supervised Transformers for Unsupervised Object Discovery using Normalized Cut Yangtao W

YANGTAO WANG 200 Jan 02, 2023
The (Official) PyTorch Implementation of the paper "Deep Extraction of Manga Structural Lines"

MangaLineExtraction_PyTorch The (Official) PyTorch Implementation of the paper "Deep Extraction of Manga Structural Lines" Usage model_torch.py [sourc

Miaomiao Li 82 Jan 02, 2023
Manipulation OpenAI Gym environments to simulate robots at the STARS lab

Manipulator Learning This repository contains a set of manipulation environments that are compatible with OpenAI Gym and simulated in pybullet. In par

STARS Laboratory 5 Dec 08, 2022
Categorical Depth Distribution Network for Monocular 3D Object Detection

CaDDN CaDDN is a monocular-based 3D object detection method. This repository is based off of [OpenPCDet]. Categorical Depth Distribution Network for M

Toronto Robotics and AI Laboratory 289 Jan 05, 2023
Code for the paper "Regularizing Variational Autoencoder with Diversity and Uncertainty Awareness"

DU-VAE This is the pytorch implementation of the paper "Regularizing Variational Autoencoder with Diversity and Uncertainty Awareness" Acknowledgement

Dazhong Shen 4 Oct 19, 2022
Implementation of "Meta-rPPG: Remote Heart Rate Estimation Using a Transductive Meta-Learner"

Meta-rPPG: Remote Heart Rate Estimation Using a Transductive Meta-Learner This repository is the official implementation of Meta-rPPG: Remote Heart Ra

Eugene Lee 137 Dec 13, 2022
PyTorch implementation of Value Iteration Networks (VIN): Clean, Simple and Modular. Visualization in Visdom.

VIN: Value Iteration Networks This is an implementation of Value Iteration Networks (VIN) in PyTorch to reproduce the results.(TensorFlow version) Key

Xingdong Zuo 215 Dec 07, 2022
Official implementation of our neural-network-based fast diffuse room impulse response generator (FAST-RIR)

This is the official implementation of our neural-network-based fast diffuse room impulse response generator (FAST-RIR) for generating room impulse responses (RIRs) for a given acoustic environment.

12 Jan 13, 2022
Real-Time-Student-Attendence-System - Real Time Student Attendence System

Real-Time-Student-Attendence-System The Student Attendance Management System Pro

Rounak Das 1 Feb 15, 2022
[AAAI 2022] Sparse Structure Learning via Graph Neural Networks for Inductive Document Classification

Sparse Structure Learning via Graph Neural Networks for inductive document classification Make graph dataset create co-occurrence graph for datasets.

16 Dec 22, 2022
Interpretation of T cell states using reference single-cell atlases

Interpretation of T cell states using reference single-cell atlases ProjecTILs is a computational method to project scRNA-seq data into reference sing

Cancer Systems Immunology Lab 139 Jan 03, 2023
Artificial Intelligence search algorithm base on Pacman

Pacman Search Artificial Intelligence search algorithm base on Pacman Source The Pacman Projects by the University of California, Berkeley. Layouts Di

Day Fundora 6 Nov 17, 2022
This is the repo for the paper "Improving the Accuracy-Memory Trade-Off of Random Forests Via Leaf-Refinement".

Improving the Accuracy-Memory Trade-Off of Random Forests Via Leaf-Refinement This is the repository for the paper "Improving the Accuracy-Memory Trad

3 Dec 29, 2022
Embeds a story into a music playlist by sorting the playlist so that the order of the music follows a narrative arc.

playlist-story-builder This project attempts to embed a story into a music playlist by sorting the playlist so that the order of the music follows a n

Dylan R. Ashley 0 Oct 28, 2021
Breast Cancer Detection 🔬 ITI "AI_Pro" Graduation Project

BreastCancerDetection - This program is designed to predict two severity of abnormalities associated with breast cancer cells: benign and malignant. Mammograms from MIAS is preprocessed and features

6 Nov 29, 2022
PointCNN: Convolution On X-Transformed Points (NeurIPS 2018)

PointCNN: Convolution On X-Transformed Points Created by Yangyan Li, Rui Bu, Mingchao Sun, Wei Wu, Xinhan Di, and Baoquan Chen. Introduction PointCNN

Yangyan Li 1.3k Dec 21, 2022
Official code release for "GRAF: Generative Radiance Fields for 3D-Aware Image Synthesis"

GRAF This repository contains official code for the paper GRAF: Generative Radiance Fields for 3D-Aware Image Synthesis. You can find detailed usage i

349 Dec 29, 2022
Propagate Yourself: Exploring Pixel-Level Consistency for Unsupervised Visual Representation Learning, CVPR 2021

Propagate Yourself: Exploring Pixel-Level Consistency for Unsupervised Visual Representation Learning By Zhenda Xie*, Yutong Lin*, Zheng Zhang, Yue Ca

Zhenda Xie 293 Dec 20, 2022