tensorlogic_compiler/passes/
post_compilation.rs1use anyhow::{bail, Result};
7use std::collections::HashSet;
8use tensorlogic_ir::{validate_graph, EinsumGraph, OpType, ValidationReport};
9
10use crate::CompilerContext;
11
12use super::{
14 contraction_opt::{optimize_contractions_with_config, ContractionOptConfig},
15 loop_fusion::{fuse_loops_with_config, LoopFusionConfig},
16};
17
18#[derive(Debug, Clone)]
20pub struct PostCompilationOptions {
21 pub validate_graph_structure: bool,
23 pub validate_axes: bool,
25 pub validate_shapes: bool,
27 pub apply_optimizations: bool,
29 pub enable_contraction_opt: bool,
31 pub enable_loop_fusion: bool,
33 pub strict_mode: bool,
35}
36
37impl Default for PostCompilationOptions {
38 fn default() -> Self {
39 Self {
40 validate_graph_structure: true,
41 validate_axes: true,
42 validate_shapes: true,
43 apply_optimizations: true,
44 enable_contraction_opt: true,
45 enable_loop_fusion: true,
46 strict_mode: false,
47 }
48 }
49}
50
51#[derive(Debug, Clone)]
53pub struct PostCompilationResult {
54 pub validation_report: ValidationReport,
56 pub is_valid: bool,
58 pub optimizations_applied: usize,
60 pub messages: Vec<String>,
62}
63
64pub fn post_compilation_passes(
92 graph: &mut EinsumGraph,
93 ctx: &CompilerContext,
94 options: PostCompilationOptions,
95) -> Result<PostCompilationResult> {
96 let mut messages = Vec::new();
97 let mut optimizations_applied = 0;
98
99 let validation_report = if options.validate_graph_structure {
101 let report = validate_graph(graph);
102
103 let has_simple_passthrough = graph.nodes.is_empty()
106 || (graph.outputs.len() == 1 && graph.inputs.contains(&graph.outputs[0]));
107
108 let filtered_errors: Vec<_> = report
109 .errors
110 .into_iter()
111 .filter(|error| {
112 if has_simple_passthrough && error.message.contains("has no producer") {
114 return false; }
116 true })
118 .collect();
119
120 for error in &filtered_errors {
121 messages.push(format!("ERROR: {}", error.message));
122 }
123
124 if !report.warnings.is_empty() {
125 for warning in &report.warnings {
126 messages.push(format!("WARNING: {}", warning.message));
127 }
128 }
129
130 ValidationReport {
131 checks_performed: report.checks_performed,
132 errors: filtered_errors,
133 warnings: report.warnings,
134 stats: report.stats,
135 }
136 } else {
137 ValidationReport {
138 checks_performed: 0,
139 errors: vec![],
140 warnings: vec![],
141 stats: Default::default(),
142 }
143 };
144
145 if options.validate_axes {
147 validate_axis_consistency(graph, ctx, &mut messages)?;
148 }
149
150 if options.validate_shapes {
152 validate_shape_compatibility(graph, ctx, &mut messages)?;
153 }
154
155 if options.apply_optimizations {
157 optimizations_applied += apply_optimization_passes(graph, &options, &mut messages)?;
158 }
159
160 let is_valid = validation_report.is_valid()
162 && (!options.strict_mode || validation_report.warnings.is_empty());
163
164 if !is_valid {
165 bail!(
166 "Post-compilation validation failed:\n{}",
167 messages.join("\n")
168 );
169 }
170
171 Ok(PostCompilationResult {
172 validation_report,
173 is_valid,
174 optimizations_applied,
175 messages,
176 })
177}
178
179fn validate_axis_consistency(
181 graph: &EinsumGraph,
182 ctx: &CompilerContext,
183 messages: &mut Vec<String>,
184) -> Result<()> {
185 let mut axis_domains = std::collections::HashMap::new();
187
188 for node in &graph.nodes {
189 if let OpType::Einsum { spec, .. } = &node.op {
190 let axes = extract_axes_from_spec(spec);
192
193 for axis_char in axes {
194 for (var, &var_axis_char) in &ctx.var_to_axis {
196 if var_axis_char == axis_char {
197 if let Some(domain_name) = ctx.var_to_domain.get(var) {
199 if let Some(domain_info) = ctx.domains.get(domain_name) {
200 let size = domain_info.cardinality;
201
202 if let Some(&existing_size) = axis_domains.get(&axis_char) {
204 if existing_size != size {
205 messages.push(format!(
206 "WARNING: Axis '{}' has inconsistent domain sizes: {} vs {}",
207 axis_char, existing_size, size
208 ));
209 }
210 } else {
211 axis_domains.insert(axis_char, size);
212 }
213 }
214 }
215 break;
216 }
217 }
218 }
219 }
220 }
221
222 Ok(())
223}
224
225fn extract_axes_from_spec(spec: &str) -> Vec<char> {
227 let mut axes = Vec::new();
228
229 if let Some((inputs, _output)) = spec.split_once("->") {
231 for input in inputs.split(',') {
232 for c in input.chars() {
233 if c.is_ascii_lowercase() && !axes.contains(&c) {
234 axes.push(c);
235 }
236 }
237 }
238 }
239
240 axes.sort();
241 axes.dedup();
242 axes
243}
244
245fn validate_shape_compatibility(
247 graph: &EinsumGraph,
248 _ctx: &CompilerContext,
249 messages: &mut Vec<String>,
250) -> Result<()> {
251 let mut tensor_ranks = std::collections::HashMap::new();
253
254 for node in &graph.nodes {
255 match &node.op {
256 OpType::Einsum { spec } => {
257 if let Some((_inputs, output)) = spec.split_once("->") {
259 let output_rank = output.chars().filter(|c| c.is_alphabetic()).count();
260 if let Some(&output_idx) = node.outputs.first() {
261 tensor_ranks.insert(output_idx, output_rank);
262 }
263 }
264 }
265 OpType::ElemUnary { .. } => {
266 if let Some(&input_idx) = node.inputs.first() {
268 if let Some(&rank) = tensor_ranks.get(&input_idx) {
269 if let Some(&output_idx) = node.outputs.first() {
270 tensor_ranks.insert(output_idx, rank);
271 }
272 }
273 }
274 }
275 OpType::ElemBinary { .. } => {
276 if node.inputs.len() >= 2 {
278 let left_rank = tensor_ranks.get(&node.inputs[0]);
279 let right_rank = tensor_ranks.get(&node.inputs[1]);
280
281 if let (Some(&l), Some(&r)) = (left_rank, right_rank) {
282 if l != r && l != 0 && r != 0 {
283 messages.push(format!(
284 "WARNING: Element-wise binary op has mismatched ranks: {} vs {}",
285 l, r
286 ));
287 }
288 if let Some(&output_idx) = node.outputs.first() {
289 tensor_ranks.insert(output_idx, l.max(r));
290 }
291 }
292 }
293 }
294 OpType::Reduce { .. } => {
295 if let Some(&input_idx) = node.inputs.first() {
297 if let Some(&rank) = tensor_ranks.get(&input_idx) {
298 if let Some(&output_idx) = node.outputs.first() {
299 tensor_ranks.insert(output_idx, rank.saturating_sub(1));
300 }
301 }
302 }
303 }
304 }
305 }
306
307 Ok(())
308}
309
310fn apply_optimization_passes(
312 graph: &mut EinsumGraph,
313 options: &PostCompilationOptions,
314 messages: &mut Vec<String>,
315) -> Result<usize> {
316 let mut total_optimizations = 0;
317
318 if options.enable_contraction_opt {
320 let config = ContractionOptConfig::default();
321 let (optimized_graph, stats) = optimize_contractions_with_config(graph, &config);
322 *graph = optimized_graph;
323
324 if stats.contractions_reordered > 0 {
325 messages.push(format!(
326 "Contraction optimization: {} contractions reordered, {:.1}% FLOPs reduction",
327 stats.contractions_reordered, stats.flops_reduction_percent
328 ));
329 total_optimizations += stats.total_optimizations();
330 }
331 }
332
333 if options.enable_loop_fusion {
335 let config = LoopFusionConfig::default();
336 let (optimized_graph, stats) = fuse_loops_with_config(graph, &config);
337 *graph = optimized_graph;
338
339 if stats.loops_fused > 0 {
340 messages.push(format!(
341 "Loop fusion: {} loops fused, {} intermediates eliminated",
342 stats.loops_fused, stats.intermediates_eliminated
343 ));
344 total_optimizations += stats.total_optimizations();
345 }
346 }
347
348 if total_optimizations == 0 {
349 messages.push("No graph optimizations applied (graph may already be optimal)".to_string());
350 }
351
352 Ok(total_optimizations)
353}
354
355pub fn quick_validate(graph: &EinsumGraph) -> Result<()> {
357 if has_cycle(graph) {
359 bail!("Graph contains cycles");
360 }
361
362 for node in &graph.nodes {
364 for &input_idx in &node.inputs {
365 if input_idx >= graph.tensors.len() {
366 bail!(
367 "Invalid tensor reference: {} (graph has {} tensors)",
368 input_idx,
369 graph.tensors.len()
370 );
371 }
372 }
373 }
374
375 for &output_idx in &graph.outputs {
377 if output_idx >= graph.tensors.len() {
378 bail!(
379 "Invalid output reference: {} (graph has {} tensors)",
380 output_idx,
381 graph.tensors.len()
382 );
383 }
384 }
385
386 Ok(())
387}
388
389fn has_cycle(graph: &EinsumGraph) -> bool {
391 let mut visited = HashSet::new();
392 let mut rec_stack = HashSet::new();
393
394 for node in &graph.nodes {
395 for &output_idx in &node.outputs {
396 if !visited.contains(&output_idx)
397 && has_cycle_util(graph, output_idx, &mut visited, &mut rec_stack)
398 {
399 return true;
400 }
401 }
402 }
403
404 false
405}
406
407fn has_cycle_util(
408 graph: &EinsumGraph,
409 tensor_idx: usize,
410 visited: &mut HashSet<usize>,
411 rec_stack: &mut HashSet<usize>,
412) -> bool {
413 visited.insert(tensor_idx);
414 rec_stack.insert(tensor_idx);
415
416 for node in &graph.nodes {
418 if node.outputs.contains(&tensor_idx) {
419 for &input_idx in &node.inputs {
420 if !visited.contains(&input_idx) {
421 if has_cycle_util(graph, input_idx, visited, rec_stack) {
422 return true;
423 }
424 } else if rec_stack.contains(&input_idx) {
425 return true;
426 }
427 }
428 }
429 }
430
431 rec_stack.remove(&tensor_idx);
432 false
433}
434
435#[cfg(test)]
436mod tests {
437 use super::*;
438 use crate::{compile_to_einsum_with_context, CompilerContext};
439 use tensorlogic_ir::{TLExpr, Term};
440
441 #[test]
442 fn test_post_compilation_simple() {
443 let mut ctx = CompilerContext::new();
444 ctx.add_domain("Person", 100);
445
446 let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
447
448 let mut graph = compile_to_einsum_with_context(&expr, &mut ctx).unwrap();
449
450 let options = PostCompilationOptions::default();
451 let result = post_compilation_passes(&mut graph, &ctx, options).unwrap();
452
453 assert!(result.is_valid);
454 }
455
456 #[test]
457 fn test_post_compilation_with_quantifier() {
458 let mut ctx = CompilerContext::new();
459 ctx.add_domain("Person", 100);
460
461 let expr = TLExpr::exists(
462 "y",
463 "Person",
464 TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]),
465 );
466
467 let mut graph = compile_to_einsum_with_context(&expr, &mut ctx).unwrap();
468
469 let options = PostCompilationOptions::default();
470 let result = post_compilation_passes(&mut graph, &ctx, options).unwrap();
471
472 assert!(result.is_valid);
473 }
474
475 #[test]
476 fn test_quick_validate_success() {
477 let mut ctx = CompilerContext::new();
478 ctx.add_domain("D", 10);
479
480 let expr = TLExpr::pred("p", vec![Term::var("x")]);
481 let graph = compile_to_einsum_with_context(&expr, &mut ctx).unwrap();
482
483 assert!(quick_validate(&graph).is_ok());
484 }
485
486 #[test]
487 fn test_extract_axes_from_spec() {
488 let spec = "ab,bc->ac";
489 let axes = extract_axes_from_spec(spec);
490 assert_eq!(axes, vec!['a', 'b', 'c']);
491
492 let spec2 = "ij->i";
493 let axes2 = extract_axes_from_spec(spec2);
494 assert_eq!(axes2, vec!['i', 'j']);
495 }
496
497 #[test]
498 fn test_post_compilation_optimizations() {
499 let mut ctx = CompilerContext::new();
500 ctx.add_domain("D", 10);
501
502 let expr = TLExpr::And(
504 Box::new(TLExpr::pred("p", vec![Term::var("x")])),
505 Box::new(TLExpr::pred("q", vec![Term::var("y")])),
506 );
507
508 let mut graph = compile_to_einsum_with_context(&expr, &mut ctx).unwrap();
509
510 let options = PostCompilationOptions {
511 apply_optimizations: true,
512 ..Default::default()
513 };
514
515 let result = post_compilation_passes(&mut graph, &ctx, options).unwrap();
516 assert!(result.is_valid);
517 }
519
520 #[test]
521 fn test_post_compilation_strict_mode() {
522 let mut ctx = CompilerContext::new();
523 ctx.add_domain("Person", 100);
524
525 let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
526
527 let mut graph = compile_to_einsum_with_context(&expr, &mut ctx).unwrap();
528
529 let options = PostCompilationOptions {
530 strict_mode: true,
531 ..Default::default()
532 };
533
534 let result = post_compilation_passes(&mut graph, &ctx, options);
535 assert!(result.is_ok());
537 }
538}