Pretraining Representations For Data-Efficient Reinforcement Learning

Related tags

Deep LearningSGI
Overview

Pretraining Representations For Data-Efficient Reinforcement Learning

Max Schwarzer, Nitarshan Rajkumar, Michael Noukhovitch, Ankesh Anand, Laurent Charlin, Devon Hjelm, Philip Bachman & Aaron Courville

This repo provides code for implementing SGI.

Install

To install the requirements, follow these steps:

# PyTorch
export LANG=C.UTF-8
# Install requirements
pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html
pip install -r requirements.txt

# Finally, install the project
pip install --user -e .

Usage:

The default branch for the latest and stable changes is release.

  • To run SGI:
  1. Download the DQN replay dataset from https://research.google/tools/datasets/dqn-replay/
    • Or substitute your own pre-training data! The codebase expects a series of .gz files, one each for observations, actions and terminals.
  2. To pretrain with SGI:
python -m scripts.run public=True model_folder=./ offline.runner.save_every=2500 \
    env.game=pong seed=1 offline_model_save={your model name} \
    offline.runner.epochs=10 offline.runner.dataloader.games=[Pong] \
    offline.runner.no_eval=1 \
    +offline.algo.goal_weight=1 \
    +offline.algo.inverse_model_weight=1 \
    +offline.algo.spr_weight=1 \
    +offline.algo.target_update_tau=0.01 \
    +offline.agent.model_kwargs.momentum_tau=0.01 \
    do_online=False \
    algo.batch_size=256 \
    +offline.agent.model_kwargs.noisy_nets_std=0 \
    offline.runner.dataloader.dataset_on_disk=True \
    offline.runner.dataloader.samples=1000000 \
    offline.runner.dataloader.checkpoints='{your checkpoints}' \
    offline.runner.dataloader.num_workers=2 \
    offline.runner.dataloader.data_path={your data dir} \
    offline.runner.dataloader.tmp_data_path=./ 
  1. To fine-tune with SGI:
python -m scripts.run public=True env.game=pong seed=1 num_logs=10  \
    model_load={your_model_name} model_folder=./ \
    algo.encoder_lr=0.000001 algo.q_l1_lr=0.00003 algo.clip_grad_norm=-1 algo.clip_model_grad_norm=-1

When reporting scores, we average across 10 fine-tuning seeds.

./scripts/experiments contains a number of example configurations, including for SGI-M, SGI-M/L and SGI-W, for both pre-training and fine-tuning. Each of these scripts can be launched by providing a game and seed, e.g., ./scripts/experiments/sgim_pretrain.sh pong 1. These scripts are provided primarily to illustrate the hyperparameters used for different experiments; you will likely need to modify the arguments in these scripts to point to your data and model directories.

Data for SGI-R and SGI-E is not included due to its size, but can be re-generated locally. Contact us for details.

What does each file do?

.
โ”œโ”€โ”€ scripts
โ”‚   โ”œโ”€โ”€ run.py                # The main runner script to launch jobs.
โ”‚   โ”œโ”€โ”€ config.yaml           # The hydra configuration file, listing hyperparameters and options.
|   โ””โ”€โ”€ experiments           # Configurations for various experiments done by SGI.
|   
โ”œโ”€โ”€ src                     
โ”‚   โ”œโ”€โ”€ agent.py              # Implements the Agent API for action selection 
โ”‚   โ”œโ”€โ”€ algos.py              # Distributional RL loss and optimization
โ”‚   โ”œโ”€โ”€ models.py             # Forward passes, network initialization.
โ”‚   โ”œโ”€โ”€ networks.py           # Network architecture and forward passes.
โ”‚   โ”œโ”€โ”€ offline_dataset.py    # Dataloader for offline data.
โ”‚   โ”œโ”€โ”€ gcrl.py               # Utils for SGI's goal-conditioned RL objective.
โ”‚   โ”œโ”€โ”€ rlpyt_atari_env.py    # Slightly modified Atari env from rlpyt
โ”‚   โ”œโ”€โ”€ rlpyt_utils.py        # Utility methods that we use to extend rlpyt's functionality
โ”‚   โ””โ”€โ”€ utils.py              # Command line arguments and helper functions 
โ”‚
โ””โ”€โ”€ requirements.txt          # Dependencies
Owner
Mila
Quebec Artificial Intelligence Institute
Mila
This demo showcase the use of onnxruntime-rs with a GPU on CUDA 11 to run Bert in a data pipeline with Rust.

Demo BERT ONNX pipeline written in rust This demo showcase the use of onnxruntime-rs with a GPU on CUDA 11 to run Bert in a data pipeline with Rust. R

Xavier Tao 14 Dec 17, 2022
This code is 3d-CNN model that can predict environmental value

Predict-environmental-value-3dCNN This code is 3d-CNN model that can predict environmental value. Firstly, I built a model that can create a lot of bu

1 Jan 06, 2022
working repo for my xumx-sliCQ submissions to the ISMIR 2021 MDX

Music Demixing Challenge - xumx-sliCQ This repository is the GitHub mirror of my working submission repository for the AICrowd ISMIR 2021 Music Demixi

4 Aug 25, 2021
Dynamica causal Bayesian optimisation

Dynamic Causal Bayesian Optimization This is a Python implementation of Dynamic Causal Bayesian Optimization as presented at NeurIPS 2021. Abstract Th

nd308 18 Nov 22, 2022
The official repository for Deep Image Matting with Flexible Guidance Input

FGI-Matting The official repository for Deep Image Matting with Flexible Guidance Input. Paper: https://arxiv.org/abs/2110.10898 Requirements easydict

Hang Cheng 51 Nov 10, 2022
Implementation of algorithms for continuous control (DDPG and NAF).

DEPRECATION This repository is deprecated and is no longer maintaned. Please see a more recent implementation of RL for continuous control at jax-sac.

Ilya Kostrikov 288 Dec 31, 2022
TensorFlow implementation of Elastic Weight Consolidation

Elastic weight consolidation Introduction A TensorFlow implementation of elastic weight consolidation as presented in Overcoming catastrophic forgetti

James Stokes 67 Oct 11, 2022
Learning Modified Indicator Functions for Surface Reconstruction

Learning Modified Indicator Functions for Surface Reconstruction In this work, we propose a learning-based approach for implicit surface reconstructio

4 Apr 18, 2022
Anomaly Localization in Model Gradients Under Backdoor Attacks Against Federated Learning

Federated_Learning This repo provides a federated learning framework that allows to carry out backdoor attacks under varying conditions. This is a ker

Arรงelik ARGE Aรงฤฑk Kaynak Yazฤฑlฤฑm Organizasyonu 0 Nov 30, 2021
A powerful framework for decentralized federated learning with user-defined communication topology

Scatterbrained Decentralized Federated Learning Scatterbrained makes it easy to build federated learning systems. In addition to traditional federated

Johns Hopkins Applied Physics Laboratory 7 Sep 26, 2022
Solving Zero-Shot Learning in Named Entity Recognition with Common Sense Knowledge

Zero-Shot Learning in Named Entity Recognition with Common Sense Knowledge Associated code for the paper Zero-Shot Learning in Named Entity Recognitio

Sรธren Hougaard Mulvad 13 Dec 25, 2022
An implementation of Video Frame Interpolation via Adaptive Separable Convolution using PyTorch

This work has now been superseded by: https://github.com/sniklaus/revisiting-sepconv sepconv-slomo This is a reference implementation of Video Frame I

Simon Niklaus 984 Dec 16, 2022
Code and datasets for the paper "Combining Events and Frames using Recurrent Asynchronous Multimodal Networks for Monocular Depth Prediction" (RA-L, 2021)

Combining Events and Frames using Recurrent Asynchronous Multimodal Networks for Monocular Depth Prediction This is the code for the paper Combining E

Robotics and Perception Group 69 Dec 26, 2022
Improving XGBoost survival analysis with embeddings and debiased estimators

xgbse: XGBoost Survival Embeddings "There are two cultures in the use of statistical modeling to reach conclusions from data

Loft 242 Dec 30, 2022
Cweqgen - The CW Equation Generator

The CW Equation Generator The cweqgen (pronouced like "Queck-Jen") package provi

2 Jan 15, 2022
๐Ÿฆ• NanoSaur is a little tracked robot ROS2 enabled, made for an NVIDIA Jetson Nano

๐Ÿฆ• nanosaur NanoSaur is a little tracked robot ROS2 enabled, made for an NVIDIA Jetson Nano Website: nanosaur.ai Do you need an help? Discord For tech

NanoSaur 162 Dec 09, 2022
noisy labels; missing labels; semi-supervised learning; entropy; uncertainty; robustness and generalisation.

ProSelfLC: CVPR 2021 ProSelfLC: Progressive Self Label Correction for Training Robust Deep Neural Networks For any specific discussion or potential fu

amos_xwang 57 Dec 04, 2022
K-PLUG: Knowledge-injected Pre-trained Language Model for Natural Language Understanding and Generation in E-Commerce (EMNLP Founding 2021)

Introduction K-PLUG: Knowledge-injected Pre-trained Language Model for Natural Language Understanding and Generation in E-Commerce. Installation PyTor

Xu Song 21 Nov 16, 2022
EgGateWayGetShell py่„šๆœฌ

EgGateWayGetShell_py ๅ…่ดฃๅฃฐๆ˜Ž ็”ฑไบŽไผ ๆ’ญใ€ๅˆฉ็”จๆญคๆ–‡ๆ‰€ๆไพ›็š„ไฟกๆฏ่€Œ้€ ๆˆ็š„ไปปไฝ•็›ดๆŽฅๆˆ–่€…้—ดๆŽฅ็š„ๅŽๆžœๅŠๆŸๅคฑ๏ผŒๅ‡็”ฑไฝฟ็”จ่€…ๆœฌไบบ่ดŸ่ดฃ๏ผŒไฝœ่€…ไธไธบๆญคๆ‰ฟๆ‹…ไปปไฝ•่ดฃไปปใ€‚ ไฝฟ็”จ python3 eg.py urls.txt ็›ฎๆ ‡ title:้”ๆท็ฝ‘็ปœ-EWEB็ฝ‘็ฎก็ณป็ปŸ port:4430 ๆผๆดžๆˆๅ›  ?p

ๆฆ†ๆœจ 61 Nov 09, 2022
Microscopy Image Cytometry Toolkit

Cytokit Cytokit is a collection of tools for quantifying and analyzing properties of individual cells in large fluorescent microscopy datasets with a

Hammer Lab 106 Jan 06, 2023