1use crate::graph::ComputationGraph;
43use crate::JitResult;
44use serde::{Deserialize, Serialize};
45use std::collections::HashMap;
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct NormalDistribution {
54 pub mean: f64,
56
57 pub std_dev: f64,
59}
60
61impl NormalDistribution {
62 pub fn new(mean: f64, std_dev: f64) -> Self {
64 Self { mean, std_dev }
65 }
66
67 pub fn sample(&self) -> f64 {
69 use std::f64::consts::PI;
71 let u1 = self.uniform_sample();
72 let u2 = self.uniform_sample();
73
74 let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * PI * u2).cos();
75 self.mean + self.std_dev * z0
76 }
77
78 fn uniform_sample(&self) -> f64 {
79 use std::time::{SystemTime, UNIX_EPOCH};
80 let nanos = SystemTime::now()
81 .duration_since(UNIX_EPOCH)
82 .expect("system time should be after UNIX_EPOCH")
83 .subsec_nanos();
84 ((nanos % 10000) as f64 / 10000.0).max(0.0001)
85 }
86
87 pub fn pdf(&self, x: f64) -> f64 {
89 let coefficient = 1.0 / (self.std_dev * (2.0 * std::f64::consts::PI).sqrt());
90 let exponent = -((x - self.mean).powi(2)) / (2.0 * self.std_dev.powi(2));
91 coefficient * exponent.exp()
92 }
93
94 pub fn cdf(&self, x: f64) -> f64 {
96 let z = (x - self.mean) / self.std_dev;
97 0.5 * (1.0 + Self::erf(z / std::f64::consts::SQRT_2))
98 }
99
100 fn erf(x: f64) -> f64 {
102 let a1 = 0.254829592;
104 let a2 = -0.284496736;
105 let a3 = 1.421413741;
106 let a4 = -1.453152027;
107 let a5 = 1.061405429;
108 let p = 0.3275911;
109
110 let sign = if x < 0.0 { -1.0 } else { 1.0 };
111 let x = x.abs();
112
113 let t = 1.0 / (1.0 + p * x);
114 let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
115
116 sign * y
117 }
118
119 pub fn confidence_interval(&self, confidence: f64) -> (f64, f64) {
121 let z_score = match confidence {
124 c if c >= 0.99 => 2.576,
125 c if c >= 0.95 => 1.96,
126 c if c >= 0.90 => 1.645,
127 _ => 1.0,
128 };
129
130 let margin = z_score * self.std_dev;
131 (self.mean - margin, self.mean + margin)
132 }
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct BetaDistribution {
138 pub alpha: f64,
140
141 pub beta: f64,
143}
144
145impl BetaDistribution {
146 pub fn new(alpha: f64, beta: f64) -> Self {
148 Self { alpha, beta }
149 }
150
151 pub fn mean(&self) -> f64 {
153 self.alpha / (self.alpha + self.beta)
154 }
155
156 pub fn mode(&self) -> f64 {
158 if self.alpha > 1.0 && self.beta > 1.0 {
159 (self.alpha - 1.0) / (self.alpha + self.beta - 2.0)
160 } else {
161 self.mean()
162 }
163 }
164
165 pub fn variance(&self) -> f64 {
167 let sum = self.alpha + self.beta;
168 (self.alpha * self.beta) / (sum.powi(2) * (sum + 1.0))
169 }
170
171 pub fn update(&mut self, success: bool) {
173 if success {
174 self.alpha += 1.0;
175 } else {
176 self.beta += 1.0;
177 }
178 }
179
180 pub fn credible_interval(&self, confidence: f64) -> (f64, f64) {
182 let mean = self.mean();
184 let std_dev = self.variance().sqrt();
185
186 let z_score = if confidence >= 0.95 { 1.96 } else { 1.645 };
187 let margin = z_score * std_dev;
188
189 ((mean - margin).max(0.0), (mean + margin).min(1.0))
190 }
191}
192
193#[derive(Debug, Clone)]
199pub struct ProbabilisticPerformance {
200 pub time_dist: NormalDistribution,
202
203 pub memory_dist: NormalDistribution,
205
206 pub success_prob: BetaDistribution,
208
209 pub variance_factors: HashMap<String, f64>,
211}
212
213impl ProbabilisticPerformance {
214 pub fn new(mean_time: f64, mean_memory: f64) -> Self {
216 Self {
217 time_dist: NormalDistribution::new(mean_time, mean_time * 0.2), memory_dist: NormalDistribution::new(mean_memory, mean_memory * 0.15), success_prob: BetaDistribution::new(10.0, 1.0), variance_factors: HashMap::new(),
221 }
222 }
223
224 pub fn sample_time(&self) -> f64 {
226 self.time_dist.sample().max(0.0)
227 }
228
229 pub fn sample_memory(&self) -> f64 {
231 self.memory_dist.sample().max(0.0)
232 }
233
234 pub fn time_confidence_interval(&self, confidence: f64) -> (f64, f64) {
236 self.time_dist.confidence_interval(confidence)
237 }
238
239 pub fn memory_confidence_interval(&self, confidence: f64) -> (f64, f64) {
240 self.memory_dist.confidence_interval(confidence)
241 }
242
243 pub fn expected_time(&self) -> f64 {
245 self.time_dist.mean
246 }
247
248 pub fn expected_memory(&self) -> f64 {
249 self.memory_dist.mean
250 }
251
252 pub fn value_at_risk(&self, percentile: f64) -> f64 {
254 let z = match percentile {
256 p if p >= 0.99 => 2.326, p if p >= 0.95 => 1.645, _ => 1.0,
259 };
260 self.time_dist.mean + z * self.time_dist.std_dev
261 }
262}
263
264#[derive(Debug, Clone)]
270pub struct UncertainDecision {
271 pub optimization: String,
273
274 pub prob_improvement: BetaDistribution,
276
277 pub speedup_dist: NormalDistribution,
279
280 pub risk: f64,
282
283 pub observations: Vec<Observation>,
285}
286
287#[derive(Debug, Clone)]
289pub struct Observation {
290 pub beneficial: bool,
292
293 pub speedup: f64,
295
296 pub context: HashMap<String, f64>,
298}
299
300impl UncertainDecision {
301 pub fn new(optimization: String) -> Self {
303 Self {
304 optimization,
305 prob_improvement: BetaDistribution::new(1.0, 1.0), speedup_dist: NormalDistribution::new(1.2, 0.3), risk: 0.5,
308 observations: Vec::new(),
309 }
310 }
311
312 pub fn observe(&mut self, beneficial: bool, speedup: f64, context: HashMap<String, f64>) {
314 self.prob_improvement.update(beneficial);
316
317 if !self.observations.is_empty() {
319 let n = self.observations.len() as f64;
320 let old_mean = self.speedup_dist.mean;
321 let new_mean = (old_mean * n + speedup) / (n + 1.0);
322 self.speedup_dist.mean = new_mean;
323
324 let old_var = self.speedup_dist.std_dev.powi(2);
326 let new_var = ((old_var * n) + (speedup - new_mean).powi(2)) / (n + 1.0);
327 self.speedup_dist.std_dev = new_var.sqrt();
328 } else {
329 self.speedup_dist.mean = speedup;
330 }
331
332 self.observations.push(Observation {
333 beneficial,
334 speedup,
335 context,
336 });
337
338 self.risk = self.prob_improvement.variance();
340 }
341
342 pub fn should_apply(&self) -> bool {
344 let prob = self.prob_improvement.mean();
346 let threshold = 0.6; prob > threshold
348 }
349
350 pub fn expected_utility(&self, risk_aversion: f64) -> f64 {
352 let expected_reward = self.speedup_dist.mean * self.prob_improvement.mean();
353 let risk_penalty = risk_aversion * self.risk;
354 expected_reward - risk_penalty
355 }
356}
357
358pub struct ProbabilisticCompiler {
364 config: ProbabilisticConfig,
366
367 decisions: HashMap<String, UncertainDecision>,
369
370 performance_model: Option<ProbabilisticPerformance>,
372
373 stats: CompilerStatistics,
375}
376
377#[derive(Debug, Clone)]
379pub struct ProbabilisticConfig {
380 pub confidence_level: f64,
382
383 pub risk_aversion: f64,
385
386 pub num_samples: usize,
388
389 pub bayesian_optimization: bool,
391
392 pub exploration_rate: f64,
394}
395
396impl Default for ProbabilisticConfig {
397 fn default() -> Self {
398 Self {
399 confidence_level: 0.95,
400 risk_aversion: 0.5,
401 num_samples: 1000,
402 bayesian_optimization: true,
403 exploration_rate: 0.1,
404 }
405 }
406}
407
408#[derive(Debug, Clone, Default)]
410pub struct CompilerStatistics {
411 pub compilations: usize,
413
414 pub avg_prediction_error: f64,
416
417 pub calibration_score: f64,
419}
420
421impl ProbabilisticCompiler {
422 pub fn new() -> Self {
424 Self::with_config(ProbabilisticConfig::default())
425 }
426
427 pub fn with_config(config: ProbabilisticConfig) -> Self {
429 let mut decisions = HashMap::new();
430
431 for opt in [
433 "constant_folding",
434 "dead_code_elimination",
435 "fusion",
436 "vectorization",
437 "parallelization",
438 "tiling",
439 ] {
440 decisions.insert(opt.to_string(), UncertainDecision::new(opt.to_string()));
441 }
442
443 Self {
444 config,
445 decisions,
446 performance_model: None,
447 stats: CompilerStatistics::default(),
448 }
449 }
450
451 pub fn compile(
453 &mut self,
454 graph: &ComputationGraph,
455 ) -> JitResult<ProbabilisticCompilationResult> {
456 let node_count = graph.node_count() as f64;
458 let base_time = node_count * 10.0; let base_memory = node_count * 1024.0; let mut perf = ProbabilisticPerformance::new(base_time, base_memory);
463
464 let mut applied_opts = Vec::new();
466 let mut decisions_made = Vec::new();
467
468 for (opt_name, decision_model) in &self.decisions {
469 let should_apply = if self.config.bayesian_optimization {
470 decision_model.should_apply()
471 } else {
472 decision_model.prob_improvement.mean() > 0.5
473 };
474
475 if should_apply {
476 applied_opts.push(opt_name.clone());
477
478 let speedup = decision_model.speedup_dist.sample();
480 perf.time_dist.mean /= speedup;
481 }
482
483 decisions_made.push(OptimizationDecision {
484 optimization: opt_name.clone(),
485 applied: should_apply,
486 prob_improvement: decision_model.prob_improvement.mean(),
487 expected_speedup: decision_model.speedup_dist.mean,
488 credible_interval: decision_model
489 .prob_improvement
490 .credible_interval(self.config.confidence_level),
491 });
492 }
493
494 self.performance_model = Some(perf.clone());
495 self.stats.compilations += 1;
496
497 Ok(ProbabilisticCompilationResult {
498 performance: perf,
499 decisions: decisions_made,
500 applied_optimizations: applied_opts,
501 confidence_level: self.config.confidence_level,
502 risk_score: self.compute_overall_risk(),
503 })
504 }
505
506 pub fn observe_performance(
508 &mut self,
509 applied_opts: &[String],
510 actual_time: f64,
511 predicted_time: f64,
512 ) -> JitResult<()> {
513 let error = (actual_time - predicted_time).abs() / predicted_time;
515
516 let n = self.stats.compilations as f64;
518 self.stats.avg_prediction_error = (self.stats.avg_prediction_error * (n - 1.0) + error) / n;
519
520 for opt_name in applied_opts {
522 if let Some(decision) = self.decisions.get_mut(opt_name) {
523 let speedup = predicted_time / actual_time;
524 let beneficial = speedup > 1.0;
525 decision.observe(beneficial, speedup, HashMap::new());
526 }
527 }
528
529 if let Some(perf_model) = &mut self.performance_model {
531 let confidence_boost = 0.9; perf_model.time_dist.std_dev *= confidence_boost;
534 }
535
536 log::info!(
537 "Observed performance: actual={:.2}μs, predicted={:.2}μs, error={:.1}%",
538 actual_time,
539 predicted_time,
540 error * 100.0
541 );
542
543 Ok(())
544 }
545
546 fn compute_overall_risk(&self) -> f64 {
548 let mut total_risk = 0.0;
549 let mut count = 0;
550
551 for decision in self.decisions.values() {
552 total_risk += decision.risk;
553 count += 1;
554 }
555
556 if count > 0 {
557 total_risk / count as f64
558 } else {
559 0.5
560 }
561 }
562
563 pub fn monte_carlo_simulation(&self, num_samples: usize) -> MonteCarloResult {
565 if let Some(perf_model) = &self.performance_model {
566 let mut samples = Vec::with_capacity(num_samples);
567
568 for _ in 0..num_samples {
569 let time = perf_model.sample_time();
570 samples.push(time);
571 }
572
573 samples.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
574
575 let mean = samples.iter().sum::<f64>() / num_samples as f64;
576 let variance =
577 samples.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / num_samples as f64;
578
579 let p50 = samples[num_samples / 2];
580 let p95 = samples[(num_samples as f64 * 0.95) as usize];
581 let p99 = samples[(num_samples as f64 * 0.99) as usize];
582
583 MonteCarloResult {
584 mean,
585 std_dev: variance.sqrt(),
586 percentiles: vec![(50, p50), (95, p95), (99, p99)],
587 samples: samples.into_iter().take(100).collect(), }
589 } else {
590 MonteCarloResult::default()
591 }
592 }
593
594 pub fn statistics(&self) -> &CompilerStatistics {
596 &self.stats
597 }
598}
599
600impl Default for ProbabilisticCompiler {
601 fn default() -> Self {
602 Self::new()
603 }
604}
605
606#[derive(Debug, Clone)]
612pub struct ProbabilisticCompilationResult {
613 pub performance: ProbabilisticPerformance,
615
616 pub decisions: Vec<OptimizationDecision>,
618
619 pub applied_optimizations: Vec<String>,
621
622 pub confidence_level: f64,
624
625 pub risk_score: f64,
627}
628
629#[derive(Debug, Clone)]
631pub struct OptimizationDecision {
632 pub optimization: String,
634
635 pub applied: bool,
637
638 pub prob_improvement: f64,
640
641 pub expected_speedup: f64,
643
644 pub credible_interval: (f64, f64),
646}
647
648#[derive(Debug, Clone, Default)]
650pub struct MonteCarloResult {
651 pub mean: f64,
653
654 pub std_dev: f64,
656
657 pub percentiles: Vec<(usize, f64)>,
659
660 pub samples: Vec<f64>,
662}
663
664#[cfg(test)]
669mod tests {
670 use super::*;
671 use crate::graph::GraphBuilder;
672 use torsh_core::{DType, Shape};
673
674 #[test]
675 fn test_normal_distribution() {
676 let dist = NormalDistribution::new(100.0, 10.0);
677 assert_eq!(dist.mean, 100.0);
678 assert_eq!(dist.std_dev, 10.0);
679
680 let sample = dist.sample();
681 assert!(sample > 0.0); let (lower, upper) = dist.confidence_interval(0.95);
684 assert!(lower < dist.mean);
685 assert!(upper > dist.mean);
686 }
687
688 #[test]
689 fn test_beta_distribution() {
690 let mut dist = BetaDistribution::new(10.0, 2.0);
691 let mean = dist.mean();
692 assert!(mean > 0.5); dist.update(true);
695 assert!(dist.alpha == 11.0);
696
697 let (lower, upper) = dist.credible_interval(0.95);
698 assert!(lower <= upper);
699 }
700
701 #[test]
702 fn test_uncertain_decision() {
703 let mut decision = UncertainDecision::new("fusion".to_string());
704
705 decision.observe(true, 1.5, HashMap::new());
707 assert!(decision.prob_improvement.alpha > 1.0);
708
709 let utility = decision.expected_utility(0.5);
710 assert!(utility > 0.0);
711 }
712
713 #[test]
714 fn test_probabilistic_compilation() {
715 let mut compiler = ProbabilisticCompiler::new();
716
717 let mut builder = GraphBuilder::new();
718 let x = builder.add_input("x".to_string(), Shape::new(vec![10, 10]), DType::F32);
719 builder.mark_output(x).unwrap();
720
721 let graph = builder.build().unwrap();
722 let result = compiler.compile(&graph).unwrap();
723
724 assert!(!result.decisions.is_empty());
725 assert!(result.performance.expected_time() > 0.0);
726 assert!(result.confidence_level == 0.95);
727 }
728
729 #[test]
730 fn test_performance_observation() {
731 let mut compiler = ProbabilisticCompiler::new();
732
733 let mut builder = GraphBuilder::new();
734 let x = builder.add_input("x".to_string(), Shape::new(vec![5, 5]), DType::F32);
735 builder.mark_output(x).unwrap();
736
737 let graph = builder.build().unwrap();
738 let result = compiler.compile(&graph).unwrap();
739
740 let predicted = result.performance.expected_time();
741 let actual = predicted * 1.1; let obs_result =
744 compiler.observe_performance(&result.applied_optimizations, actual, predicted);
745 assert!(obs_result.is_ok());
746 }
747
748 #[test]
749 fn test_monte_carlo() {
750 let mut compiler = ProbabilisticCompiler::new();
751
752 let mut builder = GraphBuilder::new();
753 let x = builder.add_input("x".to_string(), Shape::new(vec![8, 8]), DType::F32);
754 builder.mark_output(x).unwrap();
755
756 let graph = builder.build().unwrap();
757 let _ = compiler.compile(&graph).unwrap();
758
759 let mc_result = compiler.monte_carlo_simulation(100);
760 assert!(mc_result.mean > 0.0);
761 assert!(mc_result.std_dev >= 0.0);
762 assert!(!mc_result.percentiles.is_empty());
763 }
764
765 #[test]
766 fn test_value_at_risk() {
767 let perf = ProbabilisticPerformance::new(1000.0, 10000.0);
768 let var_95 = perf.value_at_risk(0.95);
769 let var_99 = perf.value_at_risk(0.99);
770
771 assert!(var_95 > perf.expected_time());
772 assert!(var_99 > var_95); }
774}