Adaptive Attention Span for Reinforcement Learning

Overview

Adaptive Transformers in RL

Official implementation of Adaptive Transformers in RL

In this work we replicate several results from Stabilizing Transformers for RL on both Pong and rooms_select_nonmatching_object from DMLab30.

We also extend the Stable Transformer architecture with Adaptive Attention Span on a partially observable (POMDP) setting of Reinforcement Learning. To our knowledge this is one of the first attempts to stabilize and explore Adaptive Attention Span in an RL domain.

Steps to replicate what we did on your own machine

  1. Downloading DMLab:

  2. Downloading Atari: Getting Started with Gym– http://gym.openai.com/docs/#getting-started-with-gym

  3. Execution notes:

  • The experiments take around 4 hours on 32vCPUs and 2 P100 GPUs for 6 million environment interactions. To run without a GPU, use the flag “--disable_cuda”.
  • For more details on other flags, see the top of train.py (include a link to this file) which has descriptions for each.
  • All experiments use a slightly revised version of IMPALA from torchbeast

Snippets

Best performing adaptive attention span model on “rooms_select_nonmatching_object”:

python train.py --total_steps 20000000 \
--learning_rate 0.0001 --unroll_length 299 --num_buffers 40 --n_layer 3 \
--d_inner 1024 --xpid row85 --chunk_size 100 --action_repeat 1 \
--num_actors 32 --num_learner_threads 1 --sleep_length 20 \
--level_name rooms_select_nonmatching_object --use_adaptive \
--attn_span 400 --adapt_span_loss 0.025 --adapt_span_cache

Best performing Stable Transformer on Pong:

python train.py --total_steps 10000000 \
--learning_rate 0.0004 --unroll_length 239 --num_buffers 40 \
--n_layer 3 --d_inner 1024 --xpid row82 --chunk_size 80 \
--action_repeat 1 --num_actors 32 --num_learner_threads 1 \
--sleep_length 5 --atari True

Best performing Stable Transformer on “rooms_select_nonmatching_object”:

python train.py --total_steps 20000000 \
--learning_rate 0.0001 --unroll_length 299 \
--num_buffers 40 --n_layer 3 --d_inner 1024 \
--xpid row79 --chunk_size 100 --action_repeat 1 \
--num_actors 32 --num_learner_threads 1 --sleep_length 20 \
--level_name rooms_select_nonmatching_object  --mem_len 200

Reference

If you find this repository useful, do cite it with,

@article{kumar2020adaptive,
    title={Adaptive Transformers in RL},
    author={Shakti Kumar and Jerrod Parker and Panteha Naderian},
    year={2020},
    eprint={2004.03761},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}
A new video text spotting framework with Transformer

TransVTSpotter: End-to-end Video Text Spotter with Transformer Introduction A Multilingual, Open World Video Text Dataset and End-to-end Video Text Sp

weijiawu 67 Jan 03, 2023
Code for generating the figures in the paper "Capacity of Group-invariant Linear Readouts from Equivariant Representations: How Many Objects can be Linearly Classified Under All Possible Views?"

Code for running simulations for the paper "Capacity of Group-invariant Linear Readouts from Equivariant Representations: How Many Objects can be Lin

Matthew Farrell 1 Nov 22, 2022
From Perceptron model to Deep Neural Network from scratch in Python.

Neural-Network-Basics Aim of this Repository: From Perceptron model to Deep Neural Network (from scratch) in Python. ** Currently working on a basic N

Aditya Kahol 1 Jan 14, 2022
Implementation of popular bandit algorithms in batch environments.

batch-bandits Implementation of popular bandit algorithms in batch environments. Source code to our paper "The Impact of Batch Learning in Stochastic

Danil Provodin 2 Sep 11, 2022
🙄 Difficult algorithm, Simple code.

🎉TensorFlow2.0-Examples🎉! "Talk is cheap, show me the code." ----- Linus Torvalds Created by YunYang1994 This tutorial was designed for easily divin

1.7k Dec 25, 2022
Pre-Training Graph Neural Networks for Cold-Start Users and Items Representation.

Pretrain-Recsys This is our Tensorflow implementation for our WSDM 2021 paper: Bowen Hao, Jing Zhang, Hongzhi Yin, Cuiping Li, Hong Chen. Pre-Training

30 Nov 14, 2022
CoReNet is a technique for joint multi-object 3D reconstruction from a single RGB image.

CoReNet CoReNet is a technique for joint multi-object 3D reconstruction from a single RGB image. It produces coherent reconstructions, where all objec

Google Research 80 Dec 25, 2022
Light-weight network, depth estimation, knowledge distillation, real-time depth estimation, auxiliary data.

light-weight-depth-estimation Boosting Light-Weight Depth Estimation Via Knowledge Distillation, https://arxiv.org/abs/2105.06143 Junjie Hu, Chenyou F

Junjie Hu 13 Dec 10, 2022
Vector Neurons: A General Framework for SO(3)-Equivariant Networks

Vector Neurons: A General Framework for SO(3)-Equivariant Networks Created by Congyue Deng, Or Litany, Yueqi Duan, Adrien Poulenard, Andrea Tagliasacc

Congyue Deng 332 Dec 29, 2022
Python Library for learning (Structure and Parameter) and inference (Statistical and Causal) in Bayesian Networks.

pgmpy pgmpy is a python library for working with Probabilistic Graphical Models. Documentation and list of algorithms supported is at our official sit

pgmpy 2.2k Jan 03, 2023
Gluon CV Toolkit

Gluon CV Toolkit | Installation | Documentation | Tutorials | GluonCV provides implementations of the state-of-the-art (SOTA) deep learning models in

Distributed (Deep) Machine Learning Community 5.4k Jan 06, 2023
Neural Scene Flow Fields using pytorch-lightning, with potential improvements

nsff_pl Neural Scene Flow Fields using pytorch-lightning. This repo reimplements the NSFF idea, but modifies several operations based on observation o

AI葵 178 Dec 21, 2022
Unofficial PyTorch implementation of Attention Free Transformer (AFT) layers by Apple Inc.

aft-pytorch Unofficial PyTorch implementation of Attention Free Transformer's layers by Zhai, et al. [abs, pdf] from Apple Inc. Installation You can i

Rishabh Anand 184 Dec 12, 2022
Xview3 solution - XView3 challenge, 2nd place solution

Xview3, 2nd place solution https://iuu.xview.us/ test split aggregate score publ

Selim Seferbekov 24 Nov 23, 2022
PyTorch version repo for CSRNet: Dilated Convolutional Neural Networks for Understanding the Highly Congested Scenes

Study-CSRNet-pytorch This is the PyTorch version repo for CSRNet: Dilated Convolutional Neural Networks for Understanding the Highly Congested Scenes

0 Mar 01, 2022
Transformer part of 12th place solution in Riiid! Answer Correctness Prediction

kaggle_riiid Transformer part of 12th place solution in Riiid! Answer Correctness Prediction. Please see here for more information. Execution You need

Sakami Kosuke 2 Apr 23, 2022
Code for paper [ACE: Ally Complementary Experts for Solving Long-Tailed Recognition in One-Shot] (ICCV 2021, oral))

ACE: Ally Complementary Experts for Solving Long-Tailed Recognition in One-Shot This repository is the official PyTorch implementation of ICCV-21 pape

Jiarui 21 May 09, 2022
Pacman-AI - AI project designed by UC Berkeley. Designed reflex and minimax agents for the game Pacman.

Pacman AI Jussi Doherty CAP 4601 - Introduction to Artificial Intelligence - Fall 2020 Python version 3.0+ Source of this project This repo contains a

Jussi Doherty 1 Jan 03, 2022
The code for two papers: Feedback Transformer and Expire-Span.

transformer-sequential This repo contains the code for two papers: Feedback Transformer Expire-Span The training code is structured for long sequentia

Facebook Research 125 Dec 25, 2022
Underwater industrial application yolov5m6

This project wins the intelligent algorithm contest finalist award and stands out from over 2000teams in China Underwater Robot Professional Contest, entering the final of China Underwater Robot Prof

8 Nov 09, 2022