Domain Generalization with MixStyle, ICLR'21.

Overview

MixStyle

This repo contains the code of our ICLR'21 paper, "Domain Generalization with MixStyle".

The OpenReview link is https://openreview.net/forum?id=6xHJ37MVxxp.

########## Updates ############

12-04-2021: A variable self._activated is added to MixStyle to better control the computational flow. To deactivate MixStyle without modifying the model code, one can do

def deactivate_mixstyle(m):
    if type(m) == MixStyle:
        m.set_activation_status(False)

model.apply(deactivate_mixstyle)

Similarly, to activate MixStyle, one can do

def activate_mixstyle(m):
    if type(m) == MixStyle:
        m.set_activation_status(True)

model.apply(activate_mixstyle)

Note that MixStyle has been included in Dassl.pytorch. See the code for details.

05-03-2021: You might also be interested in our recently released survey on domain generalization at https://arxiv.org/abs/2103.02503, which summarizes the ten-year development in domain generalization, with coverage on the history, datasets, related problems, methodologies, potential directions, and so on.

##############################

A brief introduction: The key idea of MixStyle is to probablistically mix instance-level feature statistics of training samples across source domains. MixStyle improves model robustness to domain shift by implicitly synthesizing new domains at the feature level for regularizing the training of convolutional neural networks. This idea is largely inspired by neural style transfer which has shown that feature statistics are closely related to image style and therefore arbitrary image style transfer can be achieved by switching the feature statistics between a content and a style image.

MixStyle is very easy to implement. Below we show the PyTorch code of MixStyle.

import random
import torch
import torch.nn as nn


class MixStyle(nn.Module):
    """MixStyle.

    Reference:
      Zhou et al. Domain Generalization with MixStyle. ICLR 2021.
    """

    def __init__(self, p=0.5, alpha=0.1, eps=1e-6):
        """
        Args:
          p (float): probability of using MixStyle.
          alpha (float): parameter of the Beta distribution.
          eps (float): scaling parameter to avoid numerical issues.
        """
        super().__init__()
        self.p = p
        self.beta = torch.distributions.Beta(alpha, alpha)
        self.eps = eps
        self.alpha = alpha

        self._activated = True

    def __repr__(self):
        return f'MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps})'

    def set_activation_status(self, status=True):
        self._activated = status

    def forward(self, x):
        if not self.training or not self._activated:
            return x

        if random.random() > self.p:
            return x

        B = x.size(0)

        mu = x.mean(dim=[2, 3], keepdim=True)
        var = x.var(dim=[2, 3], keepdim=True)
        sig = (var + self.eps).sqrt()
        mu, sig = mu.detach(), sig.detach()
        x_normed = (x-mu) / sig

        lmda = self.beta.sample((B, 1, 1, 1))
        lmda = lmda.to(x.device)

        perm = torch.randperm(B)
        mu2, sig2 = mu[perm], sig[perm]
        mu_mix = mu*lmda + mu2 * (1-lmda)
        sig_mix = sig*lmda + sig2 * (1-lmda)

        return x_normed*sig_mix + mu_mix

How to apply MixStyle to your CNN models? Say you are using ResNet as the CNN architecture, and want to apply MixStyle after the 1st and 2nd residual blocks, you can first instantiate the MixStyle module using

self.mixstyle = MixStyle(p=0.5, alpha=0.1)

during network construction (in __init__()), and then apply MixStyle in the forward pass like

def forward(self, x):
    x = self.conv1(x) # 1st convolution layer
    x = self.res1(x) # 1st residual block
    x = self.mixstyle(x)
    x = self.res2(x) # 2nd residual block
    x = self.mixstyle(x)
    x = self.res3(x) # 3rd residual block
    x = self.res4(x) # 4th residual block
    ...

In our paper, we have demonstrated the effectiveness of MixStyle on three tasks: image classification, person re-identification, and reinforcement learning. The source code for reproducing all experiments can be found in mixstyle-release/imcls, mixstyle-release/reid, and mixstyle-release/rl, respectively.

Takeaways on applying MixStyle to your tasks:

  • Applying MixStyle to multiple lower layers is generally better
  • Do not apply MixStyle to the last layer that is the closest to the prediction layer
  • Different tasks might favor different combinations

For more analytical studies, please read our paper at https://openreview.net/forum?id=6xHJ37MVxxp.

To cite MixStyle in your publications, please use the following bibtex entry

@inproceedings{zhou2021mixstyle,
  title={Domain Generalization with MixStyle},
  author={Zhou, Kaiyang and Yang, Yongxin and Qiao, Yu and Xiang, Tao},
  booktitle={ICLR},
  year={2021}
}
Owner
Kaiyang
Researcher in computer vision and machine learning :)
Kaiyang
An elaborate and exhaustive paper list for Named Entity Recognition (NER)

Named-Entity-Recognition-NER-Papers by Pengfei Liu, Jinlan Fu and other contributors. An elaborate and exhaustive paper list for Named Entity Recognit

Pengfei Liu 388 Dec 18, 2022
Sound Event Detection with FilterAugment

Sound Event Detection with FilterAugment Official implementation of Heavily Augmented Sound Event Detection utilizing Weak Predictions (DCASE2021 Chal

43 Aug 28, 2022
Google Landmark Recogntion and Retrieval 2021 Solutions

Google Landmark Recogntion and Retrieval 2021 Solutions In this repository you can find solution and code for Google Landmark Recognition 2021 and Goo

Vadim Timakin 5 Nov 25, 2022
My solution for the 7th place / 245 in the Umoja Hack 2022 challenge

Umoja Hack 2022 : Insurance Claim Challenge My solution for the 7th place / 245 in the Umoja Hack 2022 challenge Umoja Hack Africa is a yearly hackath

Souames Annis 17 Jun 03, 2022
Reimplementation of the paper `Human Attention Maps for Text Classification: Do Humans and Neural Networks Focus on the Same Words? (ACL2020)`

Human Attention for Text Classification Re-implementation of the paper Human Attention Maps for Text Classification: Do Humans and Neural Networks Foc

Shunsuke KITADA 15 Dec 13, 2021
This a classic fintech problem that introduces real life difficulties such as data imbalance. Check out the notebook to find out more!

Credit Card Fraud Detection Introduction Online transactions have become a crucial part of any business over the years. Many of those transactions use

Jonathan Hasbani 0 Jan 20, 2022
LeetCode Solutions https://t.me/tenvlad

leetcode LeetCode Solutions groupped by common patterns YouTube: https://www.youtube.com/c/vladten Telegram: https://t.me/nilinterface Problems source

Vlad Ten 158 Dec 29, 2022
UniFormer - official implementation of UniFormer

UniFormer This repo is the official implementation of "Uniformer: Unified Transf

SenseTime X-Lab 573 Jan 04, 2023
Official codebase used to develop Vision Transformer, MLP-Mixer, LiT and more.

Big Vision This codebase is designed for training large-scale vision models on Cloud TPU VMs. It is based on Jax/Flax libraries, and uses tf.data and

Google Research 701 Jan 03, 2023
Picasso: A CUDA-based Library for Deep Learning over 3D Meshes

The Picasso Library is intended for complex real-world applications with large-scale surfaces, while it also performs impressively on the small-scale applications over synthetic shape manifolds. We h

97 Dec 01, 2022
TransMorph: Transformer for Medical Image Registration

TransMorph: Transformer for Medical Image Registration keywords: Vision Transformer, Swin Transformer, convolutional neural networks, image registrati

Junyu Chen 180 Jan 07, 2023
A simple tutoral for error correction task, based on Pytorch

gramcorrector A simple tutoral for error correction task, based on Pytorch Grammatical Error Detection (sentence-level) a binary sequence-based classi

peiyuan_gong 8 Dec 03, 2022
Vignette is a face tracking software for characters using osu!framework.

Vignette is a face tracking software for characters using osu!framework. Unlike most solutions, Vignette is: Made with osu!framework, the game framewo

Vignette 412 Dec 28, 2022
Run PowerShell command without invoking powershell.exe

PowerLessShell PowerLessShell rely on MSBuild.exe to remotely execute PowerShell scripts and commands without spawning powershell.exe. You can also ex

Mr.Un1k0d3r 1.2k Jan 03, 2023
ESTDepth: Multi-view Depth Estimation using Epipolar Spatio-Temporal Networks (CVPR 2021)

ESTDepth: Multi-view Depth Estimation using Epipolar Spatio-Temporal Networks (CVPR 2021) Project Page | Video | Paper | Data We present a novel metho

65 Nov 28, 2022
Microsoft Cognitive Toolkit (CNTK), an open source deep-learning toolkit

CNTK Chat Windows build status Linux build status The Microsoft Cognitive Toolkit (https://cntk.ai) is a unified deep learning toolkit that describes

Microsoft 17.3k Dec 29, 2022
Machine Learning From Scratch. Bare bones NumPy implementations of machine learning models and algorithms with a focus on accessibility. Aims to cover everything from linear regression to deep learning.

Machine Learning From Scratch About Python implementations of some of the fundamental Machine Learning models and algorithms from scratch. The purpose

Erik Linder-Norén 21.8k Jan 09, 2023
Code for "Learning Structural Edits via Incremental Tree Transformations" (ICLR'21)

Learning Structural Edits via Incremental Tree Transformations Code for "Learning Structural Edits via Incremental Tree Transformations" (ICLR'21) 1.

NeuLab 40 Dec 23, 2022
A port of muP to JAX/Haiku

MUP for Haiku This is a (very preliminary) port of Yang and Hu et al.'s μP repo to Haiku and JAX. It's not feature complete, and I'm very open to sugg

18 Dec 30, 2022
Recommendationsystem - Movie-recommendation - matrixfactorization colloborative filtering recommendation system user

recommendationsystem matrixfactorization colloborative filtering recommendation

kunal jagdish madavi 1 Jan 01, 2022