Now Reading
Hands-on Guide To GANSynth: An Adversarial Neural Audio Synthesis Technique

Hands-on Guide To GANSynth: An Adversarial Neural Audio Synthesis Technique


 GANSynth is a state-of-the-art method for synthesizing high-fidelity and locally coherent audio using Generative Adversarial Networks (GANs). Hence the name GANSynth (GAN used for audio Synthesis). It was introduced by Jesse Engel, Kumar Krishna Agrawal, Shuo Chen, Ishaan Gulrajani, Chris Donahue and Adam Roberts – researchers at the Google AI in 2019 (research paper).

Autoregressive models like WaveNets generate audio sequentially. On the contrary, GANSynth creates the whole sequence in parallel, synthesizing audio much faster on GPU runtime than real-time synthesis. It generates the entire audio clip from a single latent vector, allowing for easier release of global features like pitch and timbre (tone quality). It uses progressive GAN architecture. It eliminates the drawback of traditional GANs which struggle to synthesize locally coherent audio waveforms though they use global latent conditioning and efficient parallel sampling.

Register for our upcoming Masterclass>>

Are you interested in understanding the detailed workings of GANSynth? Refer to this page before proceeding!

Practical Implementation of GANSynth

Here’s a demonstration of how GANSynth learns to produce musical notes of individual instruments as contained in the NSynth dataset (a large-sized qualitative dataset having annotated notes). The GAN learns to use its latent space for representing various instrument timbres. It synthesizes audio from MIDI files and interpolates between different instruments. The code has been implemented in Google colab using Python version 3.7.10.

Step-wise explanation of the code is as follows:

Looking for a job change? Let us help you.
  1. Install Magenta (an open-source Python library, powered by Tensorflow)
 #Copy data from the GCS (Google Cloud Storage)
 !rm -r /content/gansynth &>/dev/null
 !mkdir /content/gansynth
 !mkdir /content/gansynth/midi
 !mkdir /content/gansynth/samples
 # Load default MIDI (Bach Prelude)
 #’curl’ command enables fetching a given URL 
 !curl -o /content/gansynth/midi/bach.mid 

-o option provided with the curl command saves the downloaded file on your local machine with the name specified as the parameter.

 SONG = '/content/gansynth/midi/bach.mid'
 !curl -o /content/gansynth/midi/riff-default.mid
 RIFF = '/content/gansynth/midi/riff-default.mid'
 !pip install -q -U magenta 
  1. Import required libraries and classes 
 import os #module for interacting with the operating system
 #To load files from local device (weblink)
 from google.colab import files 
 import librosa #Python library for music and audio analysis
 from magenta.models.nsynth.utils import load_audio
 from magenta.models.gansynth.lib import flags as lib_flags
 from magenta.models.gansynth.lib import generate_util as gu
 from magenta.models.gansynth.lib import model as lib_model
 from magenta.models.gansynth.lib import util
 import matplotlib.pyplot as plt #for visualization
 import note_seq
 from note_seq.notebook_utils import colab_play as play
 #colab_play() inserts an HTML audio widget to play a sound in colab
 import numpy as np
 import tensorflow.compat.v1 as tf
 #disable_v2_behavior() switches all global behaviors which vary between  
 #tensorflow 1.x and 2.x versions to behave as in 1.x.
  1. Define a function for uploading .wav file
 def upload():
   map = files.upload() #Upload the file 
   list = [] Initialize list to store names of uploaded files
 #Use iteritems() to iterate over key-value pairs of the dictionary of uploaded file content
   for key, val in map.iteritems():
     filename = os.path.join('/content/gansynth/midi', key)
     with open(filename, 'w') as file: #open the file in write mode
#write the content of uploaded file to the specified file
       print('Writing the file {}'.format(filename))
      list.append(filename) #Add the filename to the list 
   return list 
  1. Define global variables
 #checkpoint directory
 CHECKPOINT_DIR = 'gs://magentadata/models/gansynth/acoustic_only'
 OP_DIR = '/content/gansynth/samples' #output directory
 SR = 16000 #SR stands for Sample Rate 
  1. Create an output directory if it does not exist
 #Expand the path of parent directory using expand_path()
 OP_DIR = util.expand_path(opdir)
 #tensorflow.gfile.Exists() determines existence of a file
 if not tf.gfile.Exists(OP_DIR):
 #Create a directory using tensorflow.gfile.MakeDirs()
  1. Load the model
 #Clear the default graph stack and reset the global default graph
 myflags = lib_flags.Flags({
 #Dictionary for storing and accessing flags
     'batchSizeSchedule': [BATCH_SIZE],
     'tfdsData': "gs://tfds-data/datasets",
 #Create a GAN model using flags and weights from a saved model
 model = lib_model.Model.load_from_path(CHECKPOINT_DIR, myflags) 
  1. Define a function for loading MIDI file as a notesequence
 def midiLoad(path, minimumPitch=36, maximumPitch=84):
   midiPath = util.expand_path(path) #Expand the directory path
   noteSequence = note_seq.midi_file_to_sequence_proto(midiPath)
 #Define NumPy arrays to store pitches, velocities, start and end   
 #times of each note
   pitches = np.array([n.pitch for n in noteSequence.notes])
   velo = np.array([nt.velocity for nt in noteSequence.notes])
   startTimes = np.array([nt.start_time for nt in 
   endTimes = np.array([nt.end_time for nt in noteSequence.notes])
 #Keep only the notes in required pitch range
   valid = np.logical_and(pitches >= minimumPitch, pitches <= 
 #Store the valid notes’ features in the form of a dictionary
   notes = {'pitches': pitches[valid],
            'velocities': velo[valid],
            'startTimes': startTimes[valid],
            'endTimes': endTimes[valid]}
   return noteSequence, notes 
  1. Create an attack, sustain and release amplitude envelope (these are the stages of envelope generator)

‘Attack’ is part of the envelope which represents time taken by the amplitude to reach its peak.’Sustain’ is the duration for which sound is held before it fades out.’Release’ is the final reduction in amplitude over time.

 def createEnvelope(note_length, attack=0.010, release=0.3, sr=16000): 
#sr means sample rate
   note_len = min(note_length, 3.0)
   attack = int(sr * attack)
   sustain = int(sr * note_len)
   release = int(sr * release)
   total = sustain + release  #attack envelope doesn't add to sound length
   env = np.ones(total) #1’s equal to total count 
   # Linear attack
   env[:attack] = np.linspace(0.0, 1.0, attack)
 #Evenly spaced numbers from 0 to 1. Number of points equal to ‘attack’ 
   # Linear release
   env[sustain:total] = np.linspace(1.0, 0.0, release)
 #Evenly spaced numbers from 1 to 0. Number of points equal to ‘release’
   return env 
  1. Define a function to combine multiple notes from a single audio clip.
 def combine_notes(audio, start, end, velo, sr=16000):
 #’audio’ is an array of audio notes, ‘start’ is an array of note’s start  
 #time in seconds, ‘end’ is an array of note’s end times in seconds, ‘sr’ is    
 #the sample rate (integer)
   numberOfNotes = len(audio) #Number of notes
   clipLen = end.max() + 3.0 #compute length of audio clip
   clip = np.zeros(int(clipLength) * sr) #generate audio clip
   for t_start, t_end, velocity, i in zip(start, end, velo, 
     # Generate an amplitude envelope
     noteLen = t_end - t_start #compute note length
    #call createEnvelope() defined above
     env = createEnvelope(noteLen) 
     len = len(env) #length of generated envelope
     audio_note = audio[i, :len] * env
     # Normalize the notes
     audio_note /= audio_note.max()
     audio_note *= (velocity / 127.0)
     clipStart = int(t_start * sr) #start time of audio clip
     clipEnd = clipStart + length #end time of clip
     #Add the audio note to clip buffer
     clip[clipStart:clipEnd] += audio_note 
    #Normalize the audio clip
     clip /= audio_clip.max()
     clip /= 2.0
   return clip #Array of combined audio samples 
  1. Define a function to plot spectrogram
 def spectrogram(audioClip):
   min = np.min(36) #minimum number of MIDI notes
   max = np.max(84) #maximum number of MIDI notes
 #Get the frequency of MIDI notes in Hertz(Hz)
   minF = librosa.midi_to_hz(min) #minimum frequency
   maxF = 2 * librosa.midi_to_hz(max) #maximum frequency
   #number of octaves
   octaves = int(np.ceil(np.log2(maxF) - np.log2(minF)))    
   binsPerOctave = 36 #number of bins in each octave
   nBins = int(binsPerOctave * octaves) #number of bins 
 #Calculate constant-Q transform of the audio signal
   C = librosa.cqt(audioClip, sr=SR, hop_length=2048, fmin=minF,     
   n_bins=nBins, bins_per_octave=binsPerOctave)
    #’audioClip’ is the audio time series
    # ‘sr’ is the sampling rate of audioClip
# ‘hop_length’ is the number of samples between successive CQT #columns       
    #‘fmin’ is the minimum frequency
    # ‘n_bins’ is the number of frequency bins
 #Compute power of the audio signal  
 power = 10 * np.log10(np.abs(C)**2 + 1e-6)
 #Display the ‘power’ array as a matrix in a new column window using 
 #matshow()of matplotlib
   plt.matshow(power[::-1, 2:-2], aspect='auto',
  1. Choose the MIDI file
midi_file = "Arpeggio (Default)" #@param ["Arpeggio (Default)", "Upload your own"]

This will allow you to choose the default uploaded MIDI file or upload a file of your choice as follows:

 #Path of the default uploaded file
 midi_path = RIFF
 #If user chooses ‘Upload your own’ option
 if midi_file == "Upload your own":
     fileList = upload() #Upload your file
     midi_path = fileList[0] #Path of recently uploaded file
     #Load the uploaded file
     noteSeqence, notes = load_midi(midi_path)
   except Exception as e: #Throw an exception if uploading fails
     print('Upload Cancelled')
   # Load the default uploaded file, but slow it down 30%
   noteSequence, notes = load_midi(midi_path)
   notes['startTimes'] *= 1.3
   notes['endTimes'] *= 1.3
       #Plot the notesequence 


  1. Choose some random instruments to generate custom interpolation. 

Audio ‘interpolation’ means making the audio sound better.

 #Select number of instruments
 number_of_random_instruments = 10 #@param {type:"slider", min:4, max:16, step:1} 

A slider will appear as follows which will allow you to choose number of instruments from 4 to 16, in step of 1

 pitchPreview = 60
 num = number_of_random_instruments
 pitches = [pitchPreview] * num #Compute pitch
 #Generate latent vactor 
 latent_vector = model.generate_z(num)
 #Generate fake samples for latents and pitches of all the instruments
 audio_notes = model.generate_samples_from_z(latent_vector, pitches)
 for i, audio_note in enumerate(audio_notes):
 #Print the instrument number
   print("Instrument: {}".format(i))
 #Insert the HTML audio widget for each instrument’s audio file; pass the array of float sound i.e. audio_note and specify sample rate as parameters
   play(audio_note, sample_rate=16000) 

Audio files of the instruments:

Instrument0 Instrument1 Instrument2 Instrument3 Instrument4 Instrument5 Instrument6 Instrument7 Instrument8 Instrument9

Sample output showing widget for each instrument’s sound:

(You can play the audio, adjust its volume and download it using the widgets)

13) Create a list of instruments to interpolate between

     instruments = [0, 2, 4, 0]

Place each instrument at a specific point of time (from 0 to 1.0)

times = [0, 0.3, 0.6, 1.0]

Start and end times of synthesized audio

See Also

 times[0] = -0.001
 times[-1] = 1.0 

14) Latent vectors of selected instruments

z_instruments = np.array([latent_vector[i] for i in instruments])

End times for selected instruments

 t_instruments = np.array([notes['endTimes'][-1] * t for t in 

Get interpolated latent vectors for each note

z_notes = gu.get_z_notes(notes['startTimes'], z_instruments, t_instruments)

15) Generate audio for each note

 print('Generating {} samples...'.format(len(z_notes)))
 audio_notes = model.generate_samples_from_z(z_notes, notes['pitches']) 

16) Combine the audio samples of all instruments into a single audio clip

 ac = combine_notes(audio_notes,

17) Play the synthesized audio

 #Create audio widget; pass the clip and specify the sample rate
 play(ac, sample_rate=SR) 

18) Plot the spectrogram using spectrogram() function defined in step (10)

 print('CQT Spectrogram:')


For more information about GANSynth, refer to the following web links:

What Do You Think?

Join Our Discord Server. Be part of an engaging online community. Join Here.

Subscribe to our Newsletter

Get the latest updates and relevant offers by sharing your email.

Copyright Analytics India Magazine Pvt Ltd

Scroll To Top