voirs_spatial/neural/
processor.rs1use super::models::*;
4use super::quality::AdaptiveQualityController;
5use super::training::{NeuralTrainer, NeuralTrainingResults};
6use super::types::*;
7use crate::{Error, Result};
8use candle_core::Device;
9use std::collections::HashMap;
10use std::sync::{Arc, RwLock};
11
12pub struct NeuralSpatialProcessor {
14 config: NeuralSpatialConfig,
16 model: Box<dyn NeuralModel + Send + Sync>,
18 device: Device,
20 metrics: Arc<RwLock<NeuralPerformanceMetrics>>,
22 input_buffer: Arc<RwLock<Vec<NeuralInputFeatures>>>,
24 model_cache: Arc<RwLock<HashMap<String, Box<dyn NeuralModel + Send + Sync>>>>,
26 quality_controller: AdaptiveQualityController,
28}
29
30impl NeuralSpatialProcessor {
31 pub fn new(config: NeuralSpatialConfig) -> Result<Self> {
33 let device = if config.use_gpu {
34 std::panic::catch_unwind(|| Device::new_cuda(0))
36 .unwrap_or(Ok(Device::Cpu))
37 .unwrap_or(Device::Cpu)
38 } else {
39 Device::Cpu
40 };
41
42 let model = Self::create_model(&config, &device)?;
43 let quality_controller =
44 AdaptiveQualityController::new(config.realtime_constraints.max_latency_ms);
45
46 Ok(Self {
47 config,
48 model,
49 device,
50 metrics: Arc::new(RwLock::new(NeuralPerformanceMetrics::default())),
51 input_buffer: Arc::new(RwLock::new(Vec::new())),
52 model_cache: Arc::new(RwLock::new(HashMap::new())),
53 quality_controller,
54 })
55 }
56
57 fn create_model(
59 config: &NeuralSpatialConfig,
60 device: &Device,
61 ) -> Result<Box<dyn NeuralModel + Send + Sync>> {
62 match config.model_type {
63 NeuralModelType::Feedforward => Ok(Box::new(FeedforwardModel::new(
64 config.clone(),
65 device.clone(),
66 )?)),
67 NeuralModelType::Convolutional => Ok(Box::new(ConvolutionalModel::new(
68 config.clone(),
69 device.clone(),
70 )?)),
71 NeuralModelType::Transformer => Ok(Box::new(TransformerModel::new(
72 config.clone(),
73 device.clone(),
74 )?)),
75 _ => Err(Error::LegacyProcessing(format!(
76 "Neural model type {:?} not yet implemented",
77 config.model_type
78 ))),
79 }
80 }
81
82 pub fn process(&mut self, input: &NeuralInputFeatures) -> Result<NeuralSpatialOutput> {
84 let start_time = std::time::Instant::now();
85
86 {
88 let mut buffer = self.input_buffer.write().map_err(|e| {
89 Error::LegacyProcessing(format!(
90 "Failed to acquire write lock on input buffer: {}",
91 e
92 ))
93 })?;
94 buffer.push(input.clone());
95
96 if buffer.len() > 10 {
98 buffer.remove(0);
99 }
100 }
101
102 let mut output = self.model.forward(input)?;
104
105 let processing_time = start_time.elapsed().as_secs_f32() * 1000.0;
107 output.latency_ms = processing_time;
108
109 {
111 let mut metrics = self.metrics.write().map_err(|e| {
112 Error::LegacyProcessing(format!("Failed to acquire write lock on metrics: {}", e))
113 })?;
114 metrics.frames_processed += 1;
115 metrics.avg_processing_time_ms = (metrics.avg_processing_time_ms
116 * (metrics.frames_processed - 1) as f32
117 + processing_time)
118 / metrics.frames_processed as f32;
119 metrics.peak_processing_time_ms = metrics.peak_processing_time_ms.max(processing_time);
120
121 if processing_time > self.config.realtime_constraints.max_latency_ms {
122 metrics.realtime_violations += 1;
123 }
124 }
125
126 if self.config.realtime_constraints.adaptive_quality {
128 self.quality_controller.update(processing_time);
129 let new_quality = self.quality_controller.get_quality();
130 if (new_quality - self.quality_controller.current_quality).abs() > 0.05 {
131 self.model.set_quality(new_quality)?;
132 self.quality_controller.current_quality = new_quality;
133 }
134 }
135
136 output.quality_score = self.quality_controller.current_quality;
137 Ok(output)
138 }
139
140 pub fn process_batch(
142 &mut self,
143 inputs: &[NeuralInputFeatures],
144 ) -> Result<Vec<NeuralSpatialOutput>> {
145 let mut outputs = Vec::with_capacity(inputs.len());
146
147 for input in inputs {
148 outputs.push(self.process(input)?);
149 }
150
151 Ok(outputs)
152 }
153
154 pub fn metrics(&self) -> Result<NeuralPerformanceMetrics> {
156 Ok(self
157 .metrics
158 .read()
159 .map_err(|e| {
160 Error::LegacyProcessing(format!("Failed to acquire read lock on metrics: {}", e))
161 })?
162 .clone())
163 }
164
165 pub fn reset_metrics(&self) -> Result<()> {
167 let mut metrics = self.metrics.write().map_err(|e| {
168 Error::LegacyProcessing(format!("Failed to acquire write lock on metrics: {}", e))
169 })?;
170 *metrics = NeuralPerformanceMetrics::default();
171 Ok(())
172 }
173
174 pub fn update_config(&mut self, new_config: NeuralSpatialConfig) -> Result<()> {
176 if new_config.model_type != self.config.model_type
178 || new_config.hidden_dims != self.config.hidden_dims
179 {
180 self.model = Self::create_model(&new_config, &self.device)?;
181 }
182
183 self.config = new_config;
184 self.quality_controller.target_latency_ms = self.config.realtime_constraints.max_latency_ms;
185
186 Ok(())
187 }
188
189 pub fn train(
191 &mut self,
192 training_data: &[(NeuralInputFeatures, Vec<Vec<f32>>)],
193 ) -> Result<NeuralTrainingResults> {
194 let config = self.config.training_config.as_ref().ok_or_else(|| {
195 Error::LegacyConfig("Training configuration not provided".to_string())
196 })?;
197
198 let mut trainer = NeuralTrainer::new(config.clone());
199 trainer.train(&mut *self.model, training_data)
200 }
201
202 pub fn save_model(&self, path: &str) -> Result<()> {
204 self.model.save(path)
205 }
206
207 pub fn load_model(&mut self, path: &str) -> Result<()> {
209 self.model.load(path)
210 }
211}