Source code for diffFx_pytorch.processors.spatial.imager

import torch 
import torch.nn as nn
import numpy as np 
from typing import Dict, Union, Tuple
from ..base_utils import check_params 
from ..base import ProcessorsBase, EffectParam
from ..core.midside import * 
from ..filters import LinkwitzRileyFilter


[docs]class StereoImager(ProcessorsBase): """Differentiable implementation of a multi-band stereo imaging processor. This processor implements frequency-dependent stereo width control using mid-side (M/S) processing combined with Linkwitz-Riley crossover filters. It allows independent width control over multiple frequency bands, enabling precise stereo field manipulation across the frequency spectrum. The processor splits the signal into frequency bands using a series of Linkwitz-Riley crossover filters, processes each band's stereo width independently, then recombines the bands. Processing Chain: 1. Convert L/R to M/S representation 2. Split M/S signals into frequency bands using crossovers 3. Apply independent width control to each band 4. Sum processed bands 5. Convert back to L/R representation The width control for each band follows: .. math:: M_{out} = M_{in} * 2(1 - width) S_{out} = S_{in} * 2(width) where: - M is the mid (mono) signal for the band - S is the side (difference) signal for the band - width is the stereo width control parameter for that band Args: sample_rate (int): Audio sample rate in Hz num_bands (int): Number of frequency bands. Defaults to 3. Attributes: crossovers (nn.ModuleList): List of Linkwitz-Riley crossover filters num_bands (int): Number of frequency bands Parameters Details: For each band i: bandX_width: Stereo width control for band X - 0.0: Mono (only mid signal) - 0.5: Original stereo - 1.0: Maximum width (enhanced side signal) For each crossover i: crossoverX_freq: Crossover frequency between bands X and X+1 - Frequency range scales with band number - Default ranges follow standard mastering crossover points - Min frequency doubles for each successive crossover - Max frequency is limited to 20kHz Note: - Input must be stereo (two channels) - Uses energy-preserving M/S conversion matrices - Linkwitz-Riley crossovers ensure phase coherence - Total number of parameters = 2 * num_bands - 1 - Width controls affect the ratio of mid to side signal per band Warning: When using with neural networks: - norm_params must be in range [0, 1] - Parameters will be automatically mapped to their ranges - Ensure your network output is properly normalized (e.g., using sigmoid) - Parameter order must match _register_default_parameters() Examples: Basic DSP Usage: >>> # Create a 3-band stereo imager >>> imager = StereoImager( ... sample_rate=44100, ... num_bands=3 ... ) >>> # Process with different width for each band >>> output = imager(input_audio, dsp_params={ ... 'band0_width': 0.3, # Reduce width in low frequencies ... 'band1_width': 0.5, # Keep mids unchanged ... 'band2_width': 0.8, # Enhance width in highs ... 'crossover0_freq': 200.0, # Low/mid crossover ... 'crossover1_freq': 2000.0 # Mid/high crossover ... }) Neural Network Control: >>> # 1. Simple parameter prediction >>> class ImagerController(nn.Module): ... def __init__(self, input_size, num_params): ... super().__init__() ... self.net = nn.Sequential( ... nn.Linear(input_size, 32), ... nn.ReLU(), ... nn.Linear(32, num_params), ... nn.Sigmoid() # Ensures output is in [0,1] range ... ) ... ... def forward(self, x): ... return self.net(x) >>> >>> # Initialize controller >>> imager = StereoImager(num_bands=3) >>> num_params = imager.count_num_parameters() # 5 parameters for 3 bands >>> controller = ImagerController(input_size=16, num_params=num_params) >>> >>> # Process with features >>> features = torch.randn(batch_size, 16) # Audio features >>> norm_params = controller(features) >>> output = imager(input_audio, norm_params=norm_params) """
[docs] def __init__(self, sample_rate, param_range: Dict[str, EffectParam]=None, num_bands=3): self.num_bands = num_bands super().__init__(sample_rate, param_range) # Create crossover filters self.crossovers = nn.ModuleList([ LinkwitzRileyFilter(sample_rate) for _ in range(num_bands - 1) ])
[docs] def _register_default_parameters(self): """Register parameters for band widths and crossover frequencies. Sets up: - Width control for each frequency band (0.0 to 1.0) - Crossover frequencies between bands (frequency ranges scale with band) """ self.params = {} # Width controls for each band (0 = only mid, 1 = only side) for i in range(self.num_bands): self.params[f'band{i}_width'] = EffectParam( min_val=0.0, max_val=1.0 ) # Crossover frequencies between bands # Using standard mastering crossover points as defaults for i in range(self.num_bands - 1): min_freq = 20.0 * (2 ** i) max_freq = min(20000.0, min_freq * 100) self.params[f'crossover{i}_freq'] = EffectParam( min_val=min_freq, max_val=max_freq )
[docs] def _apply_width(self, mid: torch.Tensor, side: torch.Tensor, width: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Apply stereo width processing to mid/side signals for a single band. Args: mid (torch.Tensor): Mid signal for the band. Shape: (batch, 1, samples) side (torch.Tensor): Side signal for the band. Shape: (batch, 1, samples) width (torch.Tensor): Width control parameter. Shape: (batch,) Returns: Tuple[torch.Tensor, torch.Tensor]: Processed (mid, side) signals Note: Scales mid and side signals to maintain constant energy across width settings """ width = width.view(-1, 1, 1) # Reshape for broadcasting return ( mid * (2 * (1 - width)), # Scale mid based on width side * (2 * width) # Scale side based on width )
[docs] def process(self, x: torch.Tensor, norm_params: Union[Dict[str, torch.Tensor], None] = None, dsp_params: Union[Dict[str, torch.Tensor], None] = None) -> torch.Tensor: """Process input signal through the multi-band stereo imager. Args: x (torch.Tensor): Input audio tensor. Shape: (batch, 2, samples) norm_params (Dict[str, torch.Tensor]): Normalized parameters (0 to 1) Must contain the following keys: - 'bandi_width': Width control for band i (0 to 1) - 'crossoveri_freq': Frequency between band i and band i+1 (0 to 1) Each value should be a tensor of shape (batch_size,) dsp_params (Dict[str, Union[float, torch.Tensor]], optional): Direct DSP parameters. Can specify imager parameters as: - float/int: Single value applied to entire batch - 0D tensor: Single value applied to entire batch - 1D tensor: Batch of values matching input batch size Parameters will be automatically expanded to match batch size and moved to input device if necessary. If provided, norm_params must be None. Returns: torch.Tensor: Processed stereo audio tensor. Shape: (batch, 2, samples) Raises: AssertionError: If input is not stereo (two channels) """ check_params(norm_params, dsp_params) if norm_params is not None: params = self.map_parameters(norm_params) else: params = dsp_params bs, chs, seq_len = x.size() assert chs == 2, "Input tensor must have shape (batch_size, 2, seq_len)" # Convert to mid-side x_ms = lr_to_ms(x, mult=1/np.sqrt(2)) mid, side = torch.split(x_ms, (1, 1), dim=-2) # Split into frequency bands using LR crossovers mid_bands = [] side_bands = [] current_mid = mid current_side = side # Apply crossovers in series for i, crossover in enumerate(self.crossovers): # Split mid signal mid_lh = crossover.process(current_mid, norm_params=None,dsp_params={ 'frequency': params[f'crossover{i}_freq'] }) mid_low, mid_high = torch.split(mid_lh, (1,1), dim=-2) # Split side signal (using same crossover frequency) side_lh = crossover.process(current_side, norm_params=None,dsp_params={ 'frequency': params[f'crossover{i}_freq'] }) side_low, side_high = torch.split(side_lh, (1,1), dim=-2) mid_bands.append(mid_low) side_bands.append(side_low) current_mid = mid_high current_side = side_high # Add the final high bands mid_bands.append(current_mid) side_bands.append(current_side) # Process each band processed_mid = torch.zeros_like(mid) processed_side = torch.zeros_like(side) for i in range(self.num_bands): # Apply width processing to each band width = params[f'band{i}_width'] mid_processed, side_processed = self._apply_width( mid_bands[i], side_bands[i], width ) # Sum the processed bands processed_mid += mid_processed processed_side += side_processed # Combine processed mid-side signals x_ms_new = torch.cat([processed_mid, processed_side], dim=-2) # Convert back to left-right x_lr = ms_to_lr(x_ms_new, mult=1/np.sqrt(2)) return x_lr