1use crate::{Result, Shape};
17#[cfg(feature = "serde")]
18use serde::{Deserialize, Serialize};
19use std::collections::HashMap;
20use std::sync::{Arc, Mutex, RwLock};
21use std::time::{Duration, Instant};
22
23#[derive(Debug, Clone)]
25#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
26pub struct OperationMetrics {
27 pub op_name: String,
29 pub input_shapes: Vec<Shape>,
31 pub duration_ns: u64,
33 pub memory_bandwidth: f64,
35 pub cpu_utilization: f32,
37 pub cache_hit_rate: f32,
39 pub hardware_features: Vec<String>,
41}
42
43#[derive(Debug, Clone, PartialEq, Eq, Hash)]
45#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
46pub enum ExecutionStrategy {
47 Sequential,
49 Parallel { num_threads: usize },
51 Simd { instruction_set: String },
53 Gpu { device_id: u32 },
55 Hybrid { cpu_ratio_percent: u8 },
57 Custom { algorithm: String },
59}
60
61type PerformanceMap = HashMap<(String, Vec<Shape>, ExecutionStrategy), f64>;
63
64#[derive(Debug, Clone)]
66pub struct PerformancePredictor {
67 metrics_history: Arc<RwLock<Vec<OperationMetrics>>>,
69 strategy_performance: Arc<RwLock<PerformanceMap>>,
71 learning_rate: f64,
73}
74
75impl Default for PerformancePredictor {
76 fn default() -> Self {
77 Self::new()
78 }
79}
80
81impl PerformancePredictor {
82 pub fn new() -> Self {
84 Self {
85 metrics_history: Arc::new(RwLock::new(Vec::new())),
86 strategy_performance: Arc::new(RwLock::new(HashMap::new())),
87 learning_rate: 0.1,
88 }
89 }
90
91 pub fn predict_best_strategy(&self, op_name: &str, shapes: &[Shape]) -> ExecutionStrategy {
93 let performance_map = self
94 .strategy_performance
95 .read()
96 .expect("read lock should not be poisoned");
97
98 let mut best_strategy = ExecutionStrategy::Sequential;
100 let mut best_performance = f64::INFINITY;
101
102 for ((stored_op, stored_shapes, strategy), &performance) in performance_map.iter() {
103 if stored_op == op_name
104 && self.shapes_match(stored_shapes, shapes)
105 && performance < best_performance
106 {
107 best_performance = performance;
108 best_strategy = strategy.clone();
109 }
110 }
111
112 if best_performance == f64::INFINITY {
114 self.heuristic_strategy_selection(shapes)
115 } else {
116 best_strategy
117 }
118 }
119
120 pub fn update_performance(&self, metrics: &OperationMetrics, strategy: ExecutionStrategy) {
122 let mut history = self
123 .metrics_history
124 .write()
125 .expect("write lock should not be poisoned");
126 history.push(metrics.clone());
127
128 let mut performance_map = self
129 .strategy_performance
130 .write()
131 .expect("write lock should not be poisoned");
132 let key = (
133 metrics.op_name.clone(),
134 metrics.input_shapes.clone(),
135 strategy,
136 );
137
138 let new_performance = metrics.duration_ns as f64;
140 let entry = performance_map.entry(key).or_insert(new_performance);
141 *entry = (1.0 - self.learning_rate) * *entry + self.learning_rate * new_performance;
142
143 if history.len() > 10000 {
145 history.drain(..1000);
146 }
147 }
148
149 fn shapes_match(&self, historical: &[Shape], current: &[Shape]) -> bool {
151 if historical.len() != current.len() {
152 return false;
153 }
154
155 for (hist_shape, curr_shape) in historical.iter().zip(current.iter()) {
157 if hist_shape.dims() != curr_shape.dims() {
158 return false;
159 }
160
161 let hist_size: usize = hist_shape.size();
163 let curr_size: usize = curr_shape.size();
164 let size_ratio = (hist_size.max(curr_size) as f64) / (hist_size.min(curr_size) as f64);
165
166 if size_ratio > 1.2 {
167 return false;
168 }
169 }
170
171 true
172 }
173
174 fn heuristic_strategy_selection(&self, shapes: &[Shape]) -> ExecutionStrategy {
176 let total_elements: usize = shapes.iter().map(|s| s.size()).sum();
177
178 match total_elements {
179 0..=1000 => ExecutionStrategy::Sequential,
181
182 1001..=100000 => {
184 if self.has_avx2() {
185 ExecutionStrategy::Simd {
186 instruction_set: "avx2".to_string(),
187 }
188 } else if self.has_neon() {
189 ExecutionStrategy::Simd {
190 instruction_set: "neon".to_string(),
191 }
192 } else {
193 ExecutionStrategy::Parallel { num_threads: 4 }
194 }
195 }
196
197 100001..=10000000 => ExecutionStrategy::Parallel {
199 num_threads: num_cpus::get().min(16),
200 },
201
202 _ => ExecutionStrategy::Hybrid {
204 cpu_ratio_percent: 30,
205 },
206 }
207 }
208
209 fn has_avx2(&self) -> bool {
211 #[cfg(target_arch = "x86_64")]
212 {
213 is_x86_feature_detected!("avx2")
214 }
215 #[cfg(not(target_arch = "x86_64"))]
216 false
217 }
218
219 fn has_neon(&self) -> bool {
221 #[cfg(target_arch = "aarch64")]
222 {
223 std::arch::is_aarch64_feature_detected!("neon")
224 }
225 #[cfg(not(target_arch = "aarch64"))]
226 false
227 }
228}
229
230pub struct AdaptiveTuner {
232 predictor: PerformancePredictor,
234 active_strategies: Arc<Mutex<HashMap<String, ExecutionStrategy>>>,
236 profiling_enabled: bool,
238}
239
240impl AdaptiveTuner {
241 pub fn new() -> Self {
243 Self {
244 predictor: PerformancePredictor::new(),
245 active_strategies: Arc::new(Mutex::new(HashMap::new())),
246 profiling_enabled: true,
247 }
248 }
249
250 pub fn execute_with_tuning<F, T>(
252 &self,
253 op_name: &str,
254 shapes: &[Shape],
255 operation: F,
256 ) -> Result<T>
257 where
258 F: Fn(ExecutionStrategy) -> Result<T>,
259 {
260 let cache_key = self.create_cache_key(op_name, shapes);
262
263 let strategy = {
265 let cache = self
266 .active_strategies
267 .lock()
268 .expect("lock should not be poisoned");
269 cache.get(&cache_key).cloned()
270 }
271 .unwrap_or_else(|| {
272 self.predictor.predict_best_strategy(op_name, shapes)
274 });
275
276 let start_time = Instant::now();
278 let result = operation(strategy.clone())?;
279 let duration = start_time.elapsed();
280
281 if self.profiling_enabled {
283 let metrics = OperationMetrics {
284 op_name: op_name.to_string(),
285 input_shapes: shapes.to_vec(),
286 duration_ns: duration.as_nanos() as u64,
287 memory_bandwidth: self.estimate_memory_bandwidth(shapes, duration),
288 cpu_utilization: self.get_cpu_utilization(),
289 cache_hit_rate: 0.95, hardware_features: self.get_active_features(&strategy),
291 };
292
293 self.predictor
294 .update_performance(&metrics, strategy.clone());
295
296 let mut cache = self
298 .active_strategies
299 .lock()
300 .expect("lock should not be poisoned");
301 cache.insert(cache_key, strategy);
302 }
303
304 Ok(result)
305 }
306
307 pub fn set_profiling_enabled(&mut self, enabled: bool) {
309 self.profiling_enabled = enabled;
310 }
311
312 fn create_cache_key(&self, op_name: &str, shapes: &[Shape]) -> String {
314 let shapes_str = shapes
315 .iter()
316 .map(|shape| format!("{shape:?}"))
317 .collect::<Vec<_>>()
318 .join(",");
319 format!("{op_name}:{shapes_str}")
320 }
321
322 pub fn clear_strategy_cache(&self) {
324 let mut cache = self
325 .active_strategies
326 .lock()
327 .expect("lock should not be poisoned");
328 cache.clear();
329 }
330
331 pub fn get_performance_stats(&self) -> Result<String> {
333 let history = self
334 .predictor
335 .metrics_history
336 .read()
337 .expect("read lock should not be poisoned");
338
339 if history.is_empty() {
340 return Ok("No performance data collected yet.".to_string());
341 }
342
343 let mut stats = String::new();
344 stats.push_str("Adaptive Tuning Performance Statistics\n");
345 stats.push_str("======================================\n");
346 stats.push_str(&format!("Total operations profiled: {}\n", history.len()));
347
348 let mut op_stats: HashMap<String, Vec<&OperationMetrics>> = HashMap::new();
350 for metrics in history.iter() {
351 op_stats
352 .entry(metrics.op_name.clone())
353 .or_default()
354 .push(metrics);
355 }
356
357 for (op_name, metrics) in op_stats {
358 let avg_duration =
359 metrics.iter().map(|m| m.duration_ns).sum::<u64>() / metrics.len() as u64;
360 let avg_bandwidth =
361 metrics.iter().map(|m| m.memory_bandwidth).sum::<f64>() / metrics.len() as f64;
362
363 stats.push_str(&format!(
364 "\n{}: {} executions, avg {:.2}ms, {:.2} GB/s\n",
365 op_name,
366 metrics.len(),
367 avg_duration as f64 / 1_000_000.0,
368 avg_bandwidth / 1e9
369 ));
370 }
371
372 Ok(stats)
373 }
374
375 fn estimate_memory_bandwidth(&self, shapes: &[Shape], duration: Duration) -> f64 {
377 let total_elements: usize = shapes.iter().map(|s| s.size()).sum();
378 let estimated_bytes = total_elements * 8; if duration.as_nanos() == 0 {
381 0.0
382 } else {
383 (estimated_bytes as f64) / (duration.as_secs_f64())
384 }
385 }
386
387 fn get_cpu_utilization(&self) -> f32 {
389 0.8
392 }
393
394 fn get_active_features(&self, strategy: &ExecutionStrategy) -> Vec<String> {
396 match strategy {
397 ExecutionStrategy::Simd { instruction_set } => vec![instruction_set.clone()],
398 ExecutionStrategy::Gpu { .. } => vec!["gpu".to_string()],
399 ExecutionStrategy::Parallel { .. } => vec!["multi-thread".to_string()],
400 _ => vec![],
401 }
402 }
403}
404
405impl Default for AdaptiveTuner {
406 fn default() -> Self {
407 Self::new()
408 }
409}
410
411lazy_static::lazy_static! {
412 pub static ref GLOBAL_TUNER: Arc<Mutex<AdaptiveTuner>> =
414 Arc::new(Mutex::new(AdaptiveTuner::new()));
415}
416
417pub fn execute_with_adaptive_tuning<F, T>(
419 op_name: &str,
420 shapes: &[Shape],
421 operation: F,
422) -> Result<T>
423where
424 F: Fn(ExecutionStrategy) -> Result<T>,
425{
426 let tuner = GLOBAL_TUNER.lock().expect("lock should not be poisoned");
427 tuner.execute_with_tuning(op_name, shapes, operation)
428}
429
430#[cfg(test)]
431mod tests {
432 use super::*;
433
434 #[test]
435 fn test_performance_predictor_creation() {
436 let predictor = PerformancePredictor::new();
437 let strategy = predictor.predict_best_strategy("test_op", &[Shape::from_slice(&[10, 10])]);
438
439 assert!(matches!(
441 strategy,
442 ExecutionStrategy::Sequential | ExecutionStrategy::Simd { .. }
443 ));
444 }
445
446 #[test]
447 fn test_adaptive_tuner_execution() {
448 let tuner = AdaptiveTuner::new();
449
450 let result =
451 tuner.execute_with_tuning("test_add", &[Shape::from_slice(&[100])], |strategy| {
452 Ok(format!("Executed with {:?}", strategy))
454 });
455
456 assert!(result.is_ok());
457 assert!(result
458 .expect("test: operation should succeed")
459 .contains("Executed with"));
460 }
461
462 #[test]
463 fn test_heuristic_strategy_selection() {
464 let predictor = PerformancePredictor::new();
465
466 let small_strategy = predictor.heuristic_strategy_selection(&[Shape::from_slice(&[10])]);
468 assert!(matches!(small_strategy, ExecutionStrategy::Sequential));
469
470 let large_strategy = predictor.heuristic_strategy_selection(&[Shape::from_slice(&[10000])]);
472 assert!(matches!(
473 large_strategy,
474 ExecutionStrategy::Parallel { .. } | ExecutionStrategy::Simd { .. }
475 ));
476 }
477
478 #[test]
479 fn test_performance_metrics_update() {
480 let predictor = PerformancePredictor::new();
481
482 let metrics = OperationMetrics {
483 op_name: "test_op".to_string(),
484 input_shapes: vec![Shape::from_slice(&[100, 100])],
485 duration_ns: 1000000,
486 memory_bandwidth: 1e9,
487 cpu_utilization: 0.8,
488 cache_hit_rate: 0.95,
489 hardware_features: vec!["avx2".to_string()],
490 };
491
492 predictor.update_performance(
493 &metrics,
494 ExecutionStrategy::Simd {
495 instruction_set: "avx2".to_string(),
496 },
497 );
498
499 let predicted =
501 predictor.predict_best_strategy("test_op", &[Shape::from_slice(&[100, 100])]);
502 assert!(matches!(predicted, ExecutionStrategy::Simd { .. }));
503 }
504}