Code for "Training Neural Networks with Fixed Sparse Masks" (NeurIPS 2021).

Related tags

Deep LearningFISH
Overview

Fisher Induced Sparse uncHanging (FISH) Mask

This repo contains the code for Fisher Induced Sparse uncHanging (FISH) Mask training, from "Training Neural Networks with Fixed Sparse Masks" by Yi-Lin Sung, Varun Nair, and Colin Raffel. To appear in Neural Information Processing Systems (NeurIPS) 2021.

Abstract: During typical gradient-based training of deep neural networks, all of the model's parameters are updated at each iteration. Recent work has shown that it is possible to update only a small subset of the model's parameters during training, which can alleviate storage and communication requirements. In this paper, we show that it is possible to induce a fixed sparse mask on the model’s parameters that selects a subset to update over many iterations. Our method constructs the mask out of the parameters with the largest Fisher information as a simple approximation as to which parameters are most important for the task at hand. In experiments on parameter-efficient transfer learning and distributed training, we show that our approach matches or exceeds the performance of other methods for training with sparse updates while being more efficient in terms of memory usage and communication costs.

Setup

pip install transformers/.
pip install datasets torch==1.8.0 tqdm torchvision==0.9.0

FISH Mask: GLUE Experiments

Parameter-Efficient Transfer Learning

To run the FISH Mask on a GLUE dataset, code can be run with the following format:

$ bash transformers/examples/text-classification/scripts/run_sparse_updates.sh <dataset-name> <seed> <top_k_percentage> <num_samples_for_fisher>

An example command used to generate Table 1 in the paper is as follows, where all GLUE tasks are provided at a seed of 0 and a FISH mask sparsity of 0.5%.

$ bash transformers/examples/text-classification/scripts/run_sparse_updates.sh "qqp mnli rte cola stsb sst2 mrpc qnli" 0 0.005 1024

Distributed Training

To use the FISH mask on the GLUE tasks in a distributed setting, one can use the following command.

$ bash transformers/examples/text-classification/scripts/distributed_training.sh <dataset-name> <seed> <num_workers> <training_epochs> <gpu_id>

Note the <dataset-name> here can only contain one task, so an example command could be

$ bash transformers/examples/text-classification/scripts/distributed_training.sh "mnli" 0 2 3.5 0

FISH Mask: CIFAR10 Experiments

To run the FISH mask on CIFAR10, code can be run with the following format:

Distributed Training

$ bash cifar10-fast/scripts/distributed_training_fish.sh <num_samples_for_fisher> <top_k_percentage> <training_epochs> <worker_updates> <learning_rate> <num_workers>

For example, in the paper, we compute the FISH mask of the 0.5% sparsity level by 256 samples and distribute the job to 2 workers for a total of 50 epochs training. Then the command would be

$ bash cifar10-fast/scripts/distributed_training_fish.sh 256 0.005 50 2 0.4 2

Efficient Checkpointing

$ bash cifar10-fast/scripts/small_checkpoints_fish.sh <num_samples_for_fisher> <top_k_percentage> <training_epochs> <learning_rate> <fix_mask>

The hyperparameters are almost the same as distributed training. However, the <fix_mask> is to indicate to fix the mask or not, and a valid input is either 0 or 1 (1 means to fix the mask).

Replicating Results

Replicating each of the tables and figures present in the original paper can be done by running the following:

# Table 1 - Parameter Efficient Fine-Tuning on GLUE

$ bash transformers/examples/text-classification/scripts/run_table_1.sh
# Figure 2 - Mask Sparsity Ablation and Sample Ablation

$ bash transformers/examples/text-classification/scripts/run_figure_2.sh
# Table 2 - Distributed Training on GLUE

$ bash transformers/examples/text-classification/scripts/run_table_2.sh
# Table 3 - Distributed Training on CIFAR10

$ bash cifar10-fast/scripts/distributed_training.sh

# Table 4 - Efficient Checkpointing

$ bash cifar10-fast/scripts/small_checkpoints.sh

Notes

  • For reproduction of Diff Pruning results from Table 1, see code here.

Acknowledgements

We thank Yoon Kim, Michael Matena, and Demi Guo for helpful discussions.

Owner
Varun Nair
Hi! I'm a student at Duke University studying CS. I'm interested in researching AI/ML and its applications in medicine, transportation, & education.
Varun Nair
Breaking the Curse of Space Explosion: Towards Efficient NAS with Curriculum Search

Breaking the Curse of Space Explosion: Towards Effcient NAS with Curriculum Search Pytorch implementation for "Breaking the Curse of Space Explosion:

guoyong 17 Jan 03, 2023
Malware Analysis Neural Network project.

MalanaNeuralNetwork Description Malware Analysis Neural Network project. Table of Contents Getting Started Requirements Installation Clone Set-Up VENV

2 Nov 13, 2021
Brain tumor detection using Convolution-Neural Network (CNN)

Detect and Classify Brain Tumor using CNN. A system performing detection and classification by using Deep Learning Algorithms using Convolution-Neural Network (CNN).

assia 1 Feb 07, 2022
Theano is a Python library that allows you to define, optimize, and evaluate mathematical expressions involving multi-dimensional arrays efficiently. It can use GPUs and perform efficient symbolic differentiation.

============================================================================================================ `MILA will stop developing Theano https:

9.6k Dec 31, 2022
Facilitating Database Tuning with Hyper-ParameterOptimization: A Comprehensive Experimental Evaluation

A Comprehensive Experimental Evaluation for Database Configuration Tuning This is the source code to the paper "Facilitating Database Tuning with Hype

DAIR Lab 9 Oct 29, 2022
SMIS - Semantically Multi-modal Image Synthesis(CVPR 2020)

Semantically Multi-modal Image Synthesis Project page / Paper / Demo Semantically Multi-modal Image Synthesis(CVPR2020). Zhen Zhu, Zhiliang Xu, Anshen

316 Dec 01, 2022
A general-purpose programming language, focused on simplicity, safety and stability.

The Rivet programming language A general-purpose programming language, focused on simplicity, safety and stability. Rivet's goal is to be a very power

The Rivet programming language 17 Dec 29, 2022
Repository for "Exploring Sparsity in Image Super-Resolution for Efficient Inference", CVPR 2021

SMSR Reposity for "Exploring Sparsity in Image Super-Resolution for Efficient Inference" [arXiv] Highlights Locate and skip redundant computation in S

Longguang Wang 225 Dec 26, 2022
A project which aims to protect your privacy using inexpensive hardware and easily modifiable software

Protecting your privacy using an ESP32, an IR sensor and a python script This project, which I personally call the "never-gonna-catch-me-in-the-act-ev

8 Oct 10, 2022
Multimodal Temporal Context Network (MTCN)

Multimodal Temporal Context Network (MTCN) This repository implements the model proposed in the paper: Evangelos Kazakos, Jaesung Huh, Arsha Nagrani,

Evangelos Kazakos 13 Nov 24, 2022
Yet Another Robotics and Reinforcement (YARR) learning framework for PyTorch.

Yet Another Robotics and Reinforcement (YARR) learning framework for PyTorch.

Stephen James 51 Dec 27, 2022
Download files from DSpace systems (because for some reason DSpace won't let you)

DSpaceDL A tool for downloading files from DSpace items. For some reason, DSpace systems have a dogshit UI, and Universities absolutely LOOOVE to use

Soumitra Shewale 5 Dec 01, 2022
Tracking code for the winner of track 1 in the MMP-Tracking Challenge at ICCV 2021 Workshop.

Tracking Code for the winner of track1 in MMP-Trakcing challenge This repository contains our tracking code for the Multi-camera Multiple People Track

DamoCV 29 Nov 13, 2022
IEEE-CIS Technical Challenge on Predict+Optimize for Renewable Energy Scheduling

IEEE-CIS Technical Challenge on Predict+Optimize for Renewable Energy Scheduling This is my code, data and approach for the IEEE-CIS Technical Challen

3 Sep 18, 2022
This repository contains the needed resources to build the HIRID-ICU-Benchmark dataset

HiRID-ICU-Benchmark This repository contains the needed resources to build the HIRID-ICU-Benchmark dataset for which the manuscript can be found here.

Biomedical Informatics at ETH Zurich 30 Dec 16, 2022
Pytorch implementation of NeurIPS 2021 paper: Geometry Processing with Neural Fields.

Geometry Processing with Neural Fields Pytorch implementation for the NeurIPS 2021 paper: Geometry Processing with Neural Fields Guandao Yang, Serge B

Guandao Yang 162 Dec 16, 2022
Class activation maps for your PyTorch models (CAM, Grad-CAM, Grad-CAM++, Smooth Grad-CAM++, Score-CAM, SS-CAM, IS-CAM, XGrad-CAM, Layer-CAM)

TorchCAM: class activation explorer Simple way to leverage the class-specific activation of convolutional layers in PyTorch. Quick Tour Setting your C

F-G Fernandez 1.2k Dec 29, 2022
Supporting code for the Neograd algorithm

Neograd This repo supports the paper Neograd: Gradient Descent with a Near-Ideal Learning Rate, which introduces the algorithm "Neograd". The paper an

Michael Zimmer 12 May 01, 2022
Final project code: Implementing MAE with downscaled encoders and datasets, for ESE546 FA21 at University of Pennsylvania

546 Final Project: Masked Autoencoder Haoran Tang, Qirui Wu 1. Training To train the network, please run mae_pretraining.py. Please modify folder path

Haoran Tang 0 Apr 22, 2022