Domain Generalization with MixStyle, ICLR'21.

Overview

MixStyle

This repo contains the code of our ICLR'21 paper, "Domain Generalization with MixStyle".

The OpenReview link is https://openreview.net/forum?id=6xHJ37MVxxp.

########## Updates ############

12-04-2021: A variable self._activated is added to MixStyle to better control the computational flow. To deactivate MixStyle without modifying the model code, one can do

def deactivate_mixstyle(m):
    if type(m) == MixStyle:
        m.set_activation_status(False)

model.apply(deactivate_mixstyle)

Similarly, to activate MixStyle, one can do

def activate_mixstyle(m):
    if type(m) == MixStyle:
        m.set_activation_status(True)

model.apply(activate_mixstyle)

Note that MixStyle has been included in Dassl.pytorch. See the code for details.

05-03-2021: You might also be interested in our recently released survey on domain generalization at https://arxiv.org/abs/2103.02503, which summarizes the ten-year development in domain generalization, with coverage on the history, datasets, related problems, methodologies, potential directions, and so on.

##############################

A brief introduction: The key idea of MixStyle is to probablistically mix instance-level feature statistics of training samples across source domains. MixStyle improves model robustness to domain shift by implicitly synthesizing new domains at the feature level for regularizing the training of convolutional neural networks. This idea is largely inspired by neural style transfer which has shown that feature statistics are closely related to image style and therefore arbitrary image style transfer can be achieved by switching the feature statistics between a content and a style image.

MixStyle is very easy to implement. Below we show the PyTorch code of MixStyle.

import random
import torch
import torch.nn as nn


class MixStyle(nn.Module):
    """MixStyle.

    Reference:
      Zhou et al. Domain Generalization with MixStyle. ICLR 2021.
    """

    def __init__(self, p=0.5, alpha=0.1, eps=1e-6):
        """
        Args:
          p (float): probability of using MixStyle.
          alpha (float): parameter of the Beta distribution.
          eps (float): scaling parameter to avoid numerical issues.
        """
        super().__init__()
        self.p = p
        self.beta = torch.distributions.Beta(alpha, alpha)
        self.eps = eps
        self.alpha = alpha

        self._activated = True

    def __repr__(self):
        return f'MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps})'

    def set_activation_status(self, status=True):
        self._activated = status

    def forward(self, x):
        if not self.training or not self._activated:
            return x

        if random.random() > self.p:
            return x

        B = x.size(0)

        mu = x.mean(dim=[2, 3], keepdim=True)
        var = x.var(dim=[2, 3], keepdim=True)
        sig = (var + self.eps).sqrt()
        mu, sig = mu.detach(), sig.detach()
        x_normed = (x-mu) / sig

        lmda = self.beta.sample((B, 1, 1, 1))
        lmda = lmda.to(x.device)

        perm = torch.randperm(B)
        mu2, sig2 = mu[perm], sig[perm]
        mu_mix = mu*lmda + mu2 * (1-lmda)
        sig_mix = sig*lmda + sig2 * (1-lmda)

        return x_normed*sig_mix + mu_mix

How to apply MixStyle to your CNN models? Say you are using ResNet as the CNN architecture, and want to apply MixStyle after the 1st and 2nd residual blocks, you can first instantiate the MixStyle module using

self.mixstyle = MixStyle(p=0.5, alpha=0.1)

during network construction (in __init__()), and then apply MixStyle in the forward pass like

def forward(self, x):
    x = self.conv1(x) # 1st convolution layer
    x = self.res1(x) # 1st residual block
    x = self.mixstyle(x)
    x = self.res2(x) # 2nd residual block
    x = self.mixstyle(x)
    x = self.res3(x) # 3rd residual block
    x = self.res4(x) # 4th residual block
    ...

In our paper, we have demonstrated the effectiveness of MixStyle on three tasks: image classification, person re-identification, and reinforcement learning. The source code for reproducing all experiments can be found in mixstyle-release/imcls, mixstyle-release/reid, and mixstyle-release/rl, respectively.

Takeaways on applying MixStyle to your tasks:

  • Applying MixStyle to multiple lower layers is generally better
  • Do not apply MixStyle to the last layer that is the closest to the prediction layer
  • Different tasks might favor different combinations

For more analytical studies, please read our paper at https://openreview.net/forum?id=6xHJ37MVxxp.

To cite MixStyle in your publications, please use the following bibtex entry

@inproceedings{zhou2021mixstyle,
  title={Domain Generalization with MixStyle},
  author={Zhou, Kaiyang and Yang, Yongxin and Qiao, Yu and Xiang, Tao},
  booktitle={ICLR},
  year={2021}
}
Owner
Kaiyang
Researcher in computer vision and machine learning :)
Kaiyang
Self-Adaptable Point Processes with Nonparametric Time Decays

NPPDecay This is our implementation for the paper Self-Adaptable Point Processes with Nonparametric Time Decays, by Zhimeng Pan, Zheng Wang, Jeff M. P

zpan 2 Sep 24, 2022
PINN Burgers - 1D Burgers equation simulated by PINN

PINN(s): Physics-Informed Neural Network(s) for Burgers equation This is an impl

ShotaDEGUCHI 1 Feb 12, 2022
Adversarial Graph Augmentation to Improve Graph Contrastive Learning

ADGCL : Adversarial Graph Augmentation to Improve Graph Contrastive Learning Introduction This repo contains the Pytorch [1] implementation of Adversa

susheel suresh 62 Nov 19, 2022
This is an official implementation for "ResT: An Efficient Transformer for Visual Recognition".

ResT By Qing-Long Zhang and Yu-Bin Yang [State Key Laboratory for Novel Software Technology at Nanjing University] This repo is the official implement

zhql 222 Dec 13, 2022
ML-Decoder: Scalable and Versatile Classification Head

ML-Decoder: Scalable and Versatile Classification Head Paper Official PyTorch Implementation Tal Ridnik, Gilad Sharir, Avi Ben-Cohen, Emanuel Ben-Baru

189 Jan 04, 2023
Unsupervised Learning of Probably Symmetric Deformable 3D Objects from Images in the Wild

Unsupervised Learning of Probably Symmetric Deformable 3D Objects from Images in the Wild

1.1k Jan 03, 2023
Auto Seg-Loss: Searching Metric Surrogates for Semantic Segmentation

Auto-Seg-Loss By Hao Li, Chenxin Tao, Xizhou Zhu, Xiaogang Wang, Gao Huang, Jifeng Dai This is the official implementation of the ICLR 2021 paper Auto

61 Dec 21, 2022
This repository includes the official project for the paper: TransMix: Attend to Mix for Vision Transformers.

TransMix: Attend to Mix for Vision Transformers This repository includes the official project for the paper: TransMix: Attend to Mix for Vision Transf

Jie-Neng Chen 130 Jan 01, 2023
SimulLR - PyTorch Implementation of SimulLR

PyTorch Implementation of SimulLR There is an interesting work[1] about simultan

11 Dec 22, 2022
OntoProtein: Protein Pretraining With Ontology Embedding

OntoProtein This is the implement of the paper "OntoProtein: Protein Pretraining With Ontology Embedding". OntoProtein is an effective method that mak

ZJUNLP 80 Dec 14, 2022
Github for the conference paper GLOD-Gaussian Likelihood OOD detector

FOOD - Fast OOD Detector Pytorch implamentation of the confernce peper FOOD arxiv link. Abstract Deep neural networks (DNNs) perform well at classifyi

17 Jun 19, 2022
[NeurIPS 2021] “Improving Contrastive Learning on Imbalanced Data via Open-World Sampling”,

Improving Contrastive Learning on Imbalanced Data via Open-World Sampling Introduction Contrastive learning approaches have achieved great success in

VITA 24 Dec 17, 2022
VGGVox models for Speaker Identification and Verification trained on the VoxCeleb (1 & 2) datasets

VGGVox models for speaker identification and verification This directory contains code to import and evaluate the speaker identification and verificat

338 Dec 27, 2022
Efficient Deep Learning Systems course

Efficient Deep Learning Systems This repository contains materials for the Efficient Deep Learning Systems course taught at the Faculty of Computer Sc

Max Ryabinin 173 Dec 29, 2022
Learning to See by Looking at Noise

Learning to See by Looking at Noise This is the official implementation of Learning to See by Looking at Noise. In this work, we investigate a suite o

Manel Baradad Jurjo 82 Dec 24, 2022
Deep Reinforcement Learning for Keras.

Deep Reinforcement Learning for Keras What is it? keras-rl implements some state-of-the art deep reinforcement learning algorithms in Python and seaml

Keras-RL 0 Dec 15, 2022
Chinese license plate recognition

AgentCLPR 简介 一个基于 ONNXRuntime、AgentOCR 和 License-Plate-Detector 项目开发的中国车牌检测识别系统。 车牌识别效果 支持多种车牌的检测和识别(其中单层车牌识别效果较好): 单层车牌: [[[[373, 282], [69, 284],

AgentMaker 26 Dec 25, 2022
Wenet STT Python

Wenet STT Python Beta Software Simple Python library, distributed via binary wheels with few direct dependencies, for easily using WeNet models for sp

David Zurow 33 Feb 21, 2022
A toy compiler that can convert Python scripts to pickle bytecode 🥒

Pickora 🐰 A small compiler that can convert Python scripts to pickle bytecode. Requirements Python 3.8+ No third-party modules are required. Usage us

ꌗᖘ꒒ꀤ꓄꒒ꀤꈤꍟ 68 Jan 04, 2023