PyTorch implementation of paper "StarEnhancer: Learning Real-Time and Style-Aware Image Enhancement" (ICCV 2021 Oral)

Overview

StarEnhancer

StarEnhancer: Learning Real-Time and Style-Aware Image Enhancement (ICCV 2021 Oral)

Abstract: Image enhancement is a subjective process whose targets vary with user preferences. In this paper, we propose a deep learning-based image enhancement method covering multiple tonal styles using only a single model dubbed StarEnhancer. It can transform an image from one tonal style to another, even if that style is unseen. With a simple one-time setting, users can customize the model to make the enhanced images more in line with their aesthetics. To make the method more practical, we propose a well-designed enhancer that can process a 4K-resolution image over 200 FPS but surpasses the contemporaneous single style image enhancement methods in terms of PSNR, SSIM, and LPIPS. Finally, our proposed enhancement method has good interactability, which allows the user to fine-tune the enhanced image using intuitive options.

StarEnhancer

Getting started

Install

We test the code on PyTorch 1.8.1 + CUDA 11.1 + cuDNN 8.0.5, and close versions also work fine.

pip install -r requirements.txt

We mainly train the model on RTX 2080Ti * 4, but a smaller mini batch size can also work.

Prepare

You can generate your own dataset, or download the one we generate.

The final file path should be the same as the following:

┬─ save_model
│   ├─ stylish.pth.tar
│   └─ ... (model & embedding)
└─ data
    ├─ train
    │   ├─ 01-Experts-A
    │   │   ├─ a0001.jpg
    │   │   └─ ... (id.jpg)
    │   └─ ... (style folder)
    ├─ valid
    │   └─ ... (style folder)
    └─ test
        └─ ... (style folder)

Download

Data and pretrained models are available on GoogleDrive.

Generate

  1. Download raw data from MIT-Adobe FiveK Dataset.
  2. Download the modified Lightroom database fivek.lrcat, and replace the original database with it.
  3. Generate dataset in JPEG format with quality 100, which can refer to this issue.
  4. Run generate_dataset.py in data folder to generate dataset.

Train

Firstly, train the style encoder:

python train_stylish.py

Secondly, fetch the style embedding for each sample in the train set:

python fetch_embedding.py

Lastly, train the curve encoder and mapping network:

python train_enhancer.py

Test

Just run:

python test.py

Testing LPIPS requires about 10 GB GPU memory, and if an OOM occurs, replace the following lines

lpips_val = loss_fn_alex(output * 2 - 1, target_img * 2 - 1).item()

with

lpips_val = 0

Notes

Due to agreements, we are unable to release part of the source code. This repository provides a pure python implementation for research use. There are some differences between the repository and the paper as follows:

  1. The repository uses a ResNet-18 w/o BN as the curve encoder's backbone, and the paper uses a more lightweight model.
  2. The paper uses CUDA to implement the color transform function, and the repository uses torch.gather to implement it.
  3. The repository removes some tricks used in training lightweight models.

Overall, this repository can achieve higher performance, but will be slightly slower.

Comments
  • Multi-style, unpaired setting

    Multi-style, unpaired setting

    您好,在多风格非配对图场景,能否交换source和target的位置,并将得到的output_A和output_B进一步经过enhancer,得到recover_A和recover_B。最后计算l1_loss(source, recover_A)和l1_loss(target, recover_B)及Triplet_loss(output_A,target, source) 和 Triplet_loss(output_B,source,target)

    def train(train_loader, mapping, enhancer, criterion, optimizer):
        losses = AverageMeter()
        criterionTriplet = torch.nn.TripletMarginLoss(margin=1.0, p=2)
        FEModel = Feature_Extract_Model().cuda()
    
        mapping.train()
        enhancer.train()
    
        for (source_img, source_center, target_img, target_center) in train_loader:
            source_img = source_img.cuda(non_blocking=True)
            source_center = source_center.cuda(non_blocking=True)
            target_img = target_img.cuda(non_blocking=True)
            target_center = target_center.cuda(non_blocking=True)
    
            style_A = mapping(source_center)
            style_B = mapping(target_center)
    
            output_A = enhancer(source_img, style_A, style_B)
            output_B = enhancer(target_img, style_B, style_A)
            recoverA = enhancer(output_A, style_B, style_A)
            recoverB = enhancer(output_B, style_A, style_B)
    
            source_img_feature = FEModel(source_img)
            target_img_feature = FEModel(target_img)
            output_A_feature = FEModel(output_A)
            output_B_feature = FEModel(output_B)
    
            loss_l1 = criterion(recoverA, source_img) + criterion(recoverB, target_img)
            loss_triplet = criterionTriplet(output_B_feature, source_img_feature, target_img_feature) + \
                           criterionTriplet(output_A_feature, target_img_feature, source_img_feature)
            loss = loss_l1 + loss_triplet
    
            losses.update(loss.item(), args.t_batch_size)
    
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
        return losses.avg
    
    opened by jxust01 4
  • Questions about dataset preparation

    Questions about dataset preparation

    您好,我想用您的工程跑一下自己的数据,现在有输入,输出一组数据对,训练数据里面A-E剩下的4种效果是怎样生成的呢,这些目标效果数据能否是非成对的呢?如果只有一种风格,能否A-E目标效果都拷贝成一样的数据呢,在train_enhancer.py所训练的单风格脚本是需要embeddings.npy文件,这个文件在单风格训练时是必须的吗

    opened by zener90818 4
  • Dataset processing

    Dataset processing

    你好,我在您提供的fivek.lrcat没找到 DeepUPE issue里的"(default) input with ExpertC"。请问单风格实验的输入是下图中的“InputAsShotZeroed”还是“(Q)InputZeroed with ExpertC WhiteBalance” image

    opened by madfff 2
  • Configure Renovate

    Configure Renovate

    WhiteSource Renovate

    Welcome to Renovate! This is an onboarding PR to help you understand and configure settings before regular Pull Requests begin.

    🚦 To activate Renovate, merge this Pull Request. To disable Renovate, simply close this Pull Request unmerged.


    Detected Package Files

    • requirements.txt (pip_requirements)

    Configuration Summary

    Based on the default config's presets, Renovate will:

    • Start dependency updates only once this onboarding PR is merged
    • Enable Renovate Dependency Dashboard creation
    • If semantic commits detected, use semantic commit type fix for dependencies and chore for all others
    • Ignore node_modules, bower_components, vendor and various test/tests directories
    • Autodetect whether to pin dependencies or maintain ranges
    • Rate limit PR creation to a maximum of two per hour
    • Limit to maximum 20 open PRs at any time
    • Group known monorepo packages together
    • Use curated list of recommended non-monorepo package groupings
    • Fix some problems with very old Maven commons versions
    • Ignore spring cloud 1.x releases
    • Ignore http4s digest-based 1.x milestones
    • Use node versioning for @types/node
    • Limit concurrent requests to reduce load on Repology servers until we can fix this properly, see issue 10133

    🔡 Would you like to change the way Renovate is upgrading your dependencies? Simply edit the renovate.json in this branch with your custom config and the list of Pull Requests in the "What to Expect" section below will be updated the next time Renovate runs.


    What to Expect

    With your current configuration, Renovate will create 1 Pull Request:

    Pin dependency torch to ==1.10.0
    • Schedule: ["at any time"]
    • Branch name: renovate/pin-dependencies
    • Merge into: main
    • Pin torch to ==1.10.0

    ❓ Got questions? Check out Renovate's Docs, particularly the Getting Started section. If you need any further assistance then you can also request help here.


    This PR has been generated by WhiteSource Renovate. View repository job log here.

    opened by renovate[bot] 1
  • The results are not the same as the paper

    The results are not the same as the paper

    I am the author.

    Some peers have emailed me asking about the performance of the open source model that does not agree with the results in the paper. As stated in the README, the model is not the model of the paper, but the performance is similar. The exact result should be: PSNR: 25.41, SSIM: 0.942, LPIPS: 0.085

    If you find that your result is not this, then it may be that the JPEG codec is different, which is related to the version of opencv and how it is installed.

    You can uninstall your opencv (either with pip or conda) and reinstall it using pip (it must be pip, because conda installs a different JPEG codec):

    pip install opencv-python==4.5.5.62​
    
    opened by IDKiro 0
Owner
IDKiro
Stroll in the abyss
IDKiro
Official implementation of UTNet: A Hybrid Transformer Architecture for Medical Image Segmentation

UTNet (Accepted at MICCAI 2021) Official implementation of UTNet: A Hybrid Transformer Architecture for Medical Image Segmentation Introduction Transf

110 Jan 01, 2023
An off-line judger supporting distributed problem repositories

Thaw 中文 | English Thaw is an off-line judger supporting distributed problem repositories. Everyone can use Thaw release problems with license on GitHu

countercurrent_time 2 Jan 09, 2022
Code for ICLR 2020 paper "VL-BERT: Pre-training of Generic Visual-Linguistic Representations".

VL-BERT By Weijie Su, Xizhou Zhu, Yue Cao, Bin Li, Lewei Lu, Furu Wei, Jifeng Dai. This repository is an official implementation of the paper VL-BERT:

Weijie Su 698 Dec 18, 2022
This repository introduces a short project about Transfer Learning for Classification of MRI Images.

Transfer Learning for MRI Images Classification This repository introduces a short project made during my stay at Neuromatch Summer School 2021. This

Oscar Guarnizo 3 Nov 15, 2022
TigerLily: Finding drug interactions in silico with the Graph.

Drug Interaction Prediction with Tigerlily Documentation | Example Notebook | Youtube Video | Project Report Tigerlily is a TigerGraph based system de

Benedek Rozemberczki 91 Dec 30, 2022
[ICCV2021] Learning to Track Objects from Unlabeled Videos

Unsupervised Single Object Tracking (USOT) 🌿 Learning to Track Objects from Unlabeled Videos Jilai Zheng, Chao Ma, Houwen Peng and Xiaokang Yang 2021

53 Dec 28, 2022
Spearmint Bayesian optimization codebase

Spearmint Spearmint is a software package to perform Bayesian optimization. The Software is designed to automatically run experiments (thus the code n

Formerly: Harvard Intelligent Probabilistic Systems Group -- Now at Princeton 1.5k Dec 29, 2022
The official GitHub repository for the Argoverse 2 dataset.

Argoverse 2 API Official GitHub repository for the Argoverse 2 family of datasets. If you have any questions or run into any problems with either the

Argo AI 156 Dec 23, 2022
GBK-GNN: Gated Bi-Kernel Graph Neural Networks for Modeling Both Homophily and Heterophily

GBK-GNN: Gated Bi-Kernel Graph Neural Networks for Modeling Both Homophily and Heterophily Abstract Graph Neural Networks (GNNs) are widely used on a

10 Dec 20, 2022
Luminaire is a python package that provides ML driven solutions for monitoring time series data.

A hands-off Anomaly Detection Library Table of contents What is Luminaire Quick Start Time Series Outlier Detection Workflow Anomaly Detection for Hig

Zillow 670 Jan 02, 2023
SCALoss: Side and Corner Aligned Loss for Bounding Box Regression (AAAI2022).

SCALoss PyTorch implementation of the paper "SCALoss: Side and Corner Aligned Loss for Bounding Box Regression" (AAAI 2022). Introduction IoU-based lo

TuZheng 20 Sep 07, 2022
Segmentation vgg16 fcn - cityscapes

VGGSegmentation Segmentation vgg16 fcn - cityscapes Priprema skupa skripta prepare_dataset_downsampled.py Iz slika cityscapesa izrezuje haubu automobi

6 Oct 24, 2020
Checking fibonacci - Generating the Fibonacci sequence is a classic recursive problem

Fibonaaci Series Generating the Fibonacci sequence is a classic recursive proble

Moureen Caroline O 1 Feb 15, 2022
Omniscient Video Super-Resolution

Omniscient Video Super-Resolution This is the official code of OVSR (Omniscient Video Super-Resolution, ICCV 2021). This work is based on PFNL. Datase

36 Oct 27, 2022
Demos of essentia classifiers hosted on replicate.ai

essentia-replicate-demos Demos of Essentia models hosted on replicate.ai's MTG site. The models Check our site for a complete list of the models avail

Music Technology Group - Universitat Pompeu Fabra 12 Nov 14, 2022
Reproducing Results from A Hybrid Approach to Targeting Social Assistance

title author date output Reproducing Results from A Hybrid Approach to Targeting Social Assistance Lendie Follett and Heath Henderson 12/28/2021 html_

Lendie Follett 0 Jan 06, 2022
A PyTorch implementation of the architecture of Mask RCNN

EDIT (AS OF 4th NOVEMBER 2019): This implementation has multiple errors and as of the date 4th, November 2019 is insufficient to be utilized as a reso

Sai Himal Allu 975 Dec 30, 2022
⚡️Optimizing einsum functions in NumPy, Tensorflow, Dask, and more with contraction order optimization.

Optimized Einsum Optimized Einsum: A tensor contraction order optimizer Optimized einsum can significantly reduce the overall execution time of einsum

Daniel Smith 653 Dec 30, 2022
Official PyTorch implementation of Retrieve in Style: Unsupervised Facial Feature Transfer and Retrieval.

Retrieve in Style: Unsupervised Facial Feature Transfer and Retrieval PyTorch This is the PyTorch implementation of Retrieve in Style: Unsupervised Fa

60 Oct 12, 2022
Unsupervised Pre-training for Person Re-identification (LUPerson)

LUPerson Unsupervised Pre-training for Person Re-identification (LUPerson). The repository is for our CVPR2021 paper Unsupervised Pre-training for Per

143 Dec 24, 2022