Ensemble Knowledge Guided Sub-network Search and Fine-tuning for Filter Pruning

Overview

Ensemble Knowledge Guided Sub-network Search and Fine-tuning for Filter Pruning

This repository is official Tensorflow implementation of paper:

Ensemble Knowledge Guided Sub-network Search and Fine-tuning for Filter Pruning [paper link]

and Tensorflow 2 example code for
   "Custom layers", "Custom training loop", "XLA (JIT)-compiling", "Distributed learing", and "Gradients accumulator".

Paper abstract

Conventional NAS-based pruning algorithms aim to find the sub-network with the best validation performance. However, validation performance does not successfully represent test performance, i.e., potential performance. Also, although fine-tuning the pruned network to restore the performance drop is an inevitable process, few studies have handled this issue. This paper proposes a novel sub-network search and fine-tuning method, i.e., Ensemble Knowledge Guidance (EKG). First, we experimentally prove that the fluctuation of the loss landscape is an effective metric to evaluate the potential performance. In order to search a sub-network with the smoothest loss landscape at a low cost, we propose a pseudo-supernet built by an ensemble sub-network knowledge distillation. Next, we propose a novel fine-tuning that re-uses the information of the search phase. We store the interim sub-networks, that is, the by-products of the search phase, and transfer their knowledge into the pruned network. Note that EKG is easy to be plugged-in and computationally efficient. For example, in the case of ResNet-50, about 45% of FLOPS is removed without any performance drop in only 315 GPU hours.


Conceptual visualization of the goal of the proposed method.

Contribution points and key features

  • As a new tool to measure the potential performance of sub-network in NAS-based pruning, the smoothness of the loss landscape is presented. Also, the experimental evidence that the loss landscape fluctuation has a higher correlation with the test performance than the validation performance is provided.
  • The pseudo-supernet based on an ensemble sub-network knowledge distillation is proposed to find a sub-network of smoother loss landscape without increasing complexity. It helps NAS-based pruning to prune all pre-trained networks, and also allows to find optimal sub-network(s) more accurately.
  • To our knowledge, this paper provides the world-first approach to store the information of the search phase in a memory bank and to reuse it in the fine-tuning phase of the pruned network. The proposed memory bank contributes to greatly improving the performance of the pruned network.

Requirement

  • Tensorflow >= 2.7 (I have tested on 2.7-2.8)
  • Pickle
  • tqdm

How to run

  1. Move to the codebase.
  2. Train and evaluate our model by the below command.
  # ResNet-56 on CIFAR10
  python train_cifar.py --gpu_id 0 --arch ResNet-56 --dataset CIFAR10 --search_target_rate 0.45 --train_path ../test
  python test.py --gpu_id 0 --arch ResNet-56 --dataset CIFAR10 --trained_param ../test/trained_param.pkl

Experimental results


(Left) Potential performance vs. validation loss (right) Potential performance vs. condition number. 50 sub-networks of ResNet-56 trained on CIFAR10 were used for this experiment. accurately.


Visualization of loss landscapes of sub-networks searched by various filter importance scoring algorithms.

Comparison with various pruning techniques for ResNet family trained on ImageNet.


Performance analysis in case of ResNet-50 trained on ImageNet-2012. The left plot is the FLOPs reduction rate-Top-1 accuracy, and the right plot is the GPU hours-Top-1 accuracy.

Reference

@article{lee2022ensemble,
  title        = {Ensemble Knowledge Guided Sub-network Search and Fine-tuning for Filter Pruning},
  author       = {Seunghyun Lee, Byung Cheol Song},
  year         = 2022,
  journal      = {arXiv preprint arXiv:2203.02651}
}

Owner
Seunghyun Lee
Knowledge distillation; Neural network light-weighting; Tensorflow
Seunghyun Lee
Open-sourcing the Slates Dataset for recommender systems research

FINN.no Recommender Systems Slate Dataset This repository accompany the paper "Dynamic Slate Recommendation with Gated Recurrent Units and Thompson Sa

FINN.no 48 Nov 28, 2022
This is a TensorFlow implementation for C2-Rec

This is a TensorFlow implementation for C2-Rec We refer to the repo SASRec. Requirements requirement.txt Datasets This repo includes Amazon Beauty dat

7 Nov 14, 2022
CLOCs: Camera-LiDAR Object Candidates Fusion for 3D Object Detection

CLOCs is a novel Camera-LiDAR Object Candidates fusion network. It provides a low-complexity multi-modal fusion framework that improves the performance of single-modality detectors. CLOCs operates on

Su Pang 254 Dec 16, 2022
A PyTorch Lightning Callback for pushing models to the Hugging Face Hub 🤗⚡️

hf-hub-lightning A callback for pushing lightning models to the Hugging Face Hub. Note: I made this package for myself, mostly...if folks seem to be i

Nathan Raw 27 Dec 14, 2022
Improving Factual Completeness and Consistency of Image-to-text Radiology Report Generation

Improving Factual Completeness and Consistency of Image-to-text Radiology Report Generation The reference code of Improving Factual Completeness and C

46 Dec 15, 2022
这是一个deeplabv3-plus-pytorch的源码,可以用于训练自己的模型。

DeepLabv3+:Encoder-Decoder with Atrous Separable Convolution语义分割模型在Pytorch当中的实现 目录 性能情况 Performance 所需环境 Environment 注意事项 Attention 文件下载 Download 训练步骤

Bubbliiiing 350 Dec 28, 2022
Official MegEngine implementation of CREStereo(CVPR 2022 Oral).

[CVPR 2022] Practical Stereo Matching via Cascaded Recurrent Network with Adaptive Correlation This repository contains MegEngine implementation of ou

MEGVII Research 309 Dec 30, 2022
A PyTorch implementation of the architecture of Mask RCNN

EDIT (AS OF 4th NOVEMBER 2019): This implementation has multiple errors and as of the date 4th, November 2019 is insufficient to be utilized as a reso

Sai Himal Allu 975 Dec 30, 2022
PyTorch Code for NeurIPS 2021 paper Anti-Backdoor Learning: Training Clean Models on Poisoned Data.

Anti-Backdoor Learning PyTorch Code for NeurIPS 2021 paper Anti-Backdoor Learning: Training Clean Models on Poisoned Data. The Anti-Backdoor Learning

Yige-Li 51 Dec 07, 2022
Not All Points Are Equal: Learning Highly Efficient Point-based Detectors for 3D LiDAR Point Clouds (CVPR 2022, Oral)

Not All Points Are Equal: Learning Highly Efficient Point-based Detectors for 3D LiDAR Point Clouds (CVPR 2022, Oral) This is the official implementat

Yifan Zhang 259 Dec 25, 2022
ESPNet: Efficient Spatial Pyramid of Dilated Convolutions for Semantic Segmentation

ESPNet: Efficient Spatial Pyramid of Dilated Convolutions for Semantic Segmentation This repository contains the source code of our paper, ESPNet (acc

Sachin Mehta 515 Dec 13, 2022
Official repository for "Action-Based Conversations Dataset: A Corpus for Building More In-Depth Task-Oriented Dialogue Systems"

Action-Based Conversations Dataset (ABCD) This respository contains the code and data for ABCD (Chen et al., 2021) Introduction Whereas existing goal-

ASAPP Research 49 Oct 09, 2022
Modeling Temporal Concept Receptive Field Dynamically for Untrimmed Video Analysis

Modeling Temporal Concept Receptive Field Dynamically for Untrimmed Video Analysis This is a PyTorch implementation of the model described in our pape

qzhb 6 Jul 08, 2021
PyTorch implementation of CDistNet: Perceiving Multi-Domain Character Distance for Robust Text Recognition

PyTorch implementation of CDistNet: Perceiving Multi-Domain Character Distance for Robust Text Recognition The unofficial code of CDistNet. Now, we ha

25 Jul 20, 2022
Computational Methods Course at UdeA. Forked and size reduced from:

Computational Methods for Physics & Astronomy Book version at: https://restrepo.github.io/ComputationalMethods by: Sebastian Bustamante 2014/2015 Dieg

Diego Restrepo 11 Sep 10, 2022
Validated, scalable, community developed variant calling, RNA-seq and small RNA analysis

Validated, scalable, community developed variant calling, RNA-seq and small RNA analysis. You write a high level configuration file specifying your in

Blue Collar Bioinformatics 917 Jan 03, 2023
Course content and resources for the AIAIART course.

AIAIART course This repo will house the notebooks used for the AIAIART course. Part 1 (first four lessons) ran via Discord in September/October 2021.

Jonathan Whitaker 492 Jan 06, 2023
TLXZoo - Pre-trained models based on TensorLayerX

Pre-trained models based on TensorLayerX. TensorLayerX is a multi-backend AI fra

TensorLayer Community 13 Dec 07, 2022
Code for one-stage adaptive set-based HOI detector AS-Net.

AS-Net Code for one-stage adaptive set-based HOI detector AS-Net. Mingfei Chen*, Yue Liao*, Si Liu, Zhiyuan Chen, Fei Wang, Chen Qian. "Reformulating

Mingfei Chen 45 Dec 09, 2022
Solving Zero-Shot Learning in Named Entity Recognition with Common Sense Knowledge

Zero-Shot Learning in Named Entity Recognition with Common Sense Knowledge Associated code for the paper Zero-Shot Learning in Named Entity Recognitio

Søren Hougaard Mulvad 13 Dec 25, 2022