Skip to main content

voirs_recognizer/
sdk_bridge.rs

1//! # `VoiRS` SDK Bridge
2//!
3//! Deep integration bridge between voirs-recognizer and voirs-sdk.
4//! Provides standardized interfaces for:
5//! - Cross-crate configuration synchronization
6//! - Common error handling patterns
7//! - Unified performance optimization
8//! - Shared resource management
9
10use crate::config::{AsrConfig, PreprocessingConfig, RecognizerConfig};
11use crate::RecognitionError;
12use parking_lot::RwLock;
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::sync::Arc;
16use thiserror::Error;
17use voirs_sdk::config::PipelineConfig;
18
19/// SDK Bridge errors
20#[derive(Debug, Error)]
21pub enum SdkBridgeError {
22    /// Configuration synchronization failed
23    #[error("Configuration sync failed: {0}")]
24    ConfigSyncFailed(String),
25
26    /// SDK communication error
27    #[error("SDK communication error: {0}")]
28    CommunicationError(String),
29
30    /// Resource conflict
31    #[error("Resource conflict: {0}")]
32    ResourceConflict(String),
33
34    /// Compatibility issue
35    #[error("Compatibility issue: {0}")]
36    CompatibilityIssue(String),
37
38    /// Recognition error wrapper
39    #[error("Recognition error: {0}")]
40    RecognitionError(#[from] RecognitionError),
41}
42
43/// SDK Bridge for deep ecosystem integration
44#[derive(Debug)]
45pub struct VoirsSdkBridge {
46    /// Recognizer configuration
47    recognizer_config: Arc<RwLock<RecognizerConfig>>,
48
49    /// SDK pipeline configuration
50    sdk_config: Arc<RwLock<PipelineConfig>>,
51
52    /// Cross-crate shared state
53    shared_state: Arc<RwLock<SharedState>>,
54
55    /// Performance optimization settings
56    optimization_settings: Arc<RwLock<OptimizationSettings>>,
57
58    /// Error mapping registry
59    error_registry: Arc<RwLock<ErrorRegistry>>,
60}
61
62/// Shared state across `VoiRS` crates
63#[derive(Debug, Clone, Default)]
64pub struct SharedState {
65    /// Shared model cache paths
66    pub model_cache_paths: HashMap<String, String>,
67
68    /// Shared audio buffer pools
69    pub audio_buffer_pool_size: usize,
70
71    /// Shared GPU device IDs
72    pub gpu_device_ids: Vec<i32>,
73
74    /// Shared thread pool size
75    pub thread_pool_size: usize,
76
77    /// Feature flags enabled across crates
78    pub feature_flags: HashMap<String, bool>,
79
80    /// Resource quotas
81    pub resource_quotas: ResourceQuotas,
82}
83
84/// Resource quotas for cross-crate resource management
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct ResourceQuotas {
87    /// Maximum total memory allocation in MB
88    pub max_memory_mb: f32,
89
90    /// Maximum GPU memory allocation in MB
91    pub max_gpu_memory_mb: Option<f32>,
92
93    /// Maximum concurrent operations
94    pub max_concurrent_operations: usize,
95
96    /// Maximum cache size in MB
97    pub max_cache_size_mb: f32,
98}
99
100impl Default for ResourceQuotas {
101    fn default() -> Self {
102        Self {
103            max_memory_mb: 2048.0,
104            max_gpu_memory_mb: Some(4096.0),
105            max_concurrent_operations: 4,
106            max_cache_size_mb: 512.0,
107        }
108    }
109}
110
111/// Cross-crate optimization settings
112#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct OptimizationSettings {
114    /// Enable cross-crate model sharing
115    pub enable_model_sharing: bool,
116
117    /// Enable unified memory pooling
118    pub enable_unified_memory_pool: bool,
119
120    /// Enable cross-crate caching
121    pub enable_cross_crate_cache: bool,
122
123    /// Enable GPU resource pooling
124    pub enable_gpu_pooling: bool,
125
126    /// Batch size coordination
127    pub coordinated_batch_size: Option<usize>,
128
129    /// Unified precision policy (fp32, fp16, int8)
130    pub unified_precision: PrecisionPolicy,
131
132    /// Enable parallel pipeline execution
133    pub enable_parallel_pipelines: bool,
134
135    /// Memory pressure threshold (0.0-1.0)
136    pub memory_pressure_threshold: f32,
137}
138
139impl Default for OptimizationSettings {
140    fn default() -> Self {
141        Self {
142            enable_model_sharing: true,
143            enable_unified_memory_pool: true,
144            enable_cross_crate_cache: true,
145            enable_gpu_pooling: true,
146            coordinated_batch_size: Some(4),
147            unified_precision: PrecisionPolicy::Mixed,
148            enable_parallel_pipelines: false,
149            memory_pressure_threshold: 0.8,
150        }
151    }
152}
153
154/// Unified precision policy across crates
155#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
156pub enum PrecisionPolicy {
157    /// Full precision (FP32)
158    Full,
159    /// Half precision (FP16)
160    Half,
161    /// Integer quantization (INT8)
162    Quantized,
163    /// Mixed precision (FP32 + FP16)
164    Mixed,
165}
166
167/// Error registry for standardized error handling
168#[derive(Default)]
169pub struct ErrorRegistry {
170    /// Error code mappings
171    error_codes: HashMap<String, ErrorCode>,
172
173    /// Error handlers
174    handlers: HashMap<String, ErrorHandler>,
175}
176
177impl std::fmt::Debug for ErrorRegistry {
178    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179        f.debug_struct("ErrorRegistry")
180            .field("error_codes", &self.error_codes)
181            .field("handlers_count", &self.handlers.len())
182            .finish()
183    }
184}
185
186/// Standardized error codes across `VoiRS` crates
187#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
188pub enum ErrorCode {
189    /// Success
190    Success = 0,
191
192    /// Generic error
193    GenericError = 1000,
194
195    /// Configuration error
196    ConfigError = 2000,
197
198    /// Model loading error
199    ModelLoadError = 3000,
200
201    /// Inference error
202    InferenceError = 4000,
203
204    /// Audio processing error
205    AudioProcessingError = 5000,
206
207    /// Memory allocation error
208    MemoryError = 6000,
209
210    /// GPU error
211    GpuError = 7000,
212
213    /// I/O error
214    IoError = 8000,
215
216    /// Timeout error
217    TimeoutError = 9000,
218
219    /// Resource exhausted error
220    ResourceExhausted = 10000,
221}
222
223/// Error handler function type
224pub type ErrorHandler = Arc<dyn Fn(&str) + Send + Sync>;
225
226impl VoirsSdkBridge {
227    /// Create a new SDK bridge
228    pub fn new(
229        recognizer_config: RecognizerConfig,
230        sdk_config: PipelineConfig,
231    ) -> Result<Self, SdkBridgeError> {
232        Ok(Self {
233            recognizer_config: Arc::new(RwLock::new(recognizer_config)),
234            sdk_config: Arc::new(RwLock::new(sdk_config)),
235            shared_state: Arc::new(RwLock::new(SharedState::default())),
236            optimization_settings: Arc::new(RwLock::new(OptimizationSettings::default())),
237            error_registry: Arc::new(RwLock::new(ErrorRegistry::default())),
238        })
239    }
240
241    /// Synchronize configuration from SDK to recognizer
242    pub fn sync_config_from_sdk(&self) -> Result<(), SdkBridgeError> {
243        let sdk_cfg = self.sdk_config.read();
244        let mut recognizer_cfg = self.recognizer_config.write();
245
246        // Synchronize common settings
247        let device_str = sdk_cfg.device.to_lowercase();
248        recognizer_cfg.asr.enable_gpu = sdk_cfg.use_gpu
249            || device_str.contains("gpu")
250            || device_str.contains("cuda")
251            || device_str.contains("metal");
252
253        // Sync thread count if available
254        if let Some(num_threads) = sdk_cfg.num_threads {
255            recognizer_cfg.performance.num_threads = num_threads;
256        }
257
258        Ok(())
259    }
260
261    /// Synchronize configuration from recognizer to SDK
262    pub fn sync_config_to_sdk(&self) -> Result<(), SdkBridgeError> {
263        let recognizer_cfg = self.recognizer_config.read();
264        let mut _sdk_cfg = self.sdk_config.write();
265
266        // Synchronize ASR settings back to SDK
267        // This would update the SDK configuration based on recognizer settings
268        tracing::debug!(
269            "Syncing recognizer config to SDK: GPU={}, batch_size={}",
270            recognizer_cfg.asr.enable_gpu,
271            recognizer_cfg.asr.batch_size
272        );
273
274        Ok(())
275    }
276
277    /// Get shared state
278    #[must_use]
279    pub fn shared_state(&self) -> Arc<RwLock<SharedState>> {
280        Arc::clone(&self.shared_state)
281    }
282
283    /// Get optimization settings
284    #[must_use]
285    pub fn optimization_settings(&self) -> Arc<RwLock<OptimizationSettings>> {
286        Arc::clone(&self.optimization_settings)
287    }
288
289    /// Update shared state
290    pub fn update_shared_state<F>(&self, updater: F) -> Result<(), SdkBridgeError>
291    where
292        F: FnOnce(&mut SharedState),
293    {
294        let mut state = self.shared_state.write();
295        updater(&mut state);
296        Ok(())
297    }
298
299    /// Update optimization settings
300    pub fn update_optimization_settings<F>(&self, updater: F) -> Result<(), SdkBridgeError>
301    where
302        F: FnOnce(&mut OptimizationSettings),
303    {
304        let mut settings = self.optimization_settings.write();
305        updater(&mut settings);
306        Ok(())
307    }
308
309    /// Register error code mapping
310    pub fn register_error_code(&self, error_name: String, code: ErrorCode) {
311        let mut registry = self.error_registry.write();
312        registry.error_codes.insert(error_name, code);
313    }
314
315    /// Register error handler
316    pub fn register_error_handler(&self, error_type: String, handler: ErrorHandler) {
317        let mut registry = self.error_registry.write();
318        registry.handlers.insert(error_type, handler);
319    }
320
321    /// Get error code for error name
322    #[must_use]
323    pub fn get_error_code(&self, error_name: &str) -> Option<ErrorCode> {
324        let registry = self.error_registry.read();
325        registry.error_codes.get(error_name).copied()
326    }
327
328    /// Handle error with registered handlers
329    pub fn handle_error(&self, error_type: &str, message: &str) {
330        let registry = self.error_registry.read();
331        if let Some(handler) = registry.handlers.get(error_type) {
332            handler(message);
333        } else {
334            tracing::error!("Unhandled error [{}]: {}", error_type, message);
335        }
336    }
337
338    /// Get recognizer configuration (read-only)
339    #[must_use]
340    pub fn recognizer_config(&self) -> Arc<RwLock<RecognizerConfig>> {
341        Arc::clone(&self.recognizer_config)
342    }
343
344    /// Get SDK configuration (read-only)
345    #[must_use]
346    pub fn sdk_config(&self) -> Arc<RwLock<PipelineConfig>> {
347        Arc::clone(&self.sdk_config)
348    }
349
350    /// Check resource availability
351    #[must_use]
352    pub fn check_resource_availability(&self) -> ResourceAvailability {
353        let state = self.shared_state.read();
354        let quotas = &state.resource_quotas;
355
356        ResourceAvailability {
357            memory_available_mb: quotas.max_memory_mb * 0.5, // Simplified calculation
358            gpu_memory_available_mb: quotas.max_gpu_memory_mb.map(|m| m * 0.5),
359            concurrent_slots_available: quotas.max_concurrent_operations / 2,
360            cache_space_available_mb: quotas.max_cache_size_mb * 0.5,
361        }
362    }
363
364    /// Allocate resources for an operation
365    pub fn allocate_resources(
366        &self,
367        requirements: &ResourceRequirements,
368    ) -> Result<ResourceAllocation, SdkBridgeError> {
369        let availability = self.check_resource_availability();
370
371        // Check if resources are available
372        if requirements.memory_mb > availability.memory_available_mb {
373            return Err(SdkBridgeError::ResourceConflict(
374                "Insufficient memory".to_string(),
375            ));
376        }
377
378        if let (Some(required_gpu), Some(available_gpu)) = (
379            requirements.gpu_memory_mb,
380            availability.gpu_memory_available_mb,
381        ) {
382            if required_gpu > available_gpu {
383                return Err(SdkBridgeError::ResourceConflict(
384                    "Insufficient GPU memory".to_string(),
385                ));
386            }
387        }
388
389        Ok(ResourceAllocation {
390            allocation_id: uuid::Uuid::new_v4().to_string(),
391            memory_allocated_mb: requirements.memory_mb,
392            gpu_memory_allocated_mb: requirements.gpu_memory_mb,
393            concurrent_slot: true,
394        })
395    }
396
397    /// Release allocated resources
398    pub fn release_resources(&self, allocation: &ResourceAllocation) {
399        tracing::debug!(
400            "Releasing resources: allocation_id={}, memory={} MB",
401            allocation.allocation_id,
402            allocation.memory_allocated_mb
403        );
404        // In a real implementation, this would update resource tracking
405    }
406}
407
408/// Resource availability information
409#[derive(Debug, Clone)]
410pub struct ResourceAvailability {
411    /// Available memory in MB
412    pub memory_available_mb: f32,
413
414    /// Available GPU memory in MB
415    pub gpu_memory_available_mb: Option<f32>,
416
417    /// Available concurrent operation slots
418    pub concurrent_slots_available: usize,
419
420    /// Available cache space in MB
421    pub cache_space_available_mb: f32,
422}
423
424/// Resource requirements
425#[derive(Debug, Clone)]
426pub struct ResourceRequirements {
427    /// Memory requirement in MB
428    pub memory_mb: f32,
429
430    /// GPU memory requirement in MB
431    pub gpu_memory_mb: Option<f32>,
432
433    /// Requires concurrent slot
434    pub requires_concurrent_slot: bool,
435}
436
437/// Resource allocation result
438#[derive(Debug, Clone)]
439pub struct ResourceAllocation {
440    /// Allocation ID
441    pub allocation_id: String,
442
443    /// Memory allocated in MB
444    pub memory_allocated_mb: f32,
445
446    /// GPU memory allocated in MB
447    pub gpu_memory_allocated_mb: Option<f32>,
448
449    /// Concurrent slot allocated
450    pub concurrent_slot: bool,
451}
452
453/// Builder for SDK bridge configuration
454pub struct SdkBridgeBuilder {
455    recognizer_config: Option<RecognizerConfig>,
456    sdk_config: Option<PipelineConfig>,
457    optimization_settings: OptimizationSettings,
458    resource_quotas: ResourceQuotas,
459}
460
461impl SdkBridgeBuilder {
462    /// Create a new builder
463    #[must_use]
464    pub fn new() -> Self {
465        Self {
466            recognizer_config: None,
467            sdk_config: None,
468            optimization_settings: OptimizationSettings::default(),
469            resource_quotas: ResourceQuotas::default(),
470        }
471    }
472
473    /// Set recognizer configuration
474    #[must_use]
475    pub fn with_recognizer_config(mut self, config: RecognizerConfig) -> Self {
476        self.recognizer_config = Some(config);
477        self
478    }
479
480    /// Set SDK configuration
481    #[must_use]
482    pub fn with_sdk_config(mut self, config: PipelineConfig) -> Self {
483        self.sdk_config = Some(config);
484        self
485    }
486
487    /// Set optimization settings
488    #[must_use]
489    pub fn with_optimization_settings(mut self, settings: OptimizationSettings) -> Self {
490        self.optimization_settings = settings;
491        self
492    }
493
494    /// Set resource quotas
495    #[must_use]
496    pub fn with_resource_quotas(mut self, quotas: ResourceQuotas) -> Self {
497        self.resource_quotas = quotas;
498        self
499    }
500
501    /// Build the SDK bridge
502    pub fn build(self) -> Result<VoirsSdkBridge, SdkBridgeError> {
503        let recognizer_config = self.recognizer_config.unwrap_or_default();
504        let sdk_config = self.sdk_config.unwrap_or_default();
505
506        let bridge = VoirsSdkBridge::new(recognizer_config, sdk_config)?;
507
508        // Apply optimization settings
509        bridge.update_optimization_settings(|settings| {
510            *settings = self.optimization_settings.clone();
511        })?;
512
513        // Apply resource quotas
514        bridge.update_shared_state(|state| {
515            state.resource_quotas = self.resource_quotas.clone();
516        })?;
517
518        Ok(bridge)
519    }
520}
521
522impl Default for SdkBridgeBuilder {
523    fn default() -> Self {
524        Self::new()
525    }
526}
527
528#[cfg(test)]
529mod tests {
530    use super::*;
531
532    #[test]
533    fn test_sdk_bridge_creation() {
534        let bridge = VoirsSdkBridge::new(RecognizerConfig::default(), PipelineConfig::default());
535        assert!(bridge.is_ok());
536    }
537
538    #[test]
539    fn test_config_synchronization() {
540        let bridge =
541            VoirsSdkBridge::new(RecognizerConfig::default(), PipelineConfig::default()).unwrap();
542
543        assert!(bridge.sync_config_from_sdk().is_ok());
544        assert!(bridge.sync_config_to_sdk().is_ok());
545    }
546
547    #[test]
548    fn test_shared_state_update() {
549        let bridge =
550            VoirsSdkBridge::new(RecognizerConfig::default(), PipelineConfig::default()).unwrap();
551
552        let result = bridge.update_shared_state(|state| {
553            state.thread_pool_size = 8;
554        });
555
556        assert!(result.is_ok());
557        assert_eq!(bridge.shared_state().read().thread_pool_size, 8);
558    }
559
560    #[test]
561    fn test_optimization_settings_update() {
562        let bridge =
563            VoirsSdkBridge::new(RecognizerConfig::default(), PipelineConfig::default()).unwrap();
564
565        let result = bridge.update_optimization_settings(|settings| {
566            settings.enable_model_sharing = false;
567        });
568
569        assert!(result.is_ok());
570        assert!(!bridge.optimization_settings().read().enable_model_sharing);
571    }
572
573    #[test]
574    fn test_error_registry() {
575        let bridge =
576            VoirsSdkBridge::new(RecognizerConfig::default(), PipelineConfig::default()).unwrap();
577
578        bridge.register_error_code("test_error".to_string(), ErrorCode::GenericError);
579
580        let code = bridge.get_error_code("test_error");
581        assert_eq!(code, Some(ErrorCode::GenericError));
582    }
583
584    #[test]
585    fn test_resource_allocation() {
586        let bridge =
587            VoirsSdkBridge::new(RecognizerConfig::default(), PipelineConfig::default()).unwrap();
588
589        let requirements = ResourceRequirements {
590            memory_mb: 100.0,
591            gpu_memory_mb: Some(200.0),
592            requires_concurrent_slot: true,
593        };
594
595        let allocation = bridge.allocate_resources(&requirements);
596        assert!(allocation.is_ok());
597
598        if let Ok(alloc) = allocation {
599            bridge.release_resources(&alloc);
600        }
601    }
602
603    #[test]
604    fn test_resource_exhaustion() {
605        let bridge =
606            VoirsSdkBridge::new(RecognizerConfig::default(), PipelineConfig::default()).unwrap();
607
608        // Try to allocate more memory than available
609        let requirements = ResourceRequirements {
610            memory_mb: 10000.0, // More than default quota
611            gpu_memory_mb: None,
612            requires_concurrent_slot: true,
613        };
614
615        let allocation = bridge.allocate_resources(&requirements);
616        assert!(allocation.is_err());
617    }
618
619    #[test]
620    fn test_builder_pattern() {
621        let builder = SdkBridgeBuilder::new()
622            .with_recognizer_config(RecognizerConfig::default())
623            .with_sdk_config(PipelineConfig::default())
624            .with_optimization_settings(OptimizationSettings {
625                enable_model_sharing: false,
626                ..OptimizationSettings::default()
627            });
628
629        let bridge = builder.build();
630        assert!(bridge.is_ok());
631
632        if let Ok(b) = bridge {
633            assert!(!b.optimization_settings().read().enable_model_sharing);
634        }
635    }
636}