1use std::collections::HashMap;
31
32use super::{
33 distributive_laws::{apply_distributive_laws, DistributiveStrategy},
34 modal_equivalences::apply_modal_equivalences,
35 normal_forms::to_nnf,
36 optimization::{algebraic_simplify, constant_fold, propagate_constants},
37 temporal_equivalences::apply_temporal_equivalences,
38 TLExpr,
39};
40
41#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Default)]
43pub enum OptimizationLevel {
44 None,
46 Basic,
48 #[default]
50 Standard,
51 Aggressive,
53}
54
55#[derive(Clone, Copy, Debug, PartialEq, Eq)]
57pub enum OptimizationPass {
58 ConstantFolding,
60 ConstantPropagation,
62 AlgebraicSimplification,
64 NegationNormalForm,
66 ModalEquivalences,
68 TemporalEquivalences,
70 DistributiveAndOverOr,
72 DistributiveOrOverAnd,
74 DistributiveQuantifiers,
76 DistributiveModal,
78}
79
80impl OptimizationPass {
81 pub fn name(&self) -> &'static str {
83 match self {
84 OptimizationPass::ConstantFolding => "constant_folding",
85 OptimizationPass::ConstantPropagation => "constant_propagation",
86 OptimizationPass::AlgebraicSimplification => "algebraic_simplification",
87 OptimizationPass::NegationNormalForm => "negation_normal_form",
88 OptimizationPass::ModalEquivalences => "modal_equivalences",
89 OptimizationPass::TemporalEquivalences => "temporal_equivalences",
90 OptimizationPass::DistributiveAndOverOr => "distributive_and_over_or",
91 OptimizationPass::DistributiveOrOverAnd => "distributive_or_over_and",
92 OptimizationPass::DistributiveQuantifiers => "distributive_quantifiers",
93 OptimizationPass::DistributiveModal => "distributive_modal",
94 }
95 }
96
97 pub fn apply(&self, expr: TLExpr) -> TLExpr {
99 match self {
100 OptimizationPass::ConstantFolding => constant_fold(&expr),
101 OptimizationPass::ConstantPropagation => propagate_constants(&expr),
102 OptimizationPass::AlgebraicSimplification => algebraic_simplify(&expr),
103 OptimizationPass::NegationNormalForm => to_nnf(&expr),
104 OptimizationPass::ModalEquivalences => apply_modal_equivalences(&expr),
105 OptimizationPass::TemporalEquivalences => apply_temporal_equivalences(&expr),
106 OptimizationPass::DistributiveAndOverOr => {
107 apply_distributive_laws(&expr, DistributiveStrategy::AndOverOr)
108 }
109 OptimizationPass::DistributiveOrOverAnd => {
110 apply_distributive_laws(&expr, DistributiveStrategy::OrOverAnd)
111 }
112 OptimizationPass::DistributiveQuantifiers => {
113 apply_distributive_laws(&expr, DistributiveStrategy::Quantifiers)
114 }
115 OptimizationPass::DistributiveModal => {
116 apply_distributive_laws(&expr, DistributiveStrategy::Modal)
117 }
118 }
119 }
120
121 pub fn priority(&self) -> u32 {
123 match self {
124 OptimizationPass::ConstantFolding => 10,
126 OptimizationPass::ConstantPropagation => 20,
127 OptimizationPass::NegationNormalForm => 30,
128 OptimizationPass::AlgebraicSimplification => 40,
130 OptimizationPass::ModalEquivalences => 50,
131 OptimizationPass::TemporalEquivalences => 60,
132 OptimizationPass::DistributiveQuantifiers => 70,
134 OptimizationPass::DistributiveModal => 80,
135 OptimizationPass::DistributiveAndOverOr => 90,
136 OptimizationPass::DistributiveOrOverAnd => 100,
137 }
138 }
139
140 pub fn for_level(level: OptimizationLevel) -> Vec<OptimizationPass> {
142 match level {
143 OptimizationLevel::None => vec![],
144 OptimizationLevel::Basic => vec![
145 OptimizationPass::ConstantFolding,
146 OptimizationPass::ConstantPropagation,
147 OptimizationPass::AlgebraicSimplification,
148 ],
149 OptimizationLevel::Standard => vec![
150 OptimizationPass::ConstantFolding,
151 OptimizationPass::ConstantPropagation,
152 OptimizationPass::NegationNormalForm,
153 OptimizationPass::AlgebraicSimplification,
154 OptimizationPass::ModalEquivalences,
155 OptimizationPass::TemporalEquivalences,
156 ],
157 OptimizationLevel::Aggressive => vec![
158 OptimizationPass::ConstantFolding,
159 OptimizationPass::ConstantPropagation,
160 OptimizationPass::NegationNormalForm,
161 OptimizationPass::AlgebraicSimplification,
162 OptimizationPass::ModalEquivalences,
163 OptimizationPass::TemporalEquivalences,
164 OptimizationPass::DistributiveQuantifiers,
165 OptimizationPass::DistributiveModal,
166 OptimizationPass::DistributiveAndOverOr,
167 ],
168 }
169 }
170}
171
172#[derive(Clone, Debug, Default, PartialEq)]
174pub struct OptimizationMetrics {
175 pub passes_applied: usize,
177 pub iterations: usize,
179 pub converged: bool,
181 pub pass_counts: HashMap<String, usize>,
183 pub initial_size: usize,
185 pub final_size: usize,
187 pub reduction_ratio: f64,
189}
190
191impl OptimizationMetrics {
192 pub fn new() -> Self {
194 Self::default()
195 }
196
197 pub fn record_pass(&mut self, pass: OptimizationPass) {
199 self.passes_applied += 1;
200 *self.pass_counts.entry(pass.name().to_string()).or_insert(0) += 1;
201 }
202
203 pub fn finalize(&mut self, initial_size: usize, final_size: usize) {
205 self.initial_size = initial_size;
206 self.final_size = final_size;
207 self.reduction_ratio = if initial_size > 0 {
208 1.0 - (final_size as f64 / initial_size as f64)
209 } else {
210 0.0
211 };
212 }
213}
214
215#[derive(Clone, Debug)]
217pub struct PipelineConfig {
218 pub level: OptimizationLevel,
220 pub max_iterations: usize,
222 pub custom_passes: Option<Vec<OptimizationPass>>,
224 pub enable_convergence: bool,
226}
227
228impl Default for PipelineConfig {
229 fn default() -> Self {
230 Self {
231 level: OptimizationLevel::Standard,
232 max_iterations: 10,
233 custom_passes: None,
234 enable_convergence: true,
235 }
236 }
237}
238
239impl PipelineConfig {
240 pub fn with_level(level: OptimizationLevel) -> Self {
242 Self {
243 level,
244 ..Default::default()
245 }
246 }
247
248 pub fn with_custom_passes(mut self, passes: Vec<OptimizationPass>) -> Self {
250 self.custom_passes = Some(passes);
251 self
252 }
253
254 pub fn with_max_iterations(mut self, max_iterations: usize) -> Self {
256 self.max_iterations = max_iterations;
257 self
258 }
259
260 pub fn without_convergence(mut self) -> Self {
262 self.enable_convergence = false;
263 self
264 }
265}
266
267#[derive(Default)]
269pub struct OptimizationPipeline {
270 config: PipelineConfig,
271}
272
273impl OptimizationPipeline {
274 pub fn new(config: PipelineConfig) -> Self {
276 Self { config }
277 }
278
279 pub fn with_level(level: OptimizationLevel) -> Self {
281 Self::new(PipelineConfig::with_level(level))
282 }
283
284 pub fn optimize(&self, expr: TLExpr) -> (TLExpr, OptimizationMetrics) {
288 let mut current = expr;
289 let mut metrics = OptimizationMetrics::new();
290 let initial_size = count_nodes(¤t);
291
292 let passes = self
294 .config
295 .custom_passes
296 .clone()
297 .unwrap_or_else(|| OptimizationPass::for_level(self.config.level));
298
299 let mut sorted_passes = passes.clone();
301 sorted_passes.sort_by_key(|p| p.priority());
302
303 for iteration in 0..self.config.max_iterations {
305 metrics.iterations = iteration + 1;
306 let previous = current.clone();
307
308 for pass in &sorted_passes {
310 let before = current.clone();
311 current = pass.apply(current);
312
313 if before != current {
315 metrics.record_pass(*pass);
316 }
317 }
318
319 if self.config.enable_convergence && current == previous {
321 metrics.converged = true;
322 break;
323 }
324 }
325
326 let final_size = count_nodes(¤t);
327 metrics.finalize(initial_size, final_size);
328
329 (current, metrics)
330 }
331
332 pub fn apply_pass(&self, expr: TLExpr, pass: OptimizationPass) -> TLExpr {
334 pass.apply(expr)
335 }
336
337 pub fn config(&self) -> &PipelineConfig {
339 &self.config
340 }
341}
342
343fn count_nodes(expr: &TLExpr) -> usize {
345 match expr {
346 TLExpr::Pred { .. } | TLExpr::Constant(_) => 1,
347 TLExpr::And(l, r)
348 | TLExpr::Or(l, r)
349 | TLExpr::Imply(l, r)
350 | TLExpr::Add(l, r)
351 | TLExpr::Sub(l, r)
352 | TLExpr::Mul(l, r)
353 | TLExpr::Div(l, r)
354 | TLExpr::Pow(l, r)
355 | TLExpr::Mod(l, r)
356 | TLExpr::Min(l, r)
357 | TLExpr::Max(l, r)
358 | TLExpr::Eq(l, r)
359 | TLExpr::Lt(l, r)
360 | TLExpr::Gt(l, r)
361 | TLExpr::Lte(l, r)
362 | TLExpr::Gte(l, r) => 1 + count_nodes(l) + count_nodes(r),
363 TLExpr::Not(e)
364 | TLExpr::Score(e)
365 | TLExpr::Abs(e)
366 | TLExpr::Floor(e)
367 | TLExpr::Ceil(e)
368 | TLExpr::Round(e)
369 | TLExpr::Sqrt(e)
370 | TLExpr::Exp(e)
371 | TLExpr::Log(e)
372 | TLExpr::Sin(e)
373 | TLExpr::Cos(e)
374 | TLExpr::Tan(e)
375 | TLExpr::Box(e)
376 | TLExpr::Diamond(e)
377 | TLExpr::Next(e)
378 | TLExpr::Eventually(e)
379 | TLExpr::Always(e) => 1 + count_nodes(e),
380 TLExpr::Until { before, after }
381 | TLExpr::Release {
382 released: before,
383 releaser: after,
384 }
385 | TLExpr::WeakUntil { before, after }
386 | TLExpr::StrongRelease {
387 released: before,
388 releaser: after,
389 } => 1 + count_nodes(before) + count_nodes(after),
390 TLExpr::Exists { body, .. }
391 | TLExpr::ForAll { body, .. }
392 | TLExpr::SoftExists { body, .. }
393 | TLExpr::SoftForAll { body, .. }
394 | TLExpr::Aggregate { body, .. }
395 | TLExpr::WeightedRule { rule: body, .. }
396 | TLExpr::FuzzyNot { expr: body, .. } => 1 + count_nodes(body),
397 TLExpr::TNorm { left, right, .. }
398 | TLExpr::TCoNorm { left, right, .. }
399 | TLExpr::FuzzyImplication {
400 premise: left,
401 conclusion: right,
402 ..
403 } => 1 + count_nodes(left) + count_nodes(right),
404 TLExpr::ProbabilisticChoice { alternatives } => {
405 1 + alternatives
406 .iter()
407 .map(|(_, e)| count_nodes(e))
408 .sum::<usize>()
409 }
410 TLExpr::IfThenElse {
411 condition,
412 then_branch,
413 else_branch,
414 } => 1 + count_nodes(condition) + count_nodes(then_branch) + count_nodes(else_branch),
415 TLExpr::Let { value, body, .. } => 1 + count_nodes(value) + count_nodes(body),
416
417 TLExpr::Lambda { body, .. } => 1 + count_nodes(body),
419 TLExpr::Apply { function, argument } => 1 + count_nodes(function) + count_nodes(argument),
420 TLExpr::SetMembership { element, set }
421 | TLExpr::SetUnion {
422 left: element,
423 right: set,
424 }
425 | TLExpr::SetIntersection {
426 left: element,
427 right: set,
428 }
429 | TLExpr::SetDifference {
430 left: element,
431 right: set,
432 } => 1 + count_nodes(element) + count_nodes(set),
433 TLExpr::SetCardinality { set } => 1 + count_nodes(set),
434 TLExpr::EmptySet => 1,
435 TLExpr::SetComprehension { condition, .. } => 1 + count_nodes(condition),
436 TLExpr::CountingExists { body, .. }
437 | TLExpr::CountingForAll { body, .. }
438 | TLExpr::ExactCount { body, .. }
439 | TLExpr::Majority { body, .. } => 1 + count_nodes(body),
440 TLExpr::LeastFixpoint { body, .. } | TLExpr::GreatestFixpoint { body, .. } => {
441 1 + count_nodes(body)
442 }
443 TLExpr::Nominal { .. } => 1,
444 TLExpr::At { formula, .. } => 1 + count_nodes(formula),
445 TLExpr::Somewhere { formula } | TLExpr::Everywhere { formula } => 1 + count_nodes(formula),
446 TLExpr::AllDifferent { .. } => 1,
447 TLExpr::GlobalCardinality { values, .. } => {
448 1 + values.iter().map(count_nodes).sum::<usize>()
449 }
450 TLExpr::Abducible { .. } => 1,
451 TLExpr::Explain { formula } => 1 + count_nodes(formula),
452 }
453}
454
455#[cfg(test)]
456mod tests {
457 use super::*;
458 use crate::Term;
459
460 #[test]
461 fn test_optimization_level_ordering() {
462 assert!(OptimizationLevel::None < OptimizationLevel::Basic);
463 assert!(OptimizationLevel::Basic < OptimizationLevel::Standard);
464 assert!(OptimizationLevel::Standard < OptimizationLevel::Aggressive);
465 }
466
467 #[test]
468 fn test_pass_priority_ordering() {
469 let passes = OptimizationPass::for_level(OptimizationLevel::Aggressive);
470 let priorities: Vec<u32> = passes.iter().map(|p| p.priority()).collect();
471
472 assert_eq!(passes[0], OptimizationPass::ConstantFolding);
474 assert_eq!(priorities[0], 10);
475 }
476
477 #[test]
478 fn test_pipeline_basic_optimization() {
479 let expr = TLExpr::and(
481 TLExpr::constant(1.0),
482 TLExpr::pred("P", vec![Term::var("x")]),
483 );
484
485 let pipeline = OptimizationPipeline::with_level(OptimizationLevel::Basic);
486 let (_optimized, metrics) = pipeline.optimize(expr);
487
488 assert!(metrics.passes_applied > 0);
490 assert!(metrics.reduction_ratio > 0.0);
491 assert!(metrics.converged);
492 }
493
494 #[test]
495 fn test_pipeline_no_optimization() {
496 let expr = TLExpr::pred("P", vec![Term::var("x")]);
497
498 let pipeline = OptimizationPipeline::with_level(OptimizationLevel::None);
499 let (optimized, metrics) = pipeline.optimize(expr.clone());
500
501 assert_eq!(optimized, expr);
503 assert_eq!(metrics.passes_applied, 0);
504 }
505
506 #[test]
507 fn test_pipeline_convergence() {
508 let expr = TLExpr::and(
510 TLExpr::or(TLExpr::constant(1.0), TLExpr::constant(0.0)),
511 TLExpr::pred("P", vec![Term::var("x")]),
512 );
513
514 let pipeline = OptimizationPipeline::with_level(OptimizationLevel::Standard);
515 let (_, metrics) = pipeline.optimize(expr);
516
517 assert!(metrics.converged);
519 assert!(metrics.iterations > 0);
520 }
521
522 #[test]
523 fn test_pipeline_max_iterations() {
524 let expr = TLExpr::pred("P", vec![Term::var("x")]);
525
526 let config = PipelineConfig::default().with_max_iterations(5);
527 let pipeline = OptimizationPipeline::new(config);
528 let (_, metrics) = pipeline.optimize(expr);
529
530 assert!(metrics.iterations <= 5);
532 }
533
534 #[test]
535 fn test_custom_passes() {
536 let expr = TLExpr::constant(42.0);
537
538 let custom_passes = vec![
539 OptimizationPass::ConstantFolding,
540 OptimizationPass::AlgebraicSimplification,
541 ];
542
543 let config = PipelineConfig::default().with_custom_passes(custom_passes);
544 let pipeline = OptimizationPipeline::new(config);
545 let (_, metrics) = pipeline.optimize(expr);
546
547 assert!(metrics.pass_counts.len() <= 2);
549 }
550
551 #[test]
552 fn test_metrics_tracking() {
553 let expr = TLExpr::and(
554 TLExpr::constant(1.0),
555 TLExpr::pred("P", vec![Term::var("x")]),
556 );
557
558 let pipeline = OptimizationPipeline::with_level(OptimizationLevel::Standard);
559 let (_, metrics) = pipeline.optimize(expr);
560
561 assert!(metrics.initial_size > metrics.final_size);
562 assert!(metrics.reduction_ratio > 0.0);
563 assert!(metrics.reduction_ratio <= 1.0);
564 }
565
566 #[test]
567 fn test_count_nodes_simple() {
568 let expr = TLExpr::pred("P", vec![Term::var("x")]);
569 assert_eq!(count_nodes(&expr), 1);
570 }
571
572 #[test]
573 fn test_count_nodes_complex() {
574 let expr = TLExpr::and(
575 TLExpr::pred("P", vec![Term::var("x")]),
576 TLExpr::or(
577 TLExpr::pred("Q", vec![Term::var("y")]),
578 TLExpr::pred("R", vec![Term::var("z")]),
579 ),
580 );
581 assert_eq!(count_nodes(&expr), 5);
583 }
584
585 #[test]
586 fn test_pipeline_aggressive_level() {
587 let expr = TLExpr::and(
588 TLExpr::or(
589 TLExpr::pred("P", vec![Term::var("x")]),
590 TLExpr::pred("Q", vec![Term::var("x")]),
591 ),
592 TLExpr::pred("R", vec![Term::var("x")]),
593 );
594
595 let pipeline = OptimizationPipeline::with_level(OptimizationLevel::Aggressive);
596 let (_, metrics) = pipeline.optimize(expr);
597
598 assert!(metrics.passes_applied > 0);
600 }
601
602 #[test]
603 fn test_pass_application() {
604 let expr = TLExpr::constant(1.0);
605 let pipeline = OptimizationPipeline::default();
606
607 let result = pipeline.apply_pass(expr.clone(), OptimizationPass::ConstantFolding);
608 assert_eq!(result, expr); }
610}