BYOL for Audio: Self-Supervised Learning for General-Purpose Audio Representation

Overview

key_visual

BYOL for Audio: Self-Supervised Learning for General-Purpose Audio Representation

This is a demo implementation of BYOL for Audio (BYOL-A), a self-supervised learning method for general-purpose audio representation, includes:

  • Training code that can train models with arbitrary audio files.
  • Evaluation code that can evaluate trained models with downstream tasks.
  • Pretrained weights.

If you find BYOL-A useful in your research, please use the following BibTeX entry for citation.

@misc{niizumi2021byol-a,
      title={BYOL for Audio: Self-Supervised Learning for General-Purpose Audio Representation}, 
      author={Daisuke Niizumi and Daiki Takeuchi and Yasunori Ohishi and Noboru Harada and Kunio Kashino},
      booktitle = {2021 International Joint Conference on Neural Networks, {IJCNN} 2021},
      year={2021},
      eprint={2103.06695},
      archivePrefix={arXiv},
      primaryClass={eess.AS}
}

Getting Started

  1. Download external source files, and apply a patch. Our implementation uses the following.

    curl -O https://raw.githubusercontent.com/lucidrains/byol-pytorch/2aa84ee18fafecaf35637da4657f92619e83876d/byol_pytorch/byol_pytorch.py
    patch < byol_a/byol_pytorch.diff
    mv byol_pytorch.py byol_a
    curl -O https://raw.githubusercontent.com/daisukelab/general-learning/7b31d31637d73e1a74aec3930793bd5175b64126/MLP/torch_mlp_clf.py
    mv torch_mlp_clf.py utils
  2. Install PyTorch 1.7.1, torchaudio, and other dependencies listed on requirements.txt.

Evaluating BYOL-A Representations

Downstream Task Evaluation

The following steps will perform a downstream task evaluation by linear-probe fashion. This is an example with SPCV2; Speech commands dataset v2.

  1. Preprocess metadata (.csv file) and audio files, processed files will be stored under a folder work.

    # usage: python -m utils.preprocess_ds <downstream task> <path to its dataset>
    python -m utils.preprocess_ds spcv2 /path/to/speech_commands_v0.02
  2. Run evaluation. This will convert all .wav audio to representation embeddings first, train a lineaer layer network, then calculate accuracy as a result.

    python evaluate.py pretrained_weights/AudioNTT2020-BYOLA-64x96d2048.pth spcv2

You can also run an evaluation multiple times and take an average result. Following will evaluate on UrbanSound8K with a unit audio duration of 4.0 seconds, for 10 times.

# usage: python evaluate.py <your weight> <downstream task> <unit duration sec.> <# of iteration>
python evaluate.py pretrained_weights/AudioNTT2020-BYOLA-64x96d2048.pth us8k 4.0 10

Evaluating Representations In Your Tasks

This is an example to calculate a feature vector for an audio sample.

from byol_a.common import *
from byol_a.augmentations import PrecomputedNorm
from byol_a.models import AudioNTT2020


device = torch.device('cuda')
cfg = load_yaml_config('config.yaml')
print(cfg)

# Mean and standard deviation of the log-mel spectrogram of input audio samples, pre-computed.
# See calc_norm_stats in evaluate.py for your reference.
stats = [-5.4919195,  5.0389895]

# Preprocessor and normalizer.
to_melspec = torchaudio.transforms.MelSpectrogram(
    sample_rate=cfg.sample_rate,
    n_fft=cfg.n_fft,
    win_length=cfg.win_length,
    hop_length=cfg.hop_length,
    n_mels=cfg.n_mels,
    f_min=cfg.f_min,
    f_max=cfg.f_max,
)
normalizer = PrecomputedNorm(stats)

# Load pretrained weights.
model = AudioNTT2020(d=cfg.feature_d)
model.load_weight('pretrained_weights/AudioNTT2020-BYOLA-64x96d2048.pth', device)

# Load your audio file.
wav, sr = torchaudio.load('work/16k/spcv2/one/00176480_nohash_0.wav') # a sample from SPCV2 for now
assert sr == cfg.sample_rate, "Let's convert the audio sampling rate in advance, or do it here online."

# Convert to a log-mel spectrogram, then normalize.
lms = normalizer((to_melspec(wav) + torch.finfo(torch.float).eps).log())

# Now, convert the audio to the representation.
features = model(lms.unsqueeze(0))

Training From Scratch

You can also train models. Followings are an example of training on FSD50K.

  1. Convert all samples to 16kHz. This will convert all FSD50K files to a folder work/16k/fsd50k while preserving folder structure.

    python -m utils.convert_wav /path/to/fsd50k work/16k/fsd50k
  2. Start training, this example trains with all development set audio samples from FSD50K.

    python train.py work/16k/fsd50k/FSD50K.dev_audio

Refer to Table VI on our paper for the performance of a model trained on FSD50K.

Pretrained Weights

We include 3 pretrained weights of our encoder network.

Method Dim. Filename NSynth US8K VoxCeleb1 VoxForge SPCV2/12 SPCV2 Average
BYOL-A 512-d AudioNTT2020-BYOLA-64x96d512.pth 69.1% 78.2% 33.4% 83.5% 86.5% 88.9% 73.3%
BYOL-A 1024-d AudioNTT2020-BYOLA-64x96d1024.pth 72.7% 78.2% 38.0% 88.5% 90.1% 91.4% 76.5%
BYOL-A 2048-d AudioNTT2020-BYOLA-64x96d2048.pth 74.1% 79.1% 40.1% 90.2% 91.0% 92.2% 77.8%

License

This implementation is for your evaluation of BYOL-A paper, see LICENSE for the detail.

Acknowledgements

BYOL-A is built on top of byol-pytorch, a BYOL implementation by Phil Wang (@lucidrains). We thank Phil for open-source sophisticated code.

@misc{wang2020byol-pytorch,
  author =       {Phil Wang},
  title =        {Bootstrap Your Own Latent (BYOL), in Pytorch},
  howpublished = {\url{https://github.com/lucidrains/byol-pytorch}},
  year =         {2020}
}

References

Comments
  • Question for reproducing results

    Question for reproducing results

    Hi,

    Thanks for sharing this great work! I tried to reproduce the results using the official guidance but I failed.

    After processing the data, I run the following commands:

    CUDA_VISIBLE_DEVICES=0 python -W ignore train.py work/16k/fsd50k/FSD50K.dev_audio
    cp lightning_logs/version_4/checkpoints/epoch\=99-step\=16099.ckpt AudioNTT2020-BYOLA-64x96d2048.pth
    CUDA_VISIBLE_DEVICES=4 python evaluate.py AudioNTT2020-BYOLA-64x96d2048.pth spcv2
    

    However, the results are far from the reported results

    image

    Did I miss something important? Thank you very much.

    question 
    opened by ChenyangLEI 15
  • Evaluation on voxforge

    Evaluation on voxforge

    Hi,

    Thank you so much for your contribution. This works is very interesting and your code is easy for me to follow. But one of the downstream dataset, voxforge is missing from the preprocess_ds.py. Could you please release the code for that dataset, too?

    Thank you again for your time.

    Best regards

    opened by Huiimin5 9
  • A mistake in RunningMean

    A mistake in RunningMean

    Thank you for the fascinating paper and the code to reproduce it!

    I think there might be a problem in RunningMean. The current formula (the same in v1 and v2) looks like this:

    $$ m_n = m_{n - 1} + \frac{a_n - m_{n - 1}}{n - 1}, $$

    which is inconsistent with the correct formula listed on StackOverflow:

    $$ m_n = m_{n - 1} + \frac{a_n - m_{n - 1}}{n}. $$

    The problem is that self.n is incremented after the new mean is computed. Could you please either correct me if I am wrong or correct the code?

    opened by WhiteTeaDragon 4
  • a basic question:torch.randn(): argument 'size' must be tuple of ints, but found element of type list at pos 3`

    a basic question:torch.randn(): argument 'size' must be tuple of ints, but found element of type list at pos 3`

    Traceback (most recent call last):
      File "F:\IntellIDEA\PyCharm 2019.2.2\helpers\pydev\pydevd.py", line 2066, in <module>
        main()
      File "F:\IntellIDEA\PyCharm 2019.2.2\helpers\pydev\pydevd.py", line 2060, in main
        globals = debugger.run(setup['file'], None, None, is_module)
      File "F:\IntellIDEA\PyCharm 2019.2.2\helpers\pydev\pydevd.py", line 1411, in run
        return self._exec(is_module, entry_point_fn, module_name, file, globals, locals)
      File "F:\IntellIDEA\PyCharm 2019.2.2\helpers\pydev\pydevd.py", line 1418, in _exec
        pydev_imports.execfile(file, globals, locals)  # execute the script
      File "F:\IntellIDEA\PyCharm 2019.2.2\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
        exec(compile(contents+"\n", file, 'exec'), glob, loc)
      File "E:/pythonSpace/byol-a/train.py", line 132, in <module>
        main(audio_dir=base_path + '1/', epochs=100)
      File "E:/pythonSpace/byol-a/train.py", line 112, in main
        learner = BYOLALearner(model, cfg.lr, cfg.shape,
      File "E:/pythonSpace/byol-a/train.py", line 56, in __init__
        self.learner = BYOL(model, image_size=shape, **kwargs)
      File "D:\min\envs\torch1_7_1\lib\site-packages\byol_pytorch\byol_pytorch.py", line 211, in __init__
        self.forward(torch.randn(2, 3, image_size, image_size, device=device))
    TypeError: randn(): argument 'size' must be tuple of ints, but found element of type list at pos 3
    
    Not_an_issue 
    opened by a1030076395 3
  • Question about comments in the train.py

    Question about comments in the train.py

    https://github.com/nttcslab/byol-a/blob/master/train.py

    At line 67, there is comments for the shape of input.

            # in fact, it should be (B, 1, F, T), e.g. (256, 1, 64, 96) where 64 is the number of mel bins
            paired_inputs = torch.cat(paired_inputs) # [(B,1,T,F), (B,1,T,F)] -> (2*B,1,T,F)
    

    image

    However, it is different from the descriptions in config.yml file

    # Shape of loh-mel spectrogram [F, T].
    shape: [64, 96]
    
    bug 
    opened by ChenyangLEI 2
  • Doubt in paper

    Doubt in paper

    Hi there,

    Section 4, subsection A, part 1 from your paper says:

     The number of frames, T, in one segment was 96 in pretraining, which corresponds to 1,014ms. 
    

    However, the previous line says the hop size used was 10ms. So according to this 96 would mean 960ms?

    Am I understanding something wrong here?

    Thank You in advance!

    question 
    opened by Sreyan88 2
  • Random crop is not working.

    Random crop is not working.

    https://github.com/nttcslab/byol-a/blob/60cebdc514951e6b42e18e40a2537a01a39ad47b/byol_a/dataset.py#L80-L82

    If len(wav) > self.unit_length, length_adj will be a negative value. So start will be 0. If wav (before pad) is shorter than unit length, length_adj == 0 after padding. So start is always 0. So It will only perform a certain area of crop from 0 to self.unit_length (cropped_wav == wav[0: self.unit_length]), not random crop.

    So I think line 80 should be changed to length_adj = len(wav) - self.unit_length .

    bug 
    opened by JUiscoming 2
  • Doubt in RunningNorm

    Doubt in RunningNorm

    Hi There, great repo!

    I think I have misunderstood something wrong with the RunningNorm function. The function expects the size of an epoch, however, your implementation passes the size of the entire dataset.

    Is it a bug? Or is there a problem with my understanding?

    Thank You!

    question 
    opened by Sreyan88 2
  • How to interpret the performance

    How to interpret the performance

    Hi, it' s a great work, but how can I understance the performance metric? For example, VoxCeleb1 is usually for speaker verification, shouldn't we measure EER?

    opened by ranchlai 2
  • Finetuning of BYOL-A

    Finetuning of BYOL-A

    Hi,

    your paper is super interesting. I have a question regarding the downstream tasks. If I understand the paper correctly, you used a single linear layer for the downstream tasks which only used the sum of mean and max of the representation over time as input.

    Did you try to finetune BYOL-A end-to-end after pretraining to the downstream tasks? In the case of TRILL they were able to improve the performance even further by finetuning the whole model end-to-end. Is there a specific reason why this is not possible with BYOL-A?

    questions 
    opened by mschiwek 1
  • Missing scaling of validation samples in evaluate.py

    Missing scaling of validation samples in evaluate.py

    https://github.com/nttcslab/byol-a/blob/master/evaluate.py#L112

    It also needs: X_val = scaler.transform(X_val), or validation acc & loss will be invalid. This can be one of the reasons why we see lower performance when I tried to get official performances...

    bug 
    opened by daisukelab 0
Releases(v2.0.0)
Owner
NTT Communication Science Laboratories
NTT Communication Science Laboratories
Codebase for Time-series Generative Adversarial Networks (TimeGAN)

Codebase for Time-series Generative Adversarial Networks (TimeGAN)

Jinsung Yoon 532 Dec 31, 2022
This project aims to segment 4 common retinal lesions from Fundus Images.

This project aims to segment 4 common retinal lesions from Fundus Images.

Husam Nujaim 1 Oct 10, 2021
Normalization Matters in Weakly Supervised Object Localization (ICCV 2021)

Normalization Matters in Weakly Supervised Object Localization (ICCV 2021) 99% of the code in this repository originates from this link. ICCV 2021 pap

Jeesoo Kim 10 Feb 01, 2022
Code for the paper: Sketch Your Own GAN

Sketch Your Own GAN Project | Paper | Youtube Our method takes in one or a few hand-drawn sketches and customizes an off-the-shelf GAN to match the in

677 Dec 28, 2022
Revisiting, benchmarking, and refining Heterogeneous Graph Neural Networks.

Heterogeneous Graph Benchmark Revisiting, benchmarking, and refining Heterogeneous Graph Neural Networks. Roadmap We organize our repo by task, and on

THUDM 176 Dec 17, 2022
Sample and Computation Redistribution for Efficient Face Detection

Introduction SCRFD is an efficient high accuracy face detection approach which initially described in Arxiv. Performance Precision, flops and infer ti

Sajjad Aemmi 13 Mar 05, 2022
CONditionals for Ordinal Regression and classification in PyTorch

CONDOR pytorch implementation for ordinal regression with deep neural networks. Documentation: https://GarrettJenkinson.github.io/condor_pytorch About

7 Jul 25, 2022
Romanian Automatic Speech Recognition from the ROBIN project

RobinASR This repository contains Robin's Automatic Speech Recognition (RobinASR) for the Romanian language based on the DeepSpeech2 architecture, tog

RACAI 10 Jan 01, 2023
Image processing in Python

scikit-image: Image processing in Python Website (including documentation): https://scikit-image.org/ Mailing list: https://mail.python.org/mailman3/l

Image Processing Toolbox for SciPy 5.2k Dec 31, 2022
Unofficial PyTorch implementation of MobileViT based on paper "MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer".

MobileViT RegNet Unofficial PyTorch implementation of MobileViT based on paper MOBILEVIT: LIGHT-WEIGHT, GENERAL-PURPOSE, AND MOBILE-FRIENDLY VISION TR

Hong-Jia Chen 91 Dec 02, 2022
🕵 Artificial Intelligence for social control of public administration

Non-tech crash course into Operação Serenata de Amor Tech crash course into Operação Serenata de Amor Contributing with code and tech skills Supportin

Open Knowledge Brasil - Rede pelo Conhecimento Livre 4.4k Dec 31, 2022
CS583: Deep Learning

CS583: Deep Learning

Shusen Wang 2.6k Dec 30, 2022
Article Reranking by Memory-enhanced Key Sentence Matching for Detecting Previously Fact-checked Claims.

MTM This is the official repository of the paper: Article Reranking by Memory-enhanced Key Sentence Matching for Detecting Previously Fact-checked Cla

ICTMCG 13 Sep 17, 2022
The codes and related files to reproduce the results for Image Similarity Challenge Track 1.

ISC-Track1-Submission The codes and related files to reproduce the results for Image Similarity Challenge Track 1. Required dependencies To begin with

Wenhao Wang 115 Jan 02, 2023
An implementation of Video Frame Interpolation via Adaptive Separable Convolution using PyTorch

This work has now been superseded by: https://github.com/sniklaus/revisiting-sepconv sepconv-slomo This is a reference implementation of Video Frame I

Simon Niklaus 984 Dec 16, 2022
(AAAI2022) Style Mixing and Patchwise Prototypical Matching for One-Shot Unsupervised Domain Adaptive Semantic Segmentation

SM-PPM This is a Pytorch implementation of our paper "Style Mixing and Patchwise Prototypical Matching for One-Shot Unsupervised Domain Adaptive Seman

W-zx-Y 10 Dec 07, 2022
We will release the code of "ConTNet: Why not use convolution and transformer at the same time?" in this repo

ConTNet Introduction ConTNet (Convlution-Tranformer Network) is proposed mainly in response to the following two issues: (1) ConvNets lack a large rec

93 Nov 08, 2022
ONNX Command-Line Toolbox

ONNX Command Line Toolbox Aims to improve your experience of investigating ONNX models. Use it like onnx infershape /path/to/model.onnx. (See the usag

黎明灰烬 (王振华 Zhenhua WANG) 23 Nov 13, 2022
Contains source code for the winning solution of the xView3 challenge

Winning Solution for xView3 Challenge This repository contains source code and pretrained models for my (Eugene Khvedchenya) solution to xView 3 Chall

Eugene Khvedchenya 51 Dec 30, 2022
A Tensorflow implementation of the Text Conditioned Auxiliary Classifier Generative Adversarial Network for Generating Images from text descriptions

A Tensorflow implementation of the Text Conditioned Auxiliary Classifier Generative Adversarial Network for Generating Images from text descriptions

Ayushman Dash 93 Aug 04, 2022