Invariant Causal Prediction for Block MDPs

Overview

MISA

Abstract

Generalization across environments is critical to the successful application of reinforcement learning algorithms to real-world challenges. In this paper, we consider the problem of learning abstractions that generalize in block MDPs, families of environments with a shared latent state space and dynamics structure over that latent space, but varying observations. We leverage tools from causal inference to propose a method of invariant prediction to learn model-irrelevance state abstractions (MISA) that generalize to novel observations in the multi-environment setting. We prove that for certain classes of environments, this approach outputs with high probability a state abstraction corresponding to the causal feature set with respect to the return. We further provide more general bounds on model error and generalization error in the multi-environment setting, in the process showing a connection between causal variable selection and the state abstraction framework for MDPs. We give empirical evidence that our methods work in both linear and nonlinear settings, attaining improved generalization over single-and multi-task baselines.

Citation

@inproceedings{zhang2020invariant,
    title={Invariant Causal Prediction for Block MDPs},
    author={Amy Zhang and Clare Lyle and Shagun Sodhani and Angelos Filos and Marta Kwiatkowska and Joelle Pineau and Yarin Gal and Doina Precup},
    year={2020},
    booktitle={International Conference on Machine Learning (ICML)},
}

Experiments

The three sets of experiments on model learning, imitation learning, and reinforcement learning can be found in their respective folder. To install requirements, create a new conda environment and run

pip install -e requirements.txt

In model learning, there are two sets of experiments, linear MISA and nonlinear MISA. The code is in model_learning. First cd model_learning.

The main experiment with linear MISA can be run with

ICPAbstractMDP.ipynb

The main experiment with nonlinear MISA can be run with

python main.py

For running the imitation learning experiments, first cd imitation_learning. Then install the baselines by running cd baselines && pip install tensorflow==1.14 && pip install -e . The main experiments can be run in imitation_learning directory with:

python train_expert.py --save_model --save_model_path models # Training the expert model

#Lets say the model was trained for 150K steps.

mkdir -p buffers/train/0 buffers/train/1 buffers/eval/0 # Directory to hold the buffer data

python collect_data_using_expert_policy.py --load_model_path models_150000 --save_buffer --save_buffer_path buffers  # Collecting the trajectories using the expert model

python train.py --use_single_encoder_decoder --num_train_envs 1 --num_eval_envs 1 --load_buffer_path buffers # MISA One Env

python train.py --use_single_encoder_decoder --num_train_envs 2 --num_eval_envs 1 --load_buffer_path buffers # Baseline One Decoder 

python train.py --use_discriminator --num_train_envs 2 --num_eval_envs 1 --load_buffer_path buffers # Proposed Approach

python train.py --use_irm_loss --num_train_envs 2 --num_eval_envs 1 --load_buffer_path buffers # IRM

In reinforcement learning, the main experiment can be run in reinforcement_learning directory with

./run_local.sh

LICENSE

Attribution-NonCommercial 4.0 International

Owner
Meta Research
Meta Research
Help you understand Manual and w/ Clutch point while driving.

简体中文 forza_auto_gear forza_auto_gear is a tool for Forza Horizon 5. It will help us understand the best gear shift point using Manual or w/ Clutch in

15 Oct 08, 2022
Modular Probabilistic Programming on MXNet

MXFusion | | | | Tutorials | Documentation | Contribution Guide MXFusion is a modular deep probabilistic programming library. With MXFusion Modules yo

Amazon 100 Dec 10, 2022
Must-read Papers on Physics-Informed Neural Networks.

PINNpapers Contributed by IDRL lab. Introduction Physics-Informed Neural Network (PINN) has achieved great success in scientific computing since 2017.

IDRL 330 Jan 07, 2023
Official repository for the NeurIPS 2021 paper Get Fooled for the Right Reason: Improving Adversarial Robustness through a Teacher-guided curriculum Learning Approach

Get Fooled for the Right Reason Official repository for the NeurIPS 2021 paper Get Fooled for the Right Reason: Improving Adversarial Robustness throu

Sowrya Gali 1 Apr 25, 2022
Official Pytorch Implementation of Unsupervised Image Denoising with Frequency Domain Knowledge

Unsupervised Image Denoising with Frequency Domain Knowledge (BMVC 2021 Oral) : Official Project Page This repository provides the official PyTorch im

Donggon Jang 12 Sep 26, 2022
The code for "Deep Level Set for Box-supervised Instance Segmentation in Aerial Images".

Deep Levelset for Box-supervised Instance Segmentation in Aerial Images Wentong Li, Yijie Chen, Wenyu Liu, Jianke Zhu* Any questions or discussions ar

sunshine.lwt 112 Jan 05, 2023
Efficiently computes derivatives of numpy code.

Note: Autograd is still being maintained but is no longer actively developed. The main developers (Dougal Maclaurin, David Duvenaud, Matt Johnson, and

Formerly: Harvard Intelligent Probabilistic Systems Group -- Now at Princeton 6.1k Jan 08, 2023
Data-depth-inference - Data depth inference with python

Welcome! This readme will guide you through the use of the code in this reposito

Marco 3 Feb 08, 2022
Predicting Event Memorability from Contextual Visual Semantics

Predicting Event Memorability from Contextual Visual Semantics

0 Oct 06, 2021
The code from the paper Character Transformations for Non-Autoregressive GEC Tagging

Character Transformations for Non-Autoregressive GEC Tagging Milan Straka, Jakub Náplava, Jana Straková Charles University Faculty of Mathematics and

ÚFAL 5 Dec 10, 2022
Original code for "Zero-Shot Domain Adaptation with a Physics Prior"

Zero-Shot Domain Adaptation with a Physics Prior [arXiv] [sup. material] - ICCV 2021 Oral paper, by Attila Lengyel, Sourav Garg, Michael Milford and J

Attila Lengyel 40 Dec 21, 2022
Code for the paper Language as a Cognitive Tool to Imagine Goals in Curiosity Driven Exploration

IMAGINE: Language as a Cognitive Tool to Imagine Goals in Curiosity Driven Exploration This repo contains the code base of the paper Language as a Cog

Flowers Team 26 Dec 22, 2022
FwordCTF 2021 Infrastructure and Source code of Web/Bash challenges

FwordCTF 2021 You can find here the source code of the challenges I wrote (Web and Bash) in FwordCTF 2021 and the source code of the platform with our

Kahla 5 Nov 25, 2022
Official repository for the paper, MidiBERT-Piano: Large-scale Pre-training for Symbolic Music Understanding.

MidiBERT-Piano Authors: Yi-Hui (Sophia) Chou, I-Chun (Bronwin) Chen Introduction This is the official repository for the paper, MidiBERT-Piano: Large-

137 Dec 15, 2022
RL Algorithms with examples in Python / Pytorch / Unity ML agents

Reinforcement Learning Project This project was created to make it easier to get started with Reinforcement Learning. It now contains: An implementati

Rogier Wachters 3 Aug 19, 2022
Lex Rosetta: Transfer of Predictive Models Across Languages, Jurisdictions, and Legal Domains

Lex Rosetta: Transfer of Predictive Models Across Languages, Jurisdictions, and Legal Domains This is an accompanying repository to the ICAIL 2021 pap

4 Dec 16, 2021
Learning to Disambiguate Strongly Interacting Hands via Probabilistic Per-Pixel Part Segmentation [3DV 2021 Oral]

Learning to Disambiguate Strongly Interacting Hands via Probabilistic Per-Pixel Part Segmentation [3DV 2021 Oral] Learning to Disambiguate Strongly In

Zicong Fan 40 Dec 22, 2022
Time-Optimal Planning for Quadrotor Waypoint Flight

Time-Optimal Planning for Quadrotor Waypoint Flight This is an example implementation of the paper "Time-Optimal Planning for Quadrotor Waypoint Fligh

Robotics and Perception Group 38 Dec 02, 2022
Mesh Graphormer is a new transformer-based method for human pose and mesh reconsruction from an input image

MeshGraphormer ✨ ✨ This is our research code of Mesh Graphormer. Mesh Graphormer is a new transformer-based method for human pose and mesh reconsructi

Microsoft 251 Jan 08, 2023
Implementation of Auto-Conditioned Recurrent Networks for Extended Complex Human Motion Synthesis

acLSTM_motion This folder contains an implementation of acRNN for the CMU motion database written in Pytorch. See the following links for more backgro

Yi_Zhou 61 Sep 07, 2022