Skip to main content

torsh_tensor/
cache_optimization.rs

1// Cache optimization module for improving memory layout and access patterns
2
3#[cfg(feature = "simd")]
4use crate::storage::SimdStorage;
5use crate::{Tensor, TensorStorage};
6use std::collections::HashMap;
7use std::sync::{Arc, Mutex};
8use std::time::{Duration, Instant};
9use torsh_core::{
10    dtype::TensorElement,
11    error::{Result, TorshError},
12    shape::Shape,
13};
14
15#[cfg(feature = "simd")]
16use scirs2_core::simd_aligned::AlignedVec;
17
18/// Cache analysis report providing detailed performance metrics
19#[derive(Debug, Clone)]
20pub struct CacheAnalysisReport {
21    /// Overall cache efficiency score (0.0 to 1.0)
22    pub cache_efficiency: f64,
23    /// Estimated number of cache misses for typical access patterns
24    pub estimated_cache_misses: usize,
25    /// Spatial locality score (0.0 to 1.0)
26    pub spatial_locality_score: f64,
27    /// Temporal locality score (0.0 to 1.0)
28    pub temporal_locality_score: f64,
29    /// Whether current memory layout is optimal
30    pub memory_layout_optimal: bool,
31    /// List of recommended optimizations
32    pub recommended_optimizations: Vec<String>,
33}
34
35impl<T: TensorElement + Copy> Tensor<T> {
36    /// Memory layout optimization for cache efficiency
37    /// Analyzes and optimizes the tensor's memory layout to improve cache performance
38    pub fn optimize_cache_layout(&mut self) -> Result<()> {
39        // Check if tensor is large enough to benefit from optimization
40        if self.numel() < 1024 {
41            return Ok(()); // Skip small tensors
42        }
43
44        // Analyze current access pattern and stride layout
45        let current_strides = self.compute_strides();
46        let optimal_order = self.determine_optimal_dimension_order(&current_strides);
47
48        // If current layout is already optimal, return early
49        if optimal_order.iter().enumerate().all(|(i, &dim)| dim == i) {
50            return Ok(());
51        }
52
53        // Reorganize data for better cache locality
54        self.reorder_dimensions(&optimal_order)?;
55
56        // Add padding for cache line alignment if beneficial
57        self.add_cache_padding()?;
58
59        Ok(())
60    }
61
62    /// Determine optimal dimension order for cache efficiency
63    /// Prioritizes dimensions that are accessed more frequently together
64    fn determine_optimal_dimension_order(&self, strides: &[usize]) -> Vec<usize> {
65        let shape_binding = self.shape();
66        let dims = shape_binding.dims();
67        let mut dim_priorities: Vec<(usize, f64)> = (0..dims.len())
68            .map(|i| {
69                // Calculate priority based on dimension size and stride
70                let size_factor = dims[i] as f64;
71                let stride_factor = 1.0 / (strides[i] as f64 + 1.0);
72                let cache_friendliness = size_factor * stride_factor;
73                (i, cache_friendliness)
74            })
75            .collect();
76
77        // Sort by cache friendliness (higher is better)
78        dim_priorities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
79
80        dim_priorities.into_iter().map(|(dim, _)| dim).collect()
81    }
82
83    /// Reorder tensor dimensions for optimal cache access
84    fn reorder_dimensions(&mut self, optimal_order: &[usize]) -> Result<()> {
85        if optimal_order.len() != self.ndim() {
86            return Err(TorshError::InvalidOperation(
87                "Dimension order length mismatch".to_string(),
88            ));
89        }
90
91        // Create permutation for transpose operation
92        let data = self.to_vec()?;
93        let old_dims = self.shape().dims().to_vec();
94        let old_strides = self.compute_strides();
95
96        // Calculate new dimensions and create reordered data
97        let new_dims: Vec<usize> = optimal_order.iter().map(|&i| old_dims[i]).collect();
98        let new_numel = new_dims.iter().product::<usize>();
99        let mut new_data = vec![data[0]; new_numel]; // Initialize with first element
100
101        // Reorder data according to optimal dimension order
102        #[allow(clippy::needless_range_loop)]
103        for i in 0..new_numel {
104            let mut old_indices = vec![0; self.ndim()];
105            let mut remaining = i;
106
107            // Convert flat index to multi-dimensional indices in new layout
108            for (j, &dim_size) in new_dims.iter().enumerate().rev() {
109                old_indices[optimal_order[j]] = remaining % dim_size;
110                remaining /= dim_size;
111            }
112
113            // Calculate flat index in original layout
114            let old_flat_index: usize = old_indices
115                .iter()
116                .zip(old_strides.iter())
117                .map(|(&idx, &stride)| idx * stride)
118                .sum();
119
120            new_data[i] = data[old_flat_index];
121        }
122
123        // Update tensor with optimized layout
124        self.storage = TensorStorage::create_optimal(new_data)?;
125        self.shape = Shape::new(new_dims);
126
127        Ok(())
128    }
129
130    /// Add cache-line aligned padding for better memory access patterns
131    fn add_cache_padding(&mut self) -> Result<()> {
132        const CACHE_LINE_SIZE: usize = 64; // bytes
133        let element_size = std::mem::size_of::<T>();
134        let elements_per_cache_line = CACHE_LINE_SIZE / element_size;
135
136        // Only add padding if it would be beneficial
137        let shape_binding = self.shape();
138        let dims = shape_binding.dims();
139        if dims.is_empty() || dims[dims.len() - 1] % elements_per_cache_line == 0 {
140            return Ok(()); // Already aligned or no benefit
141        }
142
143        // Calculate padding needed for last dimension
144        let last_dim = dims[dims.len() - 1];
145        let padded_last_dim = last_dim.div_ceil(elements_per_cache_line) * elements_per_cache_line;
146        let padding_needed = padded_last_dim - last_dim;
147
148        // Only add padding if overhead is reasonable (< 25%)
149        if (padding_needed as f64 / last_dim as f64) > 0.25 {
150            return Ok(());
151        }
152
153        let data = self.to_vec()?;
154        let mut new_dims = dims.to_vec();
155        let last_idx = new_dims.len() - 1;
156        new_dims[last_idx] = padded_last_dim;
157
158        // Create padded data
159        let new_numel = new_dims.iter().product::<usize>();
160        let mut padded_data = Vec::with_capacity(new_numel);
161
162        let outer_size = new_numel / padded_last_dim;
163        for i in 0..outer_size {
164            let start_idx = i * last_dim;
165            let end_idx = (i + 1) * last_dim;
166
167            // Copy original data
168            padded_data.extend_from_slice(&data[start_idx..end_idx]);
169
170            // Add padding (zeros)
171            for _ in 0..padding_needed {
172                padded_data.push(data[0]); // Use first element as padding value
173            }
174        }
175
176        // Update tensor with padded layout
177        self.storage = TensorStorage::create_optimal(padded_data)?;
178        self.shape = Shape::new(new_dims);
179
180        Ok(())
181    }
182
183    /// Analyze memory access patterns and provide optimization recommendations
184    pub fn analyze_cache_performance(&self) -> CacheAnalysisReport {
185        let shape_binding = self.shape();
186        let dims = shape_binding.dims();
187        let strides = self.compute_strides();
188        let numel = self.numel();
189
190        // Calculate cache efficiency metrics
191        let mut cache_misses_estimate = 0f64;
192
193        // Estimate cache misses based on stride patterns
194        for (i, &stride) in strides.iter().enumerate() {
195            let dimension_accesses = dims[i] as f64;
196            let stride_penalty = if stride > 64 {
197                stride as f64 / 64.0
198            } else {
199                1.0
200            };
201            cache_misses_estimate += dimension_accesses * stride_penalty;
202        }
203
204        // Calculate spatial locality (how well adjacent elements are accessed together)
205        let spatial_locality_score = if strides.last().copied().unwrap_or(1) == 1usize {
206            1.0
207        } else {
208            1.0 / strides.last().copied().unwrap_or(1) as f64
209        };
210
211        // Calculate temporal locality (reuse of recently accessed data)
212        let temporal_locality_score = 1.0 / (numel as f64).log2().max(1.0);
213
214        CacheAnalysisReport {
215            cache_efficiency: (spatial_locality_score + temporal_locality_score) / 2.0,
216            estimated_cache_misses: cache_misses_estimate as usize,
217            spatial_locality_score,
218            temporal_locality_score,
219            memory_layout_optimal: strides.last().copied().unwrap_or(1) == 1usize,
220            recommended_optimizations: self.generate_optimization_recommendations(&strides),
221        }
222    }
223
224    /// Generate specific optimization recommendations based on current layout
225    fn generate_optimization_recommendations(&self, strides: &[usize]) -> Vec<String> {
226        let mut recommendations = Vec::new();
227        let shape_binding = self.shape();
228        let dims = shape_binding.dims();
229
230        // Check for non-contiguous memory layout
231        if strides.last().copied().unwrap_or(1) != 1 {
232            recommendations
233                .push("Consider using .contiguous() to ensure row-major layout".to_string());
234        }
235
236        // Check for small tensors that don't benefit from optimization
237        if self.numel() < 1024 {
238            recommendations.push("Tensor too small to benefit from cache optimization".to_string());
239        }
240
241        // Check for dimensions that could benefit from reordering
242        if dims.len() > 2 {
243            let largest_dim = dims.iter().enumerate().max_by_key(|(_, &size)| size);
244            if let Some((largest_idx, _)) = largest_dim {
245                if largest_idx != dims.len() - 1 {
246                    recommendations.push(format!(
247                        "Consider moving dimension {largest_idx} to the end for better cache locality"
248                    ));
249                }
250            }
251        }
252
253        // Check for padding opportunities
254        const CACHE_LINE_SIZE: usize = 64;
255        let element_size = std::mem::size_of::<T>();
256        let elements_per_cache_line = CACHE_LINE_SIZE / element_size;
257
258        if !dims.is_empty() {
259            let last_dim = dims[dims.len() - 1];
260            if last_dim % elements_per_cache_line != 0 {
261                recommendations
262                    .push("Consider adding cache-line padding for better alignment".to_string());
263            }
264        }
265
266        recommendations
267    }
268
269    /// Create a cache-optimized copy of the tensor
270    pub fn to_cache_optimized(&self) -> Result<Self> {
271        let mut optimized = self.clone();
272        optimized.optimize_cache_layout()?;
273        Ok(optimized)
274    }
275
276    /// Get memory usage statistics for the tensor
277    pub fn memory_stats(&self) -> MemoryStats {
278        let element_size = std::mem::size_of::<T>();
279        let total_elements = self.numel();
280        let total_bytes = total_elements * element_size;
281
282        // Estimate memory overhead based on storage type
283        let overhead_bytes = match &self.storage {
284            TensorStorage::InMemory(_) => {
285                // Arc + RwLock overhead
286                std::mem::size_of::<std::sync::Arc<std::sync::RwLock<Vec<T>>>>()
287            }
288            TensorStorage::MemoryMapped(_) => {
289                // Memory mapped storage overhead
290                1024 // Approximate overhead for file handles, cache, etc.
291            }
292            #[cfg(feature = "simd")]
293            TensorStorage::Aligned(_) => {
294                // Arc + RwLock + AlignedVec overhead
295                std::mem::size_of::<std::sync::Arc<std::sync::RwLock<AlignedVec<T>>>>()
296            }
297            #[cfg(feature = "simd")]
298            TensorStorage::SimdOptimized(_) => {
299                // Arc + SimdStorage overhead (no RwLock, so less overhead)
300                std::mem::size_of::<std::sync::Arc<SimdStorage<T>>>()
301            }
302        };
303
304        MemoryStats {
305            total_bytes,
306            element_size,
307            total_elements,
308            overhead_bytes,
309            is_memory_mapped: matches!(&self.storage, TensorStorage::MemoryMapped(_)),
310        }
311    }
312}
313
314/// Memory usage statistics for a tensor
315#[derive(Debug, Clone)]
316pub struct MemoryStats {
317    /// Total memory used by tensor data in bytes
318    pub total_bytes: usize,
319    /// Size of each element in bytes
320    pub element_size: usize,
321    /// Total number of elements
322    pub total_elements: usize,
323    /// Memory overhead from storage structures
324    pub overhead_bytes: usize,
325    /// Whether tensor uses memory-mapped storage
326    pub is_memory_mapped: bool,
327}
328
329impl MemoryStats {
330    /// Get effective memory usage (data + overhead)
331    pub fn effective_bytes(&self) -> usize {
332        self.total_bytes + self.overhead_bytes
333    }
334
335    /// Get memory efficiency (data bytes / total bytes)
336    pub fn efficiency(&self) -> f64 {
337        self.total_bytes as f64 / self.effective_bytes() as f64
338    }
339}
340
341/// Global memory pool for temporary tensor allocations
342pub struct TensorMemoryPool {
343    /// Pooled memory blocks organized by size
344    pool: Arc<Mutex<HashMap<usize, Vec<Vec<u8>>>>>,
345    /// Memory allocation statistics
346    stats: Arc<Mutex<PoolStatistics>>,
347    /// Maximum memory pool size in bytes
348    max_pool_size: usize,
349    /// Current pool size in bytes
350    current_pool_size: Arc<Mutex<usize>>,
351}
352
353#[derive(Debug, Clone, Default)]
354pub struct PoolStatistics {
355    pub allocations: usize,
356    pub deallocations: usize,
357    pub cache_hits: usize,
358    pub cache_misses: usize,
359    pub peak_memory_usage: usize,
360    pub total_memory_saved: usize,
361}
362
363impl TensorMemoryPool {
364    /// Create a new memory pool with specified maximum size
365    pub fn new(max_size_mb: usize) -> Self {
366        Self {
367            pool: Arc::new(Mutex::new(HashMap::new())),
368            stats: Arc::new(Mutex::new(PoolStatistics::default())),
369            max_pool_size: max_size_mb * 1024 * 1024,
370            current_pool_size: Arc::new(Mutex::new(0)),
371        }
372    }
373
374    /// Allocate memory from pool or create new
375    pub fn allocate(&self, size_bytes: usize) -> Vec<u8> {
376        let mut pool = self.pool.lock().expect("lock should not be poisoned");
377        let mut stats = self.stats.lock().expect("lock should not be poisoned");
378
379        stats.allocations += 1;
380
381        // Round up to next power of 2 for better pooling
382        let rounded_size = size_bytes.next_power_of_two();
383
384        if let Some(pool_vec) = pool.get_mut(&rounded_size) {
385            if let Some(memory) = pool_vec.pop() {
386                stats.cache_hits += 1;
387                let mut current_size = self
388                    .current_pool_size
389                    .lock()
390                    .expect("lock should not be poisoned");
391                *current_size -= rounded_size;
392                return memory;
393            }
394        }
395
396        stats.cache_misses += 1;
397        vec![0u8; rounded_size]
398    }
399
400    /// Return memory to pool
401    pub fn deallocate(&self, mut memory: Vec<u8>) {
402        let size = memory.len();
403        let mut pool = self.pool.lock().expect("lock should not be poisoned");
404        let mut stats = self.stats.lock().expect("lock should not be poisoned");
405        let mut current_size = self
406            .current_pool_size
407            .lock()
408            .expect("lock should not be poisoned");
409
410        stats.deallocations += 1;
411
412        // Only pool if under size limit
413        if *current_size + size <= self.max_pool_size {
414            // Clear the memory before pooling for security
415            memory.fill(0);
416
417            pool.entry(size).or_default().push(memory);
418            *current_size += size;
419            stats.total_memory_saved += size;
420        }
421
422        stats.peak_memory_usage = stats.peak_memory_usage.max(*current_size);
423    }
424
425    /// Get pool statistics
426    pub fn get_statistics(&self) -> PoolStatistics {
427        self.stats
428            .lock()
429            .expect("lock should not be poisoned")
430            .clone()
431    }
432
433    /// Clear the entire pool
434    pub fn clear(&self) {
435        let mut pool = self.pool.lock().expect("lock should not be poisoned");
436        let mut current_size = self
437            .current_pool_size
438            .lock()
439            .expect("lock should not be poisoned");
440
441        pool.clear();
442        *current_size = 0;
443    }
444}
445
446/// Memory pressure detection and adaptive allocation
447pub struct MemoryPressureMonitor {
448    /// Memory usage samples
449    samples: Arc<Mutex<Vec<(Instant, usize)>>>,
450    /// Current pressure level (0.0 to 1.0)
451    pressure_level: Arc<Mutex<f64>>,
452    /// System memory threshold for high pressure
453    high_pressure_threshold: usize,
454}
455
456impl MemoryPressureMonitor {
457    pub fn new(memory_limit_mb: usize) -> Self {
458        Self {
459            samples: Arc::new(Mutex::new(Vec::new())),
460            pressure_level: Arc::new(Mutex::new(0.0)),
461            high_pressure_threshold: memory_limit_mb * 1024 * 1024,
462        }
463    }
464
465    /// Record memory usage sample
466    pub fn record_usage(&self, bytes_used: usize) {
467        let mut samples = self.samples.lock().expect("lock should not be poisoned");
468        let mut pressure = self
469            .pressure_level
470            .lock()
471            .expect("lock should not be poisoned");
472
473        let now = Instant::now();
474        samples.push((now, bytes_used));
475
476        // Keep only recent samples (last 60 seconds)
477        samples.retain(|(time, _)| now.duration_since(*time) < Duration::from_secs(60));
478
479        // Calculate pressure based on recent usage
480        let avg_usage = if samples.is_empty() {
481            0.0
482        } else {
483            samples.iter().map(|(_, usage)| *usage as f64).sum::<f64>() / samples.len() as f64
484        };
485
486        *pressure = (avg_usage / self.high_pressure_threshold as f64).min(1.0);
487    }
488
489    /// Get current memory pressure level
490    pub fn get_pressure_level(&self) -> f64 {
491        *self
492            .pressure_level
493            .lock()
494            .expect("lock should not be poisoned")
495    }
496
497    /// Check if system is under high memory pressure
498    pub fn is_high_pressure(&self) -> bool {
499        self.get_pressure_level() > 0.8
500    }
501}
502
503/// NUMA-aware memory allocation hints
504#[derive(Debug, Clone, Copy)]
505pub enum NumaNode {
506    Local,
507    Node(u32),
508    Interleaved,
509}
510
511#[derive(Debug, Clone)]
512pub struct NumaAllocationHint {
513    pub preferred_node: NumaNode,
514    pub allow_fallback: bool,
515    pub bind_threads: bool,
516}
517
518impl<T: TensorElement + Copy + Default> Tensor<T> {
519    /// Advanced memory optimization with NUMA awareness
520    pub fn optimize_memory_layout(&mut self, numa_hint: Option<NumaAllocationHint>) -> Result<()> {
521        // Basic cache optimization
522        self.optimize_cache_layout()?;
523
524        // Apply NUMA optimization if hint provided
525        if let Some(hint) = numa_hint {
526            self.apply_numa_optimization(hint)?;
527        }
528
529        // Memory access pattern prediction
530        self.optimize_access_patterns()?;
531
532        Ok(())
533    }
534
535    /// Apply NUMA-specific optimizations
536    fn apply_numa_optimization(&mut self, _hint: NumaAllocationHint) -> Result<()> {
537        // NUMA optimization would require platform-specific implementation
538        // For now, we'll implement basic interleaving for large tensors
539        if self.numel() > 1_000_000 {
540            // Large tensors benefit from interleaved allocation
541            // This would require platform-specific NUMA API calls
542            // For now, just ensure contiguous layout
543            if !self.is_contiguous() {
544                let contiguous_tensor = self.contiguous()?;
545                *self = contiguous_tensor;
546            }
547        }
548        Ok(())
549    }
550
551    /// Optimize memory access patterns based on predicted usage
552    fn optimize_access_patterns(&mut self) -> Result<()> {
553        let shape_binding = self.shape();
554        let dims = shape_binding.dims();
555
556        // For matrices, optimize for row-major access
557        if dims.len() == 2 && dims[0] > 64 && dims[1] > 64 {
558            // Check if we should transpose for better cache behavior
559            let row_size = dims[1] * std::mem::size_of::<T>();
560            let cache_line_size = 64;
561
562            // If rows don't align well with cache lines, consider optimization
563            if row_size % cache_line_size != 0 && row_size < cache_line_size * 4 {
564                self.add_cache_padding()?;
565            }
566        }
567
568        // For 3D+ tensors, ensure innermost dimension is cache-friendly
569        if dims.len() >= 3 {
570            let innermost_size = dims[dims.len() - 1] * std::mem::size_of::<T>();
571            if !(32..=256).contains(&innermost_size) {
572                // Consider reshaping for better cache utilization
573                self.add_cache_padding()?;
574            }
575        }
576
577        Ok(())
578    }
579
580    /// Memory-mapped tensor creation with optimization hints
581    pub fn create_memory_mapped_optimized(
582        data: Vec<T>,
583        shape: Vec<usize>,
584        numa_hint: Option<NumaAllocationHint>,
585    ) -> Result<Self> {
586        let mut tensor = Self::from_data(data, shape, torsh_core::device::DeviceType::Cpu)?;
587        tensor.optimize_memory_layout(numa_hint)?;
588        Ok(tensor)
589    }
590
591    /// Prefetch memory pages for better performance
592    pub fn prefetch_data(&self) -> Result<()> {
593        // This would use madvise/PrefetchVirtualMemory on supported platforms
594        // For now, we'll implement a simple memory access pattern
595        if self.numel() > 10_000 {
596            let data = self.to_vec()?;
597            let stride = data.len() / 100; // Sample every 1% of data
598
599            // Touch memory at regular intervals to trigger prefetch
600            let mut _sum = T::default();
601            for i in (0..data.len()).step_by(stride.max(1)) {
602                _sum = data[i]; // Simple memory access to trigger prefetch
603            }
604        }
605        Ok(())
606    }
607}
608
609// Global memory pool instance
610static GLOBAL_MEMORY_POOL: std::sync::OnceLock<TensorMemoryPool> = std::sync::OnceLock::new();
611static MEMORY_PRESSURE_MONITOR: std::sync::OnceLock<MemoryPressureMonitor> =
612    std::sync::OnceLock::new();
613
614/// Get global memory pool
615pub fn get_memory_pool() -> &'static TensorMemoryPool {
616    GLOBAL_MEMORY_POOL.get_or_init(|| TensorMemoryPool::new(1024)) // 1GB default
617}
618
619/// Get memory pressure monitor
620pub fn get_memory_pressure_monitor() -> &'static MemoryPressureMonitor {
621    MEMORY_PRESSURE_MONITOR.get_or_init(|| MemoryPressureMonitor::new(8192)) // 8GB default
622}
623
624#[cfg(test)]
625mod tests {
626    use crate::creation::*;
627
628    #[test]
629    fn test_cache_optimization() {
630        let mut tensor = ones::<f32>(&[100, 100]).expect("ones creation should succeed");
631        assert!(tensor.optimize_cache_layout().is_ok());
632    }
633
634    #[test]
635    fn test_cache_analysis() {
636        let tensor = ones::<f32>(&[64, 64]).expect("ones creation should succeed");
637        let report = tensor.analyze_cache_performance();
638        assert!(report.cache_efficiency >= 0.0 && report.cache_efficiency <= 1.0);
639    }
640
641    #[test]
642    fn test_contiguous_layout() {
643        let tensor = ones::<f32>(&[10, 10]).expect("ones creation should succeed");
644        assert!(tensor.is_contiguous());
645
646        let contiguous = tensor
647            .contiguous()
648            .expect("contiguous conversion should succeed");
649        assert!(contiguous.is_contiguous());
650    }
651
652    #[test]
653    fn test_memory_stats() {
654        let tensor = ones::<f32>(&[100, 100]).expect("ones creation should succeed");
655        let stats = tensor.memory_stats();
656        assert_eq!(stats.total_elements, 10000);
657        assert_eq!(stats.element_size, 4); // f32 is 4 bytes
658        assert_eq!(stats.total_bytes, 40000);
659    }
660
661    #[test]
662    fn test_memory_pool() {
663        use super::*;
664
665        let pool = TensorMemoryPool::new(10); // 10 MB
666
667        // Test allocation
668        let memory1 = pool.allocate(1024);
669        assert_eq!(memory1.len(), 1024);
670
671        let memory2 = pool.allocate(2048);
672        assert_eq!(memory2.len(), 2048);
673
674        // Test deallocation and reuse
675        pool.deallocate(memory1);
676        let memory3 = pool.allocate(1024);
677        assert_eq!(memory3.len(), 1024);
678
679        // Check statistics
680        let stats = pool.get_statistics();
681        assert!(stats.allocations > 0);
682        assert!(stats.deallocations > 0);
683
684        pool.deallocate(memory2);
685        pool.deallocate(memory3);
686    }
687
688    #[test]
689    fn test_memory_pressure_monitor() {
690        use super::*;
691
692        let monitor = MemoryPressureMonitor::new(100); // 100 MB limit
693
694        // Test pressure calculation - monitor uses average of samples
695        monitor.record_usage(50 * 1024 * 1024); // 50 MB
696        assert!(monitor.get_pressure_level() < 0.6);
697
698        monitor.record_usage(90 * 1024 * 1024); // 90 MB
699                                                // Average of 50MB and 90MB = 70MB = 0.7 pressure
700        assert!(monitor.get_pressure_level() > 0.6);
701        assert!(monitor.get_pressure_level() < 0.8);
702        assert!(!monitor.is_high_pressure()); // 0.7 < 0.8, so not high pressure
703
704        // Add a higher pressure reading to trigger high pressure
705        monitor.record_usage(95 * 1024 * 1024); // 95 MB
706                                                // Average of 50MB, 90MB, and 95MB = ~78MB = 0.78 pressure (still < 0.8)
707        monitor.record_usage(100 * 1024 * 1024); // 100 MB
708                                                 // This should push the average above 0.8
709        assert!(monitor.is_high_pressure());
710    }
711
712    #[test]
713    fn test_advanced_memory_optimization() {
714        let mut tensor = ones::<f32>(&[64, 64]).expect("ones creation should succeed");
715
716        // Test with NUMA hint
717        let numa_hint = super::NumaAllocationHint {
718            preferred_node: super::NumaNode::Local,
719            allow_fallback: true,
720            bind_threads: false,
721        };
722
723        assert!(tensor.optimize_memory_layout(Some(numa_hint)).is_ok());
724        assert!(tensor.is_contiguous());
725    }
726
727    #[test]
728    fn test_cache_optimized_creation() {
729        let data: Vec<f32> = (0..10000).map(|i| i as f32).collect();
730        let shape = vec![100, 100];
731
732        let numa_hint = super::NumaAllocationHint {
733            preferred_node: super::NumaNode::Interleaved,
734            allow_fallback: true,
735            bind_threads: false,
736        };
737
738        let tensor = super::Tensor::create_memory_mapped_optimized(data, shape, Some(numa_hint));
739        assert!(tensor.is_ok());
740
741        let tensor = tensor.expect("operation should succeed");
742        // Shape may be optimized with padding for cache efficiency
743        let shape = tensor.shape();
744        let dims = shape.dims();
745        assert_eq!(dims[0], 100); // First dimension should be preserved
746        assert!(dims[1] >= 100); // Second dimension may have padding
747    }
748
749    #[test]
750    fn test_memory_prefetch() {
751        let tensor = ones::<f32>(&[200, 200]).expect("ones creation should succeed");
752        assert!(tensor.prefetch_data().is_ok());
753    }
754
755    #[test]
756    fn test_global_memory_pool_access() {
757        use super::*;
758
759        let pool = get_memory_pool();
760        let memory = pool.allocate(1024);
761        assert_eq!(memory.len(), 1024);
762        pool.deallocate(memory);
763
764        let monitor = get_memory_pressure_monitor();
765        monitor.record_usage(1024 * 1024); // 1 MB
766        assert!(monitor.get_pressure_level() >= 0.0);
767    }
768
769    #[test]
770    fn test_pool_statistics() {
771        use super::*;
772
773        let pool = TensorMemoryPool::new(5); // 5 MB
774
775        // Perform multiple allocations and deallocations
776        let mut memories = Vec::new();
777        for i in 0..10 {
778            let size = (i + 1) * 512;
779            memories.push(pool.allocate(size));
780        }
781
782        for memory in memories {
783            pool.deallocate(memory);
784        }
785
786        let stats = pool.get_statistics();
787        assert_eq!(stats.allocations, 10);
788        assert_eq!(stats.deallocations, 10);
789        assert!(stats.cache_hits + stats.cache_misses == 10);
790
791        pool.clear();
792    }
793
794    #[test]
795    fn test_memory_efficiency_calculation() {
796        let tensor = ones::<f32>(&[50, 50]).expect("ones creation should succeed");
797        let stats = tensor.memory_stats();
798
799        let efficiency = stats.efficiency();
800        assert!(efficiency > 0.0 && efficiency <= 1.0);
801
802        let effective = stats.effective_bytes();
803        assert!(effective >= stats.total_bytes);
804    }
805}