Self-supervised learning algorithms provide a way to train Deep Neural Networks in an unsupervised way using contrastive losses

Overview

Self-supervised learning

Paper Conference

CI testing

Self-supervised learning algorithms provide a way to train Deep Neural Networks in an unsupervised way using contrastive losses. The idea is to learn a representation which can discriminate between negative examples and be as close as possible to augmentations and transformations of itself. In this approach, we first train a ResNet on the unlabeled dataset which is then fine-tuned on a relatively small labeled one. This approach drastically reduces the amount of labeled data required, a big problem in applying deep learning in the real world. Surprisingly, this approach actually leads to increase in robustness as well as raw performance, when compared to fully supervised counterparts, even with the same architecture.

In case, the user wants to skip the pre-training part, the pre-trained weights can be downloaded from here to use for fine-tuning tasks and directly skip to the second part of the tutorial which is using the 'ssl_finetune_train.py'.

Steps to run the tutorial

1.) Download the two datasets TCIA-Covid19 & BTCV (More detail about them in the Data section)
2.) Modify the paths for data_root, json_path & logdir in ssl_script_train.py
3.) Run the 'ssl_script_train.py'
4.) Modify the paths for data_root, json_path, pre-trained_weights_path from 2.) and logdir_path in 'ssl_finetuning_train.py'
5.) Run the 'ssl_finetuning_script.py'
6.) And that's all folks, use the model to your needs

1.Data

Pre-training Dataset: The TCIA Covid-19 dataset was used for generating the pre-trained weights. The dataset contains a total of 771 3D CT Volumes. The volumes were split into training and validation sets of 600 and 171 3D volumes correspondingly. The data is available for download at this link. If this dataset is being used in your work, please use [1] as reference. A json file is provided which contains the training and validation splits that were used for the training. The json file can be found in the json_files directory of the self-supervised training tutorial.

Fine-tuning Dataset: The dataset from Beyond the Cranial Vault Challenge (BTCV) 2015 hosted at MICCAI, was used as a fully supervised fine-tuning task on the pre-trained weights. The dataset consists of 30 3D Volumes with annotated labels of up to 13 different organs [2]. There are 3 json files provided in the json_files directory for the dataset. They correspond to having different number of training volumes ranging from 6, 12 and 24. All 3 json files have the same validation split.

References:

1.) Harmon, Stephanie A., et al. "Artificial intelligence for the detection of COVID-19 pneumonia on chest CT using multinational datasets." Nature communications 11.1 (2020): 1-7.

2.) Tang, Yucheng, et al. "High-resolution 3D abdominal segmentation with random patch network fusion." Medical Image Analysis 69 (2021): 101894.

2. Network Architectures

For pre-training a modified version of ViT [1] has been used, it can be referred here from MONAI. The original ViT was modified by attachment of two 3D Convolutional Transpose Layers to achieve a similar reconstruction size as that of the input image. The ViT is the backbone for the UNETR [2] network architecture which was used for the fine-tuning fully supervised tasks.

The pre-trained backbone of ViT weights were loaded to UNETR and the decoder head still relies on random initialization for adaptability of the new downstream task. This flexibility also allows the user to adapt the ViT backbone to their own custom created network architectures as well.

References:

1.) Dosovitskiy, Alexey, et al. "An image is worth 16x16 words: Transformers for image recognition at scale." arXiv preprint arXiv:2010.11929 (2020).

2.) Hatamizadeh, Ali, et al. "Unetr: Transformers for 3d medical image segmentation." arXiv preprint arXiv:2103.10504 (2021).

3. Self-supervised Tasks

The pre-training pipeline has two aspects to it (Refer figure shown below). First, it uses augmentation (top row) to mutate the data and second, it utilizes regularized contrastive loss [3] to learn feature representations of the unlabeled data. The multiple augmentations are applied on a randomly selected 3D foreground patch from a 3D volume. Two augmented views of the same 3D patch are generated for the contrastive loss as it functions by drawing the two augmented views closer to each other if the views are generated from the same patch, if not then it tries to maximize the disagreement. The CL offers this functionality on a mini-batch.

image

The augmentations mutate the 3D patch in various ways, the primary task of the network is to reconstruct the original image. The different augmentations used are classical techniques such as in-painting [1], out-painting [1] and noise augmentation to the image by local pixel shuffling [2]. The secondary task of the network is to simultaneously reconstruct the two augmented views as similar to each other as possible via regularized contrastive loss [3] as its objective is to maximize the agreement. The term regularized has been used here because contrastive loss is adjusted by the reconstruction loss as a dynamic weight itself.

The below example image depicts the usage of the augmentation pipeline where two augmented views are drawn of the same 3D patch:

image

Multiple axial slices of a 96x96x96 patch are shown before the augmentation (Ref Original Patch in the above figure). Augmented View 1 & 2 are different augmentations generated via the transforms on the same cubic patch. The objective of the SSL network is to reconstruct the original top row image from the first view. The contrastive loss is driven by maximizing agreement of the reconstruction based on input of the two augmented views. matshow3d from monai.visualize was used for creating this figure, a tutorial for using can be found here

References:

1.) Pathak, Deepak, et al. "Context encoders: Feature learning by inpainting." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.

2.) Chen, Liang, et al. "Self-supervised learning for medical image analysis using image context restoration." Medical image analysis 58 (2019): 101539.

3.) Chen, Ting, et al. "A simple framework for contrastive learning of visual representations." International conference on machine learning. PMLR, 2020.

4. Experiment Hyper-parameters

Training Hyper-Parameters for SSL:
Epochs: 300
Validation Frequency: 2
Learning Rate: 1e-4
Batch size: 4 3D Volumes (Total of 8 as 2 samples were drawn per 3D Volume)
Loss Function: L1 Contrastive Loss Temperature: 0.005

Training Hyper-parameters for Fine-tuning BTCV task (All settings have been kept consistent with prior UNETR 3D Segmentation tutorial):
Number of Steps: 30000
Validation Frequency: 100 steps
Batch Size: 1 3D Volume (4 samples are drawn per 3D volume)
Learning Rate: 1e-4
Loss Function: DiceCELoss

4. Training & Validation Curves for pre-training SSL

image

L1 error reported for training and validation when performing the SSL training. Please note contrastive loss is not L1.

5. Results of the Fine-tuning vs Random Initialization on BTCV

Training Volumes Validation Volumes Random Init Dice score Pre-trained Dice Score Relative Performance Improvement
6 6 63.07 70.09 ~11.13%
12 6 76.06 79.55 ~4.58%
24 6 78.91 82.30 ~4.29%

Citation

@article{Arijit Das,
  title={Self-supervised learning for medical data},
  author={Arijit Das},
  journal={https://github.com/das-projects/selfsupervised-learning},
  year={2020}
}
Owner
Arijit Das
Data Scientist who is passionate about developing and implementing robust and explainable Machine Learning algorithms.
Arijit Das
Code for paper Adaptively Aligned Image Captioning via Adaptive Attention Time

Adaptively Aligned Image Captioning via Adaptive Attention Time This repository includes the implementation for Adaptively Aligned Image Captioning vi

Lun Huang 45 Aug 27, 2022
RoboDesk A Multi-Task Reinforcement Learning Benchmark

RoboDesk A Multi-Task Reinforcement Learning Benchmark If you find this open source release useful, please reference in your paper: @misc{kannan2021ro

Google Research 66 Oct 07, 2022
Weakly Supervised Segmentation by Tensorflow.

Weakly Supervised Segmentation by Tensorflow. Implements semantic segmentation in Simple Does It: Weakly Supervised Instance and Semantic Segmentation, by Khoreva et al. (CVPR 2017).

CHENG-YOU LU 52 Dec 27, 2022
Code for reproducing key results in the paper "InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets"

Status: Archive (code is provided as-is, no updates expected) InfoGAN Code for reproducing key results in the paper InfoGAN: Interpretable Representat

OpenAI 1k Dec 19, 2022
CURL: Contrastive Unsupervised Representations for Reinforcement Learning

CURL Rainbow Status: Archive (code is provided as-is, no updates expected) This is an implementation of CURL: Contrastive Unsupervised Representations

Aravind Srinivas 46 Dec 12, 2022
Supplementary code for the paper "Meta-Solver for Neural Ordinary Differential Equations" https://arxiv.org/abs/2103.08561

Meta-Solver for Neural Ordinary Differential Equations Towards robust neural ODEs using parametrized solvers. Main idea Each Runge-Kutta (RK) solver w

Julia Gusak 25 Aug 12, 2021
Christmas face app for Decathlon xmas coding party!

Christmas Face Application Use this library to create the perfect picture for your christmas cards! Done by Hasib Zunair, Guillaume Brassard and Samue

Hasib Zunair 4 Dec 20, 2021
LogAvgExp - Pytorch Implementation of LogAvgExp

LogAvgExp - Pytorch Implementation of LogAvgExp for Pytorch Install $ pip instal

Phil Wang 31 Oct 14, 2022
Image Segmentation using U-Net, U-Net with skip connections and M-Net architectures

Brain-Image-Segmentation Segmentation of brain tissues in MRI image has a number of applications in diagnosis, surgical planning, and treatment of bra

Angad Bajwa 8 Oct 27, 2022
Code for IntraQ, PyTorch implementation of our paper under review

IntraQ: Learning Synthetic Images with Intra-Class Heterogeneity for Zero-Shot Network Quantization paper Requirements Python = 3.7.10 Pytorch == 1.7

1 Nov 19, 2021
PyTorch Connectomics: segmentation toolbox for EM connectomics

Introduction The field of connectomics aims to reconstruct the wiring diagram of the brain by mapping the neural connections at the level of individua

Zudi Lin 132 Dec 26, 2022
All public open-source implementations of convnets benchmarks

convnet-benchmarks Easy benchmarking of all public open-source implementations of convnets. A summary is provided in the section below. Machine: 6-cor

Soumith Chintala 2.7k Dec 30, 2022
Facial expression detector

A tensorflow convolutional neural network model to detect facial expressions.

Carlos Tardón Rubio 5 Apr 20, 2022
Use Python, OpenCV, and MediaPipe to control a keyboard with facial gestures

CheekyKeys A Face-Computer Interface CheekyKeys lets you control your keyboard using your face. View a fuller demo and more background on the project

69 Nov 09, 2022
Robotic Process Automation in Windows and Linux by using Driagrams.net BPMN diagrams.

BPMN_RPA Robotic Process Automation in Windows and Linux by using BPMN diagrams. With this Framework you can draw Business Process Model Notation base

23 Dec 14, 2022
Python Actor concurrency library

Thespian Actor Library This library provides the framework of an Actor model for use by applications implementing Actors. Thespian Site with Documenta

Kevin Quick 177 Dec 11, 2022
An easier way to build neural search on the cloud

An easier way to build neural search on the cloud Jina is a deep learning-powered search framework for building cross-/multi-modal search systems (e.g

Jina AI 17k Jan 02, 2023
DLL: Direct Lidar Localization

DLL: Direct Lidar Localization Summary This package presents DLL, a direct map-based localization technique using 3D LIDAR for its application to aeri

Service Robotics Lab 127 Dec 16, 2022
A PyTorch implementation of Multi-digit Number Recognition from Street View Imagery using Deep Convolutional Neural Networks

SVHNClassifier-PyTorch A PyTorch implementation of Multi-digit Number Recognition from Street View Imagery using Deep Convolutional Neural Networks If

Potter Hsu 182 Jan 03, 2023
Fully Convolutional Networks for Semantic Segmentation by Jonathan Long*, Evan Shelhamer*, and Trevor Darrell. CVPR 2015 and PAMI 2016.

Fully Convolutional Networks for Semantic Segmentation This is the reference implementation of the models and code for the fully convolutional network

Evan Shelhamer 3.2k Jan 08, 2023