The official PyTorch implementation of recent paper - SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training

Overview

This repository is the official PyTorch implementation of SAINT. Find the paper on arxiv

SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training

Overview

Requirements

We recommend using anaconda or miniconda for python. Our code has been tested with python=3.8 on linux.

To create a new environment with conda

conda create -n saint_env python=3.8
conda activate saint_env

We recommend installing the latest pytorch, torchvision, einops, pandas, wget, sklearn packages.

You can install them using

conda install pytorch torchvision -c pytorch
conda install -c conda-forge einops 
conda install -c conda-forge pandas 
conda install -c conda-forge python-wget 
conda install -c anaconda scikit-learn 

Make sure the following requirements are met

  • torch>=1.8.1
  • torchvision>=0.9.1

Optional

We used wandb to update our logs. But it is optional.

conda install -c conda-forge wandb 

Training & Evaluation

In each of our experiments, we use a single Nvidia GeForce RTX 2080Ti GPU.

First download the processed datasets from this link into the folder ./data

To train the model(s) in the paper, run this command:

python train.py  --dataset <dataset_name> --attentiontype <attention_type> 

Pretraining is useful when there are few training data samples. Sample code looks like this

python train.py  --dataset <dataset_name> --attentiontype <attention_type> --pretrain --pt_tasks <pretraining_task_touse> --pt_aug <augmentations_on_data_touse> --ssl_avail_y <Number_of_labeled_samples>

Train all 16 datasets by running bash files. train.sh for supervised learning and train_pt.sh for pretraining and semi-supervised learning

bash train.sh
bash train_pt.sh

Arguments

  • --dataset : Dataset name. We support only the 16 datasets discussed in the paper. Supported datasets are ['1995_income','bank_marketing','qsar_bio','online_shoppers','blastchar','htru2','shrutime','spambase','philippine','mnist','arcene','volkert','creditcard','arrhythmia','forest','kdd99']
  • --embedding_size : Size of the feature embeddings
  • --transformer_depth : Depth of the model. Number of stages.
  • --attention_heads : Number of attention heads in each Attention layer.
  • --cont_embeddings : Style of embedding continuous data.
  • --attentiontype : Variant of SAINT. 'col' refers to SAINT-s variant, 'row' is SAINT-i, and 'colrow' refers to SAINT.
  • --pretrain : To enable pretraining
  • --pt_tasks : Losses we want to use for pretraining. Multiple arguments can be passed.
  • --pt_aug : Types of data augmentations used in pretraining. Multiple arguments are allowed. We support only mixup and CutMix right now.
  • --ssl_avail_y : Number of labeled samples used in semi-supervised experiments. Default is 0, which means all samples are labeled and is supervised case.
  • --pt_projhead_style : Projection head style used in contrastive pipeline.
  • --nce_temp : Temperature used in contrastive loss function.
  • --active_log : To update the logs onto wandb. This is optional

Evaluation

We choose the best model by evaluating the model on validation dataset. The AUROC(for binary classification datasets) and Accuracy (for multiclass classification datasets) of the best model on test datasets is printed after training is completed. If wandb is enabled, they are logged to 'test_auroc_bestep', 'test_accuracy_bestep' variables.

Acknowledgements

We would like to thank the following public repo from which we borrowed various utilites.

License

This repository is released under the Apache 2.0 license as found in the LICENSE file.

Cite us

@article{somepalli2021saint,
  title={SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training},
  author={Somepalli, Gowthami and Goldblum, Micah and Schwarzschild, Avi and Bruss, C Bayan and Goldstein, Tom},
  journal={arXiv preprint arXiv:2106.01342},
  year={2021}
}

Owner
Gowthami Somepalli
Gowthami Somepalli
Improving Machine Translation Systems via Isotopic Replacement

CAT (Improving Machine Translation Systems via Isotopic Replacement) Machine translation plays an essential role in people’s daily international commu

Zeyu Sun 10 Nov 30, 2022
The AugNet Python module contains functions for the fast computation of image similarity.

AugNet AugNet: End-to-End Unsupervised Visual Representation Learning with Image Augmentation arxiv link In our work, we propose AugNet, a new deep le

Ming 74 Dec 28, 2022
Learning hierarchical attention for weakly-supervised chest X-ray abnormality localization and diagnosis

Hierarchical Attention Mining (HAM) for weakly-supervised abnormality localization This is the official PyTorch implementation for the HAM method. Pap

Xi Ouyang 22 Jan 02, 2023
Apache Spark - A unified analytics engine for large-scale data processing

Apache Spark Spark is a unified analytics engine for large-scale data processing. It provides high-level APIs in Scala, Java, Python, and R, and an op

The Apache Software Foundation 34.7k Jan 04, 2023
The Python ensemble sampling toolkit for affine-invariant MCMC

emcee The Python ensemble sampling toolkit for affine-invariant MCMC emcee is a stable, well tested Python implementation of the affine-invariant ense

Dan Foreman-Mackey 1.3k Dec 31, 2022
Python Actor concurrency library

Thespian Actor Library This library provides the framework of an Actor model for use by applications implementing Actors. Thespian Site with Documenta

Kevin Quick 177 Dec 11, 2022
Code, environments, and scripts for the paper: "How Private Is Your RL Policy? An Inverse RL Based Analysis Framework"

Privacy-Aware Inverse RL (PRIL) Analysis Framework Code, environments, and scripts for the paper: "How Private Is Your RL Policy? An Inverse RL Based

1 Dec 06, 2021
Deep learning library for solving differential equations and more

DeepXDE Voting on whether we should have a Slack channel for discussion. DeepXDE is a library for scientific machine learning. Use DeepXDE if you need

Lu Lu 1.4k Dec 29, 2022
[CVPR 2021 Oral] ForgeryNet: A Versatile Benchmark for Comprehensive Forgery Analysis

ForgeryNet: A Versatile Benchmark for Comprehensive Forgery Analysis ForgeryNet: A Versatile Benchmark for Comprehensive Forgery Analysis [arxiv|pdf|v

Yinan He 78 Dec 22, 2022
Semi-Supervised Semantic Segmentation via Adaptive Equalization Learning, NeurIPS 2021 (Spotlight)

Semi-Supervised Semantic Segmentation via Adaptive Equalization Learning, NeurIPS 2021 (Spotlight) Abstract Due to the limited and even imbalanced dat

Hanzhe Hu 99 Dec 12, 2022
Self-Learned Video Rain Streak Removal: When Cyclic Consistency Meets Temporal Correspondence

In this paper, we address the problem of rain streaks removal in video by developing a self-learned rain streak removal method, which does not require any clean groundtruth images in the training pro

Yang Wenhan 44 Dec 06, 2022
The AWS Certified SysOps Administrator

The AWS Certified SysOps Administrator – Associate (SOA-C02) exam is intended for system administrators in a cloud operations role who have at least 1 year of hands-on experience with deployment, man

Aiden Pearce 32 Dec 11, 2022
Simple PyTorch hierarchical models.

A python package adding basic hierarchal networks in pytorch for classification tasks. It implements a simple hierarchal network structure based on feed-backward outputs.

Rajiv Sarvepalli 5 Mar 06, 2022
PyTorch implementation of CloudWalk's recent work DenseBody

densebody_pytorch PyTorch implementation of CloudWalk's recent paper DenseBody. Note: For most recent updates, please check out the dev branch. Update

Lingbo Yang 401 Nov 19, 2022
transfer attack; adversarial examples; black-box attack; unrestricted Adversarial Attacks on ImageNet; CVPR2021 天池黑盒竞赛

transfer_adv CVPR-2021 AIC-VI: unrestricted Adversarial Attacks on ImageNet CVPR2021 安全AI挑战者计划第六期赛道2:ImageNet无限制对抗攻击 介绍 : 深度神经网络已经在各种视觉识别问题上取得了最先进的性能。

25 Dec 08, 2022
Improving Compound Activity Classification via Deep Transfer and Representation Learning

Improving Compound Activity Classification via Deep Transfer and Representation Learning This repository is the official implementation of Improving C

NingLab 2 Nov 24, 2021
Continuous Security Group Rule Change Detection & Response at scale

Introduction Get notified of Security Group Changes across all AWS Accounts & Regions in an AWS Organization, with the ability to respond/revert those

Raajhesh Kannaa Chidambaram 3 Aug 13, 2022
ICRA 2021 "Towards Precise and Efficient Image Guided Depth Completion"

PENet: Precise and Efficient Depth Completion This repo is the PyTorch implementation of our paper to appear in ICRA2021 on "Towards Precise and Effic

232 Dec 25, 2022
A unified framework for machine learning with time series

Welcome to sktime A unified framework for machine learning with time series We provide specialized time series algorithms and scikit-learn compatible

The Alan Turing Institute 6k Jan 08, 2023
Algorithmic Trading using RNN

Deep-Trading This an implementation adapted from Rachnog Neural networks for algorithmic trading. Part One — Simple time series forecasting and this c

Hazem Nomer 29 Sep 04, 2022