BYOL for Audio: Self-Supervised Learning for General-Purpose Audio Representation

Overview

key_visual

BYOL for Audio: Self-Supervised Learning for General-Purpose Audio Representation

This is a demo implementation of BYOL for Audio (BYOL-A), a self-supervised learning method for general-purpose audio representation, includes:

  • Training code that can train models with arbitrary audio files.
  • Evaluation code that can evaluate trained models with downstream tasks.
  • Pretrained weights.

If you find BYOL-A useful in your research, please use the following BibTeX entry for citation.

@misc{niizumi2021byol-a,
      title={BYOL for Audio: Self-Supervised Learning for General-Purpose Audio Representation}, 
      author={Daisuke Niizumi and Daiki Takeuchi and Yasunori Ohishi and Noboru Harada and Kunio Kashino},
      booktitle = {2021 International Joint Conference on Neural Networks, {IJCNN} 2021},
      year={2021},
      eprint={2103.06695},
      archivePrefix={arXiv},
      primaryClass={eess.AS}
}

Getting Started

  1. Download external source files, and apply a patch. Our implementation uses the following.

    curl -O https://raw.githubusercontent.com/lucidrains/byol-pytorch/2aa84ee18fafecaf35637da4657f92619e83876d/byol_pytorch/byol_pytorch.py
    patch < byol_a/byol_pytorch.diff
    mv byol_pytorch.py byol_a
    curl -O https://raw.githubusercontent.com/daisukelab/general-learning/7b31d31637d73e1a74aec3930793bd5175b64126/MLP/torch_mlp_clf.py
    mv torch_mlp_clf.py utils
  2. Install PyTorch 1.7.1, torchaudio, and other dependencies listed on requirements.txt.

Evaluating BYOL-A Representations

Downstream Task Evaluation

The following steps will perform a downstream task evaluation by linear-probe fashion. This is an example with SPCV2; Speech commands dataset v2.

  1. Preprocess metadata (.csv file) and audio files, processed files will be stored under a folder work.

    # usage: python -m utils.preprocess_ds <downstream task> <path to its dataset>
    python -m utils.preprocess_ds spcv2 /path/to/speech_commands_v0.02
  2. Run evaluation. This will convert all .wav audio to representation embeddings first, train a lineaer layer network, then calculate accuracy as a result.

    python evaluate.py pretrained_weights/AudioNTT2020-BYOLA-64x96d2048.pth spcv2

You can also run an evaluation multiple times and take an average result. Following will evaluate on UrbanSound8K with a unit audio duration of 4.0 seconds, for 10 times.

# usage: python evaluate.py <your weight> <downstream task> <unit duration sec.> <# of iteration>
python evaluate.py pretrained_weights/AudioNTT2020-BYOLA-64x96d2048.pth us8k 4.0 10

Evaluating Representations In Your Tasks

This is an example to calculate a feature vector for an audio sample.

from byol_a.common import *
from byol_a.augmentations import PrecomputedNorm
from byol_a.models import AudioNTT2020


device = torch.device('cuda')
cfg = load_yaml_config('config.yaml')
print(cfg)

# Mean and standard deviation of the log-mel spectrogram of input audio samples, pre-computed.
# See calc_norm_stats in evaluate.py for your reference.
stats = [-5.4919195,  5.0389895]

# Preprocessor and normalizer.
to_melspec = torchaudio.transforms.MelSpectrogram(
    sample_rate=cfg.sample_rate,
    n_fft=cfg.n_fft,
    win_length=cfg.win_length,
    hop_length=cfg.hop_length,
    n_mels=cfg.n_mels,
    f_min=cfg.f_min,
    f_max=cfg.f_max,
)
normalizer = PrecomputedNorm(stats)

# Load pretrained weights.
model = AudioNTT2020(d=cfg.feature_d)
model.load_weight('pretrained_weights/AudioNTT2020-BYOLA-64x96d2048.pth', device)

# Load your audio file.
wav, sr = torchaudio.load('work/16k/spcv2/one/00176480_nohash_0.wav') # a sample from SPCV2 for now
assert sr == cfg.sample_rate, "Let's convert the audio sampling rate in advance, or do it here online."

# Convert to a log-mel spectrogram, then normalize.
lms = normalizer((to_melspec(wav) + torch.finfo(torch.float).eps).log())

# Now, convert the audio to the representation.
features = model(lms.unsqueeze(0))

Training From Scratch

You can also train models. Followings are an example of training on FSD50K.

  1. Convert all samples to 16kHz. This will convert all FSD50K files to a folder work/16k/fsd50k while preserving folder structure.

    python -m utils.convert_wav /path/to/fsd50k work/16k/fsd50k
  2. Start training, this example trains with all development set audio samples from FSD50K.

    python train.py work/16k/fsd50k/FSD50K.dev_audio

Refer to Table VI on our paper for the performance of a model trained on FSD50K.

Pretrained Weights

We include 3 pretrained weights of our encoder network.

Method Dim. Filename NSynth US8K VoxCeleb1 VoxForge SPCV2/12 SPCV2 Average
BYOL-A 512-d AudioNTT2020-BYOLA-64x96d512.pth 69.1% 78.2% 33.4% 83.5% 86.5% 88.9% 73.3%
BYOL-A 1024-d AudioNTT2020-BYOLA-64x96d1024.pth 72.7% 78.2% 38.0% 88.5% 90.1% 91.4% 76.5%
BYOL-A 2048-d AudioNTT2020-BYOLA-64x96d2048.pth 74.1% 79.1% 40.1% 90.2% 91.0% 92.2% 77.8%

License

This implementation is for your evaluation of BYOL-A paper, see LICENSE for the detail.

Acknowledgements

BYOL-A is built on top of byol-pytorch, a BYOL implementation by Phil Wang (@lucidrains). We thank Phil for open-source sophisticated code.

@misc{wang2020byol-pytorch,
  author =       {Phil Wang},
  title =        {Bootstrap Your Own Latent (BYOL), in Pytorch},
  howpublished = {\url{https://github.com/lucidrains/byol-pytorch}},
  year =         {2020}
}

References

Comments
  • Question for reproducing results

    Question for reproducing results

    Hi,

    Thanks for sharing this great work! I tried to reproduce the results using the official guidance but I failed.

    After processing the data, I run the following commands:

    CUDA_VISIBLE_DEVICES=0 python -W ignore train.py work/16k/fsd50k/FSD50K.dev_audio
    cp lightning_logs/version_4/checkpoints/epoch\=99-step\=16099.ckpt AudioNTT2020-BYOLA-64x96d2048.pth
    CUDA_VISIBLE_DEVICES=4 python evaluate.py AudioNTT2020-BYOLA-64x96d2048.pth spcv2
    

    However, the results are far from the reported results

    image

    Did I miss something important? Thank you very much.

    question 
    opened by ChenyangLEI 15
  • Evaluation on voxforge

    Evaluation on voxforge

    Hi,

    Thank you so much for your contribution. This works is very interesting and your code is easy for me to follow. But one of the downstream dataset, voxforge is missing from the preprocess_ds.py. Could you please release the code for that dataset, too?

    Thank you again for your time.

    Best regards

    opened by Huiimin5 9
  • A mistake in RunningMean

    A mistake in RunningMean

    Thank you for the fascinating paper and the code to reproduce it!

    I think there might be a problem in RunningMean. The current formula (the same in v1 and v2) looks like this:

    $$ m_n = m_{n - 1} + \frac{a_n - m_{n - 1}}{n - 1}, $$

    which is inconsistent with the correct formula listed on StackOverflow:

    $$ m_n = m_{n - 1} + \frac{a_n - m_{n - 1}}{n}. $$

    The problem is that self.n is incremented after the new mean is computed. Could you please either correct me if I am wrong or correct the code?

    opened by WhiteTeaDragon 4
  • a basic question:torch.randn(): argument 'size' must be tuple of ints, but found element of type list at pos 3`

    a basic question:torch.randn(): argument 'size' must be tuple of ints, but found element of type list at pos 3`

    Traceback (most recent call last):
      File "F:\IntellIDEA\PyCharm 2019.2.2\helpers\pydev\pydevd.py", line 2066, in <module>
        main()
      File "F:\IntellIDEA\PyCharm 2019.2.2\helpers\pydev\pydevd.py", line 2060, in main
        globals = debugger.run(setup['file'], None, None, is_module)
      File "F:\IntellIDEA\PyCharm 2019.2.2\helpers\pydev\pydevd.py", line 1411, in run
        return self._exec(is_module, entry_point_fn, module_name, file, globals, locals)
      File "F:\IntellIDEA\PyCharm 2019.2.2\helpers\pydev\pydevd.py", line 1418, in _exec
        pydev_imports.execfile(file, globals, locals)  # execute the script
      File "F:\IntellIDEA\PyCharm 2019.2.2\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
        exec(compile(contents+"\n", file, 'exec'), glob, loc)
      File "E:/pythonSpace/byol-a/train.py", line 132, in <module>
        main(audio_dir=base_path + '1/', epochs=100)
      File "E:/pythonSpace/byol-a/train.py", line 112, in main
        learner = BYOLALearner(model, cfg.lr, cfg.shape,
      File "E:/pythonSpace/byol-a/train.py", line 56, in __init__
        self.learner = BYOL(model, image_size=shape, **kwargs)
      File "D:\min\envs\torch1_7_1\lib\site-packages\byol_pytorch\byol_pytorch.py", line 211, in __init__
        self.forward(torch.randn(2, 3, image_size, image_size, device=device))
    TypeError: randn(): argument 'size' must be tuple of ints, but found element of type list at pos 3
    
    Not_an_issue 
    opened by a1030076395 3
  • Question about comments in the train.py

    Question about comments in the train.py

    https://github.com/nttcslab/byol-a/blob/master/train.py

    At line 67, there is comments for the shape of input.

            # in fact, it should be (B, 1, F, T), e.g. (256, 1, 64, 96) where 64 is the number of mel bins
            paired_inputs = torch.cat(paired_inputs) # [(B,1,T,F), (B,1,T,F)] -> (2*B,1,T,F)
    

    image

    However, it is different from the descriptions in config.yml file

    # Shape of loh-mel spectrogram [F, T].
    shape: [64, 96]
    
    bug 
    opened by ChenyangLEI 2
  • Doubt in paper

    Doubt in paper

    Hi there,

    Section 4, subsection A, part 1 from your paper says:

     The number of frames, T, in one segment was 96 in pretraining, which corresponds to 1,014ms. 
    

    However, the previous line says the hop size used was 10ms. So according to this 96 would mean 960ms?

    Am I understanding something wrong here?

    Thank You in advance!

    question 
    opened by Sreyan88 2
  • Random crop is not working.

    Random crop is not working.

    https://github.com/nttcslab/byol-a/blob/60cebdc514951e6b42e18e40a2537a01a39ad47b/byol_a/dataset.py#L80-L82

    If len(wav) > self.unit_length, length_adj will be a negative value. So start will be 0. If wav (before pad) is shorter than unit length, length_adj == 0 after padding. So start is always 0. So It will only perform a certain area of crop from 0 to self.unit_length (cropped_wav == wav[0: self.unit_length]), not random crop.

    So I think line 80 should be changed to length_adj = len(wav) - self.unit_length .

    bug 
    opened by JUiscoming 2
  • Doubt in RunningNorm

    Doubt in RunningNorm

    Hi There, great repo!

    I think I have misunderstood something wrong with the RunningNorm function. The function expects the size of an epoch, however, your implementation passes the size of the entire dataset.

    Is it a bug? Or is there a problem with my understanding?

    Thank You!

    question 
    opened by Sreyan88 2
  • How to interpret the performance

    How to interpret the performance

    Hi, it' s a great work, but how can I understance the performance metric? For example, VoxCeleb1 is usually for speaker verification, shouldn't we measure EER?

    opened by ranchlai 2
  • Finetuning of BYOL-A

    Finetuning of BYOL-A

    Hi,

    your paper is super interesting. I have a question regarding the downstream tasks. If I understand the paper correctly, you used a single linear layer for the downstream tasks which only used the sum of mean and max of the representation over time as input.

    Did you try to finetune BYOL-A end-to-end after pretraining to the downstream tasks? In the case of TRILL they were able to improve the performance even further by finetuning the whole model end-to-end. Is there a specific reason why this is not possible with BYOL-A?

    questions 
    opened by mschiwek 1
  • Missing scaling of validation samples in evaluate.py

    Missing scaling of validation samples in evaluate.py

    https://github.com/nttcslab/byol-a/blob/master/evaluate.py#L112

    It also needs: X_val = scaler.transform(X_val), or validation acc & loss will be invalid. This can be one of the reasons why we see lower performance when I tried to get official performances...

    bug 
    opened by daisukelab 0
Releases(v2.0.0)
Owner
NTT Communication Science Laboratories
NTT Communication Science Laboratories
State-of-the-art data augmentation search algorithms in PyTorch

MuarAugment Description MuarAugment is a package providing the easiest way to a state-of-the-art data augmentation pipeline. How to use You can instal

43 Dec 12, 2022
Fuzzer for Linux Kernel Drivers

difuze: Fuzzer for Linux Kernel Drivers This repo contains all the sources (including setup scripts), you need to get difuze up and running. Tested on

seclab 344 Dec 27, 2022
Bayesian optimization in PyTorch

BoTorch is a library for Bayesian Optimization built on PyTorch. BoTorch is currently in beta and under active development! Why BoTorch ? BoTorch Prov

2.5k Dec 31, 2022
Implementation of "JOKR: Joint Keypoint Representation for Unsupervised Cross-Domain Motion Retargeting"

JOKR: Joint Keypoint Representation for Unsupervised Cross-Domain Motion Retargeting Pytorch implementation for the paper "JOKR: Joint Keypoint Repres

45 Dec 25, 2022
Code for "ShineOn: Illuminating Design Choices for Practical Video-based Virtual Clothing Try-on", accepted at WACV 2021 Generation of Human Behavior Workshop.

ShineOn: Illuminating Design Choices for Practical Video-based Virtual Clothing Try-on [ Paper ] [ Project Page ] This repository contains the code fo

Andrew Jong 97 Dec 13, 2022
Mmdetection3d Noted - MMDetection3D is an open source object detection toolbox based on PyTorch

MMDetection3D is an open source object detection toolbox based on PyTorch

Jiangjingwen 13 Jan 06, 2023
Objective of the repository is to learn and build machine learning models using Pytorch. 30DaysofML Using Pytorch

30 Days Of Machine Learning Using Pytorch Objective of the repository is to learn and build machine learning models using Pytorch. List of Algorithms

Mayur 119 Nov 24, 2022
Nightmare-Writeup - Writeup for the Nightmare CTF Challenge from 2022 DiceCTF

Nightmare: One Byte to ROP // Alternate Solution TLDR: One byte write, no leak.

1 Feb 17, 2022
Explainable Medical ImageSegmentation via GenerativeAdversarial Networks andLayer-wise Relevance Propagation

MedAI: Transparency in Medical Image Segmentation What is this repo This repo contains the code and experiments that are implemented to contribute in

Awadelrahman M. A. Ahmed 1 Nov 22, 2021
Fast and simple implementation of RL algorithms, designed to run fully on GPU.

RSL RL Fast and simple implementation of RL algorithms, designed to run fully on GPU. This code is an evolution of rl-pytorch provided with NVIDIA's I

Robotic Systems Lab - Legged Robotics at ETH Zürich 68 Dec 29, 2022
A (PyTorch) imbalanced dataset sampler for oversampling low frequent classes and undersampling high frequent ones.

Imbalanced Dataset Sampler Introduction In many machine learning applications, we often come across datasets where some types of data may be seen more

Ming 2k Jan 08, 2023
Neural Articulated Radiance Field

Neural Articulated Radiance Field NARF Neural Articulated Radiance Field Atsuhiro Noguchi, Xiao Sun, Stephen Lin, Tatsuya Harada ICCV 2021 [Paper] [Co

Atsuhiro Noguchi 144 Jan 03, 2023
Automatic detection and classification of Covid severity degree in LUS (lung ultrasound) scans

Final-Project Final project in the Technion, Biomedical faculty, by Mor Ventura, Dekel Brav & Omri Magen. Subproject 1: Automatic Detection of LUS Cha

Mor Ventura 1 Dec 18, 2021
Breast Cancer Detection 🔬 ITI "AI_Pro" Graduation Project

BreastCancerDetection - This program is designed to predict two severity of abnormalities associated with breast cancer cells: benign and malignant. Mammograms from MIAS is preprocessed and features

6 Nov 29, 2022
A library for augmentation of a YOLO-formated dataset

YOLO Dataset Augmentation lib Инструкция по использованию этой библиотеки Запуск всех файлов осуществлять из консоли. GoogleCrawl_to_Dataset.py Это ск

Egor Orel 1 Dec 10, 2022
(CVPR 2021) Lifting 2D StyleGAN for 3D-Aware Face Generation

Lifting 2D StyleGAN for 3D-Aware Face Generation Official implementation of paper "Lifting 2D StyleGAN for 3D-Aware Face Generation". Requirements You

Yichun Shi 66 Nov 29, 2022
Anomaly Transformer: Time Series Anomaly Detection with Association Discrepancy" (ICLR 2022 Spotlight)

About Code release for Anomaly Transformer: Time Series Anomaly Detection with Association Discrepancy (ICLR 2022 Spotlight)

THUML @ Tsinghua University 221 Dec 31, 2022
How Do Adam and Training Strategies Help BNNs Optimization? In ICML 2021.

AdamBNN This is the pytorch implementation of our paper "How Do Adam and Training Strategies Help BNNs Optimization?", published in ICML 2021. In this

Zechun Liu 47 Sep 20, 2022
Large-scale Hyperspectral Image Clustering Using Contrastive Learning, CIKM 21 Workshop

Spectral-spatial contrastive clustering (SSCC) Yaoming Cai, Yan Liu, Zijia Zhang, Zhihua Cai, and Xiaobo Liu, Large-scale Hyperspectral Image Clusteri

Yaoming Cai 4 Nov 02, 2022
An end-to-end machine learning library to directly optimize AUC loss

LibAUC An end-to-end machine learning library for AUC optimization. Why LibAUC? Deep AUC Maximization (DAM) is a paradigm for learning a deep neural n

Andrew 75 Dec 12, 2022