PyTorch deep learning projects made easy.

Overview

PyTorch Template Project

PyTorch deep learning project made easy.

Requirements

  • Python >= 3.5 (3.6 recommended)
  • PyTorch >= 0.4 (1.2 recommended)
  • tqdm (Optional for test.py)
  • tensorboard >= 1.14 (see Tensorboard Visualization)

Features

  • Clear folder structure which is suitable for many deep learning projects.
  • .json config file support for convenient parameter tuning.
  • Customizable command line options for more convenient parameter tuning.
  • Checkpoint saving and resuming.
  • Abstract base classes for faster development:
    • BaseTrainer handles checkpoint saving/resuming, training process logging, and more.
    • BaseDataLoader handles batch generation, data shuffling, and validation data splitting.
    • BaseModel provides basic model summary.

Folder Structure

pytorch-template/
│
├── train.py - main script to start training
├── test.py - evaluation of trained model
│
├── config.json - holds configuration for training
├── parse_config.py - class to handle config file and cli options
│
├── new_project.py - initialize new project with template files
│
├── base/ - abstract base classes
│   ├── base_data_loader.py
│   ├── base_model.py
│   └── base_trainer.py
│
├── data_loader/ - anything about data loading goes here
│   └── data_loaders.py
│
├── data/ - default directory for storing input data
│
├── model/ - models, losses, and metrics
│   ├── model.py
│   ├── metric.py
│   └── loss.py
│
├── saved/
│   ├── models/ - trained models are saved here
│   └── log/ - default logdir for tensorboard and logging output
│
├── trainer/ - trainers
│   └── trainer.py
│
├── logger/ - module for tensorboard visualization and logging
│   ├── visualization.py
│   ├── logger.py
│   └── logger_config.json
│  
└── utils/ - small utility functions
    ├── util.py
    └── ...

Usage

The code in this repo is an MNIST example of the template. Try python train.py -c config.json to run code.

Config file format

Config files are in .json format:

{
  "name": "Mnist_LeNet",        // training session name
  "n_gpu": 1,                   // number of GPUs to use for training.
  
  "arch": {
    "type": "MnistModel",       // name of model architecture to train
    "args": {

    }                
  },
  "data_loader": {
    "type": "MnistDataLoader",         // selecting data loader
    "args":{
      "data_dir": "data/",             // dataset path
      "batch_size": 64,                // batch size
      "shuffle": true,                 // shuffle training data before splitting
      "validation_split": 0.1          // size of validation dataset. float(portion) or int(number of samples)
      "num_workers": 2,                // number of cpu processes to be used for data loading
    }
  },
  "optimizer": {
    "type": "Adam",
    "args":{
      "lr": 0.001,                     // learning rate
      "weight_decay": 0,               // (optional) weight decay
      "amsgrad": true
    }
  },
  "loss": "nll_loss",                  // loss
  "metrics": [
    "accuracy", "top_k_acc"            // list of metrics to evaluate
  ],                         
  "lr_scheduler": {
    "type": "StepLR",                  // learning rate scheduler
    "args":{
      "step_size": 50,          
      "gamma": 0.1
    }
  },
  "trainer": {
    "epochs": 100,                     // number of training epochs
    "save_dir": "saved/",              // checkpoints are saved in save_dir/models/name
    "save_freq": 1,                    // save checkpoints every save_freq epochs
    "verbosity": 2,                    // 0: quiet, 1: per epoch, 2: full
  
    "monitor": "min val_loss"          // mode and metric for model performance monitoring. set 'off' to disable.
    "early_stop": 10	                 // number of epochs to wait before early stop. set 0 to disable.
  
    "tensorboard": true,               // enable tensorboard visualization
  }
}

Add addional configurations if you need.

Using config files

Modify the configurations in .json config files, then run:

python train.py --config config.json

Resuming from checkpoints

You can resume from a previously saved checkpoint by:

python train.py --resume path/to/checkpoint

Using Multiple GPU

You can enable multi-GPU training by setting n_gpu argument of the config file to larger number. If configured to use smaller number of gpu than available, first n devices will be used by default. Specify indices of available GPUs by cuda environmental variable.

python train.py --device 2,3 -c config.json

This is equivalent to

CUDA_VISIBLE_DEVICES=2,3 python train.py -c config.py

Customization

Project initialization

Use the new_project.py script to make your new project directory with template files. python new_project.py ../NewProject then a new project folder named 'NewProject' will be made. This script will filter out unneccessary files like cache, git files or readme file.

Custom CLI options

Changing values of config file is a clean, safe and easy way of tuning hyperparameters. However, sometimes it is better to have command line options if some values need to be changed too often or quickly.

This template uses the configurations stored in the json file by default, but by registering custom options as follows you can change some of them using CLI flags.

# simple class-like object having 3 attributes, `flags`, `type`, `target`.
CustomArgs = collections.namedtuple('CustomArgs', 'flags type target')
options = [
    CustomArgs(['--lr', '--learning_rate'], type=float, target=('optimizer', 'args', 'lr')),
    CustomArgs(['--bs', '--batch_size'], type=int, target=('data_loader', 'args', 'batch_size'))
    # options added here can be modified by command line flags.
]

target argument should be sequence of keys, which are used to access that option in the config dict. In this example, target for the learning rate option is ('optimizer', 'args', 'lr') because config['optimizer']['args']['lr'] points to the learning rate. python train.py -c config.json --bs 256 runs training with options given in config.json except for the batch size which is increased to 256 by command line options.

Data Loader

  • Writing your own data loader
  1. Inherit BaseDataLoader

    BaseDataLoader is a subclass of torch.utils.data.DataLoader, you can use either of them.

    BaseDataLoader handles:

    • Generating next batch
    • Data shuffling
    • Generating validation data loader by calling BaseDataLoader.split_validation()
  • DataLoader Usage

    BaseDataLoader is an iterator, to iterate through batches:

    for batch_idx, (x_batch, y_batch) in data_loader:
        pass
  • Example

    Please refer to data_loader/data_loaders.py for an MNIST data loading example.

Trainer

  • Writing your own trainer
  1. Inherit BaseTrainer

    BaseTrainer handles:

    • Training process logging
    • Checkpoint saving
    • Checkpoint resuming
    • Reconfigurable performance monitoring for saving current best model, and early stop training.
      • If config monitor is set to max val_accuracy, which means then the trainer will save a checkpoint model_best.pth when validation accuracy of epoch replaces current maximum.
      • If config early_stop is set, training will be automatically terminated when model performance does not improve for given number of epochs. This feature can be turned off by passing 0 to the early_stop option, or just deleting the line of config.
  2. Implementing abstract methods

    You need to implement _train_epoch() for your training process, if you need validation then you can implement _valid_epoch() as in trainer/trainer.py

  • Example

    Please refer to trainer/trainer.py for MNIST training.

  • Iteration-based training

    Trainer.__init__ takes an optional argument, len_epoch which controls number of batches(steps) in each epoch.

Model

  • Writing your own model
  1. Inherit BaseModel

    BaseModel handles:

    • Inherited from torch.nn.Module
    • __str__: Modify native print function to prints the number of trainable parameters.
  2. Implementing abstract methods

    Implement the foward pass method forward()

  • Example

    Please refer to model/model.py for a LeNet example.

Loss

Custom loss functions can be implemented in 'model/loss.py'. Use them by changing the name given in "loss" in config file, to corresponding name.

Metrics

Metric functions are located in 'model/metric.py'.

You can monitor multiple metrics by providing a list in the configuration file, e.g.:

"metrics": ["accuracy", "top_k_acc"],

Additional logging

If you have additional information to be logged, in _train_epoch() of your trainer class, merge them with log as shown below before returning:

additional_log = {"gradient_norm": g, "sensitivity": s}
log.update(additional_log)
return log

Testing

You can test trained model by running test.py passing path to the trained checkpoint by --resume argument.

Validation data

To split validation data from a data loader, call BaseDataLoader.split_validation(), then it will return a data loader for validation of size specified in your config file. The validation_split can be a ratio of validation set per total data(0.0 <= float < 1.0), or the number of samples (0 <= int < n_total_samples).

Note: the split_validation() method will modify the original data loader Note: split_validation() will return None if "validation_split" is set to 0

Checkpoints

You can specify the name of the training session in config files:

"name": "MNIST_LeNet",

The checkpoints will be saved in save_dir/name/timestamp/checkpoint_epoch_n, with timestamp in mmdd_HHMMSS format.

A copy of config file will be saved in the same folder.

Note: checkpoints contain:

{
  'arch': arch,
  'epoch': epoch,
  'state_dict': self.model.state_dict(),
  'optimizer': self.optimizer.state_dict(),
  'monitor_best': self.mnt_best,
  'config': self.config
}

Tensorboard Visualization

This template supports Tensorboard visualization by using either torch.utils.tensorboard or TensorboardX.

  1. Install

    If you are using pytorch 1.1 or higher, install tensorboard by 'pip install tensorboard>=1.14.0'.

    Otherwise, you should install tensorboardx. Follow installation guide in TensorboardX.

  2. Run training

    Make sure that tensorboard option in the config file is turned on.

     "tensorboard" : true
    
  3. Open Tensorboard server

    Type tensorboard --logdir saved/log/ at the project root, then server will open at http://localhost:6006

By default, values of loss and metrics specified in config file, input images, and histogram of model parameters will be logged. If you need more visualizations, use add_scalar('tag', data), add_image('tag', image), etc in the trainer._train_epoch method. add_something() methods in this template are basically wrappers for those of tensorboardX.SummaryWriter and torch.utils.tensorboard.SummaryWriter modules.

Note: You don't have to specify current steps, since WriterTensorboard class defined at logger/visualization.py will track current steps.

Contribution

Feel free to contribute any kind of function or enhancement, here the coding style follows PEP8

Code should pass the Flake8 check before committing.

TODOs

  • Multiple optimizers
  • Support more tensorboard functions
  • Using fixed random seed
  • Support pytorch native tensorboard
  • tensorboardX logger support
  • Configurable logging layout, checkpoint naming
  • Iteration-based training (instead of epoch-based)
  • Adding command line option for fine-tuning

License

This project is licensed under the MIT License. See LICENSE for more details

Acknowledgements

This project is inspired by the project Tensorflow-Project-Template by Mahmoud Gemy

Owner
Victor Huang
Victor Huang
Unified API to facilitate usage of pre-trained "perceptor" models, a la CLIP

mmc installation git clone https://github.com/dmarx/Multi-Modal-Comparators cd 'Multi-Modal-Comparators' pip install poetry poetry build pip install d

David Marx 37 Nov 25, 2022
A wrapper around SageMaker ML Lineage Tracking extending ML Lineage to end-to-end ML lifecycles, including additional capabilities around Feature Store groups, queries, and other relevant artifacts.

ML Lineage Helper This library is a wrapper around the SageMaker SDK to support ease of lineage tracking across the ML lifecycle. Lineage artifacts in

AWS Samples 12 Nov 01, 2022
Colour detection is necessary to recognize objects, it is also used as a tool in various image editing and drawing apps.

Colour Detection On Image Colour detection is the process of detecting the name of any color. Simple isn’t it? Well, for humans this is an extremely e

Astitva Veer Garg 1 Jan 13, 2022
Code to reproduce the results for Compositional Attention

Compositional-Attention This repository contains the official implementation for the paper Compositional Attention: Disentangling Search and Retrieval

Sarthak Mittal 58 Nov 30, 2022
DAFNe: A One-Stage Anchor-Free Deep Model for Oriented Object Detection

DAFNe: A One-Stage Anchor-Free Deep Model for Oriented Object Detection Code for our Paper DAFNe: A One-Stage Anchor-Free Deep Model for Oriented Obje

Steven Lang 58 Dec 19, 2022
LibMTL: A PyTorch Library for Multi-Task Learning

LibMTL LibMTL is an open-source library built on PyTorch for Multi-Task Learning (MTL). See the latest documentation for detailed introductions and AP

765 Jan 06, 2023
Official Implementation of DE-DETR and DELA-DETR in "Towards Data-Efficient Detection Transformers"

DE-DETRs By Wen Wang, Jing Zhang, Yang Cao, Yongliang Shen, and Dacheng Tao This repository is an official implementation of DE-DETR and DELA-DETR in

Wen Wang 61 Dec 12, 2022
Blender add-on: Add to Cameras menu: View → Camera, View → Add Camera, Camera → View, Previous Camera, Next Camera

Blender add-on: Camera additions In 3D view, it adds these actions to the View|Cameras menu: View → Camera : set the current camera to the 3D view Vie

German Bauer 11 Feb 08, 2022
A Free and Open Source Python Library for Multiobjective Optimization

Platypus What is Platypus? Platypus is a framework for evolutionary computing in Python with a focus on multiobjective evolutionary algorithms (MOEAs)

Project Platypus 424 Dec 18, 2022
[WWW 2021] Source code for "Graph Contrastive Learning with Adaptive Augmentation"

GCA Source code for Graph Contrastive Learning with Adaptive Augmentation (WWW 2021) For example, to run GCA-Degree under WikiCS, execute: python trai

Big Data and Multi-modal Computing Group, CRIPAC 97 Jan 07, 2023
Implementation of PersonaGPT Dialog Model

PersonaGPT An open-domain conversational agent with many personalities PersonaGPT is an open-domain conversational agent cpable of decoding personaliz

ILLIDAN Lab 42 Jan 01, 2023
PyExplainer: A Local Rule-Based Model-Agnostic Technique (Explainable AI)

PyExplainer PyExplainer is a local rule-based model-agnostic technique for generating explanations (i.e., why a commit is predicted as defective) of J

AI Wizards for Software Management (AWSM) Research Group 14 Nov 13, 2022
QuakeLabeler is a Python package to create and manage your seismic training data, processes, and visualization in a single place — so you can focus on building the next big thing.

QuakeLabeler Quake Labeler was born from the need for seismologists and developers who are not AI specialists to easily, quickly, and independently bu

Hao Mai 15 Nov 04, 2022
Implementation based on Paper - Learning a Probabilistic Latent Space of Object Shapes via 3D Generative-Adversarial Modeling

Implementation based on Paper - Learning a Probabilistic Latent Space of Object Shapes via 3D Generative-Adversarial Modeling

HamasKhan 3 Jul 08, 2022
A library for implementing Decentralized Graph Neural Network algorithms.

decentralized-gnn A package for implementing and simulating decentralized Graph Neural Network algorithms for classification of peer-to-peer nodes. De

Multimedia Knowledge and Social Analytics Lab 5 Nov 07, 2022
Recovering Brain Structure Network Using Functional Connectivity

Recovering-Brain-Structure-Network-Using-Functional-Connectivity Framework: Papers: This repository provides a PyTorch implementation of the models ad

5 Nov 30, 2022
(CVPR2021) DANNet: A One-Stage Domain Adaptation Network for Unsupervised Nighttime Semantic Segmentation

DANNet: A One-Stage Domain Adaptation Network for Unsupervised Nighttime Semantic Segmentation CVPR2021(oral) [arxiv] Requirements python3.7 pytorch==

W-zx-Y 85 Dec 07, 2022
Model that predicts the probability of a Twitter user being anti-vaccination.

stylebody {text-align: justify}/style AVAXTAR: Anti-VAXx Tweet AnalyzeR AVAXTAR is a python package to identify anti-vaccine users on twitter. The

10 Sep 27, 2022
[2021][ICCV][FSNet] Full-Duplex Strategy for Video Object Segmentation

Full-Duplex Strategy for Video Object Segmentation (ICCV, 2021) Authors: Ge-Peng Ji, Keren Fu, Zhe Wu, Deng-Ping Fan*, Jianbing Shen, & Ling Shao This

Daniel-Ji 55 Dec 22, 2022
This repo contains source code and materials for the TEmporally COherent GAN SIGGRAPH project.

TecoGAN This repository contains source code and materials for the TecoGAN project, i.e. code for a TEmporally COherent GAN for video super-resolution

Nils Thuerey 5.2k Jan 02, 2023