1use std::alloc::Layout;
26use std::cell::Cell;
27use std::cell::RefCell;
28use std::ptr::NonNull;
29use std::time::Instant;
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub struct PoolStats {
37 pub allocations: usize,
39 pub deallocations: usize,
41 pub pool_hits: usize,
43 pub pool_misses: usize,
45 pub current_usage: usize,
47 pub peak_usage: usize,
49}
50
51pub const SMALL_BUFFER_SIZE: usize = 1024; pub const MEDIUM_BUFFER_SIZE: usize = 65536; pub const LARGE_BUFFER_SIZE: usize = 1048576; const TARGET_SMALL_BUFFERS: usize = 32; const TARGET_MEDIUM_BUFFERS: usize = 16; const TARGET_LARGE_BUFFERS: usize = 8; const HEADROOM_SMALL: usize = 8;
79const HEADROOM_MEDIUM: usize = 4;
80const HEADROOM_LARGE: usize = 2;
81const HEADROOM_XLARGE: usize = 1;
82
83const CLEANUP_MIN_OPS: u64 = 2048;
85const CLEANUP_MIN_INTERVAL_MS: u64 = 2000; const UNUSED_OPS_THRESHOLD: u64 = 4096;
89
90pub struct PooledBuffer {
102 alloc: crate::tensor::core::Allocation,
104 in_use: bool,
106 last_used_counter: u64,
108}
109
110pub struct TensorMemoryPool {
123 small_buffers: Vec<PooledBuffer>,
126
127 medium_buffers: Vec<PooledBuffer>,
130
131 large_buffers: Vec<PooledBuffer>,
134
135 xlarge_buffers: Vec<PooledBuffer>,
138
139 stats: PoolStats,
141 op_counter: u64,
144 last_cleanup_counter: u64,
146 last_cleanup_instant: Instant,
148}
149
150#[derive(Debug, Clone, Copy, PartialEq, Eq)]
154enum SizeClass {
155 Small, Medium, Large, XLarge, }
160
161thread_local! {
166 static MEMORY_POOL: RefCell<TensorMemoryPool> = RefCell::new(TensorMemoryPool::new());
167 static NO_MEM_PADDING: Cell<bool> = const { Cell::new(false) };
171 static USE_POOL_ALLOC: Cell<bool> = const { Cell::new(true) };
174}
175
176#[derive(Debug, Clone, Copy, PartialEq, Eq)]
178pub enum SimdLevel {
179 #[cfg(target_arch = "x86_64")]
180 Avx512,
181 #[cfg(target_arch = "x86_64")]
182 Avx2,
183 #[cfg(target_arch = "x86_64")]
184 Sse2,
185 Scalar,
186}
187
188#[inline]
193pub fn detect_runtime_simd() -> SimdLevel {
194 #[cfg(target_arch = "x86_64")]
195 {
196 if is_x86_feature_detected!("avx512f") {
198 return SimdLevel::Avx512;
199 }
200 if is_x86_feature_detected!("avx2") {
201 return SimdLevel::Avx2;
202 }
203 if is_x86_feature_detected!("sse2") {
204 return SimdLevel::Sse2;
205 }
206
207 SimdLevel::Scalar
208 }
209 #[cfg(not(target_arch = "x86_64"))]
210 {
211 SimdLevel::Scalar
212 }
213}
214
215#[inline]
217pub(crate) fn simd_lane_width_elems(level: SimdLevel) -> usize {
218 match level {
219 #[cfg(target_arch = "x86_64")]
220 SimdLevel::Avx512 => 16, #[cfg(target_arch = "x86_64")]
222 SimdLevel::Avx2 => 8, #[cfg(target_arch = "x86_64")]
224 SimdLevel::Sse2 => 4, SimdLevel::Scalar => 1,
226 }
227}
228
229#[inline]
234pub fn simd_alignment_bytes(level: SimdLevel) -> usize {
235 match level {
236 #[cfg(target_arch = "x86_64")]
237 SimdLevel::Avx512 => 64,
238 #[cfg(target_arch = "x86_64")]
239 SimdLevel::Avx2 => 32,
240 #[cfg(target_arch = "x86_64")]
241 SimdLevel::Sse2 => 16,
242 SimdLevel::Scalar => 16, }
244}
245
246#[inline]
260pub fn compute_allocation_params(requested_elems: usize) -> (usize, usize) {
261 let level = detect_runtime_simd();
262 #[cfg(target_arch = "x86_64")]
263 let mut align = simd_alignment_bytes(level);
264 #[cfg(not(target_arch = "x86_64"))]
265 let align = simd_alignment_bytes(level);
266
267 #[cfg(target_arch = "x86_64")]
269 {
270 if is_x86_feature_detected!("avx512f") {
271 align = 64;
272 } else if is_x86_feature_detected!("avx2") {
273 align = align.max(32);
274 }
275 }
276
277 if no_mem_padding_enabled() || requested_elems == 0 {
278 (align, requested_elems)
279 } else {
280 let lane = simd_lane_width_elems(level);
281 let padded = requested_elems.div_ceil(lane) * lane;
282 (align, padded)
283 }
284}
285
286#[inline]
290#[cfg(target_arch = "x86_64")]
291pub fn stream_min_elems() -> usize {
292 match detect_runtime_simd() {
293 SimdLevel::Avx512 => 16_384, SimdLevel::Avx2 => 8_192, SimdLevel::Sse2 => 4_096, SimdLevel::Scalar => usize::MAX,
297 }
298}
299
300#[inline]
304#[cfg(target_arch = "x86_64")]
305pub fn prefetch_distance_elems() -> usize {
306 match detect_runtime_simd() {
307 SimdLevel::Avx512 => 512, SimdLevel::Avx2 => 256, SimdLevel::Sse2 => 128, SimdLevel::Scalar => 0,
311 }
312}
313
314#[inline]
327pub fn choose_fast_chunk_size(total_elems: usize) -> usize {
328 if total_elems == 0 {
329 return 1;
330 }
331 let mut sz = 16_384usize; if total_elems < 16_384 {
334 sz = 4_096;
335 } else if total_elems > 1_048_576 {
336 sz = 65_536;
337 }
338 let lane = simd_lane_width_elems(detect_runtime_simd());
340 if lane > 1 {
341 sz = sz.div_ceil(lane) * lane;
342 }
343 sz.clamp(4_096, 262_144)
345}
346
347#[inline]
349pub fn use_pool_alloc_enabled() -> bool {
350 USE_POOL_ALLOC.with(|flag| flag.get())
351}
352
353impl PooledBuffer {
354 fn new(size: usize, alignment: usize) -> Self {
359 let effective_alignment = alignment.max(std::mem::align_of::<f32>());
361 let layout =
362 Layout::from_size_align(size * std::mem::size_of::<f32>(), effective_alignment)
363 .expect("Invalid layout for pooled buffer");
364 let alloc =
366 crate::tensor::core::Allocation::new_uninitialized(size, effective_alignment, layout);
367 let addr = alloc.ptr.as_ptr() as usize;
369 assert_eq!(
370 addr % alignment,
371 0,
372 "System allocator failed to provide {}-byte aligned memory. Got address 0x{:x} (alignment {})",
373 alignment,
374 addr,
375 addr % alignment
376 );
377 PooledBuffer {
378 alloc,
379 in_use: false,
380 last_used_counter: 0,
381 }
382 }
383
384 #[inline(always)]
386 pub fn as_ptr(&self) -> NonNull<f32> {
387 self.alloc.ptr
388 }
389
390 #[inline(always)]
392 pub fn size(&self) -> usize {
393 self.alloc.capacity_elems()
394 }
395
396 #[inline]
400 fn allocate_for_tensor(&mut self, now_counter: u64) -> bool {
401 if self.in_use {
402 false
403 } else {
404 self.in_use = true;
405 self.last_used_counter = now_counter;
406 true
407 }
408 }
409
410 #[inline]
412 fn return_to_pool(&mut self, now_counter: u64) {
413 self.in_use = false;
414 self.last_used_counter = now_counter;
415 }
416
417 #[inline(always)]
419 pub fn is_available(&self) -> bool {
420 !self.in_use
421 }
422}
423
424impl TensorMemoryPool {
427 pub fn new() -> Self {
431 TensorMemoryPool {
432 small_buffers: Vec::with_capacity(TARGET_SMALL_BUFFERS),
434 medium_buffers: Vec::with_capacity(TARGET_MEDIUM_BUFFERS),
435 large_buffers: Vec::with_capacity(TARGET_LARGE_BUFFERS),
436 xlarge_buffers: Vec::with_capacity(4),
437 stats: PoolStats::new(),
438 op_counter: 0,
439 last_cleanup_counter: 0,
440 last_cleanup_instant: Instant::now(),
441 }
442 }
443
444 fn try_allocate(&mut self, size: usize, alignment: usize) -> Option<NonNull<f32>> {
449 let size_class = self.classify_size(size);
450
451 self.try_allocate_internal(size, alignment, size_class)
452 }
453
454 fn try_allocate_internal(
456 &mut self,
457 size: usize,
458 alignment: usize,
459 size_class: SizeClass,
460 ) -> Option<NonNull<f32>> {
461 self.maybe_cleanup();
463 match size_class {
464 SizeClass::Small => {
465 self.try_allocate_from_small_pool(SMALL_BUFFER_SIZE, alignment, size_class)
466 }
467 SizeClass::Medium => {
468 self.try_allocate_from_medium_pool(MEDIUM_BUFFER_SIZE, alignment, size_class)
469 }
470 SizeClass::Large => {
471 self.try_allocate_from_large_pool(LARGE_BUFFER_SIZE, alignment, size_class)
472 }
473 SizeClass::XLarge => {
474 let planned = TensorMemoryPool::planned_capacity_elems(size);
475 self.try_allocate_from_xlarge_pool(planned, alignment, size_class)
476 }
477 }
478 }
479
480 fn try_allocate_from_small_pool(
482 &mut self,
483 buffer_size: usize,
484 alignment: usize,
485 _size_class: SizeClass,
486 ) -> Option<NonNull<f32>> {
487 let nowc = self.bump_op_counter();
488 for buffer in self.small_buffers.iter_mut() {
489 if buffer.is_available()
490 && buffer.alloc.alignment() >= alignment
491 && buffer.allocate_for_tensor(nowc)
492 {
493 self.stats.record_allocation_hit(buffer_size);
494 return Some(buffer.as_ptr());
495 }
496 }
497 let mut new_buffer = PooledBuffer::new(buffer_size, alignment);
498 if new_buffer.allocate_for_tensor(nowc) {
499 let ptr = new_buffer.as_ptr();
500 self.small_buffers.push(new_buffer);
501 self.stats
502 .record_allocation_miss(buffer_size, "new_buffer_created");
503 Some(ptr)
504 } else {
505 None
506 }
507 }
508
509 fn try_allocate_from_medium_pool(
511 &mut self,
512 buffer_size: usize,
513 alignment: usize,
514 _size_class: SizeClass,
515 ) -> Option<NonNull<f32>> {
516 let nowc = self.bump_op_counter();
517 for buffer in self.medium_buffers.iter_mut() {
518 if buffer.is_available()
519 && buffer.alloc.alignment() >= alignment
520 && buffer.allocate_for_tensor(nowc)
521 {
522 self.stats.record_allocation_hit(buffer_size);
523 return Some(buffer.as_ptr());
524 }
525 }
526 let mut new_buffer = PooledBuffer::new(buffer_size, alignment);
527 if new_buffer.allocate_for_tensor(nowc) {
528 let ptr = new_buffer.as_ptr();
529 self.medium_buffers.push(new_buffer);
530 self.stats
531 .record_allocation_miss(buffer_size, "new_buffer_created");
532 Some(ptr)
533 } else {
534 None
535 }
536 }
537
538 fn try_allocate_from_large_pool(
540 &mut self,
541 buffer_size: usize,
542 alignment: usize,
543 _size_class: SizeClass,
544 ) -> Option<NonNull<f32>> {
545 let nowc = self.bump_op_counter();
546 for buffer in self.large_buffers.iter_mut() {
547 if buffer.is_available()
548 && buffer.alloc.alignment() >= alignment
549 && buffer.allocate_for_tensor(nowc)
550 {
551 self.stats.record_allocation_hit(buffer_size);
552 return Some(buffer.as_ptr());
553 }
554 }
555 let mut new_buffer = PooledBuffer::new(buffer_size, alignment);
556 if new_buffer.allocate_for_tensor(nowc) {
557 let ptr = new_buffer.as_ptr();
558 self.large_buffers.push(new_buffer);
559 self.stats
560 .record_allocation_miss(buffer_size, "new_buffer_created");
561 Some(ptr)
562 } else {
563 None
564 }
565 }
566
567 fn try_allocate_from_xlarge_pool(
569 &mut self,
570 buffer_size: usize,
571 alignment: usize,
572 _size_class: SizeClass,
573 ) -> Option<NonNull<f32>> {
574 let nowc = self.bump_op_counter();
575 for buffer in self.xlarge_buffers.iter_mut() {
576 if buffer.is_available()
578 && buffer.size() >= buffer_size
579 && buffer.alloc.alignment() >= alignment
580 && buffer.allocate_for_tensor(nowc)
581 {
582 self.stats.record_allocation_hit(buffer_size);
583 return Some(buffer.as_ptr());
584 }
585 }
586 let mut new_buffer = PooledBuffer::new(buffer_size, alignment);
587 if new_buffer.allocate_for_tensor(nowc) {
588 let ptr = new_buffer.as_ptr();
589 self.xlarge_buffers.push(new_buffer);
590 self.stats
591 .record_allocation_miss(buffer_size, "new_buffer_created");
592 Some(ptr)
593 } else {
594 None
595 }
596 }
597
598 #[inline]
602 fn classify_size(&self, size: usize) -> SizeClass {
603 if size <= SMALL_BUFFER_SIZE {
604 SizeClass::Small
605 } else if size <= MEDIUM_BUFFER_SIZE {
606 SizeClass::Medium
607 } else if size <= LARGE_BUFFER_SIZE {
608 SizeClass::Large
609 } else {
610 SizeClass::XLarge
611 }
612 }
613
614 #[cfg(test)]
615 fn stats(&self) -> &PoolStats {
616 &self.stats
617 }
618}
619
620#[allow(dead_code)]
624pub struct NoMemPaddingGuard {
625 prev: bool,
626}
627
628impl Drop for NoMemPaddingGuard {
629 fn drop(&mut self) {
630 let _ = NO_MEM_PADDING.try_with(|flag| flag.set(self.prev));
631 }
632}
633
634impl NoMemPaddingGuard {
635 #[allow(dead_code)]
637 pub fn new() -> Self {
638 let prev = NO_MEM_PADDING.with(|flag| {
639 let old = flag.get();
640 flag.set(true);
641 old
642 });
643 NoMemPaddingGuard { prev }
644 }
645}
646
647impl Default for NoMemPaddingGuard {
648 fn default() -> Self {
649 Self::new()
650 }
651}
652
653pub struct NoMemPoolGuard {
655 prev: bool,
656}
657
658impl Drop for NoMemPoolGuard {
659 fn drop(&mut self) {
660 let _ = USE_POOL_ALLOC.try_with(|flag| flag.set(self.prev));
661 }
662}
663
664impl NoMemPoolGuard {
665 pub fn new() -> Self {
667 let prev = USE_POOL_ALLOC.with(|flag| {
668 let old = flag.get();
669 flag.set(false);
670 old
671 });
672 NoMemPoolGuard { prev }
673 }
674}
675
676impl Default for NoMemPoolGuard {
677 fn default() -> Self {
678 Self::new()
679 }
680}
681
682#[inline]
684pub fn with_no_mem_pool<F, R>(f: F) -> R
685where
686 F: FnOnce() -> R,
687{
688 let _guard = NoMemPoolGuard::new();
689 f()
690}
691
692#[inline]
694#[allow(dead_code)]
695pub fn with_no_mem_padding<F, R>(f: F) -> R
696where
697 F: FnOnce() -> R,
698{
699 let _guard = NoMemPaddingGuard::new();
700 f()
701}
702
703#[inline]
705pub fn no_mem_padding_enabled() -> bool {
706 NO_MEM_PADDING.with(|flag| flag.get())
707}
708
709impl TensorMemoryPool {
710 pub fn planned_capacity_elems(requested_elems: usize) -> usize {
713 if requested_elems <= SMALL_BUFFER_SIZE {
714 SMALL_BUFFER_SIZE
715 } else if requested_elems <= MEDIUM_BUFFER_SIZE {
716 MEDIUM_BUFFER_SIZE
717 } else if requested_elems <= LARGE_BUFFER_SIZE {
718 LARGE_BUFFER_SIZE
719 } else {
720 (requested_elems * 2).max(262144 * 2)
722 }
723 }
724}
725
726impl PoolStats {
727 fn new() -> Self {
728 PoolStats {
729 allocations: 0,
730 deallocations: 0,
731 pool_hits: 0,
732 pool_misses: 0,
733 current_usage: 0,
734 peak_usage: 0,
735 }
736 }
737
738 fn record_allocation_hit(&mut self, buffer_size: usize) {
739 self.allocations += 1;
740 self.pool_hits += 1;
741 self.current_usage += buffer_size;
742 if self.current_usage > self.peak_usage {
743 self.peak_usage = self.current_usage;
744 }
745 }
746
747 fn record_allocation_miss(&mut self, _buffer_size: usize, _reason: &str) {
748 self.allocations += 1;
749 self.pool_misses += 1;
750 }
751
752 fn record_deallocation(&mut self, size: usize) {
753 self.deallocations += 1;
754 self.current_usage = self.current_usage.saturating_sub(size);
755 }
756}
757
758impl TensorMemoryPool {
760 pub fn allocate(size: usize, alignment: usize) -> Option<NonNull<f32>> {
765 let result = MEMORY_POOL.with(|pool| pool.borrow_mut().try_allocate(size, alignment));
766 result
767 }
768
769 pub fn try_deallocate(ptr: NonNull<f32>) -> Option<bool> {
773 MEMORY_POOL
774 .try_with(|pool| {
775 let mut pool_mut = pool.borrow_mut();
776 pool_mut.return_to_pool(ptr)
777 })
778 .ok()
779 }
780
781 fn return_to_pool(&mut self, ptr: NonNull<f32>) -> bool {
787 if self.return_to_small_pool(ptr) {
789 self.maybe_cleanup();
790 return true;
791 }
792 if self.return_to_medium_pool(ptr) {
793 self.maybe_cleanup();
794 return true;
795 }
796 if self.return_to_large_pool(ptr) {
797 self.maybe_cleanup();
798 return true;
799 }
800 if self.return_to_xlarge_pool(ptr) {
801 self.maybe_cleanup();
802 return true;
803 }
804
805 false
807 }
808
809 fn return_to_small_pool(&mut self, ptr: NonNull<f32>) -> bool {
811 let nowc = self.bump_op_counter();
812 for buffer in self.small_buffers.iter_mut() {
813 if buffer.as_ptr() == ptr {
814 buffer.return_to_pool(nowc);
815 self.stats.record_deallocation(buffer.size());
816 return true;
817 }
818 }
819 false
820 }
821
822 fn return_to_medium_pool(&mut self, ptr: NonNull<f32>) -> bool {
824 let nowc = self.bump_op_counter();
825 for buffer in self.medium_buffers.iter_mut() {
826 if buffer.as_ptr() == ptr {
827 buffer.return_to_pool(nowc);
828 self.stats.record_deallocation(buffer.size());
829 return true;
830 }
831 }
832 false
833 }
834
835 fn return_to_large_pool(&mut self, ptr: NonNull<f32>) -> bool {
837 let nowc = self.bump_op_counter();
838 for buffer in self.large_buffers.iter_mut() {
839 if buffer.as_ptr() == ptr {
840 buffer.return_to_pool(nowc);
841 self.stats.record_deallocation(buffer.size());
842 return true;
843 }
844 }
845 false
846 }
847
848 fn return_to_xlarge_pool(&mut self, ptr: NonNull<f32>) -> bool {
850 let nowc = self.bump_op_counter();
851 for buffer in self.xlarge_buffers.iter_mut() {
852 if buffer.as_ptr() == ptr {
853 buffer.return_to_pool(nowc);
854 self.stats.record_deallocation(buffer.size());
855 return true;
856 }
857 }
858 false
859 }
860
861 #[cfg(test)]
863 pub fn thread_stats() -> PoolStats {
864 MEMORY_POOL.with(|pool| *pool.borrow().stats())
865 }
866
867 #[cfg(test)]
869 pub fn pool_sizes() -> (usize, usize, usize, usize) {
870 MEMORY_POOL.with(|pool| {
871 let p = pool.borrow();
872 (
873 p.small_buffers.len(),
874 p.medium_buffers.len(),
875 p.large_buffers.len(),
876 p.xlarge_buffers.len(),
877 )
878 })
879 }
880}
881
882impl TensorMemoryPool {
883 #[inline]
884 fn bump_op_counter(&mut self) -> u64 {
885 self.op_counter = self.op_counter.wrapping_add(1);
887 self.op_counter
888 }
889
890 #[inline]
892 fn should_cleanup(&self) -> bool {
893 let ops_since = self.op_counter.wrapping_sub(self.last_cleanup_counter);
894 if ops_since < CLEANUP_MIN_OPS {
895 return false;
896 }
897 let elapsed = self.last_cleanup_instant.elapsed();
898 elapsed.as_millis() as u64 >= CLEANUP_MIN_INTERVAL_MS
899 }
900
901 fn maybe_cleanup(&mut self) {
903 if !self.should_cleanup() {
904 return;
905 }
906
907 let nowc = self.op_counter;
909 Self::cleanup_pool_vec(
910 &mut self.small_buffers,
911 TARGET_SMALL_BUFFERS,
912 HEADROOM_SMALL,
913 nowc,
914 );
915 Self::cleanup_pool_vec(
916 &mut self.medium_buffers,
917 TARGET_MEDIUM_BUFFERS,
918 HEADROOM_MEDIUM,
919 nowc,
920 );
921 Self::cleanup_pool_vec(
922 &mut self.large_buffers,
923 TARGET_LARGE_BUFFERS,
924 HEADROOM_LARGE,
925 nowc,
926 );
927 Self::cleanup_pool_vec(&mut self.xlarge_buffers, 2, HEADROOM_XLARGE, nowc);
929
930 self.last_cleanup_counter = self.op_counter;
932 self.last_cleanup_instant = Instant::now();
933 }
934
935 fn cleanup_pool_vec(
936 vec: &mut Vec<PooledBuffer>,
937 target: usize,
938 headroom: usize,
939 now_counter: u64,
940 ) {
941 if vec.is_empty() {
942 return;
943 }
944 let in_use = vec.iter().filter(|b| !b.is_available()).count();
946 let desired = core::cmp::max(target, in_use.saturating_add(headroom));
947 if vec.len() <= desired {
948 return;
949 }
950
951 let mut eligible: Vec<(usize, u64)> = vec
953 .iter()
954 .enumerate()
955 .filter(|(_i, b)| b.is_available())
956 .map(|(i, b)| (i, now_counter.wrapping_sub(b.last_used_counter)))
957 .filter(|(_i, age_ops)| *age_ops >= UNUSED_OPS_THRESHOLD)
958 .collect();
959
960 if eligible.is_empty() {
961 return;
962 }
963
964 eligible.sort_by_key(|(_i, age)| core::cmp::Reverse(*age));
966
967 let excess = vec.len().saturating_sub(desired);
968 let to_remove = core::cmp::min(excess, eligible.len());
969 if to_remove == 0 {
970 return;
971 }
972
973 let mut to_drop: Vec<usize> = eligible.iter().take(to_remove).map(|(i, _)| *i).collect();
975 to_drop.sort_unstable_by(|a, b| b.cmp(a));
976 for idx in to_drop {
977 vec.remove(idx);
978 }
979 }
980}
981
982#[cfg(test)]
983mod tests {
984 use super::*;
985
986 #[test]
987 fn test_with_no_mem_padding_guard_scoping() {
988 assert!(!no_mem_padding_enabled());
990 {
991 let _g = NoMemPaddingGuard::new();
992 assert!(no_mem_padding_enabled());
993 }
994 assert!(!no_mem_padding_enabled());
995 }
996
997 #[test]
998 fn test_compute_allocation_params_padding_behavior() {
999 let (align1, padded1) = compute_allocation_params(33);
1001 let lane = simd_lane_width_elems(detect_runtime_simd());
1002 assert!(padded1 >= 33);
1003 assert_eq!(padded1 % lane, 0);
1004 assert!(align1 >= 16);
1005
1006 let res = with_no_mem_padding(|| compute_allocation_params(33));
1008
1009 assert_eq!(res.1, 33);
1010 }
1011
1012 #[test]
1013 fn test_same_thread_alloc_dealloc_counters_across_classes() {
1014 let before = TensorMemoryPool::thread_stats();
1015 {
1016 let lane = simd_lane_width_elems(detect_runtime_simd());
1017 let sizes = [
1018 SMALL_BUFFER_SIZE.min(8),
1019 MEDIUM_BUFFER_SIZE / 2,
1020 LARGE_BUFFER_SIZE / 2,
1021 LARGE_BUFFER_SIZE + lane * 3 + 7, ];
1023 for &n in &sizes {
1024 let _t = crate::tensor::Tensor::new(vec![n]);
1025 }
1026 }
1027 let after = TensorMemoryPool::thread_stats();
1028 assert!(after.allocations >= before.allocations + 4);
1029 assert!(after.deallocations >= before.deallocations + 4);
1030 }
1031
1032 #[test]
1033 fn test_xlarge_pool_does_not_reuse_too_small_buffer() {
1034 let lane = simd_lane_width_elems(detect_runtime_simd());
1035 let align = simd_alignment_bytes(detect_runtime_simd());
1036 let small_xlarge = LARGE_BUFFER_SIZE + lane * 2;
1038 let _t1 = crate::tensor::Tensor::new(vec![small_xlarge]);
1039 let larger = small_xlarge * 2 + lane * 3;
1041 let ptr_opt = MEMORY_POOL.with(|pool| {
1042 let mut p = pool.borrow_mut();
1043 p.try_allocate_from_xlarge_pool(larger, align, SizeClass::XLarge)
1044 });
1045 assert!(ptr_opt.is_some());
1048 }
1049
1050 #[test]
1051 fn test_cross_thread_drop_safe_no_crash() {
1052 use std::thread;
1053 let lane = simd_lane_width_elems(detect_runtime_simd());
1054 let n = LARGE_BUFFER_SIZE + lane * 2 + 3; let t = crate::tensor::Tensor::new(vec![n]);
1056 let handle = thread::spawn(move || {
1057 drop(t);
1059 });
1060 let _ = handle.join();
1061 }
1062
1063 #[test]
1064 fn test_try_deallocate_returns_some_true_for_pooled() {
1065 let align = simd_alignment_bytes(detect_runtime_simd());
1066 let ptr = TensorMemoryPool::allocate(128, align).expect("pool allocate failed");
1067 let res = TensorMemoryPool::try_deallocate(ptr);
1068 assert_eq!(res, Some(true));
1069 }
1070
1071 #[test]
1072 fn perf_pool_vs_no_pool_by_category_over_1000_iterations() {
1073 use std::time::Instant;
1074
1075 let small = vec![32, 32]; let medium = vec![256, 256]; let large = vec![1024, 1024]; let xlarge = vec![1200, 1200]; fn bench_shape(shape: &[usize], iters: usize) -> std::time::Duration {
1082 let start = Instant::now();
1083 let mut sink = 0.0f32;
1084 for i in 0..iters {
1085 let t0 = crate::tensor::Tensor::ones(shape.to_vec());
1087 let t1 = t0.add_scalar((i % 5) as f32 * 0.1);
1089 let t2 = t1.mul_scalar(1.2345);
1090 let s = t2.sum();
1092 sink += s.value();
1093 }
1094 assert!(sink.is_finite());
1095 start.elapsed()
1096 }
1097
1098 let iters = 1000usize;
1099
1100 let cats: [(&str, Vec<usize>); 4] = [
1101 ("small", small),
1102 ("medium", medium),
1103 ("large", large),
1104 ("xlarge", xlarge),
1105 ];
1106
1107 for (name, shape) in cats.iter() {
1108 let pooled = bench_shape(shape, iters);
1109 let system = super::with_no_mem_pool(|| bench_shape(shape, iters));
1110 let pooled_ms = pooled.as_secs_f64() * 1_000.0;
1111 let system_ms = system.as_secs_f64() * 1_000.0;
1112 let speedup = if pooled_ms > 0.0 {
1113 system_ms / pooled_ms
1114 } else {
1115 0.0
1116 };
1117 println!(
1118 "Perf [{} | {:?} elems]: pooled={:.2} ms, no_pool={:.2} ms, speedup={:.2}x (iters={})",
1119 name,
1120 shape.iter().product::<usize>(),
1121 pooled_ms,
1122 system_ms,
1123 speedup,
1124 iters
1125 );
1126
1127 assert!(pooled > std::time::Duration::from_millis(0));
1129 assert!(system > std::time::Duration::from_millis(0));
1130 }
1131 }
1132}
1133
1134#[cfg(test)]
1135mod xlarge_stress_tests {
1136 use super::*;
1137
1138 #[test]
1139 fn stress_xlarge_pool_various_sizes_single_thread() {
1140 let lane = simd_lane_width_elems(detect_runtime_simd());
1142 let sizes = [
1143 LARGE_BUFFER_SIZE + 1,
1144 LARGE_BUFFER_SIZE * 2 + lane - 1,
1145 LARGE_BUFFER_SIZE * 3 + 17,
1146 LARGE_BUFFER_SIZE * 4 + lane * 3 + 5,
1147 LARGE_BUFFER_SIZE * 6 + 123,
1148 ];
1149 for _ in 0..1000 {
1150 for &n in &sizes {
1151 let elems = n;
1152 let mut t = crate::tensor::Tensor::new(vec![elems]);
1153 if elems > 0 {
1155 t.set(&[0], 0.0);
1156 }
1157 assert_eq!(t.size(), elems);
1158 }
1159 }
1160 }
1161
1162 #[test]
1163 fn stress_xlarge_pool_multithreaded() {
1164 use std::thread;
1165 let lane = simd_lane_width_elems(detect_runtime_simd());
1166 let sizes = [
1167 LARGE_BUFFER_SIZE + 1,
1168 LARGE_BUFFER_SIZE * 2 + lane - 1,
1169 LARGE_BUFFER_SIZE * 3 + 17,
1170 LARGE_BUFFER_SIZE * 4 + lane * 3 + 5,
1171 LARGE_BUFFER_SIZE * 6 + 123,
1172 ];
1173 let threads = 8usize.min(
1174 std::thread::available_parallelism()
1175 .map(|n| n.get())
1176 .unwrap_or(8),
1177 );
1178 let mut handles = Vec::new();
1179 for tid in 0..threads {
1180 let sizes_clone = sizes;
1181 handles.push(thread::spawn(move || {
1182 for r in 0..20 {
1183 for (i, n) in sizes_clone.iter().enumerate() {
1184 let elems = n + (tid * 13 + r * 7 + i) % lane;
1185 let mut t = crate::tensor::Tensor::new(vec![elems]);
1186 assert_eq!(t.size(), elems);
1187 if elems > 0 {
1189 let idx0 = elems / 2;
1190 let idx1 = (elems.saturating_sub(1)) / 3;
1191 let idx2 = (elems.saturating_sub(1)) / 5;
1192 if idx0 < t.size() {
1194 t.set(&[idx0], 1.2345);
1195 }
1196 if idx1 < t.size() {
1197 t.set(&[idx1], 2.3456);
1198 }
1199 if idx2 < t.size() {
1200 t.set(&[idx2], 3.4567);
1201 }
1202 }
1203 }
1204 }
1205 }));
1206 }
1207 for h in handles {
1208 let _ = h.join();
1209 }
1210 }
1211}
1212
1213#[cfg(test)]
1214mod additional_safety_tests {
1215 use super::*;
1216
1217 #[test]
1218 fn test_pool_alloc_dealloc_balanced_small_medium_large() {
1219 let before = TensorMemoryPool::thread_stats();
1220 {
1221 let _s1 = crate::tensor::Tensor::new(vec![SMALL_BUFFER_SIZE.min(16)]);
1222 let _m1 = crate::tensor::Tensor::new(vec![MEDIUM_BUFFER_SIZE / 4]);
1223 let _l1 = crate::tensor::Tensor::new(vec![LARGE_BUFFER_SIZE / 4]);
1224 }
1225 let after = TensorMemoryPool::thread_stats();
1226 assert!(
1227 after.allocations >= before.allocations + 3,
1228 "allocations did not increase as expected: before={}, after={}",
1229 before.allocations,
1230 after.allocations
1231 );
1232 assert!(
1233 after.deallocations >= before.deallocations + 3,
1234 "deallocations did not increase as expected: before={}, after={}",
1235 before.deallocations,
1236 after.deallocations
1237 );
1238 assert!(
1240 after.current_usage <= before.current_usage,
1241 "current_usage grew: before={}, after={}",
1242 before.current_usage,
1243 after.current_usage
1244 );
1245 }
1246
1247 #[test]
1248 fn test_pointer_alignment_across_classes() {
1249 let align = simd_alignment_bytes(detect_runtime_simd());
1250 for &n in &[
1251 8usize,
1252 SMALL_BUFFER_SIZE,
1253 MEDIUM_BUFFER_SIZE,
1254 LARGE_BUFFER_SIZE + 128,
1255 ] {
1256 let t = crate::tensor::Tensor::new(vec![n]);
1257 unsafe {
1258 let addr = t.as_ptr() as usize;
1259 assert_eq!(
1260 addr % align,
1261 0,
1262 "pointer not aligned to {} for n={}",
1263 align,
1264 n
1265 );
1266 }
1267 }
1268 }
1269
1270 #[test]
1271 fn test_with_no_mem_pool_uses_system_allocator_no_pool_stats() {
1272 let before = TensorMemoryPool::thread_stats();
1273 with_no_mem_pool(|| {
1274 let _t1 = crate::tensor::Tensor::new(vec![64]);
1275 let _t2 = crate::tensor::Tensor::new(vec![2048]);
1276 let _t3 = crate::tensor::Tensor::new(vec![131072]);
1277 });
1278 let after = TensorMemoryPool::thread_stats();
1279 assert_eq!(
1281 after.allocations, before.allocations,
1282 "pool allocations changed with pool disabled: before={}, after={}",
1283 before.allocations, after.allocations
1284 );
1285 assert_eq!(
1286 after.deallocations, before.deallocations,
1287 "pool deallocations changed with pool disabled: before={}, after={}",
1288 before.deallocations, after.deallocations
1289 );
1290 }
1291
1292 #[test]
1293 fn test_cross_thread_drop_does_not_affect_this_thread_stats() {
1294 let before = TensorMemoryPool::thread_stats();
1295 let handle =
1297 std::thread::spawn(|| crate::tensor::Tensor::new(vec![SMALL_BUFFER_SIZE.min(32)]));
1298 let t = handle.join().unwrap();
1299 drop(t); let after = TensorMemoryPool::thread_stats();
1301 assert_eq!(
1302 after.allocations, before.allocations,
1303 "allocations changed in current thread due to cross-thread drop: before={}, after={}",
1304 before.allocations, after.allocations
1305 );
1306 assert_eq!(
1308 after.deallocations, before.deallocations,
1309 "deallocations changed in current thread due to cross-thread drop: before={}, after={}",
1310 before.deallocations, after.deallocations
1311 );
1312 }
1313
1314 #[test]
1315 fn test_many_alloc_dealloc_cycles_no_growth_in_current_usage() {
1316 let before = TensorMemoryPool::thread_stats();
1317 for _ in 0..100 {
1318 let _t1 = crate::tensor::Tensor::new(vec![SMALL_BUFFER_SIZE.min(64)]);
1319 let _t2 = crate::tensor::Tensor::new(vec![MEDIUM_BUFFER_SIZE / 8]);
1320 }
1321 let after = TensorMemoryPool::thread_stats();
1322 assert!(
1324 after.current_usage <= before.current_usage + SMALL_BUFFER_SIZE + MEDIUM_BUFFER_SIZE,
1325 "current_usage unexpected growth: before={}, after={}",
1326 before.current_usage,
1327 after.current_usage
1328 );
1329 }
1330}
1331
1332#[cfg(test)]
1333mod cleanup_tests {
1334 use super::*;
1335 use std::thread;
1336 use std::time::Duration;
1337
1338 fn hold_tensors(count: usize, elems: usize) -> Vec<crate::tensor::Tensor> {
1340 let mut v = Vec::with_capacity(count);
1341 for _ in 0..count {
1342 v.push(crate::tensor::Tensor::new(vec![elems]));
1343 }
1344 v
1345 }
1346
1347 fn bump_ops_small_iters(iters: usize) {
1349 for _ in 0..iters {
1350 let _t = crate::tensor::Tensor::new(vec![SMALL_BUFFER_SIZE.min(8)]);
1351 }
1352 }
1353
1354 #[test]
1355 fn test_no_cleanup_while_many_small_buffers_in_use() {
1356 let holders = hold_tensors(40, SMALL_BUFFER_SIZE.min(32));
1358 let (small_before, _, _, _) = TensorMemoryPool::pool_sizes();
1359 assert!(
1360 small_before >= 40,
1361 "expected >=40 small buffers, got {}",
1362 small_before
1363 );
1364
1365 bump_ops_small_iters(1500); thread::sleep(Duration::from_millis(2100));
1368 bump_ops_small_iters(700); {
1372 let _m = crate::tensor::Tensor::new(vec![MEDIUM_BUFFER_SIZE / 2]);
1373 }
1374
1375 let (small_mid, _, _, _) = TensorMemoryPool::pool_sizes();
1377 assert!(
1378 small_mid >= small_before,
1379 "small pool shrank while heavily in-use: before={} after={}",
1380 small_before,
1381 small_mid
1382 );
1383
1384 drop(holders);
1386
1387 let _ = crate::tensor::Tensor::new(vec![MEDIUM_BUFFER_SIZE / 2]);
1389 let (small_after, _, _, _) = TensorMemoryPool::pool_sizes();
1390 assert!(
1391 small_after >= small_before,
1392 "small pool unexpectedly trimmed active buffers: before={} after={}",
1393 small_before,
1394 small_after
1395 );
1396 }
1397
1398 #[test]
1399 fn test_cleanup_trims_long_idle_medium_buffers() {
1400 {
1402 let _holders = hold_tensors(30, MEDIUM_BUFFER_SIZE / 2);
1403 }
1405 let (_, med_before, _, _) = TensorMemoryPool::pool_sizes();
1406 assert!(
1407 med_before >= 30,
1408 "expected >=30 medium buffers, got {}",
1409 med_before
1410 );
1411
1412 bump_ops_small_iters(2300); thread::sleep(Duration::from_millis(2100));
1415
1416 let _ = crate::tensor::Tensor::new(vec![SMALL_BUFFER_SIZE.min(16)]);
1418 let (_, med_after, _, _) = TensorMemoryPool::pool_sizes();
1419
1420 assert!(
1421 med_after < med_before,
1422 "medium pool not trimmed despite long idle: before={} after={}",
1423 med_before,
1424 med_after
1425 );
1426 }
1427}