{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Signal-to-Noise Ratio\n\nImagine developing a song recognition application. The software's goal is to recognize a song even when it's played in a noisy environment, similar to Shazam. To achieve this, you want to enhance the audio quality by reducing the noise and evaluating the improvement using the Signal-to-Noise Ratio (SNR).\n\nIn this example, we will demonstrate how to generate a clean signal, add varying levels of noise to simulate the noisy recording, use FFT for noise reduction, and then evaluate the quality of the reconstructed audio using SNR.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Import necessary libraries\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import matplotlib.animation as animation\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport torch\nfrom torchmetrics.audio import SignalNoiseRatio" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Generate a clean signal (simulating a high-quality recording)\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def generate_clean_signal(length: int = 1000) -> tuple[np.ndarray, np.ndarray]:\n \"\"\"Generate a clean signal (sine wave)\"\"\"\n t = np.linspace(0, 1, length)\n signal = np.sin(2 * np.pi * 10 * t) # 10 Hz sine wave, representing the clean recording\n return t, signal" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Add Gaussian noise to the signal to simulate the noisy environment\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def add_noise(signal: np.ndarray, noise_level: float = 0.5) -> np.ndarray:\n \"\"\"Add Gaussian noise to the signal.\"\"\"\n noise = noise_level * np.random.randn(signal.shape[0])\n return signal + noise" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Apply FFT to filter out the noise\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def fft_denoise(noisy_signal: np.ndarray, threshold: float) -> np.ndarray:\n \"\"\"Denoise the signal using FFT.\"\"\"\n freq_domain = np.fft.fft(noisy_signal) # Filter frequencies using FFT\n magnitude = np.abs(freq_domain)\n filtered_freq_domain = freq_domain * (magnitude > threshold)\n return np.fft.ifft(filtered_freq_domain).real # Perform inverse FFT to reconstruct the signal" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Generate and plot clean, noisy, and denoised signals to visualize the reconstruction\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "length = 1000\nt, clean_signal = generate_clean_signal(length)\nnoisy_signal = add_noise(clean_signal, noise_level=0.5)\ndenoised_signal = fft_denoise(noisy_signal, threshold=10)\n\nplt.figure(figsize=(12, 4))\nplt.plot(t, noisy_signal, label=\"Noisy environment\", color=\"blue\", alpha=0.7)\nplt.plot(t, denoised_signal, label=\"Denoised signal\", color=\"green\", alpha=0.7)\nplt.plot(t, clean_signal, label=\"Clean song\", color=\"red\", linewidth=3)\nplt.xlabel(\"Time\")\nplt.ylabel(\"Amplitude\")\nplt.title(\"Clean Song vs. Noisy Environment vs. Denoised Signal\")\nplt.legend()\nplt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Convert the signals to PyTorch tensors and calculate the SNR\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "clean_signal_tensor = torch.tensor(clean_signal).float()\nnoisy_signal_tensor = torch.tensor(noisy_signal).float()\ndenoised_signal_tensor = torch.tensor(denoised_signal).float()\n\nsnr = SignalNoiseRatio()\ninitial_snr = snr(preds=noisy_signal_tensor, target=clean_signal_tensor)\nreconstructed_snr = snr(preds=denoised_signal_tensor, target=clean_signal_tensor)\nprint(f\"Initial SNR: {initial_snr:.2f}\")\nprint(f\"Reconstructed SNR: {reconstructed_snr:.2f}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To show the effect of different noise levels on the SNR, we create an animation that iterates over different noise levels and updates the plot accordingly:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "fig, ax = plt.subplots(figsize=(12, 4))\nnoise_levels = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]\n\n\ndef update(num: int) -> tuple:\n \"\"\"Update the plot for each frame.\"\"\"\n t, clean_signal = generate_clean_signal(length)\n noisy_signal = add_noise(clean_signal, noise_level=noise_levels[num])\n denoised_signal = fft_denoise(noisy_signal, threshold=10)\n\n clean_signal_tensor = torch.tensor(clean_signal).float()\n noisy_signal_tensor = torch.tensor(noisy_signal).float()\n denoised_signal_tensor = torch.tensor(denoised_signal).float()\n initial_snr = snr(preds=noisy_signal_tensor, target=clean_signal_tensor)\n reconstructed_snr = snr(preds=denoised_signal_tensor, target=clean_signal_tensor)\n\n ax.clear()\n (noisy,) = plt.plot(t, noisy_signal, label=\"Noisy Environment\", color=\"blue\", alpha=0.7)\n (denoised,) = plt.plot(t, denoised_signal, label=\"Denoised Signal\", color=\"green\", alpha=0.7)\n (clean,) = plt.plot(t, clean_signal, label=\"Clean Song\", color=\"red\", linewidth=3)\n ax.set_xlabel(\"Time\")\n ax.set_ylabel(\"Amplitude\")\n ax.set_title(\n f\"Initial SNR: {initial_snr:.2f} - Reconstructed SNR: {reconstructed_snr:.2f} - Noise level: {noise_levels[num]}\"\n )\n ax.legend(loc=\"upper right\")\n ax.set_ylim(-3, 3)\n return noisy, denoised, clean\n\n\nani = animation.FuncAnimation(fig, update, frames=len(noise_levels), interval=1000)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.20" } }, "nbformat": 4, "nbformat_minor": 0 }