1use super::{
51 algebraic::{simplify_algebraic, AlgebraicSimplificationStats},
52 constant_folding::{fold_constants, ConstantFoldingStats},
53 dead_code::{eliminate_dead_code, DeadCodeStats},
54 distributivity::{optimize_distributivity, DistributivityStats},
55 negation::{optimize_negations, NegationOptStats},
56 quantifier_opt::{optimize_quantifiers, QuantifierOptStats},
57 strength_reduction::{reduce_strength, StrengthReductionStats},
58};
59use tensorlogic_ir::TLExpr;
60
61#[derive(Debug, Clone)]
63pub struct PipelineConfig {
64 pub enable_negation_opt: bool,
66 pub enable_constant_folding: bool,
68 pub enable_algebraic_simplification: bool,
70 pub enable_strength_reduction: bool,
72 pub enable_distributivity: bool,
74 pub enable_quantifier_opt: bool,
76 pub enable_dead_code_elimination: bool,
78 pub max_iterations: usize,
80 pub stop_on_fixed_point: bool,
82}
83
84impl Default for PipelineConfig {
85 fn default() -> Self {
86 Self {
87 enable_negation_opt: true,
88 enable_constant_folding: true,
89 enable_algebraic_simplification: true,
90 enable_strength_reduction: true,
91 enable_distributivity: true,
92 enable_quantifier_opt: true,
93 enable_dead_code_elimination: true,
94 max_iterations: 10,
95 stop_on_fixed_point: true,
96 }
97 }
98}
99
100impl PipelineConfig {
101 pub fn all() -> Self {
103 Self::default()
104 }
105
106 pub fn none() -> Self {
108 Self {
109 enable_negation_opt: false,
110 enable_constant_folding: false,
111 enable_algebraic_simplification: false,
112 enable_strength_reduction: false,
113 enable_distributivity: false,
114 enable_quantifier_opt: false,
115 enable_dead_code_elimination: false,
116 max_iterations: 1,
117 stop_on_fixed_point: true,
118 }
119 }
120
121 pub fn constant_folding_only() -> Self {
123 Self {
124 enable_negation_opt: false,
125 enable_constant_folding: true,
126 enable_algebraic_simplification: false,
127 enable_strength_reduction: false,
128 enable_distributivity: false,
129 enable_quantifier_opt: false,
130 enable_dead_code_elimination: false,
131 max_iterations: 1,
132 stop_on_fixed_point: true,
133 }
134 }
135
136 pub fn algebraic_only() -> Self {
138 Self {
139 enable_negation_opt: false,
140 enable_constant_folding: false,
141 enable_algebraic_simplification: true,
142 enable_strength_reduction: false,
143 enable_distributivity: false,
144 enable_quantifier_opt: false,
145 enable_dead_code_elimination: false,
146 max_iterations: 1,
147 stop_on_fixed_point: true,
148 }
149 }
150
151 pub fn aggressive() -> Self {
153 Self {
154 enable_negation_opt: true,
155 enable_constant_folding: true,
156 enable_algebraic_simplification: true,
157 enable_strength_reduction: true,
158 enable_distributivity: true,
159 enable_quantifier_opt: true,
160 enable_dead_code_elimination: true,
161 max_iterations: 20,
162 stop_on_fixed_point: true,
163 }
164 }
165
166 pub fn with_negation_opt(mut self, enable: bool) -> Self {
168 self.enable_negation_opt = enable;
169 self
170 }
171
172 pub fn with_constant_folding(mut self, enable: bool) -> Self {
174 self.enable_constant_folding = enable;
175 self
176 }
177
178 pub fn with_algebraic_simplification(mut self, enable: bool) -> Self {
180 self.enable_algebraic_simplification = enable;
181 self
182 }
183
184 pub fn with_max_iterations(mut self, max: usize) -> Self {
186 self.max_iterations = max;
187 self
188 }
189
190 pub fn with_stop_on_fixed_point(mut self, stop: bool) -> Self {
192 self.stop_on_fixed_point = stop;
193 self
194 }
195
196 pub fn with_strength_reduction(mut self, enable: bool) -> Self {
198 self.enable_strength_reduction = enable;
199 self
200 }
201
202 pub fn with_distributivity(mut self, enable: bool) -> Self {
204 self.enable_distributivity = enable;
205 self
206 }
207
208 pub fn with_quantifier_opt(mut self, enable: bool) -> Self {
210 self.enable_quantifier_opt = enable;
211 self
212 }
213
214 pub fn with_dead_code_elimination(mut self, enable: bool) -> Self {
216 self.enable_dead_code_elimination = enable;
217 self
218 }
219}
220
221#[derive(Debug, Clone, Default)]
223pub struct IterationStats {
224 pub negation: NegationOptStats,
226 pub constant_folding: ConstantFoldingStats,
228 pub algebraic: AlgebraicSimplificationStats,
230 pub strength_reduction: StrengthReductionStats,
232 pub distributivity: DistributivityStats,
234 pub quantifier_opt: QuantifierOptStats,
236 pub dead_code: DeadCodeStats,
238}
239
240impl IterationStats {
241 pub fn made_changes(&self) -> bool {
243 self.negation.double_negations_eliminated > 0
244 || self.negation.demorgans_applied > 0
245 || self.negation.quantifier_negations_pushed > 0
246 || self.constant_folding.binary_ops_folded > 0
247 || self.constant_folding.unary_ops_folded > 0
248 || self.algebraic.identities_eliminated > 0
249 || self.algebraic.annihilations_applied > 0
250 || self.algebraic.idempotent_simplified > 0
251 || self.strength_reduction.total_optimizations() > 0
252 || self.distributivity.total_optimizations() > 0
253 || self.quantifier_opt.total_optimizations() > 0
254 || self.dead_code.total_optimizations() > 0
255 }
256
257 pub fn total_optimizations(&self) -> usize {
259 self.negation.double_negations_eliminated
260 + self.negation.demorgans_applied
261 + self.negation.quantifier_negations_pushed
262 + self.constant_folding.binary_ops_folded
263 + self.constant_folding.unary_ops_folded
264 + self.algebraic.identities_eliminated
265 + self.algebraic.annihilations_applied
266 + self.algebraic.idempotent_simplified
267 + self.strength_reduction.total_optimizations()
268 + self.distributivity.total_optimizations()
269 + self.quantifier_opt.total_optimizations()
270 + self.dead_code.total_optimizations()
271 }
272}
273
274#[derive(Debug, Clone, Default)]
276pub struct PipelineStats {
277 pub total_iterations: usize,
279 pub negation: NegationOptStats,
281 pub constant_folding: ConstantFoldingStats,
283 pub algebraic: AlgebraicSimplificationStats,
285 pub strength_reduction: StrengthReductionStats,
287 pub distributivity: DistributivityStats,
289 pub quantifier_opt: QuantifierOptStats,
291 pub dead_code: DeadCodeStats,
293 pub iterations: Vec<IterationStats>,
295 pub reached_fixed_point: bool,
297 pub stopped_at_max_iterations: bool,
299}
300
301impl PipelineStats {
302 pub fn total_optimizations(&self) -> usize {
304 self.negation.double_negations_eliminated
305 + self.negation.demorgans_applied
306 + self.negation.quantifier_negations_pushed
307 + self.constant_folding.binary_ops_folded
308 + self.constant_folding.unary_ops_folded
309 + self.algebraic.identities_eliminated
310 + self.algebraic.annihilations_applied
311 + self.algebraic.idempotent_simplified
312 + self.strength_reduction.total_optimizations()
313 + self.distributivity.total_optimizations()
314 + self.quantifier_opt.total_optimizations()
315 + self.dead_code.total_optimizations()
316 }
317
318 pub fn most_productive_iteration(&self) -> Option<(usize, &IterationStats)> {
320 self.iterations
321 .iter()
322 .enumerate()
323 .max_by_key(|(_, stats)| stats.total_optimizations())
324 }
325}
326
327impl std::fmt::Display for PipelineStats {
328 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
329 writeln!(f, "Pipeline Statistics:")?;
330 writeln!(f, " Iterations: {}", self.total_iterations)?;
331 writeln!(f, " Reached fixed point: {}", self.reached_fixed_point)?;
332 writeln!(f, " Total optimizations: {}", self.total_optimizations())?;
333 writeln!(f, "\nNegation Optimization:")?;
334 writeln!(
335 f,
336 " Double negations eliminated: {}",
337 self.negation.double_negations_eliminated
338 )?;
339 writeln!(
340 f,
341 " De Morgan's laws applied: {}",
342 self.negation.demorgans_applied
343 )?;
344 writeln!(
345 f,
346 " Quantifier negations pushed: {}",
347 self.negation.quantifier_negations_pushed
348 )?;
349 writeln!(f, "\nConstant Folding:")?;
350 writeln!(
351 f,
352 " Binary ops folded: {}",
353 self.constant_folding.binary_ops_folded
354 )?;
355 writeln!(
356 f,
357 " Unary ops folded: {}",
358 self.constant_folding.unary_ops_folded
359 )?;
360 writeln!(f, "\nAlgebraic Simplification:")?;
361 writeln!(
362 f,
363 " Identities eliminated: {}",
364 self.algebraic.identities_eliminated
365 )?;
366 writeln!(
367 f,
368 " Annihilations applied: {}",
369 self.algebraic.annihilations_applied
370 )?;
371 writeln!(
372 f,
373 " Idempotent simplified: {}",
374 self.algebraic.idempotent_simplified
375 )?;
376 writeln!(f, "\nStrength Reduction:")?;
377 writeln!(
378 f,
379 " Power reductions: {}",
380 self.strength_reduction.power_reductions
381 )?;
382 writeln!(
383 f,
384 " Operations eliminated: {}",
385 self.strength_reduction.operations_eliminated
386 )?;
387 writeln!(
388 f,
389 " Special function optimizations: {}",
390 self.strength_reduction.special_function_optimizations
391 )?;
392 writeln!(f, "\nDistributivity:")?;
393 writeln!(
394 f,
395 " Expressions factored: {}",
396 self.distributivity.expressions_factored
397 )?;
398 writeln!(
399 f,
400 " Expressions expanded: {}",
401 self.distributivity.expressions_expanded
402 )?;
403 writeln!(f, "\nQuantifier Optimization:")?;
404 writeln!(
405 f,
406 " Invariants hoisted: {}",
407 self.quantifier_opt.invariants_hoisted
408 )?;
409 writeln!(
410 f,
411 " Quantifiers reordered: {}",
412 self.quantifier_opt.quantifiers_reordered
413 )?;
414 writeln!(f, "\nDead Code Elimination:")?;
415 writeln!(
416 f,
417 " Branches eliminated: {}",
418 self.dead_code.branches_eliminated
419 )?;
420 writeln!(f, " Short circuits: {}", self.dead_code.short_circuits)?;
421 writeln!(
422 f,
423 " Unused quantifiers removed: {}",
424 self.dead_code.unused_quantifiers_removed
425 )?;
426 Ok(())
427 }
428}
429
430pub struct OptimizationPipeline {
487 config: PipelineConfig,
488}
489
490impl OptimizationPipeline {
491 pub fn new() -> Self {
493 Self {
494 config: PipelineConfig::default(),
495 }
496 }
497
498 pub fn with_config(config: PipelineConfig) -> Self {
500 Self { config }
501 }
502
503 pub fn optimize(&self, expr: &TLExpr) -> (TLExpr, PipelineStats) {
507 let mut current = expr.clone();
508 let mut stats = PipelineStats::default();
509
510 for iteration in 0..self.config.max_iterations {
511 let mut iter_stats = IterationStats::default();
512 let mut changed = false;
513
514 if self.config.enable_negation_opt {
516 let (optimized, neg_stats) = optimize_negations(¤t);
517 iter_stats.negation = neg_stats;
518
519 if optimized != current {
520 current = optimized;
521 changed = true;
522 }
523 }
524
525 if self.config.enable_constant_folding {
527 let (optimized, fold_stats) = fold_constants(¤t);
528 iter_stats.constant_folding = fold_stats;
529
530 if optimized != current {
531 current = optimized;
532 changed = true;
533 }
534 }
535
536 if self.config.enable_algebraic_simplification {
538 let (optimized, alg_stats) = simplify_algebraic(¤t);
539 iter_stats.algebraic = alg_stats;
540
541 if optimized != current {
542 current = optimized;
543 changed = true;
544 }
545 }
546
547 if self.config.enable_strength_reduction {
549 let (optimized, sr_stats) = reduce_strength(¤t);
550 iter_stats.strength_reduction = sr_stats;
551
552 if optimized != current {
553 current = optimized;
554 changed = true;
555 }
556 }
557
558 if self.config.enable_distributivity {
560 let (optimized, dist_stats) = optimize_distributivity(¤t);
561 iter_stats.distributivity = dist_stats;
562
563 if optimized != current {
564 current = optimized;
565 changed = true;
566 }
567 }
568
569 if self.config.enable_quantifier_opt {
571 let (optimized, quant_stats) = optimize_quantifiers(¤t);
572 iter_stats.quantifier_opt = quant_stats;
573
574 if optimized != current {
575 current = optimized;
576 changed = true;
577 }
578 }
579
580 if self.config.enable_dead_code_elimination {
582 let (optimized, dead_stats) = eliminate_dead_code(¤t);
583 iter_stats.dead_code = dead_stats;
584
585 if optimized != current {
586 current = optimized;
587 changed = true;
588 }
589 }
590
591 stats.total_iterations = iteration + 1;
593 stats.negation.double_negations_eliminated +=
594 iter_stats.negation.double_negations_eliminated;
595 stats.negation.demorgans_applied += iter_stats.negation.demorgans_applied;
596 stats.negation.quantifier_negations_pushed +=
597 iter_stats.negation.quantifier_negations_pushed;
598 stats.constant_folding.binary_ops_folded +=
599 iter_stats.constant_folding.binary_ops_folded;
600 stats.constant_folding.unary_ops_folded += iter_stats.constant_folding.unary_ops_folded;
601 stats.constant_folding.total_processed += iter_stats.constant_folding.total_processed;
602 stats.algebraic.identities_eliminated += iter_stats.algebraic.identities_eliminated;
603 stats.algebraic.annihilations_applied += iter_stats.algebraic.annihilations_applied;
604 stats.algebraic.idempotent_simplified += iter_stats.algebraic.idempotent_simplified;
605 stats.algebraic.total_processed += iter_stats.algebraic.total_processed;
606 stats.strength_reduction.power_reductions +=
607 iter_stats.strength_reduction.power_reductions;
608 stats.strength_reduction.operations_eliminated +=
609 iter_stats.strength_reduction.operations_eliminated;
610 stats.strength_reduction.special_function_optimizations +=
611 iter_stats.strength_reduction.special_function_optimizations;
612 stats.strength_reduction.total_processed +=
613 iter_stats.strength_reduction.total_processed;
614 stats.distributivity.expressions_factored +=
615 iter_stats.distributivity.expressions_factored;
616 stats.distributivity.expressions_expanded +=
617 iter_stats.distributivity.expressions_expanded;
618 stats.distributivity.common_terms_extracted +=
619 iter_stats.distributivity.common_terms_extracted;
620 stats.distributivity.total_processed += iter_stats.distributivity.total_processed;
621 stats.quantifier_opt.invariants_hoisted += iter_stats.quantifier_opt.invariants_hoisted;
622 stats.quantifier_opt.quantifiers_reordered +=
623 iter_stats.quantifier_opt.quantifiers_reordered;
624 stats.quantifier_opt.quantifiers_fused += iter_stats.quantifier_opt.quantifiers_fused;
625 stats.quantifier_opt.total_processed += iter_stats.quantifier_opt.total_processed;
626 stats.dead_code.branches_eliminated += iter_stats.dead_code.branches_eliminated;
627 stats.dead_code.short_circuits += iter_stats.dead_code.short_circuits;
628 stats.dead_code.unused_quantifiers_removed +=
629 iter_stats.dead_code.unused_quantifiers_removed;
630 stats.dead_code.identity_simplifications +=
631 iter_stats.dead_code.identity_simplifications;
632 stats.dead_code.total_processed += iter_stats.dead_code.total_processed;
633 stats.iterations.push(iter_stats);
634
635 if self.config.stop_on_fixed_point && !changed {
637 stats.reached_fixed_point = true;
638 break;
639 }
640
641 if iteration + 1 >= self.config.max_iterations {
643 stats.stopped_at_max_iterations = true;
644 }
645 }
646
647 (current, stats)
648 }
649
650 pub fn config(&self) -> &PipelineConfig {
652 &self.config
653 }
654}
655
656impl Default for OptimizationPipeline {
657 fn default() -> Self {
658 Self::new()
659 }
660}
661
662#[cfg(test)]
663mod tests {
664 use super::*;
665 use tensorlogic_ir::Term;
666
667 #[test]
668 fn test_pipeline_with_all_passes() {
669 let x = TLExpr::pred("x", vec![Term::var("i")]);
671 let expr = TLExpr::negate(TLExpr::and(
672 TLExpr::add(x, TLExpr::Constant(0.0)),
673 TLExpr::mul(TLExpr::Constant(2.0), TLExpr::Constant(3.0)),
674 ));
675
676 let pipeline = OptimizationPipeline::new();
677 let (optimized, stats) = pipeline.optimize(&expr);
678
679 assert!(stats.total_iterations > 0);
681 assert!(stats.constant_folding.binary_ops_folded > 0);
682 assert!(stats.algebraic.identities_eliminated > 0);
683 assert!(stats.negation.demorgans_applied > 0);
684
685 assert!(optimized != expr);
687 }
688
689 #[test]
690 fn test_constant_folding_only() {
691 let expr = TLExpr::add(
692 TLExpr::Constant(2.0),
693 TLExpr::mul(TLExpr::Constant(3.0), TLExpr::Constant(4.0)),
694 );
695
696 let config = PipelineConfig::constant_folding_only();
697 let pipeline = OptimizationPipeline::with_config(config);
698 let (optimized, stats) = pipeline.optimize(&expr);
699
700 assert!(matches!(optimized, TLExpr::Constant(_)));
702 assert_eq!(stats.constant_folding.binary_ops_folded, 2);
703 assert_eq!(stats.algebraic.identities_eliminated, 0);
704 assert_eq!(stats.negation.demorgans_applied, 0);
705 }
706
707 #[test]
708 fn test_algebraic_only() {
709 let x = TLExpr::pred("x", vec![Term::var("i")]);
710 let expr = TLExpr::mul(TLExpr::add(x, TLExpr::Constant(0.0)), TLExpr::Constant(1.0));
711
712 let config = PipelineConfig::algebraic_only();
713 let pipeline = OptimizationPipeline::with_config(config);
714 let (_optimized, stats) = pipeline.optimize(&expr);
715
716 assert_eq!(stats.algebraic.identities_eliminated, 2);
718 assert_eq!(stats.constant_folding.binary_ops_folded, 0);
719 }
720
721 #[test]
722 fn test_fixed_point_detection() {
723 let x = TLExpr::pred("x", vec![Term::var("i")]);
725
726 let config = PipelineConfig::default().with_max_iterations(10);
727 let pipeline = OptimizationPipeline::with_config(config);
728 let (optimized, stats) = pipeline.optimize(&x);
729
730 assert_eq!(stats.total_iterations, 1);
732 assert!(stats.reached_fixed_point);
733 assert!(!stats.stopped_at_max_iterations);
734 assert_eq!(optimized, x);
735 }
736
737 #[test]
738 fn test_max_iterations_limit() {
739 let x = TLExpr::pred("x", vec![Term::var("i")]);
741 let expr = TLExpr::negate(TLExpr::negate(TLExpr::add(x, TLExpr::Constant(0.0))));
742
743 let config = PipelineConfig::default().with_max_iterations(1);
744 let pipeline = OptimizationPipeline::with_config(config);
745 let (_, stats) = pipeline.optimize(&expr);
746
747 assert_eq!(stats.total_iterations, 1);
748 assert!(stats.stopped_at_max_iterations);
749 }
750
751 #[test]
752 fn test_aggressive_optimization() {
753 let x = TLExpr::pred("x", vec![Term::var("i")]);
755 let expr = TLExpr::add(
757 TLExpr::negate(TLExpr::and(
758 TLExpr::negate(TLExpr::add(x.clone(), TLExpr::Constant(0.0))),
759 TLExpr::negate(TLExpr::mul(
760 TLExpr::mul(TLExpr::Constant(2.0), TLExpr::Constant(3.0)),
761 x,
762 )),
763 )),
764 TLExpr::mul(TLExpr::Constant(1.0), TLExpr::Constant(1.0)),
765 );
766
767 let config = PipelineConfig::aggressive();
768 let pipeline = OptimizationPipeline::with_config(config);
769 let (_, stats) = pipeline.optimize(&expr);
770
771 assert!(
774 stats.total_optimizations() >= 4,
775 "Expected at least 4 optimizations, got {}",
776 stats.total_optimizations()
777 );
778 assert!(stats.total_iterations >= 1);
779 }
780
781 #[test]
782 fn test_no_optimization() {
783 let x = TLExpr::pred("x", vec![Term::var("i")]);
784 let expr = TLExpr::add(x.clone(), TLExpr::Constant(1.0));
785
786 let config = PipelineConfig::none();
787 let pipeline = OptimizationPipeline::with_config(config);
788 let (optimized, stats) = pipeline.optimize(&expr);
789
790 assert_eq!(optimized, expr);
792 assert_eq!(stats.total_optimizations(), 0);
793 }
794
795 #[test]
796 fn test_iteration_stats() {
797 let expr = TLExpr::add(
798 TLExpr::mul(TLExpr::Constant(2.0), TLExpr::Constant(3.0)),
799 TLExpr::Constant(0.0),
800 );
801
802 let pipeline = OptimizationPipeline::new();
803 let (_, stats) = pipeline.optimize(&expr);
804
805 assert!(!stats.iterations.is_empty());
807 assert!(stats.iterations[0].made_changes());
808 assert!(stats.iterations[0].total_optimizations() > 0);
809 }
810
811 #[test]
812 fn test_most_productive_iteration() {
813 let x = TLExpr::pred("x", vec![Term::var("i")]);
814 let expr = TLExpr::negate(TLExpr::negate(TLExpr::add(
815 TLExpr::mul(TLExpr::Constant(2.0), TLExpr::Constant(3.0)),
816 TLExpr::mul(x, TLExpr::Constant(1.0)),
817 )));
818
819 let pipeline = OptimizationPipeline::new();
820 let (_, stats) = pipeline.optimize(&expr);
821
822 let (iter_idx, iter_stats) = stats.most_productive_iteration().unwrap();
824 assert!(iter_stats.total_optimizations() > 0);
825 assert!(iter_idx < stats.total_iterations);
826 }
827
828 #[test]
829 fn test_pipeline_display() {
830 let expr = TLExpr::add(TLExpr::Constant(2.0), TLExpr::Constant(3.0));
831 let pipeline = OptimizationPipeline::new();
832 let (_, stats) = pipeline.optimize(&expr);
833
834 let output = format!("{}", stats);
836 assert!(output.contains("Pipeline Statistics"));
837 assert!(output.contains("Iterations:"));
838 assert!(output.contains("Total optimizations:"));
839 }
840
841 #[test]
842 fn test_builder_pattern() {
843 let config = PipelineConfig::default()
844 .with_negation_opt(false)
845 .with_constant_folding(true)
846 .with_algebraic_simplification(false)
847 .with_max_iterations(5)
848 .with_stop_on_fixed_point(false);
849
850 assert!(!config.enable_negation_opt);
851 assert!(config.enable_constant_folding);
852 assert!(!config.enable_algebraic_simplification);
853 assert_eq!(config.max_iterations, 5);
854 assert!(!config.stop_on_fixed_point);
855 }
856
857 #[test]
858 fn test_complex_real_world_expression() {
859 let x = TLExpr::pred("x", vec![Term::var("i")]);
861 let max_val = TLExpr::pred("max", vec![]);
862 let temp = TLExpr::Constant(1.0);
863
864 let expr = TLExpr::exp(TLExpr::div(TLExpr::sub(x, max_val), temp));
865
866 let pipeline = OptimizationPipeline::new();
867 let (optimized, stats) = pipeline.optimize(&expr);
868
869 assert!(stats.algebraic.identities_eliminated > 0);
871 assert!(optimized != expr);
872 }
873
874 #[test]
875 fn test_dead_code_elimination_integration() {
876 let a = TLExpr::pred("a", vec![Term::var("i")]);
878 let b = TLExpr::pred("b", vec![Term::var("i")]);
879 let expr = TLExpr::IfThenElse {
880 condition: Box::new(TLExpr::Constant(1.0)), then_branch: Box::new(a.clone()),
882 else_branch: Box::new(b),
883 };
884
885 let pipeline = OptimizationPipeline::new();
886 let (optimized, stats) = pipeline.optimize(&expr);
887
888 assert!(stats.dead_code.branches_eliminated > 0);
890 assert!(matches!(optimized, TLExpr::Pred { .. }));
892 }
893
894 #[test]
895 fn test_all_passes_together() {
896 let x = TLExpr::pred("x", vec![Term::var("i")]);
899 let a = TLExpr::pred("a", vec![Term::var("i")]);
900 let b = TLExpr::pred("b", vec![Term::var("i")]);
901 let c = TLExpr::pred("c", vec![Term::var("i")]);
902
903 let expr = TLExpr::IfThenElse {
904 condition: Box::new(TLExpr::Constant(1.0)),
905 then_branch: Box::new(TLExpr::and(
906 TLExpr::negate(TLExpr::negate(TLExpr::add(
907 TLExpr::pow(x, TLExpr::Constant(2.0)),
908 TLExpr::Constant(0.0),
909 ))),
910 TLExpr::add(
911 TLExpr::mul(a.clone(), b.clone()),
912 TLExpr::mul(a.clone(), c.clone()),
913 ),
914 )),
915 else_branch: Box::new(TLExpr::Constant(0.0)),
916 };
917
918 let pipeline = OptimizationPipeline::new();
919 let (_, stats) = pipeline.optimize(&expr);
920
921 assert!(
928 stats.dead_code.branches_eliminated > 0,
929 "Dead code elimination should apply"
930 );
931 assert!(
932 stats.negation.double_negations_eliminated > 0,
933 "Negation optimization should apply"
934 );
935 assert!(
936 stats.algebraic.identities_eliminated > 0,
937 "Algebraic simplification should apply"
938 );
939 assert!(
940 stats.strength_reduction.power_reductions > 0,
941 "Strength reduction should apply"
942 );
943 assert!(
944 stats.distributivity.expressions_factored > 0,
945 "Distributivity should apply"
946 );
947
948 assert!(
950 stats.total_optimizations() >= 5,
951 "Should apply at least 5 optimizations"
952 );
953 }
954}