Skip to main content

tensorlogic_compiler/passes/
post_compilation.rs

1//! Post-compilation validation and optimization passes.
2//!
3//! This module provides validation and optimization passes that run after
4//! the initial compilation to ensure correctness and improve performance.
5
6use anyhow::{bail, Result};
7use std::collections::HashSet;
8use tensorlogic_ir::{validate_graph, EinsumGraph, OpType, ValidationReport};
9
10use crate::CompilerContext;
11
12// Import our new optimization passes
13use super::{
14    contraction_opt::{optimize_contractions_with_config, ContractionOptConfig},
15    loop_fusion::{fuse_loops_with_config, LoopFusionConfig},
16};
17
18/// Post-compilation validation options.
19#[derive(Debug, Clone)]
20pub struct PostCompilationOptions {
21    /// Enable graph structure validation
22    pub validate_graph_structure: bool,
23    /// Enable axis consistency checks
24    pub validate_axes: bool,
25    /// Enable shape compatibility checks
26    pub validate_shapes: bool,
27    /// Enable optimization passes
28    pub apply_optimizations: bool,
29    /// Enable contraction optimization
30    pub enable_contraction_opt: bool,
31    /// Enable loop fusion optimization
32    pub enable_loop_fusion: bool,
33    /// Fail on warnings
34    pub strict_mode: bool,
35}
36
37impl Default for PostCompilationOptions {
38    fn default() -> Self {
39        Self {
40            validate_graph_structure: true,
41            validate_axes: true,
42            validate_shapes: true,
43            apply_optimizations: true,
44            enable_contraction_opt: true,
45            enable_loop_fusion: true,
46            strict_mode: false,
47        }
48    }
49}
50
51/// Result of post-compilation passes.
52#[derive(Debug, Clone)]
53pub struct PostCompilationResult {
54    /// Validation report
55    pub validation_report: ValidationReport,
56    /// Whether the graph passed all checks
57    pub is_valid: bool,
58    /// Number of optimizations applied
59    pub optimizations_applied: usize,
60    /// Detailed messages
61    pub messages: Vec<String>,
62}
63
64/// Run post-compilation validation and optimization passes.
65///
66/// # Examples
67///
68/// ```
69/// use tensorlogic_compiler::{compile_to_einsum_with_context, CompilerContext};
70/// use tensorlogic_compiler::passes::{
71///     post_compilation_passes, PostCompilationOptions
72/// };
73/// use tensorlogic_ir::{TLExpr, Term};
74///
75/// let mut ctx = CompilerContext::new();
76/// ctx.add_domain("Person", 100);
77///
78/// let expr = TLExpr::exists(
79///     "y",
80///     "Person",
81///     TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]),
82/// );
83///
84/// let mut graph = compile_to_einsum_with_context(&expr, &mut ctx).unwrap();
85///
86/// let options = PostCompilationOptions::default();
87/// let result = post_compilation_passes(&mut graph, &ctx, options).unwrap();
88///
89/// assert!(result.is_valid);
90/// ```
91pub fn post_compilation_passes(
92    graph: &mut EinsumGraph,
93    ctx: &CompilerContext,
94    options: PostCompilationOptions,
95) -> Result<PostCompilationResult> {
96    let mut messages = Vec::new();
97    let mut optimizations_applied = 0;
98
99    // 1. Validate graph structure
100    let validation_report = if options.validate_graph_structure {
101        let report = validate_graph(graph);
102
103        // For simple predicate expressions, the output might be an input tensor
104        // with no producer node, which is valid. Filter out such errors.
105        let has_simple_passthrough = graph.nodes.is_empty()
106            || (graph.outputs.len() == 1 && graph.inputs.contains(&graph.outputs[0]));
107
108        let filtered_errors: Vec<_> = report
109            .errors
110            .into_iter()
111            .filter(|error| {
112                // Allow "no producer" errors for simple passthrough graphs
113                if has_simple_passthrough && error.message.contains("has no producer") {
114                    return false; // Filter out this error
115                }
116                true // Keep all other errors
117            })
118            .collect();
119
120        for error in &filtered_errors {
121            messages.push(format!("ERROR: {}", error.message));
122        }
123
124        if !report.warnings.is_empty() {
125            for warning in &report.warnings {
126                messages.push(format!("WARNING: {}", warning.message));
127            }
128        }
129
130        ValidationReport {
131            checks_performed: report.checks_performed,
132            errors: filtered_errors,
133            warnings: report.warnings,
134            stats: report.stats,
135        }
136    } else {
137        ValidationReport {
138            checks_performed: 0,
139            errors: vec![],
140            warnings: vec![],
141            stats: Default::default(),
142        }
143    };
144
145    // 2. Validate axis consistency
146    if options.validate_axes {
147        validate_axis_consistency(graph, ctx, &mut messages)?;
148    }
149
150    // 3. Validate shape compatibility (basic checks)
151    if options.validate_shapes {
152        validate_shape_compatibility(graph, ctx, &mut messages)?;
153    }
154
155    // 4. Apply optimization passes
156    if options.apply_optimizations {
157        optimizations_applied += apply_optimization_passes(graph, &options, &mut messages)?;
158    }
159
160    // Check if valid
161    let is_valid = validation_report.is_valid()
162        && (!options.strict_mode || validation_report.warnings.is_empty());
163
164    if !is_valid {
165        bail!(
166            "Post-compilation validation failed:\n{}",
167            messages.join("\n")
168        );
169    }
170
171    Ok(PostCompilationResult {
172        validation_report,
173        is_valid,
174        optimizations_applied,
175        messages,
176    })
177}
178
179/// Validate axis consistency across the graph.
180fn validate_axis_consistency(
181    graph: &EinsumGraph,
182    ctx: &CompilerContext,
183    messages: &mut Vec<String>,
184) -> Result<()> {
185    // Track which axes are used and their expected sizes
186    let mut axis_domains = std::collections::HashMap::new();
187
188    for node in &graph.nodes {
189        if let OpType::Einsum { spec, .. } = &node.op {
190            // Extract axes from einsum spec
191            let axes = extract_axes_from_spec(spec);
192
193            for axis_char in axes {
194                // Check if this axis character is used by a variable
195                for (var, &var_axis_char) in &ctx.var_to_axis {
196                    if var_axis_char == axis_char {
197                        // Get the domain for this variable
198                        if let Some(domain_name) = ctx.var_to_domain.get(var) {
199                            if let Some(domain_info) = ctx.domains.get(domain_name) {
200                                let size = domain_info.cardinality;
201
202                                // Track or validate axis size
203                                if let Some(&existing_size) = axis_domains.get(&axis_char) {
204                                    if existing_size != size {
205                                        messages.push(format!(
206                                            "WARNING: Axis '{}' has inconsistent domain sizes: {} vs {}",
207                                            axis_char, existing_size, size
208                                        ));
209                                    }
210                                } else {
211                                    axis_domains.insert(axis_char, size);
212                                }
213                            }
214                        }
215                        break;
216                    }
217                }
218            }
219        }
220    }
221
222    Ok(())
223}
224
225/// Extract axis labels from einsum specification.
226fn extract_axes_from_spec(spec: &str) -> Vec<char> {
227    let mut axes = Vec::new();
228
229    // Parse einsum spec: "ij,jk->ik"
230    if let Some((inputs, _output)) = spec.split_once("->") {
231        for input in inputs.split(',') {
232            for c in input.chars() {
233                if c.is_ascii_lowercase() && !axes.contains(&c) {
234                    axes.push(c);
235                }
236            }
237        }
238    }
239
240    axes.sort();
241    axes.dedup();
242    axes
243}
244
245/// Validate basic shape compatibility.
246fn validate_shape_compatibility(
247    graph: &EinsumGraph,
248    _ctx: &CompilerContext,
249    messages: &mut Vec<String>,
250) -> Result<()> {
251    // Track tensor shapes (if known)
252    let mut tensor_ranks = std::collections::HashMap::new();
253
254    for node in &graph.nodes {
255        match &node.op {
256            OpType::Einsum { spec } => {
257                // Parse spec to determine output rank
258                if let Some((_inputs, output)) = spec.split_once("->") {
259                    let output_rank = output.chars().filter(|c| c.is_alphabetic()).count();
260                    if let Some(&output_idx) = node.outputs.first() {
261                        tensor_ranks.insert(output_idx, output_rank);
262                    }
263                }
264            }
265            OpType::ElemUnary { .. } => {
266                // Unary ops preserve rank
267                if let Some(&input_idx) = node.inputs.first() {
268                    if let Some(&rank) = tensor_ranks.get(&input_idx) {
269                        if let Some(&output_idx) = node.outputs.first() {
270                            tensor_ranks.insert(output_idx, rank);
271                        }
272                    }
273                }
274            }
275            OpType::ElemBinary { .. } => {
276                // Binary ops require compatible ranks
277                if node.inputs.len() >= 2 {
278                    let left_rank = tensor_ranks.get(&node.inputs[0]);
279                    let right_rank = tensor_ranks.get(&node.inputs[1]);
280
281                    if let (Some(&l), Some(&r)) = (left_rank, right_rank) {
282                        if l != r && l != 0 && r != 0 {
283                            messages.push(format!(
284                                "WARNING: Element-wise binary op has mismatched ranks: {} vs {}",
285                                l, r
286                            ));
287                        }
288                        if let Some(&output_idx) = node.outputs.first() {
289                            tensor_ranks.insert(output_idx, l.max(r));
290                        }
291                    }
292                }
293            }
294            OpType::Reduce { .. } => {
295                // Reduce decreases rank by 1
296                if let Some(&input_idx) = node.inputs.first() {
297                    if let Some(&rank) = tensor_ranks.get(&input_idx) {
298                        if let Some(&output_idx) = node.outputs.first() {
299                            tensor_ranks.insert(output_idx, rank.saturating_sub(1));
300                        }
301                    }
302                }
303            }
304        }
305    }
306
307    Ok(())
308}
309
310/// Apply optimization passes to the graph.
311fn apply_optimization_passes(
312    graph: &mut EinsumGraph,
313    options: &PostCompilationOptions,
314    messages: &mut Vec<String>,
315) -> Result<usize> {
316    let mut total_optimizations = 0;
317
318    // 1. Contraction optimization
319    if options.enable_contraction_opt {
320        let config = ContractionOptConfig::default();
321        let (optimized_graph, stats) = optimize_contractions_with_config(graph, &config);
322        *graph = optimized_graph;
323
324        if stats.contractions_reordered > 0 {
325            messages.push(format!(
326                "Contraction optimization: {} contractions reordered, {:.1}% FLOPs reduction",
327                stats.contractions_reordered, stats.flops_reduction_percent
328            ));
329            total_optimizations += stats.total_optimizations();
330        }
331    }
332
333    // 2. Loop fusion
334    if options.enable_loop_fusion {
335        let config = LoopFusionConfig::default();
336        let (optimized_graph, stats) = fuse_loops_with_config(graph, &config);
337        *graph = optimized_graph;
338
339        if stats.loops_fused > 0 {
340            messages.push(format!(
341                "Loop fusion: {} loops fused, {} intermediates eliminated",
342                stats.loops_fused, stats.intermediates_eliminated
343            ));
344            total_optimizations += stats.total_optimizations();
345        }
346    }
347
348    if total_optimizations == 0 {
349        messages.push("No graph optimizations applied (graph may already be optimal)".to_string());
350    }
351
352    Ok(total_optimizations)
353}
354
355/// Quick validation check (used internally).
356pub fn quick_validate(graph: &EinsumGraph) -> Result<()> {
357    // Check for cycles
358    if has_cycle(graph) {
359        bail!("Graph contains cycles");
360    }
361
362    // Check that all tensor references are valid
363    for node in &graph.nodes {
364        for &input_idx in &node.inputs {
365            if input_idx >= graph.tensors.len() {
366                bail!(
367                    "Invalid tensor reference: {} (graph has {} tensors)",
368                    input_idx,
369                    graph.tensors.len()
370                );
371            }
372        }
373    }
374
375    // Check that outputs are valid
376    for &output_idx in &graph.outputs {
377        if output_idx >= graph.tensors.len() {
378            bail!(
379                "Invalid output reference: {} (graph has {} tensors)",
380                output_idx,
381                graph.tensors.len()
382            );
383        }
384    }
385
386    Ok(())
387}
388
389/// Check if graph contains cycles (basic DFS).
390fn has_cycle(graph: &EinsumGraph) -> bool {
391    let mut visited = HashSet::new();
392    let mut rec_stack = HashSet::new();
393
394    for node in &graph.nodes {
395        for &output_idx in &node.outputs {
396            if !visited.contains(&output_idx)
397                && has_cycle_util(graph, output_idx, &mut visited, &mut rec_stack)
398            {
399                return true;
400            }
401        }
402    }
403
404    false
405}
406
407fn has_cycle_util(
408    graph: &EinsumGraph,
409    tensor_idx: usize,
410    visited: &mut HashSet<usize>,
411    rec_stack: &mut HashSet<usize>,
412) -> bool {
413    visited.insert(tensor_idx);
414    rec_stack.insert(tensor_idx);
415
416    // Find nodes that produce this tensor
417    for node in &graph.nodes {
418        if node.outputs.contains(&tensor_idx) {
419            for &input_idx in &node.inputs {
420                if !visited.contains(&input_idx) {
421                    if has_cycle_util(graph, input_idx, visited, rec_stack) {
422                        return true;
423                    }
424                } else if rec_stack.contains(&input_idx) {
425                    return true;
426                }
427            }
428        }
429    }
430
431    rec_stack.remove(&tensor_idx);
432    false
433}
434
435#[cfg(test)]
436mod tests {
437    use super::*;
438    use crate::{compile_to_einsum_with_context, CompilerContext};
439    use tensorlogic_ir::{TLExpr, Term};
440
441    #[test]
442    fn test_post_compilation_simple() {
443        let mut ctx = CompilerContext::new();
444        ctx.add_domain("Person", 100);
445
446        let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
447
448        let mut graph = compile_to_einsum_with_context(&expr, &mut ctx).unwrap();
449
450        let options = PostCompilationOptions::default();
451        let result = post_compilation_passes(&mut graph, &ctx, options).unwrap();
452
453        assert!(result.is_valid);
454    }
455
456    #[test]
457    fn test_post_compilation_with_quantifier() {
458        let mut ctx = CompilerContext::new();
459        ctx.add_domain("Person", 100);
460
461        let expr = TLExpr::exists(
462            "y",
463            "Person",
464            TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]),
465        );
466
467        let mut graph = compile_to_einsum_with_context(&expr, &mut ctx).unwrap();
468
469        let options = PostCompilationOptions::default();
470        let result = post_compilation_passes(&mut graph, &ctx, options).unwrap();
471
472        assert!(result.is_valid);
473    }
474
475    #[test]
476    fn test_quick_validate_success() {
477        let mut ctx = CompilerContext::new();
478        ctx.add_domain("D", 10);
479
480        let expr = TLExpr::pred("p", vec![Term::var("x")]);
481        let graph = compile_to_einsum_with_context(&expr, &mut ctx).unwrap();
482
483        assert!(quick_validate(&graph).is_ok());
484    }
485
486    #[test]
487    fn test_extract_axes_from_spec() {
488        let spec = "ab,bc->ac";
489        let axes = extract_axes_from_spec(spec);
490        assert_eq!(axes, vec!['a', 'b', 'c']);
491
492        let spec2 = "ij->i";
493        let axes2 = extract_axes_from_spec(spec2);
494        assert_eq!(axes2, vec!['i', 'j']);
495    }
496
497    #[test]
498    fn test_post_compilation_optimizations() {
499        let mut ctx = CompilerContext::new();
500        ctx.add_domain("D", 10);
501
502        // Create expression that will have optimizable structure
503        let expr = TLExpr::And(
504            Box::new(TLExpr::pred("p", vec![Term::var("x")])),
505            Box::new(TLExpr::pred("q", vec![Term::var("y")])),
506        );
507
508        let mut graph = compile_to_einsum_with_context(&expr, &mut ctx).unwrap();
509
510        let options = PostCompilationOptions {
511            apply_optimizations: true,
512            ..Default::default()
513        };
514
515        let result = post_compilation_passes(&mut graph, &ctx, options).unwrap();
516        assert!(result.is_valid);
517        // May or may not have optimizations depending on graph structure
518    }
519
520    #[test]
521    fn test_post_compilation_strict_mode() {
522        let mut ctx = CompilerContext::new();
523        ctx.add_domain("Person", 100);
524
525        let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
526
527        let mut graph = compile_to_einsum_with_context(&expr, &mut ctx).unwrap();
528
529        let options = PostCompilationOptions {
530            strict_mode: true,
531            ..Default::default()
532        };
533
534        let result = post_compilation_passes(&mut graph, &ctx, options);
535        // Should pass if no warnings
536        assert!(result.is_ok());
537    }
538}