tenrso_exec/executor/
thread_local_pool.rs

1//! Thread-local memory pools for zero-contention parallel execution
2//!
3//! This module provides thread-local buffer pools that eliminate contention
4//! in multi-threaded scenarios. Each thread maintains its own pool, avoiding
5//! the need for locks or atomic operations.
6//!
7//! # Performance Benefits
8//!
9//! - **Zero Contention**: No locks needed, each thread has its own pool
10//! - **Cache Locality**: Thread-local data improves cache hit rates
11//! - **Parallel Scalability**: Linear scaling with thread count
12//! - **Low Overhead**: ~10-20ns per acquire/release (vs ~100ns+ with locks)
13//!
14//! # Usage
15//!
16//! ```rust
17//! use tenrso_exec::executor::thread_local_pool::ThreadLocalPoolManager;
18//!
19//! // Enable thread-local pooling
20//! let manager = ThreadLocalPoolManager::new();
21//! manager.enable();
22//!
23//! // In parallel code, each thread will use its own pool
24//! let buffer = manager.acquire_f64(&[1000]);
25//! // ... use buffer ...
26//! manager.release_f64(&[1000], buffer);
27//! ```
28
29use std::cell::RefCell;
30use std::collections::HashMap;
31use std::marker::PhantomData;
32
33/// Maximum number of buffers to pool per shape (per thread)
34const MAX_BUFFERS_PER_SHAPE: usize = 8;
35
36/// Thread-local buffer pool for a single type
37///
38/// This is the actual pool storage that lives in thread-local storage.
39/// It's wrapped in RefCell for interior mutability within the thread.
40struct ThreadLocalBuffer<T>
41where
42    T: bytemuck::Pod + bytemuck::Zeroable + 'static,
43{
44    /// Map from shape signature to buffer pool
45    pools: HashMap<String, Vec<Vec<T>>>,
46    /// Statistics
47    hits: usize,
48    misses: usize,
49    total_allocations: usize,
50    total_releases: usize,
51    _phantom: PhantomData<T>,
52}
53
54impl<T> ThreadLocalBuffer<T>
55where
56    T: bytemuck::Pod + bytemuck::Zeroable + 'static,
57{
58    fn new() -> Self {
59        Self {
60            pools: HashMap::new(),
61            hits: 0,
62            misses: 0,
63            total_allocations: 0,
64            total_releases: 0,
65            _phantom: PhantomData,
66        }
67    }
68
69    fn shape_signature(shape: &[usize]) -> String {
70        shape
71            .iter()
72            .map(|s| s.to_string())
73            .collect::<Vec<_>>()
74            .join("x")
75    }
76
77    fn acquire(&mut self, shape: &[usize]) -> Vec<T> {
78        self.total_allocations += 1;
79        let sig = Self::shape_signature(shape);
80        let size: usize = shape.iter().product();
81
82        if let Some(pool) = self.pools.get_mut(&sig) {
83            if let Some(buffer) = pool.pop() {
84                self.hits += 1;
85                return buffer;
86            }
87        }
88
89        // Miss - allocate new buffer
90        self.misses += 1;
91        vec![T::zeroed(); size]
92    }
93
94    fn release(&mut self, shape: &[usize], buffer: Vec<T>) {
95        self.total_releases += 1;
96        let sig = Self::shape_signature(shape);
97
98        let pool = self.pools.entry(sig).or_default();
99
100        // Limit pool size to prevent unbounded growth
101        if pool.len() < MAX_BUFFERS_PER_SHAPE {
102            pool.push(buffer);
103        }
104        // Otherwise drop the buffer (let it be deallocated)
105    }
106
107    fn clear(&mut self) {
108        self.pools.clear();
109        self.hits = 0;
110        self.misses = 0;
111        self.total_allocations = 0;
112        self.total_releases = 0;
113    }
114
115    fn stats(&self) -> ThreadLocalPoolStats {
116        let total = self.total_allocations;
117        let hit_rate = if total > 0 {
118            self.hits as f64 / total as f64
119        } else {
120            0.0
121        };
122
123        let mut total_bytes = 0;
124        let mut total_buffers = 0;
125        for pool in self.pools.values() {
126            total_buffers += pool.len();
127            for buffer in pool {
128                total_bytes += buffer.len() * std::mem::size_of::<T>();
129            }
130        }
131
132        ThreadLocalPoolStats {
133            hits: self.hits,
134            misses: self.misses,
135            total_allocations: self.total_allocations,
136            total_releases: self.total_releases,
137            hit_rate,
138            unique_shapes: self.pools.len(),
139            total_bytes_pooled: total_bytes,
140            total_buffers_pooled: total_buffers,
141        }
142    }
143}
144
145/// Statistics for a thread-local pool
146#[derive(Debug, Clone, PartialEq)]
147pub struct ThreadLocalPoolStats {
148    pub hits: usize,
149    pub misses: usize,
150    pub total_allocations: usize,
151    pub total_releases: usize,
152    pub hit_rate: f64,
153    pub unique_shapes: usize,
154    pub total_bytes_pooled: usize,
155    pub total_buffers_pooled: usize,
156}
157
158/// Statistics aggregated across all threads
159#[derive(Debug, Clone, PartialEq)]
160pub struct AggregatedPoolStats {
161    pub total_threads: usize,
162    pub total_hits: usize,
163    pub total_misses: usize,
164    pub total_allocations: usize,
165    pub total_releases: usize,
166    pub overall_hit_rate: f64,
167    pub total_bytes_pooled: usize,
168    pub total_buffers_pooled: usize,
169    pub per_thread_stats: Vec<ThreadLocalPoolStats>,
170}
171
172thread_local! {
173    static F32_POOL: RefCell<ThreadLocalBuffer<f32>> = RefCell::new(ThreadLocalBuffer::new());
174    static F64_POOL: RefCell<ThreadLocalBuffer<f64>> = RefCell::new(ThreadLocalBuffer::new());
175}
176
177/// Manager for thread-local memory pools
178///
179/// This provides a global interface to thread-local pools with zero contention.
180/// Each thread maintains its own pools for f32 and f64 buffers.
181///
182/// # Thread Safety
183///
184/// Thread-local pools are completely thread-safe without any locks because
185/// each thread has its own storage. There is no contention between threads.
186///
187/// # Example
188///
189/// ```
190/// use tenrso_exec::executor::thread_local_pool::ThreadLocalPoolManager;
191/// use std::thread;
192///
193/// let manager = ThreadLocalPoolManager::new();
194/// manager.enable();
195///
196/// // Spawn threads - each will use its own pool
197/// let handles: Vec<_> = (0..4)
198///     .map(|_| {
199///         let mgr = manager.clone();
200///         thread::spawn(move || {
201///             for _ in 0..100 {
202///                 let buf = mgr.acquire_f64(&[1000]);
203///                 mgr.release_f64(&[1000], buf);
204///             }
205///         })
206///     })
207///     .collect();
208///
209/// for h in handles {
210///     h.join().unwrap();
211/// }
212///
213/// // Get aggregated statistics from the main thread
214/// let stats = manager.aggregated_stats_f64();
215/// println!("Main thread hit rate: {:.2}%", stats.overall_hit_rate * 100.0);
216/// ```
217#[derive(Clone)]
218pub struct ThreadLocalPoolManager {
219    enabled: bool,
220}
221
222impl ThreadLocalPoolManager {
223    /// Create a new thread-local pool manager
224    pub fn new() -> Self {
225        Self { enabled: true }
226    }
227
228    /// Enable thread-local pooling
229    pub fn enable(&self) {
230        // Thread-local pools are enabled by default
231        // This is a no-op but kept for API consistency
232    }
233
234    /// Disable thread-local pooling
235    pub fn disable(&self) {
236        // Note: Due to thread_local! design, we can't actually disable individual threads
237        // Users should use the regular MemoryPool if they need enable/disable functionality
238    }
239
240    /// Check if pooling is enabled
241    pub fn is_enabled(&self) -> bool {
242        self.enabled
243    }
244
245    /// Acquire an f32 buffer from the current thread's pool
246    pub fn acquire_f32(&self, shape: &[usize]) -> Vec<f32> {
247        if !self.enabled {
248            let size: usize = shape.iter().product();
249            return vec![0.0; size];
250        }
251
252        F32_POOL.with(|pool| pool.borrow_mut().acquire(shape))
253    }
254
255    /// Release an f32 buffer back to the current thread's pool
256    pub fn release_f32(&self, shape: &[usize], buffer: Vec<f32>) {
257        if !self.enabled {
258            return;
259        }
260
261        F32_POOL.with(|pool| pool.borrow_mut().release(shape, buffer))
262    }
263
264    /// Acquire an f64 buffer from the current thread's pool
265    pub fn acquire_f64(&self, shape: &[usize]) -> Vec<f64> {
266        if !self.enabled {
267            let size: usize = shape.iter().product();
268            return vec![0.0; size];
269        }
270
271        F64_POOL.with(|pool| pool.borrow_mut().acquire(shape))
272    }
273
274    /// Release an f64 buffer back to the current thread's pool
275    pub fn release_f64(&self, shape: &[usize], buffer: Vec<f64>) {
276        if !self.enabled {
277            return;
278        }
279
280        F64_POOL.with(|pool| pool.borrow_mut().release(shape, buffer))
281    }
282
283    /// Get statistics for the current thread's f32 pool
284    pub fn thread_stats_f32(&self) -> ThreadLocalPoolStats {
285        F32_POOL.with(|pool| pool.borrow().stats())
286    }
287
288    /// Get statistics for the current thread's f64 pool
289    pub fn thread_stats_f64(&self) -> ThreadLocalPoolStats {
290        F64_POOL.with(|pool| pool.borrow().stats())
291    }
292
293    /// Clear the current thread's f32 pool
294    pub fn clear_thread_f32(&self) {
295        F32_POOL.with(|pool| pool.borrow_mut().clear())
296    }
297
298    /// Clear the current thread's f64 pool
299    pub fn clear_thread_f64(&self) {
300        F64_POOL.with(|pool| pool.borrow_mut().clear())
301    }
302
303    /// Get aggregated statistics from all threads (f32)
304    ///
305    /// Note: This requires cooperation from all threads. The returned
306    /// statistics only include the current thread's data since we can't
307    /// access other threads' thread-local storage directly.
308    ///
309    /// For true multi-thread statistics, use the shared MemoryPool instead.
310    pub fn aggregated_stats_f32(&self) -> AggregatedPoolStats {
311        let thread_stats = self.thread_stats_f32();
312
313        AggregatedPoolStats {
314            total_threads: 1, // Only current thread
315            total_hits: thread_stats.hits,
316            total_misses: thread_stats.misses,
317            total_allocations: thread_stats.total_allocations,
318            total_releases: thread_stats.total_releases,
319            overall_hit_rate: thread_stats.hit_rate,
320            total_bytes_pooled: thread_stats.total_bytes_pooled,
321            total_buffers_pooled: thread_stats.total_buffers_pooled,
322            per_thread_stats: vec![thread_stats],
323        }
324    }
325
326    /// Get aggregated statistics from all threads (f64)
327    pub fn aggregated_stats_f64(&self) -> AggregatedPoolStats {
328        let thread_stats = self.thread_stats_f64();
329
330        AggregatedPoolStats {
331            total_threads: 1, // Only current thread
332            total_hits: thread_stats.hits,
333            total_misses: thread_stats.misses,
334            total_allocations: thread_stats.total_allocations,
335            total_releases: thread_stats.total_releases,
336            overall_hit_rate: thread_stats.hit_rate,
337            total_bytes_pooled: thread_stats.total_bytes_pooled,
338            total_buffers_pooled: thread_stats.total_buffers_pooled,
339            per_thread_stats: vec![thread_stats],
340        }
341    }
342}
343
344impl Default for ThreadLocalPoolManager {
345    fn default() -> Self {
346        Self::new()
347    }
348}
349
350#[cfg(test)]
351mod tests {
352    use super::*;
353    use std::thread;
354
355    #[test]
356    fn test_thread_local_pool_basic_f64() {
357        let manager = ThreadLocalPoolManager::new();
358
359        let buf1 = manager.acquire_f64(&[100]);
360        assert_eq!(buf1.len(), 100);
361
362        let stats = manager.thread_stats_f64();
363        assert_eq!(stats.hits, 0);
364        assert_eq!(stats.misses, 1);
365
366        manager.release_f64(&[100], buf1);
367
368        let buf2 = manager.acquire_f64(&[100]);
369        let stats = manager.thread_stats_f64();
370        assert_eq!(stats.hits, 1);
371        assert_eq!(stats.misses, 1);
372
373        manager.release_f64(&[100], buf2);
374    }
375
376    #[test]
377    fn test_thread_local_pool_basic_f32() {
378        let manager = ThreadLocalPoolManager::new();
379
380        let buf1 = manager.acquire_f32(&[50]);
381        assert_eq!(buf1.len(), 50);
382
383        let stats = manager.thread_stats_f32();
384        assert_eq!(stats.hits, 0);
385        assert_eq!(stats.misses, 1);
386
387        manager.release_f32(&[50], buf1);
388
389        let buf2 = manager.acquire_f32(&[50]);
390        let stats = manager.thread_stats_f32();
391        assert_eq!(stats.hits, 1);
392
393        manager.release_f32(&[50], buf2);
394    }
395
396    #[test]
397    fn test_thread_local_pool_different_shapes() {
398        let manager = ThreadLocalPoolManager::new();
399
400        let buf1 = manager.acquire_f64(&[10, 10]);
401        let buf2 = manager.acquire_f64(&[20, 20]);
402
403        manager.release_f64(&[10, 10], buf1);
404        manager.release_f64(&[20, 20], buf2);
405
406        let stats = manager.thread_stats_f64();
407        assert_eq!(stats.unique_shapes, 2);
408        assert_eq!(stats.misses, 2);
409    }
410
411    #[test]
412    fn test_thread_local_pool_multithread_isolation() {
413        let manager = ThreadLocalPoolManager::new();
414
415        // Main thread allocates
416        let buf = manager.acquire_f64(&[100]);
417        manager.release_f64(&[100], buf);
418
419        let main_stats = manager.thread_stats_f64();
420        assert_eq!(main_stats.hits, 0);
421        assert_eq!(main_stats.misses, 1);
422
423        // Spawn thread - should have its own pool (miss on first acquire)
424        let manager_clone = manager.clone();
425        let handle = thread::spawn(move || {
426            let buf = manager_clone.acquire_f64(&[100]);
427            let stats = manager_clone.thread_stats_f64();
428            assert_eq!(stats.hits, 0); // First acquire in this thread = miss
429            assert_eq!(stats.misses, 1);
430            manager_clone.release_f64(&[100], buf);
431
432            // Second acquire in same thread = hit
433            let buf2 = manager_clone.acquire_f64(&[100]);
434            let stats2 = manager_clone.thread_stats_f64();
435            assert_eq!(stats2.hits, 1); // Second acquire = hit
436            manager_clone.release_f64(&[100], buf2);
437        });
438
439        handle.join().unwrap();
440
441        // Main thread stats unchanged
442        let main_stats_after = manager.thread_stats_f64();
443        assert_eq!(main_stats_after.hits, main_stats.hits);
444        assert_eq!(main_stats_after.misses, main_stats.misses);
445    }
446
447    #[test]
448    fn test_thread_local_pool_clear() {
449        let manager = ThreadLocalPoolManager::new();
450
451        let buf = manager.acquire_f64(&[100]);
452        manager.release_f64(&[100], buf);
453
454        let stats_before = manager.thread_stats_f64();
455        assert_eq!(stats_before.total_buffers_pooled, 1);
456
457        manager.clear_thread_f64();
458
459        let stats_after = manager.thread_stats_f64();
460        assert_eq!(stats_after.total_buffers_pooled, 0);
461        assert_eq!(stats_after.hits, 0);
462        assert_eq!(stats_after.misses, 0);
463    }
464
465    #[test]
466    fn test_thread_local_pool_max_buffers_limit() {
467        let manager = ThreadLocalPoolManager::new();
468
469        // Release more buffers than MAX_BUFFERS_PER_SHAPE
470        for _ in 0..(MAX_BUFFERS_PER_SHAPE + 5) {
471            let buf = manager.acquire_f64(&[100]);
472            manager.release_f64(&[100], buf);
473        }
474
475        let stats = manager.thread_stats_f64();
476        assert!(stats.total_buffers_pooled <= MAX_BUFFERS_PER_SHAPE);
477    }
478
479    #[test]
480    fn test_thread_local_pool_parallel_scalability() {
481        let manager = ThreadLocalPoolManager::new();
482        let num_threads = 4;
483        let iterations = 100;
484
485        let handles: Vec<_> = (0..num_threads)
486            .map(|_| {
487                let mgr = manager.clone();
488                thread::spawn(move || {
489                    for _ in 0..iterations {
490                        let buf = mgr.acquire_f64(&[1000]);
491                        mgr.release_f64(&[1000], buf);
492                    }
493                    mgr.thread_stats_f64()
494                })
495            })
496            .collect();
497
498        let thread_stats: Vec<_> = handles.into_iter().map(|h| h.join().unwrap()).collect();
499
500        // Each thread should have high hit rate (except first iteration)
501        for stats in &thread_stats {
502            assert!(stats.hit_rate >= 0.99); // 99% or better after first miss
503        }
504
505        // Verify isolation - each thread had exactly one miss (first acquire)
506        for stats in &thread_stats {
507            assert_eq!(stats.misses, 1);
508            assert_eq!(stats.hits, iterations - 1);
509        }
510    }
511
512    #[test]
513    fn test_thread_local_pool_disabled() {
514        let mut manager = ThreadLocalPoolManager::new();
515        manager.enabled = false;
516
517        let buf1 = manager.acquire_f64(&[100]);
518        manager.release_f64(&[100], buf1);
519
520        // Should not use pool when disabled
521        let buf2 = manager.acquire_f64(&[100]);
522        manager.release_f64(&[100], buf2);
523
524        let stats = manager.thread_stats_f64();
525        // Stats should still be 0 since pooling was disabled
526        assert_eq!(stats.total_allocations, 0);
527    }
528
529    #[test]
530    fn test_thread_local_pool_mixed_types() {
531        let manager = ThreadLocalPoolManager::new();
532
533        // Acquire both f32 and f64 buffers
534        let buf_f32 = manager.acquire_f32(&[100]);
535        let buf_f64 = manager.acquire_f64(&[100]);
536
537        manager.release_f32(&[100], buf_f32);
538        manager.release_f64(&[100], buf_f64);
539
540        // Each type should have its own pool
541        let stats_f32 = manager.thread_stats_f32();
542        let stats_f64 = manager.thread_stats_f64();
543
544        assert_eq!(stats_f32.misses, 1);
545        assert_eq!(stats_f64.misses, 1);
546        assert_eq!(stats_f32.total_buffers_pooled, 1);
547        assert_eq!(stats_f64.total_buffers_pooled, 1);
548    }
549
550    #[test]
551    fn test_aggregated_stats() {
552        let manager = ThreadLocalPoolManager::new();
553
554        for _ in 0..10 {
555            let buf = manager.acquire_f64(&[100]);
556            manager.release_f64(&[100], buf);
557        }
558
559        let agg_stats = manager.aggregated_stats_f64();
560        assert_eq!(agg_stats.total_threads, 1);
561        assert_eq!(agg_stats.total_allocations, 10);
562        assert_eq!(agg_stats.total_hits, 9);
563        assert_eq!(agg_stats.total_misses, 1);
564        assert!(agg_stats.overall_hit_rate >= 0.9);
565    }
566}