Cross-Task Consistency Learning Framework for Multi-Task Learning

Related tags

Deep Learningxtask_mt
Overview

Cross-Task Consistency Learning Framework for Multi-Task Learning

Tested on

  • numpy(v1.19.1)
  • opencv-python(v4.4.0.42)
  • torch(v1.7.0)
  • torchvision(v0.8.0)
  • tqdm(v4.48.2)
  • matplotlib(v3.3.1)
  • seaborn(v0.11.0)
  • pandas(v.1.1.2)

Data

Cityscapes (CS)

Download Cityscapes dataset and put it in a subdirectory named ./data/cityscapes. The folder should have the following subfolders:

  • RGB image in folder leftImg8bit
  • Segmentation in folder gtFine
  • Disparity maps in folder disparity

NYU

We use the preprocessed NYUv2 dataset provided by this repo. Download the dataset and put it in the dataset folder in ./data/nyu.

Model

The model consists of one encoder (ResNet) and two decoders, one for each task. The decoders outputs the predictions for each task ("direct predictions"), which are fed to the TaskTransferNet.
The objective of the TaskTranferNet is to predict the other task given a prediction image as an input (Segmentation prediction -> Depth prediction, vice versa), which I refer to as "transferred predictions"

Loss function

When computing the losses, the direct predictions are compared with the target while the transferred predictions are compared with the direct predictions so that they "align themselves".
The total loss consists of 4 different losses:

  • direct segmentation loss: CrossEntropyLoss()
  • direct depth loss: L1() or MSE() or logL1() or SmoothL1()
  • transferred segmentation loss:
    CrossEntropyLoss() or KLDivergence()
  • transferred depth loss: L1() or SSIM()

* Label smoothing: To "smooth" the one-hot probability by taking some of the probability from the correct class and distributing it among other classes.
* SSIM: Structural Similarity Loss

Flags

The flags are the same for both datasets. The flags and its usage are as written below,

Flag Name Usage Comments
input_path Path to dataset default is data/cityscapes (CS) or data/nyu (NYU)
height height of prediction default: 128 (CS) or 288 (NYU)
width width of prediction default: 256 (CS) or 384 (NYU)
epochs # of epochs default: 250 (CS) or 100 (NYU)
enc_layers which encoder to use default: 34, can choose from 18, 34, 50, 101, 152
use_pretrain toggle on to use pretrained encoder weights available for both datasets
batch_size batch size default: 8 (CS) or 6 (NYU)
scheduler_step_size step size for scheduler default: 80 (CS) or 60 (NYU), note that we use StepLR
scheduler_gamma decay rate of scheduler default: 0.5
alpha weight of adding transferred depth loss default: 0.01 (CS) or 0.0001 (NYU)
gamma weight of adding transferred segmentation loss default: 0.01 (CS) or 0.0001 (NYU)
label_smoothing amount of label smoothing default: 0.0
lp loss fn for direct depth loss default: L1, can choose from L1, MSE, logL1, smoothL1
tdep_loss loss fn for transferred depth loss default: L1, can choose from L1 or SSIM
tseg_loss loss fn for transferred segmentation loss default: cross, can choose from cross or kl
batch_norm toggle to enable batch normalization layer in TaskTransferNet slightly improves segmentation task
wider_ttnet toggle to double the # of channels in TaskTransferNet
uncertainty_weights toggle to use uncertainty weights (Kendall, et al. 2018) we used this for best results
gradnorm toggle to use GradNorm (Chen, et al. 2018)

Training

Cityscapes

For the Cityscapes dataset, there are two versions of segmentation task, which are 7-classes task and 19-classes task (Use flag 'num_classes' to switch tasks, default is 7).
So far, the results show near-SOTA for 7-class segmentation task + depth estimation.

ResNet34 was used as the encoder, L1() for direct depth loss and CrossEntropyLoss() for transferred segmentation loss.
The hyperparameter weights for both transferred predictions were 0.01.
I used Adam as my optimizer with an initial learning rate of 0.0001 and trained for 250 epochs with batch size 8. The learning rate was halved every 80 epochs.

To reproduce the code, use the following:

python main_cross_cs.py --uncertainty_weights

NYU

Our results show SOTA for NYU dataset.

ResNet34 was used as the encoder, L1() for direct depth loss and CrossEntropyLoss() for transferred segmentation loss.
The hyperparameter weights for both transferred predictions were 0.0001.
I used Adam as my optimizer with an initial learning rate of 0.0001 and trained for 100 epochs with batch size 6. The learning rate was halved every 60 epochs.

To reproduce the code, use the following:

python main_cross_nyu.py --uncertainty_weights

Comparisons

Evaluation metrics are the following:

Segmentation

  • Pixel accuracy (Pix Acc): percentage of pixels with the correct label
  • mIoU: mean Intersection over Union

Depth

  • Absolute Error (Abs)
  • Absolute Relative Error (Abs Rel): Absolute error divided by ground truth depth

The results are the following:

Cityscapes

Models mIoU Pix Acc Abs Abs Rel
MTAN 53.04 91.11 0.0144 33.63
KD4MTL 52.71 91.54 0.0139 27.33
PCGrad 53.59 91.45 0.0171 31.34
AdaMT-Net 62.53 94.16 0.0125 22.23
Ours 66.51 93.56 0.0122 19.40

NYU

Models mIoU Pix Acc Abs Abs Rel
MTAN* 21.07 55.70 0.6035 0.2472
MTAN† 20.10 53.73 0.6417 0.2758
KD4MTL* 20.75 57.90 0.5816 0.2445
KD4MTL† 22.44 57.32 0.6003 0.2601
PCGrad* 20.17 56.65 0.5904 0.2467
PCGrad† 21.29 54.07 0.6705 0.3000
AdaMT-Net* 21.86 60.35 0.5933 0.2456
AdaMT-Net† 20.61 58.91 0.6136 0.2547
Ours† 30.31 63.02 0.5954 0.2235

*: Trained on 3 tasks (segmentation, depth, and surface normal)
†: Trained on 2 tasks (segmentation and depth)
Italic: Reproduced by ourselves

Scores with models trained on 3 tasks for NYU dataset are shown only as reference.

Papers referred

MTAN: [paper][github]
KD4MTL: [paper][github]
PCGrad: [paper][github (tensorflow)][github (pytorch)]
AdaMT-Net: [paper]

Owner
Aki Nakano
Student at the University of Tokyo pursuing master's degree. Joined UC Berkeley Summer Session 2019. Researching deep learning. Python/R
Aki Nakano
A Python library that provides a simplified alternative to DBAPI 2

A Python library that provides a simplified alternative to DBAPI 2. It provides a facade in front of DBAPI 2 drivers.

Tony Locke 44 Nov 17, 2021
This project is the official implementation of our accepted ICLR 2021 paper BiPointNet: Binary Neural Network for Point Clouds.

BiPointNet: Binary Neural Network for Point Clouds Created by Haotong Qin, Zhongang Cai, Mingyuan Zhang, Yifu Ding, Haiyu Zhao, Shuai Yi, Xianglong Li

Haotong Qin 59 Dec 17, 2022
ONNX Command-Line Toolbox

ONNX Command Line Toolbox Aims to improve your experience of investigating ONNX models. Use it like onnx infershape /path/to/model.onnx. (See the usag

黎明灰烬 (王振华 Zhenhua WANG) 23 Nov 13, 2022
Create images and texts with the First Order Generative Adversarial Networks

First Order Divergence for training GANs This repository contains code accompanying the paper First Order Generative Advesarial Netoworks The majority

Zalando Research 35 Dec 11, 2021
This is the first released system towards complex meters` detection and recognition, which is implemented by computer vision techniques.

A three-stage detection and recognition pipeline of complex meters in wild This is the first released system towards detection and recognition of comp

Yan Shu 19 Nov 28, 2022
Pytorch implementations of Bayes By Backprop, MC Dropout, SGLD, the Local Reparametrization Trick, KF-Laplace, SG-HMC and more

Bayesian Neural Networks Pytorch implementations for the following approximate inference methods: Bayes by Backprop Bayes by Backprop + Local Reparame

1.4k Jan 07, 2023
A Moonraker plug-in for real-time compensation of frame thermal expansion

Frame Expansion Compensation A Moonraker plug-in for real-time compensation of frame thermal expansion. Installation Credit to protoloft, from whom I

58 Jan 02, 2023
Music Source Separation; Train & Eval & Inference piplines and pretrained models we used for 2021 ISMIR MDX Challenge.

Introduction 1. Usage (For MSS) 1.1 Prepare running environment 1.2 Use pretrained model 1.3 Train new MSS models from scratch 1.3.1 How to train 1.3.

Leo 100 Dec 25, 2022
This repository contains the code for the paper "Hierarchical Motion Understanding via Motion Programs"

Hierarchical Motion Understanding via Motion Programs (CVPR 2021) This repository contains the official implementation of: Hierarchical Motion Underst

Sumith Kulal 40 Dec 05, 2022
Implementation of the paper titled "Using Sampling to Estimate and Improve Performance of Automated Scoring Systems with Guarantees"

Using Sampling to Estimate and Improve Performance of Automated Scoring Systems with Guarantees Implementation of the paper titled "Using Sampling to

MIDAS, IIIT Delhi 2 Aug 29, 2022
A code repository associated with the paper A Benchmark for Rough Sketch Cleanup by Chuan Yan, David Vanderhaeghe, and Yotam Gingold from SIGGRAPH Asia 2020.

A Benchmark for Rough Sketch Cleanup This is the code repository associated with the paper A Benchmark for Rough Sketch Cleanup by Chuan Yan, David Va

33 Dec 18, 2022
Demo project for real time anomaly detection using kafka and python

kafkaml-anomaly-detection Project for real time anomaly detection using kafka and python It's assumed that zookeeper and kafka are running in the loca

Rodrigo Arenas 36 Dec 12, 2022
A new test set for ImageNet

ImageNetV2 The ImageNetV2 dataset contains new test data for the ImageNet benchmark. This repository provides associated code for assembling and worki

186 Dec 18, 2022
Ranger deep learning optimizer rewrite to use newest components

Ranger21 - integrating the latest deep learning components into a single optimizer Ranger deep learning optimizer rewrite to use newest components Ran

Less Wright 266 Dec 28, 2022
Training BERT with Compute/Time (Academic) Budget

Training BERT with Compute/Time (Academic) Budget This repository contains scripts for pre-training and finetuning BERT-like models with limited time

Intel Labs 263 Jan 07, 2023
MVSDF - Learning Signed Distance Field for Multi-view Surface Reconstruction

MVSDF - Learning Signed Distance Field for Multi-view Surface Reconstruction This is the official implementation for the ICCV 2021 paper Learning Sign

110 Dec 20, 2022
Source code for "FastBERT: a Self-distilling BERT with Adaptive Inference Time".

FastBERT Source code for "FastBERT: a Self-distilling BERT with Adaptive Inference Time". Good News 2021/10/29 - Code: Code of FastPLM is released on

Weijie Liu 584 Jan 02, 2023
This is the official pytorch implementation for our ICCV 2021 paper "TRAR: Routing the Attention Spans in Transformers for Visual Question Answering" on VQA Task

🌈 ERASOR (RA-L'21 with ICRA Option) Official page of "ERASOR: Egocentric Ratio of Pseudo Occupancy-based Dynamic Object Removal for Static 3D Point C

Hyungtae Lim 225 Dec 29, 2022
Identify the emotion of multiple speakers in an Audio Segment

MevonAI - Speech Emotion Recognition Identify the emotion of multiple speakers in a Audio Segment Report Bug · Request Feature Try the Demo Here Table

Suyash More 110 Dec 03, 2022
An OpenAI Gym environment for Super Mario Bros

gym-super-mario-bros An OpenAI Gym environment for Super Mario Bros. & Super Mario Bros. 2 (Lost Levels) on The Nintendo Entertainment System (NES) us

Andrew Stelmach 1 Jan 05, 2022