Skip to main content

tensorlogic_compiler/optimize/
canonical.rs

1//! Expression canonicalization for improved cache hit rates.
2//!
3//! Puts TLExpr into a canonical structural form so that semantically
4//! equivalent expressions (e.g., AND(a,b) vs AND(b,a)) have identical
5//! fingerprints. This is purely structural normalization — not constant
6//! folding or algebraic simplification (those exist in separate passes).
7//!
8//! Three canonicalization rules:
9//! 1. **Double negation elimination**: NOT(NOT(x)) → x
10//! 2. **Nested same-op flattening**: AND(AND(a,b), c) → AND(a, AND(b,c)) sorted
11//! 3. **Commutative sorting**: AND/OR operands sorted by canonical_order_key
12
13use tensorlogic_ir::TLExpr;
14
15/// Statistics from canonicalization.
16#[derive(Debug, Clone, Default)]
17pub struct CanonicalStats {
18    /// Number of double negations removed
19    pub double_neg_removed: usize,
20    /// Number of commutative sorts applied
21    pub commutative_sorted: usize,
22    /// Number of nested same-op flattened
23    pub nested_flattened: usize,
24    /// Total rewrites performed
25    pub total_rewrites: usize,
26}
27
28impl CanonicalStats {
29    /// Merge another stats into this one.
30    pub fn merge(&mut self, other: &CanonicalStats) {
31        self.double_neg_removed += other.double_neg_removed;
32        self.commutative_sorted += other.commutative_sorted;
33        self.nested_flattened += other.nested_flattened;
34        self.total_rewrites += other.total_rewrites;
35    }
36}
37
38impl std::fmt::Display for CanonicalStats {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        write!(
41            f,
42            "CanonicalStats {{ double_neg: {}, comm_sorted: {}, flattened: {}, total: {} }}",
43            self.double_neg_removed,
44            self.commutative_sorted,
45            self.nested_flattened,
46            self.total_rewrites
47        )
48    }
49}
50
51/// Expression canonicalizer with configurable rules.
52#[derive(Debug, Clone)]
53pub struct Canonicalizer {
54    /// Whether to sort commutative operands (AND, OR).
55    pub sort_commutative: bool,
56    /// Whether to flatten nested same-op expressions.
57    pub flatten_nested: bool,
58    /// Whether to eliminate double negations.
59    pub elim_double_neg: bool,
60}
61
62impl Default for Canonicalizer {
63    fn default() -> Self {
64        Self::new()
65    }
66}
67
68impl Canonicalizer {
69    /// Create a new canonicalizer with all rules enabled.
70    pub fn new() -> Self {
71        Canonicalizer {
72            sort_commutative: true,
73            flatten_nested: true,
74            elim_double_neg: true,
75        }
76    }
77
78    /// Set whether to sort commutative operands.
79    pub fn with_sort_commutative(mut self, v: bool) -> Self {
80        self.sort_commutative = v;
81        self
82    }
83
84    /// Set whether to flatten nested same-op expressions.
85    pub fn with_flatten_nested(mut self, v: bool) -> Self {
86        self.flatten_nested = v;
87        self
88    }
89
90    /// Set whether to eliminate double negations.
91    pub fn with_elim_double_neg(mut self, v: bool) -> Self {
92        self.elim_double_neg = v;
93        self
94    }
95
96    /// Canonicalize an expression, returning the normalized form and stats.
97    pub fn canonicalize(&self, expr: &TLExpr) -> (TLExpr, CanonicalStats) {
98        let mut stats = CanonicalStats::default();
99        let result = self.normalize(expr, &mut stats);
100        (result, stats)
101    }
102
103    /// Produce a deterministic canonical key string for cache use.
104    pub fn canonical_key(&self, expr: &TLExpr) -> String {
105        let (normalized, _) = self.canonicalize(expr);
106        format!("{:?}", normalized)
107    }
108
109    fn normalize(&self, expr: &TLExpr, stats: &mut CanonicalStats) -> TLExpr {
110        match expr {
111            // Double negation elimination: NOT(NOT(x)) → x
112            TLExpr::Not(inner) => {
113                if self.elim_double_neg {
114                    if let TLExpr::Not(inner_inner) = inner.as_ref() {
115                        stats.double_neg_removed += 1;
116                        stats.total_rewrites += 1;
117                        return self.normalize(inner_inner, stats);
118                    }
119                }
120                TLExpr::negate(self.normalize(inner, stats))
121            }
122
123            // AND: flatten nested AND, then sort commutative operands
124            TLExpr::And(a, b) => {
125                let norm_a = self.normalize(a, stats);
126                let norm_b = self.normalize(b, stats);
127                let mut operands = Vec::new();
128                if self.flatten_nested {
129                    self.collect_and_operands(&norm_a, &mut operands, stats);
130                    self.collect_and_operands(&norm_b, &mut operands, stats);
131                } else {
132                    operands.push(norm_a);
133                    operands.push(norm_b);
134                }
135                if self.sort_commutative {
136                    let before = operands.iter().map(canonical_order_key).collect::<Vec<_>>();
137                    operands.sort_by_key(canonical_order_key);
138                    let after = operands.iter().map(canonical_order_key).collect::<Vec<_>>();
139                    if before != after {
140                        stats.commutative_sorted += 1;
141                        stats.total_rewrites += 1;
142                    }
143                }
144                self.build_right_leaning_and(operands)
145            }
146
147            // OR: flatten nested OR, then sort commutative operands
148            TLExpr::Or(a, b) => {
149                let norm_a = self.normalize(a, stats);
150                let norm_b = self.normalize(b, stats);
151                let mut operands = Vec::new();
152                if self.flatten_nested {
153                    self.collect_or_operands(&norm_a, &mut operands, stats);
154                    self.collect_or_operands(&norm_b, &mut operands, stats);
155                } else {
156                    operands.push(norm_a);
157                    operands.push(norm_b);
158                }
159                if self.sort_commutative {
160                    let before = operands.iter().map(canonical_order_key).collect::<Vec<_>>();
161                    operands.sort_by_key(canonical_order_key);
162                    let after = operands.iter().map(canonical_order_key).collect::<Vec<_>>();
163                    if before != after {
164                        stats.commutative_sorted += 1;
165                        stats.total_rewrites += 1;
166                    }
167                }
168                self.build_right_leaning_or(operands)
169            }
170
171            // Recurse into binary operators (non-commutative in canonicalization sense)
172            TLExpr::Imply(a, b) => {
173                TLExpr::imply(self.normalize(a, stats), self.normalize(b, stats))
174            }
175            TLExpr::Add(a, b) => TLExpr::add(self.normalize(a, stats), self.normalize(b, stats)),
176            TLExpr::Sub(a, b) => TLExpr::sub(self.normalize(a, stats), self.normalize(b, stats)),
177            TLExpr::Mul(a, b) => TLExpr::mul(self.normalize(a, stats), self.normalize(b, stats)),
178            TLExpr::Div(a, b) => TLExpr::div(self.normalize(a, stats), self.normalize(b, stats)),
179            TLExpr::Pow(a, b) => TLExpr::pow(self.normalize(a, stats), self.normalize(b, stats)),
180            TLExpr::Mod(a, b) => TLExpr::modulo(self.normalize(a, stats), self.normalize(b, stats)),
181            TLExpr::Min(a, b) => TLExpr::min(self.normalize(a, stats), self.normalize(b, stats)),
182            TLExpr::Max(a, b) => TLExpr::max(self.normalize(a, stats), self.normalize(b, stats)),
183            TLExpr::Eq(a, b) => TLExpr::eq(self.normalize(a, stats), self.normalize(b, stats)),
184            TLExpr::Lt(a, b) => TLExpr::lt(self.normalize(a, stats), self.normalize(b, stats)),
185            TLExpr::Gt(a, b) => TLExpr::gt(self.normalize(a, stats), self.normalize(b, stats)),
186            TLExpr::Lte(a, b) => TLExpr::lte(self.normalize(a, stats), self.normalize(b, stats)),
187            TLExpr::Gte(a, b) => TLExpr::gte(self.normalize(a, stats), self.normalize(b, stats)),
188
189            // Unary math
190            TLExpr::Abs(inner) => TLExpr::abs(self.normalize(inner, stats)),
191            TLExpr::Floor(inner) => TLExpr::floor(self.normalize(inner, stats)),
192            TLExpr::Ceil(inner) => TLExpr::ceil(self.normalize(inner, stats)),
193            TLExpr::Round(inner) => TLExpr::round(self.normalize(inner, stats)),
194            TLExpr::Sqrt(inner) => TLExpr::sqrt(self.normalize(inner, stats)),
195            TLExpr::Exp(inner) => TLExpr::exp(self.normalize(inner, stats)),
196            TLExpr::Log(inner) => TLExpr::log(self.normalize(inner, stats)),
197            TLExpr::Sin(inner) => TLExpr::sin(self.normalize(inner, stats)),
198            TLExpr::Cos(inner) => TLExpr::cos(self.normalize(inner, stats)),
199            TLExpr::Tan(inner) => TLExpr::tan(self.normalize(inner, stats)),
200            TLExpr::Score(inner) => TLExpr::score(self.normalize(inner, stats)),
201
202            // Quantifiers
203            TLExpr::Exists { var, domain, body } => {
204                TLExpr::exists(var.clone(), domain.clone(), self.normalize(body, stats))
205            }
206            TLExpr::ForAll { var, domain, body } => {
207                TLExpr::forall(var.clone(), domain.clone(), self.normalize(body, stats))
208            }
209
210            // Conditional
211            TLExpr::IfThenElse {
212                condition,
213                then_branch,
214                else_branch,
215            } => TLExpr::if_then_else(
216                self.normalize(condition, stats),
217                self.normalize(then_branch, stats),
218                self.normalize(else_branch, stats),
219            ),
220
221            // Let binding
222            TLExpr::Let { var, value, body } => TLExpr::let_binding(
223                var.clone(),
224                self.normalize(value, stats),
225                self.normalize(body, stats),
226            ),
227
228            // Aggregate
229            TLExpr::Aggregate {
230                op,
231                var,
232                domain,
233                body,
234                group_by,
235            } => TLExpr::Aggregate {
236                op: op.clone(),
237                var: var.clone(),
238                domain: domain.clone(),
239                body: Box::new(self.normalize(body, stats)),
240                group_by: group_by.clone(),
241            },
242
243            // Fuzzy logic operators
244            TLExpr::TNorm { kind, left, right } => TLExpr::TNorm {
245                kind: *kind,
246                left: Box::new(self.normalize(left, stats)),
247                right: Box::new(self.normalize(right, stats)),
248            },
249            TLExpr::TCoNorm { kind, left, right } => TLExpr::TCoNorm {
250                kind: *kind,
251                left: Box::new(self.normalize(left, stats)),
252                right: Box::new(self.normalize(right, stats)),
253            },
254            TLExpr::FuzzyNot { kind, expr: inner } => TLExpr::FuzzyNot {
255                kind: *kind,
256                expr: Box::new(self.normalize(inner, stats)),
257            },
258            TLExpr::FuzzyImplication {
259                kind,
260                premise,
261                conclusion,
262            } => TLExpr::FuzzyImplication {
263                kind: *kind,
264                premise: Box::new(self.normalize(premise, stats)),
265                conclusion: Box::new(self.normalize(conclusion, stats)),
266            },
267
268            // Soft quantifiers
269            TLExpr::SoftExists {
270                var,
271                domain,
272                body,
273                temperature,
274            } => TLExpr::SoftExists {
275                var: var.clone(),
276                domain: domain.clone(),
277                body: Box::new(self.normalize(body, stats)),
278                temperature: *temperature,
279            },
280            TLExpr::SoftForAll {
281                var,
282                domain,
283                body,
284                temperature,
285            } => TLExpr::SoftForAll {
286                var: var.clone(),
287                domain: domain.clone(),
288                body: Box::new(self.normalize(body, stats)),
289                temperature: *temperature,
290            },
291
292            // Weighted rule
293            TLExpr::WeightedRule { weight, rule } => TLExpr::WeightedRule {
294                weight: *weight,
295                rule: Box::new(self.normalize(rule, stats)),
296            },
297
298            // Probabilistic choice
299            TLExpr::ProbabilisticChoice { alternatives } => {
300                let norm_alts: Vec<_> = alternatives
301                    .iter()
302                    .map(|(w, e)| (*w, self.normalize(e, stats)))
303                    .collect();
304                TLExpr::ProbabilisticChoice {
305                    alternatives: norm_alts,
306                }
307            }
308
309            // Modal logic
310            TLExpr::Box(inner) => TLExpr::Box(Box::new(self.normalize(inner, stats))),
311            TLExpr::Diamond(inner) => TLExpr::Diamond(Box::new(self.normalize(inner, stats))),
312
313            // Temporal logic
314            TLExpr::Next(inner) => TLExpr::Next(Box::new(self.normalize(inner, stats))),
315            TLExpr::Eventually(inner) => TLExpr::Eventually(Box::new(self.normalize(inner, stats))),
316            TLExpr::Always(inner) => TLExpr::Always(Box::new(self.normalize(inner, stats))),
317            TLExpr::Until { before, after } => TLExpr::Until {
318                before: Box::new(self.normalize(before, stats)),
319                after: Box::new(self.normalize(after, stats)),
320            },
321            TLExpr::Release { released, releaser } => TLExpr::Release {
322                released: Box::new(self.normalize(released, stats)),
323                releaser: Box::new(self.normalize(releaser, stats)),
324            },
325            TLExpr::WeakUntil { before, after } => TLExpr::WeakUntil {
326                before: Box::new(self.normalize(before, stats)),
327                after: Box::new(self.normalize(after, stats)),
328            },
329            TLExpr::StrongRelease { released, releaser } => TLExpr::StrongRelease {
330                released: Box::new(self.normalize(released, stats)),
331                releaser: Box::new(self.normalize(releaser, stats)),
332            },
333
334            // Higher-order
335            TLExpr::Lambda {
336                var,
337                var_type,
338                body,
339            } => TLExpr::Lambda {
340                var: var.clone(),
341                var_type: var_type.clone(),
342                body: Box::new(self.normalize(body, stats)),
343            },
344            TLExpr::Apply { function, argument } => TLExpr::Apply {
345                function: Box::new(self.normalize(function, stats)),
346                argument: Box::new(self.normalize(argument, stats)),
347            },
348
349            // Set operations
350            TLExpr::SetMembership { element, set } => TLExpr::SetMembership {
351                element: Box::new(self.normalize(element, stats)),
352                set: Box::new(self.normalize(set, stats)),
353            },
354            TLExpr::SetUnion { left, right } => TLExpr::SetUnion {
355                left: Box::new(self.normalize(left, stats)),
356                right: Box::new(self.normalize(right, stats)),
357            },
358            TLExpr::SetIntersection { left, right } => TLExpr::SetIntersection {
359                left: Box::new(self.normalize(left, stats)),
360                right: Box::new(self.normalize(right, stats)),
361            },
362            TLExpr::SetDifference { left, right } => TLExpr::SetDifference {
363                left: Box::new(self.normalize(left, stats)),
364                right: Box::new(self.normalize(right, stats)),
365            },
366            TLExpr::SetCardinality { set } => TLExpr::SetCardinality {
367                set: Box::new(self.normalize(set, stats)),
368            },
369            TLExpr::SetComprehension {
370                var,
371                domain,
372                condition,
373            } => TLExpr::SetComprehension {
374                var: var.clone(),
375                domain: domain.clone(),
376                condition: Box::new(self.normalize(condition, stats)),
377            },
378
379            // Counting quantifiers
380            TLExpr::CountingExists {
381                var,
382                domain,
383                body,
384                min_count,
385            } => TLExpr::CountingExists {
386                var: var.clone(),
387                domain: domain.clone(),
388                body: Box::new(self.normalize(body, stats)),
389                min_count: *min_count,
390            },
391            TLExpr::CountingForAll {
392                var,
393                domain,
394                body,
395                min_count,
396            } => TLExpr::CountingForAll {
397                var: var.clone(),
398                domain: domain.clone(),
399                body: Box::new(self.normalize(body, stats)),
400                min_count: *min_count,
401            },
402            TLExpr::ExactCount {
403                var,
404                domain,
405                body,
406                count,
407            } => TLExpr::ExactCount {
408                var: var.clone(),
409                domain: domain.clone(),
410                body: Box::new(self.normalize(body, stats)),
411                count: *count,
412            },
413            TLExpr::Majority { var, domain, body } => TLExpr::Majority {
414                var: var.clone(),
415                domain: domain.clone(),
416                body: Box::new(self.normalize(body, stats)),
417            },
418
419            // Fixed-point operators
420            TLExpr::LeastFixpoint { var, body } => TLExpr::LeastFixpoint {
421                var: var.clone(),
422                body: Box::new(self.normalize(body, stats)),
423            },
424            TLExpr::GreatestFixpoint { var, body } => TLExpr::GreatestFixpoint {
425                var: var.clone(),
426                body: Box::new(self.normalize(body, stats)),
427            },
428
429            // Hybrid logic
430            TLExpr::At { nominal, formula } => TLExpr::At {
431                nominal: nominal.clone(),
432                formula: Box::new(self.normalize(formula, stats)),
433            },
434            TLExpr::Somewhere { formula } => TLExpr::Somewhere {
435                formula: Box::new(self.normalize(formula, stats)),
436            },
437            TLExpr::Everywhere { formula } => TLExpr::Everywhere {
438                formula: Box::new(self.normalize(formula, stats)),
439            },
440            TLExpr::Explain { formula } => TLExpr::Explain {
441                formula: Box::new(self.normalize(formula, stats)),
442            },
443
444            // Leaves and remaining variants
445            TLExpr::Pred { .. }
446            | TLExpr::Constant(_)
447            | TLExpr::EmptySet
448            | TLExpr::Nominal { .. }
449            | TLExpr::AllDifferent { .. }
450            | TLExpr::GlobalCardinality { .. }
451            | TLExpr::Abducible { .. }
452            | TLExpr::SymbolLiteral(_) => expr.clone(),
453
454            TLExpr::Match { scrutinee, arms } => TLExpr::Match {
455                scrutinee: Box::new(self.normalize(scrutinee, stats)),
456                arms: arms
457                    .iter()
458                    .map(|(p, b)| (p.clone(), Box::new(self.normalize(b, stats))))
459                    .collect(),
460            },
461        }
462    }
463
464    /// Collect all operands from nested AND expressions (flattening).
465    fn collect_and_operands(
466        &self,
467        expr: &TLExpr,
468        operands: &mut Vec<TLExpr>,
469        stats: &mut CanonicalStats,
470    ) {
471        if let TLExpr::And(a, b) = expr {
472            stats.nested_flattened += 1;
473            stats.total_rewrites += 1;
474            self.collect_and_operands(a, operands, stats);
475            self.collect_and_operands(b, operands, stats);
476        } else {
477            operands.push(expr.clone());
478        }
479    }
480
481    /// Collect all operands from nested OR expressions (flattening).
482    fn collect_or_operands(
483        &self,
484        expr: &TLExpr,
485        operands: &mut Vec<TLExpr>,
486        stats: &mut CanonicalStats,
487    ) {
488        if let TLExpr::Or(a, b) = expr {
489            stats.nested_flattened += 1;
490            stats.total_rewrites += 1;
491            self.collect_or_operands(a, operands, stats);
492            self.collect_or_operands(b, operands, stats);
493        } else {
494            operands.push(expr.clone());
495        }
496    }
497
498    /// Build a right-leaning AND tree from a list of operands.
499    fn build_right_leaning_and(&self, mut operands: Vec<TLExpr>) -> TLExpr {
500        match operands.len() {
501            0 => TLExpr::Constant(1.0), // identity for AND (true)
502            1 => operands.remove(0),
503            _ => {
504                // Build right-leaning: AND(a, AND(b, AND(c, d)))
505                let last = operands.pop();
506                operands.into_iter().rev().fold(
507                    // Safe: len >= 2 so pop always returns Some
508                    last.unwrap_or(TLExpr::Constant(1.0)),
509                    |acc, elem| TLExpr::and(elem, acc),
510                )
511            }
512        }
513    }
514
515    /// Build a right-leaning OR tree from a list of operands.
516    fn build_right_leaning_or(&self, mut operands: Vec<TLExpr>) -> TLExpr {
517        match operands.len() {
518            0 => TLExpr::Constant(0.0), // identity for OR (false)
519            1 => operands.remove(0),
520            _ => {
521                let last = operands.pop();
522                operands
523                    .into_iter()
524                    .rev()
525                    .fold(last.unwrap_or(TLExpr::Constant(0.0)), |acc, elem| {
526                        TLExpr::or(elem, acc)
527                    })
528            }
529        }
530    }
531}
532
533/// Compute a canonical ordering key for sorting commutative children.
534///
535/// Produces a deterministic string representation suitable for ordering.
536/// This ensures AND(a,b) and AND(b,a) sort to the same canonical form.
537pub fn canonical_order_key(expr: &TLExpr) -> String {
538    match expr {
539        TLExpr::Pred { name, args } => format!("P:{}:{}", name, args.len()),
540        TLExpr::Constant(v) => {
541            // Use a canonical float representation
542            if v.is_nan() {
543                "C:NaN".to_string()
544            } else {
545                format!("C:{}", v)
546            }
547        }
548        TLExpr::Not(inner) => format!("Op:Not({})", canonical_order_key(inner)),
549        TLExpr::And(a, b) => format!(
550            "Op:And({},{})",
551            canonical_order_key(a),
552            canonical_order_key(b)
553        ),
554        TLExpr::Or(a, b) => format!(
555            "Op:Or({},{})",
556            canonical_order_key(a),
557            canonical_order_key(b)
558        ),
559        TLExpr::Imply(a, b) => format!(
560            "Op:Imply({},{})",
561            canonical_order_key(a),
562            canonical_order_key(b)
563        ),
564        TLExpr::Exists { var, domain, body } => {
565            format!("Q:Exists({},{},{})", var, domain, canonical_order_key(body))
566        }
567        TLExpr::ForAll { var, domain, body } => {
568            format!("Q:ForAll({},{},{})", var, domain, canonical_order_key(body))
569        }
570        TLExpr::Score(inner) => format!("Op:Score({})", canonical_order_key(inner)),
571        TLExpr::Add(a, b) => format!(
572            "Op:Add({},{})",
573            canonical_order_key(a),
574            canonical_order_key(b)
575        ),
576        TLExpr::Sub(a, b) => format!(
577            "Op:Sub({},{})",
578            canonical_order_key(a),
579            canonical_order_key(b)
580        ),
581        TLExpr::Mul(a, b) => format!(
582            "Op:Mul({},{})",
583            canonical_order_key(a),
584            canonical_order_key(b)
585        ),
586        TLExpr::Div(a, b) => format!(
587            "Op:Div({},{})",
588            canonical_order_key(a),
589            canonical_order_key(b)
590        ),
591        TLExpr::Pow(a, b) => format!(
592            "Op:Pow({},{})",
593            canonical_order_key(a),
594            canonical_order_key(b)
595        ),
596        TLExpr::Mod(a, b) => format!(
597            "Op:Mod({},{})",
598            canonical_order_key(a),
599            canonical_order_key(b)
600        ),
601        TLExpr::Min(a, b) => format!(
602            "Op:Min({},{})",
603            canonical_order_key(a),
604            canonical_order_key(b)
605        ),
606        TLExpr::Max(a, b) => format!(
607            "Op:Max({},{})",
608            canonical_order_key(a),
609            canonical_order_key(b)
610        ),
611        TLExpr::Eq(a, b) => format!(
612            "Op:Eq({},{})",
613            canonical_order_key(a),
614            canonical_order_key(b)
615        ),
616        TLExpr::Lt(a, b) => format!(
617            "Op:Lt({},{})",
618            canonical_order_key(a),
619            canonical_order_key(b)
620        ),
621        TLExpr::Gt(a, b) => format!(
622            "Op:Gt({},{})",
623            canonical_order_key(a),
624            canonical_order_key(b)
625        ),
626        TLExpr::Lte(a, b) => format!(
627            "Op:Lte({},{})",
628            canonical_order_key(a),
629            canonical_order_key(b)
630        ),
631        TLExpr::Gte(a, b) => format!(
632            "Op:Gte({},{})",
633            canonical_order_key(a),
634            canonical_order_key(b)
635        ),
636        // Unary math
637        TLExpr::Abs(inner) => format!("Op:Abs({})", canonical_order_key(inner)),
638        TLExpr::Floor(inner) => format!("Op:Floor({})", canonical_order_key(inner)),
639        TLExpr::Ceil(inner) => format!("Op:Ceil({})", canonical_order_key(inner)),
640        TLExpr::Round(inner) => format!("Op:Round({})", canonical_order_key(inner)),
641        TLExpr::Sqrt(inner) => format!("Op:Sqrt({})", canonical_order_key(inner)),
642        TLExpr::Exp(inner) => format!("Op:Exp({})", canonical_order_key(inner)),
643        TLExpr::Log(inner) => format!("Op:Log({})", canonical_order_key(inner)),
644        TLExpr::Sin(inner) => format!("Op:Sin({})", canonical_order_key(inner)),
645        TLExpr::Cos(inner) => format!("Op:Cos({})", canonical_order_key(inner)),
646        TLExpr::Tan(inner) => format!("Op:Tan({})", canonical_order_key(inner)),
647        TLExpr::EmptySet => "L:EmptySet".to_string(),
648        TLExpr::Nominal { name } => format!("L:Nominal({})", name),
649        // For all other complex variants, use Debug for deterministic ordering
650        other => format!("X:{:?}", other),
651    }
652}
653
654/// Convenience function: canonicalize and return the result.
655pub fn canonicalize(expr: &TLExpr) -> (TLExpr, CanonicalStats) {
656    Canonicalizer::new().canonicalize(expr)
657}
658
659#[cfg(test)]
660mod tests {
661    use super::*;
662    use tensorlogic_ir::Term;
663
664    fn pred_a() -> TLExpr {
665        TLExpr::pred("a", vec![Term::var("x")])
666    }
667
668    fn pred_b() -> TLExpr {
669        TLExpr::pred("b", vec![Term::var("x")])
670    }
671
672    fn pred_c() -> TLExpr {
673        TLExpr::pred("c", vec![Term::var("x")])
674    }
675
676    #[test]
677    fn test_double_neg_elimination() {
678        let p = pred_a();
679        let expr = TLExpr::negate(TLExpr::negate(p.clone()));
680        let (result, stats) = canonicalize(&expr);
681        assert_eq!(result, p);
682        assert_eq!(stats.double_neg_removed, 1);
683    }
684
685    #[test]
686    fn test_double_neg_nested_three() {
687        // NOT(NOT(NOT(pred))) → NOT(pred)
688        let p = pred_a();
689        let expr = TLExpr::negate(TLExpr::negate(TLExpr::negate(p.clone())));
690        let (result, _stats) = canonicalize(&expr);
691        assert_eq!(result, TLExpr::negate(p));
692    }
693
694    #[test]
695    fn test_and_commutative_sorted() {
696        let a = pred_a();
697        let b = pred_b();
698        let c = Canonicalizer::new();
699        let key1 = c.canonical_key(&TLExpr::and(b.clone(), a.clone()));
700        let key2 = c.canonical_key(&TLExpr::and(a.clone(), b.clone()));
701        assert_eq!(key1, key2);
702    }
703
704    #[test]
705    fn test_or_commutative_sorted() {
706        let a = pred_a();
707        let b = pred_b();
708        let c = Canonicalizer::new();
709        let key1 = c.canonical_key(&TLExpr::or(b.clone(), a.clone()));
710        let key2 = c.canonical_key(&TLExpr::or(a.clone(), b.clone()));
711        assert_eq!(key1, key2);
712    }
713
714    #[test]
715    fn test_nested_and_consistent() {
716        // AND(AND(a,b), c) key == AND(a, AND(b,c)) key (both flatten + sort)
717        let a = pred_a();
718        let b = pred_b();
719        let c = pred_c();
720        let can = Canonicalizer::new();
721        let left_nested = TLExpr::and(TLExpr::and(a.clone(), b.clone()), c.clone());
722        let right_nested = TLExpr::and(a.clone(), TLExpr::and(b.clone(), c.clone()));
723        let key1 = can.canonical_key(&left_nested);
724        let key2 = can.canonical_key(&right_nested);
725        assert_eq!(key1, key2);
726    }
727
728    #[test]
729    fn test_canonical_key_deterministic() {
730        let expr = TLExpr::and(pred_a(), TLExpr::or(pred_b(), pred_c()));
731        let c = Canonicalizer::new();
732        let key1 = c.canonical_key(&expr);
733        let key2 = c.canonical_key(&expr);
734        assert_eq!(key1, key2);
735    }
736
737    #[test]
738    fn test_canonical_key_different_exprs() {
739        let c = Canonicalizer::new();
740        let key1 = c.canonical_key(&TLExpr::and(pred_a(), pred_b()));
741        let key2 = c.canonical_key(&TLExpr::or(pred_a(), pred_b()));
742        assert_ne!(key1, key2);
743    }
744
745    #[test]
746    fn test_stats_double_neg_counted() {
747        let expr = TLExpr::negate(TLExpr::negate(pred_a()));
748        let (_result, stats) = canonicalize(&expr);
749        assert_eq!(stats.double_neg_removed, 1);
750        assert!(stats.total_rewrites >= 1);
751    }
752
753    #[test]
754    fn test_stats_commutative_counted() {
755        // AND(b, a) should sort to AND(a, b), incrementing commutative_sorted
756        let a = pred_a();
757        let b = pred_b();
758        let expr = TLExpr::and(b, a);
759        let (_result, stats) = canonicalize(&expr);
760        assert_eq!(stats.commutative_sorted, 1);
761    }
762
763    #[test]
764    fn test_stats_merge() {
765        let mut s1 = CanonicalStats {
766            double_neg_removed: 2,
767            commutative_sorted: 1,
768            nested_flattened: 3,
769            total_rewrites: 6,
770        };
771        let s2 = CanonicalStats {
772            double_neg_removed: 1,
773            commutative_sorted: 4,
774            nested_flattened: 0,
775            total_rewrites: 5,
776        };
777        s1.merge(&s2);
778        assert_eq!(s1.double_neg_removed, 3);
779        assert_eq!(s1.commutative_sorted, 5);
780        assert_eq!(s1.nested_flattened, 3);
781        assert_eq!(s1.total_rewrites, 11);
782    }
783
784    #[test]
785    fn test_canonicalize_pred_unchanged() {
786        let p = pred_a();
787        let (result, stats) = canonicalize(&p);
788        assert_eq!(result, p);
789        assert_eq!(stats.total_rewrites, 0);
790    }
791
792    #[test]
793    fn test_canonicalize_constant_unchanged() {
794        let c = TLExpr::Constant(42.0);
795        let (result, stats) = canonicalize(&c);
796        assert_eq!(result, c);
797        assert_eq!(stats.total_rewrites, 0);
798    }
799
800    #[test]
801    fn test_canonicalize_exists_recurses() {
802        // exists body containing double neg should be canonicalized
803        let body = TLExpr::negate(TLExpr::negate(pred_a()));
804        let expr = TLExpr::exists("x", "D", body);
805        let (result, stats) = canonicalize(&expr);
806        assert_eq!(stats.double_neg_removed, 1);
807        if let TLExpr::Exists { body, .. } = &result {
808            assert!(matches!(body.as_ref(), TLExpr::Pred { .. }));
809        } else {
810            panic!("Expected Exists");
811        }
812    }
813
814    #[test]
815    fn test_canonicalize_forall_recurses() {
816        let body = TLExpr::negate(TLExpr::negate(pred_a()));
817        let expr = TLExpr::forall("x", "D", body);
818        let (result, stats) = canonicalize(&expr);
819        assert_eq!(stats.double_neg_removed, 1);
820        if let TLExpr::ForAll { body, .. } = &result {
821            assert!(matches!(body.as_ref(), TLExpr::Pred { .. }));
822        } else {
823            panic!("Expected ForAll");
824        }
825    }
826
827    #[test]
828    fn test_canonicalize_implication_recurses() {
829        let premise = TLExpr::negate(TLExpr::negate(pred_a()));
830        let conclusion = TLExpr::negate(TLExpr::negate(pred_b()));
831        let expr = TLExpr::imply(premise, conclusion);
832        let (result, stats) = canonicalize(&expr);
833        assert_eq!(stats.double_neg_removed, 2);
834        if let TLExpr::Imply(a, b) = &result {
835            assert!(matches!(a.as_ref(), TLExpr::Pred { .. }));
836            assert!(matches!(b.as_ref(), TLExpr::Pred { .. }));
837        } else {
838            panic!("Expected Imply");
839        }
840    }
841
842    #[test]
843    fn test_canonicalize_deep_nesting() {
844        // Build a deeply nested expression: AND(AND(AND(...), b), c)
845        let mut expr = pred_a();
846        for i in 0..50 {
847            let p = TLExpr::pred(format!("p{}", i), vec![Term::var("x")]);
848            expr = TLExpr::and(expr, p);
849        }
850        // Should not stack overflow
851        let (result, _stats) = canonicalize(&expr);
852        // Result should be valid (just check it doesn't panic)
853        let _ = canonical_order_key(&result);
854    }
855
856    #[test]
857    fn test_canonical_order_key_pred() {
858        let p = pred_a();
859        let key = canonical_order_key(&p);
860        assert!(
861            key.starts_with("P:"),
862            "Expected key to start with 'P:', got: {}",
863            key
864        );
865        assert!(key.contains("a"));
866    }
867
868    #[test]
869    fn test_canonical_order_key_constant() {
870        let c = TLExpr::Constant(42.5);
871        let key = canonical_order_key(&c);
872        assert!(
873            key.starts_with("C:"),
874            "Expected key to start with 'C:', got: {}",
875            key
876        );
877    }
878
879    #[test]
880    fn test_convenience_fn() {
881        let expr = TLExpr::negate(TLExpr::negate(pred_a()));
882        let (result, stats) = canonicalize(&expr);
883        assert_eq!(result, pred_a());
884        assert_eq!(stats.double_neg_removed, 1);
885    }
886
887    #[test]
888    fn test_disabled_rules() {
889        let a = pred_a();
890        let b = pred_b();
891        // With sort disabled, AND(b, a) should NOT be sorted
892        let c = Canonicalizer::new().with_sort_commutative(false);
893        let expr = TLExpr::and(b.clone(), a.clone());
894        let (result, stats) = c.canonicalize(&expr);
895        assert_eq!(stats.commutative_sorted, 0);
896        // The result should still be AND(b, a), not AND(a, b)
897        if let TLExpr::And(left, right) = &result {
898            assert_eq!(left.as_ref(), &b);
899            assert_eq!(right.as_ref(), &a);
900        } else {
901            panic!("Expected And");
902        }
903    }
904}