tenrso_exec/executor/
pool_heuristics.rs

1//! Smart pooling heuristics for automatic buffer pool management
2//!
3//! This module provides intelligent heuristics for when and how to use
4//! memory pooling to maximize performance while minimizing memory overhead.
5//!
6//! # Heuristics
7//!
8//! - **Size-based**: Pool buffers above a minimum size threshold
9//! - **Frequency-based**: Pool shapes accessed frequently
10//! - **Memory-aware**: Adjust pooling based on available memory
11//! - **Operation-aware**: Consider operation characteristics
12//!
13//! # Example
14//!
15//! ```
16//! use tenrso_exec::executor::pool_heuristics::PoolingPolicy;
17//!
18//! let policy = PoolingPolicy::default();
19//!
20//! // Should we pool a 1000-element f64 buffer?
21//! if policy.should_pool(&[1000], std::mem::size_of::<f64>()) {
22//!     println!("Pooling recommended");
23//! }
24//! ```
25
26use std::collections::HashMap;
27
28/// Minimum buffer size (in bytes) to consider for pooling
29/// Buffers smaller than this are cheap to allocate/deallocate
30pub const DEFAULT_MIN_POOL_SIZE_BYTES: usize = 4096; // 4KB
31
32/// Maximum buffer size (in bytes) to pool
33/// Very large buffers can cause memory pressure
34pub const DEFAULT_MAX_POOL_SIZE_BYTES: usize = 64 * 1024 * 1024; // 64MB
35
36/// Minimum access frequency to warrant pooling
37/// Shape must be accessed at least this many times
38pub const DEFAULT_MIN_ACCESS_FREQUENCY: usize = 2;
39
40/// Policy for determining when to use buffer pooling
41#[derive(Debug, Clone)]
42pub struct PoolingPolicy {
43    /// Minimum buffer size in bytes to pool
44    pub min_size_bytes: usize,
45    /// Maximum buffer size in bytes to pool
46    pub max_size_bytes: usize,
47    /// Minimum access frequency for a shape
48    pub min_frequency: usize,
49    /// Available memory threshold (0.0-1.0)
50    /// Don't pool aggressively if available memory is low
51    pub memory_pressure_threshold: f64,
52    /// Enable adaptive thresholds based on runtime behavior
53    pub adaptive: bool,
54}
55
56impl PoolingPolicy {
57    /// Create a new pooling policy with default settings
58    pub fn new() -> Self {
59        Self {
60            min_size_bytes: DEFAULT_MIN_POOL_SIZE_BYTES,
61            max_size_bytes: DEFAULT_MAX_POOL_SIZE_BYTES,
62            min_frequency: DEFAULT_MIN_ACCESS_FREQUENCY,
63            memory_pressure_threshold: 0.2, // Pool less when <20% memory free
64            adaptive: true,
65        }
66    }
67
68    /// Create a conservative policy (pool only large, frequently accessed buffers)
69    pub fn conservative() -> Self {
70        Self {
71            min_size_bytes: 16384,            // 16KB minimum
72            max_size_bytes: 32 * 1024 * 1024, // 32MB max
73            min_frequency: 5,
74            memory_pressure_threshold: 0.3,
75            adaptive: false,
76        }
77    }
78
79    /// Create an aggressive policy (pool most buffers)
80    pub fn aggressive() -> Self {
81        Self {
82            min_size_bytes: 1024,              // 1KB minimum
83            max_size_bytes: 128 * 1024 * 1024, // 128MB max
84            min_frequency: 1,
85            memory_pressure_threshold: 0.1,
86            adaptive: true,
87        }
88    }
89
90    /// Create a memory-constrained policy (minimize memory usage)
91    pub fn memory_constrained() -> Self {
92        Self {
93            min_size_bytes: 8192,             // 8KB minimum
94            max_size_bytes: 16 * 1024 * 1024, // 16MB max
95            min_frequency: 3,
96            memory_pressure_threshold: 0.4,
97            adaptive: false,
98        }
99    }
100
101    /// Determine if a buffer should be pooled based on its characteristics
102    ///
103    /// # Arguments
104    ///
105    /// * `shape` - Buffer shape
106    /// * `elem_size` - Size of each element in bytes
107    ///
108    /// # Returns
109    ///
110    /// `true` if the buffer should be pooled, `false` otherwise
111    pub fn should_pool(&self, shape: &[usize], elem_size: usize) -> bool {
112        let total_elements: usize = shape.iter().product();
113        let total_bytes = total_elements * elem_size;
114
115        // Check size bounds
116        if total_bytes < self.min_size_bytes {
117            return false; // Too small - allocation overhead is minimal
118        }
119
120        if total_bytes > self.max_size_bytes {
121            return false; // Too large - can cause memory pressure
122        }
123
124        true
125    }
126
127    /// Check if we should pool based on access frequency
128    ///
129    /// This should be called after tracking access patterns.
130    pub fn should_pool_with_frequency(
131        &self,
132        shape: &[usize],
133        elem_size: usize,
134        frequency: usize,
135    ) -> bool {
136        if !self.should_pool(shape, elem_size) {
137            return false;
138        }
139
140        frequency >= self.min_frequency
141    }
142
143    /// Adjust thresholds based on memory pressure
144    ///
145    /// # Arguments
146    ///
147    /// * `available_memory_ratio` - Fraction of memory available (0.0-1.0)
148    ///
149    /// # Returns
150    ///
151    /// Adjusted policy with modified thresholds
152    pub fn with_memory_pressure(&self, available_memory_ratio: f64) -> Self {
153        if !self.adaptive {
154            return self.clone();
155        }
156
157        let mut adjusted = self.clone();
158
159        if available_memory_ratio < self.memory_pressure_threshold {
160            // High memory pressure - be more conservative
161            adjusted.min_size_bytes *= 2;
162            adjusted.max_size_bytes /= 2;
163            adjusted.min_frequency += 2;
164        } else if available_memory_ratio > 0.5 {
165            // Low memory pressure - can be more aggressive
166            adjusted.min_size_bytes = adjusted.min_size_bytes.saturating_sub(1024);
167            adjusted.max_size_bytes = adjusted.max_size_bytes.saturating_mul(2);
168        }
169
170        adjusted
171    }
172}
173
174impl Default for PoolingPolicy {
175    fn default() -> Self {
176        Self::new()
177    }
178}
179
180/// Access pattern tracker for shape-based pooling decisions
181///
182/// Tracks how frequently different shapes are accessed to make
183/// intelligent pooling decisions.
184#[derive(Debug, Clone)]
185pub struct AccessPatternTracker {
186    /// Shape signature -> access count
187    access_counts: HashMap<String, usize>,
188    /// Total number of allocations tracked
189    total_accesses: usize,
190}
191
192impl AccessPatternTracker {
193    /// Create a new access pattern tracker
194    pub fn new() -> Self {
195        Self {
196            access_counts: HashMap::new(),
197            total_accesses: 0,
198        }
199    }
200
201    fn shape_signature(shape: &[usize]) -> String {
202        shape
203            .iter()
204            .map(|s| s.to_string())
205            .collect::<Vec<_>>()
206            .join("x")
207    }
208
209    /// Record an access to a shape
210    pub fn record_access(&mut self, shape: &[usize]) {
211        let sig = Self::shape_signature(shape);
212        *self.access_counts.entry(sig).or_insert(0) += 1;
213        self.total_accesses += 1;
214    }
215
216    /// Get access frequency for a shape
217    pub fn get_frequency(&self, shape: &[usize]) -> usize {
218        let sig = Self::shape_signature(shape);
219        *self.access_counts.get(&sig).unwrap_or(&0)
220    }
221
222    /// Get top N most frequently accessed shapes
223    pub fn top_shapes(&self, n: usize) -> Vec<(String, usize)> {
224        let mut sorted: Vec<_> = self
225            .access_counts
226            .iter()
227            .map(|(k, v)| (k.clone(), *v))
228            .collect();
229        sorted.sort_by(|a, b| b.1.cmp(&a.1));
230        sorted.truncate(n);
231        sorted
232    }
233
234    /// Clear all recorded access patterns
235    pub fn clear(&mut self) {
236        self.access_counts.clear();
237        self.total_accesses = 0;
238    }
239
240    /// Get total number of unique shapes accessed
241    pub fn num_unique_shapes(&self) -> usize {
242        self.access_counts.len()
243    }
244
245    /// Get total number of accesses
246    pub fn total_accesses(&self) -> usize {
247        self.total_accesses
248    }
249
250    /// Get access distribution (shape -> frequency ratio)
251    pub fn access_distribution(&self) -> HashMap<String, f64> {
252        if self.total_accesses == 0 {
253            return HashMap::new();
254        }
255
256        self.access_counts
257            .iter()
258            .map(|(k, &v)| {
259                let ratio = v as f64 / self.total_accesses as f64;
260                (k.clone(), ratio)
261            })
262            .collect()
263    }
264}
265
266impl Default for AccessPatternTracker {
267    fn default() -> Self {
268        Self::new()
269    }
270}
271
272/// Automatic pooling recommendation engine
273///
274/// Analyzes access patterns and provides recommendations for
275/// optimal pooling configuration.
276pub struct PoolingRecommender {
277    policy: PoolingPolicy,
278    tracker: AccessPatternTracker,
279}
280
281impl PoolingRecommender {
282    /// Create a new recommender with default policy
283    pub fn new() -> Self {
284        Self {
285            policy: PoolingPolicy::default(),
286            tracker: AccessPatternTracker::new(),
287        }
288    }
289
290    /// Create a recommender with custom policy
291    pub fn with_policy(policy: PoolingPolicy) -> Self {
292        Self {
293            policy,
294            tracker: AccessPatternTracker::new(),
295        }
296    }
297
298    /// Record a buffer allocation
299    pub fn record_allocation(&mut self, shape: &[usize]) {
300        self.tracker.record_access(shape);
301    }
302
303    /// Get shapes that should be pooled based on recorded patterns
304    pub fn recommend_shapes(&self, elem_size: usize) -> Vec<String> {
305        let mut recommendations = Vec::new();
306
307        for (shape_sig, &frequency) in &self.tracker.access_counts {
308            // Parse shape signature back to shape
309            let shape: Vec<usize> = shape_sig
310                .split('x')
311                .filter_map(|s| s.parse().ok())
312                .collect();
313
314            if self
315                .policy
316                .should_pool_with_frequency(&shape, elem_size, frequency)
317            {
318                recommendations.push(shape_sig.clone());
319            }
320        }
321
322        recommendations
323    }
324
325    /// Generate a report with pooling recommendations
326    pub fn generate_report(&self, elem_size: usize) -> PoolingReport {
327        let recommended_shapes = self.recommend_shapes(elem_size);
328        let top_shapes = self.tracker.top_shapes(10);
329
330        let total_poolable_accesses: usize = recommended_shapes
331            .iter()
332            .filter_map(|sig| self.tracker.access_counts.get(sig))
333            .sum();
334
335        let potential_hit_rate = if self.tracker.total_accesses > 0 {
336            total_poolable_accesses as f64 / self.tracker.total_accesses as f64
337        } else {
338            0.0
339        };
340
341        PoolingReport {
342            total_shapes_accessed: self.tracker.num_unique_shapes(),
343            total_accesses: self.tracker.total_accesses(),
344            recommended_shapes_count: recommended_shapes.len(),
345            recommended_shapes,
346            top_10_shapes: top_shapes,
347            potential_hit_rate,
348            policy: self.policy.clone(),
349        }
350    }
351
352    /// Clear all tracking data
353    pub fn clear(&mut self) {
354        self.tracker.clear();
355    }
356}
357
358impl Default for PoolingRecommender {
359    fn default() -> Self {
360        Self::new()
361    }
362}
363
364/// Report with pooling recommendations
365#[derive(Debug, Clone)]
366pub struct PoolingReport {
367    /// Total number of unique shapes accessed
368    pub total_shapes_accessed: usize,
369    /// Total number of allocation requests
370    pub total_accesses: usize,
371    /// Number of shapes recommended for pooling
372    pub recommended_shapes_count: usize,
373    /// List of recommended shape signatures
374    pub recommended_shapes: Vec<String>,
375    /// Top 10 most frequently accessed shapes
376    pub top_10_shapes: Vec<(String, usize)>,
377    /// Potential hit rate if pooling is enabled for recommended shapes
378    pub potential_hit_rate: f64,
379    /// Policy used for recommendations
380    pub policy: PoolingPolicy,
381}
382
383impl PoolingReport {
384    /// Print a formatted report to console
385    pub fn print(&self) {
386        println!("=== Pooling Recommendation Report ===");
387        println!("Total shapes accessed: {}", self.total_shapes_accessed);
388        println!("Total allocations: {}", self.total_accesses);
389        println!(
390            "Recommended for pooling: {} shapes",
391            self.recommended_shapes_count
392        );
393        println!(
394            "Potential hit rate: {:.1}%",
395            self.potential_hit_rate * 100.0
396        );
397        println!("\nTop 10 most accessed shapes:");
398        for (i, (shape, count)) in self.top_10_shapes.iter().enumerate() {
399            let is_recommended = self.recommended_shapes.contains(shape);
400            let marker = if is_recommended { "✓" } else { " " };
401            println!("  {}. [{}] {} - {} accesses", i + 1, marker, shape, count);
402        }
403        println!("\nPolicy settings:");
404        println!("  Min size: {} bytes", self.policy.min_size_bytes);
405        println!("  Max size: {} bytes", self.policy.max_size_bytes);
406        println!("  Min frequency: {}", self.policy.min_frequency);
407    }
408}
409
410#[cfg(test)]
411mod tests {
412    use super::*;
413
414    #[test]
415    fn test_pooling_policy_default() {
416        let policy = PoolingPolicy::default();
417
418        // Small buffer - should not pool
419        assert!(!policy.should_pool(&[10], 8)); // 80 bytes
420
421        // Medium buffer - should pool
422        assert!(policy.should_pool(&[1000], 8)); // 8KB
423
424        // Very large buffer - should not pool (exceeds max)
425        assert!(!policy.should_pool(&[10_000_000], 8)); // 80MB
426    }
427
428    #[test]
429    fn test_pooling_policy_conservative() {
430        let policy = PoolingPolicy::conservative();
431
432        // More restrictive than default
433        assert!(!policy.should_pool(&[1000], 8)); // 8KB - below 16KB min
434        assert!(policy.should_pool(&[5000], 8)); // 40KB - within range
435    }
436
437    #[test]
438    fn test_pooling_policy_aggressive() {
439        let policy = PoolingPolicy::aggressive();
440
441        // Pool almost everything
442        assert!(policy.should_pool(&[200], 8)); // 1.6KB
443        assert!(policy.should_pool(&[10_000], 8)); // 80KB
444    }
445
446    #[test]
447    fn test_pooling_policy_with_frequency() {
448        let policy = PoolingPolicy::default();
449
450        // Good size but low frequency
451        assert!(!policy.should_pool_with_frequency(&[1000], 8, 1));
452
453        // Good size and good frequency
454        assert!(policy.should_pool_with_frequency(&[1000], 8, 5));
455    }
456
457    #[test]
458    fn test_pooling_policy_memory_pressure() {
459        let policy = PoolingPolicy::default();
460
461        // Low memory - adjust to be more conservative
462        let adjusted_low = policy.with_memory_pressure(0.1); // 10% memory free
463        assert!(adjusted_low.min_size_bytes > policy.min_size_bytes);
464        assert!(adjusted_low.max_size_bytes < policy.max_size_bytes);
465
466        // High memory - can be more aggressive
467        let adjusted_high = policy.with_memory_pressure(0.8); // 80% memory free
468        assert!(adjusted_high.min_size_bytes <= policy.min_size_bytes);
469        assert!(adjusted_high.max_size_bytes >= policy.max_size_bytes);
470    }
471
472    #[test]
473    fn test_access_pattern_tracker() {
474        let mut tracker = AccessPatternTracker::new();
475
476        tracker.record_access(&[100]);
477        tracker.record_access(&[100]);
478        tracker.record_access(&[200]);
479
480        assert_eq!(tracker.get_frequency(&[100]), 2);
481        assert_eq!(tracker.get_frequency(&[200]), 1);
482        assert_eq!(tracker.get_frequency(&[300]), 0);
483        assert_eq!(tracker.total_accesses(), 3);
484        assert_eq!(tracker.num_unique_shapes(), 2);
485    }
486
487    #[test]
488    fn test_access_pattern_top_shapes() {
489        let mut tracker = AccessPatternTracker::new();
490
491        for _ in 0..10 {
492            tracker.record_access(&[100]);
493        }
494        for _ in 0..5 {
495            tracker.record_access(&[200]);
496        }
497        for _ in 0..3 {
498            tracker.record_access(&[300]);
499        }
500
501        let top = tracker.top_shapes(2);
502        assert_eq!(top.len(), 2);
503        assert_eq!(top[0].1, 10); // [100] accessed 10 times
504        assert_eq!(top[1].1, 5); // [200] accessed 5 times
505    }
506
507    #[test]
508    fn test_access_pattern_clear() {
509        let mut tracker = AccessPatternTracker::new();
510
511        tracker.record_access(&[100]);
512        tracker.clear();
513
514        assert_eq!(tracker.total_accesses(), 0);
515        assert_eq!(tracker.num_unique_shapes(), 0);
516    }
517
518    #[test]
519    fn test_pooling_recommender_basic() {
520        let mut recommender = PoolingRecommender::new();
521
522        // Record accesses
523        for _ in 0..10 {
524            recommender.record_allocation(&[1000]); // 8KB buffer
525        }
526        for _ in 0..2 {
527            recommender.record_allocation(&[100]); // 800B buffer (too small)
528        }
529
530        let recommendations = recommender.recommend_shapes(8);
531
532        // Should recommend [1000] but not [100]
533        assert!(recommendations.contains(&"1000".to_string()));
534        assert!(!recommendations.contains(&"100".to_string()));
535    }
536
537    #[test]
538    fn test_pooling_recommender_report() {
539        let mut recommender = PoolingRecommender::new();
540
541        for _ in 0..20 {
542            recommender.record_allocation(&[1000]);
543        }
544        for _ in 0..10 {
545            recommender.record_allocation(&[2000]);
546        }
547
548        let report = recommender.generate_report(8);
549
550        assert_eq!(report.total_accesses, 30);
551        assert_eq!(report.total_shapes_accessed, 2);
552        assert!(report.recommended_shapes_count > 0);
553        assert!(report.potential_hit_rate > 0.0);
554    }
555
556    #[test]
557    fn test_pooling_recommender_conservative_vs_aggressive() {
558        let mut recommender_conservative =
559            PoolingRecommender::with_policy(PoolingPolicy::conservative());
560        let mut recommender_aggressive =
561            PoolingRecommender::with_policy(PoolingPolicy::aggressive());
562
563        // Small buffer accessed frequently
564        for _ in 0..10 {
565            recommender_conservative.record_allocation(&[500]); // 4KB
566            recommender_aggressive.record_allocation(&[500]);
567        }
568
569        let rec_conservative = recommender_conservative.recommend_shapes(8);
570        let rec_aggressive = recommender_aggressive.recommend_shapes(8);
571
572        // Aggressive should recommend small buffers, conservative should not
573        assert!(rec_aggressive.len() >= rec_conservative.len());
574    }
575
576    #[test]
577    fn test_access_distribution() {
578        let mut tracker = AccessPatternTracker::new();
579
580        for _ in 0..50 {
581            tracker.record_access(&[100]);
582        }
583        for _ in 0..30 {
584            tracker.record_access(&[200]);
585        }
586        for _ in 0..20 {
587            tracker.record_access(&[300]);
588        }
589
590        let dist = tracker.access_distribution();
591
592        assert_eq!(dist.len(), 3);
593        assert!((dist["100"] - 0.5).abs() < 0.01); // 50/100 = 0.5
594        assert!((dist["200"] - 0.3).abs() < 0.01); // 30/100 = 0.3
595        assert!((dist["300"] - 0.2).abs() < 0.01); // 20/100 = 0.2
596    }
597}