Differentiable Factor Graph Optimization for Learning Smoothers @ IROS 2021

Related tags

Deep Learningdfgo
Overview

Differentiable Factor Graph Optimization for Learning Smoothers

mypy

Figure describing the overall training pipeline proposed by our IROS paper. Contains five sections, arranged left to right: (1) system models, (2) factor graphs for state estimation, (3) MAP inference, (4) state estimates, and (5) errors with respect to ground-truth. Arrows show how gradients are backpropagated from right to left, starting directly from the final stage (error with respect to ground-truth) back to parameters of the system models.

Overview

Code release for our IROS 2021 conference paper:

Brent Yi1, Michelle A. Lee1, Alina Kloss2, Roberto Martín-Martín1, and Jeannette Bohg1. Differentiable Factor Graph Optimization for Learning Smoothers. Proceedings of the IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS), October 2021.

1Stanford University, {brentyi,michellelee,robertom,bohg}@cs.stanford.edu
2Max Planck Institute for Intelligent Systems, [email protected]


This repository contains models, training scripts, and experimental results, and can be used to either reproduce our results or as a reference for implementation details.

Significant chunks of the code written for this paper have been factored out of this repository and released as standalone libraries, which may be useful for building on our work. You can find each of them linked here:

  • jaxfg is our core factor graph optimization library.
  • jaxlie is our Lie theory library for working with rigid body transformations.
  • jax_dataclasses is our library for building JAX pytrees as dataclasses. It's similar to flax.struct, but has workflow improvements for static analysis and nested structures.
  • jax-ekf contains our EKF implementation.

Status

Included in this repo for the disk task:

  • Smoother training & results
    • Training: python train_disk_fg.py --help
    • Evaluation: python cross_validate.py --experiment-paths ./experiments/disk/fg/**/
  • Filter baseline training & results
    • Training: python train_disk_ekf.py --help
    • Evaluation: python cross_validate.py --experiment-paths ./experiments/disk/ekf/**/
  • LSTM baseline training & results
    • Training: python train_disk_lstm.py --help
    • Evaluation: python cross_validate.py --experiment-paths ./experiments/disk/lstm/**/

And, for the visual odometry task:

  • Smoother training & results (including ablations)
    • Training: python train_kitti_fg.py --help
    • Evaluation: python cross_validate.py --experiment-paths ./experiments/kitti/fg/**/
  • EKF baseline training & results
    • Training: python train_kitti_ekf.py --help
    • Evaluation: python cross_validate.py --experiment-paths ./experiments/kitti/ekf/**/
  • LSTM baseline training & results
    • Training: python train_kitti_lstm.py --help
    • Evaluation: python cross_validate.py --experiment-paths ./experiments/kitti/lstm/**/

Note that **/ indicates a recursive glob in zsh. This can be emulated in bash>4 via the globstar option (shopt -q globstar).

We've done our best to make our research code easy to parse, but it's still being iterated on! If you have questions, suggestions, or any general comments, please reach out or file an issue.

Setup

We use Python 3.8 and miniconda for development.

  1. Any calls to CHOLMOD (via scikit-sparse, sometimes used for eval but never for training itself) will require SuiteSparse:

    # Mac
    brew install suite-sparse
    
    # Debian
    sudo apt-get install -y libsuitesparse-dev
  2. Dependencies can be installed via pip:

    pip install -r requirements.txt

    In addition to JAX and the first-party dependencies listed above, note that this also includes various other helpers:

    • datargs (currently forked) is super useful for building type-safe argument parsers.
    • torch's Dataset and DataLoader interfaces are used for training.
    • fannypack contains some utilities for downloading datasets, working with PDB, polling repository commit hashes.

The requirements.txt provided will install the CPU version of JAX by default. For CUDA support, please see instructions from the JAX team.

Datasets

Datasets synced from Google Drive and loaded via h5py automatically as needed. If you're interested in downloading them manually, see lib/kitti/data_loading.py and lib/disk/data_loading.py.

Training

The naming convention for training scripts is as follows: train_{task}_{model type}.py.

All of the training scripts provide a command-line interface for configuring experiment details and hyperparameters. The --help flag will summarize these settings and their default values. For example, to run the training script for factor graphs on the disk task, try:

> python train_disk_fg.py --help

Factor graph training script for disk task.

optional arguments:
  -h, --help            show this help message and exit
  --experiment-identifier EXPERIMENT_IDENTIFIER
                        (default: disk/fg/default_experiment/fold_{dataset_fold})
  --random-seed RANDOM_SEED
                        (default: 94305)
  --dataset-fold {0,1,2,3,4,5,6,7,8,9}
                        (default: 0)
  --batch-size BATCH_SIZE
                        (default: 32)
  --train-sequence-length TRAIN_SEQUENCE_LENGTH
                        (default: 20)
  --num-epochs NUM_EPOCHS
                        (default: 30)
  --learning-rate LEARNING_RATE
                        (default: 0.0001)
  --warmup-steps WARMUP_STEPS
                        (default: 50)
  --max-gradient-norm MAX_GRADIENT_NORM
                        (default: 10.0)
  --noise-model {CONSTANT,HETEROSCEDASTIC}
                        (default: CONSTANT)
  --loss {JOINT_NLL,SURROGATE_LOSS}
                        (default: SURROGATE_LOSS)
  --pretrained-virtual-sensor-identifier PRETRAINED_VIRTUAL_SENSOR_IDENTIFIER
                        (default: disk/pretrain_virtual_sensor/fold_{dataset_fold})

When run, train scripts serialize experiment configurations to an experiment_config.yaml file. You can find hyperparameters in the experiments/ directory for all results presented in our paper.

Evaluation

All evaluation metrics are recorded at train time. The cross_validate.py script can be used to compute metrics across folds:

# Summarize all experiments with means and standard errors of recorded metrics.
python cross_validate.py

# Include statistics for every fold -- this is much more data!
python cross_validate.py --disaggregate

# We can also glob for a partial set of experiments; for example, all of the
# disk experiments.
# Note that the ** wildcard may fail in bash; see above for a fix.
python cross_validate.py --experiment-paths ./experiments/disk/**/

Acknowledgements

We'd like to thank Rika Antonova, Kevin Zakka, Nick Heppert, Angelina Wang, and Philipp Wu for discussions and feedback on both our paper and codebase. Our software design also benefits from ideas from several open-source projects, including Sophus, GTSAM, Ceres Solver, minisam, and SwiftFusion.

This work is partially supported by the Toyota Research Institute (TRI) and Google. This article solely reflects the opinions and conclusions of its authors and not TRI, Google, or any entity associated with TRI or Google.

Owner
Brent Yi
Brent Yi
A python toolbox for predictive uncertainty quantification, calibration, metrics, and visualization

Website, Tutorials, and Docs    Uncertainty Toolbox A python toolbox for predictive uncertainty quantification, calibration, metrics, and visualizatio

Uncertainty Toolbox 1.4k Dec 28, 2022
TensorFlow code for the neural network presented in the paper: "Structural Language Models of Code" (ICML'2020)

SLM: Structural Language Models of Code This is an official implementation of the model described in: "Structural Language Models of Code" [PDF] To ap

73 Nov 06, 2022
[CVPR'22] Official PyTorch Implementation of Collaborative Transformers for Grounded Situation Recognition

[CVPR'22] Collaborative Transformers for Grounded Situation Recognition Paper | Model Checkpoint This is the official PyTorch implementation of Collab

Junhyeong Cho 29 Dec 10, 2022
Lightweight mmm - Lightweight (Bayesian) Media Mix Model

Lightweight (Bayesian) Media Mix Model This is not an official Google product. L

Google 342 Jan 03, 2023
Multi-query Video Retreival

Multi-query Video Retreival

Princeton Visual AI Lab 17 Nov 22, 2022
StackNet is a computational, scalable and analytical Meta modelling framework

StackNet This repository contains StackNet Meta modelling methodology (and software) which is part of my work as a PhD Student in the computer science

Marios Michailidis 1.3k Dec 15, 2022
PyTorch implementation for our paper Learning Character-Agnostic Motion for Motion Retargeting in 2D, SIGGRAPH 2019

Learning Character-Agnostic Motion for Motion Retargeting in 2D We provide PyTorch implementation for our paper Learning Character-Agnostic Motion for

Rundi Wu 367 Dec 22, 2022
Tutoriais publicados nas nossas redes sociais para obtenção de dados, análises simples e outras tarefas relevantes no mercado financeiro.

Tutoriais Públicos Tutoriais publicados nas nossas redes sociais para obtenção de dados, análises simples e outras tarefas relevantes no mercado finan

Trading com Dados 68 Oct 15, 2022
Hi Guys, here I am providing examples, which will help you in Lerarning Python

LearningPython Hi guys, here I am trying to include as many practice examples of Python Language, as i Myself learn, and hope these will help you in t

4 Feb 03, 2022
Code for Robust Contrastive Learning against Noisy Views

Robust Contrastive Learning against Noisy Views This repository provides a PyTorch implementation of the Robust InfoNCE loss proposed in paper Robust

Ching-Yao Chuang 53 Jan 08, 2023
ADOP: Approximate Differentiable One-Pixel Point Rendering

ADOP: Approximate Differentiable One-Pixel Point Rendering Abstract: We present a novel point-based, differentiable neural rendering pipeline for scen

Darius Rückert 1.9k Jan 06, 2023
Code for ACL'2021 paper WARP 🌀 Word-level Adversarial ReProgramming

Code for ACL'2021 paper WARP 🌀 Word-level Adversarial ReProgramming. Outperforming `GPT-3` on SuperGLUE Few-Shot text classification.

YerevaNN 75 Nov 06, 2022
The official implementation of the Interspeech 2021 paper WSRGlow: A Glow-based Waveform Generative Model for Audio Super-Resolution.

WSRGlow The official implementation of the Interspeech 2021 paper WSRGlow: A Glow-based Waveform Generative Model for Audio Super-Resolution. Audio sa

Kexun Zhang 96 Jan 03, 2023
Delta Conformity Sociopatterns Analysis - Delta Conformity Sociopatterns Analysis

Delta_Conformity_Sociopatterns_Analysis ∆-Conformity is a local homophily measur

2 Jan 09, 2022
Repository for the paper "From global to local MDI variable importances for random forests and when they are Shapley values"

From global to local MDI variable importances for random forests and when they are Shapley values Antonio Sutera ( Antonio Sutera 3 Feb 23, 2022

The backbone CSPDarkNet of YOLOX.

YOLOX-Backbone The backbone CSPDarkNet of YOLOX. In this project, you can enjoy: CSPDarkNet-S CSPDarkNet-M CSPDarkNet-L CSPDarkNet-X CSPDarkNet-Tiny C

Jianhua Yang 9 Aug 22, 2022
Multi-label Co-regularization for Semi-supervised Facial Action Unit Recognition (NeurIPS 2019)

MLCR This is the source code for paper Multi-label Co-regularization for Semi-supervised Facial Action Unit Recognition. Xuesong Niu, Hu Han, Shiguang

Edson-Niu 60 Nov 29, 2022
Hyper-parameter optimization for sklearn

hyperopt-sklearn Hyperopt-sklearn is Hyperopt-based model selection among machine learning algorithms in scikit-learn. See how to use hyperopt-sklearn

1.4k Jan 01, 2023
Streamlit tool to explore coco datasets

What is this This tool given a COCO annotations file and COCO predictions file will let you explore your dataset, visualize results and calculate impo

Jakub Cieslik 75 Dec 16, 2022
Framework web SnakeServer.

SnakeServer - Framework Web 🐍 Documentação oficial do framework SnakeServer. Conteúdo Sobre Como contribuir Enviar relatórios de segurança Pull reque

Jaedson Silva 0 Jul 21, 2022