import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Dict, List, Tuple, Union
from enum import Enum
from ..base_utils import check_params
from ..base import ProcessorsBase, EffectParam
from ..core.midside import *
from ..core.phase import unwrap_phase
# Haas Effect
[docs]class StereoEnhancer(ProcessorsBase):
"""Differentiable implementation of stereo enhancement using the Haas effect.
This processor implements stereo enhancement using the Haas effect (precedence effect),
which creates an enhanced sense of stereo width by introducing small time delays between
channels. The implementation combines mid-side processing with frequency-domain delay
to achieve precise control over the stereo image.
The Haas effect exploits the human auditory system's precedence effect, where delays
between 1-30ms affect spatial perception without creating distinct echoes. The processor
applies the delay in the frequency domain for artifact-free time shifting.
Processing Chain:
1. Convert L/R to M/S representation
2. Apply frequency-domain delay to side signal
3. Apply width scaling to delayed side signal
4. Convert back to L/R representation
The frequency domain delay is implemented as:
.. math::
S_{delayed}(f) = S(f) * e^{-j2\\pi f \\tau}
where:
- S(f) is the side signal in frequency domain
- f is frequency
- τ is the delay time in seconds
- Phase is unwrapped to ensure continuous delay
Args:
sample_rate (int): Audio sample rate in Hz
Parameters Details:
delay_ms: Delay time for the Haas effect
- Range: 0 to 30 milliseconds
- Values around 10-15ms typically most effective
- Controls the perceived spatial width
- Based on psychoacoustic precedence effect
width: Overall stereo width control
- Range: 0.0 to 1.0
- 0.5: No enhancement (original signal)
- 1.0: Maximum enhancement
- Scales the processed side signal
Note:
- Input must be stereo (two channels)
- Uses frequency domain processing for precise delays
- Phase unwrapping ensures continuous delay response
- Delay range chosen based on psychoacoustic research
- Maintains mono compatibility
- Most effective on transient-rich material
Warning:
When using with neural networks:
- norm_params must be in range [0, 1]
- Parameters will be automatically mapped to their ranges
- Ensure your network output is properly normalized (e.g., using sigmoid)
- Parameter order must match _register_default_parameters()
High delay values (>20ms) may cause noticeable separation of channels,
particularly on transient material.
Examples:
Basic DSP Usage:
>>> # Create a stereo enhancer
>>> enhancer = StereoEnhancer(sample_rate=44100)
>>> # Process with moderate Haas effect
>>> output = enhancer(input_audio, dsp_params={
... 'delay_ms': 12.0, # 12ms delay for natural width
... 'width': 0.7 # 70% enhancement amount
... })
Neural Network Control:
>>> # 1. Simple parameter prediction
>>> class EnhancerController(nn.Module):
... def __init__(self, input_size):
... super().__init__()
... self.net = nn.Sequential(
... nn.Linear(input_size, 32),
... nn.ReLU(),
... nn.Linear(32, 2), # 2 parameters: delay and width
... nn.Sigmoid() # Ensures output is in [0,1] range
... )
...
... def forward(self, x):
... return self.net(x)
>>>
>>> # Initialize controller
>>> enhancer = StereoEnhancer(sample_rate=44100)
>>> controller = EnhancerController(input_size=16)
>>>
>>> # Process with features
>>> features = torch.randn(batch_size, 16) # Audio features
>>> norm_params = controller(features)
>>> output = enhancer(input_audio, norm_params=norm_params)
"""
[docs] def _register_default_parameters(self):
"""Register delay and width parameters.
Sets up:
- delay_ms: Haas effect delay time (0 to 30 ms)
- width: Enhancement amount (0.0 to 1.0)
"""
self.params = {
'delay_ms': EffectParam(min_val=0.0, max_val=30.0),
'width': EffectParam(min_val=0.0, max_val=1.0)
}
[docs] def process(self, x: torch.Tensor, norm_params: Union[Dict[str, torch.Tensor], None]=None, dsp_params: Union[Dict[str, torch.Tensor], None] = None):
"""Process input signal through the stereo enhancer.
Args:
x (torch.Tensor): Input audio tensor. Shape: (batch, 2, samples)
norm_params (Dict[str, torch.Tensor]): Normalized parameters (0 to 1)
Must contain the following keys:
- 'delay_ms': Delay time for side signal (0 to 1)
- 'width': Stereo width enhancement (0 to 1)
Each value should be a tensor of shape (batch_size,)
dsp_params (Dict[str, Union[float, torch.Tensor]], optional): Direct DSP parameters.
Can specify enhancer parameters as:
- float/int: Single value applied to entire batch
- 0D tensor: Single value applied to entire batch
- 1D tensor: Batch of values matching input batch size
Parameters will be automatically expanded to match batch size
and moved to input device if necessary.
If provided, norm_params must be None.
Returns:
torch.Tensor: Processed stereo audio tensor. Shape: (batch, 2, samples)
Raises:
AssertionError: If input is not stereo (two channels)
"""
check_params(norm_params, dsp_params)
if norm_params is not None:
params = self.map_parameters(norm_params)
else:
params = dsp_params
bs, chs, seq_len = x.size()
assert chs == 2, "Input tensor must have shape (bs, 2, seq_len)"
# Convert to M/S
x_ms = lr_to_ms(x, mult=1/np.sqrt(2))
mid, side = torch.split(x_ms, (1, 1), -2)
# Apply delay to side channel
# Calculate FFT size (next power of 2 for efficiency)
max_delay_samples = max(
1,
int(torch.max(params['delay_ms']) * self.sample_rate / 1000)
)
fft_size = 2 ** int(np.ceil(np.log2(side.shape[-1] + max_delay_samples)))
# Pad input signal to FFT size
pad_right = fft_size - (x.shape[-1] + max_delay_samples)
side_padded = torch.nn.functional.pad(side, (max_delay_samples, pad_right))
Side = torch.fft.rfft(side_padded, n=fft_size)
freqs = torch.fft.rfftfreq(side_padded.shape[-1], 1/self.sample_rate).to(x.device)
phase = -2 * np.pi * freqs * params['delay_ms'].view(-1, 1, 1) / 1000
phase = unwrap_phase(phase, dim=-1)
Side = Side * torch.exp(1j * phase).to(Side.dtype)
# Convert back to time domain with width control
side_delayed = torch.fft.irfft(Side, n=fft_size)
side_delayed = side_delayed[..., max_delay_samples:max_delay_samples + side.shape[-1]]
width = params['width'].view(-1, 1, 1)
x_ms_new = torch.cat([mid, side_delayed * width], -2)
x_lr = ms_to_lr(x_ms_new, mult=1/np.sqrt(2))
return x_lr