1use crate::{
7 config::ConversionConfig,
8 models::ConversionModel,
9 types::{ConversionRequest, ConversionResult, ConversionType},
10 Error, Result,
11};
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
15use std::sync::Arc;
16use std::sync::Weak;
17use std::time::{Duration, Instant};
18use tokio::sync::{Mutex, OwnedSemaphorePermit, RwLock, Semaphore};
19use tracing::{debug, info, trace, warn};
20
21pub struct MemorySafetyAuditor {
23 allocation_tracker: Arc<RwLock<AllocationTracker>>,
25 reference_tracker: Arc<RwLock<ReferenceTracker>>,
27 buffer_safety_monitor: Arc<RwLock<BufferSafetyMonitor>>,
29 audit_config: MemorySafetyConfig,
31}
32
33#[derive(Debug, Clone)]
35pub struct MemorySafetyConfig {
36 pub enable_allocation_tracking: bool,
38 pub enable_reference_cycle_detection: bool,
40 pub enable_buffer_bounds_checking: bool,
42 pub max_memory_threshold: u64,
44 pub enable_automatic_cleanup: bool,
46 pub audit_interval: Duration,
48}
49
50impl Default for MemorySafetyConfig {
51 fn default() -> Self {
52 Self {
53 enable_allocation_tracking: true,
54 enable_reference_cycle_detection: true,
55 enable_buffer_bounds_checking: true,
56 max_memory_threshold: 1024 * 1024 * 1024, enable_automatic_cleanup: true,
58 audit_interval: Duration::from_secs(30),
59 }
60 }
61}
62
63#[derive(Debug, Default)]
65pub struct AllocationTracker {
66 pub total_allocations: AtomicU64,
68 pub total_deallocations: AtomicU64,
70 pub current_memory_usage: AtomicU64,
72 pub peak_memory_usage: AtomicU64,
74 pub active_allocations: HashMap<String, AllocationInfo>,
76 pub allocation_patterns: HashMap<String, AllocationPattern>,
78 pub detected_leaks: Vec<MemoryLeak>,
80}
81
82#[derive(Debug, Clone)]
84pub struct AllocationInfo {
85 pub allocation_id: String,
87 pub size: u64,
89 pub timestamp: Instant,
91 pub location: String,
93 pub allocation_type: AllocationType,
95 pub thread_id: std::thread::ThreadId,
97}
98
99#[derive(Debug, Clone, PartialEq)]
101pub enum AllocationType {
102 AudioBuffer,
104 ModelData,
106 ConversionCache,
108 TemporaryBuffer,
110 ConfigurationData,
112 MetricsData,
114 Other(String),
116}
117
118#[derive(Debug, Clone)]
120pub struct AllocationPattern {
121 pub pattern_name: String,
123 pub allocation_count: u32,
125 pub average_size: u64,
127 pub total_size: u64,
129 pub frequency: f64,
131 pub typical_lifetime: Duration,
133}
134
135#[derive(Debug, Clone)]
137pub struct MemoryLeak {
138 pub leak_id: String,
140 pub allocation_info: AllocationInfo,
142 pub leak_detected_at: Instant,
144 pub estimated_leak_duration: Duration,
146 pub severity: LeakSeverity,
148}
149
150#[derive(Debug, Clone, PartialEq)]
152pub enum LeakSeverity {
153 Low,
155 Medium,
157 High,
159 Critical,
161}
162
163#[derive(Debug, Default)]
165pub struct ReferenceTracker {
166 pub strong_references: HashMap<String, ReferenceInfo>,
168 pub weak_references: HashMap<String, WeakReferenceInfo>,
170 pub detected_cycles: Vec<ReferenceCycle>,
172 pub reference_patterns: HashMap<String, ReferencePattern>,
174}
175
176#[derive(Debug, Clone)]
178pub struct ReferenceInfo {
179 pub reference_id: String,
181 pub object_type: String,
183 pub created_at: Instant,
185 pub last_accessed: Instant,
187 pub access_count: u32,
189 pub source_location: String,
191 pub reference_chain: Vec<String>,
193}
194
195#[derive(Debug, Clone)]
197pub struct WeakReferenceInfo {
198 pub reference_id: String,
200 pub object_type: String,
202 pub created_at: Instant,
204 pub is_valid: bool,
206 pub upgrade_attempts: u32,
208 pub successful_upgrades: u32,
210}
211
212#[derive(Debug, Clone)]
214pub struct ReferenceCycle {
215 pub cycle_id: String,
217 pub objects_in_cycle: Vec<String>,
219 pub cycle_length: usize,
221 pub detected_at: Instant,
223 pub cycle_type: CycleType,
225 pub estimated_memory_impact: u64,
227}
228
229#[derive(Debug, Clone, PartialEq)]
231pub enum CycleType {
232 DirectCycle,
234 IndirectCycle,
236 ComplexCycle,
238}
239
240#[derive(Debug, Clone)]
242pub struct ReferencePattern {
243 pub pattern_name: String,
245 pub creation_frequency: f64,
247 pub average_lifetime: Duration,
249 pub typical_access_pattern: AccessPattern,
251 pub common_reference_chains: Vec<Vec<String>>,
253}
254
255#[derive(Debug, Clone, PartialEq)]
257pub enum AccessPattern {
258 SingleAccess,
260 BurstAccess,
262 SteadyAccess,
264 DecreasingAccess,
266 PeriodicAccess,
268}
269
270#[derive(Debug, Default)]
272pub struct BufferSafetyMonitor {
273 pub bounds_violations: Vec<BoundsViolation>,
275 pub buffer_stats: HashMap<String, BufferStats>,
277 pub unsafe_operations: Vec<UnsafeOperation>,
279 pub buffer_lifecycle: HashMap<String, BufferLifecycle>,
281}
282
283#[derive(Debug, Clone)]
285pub struct BoundsViolation {
286 pub violation_id: String,
288 pub buffer_id: String,
290 pub violation_type: ViolationType,
292 pub attempted_index: isize,
294 pub buffer_size: usize,
296 pub stack_trace: String,
298 pub detected_at: Instant,
300 pub severity: ViolationSeverity,
302}
303
304#[derive(Debug, Clone, PartialEq)]
306pub enum ViolationType {
307 ReadBeyondBounds,
309 WriteBeyondBounds,
311 NegativeIndex,
313 UseAfterFree,
315 DoubleFree,
317}
318
319#[derive(Debug, Clone, PartialEq)]
321pub enum ViolationSeverity {
322 Warning,
324 Error,
326 Critical,
328}
329
330#[derive(Debug, Clone)]
332pub struct BufferStats {
333 pub buffer_id: String,
335 pub buffer_type: String,
337 pub size: usize,
339 pub access_count: u32,
341 pub read_operations: u32,
343 pub write_operations: u32,
345 pub resize_operations: u32,
347 pub first_access: Instant,
349 pub last_access: Instant,
351 pub average_access_interval: Duration,
353}
354
355#[derive(Debug, Clone)]
357pub struct UnsafeOperation {
358 pub operation_id: String,
360 pub operation_type: UnsafeOperationType,
362 pub buffer_id: String,
364 pub detected_at: Instant,
366 pub risk_level: RiskLevel,
368 pub mitigation_applied: Option<String>,
370}
371
372#[derive(Debug, Clone, PartialEq)]
374pub enum UnsafeOperationType {
375 UnalignedAccess,
377 RacyAccess,
379 DanglingPointer,
381 BufferOverflow,
383 UseAfterMove,
385 ConcurrentMutation,
387}
388
389#[derive(Debug, Clone, PartialEq)]
391pub enum RiskLevel {
392 Low,
394 Medium,
396 High,
398 Critical,
400}
401
402#[derive(Debug, Clone)]
404pub struct BufferLifecycle {
405 pub buffer_id: String,
407 pub created_at: Instant,
409 pub size_changes: Vec<(Instant, usize)>,
411 pub access_pattern: Vec<(Instant, AccessType)>,
413 pub current_state: BufferState,
415 pub expected_lifetime: Option<Duration>,
417}
418
419#[derive(Debug, Clone, PartialEq)]
421pub enum AccessType {
422 Read,
424 Write,
426 Resize,
428 Clone,
430 Move,
432}
433
434#[derive(Debug, Clone, PartialEq)]
436pub enum BufferState {
437 Active,
439 Borrowed,
441 Moved,
443 Dropped,
445}
446
447impl MemorySafetyAuditor {
448 pub fn new(config: MemorySafetyConfig) -> Self {
450 Self {
451 allocation_tracker: Arc::new(RwLock::new(AllocationTracker::default())),
452 reference_tracker: Arc::new(RwLock::new(ReferenceTracker::default())),
453 buffer_safety_monitor: Arc::new(RwLock::new(BufferSafetyMonitor::default())),
454 audit_config: config,
455 }
456 }
457
458 pub async fn start_periodic_audit(&self) -> Result<()> {
460 if !self.audit_config.enable_allocation_tracking
461 && !self.audit_config.enable_reference_cycle_detection
462 && !self.audit_config.enable_buffer_bounds_checking
463 {
464 return Ok(()); }
466
467 let auditor = Self {
468 allocation_tracker: Arc::clone(&self.allocation_tracker),
469 reference_tracker: Arc::clone(&self.reference_tracker),
470 buffer_safety_monitor: Arc::clone(&self.buffer_safety_monitor),
471 audit_config: self.audit_config.clone(),
472 };
473
474 tokio::spawn(async move {
475 let mut interval = tokio::time::interval(auditor.audit_config.audit_interval);
476
477 loop {
478 interval.tick().await;
479
480 if let Err(e) = auditor.perform_audit().await {
481 warn!("Memory safety audit failed: {}", e);
482 }
483 }
484 });
485
486 info!(
487 "Started periodic memory safety audit with interval: {:?}",
488 self.audit_config.audit_interval
489 );
490 Ok(())
491 }
492
493 pub async fn perform_audit(&self) -> Result<MemorySafetyReport> {
495 let mut report = MemorySafetyReport::default();
496
497 if self.audit_config.enable_allocation_tracking {
499 report.allocation_audit = Some(self.audit_allocations().await?);
500 }
501
502 if self.audit_config.enable_reference_cycle_detection {
504 report.reference_audit = Some(self.audit_references().await?);
505 }
506
507 if self.audit_config.enable_buffer_bounds_checking {
509 report.buffer_audit = Some(self.audit_buffers().await?);
510 }
511
512 report.overall_safety_score = self.calculate_safety_score(&report);
514 report.audit_timestamp = Instant::now();
515
516 if self.audit_config.enable_automatic_cleanup {
518 self.apply_automatic_cleanup(&report).await?;
519 }
520
521 Ok(report)
522 }
523
524 async fn audit_allocations(&self) -> Result<AllocationAuditResult> {
526 let tracker = self.allocation_tracker.read().await;
527 let mut result = AllocationAuditResult::default();
528
529 let current_time = Instant::now();
531 for (id, alloc_info) in &tracker.active_allocations {
532 let age = current_time.duration_since(alloc_info.timestamp);
533
534 if age > Duration::from_secs(300) {
536 let severity = match alloc_info.size {
537 size if size > 100 * 1024 * 1024 => LeakSeverity::Critical, size if size > 10 * 1024 * 1024 => LeakSeverity::High, size if size > 1024 * 1024 => LeakSeverity::Medium, _ => LeakSeverity::Low,
541 };
542
543 let leak = MemoryLeak {
544 leak_id: format!("leak_{}", id),
545 allocation_info: alloc_info.clone(),
546 leak_detected_at: current_time,
547 estimated_leak_duration: age,
548 severity,
549 };
550
551 result.detected_leaks.push(leak);
552 }
553 }
554
555 result.total_active_allocations = tracker.active_allocations.len();
557 result.current_memory_usage = tracker.current_memory_usage.load(Ordering::Relaxed);
558 result.peak_memory_usage = tracker.peak_memory_usage.load(Ordering::Relaxed);
559 result.allocation_patterns = tracker.allocation_patterns.clone();
560
561 if result.current_memory_usage > self.audit_config.max_memory_threshold {
563 result.memory_threshold_exceeded = true;
564 warn!(
565 "Memory usage ({} bytes) exceeds threshold ({} bytes)",
566 result.current_memory_usage, self.audit_config.max_memory_threshold
567 );
568 }
569
570 Ok(result)
571 }
572
573 async fn audit_references(&self) -> Result<ReferenceAuditResult> {
575 let tracker = self.reference_tracker.read().await;
576 let mut result = ReferenceAuditResult::default();
577
578 result.detected_cycles = tracker.detected_cycles.clone();
580 result.active_strong_references = tracker.strong_references.len();
581 result.active_weak_references = tracker.weak_references.len();
582
583 let current_time = Instant::now();
585 for (id, ref_info) in &tracker.strong_references {
586 let idle_time = current_time.duration_since(ref_info.last_accessed);
587 if idle_time > Duration::from_secs(600) {
588 result.orphaned_references.push(ref_info.clone());
590 }
591 }
592
593 result.reference_patterns = tracker.reference_patterns.clone();
595
596 Ok(result)
597 }
598
599 async fn audit_buffers(&self) -> Result<BufferAuditResult> {
601 let monitor = self.buffer_safety_monitor.read().await;
602 let mut result = BufferAuditResult::default();
603
604 result.bounds_violations = monitor.bounds_violations.clone();
605 result.unsafe_operations = monitor.unsafe_operations.clone();
606 result.buffer_statistics = monitor.buffer_stats.clone();
607
608 for (id, lifecycle) in &monitor.buffer_lifecycle {
610 if lifecycle.current_state == BufferState::Dropped {
611 continue; }
613
614 let age = Instant::now().duration_since(lifecycle.created_at);
616 if age > Duration::from_secs(1800) {
617 result.long_lived_buffers.push(lifecycle.clone());
619 }
620 }
621
622 Ok(result)
623 }
624
625 fn calculate_safety_score(&self, report: &MemorySafetyReport) -> f64 {
627 let mut score = 100.0;
628
629 if let Some(ref alloc_audit) = report.allocation_audit {
630 for leak in &alloc_audit.detected_leaks {
632 let deduction = match leak.severity {
633 LeakSeverity::Critical => 25.0,
634 LeakSeverity::High => 15.0,
635 LeakSeverity::Medium => 8.0,
636 LeakSeverity::Low => 3.0,
637 };
638 score -= deduction;
639 }
640
641 if alloc_audit.memory_threshold_exceeded {
643 score -= 20.0;
644 }
645 }
646
647 if let Some(ref ref_audit) = report.reference_audit {
648 score -= ref_audit.detected_cycles.len() as f64 * 10.0;
650
651 score -= ref_audit.orphaned_references.len() as f64 * 5.0;
653 }
654
655 if let Some(ref buf_audit) = report.buffer_audit {
656 for violation in &buf_audit.bounds_violations {
658 let deduction = match violation.severity {
659 ViolationSeverity::Critical => 30.0,
660 ViolationSeverity::Error => 15.0,
661 ViolationSeverity::Warning => 5.0,
662 };
663 score -= deduction;
664 }
665
666 for operation in &buf_audit.unsafe_operations {
668 let deduction = match operation.risk_level {
669 RiskLevel::Critical => 25.0,
670 RiskLevel::High => 15.0,
671 RiskLevel::Medium => 8.0,
672 RiskLevel::Low => 3.0,
673 };
674 score -= deduction;
675 }
676 }
677
678 score.max(0.0)
680 }
681
682 async fn apply_automatic_cleanup(&self, report: &MemorySafetyReport) -> Result<()> {
684 if let Some(ref alloc_audit) = report.allocation_audit {
685 let mut tracker = self.allocation_tracker.write().await;
687 let mut cleaned_up = 0;
688
689 let current_time = Instant::now();
690 let old_allocations: Vec<String> = tracker
691 .active_allocations
692 .iter()
693 .filter(|(_, alloc)| {
694 let age = current_time.duration_since(alloc.timestamp);
695 age > Duration::from_secs(3600) && alloc.size < 1024 * 1024 })
697 .map(|(id, _)| id.clone())
698 .collect();
699
700 for id in old_allocations {
701 if let Some(alloc_info) = tracker.active_allocations.remove(&id) {
702 let current_usage = tracker.current_memory_usage.load(Ordering::Relaxed);
703 tracker.current_memory_usage.store(
704 current_usage.saturating_sub(alloc_info.size),
705 Ordering::Relaxed,
706 );
707 tracker.total_deallocations.fetch_add(1, Ordering::Relaxed);
708 cleaned_up += 1;
709 }
710 }
711
712 if cleaned_up > 0 {
713 info!("Automatic cleanup removed {} old allocations", cleaned_up);
714 }
715 }
716
717 Ok(())
718 }
719
720 pub async fn track_allocation(
722 &self,
723 allocation_id: String,
724 size: u64,
725 location: String,
726 allocation_type: AllocationType,
727 ) -> Result<()> {
728 if !self.audit_config.enable_allocation_tracking {
729 return Ok(());
730 }
731
732 let mut tracker = self.allocation_tracker.write().await;
733
734 let alloc_info = AllocationInfo {
735 allocation_id: allocation_id.clone(),
736 size,
737 timestamp: Instant::now(),
738 location,
739 allocation_type,
740 thread_id: std::thread::current().id(),
741 };
742
743 tracker.active_allocations.insert(allocation_id, alloc_info);
744 tracker.total_allocations.fetch_add(1, Ordering::Relaxed);
745
746 let new_usage = tracker
747 .current_memory_usage
748 .fetch_add(size, Ordering::Relaxed)
749 + size;
750
751 let current_peak = tracker.peak_memory_usage.load(Ordering::Relaxed);
753 if new_usage > current_peak {
754 tracker
755 .peak_memory_usage
756 .store(new_usage, Ordering::Relaxed);
757 }
758
759 Ok(())
760 }
761
762 pub async fn track_deallocation(&self, allocation_id: &str) -> Result<()> {
764 if !self.audit_config.enable_allocation_tracking {
765 return Ok(());
766 }
767
768 let mut tracker = self.allocation_tracker.write().await;
769
770 if let Some(alloc_info) = tracker.active_allocations.remove(allocation_id) {
771 tracker.total_deallocations.fetch_add(1, Ordering::Relaxed);
772 let current_usage = tracker.current_memory_usage.load(Ordering::Relaxed);
773 tracker.current_memory_usage.store(
774 current_usage.saturating_sub(alloc_info.size),
775 Ordering::Relaxed,
776 );
777 }
778
779 Ok(())
780 }
781
782 pub async fn get_safety_status(&self) -> MemorySafetyStatus {
784 let allocation_tracker = self.allocation_tracker.read().await;
785 let reference_tracker = self.reference_tracker.read().await;
786 let buffer_monitor = self.buffer_safety_monitor.read().await;
787
788 MemorySafetyStatus {
789 current_memory_usage: allocation_tracker
790 .current_memory_usage
791 .load(Ordering::Relaxed),
792 active_allocations: allocation_tracker.active_allocations.len(),
793 detected_leaks: allocation_tracker.detected_leaks.len(),
794 active_references: reference_tracker.strong_references.len(),
795 detected_cycles: reference_tracker.detected_cycles.len(),
796 bounds_violations: buffer_monitor.bounds_violations.len(),
797 unsafe_operations: buffer_monitor.unsafe_operations.len(),
798 last_audit: Instant::now(), }
800 }
801}
802
803#[derive(Debug)]
805pub struct MemorySafetyReport {
806 pub allocation_audit: Option<AllocationAuditResult>,
808 pub reference_audit: Option<ReferenceAuditResult>,
810 pub buffer_audit: Option<BufferAuditResult>,
812 pub overall_safety_score: f64,
814 pub audit_timestamp: Instant,
816}
817
818impl Default for MemorySafetyReport {
819 fn default() -> Self {
820 Self {
821 allocation_audit: None,
822 reference_audit: None,
823 buffer_audit: None,
824 overall_safety_score: 0.0,
825 audit_timestamp: Instant::now(),
826 }
827 }
828}
829
830#[derive(Debug, Default)]
832pub struct AllocationAuditResult {
833 pub detected_leaks: Vec<MemoryLeak>,
835 pub total_active_allocations: usize,
837 pub current_memory_usage: u64,
839 pub peak_memory_usage: u64,
841 pub allocation_patterns: HashMap<String, AllocationPattern>,
843 pub memory_threshold_exceeded: bool,
845}
846
847#[derive(Debug, Default)]
849pub struct ReferenceAuditResult {
850 pub detected_cycles: Vec<ReferenceCycle>,
852 pub active_strong_references: usize,
854 pub active_weak_references: usize,
856 pub orphaned_references: Vec<ReferenceInfo>,
858 pub reference_patterns: HashMap<String, ReferencePattern>,
860}
861
862#[derive(Debug, Default)]
864pub struct BufferAuditResult {
865 pub bounds_violations: Vec<BoundsViolation>,
867 pub unsafe_operations: Vec<UnsafeOperation>,
869 pub buffer_statistics: HashMap<String, BufferStats>,
871 pub long_lived_buffers: Vec<BufferLifecycle>,
873}
874
875#[derive(Debug)]
877pub struct MemorySafetyStatus {
878 pub current_memory_usage: u64,
880 pub active_allocations: usize,
882 pub detected_leaks: usize,
884 pub active_references: usize,
886 pub detected_cycles: usize,
888 pub bounds_violations: usize,
890 pub unsafe_operations: usize,
892 pub last_audit: Instant,
894}
895
896pub struct ThreadSafeModelManager {
898 models: Arc<RwLock<HashMap<ConversionType, Arc<ConversionModel>>>>,
900 loading_semaphore: Arc<Semaphore>,
902 stats: Arc<RwLock<ModelAccessStats>>,
904 max_cached_models: usize,
906 usage_tracker: Arc<RwLock<HashMap<ConversionType, ModelUsageInfo>>>,
908}
909
910#[derive(Debug, Clone, Default)]
912pub struct ModelAccessStats {
913 pub cache_hits: u64,
915 pub cache_misses: u64,
917 pub models_loaded: u64,
919 pub models_evicted: u64,
921 pub concurrent_loads: u64,
923 pub average_load_time: Duration,
925 pub last_cleanup: Option<Instant>,
927}
928
929#[derive(Debug, Clone)]
931pub struct ModelUsageInfo {
932 pub last_accessed: Instant,
934 pub access_count: u32,
936 pub total_processing_time: Duration,
938 pub average_processing_time: Duration,
940 pub memory_usage_estimate: u64,
942}
943
944impl ThreadSafeModelManager {
945 pub fn new(max_cached_models: usize) -> Self {
947 Self {
948 models: Arc::new(RwLock::new(HashMap::new())),
949 loading_semaphore: Arc::new(Semaphore::new(2)), stats: Arc::new(RwLock::new(ModelAccessStats::default())),
951 max_cached_models,
952 usage_tracker: Arc::new(RwLock::new(HashMap::new())),
953 }
954 }
955
956 pub async fn get_model(
958 &self,
959 conversion_type: &ConversionType,
960 ) -> Result<Option<Arc<ConversionModel>>> {
961 {
963 let models_guard = self.models.read().await;
964 if let Some(model) = models_guard.get(conversion_type) {
965 self.update_access_stats(conversion_type, true).await;
967 return Ok(Some(Arc::clone(model)));
968 }
969 }
970
971 self.update_access_stats(conversion_type, false).await;
973
974 let _permit = self
976 .loading_semaphore
977 .acquire()
978 .await
979 .map_err(|e| Error::runtime(format!("Failed to acquire loading permit: {}", e)))?;
980
981 {
983 let models_guard = self.models.read().await;
984 if let Some(model) = models_guard.get(conversion_type) {
985 self.update_access_stats(conversion_type, true).await;
986 return Ok(Some(Arc::clone(model)));
987 }
988 }
989
990 debug!("Loading model for conversion type: {:?}", conversion_type);
992 let start_time = Instant::now();
993
994 let model = self.load_model_impl(conversion_type).await?;
996
997 let load_time = start_time.elapsed();
998
999 {
1001 let mut models_guard = self.models.write().await;
1002 let mut usage_guard = self.usage_tracker.write().await;
1003
1004 if models_guard.len() >= self.max_cached_models {
1006 self.evict_least_used_model(&mut models_guard, &mut usage_guard)
1007 .await;
1008 }
1009
1010 let model_arc = Arc::new(model);
1012 models_guard.insert(conversion_type.clone(), Arc::clone(&model_arc));
1013
1014 usage_guard.insert(
1016 conversion_type.clone(),
1017 ModelUsageInfo {
1018 last_accessed: Instant::now(),
1019 access_count: 1,
1020 total_processing_time: Duration::from_millis(0),
1021 average_processing_time: Duration::from_millis(0),
1022 memory_usage_estimate: 100 * 1024 * 1024, },
1024 );
1025
1026 {
1028 let mut stats_guard = self.stats.write().await;
1029 stats_guard.models_loaded += 1;
1030 stats_guard.concurrent_loads += 1;
1031 stats_guard.average_load_time = if stats_guard.models_loaded == 1 {
1032 load_time
1033 } else {
1034 Duration::from_nanos(
1035 (stats_guard.average_load_time.as_nanos() as u64
1036 * (stats_guard.models_loaded - 1)
1037 + load_time.as_nanos() as u64)
1038 / stats_guard.models_loaded,
1039 )
1040 };
1041 }
1042
1043 Ok(Some(model_arc))
1044 }
1045 }
1046
1047 async fn load_model_impl(&self, conversion_type: &ConversionType) -> Result<ConversionModel> {
1049 tokio::time::sleep(Duration::from_millis(100)).await; let model_type = match conversion_type {
1053 ConversionType::SpeakerConversion => crate::models::ModelType::NeuralVC,
1054 ConversionType::AgeTransformation => crate::models::ModelType::NeuralVC,
1055 ConversionType::GenderTransformation => crate::models::ModelType::NeuralVC,
1056 ConversionType::VoiceMorphing => crate::models::ModelType::AutoVC,
1057 ConversionType::EmotionalTransformation => crate::models::ModelType::Transformer,
1058 _ => crate::models::ModelType::Custom,
1059 };
1060
1061 Ok(ConversionModel::new(model_type))
1062 }
1063
1064 async fn evict_least_used_model(
1066 &self,
1067 models_guard: &mut HashMap<ConversionType, Arc<ConversionModel>>,
1068 usage_guard: &mut HashMap<ConversionType, ModelUsageInfo>,
1069 ) {
1070 if let Some((least_used_type, _)) = usage_guard
1071 .iter()
1072 .min_by_key(|(_, usage)| (usage.last_accessed, usage.access_count))
1073 {
1074 let evicted_type = least_used_type.clone();
1075 models_guard.remove(&evicted_type);
1076 usage_guard.remove(&evicted_type);
1077
1078 {
1080 let mut stats_guard = self.stats.write().await;
1081 stats_guard.models_evicted += 1;
1082 }
1083
1084 debug!("Evicted least used model: {:?}", evicted_type);
1085 }
1086 }
1087
1088 async fn update_access_stats(&self, conversion_type: &ConversionType, cache_hit: bool) {
1090 let mut stats_guard = self.stats.write().await;
1091 if cache_hit {
1092 stats_guard.cache_hits += 1;
1093 } else {
1094 stats_guard.cache_misses += 1;
1095 }
1096
1097 drop(stats_guard);
1099 let mut usage_guard = self.usage_tracker.write().await;
1100 if let Some(usage_info) = usage_guard.get_mut(conversion_type) {
1101 usage_info.last_accessed = Instant::now();
1102 usage_info.access_count += 1;
1103 }
1104 }
1105
1106 pub async fn get_stats(&self) -> ModelAccessStats {
1108 self.stats.read().await.clone()
1109 }
1110
1111 pub async fn clear_cache(&self) {
1113 let mut models_guard = self.models.write().await;
1114 let mut usage_guard = self.usage_tracker.write().await;
1115
1116 let evicted_count = models_guard.len();
1117 models_guard.clear();
1118 usage_guard.clear();
1119
1120 {
1122 let mut stats_guard = self.stats.write().await;
1123 stats_guard.models_evicted += evicted_count as u64;
1124 }
1125
1126 info!(
1127 "Cleared all cached models: {} models evicted",
1128 evicted_count
1129 );
1130 }
1131
1132 pub async fn cleanup_unused_models(&self, max_idle_time: Duration) {
1134 let now = Instant::now();
1135 let mut models_guard = self.models.write().await;
1136 let mut usage_guard = self.usage_tracker.write().await;
1137
1138 let mut to_remove = Vec::new();
1139 for (conversion_type, usage_info) in usage_guard.iter() {
1140 if now.duration_since(usage_info.last_accessed) > max_idle_time {
1141 to_remove.push(conversion_type.clone());
1142 }
1143 }
1144
1145 let mut evicted_count = 0;
1146 for conversion_type in to_remove {
1147 models_guard.remove(&conversion_type);
1148 usage_guard.remove(&conversion_type);
1149 evicted_count += 1;
1150 debug!("Evicted idle model: {:?}", conversion_type);
1151 }
1152
1153 if evicted_count > 0 {
1154 {
1156 let mut stats_guard = self.stats.write().await;
1157 stats_guard.models_evicted += evicted_count;
1158 stats_guard.last_cleanup = Some(now);
1159 }
1160
1161 info!("Cleanup evicted {} idle models", evicted_count);
1162 }
1163 }
1164}
1165
1166pub struct OperationGuard {
1168 operation_state: Arc<RwLock<OperationState>>,
1170 _permit: OwnedSemaphorePermit,
1172 operation_id: String,
1174 start_time: Instant,
1176}
1177
1178#[derive(Debug, Default, Clone)]
1180pub struct OperationState {
1181 pub active_operations: HashMap<String, OperationInfo>,
1183 pub completed_operations: u64,
1185 pub failed_operations: u64,
1187 pub average_duration: Duration,
1189}
1190
1191#[derive(Debug, Clone)]
1193pub struct OperationInfo {
1194 pub operation_id: String,
1196 pub conversion_type: ConversionType,
1198 pub start_time: Instant,
1200 pub thread_id: std::thread::ThreadId,
1202 pub status: OperationStatus,
1204}
1205
1206#[derive(Debug, Clone, PartialEq)]
1208pub enum OperationStatus {
1209 Starting,
1211 Processing,
1213 Finalizing,
1215 Completed,
1217 Failed(String),
1219}
1220
1221impl OperationGuard {
1222 pub async fn new(
1224 operation_state: Arc<RwLock<OperationState>>,
1225 semaphore: Arc<Semaphore>,
1226 operation_id: String,
1227 conversion_type: ConversionType,
1228 ) -> Result<Self> {
1229 let permit = semaphore
1230 .acquire_owned()
1231 .await
1232 .map_err(|e| Error::runtime(format!("Failed to acquire operation permit: {}", e)))?;
1233
1234 let start_time = Instant::now();
1235
1236 {
1238 let mut state_guard = operation_state.write().await;
1239 state_guard.active_operations.insert(
1240 operation_id.clone(),
1241 OperationInfo {
1242 operation_id: operation_id.clone(),
1243 conversion_type,
1244 start_time,
1245 thread_id: std::thread::current().id(),
1246 status: OperationStatus::Starting,
1247 },
1248 );
1249 }
1250
1251 Ok(Self {
1252 operation_state,
1253 _permit: permit,
1254 operation_id,
1255 start_time,
1256 })
1257 }
1258
1259 pub async fn update_status(&self, status: OperationStatus) {
1261 let mut state_guard = self.operation_state.write().await;
1262 if let Some(op_info) = state_guard.active_operations.get_mut(&self.operation_id) {
1263 op_info.status = status;
1264 }
1265 }
1266
1267 pub async fn complete(&self) {
1269 self.finalize_operation(OperationStatus::Completed).await;
1270 }
1271
1272 pub async fn fail(&self, error: String) {
1274 self.finalize_operation(OperationStatus::Failed(error))
1275 .await;
1276 }
1277
1278 async fn finalize_operation(&self, final_status: OperationStatus) {
1280 let duration = self.start_time.elapsed();
1281 let mut state_guard = self.operation_state.write().await;
1282
1283 state_guard.active_operations.remove(&self.operation_id);
1285
1286 match final_status {
1288 OperationStatus::Completed => {
1289 state_guard.completed_operations += 1;
1290
1291 let total_ops = state_guard.completed_operations;
1293 if total_ops == 1 {
1294 state_guard.average_duration = duration;
1295 } else {
1296 let total_nanos = state_guard.average_duration.as_nanos() as u64
1297 * (total_ops - 1)
1298 + duration.as_nanos() as u64;
1299 state_guard.average_duration = Duration::from_nanos(total_nanos / total_ops);
1300 }
1301 }
1302 OperationStatus::Failed(_) => {
1303 state_guard.failed_operations += 1;
1304 }
1305 _ => {}
1306 }
1307
1308 debug!(
1309 "Operation {} finalized with status {:?} in {:?}",
1310 self.operation_id, final_status, duration
1311 );
1312 }
1313}
1314
1315impl Drop for OperationGuard {
1316 fn drop(&mut self) {
1317 let operation_state = Arc::clone(&self.operation_state);
1319 let operation_id = self.operation_id.clone();
1320
1321 tokio::spawn(async move {
1322 let mut state_guard = operation_state.write().await;
1323 state_guard.active_operations.remove(&operation_id);
1324 });
1325 }
1326}
1327
1328pub struct ConcurrentConversionManager {
1330 operation_state: Arc<RwLock<OperationState>>,
1332 operation_semaphore: Arc<Semaphore>,
1334 model_manager: Arc<ThreadSafeModelManager>,
1336 config: Arc<RwLock<ConversionConfig>>,
1338 metrics: Arc<RwLock<ConcurrentConversionMetrics>>,
1340}
1341
1342#[derive(Debug, Default, Clone)]
1344pub struct ConcurrentConversionMetrics {
1345 pub total_requests: u64,
1347 pub successful_conversions: u64,
1349 pub failed_conversions: u64,
1351 pub average_queue_time: Duration,
1353 pub average_processing_time: Duration,
1355 pub peak_concurrent_operations: usize,
1357 pub current_concurrent_operations: usize,
1359}
1360
1361impl ConcurrentConversionManager {
1362 pub fn new(
1364 max_concurrent_operations: usize,
1365 max_cached_models: usize,
1366 config: ConversionConfig,
1367 ) -> Self {
1368 Self {
1369 operation_state: Arc::new(RwLock::new(OperationState::default())),
1370 operation_semaphore: Arc::new(Semaphore::new(max_concurrent_operations)),
1371 model_manager: Arc::new(ThreadSafeModelManager::new(max_cached_models)),
1372 config: Arc::new(RwLock::new(config)),
1373 metrics: Arc::new(RwLock::new(ConcurrentConversionMetrics::default())),
1374 }
1375 }
1376
1377 pub async fn convert_with_concurrency_control(
1379 &self,
1380 request: ConversionRequest,
1381 ) -> Result<ConversionResult> {
1382 let queue_start = Instant::now();
1383
1384 {
1386 let mut metrics_guard = self.metrics.write().await;
1387 metrics_guard.total_requests += 1;
1388 }
1389
1390 let operation_guard = OperationGuard::new(
1392 Arc::clone(&self.operation_state),
1393 Arc::clone(&self.operation_semaphore),
1394 request.id.clone(),
1395 request.conversion_type.clone(),
1396 )
1397 .await?;
1398
1399 let queue_time = queue_start.elapsed();
1400
1401 {
1403 let mut metrics_guard = self.metrics.write().await;
1404 let total_requests = metrics_guard.total_requests;
1405 let current_avg = metrics_guard.average_queue_time;
1406
1407 metrics_guard.average_queue_time = if total_requests == 1 {
1408 queue_time
1409 } else {
1410 Duration::from_nanos(
1411 (current_avg.as_nanos() as u64 * (total_requests - 1)
1412 + queue_time.as_nanos() as u64)
1413 / total_requests,
1414 )
1415 };
1416
1417 metrics_guard.current_concurrent_operations += 1;
1418 if metrics_guard.current_concurrent_operations
1419 > metrics_guard.peak_concurrent_operations
1420 {
1421 metrics_guard.peak_concurrent_operations =
1422 metrics_guard.current_concurrent_operations;
1423 }
1424 }
1425
1426 operation_guard
1427 .update_status(OperationStatus::Processing)
1428 .await;
1429
1430 let conversion_result = match self
1432 .perform_safe_conversion(&request, &operation_guard)
1433 .await
1434 {
1435 Ok(result) => {
1436 operation_guard.complete().await;
1437
1438 {
1440 let mut metrics_guard = self.metrics.write().await;
1441 metrics_guard.successful_conversions += 1;
1442 metrics_guard.current_concurrent_operations -= 1;
1443 }
1444
1445 Ok(result)
1446 }
1447 Err(e) => {
1448 operation_guard.fail(e.to_string()).await;
1449
1450 {
1452 let mut metrics_guard = self.metrics.write().await;
1453 metrics_guard.failed_conversions += 1;
1454 metrics_guard.current_concurrent_operations -= 1;
1455 }
1456
1457 Err(e)
1458 }
1459 };
1460
1461 conversion_result
1462 }
1463
1464 async fn perform_safe_conversion(
1466 &self,
1467 request: &ConversionRequest,
1468 operation_guard: &OperationGuard,
1469 ) -> Result<ConversionResult> {
1470 let processing_start = Instant::now();
1471
1472 let model = self
1474 .model_manager
1475 .get_model(&request.conversion_type)
1476 .await?;
1477
1478 operation_guard
1479 .update_status(OperationStatus::Finalizing)
1480 .await;
1481
1482 let converted_audio = self.simulate_conversion(&request.source_audio).await?;
1484
1485 let processing_time = processing_start.elapsed();
1486
1487 {
1489 let mut metrics_guard = self.metrics.write().await;
1490 let successful_conversions = metrics_guard.successful_conversions + 1; let current_avg = metrics_guard.average_processing_time;
1492
1493 metrics_guard.average_processing_time = if successful_conversions == 1 {
1494 processing_time
1495 } else {
1496 Duration::from_nanos(
1497 (current_avg.as_nanos() as u64 * (successful_conversions - 1)
1498 + processing_time.as_nanos() as u64)
1499 / successful_conversions,
1500 )
1501 };
1502 }
1503
1504 Ok(ConversionResult {
1506 request_id: request.id.clone(),
1507 converted_audio,
1508 output_sample_rate: 22050, quality_metrics: HashMap::new(),
1510 artifacts: None,
1511 objective_quality: None,
1512 processing_time,
1513 conversion_type: request.conversion_type.clone(),
1514 success: true,
1515 error_message: None,
1516 timestamp: std::time::SystemTime::now(),
1517 })
1518 }
1519
1520 async fn simulate_conversion(&self, source_audio: &[f32]) -> Result<Vec<f32>> {
1522 tokio::time::sleep(Duration::from_millis(10)).await;
1524
1525 let mut result = source_audio.to_vec();
1527 for sample in &mut result {
1528 *sample *= 0.9; }
1530
1531 Ok(result)
1532 }
1533
1534 pub async fn get_metrics(&self) -> ConcurrentConversionMetrics {
1536 let metrics_guard = self.metrics.read().await;
1537 metrics_guard.clone()
1538 }
1539
1540 pub async fn get_operation_state(&self) -> OperationState {
1542 let state_guard = self.operation_state.read().await;
1543 state_guard.clone()
1544 }
1545
1546 pub async fn update_config(&self, new_config: ConversionConfig) -> Result<()> {
1548 let mut config_guard = self.config.write().await;
1549 *config_guard = new_config;
1550 info!("Configuration updated successfully");
1551 Ok(())
1552 }
1553
1554 pub async fn get_config(&self) -> ConversionConfig {
1556 self.config.read().await.clone()
1557 }
1558
1559 pub async fn health_check(&self) -> HashMap<String, String> {
1561 let mut health = HashMap::new();
1562
1563 let metrics = self.get_metrics().await;
1564 let operation_state = self.get_operation_state().await;
1565 let model_stats = self.model_manager.get_stats().await;
1566
1567 health.insert("status".to_string(), "healthy".to_string());
1568 health.insert(
1569 "total_requests".to_string(),
1570 metrics.total_requests.to_string(),
1571 );
1572 health.insert(
1573 "success_rate".to_string(),
1574 format!(
1575 "{:.2}%",
1576 if metrics.total_requests > 0 {
1577 (metrics.successful_conversions as f64 / metrics.total_requests as f64) * 100.0
1578 } else {
1579 100.0
1580 }
1581 ),
1582 );
1583 health.insert(
1584 "active_operations".to_string(),
1585 operation_state.active_operations.len().to_string(),
1586 );
1587 health.insert(
1588 "cached_models".to_string(),
1589 format!(
1590 "{}/{}",
1591 model_stats.cache_hits + model_stats.cache_misses - model_stats.models_evicted,
1592 self.model_manager.max_cached_models
1593 ),
1594 );
1595 health.insert(
1596 "model_cache_hit_rate".to_string(),
1597 format!(
1598 "{:.2}%",
1599 if model_stats.cache_hits + model_stats.cache_misses > 0 {
1600 (model_stats.cache_hits as f64
1601 / (model_stats.cache_hits + model_stats.cache_misses) as f64)
1602 * 100.0
1603 } else {
1604 0.0
1605 }
1606 ),
1607 );
1608
1609 health
1610 }
1611
1612 pub async fn shutdown(&self) -> Result<()> {
1614 info!("Starting graceful shutdown of concurrent conversion manager");
1615
1616 let shutdown_timeout = Duration::from_secs(30);
1618 let start_time = Instant::now();
1619
1620 while start_time.elapsed() < shutdown_timeout {
1621 let operation_state = self.operation_state.read().await;
1622 if operation_state.active_operations.is_empty() {
1623 break;
1624 }
1625 drop(operation_state);
1626
1627 debug!(
1628 "Waiting for {} active operations to complete",
1629 self.operation_state.read().await.active_operations.len()
1630 );
1631 tokio::time::sleep(Duration::from_millis(100)).await;
1632 }
1633
1634 self.model_manager.clear_cache().await;
1636
1637 let final_metrics = self.get_metrics().await;
1638 info!(
1639 "Concurrent conversion manager shutdown complete. Final stats: {} total requests, {} successful, {} failed",
1640 final_metrics.total_requests, final_metrics.successful_conversions, final_metrics.failed_conversions
1641 );
1642
1643 Ok(())
1644 }
1645}
1646
1647#[cfg(test)]
1648mod tests {
1649 use super::*;
1650 use crate::types::{ConversionTarget, VoiceCharacteristics};
1651
1652 #[tokio::test]
1653 async fn test_thread_safe_model_manager() {
1654 let manager = ThreadSafeModelManager::new(3);
1655
1656 let model = manager
1658 .get_model(&ConversionType::PitchShift)
1659 .await
1660 .unwrap();
1661 assert!(model.is_some());
1662
1663 let model2 = manager
1665 .get_model(&ConversionType::PitchShift)
1666 .await
1667 .unwrap();
1668 assert!(model2.is_some());
1669
1670 let stats = manager.get_stats().await;
1672 assert_eq!(stats.cache_hits, 1);
1673 assert_eq!(stats.cache_misses, 1);
1674 assert_eq!(stats.models_loaded, 1);
1675 }
1676
1677 #[tokio::test]
1678 async fn test_operation_guard() {
1679 let operation_state = Arc::new(RwLock::new(OperationState::default()));
1680 let semaphore = Arc::new(Semaphore::new(1));
1681
1682 let guard = OperationGuard::new(
1683 Arc::clone(&operation_state),
1684 semaphore,
1685 "test_op".to_string(),
1686 ConversionType::PitchShift,
1687 )
1688 .await
1689 .unwrap();
1690
1691 {
1693 let state = operation_state.read().await;
1694 assert!(state.active_operations.contains_key("test_op"));
1695 }
1696
1697 guard.complete().await;
1699
1700 {
1702 let state = operation_state.read().await;
1703 assert!(!state.active_operations.contains_key("test_op"));
1704 assert_eq!(state.completed_operations, 1);
1705 }
1706 }
1707
1708 #[tokio::test]
1709 async fn test_concurrent_conversion_manager() {
1710 let config = ConversionConfig::default();
1711 let manager = ConcurrentConversionManager::new(2, 3, config);
1712
1713 let request = ConversionRequest::new(
1714 "test_request".to_string(),
1715 vec![0.1, -0.1, 0.2, -0.2],
1716 22050,
1717 ConversionType::PitchShift,
1718 ConversionTarget::new(VoiceCharacteristics::default()),
1719 );
1720
1721 let result = manager.convert_with_concurrency_control(request).await;
1722 assert!(result.is_ok());
1723
1724 let metrics = manager.get_metrics().await;
1725 assert_eq!(metrics.total_requests, 1);
1726 assert_eq!(metrics.successful_conversions, 1);
1727 assert_eq!(metrics.failed_conversions, 0);
1728 }
1729
1730 #[tokio::test]
1731 async fn test_concurrent_operations() {
1732 let config = ConversionConfig::default();
1733 let manager = Arc::new(ConcurrentConversionManager::new(3, 2, config));
1734
1735 let mut handles = Vec::new();
1736
1737 for i in 0..5 {
1739 let manager_clone = Arc::clone(&manager);
1740 let handle = tokio::spawn(async move {
1741 let request = ConversionRequest::new(
1742 format!("test_request_{}", i),
1743 vec![0.1, -0.1, 0.2, -0.2],
1744 22050,
1745 ConversionType::PitchShift,
1746 ConversionTarget::new(VoiceCharacteristics::default()),
1747 );
1748
1749 manager_clone
1750 .convert_with_concurrency_control(request)
1751 .await
1752 });
1753 handles.push(handle);
1754 }
1755
1756 let mut successful = 0;
1758 for handle in handles {
1759 if handle.await.unwrap().is_ok() {
1760 successful += 1;
1761 }
1762 }
1763
1764 assert_eq!(successful, 5);
1765
1766 let metrics = manager.get_metrics().await;
1767 assert_eq!(metrics.total_requests, 5);
1768 assert_eq!(metrics.successful_conversions, 5);
1769 }
1770
1771 #[tokio::test]
1772 async fn test_model_cache_eviction() {
1773 let manager = ThreadSafeModelManager::new(2); manager
1777 .get_model(&ConversionType::PitchShift)
1778 .await
1779 .unwrap();
1780 manager
1781 .get_model(&ConversionType::SpeedTransformation)
1782 .await
1783 .unwrap();
1784
1785 manager
1787 .get_model(&ConversionType::GenderTransformation)
1788 .await
1789 .unwrap();
1790
1791 let stats = manager.get_stats().await;
1792 assert_eq!(stats.models_evicted, 1);
1793 }
1794
1795 #[tokio::test]
1796 async fn test_health_check() {
1797 let config = ConversionConfig::default();
1798 let manager = ConcurrentConversionManager::new(2, 3, config);
1799
1800 let health = manager.health_check().await;
1801 assert_eq!(health.get("status"), Some(&"healthy".to_string()));
1802 assert!(health.contains_key("total_requests"));
1803 assert!(health.contains_key("success_rate"));
1804 assert!(health.contains_key("cached_models"));
1805 }
1806}