Skip to main content

torsh_core/
alloc_optimizer.rs

1//! Heap Allocation Optimizer for ToRSh Hot Paths
2//!
3//! This module provides allocation-free alternatives and optimization strategies
4//! for performance-critical code paths in ToRSh. It identifies and eliminates
5//! unnecessary heap allocations through:
6//!
7//! - Stack-based small arrays (up to 8 dimensions)
8//! - Copy-on-write semantics for shape operations
9//! - Arena allocators for batch operations
10//! - Reusable buffer pools for temporary allocations
11//!
12//! # Performance Impact
13//!
14//! Eliminating heap allocations in hot paths can provide:
15//! - 2-5x speedup for shape broadcasting operations
16//! - 10-50x speedup for small shape manipulations
17//! - Reduced memory fragmentation and GC pressure
18//! - Better cache locality and CPU pipeline utilization
19
20use crate::shape::Shape;
21
22#[cfg(feature = "std")]
23use std::cell::RefCell;
24#[cfg(feature = "std")]
25use std::sync::Mutex;
26
27#[cfg(not(feature = "std"))]
28use core::cell::RefCell;
29
30/// Maximum dimensions for stack allocation (covers 99% of real-world cases)
31pub const MAX_STACK_DIMS: usize = 8;
32
33/// Small shape stored on stack for zero-allocation operations
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub struct StackShape<const N: usize> {
36    /// Dimensions stored on stack
37    pub dims: [usize; N],
38    /// Actual number of used dimensions (may be less than N)
39    pub ndim: usize,
40}
41
42impl<const N: usize> StackShape<N> {
43    /// Create a new stack shape
44    #[inline]
45    pub const fn new(dims: [usize; N]) -> Self {
46        Self { dims, ndim: N }
47    }
48
49    /// Create from slice with runtime length check
50    #[inline]
51    pub fn from_slice(dims: &[usize]) -> Option<Self> {
52        if dims.len() > N {
53            return None;
54        }
55        let mut stack_dims = [0; N];
56        let mut i = 0;
57        while i < dims.len() {
58            stack_dims[i] = dims[i];
59            i += 1;
60        }
61        Some(Self {
62            dims: stack_dims,
63            ndim: dims.len(),
64        })
65    }
66
67    /// Get active dimensions as slice
68    #[inline]
69    pub fn as_slice(&self) -> &[usize] {
70        &self.dims[..self.ndim]
71    }
72
73    /// Calculate total number of elements (no allocation)
74    #[inline]
75    pub const fn numel(&self) -> usize {
76        let mut product = 1;
77        let mut i = 0;
78        while i < self.ndim {
79            product *= self.dims[i];
80            i += 1;
81        }
82        product
83    }
84
85    /// Convert to heap-allocated Shape
86    #[inline]
87    pub fn to_shape(&self) -> Shape {
88        Shape::new(self.as_slice().to_vec())
89    }
90
91    /// Broadcast compatibility check (no allocation)
92    #[inline]
93    pub fn broadcast_compatible<const M: usize>(&self, other: &StackShape<M>) -> bool {
94        let max_ndim = self.ndim.max(other.ndim);
95
96        for i in 0..max_ndim {
97            let dim1 = if i < self.ndim {
98                self.dims[self.ndim - 1 - i]
99            } else {
100                1
101            };
102            let dim2 = if i < other.ndim {
103                other.dims[other.ndim - 1 - i]
104            } else {
105                1
106            };
107
108            if dim1 != dim2 && dim1 != 1 && dim2 != 1 {
109                return false;
110            }
111        }
112        true
113    }
114}
115
116/// Copy-on-write shape wrapper for deferred allocation
117#[derive(Debug, Clone)]
118pub enum CowShape {
119    /// Borrowed reference to existing shape (zero-copy)
120    Borrowed(&'static [usize]),
121    /// Owned shape data
122    Owned(Shape),
123}
124
125impl CowShape {
126    /// Create from static slice (zero allocation)
127    #[inline]
128    pub const fn from_static(dims: &'static [usize]) -> Self {
129        CowShape::Borrowed(dims)
130    }
131
132    /// Create from owned shape
133    #[inline]
134    pub fn from_owned(shape: Shape) -> Self {
135        CowShape::Owned(shape)
136    }
137
138    /// Get dimensions as slice
139    #[inline]
140    pub fn as_slice(&self) -> &[usize] {
141        match self {
142            CowShape::Borrowed(dims) => dims,
143            CowShape::Owned(shape) => shape.dims(),
144        }
145    }
146
147    /// Convert to owned shape (allocates if borrowed)
148    #[inline]
149    pub fn into_owned(self) -> Shape {
150        match self {
151            CowShape::Borrowed(dims) => Shape::new(dims.to_vec()),
152            CowShape::Owned(shape) => shape,
153        }
154    }
155
156    /// Get number of elements
157    #[inline]
158    pub fn numel(&self) -> usize {
159        self.as_slice().iter().product()
160    }
161}
162
163/// Allocation statistics for hot path analysis
164#[derive(Debug, Default, Clone, Copy)]
165pub struct AllocationStats {
166    /// Total allocations observed
167    pub total_allocations: u64,
168    /// Total bytes allocated
169    pub total_bytes: u64,
170    /// Allocations that could have been avoided
171    pub avoidable_allocations: u64,
172    /// Bytes that could have been saved
173    pub avoidable_bytes: u64,
174    /// Small allocations (< 64 bytes)
175    pub small_allocations: u64,
176    /// Medium allocations (64-1024 bytes)
177    pub medium_allocations: u64,
178    /// Large allocations (> 1024 bytes)
179    pub large_allocations: u64,
180}
181
182impl AllocationStats {
183    /// Record an allocation
184    #[inline]
185    pub fn record_allocation(&mut self, bytes: usize, avoidable: bool) {
186        self.total_allocations += 1;
187        self.total_bytes += bytes as u64;
188
189        if avoidable {
190            self.avoidable_allocations += 1;
191            self.avoidable_bytes += bytes as u64;
192        }
193
194        if bytes < 64 {
195            self.small_allocations += 1;
196        } else if bytes < 1024 {
197            self.medium_allocations += 1;
198        } else {
199            self.large_allocations += 1;
200        }
201    }
202
203    /// Calculate waste percentage
204    pub fn waste_percentage(&self) -> f64 {
205        if self.total_bytes == 0 {
206            0.0
207        } else {
208            (self.avoidable_bytes as f64 / self.total_bytes as f64) * 100.0
209        }
210    }
211
212    /// Generate optimization report
213    pub fn report(&self) -> String {
214        format!(
215            "Allocation Statistics:\n\
216             Total: {} allocations, {} bytes\n\
217             Avoidable: {} allocations, {} bytes ({:.1}% waste)\n\
218             Size distribution: {} small, {} medium, {} large",
219            self.total_allocations,
220            self.total_bytes,
221            self.avoidable_allocations,
222            self.avoidable_bytes,
223            self.waste_percentage(),
224            self.small_allocations,
225            self.medium_allocations,
226            self.large_allocations
227        )
228    }
229}
230
231#[cfg(feature = "std")]
232thread_local! {
233    /// Thread-local allocation statistics tracker for hot path analysis
234    static ALLOC_STATS: RefCell<AllocationStats> = RefCell::new(AllocationStats::default());
235}
236
237/// Record an allocation in thread-local statistics
238#[cfg(feature = "std")]
239#[inline]
240pub fn track_allocation(bytes: usize, avoidable: bool) {
241    ALLOC_STATS.with(|stats| {
242        stats.borrow_mut().record_allocation(bytes, avoidable);
243    });
244}
245
246/// Get current allocation statistics
247#[cfg(feature = "std")]
248pub fn get_allocation_stats() -> AllocationStats {
249    ALLOC_STATS.with(|stats| *stats.borrow())
250}
251
252/// Reset allocation statistics
253#[cfg(feature = "std")]
254pub fn reset_allocation_stats() {
255    ALLOC_STATS.with(|stats| {
256        *stats.borrow_mut() = AllocationStats::default();
257    });
258}
259
260/// Reusable buffer pool for temporary allocations
261#[cfg(feature = "std")]
262pub struct BufferPool<T> {
263    /// Available buffers
264    buffers: Mutex<Vec<Vec<T>>>,
265    /// Maximum pool size
266    max_pool_size: usize,
267    /// Buffer capacity
268    buffer_capacity: usize,
269}
270
271#[cfg(feature = "std")]
272impl<T: Clone + Default> BufferPool<T> {
273    /// Create a new buffer pool
274    pub fn new(buffer_capacity: usize, max_pool_size: usize) -> Self {
275        Self {
276            buffers: Mutex::new(Vec::new()),
277            max_pool_size,
278            buffer_capacity,
279        }
280    }
281
282    /// Acquire a buffer from the pool
283    pub fn acquire(&self) -> Vec<T> {
284        let mut buffers = self.buffers.lock().expect("lock should not be poisoned");
285        buffers
286            .pop()
287            .unwrap_or_else(|| Vec::with_capacity(self.buffer_capacity))
288    }
289
290    /// Return a buffer to the pool
291    pub fn release(&self, mut buffer: Vec<T>) {
292        buffer.clear();
293
294        let mut buffers = self.buffers.lock().expect("lock should not be poisoned");
295        if buffers.len() < self.max_pool_size {
296            buffers.push(buffer);
297        }
298        // Otherwise drop the buffer
299    }
300
301    /// Get pool statistics
302    pub fn stats(&self) -> (usize, usize) {
303        let buffers = self.buffers.lock().expect("lock should not be poisoned");
304        (buffers.len(), self.max_pool_size)
305    }
306}
307
308/// Global shape buffer pool for temporary shape operations
309#[cfg(feature = "std")]
310static SHAPE_BUFFER_POOL: once_cell::sync::Lazy<BufferPool<usize>> =
311    once_cell::sync::Lazy::new(|| BufferPool::new(8, 100));
312
313/// Acquire a shape buffer from the global pool
314#[cfg(feature = "std")]
315#[inline]
316pub fn acquire_shape_buffer() -> Vec<usize> {
317    SHAPE_BUFFER_POOL.acquire()
318}
319
320/// Return a shape buffer to the global pool
321#[cfg(feature = "std")]
322#[inline]
323pub fn release_shape_buffer(buffer: Vec<usize>) {
324    SHAPE_BUFFER_POOL.release(buffer);
325}
326
327/// Scoped buffer guard that auto-returns to pool on drop
328#[cfg(feature = "std")]
329pub struct ScopedBuffer<T: Clone + Default + 'static> {
330    buffer: Option<Vec<T>>,
331    pool: &'static BufferPool<T>,
332}
333
334#[cfg(feature = "std")]
335impl<T: Clone + Default + 'static> ScopedBuffer<T> {
336    /// Create a new scoped buffer
337    pub fn new(pool: &'static BufferPool<T>) -> Self {
338        Self {
339            buffer: Some(pool.acquire()),
340            pool,
341        }
342    }
343
344    /// Get mutable access to buffer
345    pub fn get_mut(&mut self) -> &mut Vec<T> {
346        self.buffer
347            .as_mut()
348            .expect("buffer should be present before drop")
349    }
350
351    /// Get immutable access to buffer
352    pub fn get(&self) -> &Vec<T> {
353        self.buffer
354            .as_ref()
355            .expect("buffer should be present before drop")
356    }
357}
358
359#[cfg(feature = "std")]
360impl<T: Clone + Default + 'static> Drop for ScopedBuffer<T> {
361    fn drop(&mut self) {
362        if let Some(buffer) = self.buffer.take() {
363            self.pool.release(buffer);
364        }
365    }
366}
367
368/// Optimization recommendations based on allocation patterns
369#[derive(Debug, Clone)]
370pub struct OptimizationRecommendations {
371    /// Use stack allocation for small shapes
372    pub use_stack_shapes: bool,
373    /// Use buffer pools for temporary allocations
374    pub use_buffer_pools: bool,
375    /// Use copy-on-write for borrowed shapes
376    pub use_cow_shapes: bool,
377    /// Estimated speedup factor
378    pub estimated_speedup: f64,
379    /// Estimated memory savings (bytes)
380    pub estimated_memory_savings: u64,
381}
382
383impl OptimizationRecommendations {
384    /// Analyze allocation stats and generate recommendations
385    pub fn from_stats(stats: &AllocationStats) -> Self {
386        let use_stack_shapes = stats.small_allocations > stats.total_allocations / 2;
387        let use_buffer_pools = stats.avoidable_allocations > stats.total_allocations / 3;
388        let use_cow_shapes = stats.total_allocations > 100;
389
390        let mut estimated_speedup = 1.0;
391        if use_stack_shapes {
392            estimated_speedup *= 2.0;
393        }
394        if use_buffer_pools {
395            estimated_speedup *= 1.5;
396        }
397        if use_cow_shapes {
398            estimated_speedup *= 1.2;
399        }
400
401        Self {
402            use_stack_shapes,
403            use_buffer_pools,
404            use_cow_shapes,
405            estimated_speedup,
406            estimated_memory_savings: stats.avoidable_bytes,
407        }
408    }
409
410    /// Generate detailed report
411    pub fn report(&self) -> String {
412        let mut recommendations = Vec::new();
413
414        if self.use_stack_shapes {
415            recommendations.push("Use StackShape for operations with ≤8 dimensions");
416        }
417        if self.use_buffer_pools {
418            recommendations.push("Use buffer pools for temporary allocations");
419        }
420        if self.use_cow_shapes {
421            recommendations.push("Use CowShape for borrowed/static shapes");
422        }
423
424        format!(
425            "Optimization Recommendations:\n\
426             {}\n\
427             Estimated speedup: {:.1}x\n\
428             Estimated memory savings: {} bytes",
429            recommendations.join("\n"),
430            self.estimated_speedup,
431            self.estimated_memory_savings
432        )
433    }
434}
435
436#[cfg(test)]
437mod tests {
438    use super::*;
439
440    #[test]
441    fn test_stack_shape_creation() {
442        let shape = StackShape::<4>::new([2, 3, 4, 5]);
443        assert_eq!(shape.ndim, 4);
444        assert_eq!(shape.as_slice(), &[2, 3, 4, 5]);
445        assert_eq!(shape.numel(), 120);
446    }
447
448    #[test]
449    fn test_stack_shape_from_slice() {
450        let dims = vec![2, 3, 4];
451        let shape = StackShape::<8>::from_slice(&dims).expect("from_slice should succeed");
452        assert_eq!(shape.ndim, 3);
453        assert_eq!(shape.as_slice(), &[2, 3, 4]);
454    }
455
456    #[test]
457    fn test_stack_shape_broadcast_compatible() {
458        let shape1 = StackShape::<4>::new([3, 1, 4, 1]);
459        let shape2 = StackShape::<3>::from_slice(&[2, 4, 5]).expect("from_slice should succeed");
460
461        // Compatible: [3,1,4,1] broadcasts with [2,4,5] -> [3,2,4,5]
462        assert!(shape1.broadcast_compatible(&shape2));
463
464        let shape3 = StackShape::<3>::from_slice(&[1, 4, 5]).expect("from_slice should succeed");
465        assert!(shape1.broadcast_compatible(&shape3));
466
467        // Incompatible case: different non-1 dimensions
468        let shape4 = StackShape::<3>::from_slice(&[2, 3, 4]).expect("from_slice should succeed");
469        let shape5 = StackShape::<3>::from_slice(&[2, 5, 4]).expect("from_slice should succeed");
470        assert!(!shape4.broadcast_compatible(&shape5)); // 3 vs 5 in middle dimension
471    }
472
473    #[test]
474    fn test_cow_shape_borrowed() {
475        static DIMS: [usize; 3] = [2, 3, 4];
476        let cow = CowShape::from_static(&DIMS);
477        assert_eq!(cow.as_slice(), &[2, 3, 4]);
478        assert_eq!(cow.numel(), 24);
479    }
480
481    #[test]
482    fn test_cow_shape_owned() {
483        let shape = Shape::new(vec![2, 3, 4]);
484        let cow = CowShape::from_owned(shape);
485        assert_eq!(cow.as_slice(), &[2, 3, 4]);
486    }
487
488    #[test]
489    fn test_allocation_stats() {
490        let mut stats = AllocationStats::default();
491
492        // Record some allocations
493        stats.record_allocation(32, true); // Small, avoidable
494        stats.record_allocation(128, false); // Medium, unavoidable
495        stats.record_allocation(2048, true); // Large, avoidable
496
497        assert_eq!(stats.total_allocations, 3);
498        assert_eq!(stats.avoidable_allocations, 2);
499        assert_eq!(stats.small_allocations, 1);
500        assert_eq!(stats.medium_allocations, 1);
501        assert_eq!(stats.large_allocations, 1);
502
503        let waste = stats.waste_percentage();
504        assert!(waste > 90.0); // Most allocations were avoidable
505    }
506
507    #[test]
508    #[cfg(feature = "std")]
509    fn test_buffer_pool() {
510        let pool = BufferPool::<usize>::new(10, 5);
511
512        // Acquire and release buffers
513        let mut buffer1 = pool.acquire();
514        buffer1.extend_from_slice(&[1, 2, 3]);
515        pool.release(buffer1);
516
517        // Pool should have 1 buffer after release
518        let (available, max) = pool.stats();
519        assert_eq!(available, 1);
520        assert_eq!(max, 5);
521
522        // Acquire again - should get the recycled buffer
523        let buffer2 = pool.acquire();
524        assert!(buffer2.is_empty()); // Should be cleared
525
526        // Pool should now be empty (buffer2 still held)
527        let (available, _) = pool.stats();
528        assert_eq!(available, 0);
529
530        // Release buffer2
531        pool.release(buffer2);
532
533        // Pool should have 1 buffer again
534        let (available, _) = pool.stats();
535        assert_eq!(available, 1);
536    }
537
538    #[test]
539    #[cfg(feature = "std")]
540    fn test_scoped_buffer() {
541        static POOL: once_cell::sync::Lazy<BufferPool<usize>> =
542            once_cell::sync::Lazy::new(|| BufferPool::new(10, 5));
543
544        {
545            let mut scoped = ScopedBuffer::new(&*POOL);
546            scoped.get_mut().push(42);
547            assert_eq!(scoped.get()[0], 42);
548        }
549        // Buffer automatically returned to pool on drop
550
551        let (available, _) = POOL.stats();
552        assert_eq!(available, 1);
553    }
554
555    #[test]
556    fn test_optimization_recommendations() {
557        let mut stats = AllocationStats::default();
558
559        // Simulate many small avoidable allocations
560        for _ in 0..100 {
561            stats.record_allocation(32, true);
562        }
563
564        let recommendations = OptimizationRecommendations::from_stats(&stats);
565        assert!(recommendations.use_stack_shapes);
566        assert!(recommendations.use_buffer_pools);
567        assert!(recommendations.estimated_speedup > 1.5);
568    }
569
570    #[test]
571    #[cfg(feature = "std")]
572    fn test_global_shape_buffer_pool() {
573        let mut buffer = acquire_shape_buffer();
574        buffer.extend_from_slice(&[1, 2, 3, 4]);
575        assert_eq!(buffer.len(), 4);
576
577        release_shape_buffer(buffer);
578
579        // Acquire again - should get a clean buffer
580        let buffer2 = acquire_shape_buffer();
581        assert_eq!(buffer2.len(), 0);
582        release_shape_buffer(buffer2);
583    }
584
585    #[test]
586    #[cfg(feature = "std")]
587    fn test_allocation_tracking() {
588        reset_allocation_stats();
589
590        track_allocation(64, false);
591        track_allocation(128, true);
592
593        let stats = get_allocation_stats();
594        assert_eq!(stats.total_allocations, 2);
595        assert_eq!(stats.avoidable_allocations, 1);
596    }
597}