PyTorch implementation of "A Full-Band and Sub-Band Fusion Model for Real-Time Single-Channel Speech Enhancement."

Overview

FullSubNet

Platform Python version Pytorch Version GitHub repo size

This Git repository for the official PyTorch implementation of "A Full-Band and Sub-Band Fusion Model for Real-Time Single-Channel Speech Enhancement", submitted to ICASSP 2021.

🌼 See the demo page in this link.

workflow

fullsubnet_result

You can use all of these things:

  • Available models
    • FullSubNet
    • Delayed Sub-Band LSTM
    • Fullband LSTM Baseline
  • Available Datasets
    • Deep Noise Suppression Challenge - INTERSPEECH 2020
    • Demand + CSTR VCTK Corpus

Documentation

Citation

If you use this code for your research, please consider citing:

@misc{hao2020fullsubnet,
      title={FullSubNet: A Full-Band and Sub-Band Fusion Model for Real-Time Single-Channel Speech Enhancement}, 
      author={Xiang Hao and Xiangdong Su and Radu Horaud and Xiaofei Li},
      year={2020},
      eprint={2010.15508},
      archivePrefix={arXiv},
      primaryClass={eess.AS}
}

License

License: MIT

Comments
  • Will there be a 44.1 or 48kHz pre-trained model released?

    Will there be a 44.1 or 48kHz pre-trained model released?

    Hi guys! Your work is absolutely amazing and inspiring. Trying your model on data that was not on your datasets, it performed very well.

    I did have to convert to 16kHz first since as I understand, it was trained on 16kHz?

    My question is: will you guys release a 44.1kHz or 48kHz pre-trained model in the near future or not?

    While I could try my hand at training that model, I'm nowhere near experienced as you guys at this and feel that I'd miss so many things and would not be able to create a model that generalizes as well as you have, or maybe you can prove me wrong.

    opened by youssefavx 8
  • How to Use Pretrained, pickled model in Releases with No Documentation?

    How to Use Pretrained, pickled model in Releases with No Documentation?

    @haoxiangsnr Hi, it's already April and there still isn't any documentation for the pretrained model in releases. How do we go about using the pickled file data.pkl for inference? Thanks!

    opened by uwstudent123 3
  • Any plans about releasing the pretrained models?

    Any plans about releasing the pretrained models?

    First, thanks for the open-source implementation. I saw that the pretrained model is on your TO-DO list in the baseline readme. Do you have any plans for the releasing schedule? Thanks a lot!

    Type: Documentation Priority: Critical Status: In Progress 
    opened by 121898 3
  • [Question] The real-time speech enhance is poor, Need help!

    [Question] The real-time speech enhance is poor, Need help!

    At present, I have completed the modification of cumulative_laplace_norm, and then sent it to the network in batches through Stft for streaming inference, and obtained the Hidden state and cell state of the network. But the results are poor. I looked at the previous issue, you said that you need to replace LSTM with LSTMCell, what is the difference between the two? Why do you convert it? Pictures are as follow:

    • Fullwav_load:

    image

    • Stream_load:

    image

    Metrics: (My experiment)

    ... | NB_PESQ | WB_PESQ | SI_SDR | STOI -- | -- | -- | -- | -- FullSubNet-cum| Epoch 130 | 3.364 | 2.861 | 17.65 | 96.25 FullSubNet-cum-stream | | 3.155 | 2.466 | 14.77 | 94.30

    opened by Kayden-Wang 2
  • Training and Validation cRM Mismatch

    Training and Validation cRM Mismatch

    During training, with batch size 10, we observe the following shapes:

    cRM torch.Size([10, 128, 193, 2])
    noisy_real torch.Size([10, 257, 193])
    noisy_imag torch.Size([10, 257, 193])
    

    However, during validation, we see:

    cRM torch.Size([1, 257, 626, 2])
    noisy_real torch.Size([1, 257, 626])
    noisy_imag torch.Size([1, 257, 626])
    

    Why is dimension 1 and 2 of the cRM different during training but not during validation?

    Without these, I am unable to get the enhanced waveform during training, since this calculation fails:

    cRM = decompress_cIRM(cRM)
    
    enhanced_real = cRM[..., 0] * noisy_real - cRM[..., 1] * noisy_imag
    enhanced_imag = cRM[..., 1] * noisy_real + cRM[..., 0] * noisy_imag
    
    opened by jhkonan 2
  • Sub-band model

    Sub-band model

    Hi, thanks for sharing this excellent project. I took a great interest in the model(Delayed Sub-Band LSTM) posted in the paper. I've tried my hard to reproduce this model, but still can't get a well performance. Can you release the code about your sub-band models ? Thanks a lot!

    opened by HWLhsu 2
  • error !

    error !

    hi @haoxiangsnr, I run pre-training on Google Colab followed by this link: https://github.com/haoxiangsnr/FullSubNet/blob/main/docs/getting_started.md but I got an issue like this: command: !python inference.py -C /content/FullSubNet/recipes/dns_interspeech_2020/fullband_baseline/inference.toml -M /content/drive/MyDrive/Colab_Notebooks/FullSubNet/fullsubnet_best_model_58epochs.tar -O /content/drive/MyDrive/Colab_Notebooks/FullSubNet/output_dir result:

    Loading inference dataset... Loading model... Traceback (most recent call last): File "inference.py", line 32, in main(configuration, checkpoint_path, output_dir) File "inference.py", line 16, in main output_dir File "/content/FullSubNet/recipes/dns_interspeech_2020/inferencer.py", line 50, in init super().init(config, checkpoint_path, output_dir) File "/content/FullSubNet/audio_zen/inferencer/base_inferencer.py", line 27, in init self.model, epoch = self._load_model(config["model"], checkpoint_path, self.device) File "/content/FullSubNet/audio_zen/inferencer/base_inferencer.py", line 91, in _load_model model = initialize_module(model_config["path"], args=model_config["args"], initialize=True) File "/content/FullSubNet/audio_zen/utils.py", line 87, in initialize_module module = importlib.import_module(module_path) File "/usr/lib/python3.7/importlib/init.py", line 127, in import_module return _bootstrap._gcd_import(name[level:], package, level) File "", line 1006, in _gcd_import File "", line 983, in _find_and_load File "", line 965, in _find_and_load_unlocked ModuleNotFoundError: No module named 'model'

    please help me a hand, thanks a lot

    opened by vinh1988 2
  • # of Epochs for training a FullBand baseline model

    # of Epochs for training a FullBand baseline model

    Hello,

    My question is about the training. I am just trying to replicate the results with the DNS challange dataset. The number of epochs for the fullband_baseline.toml file is set to 9999 which seems to be a "little" high :) Could you please shed some light on it ? Is this the default value ?

    Thank you for sharing your work.

    B.R.

    opened by kadir-gunel 2
  • Not an issue, but I wanted to show you some impressive results...

    Not an issue, but I wanted to show you some impressive results...

    Greetings! I have been helping restore an old (1980) recording of a recording of an interview with an elderly person relaying stories of the early history of the Baha'i Faith in the U.S.. I have surveyed and tried a number of machine learning methods to denoise and enhance the recording. I just finished processing with your FullSubNet today, and it far surpassed the other ones I tried out in removing the recording noise to make the voice easier to understand. Enclosed is a graphic of the comparison of the frequency spectograms of the 3 files where the top one is the original recording, the middle is the result of using another method (that was dozens of times slower than yours and principally dealt with white noise) and at the bottom is the result of FullSubNet using your pretrained checkpoint. The reduction in noise going from the original recording to what your method produced was astonishing! I can check to see if the archivist would allow me to provide the recordings (so if you are interested in getting them, please let me know). Thanks so much for making the code available here! Regards -Steve

    compare_orig_n2n_fullsubnet

    opened by sjscotti 2
  • The batch size for the validation stage must be one

    The batch size for the validation stage must be one

    Hi Hao Xiang

    I can't run the demo due to the limit that gpu usage must be over 20 percentage. Therefore, I found the validation set the batchsize

    must be one. Can I change the batchsize in validation?

    Hope your reply!

    opened by zc1616 2
  • Questions about the training process

    Questions about the training process

    Very interesting project. Thank you for sharing.

    I have a quastion - what are the text files noise.txt, rir.txt and clean_0.6.txt? Are they part of the original dataset or dedicated files that you've created for the training?

    Another qaustion - is it possible to run it on Windows run without the "dist" feature (using a single GPU)? (I mean after commecting all parts related to the 'dist')

    opened by ahikaml 2
  • 有关look-ahead的疑问

    有关look-ahead的疑问

    hi,我理解的look-ahead是使用多少未来帧,可是我在看您代码的过程中发现是在后面补两帧0,noisy_mag = F.pad(noisy_mag, [0, self.look_ahead]),最后只取第二帧之后的数据output = sb_mask[:, :, :, self.look_ahead:]

    是不是在推理的过程中,不需要补0,而是直接处理3帧,结果出一帧(output = sb_mask[:, :, :, self.look_ahead:])之后,然后流式的一帧进 一帧出

    opened by LXP-Never 0
  • Normalization

    Normalization

    Dear authors,

    I notice that in snr_mix, the signal dBFs will be [-35, -15], meaning the intensity can change randomly. However, in inference.py, normalization is applied, which is weird. From my understanding, we either normalize all data or don't normalize all data, but why do you choose to normalize it in inference while discarding it during training? Maybe I have some misunderstanding, please correct me if possible.

    Best

    opened by lixinghe1999 2
  • Unable to fine-tune pre-trained model (fullsubnet_best_model_58epochs.tar)

    Unable to fine-tune pre-trained model (fullsubnet_best_model_58epochs.tar)

    I am trying to continue training the pre-trained FullSubNet model provided by this repo:

    fullsubnet_best_model_58epochs.tar

    I can confirm the model works for inference. However, I run into issues loading the state dictionary for training based on how the model was saved.

    Here is the error in full:

    (FullSubNet) $ torchrun --standalone --nnodes=1 --nproc_per_node=1 train.py -C fullsubnet/train.toml -R
    1 process initialized.
    Traceback (most recent call last):
      File "/home/github/FullSubNet/recipes/dns_interspeech_2020/train.py", line 99, in <module>
        entry(local_rank, configuration, args.resume, args.only_validation)
      File "/home/github/FullSubNet/recipes/dns_interspeech_2020/train.py", line 59, in entry
        trainer = trainer_class(
      File "/home/github/FullSubNet/recipes/dns_interspeech_2020/fullsubnet/trainer.py", line 17, in __init__
        super().__init__(dist, rank, config, resume, only_validation, model, loss_function, optimizer)
      File "/home/github/FullSubNet/audio_zen/trainer/base_trainer.py", line 84, in __init__
        self._resume_checkpoint()
      File "/home/github/FullSubNet/audio_zen/trainer/base_trainer.py", line 153, in _resume_checkpoint
        self.scaler.load_state_dict(checkpoint["scaler"])
      File "/home/anaconda3/envs/FullSubNet/lib/python3.9/site-packages/torch/cuda/amp/grad_scaler.py", line 502, in load_state_dict
        raise RuntimeError("The source state dict is empty, possibly because it was saved "
    RuntimeError: The source state dict is empty, possibly because it was saved from a disabled instance of GradScaler.
    ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 1822537) of binary: /home/anaconda3/envs/FullSubNet/bin/python
    Traceback (most recent call last):
      File "/home/anaconda3/envs/FullSubNet/bin/torchrun", line 33, in <module>
        sys.exit(load_entry_point('torch==1.11.0', 'console_scripts', 'torchrun')())
      File "/home/anaconda3/envs/FullSubNet/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 345, in wrapper
        return f(*args, **kwargs)
      File "/home/anaconda3/envs/FullSubNet/lib/python3.9/site-packages/torch/distributed/run.py", line 724, in main
        run(args)
      File "/home/anaconda3/envs/FullSubNet/lib/python3.9/site-packages/torch/distributed/run.py", line 715, in run
        elastic_launch(
      File "/home/anaconda3/envs/FullSubNet/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 131, in __call__
        return launch_agent(self._config, self._entrypoint, list(args))
      File "/home/anaconda3/envs/FullSubNet/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 245, in launch_agent
        raise ChildFailedError(
    torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
    ============================================================
    train.py FAILED
    ------------------------------------------------------------
    Failures:
      <NO_OTHER_FAILURES>
    ------------------------------------------------------------
    Root Cause (first observed failure):
    [0]:
      time      : 2022-06-19_21:25:20
      host      : host-server
      rank      : 0 (local_rank: 0)
      exitcode  : 1 (pid: 1822537)
      error_file: <N/A>
      traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
    ============================================================
    
    

    Are there specific modifications that need to be made to continue training?

    Thank you for your help.

    opened by jhkonan 1
  • 仓库里模型比比赛时提交的结果好很多, 这个是有什么不同吗

    仓库里模型比比赛时提交的结果好很多, 这个是有什么不同吗

    您好, 我注意到FullSubNet 参加DNS2021 的成绩是在dev_testset MOS 3.06. https://www.microsoft.com/en-us/research/uploads/prod/2020/12/Challenge_Results.pdf

    但是我实际测试仓库的的模型, 指标是:3.44 ==> DNSMOS_SIG : 3.790972579288795 ==> DNSMOS_BAK : 4.130271822666175 ==> DNSMOS_OVR : 3.441530761341177

    请问是当时提交的模型与这里的不一样吗?

    感谢!

    opened by lhwcv 0
  • error

    error

    con somebody help me fix this

    command : python inference.py -C C:\Users\punnp\Desktop\FullSubNet\recipes\dns_interspeech_2020\fullsubnet/inference.toml -M C:\Users\punnp\Desktop\FullSubNet\model\fullsubnet_best_model_58epochs.tar -O C:\Users\punnp\Desktop\FullSubNet\output

    result : Traceback (most recent call last): File "C:\Users\punnp\anaconda3\envs\FullSubNet\lib\site-packages\toml\decoder.py", line 512, in loads multibackslash) File "C:\Users\punnp\anaconda3\envs\FullSubNet\lib\site-packages\toml\decoder.py", line 778, in load_line value, vtype = self.load_value(pair[1], strictly_valid) File "C:\Users\punnp\anaconda3\envs\FullSubNet\lib\site-packages\toml\decoder.py", line 880, in load_value return (self.load_array(v), "array") File "C:\Users\punnp\anaconda3\envs\FullSubNet\lib\site-packages\toml\decoder.py", line 1026, in load_array nval, ntype = self.load_value(a[i]) File "C:\Users\punnp\anaconda3\envs\FullSubNet\lib\site-packages\toml\decoder.py", line 866, in load_value raise ValueError("Reserved escape sequence used") ValueError: Reserved escape sequence used

    During handling of the above exception, another exception occurred:

    Traceback (most recent call last): File "inference.py", line 30, in configuration = toml.load(config_path.as_posix()) File "C:\Users\punnp\anaconda3\envs\FullSubNet\lib\site-packages\toml\decoder.py", line 134, in load return loads(ffile.read(), _dict, decoder) File "C:\Users\punnp\anaconda3\envs\FullSubNet\lib\site-packages\toml\decoder.py", line 514, in loads raise TomlDecodeError(str(err), original, pos) toml.decoder.TomlDecodeError: Reserved escape sequence used (line 20 column 1 char 241)

    opened by MisuyaXZ 0
Releases(v0.2)
  • v0.2(Jan 16, 2021)

    Checkpoints

    This page has two released model checkpoints. All checkpoints include "model_state_dict", "optimizer_state_dict", and some other meta information.

    The first model checkpoint is the original model checkpoint at the 58th epoch. The performance is shown in this table:

    | | With Reverb | | | | No Reverb | | | | |:----------:|:-----------:|:-------:|:------:|:-----:|:---------:|:-------:|:------:|:-----:| | Method | WB-PESQ | NB-PESQ | SI-SDR | STOI | WB-PESQ | NB-PESQ | SI-SDR | STOI | | FullSubNet | 2.987 | 3.496 | 15.756 | 0.926 | 2.889 | 3.385 | 17.635 | 0.964 |

    In addition, some people are interested in the performance when using cumulative normalization. The below one is a pre-trained FullSubNet using cumulative normalization:

    | | With Reverb | | | | No Reverb | | | | |:----------:|:-----------:|:-------:|:------:|:-----:|:---------:|:-------:|:------:|:-----:| | Method | WB-PESQ | NB-PESQ | SI-SDR | STOI | WB-PESQ | NB-PESQ | SI-SDR | STOI | |FullSubNet (Cumulative Norm)| 2.978| 3.503 | 15.820 | 0.928 | 2.863| 3.376 | 17.913 | 0.964 |

    If you want to inference or fine-tune based on these checkpoints, please check the usage in the documents.

    Room Impulse Responses

    As mentioned in the paper, the room impulse responses (RIRs) come from the Multichannel Impulse Response Database and the Reverb Challenge dataset. Please download the zip package "RIR (Multichannel Impulse Response Database + The REVERB challenge).zip" if you would like to retrain the FullSubNet.

    Note that the zip package includes a folder "rir" and a file "rir.txt." The folder "rir" contains all separated single-channel RIRs extracted from the above two datasets. The suffix (e.g., "m_") of the filename is the index of a microphone. The file "rir.txt" is just a path list of all RIRs. Please modify it to fit your case before you use it.

    For some cases, if you would like to extract channel by yourself, you can download these RIRs from pages:

    1. Multichannel Impulse Response Database: https://www.eng.biu.ac.il/~gannot/RIR_DATABASE/
    2. The REVERB challenge data: https://reverb2014.dereverberation.com/tools/reverb_tools_for_Generate_mcTrainData.tgz and https://reverb2014.dereverberation.com/tools/reverb_tools_for_Generate_SimData.tgz

    Enjoy ~

    Source code(tar.gz)
    Source code(zip)
    cum_fullsubnet_best_model_218epochs.tar(64.53 MB)
    fullsubnet_best_model_58epochs.tar(64.53 MB)
    RIR.Multichannel.Impulse.Response.Database.+.The.REVERB.challenge.zip(10.68 MB)
Owner
郝翔
Audio/Speech Signal Processing
郝翔
SPT_LSA_ViT - Implementation for Visual Transformer for Small-size Datasets

Vision Transformer for Small-Size Datasets Seung Hoon Lee and Seunghyun Lee and Byung Cheol Song | Paper Inha University Abstract Recently, the Vision

Lee SeungHoon 87 Jan 01, 2023
Contrastive Language-Image Pretraining

CLIP [Blog] [Paper] [Model Card] [Colab] CLIP (Contrastive Language-Image Pre-Training) is a neural network trained on a variety of (image, text) pair

OpenAI 11.5k Jan 08, 2023
Official Pytorch implementation of "Unbiased Classification Through Bias-Contrastive and Bias-Balanced Learning (NeurIPS 2021)

Unbiased Classification Through Bias-Contrastive and Bias-Balanced Learning (NeurIPS 2021) Official Pytorch implementation of Unbiased Classification

Youngkyu 17 Jan 01, 2023
A collection of metrics for evaluating timbre dissimilarity using the TorchMetrics API

Timbre Dissimilarity Metrics A collection of metrics for evaluating timbre dissimilarity using the TorchMetrics API Installation pip install -e . Usag

Ben Hayes 21 Jan 05, 2022
Physics-Aware Training (PAT) is a method to train real physical systems with backpropagation.

Physics-Aware Training (PAT) is a method to train real physical systems with backpropagation. It was introduced in Wright, Logan G. & Onodera, Tatsuhiro et al. (2021)1 to train Physical Neural Networ

McMahon Lab 230 Jan 05, 2023
Keras implementation of PersonLab for Multi-Person Pose Estimation and Instance Segmentation.

PersonLab This is a Keras implementation of PersonLab for Multi-Person Pose Estimation and Instance Segmentation. The model predicts heatmaps and vari

OCTI 160 Dec 21, 2022
Semantic Segmentation Architectures Implemented in PyTorch

pytorch-semseg Semantic Segmentation Algorithms Implemented in PyTorch This repository aims at mirroring popular semantic segmentation architectures i

Meet Shah 3.3k Dec 29, 2022
Repository for Traffic Accident Benchmark for Causality Recognition (ECCV 2020)

Causality In Traffic Accident (Under Construction) Repository for Traffic Accident Benchmark for Causality Recognition (ECCV 2020) Overview Data Prepa

Tackgeun 21 Nov 20, 2022
Multi-Agent Reinforcement Learning for Active Voltage Control on Power Distribution Networks (MAPDN)

Multi-Agent Reinforcement Learning for Active Voltage Control on Power Distribution Networks (MAPDN) This is the implementation of the paper Multi-Age

Future Power Networks 83 Jan 06, 2023
git《FSCE: Few-Shot Object Detection via Contrastive Proposal Encoding》(CVPR 2021) GitHub: [fig8]

FSCE: Few-Shot Object Detection via Contrastive Proposal Encoding (CVPR 2021) This repo contains the implementation of our state-of-the-art fewshot ob

233 Dec 29, 2022
Human head pose estimation using Keras over TensorFlow.

RealHePoNet: a robust single-stage ConvNet for head pose estimation in the wild.

Rafael Berral Soler 71 Jan 05, 2023
QRec: A Python Framework for quick implementation of recommender systems (TensorFlow Based)

Introduction QRec is a Python framework for recommender systems (Supported by Python 3.7.4 and Tensorflow 1.14+) in which a number of influential and

Yu 1.4k Jan 01, 2023
Much faster than SORT(Simple Online and Realtime Tracking), a little worse than SORT

QSORT QSORT(Quick + Simple Online and Realtime Tracking) is a simple online and realtime tracking algorithm for 2D multiple object tracking in video s

Yonghye Kwon 8 Jul 27, 2022
Keras Model Implementation Walkthrough

Keras Model Implementation Walkthrough

Luke Wood 17 Sep 27, 2022
Unsupervised 3D Human Mesh Recovery from Noisy Point Clouds

Unsupervised 3D Human Mesh Recovery from Noisy Point Clouds Xinxin Zuo, Sen Wang, Minglun Gong, Li Cheng Prerequisites We have tested the code on Ubun

41 Dec 12, 2022
On Effective Scheduling of Model-based Reinforcement Learning

On Effective Scheduling of Model-based Reinforcement Learning Code to reproduce the experiments in On Effective Scheduling of Model-based Reinforcemen

laihang 8 Oct 07, 2022
Package to compute Mauve, a similarity score between neural text and human text. Install with `pip install mauve-text`.

MAUVE MAUVE is a library built on PyTorch and HuggingFace Transformers to measure the gap between neural text and human text with the eponymous MAUVE

Krishna Pillutla 182 Jan 02, 2023
3D-aware GANs based on NeRF (arXiv).

CIPS-3D This repository will contain the code of the paper, CIPS-3D: A 3D-Aware Generator of GANs Based on Conditionally-Independent Pixel Synthesis.

Peterou 563 Dec 31, 2022
HCQ: Hybrid Contrastive Quantization for Efficient Cross-View Video Retrieval

HCQ: Hybrid Contrastive Quantization for Efficient Cross-View Video Retrieval [toc] 1. Introduction This repository provides the code for our paper at

13 Dec 08, 2022
yufan 81 Dec 08, 2022