1use serde::{Deserialize, Serialize};
30use std::collections::{HashMap, VecDeque};
31use thiserror::Error;
32
33#[derive(Error, Debug, Clone, PartialEq)]
35pub enum SpeculativeError {
36 #[error("Speculation failed: {0}")]
37 SpeculationFailed(String),
38
39 #[error("Rollback failed: {0}")]
40 RollbackFailed(String),
41
42 #[error("Invalid prediction: {0}")]
43 InvalidPrediction(String),
44
45 #[error("Checkpoint not found: {0}")]
46 CheckpointNotFound(String),
47}
48
49pub type NodeId = String;
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
54pub enum PredictionStrategy {
55 MostFrequent,
57 HistoryBased,
59 Static,
61 Adaptive,
63 AlwaysTrue,
65 Never,
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
71pub enum RollbackPolicy {
72 Immediate,
74 Lazy,
76 Checkpoint,
78}
79
80#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
82pub enum BranchOutcome {
83 True,
84 False,
85 Unknown,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct SpeculativeTask {
91 pub task_id: u64,
92 pub node_id: NodeId,
93 pub predicted_branch: BranchOutcome,
94 pub confidence: f64,
95 pub started_at: u64, pub completed: bool,
97 pub correct: Option<bool>, }
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
102struct BranchHistory {
103 node_id: NodeId,
104 outcomes: VecDeque<BranchOutcome>,
105 max_history: usize,
106}
107
108impl BranchHistory {
109 fn new(node_id: NodeId, max_history: usize) -> Self {
110 Self {
111 node_id,
112 outcomes: VecDeque::new(),
113 max_history,
114 }
115 }
116
117 fn add_outcome(&mut self, outcome: BranchOutcome) {
118 if self.outcomes.len() >= self.max_history {
119 self.outcomes.pop_front();
120 }
121 self.outcomes.push_back(outcome);
122 }
123
124 fn predict(&self) -> (BranchOutcome, f64) {
125 if self.outcomes.is_empty() {
126 return (BranchOutcome::Unknown, 0.0);
127 }
128
129 let true_count = self
130 .outcomes
131 .iter()
132 .filter(|&&o| o == BranchOutcome::True)
133 .count();
134 let false_count = self
135 .outcomes
136 .iter()
137 .filter(|&&o| o == BranchOutcome::False)
138 .count();
139 let total = true_count + false_count;
140
141 if total == 0 {
142 return (BranchOutcome::Unknown, 0.0);
143 }
144
145 if true_count > false_count {
146 (BranchOutcome::True, true_count as f64 / total as f64)
147 } else {
148 (BranchOutcome::False, false_count as f64 / total as f64)
149 }
150 }
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct SpeculationStats {
156 pub total_speculations: usize,
157 pub correct_speculations: usize,
158 pub incorrect_speculations: usize,
159 pub rollbacks: usize,
160 pub success_rate: f64,
161 pub average_confidence: f64,
162 pub time_saved_us: f64,
163 pub time_wasted_us: f64,
164}
165
166#[derive(Debug, Clone, Serialize, Deserialize)]
168struct Checkpoint {
169 checkpoint_id: u64,
170 node_id: NodeId,
171 timestamp: u64,
172 }
174
175pub struct SpeculativeExecutor {
177 strategy: PredictionStrategy,
178 rollback_policy: RollbackPolicy,
179 confidence_threshold: f64,
180 max_speculation_depth: usize,
181 branch_history: HashMap<NodeId, BranchHistory>,
182 active_tasks: HashMap<u64, SpeculativeTask>,
183 checkpoints: HashMap<u64, Checkpoint>,
184 next_task_id: u64,
185 next_checkpoint_id: u64,
186 stats: SpeculationStats,
187 history_length: usize,
188}
189
190impl SpeculativeExecutor {
191 pub fn new() -> Self {
193 Self {
194 strategy: PredictionStrategy::HistoryBased,
195 rollback_policy: RollbackPolicy::Immediate,
196 confidence_threshold: 0.6,
197 max_speculation_depth: 3,
198 branch_history: HashMap::new(),
199 active_tasks: HashMap::new(),
200 checkpoints: HashMap::new(),
201 next_task_id: 0,
202 next_checkpoint_id: 0,
203 stats: SpeculationStats {
204 total_speculations: 0,
205 correct_speculations: 0,
206 incorrect_speculations: 0,
207 rollbacks: 0,
208 success_rate: 0.0,
209 average_confidence: 0.0,
210 time_saved_us: 0.0,
211 time_wasted_us: 0.0,
212 },
213 history_length: 10,
214 }
215 }
216
217 pub fn with_strategy(mut self, strategy: PredictionStrategy) -> Self {
219 self.strategy = strategy;
220 self
221 }
222
223 pub fn with_rollback_policy(mut self, policy: RollbackPolicy) -> Self {
225 self.rollback_policy = policy;
226 self
227 }
228
229 pub fn with_confidence_threshold(mut self, threshold: f64) -> Self {
231 self.confidence_threshold = threshold.clamp(0.0, 1.0);
232 self
233 }
234
235 pub fn with_max_depth(mut self, depth: usize) -> Self {
237 self.max_speculation_depth = depth;
238 self
239 }
240
241 pub fn predict_branch(&self, node_id: &NodeId) -> (BranchOutcome, f64) {
243 match self.strategy {
244 PredictionStrategy::Never => (BranchOutcome::Unknown, 0.0),
245 PredictionStrategy::AlwaysTrue => (BranchOutcome::True, 1.0),
246 PredictionStrategy::MostFrequent => {
247 if let Some(history) = self.branch_history.get(node_id) {
248 history.predict()
249 } else {
250 (BranchOutcome::True, 0.5) }
252 }
253 PredictionStrategy::HistoryBased => {
254 if let Some(history) = self.branch_history.get(node_id) {
255 history.predict()
256 } else {
257 (BranchOutcome::Unknown, 0.0)
258 }
259 }
260 PredictionStrategy::Static | PredictionStrategy::Adaptive => {
261 if let Some(history) = self.branch_history.get(node_id) {
263 history.predict()
264 } else {
265 (BranchOutcome::True, 0.5)
266 }
267 }
268 }
269 }
270
271 pub fn speculate(&mut self, node_id: NodeId) -> Result<u64, SpeculativeError> {
273 let (predicted_branch, confidence) = self.predict_branch(&node_id);
274
275 if confidence < self.confidence_threshold {
277 return Err(SpeculativeError::SpeculationFailed(format!(
278 "Confidence {} below threshold {}",
279 confidence, self.confidence_threshold
280 )));
281 }
282
283 let active_count = self.active_tasks.values().filter(|t| !t.completed).count();
285
286 if active_count >= self.max_speculation_depth {
287 return Err(SpeculativeError::SpeculationFailed(format!(
288 "Maximum speculation depth {} reached",
289 self.max_speculation_depth
290 )));
291 }
292
293 let task_id = self.next_task_id;
295 self.next_task_id += 1;
296
297 let task = SpeculativeTask {
298 task_id,
299 node_id: node_id.clone(),
300 predicted_branch,
301 confidence,
302 started_at: 0, completed: false,
304 correct: None,
305 };
306
307 self.active_tasks.insert(task_id, task);
308 self.stats.total_speculations += 1;
309
310 Ok(task_id)
311 }
312
313 pub fn validate(
315 &mut self,
316 task_id: u64,
317 actual_branch: BranchOutcome,
318 ) -> Result<bool, SpeculativeError> {
319 let task = self.active_tasks.get_mut(&task_id).ok_or_else(|| {
320 SpeculativeError::InvalidPrediction(format!("Task {} not found", task_id))
321 })?;
322
323 let correct = task.predicted_branch == actual_branch;
324 task.correct = Some(correct);
325 task.completed = true;
326
327 let history = self
329 .branch_history
330 .entry(task.node_id.clone())
331 .or_insert_with(|| BranchHistory::new(task.node_id.clone(), self.history_length));
332 history.add_outcome(actual_branch);
333
334 if correct {
336 self.stats.correct_speculations += 1;
337 } else {
338 self.stats.incorrect_speculations += 1;
339 self.rollback(task_id)?;
341 }
342
343 self.update_stats();
344
345 Ok(correct)
346 }
347
348 fn rollback(&mut self, task_id: u64) -> Result<(), SpeculativeError> {
350 match self.rollback_policy {
351 RollbackPolicy::Immediate => {
352 self.active_tasks.remove(&task_id);
354 self.stats.rollbacks += 1;
355 Ok(())
356 }
357 RollbackPolicy::Lazy => {
358 if let Some(task) = self.active_tasks.get_mut(&task_id) {
360 task.completed = true;
361 }
362 self.stats.rollbacks += 1;
363 Ok(())
364 }
365 RollbackPolicy::Checkpoint => {
366 self.restore_checkpoint(task_id)?;
368 self.stats.rollbacks += 1;
369 Ok(())
370 }
371 }
372 }
373
374 pub fn create_checkpoint(&mut self, node_id: NodeId) -> u64 {
376 let checkpoint_id = self.next_checkpoint_id;
377 self.next_checkpoint_id += 1;
378
379 let checkpoint = Checkpoint {
380 checkpoint_id,
381 node_id,
382 timestamp: 0, };
384
385 self.checkpoints.insert(checkpoint_id, checkpoint);
386 checkpoint_id
387 }
388
389 fn restore_checkpoint(&mut self, task_id: u64) -> Result<(), SpeculativeError> {
391 let _task = self.active_tasks.get(&task_id).ok_or_else(|| {
393 SpeculativeError::CheckpointNotFound(format!("No task found for id: {}", task_id))
394 })?;
395
396 self.active_tasks.remove(&task_id);
398 Ok(())
399 }
400
401 fn update_stats(&mut self) {
403 let total = (self.stats.correct_speculations + self.stats.incorrect_speculations) as f64;
404 if total > 0.0 {
405 self.stats.success_rate = self.stats.correct_speculations as f64 / total;
406 }
407
408 let confidence_sum: f64 = self.active_tasks.values().map(|t| t.confidence).sum();
409 let task_count = self.active_tasks.len() as f64;
410 if task_count > 0.0 {
411 self.stats.average_confidence = confidence_sum / task_count;
412 }
413 }
414
415 pub fn get_stats(&self) -> &SpeculationStats {
417 &self.stats
418 }
419
420 pub fn cleanup(&mut self) {
422 self.active_tasks.retain(|_, task| !task.completed);
423 }
424
425 pub fn reset_stats(&mut self) {
427 self.stats = SpeculationStats {
428 total_speculations: 0,
429 correct_speculations: 0,
430 incorrect_speculations: 0,
431 rollbacks: 0,
432 success_rate: 0.0,
433 average_confidence: 0.0,
434 time_saved_us: 0.0,
435 time_wasted_us: 0.0,
436 };
437 }
438
439 pub fn active_speculation_count(&self) -> usize {
441 self.active_tasks.values().filter(|t| !t.completed).count()
442 }
443
444 pub fn should_speculate(&self, node_id: &NodeId) -> bool {
446 let (_, confidence) = self.predict_branch(node_id);
447 confidence >= self.confidence_threshold
448 && self.active_speculation_count() < self.max_speculation_depth
449 }
450}
451
452impl Default for SpeculativeExecutor {
453 fn default() -> Self {
454 Self::new()
455 }
456}
457
458#[cfg(test)]
459mod tests {
460 use super::*;
461
462 #[test]
463 fn test_speculative_executor_creation() {
464 let executor = SpeculativeExecutor::new();
465 assert_eq!(executor.strategy, PredictionStrategy::HistoryBased);
466 assert_eq!(executor.rollback_policy, RollbackPolicy::Immediate);
467 assert_eq!(executor.confidence_threshold, 0.6);
468 }
469
470 #[test]
471 fn test_builder_pattern() {
472 let executor = SpeculativeExecutor::new()
473 .with_strategy(PredictionStrategy::Adaptive)
474 .with_rollback_policy(RollbackPolicy::Checkpoint)
475 .with_confidence_threshold(0.8)
476 .with_max_depth(5);
477
478 assert_eq!(executor.strategy, PredictionStrategy::Adaptive);
479 assert_eq!(executor.rollback_policy, RollbackPolicy::Checkpoint);
480 assert_eq!(executor.confidence_threshold, 0.8);
481 assert_eq!(executor.max_speculation_depth, 5);
482 }
483
484 #[test]
485 fn test_always_true_prediction() {
486 let executor = SpeculativeExecutor::new().with_strategy(PredictionStrategy::AlwaysTrue);
487
488 let (outcome, confidence) = executor.predict_branch(&"test".to_string());
489 assert_eq!(outcome, BranchOutcome::True);
490 assert_eq!(confidence, 1.0);
491 }
492
493 #[test]
494 fn test_never_speculation() {
495 let executor = SpeculativeExecutor::new().with_strategy(PredictionStrategy::Never);
496
497 let (outcome, confidence) = executor.predict_branch(&"test".to_string());
498 assert_eq!(outcome, BranchOutcome::Unknown);
499 assert_eq!(confidence, 0.0);
500 }
501
502 #[test]
503 fn test_speculation_below_threshold() {
504 let mut executor = SpeculativeExecutor::new().with_confidence_threshold(0.9);
505
506 let result = executor.speculate("test".to_string());
507 assert!(result.is_err()); }
509
510 #[test]
511 fn test_successful_speculation() {
512 let mut executor = SpeculativeExecutor::new()
513 .with_strategy(PredictionStrategy::AlwaysTrue)
514 .with_confidence_threshold(0.5);
515
516 let task_id = executor.speculate("test".to_string()).expect("unwrap");
517 assert_eq!(executor.stats.total_speculations, 1);
518 assert!(executor.active_tasks.contains_key(&task_id));
519 }
520
521 #[test]
522 fn test_correct_validation() {
523 let mut executor = SpeculativeExecutor::new()
524 .with_strategy(PredictionStrategy::AlwaysTrue)
525 .with_confidence_threshold(0.5);
526
527 let task_id = executor.speculate("test".to_string()).expect("unwrap");
528 let correct = executor
529 .validate(task_id, BranchOutcome::True)
530 .expect("unwrap");
531
532 assert!(correct);
533 assert_eq!(executor.stats.correct_speculations, 1);
534 assert_eq!(executor.stats.incorrect_speculations, 0);
535 }
536
537 #[test]
538 fn test_incorrect_validation() {
539 let mut executor = SpeculativeExecutor::new()
540 .with_strategy(PredictionStrategy::AlwaysTrue)
541 .with_confidence_threshold(0.5);
542
543 let task_id = executor.speculate("test".to_string()).expect("unwrap");
544 let correct = executor
545 .validate(task_id, BranchOutcome::False)
546 .expect("unwrap");
547
548 assert!(!correct);
549 assert_eq!(executor.stats.correct_speculations, 0);
550 assert_eq!(executor.stats.incorrect_speculations, 1);
551 assert_eq!(executor.stats.rollbacks, 1);
552 }
553
554 #[test]
555 fn test_history_based_prediction() {
556 let mut executor = SpeculativeExecutor::new()
557 .with_strategy(PredictionStrategy::AlwaysTrue) .with_confidence_threshold(0.5);
559
560 for _ in 0..8 {
562 let task_id = executor.speculate("node1".to_string()).expect("unwrap");
563 executor
564 .validate(task_id, BranchOutcome::True)
565 .expect("unwrap");
566 }
567
568 for _ in 0..2 {
569 let task_id = executor.speculate("node1".to_string()).expect("unwrap");
570 executor
571 .validate(task_id, BranchOutcome::False)
572 .expect("unwrap");
573 }
574
575 executor.strategy = PredictionStrategy::HistoryBased;
577
578 let (outcome, confidence) = executor.predict_branch(&"node1".to_string());
580 assert_eq!(outcome, BranchOutcome::True);
581 assert!(confidence > 0.7);
582 }
583
584 #[test]
585 fn test_max_speculation_depth() {
586 let mut executor = SpeculativeExecutor::new()
587 .with_strategy(PredictionStrategy::AlwaysTrue)
588 .with_confidence_threshold(0.5)
589 .with_max_depth(2);
590
591 executor.speculate("node1".to_string()).expect("unwrap");
592 executor.speculate("node2".to_string()).expect("unwrap");
593
594 let result = executor.speculate("node3".to_string());
596 assert!(result.is_err());
597 }
598
599 #[test]
600 fn test_checkpoint_creation() {
601 let mut executor = SpeculativeExecutor::new();
602 let checkpoint_id = executor.create_checkpoint("node1".to_string());
603
604 assert!(executor.checkpoints.contains_key(&checkpoint_id));
605 }
606
607 #[test]
608 fn test_cleanup() {
609 let mut executor = SpeculativeExecutor::new()
610 .with_strategy(PredictionStrategy::AlwaysTrue)
611 .with_confidence_threshold(0.5);
612
613 let task_id = executor.speculate("test".to_string()).expect("unwrap");
614 executor
615 .validate(task_id, BranchOutcome::True)
616 .expect("unwrap");
617
618 assert!(executor.active_tasks.contains_key(&task_id));
619 executor.cleanup();
620 assert!(!executor.active_tasks.contains_key(&task_id));
621 }
622
623 #[test]
624 fn test_success_rate_calculation() {
625 let mut executor = SpeculativeExecutor::new()
626 .with_strategy(PredictionStrategy::AlwaysTrue)
627 .with_confidence_threshold(0.5);
628
629 for _ in 0..3 {
631 let task_id = executor.speculate("test".to_string()).expect("unwrap");
632 executor
633 .validate(task_id, BranchOutcome::True)
634 .expect("unwrap");
635 }
636
637 let task_id = executor.speculate("test".to_string()).expect("unwrap");
638 executor
639 .validate(task_id, BranchOutcome::False)
640 .expect("unwrap");
641
642 assert!((executor.stats.success_rate - 0.75).abs() < 0.01);
643 }
644
645 #[test]
646 fn test_reset_stats() {
647 let mut executor = SpeculativeExecutor::new()
648 .with_strategy(PredictionStrategy::AlwaysTrue)
649 .with_confidence_threshold(0.5);
650
651 let task_id = executor.speculate("test".to_string()).expect("unwrap");
652 executor
653 .validate(task_id, BranchOutcome::True)
654 .expect("unwrap");
655
656 assert_eq!(executor.stats.total_speculations, 1);
657
658 executor.reset_stats();
659 assert_eq!(executor.stats.total_speculations, 0);
660 assert_eq!(executor.stats.correct_speculations, 0);
661 }
662
663 #[test]
664 fn test_should_speculate() {
665 let mut executor = SpeculativeExecutor::new()
666 .with_strategy(PredictionStrategy::AlwaysTrue)
667 .with_confidence_threshold(0.5);
668
669 assert!(executor.should_speculate(&"test".to_string()));
670
671 for i in 0..executor.max_speculation_depth {
673 executor.speculate(format!("node{}", i)).expect("unwrap");
674 }
675
676 assert!(!executor.should_speculate(&"test".to_string()));
677 }
678
679 #[test]
680 fn test_active_speculation_count() {
681 let mut executor = SpeculativeExecutor::new()
682 .with_strategy(PredictionStrategy::AlwaysTrue)
683 .with_confidence_threshold(0.5);
684
685 assert_eq!(executor.active_speculation_count(), 0);
686
687 executor.speculate("node1".to_string()).expect("unwrap");
688 assert_eq!(executor.active_speculation_count(), 1);
689
690 executor.speculate("node2".to_string()).expect("unwrap");
691 assert_eq!(executor.active_speculation_count(), 2);
692 }
693
694 #[test]
695 fn test_different_rollback_policies() {
696 let strategies = vec![
697 RollbackPolicy::Immediate,
698 RollbackPolicy::Lazy,
699 RollbackPolicy::Checkpoint,
700 ];
701
702 for policy in strategies {
703 let mut executor = SpeculativeExecutor::new()
704 .with_strategy(PredictionStrategy::AlwaysTrue)
705 .with_rollback_policy(policy)
706 .with_confidence_threshold(0.5);
707
708 let task_id = executor.speculate("test".to_string()).expect("unwrap");
709 executor
710 .validate(task_id, BranchOutcome::False)
711 .expect("unwrap");
712
713 assert_eq!(executor.stats.rollbacks, 1);
714 }
715 }
716}