Official PyTorch implementation of "Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets" (ICLR 2021)

Related tags

Deep LearningMetaD2A
Overview

Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets

This is the official PyTorch implementation for the paper Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets (ICLR 2021) : https://openreview.net/forum?id=rkQuFUmUOg3.

Abstract

Despite the success of recent Neural Architecture Search (NAS) methods on various tasks which have shown to output networks that largely outperform human-designed networks, conventional NAS methods have mostly tackled the optimization of searching for the network architecture for a single task (dataset), which does not generalize well across multiple tasks (datasets). Moreover, since such task-specific methods search for a neural architecture from scratch for every given task, they incur a large computational cost, which is problematic when the time and monetary budget are limited. In this paper, we propose an efficient NAS framework that is trained once on a database consisting of datasets and pretrained networks and can rapidly search a neural architecture for a novel dataset. The proposed MetaD2A (Meta Dataset-to-Architecture) model can stochastically generate graphs (architectures) from a given set (dataset) via a cross-modal latent space learned with amortized meta-learning. Moreover, we also propose a meta-performance predictor to estimate and select the best architecture without direct training on target datasets. The experimental results demonstrate that our model meta-learned on subsets of ImageNet-1K and architectures from NAS-Bench 201 search space successfully generalizes to multiple benchmark datasets including CIFAR-10 and CIFAR-100, with an average search time of 33 GPU seconds. Even under a large search space, MetaD2A is 5.5K times faster than NSGANetV2, a transferable NAS method, with comparable performance. We believe that the MetaD2A proposes a new research direction for rapid NAS as well as ways to utilize the knowledge from rich databases of datasets and architectures accumulated over the past years.

Framework of MetaD2A Model

Prerequisites

  • Python 3.6 (Anaconda)
  • PyTorch 1.6.0
  • CUDA 10.2
  • python-igraph==0.8.2
  • tqdm==4.50.2
  • torchvision==0.7.0
  • python-igraph==0.8.2
  • nas-bench-201==1.3
  • scipy==1.5.2

If you are not familiar with preparing conda environment, please follow the below instructions

$ conda create --name metad2a python=3.6
$ conda activate metad2a
$ conda install pytorch==1.6.0 torchvision cudatoolkit=10.2 -c pytorch
$ pip install nas-bench-201
$ conda install -c conda-forge tqdm
$ conda install -c conda-forge python-igraph
$ pip install scipy

And for data preprocessing,

$ pip install requests

Hardware Spec used for experiments of the paper

  • GPU: A single Nvidia GeForce RTX 2080Ti
  • CPU: Intel(R) Xeon(R) Silver 4114 CPU @ 2.20GHz

NAS-Bench-201

Go to the folder for NAS-Bench-201 experiments (i.e. MetaD2A_nas_bench_201)

$ cd MetaD2A_nas_bench_201

Data Preparation

To download preprocessed data files, run get_files/get_preprocessed_data.py:

$ python get_files/get_preprocessed_data.py

It will take some time to download and preprocess each dataset.

To download MNIST, Pets and Aircraft Datasets, run get_files/get_{DATASET}.py

$ python get_files/get_mnist.py
$ python get_files/get_aircraft.py
$ python get_files/get_pets.py

Other datasets such as Cifar10, Cifar100, SVHN will be automatically downloaded when you load dataloader by torchvision.

If you want to use your own dataset, please first make your own preprocessed data, by modifying process_dataset.py .

$ process_dataset.py

MetaD2A Evaluation (Meta-Test)

You can download trained checkpoint files for generator and predictor

$ python get_files/get_checkpoint.py
$ python get_files/get_predictor_checkpoint.py

1. Evaluation on Cifar10 and Cifar100

By set --data-name as the name of dataset (i.e. cifar10, cifar100), you can evaluate the specific dataset only

# Meta-testing for generator 
$ python main.py --gpu 0 --model generator --hs 56 --nz 56 --test --load-epoch 400 --num-gen-arch 500 --data-name {DATASET_NAME}

After neural architecture generation is completed, meta-performance predictor selects high-performing architectures among the candidates

# Meta-testing for predictor
$ python main.py --gpu 0 --model predictor --hs 512 --nz 56 --test --num-gen-arch 500 --data-name {DATASET_NAME}

2. Evaluation on Other Datasets

By set --data-name as the name of dataset (i.e. mnist, svhn, aircraft, pets), you can evaluate the specific dataset only

# Meta-testing for generator
$ python main.py --gpu 0 --model generator --hs 56 --nz 56 --test --load-epoch 400 --num-gen-arch 50 --data-name {DATASET_NAME}

After neural architecture generation is completed, meta-performance predictor selects high-performing architectures among the candidates

# Meta-testing for predictor
$ python main.py --gpu 0 --model predictor --hs 512 --nz 56 --test --num-gen-arch 50 --data-name {DATASET_NAME}

Meta-Training MetaD2A Model

You can train the generator and predictor as follows

# Meta-training for generator
$ python main.py --gpu 0 --model generator --hs 56 --nz 56 
                 
# Meta-training for predictor
$ python main.py --gpu 0 --model predictor --hs 512 --nz 56 

Results

The results of training architectures which are searched by meta-trained MetaD2A model for each dataset

Accuracy

CIFAR10 CIFAR100 MNIST SVHN Aircraft Oxford-IIT Pets
PC-DARTS 93.66±0.17 66.64±0.04 99.66±0.04 95.40±0.67 46.08±7.00 25.31±1.38
MetaD2A (Ours) 94.37±0.03 73.51±0.00 99.71±0.08 96.34±0.37 58.43±1.18 41.50±4.39

Search Time (GPU Sec)

CIFAR10 CIFAR100 MNIST SVHN Aircraft Oxford-IIT Pets
PC-DARTS 10395 19951 24857 31124 3524 2844
MetaD2A (Ours) 69 96 7 7 10 8

MobileNetV3 Search Space

Go to the folder for MobileNetV3 Search Space experiments (i.e. MetaD2A_mobilenetV3)

$ cd MetaD2A_mobilenetV3

And follow README.md written for experiments of MobileNetV3 Search Space

Citation

If you found the provided code useful, please cite our work.

@inproceedings{
    lee2021rapid,
    title={Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets},
    author={Hayeon Lee and Eunyoung Hyung and Sung Ju Hwang},
    booktitle={ICLR},
    year={2021}
}

Reference

Owner
Ph.D. student @ School of Computing, Korea Advanced Institute of Science and Technology (KAIST)
An official implementation of "Background-Aware Pooling and Noise-Aware Loss for Weakly-Supervised Semantic Segmentation" (CVPR 2021) in PyTorch.

BANA This is the implementation of the paper "Background-Aware Pooling and Noise-Aware Loss for Weakly-Supervised Semantic Segmentation". For more inf

CV Lab @ Yonsei University 59 Dec 12, 2022
Exporter for Storage Area Network (SAN)

SAN Exporter Prometheus exporter for Storage Area Network (SAN). We all know that each SAN Storage vendor has their own glossary of terms, health/perf

vCloud 32 Dec 16, 2022
Radar-to-Lidar: Heterogeneous Place Recognition via Joint Learning

radar-to-lidar-place-recognition This page is the coder of a pre-print, implemented by PyTorch. If you have some questions on this project, please fee

Huan Yin 37 Oct 09, 2022
An example of Scatterbrain implementation (combining local attention and Performer)

An example of Scatterbrain implementation (combining local attention and Performer)

HazyResearch 97 Jan 02, 2023
Wenzhou-Kean University AI-LAB

AI-LAB This is Wenzhou-Kean University AI-LAB. Our research interests are in Computer Vision and Natural Language Processing. Computer Vision Please g

WKU AI-LAB 10 May 05, 2022
Learning from Guided Play: A Scheduled Hierarchical Approach for Improving Exploration in Adversarial Imitation Learning Source Code

Learning from Guided Play: A Scheduled Hierarchical Approach for Improving Exploration in Adversarial Imitation Learning Source Code

STARS Laboratory 8 Sep 14, 2022
Unofficial PyTorch implementation of Guided Dropout

Unofficial PyTorch implementation of Guided Dropout This is a simple implementation of Guided Dropout for research. We try to reproduce the algorithm

2 Jan 07, 2022
Matlab Python Heuristic Battery Opt - SMOP conversion and manual conversion

SMOP is Small Matlab and Octave to Python compiler. SMOP translates matlab to py

Tom Xu 1 Jan 12, 2022
This is the official PyTorch implementation for "Mesa: A Memory-saving Training Framework for Transformers".

A Memory-saving Training Framework for Transformers This is the official PyTorch implementation for Mesa: A Memory-saving Training Framework for Trans

Zhuang AI Group 105 Dec 06, 2022
Repo for "TableParser: Automatic Table Parsing with Weak Supervision from Spreadsheets" at [email protected]

TableParser Repo for "TableParser: Automatic Table Parsing with Weak Supervision from Spreadsheets" at DS3 Lab 11 Dec 13, 2022

Codes and models of NeurIPS2021 paper - DominoSearch: Find layer-wise fine-grained N:M sparse schemes from dense neural networks

DominoSearch This is repository for codes and models of NeurIPS2021 paper - DominoSearch: Find layer-wise fine-grained N:M sparse schemes from dense n

11 Sep 10, 2022
Cave Generation using metaballs in Blender. Originally created by sdfgeoff, Edited by Myself (Archie Jaskowicz).

Blender-Cave-Generation Cave Generation using metaballs in Blender. Originally created by sdfgeoff, Edited by Myself (Archie Jaskowicz). Installation

2 Dec 28, 2022
MinkLoc++: Lidar and Monocular Image Fusion for Place Recognition

MinkLoc++: Lidar and Monocular Image Fusion for Place Recognition Paper: MinkLoc++: Lidar and Monocular Image Fusion for Place Recognition accepted fo

64 Dec 18, 2022
The first machine learning framework that encourages learning ML concepts instead of memorizing class functions.

SeaLion is designed to teach today's aspiring ml-engineers the popular machine learning concepts of today in a way that gives both intuition and ways of application. We do this through concise algori

Anish 324 Dec 27, 2022
Repository containing the PhD Thesis "Formal Verification of Deep Reinforcement Learning Agents"

Getting Started This repository contains the code used for the following publications: Probabilistic Guarantees for Safe Deep Reinforcement Learning (

Edoardo Bacci 5 Aug 31, 2022
Demystifying How Self-Supervised Features Improve Training from Noisy Labels

Demystifying How Self-Supervised Features Improve Training from Noisy Labels This code is a PyTorch implementation of the paper "[Demystifying How Sel

<a href=[email protected]"> 4 Oct 14, 2022
Repository for GNSS-based position estimation using a Deep Neural Network

Code repository accompanying our work on 'Improving GNSS Positioning using Neural Network-based Corrections'. In this paper, we present a Deep Neural

32 Dec 13, 2022
4st place solution for the PBVS 2022 Multi-modal Aerial View Object Classification Challenge - Track 1 (SAR) at PBVS2022

A Two-Stage Shake-Shake Network for Long-tailed Recognition of SAR Aerial View Objects 4st place solution for the PBVS 2022 Multi-modal Aerial View Ob

LinpengPan 5 Nov 09, 2022
Point cloud processing tool library.

Point Cloud ToolBox This point cloud processing tool library can be used to process point clouds, 3d meshes, and voxels. Environment python 3.7.5 Dep

ZhangXinyun 40 Dec 09, 2022
Base pretrained models and datasets in pytorch (MNIST, SVHN, CIFAR10, CIFAR100, STL10, AlexNet, VGG16, VGG19, ResNet, Inception, SqueezeNet)

This is a playground for pytorch beginners, which contains predefined models on popular dataset. Currently we support mnist, svhn cifar10, cifar100 st

Aaron Chen 2.4k Dec 28, 2022