Signal-to-Noise Ratio

Imagine developing a song recognition application. The software’s goal is to recognize a song even when it’s played in a noisy environment, similar to Shazam. To achieve this, you want to enhance the audio quality by reducing the noise and evaluating the improvement using the Signal-to-Noise Ratio (SNR).

In this example, we will demonstrate how to generate a clean signal, add varying levels of noise to simulate the noisy recording, use FFT for noise reduction, and then evaluate the quality of the reconstructed audio using SNR.

Import necessary libraries

11 from typing import Tuple
12
13 import matplotlib.animation as animation
14 import matplotlib.pyplot as plt
15 import numpy as np
16 import torch
17 from torchmetrics.audio import SignalNoiseRatio
18
19 # Set seed for reproducibility
20 torch.manual_seed(42)
21 np.random.seed(42)

Generate a clean signal (simulating a high-quality recording)

26 def generate_clean_signal(length: int = 1000) -> Tuple[np.ndarray, np.ndarray]:
27     """Generate a clean signal (sine wave)"""
28     t = np.linspace(0, 1, length)
29     signal = np.sin(2 * np.pi * 10 * t)  # 10 Hz sine wave, representing the clean recording
30     return t, signal

Add Gaussian noise to the signal to simulate the noisy environment

35 def add_noise(signal: np.ndarray, noise_level: float = 0.5) -> np.ndarray:
36     """Add Gaussian noise to the signal."""
37     noise = noise_level * np.random.randn(signal.shape[0])
38     return signal + noise

Apply FFT to filter out the noise

43 def fft_denoise(noisy_signal: np.ndarray, threshold: float) -> np.ndarray:
44     """Denoise the signal using FFT."""
45     freq_domain = np.fft.fft(noisy_signal)  # Filter frequencies using FFT
46     magnitude = np.abs(freq_domain)
47     filtered_freq_domain = freq_domain * (magnitude > threshold)
48     return np.fft.ifft(filtered_freq_domain).real  # Perform inverse FFT to reconstruct the signal

Generate and plot clean, noisy, and denoised signals to visualize the reconstruction

53 length = 1000
54 t, clean_signal = generate_clean_signal(length)
55 noisy_signal = add_noise(clean_signal, noise_level=0.5)
56 denoised_signal = fft_denoise(noisy_signal, threshold=10)
57
58 plt.figure(figsize=(12, 4))
59 plt.plot(t, noisy_signal, label="Noisy environment", color="blue", alpha=0.7)
60 plt.plot(t, denoised_signal, label="Denoised signal", color="green", alpha=0.7)
61 plt.plot(t, clean_signal, label="Clean song", color="red", linewidth=3)
62 plt.xlabel("Time")
63 plt.ylabel("Amplitude")
64 plt.title("Clean Song vs. Noisy Environment vs. Denoised Signal")
65 plt.legend()
66 plt.show()
Clean Song vs. Noisy Environment vs. Denoised Signal

Convert the signals to PyTorch tensors and calculate the SNR

70 clean_signal_tensor = torch.tensor(clean_signal).float()
71 noisy_signal_tensor = torch.tensor(noisy_signal).float()
72 denoised_signal_tensor = torch.tensor(denoised_signal).float()
73
74 snr = SignalNoiseRatio()
75 initial_snr = snr(preds=noisy_signal_tensor, target=clean_signal_tensor)
76 reconstructed_snr = snr(preds=denoised_signal_tensor, target=clean_signal_tensor)
77 print(f"Initial SNR: {initial_snr:.2f}")
78 print(f"Reconstructed SNR: {reconstructed_snr:.2f}")
Initial SNR: 3.19
Reconstructed SNR: 3.49

To show the effect of different noise levels on the SNR, we create an animation that iterates over different noise levels and updates the plot accordingly:

 82 fig, ax = plt.subplots(figsize=(12, 4))
 83 noise_levels = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
 84
 85
 86 def update(num: int) -> tuple:
 87     """Update the plot for each frame."""
 88     t, clean_signal = generate_clean_signal(length)
 89     noisy_signal = add_noise(clean_signal, noise_level=noise_levels[num])
 90     denoised_signal = fft_denoise(noisy_signal, threshold=10)
 91
 92     clean_signal_tensor = torch.tensor(clean_signal).float()
 93     noisy_signal_tensor = torch.tensor(noisy_signal).float()
 94     denoised_signal_tensor = torch.tensor(denoised_signal).float()
 95     initial_snr = snr(preds=noisy_signal_tensor, target=clean_signal_tensor)
 96     reconstructed_snr = snr(preds=denoised_signal_tensor, target=clean_signal_tensor)
 97
 98     ax.clear()
 99     (noisy,) = plt.plot(t, noisy_signal, label="Noisy Environment", color="blue", alpha=0.7)
100     (denoised,) = plt.plot(t, denoised_signal, label="Denoised Signal", color="green", alpha=0.7)
101     (clean,) = plt.plot(t, clean_signal, label="Clean Song", color="red", linewidth=3)
102     ax.set_xlabel("Time")
103     ax.set_ylabel("Amplitude")
104     ax.set_title(
105         f"Initial SNR: {initial_snr:.2f} - Reconstructed SNR: {reconstructed_snr:.2f} - Noise level: {noise_levels[num]}"
106     )
107     ax.legend(loc="upper right")
108     ax.set_ylim(-3, 3)
109     return noisy, denoised, clean
110
111
112 ani = animation.FuncAnimation(fig, update, frames=len(noise_levels), interval=1000)

Total running time of the script: (0 minutes 2.961 seconds)

Gallery generated by Sphinx-Gallery