1use std::cell::RefCell;
30use std::collections::HashMap;
31use std::marker::PhantomData;
32
33const MAX_BUFFERS_PER_SHAPE: usize = 8;
35
36struct ThreadLocalBuffer<T>
41where
42 T: bytemuck::Pod + bytemuck::Zeroable + 'static,
43{
44 pools: HashMap<String, Vec<Vec<T>>>,
46 hits: usize,
48 misses: usize,
49 total_allocations: usize,
50 total_releases: usize,
51 _phantom: PhantomData<T>,
52}
53
54impl<T> ThreadLocalBuffer<T>
55where
56 T: bytemuck::Pod + bytemuck::Zeroable + 'static,
57{
58 fn new() -> Self {
59 Self {
60 pools: HashMap::new(),
61 hits: 0,
62 misses: 0,
63 total_allocations: 0,
64 total_releases: 0,
65 _phantom: PhantomData,
66 }
67 }
68
69 fn shape_signature(shape: &[usize]) -> String {
70 shape
71 .iter()
72 .map(|s| s.to_string())
73 .collect::<Vec<_>>()
74 .join("x")
75 }
76
77 fn acquire(&mut self, shape: &[usize]) -> Vec<T> {
78 self.total_allocations += 1;
79 let sig = Self::shape_signature(shape);
80 let size: usize = shape.iter().product();
81
82 if let Some(pool) = self.pools.get_mut(&sig) {
83 if let Some(buffer) = pool.pop() {
84 self.hits += 1;
85 return buffer;
86 }
87 }
88
89 self.misses += 1;
91 vec![T::zeroed(); size]
92 }
93
94 fn release(&mut self, shape: &[usize], buffer: Vec<T>) {
95 self.total_releases += 1;
96 let sig = Self::shape_signature(shape);
97
98 let pool = self.pools.entry(sig).or_default();
99
100 if pool.len() < MAX_BUFFERS_PER_SHAPE {
102 pool.push(buffer);
103 }
104 }
106
107 fn clear(&mut self) {
108 self.pools.clear();
109 self.hits = 0;
110 self.misses = 0;
111 self.total_allocations = 0;
112 self.total_releases = 0;
113 }
114
115 fn stats(&self) -> ThreadLocalPoolStats {
116 let total = self.total_allocations;
117 let hit_rate = if total > 0 {
118 self.hits as f64 / total as f64
119 } else {
120 0.0
121 };
122
123 let mut total_bytes = 0;
124 let mut total_buffers = 0;
125 for pool in self.pools.values() {
126 total_buffers += pool.len();
127 for buffer in pool {
128 total_bytes += buffer.len() * std::mem::size_of::<T>();
129 }
130 }
131
132 ThreadLocalPoolStats {
133 hits: self.hits,
134 misses: self.misses,
135 total_allocations: self.total_allocations,
136 total_releases: self.total_releases,
137 hit_rate,
138 unique_shapes: self.pools.len(),
139 total_bytes_pooled: total_bytes,
140 total_buffers_pooled: total_buffers,
141 }
142 }
143}
144
145#[derive(Debug, Clone, PartialEq)]
147pub struct ThreadLocalPoolStats {
148 pub hits: usize,
149 pub misses: usize,
150 pub total_allocations: usize,
151 pub total_releases: usize,
152 pub hit_rate: f64,
153 pub unique_shapes: usize,
154 pub total_bytes_pooled: usize,
155 pub total_buffers_pooled: usize,
156}
157
158#[derive(Debug, Clone, PartialEq)]
160pub struct AggregatedPoolStats {
161 pub total_threads: usize,
162 pub total_hits: usize,
163 pub total_misses: usize,
164 pub total_allocations: usize,
165 pub total_releases: usize,
166 pub overall_hit_rate: f64,
167 pub total_bytes_pooled: usize,
168 pub total_buffers_pooled: usize,
169 pub per_thread_stats: Vec<ThreadLocalPoolStats>,
170}
171
172thread_local! {
173 static F32_POOL: RefCell<ThreadLocalBuffer<f32>> = RefCell::new(ThreadLocalBuffer::new());
174 static F64_POOL: RefCell<ThreadLocalBuffer<f64>> = RefCell::new(ThreadLocalBuffer::new());
175}
176
177#[derive(Clone)]
218pub struct ThreadLocalPoolManager {
219 enabled: bool,
220}
221
222impl ThreadLocalPoolManager {
223 pub fn new() -> Self {
225 Self { enabled: true }
226 }
227
228 pub fn enable(&self) {
230 }
233
234 pub fn disable(&self) {
236 }
239
240 pub fn is_enabled(&self) -> bool {
242 self.enabled
243 }
244
245 pub fn acquire_f32(&self, shape: &[usize]) -> Vec<f32> {
247 if !self.enabled {
248 let size: usize = shape.iter().product();
249 return vec![0.0; size];
250 }
251
252 F32_POOL.with(|pool| pool.borrow_mut().acquire(shape))
253 }
254
255 pub fn release_f32(&self, shape: &[usize], buffer: Vec<f32>) {
257 if !self.enabled {
258 return;
259 }
260
261 F32_POOL.with(|pool| pool.borrow_mut().release(shape, buffer))
262 }
263
264 pub fn acquire_f64(&self, shape: &[usize]) -> Vec<f64> {
266 if !self.enabled {
267 let size: usize = shape.iter().product();
268 return vec![0.0; size];
269 }
270
271 F64_POOL.with(|pool| pool.borrow_mut().acquire(shape))
272 }
273
274 pub fn release_f64(&self, shape: &[usize], buffer: Vec<f64>) {
276 if !self.enabled {
277 return;
278 }
279
280 F64_POOL.with(|pool| pool.borrow_mut().release(shape, buffer))
281 }
282
283 pub fn thread_stats_f32(&self) -> ThreadLocalPoolStats {
285 F32_POOL.with(|pool| pool.borrow().stats())
286 }
287
288 pub fn thread_stats_f64(&self) -> ThreadLocalPoolStats {
290 F64_POOL.with(|pool| pool.borrow().stats())
291 }
292
293 pub fn clear_thread_f32(&self) {
295 F32_POOL.with(|pool| pool.borrow_mut().clear())
296 }
297
298 pub fn clear_thread_f64(&self) {
300 F64_POOL.with(|pool| pool.borrow_mut().clear())
301 }
302
303 pub fn aggregated_stats_f32(&self) -> AggregatedPoolStats {
311 let thread_stats = self.thread_stats_f32();
312
313 AggregatedPoolStats {
314 total_threads: 1, total_hits: thread_stats.hits,
316 total_misses: thread_stats.misses,
317 total_allocations: thread_stats.total_allocations,
318 total_releases: thread_stats.total_releases,
319 overall_hit_rate: thread_stats.hit_rate,
320 total_bytes_pooled: thread_stats.total_bytes_pooled,
321 total_buffers_pooled: thread_stats.total_buffers_pooled,
322 per_thread_stats: vec![thread_stats],
323 }
324 }
325
326 pub fn aggregated_stats_f64(&self) -> AggregatedPoolStats {
328 let thread_stats = self.thread_stats_f64();
329
330 AggregatedPoolStats {
331 total_threads: 1, total_hits: thread_stats.hits,
333 total_misses: thread_stats.misses,
334 total_allocations: thread_stats.total_allocations,
335 total_releases: thread_stats.total_releases,
336 overall_hit_rate: thread_stats.hit_rate,
337 total_bytes_pooled: thread_stats.total_bytes_pooled,
338 total_buffers_pooled: thread_stats.total_buffers_pooled,
339 per_thread_stats: vec![thread_stats],
340 }
341 }
342}
343
344impl Default for ThreadLocalPoolManager {
345 fn default() -> Self {
346 Self::new()
347 }
348}
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353 use std::thread;
354
355 #[test]
356 fn test_thread_local_pool_basic_f64() {
357 let manager = ThreadLocalPoolManager::new();
358
359 let buf1 = manager.acquire_f64(&[100]);
360 assert_eq!(buf1.len(), 100);
361
362 let stats = manager.thread_stats_f64();
363 assert_eq!(stats.hits, 0);
364 assert_eq!(stats.misses, 1);
365
366 manager.release_f64(&[100], buf1);
367
368 let buf2 = manager.acquire_f64(&[100]);
369 let stats = manager.thread_stats_f64();
370 assert_eq!(stats.hits, 1);
371 assert_eq!(stats.misses, 1);
372
373 manager.release_f64(&[100], buf2);
374 }
375
376 #[test]
377 fn test_thread_local_pool_basic_f32() {
378 let manager = ThreadLocalPoolManager::new();
379
380 let buf1 = manager.acquire_f32(&[50]);
381 assert_eq!(buf1.len(), 50);
382
383 let stats = manager.thread_stats_f32();
384 assert_eq!(stats.hits, 0);
385 assert_eq!(stats.misses, 1);
386
387 manager.release_f32(&[50], buf1);
388
389 let buf2 = manager.acquire_f32(&[50]);
390 let stats = manager.thread_stats_f32();
391 assert_eq!(stats.hits, 1);
392
393 manager.release_f32(&[50], buf2);
394 }
395
396 #[test]
397 fn test_thread_local_pool_different_shapes() {
398 let manager = ThreadLocalPoolManager::new();
399
400 let buf1 = manager.acquire_f64(&[10, 10]);
401 let buf2 = manager.acquire_f64(&[20, 20]);
402
403 manager.release_f64(&[10, 10], buf1);
404 manager.release_f64(&[20, 20], buf2);
405
406 let stats = manager.thread_stats_f64();
407 assert_eq!(stats.unique_shapes, 2);
408 assert_eq!(stats.misses, 2);
409 }
410
411 #[test]
412 fn test_thread_local_pool_multithread_isolation() {
413 let manager = ThreadLocalPoolManager::new();
414
415 let buf = manager.acquire_f64(&[100]);
417 manager.release_f64(&[100], buf);
418
419 let main_stats = manager.thread_stats_f64();
420 assert_eq!(main_stats.hits, 0);
421 assert_eq!(main_stats.misses, 1);
422
423 let manager_clone = manager.clone();
425 let handle = thread::spawn(move || {
426 let buf = manager_clone.acquire_f64(&[100]);
427 let stats = manager_clone.thread_stats_f64();
428 assert_eq!(stats.hits, 0); assert_eq!(stats.misses, 1);
430 manager_clone.release_f64(&[100], buf);
431
432 let buf2 = manager_clone.acquire_f64(&[100]);
434 let stats2 = manager_clone.thread_stats_f64();
435 assert_eq!(stats2.hits, 1); manager_clone.release_f64(&[100], buf2);
437 });
438
439 handle.join().unwrap();
440
441 let main_stats_after = manager.thread_stats_f64();
443 assert_eq!(main_stats_after.hits, main_stats.hits);
444 assert_eq!(main_stats_after.misses, main_stats.misses);
445 }
446
447 #[test]
448 fn test_thread_local_pool_clear() {
449 let manager = ThreadLocalPoolManager::new();
450
451 let buf = manager.acquire_f64(&[100]);
452 manager.release_f64(&[100], buf);
453
454 let stats_before = manager.thread_stats_f64();
455 assert_eq!(stats_before.total_buffers_pooled, 1);
456
457 manager.clear_thread_f64();
458
459 let stats_after = manager.thread_stats_f64();
460 assert_eq!(stats_after.total_buffers_pooled, 0);
461 assert_eq!(stats_after.hits, 0);
462 assert_eq!(stats_after.misses, 0);
463 }
464
465 #[test]
466 fn test_thread_local_pool_max_buffers_limit() {
467 let manager = ThreadLocalPoolManager::new();
468
469 for _ in 0..(MAX_BUFFERS_PER_SHAPE + 5) {
471 let buf = manager.acquire_f64(&[100]);
472 manager.release_f64(&[100], buf);
473 }
474
475 let stats = manager.thread_stats_f64();
476 assert!(stats.total_buffers_pooled <= MAX_BUFFERS_PER_SHAPE);
477 }
478
479 #[test]
480 fn test_thread_local_pool_parallel_scalability() {
481 let manager = ThreadLocalPoolManager::new();
482 let num_threads = 4;
483 let iterations = 100;
484
485 let handles: Vec<_> = (0..num_threads)
486 .map(|_| {
487 let mgr = manager.clone();
488 thread::spawn(move || {
489 for _ in 0..iterations {
490 let buf = mgr.acquire_f64(&[1000]);
491 mgr.release_f64(&[1000], buf);
492 }
493 mgr.thread_stats_f64()
494 })
495 })
496 .collect();
497
498 let thread_stats: Vec<_> = handles.into_iter().map(|h| h.join().unwrap()).collect();
499
500 for stats in &thread_stats {
502 assert!(stats.hit_rate >= 0.99); }
504
505 for stats in &thread_stats {
507 assert_eq!(stats.misses, 1);
508 assert_eq!(stats.hits, iterations - 1);
509 }
510 }
511
512 #[test]
513 fn test_thread_local_pool_disabled() {
514 let mut manager = ThreadLocalPoolManager::new();
515 manager.enabled = false;
516
517 let buf1 = manager.acquire_f64(&[100]);
518 manager.release_f64(&[100], buf1);
519
520 let buf2 = manager.acquire_f64(&[100]);
522 manager.release_f64(&[100], buf2);
523
524 let stats = manager.thread_stats_f64();
525 assert_eq!(stats.total_allocations, 0);
527 }
528
529 #[test]
530 fn test_thread_local_pool_mixed_types() {
531 let manager = ThreadLocalPoolManager::new();
532
533 let buf_f32 = manager.acquire_f32(&[100]);
535 let buf_f64 = manager.acquire_f64(&[100]);
536
537 manager.release_f32(&[100], buf_f32);
538 manager.release_f64(&[100], buf_f64);
539
540 let stats_f32 = manager.thread_stats_f32();
542 let stats_f64 = manager.thread_stats_f64();
543
544 assert_eq!(stats_f32.misses, 1);
545 assert_eq!(stats_f64.misses, 1);
546 assert_eq!(stats_f32.total_buffers_pooled, 1);
547 assert_eq!(stats_f64.total_buffers_pooled, 1);
548 }
549
550 #[test]
551 fn test_aggregated_stats() {
552 let manager = ThreadLocalPoolManager::new();
553
554 for _ in 0..10 {
555 let buf = manager.acquire_f64(&[100]);
556 manager.release_f64(&[100], buf);
557 }
558
559 let agg_stats = manager.aggregated_stats_f64();
560 assert_eq!(agg_stats.total_threads, 1);
561 assert_eq!(agg_stats.total_allocations, 10);
562 assert_eq!(agg_stats.total_hits, 9);
563 assert_eq!(agg_stats.total_misses, 1);
564 assert!(agg_stats.overall_hit_rate >= 0.9);
565 }
566}