DrNAS: Dirichlet Neural Architecture Search

Related tags

Deep LearningDrNAS
Overview

DrNAS

About

Code accompanying the paper
ICLR'2021: DrNAS: Dirichlet Neural Architecture Search paper
Xiangning Chen*, Ruochen Wang*, Minhao Cheng*, Xiaocheng Tang, Cho-Jui Hsieh

This code is based on the implementation of NAS-Bench-201 and PC-DARTS.

This paper proposes a novel differentiable architecture search method by formulating it into a distribution learning problem. We treat the continuously relaxed architecture mixing weight as random variables, modeled by Dirichlet distribution. With recently developed pathwise derivatives, the Dirichlet parameters can be easily optimized with gradient-based optimizer in an end-to-end manner. This formulation improves the generalization ability and induces stochasticity that naturally encourages exploration in the search space. Furthermore, to alleviate the large memory consumption of differentiable NAS, we propose a simple yet effective progressive learning scheme that enables searching directly on large-scale tasks, eliminating the gap between search and evaluation phases. Extensive experiments demonstrate the effectiveness of our method. Specifically, we obtain a test error of 2.46% for CIFAR-10, 23.7% for ImageNet under the mobile setting. On NAS-Bench-201, we also achieve state-of-the-art results on all three datasets and provide insights for the effective design of neural architecture search algorithms.

Results

On NAS-Bench-201

The table below shows the test accuracy on NAS-Bench-201 space. We achieve the state-of-the-art results on all three datasets. On CIFAR-100, DrNAS even achieves the global optimal with no variance!

Method CIFAR-10 (test) CIFAR-100 (test) ImageNet-16-120 (test)
ENAS 54.30 ± 0.00 10.62 ± 0.27 16.32 ± 0.00
DARTS 54.30 ± 0.00 38.97 ± 0.00 18.41 ± 0.00
SNAS 92.77 ± 0.83 69.34 ± 1.98 43.16 ± 2.64
PC-DARTS 93.41 ± 0.30 67.48 ± 0.89 41.31 ± 0.22
DrNAS (ours) 94.36 ± 0.00 73.51 ± 0.00 46.34 ± 0.00
optimal 94.37 73.51 47.31

For every search process, we sample 100 architectures from the current Dirichlet distribution and plot their accuracy range along with the current architecture selected by Dirichlet mean (solid line). The figure below shows that the accuracy range of the sampled architectures starts very wide but narrows gradually during the search phase. It indicates that DrNAS learns to encourage exploration at the early stages and then gradually reduces it towards the end as the algorithm becomes more and more confident of the current choice. Moreover, the performance of our architectures can consistently match the best performance of the sampled architectures, indicating the effectiveness of DrNAS.

On DARTS Space (CIFAR-10)

DrNAS achieves an average test error of 2.46%, ranking top amongst recent NAS results.

Method Test Error (%) Params (M) Search Cost (GPU days)
ENAS 2.89 4.6 0.5
DARTS 2.76 ± 0.09 3.3 1.0
SNAS 2.85 ± 0.02 2.8 1.5
PC-DARTS 2.57 ± 0.07 3.6 0.1
DrNAS (ours) 2.46 ± 0.03 4.1 0.6

On DARTS Space (ImageNet)

DrNAS can perform a direct search on ImageNet and achieves a top-1 test error below 24.0%!

Method Top-1 Error (%) Params (M) Search Cost (GPU days)
DARTS* 26.7 4.7 1.0
SNAS* 27.3 4.3 1.5
PC-DARTS 24.2 5.3 3.8
DSNAS 25.7 - -
DrNAS (ours) 23.7 5.7 4.6

* not a direct search

Usage

Architecture Search

Search on NAS-Bench-201 Space: (3 datasets to choose from)

  • Data preparation: Please first download the 201 benchmark file and prepare the api follow this repository.

  • cd 201-space && python train_search.py

  • With Progressively Pruning: cd 201-space && python train_search_progressive.py

Search on DARTS Space:

  • Data preparation: For a direct search on ImageNet, we follow PC-DARTS to sample 10% and 2.5% images for earch class as train and validation.

  • CIFAR-10: cd DARTS-space && python train_search.py

  • ImageNet: cd DARTS-space && python train_search_imagenet.py

Architecture Evaluation

  • CIFAR-10: cd DARTS-space && python train.py --cutout --auxiliary

  • ImageNet: cd DARTS-space && python train_imagenet.py --auxiliary

Reference

If you find this code useful in your research please cite

@inproceedings{chen2021drnas,
    title={Dr{\{}NAS{\}}: Dirichlet Neural Architecture Search},
    author={Xiangning Chen and Ruochen Wang and Minhao Cheng and Xiaocheng Tang and Cho-Jui Hsieh},
    booktitle={International Conference on Learning Representations},
    year={2021},
    url={https://openreview.net/forum?id=9FWas6YbmB3}
}

Related Publications

Owner
Xiangning Chen
UCLA CS Ph.D. Student
Xiangning Chen
Block Sparse movement pruning

Movement Pruning: Adaptive Sparsity by Fine-Tuning Magnitude pruning is a widely used strategy for reducing model size in pure supervised learning; ho

Hugging Face 54 Dec 20, 2022
Hierarchical Motion Encoder-Decoder Network for Trajectory Forecasting (HMNet)

Hierarchical Motion Encoder-Decoder Network for Trajectory Forecasting (HMNet) Our paper: https://arxiv.org/abs/2111.13324 We will release the complet

15 Oct 17, 2022
LLVM-based compiler for LightGBM gradient-boosted trees. Speeds up prediction by ≥10x.

LLVM-based compiler for LightGBM gradient-boosted trees. Speeds up prediction by ≥10x.

Simon Boehm 183 Jan 02, 2023
KITTI-360 Annotation Tool is a framework that developed based on python(cherrypy + jinja2 + sqlite3) as the server end and javascript + WebGL as the front end.

KITTI-360 Annotation Tool is a framework that developed based on python(cherrypy + jinja2 + sqlite3) as the server end and javascript + WebGL as the front end.

86 Dec 12, 2022
The code for our paper CrossFormer: A Versatile Vision Transformer Based on Cross-scale Attention.

CrossFormer This repository is the code for our paper CrossFormer: A Versatile Vision Transformer Based on Cross-scale Attention. Introduction Existin

cheerss 238 Jan 06, 2023
Using Machine Learning to Create High-Res Fine Art

BIG.art: Using Machine Learning to Create High-Res Fine Art How to use GLIDE and BSRGAN to create ultra-high-resolution paintings with fine details By

Robert A. Gonsalves 13 Nov 27, 2022
ICCV2021 Oral SA-ConvONet: Sign-Agnostic Optimization of Convolutional Occupancy Networks

Sign-Agnostic Convolutional Occupancy Networks Paper | Supplementary | Video | Teaser Video | Project Page This repository contains the implementation

64 Jan 05, 2023
PyTorch implementations of the NeRF model described in "NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis"

PyTorch NeRF and pixelNeRF NeRF: Tiny NeRF: pixelNeRF: This repository contains minimal PyTorch implementations of the NeRF model described in "NeRF:

Michael A. Alcorn 178 Dec 20, 2022
This is the official code of L2G, Unrolling and Recurrent Unrolling in Learning to Learn Graph Topologies.

Learning to Learn Graph Topologies This is the official code of L2G, Unrolling and Recurrent Unrolling in Learning to Learn Graph Topologies. Requirem

Stacy X PU 16 Dec 09, 2022
Minimal fastai code needed for working with pytorch

fastai_minima A mimal version of fastai with the barebones needed to work with Pytorch #all_slow Install pip install fastai_minima How to use This lib

Zachary Mueller 14 Oct 21, 2022
This application is the basic of automated online-class-joiner(for YıldızEdu) within the right time. Gets the ZOOM link by scheduled date and time.

This application is the basic of automated online-class-joiner(for YıldızEdu) within the right time. Gets the ZOOM link by scheduled date and time.

215355 1 Dec 16, 2021
This repo contains source code and materials for the TEmporally COherent GAN SIGGRAPH project.

TecoGAN This repository contains source code and materials for the TecoGAN project, i.e. code for a TEmporally COherent GAN for video super-resolution

Nils Thuerey 5.2k Jan 02, 2023
You Only Sample (Almost) Once: Linear Cost Self-Attention Via Bernoulli Sampling

You Only Sample (Almost) Once: Linear Cost Self-Attention Via Bernoulli Sampling Transformer-based models are widely used in natural language processi

Zhanpeng Zeng 12 Jan 01, 2023
Code and results accompanying our paper titled Mixture Proportion Estimation and PU Learning: A Modern Approach at Neurips 2021 (Spotlight)

Mixture Proportion Estimation and PU Learning: A Modern Approach This repository is the official implementation of Mixture Proportion Estimation and P

Approximately Correct Machine Intelligence (ACMI) Lab 23 Dec 28, 2022
The implementation of the algorithm in the paper "Safe Deep Semi-Supervised Learning for Unseen-Class Unlabeled Data" published in ICML 2020.

DS3L This is the code for paper "Safe Deep Semi-Supervised Learning for Unseen-Class Unlabeled Data" published in ICML 2020. Setups The code is implem

Guolz 36 Oct 19, 2022
Focal and Global Knowledge Distillation for Detectors

FGD Paper: Focal and Global Knowledge Distillation for Detectors Install MMDetection and MS COCO2017 Our codes are based on MMDetection. Please follow

Mesopotamia 261 Dec 23, 2022
Pytorch Implementation of Residual Vision Transformers(ResViT)

ResViT Official Pytorch Implementation of Residual Vision Transformers(ResViT) which is described in the following paper: Onat Dalmaz and Mahmut Yurt

ICON Lab 41 Dec 08, 2022
No-Reference Image Quality Assessment via Transformers, Relative Ranking, and Self-Consistency

This repository contains the implementation for the paper: No-Reference Image Quality Assessment via Transformers, Relative Ranking, and Self-Consiste

Alireza Golestaneh 75 Dec 30, 2022
Cmsc11 arcade - Final Project for CMSC11

cmsc11_arcade Final Project for CMSC11 Developers: Limson, Mark Vincent Peñafiel

Gregory 1 Jan 18, 2022
Robotics with GPU computing

Robotics with GPU computing Cupoch is a library that implements rapid 3D data processing for robotics using CUDA. The goal of this library is to imple

Shirokuma 625 Jan 07, 2023