1use crate::{ComputationGraph, JitError, JitResult, Node, NodeId};
8use petgraph::graph::NodeIndex;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::{
12 atomic::{AtomicBool, AtomicU64, Ordering},
13 Arc, RwLock,
14};
15
16pub struct SpeculativeOptimizer {
18 config: SpeculativeConfig,
19 assumptions: Arc<RwLock<HashMap<AssumptionId, Assumption>>>,
20 guards: Arc<RwLock<HashMap<NodeId, Vec<Guard>>>>,
21 deopt_counter: AtomicU64,
22 enabled: AtomicBool,
23}
24
25#[derive(Debug, Clone)]
27pub struct SpeculativeConfig {
28 pub max_assumptions: usize,
30
31 pub deopt_threshold: u64,
33
34 pub confidence_threshold: f64,
36
37 pub enable_type_speculation: bool,
39
40 pub enable_shape_speculation: bool,
42
43 pub enable_value_speculation: bool,
45
46 pub enable_nullability_speculation: bool,
48
49 pub enable_branch_speculation: bool,
51
52 pub enable_loop_speculation: bool,
54
55 pub aggressiveness: f64,
57}
58
59impl Default for SpeculativeConfig {
60 fn default() -> Self {
61 Self {
62 max_assumptions: 1000,
63 deopt_threshold: 100,
64 confidence_threshold: 0.8,
65 enable_type_speculation: true,
66 enable_shape_speculation: true,
67 enable_value_speculation: false, enable_nullability_speculation: true,
69 enable_branch_speculation: true,
70 enable_loop_speculation: true,
71 aggressiveness: 0.7,
72 }
73 }
74}
75
76#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
78pub struct AssumptionId(pub u64);
79
80#[derive(Debug, Clone)]
82pub struct Assumption {
83 pub id: AssumptionId,
84 pub assumption_type: AssumptionType,
85 pub node_id: NodeId,
86 pub confidence: f64,
87 pub success_count: u64,
88 pub failure_count: u64,
89 pub created_at: std::time::SystemTime,
90 pub metadata: HashMap<String, String>,
91}
92
93#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
95pub enum AssumptionType {
96 TypeSpeculation { expected_type: String },
98
99 ShapeSpeculation { expected_shape: Vec<usize> },
101
102 ValueSpeculation { expected_value: f64, tolerance: f64 },
104
105 NullabilitySpeculation,
107
108 BranchSpeculation {
110 usually_taken: bool,
111 probability: f64,
112 },
113
114 LoopSpeculation {
116 expected_iterations: u64,
117 tolerance: u64,
118 },
119
120 MemorySpeculation { access_pattern: MemoryAccessPattern },
122}
123
124#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
126pub enum MemoryAccessPattern {
127 Sequential,
128 Random,
129 Strided { stride: usize },
130 Clustered { cluster_size: usize },
131}
132
133#[derive(Debug, Clone)]
135pub struct Guard {
136 pub assumption_id: AssumptionId,
137 pub guard_type: GuardType,
138 pub check_frequency: GuardFrequency,
139}
140
141#[derive(Debug, Clone, PartialEq)]
143pub enum GuardType {
144 TypeCheck,
146
147 ShapeCheck,
149
150 ValueCheck,
152
153 NullabilityCheck,
155
156 BranchCheck,
158
159 LoopCheck,
161
162 MemoryCheck,
164}
165
166#[derive(Debug, Clone, PartialEq)]
168pub enum GuardFrequency {
169 Always,
171
172 Probabilistic(f64),
174
175 Periodic(u64),
177
178 InitialOnly(u64),
180}
181
182#[derive(Debug, Clone)]
184pub struct SpeculationResult {
185 pub assumptions_made: Vec<AssumptionId>,
186 pub optimizations_applied: Vec<SpeculativeOptimization>,
187 pub guards_installed: Vec<Guard>,
188 pub estimated_speedup: f64,
189}
190
191#[derive(Debug, Clone)]
193pub struct SpeculativeOptimization {
194 pub optimization_type: SpeculativeOptimizationType,
195 pub node_id: NodeId,
196 pub description: String,
197 pub estimated_benefit: f64,
198}
199
200#[derive(Debug, Clone, PartialEq)]
202pub enum SpeculativeOptimizationType {
203 TypeCheckElimination,
205
206 BoundsCheckElimination,
208
209 ShapeSpecialization,
211
212 ConstantPropagation,
214
215 DeadCodeElimination,
217
218 LoopUnrolling,
220
221 BranchElimination,
223
224 MemoryPrefetching,
226
227 VectorizationOptimization,
229}
230
231#[derive(Debug, Clone)]
233pub struct DeoptimizationEvent {
234 pub assumption_id: AssumptionId,
235 pub node_id: NodeId,
236 pub reason: String,
237 pub timestamp: std::time::SystemTime,
238 pub execution_count: u64,
239}
240
241impl SpeculativeOptimizer {
242 pub fn new(config: SpeculativeConfig) -> Self {
244 Self {
245 config,
246 assumptions: Arc::new(RwLock::new(HashMap::new())),
247 guards: Arc::new(RwLock::new(HashMap::new())),
248 deopt_counter: AtomicU64::new(0),
249 enabled: AtomicBool::new(true),
250 }
251 }
252
253 pub fn analyze_and_speculate(
255 &self,
256 graph: &ComputationGraph,
257 execution_history: &ExecutionHistory,
258 ) -> JitResult<SpeculationResult> {
259 if !self.enabled.load(Ordering::Relaxed) {
260 return Ok(SpeculationResult {
261 assumptions_made: Vec::new(),
262 optimizations_applied: Vec::new(),
263 guards_installed: Vec::new(),
264 estimated_speedup: 1.0,
265 });
266 }
267
268 let mut assumptions_made = Vec::new();
269 let mut optimizations = Vec::new();
270 let mut guards = Vec::new();
271 let mut total_speedup = 1.0;
272
273 for (node_id, node) in graph.nodes() {
274 if let Some(node_history) = execution_history.get_node_history(node_id) {
276 if self.config.enable_type_speculation {
278 if let Some(spec_result) =
279 self.analyze_type_speculation(node_id, node, node_history)?
280 {
281 assumptions_made.extend(spec_result.assumptions_made);
282 optimizations.extend(spec_result.optimizations_applied);
283 guards.extend(spec_result.guards_installed);
284 total_speedup *= spec_result.estimated_speedup;
285 }
286 }
287
288 if self.config.enable_shape_speculation {
290 if let Some(spec_result) =
291 self.analyze_shape_speculation(node_id, node, node_history)?
292 {
293 assumptions_made.extend(spec_result.assumptions_made);
294 optimizations.extend(spec_result.optimizations_applied);
295 guards.extend(spec_result.guards_installed);
296 total_speedup *= spec_result.estimated_speedup;
297 }
298 }
299
300 if self.config.enable_value_speculation {
302 if let Some(spec_result) =
303 self.analyze_value_speculation(node_id, node, node_history)?
304 {
305 assumptions_made.extend(spec_result.assumptions_made);
306 optimizations.extend(spec_result.optimizations_applied);
307 guards.extend(spec_result.guards_installed);
308 total_speedup *= spec_result.estimated_speedup;
309 }
310 }
311
312 if self.config.enable_branch_speculation {
314 if let Some(spec_result) =
315 self.analyze_branch_speculation(node_id, node, node_history)?
316 {
317 assumptions_made.extend(spec_result.assumptions_made);
318 optimizations.extend(spec_result.optimizations_applied);
319 guards.extend(spec_result.guards_installed);
320 total_speedup *= spec_result.estimated_speedup;
321 }
322 }
323 }
324 }
325
326 if let Ok(mut guard_map) = self.guards.write() {
328 for guard in &guards {
329 guard_map
330 .entry(NodeIndex::new(guard.assumption_id.0 as usize))
331 .or_insert_with(Vec::new)
332 .push(guard.clone());
333 }
334 }
335
336 if let Ok(mut assumption_map) = self.assumptions.write() {
338 for assumption_id in &assumptions_made {
339 if let Some(assumption) = self.create_assumption(*assumption_id) {
340 assumption_map.insert(*assumption_id, assumption);
341 }
342 }
343 }
344
345 Ok(SpeculationResult {
346 assumptions_made,
347 optimizations_applied: optimizations,
348 guards_installed: guards,
349 estimated_speedup: total_speedup,
350 })
351 }
352
353 pub fn apply_speculative_optimizations(
355 &self,
356 graph: &mut ComputationGraph,
357 result: &SpeculationResult,
358 ) -> JitResult<usize> {
359 let mut applied_count = 0;
360
361 for optimization in &result.optimizations_applied {
362 match optimization.optimization_type {
363 SpeculativeOptimizationType::TypeCheckElimination => {
364 if self.apply_type_check_elimination(graph, optimization)? {
365 applied_count += 1;
366 }
367 }
368 SpeculativeOptimizationType::ShapeSpecialization => {
369 if self.apply_shape_specialization(graph, optimization)? {
370 applied_count += 1;
371 }
372 }
373 SpeculativeOptimizationType::ConstantPropagation => {
374 if self.apply_constant_propagation(graph, optimization)? {
375 applied_count += 1;
376 }
377 }
378 SpeculativeOptimizationType::BranchElimination => {
379 if self.apply_branch_elimination(graph, optimization)? {
380 applied_count += 1;
381 }
382 }
383 _ => {
384 }
386 }
387 }
388
389 Ok(applied_count)
390 }
391
392 pub fn check_guards(&self, node_id: NodeId, runtime_info: &RuntimeInfo) -> JitResult<bool> {
394 let guard_map = self
395 .guards
396 .read()
397 .map_err(|_| JitError::RuntimeError("Failed to read guards".to_string()))?;
398
399 if let Some(node_guards) = guard_map.get(&node_id) {
400 for guard in node_guards {
401 if self.should_check_guard(guard, runtime_info.execution_count) {
402 let check_passed = self.execute_guard_check(guard, runtime_info)?;
403
404 if !check_passed {
405 self.handle_deoptimization(
406 guard.assumption_id,
407 node_id,
408 "Guard check failed",
409 )?;
410 return Ok(false);
411 }
412 }
413 }
414 }
415
416 Ok(true)
417 }
418
419 pub fn record_success(&self, assumption_id: AssumptionId) {
421 if let Ok(mut assumptions) = self.assumptions.write() {
422 if let Some(assumption) = assumptions.get_mut(&assumption_id) {
423 assumption.success_count += 1;
424 assumption.confidence = self.calculate_confidence(assumption);
425 }
426 }
427 }
428
429 pub fn handle_deoptimization(
431 &self,
432 assumption_id: AssumptionId,
433 node_id: NodeId,
434 reason: &str,
435 ) -> JitResult<()> {
436 let deopt_count = self.deopt_counter.fetch_add(1, Ordering::Relaxed);
437
438 if let Ok(mut assumptions) = self.assumptions.write() {
440 if let Some(assumption) = assumptions.get_mut(&assumption_id) {
441 assumption.failure_count += 1;
442 assumption.confidence = self.calculate_confidence(assumption);
443
444 if assumption.confidence < 0.3 {
446 assumptions.remove(&assumption_id);
447 }
448 }
449 }
450
451 if deopt_count > self.config.deopt_threshold {
453 self.enabled.store(false, Ordering::Relaxed);
454 }
455
456 let event = DeoptimizationEvent {
458 assumption_id,
459 node_id,
460 reason: reason.to_string(),
461 timestamp: std::time::SystemTime::now(),
462 execution_count: deopt_count,
463 };
464
465 self.log_deoptimization_event(&event);
466
467 Ok(())
468 }
469
470 pub fn get_statistics(&self) -> JitResult<SpeculationStatistics> {
472 let assumptions = self
473 .assumptions
474 .read()
475 .map_err(|_| JitError::RuntimeError("Failed to read assumptions".to_string()))?;
476
477 let active_assumptions = assumptions.len();
478 let total_successes = assumptions.values().map(|a| a.success_count).sum();
479 let total_failures = assumptions.values().map(|a| a.failure_count).sum();
480 let avg_confidence = if !assumptions.is_empty() {
481 assumptions.values().map(|a| a.confidence).sum::<f64>() / assumptions.len() as f64
482 } else {
483 0.0
484 };
485
486 let deopt_count = self.deopt_counter.load(Ordering::Relaxed);
487 let enabled = self.enabled.load(Ordering::Relaxed);
488
489 Ok(SpeculationStatistics {
490 active_assumptions,
491 total_successes,
492 total_failures,
493 avg_confidence,
494 deoptimization_count: deopt_count,
495 enabled,
496 })
497 }
498
499 fn analyze_type_speculation(
501 &self,
502 node_id: NodeId,
503 _node: &Node,
504 history: &NodeExecutionHistory,
505 ) -> JitResult<Option<SpeculationResult>> {
506 if let Some(dominant_type) = history.get_dominant_type(self.config.confidence_threshold) {
508 let assumption_id = self.generate_assumption_id();
509
510 let optimization = SpeculativeOptimization {
511 optimization_type: SpeculativeOptimizationType::TypeCheckElimination,
512 node_id,
513 description: format!("Assume type is always {}", dominant_type),
514 estimated_benefit: 0.05, };
516
517 let guard = Guard {
518 assumption_id,
519 guard_type: GuardType::TypeCheck,
520 check_frequency: GuardFrequency::Probabilistic(0.1), };
522
523 return Ok(Some(SpeculationResult {
524 assumptions_made: vec![assumption_id],
525 optimizations_applied: vec![optimization],
526 guards_installed: vec![guard],
527 estimated_speedup: 1.05,
528 }));
529 }
530
531 Ok(None)
532 }
533
534 fn analyze_shape_speculation(
535 &self,
536 node_id: NodeId,
537 _node: &Node,
538 history: &NodeExecutionHistory,
539 ) -> JitResult<Option<SpeculationResult>> {
540 if let Some(dominant_shape) = history.get_dominant_shape(self.config.confidence_threshold) {
542 let assumption_id = self.generate_assumption_id();
543
544 let optimization = SpeculativeOptimization {
545 optimization_type: SpeculativeOptimizationType::ShapeSpecialization,
546 node_id,
547 description: format!("Specialize for shape {:?}", dominant_shape),
548 estimated_benefit: 0.15, };
550
551 let guard = Guard {
552 assumption_id,
553 guard_type: GuardType::ShapeCheck,
554 check_frequency: GuardFrequency::Always, };
556
557 return Ok(Some(SpeculationResult {
558 assumptions_made: vec![assumption_id],
559 optimizations_applied: vec![optimization],
560 guards_installed: vec![guard],
561 estimated_speedup: 1.15,
562 }));
563 }
564
565 Ok(None)
566 }
567
568 fn analyze_value_speculation(
569 &self,
570 node_id: NodeId,
571 _node: &Node,
572 history: &NodeExecutionHistory,
573 ) -> JitResult<Option<SpeculationResult>> {
574 if let Some(constant_value) = history.get_constant_value(self.config.confidence_threshold) {
576 let assumption_id = self.generate_assumption_id();
577
578 let optimization = SpeculativeOptimization {
579 optimization_type: SpeculativeOptimizationType::ConstantPropagation,
580 node_id,
581 description: format!("Assume constant value {}", constant_value),
582 estimated_benefit: 0.20, };
584
585 let guard = Guard {
586 assumption_id,
587 guard_type: GuardType::ValueCheck,
588 check_frequency: GuardFrequency::Periodic(100), };
590
591 return Ok(Some(SpeculationResult {
592 assumptions_made: vec![assumption_id],
593 optimizations_applied: vec![optimization],
594 guards_installed: vec![guard],
595 estimated_speedup: 1.20,
596 }));
597 }
598
599 Ok(None)
600 }
601
602 fn analyze_branch_speculation(
603 &self,
604 node_id: NodeId,
605 _node: &Node,
606 history: &NodeExecutionHistory,
607 ) -> JitResult<Option<SpeculationResult>> {
608 if let Some(branch_bias) = history.get_branch_bias(self.config.confidence_threshold) {
610 let assumption_id = self.generate_assumption_id();
611
612 let optimization = SpeculativeOptimization {
613 optimization_type: SpeculativeOptimizationType::BranchElimination,
614 node_id,
615 description: format!(
616 "Assume branch is usually {}",
617 if branch_bias > 0.5 {
618 "taken"
619 } else {
620 "not taken"
621 }
622 ),
623 estimated_benefit: 0.10, };
625
626 let guard = Guard {
627 assumption_id,
628 guard_type: GuardType::BranchCheck,
629 check_frequency: GuardFrequency::Probabilistic(0.05), };
631
632 return Ok(Some(SpeculationResult {
633 assumptions_made: vec![assumption_id],
634 optimizations_applied: vec![optimization],
635 guards_installed: vec![guard],
636 estimated_speedup: 1.10,
637 }));
638 }
639
640 Ok(None)
641 }
642
643 fn apply_type_check_elimination(
644 &self,
645 graph: &mut ComputationGraph,
646 optimization: &SpeculativeOptimization,
647 ) -> JitResult<bool> {
648 let node_id = optimization.node_id;
649
650 if let Some(node) = graph.node_mut(node_id) {
651 node.set_optimization_hint("eliminate_type_checks", "true")?;
653 node.set_optimization_hint("assumed_type_stable", "true")?;
654
655 node.set_optimization_hint("add_type_guard", "true")?;
657 node.set_optimization_hint("guard_frequency", "low")?;
658
659 return Ok(true);
660 }
661
662 Ok(false)
663 }
664
665 fn apply_shape_specialization(
666 &self,
667 graph: &mut ComputationGraph,
668 optimization: &SpeculativeOptimization,
669 ) -> JitResult<bool> {
670 let node_id = optimization.node_id;
671
672 if let Some(node) = graph.node_mut(node_id) {
673 node.set_optimization_hint("shape_specialized", "true")?;
675 node.set_optimization_hint("eliminate_shape_checks", "true")?;
676
677 if optimization.description.contains("shape") {
679 node.set_optimization_hint("specialized_shape_source", "speculation")?;
680 node.set_optimization_hint("add_shape_guard", "true")?;
681 }
682
683 return Ok(true);
684 }
685
686 Ok(false)
687 }
688
689 fn apply_constant_propagation(
690 &self,
691 graph: &mut ComputationGraph,
692 optimization: &SpeculativeOptimization,
693 ) -> JitResult<bool> {
694 let node_id = optimization.node_id;
695
696 if let Some(node) = graph.node_mut(node_id) {
697 node.set_optimization_hint("constant_propagation", "true")?;
699 node.set_optimization_hint("assumed_constant", "true")?;
700
701 if let Some(start) = optimization.description.find("value ") {
703 if let Some(end) = optimization.description[start + 6..].find(' ') {
704 let value_str = &optimization.description[start + 6..start + 6 + end];
705 node.set_optimization_hint("assumed_constant_value", value_str)?;
706 }
707 }
708
709 node.set_optimization_hint("add_value_guard", "true")?;
711 node.set_optimization_hint("guard_tolerance", "1e-10")?;
712
713 return Ok(true);
714 }
715
716 Ok(false)
717 }
718
719 fn apply_branch_elimination(
720 &self,
721 graph: &mut ComputationGraph,
722 optimization: &SpeculativeOptimization,
723 ) -> JitResult<bool> {
724 let node_id = optimization.node_id;
725
726 if let Some(node) = graph.node_mut(node_id) {
727 let usually_taken = optimization.description.contains("usually taken");
729
730 if usually_taken {
731 node.set_optimization_hint("branch_likely", "true")?;
732 node.set_optimization_hint("optimize_taken_path", "true")?;
733 } else {
734 node.set_optimization_hint("branch_unlikely", "true")?;
735 node.set_optimization_hint("optimize_not_taken_path", "true")?;
736 }
737
738 if optimization.estimated_benefit > 0.08 {
740 node.set_optimization_hint("branch_elimination_candidate", "true")?;
742 node.set_optimization_hint("speculative_branch_elimination", "true")?;
743 }
744
745 node.set_optimization_hint("add_branch_guard", "true")?;
747
748 return Ok(true);
749 }
750
751 Ok(false)
752 }
753
754 fn should_check_guard(&self, guard: &Guard, execution_count: u64) -> bool {
755 match guard.check_frequency {
756 GuardFrequency::Always => true,
757 GuardFrequency::Probabilistic(probability) => {
758 use std::collections::hash_map::DefaultHasher;
759 use std::hash::{Hash, Hasher};
760
761 let mut hasher = DefaultHasher::new();
762 execution_count.hash(&mut hasher);
763 let hash = hasher.finish();
764 (hash as f64 / u64::MAX as f64) < probability
765 }
766 GuardFrequency::Periodic(period) => execution_count % period == 0,
767 GuardFrequency::InitialOnly(limit) => execution_count < limit,
768 }
769 }
770
771 fn execute_guard_check(&self, guard: &Guard, runtime_info: &RuntimeInfo) -> JitResult<bool> {
772 match guard.guard_type {
773 GuardType::TypeCheck => {
774 Ok(runtime_info.actual_type == runtime_info.expected_type)
776 }
777 GuardType::ShapeCheck => {
778 Ok(runtime_info.actual_shape == runtime_info.expected_shape)
780 }
781 GuardType::ValueCheck => {
782 Ok(
784 (runtime_info.actual_value - runtime_info.expected_value).abs()
785 < runtime_info.tolerance,
786 )
787 }
788 GuardType::NullabilityCheck => {
789 Ok(!runtime_info.actual_value.is_nan() && runtime_info.actual_value.is_finite())
791 }
792 GuardType::BranchCheck => {
793 Ok(runtime_info.branch_taken == runtime_info.expected_branch_taken)
795 }
796 GuardType::LoopCheck => {
797 let diff = (runtime_info.actual_iterations as i64
799 - runtime_info.expected_iterations as i64)
800 .abs();
801 Ok(diff <= runtime_info.iteration_tolerance as i64)
802 }
803 GuardType::MemoryCheck => {
804 Ok(runtime_info.memory_pattern == runtime_info.expected_memory_pattern)
806 }
807 }
808 }
809
810 fn generate_assumption_id(&self) -> AssumptionId {
811 use std::sync::atomic::{AtomicU64, Ordering};
812 static COUNTER: AtomicU64 = AtomicU64::new(0);
813 AssumptionId(COUNTER.fetch_add(1, Ordering::Relaxed))
814 }
815
816 fn create_assumption(&self, id: AssumptionId) -> Option<Assumption> {
817 Some(Assumption {
820 id,
821 assumption_type: AssumptionType::NullabilitySpeculation,
822 node_id: NodeIndex::new(0),
823 confidence: 0.8,
824 success_count: 0,
825 failure_count: 0,
826 created_at: std::time::SystemTime::now(),
827 metadata: HashMap::new(),
828 })
829 }
830
831 fn calculate_confidence(&self, assumption: &Assumption) -> f64 {
832 let total = assumption.success_count + assumption.failure_count;
833 if total == 0 {
834 return 0.5; }
836
837 assumption.success_count as f64 / total as f64
838 }
839
840 fn log_deoptimization_event(&self, event: &DeoptimizationEvent) {
841 eprintln!("Deoptimization: {:?}", event);
843 }
844}
845
846#[derive(Debug, Clone)]
848pub struct RuntimeInfo {
849 pub execution_count: u64,
850 pub actual_type: String,
851 pub expected_type: String,
852 pub actual_shape: Vec<usize>,
853 pub expected_shape: Vec<usize>,
854 pub actual_value: f64,
855 pub expected_value: f64,
856 pub tolerance: f64,
857 pub branch_taken: bool,
858 pub expected_branch_taken: bool,
859 pub actual_iterations: u64,
860 pub expected_iterations: u64,
861 pub iteration_tolerance: u64,
862 pub memory_pattern: MemoryAccessPattern,
863 pub expected_memory_pattern: MemoryAccessPattern,
864}
865
866#[derive(Debug, Clone)]
868pub struct NodeExecutionHistory {
869 types: Vec<String>,
870 shapes: Vec<Vec<usize>>,
871 values: Vec<f64>,
872 branch_outcomes: Vec<bool>,
873 loop_iterations: Vec<u64>,
874}
875
876impl NodeExecutionHistory {
877 pub fn get_dominant_type(&self, threshold: f64) -> Option<String> {
878 let mut type_counts = HashMap::new();
879 for type_name in &self.types {
880 *type_counts.entry(type_name.clone()).or_insert(0) += 1;
881 }
882
883 if let Some((dominant_type, count)) = type_counts.iter().max_by_key(|(_, &count)| count) {
884 if *count as f64 / self.types.len() as f64 >= threshold {
885 return Some(dominant_type.clone());
886 }
887 }
888
889 None
890 }
891
892 pub fn get_dominant_shape(&self, threshold: f64) -> Option<Vec<usize>> {
893 let mut shape_counts = HashMap::new();
894 for shape in &self.shapes {
895 *shape_counts.entry(shape.clone()).or_insert(0) += 1;
896 }
897
898 if let Some((dominant_shape, count)) = shape_counts.iter().max_by_key(|(_, &count)| count) {
899 if *count as f64 / self.shapes.len() as f64 >= threshold {
900 return Some(dominant_shape.clone());
901 }
902 }
903
904 None
905 }
906
907 pub fn get_constant_value(&self, threshold: f64) -> Option<f64> {
908 if self.values.is_empty() {
909 return None;
910 }
911
912 let first_value = self.values[0];
914 let tolerance = 1e-10;
915 let constant_count = self
916 .values
917 .iter()
918 .filter(|&&v| (v - first_value).abs() < tolerance)
919 .count();
920
921 if constant_count as f64 / self.values.len() as f64 >= threshold {
922 Some(first_value)
923 } else {
924 None
925 }
926 }
927
928 pub fn get_branch_bias(&self, threshold: f64) -> Option<f64> {
929 if self.branch_outcomes.is_empty() {
930 return None;
931 }
932
933 let taken_count = self.branch_outcomes.iter().filter(|&&taken| taken).count();
934 let bias = taken_count as f64 / self.branch_outcomes.len() as f64;
935
936 if (bias - 0.5).abs() >= (threshold - 0.5) {
938 Some(bias)
939 } else {
940 None
941 }
942 }
943}
944
945#[derive(Debug, Clone)]
947pub struct ExecutionHistory {
948 node_histories: HashMap<NodeId, NodeExecutionHistory>,
949}
950
951impl ExecutionHistory {
952 pub fn new() -> Self {
953 Self {
954 node_histories: HashMap::new(),
955 }
956 }
957
958 pub fn get_node_history(&self, node_id: NodeId) -> Option<&NodeExecutionHistory> {
959 self.node_histories.get(&node_id)
960 }
961
962 pub fn record_execution(&mut self, node_id: NodeId, info: NodeExecutionInfo) {
963 let history = self
964 .node_histories
965 .entry(node_id)
966 .or_insert_with(|| NodeExecutionHistory {
967 types: Vec::new(),
968 shapes: Vec::new(),
969 values: Vec::new(),
970 branch_outcomes: Vec::new(),
971 loop_iterations: Vec::new(),
972 });
973
974 if let Some(type_name) = info.type_name {
975 history.types.push(type_name);
976 }
977 if let Some(shape) = info.shape {
978 history.shapes.push(shape);
979 }
980 if let Some(value) = info.value {
981 history.values.push(value);
982 }
983 if let Some(branch_taken) = info.branch_taken {
984 history.branch_outcomes.push(branch_taken);
985 }
986 if let Some(iterations) = info.loop_iterations {
987 history.loop_iterations.push(iterations);
988 }
989 }
990}
991
992#[derive(Debug, Clone)]
994pub struct NodeExecutionInfo {
995 pub type_name: Option<String>,
996 pub shape: Option<Vec<usize>>,
997 pub value: Option<f64>,
998 pub branch_taken: Option<bool>,
999 pub loop_iterations: Option<u64>,
1000}
1001
1002#[derive(Debug, Clone)]
1004pub struct SpeculationStatistics {
1005 pub active_assumptions: usize,
1006 pub total_successes: u64,
1007 pub total_failures: u64,
1008 pub avg_confidence: f64,
1009 pub deoptimization_count: u64,
1010 pub enabled: bool,
1011}
1012
1013#[cfg(test)]
1014mod tests {
1015 use super::*;
1016
1017 #[test]
1018 fn test_speculative_optimizer_creation() {
1019 let config = SpeculativeConfig::default();
1020 let optimizer = SpeculativeOptimizer::new(config);
1021 assert!(optimizer.enabled.load(Ordering::Relaxed));
1022 assert_eq!(optimizer.deopt_counter.load(Ordering::Relaxed), 0);
1023 }
1024
1025 #[test]
1026 fn test_assumption_id_generation() {
1027 let optimizer = SpeculativeOptimizer::new(SpeculativeConfig::default());
1028 let id1 = optimizer.generate_assumption_id();
1029 let id2 = optimizer.generate_assumption_id();
1030 assert_ne!(id1, id2);
1031 }
1032
1033 #[test]
1034 fn test_guard_frequency_checking() {
1035 let optimizer = SpeculativeOptimizer::new(SpeculativeConfig::default());
1036
1037 let always_guard = Guard {
1038 assumption_id: AssumptionId(1),
1039 guard_type: GuardType::TypeCheck,
1040 check_frequency: GuardFrequency::Always,
1041 };
1042 assert!(optimizer.should_check_guard(&always_guard, 100));
1043
1044 let periodic_guard = Guard {
1045 assumption_id: AssumptionId(2),
1046 guard_type: GuardType::TypeCheck,
1047 check_frequency: GuardFrequency::Periodic(10),
1048 };
1049 assert!(optimizer.should_check_guard(&periodic_guard, 100));
1050 assert!(!optimizer.should_check_guard(&periodic_guard, 101));
1051 }
1052
1053 #[test]
1054 fn test_execution_history() {
1055 let mut history = ExecutionHistory::new();
1056 let node_id = NodeId::new(1);
1057
1058 history.record_execution(
1060 node_id,
1061 NodeExecutionInfo {
1062 type_name: Some("f32".to_string()),
1063 shape: Some(vec![10, 20]),
1064 value: Some(1.0),
1065 branch_taken: Some(true),
1066 loop_iterations: Some(5),
1067 },
1068 );
1069
1070 history.record_execution(
1071 node_id,
1072 NodeExecutionInfo {
1073 type_name: Some("f32".to_string()),
1074 shape: Some(vec![10, 20]),
1075 value: Some(1.0),
1076 branch_taken: Some(true),
1077 loop_iterations: Some(5),
1078 },
1079 );
1080
1081 let node_history = history.get_node_history(node_id).unwrap();
1082 assert_eq!(node_history.get_dominant_type(0.8), Some("f32".to_string()));
1083 assert_eq!(node_history.get_dominant_shape(0.8), Some(vec![10, 20]));
1084 assert_eq!(node_history.get_constant_value(0.8), Some(1.0));
1085 assert_eq!(node_history.get_branch_bias(0.8), Some(1.0));
1086 }
1087}