Skip to main content

tensorlogic_ir/
diff.rs

1//! IR diff tool for comparing graphs and expressions.
2//!
3//! This module provides utilities to compare two IR structures and
4//! identify differences, useful for debugging and validation.
5
6use crate::{EinsumGraph, EinsumNode, OpType, TLExpr};
7use std::collections::HashSet;
8
9/// Difference between two expressions
10#[derive(Debug, Clone, PartialEq, Eq)]
11pub enum ExprDiff {
12    /// Expressions are identical
13    Identical,
14    /// Different expression types
15    TypeMismatch { left: String, right: String },
16    /// Different predicate names or arities
17    PredicateMismatch { left: String, right: String },
18    /// Different subexpressions
19    SubexprMismatch {
20        path: Vec<String>,
21        left: String,
22        right: String,
23    },
24    /// Different quantifier variables or domains
25    QuantifierMismatch {
26        left_var: String,
27        right_var: String,
28        left_domain: String,
29        right_domain: String,
30    },
31}
32
33/// Difference between two graphs
34#[derive(Debug, Clone)]
35pub struct GraphDiff {
36    /// Tensors only in left graph
37    pub left_only_tensors: Vec<String>,
38    /// Tensors only in right graph
39    pub right_only_tensors: Vec<String>,
40    /// Nodes only in left graph
41    pub left_only_nodes: usize,
42    /// Nodes only in right graph
43    pub right_only_nodes: usize,
44    /// Different node operations
45    pub node_differences: Vec<NodeDiff>,
46    /// Output differences
47    pub output_differences: Vec<String>,
48}
49
50/// Difference in a specific node
51#[derive(Debug, Clone)]
52pub struct NodeDiff {
53    pub node_index: usize,
54    pub description: String,
55}
56
57impl ExprDiff {
58    /// Check if expressions are identical
59    pub fn is_identical(&self) -> bool {
60        matches!(self, ExprDiff::Identical)
61    }
62
63    /// Get human-readable description
64    pub fn description(&self) -> String {
65        match self {
66            ExprDiff::Identical => "Expressions are identical".to_string(),
67            ExprDiff::TypeMismatch { left, right } => {
68                format!("Type mismatch: left={}, right={}", left, right)
69            }
70            ExprDiff::PredicateMismatch { left, right } => {
71                format!("Predicate mismatch: left={}, right={}", left, right)
72            }
73            ExprDiff::SubexprMismatch { path, left, right } => {
74                format!(
75                    "Subexpression mismatch at {}: left={}, right={}",
76                    path.join("/"),
77                    left,
78                    right
79                )
80            }
81            ExprDiff::QuantifierMismatch {
82                left_var,
83                right_var,
84                left_domain,
85                right_domain,
86            } => {
87                format!(
88                    "Quantifier mismatch: left=({}, {}), right=({}, {})",
89                    left_var, left_domain, right_var, right_domain
90                )
91            }
92        }
93    }
94}
95
96impl GraphDiff {
97    /// Check if graphs are identical
98    pub fn is_identical(&self) -> bool {
99        self.left_only_tensors.is_empty()
100            && self.right_only_tensors.is_empty()
101            && self.left_only_nodes == 0
102            && self.right_only_nodes == 0
103            && self.node_differences.is_empty()
104            && self.output_differences.is_empty()
105    }
106
107    /// Get summary of differences
108    pub fn summary(&self) -> String {
109        if self.is_identical() {
110            return "Graphs are identical".to_string();
111        }
112
113        let mut parts = Vec::new();
114
115        if !self.left_only_tensors.is_empty() {
116            parts.push(format!(
117                "{} tensors only in left",
118                self.left_only_tensors.len()
119            ));
120        }
121        if !self.right_only_tensors.is_empty() {
122            parts.push(format!(
123                "{} tensors only in right",
124                self.right_only_tensors.len()
125            ));
126        }
127        if self.left_only_nodes > 0 {
128            parts.push(format!("{} nodes only in left", self.left_only_nodes));
129        }
130        if self.right_only_nodes > 0 {
131            parts.push(format!("{} nodes only in right", self.right_only_nodes));
132        }
133        if !self.node_differences.is_empty() {
134            parts.push(format!("{} node differences", self.node_differences.len()));
135        }
136        if !self.output_differences.is_empty() {
137            parts.push(format!(
138                "{} output differences",
139                self.output_differences.len()
140            ));
141        }
142
143        parts.join(", ")
144    }
145}
146
147/// Compare two expressions
148pub fn diff_exprs(left: &TLExpr, right: &TLExpr) -> ExprDiff {
149    diff_exprs_impl(left, right, &mut Vec::new())
150}
151
152fn diff_exprs_impl(left: &TLExpr, right: &TLExpr, path: &mut Vec<String>) -> ExprDiff {
153    match (left, right) {
154        (TLExpr::Pred { name: n1, args: a1 }, TLExpr::Pred { name: n2, args: a2 }) => {
155            if n1 != n2 || a1.len() != a2.len() {
156                ExprDiff::PredicateMismatch {
157                    left: format!("{}({})", n1, a1.len()),
158                    right: format!("{}({})", n2, a2.len()),
159                }
160            } else {
161                ExprDiff::Identical
162            }
163        }
164        (TLExpr::And(l1, r1), TLExpr::And(l2, r2))
165        | (TLExpr::Or(l1, r1), TLExpr::Or(l2, r2))
166        | (TLExpr::Imply(l1, r1), TLExpr::Imply(l2, r2))
167        | (TLExpr::Add(l1, r1), TLExpr::Add(l2, r2))
168        | (TLExpr::Sub(l1, r1), TLExpr::Sub(l2, r2))
169        | (TLExpr::Mul(l1, r1), TLExpr::Mul(l2, r2))
170        | (TLExpr::Div(l1, r1), TLExpr::Div(l2, r2))
171        | (TLExpr::Pow(l1, r1), TLExpr::Pow(l2, r2))
172        | (TLExpr::Mod(l1, r1), TLExpr::Mod(l2, r2))
173        | (TLExpr::Min(l1, r1), TLExpr::Min(l2, r2))
174        | (TLExpr::Max(l1, r1), TLExpr::Max(l2, r2))
175        | (TLExpr::Eq(l1, r1), TLExpr::Eq(l2, r2))
176        | (TLExpr::Lt(l1, r1), TLExpr::Lt(l2, r2))
177        | (TLExpr::Gt(l1, r1), TLExpr::Gt(l2, r2))
178        | (TLExpr::Lte(l1, r1), TLExpr::Lte(l2, r2))
179        | (TLExpr::Gte(l1, r1), TLExpr::Gte(l2, r2)) => {
180            path.push("left".to_string());
181            let left_diff = diff_exprs_impl(l1, l2, path);
182            path.pop();
183
184            if !left_diff.is_identical() {
185                return left_diff;
186            }
187
188            path.push("right".to_string());
189            let right_diff = diff_exprs_impl(r1, r2, path);
190            path.pop();
191
192            right_diff
193        }
194        (TLExpr::Not(e1), TLExpr::Not(e2))
195        | (TLExpr::Score(e1), TLExpr::Score(e2))
196        | (TLExpr::Abs(e1), TLExpr::Abs(e2))
197        | (TLExpr::Floor(e1), TLExpr::Floor(e2))
198        | (TLExpr::Ceil(e1), TLExpr::Ceil(e2))
199        | (TLExpr::Round(e1), TLExpr::Round(e2))
200        | (TLExpr::Sqrt(e1), TLExpr::Sqrt(e2))
201        | (TLExpr::Exp(e1), TLExpr::Exp(e2))
202        | (TLExpr::Log(e1), TLExpr::Log(e2))
203        | (TLExpr::Sin(e1), TLExpr::Sin(e2))
204        | (TLExpr::Cos(e1), TLExpr::Cos(e2))
205        | (TLExpr::Tan(e1), TLExpr::Tan(e2)) => {
206            path.push("inner".to_string());
207            let diff = diff_exprs_impl(e1, e2, path);
208            path.pop();
209            diff
210        }
211        (
212            TLExpr::Exists {
213                var: v1,
214                domain: d1,
215                body: b1,
216            },
217            TLExpr::Exists {
218                var: v2,
219                domain: d2,
220                body: b2,
221            },
222        )
223        | (
224            TLExpr::ForAll {
225                var: v1,
226                domain: d1,
227                body: b1,
228            },
229            TLExpr::ForAll {
230                var: v2,
231                domain: d2,
232                body: b2,
233            },
234        ) => {
235            if v1 != v2 || d1 != d2 {
236                return ExprDiff::QuantifierMismatch {
237                    left_var: v1.clone(),
238                    right_var: v2.clone(),
239                    left_domain: d1.clone(),
240                    right_domain: d2.clone(),
241                };
242            }
243
244            path.push("body".to_string());
245            let diff = diff_exprs_impl(b1, b2, path);
246            path.pop();
247            diff
248        }
249        (
250            TLExpr::IfThenElse {
251                condition: c1,
252                then_branch: t1,
253                else_branch: e1,
254            },
255            TLExpr::IfThenElse {
256                condition: c2,
257                then_branch: t2,
258                else_branch: e2,
259            },
260        ) => {
261            path.push("condition".to_string());
262            let cond_diff = diff_exprs_impl(c1, c2, path);
263            path.pop();
264
265            if !cond_diff.is_identical() {
266                return cond_diff;
267            }
268
269            path.push("then".to_string());
270            let then_diff = diff_exprs_impl(t1, t2, path);
271            path.pop();
272
273            if !then_diff.is_identical() {
274                return then_diff;
275            }
276
277            path.push("else".to_string());
278            let else_diff = diff_exprs_impl(e1, e2, path);
279            path.pop();
280
281            else_diff
282        }
283        (TLExpr::Constant(c1), TLExpr::Constant(c2)) => {
284            if (c1 - c2).abs() < f64::EPSILON {
285                ExprDiff::Identical
286            } else {
287                ExprDiff::SubexprMismatch {
288                    path: path.clone(),
289                    left: format!("{}", c1),
290                    right: format!("{}", c2),
291                }
292            }
293        }
294        _ => ExprDiff::TypeMismatch {
295            left: format!("{:?}", left).split('(').next().unwrap().to_string(),
296            right: format!("{:?}", right)
297                .split('(')
298                .next()
299                .unwrap()
300                .to_string(),
301        },
302    }
303}
304
305/// Compare two graphs
306pub fn diff_graphs(left: &EinsumGraph, right: &EinsumGraph) -> GraphDiff {
307    let left_tensors: HashSet<_> = left.tensors.iter().collect();
308    let right_tensors: HashSet<_> = right.tensors.iter().collect();
309
310    let left_only_tensors: Vec<String> = left_tensors
311        .difference(&right_tensors)
312        .map(|s| s.to_string())
313        .collect();
314    let right_only_tensors: Vec<String> = right_tensors
315        .difference(&left_tensors)
316        .map(|s| s.to_string())
317        .collect();
318
319    let node_differences = diff_nodes(&left.nodes, &right.nodes);
320
321    let left_only_nodes = if left.nodes.len() > right.nodes.len() {
322        left.nodes.len() - right.nodes.len()
323    } else {
324        0
325    };
326    let right_only_nodes = if right.nodes.len() > left.nodes.len() {
327        right.nodes.len() - left.nodes.len()
328    } else {
329        0
330    };
331
332    let output_differences = diff_outputs(&left.outputs, &right.outputs);
333
334    GraphDiff {
335        left_only_tensors,
336        right_only_tensors,
337        left_only_nodes,
338        right_only_nodes,
339        node_differences,
340        output_differences,
341    }
342}
343
344fn diff_nodes(left: &[EinsumNode], right: &[EinsumNode]) -> Vec<NodeDiff> {
345    let mut differences = Vec::new();
346    let min_len = left.len().min(right.len());
347
348    for i in 0..min_len {
349        if let Some(diff) = diff_node(&left[i], &right[i], i) {
350            differences.push(diff);
351        }
352    }
353
354    differences
355}
356
357fn diff_node(left: &EinsumNode, right: &EinsumNode, index: usize) -> Option<NodeDiff> {
358    if left.inputs != right.inputs {
359        return Some(NodeDiff {
360            node_index: index,
361            description: format!("Different inputs: {:?} vs {:?}", left.inputs, right.inputs),
362        });
363    }
364
365    if left.outputs != right.outputs {
366        return Some(NodeDiff {
367            node_index: index,
368            description: format!(
369                "Different outputs: {:?} vs {:?}",
370                left.outputs, right.outputs
371            ),
372        });
373    }
374
375    if !ops_equal(&left.op, &right.op) {
376        return Some(NodeDiff {
377            node_index: index,
378            description: format!("Different operations: {:?} vs {:?}", left.op, right.op),
379        });
380    }
381
382    None
383}
384
385fn ops_equal(left: &OpType, right: &OpType) -> bool {
386    // Simple discriminant comparison
387    std::mem::discriminant(left) == std::mem::discriminant(right)
388}
389
390fn diff_outputs(left: &[usize], right: &[usize]) -> Vec<String> {
391    let mut differences = Vec::new();
392
393    if left.len() != right.len() {
394        differences.push(format!(
395            "Different number of outputs: {} vs {}",
396            left.len(),
397            right.len()
398        ));
399    }
400
401    for (i, (l, r)) in left.iter().zip(right.iter()).enumerate() {
402        if l != r {
403            differences.push(format!("Output {} differs: {} vs {}", i, l, r));
404        }
405    }
406
407    differences
408}
409
410#[cfg(test)]
411mod tests {
412    use super::*;
413    use crate::Term;
414
415    #[test]
416    fn test_identical_exprs() {
417        let expr1 = TLExpr::pred("p", vec![Term::var("x")]);
418        let expr2 = TLExpr::pred("p", vec![Term::var("x")]);
419
420        let diff = diff_exprs(&expr1, &expr2);
421        assert!(diff.is_identical());
422    }
423
424    #[test]
425    fn test_different_predicates() {
426        let expr1 = TLExpr::pred("p", vec![Term::var("x")]);
427        let expr2 = TLExpr::pred("q", vec![Term::var("x")]);
428
429        let diff = diff_exprs(&expr1, &expr2);
430        assert!(!diff.is_identical());
431        assert!(matches!(diff, ExprDiff::PredicateMismatch { .. }));
432    }
433
434    #[test]
435    fn test_different_types() {
436        let expr1 = TLExpr::pred("p", vec![Term::var("x")]);
437        let expr2 = TLExpr::constant(1.0);
438
439        let diff = diff_exprs(&expr1, &expr2);
440        assert!(!diff.is_identical());
441        assert!(matches!(diff, ExprDiff::TypeMismatch { .. }));
442    }
443
444    #[test]
445    fn test_nested_and_difference() {
446        let expr1 = TLExpr::and(
447            TLExpr::pred("p", vec![Term::var("x")]),
448            TLExpr::pred("q", vec![Term::var("y")]),
449        );
450        let expr2 = TLExpr::and(
451            TLExpr::pred("p", vec![Term::var("x")]),
452            TLExpr::pred("r", vec![Term::var("y")]),
453        );
454
455        let diff = diff_exprs(&expr1, &expr2);
456        assert!(!diff.is_identical());
457    }
458
459    #[test]
460    fn test_quantifier_difference() {
461        let expr1 = TLExpr::exists("x", "Domain1", TLExpr::pred("p", vec![Term::var("x")]));
462        let expr2 = TLExpr::exists("y", "Domain2", TLExpr::pred("p", vec![Term::var("y")]));
463
464        let diff = diff_exprs(&expr1, &expr2);
465        assert!(!diff.is_identical());
466        assert!(matches!(diff, ExprDiff::QuantifierMismatch { .. }));
467    }
468
469    #[test]
470    fn test_identical_graphs() {
471        let graph1 = EinsumGraph {
472            tensors: vec!["t0".to_string()],
473            nodes: vec![],
474            inputs: vec![],
475            outputs: vec![0],
476            tensor_metadata: std::collections::HashMap::new(),
477        };
478        let graph2 = EinsumGraph {
479            tensors: vec!["t0".to_string()],
480            nodes: vec![],
481            inputs: vec![],
482            outputs: vec![0],
483            tensor_metadata: std::collections::HashMap::new(),
484        };
485
486        let diff = diff_graphs(&graph1, &graph2);
487        assert!(diff.is_identical());
488    }
489
490    #[test]
491    fn test_different_tensor_count() {
492        let graph1 = EinsumGraph {
493            tensors: vec!["t0".to_string(), "t1".to_string()],
494            nodes: vec![],
495            inputs: vec![],
496            outputs: vec![],
497            tensor_metadata: std::collections::HashMap::new(),
498        };
499        let graph2 = EinsumGraph {
500            tensors: vec!["t0".to_string()],
501            nodes: vec![],
502            inputs: vec![],
503            outputs: vec![],
504            tensor_metadata: std::collections::HashMap::new(),
505        };
506
507        let diff = diff_graphs(&graph1, &graph2);
508        assert!(!diff.is_identical());
509        assert_eq!(diff.left_only_tensors.len(), 1);
510    }
511
512    #[test]
513    fn test_different_outputs() {
514        let graph1 = EinsumGraph {
515            tensors: vec!["t0".to_string()],
516            nodes: vec![],
517            inputs: vec![],
518            outputs: vec![0],
519            tensor_metadata: std::collections::HashMap::new(),
520        };
521        let graph2 = EinsumGraph {
522            tensors: vec!["t0".to_string()],
523            nodes: vec![],
524            inputs: vec![],
525            outputs: vec![1],
526            tensor_metadata: std::collections::HashMap::new(),
527        };
528
529        let diff = diff_graphs(&graph1, &graph2);
530        assert!(!diff.is_identical());
531        assert!(!diff.output_differences.is_empty());
532    }
533
534    #[test]
535    fn test_diff_summary() {
536        let diff = GraphDiff {
537            left_only_tensors: vec!["t1".to_string()],
538            right_only_tensors: vec!["t2".to_string()],
539            left_only_nodes: 0,
540            right_only_nodes: 0,
541            node_differences: vec![],
542            output_differences: vec![],
543        };
544
545        let summary = diff.summary();
546        assert!(summary.contains("tensors only in left"));
547        assert!(summary.contains("tensors only in right"));
548    }
549}