1use crate::query::df_graph::GraphExecutionContext;
12use crate::query::df_graph::common::{
13 collect_all_partitions, compute_plan_properties, execute_subplan, execute_subplan_collecting,
14};
15use crate::query::df_graph::locy_best_by::SortCriterion;
16use crate::query::df_graph::locy_explain::ProvenanceStore;
17use crate::query::df_graph::locy_fixpoint::{
18 DerivedScanRegistry, FixpointClausePlan, FixpointExec, FixpointRulePlan, IsRefBinding,
19};
20use crate::query::df_graph::locy_fold::{FoldBinding, resolve_locy_aggregate};
21use crate::query::df_graph::locy_profile::{
22 LocyExecProfile, LocyProfileCollector, LocyStratumProfile,
23};
24use crate::query::executor::core::OperatorStats;
25use crate::query::planner_locy_types::{
26 LocyCommand, LocyIsRef, LocyRulePlan, LocyStratum, LocyYieldColumn,
27};
28use arrow_array::RecordBatch;
29use arrow_schema::{DataType, Field, Schema as ArrowSchema, SchemaRef};
30use datafusion::common::Result as DFResult;
31use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
32use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
33use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
34use futures::Stream;
35use parking_lot::RwLock;
36use std::any::Any;
37use std::collections::HashMap;
38use std::fmt;
39use std::pin::Pin;
40use std::sync::Arc;
41use std::sync::RwLock as StdRwLock;
42use std::task::{Context, Poll};
43use std::time::{Duration, Instant};
44use uni_common::Value;
45use uni_common::core::schema::Schema as UniSchema;
46use uni_cypher::ast::Expr;
47use uni_locy::{
48 ClassifierRegistry, CommandResult, FactRow, ModelInvocationCache, RuntimeWarning, SemiringKind,
49};
50use uni_plugin::PluginRegistry;
51use uni_store::storage::manager::StorageManager;
52
53pub struct DerivedStore {
62 relations: HashMap<String, Vec<RecordBatch>>,
63}
64
65impl Default for DerivedStore {
66 fn default() -> Self {
67 Self::new()
68 }
69}
70
71impl DerivedStore {
72 pub fn new() -> Self {
73 Self {
74 relations: HashMap::new(),
75 }
76 }
77
78 pub fn insert(&mut self, rule_name: String, facts: Vec<RecordBatch>) {
79 self.relations.insert(rule_name, facts);
80 }
81
82 pub fn get(&self, rule_name: &str) -> Option<&Vec<RecordBatch>> {
83 self.relations.get(rule_name)
84 }
85
86 pub fn fact_count(&self, rule_name: &str) -> usize {
87 self.relations
88 .get(rule_name)
89 .map(|batches| batches.iter().map(|b| b.num_rows()).sum())
90 .unwrap_or(0)
91 }
92
93 pub fn rule_names(&self) -> impl Iterator<Item = &str> {
94 self.relations.keys().map(|s| s.as_str())
95 }
96}
97
98pub struct LocyProgramExec {
108 strata: Vec<LocyStratum>,
109 commands: Vec<LocyCommand>,
110 derived_scan_registry: Arc<DerivedScanRegistry>,
111 plugin_registry: Arc<PluginRegistry>,
112 graph_ctx: Arc<GraphExecutionContext>,
113 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
114 storage: Arc<StorageManager>,
115 schema_info: Arc<UniSchema>,
116 params: HashMap<String, Value>,
117 output_schema: SchemaRef,
118 properties: Arc<PlanProperties>,
119 metrics: ExecutionPlanMetricsSet,
120 max_iterations: usize,
121 timeout: Duration,
122 max_derived_bytes: usize,
123 deterministic_best_by: bool,
124 strict_probability_domain: bool,
125 probability_epsilon: f64,
126 exact_probability: bool,
127 max_bdd_variables: usize,
128 semiring_kind: SemiringKind,
131 classifier_registry: Arc<ClassifierRegistry>,
135 classifier_cache: Option<Arc<ModelInvocationCache>>,
138 classifier_provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
142 derived_store_slot: Arc<StdRwLock<Option<DerivedStore>>>,
144 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
146 derivation_tracker: Arc<StdRwLock<Option<Arc<ProvenanceStore>>>>,
148 iteration_counts_slot: Arc<StdRwLock<HashMap<String, usize>>>,
150 peak_memory_slot: Arc<StdRwLock<usize>>,
152 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
154 command_results_slot: Arc<StdRwLock<Vec<(usize, CommandResult)>>>,
156 top_k_proofs: usize,
158 timeout_flag: Arc<std::sync::atomic::AtomicU8>,
162 incomplete_slot: Arc<StdRwLock<Option<uni_common::LocyIncomplete>>>,
166 profile_enabled: std::sync::atomic::AtomicBool,
171 profile_slot: Arc<StdRwLock<Option<LocyExecProfile>>>,
174}
175
176pub(crate) mod interruption {
183 use std::sync::atomic::{AtomicU8, Ordering};
184
185 use uni_common::LocyIncompleteReason;
186
187 pub(crate) const NONE: u8 = 0;
189 pub(crate) const TIMEOUT: u8 = 1;
191 pub(crate) const ITERATION_LIMIT: u8 = 2;
193
194 pub(crate) fn reason(flag: &AtomicU8) -> Option<LocyIncompleteReason> {
196 match flag.load(Ordering::Relaxed) {
197 TIMEOUT => Some(LocyIncompleteReason::Timeout),
198 ITERATION_LIMIT => Some(LocyIncompleteReason::IterationLimit),
199 _ => None,
200 }
201 }
202
203 pub(crate) fn set(flag: &AtomicU8, code: u8) {
207 let _ = flag.compare_exchange(NONE, code, Ordering::Relaxed, Ordering::Relaxed);
208 }
209}
210
211impl fmt::Debug for LocyProgramExec {
212 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
213 f.debug_struct("LocyProgramExec")
214 .field("strata_count", &self.strata.len())
215 .field("commands_count", &self.commands.len())
216 .field("max_iterations", &self.max_iterations)
217 .field("timeout", &self.timeout)
218 .field("output_schema", &self.output_schema)
219 .field("max_derived_bytes", &self.max_derived_bytes)
220 .finish_non_exhaustive()
221 }
222}
223
224impl LocyProgramExec {
225 #[expect(
226 clippy::too_many_arguments,
227 reason = "execution plan node requires full graph and session context"
228 )]
229 #[deprecated(
230 note = "use `new_with_semiring_classifiers_and_cache` (or the lighter \
231 `new_with_semiring_and_classifiers` / `new_with_semiring`) — \
232 this legacy ctor defaults the semiring to AddMultProb and \
233 ships no classifier registry. To be removed after C0 Stage 2."
234 )]
235 pub fn new(
236 strata: Vec<LocyStratum>,
237 commands: Vec<LocyCommand>,
238 derived_scan_registry: Arc<DerivedScanRegistry>,
239 plugin_registry: Arc<PluginRegistry>,
240 graph_ctx: Arc<GraphExecutionContext>,
241 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
242 storage: Arc<StorageManager>,
243 schema_info: Arc<UniSchema>,
244 params: HashMap<String, Value>,
245 output_schema: SchemaRef,
246 max_iterations: usize,
247 timeout: Duration,
248 max_derived_bytes: usize,
249 deterministic_best_by: bool,
250 strict_probability_domain: bool,
251 probability_epsilon: f64,
252 exact_probability: bool,
253 max_bdd_variables: usize,
254 top_k_proofs: usize,
255 ) -> Self {
256 Self::new_with_semiring_and_classifiers(
257 strata,
258 commands,
259 derived_scan_registry,
260 plugin_registry,
261 graph_ctx,
262 session_ctx,
263 storage,
264 schema_info,
265 params,
266 output_schema,
267 max_iterations,
268 timeout,
269 max_derived_bytes,
270 deterministic_best_by,
271 strict_probability_domain,
272 probability_epsilon,
273 exact_probability,
274 max_bdd_variables,
275 top_k_proofs,
276 SemiringKind::AddMultProb,
277 Arc::new(ClassifierRegistry::new()),
278 )
279 }
280
281 #[expect(
285 clippy::too_many_arguments,
286 reason = "execution plan node requires full graph and session context"
287 )]
288 pub fn new_with_semiring(
289 strata: Vec<LocyStratum>,
290 commands: Vec<LocyCommand>,
291 derived_scan_registry: Arc<DerivedScanRegistry>,
292 plugin_registry: Arc<PluginRegistry>,
293 graph_ctx: Arc<GraphExecutionContext>,
294 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
295 storage: Arc<StorageManager>,
296 schema_info: Arc<UniSchema>,
297 params: HashMap<String, Value>,
298 output_schema: SchemaRef,
299 max_iterations: usize,
300 timeout: Duration,
301 max_derived_bytes: usize,
302 deterministic_best_by: bool,
303 strict_probability_domain: bool,
304 probability_epsilon: f64,
305 exact_probability: bool,
306 max_bdd_variables: usize,
307 top_k_proofs: usize,
308 semiring_kind: SemiringKind,
309 ) -> Self {
310 Self::new_with_semiring_and_classifiers(
311 strata,
312 commands,
313 derived_scan_registry,
314 plugin_registry,
315 graph_ctx,
316 session_ctx,
317 storage,
318 schema_info,
319 params,
320 output_schema,
321 max_iterations,
322 timeout,
323 max_derived_bytes,
324 deterministic_best_by,
325 strict_probability_domain,
326 probability_epsilon,
327 exact_probability,
328 max_bdd_variables,
329 top_k_proofs,
330 semiring_kind,
331 Arc::new(ClassifierRegistry::new()),
332 )
333 }
334
335 #[expect(
338 clippy::too_many_arguments,
339 reason = "execution plan node requires full graph and session context"
340 )]
341 pub fn new_with_semiring_and_classifiers(
342 strata: Vec<LocyStratum>,
343 commands: Vec<LocyCommand>,
344 derived_scan_registry: Arc<DerivedScanRegistry>,
345 plugin_registry: Arc<PluginRegistry>,
346 graph_ctx: Arc<GraphExecutionContext>,
347 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
348 storage: Arc<StorageManager>,
349 schema_info: Arc<UniSchema>,
350 params: HashMap<String, Value>,
351 output_schema: SchemaRef,
352 max_iterations: usize,
353 timeout: Duration,
354 max_derived_bytes: usize,
355 deterministic_best_by: bool,
356 strict_probability_domain: bool,
357 probability_epsilon: f64,
358 exact_probability: bool,
359 max_bdd_variables: usize,
360 top_k_proofs: usize,
361 semiring_kind: SemiringKind,
362 classifier_registry: Arc<ClassifierRegistry>,
363 ) -> Self {
364 Self::new_with_semiring_classifiers_and_cache(
365 strata,
366 commands,
367 derived_scan_registry,
368 plugin_registry,
369 graph_ctx,
370 session_ctx,
371 storage,
372 schema_info,
373 params,
374 output_schema,
375 max_iterations,
376 timeout,
377 max_derived_bytes,
378 deterministic_best_by,
379 strict_probability_domain,
380 probability_epsilon,
381 exact_probability,
382 max_bdd_variables,
383 top_k_proofs,
384 semiring_kind,
385 classifier_registry,
386 None,
387 None,
388 )
389 }
390
391 #[expect(
396 clippy::too_many_arguments,
397 reason = "execution plan node requires full graph and session context"
398 )]
399 pub fn new_with_semiring_classifiers_and_cache(
400 strata: Vec<LocyStratum>,
401 commands: Vec<LocyCommand>,
402 derived_scan_registry: Arc<DerivedScanRegistry>,
403 plugin_registry: Arc<PluginRegistry>,
404 graph_ctx: Arc<GraphExecutionContext>,
405 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
406 storage: Arc<StorageManager>,
407 schema_info: Arc<UniSchema>,
408 params: HashMap<String, Value>,
409 output_schema: SchemaRef,
410 max_iterations: usize,
411 timeout: Duration,
412 max_derived_bytes: usize,
413 deterministic_best_by: bool,
414 strict_probability_domain: bool,
415 probability_epsilon: f64,
416 exact_probability: bool,
417 max_bdd_variables: usize,
418 top_k_proofs: usize,
419 semiring_kind: SemiringKind,
420 classifier_registry: Arc<ClassifierRegistry>,
421 classifier_cache: Option<Arc<ModelInvocationCache>>,
422 classifier_provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
423 ) -> Self {
424 let properties = compute_plan_properties(Arc::clone(&output_schema));
425 Self {
426 strata,
427 commands,
428 derived_scan_registry,
429 plugin_registry,
430 graph_ctx,
431 session_ctx,
432 storage,
433 schema_info,
434 params,
435 output_schema,
436 properties,
437 metrics: ExecutionPlanMetricsSet::new(),
438 max_iterations,
439 timeout,
440 max_derived_bytes,
441 deterministic_best_by,
442 strict_probability_domain,
443 probability_epsilon,
444 exact_probability,
445 max_bdd_variables,
446 semiring_kind,
447 classifier_registry,
448 classifier_cache,
449 classifier_provenance_store,
450 derived_store_slot: Arc::new(StdRwLock::new(None)),
451 approximate_slot: Arc::new(StdRwLock::new(HashMap::new())),
452 derivation_tracker: Arc::new(StdRwLock::new(None)),
453 iteration_counts_slot: Arc::new(StdRwLock::new(HashMap::new())),
454 peak_memory_slot: Arc::new(StdRwLock::new(0)),
455 warnings_slot: Arc::new(StdRwLock::new(Vec::new())),
456 command_results_slot: Arc::new(StdRwLock::new(Vec::new())),
457 top_k_proofs,
458 timeout_flag: Arc::new(std::sync::atomic::AtomicU8::new(interruption::NONE)),
459 incomplete_slot: Arc::new(StdRwLock::new(None)),
460 profile_enabled: std::sync::atomic::AtomicBool::new(false),
461 profile_slot: Arc::new(StdRwLock::new(None)),
462 }
463 }
464
465 pub fn set_profile_enabled(&self, enabled: bool) {
472 self.profile_enabled
473 .store(enabled, std::sync::atomic::Ordering::Relaxed);
474 }
475
476 pub fn profile_slot(&self) -> Arc<StdRwLock<Option<LocyExecProfile>>> {
479 Arc::clone(&self.profile_slot)
480 }
481
482 pub fn derived_store_slot(&self) -> Arc<StdRwLock<Option<DerivedStore>>> {
487 Arc::clone(&self.derived_store_slot)
488 }
489
490 pub fn set_derivation_tracker(&self, tracker: Arc<ProvenanceStore>) {
495 if let Ok(mut guard) = self.derivation_tracker.write() {
496 *guard = Some(tracker);
497 }
498 }
499
500 pub fn iteration_counts_slot(&self) -> Arc<StdRwLock<HashMap<String, usize>>> {
505 Arc::clone(&self.iteration_counts_slot)
506 }
507
508 pub fn peak_memory_slot(&self) -> Arc<StdRwLock<usize>> {
513 Arc::clone(&self.peak_memory_slot)
514 }
515
516 pub fn warnings_slot(&self) -> Arc<StdRwLock<Vec<RuntimeWarning>>> {
521 Arc::clone(&self.warnings_slot)
522 }
523
524 pub fn approximate_slot(&self) -> Arc<StdRwLock<HashMap<String, Vec<String>>>> {
529 Arc::clone(&self.approximate_slot)
530 }
531
532 pub fn command_results_slot(&self) -> Arc<StdRwLock<Vec<(usize, CommandResult)>>> {
537 Arc::clone(&self.command_results_slot)
538 }
539
540 pub fn timeout_flag(&self) -> Arc<std::sync::atomic::AtomicU8> {
546 Arc::clone(&self.timeout_flag)
547 }
548
549 pub fn incomplete_slot(&self) -> Arc<StdRwLock<Option<uni_common::LocyIncomplete>>> {
555 Arc::clone(&self.incomplete_slot)
556 }
557}
558
559impl DisplayAs for LocyProgramExec {
560 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
561 write!(
562 f,
563 "LocyProgramExec: strata={}, commands={}, max_iter={}, timeout={:?}",
564 self.strata.len(),
565 self.commands.len(),
566 self.max_iterations,
567 self.timeout,
568 )
569 }
570}
571
572impl ExecutionPlan for LocyProgramExec {
573 fn name(&self) -> &str {
574 "LocyProgramExec"
575 }
576
577 fn as_any(&self) -> &dyn Any {
578 self
579 }
580
581 fn schema(&self) -> SchemaRef {
582 Arc::clone(&self.output_schema)
583 }
584
585 fn properties(&self) -> &Arc<PlanProperties> {
586 &self.properties
587 }
588
589 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
590 vec![]
591 }
592
593 fn with_new_children(
594 self: Arc<Self>,
595 children: Vec<Arc<dyn ExecutionPlan>>,
596 ) -> DFResult<Arc<dyn ExecutionPlan>> {
597 if !children.is_empty() {
598 return Err(datafusion::error::DataFusionError::Plan(
599 "LocyProgramExec has no children".to_string(),
600 ));
601 }
602 Ok(self)
603 }
604
605 fn execute(
606 &self,
607 partition: usize,
608 _context: Arc<TaskContext>,
609 ) -> DFResult<SendableRecordBatchStream> {
610 let metrics = BaselineMetrics::new(&self.metrics, partition);
611
612 let strata = self.strata.clone();
613 let registry = Arc::clone(&self.derived_scan_registry);
614 let plugin_registry = Arc::clone(&self.plugin_registry);
615 let graph_ctx = Arc::clone(&self.graph_ctx);
616 let session_ctx = Arc::clone(&self.session_ctx);
617 let storage = Arc::clone(&self.storage);
618 let schema_info = Arc::clone(&self.schema_info);
619 let params = self.params.clone();
620 let output_schema = Arc::clone(&self.output_schema);
621 let max_iterations = self.max_iterations;
622 let timeout = self.timeout;
623 let max_derived_bytes = self.max_derived_bytes;
624 let deterministic_best_by = self.deterministic_best_by;
625 let strict_probability_domain = self.strict_probability_domain;
626 let probability_epsilon = self.probability_epsilon;
627 let exact_probability = self.exact_probability;
628 let max_bdd_variables = self.max_bdd_variables;
629 let derived_store_slot = Arc::clone(&self.derived_store_slot);
630 let approximate_slot = Arc::clone(&self.approximate_slot);
631 let iteration_counts_slot = Arc::clone(&self.iteration_counts_slot);
632 let peak_memory_slot = Arc::clone(&self.peak_memory_slot);
633 let derivation_tracker = self.derivation_tracker.read().ok().and_then(|g| g.clone());
634 let warnings_slot = Arc::clone(&self.warnings_slot);
635 let commands = self.commands.clone();
636 let command_results_slot = Arc::clone(&self.command_results_slot);
637 let top_k_proofs = self.top_k_proofs;
638 let timeout_flag = Arc::clone(&self.timeout_flag);
639 let incomplete_slot = Arc::clone(&self.incomplete_slot);
640 let semiring_kind = self.semiring_kind;
641 let classifier_registry = Arc::clone(&self.classifier_registry);
642 let classifier_cache = self.classifier_cache.as_ref().map(Arc::clone);
643 let classifier_provenance_store = self.classifier_provenance_store.as_ref().map(Arc::clone);
644 let profile_enabled = self
645 .profile_enabled
646 .load(std::sync::atomic::Ordering::Relaxed);
647 let profile_slot = Arc::clone(&self.profile_slot);
648
649 let fut = async move {
650 run_program(
651 strata,
652 commands,
653 registry,
654 plugin_registry,
655 graph_ctx,
656 session_ctx,
657 storage,
658 schema_info,
659 params,
660 output_schema,
661 max_iterations,
662 timeout,
663 max_derived_bytes,
664 deterministic_best_by,
665 strict_probability_domain,
666 probability_epsilon,
667 exact_probability,
668 max_bdd_variables,
669 derived_store_slot,
670 approximate_slot,
671 iteration_counts_slot,
672 peak_memory_slot,
673 derivation_tracker,
674 warnings_slot,
675 command_results_slot,
676 top_k_proofs,
677 timeout_flag,
678 incomplete_slot,
679 semiring_kind,
680 classifier_registry,
681 classifier_cache,
682 classifier_provenance_store,
683 profile_enabled,
684 profile_slot,
685 )
686 .await
687 };
688
689 Ok(Box::pin(ProgramStream {
690 state: ProgramStreamState::Running(Box::pin(fut)),
691 schema: Arc::clone(&self.output_schema),
692 metrics,
693 }))
694 }
695
696 fn metrics(&self) -> Option<MetricsSet> {
697 Some(self.metrics.clone_inner())
698 }
699}
700
701enum ProgramStreamState {
706 Running(Pin<Box<dyn std::future::Future<Output = DFResult<Vec<RecordBatch>>> + Send>>),
707 Emitting(Vec<RecordBatch>, usize),
708 Done,
709}
710
711struct ProgramStream {
712 state: ProgramStreamState,
713 schema: SchemaRef,
714 metrics: BaselineMetrics,
715}
716
717impl Stream for ProgramStream {
718 type Item = DFResult<RecordBatch>;
719
720 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
721 let this = self.get_mut();
722 let metrics = this.metrics.clone();
723 let _timer = metrics.elapsed_compute().timer();
724 loop {
725 match &mut this.state {
726 ProgramStreamState::Running(fut) => match fut.as_mut().poll(cx) {
727 Poll::Ready(Ok(batches)) => {
728 if batches.is_empty() {
729 this.state = ProgramStreamState::Done;
730 return Poll::Ready(None);
731 }
732 this.state = ProgramStreamState::Emitting(batches, 0);
733 }
734 Poll::Ready(Err(e)) => {
735 this.state = ProgramStreamState::Done;
736 return Poll::Ready(Some(Err(e)));
737 }
738 Poll::Pending => return Poll::Pending,
739 },
740 ProgramStreamState::Emitting(batches, idx) => {
741 if *idx >= batches.len() {
742 this.state = ProgramStreamState::Done;
743 return Poll::Ready(None);
744 }
745 let batch = batches[*idx].clone();
746 *idx += 1;
747 this.metrics.record_output(batch.num_rows());
748 return Poll::Ready(Some(Ok(batch)));
749 }
750 ProgramStreamState::Done => return Poll::Ready(None),
751 }
752 }
753 }
754}
755
756impl RecordBatchStream for ProgramStream {
757 fn schema(&self) -> SchemaRef {
758 Arc::clone(&self.schema)
759 }
760}
761
762async fn execute_cypher_inline(
768 query: &uni_cypher::ast::Query,
769 schema_info: &Arc<UniSchema>,
770 params: &HashMap<String, Value>,
771 graph_ctx: &Arc<GraphExecutionContext>,
772 session_ctx: &Arc<RwLock<datafusion::prelude::SessionContext>>,
773 storage: &Arc<StorageManager>,
774) -> DFResult<Vec<FactRow>> {
775 let planner = crate::query::planner::QueryPlanner::new(Arc::clone(schema_info));
776 let logical_plan = planner.plan(query.clone()).map_err(|e| {
777 datafusion::error::DataFusionError::Execution(format!("Cypher plan error: {e}"))
778 })?;
779 let batches = execute_subplan(
780 &logical_plan,
781 params,
782 &HashMap::new(),
783 graph_ctx,
784 session_ctx,
785 storage,
786 schema_info,
787 None, )
789 .await?;
790 Ok(super::locy_eval::record_batches_to_locy_rows(&batches))
791}
792
793#[expect(
798 clippy::too_many_arguments,
799 reason = "program evaluation requires full graph and session context"
800)]
801async fn run_program(
802 strata: Vec<LocyStratum>,
803 commands: Vec<LocyCommand>,
804 registry: Arc<DerivedScanRegistry>,
805 plugin_registry: Arc<PluginRegistry>,
806 graph_ctx: Arc<GraphExecutionContext>,
807 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
808 storage: Arc<StorageManager>,
809 schema_info: Arc<UniSchema>,
810 params: HashMap<String, Value>,
811 output_schema: SchemaRef,
812 max_iterations: usize,
813 timeout: Duration,
814 max_derived_bytes: usize,
815 deterministic_best_by: bool,
816 strict_probability_domain: bool,
817 probability_epsilon: f64,
818 exact_probability: bool,
819 max_bdd_variables: usize,
820 derived_store_slot: Arc<StdRwLock<Option<DerivedStore>>>,
821 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
822 iteration_counts_slot: Arc<StdRwLock<HashMap<String, usize>>>,
823 peak_memory_slot: Arc<StdRwLock<usize>>,
824 derivation_tracker: Option<Arc<ProvenanceStore>>,
825 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
826 command_results_slot: Arc<StdRwLock<Vec<(usize, CommandResult)>>>,
827 top_k_proofs: usize,
828 timeout_flag: Arc<std::sync::atomic::AtomicU8>,
829 incomplete_slot: Arc<StdRwLock<Option<uni_common::LocyIncomplete>>>,
830 semiring_kind: SemiringKind,
831 classifier_registry: Arc<ClassifierRegistry>,
832 classifier_cache: Option<Arc<ModelInvocationCache>>,
833 classifier_provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
834 profile_enabled: bool,
835 profile_slot: Arc<StdRwLock<Option<LocyExecProfile>>>,
836) -> DFResult<Vec<RecordBatch>> {
837 let start = Instant::now();
838 let mut derived_store = DerivedStore::new();
839 let mut stratum_profiles: Vec<LocyStratumProfile> = Vec::new();
841
842 if semiring_kind == SemiringKind::MaxMinProb {
847 let mut warnings = warnings_slot.write().unwrap_or_else(|e| e.into_inner());
848 let mut already: std::collections::HashSet<String> = warnings
849 .iter()
850 .filter(|w| w.code == uni_locy::RuntimeWarningCode::FuzzyNotProbabilistic)
851 .map(|w| w.rule_name.clone())
852 .collect();
853 for stratum in &strata {
854 for rule in &stratum.rules {
855 let has_prob = rule.yield_schema.iter().any(|c| c.is_prob);
856 if has_prob && !already.contains(&rule.name) {
857 warnings.push(RuntimeWarning {
858 code: uni_locy::RuntimeWarningCode::FuzzyNotProbabilistic,
859 message: format!(
860 "rule '{}' carries a PROB column but is being evaluated under \
861 the MaxMinProb (fuzzy / Viterbi) semiring; outputs are fuzzy \
862 truth values, not probabilities",
863 rule.name
864 ),
865 rule_name: rule.name.clone(),
866 variable_count: None,
867 key_group: None,
868 });
869 already.insert(rule.name.clone());
870 }
871 }
872 }
873 }
874
875 let total_strata = strata.len();
879 let mut completed_strata = 0usize;
880 let mut partial_stratum: Option<usize> = None;
881 for (stratum_idx, stratum) in strata.iter().enumerate() {
882 write_cross_stratum_facts(®istry, &derived_store, stratum);
884
885 let stratum_start = Instant::now();
888 let collector = profile_enabled.then(|| Arc::new(LocyProfileCollector::default()));
889
890 let remaining_timeout = timeout.saturating_sub(start.elapsed());
891 if remaining_timeout.is_zero() {
892 tracing::warn!("Locy program timeout exceeded during stratum evaluation");
893 interruption::set(&timeout_flag, interruption::TIMEOUT);
894 break;
895 }
896
897 if stratum.is_recursive {
898 let fixpoint_rules = convert_to_fixpoint_plans(
900 &stratum.rules,
901 ®istry,
902 &plugin_registry,
903 deterministic_best_by,
904 )?;
905 let fixpoint_schema = build_fixpoint_output_schema(&stratum.rules);
906
907 let mut exec = FixpointExec::new_with_semiring_classifiers_and_cache(
908 fixpoint_rules,
909 max_iterations,
910 remaining_timeout,
911 Arc::clone(&graph_ctx),
912 Arc::clone(&session_ctx),
913 Arc::clone(&storage),
914 Arc::clone(&schema_info),
915 params.clone(),
916 Arc::clone(®istry),
917 fixpoint_schema,
918 max_derived_bytes,
919 derivation_tracker.clone(),
920 Arc::clone(&iteration_counts_slot),
921 strict_probability_domain,
922 probability_epsilon,
923 exact_probability,
924 max_bdd_variables,
925 Arc::clone(&warnings_slot),
926 Arc::clone(&approximate_slot),
927 top_k_proofs,
928 Arc::clone(&timeout_flag),
929 semiring_kind,
930 Arc::clone(&classifier_registry),
931 classifier_cache.as_ref().map(Arc::clone),
932 classifier_provenance_store.as_ref().map(Arc::clone),
933 );
934
935 if let Some(ref c) = collector {
936 exec.set_profile_collector(Arc::clone(c));
937 }
938 let task_ctx = session_ctx.read().task_ctx();
939 let exec_arc: Arc<dyn ExecutionPlan> = Arc::new(exec);
940 let batches = collect_all_partitions(&exec_arc, task_ctx).await?;
941
942 for rule in &stratum.rules {
953 if rule.yield_schema.is_empty() {
955 continue;
956 }
957 let rule_entries = registry.entries_for_rule(&rule.name);
959 for entry in rule_entries {
960 if !entry.is_self_ref {
961 let all_facts: Vec<RecordBatch> = batches
965 .iter()
966 .filter(|b| {
967 let rule_schema = yield_columns_to_arrow_schema(&rule.yield_schema);
969 b.schema().fields().len() == rule_schema.fields().len()
970 })
971 .cloned()
972 .collect();
973 let mut guard = entry.data.write();
974 *guard = if all_facts.is_empty() {
975 vec![RecordBatch::new_empty(Arc::clone(&entry.schema))]
976 } else {
977 all_facts
978 };
979 }
980 }
981 derived_store.insert(rule.name.clone(), batches.clone());
982 }
983 } else {
984 let fixpoint_rules = convert_to_fixpoint_plans(
986 &stratum.rules,
987 ®istry,
988 &plugin_registry,
989 deterministic_best_by,
990 )?;
991 let task_ctx = session_ctx.read().task_ctx();
992
993 for (rule, fp_rule) in stratum.rules.iter().zip(fixpoint_rules.iter()) {
994 if rule.yield_schema.is_empty() {
999 continue;
1000 }
1001
1002 if let Ok(mut counts) = iteration_counts_slot.write() {
1007 counts.insert(rule.name.clone(), 1);
1008 }
1009
1010 let rule_start = Instant::now();
1014 let mut iter_ops: Vec<OperatorStats> = Vec::new();
1015 let mut tagged_clause_facts: Vec<(usize, Vec<RecordBatch>)> = Vec::new();
1016 for (clause_idx, (clause, fp_clause)) in
1017 rule.clauses.iter().zip(fp_rule.clauses.iter()).enumerate()
1018 {
1019 let mut batches = if collector.is_some() {
1023 let (b, ops) = execute_subplan_collecting(
1024 &clause.body,
1025 ¶ms,
1026 &HashMap::new(),
1027 &graph_ctx,
1028 &session_ctx,
1029 &storage,
1030 &schema_info,
1031 None, )
1033 .await?;
1034 iter_ops.extend(ops);
1035 b
1036 } else {
1037 execute_subplan(
1038 &clause.body,
1039 ¶ms,
1040 &HashMap::new(),
1041 &graph_ctx,
1042 &session_ctx,
1043 &storage,
1044 &schema_info,
1045 None, )
1047 .await?
1048 };
1049
1050 for binding in &fp_clause.is_ref_bindings {
1052 if binding.negated
1053 && !binding.anti_join_cols.is_empty()
1054 && let Some(entry) = registry.get(binding.derived_scan_index)
1055 {
1056 let neg_facts = entry.data.read().clone();
1057 if !neg_facts.is_empty() {
1058 if binding.target_has_prob && fp_rule.prob_column_name.is_some() {
1059 let complement_col =
1060 format!("__prob_complement_{}", binding.rule_name);
1061 if let Some(prob_col) = &binding.target_prob_col {
1062 batches =
1063 super::locy_fixpoint::apply_prob_complement_composite(
1064 batches,
1065 &neg_facts,
1066 &binding.anti_join_cols,
1067 prob_col,
1068 &complement_col,
1069 )?;
1070 } else {
1071 batches = super::locy_fixpoint::apply_anti_join_composite(
1073 batches,
1074 &neg_facts,
1075 &binding.anti_join_cols,
1076 )?;
1077 }
1078 } else {
1079 batches = super::locy_fixpoint::apply_anti_join_composite(
1080 batches,
1081 &neg_facts,
1082 &binding.anti_join_cols,
1083 )?;
1084 }
1085 }
1086 }
1087 }
1088
1089 let complement_cols: Vec<String> = if !batches.is_empty() {
1091 batches[0]
1092 .schema()
1093 .fields()
1094 .iter()
1095 .filter(|f| f.name().starts_with("__prob_complement_"))
1096 .map(|f| f.name().clone())
1097 .collect()
1098 } else {
1099 vec![]
1100 };
1101 if !complement_cols.is_empty() {
1102 batches = super::locy_fixpoint::multiply_prob_factors(
1103 batches,
1104 fp_rule.prob_column_name.as_deref(),
1105 &complement_cols,
1106 )?;
1107 }
1108
1109 tagged_clause_facts.push((clause_idx, batches));
1110 }
1111
1112 let shared_info = if semiring_kind == SemiringKind::MaxMinProb {
1125 None
1126 } else if let Some(ref tracker) = derivation_tracker {
1127 super::locy_fixpoint::record_and_detect_lineage_nonrecursive(
1128 fp_rule,
1129 &tagged_clause_facts,
1130 tracker,
1131 &warnings_slot,
1132 ®istry,
1133 top_k_proofs,
1134 super::locy_fixpoint::ClassifierRefs {
1135 registry: &classifier_registry,
1136 cache: classifier_cache.as_ref(),
1137 provenance_store: classifier_provenance_store.as_ref(),
1138 },
1139 semiring_kind,
1140 )
1141 .await
1142 } else {
1143 None
1144 };
1145
1146 let mut all_clause_facts: Vec<RecordBatch> = tagged_clause_facts
1148 .into_iter()
1149 .flat_map(|(_, batches)| batches)
1150 .collect();
1151
1152 if exact_probability
1154 && let Some(ref info) = shared_info
1155 && let Some(ref tracker) = derivation_tracker
1156 {
1157 all_clause_facts = super::locy_fixpoint::apply_exact_wmc(
1158 all_clause_facts,
1159 fp_rule,
1160 info,
1161 tracker,
1162 max_bdd_variables,
1163 &warnings_slot,
1164 &approximate_slot,
1165 )?;
1166 }
1167
1168 let facts = super::locy_fixpoint::apply_post_fixpoint_chain(
1170 all_clause_facts,
1171 fp_rule,
1172 &task_ctx,
1173 strict_probability_domain,
1174 probability_epsilon,
1175 semiring_kind,
1176 derivation_tracker.as_ref().map(Arc::clone),
1177 top_k_proofs,
1178 Some(Arc::clone(®istry)),
1179 )
1180 .await?;
1181
1182 if let Some(ref c) = collector {
1184 let fact_count: usize = facts.iter().map(|b| b.num_rows()).sum();
1185 c.record(
1186 &rule.name,
1187 0,
1188 fact_count,
1189 rule_start.elapsed().as_secs_f64() * 1000.0,
1190 std::mem::take(&mut iter_ops),
1191 );
1192 c.set_final_facts(&rule.name, fact_count);
1193 }
1194
1195 write_facts_to_registry(®istry, &rule.name, &facts);
1197 derived_store.insert(rule.name.clone(), facts);
1198 }
1199 }
1200
1201 if let Some(c) = collector {
1203 let rules = c.into_rules();
1204 let iterations = rules.iter().map(|r| r.iterations.len()).max().unwrap_or(0);
1205 let facts_derived: usize = rules.iter().map(|r| r.facts).sum();
1206 stratum_profiles.push(LocyStratumProfile {
1207 index: stratum_idx,
1208 recursive: stratum.is_recursive,
1209 elapsed_ms: stratum_start.elapsed().as_secs_f64() * 1000.0,
1210 iterations,
1211 facts_derived,
1212 rules,
1213 });
1214 }
1215
1216 if interruption::reason(&timeout_flag).is_some() {
1220 partial_stratum = Some(stratum_idx);
1221 break;
1222 }
1223 completed_strata += 1;
1224 }
1225
1226 if let Some(reason) = interruption::reason(&timeout_flag) {
1230 let skipped_start = match partial_stratum {
1231 Some(i) => i + 1,
1232 None => completed_strata,
1233 };
1234 let incomplete_rules: Vec<String> = partial_stratum
1235 .map(|i| strata[i].rules.iter().map(|r| r.name.clone()).collect())
1236 .unwrap_or_default();
1237 let skipped_rules: Vec<String> = strata[skipped_start..]
1238 .iter()
1239 .flat_map(|s| s.rules.iter().map(|r| r.name.clone()))
1240 .collect();
1241 let mut complement_rules_affected = Vec::new();
1242 for idx in partial_stratum
1243 .into_iter()
1244 .chain(skipped_start..total_strata)
1245 {
1246 for rule in &strata[idx].rules {
1247 if rule
1248 .clauses
1249 .iter()
1250 .any(|c| c.is_refs.iter().any(|r| r.negated))
1251 {
1252 complement_rules_affected.push(rule.name.clone());
1253 }
1254 }
1255 }
1256 if let Ok(mut slot) = incomplete_slot.write() {
1257 *slot = Some(uni_common::LocyIncomplete {
1258 reason,
1259 elapsed_ms: u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
1260 limit_ms: u64::try_from(timeout.as_millis()).unwrap_or(u64::MAX),
1261 max_iterations,
1262 completed_strata,
1263 total_strata,
1264 incomplete_rules,
1265 skipped_rules,
1266 complement_rules_affected,
1267 });
1268 }
1269 }
1270
1271 let peak_bytes: usize = derived_store
1273 .relations
1274 .values()
1275 .flat_map(|batches| batches.iter())
1276 .map(|b| {
1277 b.columns()
1278 .iter()
1279 .map(|col| col.get_buffer_memory_size())
1280 .sum::<usize>()
1281 })
1282 .sum();
1283 *peak_memory_slot.write().unwrap() = peak_bytes;
1284
1285 if profile_enabled && let Ok(mut slot) = profile_slot.write() {
1287 *slot = Some(LocyExecProfile {
1288 total_elapsed_ms: start.elapsed().as_secs_f64() * 1000.0,
1289 peak_memory_bytes: peak_bytes,
1290 strata: std::mem::take(&mut stratum_profiles),
1291 });
1292 }
1293
1294 let first_derive_idx = commands
1304 .iter()
1305 .position(|c| matches!(c, LocyCommand::Derive { .. }));
1306 let mut inline_results: Vec<(usize, CommandResult)> = Vec::new();
1307 for (cmd_idx, cmd) in commands.iter().enumerate() {
1308 match cmd {
1309 LocyCommand::Cypher { query } => {
1310 if first_derive_idx.is_some_and(|di| cmd_idx > di) {
1313 continue;
1314 }
1315 let rows = execute_cypher_inline(
1316 query,
1317 &schema_info,
1318 ¶ms,
1319 &graph_ctx,
1320 &session_ctx,
1321 &storage,
1322 )
1323 .await?;
1324 inline_results.push((cmd_idx, CommandResult::Cypher(rows)));
1325 }
1326 LocyCommand::Validate { validate } => {
1327 let rule_key_cols: Vec<String> = strata
1331 .iter()
1332 .flat_map(|s| s.rules.iter())
1333 .find(|r| r.name == validate.rule_name)
1334 .map(|r| {
1335 r.yield_schema
1336 .iter()
1337 .filter(|c| c.is_key)
1338 .map(|c| c.name.clone())
1339 .collect()
1340 })
1341 .unwrap_or_default();
1342 let query =
1343 super::locy_validate::validate_collection_query(validate, &rule_key_cols);
1344 let target_rows = execute_cypher_inline(
1345 &query,
1346 &schema_info,
1347 ¶ms,
1348 &graph_ctx,
1349 &session_ctx,
1350 &storage,
1351 )
1352 .await?;
1353 let rule_facts: Vec<uni_locy::FactRow> = derived_store
1354 .get(&validate.rule_name)
1355 .map(|batches| super::locy_eval::record_batches_to_locy_rows(batches))
1356 .unwrap_or_default();
1357 let result = super::locy_validate::run_validate(
1358 validate,
1359 &rule_key_cols,
1360 &rule_facts,
1361 target_rows,
1362 )
1363 .map_err(|e| {
1364 datafusion::error::DataFusionError::Execution(format!("VALIDATE error: {e}"))
1365 })?;
1366 inline_results.push((cmd_idx, CommandResult::Validate(result)));
1367 }
1368 LocyCommand::Calibrate {
1369 calibrate,
1370 model_inputs,
1371 } => {
1372 let model_snapshot = uni_locy::CompiledModel {
1385 name: calibrate.model_name.clone(),
1386 inputs: model_inputs.clone(),
1387 features: vec![],
1388 path_context: None,
1389 output_type: uni_cypher::locy_ast::OutputType::Prob,
1390 output_name: String::new(),
1391 xervo_alias: String::new(),
1392 embedder_alias: None,
1393 calibration: None,
1394 version: None,
1395 annotations: Default::default(),
1396 };
1397 let query =
1398 super::locy_calibrate::calibrate_collection_query(calibrate, &model_snapshot);
1399 let rows = execute_cypher_inline(
1400 &query,
1401 &schema_info,
1402 ¶ms,
1403 &graph_ctx,
1404 &session_ctx,
1405 &storage,
1406 )
1407 .await?;
1408 let mut catalog = std::collections::HashMap::new();
1409 catalog.insert(calibrate.model_name.clone(), model_snapshot);
1410 let result = super::locy_calibrate::run_calibrate(
1411 calibrate,
1412 &catalog,
1413 &classifier_registry,
1414 rows,
1415 )
1416 .await
1417 .map_err(|e| {
1418 datafusion::error::DataFusionError::Execution(format!("CALIBRATE error: {e}"))
1419 })?;
1420 inline_results.push((cmd_idx, CommandResult::Calibrate(result)));
1421 }
1422 _ => {}
1423 }
1424 }
1425 *command_results_slot.write().unwrap() = inline_results;
1426
1427 let stats = vec![build_stats_batch(&derived_store, &strata, output_schema)];
1428 *derived_store_slot.write().unwrap() = Some(derived_store);
1429 Ok(stats)
1430}
1431
1432fn write_cross_stratum_facts(
1438 registry: &DerivedScanRegistry,
1439 derived_store: &DerivedStore,
1440 stratum: &LocyStratum,
1441) {
1442 for rule in &stratum.rules {
1444 for clause in &rule.clauses {
1445 for is_ref in &clause.is_refs {
1446 if let Some(facts) = derived_store.get(&is_ref.rule_name) {
1449 write_facts_to_registry(registry, &is_ref.rule_name, facts);
1450 }
1451 }
1452 }
1453 }
1454}
1455
1456fn write_facts_to_registry(registry: &DerivedScanRegistry, rule_name: &str, facts: &[RecordBatch]) {
1458 let entries = registry.entries_for_rule(rule_name);
1459 for entry in entries {
1460 if !entry.is_self_ref {
1461 let mut guard = entry.data.write();
1462 *guard = if facts.is_empty() || facts.iter().all(|b| b.num_rows() == 0) {
1463 vec![RecordBatch::new_empty(Arc::clone(&entry.schema))]
1464 } else {
1465 facts
1470 .iter()
1471 .filter(|b| b.num_rows() > 0)
1472 .map(|b| {
1473 RecordBatch::try_new(Arc::clone(&entry.schema), b.columns().to_vec())
1474 .unwrap_or_else(|_| b.clone())
1475 })
1476 .collect()
1477 };
1478 }
1479 }
1480}
1481
1482fn convert_to_fixpoint_plans(
1488 rules: &[LocyRulePlan],
1489 registry: &DerivedScanRegistry,
1490 plugin_registry: &PluginRegistry,
1491 deterministic_best_by: bool,
1492) -> DFResult<Vec<FixpointRulePlan>> {
1493 let stratum_rule_names: std::collections::HashSet<&str> =
1496 rules.iter().map(|r| r.name.as_str()).collect();
1497 rules
1498 .iter()
1499 .map(|rule| {
1500 let yield_schema = yield_columns_to_arrow_schema(&rule.yield_schema);
1501 let key_column_indices: Vec<usize> = rule
1502 .yield_schema
1503 .iter()
1504 .enumerate()
1505 .filter(|(_, yc)| yc.is_key)
1506 .map(|(i, _)| i)
1507 .collect();
1508
1509 let clauses: Vec<FixpointClausePlan> = rule
1510 .clauses
1511 .iter()
1512 .map(|clause| {
1513 let is_ref_bindings =
1514 convert_is_refs(&clause.is_refs, registry, &stratum_rule_names)?;
1515 Ok(FixpointClausePlan {
1516 body_logical: clause.body.clone(),
1517 is_ref_bindings,
1518 priority: clause.priority,
1519 along_bindings: clause.along_bindings.clone(),
1520 model_invocations: clause.model_invocations.clone(),
1521 })
1522 })
1523 .collect::<DFResult<Vec<_>>>()?;
1524
1525 let fold_bindings =
1526 convert_fold_bindings(&rule.fold_bindings, &rule.yield_schema, plugin_registry)?;
1527 let best_by_criteria =
1528 convert_best_by_criteria(&rule.best_by_criteria, &rule.yield_schema)?;
1529
1530 let has_priority = rule.priority.is_some();
1531
1532 let yield_schema = if has_priority {
1534 let mut fields: Vec<Arc<Field>> = yield_schema.fields().iter().cloned().collect();
1535 fields.push(Arc::new(Field::new("__priority", DataType::Int64, true)));
1536 ArrowSchema::new(fields)
1537 } else {
1538 yield_schema
1539 };
1540
1541 let prob_column_name = rule
1542 .yield_schema
1543 .iter()
1544 .find(|yc| yc.is_prob)
1545 .map(|yc| yc.name.clone());
1546
1547 let non_linear = rule.clauses.iter().any(|clause| {
1551 clause
1552 .is_refs
1553 .iter()
1554 .filter(|ir| !ir.negated && stratum_rule_names.contains(ir.rule_name.as_str()))
1555 .count()
1556 >= 2
1557 });
1558
1559 Ok(FixpointRulePlan {
1560 name: rule.name.clone(),
1561 clauses,
1562 yield_schema: Arc::new(yield_schema),
1563 key_column_indices,
1564 priority: rule.priority,
1565 has_fold: !rule.fold_bindings.is_empty(),
1566 fold_bindings,
1567 having: rule.having.clone(),
1568 has_best_by: !rule.best_by_criteria.is_empty(),
1569 best_by_criteria,
1570 has_priority,
1571 deterministic: deterministic_best_by,
1572 prob_column_name,
1573 non_linear,
1574 })
1575 })
1576 .collect()
1577}
1578
1579fn convert_is_refs(
1591 is_refs: &[LocyIsRef],
1592 registry: &DerivedScanRegistry,
1593 stratum_rule_names: &std::collections::HashSet<&str>,
1594) -> DFResult<Vec<IsRefBinding>> {
1595 is_refs
1596 .iter()
1597 .map(|is_ref| {
1598 let entries = registry.entries_for_rule(&is_ref.rule_name);
1599 let want_self_ref = stratum_rule_names.contains(is_ref.rule_name.as_str());
1604 let entry = entries
1605 .iter()
1606 .find(|e| e.is_self_ref == want_self_ref)
1607 .or_else(|| entries.first())
1608 .ok_or_else(|| {
1609 datafusion::error::DataFusionError::Plan(format!(
1610 "No derived scan entry found for IS-ref to '{}'",
1611 is_ref.rule_name
1612 ))
1613 })?;
1614
1615 let anti_join_cols = if is_ref.negated {
1620 let mut cols: Vec<(String, String)> = is_ref
1621 .subjects
1622 .iter()
1623 .enumerate()
1624 .filter_map(|(i, s)| {
1625 if let uni_cypher::ast::Expr::Variable(var) = s {
1626 let right_col = entry
1627 .schema
1628 .fields()
1629 .get(i)
1630 .map(|f| f.name().clone())
1631 .unwrap_or_else(|| var.clone());
1632 Some((var.clone(), right_col))
1635 } else {
1636 None
1637 }
1638 })
1639 .collect();
1640 if let Some(uni_cypher::ast::Expr::Variable(target_var)) = &is_ref.target {
1645 let target_idx = is_ref.subjects.len();
1646 if let Some(field) = entry.schema.fields().get(target_idx) {
1647 cols.push((target_var.clone(), field.name().clone()));
1648 }
1649 }
1650 cols
1651 } else {
1652 Vec::new()
1653 };
1654
1655 let provenance_join_cols: Vec<(String, String)> = is_ref
1659 .subjects
1660 .iter()
1661 .enumerate()
1662 .filter_map(|(i, s)| {
1663 if let uni_cypher::ast::Expr::Variable(var) = s {
1664 let right_col = entry
1665 .schema
1666 .fields()
1667 .get(i)
1668 .map(|f| f.name().clone())
1669 .unwrap_or_else(|| var.clone());
1670 Some((var.clone(), right_col))
1671 } else {
1672 None
1673 }
1674 })
1675 .collect();
1676
1677 Ok(IsRefBinding {
1678 derived_scan_index: entry.scan_index,
1679 rule_name: is_ref.rule_name.clone(),
1680 is_self_ref: entry.is_self_ref,
1681 negated: is_ref.negated,
1682 anti_join_cols,
1683 target_has_prob: is_ref.target_has_prob,
1684 target_prob_col: is_ref.target_prob_col.clone(),
1685 provenance_join_cols,
1686 })
1687 })
1688 .collect()
1689}
1690
1691fn convert_fold_bindings(
1699 fold_bindings: &[(String, String, Expr)],
1700 yield_schema: &[LocyYieldColumn],
1701 plugin_registry: &PluginRegistry,
1702) -> DFResult<Vec<FoldBinding>> {
1703 fold_bindings
1704 .iter()
1705 .map(|(name, yield_alias, expr)| {
1706 let (agg_name, _input_col_name) = parse_fold_aggregate(expr)?;
1707 let entry =
1708 resolve_locy_aggregate(plugin_registry, agg_name.as_str()).ok_or_else(|| {
1709 datafusion::error::DataFusionError::Plan(format!(
1710 "Unknown Locy aggregate '{agg_name}' — not registered in plugin registry"
1711 ))
1712 })?;
1713 let aggregate = Arc::clone(&entry.aggregate);
1714
1715 if agg_name.as_str() == "COUNTALL" {
1718 return Ok(FoldBinding {
1719 output_name: yield_alias.clone(),
1720 name: agg_name,
1721 aggregate,
1722 input_col_index: 0, input_col_name: None,
1724 });
1725 }
1726
1727 let input_col_index = yield_schema
1732 .iter()
1733 .position(|yc| yc.name == *name || yc.name == *yield_alias)
1734 .unwrap_or(0);
1735 Ok(FoldBinding {
1736 output_name: yield_alias.clone(),
1737 name: agg_name,
1738 aggregate,
1739 input_col_index,
1740 input_col_name: Some(name.clone()),
1741 })
1742 })
1743 .collect()
1744}
1745
1746fn parse_fold_aggregate(expr: &Expr) -> DFResult<(smol_str::SmolStr, String)> {
1752 match expr {
1753 Expr::FunctionCall { name, args, .. } => {
1754 let upper = name.to_uppercase();
1755 let is_count = matches!(upper.as_str(), "COUNT" | "MCOUNT");
1756
1757 if is_count && args.is_empty() {
1759 return Ok((smol_str::SmolStr::new_static("COUNTALL"), String::new()));
1760 }
1761
1762 let canonical = match upper.as_str() {
1763 "SUM" | "MSUM" => smol_str::SmolStr::new_static("SUM"),
1764 "MAX" | "MMAX" => smol_str::SmolStr::new_static("MAX"),
1765 "MIN" | "MMIN" => smol_str::SmolStr::new_static("MIN"),
1766 "COUNT" | "MCOUNT" => smol_str::SmolStr::new_static("COUNT"),
1767 "AVG" => smol_str::SmolStr::new_static("AVG"),
1768 "COLLECT" => smol_str::SmolStr::new_static("COLLECT"),
1769 "MNOR" => smol_str::SmolStr::new_static("MNOR"),
1770 "MPROD" => smol_str::SmolStr::new_static("MPROD"),
1771 _ => {
1772 return Err(datafusion::error::DataFusionError::Plan(format!(
1773 "Unknown FOLD aggregate function: {}",
1774 name
1775 )));
1776 }
1777 };
1778 let col_name = match args.first() {
1779 Some(Expr::Variable(v)) => v.clone(),
1780 Some(Expr::Property(_, prop)) => prop.clone(),
1781 Some(other) => other.to_string_repr(),
1782 None => {
1783 return Err(datafusion::error::DataFusionError::Plan(
1784 "FOLD aggregate function requires at least one argument".to_string(),
1785 ));
1786 }
1787 };
1788 Ok((canonical, col_name))
1789 }
1790 _ => Err(datafusion::error::DataFusionError::Plan(
1791 "FOLD binding must be a function call (e.g., SUM(x))".to_string(),
1792 )),
1793 }
1794}
1795
1796fn convert_best_by_criteria(
1803 criteria: &[(Expr, bool)],
1804 yield_schema: &[LocyYieldColumn],
1805) -> DFResult<Vec<SortCriterion>> {
1806 criteria
1807 .iter()
1808 .map(|(expr, ascending)| {
1809 let col_name = match expr {
1810 Expr::Property(_, prop) => prop.clone(),
1811 Expr::Variable(v) => v.clone(),
1812 _ => {
1813 return Err(datafusion::error::DataFusionError::Plan(
1814 "BEST BY criterion must be a variable or property reference".to_string(),
1815 ));
1816 }
1817 };
1818 let col_index = yield_schema
1820 .iter()
1821 .position(|yc| yc.name == col_name)
1822 .or_else(|| {
1823 let short_name = col_name.rsplit('.').next().unwrap_or(&col_name);
1824 yield_schema.iter().position(|yc| yc.name == short_name)
1825 })
1826 .ok_or_else(|| {
1827 datafusion::error::DataFusionError::Plan(format!(
1828 "BEST BY column '{}' not found in yield schema",
1829 col_name
1830 ))
1831 })?;
1832 Ok(SortCriterion {
1833 col_index,
1834 ascending: *ascending,
1835 nulls_first: false,
1836 })
1837 })
1838 .collect()
1839}
1840
1841fn yield_columns_to_arrow_schema(columns: &[LocyYieldColumn]) -> ArrowSchema {
1847 let fields: Vec<Arc<Field>> = columns
1848 .iter()
1849 .map(|yc| Arc::new(Field::new(&yc.name, yc.data_type.clone(), true)))
1850 .collect();
1851 ArrowSchema::new(fields)
1852}
1853
1854fn build_fixpoint_output_schema(rules: &[LocyRulePlan]) -> SchemaRef {
1856 if let Some(rule) = rules.first() {
1859 Arc::new(yield_columns_to_arrow_schema(&rule.yield_schema))
1860 } else {
1861 Arc::new(ArrowSchema::empty())
1862 }
1863}
1864
1865fn build_stats_batch(
1867 derived_store: &DerivedStore,
1868 _strata: &[LocyStratum],
1869 output_schema: SchemaRef,
1870) -> RecordBatch {
1871 let mut rule_names: Vec<String> = derived_store.rule_names().map(String::from).collect();
1873 rule_names.sort();
1874
1875 let name_col: arrow_array::StringArray = rule_names.iter().map(|s| Some(s.as_str())).collect();
1876 let count_col: arrow_array::Int64Array = rule_names
1877 .iter()
1878 .map(|name| Some(derived_store.fact_count(name) as i64))
1879 .collect();
1880
1881 let stats_schema = stats_schema();
1882 RecordBatch::try_new(stats_schema, vec![Arc::new(name_col), Arc::new(count_col)])
1883 .unwrap_or_else(|_| RecordBatch::new_empty(output_schema))
1884}
1885
1886pub fn stats_schema() -> SchemaRef {
1888 Arc::new(ArrowSchema::new(vec![
1889 Arc::new(Field::new("rule_name", DataType::Utf8, false)),
1890 Arc::new(Field::new("fact_count", DataType::Int64, false)),
1891 ]))
1892}
1893
1894#[cfg(test)]
1899mod tests {
1900 use super::*;
1901 use arrow_array::{Int64Array, LargeBinaryArray, StringArray};
1902
1903 #[test]
1904 fn test_derived_store_insert_and_get() {
1905 let mut store = DerivedStore::new();
1906 assert!(store.get("test").is_none());
1907
1908 let schema = Arc::new(ArrowSchema::new(vec![Arc::new(Field::new(
1909 "x",
1910 DataType::LargeBinary,
1911 true,
1912 ))]));
1913 let batch = RecordBatch::try_new(
1914 Arc::clone(&schema),
1915 vec![Arc::new(LargeBinaryArray::from(vec![
1916 Some(b"a" as &[u8]),
1917 Some(b"b"),
1918 ]))],
1919 )
1920 .unwrap();
1921
1922 store.insert("test".to_string(), vec![batch.clone()]);
1923
1924 let facts = store.get("test").unwrap();
1925 assert_eq!(facts.len(), 1);
1926 assert_eq!(facts[0].num_rows(), 2);
1927 }
1928
1929 #[test]
1930 fn test_derived_store_fact_count() {
1931 let mut store = DerivedStore::new();
1932 assert_eq!(store.fact_count("empty"), 0);
1933
1934 let schema = Arc::new(ArrowSchema::new(vec![Arc::new(Field::new(
1935 "x",
1936 DataType::LargeBinary,
1937 true,
1938 ))]));
1939 let batch1 = RecordBatch::try_new(
1940 Arc::clone(&schema),
1941 vec![Arc::new(LargeBinaryArray::from(vec![Some(b"a" as &[u8])]))],
1942 )
1943 .unwrap();
1944 let batch2 = RecordBatch::try_new(
1945 Arc::clone(&schema),
1946 vec![Arc::new(LargeBinaryArray::from(vec![
1947 Some(b"b" as &[u8]),
1948 Some(b"c"),
1949 ]))],
1950 )
1951 .unwrap();
1952
1953 store.insert("test".to_string(), vec![batch1, batch2]);
1954 assert_eq!(store.fact_count("test"), 3);
1955 }
1956
1957 #[test]
1958 fn test_stats_batch_schema() {
1959 let schema = stats_schema();
1960 assert_eq!(schema.fields().len(), 2);
1961 assert_eq!(schema.field(0).name(), "rule_name");
1962 assert_eq!(schema.field(1).name(), "fact_count");
1963 assert_eq!(schema.field(0).data_type(), &DataType::Utf8);
1964 assert_eq!(schema.field(1).data_type(), &DataType::Int64);
1965 }
1966
1967 #[test]
1968 fn test_stats_batch_content() {
1969 let mut store = DerivedStore::new();
1970 let schema = Arc::new(ArrowSchema::new(vec![Arc::new(Field::new(
1971 "x",
1972 DataType::LargeBinary,
1973 true,
1974 ))]));
1975 let batch = RecordBatch::try_new(
1976 Arc::clone(&schema),
1977 vec![Arc::new(LargeBinaryArray::from(vec![
1978 Some(b"a" as &[u8]),
1979 Some(b"b"),
1980 ]))],
1981 )
1982 .unwrap();
1983 store.insert("reach".to_string(), vec![batch]);
1984
1985 let output_schema = stats_schema();
1986 let stats = build_stats_batch(&store, &[], Arc::clone(&output_schema));
1987 assert_eq!(stats.num_rows(), 1);
1988
1989 let names = stats
1990 .column(0)
1991 .as_any()
1992 .downcast_ref::<StringArray>()
1993 .unwrap();
1994 assert_eq!(names.value(0), "reach");
1995
1996 let counts = stats
1997 .column(1)
1998 .as_any()
1999 .downcast_ref::<Int64Array>()
2000 .unwrap();
2001 assert_eq!(counts.value(0), 2);
2002 }
2003
2004 #[test]
2005 fn test_yield_columns_to_arrow_schema() {
2006 let columns = vec![
2007 LocyYieldColumn {
2008 name: "a".to_string(),
2009 is_key: true,
2010 is_prob: false,
2011 data_type: DataType::UInt64,
2012 },
2013 LocyYieldColumn {
2014 name: "b".to_string(),
2015 is_key: false,
2016 is_prob: false,
2017 data_type: DataType::LargeUtf8,
2018 },
2019 LocyYieldColumn {
2020 name: "c".to_string(),
2021 is_key: true,
2022 is_prob: false,
2023 data_type: DataType::Float64,
2024 },
2025 ];
2026
2027 let schema = yield_columns_to_arrow_schema(&columns);
2028 assert_eq!(schema.fields().len(), 3);
2029 assert_eq!(schema.field(0).name(), "a");
2030 assert_eq!(schema.field(1).name(), "b");
2031 assert_eq!(schema.field(2).name(), "c");
2032 assert_eq!(schema.field(0).data_type(), &DataType::UInt64);
2034 assert_eq!(schema.field(1).data_type(), &DataType::LargeUtf8);
2035 assert_eq!(schema.field(2).data_type(), &DataType::Float64);
2036 for field in schema.fields() {
2037 assert!(field.is_nullable());
2038 }
2039 }
2040
2041 #[test]
2042 fn test_key_column_indices() {
2043 let columns = [
2044 LocyYieldColumn {
2045 name: "a".to_string(),
2046 is_key: true,
2047 is_prob: false,
2048 data_type: DataType::LargeBinary,
2049 },
2050 LocyYieldColumn {
2051 name: "b".to_string(),
2052 is_key: false,
2053 is_prob: false,
2054 data_type: DataType::LargeBinary,
2055 },
2056 LocyYieldColumn {
2057 name: "c".to_string(),
2058 is_key: true,
2059 is_prob: false,
2060 data_type: DataType::LargeBinary,
2061 },
2062 ];
2063
2064 let key_indices: Vec<usize> = columns
2065 .iter()
2066 .enumerate()
2067 .filter(|(_, yc)| yc.is_key)
2068 .map(|(i, _)| i)
2069 .collect();
2070 assert_eq!(key_indices, vec![0, 2]);
2071 }
2072
2073 #[test]
2074 fn test_parse_fold_aggregate_sum() {
2075 let expr = Expr::FunctionCall {
2076 name: "SUM".to_string(),
2077 args: vec![Expr::Variable("cost".to_string())],
2078 distinct: false,
2079 window_spec: None,
2080 };
2081 let (kind, col) = parse_fold_aggregate(&expr).unwrap();
2082 assert_eq!(kind.as_str(), "SUM");
2083 assert_eq!(col, "cost");
2084 }
2085
2086 #[test]
2087 fn test_parse_fold_aggregate_monotonic() {
2088 let expr = Expr::FunctionCall {
2089 name: "MMAX".to_string(),
2090 args: vec![Expr::Variable("score".to_string())],
2091 distinct: false,
2092 window_spec: None,
2093 };
2094 let (kind, col) = parse_fold_aggregate(&expr).unwrap();
2095 assert_eq!(kind.as_str(), "MAX");
2096 assert_eq!(col, "score");
2097 }
2098
2099 #[test]
2100 fn test_parse_fold_aggregate_unknown() {
2101 let expr = Expr::FunctionCall {
2102 name: "UNKNOWN_AGG".to_string(),
2103 args: vec![Expr::Variable("x".to_string())],
2104 distinct: false,
2105 window_spec: None,
2106 };
2107 assert!(parse_fold_aggregate(&expr).is_err());
2108 }
2109
2110 #[test]
2111 fn test_no_commands_returns_stats() {
2112 let store = DerivedStore::new();
2113 let output_schema = stats_schema();
2114 let stats = build_stats_batch(&store, &[], Arc::clone(&output_schema));
2115 assert_eq!(stats.num_rows(), 0);
2117 }
2118}