1use std::collections::HashMap;
31
32use super::{
33 distributive_laws::{apply_distributive_laws, DistributiveStrategy},
34 modal_equivalences::apply_modal_equivalences,
35 normal_forms::to_nnf,
36 optimization::{algebraic_simplify, constant_fold, propagate_constants},
37 temporal_equivalences::apply_temporal_equivalences,
38 TLExpr,
39};
40
41#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Default)]
43pub enum OptimizationLevel {
44 None,
46 Basic,
48 #[default]
50 Standard,
51 Aggressive,
53}
54
55#[derive(Clone, Copy, Debug, PartialEq, Eq)]
57pub enum OptimizationPass {
58 ConstantFolding,
60 ConstantPropagation,
62 AlgebraicSimplification,
64 NegationNormalForm,
66 ModalEquivalences,
68 TemporalEquivalences,
70 DistributiveAndOverOr,
72 DistributiveOrOverAnd,
74 DistributiveQuantifiers,
76 DistributiveModal,
78}
79
80impl OptimizationPass {
81 pub fn name(&self) -> &'static str {
83 match self {
84 OptimizationPass::ConstantFolding => "constant_folding",
85 OptimizationPass::ConstantPropagation => "constant_propagation",
86 OptimizationPass::AlgebraicSimplification => "algebraic_simplification",
87 OptimizationPass::NegationNormalForm => "negation_normal_form",
88 OptimizationPass::ModalEquivalences => "modal_equivalences",
89 OptimizationPass::TemporalEquivalences => "temporal_equivalences",
90 OptimizationPass::DistributiveAndOverOr => "distributive_and_over_or",
91 OptimizationPass::DistributiveOrOverAnd => "distributive_or_over_and",
92 OptimizationPass::DistributiveQuantifiers => "distributive_quantifiers",
93 OptimizationPass::DistributiveModal => "distributive_modal",
94 }
95 }
96
97 pub fn apply(&self, expr: TLExpr) -> TLExpr {
99 match self {
100 OptimizationPass::ConstantFolding => constant_fold(&expr),
101 OptimizationPass::ConstantPropagation => propagate_constants(&expr),
102 OptimizationPass::AlgebraicSimplification => algebraic_simplify(&expr),
103 OptimizationPass::NegationNormalForm => to_nnf(&expr),
104 OptimizationPass::ModalEquivalences => apply_modal_equivalences(&expr),
105 OptimizationPass::TemporalEquivalences => apply_temporal_equivalences(&expr),
106 OptimizationPass::DistributiveAndOverOr => {
107 apply_distributive_laws(&expr, DistributiveStrategy::AndOverOr)
108 }
109 OptimizationPass::DistributiveOrOverAnd => {
110 apply_distributive_laws(&expr, DistributiveStrategy::OrOverAnd)
111 }
112 OptimizationPass::DistributiveQuantifiers => {
113 apply_distributive_laws(&expr, DistributiveStrategy::Quantifiers)
114 }
115 OptimizationPass::DistributiveModal => {
116 apply_distributive_laws(&expr, DistributiveStrategy::Modal)
117 }
118 }
119 }
120
121 pub fn priority(&self) -> u32 {
123 match self {
124 OptimizationPass::ConstantFolding => 10,
126 OptimizationPass::ConstantPropagation => 20,
127 OptimizationPass::NegationNormalForm => 30,
128 OptimizationPass::AlgebraicSimplification => 40,
130 OptimizationPass::ModalEquivalences => 50,
131 OptimizationPass::TemporalEquivalences => 60,
132 OptimizationPass::DistributiveQuantifiers => 70,
134 OptimizationPass::DistributiveModal => 80,
135 OptimizationPass::DistributiveAndOverOr => 90,
136 OptimizationPass::DistributiveOrOverAnd => 100,
137 }
138 }
139
140 pub fn for_level(level: OptimizationLevel) -> Vec<OptimizationPass> {
142 match level {
143 OptimizationLevel::None => vec![],
144 OptimizationLevel::Basic => vec![
145 OptimizationPass::ConstantFolding,
146 OptimizationPass::ConstantPropagation,
147 OptimizationPass::AlgebraicSimplification,
148 ],
149 OptimizationLevel::Standard => vec![
150 OptimizationPass::ConstantFolding,
151 OptimizationPass::ConstantPropagation,
152 OptimizationPass::NegationNormalForm,
153 OptimizationPass::AlgebraicSimplification,
154 OptimizationPass::ModalEquivalences,
155 OptimizationPass::TemporalEquivalences,
156 ],
157 OptimizationLevel::Aggressive => vec![
158 OptimizationPass::ConstantFolding,
159 OptimizationPass::ConstantPropagation,
160 OptimizationPass::NegationNormalForm,
161 OptimizationPass::AlgebraicSimplification,
162 OptimizationPass::ModalEquivalences,
163 OptimizationPass::TemporalEquivalences,
164 OptimizationPass::DistributiveQuantifiers,
165 OptimizationPass::DistributiveModal,
166 OptimizationPass::DistributiveAndOverOr,
167 ],
168 }
169 }
170}
171
172#[derive(Clone, Debug, Default, PartialEq)]
174pub struct OptimizationMetrics {
175 pub passes_applied: usize,
177 pub iterations: usize,
179 pub converged: bool,
181 pub pass_counts: HashMap<String, usize>,
183 pub initial_size: usize,
185 pub final_size: usize,
187 pub reduction_ratio: f64,
189}
190
191impl OptimizationMetrics {
192 pub fn new() -> Self {
194 Self::default()
195 }
196
197 pub fn record_pass(&mut self, pass: OptimizationPass) {
199 self.passes_applied += 1;
200 *self.pass_counts.entry(pass.name().to_string()).or_insert(0) += 1;
201 }
202
203 pub fn finalize(&mut self, initial_size: usize, final_size: usize) {
205 self.initial_size = initial_size;
206 self.final_size = final_size;
207 self.reduction_ratio = if initial_size > 0 {
208 1.0 - (final_size as f64 / initial_size as f64)
209 } else {
210 0.0
211 };
212 }
213}
214
215#[derive(Clone, Debug)]
217pub struct PipelineConfig {
218 pub level: OptimizationLevel,
220 pub max_iterations: usize,
222 pub custom_passes: Option<Vec<OptimizationPass>>,
224 pub enable_convergence: bool,
226}
227
228impl Default for PipelineConfig {
229 fn default() -> Self {
230 Self {
231 level: OptimizationLevel::Standard,
232 max_iterations: 10,
233 custom_passes: None,
234 enable_convergence: true,
235 }
236 }
237}
238
239impl PipelineConfig {
240 pub fn with_level(level: OptimizationLevel) -> Self {
242 Self {
243 level,
244 ..Default::default()
245 }
246 }
247
248 pub fn with_custom_passes(mut self, passes: Vec<OptimizationPass>) -> Self {
250 self.custom_passes = Some(passes);
251 self
252 }
253
254 pub fn with_max_iterations(mut self, max_iterations: usize) -> Self {
256 self.max_iterations = max_iterations;
257 self
258 }
259
260 pub fn without_convergence(mut self) -> Self {
262 self.enable_convergence = false;
263 self
264 }
265}
266
267#[derive(Default)]
269pub struct OptimizationPipeline {
270 config: PipelineConfig,
271}
272
273impl OptimizationPipeline {
274 pub fn new(config: PipelineConfig) -> Self {
276 Self { config }
277 }
278
279 pub fn with_level(level: OptimizationLevel) -> Self {
281 Self::new(PipelineConfig::with_level(level))
282 }
283
284 pub fn optimize(&self, expr: TLExpr) -> (TLExpr, OptimizationMetrics) {
288 let mut current = expr;
289 let mut metrics = OptimizationMetrics::new();
290 let initial_size = count_nodes(¤t);
291
292 let passes = self
294 .config
295 .custom_passes
296 .clone()
297 .unwrap_or_else(|| OptimizationPass::for_level(self.config.level));
298
299 let mut sorted_passes = passes.clone();
301 sorted_passes.sort_by_key(|p| p.priority());
302
303 for iteration in 0..self.config.max_iterations {
305 metrics.iterations = iteration + 1;
306 let previous = current.clone();
307
308 for pass in &sorted_passes {
310 let before = current.clone();
311 current = pass.apply(current);
312
313 if before != current {
315 metrics.record_pass(*pass);
316 }
317 }
318
319 if self.config.enable_convergence && current == previous {
321 metrics.converged = true;
322 break;
323 }
324 }
325
326 let final_size = count_nodes(¤t);
327 metrics.finalize(initial_size, final_size);
328
329 (current, metrics)
330 }
331
332 pub fn apply_pass(&self, expr: TLExpr, pass: OptimizationPass) -> TLExpr {
334 pass.apply(expr)
335 }
336
337 pub fn config(&self) -> &PipelineConfig {
339 &self.config
340 }
341}
342
343fn count_nodes(expr: &TLExpr) -> usize {
345 match expr {
346 TLExpr::Pred { .. } | TLExpr::Constant(_) => 1,
347 TLExpr::And(l, r)
348 | TLExpr::Or(l, r)
349 | TLExpr::Imply(l, r)
350 | TLExpr::Add(l, r)
351 | TLExpr::Sub(l, r)
352 | TLExpr::Mul(l, r)
353 | TLExpr::Div(l, r)
354 | TLExpr::Pow(l, r)
355 | TLExpr::Mod(l, r)
356 | TLExpr::Min(l, r)
357 | TLExpr::Max(l, r)
358 | TLExpr::Eq(l, r)
359 | TLExpr::Lt(l, r)
360 | TLExpr::Gt(l, r)
361 | TLExpr::Lte(l, r)
362 | TLExpr::Gte(l, r) => 1 + count_nodes(l) + count_nodes(r),
363 TLExpr::Not(e)
364 | TLExpr::Score(e)
365 | TLExpr::Abs(e)
366 | TLExpr::Floor(e)
367 | TLExpr::Ceil(e)
368 | TLExpr::Round(e)
369 | TLExpr::Sqrt(e)
370 | TLExpr::Exp(e)
371 | TLExpr::Log(e)
372 | TLExpr::Sin(e)
373 | TLExpr::Cos(e)
374 | TLExpr::Tan(e)
375 | TLExpr::Box(e)
376 | TLExpr::Diamond(e)
377 | TLExpr::Next(e)
378 | TLExpr::Eventually(e)
379 | TLExpr::Always(e) => 1 + count_nodes(e),
380 TLExpr::Until { before, after }
381 | TLExpr::Release {
382 released: before,
383 releaser: after,
384 }
385 | TLExpr::WeakUntil { before, after }
386 | TLExpr::StrongRelease {
387 released: before,
388 releaser: after,
389 } => 1 + count_nodes(before) + count_nodes(after),
390 TLExpr::Exists { body, .. }
391 | TLExpr::ForAll { body, .. }
392 | TLExpr::SoftExists { body, .. }
393 | TLExpr::SoftForAll { body, .. }
394 | TLExpr::Aggregate { body, .. }
395 | TLExpr::WeightedRule { rule: body, .. }
396 | TLExpr::FuzzyNot { expr: body, .. } => 1 + count_nodes(body),
397 TLExpr::TNorm { left, right, .. }
398 | TLExpr::TCoNorm { left, right, .. }
399 | TLExpr::FuzzyImplication {
400 premise: left,
401 conclusion: right,
402 ..
403 } => 1 + count_nodes(left) + count_nodes(right),
404 TLExpr::ProbabilisticChoice { alternatives } => {
405 1 + alternatives
406 .iter()
407 .map(|(_, e)| count_nodes(e))
408 .sum::<usize>()
409 }
410 TLExpr::IfThenElse {
411 condition,
412 then_branch,
413 else_branch,
414 } => 1 + count_nodes(condition) + count_nodes(then_branch) + count_nodes(else_branch),
415 TLExpr::Let { value, body, .. } => 1 + count_nodes(value) + count_nodes(body),
416 }
417}
418
419#[cfg(test)]
420mod tests {
421 use super::*;
422 use crate::Term;
423
424 #[test]
425 fn test_optimization_level_ordering() {
426 assert!(OptimizationLevel::None < OptimizationLevel::Basic);
427 assert!(OptimizationLevel::Basic < OptimizationLevel::Standard);
428 assert!(OptimizationLevel::Standard < OptimizationLevel::Aggressive);
429 }
430
431 #[test]
432 fn test_pass_priority_ordering() {
433 let passes = OptimizationPass::for_level(OptimizationLevel::Aggressive);
434 let priorities: Vec<u32> = passes.iter().map(|p| p.priority()).collect();
435
436 assert_eq!(passes[0], OptimizationPass::ConstantFolding);
438 assert_eq!(priorities[0], 10);
439 }
440
441 #[test]
442 fn test_pipeline_basic_optimization() {
443 let expr = TLExpr::and(
445 TLExpr::constant(1.0),
446 TLExpr::pred("P", vec![Term::var("x")]),
447 );
448
449 let pipeline = OptimizationPipeline::with_level(OptimizationLevel::Basic);
450 let (_optimized, metrics) = pipeline.optimize(expr);
451
452 assert!(metrics.passes_applied > 0);
454 assert!(metrics.reduction_ratio > 0.0);
455 assert!(metrics.converged);
456 }
457
458 #[test]
459 fn test_pipeline_no_optimization() {
460 let expr = TLExpr::pred("P", vec![Term::var("x")]);
461
462 let pipeline = OptimizationPipeline::with_level(OptimizationLevel::None);
463 let (optimized, metrics) = pipeline.optimize(expr.clone());
464
465 assert_eq!(optimized, expr);
467 assert_eq!(metrics.passes_applied, 0);
468 }
469
470 #[test]
471 fn test_pipeline_convergence() {
472 let expr = TLExpr::and(
474 TLExpr::or(TLExpr::constant(1.0), TLExpr::constant(0.0)),
475 TLExpr::pred("P", vec![Term::var("x")]),
476 );
477
478 let pipeline = OptimizationPipeline::with_level(OptimizationLevel::Standard);
479 let (_, metrics) = pipeline.optimize(expr);
480
481 assert!(metrics.converged);
483 assert!(metrics.iterations > 0);
484 }
485
486 #[test]
487 fn test_pipeline_max_iterations() {
488 let expr = TLExpr::pred("P", vec![Term::var("x")]);
489
490 let config = PipelineConfig::default().with_max_iterations(5);
491 let pipeline = OptimizationPipeline::new(config);
492 let (_, metrics) = pipeline.optimize(expr);
493
494 assert!(metrics.iterations <= 5);
496 }
497
498 #[test]
499 fn test_custom_passes() {
500 let expr = TLExpr::constant(42.0);
501
502 let custom_passes = vec![
503 OptimizationPass::ConstantFolding,
504 OptimizationPass::AlgebraicSimplification,
505 ];
506
507 let config = PipelineConfig::default().with_custom_passes(custom_passes);
508 let pipeline = OptimizationPipeline::new(config);
509 let (_, metrics) = pipeline.optimize(expr);
510
511 assert!(metrics.pass_counts.len() <= 2);
513 }
514
515 #[test]
516 fn test_metrics_tracking() {
517 let expr = TLExpr::and(
518 TLExpr::constant(1.0),
519 TLExpr::pred("P", vec![Term::var("x")]),
520 );
521
522 let pipeline = OptimizationPipeline::with_level(OptimizationLevel::Standard);
523 let (_, metrics) = pipeline.optimize(expr);
524
525 assert!(metrics.initial_size > metrics.final_size);
526 assert!(metrics.reduction_ratio > 0.0);
527 assert!(metrics.reduction_ratio <= 1.0);
528 }
529
530 #[test]
531 fn test_count_nodes_simple() {
532 let expr = TLExpr::pred("P", vec![Term::var("x")]);
533 assert_eq!(count_nodes(&expr), 1);
534 }
535
536 #[test]
537 fn test_count_nodes_complex() {
538 let expr = TLExpr::and(
539 TLExpr::pred("P", vec![Term::var("x")]),
540 TLExpr::or(
541 TLExpr::pred("Q", vec![Term::var("y")]),
542 TLExpr::pred("R", vec![Term::var("z")]),
543 ),
544 );
545 assert_eq!(count_nodes(&expr), 5);
547 }
548
549 #[test]
550 fn test_pipeline_aggressive_level() {
551 let expr = TLExpr::and(
552 TLExpr::or(
553 TLExpr::pred("P", vec![Term::var("x")]),
554 TLExpr::pred("Q", vec![Term::var("x")]),
555 ),
556 TLExpr::pred("R", vec![Term::var("x")]),
557 );
558
559 let pipeline = OptimizationPipeline::with_level(OptimizationLevel::Aggressive);
560 let (_, metrics) = pipeline.optimize(expr);
561
562 assert!(metrics.passes_applied > 0);
564 }
565
566 #[test]
567 fn test_pass_application() {
568 let expr = TLExpr::constant(1.0);
569 let pipeline = OptimizationPipeline::default();
570
571 let result = pipeline.apply_pass(expr.clone(), OptimizationPass::ConstantFolding);
572 assert_eq!(result, expr); }
574}