1use crate::operations::{Arithmetic, arithmetic::BinaryOperation};
2use crate::traits::Shareable;
3use std::collections::HashMap;
4use std::sync::Arc;
5
6#[derive(Debug, Clone)]
8pub struct AdaptiveSampling {
9 pub min_samples: usize,
11 pub max_samples: usize,
13 pub error_threshold: f64,
15 pub growth_factor: f64,
17}
18
19impl Default for AdaptiveSampling {
20 fn default() -> Self {
21 Self {
22 min_samples: 100,
23 max_samples: 10000,
24 error_threshold: 0.01,
25 growth_factor: 1.5,
26 }
27 }
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum CachingStrategy {
33 Aggressive,
35 Conservative,
37 Adaptive,
39}
40
41pub struct SampleContext {
44 memoized_values: HashMap<uuid::Uuid, Box<dyn std::any::Any + Send>>,
46 caching_strategy: CachingStrategy,
48 adaptive_sampling: AdaptiveSampling,
50}
51
52impl SampleContext {
53 #[must_use]
55 pub fn new() -> Self {
56 Self {
57 memoized_values: HashMap::new(),
58 caching_strategy: CachingStrategy::Adaptive,
59 adaptive_sampling: AdaptiveSampling::default(),
60 }
61 }
62
63 #[must_use]
65 pub fn with_caching_strategy(strategy: CachingStrategy) -> Self {
66 Self {
67 memoized_values: HashMap::new(),
68 caching_strategy: strategy,
69 adaptive_sampling: AdaptiveSampling::default(),
70 }
71 }
72
73 #[must_use]
75 pub fn get_value<T: Clone + 'static>(&self, id: &uuid::Uuid) -> Option<T> {
76 self.memoized_values.get(id)?.downcast_ref::<T>().cloned()
77 }
78
79 pub fn set_value<T: Clone + Send + 'static>(&mut self, id: uuid::Uuid, value: T) {
81 self.memoized_values.insert(id, Box::new(value));
82 }
83
84 pub fn clear(&mut self) {
86 self.memoized_values.clear();
87 }
88
89 #[must_use]
91 pub fn len(&self) -> usize {
92 self.memoized_values.len()
93 }
94
95 #[must_use]
97 pub fn is_empty(&self) -> bool {
98 self.memoized_values.is_empty()
99 }
100
101 #[must_use]
103 pub fn should_cache_node(&self, node: &ComputationNode<impl Shareable>) -> bool {
104 match self.caching_strategy {
105 CachingStrategy::Aggressive => true,
106 CachingStrategy::Conservative => {
107 node.depth() > 2 || matches!(node, ComputationNode::Conditional { .. })
109 }
110 CachingStrategy::Adaptive => {
111 let complexity = node.compute_complexity();
112 complexity > 5 }
114 }
115 }
116
117 #[must_use]
119 pub fn adaptive_sampling(&self) -> &AdaptiveSampling {
120 &self.adaptive_sampling
121 }
122
123 pub fn set_adaptive_sampling(&mut self, config: AdaptiveSampling) {
125 self.adaptive_sampling = config;
126 }
127}
128
129impl Default for SampleContext {
130 fn default() -> Self {
131 Self::new()
132 }
133}
134
135#[derive(Clone)]
141pub enum ComputationNode<T> {
142 Leaf {
144 id: uuid::Uuid,
145 sample: Arc<dyn Fn() -> T + Send + Sync>,
146 },
147
148 BinaryOp {
150 left: Box<ComputationNode<T>>,
151 right: Box<ComputationNode<T>>,
152 operation: BinaryOperation,
153 },
154
155 UnaryOp {
157 operand: Box<ComputationNode<T>>,
158 operation: UnaryOperation<T>,
159 },
160
161 Conditional {
163 condition: Box<ComputationNode<bool>>,
164 if_true: Box<ComputationNode<T>>,
165 if_false: Box<ComputationNode<T>>,
166 },
167}
168
169#[derive(Clone)]
171pub enum UnaryOperation<T> {
172 Map(Arc<dyn Fn(T) -> T + Send + Sync>),
173 Filter(Arc<dyn Fn(&T) -> bool + Send + Sync>),
174}
175
176impl<T> ComputationNode<T>
177where
178 T: Shareable,
179{
180 pub fn evaluate(&self, context: &mut SampleContext) -> T {
190 match self {
191 ComputationNode::Leaf { id, sample } => {
192 if let Some(cached) = context.get_value::<T>(id) {
194 cached
195 } else {
196 let value = sample();
198 context.set_value(*id, value.clone());
199 value
200 }
201 }
202
203 ComputationNode::UnaryOp { operand, operation } => {
204 let operand_val = operand.evaluate(context);
205 match operation {
206 UnaryOperation::Map(func) => func(operand_val),
207 UnaryOperation::Filter(_) => {
208 operand_val
211 }
212 }
213 }
214
215 ComputationNode::BinaryOp { .. } => {
217 panic!(
218 "BinaryOp evaluation requires arithmetic trait bounds. Use evaluate_arithmetic instead."
219 )
220 }
221
222 ComputationNode::Conditional { .. } => {
223 panic!(
224 "Conditional evaluation requires specific handling. Use evaluate_conditional instead."
225 )
226 }
227 }
228 }
229
230 pub fn evaluate_arithmetic(&self, context: &mut SampleContext) -> T
236 where
237 T: Arithmetic,
238 {
239 match self {
240 ComputationNode::Leaf { id, sample } => {
241 if let Some(cached) = context.get_value::<T>(id) {
242 cached
243 } else {
244 let value = sample();
245 context.set_value(*id, value.clone());
246 value
247 }
248 }
249
250 ComputationNode::BinaryOp {
251 left,
252 right,
253 operation,
254 } => {
255 let left_val = left.evaluate_arithmetic(context);
256 let right_val = right.evaluate_arithmetic(context);
257 operation.apply(left_val, right_val)
258 }
259
260 ComputationNode::UnaryOp { operand, operation } => {
261 let operand_val = operand.evaluate_arithmetic(context);
262 match operation {
263 UnaryOperation::Map(func) => func(operand_val),
264 UnaryOperation::Filter(_) => operand_val,
265 }
266 }
267
268 ComputationNode::Conditional {
269 condition: _,
270 if_true: _,
271 if_false: _,
272 } => {
273 panic!(
274 "Conditional evaluation with bool condition not supported in arithmetic context"
275 )
276 }
277 }
278 }
279
280 #[must_use]
285 pub fn evaluate_fresh(&self) -> T
286 where
287 T: Arithmetic,
288 {
289 let mut context = SampleContext::new();
290 self.evaluate_conditional_with_arithmetic(&mut context)
291 }
292
293 pub fn leaf<F>(sample: F) -> Self
295 where
296 F: Fn() -> T + Send + Sync + 'static,
297 {
298 ComputationNode::Leaf {
299 id: uuid::Uuid::new_v4(),
300 sample: Arc::new(sample),
301 }
302 }
303
304 #[must_use]
306 pub fn binary_op(
307 left: ComputationNode<T>,
308 right: ComputationNode<T>,
309 operation: BinaryOperation,
310 ) -> Self {
311 ComputationNode::BinaryOp {
312 left: Box::new(left),
313 right: Box::new(right),
314 operation,
315 }
316 }
317
318 pub fn map<F>(operand: ComputationNode<T>, func: F) -> Self
320 where
321 F: Fn(T) -> T + Send + Sync + 'static,
322 {
323 ComputationNode::UnaryOp {
324 operand: Box::new(operand),
325 operation: UnaryOperation::Map(Arc::new(func)),
326 }
327 }
328
329 #[must_use]
331 pub fn conditional(
332 condition: ComputationNode<bool>,
333 if_true: ComputationNode<T>,
334 if_false: ComputationNode<T>,
335 ) -> Self {
336 ComputationNode::Conditional {
337 condition: Box::new(condition),
338 if_true: Box::new(if_true),
339 if_false: Box::new(if_false),
340 }
341 }
342
343 #[must_use]
345 pub fn node_count(&self) -> usize {
346 match self {
347 ComputationNode::Leaf { .. } => 1,
348 ComputationNode::BinaryOp { left, right, .. } => {
349 1 + left.node_count() + right.node_count()
350 }
351 ComputationNode::UnaryOp { operand, .. } => 1 + operand.node_count(),
352 ComputationNode::Conditional {
353 condition,
354 if_true,
355 if_false,
356 } => 1 + condition.node_count() + if_true.node_count() + if_false.node_count(),
357 }
358 }
359
360 #[must_use]
362 pub fn depth(&self) -> usize {
363 match self {
364 ComputationNode::Leaf { .. } => 1,
365 ComputationNode::BinaryOp { left, right, .. } => 1 + left.depth().max(right.depth()),
366 ComputationNode::UnaryOp { operand, .. } => 1 + operand.depth(),
367 ComputationNode::Conditional {
368 condition,
369 if_true,
370 if_false,
371 } => 1 + condition.depth().max(if_true.depth().max(if_false.depth())),
372 }
373 }
374
375 #[must_use]
377 pub fn has_conditionals(&self) -> bool {
378 match self {
379 ComputationNode::Leaf { .. } => false,
380 ComputationNode::BinaryOp { left, right, .. } => {
381 left.has_conditionals() || right.has_conditionals()
382 }
383 ComputationNode::UnaryOp { operand, .. } => operand.has_conditionals(),
384 ComputationNode::Conditional { .. } => true,
385 }
386 }
387
388 #[must_use]
390 pub fn compute_complexity(&self) -> usize {
391 match self {
392 ComputationNode::Leaf { .. } => 1,
393 ComputationNode::BinaryOp { left, right, .. } => {
394 2 + left.compute_complexity() + right.compute_complexity()
395 }
396 ComputationNode::UnaryOp { operand, .. } => 1 + operand.compute_complexity(),
397 ComputationNode::Conditional {
398 condition,
399 if_true,
400 if_false,
401 } => {
402 5 + condition.compute_complexity()
403 + if_true.compute_complexity()
404 + if_false.compute_complexity()
405 }
406 }
407 }
408
409 #[must_use]
411 pub fn structural_hash(&self) -> u64 {
412 use std::collections::hash_map::DefaultHasher;
413 use std::hash::Hasher;
414
415 let mut hasher = DefaultHasher::new();
416 self.hash_structure(&mut hasher);
417 hasher.finish()
418 }
419
420 fn hash_structure(&self, hasher: &mut impl std::hash::Hasher) {
421 use std::hash::Hash;
422
423 match self {
424 ComputationNode::Leaf { id, .. } => {
425 "leaf".hash(hasher);
426 id.hash(hasher);
427 }
428 ComputationNode::BinaryOp {
429 left,
430 right,
431 operation,
432 } => {
433 "binary".hash(hasher);
434 operation.hash(hasher);
435 left.hash_structure(hasher);
436 right.hash_structure(hasher);
437 }
438 ComputationNode::UnaryOp { operand, .. } => {
439 "unary".hash(hasher);
440 operand.hash_structure(hasher);
441 }
442 ComputationNode::Conditional {
443 condition,
444 if_true,
445 if_false,
446 } => {
447 "conditional".hash(hasher);
448 condition.hash_structure(hasher);
449 if_true.hash_structure(hasher);
450 if_false.hash_structure(hasher);
451 }
452 }
453 }
454}
455
456impl ComputationNode<bool> {
458 pub fn evaluate_bool(&self, context: &mut SampleContext) -> bool {
464 match self {
465 ComputationNode::Leaf { id, sample } => {
466 if let Some(cached) = context.get_value::<bool>(id) {
467 cached
468 } else {
469 let value = sample();
470 context.set_value(*id, value);
471 value
472 }
473 }
474 ComputationNode::UnaryOp { operand, operation } => {
475 let operand_val = operand.evaluate_bool(context);
476 match operation {
477 UnaryOperation::Map(func) => func(operand_val),
478 UnaryOperation::Filter(_) => operand_val,
479 }
480 }
481 ComputationNode::BinaryOp { .. } => {
482 panic!("Boolean binary operations not implemented")
483 }
484 ComputationNode::Conditional {
485 condition,
486 if_true,
487 if_false,
488 } => {
489 let condition_val = condition.evaluate_bool(context);
490 if condition_val {
491 if_true.evaluate_bool(context)
492 } else {
493 if_false.evaluate_bool(context)
494 }
495 }
496 }
497 }
498}
499
500impl<T> ComputationNode<T>
502where
503 T: Shareable,
504{
505 pub fn evaluate_conditional_with_arithmetic(&self, context: &mut SampleContext) -> T
507 where
508 T: Arithmetic,
509 {
510 match self {
511 ComputationNode::Conditional {
512 condition,
513 if_true,
514 if_false,
515 } => {
516 let condition_val = condition.evaluate_bool(context);
517 if condition_val {
518 if_true.evaluate_arithmetic(context)
519 } else {
520 if_false.evaluate_arithmetic(context)
521 }
522 }
523 _ => self.evaluate_arithmetic(context),
524 }
525 }
526}
527
528pub struct GraphOptimizer {
530 pub subexpression_cache: HashMap<u64, Box<dyn std::any::Any + Send + Sync>>,
532}
533
534impl GraphOptimizer {
535 #[must_use]
537 pub fn new() -> Self {
538 Self {
539 subexpression_cache: HashMap::new(),
540 }
541 }
542
543 #[must_use]
545 pub fn optimize<T>(&mut self, node: ComputationNode<T>) -> ComputationNode<T>
546 where
547 T: Shareable + Arithmetic + PartialEq + Clone,
548 {
549 let node = self.eliminate_common_subexpressions(node);
550 let node = Self::eliminate_identity_operations(node);
551 Self::constant_folding(node)
552 }
553
554 pub fn eliminate_common_subexpressions<T>(
556 &mut self,
557 node: ComputationNode<T>,
558 ) -> ComputationNode<T>
559 where
560 T: Shareable,
561 {
562 let hash = node.structural_hash();
563
564 if let Some(cached_node) = self.subexpression_cache.get(&hash)
566 && let Some(cached) = cached_node.downcast_ref::<ComputationNode<T>>()
567 {
568 return cached.clone();
569 }
570
571 let optimized = match node {
573 ComputationNode::BinaryOp {
574 left,
575 right,
576 operation,
577 } => {
578 let left_opt = Box::new(self.eliminate_common_subexpressions(*left));
579 let right_opt = Box::new(self.eliminate_common_subexpressions(*right));
580 ComputationNode::BinaryOp {
581 left: left_opt,
582 right: right_opt,
583 operation,
584 }
585 }
586 ComputationNode::UnaryOp { operand, operation } => {
587 let operand_opt = Box::new(self.eliminate_common_subexpressions(*operand));
588 ComputationNode::UnaryOp {
589 operand: operand_opt,
590 operation,
591 }
592 }
593 ComputationNode::Conditional {
594 condition,
595 if_true,
596 if_false,
597 } => {
598 let condition_opt = Box::new(self.eliminate_common_subexpressions(*condition));
599 let if_true_opt = Box::new(self.eliminate_common_subexpressions(*if_true));
600 let if_false_opt = Box::new(self.eliminate_common_subexpressions(*if_false));
601 ComputationNode::Conditional {
602 condition: condition_opt,
603 if_true: if_true_opt,
604 if_false: if_false_opt,
605 }
606 }
607 leaf @ ComputationNode::Leaf { .. } => leaf,
608 };
609
610 self.subexpression_cache
612 .insert(hash, Box::new(optimized.clone()));
613
614 optimized
615 }
616
617 #[allow(clippy::too_many_lines)]
619 fn eliminate_identity_operations<T>(node: ComputationNode<T>) -> ComputationNode<T>
620 where
621 T: Shareable + Arithmetic + PartialEq + Clone,
622 {
623 match node {
624 ComputationNode::BinaryOp {
625 left,
626 right,
627 operation,
628 } => Self::eliminate_identity_operations_binary(*left, *right, operation),
629 ComputationNode::UnaryOp { operand, operation } => {
630 Self::eliminate_identity_operations_unary(*operand, operation)
631 }
632 ComputationNode::Conditional {
633 condition,
634 if_true,
635 if_false,
636 } => Self::eliminate_identity_operations_conditional(*condition, *if_true, *if_false),
637 ComputationNode::Leaf { .. } => node,
638 }
639 }
640
641 fn eliminate_identity_operations_binary<T>(
643 left: ComputationNode<T>,
644 right: ComputationNode<T>,
645 operation: BinaryOperation,
646 ) -> ComputationNode<T>
647 where
648 T: Shareable + Arithmetic + PartialEq + Clone,
649 {
650 let left_opt = Self::eliminate_identity_operations(left);
651 let right_opt = Self::eliminate_identity_operations(right);
652
653 match operation {
655 BinaryOperation::Add => {
656 if let Some(result) = Self::check_addition_identities(&left_opt, &right_opt) {
657 return result;
658 }
659 }
660 BinaryOperation::Sub => {
661 if let Some(result) = Self::check_subtraction_identities(&left_opt, &right_opt) {
662 return result;
663 }
664 }
665 BinaryOperation::Mul => {
666 if let Some(result) = Self::check_multiplication_identities(&left_opt, &right_opt) {
667 return result;
668 }
669 }
670 BinaryOperation::Div => {
671 if let Some(result) = Self::check_division_identities(&left_opt, &right_opt) {
672 return result;
673 }
674 }
675 }
676
677 ComputationNode::BinaryOp {
678 left: Box::new(left_opt),
679 right: Box::new(right_opt),
680 operation,
681 }
682 }
683
684 fn check_addition_identities<T>(
686 left: &ComputationNode<T>,
687 right: &ComputationNode<T>,
688 ) -> Option<ComputationNode<T>>
689 where
690 T: Shareable + Arithmetic + PartialEq + Clone,
691 {
692 match (left, right) {
693 (
695 left,
696 ComputationNode::Leaf {
697 sample: right_sample,
698 ..
699 },
700 ) => {
701 if Self::is_constant_zero(right_sample) {
702 return Some(left.clone());
703 }
704 }
705 (
707 ComputationNode::Leaf {
708 sample: left_sample,
709 ..
710 },
711 right,
712 ) => {
713 if Self::is_constant_zero(left_sample) {
714 return Some(right.clone());
715 }
716 }
717 _ => {}
718 }
719 None
720 }
721
722 fn check_subtraction_identities<T>(
724 left: &ComputationNode<T>,
725 right: &ComputationNode<T>,
726 ) -> Option<ComputationNode<T>>
727 where
728 T: Shareable + Arithmetic + PartialEq + Clone,
729 {
730 if let (
732 left,
733 ComputationNode::Leaf {
734 sample: right_sample,
735 ..
736 },
737 ) = (left, right)
738 && Self::is_constant_zero(right_sample)
739 {
740 return Some(left.clone());
741 }
742 None
743 }
744
745 fn check_multiplication_identities<T>(
747 left: &ComputationNode<T>,
748 right: &ComputationNode<T>,
749 ) -> Option<ComputationNode<T>>
750 where
751 T: Shareable + Arithmetic + PartialEq + Clone,
752 {
753 match (left, right) {
755 (
757 _left,
758 ComputationNode::Leaf {
759 sample: right_sample,
760 ..
761 },
762 ) => {
763 if Self::is_constant_zero(right_sample) {
764 return Some(ComputationNode::leaf(|| T::zero()));
765 }
766 }
767 (
769 ComputationNode::Leaf {
770 sample: left_sample,
771 ..
772 },
773 _right,
774 ) => {
775 if Self::is_constant_zero(left_sample) {
776 return Some(ComputationNode::leaf(|| T::zero()));
777 }
778 }
779 _ => {}
780 }
781
782 match (left, right) {
784 (
786 left,
787 ComputationNode::Leaf {
788 sample: right_sample,
789 ..
790 },
791 ) => {
792 if Self::is_constant_one(right_sample) {
793 return Some(left.clone());
794 }
795 }
796 (
798 ComputationNode::Leaf {
799 sample: left_sample,
800 ..
801 },
802 right,
803 ) => {
804 if Self::is_constant_one(left_sample) {
805 return Some(right.clone());
806 }
807 }
808 _ => {}
809 }
810
811 None
812 }
813
814 fn check_division_identities<T>(
816 left: &ComputationNode<T>,
817 right: &ComputationNode<T>,
818 ) -> Option<ComputationNode<T>>
819 where
820 T: Shareable + Arithmetic + PartialEq + Clone,
821 {
822 if let (
824 left,
825 ComputationNode::Leaf {
826 sample: right_sample,
827 ..
828 },
829 ) = (left, right)
830 && Self::is_constant_one(right_sample)
831 {
832 return Some(left.clone());
833 }
834 None
835 }
836
837 fn eliminate_identity_operations_unary<T>(
839 operand: ComputationNode<T>,
840 operation: UnaryOperation<T>,
841 ) -> ComputationNode<T>
842 where
843 T: Shareable + Arithmetic + PartialEq + Clone,
844 {
845 let operand_opt = Self::eliminate_identity_operations(operand);
846 ComputationNode::UnaryOp {
847 operand: Box::new(operand_opt),
848 operation,
849 }
850 }
851
852 fn eliminate_identity_operations_conditional<T>(
854 condition: ComputationNode<bool>,
855 if_true: ComputationNode<T>,
856 if_false: ComputationNode<T>,
857 ) -> ComputationNode<T>
858 where
859 T: Shareable + Arithmetic + PartialEq + Clone,
860 {
861 let condition_opt = Self::eliminate_identity_operations_bool(condition);
863 let if_true_opt = Self::eliminate_identity_operations(if_true);
864 let if_false_opt = Self::eliminate_identity_operations(if_false);
865 ComputationNode::Conditional {
866 condition: Box::new(condition_opt),
867 if_true: Box::new(if_true_opt),
868 if_false: Box::new(if_false_opt),
869 }
870 }
871
872 fn eliminate_identity_operations_bool(node: ComputationNode<bool>) -> ComputationNode<bool> {
874 match node {
875 ComputationNode::UnaryOp { operand, operation } => {
876 let operand_opt = Self::eliminate_identity_operations_bool(*operand);
877 ComputationNode::UnaryOp {
878 operand: Box::new(operand_opt),
879 operation,
880 }
881 }
882 ComputationNode::Conditional {
883 condition,
884 if_true,
885 if_false,
886 } => {
887 let condition_opt = Self::eliminate_identity_operations_bool(*condition);
888 let if_true_opt = Self::eliminate_identity_operations_bool(*if_true);
889 let if_false_opt = Self::eliminate_identity_operations_bool(*if_false);
890 ComputationNode::Conditional {
891 condition: Box::new(condition_opt),
892 if_true: Box::new(if_true_opt),
893 if_false: Box::new(if_false_opt),
894 }
895 }
896 ComputationNode::Leaf { .. } | ComputationNode::BinaryOp { .. } => node,
897 }
898 }
899
900 fn is_constant_zero<T>(sample_fn: &Arc<dyn Fn() -> T + Send + Sync>) -> bool
902 where
903 T: PartialEq + Clone + Arithmetic,
904 {
905 for _ in 0..3 {
907 if sample_fn() != T::zero() {
908 return false;
909 }
910 }
911 true
912 }
913
914 fn is_constant_one<T>(sample_fn: &Arc<dyn Fn() -> T + Send + Sync>) -> bool
916 where
917 T: PartialEq + Clone + Arithmetic,
918 {
919 for _ in 0..3 {
921 if sample_fn() != T::one() {
922 return false;
923 }
924 }
925 true
926 }
927
928 fn constant_folding<T>(node: ComputationNode<T>) -> ComputationNode<T>
930 where
931 T: Shareable + Arithmetic + Clone + PartialEq,
932 {
933 match node {
934 ComputationNode::BinaryOp {
935 left,
936 right,
937 operation,
938 } => Self::constant_folding_binary_op(*left, *right, operation),
939 ComputationNode::UnaryOp { operand, operation } => {
940 Self::constant_folding_unary_op(*operand, operation)
941 }
942 ComputationNode::Conditional {
943 condition,
944 if_true,
945 if_false,
946 } => Self::constant_folding_conditional(*condition, *if_true, *if_false),
947 ComputationNode::Leaf { .. } => node,
948 }
949 }
950
951 fn constant_folding_binary_op<T>(
953 left: ComputationNode<T>,
954 right: ComputationNode<T>,
955 operation: BinaryOperation,
956 ) -> ComputationNode<T>
957 where
958 T: Shareable + Arithmetic + Clone + PartialEq,
959 {
960 let left_opt = Self::constant_folding(left);
961 let right_opt = Self::constant_folding(right);
962
963 if let (
964 ComputationNode::Leaf {
965 sample: left_sample,
966 ..
967 },
968 ComputationNode::Leaf {
969 sample: right_sample,
970 ..
971 },
972 ) = (&left_opt, &right_opt)
973 && Self::is_constant(left_sample)
974 && Self::is_constant(right_sample)
975 {
976 let left_val = left_sample();
977 let right_val = right_sample();
978 let result = match operation {
979 BinaryOperation::Add => left_val + right_val,
980 BinaryOperation::Sub => left_val - right_val,
981 BinaryOperation::Mul => left_val * right_val,
982 BinaryOperation::Div => left_val / right_val,
983 };
984 return ComputationNode::leaf(move || result.clone());
985 }
986
987 ComputationNode::BinaryOp {
988 left: Box::new(left_opt),
989 right: Box::new(right_opt),
990 operation,
991 }
992 }
993
994 fn constant_folding_unary_op<T>(
996 operand: ComputationNode<T>,
997 operation: UnaryOperation<T>,
998 ) -> ComputationNode<T>
999 where
1000 T: Shareable + Arithmetic + Clone + PartialEq,
1001 {
1002 let operand_opt = Self::constant_folding(operand);
1003
1004 if let ComputationNode::Leaf {
1005 sample: operand_sample,
1006 ..
1007 } = &operand_opt
1008 && Self::is_constant(operand_sample)
1009 {
1010 let operand_val = operand_sample();
1011 let result = match operation {
1012 UnaryOperation::Map(func) => func(operand_val),
1013 UnaryOperation::Filter(_) => operand_val, };
1015 return ComputationNode::leaf(move || result.clone());
1016 }
1017
1018 ComputationNode::UnaryOp {
1019 operand: Box::new(operand_opt),
1020 operation,
1021 }
1022 }
1023
1024 fn constant_folding_conditional<T>(
1026 condition: ComputationNode<bool>,
1027 if_true: ComputationNode<T>,
1028 if_false: ComputationNode<T>,
1029 ) -> ComputationNode<T>
1030 where
1031 T: Shareable + Arithmetic + Clone + PartialEq,
1032 {
1033 let condition_opt = Self::constant_folding_bool(condition);
1034 let if_true_opt = Self::constant_folding(if_true);
1035 let if_false_opt = Self::constant_folding(if_false);
1036
1037 if let ComputationNode::Leaf {
1039 sample: condition_sample,
1040 ..
1041 } = &condition_opt
1042 && Self::is_constant_bool(condition_sample)
1043 {
1044 let condition_val = condition_sample();
1045 if condition_val {
1046 return if_true_opt;
1047 }
1048 return if_false_opt;
1049 }
1050
1051 ComputationNode::Conditional {
1052 condition: Box::new(condition_opt),
1053 if_true: Box::new(if_true_opt),
1054 if_false: Box::new(if_false_opt),
1055 }
1056 }
1057
1058 fn constant_folding_bool(node: ComputationNode<bool>) -> ComputationNode<bool> {
1060 match node {
1061 ComputationNode::UnaryOp { operand, operation } => {
1062 Self::constant_folding_bool_unary_op(*operand, operation)
1063 }
1064 ComputationNode::Conditional {
1065 condition,
1066 if_true,
1067 if_false,
1068 } => Self::constant_folding_bool_conditional(*condition, *if_true, *if_false),
1069 ComputationNode::Leaf { .. } | ComputationNode::BinaryOp { .. } => node,
1070 }
1071 }
1072
1073 fn constant_folding_bool_unary_op(
1075 operand: ComputationNode<bool>,
1076 operation: UnaryOperation<bool>,
1077 ) -> ComputationNode<bool> {
1078 let operand_opt = Self::constant_folding_bool(operand);
1079
1080 if let ComputationNode::Leaf {
1081 sample: operand_sample,
1082 ..
1083 } = &operand_opt
1084 && Self::is_constant_bool(operand_sample)
1085 {
1086 let operand_val = operand_sample();
1087 let result = match operation {
1088 UnaryOperation::Map(func) => func(operand_val),
1089 UnaryOperation::Filter(_) => operand_val, };
1091 return ComputationNode::leaf(move || result);
1092 }
1093
1094 ComputationNode::UnaryOp {
1095 operand: Box::new(operand_opt),
1096 operation,
1097 }
1098 }
1099
1100 fn constant_folding_bool_conditional(
1102 condition: ComputationNode<bool>,
1103 if_true: ComputationNode<bool>,
1104 if_false: ComputationNode<bool>,
1105 ) -> ComputationNode<bool> {
1106 let condition_opt = Self::constant_folding_bool(condition);
1107 let if_true_opt = Self::constant_folding_bool(if_true);
1108 let if_false_opt = Self::constant_folding_bool(if_false);
1109
1110 if let ComputationNode::Leaf {
1111 sample: condition_sample,
1112 ..
1113 } = &condition_opt
1114 && Self::is_constant_bool(condition_sample)
1115 {
1116 let condition_val = condition_sample();
1117 if condition_val {
1118 return if_true_opt;
1119 }
1120 return if_false_opt;
1121 }
1122
1123 ComputationNode::Conditional {
1124 condition: Box::new(condition_opt),
1125 if_true: Box::new(if_true_opt),
1126 if_false: Box::new(if_false_opt),
1127 }
1128 }
1129
1130 fn is_constant<T>(sample_fn: &Arc<dyn Fn() -> T + Send + Sync>) -> bool
1132 where
1133 T: PartialEq + Clone,
1134 {
1135 let first_sample = sample_fn();
1137 for _ in 0..3 {
1138 if sample_fn() != first_sample {
1139 return false;
1140 }
1141 }
1142 true
1143 }
1144
1145 fn is_constant_bool(sample_fn: &Arc<dyn Fn() -> bool + Send + Sync>) -> bool {
1147 let first_sample = sample_fn();
1149 for _ in 0..3 {
1150 if sample_fn() != first_sample {
1151 return false;
1152 }
1153 }
1154 true
1155 }
1156}
1157
1158impl Default for GraphOptimizer {
1159 fn default() -> Self {
1160 Self::new()
1161 }
1162}
1163
1164pub struct GraphVisualizer;
1166
1167impl GraphVisualizer {
1168 #[must_use]
1170 pub fn to_dot<T>(node: &ComputationNode<T>) -> String
1171 where
1172 T: Shareable,
1173 {
1174 let mut dot = String::from("digraph G {\n");
1175 let mut node_id = 0;
1176 Self::add_node_to_dot(node, &mut dot, &mut node_id);
1177 dot.push_str("}\n");
1178 dot
1179 }
1180
1181 fn add_node_to_dot<T>(node: &ComputationNode<T>, dot: &mut String, node_id: &mut usize) -> usize
1182 where
1183 T: Shareable,
1184 {
1185 use std::fmt::Write;
1186 let current_id = *node_id;
1187 *node_id += 1;
1188
1189 match node {
1190 ComputationNode::Leaf { .. } => {
1191 writeln!(dot, " {current_id} [label=\"Leaf\", shape=circle];").unwrap();
1192 }
1193 ComputationNode::BinaryOp {
1194 left,
1195 right,
1196 operation,
1197 } => {
1198 let op_name = match operation {
1199 BinaryOperation::Add => "Add",
1200 BinaryOperation::Sub => "Sub",
1201 BinaryOperation::Mul => "Mul",
1202 BinaryOperation::Div => "Div",
1203 };
1204 writeln!(dot, " {current_id} [label=\"{op_name}\", shape=box];").unwrap();
1205
1206 let left_id = Self::add_node_to_dot(left, dot, node_id);
1207 let right_id = Self::add_node_to_dot(right, dot, node_id);
1208
1209 writeln!(dot, " {current_id} -> {left_id};").unwrap();
1210 writeln!(dot, " {current_id} -> {right_id};").unwrap();
1211 }
1212 ComputationNode::UnaryOp { operand, .. } => {
1213 writeln!(dot, " {current_id} [label=\"UnaryOp\", shape=box];").unwrap();
1214 let operand_id = Self::add_node_to_dot(operand, dot, node_id);
1215 writeln!(dot, " {current_id} -> {operand_id};").unwrap();
1216 }
1217 ComputationNode::Conditional {
1218 condition,
1219 if_true,
1220 if_false,
1221 } => {
1222 writeln!(dot, " {current_id} [label=\"If\", shape=diamond];").unwrap();
1223
1224 let cond_id = Self::add_node_to_dot(condition, dot, node_id);
1225 let true_id = Self::add_node_to_dot(if_true, dot, node_id);
1226 let false_id = Self::add_node_to_dot(if_false, dot, node_id);
1227
1228 writeln!(dot, " {current_id} -> {cond_id} [label=\"cond\"];").unwrap();
1229 writeln!(dot, " {current_id} -> {true_id} [label=\"true\"];").unwrap();
1230 writeln!(dot, " {current_id} -> {false_id} [label=\"false\"];").unwrap();
1231 }
1232 }
1233
1234 current_id
1235 }
1236
1237 pub fn print_tree<T>(node: &ComputationNode<T>, indent: usize)
1239 where
1240 T: Shareable,
1241 {
1242 let prefix = " ".repeat(indent);
1243
1244 match node {
1245 ComputationNode::Leaf { id, .. } => {
1246 println!("{prefix}Leaf({id})");
1247 }
1248 ComputationNode::BinaryOp {
1249 left,
1250 right,
1251 operation,
1252 } => {
1253 let op_name = match operation {
1254 BinaryOperation::Add => "Add",
1255 BinaryOperation::Sub => "Sub",
1256 BinaryOperation::Mul => "Mul",
1257 BinaryOperation::Div => "Div",
1258 };
1259 println!("{prefix}{op_name}");
1260 Self::print_tree(left, indent + 1);
1261 Self::print_tree(right, indent + 1);
1262 }
1263 ComputationNode::UnaryOp { operand, .. } => {
1264 println!("{prefix}UnaryOp");
1265 Self::print_tree(operand, indent + 1);
1266 }
1267 ComputationNode::Conditional {
1268 condition,
1269 if_true,
1270 if_false,
1271 } => {
1272 println!("{prefix}Conditional");
1273 println!("{prefix} Condition:");
1274 Self::print_tree(condition, indent + 2);
1275 println!("{prefix} If True:");
1276 Self::print_tree(if_true, indent + 2);
1277 println!("{prefix} If False:");
1278 Self::print_tree(if_false, indent + 2);
1279 }
1280 }
1281 }
1282}
1283
1284pub struct GraphProfiler {
1286 execution_times: HashMap<String, Vec<std::time::Duration>>,
1287}
1288
1289impl GraphProfiler {
1290 #[must_use]
1292 pub fn new() -> Self {
1293 Self {
1294 execution_times: HashMap::new(),
1295 }
1296 }
1297
1298 pub fn profile_execution<T, F>(&mut self, name: &str, func: F) -> T
1300 where
1301 F: FnOnce() -> T,
1302 {
1303 let start = std::time::Instant::now();
1304 let result = func();
1305 let duration = start.elapsed();
1306
1307 self.execution_times
1308 .entry(name.to_string())
1309 .or_default()
1310 .push(duration);
1311
1312 result
1313 }
1314
1315 #[must_use]
1322 pub fn get_stats(&self, name: &str) -> Option<ProfileStats> {
1323 let times = self.execution_times.get(name)?;
1324 if times.is_empty() {
1325 return None;
1326 }
1327
1328 let total: std::time::Duration = times.iter().sum();
1329 let count = times.len();
1330 let average = total / u32::try_from(count).unwrap_or(1);
1331
1332 let mut sorted_times = times.clone();
1333 sorted_times.sort();
1334 let median = sorted_times[count / 2];
1335 let min = *sorted_times
1336 .first()
1337 .expect("Times vector should not be empty");
1338 let max = *sorted_times
1339 .last()
1340 .expect("Times vector should not be empty");
1341
1342 Some(ProfileStats {
1343 count,
1344 total,
1345 average,
1346 median,
1347 min,
1348 max,
1349 })
1350 }
1351
1352 pub fn print_report(&self) {
1354 println!("=== Computation Graph Profiling Report ===");
1355 for name in self.execution_times.keys() {
1356 if let Some(stats) = self.get_stats(name) {
1357 println!("\n{name}:");
1358 println!(" Count: {}", stats.count);
1359 println!(" Total: {:?}", stats.total);
1360 println!(" Average: {:?}", stats.average);
1361 println!(" Median: {:?}", stats.median);
1362 println!(" Min: {:?}", stats.min);
1363 println!(" Max: {:?}", stats.max);
1364 }
1365 }
1366 }
1367}
1368
1369impl Default for GraphProfiler {
1370 fn default() -> Self {
1371 Self::new()
1372 }
1373}
1374
1375#[derive(Debug, Clone)]
1377pub struct ProfileStats {
1378 pub count: usize,
1380 pub total: std::time::Duration,
1382 pub average: std::time::Duration,
1384 pub median: std::time::Duration,
1386 pub min: std::time::Duration,
1388 pub max: std::time::Duration,
1390}
1391
1392#[cfg(test)]
1393mod tests {
1394 use super::*;
1395 use crate::Uncertain;
1396
1397 #[test]
1398 fn test_sample_context_memoization() {
1399 let mut context = SampleContext::new();
1400 let id = uuid::Uuid::new_v4();
1401
1402 context.set_value(id, 42.0);
1404
1405 assert_eq!(context.get_value::<f64>(&id), Some(42.0));
1407
1408 let other_id = uuid::Uuid::new_v4();
1410 assert_eq!(context.get_value::<f64>(&other_id), None);
1411 }
1412
1413 #[test]
1414 #[allow(clippy::float_cmp)]
1415 fn test_computation_node_evaluation() {
1416 let left = ComputationNode::leaf(|| 5.0);
1417 let right = ComputationNode::leaf(|| 3.0);
1418 let add_node = ComputationNode::binary_op(left, right, BinaryOperation::Add);
1419
1420 let result = add_node.evaluate_fresh();
1421 assert_eq!(result, 8.0);
1422 }
1423
1424 #[test]
1425 #[allow(clippy::float_cmp)]
1426 fn test_shared_variable_memoization() {
1427 let mut context = SampleContext::new();
1428
1429 let leaf_id = uuid::Uuid::new_v4();
1431 let leaf = ComputationNode::Leaf {
1432 id: leaf_id,
1433 sample: Arc::new(rand::random::<f64>),
1434 };
1435
1436 let val1 = leaf.evaluate(&mut context);
1438 let val2 = leaf.evaluate(&mut context);
1439
1440 assert_eq!(val1, val2);
1442 }
1443
1444 #[test]
1445 fn test_computation_graph_metrics() {
1446 let left = ComputationNode::leaf(|| 1.0);
1447 let right = ComputationNode::leaf(|| 2.0);
1448 let add_node = ComputationNode::binary_op(left, right, BinaryOperation::Add);
1449
1450 assert_eq!(add_node.node_count(), 3); assert_eq!(add_node.depth(), 2); assert!(!add_node.has_conditionals());
1453 }
1454
1455 #[test]
1456 #[allow(clippy::float_cmp)]
1457 fn test_conditional_node() {
1458 let condition = ComputationNode::leaf(|| true);
1459 let if_true = ComputationNode::leaf(|| 10.0);
1460 let if_false = ComputationNode::leaf(|| 20.0);
1461
1462 let conditional = ComputationNode::conditional(condition, if_true, if_false);
1463
1464 let result = conditional.evaluate_fresh();
1465 assert_eq!(result, 10.0); assert!(conditional.has_conditionals());
1467 }
1468
1469 #[test]
1470 fn test_graph_visualizer_dot_output() {
1471 let left = ComputationNode::leaf(|| 1.0);
1472 let right = ComputationNode::leaf(|| 2.0);
1473 let add_node = ComputationNode::binary_op(left, right, BinaryOperation::Add);
1474
1475 let dot = GraphVisualizer::to_dot(&add_node);
1476
1477 assert!(dot.contains("digraph G"));
1478 assert!(dot.contains("Add"));
1479 assert!(dot.contains("Leaf"));
1480 }
1481
1482 #[test]
1483 fn test_profiler() {
1484 let mut profiler = GraphProfiler::new();
1485
1486 let result = profiler.profile_execution("test", || {
1487 std::thread::sleep(std::time::Duration::from_millis(10));
1488 42
1489 });
1490
1491 assert_eq!(result, 42);
1492
1493 let stats = profiler.get_stats("test").unwrap();
1494 assert_eq!(stats.count, 1);
1495 assert!(stats.total >= std::time::Duration::from_millis(10));
1496 }
1497
1498 #[test]
1499 fn test_complex_computation_graph() {
1500 let x = Uncertain::normal(5.0, 1.0);
1502 let y = Uncertain::normal(3.0, 1.0);
1503
1504 let sum = x.clone() + y.clone();
1506 let diff = x - y;
1507 let product = sum * diff;
1508
1509 assert!(product.node.node_count() > 5);
1511 assert!(product.node.depth() > 2);
1512
1513 let sample1 = product.sample();
1515 let sample2 = product.sample();
1516
1517 assert!(sample1 > -50.0 && sample1 < 150.0);
1522 assert!(sample2 > -50.0 && sample2 < 150.0);
1523 }
1524
1525 #[test]
1526 fn test_sample_context_clear() {
1527 let mut context = SampleContext::new();
1528 let id = uuid::Uuid::new_v4();
1529
1530 context.set_value(id, 42.0);
1531 assert_eq!(context.len(), 1);
1532 assert!(!context.is_empty());
1533
1534 context.clear();
1535 assert_eq!(context.len(), 0);
1536 assert!(context.is_empty());
1537 assert_eq!(context.get_value::<f64>(&id), None);
1538 }
1539
1540 #[test]
1541 fn test_sample_context_default() {
1542 let context = SampleContext::default();
1543 assert!(context.is_empty());
1544 assert_eq!(context.len(), 0);
1545 }
1546
1547 #[test]
1548 #[should_panic(expected = "BinaryOp evaluation requires arithmetic trait bounds")]
1549 fn test_evaluate_panic_on_binary_op() {
1550 let left = ComputationNode::leaf(|| 1.0);
1551 let right = ComputationNode::leaf(|| 2.0);
1552 let binary_op = ComputationNode::binary_op(left, right, BinaryOperation::Add);
1553
1554 let mut context = SampleContext::new();
1555 binary_op.evaluate(&mut context);
1556 }
1557
1558 #[test]
1559 #[should_panic(expected = "Conditional evaluation requires specific handling")]
1560 fn test_evaluate_panic_on_conditional() {
1561 let condition = ComputationNode::leaf(|| true);
1562 let if_true = ComputationNode::leaf(|| 10.0);
1563 let if_false = ComputationNode::leaf(|| 20.0);
1564 let conditional = ComputationNode::conditional(condition, if_true, if_false);
1565
1566 let mut context = SampleContext::new();
1567 conditional.evaluate(&mut context);
1568 }
1569
1570 #[test]
1571 #[should_panic(
1572 expected = "Conditional evaluation with bool condition not supported in arithmetic context"
1573 )]
1574 fn test_evaluate_arithmetic_panic_on_conditional() {
1575 let condition = ComputationNode::leaf(|| true);
1576 let if_true = ComputationNode::leaf(|| 10.0);
1577 let if_false = ComputationNode::leaf(|| 20.0);
1578 let conditional = ComputationNode::conditional(condition, if_true, if_false);
1579
1580 let mut context = SampleContext::new();
1581 conditional.evaluate_arithmetic(&mut context);
1582 }
1583
1584 #[test]
1585 #[should_panic(expected = "Boolean binary operations not implemented")]
1586 fn test_evaluate_bool_panic_on_binary_op() {
1587 let left = ComputationNode::leaf(|| true);
1588 let right = ComputationNode::leaf(|| false);
1589 let binary_op = ComputationNode::binary_op(left, right, BinaryOperation::Add);
1590
1591 let mut context = SampleContext::new();
1592 binary_op.evaluate_bool(&mut context);
1593 }
1594
1595 #[test]
1596 #[allow(clippy::float_cmp)]
1597 fn test_unary_map_operation() {
1598 let operand = ComputationNode::leaf(|| 5.0);
1599 let mapped = ComputationNode::map(operand, |x| x * 2.0);
1600
1601 let result = mapped.evaluate_fresh();
1602 assert_eq!(result, 10.0);
1603 }
1604
1605 #[test]
1606 #[allow(clippy::float_cmp)]
1607 fn test_unary_filter_operation() {
1608 let operand = ComputationNode::leaf(|| 42.0);
1609 let filtered = ComputationNode::UnaryOp {
1610 operand: Box::new(operand),
1611 operation: UnaryOperation::Filter(Arc::new(|x: &f64| *x > 0.0)),
1612 };
1613
1614 let mut context = SampleContext::new();
1615 let result = filtered.evaluate(&mut context);
1616 assert_eq!(result, 42.0); }
1618
1619 #[test]
1620 fn test_graph_optimizer() {
1621 let node = ComputationNode::leaf(|| 1.0);
1622 let mut optimizer = GraphOptimizer::new();
1623 let optimized_node = optimizer.optimize(node);
1624 assert_eq!(optimized_node.node_count(), 1);
1625 }
1626
1627 #[test]
1628 fn test_common_subexpression_elimination() {
1629 let mut optimizer = GraphOptimizer::new();
1630
1631 let x = ComputationNode::leaf(|| 2.0);
1633 let y = ComputationNode::leaf(|| 3.0);
1634 let sum = ComputationNode::binary_op(x.clone(), y.clone(), BinaryOperation::Add);
1635
1636 let expr = ComputationNode::binary_op(sum.clone(), sum, BinaryOperation::Mul);
1638
1639 let optimized1 = optimizer.eliminate_common_subexpressions(expr.clone());
1641
1642 let optimized2 = optimizer.eliminate_common_subexpressions(expr);
1644
1645 let result1: f64 = optimized1.evaluate_fresh();
1647 let result2: f64 = optimized2.evaluate_fresh();
1648 assert!((result1 - result2).abs() < f64::EPSILON);
1649
1650 assert!(!optimizer.subexpression_cache.is_empty());
1652 }
1653
1654 #[test]
1655 #[allow(clippy::similar_names)]
1656 fn test_common_subexpression_elimination_complex() {
1657 let mut optimizer = GraphOptimizer::new();
1658
1659 let a = ComputationNode::leaf(|| 1.0);
1661 let b = ComputationNode::leaf(|| 2.0);
1662 let c = ComputationNode::leaf(|| 3.0);
1663
1664 let sum_ab = ComputationNode::binary_op(a.clone(), b.clone(), BinaryOperation::Add);
1666
1667 let expr1 =
1669 ComputationNode::binary_op(sum_ab.clone(), sum_ab.clone(), BinaryOperation::Mul);
1670 let expr2 = ComputationNode::binary_op(sum_ab.clone(), c.clone(), BinaryOperation::Mul);
1671 let final_expr = ComputationNode::binary_op(expr1, expr2, BinaryOperation::Add);
1672
1673 let optimized = optimizer.eliminate_common_subexpressions(final_expr);
1674
1675 let result: f64 = optimized.evaluate_fresh();
1677 let expected = (1.0 + 2.0) * (1.0 + 2.0) + (1.0 + 2.0) * 3.0;
1678 assert!((result - expected).abs() < f64::EPSILON);
1679
1680 assert!(!optimizer.subexpression_cache.is_empty());
1682 }
1683
1684 #[test]
1685 fn test_identity_operation_elimination() {
1686 let x = ComputationNode::leaf(|| 5.0);
1688 let zero = ComputationNode::leaf(|| 0.0);
1689 let add_zero = ComputationNode::binary_op(x.clone(), zero, BinaryOperation::Add);
1690
1691 let optimized = GraphOptimizer::eliminate_identity_operations(add_zero);
1692 let result: f64 = optimized.evaluate_fresh();
1693 assert!((result - 5.0).abs() < f64::EPSILON);
1694
1695 let one = ComputationNode::leaf(|| 1.0);
1697 let mul_one = ComputationNode::binary_op(x.clone(), one, BinaryOperation::Mul);
1698
1699 let optimized = GraphOptimizer::eliminate_identity_operations(mul_one);
1700 let result: f64 = optimized.evaluate_fresh();
1701 assert!((result - 5.0).abs() < f64::EPSILON);
1702
1703 let zero2 = ComputationNode::leaf(|| 0.0);
1705 let sub_zero = ComputationNode::binary_op(x.clone(), zero2, BinaryOperation::Sub);
1706
1707 let optimized = GraphOptimizer::eliminate_identity_operations(sub_zero);
1708 let result: f64 = optimized.evaluate_fresh();
1709 assert!((result - 5.0).abs() < f64::EPSILON);
1710
1711 let one2 = ComputationNode::leaf(|| 1.0);
1713 let div_one = ComputationNode::binary_op(x.clone(), one2, BinaryOperation::Div);
1714
1715 let optimized = GraphOptimizer::eliminate_identity_operations(div_one);
1716 let result: f64 = optimized.evaluate_fresh();
1717 assert!((result - 5.0).abs() < f64::EPSILON);
1718
1719 let zero3 = ComputationNode::leaf(|| 0.0);
1721 let mul_zero = ComputationNode::binary_op(x.clone(), zero3, BinaryOperation::Mul);
1722
1723 let optimized = GraphOptimizer::eliminate_identity_operations(mul_zero);
1724 let result: f64 = optimized.evaluate_fresh();
1725 assert!((result - 0.0).abs() < f64::EPSILON);
1726 }
1727
1728 #[test]
1729 fn test_constant_folding() {
1730 let two = ComputationNode::leaf(|| 2.0);
1732 let three = ComputationNode::leaf(|| 3.0);
1733 let add_const = ComputationNode::binary_op(two, three, BinaryOperation::Add);
1734
1735 let optimized = GraphOptimizer::constant_folding(add_const);
1736 let result: f64 = optimized.evaluate_fresh();
1737 assert!((result - 5.0).abs() < f64::EPSILON);
1738
1739 let four = ComputationNode::leaf(|| 4.0);
1741 let five = ComputationNode::leaf(|| 5.0);
1742 let mul_const = ComputationNode::binary_op(four, five, BinaryOperation::Mul);
1743
1744 let optimized = GraphOptimizer::constant_folding(mul_const);
1745 let result: f64 = optimized.evaluate_fresh();
1746 assert!((result - 20.0).abs() < f64::EPSILON);
1747
1748 let ten = ComputationNode::leaf(|| 10.0);
1750 let two_div = ComputationNode::leaf(|| 2.0);
1751 let div_const = ComputationNode::binary_op(ten, two_div, BinaryOperation::Div);
1752
1753 let optimized = GraphOptimizer::constant_folding(div_const);
1754 let result: f64 = optimized.evaluate_fresh();
1755 assert!((result - 5.0).abs() < f64::EPSILON);
1756
1757 let eight = ComputationNode::leaf(|| 8.0);
1759 let three_sub = ComputationNode::leaf(|| 3.0);
1760 let sub_const = ComputationNode::binary_op(eight, three_sub, BinaryOperation::Sub);
1761
1762 let optimized = GraphOptimizer::constant_folding(sub_const);
1763 let result: f64 = optimized.evaluate_fresh();
1764 assert!((result - 5.0).abs() < f64::EPSILON);
1765 }
1766
1767 #[test]
1768 fn test_constant_folding_conditional() {
1769 let true_condition = ComputationNode::leaf(|| true);
1771 let if_true = ComputationNode::leaf(|| 10.0);
1772 let if_false = ComputationNode::leaf(|| 20.0);
1773 let conditional = ComputationNode::conditional(true_condition, if_true, if_false);
1774
1775 let optimized = GraphOptimizer::constant_folding(conditional);
1776 let result: f64 = optimized.evaluate_fresh();
1777 assert!((result - 10.0).abs() < f64::EPSILON);
1778
1779 let false_condition = ComputationNode::leaf(|| false);
1781 let if_true2 = ComputationNode::leaf(|| 10.0);
1782 let if_false2 = ComputationNode::leaf(|| 20.0);
1783 let conditional2 = ComputationNode::conditional(false_condition, if_true2, if_false2);
1784
1785 let optimized = GraphOptimizer::constant_folding(conditional2);
1786 let result: f64 = optimized.evaluate_fresh();
1787 assert!((result - 20.0).abs() < f64::EPSILON);
1788 }
1789
1790 #[test]
1791 fn test_constant_folding_unary() {
1792 let five = ComputationNode::leaf(|| 5.0);
1794 let double = ComputationNode::map(five, |x| x * 2.0);
1795
1796 let optimized = GraphOptimizer::constant_folding(double);
1797 let result: f64 = optimized.evaluate_fresh();
1798 assert!((result - 10.0).abs() < f64::EPSILON);
1799 }
1800
1801 #[test]
1802 fn test_graph_visualizer_print_tree() {
1803 let left = ComputationNode::leaf(|| 1.0);
1804 let right = ComputationNode::leaf(|| 2.0);
1805 let add_node = ComputationNode::binary_op(left, right, BinaryOperation::Add);
1806
1807 GraphVisualizer::print_tree(&add_node, 0);
1809 }
1810
1811 #[test]
1812 fn test_graph_visualizer_dot_conditional() {
1813 let condition = ComputationNode::leaf(|| true);
1814 let if_true = ComputationNode::leaf(|| 10.0);
1815 let if_false = ComputationNode::leaf(|| 20.0);
1816 let conditional = ComputationNode::conditional(condition, if_true, if_false);
1817
1818 let dot = GraphVisualizer::to_dot(&conditional);
1819
1820 assert!(dot.contains("digraph G"));
1821 assert!(dot.contains("If"));
1822 assert!(dot.contains("diamond"));
1823 assert!(dot.contains("cond"));
1824 assert!(dot.contains("true"));
1825 assert!(dot.contains("false"));
1826 }
1827
1828 #[test]
1829 fn test_graph_visualizer_dot_unary_op() {
1830 let operand = ComputationNode::leaf(|| 5.0);
1831 let unary = ComputationNode::map(operand, |x| x * 2.0);
1832
1833 let dot = GraphVisualizer::to_dot(&unary);
1834
1835 assert!(dot.contains("digraph G"));
1836 assert!(dot.contains("UnaryOp"));
1837 assert!(dot.contains("Leaf"));
1838 }
1839
1840 #[test]
1841 fn test_profiler_default() {
1842 let profiler = GraphProfiler::default();
1843 assert!(profiler.get_stats("nonexistent").is_none());
1844 }
1845
1846 #[test]
1847 fn test_profiler_get_stats_nonexistent() {
1848 let profiler = GraphProfiler::new();
1849 assert!(profiler.get_stats("nonexistent").is_none());
1850 }
1851
1852 #[test]
1853 fn test_profiler_multiple_executions() {
1854 let mut profiler = GraphProfiler::new();
1855
1856 profiler.profile_execution("test", || {
1857 std::thread::sleep(std::time::Duration::from_millis(1));
1858 });
1859 profiler.profile_execution("test", || {
1860 std::thread::sleep(std::time::Duration::from_millis(2));
1861 });
1862 profiler.profile_execution("test", || {
1863 std::thread::sleep(std::time::Duration::from_millis(3));
1864 });
1865
1866 let stats = profiler.get_stats("test").unwrap();
1867 assert_eq!(stats.count, 3);
1868 assert!(stats.min <= stats.median);
1869 assert!(stats.median <= stats.max);
1870 assert!(stats.average.as_nanos() > 0);
1871
1872 profiler.print_report();
1873 }
1874
1875 #[test]
1876 #[allow(clippy::float_cmp)]
1877 fn test_conditional_evaluation_false_branch() {
1878 let condition = ComputationNode::leaf(|| false);
1879 let if_true = ComputationNode::leaf(|| 10.0);
1880 let if_false = ComputationNode::leaf(|| 20.0);
1881 let conditional = ComputationNode::conditional(condition, if_true, if_false);
1882
1883 let result = conditional.evaluate_fresh();
1884 assert_eq!(result, 20.0);
1885 }
1886
1887 #[test]
1888 fn test_bool_conditional_evaluation() {
1889 let condition = ComputationNode::leaf(|| true);
1890 let if_true = ComputationNode::leaf(|| true);
1891 let if_false = ComputationNode::leaf(|| false);
1892 let conditional = ComputationNode::conditional(condition, if_true, if_false);
1893
1894 let mut context = SampleContext::new();
1895 let result = conditional.evaluate_bool(&mut context);
1896 assert!(result);
1897 }
1898
1899 #[test]
1900 fn test_bool_unary_operation() {
1901 let operand = ComputationNode::leaf(|| true);
1902 let mapped = ComputationNode::map(operand, |x| !x);
1903
1904 let mut context = SampleContext::new();
1905 let result = mapped.evaluate_bool(&mut context);
1906 assert!(!result);
1907 }
1908
1909 #[test]
1910 #[allow(clippy::float_cmp)]
1911 fn test_binary_operations_subtraction() {
1912 let left = ComputationNode::leaf(|| 10.0);
1913 let right = ComputationNode::leaf(|| 3.0);
1914 let sub_node = ComputationNode::binary_op(left, right, BinaryOperation::Sub);
1915
1916 let result = sub_node.evaluate_fresh();
1917 assert_eq!(result, 7.0);
1918 }
1919
1920 #[test]
1921 #[allow(clippy::float_cmp)]
1922 fn test_binary_operations_multiplication() {
1923 let left = ComputationNode::leaf(|| 4.0);
1924 let right = ComputationNode::leaf(|| 5.0);
1925 let mul_node = ComputationNode::binary_op(left, right, BinaryOperation::Mul);
1926
1927 let result = mul_node.evaluate_fresh();
1928 assert_eq!(result, 20.0);
1929 }
1930
1931 #[test]
1932 #[allow(clippy::float_cmp)]
1933 fn test_binary_operations_division() {
1934 let left = ComputationNode::leaf(|| 15.0);
1935 let right = ComputationNode::leaf(|| 3.0);
1936 let div_node = ComputationNode::binary_op(left, right, BinaryOperation::Div);
1937
1938 let result = div_node.evaluate_fresh();
1939 assert_eq!(result, 5.0);
1940 }
1941
1942 #[test]
1943 fn test_nested_conditional_depth() {
1944 let condition1 = ComputationNode::leaf(|| true);
1945 let condition2 = ComputationNode::leaf(|| false);
1946 let leaf1 = ComputationNode::leaf(|| 1.0);
1947 let _leaf2 = ComputationNode::leaf(|| 2.0);
1948 let leaf3 = ComputationNode::leaf(|| 3.0);
1949 let leaf4 = ComputationNode::leaf(|| 4.0);
1950
1951 let inner_conditional = ComputationNode::conditional(condition2, leaf3, leaf4);
1952 let outer_conditional = ComputationNode::conditional(condition1, leaf1, inner_conditional);
1953
1954 assert_eq!(outer_conditional.depth(), 3);
1955 assert_eq!(outer_conditional.node_count(), 7);
1956 assert!(outer_conditional.has_conditionals());
1957 }
1958
1959 #[test]
1960 #[allow(clippy::float_cmp)]
1961 fn test_evaluate_conditional_with_arithmetic() {
1962 let condition = ComputationNode::leaf(|| true);
1963 let if_true = ComputationNode::leaf(|| 42.0);
1964 let if_false = ComputationNode::leaf(|| 24.0);
1965 let conditional = ComputationNode::conditional(condition, if_true, if_false);
1966
1967 let mut context = SampleContext::new();
1968 let result = conditional.evaluate_conditional_with_arithmetic(&mut context);
1969 assert_eq!(result, 42.0);
1970
1971 let leaf = ComputationNode::leaf(|| 99.0);
1972 let result = leaf.evaluate_conditional_with_arithmetic(&mut context);
1973 assert_eq!(result, 99.0);
1974 }
1975
1976 #[test]
1977 fn test_sample_context_different_types() {
1978 let mut context = SampleContext::new();
1979 let id1 = uuid::Uuid::new_v4();
1980 let id2 = uuid::Uuid::new_v4();
1981
1982 context.set_value(id1, 42.0_f64);
1983 context.set_value(id2, 100_i32);
1984
1985 assert_eq!(context.get_value::<f64>(&id1), Some(42.0));
1986 assert_eq!(context.get_value::<i32>(&id2), Some(100));
1987 assert_eq!(context.get_value::<f64>(&id2), None); assert_eq!(context.get_value::<i32>(&id1), None); assert_eq!(context.len(), 2);
1991 }
1992
1993 #[test]
1994 fn test_profile_stats_debug() {
1995 let stats = ProfileStats {
1996 count: 5,
1997 total: std::time::Duration::from_millis(100),
1998 average: std::time::Duration::from_millis(20),
1999 median: std::time::Duration::from_millis(18),
2000 min: std::time::Duration::from_millis(15),
2001 max: std::time::Duration::from_millis(30),
2002 };
2003
2004 let debug_str = format!("{stats:?}");
2005 assert!(debug_str.contains("ProfileStats"));
2006
2007 let cloned = stats.clone();
2008 assert_eq!(cloned.count, stats.count);
2009 }
2010}