1use crate::ndarray_ext::NdArray;
44use crate::Float;
45use once_cell::sync::Lazy;
46use parking_lot::Mutex;
47use std::any::TypeId;
48use std::collections::HashMap;
49use std::fmt;
50use std::ops::{Deref, DerefMut};
51use std::sync::atomic::{AtomicU64, Ordering};
52use std::sync::Arc;
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq)]
60pub struct PoolStats {
61 pub n_acquired: u64,
63 pub n_released: u64,
65 pub n_allocated: u64,
67 pub n_reused: u64,
69 pub pool_bytes: u64,
71 pub n_buckets: u64,
73 pub n_pooled_buffers: u64,
75}
76
77impl fmt::Display for PoolStats {
78 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79 write!(
80 f,
81 "PoolStats {{ acquired: {}, released: {}, allocated: {}, reused: {}, \
82 pool_bytes: {}, buckets: {}, pooled_buffers: {} }}",
83 self.n_acquired,
84 self.n_released,
85 self.n_allocated,
86 self.n_reused,
87 self.pool_bytes,
88 self.n_buckets,
89 self.n_pooled_buffers,
90 )
91 }
92}
93
94#[derive(Debug, Clone, PartialEq, Eq, Hash)]
100struct BucketKey {
101 shape: Vec<usize>,
102 type_id: TypeId,
103}
104
105pub struct TensorPool {
113 inner: Arc<TensorPoolInner>,
114}
115
116struct TensorPoolInner {
117 buckets: Mutex<HashMap<BucketKey, Vec<ErasedArray>>>,
119 n_acquired: AtomicU64,
121 n_released: AtomicU64,
122 n_allocated: AtomicU64,
123 n_reused: AtomicU64,
124 max_per_bucket: usize,
126}
127
128struct ErasedArray {
135 data: Vec<u8>,
137 shape: Vec<usize>,
139 elem_size: usize,
141}
142
143impl ErasedArray {
144 fn byte_size(&self) -> usize {
146 self.data.len()
147 }
148}
149
150impl Clone for TensorPool {
151 fn clone(&self) -> Self {
152 Self {
153 inner: Arc::clone(&self.inner),
154 }
155 }
156}
157
158unsafe impl Send for TensorPoolInner {}
161unsafe impl Sync for TensorPoolInner {}
162
163impl fmt::Debug for TensorPool {
164 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
165 let stats = self.stats();
166 f.debug_struct("TensorPool").field("stats", &stats).finish()
167 }
168}
169
170impl Default for TensorPool {
171 fn default() -> Self {
172 Self::new()
173 }
174}
175
176impl TensorPool {
177 pub fn new() -> Self {
179 Self::with_max_per_bucket(0)
180 }
181
182 pub fn with_max_per_bucket(max: usize) -> Self {
186 Self {
187 inner: Arc::new(TensorPoolInner {
188 buckets: Mutex::new(HashMap::new()),
189 n_acquired: AtomicU64::new(0),
190 n_released: AtomicU64::new(0),
191 n_allocated: AtomicU64::new(0),
192 n_reused: AtomicU64::new(0),
193 max_per_bucket: max,
194 }),
195 }
196 }
197
198 pub fn acquire<F: Float>(&self, shape: &[usize]) -> PooledArray<F> {
206 self.inner.n_acquired.fetch_add(1, Ordering::Relaxed);
207
208 let key = BucketKey {
209 shape: shape.to_vec(),
210 type_id: TypeId::of::<F>(),
211 };
212
213 let maybe_erased = {
214 let mut buckets = self.inner.buckets.lock();
215 buckets.get_mut(&key).and_then(|v| v.pop())
216 };
217
218 let array = if let Some(erased) = maybe_erased {
219 self.inner.n_reused.fetch_add(1, Ordering::Relaxed);
220 erased_to_ndarray::<F>(erased)
221 } else {
222 self.inner.n_allocated.fetch_add(1, Ordering::Relaxed);
223 NdArray::<F>::zeros(scirs2_core::ndarray::IxDyn(shape))
224 };
225
226 PooledArray {
227 array: Some(array),
228 pool: self.clone(),
229 }
230 }
231
232 pub fn release<F: Float>(&self, array: NdArray<F>) {
238 self.inner.n_released.fetch_add(1, Ordering::Relaxed);
239 self.release_inner::<F>(array);
240 }
241
242 fn release_inner<F: Float>(&self, array: NdArray<F>) {
244 let key = BucketKey {
245 shape: array.shape().to_vec(),
246 type_id: TypeId::of::<F>(),
247 };
248
249 let erased = ndarray_to_erased(array);
250
251 let mut buckets = self.inner.buckets.lock();
252 let bucket = buckets.entry(key).or_default();
253
254 if self.inner.max_per_bucket == 0 || bucket.len() < self.inner.max_per_bucket {
256 bucket.push(erased);
257 }
258 }
260
261 pub fn clear(&self) {
263 let mut buckets = self.inner.buckets.lock();
264 buckets.clear();
265 }
266
267 pub fn stats(&self) -> PoolStats {
269 let buckets = self.inner.buckets.lock();
270 let mut pool_bytes: u64 = 0;
271 let mut n_pooled_buffers: u64 = 0;
272 for bucket in buckets.values() {
273 for erased in bucket {
274 pool_bytes = pool_bytes.saturating_add(erased.byte_size() as u64);
275 }
276 n_pooled_buffers = n_pooled_buffers.saturating_add(bucket.len() as u64);
277 }
278
279 PoolStats {
280 n_acquired: self.inner.n_acquired.load(Ordering::Relaxed),
281 n_released: self.inner.n_released.load(Ordering::Relaxed),
282 n_allocated: self.inner.n_allocated.load(Ordering::Relaxed),
283 n_reused: self.inner.n_reused.load(Ordering::Relaxed),
284 pool_bytes,
285 n_buckets: buckets.len() as u64,
286 n_pooled_buffers,
287 }
288 }
289
290 pub fn reset_stats(&self) {
293 self.inner.n_acquired.store(0, Ordering::Relaxed);
294 self.inner.n_released.store(0, Ordering::Relaxed);
295 self.inner.n_allocated.store(0, Ordering::Relaxed);
296 self.inner.n_reused.store(0, Ordering::Relaxed);
297 }
298}
299
300fn ndarray_to_erased<F: Float>(array: NdArray<F>) -> ErasedArray {
306 let shape = array.shape().to_vec();
307 let elem_size = std::mem::size_of::<F>();
308
309 let vec_f: Vec<F> = array.into_raw_vec_and_offset().0;
311 let len = vec_f.len();
312 let cap = vec_f.capacity();
313
314 let ptr = vec_f.as_ptr();
315 std::mem::forget(vec_f);
316
317 let data = unsafe { Vec::from_raw_parts(ptr as *mut u8, len * elem_size, cap * elem_size) };
319
320 ErasedArray {
321 data,
322 shape,
323 elem_size,
324 }
325}
326
327fn erased_to_ndarray<F: Float>(erased: ErasedArray) -> NdArray<F> {
329 let elem_size = std::mem::size_of::<F>();
330 debug_assert_eq!(erased.elem_size, elem_size);
331
332 let byte_len = erased.data.len();
333 let byte_cap = erased.data.capacity();
334 let ptr = erased.data.as_ptr();
335 std::mem::forget(erased.data);
336
337 let f_len = byte_len / elem_size;
338 let f_cap = byte_cap / elem_size;
339
340 let mut vec_f: Vec<F> = unsafe { Vec::from_raw_parts(ptr as *mut F, f_len, f_cap) };
342
343 for v in vec_f.iter_mut() {
345 *v = F::zero();
346 }
347
348 NdArray::<F>::from_shape_vec(scirs2_core::ndarray::IxDyn(&erased.shape), vec_f).unwrap_or_else(
349 |_| {
350 NdArray::<F>::zeros(scirs2_core::ndarray::IxDyn(&erased.shape))
353 },
354 )
355}
356
357pub struct PooledArray<F: Float> {
368 array: Option<NdArray<F>>,
370 pool: TensorPool,
371}
372
373impl<F: Float> PooledArray<F> {
374 pub fn into_inner(mut self) -> NdArray<F> {
377 self.array
379 .take()
380 .expect("PooledArray inner array already taken")
381 }
382
383 pub fn shape(&self) -> &[usize] {
385 match &self.array {
386 Some(a) => a.shape(),
387 None => &[],
388 }
389 }
390}
391
392impl<F: Float> Deref for PooledArray<F> {
393 type Target = NdArray<F>;
394
395 fn deref(&self) -> &Self::Target {
396 self.array
397 .as_ref()
398 .expect("PooledArray inner array already taken")
399 }
400}
401
402impl<F: Float> DerefMut for PooledArray<F> {
403 fn deref_mut(&mut self) -> &mut Self::Target {
404 self.array
405 .as_mut()
406 .expect("PooledArray inner array already taken")
407 }
408}
409
410impl<F: Float> Drop for PooledArray<F> {
411 fn drop(&mut self) {
412 if let Some(array) = self.array.take() {
413 self.pool.inner.n_released.fetch_add(1, Ordering::Relaxed);
414 self.pool.release_inner::<F>(array);
415 }
416 }
417}
418
419impl<F: Float> fmt::Debug for PooledArray<F> {
420 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
421 match &self.array {
422 Some(a) => write!(f, "PooledArray(shape={:?})", a.shape()),
423 None => write!(f, "PooledArray(<taken>)"),
424 }
425 }
426}
427
428static GLOBAL_POOL: Lazy<TensorPool> = Lazy::new(TensorPool::new);
434
435pub fn global_pool() -> &'static TensorPool {
440 &GLOBAL_POOL
441}
442
443#[cfg(test)]
448mod tests {
449 use super::*;
450
451 #[test]
452 fn test_acquire_returns_zero_array() {
453 let pool = TensorPool::new();
454 let buf: PooledArray<f64> = pool.acquire(&[3, 4]);
455 assert_eq!(buf.shape(), &[3, 4]);
456 for &v in buf.iter() {
458 assert!((v - 0.0).abs() < f64::EPSILON);
459 }
460 }
461
462 #[test]
463 fn test_acquire_release_reuse_cycle() {
464 let pool = TensorPool::new();
465
466 let buf1: PooledArray<f64> = pool.acquire(&[8, 16]);
468 let stats1 = pool.stats();
469 assert_eq!(stats1.n_acquired, 1);
470 assert_eq!(stats1.n_allocated, 1);
471 assert_eq!(stats1.n_reused, 0);
472
473 drop(buf1);
475 let stats2 = pool.stats();
476 assert_eq!(stats2.n_released, 1);
477 assert_eq!(stats2.n_pooled_buffers, 1);
478
479 let buf2: PooledArray<f64> = pool.acquire(&[8, 16]);
481 let stats3 = pool.stats();
482 assert_eq!(stats3.n_acquired, 2);
483 assert_eq!(stats3.n_allocated, 1); assert_eq!(stats3.n_reused, 1);
485
486 for &v in buf2.iter() {
488 assert!((v - 0.0).abs() < f64::EPSILON);
489 }
490 }
491
492 #[test]
493 fn test_different_shapes_get_different_buckets() {
494 let pool = TensorPool::new();
495
496 let a: PooledArray<f64> = pool.acquire(&[2, 3]);
497 let b: PooledArray<f64> = pool.acquire(&[3, 2]);
498
499 drop(a);
500 drop(b);
501
502 let stats = pool.stats();
503 assert_eq!(stats.n_buckets, 2);
504 assert_eq!(stats.n_pooled_buffers, 2);
505 }
506
507 #[test]
508 fn test_different_types_get_different_buckets() {
509 let pool = TensorPool::new();
510
511 let a: PooledArray<f32> = pool.acquire(&[4, 4]);
512 let b: PooledArray<f64> = pool.acquire(&[4, 4]);
513
514 drop(a);
515 drop(b);
516
517 let stats = pool.stats();
518 assert_eq!(stats.n_buckets, 2);
519 }
520
521 #[test]
522 fn test_manual_release() {
523 let pool = TensorPool::new();
524 let arr: NdArray<f64> = NdArray::zeros(scirs2_core::ndarray::IxDyn(&[5, 5]));
525 pool.release(arr);
526
527 let stats = pool.stats();
528 assert_eq!(stats.n_released, 1);
529 assert_eq!(stats.n_pooled_buffers, 1);
530
531 let buf: PooledArray<f64> = pool.acquire(&[5, 5]);
533 let stats2 = pool.stats();
534 assert_eq!(stats2.n_reused, 1);
535 assert_eq!(stats2.n_allocated, 0);
536 drop(buf);
537 }
538
539 #[test]
540 fn test_clear_empties_pool() {
541 let pool = TensorPool::new();
542
543 let a: PooledArray<f64> = pool.acquire(&[10, 10]);
544 drop(a);
545
546 assert_eq!(pool.stats().n_pooled_buffers, 1);
547
548 pool.clear();
549
550 assert_eq!(pool.stats().n_pooled_buffers, 0);
551 assert_eq!(pool.stats().n_buckets, 0);
552 }
553
554 #[test]
555 fn test_into_inner_does_not_return_to_pool() {
556 let pool = TensorPool::new();
557
558 let buf: PooledArray<f64> = pool.acquire(&[3, 3]);
559 let _arr: NdArray<f64> = buf.into_inner();
560
561 let stats = pool.stats();
563 assert_eq!(stats.n_released, 0);
564 assert_eq!(stats.n_pooled_buffers, 0);
565 }
566
567 #[test]
568 fn test_stats_display() {
569 let pool = TensorPool::new();
570 let _a: PooledArray<f64> = pool.acquire(&[2]);
571 let display = format!("{}", pool.stats());
572 assert!(display.contains("acquired: 1"));
573 }
574
575 #[test]
576 fn test_pool_stats_pool_bytes() {
577 let pool = TensorPool::new();
578
579 let buf: PooledArray<f64> = pool.acquire(&[100]);
580 drop(buf);
581
582 let stats = pool.stats();
583 assert_eq!(stats.pool_bytes, 800);
585 }
586
587 #[test]
588 fn test_reset_stats() {
589 let pool = TensorPool::new();
590
591 let buf: PooledArray<f64> = pool.acquire(&[4]);
592 drop(buf);
593
594 pool.reset_stats();
595
596 let stats = pool.stats();
597 assert_eq!(stats.n_acquired, 0);
598 assert_eq!(stats.n_released, 0);
599 assert_eq!(stats.n_allocated, 0);
600 assert_eq!(stats.n_reused, 0);
601 assert_eq!(stats.n_pooled_buffers, 1);
603 }
604
605 #[test]
606 fn test_max_per_bucket() {
607 let pool = TensorPool::with_max_per_bucket(2);
608
609 for _ in 0..5 {
611 let buf: PooledArray<f64> = pool.acquire(&[10]);
612 drop(buf);
613 }
614
615 assert!(pool.stats().n_pooled_buffers <= 2);
617 }
618
619 #[test]
620 fn test_global_pool_accessible() {
621 let pool = global_pool();
622 let _buf: PooledArray<f64> = pool.acquire(&[1]);
623 }
625
626 #[test]
627 fn test_deref_mut() {
628 let pool = TensorPool::new();
629 let mut buf: PooledArray<f64> = pool.acquire(&[3]);
630
631 buf[[0]] = 42.0;
633 assert!((buf[[0]] - 42.0).abs() < f64::EPSILON);
634 }
635
636 #[test]
637 fn test_debug_format() {
638 let pool = TensorPool::new();
639 let buf: PooledArray<f64> = pool.acquire(&[2, 3]);
640 let dbg = format!("{:?}", buf);
641 assert!(dbg.contains("PooledArray"));
642 assert!(dbg.contains("[2, 3]"));
643 }
644
645 #[test]
646 fn test_pool_debug_format() {
647 let pool = TensorPool::new();
648 let dbg = format!("{:?}", pool);
649 assert!(dbg.contains("TensorPool"));
650 }
651
652 #[test]
653 fn test_pool_clone_shares_state() {
654 let pool1 = TensorPool::new();
655 let pool2 = pool1.clone();
656
657 let buf: PooledArray<f64> = pool1.acquire(&[4]);
658 drop(buf);
659
660 let stats = pool2.stats();
662 assert_eq!(stats.n_acquired, 1);
663 assert_eq!(stats.n_released, 1);
664 assert_eq!(stats.n_pooled_buffers, 1);
665 }
666
667 #[test]
668 fn test_scalar_shape() {
669 let pool = TensorPool::new();
670 let buf: PooledArray<f64> = pool.acquire(&[]);
671 assert_eq!(buf.shape(), &[] as &[usize]);
672 drop(buf);
673
674 let buf2: PooledArray<f64> = pool.acquire(&[]);
675 assert_eq!(pool.stats().n_reused, 1);
676 drop(buf2);
677 }
678
679 #[test]
680 fn test_f32_pool() {
681 let pool = TensorPool::new();
682 let buf: PooledArray<f32> = pool.acquire(&[5, 5]);
683 assert_eq!(buf.shape(), &[5, 5]);
684 for &v in buf.iter() {
685 assert!((v - 0.0f32).abs() < f32::EPSILON);
686 }
687 drop(buf);
688
689 let stats = pool.stats();
690 assert_eq!(stats.pool_bytes, 100); }
692
693 #[test]
694 fn test_concurrent_access() {
695 use std::sync::Arc;
696 use std::thread;
697
698 let pool = Arc::new(TensorPool::new());
699 let n_threads = 8;
700 let n_ops_per_thread = 100;
701
702 let mut handles = Vec::with_capacity(n_threads);
703
704 for _ in 0..n_threads {
705 let pool = Arc::clone(&pool);
706 handles.push(thread::spawn(move || {
707 for i in 0..n_ops_per_thread {
708 let shape = match i % 3 {
710 0 => vec![16, 32],
711 1 => vec![32, 16],
712 _ => vec![64],
713 };
714 let mut buf: PooledArray<f64> = pool.acquire(&shape);
715 if let Some(v) = buf.iter_mut().next() {
717 *v = 1.0;
718 }
719 drop(buf);
720 }
721 }));
722 }
723
724 for h in handles {
725 h.join().expect("thread panicked");
726 }
727
728 let stats = pool.stats();
729 assert_eq!(stats.n_acquired, (n_threads * n_ops_per_thread) as u64,);
730 assert_eq!(stats.n_acquired, stats.n_allocated + stats.n_reused);
731 assert_eq!(stats.n_released, stats.n_acquired);
732 }
733
734 #[test]
735 fn test_concurrent_mixed_types() {
736 use std::sync::Arc;
737 use std::thread;
738
739 let pool = Arc::new(TensorPool::new());
740 let n_threads = 4;
741 let n_ops = 50;
742
743 let mut handles = Vec::with_capacity(n_threads * 2);
744
745 for _ in 0..n_threads {
747 let pool = Arc::clone(&pool);
748 handles.push(thread::spawn(move || {
749 for _ in 0..n_ops {
750 let buf: PooledArray<f64> = pool.acquire(&[8, 8]);
751 drop(buf);
752 }
753 }));
754 }
755
756 for _ in 0..n_threads {
758 let pool = Arc::clone(&pool);
759 handles.push(thread::spawn(move || {
760 for _ in 0..n_ops {
761 let buf: PooledArray<f32> = pool.acquire(&[8, 8]);
762 drop(buf);
763 }
764 }));
765 }
766
767 for h in handles {
768 h.join().expect("thread panicked");
769 }
770
771 let stats = pool.stats();
772 let total_ops = (n_threads * 2 * n_ops) as u64;
773 assert_eq!(stats.n_acquired, total_ops);
774 }
775
776 #[test]
777 fn test_large_shape() {
778 let pool = TensorPool::new();
779 let buf: PooledArray<f64> = pool.acquire(&[256, 256]);
780 assert_eq!(buf.shape(), &[256, 256]);
781 assert_eq!(buf.len(), 256 * 256);
782 drop(buf);
783
784 let stats = pool.stats();
785 assert_eq!(stats.pool_bytes, (256 * 256 * 8) as u64);
786 }
787
788 #[test]
789 fn test_reused_buffer_is_zeroed() {
790 let pool = TensorPool::new();
791
792 let mut buf: PooledArray<f64> = pool.acquire(&[4]);
794 buf[[0]] = 99.0;
795 buf[[1]] = 88.0;
796 buf[[2]] = 77.0;
797 buf[[3]] = 66.0;
798 drop(buf);
799
800 let buf2: PooledArray<f64> = pool.acquire(&[4]);
802 for &v in buf2.iter() {
803 assert!((v - 0.0).abs() < f64::EPSILON, "expected zero, got {}", v);
804 }
805 }
806
807 #[test]
808 fn test_multiple_buffers_same_shape() {
809 let pool = TensorPool::new();
810
811 for _ in 0..5 {
813 let arr: NdArray<f64> = NdArray::zeros(scirs2_core::ndarray::IxDyn(&[3]));
814 pool.release(arr);
815 }
816
817 assert_eq!(pool.stats().n_pooled_buffers, 5);
818
819 let mut held: Vec<PooledArray<f64>> = Vec::with_capacity(5);
821 for i in 0..5 {
822 held.push(pool.acquire(&[3]));
823 assert_eq!(pool.stats().n_pooled_buffers, 4 - i as u64);
824 }
825 drop(held);
827 assert_eq!(pool.stats().n_pooled_buffers, 5);
828 }
829}