Skip to main content

xlog_logic/
compile.rs

1//! Compilation pipeline for XLOG programs
2//!
3//! This module provides the main entry point for compiling XLOG source code
4//! into execution plans. The compilation process consists of:
5//!
6//! 1. **Parsing**: Convert source text to AST (`parser::parse_program`)
7//! 2. **Stratification**: Analyze negation/aggregation dependencies (`stratify::stratify`)
8//! 3. **Lowering**: Transform AST to Relational IR (`lower::Lowerer::lower_program`)
9//!
10//! The `Compiler` struct orchestrates these phases and provides a single
11//! entry point via the `compile` method.
12
13use std::path::{Path, PathBuf};
14
15use xlog_core::{Result, XlogError};
16use xlog_ir::ExecutionPlan;
17use xlog_stats::{StatsManager, StatsSnapshot};
18
19use crate::compiler_config::CompilerConfig;
20use crate::list_normalize::normalize_v085_lists;
21use crate::lower::Lowerer;
22use crate::magic_sets::rewrite_v085_magic_sets;
23use crate::meta_normalize::normalize_v085_meta;
24use crate::module::ModuleError;
25use crate::optimizer::Optimizer;
26use crate::parser::parse_program;
27use crate::resolver::ModuleResolver;
28use crate::stratify::stratify;
29use crate::{BodyLiteral, Program, Query, Rule as AstRule, Term};
30
31/// The XLOG compiler orchestrates the full compilation pipeline.
32///
33/// # Example
34///
35/// ```ignore
36/// use xlog_logic::compile::Compiler;
37///
38/// let mut compiler = Compiler::new();
39/// let plan = compiler.compile(r#"
40///     edge(1, 2).
41///     edge(2, 3).
42///     reach(X, Y) :- edge(X, Y).
43///     reach(X, Z) :- reach(X, Y), edge(Y, Z).
44/// "#)?;
45/// ```
46pub struct Compiler {
47    lowerer: Lowerer,
48}
49
50use std::collections::{HashMap, HashSet};
51use std::sync::Arc;
52use xlog_core::{RelId, Schema};
53
54impl Default for Compiler {
55    fn default() -> Self {
56        Self::new()
57    }
58}
59
60impl Compiler {
61    /// Create a new compiler instance.
62    pub fn new() -> Self {
63        Self {
64            lowerer: Lowerer::new(),
65        }
66    }
67
68    /// Set the maximum active rules for TensorMaskedJoin (16..=128).
69    pub fn set_max_active_rules(&mut self, max: usize) {
70        self.lowerer.set_max_active_rules(max);
71    }
72
73    /// Compile XLOG source code into an execution plan.
74    ///
75    /// This is the main entry point for compilation. It chains together:
76    /// 1. Parsing (source → AST)
77    /// 2. Stratification (analyze dependencies, check for cycles)
78    /// 3. Lowering (AST → Relational IR execution plan)
79    ///
80    /// # Arguments
81    ///
82    /// * `source` - The XLOG source code as a string
83    ///
84    /// # Returns
85    ///
86    /// * `Ok(ExecutionPlan)` - The compiled execution plan ready for execution
87    /// * `Err(XlogError)` - If any compilation phase fails:
88    ///   - `XlogError::Parse` - Syntax errors in the source
89    ///   - `XlogError::StratificationCycle` - Unstratifiable negation/aggregation
90    ///   - `XlogError::Compilation` - Other semantic errors
91    ///
92    /// # Example
93    ///
94    /// ```ignore
95    /// let mut compiler = Compiler::new();
96    ///
97    /// // Compile a simple transitive closure program
98    /// let plan = compiler.compile(r#"
99    ///     edge(1, 2).
100    ///     edge(2, 3).
101    ///     reach(X, Y) :- edge(X, Y).
102    ///     reach(X, Z) :- reach(X, Y), edge(Y, Z).
103    /// "#)?;
104    ///
105    /// // The plan can now be executed by xlog-runtime
106    /// ```
107    pub fn compile(&mut self, source: &str) -> Result<ExecutionPlan> {
108        self.compile_with_stats_snapshot(source, None)
109    }
110
111    /// Compile XLOG source code into an execution plan, optionally seeding the optimizer
112    /// with a runtime statistics snapshot.
113    ///
114    /// W2.1: this entry point delegates through the new composable API
115    /// with `CompilerConfig::default()`, which preserves slice
116    /// 1/2/4/W2.2 behavior bit-identically.
117    pub fn compile_with_stats_snapshot(
118        &mut self,
119        source: &str,
120        stats_snapshot: Option<&StatsSnapshot>,
121    ) -> Result<ExecutionPlan> {
122        self.compile_with_config_and_stats_snapshot(
123            source,
124            &CompilerConfig::default(),
125            stats_snapshot,
126        )
127    }
128
129    /// W2.1: composable entry point that accepts a `CompilerConfig`.
130    ///
131    /// Default-config callers should keep using `compile()` /
132    /// `compile_with_stats_snapshot()`. This entry point exists so
133    /// W2.1 can flip the variable-ordering cost model on per-call
134    /// without an env override.
135    pub fn compile_with_config_and_stats_snapshot(
136        &mut self,
137        source: &str,
138        config: &CompilerConfig,
139        stats_snapshot: Option<&StatsSnapshot>,
140    ) -> Result<ExecutionPlan> {
141        let program = parse_program(source)?;
142        self.compile_program_with_config_and_stats_snapshot(&program, config, stats_snapshot)
143    }
144
145    /// Compile a parsed XLOG program into an execution plan.
146    ///
147    /// This is useful for callers that want to inspect the AST (facts, queries,
148    /// constraints) while compiling without reparsing.
149    pub fn compile_program(&mut self, program: &Program) -> Result<ExecutionPlan> {
150        self.compile_program_with_stats_snapshot(program, None)
151    }
152
153    /// Compile a parsed XLOG program into an execution plan, optionally seeding the optimizer.
154    ///
155    /// W2.1: delegates to
156    /// [`Self::compile_program_with_config_and_stats_snapshot`] with
157    /// `CompilerConfig::default()`.
158    pub fn compile_program_with_stats_snapshot(
159        &mut self,
160        program: &Program,
161        stats_snapshot: Option<&StatsSnapshot>,
162    ) -> Result<ExecutionPlan> {
163        self.compile_program_with_config_and_stats_snapshot(
164            program,
165            &CompilerConfig::default(),
166            stats_snapshot,
167        )
168    }
169
170    /// W2.1: composable program-level entry point.
171    ///
172    /// `config` is currently consumed only by the promoter
173    /// (W2.1 step 5) when it wires the variable-ordering cost
174    /// model. With `CompilerConfig::default()`, the promoter
175    /// behaves identically to pre-W2.1.
176    pub fn compile_program_with_config_and_stats_snapshot(
177        &mut self,
178        program: &Program,
179        config: &CompilerConfig,
180        stats_snapshot: Option<&StatsSnapshot>,
181    ) -> Result<ExecutionPlan> {
182        let program = desugar_queries_and_constraints(program);
183        let program = normalize_v085_meta(&program)?;
184        let program = normalize_v085_lists(&program)?;
185        let program = rewrite_v085_magic_sets(&program)?.program;
186        validate_v085_naf_safety(&program)?;
187
188        // Phase 2: Stratify (analyze dependencies, detect cycles)
189        let strata = stratify(&program).map_err(map_stratification_to_naf_error)?;
190
191        // Convert strata to the format expected by the lowerer
192        let strata_preds: Vec<Vec<String>> = strata.into_iter().map(|s| s.predicates).collect();
193
194        // Phase 3: Lower AST to execution plan
195        self.lowerer.set_strata(strata_preds);
196
197        // If we have predicate names for the snapshot, use them to seed lowering-time
198        // join ordering with better cardinality estimates.
199        let mut cardinality_hints: HashMap<String, u64> = HashMap::new();
200        if let Some(snapshot) = stats_snapshot {
201            if !snapshot.rel_names.is_empty() {
202                let rel_name_by_id: HashMap<RelId, &str> = snapshot
203                    .rel_names
204                    .iter()
205                    .map(|(id, name)| (*id, name.as_str()))
206                    .collect();
207                for rel in &snapshot.relations {
208                    if let Some(name) = rel_name_by_id.get(&rel.rel_id) {
209                        cardinality_hints.insert((*name).to_string(), rel.cardinality);
210                    }
211                }
212            }
213        }
214        self.lowerer.set_cardinality_hints(cardinality_hints);
215
216        let mut plan = self.lowerer.lower_program(&program)?;
217
218        // Phase 4: Optimize (predicate pushdown + cost-aware rewrites)
219        //
220        // Seed statistics with any known fact cardinalities so cost estimation has
221        // at least a baseline for EDB relations.
222        let mut mgr = StatsManager::new();
223        let mut fact_counts: HashMap<String, u64> = HashMap::new();
224        for fact in program.facts() {
225            *fact_counts.entry(fact.head.predicate.clone()).or_insert(0) += 1;
226        }
227
228        for (pred, rel_id) in self.lowerer.rel_ids() {
229            mgr.register_relation(*rel_id);
230            let rows = fact_counts.get(pred).copied().unwrap_or(0);
231            if rows > 0 {
232                mgr.update_cardinality(*rel_id, rows);
233                if let Some(schema) = self.lowerer.schemas().get(pred) {
234                    mgr.update_byte_size(*rel_id, rows * schema.row_size_bytes() as u64);
235                }
236            }
237        }
238
239        if let Some(snapshot) = stats_snapshot {
240            if snapshot.rel_names.is_empty() {
241                mgr.merge_snapshot(snapshot);
242            } else {
243                let rel_name_by_id: HashMap<RelId, &str> = snapshot
244                    .rel_names
245                    .iter()
246                    .map(|(id, name)| (*id, name.as_str()))
247                    .collect();
248
249                for rel in &snapshot.relations {
250                    let Some(pred) = rel_name_by_id.get(&rel.rel_id) else {
251                        continue;
252                    };
253                    let Some(rel_id) = self.lowerer.rel_ids().get(*pred) else {
254                        continue;
255                    };
256
257                    let mut remapped = rel.clone();
258                    remapped.rel_id = *rel_id;
259
260                    if let Some(schema) = self.lowerer.schemas().get(*pred) {
261                        remapped.column_stats.retain(|col| {
262                            col.col_idx < schema.arity()
263                                && schema.column_type(col.col_idx) == Some(col.dtype)
264                        });
265                    } else {
266                        remapped.column_stats.clear();
267                    }
268
269                    mgr.register_relation(*rel_id);
270                    if let Some(stats) = mgr.get_relation_stats_mut(*rel_id) {
271                        *stats = remapped;
272                    }
273                }
274
275                for js in &snapshot.join_selectivities {
276                    if js.left_keys.len() != js.right_keys.len() {
277                        continue;
278                    }
279
280                    let Some(left_pred) = rel_name_by_id.get(&js.left_rel) else {
281                        continue;
282                    };
283                    let Some(right_pred) = rel_name_by_id.get(&js.right_rel) else {
284                        continue;
285                    };
286                    let Some(&left_id) = self.lowerer.rel_ids().get(*left_pred) else {
287                        continue;
288                    };
289                    let Some(&right_id) = self.lowerer.rel_ids().get(*right_pred) else {
290                        continue;
291                    };
292
293                    let Some(left_schema) = self.lowerer.schemas().get(*left_pred) else {
294                        continue;
295                    };
296                    let Some(right_schema) = self.lowerer.schemas().get(*right_pred) else {
297                        continue;
298                    };
299                    if js.left_keys.iter().any(|&k| k >= left_schema.arity())
300                        || js.right_keys.iter().any(|&k| k >= right_schema.arity())
301                    {
302                        continue;
303                    }
304
305                    mgr.set_join_selectivity(
306                        left_id,
307                        right_id,
308                        js.left_keys.clone(),
309                        js.right_keys.clone(),
310                        js.selectivity,
311                    );
312                }
313            }
314        }
315
316        // Build schemas by RelId for the optimizer
317        let schemas_by_rel_id: HashMap<RelId, Schema> = self
318            .lowerer
319            .rel_ids()
320            .iter()
321            .filter_map(|(pred, rel_id)| {
322                self.lowerer
323                    .schemas()
324                    .get(pred)
325                    .map(|schema| (*rel_id, schema.clone()))
326            })
327            .collect();
328
329        let stats_arc = Arc::new(mgr);
330
331        crate::optimizer::helper_split_pass::run(
332            &mut plan,
333            &schemas_by_rel_id,
334            &stats_arc,
335            |schema| self.lowerer.create_helper_relation(schema),
336        );
337
338        let schemas_by_rel_id: HashMap<RelId, Schema> = self
339            .lowerer
340            .rel_ids()
341            .iter()
342            .filter_map(|(pred, rel_id)| {
343                self.lowerer
344                    .schemas()
345                    .get(pred)
346                    .map(|schema| (*rel_id, schema.clone()))
347            })
348            .collect();
349
350        let mut optimizer = Optimizer::new(Arc::clone(&stats_arc));
351        optimizer.set_schemas(schemas_by_rel_id);
352        for rules in &mut plan.rules_by_scc {
353            for rule in rules {
354                rule.body = optimizer.optimize(rule.body.clone());
355            }
356        }
357
358        // v0.6.5 slice 3: selectivity-aware reordering pass. Runs
359        // BETWEEN the optimizer loop and promote_multiway.
360        // Locked compile-pipeline ordering:
361        //   lower → helper_split_pass → optimizer → selectivity_pass → promote_multiway
362        //
363        // v0.6.5 W2.2: takes `rel_ids` so per-body Scans can be
364        // resolved against `StatsManager`. Behavior on empty
365        // stats / unseeded relations is no-op (safety floor).
366        crate::optimizer::selectivity_pass::run(&mut plan, &stats_arc, self.lowerer.rel_ids());
367
368        // v0.6.5 slice 1: promote eligible triangle subtrees to
369        // RirNode::MultiWayJoin. Runs *after* the optimizer so the
370        // optimizer never has to learn the new variant. Fallback
371        // identity preserves v0.6.2 binary-join semantics on
372        // dispatch decline.
373        //
374        // v0.6.5 slice 4: pass the lowerer's predicate→RelId map
375        // so the promoter can gate recursive-SCC bodies on the
376        // count of in-SCC Scans (≤ 1 = promote, ≥ 2 = skip).
377        //
378        // W2.1: also pass `&stats_arc` and the caller-provided
379        // `&CompilerConfig`. With `CompilerConfig::default()`
380        // (`Disabled`), the promoter never sets `var_order` and
381        // slice 1/2/4/W2.2 dispatch is bit-identical.
382        crate::promote::promote_multiway(&mut plan, self.lowerer.rel_ids(), &stats_arc, config);
383
384        let schemas_by_rel_id: HashMap<RelId, Schema> = self
385            .lowerer
386            .rel_ids()
387            .iter()
388            .filter_map(|(pred, rel_id)| {
389                self.lowerer
390                    .schemas()
391                    .get(pred)
392                    .map(|schema| (*rel_id, schema.clone()))
393            })
394            .collect();
395
396        crate::optimizer::helper_split_pass::run_kclique_specs(
397            &mut plan,
398            &schemas_by_rel_id,
399            |schema| self.lowerer.create_helper_relation(schema),
400        );
401
402        Ok(plan)
403    }
404
405    /// Reset the compiler state for a fresh compilation.
406    ///
407    /// This creates a new lowerer, clearing any cached schemas or relation IDs
408    /// from previous compilations.
409    pub fn reset(&mut self) {
410        self.lowerer = Lowerer::new();
411    }
412
413    /// Get the mapping from predicate names to relation IDs after compilation.
414    ///
415    /// This mapping is needed to register relations in the executor with
416    /// the correct RelIds.
417    pub fn rel_ids(&self) -> &HashMap<String, RelId> {
418        self.lowerer.rel_ids()
419    }
420
421    /// Get the inferred schemas for predicates after compilation.
422    ///
423    /// These schemas are needed to create GPU buffers with correct column types.
424    pub fn schemas(&self) -> &HashMap<String, Schema> {
425        self.lowerer.schemas()
426    }
427}
428
429fn desugar_queries_and_constraints(program: &Program) -> Program {
430    let mut out = program.clone();
431
432    // Constraints: `:- body.` becomes `__xlog_constraint_i(1) :- body.`
433    for (i, constraint) in program.constraints.iter().enumerate() {
434        let pred = format!("__xlog_constraint_{}", i);
435        out.rules.push(AstRule {
436            head: crate::ast::Atom {
437                predicate: pred,
438                terms: vec![Term::Integer(1)],
439            },
440            body: constraint.body.clone(),
441        });
442    }
443
444    // Queries: `?- atom.` becomes `__xlog_query_i(Vars...) :- atom.`
445    for (i, Query { atom }) in program.queries.iter().enumerate() {
446        let pred = format!("__xlog_query_{}", i);
447
448        let mut head_terms: Vec<Term> = Vec::new();
449        let mut seen: std::collections::HashSet<&str> = std::collections::HashSet::new();
450
451        for term in &atom.terms {
452            for name in term.variables() {
453                if seen.insert(name) {
454                    head_terms.push(Term::Variable(name.to_string()));
455                }
456            }
457        }
458
459        if head_terms.is_empty() {
460            head_terms.push(Term::Integer(1));
461        }
462
463        out.rules.push(AstRule {
464            head: crate::ast::Atom {
465                predicate: pred,
466                terms: head_terms,
467            },
468            body: vec![BodyLiteral::Positive(atom.clone())],
469        });
470    }
471
472    out
473}
474
475fn validate_v085_naf_safety(program: &Program) -> Result<()> {
476    for rule in &program.rules {
477        validate_body_naf_safety(&rule.body, &format!("rule {}", rule.head.predicate))?;
478    }
479    for (idx, constraint) in program.constraints.iter().enumerate() {
480        validate_body_naf_safety(&constraint.body, &format!("constraint {}", idx))?;
481    }
482    for (idx, learnable) in program.learnable_rules.iter().enumerate() {
483        validate_body_naf_safety(&learnable.body, &format!("learnable rule {}", idx))?;
484    }
485    Ok(())
486}
487
488fn validate_body_naf_safety(body: &[BodyLiteral], context: &str) -> Result<()> {
489    let mut bound: HashSet<String> = HashSet::new();
490    for lit in body {
491        match lit {
492            BodyLiteral::Positive(atom) => {
493                for name in atom.variables() {
494                    bound.insert(name.to_string());
495                }
496            }
497            BodyLiteral::Negated(atom) => {
498                for name in atom.variables() {
499                    if !bound.contains(name) {
500                        return Err(naf_error(format!(
501                            "unbound variable {} in negated atom {}/{} in {}; bind it before not with a positive atom or deterministic is expression, or use '_' for existential positions",
502                            name,
503                            atom.predicate,
504                            atom.arity(),
505                            context
506                        )));
507                    }
508                }
509            }
510            BodyLiteral::IsExpr(is_expr) => {
511                bound.insert(is_expr.target.clone());
512            }
513            BodyLiteral::Epistemic(_) => {}
514            BodyLiteral::Comparison(_) | BodyLiteral::Univ(_) => {}
515        }
516    }
517    Ok(())
518}
519
520fn map_stratification_to_naf_error(err: XlogError) -> XlogError {
521    match err {
522        XlogError::StratificationCycle(cycle) => naf_error(format!(
523            "deterministic not atom must be stratified; cycle through negation or aggregation: {}",
524            cycle.join(" -> ")
525        )),
526        other => other,
527    }
528}
529
530fn naf_error(message: impl Into<String>) -> XlogError {
531    XlogError::Compilation(format!("v0.8.5 naf error: {}", message.into()))
532}
533
534/// Convenience function to compile source in one call.
535///
536/// This creates a short-lived compiler and compiles the source.
537/// For multiple compilations, prefer creating a `Compiler` instance directly.
538///
539/// # Example
540///
541/// ```ignore
542/// use xlog_logic::compile::compile;
543///
544/// let plan = compile("edge(1, 2). reach(X, Y) :- edge(X, Y).")?;
545/// ```
546pub fn compile(source: &str) -> Result<ExecutionPlan> {
547    let mut compiler = Compiler::new();
548    compiler.compile(source)
549}
550
551/// Load and validate modules for a source file.
552///
553/// This function:
554/// 1. Determines the module path from the entry file name
555/// 2. Loads the entry module and all its dependencies
556/// 3. Validates imports (checks for conflicts, private predicates, etc.)
557///
558/// # Arguments
559///
560/// * `entry_file` - Path to the main .xlog file
561/// * `search_paths` - Additional directories to search for modules
562///
563/// # Returns
564///
565/// The loaded module resolver with all dependencies resolved, or an error
566/// if module resolution fails.
567pub fn load_modules(
568    entry_file: &Path,
569    search_paths: Vec<PathBuf>,
570) -> std::result::Result<ModuleResolver, ModuleError> {
571    let mut resolver = ModuleResolver::new(search_paths);
572
573    // Determine base directory and module path
574    let base_dir = entry_file.parent().unwrap_or(Path::new("."));
575    let module_name = entry_file
576        .file_stem()
577        .and_then(|s| s.to_str())
578        .unwrap_or("main");
579
580    // Load entry module (recursively loads dependencies)
581    resolver.load_module(base_dir, &[module_name.to_string()])?;
582
583    Ok(resolver)
584}
585
586#[cfg(test)]
587mod tests {
588    use super::*;
589    use xlog_core::ScalarType;
590    use xlog_ir::RirNode;
591    use xlog_stats::ColumnStats;
592    use xlog_stats::RelationStats;
593    use xlog_stats::StatsManager;
594
595    #[test]
596    fn test_compiler_new() {
597        let compiler = Compiler::new();
598        // Just verify it can be created
599        drop(compiler);
600    }
601
602    #[test]
603    fn test_compile_fact() {
604        let mut compiler = Compiler::new();
605        let result = compiler.compile("edge(1, 2).");
606        assert!(result.is_ok(), "Failed to compile fact: {:?}", result.err());
607    }
608
609    #[test]
610    fn test_compile_simple_rule() {
611        let mut compiler = Compiler::new();
612        let result = compiler.compile(
613            r#"
614            edge(1, 2).
615            reach(X, Y) :- edge(X, Y).
616        "#,
617        );
618        assert!(
619            result.is_ok(),
620            "Failed to compile simple rule: {:?}",
621            result.err()
622        );
623
624        let plan = result.unwrap();
625        assert!(!plan.sccs.is_empty(), "Expected at least one SCC");
626    }
627
628    #[test]
629    fn test_compile_transitive_closure() {
630        let mut compiler = Compiler::new();
631        let result = compiler.compile(
632            r#"
633            edge(1, 2).
634            edge(2, 3).
635            edge(3, 4).
636            reach(X, Y) :- edge(X, Y).
637            reach(X, Z) :- reach(X, Y), edge(Y, Z).
638        "#,
639        );
640        assert!(result.is_ok(), "Failed to compile TC: {:?}", result.err());
641
642        let plan = result.unwrap();
643        // Should have SCCs for edge and reach
644        assert!(!plan.sccs.is_empty());
645    }
646
647    #[test]
648    fn test_compile_with_negation() {
649        let mut compiler = Compiler::new();
650        let result = compiler.compile(
651            r#"
652            node(1).
653            node(2).
654            node(3).
655            edge(1, 2).
656            isolated(X) :- node(X), not edge(X, _).
657        "#,
658        );
659        assert!(
660            result.is_ok(),
661            "Failed to compile with negation: {:?}",
662            result.err()
663        );
664    }
665
666    #[test]
667    fn test_compile_with_comparison() {
668        let mut compiler = Compiler::new();
669        let result = compiler.compile(
670            r#"
671            value(1).
672            value(5).
673            value(10).
674            value(15).
675            small(X) :- value(X), X < 10.
676        "#,
677        );
678        assert!(
679            result.is_ok(),
680            "Failed to compile with comparison: {:?}",
681            result.err()
682        );
683    }
684
685    #[test]
686    fn test_schema_infers_from_rule_body_types() {
687        let mut compiler = Compiler::new();
688        let result = compiler.compile(
689            r#"
690            edge(1, 2).
691            edge(2, 3).
692            reach(X, Y) :- edge(X, Y).
693        "#,
694        );
695        assert!(
696            result.is_ok(),
697            "Failed to compile rule for schema inference: {:?}",
698            result.err()
699        );
700
701        let schema = compiler
702            .schemas()
703            .get("reach")
704            .expect("missing reach schema");
705        assert_eq!(
706            schema.column_type(0),
707            Some(ScalarType::U32),
708            "reach column 0 should match edge column type"
709        );
710        assert_eq!(
711            schema.column_type(1),
712            Some(ScalarType::U32),
713            "reach column 1 should match edge column type"
714        );
715    }
716
717    #[test]
718    fn test_compile_unstratifiable_fails() {
719        let mut compiler = Compiler::new();
720        let result = compiler.compile(
721            r#"
722            p :- not q.
723            q :- not p.
724        "#,
725        );
726        assert!(result.is_err(), "Should fail with stratification cycle");
727    }
728
729    #[test]
730    fn test_compile_syntax_error_fails() {
731        let mut compiler = Compiler::new();
732        let result = compiler.compile("edge(1, 2"); // Missing closing paren and period
733        assert!(result.is_err(), "Should fail with syntax error");
734    }
735
736    #[test]
737    fn test_compile_convenience_function() {
738        let result = compile("edge(1, 2).");
739        assert!(
740            result.is_ok(),
741            "Convenience compile failed: {:?}",
742            result.err()
743        );
744    }
745
746    #[test]
747    fn test_compiler_reset() {
748        let mut compiler = Compiler::new();
749
750        // First compilation
751        let result1 = compiler.compile("edge(1, 2).");
752        assert!(result1.is_ok());
753
754        // Reset and compile again
755        compiler.reset();
756        let result2 = compiler.compile("node(1). node(2).");
757        assert!(result2.is_ok());
758    }
759
760    #[test]
761    fn test_compile_with_pred_decl() {
762        let mut compiler = Compiler::new();
763        let result = compiler.compile(
764            r#"
765            pred edge(u32, u32).
766            edge(1, 2).
767            edge(2, 3).
768            reach(X, Y) :- edge(X, Y).
769        "#,
770        );
771        assert!(
772            result.is_ok(),
773            "Failed to compile with pred decl: {:?}",
774            result.err()
775        );
776    }
777
778    #[test]
779    fn test_compile_multi_stratum() {
780        let mut compiler = Compiler::new();
781        let result = compiler.compile(
782            r#"
783            // Base facts
784            edge(1, 2).
785            edge(2, 3).
786            edge(3, 1).
787
788            // Stratum 0: edge (base)
789            // Stratum 1: reach (depends on edge, recursive)
790            reach(X, Y) :- edge(X, Y).
791            reach(X, Z) :- reach(X, Y), edge(Y, Z).
792
793            // Stratum 2: non_reach (negates reach)
794            all_pairs(X, Y) :- edge(X, Z), edge(Y, W).
795            non_reach(X, Y) :- all_pairs(X, Y), not reach(X, Y).
796        "#,
797        );
798        assert!(
799            result.is_ok(),
800            "Failed to compile multi-stratum: {:?}",
801            result.err()
802        );
803
804        let plan = result.unwrap();
805        // Should have multiple strata
806        assert!(!plan.strata.is_empty(), "Expected multiple strata");
807    }
808
809    #[test]
810    fn test_compile_aggregation() {
811        let mut compiler = Compiler::new();
812        let result = compiler.compile(
813            r#"
814            edge(1, 2).
815            edge(1, 3).
816            edge(2, 3).
817            out_degree(X, count(Y)) :- edge(X, Y).
818        "#,
819        );
820        assert!(
821            result.is_ok(),
822            "Failed to compile with aggregation: {:?}",
823            result.err()
824        );
825
826        let plan = result.unwrap();
827        let out_degree_rules: Vec<_> = plan
828            .rules_by_scc
829            .iter()
830            .flatten()
831            .filter(|r| r.head == "out_degree")
832            .collect();
833        assert_eq!(out_degree_rules.len(), 1, "Expected one out_degree rule");
834
835        // Aggregation lowering should produce a GroupBy node (wrapped in a Project to match head order).
836        let body = &out_degree_rules[0].body;
837        match body {
838            RirNode::Project { input, .. } => {
839                assert!(
840                    matches!(input.as_ref(), RirNode::GroupBy { .. }),
841                    "Expected Project(GroupBy(..)), got {:?}",
842                    input
843                );
844            }
845            other => panic!("Expected Project(GroupBy(..)), got {:?}", other),
846        }
847    }
848
849    #[test]
850    fn test_compile_with_stats_snapshot() {
851        let mut compiler = Compiler::new();
852        let source = r#"
853            edge(1, 2).
854            edge(2, 3).
855            reach(X, Y) :- edge(X, Y).
856        "#;
857
858        let _ = compiler.compile(source).expect("Initial compile failed");
859        let edge_id = *compiler.rel_ids().get("edge").expect("edge rel_id missing");
860
861        let mut mgr = StatsManager::new();
862        mgr.register_relation(edge_id);
863        mgr.update_cardinality(edge_id, 42);
864        let snapshot = mgr.snapshot();
865
866        let plan = compiler
867            .compile_with_stats_snapshot(source, Some(&snapshot))
868            .expect("Compile with snapshot failed");
869        assert!(!plan.sccs.is_empty());
870    }
871
872    #[test]
873    fn test_compile_with_named_stats_snapshot_reorders_joins() {
874        let mut compiler = Compiler::new();
875        let source = r#"
876            foo(1).
877            edge(1).
878            out(X) :- edge(X), foo(X).
879        "#;
880
881        // Snapshot uses different RelIds than the compiler will assign for this program.
882        // Map: RelId(0) -> edge (small), RelId(1) -> foo (big)
883        let mut edge_stats = RelationStats::new(RelId(0));
884        edge_stats.update_cardinality(10);
885        let mut foo_stats = RelationStats::new(RelId(1));
886        foo_stats.update_cardinality(10_000);
887
888        let snapshot = StatsSnapshot {
889            relations: vec![edge_stats, foo_stats],
890            join_selectivities: Vec::new(),
891            rel_names: vec![
892                (RelId(0), "edge".to_string()),
893                (RelId(1), "foo".to_string()),
894            ],
895        };
896
897        let plan = compiler
898            .compile_with_stats_snapshot(source, Some(&snapshot))
899            .expect("Compile with named snapshot failed");
900
901        let foo_id = *compiler.rel_ids().get("foo").expect("foo rel_id missing");
902        let edge_id = *compiler.rel_ids().get("edge").expect("edge rel_id missing");
903
904        let out_rule = plan
905            .rules_by_scc
906            .iter()
907            .flatten()
908            .find(|r| r.head == "out")
909            .expect("out rule missing");
910
911        // Peel projections to reach the join.
912        let mut node = &out_rule.body;
913        while let RirNode::Project { input, .. } = node {
914            node = input;
915        }
916
917        match node {
918            RirNode::ChainJoin {
919                left,
920                right,
921                fallback,
922                ..
923            } => {
924                // W63 wraps eligible two-atom joins after stats-aware
925                // ordering. The chain node and its captured fallback must
926                // agree on the build-side choice.
927                assert!(matches!(**left, RirNode::Scan { rel } if rel == foo_id));
928                assert!(matches!(**right, RirNode::Scan { rel } if rel == edge_id));
929
930                let mut fallback_node = fallback.as_ref();
931                while let RirNode::Project { input, .. } = fallback_node {
932                    fallback_node = input;
933                }
934                match fallback_node {
935                    RirNode::Join { left, right, .. } => {
936                        assert!(matches!(**left, RirNode::Scan { rel } if rel == foo_id));
937                        assert!(matches!(**right, RirNode::Scan { rel } if rel == edge_id));
938                    }
939                    other => panic!("Expected ChainJoin fallback Join node, got {:?}", other),
940                }
941            }
942            RirNode::Join { left, right, .. } => {
943                // Prefer building on the smaller relation (right/build side).
944                assert!(matches!(**left, RirNode::Scan { rel } if rel == foo_id));
945                assert!(matches!(**right, RirNode::Scan { rel } if rel == edge_id));
946            }
947            other => panic!("Expected Join node, got {:?}", other),
948        }
949    }
950
951    fn helper_split_source() -> &'static str {
952        r#"
953            ab(0, 0). bc(0, 0). cd(0, 0). de(0, 0). ef(0, 0). af(0, 0).
954            out(A, B, C, D, F) :-
955                ab(A, B),
956                bc(B, C),
957                cd(C, D),
958                de(D, E),
959                ef(E, F),
960                af(A, F).
961        "#
962    }
963
964    fn helper_split_snapshot(distinct_d: u64) -> StatsSnapshot {
965        let mut snapshot_relations = Vec::new();
966        for (idx, name) in ["ab", "bc", "cd", "de", "ef", "af"].iter().enumerate() {
967            let mut rel_stats = RelationStats::new(RelId(idx as u32));
968            rel_stats.update_cardinality(8192);
969            if *name == "de" {
970                let mut d_col = ColumnStats::new(0, ScalarType::U32);
971                d_col.update_distinct(distinct_d);
972                rel_stats.add_column(d_col);
973            }
974            snapshot_relations.push(rel_stats);
975        }
976        StatsSnapshot {
977            relations: snapshot_relations,
978            join_selectivities: Vec::new(),
979            rel_names: ["ab", "bc", "cd", "de", "ef", "af"]
980                .iter()
981                .enumerate()
982                .map(|(idx, name)| (RelId(idx as u32), (*name).to_string()))
983                .collect(),
984        }
985    }
986
987    #[test]
988    fn test_compile_with_named_stats_snapshot_creates_helper_relation() {
989        let mut compiler = Compiler::new();
990        let snapshot = helper_split_snapshot(1);
991        let plan = compiler
992            .compile_with_stats_snapshot(helper_split_source(), Some(&snapshot))
993            .expect("compile with helper stats");
994        let helper = compiler
995            .rel_ids()
996            .iter()
997            .find_map(|(name, rel)| {
998                name.starts_with("__w37_helper_")
999                    .then_some((name.clone(), *rel))
1000            })
1001            .expect("helper relation allocated");
1002
1003        let helper_rule_count = plan
1004            .rules_by_scc
1005            .iter()
1006            .flatten()
1007            .filter(|rule| rule.head == helper.0)
1008            .count();
1009        assert_eq!(helper_rule_count, 1);
1010
1011        let helper_rule = plan
1012            .rules_by_scc
1013            .iter()
1014            .flatten()
1015            .find(|rule| rule.head == helper.0)
1016            .expect("helper rule");
1017        assert!(
1018            matches!(helper_rule.body, RirNode::ChainJoin { .. }),
1019            "helper split output should be eligible for W63 ChainJoin promotion"
1020        );
1021
1022        let out_rule = plan
1023            .rules_by_scc
1024            .iter()
1025            .flatten()
1026            .find(|rule| rule.head == "out")
1027            .expect("out rule");
1028        assert!(contains_scan(&out_rule.body, helper.1));
1029    }
1030
1031    #[test]
1032    fn test_compile_with_flat_named_stats_keeps_original_rule() {
1033        let mut compiler = Compiler::new();
1034        let snapshot = helper_split_snapshot(8192);
1035        let plan = compiler
1036            .compile_with_stats_snapshot(helper_split_source(), Some(&snapshot))
1037            .expect("compile with flat stats");
1038
1039        assert!(!compiler
1040            .rel_ids()
1041            .keys()
1042            .any(|name| name.starts_with("__w37_helper_")));
1043        let out_rules = plan
1044            .rules_by_scc
1045            .iter()
1046            .flatten()
1047            .filter(|rule| rule.head == "out")
1048            .count();
1049        assert_eq!(out_rules, 1);
1050    }
1051
1052    fn contains_scan(node: &RirNode, rel: RelId) -> bool {
1053        match node {
1054            RirNode::Scan { rel: scan_rel } => *scan_rel == rel,
1055            RirNode::Join { left, right, .. } | RirNode::ChainJoin { left, right, .. } => {
1056                contains_scan(left, rel) || contains_scan(right, rel)
1057            }
1058            RirNode::Project { input, .. }
1059            | RirNode::Filter { input, .. }
1060            | RirNode::Distinct { input, .. }
1061            | RirNode::GroupBy { input, .. } => contains_scan(input, rel),
1062            RirNode::Union { inputs } => inputs.iter().any(|input| contains_scan(input, rel)),
1063            RirNode::Diff { left, right } => contains_scan(left, rel) || contains_scan(right, rel),
1064            RirNode::Fixpoint {
1065                base, recursive, ..
1066            } => contains_scan(base, rel) || contains_scan(recursive, rel),
1067            RirNode::MultiWayJoin { inputs, .. } => {
1068                inputs.iter().any(|input| contains_scan(input, rel))
1069            }
1070            RirNode::TensorMaskedJoin { rel_index, .. } => {
1071                rel_index.iter().any(|(input_rel, _)| *input_rel == rel)
1072            }
1073            RirNode::Unit => false,
1074        }
1075    }
1076}