MITB Banner

Facebook’s Recently Released Voice Separation Technique SVoice – Complete Guide With Python Code

SVoice is Facebook Research’s newly achieved state-of-the-art speech separation technique for multiple voices speaking simultaneously in a single audio sequence

Share

SVoice is Facebook Research’s newly achieved state-of-the-art speech separation technique for multiple voices speaking simultaneously in a single audio sequence. This technique was presented to ICML (International Conference on Machine Learning), the research paper named “Voice Separation with an Unknown Number of Multiple Speakers” by Eliya Nachmani, Yossi Adi, and Lior Wolf.  

Voice separation has been a challenging problem for ages. Earlier methods involved unsupervised learning or multiple microphones. With the advent of neural networks and deep learning, better solutions came through. SVoice considers supervised voice separation technique with the source being a single microphone/single-channel source separation containing mixed voices. It is a mask free method built on RNNs. 

For every number of speakers, a model is trained, and the model containing the highest number of speakers can show the actual number of different voices from the original sample. The output maintains the speaker in each output channel fixed. This method thus has proven to be a new benchmark by achieving a scale-invariant SI-SNR (signal-to-noise ratio, a common measure of separation quality) using multiple loss functions that have been an improvement of more than 1.5 dB (decibels) from the existing ground truth voice separation methods. Until now, the model works for up to 5 voices. The various datasets used are WHAM, WHAMR, WSJ-2mix, WSJ-3mix, WSJ-4mix, WSJ-5mix. Have a look at the following demonstration video along with working architecture. 

https://www.youtube.com/watch?v=e7QT7dD9-J8

Architecture

As part of processing, the method includes encoding, chunking, and the RNNs on the tensor. Here the RNNs contain dual heads, no masking is used, our losses are also different. 

The MULCAT block which means multiply and concat, here the 3D tensor and the odd blocks obtained from chunking is fed as input to two bi-directional LSTMs that operate along the second dimension. The results are then multiplied element-wise, and it is followed by a concatenation of the original signal along the third dimension. To obtain a tensor of the same size of the input, a linear projection along this dimension is applied. In the even blocks, these same operations take place along the chunking axis.

Speaker Classification Loss Terms

The training losses used in the method, shown for the case of C = 2 speakers. The mixed-signal x combines the two input audio signals s1 and s2. The model then separates to create two output channels sˆ1 and sˆ2. The permutation invariant training loss(called uPIT) calculates the SI-SNR between the output channels and the ground truth channels, obtained at the channel permutation π to minimize the loss. Lastly, the identity loss is computed for matching channels after they have been ordered by π.

Code Snippet:

The code for svoice is implemented in PyTorch. 

# importing libraries

 import sys
 import numpy as np
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
 from torch.autograd import Variable
 from ..utils import overlap_and_add
 from ..utils import capture_init 

RNN blocks contain the MULCAT block with two sub-networks(encoder-decoder and separator) and a skip connection. Two separate bidirectional LSTM are present, element-wise their outputs are multiplied, and finally concatenated with the input to produce the module output. The result of the concatenation is the product of the two LSTMs. 

 class MulCatBlock(nn.Module):
     def __init__(self, input_size, hidden_size, dropout=0, bidirectional=False):
         super(MulCatBlock, self).__init__()
         self.input_size = input_size
         self.hidden_size = hidden_size
         self.num_direction = int(bidirectional) + 1
         self.rnn = nn.LSTM(input_size, hidden_size, 1, dropout=dropout,
                            batch_first=True, bidirectional=bidirectional)
         self.rnn_proj = nn.Linear(hidden_size * self.num_direction, input_size)
         self.gate_rnn = nn.LSTM(input_size, hidden_size, num_layers=1,
                                 batch_first=True, dropout=dropout, bidirectional=bidirectional)
         self.gate_rnn_proj = nn.Linear(
             hidden_size * self.num_direction, input_size)
         self.block_projection = nn.Linear(input_size * 2, input_size) 
     def forward(self, input):
         output = input
 # run rnn module
         rnn_output, _ = self.rnn(output)
         rnn_output = self.rnn_proj(rnn_output.contiguous(
         ).view(-1, rnn_output.shape[2])).view(output.shape).contiguous() 
 # run gate rnn module
         gate_rnn_output, _ = self.gate_rnn(output)
         gate_rnn_output = self.gate_rnn_proj(gate_rnn_output.contiguous(
         ).view(-1, gate_rnn_output.shape[2])).view(output.shape).contiguous() 
 # apply gated rnn
         gated_output = torch.mul(rnn_output, gate_rnn_output)
         gated_output = torch.cat([gated_output, output], 2)
         gated_output = self.block_projection(
             gated_output.contiguous().view(-1, gated_output.shape[2])).view(output.shape)
         return gated_output 
 class DPMulCat(nn.Module):
     def __init__(self, input_size, hidden_size, output_size, num_spk,
                  dropout=0, num_layers=1, bidirectional=True, input_normalize=False):
         super(DPMulCat, self).__init__()
         self.input_size = input_size
         self.output_size = output_size
         self.hidden_size = hidden_size
         self.in_norm = input_normalize
         self.num_layers = num_layers
         self.rows_grnn = nn.ModuleList([])
         self.cols_grnn = nn.ModuleList([])
         self.rows_normalization = nn.ModuleList([])
         self.cols_normalization = nn.ModuleList([]) 

# creating the dual path pipeline

         for i in range(num_layers):
             self.rows_grnn.append(MulCatBlock(
                 input_size, hidden_size, dropout, bidirectional=bidirectional))
             self.cols_grnn.append(MulCatBlock(
                 input_size, hidden_size, dropout, bidirectional=bidirectional))
             if self.in_norm:
                 self.rows_normalization.append(
                     nn.GroupNorm(1, input_size, eps=1e-8))
                 self.cols_normalization.append(
                     nn.GroupNorm(1, input_size, eps=1e-8)) 

A multi-scale loss is deployed, which requires reconstructing the original audio after each pair of blocks. The 3D tensor undergoes the PReLU non-linear activation function.

   else:
 # disabling normalization
                 self.rows_normalization.append(ByPass())
                 self.cols_normalization.append(ByPass())
         self.output = nn.Sequential(
             nn.PReLU(), nn.Conv2d(input_size, output_size * num_spk, 1))
     def forward(self, input):
         batch_size, _, d1, d2 = input.shape
         output = input
         output_all = []
         for i in range(self.num_layers):
             row_input = output.permute(0, 3, 2, 1).contiguous().view(
                 batch_size * d2, d1, -1)
             row_output = self.rows_grnn[i](row_input)
             row_output = row_output.view(
                 batch_size, d2, d1, -1).permute(0, 3, 2, 1).contiguous()
             row_output = self.rows_normalization[i](row_output) 

# applying a skip connection

             output = output + row_output
             col_input = output.permute(0, 2, 3, 1).contiguous().view(
                 batch_size * d1, d2, -1)
             col_output = self.cols_grnn[i](col_input)
             col_output = col_output.view(
                 batch_size, d1, d2, -1).permute(0, 3, 1, 2).contiguous()
             col_output = self.cols_normalization[i](col_output).contiguous() 

# applying a skip connection

             output = output + col_output
             output_i = self.output(output)
             output_all.append(output_i)
         return output_all 

The separation network which consists of ‘B’ RNN blocks. The odd blocks apply the RNN along the time-dependent dimension of size ‘R’. The even blocks are applied along the chunking dimension of size ‘K’. Processing the second dimension returns a short term representation while processing the third dimension results in a long-term representation.

 class Separator(nn.Module):
     def __init__(self, input_dim, feature_dim, hidden_dim, output_dim, num_spk=2,
                  layer=4, segment_size=100, input_normalize=False, bidirectional=True):
         super(Separator, self).__init__()
         self.input_dim = input_dim
         self.feature_dim = feature_dim
         self.hidden_dim = hidden_dim
         self.output_dim = output_dim
         self.layer = layer
         self.segment_size = segment_size
         self.num_spk = num_spk
         self.input_normalize = input_normalize
         self.rnn_model = DPMulCat(self.feature_dim, self.hidden_dim,
                                   self.feature_dim, self.num_spk, num_layers=layer, bidirectional=bidirectional, input_normalize=input_normalize) 
     def pad_segment(self, input, segment_size):
 # input features: (B, N, T)
         batch_size, dim, seq_len = input.shape
         segment_stride = segment_size // 2
         rest = segment_size - (segment_stride + seq_len %
                                segment_size) % segment_size
         if rest > 0:
             pad = Variable(torch.zeros(batch_size, dim, rest)
                            ).type(input.type())
             input = torch.cat([input, pad], 2)
         pad_aux = Variable(torch.zeros(
             batch_size, dim, segment_stride)).type(input.type())
         input = torch.cat([pad_aux, input, pad_aux], 2)
         return input, rest 
     def create_chuncks(self, input, segment_size):
 # spliting the feature into chunks of the segment size
 # input features (B, N, T)
         input, rest = self.pad_segment(input, segment_size)
         batch_size, dim, seq_len = input.shape
         segment_stride = segment_size // 2
         segments1 = input[:, :, :-segment_stride].contiguous().view(batch_size,dim, -1, segment_size)
         segments2 = input[:, :, segment_stride:].contiguous().view(
             batch_size, dim, -1, segment_size)
         segments = torch.cat([segments1, segments2], 3).view(
             batch_size, dim, -1, segment_size).transpose(2, 3)
         return segments.contiguous(), rest 
     def merge_chuncks(self, input, rest):
 # merging the split features into full utterance
 # input features (B, N, L, K)
         batch_size, dim, segment_size, _ = input.shape
         segment_stride = segment_size // 2
         input = input.transpose(2, 3).contiguous().view(
 # B, N, K, L
             batch_size, dim, -1, segment_size*2)  
         input1 = input[:, :, :, :segment_size].contiguous().view(
             batch_size, dim, -1)[:, :, segment_stride:]
         input2 = input[:, :, :, segment_size:].contiguous().view(
             batch_size, dim, -1)[:, :, :-segment_stride]
         output = input1 + input2
         if rest > 0:
             output = output[:, :, :-rest]
  # B, N, T
         return output.contiguous()  
     def forward(self, input):
 # create chunks
         enc_segments, enc_rest = self.create_chuncks(
             input, self.segment_size)
 # separation
         output_all = self.rnn_model(enc_segments)
 # merging back audio files
         output_all_wav = []
         for i in range(len(output_all)):
             output_ii = self.merge_chuncks(
                 output_all[ii], enc_rest)
             output_all_wav.append(output_ii)
         return output_all_wav 
 class SWave(nn.Module):
     def __init__(self, N, L, H, R, C, sr, segment, input_normalize):
         super(SWave, self).__init__()
 # hyper-parameter declaration
         self.N, self.L, self.H, self.R, self.C, self.sr, self.segment = N, L, H, R, C, sr, segment
         self.input_normalize = input_normalize
         self.context_len = 2 * self.sr / 1000
         self.context = int(self.sr * self.context_len / 1000)
         self.layer = self.R
         self.filter_dim = self.context * 2 + 1
         self.num_spk = self.C
 # setting chunksize to sqrt(2*L)
         self.segment_size = int(
             np.sqrt(2 * self.sr * self.segment / (self.L/2))) 

# model sub-networks

         self.encoder = Encoder(L, N)
         self.decoder = Decoder(L)
         self.separator = Separator(self.filter_dim + self.N, self.N, self.H,
                                    self.filter_dim, self.num_spk, self.layer, self.segment_size, self.input_normalize)
         for p in self.parameters():
             if p.dim() > 1:
                 nn.init.xavier_normal_(p)
     def forward(self, mixture):
         mixture_w = self.encoder(mixture)
         output_all = self.separator(mixture_w) 

# fixing time dimension, which might change due to convolution operations

        T_mix = mixture.size(-1)

# generating wav after each RNN block and optimize the loss

         outputs = []
         for i in range(len(output_all)):
             output_ii = output_all[ii].view(
                 mixture.shape[0], self.C, self.N, mixture_w.shape[2])
             output_ii = self.decoder(output_ii)
             T_est = output_ii.size(-1)
             output_ii = F.pad(output_ii, (0, T_mix - T_est))
             outputs.append(output_ii)
         return torch.stack(outputs) 

Encoder network, E, gets input as the mixture waveform and outputs as an N-dimensional latent representation. E is a 1-dimensional convolutional layer with a kernel size L and a stride of L/2, non-linear activation function ReLU is applied. 

 class Encoder(nn.Module):
     def __init__(self, L, N):
         super(Encoder, self).__init__()
         self.L, self.N = L, N
         # setting 50% overlap
         self.conv = nn.Conv1d(
             1, N, kernel_size=L, stride=L // 2, bias=False)
     def forward(self, mixture):
         mixture = torch.unsqueeze(mixture, 1)
         mixture_w = F.relu(self.conv(mixture))
         return mixture_w 
 class Decoder(nn.Module):
     def __init__(self, L):
         super(Decoder, self).__init__()
         self.L = L
     def forward(self, est_source):
         est_source = torch.transpose(est_source, 2, 3)
         est_source = nn.AvgPool2d((1, self.L))(est_source)
         est_source = overlap_and_add(est_source, self.L//2)
         return est_source 

Training

Hydra framework is used for research application building, and svoice uses it for training to get the hierarchical configurations.

 import json
 import logging
 import os
 import subprocess as sp
 from omegaconf import DictConfig, OmegaConf
 import hydra
 from svoice.executor import start_ddp_workers 
 def run(args):
     import torch
     from svoice import distrib
     from svoice.data.data import Trainset, Validset
     from svoice.models.swave import SWave
     from svoice.solver import Solver
     if args.model == "swave":
         kwargs = dict(args.swave)
         kwargs['sr'] = args.sample_rate
         kwargs['segment'] = args.segment
         model = SWave(**kwargs)
     else:
         logger.fatal("Invalid model name %s", args.model)
         os._exit(1) 

    # a specific number of samples is required to avoid 0 padding during training

     if hasattr(model, 'valid_length'):
         segment_len = int(args.segment * args.sample_rate)
         segment_len = model.valid_length(segment_len)
         args.segment = segment_len / args.sample_rate 
     if args.show:
         logger.info(model)
         mb = sum(p.numel() for p in model.parameters()) * 4 / 2**20
         logger.info('Size: %.1f MB', mb)
         if hasattr(model, 'valid_length'):
             field = model.valid_length(1)
             logger.info('Field: %.1f ms', field / args.sample_rate * 1000)
         return
     assert args.batch_size % distrib.world_size == 0
     args.batch_size //= distrib.world_size 

    # Building datasets and loaders

     tr_dataset = Trainset(
         args.dset.train, sample_rate=args.sample_rate, segment=args.segment, stride=args.stride, pad=args.pad)
     tr_loader = distrib.loader(
         tr_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)

# batch_size=1 -> use less GPU memory to do cv
     cv_dataset = Validset(args.dset.valid)
     tt_dataset = Validset(args.dset.test)
     cv_loader = distrib.loader(
         cv_dataset, batch_size=1, num_workers=args.num_workers)
     tt_loader = distrib.loader(
         tt_dataset, batch_size=1, num_workers=args.num_workers)
     data = {"tr_loader": tr_loader,
             "cv_loader": cv_loader, "tt_loader": tt_loader}

# initializing optimizer
     if args.optim == "adam":
         optimizer = torch.optim.Adam(
             model.parameters(), lr=args.lr, betas=(0.9, args.beta2))
     else:
         logger.fatal('Invalid optimizer %s', args.optim)
         os._exit(1) 

# Constructing Solver

     solver = Solver(data, model, optimizer, args)
     solver.train() 

Benchmark Results

Comparison of performance of various models for the number of speakers. The baselines are obtained from respective published papers. 

Starred results(*) mark is author’s training, using the published code by the method’s authors.

Comparison of svoice against several benchmarks using WHAM! and WHAMR! Datasets.

SI-SNRi Curve

Share
Picture of Jayita Bhattacharyya

Jayita Bhattacharyya

Machine learning and data science enthusiast. Eager to learn new technology advances. A self-taught techie who loves to do cool stuff using technology for fun and worthwhile.
Related Posts

CORPORATE TRAINING PROGRAMS ON GENERATIVE AI

Generative AI Skilling for Enterprises

Our customized corporate training program on Generative AI provides a unique opportunity to empower, retain, and advance your talent.

Upcoming Large format Conference

May 30 and 31, 2024 | 📍 Bangalore, India

Download the easiest way to
stay informed

Subscribe to The Belamy: Our Weekly Newsletter

Biggest AI stories, delivered to your inbox every week.

AI Courses & Careers

Become a Certified Generative AI Engineer

AI Forum for India

Our Discord Community for AI Ecosystem, In collaboration with NVIDIA. 

Flagship Events

Rising 2024 | DE&I in Tech Summit

April 4 and 5, 2024 | 📍 Hilton Convention Center, Manyata Tech Park, Bangalore

MachineCon GCC Summit 2024

June 28 2024 | 📍Bangalore, India

MachineCon USA 2024

26 July 2024 | 583 Park Avenue, New York

Cypher India 2024

September 25-27, 2024 | 📍Bangalore, India

Cypher USA 2024

Nov 21-22 2024 | 📍Santa Clara Convention Center, California, USA

Data Engineering Summit 2024

May 30 and 31, 2024 | 📍 Bangalore, India

Subscribe to Our Newsletter

The Belamy, our weekly Newsletter is a rage. Just enter your email below.