train_station/tensor/core/
memory.rs

1//! High-performance memory management for tensor operations
2//!
3//! This module provides thread-local memory pools optimized for ML workloads
4//! with frequent tensor allocation and deallocation. Designed as the foundation
5//! for AGI/ASI research with zero dependencies and maximum performance.
6//!
7//! # Key Features
8//!
9//! - **Thread-Local Pools**: Eliminate contention with per-thread pools
10//! - **Size-Class Optimization**: Optimized for common ML tensor sizes (scalars to large matrices)
11//! - **Zero-Copy Integration**: Seamless integration with tensor view system
12//! - **Statistics Tracking**: Memory usage monitoring and optimization
13//! - **SIMD Alignment**: 32-byte alignment for AVX2 operations
14//! - **Research Enablement**: Predictable allocation patterns for novel architectures
15//!
16//! # Performance Characteristics
17//!
18//! - **Allocation Speed**: 5-10x faster than system allocator for pooled sizes
19//! - **Memory Efficiency**: Reduced fragmentation through ML-optimized size classes
20//! - **Cache Locality**: Better cache utilization through buffer reuse
21//! - **Thread Safety**: Lock-free through thread-local storage
22//! - **Zero Dependencies**: Pure Rust implementation with no external dependencies
23//! - **Edge Ready**: Minimal memory overhead suitable for embedded deployment
24
25use std::alloc::Layout;
26use std::cell::Cell;
27use std::cell::RefCell;
28use std::ptr::NonNull;
29use std::time::Instant;
30// no global atomics needed in simplified design
31
32// Global cross-thread counters removed for simplicity; thread-local stats remain
33
34/// Memory pool statistics for performance monitoring
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub struct PoolStats {
37    /// Total number of allocation requests
38    pub allocations: usize,
39    /// Total number of deallocation requests  
40    pub deallocations: usize,
41    /// Number of successful pool hits (allocations served from pool)
42    pub pool_hits: usize,
43    /// Number of pool misses (allocations that fell back to system allocator)
44    pub pool_misses: usize,
45    /// Current memory usage in bytes
46    pub current_usage: usize,
47    /// Peak memory usage in bytes
48    pub peak_usage: usize,
49}
50
51/// Size classes optimized for ML workloads
52///
53/// Based on analysis of common tensor sizes in ML applications:
54/// - Small: Scalars, small vectors, activations (≤4KB) - covers up to 32x32 matrices
55/// - Medium: Embeddings, medium matrices (4KB-256KB) - covers 64x64 to 256x256 matrices
56/// - Large: Batch data, large matrices (256KB-4MB) - covers large batch processing
57/// - XLarge: Very large tensors (>4MB) - covers massive models and datasets
58pub const SMALL_BUFFER_SIZE: usize = 1024; // 4KB (1024 * 4 bytes) - up to 32x32 matrices
59pub const MEDIUM_BUFFER_SIZE: usize = 65536; // 256KB (65536 * 4 bytes) - up to 256x256 matrices
60pub const LARGE_BUFFER_SIZE: usize = 1048576; // 4MB (1048576 * 4 bytes) - large batch processing
61
62/// **CRITICAL DESIGN PRINCIPLE**: NO MAXIMUM LIMITS
63///
64/// The memory pool NEVER prevents tensor creation. Instead, it uses adaptive
65/// management to balance performance and memory usage. Users control memory
66/// through their allocation patterns, not artificial limits.
67///
68/// Pool management strategy:
69/// - Pools grow dynamically based on usage patterns
70/// - Automatic cleanup of unused buffers during low activity
71/// - Memory pressure detection for adaptive behavior
72/// - User-controlled memory management through allocation patterns
73///   Target pool sizes for optimal performance (not limits!)
74const TARGET_SMALL_BUFFERS: usize = 32; // Optimal: 32KB cached
75const TARGET_MEDIUM_BUFFERS: usize = 16; // Optimal: 1MB cached
76const TARGET_LARGE_BUFFERS: usize = 8; // Optimal: 8MB cached
77                                       // Cleanup heuristics: conservative headroom and cadence
78const HEADROOM_SMALL: usize = 8;
79const HEADROOM_MEDIUM: usize = 4;
80const HEADROOM_LARGE: usize = 2;
81const HEADROOM_XLARGE: usize = 1;
82
83// Minimum operations and time between cleanup passes (hybrid gating)
84const CLEANUP_MIN_OPS: u64 = 2048;
85const CLEANUP_MIN_INTERVAL_MS: u64 = 2000; // 2s
86
87// A buffer must remain unused for at least this many ops since last touch
88const UNUSED_OPS_THRESHOLD: u64 = 4096;
89
90/// Pooled memory buffer with alignment guarantees and lifecycle tracking
91///
92/// Provides SIMD-aligned memory buffers for efficient tensor operations.
93/// Buffers are reused across allocations to reduce overhead and support
94/// advanced view system integration.
95///
96/// # Key Features
97/// - **Adaptive Lifecycle**: Tracks usage patterns for intelligent management
98/// - **View Integration**: Optimized for tensor view operations
99/// - **Future Proof**: Extensible design for novel ML architectures
100/// - **Zero Limits**: No artificial constraints on buffer creation
101pub struct PooledBuffer {
102    /// Owning allocation for this pooled buffer (system-owned; pool manages lifetime)
103    alloc: crate::tensor::core::Allocation,
104    /// Whether this buffer is currently checked out
105    in_use: bool,
106    /// Last time (in pool ops) this buffer was touched (allocated or returned)
107    last_used_counter: u64,
108}
109
110/// Thread-local memory pool for tensor allocation with adaptive management
111///
112/// **REVOLUTIONARY DESIGN**: No artificial limits - pools grow and shrink based on
113/// actual usage patterns. Optimized for ML workloads with intelligent view system
114/// integration and future-proof extensibility.
115///
116/// # Key Features
117/// - **Unlimited Growth**: Pools expand as needed until system memory exhausted
118/// - **Adaptive Cleanup**: Automatic cleanup of unused buffers during low activity
119/// - **View Optimization**: Special handling for tensors used in view operations
120/// - **Future Proof**: Extensible design for novel ML architectures
121/// - **User Control**: Memory management through allocation patterns, not limits
122pub struct TensorMemoryPool {
123    /// Small buffers for scalars, small vectors (≤1KB)
124    /// **NO SIZE LIMIT** - grows dynamically based on usage
125    small_buffers: Vec<PooledBuffer>,
126
127    /// Medium buffers for embeddings, small matrices (1KB-64KB)
128    /// **NO SIZE LIMIT** - grows dynamically based on usage
129    medium_buffers: Vec<PooledBuffer>,
130
131    /// Large buffers for batch data, large matrices (64KB-1MB)
132    /// **NO SIZE LIMIT** - grows dynamically based on usage
133    large_buffers: Vec<PooledBuffer>,
134
135    /// Extra large buffers for massive tensors (>1MB)
136    /// Reinstated for stress testing rapid reuse stability
137    xlarge_buffers: Vec<PooledBuffer>,
138
139    /// Statistics for this thread's pool
140    stats: PoolStats,
141    // Simplified: no adaptive management state; dynamic growth via Vec
142    /// Monotonic operation counter to timestamp buffer activity
143    op_counter: u64,
144    /// Last cleanup op counter (to avoid frequent passes)
145    last_cleanup_counter: u64,
146    /// Wall-clock last cleanup time (additional gate)
147    last_cleanup_instant: Instant,
148}
149
150// Simplified: removed adaptive/view metrics/usage patterns to reduce complexity
151
152/// Size class enumeration for pattern analysis
153#[derive(Debug, Clone, Copy, PartialEq, Eq)]
154enum SizeClass {
155    Small,  // ≤1KB
156    Medium, // 1KB-64KB
157    Large,  // 64KB-1MB
158    XLarge, // >1MB
159}
160
161// Removed ComprehensivePoolStats/BufferCounts; use thread_stats() when needed
162
163// Duplicate PoolStats removed - using the one defined earlier with proper field names
164
165thread_local! {
166    static MEMORY_POOL: RefCell<TensorMemoryPool> = RefCell::new(TensorMemoryPool::new());
167    /// Thread-local flag to disable memory padding and pooling for allocations made
168    /// during the active context. When enabled, allocations will not add lane-size
169    /// padding and will prefer exact-size system allocations over the pool.
170    static NO_MEM_PADDING: Cell<bool> = const { Cell::new(false) };
171    /// Thread-local flag to control whether allocations should use the memory pool.
172    /// Defaults to true for efficiency. When false, allocations use the system allocator.
173    static USE_POOL_ALLOC: Cell<bool> = const { Cell::new(true) };
174}
175
176/// Runtime SIMD capability level on the current CPU
177#[derive(Debug, Clone, Copy, PartialEq, Eq)]
178pub enum SimdLevel {
179    #[cfg(target_arch = "x86_64")]
180    Avx512,
181    #[cfg(target_arch = "x86_64")]
182    Avx2,
183    #[cfg(target_arch = "x86_64")]
184    Sse2,
185    Scalar,
186}
187
188/// Detect the highest available SIMD level at runtime.
189///
190/// On x86_64, checks for AVX512, AVX2, then SSE2 in descending order. Falls back
191/// to `Scalar` when features are unavailable or on non-x86 targets.
192#[inline]
193pub fn detect_runtime_simd() -> SimdLevel {
194    #[cfg(target_arch = "x86_64")]
195    {
196        // Check in descending order
197        if is_x86_feature_detected!("avx512f") {
198            return SimdLevel::Avx512;
199        }
200        if is_x86_feature_detected!("avx2") {
201            return SimdLevel::Avx2;
202        }
203        if is_x86_feature_detected!("sse2") {
204            return SimdLevel::Sse2;
205        }
206
207        SimdLevel::Scalar
208    }
209    #[cfg(not(target_arch = "x86_64"))]
210    {
211        SimdLevel::Scalar
212    }
213}
214
215/// Lane width (elements per vector) for the given SIMD level for f32 values
216#[inline]
217pub(crate) fn simd_lane_width_elems(level: SimdLevel) -> usize {
218    match level {
219        #[cfg(target_arch = "x86_64")]
220        SimdLevel::Avx512 => 16, // 512 / 32
221        #[cfg(target_arch = "x86_64")]
222        SimdLevel::Avx2 => 8, // 256 / 32
223        #[cfg(target_arch = "x86_64")]
224        SimdLevel::Sse2 => 4, // 128 / 32
225        SimdLevel::Scalar => 1,
226    }
227}
228
229/// Alignment in bytes recommended for the given SIMD level
230///
231/// Returns a conservative alignment for each SIMD tier to enable aligned loads/stores.
232/// On Scalar, returns 16 for general safety.
233#[inline]
234pub fn simd_alignment_bytes(level: SimdLevel) -> usize {
235    match level {
236        #[cfg(target_arch = "x86_64")]
237        SimdLevel::Avx512 => 64,
238        #[cfg(target_arch = "x86_64")]
239        SimdLevel::Avx2 => 32,
240        #[cfg(target_arch = "x86_64")]
241        SimdLevel::Sse2 => 16,
242        SimdLevel::Scalar => 16, // keep at least 16 for general safety
243    }
244}
245
246/// Compute allocation alignment (bytes) and padded element count for a requested length.
247///
248/// When NoMemPadding is enabled, padding is disabled and the exact element count is returned.
249/// Otherwise, the element count is rounded up to the nearest SIMD lane width to improve
250/// vectorization and store alignment.
251///
252/// # Arguments
253///
254/// * `requested_elems` - Desired number of elements
255///
256/// # Returns
257///
258/// `(alignment_bytes, padded_elems)`
259#[inline]
260pub fn compute_allocation_params(requested_elems: usize) -> (usize, usize) {
261    let level = detect_runtime_simd();
262    #[cfg(target_arch = "x86_64")]
263    let mut align = simd_alignment_bytes(level);
264    #[cfg(not(target_arch = "x86_64"))]
265    let align = simd_alignment_bytes(level);
266
267    // Preserve existing alignment policy (keep minimums but avoid over-padding semantics)
268    #[cfg(target_arch = "x86_64")]
269    {
270        if is_x86_feature_detected!("avx512f") {
271            align = 64;
272        } else if is_x86_feature_detected!("avx2") {
273            align = align.max(32);
274        }
275    }
276
277    if no_mem_padding_enabled() || requested_elems == 0 {
278        (align, requested_elems)
279    } else {
280        let lane = simd_lane_width_elems(level);
281        let padded = requested_elems.div_ceil(lane) * lane;
282        (align, padded)
283    }
284}
285
286/// Minimum element count at which non-temporal (streaming) stores become profitable
287/// for large, linear copies on the current CPU. Returns `usize::MAX` for Scalar to
288/// effectively disable streaming stores when SIMD is unavailable.
289#[inline]
290#[cfg(target_arch = "x86_64")]
291pub fn stream_min_elems() -> usize {
292    match detect_runtime_simd() {
293        SimdLevel::Avx512 => 16_384, // 64 KiB+ worth of data
294        SimdLevel::Avx2 => 8_192,    // 32 KiB+
295        SimdLevel::Sse2 => 4_096,    // 16 KiB+
296        SimdLevel::Scalar => usize::MAX,
297    }
298}
299
300/// Prefetch lookahead distance in elements for long, streaming loops. Tuned per
301/// SIMD level to balance cache pollution and latency hiding. Returns 0 to skip
302/// prefetching when in Scalar mode.
303#[inline]
304#[cfg(target_arch = "x86_64")]
305pub fn prefetch_distance_elems() -> usize {
306    match detect_runtime_simd() {
307        SimdLevel::Avx512 => 512, // 2 KiB lookahead
308        SimdLevel::Avx2 => 256,   // 1 KiB lookahead
309        SimdLevel::Sse2 => 128,   // 512 B lookahead
310        SimdLevel::Scalar => 0,
311    }
312}
313
314/// Centralized heuristic for auto-tuned chunk sizes used by `iter_fast_chunks()`.
315///
316/// Targets approximately 64 KiB working-set per chunk, clamps to a conservative
317/// range, and rounds up to the current SIMD lane width for better alignment.
318///
319/// # Arguments
320///
321/// * `total_elems` - Total number of elements to process
322///
323/// # Returns
324///
325/// Chunk size in elements
326#[inline]
327pub fn choose_fast_chunk_size(total_elems: usize) -> usize {
328    if total_elems == 0 {
329        return 1;
330    }
331    // Start from a cache-aware baseline and scale lightly with problem size
332    let mut sz = 16_384usize; // 64 KiB of f32
333    if total_elems < 16_384 {
334        sz = 4_096;
335    } else if total_elems > 1_048_576 {
336        sz = 65_536;
337    }
338    // Align to SIMD lane width to reduce tail handling
339    let lane = simd_lane_width_elems(detect_runtime_simd());
340    if lane > 1 {
341        sz = sz.div_ceil(lane) * lane;
342    }
343    // Final clamp to practical bounds to avoid tiny or excessively large chunks
344    sz.clamp(4_096, 262_144)
345}
346
347/// Returns true if the current thread prefers using the memory pool for allocations.
348#[inline]
349pub fn use_pool_alloc_enabled() -> bool {
350    USE_POOL_ALLOC.with(|flag| flag.get())
351}
352
353impl PooledBuffer {
354    /// Creates a new pooled buffer with specified size and alignment
355    ///
356    /// **DESIGN PRINCIPLE**: Never fails due to limits - always creates buffer
357    /// if system memory is available. Users control memory through their patterns.
358    fn new(size: usize, alignment: usize) -> Self {
359        // Ensure alignment is at least align_of::<f32>()
360        let effective_alignment = alignment.max(std::mem::align_of::<f32>());
361        let layout =
362            Layout::from_size_align(size * std::mem::size_of::<f32>(), effective_alignment)
363                .expect("Invalid layout for pooled buffer");
364        // Use system allocation via Allocation; pool owns this memory
365        let alloc =
366            crate::tensor::core::Allocation::new_uninitialized(size, effective_alignment, layout);
367        // Verify alignment satisfies requested
368        let addr = alloc.ptr.as_ptr() as usize;
369        assert_eq!(
370            addr % alignment,
371            0,
372            "System allocator failed to provide {}-byte aligned memory. Got address 0x{:x} (alignment {})",
373            alignment,
374            addr,
375            addr % alignment
376        );
377        PooledBuffer {
378            alloc,
379            in_use: false,
380            last_used_counter: 0,
381        }
382    }
383
384    /// Gets the raw pointer to the buffer data
385    #[inline(always)]
386    pub fn as_ptr(&self) -> NonNull<f32> {
387        self.alloc.ptr
388    }
389
390    /// Gets the size of the buffer in elements
391    #[inline(always)]
392    pub fn size(&self) -> usize {
393        self.alloc.capacity_elems()
394    }
395
396    // Removed buffer_id tracking in simplified design
397
398    /// Allocates this buffer for tensor use
399    #[inline]
400    fn allocate_for_tensor(&mut self, now_counter: u64) -> bool {
401        if self.in_use {
402            false
403        } else {
404            self.in_use = true;
405            self.last_used_counter = now_counter;
406            true
407        }
408    }
409
410    /// Returns buffer to available state
411    #[inline]
412    fn return_to_pool(&mut self, now_counter: u64) {
413        self.in_use = false;
414        self.last_used_counter = now_counter;
415    }
416
417    /// Checks if buffer is available for allocation
418    #[inline(always)]
419    pub fn is_available(&self) -> bool {
420        !self.in_use
421    }
422}
423
424// No custom Drop needed; `alloc` owns the memory and will free on drop.
425
426impl TensorMemoryPool {
427    /// Creates a new tensor memory pool with adaptive management
428    ///
429    /// **DESIGN PRINCIPLE**: Starts with optimal capacity but grows unlimited
430    pub fn new() -> Self {
431        TensorMemoryPool {
432            // Start with target capacities for optimal performance
433            small_buffers: Vec::with_capacity(TARGET_SMALL_BUFFERS),
434            medium_buffers: Vec::with_capacity(TARGET_MEDIUM_BUFFERS),
435            large_buffers: Vec::with_capacity(TARGET_LARGE_BUFFERS),
436            xlarge_buffers: Vec::with_capacity(4),
437            stats: PoolStats::new(),
438            op_counter: 0,
439            last_cleanup_counter: 0,
440            last_cleanup_instant: Instant::now(),
441        }
442    }
443
444    /// Attempts to allocate memory from the pool
445    ///
446    /// Returns a pointer to allocated memory if a suitable buffer is available,
447    /// otherwise returns None to indicate fallback to system allocator.
448    fn try_allocate(&mut self, size: usize, alignment: usize) -> Option<NonNull<f32>> {
449        let size_class = self.classify_size(size);
450
451        self.try_allocate_internal(size, alignment, size_class)
452    }
453
454    /// Internal allocation method that avoids borrowing conflicts
455    fn try_allocate_internal(
456        &mut self,
457        size: usize,
458        alignment: usize,
459        size_class: SizeClass,
460    ) -> Option<NonNull<f32>> {
461        // Periodically attempt cleanup prior to allocation
462        self.maybe_cleanup();
463        match size_class {
464            SizeClass::Small => {
465                self.try_allocate_from_small_pool(SMALL_BUFFER_SIZE, alignment, size_class)
466            }
467            SizeClass::Medium => {
468                self.try_allocate_from_medium_pool(MEDIUM_BUFFER_SIZE, alignment, size_class)
469            }
470            SizeClass::Large => {
471                self.try_allocate_from_large_pool(LARGE_BUFFER_SIZE, alignment, size_class)
472            }
473            SizeClass::XLarge => {
474                let planned = TensorMemoryPool::planned_capacity_elems(size);
475                self.try_allocate_from_xlarge_pool(planned, alignment, size_class)
476            }
477        }
478    }
479
480    /// Allocate from small pool
481    fn try_allocate_from_small_pool(
482        &mut self,
483        buffer_size: usize,
484        alignment: usize,
485        _size_class: SizeClass,
486    ) -> Option<NonNull<f32>> {
487        let nowc = self.bump_op_counter();
488        for buffer in self.small_buffers.iter_mut() {
489            if buffer.is_available()
490                && buffer.alloc.alignment() >= alignment
491                && buffer.allocate_for_tensor(nowc)
492            {
493                self.stats.record_allocation_hit(buffer_size);
494                return Some(buffer.as_ptr());
495            }
496        }
497        let mut new_buffer = PooledBuffer::new(buffer_size, alignment);
498        if new_buffer.allocate_for_tensor(nowc) {
499            let ptr = new_buffer.as_ptr();
500            self.small_buffers.push(new_buffer);
501            self.stats
502                .record_allocation_miss(buffer_size, "new_buffer_created");
503            Some(ptr)
504        } else {
505            None
506        }
507    }
508
509    /// Allocate from medium pool
510    fn try_allocate_from_medium_pool(
511        &mut self,
512        buffer_size: usize,
513        alignment: usize,
514        _size_class: SizeClass,
515    ) -> Option<NonNull<f32>> {
516        let nowc = self.bump_op_counter();
517        for buffer in self.medium_buffers.iter_mut() {
518            if buffer.is_available()
519                && buffer.alloc.alignment() >= alignment
520                && buffer.allocate_for_tensor(nowc)
521            {
522                self.stats.record_allocation_hit(buffer_size);
523                return Some(buffer.as_ptr());
524            }
525        }
526        let mut new_buffer = PooledBuffer::new(buffer_size, alignment);
527        if new_buffer.allocate_for_tensor(nowc) {
528            let ptr = new_buffer.as_ptr();
529            self.medium_buffers.push(new_buffer);
530            self.stats
531                .record_allocation_miss(buffer_size, "new_buffer_created");
532            Some(ptr)
533        } else {
534            None
535        }
536    }
537
538    /// Allocate from large pool
539    fn try_allocate_from_large_pool(
540        &mut self,
541        buffer_size: usize,
542        alignment: usize,
543        _size_class: SizeClass,
544    ) -> Option<NonNull<f32>> {
545        let nowc = self.bump_op_counter();
546        for buffer in self.large_buffers.iter_mut() {
547            if buffer.is_available()
548                && buffer.alloc.alignment() >= alignment
549                && buffer.allocate_for_tensor(nowc)
550            {
551                self.stats.record_allocation_hit(buffer_size);
552                return Some(buffer.as_ptr());
553            }
554        }
555        let mut new_buffer = PooledBuffer::new(buffer_size, alignment);
556        if new_buffer.allocate_for_tensor(nowc) {
557            let ptr = new_buffer.as_ptr();
558            self.large_buffers.push(new_buffer);
559            self.stats
560                .record_allocation_miss(buffer_size, "new_buffer_created");
561            Some(ptr)
562        } else {
563            None
564        }
565    }
566
567    /// Allocate from xlarge pool
568    fn try_allocate_from_xlarge_pool(
569        &mut self,
570        buffer_size: usize,
571        alignment: usize,
572        _size_class: SizeClass,
573    ) -> Option<NonNull<f32>> {
574        let nowc = self.bump_op_counter();
575        for buffer in self.xlarge_buffers.iter_mut() {
576            // Only reuse when the existing buffer capacity is sufficient and alignment is compatible
577            if buffer.is_available()
578                && buffer.size() >= buffer_size
579                && buffer.alloc.alignment() >= alignment
580                && buffer.allocate_for_tensor(nowc)
581            {
582                self.stats.record_allocation_hit(buffer_size);
583                return Some(buffer.as_ptr());
584            }
585        }
586        let mut new_buffer = PooledBuffer::new(buffer_size, alignment);
587        if new_buffer.allocate_for_tensor(nowc) {
588            let ptr = new_buffer.as_ptr();
589            self.xlarge_buffers.push(new_buffer);
590            self.stats
591                .record_allocation_miss(buffer_size, "new_buffer_created");
592            Some(ptr)
593        } else {
594            None
595        }
596    }
597
598    // Removed create_new_buffer helper; creation handled inline in try_allocate_from_* functions
599
600    /// Classifies size into size class
601    #[inline]
602    fn classify_size(&self, size: usize) -> SizeClass {
603        if size <= SMALL_BUFFER_SIZE {
604            SizeClass::Small
605        } else if size <= MEDIUM_BUFFER_SIZE {
606            SizeClass::Medium
607        } else if size <= LARGE_BUFFER_SIZE {
608            SizeClass::Large
609        } else {
610            SizeClass::XLarge
611        }
612    }
613
614    #[cfg(test)]
615    fn stats(&self) -> &PoolStats {
616        &self.stats
617    }
618}
619
620/// RAII guard to temporarily disable memory padding and pooled allocations
621/// within the current thread. This trades some runtime performance for
622/// potentially lower memory usage by avoiding lane-size padding and pool rounding.
623#[allow(dead_code)]
624pub struct NoMemPaddingGuard {
625    prev: bool,
626}
627
628impl Drop for NoMemPaddingGuard {
629    fn drop(&mut self) {
630        let _ = NO_MEM_PADDING.try_with(|flag| flag.set(self.prev));
631    }
632}
633
634impl NoMemPaddingGuard {
635    /// Create a new guard that disables memory padding until dropped
636    #[allow(dead_code)]
637    pub fn new() -> Self {
638        let prev = NO_MEM_PADDING.with(|flag| {
639            let old = flag.get();
640            flag.set(true);
641            old
642        });
643        NoMemPaddingGuard { prev }
644    }
645}
646
647impl Default for NoMemPaddingGuard {
648    fn default() -> Self {
649        Self::new()
650    }
651}
652
653/// RAII guard to temporarily disable pool usage (force system allocation) in this thread.
654pub struct NoMemPoolGuard {
655    prev: bool,
656}
657
658impl Drop for NoMemPoolGuard {
659    fn drop(&mut self) {
660        let _ = USE_POOL_ALLOC.try_with(|flag| flag.set(self.prev));
661    }
662}
663
664impl NoMemPoolGuard {
665    /// Create a new guard that disables pool allocations until dropped
666    pub fn new() -> Self {
667        let prev = USE_POOL_ALLOC.with(|flag| {
668            let old = flag.get();
669            flag.set(false);
670            old
671        });
672        NoMemPoolGuard { prev }
673    }
674}
675
676impl Default for NoMemPoolGuard {
677    fn default() -> Self {
678        Self::new()
679    }
680}
681
682/// Execute a closure with the memory pool disabled for the current thread.
683#[inline]
684pub fn with_no_mem_pool<F, R>(f: F) -> R
685where
686    F: FnOnce() -> R,
687{
688    let _guard = NoMemPoolGuard::new();
689    f()
690}
691
692/// Execute a closure with memory padding disabled for the current thread.
693#[inline]
694#[allow(dead_code)]
695pub fn with_no_mem_padding<F, R>(f: F) -> R
696where
697    F: FnOnce() -> R,
698{
699    let _guard = NoMemPaddingGuard::new();
700    f()
701}
702
703/// Returns true if the current thread has memory padding disabled.
704#[inline]
705pub fn no_mem_padding_enabled() -> bool {
706    NO_MEM_PADDING.with(|flag| flag.get())
707}
708
709impl TensorMemoryPool {
710    /// Returns the planned capacity (in f32 elements) the pool will allocate for a
711    /// given requested number of elements. This mirrors the internal size-class logic.
712    pub fn planned_capacity_elems(requested_elems: usize) -> usize {
713        if requested_elems <= SMALL_BUFFER_SIZE {
714            SMALL_BUFFER_SIZE
715        } else if requested_elems <= MEDIUM_BUFFER_SIZE {
716            MEDIUM_BUFFER_SIZE
717        } else if requested_elems <= LARGE_BUFFER_SIZE {
718            LARGE_BUFFER_SIZE
719        } else {
720            // Ensure exponential growth for very large allocations
721            (requested_elems * 2).max(262144 * 2)
722        }
723    }
724}
725
726impl PoolStats {
727    fn new() -> Self {
728        PoolStats {
729            allocations: 0,
730            deallocations: 0,
731            pool_hits: 0,
732            pool_misses: 0,
733            current_usage: 0,
734            peak_usage: 0,
735        }
736    }
737
738    fn record_allocation_hit(&mut self, buffer_size: usize) {
739        self.allocations += 1;
740        self.pool_hits += 1;
741        self.current_usage += buffer_size;
742        if self.current_usage > self.peak_usage {
743            self.peak_usage = self.current_usage;
744        }
745    }
746
747    fn record_allocation_miss(&mut self, _buffer_size: usize, _reason: &str) {
748        self.allocations += 1;
749        self.pool_misses += 1;
750    }
751
752    fn record_deallocation(&mut self, size: usize) {
753        self.deallocations += 1;
754        self.current_usage = self.current_usage.saturating_sub(size);
755    }
756}
757
758/// Public interface for memory pool operations
759impl TensorMemoryPool {
760    /// Attempts to allocate memory from the thread-local pool
761    ///
762    /// Returns Some(ptr) if allocation succeeds from pool,
763    /// None if fallback to system allocator is needed.
764    pub fn allocate(size: usize, alignment: usize) -> Option<NonNull<f32>> {
765        let result = MEMORY_POOL.with(|pool| pool.borrow_mut().try_allocate(size, alignment));
766        result
767    }
768
769    /// Attempts to return memory to the thread-local pool without panicking if TLS is
770    /// unavailable (e.g., during thread shutdown). Returns Some(result) when TLS is
771    /// accessible, or None if TLS is not available.
772    pub fn try_deallocate(ptr: NonNull<f32>) -> Option<bool> {
773        MEMORY_POOL
774            .try_with(|pool| {
775                let mut pool_mut = pool.borrow_mut();
776                pool_mut.return_to_pool(ptr)
777            })
778            .ok()
779    }
780
781    /// Return buffer to the appropriate pool
782    ///
783    /// Returns true if the buffer was successfully returned to a pool,
784    /// false if the buffer was not found in any pool (indicating it
785    /// was allocated directly from the system allocator).
786    fn return_to_pool(&mut self, ptr: NonNull<f32>) -> bool {
787        // Check each pool individually to avoid borrowing conflicts
788        if self.return_to_small_pool(ptr) {
789            self.maybe_cleanup();
790            return true;
791        }
792        if self.return_to_medium_pool(ptr) {
793            self.maybe_cleanup();
794            return true;
795        }
796        if self.return_to_large_pool(ptr) {
797            self.maybe_cleanup();
798            return true;
799        }
800        if self.return_to_xlarge_pool(ptr) {
801            self.maybe_cleanup();
802            return true;
803        }
804
805        // Buffer not found in any pool - this is expected for system-allocated memory
806        false
807    }
808
809    /// Return buffer to small pool
810    fn return_to_small_pool(&mut self, ptr: NonNull<f32>) -> bool {
811        let nowc = self.bump_op_counter();
812        for buffer in self.small_buffers.iter_mut() {
813            if buffer.as_ptr() == ptr {
814                buffer.return_to_pool(nowc);
815                self.stats.record_deallocation(buffer.size());
816                return true;
817            }
818        }
819        false
820    }
821
822    /// Return buffer to medium pool
823    fn return_to_medium_pool(&mut self, ptr: NonNull<f32>) -> bool {
824        let nowc = self.bump_op_counter();
825        for buffer in self.medium_buffers.iter_mut() {
826            if buffer.as_ptr() == ptr {
827                buffer.return_to_pool(nowc);
828                self.stats.record_deallocation(buffer.size());
829                return true;
830            }
831        }
832        false
833    }
834
835    /// Return buffer to large pool
836    fn return_to_large_pool(&mut self, ptr: NonNull<f32>) -> bool {
837        let nowc = self.bump_op_counter();
838        for buffer in self.large_buffers.iter_mut() {
839            if buffer.as_ptr() == ptr {
840                buffer.return_to_pool(nowc);
841                self.stats.record_deallocation(buffer.size());
842                return true;
843            }
844        }
845        false
846    }
847
848    /// Return buffer to xlarge pool
849    fn return_to_xlarge_pool(&mut self, ptr: NonNull<f32>) -> bool {
850        let nowc = self.bump_op_counter();
851        for buffer in self.xlarge_buffers.iter_mut() {
852            if buffer.as_ptr() == ptr {
853                buffer.return_to_pool(nowc);
854                self.stats.record_deallocation(buffer.size());
855                return true;
856            }
857        }
858        false
859    }
860
861    /// Gets statistics for the current thread's pool
862    #[cfg(test)]
863    pub fn thread_stats() -> PoolStats {
864        MEMORY_POOL.with(|pool| *pool.borrow().stats())
865    }
866
867    /// Test-only helper: return current buffer counts per pool
868    #[cfg(test)]
869    pub fn pool_sizes() -> (usize, usize, usize, usize) {
870        MEMORY_POOL.with(|pool| {
871            let p = pool.borrow();
872            (
873                p.small_buffers.len(),
874                p.medium_buffers.len(),
875                p.large_buffers.len(),
876                p.xlarge_buffers.len(),
877            )
878        })
879    }
880}
881
882impl TensorMemoryPool {
883    #[inline]
884    fn bump_op_counter(&mut self) -> u64 {
885        // Wrapping add to avoid panic on very long runs; practical overflow is unlikely
886        self.op_counter = self.op_counter.wrapping_add(1);
887        self.op_counter
888    }
889
890    /// Determine if a cleanup pass should run given time and op-counter thresholds
891    #[inline]
892    fn should_cleanup(&self) -> bool {
893        let ops_since = self.op_counter.wrapping_sub(self.last_cleanup_counter);
894        if ops_since < CLEANUP_MIN_OPS {
895            return false;
896        }
897        let elapsed = self.last_cleanup_instant.elapsed();
898        elapsed.as_millis() as u64 >= CLEANUP_MIN_INTERVAL_MS
899    }
900
901    /// Attempt to free long-idle excess buffers while preserving headroom to avoid thrash.
902    fn maybe_cleanup(&mut self) {
903        if !self.should_cleanup() {
904            return;
905        }
906
907        // Cleanup strategy per size class
908        let nowc = self.op_counter;
909        Self::cleanup_pool_vec(
910            &mut self.small_buffers,
911            TARGET_SMALL_BUFFERS,
912            HEADROOM_SMALL,
913            nowc,
914        );
915        Self::cleanup_pool_vec(
916            &mut self.medium_buffers,
917            TARGET_MEDIUM_BUFFERS,
918            HEADROOM_MEDIUM,
919            nowc,
920        );
921        Self::cleanup_pool_vec(
922            &mut self.large_buffers,
923            TARGET_LARGE_BUFFERS,
924            HEADROOM_LARGE,
925            nowc,
926        );
927        // For xlarge, keep minimal headroom; usage is often bursty and large
928        Self::cleanup_pool_vec(&mut self.xlarge_buffers, 2, HEADROOM_XLARGE, nowc);
929
930        // Update cleanup gates
931        self.last_cleanup_counter = self.op_counter;
932        self.last_cleanup_instant = Instant::now();
933    }
934
935    fn cleanup_pool_vec(
936        vec: &mut Vec<PooledBuffer>,
937        target: usize,
938        headroom: usize,
939        now_counter: u64,
940    ) {
941        if vec.is_empty() {
942            return;
943        }
944        // Compute current demand and desired capacity
945        let in_use = vec.iter().filter(|b| !b.is_available()).count();
946        let desired = core::cmp::max(target, in_use.saturating_add(headroom));
947        if vec.len() <= desired {
948            return;
949        }
950
951        // Identify eligible candidates: available and long-idle
952        let mut eligible: Vec<(usize, u64)> = vec
953            .iter()
954            .enumerate()
955            .filter(|(_i, b)| b.is_available())
956            .map(|(i, b)| (i, now_counter.wrapping_sub(b.last_used_counter)))
957            .filter(|(_i, age_ops)| *age_ops >= UNUSED_OPS_THRESHOLD)
958            .collect();
959
960        if eligible.is_empty() {
961            return;
962        }
963
964        // Prefer removing the stalest buffers first
965        eligible.sort_by_key(|(_i, age)| core::cmp::Reverse(*age));
966
967        let excess = vec.len().saturating_sub(desired);
968        let to_remove = core::cmp::min(excess, eligible.len());
969        if to_remove == 0 {
970            return;
971        }
972
973        // Remove by index from highest to lowest to avoid shifting issues
974        let mut to_drop: Vec<usize> = eligible.iter().take(to_remove).map(|(i, _)| *i).collect();
975        to_drop.sort_unstable_by(|a, b| b.cmp(a));
976        for idx in to_drop {
977            vec.remove(idx);
978        }
979    }
980}
981
982#[cfg(test)]
983mod tests {
984    use super::*;
985
986    #[test]
987    fn test_with_no_mem_padding_guard_scoping() {
988        // Default should be false
989        assert!(!no_mem_padding_enabled());
990        {
991            let _g = NoMemPaddingGuard::new();
992            assert!(no_mem_padding_enabled());
993        }
994        assert!(!no_mem_padding_enabled());
995    }
996
997    #[test]
998    fn test_compute_allocation_params_padding_behavior() {
999        // With padding enabled
1000        let (align1, padded1) = compute_allocation_params(33);
1001        let lane = simd_lane_width_elems(detect_runtime_simd());
1002        assert!(padded1 >= 33);
1003        assert_eq!(padded1 % lane, 0);
1004        assert!(align1 >= 16);
1005
1006        // No padding
1007        let res = with_no_mem_padding(|| compute_allocation_params(33));
1008
1009        assert_eq!(res.1, 33);
1010    }
1011
1012    #[test]
1013    fn test_same_thread_alloc_dealloc_counters_across_classes() {
1014        let before = TensorMemoryPool::thread_stats();
1015        {
1016            let lane = simd_lane_width_elems(detect_runtime_simd());
1017            let sizes = [
1018                SMALL_BUFFER_SIZE.min(8),
1019                MEDIUM_BUFFER_SIZE / 2,
1020                LARGE_BUFFER_SIZE / 2,
1021                LARGE_BUFFER_SIZE + lane * 3 + 7, // xlarge request
1022            ];
1023            for &n in &sizes {
1024                let _t = crate::tensor::Tensor::new(vec![n]);
1025            }
1026        }
1027        let after = TensorMemoryPool::thread_stats();
1028        assert!(after.allocations >= before.allocations + 4);
1029        assert!(after.deallocations >= before.deallocations + 4);
1030    }
1031
1032    #[test]
1033    fn test_xlarge_pool_does_not_reuse_too_small_buffer() {
1034        let lane = simd_lane_width_elems(detect_runtime_simd());
1035        let align = simd_alignment_bytes(detect_runtime_simd());
1036        // First, create an xlarge buffer of some planned capacity
1037        let small_xlarge = LARGE_BUFFER_SIZE + lane * 2;
1038        let _t1 = crate::tensor::Tensor::new(vec![small_xlarge]);
1039        // Now request a larger xlarge size that exceeds the prior capacity
1040        let larger = small_xlarge * 2 + lane * 3;
1041        let ptr_opt = MEMORY_POOL.with(|pool| {
1042            let mut p = pool.borrow_mut();
1043            p.try_allocate_from_xlarge_pool(larger, align, SizeClass::XLarge)
1044        });
1045        // We should get Some(ptr) from a newly created buffer; this test
1046        // only asserts that an allocation succeeds and the pool doesn't panic/crash.
1047        assert!(ptr_opt.is_some());
1048    }
1049
1050    #[test]
1051    fn test_cross_thread_drop_safe_no_crash() {
1052        use std::thread;
1053        let lane = simd_lane_width_elems(detect_runtime_simd());
1054        let n = LARGE_BUFFER_SIZE + lane * 2 + 3; // xlarge
1055        let t = crate::tensor::Tensor::new(vec![n]);
1056        let handle = thread::spawn(move || {
1057            // drop in another thread
1058            drop(t);
1059        });
1060        let _ = handle.join();
1061    }
1062
1063    #[test]
1064    fn test_try_deallocate_returns_some_true_for_pooled() {
1065        let align = simd_alignment_bytes(detect_runtime_simd());
1066        let ptr = TensorMemoryPool::allocate(128, align).expect("pool allocate failed");
1067        let res = TensorMemoryPool::try_deallocate(ptr);
1068        assert_eq!(res, Some(true));
1069    }
1070
1071    #[test]
1072    fn perf_pool_vs_no_pool_by_category_over_1000_iterations() {
1073        use std::time::Instant;
1074
1075        // Choose representative shapes per size class
1076        let small = vec![32, 32]; // 1,024 elems
1077        let medium = vec![256, 256]; // 65,536 elems
1078        let large = vec![1024, 1024]; // 1,048,576 elems
1079        let xlarge = vec![1200, 1200]; // > large
1080
1081        fn bench_shape(shape: &[usize], iters: usize) -> std::time::Duration {
1082            let start = Instant::now();
1083            let mut sink = 0.0f32;
1084            for i in 0..iters {
1085                // Allocate
1086                let t0 = crate::tensor::Tensor::ones(shape.to_vec());
1087                // Simple API ops chain to exercise read/write paths
1088                let t1 = t0.add_scalar((i % 5) as f32 * 0.1);
1089                let t2 = t1.mul_scalar(1.2345);
1090                // Reduce to scalar to avoid DCE and force readback
1091                let s = t2.sum();
1092                sink += s.value();
1093            }
1094            assert!(sink.is_finite());
1095            start.elapsed()
1096        }
1097
1098        let iters = 1000usize;
1099
1100        let cats: [(&str, Vec<usize>); 4] = [
1101            ("small", small),
1102            ("medium", medium),
1103            ("large", large),
1104            ("xlarge", xlarge),
1105        ];
1106
1107        for (name, shape) in cats.iter() {
1108            let pooled = bench_shape(shape, iters);
1109            let system = super::with_no_mem_pool(|| bench_shape(shape, iters));
1110            let pooled_ms = pooled.as_secs_f64() * 1_000.0;
1111            let system_ms = system.as_secs_f64() * 1_000.0;
1112            let speedup = if pooled_ms > 0.0 {
1113                system_ms / pooled_ms
1114            } else {
1115                0.0
1116            };
1117            println!(
1118                "Perf [{} | {:?} elems]: pooled={:.2} ms, no_pool={:.2} ms, speedup={:.2}x (iters={})",
1119                name,
1120                shape.iter().product::<usize>(),
1121                pooled_ms,
1122                system_ms,
1123                speedup,
1124                iters
1125            );
1126
1127            // Both modes must produce a measurable duration
1128            assert!(pooled > std::time::Duration::from_millis(0));
1129            assert!(system > std::time::Duration::from_millis(0));
1130        }
1131    }
1132}
1133
1134#[cfg(test)]
1135mod xlarge_stress_tests {
1136    use super::*;
1137
1138    #[test]
1139    fn stress_xlarge_pool_various_sizes_single_thread() {
1140        // Define sizes slightly above LARGE_BUFFER_SIZE to hit xlarge pool
1141        let lane = simd_lane_width_elems(detect_runtime_simd());
1142        let sizes = [
1143            LARGE_BUFFER_SIZE + 1,
1144            LARGE_BUFFER_SIZE * 2 + lane - 1,
1145            LARGE_BUFFER_SIZE * 3 + 17,
1146            LARGE_BUFFER_SIZE * 4 + lane * 3 + 5,
1147            LARGE_BUFFER_SIZE * 6 + 123,
1148        ];
1149        for _ in 0..1000 {
1150            for &n in &sizes {
1151                let elems = n;
1152                let mut t = crate::tensor::Tensor::new(vec![elems]);
1153                // initialize a few positions to avoid reading uninitialized memory
1154                if elems > 0 {
1155                    t.set(&[0], 0.0);
1156                }
1157                assert_eq!(t.size(), elems);
1158            }
1159        }
1160    }
1161
1162    #[test]
1163    fn stress_xlarge_pool_multithreaded() {
1164        use std::thread;
1165        let lane = simd_lane_width_elems(detect_runtime_simd());
1166        let sizes = [
1167            LARGE_BUFFER_SIZE + 1,
1168            LARGE_BUFFER_SIZE * 2 + lane - 1,
1169            LARGE_BUFFER_SIZE * 3 + 17,
1170            LARGE_BUFFER_SIZE * 4 + lane * 3 + 5,
1171            LARGE_BUFFER_SIZE * 6 + 123,
1172        ];
1173        let threads = 8usize.min(
1174            std::thread::available_parallelism()
1175                .map(|n| n.get())
1176                .unwrap_or(8),
1177        );
1178        let mut handles = Vec::new();
1179        for tid in 0..threads {
1180            let sizes_clone = sizes;
1181            handles.push(thread::spawn(move || {
1182                for r in 0..20 {
1183                    for (i, n) in sizes_clone.iter().enumerate() {
1184                        let elems = n + (tid * 13 + r * 7 + i) % lane;
1185                        let mut t = crate::tensor::Tensor::new(vec![elems]);
1186                        assert_eq!(t.size(), elems);
1187                        // write a few positions to exercise memory
1188                        if elems > 0 {
1189                            let idx0 = elems / 2;
1190                            let idx1 = (elems.saturating_sub(1)) / 3;
1191                            let idx2 = (elems.saturating_sub(1)) / 5;
1192                            // write via safe API
1193                            if idx0 < t.size() {
1194                                t.set(&[idx0], 1.2345);
1195                            }
1196                            if idx1 < t.size() {
1197                                t.set(&[idx1], 2.3456);
1198                            }
1199                            if idx2 < t.size() {
1200                                t.set(&[idx2], 3.4567);
1201                            }
1202                        }
1203                    }
1204                }
1205            }));
1206        }
1207        for h in handles {
1208            let _ = h.join();
1209        }
1210    }
1211}
1212
1213#[cfg(test)]
1214mod additional_safety_tests {
1215    use super::*;
1216
1217    #[test]
1218    fn test_pool_alloc_dealloc_balanced_small_medium_large() {
1219        let before = TensorMemoryPool::thread_stats();
1220        {
1221            let _s1 = crate::tensor::Tensor::new(vec![SMALL_BUFFER_SIZE.min(16)]);
1222            let _m1 = crate::tensor::Tensor::new(vec![MEDIUM_BUFFER_SIZE / 4]);
1223            let _l1 = crate::tensor::Tensor::new(vec![LARGE_BUFFER_SIZE / 4]);
1224        }
1225        let after = TensorMemoryPool::thread_stats();
1226        assert!(
1227            after.allocations >= before.allocations + 3,
1228            "allocations did not increase as expected: before={}, after={}",
1229            before.allocations,
1230            after.allocations
1231        );
1232        assert!(
1233            after.deallocations >= before.deallocations + 3,
1234            "deallocations did not increase as expected: before={}, after={}",
1235            before.deallocations,
1236            after.deallocations
1237        );
1238        // Current usage should not grow across scope
1239        assert!(
1240            after.current_usage <= before.current_usage,
1241            "current_usage grew: before={}, after={}",
1242            before.current_usage,
1243            after.current_usage
1244        );
1245    }
1246
1247    #[test]
1248    fn test_pointer_alignment_across_classes() {
1249        let align = simd_alignment_bytes(detect_runtime_simd());
1250        for &n in &[
1251            8usize,
1252            SMALL_BUFFER_SIZE,
1253            MEDIUM_BUFFER_SIZE,
1254            LARGE_BUFFER_SIZE + 128,
1255        ] {
1256            let t = crate::tensor::Tensor::new(vec![n]);
1257            unsafe {
1258                let addr = t.as_ptr() as usize;
1259                assert_eq!(
1260                    addr % align,
1261                    0,
1262                    "pointer not aligned to {} for n={}",
1263                    align,
1264                    n
1265                );
1266            }
1267        }
1268    }
1269
1270    #[test]
1271    fn test_with_no_mem_pool_uses_system_allocator_no_pool_stats() {
1272        let before = TensorMemoryPool::thread_stats();
1273        with_no_mem_pool(|| {
1274            let _t1 = crate::tensor::Tensor::new(vec![64]);
1275            let _t2 = crate::tensor::Tensor::new(vec![2048]);
1276            let _t3 = crate::tensor::Tensor::new(vec![131072]);
1277        });
1278        let after = TensorMemoryPool::thread_stats();
1279        // Pool should not register hits/misses when disabled within the scope
1280        assert_eq!(
1281            after.allocations, before.allocations,
1282            "pool allocations changed with pool disabled: before={}, after={}",
1283            before.allocations, after.allocations
1284        );
1285        assert_eq!(
1286            after.deallocations, before.deallocations,
1287            "pool deallocations changed with pool disabled: before={}, after={}",
1288            before.deallocations, after.deallocations
1289        );
1290    }
1291
1292    #[test]
1293    fn test_cross_thread_drop_does_not_affect_this_thread_stats() {
1294        let before = TensorMemoryPool::thread_stats();
1295        // Allocate in a worker thread and drop in this thread
1296        let handle =
1297            std::thread::spawn(|| crate::tensor::Tensor::new(vec![SMALL_BUFFER_SIZE.min(32)]));
1298        let t = handle.join().unwrap();
1299        drop(t); // Drop on current thread; should not touch this thread's pool stats
1300        let after = TensorMemoryPool::thread_stats();
1301        assert_eq!(
1302            after.allocations, before.allocations,
1303            "allocations changed in current thread due to cross-thread drop: before={}, after={}",
1304            before.allocations, after.allocations
1305        );
1306        // Deallocation also should not be recorded in this thread
1307        assert_eq!(
1308            after.deallocations, before.deallocations,
1309            "deallocations changed in current thread due to cross-thread drop: before={}, after={}",
1310            before.deallocations, after.deallocations
1311        );
1312    }
1313
1314    #[test]
1315    fn test_many_alloc_dealloc_cycles_no_growth_in_current_usage() {
1316        let before = TensorMemoryPool::thread_stats();
1317        for _ in 0..100 {
1318            let _t1 = crate::tensor::Tensor::new(vec![SMALL_BUFFER_SIZE.min(64)]);
1319            let _t2 = crate::tensor::Tensor::new(vec![MEDIUM_BUFFER_SIZE / 8]);
1320        }
1321        let after = TensorMemoryPool::thread_stats();
1322        // current_usage should remain bounded and not monotonically grow
1323        assert!(
1324            after.current_usage <= before.current_usage + SMALL_BUFFER_SIZE + MEDIUM_BUFFER_SIZE,
1325            "current_usage unexpected growth: before={}, after={}",
1326            before.current_usage,
1327            after.current_usage
1328        );
1329    }
1330}
1331
1332#[cfg(test)]
1333mod cleanup_tests {
1334    use super::*;
1335    use std::thread;
1336    use std::time::Duration;
1337
1338    // Helper to create and hold N tensors of the given element count (single-dim)
1339    fn hold_tensors(count: usize, elems: usize) -> Vec<crate::tensor::Tensor> {
1340        let mut v = Vec::with_capacity(count);
1341        for _ in 0..count {
1342            v.push(crate::tensor::Tensor::new(vec![elems]));
1343        }
1344        v
1345    }
1346
1347    // Helper to bump pool op counters by performing lightweight small allocations
1348    fn bump_ops_small_iters(iters: usize) {
1349        for _ in 0..iters {
1350            let _t = crate::tensor::Tensor::new(vec![SMALL_BUFFER_SIZE.min(8)]);
1351        }
1352    }
1353
1354    #[test]
1355    fn test_no_cleanup_while_many_small_buffers_in_use() {
1356        // Prime the pool with many small buffers held alive
1357        let holders = hold_tensors(40, SMALL_BUFFER_SIZE.min(32));
1358        let (small_before, _, _, _) = TensorMemoryPool::pool_sizes();
1359        assert!(
1360            small_before >= 40,
1361            "expected >=40 small buffers, got {}",
1362            small_before
1363        );
1364
1365        // Bump op counters and time while these buffers remain in use
1366        bump_ops_small_iters(1500); // ~3000 ops
1367        thread::sleep(Duration::from_millis(2100));
1368        bump_ops_small_iters(700); // exceed thresholds
1369
1370        // Trigger a cleanup attempt via an allocation in another size class (medium)
1371        {
1372            let _m = crate::tensor::Tensor::new(vec![MEDIUM_BUFFER_SIZE / 2]);
1373        }
1374
1375        // While buffers are still in-use, no trimming should occur (len must not decrease)
1376        let (small_mid, _, _, _) = TensorMemoryPool::pool_sizes();
1377        assert!(
1378            small_mid >= small_before,
1379            "small pool shrank while heavily in-use: before={} after={}",
1380            small_before,
1381            small_mid
1382        );
1383
1384        // Now drop the holders; their last_used timestamps are fresh, so cleanup shouldn't trim them
1385        drop(holders);
1386
1387        // Trigger cleanup again
1388        let _ = crate::tensor::Tensor::new(vec![MEDIUM_BUFFER_SIZE / 2]);
1389        let (small_after, _, _, _) = TensorMemoryPool::pool_sizes();
1390        assert!(
1391            small_after >= small_before,
1392            "small pool unexpectedly trimmed active buffers: before={} after={}",
1393            small_before,
1394            small_after
1395        );
1396    }
1397
1398    #[test]
1399    fn test_cleanup_trims_long_idle_medium_buffers() {
1400        // Create many medium buffers simultaneously to grow pool capacity
1401        {
1402            let _holders = hold_tensors(30, MEDIUM_BUFFER_SIZE / 2);
1403            // _holders dropped at end of scope, all buffers become available
1404        }
1405        let (_, med_before, _, _) = TensorMemoryPool::pool_sizes();
1406        assert!(
1407            med_before >= 30,
1408            "expected >=30 medium buffers, got {}",
1409            med_before
1410        );
1411
1412        // Leave medium buffers idle; bump ops using small allocations and wait to satisfy time gate
1413        bump_ops_small_iters(2300); // ~4600 ops (> UNUSED_OPS_THRESHOLD)
1414        thread::sleep(Duration::from_millis(2100));
1415
1416        // Trigger cleanup and observe trimming
1417        let _ = crate::tensor::Tensor::new(vec![SMALL_BUFFER_SIZE.min(16)]);
1418        let (_, med_after, _, _) = TensorMemoryPool::pool_sizes();
1419
1420        assert!(
1421            med_after < med_before,
1422            "medium pool not trimmed despite long idle: before={} after={}",
1423            med_before,
1424            med_after
1425        );
1426    }
1427}