tensorlogic_cli/
snapshot.rs

1//! Snapshot testing for output consistency
2//!
3//! This module provides snapshot testing capabilities to ensure that compilation
4//! outputs remain consistent across code changes and refactorings.
5
6use anyhow::Result;
7use serde::{Deserialize, Serialize};
8use std::fs;
9use std::path::{Path, PathBuf};
10use tensorlogic_compiler::{compile_to_einsum_with_context, CompilerContext};
11use tensorlogic_ir::{export_to_dot, TLExpr};
12
13use crate::analysis::GraphMetrics;
14
15/// A snapshot of compilation output for testing
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct CompilationSnapshot {
18    /// The expression that was compiled
19    pub expression: String,
20
21    /// Compilation strategy used
22    pub strategy: String,
23
24    /// Domains defined
25    pub domains: Vec<(String, usize)>,
26
27    /// Number of tensors in the graph
28    pub tensor_count: usize,
29
30    /// Number of nodes in the graph
31    pub node_count: usize,
32
33    /// Graph depth
34    pub depth: usize,
35
36    /// Operation breakdown
37    pub operations: std::collections::HashMap<String, usize>,
38
39    /// Estimated FLOPs
40    pub estimated_flops: u64,
41
42    /// Estimated memory (bytes)
43    pub estimated_memory: u64,
44
45    /// DOT format representation (for structural comparison)
46    pub dot_output: String,
47
48    /// JSON serialization (for complete graph structure)
49    pub json_output: String,
50
51    /// Creation timestamp
52    pub created_at: String,
53}
54
55impl CompilationSnapshot {
56    /// Create a snapshot from an expression and context
57    pub fn create(expr: &TLExpr, context: &CompilerContext, expr_string: &str) -> Result<Self> {
58        // Compile the expression
59        let mut ctx = context.clone();
60        let graph = compile_to_einsum_with_context(expr, &mut ctx)?;
61
62        // Analyze the graph
63        let metrics = GraphMetrics::analyze(&graph);
64
65        // Generate outputs
66        let dot_output = export_to_dot(&graph);
67        let json_output = serde_json::to_string_pretty(&graph)?;
68
69        // Extract domains
70        let domains: Vec<(String, usize)> = context
71            .domains
72            .iter()
73            .map(|(k, v)| (k.clone(), v.cardinality))
74            .collect();
75
76        Ok(Self {
77            expression: expr_string.to_string(),
78            strategy: format!("{:?}", context.config.and_strategy),
79            domains,
80            tensor_count: metrics.tensor_count,
81            node_count: metrics.node_count,
82            depth: metrics.depth,
83            operations: metrics.op_breakdown,
84            estimated_flops: metrics.estimated_flops,
85            estimated_memory: metrics.estimated_memory,
86            dot_output,
87            json_output,
88            created_at: chrono::Utc::now().to_rfc3339(),
89        })
90    }
91
92    /// Save snapshot to a file
93    pub fn save(&self, path: &Path) -> Result<()> {
94        let json = serde_json::to_string_pretty(self)?;
95        fs::write(path, json)?;
96        Ok(())
97    }
98
99    /// Load snapshot from a file
100    pub fn load(path: &Path) -> Result<Self> {
101        let content = fs::read_to_string(path)?;
102        let snapshot: Self = serde_json::from_str(&content)?;
103        Ok(snapshot)
104    }
105
106    /// Compare this snapshot with another (strict mode)
107    ///
108    /// Note: Timestamps are intentionally excluded from comparison
109    pub fn compare(&self, other: &Self) -> SnapshotDiff {
110        self.compare_with_options(other, true)
111    }
112
113    /// Compare this snapshot with another with options
114    ///
115    /// - `strict_dot`: If true, compare DOT output character-by-character
116    ///   If false, skip DOT comparison (useful when DOT has non-deterministic ordering)
117    pub fn compare_with_options(&self, other: &Self, strict_dot: bool) -> SnapshotDiff {
118        let mut differences = Vec::new();
119
120        // Check expression
121        if self.expression != other.expression {
122            differences.push(format!(
123                "Expression changed: '{}' -> '{}'",
124                self.expression, other.expression
125            ));
126        }
127
128        // Check strategy
129        if self.strategy != other.strategy {
130            differences.push(format!(
131                "Strategy changed: {} -> {}",
132                self.strategy, other.strategy
133            ));
134        }
135
136        // Check domains
137        if self.domains != other.domains {
138            differences.push(format!(
139                "Domains changed: {:?} -> {:?}",
140                self.domains, other.domains
141            ));
142        }
143
144        // Check tensor count
145        if self.tensor_count != other.tensor_count {
146            differences.push(format!(
147                "Tensor count changed: {} -> {}",
148                self.tensor_count, other.tensor_count
149            ));
150        }
151
152        // Check node count
153        if self.node_count != other.node_count {
154            differences.push(format!(
155                "Node count changed: {} -> {}",
156                self.node_count, other.node_count
157            ));
158        }
159
160        // Check depth
161        if self.depth != other.depth {
162            differences.push(format!("Depth changed: {} -> {}", self.depth, other.depth));
163        }
164
165        // Check operations
166        if self.operations != other.operations {
167            differences.push(format!(
168                "Operations changed: {:?} -> {:?}",
169                self.operations, other.operations
170            ));
171        }
172
173        // Check FLOPs (allow small variation due to floating point)
174        let flops_diff = self.estimated_flops.abs_diff(other.estimated_flops);
175
176        if flops_diff > 100 {
177            // Allow 100 FLOPs tolerance
178            differences.push(format!(
179                "Estimated FLOPs changed significantly: {} -> {}",
180                self.estimated_flops, other.estimated_flops
181            ));
182        }
183
184        // Check memory (allow small variation)
185        let mem_diff = self.estimated_memory.abs_diff(other.estimated_memory);
186
187        if mem_diff > 1000 {
188            // Allow 1KB tolerance
189            differences.push(format!(
190                "Estimated memory changed significantly: {} -> {}",
191                self.estimated_memory, other.estimated_memory
192            ));
193        }
194
195        // Check DOT output (structural comparison) - only if strict mode
196        if strict_dot && self.dot_output != other.dot_output {
197            differences.push("DOT output structure changed".to_string());
198        }
199
200        SnapshotDiff {
201            identical: differences.is_empty(),
202            differences,
203        }
204    }
205}
206
207/// Result of comparing two snapshots
208#[derive(Debug, Clone)]
209pub struct SnapshotDiff {
210    /// Whether the snapshots are identical
211    pub identical: bool,
212
213    /// List of differences found
214    pub differences: Vec<String>,
215}
216
217impl SnapshotDiff {
218    /// Check if snapshots match
219    pub fn is_match(&self) -> bool {
220        self.identical
221    }
222
223    /// Print differences to stderr
224    #[allow(dead_code)]
225    pub fn print_diff(&self) {
226        if self.identical {
227            println!("✓ Snapshots match");
228        } else {
229            eprintln!("✗ Snapshots differ:");
230            for diff in &self.differences {
231                eprintln!("  - {}", diff);
232            }
233        }
234    }
235}
236
237/// Snapshot test suite manager
238pub struct SnapshotSuite {
239    /// Directory where snapshots are stored
240    snapshot_dir: PathBuf,
241
242    /// Test suite name
243    name: String,
244}
245
246impl SnapshotSuite {
247    /// Create a new snapshot suite
248    pub fn new(name: &str, snapshot_dir: PathBuf) -> Self {
249        Self {
250            snapshot_dir,
251            name: name.to_string(),
252        }
253    }
254
255    /// Get the path for a snapshot file
256    fn snapshot_path(&self, test_name: &str) -> PathBuf {
257        self.snapshot_dir
258            .join(format!("{}_{}.json", self.name, test_name))
259    }
260
261    /// Record a snapshot
262    pub fn record(
263        &self,
264        test_name: &str,
265        expr: &TLExpr,
266        context: &CompilerContext,
267        expr_string: &str,
268    ) -> Result<()> {
269        // Ensure snapshot directory exists
270        fs::create_dir_all(&self.snapshot_dir)?;
271
272        let snapshot = CompilationSnapshot::create(expr, context, expr_string)?;
273        let path = self.snapshot_path(test_name);
274        snapshot.save(&path)?;
275
276        println!("✓ Recorded snapshot: {}", test_name);
277        Ok(())
278    }
279
280    /// Verify against a recorded snapshot
281    pub fn verify(
282        &self,
283        test_name: &str,
284        expr: &TLExpr,
285        context: &CompilerContext,
286        expr_string: &str,
287    ) -> Result<SnapshotDiff> {
288        let path = self.snapshot_path(test_name);
289
290        if !path.exists() {
291            anyhow::bail!(
292                "Snapshot not found: {}. Run in record mode first.",
293                test_name
294            );
295        }
296
297        let recorded = CompilationSnapshot::load(&path)?;
298        let current = CompilationSnapshot::create(expr, context, expr_string)?;
299
300        Ok(recorded.compare(&current))
301    }
302
303    /// Update a snapshot (re-record)
304    pub fn update(
305        &self,
306        test_name: &str,
307        expr: &TLExpr,
308        context: &CompilerContext,
309        expr_string: &str,
310    ) -> Result<()> {
311        self.record(test_name, expr, context, expr_string)
312    }
313
314    /// List all snapshots in the suite
315    pub fn list_snapshots(&self) -> Result<Vec<String>> {
316        let mut snapshots = Vec::new();
317
318        if !self.snapshot_dir.exists() {
319            return Ok(snapshots);
320        }
321
322        for entry in fs::read_dir(&self.snapshot_dir)? {
323            let entry = entry?;
324            let path = entry.path();
325
326            if let Some(filename) = path.file_name() {
327                if let Some(name) = filename.to_str() {
328                    if name.starts_with(&self.name) && name.ends_with(".json") {
329                        // Extract test name
330                        let test_name = name
331                            .trim_start_matches(&format!("{}_", self.name))
332                            .trim_end_matches(".json");
333                        snapshots.push(test_name.to_string());
334                    }
335                }
336            }
337        }
338
339        Ok(snapshots)
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346    use tensorlogic_compiler::CompilationConfig;
347    use tensorlogic_ir::Term;
348
349    fn create_test_expr() -> TLExpr {
350        TLExpr::And(
351            Box::new(TLExpr::Pred {
352                name: "knows".to_string(),
353                args: vec![Term::Var("x".to_string()), Term::Var("y".to_string())],
354            }),
355            Box::new(TLExpr::Pred {
356                name: "likes".to_string(),
357                args: vec![Term::Var("y".to_string()), Term::Var("z".to_string())],
358            }),
359        )
360    }
361
362    fn create_test_context() -> CompilerContext {
363        let config = CompilationConfig::soft_differentiable();
364        let mut ctx = CompilerContext::with_config(config);
365        ctx.add_domain("D", 100);
366        ctx
367    }
368
369    #[test]
370    fn test_snapshot_creation() {
371        let expr = create_test_expr();
372        let context = create_test_context();
373
374        let snapshot =
375            CompilationSnapshot::create(&expr, &context, "knows(x, y) AND likes(y, z)").unwrap();
376
377        assert_eq!(snapshot.expression, "knows(x, y) AND likes(y, z)");
378        assert!(snapshot.tensor_count > 0);
379        assert!(snapshot.node_count > 0);
380        assert!(!snapshot.dot_output.is_empty());
381        assert!(!snapshot.json_output.is_empty());
382    }
383
384    #[test]
385    fn test_snapshot_save_load() {
386        let expr = create_test_expr();
387        let context = create_test_context();
388        let snapshot =
389            CompilationSnapshot::create(&expr, &context, "knows(x, y) AND likes(y, z)").unwrap();
390
391        let temp_dir = std::env::temp_dir();
392        let path = temp_dir.join("test_snapshot.json");
393
394        snapshot.save(&path).unwrap();
395        let loaded = CompilationSnapshot::load(&path).unwrap();
396
397        assert_eq!(snapshot.expression, loaded.expression);
398        assert_eq!(snapshot.tensor_count, loaded.tensor_count);
399        assert_eq!(snapshot.node_count, loaded.node_count);
400
401        // Cleanup
402        let _ = fs::remove_file(&path);
403    }
404
405    #[test]
406    fn test_snapshot_comparison_identical() {
407        let expr = create_test_expr();
408        let context = create_test_context();
409
410        let snapshot1 =
411            CompilationSnapshot::create(&expr, &context, "knows(x, y) AND likes(y, z)").unwrap();
412        let snapshot2 =
413            CompilationSnapshot::create(&expr, &context, "knows(x, y) AND likes(y, z)").unwrap();
414
415        // Skip strict DOT comparison since identical compilations might have
416        // different internal orderings (e.g., HashMap iteration order)
417        let diff = snapshot1.compare_with_options(&snapshot2, false);
418        if !diff.is_match() {
419            eprintln!("Differences found:");
420            for d in &diff.differences {
421                eprintln!("  {}", d);
422            }
423        }
424        assert!(diff.is_match());
425        assert!(diff.differences.is_empty());
426    }
427
428    #[test]
429    fn test_snapshot_comparison_different() {
430        let expr1 = create_test_expr();
431        let expr2 = TLExpr::Pred {
432            name: "knows".to_string(),
433            args: vec![Term::Var("x".to_string()), Term::Var("y".to_string())],
434        };
435
436        let context = create_test_context();
437
438        let snapshot1 =
439            CompilationSnapshot::create(&expr1, &context, "knows(x, y) AND likes(y, z)").unwrap();
440        let snapshot2 = CompilationSnapshot::create(&expr2, &context, "knows(x, y)").unwrap();
441
442        let diff = snapshot1.compare(&snapshot2);
443        assert!(!diff.is_match());
444        assert!(!diff.differences.is_empty());
445    }
446
447    #[test]
448    fn test_snapshot_suite() {
449        let temp_dir = std::env::temp_dir().join("tensorlogic_snapshots_test");
450        let suite = SnapshotSuite::new("test_suite", temp_dir.clone());
451
452        let expr = create_test_expr();
453        let context = create_test_context();
454
455        // Record a snapshot
456        suite
457            .record("test1", &expr, &context, "knows(x, y) AND likes(y, z)")
458            .unwrap();
459
460        // Verify against the snapshot (skip strict DOT comparison)
461        // Since we're comparing identical inputs, use non-strict comparison
462        let current =
463            CompilationSnapshot::create(&expr, &context, "knows(x, y) AND likes(y, z)").unwrap();
464        let path = suite.snapshot_path("test1");
465        let recorded = CompilationSnapshot::load(&path).unwrap();
466        let diff = recorded.compare_with_options(&current, false);
467
468        assert!(diff.is_match());
469
470        // List snapshots
471        let snapshots = suite.list_snapshots().unwrap();
472        assert!(snapshots.contains(&"test1".to_string()));
473
474        // Cleanup
475        let _ = fs::remove_dir_all(&temp_dir);
476    }
477}