Code for "Unsupervised State Representation Learning in Atari"

Overview

Unsupervised State Representation Learning in Atari

Ankesh Anand*, Evan Racah*, Sherjil Ozair*, Yoshua Bengio, Marc-Alexandre Côté, R Devon Hjelm

This repo provides code for the benchmark and techniques introduced in the paper Unsupervised State Representation Learning in Atari

Install

AtariARI Wrapper

You can do a minimal install to get just the AtariARI (Atari Annotated RAM Interface) wrapper by doing:

pip install 'gym[atari]'
pip install git+git://github.com/mila-iqia/atari-representation-learning.git

This just requires gym[atari] and it gives you the ability to play around with the AtariARI wrapper. If you want to use the code for training representation learning methods and probing them, you will need a full installation:

Full installation (AtariARI Wrapper + Training & Probing Code)

# PyTorch and scikit learn
conda install pytorch torchvision -c pytorch
conda install scikit-learn

# Baselines for Atari preprocessing
# Tensorflow is a dependency, but you don't need to install the GPU version
conda install tensorflow
pip install git+git://github.com/openai/baselines

# pytorch-a2c-ppo-acktr for RL utils
pip install git+git://github.com/ankeshanand/pytorch-a2c-ppo-acktr-gail

# Clone and install our package
pip install -r requirements.txt
pip install git+git://github.com/mila-iqia/atari-representation-learning.git

Usage

Atari Annotated RAM Interface (AtariARI):

AtariARI exposes the ground truth labels for different state variables for each observation. We have made AtariARI available as a Gym wrapper, to use it simply wrap an Atari gym env with AtariARIWrapper.

import gym
from atariari.benchmark.wrapper import AtariARIWrapper
env = AtariARIWrapper(gym.make('MsPacmanNoFrameskip-v4'))
obs = env.reset()
obs, reward, done, info = env.step(1)

Now, info is a dictionary of the form:

{'ale.lives': 3,
 'labels': {'enemy_sue_x': 88,
  'enemy_inky_x': 88,
  'enemy_pinky_x': 88,
  'enemy_blinky_x': 88,
  'enemy_sue_y': 80,
  'enemy_inky_y': 80,
  'enemy_pinky_y': 80,
  'enemy_blinky_y': 50,
  'player_x': 88,
  'player_y': 98,
  'fruit_x': 0,
  'fruit_y': 0,
  'ghosts_count': 3,
  'player_direction': 3,
  'dots_eaten_count': 0,
  'player_score': 0,
  'num_lives': 2}}

Note: In our experiments, we use additional preprocessing for Atari environments mainly following Minh et. al, 2014. See atariari/benchmark/envs.py for more info!

If you want the raw RAM annotations (which parts of ram correspond to each state variable), check out atariari/benchmark/ram_annotations.py

Probing


⚠️ Important ⚠️ : The RAM labels are meant for full-sized Atari observations (210 * 160). Probing results won't be accurate if you downsample the observations.

We provide an interface for the included probing tasks.

First, get episodes for train, val and, test:

from atariari.benchmark.episodes import get_episodes

tr_episodes, val_episodes,\
tr_labels, val_labels,\
test_episodes, test_labels = get_episodes(env_name="PitfallNoFrameskip-v4", 
                                     steps=50000, 
                                     collect_mode="random_agent")

Then probe them using ProbeTrainer and your encoder (my_encoder):

from atariari.benchmark.probe import ProbeTrainer

probe_trainer = ProbeTrainer(my_encoder, representation_len=my_encoder.feature_size)
probe_trainer.train(tr_episodes, val_episodes,
                     tr_labels, val_labels,)
final_accuracies, final_f1_scores = probe_trainer.test(test_episodes, test_labels)

To see how we use ProbeTrainer, check out scripts/run_probe.py

Here is an example of my_encoder:

# get your encoder
import torch.nn as nn
import torch
class MyEncoder(nn.Module):
    def __init__(self, input_channels, feature_size):
        super().__init__()
        self.feature_size = feature_size
        self.input_channels = input_channels
        self.final_conv_size = 64 * 9 * 6
        self.cnn = nn.Sequential(
            nn.Conv2d(input_channels, 32, 8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2),
            nn.ReLU(),
            nn.Conv2d(128, 64, 3, stride=1),
            nn.ReLU()
        )
        self.fc = nn.Linear(self.final_conv_size, self.feature_size)

    def forward(self, inputs):
        x = self.cnn(inputs)
        x = x.view(x.size(0), -1)
        return self.fc(x)
        

my_encoder = MyEncoder(input_channels=1,feature_size=256)
# load in weights
my_encoder.load_state_dict(torch.load(open("path/to/my/weights.pt", "rb")))

Spatio-Temporal DeepInfoMax:

src/ contains implementations of several representation learning methods, along with ST-DIM. Here's a sample usage:

python -m scripts.run_probe --method infonce-stdim --env-name {env_name}

where env_name is of the form {game}NoFrameskip-v4, such as PongNoFrameskip-v4

Citation

@article{anand2019unsupervised,
  title={Unsupervised State Representation Learning in Atari},
  author={Anand, Ankesh and Racah, Evan and Ozair, Sherjil and Bengio, Yoshua and C{\^o}t{\'e}, Marc-Alexandre and Hjelm, R Devon},
  journal={arXiv preprint arXiv:1906.08226},
  year={2019}
}
Owner
Mila
Quebec Artificial Intelligence Institute
Mila
Implementation of experiments in the paper Clockwork Variational Autoencoders (project website) using JAX and Flax

Clockwork VAEs in JAX/Flax Implementation of experiments in the paper Clockwork Variational Autoencoders (project website) using JAX and Flax, ported

Julius Kunze 26 Oct 05, 2022
PyTorch implementation for Graph Contrastive Learning with Augmentations

Graph Contrastive Learning with Augmentations PyTorch implementation for Graph Contrastive Learning with Augmentations [poster] [appendix] Yuning You*

Shen Lab at Texas A&M University 382 Dec 15, 2022
NuPIC Studio is an all­-in-­one tool that allows users create a HTM neural network from scratch

NuPIC Studio is an all­-in-­one tool that allows users create a HTM neural network from scratch, train it, collect statistics, and share it among the members of the community. It is not just a visual

HTM Community 93 Sep 30, 2022
Yas CRNN model training - Yet Another Genshin Impact Scanner

Yas-Train Yet Another Genshin Impact Scanner 又一个原神圣遗物导出器 介绍 该仓库为 Yas 的模型训练程序 相关资料 MobileNetV3 CRNN 使用 假设你会设置基本的pytorch环境。 生成数据集 python main.py gen 训练

wormtql 18 Jan 08, 2023
This is a file about Unet implemented in Pytorch

Unet this is an implemetion of Unet in Pytorch and it's architecture is as follows which is the same with paper of Unet component of Unet Convolution

Dragon 1 Dec 03, 2021
Official Pytorch implementation for AAAI2021 paper (RSPNet: Relative Speed Perception for Unsupervised Video Representation Learning)

RSPNet Official Pytorch implementation for AAAI2021 paper "RSPNet: Relative Speed Perception for Unsupervised Video Representation Learning" [Suppleme

35 Jun 24, 2022
This is an official pytorch implementation of Fast Fourier Convolution.

Fast Fourier Convolution (FFC) for Image Classification This is the official code of Fast Fourier Convolution for image classification on ImageNet. Ma

pkumi 199 Jan 03, 2023
RepMLP: Re-parameterizing Convolutions into Fully-connected Layers for Image Recognition

RepMLP: Re-parameterizing Convolutions into Fully-connected Layers for Image Recognition (PyTorch) Paper: https://arxiv.org/abs/2105.01883 Citation: @

260 Jan 03, 2023
[ICCV 2021 Oral] PoinTr: Diverse Point Cloud Completion with Geometry-Aware Transformers

PoinTr: Diverse Point Cloud Completion with Geometry-Aware Transformers Created by Xumin Yu*, Yongming Rao*, Ziyi Wang, Zuyan Liu, Jiwen Lu, Jie Zhou

Xumin Yu 317 Dec 26, 2022
ARAE-Tensorflow for Discrete Sequences (Adversarially Regularized Autoencoder)

ARAE Tensorflow Code Code for the paper Adversarially Regularized Autoencoders for Generating Discrete Structures by Zhao, Kim, Zhang, Rush and LeCun

19 Nov 12, 2021
Codes and Data Processing Files for our paper.

Code Scripts and Processing Files for EEG Sleep Staging Paper 1. Folder Tree ./src_preprocess (data preprocessing files for SHHS and Sleep EDF) sleepE

Chaoqi Yang 18 Dec 12, 2022
Easy to use Audio Tagging in PyTorch

Audio Classification, Tagging & Sound Event Detection in PyTorch Progress: Fine-tune on audio classification Fine-tune on audio tagging Fine-tune on s

sithu3 15 Dec 22, 2022
EfficientMPC - Efficient Model Predictive Control Implementation

efficientMPC Efficient Model Predictive Control Implementation The original algo

Vin 8 Dec 04, 2022
novel deep learning research works with PaddlePaddle

Research 发布基于飞桨的前沿研究工作,包括CV、NLP、KG、STDM等领域的顶会论文和比赛冠军模型。 目录 计算机视觉(Computer Vision) 自然语言处理(Natrual Language Processing) 知识图谱(Knowledge Graph) 时空数据挖掘(Spa

1.5k Dec 29, 2022
Unsupervised Feature Ranking via Attribute Networks.

FRANe Unsupervised Feature Ranking via Attribute Networks (FRANe) converts a dataset into a network (graph) with nodes that correspond to the features

7 Sep 29, 2022
Introducing neural networks to predict stock prices

IntroNeuralNetworks in Python: A Template Project IntroNeuralNetworks is a project that introduces neural networks and illustrates an example of how o

Vivek Palaniappan 637 Jan 04, 2023
Representing Long-Range Context for Graph Neural Networks with Global Attention

Graph Augmentation Graph augmentation/self-supervision/etc. Algorithms gcn gcn+virtual node gin gin+virtual node PNA GraphTrans Augmentation methods N

UC Berkeley RISE 67 Dec 30, 2022
Learning Saliency Propagation for Semi-supervised Instance Segmentation

Learning Saliency Propagation for Semi-supervised Instance Segmentation PyTorch Implementation This repository contains: the PyTorch implementation of

Berkeley DeepDrive 68 Oct 18, 2022
Namish Khanna 40 Oct 11, 2022
G-NIA model from "Single Node Injection Attack against Graph Neural Networks" (CIKM 2021)

Single Node Injection Attack against Graph Neural Networks This repository is our Pytorch implementation of our paper: Single Node Injection Attack ag

Shuchang Tao 18 Nov 21, 2022