1use crate::shape::Shape;
21
22#[cfg(feature = "std")]
23use std::cell::RefCell;
24#[cfg(feature = "std")]
25use std::sync::Mutex;
26
27#[cfg(not(feature = "std"))]
28use core::cell::RefCell;
29
30pub const MAX_STACK_DIMS: usize = 8;
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub struct StackShape<const N: usize> {
36 pub dims: [usize; N],
38 pub ndim: usize,
40}
41
42impl<const N: usize> StackShape<N> {
43 #[inline]
45 pub const fn new(dims: [usize; N]) -> Self {
46 Self { dims, ndim: N }
47 }
48
49 #[inline]
51 pub fn from_slice(dims: &[usize]) -> Option<Self> {
52 if dims.len() > N {
53 return None;
54 }
55 let mut stack_dims = [0; N];
56 let mut i = 0;
57 while i < dims.len() {
58 stack_dims[i] = dims[i];
59 i += 1;
60 }
61 Some(Self {
62 dims: stack_dims,
63 ndim: dims.len(),
64 })
65 }
66
67 #[inline]
69 pub fn as_slice(&self) -> &[usize] {
70 &self.dims[..self.ndim]
71 }
72
73 #[inline]
75 pub const fn numel(&self) -> usize {
76 let mut product = 1;
77 let mut i = 0;
78 while i < self.ndim {
79 product *= self.dims[i];
80 i += 1;
81 }
82 product
83 }
84
85 #[inline]
87 pub fn to_shape(&self) -> Shape {
88 Shape::new(self.as_slice().to_vec())
89 }
90
91 #[inline]
93 pub fn broadcast_compatible<const M: usize>(&self, other: &StackShape<M>) -> bool {
94 let max_ndim = self.ndim.max(other.ndim);
95
96 for i in 0..max_ndim {
97 let dim1 = if i < self.ndim {
98 self.dims[self.ndim - 1 - i]
99 } else {
100 1
101 };
102 let dim2 = if i < other.ndim {
103 other.dims[other.ndim - 1 - i]
104 } else {
105 1
106 };
107
108 if dim1 != dim2 && dim1 != 1 && dim2 != 1 {
109 return false;
110 }
111 }
112 true
113 }
114}
115
116#[derive(Debug, Clone)]
118pub enum CowShape {
119 Borrowed(&'static [usize]),
121 Owned(Shape),
123}
124
125impl CowShape {
126 #[inline]
128 pub const fn from_static(dims: &'static [usize]) -> Self {
129 CowShape::Borrowed(dims)
130 }
131
132 #[inline]
134 pub fn from_owned(shape: Shape) -> Self {
135 CowShape::Owned(shape)
136 }
137
138 #[inline]
140 pub fn as_slice(&self) -> &[usize] {
141 match self {
142 CowShape::Borrowed(dims) => dims,
143 CowShape::Owned(shape) => shape.dims(),
144 }
145 }
146
147 #[inline]
149 pub fn into_owned(self) -> Shape {
150 match self {
151 CowShape::Borrowed(dims) => Shape::new(dims.to_vec()),
152 CowShape::Owned(shape) => shape,
153 }
154 }
155
156 #[inline]
158 pub fn numel(&self) -> usize {
159 self.as_slice().iter().product()
160 }
161}
162
163#[derive(Debug, Default, Clone, Copy)]
165pub struct AllocationStats {
166 pub total_allocations: u64,
168 pub total_bytes: u64,
170 pub avoidable_allocations: u64,
172 pub avoidable_bytes: u64,
174 pub small_allocations: u64,
176 pub medium_allocations: u64,
178 pub large_allocations: u64,
180}
181
182impl AllocationStats {
183 #[inline]
185 pub fn record_allocation(&mut self, bytes: usize, avoidable: bool) {
186 self.total_allocations += 1;
187 self.total_bytes += bytes as u64;
188
189 if avoidable {
190 self.avoidable_allocations += 1;
191 self.avoidable_bytes += bytes as u64;
192 }
193
194 if bytes < 64 {
195 self.small_allocations += 1;
196 } else if bytes < 1024 {
197 self.medium_allocations += 1;
198 } else {
199 self.large_allocations += 1;
200 }
201 }
202
203 pub fn waste_percentage(&self) -> f64 {
205 if self.total_bytes == 0 {
206 0.0
207 } else {
208 (self.avoidable_bytes as f64 / self.total_bytes as f64) * 100.0
209 }
210 }
211
212 pub fn report(&self) -> String {
214 format!(
215 "Allocation Statistics:\n\
216 Total: {} allocations, {} bytes\n\
217 Avoidable: {} allocations, {} bytes ({:.1}% waste)\n\
218 Size distribution: {} small, {} medium, {} large",
219 self.total_allocations,
220 self.total_bytes,
221 self.avoidable_allocations,
222 self.avoidable_bytes,
223 self.waste_percentage(),
224 self.small_allocations,
225 self.medium_allocations,
226 self.large_allocations
227 )
228 }
229}
230
231#[cfg(feature = "std")]
232thread_local! {
233 static ALLOC_STATS: RefCell<AllocationStats> = RefCell::new(AllocationStats::default());
235}
236
237#[cfg(feature = "std")]
239#[inline]
240pub fn track_allocation(bytes: usize, avoidable: bool) {
241 ALLOC_STATS.with(|stats| {
242 stats.borrow_mut().record_allocation(bytes, avoidable);
243 });
244}
245
246#[cfg(feature = "std")]
248pub fn get_allocation_stats() -> AllocationStats {
249 ALLOC_STATS.with(|stats| *stats.borrow())
250}
251
252#[cfg(feature = "std")]
254pub fn reset_allocation_stats() {
255 ALLOC_STATS.with(|stats| {
256 *stats.borrow_mut() = AllocationStats::default();
257 });
258}
259
260#[cfg(feature = "std")]
262pub struct BufferPool<T> {
263 buffers: Mutex<Vec<Vec<T>>>,
265 max_pool_size: usize,
267 buffer_capacity: usize,
269}
270
271#[cfg(feature = "std")]
272impl<T: Clone + Default> BufferPool<T> {
273 pub fn new(buffer_capacity: usize, max_pool_size: usize) -> Self {
275 Self {
276 buffers: Mutex::new(Vec::new()),
277 max_pool_size,
278 buffer_capacity,
279 }
280 }
281
282 pub fn acquire(&self) -> Vec<T> {
284 let mut buffers = self.buffers.lock().expect("lock should not be poisoned");
285 buffers
286 .pop()
287 .unwrap_or_else(|| Vec::with_capacity(self.buffer_capacity))
288 }
289
290 pub fn release(&self, mut buffer: Vec<T>) {
292 buffer.clear();
293
294 let mut buffers = self.buffers.lock().expect("lock should not be poisoned");
295 if buffers.len() < self.max_pool_size {
296 buffers.push(buffer);
297 }
298 }
300
301 pub fn stats(&self) -> (usize, usize) {
303 let buffers = self.buffers.lock().expect("lock should not be poisoned");
304 (buffers.len(), self.max_pool_size)
305 }
306}
307
308#[cfg(feature = "std")]
310static SHAPE_BUFFER_POOL: once_cell::sync::Lazy<BufferPool<usize>> =
311 once_cell::sync::Lazy::new(|| BufferPool::new(8, 100));
312
313#[cfg(feature = "std")]
315#[inline]
316pub fn acquire_shape_buffer() -> Vec<usize> {
317 SHAPE_BUFFER_POOL.acquire()
318}
319
320#[cfg(feature = "std")]
322#[inline]
323pub fn release_shape_buffer(buffer: Vec<usize>) {
324 SHAPE_BUFFER_POOL.release(buffer);
325}
326
327#[cfg(feature = "std")]
329pub struct ScopedBuffer<T: Clone + Default + 'static> {
330 buffer: Option<Vec<T>>,
331 pool: &'static BufferPool<T>,
332}
333
334#[cfg(feature = "std")]
335impl<T: Clone + Default + 'static> ScopedBuffer<T> {
336 pub fn new(pool: &'static BufferPool<T>) -> Self {
338 Self {
339 buffer: Some(pool.acquire()),
340 pool,
341 }
342 }
343
344 pub fn get_mut(&mut self) -> &mut Vec<T> {
346 self.buffer
347 .as_mut()
348 .expect("buffer should be present before drop")
349 }
350
351 pub fn get(&self) -> &Vec<T> {
353 self.buffer
354 .as_ref()
355 .expect("buffer should be present before drop")
356 }
357}
358
359#[cfg(feature = "std")]
360impl<T: Clone + Default + 'static> Drop for ScopedBuffer<T> {
361 fn drop(&mut self) {
362 if let Some(buffer) = self.buffer.take() {
363 self.pool.release(buffer);
364 }
365 }
366}
367
368#[derive(Debug, Clone)]
370pub struct OptimizationRecommendations {
371 pub use_stack_shapes: bool,
373 pub use_buffer_pools: bool,
375 pub use_cow_shapes: bool,
377 pub estimated_speedup: f64,
379 pub estimated_memory_savings: u64,
381}
382
383impl OptimizationRecommendations {
384 pub fn from_stats(stats: &AllocationStats) -> Self {
386 let use_stack_shapes = stats.small_allocations > stats.total_allocations / 2;
387 let use_buffer_pools = stats.avoidable_allocations > stats.total_allocations / 3;
388 let use_cow_shapes = stats.total_allocations > 100;
389
390 let mut estimated_speedup = 1.0;
391 if use_stack_shapes {
392 estimated_speedup *= 2.0;
393 }
394 if use_buffer_pools {
395 estimated_speedup *= 1.5;
396 }
397 if use_cow_shapes {
398 estimated_speedup *= 1.2;
399 }
400
401 Self {
402 use_stack_shapes,
403 use_buffer_pools,
404 use_cow_shapes,
405 estimated_speedup,
406 estimated_memory_savings: stats.avoidable_bytes,
407 }
408 }
409
410 pub fn report(&self) -> String {
412 let mut recommendations = Vec::new();
413
414 if self.use_stack_shapes {
415 recommendations.push("Use StackShape for operations with ≤8 dimensions");
416 }
417 if self.use_buffer_pools {
418 recommendations.push("Use buffer pools for temporary allocations");
419 }
420 if self.use_cow_shapes {
421 recommendations.push("Use CowShape for borrowed/static shapes");
422 }
423
424 format!(
425 "Optimization Recommendations:\n\
426 {}\n\
427 Estimated speedup: {:.1}x\n\
428 Estimated memory savings: {} bytes",
429 recommendations.join("\n"),
430 self.estimated_speedup,
431 self.estimated_memory_savings
432 )
433 }
434}
435
436#[cfg(test)]
437mod tests {
438 use super::*;
439
440 #[test]
441 fn test_stack_shape_creation() {
442 let shape = StackShape::<4>::new([2, 3, 4, 5]);
443 assert_eq!(shape.ndim, 4);
444 assert_eq!(shape.as_slice(), &[2, 3, 4, 5]);
445 assert_eq!(shape.numel(), 120);
446 }
447
448 #[test]
449 fn test_stack_shape_from_slice() {
450 let dims = vec![2, 3, 4];
451 let shape = StackShape::<8>::from_slice(&dims).expect("from_slice should succeed");
452 assert_eq!(shape.ndim, 3);
453 assert_eq!(shape.as_slice(), &[2, 3, 4]);
454 }
455
456 #[test]
457 fn test_stack_shape_broadcast_compatible() {
458 let shape1 = StackShape::<4>::new([3, 1, 4, 1]);
459 let shape2 = StackShape::<3>::from_slice(&[2, 4, 5]).expect("from_slice should succeed");
460
461 assert!(shape1.broadcast_compatible(&shape2));
463
464 let shape3 = StackShape::<3>::from_slice(&[1, 4, 5]).expect("from_slice should succeed");
465 assert!(shape1.broadcast_compatible(&shape3));
466
467 let shape4 = StackShape::<3>::from_slice(&[2, 3, 4]).expect("from_slice should succeed");
469 let shape5 = StackShape::<3>::from_slice(&[2, 5, 4]).expect("from_slice should succeed");
470 assert!(!shape4.broadcast_compatible(&shape5)); }
472
473 #[test]
474 fn test_cow_shape_borrowed() {
475 static DIMS: [usize; 3] = [2, 3, 4];
476 let cow = CowShape::from_static(&DIMS);
477 assert_eq!(cow.as_slice(), &[2, 3, 4]);
478 assert_eq!(cow.numel(), 24);
479 }
480
481 #[test]
482 fn test_cow_shape_owned() {
483 let shape = Shape::new(vec![2, 3, 4]);
484 let cow = CowShape::from_owned(shape);
485 assert_eq!(cow.as_slice(), &[2, 3, 4]);
486 }
487
488 #[test]
489 fn test_allocation_stats() {
490 let mut stats = AllocationStats::default();
491
492 stats.record_allocation(32, true); stats.record_allocation(128, false); stats.record_allocation(2048, true); assert_eq!(stats.total_allocations, 3);
498 assert_eq!(stats.avoidable_allocations, 2);
499 assert_eq!(stats.small_allocations, 1);
500 assert_eq!(stats.medium_allocations, 1);
501 assert_eq!(stats.large_allocations, 1);
502
503 let waste = stats.waste_percentage();
504 assert!(waste > 90.0); }
506
507 #[test]
508 #[cfg(feature = "std")]
509 fn test_buffer_pool() {
510 let pool = BufferPool::<usize>::new(10, 5);
511
512 let mut buffer1 = pool.acquire();
514 buffer1.extend_from_slice(&[1, 2, 3]);
515 pool.release(buffer1);
516
517 let (available, max) = pool.stats();
519 assert_eq!(available, 1);
520 assert_eq!(max, 5);
521
522 let buffer2 = pool.acquire();
524 assert!(buffer2.is_empty()); let (available, _) = pool.stats();
528 assert_eq!(available, 0);
529
530 pool.release(buffer2);
532
533 let (available, _) = pool.stats();
535 assert_eq!(available, 1);
536 }
537
538 #[test]
539 #[cfg(feature = "std")]
540 fn test_scoped_buffer() {
541 static POOL: once_cell::sync::Lazy<BufferPool<usize>> =
542 once_cell::sync::Lazy::new(|| BufferPool::new(10, 5));
543
544 {
545 let mut scoped = ScopedBuffer::new(&*POOL);
546 scoped.get_mut().push(42);
547 assert_eq!(scoped.get()[0], 42);
548 }
549 let (available, _) = POOL.stats();
552 assert_eq!(available, 1);
553 }
554
555 #[test]
556 fn test_optimization_recommendations() {
557 let mut stats = AllocationStats::default();
558
559 for _ in 0..100 {
561 stats.record_allocation(32, true);
562 }
563
564 let recommendations = OptimizationRecommendations::from_stats(&stats);
565 assert!(recommendations.use_stack_shapes);
566 assert!(recommendations.use_buffer_pools);
567 assert!(recommendations.estimated_speedup > 1.5);
568 }
569
570 #[test]
571 #[cfg(feature = "std")]
572 fn test_global_shape_buffer_pool() {
573 let mut buffer = acquire_shape_buffer();
574 buffer.extend_from_slice(&[1, 2, 3, 4]);
575 assert_eq!(buffer.len(), 4);
576
577 release_shape_buffer(buffer);
578
579 let buffer2 = acquire_shape_buffer();
581 assert_eq!(buffer2.len(), 0);
582 release_shape_buffer(buffer2);
583 }
584
585 #[test]
586 #[cfg(feature = "std")]
587 fn test_allocation_tracking() {
588 reset_allocation_stats();
589
590 track_allocation(64, false);
591 track_allocation(128, true);
592
593 let stats = get_allocation_stats();
594 assert_eq!(stats.total_allocations, 2);
595 assert_eq!(stats.avoidable_allocations, 1);
596 }
597}