Skip to main content

voirs_spatial/neural/
processor.rs

1//! Neural spatial audio processor implementation
2
3use super::models::*;
4use super::quality::AdaptiveQualityController;
5use super::training::{NeuralTrainer, NeuralTrainingResults};
6use super::types::*;
7use crate::{Error, Result};
8use candle_core::Device;
9use std::collections::HashMap;
10use std::sync::{Arc, RwLock};
11
12/// Neural spatial audio processor
13pub struct NeuralSpatialProcessor {
14    /// Configuration
15    config: NeuralSpatialConfig,
16    /// Neural network model
17    model: Box<dyn NeuralModel + Send + Sync>,
18    /// Computing device (CPU/GPU)
19    device: Device,
20    /// Performance metrics
21    metrics: Arc<RwLock<NeuralPerformanceMetrics>>,
22    /// Input buffer for temporal context
23    input_buffer: Arc<RwLock<Vec<NeuralInputFeatures>>>,
24    /// Model cache for different configurations
25    model_cache: Arc<RwLock<HashMap<String, Box<dyn NeuralModel + Send + Sync>>>>,
26    /// Quality adaptation controller
27    quality_controller: AdaptiveQualityController,
28}
29
30impl NeuralSpatialProcessor {
31    /// Create a new neural spatial processor
32    pub fn new(config: NeuralSpatialConfig) -> Result<Self> {
33        let device = if config.use_gpu {
34            // Use catch_unwind because Device::new_cuda can panic on systems without CUDA
35            std::panic::catch_unwind(|| Device::new_cuda(0))
36                .unwrap_or(Ok(Device::Cpu))
37                .unwrap_or(Device::Cpu)
38        } else {
39            Device::Cpu
40        };
41
42        let model = Self::create_model(&config, &device)?;
43        let quality_controller =
44            AdaptiveQualityController::new(config.realtime_constraints.max_latency_ms);
45
46        Ok(Self {
47            config,
48            model,
49            device,
50            metrics: Arc::new(RwLock::new(NeuralPerformanceMetrics::default())),
51            input_buffer: Arc::new(RwLock::new(Vec::new())),
52            model_cache: Arc::new(RwLock::new(HashMap::new())),
53            quality_controller,
54        })
55    }
56
57    /// Create a model based on configuration
58    fn create_model(
59        config: &NeuralSpatialConfig,
60        device: &Device,
61    ) -> Result<Box<dyn NeuralModel + Send + Sync>> {
62        match config.model_type {
63            NeuralModelType::Feedforward => Ok(Box::new(FeedforwardModel::new(
64                config.clone(),
65                device.clone(),
66            )?)),
67            NeuralModelType::Convolutional => Ok(Box::new(ConvolutionalModel::new(
68                config.clone(),
69                device.clone(),
70            )?)),
71            NeuralModelType::Transformer => Ok(Box::new(TransformerModel::new(
72                config.clone(),
73                device.clone(),
74            )?)),
75            _ => Err(Error::LegacyProcessing(format!(
76                "Neural model type {:?} not yet implemented",
77                config.model_type
78            ))),
79        }
80    }
81
82    /// Process audio with neural spatial synthesis
83    pub fn process(&mut self, input: &NeuralInputFeatures) -> Result<NeuralSpatialOutput> {
84        let start_time = std::time::Instant::now();
85
86        // Add to input buffer for temporal context
87        {
88            let mut buffer = self.input_buffer.write().map_err(|e| {
89                Error::LegacyProcessing(format!(
90                    "Failed to acquire write lock on input buffer: {}",
91                    e
92                ))
93            })?;
94            buffer.push(input.clone());
95
96            // Keep only recent frames for temporal context
97            if buffer.len() > 10 {
98                buffer.remove(0);
99            }
100        }
101
102        // Forward pass through the model
103        let mut output = self.model.forward(input)?;
104
105        // Calculate processing time
106        let processing_time = start_time.elapsed().as_secs_f32() * 1000.0;
107        output.latency_ms = processing_time;
108
109        // Update metrics
110        {
111            let mut metrics = self.metrics.write().map_err(|e| {
112                Error::LegacyProcessing(format!("Failed to acquire write lock on metrics: {}", e))
113            })?;
114            metrics.frames_processed += 1;
115            metrics.avg_processing_time_ms = (metrics.avg_processing_time_ms
116                * (metrics.frames_processed - 1) as f32
117                + processing_time)
118                / metrics.frames_processed as f32;
119            metrics.peak_processing_time_ms = metrics.peak_processing_time_ms.max(processing_time);
120
121            if processing_time > self.config.realtime_constraints.max_latency_ms {
122                metrics.realtime_violations += 1;
123            }
124        }
125
126        // Adaptive quality control
127        if self.config.realtime_constraints.adaptive_quality {
128            self.quality_controller.update(processing_time);
129            let new_quality = self.quality_controller.get_quality();
130            if (new_quality - self.quality_controller.current_quality).abs() > 0.05 {
131                self.model.set_quality(new_quality)?;
132                self.quality_controller.current_quality = new_quality;
133            }
134        }
135
136        output.quality_score = self.quality_controller.current_quality;
137        Ok(output)
138    }
139
140    /// Process batch of inputs for better efficiency
141    pub fn process_batch(
142        &mut self,
143        inputs: &[NeuralInputFeatures],
144    ) -> Result<Vec<NeuralSpatialOutput>> {
145        let mut outputs = Vec::with_capacity(inputs.len());
146
147        for input in inputs {
148            outputs.push(self.process(input)?);
149        }
150
151        Ok(outputs)
152    }
153
154    /// Get current performance metrics
155    pub fn metrics(&self) -> Result<NeuralPerformanceMetrics> {
156        Ok(self
157            .metrics
158            .read()
159            .map_err(|e| {
160                Error::LegacyProcessing(format!("Failed to acquire read lock on metrics: {}", e))
161            })?
162            .clone())
163    }
164
165    /// Reset performance metrics
166    pub fn reset_metrics(&self) -> Result<()> {
167        let mut metrics = self.metrics.write().map_err(|e| {
168            Error::LegacyProcessing(format!("Failed to acquire write lock on metrics: {}", e))
169        })?;
170        *metrics = NeuralPerformanceMetrics::default();
171        Ok(())
172    }
173
174    /// Update configuration
175    pub fn update_config(&mut self, new_config: NeuralSpatialConfig) -> Result<()> {
176        // Check if model needs to be recreated
177        if new_config.model_type != self.config.model_type
178            || new_config.hidden_dims != self.config.hidden_dims
179        {
180            self.model = Self::create_model(&new_config, &self.device)?;
181        }
182
183        self.config = new_config;
184        self.quality_controller.target_latency_ms = self.config.realtime_constraints.max_latency_ms;
185
186        Ok(())
187    }
188
189    /// Train the neural model with provided data
190    pub fn train(
191        &mut self,
192        training_data: &[(NeuralInputFeatures, Vec<Vec<f32>>)],
193    ) -> Result<NeuralTrainingResults> {
194        let config = self.config.training_config.as_ref().ok_or_else(|| {
195            Error::LegacyConfig("Training configuration not provided".to_string())
196        })?;
197
198        let mut trainer = NeuralTrainer::new(config.clone());
199        trainer.train(&mut *self.model, training_data)
200    }
201
202    /// Save the current model
203    pub fn save_model(&self, path: &str) -> Result<()> {
204        self.model.save(path)
205    }
206
207    /// Load a trained model
208    pub fn load_model(&mut self, path: &str) -> Result<()> {
209        self.model.load(path)
210    }
211}