Skip to main content

voirs_spatial/neural/
types.rs

1//! Core types and configurations for neural spatial audio processing
2
3use crate::types::Position3D;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6
7/// Configuration for neural spatial audio processing
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct NeuralSpatialConfig {
10    /// Model architecture type
11    pub model_type: NeuralModelType,
12    /// Hidden layer dimensions
13    pub hidden_dims: Vec<usize>,
14    /// Input feature dimensions
15    pub input_dim: usize,
16    /// Output audio channels (typically 2 for binaural)
17    pub output_channels: usize,
18    /// Sample rate in Hz
19    pub sample_rate: u32,
20    /// Buffer size for processing
21    pub buffer_size: usize,
22    /// Whether to use GPU acceleration
23    pub use_gpu: bool,
24    /// Model quality setting (0.0-1.0)
25    pub quality: f32,
26    /// Real-time processing constraints
27    pub realtime_constraints: RealtimeConstraints,
28    /// Training parameters
29    pub training_config: Option<TrainingConfig>,
30}
31
32/// Types of neural models available
33#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
34pub enum NeuralModelType {
35    /// Feedforward neural network for basic spatial synthesis
36    Feedforward,
37    /// Convolutional neural network for temporal-spatial processing
38    Convolutional,
39    /// Recurrent neural network for temporal modeling
40    Recurrent,
41    /// Transformer model for attention-based spatial processing
42    Transformer,
43    /// Generative Adversarial Network for high-quality synthesis
44    GAN,
45    /// Variational Autoencoder for latent space spatial modeling
46    VAE,
47    /// Diffusion model for high-fidelity spatial audio generation
48    Diffusion,
49    /// Hybrid model combining multiple architectures
50    Hybrid,
51}
52
53/// Real-time processing constraints
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct RealtimeConstraints {
56    /// Maximum latency in milliseconds
57    pub max_latency_ms: f32,
58    /// Maximum CPU usage percentage
59    pub max_cpu_usage: f32,
60    /// Maximum memory usage in MB
61    pub max_memory_mb: usize,
62    /// Target frame rate for processing
63    pub target_fps: u32,
64    /// Enable adaptive quality adjustment
65    pub adaptive_quality: bool,
66}
67
68/// Training configuration for neural models
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct TrainingConfig {
71    /// Learning rate
72    pub learning_rate: f64,
73    /// Batch size
74    pub batch_size: usize,
75    /// Number of training epochs
76    pub epochs: usize,
77    /// Validation split ratio
78    pub validation_split: f32,
79    /// Loss function type
80    pub loss_function: LossFunction,
81    /// Optimizer type
82    pub optimizer: OptimizerType,
83    /// Early stopping patience
84    pub early_stopping_patience: usize,
85    /// Data augmentation settings
86    pub augmentation: AugmentationConfig,
87}
88
89/// Neural network loss functions
90#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
91pub enum LossFunction {
92    /// Mean Squared Error for regression
93    MSE,
94    /// Mean Absolute Error
95    MAE,
96    /// Spectral loss for audio quality
97    SpectralLoss,
98    /// Perceptual loss based on human auditory system
99    PerceptualLoss,
100    /// Multi-scale spectral loss
101    MultiScaleSpectralLoss,
102    /// Combined loss function
103    Combined,
104}
105
106/// Optimizer types for training
107#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
108pub enum OptimizerType {
109    /// Adam optimizer
110    Adam,
111    /// Stochastic Gradient Descent
112    SGD,
113    /// AdamW with weight decay
114    AdamW,
115    /// RMSprop optimizer
116    RMSprop,
117}
118
119/// Data augmentation configuration
120#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct AugmentationConfig {
122    /// Enable noise injection
123    pub noise_injection: bool,
124    /// Enable time stretching
125    pub time_stretching: bool,
126    /// Enable pitch shifting
127    pub pitch_shifting: bool,
128    /// Enable reverb augmentation
129    pub reverb_augmentation: bool,
130    /// Random gain variation range
131    pub gain_variation: f32,
132}
133
134/// Input features for neural spatial processing
135#[derive(Debug, Clone)]
136pub struct NeuralInputFeatures {
137    /// 3D position of the sound source
138    pub position: Position3D,
139    /// Listener orientation (quaternion: w, x, y, z)
140    pub listener_orientation: [f32; 4],
141    /// Audio content features (e.g., spectral features)
142    pub audio_features: Vec<f32>,
143    /// Room acoustics parameters
144    pub room_features: Vec<f32>,
145    /// HRTF parameters if available
146    pub hrtf_features: Option<Vec<f32>>,
147    /// Temporal context from previous frames
148    pub temporal_context: Vec<f32>,
149    /// User-specific features (age, head size, etc.)
150    pub user_features: Option<Vec<f32>>,
151}
152
153/// Output from neural spatial processing
154#[derive(Debug, Clone)]
155pub struct NeuralSpatialOutput {
156    /// Synthesized binaural audio (left, right channels)
157    pub binaural_audio: Vec<Vec<f32>>,
158    /// Confidence score for the synthesis
159    pub confidence: f32,
160    /// Processing latency in milliseconds
161    pub latency_ms: f32,
162    /// Quality score (0.0-1.0)
163    pub quality_score: f32,
164    /// Additional metadata
165    pub metadata: HashMap<String, f32>,
166}
167
168/// Performance metrics for neural processing
169#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
170pub struct NeuralPerformanceMetrics {
171    /// Total number of processed frames
172    pub frames_processed: u64,
173    /// Average processing time per frame (ms)
174    pub avg_processing_time_ms: f32,
175    /// Peak processing time (ms)
176    pub peak_processing_time_ms: f32,
177    /// Memory usage in MB
178    pub memory_usage_mb: f32,
179    /// GPU utilization percentage
180    pub gpu_utilization: f32,
181    /// Model inference time (ms)
182    pub inference_time_ms: f32,
183    /// Quality degradation events
184    pub quality_degradations: u32,
185    /// Real-time violations
186    pub realtime_violations: u32,
187    /// Last updated timestamp (seconds since UNIX epoch)
188    pub last_updated: u64,
189}
190
191/// Training results from neural model training
192#[derive(Debug, Clone)]
193pub struct NeuralTrainingResults {
194    /// Training loss per epoch
195    pub training_loss: Vec<f32>,
196    /// Validation loss per epoch
197    pub validation_loss: Vec<f32>,
198    /// Final training accuracy
199    pub final_accuracy: f32,
200    /// Training duration in seconds
201    pub training_duration_secs: f32,
202    /// Number of epochs completed
203    pub epochs_completed: usize,
204    /// Whether early stopping was triggered
205    pub early_stopped: bool,
206}
207
208/// Builder for neural spatial processor configuration
209pub struct NeuralSpatialConfigBuilder {
210    config: NeuralSpatialConfig,
211}
212
213impl Default for NeuralSpatialConfig {
214    fn default() -> Self {
215        Self {
216            model_type: NeuralModelType::Feedforward,
217            hidden_dims: vec![512, 256, 128],
218            input_dim: 128,
219            output_channels: 2,
220            sample_rate: 48000,
221            buffer_size: 1024,
222            use_gpu: true,
223            quality: 0.8,
224            realtime_constraints: RealtimeConstraints {
225                max_latency_ms: 20.0,
226                max_cpu_usage: 25.0,
227                max_memory_mb: 512,
228                target_fps: 60,
229                adaptive_quality: true,
230            },
231            training_config: None,
232        }
233    }
234}
235
236impl Default for RealtimeConstraints {
237    fn default() -> Self {
238        Self {
239            max_latency_ms: 20.0,
240            max_cpu_usage: 25.0,
241            max_memory_mb: 512,
242            target_fps: 60,
243            adaptive_quality: true,
244        }
245    }
246}
247
248impl Default for TrainingConfig {
249    fn default() -> Self {
250        Self {
251            learning_rate: 0.001,
252            batch_size: 32,
253            epochs: 100,
254            validation_split: 0.2,
255            loss_function: LossFunction::MultiScaleSpectralLoss,
256            optimizer: OptimizerType::Adam,
257            early_stopping_patience: 10,
258            augmentation: AugmentationConfig::default(),
259        }
260    }
261}
262
263impl Default for AugmentationConfig {
264    fn default() -> Self {
265        Self {
266            noise_injection: true,
267            time_stretching: true,
268            pitch_shifting: true,
269            reverb_augmentation: true,
270            gain_variation: 0.1,
271        }
272    }
273}
274
275impl NeuralSpatialConfigBuilder {
276    /// Create a new configuration builder
277    pub fn new() -> Self {
278        Self {
279            config: NeuralSpatialConfig::default(),
280        }
281    }
282
283    /// Set the neural model type
284    pub fn model_type(mut self, model_type: NeuralModelType) -> Self {
285        self.config.model_type = model_type;
286        self
287    }
288
289    /// Set the hidden layer dimensions
290    pub fn hidden_dims(mut self, dims: Vec<usize>) -> Self {
291        self.config.hidden_dims = dims;
292        self
293    }
294
295    /// Set the input dimension
296    pub fn input_dim(mut self, dim: usize) -> Self {
297        self.config.input_dim = dim;
298        self
299    }
300
301    /// Set the number of output channels
302    pub fn output_channels(mut self, channels: usize) -> Self {
303        self.config.output_channels = channels;
304        self
305    }
306
307    /// Set the audio sample rate
308    pub fn sample_rate(mut self, rate: u32) -> Self {
309        self.config.sample_rate = rate;
310        self
311    }
312
313    /// Set the audio buffer size
314    pub fn buffer_size(mut self, size: usize) -> Self {
315        self.config.buffer_size = size;
316        self
317    }
318
319    /// Enable or disable GPU usage
320    pub fn use_gpu(mut self, use_gpu: bool) -> Self {
321        self.config.use_gpu = use_gpu;
322        self
323    }
324
325    /// Set the quality level (0.0-1.0)
326    pub fn quality(mut self, quality: f32) -> Self {
327        self.config.quality = quality.clamp(0.0, 1.0);
328        self
329    }
330
331    /// Set the maximum latency in milliseconds
332    pub fn max_latency_ms(mut self, latency: f32) -> Self {
333        self.config.realtime_constraints.max_latency_ms = latency;
334        self
335    }
336
337    /// Set the training configuration
338    pub fn training_config(mut self, training_config: TrainingConfig) -> Self {
339        self.config.training_config = Some(training_config);
340        self
341    }
342
343    /// Build the neural spatial configuration
344    pub fn build(self) -> NeuralSpatialConfig {
345        self.config
346    }
347}
348
349impl Default for NeuralSpatialConfigBuilder {
350    fn default() -> Self {
351        Self::new()
352    }
353}