Skip to main content

tensorlogic_infer/
speculative.rs

1//! Speculative execution for computation graphs.
2//!
3//! This module implements speculative execution techniques:
4//! - **Branch prediction**: Predict conditional branches and execute speculatively
5//! - **Prefetching**: Pre-execute likely future operations
6//! - **Rollback mechanisms**: Discard incorrect speculative results
7//! - **Confidence scoring**: Track prediction accuracy
8//! - **Adaptive strategies**: Learn from prediction success/failure
9//!
10//! ## Example
11//!
12//! ```rust,ignore
13//! use tensorlogic_infer::{SpeculativeExecutor, PredictionStrategy, RollbackPolicy};
14//!
15//! // Create speculative executor
16//! let executor = SpeculativeExecutor::new()
17//!     .with_strategy(PredictionStrategy::HistoryBased)
18//!     .with_rollback_policy(RollbackPolicy::Immediate)
19//!     .with_confidence_threshold(0.7);
20//!
21//! // Execute with speculation
22//! let result = executor.execute_speculative(&graph, &inputs)?;
23//!
24//! // Check speculation stats
25//! let stats = executor.get_stats();
26//! println!("Speculation success rate: {:.1}%", stats.success_rate * 100.0);
27//! ```
28
29use serde::{Deserialize, Serialize};
30use std::collections::{HashMap, VecDeque};
31use thiserror::Error;
32
33/// Speculative execution errors.
34#[derive(Error, Debug, Clone, PartialEq)]
35pub enum SpeculativeError {
36    #[error("Speculation failed: {0}")]
37    SpeculationFailed(String),
38
39    #[error("Rollback failed: {0}")]
40    RollbackFailed(String),
41
42    #[error("Invalid prediction: {0}")]
43    InvalidPrediction(String),
44
45    #[error("Checkpoint not found: {0}")]
46    CheckpointNotFound(String),
47}
48
49/// Node ID in the computation graph.
50pub type NodeId = String;
51
52/// Prediction strategy for speculative execution.
53#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
54pub enum PredictionStrategy {
55    /// Always predict most frequent branch
56    MostFrequent,
57    /// Use recent history to predict
58    HistoryBased,
59    /// Use static analysis and heuristics
60    Static,
61    /// Adaptive strategy that learns over time
62    Adaptive,
63    /// Always speculate on true branch
64    AlwaysTrue,
65    /// Never speculate (conservative)
66    Never,
67}
68
69/// Rollback policy when speculation is incorrect.
70#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
71pub enum RollbackPolicy {
72    /// Immediately rollback on misprediction
73    Immediate,
74    /// Continue speculation and rollback later
75    Lazy,
76    /// Checkpoint-based rollback
77    Checkpoint,
78}
79
80/// Branch outcome.
81#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
82pub enum BranchOutcome {
83    True,
84    False,
85    Unknown,
86}
87
88/// Speculative task representing work done speculatively.
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct SpeculativeTask {
91    pub task_id: u64,
92    pub node_id: NodeId,
93    pub predicted_branch: BranchOutcome,
94    pub confidence: f64,
95    pub started_at: u64, // timestamp in microseconds
96    pub completed: bool,
97    pub correct: Option<bool>, // None if not yet validated
98}
99
100/// Branch history entry.
101#[derive(Debug, Clone, Serialize, Deserialize)]
102struct BranchHistory {
103    node_id: NodeId,
104    outcomes: VecDeque<BranchOutcome>,
105    max_history: usize,
106}
107
108impl BranchHistory {
109    fn new(node_id: NodeId, max_history: usize) -> Self {
110        Self {
111            node_id,
112            outcomes: VecDeque::new(),
113            max_history,
114        }
115    }
116
117    fn add_outcome(&mut self, outcome: BranchOutcome) {
118        if self.outcomes.len() >= self.max_history {
119            self.outcomes.pop_front();
120        }
121        self.outcomes.push_back(outcome);
122    }
123
124    fn predict(&self) -> (BranchOutcome, f64) {
125        if self.outcomes.is_empty() {
126            return (BranchOutcome::Unknown, 0.0);
127        }
128
129        let true_count = self
130            .outcomes
131            .iter()
132            .filter(|&&o| o == BranchOutcome::True)
133            .count();
134        let false_count = self
135            .outcomes
136            .iter()
137            .filter(|&&o| o == BranchOutcome::False)
138            .count();
139        let total = true_count + false_count;
140
141        if total == 0 {
142            return (BranchOutcome::Unknown, 0.0);
143        }
144
145        if true_count > false_count {
146            (BranchOutcome::True, true_count as f64 / total as f64)
147        } else {
148            (BranchOutcome::False, false_count as f64 / total as f64)
149        }
150    }
151}
152
153/// Speculation statistics.
154#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct SpeculationStats {
156    pub total_speculations: usize,
157    pub correct_speculations: usize,
158    pub incorrect_speculations: usize,
159    pub rollbacks: usize,
160    pub success_rate: f64,
161    pub average_confidence: f64,
162    pub time_saved_us: f64,
163    pub time_wasted_us: f64,
164}
165
166/// Checkpoint for rollback.
167#[derive(Debug, Clone, Serialize, Deserialize)]
168struct Checkpoint {
169    checkpoint_id: u64,
170    node_id: NodeId,
171    timestamp: u64,
172    // In real implementation, this would store actual state
173}
174
175/// Speculative executor.
176pub struct SpeculativeExecutor {
177    strategy: PredictionStrategy,
178    rollback_policy: RollbackPolicy,
179    confidence_threshold: f64,
180    max_speculation_depth: usize,
181    branch_history: HashMap<NodeId, BranchHistory>,
182    active_tasks: HashMap<u64, SpeculativeTask>,
183    checkpoints: HashMap<u64, Checkpoint>,
184    next_task_id: u64,
185    next_checkpoint_id: u64,
186    stats: SpeculationStats,
187    history_length: usize,
188}
189
190impl SpeculativeExecutor {
191    /// Create a new speculative executor with default settings.
192    pub fn new() -> Self {
193        Self {
194            strategy: PredictionStrategy::HistoryBased,
195            rollback_policy: RollbackPolicy::Immediate,
196            confidence_threshold: 0.6,
197            max_speculation_depth: 3,
198            branch_history: HashMap::new(),
199            active_tasks: HashMap::new(),
200            checkpoints: HashMap::new(),
201            next_task_id: 0,
202            next_checkpoint_id: 0,
203            stats: SpeculationStats {
204                total_speculations: 0,
205                correct_speculations: 0,
206                incorrect_speculations: 0,
207                rollbacks: 0,
208                success_rate: 0.0,
209                average_confidence: 0.0,
210                time_saved_us: 0.0,
211                time_wasted_us: 0.0,
212            },
213            history_length: 10,
214        }
215    }
216
217    /// Set prediction strategy.
218    pub fn with_strategy(mut self, strategy: PredictionStrategy) -> Self {
219        self.strategy = strategy;
220        self
221    }
222
223    /// Set rollback policy.
224    pub fn with_rollback_policy(mut self, policy: RollbackPolicy) -> Self {
225        self.rollback_policy = policy;
226        self
227    }
228
229    /// Set confidence threshold for speculation.
230    pub fn with_confidence_threshold(mut self, threshold: f64) -> Self {
231        self.confidence_threshold = threshold.clamp(0.0, 1.0);
232        self
233    }
234
235    /// Set maximum speculation depth.
236    pub fn with_max_depth(mut self, depth: usize) -> Self {
237        self.max_speculation_depth = depth;
238        self
239    }
240
241    /// Predict branch outcome for a node.
242    pub fn predict_branch(&self, node_id: &NodeId) -> (BranchOutcome, f64) {
243        match self.strategy {
244            PredictionStrategy::Never => (BranchOutcome::Unknown, 0.0),
245            PredictionStrategy::AlwaysTrue => (BranchOutcome::True, 1.0),
246            PredictionStrategy::MostFrequent => {
247                if let Some(history) = self.branch_history.get(node_id) {
248                    history.predict()
249                } else {
250                    (BranchOutcome::True, 0.5) // Default to true with low confidence
251                }
252            }
253            PredictionStrategy::HistoryBased => {
254                if let Some(history) = self.branch_history.get(node_id) {
255                    history.predict()
256                } else {
257                    (BranchOutcome::Unknown, 0.0)
258                }
259            }
260            PredictionStrategy::Static | PredictionStrategy::Adaptive => {
261                // Simplified: use history if available
262                if let Some(history) = self.branch_history.get(node_id) {
263                    history.predict()
264                } else {
265                    (BranchOutcome::True, 0.5)
266                }
267            }
268        }
269    }
270
271    /// Start speculative execution for a branch.
272    pub fn speculate(&mut self, node_id: NodeId) -> Result<u64, SpeculativeError> {
273        let (predicted_branch, confidence) = self.predict_branch(&node_id);
274
275        // Only speculate if confidence exceeds threshold
276        if confidence < self.confidence_threshold {
277            return Err(SpeculativeError::SpeculationFailed(format!(
278                "Confidence {} below threshold {}",
279                confidence, self.confidence_threshold
280            )));
281        }
282
283        // Check speculation depth
284        let active_count = self.active_tasks.values().filter(|t| !t.completed).count();
285
286        if active_count >= self.max_speculation_depth {
287            return Err(SpeculativeError::SpeculationFailed(format!(
288                "Maximum speculation depth {} reached",
289                self.max_speculation_depth
290            )));
291        }
292
293        // Create speculative task
294        let task_id = self.next_task_id;
295        self.next_task_id += 1;
296
297        let task = SpeculativeTask {
298            task_id,
299            node_id: node_id.clone(),
300            predicted_branch,
301            confidence,
302            started_at: 0, // Would be real timestamp
303            completed: false,
304            correct: None,
305        };
306
307        self.active_tasks.insert(task_id, task);
308        self.stats.total_speculations += 1;
309
310        Ok(task_id)
311    }
312
313    /// Validate speculative execution result.
314    pub fn validate(
315        &mut self,
316        task_id: u64,
317        actual_branch: BranchOutcome,
318    ) -> Result<bool, SpeculativeError> {
319        let task = self.active_tasks.get_mut(&task_id).ok_or_else(|| {
320            SpeculativeError::InvalidPrediction(format!("Task {} not found", task_id))
321        })?;
322
323        let correct = task.predicted_branch == actual_branch;
324        task.correct = Some(correct);
325        task.completed = true;
326
327        // Update history
328        let history = self
329            .branch_history
330            .entry(task.node_id.clone())
331            .or_insert_with(|| BranchHistory::new(task.node_id.clone(), self.history_length));
332        history.add_outcome(actual_branch);
333
334        // Update stats
335        if correct {
336            self.stats.correct_speculations += 1;
337        } else {
338            self.stats.incorrect_speculations += 1;
339            // Perform rollback if needed
340            self.rollback(task_id)?;
341        }
342
343        self.update_stats();
344
345        Ok(correct)
346    }
347
348    /// Rollback speculative execution.
349    fn rollback(&mut self, task_id: u64) -> Result<(), SpeculativeError> {
350        match self.rollback_policy {
351            RollbackPolicy::Immediate => {
352                // Immediately discard speculative work
353                self.active_tasks.remove(&task_id);
354                self.stats.rollbacks += 1;
355                Ok(())
356            }
357            RollbackPolicy::Lazy => {
358                // Mark for later cleanup
359                if let Some(task) = self.active_tasks.get_mut(&task_id) {
360                    task.completed = true;
361                }
362                self.stats.rollbacks += 1;
363                Ok(())
364            }
365            RollbackPolicy::Checkpoint => {
366                // Restore from checkpoint
367                self.restore_checkpoint(task_id)?;
368                self.stats.rollbacks += 1;
369                Ok(())
370            }
371        }
372    }
373
374    /// Create checkpoint before speculation.
375    pub fn create_checkpoint(&mut self, node_id: NodeId) -> u64 {
376        let checkpoint_id = self.next_checkpoint_id;
377        self.next_checkpoint_id += 1;
378
379        let checkpoint = Checkpoint {
380            checkpoint_id,
381            node_id,
382            timestamp: 0, // Would be real timestamp
383        };
384
385        self.checkpoints.insert(checkpoint_id, checkpoint);
386        checkpoint_id
387    }
388
389    /// Restore from checkpoint.
390    fn restore_checkpoint(&mut self, task_id: u64) -> Result<(), SpeculativeError> {
391        // Find and restore checkpoint
392        let _task = self.active_tasks.get(&task_id).ok_or_else(|| {
393            SpeculativeError::CheckpointNotFound(format!("No task found for id: {}", task_id))
394        })?;
395
396        // In real implementation, would restore actual state
397        self.active_tasks.remove(&task_id);
398        Ok(())
399    }
400
401    /// Update speculation statistics.
402    fn update_stats(&mut self) {
403        let total = (self.stats.correct_speculations + self.stats.incorrect_speculations) as f64;
404        if total > 0.0 {
405            self.stats.success_rate = self.stats.correct_speculations as f64 / total;
406        }
407
408        let confidence_sum: f64 = self.active_tasks.values().map(|t| t.confidence).sum();
409        let task_count = self.active_tasks.len() as f64;
410        if task_count > 0.0 {
411            self.stats.average_confidence = confidence_sum / task_count;
412        }
413    }
414
415    /// Get speculation statistics.
416    pub fn get_stats(&self) -> &SpeculationStats {
417        &self.stats
418    }
419
420    /// Clear completed speculative tasks.
421    pub fn cleanup(&mut self) {
422        self.active_tasks.retain(|_, task| !task.completed);
423    }
424
425    /// Reset statistics.
426    pub fn reset_stats(&mut self) {
427        self.stats = SpeculationStats {
428            total_speculations: 0,
429            correct_speculations: 0,
430            incorrect_speculations: 0,
431            rollbacks: 0,
432            success_rate: 0.0,
433            average_confidence: 0.0,
434            time_saved_us: 0.0,
435            time_wasted_us: 0.0,
436        };
437    }
438
439    /// Get active speculation count.
440    pub fn active_speculation_count(&self) -> usize {
441        self.active_tasks.values().filter(|t| !t.completed).count()
442    }
443
444    /// Check if should speculate based on current state.
445    pub fn should_speculate(&self, node_id: &NodeId) -> bool {
446        let (_, confidence) = self.predict_branch(node_id);
447        confidence >= self.confidence_threshold
448            && self.active_speculation_count() < self.max_speculation_depth
449    }
450}
451
452impl Default for SpeculativeExecutor {
453    fn default() -> Self {
454        Self::new()
455    }
456}
457
458#[cfg(test)]
459mod tests {
460    use super::*;
461
462    #[test]
463    fn test_speculative_executor_creation() {
464        let executor = SpeculativeExecutor::new();
465        assert_eq!(executor.strategy, PredictionStrategy::HistoryBased);
466        assert_eq!(executor.rollback_policy, RollbackPolicy::Immediate);
467        assert_eq!(executor.confidence_threshold, 0.6);
468    }
469
470    #[test]
471    fn test_builder_pattern() {
472        let executor = SpeculativeExecutor::new()
473            .with_strategy(PredictionStrategy::Adaptive)
474            .with_rollback_policy(RollbackPolicy::Checkpoint)
475            .with_confidence_threshold(0.8)
476            .with_max_depth(5);
477
478        assert_eq!(executor.strategy, PredictionStrategy::Adaptive);
479        assert_eq!(executor.rollback_policy, RollbackPolicy::Checkpoint);
480        assert_eq!(executor.confidence_threshold, 0.8);
481        assert_eq!(executor.max_speculation_depth, 5);
482    }
483
484    #[test]
485    fn test_always_true_prediction() {
486        let executor = SpeculativeExecutor::new().with_strategy(PredictionStrategy::AlwaysTrue);
487
488        let (outcome, confidence) = executor.predict_branch(&"test".to_string());
489        assert_eq!(outcome, BranchOutcome::True);
490        assert_eq!(confidence, 1.0);
491    }
492
493    #[test]
494    fn test_never_speculation() {
495        let executor = SpeculativeExecutor::new().with_strategy(PredictionStrategy::Never);
496
497        let (outcome, confidence) = executor.predict_branch(&"test".to_string());
498        assert_eq!(outcome, BranchOutcome::Unknown);
499        assert_eq!(confidence, 0.0);
500    }
501
502    #[test]
503    fn test_speculation_below_threshold() {
504        let mut executor = SpeculativeExecutor::new().with_confidence_threshold(0.9);
505
506        let result = executor.speculate("test".to_string());
507        assert!(result.is_err()); // Should fail due to low confidence
508    }
509
510    #[test]
511    fn test_successful_speculation() {
512        let mut executor = SpeculativeExecutor::new()
513            .with_strategy(PredictionStrategy::AlwaysTrue)
514            .with_confidence_threshold(0.5);
515
516        let task_id = executor.speculate("test".to_string()).unwrap();
517        assert_eq!(executor.stats.total_speculations, 1);
518        assert!(executor.active_tasks.contains_key(&task_id));
519    }
520
521    #[test]
522    fn test_correct_validation() {
523        let mut executor = SpeculativeExecutor::new()
524            .with_strategy(PredictionStrategy::AlwaysTrue)
525            .with_confidence_threshold(0.5);
526
527        let task_id = executor.speculate("test".to_string()).unwrap();
528        let correct = executor.validate(task_id, BranchOutcome::True).unwrap();
529
530        assert!(correct);
531        assert_eq!(executor.stats.correct_speculations, 1);
532        assert_eq!(executor.stats.incorrect_speculations, 0);
533    }
534
535    #[test]
536    fn test_incorrect_validation() {
537        let mut executor = SpeculativeExecutor::new()
538            .with_strategy(PredictionStrategy::AlwaysTrue)
539            .with_confidence_threshold(0.5);
540
541        let task_id = executor.speculate("test".to_string()).unwrap();
542        let correct = executor.validate(task_id, BranchOutcome::False).unwrap();
543
544        assert!(!correct);
545        assert_eq!(executor.stats.correct_speculations, 0);
546        assert_eq!(executor.stats.incorrect_speculations, 1);
547        assert_eq!(executor.stats.rollbacks, 1);
548    }
549
550    #[test]
551    fn test_history_based_prediction() {
552        let mut executor = SpeculativeExecutor::new()
553            .with_strategy(PredictionStrategy::AlwaysTrue) // Start with AlwaysTrue to build history
554            .with_confidence_threshold(0.5);
555
556        // Build history with mostly true outcomes
557        for _ in 0..8 {
558            let task_id = executor.speculate("node1".to_string()).unwrap();
559            executor.validate(task_id, BranchOutcome::True).unwrap();
560        }
561
562        for _ in 0..2 {
563            let task_id = executor.speculate("node1".to_string()).unwrap();
564            executor.validate(task_id, BranchOutcome::False).unwrap();
565        }
566
567        // Switch to history-based after building history
568        executor.strategy = PredictionStrategy::HistoryBased;
569
570        // Should predict True with high confidence
571        let (outcome, confidence) = executor.predict_branch(&"node1".to_string());
572        assert_eq!(outcome, BranchOutcome::True);
573        assert!(confidence > 0.7);
574    }
575
576    #[test]
577    fn test_max_speculation_depth() {
578        let mut executor = SpeculativeExecutor::new()
579            .with_strategy(PredictionStrategy::AlwaysTrue)
580            .with_confidence_threshold(0.5)
581            .with_max_depth(2);
582
583        executor.speculate("node1".to_string()).unwrap();
584        executor.speculate("node2".to_string()).unwrap();
585
586        // Third speculation should fail
587        let result = executor.speculate("node3".to_string());
588        assert!(result.is_err());
589    }
590
591    #[test]
592    fn test_checkpoint_creation() {
593        let mut executor = SpeculativeExecutor::new();
594        let checkpoint_id = executor.create_checkpoint("node1".to_string());
595
596        assert!(executor.checkpoints.contains_key(&checkpoint_id));
597    }
598
599    #[test]
600    fn test_cleanup() {
601        let mut executor = SpeculativeExecutor::new()
602            .with_strategy(PredictionStrategy::AlwaysTrue)
603            .with_confidence_threshold(0.5);
604
605        let task_id = executor.speculate("test".to_string()).unwrap();
606        executor.validate(task_id, BranchOutcome::True).unwrap();
607
608        assert!(executor.active_tasks.contains_key(&task_id));
609        executor.cleanup();
610        assert!(!executor.active_tasks.contains_key(&task_id));
611    }
612
613    #[test]
614    fn test_success_rate_calculation() {
615        let mut executor = SpeculativeExecutor::new()
616            .with_strategy(PredictionStrategy::AlwaysTrue)
617            .with_confidence_threshold(0.5);
618
619        // 3 correct, 1 incorrect = 75% success rate
620        for _ in 0..3 {
621            let task_id = executor.speculate("test".to_string()).unwrap();
622            executor.validate(task_id, BranchOutcome::True).unwrap();
623        }
624
625        let task_id = executor.speculate("test".to_string()).unwrap();
626        executor.validate(task_id, BranchOutcome::False).unwrap();
627
628        assert!((executor.stats.success_rate - 0.75).abs() < 0.01);
629    }
630
631    #[test]
632    fn test_reset_stats() {
633        let mut executor = SpeculativeExecutor::new()
634            .with_strategy(PredictionStrategy::AlwaysTrue)
635            .with_confidence_threshold(0.5);
636
637        let task_id = executor.speculate("test".to_string()).unwrap();
638        executor.validate(task_id, BranchOutcome::True).unwrap();
639
640        assert_eq!(executor.stats.total_speculations, 1);
641
642        executor.reset_stats();
643        assert_eq!(executor.stats.total_speculations, 0);
644        assert_eq!(executor.stats.correct_speculations, 0);
645    }
646
647    #[test]
648    fn test_should_speculate() {
649        let mut executor = SpeculativeExecutor::new()
650            .with_strategy(PredictionStrategy::AlwaysTrue)
651            .with_confidence_threshold(0.5);
652
653        assert!(executor.should_speculate(&"test".to_string()));
654
655        // Fill up speculation depth
656        for i in 0..executor.max_speculation_depth {
657            executor.speculate(format!("node{}", i)).unwrap();
658        }
659
660        assert!(!executor.should_speculate(&"test".to_string()));
661    }
662
663    #[test]
664    fn test_active_speculation_count() {
665        let mut executor = SpeculativeExecutor::new()
666            .with_strategy(PredictionStrategy::AlwaysTrue)
667            .with_confidence_threshold(0.5);
668
669        assert_eq!(executor.active_speculation_count(), 0);
670
671        executor.speculate("node1".to_string()).unwrap();
672        assert_eq!(executor.active_speculation_count(), 1);
673
674        executor.speculate("node2".to_string()).unwrap();
675        assert_eq!(executor.active_speculation_count(), 2);
676    }
677
678    #[test]
679    fn test_different_rollback_policies() {
680        let strategies = vec![
681            RollbackPolicy::Immediate,
682            RollbackPolicy::Lazy,
683            RollbackPolicy::Checkpoint,
684        ];
685
686        for policy in strategies {
687            let mut executor = SpeculativeExecutor::new()
688                .with_strategy(PredictionStrategy::AlwaysTrue)
689                .with_rollback_policy(policy)
690                .with_confidence_threshold(0.5);
691
692            let task_id = executor.speculate("test".to_string()).unwrap();
693            executor.validate(task_id, BranchOutcome::False).unwrap();
694
695            assert_eq!(executor.stats.rollbacks, 1);
696        }
697    }
698}