Causal Influence Detection for Improving Efficiency in Reinforcement Learning

Overview

Causal Influence Detection for Improving Efficiency in Reinforcement Learning

This repository contains the code release for the paper "Causal Influence Detection for Improving Efficiency in Reinforcement Learning", published at NeurIPS 2021.

This work was done by Maximilian Seitzer, Bernhard Schölkopf and Georg Martius at the Autonomous Learning Group, Max-Planck Institute for Intelligent Systems.

If you make use of our work, please use the citation information below.

Abstract

Many reinforcement learning (RL) environments consist of independent entities that interact sparsely. In such environments, RL agents have only limited influence over other entities in any particular situation. Our idea in this work is that learning can be efficiently guided by knowing when and what the agent can influence with its actions. To achieve this, we introduce a measure of situation-dependent causal influence based on conditional mutual information and show that it can reliably detect states of influence. We then propose several ways to integrate this measure into RL algorithms to improve exploration and off-policy learning. All modified algorithms show strong increases in data efficiency on robotic manipulation tasks.

Setup

Use make_conda_env.sh to create a Conda environment with minimal dependencies:

./make_conda_env.sh minimal cid_in_rl

or recreate the environment used to get the results (more dependencies than necessary):

conda env create -f orig_environment.yml

Activate the environment with conda activate cid_in_rl.

Experiments

Causal Influence Detection

To reproduce the causal influence detection experiment, you will need to download the used datasets here. Extract them into the folder data/. The most simple way to run all experiments is to use the included Makefile (this will take a long time):

make -C experiments/1-influence

The results will be in the folder ./data/experiments/1-influence/.

You can also train a single model, for example

python -m cid.influence_estimation.train_model \
        --log-dir logs/eval_fetchpickandplace 
        --no-logging-subdir --seed 0 \
        --memory-path data/fetchpickandplace/memory_5k_her_agent_v2.npy \
        --val-memory-path data/fetchpickandplace/val_memory_2kof5k_her_agent_v2.npy \
        experiments/1-influence/pickandplace_model_gaussian.gin

which will train a model on FetchPickPlace, and put the results in logs/eval_fetchpickandplace.

To evaluate the CAI score performance of the model on the validation set, use

python experiments/1-influence/pickandplace_cmi.py 
    --output-path logs/eval_fetchpickandplace 
    --model-path logs/eval_fetchpickandplace
    --settings-path logs/eval_fetchpickandplace/eval_settings.gin \
    --memory-path data/fetchpickandplace/val_memory_2kof5k_her_agent_v2.npy 
    --variants var_prod_approx

Reinforcement Learning

The RL experiments can be reproduced using the settings in experiments/2-prioritization, experiments/3-exploration, experiments/4-other.

To do so, run

python -m cid.train 
   

   

By default, the output will be in the folder ./logs.

Codebase Overview

  • cid/algorithms/ddpg_agent.py contains the DDPG agent
  • cid/envs contains new environments
    • cid/envs/one_d_slide.py implements the 1D-Slide dataset
    • cid/envs/robotics/pick_and_place_rot_table.py implements the RotatingTable environment
    • cid/envs/robotics/fetch_control_detection.py contains the code for deriving ground truth control labels for FetchPickAndPlace
  • cid/influence_estimation contains code for model training, evaluation and computing the causal influence score
    • cid/influence_estimation/train_model.py is the main model training script
    • cid/influence_estimation/eval_influence.py evaluates a trained model for its classification performance
    • cid/influence_estimation/transition_scorers contains code for computing the CAI score
  • cid/memory/ contains the replay buffers, which handle prioritization and exploration bonuses
    • cid/memory/mbp implements CAI (ours)
    • cid/memory/her implements Hindsight Experience Replay
    • cid/memory/ebp implements Energy-Based Hindsight Experience Prioritization
    • cid/memory/per implements Prioritized Experience Replay
  • cid/models contains Pytorch model implementations
    • cid/bnn.py contains the implementation of VIME
  • cid/play.py lets a trained RL agent run in an environment
  • cid/train.py is the main RL training script

Citation

Please use the following citation if you make use of our work:

@inproceedings{Seitzer2021CID,
  title = {Causal Influence Detection for Improving Efficiency in Reinforcement Learning},
  author = {Seitzer, Maximilian and Sch{\"o}lkopf, Bernhard and Martius, Georg},
  booktitle = {Advances in Neural Information Processing Systems (NeurIPS 2021)},
  month = dec,
  year = {2021},
  url = {https://arxiv.org/abs/2106.03443},
  month_numeric = {12}
}

License

This implementation is licensed under the MIT license.

The robotics environments were adapted from OpenAI Gym under MIT license. The VIME implementation was adapted from https://github.com/alec-tschantz/vime under MIT license.

Owner
Autonomous Learning Group
Autonomous Learning Group
StyleSwin: Transformer-based GAN for High-resolution Image Generation

StyleSwin This repo is the official implementation of "StyleSwin: Transformer-based GAN for High-resolution Image Generation". By Bowen Zhang, Shuyang

Microsoft 349 Dec 28, 2022
Bringing Computer Vision and Flutter together , to build an awesome app !!

Bringing Computer Vision and Flutter together , to build an awesome app !! Explore the Directories Flutter · Machine Learning Table of Contents About

Padmanabha Banerjee 14 Apr 07, 2022
LLVIP: A Visible-infrared Paired Dataset for Low-light Vision

LLVIP: A Visible-infrared Paired Dataset for Low-light Vision Project | Arxiv | Abstract It is very challenging for various visual tasks such as image

CVSM Group - email: <a href=[email protected]"> 377 Jan 07, 2023
Adaptive Attention Span for Reinforcement Learning

Adaptive Transformers in RL Official implementation of Adaptive Transformers in RL In this work we replicate several results from Stabilizing Transfor

100 Nov 15, 2022
This is the official code for the paper "Ad2Attack: Adaptive Adversarial Attack for Real-Time UAV Tracking".

Ad^2Attack:Adaptive Adversarial Attack on Real-Time UAV Tracking Demo video 📹 Our video on bilibili demonstrates the test results of Ad^2Attack on se

Intelligent Vision for Robotics in Complex Environment 10 Nov 07, 2022
ATOMIC 2020: On Symbolic and Neural Commonsense Knowledge Graphs

(Comet-) ATOMIC 2020: On Symbolic and Neural Commonsense Knowledge Graphs Paper Jena D. Hwang, Chandra Bhagavatula, Ronan Le Bras, Jeff Da, Keisuke Sa

AI2 152 Dec 27, 2022
Mortgage-loan-prediction - Show how to perform advanced Analytics and Machine Learning in Python using a full complement of PyData utilities

Mortgage-loan-prediction - Show how to perform advanced Analytics and Machine Learning in Python using a full complement of PyData utilities

Deepak Nandwani 1 Dec 31, 2021
The code of paper "Block Modeling-Guided Graph Convolutional Neural Networks".

Block Modeling-Guided Graph Convolutional Neural Networks This repository contains the demo code of the paper: Block Modeling-Guided Graph Convolution

22 Dec 08, 2022
Sparse R-CNN: End-to-End Object Detection with Learnable Proposals, CVPR2021

End-to-End Object Detection with Learnable Proposal, CVPR2021

Peize Sun 1.2k Dec 27, 2022
Source code for "Progressive Transformers for End-to-End Sign Language Production" (ECCV 2020)

Progressive Transformers for End-to-End Sign Language Production Source code for "Progressive Transformers for End-to-End Sign Language Production" (B

58 Dec 21, 2022
TGS Salt Identification Challenge

TGS Salt Identification Challenge This is an open solution to the TGS Salt Identification Challenge. Note Unfortunately, we can no longer provide supp

neptune.ai 123 Nov 04, 2022
A Multi-modal Model Chinese Spell Checker Released on ACL2021.

ReaLiSe ReaLiSe is a multi-modal Chinese spell checking model. This the office code for the paper Read, Listen, and See: Leveraging Multimodal Informa

DaDa 106 Dec 29, 2022
This program creates a formatted excel file which highlights the undervalued stock according to Graham's number.

Over-and-Undervalued-Stocks Of Nepse Using Graham's Number Scrap the latest data using different websites and creates a formatted excel file that high

6 May 03, 2022
[ICCV 2021 Oral] Just Ask: Learning to Answer Questions from Millions of Narrated Videos

Just Ask: Learning to Answer Questions from Millions of Narrated Videos Webpage • Demo • Paper This repository provides the code for our paper, includ

Antoine Yang 87 Jan 05, 2023
Official implementation of "Learning Not to Reconstruct" (BMVC 2021)

Official PyTorch implementation of "Learning Not to Reconstruct Anomalies" This is the implementation of the paper "Learning Not to Reconstruct Anomal

Marcella Astrid 13 Dec 04, 2022
Narya API allows you track soccer player from camera inputs, and evaluate them with an Expected Discounted Goal (EDG) Agent

Narya The Narya API allows you track soccer player from camera inputs, and evaluate them with an Expected Discounted Goal (EDG) Agent. This repository

Paul Garnier 121 Dec 30, 2022
This project provides the code and datasets for 'CapSal: Leveraging Captioning to Boost Semantics for Salient Object Detection', CVPR 2019.

Code-and-Dataset-for-CapSal This project provides the code and datasets for 'CapSal: Leveraging Captioning to Boost Semantics for Salient Object Detec

lu zhang 48 Aug 19, 2022
Simple reference implementation of GraphSAGE.

Reference PyTorch GraphSAGE Implementation Author: William L. Hamilton Basic reference PyTorch implementation of GraphSAGE. This reference implementat

William L Hamilton 861 Jan 06, 2023
The code is for the paper "A Self-Distillation Embedded Supervised Affinity Attention Model for Few-Shot Segmentation"

SD-AANet The code is for the paper "A Self-Distillation Embedded Supervised Affinity Attention Model for Few-Shot Segmentation" [arxiv] Overview confi

cv516Buaa 9 Nov 07, 2022
AdvStyle - Official PyTorch Implementation

AdvStyle - Official PyTorch Implementation Paper | Supp Discovering Interpretable Latent Space Directions of GANs Beyond Binary Attributes. Huiting Ya

Beryl 37 Oct 21, 2022