Code for "Unsupervised State Representation Learning in Atari"

Overview

Unsupervised State Representation Learning in Atari

Ankesh Anand*, Evan Racah*, Sherjil Ozair*, Yoshua Bengio, Marc-Alexandre Cรดtรฉ, R Devon Hjelm

This repo provides code for the benchmark and techniques introduced in the paper Unsupervised State Representation Learning in Atari

Install

AtariARI Wrapper

You can do a minimal install to get just the AtariARI (Atari Annotated RAM Interface) wrapper by doing:

pip install 'gym[atari]'
pip install git+git://github.com/mila-iqia/atari-representation-learning.git

This just requires gym[atari] and it gives you the ability to play around with the AtariARI wrapper. If you want to use the code for training representation learning methods and probing them, you will need a full installation:

Full installation (AtariARI Wrapper + Training & Probing Code)

# PyTorch and scikit learn
conda install pytorch torchvision -c pytorch
conda install scikit-learn

# Baselines for Atari preprocessing
# Tensorflow is a dependency, but you don't need to install the GPU version
conda install tensorflow
pip install git+git://github.com/openai/baselines

# pytorch-a2c-ppo-acktr for RL utils
pip install git+git://github.com/ankeshanand/pytorch-a2c-ppo-acktr-gail

# Clone and install our package
pip install -r requirements.txt
pip install git+git://github.com/mila-iqia/atari-representation-learning.git

Usage

Atari Annotated RAM Interface (AtariARI):

AtariARI exposes the ground truth labels for different state variables for each observation. We have made AtariARI available as a Gym wrapper, to use it simply wrap an Atari gym env with AtariARIWrapper.

import gym
from atariari.benchmark.wrapper import AtariARIWrapper
env = AtariARIWrapper(gym.make('MsPacmanNoFrameskip-v4'))
obs = env.reset()
obs, reward, done, info = env.step(1)

Now, info is a dictionary of the form:

{'ale.lives': 3,
 'labels': {'enemy_sue_x': 88,
  'enemy_inky_x': 88,
  'enemy_pinky_x': 88,
  'enemy_blinky_x': 88,
  'enemy_sue_y': 80,
  'enemy_inky_y': 80,
  'enemy_pinky_y': 80,
  'enemy_blinky_y': 50,
  'player_x': 88,
  'player_y': 98,
  'fruit_x': 0,
  'fruit_y': 0,
  'ghosts_count': 3,
  'player_direction': 3,
  'dots_eaten_count': 0,
  'player_score': 0,
  'num_lives': 2}}

Note: In our experiments, we use additional preprocessing for Atari environments mainly following Minh et. al, 2014. See atariari/benchmark/envs.py for more info!

If you want the raw RAM annotations (which parts of ram correspond to each state variable), check out atariari/benchmark/ram_annotations.py

Probing


โš ๏ธ Important โš ๏ธ : The RAM labels are meant for full-sized Atari observations (210 * 160). Probing results won't be accurate if you downsample the observations.

We provide an interface for the included probing tasks.

First, get episodes for train, val and, test:

from atariari.benchmark.episodes import get_episodes

tr_episodes, val_episodes,\
tr_labels, val_labels,\
test_episodes, test_labels = get_episodes(env_name="PitfallNoFrameskip-v4", 
                                     steps=50000, 
                                     collect_mode="random_agent")

Then probe them using ProbeTrainer and your encoder (my_encoder):

from atariari.benchmark.probe import ProbeTrainer

probe_trainer = ProbeTrainer(my_encoder, representation_len=my_encoder.feature_size)
probe_trainer.train(tr_episodes, val_episodes,
                     tr_labels, val_labels,)
final_accuracies, final_f1_scores = probe_trainer.test(test_episodes, test_labels)

To see how we use ProbeTrainer, check out scripts/run_probe.py

Here is an example of my_encoder:

# get your encoder
import torch.nn as nn
import torch
class MyEncoder(nn.Module):
    def __init__(self, input_channels, feature_size):
        super().__init__()
        self.feature_size = feature_size
        self.input_channels = input_channels
        self.final_conv_size = 64 * 9 * 6
        self.cnn = nn.Sequential(
            nn.Conv2d(input_channels, 32, 8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2),
            nn.ReLU(),
            nn.Conv2d(128, 64, 3, stride=1),
            nn.ReLU()
        )
        self.fc = nn.Linear(self.final_conv_size, self.feature_size)

    def forward(self, inputs):
        x = self.cnn(inputs)
        x = x.view(x.size(0), -1)
        return self.fc(x)
        

my_encoder = MyEncoder(input_channels=1,feature_size=256)
# load in weights
my_encoder.load_state_dict(torch.load(open("path/to/my/weights.pt", "rb")))

Spatio-Temporal DeepInfoMax:

src/ contains implementations of several representation learning methods, along with ST-DIM. Here's a sample usage:

python -m scripts.run_probe --method infonce-stdim --env-name {env_name}

where env_name is of the form {game}NoFrameskip-v4, such as PongNoFrameskip-v4

Citation

@article{anand2019unsupervised,
  title={Unsupervised State Representation Learning in Atari},
  author={Anand, Ankesh and Racah, Evan and Ozair, Sherjil and Bengio, Yoshua and C{\^o}t{\'e}, Marc-Alexandre and Hjelm, R Devon},
  journal={arXiv preprint arXiv:1906.08226},
  year={2019}
}
Owner
Mila
Quebec Artificial Intelligence Institute
Mila
Airbus Ship Detection Challenge

Airbus Ship Detection Challenge This is an open solution to the Airbus Ship Detection Challenge. Our goals We are building entirely open solution to t

minerva.ml 55 Nov 29, 2022
Approximate Nearest Neighbors in C++/Python optimized for memory usage and loading/saving to disk

Annoy Annoy (Approximate Nearest Neighbors Oh Yeah) is a C++ library with Python bindings to search for points in space that are close to a given quer

Spotify 10.6k Jan 04, 2023
Deep Learning Training Scripts With Python

Deep Learning Training Scripts DNN Frameworks Caffe PyTorch Tensorflow CNN Models VGG ResNet DenseNet Inception Language Modeling GatedCNN-LM Attentio

Multicore Computing Research Lab 16 Dec 15, 2022
GPU-accelerated PyTorch implementation of Zero-shot User Intent Detection via Capsule Neural Networks

GPU-accelerated PyTorch implementation of Zero-shot User Intent Detection via Capsule Neural Networks This repository implements a capsule model Inten

Joel Huang 15 Dec 24, 2022
Current state of supervised and unsupervised depth completion methods

Awesome Depth Completion Table of Contents About Sparse-to-Dense Depth Completion Current State of Depth Completion Unsupervised VOID Benchmark Superv

224 Dec 28, 2022
Code for our paper "Sematic Representation for Dialogue Modeling" in ACL2021

AMR-Dialogue An implementation for paper "Semantic Representation for Dialogue Modeling". You may find our paper here. Requirements python 3.6 pytorch

xfbai 45 Dec 26, 2022
PyTorch implementation of U-TAE and PaPs for satellite image time series panoptic segmentation.

Panoptic Segmentation of Satellite Image Time Series with Convolutional Temporal Attention Networks (ICCV 2021) This repository is the official implem

71 Jan 04, 2023
Official code for "End-to-End Optimization of Scene Layout" -- including VAE, Diff Render, SPADE for colorization (CVPR 2020 Oral)

End-to-End Optimization of Scene Layout Code release for: End-to-End Optimization of Scene Layout CVPR 2020 (Oral) Project site, Bibtex For help conta

Andrew Luo 41 Dec 09, 2022
Learning Time-Critical Responses for Interactive Character Control

Learning Time-Critical Responses for Interactive Character Control Abstract This code implements the paper Learning Time-Critical Responses for Intera

Movement Research Lab 227 Dec 31, 2022
MNE: Magnetoencephalography (MEG) and Electroencephalography (EEG) in Python

MNE-Python MNE-Python software is an open-source Python package for exploring, visualizing, and analyzing human neurophysiological data such as MEG, E

MNE tools for MEG and EEG data analysis 2.1k Dec 28, 2022
Riemann Noise Injection With PyTorch

Riemann Noise Injection - PyTorch A module for modeling GAN noise injection based on Riemann geometry, as described in Ruili Feng, Deli Zhao, and Zhen

2 May 27, 2022
Job Assignment System by Real-time Emotion Detection

Emotion-Detection Job Assignment System by Real-time Emotion Detection Emotion is the essential role of facial expression and it could provide a lot o

1 Feb 08, 2022
A library for uncertainty quantification based on PyTorch

Torchuq [logo here] TorchUQ is an extensive library for uncertainty quantification (UQ) based on pytorch. TorchUQ currently supports 10 representation

TorchUQ 96 Dec 12, 2022
Keeping it safe - AI Based COVID-19 Tracker using Deep Learning and facial recognition

Keeping it safe - AI Based COVID-19 Tracker using Deep Learning and facial recognition

Vansh Wassan 15 Jun 17, 2021
Pmapper is a super-resolution and deconvolution toolkit for python 3.6+

pmapper pmapper is a super-resolution and deconvolution toolkit for python 3.6+. PMAP stands for Poisson Maximum A-Posteriori, a highly flexible and a

NASA Jet Propulsion Laboratory 8 Nov 06, 2022
Implicit Model Specialization through DAG-based Decentralized Federated Learning

Federated Learning DAG Experiments This repository contains software artifacts to reproduce the experiments presented in the Middleware '21 paper "Imp

Operating Systems and Middleware Group 5 Oct 16, 2022
๐Ÿ’ƒ VALSE: A Task-Independent Benchmark for Vision and Language Models Centered on Linguistic Phenomena

๐Ÿ’ƒ VALSE: A Task-Independent Benchmark for Vision and Language Models Centered on Linguistic Phenomena.

Heidelberg-NLP 17 Nov 07, 2022
Using CNN to mimic the driver based on training data from Torcs

Behavioural-Cloning-in-autonomous-driving Using CNN to mimic the driver based on training data from Torcs. Approach First, the data was collected from

Sudharshan 2 Jan 05, 2022
Pytorch implementation of winner from VQA Chllange Workshop in CVPR'17

2017 VQA Challenge Winner (CVPR'17 Workshop) pytorch implementation of Tips and Tricks for Visual Question Answering: Learnings from the 2017 Challeng

Mark Dong 166 Dec 11, 2022
PyTorch implementation for Score-Based Generative Modeling through Stochastic Differential Equations (ICLR 2021, Oral)

Score-Based Generative Modeling through Stochastic Differential Equations This repo contains a PyTorch implementation for the paper Score-Based Genera

Yang Song 757 Jan 04, 2023