1use crate::query::df_graph::GraphExecutionContext;
12use crate::query::df_graph::common::{
13 collect_all_partitions, compute_plan_properties, execute_subplan,
14};
15use crate::query::df_graph::locy_best_by::SortCriterion;
16use crate::query::df_graph::locy_explain::ProvenanceStore;
17use crate::query::df_graph::locy_fixpoint::{
18 DerivedScanRegistry, FixpointClausePlan, FixpointExec, FixpointRulePlan, IsRefBinding,
19};
20use crate::query::df_graph::locy_fold::{FoldAggKind, FoldBinding};
21use crate::query::planner_locy_types::{
22 LocyCommand, LocyIsRef, LocyRulePlan, LocyStratum, LocyYieldColumn,
23};
24use arrow_array::RecordBatch;
25use arrow_schema::{DataType, Field, Schema as ArrowSchema, SchemaRef};
26use datafusion::common::Result as DFResult;
27use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
28use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
29use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
30use futures::Stream;
31use parking_lot::RwLock;
32use std::any::Any;
33use std::collections::HashMap;
34use std::fmt;
35use std::pin::Pin;
36use std::sync::Arc;
37use std::sync::RwLock as StdRwLock;
38use std::task::{Context, Poll};
39use std::time::{Duration, Instant};
40use uni_common::Value;
41use uni_common::core::schema::Schema as UniSchema;
42use uni_cypher::ast::Expr;
43use uni_cypher::locy_ast::GoalQuery;
44use uni_locy::{CommandResult, FactRow, RuntimeWarning};
45use uni_store::storage::manager::StorageManager;
46
47pub struct DerivedStore {
56 relations: HashMap<String, Vec<RecordBatch>>,
57}
58
59impl Default for DerivedStore {
60 fn default() -> Self {
61 Self::new()
62 }
63}
64
65impl DerivedStore {
66 pub fn new() -> Self {
67 Self {
68 relations: HashMap::new(),
69 }
70 }
71
72 pub fn insert(&mut self, rule_name: String, facts: Vec<RecordBatch>) {
73 self.relations.insert(rule_name, facts);
74 }
75
76 pub fn get(&self, rule_name: &str) -> Option<&Vec<RecordBatch>> {
77 self.relations.get(rule_name)
78 }
79
80 pub fn fact_count(&self, rule_name: &str) -> usize {
81 self.relations
82 .get(rule_name)
83 .map(|batches| batches.iter().map(|b| b.num_rows()).sum())
84 .unwrap_or(0)
85 }
86
87 pub fn rule_names(&self) -> impl Iterator<Item = &str> {
88 self.relations.keys().map(|s| s.as_str())
89 }
90}
91
92pub struct LocyProgramExec {
102 strata: Vec<LocyStratum>,
103 commands: Vec<LocyCommand>,
104 derived_scan_registry: Arc<DerivedScanRegistry>,
105 graph_ctx: Arc<GraphExecutionContext>,
106 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
107 storage: Arc<StorageManager>,
108 schema_info: Arc<UniSchema>,
109 params: HashMap<String, Value>,
110 output_schema: SchemaRef,
111 properties: PlanProperties,
112 metrics: ExecutionPlanMetricsSet,
113 max_iterations: usize,
114 timeout: Duration,
115 max_derived_bytes: usize,
116 deterministic_best_by: bool,
117 strict_probability_domain: bool,
118 probability_epsilon: f64,
119 exact_probability: bool,
120 max_bdd_variables: usize,
121 derived_store_slot: Arc<StdRwLock<Option<DerivedStore>>>,
123 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
125 derivation_tracker: Arc<StdRwLock<Option<Arc<ProvenanceStore>>>>,
127 iteration_counts_slot: Arc<StdRwLock<HashMap<String, usize>>>,
129 peak_memory_slot: Arc<StdRwLock<usize>>,
131 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
133 command_results_slot: Arc<StdRwLock<Vec<(usize, CommandResult)>>>,
135 top_k_proofs: usize,
137}
138
139impl fmt::Debug for LocyProgramExec {
140 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
141 f.debug_struct("LocyProgramExec")
142 .field("strata_count", &self.strata.len())
143 .field("commands_count", &self.commands.len())
144 .field("max_iterations", &self.max_iterations)
145 .field("timeout", &self.timeout)
146 .field("output_schema", &self.output_schema)
147 .field("max_derived_bytes", &self.max_derived_bytes)
148 .finish_non_exhaustive()
149 }
150}
151
152impl LocyProgramExec {
153 #[expect(
154 clippy::too_many_arguments,
155 reason = "execution plan node requires full graph and session context"
156 )]
157 pub fn new(
158 strata: Vec<LocyStratum>,
159 commands: Vec<LocyCommand>,
160 derived_scan_registry: Arc<DerivedScanRegistry>,
161 graph_ctx: Arc<GraphExecutionContext>,
162 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
163 storage: Arc<StorageManager>,
164 schema_info: Arc<UniSchema>,
165 params: HashMap<String, Value>,
166 output_schema: SchemaRef,
167 max_iterations: usize,
168 timeout: Duration,
169 max_derived_bytes: usize,
170 deterministic_best_by: bool,
171 strict_probability_domain: bool,
172 probability_epsilon: f64,
173 exact_probability: bool,
174 max_bdd_variables: usize,
175 top_k_proofs: usize,
176 ) -> Self {
177 let properties = compute_plan_properties(Arc::clone(&output_schema));
178 Self {
179 strata,
180 commands,
181 derived_scan_registry,
182 graph_ctx,
183 session_ctx,
184 storage,
185 schema_info,
186 params,
187 output_schema,
188 properties,
189 metrics: ExecutionPlanMetricsSet::new(),
190 max_iterations,
191 timeout,
192 max_derived_bytes,
193 deterministic_best_by,
194 strict_probability_domain,
195 probability_epsilon,
196 exact_probability,
197 max_bdd_variables,
198 derived_store_slot: Arc::new(StdRwLock::new(None)),
199 approximate_slot: Arc::new(StdRwLock::new(HashMap::new())),
200 derivation_tracker: Arc::new(StdRwLock::new(None)),
201 iteration_counts_slot: Arc::new(StdRwLock::new(HashMap::new())),
202 peak_memory_slot: Arc::new(StdRwLock::new(0)),
203 warnings_slot: Arc::new(StdRwLock::new(Vec::new())),
204 command_results_slot: Arc::new(StdRwLock::new(Vec::new())),
205 top_k_proofs,
206 }
207 }
208
209 pub fn derived_store_slot(&self) -> Arc<StdRwLock<Option<DerivedStore>>> {
214 Arc::clone(&self.derived_store_slot)
215 }
216
217 pub fn set_derivation_tracker(&self, tracker: Arc<ProvenanceStore>) {
222 if let Ok(mut guard) = self.derivation_tracker.write() {
223 *guard = Some(tracker);
224 }
225 }
226
227 pub fn iteration_counts_slot(&self) -> Arc<StdRwLock<HashMap<String, usize>>> {
232 Arc::clone(&self.iteration_counts_slot)
233 }
234
235 pub fn peak_memory_slot(&self) -> Arc<StdRwLock<usize>> {
240 Arc::clone(&self.peak_memory_slot)
241 }
242
243 pub fn warnings_slot(&self) -> Arc<StdRwLock<Vec<RuntimeWarning>>> {
248 Arc::clone(&self.warnings_slot)
249 }
250
251 pub fn approximate_slot(&self) -> Arc<StdRwLock<HashMap<String, Vec<String>>>> {
256 Arc::clone(&self.approximate_slot)
257 }
258
259 pub fn command_results_slot(&self) -> Arc<StdRwLock<Vec<(usize, CommandResult)>>> {
264 Arc::clone(&self.command_results_slot)
265 }
266}
267
268impl DisplayAs for LocyProgramExec {
269 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
270 write!(
271 f,
272 "LocyProgramExec: strata={}, commands={}, max_iter={}, timeout={:?}",
273 self.strata.len(),
274 self.commands.len(),
275 self.max_iterations,
276 self.timeout,
277 )
278 }
279}
280
281impl ExecutionPlan for LocyProgramExec {
282 fn name(&self) -> &str {
283 "LocyProgramExec"
284 }
285
286 fn as_any(&self) -> &dyn Any {
287 self
288 }
289
290 fn schema(&self) -> SchemaRef {
291 Arc::clone(&self.output_schema)
292 }
293
294 fn properties(&self) -> &PlanProperties {
295 &self.properties
296 }
297
298 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
299 vec![]
300 }
301
302 fn with_new_children(
303 self: Arc<Self>,
304 children: Vec<Arc<dyn ExecutionPlan>>,
305 ) -> DFResult<Arc<dyn ExecutionPlan>> {
306 if !children.is_empty() {
307 return Err(datafusion::error::DataFusionError::Plan(
308 "LocyProgramExec has no children".to_string(),
309 ));
310 }
311 Ok(self)
312 }
313
314 fn execute(
315 &self,
316 partition: usize,
317 _context: Arc<TaskContext>,
318 ) -> DFResult<SendableRecordBatchStream> {
319 let metrics = BaselineMetrics::new(&self.metrics, partition);
320
321 let strata = self.strata.clone();
322 let registry = Arc::clone(&self.derived_scan_registry);
323 let graph_ctx = Arc::clone(&self.graph_ctx);
324 let session_ctx = Arc::clone(&self.session_ctx);
325 let storage = Arc::clone(&self.storage);
326 let schema_info = Arc::clone(&self.schema_info);
327 let params = self.params.clone();
328 let output_schema = Arc::clone(&self.output_schema);
329 let max_iterations = self.max_iterations;
330 let timeout = self.timeout;
331 let max_derived_bytes = self.max_derived_bytes;
332 let deterministic_best_by = self.deterministic_best_by;
333 let strict_probability_domain = self.strict_probability_domain;
334 let probability_epsilon = self.probability_epsilon;
335 let exact_probability = self.exact_probability;
336 let max_bdd_variables = self.max_bdd_variables;
337 let derived_store_slot = Arc::clone(&self.derived_store_slot);
338 let approximate_slot = Arc::clone(&self.approximate_slot);
339 let iteration_counts_slot = Arc::clone(&self.iteration_counts_slot);
340 let peak_memory_slot = Arc::clone(&self.peak_memory_slot);
341 let derivation_tracker = self.derivation_tracker.read().ok().and_then(|g| g.clone());
342 let warnings_slot = Arc::clone(&self.warnings_slot);
343 let commands = self.commands.clone();
344 let command_results_slot = Arc::clone(&self.command_results_slot);
345 let top_k_proofs = self.top_k_proofs;
346
347 let fut = async move {
348 run_program(
349 strata,
350 commands,
351 registry,
352 graph_ctx,
353 session_ctx,
354 storage,
355 schema_info,
356 params,
357 output_schema,
358 max_iterations,
359 timeout,
360 max_derived_bytes,
361 deterministic_best_by,
362 strict_probability_domain,
363 probability_epsilon,
364 exact_probability,
365 max_bdd_variables,
366 derived_store_slot,
367 approximate_slot,
368 iteration_counts_slot,
369 peak_memory_slot,
370 derivation_tracker,
371 warnings_slot,
372 command_results_slot,
373 top_k_proofs,
374 )
375 .await
376 };
377
378 Ok(Box::pin(ProgramStream {
379 state: ProgramStreamState::Running(Box::pin(fut)),
380 schema: Arc::clone(&self.output_schema),
381 metrics,
382 }))
383 }
384
385 fn metrics(&self) -> Option<MetricsSet> {
386 Some(self.metrics.clone_inner())
387 }
388}
389
390enum ProgramStreamState {
395 Running(Pin<Box<dyn std::future::Future<Output = DFResult<Vec<RecordBatch>>> + Send>>),
396 Emitting(Vec<RecordBatch>, usize),
397 Done,
398}
399
400struct ProgramStream {
401 state: ProgramStreamState,
402 schema: SchemaRef,
403 metrics: BaselineMetrics,
404}
405
406impl Stream for ProgramStream {
407 type Item = DFResult<RecordBatch>;
408
409 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
410 let this = self.get_mut();
411 loop {
412 match &mut this.state {
413 ProgramStreamState::Running(fut) => match fut.as_mut().poll(cx) {
414 Poll::Ready(Ok(batches)) => {
415 if batches.is_empty() {
416 this.state = ProgramStreamState::Done;
417 return Poll::Ready(None);
418 }
419 this.state = ProgramStreamState::Emitting(batches, 0);
420 }
421 Poll::Ready(Err(e)) => {
422 this.state = ProgramStreamState::Done;
423 return Poll::Ready(Some(Err(e)));
424 }
425 Poll::Pending => return Poll::Pending,
426 },
427 ProgramStreamState::Emitting(batches, idx) => {
428 if *idx >= batches.len() {
429 this.state = ProgramStreamState::Done;
430 return Poll::Ready(None);
431 }
432 let batch = batches[*idx].clone();
433 *idx += 1;
434 this.metrics.record_output(batch.num_rows());
435 return Poll::Ready(Some(Ok(batch)));
436 }
437 ProgramStreamState::Done => return Poll::Ready(None),
438 }
439 }
440 }
441}
442
443impl RecordBatchStream for ProgramStream {
444 fn schema(&self) -> SchemaRef {
445 Arc::clone(&self.schema)
446 }
447}
448
449#[allow(dead_code)]
465fn execute_query_inline(
466 query: &GoalQuery,
467 derived_store: &DerivedStore,
468 params: &HashMap<String, Value>,
469) -> DFResult<Vec<FactRow>> {
470 let rule_name = query.rule_name.to_string();
471 let batches = derived_store.get(&rule_name).cloned().unwrap_or_default();
472 let rows = super::locy_eval::record_batches_to_locy_rows(&batches);
473
474 let filtered = if let Some(ref where_expr) = query.where_expr {
476 rows.into_iter()
477 .filter(|row| {
478 let merged = super::locy_query::merge_params(row, params);
479 super::locy_eval::eval_expr(where_expr, &merged)
480 .map(|v| v.as_bool().unwrap_or(false))
481 .unwrap_or(false)
482 })
483 .collect()
484 } else {
485 rows
486 };
487
488 super::locy_query::apply_return_clause(filtered, &query.return_clause, params)
490 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
491}
492
493async fn execute_cypher_inline(
495 query: &uni_cypher::ast::Query,
496 schema_info: &Arc<UniSchema>,
497 params: &HashMap<String, Value>,
498 graph_ctx: &Arc<GraphExecutionContext>,
499 session_ctx: &Arc<RwLock<datafusion::prelude::SessionContext>>,
500 storage: &Arc<StorageManager>,
501) -> DFResult<Vec<FactRow>> {
502 let planner = crate::query::planner::QueryPlanner::new(Arc::clone(schema_info));
503 let logical_plan = planner.plan(query.clone()).map_err(|e| {
504 datafusion::error::DataFusionError::Execution(format!("Cypher plan error: {e}"))
505 })?;
506 let batches = execute_subplan(
507 &logical_plan,
508 params,
509 &HashMap::new(),
510 graph_ctx,
511 session_ctx,
512 storage,
513 schema_info,
514 )
515 .await?;
516 Ok(super::locy_eval::record_batches_to_locy_rows(&batches))
517}
518
519#[allow(dead_code)]
522fn needs_node_enrichment(query: &GoalQuery) -> bool {
523 let where_has_property = query
524 .where_expr
525 .as_ref()
526 .is_some_and(expr_has_property_access);
527 let return_has_property = query.return_clause.as_ref().is_some_and(|rc| {
528 rc.items.iter().any(|item| match item {
529 uni_cypher::ast::ReturnItem::Expr { expr, .. } => expr_has_property_access(expr),
530 uni_cypher::ast::ReturnItem::All => false,
531 })
532 });
533 where_has_property || return_has_property
534}
535
536#[allow(dead_code)]
538fn expr_has_property_access(expr: &Expr) -> bool {
539 match expr {
540 Expr::Property(..) => true,
541 Expr::BinaryOp { left, right, .. } => {
542 expr_has_property_access(left) || expr_has_property_access(right)
543 }
544 Expr::UnaryOp { expr, .. } => expr_has_property_access(expr),
545 Expr::FunctionCall { args, .. } => args.iter().any(expr_has_property_access),
546 Expr::List(items) => items.iter().any(expr_has_property_access),
547 Expr::Map(entries) => entries.iter().any(|(_, e)| expr_has_property_access(e)),
548 Expr::Case {
549 expr: case_expr,
550 when_then,
551 else_expr,
552 } => {
553 case_expr
554 .as_ref()
555 .is_some_and(|e| expr_has_property_access(e))
556 || when_then
557 .iter()
558 .any(|(w, t)| expr_has_property_access(w) || expr_has_property_access(t))
559 || else_expr
560 .as_ref()
561 .is_some_and(|e| expr_has_property_access(e))
562 }
563 Expr::IsNull(e) | Expr::IsNotNull(e) | Expr::IsUnique(e) => expr_has_property_access(e),
564 Expr::In { expr, list } => expr_has_property_access(expr) || expr_has_property_access(list),
565 Expr::ArrayIndex { array, index } => {
566 expr_has_property_access(array) || expr_has_property_access(index)
567 }
568 Expr::ArraySlice { array, start, end } => {
569 expr_has_property_access(array)
570 || start.as_ref().is_some_and(|e| expr_has_property_access(e))
571 || end.as_ref().is_some_and(|e| expr_has_property_access(e))
572 }
573 Expr::Quantifier {
574 list, predicate, ..
575 } => expr_has_property_access(list) || expr_has_property_access(predicate),
576 Expr::Reduce {
577 init, list, expr, ..
578 } => {
579 expr_has_property_access(init)
580 || expr_has_property_access(list)
581 || expr_has_property_access(expr)
582 }
583 Expr::ListComprehension {
584 list,
585 where_clause,
586 map_expr,
587 ..
588 } => {
589 expr_has_property_access(list)
590 || where_clause
591 .as_ref()
592 .is_some_and(|e| expr_has_property_access(e))
593 || expr_has_property_access(map_expr)
594 }
595 Expr::PatternComprehension {
596 where_clause,
597 map_expr,
598 ..
599 } => {
600 where_clause
601 .as_ref()
602 .is_some_and(|e| expr_has_property_access(e))
603 || expr_has_property_access(map_expr)
604 }
605 Expr::ValidAt {
606 entity, timestamp, ..
607 } => expr_has_property_access(entity) || expr_has_property_access(timestamp),
608 Expr::MapProjection { base, .. } => expr_has_property_access(base),
609 Expr::LabelCheck { expr, .. } => expr_has_property_access(expr),
610 _ => false,
613 }
614}
615
616#[expect(
621 clippy::too_many_arguments,
622 reason = "program evaluation requires full graph and session context"
623)]
624async fn run_program(
625 strata: Vec<LocyStratum>,
626 commands: Vec<LocyCommand>,
627 registry: Arc<DerivedScanRegistry>,
628 graph_ctx: Arc<GraphExecutionContext>,
629 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
630 storage: Arc<StorageManager>,
631 schema_info: Arc<UniSchema>,
632 params: HashMap<String, Value>,
633 output_schema: SchemaRef,
634 max_iterations: usize,
635 timeout: Duration,
636 max_derived_bytes: usize,
637 deterministic_best_by: bool,
638 strict_probability_domain: bool,
639 probability_epsilon: f64,
640 exact_probability: bool,
641 max_bdd_variables: usize,
642 derived_store_slot: Arc<StdRwLock<Option<DerivedStore>>>,
643 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
644 iteration_counts_slot: Arc<StdRwLock<HashMap<String, usize>>>,
645 peak_memory_slot: Arc<StdRwLock<usize>>,
646 derivation_tracker: Option<Arc<ProvenanceStore>>,
647 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
648 command_results_slot: Arc<StdRwLock<Vec<(usize, CommandResult)>>>,
649 top_k_proofs: usize,
650) -> DFResult<Vec<RecordBatch>> {
651 let start = Instant::now();
652 let mut derived_store = DerivedStore::new();
653
654 for stratum in &strata {
656 write_cross_stratum_facts(®istry, &derived_store, stratum);
658
659 let remaining_timeout = timeout.saturating_sub(start.elapsed());
660 if remaining_timeout.is_zero() {
661 return Err(datafusion::error::DataFusionError::Execution(
662 "Locy program timeout exceeded during stratum evaluation".to_string(),
663 ));
664 }
665
666 if stratum.is_recursive {
667 let fixpoint_rules =
669 convert_to_fixpoint_plans(&stratum.rules, ®istry, deterministic_best_by)?;
670 let fixpoint_schema = build_fixpoint_output_schema(&stratum.rules);
671
672 let exec = FixpointExec::new(
673 fixpoint_rules,
674 max_iterations,
675 remaining_timeout,
676 Arc::clone(&graph_ctx),
677 Arc::clone(&session_ctx),
678 Arc::clone(&storage),
679 Arc::clone(&schema_info),
680 params.clone(),
681 Arc::clone(®istry),
682 fixpoint_schema,
683 max_derived_bytes,
684 derivation_tracker.clone(),
685 Arc::clone(&iteration_counts_slot),
686 strict_probability_domain,
687 probability_epsilon,
688 exact_probability,
689 max_bdd_variables,
690 Arc::clone(&warnings_slot),
691 Arc::clone(&approximate_slot),
692 top_k_proofs,
693 );
694
695 let task_ctx = session_ctx.read().task_ctx();
696 let exec_arc: Arc<dyn ExecutionPlan> = Arc::new(exec);
697 let batches = collect_all_partitions(&exec_arc, task_ctx).await?;
698
699 for rule in &stratum.rules {
710 if rule.yield_schema.is_empty() {
712 continue;
713 }
714 let rule_entries = registry.entries_for_rule(&rule.name);
716 for entry in rule_entries {
717 if !entry.is_self_ref {
718 let all_facts: Vec<RecordBatch> = batches
722 .iter()
723 .filter(|b| {
724 let rule_schema = yield_columns_to_arrow_schema(&rule.yield_schema);
726 b.schema().fields().len() == rule_schema.fields().len()
727 })
728 .cloned()
729 .collect();
730 let mut guard = entry.data.write();
731 *guard = if all_facts.is_empty() {
732 vec![RecordBatch::new_empty(Arc::clone(&entry.schema))]
733 } else {
734 all_facts
735 };
736 }
737 }
738 derived_store.insert(rule.name.clone(), batches.clone());
739 }
740 } else {
741 let fixpoint_rules =
743 convert_to_fixpoint_plans(&stratum.rules, ®istry, deterministic_best_by)?;
744 let task_ctx = session_ctx.read().task_ctx();
745
746 for (rule, fp_rule) in stratum.rules.iter().zip(fixpoint_rules.iter()) {
747 if rule.yield_schema.is_empty() {
752 continue;
753 }
754
755 let mut tagged_clause_facts: Vec<(usize, Vec<RecordBatch>)> = Vec::new();
757 for (clause_idx, (clause, fp_clause)) in
758 rule.clauses.iter().zip(fp_rule.clauses.iter()).enumerate()
759 {
760 let mut batches = execute_subplan(
761 &clause.body,
762 ¶ms,
763 &HashMap::new(),
764 &graph_ctx,
765 &session_ctx,
766 &storage,
767 &schema_info,
768 )
769 .await?;
770
771 for binding in &fp_clause.is_ref_bindings {
773 if binding.negated
774 && !binding.anti_join_cols.is_empty()
775 && let Some(entry) = registry.get(binding.derived_scan_index)
776 {
777 let neg_facts = entry.data.read().clone();
778 if !neg_facts.is_empty() {
779 if binding.target_has_prob && fp_rule.prob_column_name.is_some() {
780 let complement_col =
781 format!("__prob_complement_{}", binding.rule_name);
782 if let Some(prob_col) = &binding.target_prob_col {
783 batches =
784 super::locy_fixpoint::apply_prob_complement_composite(
785 batches,
786 &neg_facts,
787 &binding.anti_join_cols,
788 prob_col,
789 &complement_col,
790 )?;
791 } else {
792 batches = super::locy_fixpoint::apply_anti_join_composite(
794 batches,
795 &neg_facts,
796 &binding.anti_join_cols,
797 )?;
798 }
799 } else {
800 batches = super::locy_fixpoint::apply_anti_join_composite(
801 batches,
802 &neg_facts,
803 &binding.anti_join_cols,
804 )?;
805 }
806 }
807 }
808 }
809
810 let complement_cols: Vec<String> = if !batches.is_empty() {
812 batches[0]
813 .schema()
814 .fields()
815 .iter()
816 .filter(|f| f.name().starts_with("__prob_complement_"))
817 .map(|f| f.name().clone())
818 .collect()
819 } else {
820 vec![]
821 };
822 if !complement_cols.is_empty() {
823 batches = super::locy_fixpoint::multiply_prob_factors(
824 batches,
825 fp_rule.prob_column_name.as_deref(),
826 &complement_cols,
827 )?;
828 }
829
830 tagged_clause_facts.push((clause_idx, batches));
831 }
832
833 let shared_info = if let Some(ref tracker) = derivation_tracker {
835 super::locy_fixpoint::record_and_detect_lineage_nonrecursive(
836 fp_rule,
837 &tagged_clause_facts,
838 tracker,
839 &warnings_slot,
840 ®istry,
841 top_k_proofs,
842 )
843 } else {
844 None
845 };
846
847 let mut all_clause_facts: Vec<RecordBatch> = tagged_clause_facts
849 .into_iter()
850 .flat_map(|(_, batches)| batches)
851 .collect();
852
853 if exact_probability
855 && let Some(ref info) = shared_info
856 && let Some(ref tracker) = derivation_tracker
857 {
858 all_clause_facts = super::locy_fixpoint::apply_exact_wmc(
859 all_clause_facts,
860 fp_rule,
861 info,
862 tracker,
863 max_bdd_variables,
864 &warnings_slot,
865 &approximate_slot,
866 )?;
867 }
868
869 let facts = super::locy_fixpoint::apply_post_fixpoint_chain(
871 all_clause_facts,
872 fp_rule,
873 &task_ctx,
874 strict_probability_domain,
875 probability_epsilon,
876 )
877 .await?;
878
879 write_facts_to_registry(®istry, &rule.name, &facts);
881 derived_store.insert(rule.name.clone(), facts);
882 }
883 }
884 }
885
886 let peak_bytes: usize = derived_store
888 .relations
889 .values()
890 .flat_map(|batches| batches.iter())
891 .map(|b| {
892 b.columns()
893 .iter()
894 .map(|col| col.get_buffer_memory_size())
895 .sum::<usize>()
896 })
897 .sum();
898 *peak_memory_slot.write().unwrap() = peak_bytes;
899
900 let first_derive_idx = commands
910 .iter()
911 .position(|c| matches!(c, LocyCommand::Derive { .. }));
912 let mut inline_results: Vec<(usize, CommandResult)> = Vec::new();
913 for (cmd_idx, cmd) in commands.iter().enumerate() {
914 if let LocyCommand::Cypher { query } = cmd {
915 if first_derive_idx.is_some_and(|di| cmd_idx > di) {
918 continue;
919 }
920 let rows = execute_cypher_inline(
921 query,
922 &schema_info,
923 ¶ms,
924 &graph_ctx,
925 &session_ctx,
926 &storage,
927 )
928 .await?;
929 inline_results.push((cmd_idx, CommandResult::Cypher(rows)));
930 }
931 }
932 *command_results_slot.write().unwrap() = inline_results;
933
934 let stats = vec![build_stats_batch(&derived_store, &strata, output_schema)];
935 *derived_store_slot.write().unwrap() = Some(derived_store);
936 Ok(stats)
937}
938
939fn write_cross_stratum_facts(
945 registry: &DerivedScanRegistry,
946 derived_store: &DerivedStore,
947 stratum: &LocyStratum,
948) {
949 for rule in &stratum.rules {
951 for clause in &rule.clauses {
952 for is_ref in &clause.is_refs {
953 if let Some(facts) = derived_store.get(&is_ref.rule_name) {
956 write_facts_to_registry(registry, &is_ref.rule_name, facts);
957 }
958 }
959 }
960 }
961}
962
963fn write_facts_to_registry(registry: &DerivedScanRegistry, rule_name: &str, facts: &[RecordBatch]) {
965 let entries = registry.entries_for_rule(rule_name);
966 for entry in entries {
967 if !entry.is_self_ref {
968 let mut guard = entry.data.write();
969 *guard = if facts.is_empty() || facts.iter().all(|b| b.num_rows() == 0) {
970 vec![RecordBatch::new_empty(Arc::clone(&entry.schema))]
971 } else {
972 facts
977 .iter()
978 .filter(|b| b.num_rows() > 0)
979 .map(|b| {
980 RecordBatch::try_new(Arc::clone(&entry.schema), b.columns().to_vec())
981 .unwrap_or_else(|_| b.clone())
982 })
983 .collect()
984 };
985 }
986 }
987}
988
989fn convert_to_fixpoint_plans(
995 rules: &[LocyRulePlan],
996 registry: &DerivedScanRegistry,
997 deterministic_best_by: bool,
998) -> DFResult<Vec<FixpointRulePlan>> {
999 rules
1000 .iter()
1001 .map(|rule| {
1002 let yield_schema = yield_columns_to_arrow_schema(&rule.yield_schema);
1003 let key_column_indices: Vec<usize> = rule
1004 .yield_schema
1005 .iter()
1006 .enumerate()
1007 .filter(|(_, yc)| yc.is_key)
1008 .map(|(i, _)| i)
1009 .collect();
1010
1011 let clauses: Vec<FixpointClausePlan> = rule
1012 .clauses
1013 .iter()
1014 .map(|clause| {
1015 let is_ref_bindings = convert_is_refs(&clause.is_refs, registry)?;
1016 Ok(FixpointClausePlan {
1017 body_logical: clause.body.clone(),
1018 is_ref_bindings,
1019 priority: clause.priority,
1020 along_bindings: clause.along_bindings.clone(),
1021 })
1022 })
1023 .collect::<DFResult<Vec<_>>>()?;
1024
1025 let fold_bindings = convert_fold_bindings(&rule.fold_bindings, &rule.yield_schema)?;
1026 let best_by_criteria =
1027 convert_best_by_criteria(&rule.best_by_criteria, &rule.yield_schema)?;
1028
1029 let has_priority = rule.priority.is_some();
1030
1031 let yield_schema = if has_priority {
1033 let mut fields: Vec<Arc<Field>> = yield_schema.fields().iter().cloned().collect();
1034 fields.push(Arc::new(Field::new("__priority", DataType::Int64, true)));
1035 ArrowSchema::new(fields)
1036 } else {
1037 yield_schema
1038 };
1039
1040 let prob_column_name = rule
1041 .yield_schema
1042 .iter()
1043 .find(|yc| yc.is_prob)
1044 .map(|yc| yc.name.clone());
1045
1046 Ok(FixpointRulePlan {
1047 name: rule.name.clone(),
1048 clauses,
1049 yield_schema: Arc::new(yield_schema),
1050 key_column_indices,
1051 priority: rule.priority,
1052 has_fold: !rule.fold_bindings.is_empty(),
1053 fold_bindings,
1054 having: rule.having.clone(),
1055 has_best_by: !rule.best_by_criteria.is_empty(),
1056 best_by_criteria,
1057 has_priority,
1058 deterministic: deterministic_best_by,
1059 prob_column_name,
1060 })
1061 })
1062 .collect()
1063}
1064
1065fn convert_is_refs(
1067 is_refs: &[LocyIsRef],
1068 registry: &DerivedScanRegistry,
1069) -> DFResult<Vec<IsRefBinding>> {
1070 is_refs
1071 .iter()
1072 .map(|is_ref| {
1073 let entries = registry.entries_for_rule(&is_ref.rule_name);
1074 let entry = entries
1076 .iter()
1077 .find(|e| e.is_self_ref)
1078 .or_else(|| entries.first())
1079 .ok_or_else(|| {
1080 datafusion::error::DataFusionError::Plan(format!(
1081 "No derived scan entry found for IS-ref to '{}'",
1082 is_ref.rule_name
1083 ))
1084 })?;
1085
1086 let anti_join_cols = if is_ref.negated {
1091 let mut cols: Vec<(String, String)> = is_ref
1092 .subjects
1093 .iter()
1094 .enumerate()
1095 .filter_map(|(i, s)| {
1096 if let uni_cypher::ast::Expr::Variable(var) = s {
1097 let right_col = entry
1098 .schema
1099 .fields()
1100 .get(i)
1101 .map(|f| f.name().clone())
1102 .unwrap_or_else(|| var.clone());
1103 Some((var.clone(), right_col))
1106 } else {
1107 None
1108 }
1109 })
1110 .collect();
1111 if let Some(uni_cypher::ast::Expr::Variable(target_var)) = &is_ref.target {
1116 let target_idx = is_ref.subjects.len();
1117 if let Some(field) = entry.schema.fields().get(target_idx) {
1118 cols.push((target_var.clone(), field.name().clone()));
1119 }
1120 }
1121 cols
1122 } else {
1123 Vec::new()
1124 };
1125
1126 let provenance_join_cols: Vec<(String, String)> = is_ref
1130 .subjects
1131 .iter()
1132 .enumerate()
1133 .filter_map(|(i, s)| {
1134 if let uni_cypher::ast::Expr::Variable(var) = s {
1135 let right_col = entry
1136 .schema
1137 .fields()
1138 .get(i)
1139 .map(|f| f.name().clone())
1140 .unwrap_or_else(|| var.clone());
1141 Some((var.clone(), right_col))
1142 } else {
1143 None
1144 }
1145 })
1146 .collect();
1147
1148 Ok(IsRefBinding {
1149 derived_scan_index: entry.scan_index,
1150 rule_name: is_ref.rule_name.clone(),
1151 is_self_ref: entry.is_self_ref,
1152 negated: is_ref.negated,
1153 anti_join_cols,
1154 target_has_prob: is_ref.target_has_prob,
1155 target_prob_col: is_ref.target_prob_col.clone(),
1156 provenance_join_cols,
1157 })
1158 })
1159 .collect()
1160}
1161
1162fn convert_fold_bindings(
1168 fold_bindings: &[(String, Expr)],
1169 yield_schema: &[LocyYieldColumn],
1170) -> DFResult<Vec<FoldBinding>> {
1171 fold_bindings
1172 .iter()
1173 .map(|(name, expr)| {
1174 let (kind, _input_col_name) = parse_fold_aggregate(expr)?;
1175
1176 if kind == FoldAggKind::CountAll {
1179 return Ok(FoldBinding {
1180 output_name: name.clone(),
1181 kind,
1182 input_col_index: 0, input_col_name: None,
1184 });
1185 }
1186
1187 let input_col_index = yield_schema
1192 .iter()
1193 .position(|yc| yc.name == *name)
1194 .unwrap_or(0);
1195 Ok(FoldBinding {
1196 output_name: name.clone(),
1197 kind,
1198 input_col_index,
1199 input_col_name: Some(name.clone()),
1200 })
1201 })
1202 .collect()
1203}
1204
1205fn parse_fold_aggregate(expr: &Expr) -> DFResult<(FoldAggKind, String)> {
1207 match expr {
1208 Expr::FunctionCall { name, args, .. } => {
1209 let upper = name.to_uppercase();
1210 let is_count = matches!(upper.as_str(), "COUNT" | "MCOUNT");
1211
1212 if is_count && args.is_empty() {
1214 return Ok((FoldAggKind::CountAll, String::new()));
1215 }
1216
1217 let kind = match upper.as_str() {
1218 "SUM" | "MSUM" => FoldAggKind::Sum,
1219 "MAX" | "MMAX" => FoldAggKind::Max,
1220 "MIN" | "MMIN" => FoldAggKind::Min,
1221 "COUNT" | "MCOUNT" => FoldAggKind::Count,
1222 "AVG" => FoldAggKind::Avg,
1223 "COLLECT" => FoldAggKind::Collect,
1224 "MNOR" => FoldAggKind::Nor,
1225 "MPROD" => FoldAggKind::Prod,
1226 _ => {
1227 return Err(datafusion::error::DataFusionError::Plan(format!(
1228 "Unknown FOLD aggregate function: {}",
1229 name
1230 )));
1231 }
1232 };
1233 let col_name = match args.first() {
1234 Some(Expr::Variable(v)) => v.clone(),
1235 Some(Expr::Property(_, prop)) => prop.clone(),
1236 Some(other) => other.to_string_repr(),
1237 None => {
1238 return Err(datafusion::error::DataFusionError::Plan(
1239 "FOLD aggregate function requires at least one argument".to_string(),
1240 ));
1241 }
1242 };
1243 Ok((kind, col_name))
1244 }
1245 _ => Err(datafusion::error::DataFusionError::Plan(
1246 "FOLD binding must be a function call (e.g., SUM(x))".to_string(),
1247 )),
1248 }
1249}
1250
1251fn convert_best_by_criteria(
1258 criteria: &[(Expr, bool)],
1259 yield_schema: &[LocyYieldColumn],
1260) -> DFResult<Vec<SortCriterion>> {
1261 criteria
1262 .iter()
1263 .map(|(expr, ascending)| {
1264 let col_name = match expr {
1265 Expr::Property(_, prop) => prop.clone(),
1266 Expr::Variable(v) => v.clone(),
1267 _ => {
1268 return Err(datafusion::error::DataFusionError::Plan(
1269 "BEST BY criterion must be a variable or property reference".to_string(),
1270 ));
1271 }
1272 };
1273 let col_index = yield_schema
1275 .iter()
1276 .position(|yc| yc.name == col_name)
1277 .or_else(|| {
1278 let short_name = col_name.rsplit('.').next().unwrap_or(&col_name);
1279 yield_schema.iter().position(|yc| yc.name == short_name)
1280 })
1281 .ok_or_else(|| {
1282 datafusion::error::DataFusionError::Plan(format!(
1283 "BEST BY column '{}' not found in yield schema",
1284 col_name
1285 ))
1286 })?;
1287 Ok(SortCriterion {
1288 col_index,
1289 ascending: *ascending,
1290 nulls_first: false,
1291 })
1292 })
1293 .collect()
1294}
1295
1296fn yield_columns_to_arrow_schema(columns: &[LocyYieldColumn]) -> ArrowSchema {
1302 let fields: Vec<Arc<Field>> = columns
1303 .iter()
1304 .map(|yc| Arc::new(Field::new(&yc.name, yc.data_type.clone(), true)))
1305 .collect();
1306 ArrowSchema::new(fields)
1307}
1308
1309fn build_fixpoint_output_schema(rules: &[LocyRulePlan]) -> SchemaRef {
1311 if let Some(rule) = rules.first() {
1314 Arc::new(yield_columns_to_arrow_schema(&rule.yield_schema))
1315 } else {
1316 Arc::new(ArrowSchema::empty())
1317 }
1318}
1319
1320fn build_stats_batch(
1322 derived_store: &DerivedStore,
1323 _strata: &[LocyStratum],
1324 output_schema: SchemaRef,
1325) -> RecordBatch {
1326 let mut rule_names: Vec<String> = derived_store.rule_names().map(String::from).collect();
1328 rule_names.sort();
1329
1330 let name_col: arrow_array::StringArray = rule_names.iter().map(|s| Some(s.as_str())).collect();
1331 let count_col: arrow_array::Int64Array = rule_names
1332 .iter()
1333 .map(|name| Some(derived_store.fact_count(name) as i64))
1334 .collect();
1335
1336 let stats_schema = stats_schema();
1337 RecordBatch::try_new(stats_schema, vec![Arc::new(name_col), Arc::new(count_col)])
1338 .unwrap_or_else(|_| RecordBatch::new_empty(output_schema))
1339}
1340
1341pub fn stats_schema() -> SchemaRef {
1343 Arc::new(ArrowSchema::new(vec![
1344 Arc::new(Field::new("rule_name", DataType::Utf8, false)),
1345 Arc::new(Field::new("fact_count", DataType::Int64, false)),
1346 ]))
1347}
1348
1349#[cfg(test)]
1354mod tests {
1355 use super::*;
1356 use arrow_array::{Int64Array, LargeBinaryArray, StringArray};
1357
1358 #[test]
1359 fn test_derived_store_insert_and_get() {
1360 let mut store = DerivedStore::new();
1361 assert!(store.get("test").is_none());
1362
1363 let schema = Arc::new(ArrowSchema::new(vec![Arc::new(Field::new(
1364 "x",
1365 DataType::LargeBinary,
1366 true,
1367 ))]));
1368 let batch = RecordBatch::try_new(
1369 Arc::clone(&schema),
1370 vec![Arc::new(LargeBinaryArray::from(vec![
1371 Some(b"a" as &[u8]),
1372 Some(b"b"),
1373 ]))],
1374 )
1375 .unwrap();
1376
1377 store.insert("test".to_string(), vec![batch.clone()]);
1378
1379 let facts = store.get("test").unwrap();
1380 assert_eq!(facts.len(), 1);
1381 assert_eq!(facts[0].num_rows(), 2);
1382 }
1383
1384 #[test]
1385 fn test_derived_store_fact_count() {
1386 let mut store = DerivedStore::new();
1387 assert_eq!(store.fact_count("empty"), 0);
1388
1389 let schema = Arc::new(ArrowSchema::new(vec![Arc::new(Field::new(
1390 "x",
1391 DataType::LargeBinary,
1392 true,
1393 ))]));
1394 let batch1 = RecordBatch::try_new(
1395 Arc::clone(&schema),
1396 vec![Arc::new(LargeBinaryArray::from(vec![Some(b"a" as &[u8])]))],
1397 )
1398 .unwrap();
1399 let batch2 = RecordBatch::try_new(
1400 Arc::clone(&schema),
1401 vec![Arc::new(LargeBinaryArray::from(vec![
1402 Some(b"b" as &[u8]),
1403 Some(b"c"),
1404 ]))],
1405 )
1406 .unwrap();
1407
1408 store.insert("test".to_string(), vec![batch1, batch2]);
1409 assert_eq!(store.fact_count("test"), 3);
1410 }
1411
1412 #[test]
1413 fn test_stats_batch_schema() {
1414 let schema = stats_schema();
1415 assert_eq!(schema.fields().len(), 2);
1416 assert_eq!(schema.field(0).name(), "rule_name");
1417 assert_eq!(schema.field(1).name(), "fact_count");
1418 assert_eq!(schema.field(0).data_type(), &DataType::Utf8);
1419 assert_eq!(schema.field(1).data_type(), &DataType::Int64);
1420 }
1421
1422 #[test]
1423 fn test_stats_batch_content() {
1424 let mut store = DerivedStore::new();
1425 let schema = Arc::new(ArrowSchema::new(vec![Arc::new(Field::new(
1426 "x",
1427 DataType::LargeBinary,
1428 true,
1429 ))]));
1430 let batch = RecordBatch::try_new(
1431 Arc::clone(&schema),
1432 vec![Arc::new(LargeBinaryArray::from(vec![
1433 Some(b"a" as &[u8]),
1434 Some(b"b"),
1435 ]))],
1436 )
1437 .unwrap();
1438 store.insert("reach".to_string(), vec![batch]);
1439
1440 let output_schema = stats_schema();
1441 let stats = build_stats_batch(&store, &[], Arc::clone(&output_schema));
1442 assert_eq!(stats.num_rows(), 1);
1443
1444 let names = stats
1445 .column(0)
1446 .as_any()
1447 .downcast_ref::<StringArray>()
1448 .unwrap();
1449 assert_eq!(names.value(0), "reach");
1450
1451 let counts = stats
1452 .column(1)
1453 .as_any()
1454 .downcast_ref::<Int64Array>()
1455 .unwrap();
1456 assert_eq!(counts.value(0), 2);
1457 }
1458
1459 #[test]
1460 fn test_yield_columns_to_arrow_schema() {
1461 let columns = vec![
1462 LocyYieldColumn {
1463 name: "a".to_string(),
1464 is_key: true,
1465 is_prob: false,
1466 data_type: DataType::UInt64,
1467 },
1468 LocyYieldColumn {
1469 name: "b".to_string(),
1470 is_key: false,
1471 is_prob: false,
1472 data_type: DataType::LargeUtf8,
1473 },
1474 LocyYieldColumn {
1475 name: "c".to_string(),
1476 is_key: true,
1477 is_prob: false,
1478 data_type: DataType::Float64,
1479 },
1480 ];
1481
1482 let schema = yield_columns_to_arrow_schema(&columns);
1483 assert_eq!(schema.fields().len(), 3);
1484 assert_eq!(schema.field(0).name(), "a");
1485 assert_eq!(schema.field(1).name(), "b");
1486 assert_eq!(schema.field(2).name(), "c");
1487 assert_eq!(schema.field(0).data_type(), &DataType::UInt64);
1489 assert_eq!(schema.field(1).data_type(), &DataType::LargeUtf8);
1490 assert_eq!(schema.field(2).data_type(), &DataType::Float64);
1491 for field in schema.fields() {
1492 assert!(field.is_nullable());
1493 }
1494 }
1495
1496 #[test]
1497 fn test_key_column_indices() {
1498 let columns = [
1499 LocyYieldColumn {
1500 name: "a".to_string(),
1501 is_key: true,
1502 is_prob: false,
1503 data_type: DataType::LargeBinary,
1504 },
1505 LocyYieldColumn {
1506 name: "b".to_string(),
1507 is_key: false,
1508 is_prob: false,
1509 data_type: DataType::LargeBinary,
1510 },
1511 LocyYieldColumn {
1512 name: "c".to_string(),
1513 is_key: true,
1514 is_prob: false,
1515 data_type: DataType::LargeBinary,
1516 },
1517 ];
1518
1519 let key_indices: Vec<usize> = columns
1520 .iter()
1521 .enumerate()
1522 .filter(|(_, yc)| yc.is_key)
1523 .map(|(i, _)| i)
1524 .collect();
1525 assert_eq!(key_indices, vec![0, 2]);
1526 }
1527
1528 #[test]
1529 fn test_parse_fold_aggregate_sum() {
1530 let expr = Expr::FunctionCall {
1531 name: "SUM".to_string(),
1532 args: vec![Expr::Variable("cost".to_string())],
1533 distinct: false,
1534 window_spec: None,
1535 };
1536 let (kind, col) = parse_fold_aggregate(&expr).unwrap();
1537 assert!(matches!(kind, FoldAggKind::Sum));
1538 assert_eq!(col, "cost");
1539 }
1540
1541 #[test]
1542 fn test_parse_fold_aggregate_monotonic() {
1543 let expr = Expr::FunctionCall {
1544 name: "MMAX".to_string(),
1545 args: vec![Expr::Variable("score".to_string())],
1546 distinct: false,
1547 window_spec: None,
1548 };
1549 let (kind, col) = parse_fold_aggregate(&expr).unwrap();
1550 assert!(matches!(kind, FoldAggKind::Max));
1551 assert_eq!(col, "score");
1552 }
1553
1554 #[test]
1555 fn test_parse_fold_aggregate_unknown() {
1556 let expr = Expr::FunctionCall {
1557 name: "UNKNOWN_AGG".to_string(),
1558 args: vec![Expr::Variable("x".to_string())],
1559 distinct: false,
1560 window_spec: None,
1561 };
1562 assert!(parse_fold_aggregate(&expr).is_err());
1563 }
1564
1565 #[test]
1566 fn test_no_commands_returns_stats() {
1567 let store = DerivedStore::new();
1568 let output_schema = stats_schema();
1569 let stats = build_stats_batch(&store, &[], Arc::clone(&output_schema));
1570 assert_eq!(stats.num_rows(), 0);
1572 }
1573}