Run Effective Large Batch Contrastive Learning on Limited Memory GPU

Overview

Gradient Cache

Gradient Cache is a simple technique for unlimitedly scaling contrastive learning batch far beyond GPU memory constraint. This means training that used to take heavy hardware, e.g. 8 V100 GPU, can be done on a single GPU. In addition, Gradient Cache allow users to replace big RAM GPU with much more cost efficient high FLOP low RAM cards.

This repo holds a generic Pytorch implementation of Gradient Cache described in our paper Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup .

@inproceedings{gao2021scaling,
     title={Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup},
     author={Luyu Gao, Yunyi Zhang, Jiawei Han, Jamie Callan},
     booktitle ={Proceedings of the 6th Workshop on Representation Learning for NLP},
     year={2021},
}

Gradient Cache has also been integrated into dense passage retrieval (DPR). Checkout our GC-DPR toolkit.

Installation

The package depends only on pytorch>=1.6. To install, clone this repo and run pip.

git clone https://github.com/luyug/GradCache
cd GradCache
pip install .

For development,

pip install --editable .

Usage

Gradient caching functionalities are implemented in GradCache class. If you are developing a new project instead of patching an old one, also checkout our functional approach for a effort reduced approach.

Initialization

The class's __init__ method defines the cache and has several functional parameters *_fn for easy adjust of model behaviors. Alternatively you can also sub-class GradCache.

grad_cache.GradCache(  
  models: List[nn.Module],  
  chunk_sizes: Union[int, List[int]],  
  loss_fn: Callable[..., Tensor],  
  split_input_fn: Callable[[Any, int], Any] = None,  
  get_rep_fn: Callable[..., Tensor] = None,  
  fp16: bool = False,  
  scaler: GradScaler = None,  
)

models - A list of encoder models to be updated with with the Gradient Cache.

chunk_sizes - An integer indicating chunk size. Or a list of integers of chunk size for each model. This controls for each model the sub-batch size to run forward-backward pass and should be set based on available GPU memory. A value too small will leave the GPU under utilized.

loss_fn - A loss function that takes representation tensors of number equal to number of models in models and arbitrary numbers of keyword arguments. It should compute loss based on the input tensors, and in no case modify the input tensors' relations in the autograd graph, which are later relied upon to create the gradient cache.

split_input_fn - An optional function that split generic model input into chunks based on defined chunk_sizes. If not provided, this class will try its best to split the inputs of supported types. See split_inputs function.

get_rep_fn - An optional function that takes generic model output and return representation tensors. If not provided, the generic output is assumed to be the representation tensor.

fp16 - If True, run mixed precision training, which requires scaler to also be set.

scaler - A GradScaler object for automatic mixed precision training.

Cache Gradient Step

To run a cached gradient computatoin step, call cache_step function,

cache_step(  
  *model_inputs,  
  no_sync_except_last: bool = False,  
  **loss_kwargs  
)

Run a single gradient cache step. Upon function return, updates are computed for each model in self.models with gradient populated on the weights, as if the model_inputs are run as a huge single batch on sufficiently large hardware. Calling an GradCache object with __call__ will also invoke this function.

model_inputs - List of inputs to each encoder model. Should be in similar order as self.models.

no_sync_except_last - If True, under distributed setup, for each model, only trigger gradient reduction across processes for the last sub-batch's forward-backward pass. This could come in handy when dealing with a) large model, and/or b) non trivial number of sub-batches.

loss_kwargs - Additional keyword arguments to the loss function loss_fn. This is intended to enable flexible loss computation (thanks to dynamic graph in Pytorch) such as reduction, weighting, etc. Potentially, using loss_kwargs you can incorporate outputs from those encoder models not tracked by the cache.

Return - loss, the current steps loss scaler tensor (detached from the graph).

Natively Supported Input Types

  • x: Tensor - will be passed in as model(x)
  • x: List[Tensor] - will be passed in as model(*x)
  • x: Dict[str, Tensor] (or UserDict[str, Tensor]) - will be passed in as model(**x)
  • x: Tuple[List[Tensor], Dict[str, Tensor]] - will be passed in as model(*x[0], **x[1])

Other generic input are not fully supported, we perform model call using the following heuristics,

  • x: List[Any] - will be passed in as model(*x)
  • x: Dict[str, Any] - will be passed in as model(**x)
  • x: Tuple[List[Any], Dict[str, Any]] - will be passed in as model(*x[0], **x[1])

To run with them, split_input_fn should be specified during cache initialization to break these inputs into smaller batches. In some rare cases, you may also need to override get_input_tensors when its heuristic can not grab enough tensors that covers all cuda devices that hold some tensors in the input.

Example Usage with Huggingface Transformers

Learning a Bi-encoder

Say we want to learn a embedding space of labels and text. Consider the following four pairs. (In practice, you will have many more and much longer text entries.)

labels = ['fruit', 'meat', 'school', 'company']
texts = [
  'this is an apple', 
  'steak should be cooked medium rare', 
  'cmu is pittsburgh', 
  'apple sells laptop'
]

Initialize our encoder models,

from transformers import AutoTokenizer, TFAutoModel
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
encoder1 = AutoModel.from_pretrained("bert-base-uncased").cuda()
encoder2 = AutoModel.from_pretrained("bert-base-uncased").cuda()

Initialize the GradCache object,

from grad_cache import GradCache
from grad_cache.loss import SimpleContrastiveLoss

loss_fn = SimpleContrastiveLoss()
gc = GradCache(
  models=[encoder1, encoder2], 
  chunk_sizes=2, 
  loss_fn=loss_fn, 
  get_rep_fn=lambda v: v.pooler_output
)

Here we use the get_rep_fn argument to specify a function that takes generic Huggingface model output and return the actual representation tensor.

Create model input,

xx = tokenizer(tt, return_tensors='pt', padding=True)
yy = tokenizer(tt2, return_tensors='pt', padding=True)

Run a cache step,

gc(xx, yy, reduction='mean')

Here we use reduction='mean' as a loss_kwargs to control loss behavior. With a defined optimizer, the full gradient update can be done as,

optimizer.zero_grad()
gc(xx, yy, reduction='mean')
optimizer.step()

Use Tied Encoder?

This is naturally handled by the (magic of) dynamic graph. You pass shallow copies of the same encoder model to the GradCache init method.

tied_encoder = AutoModel.from_pretrained("bert-base-uncased").cuda()
gc = GradCache(
  models=[tied_encoder , tied_encoder], 
  chunk_sizes=2, 
  loss_fn=loss_fn, 
  get_rep_fn=lambda v: v.pooler_output
)

Under the hood, distinct hooks will be registered to make correct gradient computation.

Distributed Training with Multiple GPUs?

We expect cross process communication of representations to be handled by the loss_fn.

from grad_cache.loss import DistributedContrastiveLoss
loss_fn_dist = DistributedContrastiveLoss()

Properly wrap the the encoder models for gradient reduction,

encoder1_ddp = DistributedDataParallel(
	encoder1, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
encoder2_ddp = DistributedDataParallel(
	encoder2, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)

You can initialize the cache use the distributed loss and the DDP models,

gc = GradCache(
  models=[encoder1_ddp, encoder2_ddp], 
  chunk_sizes=2, 
  loss_fn=loss_fn_dist, 
  get_rep_fn=lambda v: v.pooler_output
)

Run a cache step,

gc(xx, yy, no_sync_except_last=True, reduction='mean')

Set no_sync_except_last=True to avoid unnecessary gradient reduction.

Functional Approach

Decorators

If you are developing a new project, we recommend also checking out the decorators we have provided to create higher order functions for cache.

grad_cache.functional.cached(func: Callable[..., Tensor])

A decorator that takes a model call function into a cached compatible version.

func - A function that calls the model and return representation tensor.

Return - A function that returns 1) representation leaf tensors for cache construction, 2) a closure function for the 2nd forward and the cached backward. Call 2) with 1) as argument after calling backward on the loss Tensor.

grad_cache.functional.cat_input_tensor(func: Callable[..., Tensor])

A decorator that concatenates positional and keyword arguments of type List[Tensor] into a single Tensor on the 0th dimension. This can come in handy dealing with results of representation tensors from multiple cached forward.

func - A loss function

Return - Decorated loss function for cached results.

Usage

The functional decorators are particular useful if your data loader is emitting small batches, from which you can construct the big batch. Say you also want to do automatic mixed precision, we first define the model call function and loss function,

from grad_cache.functional import cached, cat_input_tensor

import torch
import torch.nn.functional as F
from torch.cuda.amp import autocast

@cached
@autocast()
def  call_model(model, input):
	return model(**input).pooler_output

@cat_input_tensor
@autocast()
def  contrastive_loss(x, y):
	target = torch.arange(0, y.size(0), int(y.size(0) / x.size(0)), device=x.device)
	scores = torch.matmul(x, y.transpose(0, 1))
	return F.cross_entropy(scores, target=target)

Say you have a DataLoader loader emitting small batches of tuple (xx, yy) of size (M * N) and that you want to train by aggregating 16 small batches to get a batch of (16M * 16N),

cache_x = []
cache_y = []
closures_x = []
closures_y = []

for step, sub_batch in enumerate(loader):  
    xx, yy = sub_batch
    rx, cx = call_model(bert, xx)
    ry, cy = call_model(bert, yy)
    
    cache_x.append(rx)
    cache_y.append(ry)
    closuresx.append(cx)
    closuresy.append(cy)
    
    if (step + 1) % 16 == 0:
        loss = contrastive_loss(cache_x, cache_y)
        scaler.scale(loss).backward()
        
	for f, r in zip(closuresx, cache_x):
            f(r)
        for f, r in zip(closuresy, cache_y):
            f(r)

        cache_x = []
        cache_y = []
        closures_x = []
        closures_y = []
	
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

Code Structure

grad_cache/grad_cache.py - Define the GradCache class. The code is under 300 lines including comments. For development, we encourage you to read through it.

grad_cache/functional.py - Define decorators to create higher order function for gradient caching from ordinary model call functions and loss functions.

Owner
Luyu Gao
NLP Research [email protected], CMU
Luyu Gao
Deep Learning tutorials in jupyter notebooks.

DeepSchool.io Sign up here for Udemy Course on Machine Learning (Use code DEEPSCHOOL-MARCH to get 85% off course). Goals Make Deep Learning easier (mi

Sachin Abeywardana 1.8k Dec 28, 2022
Conditional Gradients For The Approximately Vanishing Ideal

Conditional Gradients For The Approximately Vanishing Ideal Code for the paper: Wirth, E., and Pokutta, S. (2022). Conditional Gradients for the Appro

IOL Lab @ ZIB 0 May 25, 2022
Implementation of the paper ''Implicit Feature Refinement for Instance Segmentation''.

Implicit Feature Refinement for Instance Segmentation This repository is an official implementation of the ACM Multimedia 2021 paper Implicit Feature

Lufan Ma 17 Dec 28, 2022
EfficientNetv2 TensorRT int8

EfficientNetv2_TensorRT_int8 EfficientNetv2模型实现来自https://github.com/d-li14/efficientnetv2.pytorch 环境配置 ubuntu:18.04 cuda:11.0 cudnn:8.0 tensorrt:7

34 Apr 24, 2022
Fuwa-http - The http client implementation for the fuwa eco-system

Fuwa HTTP The HTTP client implementation for the fuwa eco-system Example import

Fuwa 2 Feb 16, 2022
This YoloV5 based model is fit to detect people and different types of land vehicles, and displaying their density on a fitted map, according to their coordinates and detected labels.

This YoloV5 based model is fit to detect people and different types of land vehicles, and displaying their density on a fitted map, according to their

Liron Bdolah 8 May 22, 2022
Analysis code and Latex source of the manuscript describing the conditional permutation test of confounding bias in predictive modelling.

Git repositoty of the manuscript entitled Statistical quantification of confounding bias in predictive modelling by Tamas Spisak The manuscript descri

PNI - Predictive Neuroimaging Lab, University Hospital Essen, Germany 0 Nov 22, 2021
FinEAS: Financial Embedding Analysis of Sentiment 📈

FinEAS: Financial Embedding Analysis of Sentiment 📈 (SentenceBERT for Financial News Sentiment Regression) This repository contains the code for gene

LHF Labs 31 Dec 13, 2022
Scale-aware Automatic Augmentation for Object Detection (CVPR 2021)

SA-AutoAug Scale-aware Automatic Augmentation for Object Detection Yukang Chen, Yanwei Li, Tao Kong, Lu Qi, Ruihang Chu, Lei Li, Jiaya Jia [Paper] [Bi

DV Lab 182 Dec 29, 2022
Implementation of 'X-Linear Attention Networks for Image Captioning' [CVPR 2020]

Introduction This repository is for X-Linear Attention Networks for Image Captioning (CVPR 2020). The original paper can be found here. Please cite wi

JDAI-CV 240 Dec 17, 2022
Unsupervised Attributed Multiplex Network Embedding (AAAI 2020)

Unsupervised Attributed Multiplex Network Embedding (DMGI) Overview Nodes in a multiplex network are connected by multiple types of relations. However

Chanyoung Park 114 Dec 06, 2022
Many Class Activation Map methods implemented in Pytorch for CNNs and Vision Transformers. Including Grad-CAM, Grad-CAM++, Score-CAM, Ablation-CAM and XGrad-CAM

Class Activation Map methods implemented in Pytorch pip install grad-cam ⭐ Tested on many Common CNN Networks and Vision Transformers. ⭐ Includes smoo

Jacob Gildenblat 6.6k Jan 06, 2023
Video Background Music Generation with Controllable Music Transformer (ACM MM 2021 Oral)

CMT Code for paper Video Background Music Generation with Controllable Music Transformer (ACM MM 2021 Best Paper Award) [Paper] [Site] Directory Struc

Zhaokai Wang 198 Dec 27, 2022
Development kit for MIT Scene Parsing Benchmark

Development Kit for MIT Scene Parsing Benchmark [NEW!] Our PyTorch implementation is released in the following repository: https://github.com/hangzhao

MIT CSAIL Computer Vision 424 Dec 01, 2022
Official PyTorch implementation of "Contrastive Learning from Extremely Augmented Skeleton Sequences for Self-supervised Action Recognition" in AAAI2022.

AimCLR This is an official PyTorch implementation of "Contrastive Learning from Extremely Augmented Skeleton Sequences for Self-supervised Action Reco

Gty 44 Dec 17, 2022
Official code for our CVPR '22 paper "Dataset Distillation by Matching Training Trajectories"

Dataset Distillation by Matching Training Trajectories Project Page | Paper This repo contains code for training expert trajectories and distilling sy

George Cazenavette 256 Jan 05, 2023
Neural implicit reconstruction experiments for the Vector Neuron paper

Neural Implicit Reconstruction with Vector Neurons This repository contains code for the neural implicit reconstruction experiments in the paper Vecto

Congyue Deng 35 Jan 02, 2023
S2s2net - Sentinel-2 Super-Resolution Segmentation Network

S2S2Net Sentinel-2 Super-Resolution Segmentation Network Getting started Install

Wei Ji 10 Nov 10, 2022
Implementation for our ICCV 2021 paper: Dual-Camera Super-Resolution with Aligned Attention Modules

DCSR: Dual Camera Super-Resolution Implementation for our ICCV 2021 oral paper: Dual-Camera Super-Resolution with Aligned Attention Modules paper | pr

Tengfei Wang 110 Dec 20, 2022
DatasetGAN: Efficient Labeled Data Factory with Minimal Human Effort

DatasetGAN This is the official code and data release for: DatasetGAN: Efficient Labeled Data Factory with Minimal Human Effort Yuxuan Zhang*, Huan Li

302 Jan 05, 2023