1use crate::config::{AsrConfig, PreprocessingConfig, RecognizerConfig};
11use crate::RecognitionError;
12use parking_lot::RwLock;
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::sync::Arc;
16use thiserror::Error;
17use voirs_sdk::config::PipelineConfig;
18
19#[derive(Debug, Error)]
21pub enum SdkBridgeError {
22 #[error("Configuration sync failed: {0}")]
24 ConfigSyncFailed(String),
25
26 #[error("SDK communication error: {0}")]
28 CommunicationError(String),
29
30 #[error("Resource conflict: {0}")]
32 ResourceConflict(String),
33
34 #[error("Compatibility issue: {0}")]
36 CompatibilityIssue(String),
37
38 #[error("Recognition error: {0}")]
40 RecognitionError(#[from] RecognitionError),
41}
42
43#[derive(Debug)]
45pub struct VoirsSdkBridge {
46 recognizer_config: Arc<RwLock<RecognizerConfig>>,
48
49 sdk_config: Arc<RwLock<PipelineConfig>>,
51
52 shared_state: Arc<RwLock<SharedState>>,
54
55 optimization_settings: Arc<RwLock<OptimizationSettings>>,
57
58 error_registry: Arc<RwLock<ErrorRegistry>>,
60}
61
62#[derive(Debug, Clone, Default)]
64pub struct SharedState {
65 pub model_cache_paths: HashMap<String, String>,
67
68 pub audio_buffer_pool_size: usize,
70
71 pub gpu_device_ids: Vec<i32>,
73
74 pub thread_pool_size: usize,
76
77 pub feature_flags: HashMap<String, bool>,
79
80 pub resource_quotas: ResourceQuotas,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct ResourceQuotas {
87 pub max_memory_mb: f32,
89
90 pub max_gpu_memory_mb: Option<f32>,
92
93 pub max_concurrent_operations: usize,
95
96 pub max_cache_size_mb: f32,
98}
99
100impl Default for ResourceQuotas {
101 fn default() -> Self {
102 Self {
103 max_memory_mb: 2048.0,
104 max_gpu_memory_mb: Some(4096.0),
105 max_concurrent_operations: 4,
106 max_cache_size_mb: 512.0,
107 }
108 }
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct OptimizationSettings {
114 pub enable_model_sharing: bool,
116
117 pub enable_unified_memory_pool: bool,
119
120 pub enable_cross_crate_cache: bool,
122
123 pub enable_gpu_pooling: bool,
125
126 pub coordinated_batch_size: Option<usize>,
128
129 pub unified_precision: PrecisionPolicy,
131
132 pub enable_parallel_pipelines: bool,
134
135 pub memory_pressure_threshold: f32,
137}
138
139impl Default for OptimizationSettings {
140 fn default() -> Self {
141 Self {
142 enable_model_sharing: true,
143 enable_unified_memory_pool: true,
144 enable_cross_crate_cache: true,
145 enable_gpu_pooling: true,
146 coordinated_batch_size: Some(4),
147 unified_precision: PrecisionPolicy::Mixed,
148 enable_parallel_pipelines: false,
149 memory_pressure_threshold: 0.8,
150 }
151 }
152}
153
154#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
156pub enum PrecisionPolicy {
157 Full,
159 Half,
161 Quantized,
163 Mixed,
165}
166
167#[derive(Default)]
169pub struct ErrorRegistry {
170 error_codes: HashMap<String, ErrorCode>,
172
173 handlers: HashMap<String, ErrorHandler>,
175}
176
177impl std::fmt::Debug for ErrorRegistry {
178 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179 f.debug_struct("ErrorRegistry")
180 .field("error_codes", &self.error_codes)
181 .field("handlers_count", &self.handlers.len())
182 .finish()
183 }
184}
185
186#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
188pub enum ErrorCode {
189 Success = 0,
191
192 GenericError = 1000,
194
195 ConfigError = 2000,
197
198 ModelLoadError = 3000,
200
201 InferenceError = 4000,
203
204 AudioProcessingError = 5000,
206
207 MemoryError = 6000,
209
210 GpuError = 7000,
212
213 IoError = 8000,
215
216 TimeoutError = 9000,
218
219 ResourceExhausted = 10000,
221}
222
223pub type ErrorHandler = Arc<dyn Fn(&str) + Send + Sync>;
225
226impl VoirsSdkBridge {
227 pub fn new(
229 recognizer_config: RecognizerConfig,
230 sdk_config: PipelineConfig,
231 ) -> Result<Self, SdkBridgeError> {
232 Ok(Self {
233 recognizer_config: Arc::new(RwLock::new(recognizer_config)),
234 sdk_config: Arc::new(RwLock::new(sdk_config)),
235 shared_state: Arc::new(RwLock::new(SharedState::default())),
236 optimization_settings: Arc::new(RwLock::new(OptimizationSettings::default())),
237 error_registry: Arc::new(RwLock::new(ErrorRegistry::default())),
238 })
239 }
240
241 pub fn sync_config_from_sdk(&self) -> Result<(), SdkBridgeError> {
243 let sdk_cfg = self.sdk_config.read();
244 let mut recognizer_cfg = self.recognizer_config.write();
245
246 let device_str = sdk_cfg.device.to_lowercase();
248 recognizer_cfg.asr.enable_gpu = sdk_cfg.use_gpu
249 || device_str.contains("gpu")
250 || device_str.contains("cuda")
251 || device_str.contains("metal");
252
253 if let Some(num_threads) = sdk_cfg.num_threads {
255 recognizer_cfg.performance.num_threads = num_threads;
256 }
257
258 Ok(())
259 }
260
261 pub fn sync_config_to_sdk(&self) -> Result<(), SdkBridgeError> {
263 let recognizer_cfg = self.recognizer_config.read();
264 let mut _sdk_cfg = self.sdk_config.write();
265
266 tracing::debug!(
269 "Syncing recognizer config to SDK: GPU={}, batch_size={}",
270 recognizer_cfg.asr.enable_gpu,
271 recognizer_cfg.asr.batch_size
272 );
273
274 Ok(())
275 }
276
277 #[must_use]
279 pub fn shared_state(&self) -> Arc<RwLock<SharedState>> {
280 Arc::clone(&self.shared_state)
281 }
282
283 #[must_use]
285 pub fn optimization_settings(&self) -> Arc<RwLock<OptimizationSettings>> {
286 Arc::clone(&self.optimization_settings)
287 }
288
289 pub fn update_shared_state<F>(&self, updater: F) -> Result<(), SdkBridgeError>
291 where
292 F: FnOnce(&mut SharedState),
293 {
294 let mut state = self.shared_state.write();
295 updater(&mut state);
296 Ok(())
297 }
298
299 pub fn update_optimization_settings<F>(&self, updater: F) -> Result<(), SdkBridgeError>
301 where
302 F: FnOnce(&mut OptimizationSettings),
303 {
304 let mut settings = self.optimization_settings.write();
305 updater(&mut settings);
306 Ok(())
307 }
308
309 pub fn register_error_code(&self, error_name: String, code: ErrorCode) {
311 let mut registry = self.error_registry.write();
312 registry.error_codes.insert(error_name, code);
313 }
314
315 pub fn register_error_handler(&self, error_type: String, handler: ErrorHandler) {
317 let mut registry = self.error_registry.write();
318 registry.handlers.insert(error_type, handler);
319 }
320
321 #[must_use]
323 pub fn get_error_code(&self, error_name: &str) -> Option<ErrorCode> {
324 let registry = self.error_registry.read();
325 registry.error_codes.get(error_name).copied()
326 }
327
328 pub fn handle_error(&self, error_type: &str, message: &str) {
330 let registry = self.error_registry.read();
331 if let Some(handler) = registry.handlers.get(error_type) {
332 handler(message);
333 } else {
334 tracing::error!("Unhandled error [{}]: {}", error_type, message);
335 }
336 }
337
338 #[must_use]
340 pub fn recognizer_config(&self) -> Arc<RwLock<RecognizerConfig>> {
341 Arc::clone(&self.recognizer_config)
342 }
343
344 #[must_use]
346 pub fn sdk_config(&self) -> Arc<RwLock<PipelineConfig>> {
347 Arc::clone(&self.sdk_config)
348 }
349
350 #[must_use]
352 pub fn check_resource_availability(&self) -> ResourceAvailability {
353 let state = self.shared_state.read();
354 let quotas = &state.resource_quotas;
355
356 ResourceAvailability {
357 memory_available_mb: quotas.max_memory_mb * 0.5, gpu_memory_available_mb: quotas.max_gpu_memory_mb.map(|m| m * 0.5),
359 concurrent_slots_available: quotas.max_concurrent_operations / 2,
360 cache_space_available_mb: quotas.max_cache_size_mb * 0.5,
361 }
362 }
363
364 pub fn allocate_resources(
366 &self,
367 requirements: &ResourceRequirements,
368 ) -> Result<ResourceAllocation, SdkBridgeError> {
369 let availability = self.check_resource_availability();
370
371 if requirements.memory_mb > availability.memory_available_mb {
373 return Err(SdkBridgeError::ResourceConflict(
374 "Insufficient memory".to_string(),
375 ));
376 }
377
378 if let (Some(required_gpu), Some(available_gpu)) = (
379 requirements.gpu_memory_mb,
380 availability.gpu_memory_available_mb,
381 ) {
382 if required_gpu > available_gpu {
383 return Err(SdkBridgeError::ResourceConflict(
384 "Insufficient GPU memory".to_string(),
385 ));
386 }
387 }
388
389 Ok(ResourceAllocation {
390 allocation_id: uuid::Uuid::new_v4().to_string(),
391 memory_allocated_mb: requirements.memory_mb,
392 gpu_memory_allocated_mb: requirements.gpu_memory_mb,
393 concurrent_slot: true,
394 })
395 }
396
397 pub fn release_resources(&self, allocation: &ResourceAllocation) {
399 tracing::debug!(
400 "Releasing resources: allocation_id={}, memory={} MB",
401 allocation.allocation_id,
402 allocation.memory_allocated_mb
403 );
404 }
406}
407
408#[derive(Debug, Clone)]
410pub struct ResourceAvailability {
411 pub memory_available_mb: f32,
413
414 pub gpu_memory_available_mb: Option<f32>,
416
417 pub concurrent_slots_available: usize,
419
420 pub cache_space_available_mb: f32,
422}
423
424#[derive(Debug, Clone)]
426pub struct ResourceRequirements {
427 pub memory_mb: f32,
429
430 pub gpu_memory_mb: Option<f32>,
432
433 pub requires_concurrent_slot: bool,
435}
436
437#[derive(Debug, Clone)]
439pub struct ResourceAllocation {
440 pub allocation_id: String,
442
443 pub memory_allocated_mb: f32,
445
446 pub gpu_memory_allocated_mb: Option<f32>,
448
449 pub concurrent_slot: bool,
451}
452
453pub struct SdkBridgeBuilder {
455 recognizer_config: Option<RecognizerConfig>,
456 sdk_config: Option<PipelineConfig>,
457 optimization_settings: OptimizationSettings,
458 resource_quotas: ResourceQuotas,
459}
460
461impl SdkBridgeBuilder {
462 #[must_use]
464 pub fn new() -> Self {
465 Self {
466 recognizer_config: None,
467 sdk_config: None,
468 optimization_settings: OptimizationSettings::default(),
469 resource_quotas: ResourceQuotas::default(),
470 }
471 }
472
473 #[must_use]
475 pub fn with_recognizer_config(mut self, config: RecognizerConfig) -> Self {
476 self.recognizer_config = Some(config);
477 self
478 }
479
480 #[must_use]
482 pub fn with_sdk_config(mut self, config: PipelineConfig) -> Self {
483 self.sdk_config = Some(config);
484 self
485 }
486
487 #[must_use]
489 pub fn with_optimization_settings(mut self, settings: OptimizationSettings) -> Self {
490 self.optimization_settings = settings;
491 self
492 }
493
494 #[must_use]
496 pub fn with_resource_quotas(mut self, quotas: ResourceQuotas) -> Self {
497 self.resource_quotas = quotas;
498 self
499 }
500
501 pub fn build(self) -> Result<VoirsSdkBridge, SdkBridgeError> {
503 let recognizer_config = self.recognizer_config.unwrap_or_default();
504 let sdk_config = self.sdk_config.unwrap_or_default();
505
506 let bridge = VoirsSdkBridge::new(recognizer_config, sdk_config)?;
507
508 bridge.update_optimization_settings(|settings| {
510 *settings = self.optimization_settings.clone();
511 })?;
512
513 bridge.update_shared_state(|state| {
515 state.resource_quotas = self.resource_quotas.clone();
516 })?;
517
518 Ok(bridge)
519 }
520}
521
522impl Default for SdkBridgeBuilder {
523 fn default() -> Self {
524 Self::new()
525 }
526}
527
528#[cfg(test)]
529mod tests {
530 use super::*;
531
532 #[test]
533 fn test_sdk_bridge_creation() {
534 let bridge = VoirsSdkBridge::new(RecognizerConfig::default(), PipelineConfig::default());
535 assert!(bridge.is_ok());
536 }
537
538 #[test]
539 fn test_config_synchronization() {
540 let bridge =
541 VoirsSdkBridge::new(RecognizerConfig::default(), PipelineConfig::default()).unwrap();
542
543 assert!(bridge.sync_config_from_sdk().is_ok());
544 assert!(bridge.sync_config_to_sdk().is_ok());
545 }
546
547 #[test]
548 fn test_shared_state_update() {
549 let bridge =
550 VoirsSdkBridge::new(RecognizerConfig::default(), PipelineConfig::default()).unwrap();
551
552 let result = bridge.update_shared_state(|state| {
553 state.thread_pool_size = 8;
554 });
555
556 assert!(result.is_ok());
557 assert_eq!(bridge.shared_state().read().thread_pool_size, 8);
558 }
559
560 #[test]
561 fn test_optimization_settings_update() {
562 let bridge =
563 VoirsSdkBridge::new(RecognizerConfig::default(), PipelineConfig::default()).unwrap();
564
565 let result = bridge.update_optimization_settings(|settings| {
566 settings.enable_model_sharing = false;
567 });
568
569 assert!(result.is_ok());
570 assert!(!bridge.optimization_settings().read().enable_model_sharing);
571 }
572
573 #[test]
574 fn test_error_registry() {
575 let bridge =
576 VoirsSdkBridge::new(RecognizerConfig::default(), PipelineConfig::default()).unwrap();
577
578 bridge.register_error_code("test_error".to_string(), ErrorCode::GenericError);
579
580 let code = bridge.get_error_code("test_error");
581 assert_eq!(code, Some(ErrorCode::GenericError));
582 }
583
584 #[test]
585 fn test_resource_allocation() {
586 let bridge =
587 VoirsSdkBridge::new(RecognizerConfig::default(), PipelineConfig::default()).unwrap();
588
589 let requirements = ResourceRequirements {
590 memory_mb: 100.0,
591 gpu_memory_mb: Some(200.0),
592 requires_concurrent_slot: true,
593 };
594
595 let allocation = bridge.allocate_resources(&requirements);
596 assert!(allocation.is_ok());
597
598 if let Ok(alloc) = allocation {
599 bridge.release_resources(&alloc);
600 }
601 }
602
603 #[test]
604 fn test_resource_exhaustion() {
605 let bridge =
606 VoirsSdkBridge::new(RecognizerConfig::default(), PipelineConfig::default()).unwrap();
607
608 let requirements = ResourceRequirements {
610 memory_mb: 10000.0, gpu_memory_mb: None,
612 requires_concurrent_slot: true,
613 };
614
615 let allocation = bridge.allocate_resources(&requirements);
616 assert!(allocation.is_err());
617 }
618
619 #[test]
620 fn test_builder_pattern() {
621 let builder = SdkBridgeBuilder::new()
622 .with_recognizer_config(RecognizerConfig::default())
623 .with_sdk_config(PipelineConfig::default())
624 .with_optimization_settings(OptimizationSettings {
625 enable_model_sharing: false,
626 ..OptimizationSettings::default()
627 });
628
629 let bridge = builder.build();
630 assert!(bridge.is_ok());
631
632 if let Ok(b) = bridge {
633 assert!(!b.optimization_settings().read().enable_model_sharing);
634 }
635 }
636}