Source code for diffFx_pytorch.processors.delay.multitaps

import torch 
import torch.nn as nn
import torch.nn.functional as F
import numpy as np 
from typing import Dict, List, Tuple, Union
from ..base import ProcessorsBase, EffectParam
from ..base_utils import check_params
from ..core.phase import unwrap_phase

# MultiTapDelay
[docs]class MultiTapsDelay(ProcessorsBase): """Differentiable implementation of a multi-tap delay effect. This processor implements a parallel delay structure with multiple taps, where each tap represents an independent echo with its own delay time and gain. The implementation uses frequency-domain processing for precise timing control and efficient computation. Implementation is based on: .. [1] Reiss, Joshua D., and Andrew McPherson. Audio effects: theory, implementation and application. CRC Press, 2014. .. [2] Smith, Julius O. "Digital Audio Effects." https://ccrma.stanford.edu/~jos/fp3/Phase_Unwrapping.html The transfer function is a sum of delayed signals: .. math:: H(\\omega) = \\sum_{i=0}^{N-1} g_i e^{-j\\omega\\tau_i} where: - N is the number of taps - g_i is the gain of tap i - τ_i is the delay time of tap i - Phase is unwrapped for each tap Processing Chain: 1. Zero-pad input for maximum delay buffer 2. Convert to frequency domain 3. Calculate phase shifts for each tap 4. Apply gains and sum delayed signals 5. Convert back to time domain 6. Mix processed signal with original Args: sample_rate (int): Audio sample rate in Hz num_taps (int): Number of independent delay taps. Defaults to 4. param_range (Dict[str, EffectParam], optional): Parameter ranges. Parameters Details: For each tap i (where i ranges from 0 to num_taps-1): i_tap_delays_ms: Delay time for tap i - Range: 50.0 to 500.0 milliseconds - Controls timing of each echo - Independent control per tap - Can create complex rhythmic patterns i_tap_gains: Gain for tap i - Range: 0.0 to 1.0 - Controls amplitude of each echo - Allows creation of complex patterns - Can be used for amplitude envelopes mix: Overall wet/dry mix ratio - Range: 0.0 to 1.0 - 0.0: Only original signal - 1.0: Only processed signal - Controls overall effect intensity Note: - Uses FFT-based delay for precise time shifting - Phase unwrapping prevents discontinuities - Automatic padding handles all delay times - Particularly effective for: - Creating complex rhythmic patterns - Building custom echo sequences - Designing unique delay textures - Each tap can be independently controlled - System is stable for all parameter values Examples: Basic DSP Usage: >>> # Create a 4-tap delay >>> delay = MultiTapsDelay(sample_rate=44100, num_taps=4) >>> # Process with rhythmic pattern >>> params = { ... '0_tap_delays_ms': 125.0, # Eighth note at 120 BPM ... '0_tap_gains': 0.8, ... '1_tap_delays_ms': 250.0, # Quarter note ... '1_tap_gains': 0.6, ... '2_tap_delays_ms': 375.0, # Dotted quarter ... '2_tap_gains': 0.4, ... '3_tap_delays_ms': 500.0, # Half note ... '3_tap_gains': 0.2, ... 'mix': 0.5 ... } >>> output = delay(input_audio, dsp_params=params) Neural Network Control: >>> # 1. Simple parameter prediction >>> class MultiTapController(nn.Module): ... def __init__(self, input_size, num_taps): ... super().__init__() ... num_params = 2 * num_taps + 1 # delays, gains, and mix ... self.net = nn.Sequential( ... nn.Linear(input_size, 32), ... nn.ReLU(), ... nn.Linear(32, num_params), ... nn.Sigmoid() # Ensures output is in [0,1] range ... ) ... ... def forward(self, x): ... return self.net(x) >>> >>> # Process with features >>> controller = MultiTapController(input_size=16, num_taps=4) >>> features = torch.randn(batch_size, 16) >>> norm_params = controller(features) >>> output = delay(input_audio, norm_params=norm_params) """
[docs] def __init__(self, sample_rate, param_range=None, num_taps=4): self.num_taps = num_taps super().__init__(sample_rate, param_range)
[docs] def _register_default_parameters(self): """Register parameters for all taps and mix. Creates parameters for each tap: - i_tap_delays_ms: Delay time (50.0 to 500.0 ms) - i_tap_gains: Tap gain (0.0 to 1.0) Plus overall mix parameter. Total parameters = 2 * num_taps + 1 """ self.params = {} for i in range(self.num_taps): self.params.update({ f'{i}_tap_delays_ms': EffectParam(min_val=50.0, max_val=500.0), f'{i}_tap_gains': EffectParam(min_val=0.0, max_val=1.0) }) self.params['mix'] = EffectParam(min_val=0.0, max_val=1.0)
[docs] def process(self, x: torch.Tensor, norm_params: Union[Dict[str, torch.Tensor], None] = None, dsp_params: Union[Dict[str, torch.Tensor], None] = None): """Process input signal through the multi-tap delay. Args: x (torch.Tensor): Input audio tensor. Shape: (batch, channels, samples) norm_params (Dict[str, torch.Tensor]): Normalized parameters (0 to 1) Must contain the following keys: - '{i}_tap_delays_ms': Base delay time for each tap (0 to 1) - '{i}_tap_gains': Tap gain for each tap (0 to 1) - 'mix': Wet/dry balance (0 to 1) Each value should be a tensor of shape (batch_size,) dsp_params (Dict[str, Union[float, torch.Tensor]], optional): Direct DSP parameters. Can specify multi-tap parameters as: - float/int: Single value applied to entire batch - 0D tensor: Single value applied to entire batch - 1D tensor: Batch of values matching input batch size Parameters will be automatically expanded to match batch size and moved to input device if necessary. If provided, norm_params must be None. Returns: torch.Tensor: Processed audio tensor of same shape as input """ check_params(norm_params, dsp_params) if norm_params is not None: params = self.map_parameters(norm_params) else: params = dsp_params tap_delays_ms = torch.stack([params[f'{i}_tap_delays_ms'] for i in range(self.num_taps)]) tap_gains = torch.stack([params[f'{i}_tap_gains'] for i in range(self.num_taps)]) mix = params['mix'].view(-1, 1, 1) b, ch, s = x.shape max_delay_samples = max( 1, int(torch.max(tap_delays_ms) * self.sample_rate / 1000) ) # Calculate FFT size (next power of 2 for efficiency) fft_size = 2 ** int(np.ceil(np.log2(x.shape[-1] + max_delay_samples))) # Pad input signal to FFT size pad_right = fft_size - (x.shape[-1] + max_delay_samples) x_padded = torch.nn.functional.pad(x, (max_delay_samples, pad_right)) X = torch.fft.rfft(x_padded, n=fft_size) freqs = torch.fft.rfftfreq(x_padded.shape[-1], 1/self.sample_rate).to(x.device) y = torch.zeros_like(X) for i in range(self.num_taps): phase = -2 * np.pi * freqs * tap_delays_ms[i].view(-1, 1, 1) / 1000 phase = unwrap_phase(phase, dim=-1).double() z_n = torch.exp(1j * phase).to(X.dtype) y += tap_gains[i].view(-1, 1, 1) * z_n * X # y = torch.fft.irfft(y, n=x_padded.shape[-1])[:, :, max_delay_samples:] y = torch.fft.irfft(y, n=fft_size) y = y[..., max_delay_samples:max_delay_samples + x.shape[-1]] return (1 - mix) * x + mix * y