1use crate::{EinsumGraph, EinsumNode, OpType, TLExpr};
7use std::collections::HashSet;
8
9#[derive(Debug, Clone, PartialEq, Eq)]
11pub enum ExprDiff {
12 Identical,
14 TypeMismatch { left: String, right: String },
16 PredicateMismatch { left: String, right: String },
18 SubexprMismatch {
20 path: Vec<String>,
21 left: String,
22 right: String,
23 },
24 QuantifierMismatch {
26 left_var: String,
27 right_var: String,
28 left_domain: String,
29 right_domain: String,
30 },
31}
32
33#[derive(Debug, Clone)]
35pub struct GraphDiff {
36 pub left_only_tensors: Vec<String>,
38 pub right_only_tensors: Vec<String>,
40 pub left_only_nodes: usize,
42 pub right_only_nodes: usize,
44 pub node_differences: Vec<NodeDiff>,
46 pub output_differences: Vec<String>,
48}
49
50#[derive(Debug, Clone)]
52pub struct NodeDiff {
53 pub node_index: usize,
54 pub description: String,
55}
56
57impl ExprDiff {
58 pub fn is_identical(&self) -> bool {
60 matches!(self, ExprDiff::Identical)
61 }
62
63 pub fn description(&self) -> String {
65 match self {
66 ExprDiff::Identical => "Expressions are identical".to_string(),
67 ExprDiff::TypeMismatch { left, right } => {
68 format!("Type mismatch: left={}, right={}", left, right)
69 }
70 ExprDiff::PredicateMismatch { left, right } => {
71 format!("Predicate mismatch: left={}, right={}", left, right)
72 }
73 ExprDiff::SubexprMismatch { path, left, right } => {
74 format!(
75 "Subexpression mismatch at {}: left={}, right={}",
76 path.join("/"),
77 left,
78 right
79 )
80 }
81 ExprDiff::QuantifierMismatch {
82 left_var,
83 right_var,
84 left_domain,
85 right_domain,
86 } => {
87 format!(
88 "Quantifier mismatch: left=({}, {}), right=({}, {})",
89 left_var, left_domain, right_var, right_domain
90 )
91 }
92 }
93 }
94}
95
96impl GraphDiff {
97 pub fn is_identical(&self) -> bool {
99 self.left_only_tensors.is_empty()
100 && self.right_only_tensors.is_empty()
101 && self.left_only_nodes == 0
102 && self.right_only_nodes == 0
103 && self.node_differences.is_empty()
104 && self.output_differences.is_empty()
105 }
106
107 pub fn summary(&self) -> String {
109 if self.is_identical() {
110 return "Graphs are identical".to_string();
111 }
112
113 let mut parts = Vec::new();
114
115 if !self.left_only_tensors.is_empty() {
116 parts.push(format!(
117 "{} tensors only in left",
118 self.left_only_tensors.len()
119 ));
120 }
121 if !self.right_only_tensors.is_empty() {
122 parts.push(format!(
123 "{} tensors only in right",
124 self.right_only_tensors.len()
125 ));
126 }
127 if self.left_only_nodes > 0 {
128 parts.push(format!("{} nodes only in left", self.left_only_nodes));
129 }
130 if self.right_only_nodes > 0 {
131 parts.push(format!("{} nodes only in right", self.right_only_nodes));
132 }
133 if !self.node_differences.is_empty() {
134 parts.push(format!("{} node differences", self.node_differences.len()));
135 }
136 if !self.output_differences.is_empty() {
137 parts.push(format!(
138 "{} output differences",
139 self.output_differences.len()
140 ));
141 }
142
143 parts.join(", ")
144 }
145}
146
147pub fn diff_exprs(left: &TLExpr, right: &TLExpr) -> ExprDiff {
149 diff_exprs_impl(left, right, &mut Vec::new())
150}
151
152fn diff_exprs_impl(left: &TLExpr, right: &TLExpr, path: &mut Vec<String>) -> ExprDiff {
153 match (left, right) {
154 (TLExpr::Pred { name: n1, args: a1 }, TLExpr::Pred { name: n2, args: a2 }) => {
155 if n1 != n2 || a1.len() != a2.len() {
156 ExprDiff::PredicateMismatch {
157 left: format!("{}({})", n1, a1.len()),
158 right: format!("{}({})", n2, a2.len()),
159 }
160 } else {
161 ExprDiff::Identical
162 }
163 }
164 (TLExpr::And(l1, r1), TLExpr::And(l2, r2))
165 | (TLExpr::Or(l1, r1), TLExpr::Or(l2, r2))
166 | (TLExpr::Imply(l1, r1), TLExpr::Imply(l2, r2))
167 | (TLExpr::Add(l1, r1), TLExpr::Add(l2, r2))
168 | (TLExpr::Sub(l1, r1), TLExpr::Sub(l2, r2))
169 | (TLExpr::Mul(l1, r1), TLExpr::Mul(l2, r2))
170 | (TLExpr::Div(l1, r1), TLExpr::Div(l2, r2))
171 | (TLExpr::Pow(l1, r1), TLExpr::Pow(l2, r2))
172 | (TLExpr::Mod(l1, r1), TLExpr::Mod(l2, r2))
173 | (TLExpr::Min(l1, r1), TLExpr::Min(l2, r2))
174 | (TLExpr::Max(l1, r1), TLExpr::Max(l2, r2))
175 | (TLExpr::Eq(l1, r1), TLExpr::Eq(l2, r2))
176 | (TLExpr::Lt(l1, r1), TLExpr::Lt(l2, r2))
177 | (TLExpr::Gt(l1, r1), TLExpr::Gt(l2, r2))
178 | (TLExpr::Lte(l1, r1), TLExpr::Lte(l2, r2))
179 | (TLExpr::Gte(l1, r1), TLExpr::Gte(l2, r2)) => {
180 path.push("left".to_string());
181 let left_diff = diff_exprs_impl(l1, l2, path);
182 path.pop();
183
184 if !left_diff.is_identical() {
185 return left_diff;
186 }
187
188 path.push("right".to_string());
189 let right_diff = diff_exprs_impl(r1, r2, path);
190 path.pop();
191
192 right_diff
193 }
194 (TLExpr::Not(e1), TLExpr::Not(e2))
195 | (TLExpr::Score(e1), TLExpr::Score(e2))
196 | (TLExpr::Abs(e1), TLExpr::Abs(e2))
197 | (TLExpr::Floor(e1), TLExpr::Floor(e2))
198 | (TLExpr::Ceil(e1), TLExpr::Ceil(e2))
199 | (TLExpr::Round(e1), TLExpr::Round(e2))
200 | (TLExpr::Sqrt(e1), TLExpr::Sqrt(e2))
201 | (TLExpr::Exp(e1), TLExpr::Exp(e2))
202 | (TLExpr::Log(e1), TLExpr::Log(e2))
203 | (TLExpr::Sin(e1), TLExpr::Sin(e2))
204 | (TLExpr::Cos(e1), TLExpr::Cos(e2))
205 | (TLExpr::Tan(e1), TLExpr::Tan(e2)) => {
206 path.push("inner".to_string());
207 let diff = diff_exprs_impl(e1, e2, path);
208 path.pop();
209 diff
210 }
211 (
212 TLExpr::Exists {
213 var: v1,
214 domain: d1,
215 body: b1,
216 },
217 TLExpr::Exists {
218 var: v2,
219 domain: d2,
220 body: b2,
221 },
222 )
223 | (
224 TLExpr::ForAll {
225 var: v1,
226 domain: d1,
227 body: b1,
228 },
229 TLExpr::ForAll {
230 var: v2,
231 domain: d2,
232 body: b2,
233 },
234 ) => {
235 if v1 != v2 || d1 != d2 {
236 return ExprDiff::QuantifierMismatch {
237 left_var: v1.clone(),
238 right_var: v2.clone(),
239 left_domain: d1.clone(),
240 right_domain: d2.clone(),
241 };
242 }
243
244 path.push("body".to_string());
245 let diff = diff_exprs_impl(b1, b2, path);
246 path.pop();
247 diff
248 }
249 (
250 TLExpr::IfThenElse {
251 condition: c1,
252 then_branch: t1,
253 else_branch: e1,
254 },
255 TLExpr::IfThenElse {
256 condition: c2,
257 then_branch: t2,
258 else_branch: e2,
259 },
260 ) => {
261 path.push("condition".to_string());
262 let cond_diff = diff_exprs_impl(c1, c2, path);
263 path.pop();
264
265 if !cond_diff.is_identical() {
266 return cond_diff;
267 }
268
269 path.push("then".to_string());
270 let then_diff = diff_exprs_impl(t1, t2, path);
271 path.pop();
272
273 if !then_diff.is_identical() {
274 return then_diff;
275 }
276
277 path.push("else".to_string());
278 let else_diff = diff_exprs_impl(e1, e2, path);
279 path.pop();
280
281 else_diff
282 }
283 (TLExpr::Constant(c1), TLExpr::Constant(c2)) => {
284 if (c1 - c2).abs() < f64::EPSILON {
285 ExprDiff::Identical
286 } else {
287 ExprDiff::SubexprMismatch {
288 path: path.clone(),
289 left: format!("{}", c1),
290 right: format!("{}", c2),
291 }
292 }
293 }
294 _ => ExprDiff::TypeMismatch {
295 left: format!("{:?}", left).split('(').next().unwrap().to_string(),
296 right: format!("{:?}", right)
297 .split('(')
298 .next()
299 .unwrap()
300 .to_string(),
301 },
302 }
303}
304
305pub fn diff_graphs(left: &EinsumGraph, right: &EinsumGraph) -> GraphDiff {
307 let left_tensors: HashSet<_> = left.tensors.iter().collect();
308 let right_tensors: HashSet<_> = right.tensors.iter().collect();
309
310 let left_only_tensors: Vec<String> = left_tensors
311 .difference(&right_tensors)
312 .map(|s| s.to_string())
313 .collect();
314 let right_only_tensors: Vec<String> = right_tensors
315 .difference(&left_tensors)
316 .map(|s| s.to_string())
317 .collect();
318
319 let node_differences = diff_nodes(&left.nodes, &right.nodes);
320
321 let left_only_nodes = if left.nodes.len() > right.nodes.len() {
322 left.nodes.len() - right.nodes.len()
323 } else {
324 0
325 };
326 let right_only_nodes = if right.nodes.len() > left.nodes.len() {
327 right.nodes.len() - left.nodes.len()
328 } else {
329 0
330 };
331
332 let output_differences = diff_outputs(&left.outputs, &right.outputs);
333
334 GraphDiff {
335 left_only_tensors,
336 right_only_tensors,
337 left_only_nodes,
338 right_only_nodes,
339 node_differences,
340 output_differences,
341 }
342}
343
344fn diff_nodes(left: &[EinsumNode], right: &[EinsumNode]) -> Vec<NodeDiff> {
345 let mut differences = Vec::new();
346 let min_len = left.len().min(right.len());
347
348 for i in 0..min_len {
349 if let Some(diff) = diff_node(&left[i], &right[i], i) {
350 differences.push(diff);
351 }
352 }
353
354 differences
355}
356
357fn diff_node(left: &EinsumNode, right: &EinsumNode, index: usize) -> Option<NodeDiff> {
358 if left.inputs != right.inputs {
359 return Some(NodeDiff {
360 node_index: index,
361 description: format!("Different inputs: {:?} vs {:?}", left.inputs, right.inputs),
362 });
363 }
364
365 if left.outputs != right.outputs {
366 return Some(NodeDiff {
367 node_index: index,
368 description: format!(
369 "Different outputs: {:?} vs {:?}",
370 left.outputs, right.outputs
371 ),
372 });
373 }
374
375 if !ops_equal(&left.op, &right.op) {
376 return Some(NodeDiff {
377 node_index: index,
378 description: format!("Different operations: {:?} vs {:?}", left.op, right.op),
379 });
380 }
381
382 None
383}
384
385fn ops_equal(left: &OpType, right: &OpType) -> bool {
386 std::mem::discriminant(left) == std::mem::discriminant(right)
388}
389
390fn diff_outputs(left: &[usize], right: &[usize]) -> Vec<String> {
391 let mut differences = Vec::new();
392
393 if left.len() != right.len() {
394 differences.push(format!(
395 "Different number of outputs: {} vs {}",
396 left.len(),
397 right.len()
398 ));
399 }
400
401 for (i, (l, r)) in left.iter().zip(right.iter()).enumerate() {
402 if l != r {
403 differences.push(format!("Output {} differs: {} vs {}", i, l, r));
404 }
405 }
406
407 differences
408}
409
410#[cfg(test)]
411mod tests {
412 use super::*;
413 use crate::Term;
414
415 #[test]
416 fn test_identical_exprs() {
417 let expr1 = TLExpr::pred("p", vec![Term::var("x")]);
418 let expr2 = TLExpr::pred("p", vec![Term::var("x")]);
419
420 let diff = diff_exprs(&expr1, &expr2);
421 assert!(diff.is_identical());
422 }
423
424 #[test]
425 fn test_different_predicates() {
426 let expr1 = TLExpr::pred("p", vec![Term::var("x")]);
427 let expr2 = TLExpr::pred("q", vec![Term::var("x")]);
428
429 let diff = diff_exprs(&expr1, &expr2);
430 assert!(!diff.is_identical());
431 assert!(matches!(diff, ExprDiff::PredicateMismatch { .. }));
432 }
433
434 #[test]
435 fn test_different_types() {
436 let expr1 = TLExpr::pred("p", vec![Term::var("x")]);
437 let expr2 = TLExpr::constant(1.0);
438
439 let diff = diff_exprs(&expr1, &expr2);
440 assert!(!diff.is_identical());
441 assert!(matches!(diff, ExprDiff::TypeMismatch { .. }));
442 }
443
444 #[test]
445 fn test_nested_and_difference() {
446 let expr1 = TLExpr::and(
447 TLExpr::pred("p", vec![Term::var("x")]),
448 TLExpr::pred("q", vec![Term::var("y")]),
449 );
450 let expr2 = TLExpr::and(
451 TLExpr::pred("p", vec![Term::var("x")]),
452 TLExpr::pred("r", vec![Term::var("y")]),
453 );
454
455 let diff = diff_exprs(&expr1, &expr2);
456 assert!(!diff.is_identical());
457 }
458
459 #[test]
460 fn test_quantifier_difference() {
461 let expr1 = TLExpr::exists("x", "Domain1", TLExpr::pred("p", vec![Term::var("x")]));
462 let expr2 = TLExpr::exists("y", "Domain2", TLExpr::pred("p", vec![Term::var("y")]));
463
464 let diff = diff_exprs(&expr1, &expr2);
465 assert!(!diff.is_identical());
466 assert!(matches!(diff, ExprDiff::QuantifierMismatch { .. }));
467 }
468
469 #[test]
470 fn test_identical_graphs() {
471 let graph1 = EinsumGraph {
472 tensors: vec!["t0".to_string()],
473 nodes: vec![],
474 inputs: vec![],
475 outputs: vec![0],
476 tensor_metadata: std::collections::HashMap::new(),
477 };
478 let graph2 = EinsumGraph {
479 tensors: vec!["t0".to_string()],
480 nodes: vec![],
481 inputs: vec![],
482 outputs: vec![0],
483 tensor_metadata: std::collections::HashMap::new(),
484 };
485
486 let diff = diff_graphs(&graph1, &graph2);
487 assert!(diff.is_identical());
488 }
489
490 #[test]
491 fn test_different_tensor_count() {
492 let graph1 = EinsumGraph {
493 tensors: vec!["t0".to_string(), "t1".to_string()],
494 nodes: vec![],
495 inputs: vec![],
496 outputs: vec![],
497 tensor_metadata: std::collections::HashMap::new(),
498 };
499 let graph2 = EinsumGraph {
500 tensors: vec!["t0".to_string()],
501 nodes: vec![],
502 inputs: vec![],
503 outputs: vec![],
504 tensor_metadata: std::collections::HashMap::new(),
505 };
506
507 let diff = diff_graphs(&graph1, &graph2);
508 assert!(!diff.is_identical());
509 assert_eq!(diff.left_only_tensors.len(), 1);
510 }
511
512 #[test]
513 fn test_different_outputs() {
514 let graph1 = EinsumGraph {
515 tensors: vec!["t0".to_string()],
516 nodes: vec![],
517 inputs: vec![],
518 outputs: vec![0],
519 tensor_metadata: std::collections::HashMap::new(),
520 };
521 let graph2 = EinsumGraph {
522 tensors: vec!["t0".to_string()],
523 nodes: vec![],
524 inputs: vec![],
525 outputs: vec![1],
526 tensor_metadata: std::collections::HashMap::new(),
527 };
528
529 let diff = diff_graphs(&graph1, &graph2);
530 assert!(!diff.is_identical());
531 assert!(!diff.output_differences.is_empty());
532 }
533
534 #[test]
535 fn test_diff_summary() {
536 let diff = GraphDiff {
537 left_only_tensors: vec!["t1".to_string()],
538 right_only_tensors: vec!["t2".to_string()],
539 left_only_nodes: 0,
540 right_only_nodes: 0,
541 node_differences: vec![],
542 output_differences: vec![],
543 };
544
545 let summary = diff.summary();
546 assert!(summary.contains("tensors only in left"));
547 assert!(summary.contains("tensors only in right"));
548 }
549}