tree-math: mathematical operations for JAX pytrees

Overview

tree-math: mathematical operations for JAX pytrees

tree-math makes it easy to implement numerical algorithms that work on JAX pytrees, such as iterative methods for optimization and equation solving. It does so by providing a wrapper class tree_math.Vector that defines array operations such as infix arithmetic and dot-products on pytrees as if they were vectors.

Why tree-math

In a library like SciPy, numerical algorithms are typically written to handle fixed-rank arrays, e.g., scipy.integrate.solve_ivp requires inputs of shape (n,). This is convenient for implementors of numerical methods, but not for users, because 1d arrays are typically not the best way to keep track of state for non-trivial functions (e.g., neural networks or PDE solvers).

tree-math provides an alternative to flattening and unflattening these more complex data structures ("pytrees") for use in numerical algorithms. Instead, the numerical algorithm itself can be written in way to handle arbitrary collections of arrays stored in pytrees. This avoids unnecessary memory copies, and gives the user more control over the memory layouts used in computation. In practice, this can often makes a big difference for computational efficiency as well, which is why support for flexible data structures is so prevalent inside libraries that use JAX.

Installation

tree-math is implemented in pure Python, and only depends upon JAX.

You can install it from PyPI: pip install tree-math.

User guide

tree-math is simple to use. Just pass arbitrary pytree objects into tree_math.Vector to create an a object that arithmetic as if all leaves of the pytree were flattened and concatenated together:

>>> import tree_math as tm
>>> import jax.numpy as jnp
>>> v = tm.Vector({'x': 1, 'y': jnp.arange(2, 4)})
>>> v
tree_math.Vector({'x': 1, 'y': DeviceArray([2, 3], dtype=int32)})
>>> v + 1
tree_math.Vector({'x': 2, 'y': DeviceArray([3, 4], dtype=int32)})
>>> v.sum()
DeviceArray(6, dtype=int32)

You can also find a few functions defined on vectors in tree_math.numpy, which implements a very restricted subset of jax.numpy. If you're interested in more functionality, please open an issue to discuss before sending a pull request. (In the long term, this separate module might disappear if we can support Vector objects directly inside jax.numpy.)

Vector objects are pytrees themselves, which means the are compatible with JAX transformations like jit, vmap and grad, and control flow like while_loop and cond.

When you're done manipulating vectors, you can pull out the underlying pytrees from the .tree property:

>>> v.tree
{'x': 1, 'y': DeviceArray([2, 3], dtype=int32)}

As an alternative to manipulating Vector objects directly, you can also use the functional transformations wrap and unwrap (see the "Example usage" below).

One important difference between tree_math and jax.numpy is that dot products in tree_math default to full precision on all platforms, rather than defaulting to bfloat16 precision on TPUs. This is useful for writing most numerical algorithms, and will likely be JAX's default behavior in the future.

In the near-term, we also plan to add a Matrix class that will make it possible to use tree-math for numerical algorithms such as L-BFGS which use matrices to represent stacks of vectors.

Example usage

Here is how we could write the preconditioned conjugate gradient method. Notice how similar the implementation is to the pseudocode from Wikipedia, unlike the implementation in JAX:

atol2) & (k < maxiter) def body_fun(value): x, r, gamma, p, k = value Ap = A(p) alpha = gamma / (p.conj() @ Ap) x_ = x + alpha * p r_ = r - alpha * Ap z_ = M(r_) gamma_ = r_.conj() @ z_ beta_ = gamma_ / gamma p_ = z_ + beta_ * p return x_, r_, gamma_, p_, k + 1 r0 = b - A(x0) p0 = z0 = M(r0) gamma0 = r0 @ z0 initial_value = (x0, r0, gamma0, p0, 0) x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value) return x_final">
import functools
from jax import lax
import tree_math as tm
import tree_math.numpy as tnp

@functools.partial(tm.wrap, vector_argnames=['b', 'x0'])
def cg(A, b, x0, M=lambda x: x, maxiter=5, tol=1e-5, atol=0.0):
  """jax.scipy.sparse.linalg.cg, written with tree_math."""
  A = tm.unwrap(A)
  M = tm.unwrap(M)

  atol2 = tnp.maximum(tol**2 * (b @ b), atol**2)

  def cond_fun(value):
    x, r, gamma, p, k = value
    return (r @ r > atol2) & (k < maxiter)

  def body_fun(value):
    x, r, gamma, p, k = value
    Ap = A(p)
    alpha = gamma / (p.conj() @ Ap)
    x_ = x + alpha * p
    r_ = r - alpha * Ap
    z_ = M(r_)
    gamma_ = r_.conj() @ z_
    beta_ = gamma_ / gamma
    p_ = z_ + beta_ * p
    return x_, r_, gamma_, p_, k + 1

  r0 = b - A(x0)
  p0 = z0 = M(r0)
  gamma0 = r0 @ z0
  initial_value = (x0, r0, gamma0, p0, 0)

  x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value)
  return x_final
Owner
Google
Google ❤️ Open Source
Google
A Simple LSTM-Based Solution for "Heartbeat Signal Classification and Prediction" in Tianchi

LSTM-Time-Series-Prediction A Simple LSTM-Based Solution for "Heartbeat Signal Classification and Prediction" in Tianchi Contest. The Link of the Cont

KevinCHEN 1 Jun 13, 2022
Code-free deep segmentation for computational pathology

NoCodeSeg: Deep segmentation made easy! This is the official repository for the manuscript "Code-free development and deployment of deep segmentation

André Pedersen 26 Nov 23, 2022
Evolutionary Population Curriculum for Scaling Multi-Agent Reinforcement Learning

Evolutionary Population Curriculum for Scaling Multi-Agent Reinforcement Learning This is the code for implementing the MADDPG algorithm presented in

97 Dec 21, 2022
Examples of using f2py to get high-speed Fortran integrated with Python easily

f2py Examples Simple examples of using f2py to get high-speed Fortran integrated with Python easily. These examples are also useful to troubleshoot pr

Michael 35 Aug 21, 2022
Bio-OFC gym implementation and Gym-Fly environment

Bio-OFC gym implementation and Gym-Fly environment This repository includes the gym compatible implementation of the Bio-OFC algorithm from the paper

Siavash Golkar 1 Nov 16, 2021
Image Completion with Deep Learning in TensorFlow

Image Completion with Deep Learning in TensorFlow See my blog post for more details and usage instructions. This repository implements Raymond Yeh and

Brandon Amos 1.3k Dec 23, 2022
[AAAI2021] The source code for our paper 《Enhancing Unsupervised Video Representation Learning by Decoupling the Scene and the Motion》.

DSM The source code for paper Enhancing Unsupervised Video Representation Learning by Decoupling the Scene and the Motion Project Website; Datasets li

Jinpeng Wang 114 Oct 16, 2022
A comprehensive list of published machine learning applications to cosmology

ml-in-cosmology This github attempts to maintain a comprehensive list of published machine learning applications to cosmology, organized by subject ma

George Stein 290 Dec 29, 2022
PAWS 🐾 Predicting View-Assignments with Support Samples

This repo provides a PyTorch implementation of PAWS (predicting view assignments with support samples), as described in the paper Semi-Supervised Learning of Visual Features by Non-Parametrically Pre

Facebook Research 437 Dec 23, 2022
SMCA replication There are no extra compiled components in SMCA DETR and package dependencies are minimal

Usage There are no extra compiled components in SMCA DETR and package dependencies are minimal, so the code is very simple to use. We provide instruct

22 May 06, 2022
Core ML tools contain supporting tools for Core ML model conversion, editing, and validation.

Core ML Tools Use coremltools to convert machine learning models from third-party libraries to the Core ML format. The Python package contains the sup

Apple 3k Jan 08, 2023
Attention mechanism with MNIST dataset

[TensorFlow] Attention mechanism with MNIST dataset Usage $ python run.py Result Training Loss graph. Test Each figure shows input digit, attention ma

YeongHyeon Park 12 Jun 10, 2022
Video lie detector using xgboost - A video lie detector using OpenFace and xgboost

video_lie_detector_using_xgboost a video lie detector using OpenFace and xgboost

2 Jan 11, 2022
2021:"Bridging Global Context Interactions for High-Fidelity Image Completion"

TFill arXiv | Project This repository implements the training, testing and editing tools for "Bridging Global Context Interactions for High-Fidelity I

Chuanxia Zheng 111 Jan 08, 2023
Recurrent Conditional Query Learning

Recurrent Conditional Query Learning (RCQL) This repository contains the Pytorch implementation of One Model Packs Thousands of Items with Recurrent C

Dongda 4 Nov 28, 2022
GeneralOCR is open source Optical Character Recognition based on PyTorch.

Introduction GeneralOCR is open source Optical Character Recognition based on PyTorch. It makes a fidelity and useful tool to implement SOTA models on

57 Dec 29, 2022
HMLLDB is a collection of LLDB commands to assist in the debugging of iOS apps.

HMLLDB is a collection of LLDB commands to assist in the debugging of iOS apps. 中文介绍 Features Non-intrusive. Your iOS project does not need to be modi

mao2020 47 Oct 22, 2022
An open-source Deep Learning Engine for Healthcare that aims to treat & prevent major diseases

AlphaCare Background AlphaCare is a work-in-progress, open-source Deep Learning Engine for Healthcare that aims to treat and prevent major diseases. T

Siraj Raval 44 Nov 05, 2022
Bridging Vision and Language Model

BriVL BriVL (Bridging Vision and Language Model) 是首个中文通用图文多模态大规模预训练模型。BriVL模型在图文检索任务上有着优异的效果,超过了同期其他常见的多模态预训练模型(例如UNITER、CLIP)。 BriVL论文:WenLan: Bridgi

235 Dec 27, 2022
Soft actor-critic is a deep reinforcement learning framework for training maximum entropy policies in continuous domains.

This repository is no longer maintained. Please use our new Softlearning package instead. Soft Actor-Critic Soft actor-critic is a deep reinforcement

Tuomas Haarnoja 752 Jan 07, 2023