Skip to main content

voirs_spatial/
memory.rs

1//! Memory Management and Cache Optimization Module
2//!
3//! This module provides comprehensive memory management and cache optimization
4//! for high-performance spatial audio processing, including object pools,
5//! cache-friendly data structures, and memory usage monitoring.
6
7use crate::types::Position3D;
8use scirs2_core::ndarray::{Array1, Array2};
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, VecDeque};
11use std::sync::Arc;
12use std::time::{Duration, Instant};
13use tokio::sync::RwLock;
14
15/// Memory manager for spatial audio processing
16pub struct MemoryManager {
17    /// Buffer pools for different sizes
18    buffer_pools: Arc<RwLock<HashMap<usize, BufferPool<f32>>>>,
19    /// Array pools for 2D arrays
20    array2d_pools: Arc<RwLock<HashMap<(usize, usize), Array2Pool>>>,
21    /// Cache manager for expensive computations
22    cache_manager: Arc<RwLock<CacheManager>>,
23    /// Memory usage statistics
24    memory_stats: Arc<RwLock<MemoryStatistics>>,
25    /// Configuration
26    config: MemoryConfig,
27}
28
29/// Buffer pool for reusing audio buffers
30pub struct BufferPool<T> {
31    /// Available buffers
32    available: VecDeque<Array1<T>>,
33    /// Maximum pool size
34    max_size: usize,
35    /// Total allocations
36    total_allocations: u64,
37    /// Pool hits
38    pool_hits: u64,
39}
40
41/// Pool for 2D arrays
42pub struct Array2Pool {
43    /// Available arrays
44    available: VecDeque<Array2<f32>>,
45    /// Array dimensions
46    dimensions: (usize, usize),
47    /// Maximum pool size
48    max_size: usize,
49    /// Total allocations
50    total_allocations: u64,
51    /// Pool hits
52    pool_hits: u64,
53}
54
55/// Cache manager for expensive computations
56pub struct CacheManager {
57    /// HRTF interpolation cache
58    hrtf_cache: HashMap<HrtfCacheKey, HrtfCacheEntry>,
59    /// Distance attenuation cache
60    distance_cache: HashMap<DistanceCacheKey, f32>,
61    /// Room impulse response cache
62    room_cache: HashMap<RoomCacheKey, Array1<f32>>,
63    /// Cache statistics
64    cache_stats: CacheStatistics,
65    /// Maximum cache size per type
66    max_cache_size: usize,
67}
68
69/// Memory management configuration
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct MemoryConfig {
72    /// Maximum buffer pool size per buffer type
73    pub max_buffer_pool_size: usize,
74    /// Maximum cache size per cache type
75    pub max_cache_size: usize,
76    /// Enable memory usage monitoring
77    pub enable_monitoring: bool,
78    /// Memory pressure threshold (0.0-1.0)
79    pub memory_pressure_threshold: f32,
80    /// Cache eviction policy
81    pub cache_policy: CachePolicy,
82    /// Buffer alignment for SIMD operations
83    pub buffer_alignment: usize,
84}
85
86/// Cache eviction policies
87#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
88pub enum CachePolicy {
89    /// Least Recently Used
90    LRU,
91    /// Least Frequently Used
92    LFU,
93    /// Time-based expiration
94    TTL,
95    /// Size-based eviction
96    SizeBased,
97}
98
99/// HRTF cache key for interpolation results
100#[derive(Debug, Clone, Hash, PartialEq, Eq)]
101struct HrtfCacheKey {
102    azimuth: i32,
103    elevation: i32,
104    distance: u32,
105}
106
107/// HRTF cache entry
108#[derive(Debug, Clone)]
109struct HrtfCacheEntry {
110    left_hrir: Array1<f32>,
111    right_hrir: Array1<f32>,
112    last_accessed: Instant,
113    access_count: u64,
114}
115
116/// Distance attenuation cache key
117#[derive(Debug, Clone, Hash, PartialEq, Eq)]
118struct DistanceCacheKey {
119    distance_mm: u32, // Store distance in millimeters for precision
120    model_type: u8,   // Attenuation model type
121}
122
123/// Room acoustic cache key
124#[derive(Debug, Clone, Hash, PartialEq, Eq)]
125struct RoomCacheKey {
126    room_hash: u64,
127    source_position_hash: u64,
128    listener_position_hash: u64,
129}
130
131/// Memory usage statistics
132#[derive(Debug, Clone)]
133pub struct MemoryStatistics {
134    /// Total memory allocated (bytes)
135    pub total_allocated: u64,
136    /// Memory currently in use (bytes)
137    pub memory_in_use: u64,
138    /// Peak memory usage (bytes)
139    pub peak_memory_usage: u64,
140    /// Buffer pool statistics
141    pub buffer_pool_stats: HashMap<usize, BufferPoolStats>,
142    /// Cache hit rates
143    pub cache_hit_rates: HashMap<String, f64>,
144    /// Memory pressure level (0.0-1.0)
145    pub memory_pressure: f32,
146    /// Last update time
147    pub last_updated: Instant,
148}
149
150impl Default for MemoryStatistics {
151    fn default() -> Self {
152        Self {
153            total_allocated: 0,
154            memory_in_use: 0,
155            peak_memory_usage: 0,
156            buffer_pool_stats: HashMap::new(),
157            cache_hit_rates: HashMap::new(),
158            memory_pressure: 0.0,
159            last_updated: Instant::now(),
160        }
161    }
162}
163
164/// Buffer pool statistics
165#[derive(Debug, Default, Clone)]
166pub struct BufferPoolStats {
167    /// Total allocations
168    pub total_allocations: u64,
169    /// Pool hits (reused buffers)
170    pub pool_hits: u64,
171    /// Current pool size
172    pub current_pool_size: usize,
173    /// Hit rate (pool_hits / total_allocations)
174    pub hit_rate: f64,
175}
176
177/// Cache statistics
178#[derive(Debug, Default)]
179struct CacheStatistics {
180    /// Total cache requests
181    total_requests: u64,
182    /// Cache hits
183    cache_hits: u64,
184    /// Cache misses
185    cache_misses: u64,
186    /// Cache evictions
187    cache_evictions: u64,
188    /// Memory used by caches (bytes)
189    memory_usage: u64,
190}
191
192impl Default for MemoryConfig {
193    fn default() -> Self {
194        Self {
195            max_buffer_pool_size: 128,
196            max_cache_size: 1024,
197            enable_monitoring: true,
198            memory_pressure_threshold: 0.8,
199            cache_policy: CachePolicy::LRU,
200            buffer_alignment: 32, // 32-byte alignment for AVX
201        }
202    }
203}
204
205impl Default for MemoryManager {
206    fn default() -> Self {
207        Self::new(MemoryConfig::default())
208    }
209}
210
211impl MemoryManager {
212    /// Create new memory manager
213    pub fn new(config: MemoryConfig) -> Self {
214        Self {
215            buffer_pools: Arc::new(RwLock::new(HashMap::new())),
216            array2d_pools: Arc::new(RwLock::new(HashMap::new())),
217            cache_manager: Arc::new(RwLock::new(CacheManager::new(&config))),
218            memory_stats: Arc::new(RwLock::new(MemoryStatistics::default())),
219            config,
220        }
221    }
222
223    /// Get buffer from pool or create new one
224    pub async fn get_buffer(&self, size: usize) -> Array1<f32> {
225        let mut pools = self.buffer_pools.write().await;
226        let pool = pools
227            .entry(size)
228            .or_insert_with(|| BufferPool::new(size, self.config.max_buffer_pool_size));
229
230        if let Some(mut buffer) = pool.available.pop_front() {
231            // Clear the buffer for reuse
232            buffer.fill(0.0);
233            pool.pool_hits += 1;
234            self.update_buffer_stats(size, false).await;
235            buffer
236        } else {
237            // Create new buffer
238            pool.total_allocations += 1;
239            self.update_buffer_stats(size, true).await;
240            Array1::zeros(size)
241        }
242    }
243
244    /// Return buffer to pool
245    pub async fn return_buffer(&self, buffer: Array1<f32>) {
246        let size = buffer.len();
247        let mut pools = self.buffer_pools.write().await;
248
249        if let Some(pool) = pools.get_mut(&size) {
250            if pool.available.len() < pool.max_size {
251                pool.available.push_back(buffer);
252            }
253            // If pool is full, buffer will be dropped
254        }
255    }
256
257    /// Get 2D array from pool or create new one
258    pub async fn get_array2d(&self, rows: usize, cols: usize) -> Array2<f32> {
259        let dims = (rows, cols);
260        let mut pools = self.array2d_pools.write().await;
261        let pool = pools
262            .entry(dims)
263            .or_insert_with(|| Array2Pool::new(dims, self.config.max_buffer_pool_size));
264
265        if let Some(mut array) = pool.available.pop_front() {
266            // Clear the array for reuse
267            array.fill(0.0);
268            pool.pool_hits += 1;
269            array
270        } else {
271            // Create new array
272            pool.total_allocations += 1;
273            Array2::zeros(dims)
274        }
275    }
276
277    /// Return 2D array to pool
278    pub async fn return_array2d(&self, array: Array2<f32>) {
279        let dims = array.dim();
280        let mut pools = self.array2d_pools.write().await;
281
282        if let Some(pool) = pools.get_mut(&dims) {
283            if pool.available.len() < pool.max_size {
284                pool.available.push_back(array);
285            }
286        }
287    }
288
289    /// Cache HRTF interpolation result
290    pub async fn cache_hrtf(
291        &self,
292        key: (i32, i32, f32),
293        left_hrir: Array1<f32>,
294        right_hrir: Array1<f32>,
295    ) {
296        let cache_key = HrtfCacheKey {
297            azimuth: key.0,
298            elevation: key.1,
299            distance: (key.2 * 1000.0) as u32, // Store in millimeters
300        };
301
302        let entry = HrtfCacheEntry {
303            left_hrir,
304            right_hrir,
305            last_accessed: Instant::now(),
306            access_count: 1,
307        };
308
309        let mut cache_manager = self.cache_manager.write().await;
310        cache_manager.cache_hrtf(cache_key, entry).await;
311    }
312
313    /// Get cached HRTF result
314    pub async fn get_cached_hrtf(
315        &self,
316        key: (i32, i32, f32),
317    ) -> Option<(Array1<f32>, Array1<f32>)> {
318        let cache_key = HrtfCacheKey {
319            azimuth: key.0,
320            elevation: key.1,
321            distance: (key.2 * 1000.0) as u32,
322        };
323
324        let mut cache_manager = self.cache_manager.write().await;
325        cache_manager.get_hrtf(&cache_key).await
326    }
327
328    /// Cache distance attenuation result
329    pub async fn cache_distance_attenuation(
330        &self,
331        distance: f32,
332        model_type: u8,
333        attenuation: f32,
334    ) {
335        let key = DistanceCacheKey {
336            distance_mm: (distance * 1000.0) as u32,
337            model_type,
338        };
339
340        let mut cache_manager = self.cache_manager.write().await;
341        cache_manager.cache_distance(key, attenuation).await;
342    }
343
344    /// Get cached distance attenuation
345    pub async fn get_cached_distance_attenuation(
346        &self,
347        distance: f32,
348        model_type: u8,
349    ) -> Option<f32> {
350        let key = DistanceCacheKey {
351            distance_mm: (distance * 1000.0) as u32,
352            model_type,
353        };
354
355        let cache_manager = self.cache_manager.read().await;
356        cache_manager.get_distance(&key)
357    }
358
359    /// Get memory statistics
360    pub async fn get_memory_stats(&self) -> MemoryStatistics {
361        let stats = self.memory_stats.read().await;
362        stats.clone()
363    }
364
365    /// Check memory pressure and trigger cleanup if needed
366    pub async fn check_memory_pressure(&self) -> bool {
367        let stats = self.memory_stats.read().await;
368        if stats.memory_pressure > self.config.memory_pressure_threshold {
369            drop(stats); // Release read lock
370            self.cleanup_memory().await;
371            true
372        } else {
373            false
374        }
375    }
376
377    /// Cleanup memory when under pressure
378    async fn cleanup_memory(&self) {
379        // Clear least recently used cache entries
380        let mut cache_manager = self.cache_manager.write().await;
381        cache_manager
382            .evict_lru_entries(self.config.max_cache_size / 2)
383            .await;
384
385        // Trim buffer pools
386        self.trim_buffer_pools().await;
387
388        // Update statistics
389        self.update_memory_stats().await;
390    }
391
392    /// Trim buffer pools to free memory
393    async fn trim_buffer_pools(&self) {
394        let mut pools = self.buffer_pools.write().await;
395        for pool in pools.values_mut() {
396            pool.available.truncate(pool.max_size / 2);
397        }
398
399        let mut array_pools = self.array2d_pools.write().await;
400        for pool in array_pools.values_mut() {
401            pool.available.truncate(pool.max_size / 2);
402        }
403    }
404
405    /// Update buffer pool statistics
406    async fn update_buffer_stats(&self, size: usize, is_new_allocation: bool) {
407        let mut stats = self.memory_stats.write().await;
408
409        // Handle memory allocation tracking
410        if is_new_allocation {
411            stats.total_allocated += (size * std::mem::size_of::<f32>()) as u64;
412        }
413
414        // Update pool-specific stats
415        {
416            let pool_stats = stats.buffer_pool_stats.entry(size).or_default();
417            if is_new_allocation {
418                pool_stats.total_allocations += 1;
419            } else {
420                pool_stats.pool_hits += 1;
421            }
422            pool_stats.hit_rate =
423                pool_stats.pool_hits as f64 / pool_stats.total_allocations.max(1) as f64;
424        }
425
426        stats.last_updated = Instant::now();
427    }
428
429    /// Update memory statistics
430    async fn update_memory_stats(&self) {
431        let mut stats = self.memory_stats.write().await;
432
433        // Calculate memory usage from pools
434        let pools = self.buffer_pools.read().await;
435        let mut memory_in_use = 0u64;
436
437        for (size, pool) in pools.iter() {
438            let pool_memory = (pool.available.len() * size * std::mem::size_of::<f32>()) as u64;
439            memory_in_use += pool_memory;
440
441            let pool_stats = stats.buffer_pool_stats.entry(*size).or_default();
442            pool_stats.current_pool_size = pool.available.len();
443        }
444
445        stats.memory_in_use = memory_in_use;
446        if memory_in_use > stats.peak_memory_usage {
447            stats.peak_memory_usage = memory_in_use;
448        }
449
450        // Calculate memory pressure (simplified)
451        stats.memory_pressure = (memory_in_use as f32 / (1024.0 * 1024.0 * 1024.0)).min(1.0); // Normalize to GB
452        stats.last_updated = Instant::now();
453    }
454}
455
456impl<T> BufferPool<T> {
457    fn new(size: usize, max_size: usize) -> Self {
458        Self {
459            available: VecDeque::with_capacity(max_size),
460            max_size,
461            total_allocations: 0,
462            pool_hits: 0,
463        }
464    }
465}
466
467impl Array2Pool {
468    fn new(dimensions: (usize, usize), max_size: usize) -> Self {
469        Self {
470            available: VecDeque::with_capacity(max_size),
471            dimensions,
472            max_size,
473            total_allocations: 0,
474            pool_hits: 0,
475        }
476    }
477}
478
479impl CacheManager {
480    fn new(config: &MemoryConfig) -> Self {
481        Self {
482            hrtf_cache: HashMap::new(),
483            distance_cache: HashMap::new(),
484            room_cache: HashMap::new(),
485            cache_stats: CacheStatistics::default(),
486            max_cache_size: config.max_cache_size,
487        }
488    }
489
490    async fn cache_hrtf(&mut self, key: HrtfCacheKey, entry: HrtfCacheEntry) {
491        if self.hrtf_cache.len() >= self.max_cache_size {
492            self.evict_lru_hrtf().await;
493        }
494        self.hrtf_cache.insert(key, entry);
495    }
496
497    async fn get_hrtf(&mut self, key: &HrtfCacheKey) -> Option<(Array1<f32>, Array1<f32>)> {
498        if let Some(entry) = self.hrtf_cache.get_mut(key) {
499            entry.last_accessed = Instant::now();
500            entry.access_count += 1;
501            self.cache_stats.cache_hits += 1;
502            Some((entry.left_hrir.clone(), entry.right_hrir.clone()))
503        } else {
504            self.cache_stats.cache_misses += 1;
505            None
506        }
507    }
508
509    async fn cache_distance(&mut self, key: DistanceCacheKey, value: f32) {
510        if self.distance_cache.len() >= self.max_cache_size {
511            // Simple eviction - remove oldest entries
512            if self.distance_cache.len() > self.max_cache_size * 3 / 4 {
513                let keys: Vec<_> = self.distance_cache.keys().cloned().collect();
514                for key in keys.iter().take(self.max_cache_size / 4) {
515                    self.distance_cache.remove(key);
516                }
517            }
518        }
519        self.distance_cache.insert(key, value);
520    }
521
522    fn get_distance(&self, key: &DistanceCacheKey) -> Option<f32> {
523        self.distance_cache.get(key).copied()
524    }
525
526    async fn evict_lru_entries(&mut self, count: usize) {
527        // Evict LRU HRTF entries
528        let mut entries: Vec<_> = self.hrtf_cache.iter().collect();
529        entries.sort_by_key(|a| a.1.last_accessed);
530
531        let to_remove: Vec<_> = entries
532            .iter()
533            .take(count.min(entries.len()))
534            .map(|(k, _)| (*k).clone())
535            .collect();
536        for key in to_remove {
537            self.hrtf_cache.remove(&key);
538            self.cache_stats.cache_evictions += 1;
539        }
540    }
541
542    async fn evict_lru_hrtf(&mut self) {
543        if let Some((oldest_key, _)) = self
544            .hrtf_cache
545            .iter()
546            .min_by_key(|(_, entry)| entry.last_accessed)
547        {
548            let key_to_remove = oldest_key.clone();
549            self.hrtf_cache.remove(&key_to_remove);
550            self.cache_stats.cache_evictions += 1;
551        }
552    }
553}
554
555/// Cache-friendly data layout optimization utilities
556pub mod cache_optimization {
557    use super::*;
558
559    /// Struct-of-Arrays pattern for better cache locality
560    #[derive(Debug)]
561    pub struct SoAPositions {
562        /// X coordinates
563        pub x: Vec<f32>,
564        /// Y coordinates  
565        pub y: Vec<f32>,
566        /// Z coordinates
567        pub z: Vec<f32>,
568        /// Capacity
569        pub capacity: usize,
570    }
571
572    impl SoAPositions {
573        /// Create new SoA position array
574        pub fn with_capacity(capacity: usize) -> Self {
575            Self {
576                x: Vec::with_capacity(capacity),
577                y: Vec::with_capacity(capacity),
578                z: Vec::with_capacity(capacity),
579                capacity,
580            }
581        }
582
583        /// Add position
584        pub fn push(&mut self, pos: Position3D) {
585            self.x.push(pos.x);
586            self.y.push(pos.y);
587            self.z.push(pos.z);
588        }
589
590        /// Get position by index
591        pub fn get(&self, index: usize) -> Option<Position3D> {
592            if index < self.len() {
593                Some(Position3D::new(self.x[index], self.y[index], self.z[index]))
594            } else {
595                None
596            }
597        }
598
599        /// Length
600        pub fn len(&self) -> usize {
601            self.x.len()
602        }
603
604        /// Is empty
605        pub fn is_empty(&self) -> bool {
606            self.len() == 0
607        }
608
609        /// Clear all positions
610        pub fn clear(&mut self) {
611            self.x.clear();
612            self.y.clear();
613            self.z.clear();
614        }
615    }
616
617    /// Prefetch data for cache optimization
618    #[cfg(target_arch = "x86_64")]
619    #[allow(unsafe_code)]
620    pub fn prefetch_data<T>(data: *const T) {
621        #[cfg(target_feature = "sse")]
622        unsafe {
623            std::arch::x86_64::_mm_prefetch(data as *const i8, std::arch::x86_64::_MM_HINT_T0);
624        }
625    }
626
627    #[cfg(not(target_arch = "x86_64"))]
628    /// Prefetch data (no-op on non-x86 architectures)
629    pub fn prefetch_data<T>(_data: *const T) {
630        // No-op on non-x86 architectures
631    }
632}
633
634#[cfg(test)]
635mod tests {
636    use super::*;
637
638    #[tokio::test]
639    async fn test_memory_manager_creation() {
640        let config = MemoryConfig::default();
641        let manager = MemoryManager::new(config);
642
643        let stats = manager.get_memory_stats().await;
644        assert_eq!(stats.total_allocated, 0);
645    }
646
647    #[tokio::test]
648    async fn test_buffer_pool_reuse() {
649        let config = MemoryConfig::default();
650        let manager = MemoryManager::new(config);
651
652        // Get buffer and return it
653        let buffer = manager.get_buffer(1024).await;
654        assert_eq!(buffer.len(), 1024);
655        manager.return_buffer(buffer).await;
656
657        // Get another buffer - should be reused
658        let buffer2 = manager.get_buffer(1024).await;
659        assert_eq!(buffer2.len(), 1024);
660
661        let stats = manager.get_memory_stats().await;
662        assert!(stats.buffer_pool_stats.contains_key(&1024));
663    }
664
665    #[tokio::test]
666    async fn test_hrtf_cache() {
667        let config = MemoryConfig::default();
668        let manager = MemoryManager::new(config);
669
670        let left = Array1::zeros(256);
671        let right = Array1::zeros(256);
672
673        // Cache HRTF
674        manager
675            .cache_hrtf((45, 0, 2.0), left.clone(), right.clone())
676            .await;
677
678        // Retrieve from cache
679        let cached = manager.get_cached_hrtf((45, 0, 2.0)).await;
680        assert!(cached.is_some());
681
682        let (cached_left, cached_right) = cached.expect("Cached HRTF should be available");
683        assert_eq!(cached_left.len(), 256);
684        assert_eq!(cached_right.len(), 256);
685    }
686
687    #[tokio::test]
688    async fn test_distance_cache() {
689        let config = MemoryConfig::default();
690        let manager = MemoryManager::new(config);
691
692        // Cache distance attenuation
693        manager.cache_distance_attenuation(5.0, 1, 0.2).await;
694
695        // Retrieve from cache
696        let cached = manager.get_cached_distance_attenuation(5.0, 1).await;
697        assert_eq!(cached, Some(0.2));
698
699        // Non-existent entry
700        let not_cached = manager.get_cached_distance_attenuation(10.0, 1).await;
701        assert_eq!(not_cached, None);
702    }
703
704    #[tokio::test]
705    async fn test_memory_pressure() {
706        let mut config = MemoryConfig::default();
707        config.memory_pressure_threshold = 0.1; // Low threshold for testing
708        let manager = MemoryManager::new(config);
709
710        // Allocate many buffers to trigger pressure
711        let mut buffers = Vec::new();
712        for _ in 0..100 {
713            buffers.push(manager.get_buffer(1024).await);
714        }
715
716        // Update stats manually (in real usage this would be automatic)
717        manager.update_memory_stats().await;
718
719        // Check if cleanup is triggered
720        let pressure_detected = manager.check_memory_pressure().await;
721        // This test is simplified - in a real scenario we'd need more sophisticated pressure detection
722    }
723
724    #[tokio::test]
725    async fn test_soa_positions() {
726        let mut positions = cache_optimization::SoAPositions::with_capacity(10);
727
728        positions.push(Position3D::new(1.0, 2.0, 3.0));
729        positions.push(Position3D::new(4.0, 5.0, 6.0));
730
731        assert_eq!(positions.len(), 2);
732
733        let pos = positions.get(0).expect("First position should exist");
734        assert_eq!(pos.x, 1.0);
735        assert_eq!(pos.y, 2.0);
736        assert_eq!(pos.z, 3.0);
737    }
738}