Skip to main content

tensorlogic_compiler/
debug.rs

1//! Debug utilities for tracing compilation.
2//!
3//! This module provides tools for debugging and understanding the compilation process,
4//! including intermediate state tracking, step-by-step tracing, and detailed logging.
5
6use std::collections::HashMap;
7use tensorlogic_ir::{EinsumGraph, TLExpr};
8
9use crate::context::CompilerContext;
10
11/// Compilation trace for debugging.
12///
13/// Captures the state of compilation at each major step, allowing
14/// developers to understand how expressions are transformed into tensor graphs.
15#[derive(Debug, Clone)]
16pub struct CompilationTrace {
17    /// Original expression before compilation
18    pub input_expr: String,
19    /// Compilation steps with intermediate states
20    pub steps: Vec<CompilationStep>,
21    /// Final compiled graph (if successful)
22    pub final_graph: Option<String>,
23    /// Errors encountered during compilation
24    pub errors: Vec<String>,
25    /// Total compilation time (if measured)
26    pub duration_ms: Option<f64>,
27}
28
29/// A single step in the compilation process.
30#[derive(Debug, Clone)]
31pub struct CompilationStep {
32    /// Step number (0-indexed)
33    pub step_num: usize,
34    /// Name of this compilation phase
35    pub phase: String,
36    /// Description of what happened
37    pub description: String,
38    /// State snapshot at this step
39    pub state: StepState,
40    /// Duration of this step (if measured)
41    pub duration_us: Option<u64>,
42}
43
44/// Snapshot of compiler state at a specific step.
45#[derive(Debug, Clone)]
46pub struct StepState {
47    /// Number of tensors in the graph
48    pub tensor_count: usize,
49    /// Number of nodes in the graph
50    pub node_count: usize,
51    /// Number of domains defined
52    pub domain_count: usize,
53    /// Number of bound variables
54    pub bound_vars: usize,
55    /// Number of axis assignments
56    pub axis_assignments: usize,
57    /// Additional custom data
58    pub metadata: HashMap<String, String>,
59}
60
61impl CompilationTrace {
62    /// Create a new empty trace.
63    pub fn new(input_expr: &TLExpr) -> Self {
64        Self {
65            input_expr: format!("{:?}", input_expr),
66            steps: Vec::new(),
67            final_graph: None,
68            errors: Vec::new(),
69            duration_ms: None,
70        }
71    }
72
73    /// Add a compilation step to the trace.
74    pub fn add_step(
75        &mut self,
76        phase: impl Into<String>,
77        description: impl Into<String>,
78        ctx: &CompilerContext,
79        graph: &EinsumGraph,
80    ) {
81        let state = StepState {
82            tensor_count: graph.tensors.len(),
83            node_count: graph.nodes.len(),
84            domain_count: ctx.domains.len(),
85            bound_vars: ctx.var_to_domain.len(),
86            axis_assignments: ctx.var_to_axis.len(),
87            metadata: HashMap::new(),
88        };
89
90        self.steps.push(CompilationStep {
91            step_num: self.steps.len(),
92            phase: phase.into(),
93            description: description.into(),
94            state,
95            duration_us: None,
96        });
97    }
98
99    /// Add an error to the trace.
100    pub fn add_error(&mut self, error: impl Into<String>) {
101        self.errors.push(error.into());
102    }
103
104    /// Set the final compiled graph.
105    pub fn set_final_graph(&mut self, graph: &EinsumGraph) {
106        self.final_graph = Some(format!("{:?}", graph));
107    }
108
109    /// Set the total compilation duration.
110    pub fn set_duration(&mut self, duration_ms: f64) {
111        self.duration_ms = Some(duration_ms);
112    }
113
114    /// Print a summary of the compilation trace.
115    pub fn print_summary(&self) {
116        println!("=== Compilation Trace Summary ===");
117        println!("Input: {}", truncate(&self.input_expr, 100));
118        println!("Steps: {}", self.steps.len());
119        println!("Errors: {}", self.errors.len());
120
121        if let Some(dur) = self.duration_ms {
122            println!("Duration: {:.3}ms", dur);
123        }
124
125        println!("\n--- Steps ---");
126        for step in &self.steps {
127            println!(
128                "{:2}. {} - {} (T:{}, N:{})",
129                step.step_num,
130                step.phase,
131                step.description,
132                step.state.tensor_count,
133                step.state.node_count
134            );
135        }
136
137        if !self.errors.is_empty() {
138            println!("\n--- Errors ---");
139            for (i, error) in self.errors.iter().enumerate() {
140                println!("{}. {}", i + 1, error);
141            }
142        }
143
144        if let Some(ref graph) = self.final_graph {
145            println!("\n--- Final Graph ---");
146            println!("{}", truncate(graph, 200));
147        }
148
149        println!("================================");
150    }
151
152    /// Generate a detailed report with all intermediate states.
153    pub fn detailed_report(&self) -> String {
154        let mut report = String::new();
155
156        report.push_str("╔════════════════════════════════════════╗\n");
157        report.push_str("║   COMPILATION TRACE - DETAILED REPORT   ║\n");
158        report.push_str("╚════════════════════════════════════════╝\n\n");
159
160        report.push_str(&format!("Input Expression:\n  {}\n\n", self.input_expr));
161
162        if let Some(dur) = self.duration_ms {
163            report.push_str(&format!("Total Duration: {:.3}ms\n\n", dur));
164        }
165
166        report.push_str("Compilation Steps:\n");
167        report.push_str("─────────────────────────────────────────\n\n");
168
169        for step in &self.steps {
170            report.push_str(&format!("Step {}: {}\n", step.step_num, step.phase));
171            report.push_str(&format!("  Description: {}\n", step.description));
172            report.push_str("  State:\n");
173            report.push_str(&format!("    Tensors: {}\n", step.state.tensor_count));
174            report.push_str(&format!("    Nodes: {}\n", step.state.node_count));
175            report.push_str(&format!("    Domains: {}\n", step.state.domain_count));
176            report.push_str(&format!("    Bound Variables: {}\n", step.state.bound_vars));
177            report.push_str(&format!(
178                "    Axis Assignments: {}\n",
179                step.state.axis_assignments
180            ));
181
182            if !step.state.metadata.is_empty() {
183                report.push_str("    Metadata:\n");
184                for (key, value) in &step.state.metadata {
185                    report.push_str(&format!("      {}: {}\n", key, value));
186                }
187            }
188
189            if let Some(dur) = step.duration_us {
190                report.push_str(&format!("  Duration: {}μs\n", dur));
191            }
192
193            report.push('\n');
194        }
195
196        if !self.errors.is_empty() {
197            report.push_str("Errors Encountered:\n");
198            report.push_str("─────────────────────────────────────────\n");
199            for (i, error) in self.errors.iter().enumerate() {
200                report.push_str(&format!("{}. {}\n", i + 1, error));
201            }
202            report.push('\n');
203        }
204
205        if let Some(ref graph) = self.final_graph {
206            report.push_str("Final Graph:\n");
207            report.push_str("─────────────────────────────────────────\n");
208            report.push_str(graph);
209            report.push('\n');
210        }
211
212        report
213    }
214}
215
216/// Helper function to truncate long strings.
217fn truncate(s: &str, max_len: usize) -> String {
218    if s.len() <= max_len {
219        s.to_string()
220    } else {
221        format!("{}...", &s[..max_len])
222    }
223}
224
225/// Compilation tracer that can be enabled/disabled.
226///
227/// Use this to instrument compilation with tracing:
228///
229/// ```ignore
230/// let mut tracer = CompilationTracer::new(true); // enabled
231/// tracer.start(&expr);
232///
233/// // During compilation:
234/// tracer.record_step("Parse", "Parsed expression", &ctx, &graph);
235/// tracer.record_step("Optimize", "Applied CSE", &ctx, &graph);
236///
237/// let trace = tracer.finish(&graph);
238/// trace.print_summary();
239/// ```
240pub struct CompilationTracer {
241    enabled: bool,
242    trace: Option<CompilationTrace>,
243    start_time: Option<std::time::Instant>,
244}
245
246impl CompilationTracer {
247    /// Create a new tracer.
248    pub fn new(enabled: bool) -> Self {
249        Self {
250            enabled,
251            trace: None,
252            start_time: None,
253        }
254    }
255
256    /// Start tracing for the given expression.
257    pub fn start(&mut self, expr: &TLExpr) {
258        if self.enabled {
259            self.trace = Some(CompilationTrace::new(expr));
260            self.start_time = Some(std::time::Instant::now());
261        }
262    }
263
264    /// Record a compilation step.
265    pub fn record_step(
266        &mut self,
267        phase: impl Into<String>,
268        description: impl Into<String>,
269        ctx: &CompilerContext,
270        graph: &EinsumGraph,
271    ) {
272        if self.enabled {
273            if let Some(ref mut trace) = self.trace {
274                trace.add_step(phase, description, ctx, graph);
275            }
276        }
277    }
278
279    /// Record an error.
280    pub fn record_error(&mut self, error: impl Into<String>) {
281        if self.enabled {
282            if let Some(ref mut trace) = self.trace {
283                trace.add_error(error);
284            }
285        }
286    }
287
288    /// Finish tracing and return the trace.
289    pub fn finish(&mut self, graph: &EinsumGraph) -> Option<CompilationTrace> {
290        if !self.enabled {
291            return None;
292        }
293
294        if let Some(ref mut trace) = self.trace {
295            trace.set_final_graph(graph);
296
297            if let Some(start) = self.start_time {
298                let duration = start.elapsed();
299                trace.set_duration(duration.as_secs_f64() * 1000.0);
300            }
301        }
302
303        self.trace.take()
304    }
305}
306
307/// Print the compiler context state for debugging.
308pub fn print_context_state(ctx: &CompilerContext, label: &str) {
309    println!("\n=== Context State: {} ===", label);
310    println!("Domains: {}", ctx.domains.len());
311    for (name, info) in &ctx.domains {
312        println!("  - {} (cardinality: {})", name, info.cardinality);
313    }
314
315    println!("Var->Domain bindings: {}", ctx.var_to_domain.len());
316    for (var, domain) in &ctx.var_to_domain {
317        println!("  - {} -> {}", var, domain);
318    }
319
320    println!("Var->Axis assignments: {}", ctx.var_to_axis.len());
321    for (var, axis) in &ctx.var_to_axis {
322        println!("  - {} -> axis '{}'", var, axis);
323    }
324
325    println!("Config: {:?}", ctx.config.and_strategy);
326    println!("========================\n");
327}
328
329/// Print the graph state for debugging.
330pub fn print_graph_state(graph: &EinsumGraph, label: &str) {
331    println!("\n=== Graph State: {} ===", label);
332    println!("Tensors: {}", graph.tensors.len());
333    for (i, tensor) in graph.tensors.iter().enumerate() {
334        println!("  [{:3}] {}", i, tensor);
335    }
336
337    println!("Nodes: {}", graph.nodes.len());
338    for (i, node) in graph.nodes.iter().enumerate() {
339        println!("  [{:3}] {:?}", i, node.op);
340        println!(
341            "        inputs: {:?}, outputs: {:?}",
342            node.inputs, node.outputs
343        );
344    }
345
346    println!("Inputs: {:?}", graph.inputs);
347    println!("Outputs: {:?}", graph.outputs);
348    println!("========================\n");
349}
350
351/// Diff two graphs and print the differences.
352pub fn print_graph_diff(before: &EinsumGraph, after: &EinsumGraph, label: &str) {
353    println!("\n=== Graph Diff: {} ===", label);
354
355    let tensor_diff = after.tensors.len() as i32 - before.tensors.len() as i32;
356    let node_diff = after.nodes.len() as i32 - before.nodes.len() as i32;
357
358    println!(
359        "Tensors: {} -> {} ({:+})",
360        before.tensors.len(),
361        after.tensors.len(),
362        tensor_diff
363    );
364    println!(
365        "Nodes: {} -> {} ({:+})",
366        before.nodes.len(),
367        after.nodes.len(),
368        node_diff
369    );
370
371    if tensor_diff > 0 {
372        println!("New tensors:");
373        for tensor in &after.tensors[before.tensors.len()..] {
374            println!("  + {}", tensor);
375        }
376    }
377
378    if node_diff > 0 {
379        println!("New nodes:");
380        for (i, node) in after.nodes[before.nodes.len()..].iter().enumerate() {
381            let idx = before.nodes.len() + i;
382            println!("  + [{:3}] {:?}", idx, node.op);
383        }
384    }
385
386    println!("========================\n");
387}
388
389#[cfg(test)]
390mod tests {
391    use super::*;
392    use crate::CompilerContext;
393    use tensorlogic_ir::{EinsumGraph, EinsumNode, TLExpr, Term};
394
395    #[test]
396    fn test_compilation_trace_creation() {
397        let expr = TLExpr::pred("P", vec![Term::var("x")]);
398
399        let trace = CompilationTrace::new(&expr);
400        assert_eq!(trace.steps.len(), 0);
401        assert_eq!(trace.errors.len(), 0);
402        assert!(trace.final_graph.is_none());
403    }
404
405    #[test]
406    fn test_add_compilation_step() {
407        let expr = TLExpr::pred("P", vec![Term::var("x")]);
408
409        let mut trace = CompilationTrace::new(&expr);
410        let ctx = CompilerContext::new();
411        let graph = EinsumGraph::new();
412
413        trace.add_step("Parse", "Parsed expression", &ctx, &graph);
414
415        assert_eq!(trace.steps.len(), 1);
416        assert_eq!(trace.steps[0].phase, "Parse");
417        assert_eq!(trace.steps[0].description, "Parsed expression");
418    }
419
420    #[test]
421    fn test_compilation_tracer_disabled() {
422        let mut tracer = CompilationTracer::new(false);
423
424        let expr = TLExpr::pred("P", vec![Term::var("x")]);
425
426        tracer.start(&expr);
427
428        let ctx = CompilerContext::new();
429        let graph = EinsumGraph::new();
430
431        tracer.record_step("Test", "Description", &ctx, &graph);
432
433        let result = tracer.finish(&graph);
434        assert!(result.is_none());
435    }
436
437    #[test]
438    fn test_compilation_tracer_enabled() {
439        let mut tracer = CompilationTracer::new(true);
440
441        let expr = TLExpr::pred("P", vec![Term::var("x")]);
442
443        tracer.start(&expr);
444
445        let ctx = CompilerContext::new();
446        let graph = EinsumGraph::new();
447
448        tracer.record_step("Phase1", "First step", &ctx, &graph);
449        tracer.record_step("Phase2", "Second step", &ctx, &graph);
450
451        let result = tracer.finish(&graph);
452        assert!(result.is_some());
453
454        let trace = result.unwrap();
455        assert_eq!(trace.steps.len(), 2);
456        assert!(trace.duration_ms.is_some());
457    }
458
459    #[test]
460    fn test_print_context_state() {
461        let mut ctx = CompilerContext::new();
462        ctx.add_domain("D1".to_string(), 10);
463        // bind_var requires a domain name, not an axis number
464        let _ = ctx.bind_var("x", "D1");
465
466        // Should not panic
467        print_context_state(&ctx, "Test");
468    }
469
470    #[test]
471    fn test_print_graph_state() {
472        let mut graph = EinsumGraph::new();
473        let t0 = graph.add_tensor("input".to_string());
474        let t1 = graph.add_tensor("output".to_string());
475
476        graph
477            .add_node(EinsumNode::elem_unary("relu", t0, t1))
478            .unwrap();
479
480        // Should not panic
481        print_graph_state(&graph, "Test");
482    }
483
484    #[test]
485    fn test_truncate() {
486        assert_eq!(truncate("hello", 10), "hello");
487        assert_eq!(truncate("hello world this is long", 10), "hello worl...");
488    }
489}