JMP is a Mixed Precision library for JAX.

Related tags

Machine Learningjmp
Overview

Mixed precision training in JAX

Test status PyPI version

Installation | Examples | Policies | Loss scaling | Citing JMP | References

Mixed precision training [0] is a technique that mixes the use of full and half precision floating point numbers during training to reduce the memory bandwidth requirements and improve the computational efficiency of a given model.

This library implements support for mixed precision training in JAX by providing two key abstractions (mixed precision "policies" and loss scaling). Neural network libraries (such as Haiku) can integrate with jmp and provide "Automatic Mixed Precision (AMP)" support (automating or simplifying applying policies to modules).

All code examples below assume the following:

import jax
import jax.numpy as jnp
import jmp

half = jnp.float16  # On TPU this should be jnp.bfloat16.
full = jnp.float32

Installation

JMP is written in pure Python, but depends on C++ code via JAX and NumPy.

Because JAX installation is different depending on your CUDA version, JMP does not list JAX as a dependency in requirements.txt.

First, follow these instructions to install JAX with the relevant accelerator support.

Then, install JMP using pip:

$ pip install git+https://github.com/deepmind/jmp

Examples

You can find a fully worked JMP example in Haiku which shows how to use mixed f32/f16 precision to halve training time on GPU and mixed f32/bf16 to reduce training time on TPU by a third.

Policies

A mixed precision policy encapsulates the configuration in a mixed precision experiment.

# Our policy specifies that we will store parameters in full precision but will
# compute and return output in half precision.
my_policy = jmp.Policy(compute_dtype=half,
                       param_dtype=full,
                       output_dtype=half)

The policy object can be used to cast pytrees:

def layer(params, x):
  params, x = my_policy.cast_to_compute((params, x))
  w, b = params
  y = x @ w + b
  return my_policy.cast_to_output(y)

params = {"w": jnp.ones([], dtype=my_policy.param_dtype)}
y = layer(params, x)
assert y.dtype == half

You can replace the output type of a given policy:

my_policy = my_policy.with_output_dtype(full)

You can also define a policy via a string, which may be useful for specifying a policy as a command-line argument or as a hyperparameter to your experiment:

my_policy = jmp.get_policy("params=float32,compute=float16,output=float32")
float16 = jmp.get_policy("float16")  # Everything in f16.
half = jmp.get_policy("half")        # Everything in half (f16 or bf16).

Loss scaling

When training with reduced precision, consider whether gradients will need to be shifted into the representable range of the format that you are using. This is particularly important when training with float16 and less important for bfloat16. See the NVIDIA mixed precision user guide [1] for more details.

The easiest way to shift gradients is with loss scaling, which scales your loss and gradients by S and 1/S respectively.

def my_loss_fn(params, loss_scale: jmp.LossScale, ...):
  loss = ...
  # You should apply regularization etc before scaling.
  loss = loss_scale.scale(loss)
  return loss

def train_step(params, loss_scale: jmp.LossScale, ...):
  grads = jax.grad(my_loss_fn)(...)
  grads = loss_scale.unscale(grads)
  # You should put gradient clipping etc after unscaling.
  params = apply_optimizer(params, grads)
  return params

loss_scale = jmp.StaticLossScale(2 ** 15)
for _ in range(num_steps):
  params = train_step(params, loss_scale, ...)

The appropriate value for S depends on your model, loss, batch size and potentially other factors. You can determine this with trial and error. As a rule of thumb you want the largest value of S that does not introduce overflow during backprop. NVIDIA [1] recommend computing statistics about the gradients of your model (in full precision) and picking S such that its product with the maximum norm of your gradients is below 65,504.

We provide a dynamic loss scale, which adjusts the loss scale periodically during training to find the largest value for S that produces finite gradients. This is more convenient and robust compared with picking a static loss scale, but has a small performance impact (between 1 and 5%).

def my_loss_fn(params, loss_scale: jmp.LossScale, ...):
  loss = ...
  # You should apply regularization etc before scaling.
  loss = loss_scale.scale(loss)
  return loss

def train_step(params, loss_scale: jmp.LossScale, ...):
  grads = jax.grad(my_loss_fn)(...)
  grads = loss_scale.unscale(grads)
  # You should put gradient clipping etc after unscaling.

  # You definitely want to skip non-finite updates with the dynamic loss scale,
  # but you might also want to consider skipping them when using a static loss
  # scale if you experience NaN's when training.
  skip_nonfinite_updates = isinstance(loss_scale, jmp.DynamicLossScale)

  if skip_nonfinite_updates:
    grads_finite = jmp.all_finite(grads)
    # Adjust our loss scale depending on whether gradients were finite. The
    # loss scale will be periodically increased if gradients remain finite and
    # will be decreased if not.
    loss_scale = loss_scale.adjust(grads_finite)
    # Only apply our optimizer if grads are finite, if any element of any
    # gradient is non-finite the whole update is discarded.
    params = jmp.select_tree(grads_finite, apply_optimizer(params, grads), params)
  else:
    # With static or no loss scaling just apply our optimizer.
    params = apply_optimizer(params, grads)

  # Since our loss scale is dynamic we need to return the new value from
  # each step. All loss scales are `PyTree`s.
  return params, loss_scale

loss_scale = jmp.DynamicLossScale(jmp.half_dtype()(2 ** 15))
for _ in range(num_steps):
  params, loss_scale = train_step(params, loss_scale, ...)

In general using a static loss scale should offer the best speed, but we have optimized dynamic loss scaling to make it competitive. We recommend you start with dynamic loss scaling and move to static loss scaling if performance is an issue.

We finally offer a no-op loss scale which you can use as a drop in replacement. It does nothing (apart from implement the jmp.LossScale API):

loss_scale = jmp.NoOpLossScale()
assert loss is loss_scale.scale(loss)
assert grads is loss_scale.unscale(grads)
assert loss_scale is loss_scale.adjust(grads_finite)
assert loss_scale.loss_scale == 1

Citing JMP

This repository is part of the DeepMind JAX Ecosystem, to cite JMP please use the DeepMind JAX Ecosystem citation.

References

[0] Paulius Micikevicius, Sharan Narang, Jonah Alben, Gregory Diamos, Erich Elsen, David Garcia, Boris Ginsburg, Michael Houston, Oleksii Kuchaiev, Ganesh Venkatesh, Hao Wu: "Mixed Precision Training", 2017; arXiv:1710.03740 https://arxiv.org/abs/1710.03740.

[1] "Training With Mixed Precision :: NVIDIA Deep Learning Performance Documentation". Docs.Nvidia.Com, 2020, https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/.

Comments
  • Questions around speedup

    Questions around speedup

    Hi,

    Thanks for creating this amazing library!

    So from what I understood this is the minimal needed to benefit from the JMP speedup

    policy = jmp.Policy(param_dtype=jnp.float32, compute_dtype=jnp.float16, output_dtype=jnp.float16)
    ...
    # creating a network, and creating a loss_fn using the network
    data = policy.cast_to_compute(data)
    params = policy.cast_to_compute(params)
    
    grads = jax.grad(loss_fn)(data, params)
    ...
    grads = policy.cast_to_param(grads)
    

    The loss scale is only needed if we experience NaN or inf values. And thus we should see a difference in the time needed for a jax.grad operation when the data and params are in float16.

    Please correct me if I'm wrong or if I miss anything.

    Considering that, I have some questions:

    • Can we expect to see a 2x or any speedup if we time the jax.grad operation with (params, data) in float16 and in float32 or is the speedup only achieved at scale ?
    • Can we expect to see a 2x or any speedup with small networks and small experiments ? (say 2-layer nets and experiments that only need 2-3 minutes on a single gpu)
    • Can we expect to see a 2x or any speedup with all types of networks or only some with specific architectures ? (say a 50 layer MLP, Transformers, etc)
    • Is it necessary to apply the mixed precision policy to the network and thus to use Haiku for hk.mixed_precision.set_policy or is having the parameters and data in float16 sufficient to have a speedup even with a network created by us ?

    Thank you and have a nice day !

    opened by 1m1ne 3
  • Basic question about the use of my_policy

    Basic question about the use of my_policy

    Hi,

    Thanks for creating this amazing functionality!

    I have a basic question about the use of policy functions to set certain precision levels. As stated in the example of your README.md:

    def layer(params, x):
      params, x = my_policy.cast_to_compute((params, x))
      w, b = params
      y = x @ w + b
      return my_policy.cast_to_output(y)
    

    I am new to JAX, but the first thing I learned is that JAX likes pure functions. Does the use of my_policy violate the pure functions paradigm?

    Should it become:

    def layer(params, x, my_policy):
      params, x = my_policy.cast_to_compute((params, x))
      w, b = params
      y = x @ w + b
      return my_policy.cast_to_output(y)
    

    The function can be jitted by using partial.

    from functools import partial
    layer_compiled = jit(partial(layer, my_policy=my_policy))
    

    Fundamental question

    A more fundamental question (which I maybe need to ask at JAX ) is: How should functions as input to functions be handled in JAX?

    func_a(x,w,func):
      ...
    
    func_a_compiled = jit(partial(func_a, func=func_b))
    

    In order to jit this, I came up with the partial solution above. I assume your use of my_policy is valid, as you probably have more experience with JAX. But that creates some magic, which is undesirable for my use case. Is the jit(partial()) solution valid or is there a better way to handle functions as input to functions?

    Have a nice day! J

    opened by JSchuurmans 2
  •  jmp-0.0.2.tar.gz on PyPY doesn't contain requirements.txt

    jmp-0.0.2.tar.gz on PyPY doesn't contain requirements.txt

    When trying to build my own wheel from the tar-ball of jmp-0.0.2 that has been uploaded to PyPI I get the following error:

    Collecting jmp==0.0.2
      Downloading jmp-0.0.2.tar.gz (13 kB)
        Running command python setup.py egg_info
        Traceback (most recent call last):
          File "<string>", line 1, in <module>
          File "/tmp/pip-download-zowwz99o/jmp_55ae4a2c4cdb44aa99c65fbf7a5ee9bb/setup.py", line 55, in <module>
            install_requires=_parse_requirements('requirements.txt'),
          File "/tmp/pip-download-zowwz99o/jmp_55ae4a2c4cdb44aa99c65fbf7a5ee9bb/setup.py", line 32, in _parse_requirements
            with open(requirements_txt_path) as fp:
        FileNotFoundError: [Errno 2] No such file or directory: 'requirements.txt'
    WARNING: Discarding https://files.pythonhosted.org/packages/7c/ba/a6bfcaeedca8551e2fb4054d1fd061a0dd97d26dd44002b3e92d13b51877/jmp-0.0.2.tar.gz#sha256=fdb5cec0d10aab4116c2770f24b2adf4f503fcfbb96ce8ef583e1879bdbf1b9b (from https://pypi.org/simple/jmp/). Command errored out with exit status 1: python setup.py egg_info Check the logs for full command output.
    ERROR: Could not find a version that satisfies the requirement jmp==0.0.2 (from versions: 0.0.1, 0.0.2)
    ERROR: No matching distribution found for jmp==0.0.2
    

    After downloading and extracting jmp-0.0.2.tar.gz from PyPI and looking at setup.py, I see that it contains lines that reference requirements.txt:

    install_requires=_parse_requirements('requirements.txt')
    

    however that file is not part of the source distribution and my build fails.

    Oliver

    P.S. Yes I know that there is a wheel for jmp-0.0.2 on PyPI, which I ended up using, but I'm using our Wheels_builder script that will recursively build wheels for dependencies for our systems. My point is that the source distribution of jmp 0.0.2 is incomplete as it lacks the information that is contained in the requirements file.

    opened by ostueker 1
  • Casting numpy array

    Casting numpy array

    I'll first start by saying that I've just starting with JAX, so I might be doing something wrong.

    When I run the following code:

    precision = jmp.get_policy('params=float32,compute=float16,output=float16')
    some_input = np.arange(15).reshape((5, 3))
    @jax.jit
    def some_function(some_input):
        some_input = precision.cast_to_compute(some_input)
        print(some_input.dtype)
        # prints float32
       return a_float16_compute_model(some_input) # fails
    

    It seems that at least on the first (tracing) run, the numpy array doesn't get cast to float16, maybe because it is being treated as a tree, here

    https://github.com/deepmind/jmp/blob/4b94370b8de29b79d6f840b09d1990b91c1afddd/jmp/_src/policy.py#L26

    Surprisingly, if I run

    some_input = np.arange(15).reshape((5, 3)).astype(precision.compute_dtype)
    print(some_input.dtype)
    
    #prints float16
    

    the cast succeeds. I think that the expected behavior is that precision.cast_to_compute(some_input) should return a numpy array with compute_dtype, but I might be missing something.

    opened by yardenas 1
  • [JAX] Fix test failure due to upcoming change to JAX.

    [JAX] Fix test failure due to upcoming change to JAX.

    [JAX] Fix test failure due to upcoming change to JAX.

    An upcoming change to JAX adds a @jit decorator around a number of array operators, in this case division (/). A side effect of that change is that large Python integer constants that overflow an int32 or int64 type may produce an error. The workaround is either to explicitly cast the large constants to a specific type (e.g., np.float64), or in this case we can just do the math in question in classic NumPy since it is computing test expectations.

    cla: yes 
    opened by copybara-service[bot] 0
  • Bump version to `0.0.3.dev`.

    Bump version to `0.0.3.dev`.

    Bump version to 0.0.3.dev.

    This is so if users install from GitHub we can see they are not using a stable version in bug reports:

    $ pip install git+https://github.com/deepmind/jmp
    $ python -c 'import jmp ; print(jmp.__version__)'
    0.0.3.dev
    
    cla: yes 
    opened by copybara-service[bot] 0
  • Cut 0.0.2 release and create GitHub action to publish to PyPi.

    Cut 0.0.2 release and create GitHub action to publish to PyPi.

    Cut 0.0.2 release and create GitHub action to publish to PyPi.

    We are not using 0.0.1 because this version was already used by the previous owner of the pypi package (https://pypi.org/project/jmp/0.0.1/). As such our releases will start at 0.0.2.

    cla: yes 
    opened by copybara-service[bot] 0
  • Replace jax.lax.select with jnp.where

    Replace jax.lax.select with jnp.where

    Thanks for the awesome work!

    This PR fixes an issue where jax.lax.select complains about dtypes not being equal when adjusting the DynamicLossScale.

    The exception's stack trace ends with:

    File ".../jmp/_src/loss_scale.py", line 147, in adjust
        loss_scale = jax.lax.select(
    TypeError: lax.select requires arguments to have the same dtypes, got float32, int32. (Tip: jnp.where is a similar function that does automatic type promotion on inputs).
    

    My code looks similar to

    scale = jmp.DynamicLossScale(jnp.asarray(2 ** 15))
    ...
    gradients, scale = gradient_fn(..., scale)
    gradients = scale.unscale(gradients)
    gradients_finite = jmp.all_finite(gradients)
    scale = scale.adjust(gradients_finite)  # This line throws the exception
    ...
    
    opened by nlsfnr 6
Releases(v0.0.2)
  • v0.0.2(Apr 15, 2021)

    Initial release of JMP.

    Changelog:

    • Add jmp.Policy abstraction and jmp.get_policy(..) factory.
    • Add jmp.LossScale and three implementations thereof (noop, static and dynamic).
    • Add various utilities (jmp.all_finite) to support common tasks in mixed precision codebases.
    Source code(tar.gz)
    Source code(zip)
Owner
DeepMind
DeepMind
learn python in 100 days, a simple step could be follow from beginner to master of every aspect of python programming and project also include side project which you can use as demo project for your personal portfolio

learn python in 100 days, a simple step could be follow from beginner to master of every aspect of python programming and project also include side project which you can use as demo project for your

BDFD 6 Nov 05, 2022
Distributed Tensorflow, Keras and PyTorch on Apache Spark/Flink & Ray

A unified Data Analytics and AI platform for distributed TensorFlow, Keras and PyTorch on Apache Spark/Flink & Ray What is Analytics Zoo? Analytics Zo

2.5k Dec 28, 2022
Machine Learning from Scratch

Machine Learning from Scratch Author: Shengxuan Wang From: Oregon State University Content: Building Machine Learning model from Scratch, without usin

ShawnWang 0 Jul 05, 2022
ZenML 🙏: MLOps framework to create reproducible ML pipelines for production machine learning.

ZenML is an extensible, open-source MLOps framework to create production-ready machine learning pipelines. It has a simple, flexible syntax, is cloud and tool agnostic, and has interfaces/abstraction

ZenML 2.6k Jan 08, 2023
50% faster, 50% less RAM Machine Learning. Numba rewritten Sklearn. SVD, NNMF, PCA, LinearReg, RidgeReg, Randomized, Truncated SVD/PCA, CSR Matrices all 50+% faster

[Due to the time taken @ uni, work + hell breaking loose in my life, since things have calmed down a bit, will continue commiting!!!] [By the way, I'm

Daniel Han-Chen 1.4k Jan 01, 2023
PyHarmonize: Adding harmony lines to recorded melodies in Python

PyHarmonize: Adding harmony lines to recorded melodies in Python About To use this module, the user provides a wav file containing a melody, the key i

Julian Kappler 2 May 20, 2022
Predicting India’s COVID-19 Third Wave with LSTM

Predicting India’s COVID-19 Third Wave with LSTM Complete project of predicting new COVID-19 cases in the next 90 days with LSTM India is seeing a ste

Samrat Dutta 4 Jan 27, 2022
dirty_cat is a Python module for machine-learning on dirty categorical variables.

dirty_cat dirty_cat is a Python module for machine-learning on dirty categorical variables.

637 Dec 29, 2022
DirectML is a high-performance, hardware-accelerated DirectX 12 library for machine learning.

DirectML is a high-performance, hardware-accelerated DirectX 12 library for machine learning. DirectML provides GPU acceleration for common machine learning tasks across a broad range of supported ha

Microsoft 1.1k Jan 04, 2023
A comprehensive repository containing 30+ notebooks on learning machine learning!

A comprehensive repository containing 30+ notebooks on learning machine learning!

Jean de Dieu Nyandwi 3.8k Jan 09, 2023
Deep Survival Machines - Fully Parametric Survival Regression

Package: dsm Python package dsm provides an API to train the Deep Survival Machines and associated models for problems in survival analysis. The under

Carnegie Mellon University Auton Lab 10 Dec 30, 2022
Book Recommender System Using Sci-kit learn N-neighbours

Model-Based-Recommender-Engine I created a book Recommender System using Sci-kit learn's N-neighbours algorithm for my model and the streamlit library

1 Jan 13, 2022
Pragmatic AI Labs 421 Dec 31, 2022
Falken provides developers with a service that allows them to train AI that can play their games

Falken provides developers with a service that allows them to train AI that can play their games. Unlike traditional RL frameworks that learn through rewards or batches of offline training, Falken is

Google Research 223 Jan 03, 2023
Learn Machine Learning Algorithms by doing projects in Python and R Programming Language

Learn Machine Learning Algorithms by doing projects in Python and R Programming Language. This repo covers all aspect of Machine Learning Algorithms.

Ravi Chaubey 6 Oct 20, 2022
Bayesian optimization in JAX

Bayesian optimization in JAX

Predictive Intelligence Lab 26 May 11, 2022
The unified machine learning framework, enabling framework-agnostic functions, layers and libraries.

The unified machine learning framework, enabling framework-agnostic functions, layers and libraries. Contents Overview In a Nutshell Where Next? Overv

Ivy 8.2k Dec 31, 2022
Contains an implementation (sklearn API) of the algorithm proposed in "GENDIS: GEnetic DIscovery of Shapelets" and code to reproduce all experiments.

GENDIS GENetic DIscovery of Shapelets In the time series classification domain, shapelets are small subseries that are discriminative for a certain cl

IDLab Services 90 Oct 28, 2022
Regularization and Feature Selection in Least Squares Temporal Difference Learning

Regularization and Feature Selection in Least Squares Temporal Difference Learning Description This is Python implementations of Least Angle Regressio

Mina Parham 0 Jan 18, 2022
PySpark ML Bank Churn Prediction

PySpark-Bank-Churn Surname: corresponds to the record (row) number and has no effect on the output. CreditScore: contains random values and has no eff

kemalgunay 2 Nov 11, 2021