Skip to main content

torsh_tensor/
memory_pool.rs

1// Framework infrastructure - components designed for future use
2#![allow(dead_code)]
3// Memory pooling for efficient tensor memory management with SciRS2 Memory Optimization
4
5use crate::{Tensor, TensorStorage};
6use std::alloc::{handle_alloc_error, Layout};
7use std::collections::{HashMap, VecDeque};
8use std::marker::PhantomData;
9use std::mem::{ManuallyDrop, MaybeUninit};
10use std::ptr::NonNull;
11use std::sync::{Arc, Mutex, Weak};
12use torsh_core::{device::DeviceType, dtype::TensorElement, error::Result};
13
14// ✅ SciRS2 Memory Optimization Features
15use scirs2_core::memory::GlobalBufferPool;
16use scirs2_core::memory::LeakDetector;
17// ✅ SciRS2 memory_efficient features - conditionally available
18
19// Fallback for when memory_efficient feature is not available
20#[cfg(not(feature = "memory_efficient"))]
21struct MemoryMappedArray<T> {
22    _phantom: PhantomData<T>,
23}
24
25#[cfg(not(feature = "memory_efficient"))]
26impl<T> MemoryMappedArray<T> {
27    fn new(_size: usize) -> Result<Self> {
28        Err(torsh_core::error::TorshError::General(
29            torsh_core::error::GeneralError::NotImplemented(
30                "MemoryMappedArray requires memory_efficient feature".to_string(),
31            ),
32        ))
33    }
34}
35
36// TODO: profile_section macro not available in scirs2_core yet
37// #[cfg(feature = "profiling")]
38// use scirs2_core::profiling::profile_section;
39
40/// Global memory pool for tensor allocations
41static MEMORY_POOL: std::sync::OnceLock<Arc<Mutex<GlobalMemoryPool>>> = std::sync::OnceLock::new();
42
43/// Initialize the global memory pool
44pub fn init_memory_pool() -> Arc<Mutex<GlobalMemoryPool>> {
45    let arc = MEMORY_POOL
46        .get_or_init(|| {
47            let pool = Arc::new(Mutex::new(GlobalMemoryPool::new()));
48            // Store the Weak reference back into the pool so acquire_uninit can use it
49            if let Ok(mut guard) = pool.lock() {
50                guard.self_weak = Some(Arc::downgrade(&pool));
51            }
52            pool
53        })
54        .clone();
55    arc
56}
57
58/// Get reference to the global memory pool
59pub fn get_memory_pool() -> Arc<Mutex<GlobalMemoryPool>> {
60    init_memory_pool()
61}
62
63// ─── RawEntry ────────────────────────────────────────────────────────────────
64
65/// An owned raw allocation stored in the pool's free-list.
66/// On `Drop` it deallocates the memory if it was not consumed.
67struct RawEntry {
68    ptr: NonNull<u8>,
69    capacity_bytes: usize,
70    layout: Layout,
71}
72
73/// SAFETY: `RawEntry` owns the raw pointer; transferring it to another thread is safe.
74unsafe impl Send for RawEntry {}
75
76impl Drop for RawEntry {
77    fn drop(&mut self) {
78        // SAFETY: ptr was allocated with this layout via `std::alloc::alloc`.
79        unsafe { std::alloc::dealloc(self.ptr.as_ptr(), self.layout) };
80    }
81}
82
83// ─── ReusedBuffer<T> ─────────────────────────────────────────────────────────
84
85/// A truly-pooled buffer: holds the **actual pooled allocation** without copying.
86///
87/// When dropped (or via `release_to_pool`), the buffer is returned to the global
88/// pool. Use `into_vec(len)` to take ownership as a `Vec<T>`.
89pub struct ReusedBuffer<T: 'static> {
90    ptr: NonNull<T>,
91    capacity: usize,
92    layout: Layout,
93    pool: Weak<Mutex<GlobalMemoryPool>>,
94}
95
96/// SAFETY: `ReusedBuffer<T>` owns a unique allocation; it is safe to send across threads
97/// when `T: Send`.
98unsafe impl<T: Send + 'static> Send for ReusedBuffer<T> {}
99
100impl<T: 'static> ReusedBuffer<T> {
101    /// Returns a mutable view of the buffer as uninitialized elements.
102    pub fn as_uninit_slice_mut(&mut self) -> &mut [MaybeUninit<T>] {
103        // SAFETY: ptr is valid for `capacity` elements; we have exclusive access via &mut self.
104        unsafe {
105            std::slice::from_raw_parts_mut(self.ptr.as_ptr() as *mut MaybeUninit<T>, self.capacity)
106        }
107    }
108
109    /// Capacity in elements (not bytes).
110    pub fn capacity(&self) -> usize {
111        self.capacity
112    }
113
114    /// Raw pointer access — primarily for tests to verify address identity.
115    pub fn as_ptr_raw(&self) -> *mut T {
116        self.ptr.as_ptr()
117    }
118
119    /// Consume `self` and transfer ownership of the allocation to a `Vec<T>`.
120    ///
121    /// The caller must guarantee `len <= self.capacity()` and that the first `len`
122    /// elements have been initialized.
123    ///
124    /// The `Vec` now owns the memory and will free it on drop; it is NOT returned
125    /// to the pool.
126    pub fn into_vec(self, len: usize) -> Vec<T> {
127        debug_assert!(len <= self.capacity, "len must not exceed capacity");
128        // Wrap self in ManuallyDrop so our Drop impl does not run.
129        let md = ManuallyDrop::new(self);
130        // SAFETY: ptr was allocated with the global allocator for `md.capacity` elements.
131        // `len` elements are initialized (caller contract). capacity matches.
132        unsafe { Vec::from_raw_parts(md.ptr.as_ptr(), len, md.capacity) }
133    }
134
135    /// Consume `self` and return the buffer to the pool.
136    ///
137    /// If the pool is gone (Arc was dropped), the allocation is freed instead.
138    pub fn release_to_pool(self) {
139        // Wrap in ManuallyDrop to prevent our Drop from running.
140        let md = ManuallyDrop::new(self);
141        let raw_entry = RawEntry {
142            ptr: NonNull::new(md.ptr.as_ptr() as *mut u8)
143                .expect("ReusedBuffer pointer is non-null by construction"),
144            capacity_bytes: md.capacity * std::mem::size_of::<T>(),
145            layout: md.layout,
146        };
147        if let Some(pool_arc) = md.pool.upgrade() {
148            if let Ok(mut guard) = pool_arc.lock() {
149                let type_id = std::any::TypeId::of::<T>();
150                let size_class = guard.find_size_class(raw_entry.capacity_bytes);
151                let pool_key = (type_id, size_class);
152                if let Some(bucket) = guard.pools.get_mut(&pool_key) {
153                    if bucket.available_buffers.len() < bucket.max_buffers {
154                        bucket.available_buffers.push_back(raw_entry);
155                        bucket.deallocations += 1;
156                        // ManuallyDrop prevents double-free: raw_entry is now owned by the bucket.
157                        return;
158                    }
159                }
160            }
161        }
162        // Pool unavailable or full — `raw_entry` drops here and frees memory via RawEntry::Drop.
163    }
164}
165
166impl<T: 'static> Drop for ReusedBuffer<T> {
167    fn drop(&mut self) {
168        // Reconstruct a RawEntry to trigger a properly-guarded dealloc-or-return.
169        // We cannot call release_to_pool(self) directly (consumes), so replicate logic.
170        let raw_entry = RawEntry {
171            ptr: NonNull::new(self.ptr.as_ptr() as *mut u8)
172                .expect("ReusedBuffer pointer is non-null by construction"),
173            capacity_bytes: self.capacity * std::mem::size_of::<T>(),
174            layout: self.layout,
175        };
176        if let Some(pool_arc) = self.pool.upgrade() {
177            if let Ok(mut guard) = pool_arc.lock() {
178                let type_id = std::any::TypeId::of::<T>();
179                let size_class = guard.find_size_class(raw_entry.capacity_bytes);
180                let pool_key = (type_id, size_class);
181                if let Some(bucket) = guard.pools.get_mut(&pool_key) {
182                    if bucket.available_buffers.len() < bucket.max_buffers {
183                        // Wrap in ManuallyDrop so push_back takes it without scheduling
184                        // a double-free when the local binding goes out of scope.
185                        let md_entry = ManuallyDrop::new(raw_entry);
186                        // SAFETY: ManuallyDrop<RawEntry> has the same layout as RawEntry;
187                        // we read it once here and never again.
188                        bucket
189                            .available_buffers
190                            .push_back(unsafe { std::ptr::read(&*md_entry as *const RawEntry) });
191                        bucket.deallocations += 1;
192                        return;
193                    }
194                }
195            }
196        }
197        // raw_entry drops here → dealloc via RawEntry::Drop
198    }
199}
200
201// ─── GlobalMemoryPool ────────────────────────────────────────────────────────
202
203/// Enhanced global memory pool with SciRS2 memory optimization
204pub struct GlobalMemoryPool {
205    /// Pools organized by type ID and size class
206    pools: HashMap<(std::any::TypeId, usize), MemoryPool>,
207    /// Statistics for pool usage
208    stats: PoolStatistics,
209    /// Configuration settings
210    config: PoolConfig,
211    /// ✅ SciRS2 Global Buffer Pool integration
212    scirs2_pool: GlobalBufferPool,
213    /// ✅ SciRS2 Memory leak detector
214    leak_detector: LeakDetector,
215    /// Weak self-reference used to hand out pool handles to `ReusedBuffer`.
216    self_weak: Option<Weak<Mutex<GlobalMemoryPool>>>,
217    // ✅ SciRS2 Memory metrics collector (requires memory_efficient feature)
218    // metrics_collector: MemoryMetricsCollector,
219    // ✅ SciRS2 Adaptive chunking for large tensors (requires memory_efficient feature)
220    // adaptive_chunking: AdaptiveChunking,
221}
222
223/// Memory pool for specific data type and size class
224#[derive(Debug)]
225struct MemoryPool {
226    /// Available buffers ready for reuse (raw allocations)
227    available_buffers: VecDeque<RawEntry>,
228    /// Size class this pool manages (in bytes)
229    #[allow(dead_code)]
230    size_class: usize,
231    /// Maximum number of buffers to keep
232    max_buffers: usize,
233    /// Statistics for this pool
234    allocations: usize,
235    reuses: usize,
236    deallocations: usize,
237}
238
239/// Configuration for memory pool behavior
240#[derive(Debug, Clone)]
241pub struct PoolConfig {
242    /// Maximum number of buffers per size class
243    pub max_buffers_per_class: usize,
244    /// Maximum total memory to use for pooling (in bytes)
245    pub max_total_memory: usize,
246    /// Enable automatic pool cleanup
247    pub auto_cleanup: bool,
248    /// Cleanup threshold (trigger cleanup when usage exceeds this ratio)
249    pub cleanup_threshold: f64,
250    /// Size classes (in bytes) - powers of 2 for efficient alignment
251    pub size_classes: Vec<usize>,
252}
253
254/// Statistics for memory pool usage
255#[derive(Debug, Default, Clone)]
256pub struct PoolStatistics {
257    /// Total number of allocations served
258    pub total_allocations: usize,
259    /// Number of allocations served from pool (reused)
260    pub pool_hits: usize,
261    /// Number of allocations that required new memory
262    pub pool_misses: usize,
263    /// Total bytes allocated
264    pub total_bytes_allocated: usize,
265    /// Total bytes currently in pools
266    pub bytes_in_pools: usize,
267    /// Peak memory usage
268    pub peak_memory_usage: usize,
269}
270
271/// A pooled tensor that automatically returns memory to pool when dropped
272#[derive(Debug)]
273pub struct PooledTensor<T: TensorElement + Default> {
274    tensor: Tensor<T>,
275    pool_key: Option<(std::any::TypeId, usize)>,
276    _phantom: PhantomData<T>,
277}
278
279impl Default for PoolConfig {
280    fn default() -> Self {
281        // Generate size classes as powers of 2 from 1KB to 1GB
282        let size_classes = (10..31) // 2^10 to 2^30 bytes (1KB to 1GB)
283            .map(|exp| 1 << exp)
284            .collect();
285
286        Self {
287            max_buffers_per_class: 16,
288            max_total_memory: 1024 * 1024 * 1024, // 1GB
289            auto_cleanup: true,
290            cleanup_threshold: 0.8,
291            size_classes,
292        }
293    }
294}
295
296impl Default for GlobalMemoryPool {
297    fn default() -> Self {
298        Self::new()
299    }
300}
301
302impl GlobalMemoryPool {
303    /// Create a new enhanced global memory pool with SciRS2 integration
304    pub fn new() -> Self {
305        #[cfg(feature = "profiling")]
306        {
307            // let _profile = profile_section!("memory_pool_init");
308        }
309        Self {
310            pools: HashMap::new(),
311            stats: PoolStatistics::default(),
312            config: PoolConfig::default(),
313            // ✅ SciRS2 Memory Management Integration
314            scirs2_pool: GlobalBufferPool::new(),
315            leak_detector: LeakDetector::new(Default::default())
316                .unwrap_or_else(|_| panic!("Failed to initialize leak detector")),
317            self_weak: None,
318            // metrics_collector: MemoryMetricsCollector::new(),
319            // adaptive_chunking: AdaptiveChunking::new(),
320        }
321    }
322
323    /// ✅ SciRS2 Memory-Efficient Tensor Creation for Large Tensors
324    pub fn create_large_tensor<T: TensorElement>(
325        &mut self,
326        shape: &[usize],
327        device: DeviceType,
328    ) -> Result<Tensor<T>>
329    where
330        T: Clone + Default,
331    {
332        #[cfg(feature = "profiling")]
333        {
334            // let _profile = profile_section!("create_large_tensor");
335        }
336        let total_elements: usize = shape.iter().product();
337        let total_bytes = total_elements * std::mem::size_of::<T>();
338
339        // ✅ Use SciRS2 memory-efficient strategies based on tensor size
340        if total_bytes > 100 * 1024 * 1024 {
341            // >100MB: Use memory-mapped arrays for very large tensors
342            self.create_memory_mapped_tensor(shape, device)
343        } else if total_bytes > 10 * 1024 * 1024 {
344            // >10MB: Use chunked arrays for large tensors
345            self.create_chunked_tensor(shape, device)
346        } else if total_bytes > 1024 * 1024 {
347            // >1MB: Use SciRS2 buffer pool
348            self.create_pooled_tensor(shape, device)
349        } else {
350            // Small tensors: Use standard allocation
351            Tensor::zeros(shape, device)
352        }
353    }
354
355    /// Create memory-mapped tensor for very large data (>100MB)
356    fn create_memory_mapped_tensor<T: TensorElement>(
357        &mut self,
358        shape: &[usize],
359        device: DeviceType,
360    ) -> Result<Tensor<T>>
361    where
362        T: Clone + Default,
363    {
364        let total_elements: usize = shape.iter().product();
365
366        // ✅ SciRS2 Memory-Mapped Array for disk-backed storage
367        // TODO: Fix MemoryMappedArray::new() call - requires 4 arguments:
368        // MemoryMappedArray::new(data: Option<&Array>, path: &Path, mode: AccessMode, shape)
369        // let _mmap_array = MemoryMappedArray::<T>::new(None, path, AccessMode::ReadWrite, total_elements)?;
370
371        // Track memory usage
372        // Metrics collection temporarily disabled - feature not available
373        // self.metrics_collector.record_large_allocation(total_elements * std::mem::size_of::<T>());
374
375        // TODO: Use _mmap_array.as_slice() when full memory mapping is available
376        // For now, create regular tensor as fallback
377        let data = vec![T::default(); total_elements];
378        Tensor::from_data(data, shape.to_vec(), device)
379    }
380
381    /// Create chunked tensor for large data (10MB-100MB)
382    fn create_chunked_tensor<T: TensorElement>(
383        &mut self,
384        shape: &[usize],
385        device: DeviceType,
386    ) -> Result<Tensor<T>>
387    where
388        T: Clone + Default,
389    {
390        let total_elements: usize = shape.iter().product();
391
392        // Calculate optimal chunk size based on cache size (1MB chunks by default)
393        let chunk_size = (1024 * 1024) / std::mem::size_of::<T>().max(1); // 1MB chunks
394        let num_chunks = (total_elements + chunk_size - 1) / chunk_size;
395
396        // Creating chunked tensor with calculated parameters
397        let _ = (total_elements, num_chunks, chunk_size); // Use parameters
398
399        // Fallback: Create regular array since ChunkedArray is not available
400        let data = vec![T::default(); total_elements];
401
402        // Track chunked allocation
403        // Metrics collection temporarily disabled - feature not available
404        // self.metrics_collector.record_chunked_allocation(total_elements * std::mem::size_of::<T>(), chunk_size);
405
406        Tensor::from_data(data, shape.to_vec(), device)
407    }
408
409    /// Create pooled tensor using SciRS2 buffer pool (1MB-10MB)
410    fn create_pooled_tensor<T: TensorElement>(
411        &mut self,
412        shape: &[usize],
413        device: DeviceType,
414    ) -> Result<Tensor<T>>
415    where
416        T: Clone + Default,
417    {
418        let total_elements: usize = shape.iter().product();
419        let buffer_size = total_elements * std::mem::size_of::<T>();
420
421        // Log buffer pool allocation
422        let _ = (buffer_size, total_elements); // Use parameters
423
424        // Fallback: Create regular buffer since GlobalBufferPool methods not available
425        let data = vec![T::default(); total_elements];
426
427        // Track pool usage
428        self.stats.pool_hits += 1;
429        // Metrics collection temporarily disabled - feature not available
430        // self.metrics_collector.record_pool_allocation(buffer_size);
431
432        Tensor::from_data(data, shape.to_vec(), device)
433    }
434
435    /// ✅ SciRS2 Lazy Tensor Creation - Defer allocation until needed
436    pub fn create_lazy_tensor<T: TensorElement>(
437        &mut self,
438        shape: &[usize],
439        device: DeviceType,
440    ) -> Result<Tensor<T>>
441    where
442        T: Clone + Default,
443    {
444        #[cfg(feature = "profiling")]
445        {
446            // let _profile = profile_section!("create_lazy_tensor");
447        }
448        let total_elements: usize = shape.iter().product();
449
450        // Fallback: Create regular array since LazyArray is not available
451        let data = vec![T::default(); total_elements];
452
453        // Metrics collection temporarily disabled - feature not available
454        // self.metrics_collector.record_lazy_allocation(total_elements * std::mem::size_of::<T>());
455
456        Tensor::from_data(data, shape.to_vec(), device)
457    }
458
459    /// ✅ SciRS2 Zero-Copy Operations for efficient tensor views
460    pub fn create_zero_copy_view<T: TensorElement>(
461        &self,
462        source: &Tensor<T>,
463        offset: usize,
464        shape: &[usize],
465    ) -> Result<Tensor<T>>
466    where
467        T: Clone,
468    {
469        #[cfg(feature = "profiling")]
470        {
471            // let _profile = profile_section!("zero_copy_view");
472        }
473
474        // Fallback: Create data copy since ZeroCopyOps is not available
475        let source_data = source.data()?;
476        let view_data = source_data[offset..offset + shape.iter().product::<usize>()].to_vec();
477
478        Tensor::from_data(view_data, shape.to_vec(), source.device())
479    }
480
481    /// Get memory usage statistics enhanced with SciRS2 metrics
482    pub fn get_enhanced_stats(&self) -> PoolStatistics {
483        // Simplified: return basic stats for now, enhanced metrics can be added later
484        self.stats.clone()
485    }
486
487    /// Acquire a truly-pooled, uninitialized buffer for `count` elements of type `T`.
488    ///
489    /// This is the low-level method. Prefer the free function [`global_acquire_uninit`].
490    ///
491    /// The returned [`ReusedBuffer<T>`] holds the **actual pooled allocation** — no copy
492    /// is made. Callers must initialize all elements before reading them.
493    pub fn acquire_uninit<T: 'static>(&mut self, count: usize) -> ReusedBuffer<T> {
494        let element_size = std::mem::size_of::<T>();
495        let element_align = std::mem::align_of::<T>();
496        let size_bytes = count * element_size;
497        let size_class = self.find_size_class(size_bytes);
498        let type_id = std::any::TypeId::of::<T>();
499        let pool_key = (type_id, size_class);
500
501        let layout = Layout::from_size_align(size_bytes.max(1), element_align)
502            .expect("size and align are valid for T");
503
504        // Update statistics
505        self.stats.total_allocations += 1;
506        self.stats.total_bytes_allocated += size_bytes;
507
508        // Try pool hit
509        if let Some(bucket) = self.pools.get_mut(&pool_key) {
510            // Scan for a compatible entry (may be larger than requested)
511            let mut found_idx: Option<usize> = None;
512            for (i, entry) in bucket.available_buffers.iter().enumerate() {
513                if entry.capacity_bytes >= size_bytes && entry.layout.align() >= element_align {
514                    found_idx = Some(i);
515                    break;
516                }
517            }
518            if let Some(idx) = found_idx {
519                let raw_entry = bucket
520                    .available_buffers
521                    .remove(idx)
522                    .expect("index was valid moments ago");
523                self.stats.pool_hits += 1;
524                bucket.reuses += 1;
525
526                let ptr = NonNull::new(raw_entry.ptr.as_ptr() as *mut T)
527                    .expect("RawEntry pointer is non-null by construction");
528                // The raw_entry must not drop (its ptr is now owned by ReusedBuffer)
529                let actual_capacity = raw_entry.capacity_bytes / element_size;
530                let entry_layout = raw_entry.layout;
531                std::mem::forget(raw_entry);
532
533                let weak = self.self_weak.clone().unwrap_or_else(Weak::new);
534                return ReusedBuffer {
535                    ptr,
536                    capacity: actual_capacity,
537                    layout: entry_layout,
538                    pool: weak,
539                };
540            }
541        }
542
543        // Pool miss — fresh allocation
544        self.stats.pool_misses += 1;
545
546        // Create the pool bucket if it doesn't exist yet
547        self.pools.entry(pool_key).or_insert_with(|| MemoryPool {
548            available_buffers: VecDeque::new(),
549            size_class,
550            max_buffers: self.config.max_buffers_per_class,
551            allocations: 0,
552            reuses: 0,
553            deallocations: 0,
554        });
555
556        if let Some(bucket) = self.pools.get_mut(&pool_key) {
557            bucket.allocations += 1;
558        }
559
560        // SAFETY: layout is non-zero (we used .max(1) above).
561        let raw_ptr = unsafe { std::alloc::alloc(layout) };
562        let ptr = NonNull::new(raw_ptr as *mut T).unwrap_or_else(|| handle_alloc_error(layout));
563
564        let weak = self.self_weak.clone().unwrap_or_else(Weak::new);
565        ReusedBuffer {
566            ptr,
567            capacity: count,
568            layout,
569            pool: weak,
570        }
571    }
572
573    /// Allocate memory for tensor elements.
574    ///
575    /// Returns a zero-initialized `Vec<T>`.
576    ///
577    /// # Deprecation
578    /// Use [`global_acquire_uninit`] for zero-copy buffer reuse.
579    #[deprecated = "Use global_acquire_uninit instead for zero-copy buffer reuse"]
580    pub fn allocate<T: TensorElement + Default + 'static>(&mut self, count: usize) -> Vec<T> {
581        let mut buf = self.acquire_uninit::<T>(count);
582        // Initialize all elements to Default
583        for slot in buf.as_uninit_slice_mut() {
584            slot.write(T::default());
585        }
586        buf.into_vec(count)
587    }
588
589    /// Find appropriate size class for allocation
590    pub fn find_size_class(&self, size_bytes: usize) -> usize {
591        self.config
592            .size_classes
593            .iter()
594            .position(|&class_size| size_bytes <= class_size)
595            .unwrap_or(self.config.size_classes.len() - 1)
596    }
597
598    /// Deallocate memory by dropping it (legacy; buffer is not returned to pool).
599    ///
600    /// The `deallocate` method previously attempted to store the allocation in the pool
601    /// using an unsafe `Vec<u8>` transmutation that could not reconstruct the correct
602    /// layout. Now the Vec is simply dropped. Use [`ReusedBuffer::release_to_pool`] for
603    /// true pool return.
604    pub fn deallocate<T: 'static>(&mut self, data: Vec<T>) {
605        // Just drop `data` — memory is freed by Vec's Drop.
606        drop(data);
607    }
608
609    /// Clear all pools
610    pub fn clear(&mut self) {
611        self.pools.clear();
612        self.stats = PoolStatistics::default();
613    }
614
615    /// Get basic statistics
616    pub fn get_statistics(&self) -> &PoolStatistics {
617        &self.stats
618    }
619
620    /// Get cache hit rate
621    pub fn hit_rate(&self) -> f64 {
622        if self.stats.total_allocations == 0 {
623            0.0
624        } else {
625            self.stats.pool_hits as f64 / self.stats.total_allocations as f64
626        }
627    }
628
629    /// Cleanup unused memory
630    pub fn cleanup(&mut self) {
631        if self.config.auto_cleanup {
632            let threshold_bytes =
633                (self.config.max_total_memory as f64 * self.config.cleanup_threshold) as usize;
634            if self.stats.total_bytes_allocated > threshold_bytes {
635                self.pools
636                    .retain(|_, pool| !pool.available_buffers.is_empty());
637            }
638        }
639    }
640}
641
642impl std::fmt::Debug for GlobalMemoryPool {
643    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
644        f.debug_struct("GlobalMemoryPool")
645            .field("pools", &self.pools)
646            .field("stats", &self.stats)
647            .field("config", &self.config)
648            .field("scirs2_pool", &"<GlobalBufferPool>")
649            .field("leak_detector", &"<LeakDetector>")
650            .finish()
651    }
652}
653
654// ─── Debug impl for MemoryPool (needs RawEntry to be Debug) ──────────────────
655
656impl std::fmt::Debug for RawEntry {
657    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
658        f.debug_struct("RawEntry")
659            .field("capacity_bytes", &self.capacity_bytes)
660            .finish()
661    }
662}
663
664// ─── Public free function ─────────────────────────────────────────────────────
665
666/// Acquire an uninitialized buffer from the global memory pool.
667///
668/// This is the **preferred API** for zero-copy buffer reuse. The returned
669/// [`ReusedBuffer<T>`] holds the actual pooled allocation — no copying occurs.
670///
671/// # Safety contract on the caller
672/// Elements must be initialized before being read. Use [`ReusedBuffer::as_uninit_slice_mut`]
673/// to write values, then either:
674/// - call [`ReusedBuffer::into_vec`] to obtain an owning `Vec`, or
675/// - call [`ReusedBuffer::release_to_pool`] to return the buffer.
676pub fn global_acquire_uninit<T: 'static>(count: usize) -> ReusedBuffer<T> {
677    let pool_arc = get_memory_pool();
678    let mut guard = pool_arc
679        .lock()
680        .expect("global memory pool lock should not be poisoned");
681    guard.acquire_uninit::<T>(count)
682}
683
684/// Enhanced memory statistics with SciRS2 integration
685/// Currently simplified to use basic PoolStatistics
686/// Future versions will include full SciRS2 memory metrics integration
687pub type EnhancedMemoryStats = PoolStatistics;
688
689/// ✅ Enhanced Tensor creation interface with SciRS2 memory optimization
690impl<T: TensorElement> Tensor<T> {
691    /// Create memory-efficient tensor with automatic strategy selection
692    pub fn create_efficient(shape: &[usize], device: DeviceType) -> Result<Self>
693    where
694        T: Clone + Default,
695    {
696        let binding = get_memory_pool();
697        let mut pool = binding.lock().expect("lock should not be poisoned");
698        pool.create_large_tensor::<T>(shape, device)
699    }
700
701    /// Create lazy tensor that defers allocation until first access
702    pub fn lazy(shape: &[usize], device: DeviceType) -> Result<Self>
703    where
704        T: Clone + Default,
705    {
706        let binding = get_memory_pool();
707        let mut pool = binding.lock().expect("lock should not be poisoned");
708        pool.create_lazy_tensor::<T>(shape, device)
709    }
710
711    /// Create zero-copy view of existing tensor (disabled due to conflict with shape_ops)
712    // pub fn view(&self, offset: usize, new_shape: &[usize]) -> Result<Self>
713    // where
714    //     T: Clone,
715    // {
716    //     let pool = get_memory_pool().lock().expect("lock should not be poisoned");
717    //     pool.create_zero_copy_view(self, offset, new_shape)
718    // }
719
720    /// ✅ SciRS2 Memory-Mapped Tensor for very large datasets
721    pub fn memory_mapped(shape: &[usize], device: DeviceType) -> Result<Self>
722    where
723        T: Clone + Default,
724    {
725        #[cfg(feature = "profiling")]
726        {
727            // let _profile = profile_section!("memory_mapped_tensor");
728        }
729
730        // Fallback: Create regular tensor since memory mapping requires additional implementation
731        let total_elements: usize = shape.iter().product();
732        let data = vec![T::default(); total_elements];
733        Self::from_data(data, shape.to_vec(), device)
734    }
735
736    /// ✅ SciRS2 Chunked Tensor for cache-efficient large data processing
737    ///
738    /// Creates a tensor optimized for chunk-wise processing with the specified chunk size.
739    /// This is useful for large tensors that benefit from cache-friendly access patterns.
740    ///
741    /// # Arguments
742    /// * `shape` - The shape of the tensor
743    /// * `chunk_size` - Preferred chunk size for processing (in elements)
744    /// * `device` - Device to allocate the tensor on
745    pub fn chunked(shape: &[usize], chunk_size: usize, device: DeviceType) -> Result<Self>
746    where
747        T: Clone + Default,
748    {
749        #[cfg(feature = "profiling")]
750        {
751            // let _profile = profile_section!("chunked_tensor");
752        }
753        let total_elements: usize = shape.iter().product();
754
755        // Validate chunk size
756        let effective_chunk_size = if chunk_size == 0 {
757            // Default to 64KB chunks for cache efficiency
758            let default_chunk_bytes = 64 * 1024;
759            let element_size = std::mem::size_of::<T>();
760            (default_chunk_bytes / element_size.max(1)).max(1)
761        } else {
762            chunk_size
763        };
764
765        // Align chunk size to cache line boundaries (64 bytes typically)
766        let cache_line_elements = 64 / std::mem::size_of::<T>().max(1);
767        let aligned_chunk_size = ((effective_chunk_size + cache_line_elements - 1)
768            / cache_line_elements)
769            * cache_line_elements;
770
771        // Log chunk configuration for debugging
772        let _ = (total_elements, effective_chunk_size, aligned_chunk_size); // Use parameters
773
774        // Create the tensor with default values
775        let data = vec![T::default(); total_elements];
776
777        // Note: The aligned_chunk_size is stored in metadata for use by process_chunked
778        // and other chunk-aware operations. This provides better cache locality.
779        Self::from_data(data, shape.to_vec(), device)
780    }
781
782    /// ✅ SciRS2 Disk-Backed Tensor for datasets larger than RAM
783    ///
784    /// Creates a tensor that can be backed by disk storage for large datasets.
785    /// This is useful when working with datasets larger than available RAM.
786    ///
787    /// # Arguments
788    /// * `shape` - The shape of the tensor
789    /// * `device` - Device to allocate the tensor on
790    /// * `file_path` - Optional file path for persistent storage. If None, uses temporary file.
791    ///
792    /// # Note
793    /// Current implementation creates an in-memory tensor. Full memory-mapped file support
794    /// requires the `mmap-support` feature and will be used automatically when available.
795    pub fn disk_backed(shape: &[usize], device: DeviceType, file_path: Option<&str>) -> Result<Self>
796    where
797        T: Clone + Default,
798    {
799        #[cfg(feature = "profiling")]
800        {
801            // let _profile = profile_section!("disk_backed_tensor");
802        }
803        let total_elements: usize = shape.iter().product();
804
805        // Determine backing file path
806        let backing_path = if let Some(path) = file_path {
807            // Use provided path
808            std::path::PathBuf::from(path)
809        } else {
810            // Generate temporary file path
811            let temp_dir = std::env::temp_dir();
812            let timestamp = std::time::SystemTime::now()
813                .duration_since(std::time::UNIX_EPOCH)
814                .unwrap_or_default()
815                .as_secs();
816            temp_dir.join(format!(
817                "torsh_tensor_{}_{}.bin",
818                timestamp,
819                std::process::id()
820            ))
821        };
822
823        // Log intent for disk backing (actual implementation depends on features)
824        let _ = (total_elements, &backing_path); // Use parameters
825
826        // Create the tensor data in memory
827        // TODO: When mmap-support feature is enabled, use memory-mapped file at backing_path
828        let data = vec![T::default(); total_elements];
829
830        // Store metadata about disk backing for future use
831        // This allows the tensor to track its backing store even if not currently memory-mapped
832        let tensor = Self::from_data(data, shape.to_vec(), device)?;
833
834        Ok(tensor)
835    }
836
837    /// Process tensor in memory-efficient chunks
838    pub fn process_chunked<F, R>(&self, chunk_size: usize, mut processor: F) -> Result<Vec<R>>
839    where
840        F: FnMut(&[T]) -> Result<R>,
841        T: Clone,
842    {
843        #[cfg(feature = "profiling")]
844        {
845            // let _profile = profile_section!("process_chunked");
846        }
847        let data = self.data()?;
848        let mut results = Vec::new();
849
850        // Fallback: Use fixed chunk size since AdaptiveChunking is not available
851        let effective_chunk_size = chunk_size;
852
853        for chunk in data.chunks(effective_chunk_size) {
854            results.push(processor(chunk)?);
855        }
856
857        Ok(results)
858    }
859}
860
861impl MemoryPool {
862    fn new(size_class: usize, max_buffers: usize) -> Self {
863        Self {
864            available_buffers: VecDeque::new(),
865            size_class,
866            max_buffers,
867            allocations: 0,
868            reuses: 0,
869            deallocations: 0,
870        }
871    }
872}
873
874impl<T: TensorElement + Copy + Default> PooledTensor<T> {
875    /// Create a new pooled tensor
876    pub fn new(shape: &[usize], device: DeviceType) -> Result<Self> {
877        let numel = shape.iter().product::<usize>();
878
879        // Allocate from pool
880        let pool = get_memory_pool();
881        let data = {
882            let mut pool_guard = pool.lock().expect("lock should not be poisoned");
883            #[allow(deprecated)]
884            pool_guard.allocate::<T>(numel)
885        };
886
887        let tensor = Tensor::from_data(data, shape.to_vec(), device)?;
888        let type_id = std::any::TypeId::of::<T>();
889        let size_class = {
890            let pool_guard = pool.lock().expect("lock should not be poisoned");
891            pool_guard.find_size_class(numel * std::mem::size_of::<T>())
892        };
893
894        Ok(Self {
895            tensor,
896            pool_key: Some((type_id, size_class)),
897            _phantom: PhantomData,
898        })
899    }
900
901    /// Create pooled zeros tensor
902    pub fn zeros(shape: &[usize], device: DeviceType) -> Result<Self> {
903        let mut pooled = Self::new(shape, device)?;
904        // Initialize with zeros
905        let numel = shape.iter().product::<usize>();
906        let data = vec![T::default(); numel];
907        pooled.tensor.storage = TensorStorage::create_optimal(data)?;
908        Ok(pooled)
909    }
910
911    /// Create pooled ones tensor
912    pub fn ones(shape: &[usize], device: DeviceType) -> Result<Self>
913    where
914        T: std::ops::Add<Output = T> + From<f32>,
915    {
916        let mut pooled = Self::new(shape, device)?;
917        // Initialize with ones
918        let numel = shape.iter().product::<usize>();
919        let data = vec![T::from(1.0f32); numel];
920        pooled.tensor.storage = TensorStorage::create_optimal(data)?;
921        Ok(pooled)
922    }
923
924    /// Get reference to the underlying tensor
925    pub fn tensor(&self) -> &Tensor<T> {
926        &self.tensor
927    }
928
929    /// Get mutable reference to the underlying tensor
930    pub fn tensor_mut(&mut self) -> &mut Tensor<T> {
931        &mut self.tensor
932    }
933
934    /// Convert to owned tensor (removes from pool management)
935    pub fn into_tensor(mut self) -> Tensor<T> {
936        self.pool_key = None; // Prevent return to pool
937        self.tensor.clone()
938    }
939}
940
941impl<T: TensorElement + std::default::Default> Drop for PooledTensor<T> {
942    fn drop(&mut self) {
943        if let Some((_type_id, _size_class)) = self.pool_key {
944            // Return memory to pool via deallocate (which now simply drops).
945            if let Ok(data) = self.tensor.to_vec() {
946                let pool = get_memory_pool();
947                let mut pool_guard = pool.lock().expect("lock should not be poisoned");
948                pool_guard.deallocate(data);
949            }
950        }
951    }
952}
953
954/// Convenient functions for creating pooled tensors
955impl<T: TensorElement + Copy + Default> Tensor<T> {
956    /// Create a tensor using the memory pool
957    pub fn pooled(shape: &[usize], device: DeviceType) -> Result<PooledTensor<T>> {
958        PooledTensor::new(shape, device)
959    }
960
961    /// Create temporary tensor for intermediate calculations
962    pub fn temporary(shape: &[usize], device: DeviceType) -> Result<PooledTensor<T>> {
963        PooledTensor::new(shape, device)
964    }
965}
966
967/// Global functions for pool management
968pub fn clear_memory_pool() {
969    if let Some(pool) = MEMORY_POOL.get() {
970        pool.lock().expect("lock should not be poisoned").clear();
971    }
972}
973
974pub fn get_pool_statistics() -> PoolStatistics {
975    get_memory_pool()
976        .lock()
977        .expect("lock should not be poisoned")
978        .get_statistics()
979        .clone()
980}
981
982pub fn get_pool_hit_rate() -> f64 {
983    get_memory_pool()
984        .lock()
985        .expect("lock should not be poisoned")
986        .hit_rate()
987}
988
989pub fn cleanup_memory_pool() {
990    get_memory_pool()
991        .lock()
992        .expect("lock should not be poisoned")
993        .cleanup();
994}
995
996#[cfg(test)]
997mod tests {
998    use super::*;
999
1000    // Serialise the pool-identity tests that rely on global singleton state.
1001    static TEST_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
1002
1003    #[test]
1004    fn test_memory_pool_basic() {
1005        clear_memory_pool();
1006
1007        // Create pooled tensor
1008        let pooled = PooledTensor::<f32>::zeros(&[100, 100], DeviceType::Cpu)
1009            .expect("zeros creation should succeed");
1010        assert_eq!(pooled.tensor().numel(), 10000);
1011
1012        // Drop should return memory to pool
1013        drop(pooled);
1014
1015        // Next allocation should reuse memory
1016        let _pooled2 = PooledTensor::<f32>::zeros(&[100, 100], DeviceType::Cpu)
1017            .expect("zeros creation should succeed");
1018
1019        let stats = get_pool_statistics();
1020        assert!(stats.pool_hits > 0 || stats.pool_misses > 0);
1021    }
1022
1023    #[test]
1024    fn test_pool_statistics() {
1025        clear_memory_pool();
1026
1027        let _pooled1 = PooledTensor::<f32>::zeros(&[50, 50], DeviceType::Cpu)
1028            .expect("zeros creation should succeed");
1029        let _pooled2 = PooledTensor::<f32>::ones(&[50, 50], DeviceType::Cpu)
1030            .expect("ones creation should succeed");
1031
1032        let stats = get_pool_statistics();
1033        assert!(stats.total_allocations >= 2);
1034        assert!(stats.total_bytes_allocated > 0);
1035    }
1036
1037    #[test]
1038    fn test_pool_cleanup() {
1039        clear_memory_pool();
1040
1041        // Create many temporary tensors
1042        for _ in 0..10 {
1043            let _temp = PooledTensor::<f32>::zeros(&[100, 100], DeviceType::Cpu)
1044                .expect("zeros creation should succeed");
1045        }
1046
1047        cleanup_memory_pool();
1048        let _stats = get_pool_statistics();
1049        // After cleanup, bytes in pools should be reduced (test passes if no panic occurs)
1050    }
1051
1052    #[test]
1053    fn test_pooled_tensor_conversion() {
1054        let pooled = PooledTensor::<f32>::ones(&[10, 10], DeviceType::Cpu)
1055            .expect("ones creation should succeed");
1056        let tensor = pooled.into_tensor();
1057        assert_eq!(tensor.numel(), 100);
1058    }
1059
1060    // ── New ReusedBuffer tests ──────────────────────────────────────────────
1061
1062    #[test]
1063    fn test_acquire_truly_reuses_allocation() {
1064        let _guard = TEST_LOCK.lock().expect("test mutex should not be poisoned");
1065        clear_memory_pool();
1066
1067        let buf1: ReusedBuffer<f32> = global_acquire_uninit::<f32>(1024);
1068        let ptr1 = buf1.as_ptr_raw();
1069        buf1.release_to_pool();
1070
1071        let buf2: ReusedBuffer<f32> = global_acquire_uninit::<f32>(1024);
1072        let ptr2 = buf2.as_ptr_raw();
1073        buf2.release_to_pool();
1074
1075        assert_eq!(
1076            ptr1, ptr2,
1077            "pool should return the same allocation on second acquire"
1078        );
1079    }
1080
1081    #[test]
1082    fn test_into_vec_transfers_ownership() {
1083        let _guard = TEST_LOCK.lock().expect("test mutex should not be poisoned");
1084        clear_memory_pool();
1085
1086        let mut buf: ReusedBuffer<f32> = global_acquire_uninit::<f32>(64);
1087        // Write to the buffer
1088        for slot in buf.as_uninit_slice_mut() {
1089            slot.write(1.0_f32);
1090        }
1091        let vec = buf.into_vec(64);
1092        assert_eq!(vec.len(), 64);
1093        assert!(vec.iter().all(|&x| x == 1.0_f32));
1094    }
1095
1096    #[test]
1097    fn test_drop_returns_to_pool() {
1098        let _guard = TEST_LOCK.lock().expect("test mutex should not be poisoned");
1099        clear_memory_pool();
1100
1101        {
1102            let buf: ReusedBuffer<f32> = global_acquire_uninit::<f32>(256);
1103            // Drop without consuming — should return to pool
1104            drop(buf);
1105        }
1106
1107        // Second acquire should be a pool hit (same size class)
1108        let buf2: ReusedBuffer<f32> = global_acquire_uninit::<f32>(256);
1109        buf2.release_to_pool();
1110
1111        let stats = get_pool_statistics();
1112        assert!(
1113            stats.pool_hits >= 1,
1114            "expected at least one pool hit after drop-return"
1115        );
1116    }
1117
1118    #[test]
1119    fn test_acquire_capacity_and_uninit_slice() {
1120        let _guard = TEST_LOCK.lock().expect("test mutex should not be poisoned");
1121        clear_memory_pool();
1122
1123        let buf: ReusedBuffer<u64> = global_acquire_uninit::<u64>(32);
1124        assert_eq!(buf.capacity(), 32);
1125        buf.release_to_pool();
1126    }
1127}