1use crate::query::df_graph::GraphExecutionContext;
12use crate::query::df_graph::common::{
13 collect_all_partitions, compute_plan_properties, execute_subplan,
14};
15use crate::query::df_graph::locy_best_by::SortCriterion;
16use crate::query::df_graph::locy_explain::ProvenanceStore;
17use crate::query::df_graph::locy_fixpoint::{
18 DerivedScanRegistry, FixpointClausePlan, FixpointExec, FixpointRulePlan, IsRefBinding,
19};
20use crate::query::df_graph::locy_fold::{FoldAggKind, FoldBinding};
21use crate::query::planner_locy_types::{
22 LocyCommand, LocyIsRef, LocyRulePlan, LocyStratum, LocyYieldColumn,
23};
24use arrow_array::RecordBatch;
25use arrow_schema::{DataType, Field, Schema as ArrowSchema, SchemaRef};
26use datafusion::common::Result as DFResult;
27use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
28use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
29use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
30use futures::Stream;
31use parking_lot::RwLock;
32use std::any::Any;
33use std::collections::HashMap;
34use std::fmt;
35use std::pin::Pin;
36use std::sync::Arc;
37use std::sync::RwLock as StdRwLock;
38use std::task::{Context, Poll};
39use std::time::{Duration, Instant};
40use uni_common::Value;
41use uni_common::core::schema::Schema as UniSchema;
42use uni_cypher::ast::Expr;
43use uni_cypher::locy_ast::GoalQuery;
44use uni_locy::{CommandResult, FactRow, RuntimeWarning};
45use uni_store::storage::manager::StorageManager;
46
47pub struct DerivedStore {
56 relations: HashMap<String, Vec<RecordBatch>>,
57}
58
59impl Default for DerivedStore {
60 fn default() -> Self {
61 Self::new()
62 }
63}
64
65impl DerivedStore {
66 pub fn new() -> Self {
67 Self {
68 relations: HashMap::new(),
69 }
70 }
71
72 pub fn insert(&mut self, rule_name: String, facts: Vec<RecordBatch>) {
73 self.relations.insert(rule_name, facts);
74 }
75
76 pub fn get(&self, rule_name: &str) -> Option<&Vec<RecordBatch>> {
77 self.relations.get(rule_name)
78 }
79
80 pub fn fact_count(&self, rule_name: &str) -> usize {
81 self.relations
82 .get(rule_name)
83 .map(|batches| batches.iter().map(|b| b.num_rows()).sum())
84 .unwrap_or(0)
85 }
86
87 pub fn rule_names(&self) -> impl Iterator<Item = &str> {
88 self.relations.keys().map(|s| s.as_str())
89 }
90}
91
92pub struct LocyProgramExec {
102 strata: Vec<LocyStratum>,
103 commands: Vec<LocyCommand>,
104 derived_scan_registry: Arc<DerivedScanRegistry>,
105 graph_ctx: Arc<GraphExecutionContext>,
106 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
107 storage: Arc<StorageManager>,
108 schema_info: Arc<UniSchema>,
109 params: HashMap<String, Value>,
110 output_schema: SchemaRef,
111 properties: PlanProperties,
112 metrics: ExecutionPlanMetricsSet,
113 max_iterations: usize,
114 timeout: Duration,
115 max_derived_bytes: usize,
116 deterministic_best_by: bool,
117 strict_probability_domain: bool,
118 probability_epsilon: f64,
119 exact_probability: bool,
120 max_bdd_variables: usize,
121 derived_store_slot: Arc<StdRwLock<Option<DerivedStore>>>,
123 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
125 derivation_tracker: Arc<StdRwLock<Option<Arc<ProvenanceStore>>>>,
127 iteration_counts_slot: Arc<StdRwLock<HashMap<String, usize>>>,
129 peak_memory_slot: Arc<StdRwLock<usize>>,
131 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
133 command_results_slot: Arc<StdRwLock<Vec<(usize, CommandResult)>>>,
135 top_k_proofs: usize,
137 timeout_flag: Arc<std::sync::atomic::AtomicBool>,
140}
141
142impl fmt::Debug for LocyProgramExec {
143 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
144 f.debug_struct("LocyProgramExec")
145 .field("strata_count", &self.strata.len())
146 .field("commands_count", &self.commands.len())
147 .field("max_iterations", &self.max_iterations)
148 .field("timeout", &self.timeout)
149 .field("output_schema", &self.output_schema)
150 .field("max_derived_bytes", &self.max_derived_bytes)
151 .finish_non_exhaustive()
152 }
153}
154
155impl LocyProgramExec {
156 #[expect(
157 clippy::too_many_arguments,
158 reason = "execution plan node requires full graph and session context"
159 )]
160 pub fn new(
161 strata: Vec<LocyStratum>,
162 commands: Vec<LocyCommand>,
163 derived_scan_registry: Arc<DerivedScanRegistry>,
164 graph_ctx: Arc<GraphExecutionContext>,
165 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
166 storage: Arc<StorageManager>,
167 schema_info: Arc<UniSchema>,
168 params: HashMap<String, Value>,
169 output_schema: SchemaRef,
170 max_iterations: usize,
171 timeout: Duration,
172 max_derived_bytes: usize,
173 deterministic_best_by: bool,
174 strict_probability_domain: bool,
175 probability_epsilon: f64,
176 exact_probability: bool,
177 max_bdd_variables: usize,
178 top_k_proofs: usize,
179 ) -> Self {
180 let properties = compute_plan_properties(Arc::clone(&output_schema));
181 Self {
182 strata,
183 commands,
184 derived_scan_registry,
185 graph_ctx,
186 session_ctx,
187 storage,
188 schema_info,
189 params,
190 output_schema,
191 properties,
192 metrics: ExecutionPlanMetricsSet::new(),
193 max_iterations,
194 timeout,
195 max_derived_bytes,
196 deterministic_best_by,
197 strict_probability_domain,
198 probability_epsilon,
199 exact_probability,
200 max_bdd_variables,
201 derived_store_slot: Arc::new(StdRwLock::new(None)),
202 approximate_slot: Arc::new(StdRwLock::new(HashMap::new())),
203 derivation_tracker: Arc::new(StdRwLock::new(None)),
204 iteration_counts_slot: Arc::new(StdRwLock::new(HashMap::new())),
205 peak_memory_slot: Arc::new(StdRwLock::new(0)),
206 warnings_slot: Arc::new(StdRwLock::new(Vec::new())),
207 command_results_slot: Arc::new(StdRwLock::new(Vec::new())),
208 top_k_proofs,
209 timeout_flag: Arc::new(std::sync::atomic::AtomicBool::new(false)),
210 }
211 }
212
213 pub fn derived_store_slot(&self) -> Arc<StdRwLock<Option<DerivedStore>>> {
218 Arc::clone(&self.derived_store_slot)
219 }
220
221 pub fn set_derivation_tracker(&self, tracker: Arc<ProvenanceStore>) {
226 if let Ok(mut guard) = self.derivation_tracker.write() {
227 *guard = Some(tracker);
228 }
229 }
230
231 pub fn iteration_counts_slot(&self) -> Arc<StdRwLock<HashMap<String, usize>>> {
236 Arc::clone(&self.iteration_counts_slot)
237 }
238
239 pub fn peak_memory_slot(&self) -> Arc<StdRwLock<usize>> {
244 Arc::clone(&self.peak_memory_slot)
245 }
246
247 pub fn warnings_slot(&self) -> Arc<StdRwLock<Vec<RuntimeWarning>>> {
252 Arc::clone(&self.warnings_slot)
253 }
254
255 pub fn approximate_slot(&self) -> Arc<StdRwLock<HashMap<String, Vec<String>>>> {
260 Arc::clone(&self.approximate_slot)
261 }
262
263 pub fn command_results_slot(&self) -> Arc<StdRwLock<Vec<(usize, CommandResult)>>> {
268 Arc::clone(&self.command_results_slot)
269 }
270
271 pub fn timeout_flag(&self) -> Arc<std::sync::atomic::AtomicBool> {
276 Arc::clone(&self.timeout_flag)
277 }
278}
279
280impl DisplayAs for LocyProgramExec {
281 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
282 write!(
283 f,
284 "LocyProgramExec: strata={}, commands={}, max_iter={}, timeout={:?}",
285 self.strata.len(),
286 self.commands.len(),
287 self.max_iterations,
288 self.timeout,
289 )
290 }
291}
292
293impl ExecutionPlan for LocyProgramExec {
294 fn name(&self) -> &str {
295 "LocyProgramExec"
296 }
297
298 fn as_any(&self) -> &dyn Any {
299 self
300 }
301
302 fn schema(&self) -> SchemaRef {
303 Arc::clone(&self.output_schema)
304 }
305
306 fn properties(&self) -> &PlanProperties {
307 &self.properties
308 }
309
310 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
311 vec![]
312 }
313
314 fn with_new_children(
315 self: Arc<Self>,
316 children: Vec<Arc<dyn ExecutionPlan>>,
317 ) -> DFResult<Arc<dyn ExecutionPlan>> {
318 if !children.is_empty() {
319 return Err(datafusion::error::DataFusionError::Plan(
320 "LocyProgramExec has no children".to_string(),
321 ));
322 }
323 Ok(self)
324 }
325
326 fn execute(
327 &self,
328 partition: usize,
329 _context: Arc<TaskContext>,
330 ) -> DFResult<SendableRecordBatchStream> {
331 let metrics = BaselineMetrics::new(&self.metrics, partition);
332
333 let strata = self.strata.clone();
334 let registry = Arc::clone(&self.derived_scan_registry);
335 let graph_ctx = Arc::clone(&self.graph_ctx);
336 let session_ctx = Arc::clone(&self.session_ctx);
337 let storage = Arc::clone(&self.storage);
338 let schema_info = Arc::clone(&self.schema_info);
339 let params = self.params.clone();
340 let output_schema = Arc::clone(&self.output_schema);
341 let max_iterations = self.max_iterations;
342 let timeout = self.timeout;
343 let max_derived_bytes = self.max_derived_bytes;
344 let deterministic_best_by = self.deterministic_best_by;
345 let strict_probability_domain = self.strict_probability_domain;
346 let probability_epsilon = self.probability_epsilon;
347 let exact_probability = self.exact_probability;
348 let max_bdd_variables = self.max_bdd_variables;
349 let derived_store_slot = Arc::clone(&self.derived_store_slot);
350 let approximate_slot = Arc::clone(&self.approximate_slot);
351 let iteration_counts_slot = Arc::clone(&self.iteration_counts_slot);
352 let peak_memory_slot = Arc::clone(&self.peak_memory_slot);
353 let derivation_tracker = self.derivation_tracker.read().ok().and_then(|g| g.clone());
354 let warnings_slot = Arc::clone(&self.warnings_slot);
355 let commands = self.commands.clone();
356 let command_results_slot = Arc::clone(&self.command_results_slot);
357 let top_k_proofs = self.top_k_proofs;
358 let timeout_flag = Arc::clone(&self.timeout_flag);
359
360 let fut = async move {
361 run_program(
362 strata,
363 commands,
364 registry,
365 graph_ctx,
366 session_ctx,
367 storage,
368 schema_info,
369 params,
370 output_schema,
371 max_iterations,
372 timeout,
373 max_derived_bytes,
374 deterministic_best_by,
375 strict_probability_domain,
376 probability_epsilon,
377 exact_probability,
378 max_bdd_variables,
379 derived_store_slot,
380 approximate_slot,
381 iteration_counts_slot,
382 peak_memory_slot,
383 derivation_tracker,
384 warnings_slot,
385 command_results_slot,
386 top_k_proofs,
387 timeout_flag,
388 )
389 .await
390 };
391
392 Ok(Box::pin(ProgramStream {
393 state: ProgramStreamState::Running(Box::pin(fut)),
394 schema: Arc::clone(&self.output_schema),
395 metrics,
396 }))
397 }
398
399 fn metrics(&self) -> Option<MetricsSet> {
400 Some(self.metrics.clone_inner())
401 }
402}
403
404enum ProgramStreamState {
409 Running(Pin<Box<dyn std::future::Future<Output = DFResult<Vec<RecordBatch>>> + Send>>),
410 Emitting(Vec<RecordBatch>, usize),
411 Done,
412}
413
414struct ProgramStream {
415 state: ProgramStreamState,
416 schema: SchemaRef,
417 metrics: BaselineMetrics,
418}
419
420impl Stream for ProgramStream {
421 type Item = DFResult<RecordBatch>;
422
423 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
424 let this = self.get_mut();
425 loop {
426 match &mut this.state {
427 ProgramStreamState::Running(fut) => match fut.as_mut().poll(cx) {
428 Poll::Ready(Ok(batches)) => {
429 if batches.is_empty() {
430 this.state = ProgramStreamState::Done;
431 return Poll::Ready(None);
432 }
433 this.state = ProgramStreamState::Emitting(batches, 0);
434 }
435 Poll::Ready(Err(e)) => {
436 this.state = ProgramStreamState::Done;
437 return Poll::Ready(Some(Err(e)));
438 }
439 Poll::Pending => return Poll::Pending,
440 },
441 ProgramStreamState::Emitting(batches, idx) => {
442 if *idx >= batches.len() {
443 this.state = ProgramStreamState::Done;
444 return Poll::Ready(None);
445 }
446 let batch = batches[*idx].clone();
447 *idx += 1;
448 this.metrics.record_output(batch.num_rows());
449 return Poll::Ready(Some(Ok(batch)));
450 }
451 ProgramStreamState::Done => return Poll::Ready(None),
452 }
453 }
454 }
455}
456
457impl RecordBatchStream for ProgramStream {
458 fn schema(&self) -> SchemaRef {
459 Arc::clone(&self.schema)
460 }
461}
462
463#[allow(dead_code)]
479fn execute_query_inline(
480 query: &GoalQuery,
481 derived_store: &DerivedStore,
482 params: &HashMap<String, Value>,
483) -> DFResult<Vec<FactRow>> {
484 let rule_name = query.rule_name.to_string();
485 let batches = derived_store.get(&rule_name).cloned().unwrap_or_default();
486 let rows = super::locy_eval::record_batches_to_locy_rows(&batches);
487
488 let filtered = if let Some(ref where_expr) = query.where_expr {
490 rows.into_iter()
491 .filter(|row| {
492 let merged = super::locy_query::merge_params(row, params);
493 super::locy_eval::eval_expr(where_expr, &merged)
494 .map(|v| v.as_bool().unwrap_or(false))
495 .unwrap_or(false)
496 })
497 .collect()
498 } else {
499 rows
500 };
501
502 super::locy_query::apply_return_clause(filtered, &query.return_clause, params)
504 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
505}
506
507async fn execute_cypher_inline(
509 query: &uni_cypher::ast::Query,
510 schema_info: &Arc<UniSchema>,
511 params: &HashMap<String, Value>,
512 graph_ctx: &Arc<GraphExecutionContext>,
513 session_ctx: &Arc<RwLock<datafusion::prelude::SessionContext>>,
514 storage: &Arc<StorageManager>,
515) -> DFResult<Vec<FactRow>> {
516 let planner = crate::query::planner::QueryPlanner::new(Arc::clone(schema_info));
517 let logical_plan = planner.plan(query.clone()).map_err(|e| {
518 datafusion::error::DataFusionError::Execution(format!("Cypher plan error: {e}"))
519 })?;
520 let batches = execute_subplan(
521 &logical_plan,
522 params,
523 &HashMap::new(),
524 graph_ctx,
525 session_ctx,
526 storage,
527 schema_info,
528 )
529 .await?;
530 Ok(super::locy_eval::record_batches_to_locy_rows(&batches))
531}
532
533#[allow(dead_code)]
536fn needs_node_enrichment(query: &GoalQuery) -> bool {
537 let where_has_property = query
538 .where_expr
539 .as_ref()
540 .is_some_and(expr_has_property_access);
541 let return_has_property = query.return_clause.as_ref().is_some_and(|rc| {
542 rc.items.iter().any(|item| match item {
543 uni_cypher::ast::ReturnItem::Expr { expr, .. } => expr_has_property_access(expr),
544 uni_cypher::ast::ReturnItem::All => false,
545 })
546 });
547 where_has_property || return_has_property
548}
549
550#[allow(dead_code)]
552fn expr_has_property_access(expr: &Expr) -> bool {
553 match expr {
554 Expr::Property(..) => true,
555 Expr::BinaryOp { left, right, .. } => {
556 expr_has_property_access(left) || expr_has_property_access(right)
557 }
558 Expr::UnaryOp { expr, .. } => expr_has_property_access(expr),
559 Expr::FunctionCall { args, .. } => args.iter().any(expr_has_property_access),
560 Expr::List(items) => items.iter().any(expr_has_property_access),
561 Expr::Map(entries) => entries.iter().any(|(_, e)| expr_has_property_access(e)),
562 Expr::Case {
563 expr: case_expr,
564 when_then,
565 else_expr,
566 } => {
567 case_expr
568 .as_ref()
569 .is_some_and(|e| expr_has_property_access(e))
570 || when_then
571 .iter()
572 .any(|(w, t)| expr_has_property_access(w) || expr_has_property_access(t))
573 || else_expr
574 .as_ref()
575 .is_some_and(|e| expr_has_property_access(e))
576 }
577 Expr::IsNull(e) | Expr::IsNotNull(e) | Expr::IsUnique(e) => expr_has_property_access(e),
578 Expr::In { expr, list } => expr_has_property_access(expr) || expr_has_property_access(list),
579 Expr::ArrayIndex { array, index } => {
580 expr_has_property_access(array) || expr_has_property_access(index)
581 }
582 Expr::ArraySlice { array, start, end } => {
583 expr_has_property_access(array)
584 || start.as_ref().is_some_and(|e| expr_has_property_access(e))
585 || end.as_ref().is_some_and(|e| expr_has_property_access(e))
586 }
587 Expr::Quantifier {
588 list, predicate, ..
589 } => expr_has_property_access(list) || expr_has_property_access(predicate),
590 Expr::Reduce {
591 init, list, expr, ..
592 } => {
593 expr_has_property_access(init)
594 || expr_has_property_access(list)
595 || expr_has_property_access(expr)
596 }
597 Expr::ListComprehension {
598 list,
599 where_clause,
600 map_expr,
601 ..
602 } => {
603 expr_has_property_access(list)
604 || where_clause
605 .as_ref()
606 .is_some_and(|e| expr_has_property_access(e))
607 || expr_has_property_access(map_expr)
608 }
609 Expr::PatternComprehension {
610 where_clause,
611 map_expr,
612 ..
613 } => {
614 where_clause
615 .as_ref()
616 .is_some_and(|e| expr_has_property_access(e))
617 || expr_has_property_access(map_expr)
618 }
619 Expr::ValidAt {
620 entity, timestamp, ..
621 } => expr_has_property_access(entity) || expr_has_property_access(timestamp),
622 Expr::MapProjection { base, .. } => expr_has_property_access(base),
623 Expr::LabelCheck { expr, .. } => expr_has_property_access(expr),
624 _ => false,
627 }
628}
629
630#[expect(
635 clippy::too_many_arguments,
636 reason = "program evaluation requires full graph and session context"
637)]
638async fn run_program(
639 strata: Vec<LocyStratum>,
640 commands: Vec<LocyCommand>,
641 registry: Arc<DerivedScanRegistry>,
642 graph_ctx: Arc<GraphExecutionContext>,
643 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
644 storage: Arc<StorageManager>,
645 schema_info: Arc<UniSchema>,
646 params: HashMap<String, Value>,
647 output_schema: SchemaRef,
648 max_iterations: usize,
649 timeout: Duration,
650 max_derived_bytes: usize,
651 deterministic_best_by: bool,
652 strict_probability_domain: bool,
653 probability_epsilon: f64,
654 exact_probability: bool,
655 max_bdd_variables: usize,
656 derived_store_slot: Arc<StdRwLock<Option<DerivedStore>>>,
657 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
658 iteration_counts_slot: Arc<StdRwLock<HashMap<String, usize>>>,
659 peak_memory_slot: Arc<StdRwLock<usize>>,
660 derivation_tracker: Option<Arc<ProvenanceStore>>,
661 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
662 command_results_slot: Arc<StdRwLock<Vec<(usize, CommandResult)>>>,
663 top_k_proofs: usize,
664 timeout_flag: Arc<std::sync::atomic::AtomicBool>,
665) -> DFResult<Vec<RecordBatch>> {
666 let start = Instant::now();
667 let mut derived_store = DerivedStore::new();
668
669 for stratum in &strata {
671 write_cross_stratum_facts(®istry, &derived_store, stratum);
673
674 let remaining_timeout = timeout.saturating_sub(start.elapsed());
675 if remaining_timeout.is_zero() {
676 tracing::warn!("Locy program timeout exceeded during stratum evaluation");
677 timeout_flag.store(true, std::sync::atomic::Ordering::Relaxed);
678 break;
679 }
680
681 if stratum.is_recursive {
682 let fixpoint_rules =
684 convert_to_fixpoint_plans(&stratum.rules, ®istry, deterministic_best_by)?;
685 let fixpoint_schema = build_fixpoint_output_schema(&stratum.rules);
686
687 let exec = FixpointExec::new(
688 fixpoint_rules,
689 max_iterations,
690 remaining_timeout,
691 Arc::clone(&graph_ctx),
692 Arc::clone(&session_ctx),
693 Arc::clone(&storage),
694 Arc::clone(&schema_info),
695 params.clone(),
696 Arc::clone(®istry),
697 fixpoint_schema,
698 max_derived_bytes,
699 derivation_tracker.clone(),
700 Arc::clone(&iteration_counts_slot),
701 strict_probability_domain,
702 probability_epsilon,
703 exact_probability,
704 max_bdd_variables,
705 Arc::clone(&warnings_slot),
706 Arc::clone(&approximate_slot),
707 top_k_proofs,
708 Arc::clone(&timeout_flag),
709 );
710
711 let task_ctx = session_ctx.read().task_ctx();
712 let exec_arc: Arc<dyn ExecutionPlan> = Arc::new(exec);
713 let batches = collect_all_partitions(&exec_arc, task_ctx).await?;
714
715 for rule in &stratum.rules {
726 if rule.yield_schema.is_empty() {
728 continue;
729 }
730 let rule_entries = registry.entries_for_rule(&rule.name);
732 for entry in rule_entries {
733 if !entry.is_self_ref {
734 let all_facts: Vec<RecordBatch> = batches
738 .iter()
739 .filter(|b| {
740 let rule_schema = yield_columns_to_arrow_schema(&rule.yield_schema);
742 b.schema().fields().len() == rule_schema.fields().len()
743 })
744 .cloned()
745 .collect();
746 let mut guard = entry.data.write();
747 *guard = if all_facts.is_empty() {
748 vec![RecordBatch::new_empty(Arc::clone(&entry.schema))]
749 } else {
750 all_facts
751 };
752 }
753 }
754 derived_store.insert(rule.name.clone(), batches.clone());
755 }
756 } else {
757 let fixpoint_rules =
759 convert_to_fixpoint_plans(&stratum.rules, ®istry, deterministic_best_by)?;
760 let task_ctx = session_ctx.read().task_ctx();
761
762 for (rule, fp_rule) in stratum.rules.iter().zip(fixpoint_rules.iter()) {
763 if rule.yield_schema.is_empty() {
768 continue;
769 }
770
771 let mut tagged_clause_facts: Vec<(usize, Vec<RecordBatch>)> = Vec::new();
773 for (clause_idx, (clause, fp_clause)) in
774 rule.clauses.iter().zip(fp_rule.clauses.iter()).enumerate()
775 {
776 let mut batches = execute_subplan(
777 &clause.body,
778 ¶ms,
779 &HashMap::new(),
780 &graph_ctx,
781 &session_ctx,
782 &storage,
783 &schema_info,
784 )
785 .await?;
786
787 for binding in &fp_clause.is_ref_bindings {
789 if binding.negated
790 && !binding.anti_join_cols.is_empty()
791 && let Some(entry) = registry.get(binding.derived_scan_index)
792 {
793 let neg_facts = entry.data.read().clone();
794 if !neg_facts.is_empty() {
795 if binding.target_has_prob && fp_rule.prob_column_name.is_some() {
796 let complement_col =
797 format!("__prob_complement_{}", binding.rule_name);
798 if let Some(prob_col) = &binding.target_prob_col {
799 batches =
800 super::locy_fixpoint::apply_prob_complement_composite(
801 batches,
802 &neg_facts,
803 &binding.anti_join_cols,
804 prob_col,
805 &complement_col,
806 )?;
807 } else {
808 batches = super::locy_fixpoint::apply_anti_join_composite(
810 batches,
811 &neg_facts,
812 &binding.anti_join_cols,
813 )?;
814 }
815 } else {
816 batches = super::locy_fixpoint::apply_anti_join_composite(
817 batches,
818 &neg_facts,
819 &binding.anti_join_cols,
820 )?;
821 }
822 }
823 }
824 }
825
826 let complement_cols: Vec<String> = if !batches.is_empty() {
828 batches[0]
829 .schema()
830 .fields()
831 .iter()
832 .filter(|f| f.name().starts_with("__prob_complement_"))
833 .map(|f| f.name().clone())
834 .collect()
835 } else {
836 vec![]
837 };
838 if !complement_cols.is_empty() {
839 batches = super::locy_fixpoint::multiply_prob_factors(
840 batches,
841 fp_rule.prob_column_name.as_deref(),
842 &complement_cols,
843 )?;
844 }
845
846 tagged_clause_facts.push((clause_idx, batches));
847 }
848
849 let shared_info = if let Some(ref tracker) = derivation_tracker {
851 super::locy_fixpoint::record_and_detect_lineage_nonrecursive(
852 fp_rule,
853 &tagged_clause_facts,
854 tracker,
855 &warnings_slot,
856 ®istry,
857 top_k_proofs,
858 )
859 } else {
860 None
861 };
862
863 let mut all_clause_facts: Vec<RecordBatch> = tagged_clause_facts
865 .into_iter()
866 .flat_map(|(_, batches)| batches)
867 .collect();
868
869 if exact_probability
871 && let Some(ref info) = shared_info
872 && let Some(ref tracker) = derivation_tracker
873 {
874 all_clause_facts = super::locy_fixpoint::apply_exact_wmc(
875 all_clause_facts,
876 fp_rule,
877 info,
878 tracker,
879 max_bdd_variables,
880 &warnings_slot,
881 &approximate_slot,
882 )?;
883 }
884
885 let facts = super::locy_fixpoint::apply_post_fixpoint_chain(
887 all_clause_facts,
888 fp_rule,
889 &task_ctx,
890 strict_probability_domain,
891 probability_epsilon,
892 )
893 .await?;
894
895 write_facts_to_registry(®istry, &rule.name, &facts);
897 derived_store.insert(rule.name.clone(), facts);
898 }
899 }
900 }
901
902 let peak_bytes: usize = derived_store
904 .relations
905 .values()
906 .flat_map(|batches| batches.iter())
907 .map(|b| {
908 b.columns()
909 .iter()
910 .map(|col| col.get_buffer_memory_size())
911 .sum::<usize>()
912 })
913 .sum();
914 *peak_memory_slot.write().unwrap() = peak_bytes;
915
916 let first_derive_idx = commands
926 .iter()
927 .position(|c| matches!(c, LocyCommand::Derive { .. }));
928 let mut inline_results: Vec<(usize, CommandResult)> = Vec::new();
929 for (cmd_idx, cmd) in commands.iter().enumerate() {
930 if let LocyCommand::Cypher { query } = cmd {
931 if first_derive_idx.is_some_and(|di| cmd_idx > di) {
934 continue;
935 }
936 let rows = execute_cypher_inline(
937 query,
938 &schema_info,
939 ¶ms,
940 &graph_ctx,
941 &session_ctx,
942 &storage,
943 )
944 .await?;
945 inline_results.push((cmd_idx, CommandResult::Cypher(rows)));
946 }
947 }
948 *command_results_slot.write().unwrap() = inline_results;
949
950 let stats = vec![build_stats_batch(&derived_store, &strata, output_schema)];
951 *derived_store_slot.write().unwrap() = Some(derived_store);
952 Ok(stats)
953}
954
955fn write_cross_stratum_facts(
961 registry: &DerivedScanRegistry,
962 derived_store: &DerivedStore,
963 stratum: &LocyStratum,
964) {
965 for rule in &stratum.rules {
967 for clause in &rule.clauses {
968 for is_ref in &clause.is_refs {
969 if let Some(facts) = derived_store.get(&is_ref.rule_name) {
972 write_facts_to_registry(registry, &is_ref.rule_name, facts);
973 }
974 }
975 }
976 }
977}
978
979fn write_facts_to_registry(registry: &DerivedScanRegistry, rule_name: &str, facts: &[RecordBatch]) {
981 let entries = registry.entries_for_rule(rule_name);
982 for entry in entries {
983 if !entry.is_self_ref {
984 let mut guard = entry.data.write();
985 *guard = if facts.is_empty() || facts.iter().all(|b| b.num_rows() == 0) {
986 vec![RecordBatch::new_empty(Arc::clone(&entry.schema))]
987 } else {
988 facts
993 .iter()
994 .filter(|b| b.num_rows() > 0)
995 .map(|b| {
996 RecordBatch::try_new(Arc::clone(&entry.schema), b.columns().to_vec())
997 .unwrap_or_else(|_| b.clone())
998 })
999 .collect()
1000 };
1001 }
1002 }
1003}
1004
1005fn convert_to_fixpoint_plans(
1011 rules: &[LocyRulePlan],
1012 registry: &DerivedScanRegistry,
1013 deterministic_best_by: bool,
1014) -> DFResult<Vec<FixpointRulePlan>> {
1015 rules
1016 .iter()
1017 .map(|rule| {
1018 let yield_schema = yield_columns_to_arrow_schema(&rule.yield_schema);
1019 let key_column_indices: Vec<usize> = rule
1020 .yield_schema
1021 .iter()
1022 .enumerate()
1023 .filter(|(_, yc)| yc.is_key)
1024 .map(|(i, _)| i)
1025 .collect();
1026
1027 let clauses: Vec<FixpointClausePlan> = rule
1028 .clauses
1029 .iter()
1030 .map(|clause| {
1031 let is_ref_bindings = convert_is_refs(&clause.is_refs, registry)?;
1032 Ok(FixpointClausePlan {
1033 body_logical: clause.body.clone(),
1034 is_ref_bindings,
1035 priority: clause.priority,
1036 along_bindings: clause.along_bindings.clone(),
1037 })
1038 })
1039 .collect::<DFResult<Vec<_>>>()?;
1040
1041 let fold_bindings = convert_fold_bindings(&rule.fold_bindings, &rule.yield_schema)?;
1042 let best_by_criteria =
1043 convert_best_by_criteria(&rule.best_by_criteria, &rule.yield_schema)?;
1044
1045 let has_priority = rule.priority.is_some();
1046
1047 let yield_schema = if has_priority {
1049 let mut fields: Vec<Arc<Field>> = yield_schema.fields().iter().cloned().collect();
1050 fields.push(Arc::new(Field::new("__priority", DataType::Int64, true)));
1051 ArrowSchema::new(fields)
1052 } else {
1053 yield_schema
1054 };
1055
1056 let prob_column_name = rule
1057 .yield_schema
1058 .iter()
1059 .find(|yc| yc.is_prob)
1060 .map(|yc| yc.name.clone());
1061
1062 Ok(FixpointRulePlan {
1063 name: rule.name.clone(),
1064 clauses,
1065 yield_schema: Arc::new(yield_schema),
1066 key_column_indices,
1067 priority: rule.priority,
1068 has_fold: !rule.fold_bindings.is_empty(),
1069 fold_bindings,
1070 having: rule.having.clone(),
1071 has_best_by: !rule.best_by_criteria.is_empty(),
1072 best_by_criteria,
1073 has_priority,
1074 deterministic: deterministic_best_by,
1075 prob_column_name,
1076 })
1077 })
1078 .collect()
1079}
1080
1081fn convert_is_refs(
1083 is_refs: &[LocyIsRef],
1084 registry: &DerivedScanRegistry,
1085) -> DFResult<Vec<IsRefBinding>> {
1086 is_refs
1087 .iter()
1088 .map(|is_ref| {
1089 let entries = registry.entries_for_rule(&is_ref.rule_name);
1090 let entry = entries
1092 .iter()
1093 .find(|e| e.is_self_ref)
1094 .or_else(|| entries.first())
1095 .ok_or_else(|| {
1096 datafusion::error::DataFusionError::Plan(format!(
1097 "No derived scan entry found for IS-ref to '{}'",
1098 is_ref.rule_name
1099 ))
1100 })?;
1101
1102 let anti_join_cols = if is_ref.negated {
1107 let mut cols: Vec<(String, String)> = is_ref
1108 .subjects
1109 .iter()
1110 .enumerate()
1111 .filter_map(|(i, s)| {
1112 if let uni_cypher::ast::Expr::Variable(var) = s {
1113 let right_col = entry
1114 .schema
1115 .fields()
1116 .get(i)
1117 .map(|f| f.name().clone())
1118 .unwrap_or_else(|| var.clone());
1119 Some((var.clone(), right_col))
1122 } else {
1123 None
1124 }
1125 })
1126 .collect();
1127 if let Some(uni_cypher::ast::Expr::Variable(target_var)) = &is_ref.target {
1132 let target_idx = is_ref.subjects.len();
1133 if let Some(field) = entry.schema.fields().get(target_idx) {
1134 cols.push((target_var.clone(), field.name().clone()));
1135 }
1136 }
1137 cols
1138 } else {
1139 Vec::new()
1140 };
1141
1142 let provenance_join_cols: Vec<(String, String)> = is_ref
1146 .subjects
1147 .iter()
1148 .enumerate()
1149 .filter_map(|(i, s)| {
1150 if let uni_cypher::ast::Expr::Variable(var) = s {
1151 let right_col = entry
1152 .schema
1153 .fields()
1154 .get(i)
1155 .map(|f| f.name().clone())
1156 .unwrap_or_else(|| var.clone());
1157 Some((var.clone(), right_col))
1158 } else {
1159 None
1160 }
1161 })
1162 .collect();
1163
1164 Ok(IsRefBinding {
1165 derived_scan_index: entry.scan_index,
1166 rule_name: is_ref.rule_name.clone(),
1167 is_self_ref: entry.is_self_ref,
1168 negated: is_ref.negated,
1169 anti_join_cols,
1170 target_has_prob: is_ref.target_has_prob,
1171 target_prob_col: is_ref.target_prob_col.clone(),
1172 provenance_join_cols,
1173 })
1174 })
1175 .collect()
1176}
1177
1178fn convert_fold_bindings(
1184 fold_bindings: &[(String, String, Expr)],
1185 yield_schema: &[LocyYieldColumn],
1186) -> DFResult<Vec<FoldBinding>> {
1187 fold_bindings
1188 .iter()
1189 .map(|(name, yield_alias, expr)| {
1190 let (kind, _input_col_name) = parse_fold_aggregate(expr)?;
1191
1192 if kind == FoldAggKind::CountAll {
1195 return Ok(FoldBinding {
1196 output_name: yield_alias.clone(),
1197 kind,
1198 input_col_index: 0, input_col_name: None,
1200 });
1201 }
1202
1203 let input_col_index = yield_schema
1208 .iter()
1209 .position(|yc| yc.name == *name || yc.name == *yield_alias)
1210 .unwrap_or(0);
1211 Ok(FoldBinding {
1212 output_name: yield_alias.clone(),
1213 kind,
1214 input_col_index,
1215 input_col_name: Some(name.clone()),
1216 })
1217 })
1218 .collect()
1219}
1220
1221fn parse_fold_aggregate(expr: &Expr) -> DFResult<(FoldAggKind, String)> {
1223 match expr {
1224 Expr::FunctionCall { name, args, .. } => {
1225 let upper = name.to_uppercase();
1226 let is_count = matches!(upper.as_str(), "COUNT" | "MCOUNT");
1227
1228 if is_count && args.is_empty() {
1230 return Ok((FoldAggKind::CountAll, String::new()));
1231 }
1232
1233 let kind = match upper.as_str() {
1234 "SUM" | "MSUM" => FoldAggKind::Sum,
1235 "MAX" | "MMAX" => FoldAggKind::Max,
1236 "MIN" | "MMIN" => FoldAggKind::Min,
1237 "COUNT" | "MCOUNT" => FoldAggKind::Count,
1238 "AVG" => FoldAggKind::Avg,
1239 "COLLECT" => FoldAggKind::Collect,
1240 "MNOR" => FoldAggKind::Nor,
1241 "MPROD" => FoldAggKind::Prod,
1242 _ => {
1243 return Err(datafusion::error::DataFusionError::Plan(format!(
1244 "Unknown FOLD aggregate function: {}",
1245 name
1246 )));
1247 }
1248 };
1249 let col_name = match args.first() {
1250 Some(Expr::Variable(v)) => v.clone(),
1251 Some(Expr::Property(_, prop)) => prop.clone(),
1252 Some(other) => other.to_string_repr(),
1253 None => {
1254 return Err(datafusion::error::DataFusionError::Plan(
1255 "FOLD aggregate function requires at least one argument".to_string(),
1256 ));
1257 }
1258 };
1259 Ok((kind, col_name))
1260 }
1261 _ => Err(datafusion::error::DataFusionError::Plan(
1262 "FOLD binding must be a function call (e.g., SUM(x))".to_string(),
1263 )),
1264 }
1265}
1266
1267fn convert_best_by_criteria(
1274 criteria: &[(Expr, bool)],
1275 yield_schema: &[LocyYieldColumn],
1276) -> DFResult<Vec<SortCriterion>> {
1277 criteria
1278 .iter()
1279 .map(|(expr, ascending)| {
1280 let col_name = match expr {
1281 Expr::Property(_, prop) => prop.clone(),
1282 Expr::Variable(v) => v.clone(),
1283 _ => {
1284 return Err(datafusion::error::DataFusionError::Plan(
1285 "BEST BY criterion must be a variable or property reference".to_string(),
1286 ));
1287 }
1288 };
1289 let col_index = yield_schema
1291 .iter()
1292 .position(|yc| yc.name == col_name)
1293 .or_else(|| {
1294 let short_name = col_name.rsplit('.').next().unwrap_or(&col_name);
1295 yield_schema.iter().position(|yc| yc.name == short_name)
1296 })
1297 .ok_or_else(|| {
1298 datafusion::error::DataFusionError::Plan(format!(
1299 "BEST BY column '{}' not found in yield schema",
1300 col_name
1301 ))
1302 })?;
1303 Ok(SortCriterion {
1304 col_index,
1305 ascending: *ascending,
1306 nulls_first: false,
1307 })
1308 })
1309 .collect()
1310}
1311
1312fn yield_columns_to_arrow_schema(columns: &[LocyYieldColumn]) -> ArrowSchema {
1318 let fields: Vec<Arc<Field>> = columns
1319 .iter()
1320 .map(|yc| Arc::new(Field::new(&yc.name, yc.data_type.clone(), true)))
1321 .collect();
1322 ArrowSchema::new(fields)
1323}
1324
1325fn build_fixpoint_output_schema(rules: &[LocyRulePlan]) -> SchemaRef {
1327 if let Some(rule) = rules.first() {
1330 Arc::new(yield_columns_to_arrow_schema(&rule.yield_schema))
1331 } else {
1332 Arc::new(ArrowSchema::empty())
1333 }
1334}
1335
1336fn build_stats_batch(
1338 derived_store: &DerivedStore,
1339 _strata: &[LocyStratum],
1340 output_schema: SchemaRef,
1341) -> RecordBatch {
1342 let mut rule_names: Vec<String> = derived_store.rule_names().map(String::from).collect();
1344 rule_names.sort();
1345
1346 let name_col: arrow_array::StringArray = rule_names.iter().map(|s| Some(s.as_str())).collect();
1347 let count_col: arrow_array::Int64Array = rule_names
1348 .iter()
1349 .map(|name| Some(derived_store.fact_count(name) as i64))
1350 .collect();
1351
1352 let stats_schema = stats_schema();
1353 RecordBatch::try_new(stats_schema, vec![Arc::new(name_col), Arc::new(count_col)])
1354 .unwrap_or_else(|_| RecordBatch::new_empty(output_schema))
1355}
1356
1357pub fn stats_schema() -> SchemaRef {
1359 Arc::new(ArrowSchema::new(vec![
1360 Arc::new(Field::new("rule_name", DataType::Utf8, false)),
1361 Arc::new(Field::new("fact_count", DataType::Int64, false)),
1362 ]))
1363}
1364
1365#[cfg(test)]
1370mod tests {
1371 use super::*;
1372 use arrow_array::{Int64Array, LargeBinaryArray, StringArray};
1373
1374 #[test]
1375 fn test_derived_store_insert_and_get() {
1376 let mut store = DerivedStore::new();
1377 assert!(store.get("test").is_none());
1378
1379 let schema = Arc::new(ArrowSchema::new(vec![Arc::new(Field::new(
1380 "x",
1381 DataType::LargeBinary,
1382 true,
1383 ))]));
1384 let batch = RecordBatch::try_new(
1385 Arc::clone(&schema),
1386 vec![Arc::new(LargeBinaryArray::from(vec![
1387 Some(b"a" as &[u8]),
1388 Some(b"b"),
1389 ]))],
1390 )
1391 .unwrap();
1392
1393 store.insert("test".to_string(), vec![batch.clone()]);
1394
1395 let facts = store.get("test").unwrap();
1396 assert_eq!(facts.len(), 1);
1397 assert_eq!(facts[0].num_rows(), 2);
1398 }
1399
1400 #[test]
1401 fn test_derived_store_fact_count() {
1402 let mut store = DerivedStore::new();
1403 assert_eq!(store.fact_count("empty"), 0);
1404
1405 let schema = Arc::new(ArrowSchema::new(vec![Arc::new(Field::new(
1406 "x",
1407 DataType::LargeBinary,
1408 true,
1409 ))]));
1410 let batch1 = RecordBatch::try_new(
1411 Arc::clone(&schema),
1412 vec![Arc::new(LargeBinaryArray::from(vec![Some(b"a" as &[u8])]))],
1413 )
1414 .unwrap();
1415 let batch2 = RecordBatch::try_new(
1416 Arc::clone(&schema),
1417 vec![Arc::new(LargeBinaryArray::from(vec![
1418 Some(b"b" as &[u8]),
1419 Some(b"c"),
1420 ]))],
1421 )
1422 .unwrap();
1423
1424 store.insert("test".to_string(), vec![batch1, batch2]);
1425 assert_eq!(store.fact_count("test"), 3);
1426 }
1427
1428 #[test]
1429 fn test_stats_batch_schema() {
1430 let schema = stats_schema();
1431 assert_eq!(schema.fields().len(), 2);
1432 assert_eq!(schema.field(0).name(), "rule_name");
1433 assert_eq!(schema.field(1).name(), "fact_count");
1434 assert_eq!(schema.field(0).data_type(), &DataType::Utf8);
1435 assert_eq!(schema.field(1).data_type(), &DataType::Int64);
1436 }
1437
1438 #[test]
1439 fn test_stats_batch_content() {
1440 let mut store = DerivedStore::new();
1441 let schema = Arc::new(ArrowSchema::new(vec![Arc::new(Field::new(
1442 "x",
1443 DataType::LargeBinary,
1444 true,
1445 ))]));
1446 let batch = RecordBatch::try_new(
1447 Arc::clone(&schema),
1448 vec![Arc::new(LargeBinaryArray::from(vec![
1449 Some(b"a" as &[u8]),
1450 Some(b"b"),
1451 ]))],
1452 )
1453 .unwrap();
1454 store.insert("reach".to_string(), vec![batch]);
1455
1456 let output_schema = stats_schema();
1457 let stats = build_stats_batch(&store, &[], Arc::clone(&output_schema));
1458 assert_eq!(stats.num_rows(), 1);
1459
1460 let names = stats
1461 .column(0)
1462 .as_any()
1463 .downcast_ref::<StringArray>()
1464 .unwrap();
1465 assert_eq!(names.value(0), "reach");
1466
1467 let counts = stats
1468 .column(1)
1469 .as_any()
1470 .downcast_ref::<Int64Array>()
1471 .unwrap();
1472 assert_eq!(counts.value(0), 2);
1473 }
1474
1475 #[test]
1476 fn test_yield_columns_to_arrow_schema() {
1477 let columns = vec![
1478 LocyYieldColumn {
1479 name: "a".to_string(),
1480 is_key: true,
1481 is_prob: false,
1482 data_type: DataType::UInt64,
1483 },
1484 LocyYieldColumn {
1485 name: "b".to_string(),
1486 is_key: false,
1487 is_prob: false,
1488 data_type: DataType::LargeUtf8,
1489 },
1490 LocyYieldColumn {
1491 name: "c".to_string(),
1492 is_key: true,
1493 is_prob: false,
1494 data_type: DataType::Float64,
1495 },
1496 ];
1497
1498 let schema = yield_columns_to_arrow_schema(&columns);
1499 assert_eq!(schema.fields().len(), 3);
1500 assert_eq!(schema.field(0).name(), "a");
1501 assert_eq!(schema.field(1).name(), "b");
1502 assert_eq!(schema.field(2).name(), "c");
1503 assert_eq!(schema.field(0).data_type(), &DataType::UInt64);
1505 assert_eq!(schema.field(1).data_type(), &DataType::LargeUtf8);
1506 assert_eq!(schema.field(2).data_type(), &DataType::Float64);
1507 for field in schema.fields() {
1508 assert!(field.is_nullable());
1509 }
1510 }
1511
1512 #[test]
1513 fn test_key_column_indices() {
1514 let columns = [
1515 LocyYieldColumn {
1516 name: "a".to_string(),
1517 is_key: true,
1518 is_prob: false,
1519 data_type: DataType::LargeBinary,
1520 },
1521 LocyYieldColumn {
1522 name: "b".to_string(),
1523 is_key: false,
1524 is_prob: false,
1525 data_type: DataType::LargeBinary,
1526 },
1527 LocyYieldColumn {
1528 name: "c".to_string(),
1529 is_key: true,
1530 is_prob: false,
1531 data_type: DataType::LargeBinary,
1532 },
1533 ];
1534
1535 let key_indices: Vec<usize> = columns
1536 .iter()
1537 .enumerate()
1538 .filter(|(_, yc)| yc.is_key)
1539 .map(|(i, _)| i)
1540 .collect();
1541 assert_eq!(key_indices, vec![0, 2]);
1542 }
1543
1544 #[test]
1545 fn test_parse_fold_aggregate_sum() {
1546 let expr = Expr::FunctionCall {
1547 name: "SUM".to_string(),
1548 args: vec![Expr::Variable("cost".to_string())],
1549 distinct: false,
1550 window_spec: None,
1551 };
1552 let (kind, col) = parse_fold_aggregate(&expr).unwrap();
1553 assert!(matches!(kind, FoldAggKind::Sum));
1554 assert_eq!(col, "cost");
1555 }
1556
1557 #[test]
1558 fn test_parse_fold_aggregate_monotonic() {
1559 let expr = Expr::FunctionCall {
1560 name: "MMAX".to_string(),
1561 args: vec![Expr::Variable("score".to_string())],
1562 distinct: false,
1563 window_spec: None,
1564 };
1565 let (kind, col) = parse_fold_aggregate(&expr).unwrap();
1566 assert!(matches!(kind, FoldAggKind::Max));
1567 assert_eq!(col, "score");
1568 }
1569
1570 #[test]
1571 fn test_parse_fold_aggregate_unknown() {
1572 let expr = Expr::FunctionCall {
1573 name: "UNKNOWN_AGG".to_string(),
1574 args: vec![Expr::Variable("x".to_string())],
1575 distinct: false,
1576 window_spec: None,
1577 };
1578 assert!(parse_fold_aggregate(&expr).is_err());
1579 }
1580
1581 #[test]
1582 fn test_no_commands_returns_stats() {
1583 let store = DerivedStore::new();
1584 let output_schema = stats_schema();
1585 let stats = build_stats_batch(&store, &[], Arc::clone(&output_schema));
1586 assert_eq!(stats.num_rows(), 0);
1588 }
1589}