Source code for diffFx_pytorch.processors.base

import torch 
import torch.nn as nn 
from dataclasses import dataclass
from typing import Dict, List, Tuple, Union
from .base_utils import check_params

[docs]@dataclass class EffectParam: min_val: float max_val: float default: float = None
[docs]class ProcessorsBase(nn.Module): """Base class for differentiable audio effect processors. This class provides the foundation for implementing audio effects processors with support for both normalized (0-1) and direct DSP parameter control. It handles parameter registration, validation, and mapping between normalized and DSP value ranges. The class supports two parameter interfaces: 1. Normalized parameters (0-1 range) for neural network control 2. Direct DSP parameters with actual audio processing values Args: sample_rate (int): Audio sample rate in Hz. Defaults to 44100. param_range (Dict[str, EffectParam], optional): Optional parameter definitions to override or extend default parameters. Attributes: sample_rate (int): Sampling rate in Hz params (Dict[str, EffectParam]): Registered effect parameters Parameter Management: Parameters are defined using EffectParam dataclass with: - min_val: Minimum DSP value - max_val: Maximum DSP value - default: Optional default value Example: Basic Implementation: >>> class MyEffect(ProcessorsBase): ... def _register_default_parameters(self): ... self.params = { ... 'frequency': EffectParam(min_val=20.0, max_val=20000.0), ... 'gain_db': EffectParam(min_val=-24.0, max_val=24.0) ... } ... ... def process(self, x, norm_params, dsp_params): ... # Implement effect processing here ... pass Parameter Usage: >>> effect = MyEffect(sample_rate=44100) >>> # Using DSP parameters >>> output = effect(input_audio, dsp_params={ ... 'frequency': 1000.0, # Direct Hz value ... 'gain_db': -6.0 # Direct dB value ... }) >>> >>> # Using normalized parameters (e.g., from neural network) >>> norm_params = torch.tensor([[0.5, 0.3]]) # [batch, num_params] >>> output = effect(input_audio, norm_params=norm_params) """
[docs] def __init__(self, sample_rate: int = 44100, param_range: Dict[str, EffectParam] = None): """Initialize the processor base. Args: sample_rate: Audio sample rate in Hz param_range: Optional parameter definitions to override defaults """ super().__init__() self.sample_rate = sample_rate self.params: Dict[str, EffectParam] = {} self._register_default_parameters() if param_range: self.params.update(param_range)
[docs] def _register_default_parameters(self): """Register default parameters for the processor. Override this method to define processor-specific parameters using EffectParam dataclass instances. """ pass
[docs] def _tensor_to_dict(self, tensor: torch.Tensor) -> Dict[str, torch.Tensor]: """Convert parameter tensor to dictionary""" assert len(tensor.shape) == 2, "Expected 2D tensor" # Check if tensor is 2D assert tensor.shape[1] == len(self.params), f"Expected {len(self.params)} parameters, got {tensor.shape[1]}" return {name: tensor[:, i] for i, name in enumerate(self.params.keys())}
[docs] def map_parameters(self, norm_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Maps normalized parameters (0-1) to DSP parameter ranges. Args: norm_params: Dictionary of normalized parameter values Returns: Dictionary of mapped DSP parameter values Note: Linear interpolation is used for mapping: dsp_value = min_val + (max_val - min_val) * norm_value """ return { name: param.min_val + (param.max_val - param.min_val) * norm_params[name] for name, param in self.params.items() }
[docs] def demap_parameters(self, dsp_params: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Maps DSP parameters back to normalized range (0-1). Args: dsp_params: Dictionary of DSP parameter values Returns: Dictionary of normalized parameter values """ return { name: (dsp_params[name] - param.min_val) / (param.max_val - param.min_val) for name, param in self.params.items() }
[docs] def create_dsp_params_batch(self, params_dict: Dict[str, float], batch_size: int = 1, device: Union[str, torch.device] = 'cpu') -> Dict[str, torch.Tensor]: """Creates batched tensor parameters from scalar DSP values. Args: params_dict: Dictionary of parameter names and scalar values batch_size: Number of copies in batch device: Target device for tensors Returns: Dictionary of batched parameter tensors Raises: KeyError: If parameter name not registered ValueError: If value outside valid range """ batched_params = {} for name, value in params_dict.items(): # Check if parameter exists if name not in self.params: raise KeyError(f"Parameter '{name}' not registered in effect processor") param_info = self.params[name] # Validate parameter range if value < param_info.min_val or value > param_info.max_val: raise ValueError( f"Parameter '{name}' value {value} is outside valid range " f"[{param_info.min_val}, {param_info.max_val}]" ) # Create batched tensor batched_params[name] = torch.full( (batch_size,), value, device=device, dtype=torch.float32 ) return batched_params
[docs] def forward(self, x: torch.Tensor, norm_params: Union[torch.Tensor, None] = None, dsp_params: Union[Dict[str, Union[float, torch.Tensor]], None] = None) -> torch.Tensor: """Process input with either normalized or DSP parameters. Args: x: Input audio tensor [batch, channels, samples] norm_params: Optional normalized parameters tensor [batch, num_params] dsp_params: Optional DSP parameters dictionary Each parameter can be: - float/int: Single value for all batch items - 0D tensor: Single value for all batch items - 1D tensor: Batch-specific values Returns: Processed audio tensor Raises: KeyError: If unknown parameter name provided TypeError: If parameter has invalid type ValueError: If tensor parameter shape invalid Note: Only one of norm_params or dsp_params should be provided. """ # check_params(norm_params, dsp_params) params_dict, dsp_params_dict = None, None batch_size = x.shape[0] if norm_params is not None: assert len(norm_params.shape) == 2, "Expected 2D tensor" # Check if tensor is 2D [b, num_params] params_dict = self._tensor_to_dict(norm_params) if dsp_params is not None: # Handle DSP parameters dsp_params_dict = {} for name, value in dsp_params.items(): if name not in self.params: raise KeyError(f"Unknown parameter: {name}") if isinstance(value, (int, float)): # Convert scalar to batched tensor dsp_params_dict[name] = torch.full((batch_size,), float(value), device=x.device, dtype=torch.float32) elif isinstance(value, torch.Tensor): # Validate tensor parameter if value.ndim == 0: # Scalar tensor dsp_params_dict[name] = value.expand(batch_size) elif value.ndim == 1: # Batched tensor assert value.shape[0] == batch_size, \ f"Parameter '{name}' batch size {value.shape[0]} != {batch_size}" dsp_params_dict[name] = value else: raise ValueError(f"Parameter '{name}' has too many dimensions: {value.ndim}") # Ensure parameter is on same device as input if value.device != x.device: dsp_params_dict[name] = dsp_params_dict[name].to(x.device) else: raise TypeError(f"Parameter '{name}' has invalid type: {type(value)}") # dsp_params_dict = dsp_params #self._tensor_to_dict(dsp_params) return self.process(x, params_dict, dsp_params_dict)
[docs] def process(self, x: torch.Tensor, norm_params: Union[Dict[str, torch.Tensor], None] = None, dsp_params: Union[Dict[str, torch.Tensor], None] = None) -> torch.Tensor: """Process audio with the effect (to be implemented by subclasses). Args: x: Input audio tensor [batch, channels, samples] norm_params: Optional dictionary of normalized parameters dsp_params: Optional dictionary of DSP parameters Returns: Processed audio tensor """ raise NotImplementedError
[docs] def count_num_parameters(self): """Returns the number of effect parameters. Returns: Number of registered parameters """ return len(self.params)