1use serde::{Deserialize, Serialize};
30use std::collections::{HashMap, VecDeque};
31use thiserror::Error;
32
33#[derive(Error, Debug, Clone, PartialEq)]
35pub enum SpeculativeError {
36 #[error("Speculation failed: {0}")]
37 SpeculationFailed(String),
38
39 #[error("Rollback failed: {0}")]
40 RollbackFailed(String),
41
42 #[error("Invalid prediction: {0}")]
43 InvalidPrediction(String),
44
45 #[error("Checkpoint not found: {0}")]
46 CheckpointNotFound(String),
47}
48
49pub type NodeId = String;
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
54pub enum PredictionStrategy {
55 MostFrequent,
57 HistoryBased,
59 Static,
61 Adaptive,
63 AlwaysTrue,
65 Never,
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
71pub enum RollbackPolicy {
72 Immediate,
74 Lazy,
76 Checkpoint,
78}
79
80#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
82pub enum BranchOutcome {
83 True,
84 False,
85 Unknown,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct SpeculativeTask {
91 pub task_id: u64,
92 pub node_id: NodeId,
93 pub predicted_branch: BranchOutcome,
94 pub confidence: f64,
95 pub started_at: u64, pub completed: bool,
97 pub correct: Option<bool>, }
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
102struct BranchHistory {
103 node_id: NodeId,
104 outcomes: VecDeque<BranchOutcome>,
105 max_history: usize,
106}
107
108impl BranchHistory {
109 fn new(node_id: NodeId, max_history: usize) -> Self {
110 Self {
111 node_id,
112 outcomes: VecDeque::new(),
113 max_history,
114 }
115 }
116
117 fn add_outcome(&mut self, outcome: BranchOutcome) {
118 if self.outcomes.len() >= self.max_history {
119 self.outcomes.pop_front();
120 }
121 self.outcomes.push_back(outcome);
122 }
123
124 fn predict(&self) -> (BranchOutcome, f64) {
125 if self.outcomes.is_empty() {
126 return (BranchOutcome::Unknown, 0.0);
127 }
128
129 let true_count = self
130 .outcomes
131 .iter()
132 .filter(|&&o| o == BranchOutcome::True)
133 .count();
134 let false_count = self
135 .outcomes
136 .iter()
137 .filter(|&&o| o == BranchOutcome::False)
138 .count();
139 let total = true_count + false_count;
140
141 if total == 0 {
142 return (BranchOutcome::Unknown, 0.0);
143 }
144
145 if true_count > false_count {
146 (BranchOutcome::True, true_count as f64 / total as f64)
147 } else {
148 (BranchOutcome::False, false_count as f64 / total as f64)
149 }
150 }
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize)]
155pub struct SpeculationStats {
156 pub total_speculations: usize,
157 pub correct_speculations: usize,
158 pub incorrect_speculations: usize,
159 pub rollbacks: usize,
160 pub success_rate: f64,
161 pub average_confidence: f64,
162 pub time_saved_us: f64,
163 pub time_wasted_us: f64,
164}
165
166#[derive(Debug, Clone, Serialize, Deserialize)]
168struct Checkpoint {
169 checkpoint_id: u64,
170 node_id: NodeId,
171 timestamp: u64,
172 }
174
175pub struct SpeculativeExecutor {
177 strategy: PredictionStrategy,
178 rollback_policy: RollbackPolicy,
179 confidence_threshold: f64,
180 max_speculation_depth: usize,
181 branch_history: HashMap<NodeId, BranchHistory>,
182 active_tasks: HashMap<u64, SpeculativeTask>,
183 checkpoints: HashMap<u64, Checkpoint>,
184 next_task_id: u64,
185 next_checkpoint_id: u64,
186 stats: SpeculationStats,
187 history_length: usize,
188}
189
190impl SpeculativeExecutor {
191 pub fn new() -> Self {
193 Self {
194 strategy: PredictionStrategy::HistoryBased,
195 rollback_policy: RollbackPolicy::Immediate,
196 confidence_threshold: 0.6,
197 max_speculation_depth: 3,
198 branch_history: HashMap::new(),
199 active_tasks: HashMap::new(),
200 checkpoints: HashMap::new(),
201 next_task_id: 0,
202 next_checkpoint_id: 0,
203 stats: SpeculationStats {
204 total_speculations: 0,
205 correct_speculations: 0,
206 incorrect_speculations: 0,
207 rollbacks: 0,
208 success_rate: 0.0,
209 average_confidence: 0.0,
210 time_saved_us: 0.0,
211 time_wasted_us: 0.0,
212 },
213 history_length: 10,
214 }
215 }
216
217 pub fn with_strategy(mut self, strategy: PredictionStrategy) -> Self {
219 self.strategy = strategy;
220 self
221 }
222
223 pub fn with_rollback_policy(mut self, policy: RollbackPolicy) -> Self {
225 self.rollback_policy = policy;
226 self
227 }
228
229 pub fn with_confidence_threshold(mut self, threshold: f64) -> Self {
231 self.confidence_threshold = threshold.clamp(0.0, 1.0);
232 self
233 }
234
235 pub fn with_max_depth(mut self, depth: usize) -> Self {
237 self.max_speculation_depth = depth;
238 self
239 }
240
241 pub fn predict_branch(&self, node_id: &NodeId) -> (BranchOutcome, f64) {
243 match self.strategy {
244 PredictionStrategy::Never => (BranchOutcome::Unknown, 0.0),
245 PredictionStrategy::AlwaysTrue => (BranchOutcome::True, 1.0),
246 PredictionStrategy::MostFrequent => {
247 if let Some(history) = self.branch_history.get(node_id) {
248 history.predict()
249 } else {
250 (BranchOutcome::True, 0.5) }
252 }
253 PredictionStrategy::HistoryBased => {
254 if let Some(history) = self.branch_history.get(node_id) {
255 history.predict()
256 } else {
257 (BranchOutcome::Unknown, 0.0)
258 }
259 }
260 PredictionStrategy::Static | PredictionStrategy::Adaptive => {
261 if let Some(history) = self.branch_history.get(node_id) {
263 history.predict()
264 } else {
265 (BranchOutcome::True, 0.5)
266 }
267 }
268 }
269 }
270
271 pub fn speculate(&mut self, node_id: NodeId) -> Result<u64, SpeculativeError> {
273 let (predicted_branch, confidence) = self.predict_branch(&node_id);
274
275 if confidence < self.confidence_threshold {
277 return Err(SpeculativeError::SpeculationFailed(format!(
278 "Confidence {} below threshold {}",
279 confidence, self.confidence_threshold
280 )));
281 }
282
283 let active_count = self.active_tasks.values().filter(|t| !t.completed).count();
285
286 if active_count >= self.max_speculation_depth {
287 return Err(SpeculativeError::SpeculationFailed(format!(
288 "Maximum speculation depth {} reached",
289 self.max_speculation_depth
290 )));
291 }
292
293 let task_id = self.next_task_id;
295 self.next_task_id += 1;
296
297 let task = SpeculativeTask {
298 task_id,
299 node_id: node_id.clone(),
300 predicted_branch,
301 confidence,
302 started_at: 0, completed: false,
304 correct: None,
305 };
306
307 self.active_tasks.insert(task_id, task);
308 self.stats.total_speculations += 1;
309
310 Ok(task_id)
311 }
312
313 pub fn validate(
315 &mut self,
316 task_id: u64,
317 actual_branch: BranchOutcome,
318 ) -> Result<bool, SpeculativeError> {
319 let task = self.active_tasks.get_mut(&task_id).ok_or_else(|| {
320 SpeculativeError::InvalidPrediction(format!("Task {} not found", task_id))
321 })?;
322
323 let correct = task.predicted_branch == actual_branch;
324 task.correct = Some(correct);
325 task.completed = true;
326
327 let history = self
329 .branch_history
330 .entry(task.node_id.clone())
331 .or_insert_with(|| BranchHistory::new(task.node_id.clone(), self.history_length));
332 history.add_outcome(actual_branch);
333
334 if correct {
336 self.stats.correct_speculations += 1;
337 } else {
338 self.stats.incorrect_speculations += 1;
339 self.rollback(task_id)?;
341 }
342
343 self.update_stats();
344
345 Ok(correct)
346 }
347
348 fn rollback(&mut self, task_id: u64) -> Result<(), SpeculativeError> {
350 match self.rollback_policy {
351 RollbackPolicy::Immediate => {
352 self.active_tasks.remove(&task_id);
354 self.stats.rollbacks += 1;
355 Ok(())
356 }
357 RollbackPolicy::Lazy => {
358 if let Some(task) = self.active_tasks.get_mut(&task_id) {
360 task.completed = true;
361 }
362 self.stats.rollbacks += 1;
363 Ok(())
364 }
365 RollbackPolicy::Checkpoint => {
366 self.restore_checkpoint(task_id)?;
368 self.stats.rollbacks += 1;
369 Ok(())
370 }
371 }
372 }
373
374 pub fn create_checkpoint(&mut self, node_id: NodeId) -> u64 {
376 let checkpoint_id = self.next_checkpoint_id;
377 self.next_checkpoint_id += 1;
378
379 let checkpoint = Checkpoint {
380 checkpoint_id,
381 node_id,
382 timestamp: 0, };
384
385 self.checkpoints.insert(checkpoint_id, checkpoint);
386 checkpoint_id
387 }
388
389 fn restore_checkpoint(&mut self, task_id: u64) -> Result<(), SpeculativeError> {
391 let _task = self.active_tasks.get(&task_id).ok_or_else(|| {
393 SpeculativeError::CheckpointNotFound(format!("No task found for id: {}", task_id))
394 })?;
395
396 self.active_tasks.remove(&task_id);
398 Ok(())
399 }
400
401 fn update_stats(&mut self) {
403 let total = (self.stats.correct_speculations + self.stats.incorrect_speculations) as f64;
404 if total > 0.0 {
405 self.stats.success_rate = self.stats.correct_speculations as f64 / total;
406 }
407
408 let confidence_sum: f64 = self.active_tasks.values().map(|t| t.confidence).sum();
409 let task_count = self.active_tasks.len() as f64;
410 if task_count > 0.0 {
411 self.stats.average_confidence = confidence_sum / task_count;
412 }
413 }
414
415 pub fn get_stats(&self) -> &SpeculationStats {
417 &self.stats
418 }
419
420 pub fn cleanup(&mut self) {
422 self.active_tasks.retain(|_, task| !task.completed);
423 }
424
425 pub fn reset_stats(&mut self) {
427 self.stats = SpeculationStats {
428 total_speculations: 0,
429 correct_speculations: 0,
430 incorrect_speculations: 0,
431 rollbacks: 0,
432 success_rate: 0.0,
433 average_confidence: 0.0,
434 time_saved_us: 0.0,
435 time_wasted_us: 0.0,
436 };
437 }
438
439 pub fn active_speculation_count(&self) -> usize {
441 self.active_tasks.values().filter(|t| !t.completed).count()
442 }
443
444 pub fn should_speculate(&self, node_id: &NodeId) -> bool {
446 let (_, confidence) = self.predict_branch(node_id);
447 confidence >= self.confidence_threshold
448 && self.active_speculation_count() < self.max_speculation_depth
449 }
450}
451
452impl Default for SpeculativeExecutor {
453 fn default() -> Self {
454 Self::new()
455 }
456}
457
458#[cfg(test)]
459mod tests {
460 use super::*;
461
462 #[test]
463 fn test_speculative_executor_creation() {
464 let executor = SpeculativeExecutor::new();
465 assert_eq!(executor.strategy, PredictionStrategy::HistoryBased);
466 assert_eq!(executor.rollback_policy, RollbackPolicy::Immediate);
467 assert_eq!(executor.confidence_threshold, 0.6);
468 }
469
470 #[test]
471 fn test_builder_pattern() {
472 let executor = SpeculativeExecutor::new()
473 .with_strategy(PredictionStrategy::Adaptive)
474 .with_rollback_policy(RollbackPolicy::Checkpoint)
475 .with_confidence_threshold(0.8)
476 .with_max_depth(5);
477
478 assert_eq!(executor.strategy, PredictionStrategy::Adaptive);
479 assert_eq!(executor.rollback_policy, RollbackPolicy::Checkpoint);
480 assert_eq!(executor.confidence_threshold, 0.8);
481 assert_eq!(executor.max_speculation_depth, 5);
482 }
483
484 #[test]
485 fn test_always_true_prediction() {
486 let executor = SpeculativeExecutor::new().with_strategy(PredictionStrategy::AlwaysTrue);
487
488 let (outcome, confidence) = executor.predict_branch(&"test".to_string());
489 assert_eq!(outcome, BranchOutcome::True);
490 assert_eq!(confidence, 1.0);
491 }
492
493 #[test]
494 fn test_never_speculation() {
495 let executor = SpeculativeExecutor::new().with_strategy(PredictionStrategy::Never);
496
497 let (outcome, confidence) = executor.predict_branch(&"test".to_string());
498 assert_eq!(outcome, BranchOutcome::Unknown);
499 assert_eq!(confidence, 0.0);
500 }
501
502 #[test]
503 fn test_speculation_below_threshold() {
504 let mut executor = SpeculativeExecutor::new().with_confidence_threshold(0.9);
505
506 let result = executor.speculate("test".to_string());
507 assert!(result.is_err()); }
509
510 #[test]
511 fn test_successful_speculation() {
512 let mut executor = SpeculativeExecutor::new()
513 .with_strategy(PredictionStrategy::AlwaysTrue)
514 .with_confidence_threshold(0.5);
515
516 let task_id = executor.speculate("test".to_string()).unwrap();
517 assert_eq!(executor.stats.total_speculations, 1);
518 assert!(executor.active_tasks.contains_key(&task_id));
519 }
520
521 #[test]
522 fn test_correct_validation() {
523 let mut executor = SpeculativeExecutor::new()
524 .with_strategy(PredictionStrategy::AlwaysTrue)
525 .with_confidence_threshold(0.5);
526
527 let task_id = executor.speculate("test".to_string()).unwrap();
528 let correct = executor.validate(task_id, BranchOutcome::True).unwrap();
529
530 assert!(correct);
531 assert_eq!(executor.stats.correct_speculations, 1);
532 assert_eq!(executor.stats.incorrect_speculations, 0);
533 }
534
535 #[test]
536 fn test_incorrect_validation() {
537 let mut executor = SpeculativeExecutor::new()
538 .with_strategy(PredictionStrategy::AlwaysTrue)
539 .with_confidence_threshold(0.5);
540
541 let task_id = executor.speculate("test".to_string()).unwrap();
542 let correct = executor.validate(task_id, BranchOutcome::False).unwrap();
543
544 assert!(!correct);
545 assert_eq!(executor.stats.correct_speculations, 0);
546 assert_eq!(executor.stats.incorrect_speculations, 1);
547 assert_eq!(executor.stats.rollbacks, 1);
548 }
549
550 #[test]
551 fn test_history_based_prediction() {
552 let mut executor = SpeculativeExecutor::new()
553 .with_strategy(PredictionStrategy::AlwaysTrue) .with_confidence_threshold(0.5);
555
556 for _ in 0..8 {
558 let task_id = executor.speculate("node1".to_string()).unwrap();
559 executor.validate(task_id, BranchOutcome::True).unwrap();
560 }
561
562 for _ in 0..2 {
563 let task_id = executor.speculate("node1".to_string()).unwrap();
564 executor.validate(task_id, BranchOutcome::False).unwrap();
565 }
566
567 executor.strategy = PredictionStrategy::HistoryBased;
569
570 let (outcome, confidence) = executor.predict_branch(&"node1".to_string());
572 assert_eq!(outcome, BranchOutcome::True);
573 assert!(confidence > 0.7);
574 }
575
576 #[test]
577 fn test_max_speculation_depth() {
578 let mut executor = SpeculativeExecutor::new()
579 .with_strategy(PredictionStrategy::AlwaysTrue)
580 .with_confidence_threshold(0.5)
581 .with_max_depth(2);
582
583 executor.speculate("node1".to_string()).unwrap();
584 executor.speculate("node2".to_string()).unwrap();
585
586 let result = executor.speculate("node3".to_string());
588 assert!(result.is_err());
589 }
590
591 #[test]
592 fn test_checkpoint_creation() {
593 let mut executor = SpeculativeExecutor::new();
594 let checkpoint_id = executor.create_checkpoint("node1".to_string());
595
596 assert!(executor.checkpoints.contains_key(&checkpoint_id));
597 }
598
599 #[test]
600 fn test_cleanup() {
601 let mut executor = SpeculativeExecutor::new()
602 .with_strategy(PredictionStrategy::AlwaysTrue)
603 .with_confidence_threshold(0.5);
604
605 let task_id = executor.speculate("test".to_string()).unwrap();
606 executor.validate(task_id, BranchOutcome::True).unwrap();
607
608 assert!(executor.active_tasks.contains_key(&task_id));
609 executor.cleanup();
610 assert!(!executor.active_tasks.contains_key(&task_id));
611 }
612
613 #[test]
614 fn test_success_rate_calculation() {
615 let mut executor = SpeculativeExecutor::new()
616 .with_strategy(PredictionStrategy::AlwaysTrue)
617 .with_confidence_threshold(0.5);
618
619 for _ in 0..3 {
621 let task_id = executor.speculate("test".to_string()).unwrap();
622 executor.validate(task_id, BranchOutcome::True).unwrap();
623 }
624
625 let task_id = executor.speculate("test".to_string()).unwrap();
626 executor.validate(task_id, BranchOutcome::False).unwrap();
627
628 assert!((executor.stats.success_rate - 0.75).abs() < 0.01);
629 }
630
631 #[test]
632 fn test_reset_stats() {
633 let mut executor = SpeculativeExecutor::new()
634 .with_strategy(PredictionStrategy::AlwaysTrue)
635 .with_confidence_threshold(0.5);
636
637 let task_id = executor.speculate("test".to_string()).unwrap();
638 executor.validate(task_id, BranchOutcome::True).unwrap();
639
640 assert_eq!(executor.stats.total_speculations, 1);
641
642 executor.reset_stats();
643 assert_eq!(executor.stats.total_speculations, 0);
644 assert_eq!(executor.stats.correct_speculations, 0);
645 }
646
647 #[test]
648 fn test_should_speculate() {
649 let mut executor = SpeculativeExecutor::new()
650 .with_strategy(PredictionStrategy::AlwaysTrue)
651 .with_confidence_threshold(0.5);
652
653 assert!(executor.should_speculate(&"test".to_string()));
654
655 for i in 0..executor.max_speculation_depth {
657 executor.speculate(format!("node{}", i)).unwrap();
658 }
659
660 assert!(!executor.should_speculate(&"test".to_string()));
661 }
662
663 #[test]
664 fn test_active_speculation_count() {
665 let mut executor = SpeculativeExecutor::new()
666 .with_strategy(PredictionStrategy::AlwaysTrue)
667 .with_confidence_threshold(0.5);
668
669 assert_eq!(executor.active_speculation_count(), 0);
670
671 executor.speculate("node1".to_string()).unwrap();
672 assert_eq!(executor.active_speculation_count(), 1);
673
674 executor.speculate("node2".to_string()).unwrap();
675 assert_eq!(executor.active_speculation_count(), 2);
676 }
677
678 #[test]
679 fn test_different_rollback_policies() {
680 let strategies = vec![
681 RollbackPolicy::Immediate,
682 RollbackPolicy::Lazy,
683 RollbackPolicy::Checkpoint,
684 ];
685
686 for policy in strategies {
687 let mut executor = SpeculativeExecutor::new()
688 .with_strategy(PredictionStrategy::AlwaysTrue)
689 .with_rollback_policy(policy)
690 .with_confidence_threshold(0.5);
691
692 let task_id = executor.speculate("test".to_string()).unwrap();
693 executor.validate(task_id, BranchOutcome::False).unwrap();
694
695 assert_eq!(executor.stats.rollbacks, 1);
696 }
697 }
698}