Code for "Training Neural Networks with Fixed Sparse Masks" (NeurIPS 2021).

Related tags

Deep LearningFISH
Overview

Fisher Induced Sparse uncHanging (FISH) Mask

This repo contains the code for Fisher Induced Sparse uncHanging (FISH) Mask training, from "Training Neural Networks with Fixed Sparse Masks" by Yi-Lin Sung, Varun Nair, and Colin Raffel. To appear in Neural Information Processing Systems (NeurIPS) 2021.

Abstract: During typical gradient-based training of deep neural networks, all of the model's parameters are updated at each iteration. Recent work has shown that it is possible to update only a small subset of the model's parameters during training, which can alleviate storage and communication requirements. In this paper, we show that it is possible to induce a fixed sparse mask on the model’s parameters that selects a subset to update over many iterations. Our method constructs the mask out of the parameters with the largest Fisher information as a simple approximation as to which parameters are most important for the task at hand. In experiments on parameter-efficient transfer learning and distributed training, we show that our approach matches or exceeds the performance of other methods for training with sparse updates while being more efficient in terms of memory usage and communication costs.

Setup

pip install transformers/.
pip install datasets torch==1.8.0 tqdm torchvision==0.9.0

FISH Mask: GLUE Experiments

Parameter-Efficient Transfer Learning

To run the FISH Mask on a GLUE dataset, code can be run with the following format:

$ bash transformers/examples/text-classification/scripts/run_sparse_updates.sh <dataset-name> <seed> <top_k_percentage> <num_samples_for_fisher>

An example command used to generate Table 1 in the paper is as follows, where all GLUE tasks are provided at a seed of 0 and a FISH mask sparsity of 0.5%.

$ bash transformers/examples/text-classification/scripts/run_sparse_updates.sh "qqp mnli rte cola stsb sst2 mrpc qnli" 0 0.005 1024

Distributed Training

To use the FISH mask on the GLUE tasks in a distributed setting, one can use the following command.

$ bash transformers/examples/text-classification/scripts/distributed_training.sh <dataset-name> <seed> <num_workers> <training_epochs> <gpu_id>

Note the <dataset-name> here can only contain one task, so an example command could be

$ bash transformers/examples/text-classification/scripts/distributed_training.sh "mnli" 0 2 3.5 0

FISH Mask: CIFAR10 Experiments

To run the FISH mask on CIFAR10, code can be run with the following format:

Distributed Training

$ bash cifar10-fast/scripts/distributed_training_fish.sh <num_samples_for_fisher> <top_k_percentage> <training_epochs> <worker_updates> <learning_rate> <num_workers>

For example, in the paper, we compute the FISH mask of the 0.5% sparsity level by 256 samples and distribute the job to 2 workers for a total of 50 epochs training. Then the command would be

$ bash cifar10-fast/scripts/distributed_training_fish.sh 256 0.005 50 2 0.4 2

Efficient Checkpointing

$ bash cifar10-fast/scripts/small_checkpoints_fish.sh <num_samples_for_fisher> <top_k_percentage> <training_epochs> <learning_rate> <fix_mask>

The hyperparameters are almost the same as distributed training. However, the <fix_mask> is to indicate to fix the mask or not, and a valid input is either 0 or 1 (1 means to fix the mask).

Replicating Results

Replicating each of the tables and figures present in the original paper can be done by running the following:

# Table 1 - Parameter Efficient Fine-Tuning on GLUE

$ bash transformers/examples/text-classification/scripts/run_table_1.sh
# Figure 2 - Mask Sparsity Ablation and Sample Ablation

$ bash transformers/examples/text-classification/scripts/run_figure_2.sh
# Table 2 - Distributed Training on GLUE

$ bash transformers/examples/text-classification/scripts/run_table_2.sh
# Table 3 - Distributed Training on CIFAR10

$ bash cifar10-fast/scripts/distributed_training.sh

# Table 4 - Efficient Checkpointing

$ bash cifar10-fast/scripts/small_checkpoints.sh

Notes

  • For reproduction of Diff Pruning results from Table 1, see code here.

Acknowledgements

We thank Yoon Kim, Michael Matena, and Demi Guo for helpful discussions.

Owner
Varun Nair
Hi! I'm a student at Duke University studying CS. I'm interested in researching AI/ML and its applications in medicine, transportation, & education.
Varun Nair
In this tutorial, you will perform inference across 10 well-known pre-trained object detectors and fine-tune on a custom dataset. Design and train your own object detector.

Object Detection Object detection is a computer vision task for locating instances of predefined objects in images or videos. In this tutorial, you wi

Ibrahim Sobh 62 Dec 25, 2022
A Blender python script for getting asset browser custom preview images for objects and collections.

asset_snapshot A Blender python script for getting asset browser custom preview images for objects and collections. Installation: Click the code butto

Johnny Matthews 44 Nov 29, 2022
Multi-robot collaborative exploration and mapping through Voronoi partition and DRL in unknown environment

Voronoi Multi_Robot Collaborate Exploration Introduction In the unknown environment, the cooperative exploration of multiple robots is completed by Vo

PeaceWord 6 Nov 22, 2022
Submodular Subset Selection for Active Domain Adaptation (ICCV 2021)

S3VAADA: Submodular Subset Selection for Virtual Adversarial Active Domain Adaptation ICCV 2021 Harsh Rangwani, Arihant Jain*, Sumukh K Aithal*, R. Ve

Video Analytics Lab -- IISc 13 Dec 28, 2022
Generate images from texts. In Russian

ruDALL-E Generate images from texts pip install rudalle==1.1.0rc0 🤗 HF Models: ruDALL-E Malevich (XL) ruDALL-E Emojich (XL) (readme here) ruDALL-E S

AI Forever 1.6k Dec 31, 2022
Automated Melanoma Recognition in Dermoscopy Images via Very Deep Residual Networks

Introduction This repository contains the modified caffe library and network architectures for our paper "Automated Melanoma Recognition in Dermoscopy

Lequan Yu 47 Nov 24, 2022
Retinal vessel segmentation based on GT-UNet

Retinal vessel segmentation based on GT-UNet Introduction This project is a retinal blood vessel segmentation code based on UNet-like Group Transforme

Kent0n 27 Dec 18, 2022
Inference pipeline for our participation in the FeTA challenge 2021.

feta-inference Inference pipeline for our participation in the FeTA challenge 2021. Team name: TRABIT Installation Download the two folders in https:/

Lucas Fidon 2 Apr 13, 2022
DCSAU-Net: A Deeper and More Compact Split-Attention U-Net for Medical Image Segmentation

DCSAU-Net: A Deeper and More Compact Split-Attention U-Net for Medical Image Segmentation By Qing Xu, Wenting Duan and Na He Requirements pytorch==1.1

Qing Xu 20 Dec 09, 2022
the official code for ICRA 2021 Paper: "Multimodal Scale Consistency and Awareness for Monocular Self-Supervised Depth Estimation"

G2S This is the official code for ICRA 2021 Paper: Multimodal Scale Consistency and Awareness for Monocular Self-Supervised Depth Estimation by Hemang

NeurAI 4 Jul 27, 2022
TensorFlow (v2.7.0) benchmark results on an M1 Macbook Air 2020 laptop (macOS Monterey v12.1).

M1-tensorflow-benchmark TensorFlow (v2.7.0) benchmark results on an M1 Macbook Air 2020 laptop (macOS Monterey v12.1). I was initially testing if Tens

particle 2 Jan 05, 2022
Another pytorch implementation of FCN (Fully Convolutional Networks)

FCN-pytorch-easiest Trying to be the easiest FCN pytorch implementation and just in a get and use fashion Here I use a handbag semantic segmentation f

Y. Dong 158 Dec 21, 2022
This Repo is the official CUDA implementation of ICCV 2019 Oral paper for CARAFE: Content-Aware ReAssembly of FEatures

Introduction This Repo is the official CUDA implementation of ICCV 2019 Oral paper for CARAFE: Content-Aware ReAssembly of FEatures. @inproceedings{Wa

Jiaqi Wang 42 Jan 07, 2023
Probabilistic Programming and Statistical Inference in PyTorch

PtStat Probabilistic Programming and Statistical Inference in PyTorch. Introduction This project is being developed during my time at Cogent Labs. The

Stefano Peluchetti 109 Nov 26, 2022
Information Gain Filtration (IGF) is a method for filtering domain-specific data during language model finetuning. IGF shows significant improvements over baseline fine-tuning without data filtration.

Information Gain Filtration Information Gain Filtration (IGF) is a method for filtering domain-specific data during language model finetuning. IGF sho

4 Jul 28, 2022
PyTorch implementation of our ICCV 2021 paper, Interpretation of Emergent Communication in Heterogeneous Collaborative Embodied Agents.

PyTorch implementation of our ICCV 2021 paper, Interpretation of Emergent Communication in Heterogeneous Collaborative Embodied Agents.

Saim Wani 4 May 08, 2022
Python implementation of a live deep learning based age/gender/expression recognizer

TUT live age estimator Python implementation of a live deep learning based age/gender/smile/celebrity twin recognizer. All components use convolutiona

Heikki Huttunen 80 Nov 21, 2022
TinyML Cookbook, published by Packt

TinyML Cookbook This is the code repository for TinyML Cookbook, published by Packt. Author: Gian Marco Iodice Publisher: Packt About the book This bo

Packt 93 Dec 29, 2022
StackNet is a computational, scalable and analytical Meta modelling framework

StackNet This repository contains StackNet Meta modelling methodology (and software) which is part of my work as a PhD Student in the computer science

Marios Michailidis 1.3k Dec 15, 2022
automatic color-grading

color-matcher Description color-matcher enables color transfer across images which comes in handy for automatic color-grading of photographs, painting

hahnec 168 Jan 05, 2023