Source code for diffFx_pytorch.processors.eq.graphicEQ

import torch 
import torch.nn as nn
import numpy as np 
from enum import Enum
from typing import Dict, Union
from ..base_utils import check_params
from ..base import ProcessorsBase, EffectParam
from ..filters import BiquadFilter


class GraphicEQType(Enum):
    ISO = 'iso'          # ISO standard frequencies
    OCTAVE = 'octave'    # Octave spacing
    THIRD_OCTAVE = 'third_octave'  # 1/3 octave spacing
    
    
[docs]class GraphicEqualizer(ProcessorsBase): """Differentiable implementation of a multi-band graphic equalizer. Implementation is based on the following book: .. [1] Reiss, Joshua D., and Andrew McPherson. Audio effects: theory, implementation and application. CRC Press, 2014. This processor implements a parallel bank of peak filters to create a graphic equalizer, allowing independent gain control over multiple frequency bands. The implementation supports different frequency spacing schemes including ISO standard frequencies, octave spacing, and third-octave spacing. The equalizer uses second-order IIR peak filters for each band with transfer function: .. math:: H(z) = \\frac{b_0 + b_1z^{-1} + b_2z^{-2}}{1 + a_1z^{-1} + a_2z^{-2}} where coefficients are computed based on: - Center frequency of each band - Q factor (bandwidth) - Gain setting for each band Args: sample_rate (int): Audio sample rate in Hz. Defaults to 44100. num_bands (int): Number of frequency bands. Defaults to 10. q_factors (float): Q factor for band filters. Controls bandwidth. Defaults to None. eq_type (str): Frequency spacing scheme. Must be one of: - 'iso': ISO standard frequencies (31.5, 63, 125, 250, 500, 1000, 2000, 4000, 8000, 16000 Hz) - 'octave': Octave-spaced bands - 'third_octave': Third-octave spaced bands Defaults to 'iso'. Parameters Details: band_X_gain_db: Gain for band X (where X is 1 to num_bands) - Range: -12.0 to 12.0 dB - Controls gain at that frequency band - Positive values boost, negative values cut Note: The processor supports three types of frequency spacing: - ISO: Standard audio frequencies - Octave: Logarithmically spaced bands, one per octave - Third-octave: Logarithmically spaced bands, three per octave Each band uses a constant-Q design where the relative bandwidth remains consistent across frequencies. Warning: When using with neural networks: - norm_params must be in range [0, 1] - Parameters will be automatically mapped to their dB ranges - Ensure your network output is properly normalized (e.g., using sigmoid) - Parameter order must match _register_default_parameters() Examples: Basic DSP Usage: >>> # Create a 10-band graphic EQ with ISO frequencies >>> eq = GraphicEqualizer( ... sample_rate=44100, ... num_bands=10, ... q_factors=2.0, ... eq_type='iso' ... ) >>> # Process audio with dsp parameters >>> params = {f'band_{i+1}_gain_db': 6.0 for i in range(10)} # Boost all bands by 6dB >>> output = eq(input_audio, dsp_params=params) Neural Network Control: >>> # 1. Simple parameter prediction >>> class GraphicEQController(nn.Module): ... def __init__(self, input_size, num_bands): ... super().__init__() ... self.net = nn.Sequential( ... nn.Linear(input_size, 32), ... nn.ReLU(), ... nn.Linear(32, num_bands), ... nn.Sigmoid() # Ensures output is in [0,1] range ... ) ... ... def forward(self, x): ... return self.net(x) >>> >>> # Initialize controller >>> eq = GraphicEqualizer(num_bands=10) >>> controller = GraphicEQController(input_size=16, num_bands=10) >>> >>> # Process with features >>> features = torch.randn(batch_size, 16) # Audio features >>> norm_params = controller(features) >>> output = eq(input_audio, norm_params=norm_params) """
[docs] def __init__(self, sample_rate=44100, param_range = None, num_bands=10, q_factors=None, eq_type='octave'): self.num_bands = num_bands super().__init__(sample_rate, param_range) self.eq_type = GraphicEQType(eq_type) if eq_type == 'octave': self.R = 2 elif eq_type == 'third-octave': self.R = 2**(1/3) if q_factors is None: self.band_q = np.sqrt(self.R)/(self.R-1) else: self.band_q = q_factors # Constant Q design # Initialize filters self.fixed_frequencies = self._get_frequencies() self.band_filters = nn.ModuleList([ BiquadFilter( sample_rate=self.sample_rate, filter_type='PK', # backend='fsm' ) for _ in range(num_bands) ])
[docs] def _get_frequencies(self) -> list: """Get frequency bands based on equalizer type. Computes center frequencies for bands based on the selected EQ type: - ISO: Uses standard ISO center frequencies - Octave: Logarithmically spaced bands, one per octave - Third-octave: Logarithmically spaced bands, three per octave Returns: list: Center frequencies in Hz for each band Raises: ValueError: If eq_type is not recognized """ if self.eq_type == GraphicEQType.ISO: return [31.5, 63, 125, 250, 500, 1000, 2000, 4000, 8000, 16000] elif self.eq_type == GraphicEQType.OCTAVE: return np.geomspace(20, 20000, self.num_bands).tolist() elif self.eq_type == GraphicEQType.THIRD_OCTAVE: return np.geomspace(20, 20000, self.num_bands * 3).tolist() else: raise ValueError(f"Unknown EQ type: {self.eq_type}")
[docs] def _register_default_parameters(self): """Register gain parameters for each frequency band. Creates a gain parameter for each band with range -12 dB to +12 dB. Parameter names are formatted as 'band_X_gain_db' where X is the band number starting from 1. """ self.params = {} for i in range(self.num_bands): self.params[f'band_{i+1}_gain_db'] = EffectParam(min_val=-12.0, max_val=12.0)
[docs] def _prepare_band_parameters(self, band_idx: int, params: Dict[str, torch.Tensor], device: torch.device ) -> Dict[str, torch.Tensor]: """Prepare filter parameters for a single frequency band. Args: band_idx (int): Index of the band to prepare parameters for params (Dict[str, torch.Tensor]): All EQ parameters device (torch.device): Device to place tensors on Returns: Dict[str, torch.Tensor]: Parameters for the band's peak filter: - gain_db: Gain in dB - frequency: Center frequency in Hz - q_factor: Q factor for bandwidth Note: Expands scalar parameters to match batch size of provided parameters. """ band_name = f'band_{band_idx+1}' freq = torch.tensor(self.fixed_frequencies[band_idx], device=device) q = torch.tensor(self.band_q, device=device) # Expand parameters to match batch size if needed batch_size = params[f'{band_name}_gain_db'].shape[0] freq = freq.expand(batch_size).float() q = q.expand(batch_size).float() return { 'gain_db': params[f'{band_name}_gain_db'], 'frequency': freq, 'q_factor': q }
[docs] def process(self, x: torch.Tensor, norm_params: Union[Dict[str, torch.Tensor], None] = None, dsp_params: Union[Dict[str, torch.Tensor], None] = None): """Process input signal through the graphic equalizer. Args: x (torch.Tensor): Input audio tensor. Shape: (batch, channels, samples) norm_params (Dict[str, torch.Tensor]): Normalized parameters (0 to 1) Dictionary with keys 'band_X_gain_db' for X in range(1, num_bands+1) Each value should be a tensor of shape (batch_size,) Values will be mapped to -12.0 to 12.0 dB dsp_params (Dict[str, Union[float, torch.Tensor]], optional): Direct DSP parameters. Can specify band gains as: - float/int: Single value applied to entire batch - 0D tensor: Single value applied to entire batch - 1D tensor: Batch of values matching input batch size Parameters will be automatically expanded to match batch size and moved to input device if necessary. If provided, norm_params must be None. Returns: torch.Tensor: Processed audio tensor of same shape as input """ check_params(norm_params, dsp_params) # Map parameters if norm_params is not None: params = self.map_parameters(norm_params) else: params = dsp_params # Process each band in parallel outputs = [] for i in range(self.num_bands): band_params = self._prepare_band_parameters(i, params, x.device) band_output = self.band_filters[i](x, None, dsp_params=band_params) outputs.append(band_output) # Sum all band outputs and normalize output = torch.stack(outputs).sum(dim=0) / self.num_bands return output
@property def frequencies(self) -> list: """Get the list of center frequencies. Returns: list: Center frequencies in Hz for all bands Note: Frequencies depend on the equalizer type ('iso', 'octave', or 'third_octave') and remain fixed after initialization. """ return self.fixed_frequencies