tree-math: mathematical operations for JAX pytrees

Overview

tree-math: mathematical operations for JAX pytrees

tree-math makes it easy to implement numerical algorithms that work on JAX pytrees, such as iterative methods for optimization and equation solving. It does so by providing a wrapper class tree_math.Vector that defines array operations such as infix arithmetic and dot-products on pytrees as if they were vectors.

Why tree-math

In a library like SciPy, numerical algorithms are typically written to handle fixed-rank arrays, e.g., scipy.integrate.solve_ivp requires inputs of shape (n,). This is convenient for implementors of numerical methods, but not for users, because 1d arrays are typically not the best way to keep track of state for non-trivial functions (e.g., neural networks or PDE solvers).

tree-math provides an alternative to flattening and unflattening these more complex data structures ("pytrees") for use in numerical algorithms. Instead, the numerical algorithm itself can be written in way to handle arbitrary collections of arrays stored in pytrees. This avoids unnecessary memory copies, and gives the user more control over the memory layouts used in computation. In practice, this can often makes a big difference for computational efficiency as well, which is why support for flexible data structures is so prevalent inside libraries that use JAX.

Installation

tree-math is implemented in pure Python, and only depends upon JAX.

You can install it from PyPI: pip install tree-math.

User guide

tree-math is simple to use. Just pass arbitrary pytree objects into tree_math.Vector to create an a object that arithmetic as if all leaves of the pytree were flattened and concatenated together:

>>> import tree_math as tm
>>> import jax.numpy as jnp
>>> v = tm.Vector({'x': 1, 'y': jnp.arange(2, 4)})
>>> v
tree_math.Vector({'x': 1, 'y': DeviceArray([2, 3], dtype=int32)})
>>> v + 1
tree_math.Vector({'x': 2, 'y': DeviceArray([3, 4], dtype=int32)})
>>> v.sum()
DeviceArray(6, dtype=int32)

You can also find a few functions defined on vectors in tree_math.numpy, which implements a very restricted subset of jax.numpy. If you're interested in more functionality, please open an issue to discuss before sending a pull request. (In the long term, this separate module might disappear if we can support Vector objects directly inside jax.numpy.)

Vector objects are pytrees themselves, which means the are compatible with JAX transformations like jit, vmap and grad, and control flow like while_loop and cond.

When you're done manipulating vectors, you can pull out the underlying pytrees from the .tree property:

>>> v.tree
{'x': 1, 'y': DeviceArray([2, 3], dtype=int32)}

As an alternative to manipulating Vector objects directly, you can also use the functional transformations wrap and unwrap (see the "Example usage" below).

One important difference between tree_math and jax.numpy is that dot products in tree_math default to full precision on all platforms, rather than defaulting to bfloat16 precision on TPUs. This is useful for writing most numerical algorithms, and will likely be JAX's default behavior in the future.

In the near-term, we also plan to add a Matrix class that will make it possible to use tree-math for numerical algorithms such as L-BFGS which use matrices to represent stacks of vectors.

Example usage

Here is how we could write the preconditioned conjugate gradient method. Notice how similar the implementation is to the pseudocode from Wikipedia, unlike the implementation in JAX:

atol2) & (k < maxiter) def body_fun(value): x, r, gamma, p, k = value Ap = A(p) alpha = gamma / (p.conj() @ Ap) x_ = x + alpha * p r_ = r - alpha * Ap z_ = M(r_) gamma_ = r_.conj() @ z_ beta_ = gamma_ / gamma p_ = z_ + beta_ * p return x_, r_, gamma_, p_, k + 1 r0 = b - A(x0) p0 = z0 = M(r0) gamma0 = r0 @ z0 initial_value = (x0, r0, gamma0, p0, 0) x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value) return x_final">
import functools
from jax import lax
import tree_math as tm
import tree_math.numpy as tnp

@functools.partial(tm.wrap, vector_argnames=['b', 'x0'])
def cg(A, b, x0, M=lambda x: x, maxiter=5, tol=1e-5, atol=0.0):
  """jax.scipy.sparse.linalg.cg, written with tree_math."""
  A = tm.unwrap(A)
  M = tm.unwrap(M)

  atol2 = tnp.maximum(tol**2 * (b @ b), atol**2)

  def cond_fun(value):
    x, r, gamma, p, k = value
    return (r @ r > atol2) & (k < maxiter)

  def body_fun(value):
    x, r, gamma, p, k = value
    Ap = A(p)
    alpha = gamma / (p.conj() @ Ap)
    x_ = x + alpha * p
    r_ = r - alpha * Ap
    z_ = M(r_)
    gamma_ = r_.conj() @ z_
    beta_ = gamma_ / gamma
    p_ = z_ + beta_ * p
    return x_, r_, gamma_, p_, k + 1

  r0 = b - A(x0)
  p0 = z0 = M(r0)
  gamma0 = r0 @ z0
  initial_value = (x0, r0, gamma0, p0, 0)

  x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value)
  return x_final
Owner
Google
Google ❤️ Open Source
Google
Computer Vision and Pattern Recognition, NUS CS4243, 2022

CS4243_2022 Computer Vision and Pattern Recognition, NUS CS4243, 2022 Cloud Machine #1 : Google Colab (Free GPU) Follow this Notebook installation : h

Xavier Bresson 142 Dec 15, 2022
A Fast and Accurate One-Stage Approach to Visual Grounding, ICCV 2019 (Oral)

One-Stage Visual Grounding ***** New: Our recent work on One-stage VG is available at ReSC.***** A Fast and Accurate One-Stage Approach to Visual Grou

Zhengyuan Yang 118 Dec 05, 2022
A general python framework for visual object tracking and video object segmentation, based on PyTorch

PyTracking A general python framework for visual object tracking and video object segmentation, based on PyTorch. 📣 Two tracking/VOS papers accepted

2.6k Jan 04, 2023
Efficient 3D human pose estimation in video using 2D keypoint trajectories

3D human pose estimation in video with temporal convolutions and semi-supervised training This is the implementation of the approach described in the

Meta Research 3.1k Dec 29, 2022
DeepLabv3+:Encoder-Decoder with Atrous Separable Convolution语义分割模型在tensorflow2当中的实现

DeepLabv3+:Encoder-Decoder with Atrous Separable Convolution语义分割模型在tensorflow2当中的实现 目录 性能情况 Performance 所需环境 Environment 注意事项 Attention 文件下载 Download

Bubbliiiing 31 Nov 25, 2022
Automatic library of congress classification, using word embeddings from book titles and synopses.

Automatic Library of Congress Classification The Library of Congress Classification (LCC) is a comprehensive classification system that was first deve

Ahmad Pourihosseini 3 Oct 01, 2022
Unofficial implementation of the ImageNet, CIFAR 10 and SVHN Augmentation Policies learned by AutoAugment using pillow

AutoAugment - Learning Augmentation Policies from Data Unofficial implementation of the ImageNet, CIFAR10 and SVHN Augmentation Policies learned by Au

Philip Popien 1.3k Jan 02, 2023
Complementary Patch for Weakly Supervised Semantic Segmentation, ICCV21 (poster)

CPN (ICCV2021) This is an implementation of Complementary Patch for Weakly Supervised Semantic Segmentation, which is accepted by ICCV2021 poster. Thi

Ferenas 20 Dec 12, 2022
Ludwig is a toolbox that allows to train and evaluate deep learning models without the need to write code.

Translated in 🇰🇷 Korean/ Ludwig is a toolbox that allows users to train and test deep learning models without the need to write code. It is built on

Ludwig 8.7k Jan 05, 2023
Implementation of Monocular Direct Sparse Localization in a Prior 3D Surfel Map (DSL)

DSL Project page: https://sites.google.com/view/dsl-ram-lab/ Monocular Direct Sparse Localization in a Prior 3D Surfel Map Authors: Haoyang Ye, Huaiya

Haoyang Ye 93 Nov 30, 2022
Using some basic methods to show linkages and transformations of robotic arms

roboticArmVisualizer Python GUI application to create custom linkages and adjust joint angles. In the future, I plan to add 2d inverse kinematics solv

Sandesh Banskota 1 Nov 19, 2021
PICK: Processing Key Information Extraction from Documents using Improved Graph Learning-Convolutional Networks

Code for the paper "PICK: Processing Key Information Extraction from Documents using Improved Graph Learning-Convolutional Networks" (ICPR 2020)

Wenwen Yu 498 Dec 24, 2022
Distributing Deep Learning Hyperparameter Tuning for 3D Medical Image Segmentation

DistMIS Distributing Deep Learning Hyperparameter Tuning for 3D Medical Image Segmentation. DistriMIS Distributing Deep Learning Hyperparameter Tuning

HiEST 2 Sep 09, 2022
keyframes-CNN-RNN(action recognition)

keyframes-CNN-RNN(action recognition) Environment: python=3.7 pytorch=1.2 Datasets: Following the format of UCF101 action recognition. Run steps: Mo

4 Feb 09, 2022
P-Tuning v2: Prompt Tuning Can Be Comparable to Finetuning Universally Across Scales and Tasks

P-tuning v2 P-Tuning v2: Prompt Tuning Can Be Comparable to Finetuning Universally Across Scales and Tasks An optimized prompt tuning strategy for sma

THUDM 540 Dec 30, 2022
TensorFlow ROCm port

Documentation TensorFlow is an end-to-end open source platform for machine learning. It has a comprehensive, flexible ecosystem of tools, libraries, a

ROCm Software Platform 622 Jan 09, 2023
GNEE - GAT Neural Event Embeddings

GNEE - GAT Neural Event Embeddings This repository contains source code for the GNEE (GAT Neural Event Embeddings) method introduced in the paper: "Se

João Pedro Rodrigues Mattos 0 Sep 15, 2021
MegEngine implementation of YOLOX

Introduction YOLOX is an anchor-free version of YOLO, with a simpler design but better performance! It aims to bridge the gap between research and ind

旷视天元 MegEngine 77 Nov 22, 2022
Accelerated deep learning R&D

Accelerated deep learning R&D PyTorch framework for Deep Learning research and development. It focuses on reproducibility, rapid experimentation, and

Catalyst-Team 3.1k Jan 06, 2023
An Active Automata Learning Library Written in Python

AALpy An Active Automata Learning Library AALpy is a light-weight active automata learning library written in pure Python. You can start learning auto

TU Graz - SAL Dependable Embedded Systems Lab (DES Lab) 78 Dec 30, 2022