Skip to main content

tensorlogic_ir/
diff.rs

1//! IR diff tool for comparing graphs and expressions.
2//!
3//! This module provides utilities to compare two IR structures and
4//! identify differences, useful for debugging and validation.
5
6use crate::{EinsumGraph, EinsumNode, OpType, TLExpr};
7use std::collections::HashSet;
8
9/// Difference between two expressions
10#[derive(Debug, Clone, PartialEq, Eq)]
11pub enum ExprDiff {
12    /// Expressions are identical
13    Identical,
14    /// Different expression types
15    TypeMismatch { left: String, right: String },
16    /// Different predicate names or arities
17    PredicateMismatch { left: String, right: String },
18    /// Different subexpressions
19    SubexprMismatch {
20        path: Vec<String>,
21        left: String,
22        right: String,
23    },
24    /// Different quantifier variables or domains
25    QuantifierMismatch {
26        left_var: String,
27        right_var: String,
28        left_domain: String,
29        right_domain: String,
30    },
31}
32
33/// Difference between two graphs
34#[derive(Debug, Clone)]
35pub struct GraphDiff {
36    /// Tensors only in left graph
37    pub left_only_tensors: Vec<String>,
38    /// Tensors only in right graph
39    pub right_only_tensors: Vec<String>,
40    /// Nodes only in left graph
41    pub left_only_nodes: usize,
42    /// Nodes only in right graph
43    pub right_only_nodes: usize,
44    /// Different node operations
45    pub node_differences: Vec<NodeDiff>,
46    /// Output differences
47    pub output_differences: Vec<String>,
48}
49
50/// Difference in a specific node
51#[derive(Debug, Clone)]
52pub struct NodeDiff {
53    pub node_index: usize,
54    pub description: String,
55}
56
57impl ExprDiff {
58    /// Check if expressions are identical
59    pub fn is_identical(&self) -> bool {
60        matches!(self, ExprDiff::Identical)
61    }
62
63    /// Get human-readable description
64    pub fn description(&self) -> String {
65        match self {
66            ExprDiff::Identical => "Expressions are identical".to_string(),
67            ExprDiff::TypeMismatch { left, right } => {
68                format!("Type mismatch: left={}, right={}", left, right)
69            }
70            ExprDiff::PredicateMismatch { left, right } => {
71                format!("Predicate mismatch: left={}, right={}", left, right)
72            }
73            ExprDiff::SubexprMismatch { path, left, right } => {
74                format!(
75                    "Subexpression mismatch at {}: left={}, right={}",
76                    path.join("/"),
77                    left,
78                    right
79                )
80            }
81            ExprDiff::QuantifierMismatch {
82                left_var,
83                right_var,
84                left_domain,
85                right_domain,
86            } => {
87                format!(
88                    "Quantifier mismatch: left=({}, {}), right=({}, {})",
89                    left_var, left_domain, right_var, right_domain
90                )
91            }
92        }
93    }
94}
95
96impl GraphDiff {
97    /// Check if graphs are identical
98    pub fn is_identical(&self) -> bool {
99        self.left_only_tensors.is_empty()
100            && self.right_only_tensors.is_empty()
101            && self.left_only_nodes == 0
102            && self.right_only_nodes == 0
103            && self.node_differences.is_empty()
104            && self.output_differences.is_empty()
105    }
106
107    /// Get summary of differences
108    pub fn summary(&self) -> String {
109        if self.is_identical() {
110            return "Graphs are identical".to_string();
111        }
112
113        let mut parts = Vec::new();
114
115        if !self.left_only_tensors.is_empty() {
116            parts.push(format!(
117                "{} tensors only in left",
118                self.left_only_tensors.len()
119            ));
120        }
121        if !self.right_only_tensors.is_empty() {
122            parts.push(format!(
123                "{} tensors only in right",
124                self.right_only_tensors.len()
125            ));
126        }
127        if self.left_only_nodes > 0 {
128            parts.push(format!("{} nodes only in left", self.left_only_nodes));
129        }
130        if self.right_only_nodes > 0 {
131            parts.push(format!("{} nodes only in right", self.right_only_nodes));
132        }
133        if !self.node_differences.is_empty() {
134            parts.push(format!("{} node differences", self.node_differences.len()));
135        }
136        if !self.output_differences.is_empty() {
137            parts.push(format!(
138                "{} output differences",
139                self.output_differences.len()
140            ));
141        }
142
143        parts.join(", ")
144    }
145}
146
147/// Compare two expressions
148pub fn diff_exprs(left: &TLExpr, right: &TLExpr) -> ExprDiff {
149    diff_exprs_impl(left, right, &mut Vec::new())
150}
151
152fn diff_exprs_impl(left: &TLExpr, right: &TLExpr, path: &mut Vec<String>) -> ExprDiff {
153    match (left, right) {
154        (TLExpr::Pred { name: n1, args: a1 }, TLExpr::Pred { name: n2, args: a2 }) => {
155            if n1 != n2 || a1.len() != a2.len() {
156                ExprDiff::PredicateMismatch {
157                    left: format!("{}({})", n1, a1.len()),
158                    right: format!("{}({})", n2, a2.len()),
159                }
160            } else {
161                ExprDiff::Identical
162            }
163        }
164        (TLExpr::And(l1, r1), TLExpr::And(l2, r2))
165        | (TLExpr::Or(l1, r1), TLExpr::Or(l2, r2))
166        | (TLExpr::Imply(l1, r1), TLExpr::Imply(l2, r2))
167        | (TLExpr::Add(l1, r1), TLExpr::Add(l2, r2))
168        | (TLExpr::Sub(l1, r1), TLExpr::Sub(l2, r2))
169        | (TLExpr::Mul(l1, r1), TLExpr::Mul(l2, r2))
170        | (TLExpr::Div(l1, r1), TLExpr::Div(l2, r2))
171        | (TLExpr::Pow(l1, r1), TLExpr::Pow(l2, r2))
172        | (TLExpr::Mod(l1, r1), TLExpr::Mod(l2, r2))
173        | (TLExpr::Min(l1, r1), TLExpr::Min(l2, r2))
174        | (TLExpr::Max(l1, r1), TLExpr::Max(l2, r2))
175        | (TLExpr::Eq(l1, r1), TLExpr::Eq(l2, r2))
176        | (TLExpr::Lt(l1, r1), TLExpr::Lt(l2, r2))
177        | (TLExpr::Gt(l1, r1), TLExpr::Gt(l2, r2))
178        | (TLExpr::Lte(l1, r1), TLExpr::Lte(l2, r2))
179        | (TLExpr::Gte(l1, r1), TLExpr::Gte(l2, r2)) => {
180            path.push("left".to_string());
181            let left_diff = diff_exprs_impl(l1, l2, path);
182            path.pop();
183
184            if !left_diff.is_identical() {
185                return left_diff;
186            }
187
188            path.push("right".to_string());
189            let right_diff = diff_exprs_impl(r1, r2, path);
190            path.pop();
191
192            right_diff
193        }
194        (TLExpr::Not(e1), TLExpr::Not(e2))
195        | (TLExpr::Score(e1), TLExpr::Score(e2))
196        | (TLExpr::Abs(e1), TLExpr::Abs(e2))
197        | (TLExpr::Floor(e1), TLExpr::Floor(e2))
198        | (TLExpr::Ceil(e1), TLExpr::Ceil(e2))
199        | (TLExpr::Round(e1), TLExpr::Round(e2))
200        | (TLExpr::Sqrt(e1), TLExpr::Sqrt(e2))
201        | (TLExpr::Exp(e1), TLExpr::Exp(e2))
202        | (TLExpr::Log(e1), TLExpr::Log(e2))
203        | (TLExpr::Sin(e1), TLExpr::Sin(e2))
204        | (TLExpr::Cos(e1), TLExpr::Cos(e2))
205        | (TLExpr::Tan(e1), TLExpr::Tan(e2)) => {
206            path.push("inner".to_string());
207            let diff = diff_exprs_impl(e1, e2, path);
208            path.pop();
209            diff
210        }
211        (
212            TLExpr::Exists {
213                var: v1,
214                domain: d1,
215                body: b1,
216            },
217            TLExpr::Exists {
218                var: v2,
219                domain: d2,
220                body: b2,
221            },
222        )
223        | (
224            TLExpr::ForAll {
225                var: v1,
226                domain: d1,
227                body: b1,
228            },
229            TLExpr::ForAll {
230                var: v2,
231                domain: d2,
232                body: b2,
233            },
234        ) => {
235            if v1 != v2 || d1 != d2 {
236                return ExprDiff::QuantifierMismatch {
237                    left_var: v1.clone(),
238                    right_var: v2.clone(),
239                    left_domain: d1.clone(),
240                    right_domain: d2.clone(),
241                };
242            }
243
244            path.push("body".to_string());
245            let diff = diff_exprs_impl(b1, b2, path);
246            path.pop();
247            diff
248        }
249        (
250            TLExpr::IfThenElse {
251                condition: c1,
252                then_branch: t1,
253                else_branch: e1,
254            },
255            TLExpr::IfThenElse {
256                condition: c2,
257                then_branch: t2,
258                else_branch: e2,
259            },
260        ) => {
261            path.push("condition".to_string());
262            let cond_diff = diff_exprs_impl(c1, c2, path);
263            path.pop();
264
265            if !cond_diff.is_identical() {
266                return cond_diff;
267            }
268
269            path.push("then".to_string());
270            let then_diff = diff_exprs_impl(t1, t2, path);
271            path.pop();
272
273            if !then_diff.is_identical() {
274                return then_diff;
275            }
276
277            path.push("else".to_string());
278            let else_diff = diff_exprs_impl(e1, e2, path);
279            path.pop();
280
281            else_diff
282        }
283        (TLExpr::Constant(c1), TLExpr::Constant(c2)) => {
284            if (c1 - c2).abs() < f64::EPSILON {
285                ExprDiff::Identical
286            } else {
287                ExprDiff::SubexprMismatch {
288                    path: path.clone(),
289                    left: format!("{}", c1),
290                    right: format!("{}", c2),
291                }
292            }
293        }
294        (TLExpr::SymbolLiteral(s1), TLExpr::SymbolLiteral(s2)) => {
295            if s1 == s2 {
296                ExprDiff::Identical
297            } else {
298                ExprDiff::SubexprMismatch {
299                    path: path.clone(),
300                    left: format!(":{s1}"),
301                    right: format!(":{s2}"),
302                }
303            }
304        }
305        (
306            TLExpr::Match {
307                scrutinee: s1,
308                arms: a1,
309            },
310            TLExpr::Match {
311                scrutinee: s2,
312                arms: a2,
313            },
314        ) => {
315            path.push("scrutinee".to_string());
316            let sd = diff_exprs_impl(s1, s2, path);
317            path.pop();
318            if !matches!(sd, ExprDiff::Identical) {
319                return sd;
320            }
321            if a1.len() != a2.len() {
322                return ExprDiff::SubexprMismatch {
323                    path: path.clone(),
324                    left: format!("{} arms", a1.len()),
325                    right: format!("{} arms", a2.len()),
326                };
327            }
328            for (i, ((p1, b1), (p2, b2))) in a1.iter().zip(a2.iter()).enumerate() {
329                if p1 != p2 {
330                    return ExprDiff::SubexprMismatch {
331                        path: path.clone(),
332                        left: format!("arm[{i}] pattern {p1}"),
333                        right: format!("arm[{i}] pattern {p2}"),
334                    };
335                }
336                path.push(format!("arm[{i}]"));
337                let bd = diff_exprs_impl(b1, b2, path);
338                path.pop();
339                if !matches!(bd, ExprDiff::Identical) {
340                    return bd;
341                }
342            }
343            ExprDiff::Identical
344        }
345        _ => ExprDiff::TypeMismatch {
346            left: format!("{:?}", left)
347                .split('(')
348                .next()
349                .unwrap_or("unknown")
350                .to_string(),
351            right: format!("{:?}", right)
352                .split('(')
353                .next()
354                .unwrap_or("unknown")
355                .to_string(),
356        },
357    }
358}
359
360/// Compare two graphs
361pub fn diff_graphs(left: &EinsumGraph, right: &EinsumGraph) -> GraphDiff {
362    let left_tensors: HashSet<_> = left.tensors.iter().collect();
363    let right_tensors: HashSet<_> = right.tensors.iter().collect();
364
365    let left_only_tensors: Vec<String> = left_tensors
366        .difference(&right_tensors)
367        .map(|s| s.to_string())
368        .collect();
369    let right_only_tensors: Vec<String> = right_tensors
370        .difference(&left_tensors)
371        .map(|s| s.to_string())
372        .collect();
373
374    let node_differences = diff_nodes(&left.nodes, &right.nodes);
375
376    let left_only_nodes = if left.nodes.len() > right.nodes.len() {
377        left.nodes.len() - right.nodes.len()
378    } else {
379        0
380    };
381    let right_only_nodes = if right.nodes.len() > left.nodes.len() {
382        right.nodes.len() - left.nodes.len()
383    } else {
384        0
385    };
386
387    let output_differences = diff_outputs(&left.outputs, &right.outputs);
388
389    GraphDiff {
390        left_only_tensors,
391        right_only_tensors,
392        left_only_nodes,
393        right_only_nodes,
394        node_differences,
395        output_differences,
396    }
397}
398
399fn diff_nodes(left: &[EinsumNode], right: &[EinsumNode]) -> Vec<NodeDiff> {
400    let mut differences = Vec::new();
401    let min_len = left.len().min(right.len());
402
403    for i in 0..min_len {
404        if let Some(diff) = diff_node(&left[i], &right[i], i) {
405            differences.push(diff);
406        }
407    }
408
409    differences
410}
411
412fn diff_node(left: &EinsumNode, right: &EinsumNode, index: usize) -> Option<NodeDiff> {
413    if left.inputs != right.inputs {
414        return Some(NodeDiff {
415            node_index: index,
416            description: format!("Different inputs: {:?} vs {:?}", left.inputs, right.inputs),
417        });
418    }
419
420    if left.outputs != right.outputs {
421        return Some(NodeDiff {
422            node_index: index,
423            description: format!(
424                "Different outputs: {:?} vs {:?}",
425                left.outputs, right.outputs
426            ),
427        });
428    }
429
430    if !ops_equal(&left.op, &right.op) {
431        return Some(NodeDiff {
432            node_index: index,
433            description: format!("Different operations: {:?} vs {:?}", left.op, right.op),
434        });
435    }
436
437    None
438}
439
440fn ops_equal(left: &OpType, right: &OpType) -> bool {
441    // Simple discriminant comparison
442    std::mem::discriminant(left) == std::mem::discriminant(right)
443}
444
445fn diff_outputs(left: &[usize], right: &[usize]) -> Vec<String> {
446    let mut differences = Vec::new();
447
448    if left.len() != right.len() {
449        differences.push(format!(
450            "Different number of outputs: {} vs {}",
451            left.len(),
452            right.len()
453        ));
454    }
455
456    for (i, (l, r)) in left.iter().zip(right.iter()).enumerate() {
457        if l != r {
458            differences.push(format!("Output {} differs: {} vs {}", i, l, r));
459        }
460    }
461
462    differences
463}
464
465#[cfg(test)]
466mod tests {
467    use super::*;
468    use crate::Term;
469
470    #[test]
471    fn test_identical_exprs() {
472        let expr1 = TLExpr::pred("p", vec![Term::var("x")]);
473        let expr2 = TLExpr::pred("p", vec![Term::var("x")]);
474
475        let diff = diff_exprs(&expr1, &expr2);
476        assert!(diff.is_identical());
477    }
478
479    #[test]
480    fn test_different_predicates() {
481        let expr1 = TLExpr::pred("p", vec![Term::var("x")]);
482        let expr2 = TLExpr::pred("q", vec![Term::var("x")]);
483
484        let diff = diff_exprs(&expr1, &expr2);
485        assert!(!diff.is_identical());
486        assert!(matches!(diff, ExprDiff::PredicateMismatch { .. }));
487    }
488
489    #[test]
490    fn test_different_types() {
491        let expr1 = TLExpr::pred("p", vec![Term::var("x")]);
492        let expr2 = TLExpr::constant(1.0);
493
494        let diff = diff_exprs(&expr1, &expr2);
495        assert!(!diff.is_identical());
496        assert!(matches!(diff, ExprDiff::TypeMismatch { .. }));
497    }
498
499    #[test]
500    fn test_nested_and_difference() {
501        let expr1 = TLExpr::and(
502            TLExpr::pred("p", vec![Term::var("x")]),
503            TLExpr::pred("q", vec![Term::var("y")]),
504        );
505        let expr2 = TLExpr::and(
506            TLExpr::pred("p", vec![Term::var("x")]),
507            TLExpr::pred("r", vec![Term::var("y")]),
508        );
509
510        let diff = diff_exprs(&expr1, &expr2);
511        assert!(!diff.is_identical());
512    }
513
514    #[test]
515    fn test_quantifier_difference() {
516        let expr1 = TLExpr::exists("x", "Domain1", TLExpr::pred("p", vec![Term::var("x")]));
517        let expr2 = TLExpr::exists("y", "Domain2", TLExpr::pred("p", vec![Term::var("y")]));
518
519        let diff = diff_exprs(&expr1, &expr2);
520        assert!(!diff.is_identical());
521        assert!(matches!(diff, ExprDiff::QuantifierMismatch { .. }));
522    }
523
524    #[test]
525    fn test_identical_graphs() {
526        let graph1 = EinsumGraph {
527            tensors: vec!["t0".to_string()],
528            nodes: vec![],
529            inputs: vec![],
530            outputs: vec![0],
531            tensor_metadata: std::collections::HashMap::new(),
532        };
533        let graph2 = EinsumGraph {
534            tensors: vec!["t0".to_string()],
535            nodes: vec![],
536            inputs: vec![],
537            outputs: vec![0],
538            tensor_metadata: std::collections::HashMap::new(),
539        };
540
541        let diff = diff_graphs(&graph1, &graph2);
542        assert!(diff.is_identical());
543    }
544
545    #[test]
546    fn test_different_tensor_count() {
547        let graph1 = EinsumGraph {
548            tensors: vec!["t0".to_string(), "t1".to_string()],
549            nodes: vec![],
550            inputs: vec![],
551            outputs: vec![],
552            tensor_metadata: std::collections::HashMap::new(),
553        };
554        let graph2 = EinsumGraph {
555            tensors: vec!["t0".to_string()],
556            nodes: vec![],
557            inputs: vec![],
558            outputs: vec![],
559            tensor_metadata: std::collections::HashMap::new(),
560        };
561
562        let diff = diff_graphs(&graph1, &graph2);
563        assert!(!diff.is_identical());
564        assert_eq!(diff.left_only_tensors.len(), 1);
565    }
566
567    #[test]
568    fn test_different_outputs() {
569        let graph1 = EinsumGraph {
570            tensors: vec!["t0".to_string()],
571            nodes: vec![],
572            inputs: vec![],
573            outputs: vec![0],
574            tensor_metadata: std::collections::HashMap::new(),
575        };
576        let graph2 = EinsumGraph {
577            tensors: vec!["t0".to_string()],
578            nodes: vec![],
579            inputs: vec![],
580            outputs: vec![1],
581            tensor_metadata: std::collections::HashMap::new(),
582        };
583
584        let diff = diff_graphs(&graph1, &graph2);
585        assert!(!diff.is_identical());
586        assert!(!diff.output_differences.is_empty());
587    }
588
589    #[test]
590    fn test_diff_summary() {
591        let diff = GraphDiff {
592            left_only_tensors: vec!["t1".to_string()],
593            right_only_tensors: vec!["t2".to_string()],
594            left_only_nodes: 0,
595            right_only_nodes: 0,
596            node_differences: vec![],
597            output_differences: vec![],
598        };
599
600        let summary = diff.summary();
601        assert!(summary.contains("tensors only in left"));
602        assert!(summary.contains("tensors only in right"));
603    }
604}