Source code of the paper Meta-learning with an Adaptive Task Scheduler.

Related tags

Deep LearningATS
Overview

ATS

About

Source code of the paper Meta-learning with an Adaptive Task Scheduler.

If you find this repository useful in your research, please cite the following paper:

@inproceedings{yao2021adaptive,
  title={Meta-learning with an Adaptive Task Scheduler},
  author={Yao, Huaxiu and Wang, Yu and Wei, Ying and Zhao, Peilin and Mahdavi, Mehrdad and Lian, Defu and Finn, Chelsea},
  booktitle={Proceedings of the Thirty-fifth Conference on Neural Information Processing Systems},
  year={2021} 
}

Miniimagenet

The processed miniimagenet dataset could be downloaded here. Assume the dataset has been downloaded and unzipped to /data/miniimagenet, which has the following file structure:

-- miniimagenet  // /data/miniimagenet
  -- miniImagenet
    -- train_task_id.pkl
    -- test_task_id.pkl
    -- mini_imagenet_train.pkl
    -- mini_imagenet_test.pkl
    -- mini_imagenet_val.pkl
    -- training_classes_20000_2_new.npz
    -- training_classes_20000_4_new.npz

Then $datadir in the following code sould be set to /data/miniimagenet.

ATS with noise = 0.6

We need to first pretrain the model with no noise. The model has been uploaded to this repo. You can also pretrain the model by yourself. The script for pretraining is as follows:
(1) 1 shot:

python3 main.py --meta_batch_size 2 --datasource miniimagenet --datadir $datadir --num_updates 5 --num_updates_test 10 --update_batch_size 1 --update_batch_size_eval 15 --resume 0  --num_classes 5 --metatrain_iterations 30000 --logdir $logdir --noise 0.0

(2) 5 shot:

python3 main.py --meta_batch_size 2 --datasource miniimagenet --datadir $datadir --num_updates 5 --num_updates_test 10 --update_batch_size 5 --update_batch_size_eval 15 --resume 0  --num_classes 5 --metatrain_iterations 30000 --logdir $logdir --noise 0.0

Then move the model to the current directory:
(1) 1 shot:

mv $logdir/ANIL_pytorch.data_miniimagenetcls_5.mbs_2.ubs_1.metalr0.001.innerlr0.01.hidden32/model20000 ./model20000_1shot

(2) 5 shot:

mv $logdir/ANIL_pytorch.data_miniimagenetcls_5.mbs_2.ubs_5.metalr0.001.innerlr0.01.hidden32/model10000 ./model10000_5shot

Then with this model, we could run the uniform sampling and ATS sampling. For ATS, the script is:
(1) 1 shot

python3 main.py --meta_batch_size 2 --datasource miniimagenet --datadir $datadir --num_updates 5 --num_updates_test 10 --update_batch_size 1 --update_batch_size_eval 15 --resume 0 --num_classes 5 --metatrain_iterations 30000 --replace 0 --noise 0.6 --logdir $logdir --sampling_method ATS --buffer_size 10  --temperature 0.1 --scheduler_lr 0.001 --warmup 2000 --pretrain_iter 20000

(2) 5 shot

python3 main.py --meta_batch_size 2 --datasource miniimagenet --datadir $datadir --num_updates 5 --num_updates_test 10 --update_batch_size 5 --update_batch_size_eval 15 --resume 0  --num_classes 5 --metatrain_iterations 30000 --replace 0 --noise 0.6 --logdir $logdir --sampling_method ATS --buffer_size 10 --utility_function sample --temperature 0.1 --scheduler_lr 0.001 --warmup 2000 --pretrain_iter 10000

For uniform sampling, we need to use the validation set to finetune the model trained under uniform sampling. The training commands are:
(1) 1 shot

python3 main.py --meta_batch_size 2 --datasource miniimagenet --datadir $datadir --num_updates 5 --num_updates_test 10 --update_batch_size 1 --update_batch_size_eval 15 --resume 0 --num_classes 5 --metatrain_iterations 30000 --logdir $logdir --noise 0.6
mkdir models
mv ANIL_pytorch.data_miniimagenetcls_5.mbs_2.ubs_1.metalr0.001.innerlr0.01.hidden32_noise0.6/model30000 ./models/ANIL_0.4_model_1shot
python3 main.py --meta_batch_size 2 --datasource miniimagenet --datadir $datadir --num_updates 5 --num_updates_test 10 --update_batch_size 1 --update_batch_size_eval 15 --resume 0 --num_classes 5 --metatrain_iterations 30000 --logdir $logdir --noise 0.6 --finetune

(2) 5 shot

python3 main.py --meta_batch_size 2 --datasource miniimagenet --datadir $datadir --num_updates 5 --num_updates_test 10 --update_batch_size 5 --update_batch_size_eval 15 --resume 0  --num_classes 5 --metatrain_iterations 30000 --logdir $logdir --noise 0.6
mkdir models  // if directory "models" does not exist
mv ANIL_pytorch.data_miniimagenetcls_5.mbs_2.ubs_5.metalr0.001.innerlr0.01.hidden32_noise0.6/model30000 ./models/ANIL_0.4_model_5shot
python3 main.py --meta_batch_size 2 --datasource miniimagenet --datadir $datadir --num_updates 5 --num_updates_test 10 --update_batch_size 5 --update_batch_size_eval 15 --resume 0  --num_classes 5 --metatrain_iterations 30000 --logdir $logdir --noise 0.6 --finetune

ATS with limited budgets

In this setting, pretraining is not needed. You can directly run the following code:
uniform sampling, 1 shot

python3 main.py --meta_batch_size 3 --datasource miniimagenet --datadir ./miniimagenet/ --num_updates 5 --num_updates_test 10 --update_batch_size 1 --update_batch_size_eval 15 --resume 0  --num_classes 5 --metatrain_iterations 30000 --limit_data 1 --logdir ../train_logs --limit_classes 16

uniform sampling, 5 shot

python3 main.py --meta_batch_size 3 --datasource miniimagenet --datadir ./miniimagenet/ --num_updates 5 --num_updates_test 10 --update_batch_size 5 --update_batch_size_eval 15 --resume 0  --num_classes 5 --metatrain_iterations 30000 --limit_data 1 --logdir ../train_logs --limit_classes 16

ATS 1 shot

python3 main.py --meta_batch_size 3 --datasource miniimagenet --datadir ./miniimagenet/ --num_updates 5 --num_updates_test 10 --update_batch_size 1 --update_batch_size_eval 15 --resume 0  --num_classes 5 --metatrain_iterations 30000 --replace 0 --limit_data 1 --logdir ../train_logs --sampling_method ATS --buffer_size 6 --utility_function sample --temperature 1 --warmup 0 --limit_classes 16

ATS 5 shot

python3 main.py --meta_batch_size 3 --datasource miniimagenet --datadir ./miniimagenet/ --num_updates 5 --num_updates_test 10 --update_batch_size 5 --update_batch_size_eval 15 --resume 0  --num_classes 5 --metatrain_iterations 30000 --replace 0 --limit_data 1 --logdir ../train_logs --sampling_method ATS --buffer_size 6 --utility_function sample --temperature 0.1 --warmup 0 --limit_classes 16

Drug

The processed dataset could be downloaded here. Assume the dataset has been downloaded and unzipped to /data/drug which has the following structure:

-- drug  // /data/drug
  -- ci9b00375_si_001.txt  
  -- compound_fp.npy               
  -- drug_split_id_group2.pickle  
  -- drug_split_id_group6.pickle
  -- ci9b00375_si_002.txt  
  -- drug_split_id_group17.pickle  
  -- drug_split_id_group3.pickle  
  -- drug_split_id_group9.pickle
  -- ci9b00375_si_003.txt  
  -- drug_split_id_group1.pickle   
  -- drug_split_id_group4.pickle  
  -- important_readme.md

Then $datadir in the following script should be set as /data/.

ATS with noise=4.

Uniform Sampling:

python3 main.py --datasource=drug --metatrain_iterations=20 --update_lr=0.005 --meta_lr=0.001 --num_updates=5 --test_num_updates=5 --trial=1 --drug_group=17 --noise 4 --data_dir $datadir
python3 main.py --datasource=drug --metatrain_iterations=20 --update_lr=0.005 --meta_lr=0.001 --num_updates=5 --test_num_updates=5 --trial=1 --drug_group=17 --noise 4 --data_dir $datadir --train 0

ATS:

python3 main.py --datasource=drug --metatrain_iterations=20 --update_lr=0.005 --meta_lr=0.001 --num_updates=5 --test_num_updates=5 --trial=1 --drug_group=17 --sampling_method ATS --noise 4 --data_dir $datadir
python3 main.py --datasource=drug --metatrain_iterations=20 --update_lr=0.005 --meta_lr=0.001 --num_updates=5 --test_num_updates=5 --trial=1 --drug_group=17 --sampling_method ATS --noise 4 --data_dir $datadir --train 0

ATS with full budgets

Uniform Sampling:

python3 main.py --datasource=drug --metatrain_iterations=20 --update_lr=0.005 --meta_lr=0.001 --num_updates=5 --test_num_updates=5 --trial=1 --drug_group=17 --data_dir $datadir
python3 main.py --datasource=drug --metatrain_iterations=20 --update_lr=0.005 --meta_lr=0.001 --num_updates=5 --test_num_updates=5 --trial=1 --drug_group=17 --data_dir $datadir --train 0

ATS:

python3 main.py --datasource=drug --metatrain_iterations=20 --update_lr=0.005 --meta_lr=0.001 --num_updates=5 --test_num_updates=5 --trial=1 --drug_group=17 --sampling_method ATS --data_dir $datadir
python3 main.py --datasource=drug --metatrain_iterations=20 --update_lr=0.005 --meta_lr=0.001 --num_updates=5 --test_num_updates=5 --trial=1 --drug_group=17 --sampling_method ATS --data_dir $datadir --train 0

For ATS, if you need to use 1 for calculating the loss as the input of the scheduler instead of 1, you can add --simple_loss after the script above.

Owner
Huaxiu Yao
Postdoctoral Scholar at [email protected]
Huaxiu Yao
WRENCH: Weak supeRvision bENCHmark

🔧 What is it? Wrench is a benchmark platform containing diverse weak supervision tasks. It also provides a common and easy framework for development

Jieyu Zhang 176 Dec 28, 2022
An executor that loads ONNX models and embeds documents using the ONNX runtime.

ONNXEncoder An executor that loads ONNX models and embeds documents using the ONNX runtime. Usage via Docker image (recommended) from jina import Flow

Jina AI 2 Mar 15, 2022
An ever-growing playground of notebooks showcasing CLIP's impressive zero-shot capabilities.

Playground for CLIP-like models Demo Colab Link GradCAM Visualization Naive Zero-shot Detection Smarter Zero-shot Detection Captcha Solver Changelog 2

Kevin Zakka 101 Dec 30, 2022
Boundary-preserving Mask R-CNN (ECCV 2020)

BMaskR-CNN This code is developed on Detectron2 Boundary-preserving Mask R-CNN ECCV 2020 Tianheng Cheng, Xinggang Wang, Lichao Huang, Wenyu Liu Video

Hust Visual Learning Team 178 Nov 28, 2022
HomoInterpGAN - Homomorphic Latent Space Interpolation for Unpaired Image-to-image Translation

HomoInterpGAN Homomorphic Latent Space Interpolation for Unpaired Image-to-image Translation (CVPR 2019, oral) Installation The implementation is base

Ying-Cong Chen 99 Nov 15, 2022
Python package for covariance matrices manipulation and Biosignal classification with application in Brain Computer interface

pyRiemann pyRiemann is a python package for covariance matrices manipulation and classification through Riemannian geometry. The primary target is cla

447 Jan 05, 2023
Look Closer: Bridging Egocentric and Third-Person Views with Transformers for Robotic Manipulation

Look Closer: Bridging Egocentric and Third-Person Views with Transformers for Robotic Manipulation Official PyTorch implementation for the paper Look

Rishabh Jangir 20 Nov 24, 2022
Code for the Shortformer model, from the paper by Ofir Press, Noah A. Smith and Mike Lewis.

Shortformer This repository contains the code and the final checkpoint of the Shortformer model. This file explains how to run our experiments on the

Ofir Press 138 Apr 15, 2022
End-to-end beat and downbeat tracking in the time domain.

WaveBeat End-to-end beat and downbeat tracking in the time domain. | Paper | Code | Video | Slides | Setup First clone the repo. git clone https://git

Christian J. Steinmetz 60 Dec 24, 2022
Test-Time Personalization with a Transformer for Human Pose Estimation, NeurIPS 2021

Transforming Self-Supervision in Test Time for Personalizing Human Pose Estimation This is an official implementation of the NeurIPS 2021 paper: Trans

41 Nov 28, 2022
Basics of 2D and 3D Human Pose Estimation.

Human Pose Estimation 101 If you want a slightly more rigorous tutorial and understand the basics of Human Pose Estimation and how the field has evolv

Sudharshan Chandra Babu 293 Dec 14, 2022
python debugger and anti-vm that checks if you're in a virtual machine or if someones trying to debug your file

Anti-Debug was made by Love ❌ code ✅ 🎉 ・What it checks for ・ Kills tools that can be used to debug your file ・ Exits if ran in vm (supports different

Rdimo 31 Aug 09, 2022
Code for this paper The Lottery Ticket Hypothesis for Pre-trained BERT Networks.

The Lottery Ticket Hypothesis for Pre-trained BERT Networks Code for this paper The Lottery Ticket Hypothesis for Pre-trained BERT Networks. [NeurIPS

VITA 122 Dec 14, 2022
Some simple programs built in Python: webcam with cv2 that detects eyes and face, with grayscale filter

Programas en Python Algunos programas simples creados en Python: 📹 Webcam con c

Madirex 1 Feb 15, 2022
Library of various Few-Shot Learning frameworks for text classification

FewShotText This repository contains code for the paper A Neural Few-Shot Text Classification Reality Check Environment setup # Create environment pyt

Thomas Dopierre 47 Jan 03, 2023
[ICCV2021] Safety-aware Motion Prediction with Unseen Vehicles for Autonomous Driving

Safety-aware Motion Prediction with Unseen Vehicles for Autonomous Driving Safety-aware Motion Prediction with Unseen Vehicles for Autonomous Driving

Xuanchi Ren 44 Dec 03, 2022
An implementation of DeepMind's Relational Recurrent Neural Networks in PyTorch.

relational-rnn-pytorch An implementation of DeepMind's Relational Recurrent Neural Networks (Santoro et al. 2018) in PyTorch. Relational Memory Core (

Sang-gil Lee 241 Nov 18, 2022
a project for 3D multi-object tracking

a project for 3D multi-object tracking

155 Jan 04, 2023
YOLOX-CondInst - Implement CondInst which is a instances segmentation method on YOLOX

YOLOX CondInst -- YOLOX 实例分割 前言 本项目是自己学习实例分割时,复现的代码. 通过自己编程,让自己对实例分割有更进一步的了解。 若想

DDGRCF 16 Nov 18, 2022
Office source code of paper UniFuse: Unidirectional Fusion for 360$^\circ$ Panorama Depth Estimation

UniFuse (RAL+ICRA2021) Office source code of paper UniFuse: Unidirectional Fusion for 360$^\circ$ Panorama Depth Estimation, arXiv, Demo Preparation I

Alibaba 47 Dec 26, 2022