Skip to main content

torsh_jit/
probabilistic_compilation.rs

1// Copyright (c) 2025 ToRSh Contributors
2// SPDX-License-Identifier: Apache-2.0 OR MIT
3
4//! # Probabilistic Compilation
5//!
6//! This module implements uncertainty-aware compilation using probabilistic models
7//! to handle variability in hardware, workloads, and optimization effects.
8//!
9//! ## Key Concepts
10//!
11//! - **Probabilistic Performance Models**: Predict performance with confidence intervals
12//! - **Bayesian Optimization**: Learn optimal compilation strategies with uncertainty
13//! - **Risk-Aware Decisions**: Make compilation choices considering worst-case scenarios
14//! - **Monte Carlo Sampling**: Estimate optimization effects through sampling
15//! - **Credible Intervals**: Quantify uncertainty in performance predictions
16//!
17//! ## Architecture
18//!
19//! ```text
20//! Program → Probabilistic Model → Distribution over Outcomes
21//!              ↓                          ↓
22//!         Uncertainty Quantification → Risk-Aware Decisions
23//! ```
24//!
25//! ## Example
26//!
27//! ```rust,ignore
28//! use torsh_jit::probabilistic_compilation::{ProbabilisticCompiler, CompilationConfig};
29//!
30//! let compiler = ProbabilisticCompiler::new();
31//!
32//! // Compile with uncertainty estimates
33//! let result = compiler.compile(&graph)?;
34//!
35//! println!("Expected time: {} ± {} μs",
36//!          result.mean_time, result.std_time);
37//! println!("95% confidence: [{}, {}]",
38//!          result.confidence_interval.0,
39//!          result.confidence_interval.1);
40//! ```
41
42use crate::graph::ComputationGraph;
43use crate::JitResult;
44use serde::{Deserialize, Serialize};
45use std::collections::HashMap;
46
47// ============================================================================
48// Probability Distributions
49// ============================================================================
50
51/// Normal (Gaussian) distribution
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct NormalDistribution {
54    /// Mean
55    pub mean: f64,
56
57    /// Standard deviation
58    pub std_dev: f64,
59}
60
61impl NormalDistribution {
62    /// Create new normal distribution
63    pub fn new(mean: f64, std_dev: f64) -> Self {
64        Self { mean, std_dev }
65    }
66
67    /// Sample from distribution (Box-Muller transform)
68    pub fn sample(&self) -> f64 {
69        // Simplified: In production, use scirs2-core random module
70        use std::f64::consts::PI;
71        let u1 = self.uniform_sample();
72        let u2 = self.uniform_sample();
73
74        let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * PI * u2).cos();
75        self.mean + self.std_dev * z0
76    }
77
78    fn uniform_sample(&self) -> f64 {
79        use std::time::{SystemTime, UNIX_EPOCH};
80        let nanos = SystemTime::now()
81            .duration_since(UNIX_EPOCH)
82            .expect("system time should be after UNIX_EPOCH")
83            .subsec_nanos();
84        ((nanos % 10000) as f64 / 10000.0).max(0.0001)
85    }
86
87    /// Probability density function
88    pub fn pdf(&self, x: f64) -> f64 {
89        let coefficient = 1.0 / (self.std_dev * (2.0 * std::f64::consts::PI).sqrt());
90        let exponent = -((x - self.mean).powi(2)) / (2.0 * self.std_dev.powi(2));
91        coefficient * exponent.exp()
92    }
93
94    /// Cumulative distribution function (approximation)
95    pub fn cdf(&self, x: f64) -> f64 {
96        let z = (x - self.mean) / self.std_dev;
97        0.5 * (1.0 + Self::erf(z / std::f64::consts::SQRT_2))
98    }
99
100    /// Error function (approximation)
101    fn erf(x: f64) -> f64 {
102        // Abramowitz and Stegun approximation
103        let a1 = 0.254829592;
104        let a2 = -0.284496736;
105        let a3 = 1.421413741;
106        let a4 = -1.453152027;
107        let a5 = 1.061405429;
108        let p = 0.3275911;
109
110        let sign = if x < 0.0 { -1.0 } else { 1.0 };
111        let x = x.abs();
112
113        let t = 1.0 / (1.0 + p * x);
114        let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
115
116        sign * y
117    }
118
119    /// Get confidence interval
120    pub fn confidence_interval(&self, confidence: f64) -> (f64, f64) {
121        // For 95% confidence: ±1.96 std_dev
122        // For 99% confidence: ±2.576 std_dev
123        let z_score = match confidence {
124            c if c >= 0.99 => 2.576,
125            c if c >= 0.95 => 1.96,
126            c if c >= 0.90 => 1.645,
127            _ => 1.0,
128        };
129
130        let margin = z_score * self.std_dev;
131        (self.mean - margin, self.mean + margin)
132    }
133}
134
135/// Beta distribution (for probabilities)
136#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct BetaDistribution {
138    /// Alpha parameter
139    pub alpha: f64,
140
141    /// Beta parameter
142    pub beta: f64,
143}
144
145impl BetaDistribution {
146    /// Create new beta distribution
147    pub fn new(alpha: f64, beta: f64) -> Self {
148        Self { alpha, beta }
149    }
150
151    /// Mean of distribution
152    pub fn mean(&self) -> f64 {
153        self.alpha / (self.alpha + self.beta)
154    }
155
156    /// Mode of distribution (most likely value)
157    pub fn mode(&self) -> f64 {
158        if self.alpha > 1.0 && self.beta > 1.0 {
159            (self.alpha - 1.0) / (self.alpha + self.beta - 2.0)
160        } else {
161            self.mean()
162        }
163    }
164
165    /// Variance
166    pub fn variance(&self) -> f64 {
167        let sum = self.alpha + self.beta;
168        (self.alpha * self.beta) / (sum.powi(2) * (sum + 1.0))
169    }
170
171    /// Update parameters with new observation (Bayesian update)
172    pub fn update(&mut self, success: bool) {
173        if success {
174            self.alpha += 1.0;
175        } else {
176            self.beta += 1.0;
177        }
178    }
179
180    /// Credible interval
181    pub fn credible_interval(&self, confidence: f64) -> (f64, f64) {
182        // Simplified approximation using normal approximation
183        let mean = self.mean();
184        let std_dev = self.variance().sqrt();
185
186        let z_score = if confidence >= 0.95 { 1.96 } else { 1.645 };
187        let margin = z_score * std_dev;
188
189        ((mean - margin).max(0.0), (mean + margin).min(1.0))
190    }
191}
192
193// ============================================================================
194// Probabilistic Performance Model
195// ============================================================================
196
197/// Performance prediction with uncertainty
198#[derive(Debug, Clone)]
199pub struct ProbabilisticPerformance {
200    /// Execution time distribution
201    pub time_dist: NormalDistribution,
202
203    /// Memory usage distribution
204    pub memory_dist: NormalDistribution,
205
206    /// Success probability (compilation succeeds)
207    pub success_prob: BetaDistribution,
208
209    /// Performance variance factors
210    pub variance_factors: HashMap<String, f64>,
211}
212
213impl ProbabilisticPerformance {
214    /// Create new performance model
215    pub fn new(mean_time: f64, mean_memory: f64) -> Self {
216        Self {
217            time_dist: NormalDistribution::new(mean_time, mean_time * 0.2), // 20% uncertainty
218            memory_dist: NormalDistribution::new(mean_memory, mean_memory * 0.15), // 15% uncertainty
219            success_prob: BetaDistribution::new(10.0, 1.0), // Optimistic prior
220            variance_factors: HashMap::new(),
221        }
222    }
223
224    /// Sample execution time
225    pub fn sample_time(&self) -> f64 {
226        self.time_dist.sample().max(0.0)
227    }
228
229    /// Sample memory usage
230    pub fn sample_memory(&self) -> f64 {
231        self.memory_dist.sample().max(0.0)
232    }
233
234    /// Get confidence intervals
235    pub fn time_confidence_interval(&self, confidence: f64) -> (f64, f64) {
236        self.time_dist.confidence_interval(confidence)
237    }
238
239    pub fn memory_confidence_interval(&self, confidence: f64) -> (f64, f64) {
240        self.memory_dist.confidence_interval(confidence)
241    }
242
243    /// Expected value (mean)
244    pub fn expected_time(&self) -> f64 {
245        self.time_dist.mean
246    }
247
248    pub fn expected_memory(&self) -> f64 {
249        self.memory_dist.mean
250    }
251
252    /// Value at risk (VaR) - worst case with given probability
253    pub fn value_at_risk(&self, percentile: f64) -> f64 {
254        // Approximate: mean + z*std_dev
255        let z = match percentile {
256            p if p >= 0.99 => 2.326, // 99th percentile
257            p if p >= 0.95 => 1.645, // 95th percentile
258            _ => 1.0,
259        };
260        self.time_dist.mean + z * self.time_dist.std_dev
261    }
262}
263
264// ============================================================================
265// Optimization Decision Under Uncertainty
266// ============================================================================
267
268/// Optimization decision with uncertainty
269#[derive(Debug, Clone)]
270pub struct UncertainDecision {
271    /// Optimization name
272    pub optimization: String,
273
274    /// Probability of improvement
275    pub prob_improvement: BetaDistribution,
276
277    /// Expected speedup distribution
278    pub speedup_dist: NormalDistribution,
279
280    /// Risk (variance of outcome)
281    pub risk: f64,
282
283    /// Historical observations
284    pub observations: Vec<Observation>,
285}
286
287/// Single observation of optimization effect
288#[derive(Debug, Clone)]
289pub struct Observation {
290    /// Was it beneficial?
291    pub beneficial: bool,
292
293    /// Speedup achieved
294    pub speedup: f64,
295
296    /// Context features
297    pub context: HashMap<String, f64>,
298}
299
300impl UncertainDecision {
301    /// Create new decision with prior
302    pub fn new(optimization: String) -> Self {
303        Self {
304            optimization,
305            prob_improvement: BetaDistribution::new(1.0, 1.0), // Uniform prior
306            speedup_dist: NormalDistribution::new(1.2, 0.3),   // Modest expected speedup
307            risk: 0.5,
308            observations: Vec::new(),
309        }
310    }
311
312    /// Update with new observation
313    pub fn observe(&mut self, beneficial: bool, speedup: f64, context: HashMap<String, f64>) {
314        // Bayesian update
315        self.prob_improvement.update(beneficial);
316
317        // Update speedup distribution
318        if !self.observations.is_empty() {
319            let n = self.observations.len() as f64;
320            let old_mean = self.speedup_dist.mean;
321            let new_mean = (old_mean * n + speedup) / (n + 1.0);
322            self.speedup_dist.mean = new_mean;
323
324            // Update variance
325            let old_var = self.speedup_dist.std_dev.powi(2);
326            let new_var = ((old_var * n) + (speedup - new_mean).powi(2)) / (n + 1.0);
327            self.speedup_dist.std_dev = new_var.sqrt();
328        } else {
329            self.speedup_dist.mean = speedup;
330        }
331
332        self.observations.push(Observation {
333            beneficial,
334            speedup,
335            context,
336        });
337
338        // Update risk
339        self.risk = self.prob_improvement.variance();
340    }
341
342    /// Should apply this optimization? (Thompson sampling)
343    pub fn should_apply(&self) -> bool {
344        // Sample from posterior
345        let prob = self.prob_improvement.mean();
346        let threshold = 0.6; // Apply if >60% chance of improvement
347        prob > threshold
348    }
349
350    /// Expected utility (reward - risk)
351    pub fn expected_utility(&self, risk_aversion: f64) -> f64 {
352        let expected_reward = self.speedup_dist.mean * self.prob_improvement.mean();
353        let risk_penalty = risk_aversion * self.risk;
354        expected_reward - risk_penalty
355    }
356}
357
358// ============================================================================
359// Probabilistic Compiler
360// ============================================================================
361
362/// Main probabilistic compilation engine
363pub struct ProbabilisticCompiler {
364    /// Configuration
365    config: ProbabilisticConfig,
366
367    /// Decision models for each optimization
368    decisions: HashMap<String, UncertainDecision>,
369
370    /// Performance model
371    performance_model: Option<ProbabilisticPerformance>,
372
373    /// Statistics
374    stats: CompilerStatistics,
375}
376
377/// Configuration
378#[derive(Debug, Clone)]
379pub struct ProbabilisticConfig {
380    /// Confidence level for intervals
381    pub confidence_level: f64,
382
383    /// Risk aversion factor (0 = risk-neutral, 1 = risk-averse)
384    pub risk_aversion: f64,
385
386    /// Number of Monte Carlo samples
387    pub num_samples: usize,
388
389    /// Enable Bayesian optimization
390    pub bayesian_optimization: bool,
391
392    /// Exploration rate
393    pub exploration_rate: f64,
394}
395
396impl Default for ProbabilisticConfig {
397    fn default() -> Self {
398        Self {
399            confidence_level: 0.95,
400            risk_aversion: 0.5,
401            num_samples: 1000,
402            bayesian_optimization: true,
403            exploration_rate: 0.1,
404        }
405    }
406}
407
408/// Compiler statistics
409#[derive(Debug, Clone, Default)]
410pub struct CompilerStatistics {
411    /// Number of compilations
412    pub compilations: usize,
413
414    /// Average prediction error
415    pub avg_prediction_error: f64,
416
417    /// Calibration score (how well confidence matches reality)
418    pub calibration_score: f64,
419}
420
421impl ProbabilisticCompiler {
422    /// Create new probabilistic compiler
423    pub fn new() -> Self {
424        Self::with_config(ProbabilisticConfig::default())
425    }
426
427    /// Create with custom configuration
428    pub fn with_config(config: ProbabilisticConfig) -> Self {
429        let mut decisions = HashMap::new();
430
431        // Initialize decision models
432        for opt in [
433            "constant_folding",
434            "dead_code_elimination",
435            "fusion",
436            "vectorization",
437            "parallelization",
438            "tiling",
439        ] {
440            decisions.insert(opt.to_string(), UncertainDecision::new(opt.to_string()));
441        }
442
443        Self {
444            config,
445            decisions,
446            performance_model: None,
447            stats: CompilerStatistics::default(),
448        }
449    }
450
451    /// Compile with uncertainty quantification
452    pub fn compile(
453        &mut self,
454        graph: &ComputationGraph,
455    ) -> JitResult<ProbabilisticCompilationResult> {
456        // Estimate base performance
457        let node_count = graph.node_count() as f64;
458        let base_time = node_count * 10.0; // microseconds per node
459        let base_memory = node_count * 1024.0; // bytes per node
460
461        // Create performance model
462        let mut perf = ProbabilisticPerformance::new(base_time, base_memory);
463
464        // Decide which optimizations to apply
465        let mut applied_opts = Vec::new();
466        let mut decisions_made = Vec::new();
467
468        for (opt_name, decision_model) in &self.decisions {
469            let should_apply = if self.config.bayesian_optimization {
470                decision_model.should_apply()
471            } else {
472                decision_model.prob_improvement.mean() > 0.5
473            };
474
475            if should_apply {
476                applied_opts.push(opt_name.clone());
477
478                // Update performance estimate
479                let speedup = decision_model.speedup_dist.sample();
480                perf.time_dist.mean /= speedup;
481            }
482
483            decisions_made.push(OptimizationDecision {
484                optimization: opt_name.clone(),
485                applied: should_apply,
486                prob_improvement: decision_model.prob_improvement.mean(),
487                expected_speedup: decision_model.speedup_dist.mean,
488                credible_interval: decision_model
489                    .prob_improvement
490                    .credible_interval(self.config.confidence_level),
491            });
492        }
493
494        self.performance_model = Some(perf.clone());
495        self.stats.compilations += 1;
496
497        Ok(ProbabilisticCompilationResult {
498            performance: perf,
499            decisions: decisions_made,
500            applied_optimizations: applied_opts,
501            confidence_level: self.config.confidence_level,
502            risk_score: self.compute_overall_risk(),
503        })
504    }
505
506    /// Update models with actual performance
507    pub fn observe_performance(
508        &mut self,
509        applied_opts: &[String],
510        actual_time: f64,
511        predicted_time: f64,
512    ) -> JitResult<()> {
513        // Compute error
514        let error = (actual_time - predicted_time).abs() / predicted_time;
515
516        // Update prediction error statistics
517        let n = self.stats.compilations as f64;
518        self.stats.avg_prediction_error = (self.stats.avg_prediction_error * (n - 1.0) + error) / n;
519
520        // Update decision models
521        for opt_name in applied_opts {
522            if let Some(decision) = self.decisions.get_mut(opt_name) {
523                let speedup = predicted_time / actual_time;
524                let beneficial = speedup > 1.0;
525                decision.observe(beneficial, speedup, HashMap::new());
526            }
527        }
528
529        // Update performance model uncertainty
530        if let Some(perf_model) = &mut self.performance_model {
531            // Reduce uncertainty after more observations
532            let confidence_boost = 0.9; // Slightly reduce std_dev
533            perf_model.time_dist.std_dev *= confidence_boost;
534        }
535
536        log::info!(
537            "Observed performance: actual={:.2}μs, predicted={:.2}μs, error={:.1}%",
538            actual_time,
539            predicted_time,
540            error * 100.0
541        );
542
543        Ok(())
544    }
545
546    /// Compute overall risk score
547    fn compute_overall_risk(&self) -> f64 {
548        let mut total_risk = 0.0;
549        let mut count = 0;
550
551        for decision in self.decisions.values() {
552            total_risk += decision.risk;
553            count += 1;
554        }
555
556        if count > 0 {
557            total_risk / count as f64
558        } else {
559            0.5
560        }
561    }
562
563    /// Monte Carlo simulation
564    pub fn monte_carlo_simulation(&self, num_samples: usize) -> MonteCarloResult {
565        if let Some(perf_model) = &self.performance_model {
566            let mut samples = Vec::with_capacity(num_samples);
567
568            for _ in 0..num_samples {
569                let time = perf_model.sample_time();
570                samples.push(time);
571            }
572
573            samples.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
574
575            let mean = samples.iter().sum::<f64>() / num_samples as f64;
576            let variance =
577                samples.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / num_samples as f64;
578
579            let p50 = samples[num_samples / 2];
580            let p95 = samples[(num_samples as f64 * 0.95) as usize];
581            let p99 = samples[(num_samples as f64 * 0.99) as usize];
582
583            MonteCarloResult {
584                mean,
585                std_dev: variance.sqrt(),
586                percentiles: vec![(50, p50), (95, p95), (99, p99)],
587                samples: samples.into_iter().take(100).collect(), // Keep first 100 for visualization
588            }
589        } else {
590            MonteCarloResult::default()
591        }
592    }
593
594    /// Get statistics
595    pub fn statistics(&self) -> &CompilerStatistics {
596        &self.stats
597    }
598}
599
600impl Default for ProbabilisticCompiler {
601    fn default() -> Self {
602        Self::new()
603    }
604}
605
606// ============================================================================
607// Compilation Results
608// ============================================================================
609
610/// Result of probabilistic compilation
611#[derive(Debug, Clone)]
612pub struct ProbabilisticCompilationResult {
613    /// Performance prediction
614    pub performance: ProbabilisticPerformance,
615
616    /// Optimization decisions
617    pub decisions: Vec<OptimizationDecision>,
618
619    /// Applied optimizations
620    pub applied_optimizations: Vec<String>,
621
622    /// Confidence level
623    pub confidence_level: f64,
624
625    /// Overall risk score
626    pub risk_score: f64,
627}
628
629/// Single optimization decision
630#[derive(Debug, Clone)]
631pub struct OptimizationDecision {
632    /// Optimization name
633    pub optimization: String,
634
635    /// Was it applied?
636    pub applied: bool,
637
638    /// Probability of improvement
639    pub prob_improvement: f64,
640
641    /// Expected speedup
642    pub expected_speedup: f64,
643
644    /// Credible interval for probability
645    pub credible_interval: (f64, f64),
646}
647
648/// Monte Carlo simulation result
649#[derive(Debug, Clone, Default)]
650pub struct MonteCarloResult {
651    /// Mean of samples
652    pub mean: f64,
653
654    /// Standard deviation
655    pub std_dev: f64,
656
657    /// Percentiles (percentile, value)
658    pub percentiles: Vec<(usize, f64)>,
659
660    /// Sample values (subset)
661    pub samples: Vec<f64>,
662}
663
664// ============================================================================
665// Tests
666// ============================================================================
667
668#[cfg(test)]
669mod tests {
670    use super::*;
671    use crate::graph::GraphBuilder;
672    use torsh_core::{DType, Shape};
673
674    #[test]
675    fn test_normal_distribution() {
676        let dist = NormalDistribution::new(100.0, 10.0);
677        assert_eq!(dist.mean, 100.0);
678        assert_eq!(dist.std_dev, 10.0);
679
680        let sample = dist.sample();
681        assert!(sample > 0.0); // Should be positive
682
683        let (lower, upper) = dist.confidence_interval(0.95);
684        assert!(lower < dist.mean);
685        assert!(upper > dist.mean);
686    }
687
688    #[test]
689    fn test_beta_distribution() {
690        let mut dist = BetaDistribution::new(10.0, 2.0);
691        let mean = dist.mean();
692        assert!(mean > 0.5); // Should be skewed towards success
693
694        dist.update(true);
695        assert!(dist.alpha == 11.0);
696
697        let (lower, upper) = dist.credible_interval(0.95);
698        assert!(lower <= upper);
699    }
700
701    #[test]
702    fn test_uncertain_decision() {
703        let mut decision = UncertainDecision::new("fusion".to_string());
704
705        // Observe positive outcome
706        decision.observe(true, 1.5, HashMap::new());
707        assert!(decision.prob_improvement.alpha > 1.0);
708
709        let utility = decision.expected_utility(0.5);
710        assert!(utility > 0.0);
711    }
712
713    #[test]
714    fn test_probabilistic_compilation() {
715        let mut compiler = ProbabilisticCompiler::new();
716
717        let mut builder = GraphBuilder::new();
718        let x = builder.add_input("x".to_string(), Shape::new(vec![10, 10]), DType::F32);
719        builder.mark_output(x).unwrap();
720
721        let graph = builder.build().unwrap();
722        let result = compiler.compile(&graph).unwrap();
723
724        assert!(!result.decisions.is_empty());
725        assert!(result.performance.expected_time() > 0.0);
726        assert!(result.confidence_level == 0.95);
727    }
728
729    #[test]
730    fn test_performance_observation() {
731        let mut compiler = ProbabilisticCompiler::new();
732
733        let mut builder = GraphBuilder::new();
734        let x = builder.add_input("x".to_string(), Shape::new(vec![5, 5]), DType::F32);
735        builder.mark_output(x).unwrap();
736
737        let graph = builder.build().unwrap();
738        let result = compiler.compile(&graph).unwrap();
739
740        let predicted = result.performance.expected_time();
741        let actual = predicted * 1.1; // 10% slower
742
743        let obs_result =
744            compiler.observe_performance(&result.applied_optimizations, actual, predicted);
745        assert!(obs_result.is_ok());
746    }
747
748    #[test]
749    fn test_monte_carlo() {
750        let mut compiler = ProbabilisticCompiler::new();
751
752        let mut builder = GraphBuilder::new();
753        let x = builder.add_input("x".to_string(), Shape::new(vec![8, 8]), DType::F32);
754        builder.mark_output(x).unwrap();
755
756        let graph = builder.build().unwrap();
757        let _ = compiler.compile(&graph).unwrap();
758
759        let mc_result = compiler.monte_carlo_simulation(100);
760        assert!(mc_result.mean > 0.0);
761        assert!(mc_result.std_dev >= 0.0);
762        assert!(!mc_result.percentiles.is_empty());
763    }
764
765    #[test]
766    fn test_value_at_risk() {
767        let perf = ProbabilisticPerformance::new(1000.0, 10000.0);
768        let var_95 = perf.value_at_risk(0.95);
769        let var_99 = perf.value_at_risk(0.99);
770
771        assert!(var_95 > perf.expected_time());
772        assert!(var_99 > var_95); // Higher percentile = more conservative
773    }
774}