1use anyhow::Result;
7use serde::{Deserialize, Serialize};
8use std::fs;
9use std::path::{Path, PathBuf};
10use tensorlogic_compiler::{compile_to_einsum_with_context, CompilerContext};
11use tensorlogic_ir::{export_to_dot, TLExpr};
12
13use crate::analysis::GraphMetrics;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct CompilationSnapshot {
18 pub expression: String,
20
21 pub strategy: String,
23
24 pub domains: Vec<(String, usize)>,
26
27 pub tensor_count: usize,
29
30 pub node_count: usize,
32
33 pub depth: usize,
35
36 pub operations: std::collections::HashMap<String, usize>,
38
39 pub estimated_flops: u64,
41
42 pub estimated_memory: u64,
44
45 pub dot_output: String,
47
48 pub json_output: String,
50
51 pub created_at: String,
53}
54
55impl CompilationSnapshot {
56 pub fn create(expr: &TLExpr, context: &CompilerContext, expr_string: &str) -> Result<Self> {
58 let mut ctx = context.clone();
60 let graph = compile_to_einsum_with_context(expr, &mut ctx)?;
61
62 let metrics = GraphMetrics::analyze(&graph);
64
65 let dot_output = export_to_dot(&graph);
67 let json_output = serde_json::to_string_pretty(&graph)?;
68
69 let domains: Vec<(String, usize)> = context
71 .domains
72 .iter()
73 .map(|(k, v)| (k.clone(), v.cardinality))
74 .collect();
75
76 Ok(Self {
77 expression: expr_string.to_string(),
78 strategy: format!("{:?}", context.config.and_strategy),
79 domains,
80 tensor_count: metrics.tensor_count,
81 node_count: metrics.node_count,
82 depth: metrics.depth,
83 operations: metrics.op_breakdown,
84 estimated_flops: metrics.estimated_flops,
85 estimated_memory: metrics.estimated_memory,
86 dot_output,
87 json_output,
88 created_at: chrono::Utc::now().to_rfc3339(),
89 })
90 }
91
92 pub fn save(&self, path: &Path) -> Result<()> {
94 let json = serde_json::to_string_pretty(self)?;
95 fs::write(path, json)?;
96 Ok(())
97 }
98
99 pub fn load(path: &Path) -> Result<Self> {
101 let content = fs::read_to_string(path)?;
102 let snapshot: Self = serde_json::from_str(&content)?;
103 Ok(snapshot)
104 }
105
106 pub fn compare(&self, other: &Self) -> SnapshotDiff {
110 self.compare_with_options(other, true)
111 }
112
113 pub fn compare_with_options(&self, other: &Self, strict_dot: bool) -> SnapshotDiff {
118 let mut differences = Vec::new();
119
120 if self.expression != other.expression {
122 differences.push(format!(
123 "Expression changed: '{}' -> '{}'",
124 self.expression, other.expression
125 ));
126 }
127
128 if self.strategy != other.strategy {
130 differences.push(format!(
131 "Strategy changed: {} -> {}",
132 self.strategy, other.strategy
133 ));
134 }
135
136 if self.domains != other.domains {
138 differences.push(format!(
139 "Domains changed: {:?} -> {:?}",
140 self.domains, other.domains
141 ));
142 }
143
144 if self.tensor_count != other.tensor_count {
146 differences.push(format!(
147 "Tensor count changed: {} -> {}",
148 self.tensor_count, other.tensor_count
149 ));
150 }
151
152 if self.node_count != other.node_count {
154 differences.push(format!(
155 "Node count changed: {} -> {}",
156 self.node_count, other.node_count
157 ));
158 }
159
160 if self.depth != other.depth {
162 differences.push(format!("Depth changed: {} -> {}", self.depth, other.depth));
163 }
164
165 if self.operations != other.operations {
167 differences.push(format!(
168 "Operations changed: {:?} -> {:?}",
169 self.operations, other.operations
170 ));
171 }
172
173 let flops_diff = self.estimated_flops.abs_diff(other.estimated_flops);
175
176 if flops_diff > 100 {
177 differences.push(format!(
179 "Estimated FLOPs changed significantly: {} -> {}",
180 self.estimated_flops, other.estimated_flops
181 ));
182 }
183
184 let mem_diff = self.estimated_memory.abs_diff(other.estimated_memory);
186
187 if mem_diff > 1000 {
188 differences.push(format!(
190 "Estimated memory changed significantly: {} -> {}",
191 self.estimated_memory, other.estimated_memory
192 ));
193 }
194
195 if strict_dot && self.dot_output != other.dot_output {
197 differences.push("DOT output structure changed".to_string());
198 }
199
200 SnapshotDiff {
201 identical: differences.is_empty(),
202 differences,
203 }
204 }
205}
206
207#[derive(Debug, Clone)]
209pub struct SnapshotDiff {
210 pub identical: bool,
212
213 pub differences: Vec<String>,
215}
216
217impl SnapshotDiff {
218 pub fn is_match(&self) -> bool {
220 self.identical
221 }
222
223 #[allow(dead_code)]
225 pub fn print_diff(&self) {
226 if self.identical {
227 println!("✓ Snapshots match");
228 } else {
229 eprintln!("✗ Snapshots differ:");
230 for diff in &self.differences {
231 eprintln!(" - {}", diff);
232 }
233 }
234 }
235}
236
237pub struct SnapshotSuite {
239 snapshot_dir: PathBuf,
241
242 name: String,
244}
245
246impl SnapshotSuite {
247 pub fn new(name: &str, snapshot_dir: PathBuf) -> Self {
249 Self {
250 snapshot_dir,
251 name: name.to_string(),
252 }
253 }
254
255 fn snapshot_path(&self, test_name: &str) -> PathBuf {
257 self.snapshot_dir
258 .join(format!("{}_{}.json", self.name, test_name))
259 }
260
261 pub fn record(
263 &self,
264 test_name: &str,
265 expr: &TLExpr,
266 context: &CompilerContext,
267 expr_string: &str,
268 ) -> Result<()> {
269 fs::create_dir_all(&self.snapshot_dir)?;
271
272 let snapshot = CompilationSnapshot::create(expr, context, expr_string)?;
273 let path = self.snapshot_path(test_name);
274 snapshot.save(&path)?;
275
276 println!("✓ Recorded snapshot: {}", test_name);
277 Ok(())
278 }
279
280 pub fn verify(
282 &self,
283 test_name: &str,
284 expr: &TLExpr,
285 context: &CompilerContext,
286 expr_string: &str,
287 ) -> Result<SnapshotDiff> {
288 let path = self.snapshot_path(test_name);
289
290 if !path.exists() {
291 anyhow::bail!(
292 "Snapshot not found: {}. Run in record mode first.",
293 test_name
294 );
295 }
296
297 let recorded = CompilationSnapshot::load(&path)?;
298 let current = CompilationSnapshot::create(expr, context, expr_string)?;
299
300 Ok(recorded.compare(¤t))
301 }
302
303 pub fn update(
305 &self,
306 test_name: &str,
307 expr: &TLExpr,
308 context: &CompilerContext,
309 expr_string: &str,
310 ) -> Result<()> {
311 self.record(test_name, expr, context, expr_string)
312 }
313
314 pub fn list_snapshots(&self) -> Result<Vec<String>> {
316 let mut snapshots = Vec::new();
317
318 if !self.snapshot_dir.exists() {
319 return Ok(snapshots);
320 }
321
322 for entry in fs::read_dir(&self.snapshot_dir)? {
323 let entry = entry?;
324 let path = entry.path();
325
326 if let Some(filename) = path.file_name() {
327 if let Some(name) = filename.to_str() {
328 if name.starts_with(&self.name) && name.ends_with(".json") {
329 let test_name = name
331 .trim_start_matches(&format!("{}_", self.name))
332 .trim_end_matches(".json");
333 snapshots.push(test_name.to_string());
334 }
335 }
336 }
337 }
338
339 Ok(snapshots)
340 }
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346 use tensorlogic_compiler::CompilationConfig;
347 use tensorlogic_ir::Term;
348
349 fn create_test_expr() -> TLExpr {
350 TLExpr::And(
351 Box::new(TLExpr::Pred {
352 name: "knows".to_string(),
353 args: vec![Term::Var("x".to_string()), Term::Var("y".to_string())],
354 }),
355 Box::new(TLExpr::Pred {
356 name: "likes".to_string(),
357 args: vec![Term::Var("y".to_string()), Term::Var("z".to_string())],
358 }),
359 )
360 }
361
362 fn create_test_context() -> CompilerContext {
363 let config = CompilationConfig::soft_differentiable();
364 let mut ctx = CompilerContext::with_config(config);
365 ctx.add_domain("D", 100);
366 ctx
367 }
368
369 #[test]
370 fn test_snapshot_creation() {
371 let expr = create_test_expr();
372 let context = create_test_context();
373
374 let snapshot =
375 CompilationSnapshot::create(&expr, &context, "knows(x, y) AND likes(y, z)").unwrap();
376
377 assert_eq!(snapshot.expression, "knows(x, y) AND likes(y, z)");
378 assert!(snapshot.tensor_count > 0);
379 assert!(snapshot.node_count > 0);
380 assert!(!snapshot.dot_output.is_empty());
381 assert!(!snapshot.json_output.is_empty());
382 }
383
384 #[test]
385 fn test_snapshot_save_load() {
386 let expr = create_test_expr();
387 let context = create_test_context();
388 let snapshot =
389 CompilationSnapshot::create(&expr, &context, "knows(x, y) AND likes(y, z)").unwrap();
390
391 let temp_dir = std::env::temp_dir();
392 let path = temp_dir.join("test_snapshot.json");
393
394 snapshot.save(&path).unwrap();
395 let loaded = CompilationSnapshot::load(&path).unwrap();
396
397 assert_eq!(snapshot.expression, loaded.expression);
398 assert_eq!(snapshot.tensor_count, loaded.tensor_count);
399 assert_eq!(snapshot.node_count, loaded.node_count);
400
401 let _ = fs::remove_file(&path);
403 }
404
405 #[test]
406 fn test_snapshot_comparison_identical() {
407 let expr = create_test_expr();
408 let context = create_test_context();
409
410 let snapshot1 =
411 CompilationSnapshot::create(&expr, &context, "knows(x, y) AND likes(y, z)").unwrap();
412 let snapshot2 =
413 CompilationSnapshot::create(&expr, &context, "knows(x, y) AND likes(y, z)").unwrap();
414
415 let diff = snapshot1.compare_with_options(&snapshot2, false);
418 if !diff.is_match() {
419 eprintln!("Differences found:");
420 for d in &diff.differences {
421 eprintln!(" {}", d);
422 }
423 }
424 assert!(diff.is_match());
425 assert!(diff.differences.is_empty());
426 }
427
428 #[test]
429 fn test_snapshot_comparison_different() {
430 let expr1 = create_test_expr();
431 let expr2 = TLExpr::Pred {
432 name: "knows".to_string(),
433 args: vec![Term::Var("x".to_string()), Term::Var("y".to_string())],
434 };
435
436 let context = create_test_context();
437
438 let snapshot1 =
439 CompilationSnapshot::create(&expr1, &context, "knows(x, y) AND likes(y, z)").unwrap();
440 let snapshot2 = CompilationSnapshot::create(&expr2, &context, "knows(x, y)").unwrap();
441
442 let diff = snapshot1.compare(&snapshot2);
443 assert!(!diff.is_match());
444 assert!(!diff.differences.is_empty());
445 }
446
447 #[test]
448 fn test_snapshot_suite() {
449 let temp_dir = std::env::temp_dir().join("tensorlogic_snapshots_test");
450 let suite = SnapshotSuite::new("test_suite", temp_dir.clone());
451
452 let expr = create_test_expr();
453 let context = create_test_context();
454
455 suite
457 .record("test1", &expr, &context, "knows(x, y) AND likes(y, z)")
458 .unwrap();
459
460 let current =
463 CompilationSnapshot::create(&expr, &context, "knows(x, y) AND likes(y, z)").unwrap();
464 let path = suite.snapshot_path("test1");
465 let recorded = CompilationSnapshot::load(&path).unwrap();
466 let diff = recorded.compare_with_options(¤t, false);
467
468 assert!(diff.is_match());
469
470 let snapshots = suite.list_snapshots().unwrap();
472 assert!(snapshots.contains(&"test1".to_string()));
473
474 let _ = fs::remove_dir_all(&temp_dir);
476 }
477}