1use tensorlogic_ir::TLExpr;
35
36#[derive(Debug, Clone, Default)]
42pub struct ConstPropStats {
43 pub arithmetic_folds: u64,
46 pub comparison_folds: u64,
49 pub boolean_folds: u64,
52 pub nodes_before: u64,
54 pub nodes_after: u64,
56 pub passes: u32,
58}
59
60impl ConstPropStats {
61 pub fn total_folds(&self) -> u64 {
63 self.arithmetic_folds
64 .saturating_add(self.comparison_folds)
65 .saturating_add(self.boolean_folds)
66 }
67
68 pub fn reduction_pct(&self) -> f64 {
72 if self.nodes_before == 0 {
73 return 0.0;
74 }
75 let before = self.nodes_before as f64;
76 let after = self.nodes_after as f64;
77 (((before - after) / before) * 100.0).max(0.0)
78 }
79
80 pub fn summary(&self) -> String {
82 format!(
83 "ConstProp: {} passes, {}/{} nodes kept ({:.1}% reduction) — \
84 {} arith folds, {} cmp folds, {} bool folds",
85 self.passes,
86 self.nodes_after,
87 self.nodes_before,
88 self.reduction_pct(),
89 self.arithmetic_folds,
90 self.comparison_folds,
91 self.boolean_folds,
92 )
93 }
94}
95
96#[derive(Debug, Clone)]
102pub struct ConstPropConfig {
103 pub fold_arithmetic: bool,
105 pub fold_comparisons: bool,
107 pub fold_boolean: bool,
109 pub max_passes: u32,
111 pub float_tolerance: f64,
113}
114
115impl Default for ConstPropConfig {
116 fn default() -> Self {
117 Self {
118 fold_arithmetic: true,
119 fold_comparisons: true,
120 fold_boolean: true,
121 max_passes: 20,
122 float_tolerance: 1e-12,
123 }
124 }
125}
126
127pub struct ConstantPropagator {
138 config: ConstPropConfig,
139}
140
141impl ConstantPropagator {
142 pub fn new(config: ConstPropConfig) -> Self {
144 Self { config }
145 }
146
147 pub fn with_default() -> Self {
149 Self::new(ConstPropConfig::default())
150 }
151
152 pub fn run(&self, expr: TLExpr) -> (TLExpr, ConstPropStats) {
156 let mut stats = ConstPropStats {
157 nodes_before: Self::count_nodes(&expr),
158 ..Default::default()
159 };
160
161 let mut current = expr;
162 let mut pass_count = 0u32;
163
164 loop {
165 if pass_count >= self.config.max_passes {
166 break;
167 }
168 let (next, changed) = self.run_pass(current, &mut stats);
169 pass_count = pass_count.saturating_add(1);
170 current = next;
171 if !changed {
172 break;
173 }
174 }
175
176 stats.passes = pass_count;
177 stats.nodes_after = Self::count_nodes(¤t);
178 (current, stats)
179 }
180
181 fn run_pass(&self, expr: TLExpr, stats: &mut ConstPropStats) -> (TLExpr, bool) {
186 self.propagate(expr, stats)
187 }
188
189 fn propagate(&self, expr: TLExpr, stats: &mut ConstPropStats) -> (TLExpr, bool) {
194 match expr {
195 TLExpr::Constant(_)
197 | TLExpr::Pred { .. }
198 | TLExpr::EmptySet
199 | TLExpr::AllDifferent { .. }
200 | TLExpr::Nominal { .. }
201 | TLExpr::Abducible { .. } => (expr, false),
202
203 TLExpr::Add(lhs, rhs) => self.fold_binary_arith("Add", *lhs, *rhs, stats, TLExpr::Add),
205 TLExpr::Sub(lhs, rhs) => self.fold_binary_arith("Sub", *lhs, *rhs, stats, TLExpr::Sub),
206 TLExpr::Mul(lhs, rhs) => self.fold_binary_arith("Mul", *lhs, *rhs, stats, TLExpr::Mul),
207 TLExpr::Div(lhs, rhs) => self.fold_binary_arith("Div", *lhs, *rhs, stats, TLExpr::Div),
208 TLExpr::Pow(lhs, rhs) => self.fold_binary_arith("Pow", *lhs, *rhs, stats, TLExpr::Pow),
209 TLExpr::Mod(lhs, rhs) => self.fold_binary_arith("Mod", *lhs, *rhs, stats, TLExpr::Mod),
210 TLExpr::Min(lhs, rhs) => self.fold_binary_arith("Min", *lhs, *rhs, stats, TLExpr::Min),
211 TLExpr::Max(lhs, rhs) => self.fold_binary_arith("Max", *lhs, *rhs, stats, TLExpr::Max),
212
213 TLExpr::Eq(lhs, rhs) => self.fold_binary_cmp("Eq", *lhs, *rhs, stats, TLExpr::Eq),
215 TLExpr::Lt(lhs, rhs) => self.fold_binary_cmp("Lt", *lhs, *rhs, stats, TLExpr::Lt),
216 TLExpr::Gt(lhs, rhs) => self.fold_binary_cmp("Gt", *lhs, *rhs, stats, TLExpr::Gt),
217 TLExpr::Lte(lhs, rhs) => self.fold_binary_cmp("Lte", *lhs, *rhs, stats, TLExpr::Lte),
218 TLExpr::Gte(lhs, rhs) => self.fold_binary_cmp("Gte", *lhs, *rhs, stats, TLExpr::Gte),
219
220 TLExpr::Abs(inner) => self.fold_unary_math("Abs", *inner, stats, TLExpr::Abs),
222 TLExpr::Floor(inner) => self.fold_unary_math("Floor", *inner, stats, TLExpr::Floor),
223 TLExpr::Ceil(inner) => self.fold_unary_math("Ceil", *inner, stats, TLExpr::Ceil),
224 TLExpr::Round(inner) => self.fold_unary_math("Round", *inner, stats, TLExpr::Round),
225 TLExpr::Sqrt(inner) => self.fold_unary_math("Sqrt", *inner, stats, TLExpr::Sqrt),
226 TLExpr::Exp(inner) => self.fold_unary_math("Exp", *inner, stats, TLExpr::Exp),
227 TLExpr::Log(inner) => self.fold_unary_math("Log", *inner, stats, TLExpr::Log),
228 TLExpr::Sin(inner) => self.fold_unary_math("Sin", *inner, stats, TLExpr::Sin),
229 TLExpr::Cos(inner) => self.fold_unary_math("Cos", *inner, stats, TLExpr::Cos),
230 TLExpr::Tan(inner) => self.fold_unary_math("Tan", *inner, stats, TLExpr::Tan),
231
232 TLExpr::Not(inner) => {
234 let (new_inner, child_changed) = self.propagate(*inner, stats);
235 if self.config.fold_boolean {
236 if let Some(v) = Self::as_constant(&new_inner) {
237 let result = TLExpr::Constant(1.0 - v);
240 stats.boolean_folds = stats.boolean_folds.saturating_add(1);
241 return (result, true);
242 }
243 }
244 (TLExpr::Not(Box::new(new_inner)), child_changed)
245 }
246
247 TLExpr::And(lhs, rhs) => {
249 let (new_lhs, cl) = self.propagate(*lhs, stats);
250 let (new_rhs, cr) = self.propagate(*rhs, stats);
251 if self.config.fold_boolean {
252 if let (Some(a), Some(b)) =
253 (Self::as_constant(&new_lhs), Self::as_constant(&new_rhs))
254 {
255 let result = if a != 0.0 && b != 0.0 { 1.0 } else { 0.0 };
257 stats.boolean_folds = stats.boolean_folds.saturating_add(1);
258 return (TLExpr::Constant(result), true);
259 }
260 }
261 (TLExpr::And(Box::new(new_lhs), Box::new(new_rhs)), cl || cr)
262 }
263 TLExpr::Or(lhs, rhs) => {
264 let (new_lhs, cl) = self.propagate(*lhs, stats);
265 let (new_rhs, cr) = self.propagate(*rhs, stats);
266 if self.config.fold_boolean {
267 if let (Some(a), Some(b)) =
268 (Self::as_constant(&new_lhs), Self::as_constant(&new_rhs))
269 {
270 let result = if a != 0.0 || b != 0.0 { 1.0 } else { 0.0 };
271 stats.boolean_folds = stats.boolean_folds.saturating_add(1);
272 return (TLExpr::Constant(result), true);
273 }
274 }
275 (TLExpr::Or(Box::new(new_lhs), Box::new(new_rhs)), cl || cr)
276 }
277 TLExpr::Imply(premise, conclusion) => {
278 let (new_p, cp) = self.propagate(*premise, stats);
279 let (new_c, cc) = self.propagate(*conclusion, stats);
280 if self.config.fold_boolean {
281 if let (Some(a), Some(b)) =
282 (Self::as_constant(&new_p), Self::as_constant(&new_c))
283 {
284 let result = if a == 0.0 || b != 0.0 { 1.0 } else { 0.0 };
286 stats.boolean_folds = stats.boolean_folds.saturating_add(1);
287 return (TLExpr::Constant(result), true);
288 }
289 }
290 (TLExpr::Imply(Box::new(new_p), Box::new(new_c)), cp || cc)
291 }
292
293 TLExpr::IfThenElse {
295 condition,
296 then_branch,
297 else_branch,
298 } => {
299 let (new_cond, cc) = self.propagate(*condition, stats);
300 let (new_then, ct) = self.propagate(*then_branch, stats);
301 let (new_else, ce) = self.propagate(*else_branch, stats);
302 if self.config.fold_boolean {
303 if let Some(v) = Self::as_constant(&new_cond) {
304 if v != 0.0 {
305 stats.boolean_folds = stats.boolean_folds.saturating_add(1);
307 return (new_then, true);
308 } else {
309 stats.boolean_folds = stats.boolean_folds.saturating_add(1);
311 return (new_else, true);
312 }
313 }
314 }
315 let changed = cc || ct || ce;
316 (
317 TLExpr::IfThenElse {
318 condition: Box::new(new_cond),
319 then_branch: Box::new(new_then),
320 else_branch: Box::new(new_else),
321 },
322 changed,
323 )
324 }
325
326 TLExpr::Score(inner) => {
328 let (new_inner, changed) = self.propagate(*inner, stats);
329 (TLExpr::Score(Box::new(new_inner)), changed)
330 }
331
332 TLExpr::Exists { var, domain, body } => {
334 let (new_body, changed) = self.propagate(*body, stats);
335 (
336 TLExpr::Exists {
337 var,
338 domain,
339 body: Box::new(new_body),
340 },
341 changed,
342 )
343 }
344 TLExpr::ForAll { var, domain, body } => {
345 let (new_body, changed) = self.propagate(*body, stats);
346 (
347 TLExpr::ForAll {
348 var,
349 domain,
350 body: Box::new(new_body),
351 },
352 changed,
353 )
354 }
355
356 TLExpr::Let { var, value, body } => {
358 let (new_value, cv) = self.propagate(*value, stats);
359 let (new_body, cb) = self.propagate(*body, stats);
360 (
361 TLExpr::Let {
362 var,
363 value: Box::new(new_value),
364 body: Box::new(new_body),
365 },
366 cv || cb,
367 )
368 }
369
370 TLExpr::Aggregate {
372 op,
373 var,
374 domain,
375 body,
376 group_by,
377 } => {
378 let (new_body, changed) = self.propagate(*body, stats);
379 (
380 TLExpr::Aggregate {
381 op,
382 var,
383 domain,
384 body: Box::new(new_body),
385 group_by,
386 },
387 changed,
388 )
389 }
390
391 TLExpr::Box(inner) => {
393 let (n, c) = self.propagate(*inner, stats);
394 (TLExpr::Box(Box::new(n)), c)
395 }
396 TLExpr::Diamond(inner) => {
397 let (n, c) = self.propagate(*inner, stats);
398 (TLExpr::Diamond(Box::new(n)), c)
399 }
400 TLExpr::Next(inner) => {
401 let (n, c) = self.propagate(*inner, stats);
402 (TLExpr::Next(Box::new(n)), c)
403 }
404 TLExpr::Eventually(inner) => {
405 let (n, c) = self.propagate(*inner, stats);
406 (TLExpr::Eventually(Box::new(n)), c)
407 }
408 TLExpr::Always(inner) => {
409 let (n, c) = self.propagate(*inner, stats);
410 (TLExpr::Always(Box::new(n)), c)
411 }
412 TLExpr::Until { before, after } => {
413 let (nb, cb) = self.propagate(*before, stats);
414 let (na, ca) = self.propagate(*after, stats);
415 (
416 TLExpr::Until {
417 before: Box::new(nb),
418 after: Box::new(na),
419 },
420 cb || ca,
421 )
422 }
423 TLExpr::Release { released, releaser } => {
424 let (nr, cr) = self.propagate(*released, stats);
425 let (nl, cl) = self.propagate(*releaser, stats);
426 (
427 TLExpr::Release {
428 released: Box::new(nr),
429 releaser: Box::new(nl),
430 },
431 cr || cl,
432 )
433 }
434 TLExpr::WeakUntil { before, after } => {
435 let (nb, cb) = self.propagate(*before, stats);
436 let (na, ca) = self.propagate(*after, stats);
437 (
438 TLExpr::WeakUntil {
439 before: Box::new(nb),
440 after: Box::new(na),
441 },
442 cb || ca,
443 )
444 }
445 TLExpr::StrongRelease { released, releaser } => {
446 let (nr, cr) = self.propagate(*released, stats);
447 let (nl, cl) = self.propagate(*releaser, stats);
448 (
449 TLExpr::StrongRelease {
450 released: Box::new(nr),
451 releaser: Box::new(nl),
452 },
453 cr || cl,
454 )
455 }
456
457 TLExpr::TNorm { kind, left, right } => {
458 let (nl, cl) = self.propagate(*left, stats);
459 let (nr, cr) = self.propagate(*right, stats);
460 (
461 TLExpr::TNorm {
462 kind,
463 left: Box::new(nl),
464 right: Box::new(nr),
465 },
466 cl || cr,
467 )
468 }
469 TLExpr::TCoNorm { kind, left, right } => {
470 let (nl, cl) = self.propagate(*left, stats);
471 let (nr, cr) = self.propagate(*right, stats);
472 (
473 TLExpr::TCoNorm {
474 kind,
475 left: Box::new(nl),
476 right: Box::new(nr),
477 },
478 cl || cr,
479 )
480 }
481 TLExpr::FuzzyNot { kind, expr: inner } => {
482 let (n, c) = self.propagate(*inner, stats);
483 (
484 TLExpr::FuzzyNot {
485 kind,
486 expr: Box::new(n),
487 },
488 c,
489 )
490 }
491 TLExpr::FuzzyImplication {
492 kind,
493 premise,
494 conclusion,
495 } => {
496 let (np, cp) = self.propagate(*premise, stats);
497 let (nc, cc) = self.propagate(*conclusion, stats);
498 (
499 TLExpr::FuzzyImplication {
500 kind,
501 premise: Box::new(np),
502 conclusion: Box::new(nc),
503 },
504 cp || cc,
505 )
506 }
507
508 TLExpr::SoftExists {
510 var,
511 domain,
512 body,
513 temperature,
514 } => {
515 let (nb, changed) = self.propagate(*body, stats);
516 (
517 TLExpr::SoftExists {
518 var,
519 domain,
520 body: Box::new(nb),
521 temperature,
522 },
523 changed,
524 )
525 }
526 TLExpr::SoftForAll {
527 var,
528 domain,
529 body,
530 temperature,
531 } => {
532 let (nb, changed) = self.propagate(*body, stats);
533 (
534 TLExpr::SoftForAll {
535 var,
536 domain,
537 body: Box::new(nb),
538 temperature,
539 },
540 changed,
541 )
542 }
543 TLExpr::WeightedRule { weight, rule } => {
544 let (nr, changed) = self.propagate(*rule, stats);
545 (
546 TLExpr::WeightedRule {
547 weight,
548 rule: Box::new(nr),
549 },
550 changed,
551 )
552 }
553 TLExpr::ProbabilisticChoice { alternatives } => {
554 let mut changed = false;
555 let new_alts: Vec<(f64, TLExpr)> = alternatives
556 .into_iter()
557 .map(|(p, e)| {
558 let (ne, c) = self.propagate(e, stats);
559 if c {
560 changed = true;
561 }
562 (p, ne)
563 })
564 .collect();
565 (
566 TLExpr::ProbabilisticChoice {
567 alternatives: new_alts,
568 },
569 changed,
570 )
571 }
572
573 TLExpr::Lambda {
575 var,
576 var_type,
577 body,
578 } => {
579 let (nb, changed) = self.propagate(*body, stats);
580 (
581 TLExpr::Lambda {
582 var,
583 var_type,
584 body: Box::new(nb),
585 },
586 changed,
587 )
588 }
589 TLExpr::Apply { function, argument } => {
590 let (nf, cf) = self.propagate(*function, stats);
591 let (na, ca) = self.propagate(*argument, stats);
592 (
593 TLExpr::Apply {
594 function: Box::new(nf),
595 argument: Box::new(na),
596 },
597 cf || ca,
598 )
599 }
600
601 TLExpr::SetMembership { element, set } => {
603 let (ne, ce) = self.propagate(*element, stats);
604 let (ns, cs) = self.propagate(*set, stats);
605 (
606 TLExpr::SetMembership {
607 element: Box::new(ne),
608 set: Box::new(ns),
609 },
610 ce || cs,
611 )
612 }
613 TLExpr::SetUnion { left, right } => {
614 let (nl, cl) = self.propagate(*left, stats);
615 let (nr, cr) = self.propagate(*right, stats);
616 (
617 TLExpr::SetUnion {
618 left: Box::new(nl),
619 right: Box::new(nr),
620 },
621 cl || cr,
622 )
623 }
624 TLExpr::SetIntersection { left, right } => {
625 let (nl, cl) = self.propagate(*left, stats);
626 let (nr, cr) = self.propagate(*right, stats);
627 (
628 TLExpr::SetIntersection {
629 left: Box::new(nl),
630 right: Box::new(nr),
631 },
632 cl || cr,
633 )
634 }
635 TLExpr::SetDifference { left, right } => {
636 let (nl, cl) = self.propagate(*left, stats);
637 let (nr, cr) = self.propagate(*right, stats);
638 (
639 TLExpr::SetDifference {
640 left: Box::new(nl),
641 right: Box::new(nr),
642 },
643 cl || cr,
644 )
645 }
646 TLExpr::SetCardinality { set } => {
647 let (ns, changed) = self.propagate(*set, stats);
648 (TLExpr::SetCardinality { set: Box::new(ns) }, changed)
649 }
650 TLExpr::SetComprehension {
651 var,
652 domain,
653 condition,
654 } => {
655 let (nc, changed) = self.propagate(*condition, stats);
656 (
657 TLExpr::SetComprehension {
658 var,
659 domain,
660 condition: Box::new(nc),
661 },
662 changed,
663 )
664 }
665
666 TLExpr::CountingExists {
668 var,
669 domain,
670 body,
671 min_count,
672 } => {
673 let (nb, changed) = self.propagate(*body, stats);
674 (
675 TLExpr::CountingExists {
676 var,
677 domain,
678 body: Box::new(nb),
679 min_count,
680 },
681 changed,
682 )
683 }
684 TLExpr::CountingForAll {
685 var,
686 domain,
687 body,
688 min_count,
689 } => {
690 let (nb, changed) = self.propagate(*body, stats);
691 (
692 TLExpr::CountingForAll {
693 var,
694 domain,
695 body: Box::new(nb),
696 min_count,
697 },
698 changed,
699 )
700 }
701 TLExpr::ExactCount {
702 var,
703 domain,
704 body,
705 count,
706 } => {
707 let (nb, changed) = self.propagate(*body, stats);
708 (
709 TLExpr::ExactCount {
710 var,
711 domain,
712 body: Box::new(nb),
713 count,
714 },
715 changed,
716 )
717 }
718 TLExpr::Majority { var, domain, body } => {
719 let (nb, changed) = self.propagate(*body, stats);
720 (
721 TLExpr::Majority {
722 var,
723 domain,
724 body: Box::new(nb),
725 },
726 changed,
727 )
728 }
729
730 TLExpr::LeastFixpoint { var, body } => {
732 let (nb, changed) = self.propagate(*body, stats);
733 (
734 TLExpr::LeastFixpoint {
735 var,
736 body: Box::new(nb),
737 },
738 changed,
739 )
740 }
741 TLExpr::GreatestFixpoint { var, body } => {
742 let (nb, changed) = self.propagate(*body, stats);
743 (
744 TLExpr::GreatestFixpoint {
745 var,
746 body: Box::new(nb),
747 },
748 changed,
749 )
750 }
751
752 TLExpr::At { nominal, formula } => {
754 let (nf, changed) = self.propagate(*formula, stats);
755 (
756 TLExpr::At {
757 nominal,
758 formula: Box::new(nf),
759 },
760 changed,
761 )
762 }
763 TLExpr::Somewhere { formula } => {
764 let (nf, changed) = self.propagate(*formula, stats);
765 (
766 TLExpr::Somewhere {
767 formula: Box::new(nf),
768 },
769 changed,
770 )
771 }
772 TLExpr::Everywhere { formula } => {
773 let (nf, changed) = self.propagate(*formula, stats);
774 (
775 TLExpr::Everywhere {
776 formula: Box::new(nf),
777 },
778 changed,
779 )
780 }
781
782 TLExpr::GlobalCardinality {
784 variables,
785 values,
786 min_occurrences,
787 max_occurrences,
788 } => {
789 let mut changed = false;
790 let new_values: Vec<TLExpr> = values
791 .into_iter()
792 .map(|e| {
793 let (ne, c) = self.propagate(e, stats);
794 if c {
795 changed = true;
796 }
797 ne
798 })
799 .collect();
800 (
801 TLExpr::GlobalCardinality {
802 variables,
803 values: new_values,
804 min_occurrences,
805 max_occurrences,
806 },
807 changed,
808 )
809 }
810
811 TLExpr::Explain { formula } => {
813 let (nf, changed) = self.propagate(*formula, stats);
814 (
815 TLExpr::Explain {
816 formula: Box::new(nf),
817 },
818 changed,
819 )
820 }
821
822 TLExpr::SymbolLiteral(_) => (expr, false),
824
825 TLExpr::Match { scrutinee, arms } => {
826 let (new_scrutinee, sc) = self.propagate(*scrutinee, stats);
827 let mut any_changed = sc;
828 let new_arms = arms
829 .into_iter()
830 .map(|(pat, body)| {
831 let (new_body, bc) = self.propagate(*body, stats);
832 if bc {
833 any_changed = true;
834 }
835 (pat, Box::new(new_body))
836 })
837 .collect();
838 (
839 TLExpr::Match {
840 scrutinee: Box::new(new_scrutinee),
841 arms: new_arms,
842 },
843 any_changed,
844 )
845 }
846 }
847 }
848
849 pub fn as_constant(expr: &TLExpr) -> Option<f64> {
853 if let TLExpr::Constant(v) = expr {
854 Some(*v)
855 } else {
856 None
857 }
858 }
859
860 fn fold_binary_arith(
864 &self,
865 op_name: &str,
866 lhs: TLExpr,
867 rhs: TLExpr,
868 stats: &mut ConstPropStats,
869 ctor: fn(Box<TLExpr>, Box<TLExpr>) -> TLExpr,
870 ) -> (TLExpr, bool) {
871 let (new_lhs, cl) = self.propagate(lhs, stats);
872 let (new_rhs, cr) = self.propagate(rhs, stats);
873 let child_changed = cl || cr;
874
875 if self.config.fold_arithmetic {
876 if let (Some(a), Some(b)) = (Self::as_constant(&new_lhs), Self::as_constant(&new_rhs)) {
877 if let Some(folded) = self.fold_arith_binary(op_name, a, b, stats) {
878 return (folded, true);
879 }
880 }
881 }
882 (ctor(Box::new(new_lhs), Box::new(new_rhs)), child_changed)
883 }
884
885 fn fold_binary_cmp(
887 &self,
888 op_name: &str,
889 lhs: TLExpr,
890 rhs: TLExpr,
891 stats: &mut ConstPropStats,
892 ctor: fn(Box<TLExpr>, Box<TLExpr>) -> TLExpr,
893 ) -> (TLExpr, bool) {
894 let (new_lhs, cl) = self.propagate(lhs, stats);
895 let (new_rhs, cr) = self.propagate(rhs, stats);
896 let child_changed = cl || cr;
897
898 if self.config.fold_comparisons {
899 if let (Some(a), Some(b)) = (Self::as_constant(&new_lhs), Self::as_constant(&new_rhs)) {
900 if let Some(folded) = self.fold_comparison(op_name, a, b, stats) {
901 return (folded, true);
902 }
903 }
904 }
905 (ctor(Box::new(new_lhs), Box::new(new_rhs)), child_changed)
906 }
907
908 fn fold_unary_math(
910 &self,
911 op_name: &str,
912 inner: TLExpr,
913 stats: &mut ConstPropStats,
914 ctor: fn(Box<TLExpr>) -> TLExpr,
915 ) -> (TLExpr, bool) {
916 let (new_inner, child_changed) = self.propagate(inner, stats);
917
918 if self.config.fold_boolean {
919 if let Some(v) = Self::as_constant(&new_inner) {
920 let maybe_result = Self::fold_unary_math_value(op_name, v);
921 if let Some(result) = maybe_result {
922 stats.boolean_folds = stats.boolean_folds.saturating_add(1);
923 return (TLExpr::Constant(result), true);
924 }
925 }
926 }
927 (ctor(Box::new(new_inner)), child_changed)
928 }
929
930 fn fold_unary_math_value(op_name: &str, v: f64) -> Option<f64> {
933 match op_name {
934 "Abs" => Some(v.abs()),
935 "Floor" => Some(v.floor()),
936 "Ceil" => Some(v.ceil()),
937 "Round" => Some(v.round()),
938 "Sqrt" => {
939 if v < 0.0 {
940 None
941 } else {
942 Some(v.sqrt())
943 }
944 }
945 "Exp" => Some(v.exp()),
946 "Log" => {
947 if v <= 0.0 {
948 None
949 } else {
950 Some(v.ln())
951 }
952 }
953 "Sin" => Some(v.sin()),
954 "Cos" => Some(v.cos()),
955 "Tan" => Some(v.tan()),
956 _ => None,
957 }
958 }
959
960 fn fold_arith_binary(
964 &self,
965 op_name: &str,
966 lhs: f64,
967 rhs: f64,
968 stats: &mut ConstPropStats,
969 ) -> Option<TLExpr> {
970 let result = match op_name {
971 "Add" => lhs + rhs,
972 "Sub" => lhs - rhs,
973 "Mul" => lhs * rhs,
974 "Div" => {
975 if rhs.abs() < f64::EPSILON {
976 return None; }
978 lhs / rhs
979 }
980 "Pow" => lhs.powf(rhs),
981 "Mod" => {
982 if rhs.abs() < f64::EPSILON {
983 return None; }
985 lhs % rhs
986 }
987 "Min" => lhs.min(rhs),
988 "Max" => lhs.max(rhs),
989 _ => return None,
990 };
991
992 if result.is_finite() || result.is_infinite() {
993 stats.arithmetic_folds = stats.arithmetic_folds.saturating_add(1);
996 Some(TLExpr::Constant(result))
997 } else {
998 None
1000 }
1001 }
1002
1003 fn fold_comparison(
1006 &self,
1007 op_name: &str,
1008 lhs: f64,
1009 rhs: f64,
1010 stats: &mut ConstPropStats,
1011 ) -> Option<TLExpr> {
1012 let bool_result: bool = match op_name {
1013 "Eq" => (lhs - rhs).abs() <= self.config.float_tolerance,
1014 "Lt" => lhs < rhs,
1015 "Gt" => lhs > rhs,
1016 "Lte" => lhs <= rhs || (lhs - rhs).abs() <= self.config.float_tolerance,
1017 "Gte" => lhs >= rhs || (lhs - rhs).abs() <= self.config.float_tolerance,
1018 _ => return None,
1019 };
1020 stats.comparison_folds = stats.comparison_folds.saturating_add(1);
1021 Some(TLExpr::Constant(if bool_result { 1.0 } else { 0.0 }))
1022 }
1023
1024 pub fn count_nodes(expr: &TLExpr) -> u64 {
1026 match expr {
1027 TLExpr::Constant(_)
1029 | TLExpr::EmptySet
1030 | TLExpr::AllDifferent { .. }
1031 | TLExpr::Nominal { .. }
1032 | TLExpr::Abducible { .. }
1033 | TLExpr::Pred { .. } => 1,
1034
1035 TLExpr::Not(e)
1037 | TLExpr::Score(e)
1038 | TLExpr::Abs(e)
1039 | TLExpr::Floor(e)
1040 | TLExpr::Ceil(e)
1041 | TLExpr::Round(e)
1042 | TLExpr::Sqrt(e)
1043 | TLExpr::Exp(e)
1044 | TLExpr::Log(e)
1045 | TLExpr::Sin(e)
1046 | TLExpr::Cos(e)
1047 | TLExpr::Tan(e)
1048 | TLExpr::Box(e)
1049 | TLExpr::Diamond(e)
1050 | TLExpr::Next(e)
1051 | TLExpr::Eventually(e)
1052 | TLExpr::Always(e)
1053 | TLExpr::FuzzyNot { expr: e, .. }
1054 | TLExpr::Somewhere { formula: e }
1055 | TLExpr::Everywhere { formula: e }
1056 | TLExpr::SetCardinality { set: e }
1057 | TLExpr::Explain { formula: e }
1058 | TLExpr::WeightedRule { rule: e, .. } => 1 + Self::count_nodes(e),
1059
1060 TLExpr::Exists { body: e, .. }
1062 | TLExpr::ForAll { body: e, .. }
1063 | TLExpr::SoftExists { body: e, .. }
1064 | TLExpr::SoftForAll { body: e, .. }
1065 | TLExpr::Aggregate { body: e, .. }
1066 | TLExpr::CountingExists { body: e, .. }
1067 | TLExpr::CountingForAll { body: e, .. }
1068 | TLExpr::ExactCount { body: e, .. }
1069 | TLExpr::Majority { body: e, .. }
1070 | TLExpr::LeastFixpoint { body: e, .. }
1071 | TLExpr::GreatestFixpoint { body: e, .. }
1072 | TLExpr::Lambda { body: e, .. }
1073 | TLExpr::SetComprehension { condition: e, .. }
1074 | TLExpr::At { formula: e, .. } => 1 + Self::count_nodes(e),
1075
1076 TLExpr::And(l, r)
1078 | TLExpr::Or(l, r)
1079 | TLExpr::Imply(l, r)
1080 | TLExpr::Add(l, r)
1081 | TLExpr::Sub(l, r)
1082 | TLExpr::Mul(l, r)
1083 | TLExpr::Div(l, r)
1084 | TLExpr::Pow(l, r)
1085 | TLExpr::Mod(l, r)
1086 | TLExpr::Min(l, r)
1087 | TLExpr::Max(l, r)
1088 | TLExpr::Eq(l, r)
1089 | TLExpr::Lt(l, r)
1090 | TLExpr::Gt(l, r)
1091 | TLExpr::Lte(l, r)
1092 | TLExpr::Gte(l, r)
1093 | TLExpr::Until {
1094 before: l,
1095 after: r,
1096 }
1097 | TLExpr::Release {
1098 released: l,
1099 releaser: r,
1100 }
1101 | TLExpr::WeakUntil {
1102 before: l,
1103 after: r,
1104 }
1105 | TLExpr::StrongRelease {
1106 released: l,
1107 releaser: r,
1108 }
1109 | TLExpr::SetMembership { element: l, set: r }
1110 | TLExpr::SetUnion { left: l, right: r }
1111 | TLExpr::SetIntersection { left: l, right: r }
1112 | TLExpr::SetDifference { left: l, right: r }
1113 | TLExpr::Apply {
1114 function: l,
1115 argument: r,
1116 } => 1 + Self::count_nodes(l) + Self::count_nodes(r),
1117
1118 TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
1119 1 + Self::count_nodes(left) + Self::count_nodes(right)
1120 }
1121 TLExpr::FuzzyImplication {
1122 premise,
1123 conclusion,
1124 ..
1125 } => 1 + Self::count_nodes(premise) + Self::count_nodes(conclusion),
1126
1127 TLExpr::IfThenElse {
1128 condition,
1129 then_branch,
1130 else_branch,
1131 } => {
1132 1 + Self::count_nodes(condition)
1133 + Self::count_nodes(then_branch)
1134 + Self::count_nodes(else_branch)
1135 }
1136
1137 TLExpr::Let { value, body, .. } => {
1138 1 + Self::count_nodes(value) + Self::count_nodes(body)
1139 }
1140
1141 TLExpr::ProbabilisticChoice { alternatives } => {
1142 1 + alternatives
1143 .iter()
1144 .map(|(_, e)| Self::count_nodes(e))
1145 .sum::<u64>()
1146 }
1147 TLExpr::GlobalCardinality { values, .. } => {
1148 1 + values.iter().map(Self::count_nodes).sum::<u64>()
1149 }
1150
1151 TLExpr::SymbolLiteral(_) => 1,
1152
1153 TLExpr::Match { scrutinee, arms } => {
1154 1 + Self::count_nodes(scrutinee)
1155 + arms.iter().map(|(_, b)| Self::count_nodes(b)).sum::<u64>()
1156 }
1157 }
1158 }
1159}
1160
1161impl Default for ConstantPropagator {
1162 fn default() -> Self {
1163 Self::with_default()
1164 }
1165}
1166
1167#[cfg(test)]
1172mod tests {
1173 use super::*;
1174 use tensorlogic_ir::TLExpr;
1175
1176 fn propagator() -> ConstantPropagator {
1177 ConstantPropagator::with_default()
1178 }
1179
1180 fn assert_constant(expr: &TLExpr, expected: f64) {
1181 match expr {
1182 TLExpr::Constant(v) => {
1183 let diff = (v - expected).abs();
1184 assert!(diff < 1e-9, "Expected constant {}, got {}", expected, v);
1185 }
1186 other => panic!("Expected Constant({}), got {:?}", expected, other),
1187 }
1188 }
1189
1190 #[test]
1192 fn test_constant_returns_itself() {
1193 let (result, stats) = propagator().run(TLExpr::Constant(3.0));
1194 assert_constant(&result, 3.0);
1195 assert_eq!(stats.total_folds(), 0);
1196 }
1197
1198 #[test]
1200 fn test_add_two_constants() {
1201 let expr = TLExpr::add(TLExpr::Constant(2.0), TLExpr::Constant(3.0));
1202 let (result, stats) = propagator().run(expr);
1203 assert_constant(&result, 5.0);
1204 assert!(stats.arithmetic_folds >= 1);
1205 }
1206
1207 #[test]
1209 fn test_sub_two_constants() {
1210 let expr = TLExpr::sub(TLExpr::Constant(5.0), TLExpr::Constant(3.0));
1211 let (result, _) = propagator().run(expr);
1212 assert_constant(&result, 2.0);
1213 }
1214
1215 #[test]
1217 fn test_mul_two_constants() {
1218 let expr = TLExpr::mul(TLExpr::Constant(2.0), TLExpr::Constant(4.0));
1219 let (result, _) = propagator().run(expr);
1220 assert_constant(&result, 8.0);
1221 }
1222
1223 #[test]
1225 fn test_div_two_constants() {
1226 let expr = TLExpr::div(TLExpr::Constant(6.0), TLExpr::Constant(2.0));
1227 let (result, _) = propagator().run(expr);
1228 assert_constant(&result, 3.0);
1229 }
1230
1231 #[test]
1233 fn test_div_by_zero_no_fold() {
1234 let x = TLExpr::pred("x", vec![]);
1235 let expr = TLExpr::div(x, TLExpr::Constant(0.0));
1236 let (result, stats) = propagator().run(expr);
1237 assert!(!matches!(result, TLExpr::Constant(_)));
1239 assert_eq!(stats.arithmetic_folds, 0);
1240 }
1241
1242 #[test]
1245 fn test_neg_via_sub_constant() {
1246 let expr = TLExpr::sub(TLExpr::Constant(0.0), TLExpr::Constant(3.0));
1247 let (result, _) = propagator().run(expr);
1248 assert_constant(&result, -3.0);
1249 }
1250
1251 #[test]
1253 fn test_abs_constant() {
1254 let expr = TLExpr::abs(TLExpr::Constant(-5.0));
1255 let (result, stats) = propagator().run(expr);
1256 assert_constant(&result, 5.0);
1257 assert!(stats.boolean_folds >= 1);
1258 }
1259
1260 #[test]
1262 fn test_nested_arithmetic() {
1263 let expr = TLExpr::add(
1265 TLExpr::mul(TLExpr::Constant(2.0), TLExpr::Constant(3.0)),
1266 TLExpr::Constant(4.0),
1267 );
1268 let (result, stats) = propagator().run(expr);
1269 assert_constant(&result, 10.0);
1270 assert!(stats.arithmetic_folds >= 2);
1271 }
1272
1273 #[test]
1275 fn test_comparison_lt_true() {
1276 let expr = TLExpr::Lt(
1277 Box::new(TLExpr::Constant(1.0)),
1278 Box::new(TLExpr::Constant(2.0)),
1279 );
1280 let (result, stats) = propagator().run(expr);
1281 assert_constant(&result, 1.0); assert!(stats.comparison_folds >= 1);
1283 }
1284
1285 #[test]
1287 fn test_comparison_gt_false() {
1288 let expr = TLExpr::Gt(
1289 Box::new(TLExpr::Constant(1.0)),
1290 Box::new(TLExpr::Constant(2.0)),
1291 );
1292 let (result, stats) = propagator().run(expr);
1293 assert_constant(&result, 0.0); assert!(stats.comparison_folds >= 1);
1295 }
1296
1297 #[test]
1299 fn test_const_prop_stats_counts() {
1300 let expr = TLExpr::add(TLExpr::Constant(1.0), TLExpr::Constant(1.0));
1301 let (_, stats) = propagator().run(expr);
1302 assert!(stats.arithmetic_folds > 0, "Expected arithmetic_folds > 0");
1303 }
1304
1305 #[test]
1307 fn test_const_prop_stats_summary() {
1308 let expr = TLExpr::add(TLExpr::Constant(2.0), TLExpr::Constant(3.0));
1309 let (_, stats) = propagator().run(expr);
1310 let summary = stats.summary();
1311 assert!(!summary.is_empty(), "Expected non-empty summary");
1312 assert!(summary.contains("ConstProp"));
1313 }
1314
1315 #[test]
1317 fn test_const_prop_config_default() {
1318 let config = ConstPropConfig::default();
1319 assert_eq!(config.max_passes, 20);
1320 assert!(config.fold_arithmetic);
1321 assert!(config.fold_comparisons);
1322 assert!(config.fold_boolean);
1323 assert!((config.float_tolerance - 1e-12).abs() < 1e-20);
1324 }
1325
1326 #[test]
1328 fn test_disabled_fold() {
1329 let config = ConstPropConfig {
1330 fold_arithmetic: false,
1331 ..Default::default()
1332 };
1333 let prop = ConstantPropagator::new(config);
1334 let expr = TLExpr::add(TLExpr::Constant(2.0), TLExpr::Constant(3.0));
1335 let (result, stats) = prop.run(expr);
1336 assert!(!matches!(result, TLExpr::Constant(_)));
1338 assert_eq!(stats.arithmetic_folds, 0);
1339 }
1340
1341 #[test]
1343 fn test_fixed_point() {
1344 let expr = TLExpr::add(TLExpr::Constant(2.0), TLExpr::Constant(3.0));
1345 let (result1, _) = propagator().run(expr);
1346 let (result2, stats2) = propagator().run(result1.clone());
1347 assert_eq!(stats2.total_folds(), 0);
1349 if let TLExpr::Constant(v1) = result1 {
1350 if let TLExpr::Constant(v2) = result2 {
1351 assert!((v1 - v2).abs() < 1e-12);
1352 } else {
1353 panic!("Expected Constant in second run");
1354 }
1355 } else {
1356 panic!("Expected Constant in first run");
1357 }
1358 }
1359
1360 #[test]
1362 fn test_passes_count() {
1363 let expr = TLExpr::add(TLExpr::Constant(1.0), TLExpr::Constant(2.0));
1364 let (_, stats) = propagator().run(expr);
1365 assert!(stats.passes >= 1, "Expected at least 1 pass");
1366 }
1367
1368 #[test]
1370 fn test_reduction_pct() {
1371 let expr = TLExpr::add(
1373 TLExpr::mul(TLExpr::Constant(2.0), TLExpr::Constant(3.0)),
1374 TLExpr::Constant(4.0),
1375 );
1376 let (_, stats) = propagator().run(expr);
1377 assert!(stats.nodes_before > stats.nodes_after);
1378 assert!(stats.reduction_pct() > 0.0);
1379 }
1380
1381 #[test]
1383 fn test_non_constant_unchanged() {
1384 let expr = TLExpr::pred("x", vec![]);
1385 let (result, stats) = propagator().run(expr.clone());
1386 assert_eq!(stats.total_folds(), 0);
1387 assert!(matches!(result, TLExpr::Pred { .. }));
1389 }
1390
1391 #[test]
1393 fn test_mixed_expr() {
1394 let expr = TLExpr::add(TLExpr::Constant(2.0), TLExpr::pred("x", vec![]));
1396 let (result, stats) = propagator().run(expr);
1397 assert!(matches!(result, TLExpr::Add(_, _)));
1398 assert_eq!(stats.arithmetic_folds, 0);
1399 }
1400
1401 #[test]
1403 fn test_const_prop_with_dead_code() {
1404 use crate::dead_code::{DceConfig, DeadCodeEliminator};
1405
1406 let inner = TLExpr::add(TLExpr::Constant(1.0), TLExpr::Constant(2.0));
1412 let expr = TLExpr::and(TLExpr::Constant(1.0), inner);
1413
1414 let (after_cp, cp_stats) = propagator().run(expr);
1415 assert!(cp_stats.total_folds() >= 1);
1417
1418 let dce = DeadCodeEliminator::new(DceConfig::default());
1419 let (after_dce, _dce_stats) = dce.run(after_cp);
1420 assert!(matches!(after_dce, TLExpr::Constant(_)));
1422 }
1423
1424 #[test]
1426 fn test_pow_two_constants() {
1427 let expr = TLExpr::pow(TLExpr::Constant(2.0), TLExpr::Constant(10.0));
1428 let (result, _) = propagator().run(expr);
1429 assert_constant(&result, 1024.0);
1430 }
1431
1432 #[test]
1434 fn test_min_max_constants() {
1435 let min_expr = TLExpr::min(TLExpr::Constant(3.0), TLExpr::Constant(7.0));
1436 let (min_result, _) = propagator().run(min_expr);
1437 assert_constant(&min_result, 3.0);
1438
1439 let max_expr = TLExpr::max(TLExpr::Constant(3.0), TLExpr::Constant(7.0));
1440 let (max_result, _) = propagator().run(max_expr);
1441 assert_constant(&max_result, 7.0);
1442 }
1443
1444 #[test]
1446 fn test_comparison_eq_true() {
1447 let a = 1.0_f64;
1448 let b = a + 1e-13; let expr = TLExpr::Eq(Box::new(TLExpr::Constant(a)), Box::new(TLExpr::Constant(b)));
1450 let (result, stats) = propagator().run(expr);
1451 assert_constant(&result, 1.0); assert!(stats.comparison_folds >= 1);
1453 }
1454
1455 #[test]
1457 fn test_count_nodes() {
1458 assert_eq!(ConstantPropagator::count_nodes(&TLExpr::Constant(1.0)), 1);
1459 let binary = TLExpr::add(TLExpr::Constant(1.0), TLExpr::Constant(2.0));
1460 assert_eq!(ConstantPropagator::count_nodes(&binary), 3);
1461 }
1462}