An Implementation of Transformer in Transformer in TensorFlow for image classification, attention inside local patches

Overview

Transformer-in-Transformer Twitter

PyPI Open In Colab Upload Python Package Lint Code Base Code style: black

GitHub License GitHub stars GitHub followers Twitter Follow

An Implementation of the Transformer in Transformer paper by Han et al. for image classification, attention inside local patches. Transformer in Transformer uses pixel level attention paired with patch level attention for image classification, in TensorFlow.

PyTorch Implementation

Installation

Run the following to install:

pip install tnt-tensorflow

Developing tnt-tensorflow

To install tnt-tensorflow, along with tools you need to develop and test, run the following in your virtualenv:

git clone https://github.com/Rishit-dagli/Transformer-in-Transformer.git
# or clone your own fork

cd tnt
pip install -e .[dev]

Usage

import tensorflow as tf
from tnt import TNT

tnt = TNT(
    image_size=256,  # size of image
    patch_dim=512,  # dimension of patch token
    pixel_dim=24,  # dimension of pixel token
    patch_size=16,  # patch size
    pixel_size=4,  # pixel size
    depth=5,  # depth
    num_classes=1000,  # output number of classes
    attn_dropout=0.1,  # attention dropout
    ff_dropout=0.1,  # feedforward dropout
)

img = tf.random.uniform(shape=[5, 3, 256, 256])
logits = tnt(img) # (5, 1000)

Want to Contribute 🙋‍♂️ ?

Awesome! If you want to contribute to this project, you're always welcome! See Contributing Guidelines. You can also take a look at open issues for getting more information about current or upcoming tasks.

Want to discuss? 💬

Have any questions, doubts or want to present your opinions, views? You're always welcome. You can start discussions.

Citation

@misc{han2021transformer,
      title={Transformer in Transformer}, 
      author={Kai Han and An Xiao and Enhua Wu and Jianyuan Guo and Chunjing Xu and Yunhe Wang},
      year={2021},
      eprint={2103.00112},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

License

Copyright 2020 Rishit Dagli

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Comments
  • Add Unit Tests

    Add Unit Tests

    The tests should check for the rank and shape of the output tensors, the test should override tf.test.TestCase base class.

    • [x] #15
    • [x] #16
    • [x] #18
    • [x] #17

    Feel free to take inspiration from:

    • https://github.com/Rishit-dagli/Fast-Transformer/blob/main/fast_transformer/test_fast_transformer.py
    • For parametrization feel free to follow https://stackoverflow.com/a/34094/11878567, can be used in the exact same way with subTest in TensorFlow
    enhancement good first issue 
    opened by Rishit-dagli 3
  • Update Workflows to run tests

    Update Workflows to run tests

    This issue follows #11

    Update GitHub Workflows to:

    • [ ] Run Tests before uploading to PyPI
    • [ ] Create a workflow to run tests on commits

    Feel free to take inspiration from https://github.com/Rishit-dagli/Fast-Transformer/tree/main/.github/workflows

    enhancement good first issue 
    opened by Rishit-dagli 0
  • Creates an Attention layer

    Creates an Attention layer

    Verify output shapes just from the attention layer:

    import tensorflow as tf
    Attention(dim=256)(tf.random.normal([3,256,256]))
    
    # <tf.Tensor: shape=(3, 256, 256), dtype=float32,
    

    Closes #3

    opened by Rishit-dagli 0
  • Put together a TNT class

    Put together a TNT class

    Verify shapes:

    tnt = TNT(
        image_size=256,  # size of image
        patch_dim=512,  # dimension of patch token
        pixel_dim=24,  # dimension of pixel token
        patch_size=16,  # patch size
        pixel_size=4,  # pixel size
        depth=5,  # depth
        num_classes=1000,  # output number of classes
        attn_dropout=0.1,  # attention dropout
        ff_dropout=0.1,  # feedforward dropout
    )
    
    img = tf.random.uniform(shape=[1, 3, 256, 256])
    print(tnt(img).shape)
    
    # (1, 1000)
    ```
    opened by Rishit-dagli 0
  • Create an Attention layerr

    Create an Attention layerr

    Verify output shapes just from the attention layer:

    import tensorflow as tf
    Attention(dim=256)(tf.random.normal([3,256,256]))
    
    # <tf.Tensor: shape=(3, 256, 256), dtype=float32,
    
    opened by Rishit-dagli 0
  • Create a PreNorm layer

    Create a PreNorm layer

    Verify output shapes from this layer:

    import tensorflow as tf
    PreNorm(dim=1, fn=tf.keras.layers.Dense(5))(tf.random.normal([10, 1]))
    
    # <tf.Tensor: shape=(10, 1), dtype=float32,
    
    opened by Rishit-dagli 0
Releases(v0.2.0)
  • v0.2.0(Feb 2, 2022)

    This is an interesting release for the project, including a pre-trained model on ImageNet, reproducibility of paper results, tests, and end-to-end training.

    ✅ Bug Fixes / Improvements

    • Create an end-to-end training example demonstrating how to train a TNT model for image classification through a custom training loop on the TF Flowers dataset (#14)
    • Pre-trained model to reproduce the paper results have been made available (in this release as well as on TensorFlow Hub)
    • Create an off-the-shelf inference example, that highlights how you can directly use the pre-trained model made available
    • Unit Tests for the Attention class (#19)
    • Unit Tests for the main TNT class (#20)

    Full Changelog: https://github.com/Rishit-dagli/Transformer-in-Transformer/compare/v0.1.0...v0.2.0

    Source code(tar.gz)
    Source code(zip)
    tnt_s_patch16_224.tar.gz(84.42 MB)
  • v0.1.0(Dec 3, 2021)

    This is the initial release of TNT TensorFlow and implements Transformers in Transformers as a subclassed TensorFlow model.

    Classes

    • Attention: Implements attention as a TensorFlow Keras Layer making some modifications.
    • PreNorm: Normalize the activations of the previous layer for each given example in a batch independently and apply some function to it, implemented as a TensorFlow Keras Layer.
    • FeedForward: Create a FeedForward neural net with two Dense layers and GELU activation, implemented as a TensorFlow Keras Layer.
    • TNT: Implements the Transformers in Transformers model using all the other classes, and converts to logits. Implemented as a TensorFlow Keras Model.
    Source code(tar.gz)
    Source code(zip)
    tnt_s_patch16_224.tar.gz(84.42 MB)
Owner
Rishit Dagli
High School,TEDx,2xTED-Ed speaker | International Speaker | Microsoft Student Ambassador | Mentor, @TFUGMumbai | Organize @KotlinMumbai
Rishit Dagli
Official Pytorch implementation for video neural representation (NeRV)

NeRV: Neural Representations for Videos (NeurIPS 2021) Project Page | Paper | UVG Data Hao Chen, Bo He, Hanyu Wang, Yixuan Ren, Ser-Nam Lim, Abhinav S

hao 214 Dec 28, 2022
LeafSnap replicated using deep neural networks to test accuracy compared to traditional computer vision methods.

Deep-Leafsnap Convolutional Neural Networks have become largely popular in image tasks such as image classification recently largely due to to Krizhev

Sujith Vishwajith 48 Nov 27, 2022
LSTC: Boosting Atomic Action Detection with Long-Short-Term Context

LSTC: Boosting Atomic Action Detection with Long-Short-Term Context This Repository contains the code on AVA of our ACM MM 2021 paper: LSTC: Boosting

Tencent YouTu Research 9 Oct 11, 2022
A simple log parser and summariser for IIS web server logs

IISLogFileParser A basic parser tool for IIS Logs which summarises findings from the log file. Inspired by the Gist https://gist.github.com/wh13371/e7

2 Mar 26, 2022
Overview of architecture and implementation of TEDS-Net, as described in MICCAI 2021: "TEDS-Net: Enforcing Diffeomorphisms in Spatial Transformers to Guarantee TopologyPreservation in Segmentations"

TEDS-Net Overview of architecture and implementation of TEDS-Net, as described in MICCAI 2021: "TEDS-Net: Enforcing Diffeomorphisms in Spatial Transfo

Madeleine K Wyburd 14 Jan 04, 2023
Code and datasets for the paper "KnowPrompt: Knowledge-aware Prompt-tuning with Synergistic Optimization for Relation Extraction"

KnowPrompt Code and datasets for our paper "KnowPrompt: Knowledge-aware Prompt-tuning with Synergistic Optimization for Relation Extraction" Requireme

ZJUNLP 137 Dec 31, 2022
Twins: Revisiting the Design of Spatial Attention in Vision Transformers

Twins: Revisiting the Design of Spatial Attention in Vision Transformers Very recently, a variety of vision transformer architectures for dense predic

482 Dec 18, 2022
Demo code for paper "Learning optical flow from still images", CVPR 2021.

Depthstillation Demo code for "Learning optical flow from still images", CVPR 2021. [Project page] - [Paper] - [Supplementary] This code is provided t

130 Dec 25, 2022
Official PyTorch implementation of paper: Standardized Max Logits: A Simple yet Effective Approach for Identifying Unexpected Road Obstacles in Urban-Scene Segmentation (ICCV 2021 Oral Presentation)

SML (ICCV 2021, Oral) : Official Pytorch Implementation This repository provides the official PyTorch implementation of the following paper: Standardi

SangHun 61 Dec 27, 2022
InsightFace: 2D and 3D Face Analysis Project on MXNet and PyTorch

InsightFace: 2D and 3D Face Analysis Project on MXNet and PyTorch

Deep Insight 13.2k Jan 06, 2023
Lab course materials for IEMBA 8/9 course "Coding and Artificial Intelligence"

IEMBA 8/9 - Coding and Artificial Intelligence Dear IEMBA 8/9 students, welcome to our IEMBA 8/9 elective course Coding and Artificial Intelligence, t

Artificial Intelligence & Machine Learning (AI:ML Lab) @ HSG 1 Jan 11, 2022
Answering Open-Domain Questions of Varying Reasoning Steps from Text

This repository contains the authors' implementation of the Iterative Retriever, Reader, and Reranker (IRRR) model in the EMNLP 2021 paper "Answering Open-Domain Questions of Varying Reasoning Steps

26 Dec 22, 2022
Bare bones use-case for deploying a containerized web app (built in streamlit) on AWS.

Containerized Streamlit web app This repository is featured in a 3-part series on Deploying web apps with Streamlit, Docker, and AWS. Checkout the blo

Collin Prather 62 Jan 02, 2023
A light-weight image labelling tool for Python designed for creating segmentation data sets.

An image labelling tool for creating segmentation data sets, for Django and Flask.

117 Nov 21, 2022
Pytorch reimplementation of the Vision Transformer (An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale)

Vision Transformer Pytorch reimplementation of Google's repository for the ViT model that was released with the paper An Image is Worth 16x16 Words: T

Eunkwang Jeon 1.4k Dec 28, 2022
Learning from History: Modeling Temporal Knowledge Graphs with Sequential Copy-Generation Networks

CyGNet This repository reproduces the AAAI'21 paper “Learning from History: Modeling Temporal Knowledge Graphs with Sequential Copy-Generation Network

CunchaoZ 89 Jan 03, 2023
A big endian Gentoo port developed on a Pine64.org RockPro64

Gentoo-aarch64_be A big endian Gentoo port developed on a Pine64.org RockPro64 The endian wars are over... little endian won. As a result, it is incre

Rory Bolt 6 Dec 07, 2022
Distributed Evolutionary Algorithms in Python

DEAP DEAP is a novel evolutionary computation framework for rapid prototyping and testing of ideas. It seeks to make algorithms explicit and data stru

Distributed Evolutionary Algorithms in Python 4.9k Jan 05, 2023
Neon-erc20-example - Example of creating SPL token and wrapping it with ERC20 interface in Neon EVM

Example of wrapping SPL token by ERC2-20 interface in Neon Requirements Install

7 Mar 28, 2022
Luminaire is a python package that provides ML driven solutions for monitoring time series data.

A hands-off Anomaly Detection Library Table of contents What is Luminaire Quick Start Time Series Outlier Detection Workflow Anomaly Detection for Hig

Zillow 670 Jan 02, 2023