1use crate::query::df_graph::GraphExecutionContext;
12use crate::query::df_graph::common::{
13 collect_all_partitions, compute_plan_properties, execute_subplan,
14};
15use crate::query::df_graph::locy_best_by::SortCriterion;
16use crate::query::df_graph::locy_explain::ProvenanceStore;
17use crate::query::df_graph::locy_fixpoint::{
18 DerivedScanRegistry, FixpointClausePlan, FixpointExec, FixpointRulePlan, IsRefBinding,
19};
20use crate::query::df_graph::locy_fold::{FoldAggKind, FoldBinding};
21use crate::query::planner_locy_types::{
22 LocyCommand, LocyIsRef, LocyRulePlan, LocyStratum, LocyYieldColumn,
23};
24use arrow_array::RecordBatch;
25use arrow_schema::{DataType, Field, Schema as ArrowSchema, SchemaRef};
26use datafusion::common::Result as DFResult;
27use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
28use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
29use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
30use futures::Stream;
31use parking_lot::RwLock;
32use std::any::Any;
33use std::collections::HashMap;
34use std::fmt;
35use std::pin::Pin;
36use std::sync::Arc;
37use std::sync::RwLock as StdRwLock;
38use std::task::{Context, Poll};
39use std::time::{Duration, Instant};
40use uni_common::Value;
41use uni_common::core::schema::Schema as UniSchema;
42use uni_cypher::ast::Expr;
43use uni_locy::RuntimeWarning;
44use uni_store::storage::manager::StorageManager;
45
46pub struct DerivedStore {
55 relations: HashMap<String, Vec<RecordBatch>>,
56}
57
58impl Default for DerivedStore {
59 fn default() -> Self {
60 Self::new()
61 }
62}
63
64impl DerivedStore {
65 pub fn new() -> Self {
66 Self {
67 relations: HashMap::new(),
68 }
69 }
70
71 pub fn insert(&mut self, rule_name: String, facts: Vec<RecordBatch>) {
72 self.relations.insert(rule_name, facts);
73 }
74
75 pub fn get(&self, rule_name: &str) -> Option<&Vec<RecordBatch>> {
76 self.relations.get(rule_name)
77 }
78
79 pub fn fact_count(&self, rule_name: &str) -> usize {
80 self.relations
81 .get(rule_name)
82 .map(|batches| batches.iter().map(|b| b.num_rows()).sum())
83 .unwrap_or(0)
84 }
85
86 pub fn rule_names(&self) -> impl Iterator<Item = &str> {
87 self.relations.keys().map(|s| s.as_str())
88 }
89}
90
91pub struct LocyProgramExec {
101 strata: Vec<LocyStratum>,
102 commands: Vec<LocyCommand>,
103 derived_scan_registry: Arc<DerivedScanRegistry>,
104 graph_ctx: Arc<GraphExecutionContext>,
105 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
106 storage: Arc<StorageManager>,
107 schema_info: Arc<UniSchema>,
108 params: HashMap<String, Value>,
109 output_schema: SchemaRef,
110 properties: PlanProperties,
111 metrics: ExecutionPlanMetricsSet,
112 max_iterations: usize,
113 timeout: Duration,
114 max_derived_bytes: usize,
115 deterministic_best_by: bool,
116 strict_probability_domain: bool,
117 probability_epsilon: f64,
118 exact_probability: bool,
119 max_bdd_variables: usize,
120 derived_store_slot: Arc<StdRwLock<Option<DerivedStore>>>,
122 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
124 derivation_tracker: Arc<StdRwLock<Option<Arc<ProvenanceStore>>>>,
126 iteration_counts_slot: Arc<StdRwLock<HashMap<String, usize>>>,
128 peak_memory_slot: Arc<StdRwLock<usize>>,
130 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
132 top_k_proofs: usize,
134}
135
136impl fmt::Debug for LocyProgramExec {
137 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
138 f.debug_struct("LocyProgramExec")
139 .field("strata_count", &self.strata.len())
140 .field("commands_count", &self.commands.len())
141 .field("max_iterations", &self.max_iterations)
142 .field("timeout", &self.timeout)
143 .field("output_schema", &self.output_schema)
144 .field("max_derived_bytes", &self.max_derived_bytes)
145 .finish_non_exhaustive()
146 }
147}
148
149impl LocyProgramExec {
150 #[expect(
151 clippy::too_many_arguments,
152 reason = "execution plan node requires full graph and session context"
153 )]
154 pub fn new(
155 strata: Vec<LocyStratum>,
156 commands: Vec<LocyCommand>,
157 derived_scan_registry: Arc<DerivedScanRegistry>,
158 graph_ctx: Arc<GraphExecutionContext>,
159 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
160 storage: Arc<StorageManager>,
161 schema_info: Arc<UniSchema>,
162 params: HashMap<String, Value>,
163 output_schema: SchemaRef,
164 max_iterations: usize,
165 timeout: Duration,
166 max_derived_bytes: usize,
167 deterministic_best_by: bool,
168 strict_probability_domain: bool,
169 probability_epsilon: f64,
170 exact_probability: bool,
171 max_bdd_variables: usize,
172 top_k_proofs: usize,
173 ) -> Self {
174 let properties = compute_plan_properties(Arc::clone(&output_schema));
175 Self {
176 strata,
177 commands,
178 derived_scan_registry,
179 graph_ctx,
180 session_ctx,
181 storage,
182 schema_info,
183 params,
184 output_schema,
185 properties,
186 metrics: ExecutionPlanMetricsSet::new(),
187 max_iterations,
188 timeout,
189 max_derived_bytes,
190 deterministic_best_by,
191 strict_probability_domain,
192 probability_epsilon,
193 exact_probability,
194 max_bdd_variables,
195 derived_store_slot: Arc::new(StdRwLock::new(None)),
196 approximate_slot: Arc::new(StdRwLock::new(HashMap::new())),
197 derivation_tracker: Arc::new(StdRwLock::new(None)),
198 iteration_counts_slot: Arc::new(StdRwLock::new(HashMap::new())),
199 peak_memory_slot: Arc::new(StdRwLock::new(0)),
200 warnings_slot: Arc::new(StdRwLock::new(Vec::new())),
201 top_k_proofs,
202 }
203 }
204
205 pub fn derived_store_slot(&self) -> Arc<StdRwLock<Option<DerivedStore>>> {
210 Arc::clone(&self.derived_store_slot)
211 }
212
213 pub fn set_derivation_tracker(&self, tracker: Arc<ProvenanceStore>) {
218 if let Ok(mut guard) = self.derivation_tracker.write() {
219 *guard = Some(tracker);
220 }
221 }
222
223 pub fn iteration_counts_slot(&self) -> Arc<StdRwLock<HashMap<String, usize>>> {
228 Arc::clone(&self.iteration_counts_slot)
229 }
230
231 pub fn peak_memory_slot(&self) -> Arc<StdRwLock<usize>> {
236 Arc::clone(&self.peak_memory_slot)
237 }
238
239 pub fn warnings_slot(&self) -> Arc<StdRwLock<Vec<RuntimeWarning>>> {
244 Arc::clone(&self.warnings_slot)
245 }
246
247 pub fn approximate_slot(&self) -> Arc<StdRwLock<HashMap<String, Vec<String>>>> {
252 Arc::clone(&self.approximate_slot)
253 }
254}
255
256impl DisplayAs for LocyProgramExec {
257 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
258 write!(
259 f,
260 "LocyProgramExec: strata={}, commands={}, max_iter={}, timeout={:?}",
261 self.strata.len(),
262 self.commands.len(),
263 self.max_iterations,
264 self.timeout,
265 )
266 }
267}
268
269impl ExecutionPlan for LocyProgramExec {
270 fn name(&self) -> &str {
271 "LocyProgramExec"
272 }
273
274 fn as_any(&self) -> &dyn Any {
275 self
276 }
277
278 fn schema(&self) -> SchemaRef {
279 Arc::clone(&self.output_schema)
280 }
281
282 fn properties(&self) -> &PlanProperties {
283 &self.properties
284 }
285
286 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
287 vec![]
288 }
289
290 fn with_new_children(
291 self: Arc<Self>,
292 children: Vec<Arc<dyn ExecutionPlan>>,
293 ) -> DFResult<Arc<dyn ExecutionPlan>> {
294 if !children.is_empty() {
295 return Err(datafusion::error::DataFusionError::Plan(
296 "LocyProgramExec has no children".to_string(),
297 ));
298 }
299 Ok(self)
300 }
301
302 fn execute(
303 &self,
304 partition: usize,
305 _context: Arc<TaskContext>,
306 ) -> DFResult<SendableRecordBatchStream> {
307 let metrics = BaselineMetrics::new(&self.metrics, partition);
308
309 let strata = self.strata.clone();
310 let registry = Arc::clone(&self.derived_scan_registry);
311 let graph_ctx = Arc::clone(&self.graph_ctx);
312 let session_ctx = Arc::clone(&self.session_ctx);
313 let storage = Arc::clone(&self.storage);
314 let schema_info = Arc::clone(&self.schema_info);
315 let params = self.params.clone();
316 let output_schema = Arc::clone(&self.output_schema);
317 let max_iterations = self.max_iterations;
318 let timeout = self.timeout;
319 let max_derived_bytes = self.max_derived_bytes;
320 let deterministic_best_by = self.deterministic_best_by;
321 let strict_probability_domain = self.strict_probability_domain;
322 let probability_epsilon = self.probability_epsilon;
323 let exact_probability = self.exact_probability;
324 let max_bdd_variables = self.max_bdd_variables;
325 let derived_store_slot = Arc::clone(&self.derived_store_slot);
326 let approximate_slot = Arc::clone(&self.approximate_slot);
327 let iteration_counts_slot = Arc::clone(&self.iteration_counts_slot);
328 let peak_memory_slot = Arc::clone(&self.peak_memory_slot);
329 let derivation_tracker = self.derivation_tracker.read().ok().and_then(|g| g.clone());
330 let warnings_slot = Arc::clone(&self.warnings_slot);
331 let top_k_proofs = self.top_k_proofs;
332
333 let fut = async move {
334 run_program(
335 strata,
336 registry,
337 graph_ctx,
338 session_ctx,
339 storage,
340 schema_info,
341 params,
342 output_schema,
343 max_iterations,
344 timeout,
345 max_derived_bytes,
346 deterministic_best_by,
347 strict_probability_domain,
348 probability_epsilon,
349 exact_probability,
350 max_bdd_variables,
351 derived_store_slot,
352 approximate_slot,
353 iteration_counts_slot,
354 peak_memory_slot,
355 derivation_tracker,
356 warnings_slot,
357 top_k_proofs,
358 )
359 .await
360 };
361
362 Ok(Box::pin(ProgramStream {
363 state: ProgramStreamState::Running(Box::pin(fut)),
364 schema: Arc::clone(&self.output_schema),
365 metrics,
366 }))
367 }
368
369 fn metrics(&self) -> Option<MetricsSet> {
370 Some(self.metrics.clone_inner())
371 }
372}
373
374enum ProgramStreamState {
379 Running(Pin<Box<dyn std::future::Future<Output = DFResult<Vec<RecordBatch>>> + Send>>),
380 Emitting(Vec<RecordBatch>, usize),
381 Done,
382}
383
384struct ProgramStream {
385 state: ProgramStreamState,
386 schema: SchemaRef,
387 metrics: BaselineMetrics,
388}
389
390impl Stream for ProgramStream {
391 type Item = DFResult<RecordBatch>;
392
393 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
394 let this = self.get_mut();
395 loop {
396 match &mut this.state {
397 ProgramStreamState::Running(fut) => match fut.as_mut().poll(cx) {
398 Poll::Ready(Ok(batches)) => {
399 if batches.is_empty() {
400 this.state = ProgramStreamState::Done;
401 return Poll::Ready(None);
402 }
403 this.state = ProgramStreamState::Emitting(batches, 0);
404 }
405 Poll::Ready(Err(e)) => {
406 this.state = ProgramStreamState::Done;
407 return Poll::Ready(Some(Err(e)));
408 }
409 Poll::Pending => return Poll::Pending,
410 },
411 ProgramStreamState::Emitting(batches, idx) => {
412 if *idx >= batches.len() {
413 this.state = ProgramStreamState::Done;
414 return Poll::Ready(None);
415 }
416 let batch = batches[*idx].clone();
417 *idx += 1;
418 this.metrics.record_output(batch.num_rows());
419 return Poll::Ready(Some(Ok(batch)));
420 }
421 ProgramStreamState::Done => return Poll::Ready(None),
422 }
423 }
424 }
425}
426
427impl RecordBatchStream for ProgramStream {
428 fn schema(&self) -> SchemaRef {
429 Arc::clone(&self.schema)
430 }
431}
432
433#[expect(
438 clippy::too_many_arguments,
439 reason = "program evaluation requires full graph and session context"
440)]
441async fn run_program(
442 strata: Vec<LocyStratum>,
443 registry: Arc<DerivedScanRegistry>,
444 graph_ctx: Arc<GraphExecutionContext>,
445 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
446 storage: Arc<StorageManager>,
447 schema_info: Arc<UniSchema>,
448 params: HashMap<String, Value>,
449 output_schema: SchemaRef,
450 max_iterations: usize,
451 timeout: Duration,
452 max_derived_bytes: usize,
453 deterministic_best_by: bool,
454 strict_probability_domain: bool,
455 probability_epsilon: f64,
456 exact_probability: bool,
457 max_bdd_variables: usize,
458 derived_store_slot: Arc<StdRwLock<Option<DerivedStore>>>,
459 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
460 iteration_counts_slot: Arc<StdRwLock<HashMap<String, usize>>>,
461 peak_memory_slot: Arc<StdRwLock<usize>>,
462 derivation_tracker: Option<Arc<ProvenanceStore>>,
463 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
464 top_k_proofs: usize,
465) -> DFResult<Vec<RecordBatch>> {
466 let start = Instant::now();
467 let mut derived_store = DerivedStore::new();
468
469 for stratum in &strata {
471 write_cross_stratum_facts(®istry, &derived_store, stratum);
473
474 let remaining_timeout = timeout.saturating_sub(start.elapsed());
475 if remaining_timeout.is_zero() {
476 return Err(datafusion::error::DataFusionError::Execution(
477 "Locy program timeout exceeded during stratum evaluation".to_string(),
478 ));
479 }
480
481 if stratum.is_recursive {
482 let fixpoint_rules =
484 convert_to_fixpoint_plans(&stratum.rules, ®istry, deterministic_best_by)?;
485 let fixpoint_schema = build_fixpoint_output_schema(&stratum.rules);
486
487 let exec = FixpointExec::new(
488 fixpoint_rules,
489 max_iterations,
490 remaining_timeout,
491 Arc::clone(&graph_ctx),
492 Arc::clone(&session_ctx),
493 Arc::clone(&storage),
494 Arc::clone(&schema_info),
495 params.clone(),
496 Arc::clone(®istry),
497 fixpoint_schema,
498 max_derived_bytes,
499 derivation_tracker.clone(),
500 Arc::clone(&iteration_counts_slot),
501 strict_probability_domain,
502 probability_epsilon,
503 exact_probability,
504 max_bdd_variables,
505 Arc::clone(&warnings_slot),
506 Arc::clone(&approximate_slot),
507 top_k_proofs,
508 );
509
510 let task_ctx = session_ctx.read().task_ctx();
511 let exec_arc: Arc<dyn ExecutionPlan> = Arc::new(exec);
512 let batches = collect_all_partitions(&exec_arc, task_ctx).await?;
513
514 for rule in &stratum.rules {
520 let rule_entries = registry.entries_for_rule(&rule.name);
522 for entry in rule_entries {
523 if !entry.is_self_ref {
524 let all_facts: Vec<RecordBatch> = batches
528 .iter()
529 .filter(|b| {
530 let rule_schema = yield_columns_to_arrow_schema(&rule.yield_schema);
532 b.schema().fields().len() == rule_schema.fields().len()
533 })
534 .cloned()
535 .collect();
536 let mut guard = entry.data.write();
537 *guard = if all_facts.is_empty() {
538 vec![RecordBatch::new_empty(Arc::clone(&entry.schema))]
539 } else {
540 all_facts
541 };
542 }
543 }
544 derived_store.insert(rule.name.clone(), batches.clone());
545 }
546 } else {
547 let fixpoint_rules =
549 convert_to_fixpoint_plans(&stratum.rules, ®istry, deterministic_best_by)?;
550 let task_ctx = session_ctx.read().task_ctx();
551
552 for (rule, fp_rule) in stratum.rules.iter().zip(fixpoint_rules.iter()) {
553 let mut tagged_clause_facts: Vec<(usize, Vec<RecordBatch>)> = Vec::new();
555 for (clause_idx, (clause, fp_clause)) in
556 rule.clauses.iter().zip(fp_rule.clauses.iter()).enumerate()
557 {
558 let mut batches = execute_subplan(
559 &clause.body,
560 ¶ms,
561 &HashMap::new(),
562 &graph_ctx,
563 &session_ctx,
564 &storage,
565 &schema_info,
566 )
567 .await?;
568
569 for binding in &fp_clause.is_ref_bindings {
571 if binding.negated
572 && !binding.anti_join_cols.is_empty()
573 && let Some(entry) = registry.get(binding.derived_scan_index)
574 {
575 let neg_facts = entry.data.read().clone();
576 if !neg_facts.is_empty() {
577 if binding.target_has_prob && fp_rule.prob_column_name.is_some() {
578 let complement_col =
579 format!("__prob_complement_{}", binding.rule_name);
580 if let Some(prob_col) = &binding.target_prob_col {
581 batches =
582 super::locy_fixpoint::apply_prob_complement_composite(
583 batches,
584 &neg_facts,
585 &binding.anti_join_cols,
586 prob_col,
587 &complement_col,
588 )?;
589 } else {
590 batches = super::locy_fixpoint::apply_anti_join_composite(
592 batches,
593 &neg_facts,
594 &binding.anti_join_cols,
595 )?;
596 }
597 } else {
598 batches = super::locy_fixpoint::apply_anti_join_composite(
599 batches,
600 &neg_facts,
601 &binding.anti_join_cols,
602 )?;
603 }
604 }
605 }
606 }
607
608 let complement_cols: Vec<String> = if !batches.is_empty() {
610 batches[0]
611 .schema()
612 .fields()
613 .iter()
614 .filter(|f| f.name().starts_with("__prob_complement_"))
615 .map(|f| f.name().clone())
616 .collect()
617 } else {
618 vec![]
619 };
620 if !complement_cols.is_empty() {
621 batches = super::locy_fixpoint::multiply_prob_factors(
622 batches,
623 fp_rule.prob_column_name.as_deref(),
624 &complement_cols,
625 )?;
626 }
627
628 tagged_clause_facts.push((clause_idx, batches));
629 }
630
631 let shared_info = if let Some(ref tracker) = derivation_tracker {
633 super::locy_fixpoint::record_and_detect_lineage_nonrecursive(
634 fp_rule,
635 &tagged_clause_facts,
636 tracker,
637 &warnings_slot,
638 ®istry,
639 top_k_proofs,
640 )
641 } else {
642 None
643 };
644
645 let mut all_clause_facts: Vec<RecordBatch> = tagged_clause_facts
647 .into_iter()
648 .flat_map(|(_, batches)| batches)
649 .collect();
650
651 if exact_probability
653 && let Some(ref info) = shared_info
654 && let Some(ref tracker) = derivation_tracker
655 {
656 all_clause_facts = super::locy_fixpoint::apply_exact_wmc(
657 all_clause_facts,
658 fp_rule,
659 info,
660 tracker,
661 max_bdd_variables,
662 &warnings_slot,
663 &approximate_slot,
664 )?;
665 }
666
667 let facts = super::locy_fixpoint::apply_post_fixpoint_chain(
669 all_clause_facts,
670 fp_rule,
671 &task_ctx,
672 strict_probability_domain,
673 probability_epsilon,
674 )
675 .await?;
676
677 write_facts_to_registry(®istry, &rule.name, &facts);
679 derived_store.insert(rule.name.clone(), facts);
680 }
681 }
682 }
683
684 let peak_bytes: usize = derived_store
686 .relations
687 .values()
688 .flat_map(|batches| batches.iter())
689 .map(|b| {
690 b.columns()
691 .iter()
692 .map(|col| col.get_buffer_memory_size())
693 .sum::<usize>()
694 })
695 .sum();
696 *peak_memory_slot.write().unwrap() = peak_bytes;
697
698 let stats = vec![build_stats_batch(&derived_store, &strata, output_schema)];
702 *derived_store_slot.write().unwrap() = Some(derived_store);
703 Ok(stats)
704}
705
706fn write_cross_stratum_facts(
712 registry: &DerivedScanRegistry,
713 derived_store: &DerivedStore,
714 stratum: &LocyStratum,
715) {
716 for rule in &stratum.rules {
718 for clause in &rule.clauses {
719 for is_ref in &clause.is_refs {
720 if let Some(facts) = derived_store.get(&is_ref.rule_name) {
723 write_facts_to_registry(registry, &is_ref.rule_name, facts);
724 }
725 }
726 }
727 }
728}
729
730fn write_facts_to_registry(registry: &DerivedScanRegistry, rule_name: &str, facts: &[RecordBatch]) {
732 let entries = registry.entries_for_rule(rule_name);
733 for entry in entries {
734 if !entry.is_self_ref {
735 let mut guard = entry.data.write();
736 *guard = if facts.is_empty() || facts.iter().all(|b| b.num_rows() == 0) {
737 vec![RecordBatch::new_empty(Arc::clone(&entry.schema))]
738 } else {
739 facts
744 .iter()
745 .filter(|b| b.num_rows() > 0)
746 .map(|b| {
747 RecordBatch::try_new(Arc::clone(&entry.schema), b.columns().to_vec())
748 .unwrap_or_else(|_| b.clone())
749 })
750 .collect()
751 };
752 }
753 }
754}
755
756fn convert_to_fixpoint_plans(
762 rules: &[LocyRulePlan],
763 registry: &DerivedScanRegistry,
764 deterministic_best_by: bool,
765) -> DFResult<Vec<FixpointRulePlan>> {
766 rules
767 .iter()
768 .map(|rule| {
769 let yield_schema = yield_columns_to_arrow_schema(&rule.yield_schema);
770 let key_column_indices: Vec<usize> = rule
771 .yield_schema
772 .iter()
773 .enumerate()
774 .filter(|(_, yc)| yc.is_key)
775 .map(|(i, _)| i)
776 .collect();
777
778 let clauses: Vec<FixpointClausePlan> = rule
779 .clauses
780 .iter()
781 .map(|clause| {
782 let is_ref_bindings = convert_is_refs(&clause.is_refs, registry)?;
783 Ok(FixpointClausePlan {
784 body_logical: clause.body.clone(),
785 is_ref_bindings,
786 priority: clause.priority,
787 along_bindings: clause.along_bindings.clone(),
788 })
789 })
790 .collect::<DFResult<Vec<_>>>()?;
791
792 let fold_bindings = convert_fold_bindings(&rule.fold_bindings, &rule.yield_schema)?;
793 let best_by_criteria =
794 convert_best_by_criteria(&rule.best_by_criteria, &rule.yield_schema)?;
795
796 let has_priority = rule.priority.is_some();
797
798 let yield_schema = if has_priority {
800 let mut fields: Vec<Arc<Field>> = yield_schema.fields().iter().cloned().collect();
801 fields.push(Arc::new(Field::new("__priority", DataType::Int64, true)));
802 ArrowSchema::new(fields)
803 } else {
804 yield_schema
805 };
806
807 let prob_column_name = rule
808 .yield_schema
809 .iter()
810 .find(|yc| yc.is_prob)
811 .map(|yc| yc.name.clone());
812
813 Ok(FixpointRulePlan {
814 name: rule.name.clone(),
815 clauses,
816 yield_schema: Arc::new(yield_schema),
817 key_column_indices,
818 priority: rule.priority,
819 has_fold: !rule.fold_bindings.is_empty(),
820 fold_bindings,
821 has_best_by: !rule.best_by_criteria.is_empty(),
822 best_by_criteria,
823 has_priority,
824 deterministic: deterministic_best_by,
825 prob_column_name,
826 })
827 })
828 .collect()
829}
830
831fn convert_is_refs(
833 is_refs: &[LocyIsRef],
834 registry: &DerivedScanRegistry,
835) -> DFResult<Vec<IsRefBinding>> {
836 is_refs
837 .iter()
838 .map(|is_ref| {
839 let entries = registry.entries_for_rule(&is_ref.rule_name);
840 let entry = entries
842 .iter()
843 .find(|e| e.is_self_ref)
844 .or_else(|| entries.first())
845 .ok_or_else(|| {
846 datafusion::error::DataFusionError::Plan(format!(
847 "No derived scan entry found for IS-ref to '{}'",
848 is_ref.rule_name
849 ))
850 })?;
851
852 let anti_join_cols = if is_ref.negated {
857 is_ref
858 .subjects
859 .iter()
860 .enumerate()
861 .filter_map(|(i, s)| {
862 if let uni_cypher::ast::Expr::Variable(var) = s {
863 let right_col = entry
864 .schema
865 .fields()
866 .get(i)
867 .map(|f| f.name().clone())
868 .unwrap_or_else(|| var.clone());
869 Some((var.clone(), right_col))
872 } else {
873 None
874 }
875 })
876 .collect()
877 } else {
878 Vec::new()
879 };
880
881 let provenance_join_cols: Vec<(String, String)> = is_ref
885 .subjects
886 .iter()
887 .enumerate()
888 .filter_map(|(i, s)| {
889 if let uni_cypher::ast::Expr::Variable(var) = s {
890 let right_col = entry
891 .schema
892 .fields()
893 .get(i)
894 .map(|f| f.name().clone())
895 .unwrap_or_else(|| var.clone());
896 Some((var.clone(), right_col))
897 } else {
898 None
899 }
900 })
901 .collect();
902
903 Ok(IsRefBinding {
904 derived_scan_index: entry.scan_index,
905 rule_name: is_ref.rule_name.clone(),
906 is_self_ref: entry.is_self_ref,
907 negated: is_ref.negated,
908 anti_join_cols,
909 target_has_prob: is_ref.target_has_prob,
910 target_prob_col: is_ref.target_prob_col.clone(),
911 provenance_join_cols,
912 })
913 })
914 .collect()
915}
916
917fn convert_fold_bindings(
923 fold_bindings: &[(String, Expr)],
924 yield_schema: &[LocyYieldColumn],
925) -> DFResult<Vec<FoldBinding>> {
926 fold_bindings
927 .iter()
928 .map(|(name, expr)| {
929 let (kind, _input_col_name) = parse_fold_aggregate(expr)?;
930
931 if kind == FoldAggKind::CountAll {
934 return Ok(FoldBinding {
935 output_name: name.clone(),
936 kind,
937 input_col_index: 0, });
939 }
940
941 let input_col_index = yield_schema
944 .iter()
945 .position(|yc| yc.name == *name)
946 .ok_or_else(|| {
947 datafusion::error::DataFusionError::Plan(format!(
948 "FOLD column '{}' not found in yield schema",
949 name
950 ))
951 })?;
952 Ok(FoldBinding {
953 output_name: name.clone(),
954 kind,
955 input_col_index,
956 })
957 })
958 .collect()
959}
960
961fn parse_fold_aggregate(expr: &Expr) -> DFResult<(FoldAggKind, String)> {
963 match expr {
964 Expr::FunctionCall { name, args, .. } => {
965 let upper = name.to_uppercase();
966 let is_count = matches!(upper.as_str(), "COUNT" | "MCOUNT");
967
968 if is_count && args.is_empty() {
970 return Ok((FoldAggKind::CountAll, String::new()));
971 }
972
973 let kind = match upper.as_str() {
974 "SUM" | "MSUM" => FoldAggKind::Sum,
975 "MAX" | "MMAX" => FoldAggKind::Max,
976 "MIN" | "MMIN" => FoldAggKind::Min,
977 "COUNT" | "MCOUNT" => FoldAggKind::Count,
978 "AVG" => FoldAggKind::Avg,
979 "COLLECT" => FoldAggKind::Collect,
980 "MNOR" => FoldAggKind::Nor,
981 "MPROD" => FoldAggKind::Prod,
982 _ => {
983 return Err(datafusion::error::DataFusionError::Plan(format!(
984 "Unknown FOLD aggregate function: {}",
985 name
986 )));
987 }
988 };
989 let col_name = match args.first() {
990 Some(Expr::Variable(v)) => v.clone(),
991 Some(Expr::Property(_, prop)) => prop.clone(),
992 Some(other) => other.to_string_repr(),
993 None => {
994 return Err(datafusion::error::DataFusionError::Plan(
995 "FOLD aggregate function requires at least one argument".to_string(),
996 ));
997 }
998 };
999 Ok((kind, col_name))
1000 }
1001 _ => Err(datafusion::error::DataFusionError::Plan(
1002 "FOLD binding must be a function call (e.g., SUM(x))".to_string(),
1003 )),
1004 }
1005}
1006
1007fn convert_best_by_criteria(
1014 criteria: &[(Expr, bool)],
1015 yield_schema: &[LocyYieldColumn],
1016) -> DFResult<Vec<SortCriterion>> {
1017 criteria
1018 .iter()
1019 .map(|(expr, ascending)| {
1020 let col_name = match expr {
1021 Expr::Property(_, prop) => prop.clone(),
1022 Expr::Variable(v) => v.clone(),
1023 _ => {
1024 return Err(datafusion::error::DataFusionError::Plan(
1025 "BEST BY criterion must be a variable or property reference".to_string(),
1026 ));
1027 }
1028 };
1029 let col_index = yield_schema
1031 .iter()
1032 .position(|yc| yc.name == col_name)
1033 .or_else(|| {
1034 let short_name = col_name.rsplit('.').next().unwrap_or(&col_name);
1035 yield_schema.iter().position(|yc| yc.name == short_name)
1036 })
1037 .ok_or_else(|| {
1038 datafusion::error::DataFusionError::Plan(format!(
1039 "BEST BY column '{}' not found in yield schema",
1040 col_name
1041 ))
1042 })?;
1043 Ok(SortCriterion {
1044 col_index,
1045 ascending: *ascending,
1046 nulls_first: false,
1047 })
1048 })
1049 .collect()
1050}
1051
1052fn yield_columns_to_arrow_schema(columns: &[LocyYieldColumn]) -> ArrowSchema {
1058 let fields: Vec<Arc<Field>> = columns
1059 .iter()
1060 .map(|yc| Arc::new(Field::new(&yc.name, yc.data_type.clone(), true)))
1061 .collect();
1062 ArrowSchema::new(fields)
1063}
1064
1065fn build_fixpoint_output_schema(rules: &[LocyRulePlan]) -> SchemaRef {
1067 if let Some(rule) = rules.first() {
1070 Arc::new(yield_columns_to_arrow_schema(&rule.yield_schema))
1071 } else {
1072 Arc::new(ArrowSchema::empty())
1073 }
1074}
1075
1076fn build_stats_batch(
1078 derived_store: &DerivedStore,
1079 _strata: &[LocyStratum],
1080 output_schema: SchemaRef,
1081) -> RecordBatch {
1082 let mut rule_names: Vec<String> = derived_store.rule_names().map(String::from).collect();
1084 rule_names.sort();
1085
1086 let name_col: arrow_array::StringArray = rule_names.iter().map(|s| Some(s.as_str())).collect();
1087 let count_col: arrow_array::Int64Array = rule_names
1088 .iter()
1089 .map(|name| Some(derived_store.fact_count(name) as i64))
1090 .collect();
1091
1092 let stats_schema = stats_schema();
1093 RecordBatch::try_new(stats_schema, vec![Arc::new(name_col), Arc::new(count_col)])
1094 .unwrap_or_else(|_| RecordBatch::new_empty(output_schema))
1095}
1096
1097pub fn stats_schema() -> SchemaRef {
1099 Arc::new(ArrowSchema::new(vec![
1100 Arc::new(Field::new("rule_name", DataType::Utf8, false)),
1101 Arc::new(Field::new("fact_count", DataType::Int64, false)),
1102 ]))
1103}
1104
1105#[cfg(test)]
1110mod tests {
1111 use super::*;
1112 use arrow_array::{Int64Array, LargeBinaryArray, StringArray};
1113
1114 #[test]
1115 fn test_derived_store_insert_and_get() {
1116 let mut store = DerivedStore::new();
1117 assert!(store.get("test").is_none());
1118
1119 let schema = Arc::new(ArrowSchema::new(vec![Arc::new(Field::new(
1120 "x",
1121 DataType::LargeBinary,
1122 true,
1123 ))]));
1124 let batch = RecordBatch::try_new(
1125 Arc::clone(&schema),
1126 vec![Arc::new(LargeBinaryArray::from(vec![
1127 Some(b"a" as &[u8]),
1128 Some(b"b"),
1129 ]))],
1130 )
1131 .unwrap();
1132
1133 store.insert("test".to_string(), vec![batch.clone()]);
1134
1135 let facts = store.get("test").unwrap();
1136 assert_eq!(facts.len(), 1);
1137 assert_eq!(facts[0].num_rows(), 2);
1138 }
1139
1140 #[test]
1141 fn test_derived_store_fact_count() {
1142 let mut store = DerivedStore::new();
1143 assert_eq!(store.fact_count("empty"), 0);
1144
1145 let schema = Arc::new(ArrowSchema::new(vec![Arc::new(Field::new(
1146 "x",
1147 DataType::LargeBinary,
1148 true,
1149 ))]));
1150 let batch1 = RecordBatch::try_new(
1151 Arc::clone(&schema),
1152 vec![Arc::new(LargeBinaryArray::from(vec![Some(b"a" as &[u8])]))],
1153 )
1154 .unwrap();
1155 let batch2 = RecordBatch::try_new(
1156 Arc::clone(&schema),
1157 vec![Arc::new(LargeBinaryArray::from(vec![
1158 Some(b"b" as &[u8]),
1159 Some(b"c"),
1160 ]))],
1161 )
1162 .unwrap();
1163
1164 store.insert("test".to_string(), vec![batch1, batch2]);
1165 assert_eq!(store.fact_count("test"), 3);
1166 }
1167
1168 #[test]
1169 fn test_stats_batch_schema() {
1170 let schema = stats_schema();
1171 assert_eq!(schema.fields().len(), 2);
1172 assert_eq!(schema.field(0).name(), "rule_name");
1173 assert_eq!(schema.field(1).name(), "fact_count");
1174 assert_eq!(schema.field(0).data_type(), &DataType::Utf8);
1175 assert_eq!(schema.field(1).data_type(), &DataType::Int64);
1176 }
1177
1178 #[test]
1179 fn test_stats_batch_content() {
1180 let mut store = DerivedStore::new();
1181 let schema = Arc::new(ArrowSchema::new(vec![Arc::new(Field::new(
1182 "x",
1183 DataType::LargeBinary,
1184 true,
1185 ))]));
1186 let batch = RecordBatch::try_new(
1187 Arc::clone(&schema),
1188 vec![Arc::new(LargeBinaryArray::from(vec![
1189 Some(b"a" as &[u8]),
1190 Some(b"b"),
1191 ]))],
1192 )
1193 .unwrap();
1194 store.insert("reach".to_string(), vec![batch]);
1195
1196 let output_schema = stats_schema();
1197 let stats = build_stats_batch(&store, &[], Arc::clone(&output_schema));
1198 assert_eq!(stats.num_rows(), 1);
1199
1200 let names = stats
1201 .column(0)
1202 .as_any()
1203 .downcast_ref::<StringArray>()
1204 .unwrap();
1205 assert_eq!(names.value(0), "reach");
1206
1207 let counts = stats
1208 .column(1)
1209 .as_any()
1210 .downcast_ref::<Int64Array>()
1211 .unwrap();
1212 assert_eq!(counts.value(0), 2);
1213 }
1214
1215 #[test]
1216 fn test_yield_columns_to_arrow_schema() {
1217 let columns = vec![
1218 LocyYieldColumn {
1219 name: "a".to_string(),
1220 is_key: true,
1221 is_prob: false,
1222 data_type: DataType::UInt64,
1223 },
1224 LocyYieldColumn {
1225 name: "b".to_string(),
1226 is_key: false,
1227 is_prob: false,
1228 data_type: DataType::LargeUtf8,
1229 },
1230 LocyYieldColumn {
1231 name: "c".to_string(),
1232 is_key: true,
1233 is_prob: false,
1234 data_type: DataType::Float64,
1235 },
1236 ];
1237
1238 let schema = yield_columns_to_arrow_schema(&columns);
1239 assert_eq!(schema.fields().len(), 3);
1240 assert_eq!(schema.field(0).name(), "a");
1241 assert_eq!(schema.field(1).name(), "b");
1242 assert_eq!(schema.field(2).name(), "c");
1243 assert_eq!(schema.field(0).data_type(), &DataType::UInt64);
1245 assert_eq!(schema.field(1).data_type(), &DataType::LargeUtf8);
1246 assert_eq!(schema.field(2).data_type(), &DataType::Float64);
1247 for field in schema.fields() {
1248 assert!(field.is_nullable());
1249 }
1250 }
1251
1252 #[test]
1253 fn test_key_column_indices() {
1254 let columns = [
1255 LocyYieldColumn {
1256 name: "a".to_string(),
1257 is_key: true,
1258 is_prob: false,
1259 data_type: DataType::LargeBinary,
1260 },
1261 LocyYieldColumn {
1262 name: "b".to_string(),
1263 is_key: false,
1264 is_prob: false,
1265 data_type: DataType::LargeBinary,
1266 },
1267 LocyYieldColumn {
1268 name: "c".to_string(),
1269 is_key: true,
1270 is_prob: false,
1271 data_type: DataType::LargeBinary,
1272 },
1273 ];
1274
1275 let key_indices: Vec<usize> = columns
1276 .iter()
1277 .enumerate()
1278 .filter(|(_, yc)| yc.is_key)
1279 .map(|(i, _)| i)
1280 .collect();
1281 assert_eq!(key_indices, vec![0, 2]);
1282 }
1283
1284 #[test]
1285 fn test_parse_fold_aggregate_sum() {
1286 let expr = Expr::FunctionCall {
1287 name: "SUM".to_string(),
1288 args: vec![Expr::Variable("cost".to_string())],
1289 distinct: false,
1290 window_spec: None,
1291 };
1292 let (kind, col) = parse_fold_aggregate(&expr).unwrap();
1293 assert!(matches!(kind, FoldAggKind::Sum));
1294 assert_eq!(col, "cost");
1295 }
1296
1297 #[test]
1298 fn test_parse_fold_aggregate_monotonic() {
1299 let expr = Expr::FunctionCall {
1300 name: "MMAX".to_string(),
1301 args: vec![Expr::Variable("score".to_string())],
1302 distinct: false,
1303 window_spec: None,
1304 };
1305 let (kind, col) = parse_fold_aggregate(&expr).unwrap();
1306 assert!(matches!(kind, FoldAggKind::Max));
1307 assert_eq!(col, "score");
1308 }
1309
1310 #[test]
1311 fn test_parse_fold_aggregate_unknown() {
1312 let expr = Expr::FunctionCall {
1313 name: "UNKNOWN_AGG".to_string(),
1314 args: vec![Expr::Variable("x".to_string())],
1315 distinct: false,
1316 window_spec: None,
1317 };
1318 assert!(parse_fold_aggregate(&expr).is_err());
1319 }
1320
1321 #[test]
1322 fn test_no_commands_returns_stats() {
1323 let store = DerivedStore::new();
1324 let output_schema = stats_schema();
1325 let stats = build_stats_batch(&store, &[], Arc::clone(&output_schema));
1326 assert_eq!(stats.num_rows(), 0);
1328 }
1329}