1#![allow(dead_code)]
3use crate::{Tensor, TensorStorage};
6use std::alloc::{handle_alloc_error, Layout};
7use std::collections::{HashMap, VecDeque};
8use std::marker::PhantomData;
9use std::mem::{ManuallyDrop, MaybeUninit};
10use std::ptr::NonNull;
11use std::sync::{Arc, Mutex, Weak};
12use torsh_core::{device::DeviceType, dtype::TensorElement, error::Result};
13
14use scirs2_core::memory::GlobalBufferPool;
16use scirs2_core::memory::LeakDetector;
17#[cfg(not(feature = "memory_efficient"))]
21struct MemoryMappedArray<T> {
22 _phantom: PhantomData<T>,
23}
24
25#[cfg(not(feature = "memory_efficient"))]
26impl<T> MemoryMappedArray<T> {
27 fn new(_size: usize) -> Result<Self> {
28 Err(torsh_core::error::TorshError::General(
29 torsh_core::error::GeneralError::NotImplemented(
30 "MemoryMappedArray requires memory_efficient feature".to_string(),
31 ),
32 ))
33 }
34}
35
36static MEMORY_POOL: std::sync::OnceLock<Arc<Mutex<GlobalMemoryPool>>> = std::sync::OnceLock::new();
42
43pub fn init_memory_pool() -> Arc<Mutex<GlobalMemoryPool>> {
45 let arc = MEMORY_POOL
46 .get_or_init(|| {
47 let pool = Arc::new(Mutex::new(GlobalMemoryPool::new()));
48 if let Ok(mut guard) = pool.lock() {
50 guard.self_weak = Some(Arc::downgrade(&pool));
51 }
52 pool
53 })
54 .clone();
55 arc
56}
57
58pub fn get_memory_pool() -> Arc<Mutex<GlobalMemoryPool>> {
60 init_memory_pool()
61}
62
63struct RawEntry {
68 ptr: NonNull<u8>,
69 capacity_bytes: usize,
70 layout: Layout,
71}
72
73unsafe impl Send for RawEntry {}
75
76impl Drop for RawEntry {
77 fn drop(&mut self) {
78 unsafe { std::alloc::dealloc(self.ptr.as_ptr(), self.layout) };
80 }
81}
82
83pub struct ReusedBuffer<T: 'static> {
90 ptr: NonNull<T>,
91 capacity: usize,
92 layout: Layout,
93 pool: Weak<Mutex<GlobalMemoryPool>>,
94}
95
96unsafe impl<T: Send + 'static> Send for ReusedBuffer<T> {}
99
100impl<T: 'static> ReusedBuffer<T> {
101 pub fn as_uninit_slice_mut(&mut self) -> &mut [MaybeUninit<T>] {
103 unsafe {
105 std::slice::from_raw_parts_mut(self.ptr.as_ptr() as *mut MaybeUninit<T>, self.capacity)
106 }
107 }
108
109 pub fn capacity(&self) -> usize {
111 self.capacity
112 }
113
114 pub fn as_ptr_raw(&self) -> *mut T {
116 self.ptr.as_ptr()
117 }
118
119 pub fn into_vec(self, len: usize) -> Vec<T> {
127 debug_assert!(len <= self.capacity, "len must not exceed capacity");
128 let md = ManuallyDrop::new(self);
130 unsafe { Vec::from_raw_parts(md.ptr.as_ptr(), len, md.capacity) }
133 }
134
135 pub fn release_to_pool(self) {
139 let md = ManuallyDrop::new(self);
141 let raw_entry = RawEntry {
142 ptr: NonNull::new(md.ptr.as_ptr() as *mut u8)
143 .expect("ReusedBuffer pointer is non-null by construction"),
144 capacity_bytes: md.capacity * std::mem::size_of::<T>(),
145 layout: md.layout,
146 };
147 if let Some(pool_arc) = md.pool.upgrade() {
148 if let Ok(mut guard) = pool_arc.lock() {
149 let type_id = std::any::TypeId::of::<T>();
150 let size_class = guard.find_size_class(raw_entry.capacity_bytes);
151 let pool_key = (type_id, size_class);
152 if let Some(bucket) = guard.pools.get_mut(&pool_key) {
153 if bucket.available_buffers.len() < bucket.max_buffers {
154 bucket.available_buffers.push_back(raw_entry);
155 bucket.deallocations += 1;
156 return;
158 }
159 }
160 }
161 }
162 }
164}
165
166impl<T: 'static> Drop for ReusedBuffer<T> {
167 fn drop(&mut self) {
168 let raw_entry = RawEntry {
171 ptr: NonNull::new(self.ptr.as_ptr() as *mut u8)
172 .expect("ReusedBuffer pointer is non-null by construction"),
173 capacity_bytes: self.capacity * std::mem::size_of::<T>(),
174 layout: self.layout,
175 };
176 if let Some(pool_arc) = self.pool.upgrade() {
177 if let Ok(mut guard) = pool_arc.lock() {
178 let type_id = std::any::TypeId::of::<T>();
179 let size_class = guard.find_size_class(raw_entry.capacity_bytes);
180 let pool_key = (type_id, size_class);
181 if let Some(bucket) = guard.pools.get_mut(&pool_key) {
182 if bucket.available_buffers.len() < bucket.max_buffers {
183 let md_entry = ManuallyDrop::new(raw_entry);
186 bucket
189 .available_buffers
190 .push_back(unsafe { std::ptr::read(&*md_entry as *const RawEntry) });
191 bucket.deallocations += 1;
192 return;
193 }
194 }
195 }
196 }
197 }
199}
200
201pub struct GlobalMemoryPool {
205 pools: HashMap<(std::any::TypeId, usize), MemoryPool>,
207 stats: PoolStatistics,
209 config: PoolConfig,
211 scirs2_pool: GlobalBufferPool,
213 leak_detector: LeakDetector,
215 self_weak: Option<Weak<Mutex<GlobalMemoryPool>>>,
217 }
222
223#[derive(Debug)]
225struct MemoryPool {
226 available_buffers: VecDeque<RawEntry>,
228 #[allow(dead_code)]
230 size_class: usize,
231 max_buffers: usize,
233 allocations: usize,
235 reuses: usize,
236 deallocations: usize,
237}
238
239#[derive(Debug, Clone)]
241pub struct PoolConfig {
242 pub max_buffers_per_class: usize,
244 pub max_total_memory: usize,
246 pub auto_cleanup: bool,
248 pub cleanup_threshold: f64,
250 pub size_classes: Vec<usize>,
252}
253
254#[derive(Debug, Default, Clone)]
256pub struct PoolStatistics {
257 pub total_allocations: usize,
259 pub pool_hits: usize,
261 pub pool_misses: usize,
263 pub total_bytes_allocated: usize,
265 pub bytes_in_pools: usize,
267 pub peak_memory_usage: usize,
269}
270
271#[derive(Debug)]
273pub struct PooledTensor<T: TensorElement + Default> {
274 tensor: Tensor<T>,
275 pool_key: Option<(std::any::TypeId, usize)>,
276 _phantom: PhantomData<T>,
277}
278
279impl Default for PoolConfig {
280 fn default() -> Self {
281 let size_classes = (10..31) .map(|exp| 1 << exp)
284 .collect();
285
286 Self {
287 max_buffers_per_class: 16,
288 max_total_memory: 1024 * 1024 * 1024, auto_cleanup: true,
290 cleanup_threshold: 0.8,
291 size_classes,
292 }
293 }
294}
295
296impl Default for GlobalMemoryPool {
297 fn default() -> Self {
298 Self::new()
299 }
300}
301
302impl GlobalMemoryPool {
303 pub fn new() -> Self {
305 #[cfg(feature = "profiling")]
306 {
307 }
309 Self {
310 pools: HashMap::new(),
311 stats: PoolStatistics::default(),
312 config: PoolConfig::default(),
313 scirs2_pool: GlobalBufferPool::new(),
315 leak_detector: LeakDetector::new(Default::default())
316 .unwrap_or_else(|_| panic!("Failed to initialize leak detector")),
317 self_weak: None,
318 }
321 }
322
323 pub fn create_large_tensor<T: TensorElement>(
325 &mut self,
326 shape: &[usize],
327 device: DeviceType,
328 ) -> Result<Tensor<T>>
329 where
330 T: Clone + Default,
331 {
332 #[cfg(feature = "profiling")]
333 {
334 }
336 let total_elements: usize = shape.iter().product();
337 let total_bytes = total_elements * std::mem::size_of::<T>();
338
339 if total_bytes > 100 * 1024 * 1024 {
341 self.create_memory_mapped_tensor(shape, device)
343 } else if total_bytes > 10 * 1024 * 1024 {
344 self.create_chunked_tensor(shape, device)
346 } else if total_bytes > 1024 * 1024 {
347 self.create_pooled_tensor(shape, device)
349 } else {
350 Tensor::zeros(shape, device)
352 }
353 }
354
355 fn create_memory_mapped_tensor<T: TensorElement>(
357 &mut self,
358 shape: &[usize],
359 device: DeviceType,
360 ) -> Result<Tensor<T>>
361 where
362 T: Clone + Default,
363 {
364 let total_elements: usize = shape.iter().product();
365
366 let data = vec![T::default(); total_elements];
378 Tensor::from_data(data, shape.to_vec(), device)
379 }
380
381 fn create_chunked_tensor<T: TensorElement>(
383 &mut self,
384 shape: &[usize],
385 device: DeviceType,
386 ) -> Result<Tensor<T>>
387 where
388 T: Clone + Default,
389 {
390 let total_elements: usize = shape.iter().product();
391
392 let chunk_size = (1024 * 1024) / std::mem::size_of::<T>().max(1); let num_chunks = (total_elements + chunk_size - 1) / chunk_size;
395
396 let _ = (total_elements, num_chunks, chunk_size); let data = vec![T::default(); total_elements];
401
402 Tensor::from_data(data, shape.to_vec(), device)
407 }
408
409 fn create_pooled_tensor<T: TensorElement>(
411 &mut self,
412 shape: &[usize],
413 device: DeviceType,
414 ) -> Result<Tensor<T>>
415 where
416 T: Clone + Default,
417 {
418 let total_elements: usize = shape.iter().product();
419 let buffer_size = total_elements * std::mem::size_of::<T>();
420
421 let _ = (buffer_size, total_elements); let data = vec![T::default(); total_elements];
426
427 self.stats.pool_hits += 1;
429 Tensor::from_data(data, shape.to_vec(), device)
433 }
434
435 pub fn create_lazy_tensor<T: TensorElement>(
437 &mut self,
438 shape: &[usize],
439 device: DeviceType,
440 ) -> Result<Tensor<T>>
441 where
442 T: Clone + Default,
443 {
444 #[cfg(feature = "profiling")]
445 {
446 }
448 let total_elements: usize = shape.iter().product();
449
450 let data = vec![T::default(); total_elements];
452
453 Tensor::from_data(data, shape.to_vec(), device)
457 }
458
459 pub fn create_zero_copy_view<T: TensorElement>(
461 &self,
462 source: &Tensor<T>,
463 offset: usize,
464 shape: &[usize],
465 ) -> Result<Tensor<T>>
466 where
467 T: Clone,
468 {
469 #[cfg(feature = "profiling")]
470 {
471 }
473
474 let source_data = source.data()?;
476 let view_data = source_data[offset..offset + shape.iter().product::<usize>()].to_vec();
477
478 Tensor::from_data(view_data, shape.to_vec(), source.device())
479 }
480
481 pub fn get_enhanced_stats(&self) -> PoolStatistics {
483 self.stats.clone()
485 }
486
487 pub fn acquire_uninit<T: 'static>(&mut self, count: usize) -> ReusedBuffer<T> {
494 let element_size = std::mem::size_of::<T>();
495 let element_align = std::mem::align_of::<T>();
496 let size_bytes = count * element_size;
497 let size_class = self.find_size_class(size_bytes);
498 let type_id = std::any::TypeId::of::<T>();
499 let pool_key = (type_id, size_class);
500
501 let layout = Layout::from_size_align(size_bytes.max(1), element_align)
502 .expect("size and align are valid for T");
503
504 self.stats.total_allocations += 1;
506 self.stats.total_bytes_allocated += size_bytes;
507
508 if let Some(bucket) = self.pools.get_mut(&pool_key) {
510 let mut found_idx: Option<usize> = None;
512 for (i, entry) in bucket.available_buffers.iter().enumerate() {
513 if entry.capacity_bytes >= size_bytes && entry.layout.align() >= element_align {
514 found_idx = Some(i);
515 break;
516 }
517 }
518 if let Some(idx) = found_idx {
519 let raw_entry = bucket
520 .available_buffers
521 .remove(idx)
522 .expect("index was valid moments ago");
523 self.stats.pool_hits += 1;
524 bucket.reuses += 1;
525
526 let ptr = NonNull::new(raw_entry.ptr.as_ptr() as *mut T)
527 .expect("RawEntry pointer is non-null by construction");
528 let actual_capacity = raw_entry.capacity_bytes / element_size;
530 let entry_layout = raw_entry.layout;
531 std::mem::forget(raw_entry);
532
533 let weak = self.self_weak.clone().unwrap_or_else(Weak::new);
534 return ReusedBuffer {
535 ptr,
536 capacity: actual_capacity,
537 layout: entry_layout,
538 pool: weak,
539 };
540 }
541 }
542
543 self.stats.pool_misses += 1;
545
546 self.pools.entry(pool_key).or_insert_with(|| MemoryPool {
548 available_buffers: VecDeque::new(),
549 size_class,
550 max_buffers: self.config.max_buffers_per_class,
551 allocations: 0,
552 reuses: 0,
553 deallocations: 0,
554 });
555
556 if let Some(bucket) = self.pools.get_mut(&pool_key) {
557 bucket.allocations += 1;
558 }
559
560 let raw_ptr = unsafe { std::alloc::alloc(layout) };
562 let ptr = NonNull::new(raw_ptr as *mut T).unwrap_or_else(|| handle_alloc_error(layout));
563
564 let weak = self.self_weak.clone().unwrap_or_else(Weak::new);
565 ReusedBuffer {
566 ptr,
567 capacity: count,
568 layout,
569 pool: weak,
570 }
571 }
572
573 #[deprecated = "Use global_acquire_uninit instead for zero-copy buffer reuse"]
580 pub fn allocate<T: TensorElement + Default + 'static>(&mut self, count: usize) -> Vec<T> {
581 let mut buf = self.acquire_uninit::<T>(count);
582 for slot in buf.as_uninit_slice_mut() {
584 slot.write(T::default());
585 }
586 buf.into_vec(count)
587 }
588
589 pub fn find_size_class(&self, size_bytes: usize) -> usize {
591 self.config
592 .size_classes
593 .iter()
594 .position(|&class_size| size_bytes <= class_size)
595 .unwrap_or(self.config.size_classes.len() - 1)
596 }
597
598 pub fn deallocate<T: 'static>(&mut self, data: Vec<T>) {
605 drop(data);
607 }
608
609 pub fn clear(&mut self) {
611 self.pools.clear();
612 self.stats = PoolStatistics::default();
613 }
614
615 pub fn get_statistics(&self) -> &PoolStatistics {
617 &self.stats
618 }
619
620 pub fn hit_rate(&self) -> f64 {
622 if self.stats.total_allocations == 0 {
623 0.0
624 } else {
625 self.stats.pool_hits as f64 / self.stats.total_allocations as f64
626 }
627 }
628
629 pub fn cleanup(&mut self) {
631 if self.config.auto_cleanup {
632 let threshold_bytes =
633 (self.config.max_total_memory as f64 * self.config.cleanup_threshold) as usize;
634 if self.stats.total_bytes_allocated > threshold_bytes {
635 self.pools
636 .retain(|_, pool| !pool.available_buffers.is_empty());
637 }
638 }
639 }
640}
641
642impl std::fmt::Debug for GlobalMemoryPool {
643 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
644 f.debug_struct("GlobalMemoryPool")
645 .field("pools", &self.pools)
646 .field("stats", &self.stats)
647 .field("config", &self.config)
648 .field("scirs2_pool", &"<GlobalBufferPool>")
649 .field("leak_detector", &"<LeakDetector>")
650 .finish()
651 }
652}
653
654impl std::fmt::Debug for RawEntry {
657 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
658 f.debug_struct("RawEntry")
659 .field("capacity_bytes", &self.capacity_bytes)
660 .finish()
661 }
662}
663
664pub fn global_acquire_uninit<T: 'static>(count: usize) -> ReusedBuffer<T> {
677 let pool_arc = get_memory_pool();
678 let mut guard = pool_arc
679 .lock()
680 .expect("global memory pool lock should not be poisoned");
681 guard.acquire_uninit::<T>(count)
682}
683
684pub type EnhancedMemoryStats = PoolStatistics;
688
689impl<T: TensorElement> Tensor<T> {
691 pub fn create_efficient(shape: &[usize], device: DeviceType) -> Result<Self>
693 where
694 T: Clone + Default,
695 {
696 let binding = get_memory_pool();
697 let mut pool = binding.lock().expect("lock should not be poisoned");
698 pool.create_large_tensor::<T>(shape, device)
699 }
700
701 pub fn lazy(shape: &[usize], device: DeviceType) -> Result<Self>
703 where
704 T: Clone + Default,
705 {
706 let binding = get_memory_pool();
707 let mut pool = binding.lock().expect("lock should not be poisoned");
708 pool.create_lazy_tensor::<T>(shape, device)
709 }
710
711 pub fn memory_mapped(shape: &[usize], device: DeviceType) -> Result<Self>
722 where
723 T: Clone + Default,
724 {
725 #[cfg(feature = "profiling")]
726 {
727 }
729
730 let total_elements: usize = shape.iter().product();
732 let data = vec![T::default(); total_elements];
733 Self::from_data(data, shape.to_vec(), device)
734 }
735
736 pub fn chunked(shape: &[usize], chunk_size: usize, device: DeviceType) -> Result<Self>
746 where
747 T: Clone + Default,
748 {
749 #[cfg(feature = "profiling")]
750 {
751 }
753 let total_elements: usize = shape.iter().product();
754
755 let effective_chunk_size = if chunk_size == 0 {
757 let default_chunk_bytes = 64 * 1024;
759 let element_size = std::mem::size_of::<T>();
760 (default_chunk_bytes / element_size.max(1)).max(1)
761 } else {
762 chunk_size
763 };
764
765 let cache_line_elements = 64 / std::mem::size_of::<T>().max(1);
767 let aligned_chunk_size = ((effective_chunk_size + cache_line_elements - 1)
768 / cache_line_elements)
769 * cache_line_elements;
770
771 let _ = (total_elements, effective_chunk_size, aligned_chunk_size); let data = vec![T::default(); total_elements];
776
777 Self::from_data(data, shape.to_vec(), device)
780 }
781
782 pub fn disk_backed(shape: &[usize], device: DeviceType, file_path: Option<&str>) -> Result<Self>
796 where
797 T: Clone + Default,
798 {
799 #[cfg(feature = "profiling")]
800 {
801 }
803 let total_elements: usize = shape.iter().product();
804
805 let backing_path = if let Some(path) = file_path {
807 std::path::PathBuf::from(path)
809 } else {
810 let temp_dir = std::env::temp_dir();
812 let timestamp = std::time::SystemTime::now()
813 .duration_since(std::time::UNIX_EPOCH)
814 .unwrap_or_default()
815 .as_secs();
816 temp_dir.join(format!(
817 "torsh_tensor_{}_{}.bin",
818 timestamp,
819 std::process::id()
820 ))
821 };
822
823 let _ = (total_elements, &backing_path); let data = vec![T::default(); total_elements];
829
830 let tensor = Self::from_data(data, shape.to_vec(), device)?;
833
834 Ok(tensor)
835 }
836
837 pub fn process_chunked<F, R>(&self, chunk_size: usize, mut processor: F) -> Result<Vec<R>>
839 where
840 F: FnMut(&[T]) -> Result<R>,
841 T: Clone,
842 {
843 #[cfg(feature = "profiling")]
844 {
845 }
847 let data = self.data()?;
848 let mut results = Vec::new();
849
850 let effective_chunk_size = chunk_size;
852
853 for chunk in data.chunks(effective_chunk_size) {
854 results.push(processor(chunk)?);
855 }
856
857 Ok(results)
858 }
859}
860
861impl MemoryPool {
862 fn new(size_class: usize, max_buffers: usize) -> Self {
863 Self {
864 available_buffers: VecDeque::new(),
865 size_class,
866 max_buffers,
867 allocations: 0,
868 reuses: 0,
869 deallocations: 0,
870 }
871 }
872}
873
874impl<T: TensorElement + Copy + Default> PooledTensor<T> {
875 pub fn new(shape: &[usize], device: DeviceType) -> Result<Self> {
877 let numel = shape.iter().product::<usize>();
878
879 let pool = get_memory_pool();
881 let data = {
882 let mut pool_guard = pool.lock().expect("lock should not be poisoned");
883 #[allow(deprecated)]
884 pool_guard.allocate::<T>(numel)
885 };
886
887 let tensor = Tensor::from_data(data, shape.to_vec(), device)?;
888 let type_id = std::any::TypeId::of::<T>();
889 let size_class = {
890 let pool_guard = pool.lock().expect("lock should not be poisoned");
891 pool_guard.find_size_class(numel * std::mem::size_of::<T>())
892 };
893
894 Ok(Self {
895 tensor,
896 pool_key: Some((type_id, size_class)),
897 _phantom: PhantomData,
898 })
899 }
900
901 pub fn zeros(shape: &[usize], device: DeviceType) -> Result<Self> {
903 let mut pooled = Self::new(shape, device)?;
904 let numel = shape.iter().product::<usize>();
906 let data = vec![T::default(); numel];
907 pooled.tensor.storage = TensorStorage::create_optimal(data)?;
908 Ok(pooled)
909 }
910
911 pub fn ones(shape: &[usize], device: DeviceType) -> Result<Self>
913 where
914 T: std::ops::Add<Output = T> + From<f32>,
915 {
916 let mut pooled = Self::new(shape, device)?;
917 let numel = shape.iter().product::<usize>();
919 let data = vec![T::from(1.0f32); numel];
920 pooled.tensor.storage = TensorStorage::create_optimal(data)?;
921 Ok(pooled)
922 }
923
924 pub fn tensor(&self) -> &Tensor<T> {
926 &self.tensor
927 }
928
929 pub fn tensor_mut(&mut self) -> &mut Tensor<T> {
931 &mut self.tensor
932 }
933
934 pub fn into_tensor(mut self) -> Tensor<T> {
936 self.pool_key = None; self.tensor.clone()
938 }
939}
940
941impl<T: TensorElement + std::default::Default> Drop for PooledTensor<T> {
942 fn drop(&mut self) {
943 if let Some((_type_id, _size_class)) = self.pool_key {
944 if let Ok(data) = self.tensor.to_vec() {
946 let pool = get_memory_pool();
947 let mut pool_guard = pool.lock().expect("lock should not be poisoned");
948 pool_guard.deallocate(data);
949 }
950 }
951 }
952}
953
954impl<T: TensorElement + Copy + Default> Tensor<T> {
956 pub fn pooled(shape: &[usize], device: DeviceType) -> Result<PooledTensor<T>> {
958 PooledTensor::new(shape, device)
959 }
960
961 pub fn temporary(shape: &[usize], device: DeviceType) -> Result<PooledTensor<T>> {
963 PooledTensor::new(shape, device)
964 }
965}
966
967pub fn clear_memory_pool() {
969 if let Some(pool) = MEMORY_POOL.get() {
970 pool.lock().expect("lock should not be poisoned").clear();
971 }
972}
973
974pub fn get_pool_statistics() -> PoolStatistics {
975 get_memory_pool()
976 .lock()
977 .expect("lock should not be poisoned")
978 .get_statistics()
979 .clone()
980}
981
982pub fn get_pool_hit_rate() -> f64 {
983 get_memory_pool()
984 .lock()
985 .expect("lock should not be poisoned")
986 .hit_rate()
987}
988
989pub fn cleanup_memory_pool() {
990 get_memory_pool()
991 .lock()
992 .expect("lock should not be poisoned")
993 .cleanup();
994}
995
996#[cfg(test)]
997mod tests {
998 use super::*;
999
1000 static TEST_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
1002
1003 #[test]
1004 fn test_memory_pool_basic() {
1005 clear_memory_pool();
1006
1007 let pooled = PooledTensor::<f32>::zeros(&[100, 100], DeviceType::Cpu)
1009 .expect("zeros creation should succeed");
1010 assert_eq!(pooled.tensor().numel(), 10000);
1011
1012 drop(pooled);
1014
1015 let _pooled2 = PooledTensor::<f32>::zeros(&[100, 100], DeviceType::Cpu)
1017 .expect("zeros creation should succeed");
1018
1019 let stats = get_pool_statistics();
1020 assert!(stats.pool_hits > 0 || stats.pool_misses > 0);
1021 }
1022
1023 #[test]
1024 fn test_pool_statistics() {
1025 clear_memory_pool();
1026
1027 let _pooled1 = PooledTensor::<f32>::zeros(&[50, 50], DeviceType::Cpu)
1028 .expect("zeros creation should succeed");
1029 let _pooled2 = PooledTensor::<f32>::ones(&[50, 50], DeviceType::Cpu)
1030 .expect("ones creation should succeed");
1031
1032 let stats = get_pool_statistics();
1033 assert!(stats.total_allocations >= 2);
1034 assert!(stats.total_bytes_allocated > 0);
1035 }
1036
1037 #[test]
1038 fn test_pool_cleanup() {
1039 clear_memory_pool();
1040
1041 for _ in 0..10 {
1043 let _temp = PooledTensor::<f32>::zeros(&[100, 100], DeviceType::Cpu)
1044 .expect("zeros creation should succeed");
1045 }
1046
1047 cleanup_memory_pool();
1048 let _stats = get_pool_statistics();
1049 }
1051
1052 #[test]
1053 fn test_pooled_tensor_conversion() {
1054 let pooled = PooledTensor::<f32>::ones(&[10, 10], DeviceType::Cpu)
1055 .expect("ones creation should succeed");
1056 let tensor = pooled.into_tensor();
1057 assert_eq!(tensor.numel(), 100);
1058 }
1059
1060 #[test]
1063 fn test_acquire_truly_reuses_allocation() {
1064 let _guard = TEST_LOCK.lock().expect("test mutex should not be poisoned");
1065 clear_memory_pool();
1066
1067 let buf1: ReusedBuffer<f32> = global_acquire_uninit::<f32>(1024);
1068 let ptr1 = buf1.as_ptr_raw();
1069 buf1.release_to_pool();
1070
1071 let buf2: ReusedBuffer<f32> = global_acquire_uninit::<f32>(1024);
1072 let ptr2 = buf2.as_ptr_raw();
1073 buf2.release_to_pool();
1074
1075 assert_eq!(
1076 ptr1, ptr2,
1077 "pool should return the same allocation on second acquire"
1078 );
1079 }
1080
1081 #[test]
1082 fn test_into_vec_transfers_ownership() {
1083 let _guard = TEST_LOCK.lock().expect("test mutex should not be poisoned");
1084 clear_memory_pool();
1085
1086 let mut buf: ReusedBuffer<f32> = global_acquire_uninit::<f32>(64);
1087 for slot in buf.as_uninit_slice_mut() {
1089 slot.write(1.0_f32);
1090 }
1091 let vec = buf.into_vec(64);
1092 assert_eq!(vec.len(), 64);
1093 assert!(vec.iter().all(|&x| x == 1.0_f32));
1094 }
1095
1096 #[test]
1097 fn test_drop_returns_to_pool() {
1098 let _guard = TEST_LOCK.lock().expect("test mutex should not be poisoned");
1099 clear_memory_pool();
1100
1101 {
1102 let buf: ReusedBuffer<f32> = global_acquire_uninit::<f32>(256);
1103 drop(buf);
1105 }
1106
1107 let buf2: ReusedBuffer<f32> = global_acquire_uninit::<f32>(256);
1109 buf2.release_to_pool();
1110
1111 let stats = get_pool_statistics();
1112 assert!(
1113 stats.pool_hits >= 1,
1114 "expected at least one pool hit after drop-return"
1115 );
1116 }
1117
1118 #[test]
1119 fn test_acquire_capacity_and_uninit_slice() {
1120 let _guard = TEST_LOCK.lock().expect("test mutex should not be poisoned");
1121 clear_memory_pool();
1122
1123 let buf: ReusedBuffer<u64> = global_acquire_uninit::<u64>(32);
1124 assert_eq!(buf.capacity(), 32);
1125 buf.release_to_pool();
1126 }
1127}