1use std::collections::{HashMap, HashSet};
14
15use xlog_core::{symbol, AggOp as CoreAggOp, RelId, Result, ScalarType, Schema, XlogError};
16use xlog_ir::{
17 CompareOp, CompiledRule, ConstValue, ExecutionPlan, Expr, JoinType, PlanBuilder, ProjectExpr,
18 RirMeta, RirNode, Scc, Stratum as IrStratum,
19};
20
21use crate::ast::{
22 AggOp, ArithExpr, Atom, BodyLiteral, CompOp, Comparison, IsExpr, LearnableRule, PredColumn,
23 Program, Rule, Term, TypeRef,
24};
25use crate::stratify::{build_dependency_graph, find_sccs_for_lowering, DepType};
26
27struct JoinPlan<'a> {
28 node: RirNode,
29 leaf_order: Vec<&'a Atom>,
30 leaf_order_idx: Vec<usize>,
31 var_pos: HashMap<String, usize>,
32 width: usize,
33 est_rows: f64,
34 total_cost: f64,
35}
36
37fn pred_columns_for_decl(pred_decl: &crate::ast::PredDecl) -> Vec<PredColumn> {
38 if pred_decl.columns.is_empty() {
39 pred_decl
40 .types
41 .iter()
42 .cloned()
43 .map(|typ| PredColumn { name: None, typ })
44 .collect()
45 } else {
46 pred_decl.columns.clone()
47 }
48}
49
50fn resolve_pred_column_type(
51 predicate: &str,
52 index: usize,
53 typ: &TypeRef,
54 domains: &HashMap<String, ScalarType>,
55) -> Result<ScalarType> {
56 match typ {
57 TypeRef::Scalar(ty) => Ok(*ty),
58 TypeRef::Domain(name) => domains.get(name).copied().ok_or_else(|| {
59 XlogError::Compilation(format!(
60 "v0.8.5 unknown domain alias '{}' in predicate '{}' column {}",
61 name, predicate, index
62 ))
63 }),
64 TypeRef::List(_) | TypeRef::Term | TypeRef::Compound | TypeRef::PredRef => {
65 Ok(ScalarType::U64)
66 }
67 }
68}
69
70fn validate_lowerable_terms(program: &Program) -> Result<()> {
71 for rule in &program.rules {
72 validate_atom_terms(&rule.head, "rule head")?;
73 for lit in &rule.body {
74 match lit {
75 BodyLiteral::Positive(atom) => validate_atom_terms(atom, "positive body atom")?,
76 BodyLiteral::Negated(atom) => validate_atom_terms(atom, "negated body atom")?,
77 BodyLiteral::Epistemic(_) => {}
78 BodyLiteral::Comparison(cmp) => {
79 validate_term_lowerable(&cmp.left, "comparison left operand")?;
80 validate_term_lowerable(&cmp.right, "comparison right operand")?;
81 }
82 BodyLiteral::IsExpr(_) => {}
83 BodyLiteral::Univ(_) => {
84 return Err(XlogError::Compilation(
85 "v0.8.5 meta error: univ literal was not normalized before lowering"
86 .to_string(),
87 ));
88 }
89 }
90 }
91 }
92 for constraint in &program.constraints {
93 for lit in &constraint.body {
94 match lit {
95 BodyLiteral::Positive(atom) => validate_atom_terms(atom, "constraint body atom")?,
96 BodyLiteral::Negated(atom) => {
97 validate_atom_terms(atom, "constraint negated body atom")?
98 }
99 BodyLiteral::Epistemic(_) => {}
100 BodyLiteral::Comparison(cmp) => {
101 validate_term_lowerable(&cmp.left, "constraint comparison left operand")?;
102 validate_term_lowerable(&cmp.right, "constraint comparison right operand")?;
103 }
104 BodyLiteral::IsExpr(_) => {}
105 BodyLiteral::Univ(_) => {
106 return Err(XlogError::Compilation(
107 "v0.8.5 meta error: univ literal was not normalized before lowering"
108 .to_string(),
109 ));
110 }
111 }
112 }
113 }
114 for query in &program.queries {
115 validate_atom_terms(&query.atom, "query atom")?;
116 }
117 for pf in &program.prob_facts {
118 validate_atom_terms(&pf.atom, "probabilistic fact")?;
119 }
120 for ad in &program.annotated_disjunctions {
121 for choice in &ad.choices {
122 validate_atom_terms(&choice.atom, "annotated disjunction choice")?;
123 }
124 }
125 for evidence in &program.evidence {
126 validate_atom_terms(&evidence.atom, "evidence atom")?;
127 }
128 for query in &program.prob_queries {
129 validate_atom_terms(&query.atom, "probabilistic query")?;
130 }
131 for neural in &program.neural_predicates {
132 validate_atom_terms(&neural.predicate, "neural predicate")?;
133 }
134 for learnable in &program.learnable_rules {
135 validate_atom_terms(&learnable.head, "learnable rule head")?;
136 for lit in &learnable.body {
137 if let BodyLiteral::Positive(atom) = lit {
138 validate_atom_terms(atom, "learnable rule body")?;
139 }
140 }
141 }
142 Ok(())
143}
144
145fn validate_atom_terms(atom: &Atom, context: &str) -> Result<()> {
146 for term in &atom.terms {
147 validate_term_lowerable(term, context)?;
148 }
149 Ok(())
150}
151
152fn validate_term_lowerable(term: &Term, context: &str) -> Result<()> {
153 match term {
154 Term::List(_) => Err(v085_term_not_lowerable(context, "list")),
155 Term::Cons { .. } => Err(v085_term_not_lowerable(context, "cons")),
156 Term::Compound { .. } => Err(v085_term_not_lowerable(context, "compound")),
157 Term::PredRef(_) => Err(v085_term_not_lowerable(context, "predref")),
158 Term::Variable(_)
159 | Term::Anonymous
160 | Term::Integer(_)
161 | Term::Float(_)
162 | Term::String(_)
163 | Term::Symbol(_)
164 | Term::Aggregate(_) => Ok(()),
165 }
166}
167
168fn v085_term_not_lowerable(context: &str, kind: &str) -> XlogError {
169 XlogError::Compilation(format!(
170 "v0.8.5 term form '{}' in {} is parsed but not lowerable before its G085 implementation node",
171 kind, context
172 ))
173}
174
175fn v085_term_kind(term: &Term) -> &'static str {
176 match term {
177 Term::List(_) => "list",
178 Term::Cons { .. } => "cons",
179 Term::Compound { .. } => "compound",
180 Term::PredRef(_) => "predref",
181 Term::Variable(_)
182 | Term::Anonymous
183 | Term::Integer(_)
184 | Term::Float(_)
185 | Term::String(_)
186 | Term::Symbol(_)
187 | Term::Aggregate(_) => "term",
188 }
189}
190
191pub struct Lowerer {
193 schemas: HashMap<String, Schema>,
195 strata: Vec<Vec<String>>,
197 est_cardinality: HashMap<String, u64>,
199 cardinality_hints: HashMap<String, u64>,
201 next_rel_id: u32,
203 rel_ids: HashMap<String, RelId>,
205 sccs: Vec<Scc>,
207 max_active_rules: usize,
209}
210
211impl Default for Lowerer {
212 fn default() -> Self {
213 Self::new()
214 }
215}
216
217impl Lowerer {
218 pub fn new() -> Self {
220 Self {
221 schemas: HashMap::new(),
222 strata: Vec::new(),
223 est_cardinality: HashMap::new(),
224 cardinality_hints: HashMap::new(),
225 next_rel_id: 0,
226 rel_ids: HashMap::new(),
227 sccs: Vec::new(),
228 max_active_rules: 32,
229 }
230 }
231
232 pub fn set_max_active_rules(&mut self, max: usize) {
234 self.max_active_rules = max;
235 }
236
237 pub(crate) fn set_strata(&mut self, strata: Vec<Vec<String>>) {
239 self.strata = strata;
240 }
241
242 pub(crate) fn set_cardinality_hints(&mut self, hints: HashMap<String, u64>) {
246 self.cardinality_hints = hints;
247 }
248
249 pub fn rel_ids(&self) -> &HashMap<String, RelId> {
251 &self.rel_ids
252 }
253
254 pub fn schemas(&self) -> &HashMap<String, Schema> {
256 &self.schemas
257 }
258
259 pub(crate) fn create_helper_relation(&mut self, schema: Schema) -> (String, RelId) {
260 let name = format!("__w37_helper_{}", self.next_rel_id);
261 let rel_id = self.get_or_create_rel_id(&name);
262 self.schemas.insert(name.clone(), schema);
263 (name, rel_id)
264 }
265
266 fn get_or_create_rel_id(&mut self, name: &str) -> RelId {
268 if let Some(&id) = self.rel_ids.get(name) {
269 id
270 } else {
271 let id = RelId(self.next_rel_id);
272 self.next_rel_id += 1;
273 self.rel_ids.insert(name.to_string(), id);
274 id
275 }
276 }
277
278 fn infer_schemas(&mut self, program: &Program) -> Result<()> {
280 let domains: HashMap<String, ScalarType> = program
281 .domains
282 .iter()
283 .map(|domain| (domain.name.clone(), domain.typ))
284 .collect();
285
286 for pred_decl in &program.predicates {
288 let declared_columns = pred_columns_for_decl(pred_decl);
289 let columns: Vec<(String, ScalarType)> = declared_columns
290 .iter()
291 .enumerate()
292 .map(|(i, col)| {
293 let name = col.name.clone().unwrap_or_else(|| format!("c{}", i));
294 resolve_pred_column_type(&pred_decl.name, i, &col.typ, &domains)
295 .map(|ty| (name, ty))
296 })
297 .collect::<Result<Vec<_>>>()?;
298 self.schemas
299 .insert(pred_decl.name.clone(), Schema::new(columns));
300 }
301
302 for rule in program.facts() {
304 let pred = &rule.head.predicate;
305 if !self.schemas.contains_key(pred) {
306 let columns: Vec<(String, ScalarType)> = rule
307 .head
308 .terms
309 .iter()
310 .enumerate()
311 .map(|(i, term)| {
312 let ty = infer_term_type(term);
313 (format!("c{}", i), ty)
314 })
315 .collect();
316 self.schemas.insert(pred.clone(), Schema::new(columns));
317 }
318 }
319
320 for rule in &program.rules {
322 let pred = &rule.head.predicate;
323 if !self.schemas.contains_key(pred) {
324 let columns: Vec<(String, ScalarType)> = rule
326 .head
327 .terms
328 .iter()
329 .enumerate()
330 .map(|(i, term)| {
331 let ty = match term {
332 Term::Variable(name) => self
333 .infer_head_term_type_from_body(rule, name)
334 .unwrap_or_else(|| infer_term_type(term)),
335 _ => infer_term_type(term),
336 };
337 (format!("c{}", i), ty)
338 })
339 .collect();
340 let schema = Schema::new(columns)
341 .with_sort_labels(sort_labels_from_terms(&rule.head.terms))
342 .expect("rule head sort labels match inferred schema arity");
343 self.schemas.insert(pred.clone(), schema);
344 }
345 }
346
347 for rule in &program.rules {
349 for lit in &rule.body {
350 let atom = match lit {
351 BodyLiteral::Positive(atom) | BodyLiteral::Negated(atom) => atom,
352 BodyLiteral::Epistemic(_)
353 | BodyLiteral::Comparison(_)
354 | BodyLiteral::IsExpr(_)
355 | BodyLiteral::Univ(_) => continue,
356 };
357 let pred = &atom.predicate;
358 if self.schemas.contains_key(pred) {
359 continue;
360 }
361 let columns: Vec<(String, ScalarType)> = atom
362 .terms
363 .iter()
364 .enumerate()
365 .map(|(i, term)| (format!("c{}", i), infer_term_type(term)))
366 .collect();
367 let schema = Schema::new(columns)
368 .with_sort_labels(sort_labels_from_terms(&atom.terms))
369 .expect("body sort labels match inferred schema arity");
370 self.schemas.insert(pred.clone(), schema);
371 }
372 }
373
374 for pf in &program.prob_facts {
376 let pred = &pf.atom.predicate;
377 if self.schemas.contains_key(pred) {
378 continue;
379 }
380 let columns: Vec<(String, ScalarType)> = pf
381 .atom
382 .terms
383 .iter()
384 .enumerate()
385 .map(|(i, term)| (format!("c{}", i), infer_term_type(term)))
386 .collect();
387 self.schemas.insert(pred.clone(), Schema::new(columns));
388 }
389
390 for ad in &program.annotated_disjunctions {
391 for choice in &ad.choices {
392 let pred = &choice.atom.predicate;
393 if self.schemas.contains_key(pred) {
394 continue;
395 }
396 let columns: Vec<(String, ScalarType)> = choice
397 .atom
398 .terms
399 .iter()
400 .enumerate()
401 .map(|(i, term)| (format!("c{}", i), infer_term_type(term)))
402 .collect();
403 self.schemas.insert(pred.clone(), Schema::new(columns));
404 }
405 }
406
407 Ok(())
408 }
409
410 fn infer_head_term_type_from_body(&self, rule: &Rule, var_name: &str) -> Option<ScalarType> {
411 for lit in &rule.body {
412 let atom = match lit {
413 BodyLiteral::Positive(atom) | BodyLiteral::Negated(atom) => atom,
414 BodyLiteral::Epistemic(_)
415 | BodyLiteral::Comparison(_)
416 | BodyLiteral::IsExpr(_)
417 | BodyLiteral::Univ(_) => continue,
418 };
419 let schema = self.schemas.get(&atom.predicate)?;
420 for (idx, term) in atom.terms.iter().enumerate() {
421 if let Term::Variable(name) = term {
422 if name == var_name {
423 if let Some(ty) = schema.column_type(idx) {
424 return Some(ty);
425 }
426 }
427 }
428 }
429 }
430 None
431 }
432
433 fn infer_cardinalities(&mut self, program: &Program) {
434 self.est_cardinality.clear();
435
436 let mut fact_counts: HashMap<String, u64> = HashMap::new();
437 for fact in program.facts() {
438 *fact_counts.entry(fact.head.predicate.clone()).or_insert(0) += 1;
439 }
440
441 for pred in self.schemas.keys() {
442 let est = self
443 .cardinality_hints
444 .get(pred)
445 .copied()
446 .or_else(|| fact_counts.get(pred).copied())
447 .unwrap_or(1000)
448 .max(1);
449 self.est_cardinality.insert(pred.clone(), est);
450 }
451 }
452
453 fn build_sccs(&mut self, program: &Program) {
455 let graph = build_dependency_graph(program);
456 let scc_groups = find_sccs_for_lowering(&graph);
457
458 self.sccs.clear();
459 for (id, predicates) in scc_groups.iter().enumerate() {
460 let is_recursive = if predicates.len() > 1 {
463 true
464 } else {
465 let pred = &predicates[0];
466 graph
467 .outgoing(pred)
468 .iter()
469 .any(|e| e.to == *pred && e.dep_type == DepType::Positive)
470 };
471
472 self.sccs.push(Scc {
473 id: id as u32,
474 predicates: predicates.clone(),
475 is_recursive,
476 });
477 }
478 }
479
480 pub fn lower_program(&mut self, program: &Program) -> Result<ExecutionPlan> {
482 validate_lowerable_terms(program)?;
483 self.infer_schemas(program)?;
485 self.infer_cardinalities(program);
486
487 for pred_decl in &program.predicates {
492 self.get_or_create_rel_id(&pred_decl.name);
493 }
494
495 self.build_sccs(program);
497
498 let mut builder = PlanBuilder::new();
500
501 for scc in &self.sccs {
503 builder.add_scc(scc.clone());
504 }
505
506 for (id, preds) in self.strata.iter().enumerate() {
508 let scc_ids: Vec<u32> = self
510 .sccs
511 .iter()
512 .filter(|scc| scc.predicates.iter().any(|p| preds.contains(p)))
513 .map(|scc| scc.id)
514 .collect();
515
516 if !scc_ids.is_empty() {
517 builder.add_stratum(IrStratum {
518 id: id as u32,
519 sccs: scc_ids,
520 });
521 }
522 }
523
524 let mut rules_by_pred: HashMap<String, Vec<&Rule>> = HashMap::new();
526 for rule in program.proper_rules() {
527 rules_by_pred
528 .entry(rule.head.predicate.clone())
529 .or_default()
530 .push(rule);
531 }
532
533 for fact in program.facts() {
535 let pred = &fact.head.predicate;
536 let scc_id = self.find_scc_for_predicate(pred);
537 let rel_id = self.get_or_create_rel_id(pred);
538
539 let body = RirNode::Scan { rel: rel_id };
540 let meta = self.create_meta_for_predicate(pred);
541
542 builder.add_rule(
543 scc_id,
544 CompiledRule {
545 head: pred.clone(),
546 body,
547 meta,
548 },
549 );
550 }
551
552 for (pred, rules) in &rules_by_pred {
554 let scc_id = self.find_scc_for_predicate(pred);
555
556 for rule in rules {
557 let body = self.lower_rule(rule)?;
558 let meta = self.create_meta_for_predicate(pred);
559
560 builder.add_rule(
561 scc_id,
562 CompiledRule {
563 head: pred.clone(),
564 body,
565 meta,
566 },
567 );
568 }
569 }
570
571 for learnable in &program.learnable_rules {
575 self.get_or_create_rel_id(&learnable.head.predicate);
576 for lit in &learnable.body {
577 if let BodyLiteral::Positive(atom) = lit {
578 self.get_or_create_rel_id(&atom.predicate);
579 }
580 }
581 }
582 for learnable in &program.learnable_rules {
583 let head_pred = &learnable.head.predicate;
584 let scc_id = self.find_scc_for_predicate(head_pred);
585 let body = self.lower_learnable_rule(learnable)?;
586 let meta = self.create_meta_for_predicate(head_pred);
587 builder.add_rule(
588 scc_id,
589 CompiledRule {
590 head: head_pred.clone(),
591 body,
592 meta,
593 },
594 );
595 }
596
597 Ok(builder.build())
598 }
599
600 fn find_scc_for_predicate(&self, pred: &str) -> u32 {
602 self.sccs
603 .iter()
604 .find(|scc| scc.predicates.contains(&pred.to_string()))
605 .map(|scc| scc.id)
606 .unwrap_or(0)
607 }
608
609 fn create_meta_for_predicate(&self, pred: &str) -> RirMeta {
611 let schema = self
612 .schemas
613 .get(pred)
614 .cloned()
615 .unwrap_or_else(|| Schema::new(vec![]));
616 RirMeta::with_schema(schema)
617 }
618
619 fn lower_learnable_rule(&mut self, rule: &LearnableRule) -> Result<RirNode> {
624 if rule.body.len() != 2 {
626 return Err(XlogError::Compilation(format!(
627 "learnable rule '{}' requires exactly 2 body literals, got {}",
628 rule.mask_name,
629 rule.body.len()
630 )));
631 }
632 for (idx, lit) in rule.body.iter().enumerate() {
633 match lit {
634 BodyLiteral::Positive(_) => {}
635 _ => {
636 return Err(XlogError::Compilation(format!(
637 "learnable rule '{}' body[{}]: only positive atoms allowed",
638 rule.mask_name, idx
639 )));
640 }
641 }
642 }
643
644 let mut rel_index: Vec<(RelId, String)> = self
646 .rel_ids()
647 .iter()
648 .map(|(name, id)| (*id, name.clone()))
649 .collect();
650 rel_index.sort_by_key(|(id, _)| id.0);
651 let schema_size = rel_index.len();
652
653 let (left_keys, right_keys) =
654 self.extract_template_join_keys(&rule.body[0], &rule.body[1])?;
655
656 let head_rel_name = rule.head.predicate.clone();
657 let head_rel_id = self.get_or_create_rel_id(&head_rel_name);
659
660 let left_atom = rule.body[0].atom().unwrap();
663 let right_atom = rule.body[1].atom().unwrap();
664 let left_arity = left_atom.terms.len();
665
666 let mut var_to_col: HashMap<String, usize> = HashMap::new();
668 for (i, term) in left_atom.terms.iter().enumerate() {
669 if let Some(name) = term.variable_name() {
670 var_to_col.entry(name.to_string()).or_insert(i);
671 }
672 }
673 for (i, term) in right_atom.terms.iter().enumerate() {
674 if let Some(name) = term.variable_name() {
675 var_to_col.entry(name.to_string()).or_insert(left_arity + i);
676 }
677 }
678
679 let mut head_projection: Vec<usize> = Vec::new();
680 for term in &rule.head.terms {
681 if let Some(name) = term.variable_name() {
682 let col = var_to_col.get(name).ok_or_else(|| {
683 XlogError::Compilation(format!(
684 "Learnable rule head variable '{}' not found in body atoms \
685 ({}, {}). All head variables must appear in the body.",
686 name, left_atom.predicate, right_atom.predicate,
687 ))
688 })?;
689 head_projection.push(*col);
690 } else {
691 return Err(XlogError::Compilation(format!(
692 "Learnable rule head must contain only variables, \
693 found constant {:?} in head of '{}'",
694 term, head_rel_name,
695 )));
696 }
697 }
698
699 if !self.schemas.contains_key(&head_rel_name) {
702 let columns: Vec<(String, ScalarType)> = head_projection
703 .iter()
704 .enumerate()
705 .map(|(i, &col)| {
706 let ty = if col < left_arity {
708 self.schemas
709 .get(&left_atom.predicate)
710 .and_then(|s| s.column_type(col))
711 .unwrap_or(ScalarType::U32)
712 } else {
713 self.schemas
714 .get(&right_atom.predicate)
715 .and_then(|s| s.column_type(col - left_arity))
716 .unwrap_or(ScalarType::U32)
717 };
718 (format!("c{}", i), ty)
719 })
720 .collect();
721 self.schemas
722 .insert(head_rel_name.clone(), Schema::new(columns));
723 }
724
725 Ok(RirNode::TensorMaskedJoin {
726 mask_name: rule.mask_name.clone(),
727 schema_size,
728 left_keys,
729 right_keys,
730 rel_index,
731 head_rel_name,
732 head_rel_id,
733 max_active_rules: self.max_active_rules,
734 head_projection,
735 })
736 }
737
738 fn extract_template_join_keys(
741 &self,
742 left: &BodyLiteral,
743 right: &BodyLiteral,
744 ) -> Result<(Vec<usize>, Vec<usize>)> {
745 let left_atom = left
746 .atom()
747 .ok_or_else(|| XlogError::Compilation("Learnable body[0] is not an atom".into()))?;
748 let right_atom = right
749 .atom()
750 .ok_or_else(|| XlogError::Compilation("Learnable body[1] is not an atom".into()))?;
751
752 let mut left_keys = Vec::new();
753 let mut right_keys = Vec::new();
754
755 for (li, lt) in left_atom.terms.iter().enumerate() {
756 if let Some(lname) = lt.variable_name() {
757 for (ri, rt) in right_atom.terms.iter().enumerate() {
758 if let Some(rname) = rt.variable_name() {
759 if lname == rname {
760 left_keys.push(li);
761 right_keys.push(ri);
762 }
763 }
764 }
765 }
766 }
767
768 Ok((left_keys, right_keys))
769 }
770
771 fn lower_rule(&mut self, rule: &Rule) -> Result<RirNode> {
773 if let Some(lit) = rule.body.iter().find_map(|lit| match lit {
774 BodyLiteral::Epistemic(lit) => Some(lit),
775 _ => None,
776 }) {
777 return Err(XlogError::UnsupportedEpistemicConstruct {
778 construct: "RIR lowering boundary".to_string(),
779 context: format!("{:?} {}({})", lit.op, lit.atom.predicate, lit.atom.arity()),
780 });
781 }
782
783 let (positive_atoms, negated_atoms, comparisons, is_exprs) =
785 Self::split_body_literals(&rule.body);
786
787 for lit in &rule.body {
790 match lit {
791 BodyLiteral::Positive(atom) | BodyLiteral::Negated(atom) => {
792 self.get_or_create_rel_id(&atom.predicate);
793 }
794 BodyLiteral::Epistemic(_)
795 | BodyLiteral::Comparison(_)
796 | BodyLiteral::IsExpr(_)
797 | BodyLiteral::Univ(_) => {}
798 }
799 }
800
801 let (positive_root, leaf_order) = if positive_atoms.is_empty() {
807 (RirNode::Unit, Vec::new())
808 } else {
809 self.plan_positive_atoms(&positive_atoms)?
810 };
811
812 let mut var_env = VariableEnv::new();
815 let mut current_col = 0;
816 for atom in &leaf_order {
817 let schema = self.schemas.get(&atom.predicate);
818 for (i, term) in atom.terms.iter().enumerate() {
819 if let Term::Variable(name) = term {
820 if name == "_" {
821 continue;
822 }
823 var_env.add_occurrence(name, atom.predicate.clone(), i, current_col + i);
824 if !var_env.types.contains_key(name) {
826 let typ = schema
827 .and_then(|s| s.column_type(i))
828 .unwrap_or(ScalarType::I64); var_env.types.insert(name.to_string(), typ);
830 }
831 }
832 }
833 current_col += atom.terms.len();
834 }
835 var_env.total_cols = current_col;
836
837 let body_node = self.lower_body_parts(
839 positive_root,
840 &negated_atoms,
841 &comparisons,
842 &is_exprs,
843 &mut var_env,
844 )?;
845
846 if rule.has_aggregation() {
847 return self.lower_aggregate_rule(&rule.head, body_node, &var_env);
848 }
849
850 let projection_exprs = self.compute_head_projection(&rule.head, &var_env)?;
852
853 if Self::is_identity_projection(&projection_exprs, var_env.column_count()) {
854 Ok(body_node)
855 } else {
856 Ok(RirNode::Project {
857 input: Box::new(body_node),
858 columns: projection_exprs,
859 })
860 }
861 }
862
863 fn split_body_literals(
864 body: &[BodyLiteral],
865 ) -> (Vec<&Atom>, Vec<&Atom>, Vec<&Comparison>, Vec<&IsExpr>) {
866 let mut positive_atoms: Vec<&Atom> = Vec::new();
867 let mut negated_atoms: Vec<&Atom> = Vec::new();
868 let mut comparisons: Vec<&Comparison> = Vec::new();
869 let mut is_exprs: Vec<&IsExpr> = Vec::new();
870
871 for lit in body {
872 match lit {
873 BodyLiteral::Positive(atom) => positive_atoms.push(atom),
874 BodyLiteral::Negated(atom) => negated_atoms.push(atom),
875 BodyLiteral::Epistemic(_) => {}
876 BodyLiteral::Comparison(cmp) => comparisons.push(cmp),
877 BodyLiteral::IsExpr(is_expr) => is_exprs.push(is_expr),
878 BodyLiteral::Univ(_) => {}
879 }
880 }
881
882 (positive_atoms, negated_atoms, comparisons, is_exprs)
883 }
884
885 fn atom_vars(atom: &Atom) -> std::collections::HashSet<String> {
886 atom.terms
887 .iter()
888 .flat_map(|t| t.variables().into_iter())
889 .filter(|name| *name != "_")
890 .map(ToOwned::to_owned)
891 .collect()
892 }
893
894 fn estimate_atom_rows(&self, atom: &Atom) -> f64 {
895 let base = self
896 .est_cardinality
897 .get(&atom.predicate)
898 .copied()
899 .unwrap_or(1000)
900 .max(1) as f64;
901
902 let const_count = atom
903 .terms
904 .iter()
905 .filter(|t| term_to_const_value(t).is_some())
906 .count();
907
908 let selectivity = 0.1_f64.powi(const_count as i32);
910 (base * selectivity).max(1.0)
911 }
912
913 fn build_cartesian_join(
914 &self,
915 left: RirNode,
916 right: RirNode,
917 left_width: usize,
918 right_width: usize,
919 ) -> RirNode {
920 let left_const_col =
923 ProjectExpr::Computed(Expr::Const(ConstValue::U32(0)), ScalarType::U32);
924 let right_const_col =
925 ProjectExpr::Computed(Expr::Const(ConstValue::U32(0)), ScalarType::U32);
926
927 let mut left_cols: Vec<ProjectExpr> = (0..left_width).map(ProjectExpr::Column).collect();
928 left_cols.push(left_const_col);
929 let left_aug = RirNode::Project {
930 input: Box::new(left),
931 columns: left_cols,
932 };
933
934 let mut right_cols: Vec<ProjectExpr> = (0..right_width).map(ProjectExpr::Column).collect();
935 right_cols.push(right_const_col);
936 let right_aug = RirNode::Project {
937 input: Box::new(right),
938 columns: right_cols,
939 };
940
941 let joined = RirNode::Join {
942 left: Box::new(left_aug),
943 right: Box::new(right_aug),
944 left_keys: vec![left_width],
945 right_keys: vec![right_width],
946 join_type: JoinType::Inner,
947 };
948
949 let mut keep: Vec<ProjectExpr> = Vec::with_capacity(left_width + right_width);
950 keep.extend((0..left_width).map(ProjectExpr::Column));
951 let right_start = left_width + 1;
952 keep.extend((right_start..right_start + right_width).map(ProjectExpr::Column));
953
954 RirNode::Project {
955 input: Box::new(joined),
956 columns: keep,
957 }
958 }
959
960 fn make_leaf_plan<'a>(&mut self, atom: &'a Atom, orig_idx: usize) -> Result<JoinPlan<'a>> {
961 let rel_id = self.get_or_create_rel_id(&atom.predicate);
962 let scan = RirNode::Scan { rel: rel_id };
963 let node = self.apply_constant_filters(scan, atom, 0)?;
964
965 let mut var_pos: HashMap<String, usize> = HashMap::new();
966 for (i, term) in atom.terms.iter().enumerate() {
967 if let Term::Variable(name) = term {
968 if name != "_" {
969 var_pos.entry(name.clone()).or_insert(i);
970 }
971 }
972 }
973
974 let est_rows = self.estimate_atom_rows(atom);
975 Ok(JoinPlan {
976 node,
977 leaf_order: vec![atom],
978 leaf_order_idx: vec![orig_idx],
979 var_pos,
980 width: atom.terms.len(),
981 est_rows,
982 total_cost: est_rows,
983 })
984 }
985
986 fn join_plans<'a>(&self, left: &JoinPlan<'a>, right: &JoinPlan<'a>) -> JoinPlan<'a> {
987 let shared_vars: Vec<&String> = left
988 .var_pos
989 .keys()
990 .filter(|v| right.var_pos.contains_key(*v))
991 .collect();
992
993 let node = if shared_vars.is_empty() {
994 self.build_cartesian_join(
995 left.node.clone(),
996 right.node.clone(),
997 left.width,
998 right.width,
999 )
1000 } else {
1001 let mut key_pairs: Vec<(usize, usize)> = shared_vars
1002 .iter()
1003 .filter_map(|v| {
1004 Some((
1005 left.var_pos.get(*v).copied()?,
1006 right.var_pos.get(*v).copied()?,
1007 ))
1008 })
1009 .collect();
1010 key_pairs.sort_unstable();
1011
1012 let (left_keys, right_keys): (Vec<usize>, Vec<usize>) = key_pairs.into_iter().unzip();
1013
1014 RirNode::Join {
1015 left: Box::new(left.node.clone()),
1016 right: Box::new(right.node.clone()),
1017 left_keys,
1018 right_keys,
1019 join_type: JoinType::Inner,
1020 }
1021 };
1022
1023 let mut leaf_order = left.leaf_order.clone();
1024 leaf_order.extend(right.leaf_order.iter().copied());
1025
1026 let mut leaf_order_idx = left.leaf_order_idx.clone();
1027 leaf_order_idx.extend_from_slice(&right.leaf_order_idx);
1028
1029 let mut var_pos = left.var_pos.clone();
1030 for (var, pos) in &right.var_pos {
1031 var_pos.entry(var.clone()).or_insert(left.width + *pos);
1032 }
1033
1034 let shared = shared_vars.len();
1035 let mut selectivity = if shared == 0 {
1036 1.0
1037 } else {
1038 0.1_f64.powi(shared as i32)
1039 };
1040 if shared == 0 {
1041 selectivity *= 1.0e6;
1043 }
1044
1045 let output_rows = (left.est_rows * right.est_rows * selectivity).max(1.0);
1046
1047 let build_cost = right.est_rows;
1049 let probe_cost = left.est_rows * 0.5;
1050 let total_cost = left.total_cost + right.total_cost + build_cost + probe_cost + output_rows;
1051
1052 JoinPlan {
1053 node,
1054 leaf_order,
1055 leaf_order_idx,
1056 var_pos,
1057 width: left.width + right.width,
1058 est_rows: output_rows,
1059 total_cost,
1060 }
1061 }
1062
1063 fn plan_positive_atoms_bushy<'a>(
1064 &mut self,
1065 atoms: &[&'a Atom],
1066 ) -> Result<(RirNode, Vec<&'a Atom>)> {
1067 let n = atoms.len();
1068 if n == 0 {
1069 return Err(XlogError::Compilation("Empty rule body".to_string()));
1070 }
1071 if n == 1 {
1072 let plan = self.make_leaf_plan(atoms[0], 0)?;
1073 return Ok((plan.node, plan.leaf_order));
1074 }
1075
1076 let size = 1usize << n;
1077 let mut best: Vec<Option<JoinPlan<'a>>> = (0..size).map(|_| None).collect();
1078
1079 for (i, atom) in atoms.iter().enumerate() {
1080 best[1usize << i] = Some(self.make_leaf_plan(atom, i)?);
1081 }
1082
1083 fn lex_lt(a: &[usize], b: &[usize]) -> bool {
1084 for (ai, bi) in a.iter().zip(b.iter()) {
1085 if ai != bi {
1086 return ai < bi;
1087 }
1088 }
1089 a.len() < b.len()
1090 }
1091
1092 for mask in 1..size {
1093 if mask.count_ones() <= 1 {
1094 continue;
1095 }
1096
1097 let mut best_for_mask: Option<JoinPlan<'a>> = None;
1098
1099 let mut sub = (mask - 1) & mask;
1100 while sub > 0 {
1101 let a = sub;
1102 let b = mask ^ a;
1103 if b == 0 {
1104 sub = (sub - 1) & mask;
1105 continue;
1106 }
1107
1108 let (Some(plan_a), Some(plan_b)) = (&best[a], &best[b]) else {
1109 sub = (sub - 1) & mask;
1110 continue;
1111 };
1112
1113 for (left, right) in [(plan_a, plan_b), (plan_b, plan_a)] {
1115 let cand = self.join_plans(left, right);
1116 let replace = match &best_for_mask {
1117 None => true,
1118 Some(current) => {
1119 if cand.total_cost < current.total_cost {
1120 true
1121 } else if (cand.total_cost - current.total_cost).abs() < 1e-9 {
1122 lex_lt(&cand.leaf_order_idx, ¤t.leaf_order_idx)
1123 } else {
1124 false
1125 }
1126 }
1127 };
1128
1129 if replace {
1130 best_for_mask = Some(cand);
1131 }
1132 }
1133
1134 sub = (sub - 1) & mask;
1135 }
1136
1137 best[mask] = best_for_mask;
1138 }
1139
1140 let full_mask = size - 1;
1141 if let Some(plan) = best[full_mask].take() {
1142 return Ok((plan.node, plan.leaf_order));
1143 }
1144
1145 let ordered = self.order_positive_atoms_greedy(atoms);
1147 let mut dummy_env = VariableEnv::new();
1148 let node = self.build_join_tree(&ordered, &mut dummy_env)?;
1149 Ok((node, ordered))
1150 }
1151
1152 fn plan_positive_atoms<'a>(&mut self, atoms: &[&'a Atom]) -> Result<(RirNode, Vec<&'a Atom>)> {
1153 if atoms.len() <= 1 {
1154 if atoms.is_empty() {
1155 return Err(XlogError::Compilation("Empty rule body".to_string()));
1156 }
1157 let plan = self.make_leaf_plan(atoms[0], 0)?;
1158 return Ok((plan.node, plan.leaf_order));
1159 }
1160
1161 const MAX_BUSHY_DP_ATOMS: usize = 10;
1162 if atoms.len() <= MAX_BUSHY_DP_ATOMS {
1163 return self.plan_positive_atoms_bushy(atoms);
1164 }
1165
1166 self.plan_positive_atoms_bushy_greedy(atoms)
1168 }
1169
1170 fn plan_positive_atoms_bushy_greedy<'a>(
1171 &mut self,
1172 atoms: &[&'a Atom],
1173 ) -> Result<(RirNode, Vec<&'a Atom>)> {
1174 if atoms.is_empty() {
1175 return Err(XlogError::Compilation("Empty rule body".to_string()));
1176 }
1177
1178 fn lex_lt(a: &[usize], b: &[usize]) -> bool {
1179 for (ai, bi) in a.iter().zip(b.iter()) {
1180 if ai != bi {
1181 return ai < bi;
1182 }
1183 }
1184 a.len() < b.len()
1185 }
1186
1187 let mut plans: Vec<JoinPlan<'a>> = Vec::with_capacity(atoms.len());
1188 for (idx, atom) in atoms.iter().enumerate() {
1189 plans.push(self.make_leaf_plan(atom, idx)?);
1190 }
1191
1192 while plans.len() > 1 {
1193 let mut best_pair: Option<(usize, usize, JoinPlan<'a>)> = None;
1194
1195 for i in 0..plans.len() {
1196 for j in (i + 1)..plans.len() {
1197 let a = &plans[i];
1198 let b = &plans[j];
1199
1200 let cand_ab = self.join_plans(a, b);
1201 let cand_ba = self.join_plans(b, a);
1202
1203 let cand = if cand_ab.total_cost < cand_ba.total_cost
1204 || (cand_ab.total_cost - cand_ba.total_cost).abs() < 1e-9
1205 && lex_lt(&cand_ab.leaf_order_idx, &cand_ba.leaf_order_idx)
1206 {
1207 cand_ab
1208 } else {
1209 cand_ba
1210 };
1211
1212 let replace = match &best_pair {
1213 None => true,
1214 Some((_bi, _bj, best)) => {
1215 if cand.total_cost < best.total_cost {
1216 true
1217 } else if (cand.total_cost - best.total_cost).abs() < 1e-9 {
1218 lex_lt(&cand.leaf_order_idx, &best.leaf_order_idx)
1219 } else {
1220 false
1221 }
1222 }
1223 };
1224
1225 if replace {
1226 best_pair = Some((i, j, cand));
1227 }
1228 }
1229 }
1230
1231 let Some((i, j, joined)) = best_pair else {
1232 break;
1233 };
1234
1235 let (a, b) = if i < j { (i, j) } else { (j, i) };
1237 plans.remove(b);
1238 plans.remove(a);
1239 plans.push(joined);
1240 }
1241
1242 let plan = plans
1243 .pop()
1244 .ok_or_else(|| XlogError::Compilation("Join planning failed".to_string()))?;
1245 Ok((plan.node, plan.leaf_order))
1246 }
1247
1248 fn order_positive_atoms_greedy<'a>(&self, atoms: &[&'a Atom]) -> Vec<&'a Atom> {
1249 let mut remaining: Vec<(usize, &Atom)> = atoms.iter().copied().enumerate().collect();
1250 let mut ordered: Vec<&Atom> = Vec::with_capacity(atoms.len());
1251 let mut bound_vars: HashSet<String> = HashSet::new();
1252
1253 while !remaining.is_empty() {
1254 let pick_idx = if ordered.is_empty() {
1255 remaining
1256 .iter()
1257 .enumerate()
1258 .min_by(|(_, a), (_, b)| {
1259 let (ai, aa) = **a;
1260 let (bi, bb) = **b;
1261 self.estimate_atom_rows(aa)
1262 .partial_cmp(&self.estimate_atom_rows(bb))
1263 .unwrap_or(std::cmp::Ordering::Equal)
1264 .then(ai.cmp(&bi))
1265 })
1266 .map(|(idx, _)| idx)
1267 .unwrap()
1268 } else {
1269 remaining
1270 .iter()
1271 .enumerate()
1272 .min_by(|(_, a), (_, b)| {
1273 let (ai, aa) = **a;
1274 let (bi, bb) = **b;
1275
1276 let a_vars = Self::atom_vars(aa);
1277 let b_vars = Self::atom_vars(bb);
1278
1279 let a_shared = a_vars.intersection(&bound_vars).count();
1280 let b_shared = b_vars.intersection(&bound_vars).count();
1281
1282 let a_score = if a_shared == 0 {
1283 self.estimate_atom_rows(aa) * 1.0e12
1284 } else {
1285 self.estimate_atom_rows(aa) / a_shared as f64
1286 };
1287 let b_score = if b_shared == 0 {
1288 self.estimate_atom_rows(bb) * 1.0e12
1289 } else {
1290 self.estimate_atom_rows(bb) / b_shared as f64
1291 };
1292
1293 a_score
1294 .partial_cmp(&b_score)
1295 .unwrap_or(std::cmp::Ordering::Equal)
1296 .then(ai.cmp(&bi))
1297 })
1298 .map(|(idx, _)| idx)
1299 .unwrap()
1300 };
1301
1302 let (_orig_idx, atom) = remaining.remove(pick_idx);
1303 ordered.push(atom);
1304 bound_vars.extend(Self::atom_vars(atom));
1305 }
1306
1307 ordered
1308 }
1309
1310 fn lower_body_parts(
1311 &mut self,
1312 positive_root: RirNode,
1313 negated_atoms: &[&Atom],
1314 comparisons: &[&Comparison],
1315 is_exprs: &[&IsExpr],
1316 var_env: &mut VariableEnv,
1317 ) -> Result<RirNode> {
1318 let mut result = positive_root;
1319
1320 for cmp in comparisons {
1322 result = self.apply_comparison(result, cmp, var_env)?;
1323 }
1324
1325 for is_expr in is_exprs {
1327 result = self.lower_is_expr(is_expr, result, var_env)?;
1328 }
1329
1330 for neg_atom in negated_atoms {
1332 result = self.apply_negation(result, neg_atom, var_env)?;
1333 }
1334
1335 Ok(result)
1336 }
1337
1338 fn build_join_tree(&mut self, atoms: &[&Atom], var_env: &mut VariableEnv) -> Result<RirNode> {
1340 if atoms.is_empty() {
1341 return Err(XlogError::Compilation("Empty rule body".to_string()));
1342 }
1343
1344 let first_atom = atoms[0];
1346 let rel_id = self.get_or_create_rel_id(&first_atom.predicate);
1347 let mut result = RirNode::Scan { rel: rel_id };
1348 let mut result_vars = self.collect_atom_vars(first_atom);
1349 let mut result_width = first_atom.terms.len();
1350
1351 result = self.apply_constant_filters(result, first_atom, 0)?;
1353
1354 for atom in atoms.iter().skip(1) {
1356 let right_rel_id = self.get_or_create_rel_id(&atom.predicate);
1357 let right_scan = RirNode::Scan { rel: right_rel_id };
1358
1359 let right_filtered = self.apply_constant_filters(right_scan, atom, 0)?;
1361
1362 let (left_keys, right_keys) = self.compute_join_keys(&result_vars, atom, result_width);
1364
1365 if left_keys.is_empty() {
1366 result = RirNode::Join {
1368 left: Box::new(result),
1369 right: Box::new(right_filtered),
1370 left_keys: vec![],
1371 right_keys: vec![],
1372 join_type: JoinType::Inner,
1373 };
1374 } else {
1375 result = RirNode::Join {
1376 left: Box::new(result),
1377 right: Box::new(right_filtered),
1378 left_keys,
1379 right_keys,
1380 join_type: JoinType::Inner,
1381 };
1382 }
1383
1384 for (i, term) in atom.terms.iter().enumerate() {
1386 if let Term::Variable(name) = term {
1387 result_vars.push((name.clone(), result_width + i));
1388 }
1389 }
1390 result_width += atom.terms.len();
1391 }
1392
1393 var_env.total_cols = result_width;
1395
1396 Ok(result)
1397 }
1398
1399 fn collect_atom_vars(&self, atom: &Atom) -> Vec<(String, usize)> {
1401 atom.terms
1402 .iter()
1403 .enumerate()
1404 .filter_map(|(i, term)| {
1405 if let Term::Variable(name) = term {
1406 Some((name.clone(), i))
1407 } else {
1408 None
1409 }
1410 })
1411 .collect()
1412 }
1413
1414 fn compute_join_keys(
1416 &self,
1417 left_vars: &[(String, usize)],
1418 right_atom: &Atom,
1419 _left_width: usize,
1420 ) -> (Vec<usize>, Vec<usize>) {
1421 let mut left_keys = Vec::new();
1422 let mut right_keys = Vec::new();
1423
1424 for (right_idx, term) in right_atom.terms.iter().enumerate() {
1425 if let Term::Variable(name) = term {
1426 for (left_name, left_idx) in left_vars {
1428 if left_name == name {
1429 left_keys.push(*left_idx);
1430 right_keys.push(right_idx);
1431 break; }
1433 }
1434 }
1435 }
1436
1437 (left_keys, right_keys)
1438 }
1439
1440 fn apply_constant_filters(
1442 &self,
1443 input: RirNode,
1444 atom: &Atom,
1445 _base_col: usize,
1446 ) -> Result<RirNode> {
1447 let mut filters = Vec::new();
1448 let mut first_var_col: HashMap<&str, usize> = HashMap::new();
1449 let schema = self.schemas.get(&atom.predicate).ok_or_else(|| {
1450 XlogError::Compilation(format!("Missing schema for predicate {}", atom.predicate))
1451 })?;
1452
1453 for (i, term) in atom.terms.iter().enumerate() {
1454 if let Term::Variable(name) = term {
1455 if name != "_" {
1456 if let Some(&first) = first_var_col.get(name.as_str()) {
1457 filters.push(Expr::Compare {
1458 left: Box::new(Expr::Column(first)),
1459 op: CompareOp::Eq,
1460 right: Box::new(Expr::Column(i)),
1461 });
1462 } else {
1463 first_var_col.insert(name.as_str(), i);
1464 }
1465 }
1466 }
1467
1468 let col_type = schema.column_type(i).ok_or_else(|| {
1469 XlogError::Compilation(format!(
1470 "Missing column type for {} column {}",
1471 atom.predicate, i
1472 ))
1473 })?;
1474 if let Some(const_val) = term_to_typed_const_value(term, col_type)? {
1475 filters.push(Expr::Compare {
1476 left: Box::new(Expr::Column(i)),
1477 op: CompareOp::Eq,
1478 right: Box::new(Expr::Const(const_val)),
1479 });
1480 }
1481 }
1482
1483 if filters.is_empty() {
1484 Ok(input)
1485 } else {
1486 let predicate = if filters.len() == 1 {
1487 filters.pop().unwrap()
1488 } else {
1489 Expr::And(filters)
1490 };
1491
1492 Ok(RirNode::Filter {
1493 input: Box::new(input),
1494 predicate,
1495 })
1496 }
1497 }
1498
1499 fn apply_comparison(
1501 &self,
1502 input: RirNode,
1503 cmp: &Comparison,
1504 var_env: &VariableEnv,
1505 ) -> Result<RirNode> {
1506 let (left_expr, right_expr) = match (&cmp.left, &cmp.right) {
1507 (Term::Variable(name), term) => {
1508 let col = var_env.get_column(name).ok_or_else(|| {
1509 XlogError::Compilation(format!("Variable {} not found in environment", name))
1510 })?;
1511 let typ = var_env.get_type(name).ok_or_else(|| {
1512 XlogError::Compilation(format!("Missing type for variable {}", name))
1513 })?;
1514 if let Some(const_val) = term_to_typed_const_value(term, typ)? {
1515 (Expr::Column(col), Expr::Const(const_val))
1516 } else {
1517 (
1518 self.term_to_expr(&cmp.left, var_env)?,
1519 self.term_to_expr(&cmp.right, var_env)?,
1520 )
1521 }
1522 }
1523 (term, Term::Variable(name)) => {
1524 let col = var_env.get_column(name).ok_or_else(|| {
1525 XlogError::Compilation(format!("Variable {} not found in environment", name))
1526 })?;
1527 let typ = var_env.get_type(name).ok_or_else(|| {
1528 XlogError::Compilation(format!("Missing type for variable {}", name))
1529 })?;
1530 if let Some(const_val) = term_to_typed_const_value(term, typ)? {
1531 (Expr::Const(const_val), Expr::Column(col))
1532 } else {
1533 (
1534 self.term_to_expr(&cmp.left, var_env)?,
1535 self.term_to_expr(&cmp.right, var_env)?,
1536 )
1537 }
1538 }
1539 _ => (
1540 self.term_to_expr(&cmp.left, var_env)?,
1541 self.term_to_expr(&cmp.right, var_env)?,
1542 ),
1543 };
1544
1545 let op = match cmp.op {
1546 CompOp::Eq => CompareOp::Eq,
1547 CompOp::Ne => CompareOp::Ne,
1548 CompOp::Lt => CompareOp::Lt,
1549 CompOp::Le => CompareOp::Le,
1550 CompOp::Gt => CompareOp::Gt,
1551 CompOp::Ge => CompareOp::Ge,
1552 };
1553
1554 Ok(RirNode::Filter {
1555 input: Box::new(input),
1556 predicate: Expr::Compare {
1557 left: Box::new(left_expr),
1558 op,
1559 right: Box::new(right_expr),
1560 },
1561 })
1562 }
1563
1564 fn term_to_expr(&self, term: &Term, var_env: &VariableEnv) -> Result<Expr> {
1566 match term {
1567 Term::Variable(name) => {
1568 if let Some(col) = var_env.get_column(name) {
1569 Ok(Expr::Column(col))
1570 } else {
1571 Err(XlogError::Compilation(format!(
1572 "Variable {} not found in environment",
1573 name
1574 )))
1575 }
1576 }
1577 Term::Anonymous => Err(XlogError::Compilation(
1578 "Anonymous wildcard '_' not allowed in comparisons".to_string(),
1579 )),
1580 Term::Integer(i) => Ok(Expr::Const(ConstValue::I64(*i))),
1581 Term::Float(f) => Ok(Expr::Const(ConstValue::F64(*f))),
1582 Term::String(s) => Ok(Expr::Const(ConstValue::Symbol(s.clone()))),
1583 Term::Symbol(id) => Ok(Expr::Const(ConstValue::Symbol(symbol::resolve(*id)))),
1584 Term::Aggregate(_) => Err(XlogError::Compilation(
1585 "Aggregates not allowed in comparisons".to_string(),
1586 )),
1587 Term::List(_) | Term::Cons { .. } | Term::Compound { .. } | Term::PredRef(_) => {
1588 Err(v085_term_not_lowerable("comparison", v085_term_kind(term)))
1589 }
1590 }
1591 }
1592
1593 fn apply_negation(
1595 &mut self,
1596 input: RirNode,
1597 neg_atom: &Atom,
1598 var_env: &VariableEnv,
1599 ) -> Result<RirNode> {
1600 let rel_id = self.get_or_create_rel_id(&neg_atom.predicate);
1601 let neg_scan = RirNode::Scan { rel: rel_id };
1602
1603 let neg_filtered = self.apply_constant_filters(neg_scan, neg_atom, 0)?;
1605
1606 let mut input_cols = Vec::new();
1608 let mut neg_cols = Vec::new();
1609
1610 for (neg_idx, term) in neg_atom.terms.iter().enumerate() {
1611 if let Term::Variable(name) = term {
1612 if let Some(col) = var_env.get_column(name) {
1613 input_cols.push(col);
1614 neg_cols.push(neg_idx);
1615 }
1616 }
1617 }
1618
1619 if input_cols.is_empty() {
1620 Ok(RirNode::Diff {
1624 left: Box::new(input),
1625 right: Box::new(neg_filtered),
1626 })
1627 } else {
1628 let neg_projected = if neg_cols.len() < neg_atom.terms.len() {
1630 let neg_proj_exprs: Vec<ProjectExpr> =
1631 neg_cols.iter().map(|&c| ProjectExpr::Column(c)).collect();
1632 RirNode::Project {
1633 input: Box::new(neg_filtered),
1634 columns: neg_proj_exprs,
1635 }
1636 } else {
1637 neg_filtered
1638 };
1639
1640 let input_proj_exprs: Vec<ProjectExpr> =
1648 input_cols.iter().map(|&c| ProjectExpr::Column(c)).collect();
1649 let input_projected = RirNode::Project {
1650 input: Box::new(input.clone()),
1651 columns: input_proj_exprs,
1652 };
1653
1654 let kept_keys = RirNode::Diff {
1656 left: Box::new(input_projected),
1657 right: Box::new(neg_projected),
1658 };
1659
1660 Ok(RirNode::Join {
1664 left: Box::new(input),
1665 right: Box::new(kept_keys),
1666 left_keys: input_cols.clone(),
1667 right_keys: (0..input_cols.len()).collect(),
1668 join_type: JoinType::Semi,
1669 })
1670 }
1671 }
1672
1673 fn is_identity_projection(proj: &[ProjectExpr], input_cols: usize) -> bool {
1674 if proj.len() != input_cols {
1675 return false;
1676 }
1677 proj.iter()
1678 .enumerate()
1679 .all(|(i, e)| matches!(e, ProjectExpr::Column(c) if *c == i))
1680 }
1681
1682 fn compute_head_projection(
1688 &self,
1689 head: &Atom,
1690 var_env: &VariableEnv,
1691 ) -> Result<Vec<ProjectExpr>> {
1692 let mut cols = Vec::with_capacity(head.terms.len());
1693
1694 for term in &head.terms {
1695 match term {
1696 Term::Variable(name) => {
1697 let col = var_env
1698 .get_column(name)
1699 .ok_or_else(|| XlogError::UnsafeVariable(name.clone()))?;
1700 cols.push(ProjectExpr::Column(col));
1701 }
1702 Term::Anonymous => {
1703 return Err(XlogError::Compilation(
1704 "Anonymous wildcard '_' not allowed in rule head".to_string(),
1705 ));
1706 }
1707 Term::Aggregate(_) => {
1708 return Err(XlogError::Compilation(
1709 "Aggregate term in non-aggregate rule head".to_string(),
1710 ));
1711 }
1712 Term::Integer(_) | Term::Float(_) | Term::String(_) | Term::Symbol(_) => {
1713 let (expr, typ) = term_to_project_const_expr(term)?;
1714 cols.push(ProjectExpr::Computed(expr, typ));
1715 }
1716 Term::List(_) | Term::Cons { .. } | Term::Compound { .. } | Term::PredRef(_) => {
1717 return Err(v085_term_not_lowerable(
1718 "rule head projection",
1719 v085_term_kind(term),
1720 ));
1721 }
1722 }
1723 }
1724
1725 Ok(cols)
1726 }
1727
1728 fn lower_aggregate_rule(
1730 &mut self,
1731 head: &Atom,
1732 body: RirNode,
1733 var_env: &VariableEnv,
1734 ) -> Result<RirNode> {
1735 let mut key_vars: Vec<String> = Vec::new();
1737 let mut key_var_to_pos: HashMap<String, usize> = HashMap::new();
1738 let mut key_src_cols: Vec<usize> = Vec::new();
1739
1740 let mut agg_specs: Vec<(AggOp, String)> = Vec::new();
1742 let mut agg_to_pos: HashMap<(AggOp, String), usize> = HashMap::new();
1743 let mut value_vars: Vec<String> = Vec::new();
1744 let mut value_var_to_pos: HashMap<String, usize> = HashMap::new();
1745 let mut value_src_cols: Vec<usize> = Vec::new();
1746
1747 for term in &head.terms {
1748 match term {
1749 Term::Variable(name) => {
1750 if !key_var_to_pos.contains_key(name) {
1751 let col = var_env
1752 .get_column(name)
1753 .ok_or_else(|| XlogError::UnsafeVariable(name.clone()))?;
1754 let pos = key_vars.len();
1755 key_vars.push(name.clone());
1756 key_var_to_pos.insert(name.clone(), pos);
1757 key_src_cols.push(col);
1758 }
1759 }
1760 Term::Aggregate(agg) => {
1761 let key = (agg.op, agg.variable.clone());
1762 if let std::collections::hash_map::Entry::Vacant(entry) = agg_to_pos.entry(key)
1763 {
1764 let col = var_env
1766 .get_column(&agg.variable)
1767 .ok_or_else(|| XlogError::UnsafeVariable(agg.variable.clone()))?;
1768
1769 let value_pos = *value_var_to_pos
1771 .entry(agg.variable.clone())
1772 .or_insert_with(|| {
1773 let p = value_vars.len();
1774 value_vars.push(agg.variable.clone());
1775 value_src_cols.push(col);
1776 p
1777 });
1778
1779 let agg_pos = agg_specs.len();
1780 agg_specs.push((agg.op, agg.variable.clone()));
1781 entry.insert(agg_pos);
1782
1783 let _ = value_pos;
1785 }
1786 }
1787 Term::Anonymous => {
1788 return Err(XlogError::Compilation(
1789 "Anonymous wildcard '_' not allowed in rule head".to_string(),
1790 ));
1791 }
1792 Term::Integer(_) | Term::Float(_) | Term::String(_) | Term::Symbol(_) => {
1793 }
1795 Term::List(_) | Term::Cons { .. } | Term::Compound { .. } | Term::PredRef(_) => {
1796 return Err(v085_term_not_lowerable(
1797 "aggregate rule head",
1798 v085_term_kind(term),
1799 ));
1800 }
1801 }
1802 }
1803
1804 if agg_specs.is_empty() {
1805 return Err(XlogError::Compilation(
1806 "Rule marked as aggregate but no aggregate terms found".to_string(),
1807 ));
1808 }
1809
1810 let mut group_input_cols: Vec<ProjectExpr> = Vec::new();
1813 let mut key_cols: Vec<usize> = Vec::new();
1814
1815 if key_src_cols.is_empty() {
1816 group_input_cols.push(ProjectExpr::Computed(
1817 Expr::Const(ConstValue::U32(0)),
1818 ScalarType::U32,
1819 ));
1820 key_cols.push(0);
1821 } else {
1822 for (i, &col) in key_src_cols.iter().enumerate() {
1823 group_input_cols.push(ProjectExpr::Column(col));
1824 key_cols.push(i);
1825 }
1826 }
1827
1828 let value_offset = group_input_cols.len();
1829 for &col in &value_src_cols {
1830 group_input_cols.push(ProjectExpr::Column(col));
1831 }
1832
1833 let group_input = RirNode::Project {
1834 input: Box::new(body),
1835 columns: group_input_cols,
1836 };
1837
1838 let mut aggs: Vec<(usize, CoreAggOp)> = Vec::with_capacity(agg_specs.len());
1840 for (op, var) in &agg_specs {
1841 let value_pos = *value_var_to_pos
1842 .get(var)
1843 .ok_or_else(|| XlogError::UnsafeVariable(var.clone()))?;
1844 let value_col = value_offset + value_pos;
1845 aggs.push((value_col, convert_agg_op(op)));
1846 }
1847
1848 let groupby = RirNode::GroupBy {
1849 input: Box::new(group_input),
1850 key_cols,
1851 aggs,
1852 };
1853
1854 let key_count = if key_src_cols.is_empty() {
1859 1
1860 } else {
1861 key_vars.len()
1862 };
1863
1864 let mut final_proj: Vec<ProjectExpr> = Vec::with_capacity(head.terms.len());
1865 for term in &head.terms {
1866 match term {
1867 Term::Variable(name) => {
1868 let idx = if key_src_cols.is_empty() {
1869 return Err(XlogError::UnsafeVariable(name.clone()));
1872 } else {
1873 *key_var_to_pos
1874 .get(name)
1875 .ok_or_else(|| XlogError::UnsafeVariable(name.clone()))?
1876 };
1877 final_proj.push(ProjectExpr::Column(idx));
1878 }
1879 Term::Aggregate(agg) => {
1880 let pos = *agg_to_pos
1881 .get(&(agg.op, agg.variable.clone()))
1882 .ok_or_else(|| XlogError::UnsafeVariable(agg.variable.clone()))?;
1883 final_proj.push(ProjectExpr::Column(key_count + pos));
1884 }
1885 Term::Anonymous => {
1886 return Err(XlogError::Compilation(
1887 "Anonymous wildcard '_' not allowed in rule head".to_string(),
1888 ));
1889 }
1890 Term::Integer(_) | Term::Float(_) | Term::String(_) | Term::Symbol(_) => {
1891 let (expr, typ) = term_to_project_const_expr(term)?;
1892 final_proj.push(ProjectExpr::Computed(expr, typ));
1893 }
1894 Term::List(_) | Term::Cons { .. } | Term::Compound { .. } | Term::PredRef(_) => {
1895 return Err(v085_term_not_lowerable(
1896 "aggregate rule projection",
1897 v085_term_kind(term),
1898 ));
1899 }
1900 }
1901 }
1902
1903 if final_proj.is_empty() {
1904 return Err(XlogError::Compilation(
1905 "Aggregate rule produced empty head projection".to_string(),
1906 ));
1907 }
1908
1909 Ok(RirNode::Project {
1910 input: Box::new(groupby),
1911 columns: final_proj,
1912 })
1913 }
1914
1915 pub(crate) fn infer_arith_type(
1917 &self,
1918 expr: &ArithExpr,
1919 var_env: &VariableEnv,
1920 ) -> Result<ScalarType> {
1921 match expr {
1922 ArithExpr::Variable(name) => var_env.get_type(name).ok_or_else(|| {
1923 XlogError::Compilation(format!("Unknown variable {} in arithmetic", name))
1924 }),
1925 ArithExpr::Integer(_) => Ok(ScalarType::I64),
1926 ArithExpr::Float(_) => Ok(ScalarType::F64),
1927
1928 ArithExpr::Add(l, r)
1929 | ArithExpr::Sub(l, r)
1930 | ArithExpr::Mul(l, r)
1931 | ArithExpr::Div(l, r) => {
1932 let lt = self.infer_arith_type(l, var_env)?;
1933 let rt = self.infer_arith_type(r, var_env)?;
1934
1935 if lt != rt {
1936 return Err(XlogError::Compilation(format!(
1937 "Type mismatch in arithmetic: {:?} vs {:?}. Use cast() for conversion.",
1938 lt, rt
1939 )));
1940 }
1941
1942 if !Self::is_numeric_type(<) {
1943 return Err(XlogError::Compilation(format!(
1944 "Arithmetic requires numeric type, got {:?}",
1945 lt
1946 )));
1947 }
1948
1949 Ok(lt)
1950 }
1951
1952 ArithExpr::Mod(l, r) => {
1953 let lt = self.infer_arith_type(l, var_env)?;
1954 let rt = self.infer_arith_type(r, var_env)?;
1955
1956 if lt != rt {
1957 return Err(XlogError::Compilation(format!(
1958 "Type mismatch in mod: {:?} vs {:?}",
1959 lt, rt
1960 )));
1961 }
1962
1963 if matches!(lt, ScalarType::F32 | ScalarType::F64) {
1964 return Err(XlogError::Compilation(
1965 "Modulo (%) not supported for floating point".into(),
1966 ));
1967 }
1968
1969 Ok(lt)
1970 }
1971
1972 ArithExpr::Abs(inner) => {
1973 let t = self.infer_arith_type(inner, var_env)?;
1974 if !Self::is_numeric_type(&t) {
1975 return Err(XlogError::Compilation(format!(
1976 "abs requires numeric type, got {:?}",
1977 t
1978 )));
1979 }
1980 Ok(t)
1981 }
1982
1983 ArithExpr::Min(l, r) | ArithExpr::Max(l, r) => {
1984 let lt = self.infer_arith_type(l, var_env)?;
1985 let rt = self.infer_arith_type(r, var_env)?;
1986
1987 if lt != rt {
1988 return Err(XlogError::Compilation(format!(
1989 "Type mismatch in min/max: {:?} vs {:?}",
1990 lt, rt
1991 )));
1992 }
1993
1994 if !Self::is_numeric_type(<) {
1995 return Err(XlogError::Compilation(format!(
1996 "min/max requires numeric type, got {:?}",
1997 lt
1998 )));
1999 }
2000
2001 Ok(lt)
2002 }
2003
2004 ArithExpr::Pow(base, exp) => {
2005 let base_t = self.infer_arith_type(base, var_env)?;
2006 let exp_t = self.infer_arith_type(exp, var_env)?;
2007
2008 if !Self::is_numeric_type(&base_t) || !Self::is_numeric_type(&exp_t) {
2009 return Err(XlogError::Compilation(format!(
2010 "pow requires numeric operands, got {:?} and {:?}",
2011 base_t, exp_t
2012 )));
2013 }
2014
2015 Ok(ScalarType::F64)
2017 }
2018
2019 ArithExpr::Cast(_, target) => Ok(*target),
2020
2021 ArithExpr::FuncCall { name, .. } => Err(XlogError::Compilation(format!(
2022 "User-defined function '{}' must be inlined before lowering",
2023 name
2024 ))),
2025
2026 ArithExpr::Conditional {
2027 then_expr,
2028 else_expr,
2029 ..
2030 } => {
2031 let then_type = self.infer_arith_type(then_expr, var_env)?;
2033 let else_type = self.infer_arith_type(else_expr, var_env)?;
2034 if then_type != else_type {
2035 return Err(XlogError::Compilation(format!(
2036 "Conditional branches have different types: {:?} vs {:?}",
2037 then_type, else_type
2038 )));
2039 }
2040 Ok(then_type)
2041 }
2042 }
2043 }
2044
2045 fn is_numeric_type(t: &ScalarType) -> bool {
2046 matches!(
2047 t,
2048 ScalarType::I32
2049 | ScalarType::I64
2050 | ScalarType::U32
2051 | ScalarType::U64
2052 | ScalarType::F32
2053 | ScalarType::F64
2054 )
2055 }
2056
2057 fn arith_to_expr(&self, arith: &ArithExpr, var_env: &VariableEnv) -> Result<Expr> {
2059 match arith {
2060 ArithExpr::Variable(name) => {
2061 let col = var_env.get_column(name).ok_or_else(|| {
2062 XlogError::Compilation(format!(
2063 "Variable {} not bound before use in arithmetic",
2064 name
2065 ))
2066 })?;
2067 Ok(Expr::Column(col))
2068 }
2069 ArithExpr::Integer(i) => Ok(Expr::Const(ConstValue::I64(*i))),
2070 ArithExpr::Float(f) => Ok(Expr::Const(ConstValue::F64(*f))),
2071
2072 ArithExpr::Add(l, r) => Ok(Expr::Add(
2073 Box::new(self.arith_to_expr(l, var_env)?),
2074 Box::new(self.arith_to_expr(r, var_env)?),
2075 )),
2076 ArithExpr::Sub(l, r) => Ok(Expr::Sub(
2077 Box::new(self.arith_to_expr(l, var_env)?),
2078 Box::new(self.arith_to_expr(r, var_env)?),
2079 )),
2080 ArithExpr::Mul(l, r) => Ok(Expr::Mul(
2081 Box::new(self.arith_to_expr(l, var_env)?),
2082 Box::new(self.arith_to_expr(r, var_env)?),
2083 )),
2084 ArithExpr::Div(l, r) => Ok(Expr::Div(
2085 Box::new(self.arith_to_expr(l, var_env)?),
2086 Box::new(self.arith_to_expr(r, var_env)?),
2087 )),
2088 ArithExpr::Mod(l, r) => Ok(Expr::Mod(
2089 Box::new(self.arith_to_expr(l, var_env)?),
2090 Box::new(self.arith_to_expr(r, var_env)?),
2091 )),
2092
2093 ArithExpr::Abs(e) => Ok(Expr::Abs(Box::new(self.arith_to_expr(e, var_env)?))),
2094 ArithExpr::Min(l, r) => Ok(Expr::Min(
2095 Box::new(self.arith_to_expr(l, var_env)?),
2096 Box::new(self.arith_to_expr(r, var_env)?),
2097 )),
2098 ArithExpr::Max(l, r) => Ok(Expr::Max(
2099 Box::new(self.arith_to_expr(l, var_env)?),
2100 Box::new(self.arith_to_expr(r, var_env)?),
2101 )),
2102 ArithExpr::Pow(l, r) => Ok(Expr::Pow(
2103 Box::new(self.arith_to_expr(l, var_env)?),
2104 Box::new(self.arith_to_expr(r, var_env)?),
2105 )),
2106 ArithExpr::Cast(e, t) => Ok(Expr::Cast(Box::new(self.arith_to_expr(e, var_env)?), *t)),
2107
2108 ArithExpr::FuncCall { name, .. } => Err(XlogError::Compilation(format!(
2109 "User-defined function '{}' must be inlined before lowering",
2110 name
2111 ))),
2112
2113 ArithExpr::Conditional {
2114 cond_left,
2115 cond_op,
2116 cond_right,
2117 then_expr,
2118 else_expr,
2119 } => {
2120 let ir_cond_op = match cond_op {
2122 CompOp::Eq => CompareOp::Eq,
2123 CompOp::Ne => CompareOp::Ne,
2124 CompOp::Lt => CompareOp::Lt,
2125 CompOp::Le => CompareOp::Le,
2126 CompOp::Gt => CompareOp::Gt,
2127 CompOp::Ge => CompareOp::Ge,
2128 };
2129
2130 let condition = Expr::Compare {
2132 left: Box::new(self.arith_to_expr(cond_left, var_env)?),
2133 op: ir_cond_op,
2134 right: Box::new(self.arith_to_expr(cond_right, var_env)?),
2135 };
2136
2137 let then_ir = self.arith_to_expr(then_expr, var_env)?;
2139 let else_ir = self.arith_to_expr(else_expr, var_env)?;
2140
2141 Ok(Expr::Conditional {
2142 condition: Box::new(condition),
2143 then_expr: Box::new(then_ir),
2144 else_expr: Box::new(else_ir),
2145 })
2146 }
2147 }
2148 }
2149
2150 fn lower_is_expr(
2152 &mut self,
2153 is_expr: &IsExpr,
2154 input: RirNode,
2155 var_env: &mut VariableEnv,
2156 ) -> Result<RirNode> {
2157 if var_env.contains(&is_expr.target) {
2159 return Err(XlogError::Compilation(format!(
2160 "Variable {} already bound; 'is' requires fresh variable",
2161 is_expr.target
2162 )));
2163 }
2164
2165 for var in is_expr.expr.variables() {
2167 if !var_env.contains(var) {
2168 return Err(XlogError::Compilation(format!(
2169 "Variable {} used in arithmetic but not bound",
2170 var
2171 )));
2172 }
2173 }
2174
2175 let result_type = self.infer_arith_type(&is_expr.expr, var_env)?;
2177
2178 let ir_expr = self.arith_to_expr(&is_expr.expr, var_env)?;
2180
2181 let num_cols = var_env.column_count();
2183 let mut proj_exprs: Vec<ProjectExpr> = (0..num_cols).map(ProjectExpr::Column).collect();
2184 proj_exprs.push(ProjectExpr::Computed(ir_expr, result_type));
2185
2186 var_env.bind(&is_expr.target, num_cols, result_type);
2188
2189 Ok(RirNode::Project {
2190 input: Box::new(input),
2191 columns: proj_exprs,
2192 })
2193 }
2194}
2195
2196pub(crate) struct VariableEnv {
2198 occurrences: HashMap<String, Vec<(String, usize, usize)>>,
2200 total_cols: usize,
2202 types: HashMap<String, ScalarType>,
2204}
2205
2206impl VariableEnv {
2207 fn new() -> Self {
2208 Self {
2209 occurrences: HashMap::new(),
2210 total_cols: 0,
2211 types: HashMap::new(),
2212 }
2213 }
2214
2215 fn add_occurrence(&mut self, var: &str, pred: String, atom_pos: usize, global_col: usize) {
2216 self.occurrences
2217 .entry(var.to_string())
2218 .or_default()
2219 .push((pred, atom_pos, global_col));
2220 }
2221
2222 fn get_column(&self, var: &str) -> Option<usize> {
2223 self.occurrences
2224 .get(var)
2225 .and_then(|occs| occs.first())
2226 .map(|(_, _, col)| *col)
2227 }
2228
2229 fn bind(&mut self, name: &str, column: usize, typ: ScalarType) {
2231 self.types.insert(name.to_string(), typ);
2232 self.occurrences
2234 .entry(name.to_string())
2235 .or_default()
2236 .push(("".to_string(), 0, column));
2237 if column >= self.total_cols {
2240 self.total_cols = column + 1;
2241 }
2242 }
2243
2244 fn get_type(&self, name: &str) -> Option<ScalarType> {
2246 self.types.get(name).copied()
2247 }
2248
2249 fn contains(&self, name: &str) -> bool {
2251 self.occurrences.contains_key(name)
2252 }
2253
2254 fn column_count(&self) -> usize {
2256 self.total_cols
2257 }
2258}
2259
2260fn infer_term_type(term: &Term) -> ScalarType {
2262 match term {
2263 Term::Variable(_) | Term::Anonymous => ScalarType::U64, Term::Integer(i) => {
2265 if *i >= 0 && *i <= u32::MAX as i64 {
2266 ScalarType::U32
2267 } else {
2268 ScalarType::I64
2269 }
2270 }
2271 Term::Float(_) => ScalarType::F64,
2272 Term::String(_) | Term::Symbol(_) => ScalarType::Symbol,
2273 Term::List(_) | Term::Cons { .. } | Term::Compound { .. } | Term::PredRef(_) => {
2274 ScalarType::U64
2275 }
2276 Term::Aggregate(agg) => match agg.op {
2277 AggOp::Count => ScalarType::U32,
2278 AggOp::Sum => ScalarType::U64,
2279 AggOp::Min | AggOp::Max => ScalarType::U32,
2280 AggOp::LogSumExp => ScalarType::F64,
2281 },
2282 }
2283}
2284
2285fn sort_labels_from_terms(terms: &[Term]) -> Vec<String> {
2286 terms
2287 .iter()
2288 .enumerate()
2289 .map(|(idx, term)| match term {
2290 Term::Variable(name) if !name.trim().is_empty() => name.clone(),
2291 Term::Aggregate(agg) => format!("{:?}_{}", agg.op, agg.variable),
2292 Term::List(_) => format!("list{}", idx),
2293 Term::Cons { .. } => format!("cons{}", idx),
2294 Term::Compound { functor, .. } => functor.clone(),
2295 Term::PredRef(name) => name.clone(),
2296 _ => format!("c{}", idx),
2297 })
2298 .collect()
2299}
2300
2301fn term_to_const_value(term: &Term) -> Option<ConstValue> {
2303 match term {
2304 Term::Integer(i) => Some(ConstValue::I64(*i)),
2305 Term::Float(f) => Some(ConstValue::F64(*f)),
2306 Term::String(s) => Some(ConstValue::Symbol(s.clone())),
2307 Term::Symbol(id) => Some(ConstValue::Symbol(symbol::resolve(*id))),
2308 Term::Variable(_)
2309 | Term::Anonymous
2310 | Term::Aggregate(_)
2311 | Term::List(_)
2312 | Term::Cons { .. }
2313 | Term::Compound { .. }
2314 | Term::PredRef(_) => None,
2315 }
2316}
2317
2318fn term_to_typed_const_value(term: &Term, expected: ScalarType) -> Result<Option<ConstValue>> {
2319 let const_val = match term {
2320 Term::Integer(i) => match expected {
2321 ScalarType::U32 => {
2322 if *i >= 0 && *i <= u32::MAX as i64 {
2323 ConstValue::U32(*i as u32)
2324 } else {
2325 return Err(XlogError::Compilation(format!(
2326 "Integer literal {} out of range for {:?}",
2327 i, expected
2328 )));
2329 }
2330 }
2331 ScalarType::U64 => {
2332 if *i >= 0 {
2333 ConstValue::U64(*i as u64)
2334 } else {
2335 return Err(XlogError::Compilation(format!(
2336 "Integer literal {} out of range for {:?}",
2337 i, expected
2338 )));
2339 }
2340 }
2341 ScalarType::I32 => {
2342 if *i >= i32::MIN as i64 && *i <= i32::MAX as i64 {
2343 ConstValue::I32(*i as i32)
2344 } else {
2345 return Err(XlogError::Compilation(format!(
2346 "Integer literal {} out of range for {:?}",
2347 i, expected
2348 )));
2349 }
2350 }
2351 ScalarType::I64 => ConstValue::I64(*i),
2352 ScalarType::F32 => {
2353 let value = *i as f64;
2354 if value < f32::MIN as f64 || value > f32::MAX as f64 {
2355 return Err(XlogError::Compilation(format!(
2356 "Integer literal {} out of range for {:?}",
2357 i, expected
2358 )));
2359 }
2360 ConstValue::F32(value as f32)
2361 }
2362 ScalarType::F64 => ConstValue::F64(*i as f64),
2363 ScalarType::Bool => {
2364 if *i == 0 || *i == 1 {
2365 ConstValue::Bool(*i == 1)
2366 } else {
2367 return Err(XlogError::Compilation(format!(
2368 "Integer literal {} not valid for {:?}",
2369 i, expected
2370 )));
2371 }
2372 }
2373 ScalarType::Symbol => {
2374 return Err(XlogError::Compilation(format!(
2375 "Integer literal {} not valid for {:?}",
2376 i, expected
2377 )));
2378 }
2379 },
2380 Term::Float(f) => match expected {
2381 ScalarType::F32 => {
2382 if !f.is_finite() {
2383 return Err(XlogError::Compilation(format!(
2384 "Float literal {} not valid for {:?}",
2385 f, expected
2386 )));
2387 }
2388 if *f < f32::MIN as f64 || *f > f32::MAX as f64 {
2389 return Err(XlogError::Compilation(format!(
2390 "Float literal {} out of range for {:?}",
2391 f, expected
2392 )));
2393 }
2394 ConstValue::F32(*f as f32)
2395 }
2396 ScalarType::F64 => ConstValue::F64(*f),
2397 ScalarType::U32
2398 | ScalarType::U64
2399 | ScalarType::I32
2400 | ScalarType::I64
2401 | ScalarType::Bool
2402 | ScalarType::Symbol => {
2403 return Err(XlogError::Compilation(format!(
2404 "Float literal {} not valid for {:?}",
2405 f, expected
2406 )));
2407 }
2408 },
2409 Term::String(s) => {
2410 if expected == ScalarType::Symbol {
2411 ConstValue::Symbol(s.clone())
2412 } else {
2413 return Err(XlogError::Compilation(format!(
2414 "String literal {} not valid for {:?}",
2415 s, expected
2416 )));
2417 }
2418 }
2419 Term::Symbol(id) => {
2420 if expected == ScalarType::Symbol {
2421 ConstValue::Symbol(symbol::resolve(*id))
2422 } else {
2423 return Err(XlogError::Compilation(format!(
2424 "Symbol literal {} not valid for {:?}",
2425 symbol::resolve(*id),
2426 expected
2427 )));
2428 }
2429 }
2430 Term::Variable(_)
2431 | Term::Anonymous
2432 | Term::Aggregate(_)
2433 | Term::List(_)
2434 | Term::Cons { .. }
2435 | Term::Compound { .. }
2436 | Term::PredRef(_) => return Ok(None),
2437 };
2438
2439 Ok(Some(const_val))
2440}
2441
2442fn term_to_project_const_expr(term: &Term) -> Result<(Expr, ScalarType)> {
2443 match term {
2444 Term::Integer(i) => {
2445 if *i >= 0 && *i <= u32::MAX as i64 {
2446 Ok((Expr::Const(ConstValue::U32(*i as u32)), ScalarType::U32))
2447 } else {
2448 Ok((Expr::Const(ConstValue::I64(*i)), ScalarType::I64))
2449 }
2450 }
2451 Term::Float(f) => Ok((Expr::Const(ConstValue::F64(*f)), ScalarType::F64)),
2452 Term::String(s) => Ok((
2453 Expr::Const(ConstValue::Symbol(s.clone())),
2454 ScalarType::Symbol,
2455 )),
2456 Term::Symbol(id) => Ok((
2457 Expr::Const(ConstValue::Symbol(symbol::resolve(*id))),
2458 ScalarType::Symbol,
2459 )),
2460 Term::Variable(_)
2461 | Term::Anonymous
2462 | Term::Aggregate(_)
2463 | Term::List(_)
2464 | Term::Cons { .. }
2465 | Term::Compound { .. }
2466 | Term::PredRef(_) => Err(XlogError::Compilation("Expected constant term".to_string())),
2467 }
2468}
2469
2470fn convert_agg_op(op: &AggOp) -> CoreAggOp {
2472 match op {
2473 AggOp::Count => CoreAggOp::Count,
2474 AggOp::Sum => CoreAggOp::Sum,
2475 AggOp::Min => CoreAggOp::Min,
2476 AggOp::Max => CoreAggOp::Max,
2477 AggOp::LogSumExp => CoreAggOp::LogSumExp,
2478 }
2479}
2480
2481#[cfg(test)]
2485mod arith_type_tests {
2486 use super::*;
2487 use crate::ast::ArithExpr;
2488
2489 #[test]
2490 fn test_arith_type_inference_same_type() {
2491 let lowerer = Lowerer::new();
2493 let mut var_env = VariableEnv::new();
2494 var_env.bind("X", 0, ScalarType::I64);
2495 var_env.bind("Y", 1, ScalarType::I64);
2496
2497 let expr = ArithExpr::Add(
2498 Box::new(ArithExpr::Variable("X".to_string())),
2499 Box::new(ArithExpr::Variable("Y".to_string())),
2500 );
2501 let result = lowerer.infer_arith_type(&expr, &var_env);
2502 assert!(result.is_ok());
2503 assert_eq!(result.unwrap(), ScalarType::I64);
2504 }
2505
2506 #[test]
2507 fn test_arith_type_inference_mismatch() {
2508 let lowerer = Lowerer::new();
2510 let mut var_env = VariableEnv::new();
2511 var_env.bind("X", 0, ScalarType::I64);
2512 var_env.bind("Y", 1, ScalarType::F64);
2513
2514 let expr = ArithExpr::Add(
2515 Box::new(ArithExpr::Variable("X".to_string())),
2516 Box::new(ArithExpr::Variable("Y".to_string())),
2517 );
2518 let result = lowerer.infer_arith_type(&expr, &var_env);
2519 assert!(result.is_err());
2520 }
2521}
2522
2523#[cfg(test)]
2524mod tests {
2525 use super::*;
2526 use crate::ast::*;
2527
2528 fn pred_decl(name: &str, types: Vec<ScalarType>) -> PredDecl {
2529 let type_refs: Vec<TypeRef> = types.into_iter().map(TypeRef::Scalar).collect();
2530 let columns = type_refs
2531 .iter()
2532 .cloned()
2533 .map(|typ| PredColumn { name: None, typ })
2534 .collect();
2535 PredDecl {
2536 name: name.to_string(),
2537 types: type_refs,
2538 columns,
2539 is_private: false,
2540 }
2541 }
2542
2543 fn edge_atom(x: &str, y: &str) -> Atom {
2545 Atom {
2546 predicate: "edge".to_string(),
2547 terms: vec![Term::Variable(x.to_string()), Term::Variable(y.to_string())],
2548 }
2549 }
2550
2551 fn reach_atom(x: &str, y: &str) -> Atom {
2553 Atom {
2554 predicate: "reach".to_string(),
2555 terms: vec![Term::Variable(x.to_string()), Term::Variable(y.to_string())],
2556 }
2557 }
2558
2559 fn node_atom(x: &str) -> Atom {
2561 Atom {
2562 predicate: "node".to_string(),
2563 terms: vec![Term::Variable(x.to_string())],
2564 }
2565 }
2566
2567 #[test]
2568 fn test_lowerer_new() {
2569 let lowerer = Lowerer::new();
2570 assert!(lowerer.schemas.is_empty());
2571 assert!(lowerer.strata.is_empty());
2572 assert_eq!(lowerer.next_rel_id, 0);
2573 }
2574
2575 #[test]
2576 fn test_get_or_create_rel_id() {
2577 let mut lowerer = Lowerer::new();
2578 let id1 = lowerer.get_or_create_rel_id("edge");
2579 let id2 = lowerer.get_or_create_rel_id("reach");
2580 let id3 = lowerer.get_or_create_rel_id("edge");
2581
2582 assert_eq!(id1, RelId(0));
2583 assert_eq!(id2, RelId(1));
2584 assert_eq!(id3, RelId(0)); }
2586
2587 #[test]
2588 fn test_infer_schemas_from_facts() {
2589 let mut program = Program::new();
2590 program.rules.push(Rule {
2591 head: Atom {
2592 predicate: "edge".to_string(),
2593 terms: vec![Term::Integer(1), Term::Integer(2)],
2594 },
2595 body: vec![],
2596 });
2597
2598 let mut lowerer = Lowerer::new();
2599 lowerer.infer_schemas(&program).unwrap();
2600
2601 assert!(lowerer.schemas.contains_key("edge"));
2602 let schema = lowerer.schemas.get("edge").unwrap();
2603 assert_eq!(schema.arity(), 2);
2604 }
2605
2606 #[test]
2607 fn test_lower_simple_rule() {
2608 let rule = Rule {
2610 head: reach_atom("X", "Y"),
2611 body: vec![BodyLiteral::Positive(edge_atom("X", "Y"))],
2612 };
2613
2614 let mut lowerer = Lowerer::new();
2615 lowerer.schemas.insert(
2616 "edge".to_string(),
2617 Schema::new(vec![
2618 ("c0".to_string(), ScalarType::U32),
2619 ("c1".to_string(), ScalarType::U32),
2620 ]),
2621 );
2622
2623 let result = lowerer.lower_rule(&rule);
2624 assert!(result.is_ok());
2625
2626 let node = result.unwrap();
2627 assert!(matches!(node, RirNode::Scan { .. }));
2629 }
2630
2631 #[test]
2632 fn test_lower_join_rule() {
2633 let rule = Rule {
2635 head: Atom {
2636 predicate: "reach".to_string(),
2637 terms: vec![
2638 Term::Variable("X".to_string()),
2639 Term::Variable("Z".to_string()),
2640 ],
2641 },
2642 body: vec![
2643 BodyLiteral::Positive(reach_atom("X", "Y")),
2644 BodyLiteral::Positive(edge_atom("Y", "Z")),
2645 ],
2646 };
2647
2648 let mut lowerer = Lowerer::new();
2649 lowerer.schemas.insert(
2650 "reach".to_string(),
2651 Schema::new(vec![
2652 ("c0".to_string(), ScalarType::U32),
2653 ("c1".to_string(), ScalarType::U32),
2654 ]),
2655 );
2656 lowerer.schemas.insert(
2657 "edge".to_string(),
2658 Schema::new(vec![
2659 ("c0".to_string(), ScalarType::U32),
2660 ("c1".to_string(), ScalarType::U32),
2661 ]),
2662 );
2663
2664 let result = lowerer.lower_rule(&rule);
2665 assert!(result.is_ok());
2666
2667 let node = result.unwrap();
2668 if let RirNode::Project { input, columns } = node {
2670 assert_eq!(
2672 columns,
2673 vec![ProjectExpr::Column(0), ProjectExpr::Column(3)]
2674 );
2675 assert!(matches!(*input, RirNode::Join { .. }));
2676 if let RirNode::Join {
2677 left_keys,
2678 right_keys,
2679 ..
2680 } = *input
2681 {
2682 assert_eq!(left_keys, vec![1]); assert_eq!(right_keys, vec![0]); }
2685 } else {
2686 panic!("Expected Project node");
2687 }
2688 }
2689
2690 #[test]
2691 fn test_join_order_prefers_smaller_relation() {
2692 let rule = Rule {
2694 head: Atom {
2695 predicate: "out".to_string(),
2696 terms: vec![Term::Variable("X".to_string())],
2697 },
2698 body: vec![
2699 BodyLiteral::Positive(Atom {
2700 predicate: "big".to_string(),
2701 terms: vec![Term::Variable("X".to_string())],
2702 }),
2703 BodyLiteral::Positive(Atom {
2704 predicate: "small".to_string(),
2705 terms: vec![Term::Variable("X".to_string())],
2706 }),
2707 ],
2708 };
2709
2710 let mut lowerer = Lowerer::new();
2711 lowerer.schemas.insert(
2712 "big".to_string(),
2713 Schema::new(vec![("c0".to_string(), ScalarType::U32)]),
2714 );
2715 lowerer.schemas.insert(
2716 "small".to_string(),
2717 Schema::new(vec![("c0".to_string(), ScalarType::U32)]),
2718 );
2719
2720 let big_id = lowerer.get_or_create_rel_id("big");
2722 let small_id = lowerer.get_or_create_rel_id("small");
2723 assert_eq!(big_id, RelId(0));
2724 assert_eq!(small_id, RelId(1));
2725
2726 lowerer.est_cardinality.insert("big".to_string(), 10_000);
2728 lowerer.est_cardinality.insert("small".to_string(), 10);
2729
2730 let node = lowerer.lower_rule(&rule).unwrap();
2731 let join = match node {
2732 RirNode::Project { input, .. } => *input,
2733 other => other,
2734 };
2735
2736 match join {
2737 RirNode::Join { left, right, .. } => {
2738 assert!(matches!(*left, RirNode::Scan { rel } if rel == big_id));
2740 assert!(matches!(*right, RirNode::Scan { rel } if rel == small_id));
2741 }
2742 other => panic!("Expected Join node, got {:?}", other),
2743 }
2744 }
2745
2746 #[test]
2747 fn test_lower_negation() {
2748 let rule = Rule {
2750 head: Atom {
2751 predicate: "isolated".to_string(),
2752 terms: vec![Term::Variable("X".to_string())],
2753 },
2754 body: vec![
2755 BodyLiteral::Positive(node_atom("X")),
2756 BodyLiteral::Negated(Atom {
2757 predicate: "edge".to_string(),
2758 terms: vec![
2759 Term::Variable("X".to_string()),
2760 Term::Variable("_".to_string()),
2761 ],
2762 }),
2763 ],
2764 };
2765
2766 let mut lowerer = Lowerer::new();
2767 lowerer.schemas.insert(
2768 "node".to_string(),
2769 Schema::new(vec![("c0".to_string(), ScalarType::U32)]),
2770 );
2771 lowerer.schemas.insert(
2772 "edge".to_string(),
2773 Schema::new(vec![
2774 ("c0".to_string(), ScalarType::U32),
2775 ("c1".to_string(), ScalarType::U32),
2776 ]),
2777 );
2778
2779 let result = lowerer.lower_rule(&rule);
2780 assert!(result.is_ok());
2781
2782 let node = result.unwrap();
2784 fn contains_diff_or_semi(node: &RirNode) -> bool {
2786 match node {
2787 RirNode::Diff { .. } => true,
2788 RirNode::Join {
2789 join_type: JoinType::Semi,
2790 ..
2791 } => true,
2792 RirNode::Join { left, right, .. } => {
2793 contains_diff_or_semi(left) || contains_diff_or_semi(right)
2794 }
2795 RirNode::Project { input, .. } => contains_diff_or_semi(input),
2796 RirNode::Filter { input, .. } => contains_diff_or_semi(input),
2797 _ => false,
2798 }
2799 }
2800 assert!(contains_diff_or_semi(&node));
2801 }
2802
2803 #[test]
2804 fn test_lower_comparison() {
2805 let rule = Rule {
2807 head: Atom {
2808 predicate: "greater".to_string(),
2809 terms: vec![
2810 Term::Variable("X".to_string()),
2811 Term::Variable("Y".to_string()),
2812 ],
2813 },
2814 body: vec![
2815 BodyLiteral::Positive(Atom {
2816 predicate: "pair".to_string(),
2817 terms: vec![
2818 Term::Variable("X".to_string()),
2819 Term::Variable("Y".to_string()),
2820 ],
2821 }),
2822 BodyLiteral::Comparison(Comparison {
2823 left: Term::Variable("X".to_string()),
2824 op: CompOp::Gt,
2825 right: Term::Variable("Y".to_string()),
2826 }),
2827 ],
2828 };
2829
2830 let mut lowerer = Lowerer::new();
2831 lowerer.schemas.insert(
2832 "pair".to_string(),
2833 Schema::new(vec![
2834 ("c0".to_string(), ScalarType::U32),
2835 ("c1".to_string(), ScalarType::U32),
2836 ]),
2837 );
2838
2839 let result = lowerer.lower_rule(&rule);
2840 assert!(result.is_ok());
2841
2842 let node = result.unwrap();
2843 fn contains_filter(node: &RirNode) -> bool {
2845 match node {
2846 RirNode::Filter { .. } => true,
2847 RirNode::Project { input, .. } => contains_filter(input),
2848 RirNode::Join { left, right, .. } => {
2849 contains_filter(left) || contains_filter(right)
2850 }
2851 _ => false,
2852 }
2853 }
2854 assert!(contains_filter(&node));
2855 }
2856
2857 #[test]
2858 fn test_lower_constant_filter() {
2859 let rule = Rule {
2861 head: Atom {
2862 predicate: "specific_edge".to_string(),
2863 terms: vec![Term::Variable("Y".to_string())],
2864 },
2865 body: vec![BodyLiteral::Positive(Atom {
2866 predicate: "edge".to_string(),
2867 terms: vec![Term::Integer(1), Term::Variable("Y".to_string())],
2868 })],
2869 };
2870
2871 let mut lowerer = Lowerer::new();
2872 lowerer.schemas.insert(
2873 "edge".to_string(),
2874 Schema::new(vec![
2875 ("c0".to_string(), ScalarType::U32),
2876 ("c1".to_string(), ScalarType::U32),
2877 ]),
2878 );
2879
2880 let result = lowerer.lower_rule(&rule);
2881 assert!(result.is_ok());
2882
2883 let node = result.unwrap();
2884 fn has_const_filter(node: &RirNode) -> bool {
2886 match node {
2887 RirNode::Filter {
2888 predicate: Expr::Compare { right, .. },
2889 ..
2890 } => matches!(**right, Expr::Const(_)),
2891 RirNode::Project { input, .. } => has_const_filter(input),
2892 _ => false,
2893 }
2894 }
2895 assert!(has_const_filter(&node));
2896 }
2897
2898 #[test]
2899 fn test_lower_repeated_variable_filter() {
2900 let rule = Rule {
2902 head: Atom {
2903 predicate: "self_loop".to_string(),
2904 terms: vec![Term::Variable("X".to_string())],
2905 },
2906 body: vec![BodyLiteral::Positive(Atom {
2907 predicate: "edge".to_string(),
2908 terms: vec![
2909 Term::Variable("X".to_string()),
2910 Term::Variable("X".to_string()),
2911 ],
2912 })],
2913 };
2914
2915 let mut lowerer = Lowerer::new();
2916 lowerer.schemas.insert(
2917 "edge".to_string(),
2918 Schema::new(vec![
2919 ("c0".to_string(), ScalarType::U32),
2920 ("c1".to_string(), ScalarType::U32),
2921 ]),
2922 );
2923
2924 let node = lowerer.lower_rule(&rule).expect("lower_rule failed");
2925
2926 fn has_col_eq_filter(node: &RirNode) -> bool {
2927 match node {
2928 RirNode::Filter { predicate, .. } => match predicate {
2929 Expr::Compare {
2930 left,
2931 op: CompareOp::Eq,
2932 right,
2933 } => {
2934 matches!((&**left, &**right), (Expr::Column(0), Expr::Column(1)))
2935 || matches!((&**left, &**right), (Expr::Column(1), Expr::Column(0)))
2936 }
2937 Expr::And(exprs) => exprs.iter().any(|e| match e {
2938 Expr::Compare {
2939 left,
2940 op: CompareOp::Eq,
2941 right,
2942 } => {
2943 matches!((&**left, &**right), (Expr::Column(0), Expr::Column(1)))
2944 || matches!((&**left, &**right), (Expr::Column(1), Expr::Column(0)))
2945 }
2946 _ => false,
2947 }),
2948 _ => false,
2949 },
2950 RirNode::Project { input, .. } => has_col_eq_filter(input),
2951 _ => false,
2952 }
2953 }
2954
2955 assert!(has_col_eq_filter(&node));
2956 }
2957
2958 #[test]
2959 fn test_lower_program_simple() {
2960 let mut program = Program::new();
2961
2962 program.rules.push(Rule {
2964 head: Atom {
2965 predicate: "edge".to_string(),
2966 terms: vec![Term::Integer(1), Term::Integer(2)],
2967 },
2968 body: vec![],
2969 });
2970
2971 program.rules.push(Rule {
2973 head: reach_atom("X", "Y"),
2974 body: vec![BodyLiteral::Positive(edge_atom("X", "Y"))],
2975 });
2976
2977 let mut lowerer = Lowerer::new();
2978 lowerer.set_strata(vec![vec!["edge".to_string()], vec!["reach".to_string()]]);
2979
2980 let result = lowerer.lower_program(&program);
2981 assert!(result.is_ok());
2982
2983 let plan = result.unwrap();
2984 assert!(!plan.sccs.is_empty());
2985 }
2986
2987 #[test]
2988 fn test_variable_env() {
2989 let mut env = VariableEnv::new();
2990 env.add_occurrence("X", "edge".to_string(), 0, 0);
2991 env.add_occurrence("Y", "edge".to_string(), 1, 1);
2992 env.add_occurrence("Y", "node".to_string(), 0, 2);
2993
2994 assert_eq!(env.get_column("X"), Some(0));
2995 assert_eq!(env.get_column("Y"), Some(1)); assert_eq!(env.get_column("Z"), None);
2997 }
2998
2999 #[test]
3000 fn test_infer_term_type() {
3001 assert_eq!(
3002 infer_term_type(&Term::Variable("X".to_string())),
3003 ScalarType::U64
3004 );
3005 assert_eq!(infer_term_type(&Term::Integer(42)), ScalarType::U32);
3006 assert_eq!(infer_term_type(&Term::Integer(i64::MAX)), ScalarType::I64);
3007 assert_eq!(infer_term_type(&Term::Float(3.25)), ScalarType::F64);
3008 assert_eq!(
3009 infer_term_type(&Term::Symbol(symbol::intern("foo"))),
3010 ScalarType::Symbol
3011 );
3012 }
3013
3014 #[test]
3015 fn test_convert_agg_op() {
3016 assert_eq!(convert_agg_op(&AggOp::Count), CoreAggOp::Count);
3017 assert_eq!(convert_agg_op(&AggOp::Sum), CoreAggOp::Sum);
3018 assert_eq!(convert_agg_op(&AggOp::Min), CoreAggOp::Min);
3019 assert_eq!(convert_agg_op(&AggOp::Max), CoreAggOp::Max);
3020 assert_eq!(convert_agg_op(&AggOp::LogSumExp), CoreAggOp::LogSumExp);
3021 }
3022
3023 #[test]
3024 fn test_variable_env_bind_updates_total_cols() {
3025 let mut env = VariableEnv::new();
3027 env.total_cols = 2; env.bind("A", 2, ScalarType::I64);
3031 assert_eq!(
3032 env.column_count(),
3033 3,
3034 "total_cols should be 3 after first bind"
3035 );
3036 assert_eq!(env.get_column("A"), Some(2));
3037
3038 env.bind("B", 3, ScalarType::I64);
3040 assert_eq!(
3041 env.column_count(),
3042 4,
3043 "total_cols should be 4 after second bind"
3044 );
3045 assert_eq!(env.get_column("B"), Some(3));
3046 }
3047
3048 #[test]
3049 fn test_lower_chained_is_expressions() {
3050 let rule = Rule {
3053 head: Atom {
3054 predicate: "result".to_string(),
3055 terms: vec![
3056 Term::Variable("A".to_string()),
3057 Term::Variable("B".to_string()),
3058 ],
3059 },
3060 body: vec![
3061 BodyLiteral::Positive(Atom {
3062 predicate: "input".to_string(),
3063 terms: vec![
3064 Term::Variable("X".to_string()),
3065 Term::Variable("Y".to_string()),
3066 ],
3067 }),
3068 BodyLiteral::IsExpr(IsExpr {
3069 target: "A".to_string(),
3070 expr: ArithExpr::Add(
3071 Box::new(ArithExpr::Variable("X".to_string())),
3072 Box::new(ArithExpr::Variable("Y".to_string())),
3073 ),
3074 }),
3075 BodyLiteral::IsExpr(IsExpr {
3076 target: "B".to_string(),
3077 expr: ArithExpr::Mul(
3078 Box::new(ArithExpr::Variable("A".to_string())),
3079 Box::new(ArithExpr::Integer(2)),
3080 ),
3081 }),
3082 ],
3083 };
3084
3085 let mut lowerer = Lowerer::new();
3086 lowerer.schemas.insert(
3087 "input".to_string(),
3088 Schema::new(vec![
3089 ("c0".to_string(), ScalarType::I64),
3090 ("c1".to_string(), ScalarType::I64),
3091 ]),
3092 );
3093
3094 let result = lowerer.lower_rule(&rule);
3095 assert!(
3096 result.is_ok(),
3097 "Lowering chained is-expressions should succeed: {:?}",
3098 result.err()
3099 );
3100
3101 let node = result.unwrap();
3102
3103 fn count_projects(node: &RirNode) -> usize {
3111 match node {
3112 RirNode::Project { input, .. } => 1 + count_projects(input),
3113 _ => 0,
3114 }
3115 }
3116
3117 let project_count = count_projects(&node);
3119 assert!(
3120 project_count >= 2,
3121 "Expected at least 2 Project nodes for chained is-exprs, got {}",
3122 project_count
3123 );
3124
3125 if let RirNode::Project { columns, .. } = &node {
3127 assert_eq!(columns.len(), 2, "Head has 2 variables");
3128 assert_eq!(columns[0], ProjectExpr::Column(2), "A should be column 2");
3130 assert_eq!(columns[1], ProjectExpr::Column(3), "B should be column 3");
3131 } else {
3132 panic!("Expected top-level Project node");
3133 }
3134 }
3135
3136 #[test]
3137 fn test_u64_comparison_type_from_pred_decl() {
3138 let mut program = Program::new();
3140
3141 program.predicates.push(pred_decl(
3143 "count_data",
3144 vec![ScalarType::Symbol, ScalarType::U64],
3145 ));
3146
3147 program.rules.push(Rule {
3149 head: Atom {
3150 predicate: "count_data".to_string(),
3151 terms: vec![
3152 Term::Symbol(xlog_core::symbol::intern("alice")),
3153 Term::Integer(5),
3154 ],
3155 },
3156 body: vec![],
3157 });
3158
3159 program.predicates.push(pred_decl(
3161 "big_count",
3162 vec![ScalarType::Symbol, ScalarType::U64],
3163 ));
3164
3165 program.rules.push(Rule {
3167 head: Atom {
3168 predicate: "big_count".to_string(),
3169 terms: vec![
3170 Term::Variable("Name".to_string()),
3171 Term::Variable("Count".to_string()),
3172 ],
3173 },
3174 body: vec![
3175 BodyLiteral::Positive(Atom {
3176 predicate: "count_data".to_string(),
3177 terms: vec![
3178 Term::Variable("Name".to_string()),
3179 Term::Variable("Count".to_string()),
3180 ],
3181 }),
3182 BodyLiteral::Comparison(Comparison {
3183 left: Term::Variable("Count".to_string()),
3184 op: CompOp::Ge,
3185 right: Term::Integer(3),
3186 }),
3187 ],
3188 });
3189
3190 let mut lowerer = Lowerer::new();
3191 lowerer.infer_schemas(&program).unwrap();
3192
3193 let schema = lowerer
3195 .schemas
3196 .get("count_data")
3197 .expect("schema for count_data");
3198 assert_eq!(
3199 schema.column_type(0),
3200 Some(ScalarType::Symbol),
3201 "First column should be Symbol"
3202 );
3203 assert_eq!(
3204 schema.column_type(1),
3205 Some(ScalarType::U64),
3206 "Second column should be U64"
3207 );
3208
3209 lowerer.set_strata(vec![
3211 vec!["count_data".to_string()],
3212 vec!["big_count".to_string()],
3213 ]);
3214 lowerer.build_sccs(&program);
3215
3216 let rule = &program.rules[1]; let result = lowerer.lower_rule(rule);
3218 assert!(
3219 result.is_ok(),
3220 "Lowering should succeed: {:?}",
3221 result.err()
3222 );
3223
3224 fn find_compare_const(node: &RirNode) -> Option<&ConstValue> {
3226 match node {
3227 RirNode::Filter { predicate, input } => {
3228 if let Expr::Compare { right, .. } = predicate {
3229 if let Expr::Const(val) = right.as_ref() {
3230 return Some(val);
3231 }
3232 }
3233 find_compare_const(input)
3234 }
3235 RirNode::Project { input, .. } => find_compare_const(input),
3236 RirNode::Join { left, right, .. } => {
3237 find_compare_const(left).or_else(|| find_compare_const(right))
3238 }
3239 _ => None,
3240 }
3241 }
3242
3243 let node = result.unwrap();
3244 let const_val = find_compare_const(&node);
3245 assert!(const_val.is_some(), "Should find a constant in comparison");
3246
3247 match const_val.unwrap() {
3249 ConstValue::U64(v) => assert_eq!(*v, 3, "Value should be 3"),
3250 other => panic!("Expected U64(3), got {:?}", other),
3251 }
3252 }
3253
3254 #[test]
3255 fn test_u64_comparison_with_aggregation() {
3256 use crate::ast::AggExpr;
3257
3258 let mut program = Program::new();
3260
3261 program.predicates.push(pred_decl(
3263 "reports_to",
3264 vec![ScalarType::Symbol, ScalarType::Symbol],
3265 ));
3266
3267 program.rules.push(Rule {
3269 head: Atom {
3270 predicate: "reports_to".to_string(),
3271 terms: vec![
3272 Term::Symbol(xlog_core::symbol::intern("alice")),
3273 Term::Symbol(xlog_core::symbol::intern("bob")),
3274 ],
3275 },
3276 body: vec![],
3277 });
3278 program.rules.push(Rule {
3279 head: Atom {
3280 predicate: "reports_to".to_string(),
3281 terms: vec![
3282 Term::Symbol(xlog_core::symbol::intern("carol")),
3283 Term::Symbol(xlog_core::symbol::intern("bob")),
3284 ],
3285 },
3286 body: vec![],
3287 });
3288
3289 program.predicates.push(pred_decl(
3291 "direct_count",
3292 vec![ScalarType::Symbol, ScalarType::U64],
3293 ));
3294
3295 program.rules.push(Rule {
3297 head: Atom {
3298 predicate: "direct_count".to_string(),
3299 terms: vec![
3300 Term::Variable("Mgr".to_string()),
3301 Term::Aggregate(AggExpr {
3302 op: AggOp::Count,
3303 variable: "Emp".to_string(),
3304 }),
3305 ],
3306 },
3307 body: vec![BodyLiteral::Positive(Atom {
3308 predicate: "reports_to".to_string(),
3309 terms: vec![
3310 Term::Variable("Emp".to_string()),
3311 Term::Variable("Mgr".to_string()),
3312 ],
3313 })],
3314 });
3315
3316 program.predicates.push(pred_decl(
3318 "big_manager",
3319 vec![ScalarType::Symbol, ScalarType::U64],
3320 ));
3321
3322 program.rules.push(Rule {
3324 head: Atom {
3325 predicate: "big_manager".to_string(),
3326 terms: vec![
3327 Term::Variable("Mgr".to_string()),
3328 Term::Variable("Count".to_string()),
3329 ],
3330 },
3331 body: vec![
3332 BodyLiteral::Positive(Atom {
3333 predicate: "direct_count".to_string(),
3334 terms: vec![
3335 Term::Variable("Mgr".to_string()),
3336 Term::Variable("Count".to_string()),
3337 ],
3338 }),
3339 BodyLiteral::Comparison(Comparison {
3340 left: Term::Variable("Count".to_string()),
3341 op: CompOp::Ge,
3342 right: Term::Integer(2),
3343 }),
3344 ],
3345 });
3346
3347 let mut lowerer = Lowerer::new();
3348 lowerer.infer_schemas(&program).unwrap();
3349
3350 let schema = lowerer
3352 .schemas
3353 .get("direct_count")
3354 .expect("schema for direct_count");
3355 assert_eq!(
3356 schema.column_type(0),
3357 Some(ScalarType::Symbol),
3358 "First column should be Symbol"
3359 );
3360 assert_eq!(
3361 schema.column_type(1),
3362 Some(ScalarType::U64),
3363 "Second column should be U64"
3364 );
3365
3366 lowerer.set_strata(vec![
3367 vec!["reports_to".to_string()],
3368 vec!["direct_count".to_string()],
3369 vec!["big_manager".to_string()],
3370 ]);
3371 lowerer.build_sccs(&program);
3372
3373 let big_manager_rule = &program.rules[3];
3375 let result = lowerer.lower_rule(big_manager_rule);
3376 assert!(
3377 result.is_ok(),
3378 "Lowering should succeed: {:?}",
3379 result.err()
3380 );
3381
3382 fn find_compare_const(node: &RirNode) -> Option<&ConstValue> {
3384 match node {
3385 RirNode::Filter { predicate, input } => {
3386 if let Expr::Compare { right, .. } = predicate {
3387 if let Expr::Const(val) = right.as_ref() {
3388 return Some(val);
3389 }
3390 }
3391 find_compare_const(input)
3392 }
3393 RirNode::Project { input, .. } => find_compare_const(input),
3394 RirNode::Join { left, right, .. } => {
3395 find_compare_const(left).or_else(|| find_compare_const(right))
3396 }
3397 _ => None,
3398 }
3399 }
3400
3401 let node = result.unwrap();
3402 let const_val = find_compare_const(&node);
3403 assert!(const_val.is_some(), "Should find a constant in comparison");
3404
3405 match const_val.unwrap() {
3407 ConstValue::U64(v) => assert_eq!(*v, 2, "Value should be 2"),
3408 other => panic!("Expected U64(2), got {:?}", other),
3409 }
3410 }
3411}