Skip to main content

trustformers_core/
memory.rs

1use crate::errors::{Result, TrustformersError};
2use crate::tensor::Tensor;
3use scirs2_core::ndarray::{s, IxDyn};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::fs::File;
7use std::io::{Read, Seek, SeekFrom};
8use std::sync::{Arc, Mutex, RwLock};
9use std::time::{Duration, Instant};
10
11/// Memory optimization utilities for TrustformeRS
12///
13/// This module provides high-priority memory optimizations:
14/// - Zero-copy tensor views for slice operations
15/// - Memory mapping for large model weights
16/// - Custom allocators for tensor allocation patterns
17/// - Tensor memory recycling pool
18///
19/// Eviction policy for memory pool
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
21pub enum MemoryEvictionPolicy {
22    /// Least Recently Used - evict tensors not used for longest time
23    LRU,
24    /// Least Frequently Used - evict tensors with lowest access count
25    LFU,
26    /// Size-based - evict largest tensors first to free more memory
27    SizeBased,
28    /// Adaptive Replacement Cache - balance between recency and frequency
29    ARC,
30    /// Hybrid - combination of LRU and size-based
31    Hybrid,
32}
33
34/// Adaptive strategy for dynamic pool sizing
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
36pub enum AdaptiveStrategy {
37    /// Fixed pool size (no adaptation)
38    Fixed,
39    /// Grow/shrink based on memory pressure
40    MemoryPressure,
41    /// Adapt based on hit/miss rates
42    HitRate,
43    /// Predict size based on access patterns
44    Predictive,
45}
46
47/// Configuration for memory optimizations
48#[derive(Debug, Clone)]
49pub struct MemoryConfig {
50    /// Enable memory pool for tensor recycling
51    pub enable_memory_pool: bool,
52    /// Maximum size of memory pool in bytes
53    pub max_pool_size: usize,
54    /// Minimum size of memory pool (for adaptive strategies)
55    pub min_pool_size: usize,
56    /// Enable zero-copy tensor views
57    pub enable_zero_copy: bool,
58    /// Enable memory mapping for large tensors
59    pub enable_mmap: bool,
60    /// Minimum size for memory mapping (in bytes)
61    pub mmap_threshold: usize,
62    /// Pool cleanup interval
63    pub cleanup_interval: Duration,
64    /// Eviction policy to use
65    pub eviction_policy: MemoryEvictionPolicy,
66    /// Adaptive strategy for dynamic sizing
67    pub adaptive_strategy: AdaptiveStrategy,
68    /// Target hit rate for adaptive sizing (0.0 to 1.0)
69    pub target_hit_rate: f64,
70    /// Enable prefetching based on access patterns
71    pub enable_prefetching: bool,
72    /// Enable automatic defragmentation
73    pub enable_defragmentation: bool,
74}
75
76impl Default for MemoryConfig {
77    fn default() -> Self {
78        Self {
79            enable_memory_pool: true,
80            max_pool_size: 1024 * 1024 * 1024, // 1GB
81            min_pool_size: 64 * 1024 * 1024,   // 64MB
82            enable_zero_copy: true,
83            enable_mmap: true,
84            mmap_threshold: 100 * 1024 * 1024, // 100MB
85            cleanup_interval: Duration::from_secs(60),
86            eviction_policy: MemoryEvictionPolicy::Hybrid,
87            adaptive_strategy: AdaptiveStrategy::HitRate,
88            target_hit_rate: 0.85, // 85% target hit rate
89            enable_prefetching: true,
90            enable_defragmentation: true,
91        }
92    }
93}
94
95/// Memory pool entry for tensor recycling (enhanced with adaptive metrics)
96#[derive(Debug, Clone)]
97struct PoolEntry {
98    tensor: Tensor,
99    last_used: Instant,
100    ref_count: usize,
101    /// Access frequency counter (for LFU and ARC policies)
102    access_count: usize,
103    /// Creation time (for age-based eviction)
104    #[allow(dead_code)]
105    created_at: Instant,
106    /// Total time in pool (for efficiency metrics)
107    #[allow(dead_code)]
108    pool_time: Duration,
109    /// Tensor size in bytes (cached for quick eviction decisions)
110    size_bytes: usize,
111}
112
113impl PoolEntry {
114    fn new(tensor: Tensor, size_bytes: usize) -> Self {
115        let now = Instant::now();
116        Self {
117            tensor,
118            last_used: now,
119            ref_count: 0,
120            access_count: 0,
121            created_at: now,
122            pool_time: Duration::ZERO,
123            size_bytes,
124        }
125    }
126
127    fn mark_accessed(&mut self) {
128        self.last_used = Instant::now();
129        self.access_count += 1;
130    }
131
132    /// Calculate eviction priority (lower = evict first)
133    fn eviction_priority(&self, policy: MemoryEvictionPolicy) -> f64 {
134        match policy {
135            MemoryEvictionPolicy::LRU => {
136                // Recency: older = lower priority
137                -(self.last_used.elapsed().as_secs_f64())
138            },
139            MemoryEvictionPolicy::LFU => {
140                // Frequency: less used = lower priority
141                -(self.access_count as f64)
142            },
143            MemoryEvictionPolicy::SizeBased => {
144                // Size: larger = lower priority (to free more space)
145                -(self.size_bytes as f64)
146            },
147            MemoryEvictionPolicy::ARC => {
148                // Adaptive: balance recency and frequency
149                let recency_score = 1.0 / (1.0 + self.last_used.elapsed().as_secs_f64());
150                let frequency_score = self.access_count as f64;
151                -(recency_score + frequency_score)
152            },
153            MemoryEvictionPolicy::Hybrid => {
154                // Hybrid: combine recency, frequency, and size
155                let recency = 1.0 / (1.0 + self.last_used.elapsed().as_secs_f64());
156                let frequency = self.access_count as f64;
157                let size_factor = 1.0 / (1.0 + (self.size_bytes as f64 / 1_000_000.0));
158                -(recency * 0.4 + frequency * 0.4 + size_factor * 0.2)
159            },
160        }
161    }
162}
163
164/// Zero-copy tensor view for slice operations
165#[derive(Debug)]
166pub struct TensorView {
167    /// Original tensor reference
168    original: Arc<Tensor>,
169    /// Offset in the original tensor
170    offset: usize,
171    /// Shape of the view
172    shape: Vec<usize>,
173    /// Strides for the view
174    #[allow(dead_code)]
175    strides: Vec<usize>,
176}
177
178impl TensorView {
179    /// Create a new zero-copy view of a tensor slice
180    pub fn slice(tensor: Arc<Tensor>, start: usize, end: usize) -> Result<Self> {
181        let original_shape = tensor.shape();
182        if start >= end || end > original_shape.iter().product::<usize>() {
183            return Err(TrustformersError::invalid_input(
184                "Invalid slice bounds".to_string(),
185            ));
186        }
187
188        let slice_len = end - start;
189        Ok(Self {
190            original: tensor,
191            offset: start,
192            shape: vec![slice_len],
193            strides: vec![1],
194        })
195    }
196
197    /// Get the shape of the view
198    pub fn shape(&self) -> &[usize] {
199        &self.shape
200    }
201
202    /// Get the underlying tensor data (zero-copy)
203    pub fn as_tensor(&self) -> Result<Tensor> {
204        // This would implement actual zero-copy slicing
205        // For now, return a simple implementation
206        match &*self.original {
207            Tensor::F32(arr) => {
208                let flat = arr
209                    .view()
210                    .into_shape_with_order(arr.len())
211                    .map_err(|e| TrustformersError::shape_error(e.to_string()))?;
212                let slice = flat.slice(s![
213                    self.offset..self.offset + self.shape.iter().product::<usize>()
214                ]);
215                let sliced_arr = slice
216                    .to_owned()
217                    .into_shape_with_order(IxDyn(&self.shape))
218                    .map_err(|e| TrustformersError::shape_error(e.to_string()))?;
219                Ok(Tensor::F32(sliced_arr))
220            },
221            _ => Err(TrustformersError::tensor_op_error(
222                "Zero-copy slicing not implemented for this tensor type",
223                "zero_copy_slice",
224            )),
225        }
226    }
227}
228
229/// Enhanced statistics for adaptive memory pool
230#[derive(Debug, Clone)]
231struct PoolStatistics {
232    total_requests: usize,
233    cache_hits: usize,
234    cache_misses: usize,
235    total_evictions: usize,
236    evictions_by_policy: HashMap<String, usize>,
237    total_allocated_bytes: usize,
238    peak_memory_usage: usize,
239    #[allow(dead_code)]
240    average_tensor_lifetime: Duration,
241    #[allow(dead_code)]
242    last_reset: Instant,
243}
244
245impl Default for PoolStatistics {
246    fn default() -> Self {
247        Self {
248            total_requests: 0,
249            cache_hits: 0,
250            cache_misses: 0,
251            total_evictions: 0,
252            evictions_by_policy: HashMap::new(),
253            total_allocated_bytes: 0,
254            peak_memory_usage: 0,
255            average_tensor_lifetime: Duration::ZERO,
256            last_reset: Instant::now(),
257        }
258    }
259}
260
261impl PoolStatistics {
262    fn hit_rate(&self) -> f64 {
263        if self.total_requests == 0 {
264            0.0
265        } else {
266            self.cache_hits as f64 / self.total_requests as f64
267        }
268    }
269
270    fn miss_rate(&self) -> f64 {
271        if self.total_requests == 0 {
272            0.0
273        } else {
274            self.cache_misses as f64 / self.total_requests as f64
275        }
276    }
277}
278
279/// Memory pool for tensor recycling (enhanced with adaptive strategies)
280pub struct TensorMemoryPool {
281    config: MemoryConfig,
282    pool: Arc<RwLock<HashMap<Vec<usize>, Vec<PoolEntry>>>>,
283    current_size: Arc<Mutex<usize>>,
284    last_cleanup: Arc<Mutex<Instant>>,
285    /// Enhanced statistics for adaptive behavior
286    statistics: Arc<Mutex<PoolStatistics>>,
287    /// Access pattern tracking for prefetching
288    access_patterns: Arc<Mutex<HashMap<Vec<usize>, Vec<Instant>>>>,
289    /// Dynamic pool size (for adaptive strategies)
290    dynamic_max_size: Arc<Mutex<usize>>,
291}
292
293impl TensorMemoryPool {
294    /// Create a new memory pool with enhanced adaptive strategies
295    pub fn new(config: MemoryConfig) -> Self {
296        let dynamic_max_size = config.max_pool_size;
297        Self {
298            config,
299            pool: Arc::new(RwLock::new(HashMap::new())),
300            current_size: Arc::new(Mutex::new(0)),
301            last_cleanup: Arc::new(Mutex::new(Instant::now())),
302            statistics: Arc::new(Mutex::new(PoolStatistics::default())),
303            access_patterns: Arc::new(Mutex::new(HashMap::new())),
304            dynamic_max_size: Arc::new(Mutex::new(dynamic_max_size)),
305        }
306    }
307
308    /// Get a tensor from the pool or create a new one (enhanced with statistics tracking)
309    pub fn get_tensor(&self, shape: &[usize], dtype: crate::tensor::DType) -> Result<Tensor> {
310        // Track access pattern for prefetching
311        if self.config.enable_prefetching {
312            let mut patterns = self.access_patterns.lock().expect("lock should not be poisoned");
313            patterns.entry(shape.to_vec()).or_default().push(Instant::now());
314        }
315
316        // Update statistics
317        {
318            let mut stats = self.statistics.lock().expect("lock should not be poisoned");
319            stats.total_requests += 1;
320        }
321
322        if !self.config.enable_memory_pool {
323            return self.create_tensor(shape, dtype);
324        }
325
326        // Try to get from pool first
327        if let Some(tensor) = self.try_get_from_pool(shape)? {
328            // Cache hit!
329            let mut stats = self.statistics.lock().expect("lock should not be poisoned");
330            stats.cache_hits += 1;
331            return Ok(tensor);
332        }
333
334        // Cache miss
335        {
336            let mut stats = self.statistics.lock().expect("lock should not be poisoned");
337            stats.cache_misses += 1;
338        }
339
340        // Apply adaptive pool sizing based on hit rate
341        self.apply_adaptive_sizing()?;
342
343        // Create new tensor if none available in pool
344        self.create_tensor(shape, dtype)
345    }
346
347    /// Return a tensor to the pool for recycling (enhanced tracking)
348    pub fn return_tensor(&self, tensor: Tensor) -> Result<()> {
349        if !self.config.enable_memory_pool {
350            return Ok(()); // Just drop the tensor
351        }
352
353        let shape = tensor.shape().to_vec();
354
355        // Calculate tensor size before moving
356        let tensor_size = self.estimate_tensor_size(&tensor);
357
358        // Create enhanced pool entry
359        let entry = PoolEntry::new(tensor, tensor_size);
360
361        let mut pool = self.pool.write().expect("lock should not be poisoned");
362        pool.entry(shape).or_default().push(entry);
363
364        // Update current size and peak usage
365        {
366            let mut current = self.current_size.lock().expect("lock should not be poisoned");
367            *current += tensor_size;
368
369            let mut stats = self.statistics.lock().expect("lock should not be poisoned");
370            if *current > stats.peak_memory_usage {
371                stats.peak_memory_usage = *current;
372            }
373            stats.total_allocated_bytes += tensor_size;
374        }
375
376        // Cleanup if needed (with enhanced eviction policies)
377        self.cleanup_if_needed()?;
378
379        Ok(())
380    }
381
382    /// Try to get a tensor from the pool (enhanced with access tracking)
383    fn try_get_from_pool(&self, shape: &[usize]) -> Result<Option<Tensor>> {
384        let mut pool = self.pool.write().expect("lock should not be poisoned");
385
386        if let Some(entries) = pool.get_mut(shape) {
387            if let Some(mut entry) = entries.pop() {
388                // Mark as accessed for LFU tracking
389                entry.mark_accessed();
390
391                let tensor_size = entry.size_bytes;
392                *self.current_size.lock().expect("lock should not be poisoned") -= tensor_size;
393                return Ok(Some(entry.tensor));
394            }
395        }
396
397        Ok(None)
398    }
399
400    /// Create a new tensor
401    fn create_tensor(&self, shape: &[usize], dtype: crate::tensor::DType) -> Result<Tensor> {
402        match dtype {
403            crate::tensor::DType::F32 => Tensor::zeros(shape),
404            crate::tensor::DType::F64 => Tensor::zeros_f64(shape),
405            crate::tensor::DType::F16 => Tensor::zeros_f16(shape),
406            crate::tensor::DType::BF16 => Tensor::zeros_bf16(shape),
407            crate::tensor::DType::I64 => Tensor::zeros_i64(shape),
408            crate::tensor::DType::C32 => Tensor::zeros_c32(shape),
409            crate::tensor::DType::C64 => Tensor::zeros_c64(shape),
410            crate::tensor::DType::CF16 => Tensor::zeros_cf16(shape),
411            crate::tensor::DType::CBF16 => Tensor::zeros_cbf16(shape),
412            _ => Err(TrustformersError::tensor_op_error(
413                &format!("Tensor creation not implemented for dtype: {:?} - only supported types are F32, F64, F16, BF16, I64, C32, C64, CF16, CBF16", dtype),
414                "create_tensor"
415            )),
416        }
417    }
418
419    /// Estimate the memory size of a tensor
420    fn estimate_tensor_size(&self, tensor: &Tensor) -> usize {
421        let elements = tensor.shape().iter().product::<usize>();
422        match tensor {
423            Tensor::F32(_) => elements * 4,   // 32-bit float
424            Tensor::F64(_) => elements * 8,   // 64-bit float
425            Tensor::F16(_) => elements * 2,   // 16-bit float
426            Tensor::BF16(_) => elements * 2,  // 16-bit bfloat
427            Tensor::I64(_) => elements * 8,   // 64-bit integer
428            Tensor::C32(_) => elements * 8,   // 2 * 32-bit complex
429            Tensor::C64(_) => elements * 16,  // 2 * 64-bit complex
430            Tensor::CF16(_) => elements * 4,  // 2 * 16-bit complex
431            Tensor::CBF16(_) => elements * 4, // 2 * 16-bit bfloat complex
432            #[cfg(feature = "torch")]
433            Tensor::Torch(_) => elements * 4, // Default to 32-bit
434            #[cfg(feature = "candle")]
435            Tensor::Candle(_) => elements * 4, // Default to 32-bit
436            #[cfg(all(target_os = "macos", feature = "metal"))]
437            Tensor::Metal(data) => elements * data.dtype.size_in_bytes(),
438            #[cfg(feature = "cuda")]
439            Tensor::CUDA(data) => elements * data.dtype.size_in_bytes(),
440            Tensor::Sparse(sparse) => {
441                // For sparse tensors, estimate based on non-zero elements
442                let nnz = sparse.nnz();
443                nnz * 4 + nnz * std::mem::size_of::<usize>() // values + indices
444            },
445        }
446    }
447
448    /// Cleanup old entries if needed (enhanced with adaptive eviction policies)
449    fn cleanup_if_needed(&self) -> Result<()> {
450        let mut last_cleanup = self.last_cleanup.lock().expect("lock should not be poisoned");
451        let should_cleanup_time = last_cleanup.elapsed() >= self.config.cleanup_interval;
452
453        let current_size = *self.current_size.lock().expect("lock should not be poisoned");
454        let dynamic_max = *self.dynamic_max_size.lock().expect("lock should not be poisoned");
455        let should_cleanup_size = current_size > dynamic_max;
456
457        if !should_cleanup_time && !should_cleanup_size {
458            return Ok(());
459        }
460
461        // Enhanced cleanup using configured eviction policy
462        let mut pool = self.pool.write().expect("lock should not be poisoned");
463        let mut total_freed = 0;
464        let mut eviction_count = 0;
465        let policy = self.config.eviction_policy;
466
467        // Calculate how much memory we need to free
468        let target_size = (dynamic_max as f64 * 0.85) as usize; // Target 85% of max
469        let need_to_free = current_size.saturating_sub(target_size);
470
471        // Collect all entries with their priorities
472        let mut all_entries: Vec<(Vec<usize>, usize, f64)> = Vec::new();
473
474        for (shape, entries) in pool.iter() {
475            for (idx, entry) in entries.iter().enumerate() {
476                if entry.ref_count == 0 {
477                    let priority = entry.eviction_priority(policy);
478                    all_entries.push((shape.clone(), idx, priority));
479                }
480            }
481        }
482
483        // Sort by eviction priority (lowest first = evict first)
484        all_entries.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(std::cmp::Ordering::Equal));
485
486        // Evict entries until we've freed enough memory
487        let mut freed_so_far = 0;
488        let mut shapes_to_remove: Vec<Vec<usize>> = Vec::new();
489
490        for (shape, _, _) in all_entries.iter() {
491            if freed_so_far >= need_to_free {
492                break;
493            }
494
495            if let Some(entries) = pool.get_mut(shape) {
496                if let Some(entry) = entries.first() {
497                    if entry.ref_count == 0 {
498                        let size = entry.size_bytes;
499                        freed_so_far += size;
500                        total_freed += size;
501                        eviction_count += 1;
502                        shapes_to_remove.push(shape.clone());
503                    }
504                }
505            }
506        }
507
508        // Remove marked entries
509        for shape in shapes_to_remove {
510            if let Some(entries) = pool.get_mut(&shape) {
511                if !entries.is_empty() {
512                    entries.remove(0);
513                }
514            }
515        }
516
517        // Remove empty entries
518        pool.retain(|_, entries| !entries.is_empty());
519
520        drop(pool); // Release write lock
521
522        // Update statistics
523        {
524            let mut stats = self.statistics.lock().expect("lock should not be poisoned");
525            stats.total_evictions += eviction_count;
526            *stats.evictions_by_policy.entry(format!("{:?}", policy)).or_insert(0) +=
527                eviction_count;
528        }
529
530        // Update size
531        *self.current_size.lock().expect("lock should not be poisoned") -= total_freed;
532        *last_cleanup = Instant::now();
533
534        // Run defragmentation if enabled
535        if self.config.enable_defragmentation {
536            self.defragment_pool()?;
537        }
538
539        Ok(())
540    }
541
542    /// Apply adaptive pool sizing based on configured strategy
543    fn apply_adaptive_sizing(&self) -> Result<()> {
544        match self.config.adaptive_strategy {
545            AdaptiveStrategy::Fixed => Ok(()), // No adaptation
546            AdaptiveStrategy::HitRate => self.adapt_by_hit_rate(),
547            AdaptiveStrategy::MemoryPressure => self.adapt_by_memory_pressure(),
548            AdaptiveStrategy::Predictive => self.adapt_by_prediction(),
549        }
550    }
551
552    /// Adapt pool size based on hit rate
553    fn adapt_by_hit_rate(&self) -> Result<()> {
554        let stats = self.statistics.lock().expect("lock should not be poisoned");
555        let hit_rate = stats.hit_rate();
556        drop(stats);
557
558        let mut dynamic_max = self.dynamic_max_size.lock().expect("lock should not be poisoned");
559        let target_rate = self.config.target_hit_rate;
560
561        if hit_rate < target_rate {
562            // Low hit rate: increase pool size
563            let increase = (*dynamic_max as f64 * 0.1) as usize;
564            let new_size = (*dynamic_max + increase).min(self.config.max_pool_size);
565            if new_size > *dynamic_max {
566                *dynamic_max = new_size;
567            }
568        } else if hit_rate > target_rate + 0.1 {
569            // Very high hit rate: can decrease pool size
570            let decrease = (*dynamic_max as f64 * 0.05) as usize;
571            let new_size = (*dynamic_max - decrease).max(self.config.min_pool_size);
572            if new_size < *dynamic_max {
573                *dynamic_max = new_size;
574            }
575        }
576
577        Ok(())
578    }
579
580    /// Adapt pool size based on system memory pressure
581    fn adapt_by_memory_pressure(&self) -> Result<()> {
582        // Simplified memory pressure detection
583        // In production, this would query OS for available memory
584        let current_size = *self.current_size.lock().expect("lock should not be poisoned");
585        let mut dynamic_max = self.dynamic_max_size.lock().expect("lock should not be poisoned");
586
587        let utilization = current_size as f64 / *dynamic_max as f64;
588
589        if utilization > 0.9 {
590            // High pressure: decrease pool size
591            let new_size = (*dynamic_max as f64 * 0.9) as usize;
592            *dynamic_max = new_size.max(self.config.min_pool_size);
593        } else if utilization < 0.5 {
594            // Low pressure: increase pool size
595            let new_size = (*dynamic_max as f64 * 1.1) as usize;
596            *dynamic_max = new_size.min(self.config.max_pool_size);
597        }
598
599        Ok(())
600    }
601
602    /// Adapt pool size based on access pattern prediction
603    fn adapt_by_prediction(&self) -> Result<()> {
604        let patterns = self.access_patterns.lock().expect("lock should not be poisoned");
605
606        // Analyze access patterns to predict future needs
607        let mut total_recent_accesses = 0;
608        let recent_window = Duration::from_secs(60);
609        let now = Instant::now();
610
611        for timestamps in patterns.values() {
612            total_recent_accesses +=
613                timestamps.iter().filter(|t| now.duration_since(**t) < recent_window).count();
614        }
615
616        drop(patterns);
617
618        // Adjust based on activity level
619        let mut dynamic_max = self.dynamic_max_size.lock().expect("lock should not be poisoned");
620
621        if total_recent_accesses > 1000 {
622            // High activity: increase pool
623            let new_size = (*dynamic_max as f64 * 1.15) as usize;
624            *dynamic_max = new_size.min(self.config.max_pool_size);
625        } else if total_recent_accesses < 100 {
626            // Low activity: decrease pool
627            let new_size = (*dynamic_max as f64 * 0.9) as usize;
628            *dynamic_max = new_size.max(self.config.min_pool_size);
629        }
630
631        Ok(())
632    }
633
634    /// Defragment the pool by reorganizing entries
635    fn defragment_pool(&self) -> Result<()> {
636        // Simplified defragmentation: consolidate shape groups
637        let mut pool = self.pool.write().expect("lock should not be poisoned");
638
639        for entries in pool.values_mut() {
640            // Sort entries by access count (most accessed first)
641            entries.sort_by_key(|entry| std::cmp::Reverse(entry.access_count));
642        }
643
644        Ok(())
645    }
646
647    /// Get enhanced memory pool statistics
648    pub fn get_stats(&self) -> MemoryPoolStats {
649        let pool = self.pool.read().expect("lock should not be poisoned");
650        let current_size = *self.current_size.lock().expect("lock should not be poisoned");
651        let stats = self.statistics.lock().expect("lock should not be poisoned");
652        let dynamic_max = *self.dynamic_max_size.lock().expect("lock should not be poisoned");
653
654        let total_tensors = pool.values().map(|v| v.len()).sum();
655        let total_shapes = pool.len();
656
657        MemoryPoolStats {
658            total_tensors,
659            total_shapes,
660            current_size_bytes: current_size,
661            max_size_bytes: self.config.max_pool_size,
662            dynamic_max_size_bytes: dynamic_max,
663            utilization: current_size as f64 / dynamic_max as f64,
664            hit_rate: stats.hit_rate(),
665            miss_rate: stats.miss_rate(),
666            total_requests: stats.total_requests,
667            cache_hits: stats.cache_hits,
668            cache_misses: stats.cache_misses,
669            total_evictions: stats.total_evictions,
670            peak_memory_usage_bytes: stats.peak_memory_usage,
671            eviction_policy: self.config.eviction_policy,
672            adaptive_strategy: self.config.adaptive_strategy,
673        }
674    }
675
676    /// Reset statistics counters
677    pub fn reset_statistics(&self) {
678        let mut stats = self.statistics.lock().expect("Lock poisoned");
679        *stats = PoolStatistics::default();
680    }
681
682    /// Get current hit rate
683    pub fn hit_rate(&self) -> f64 {
684        let stats = self.statistics.lock().expect("lock should not be poisoned");
685        stats.hit_rate()
686    }
687
688    /// Get current eviction policy
689    pub fn eviction_policy(&self) -> MemoryEvictionPolicy {
690        self.config.eviction_policy
691    }
692
693    /// Get current adaptive strategy
694    pub fn adaptive_strategy(&self) -> AdaptiveStrategy {
695        self.config.adaptive_strategy
696    }
697
698    /// Get predicted shapes based on access patterns
699    pub fn get_predicted_shapes(&self, window: Duration) -> Vec<Vec<usize>> {
700        let patterns = self.access_patterns.lock().expect("lock should not be poisoned");
701        let now = Instant::now();
702
703        let mut frequent_shapes: Vec<(Vec<usize>, usize)> = patterns
704            .iter()
705            .map(|(shape, timestamps)| {
706                let count = timestamps.iter().filter(|t| now.duration_since(**t) < window).count();
707                (shape.clone(), count)
708            })
709            .filter(|(_, count)| *count > 0)
710            .collect();
711
712        frequent_shapes.sort_by_key(|item| std::cmp::Reverse(item.1));
713        frequent_shapes.into_iter().map(|(shape, _)| shape).collect()
714    }
715}
716
717/// Enhanced statistics for memory pool
718#[derive(Debug, Clone)]
719pub struct MemoryPoolStats {
720    /// Total tensors currently in pool
721    pub total_tensors: usize,
722    /// Number of different tensor shapes in pool
723    pub total_shapes: usize,
724    /// Current memory usage in bytes
725    pub current_size_bytes: usize,
726    /// Maximum configured pool size in bytes
727    pub max_size_bytes: usize,
728    /// Current dynamic maximum size (for adaptive strategies)
729    pub dynamic_max_size_bytes: usize,
730    /// Pool utilization (0.0 to 1.0+)
731    pub utilization: f64,
732    /// Cache hit rate (0.0 to 1.0)
733    pub hit_rate: f64,
734    /// Cache miss rate (0.0 to 1.0)
735    pub miss_rate: f64,
736    /// Total number of tensor requests
737    pub total_requests: usize,
738    /// Number of cache hits
739    pub cache_hits: usize,
740    /// Number of cache misses
741    pub cache_misses: usize,
742    /// Total number of evictions
743    pub total_evictions: usize,
744    /// Peak memory usage observed (bytes)
745    pub peak_memory_usage_bytes: usize,
746    /// Current eviction policy
747    pub eviction_policy: MemoryEvictionPolicy,
748    /// Current adaptive strategy
749    pub adaptive_strategy: AdaptiveStrategy,
750}
751
752/// Memory mapped tensor for large model weights
753pub struct MemoryMappedTensor {
754    /// File path for the memory mapped data
755    file_path: String,
756    /// Shape of the tensor
757    shape: Vec<usize>,
758    /// Data type
759    dtype: crate::tensor::DType,
760    /// File handle for memory mapped data
761    _file: Option<File>,
762    /// Size of the file in bytes
763    file_size: u64,
764}
765
766impl MemoryMappedTensor {
767    /// Create a new memory mapped tensor
768    pub fn new(file_path: String, shape: Vec<usize>, dtype: crate::tensor::DType) -> Result<Self> {
769        // Open the file for reading
770        let mut file = File::open(&file_path).map_err(|e| {
771            TrustformersError::tensor_op_error(
772                &format!("Failed to open file for memory mapping: {}", e),
773                "mmap_new",
774            )
775        })?;
776
777        // Get file size
778        let file_size = file.seek(SeekFrom::End(0)).map_err(|e| {
779            TrustformersError::tensor_op_error(
780                &format!("Failed to get file size: {}", e),
781                "mmap_new",
782            )
783        })?;
784
785        // Verify file size matches tensor size
786        let element_size = dtype.size_in_bytes();
787        let total_elements: usize = shape.iter().product();
788        let expected_size = total_elements * element_size;
789
790        if file_size != expected_size as u64 {
791            return Err(TrustformersError::tensor_op_error(
792                &format!(
793                    "File size {} doesn't match expected tensor size {}",
794                    file_size, expected_size
795                ),
796                "mmap_new",
797            ));
798        }
799
800        Ok(Self {
801            file_path,
802            shape,
803            dtype,
804            _file: Some(file),
805            file_size,
806        })
807    }
808
809    /// Load the tensor data (lazy loading)
810    pub fn load(&self) -> Result<Tensor> {
811        // Read the entire file content
812        let mut file = File::open(&self.file_path).map_err(|e| {
813            TrustformersError::tensor_op_error(
814                &format!("Failed to open file for reading: {}", e),
815                "mmap_load",
816            )
817        })?;
818
819        let mut buffer = vec![0u8; self.file_size as usize];
820        file.read_exact(&mut buffer).map_err(|e| {
821            TrustformersError::tensor_op_error(
822                &format!("Failed to read file data: {}", e),
823                "mmap_load",
824            )
825        })?;
826
827        // Convert bytes to appropriate tensor type
828        match self.dtype {
829            crate::tensor::DType::F32 => {
830                let float_data = buffer
831                    .chunks_exact(4)
832                    .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
833                    .collect::<Vec<f32>>();
834                Tensor::from_slice(&float_data, &self.shape)
835            },
836            crate::tensor::DType::F64 => {
837                let float_data = buffer
838                    .chunks_exact(8)
839                    .map(|chunk| {
840                        f64::from_le_bytes([
841                            chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6],
842                            chunk[7],
843                        ])
844                    })
845                    .collect::<Vec<f64>>();
846                Tensor::from_slice_f64(&float_data, &self.shape)
847            },
848            crate::tensor::DType::I64 => {
849                let int_data = buffer
850                    .chunks_exact(8)
851                    .map(|chunk| {
852                        i64::from_le_bytes([
853                            chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6],
854                            chunk[7],
855                        ])
856                    })
857                    .collect::<Vec<i64>>();
858                Tensor::from_slice_i64(&int_data, &self.shape)
859            },
860            crate::tensor::DType::I32 => {
861                let int_data = buffer
862                    .chunks_exact(4)
863                    .map(|chunk| i32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
864                    .collect::<Vec<i32>>();
865                Tensor::from_slice_i32(&int_data, &self.shape)
866            },
867            _ => Err(TrustformersError::tensor_op_error(
868                "Unsupported dtype for memory mapped tensor",
869                "mmap_load",
870            )),
871        }
872    }
873
874    /// Get the shape of the tensor
875    pub fn shape(&self) -> &[usize] {
876        &self.shape
877    }
878
879    /// Get the file path
880    pub fn file_path(&self) -> &str {
881        &self.file_path
882    }
883}
884
885/// Global memory manager instance
886static MEMORY_MANAGER: std::sync::OnceLock<TensorMemoryPool> = std::sync::OnceLock::new();
887
888/// Initialize the global memory manager
889pub fn init_memory_manager(config: MemoryConfig) -> Result<()> {
890    let pool = TensorMemoryPool::new(config);
891    MEMORY_MANAGER.set(pool).map_err(|_| {
892        TrustformersError::invalid_input("Memory manager already initialized".to_string())
893    })?;
894    Ok(())
895}
896
897/// Get the global memory manager
898pub fn get_memory_manager() -> Option<&'static TensorMemoryPool> {
899    MEMORY_MANAGER.get()
900}
901
902/// Convenience function to get a tensor from the global pool
903pub fn get_tensor(shape: &[usize], dtype: crate::tensor::DType) -> Result<Tensor> {
904    if let Some(manager) = get_memory_manager() {
905        manager.get_tensor(shape, dtype)
906    } else {
907        // Fallback to direct creation
908        match dtype {
909            crate::tensor::DType::F32 => Tensor::zeros(shape),
910            crate::tensor::DType::F64 => Tensor::zeros_f64(shape),
911            crate::tensor::DType::I64 => Tensor::zeros_i64(shape),
912            _ => Err(TrustformersError::tensor_op_error(
913                "Unsupported dtype",
914                "get_tensor",
915            )),
916        }
917    }
918}
919
920/// Convenience function to return a tensor to the global pool
921pub fn return_tensor(tensor: Tensor) -> Result<()> {
922    if let Some(manager) = get_memory_manager() {
923        manager.return_tensor(tensor)
924    } else {
925        Ok(()) // Just drop the tensor
926    }
927}
928
929#[cfg(test)]
930mod tests {
931    use super::*;
932
933    #[test]
934    fn test_memory_config_default() {
935        let config = MemoryConfig::default();
936        assert!(config.enable_memory_pool);
937        assert!(config.enable_zero_copy);
938        assert!(config.enable_mmap);
939        assert_eq!(config.max_pool_size, 1024 * 1024 * 1024);
940    }
941
942    #[test]
943    fn test_tensor_pool_creation() {
944        let config = MemoryConfig::default();
945        let pool = TensorMemoryPool::new(config);
946        let stats = pool.get_stats();
947        assert_eq!(stats.total_tensors, 0);
948        assert_eq!(stats.current_size_bytes, 0);
949    }
950
951    #[test]
952    fn test_tensor_pool_get_and_return() -> Result<()> {
953        let config = MemoryConfig::default();
954        let pool = TensorMemoryPool::new(config);
955
956        // Get a tensor
957        let shape = vec![2, 3];
958        let tensor = pool.get_tensor(&shape, crate::tensor::DType::F32)?;
959        assert_eq!(tensor.shape(), shape.as_slice());
960
961        // Return it to pool
962        pool.return_tensor(tensor)?;
963
964        // Get it again (should come from pool)
965        let tensor2 = pool.get_tensor(&shape, crate::tensor::DType::F32)?;
966        assert_eq!(tensor2.shape(), shape.as_slice());
967
968        Ok(())
969    }
970
971    #[test]
972    fn test_zero_copy_tensor_view() -> Result<()> {
973        let tensor = Arc::new(Tensor::ones(&[10])?);
974        let view = TensorView::slice(tensor, 2, 8)?;
975        assert_eq!(view.shape(), &[6]);
976
977        let viewed_tensor = view.as_tensor()?;
978        assert_eq!(viewed_tensor.shape(), &[6]);
979
980        Ok(())
981    }
982
983    #[test]
984    fn test_memory_mapped_tensor() -> Result<()> {
985        use std::fs::File;
986        use std::io::Write;
987
988        // Create a temporary file with some data
989        let temp_file = "test_temp.bin";
990        let data_size = 100 * 100 * std::mem::size_of::<f32>();
991        let data: Vec<u8> = vec![0; data_size];
992
993        {
994            let mut file = File::create(temp_file).map_err(|e| {
995                TrustformersError::tensor_op_error(
996                    &format!("Failed to create test file: {}", e),
997                    "test_setup",
998                )
999            })?;
1000            file.write_all(&data).map_err(|e| {
1001                TrustformersError::tensor_op_error(
1002                    &format!("Failed to write test data: {}", e),
1003                    "test_setup",
1004                )
1005            })?;
1006        }
1007
1008        let mmap_tensor = MemoryMappedTensor::new(
1009            temp_file.to_string(),
1010            vec![100, 100],
1011            crate::tensor::DType::F32,
1012        )?;
1013
1014        assert_eq!(mmap_tensor.shape(), &[100, 100]);
1015        assert_eq!(mmap_tensor.file_path(), temp_file);
1016
1017        let loaded = mmap_tensor.load()?;
1018        assert_eq!(loaded.shape(), &[100, 100]);
1019
1020        // Clean up
1021        std::fs::remove_file(temp_file).ok();
1022
1023        Ok(())
1024    }
1025
1026    #[test]
1027    fn test_global_memory_manager() -> Result<()> {
1028        let config = MemoryConfig::default();
1029        init_memory_manager(config)?;
1030
1031        let tensor = get_tensor(&[5, 5], crate::tensor::DType::F32)?;
1032        assert_eq!(tensor.shape(), [5, 5].as_slice());
1033
1034        return_tensor(tensor)?;
1035
1036        Ok(())
1037    }
1038}