Repository for paper "Non-intrusive speech intelligibility prediction from discrete latent representations"

Overview

Non-Intrusive Speech Intelligibility Prediction from Discrete Latent Representations

Official repository for paper "Non-Intrusive Speech Intelligibility Prediction from Discrete Latent Representations".

This public repository is a work in progress! Results here bear no resemblance to results in the paper!

We predict the intelligibility of binaural speech signals by first extracting latent representations from raw audio. Then, a lightweight predictor over these latent representations can be trained. This results in improved performance over predicting on spectral features of the audio, despite the feature extractor not being explicitly trained for this task. In certain cases, a single layer is sufficient for strong correlations between the predictions and the ground-truth scores.

This repository contains:

  • vqcpc/ - Module for VQCPC model in PyTorch
  • stoi/ - Module for Small and SeqPool predictor model in PyTorch
  • data.py - File containing various PyTorch custom datasets
  • main-vqcpc.py - Script for VQCPC training
  • create-latents.py - Script for generating latent dataset from trained VQCPC
  • plot-latents.py - Script for visualizing extracted latent representations
  • main-stoi.py - Script for STOI predictor training
  • main-test.py - Script for evaluating models
  • compute-correlations.py - Script for computing metrics for many models
  • checkpoints/ - trained checkpoints of VQCPC and STOI predictor models
  • config/ - Directory containing various configuration files for experiments
  • results/ - Directory containing official results from experiments
  • dataset/ - Directory containing metadata files for the dataset
  • data-generator/ - Directory containing dataset generation scripts (MATLAB)

All models are implemented in PyTorch. The training scripts are implemented using ptpt - a lightweight framework around PyTorch.

Visualisation of binaural waveform, predicted per-frame STOI, and latent representation: Visualisation of binaural waveform, predicted per-frame STOI, and latent representation.

Usage

VQ-CPC Training

Begin VQ-CPC training using the configuration defined in config.toml:

python main-vqcpc.py --cfg-path config-path.toml

Other useful arguments:

--resume            # resume from specified checkpoint
--no-save           # do not save training progress (useful for debugging)
--no-cuda           # do not try to access CUDA device (very slow)
--no-amp            # disable automatic mixed precision (if you encounter NaN)
--nb-workers        # number of workers for for data loading (default: 8)
--detect-anomaly    # detect autograd anomalies and terminate if encountered
--seed              # random seed (default: 12345)

Latent Dataset Generation

Begin latent dataset generation using pre-trained VQCPC model-checkpoint.pt from dataset wav-dataset and output to latent-dataset using configuration defined in config.toml:

python create-latents.py model-checkpoint.pt wav-dataset latent-dataset --cfg-path config.toml

As above, but distributed across n processes with script rank r:

python create-latents.py model-checkpoint.pt wav-dataset latent-dataset --cfg-path config.toml --array-size n --array-rank r

Other useful arguments:

--no-cuda           # do not try to access CUDA device (very slow)
--no-amp            # disable automatic mixed precision (if you encounter NaN)
--no-tqdm           # disable progress bars
--detect-anomaly    # detect autograd anomalies and terminate if encountered
-n                  # alias for `--array-size`
-r                  # alias for `--array-rank`

Latent Plotting

Begin interactive VQCPC latent visualisation script using pre-trained model model-checkpoint.pt on dataset wav-dataset using configuration defined in config.toml:

python plot-latents.py model-checkpoint.pt wav-dataset --cfg-path config.toml

If you additionally have a pre-trained, per-frame STOI score predictor (not SeqPool predictor) you can specify the checkpoint stoi-checkpoint.pt and additional configuration stoi-config.toml, you can plot per-frame scores alongside the waveform and latent features:

python plot-latents.py model-checkpoint.pt wav-dataset --cfg-path config.toml --stoi stoi-checkpoint.pt --stoi-cfg stoi-config.toml

Other useful arguments:

--no-cuda           # do not try to access CUDA device (very slow)
--no-amp            # disable automatic mixed precision (if you encounter NaN)
--cmap              # define matplotlib colourmap
--style             # define matplotlib style

STOI Predictor Training

Begin intelligibility score predictor training script using configuration in config.toml:

python main-stoi.py --cfg-path config.toml

Other useful arguments:

--resume            # resume from specified checkpoint
--no-save           # do not save training progress (useful for debugging)
--no-cuda           # do not try to access CUDA device (very slow)
--no-amp            # disable automatic mixed precision (if you encounter NaN)
--nb-workers        # number of workers for for data loading (default: 8)
--detect-anomaly    # detect autograd anomalies and terminate if encountered
--seed              # random seed (default: 12345)

Predictor Evaluation

Begin evaluation of a pre-trained STOI score predictor using checkpoint stoi-checkpoint.pt on dataset dataset-root using configuration in stoi-config.toml:

python main-test.py stoi-checkpoint.pt dataset-root --cfg-path stoi-config.toml

Other useful arguments:

--no-save           # do not save training progress (useful for debugging)
--no-cuda           # do not try to access CUDA device (very slow)
--no-amp            # disable automatic mixed precision (if you encounter NaN)
--no-tqdm           # disable progress bars
--nb-workers        # number of workers for for data loading (default: 8)
--detect-anomaly    # detect autograd anomalies and terminate if encountered
--batch-size        # control dataloader batch size
--seed              # random seed (default: 12345)

Overall Evaluation

Compare results from many results files produced by main-test.py based on dataset ground truth:

python compute-correlations.py ground-truth.csv pred-1.csv ... pred-n.csv --names pred-1 ... pred-n

Configuration

Examples configurations for all experiments can be found here

We use toml files to define configurations. Each one consists of three sections:

  • [trainer]: configuration options for ptpt.TrainerConfig.
  • [data]: configuration options for the dataset.
  • [vqcpc] or [stoi]: configuration options for the VQCPC and predictor models respectively.

Checkpoints

Pretrained checkpoints for all models can be found here

Citation

TODO: add citation once paper published / arXiv-ed :)

Owner
Alex McKinney
Final-year student at Durham University. Interested in generative models and unsupervised representation learning.
Alex McKinney
Proximal Backpropagation - a neural network training algorithm that takes implicit instead of explicit gradient steps

Proximal Backpropagation Proximal Backpropagation (ProxProp) is a neural network training algorithm that takes implicit instead of explicit gradient s

Thomas Frerix 40 Dec 17, 2022
PyTorch version of the paper 'Enhanced Deep Residual Networks for Single Image Super-Resolution' (CVPRW 2017)

About PyTorch 1.2.0 Now the master branch supports PyTorch 1.2.0 by default. Due to the serious version problem (especially torch.utils.data.dataloade

Sanghyun Son 2.1k Dec 27, 2022
This tool uses Deep Learning to help you draw and write with your hand and webcam.

This tool uses Deep Learning to help you draw and write with your hand and webcam. A Deep Learning model is used to try to predict whether you want to have 'pencil up' or 'pencil down'.

lmagne 169 Dec 10, 2022
A New Open-Source Off-road Environment for Benchmark Generalization of Autonomous Driving

A New Open-Source Off-road Environment for Benchmark Generalization of Autonomous Driving Isaac Han, Dong-Hyeok Park, and Kyung-Joong Kim IEEE Access

13 Dec 27, 2022
Keras Image Embeddings using Contrastive Loss

Image to Embedding projection in vector space. Implementation in keras and tensorflow of batch all triplet loss for one-shot/few-shot learning.

Shravan Anand K 5 Mar 21, 2022
Wikidated : An Evolving Knowledge Graph Dataset of Wikidata’s Revision History

Wikidated Wikidated 1.0 is a dataset of Wikidata’s full revision history, which encodes changes between Wikidata revisions as sets of deletions and ad

Lukas Schmelzeisen 11 Aug 16, 2022
Diabet Feature Engineering - Predict whether people have diabetes when their characteristics are specified

Diabet Feature Engineering - Predict whether people have diabetes when their characteristics are specified

Şebnem 6 Jan 18, 2022
A Model for Natural Language Attack on Text Classification and Inference

TextFooler A Model for Natural Language Attack on Text Classification and Inference This is the source code for the paper: Jin, Di, et al. "Is BERT Re

Di Jin 418 Dec 16, 2022
A Player for Kanye West's Stem Player. Sort of an emulator.

Stem Player Player Stem Player Player Usage Download the latest release here Optional: install ffmpeg, instructions here NOTE: DOES NOT ENABLE DOWNLOA

119 Dec 28, 2022
Supporting code for the Neograd algorithm

Neograd This repo supports the paper Neograd: Gradient Descent with a Near-Ideal Learning Rate, which introduces the algorithm "Neograd". The paper an

Michael Zimmer 12 May 01, 2022
StyleGAN2 Webtoon / Anime Style Toonify

StyleGAN2 Webtoon / Anime Style Toonify Korea Webtoon or Japanese Anime Character Stylegan2 base high Quality 1024x1024 / 512x512 Generate and Transfe

121 Dec 21, 2022
A multilingual version of MS MARCO passage ranking dataset

mMARCO A multilingual version of MS MARCO passage ranking dataset This repository presents a neural machine translation-based method for translating t

75 Dec 27, 2022
Code for the paper "M2m: Imbalanced Classification via Major-to-minor Translation" (CVPR 2020)

M2m: Imbalanced Classification via Major-to-minor Translation This repository contains code for the paper "M2m: Imbalanced Classification via Major-to

79 Oct 13, 2022
Deep Latent Force Models

Deep Latent Force Models This repository contains a PyTorch implementation of the deep latent force model (DLFM), presented in the paper, Compositiona

Tom McDonald 5 Oct 26, 2022
[NeurIPS 2021] Shape from Blur: Recovering Textured 3D Shape and Motion of Fast Moving Objects

[NeurIPS 2021] Shape from Blur: Recovering Textured 3D Shape and Motion of Fast Moving Objects YouTube | arXiv Prerequisites Kaolin is available here:

Denys Rozumnyi 107 Dec 26, 2022
Evaluating Cross-lingual Sentence Representations

XNLI: The Cross-Lingual NLI Corpus XNLI is an evaluation corpus for language transfer and cross-lingual sentence classification in 15 languages. New:

Meta Research 395 Dec 19, 2022
Face Alignment using python

Face Alignment Face Alignment using python Input Image Aligned Face Aligned Face Aligned Face Input Image Aligned Face Input Image Aligned Face Instal

Sajjad Aemmi 28 Nov 23, 2022
Kaggle | 9th place (part of) solution for the Bristol-Myers Squibb – Molecular Translation challenge

Part of the 9th place solution for the Bristol-Myers Squibb – Molecular Translation challenge translating images containing chemical structures into I

Erdene-Ochir Tuguldur 22 Nov 30, 2022
Code, Data and Demo for Paper: Controllable Generation from Pre-trained Language Models via Inverse Prompting

InversePrompting Paper: Controllable Generation from Pre-trained Language Models via Inverse Prompting Code: The code is provided in the "chinese_ip"

THUDM 101 Dec 16, 2022
PyTorch implementation of Decoupling Value and Policy for Generalization in Reinforcement Learning

PyTorch implementation of Decoupling Value and Policy for Generalization in Reinforcement Learning

48 Dec 08, 2022