1use std::path::{Path, PathBuf};
14
15use xlog_core::{Result, XlogError};
16use xlog_ir::ExecutionPlan;
17use xlog_stats::{StatsManager, StatsSnapshot};
18
19use crate::compiler_config::CompilerConfig;
20use crate::list_normalize::normalize_v085_lists;
21use crate::lower::Lowerer;
22use crate::magic_sets::rewrite_v085_magic_sets;
23use crate::meta_normalize::normalize_v085_meta;
24use crate::module::ModuleError;
25use crate::optimizer::Optimizer;
26use crate::parser::parse_program;
27use crate::resolver::ModuleResolver;
28use crate::stratify::stratify;
29use crate::{BodyLiteral, Program, Query, Rule as AstRule, Term};
30
31pub struct Compiler {
47 lowerer: Lowerer,
48}
49
50use std::collections::{HashMap, HashSet};
51use std::sync::Arc;
52use xlog_core::{RelId, Schema};
53
54impl Default for Compiler {
55 fn default() -> Self {
56 Self::new()
57 }
58}
59
60impl Compiler {
61 pub fn new() -> Self {
63 Self {
64 lowerer: Lowerer::new(),
65 }
66 }
67
68 pub fn set_max_active_rules(&mut self, max: usize) {
70 self.lowerer.set_max_active_rules(max);
71 }
72
73 pub fn compile(&mut self, source: &str) -> Result<ExecutionPlan> {
108 self.compile_with_stats_snapshot(source, None)
109 }
110
111 pub fn compile_with_stats_snapshot(
118 &mut self,
119 source: &str,
120 stats_snapshot: Option<&StatsSnapshot>,
121 ) -> Result<ExecutionPlan> {
122 self.compile_with_config_and_stats_snapshot(
123 source,
124 &CompilerConfig::default(),
125 stats_snapshot,
126 )
127 }
128
129 pub fn compile_with_config_and_stats_snapshot(
136 &mut self,
137 source: &str,
138 config: &CompilerConfig,
139 stats_snapshot: Option<&StatsSnapshot>,
140 ) -> Result<ExecutionPlan> {
141 let program = parse_program(source)?;
142 self.compile_program_with_config_and_stats_snapshot(&program, config, stats_snapshot)
143 }
144
145 pub fn compile_program(&mut self, program: &Program) -> Result<ExecutionPlan> {
150 self.compile_program_with_stats_snapshot(program, None)
151 }
152
153 pub fn compile_program_with_stats_snapshot(
159 &mut self,
160 program: &Program,
161 stats_snapshot: Option<&StatsSnapshot>,
162 ) -> Result<ExecutionPlan> {
163 self.compile_program_with_config_and_stats_snapshot(
164 program,
165 &CompilerConfig::default(),
166 stats_snapshot,
167 )
168 }
169
170 pub fn compile_program_with_config_and_stats_snapshot(
177 &mut self,
178 program: &Program,
179 config: &CompilerConfig,
180 stats_snapshot: Option<&StatsSnapshot>,
181 ) -> Result<ExecutionPlan> {
182 let program = desugar_queries_and_constraints(program);
183 let program = normalize_v085_meta(&program)?;
184 let program = normalize_v085_lists(&program)?;
185 let program = rewrite_v085_magic_sets(&program)?.program;
186 validate_v085_naf_safety(&program)?;
187
188 let strata = stratify(&program).map_err(map_stratification_to_naf_error)?;
190
191 let strata_preds: Vec<Vec<String>> = strata.into_iter().map(|s| s.predicates).collect();
193
194 self.lowerer.set_strata(strata_preds);
196
197 let mut cardinality_hints: HashMap<String, u64> = HashMap::new();
200 if let Some(snapshot) = stats_snapshot {
201 if !snapshot.rel_names.is_empty() {
202 let rel_name_by_id: HashMap<RelId, &str> = snapshot
203 .rel_names
204 .iter()
205 .map(|(id, name)| (*id, name.as_str()))
206 .collect();
207 for rel in &snapshot.relations {
208 if let Some(name) = rel_name_by_id.get(&rel.rel_id) {
209 cardinality_hints.insert((*name).to_string(), rel.cardinality);
210 }
211 }
212 }
213 }
214 self.lowerer.set_cardinality_hints(cardinality_hints);
215
216 let mut plan = self.lowerer.lower_program(&program)?;
217
218 let mut mgr = StatsManager::new();
223 let mut fact_counts: HashMap<String, u64> = HashMap::new();
224 for fact in program.facts() {
225 *fact_counts.entry(fact.head.predicate.clone()).or_insert(0) += 1;
226 }
227
228 for (pred, rel_id) in self.lowerer.rel_ids() {
229 mgr.register_relation(*rel_id);
230 let rows = fact_counts.get(pred).copied().unwrap_or(0);
231 if rows > 0 {
232 mgr.update_cardinality(*rel_id, rows);
233 if let Some(schema) = self.lowerer.schemas().get(pred) {
234 mgr.update_byte_size(*rel_id, rows * schema.row_size_bytes() as u64);
235 }
236 }
237 }
238
239 if let Some(snapshot) = stats_snapshot {
240 if snapshot.rel_names.is_empty() {
241 mgr.merge_snapshot(snapshot);
242 } else {
243 let rel_name_by_id: HashMap<RelId, &str> = snapshot
244 .rel_names
245 .iter()
246 .map(|(id, name)| (*id, name.as_str()))
247 .collect();
248
249 for rel in &snapshot.relations {
250 let Some(pred) = rel_name_by_id.get(&rel.rel_id) else {
251 continue;
252 };
253 let Some(rel_id) = self.lowerer.rel_ids().get(*pred) else {
254 continue;
255 };
256
257 let mut remapped = rel.clone();
258 remapped.rel_id = *rel_id;
259
260 if let Some(schema) = self.lowerer.schemas().get(*pred) {
261 remapped.column_stats.retain(|col| {
262 col.col_idx < schema.arity()
263 && schema.column_type(col.col_idx) == Some(col.dtype)
264 });
265 } else {
266 remapped.column_stats.clear();
267 }
268
269 mgr.register_relation(*rel_id);
270 if let Some(stats) = mgr.get_relation_stats_mut(*rel_id) {
271 *stats = remapped;
272 }
273 }
274
275 for js in &snapshot.join_selectivities {
276 if js.left_keys.len() != js.right_keys.len() {
277 continue;
278 }
279
280 let Some(left_pred) = rel_name_by_id.get(&js.left_rel) else {
281 continue;
282 };
283 let Some(right_pred) = rel_name_by_id.get(&js.right_rel) else {
284 continue;
285 };
286 let Some(&left_id) = self.lowerer.rel_ids().get(*left_pred) else {
287 continue;
288 };
289 let Some(&right_id) = self.lowerer.rel_ids().get(*right_pred) else {
290 continue;
291 };
292
293 let Some(left_schema) = self.lowerer.schemas().get(*left_pred) else {
294 continue;
295 };
296 let Some(right_schema) = self.lowerer.schemas().get(*right_pred) else {
297 continue;
298 };
299 if js.left_keys.iter().any(|&k| k >= left_schema.arity())
300 || js.right_keys.iter().any(|&k| k >= right_schema.arity())
301 {
302 continue;
303 }
304
305 mgr.set_join_selectivity(
306 left_id,
307 right_id,
308 js.left_keys.clone(),
309 js.right_keys.clone(),
310 js.selectivity,
311 );
312 }
313 }
314 }
315
316 let schemas_by_rel_id: HashMap<RelId, Schema> = self
318 .lowerer
319 .rel_ids()
320 .iter()
321 .filter_map(|(pred, rel_id)| {
322 self.lowerer
323 .schemas()
324 .get(pred)
325 .map(|schema| (*rel_id, schema.clone()))
326 })
327 .collect();
328
329 let stats_arc = Arc::new(mgr);
330
331 crate::optimizer::helper_split_pass::run(
332 &mut plan,
333 &schemas_by_rel_id,
334 &stats_arc,
335 |schema| self.lowerer.create_helper_relation(schema),
336 );
337
338 let schemas_by_rel_id: HashMap<RelId, Schema> = self
339 .lowerer
340 .rel_ids()
341 .iter()
342 .filter_map(|(pred, rel_id)| {
343 self.lowerer
344 .schemas()
345 .get(pred)
346 .map(|schema| (*rel_id, schema.clone()))
347 })
348 .collect();
349
350 let mut optimizer = Optimizer::new(Arc::clone(&stats_arc));
351 optimizer.set_schemas(schemas_by_rel_id);
352 for rules in &mut plan.rules_by_scc {
353 for rule in rules {
354 rule.body = optimizer.optimize(rule.body.clone());
355 }
356 }
357
358 crate::optimizer::selectivity_pass::run(&mut plan, &stats_arc, self.lowerer.rel_ids());
367
368 crate::promote::promote_multiway(&mut plan, self.lowerer.rel_ids(), &stats_arc, config);
383
384 let schemas_by_rel_id: HashMap<RelId, Schema> = self
385 .lowerer
386 .rel_ids()
387 .iter()
388 .filter_map(|(pred, rel_id)| {
389 self.lowerer
390 .schemas()
391 .get(pred)
392 .map(|schema| (*rel_id, schema.clone()))
393 })
394 .collect();
395
396 crate::optimizer::helper_split_pass::run_kclique_specs(
397 &mut plan,
398 &schemas_by_rel_id,
399 |schema| self.lowerer.create_helper_relation(schema),
400 );
401
402 Ok(plan)
403 }
404
405 pub fn reset(&mut self) {
410 self.lowerer = Lowerer::new();
411 }
412
413 pub fn rel_ids(&self) -> &HashMap<String, RelId> {
418 self.lowerer.rel_ids()
419 }
420
421 pub fn schemas(&self) -> &HashMap<String, Schema> {
425 self.lowerer.schemas()
426 }
427}
428
429fn desugar_queries_and_constraints(program: &Program) -> Program {
430 let mut out = program.clone();
431
432 for (i, constraint) in program.constraints.iter().enumerate() {
434 let pred = format!("__xlog_constraint_{}", i);
435 out.rules.push(AstRule {
436 head: crate::ast::Atom {
437 predicate: pred,
438 terms: vec![Term::Integer(1)],
439 },
440 body: constraint.body.clone(),
441 });
442 }
443
444 for (i, Query { atom }) in program.queries.iter().enumerate() {
446 let pred = format!("__xlog_query_{}", i);
447
448 let mut head_terms: Vec<Term> = Vec::new();
449 let mut seen: std::collections::HashSet<&str> = std::collections::HashSet::new();
450
451 for term in &atom.terms {
452 for name in term.variables() {
453 if seen.insert(name) {
454 head_terms.push(Term::Variable(name.to_string()));
455 }
456 }
457 }
458
459 if head_terms.is_empty() {
460 head_terms.push(Term::Integer(1));
461 }
462
463 out.rules.push(AstRule {
464 head: crate::ast::Atom {
465 predicate: pred,
466 terms: head_terms,
467 },
468 body: vec![BodyLiteral::Positive(atom.clone())],
469 });
470 }
471
472 out
473}
474
475fn validate_v085_naf_safety(program: &Program) -> Result<()> {
476 for rule in &program.rules {
477 validate_body_naf_safety(&rule.body, &format!("rule {}", rule.head.predicate))?;
478 }
479 for (idx, constraint) in program.constraints.iter().enumerate() {
480 validate_body_naf_safety(&constraint.body, &format!("constraint {}", idx))?;
481 }
482 for (idx, learnable) in program.learnable_rules.iter().enumerate() {
483 validate_body_naf_safety(&learnable.body, &format!("learnable rule {}", idx))?;
484 }
485 Ok(())
486}
487
488fn validate_body_naf_safety(body: &[BodyLiteral], context: &str) -> Result<()> {
489 let mut bound: HashSet<String> = HashSet::new();
490 for lit in body {
491 match lit {
492 BodyLiteral::Positive(atom) => {
493 for name in atom.variables() {
494 bound.insert(name.to_string());
495 }
496 }
497 BodyLiteral::Negated(atom) => {
498 for name in atom.variables() {
499 if !bound.contains(name) {
500 return Err(naf_error(format!(
501 "unbound variable {} in negated atom {}/{} in {}; bind it before not with a positive atom or deterministic is expression, or use '_' for existential positions",
502 name,
503 atom.predicate,
504 atom.arity(),
505 context
506 )));
507 }
508 }
509 }
510 BodyLiteral::IsExpr(is_expr) => {
511 bound.insert(is_expr.target.clone());
512 }
513 BodyLiteral::Epistemic(_) => {}
514 BodyLiteral::Comparison(_) | BodyLiteral::Univ(_) => {}
515 }
516 }
517 Ok(())
518}
519
520fn map_stratification_to_naf_error(err: XlogError) -> XlogError {
521 match err {
522 XlogError::StratificationCycle(cycle) => naf_error(format!(
523 "deterministic not atom must be stratified; cycle through negation or aggregation: {}",
524 cycle.join(" -> ")
525 )),
526 other => other,
527 }
528}
529
530fn naf_error(message: impl Into<String>) -> XlogError {
531 XlogError::Compilation(format!("v0.8.5 naf error: {}", message.into()))
532}
533
534pub fn compile(source: &str) -> Result<ExecutionPlan> {
547 let mut compiler = Compiler::new();
548 compiler.compile(source)
549}
550
551pub fn load_modules(
568 entry_file: &Path,
569 search_paths: Vec<PathBuf>,
570) -> std::result::Result<ModuleResolver, ModuleError> {
571 let mut resolver = ModuleResolver::new(search_paths);
572
573 let base_dir = entry_file.parent().unwrap_or(Path::new("."));
575 let module_name = entry_file
576 .file_stem()
577 .and_then(|s| s.to_str())
578 .unwrap_or("main");
579
580 resolver.load_module(base_dir, &[module_name.to_string()])?;
582
583 Ok(resolver)
584}
585
586#[cfg(test)]
587mod tests {
588 use super::*;
589 use xlog_core::ScalarType;
590 use xlog_ir::RirNode;
591 use xlog_stats::ColumnStats;
592 use xlog_stats::RelationStats;
593 use xlog_stats::StatsManager;
594
595 #[test]
596 fn test_compiler_new() {
597 let compiler = Compiler::new();
598 drop(compiler);
600 }
601
602 #[test]
603 fn test_compile_fact() {
604 let mut compiler = Compiler::new();
605 let result = compiler.compile("edge(1, 2).");
606 assert!(result.is_ok(), "Failed to compile fact: {:?}", result.err());
607 }
608
609 #[test]
610 fn test_compile_simple_rule() {
611 let mut compiler = Compiler::new();
612 let result = compiler.compile(
613 r#"
614 edge(1, 2).
615 reach(X, Y) :- edge(X, Y).
616 "#,
617 );
618 assert!(
619 result.is_ok(),
620 "Failed to compile simple rule: {:?}",
621 result.err()
622 );
623
624 let plan = result.unwrap();
625 assert!(!plan.sccs.is_empty(), "Expected at least one SCC");
626 }
627
628 #[test]
629 fn test_compile_transitive_closure() {
630 let mut compiler = Compiler::new();
631 let result = compiler.compile(
632 r#"
633 edge(1, 2).
634 edge(2, 3).
635 edge(3, 4).
636 reach(X, Y) :- edge(X, Y).
637 reach(X, Z) :- reach(X, Y), edge(Y, Z).
638 "#,
639 );
640 assert!(result.is_ok(), "Failed to compile TC: {:?}", result.err());
641
642 let plan = result.unwrap();
643 assert!(!plan.sccs.is_empty());
645 }
646
647 #[test]
648 fn test_compile_with_negation() {
649 let mut compiler = Compiler::new();
650 let result = compiler.compile(
651 r#"
652 node(1).
653 node(2).
654 node(3).
655 edge(1, 2).
656 isolated(X) :- node(X), not edge(X, _).
657 "#,
658 );
659 assert!(
660 result.is_ok(),
661 "Failed to compile with negation: {:?}",
662 result.err()
663 );
664 }
665
666 #[test]
667 fn test_compile_with_comparison() {
668 let mut compiler = Compiler::new();
669 let result = compiler.compile(
670 r#"
671 value(1).
672 value(5).
673 value(10).
674 value(15).
675 small(X) :- value(X), X < 10.
676 "#,
677 );
678 assert!(
679 result.is_ok(),
680 "Failed to compile with comparison: {:?}",
681 result.err()
682 );
683 }
684
685 #[test]
686 fn test_schema_infers_from_rule_body_types() {
687 let mut compiler = Compiler::new();
688 let result = compiler.compile(
689 r#"
690 edge(1, 2).
691 edge(2, 3).
692 reach(X, Y) :- edge(X, Y).
693 "#,
694 );
695 assert!(
696 result.is_ok(),
697 "Failed to compile rule for schema inference: {:?}",
698 result.err()
699 );
700
701 let schema = compiler
702 .schemas()
703 .get("reach")
704 .expect("missing reach schema");
705 assert_eq!(
706 schema.column_type(0),
707 Some(ScalarType::U32),
708 "reach column 0 should match edge column type"
709 );
710 assert_eq!(
711 schema.column_type(1),
712 Some(ScalarType::U32),
713 "reach column 1 should match edge column type"
714 );
715 }
716
717 #[test]
718 fn test_compile_unstratifiable_fails() {
719 let mut compiler = Compiler::new();
720 let result = compiler.compile(
721 r#"
722 p :- not q.
723 q :- not p.
724 "#,
725 );
726 assert!(result.is_err(), "Should fail with stratification cycle");
727 }
728
729 #[test]
730 fn test_compile_syntax_error_fails() {
731 let mut compiler = Compiler::new();
732 let result = compiler.compile("edge(1, 2"); assert!(result.is_err(), "Should fail with syntax error");
734 }
735
736 #[test]
737 fn test_compile_convenience_function() {
738 let result = compile("edge(1, 2).");
739 assert!(
740 result.is_ok(),
741 "Convenience compile failed: {:?}",
742 result.err()
743 );
744 }
745
746 #[test]
747 fn test_compiler_reset() {
748 let mut compiler = Compiler::new();
749
750 let result1 = compiler.compile("edge(1, 2).");
752 assert!(result1.is_ok());
753
754 compiler.reset();
756 let result2 = compiler.compile("node(1). node(2).");
757 assert!(result2.is_ok());
758 }
759
760 #[test]
761 fn test_compile_with_pred_decl() {
762 let mut compiler = Compiler::new();
763 let result = compiler.compile(
764 r#"
765 pred edge(u32, u32).
766 edge(1, 2).
767 edge(2, 3).
768 reach(X, Y) :- edge(X, Y).
769 "#,
770 );
771 assert!(
772 result.is_ok(),
773 "Failed to compile with pred decl: {:?}",
774 result.err()
775 );
776 }
777
778 #[test]
779 fn test_compile_multi_stratum() {
780 let mut compiler = Compiler::new();
781 let result = compiler.compile(
782 r#"
783 // Base facts
784 edge(1, 2).
785 edge(2, 3).
786 edge(3, 1).
787
788 // Stratum 0: edge (base)
789 // Stratum 1: reach (depends on edge, recursive)
790 reach(X, Y) :- edge(X, Y).
791 reach(X, Z) :- reach(X, Y), edge(Y, Z).
792
793 // Stratum 2: non_reach (negates reach)
794 all_pairs(X, Y) :- edge(X, Z), edge(Y, W).
795 non_reach(X, Y) :- all_pairs(X, Y), not reach(X, Y).
796 "#,
797 );
798 assert!(
799 result.is_ok(),
800 "Failed to compile multi-stratum: {:?}",
801 result.err()
802 );
803
804 let plan = result.unwrap();
805 assert!(!plan.strata.is_empty(), "Expected multiple strata");
807 }
808
809 #[test]
810 fn test_compile_aggregation() {
811 let mut compiler = Compiler::new();
812 let result = compiler.compile(
813 r#"
814 edge(1, 2).
815 edge(1, 3).
816 edge(2, 3).
817 out_degree(X, count(Y)) :- edge(X, Y).
818 "#,
819 );
820 assert!(
821 result.is_ok(),
822 "Failed to compile with aggregation: {:?}",
823 result.err()
824 );
825
826 let plan = result.unwrap();
827 let out_degree_rules: Vec<_> = plan
828 .rules_by_scc
829 .iter()
830 .flatten()
831 .filter(|r| r.head == "out_degree")
832 .collect();
833 assert_eq!(out_degree_rules.len(), 1, "Expected one out_degree rule");
834
835 let body = &out_degree_rules[0].body;
837 match body {
838 RirNode::Project { input, .. } => {
839 assert!(
840 matches!(input.as_ref(), RirNode::GroupBy { .. }),
841 "Expected Project(GroupBy(..)), got {:?}",
842 input
843 );
844 }
845 other => panic!("Expected Project(GroupBy(..)), got {:?}", other),
846 }
847 }
848
849 #[test]
850 fn test_compile_with_stats_snapshot() {
851 let mut compiler = Compiler::new();
852 let source = r#"
853 edge(1, 2).
854 edge(2, 3).
855 reach(X, Y) :- edge(X, Y).
856 "#;
857
858 let _ = compiler.compile(source).expect("Initial compile failed");
859 let edge_id = *compiler.rel_ids().get("edge").expect("edge rel_id missing");
860
861 let mut mgr = StatsManager::new();
862 mgr.register_relation(edge_id);
863 mgr.update_cardinality(edge_id, 42);
864 let snapshot = mgr.snapshot();
865
866 let plan = compiler
867 .compile_with_stats_snapshot(source, Some(&snapshot))
868 .expect("Compile with snapshot failed");
869 assert!(!plan.sccs.is_empty());
870 }
871
872 #[test]
873 fn test_compile_with_named_stats_snapshot_reorders_joins() {
874 let mut compiler = Compiler::new();
875 let source = r#"
876 foo(1).
877 edge(1).
878 out(X) :- edge(X), foo(X).
879 "#;
880
881 let mut edge_stats = RelationStats::new(RelId(0));
884 edge_stats.update_cardinality(10);
885 let mut foo_stats = RelationStats::new(RelId(1));
886 foo_stats.update_cardinality(10_000);
887
888 let snapshot = StatsSnapshot {
889 relations: vec![edge_stats, foo_stats],
890 join_selectivities: Vec::new(),
891 rel_names: vec![
892 (RelId(0), "edge".to_string()),
893 (RelId(1), "foo".to_string()),
894 ],
895 };
896
897 let plan = compiler
898 .compile_with_stats_snapshot(source, Some(&snapshot))
899 .expect("Compile with named snapshot failed");
900
901 let foo_id = *compiler.rel_ids().get("foo").expect("foo rel_id missing");
902 let edge_id = *compiler.rel_ids().get("edge").expect("edge rel_id missing");
903
904 let out_rule = plan
905 .rules_by_scc
906 .iter()
907 .flatten()
908 .find(|r| r.head == "out")
909 .expect("out rule missing");
910
911 let mut node = &out_rule.body;
913 while let RirNode::Project { input, .. } = node {
914 node = input;
915 }
916
917 match node {
918 RirNode::ChainJoin {
919 left,
920 right,
921 fallback,
922 ..
923 } => {
924 assert!(matches!(**left, RirNode::Scan { rel } if rel == foo_id));
928 assert!(matches!(**right, RirNode::Scan { rel } if rel == edge_id));
929
930 let mut fallback_node = fallback.as_ref();
931 while let RirNode::Project { input, .. } = fallback_node {
932 fallback_node = input;
933 }
934 match fallback_node {
935 RirNode::Join { left, right, .. } => {
936 assert!(matches!(**left, RirNode::Scan { rel } if rel == foo_id));
937 assert!(matches!(**right, RirNode::Scan { rel } if rel == edge_id));
938 }
939 other => panic!("Expected ChainJoin fallback Join node, got {:?}", other),
940 }
941 }
942 RirNode::Join { left, right, .. } => {
943 assert!(matches!(**left, RirNode::Scan { rel } if rel == foo_id));
945 assert!(matches!(**right, RirNode::Scan { rel } if rel == edge_id));
946 }
947 other => panic!("Expected Join node, got {:?}", other),
948 }
949 }
950
951 fn helper_split_source() -> &'static str {
952 r#"
953 ab(0, 0). bc(0, 0). cd(0, 0). de(0, 0). ef(0, 0). af(0, 0).
954 out(A, B, C, D, F) :-
955 ab(A, B),
956 bc(B, C),
957 cd(C, D),
958 de(D, E),
959 ef(E, F),
960 af(A, F).
961 "#
962 }
963
964 fn helper_split_snapshot(distinct_d: u64) -> StatsSnapshot {
965 let mut snapshot_relations = Vec::new();
966 for (idx, name) in ["ab", "bc", "cd", "de", "ef", "af"].iter().enumerate() {
967 let mut rel_stats = RelationStats::new(RelId(idx as u32));
968 rel_stats.update_cardinality(8192);
969 if *name == "de" {
970 let mut d_col = ColumnStats::new(0, ScalarType::U32);
971 d_col.update_distinct(distinct_d);
972 rel_stats.add_column(d_col);
973 }
974 snapshot_relations.push(rel_stats);
975 }
976 StatsSnapshot {
977 relations: snapshot_relations,
978 join_selectivities: Vec::new(),
979 rel_names: ["ab", "bc", "cd", "de", "ef", "af"]
980 .iter()
981 .enumerate()
982 .map(|(idx, name)| (RelId(idx as u32), (*name).to_string()))
983 .collect(),
984 }
985 }
986
987 #[test]
988 fn test_compile_with_named_stats_snapshot_creates_helper_relation() {
989 let mut compiler = Compiler::new();
990 let snapshot = helper_split_snapshot(1);
991 let plan = compiler
992 .compile_with_stats_snapshot(helper_split_source(), Some(&snapshot))
993 .expect("compile with helper stats");
994 let helper = compiler
995 .rel_ids()
996 .iter()
997 .find_map(|(name, rel)| {
998 name.starts_with("__w37_helper_")
999 .then_some((name.clone(), *rel))
1000 })
1001 .expect("helper relation allocated");
1002
1003 let helper_rule_count = plan
1004 .rules_by_scc
1005 .iter()
1006 .flatten()
1007 .filter(|rule| rule.head == helper.0)
1008 .count();
1009 assert_eq!(helper_rule_count, 1);
1010
1011 let helper_rule = plan
1012 .rules_by_scc
1013 .iter()
1014 .flatten()
1015 .find(|rule| rule.head == helper.0)
1016 .expect("helper rule");
1017 assert!(
1018 matches!(helper_rule.body, RirNode::ChainJoin { .. }),
1019 "helper split output should be eligible for W63 ChainJoin promotion"
1020 );
1021
1022 let out_rule = plan
1023 .rules_by_scc
1024 .iter()
1025 .flatten()
1026 .find(|rule| rule.head == "out")
1027 .expect("out rule");
1028 assert!(contains_scan(&out_rule.body, helper.1));
1029 }
1030
1031 #[test]
1032 fn test_compile_with_flat_named_stats_keeps_original_rule() {
1033 let mut compiler = Compiler::new();
1034 let snapshot = helper_split_snapshot(8192);
1035 let plan = compiler
1036 .compile_with_stats_snapshot(helper_split_source(), Some(&snapshot))
1037 .expect("compile with flat stats");
1038
1039 assert!(!compiler
1040 .rel_ids()
1041 .keys()
1042 .any(|name| name.starts_with("__w37_helper_")));
1043 let out_rules = plan
1044 .rules_by_scc
1045 .iter()
1046 .flatten()
1047 .filter(|rule| rule.head == "out")
1048 .count();
1049 assert_eq!(out_rules, 1);
1050 }
1051
1052 fn contains_scan(node: &RirNode, rel: RelId) -> bool {
1053 match node {
1054 RirNode::Scan { rel: scan_rel } => *scan_rel == rel,
1055 RirNode::Join { left, right, .. } | RirNode::ChainJoin { left, right, .. } => {
1056 contains_scan(left, rel) || contains_scan(right, rel)
1057 }
1058 RirNode::Project { input, .. }
1059 | RirNode::Filter { input, .. }
1060 | RirNode::Distinct { input, .. }
1061 | RirNode::GroupBy { input, .. } => contains_scan(input, rel),
1062 RirNode::Union { inputs } => inputs.iter().any(|input| contains_scan(input, rel)),
1063 RirNode::Diff { left, right } => contains_scan(left, rel) || contains_scan(right, rel),
1064 RirNode::Fixpoint {
1065 base, recursive, ..
1066 } => contains_scan(base, rel) || contains_scan(recursive, rel),
1067 RirNode::MultiWayJoin { inputs, .. } => {
1068 inputs.iter().any(|input| contains_scan(input, rel))
1069 }
1070 RirNode::TensorMaskedJoin { rel_index, .. } => {
1071 rel_index.iter().any(|(input_rel, _)| *input_rel == rel)
1072 }
1073 RirNode::Unit => false,
1074 }
1075 }
1076}