tensorlogic_compiler/passes/
post_compilation.rs1use anyhow::{bail, Result};
7use std::collections::HashSet;
8use tensorlogic_ir::{validate_graph, EinsumGraph, OpType, ValidationReport};
9
10use crate::CompilerContext;
11
12#[derive(Debug, Clone)]
14pub struct PostCompilationOptions {
15 pub validate_graph_structure: bool,
17 pub validate_axes: bool,
19 pub validate_shapes: bool,
21 pub apply_optimizations: bool,
23 pub strict_mode: bool,
25}
26
27impl Default for PostCompilationOptions {
28 fn default() -> Self {
29 Self {
30 validate_graph_structure: true,
31 validate_axes: true,
32 validate_shapes: true,
33 apply_optimizations: true,
34 strict_mode: false,
35 }
36 }
37}
38
39#[derive(Debug, Clone)]
41pub struct PostCompilationResult {
42 pub validation_report: ValidationReport,
44 pub is_valid: bool,
46 pub optimizations_applied: usize,
48 pub messages: Vec<String>,
50}
51
52pub fn post_compilation_passes(
80 graph: &mut EinsumGraph,
81 ctx: &CompilerContext,
82 options: PostCompilationOptions,
83) -> Result<PostCompilationResult> {
84 let mut messages = Vec::new();
85 let mut optimizations_applied = 0;
86
87 let validation_report = if options.validate_graph_structure {
89 let report = validate_graph(graph);
90
91 let has_simple_passthrough = graph.nodes.is_empty()
94 || (graph.outputs.len() == 1 && graph.inputs.contains(&graph.outputs[0]));
95
96 let filtered_errors: Vec<_> = report
97 .errors
98 .into_iter()
99 .filter(|error| {
100 if has_simple_passthrough && error.message.contains("has no producer") {
102 return false; }
104 true })
106 .collect();
107
108 for error in &filtered_errors {
109 messages.push(format!("ERROR: {}", error.message));
110 }
111
112 if !report.warnings.is_empty() {
113 for warning in &report.warnings {
114 messages.push(format!("WARNING: {}", warning.message));
115 }
116 }
117
118 ValidationReport {
119 checks_performed: report.checks_performed,
120 errors: filtered_errors,
121 warnings: report.warnings,
122 stats: report.stats,
123 }
124 } else {
125 ValidationReport {
126 checks_performed: 0,
127 errors: vec![],
128 warnings: vec![],
129 stats: Default::default(),
130 }
131 };
132
133 if options.validate_axes {
135 validate_axis_consistency(graph, ctx, &mut messages)?;
136 }
137
138 if options.validate_shapes {
140 validate_shape_compatibility(graph, ctx, &mut messages)?;
141 }
142
143 if options.apply_optimizations {
145 optimizations_applied += apply_optimization_passes(graph, &mut messages)?;
146 }
147
148 let is_valid = validation_report.is_valid()
150 && (!options.strict_mode || validation_report.warnings.is_empty());
151
152 if !is_valid {
153 bail!(
154 "Post-compilation validation failed:\n{}",
155 messages.join("\n")
156 );
157 }
158
159 Ok(PostCompilationResult {
160 validation_report,
161 is_valid,
162 optimizations_applied,
163 messages,
164 })
165}
166
167fn validate_axis_consistency(
169 graph: &EinsumGraph,
170 ctx: &CompilerContext,
171 messages: &mut Vec<String>,
172) -> Result<()> {
173 let mut axis_domains = std::collections::HashMap::new();
175
176 for node in &graph.nodes {
177 if let OpType::Einsum { spec, .. } = &node.op {
178 let axes = extract_axes_from_spec(spec);
180
181 for axis_char in axes {
182 for (var, &var_axis_char) in &ctx.var_to_axis {
184 if var_axis_char == axis_char {
185 if let Some(domain_name) = ctx.var_to_domain.get(var) {
187 if let Some(domain_info) = ctx.domains.get(domain_name) {
188 let size = domain_info.cardinality;
189
190 if let Some(&existing_size) = axis_domains.get(&axis_char) {
192 if existing_size != size {
193 messages.push(format!(
194 "WARNING: Axis '{}' has inconsistent domain sizes: {} vs {}",
195 axis_char, existing_size, size
196 ));
197 }
198 } else {
199 axis_domains.insert(axis_char, size);
200 }
201 }
202 }
203 break;
204 }
205 }
206 }
207 }
208 }
209
210 Ok(())
211}
212
213fn extract_axes_from_spec(spec: &str) -> Vec<char> {
215 let mut axes = Vec::new();
216
217 if let Some((inputs, _output)) = spec.split_once("->") {
219 for input in inputs.split(',') {
220 for c in input.chars() {
221 if c.is_ascii_lowercase() && !axes.contains(&c) {
222 axes.push(c);
223 }
224 }
225 }
226 }
227
228 axes.sort();
229 axes.dedup();
230 axes
231}
232
233fn validate_shape_compatibility(
235 graph: &EinsumGraph,
236 _ctx: &CompilerContext,
237 messages: &mut Vec<String>,
238) -> Result<()> {
239 let mut tensor_ranks = std::collections::HashMap::new();
241
242 for node in &graph.nodes {
243 match &node.op {
244 OpType::Einsum { spec } => {
245 if let Some((_inputs, output)) = spec.split_once("->") {
247 let output_rank = output.chars().filter(|c| c.is_alphabetic()).count();
248 if let Some(&output_idx) = node.outputs.first() {
249 tensor_ranks.insert(output_idx, output_rank);
250 }
251 }
252 }
253 OpType::ElemUnary { .. } => {
254 if let Some(&input_idx) = node.inputs.first() {
256 if let Some(&rank) = tensor_ranks.get(&input_idx) {
257 if let Some(&output_idx) = node.outputs.first() {
258 tensor_ranks.insert(output_idx, rank);
259 }
260 }
261 }
262 }
263 OpType::ElemBinary { .. } => {
264 if node.inputs.len() >= 2 {
266 let left_rank = tensor_ranks.get(&node.inputs[0]);
267 let right_rank = tensor_ranks.get(&node.inputs[1]);
268
269 if let (Some(&l), Some(&r)) = (left_rank, right_rank) {
270 if l != r && l != 0 && r != 0 {
271 messages.push(format!(
272 "WARNING: Element-wise binary op has mismatched ranks: {} vs {}",
273 l, r
274 ));
275 }
276 if let Some(&output_idx) = node.outputs.first() {
277 tensor_ranks.insert(output_idx, l.max(r));
278 }
279 }
280 }
281 }
282 OpType::Reduce { .. } => {
283 if let Some(&input_idx) = node.inputs.first() {
285 if let Some(&rank) = tensor_ranks.get(&input_idx) {
286 if let Some(&output_idx) = node.outputs.first() {
287 tensor_ranks.insert(output_idx, rank.saturating_sub(1));
288 }
289 }
290 }
291 }
292 }
293 }
294
295 Ok(())
296}
297
298fn apply_optimization_passes(
300 _graph: &mut EinsumGraph,
301 messages: &mut Vec<String>,
302) -> Result<usize> {
303 messages.push("Graph optimizations: currently disabled (awaiting IR API support)".to_string());
308
309 Ok(0)
310}
311
312pub fn quick_validate(graph: &EinsumGraph) -> Result<()> {
314 if has_cycle(graph) {
316 bail!("Graph contains cycles");
317 }
318
319 for node in &graph.nodes {
321 for &input_idx in &node.inputs {
322 if input_idx >= graph.tensors.len() {
323 bail!(
324 "Invalid tensor reference: {} (graph has {} tensors)",
325 input_idx,
326 graph.tensors.len()
327 );
328 }
329 }
330 }
331
332 for &output_idx in &graph.outputs {
334 if output_idx >= graph.tensors.len() {
335 bail!(
336 "Invalid output reference: {} (graph has {} tensors)",
337 output_idx,
338 graph.tensors.len()
339 );
340 }
341 }
342
343 Ok(())
344}
345
346fn has_cycle(graph: &EinsumGraph) -> bool {
348 let mut visited = HashSet::new();
349 let mut rec_stack = HashSet::new();
350
351 for node in &graph.nodes {
352 for &output_idx in &node.outputs {
353 if !visited.contains(&output_idx)
354 && has_cycle_util(graph, output_idx, &mut visited, &mut rec_stack)
355 {
356 return true;
357 }
358 }
359 }
360
361 false
362}
363
364fn has_cycle_util(
365 graph: &EinsumGraph,
366 tensor_idx: usize,
367 visited: &mut HashSet<usize>,
368 rec_stack: &mut HashSet<usize>,
369) -> bool {
370 visited.insert(tensor_idx);
371 rec_stack.insert(tensor_idx);
372
373 for node in &graph.nodes {
375 if node.outputs.contains(&tensor_idx) {
376 for &input_idx in &node.inputs {
377 if !visited.contains(&input_idx) {
378 if has_cycle_util(graph, input_idx, visited, rec_stack) {
379 return true;
380 }
381 } else if rec_stack.contains(&input_idx) {
382 return true;
383 }
384 }
385 }
386 }
387
388 rec_stack.remove(&tensor_idx);
389 false
390}
391
392#[cfg(test)]
393mod tests {
394 use super::*;
395 use crate::{compile_to_einsum_with_context, CompilerContext};
396 use tensorlogic_ir::{TLExpr, Term};
397
398 #[test]
399 fn test_post_compilation_simple() {
400 let mut ctx = CompilerContext::new();
401 ctx.add_domain("Person", 100);
402
403 let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
404
405 let mut graph = compile_to_einsum_with_context(&expr, &mut ctx).unwrap();
406
407 let options = PostCompilationOptions::default();
408 let result = post_compilation_passes(&mut graph, &ctx, options).unwrap();
409
410 assert!(result.is_valid);
411 }
412
413 #[test]
414 fn test_post_compilation_with_quantifier() {
415 let mut ctx = CompilerContext::new();
416 ctx.add_domain("Person", 100);
417
418 let expr = TLExpr::exists(
419 "y",
420 "Person",
421 TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]),
422 );
423
424 let mut graph = compile_to_einsum_with_context(&expr, &mut ctx).unwrap();
425
426 let options = PostCompilationOptions::default();
427 let result = post_compilation_passes(&mut graph, &ctx, options).unwrap();
428
429 assert!(result.is_valid);
430 }
431
432 #[test]
433 fn test_quick_validate_success() {
434 let mut ctx = CompilerContext::new();
435 ctx.add_domain("D", 10);
436
437 let expr = TLExpr::pred("p", vec![Term::var("x")]);
438 let graph = compile_to_einsum_with_context(&expr, &mut ctx).unwrap();
439
440 assert!(quick_validate(&graph).is_ok());
441 }
442
443 #[test]
444 fn test_extract_axes_from_spec() {
445 let spec = "ab,bc->ac";
446 let axes = extract_axes_from_spec(spec);
447 assert_eq!(axes, vec!['a', 'b', 'c']);
448
449 let spec2 = "ij->i";
450 let axes2 = extract_axes_from_spec(spec2);
451 assert_eq!(axes2, vec!['i', 'j']);
452 }
453
454 #[test]
455 fn test_post_compilation_optimizations() {
456 let mut ctx = CompilerContext::new();
457 ctx.add_domain("D", 10);
458
459 let expr = TLExpr::And(
461 Box::new(TLExpr::pred("p", vec![Term::var("x")])),
462 Box::new(TLExpr::pred("q", vec![Term::var("y")])),
463 );
464
465 let mut graph = compile_to_einsum_with_context(&expr, &mut ctx).unwrap();
466
467 let options = PostCompilationOptions {
468 apply_optimizations: true,
469 ..Default::default()
470 };
471
472 let result = post_compilation_passes(&mut graph, &ctx, options).unwrap();
473 assert!(result.is_valid);
474 }
476
477 #[test]
478 fn test_post_compilation_strict_mode() {
479 let mut ctx = CompilerContext::new();
480 ctx.add_domain("Person", 100);
481
482 let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
483
484 let mut graph = compile_to_einsum_with_context(&expr, &mut ctx).unwrap();
485
486 let options = PostCompilationOptions {
487 strict_mode: true,
488 ..Default::default()
489 };
490
491 let result = post_compilation_passes(&mut graph, &ctx, options);
492 assert!(result.is_ok());
494 }
495}