tensorlogic_compiler/passes/
post_compilation.rs

1//! Post-compilation validation and optimization passes.
2//!
3//! This module provides validation and optimization passes that run after
4//! the initial compilation to ensure correctness and improve performance.
5
6use anyhow::{bail, Result};
7use std::collections::HashSet;
8use tensorlogic_ir::{validate_graph, EinsumGraph, OpType, ValidationReport};
9
10use crate::CompilerContext;
11
12/// Post-compilation validation options.
13#[derive(Debug, Clone)]
14pub struct PostCompilationOptions {
15    /// Enable graph structure validation
16    pub validate_graph_structure: bool,
17    /// Enable axis consistency checks
18    pub validate_axes: bool,
19    /// Enable shape compatibility checks
20    pub validate_shapes: bool,
21    /// Enable optimization passes
22    pub apply_optimizations: bool,
23    /// Fail on warnings
24    pub strict_mode: bool,
25}
26
27impl Default for PostCompilationOptions {
28    fn default() -> Self {
29        Self {
30            validate_graph_structure: true,
31            validate_axes: true,
32            validate_shapes: true,
33            apply_optimizations: true,
34            strict_mode: false,
35        }
36    }
37}
38
39/// Result of post-compilation passes.
40#[derive(Debug, Clone)]
41pub struct PostCompilationResult {
42    /// Validation report
43    pub validation_report: ValidationReport,
44    /// Whether the graph passed all checks
45    pub is_valid: bool,
46    /// Number of optimizations applied
47    pub optimizations_applied: usize,
48    /// Detailed messages
49    pub messages: Vec<String>,
50}
51
52/// Run post-compilation validation and optimization passes.
53///
54/// # Examples
55///
56/// ```
57/// use tensorlogic_compiler::{compile_to_einsum_with_context, CompilerContext};
58/// use tensorlogic_compiler::passes::{
59///     post_compilation_passes, PostCompilationOptions
60/// };
61/// use tensorlogic_ir::{TLExpr, Term};
62///
63/// let mut ctx = CompilerContext::new();
64/// ctx.add_domain("Person", 100);
65///
66/// let expr = TLExpr::exists(
67///     "y",
68///     "Person",
69///     TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]),
70/// );
71///
72/// let mut graph = compile_to_einsum_with_context(&expr, &mut ctx).unwrap();
73///
74/// let options = PostCompilationOptions::default();
75/// let result = post_compilation_passes(&mut graph, &ctx, options).unwrap();
76///
77/// assert!(result.is_valid);
78/// ```
79pub fn post_compilation_passes(
80    graph: &mut EinsumGraph,
81    ctx: &CompilerContext,
82    options: PostCompilationOptions,
83) -> Result<PostCompilationResult> {
84    let mut messages = Vec::new();
85    let mut optimizations_applied = 0;
86
87    // 1. Validate graph structure
88    let validation_report = if options.validate_graph_structure {
89        let report = validate_graph(graph);
90
91        // For simple predicate expressions, the output might be an input tensor
92        // with no producer node, which is valid. Filter out such errors.
93        let has_simple_passthrough = graph.nodes.is_empty()
94            || (graph.outputs.len() == 1 && graph.inputs.contains(&graph.outputs[0]));
95
96        let filtered_errors: Vec<_> = report
97            .errors
98            .into_iter()
99            .filter(|error| {
100                // Allow "no producer" errors for simple passthrough graphs
101                if has_simple_passthrough && error.message.contains("has no producer") {
102                    return false; // Filter out this error
103                }
104                true // Keep all other errors
105            })
106            .collect();
107
108        for error in &filtered_errors {
109            messages.push(format!("ERROR: {}", error.message));
110        }
111
112        if !report.warnings.is_empty() {
113            for warning in &report.warnings {
114                messages.push(format!("WARNING: {}", warning.message));
115            }
116        }
117
118        ValidationReport {
119            checks_performed: report.checks_performed,
120            errors: filtered_errors,
121            warnings: report.warnings,
122            stats: report.stats,
123        }
124    } else {
125        ValidationReport {
126            checks_performed: 0,
127            errors: vec![],
128            warnings: vec![],
129            stats: Default::default(),
130        }
131    };
132
133    // 2. Validate axis consistency
134    if options.validate_axes {
135        validate_axis_consistency(graph, ctx, &mut messages)?;
136    }
137
138    // 3. Validate shape compatibility (basic checks)
139    if options.validate_shapes {
140        validate_shape_compatibility(graph, ctx, &mut messages)?;
141    }
142
143    // 4. Apply optimization passes
144    if options.apply_optimizations {
145        optimizations_applied += apply_optimization_passes(graph, &mut messages)?;
146    }
147
148    // Check if valid
149    let is_valid = validation_report.is_valid()
150        && (!options.strict_mode || validation_report.warnings.is_empty());
151
152    if !is_valid {
153        bail!(
154            "Post-compilation validation failed:\n{}",
155            messages.join("\n")
156        );
157    }
158
159    Ok(PostCompilationResult {
160        validation_report,
161        is_valid,
162        optimizations_applied,
163        messages,
164    })
165}
166
167/// Validate axis consistency across the graph.
168fn validate_axis_consistency(
169    graph: &EinsumGraph,
170    ctx: &CompilerContext,
171    messages: &mut Vec<String>,
172) -> Result<()> {
173    // Track which axes are used and their expected sizes
174    let mut axis_domains = std::collections::HashMap::new();
175
176    for node in &graph.nodes {
177        if let OpType::Einsum { spec, .. } = &node.op {
178            // Extract axes from einsum spec
179            let axes = extract_axes_from_spec(spec);
180
181            for axis_char in axes {
182                // Check if this axis character is used by a variable
183                for (var, &var_axis_char) in &ctx.var_to_axis {
184                    if var_axis_char == axis_char {
185                        // Get the domain for this variable
186                        if let Some(domain_name) = ctx.var_to_domain.get(var) {
187                            if let Some(domain_info) = ctx.domains.get(domain_name) {
188                                let size = domain_info.cardinality;
189
190                                // Track or validate axis size
191                                if let Some(&existing_size) = axis_domains.get(&axis_char) {
192                                    if existing_size != size {
193                                        messages.push(format!(
194                                            "WARNING: Axis '{}' has inconsistent domain sizes: {} vs {}",
195                                            axis_char, existing_size, size
196                                        ));
197                                    }
198                                } else {
199                                    axis_domains.insert(axis_char, size);
200                                }
201                            }
202                        }
203                        break;
204                    }
205                }
206            }
207        }
208    }
209
210    Ok(())
211}
212
213/// Extract axis labels from einsum specification.
214fn extract_axes_from_spec(spec: &str) -> Vec<char> {
215    let mut axes = Vec::new();
216
217    // Parse einsum spec: "ij,jk->ik"
218    if let Some((inputs, _output)) = spec.split_once("->") {
219        for input in inputs.split(',') {
220            for c in input.chars() {
221                if c.is_ascii_lowercase() && !axes.contains(&c) {
222                    axes.push(c);
223                }
224            }
225        }
226    }
227
228    axes.sort();
229    axes.dedup();
230    axes
231}
232
233/// Validate basic shape compatibility.
234fn validate_shape_compatibility(
235    graph: &EinsumGraph,
236    _ctx: &CompilerContext,
237    messages: &mut Vec<String>,
238) -> Result<()> {
239    // Track tensor shapes (if known)
240    let mut tensor_ranks = std::collections::HashMap::new();
241
242    for node in &graph.nodes {
243        match &node.op {
244            OpType::Einsum { spec } => {
245                // Parse spec to determine output rank
246                if let Some((_inputs, output)) = spec.split_once("->") {
247                    let output_rank = output.chars().filter(|c| c.is_alphabetic()).count();
248                    if let Some(&output_idx) = node.outputs.first() {
249                        tensor_ranks.insert(output_idx, output_rank);
250                    }
251                }
252            }
253            OpType::ElemUnary { .. } => {
254                // Unary ops preserve rank
255                if let Some(&input_idx) = node.inputs.first() {
256                    if let Some(&rank) = tensor_ranks.get(&input_idx) {
257                        if let Some(&output_idx) = node.outputs.first() {
258                            tensor_ranks.insert(output_idx, rank);
259                        }
260                    }
261                }
262            }
263            OpType::ElemBinary { .. } => {
264                // Binary ops require compatible ranks
265                if node.inputs.len() >= 2 {
266                    let left_rank = tensor_ranks.get(&node.inputs[0]);
267                    let right_rank = tensor_ranks.get(&node.inputs[1]);
268
269                    if let (Some(&l), Some(&r)) = (left_rank, right_rank) {
270                        if l != r && l != 0 && r != 0 {
271                            messages.push(format!(
272                                "WARNING: Element-wise binary op has mismatched ranks: {} vs {}",
273                                l, r
274                            ));
275                        }
276                        if let Some(&output_idx) = node.outputs.first() {
277                            tensor_ranks.insert(output_idx, l.max(r));
278                        }
279                    }
280                }
281            }
282            OpType::Reduce { .. } => {
283                // Reduce decreases rank by 1
284                if let Some(&input_idx) = node.inputs.first() {
285                    if let Some(&rank) = tensor_ranks.get(&input_idx) {
286                        if let Some(&output_idx) = node.outputs.first() {
287                            tensor_ranks.insert(output_idx, rank.saturating_sub(1));
288                        }
289                    }
290                }
291            }
292        }
293    }
294
295    Ok(())
296}
297
298/// Apply optimization passes to the graph.
299fn apply_optimization_passes(
300    _graph: &mut EinsumGraph,
301    messages: &mut Vec<String>,
302) -> Result<usize> {
303    // Note: Graph optimization methods (eliminate_dead_code, eliminate_common_subexpressions,
304    // simplify_identities) are not yet available in the current tensorlogic-ir API.
305    // These optimizations can be added when the IR supports them.
306
307    messages.push("Graph optimizations: currently disabled (awaiting IR API support)".to_string());
308
309    Ok(0)
310}
311
312/// Quick validation check (used internally).
313pub fn quick_validate(graph: &EinsumGraph) -> Result<()> {
314    // Check for cycles
315    if has_cycle(graph) {
316        bail!("Graph contains cycles");
317    }
318
319    // Check that all tensor references are valid
320    for node in &graph.nodes {
321        for &input_idx in &node.inputs {
322            if input_idx >= graph.tensors.len() {
323                bail!(
324                    "Invalid tensor reference: {} (graph has {} tensors)",
325                    input_idx,
326                    graph.tensors.len()
327                );
328            }
329        }
330    }
331
332    // Check that outputs are valid
333    for &output_idx in &graph.outputs {
334        if output_idx >= graph.tensors.len() {
335            bail!(
336                "Invalid output reference: {} (graph has {} tensors)",
337                output_idx,
338                graph.tensors.len()
339            );
340        }
341    }
342
343    Ok(())
344}
345
346/// Check if graph contains cycles (basic DFS).
347fn has_cycle(graph: &EinsumGraph) -> bool {
348    let mut visited = HashSet::new();
349    let mut rec_stack = HashSet::new();
350
351    for node in &graph.nodes {
352        for &output_idx in &node.outputs {
353            if !visited.contains(&output_idx)
354                && has_cycle_util(graph, output_idx, &mut visited, &mut rec_stack)
355            {
356                return true;
357            }
358        }
359    }
360
361    false
362}
363
364fn has_cycle_util(
365    graph: &EinsumGraph,
366    tensor_idx: usize,
367    visited: &mut HashSet<usize>,
368    rec_stack: &mut HashSet<usize>,
369) -> bool {
370    visited.insert(tensor_idx);
371    rec_stack.insert(tensor_idx);
372
373    // Find nodes that produce this tensor
374    for node in &graph.nodes {
375        if node.outputs.contains(&tensor_idx) {
376            for &input_idx in &node.inputs {
377                if !visited.contains(&input_idx) {
378                    if has_cycle_util(graph, input_idx, visited, rec_stack) {
379                        return true;
380                    }
381                } else if rec_stack.contains(&input_idx) {
382                    return true;
383                }
384            }
385        }
386    }
387
388    rec_stack.remove(&tensor_idx);
389    false
390}
391
392#[cfg(test)]
393mod tests {
394    use super::*;
395    use crate::{compile_to_einsum_with_context, CompilerContext};
396    use tensorlogic_ir::{TLExpr, Term};
397
398    #[test]
399    fn test_post_compilation_simple() {
400        let mut ctx = CompilerContext::new();
401        ctx.add_domain("Person", 100);
402
403        let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
404
405        let mut graph = compile_to_einsum_with_context(&expr, &mut ctx).unwrap();
406
407        let options = PostCompilationOptions::default();
408        let result = post_compilation_passes(&mut graph, &ctx, options).unwrap();
409
410        assert!(result.is_valid);
411    }
412
413    #[test]
414    fn test_post_compilation_with_quantifier() {
415        let mut ctx = CompilerContext::new();
416        ctx.add_domain("Person", 100);
417
418        let expr = TLExpr::exists(
419            "y",
420            "Person",
421            TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]),
422        );
423
424        let mut graph = compile_to_einsum_with_context(&expr, &mut ctx).unwrap();
425
426        let options = PostCompilationOptions::default();
427        let result = post_compilation_passes(&mut graph, &ctx, options).unwrap();
428
429        assert!(result.is_valid);
430    }
431
432    #[test]
433    fn test_quick_validate_success() {
434        let mut ctx = CompilerContext::new();
435        ctx.add_domain("D", 10);
436
437        let expr = TLExpr::pred("p", vec![Term::var("x")]);
438        let graph = compile_to_einsum_with_context(&expr, &mut ctx).unwrap();
439
440        assert!(quick_validate(&graph).is_ok());
441    }
442
443    #[test]
444    fn test_extract_axes_from_spec() {
445        let spec = "ab,bc->ac";
446        let axes = extract_axes_from_spec(spec);
447        assert_eq!(axes, vec!['a', 'b', 'c']);
448
449        let spec2 = "ij->i";
450        let axes2 = extract_axes_from_spec(spec2);
451        assert_eq!(axes2, vec!['i', 'j']);
452    }
453
454    #[test]
455    fn test_post_compilation_optimizations() {
456        let mut ctx = CompilerContext::new();
457        ctx.add_domain("D", 10);
458
459        // Create expression that will have optimizable structure
460        let expr = TLExpr::And(
461            Box::new(TLExpr::pred("p", vec![Term::var("x")])),
462            Box::new(TLExpr::pred("q", vec![Term::var("y")])),
463        );
464
465        let mut graph = compile_to_einsum_with_context(&expr, &mut ctx).unwrap();
466
467        let options = PostCompilationOptions {
468            apply_optimizations: true,
469            ..Default::default()
470        };
471
472        let result = post_compilation_passes(&mut graph, &ctx, options).unwrap();
473        assert!(result.is_valid);
474        // May or may not have optimizations depending on graph structure
475    }
476
477    #[test]
478    fn test_post_compilation_strict_mode() {
479        let mut ctx = CompilerContext::new();
480        ctx.add_domain("Person", 100);
481
482        let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
483
484        let mut graph = compile_to_einsum_with_context(&expr, &mut ctx).unwrap();
485
486        let options = PostCompilationOptions {
487            strict_mode: true,
488            ..Default::default()
489        };
490
491        let result = post_compilation_passes(&mut graph, &ctx, options);
492        // Should pass if no warnings
493        assert!(result.is_ok());
494    }
495}