Contrastive Learning of Structured World Models

Related tags

Deep Learningc-swm
Overview

Contrastive Learning of Structured World Models

This repository contains the official PyTorch implementation of:

Contrastive Learning of Structured World Models.
Thomas Kipf, Elise van der Pol, Max Welling.
http://arxiv.org/abs/1911.12247

C-SWM

Abstract: A structured understanding of our world in terms of objects, relations, and hierarchies is an important component of human cognition. Learning such a structured world model from raw sensory data remains a challenge. As a step towards this goal, we introduce Contrastively-trained Structured World Models (C-SWMs). C-SWMs utilize a contrastive approach for representation learning in environments with compositional structure. We structure each state embedding as a set of object representations and their relations, modeled by a graph neural network. This allows objects to be discovered from raw pixel observations without direct supervision as part of the learning process. We evaluate C-SWMs on compositional environments involving multiple interacting objects that can be manipulated independently by an agent, simple Atari games, and a multi-object physics simulation. Our experiments demonstrate that C-SWMs can overcome limitations of models based on pixel reconstruction and outperform typical representatives of this model class in highly structured environments, while learning interpretable object-based representations.

Requirements

  • Python 3.6 or 3.7
  • PyTorch version 1.2
  • OpenAI Gym version: 0.12.0 pip install gym==0.12.0
  • OpenAI Atari_py version: 0.1.4: pip install atari-py==0.1.4
  • Scikit-image version 0.15.0 pip install scikit-image==0.15.0
  • Matplotlib version 3.0.2 pip install matplotlib==3.0.2

Generate datasets

2D Shapes:

python data_gen/env.py --env_id ShapesTrain-v0 --fname data/shapes_train.h5 --num_episodes 1000 --seed 1
python data_gen/env.py --env_id ShapesEval-v0 --fname data/shapes_eval.h5 --num_episodes 10000 --seed 2

3D Cubes:

python data_gen/env.py --env_id CubesTrain-v0 --fname data/cubes_train.h5 --num_episodes 1000 --seed 3
python data_gen/env.py --env_id CubesEval-v0 --fname data/cubes_eval.h5 --num_episodes 10000 --seed 4

Atari Pong:

python data_gen/env.py --env_id PongDeterministic-v4 --fname data/pong_train.h5 --num_episodes 1000 --atari --seed 1
python data_gen/env.py --env_id PongDeterministic-v4 --fname data/pong_eval.h5 --num_episodes 100 --atari --seed 2

Space Invaders:

python data_gen/env.py --env_id SpaceInvadersDeterministic-v4 --fname data/spaceinvaders_train.h5 --num_episodes 1000 --atari --seed 1
python data_gen/env.py --env_id SpaceInvadersDeterministic-v4 --fname data/spaceinvaders_eval.h5 --num_episodes 100 --atari --seed 2

3-Body Gravitational Physics:

python data_gen/physics.py --num-episodes 5000 --fname data/balls_train.h5 --seed 1
python data_gen/physics.py --num-episodes 1000 --fname data/balls_eval.h5 --eval --seed 2

Run model training and evaluation

2D Shapes:

python train.py --dataset data/shapes_train.h5 --encoder small --name shapes
python eval.py --dataset data/shapes_eval.h5 --save-folder checkpoints/shapes --num-steps 1

3D Cubes:

python train.py --dataset data/cubes_train.h5 --encoder large --name cubes
python eval.py --dataset data/cubes_eval.h5 --save-folder checkpoints/cubes --num-steps 1

Atari Pong:

python train.py --dataset data/pong_train.h5 --encoder medium --embedding-dim 4 --action-dim 6 --num-objects 3 --copy-action --epochs 200 --name pong
python eval.py --dataset data/pong_eval.h5 --save-folder checkpoints/pong --num-steps 1

Space Invaders:

python train.py --dataset data/spaceinvaders_train.h5 --encoder medium --embedding-dim 4 --action-dim 6 --num-objects 3 --copy-action --epochs 200 --name spaceinvaders
python eval.py --dataset data/spaceinvaders_eval.h5 --save-folder checkpoints/spaceinvaders --num-steps 1

3-Body Gravitational Physics:

python train.py --dataset data/balls_train.h5 --encoder medium --embedding-dim 4 --num-objects 3 --ignore-action --name balls
python eval.py --dataset data/balls_eval.h5 --save-folder checkpoints/balls --num-steps 1

Cite

If you make use of this code in your own work, please cite our paper:

@article{kipf2019contrastive,
  title={Contrastive Learning of Structured World Models}, 
  author={Kipf, Thomas and van der Pol, Elise and Welling, Max}, 
  journal={arXiv preprint arXiv:1911.12247}, 
  year={2019} 
}
Owner
Thomas Kipf
Thomas Kipf
The official implementation of the Hybrid Self-Attention NEAT algorithm

PUREPLES - Pure Python Library for ES-HyperNEAT About This is a library of evolutionary algorithms with a focus on neuroevolution, implemented in pure

Adrian Westh 91 Dec 12, 2022
Python utility to generate filesystem content for Obsidian.

Security Vault Generator Quickly parse, format, and output common frameworks/content for Obsidian.md. There is a strong focus on MITRE ATT&CK because

Justin Angel 73 Dec 02, 2022
On the adaptation of recurrent neural networks for system identification

On the adaptation of recurrent neural networks for system identification This repository contains the Python code to reproduce the results of the pape

Marco Forgione 3 Jan 13, 2022
A clean and scalable template to kickstart your deep learning project 🚀 ⚡ 🔥

Lightning-Hydra-Template A clean and scalable template to kickstart your deep learning project 🚀 ⚡ 🔥 Click on Use this template to initialize new re

Hyunsoo Cho 1 Dec 20, 2021
ICS 4u HD project, start before-wards. A curtain shooting game using python.

Touhou-Star-Salvation HDCH ICS 4u HD project, start before-wards. A curtain shooting game using python and pygame. By Jason Li For arts and gameplay,

15 Dec 22, 2022
Hypersearch weight debugging and losses tutorial

tutorial Activate tensorboard option Running TensorBoard remotely When working on a remote server, you can use SSH tunneling to forward the port of th

1 Dec 11, 2021
A python interface for training Reinforcement Learning bots to battle on pokemon showdown

The pokemon showdown Python environment A Python interface to create battling pokemon agents. poke-env offers an easy-to-use interface for creating ru

Haris Sahovic 184 Dec 30, 2022
This is the official PyTorch implementation of the paper "TransFG: A Transformer Architecture for Fine-grained Recognition" (Ju He, Jie-Neng Chen, Shuai Liu, Adam Kortylewski, Cheng Yang, Yutong Bai, Changhu Wang, Alan Yuille).

TransFG: A Transformer Architecture for Fine-grained Recognition Official PyTorch code for the paper: TransFG: A Transformer Architecture for Fine-gra

Ju He 307 Jan 03, 2023
[CVPR 2022] Official Pytorch code for OW-DETR: Open-world Detection Transformer

OW-DETR: Open-world Detection Transformer (CVPR 2022) [Paper] Akshita Gupta*, Sanath Narayan*, K J Joseph, Salman Khan, Fahad Shahbaz Khan, Mubarak Sh

Akshita Gupta 127 Dec 27, 2022
Hyper-parameter optimization for sklearn

hyperopt-sklearn Hyperopt-sklearn is Hyperopt-based model selection among machine learning algorithms in scikit-learn. See how to use hyperopt-sklearn

1.4k Jan 01, 2023
A torch.Tensor-like DataFrame library supporting multiple execution runtimes and Arrow as a common memory format

TorchArrow (Warning: Unstable Prototype) This is a prototype library currently under heavy development. It does not currently have stable releases, an

Facebook Research 536 Jan 06, 2023
Certis - Certis, A High-Quality Backtesting Engine

Certis - Backtesting For y'all Certis is a powerful, lightweight, simple backtes

Yeachan-Heo 46 Oct 30, 2022
A Low Complexity Speech Enhancement Framework for Full-Band Audio (48kHz) based on Deep Filtering.

DeepFilterNet A Low Complexity Speech Enhancement Framework for Full-Band Audio (48kHz) based on Deep Filtering. libDF contains Rust code used for dat

Hendrik Schröter 292 Dec 25, 2022
Code for "Continuous-Time Meta-Learning with Forward Mode Differentiation" (ICLR 2022)

Continuous-Time Meta-Learning with Forward Mode Differentiation ICLR 2022 (Spotlight) - Installation - Example - Citation This repository contains the

Tristan Deleu 25 Oct 20, 2022
Code for Dual Contrastive Learning for Unsupervised Image-to-Image Translation, NTIRE, CVPRW 2021.

arXiv Dual Contrastive Learning Adversarial Generative Networks (DCLGAN) We provide our PyTorch implementation of DCLGAN, which is a simple yet powerf

119 Dec 04, 2022
Gray Zone Assessment

Gray Zone Assessment Get started Clone github repository git clone https://github.com/andreanne-lemay/gray_zone_assessment.git Build docker image dock

1 Jan 08, 2022
Kaggle-titanic - A tutorial for Kaggle's Titanic: Machine Learning from Disaster competition. Demonstrates basic data munging, analysis, and visualization techniques. Shows examples of supervised machine learning techniques.

Kaggle-titanic This is a tutorial in an IPython Notebook for the Kaggle competition, Titanic Machine Learning From Disaster. The goal of this reposito

Andrew Conti 800 Dec 15, 2022
I explore rock vs. mine prediction using a SONAR dataset

I explore rock vs. mine prediction using a SONAR dataset. Using a Logistic Regression Model for my prediction algorithm, I intend on predicting what an object is based on supervised learning.

Jeff Shen 1 Jan 11, 2022
NLG evaluation via Statistical Measures of Similarity: BaryScore, DepthScore, InfoLM

NLG evaluation via Statistical Measures of Similarity: BaryScore, DepthScore, InfoLM Automatic Evaluation Metric described in the papers BaryScore (EM

Pierre Colombo 28 Dec 28, 2022
Code for "Modeling Indirect Illumination for Inverse Rendering", CVPR 2022

Modeling Indirect Illumination for Inverse Rendering Project Page | Paper | Data Preparation Set up the python environment conda create -n invrender p

ZJU3DV 116 Jan 03, 2023