Contextual Attention Network: Transformer Meets U-Net

Overview

Contextual Attention Network: Transformer Meets U-Net

Contexual attention network for medical image segmentation with state of the art results on skin lesion segmentation, multiple myeloma cell segmentation. This method incorpotrates the transformer module into a U-Net structure so as to concomitantly capture long-range dependency along with resplendent local informations. If this code helps with your research please consider citing the following paper:

R. Azad, Moein Heidari, Yuli Wu and Dorit Merhof , "Contextual Attention Network: Transformer Meets U-Net", download link.

@article{reza2022contextual,
  title={Contextual Attention Network: Transformer Meets U-Net},
  author={Reza, Azad and Moein, Heidari and Yuli, Wu and Dorit, Merhof},
  journal={arXiv preprint arXiv:2203.01932},
  year={2022}
}

Please consider starring us, if you found it useful. Thanks

Updates

This code has been implemented in python language using Pytorch library and tested in ubuntu OS, though should be compatible with related environment. following Environement and Library needed to run the code:

  • Python 3
  • Pytorch

Run Demo

For training deep model and evaluating on each data set follow the bellow steps:
1- Download the ISIC 2018 train dataset from this link and extract both training dataset and ground truth folders inside the dataset_isic18.
2- Run Prepare_ISIC2018.py for data preperation and dividing data to train,validation and test sets.
3- Run train_skin.py for training the model using trainng and validation sets. The model will be train for 100 epochs and it will save the best weights for the valiation set.
4- For performance calculation and producing segmentation result, run evaluate_skin.py. It will represent performance measures and will saves related results in results folder.

Notice: For training and evaluating on ISIC 2017 and ph2 follow the bellow steps :

ISIC 2017- Download the ISIC 2017 train dataset from this link and extract both training dataset and ground truth folders inside the dataset_isic18\7.
then Run Prepare_ISIC2017.py for data preperation and dividing data to train,validation and test sets.
ph2- Download the ph2 dataset from this link and extract it then Run Prepare_ph2.py for data preperation and dividing data to train,validation and test sets.
Follow step 3 and 4 for model traing and performance estimation. For ph2 dataset you need to first train the model with ISIC 2017 data set and then fine-tune the trained model using ph2 dataset.

Quick Overview

Diagram of the proposed method

Perceptual visualization of the proposed Contextual Attention module.

Diagram of the proposed method

Results

For evaluating the performance of the proposed method, Two challenging task in medical image segmentaion has been considered. In bellow, results of the proposed approach illustrated.

Task 1: SKin Lesion Segmentation

Performance Comparision on SKin Lesion Segmentation

In order to compare the proposed method with state of the art appraoches on SKin Lesion Segmentation, we considered Drive dataset.

Methods (On ISIC 2017) Dice-Score Sensivity Specificaty Accuracy
Ronneberger and et. all U-net 0.8159 0.8172 0.9680 0.9164
Oktay et. all Attention U-net 0.8082 0.7998 0.9776 0.9145
Lei et. all DAGAN 0.8425 0.8363 0.9716 0.9304
Chen et. all TransU-net 0.8123 0.8263 0.9577 0.9207
Asadi et. all MCGU-Net 0.8927 0.8502 0.9855 0.9570
Valanarasu et. all MedT 0.8037 0.8064 0.9546 0.9090
Wu et. all FAT-Net 0.8500 0.8392 0.9725 0.9326
Azad et. all Proposed TMUnet 0.9164 0.9128 0.9789 0.9660

For more results on ISIC 2018 and PH2 dataset, please refer to the paper

SKin Lesion Segmentation segmentation result on test data

SKin Lesion Segmentation  result (a) Input images. (b) Ground truth. (c) U-net. (d) Gated Axial-Attention. (e) Proposed method without a contextual attention module and (f) Proposed method.

Multiple Myeloma Cell Segmentation

Performance Evalution on the Multiple Myeloma Cell Segmentation task

Methods mIOU
Frequency recalibration U-Net 0.9392
XLAB Insights 0.9360
DSC-IITISM 0.9356
Multi-scale attention deeplabv3+ 0.9065
U-Net 0.7665
Baseline 0.9172
Proposed 0.9395

Multiple Myeloma Cell Segmentation results

Multiple Myeloma Cell Segmentation result

Model weights

You can download the learned weights for each dataset in the following table.

Dataset Learned weights
ISIC 2018 TMUnet
ISIC 2017 TMUnet
Ph2 TMUnet

Query

All implementations are done by Reza Azad and Moein Heidari. For any query please contact us for more information.

rezazad68@gmail.com
moeinheidari7829@gmail.com
Owner
Reza Azad
Deep Learning and Computer Vision Researcher
Reza Azad
Vowpal Wabbit is a machine learning system which pushes the frontier of machine learning with techniques such as online, hashing, allreduce, reductions, learning2search, active, and interactive learning.

This is the Vowpal Wabbit fast online learning code. Why Vowpal Wabbit? Vowpal Wabbit is a machine learning system which pushes the frontier of machin

Vowpal Wabbit 8.1k Jan 06, 2023
Catbird is an open source paraphrase generation toolkit based on PyTorch.

Catbird is an open source paraphrase generation toolkit based on PyTorch. Quick Start Requirements and Installation The project is based on PyTorch 1.

Afonso Salgado de Sousa 5 Dec 15, 2022
This project hosts the code for implementing the ISAL algorithm for object detection and image classification

Influence Selection for Active Learning (ISAL) This project hosts the code for implementing the ISAL algorithm for object detection and image classifi

25 Sep 11, 2022
Human Action Controller - A human action controller running on different platforms.

Human Action Controller (HAC) Goal A human action controller running on different platforms. Fun Easy-to-use Accurate Anywhere Fun Examples Mouse Cont

27 Jul 20, 2022
Differentiable rasterization applied to 3D model simplification tasks

nvdiffmodeling Differentiable rasterization applied to 3D model simplification tasks, as described in the paper: Appearance-Driven Automatic 3D Model

NVIDIA Research Projects 336 Dec 30, 2022
This project provides the code and datasets for 'CapSal: Leveraging Captioning to Boost Semantics for Salient Object Detection', CVPR 2019.

Code-and-Dataset-for-CapSal This project provides the code and datasets for 'CapSal: Leveraging Captioning to Boost Semantics for Salient Object Detec

lu zhang 48 Aug 19, 2022
用opencv的dnn模块做yolov5目标检测,包含C++和Python两个版本的程序

yolov5-dnn-cpp-py yolov5s,yolov5l,yolov5m,yolov5x的onnx文件在百度云盘下载, 链接:https://pan.baidu.com/s/1d67LUlOoPFQy0MV39gpJiw 提取码:bayj python版本的主程序是main_yolov5.

365 Jan 04, 2023
Learning Features with Parameter-Free Layers (ICLR 2022)

Learning Features with Parameter-Free Layers (ICLR 2022) Dongyoon Han, YoungJoon Yoo, Beomyoung Kim, Byeongho Heo | Paper NAVER AI Lab, NAVER CLOVA Up

NAVER AI 65 Dec 07, 2022
HALO: A Skeleton-Driven Neural Occupancy Representation for Articulated Hands

HALO: A Skeleton-Driven Neural Occupancy Representation for Articulated Hands Oral Presentation, 3DV 2021 Korrawe Karunratanakul, Adrian Spurr, Zicong

Korrawe Karunratanakul 43 Oct 07, 2022
Python package for visualizing the loss landscape of parameterized quantum algorithms.

orqviz A Python package for easily visualizing the loss landscape of Variational Quantum Algorithms by Zapata Computing Inc. orqviz provides a collect

Zapata Computing, Inc. 75 Dec 30, 2022
Set of models for classifcation of 3D volumes

Classification models 3D Zoo - Keras and TF.Keras This repository contains 3D variants of popular CNN models for classification like ResNets, DenseNet

69 Dec 28, 2022
Multi-resolution SeqMatch based long-term Place Recognition

MRS-SLAM for long-term place recognition In this work, we imply an multi-resolution sambling based visual place recognition method. This work is based

METASLAM 6 Dec 06, 2022
Immortal tracker

Immortal_tracker Prerequisite Our code is tested for Python 3.6. To install required liabraries: pip install -r requirements.txt Waymo Open Dataset P

74 Dec 03, 2022
The Medical Detection Toolkit contains 2D + 3D implementations of prevalent object detectors such as Mask R-CNN, Retina Net, Retina U-Net, as well as a training and inference framework focused on dealing with medical images.

The Medical Detection Toolkit contains 2D + 3D implementations of prevalent object detectors such as Mask R-CNN, Retina Net, Retina U-Net, as well as a training and inference framework focused on dea

MIC-DKFZ 1.2k Jan 04, 2023
An attempt at the implementation of Glom, Geoffrey Hinton's new idea that integrates neural fields, predictive coding, top-down-bottom-up, and attention (consensus between columns)

GLOM - Pytorch (wip) An attempt at the implementation of Glom, Geoffrey Hinton's new idea that integrates neural fields, predictive coding,

Phil Wang 173 Dec 14, 2022
Toward Realistic Single-View 3D Object Reconstruction with Unsupervised Learning from Multiple Images (ICCV 2021)

Table of Content Introduction Getting Started Datasets Installation Experiments Training & Testing Pretrained models Texture fine-tuning Demo Toward R

VinAI Research 42 Dec 05, 2022
A simple, clean TensorFlow implementation of Generative Adversarial Networks with a focus on modeling illustrations.

IllustrationGAN A simple, clean TensorFlow implementation of Generative Adversarial Networks with a focus on modeling illustrations. Generated Images

268 Nov 27, 2022
A repository for interferometer controller code.

dses-interferometer-controller A repository for interferometer controller code, hardware, and simulations. See dses.science for more information on th

Eli Reed 1 Jan 17, 2022
Sharpness-Aware Minimization for Efficiently Improving Generalization

Sharpness-Aware-Minimization-TensorFlow This repository provides a minimal implementation of sharpness-aware minimization (SAM) (Sharpness-Aware Minim

Sayak Paul 54 Dec 08, 2022
A parallel framework for population-based multi-agent reinforcement learning.

MALib: A parallel framework for population-based multi-agent reinforcement learning MALib is a parallel framework of population-based learning nested

MARL @ SJTU 348 Jan 08, 2023