Code for NeurIPS 2020 article "Contrastive learning of global and local features for medical image segmentation with limited annotations"

Overview

Contrastive learning of global and local features for medical image segmentation with limited annotations

The code is for the article "Contrastive learning of global and local features for medical image segmentation with limited annotations" which got accepted as an Oral presentation at NeurIPS 2020 (33rd international conference on Neural Information Processing Systems). With the proposed pre-training method using Contrastive learning, we get competitive segmentation performance with just 2 labeled training volumes compared to a benchmark that is trained with many labeled volumes.
https://arxiv.org/abs/2006.10511

Observations / Conclusions:

  1. For medical image segmentation, the proposed contrastive pre-training strategy incorporating domain knowledge present naturally across medical volumes yields better performance than baseline, other pre-training methods, semi-supervised, and data augmentation methods.
  2. Proposed local contrastive loss, an extension of global loss, provides an additional boost in performance by learning distinctive local-level representation to distinguish between neighbouring regions.
  3. The proposed pre-training strategy is complementary to semi-supervised and data augmentation methods. Combining them yields a further boost in accuracy.

Authors:
Krishna Chaitanya (email),
Ertunc Erdil,
Neerav Karani,
Ender Konukoglu.

Requirements:
Python 3.6.1,
Tensorflow 1.12.0,
rest of the requirements are mentioned in the "requirements.txt" file.

I) To clone the git repository.
git clone https://github.com/krishnabits001/domain_specific_dl.git

II) Install python, required packages and tensorflow.
Then, install python packages required using below command or the packages mentioned in the file.
pip install -r requirements.txt

To install tensorflow
pip install tensorflow-gpu=1.12.0

III) Dataset download.
To download the ACDC Cardiac dataset, check the website :
https://www.creatis.insa-lyon.fr/Challenge/acdc.

To download the Medical Decathlon Prostate dataset, check the website :
http://medicaldecathlon.com/

To download the MMWHS Cardiac dataset, check the website :
http://www.sdspeople.fudan.edu.cn/zhuangxiahai/0/mmwhs/

All the images were bias corrected using N4 algorithm with a threshold value of 0.001. For more details, refer to the "N4_bias_correction.py" file in scripts.
Image and label pairs are re-sampled (to chosen target resolution) and cropped/zero-padded to a fixed size using "create_cropped_imgs.py" file.

IV) Train the models.
Below commands are an example for ACDC dataset.
The models need to be trained sequentially as follows (check "train_model/pretrain_and_fine_tune_script.sh" script for commands)
Steps :

  1. Step 1: To pre-train the encoder with global loss by incorporating proposed domain knowledge when defining positive and negative pairs.
    cd train_model/
    python pretr_encoder_global_contrastive_loss.py --dataset=acdc --no_of_tr_imgs=tr52 --global_loss_exp_no=2 --n_parts=4 --temp_fac=0.1 --bt_size=12

  2. Step 2: After step 1, we pre-train the decoder with proposed local loss to aid segmentation task by learning distinctive local-level representations.
    python pretr_decoder_local_contrastive_loss.py --dataset=acdc --no_of_tr_imgs=tr52 --pretr_no_of_tr_imgs=tr52 --local_reg_size=1 --no_of_local_regions=13 --temp_fac=0.1 --global_loss_exp_no=2 --local_loss_exp_no=0 --no_of_decoder_blocks=3 --no_of_neg_local_regions=5 --bt_size=12

  3. Step 3: We use the pre-trained encoder and decoder weights as initialization and fine-tune to segmentation task using limited annotations.
    python ft_pretr_encoder_decoder_net_local_loss.py --dataset=acdc --pretr_no_of_tr_imgs=tr52 --local_reg_size=1 --no_of_local_regions=13 --temp_fac=0.1 --global_loss_exp_no=2 --local_loss_exp_no=0 --no_of_decoder_blocks=3 --no_of_neg_local_regions=5 --no_of_tr_imgs=tr1 --comb_tr_imgs=c1 --ver=0

To train the baseline with affine and random deformations & intensity transformations for comparison, use the below code file.
cd train_model/
python tr_baseline.py --dataset=acdc --no_of_tr_imgs=tr1 --comb_tr_imgs=c1 --ver=0

V) Config files contents.
One can modify the contents of the below 2 config files to run the required experiments.
experiment_init directory contains 2 files.
Example for ACDC dataset:

  1. init_acdc.py
    --> contains the config details like target resolution, image dimensions, data path where the dataset is stored and path to save the trained models.
  2. data_cfg_acdc.py
    --> contains an example of data config details where one can set the patient ids which they want to use as train, validation and test images.

Bibtex citation:

@article{chaitanya2020contrastive,
  title={Contrastive learning of global and local features for medical image segmentation with limited annotations},
  author={Chaitanya, Krishna and Erdil, Ertunc and Karani, Neerav and Konukoglu, Ender},
  journal={Advances in Neural Information Processing Systems},
  volume={33},
  year={2020}
}
Owner
Krishna Chaitanya
Doctoral Student, ETH Zurich
Krishna Chaitanya
Learning trajectory representations using self-supervision and programmatic supervision.

Trajectory Embedding for Behavior Analysis (TREBA) Implementation from the paper: Jennifer J. Sun, Ann Kennedy, Eric Zhan, David J. Anderson, Yisong Y

58 Jan 06, 2023
Python implementation of 3D facial mesh exaggeration using the techniques described in the paper: Computational Caricaturization of Surfaces.

Python implementation of 3D facial mesh exaggeration using the techniques described in the paper: Computational Caricaturization of Surfaces.

Wonjong Jang 8 Nov 01, 2022
Few-shot Learning of GPT-3

Few-shot Learning With Language Models This is a codebase to perform few-shot "in-context" learning using language models similar to the GPT-3 paper.

Tony Z. Zhao 224 Dec 28, 2022
Training vision models with full-batch gradient descent and regularization

Stochastic Training is Not Necessary for Generalization -- Training competitive vision models without stochasticity This repository implements trainin

Jonas Geiping 32 Jan 06, 2023
A Java implementation of the experiments for the paper "k-Center Clustering with Outliers in Sliding Windows"

OutliersSlidingWindows A Java implementation of the experiments for the paper "k-Center Clustering with Outliers in Sliding Windows" Dataset generatio

PaoloPellizzoni 0 Jan 05, 2022
Implementation of our paper "DMT: Dynamic Mutual Training for Semi-Supervised Learning"

DMT: Dynamic Mutual Training for Semi-Supervised Learning This repository contains the code for our paper DMT: Dynamic Mutual Training for Semi-Superv

Zhengyang Feng 120 Dec 30, 2022
Large scale embeddings on a single machine.

Marius Marius is a system under active development for training embeddings for large-scale graphs on a single machine. Training on large scale graphs

Marius 107 Jan 03, 2023
TensorFlow implementation of "Attention is all you need (Transformer)"

[TensorFlow 2] Attention is all you need (Transformer) TensorFlow implementation of "Attention is all you need (Transformer)" Dataset The MNIST datase

YeongHyeon Park 4 Jan 05, 2022
Self-Adaptable Point Processes with Nonparametric Time Decays

NPPDecay This is our implementation for the paper Self-Adaptable Point Processes with Nonparametric Time Decays, by Zhimeng Pan, Zheng Wang, Jeff M. P

zpan 2 Sep 24, 2022
Towards Part-Based Understanding of RGB-D Scans

Towards Part-Based Understanding of RGB-D Scans (CVPR 2021) We propose the task of part-based scene understanding of real-world 3D environments: from

26 Nov 23, 2022
The author's officially unofficial PyTorch BigGAN implementation.

BigGAN-PyTorch The author's officially unofficial PyTorch BigGAN implementation. This repo contains code for 4-8 GPU training of BigGANs from Large Sc

Andy Brock 2.6k Jan 02, 2023
Yet Another Reinforcement Learning Tutorial

This repo contains self-contained RL implementations

Sungjoon 65 Dec 10, 2022
PolyphonicFormer: Unified Query Learning for Depth-aware Video Panoptic Segmentation

PolyphonicFormer: Unified Query Learning for Depth-aware Video Panoptic Segmentation Winner method of the ICCV-2021 SemKITTI-DVPS Challenge. [arxiv] [

Yuan Haobo 38 Jan 03, 2023
DeepGNN is a framework for training machine learning models on large scale graph data.

DeepGNN Overview DeepGNN is a framework for training machine learning models on large scale graph data. DeepGNN contains all the necessary features in

Microsoft 45 Jan 01, 2023
An introduction to satellite image analysis using Python + OpenCV and JavaScript + Google Earth Engine

A Gentle Introduction to Satellite Image Processing Welcome to this introductory course on Satellite Image Analysis! Satellite imagery has become a pr

Edward Oughton 32 Jan 03, 2023
BasicNeuralNetwork - This project looks over the basic structure of a neural network and how machine learning training algorithms work

BasicNeuralNetwork - This project looks over the basic structure of a neural network and how machine learning training algorithms work. For this project, I used the sigmoid function as an activation

Manas Bommakanti 1 Jan 22, 2022
This repo includes our code for evaluating and improving transferability in domain generalization (NeurIPS 2021)

Transferability for domain generalization This repo is for evaluating and improving transferability in domain generalization (NeurIPS 2021), based on

gordon 9 Nov 29, 2022
In this project we investigate the performance of the SetCon model on realistic video footage. Therefore, we implemented the model in PyTorch and tested the model on two example videos.

Contrastive Learning of Object Representations Supervisor: Prof. Dr. Gemma Roig Institutions: Goethe University CVAI - Computational Vision & Artifici

Dirk Neuhäuser 6 Dec 08, 2022
Multi-Task Learning as a Bargaining Game

Nash-MTL Official implementation of "Multi-Task Learning as a Bargaining Game". Setup environment conda create -n nashmtl python=3.9.7 conda activate

Aviv Navon 87 Dec 26, 2022