1use crate::{EinsumGraph, EinsumNode, OpType, TLExpr};
7use std::collections::HashSet;
8
9#[derive(Debug, Clone, PartialEq, Eq)]
11pub enum ExprDiff {
12 Identical,
14 TypeMismatch { left: String, right: String },
16 PredicateMismatch { left: String, right: String },
18 SubexprMismatch {
20 path: Vec<String>,
21 left: String,
22 right: String,
23 },
24 QuantifierMismatch {
26 left_var: String,
27 right_var: String,
28 left_domain: String,
29 right_domain: String,
30 },
31}
32
33#[derive(Debug, Clone)]
35pub struct GraphDiff {
36 pub left_only_tensors: Vec<String>,
38 pub right_only_tensors: Vec<String>,
40 pub left_only_nodes: usize,
42 pub right_only_nodes: usize,
44 pub node_differences: Vec<NodeDiff>,
46 pub output_differences: Vec<String>,
48}
49
50#[derive(Debug, Clone)]
52pub struct NodeDiff {
53 pub node_index: usize,
54 pub description: String,
55}
56
57impl ExprDiff {
58 pub fn is_identical(&self) -> bool {
60 matches!(self, ExprDiff::Identical)
61 }
62
63 pub fn description(&self) -> String {
65 match self {
66 ExprDiff::Identical => "Expressions are identical".to_string(),
67 ExprDiff::TypeMismatch { left, right } => {
68 format!("Type mismatch: left={}, right={}", left, right)
69 }
70 ExprDiff::PredicateMismatch { left, right } => {
71 format!("Predicate mismatch: left={}, right={}", left, right)
72 }
73 ExprDiff::SubexprMismatch { path, left, right } => {
74 format!(
75 "Subexpression mismatch at {}: left={}, right={}",
76 path.join("/"),
77 left,
78 right
79 )
80 }
81 ExprDiff::QuantifierMismatch {
82 left_var,
83 right_var,
84 left_domain,
85 right_domain,
86 } => {
87 format!(
88 "Quantifier mismatch: left=({}, {}), right=({}, {})",
89 left_var, left_domain, right_var, right_domain
90 )
91 }
92 }
93 }
94}
95
96impl GraphDiff {
97 pub fn is_identical(&self) -> bool {
99 self.left_only_tensors.is_empty()
100 && self.right_only_tensors.is_empty()
101 && self.left_only_nodes == 0
102 && self.right_only_nodes == 0
103 && self.node_differences.is_empty()
104 && self.output_differences.is_empty()
105 }
106
107 pub fn summary(&self) -> String {
109 if self.is_identical() {
110 return "Graphs are identical".to_string();
111 }
112
113 let mut parts = Vec::new();
114
115 if !self.left_only_tensors.is_empty() {
116 parts.push(format!(
117 "{} tensors only in left",
118 self.left_only_tensors.len()
119 ));
120 }
121 if !self.right_only_tensors.is_empty() {
122 parts.push(format!(
123 "{} tensors only in right",
124 self.right_only_tensors.len()
125 ));
126 }
127 if self.left_only_nodes > 0 {
128 parts.push(format!("{} nodes only in left", self.left_only_nodes));
129 }
130 if self.right_only_nodes > 0 {
131 parts.push(format!("{} nodes only in right", self.right_only_nodes));
132 }
133 if !self.node_differences.is_empty() {
134 parts.push(format!("{} node differences", self.node_differences.len()));
135 }
136 if !self.output_differences.is_empty() {
137 parts.push(format!(
138 "{} output differences",
139 self.output_differences.len()
140 ));
141 }
142
143 parts.join(", ")
144 }
145}
146
147pub fn diff_exprs(left: &TLExpr, right: &TLExpr) -> ExprDiff {
149 diff_exprs_impl(left, right, &mut Vec::new())
150}
151
152fn diff_exprs_impl(left: &TLExpr, right: &TLExpr, path: &mut Vec<String>) -> ExprDiff {
153 match (left, right) {
154 (TLExpr::Pred { name: n1, args: a1 }, TLExpr::Pred { name: n2, args: a2 }) => {
155 if n1 != n2 || a1.len() != a2.len() {
156 ExprDiff::PredicateMismatch {
157 left: format!("{}({})", n1, a1.len()),
158 right: format!("{}({})", n2, a2.len()),
159 }
160 } else {
161 ExprDiff::Identical
162 }
163 }
164 (TLExpr::And(l1, r1), TLExpr::And(l2, r2))
165 | (TLExpr::Or(l1, r1), TLExpr::Or(l2, r2))
166 | (TLExpr::Imply(l1, r1), TLExpr::Imply(l2, r2))
167 | (TLExpr::Add(l1, r1), TLExpr::Add(l2, r2))
168 | (TLExpr::Sub(l1, r1), TLExpr::Sub(l2, r2))
169 | (TLExpr::Mul(l1, r1), TLExpr::Mul(l2, r2))
170 | (TLExpr::Div(l1, r1), TLExpr::Div(l2, r2))
171 | (TLExpr::Pow(l1, r1), TLExpr::Pow(l2, r2))
172 | (TLExpr::Mod(l1, r1), TLExpr::Mod(l2, r2))
173 | (TLExpr::Min(l1, r1), TLExpr::Min(l2, r2))
174 | (TLExpr::Max(l1, r1), TLExpr::Max(l2, r2))
175 | (TLExpr::Eq(l1, r1), TLExpr::Eq(l2, r2))
176 | (TLExpr::Lt(l1, r1), TLExpr::Lt(l2, r2))
177 | (TLExpr::Gt(l1, r1), TLExpr::Gt(l2, r2))
178 | (TLExpr::Lte(l1, r1), TLExpr::Lte(l2, r2))
179 | (TLExpr::Gte(l1, r1), TLExpr::Gte(l2, r2)) => {
180 path.push("left".to_string());
181 let left_diff = diff_exprs_impl(l1, l2, path);
182 path.pop();
183
184 if !left_diff.is_identical() {
185 return left_diff;
186 }
187
188 path.push("right".to_string());
189 let right_diff = diff_exprs_impl(r1, r2, path);
190 path.pop();
191
192 right_diff
193 }
194 (TLExpr::Not(e1), TLExpr::Not(e2))
195 | (TLExpr::Score(e1), TLExpr::Score(e2))
196 | (TLExpr::Abs(e1), TLExpr::Abs(e2))
197 | (TLExpr::Floor(e1), TLExpr::Floor(e2))
198 | (TLExpr::Ceil(e1), TLExpr::Ceil(e2))
199 | (TLExpr::Round(e1), TLExpr::Round(e2))
200 | (TLExpr::Sqrt(e1), TLExpr::Sqrt(e2))
201 | (TLExpr::Exp(e1), TLExpr::Exp(e2))
202 | (TLExpr::Log(e1), TLExpr::Log(e2))
203 | (TLExpr::Sin(e1), TLExpr::Sin(e2))
204 | (TLExpr::Cos(e1), TLExpr::Cos(e2))
205 | (TLExpr::Tan(e1), TLExpr::Tan(e2)) => {
206 path.push("inner".to_string());
207 let diff = diff_exprs_impl(e1, e2, path);
208 path.pop();
209 diff
210 }
211 (
212 TLExpr::Exists {
213 var: v1,
214 domain: d1,
215 body: b1,
216 },
217 TLExpr::Exists {
218 var: v2,
219 domain: d2,
220 body: b2,
221 },
222 )
223 | (
224 TLExpr::ForAll {
225 var: v1,
226 domain: d1,
227 body: b1,
228 },
229 TLExpr::ForAll {
230 var: v2,
231 domain: d2,
232 body: b2,
233 },
234 ) => {
235 if v1 != v2 || d1 != d2 {
236 return ExprDiff::QuantifierMismatch {
237 left_var: v1.clone(),
238 right_var: v2.clone(),
239 left_domain: d1.clone(),
240 right_domain: d2.clone(),
241 };
242 }
243
244 path.push("body".to_string());
245 let diff = diff_exprs_impl(b1, b2, path);
246 path.pop();
247 diff
248 }
249 (
250 TLExpr::IfThenElse {
251 condition: c1,
252 then_branch: t1,
253 else_branch: e1,
254 },
255 TLExpr::IfThenElse {
256 condition: c2,
257 then_branch: t2,
258 else_branch: e2,
259 },
260 ) => {
261 path.push("condition".to_string());
262 let cond_diff = diff_exprs_impl(c1, c2, path);
263 path.pop();
264
265 if !cond_diff.is_identical() {
266 return cond_diff;
267 }
268
269 path.push("then".to_string());
270 let then_diff = diff_exprs_impl(t1, t2, path);
271 path.pop();
272
273 if !then_diff.is_identical() {
274 return then_diff;
275 }
276
277 path.push("else".to_string());
278 let else_diff = diff_exprs_impl(e1, e2, path);
279 path.pop();
280
281 else_diff
282 }
283 (TLExpr::Constant(c1), TLExpr::Constant(c2)) => {
284 if (c1 - c2).abs() < f64::EPSILON {
285 ExprDiff::Identical
286 } else {
287 ExprDiff::SubexprMismatch {
288 path: path.clone(),
289 left: format!("{}", c1),
290 right: format!("{}", c2),
291 }
292 }
293 }
294 (TLExpr::SymbolLiteral(s1), TLExpr::SymbolLiteral(s2)) => {
295 if s1 == s2 {
296 ExprDiff::Identical
297 } else {
298 ExprDiff::SubexprMismatch {
299 path: path.clone(),
300 left: format!(":{s1}"),
301 right: format!(":{s2}"),
302 }
303 }
304 }
305 (
306 TLExpr::Match {
307 scrutinee: s1,
308 arms: a1,
309 },
310 TLExpr::Match {
311 scrutinee: s2,
312 arms: a2,
313 },
314 ) => {
315 path.push("scrutinee".to_string());
316 let sd = diff_exprs_impl(s1, s2, path);
317 path.pop();
318 if !matches!(sd, ExprDiff::Identical) {
319 return sd;
320 }
321 if a1.len() != a2.len() {
322 return ExprDiff::SubexprMismatch {
323 path: path.clone(),
324 left: format!("{} arms", a1.len()),
325 right: format!("{} arms", a2.len()),
326 };
327 }
328 for (i, ((p1, b1), (p2, b2))) in a1.iter().zip(a2.iter()).enumerate() {
329 if p1 != p2 {
330 return ExprDiff::SubexprMismatch {
331 path: path.clone(),
332 left: format!("arm[{i}] pattern {p1}"),
333 right: format!("arm[{i}] pattern {p2}"),
334 };
335 }
336 path.push(format!("arm[{i}]"));
337 let bd = diff_exprs_impl(b1, b2, path);
338 path.pop();
339 if !matches!(bd, ExprDiff::Identical) {
340 return bd;
341 }
342 }
343 ExprDiff::Identical
344 }
345 _ => ExprDiff::TypeMismatch {
346 left: format!("{:?}", left)
347 .split('(')
348 .next()
349 .unwrap_or("unknown")
350 .to_string(),
351 right: format!("{:?}", right)
352 .split('(')
353 .next()
354 .unwrap_or("unknown")
355 .to_string(),
356 },
357 }
358}
359
360pub fn diff_graphs(left: &EinsumGraph, right: &EinsumGraph) -> GraphDiff {
362 let left_tensors: HashSet<_> = left.tensors.iter().collect();
363 let right_tensors: HashSet<_> = right.tensors.iter().collect();
364
365 let left_only_tensors: Vec<String> = left_tensors
366 .difference(&right_tensors)
367 .map(|s| s.to_string())
368 .collect();
369 let right_only_tensors: Vec<String> = right_tensors
370 .difference(&left_tensors)
371 .map(|s| s.to_string())
372 .collect();
373
374 let node_differences = diff_nodes(&left.nodes, &right.nodes);
375
376 let left_only_nodes = if left.nodes.len() > right.nodes.len() {
377 left.nodes.len() - right.nodes.len()
378 } else {
379 0
380 };
381 let right_only_nodes = if right.nodes.len() > left.nodes.len() {
382 right.nodes.len() - left.nodes.len()
383 } else {
384 0
385 };
386
387 let output_differences = diff_outputs(&left.outputs, &right.outputs);
388
389 GraphDiff {
390 left_only_tensors,
391 right_only_tensors,
392 left_only_nodes,
393 right_only_nodes,
394 node_differences,
395 output_differences,
396 }
397}
398
399fn diff_nodes(left: &[EinsumNode], right: &[EinsumNode]) -> Vec<NodeDiff> {
400 let mut differences = Vec::new();
401 let min_len = left.len().min(right.len());
402
403 for i in 0..min_len {
404 if let Some(diff) = diff_node(&left[i], &right[i], i) {
405 differences.push(diff);
406 }
407 }
408
409 differences
410}
411
412fn diff_node(left: &EinsumNode, right: &EinsumNode, index: usize) -> Option<NodeDiff> {
413 if left.inputs != right.inputs {
414 return Some(NodeDiff {
415 node_index: index,
416 description: format!("Different inputs: {:?} vs {:?}", left.inputs, right.inputs),
417 });
418 }
419
420 if left.outputs != right.outputs {
421 return Some(NodeDiff {
422 node_index: index,
423 description: format!(
424 "Different outputs: {:?} vs {:?}",
425 left.outputs, right.outputs
426 ),
427 });
428 }
429
430 if !ops_equal(&left.op, &right.op) {
431 return Some(NodeDiff {
432 node_index: index,
433 description: format!("Different operations: {:?} vs {:?}", left.op, right.op),
434 });
435 }
436
437 None
438}
439
440fn ops_equal(left: &OpType, right: &OpType) -> bool {
441 std::mem::discriminant(left) == std::mem::discriminant(right)
443}
444
445fn diff_outputs(left: &[usize], right: &[usize]) -> Vec<String> {
446 let mut differences = Vec::new();
447
448 if left.len() != right.len() {
449 differences.push(format!(
450 "Different number of outputs: {} vs {}",
451 left.len(),
452 right.len()
453 ));
454 }
455
456 for (i, (l, r)) in left.iter().zip(right.iter()).enumerate() {
457 if l != r {
458 differences.push(format!("Output {} differs: {} vs {}", i, l, r));
459 }
460 }
461
462 differences
463}
464
465#[cfg(test)]
466mod tests {
467 use super::*;
468 use crate::Term;
469
470 #[test]
471 fn test_identical_exprs() {
472 let expr1 = TLExpr::pred("p", vec![Term::var("x")]);
473 let expr2 = TLExpr::pred("p", vec![Term::var("x")]);
474
475 let diff = diff_exprs(&expr1, &expr2);
476 assert!(diff.is_identical());
477 }
478
479 #[test]
480 fn test_different_predicates() {
481 let expr1 = TLExpr::pred("p", vec![Term::var("x")]);
482 let expr2 = TLExpr::pred("q", vec![Term::var("x")]);
483
484 let diff = diff_exprs(&expr1, &expr2);
485 assert!(!diff.is_identical());
486 assert!(matches!(diff, ExprDiff::PredicateMismatch { .. }));
487 }
488
489 #[test]
490 fn test_different_types() {
491 let expr1 = TLExpr::pred("p", vec![Term::var("x")]);
492 let expr2 = TLExpr::constant(1.0);
493
494 let diff = diff_exprs(&expr1, &expr2);
495 assert!(!diff.is_identical());
496 assert!(matches!(diff, ExprDiff::TypeMismatch { .. }));
497 }
498
499 #[test]
500 fn test_nested_and_difference() {
501 let expr1 = TLExpr::and(
502 TLExpr::pred("p", vec![Term::var("x")]),
503 TLExpr::pred("q", vec![Term::var("y")]),
504 );
505 let expr2 = TLExpr::and(
506 TLExpr::pred("p", vec![Term::var("x")]),
507 TLExpr::pred("r", vec![Term::var("y")]),
508 );
509
510 let diff = diff_exprs(&expr1, &expr2);
511 assert!(!diff.is_identical());
512 }
513
514 #[test]
515 fn test_quantifier_difference() {
516 let expr1 = TLExpr::exists("x", "Domain1", TLExpr::pred("p", vec![Term::var("x")]));
517 let expr2 = TLExpr::exists("y", "Domain2", TLExpr::pred("p", vec![Term::var("y")]));
518
519 let diff = diff_exprs(&expr1, &expr2);
520 assert!(!diff.is_identical());
521 assert!(matches!(diff, ExprDiff::QuantifierMismatch { .. }));
522 }
523
524 #[test]
525 fn test_identical_graphs() {
526 let graph1 = EinsumGraph {
527 tensors: vec!["t0".to_string()],
528 nodes: vec![],
529 inputs: vec![],
530 outputs: vec![0],
531 tensor_metadata: std::collections::HashMap::new(),
532 };
533 let graph2 = EinsumGraph {
534 tensors: vec!["t0".to_string()],
535 nodes: vec![],
536 inputs: vec![],
537 outputs: vec![0],
538 tensor_metadata: std::collections::HashMap::new(),
539 };
540
541 let diff = diff_graphs(&graph1, &graph2);
542 assert!(diff.is_identical());
543 }
544
545 #[test]
546 fn test_different_tensor_count() {
547 let graph1 = EinsumGraph {
548 tensors: vec!["t0".to_string(), "t1".to_string()],
549 nodes: vec![],
550 inputs: vec![],
551 outputs: vec![],
552 tensor_metadata: std::collections::HashMap::new(),
553 };
554 let graph2 = EinsumGraph {
555 tensors: vec!["t0".to_string()],
556 nodes: vec![],
557 inputs: vec![],
558 outputs: vec![],
559 tensor_metadata: std::collections::HashMap::new(),
560 };
561
562 let diff = diff_graphs(&graph1, &graph2);
563 assert!(!diff.is_identical());
564 assert_eq!(diff.left_only_tensors.len(), 1);
565 }
566
567 #[test]
568 fn test_different_outputs() {
569 let graph1 = EinsumGraph {
570 tensors: vec!["t0".to_string()],
571 nodes: vec![],
572 inputs: vec![],
573 outputs: vec![0],
574 tensor_metadata: std::collections::HashMap::new(),
575 };
576 let graph2 = EinsumGraph {
577 tensors: vec!["t0".to_string()],
578 nodes: vec![],
579 inputs: vec![],
580 outputs: vec![1],
581 tensor_metadata: std::collections::HashMap::new(),
582 };
583
584 let diff = diff_graphs(&graph1, &graph2);
585 assert!(!diff.is_identical());
586 assert!(!diff.output_differences.is_empty());
587 }
588
589 #[test]
590 fn test_diff_summary() {
591 let diff = GraphDiff {
592 left_only_tensors: vec!["t1".to_string()],
593 right_only_tensors: vec!["t2".to_string()],
594 left_only_nodes: 0,
595 right_only_nodes: 0,
596 node_differences: vec![],
597 output_differences: vec![],
598 };
599
600 let summary = diff.summary();
601 assert!(summary.contains("tensors only in left"));
602 assert!(summary.contains("tensors only in right"));
603 }
604}