Training a Resilient Q-Network against Observational Interference, Causal Inference Q-Networks

Overview

Obs-Causal-Q-Network

AAAI 2022 - Training a Resilient Q-Network against Observational Interference

Preprint | Slides | Colab Demo | PyTorch

Environment Setup

  • option 1 (from conda .yml under conda 10.2 and python 3.6)
conda env create -f obs-causal-q-conda.yml 
  • option 2 (from a clean python 3.6 and please follow the setup of UnityAgent 3D environment for Banana Navigator )
pip install torch torchvision torchaudio
pip install dowhy
pip install gym

1. Example of Training Causal Inference Q-Network (CIQ) on Cartpole

  • Run Causal Inference Q-Network Training (--network 1 for Treatment Inference Q-network)
python 0-cartpole-main.py --network 1
  • Causal Inference Q-Network Architecture

  • Output Logs
observation space: Box(4,)
action space: Discrete(2)
Timing Atk Ratio: 10%
Using CEQNetwork_1. Number of Params: 41872
 Interference Type: 1  Use baseline:  0 use CGM:  1
With:  10.42 % timing attack
Episode 0   Score: 48.00, Average Score: 48.00, Loss: 1.71
With:  0.0 % timing attack
Episode 20   Score: 15.00, Average Score: 18.71, Loss: 30.56
With:  3.57 % timing attack
Episode 40   Score: 28.00, Average Score: 19.83, Loss: 36.36
With:  8.5 % timing attack
Episode 60   Score: 200.00, Average Score: 43.65, Loss: 263.29
With:  9.0 % timing attack
Episode 80   Score: 200.00, Average Score: 103.53, Loss: 116.35
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 193.4
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 164.2
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 147.8
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 193.4
With:  9.5 % timing attack
Episode 100   Score: 200.00, Average Score: 163.20, Loss: 77.38
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 198.4
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 200.0
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 200.0
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 197.8
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 200.0
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 200.0
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 200.0
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 197.6
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 198.6
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 200.0
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 200.0
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 199.0
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 200.0
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 186.8
Using CEQNetwork_1. Number of Params: 41872
### Evaluation Phase & Report DQNs Test Score : 200.0

Environment solved in 114 episodes!     Average Score: 195.55
Environment solved in 114 episodes!     Average Score: 195.55 +- 25.07
############# Basic Evaluate #############
Using CEQNetwork_1. Number of Params: 41872
Evaluate Score : 200.0
############# Noise Evaluate #############
Using CEQNetwork_1. Number of Params: 41872
Robust Score : 200.0

2. Example of Training a "Variational" Causal Inference Q-Network on Unity 3D Banana Navigator

  • Run Variational Causal Inference Q-Networks (VCIQs) Training (--network 3 for Causal Variational Inference)
python 1-banana-navigator-main.py --network 3
  • Variational Causal Inference Q-Network Architecture

  • Output Logs
'Academy' started successfully!
Unity Academy name: Academy
        Number of Brains: 1
        Number of External Brains : 1
        Lesson number : 0
        Reset Parameters :

Unity brain name: BananaBrain
        Number of Visual Observations (per agent): 0
        Vector Observation space type: continuous
        Vector Observation space size (per agent): 37
        Number of stacked Vector Observation: 1
        Vector Action space type: discrete
        Vector Action space size (per agent): 4
        Vector Action descriptions: , , , 
Timing Atk Ratio: 10%
Using CEVAE_QNetwork.
Unity Worker id: 10  T: 1  Use baseline:  0  CEVAE:  1
With:  9.67 % timing attack
Episode 0   Score: 0.00, Average Score: 0.00
With:  11.0 % timing attack
Episode 5   Score: 1.00, Average Score: 0.17
With:  11.33 % timing attack
Episode 10   Score: 0.00, Average Score: 0.36
With:  10.33 % timing attack
Episode 15   Score: 0.00, Average Score: 0.56
...
Episode 205   Score: 10.00, Average Score: 9.25
With:  9.33 % timing attack
Episode 210   Score: 9.00, Average Score: 9.70
With:  9.0 % timing attack
Episode 215   Score: 10.00, Average Score: 11.10
With:  8.33 % timing attack
Episode 220   Score: 14.00, Average Score: 10.85
With:  12.33 % timing attack
Episode 225   Score: 19.00, Average Score: 11.70
With:  11.0 % timing attack
Episode 230   Score: 18.00, Average Score: 12.10
With:  7.67 % timing attack
Episode 235   Score: 21.00, Average Score: 11.60
With:  9.67 % timing attack
Episode 240   Score: 16.00, Average Score: 12.05

Environment solved in 242 episodes!     Average Score: 12.50
Environment solved in 242 episodes!     Average Score: 12.50 +- 4.87
############# Basic Evaluate #############
Using CEVAE_QNetwork.
Evaluate Score : 12.6
############# Noise Evaluate #############
Using CEVAE_QNetwork.
Robust Score : 12.5

Reference

This fun work was initialzed when Danny and I first read the Causal Variational Model between 2018 to 2019 with the helps from Dr. Yi Ouyang and Dr. Pin-Yu Chen.

Please consider to reference the paper if you find this work helpful or relative to your research.

@article{yang2021causal,
  title={Causal Inference Q-Network: Toward Resilient Reinforcement Learning},
  author={Yang, Chao-Han Huck and Hung, I and Danny, Te and Ouyang, Yi and Chen, Pin-Yu},
  journal={arXiv preprint arXiv:2102.09677},
  year={2021}
}
Owner
Speech, Privacy, Robust RL, and Causal Inference.
Pytorch implementation for "Adversarial Robustness under Long-Tailed Distribution" (CVPR 2021 Oral)

Adversarial Long-Tail This repository contains the PyTorch implementation of the paper: Adversarial Robustness under Long-Tailed Distribution, CVPR 20

Tong WU 89 Dec 15, 2022
An unopinionated replacement for PyTorch's Dataset and ImageFolder, that handles Tar archives

Simple Tar Dataset An unopinionated replacement for PyTorch's Dataset and ImageFolder classes, for datasets stored as uncompressed Tar archives. Just

Joao Henriques 47 Dec 20, 2022
Improving Generalization Bounds for VC Classes Using the Hypergeometric Tail Inversion

Improving Generalization Bounds for VC Classes Using the Hypergeometric Tail Inversion Preface This directory provides an implementation of the algori

Jean-Samuel Leboeuf 0 Nov 03, 2021
HMLET (Hybrid-Method-of-Linear-and-non-linEar-collaborative-filTering-method)

Methods HMLET (Hybrid-Method-of-Linear-and-non-linEar-collaborative-filTering-method) Dynamically selecting the best propagation method for each node

Yong 7 Dec 18, 2022
Self-Supervised Image Denoising via Iterative Data Refinement

Self-Supervised Image Denoising via Iterative Data Refinement Yi Zhang1, Dasong Li1, Ka Lung Law2, Xiaogang Wang1, Hongwei Qin2, Hongsheng Li1 1CUHK-S

Zhang Yi 72 Jan 01, 2023
SpeechBrain is an open-source and all-in-one speech toolkit based on PyTorch.

The SpeechBrain Toolkit SpeechBrain is an open-source and all-in-one speech toolkit based on PyTorch. The goal is to create a single, flexible, and us

SpeechBrain 5.1k Jan 02, 2023
Diverse Image Captioning with Context-Object Split Latent Spaces (NeurIPS 2020)

Diverse Image Captioning with Context-Object Split Latent Spaces This repository is the PyTorch implementation of the paper: Diverse Image Captioning

Visual Inference Lab @TU Darmstadt 34 Nov 21, 2022
Simple PyTorch hierarchical models.

A python package adding basic hierarchal networks in pytorch for classification tasks. It implements a simple hierarchal network structure based on feed-backward outputs.

Rajiv Sarvepalli 5 Mar 06, 2022
Scaling and Benchmarking Self-Supervised Visual Representation Learning

FAIR Self-Supervision Benchmark is deprecated. Please see VISSL, a ground-up rewrite of benchmark in PyTorch. FAIR Self-Supervision Benchmark This cod

Meta Research 584 Dec 31, 2022
[CVPR 2022] Official PyTorch Implementation for "Reference-based Video Super-Resolution Using Multi-Camera Video Triplets"

Reference-based Video Super-Resolution (RefVSR) Official PyTorch Implementation of the CVPR 2022 Paper Project | arXiv | RealMCVSR Dataset This repo c

Junyong Lee 151 Dec 30, 2022
Deeplab-resnet-101 in Pytorch with Jaccard loss

Deeplab-resnet-101 Pytorch with Lovász hinge loss Train deeplab-resnet-101 with binary Jaccard loss surrogate, the Lovász hinge, as described in http:

Maxim Berman 95 Apr 15, 2022
Facebook AI Image Similarity Challenge: Descriptor Track

Facebook AI Image Similarity Challenge: Descriptor Track This repository contains the code for our solution to the Facebook AI Image Similarity Challe

Sergio MP 17 Dec 14, 2022
NeurIPS 2021 paper 'Representation Learning on Spatial Networks' code

Representation Learning on Spatial Networks This repository is the official implementation of Representation Learning on Spatial Networks. Training Ex

13 Dec 29, 2022
To propose and implement a multi-class classification approach to disaster assessment from the given data set of post-earthquake satellite imagery.

To propose and implement a multi-class classification approach to disaster assessment from the given data set of post-earthquake satellite imagery.

Kunal Wadhwa 2 Jan 05, 2022
Code & Data for the Paper "Time Masking for Temporal Language Models", WSDM 2022

Time Masking for Temporal Language Models This repository provides a reference implementation of the paper: Time Masking for Temporal Language Models

Guy Rosin 12 Jan 06, 2023
A Fast and Stable GAN for Small and High Resolution Imagesets - pytorch

A Fast and Stable GAN for Small and High Resolution Imagesets - pytorch The official pytorch implementation of the paper "Towards Faster and Stabilize

Bingchen Liu 455 Jan 08, 2023
PFENet: Prior Guided Feature Enrichment Network for Few-shot Segmentation (TPAMI).

PFENet This is the implementation of our paper PFENet: Prior Guided Feature Enrichment Network for Few-shot Segmentation that has been accepted to IEE

DV Lab 230 Dec 31, 2022
toroidal - a lightweight transformer library for PyTorch

toroidal - a lightweight transformer library for PyTorch Toroidal transformers are of smaller size and lower weight than the more common E-I types. Th

MathInf GmbH 64 Jan 07, 2023
Disease Informed Neural Networks (DINNs) — neural networks capable of learning how diseases spread, forecasting their progression, and finding their unique parameters (e.g. death rate).

DINN We introduce Disease Informed Neural Networks (DINNs) — neural networks capable of learning how diseases spread, forecasting their progression, a

19 Dec 10, 2022