tensorlogic_compiler/optimize/
pipeline.rs1use super::{
43 algebraic::{simplify_algebraic, AlgebraicSimplificationStats},
44 constant_folding::{fold_constants, ConstantFoldingStats},
45 negation::{optimize_negations, NegationOptStats},
46};
47use tensorlogic_ir::TLExpr;
48
49#[derive(Debug, Clone)]
51pub struct PipelineConfig {
52 pub enable_negation_opt: bool,
54 pub enable_constant_folding: bool,
56 pub enable_algebraic_simplification: bool,
58 pub max_iterations: usize,
60 pub stop_on_fixed_point: bool,
62}
63
64impl Default for PipelineConfig {
65 fn default() -> Self {
66 Self {
67 enable_negation_opt: true,
68 enable_constant_folding: true,
69 enable_algebraic_simplification: true,
70 max_iterations: 10,
71 stop_on_fixed_point: true,
72 }
73 }
74}
75
76impl PipelineConfig {
77 pub fn all() -> Self {
79 Self::default()
80 }
81
82 pub fn none() -> Self {
84 Self {
85 enable_negation_opt: false,
86 enable_constant_folding: false,
87 enable_algebraic_simplification: false,
88 max_iterations: 1,
89 stop_on_fixed_point: true,
90 }
91 }
92
93 pub fn constant_folding_only() -> Self {
95 Self {
96 enable_negation_opt: false,
97 enable_constant_folding: true,
98 enable_algebraic_simplification: false,
99 max_iterations: 1,
100 stop_on_fixed_point: true,
101 }
102 }
103
104 pub fn algebraic_only() -> Self {
106 Self {
107 enable_negation_opt: false,
108 enable_constant_folding: false,
109 enable_algebraic_simplification: true,
110 max_iterations: 1,
111 stop_on_fixed_point: true,
112 }
113 }
114
115 pub fn aggressive() -> Self {
117 Self {
118 enable_negation_opt: true,
119 enable_constant_folding: true,
120 enable_algebraic_simplification: true,
121 max_iterations: 20,
122 stop_on_fixed_point: true,
123 }
124 }
125
126 pub fn with_negation_opt(mut self, enable: bool) -> Self {
128 self.enable_negation_opt = enable;
129 self
130 }
131
132 pub fn with_constant_folding(mut self, enable: bool) -> Self {
134 self.enable_constant_folding = enable;
135 self
136 }
137
138 pub fn with_algebraic_simplification(mut self, enable: bool) -> Self {
140 self.enable_algebraic_simplification = enable;
141 self
142 }
143
144 pub fn with_max_iterations(mut self, max: usize) -> Self {
146 self.max_iterations = max;
147 self
148 }
149
150 pub fn with_stop_on_fixed_point(mut self, stop: bool) -> Self {
152 self.stop_on_fixed_point = stop;
153 self
154 }
155}
156
157#[derive(Debug, Clone, Default)]
159pub struct IterationStats {
160 pub negation: NegationOptStats,
162 pub constant_folding: ConstantFoldingStats,
164 pub algebraic: AlgebraicSimplificationStats,
166}
167
168impl IterationStats {
169 pub fn made_changes(&self) -> bool {
171 self.negation.double_negations_eliminated > 0
172 || self.negation.demorgans_applied > 0
173 || self.negation.quantifier_negations_pushed > 0
174 || self.constant_folding.binary_ops_folded > 0
175 || self.constant_folding.unary_ops_folded > 0
176 || self.algebraic.identities_eliminated > 0
177 || self.algebraic.annihilations_applied > 0
178 || self.algebraic.idempotent_simplified > 0
179 }
180
181 pub fn total_optimizations(&self) -> usize {
183 self.negation.double_negations_eliminated
184 + self.negation.demorgans_applied
185 + self.negation.quantifier_negations_pushed
186 + self.constant_folding.binary_ops_folded
187 + self.constant_folding.unary_ops_folded
188 + self.algebraic.identities_eliminated
189 + self.algebraic.annihilations_applied
190 + self.algebraic.idempotent_simplified
191 }
192}
193
194#[derive(Debug, Clone, Default)]
196pub struct PipelineStats {
197 pub total_iterations: usize,
199 pub negation: NegationOptStats,
201 pub constant_folding: ConstantFoldingStats,
203 pub algebraic: AlgebraicSimplificationStats,
205 pub iterations: Vec<IterationStats>,
207 pub reached_fixed_point: bool,
209 pub stopped_at_max_iterations: bool,
211}
212
213impl PipelineStats {
214 pub fn total_optimizations(&self) -> usize {
216 self.negation.double_negations_eliminated
217 + self.negation.demorgans_applied
218 + self.negation.quantifier_negations_pushed
219 + self.constant_folding.binary_ops_folded
220 + self.constant_folding.unary_ops_folded
221 + self.algebraic.identities_eliminated
222 + self.algebraic.annihilations_applied
223 + self.algebraic.idempotent_simplified
224 }
225
226 pub fn most_productive_iteration(&self) -> Option<(usize, &IterationStats)> {
228 self.iterations
229 .iter()
230 .enumerate()
231 .max_by_key(|(_, stats)| stats.total_optimizations())
232 }
233}
234
235impl std::fmt::Display for PipelineStats {
236 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
237 writeln!(f, "Pipeline Statistics:")?;
238 writeln!(f, " Iterations: {}", self.total_iterations)?;
239 writeln!(f, " Reached fixed point: {}", self.reached_fixed_point)?;
240 writeln!(f, " Total optimizations: {}", self.total_optimizations())?;
241 writeln!(f, "\nNegation Optimization:")?;
242 writeln!(
243 f,
244 " Double negations eliminated: {}",
245 self.negation.double_negations_eliminated
246 )?;
247 writeln!(
248 f,
249 " De Morgan's laws applied: {}",
250 self.negation.demorgans_applied
251 )?;
252 writeln!(
253 f,
254 " Quantifier negations pushed: {}",
255 self.negation.quantifier_negations_pushed
256 )?;
257 writeln!(f, "\nConstant Folding:")?;
258 writeln!(
259 f,
260 " Binary ops folded: {}",
261 self.constant_folding.binary_ops_folded
262 )?;
263 writeln!(
264 f,
265 " Unary ops folded: {}",
266 self.constant_folding.unary_ops_folded
267 )?;
268 writeln!(f, "\nAlgebraic Simplification:")?;
269 writeln!(
270 f,
271 " Identities eliminated: {}",
272 self.algebraic.identities_eliminated
273 )?;
274 writeln!(
275 f,
276 " Annihilations applied: {}",
277 self.algebraic.annihilations_applied
278 )?;
279 writeln!(
280 f,
281 " Idempotent simplified: {}",
282 self.algebraic.idempotent_simplified
283 )?;
284 Ok(())
285 }
286}
287
288pub struct OptimizationPipeline {
333 config: PipelineConfig,
334}
335
336impl OptimizationPipeline {
337 pub fn new() -> Self {
339 Self {
340 config: PipelineConfig::default(),
341 }
342 }
343
344 pub fn with_config(config: PipelineConfig) -> Self {
346 Self { config }
347 }
348
349 pub fn optimize(&self, expr: &TLExpr) -> (TLExpr, PipelineStats) {
353 let mut current = expr.clone();
354 let mut stats = PipelineStats::default();
355
356 for iteration in 0..self.config.max_iterations {
357 let mut iter_stats = IterationStats::default();
358 let mut changed = false;
359
360 if self.config.enable_negation_opt {
362 let (optimized, neg_stats) = optimize_negations(¤t);
363 iter_stats.negation = neg_stats;
364
365 if optimized != current {
366 current = optimized;
367 changed = true;
368 }
369 }
370
371 if self.config.enable_constant_folding {
373 let (optimized, fold_stats) = fold_constants(¤t);
374 iter_stats.constant_folding = fold_stats;
375
376 if optimized != current {
377 current = optimized;
378 changed = true;
379 }
380 }
381
382 if self.config.enable_algebraic_simplification {
384 let (optimized, alg_stats) = simplify_algebraic(¤t);
385 iter_stats.algebraic = alg_stats;
386
387 if optimized != current {
388 current = optimized;
389 changed = true;
390 }
391 }
392
393 stats.total_iterations = iteration + 1;
395 stats.negation.double_negations_eliminated +=
396 iter_stats.negation.double_negations_eliminated;
397 stats.negation.demorgans_applied += iter_stats.negation.demorgans_applied;
398 stats.negation.quantifier_negations_pushed +=
399 iter_stats.negation.quantifier_negations_pushed;
400 stats.constant_folding.binary_ops_folded +=
401 iter_stats.constant_folding.binary_ops_folded;
402 stats.constant_folding.unary_ops_folded += iter_stats.constant_folding.unary_ops_folded;
403 stats.constant_folding.total_processed += iter_stats.constant_folding.total_processed;
404 stats.algebraic.identities_eliminated += iter_stats.algebraic.identities_eliminated;
405 stats.algebraic.annihilations_applied += iter_stats.algebraic.annihilations_applied;
406 stats.algebraic.idempotent_simplified += iter_stats.algebraic.idempotent_simplified;
407 stats.algebraic.total_processed += iter_stats.algebraic.total_processed;
408 stats.iterations.push(iter_stats);
409
410 if self.config.stop_on_fixed_point && !changed {
412 stats.reached_fixed_point = true;
413 break;
414 }
415
416 if iteration + 1 >= self.config.max_iterations {
418 stats.stopped_at_max_iterations = true;
419 }
420 }
421
422 (current, stats)
423 }
424
425 pub fn config(&self) -> &PipelineConfig {
427 &self.config
428 }
429}
430
431impl Default for OptimizationPipeline {
432 fn default() -> Self {
433 Self::new()
434 }
435}
436
437#[cfg(test)]
438mod tests {
439 use super::*;
440 use tensorlogic_ir::Term;
441
442 #[test]
443 fn test_pipeline_with_all_passes() {
444 let x = TLExpr::pred("x", vec![Term::var("i")]);
446 let expr = TLExpr::negate(TLExpr::and(
447 TLExpr::add(x, TLExpr::Constant(0.0)),
448 TLExpr::mul(TLExpr::Constant(2.0), TLExpr::Constant(3.0)),
449 ));
450
451 let pipeline = OptimizationPipeline::new();
452 let (optimized, stats) = pipeline.optimize(&expr);
453
454 assert!(stats.total_iterations > 0);
456 assert!(stats.constant_folding.binary_ops_folded > 0);
457 assert!(stats.algebraic.identities_eliminated > 0);
458 assert!(stats.negation.demorgans_applied > 0);
459
460 assert!(optimized != expr);
462 }
463
464 #[test]
465 fn test_constant_folding_only() {
466 let expr = TLExpr::add(
467 TLExpr::Constant(2.0),
468 TLExpr::mul(TLExpr::Constant(3.0), TLExpr::Constant(4.0)),
469 );
470
471 let config = PipelineConfig::constant_folding_only();
472 let pipeline = OptimizationPipeline::with_config(config);
473 let (optimized, stats) = pipeline.optimize(&expr);
474
475 assert!(matches!(optimized, TLExpr::Constant(_)));
477 assert_eq!(stats.constant_folding.binary_ops_folded, 2);
478 assert_eq!(stats.algebraic.identities_eliminated, 0);
479 assert_eq!(stats.negation.demorgans_applied, 0);
480 }
481
482 #[test]
483 fn test_algebraic_only() {
484 let x = TLExpr::pred("x", vec![Term::var("i")]);
485 let expr = TLExpr::mul(TLExpr::add(x, TLExpr::Constant(0.0)), TLExpr::Constant(1.0));
486
487 let config = PipelineConfig::algebraic_only();
488 let pipeline = OptimizationPipeline::with_config(config);
489 let (_optimized, stats) = pipeline.optimize(&expr);
490
491 assert_eq!(stats.algebraic.identities_eliminated, 2);
493 assert_eq!(stats.constant_folding.binary_ops_folded, 0);
494 }
495
496 #[test]
497 fn test_fixed_point_detection() {
498 let x = TLExpr::pred("x", vec![Term::var("i")]);
500
501 let config = PipelineConfig::default().with_max_iterations(10);
502 let pipeline = OptimizationPipeline::with_config(config);
503 let (optimized, stats) = pipeline.optimize(&x);
504
505 assert_eq!(stats.total_iterations, 1);
507 assert!(stats.reached_fixed_point);
508 assert!(!stats.stopped_at_max_iterations);
509 assert_eq!(optimized, x);
510 }
511
512 #[test]
513 fn test_max_iterations_limit() {
514 let x = TLExpr::pred("x", vec![Term::var("i")]);
516 let expr = TLExpr::negate(TLExpr::negate(TLExpr::add(x, TLExpr::Constant(0.0))));
517
518 let config = PipelineConfig::default().with_max_iterations(1);
519 let pipeline = OptimizationPipeline::with_config(config);
520 let (_, stats) = pipeline.optimize(&expr);
521
522 assert_eq!(stats.total_iterations, 1);
523 assert!(stats.stopped_at_max_iterations);
524 }
525
526 #[test]
527 fn test_aggressive_optimization() {
528 let x = TLExpr::pred("x", vec![Term::var("i")]);
530 let expr = TLExpr::add(
532 TLExpr::negate(TLExpr::and(
533 TLExpr::negate(TLExpr::add(x.clone(), TLExpr::Constant(0.0))),
534 TLExpr::negate(TLExpr::mul(
535 TLExpr::mul(TLExpr::Constant(2.0), TLExpr::Constant(3.0)),
536 x,
537 )),
538 )),
539 TLExpr::mul(TLExpr::Constant(1.0), TLExpr::Constant(1.0)),
540 );
541
542 let config = PipelineConfig::aggressive();
543 let pipeline = OptimizationPipeline::with_config(config);
544 let (_, stats) = pipeline.optimize(&expr);
545
546 assert!(
549 stats.total_optimizations() >= 4,
550 "Expected at least 4 optimizations, got {}",
551 stats.total_optimizations()
552 );
553 assert!(stats.total_iterations >= 1);
554 }
555
556 #[test]
557 fn test_no_optimization() {
558 let x = TLExpr::pred("x", vec![Term::var("i")]);
559 let expr = TLExpr::add(x.clone(), TLExpr::Constant(1.0));
560
561 let config = PipelineConfig::none();
562 let pipeline = OptimizationPipeline::with_config(config);
563 let (optimized, stats) = pipeline.optimize(&expr);
564
565 assert_eq!(optimized, expr);
567 assert_eq!(stats.total_optimizations(), 0);
568 }
569
570 #[test]
571 fn test_iteration_stats() {
572 let expr = TLExpr::add(
573 TLExpr::mul(TLExpr::Constant(2.0), TLExpr::Constant(3.0)),
574 TLExpr::Constant(0.0),
575 );
576
577 let pipeline = OptimizationPipeline::new();
578 let (_, stats) = pipeline.optimize(&expr);
579
580 assert!(!stats.iterations.is_empty());
582 assert!(stats.iterations[0].made_changes());
583 assert!(stats.iterations[0].total_optimizations() > 0);
584 }
585
586 #[test]
587 fn test_most_productive_iteration() {
588 let x = TLExpr::pred("x", vec![Term::var("i")]);
589 let expr = TLExpr::negate(TLExpr::negate(TLExpr::add(
590 TLExpr::mul(TLExpr::Constant(2.0), TLExpr::Constant(3.0)),
591 TLExpr::mul(x, TLExpr::Constant(1.0)),
592 )));
593
594 let pipeline = OptimizationPipeline::new();
595 let (_, stats) = pipeline.optimize(&expr);
596
597 let (iter_idx, iter_stats) = stats.most_productive_iteration().unwrap();
599 assert!(iter_stats.total_optimizations() > 0);
600 assert!(iter_idx < stats.total_iterations);
601 }
602
603 #[test]
604 fn test_pipeline_display() {
605 let expr = TLExpr::add(TLExpr::Constant(2.0), TLExpr::Constant(3.0));
606 let pipeline = OptimizationPipeline::new();
607 let (_, stats) = pipeline.optimize(&expr);
608
609 let output = format!("{}", stats);
611 assert!(output.contains("Pipeline Statistics"));
612 assert!(output.contains("Iterations:"));
613 assert!(output.contains("Total optimizations:"));
614 }
615
616 #[test]
617 fn test_builder_pattern() {
618 let config = PipelineConfig::default()
619 .with_negation_opt(false)
620 .with_constant_folding(true)
621 .with_algebraic_simplification(false)
622 .with_max_iterations(5)
623 .with_stop_on_fixed_point(false);
624
625 assert!(!config.enable_negation_opt);
626 assert!(config.enable_constant_folding);
627 assert!(!config.enable_algebraic_simplification);
628 assert_eq!(config.max_iterations, 5);
629 assert!(!config.stop_on_fixed_point);
630 }
631
632 #[test]
633 fn test_complex_real_world_expression() {
634 let x = TLExpr::pred("x", vec![Term::var("i")]);
636 let max_val = TLExpr::pred("max", vec![]);
637 let temp = TLExpr::Constant(1.0);
638
639 let expr = TLExpr::exp(TLExpr::div(TLExpr::sub(x, max_val), temp));
640
641 let pipeline = OptimizationPipeline::new();
642 let (optimized, stats) = pipeline.optimize(&expr);
643
644 assert!(stats.algebraic.identities_eliminated > 0);
646 assert!(optimized != expr);
647 }
648}