1use tensorlogic_ir::TLExpr;
14
15#[derive(Debug, Clone, Default)]
17pub struct CanonicalStats {
18 pub double_neg_removed: usize,
20 pub commutative_sorted: usize,
22 pub nested_flattened: usize,
24 pub total_rewrites: usize,
26}
27
28impl CanonicalStats {
29 pub fn merge(&mut self, other: &CanonicalStats) {
31 self.double_neg_removed += other.double_neg_removed;
32 self.commutative_sorted += other.commutative_sorted;
33 self.nested_flattened += other.nested_flattened;
34 self.total_rewrites += other.total_rewrites;
35 }
36}
37
38impl std::fmt::Display for CanonicalStats {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 write!(
41 f,
42 "CanonicalStats {{ double_neg: {}, comm_sorted: {}, flattened: {}, total: {} }}",
43 self.double_neg_removed,
44 self.commutative_sorted,
45 self.nested_flattened,
46 self.total_rewrites
47 )
48 }
49}
50
51#[derive(Debug, Clone)]
53pub struct Canonicalizer {
54 pub sort_commutative: bool,
56 pub flatten_nested: bool,
58 pub elim_double_neg: bool,
60}
61
62impl Default for Canonicalizer {
63 fn default() -> Self {
64 Self::new()
65 }
66}
67
68impl Canonicalizer {
69 pub fn new() -> Self {
71 Canonicalizer {
72 sort_commutative: true,
73 flatten_nested: true,
74 elim_double_neg: true,
75 }
76 }
77
78 pub fn with_sort_commutative(mut self, v: bool) -> Self {
80 self.sort_commutative = v;
81 self
82 }
83
84 pub fn with_flatten_nested(mut self, v: bool) -> Self {
86 self.flatten_nested = v;
87 self
88 }
89
90 pub fn with_elim_double_neg(mut self, v: bool) -> Self {
92 self.elim_double_neg = v;
93 self
94 }
95
96 pub fn canonicalize(&self, expr: &TLExpr) -> (TLExpr, CanonicalStats) {
98 let mut stats = CanonicalStats::default();
99 let result = self.normalize(expr, &mut stats);
100 (result, stats)
101 }
102
103 pub fn canonical_key(&self, expr: &TLExpr) -> String {
105 let (normalized, _) = self.canonicalize(expr);
106 format!("{:?}", normalized)
107 }
108
109 fn normalize(&self, expr: &TLExpr, stats: &mut CanonicalStats) -> TLExpr {
110 match expr {
111 TLExpr::Not(inner) => {
113 if self.elim_double_neg {
114 if let TLExpr::Not(inner_inner) = inner.as_ref() {
115 stats.double_neg_removed += 1;
116 stats.total_rewrites += 1;
117 return self.normalize(inner_inner, stats);
118 }
119 }
120 TLExpr::negate(self.normalize(inner, stats))
121 }
122
123 TLExpr::And(a, b) => {
125 let norm_a = self.normalize(a, stats);
126 let norm_b = self.normalize(b, stats);
127 let mut operands = Vec::new();
128 if self.flatten_nested {
129 self.collect_and_operands(&norm_a, &mut operands, stats);
130 self.collect_and_operands(&norm_b, &mut operands, stats);
131 } else {
132 operands.push(norm_a);
133 operands.push(norm_b);
134 }
135 if self.sort_commutative {
136 let before = operands.iter().map(canonical_order_key).collect::<Vec<_>>();
137 operands.sort_by_key(canonical_order_key);
138 let after = operands.iter().map(canonical_order_key).collect::<Vec<_>>();
139 if before != after {
140 stats.commutative_sorted += 1;
141 stats.total_rewrites += 1;
142 }
143 }
144 self.build_right_leaning_and(operands)
145 }
146
147 TLExpr::Or(a, b) => {
149 let norm_a = self.normalize(a, stats);
150 let norm_b = self.normalize(b, stats);
151 let mut operands = Vec::new();
152 if self.flatten_nested {
153 self.collect_or_operands(&norm_a, &mut operands, stats);
154 self.collect_or_operands(&norm_b, &mut operands, stats);
155 } else {
156 operands.push(norm_a);
157 operands.push(norm_b);
158 }
159 if self.sort_commutative {
160 let before = operands.iter().map(canonical_order_key).collect::<Vec<_>>();
161 operands.sort_by_key(canonical_order_key);
162 let after = operands.iter().map(canonical_order_key).collect::<Vec<_>>();
163 if before != after {
164 stats.commutative_sorted += 1;
165 stats.total_rewrites += 1;
166 }
167 }
168 self.build_right_leaning_or(operands)
169 }
170
171 TLExpr::Imply(a, b) => {
173 TLExpr::imply(self.normalize(a, stats), self.normalize(b, stats))
174 }
175 TLExpr::Add(a, b) => TLExpr::add(self.normalize(a, stats), self.normalize(b, stats)),
176 TLExpr::Sub(a, b) => TLExpr::sub(self.normalize(a, stats), self.normalize(b, stats)),
177 TLExpr::Mul(a, b) => TLExpr::mul(self.normalize(a, stats), self.normalize(b, stats)),
178 TLExpr::Div(a, b) => TLExpr::div(self.normalize(a, stats), self.normalize(b, stats)),
179 TLExpr::Pow(a, b) => TLExpr::pow(self.normalize(a, stats), self.normalize(b, stats)),
180 TLExpr::Mod(a, b) => TLExpr::modulo(self.normalize(a, stats), self.normalize(b, stats)),
181 TLExpr::Min(a, b) => TLExpr::min(self.normalize(a, stats), self.normalize(b, stats)),
182 TLExpr::Max(a, b) => TLExpr::max(self.normalize(a, stats), self.normalize(b, stats)),
183 TLExpr::Eq(a, b) => TLExpr::eq(self.normalize(a, stats), self.normalize(b, stats)),
184 TLExpr::Lt(a, b) => TLExpr::lt(self.normalize(a, stats), self.normalize(b, stats)),
185 TLExpr::Gt(a, b) => TLExpr::gt(self.normalize(a, stats), self.normalize(b, stats)),
186 TLExpr::Lte(a, b) => TLExpr::lte(self.normalize(a, stats), self.normalize(b, stats)),
187 TLExpr::Gte(a, b) => TLExpr::gte(self.normalize(a, stats), self.normalize(b, stats)),
188
189 TLExpr::Abs(inner) => TLExpr::abs(self.normalize(inner, stats)),
191 TLExpr::Floor(inner) => TLExpr::floor(self.normalize(inner, stats)),
192 TLExpr::Ceil(inner) => TLExpr::ceil(self.normalize(inner, stats)),
193 TLExpr::Round(inner) => TLExpr::round(self.normalize(inner, stats)),
194 TLExpr::Sqrt(inner) => TLExpr::sqrt(self.normalize(inner, stats)),
195 TLExpr::Exp(inner) => TLExpr::exp(self.normalize(inner, stats)),
196 TLExpr::Log(inner) => TLExpr::log(self.normalize(inner, stats)),
197 TLExpr::Sin(inner) => TLExpr::sin(self.normalize(inner, stats)),
198 TLExpr::Cos(inner) => TLExpr::cos(self.normalize(inner, stats)),
199 TLExpr::Tan(inner) => TLExpr::tan(self.normalize(inner, stats)),
200 TLExpr::Score(inner) => TLExpr::score(self.normalize(inner, stats)),
201
202 TLExpr::Exists { var, domain, body } => {
204 TLExpr::exists(var.clone(), domain.clone(), self.normalize(body, stats))
205 }
206 TLExpr::ForAll { var, domain, body } => {
207 TLExpr::forall(var.clone(), domain.clone(), self.normalize(body, stats))
208 }
209
210 TLExpr::IfThenElse {
212 condition,
213 then_branch,
214 else_branch,
215 } => TLExpr::if_then_else(
216 self.normalize(condition, stats),
217 self.normalize(then_branch, stats),
218 self.normalize(else_branch, stats),
219 ),
220
221 TLExpr::Let { var, value, body } => TLExpr::let_binding(
223 var.clone(),
224 self.normalize(value, stats),
225 self.normalize(body, stats),
226 ),
227
228 TLExpr::Aggregate {
230 op,
231 var,
232 domain,
233 body,
234 group_by,
235 } => TLExpr::Aggregate {
236 op: op.clone(),
237 var: var.clone(),
238 domain: domain.clone(),
239 body: Box::new(self.normalize(body, stats)),
240 group_by: group_by.clone(),
241 },
242
243 TLExpr::TNorm { kind, left, right } => TLExpr::TNorm {
245 kind: *kind,
246 left: Box::new(self.normalize(left, stats)),
247 right: Box::new(self.normalize(right, stats)),
248 },
249 TLExpr::TCoNorm { kind, left, right } => TLExpr::TCoNorm {
250 kind: *kind,
251 left: Box::new(self.normalize(left, stats)),
252 right: Box::new(self.normalize(right, stats)),
253 },
254 TLExpr::FuzzyNot { kind, expr: inner } => TLExpr::FuzzyNot {
255 kind: *kind,
256 expr: Box::new(self.normalize(inner, stats)),
257 },
258 TLExpr::FuzzyImplication {
259 kind,
260 premise,
261 conclusion,
262 } => TLExpr::FuzzyImplication {
263 kind: *kind,
264 premise: Box::new(self.normalize(premise, stats)),
265 conclusion: Box::new(self.normalize(conclusion, stats)),
266 },
267
268 TLExpr::SoftExists {
270 var,
271 domain,
272 body,
273 temperature,
274 } => TLExpr::SoftExists {
275 var: var.clone(),
276 domain: domain.clone(),
277 body: Box::new(self.normalize(body, stats)),
278 temperature: *temperature,
279 },
280 TLExpr::SoftForAll {
281 var,
282 domain,
283 body,
284 temperature,
285 } => TLExpr::SoftForAll {
286 var: var.clone(),
287 domain: domain.clone(),
288 body: Box::new(self.normalize(body, stats)),
289 temperature: *temperature,
290 },
291
292 TLExpr::WeightedRule { weight, rule } => TLExpr::WeightedRule {
294 weight: *weight,
295 rule: Box::new(self.normalize(rule, stats)),
296 },
297
298 TLExpr::ProbabilisticChoice { alternatives } => {
300 let norm_alts: Vec<_> = alternatives
301 .iter()
302 .map(|(w, e)| (*w, self.normalize(e, stats)))
303 .collect();
304 TLExpr::ProbabilisticChoice {
305 alternatives: norm_alts,
306 }
307 }
308
309 TLExpr::Box(inner) => TLExpr::Box(Box::new(self.normalize(inner, stats))),
311 TLExpr::Diamond(inner) => TLExpr::Diamond(Box::new(self.normalize(inner, stats))),
312
313 TLExpr::Next(inner) => TLExpr::Next(Box::new(self.normalize(inner, stats))),
315 TLExpr::Eventually(inner) => TLExpr::Eventually(Box::new(self.normalize(inner, stats))),
316 TLExpr::Always(inner) => TLExpr::Always(Box::new(self.normalize(inner, stats))),
317 TLExpr::Until { before, after } => TLExpr::Until {
318 before: Box::new(self.normalize(before, stats)),
319 after: Box::new(self.normalize(after, stats)),
320 },
321 TLExpr::Release { released, releaser } => TLExpr::Release {
322 released: Box::new(self.normalize(released, stats)),
323 releaser: Box::new(self.normalize(releaser, stats)),
324 },
325 TLExpr::WeakUntil { before, after } => TLExpr::WeakUntil {
326 before: Box::new(self.normalize(before, stats)),
327 after: Box::new(self.normalize(after, stats)),
328 },
329 TLExpr::StrongRelease { released, releaser } => TLExpr::StrongRelease {
330 released: Box::new(self.normalize(released, stats)),
331 releaser: Box::new(self.normalize(releaser, stats)),
332 },
333
334 TLExpr::Lambda {
336 var,
337 var_type,
338 body,
339 } => TLExpr::Lambda {
340 var: var.clone(),
341 var_type: var_type.clone(),
342 body: Box::new(self.normalize(body, stats)),
343 },
344 TLExpr::Apply { function, argument } => TLExpr::Apply {
345 function: Box::new(self.normalize(function, stats)),
346 argument: Box::new(self.normalize(argument, stats)),
347 },
348
349 TLExpr::SetMembership { element, set } => TLExpr::SetMembership {
351 element: Box::new(self.normalize(element, stats)),
352 set: Box::new(self.normalize(set, stats)),
353 },
354 TLExpr::SetUnion { left, right } => TLExpr::SetUnion {
355 left: Box::new(self.normalize(left, stats)),
356 right: Box::new(self.normalize(right, stats)),
357 },
358 TLExpr::SetIntersection { left, right } => TLExpr::SetIntersection {
359 left: Box::new(self.normalize(left, stats)),
360 right: Box::new(self.normalize(right, stats)),
361 },
362 TLExpr::SetDifference { left, right } => TLExpr::SetDifference {
363 left: Box::new(self.normalize(left, stats)),
364 right: Box::new(self.normalize(right, stats)),
365 },
366 TLExpr::SetCardinality { set } => TLExpr::SetCardinality {
367 set: Box::new(self.normalize(set, stats)),
368 },
369 TLExpr::SetComprehension {
370 var,
371 domain,
372 condition,
373 } => TLExpr::SetComprehension {
374 var: var.clone(),
375 domain: domain.clone(),
376 condition: Box::new(self.normalize(condition, stats)),
377 },
378
379 TLExpr::CountingExists {
381 var,
382 domain,
383 body,
384 min_count,
385 } => TLExpr::CountingExists {
386 var: var.clone(),
387 domain: domain.clone(),
388 body: Box::new(self.normalize(body, stats)),
389 min_count: *min_count,
390 },
391 TLExpr::CountingForAll {
392 var,
393 domain,
394 body,
395 min_count,
396 } => TLExpr::CountingForAll {
397 var: var.clone(),
398 domain: domain.clone(),
399 body: Box::new(self.normalize(body, stats)),
400 min_count: *min_count,
401 },
402 TLExpr::ExactCount {
403 var,
404 domain,
405 body,
406 count,
407 } => TLExpr::ExactCount {
408 var: var.clone(),
409 domain: domain.clone(),
410 body: Box::new(self.normalize(body, stats)),
411 count: *count,
412 },
413 TLExpr::Majority { var, domain, body } => TLExpr::Majority {
414 var: var.clone(),
415 domain: domain.clone(),
416 body: Box::new(self.normalize(body, stats)),
417 },
418
419 TLExpr::LeastFixpoint { var, body } => TLExpr::LeastFixpoint {
421 var: var.clone(),
422 body: Box::new(self.normalize(body, stats)),
423 },
424 TLExpr::GreatestFixpoint { var, body } => TLExpr::GreatestFixpoint {
425 var: var.clone(),
426 body: Box::new(self.normalize(body, stats)),
427 },
428
429 TLExpr::At { nominal, formula } => TLExpr::At {
431 nominal: nominal.clone(),
432 formula: Box::new(self.normalize(formula, stats)),
433 },
434 TLExpr::Somewhere { formula } => TLExpr::Somewhere {
435 formula: Box::new(self.normalize(formula, stats)),
436 },
437 TLExpr::Everywhere { formula } => TLExpr::Everywhere {
438 formula: Box::new(self.normalize(formula, stats)),
439 },
440 TLExpr::Explain { formula } => TLExpr::Explain {
441 formula: Box::new(self.normalize(formula, stats)),
442 },
443
444 TLExpr::Pred { .. }
446 | TLExpr::Constant(_)
447 | TLExpr::EmptySet
448 | TLExpr::Nominal { .. }
449 | TLExpr::AllDifferent { .. }
450 | TLExpr::GlobalCardinality { .. }
451 | TLExpr::Abducible { .. }
452 | TLExpr::SymbolLiteral(_) => expr.clone(),
453
454 TLExpr::Match { scrutinee, arms } => TLExpr::Match {
455 scrutinee: Box::new(self.normalize(scrutinee, stats)),
456 arms: arms
457 .iter()
458 .map(|(p, b)| (p.clone(), Box::new(self.normalize(b, stats))))
459 .collect(),
460 },
461 }
462 }
463
464 fn collect_and_operands(
466 &self,
467 expr: &TLExpr,
468 operands: &mut Vec<TLExpr>,
469 stats: &mut CanonicalStats,
470 ) {
471 if let TLExpr::And(a, b) = expr {
472 stats.nested_flattened += 1;
473 stats.total_rewrites += 1;
474 self.collect_and_operands(a, operands, stats);
475 self.collect_and_operands(b, operands, stats);
476 } else {
477 operands.push(expr.clone());
478 }
479 }
480
481 fn collect_or_operands(
483 &self,
484 expr: &TLExpr,
485 operands: &mut Vec<TLExpr>,
486 stats: &mut CanonicalStats,
487 ) {
488 if let TLExpr::Or(a, b) = expr {
489 stats.nested_flattened += 1;
490 stats.total_rewrites += 1;
491 self.collect_or_operands(a, operands, stats);
492 self.collect_or_operands(b, operands, stats);
493 } else {
494 operands.push(expr.clone());
495 }
496 }
497
498 fn build_right_leaning_and(&self, mut operands: Vec<TLExpr>) -> TLExpr {
500 match operands.len() {
501 0 => TLExpr::Constant(1.0), 1 => operands.remove(0),
503 _ => {
504 let last = operands.pop();
506 operands.into_iter().rev().fold(
507 last.unwrap_or(TLExpr::Constant(1.0)),
509 |acc, elem| TLExpr::and(elem, acc),
510 )
511 }
512 }
513 }
514
515 fn build_right_leaning_or(&self, mut operands: Vec<TLExpr>) -> TLExpr {
517 match operands.len() {
518 0 => TLExpr::Constant(0.0), 1 => operands.remove(0),
520 _ => {
521 let last = operands.pop();
522 operands
523 .into_iter()
524 .rev()
525 .fold(last.unwrap_or(TLExpr::Constant(0.0)), |acc, elem| {
526 TLExpr::or(elem, acc)
527 })
528 }
529 }
530 }
531}
532
533pub fn canonical_order_key(expr: &TLExpr) -> String {
538 match expr {
539 TLExpr::Pred { name, args } => format!("P:{}:{}", name, args.len()),
540 TLExpr::Constant(v) => {
541 if v.is_nan() {
543 "C:NaN".to_string()
544 } else {
545 format!("C:{}", v)
546 }
547 }
548 TLExpr::Not(inner) => format!("Op:Not({})", canonical_order_key(inner)),
549 TLExpr::And(a, b) => format!(
550 "Op:And({},{})",
551 canonical_order_key(a),
552 canonical_order_key(b)
553 ),
554 TLExpr::Or(a, b) => format!(
555 "Op:Or({},{})",
556 canonical_order_key(a),
557 canonical_order_key(b)
558 ),
559 TLExpr::Imply(a, b) => format!(
560 "Op:Imply({},{})",
561 canonical_order_key(a),
562 canonical_order_key(b)
563 ),
564 TLExpr::Exists { var, domain, body } => {
565 format!("Q:Exists({},{},{})", var, domain, canonical_order_key(body))
566 }
567 TLExpr::ForAll { var, domain, body } => {
568 format!("Q:ForAll({},{},{})", var, domain, canonical_order_key(body))
569 }
570 TLExpr::Score(inner) => format!("Op:Score({})", canonical_order_key(inner)),
571 TLExpr::Add(a, b) => format!(
572 "Op:Add({},{})",
573 canonical_order_key(a),
574 canonical_order_key(b)
575 ),
576 TLExpr::Sub(a, b) => format!(
577 "Op:Sub({},{})",
578 canonical_order_key(a),
579 canonical_order_key(b)
580 ),
581 TLExpr::Mul(a, b) => format!(
582 "Op:Mul({},{})",
583 canonical_order_key(a),
584 canonical_order_key(b)
585 ),
586 TLExpr::Div(a, b) => format!(
587 "Op:Div({},{})",
588 canonical_order_key(a),
589 canonical_order_key(b)
590 ),
591 TLExpr::Pow(a, b) => format!(
592 "Op:Pow({},{})",
593 canonical_order_key(a),
594 canonical_order_key(b)
595 ),
596 TLExpr::Mod(a, b) => format!(
597 "Op:Mod({},{})",
598 canonical_order_key(a),
599 canonical_order_key(b)
600 ),
601 TLExpr::Min(a, b) => format!(
602 "Op:Min({},{})",
603 canonical_order_key(a),
604 canonical_order_key(b)
605 ),
606 TLExpr::Max(a, b) => format!(
607 "Op:Max({},{})",
608 canonical_order_key(a),
609 canonical_order_key(b)
610 ),
611 TLExpr::Eq(a, b) => format!(
612 "Op:Eq({},{})",
613 canonical_order_key(a),
614 canonical_order_key(b)
615 ),
616 TLExpr::Lt(a, b) => format!(
617 "Op:Lt({},{})",
618 canonical_order_key(a),
619 canonical_order_key(b)
620 ),
621 TLExpr::Gt(a, b) => format!(
622 "Op:Gt({},{})",
623 canonical_order_key(a),
624 canonical_order_key(b)
625 ),
626 TLExpr::Lte(a, b) => format!(
627 "Op:Lte({},{})",
628 canonical_order_key(a),
629 canonical_order_key(b)
630 ),
631 TLExpr::Gte(a, b) => format!(
632 "Op:Gte({},{})",
633 canonical_order_key(a),
634 canonical_order_key(b)
635 ),
636 TLExpr::Abs(inner) => format!("Op:Abs({})", canonical_order_key(inner)),
638 TLExpr::Floor(inner) => format!("Op:Floor({})", canonical_order_key(inner)),
639 TLExpr::Ceil(inner) => format!("Op:Ceil({})", canonical_order_key(inner)),
640 TLExpr::Round(inner) => format!("Op:Round({})", canonical_order_key(inner)),
641 TLExpr::Sqrt(inner) => format!("Op:Sqrt({})", canonical_order_key(inner)),
642 TLExpr::Exp(inner) => format!("Op:Exp({})", canonical_order_key(inner)),
643 TLExpr::Log(inner) => format!("Op:Log({})", canonical_order_key(inner)),
644 TLExpr::Sin(inner) => format!("Op:Sin({})", canonical_order_key(inner)),
645 TLExpr::Cos(inner) => format!("Op:Cos({})", canonical_order_key(inner)),
646 TLExpr::Tan(inner) => format!("Op:Tan({})", canonical_order_key(inner)),
647 TLExpr::EmptySet => "L:EmptySet".to_string(),
648 TLExpr::Nominal { name } => format!("L:Nominal({})", name),
649 other => format!("X:{:?}", other),
651 }
652}
653
654pub fn canonicalize(expr: &TLExpr) -> (TLExpr, CanonicalStats) {
656 Canonicalizer::new().canonicalize(expr)
657}
658
659#[cfg(test)]
660mod tests {
661 use super::*;
662 use tensorlogic_ir::Term;
663
664 fn pred_a() -> TLExpr {
665 TLExpr::pred("a", vec![Term::var("x")])
666 }
667
668 fn pred_b() -> TLExpr {
669 TLExpr::pred("b", vec![Term::var("x")])
670 }
671
672 fn pred_c() -> TLExpr {
673 TLExpr::pred("c", vec![Term::var("x")])
674 }
675
676 #[test]
677 fn test_double_neg_elimination() {
678 let p = pred_a();
679 let expr = TLExpr::negate(TLExpr::negate(p.clone()));
680 let (result, stats) = canonicalize(&expr);
681 assert_eq!(result, p);
682 assert_eq!(stats.double_neg_removed, 1);
683 }
684
685 #[test]
686 fn test_double_neg_nested_three() {
687 let p = pred_a();
689 let expr = TLExpr::negate(TLExpr::negate(TLExpr::negate(p.clone())));
690 let (result, _stats) = canonicalize(&expr);
691 assert_eq!(result, TLExpr::negate(p));
692 }
693
694 #[test]
695 fn test_and_commutative_sorted() {
696 let a = pred_a();
697 let b = pred_b();
698 let c = Canonicalizer::new();
699 let key1 = c.canonical_key(&TLExpr::and(b.clone(), a.clone()));
700 let key2 = c.canonical_key(&TLExpr::and(a.clone(), b.clone()));
701 assert_eq!(key1, key2);
702 }
703
704 #[test]
705 fn test_or_commutative_sorted() {
706 let a = pred_a();
707 let b = pred_b();
708 let c = Canonicalizer::new();
709 let key1 = c.canonical_key(&TLExpr::or(b.clone(), a.clone()));
710 let key2 = c.canonical_key(&TLExpr::or(a.clone(), b.clone()));
711 assert_eq!(key1, key2);
712 }
713
714 #[test]
715 fn test_nested_and_consistent() {
716 let a = pred_a();
718 let b = pred_b();
719 let c = pred_c();
720 let can = Canonicalizer::new();
721 let left_nested = TLExpr::and(TLExpr::and(a.clone(), b.clone()), c.clone());
722 let right_nested = TLExpr::and(a.clone(), TLExpr::and(b.clone(), c.clone()));
723 let key1 = can.canonical_key(&left_nested);
724 let key2 = can.canonical_key(&right_nested);
725 assert_eq!(key1, key2);
726 }
727
728 #[test]
729 fn test_canonical_key_deterministic() {
730 let expr = TLExpr::and(pred_a(), TLExpr::or(pred_b(), pred_c()));
731 let c = Canonicalizer::new();
732 let key1 = c.canonical_key(&expr);
733 let key2 = c.canonical_key(&expr);
734 assert_eq!(key1, key2);
735 }
736
737 #[test]
738 fn test_canonical_key_different_exprs() {
739 let c = Canonicalizer::new();
740 let key1 = c.canonical_key(&TLExpr::and(pred_a(), pred_b()));
741 let key2 = c.canonical_key(&TLExpr::or(pred_a(), pred_b()));
742 assert_ne!(key1, key2);
743 }
744
745 #[test]
746 fn test_stats_double_neg_counted() {
747 let expr = TLExpr::negate(TLExpr::negate(pred_a()));
748 let (_result, stats) = canonicalize(&expr);
749 assert_eq!(stats.double_neg_removed, 1);
750 assert!(stats.total_rewrites >= 1);
751 }
752
753 #[test]
754 fn test_stats_commutative_counted() {
755 let a = pred_a();
757 let b = pred_b();
758 let expr = TLExpr::and(b, a);
759 let (_result, stats) = canonicalize(&expr);
760 assert_eq!(stats.commutative_sorted, 1);
761 }
762
763 #[test]
764 fn test_stats_merge() {
765 let mut s1 = CanonicalStats {
766 double_neg_removed: 2,
767 commutative_sorted: 1,
768 nested_flattened: 3,
769 total_rewrites: 6,
770 };
771 let s2 = CanonicalStats {
772 double_neg_removed: 1,
773 commutative_sorted: 4,
774 nested_flattened: 0,
775 total_rewrites: 5,
776 };
777 s1.merge(&s2);
778 assert_eq!(s1.double_neg_removed, 3);
779 assert_eq!(s1.commutative_sorted, 5);
780 assert_eq!(s1.nested_flattened, 3);
781 assert_eq!(s1.total_rewrites, 11);
782 }
783
784 #[test]
785 fn test_canonicalize_pred_unchanged() {
786 let p = pred_a();
787 let (result, stats) = canonicalize(&p);
788 assert_eq!(result, p);
789 assert_eq!(stats.total_rewrites, 0);
790 }
791
792 #[test]
793 fn test_canonicalize_constant_unchanged() {
794 let c = TLExpr::Constant(42.0);
795 let (result, stats) = canonicalize(&c);
796 assert_eq!(result, c);
797 assert_eq!(stats.total_rewrites, 0);
798 }
799
800 #[test]
801 fn test_canonicalize_exists_recurses() {
802 let body = TLExpr::negate(TLExpr::negate(pred_a()));
804 let expr = TLExpr::exists("x", "D", body);
805 let (result, stats) = canonicalize(&expr);
806 assert_eq!(stats.double_neg_removed, 1);
807 if let TLExpr::Exists { body, .. } = &result {
808 assert!(matches!(body.as_ref(), TLExpr::Pred { .. }));
809 } else {
810 panic!("Expected Exists");
811 }
812 }
813
814 #[test]
815 fn test_canonicalize_forall_recurses() {
816 let body = TLExpr::negate(TLExpr::negate(pred_a()));
817 let expr = TLExpr::forall("x", "D", body);
818 let (result, stats) = canonicalize(&expr);
819 assert_eq!(stats.double_neg_removed, 1);
820 if let TLExpr::ForAll { body, .. } = &result {
821 assert!(matches!(body.as_ref(), TLExpr::Pred { .. }));
822 } else {
823 panic!("Expected ForAll");
824 }
825 }
826
827 #[test]
828 fn test_canonicalize_implication_recurses() {
829 let premise = TLExpr::negate(TLExpr::negate(pred_a()));
830 let conclusion = TLExpr::negate(TLExpr::negate(pred_b()));
831 let expr = TLExpr::imply(premise, conclusion);
832 let (result, stats) = canonicalize(&expr);
833 assert_eq!(stats.double_neg_removed, 2);
834 if let TLExpr::Imply(a, b) = &result {
835 assert!(matches!(a.as_ref(), TLExpr::Pred { .. }));
836 assert!(matches!(b.as_ref(), TLExpr::Pred { .. }));
837 } else {
838 panic!("Expected Imply");
839 }
840 }
841
842 #[test]
843 fn test_canonicalize_deep_nesting() {
844 let mut expr = pred_a();
846 for i in 0..50 {
847 let p = TLExpr::pred(format!("p{}", i), vec![Term::var("x")]);
848 expr = TLExpr::and(expr, p);
849 }
850 let (result, _stats) = canonicalize(&expr);
852 let _ = canonical_order_key(&result);
854 }
855
856 #[test]
857 fn test_canonical_order_key_pred() {
858 let p = pred_a();
859 let key = canonical_order_key(&p);
860 assert!(
861 key.starts_with("P:"),
862 "Expected key to start with 'P:', got: {}",
863 key
864 );
865 assert!(key.contains("a"));
866 }
867
868 #[test]
869 fn test_canonical_order_key_constant() {
870 let c = TLExpr::Constant(42.5);
871 let key = canonical_order_key(&c);
872 assert!(
873 key.starts_with("C:"),
874 "Expected key to start with 'C:', got: {}",
875 key
876 );
877 }
878
879 #[test]
880 fn test_convenience_fn() {
881 let expr = TLExpr::negate(TLExpr::negate(pred_a()));
882 let (result, stats) = canonicalize(&expr);
883 assert_eq!(result, pred_a());
884 assert_eq!(stats.double_neg_removed, 1);
885 }
886
887 #[test]
888 fn test_disabled_rules() {
889 let a = pred_a();
890 let b = pred_b();
891 let c = Canonicalizer::new().with_sort_commutative(false);
893 let expr = TLExpr::and(b.clone(), a.clone());
894 let (result, stats) = c.canonicalize(&expr);
895 assert_eq!(stats.commutative_sorted, 0);
896 if let TLExpr::And(left, right) = &result {
898 assert_eq!(left.as_ref(), &b);
899 assert_eq!(right.as_ref(), &a);
900 } else {
901 panic!("Expected And");
902 }
903 }
904}