Skip to main content

tensorlogic_compiler/
expr_diff.rs

1//! Structural diff between two TLExpr trees.
2//!
3//! Identifies additions, removals, and modifications at each
4//! node in the expression tree. Useful for debugging incremental
5//! compilation and tracking rule evolution.
6
7use serde::{Deserialize, Serialize};
8use tensorlogic_ir::TLExpr;
9
10/// Kind of change between two expression nodes.
11#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
12pub enum DiffKind {
13    /// No change
14    Unchanged,
15    /// Node was added (present in new, absent in old)
16    Added,
17    /// Node was removed (present in old, absent in new)
18    Removed,
19    /// Node type changed (e.g., And -> Or)
20    TypeChanged { old_type: String, new_type: String },
21    /// Node parameters changed (e.g., different predicate name)
22    ParameterChanged {
23        old_value: String,
24        new_value: String,
25    },
26    /// Children changed (recurse into sub-diffs)
27    ChildrenChanged,
28}
29
30impl DiffKind {
31    /// Returns `true` if this represents an actual change (not `Unchanged`).
32    pub fn is_change(&self) -> bool {
33        !matches!(self, DiffKind::Unchanged)
34    }
35}
36
37/// A single diff entry at a specific path in the expression tree.
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct DiffEntry {
40    /// Path from root (e.g., `["left", "body", "arg0"]`)
41    pub path: Vec<String>,
42    /// Kind of change
43    pub kind: DiffKind,
44    /// Human-readable description
45    pub description: String,
46}
47
48/// Complete diff result between two expressions.
49#[derive(Debug, Clone, Default, Serialize, Deserialize)]
50pub struct ExprDiff {
51    /// All diff entries found during comparison.
52    pub entries: Vec<DiffEntry>,
53}
54
55impl ExprDiff {
56    /// Create an empty diff.
57    pub fn new() -> Self {
58        Self::default()
59    }
60
61    /// Returns `true` if the two expressions are structurally identical.
62    pub fn is_identical(&self) -> bool {
63        self.entries.is_empty() || self.entries.iter().all(|e| !e.kind.is_change())
64    }
65
66    /// Count the number of actual changes (excluding `Unchanged` entries).
67    pub fn change_count(&self) -> usize {
68        self.entries.iter().filter(|e| e.kind.is_change()).count()
69    }
70
71    /// Filter to only addition entries.
72    pub fn additions(&self) -> Vec<&DiffEntry> {
73        self.entries
74            .iter()
75            .filter(|e| matches!(e.kind, DiffKind::Added))
76            .collect()
77    }
78
79    /// Filter to only removal entries.
80    pub fn removals(&self) -> Vec<&DiffEntry> {
81        self.entries
82            .iter()
83            .filter(|e| matches!(e.kind, DiffKind::Removed))
84            .collect()
85    }
86
87    /// Filter to only modification entries (TypeChanged or ParameterChanged).
88    pub fn modifications(&self) -> Vec<&DiffEntry> {
89        self.entries
90            .iter()
91            .filter(|e| {
92                matches!(
93                    e.kind,
94                    DiffKind::TypeChanged { .. } | DiffKind::ParameterChanged { .. }
95                )
96            })
97            .collect()
98    }
99
100    /// Human-readable summary of the diff.
101    pub fn summary(&self) -> String {
102        format!(
103            "{} changes ({} added, {} removed, {} modified)",
104            self.change_count(),
105            self.additions().len(),
106            self.removals().len(),
107            self.modifications().len()
108        )
109    }
110}
111
112/// Compute the structural diff between two expressions.
113pub fn expr_diff(old: &TLExpr, new: &TLExpr) -> ExprDiff {
114    let mut diff = ExprDiff::new();
115    compare_recursive(old, new, &[], &mut diff);
116    diff
117}
118
119/// Get a short type tag for an expression variant.
120pub fn expr_type_tag(expr: &TLExpr) -> String {
121    match expr {
122        TLExpr::Pred { .. } => "Pred".to_string(),
123        TLExpr::And(..) => "And".to_string(),
124        TLExpr::Or(..) => "Or".to_string(),
125        TLExpr::Not(..) => "Not".to_string(),
126        TLExpr::Exists { .. } => "Exists".to_string(),
127        TLExpr::ForAll { .. } => "ForAll".to_string(),
128        TLExpr::Imply(..) => "Imply".to_string(),
129        TLExpr::Score(..) => "Score".to_string(),
130        TLExpr::Add(..) => "Add".to_string(),
131        TLExpr::Sub(..) => "Sub".to_string(),
132        TLExpr::Mul(..) => "Mul".to_string(),
133        TLExpr::Div(..) => "Div".to_string(),
134        TLExpr::Pow(..) => "Pow".to_string(),
135        TLExpr::Mod(..) => "Mod".to_string(),
136        TLExpr::Min(..) => "Min".to_string(),
137        TLExpr::Max(..) => "Max".to_string(),
138        TLExpr::Abs(..) => "Abs".to_string(),
139        TLExpr::Floor(..) => "Floor".to_string(),
140        TLExpr::Ceil(..) => "Ceil".to_string(),
141        TLExpr::Round(..) => "Round".to_string(),
142        TLExpr::Sqrt(..) => "Sqrt".to_string(),
143        TLExpr::Exp(..) => "Exp".to_string(),
144        TLExpr::Log(..) => "Log".to_string(),
145        TLExpr::Sin(..) => "Sin".to_string(),
146        TLExpr::Cos(..) => "Cos".to_string(),
147        TLExpr::Tan(..) => "Tan".to_string(),
148        TLExpr::Eq(..) => "Eq".to_string(),
149        TLExpr::Lt(..) => "Lt".to_string(),
150        TLExpr::Gt(..) => "Gt".to_string(),
151        TLExpr::Lte(..) => "Lte".to_string(),
152        TLExpr::Gte(..) => "Gte".to_string(),
153        TLExpr::IfThenElse { .. } => "IfThenElse".to_string(),
154        TLExpr::Constant(..) => "Constant".to_string(),
155        TLExpr::Aggregate { .. } => "Aggregate".to_string(),
156        TLExpr::Let { .. } => "Let".to_string(),
157        TLExpr::Box(..) => "Box".to_string(),
158        TLExpr::Diamond(..) => "Diamond".to_string(),
159        TLExpr::Next(..) => "Next".to_string(),
160        TLExpr::Eventually(..) => "Eventually".to_string(),
161        TLExpr::Always(..) => "Always".to_string(),
162        TLExpr::Until { .. } => "Until".to_string(),
163        TLExpr::TNorm { .. } => "TNorm".to_string(),
164        TLExpr::TCoNorm { .. } => "TCoNorm".to_string(),
165        TLExpr::FuzzyNot { .. } => "FuzzyNot".to_string(),
166        TLExpr::FuzzyImplication { .. } => "FuzzyImplication".to_string(),
167        TLExpr::SoftExists { .. } => "SoftExists".to_string(),
168        TLExpr::SoftForAll { .. } => "SoftForAll".to_string(),
169        TLExpr::WeightedRule { .. } => "WeightedRule".to_string(),
170        TLExpr::ProbabilisticChoice { .. } => "ProbabilisticChoice".to_string(),
171        TLExpr::Release { .. } => "Release".to_string(),
172        TLExpr::WeakUntil { .. } => "WeakUntil".to_string(),
173        TLExpr::StrongRelease { .. } => "StrongRelease".to_string(),
174        TLExpr::Lambda { .. } => "Lambda".to_string(),
175        TLExpr::Apply { .. } => "Apply".to_string(),
176        TLExpr::SetMembership { .. } => "SetMembership".to_string(),
177        TLExpr::SetUnion { .. } => "SetUnion".to_string(),
178        TLExpr::SetIntersection { .. } => "SetIntersection".to_string(),
179        TLExpr::SetDifference { .. } => "SetDifference".to_string(),
180        TLExpr::SetCardinality { .. } => "SetCardinality".to_string(),
181        TLExpr::EmptySet => "EmptySet".to_string(),
182        TLExpr::SetComprehension { .. } => "SetComprehension".to_string(),
183        TLExpr::CountingExists { .. } => "CountingExists".to_string(),
184        TLExpr::CountingForAll { .. } => "CountingForAll".to_string(),
185        TLExpr::ExactCount { .. } => "ExactCount".to_string(),
186        TLExpr::Majority { .. } => "Majority".to_string(),
187        TLExpr::LeastFixpoint { .. } => "LeastFixpoint".to_string(),
188        TLExpr::GreatestFixpoint { .. } => "GreatestFixpoint".to_string(),
189        TLExpr::Nominal { .. } => "Nominal".to_string(),
190        TLExpr::At { .. } => "At".to_string(),
191        TLExpr::Somewhere { .. } => "Somewhere".to_string(),
192        TLExpr::Everywhere { .. } => "Everywhere".to_string(),
193        TLExpr::AllDifferent { .. } => "AllDifferent".to_string(),
194        TLExpr::GlobalCardinality { .. } => "GlobalCardinality".to_string(),
195        TLExpr::Abducible { .. } => "Abducible".to_string(),
196        TLExpr::Explain { .. } => "Explain".to_string(),
197        TLExpr::SymbolLiteral(_) => "SymbolLiteral".to_string(),
198        TLExpr::Match { .. } => "Match".to_string(),
199    }
200}
201
202/// Compare two children at a named path position.
203fn compare_child(
204    old: &TLExpr,
205    new: &TLExpr,
206    parent_path: &[String],
207    child_name: &str,
208    diff: &mut ExprDiff,
209) {
210    let mut path = parent_path.to_vec();
211    path.push(child_name.to_string());
212    compare_recursive(old, new, &path, diff);
213}
214
215/// Record an addition entry at the given path.
216fn record_added(path: &[String], child_name: &str, desc: &str, diff: &mut ExprDiff) {
217    let mut p = path.to_vec();
218    p.push(child_name.to_string());
219    diff.entries.push(DiffEntry {
220        path: p,
221        kind: DiffKind::Added,
222        description: desc.to_string(),
223    });
224}
225
226/// Record a removal entry at the given path.
227fn record_removed(path: &[String], child_name: &str, desc: &str, diff: &mut ExprDiff) {
228    let mut p = path.to_vec();
229    p.push(child_name.to_string());
230    diff.entries.push(DiffEntry {
231        path: p,
232        kind: DiffKind::Removed,
233        description: desc.to_string(),
234    });
235}
236
237/// Compare arguments lists, reporting per-element changes.
238fn compare_args(
239    old_args: &[tensorlogic_ir::Term],
240    new_args: &[tensorlogic_ir::Term],
241    path: &[String],
242    diff: &mut ExprDiff,
243) {
244    let common_len = old_args.len().min(new_args.len());
245    for i in 0..common_len {
246        if old_args[i] != new_args[i] {
247            let mut p = path.to_vec();
248            p.push(format!("arg{}", i));
249            diff.entries.push(DiffEntry {
250                path: p,
251                kind: DiffKind::ParameterChanged {
252                    old_value: format!("{:?}", old_args[i]),
253                    new_value: format!("{:?}", new_args[i]),
254                },
255                description: format!("Arg {} changed", i),
256            });
257        }
258    }
259    for i in common_len..new_args.len() {
260        record_added(
261            path,
262            &format!("arg{}", i),
263            &format!("Arg {} added", i),
264            diff,
265        );
266    }
267    for i in common_len..old_args.len() {
268        record_removed(
269            path,
270            &format!("arg{}", i),
271            &format!("Arg {} removed", i),
272            diff,
273        );
274    }
275}
276
277/// Compare string parameters at a given field name.
278fn compare_string_param(
279    old_val: &str,
280    new_val: &str,
281    path: &[String],
282    field: &str,
283    label: &str,
284    diff: &mut ExprDiff,
285) {
286    if old_val != new_val {
287        let mut p = path.to_vec();
288        p.push(field.to_string());
289        diff.entries.push(DiffEntry {
290            path: p,
291            kind: DiffKind::ParameterChanged {
292                old_value: old_val.to_string(),
293                new_value: new_val.to_string(),
294            },
295            description: format!("{}: {} -> {}", label, old_val, new_val),
296        });
297    }
298}
299
300/// Compare f64 parameters at a given field name.
301fn compare_f64_param(
302    old_val: f64,
303    new_val: f64,
304    path: &[String],
305    field: &str,
306    label: &str,
307    diff: &mut ExprDiff,
308) {
309    if (old_val - new_val).abs() > f64::EPSILON {
310        let mut p = path.to_vec();
311        p.push(field.to_string());
312        diff.entries.push(DiffEntry {
313            path: p,
314            kind: DiffKind::ParameterChanged {
315                old_value: format!("{}", old_val),
316                new_value: format!("{}", new_val),
317            },
318            description: format!("{}: {} -> {}", label, old_val, new_val),
319        });
320    }
321}
322
323/// Compare usize parameters at a given field name.
324fn compare_usize_param(
325    old_val: usize,
326    new_val: usize,
327    path: &[String],
328    field: &str,
329    label: &str,
330    diff: &mut ExprDiff,
331) {
332    if old_val != new_val {
333        let mut p = path.to_vec();
334        p.push(field.to_string());
335        diff.entries.push(DiffEntry {
336            path: p,
337            kind: DiffKind::ParameterChanged {
338                old_value: format!("{}", old_val),
339                new_value: format!("{}", new_val),
340            },
341            description: format!("{}: {} -> {}", label, old_val, new_val),
342        });
343    }
344}
345
346/// Recursively compare two expression trees.
347fn compare_recursive(old: &TLExpr, new: &TLExpr, path: &[String], diff: &mut ExprDiff) {
348    let old_tag = expr_type_tag(old);
349    let new_tag = expr_type_tag(new);
350
351    if old_tag != new_tag {
352        diff.entries.push(DiffEntry {
353            path: path.to_vec(),
354            kind: DiffKind::TypeChanged {
355                old_type: old_tag.clone(),
356                new_type: new_tag.clone(),
357            },
358            description: format!("Changed from {} to {}", old_tag, new_tag),
359        });
360        return;
361    }
362
363    match (old, new) {
364        // Pred: compare name and args
365        (TLExpr::Pred { name: n1, args: a1 }, TLExpr::Pred { name: n2, args: a2 }) => {
366            compare_string_param(n1, n2, path, "name", "Predicate name", diff);
367            compare_args(a1, a2, path, diff);
368        }
369
370        // Binary logical/arithmetic ops
371        (TLExpr::And(l1, r1), TLExpr::And(l2, r2))
372        | (TLExpr::Or(l1, r1), TLExpr::Or(l2, r2))
373        | (TLExpr::Imply(l1, r1), TLExpr::Imply(l2, r2))
374        | (TLExpr::Add(l1, r1), TLExpr::Add(l2, r2))
375        | (TLExpr::Sub(l1, r1), TLExpr::Sub(l2, r2))
376        | (TLExpr::Mul(l1, r1), TLExpr::Mul(l2, r2))
377        | (TLExpr::Div(l1, r1), TLExpr::Div(l2, r2))
378        | (TLExpr::Pow(l1, r1), TLExpr::Pow(l2, r2))
379        | (TLExpr::Mod(l1, r1), TLExpr::Mod(l2, r2))
380        | (TLExpr::Min(l1, r1), TLExpr::Min(l2, r2))
381        | (TLExpr::Max(l1, r1), TLExpr::Max(l2, r2))
382        | (TLExpr::Eq(l1, r1), TLExpr::Eq(l2, r2))
383        | (TLExpr::Lt(l1, r1), TLExpr::Lt(l2, r2))
384        | (TLExpr::Gt(l1, r1), TLExpr::Gt(l2, r2))
385        | (TLExpr::Lte(l1, r1), TLExpr::Lte(l2, r2))
386        | (TLExpr::Gte(l1, r1), TLExpr::Gte(l2, r2)) => {
387            compare_child(l1, l2, path, "left", diff);
388            compare_child(r1, r2, path, "right", diff);
389        }
390
391        // Unary ops
392        (TLExpr::Not(c1), TLExpr::Not(c2))
393        | (TLExpr::Score(c1), TLExpr::Score(c2))
394        | (TLExpr::Abs(c1), TLExpr::Abs(c2))
395        | (TLExpr::Floor(c1), TLExpr::Floor(c2))
396        | (TLExpr::Ceil(c1), TLExpr::Ceil(c2))
397        | (TLExpr::Round(c1), TLExpr::Round(c2))
398        | (TLExpr::Sqrt(c1), TLExpr::Sqrt(c2))
399        | (TLExpr::Exp(c1), TLExpr::Exp(c2))
400        | (TLExpr::Log(c1), TLExpr::Log(c2))
401        | (TLExpr::Sin(c1), TLExpr::Sin(c2))
402        | (TLExpr::Cos(c1), TLExpr::Cos(c2))
403        | (TLExpr::Tan(c1), TLExpr::Tan(c2))
404        | (TLExpr::Box(c1), TLExpr::Box(c2))
405        | (TLExpr::Diamond(c1), TLExpr::Diamond(c2))
406        | (TLExpr::Next(c1), TLExpr::Next(c2))
407        | (TLExpr::Eventually(c1), TLExpr::Eventually(c2))
408        | (TLExpr::Always(c1), TLExpr::Always(c2)) => {
409            compare_child(c1, c2, path, "child", diff);
410        }
411
412        // Quantifiers: Exists / ForAll
413        (
414            TLExpr::Exists {
415                var: v1,
416                domain: d1,
417                body: b1,
418            },
419            TLExpr::Exists {
420                var: v2,
421                domain: d2,
422                body: b2,
423            },
424        )
425        | (
426            TLExpr::ForAll {
427                var: v1,
428                domain: d1,
429                body: b1,
430            },
431            TLExpr::ForAll {
432                var: v2,
433                domain: d2,
434                body: b2,
435            },
436        ) => {
437            compare_string_param(v1, v2, path, "var", "Variable", diff);
438            compare_string_param(d1, d2, path, "domain", "Domain", diff);
439            compare_child(b1, b2, path, "body", diff);
440        }
441
442        // Constant
443        (TLExpr::Constant(v1), TLExpr::Constant(v2)) => {
444            compare_f64_param(*v1, *v2, path, "value", "Constant", diff);
445        }
446
447        // IfThenElse
448        (
449            TLExpr::IfThenElse {
450                condition: c1,
451                then_branch: t1,
452                else_branch: e1,
453            },
454            TLExpr::IfThenElse {
455                condition: c2,
456                then_branch: t2,
457                else_branch: e2,
458            },
459        ) => {
460            compare_child(c1, c2, path, "condition", diff);
461            compare_child(t1, t2, path, "then_branch", diff);
462            compare_child(e1, e2, path, "else_branch", diff);
463        }
464
465        // Aggregate
466        (
467            TLExpr::Aggregate {
468                op: op1,
469                var: v1,
470                domain: d1,
471                body: b1,
472                group_by: g1,
473            },
474            TLExpr::Aggregate {
475                op: op2,
476                var: v2,
477                domain: d2,
478                body: b2,
479                group_by: g2,
480            },
481        ) => {
482            if op1 != op2 {
483                let mut p = path.to_vec();
484                p.push("op".to_string());
485                diff.entries.push(DiffEntry {
486                    path: p,
487                    kind: DiffKind::ParameterChanged {
488                        old_value: format!("{:?}", op1),
489                        new_value: format!("{:?}", op2),
490                    },
491                    description: format!("Aggregate op: {:?} -> {:?}", op1, op2),
492                });
493            }
494            compare_string_param(v1, v2, path, "var", "Variable", diff);
495            compare_string_param(d1, d2, path, "domain", "Domain", diff);
496            compare_child(b1, b2, path, "body", diff);
497            if g1 != g2 {
498                let mut p = path.to_vec();
499                p.push("group_by".to_string());
500                diff.entries.push(DiffEntry {
501                    path: p,
502                    kind: DiffKind::ParameterChanged {
503                        old_value: format!("{:?}", g1),
504                        new_value: format!("{:?}", g2),
505                    },
506                    description: "Group-by changed".to_string(),
507                });
508            }
509        }
510
511        // Let binding
512        (
513            TLExpr::Let {
514                var: v1,
515                value: val1,
516                body: b1,
517            },
518            TLExpr::Let {
519                var: v2,
520                value: val2,
521                body: b2,
522            },
523        ) => {
524            compare_string_param(v1, v2, path, "var", "Variable", diff);
525            compare_child(val1, val2, path, "value", diff);
526            compare_child(b1, b2, path, "body", diff);
527        }
528
529        // Until / WeakUntil
530        (
531            TLExpr::Until {
532                before: b1,
533                after: a1,
534            },
535            TLExpr::Until {
536                before: b2,
537                after: a2,
538            },
539        )
540        | (
541            TLExpr::WeakUntil {
542                before: b1,
543                after: a1,
544            },
545            TLExpr::WeakUntil {
546                before: b2,
547                after: a2,
548            },
549        ) => {
550            compare_child(b1, b2, path, "before", diff);
551            compare_child(a1, a2, path, "after", diff);
552        }
553
554        // Release / StrongRelease
555        (
556            TLExpr::Release {
557                released: r1,
558                releaser: l1,
559            },
560            TLExpr::Release {
561                released: r2,
562                releaser: l2,
563            },
564        )
565        | (
566            TLExpr::StrongRelease {
567                released: r1,
568                releaser: l1,
569            },
570            TLExpr::StrongRelease {
571                released: r2,
572                releaser: l2,
573            },
574        ) => {
575            compare_child(r1, r2, path, "released", diff);
576            compare_child(l1, l2, path, "releaser", diff);
577        }
578
579        // TNorm
580        (
581            TLExpr::TNorm {
582                kind: k1,
583                left: l1,
584                right: r1,
585            },
586            TLExpr::TNorm {
587                kind: k2,
588                left: l2,
589                right: r2,
590            },
591        ) => {
592            if k1 != k2 {
593                let mut p = path.to_vec();
594                p.push("kind".to_string());
595                diff.entries.push(DiffEntry {
596                    path: p,
597                    kind: DiffKind::ParameterChanged {
598                        old_value: format!("{:?}", k1),
599                        new_value: format!("{:?}", k2),
600                    },
601                    description: format!("TNorm kind: {:?} -> {:?}", k1, k2),
602                });
603            }
604            compare_child(l1, l2, path, "left", diff);
605            compare_child(r1, r2, path, "right", diff);
606        }
607
608        // TCoNorm
609        (
610            TLExpr::TCoNorm {
611                kind: k1,
612                left: l1,
613                right: r1,
614            },
615            TLExpr::TCoNorm {
616                kind: k2,
617                left: l2,
618                right: r2,
619            },
620        ) => {
621            if k1 != k2 {
622                let mut p = path.to_vec();
623                p.push("kind".to_string());
624                diff.entries.push(DiffEntry {
625                    path: p,
626                    kind: DiffKind::ParameterChanged {
627                        old_value: format!("{:?}", k1),
628                        new_value: format!("{:?}", k2),
629                    },
630                    description: format!("TCoNorm kind: {:?} -> {:?}", k1, k2),
631                });
632            }
633            compare_child(l1, l2, path, "left", diff);
634            compare_child(r1, r2, path, "right", diff);
635        }
636
637        // FuzzyNot
638        (TLExpr::FuzzyNot { kind: k1, expr: e1 }, TLExpr::FuzzyNot { kind: k2, expr: e2 }) => {
639            if k1 != k2 {
640                let mut p = path.to_vec();
641                p.push("kind".to_string());
642                diff.entries.push(DiffEntry {
643                    path: p,
644                    kind: DiffKind::ParameterChanged {
645                        old_value: format!("{:?}", k1),
646                        new_value: format!("{:?}", k2),
647                    },
648                    description: format!("FuzzyNot kind: {:?} -> {:?}", k1, k2),
649                });
650            }
651            compare_child(e1, e2, path, "expr", diff);
652        }
653
654        // FuzzyImplication
655        (
656            TLExpr::FuzzyImplication {
657                kind: k1,
658                premise: p1,
659                conclusion: c1,
660            },
661            TLExpr::FuzzyImplication {
662                kind: k2,
663                premise: p2,
664                conclusion: c2,
665            },
666        ) => {
667            if k1 != k2 {
668                let mut p = path.to_vec();
669                p.push("kind".to_string());
670                diff.entries.push(DiffEntry {
671                    path: p,
672                    kind: DiffKind::ParameterChanged {
673                        old_value: format!("{:?}", k1),
674                        new_value: format!("{:?}", k2),
675                    },
676                    description: format!("FuzzyImplication kind: {:?} -> {:?}", k1, k2),
677                });
678            }
679            compare_child(p1, p2, path, "premise", diff);
680            compare_child(c1, c2, path, "conclusion", diff);
681        }
682
683        // SoftExists
684        (
685            TLExpr::SoftExists {
686                var: v1,
687                domain: d1,
688                body: b1,
689                temperature: t1,
690            },
691            TLExpr::SoftExists {
692                var: v2,
693                domain: d2,
694                body: b2,
695                temperature: t2,
696            },
697        ) => {
698            compare_string_param(v1, v2, path, "var", "Variable", diff);
699            compare_string_param(d1, d2, path, "domain", "Domain", diff);
700            compare_child(b1, b2, path, "body", diff);
701            compare_f64_param(*t1, *t2, path, "temperature", "Temperature", diff);
702        }
703
704        // SoftForAll
705        (
706            TLExpr::SoftForAll {
707                var: v1,
708                domain: d1,
709                body: b1,
710                temperature: t1,
711            },
712            TLExpr::SoftForAll {
713                var: v2,
714                domain: d2,
715                body: b2,
716                temperature: t2,
717            },
718        ) => {
719            compare_string_param(v1, v2, path, "var", "Variable", diff);
720            compare_string_param(d1, d2, path, "domain", "Domain", diff);
721            compare_child(b1, b2, path, "body", diff);
722            compare_f64_param(*t1, *t2, path, "temperature", "Temperature", diff);
723        }
724
725        // WeightedRule
726        (
727            TLExpr::WeightedRule {
728                weight: w1,
729                rule: r1,
730            },
731            TLExpr::WeightedRule {
732                weight: w2,
733                rule: r2,
734            },
735        ) => {
736            compare_f64_param(*w1, *w2, path, "weight", "Weight", diff);
737            compare_child(r1, r2, path, "rule", diff);
738        }
739
740        // ProbabilisticChoice
741        (
742            TLExpr::ProbabilisticChoice { alternatives: a1 },
743            TLExpr::ProbabilisticChoice { alternatives: a2 },
744        ) => {
745            let common_len = a1.len().min(a2.len());
746            for i in 0..common_len {
747                compare_f64_param(
748                    a1[i].0,
749                    a2[i].0,
750                    path,
751                    &format!("alt{}_prob", i),
752                    &format!("Alternative {} probability", i),
753                    diff,
754                );
755                compare_child(&a1[i].1, &a2[i].1, path, &format!("alt{}_expr", i), diff);
756            }
757            for i in common_len..a2.len() {
758                record_added(
759                    path,
760                    &format!("alt{}", i),
761                    &format!("Alternative {} added", i),
762                    diff,
763                );
764            }
765            for i in common_len..a1.len() {
766                record_removed(
767                    path,
768                    &format!("alt{}", i),
769                    &format!("Alternative {} removed", i),
770                    diff,
771                );
772            }
773        }
774
775        // Lambda
776        (
777            TLExpr::Lambda {
778                var: v1,
779                var_type: t1,
780                body: b1,
781            },
782            TLExpr::Lambda {
783                var: v2,
784                var_type: t2,
785                body: b2,
786            },
787        ) => {
788            compare_string_param(v1, v2, path, "var", "Variable", diff);
789            if t1 != t2 {
790                let mut p = path.to_vec();
791                p.push("var_type".to_string());
792                diff.entries.push(DiffEntry {
793                    path: p,
794                    kind: DiffKind::ParameterChanged {
795                        old_value: format!("{:?}", t1),
796                        new_value: format!("{:?}", t2),
797                    },
798                    description: format!("Type annotation: {:?} -> {:?}", t1, t2),
799                });
800            }
801            compare_child(b1, b2, path, "body", diff);
802        }
803
804        // Apply
805        (
806            TLExpr::Apply {
807                function: f1,
808                argument: a1,
809            },
810            TLExpr::Apply {
811                function: f2,
812                argument: a2,
813            },
814        ) => {
815            compare_child(f1, f2, path, "function", diff);
816            compare_child(a1, a2, path, "argument", diff);
817        }
818
819        // Set ops with left/right
820        (
821            TLExpr::SetMembership {
822                element: e1,
823                set: s1,
824            },
825            TLExpr::SetMembership {
826                element: e2,
827                set: s2,
828            },
829        ) => {
830            compare_child(e1, e2, path, "element", diff);
831            compare_child(s1, s2, path, "set", diff);
832        }
833
834        (
835            TLExpr::SetUnion {
836                left: l1,
837                right: r1,
838            },
839            TLExpr::SetUnion {
840                left: l2,
841                right: r2,
842            },
843        )
844        | (
845            TLExpr::SetIntersection {
846                left: l1,
847                right: r1,
848            },
849            TLExpr::SetIntersection {
850                left: l2,
851                right: r2,
852            },
853        )
854        | (
855            TLExpr::SetDifference {
856                left: l1,
857                right: r1,
858            },
859            TLExpr::SetDifference {
860                left: l2,
861                right: r2,
862            },
863        ) => {
864            compare_child(l1, l2, path, "left", diff);
865            compare_child(r1, r2, path, "right", diff);
866        }
867
868        // SetCardinality
869        (TLExpr::SetCardinality { set: s1 }, TLExpr::SetCardinality { set: s2 }) => {
870            compare_child(s1, s2, path, "set", diff);
871        }
872
873        // EmptySet
874        (TLExpr::EmptySet, TLExpr::EmptySet) => {
875            // identical
876        }
877
878        // SetComprehension
879        (
880            TLExpr::SetComprehension {
881                var: v1,
882                domain: d1,
883                condition: c1,
884            },
885            TLExpr::SetComprehension {
886                var: v2,
887                domain: d2,
888                condition: c2,
889            },
890        ) => {
891            compare_string_param(v1, v2, path, "var", "Variable", diff);
892            compare_string_param(d1, d2, path, "domain", "Domain", diff);
893            compare_child(c1, c2, path, "condition", diff);
894        }
895
896        // Counting quantifiers
897        (
898            TLExpr::CountingExists {
899                var: v1,
900                domain: d1,
901                body: b1,
902                min_count: mc1,
903            },
904            TLExpr::CountingExists {
905                var: v2,
906                domain: d2,
907                body: b2,
908                min_count: mc2,
909            },
910        )
911        | (
912            TLExpr::CountingForAll {
913                var: v1,
914                domain: d1,
915                body: b1,
916                min_count: mc1,
917            },
918            TLExpr::CountingForAll {
919                var: v2,
920                domain: d2,
921                body: b2,
922                min_count: mc2,
923            },
924        ) => {
925            compare_string_param(v1, v2, path, "var", "Variable", diff);
926            compare_string_param(d1, d2, path, "domain", "Domain", diff);
927            compare_child(b1, b2, path, "body", diff);
928            compare_usize_param(*mc1, *mc2, path, "min_count", "Min count", diff);
929        }
930
931        // ExactCount
932        (
933            TLExpr::ExactCount {
934                var: v1,
935                domain: d1,
936                body: b1,
937                count: c1,
938            },
939            TLExpr::ExactCount {
940                var: v2,
941                domain: d2,
942                body: b2,
943                count: c2,
944            },
945        ) => {
946            compare_string_param(v1, v2, path, "var", "Variable", diff);
947            compare_string_param(d1, d2, path, "domain", "Domain", diff);
948            compare_child(b1, b2, path, "body", diff);
949            compare_usize_param(*c1, *c2, path, "count", "Count", diff);
950        }
951
952        // Majority
953        (
954            TLExpr::Majority {
955                var: v1,
956                domain: d1,
957                body: b1,
958            },
959            TLExpr::Majority {
960                var: v2,
961                domain: d2,
962                body: b2,
963            },
964        ) => {
965            compare_string_param(v1, v2, path, "var", "Variable", diff);
966            compare_string_param(d1, d2, path, "domain", "Domain", diff);
967            compare_child(b1, b2, path, "body", diff);
968        }
969
970        // Fixed-point operators
971        (
972            TLExpr::LeastFixpoint { var: v1, body: b1 },
973            TLExpr::LeastFixpoint { var: v2, body: b2 },
974        )
975        | (
976            TLExpr::GreatestFixpoint { var: v1, body: b1 },
977            TLExpr::GreatestFixpoint { var: v2, body: b2 },
978        ) => {
979            compare_string_param(v1, v2, path, "var", "Variable", diff);
980            compare_child(b1, b2, path, "body", diff);
981        }
982
983        // Nominal
984        (TLExpr::Nominal { name: n1 }, TLExpr::Nominal { name: n2 }) => {
985            compare_string_param(n1, n2, path, "name", "Nominal name", diff);
986        }
987
988        // At
989        (
990            TLExpr::At {
991                nominal: n1,
992                formula: f1,
993            },
994            TLExpr::At {
995                nominal: n2,
996                formula: f2,
997            },
998        ) => {
999            compare_string_param(n1, n2, path, "nominal", "Nominal", diff);
1000            compare_child(f1, f2, path, "formula", diff);
1001        }
1002
1003        // Somewhere / Everywhere
1004        (TLExpr::Somewhere { formula: f1 }, TLExpr::Somewhere { formula: f2 })
1005        | (TLExpr::Everywhere { formula: f1 }, TLExpr::Everywhere { formula: f2 }) => {
1006            compare_child(f1, f2, path, "formula", diff);
1007        }
1008
1009        // Explain
1010        (TLExpr::Explain { formula: f1 }, TLExpr::Explain { formula: f2 }) => {
1011            compare_child(f1, f2, path, "formula", diff);
1012        }
1013
1014        // AllDifferent
1015        (TLExpr::AllDifferent { variables: vars1 }, TLExpr::AllDifferent { variables: vars2 }) => {
1016            if vars1 != vars2 {
1017                let mut p = path.to_vec();
1018                p.push("variables".to_string());
1019                diff.entries.push(DiffEntry {
1020                    path: p,
1021                    kind: DiffKind::ParameterChanged {
1022                        old_value: format!("{:?}", vars1),
1023                        new_value: format!("{:?}", vars2),
1024                    },
1025                    description: "Variables list changed".to_string(),
1026                });
1027            }
1028        }
1029
1030        // GlobalCardinality
1031        (
1032            TLExpr::GlobalCardinality {
1033                variables: vars1,
1034                values: vals1,
1035                min_occurrences: min1,
1036                max_occurrences: max1,
1037            },
1038            TLExpr::GlobalCardinality {
1039                variables: vars2,
1040                values: vals2,
1041                min_occurrences: min2,
1042                max_occurrences: max2,
1043            },
1044        ) => {
1045            if vars1 != vars2 {
1046                let mut p = path.to_vec();
1047                p.push("variables".to_string());
1048                diff.entries.push(DiffEntry {
1049                    path: p,
1050                    kind: DiffKind::ParameterChanged {
1051                        old_value: format!("{:?}", vars1),
1052                        new_value: format!("{:?}", vars2),
1053                    },
1054                    description: "Variables list changed".to_string(),
1055                });
1056            }
1057            if vals1 != vals2 {
1058                let mut p = path.to_vec();
1059                p.push("values".to_string());
1060                diff.entries.push(DiffEntry {
1061                    path: p,
1062                    kind: DiffKind::ParameterChanged {
1063                        old_value: format!("{:?}", vals1),
1064                        new_value: format!("{:?}", vals2),
1065                    },
1066                    description: "Values list changed".to_string(),
1067                });
1068            }
1069            if min1 != min2 {
1070                let mut p = path.to_vec();
1071                p.push("min_occurrences".to_string());
1072                diff.entries.push(DiffEntry {
1073                    path: p,
1074                    kind: DiffKind::ParameterChanged {
1075                        old_value: format!("{:?}", min1),
1076                        new_value: format!("{:?}", min2),
1077                    },
1078                    description: "Min occurrences changed".to_string(),
1079                });
1080            }
1081            if max1 != max2 {
1082                let mut p = path.to_vec();
1083                p.push("max_occurrences".to_string());
1084                diff.entries.push(DiffEntry {
1085                    path: p,
1086                    kind: DiffKind::ParameterChanged {
1087                        old_value: format!("{:?}", max1),
1088                        new_value: format!("{:?}", max2),
1089                    },
1090                    description: "Max occurrences changed".to_string(),
1091                });
1092            }
1093        }
1094
1095        // Abducible
1096        (TLExpr::Abducible { name: n1, cost: c1 }, TLExpr::Abducible { name: n2, cost: c2 }) => {
1097            compare_string_param(n1, n2, path, "name", "Abducible name", diff);
1098            compare_f64_param(*c1, *c2, path, "cost", "Cost", diff);
1099        }
1100
1101        // SymbolLiteral
1102        (TLExpr::SymbolLiteral(s1), TLExpr::SymbolLiteral(s2)) => {
1103            compare_string_param(s1, s2, path, "symbol", "Symbol", diff);
1104        }
1105
1106        // Match
1107        (
1108            TLExpr::Match {
1109                scrutinee: sc1,
1110                arms: a1,
1111            },
1112            TLExpr::Match {
1113                scrutinee: sc2,
1114                arms: a2,
1115            },
1116        ) => {
1117            compare_child(sc1, sc2, path, "scrutinee", diff);
1118            if a1.len() != a2.len() {
1119                diff.entries.push(DiffEntry {
1120                    path: path.to_vec(),
1121                    kind: DiffKind::ParameterChanged {
1122                        old_value: format!("{} arms", a1.len()),
1123                        new_value: format!("{} arms", a2.len()),
1124                    },
1125                    description: "Match arm count changed".to_string(),
1126                });
1127            } else {
1128                for (i, ((p1, b1), (p2, b2))) in a1.iter().zip(a2.iter()).enumerate() {
1129                    if p1 != p2 {
1130                        diff.entries.push(DiffEntry {
1131                            path: path.to_vec(),
1132                            kind: DiffKind::ParameterChanged {
1133                                old_value: format!("{p1}"),
1134                                new_value: format!("{p2}"),
1135                            },
1136                            description: format!("arm[{i}] pattern changed"),
1137                        });
1138                    }
1139                    compare_child(b1, b2, path, &format!("arm[{i}]"), diff);
1140                }
1141            }
1142        }
1143
1144        // Catch-all: same type tag but not matched above (should not happen
1145        // if all variants are covered, but kept for safety)
1146        _ => {
1147            let old_dbg = format!("{:?}", old);
1148            let new_dbg = format!("{:?}", new);
1149            if old_dbg != new_dbg {
1150                diff.entries.push(DiffEntry {
1151                    path: path.to_vec(),
1152                    kind: DiffKind::ParameterChanged {
1153                        old_value: old_dbg,
1154                        new_value: new_dbg,
1155                    },
1156                    description: "Expression content changed".to_string(),
1157                });
1158            }
1159        }
1160    }
1161}
1162
1163#[cfg(test)]
1164mod tests {
1165    use super::*;
1166    use tensorlogic_ir::Term;
1167
1168    fn pred_a() -> TLExpr {
1169        TLExpr::pred("a", vec![Term::var("x")])
1170    }
1171
1172    fn pred_b() -> TLExpr {
1173        TLExpr::pred("b", vec![Term::var("x")])
1174    }
1175
1176    fn pred_c() -> TLExpr {
1177        TLExpr::pred("c", vec![Term::var("y")])
1178    }
1179
1180    #[test]
1181    fn test_diff_identical() {
1182        let e = pred_a();
1183        let diff = expr_diff(&e, &e);
1184        assert!(diff.is_identical());
1185        assert_eq!(diff.change_count(), 0);
1186    }
1187
1188    #[test]
1189    fn test_diff_different_type() {
1190        let old = TLExpr::and(pred_a(), pred_b());
1191        let new = TLExpr::or(pred_a(), pred_b());
1192        let diff = expr_diff(&old, &new);
1193        assert!(!diff.is_identical());
1194        assert!(diff.entries.iter().any(
1195            |e| matches!(&e.kind, DiffKind::TypeChanged { old_type, new_type }
1196                if old_type == "And" && new_type == "Or")
1197        ));
1198    }
1199
1200    #[test]
1201    fn test_diff_pred_name_change() {
1202        let old = TLExpr::pred("a", vec![Term::var("x")]);
1203        let new = TLExpr::pred("b", vec![Term::var("x")]);
1204        let diff = expr_diff(&old, &new);
1205        assert!(!diff.is_identical());
1206        assert!(diff.entries.iter().any(
1207            |e| matches!(&e.kind, DiffKind::ParameterChanged { old_value, new_value }
1208                if old_value == "a" && new_value == "b")
1209        ));
1210    }
1211
1212    #[test]
1213    fn test_diff_pred_arg_change() {
1214        let old = TLExpr::pred("p", vec![Term::var("x")]);
1215        let new = TLExpr::pred("p", vec![Term::var("y")]);
1216        let diff = expr_diff(&old, &new);
1217        assert!(!diff.is_identical());
1218        assert_eq!(diff.change_count(), 1);
1219        let entry = &diff.entries[0];
1220        assert_eq!(entry.path, vec!["arg0".to_string()]);
1221        assert!(matches!(&entry.kind, DiffKind::ParameterChanged { .. }));
1222    }
1223
1224    #[test]
1225    fn test_diff_constant_change() {
1226        let old = TLExpr::Constant(1.0);
1227        let new = TLExpr::Constant(2.0);
1228        let diff = expr_diff(&old, &new);
1229        assert!(!diff.is_identical());
1230        assert_eq!(diff.change_count(), 1);
1231        assert!(diff
1232            .entries
1233            .iter()
1234            .any(|e| matches!(&e.kind, DiffKind::ParameterChanged { .. })));
1235    }
1236
1237    #[test]
1238    fn test_diff_added_not() {
1239        let old = pred_a();
1240        let new = TLExpr::negate(pred_a());
1241        let diff = expr_diff(&old, &new);
1242        assert!(!diff.is_identical());
1243        assert!(diff.entries.iter().any(
1244            |e| matches!(&e.kind, DiffKind::TypeChanged { old_type, new_type }
1245                if old_type == "Pred" && new_type == "Not")
1246        ));
1247    }
1248
1249    #[test]
1250    fn test_diff_change_count() {
1251        let old = TLExpr::and(pred_a(), pred_b());
1252        let new = TLExpr::and(pred_b(), pred_c());
1253        let diff = expr_diff(&old, &new);
1254        // left child: name a->b; right child: name b->c, arg0 x->y
1255        assert!(diff.change_count() >= 2);
1256    }
1257
1258    #[test]
1259    fn test_diff_summary() {
1260        let old = TLExpr::Constant(1.0);
1261        let new = TLExpr::Constant(2.0);
1262        let diff = expr_diff(&old, &new);
1263        let s = diff.summary();
1264        assert!(s.contains("changes"));
1265        assert!(s.contains("modified"));
1266    }
1267
1268    #[test]
1269    fn test_diff_additions() {
1270        let mut diff = ExprDiff::new();
1271        diff.entries.push(DiffEntry {
1272            path: vec!["a".to_string()],
1273            kind: DiffKind::Added,
1274            description: "added".to_string(),
1275        });
1276        diff.entries.push(DiffEntry {
1277            path: vec!["b".to_string()],
1278            kind: DiffKind::Removed,
1279            description: "removed".to_string(),
1280        });
1281        assert_eq!(diff.additions().len(), 1);
1282        assert_eq!(diff.additions()[0].path, vec!["a".to_string()]);
1283    }
1284
1285    #[test]
1286    fn test_diff_removals() {
1287        let mut diff = ExprDiff::new();
1288        diff.entries.push(DiffEntry {
1289            path: vec!["a".to_string()],
1290            kind: DiffKind::Added,
1291            description: "added".to_string(),
1292        });
1293        diff.entries.push(DiffEntry {
1294            path: vec!["b".to_string()],
1295            kind: DiffKind::Removed,
1296            description: "removed".to_string(),
1297        });
1298        assert_eq!(diff.removals().len(), 1);
1299        assert_eq!(diff.removals()[0].path, vec!["b".to_string()]);
1300    }
1301
1302    #[test]
1303    fn test_diff_modifications() {
1304        let mut diff = ExprDiff::new();
1305        diff.entries.push(DiffEntry {
1306            path: vec![],
1307            kind: DiffKind::TypeChanged {
1308                old_type: "And".to_string(),
1309                new_type: "Or".to_string(),
1310            },
1311            description: "type".to_string(),
1312        });
1313        diff.entries.push(DiffEntry {
1314            path: vec![],
1315            kind: DiffKind::ParameterChanged {
1316                old_value: "a".to_string(),
1317                new_value: "b".to_string(),
1318            },
1319            description: "param".to_string(),
1320        });
1321        diff.entries.push(DiffEntry {
1322            path: vec![],
1323            kind: DiffKind::Added,
1324            description: "added".to_string(),
1325        });
1326        assert_eq!(diff.modifications().len(), 2);
1327    }
1328
1329    #[test]
1330    fn test_diff_kind_is_change() {
1331        assert!(!DiffKind::Unchanged.is_change());
1332        assert!(DiffKind::Added.is_change());
1333        assert!(DiffKind::Removed.is_change());
1334        assert!(DiffKind::ChildrenChanged.is_change());
1335        assert!((DiffKind::TypeChanged {
1336            old_type: "A".to_string(),
1337            new_type: "B".to_string(),
1338        })
1339        .is_change());
1340        assert!((DiffKind::ParameterChanged {
1341            old_value: "a".to_string(),
1342            new_value: "b".to_string(),
1343        })
1344        .is_change());
1345    }
1346
1347    #[test]
1348    fn test_diff_nested_change() {
1349        let old = TLExpr::and(pred_a(), pred_b());
1350        let new = TLExpr::and(pred_a(), pred_c());
1351        let diff = expr_diff(&old, &new);
1352        assert!(!diff.is_identical());
1353        // Right child changed: name b->c, arg y instead of x
1354        assert!(diff.change_count() >= 1);
1355        // Check that at least one path starts with "right"
1356        assert!(diff
1357            .entries
1358            .iter()
1359            .any(|e| e.path.first().is_some_and(|p| p == "right")));
1360    }
1361
1362    #[test]
1363    fn test_diff_quantifier_change() {
1364        let body = pred_a();
1365        let old = TLExpr::exists("x", "D", body.clone());
1366        let new = TLExpr::forall("x", "D", body);
1367        let diff = expr_diff(&old, &new);
1368        assert!(!diff.is_identical());
1369        assert!(diff.entries.iter().any(
1370            |e| matches!(&e.kind, DiffKind::TypeChanged { old_type, new_type }
1371                if old_type == "Exists" && new_type == "ForAll")
1372        ));
1373    }
1374
1375    #[test]
1376    fn test_diff_entry_path() {
1377        let old = TLExpr::and(TLExpr::or(pred_a(), pred_b()), TLExpr::Constant(1.0));
1378        let new = TLExpr::and(TLExpr::or(pred_a(), pred_c()), TLExpr::Constant(1.0));
1379        let diff = expr_diff(&old, &new);
1380        assert!(!diff.is_identical());
1381        // The change is at left -> right -> (name or arg0)
1382        assert!(diff
1383            .entries
1384            .iter()
1385            .any(|e| e.path.len() >= 2 && e.path[0] == "left" && e.path[1] == "right"));
1386    }
1387
1388    #[test]
1389    fn test_expr_type_tag_pred() {
1390        let e = pred_a();
1391        assert_eq!(expr_type_tag(&e), "Pred");
1392    }
1393
1394    #[test]
1395    fn test_expr_type_tag_and() {
1396        let e = TLExpr::and(pred_a(), pred_b());
1397        assert_eq!(expr_type_tag(&e), "And");
1398    }
1399
1400    #[test]
1401    fn test_diff_default_empty() {
1402        let diff = ExprDiff::new();
1403        assert!(diff.entries.is_empty());
1404        assert!(diff.is_identical());
1405        assert_eq!(diff.change_count(), 0);
1406    }
1407}