1use std::collections::HashMap;
27
28pub const DEFAULT_MIN_POOL_SIZE_BYTES: usize = 4096; pub const DEFAULT_MAX_POOL_SIZE_BYTES: usize = 64 * 1024 * 1024; pub const DEFAULT_MIN_ACCESS_FREQUENCY: usize = 2;
39
40#[derive(Debug, Clone)]
42pub struct PoolingPolicy {
43 pub min_size_bytes: usize,
45 pub max_size_bytes: usize,
47 pub min_frequency: usize,
49 pub memory_pressure_threshold: f64,
52 pub adaptive: bool,
54}
55
56impl PoolingPolicy {
57 pub fn new() -> Self {
59 Self {
60 min_size_bytes: DEFAULT_MIN_POOL_SIZE_BYTES,
61 max_size_bytes: DEFAULT_MAX_POOL_SIZE_BYTES,
62 min_frequency: DEFAULT_MIN_ACCESS_FREQUENCY,
63 memory_pressure_threshold: 0.2, adaptive: true,
65 }
66 }
67
68 pub fn conservative() -> Self {
70 Self {
71 min_size_bytes: 16384, max_size_bytes: 32 * 1024 * 1024, min_frequency: 5,
74 memory_pressure_threshold: 0.3,
75 adaptive: false,
76 }
77 }
78
79 pub fn aggressive() -> Self {
81 Self {
82 min_size_bytes: 1024, max_size_bytes: 128 * 1024 * 1024, min_frequency: 1,
85 memory_pressure_threshold: 0.1,
86 adaptive: true,
87 }
88 }
89
90 pub fn memory_constrained() -> Self {
92 Self {
93 min_size_bytes: 8192, max_size_bytes: 16 * 1024 * 1024, min_frequency: 3,
96 memory_pressure_threshold: 0.4,
97 adaptive: false,
98 }
99 }
100
101 pub fn should_pool(&self, shape: &[usize], elem_size: usize) -> bool {
112 let total_elements: usize = shape.iter().product();
113 let total_bytes = total_elements * elem_size;
114
115 if total_bytes < self.min_size_bytes {
117 return false; }
119
120 if total_bytes > self.max_size_bytes {
121 return false; }
123
124 true
125 }
126
127 pub fn should_pool_with_frequency(
131 &self,
132 shape: &[usize],
133 elem_size: usize,
134 frequency: usize,
135 ) -> bool {
136 if !self.should_pool(shape, elem_size) {
137 return false;
138 }
139
140 frequency >= self.min_frequency
141 }
142
143 pub fn with_memory_pressure(&self, available_memory_ratio: f64) -> Self {
153 if !self.adaptive {
154 return self.clone();
155 }
156
157 let mut adjusted = self.clone();
158
159 if available_memory_ratio < self.memory_pressure_threshold {
160 adjusted.min_size_bytes *= 2;
162 adjusted.max_size_bytes /= 2;
163 adjusted.min_frequency += 2;
164 } else if available_memory_ratio > 0.5 {
165 adjusted.min_size_bytes = adjusted.min_size_bytes.saturating_sub(1024);
167 adjusted.max_size_bytes = adjusted.max_size_bytes.saturating_mul(2);
168 }
169
170 adjusted
171 }
172}
173
174impl Default for PoolingPolicy {
175 fn default() -> Self {
176 Self::new()
177 }
178}
179
180#[derive(Debug, Clone)]
185pub struct AccessPatternTracker {
186 access_counts: HashMap<String, usize>,
188 total_accesses: usize,
190}
191
192impl AccessPatternTracker {
193 pub fn new() -> Self {
195 Self {
196 access_counts: HashMap::new(),
197 total_accesses: 0,
198 }
199 }
200
201 fn shape_signature(shape: &[usize]) -> String {
202 shape
203 .iter()
204 .map(|s| s.to_string())
205 .collect::<Vec<_>>()
206 .join("x")
207 }
208
209 pub fn record_access(&mut self, shape: &[usize]) {
211 let sig = Self::shape_signature(shape);
212 *self.access_counts.entry(sig).or_insert(0) += 1;
213 self.total_accesses += 1;
214 }
215
216 pub fn get_frequency(&self, shape: &[usize]) -> usize {
218 let sig = Self::shape_signature(shape);
219 *self.access_counts.get(&sig).unwrap_or(&0)
220 }
221
222 pub fn top_shapes(&self, n: usize) -> Vec<(String, usize)> {
224 let mut sorted: Vec<_> = self
225 .access_counts
226 .iter()
227 .map(|(k, v)| (k.clone(), *v))
228 .collect();
229 sorted.sort_by(|a, b| b.1.cmp(&a.1));
230 sorted.truncate(n);
231 sorted
232 }
233
234 pub fn clear(&mut self) {
236 self.access_counts.clear();
237 self.total_accesses = 0;
238 }
239
240 pub fn num_unique_shapes(&self) -> usize {
242 self.access_counts.len()
243 }
244
245 pub fn total_accesses(&self) -> usize {
247 self.total_accesses
248 }
249
250 pub fn access_distribution(&self) -> HashMap<String, f64> {
252 if self.total_accesses == 0 {
253 return HashMap::new();
254 }
255
256 self.access_counts
257 .iter()
258 .map(|(k, &v)| {
259 let ratio = v as f64 / self.total_accesses as f64;
260 (k.clone(), ratio)
261 })
262 .collect()
263 }
264}
265
266impl Default for AccessPatternTracker {
267 fn default() -> Self {
268 Self::new()
269 }
270}
271
272pub struct PoolingRecommender {
277 policy: PoolingPolicy,
278 tracker: AccessPatternTracker,
279}
280
281impl PoolingRecommender {
282 pub fn new() -> Self {
284 Self {
285 policy: PoolingPolicy::default(),
286 tracker: AccessPatternTracker::new(),
287 }
288 }
289
290 pub fn with_policy(policy: PoolingPolicy) -> Self {
292 Self {
293 policy,
294 tracker: AccessPatternTracker::new(),
295 }
296 }
297
298 pub fn record_allocation(&mut self, shape: &[usize]) {
300 self.tracker.record_access(shape);
301 }
302
303 pub fn recommend_shapes(&self, elem_size: usize) -> Vec<String> {
305 let mut recommendations = Vec::new();
306
307 for (shape_sig, &frequency) in &self.tracker.access_counts {
308 let shape: Vec<usize> = shape_sig
310 .split('x')
311 .filter_map(|s| s.parse().ok())
312 .collect();
313
314 if self
315 .policy
316 .should_pool_with_frequency(&shape, elem_size, frequency)
317 {
318 recommendations.push(shape_sig.clone());
319 }
320 }
321
322 recommendations
323 }
324
325 pub fn generate_report(&self, elem_size: usize) -> PoolingReport {
327 let recommended_shapes = self.recommend_shapes(elem_size);
328 let top_shapes = self.tracker.top_shapes(10);
329
330 let total_poolable_accesses: usize = recommended_shapes
331 .iter()
332 .filter_map(|sig| self.tracker.access_counts.get(sig))
333 .sum();
334
335 let potential_hit_rate = if self.tracker.total_accesses > 0 {
336 total_poolable_accesses as f64 / self.tracker.total_accesses as f64
337 } else {
338 0.0
339 };
340
341 PoolingReport {
342 total_shapes_accessed: self.tracker.num_unique_shapes(),
343 total_accesses: self.tracker.total_accesses(),
344 recommended_shapes_count: recommended_shapes.len(),
345 recommended_shapes,
346 top_10_shapes: top_shapes,
347 potential_hit_rate,
348 policy: self.policy.clone(),
349 }
350 }
351
352 pub fn clear(&mut self) {
354 self.tracker.clear();
355 }
356}
357
358impl Default for PoolingRecommender {
359 fn default() -> Self {
360 Self::new()
361 }
362}
363
364#[derive(Debug, Clone)]
366pub struct PoolingReport {
367 pub total_shapes_accessed: usize,
369 pub total_accesses: usize,
371 pub recommended_shapes_count: usize,
373 pub recommended_shapes: Vec<String>,
375 pub top_10_shapes: Vec<(String, usize)>,
377 pub potential_hit_rate: f64,
379 pub policy: PoolingPolicy,
381}
382
383impl PoolingReport {
384 pub fn print(&self) {
386 println!("=== Pooling Recommendation Report ===");
387 println!("Total shapes accessed: {}", self.total_shapes_accessed);
388 println!("Total allocations: {}", self.total_accesses);
389 println!(
390 "Recommended for pooling: {} shapes",
391 self.recommended_shapes_count
392 );
393 println!(
394 "Potential hit rate: {:.1}%",
395 self.potential_hit_rate * 100.0
396 );
397 println!("\nTop 10 most accessed shapes:");
398 for (i, (shape, count)) in self.top_10_shapes.iter().enumerate() {
399 let is_recommended = self.recommended_shapes.contains(shape);
400 let marker = if is_recommended { "✓" } else { " " };
401 println!(" {}. [{}] {} - {} accesses", i + 1, marker, shape, count);
402 }
403 println!("\nPolicy settings:");
404 println!(" Min size: {} bytes", self.policy.min_size_bytes);
405 println!(" Max size: {} bytes", self.policy.max_size_bytes);
406 println!(" Min frequency: {}", self.policy.min_frequency);
407 }
408}
409
410#[cfg(test)]
411mod tests {
412 use super::*;
413
414 #[test]
415 fn test_pooling_policy_default() {
416 let policy = PoolingPolicy::default();
417
418 assert!(!policy.should_pool(&[10], 8)); assert!(policy.should_pool(&[1000], 8)); assert!(!policy.should_pool(&[10_000_000], 8)); }
427
428 #[test]
429 fn test_pooling_policy_conservative() {
430 let policy = PoolingPolicy::conservative();
431
432 assert!(!policy.should_pool(&[1000], 8)); assert!(policy.should_pool(&[5000], 8)); }
436
437 #[test]
438 fn test_pooling_policy_aggressive() {
439 let policy = PoolingPolicy::aggressive();
440
441 assert!(policy.should_pool(&[200], 8)); assert!(policy.should_pool(&[10_000], 8)); }
445
446 #[test]
447 fn test_pooling_policy_with_frequency() {
448 let policy = PoolingPolicy::default();
449
450 assert!(!policy.should_pool_with_frequency(&[1000], 8, 1));
452
453 assert!(policy.should_pool_with_frequency(&[1000], 8, 5));
455 }
456
457 #[test]
458 fn test_pooling_policy_memory_pressure() {
459 let policy = PoolingPolicy::default();
460
461 let adjusted_low = policy.with_memory_pressure(0.1); assert!(adjusted_low.min_size_bytes > policy.min_size_bytes);
464 assert!(adjusted_low.max_size_bytes < policy.max_size_bytes);
465
466 let adjusted_high = policy.with_memory_pressure(0.8); assert!(adjusted_high.min_size_bytes <= policy.min_size_bytes);
469 assert!(adjusted_high.max_size_bytes >= policy.max_size_bytes);
470 }
471
472 #[test]
473 fn test_access_pattern_tracker() {
474 let mut tracker = AccessPatternTracker::new();
475
476 tracker.record_access(&[100]);
477 tracker.record_access(&[100]);
478 tracker.record_access(&[200]);
479
480 assert_eq!(tracker.get_frequency(&[100]), 2);
481 assert_eq!(tracker.get_frequency(&[200]), 1);
482 assert_eq!(tracker.get_frequency(&[300]), 0);
483 assert_eq!(tracker.total_accesses(), 3);
484 assert_eq!(tracker.num_unique_shapes(), 2);
485 }
486
487 #[test]
488 fn test_access_pattern_top_shapes() {
489 let mut tracker = AccessPatternTracker::new();
490
491 for _ in 0..10 {
492 tracker.record_access(&[100]);
493 }
494 for _ in 0..5 {
495 tracker.record_access(&[200]);
496 }
497 for _ in 0..3 {
498 tracker.record_access(&[300]);
499 }
500
501 let top = tracker.top_shapes(2);
502 assert_eq!(top.len(), 2);
503 assert_eq!(top[0].1, 10); assert_eq!(top[1].1, 5); }
506
507 #[test]
508 fn test_access_pattern_clear() {
509 let mut tracker = AccessPatternTracker::new();
510
511 tracker.record_access(&[100]);
512 tracker.clear();
513
514 assert_eq!(tracker.total_accesses(), 0);
515 assert_eq!(tracker.num_unique_shapes(), 0);
516 }
517
518 #[test]
519 fn test_pooling_recommender_basic() {
520 let mut recommender = PoolingRecommender::new();
521
522 for _ in 0..10 {
524 recommender.record_allocation(&[1000]); }
526 for _ in 0..2 {
527 recommender.record_allocation(&[100]); }
529
530 let recommendations = recommender.recommend_shapes(8);
531
532 assert!(recommendations.contains(&"1000".to_string()));
534 assert!(!recommendations.contains(&"100".to_string()));
535 }
536
537 #[test]
538 fn test_pooling_recommender_report() {
539 let mut recommender = PoolingRecommender::new();
540
541 for _ in 0..20 {
542 recommender.record_allocation(&[1000]);
543 }
544 for _ in 0..10 {
545 recommender.record_allocation(&[2000]);
546 }
547
548 let report = recommender.generate_report(8);
549
550 assert_eq!(report.total_accesses, 30);
551 assert_eq!(report.total_shapes_accessed, 2);
552 assert!(report.recommended_shapes_count > 0);
553 assert!(report.potential_hit_rate > 0.0);
554 }
555
556 #[test]
557 fn test_pooling_recommender_conservative_vs_aggressive() {
558 let mut recommender_conservative =
559 PoolingRecommender::with_policy(PoolingPolicy::conservative());
560 let mut recommender_aggressive =
561 PoolingRecommender::with_policy(PoolingPolicy::aggressive());
562
563 for _ in 0..10 {
565 recommender_conservative.record_allocation(&[500]); recommender_aggressive.record_allocation(&[500]);
567 }
568
569 let rec_conservative = recommender_conservative.recommend_shapes(8);
570 let rec_aggressive = recommender_aggressive.recommend_shapes(8);
571
572 assert!(rec_aggressive.len() >= rec_conservative.len());
574 }
575
576 #[test]
577 fn test_access_distribution() {
578 let mut tracker = AccessPatternTracker::new();
579
580 for _ in 0..50 {
581 tracker.record_access(&[100]);
582 }
583 for _ in 0..30 {
584 tracker.record_access(&[200]);
585 }
586 for _ in 0..20 {
587 tracker.record_access(&[300]);
588 }
589
590 let dist = tracker.access_distribution();
591
592 assert_eq!(dist.len(), 3);
593 assert!((dist["100"] - 0.5).abs() < 0.01); assert!((dist["200"] - 0.3).abs() < 0.01); assert!((dist["300"] - 0.2).abs() < 0.01); }
597}