1use std::collections::HashMap;
7use tensorlogic_ir::{EinsumGraph, TLExpr};
8
9use crate::context::CompilerContext;
10
11#[derive(Debug, Clone)]
16pub struct CompilationTrace {
17 pub input_expr: String,
19 pub steps: Vec<CompilationStep>,
21 pub final_graph: Option<String>,
23 pub errors: Vec<String>,
25 pub duration_ms: Option<f64>,
27}
28
29#[derive(Debug, Clone)]
31pub struct CompilationStep {
32 pub step_num: usize,
34 pub phase: String,
36 pub description: String,
38 pub state: StepState,
40 pub duration_us: Option<u64>,
42}
43
44#[derive(Debug, Clone)]
46pub struct StepState {
47 pub tensor_count: usize,
49 pub node_count: usize,
51 pub domain_count: usize,
53 pub bound_vars: usize,
55 pub axis_assignments: usize,
57 pub metadata: HashMap<String, String>,
59}
60
61impl CompilationTrace {
62 pub fn new(input_expr: &TLExpr) -> Self {
64 Self {
65 input_expr: format!("{:?}", input_expr),
66 steps: Vec::new(),
67 final_graph: None,
68 errors: Vec::new(),
69 duration_ms: None,
70 }
71 }
72
73 pub fn add_step(
75 &mut self,
76 phase: impl Into<String>,
77 description: impl Into<String>,
78 ctx: &CompilerContext,
79 graph: &EinsumGraph,
80 ) {
81 let state = StepState {
82 tensor_count: graph.tensors.len(),
83 node_count: graph.nodes.len(),
84 domain_count: ctx.domains.len(),
85 bound_vars: ctx.var_to_domain.len(),
86 axis_assignments: ctx.var_to_axis.len(),
87 metadata: HashMap::new(),
88 };
89
90 self.steps.push(CompilationStep {
91 step_num: self.steps.len(),
92 phase: phase.into(),
93 description: description.into(),
94 state,
95 duration_us: None,
96 });
97 }
98
99 pub fn add_error(&mut self, error: impl Into<String>) {
101 self.errors.push(error.into());
102 }
103
104 pub fn set_final_graph(&mut self, graph: &EinsumGraph) {
106 self.final_graph = Some(format!("{:?}", graph));
107 }
108
109 pub fn set_duration(&mut self, duration_ms: f64) {
111 self.duration_ms = Some(duration_ms);
112 }
113
114 pub fn print_summary(&self) {
116 println!("=== Compilation Trace Summary ===");
117 println!("Input: {}", truncate(&self.input_expr, 100));
118 println!("Steps: {}", self.steps.len());
119 println!("Errors: {}", self.errors.len());
120
121 if let Some(dur) = self.duration_ms {
122 println!("Duration: {:.3}ms", dur);
123 }
124
125 println!("\n--- Steps ---");
126 for step in &self.steps {
127 println!(
128 "{:2}. {} - {} (T:{}, N:{})",
129 step.step_num,
130 step.phase,
131 step.description,
132 step.state.tensor_count,
133 step.state.node_count
134 );
135 }
136
137 if !self.errors.is_empty() {
138 println!("\n--- Errors ---");
139 for (i, error) in self.errors.iter().enumerate() {
140 println!("{}. {}", i + 1, error);
141 }
142 }
143
144 if let Some(ref graph) = self.final_graph {
145 println!("\n--- Final Graph ---");
146 println!("{}", truncate(graph, 200));
147 }
148
149 println!("================================");
150 }
151
152 pub fn detailed_report(&self) -> String {
154 let mut report = String::new();
155
156 report.push_str("╔════════════════════════════════════════╗\n");
157 report.push_str("║ COMPILATION TRACE - DETAILED REPORT ║\n");
158 report.push_str("╚════════════════════════════════════════╝\n\n");
159
160 report.push_str(&format!("Input Expression:\n {}\n\n", self.input_expr));
161
162 if let Some(dur) = self.duration_ms {
163 report.push_str(&format!("Total Duration: {:.3}ms\n\n", dur));
164 }
165
166 report.push_str("Compilation Steps:\n");
167 report.push_str("─────────────────────────────────────────\n\n");
168
169 for step in &self.steps {
170 report.push_str(&format!("Step {}: {}\n", step.step_num, step.phase));
171 report.push_str(&format!(" Description: {}\n", step.description));
172 report.push_str(" State:\n");
173 report.push_str(&format!(" Tensors: {}\n", step.state.tensor_count));
174 report.push_str(&format!(" Nodes: {}\n", step.state.node_count));
175 report.push_str(&format!(" Domains: {}\n", step.state.domain_count));
176 report.push_str(&format!(" Bound Variables: {}\n", step.state.bound_vars));
177 report.push_str(&format!(
178 " Axis Assignments: {}\n",
179 step.state.axis_assignments
180 ));
181
182 if !step.state.metadata.is_empty() {
183 report.push_str(" Metadata:\n");
184 for (key, value) in &step.state.metadata {
185 report.push_str(&format!(" {}: {}\n", key, value));
186 }
187 }
188
189 if let Some(dur) = step.duration_us {
190 report.push_str(&format!(" Duration: {}μs\n", dur));
191 }
192
193 report.push('\n');
194 }
195
196 if !self.errors.is_empty() {
197 report.push_str("Errors Encountered:\n");
198 report.push_str("─────────────────────────────────────────\n");
199 for (i, error) in self.errors.iter().enumerate() {
200 report.push_str(&format!("{}. {}\n", i + 1, error));
201 }
202 report.push('\n');
203 }
204
205 if let Some(ref graph) = self.final_graph {
206 report.push_str("Final Graph:\n");
207 report.push_str("─────────────────────────────────────────\n");
208 report.push_str(graph);
209 report.push('\n');
210 }
211
212 report
213 }
214}
215
216fn truncate(s: &str, max_len: usize) -> String {
218 if s.len() <= max_len {
219 s.to_string()
220 } else {
221 format!("{}...", &s[..max_len])
222 }
223}
224
225pub struct CompilationTracer {
241 enabled: bool,
242 trace: Option<CompilationTrace>,
243 start_time: Option<std::time::Instant>,
244}
245
246impl CompilationTracer {
247 pub fn new(enabled: bool) -> Self {
249 Self {
250 enabled,
251 trace: None,
252 start_time: None,
253 }
254 }
255
256 pub fn start(&mut self, expr: &TLExpr) {
258 if self.enabled {
259 self.trace = Some(CompilationTrace::new(expr));
260 self.start_time = Some(std::time::Instant::now());
261 }
262 }
263
264 pub fn record_step(
266 &mut self,
267 phase: impl Into<String>,
268 description: impl Into<String>,
269 ctx: &CompilerContext,
270 graph: &EinsumGraph,
271 ) {
272 if self.enabled {
273 if let Some(ref mut trace) = self.trace {
274 trace.add_step(phase, description, ctx, graph);
275 }
276 }
277 }
278
279 pub fn record_error(&mut self, error: impl Into<String>) {
281 if self.enabled {
282 if let Some(ref mut trace) = self.trace {
283 trace.add_error(error);
284 }
285 }
286 }
287
288 pub fn finish(&mut self, graph: &EinsumGraph) -> Option<CompilationTrace> {
290 if !self.enabled {
291 return None;
292 }
293
294 if let Some(ref mut trace) = self.trace {
295 trace.set_final_graph(graph);
296
297 if let Some(start) = self.start_time {
298 let duration = start.elapsed();
299 trace.set_duration(duration.as_secs_f64() * 1000.0);
300 }
301 }
302
303 self.trace.take()
304 }
305}
306
307pub fn print_context_state(ctx: &CompilerContext, label: &str) {
309 println!("\n=== Context State: {} ===", label);
310 println!("Domains: {}", ctx.domains.len());
311 for (name, info) in &ctx.domains {
312 println!(" - {} (cardinality: {})", name, info.cardinality);
313 }
314
315 println!("Var->Domain bindings: {}", ctx.var_to_domain.len());
316 for (var, domain) in &ctx.var_to_domain {
317 println!(" - {} -> {}", var, domain);
318 }
319
320 println!("Var->Axis assignments: {}", ctx.var_to_axis.len());
321 for (var, axis) in &ctx.var_to_axis {
322 println!(" - {} -> axis '{}'", var, axis);
323 }
324
325 println!("Config: {:?}", ctx.config.and_strategy);
326 println!("========================\n");
327}
328
329pub fn print_graph_state(graph: &EinsumGraph, label: &str) {
331 println!("\n=== Graph State: {} ===", label);
332 println!("Tensors: {}", graph.tensors.len());
333 for (i, tensor) in graph.tensors.iter().enumerate() {
334 println!(" [{:3}] {}", i, tensor);
335 }
336
337 println!("Nodes: {}", graph.nodes.len());
338 for (i, node) in graph.nodes.iter().enumerate() {
339 println!(" [{:3}] {:?}", i, node.op);
340 println!(
341 " inputs: {:?}, outputs: {:?}",
342 node.inputs, node.outputs
343 );
344 }
345
346 println!("Inputs: {:?}", graph.inputs);
347 println!("Outputs: {:?}", graph.outputs);
348 println!("========================\n");
349}
350
351pub fn print_graph_diff(before: &EinsumGraph, after: &EinsumGraph, label: &str) {
353 println!("\n=== Graph Diff: {} ===", label);
354
355 let tensor_diff = after.tensors.len() as i32 - before.tensors.len() as i32;
356 let node_diff = after.nodes.len() as i32 - before.nodes.len() as i32;
357
358 println!(
359 "Tensors: {} -> {} ({:+})",
360 before.tensors.len(),
361 after.tensors.len(),
362 tensor_diff
363 );
364 println!(
365 "Nodes: {} -> {} ({:+})",
366 before.nodes.len(),
367 after.nodes.len(),
368 node_diff
369 );
370
371 if tensor_diff > 0 {
372 println!("New tensors:");
373 for tensor in &after.tensors[before.tensors.len()..] {
374 println!(" + {}", tensor);
375 }
376 }
377
378 if node_diff > 0 {
379 println!("New nodes:");
380 for (i, node) in after.nodes[before.nodes.len()..].iter().enumerate() {
381 let idx = before.nodes.len() + i;
382 println!(" + [{:3}] {:?}", idx, node.op);
383 }
384 }
385
386 println!("========================\n");
387}
388
389#[cfg(test)]
390mod tests {
391 use super::*;
392 use crate::CompilerContext;
393 use tensorlogic_ir::{EinsumGraph, EinsumNode, TLExpr, Term};
394
395 #[test]
396 fn test_compilation_trace_creation() {
397 let expr = TLExpr::pred("P", vec![Term::var("x")]);
398
399 let trace = CompilationTrace::new(&expr);
400 assert_eq!(trace.steps.len(), 0);
401 assert_eq!(trace.errors.len(), 0);
402 assert!(trace.final_graph.is_none());
403 }
404
405 #[test]
406 fn test_add_compilation_step() {
407 let expr = TLExpr::pred("P", vec![Term::var("x")]);
408
409 let mut trace = CompilationTrace::new(&expr);
410 let ctx = CompilerContext::new();
411 let graph = EinsumGraph::new();
412
413 trace.add_step("Parse", "Parsed expression", &ctx, &graph);
414
415 assert_eq!(trace.steps.len(), 1);
416 assert_eq!(trace.steps[0].phase, "Parse");
417 assert_eq!(trace.steps[0].description, "Parsed expression");
418 }
419
420 #[test]
421 fn test_compilation_tracer_disabled() {
422 let mut tracer = CompilationTracer::new(false);
423
424 let expr = TLExpr::pred("P", vec![Term::var("x")]);
425
426 tracer.start(&expr);
427
428 let ctx = CompilerContext::new();
429 let graph = EinsumGraph::new();
430
431 tracer.record_step("Test", "Description", &ctx, &graph);
432
433 let result = tracer.finish(&graph);
434 assert!(result.is_none());
435 }
436
437 #[test]
438 fn test_compilation_tracer_enabled() {
439 let mut tracer = CompilationTracer::new(true);
440
441 let expr = TLExpr::pred("P", vec![Term::var("x")]);
442
443 tracer.start(&expr);
444
445 let ctx = CompilerContext::new();
446 let graph = EinsumGraph::new();
447
448 tracer.record_step("Phase1", "First step", &ctx, &graph);
449 tracer.record_step("Phase2", "Second step", &ctx, &graph);
450
451 let result = tracer.finish(&graph);
452 assert!(result.is_some());
453
454 let trace = result.unwrap();
455 assert_eq!(trace.steps.len(), 2);
456 assert!(trace.duration_ms.is_some());
457 }
458
459 #[test]
460 fn test_print_context_state() {
461 let mut ctx = CompilerContext::new();
462 ctx.add_domain("D1".to_string(), 10);
463 let _ = ctx.bind_var("x", "D1");
465
466 print_context_state(&ctx, "Test");
468 }
469
470 #[test]
471 fn test_print_graph_state() {
472 let mut graph = EinsumGraph::new();
473 let t0 = graph.add_tensor("input".to_string());
474 let t1 = graph.add_tensor("output".to_string());
475
476 graph
477 .add_node(EinsumNode::elem_unary("relu", t0, t1))
478 .unwrap();
479
480 print_graph_state(&graph, "Test");
482 }
483
484 #[test]
485 fn test_truncate() {
486 assert_eq!(truncate("hello", 10), "hello");
487 assert_eq!(truncate("hello world this is long", 10), "hello worl...");
488 }
489}