Skip to main content

scirs2_autograd/
memory_pool.rs

1//! Tensor memory pool for reducing allocation pressure during training.
2//!
3//! This module provides a thread-safe memory pool that reuses gradient buffers
4//! and intermediate tensor allocations. During training loops, the same shapes
5//! are allocated and deallocated repeatedly; the pool caches these buffers
6//! by shape so subsequent requests can skip the allocator entirely.
7//!
8//! # Architecture
9//!
10//! - **`TensorPool`**: The core pool, keyed by `(shape, TypeId)`. Each bucket
11//!   holds a `Vec` of recycled `NdArray<F>` buffers. Protected by a
12//!   `parking_lot::Mutex` for low-overhead locking.
13//!
14//! - **`PooledArray<F>`**: An RAII wrapper around `NdArray<F>` that, on drop,
15//!   returns its buffer to the pool for reuse. Implements `Deref` / `DerefMut`
16//!   so it can be used transparently wherever `NdArray<F>` is expected.
17//!
18//! - **`PoolStats`**: Lightweight counters exposed via `TensorPool::stats()`.
19//!
20//! # Thread Safety
21//!
22//! The pool is `Send + Sync`. A global singleton is provided via
23//! [`global_pool`] for convenience, but callers may also create dedicated
24//! pools.
25//!
26//! # Example
27//!
28//! ```rust
29//! use scirs2_autograd::memory_pool::{global_pool, PooledArray};
30//!
31//! // Acquire a buffer with shape [64, 128]
32//! let buf: PooledArray<f64> = global_pool().acquire(&[64, 128]);
33//! assert_eq!(buf.shape(), &[64, 128]);
34//!
35//! // When `buf` is dropped it is returned to the pool automatically.
36//! drop(buf);
37//!
38//! // Next acquire with the same shape reuses the buffer (zero fresh allocations).
39//! let buf2: PooledArray<f64> = global_pool().acquire(&[64, 128]);
40//! assert_eq!(buf2.shape(), &[64, 128]);
41//! ```
42
43use crate::ndarray_ext::NdArray;
44use crate::Float;
45use once_cell::sync::Lazy;
46use parking_lot::Mutex;
47use std::any::TypeId;
48use std::collections::HashMap;
49use std::fmt;
50use std::ops::{Deref, DerefMut};
51use std::sync::atomic::{AtomicU64, Ordering};
52use std::sync::Arc;
53
54// ---------------------------------------------------------------------------
55// PoolStats
56// ---------------------------------------------------------------------------
57
58/// Cumulative statistics for a [`TensorPool`].
59#[derive(Debug, Clone, Copy, PartialEq, Eq)]
60pub struct PoolStats {
61    /// Total number of `acquire` calls.
62    pub n_acquired: u64,
63    /// Total number of `release` calls (including automatic drops from `PooledArray`).
64    pub n_released: u64,
65    /// Number of times a fresh allocation was required (pool miss).
66    pub n_allocated: u64,
67    /// Number of times an existing buffer was reused (pool hit).
68    pub n_reused: u64,
69    /// Approximate total bytes currently held *inside* the pool (not checked out).
70    pub pool_bytes: u64,
71    /// Number of distinct shape buckets.
72    pub n_buckets: u64,
73    /// Total buffers currently sitting in the pool.
74    pub n_pooled_buffers: u64,
75}
76
77impl fmt::Display for PoolStats {
78    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79        write!(
80            f,
81            "PoolStats {{ acquired: {}, released: {}, allocated: {}, reused: {}, \
82             pool_bytes: {}, buckets: {}, pooled_buffers: {} }}",
83            self.n_acquired,
84            self.n_released,
85            self.n_allocated,
86            self.n_reused,
87            self.pool_bytes,
88            self.n_buckets,
89            self.n_pooled_buffers,
90        )
91    }
92}
93
94// ---------------------------------------------------------------------------
95// BucketKey
96// ---------------------------------------------------------------------------
97
98/// A key for the internal bucket map: (shape, element TypeId).
99#[derive(Debug, Clone, PartialEq, Eq, Hash)]
100struct BucketKey {
101    shape: Vec<usize>,
102    type_id: TypeId,
103}
104
105// ---------------------------------------------------------------------------
106// TensorPool
107// ---------------------------------------------------------------------------
108
109/// Thread-safe pool that caches `NdArray<F>` buffers by shape.
110///
111/// The pool is cheap to clone (it is internally reference-counted).
112pub struct TensorPool {
113    inner: Arc<TensorPoolInner>,
114}
115
116struct TensorPoolInner {
117    /// Buckets keyed by (shape, TypeId).
118    buckets: Mutex<HashMap<BucketKey, Vec<ErasedArray>>>,
119    // Atomic counters so stats queries do not need the lock.
120    n_acquired: AtomicU64,
121    n_released: AtomicU64,
122    n_allocated: AtomicU64,
123    n_reused: AtomicU64,
124    /// Maximum number of buffers retained per bucket (0 = unlimited).
125    max_per_bucket: usize,
126}
127
128/// Type-erased array storage. We store the raw `Vec<u8>` backing plus
129/// the shape, and reconstruct the typed array on retrieval.
130///
131/// Safety: the `data` buffer was originally allocated as `Vec<F>` for some
132/// concrete `F: Float`. We only hand it back when the caller requests the
133/// same `TypeId`.
134struct ErasedArray {
135    /// Raw bytes. Length = num_elements * size_of::<F>().
136    data: Vec<u8>,
137    /// The shape that was used when the array was created.
138    shape: Vec<usize>,
139    /// size_of one element (used for byte-level accounting).
140    elem_size: usize,
141}
142
143impl ErasedArray {
144    /// Approximate byte size of the buffer.
145    fn byte_size(&self) -> usize {
146        self.data.len()
147    }
148}
149
150impl Clone for TensorPool {
151    fn clone(&self) -> Self {
152        Self {
153            inner: Arc::clone(&self.inner),
154        }
155    }
156}
157
158// Explicit Send + Sync: ErasedArray contains only a Vec<u8> and Vec<usize>.
159// The Mutex ensures synchronized access.
160unsafe impl Send for TensorPoolInner {}
161unsafe impl Sync for TensorPoolInner {}
162
163impl fmt::Debug for TensorPool {
164    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
165        let stats = self.stats();
166        f.debug_struct("TensorPool").field("stats", &stats).finish()
167    }
168}
169
170impl Default for TensorPool {
171    fn default() -> Self {
172        Self::new()
173    }
174}
175
176impl TensorPool {
177    /// Create a new, empty pool with no per-bucket limit.
178    pub fn new() -> Self {
179        Self::with_max_per_bucket(0)
180    }
181
182    /// Create a new pool that retains at most `max` buffers per shape bucket.
183    ///
184    /// A value of `0` means unlimited.
185    pub fn with_max_per_bucket(max: usize) -> Self {
186        Self {
187            inner: Arc::new(TensorPoolInner {
188                buckets: Mutex::new(HashMap::new()),
189                n_acquired: AtomicU64::new(0),
190                n_released: AtomicU64::new(0),
191                n_allocated: AtomicU64::new(0),
192                n_reused: AtomicU64::new(0),
193                max_per_bucket: max,
194            }),
195        }
196    }
197
198    /// Acquire a zeroed buffer with the given `shape`.
199    ///
200    /// If the pool contains a buffer with a matching shape and type it is
201    /// reused (and zeroed); otherwise a fresh allocation is made.
202    ///
203    /// The returned [`PooledArray`] will automatically return its buffer to
204    /// this pool when dropped.
205    pub fn acquire<F: Float>(&self, shape: &[usize]) -> PooledArray<F> {
206        self.inner.n_acquired.fetch_add(1, Ordering::Relaxed);
207
208        let key = BucketKey {
209            shape: shape.to_vec(),
210            type_id: TypeId::of::<F>(),
211        };
212
213        let maybe_erased = {
214            let mut buckets = self.inner.buckets.lock();
215            buckets.get_mut(&key).and_then(|v| v.pop())
216        };
217
218        let array = if let Some(erased) = maybe_erased {
219            self.inner.n_reused.fetch_add(1, Ordering::Relaxed);
220            erased_to_ndarray::<F>(erased)
221        } else {
222            self.inner.n_allocated.fetch_add(1, Ordering::Relaxed);
223            NdArray::<F>::zeros(scirs2_core::ndarray::IxDyn(shape))
224        };
225
226        PooledArray {
227            array: Some(array),
228            pool: self.clone(),
229        }
230    }
231
232    /// Manually return a buffer to the pool for later reuse.
233    ///
234    /// Prefer relying on [`PooledArray`]'s `Drop` impl instead of calling
235    /// this directly. This method is useful when you have a bare `NdArray`
236    /// obtained from elsewhere.
237    pub fn release<F: Float>(&self, array: NdArray<F>) {
238        self.inner.n_released.fetch_add(1, Ordering::Relaxed);
239        self.release_inner::<F>(array);
240    }
241
242    /// Internal release that inserts into the bucket.
243    fn release_inner<F: Float>(&self, array: NdArray<F>) {
244        let key = BucketKey {
245            shape: array.shape().to_vec(),
246            type_id: TypeId::of::<F>(),
247        };
248
249        let erased = ndarray_to_erased(array);
250
251        let mut buckets = self.inner.buckets.lock();
252        let bucket = buckets.entry(key).or_default();
253
254        // Respect max_per_bucket (0 means unlimited).
255        if self.inner.max_per_bucket == 0 || bucket.len() < self.inner.max_per_bucket {
256            bucket.push(erased);
257        }
258        // else: buffer is simply dropped (deallocated).
259    }
260
261    /// Remove all cached buffers from the pool.
262    pub fn clear(&self) {
263        let mut buckets = self.inner.buckets.lock();
264        buckets.clear();
265    }
266
267    /// Return a snapshot of current pool statistics.
268    pub fn stats(&self) -> PoolStats {
269        let buckets = self.inner.buckets.lock();
270        let mut pool_bytes: u64 = 0;
271        let mut n_pooled_buffers: u64 = 0;
272        for bucket in buckets.values() {
273            for erased in bucket {
274                pool_bytes = pool_bytes.saturating_add(erased.byte_size() as u64);
275            }
276            n_pooled_buffers = n_pooled_buffers.saturating_add(bucket.len() as u64);
277        }
278
279        PoolStats {
280            n_acquired: self.inner.n_acquired.load(Ordering::Relaxed),
281            n_released: self.inner.n_released.load(Ordering::Relaxed),
282            n_allocated: self.inner.n_allocated.load(Ordering::Relaxed),
283            n_reused: self.inner.n_reused.load(Ordering::Relaxed),
284            pool_bytes,
285            n_buckets: buckets.len() as u64,
286            n_pooled_buffers,
287        }
288    }
289
290    /// Reset the atomic counters (acquired, released, allocated, reused) to
291    /// zero. Does *not* clear the pooled buffers.
292    pub fn reset_stats(&self) {
293        self.inner.n_acquired.store(0, Ordering::Relaxed);
294        self.inner.n_released.store(0, Ordering::Relaxed);
295        self.inner.n_allocated.store(0, Ordering::Relaxed);
296        self.inner.n_reused.store(0, Ordering::Relaxed);
297    }
298}
299
300// ---------------------------------------------------------------------------
301// Erased <-> Typed conversions
302// ---------------------------------------------------------------------------
303
304/// Convert a typed `NdArray<F>` into an `ErasedArray` without re-allocating.
305fn ndarray_to_erased<F: Float>(array: NdArray<F>) -> ErasedArray {
306    let shape = array.shape().to_vec();
307    let elem_size = std::mem::size_of::<F>();
308
309    // Convert the ndarray into a flat Vec<F>, then transmute to Vec<u8>.
310    let vec_f: Vec<F> = array.into_raw_vec_and_offset().0;
311    let len = vec_f.len();
312    let cap = vec_f.capacity();
313
314    let ptr = vec_f.as_ptr();
315    std::mem::forget(vec_f);
316
317    // Safety: F is Copy + sized, so reinterpreting as bytes is sound.
318    let data = unsafe { Vec::from_raw_parts(ptr as *mut u8, len * elem_size, cap * elem_size) };
319
320    ErasedArray {
321        data,
322        shape,
323        elem_size,
324    }
325}
326
327/// Reconstruct a typed `NdArray<F>` from an `ErasedArray` and zero its contents.
328fn erased_to_ndarray<F: Float>(erased: ErasedArray) -> NdArray<F> {
329    let elem_size = std::mem::size_of::<F>();
330    debug_assert_eq!(erased.elem_size, elem_size);
331
332    let byte_len = erased.data.len();
333    let byte_cap = erased.data.capacity();
334    let ptr = erased.data.as_ptr();
335    std::mem::forget(erased.data);
336
337    let f_len = byte_len / elem_size;
338    let f_cap = byte_cap / elem_size;
339
340    // Safety: the bytes were originally produced from a Vec<F>.
341    let mut vec_f: Vec<F> = unsafe { Vec::from_raw_parts(ptr as *mut F, f_len, f_cap) };
342
343    // Zero the buffer so the caller gets a clean slate.
344    for v in vec_f.iter_mut() {
345        *v = F::zero();
346    }
347
348    NdArray::<F>::from_shape_vec(scirs2_core::ndarray::IxDyn(&erased.shape), vec_f).unwrap_or_else(
349        |_| {
350            // Fallback: allocate a fresh zero array. This path should be
351            // unreachable because we preserved the original shape.
352            NdArray::<F>::zeros(scirs2_core::ndarray::IxDyn(&erased.shape))
353        },
354    )
355}
356
357// ---------------------------------------------------------------------------
358// PooledArray
359// ---------------------------------------------------------------------------
360
361/// An RAII wrapper around `NdArray<F>` that returns its buffer to a
362/// [`TensorPool`] when dropped.
363///
364/// `PooledArray<F>` dereferences to `NdArray<F>`, so it can be used
365/// transparently in any context that expects `&NdArray<F>` or
366/// `&mut NdArray<F>`.
367pub struct PooledArray<F: Float> {
368    /// `Some` while alive; taken in `Drop::drop`.
369    array: Option<NdArray<F>>,
370    pool: TensorPool,
371}
372
373impl<F: Float> PooledArray<F> {
374    /// Consume the wrapper and return the inner `NdArray<F>` **without**
375    /// returning it to the pool. The caller takes ownership.
376    pub fn into_inner(mut self) -> NdArray<F> {
377        // Take the array so Drop does not recycle it.
378        self.array
379            .take()
380            .expect("PooledArray inner array already taken")
381    }
382
383    /// Get a reference to the inner array shape.
384    pub fn shape(&self) -> &[usize] {
385        match &self.array {
386            Some(a) => a.shape(),
387            None => &[],
388        }
389    }
390}
391
392impl<F: Float> Deref for PooledArray<F> {
393    type Target = NdArray<F>;
394
395    fn deref(&self) -> &Self::Target {
396        self.array
397            .as_ref()
398            .expect("PooledArray inner array already taken")
399    }
400}
401
402impl<F: Float> DerefMut for PooledArray<F> {
403    fn deref_mut(&mut self) -> &mut Self::Target {
404        self.array
405            .as_mut()
406            .expect("PooledArray inner array already taken")
407    }
408}
409
410impl<F: Float> Drop for PooledArray<F> {
411    fn drop(&mut self) {
412        if let Some(array) = self.array.take() {
413            self.pool.inner.n_released.fetch_add(1, Ordering::Relaxed);
414            self.pool.release_inner::<F>(array);
415        }
416    }
417}
418
419impl<F: Float> fmt::Debug for PooledArray<F> {
420    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
421        match &self.array {
422            Some(a) => write!(f, "PooledArray(shape={:?})", a.shape()),
423            None => write!(f, "PooledArray(<taken>)"),
424        }
425    }
426}
427
428// ---------------------------------------------------------------------------
429// Global singleton
430// ---------------------------------------------------------------------------
431
432/// The process-wide global tensor pool.
433static GLOBAL_POOL: Lazy<TensorPool> = Lazy::new(TensorPool::new);
434
435/// Returns a reference to the process-wide global [`TensorPool`].
436///
437/// This is a convenience for the common case where a single shared pool is
438/// sufficient.
439pub fn global_pool() -> &'static TensorPool {
440    &GLOBAL_POOL
441}
442
443// ---------------------------------------------------------------------------
444// Tests
445// ---------------------------------------------------------------------------
446
447#[cfg(test)]
448mod tests {
449    use super::*;
450
451    #[test]
452    fn test_acquire_returns_zero_array() {
453        let pool = TensorPool::new();
454        let buf: PooledArray<f64> = pool.acquire(&[3, 4]);
455        assert_eq!(buf.shape(), &[3, 4]);
456        // All elements should be zero.
457        for &v in buf.iter() {
458            assert!((v - 0.0).abs() < f64::EPSILON);
459        }
460    }
461
462    #[test]
463    fn test_acquire_release_reuse_cycle() {
464        let pool = TensorPool::new();
465
466        // First acquire: fresh allocation.
467        let buf1: PooledArray<f64> = pool.acquire(&[8, 16]);
468        let stats1 = pool.stats();
469        assert_eq!(stats1.n_acquired, 1);
470        assert_eq!(stats1.n_allocated, 1);
471        assert_eq!(stats1.n_reused, 0);
472
473        // Drop returns to pool.
474        drop(buf1);
475        let stats2 = pool.stats();
476        assert_eq!(stats2.n_released, 1);
477        assert_eq!(stats2.n_pooled_buffers, 1);
478
479        // Second acquire with same shape: reuse.
480        let buf2: PooledArray<f64> = pool.acquire(&[8, 16]);
481        let stats3 = pool.stats();
482        assert_eq!(stats3.n_acquired, 2);
483        assert_eq!(stats3.n_allocated, 1); // still 1
484        assert_eq!(stats3.n_reused, 1);
485
486        // The reused buffer should be zeroed.
487        for &v in buf2.iter() {
488            assert!((v - 0.0).abs() < f64::EPSILON);
489        }
490    }
491
492    #[test]
493    fn test_different_shapes_get_different_buckets() {
494        let pool = TensorPool::new();
495
496        let a: PooledArray<f64> = pool.acquire(&[2, 3]);
497        let b: PooledArray<f64> = pool.acquire(&[3, 2]);
498
499        drop(a);
500        drop(b);
501
502        let stats = pool.stats();
503        assert_eq!(stats.n_buckets, 2);
504        assert_eq!(stats.n_pooled_buffers, 2);
505    }
506
507    #[test]
508    fn test_different_types_get_different_buckets() {
509        let pool = TensorPool::new();
510
511        let a: PooledArray<f32> = pool.acquire(&[4, 4]);
512        let b: PooledArray<f64> = pool.acquire(&[4, 4]);
513
514        drop(a);
515        drop(b);
516
517        let stats = pool.stats();
518        assert_eq!(stats.n_buckets, 2);
519    }
520
521    #[test]
522    fn test_manual_release() {
523        let pool = TensorPool::new();
524        let arr: NdArray<f64> = NdArray::zeros(scirs2_core::ndarray::IxDyn(&[5, 5]));
525        pool.release(arr);
526
527        let stats = pool.stats();
528        assert_eq!(stats.n_released, 1);
529        assert_eq!(stats.n_pooled_buffers, 1);
530
531        // Acquire should reuse.
532        let buf: PooledArray<f64> = pool.acquire(&[5, 5]);
533        let stats2 = pool.stats();
534        assert_eq!(stats2.n_reused, 1);
535        assert_eq!(stats2.n_allocated, 0);
536        drop(buf);
537    }
538
539    #[test]
540    fn test_clear_empties_pool() {
541        let pool = TensorPool::new();
542
543        let a: PooledArray<f64> = pool.acquire(&[10, 10]);
544        drop(a);
545
546        assert_eq!(pool.stats().n_pooled_buffers, 1);
547
548        pool.clear();
549
550        assert_eq!(pool.stats().n_pooled_buffers, 0);
551        assert_eq!(pool.stats().n_buckets, 0);
552    }
553
554    #[test]
555    fn test_into_inner_does_not_return_to_pool() {
556        let pool = TensorPool::new();
557
558        let buf: PooledArray<f64> = pool.acquire(&[3, 3]);
559        let _arr: NdArray<f64> = buf.into_inner();
560
561        // No release should have happened.
562        let stats = pool.stats();
563        assert_eq!(stats.n_released, 0);
564        assert_eq!(stats.n_pooled_buffers, 0);
565    }
566
567    #[test]
568    fn test_stats_display() {
569        let pool = TensorPool::new();
570        let _a: PooledArray<f64> = pool.acquire(&[2]);
571        let display = format!("{}", pool.stats());
572        assert!(display.contains("acquired: 1"));
573    }
574
575    #[test]
576    fn test_pool_stats_pool_bytes() {
577        let pool = TensorPool::new();
578
579        let buf: PooledArray<f64> = pool.acquire(&[100]);
580        drop(buf);
581
582        let stats = pool.stats();
583        // 100 f64 elements = 800 bytes
584        assert_eq!(stats.pool_bytes, 800);
585    }
586
587    #[test]
588    fn test_reset_stats() {
589        let pool = TensorPool::new();
590
591        let buf: PooledArray<f64> = pool.acquire(&[4]);
592        drop(buf);
593
594        pool.reset_stats();
595
596        let stats = pool.stats();
597        assert_eq!(stats.n_acquired, 0);
598        assert_eq!(stats.n_released, 0);
599        assert_eq!(stats.n_allocated, 0);
600        assert_eq!(stats.n_reused, 0);
601        // Buffers are still in the pool (reset_stats does not clear).
602        assert_eq!(stats.n_pooled_buffers, 1);
603    }
604
605    #[test]
606    fn test_max_per_bucket() {
607        let pool = TensorPool::with_max_per_bucket(2);
608
609        // Allocate and release 5 buffers of the same shape.
610        for _ in 0..5 {
611            let buf: PooledArray<f64> = pool.acquire(&[10]);
612            drop(buf);
613        }
614
615        // Only 2 should be retained.
616        assert!(pool.stats().n_pooled_buffers <= 2);
617    }
618
619    #[test]
620    fn test_global_pool_accessible() {
621        let pool = global_pool();
622        let _buf: PooledArray<f64> = pool.acquire(&[1]);
623        // Just verify it doesn't panic.
624    }
625
626    #[test]
627    fn test_deref_mut() {
628        let pool = TensorPool::new();
629        let mut buf: PooledArray<f64> = pool.acquire(&[3]);
630
631        // Write through DerefMut.
632        buf[[0]] = 42.0;
633        assert!((buf[[0]] - 42.0).abs() < f64::EPSILON);
634    }
635
636    #[test]
637    fn test_debug_format() {
638        let pool = TensorPool::new();
639        let buf: PooledArray<f64> = pool.acquire(&[2, 3]);
640        let dbg = format!("{:?}", buf);
641        assert!(dbg.contains("PooledArray"));
642        assert!(dbg.contains("[2, 3]"));
643    }
644
645    #[test]
646    fn test_pool_debug_format() {
647        let pool = TensorPool::new();
648        let dbg = format!("{:?}", pool);
649        assert!(dbg.contains("TensorPool"));
650    }
651
652    #[test]
653    fn test_pool_clone_shares_state() {
654        let pool1 = TensorPool::new();
655        let pool2 = pool1.clone();
656
657        let buf: PooledArray<f64> = pool1.acquire(&[4]);
658        drop(buf);
659
660        // pool2 should see the same stats because they share the Arc.
661        let stats = pool2.stats();
662        assert_eq!(stats.n_acquired, 1);
663        assert_eq!(stats.n_released, 1);
664        assert_eq!(stats.n_pooled_buffers, 1);
665    }
666
667    #[test]
668    fn test_scalar_shape() {
669        let pool = TensorPool::new();
670        let buf: PooledArray<f64> = pool.acquire(&[]);
671        assert_eq!(buf.shape(), &[] as &[usize]);
672        drop(buf);
673
674        let buf2: PooledArray<f64> = pool.acquire(&[]);
675        assert_eq!(pool.stats().n_reused, 1);
676        drop(buf2);
677    }
678
679    #[test]
680    fn test_f32_pool() {
681        let pool = TensorPool::new();
682        let buf: PooledArray<f32> = pool.acquire(&[5, 5]);
683        assert_eq!(buf.shape(), &[5, 5]);
684        for &v in buf.iter() {
685            assert!((v - 0.0f32).abs() < f32::EPSILON);
686        }
687        drop(buf);
688
689        let stats = pool.stats();
690        assert_eq!(stats.pool_bytes, 100); // 25 * 4 bytes
691    }
692
693    #[test]
694    fn test_concurrent_access() {
695        use std::sync::Arc;
696        use std::thread;
697
698        let pool = Arc::new(TensorPool::new());
699        let n_threads = 8;
700        let n_ops_per_thread = 100;
701
702        let mut handles = Vec::with_capacity(n_threads);
703
704        for _ in 0..n_threads {
705            let pool = Arc::clone(&pool);
706            handles.push(thread::spawn(move || {
707                for i in 0..n_ops_per_thread {
708                    // Use a few different shapes to create contention.
709                    let shape = match i % 3 {
710                        0 => vec![16, 32],
711                        1 => vec![32, 16],
712                        _ => vec![64],
713                    };
714                    let mut buf: PooledArray<f64> = pool.acquire(&shape);
715                    // Do a tiny bit of work using the first element via iter_mut.
716                    if let Some(v) = buf.iter_mut().next() {
717                        *v = 1.0;
718                    }
719                    drop(buf);
720                }
721            }));
722        }
723
724        for h in handles {
725            h.join().expect("thread panicked");
726        }
727
728        let stats = pool.stats();
729        assert_eq!(stats.n_acquired, (n_threads * n_ops_per_thread) as u64,);
730        assert_eq!(stats.n_acquired, stats.n_allocated + stats.n_reused);
731        assert_eq!(stats.n_released, stats.n_acquired);
732    }
733
734    #[test]
735    fn test_concurrent_mixed_types() {
736        use std::sync::Arc;
737        use std::thread;
738
739        let pool = Arc::new(TensorPool::new());
740        let n_threads = 4;
741        let n_ops = 50;
742
743        let mut handles = Vec::with_capacity(n_threads * 2);
744
745        // f64 threads
746        for _ in 0..n_threads {
747            let pool = Arc::clone(&pool);
748            handles.push(thread::spawn(move || {
749                for _ in 0..n_ops {
750                    let buf: PooledArray<f64> = pool.acquire(&[8, 8]);
751                    drop(buf);
752                }
753            }));
754        }
755
756        // f32 threads
757        for _ in 0..n_threads {
758            let pool = Arc::clone(&pool);
759            handles.push(thread::spawn(move || {
760                for _ in 0..n_ops {
761                    let buf: PooledArray<f32> = pool.acquire(&[8, 8]);
762                    drop(buf);
763                }
764            }));
765        }
766
767        for h in handles {
768            h.join().expect("thread panicked");
769        }
770
771        let stats = pool.stats();
772        let total_ops = (n_threads * 2 * n_ops) as u64;
773        assert_eq!(stats.n_acquired, total_ops);
774    }
775
776    #[test]
777    fn test_large_shape() {
778        let pool = TensorPool::new();
779        let buf: PooledArray<f64> = pool.acquire(&[256, 256]);
780        assert_eq!(buf.shape(), &[256, 256]);
781        assert_eq!(buf.len(), 256 * 256);
782        drop(buf);
783
784        let stats = pool.stats();
785        assert_eq!(stats.pool_bytes, (256 * 256 * 8) as u64);
786    }
787
788    #[test]
789    fn test_reused_buffer_is_zeroed() {
790        let pool = TensorPool::new();
791
792        // Acquire, fill with non-zero, release.
793        let mut buf: PooledArray<f64> = pool.acquire(&[4]);
794        buf[[0]] = 99.0;
795        buf[[1]] = 88.0;
796        buf[[2]] = 77.0;
797        buf[[3]] = 66.0;
798        drop(buf);
799
800        // Re-acquire: should be zeroed.
801        let buf2: PooledArray<f64> = pool.acquire(&[4]);
802        for &v in buf2.iter() {
803            assert!((v - 0.0).abs() < f64::EPSILON, "expected zero, got {}", v);
804        }
805    }
806
807    #[test]
808    fn test_multiple_buffers_same_shape() {
809        let pool = TensorPool::new();
810
811        // Release multiple buffers of the same shape.
812        for _ in 0..5 {
813            let arr: NdArray<f64> = NdArray::zeros(scirs2_core::ndarray::IxDyn(&[3]));
814            pool.release(arr);
815        }
816
817        assert_eq!(pool.stats().n_pooled_buffers, 5);
818
819        // Acquire them all at once (keeping each alive so it is not returned).
820        let mut held: Vec<PooledArray<f64>> = Vec::with_capacity(5);
821        for i in 0..5 {
822            held.push(pool.acquire(&[3]));
823            assert_eq!(pool.stats().n_pooled_buffers, 4 - i as u64);
824        }
825        // Drop all at once; all 5 buffers are returned to the pool.
826        drop(held);
827        assert_eq!(pool.stats().n_pooled_buffers, 5);
828    }
829}