1use crate::query::df_graph::GraphExecutionContext;
12use crate::query::df_graph::common::{
13 collect_all_partitions, compute_plan_properties, execute_subplan,
14};
15use crate::query::df_graph::locy_best_by::SortCriterion;
16use crate::query::df_graph::locy_explain::ProvenanceStore;
17use crate::query::df_graph::locy_fixpoint::{
18 DerivedScanRegistry, FixpointClausePlan, FixpointExec, FixpointRulePlan, IsRefBinding,
19};
20use crate::query::df_graph::locy_fold::{FoldBinding, resolve_locy_aggregate};
21use crate::query::planner_locy_types::{
22 LocyCommand, LocyIsRef, LocyRulePlan, LocyStratum, LocyYieldColumn,
23};
24use arrow_array::RecordBatch;
25use arrow_schema::{DataType, Field, Schema as ArrowSchema, SchemaRef};
26use datafusion::common::Result as DFResult;
27use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
28use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
29use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
30use futures::Stream;
31use parking_lot::RwLock;
32use std::any::Any;
33use std::collections::HashMap;
34use std::fmt;
35use std::pin::Pin;
36use std::sync::Arc;
37use std::sync::RwLock as StdRwLock;
38use std::task::{Context, Poll};
39use std::time::{Duration, Instant};
40use uni_common::Value;
41use uni_common::core::schema::Schema as UniSchema;
42use uni_cypher::ast::Expr;
43use uni_locy::{
44 ClassifierRegistry, CommandResult, FactRow, ModelInvocationCache, RuntimeWarning, SemiringKind,
45};
46use uni_plugin::PluginRegistry;
47use uni_store::storage::manager::StorageManager;
48
49pub struct DerivedStore {
58 relations: HashMap<String, Vec<RecordBatch>>,
59}
60
61impl Default for DerivedStore {
62 fn default() -> Self {
63 Self::new()
64 }
65}
66
67impl DerivedStore {
68 pub fn new() -> Self {
69 Self {
70 relations: HashMap::new(),
71 }
72 }
73
74 pub fn insert(&mut self, rule_name: String, facts: Vec<RecordBatch>) {
75 self.relations.insert(rule_name, facts);
76 }
77
78 pub fn get(&self, rule_name: &str) -> Option<&Vec<RecordBatch>> {
79 self.relations.get(rule_name)
80 }
81
82 pub fn fact_count(&self, rule_name: &str) -> usize {
83 self.relations
84 .get(rule_name)
85 .map(|batches| batches.iter().map(|b| b.num_rows()).sum())
86 .unwrap_or(0)
87 }
88
89 pub fn rule_names(&self) -> impl Iterator<Item = &str> {
90 self.relations.keys().map(|s| s.as_str())
91 }
92}
93
94pub struct LocyProgramExec {
104 strata: Vec<LocyStratum>,
105 commands: Vec<LocyCommand>,
106 derived_scan_registry: Arc<DerivedScanRegistry>,
107 plugin_registry: Arc<PluginRegistry>,
108 graph_ctx: Arc<GraphExecutionContext>,
109 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
110 storage: Arc<StorageManager>,
111 schema_info: Arc<UniSchema>,
112 params: HashMap<String, Value>,
113 output_schema: SchemaRef,
114 properties: Arc<PlanProperties>,
115 metrics: ExecutionPlanMetricsSet,
116 max_iterations: usize,
117 timeout: Duration,
118 max_derived_bytes: usize,
119 deterministic_best_by: bool,
120 strict_probability_domain: bool,
121 probability_epsilon: f64,
122 exact_probability: bool,
123 max_bdd_variables: usize,
124 semiring_kind: SemiringKind,
127 classifier_registry: Arc<ClassifierRegistry>,
131 classifier_cache: Option<Arc<ModelInvocationCache>>,
134 classifier_provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
138 derived_store_slot: Arc<StdRwLock<Option<DerivedStore>>>,
140 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
142 derivation_tracker: Arc<StdRwLock<Option<Arc<ProvenanceStore>>>>,
144 iteration_counts_slot: Arc<StdRwLock<HashMap<String, usize>>>,
146 peak_memory_slot: Arc<StdRwLock<usize>>,
148 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
150 command_results_slot: Arc<StdRwLock<Vec<(usize, CommandResult)>>>,
152 top_k_proofs: usize,
154 timeout_flag: Arc<std::sync::atomic::AtomicU8>,
158 incomplete_slot: Arc<StdRwLock<Option<uni_common::LocyIncomplete>>>,
162}
163
164pub(crate) mod interruption {
171 use std::sync::atomic::{AtomicU8, Ordering};
172
173 use uni_common::LocyIncompleteReason;
174
175 pub(crate) const NONE: u8 = 0;
177 pub(crate) const TIMEOUT: u8 = 1;
179 pub(crate) const ITERATION_LIMIT: u8 = 2;
181
182 pub(crate) fn reason(flag: &AtomicU8) -> Option<LocyIncompleteReason> {
184 match flag.load(Ordering::Relaxed) {
185 TIMEOUT => Some(LocyIncompleteReason::Timeout),
186 ITERATION_LIMIT => Some(LocyIncompleteReason::IterationLimit),
187 _ => None,
188 }
189 }
190
191 pub(crate) fn set(flag: &AtomicU8, code: u8) {
195 let _ = flag.compare_exchange(NONE, code, Ordering::Relaxed, Ordering::Relaxed);
196 }
197}
198
199impl fmt::Debug for LocyProgramExec {
200 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201 f.debug_struct("LocyProgramExec")
202 .field("strata_count", &self.strata.len())
203 .field("commands_count", &self.commands.len())
204 .field("max_iterations", &self.max_iterations)
205 .field("timeout", &self.timeout)
206 .field("output_schema", &self.output_schema)
207 .field("max_derived_bytes", &self.max_derived_bytes)
208 .finish_non_exhaustive()
209 }
210}
211
212impl LocyProgramExec {
213 #[expect(
214 clippy::too_many_arguments,
215 reason = "execution plan node requires full graph and session context"
216 )]
217 #[deprecated(
218 note = "use `new_with_semiring_classifiers_and_cache` (or the lighter \
219 `new_with_semiring_and_classifiers` / `new_with_semiring`) — \
220 this legacy ctor defaults the semiring to AddMultProb and \
221 ships no classifier registry. To be removed after C0 Stage 2."
222 )]
223 pub fn new(
224 strata: Vec<LocyStratum>,
225 commands: Vec<LocyCommand>,
226 derived_scan_registry: Arc<DerivedScanRegistry>,
227 plugin_registry: Arc<PluginRegistry>,
228 graph_ctx: Arc<GraphExecutionContext>,
229 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
230 storage: Arc<StorageManager>,
231 schema_info: Arc<UniSchema>,
232 params: HashMap<String, Value>,
233 output_schema: SchemaRef,
234 max_iterations: usize,
235 timeout: Duration,
236 max_derived_bytes: usize,
237 deterministic_best_by: bool,
238 strict_probability_domain: bool,
239 probability_epsilon: f64,
240 exact_probability: bool,
241 max_bdd_variables: usize,
242 top_k_proofs: usize,
243 ) -> Self {
244 Self::new_with_semiring_and_classifiers(
245 strata,
246 commands,
247 derived_scan_registry,
248 plugin_registry,
249 graph_ctx,
250 session_ctx,
251 storage,
252 schema_info,
253 params,
254 output_schema,
255 max_iterations,
256 timeout,
257 max_derived_bytes,
258 deterministic_best_by,
259 strict_probability_domain,
260 probability_epsilon,
261 exact_probability,
262 max_bdd_variables,
263 top_k_proofs,
264 SemiringKind::AddMultProb,
265 Arc::new(ClassifierRegistry::new()),
266 )
267 }
268
269 #[expect(
273 clippy::too_many_arguments,
274 reason = "execution plan node requires full graph and session context"
275 )]
276 pub fn new_with_semiring(
277 strata: Vec<LocyStratum>,
278 commands: Vec<LocyCommand>,
279 derived_scan_registry: Arc<DerivedScanRegistry>,
280 plugin_registry: Arc<PluginRegistry>,
281 graph_ctx: Arc<GraphExecutionContext>,
282 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
283 storage: Arc<StorageManager>,
284 schema_info: Arc<UniSchema>,
285 params: HashMap<String, Value>,
286 output_schema: SchemaRef,
287 max_iterations: usize,
288 timeout: Duration,
289 max_derived_bytes: usize,
290 deterministic_best_by: bool,
291 strict_probability_domain: bool,
292 probability_epsilon: f64,
293 exact_probability: bool,
294 max_bdd_variables: usize,
295 top_k_proofs: usize,
296 semiring_kind: SemiringKind,
297 ) -> Self {
298 Self::new_with_semiring_and_classifiers(
299 strata,
300 commands,
301 derived_scan_registry,
302 plugin_registry,
303 graph_ctx,
304 session_ctx,
305 storage,
306 schema_info,
307 params,
308 output_schema,
309 max_iterations,
310 timeout,
311 max_derived_bytes,
312 deterministic_best_by,
313 strict_probability_domain,
314 probability_epsilon,
315 exact_probability,
316 max_bdd_variables,
317 top_k_proofs,
318 semiring_kind,
319 Arc::new(ClassifierRegistry::new()),
320 )
321 }
322
323 #[expect(
326 clippy::too_many_arguments,
327 reason = "execution plan node requires full graph and session context"
328 )]
329 pub fn new_with_semiring_and_classifiers(
330 strata: Vec<LocyStratum>,
331 commands: Vec<LocyCommand>,
332 derived_scan_registry: Arc<DerivedScanRegistry>,
333 plugin_registry: Arc<PluginRegistry>,
334 graph_ctx: Arc<GraphExecutionContext>,
335 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
336 storage: Arc<StorageManager>,
337 schema_info: Arc<UniSchema>,
338 params: HashMap<String, Value>,
339 output_schema: SchemaRef,
340 max_iterations: usize,
341 timeout: Duration,
342 max_derived_bytes: usize,
343 deterministic_best_by: bool,
344 strict_probability_domain: bool,
345 probability_epsilon: f64,
346 exact_probability: bool,
347 max_bdd_variables: usize,
348 top_k_proofs: usize,
349 semiring_kind: SemiringKind,
350 classifier_registry: Arc<ClassifierRegistry>,
351 ) -> Self {
352 Self::new_with_semiring_classifiers_and_cache(
353 strata,
354 commands,
355 derived_scan_registry,
356 plugin_registry,
357 graph_ctx,
358 session_ctx,
359 storage,
360 schema_info,
361 params,
362 output_schema,
363 max_iterations,
364 timeout,
365 max_derived_bytes,
366 deterministic_best_by,
367 strict_probability_domain,
368 probability_epsilon,
369 exact_probability,
370 max_bdd_variables,
371 top_k_proofs,
372 semiring_kind,
373 classifier_registry,
374 None,
375 None,
376 )
377 }
378
379 #[expect(
384 clippy::too_many_arguments,
385 reason = "execution plan node requires full graph and session context"
386 )]
387 pub fn new_with_semiring_classifiers_and_cache(
388 strata: Vec<LocyStratum>,
389 commands: Vec<LocyCommand>,
390 derived_scan_registry: Arc<DerivedScanRegistry>,
391 plugin_registry: Arc<PluginRegistry>,
392 graph_ctx: Arc<GraphExecutionContext>,
393 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
394 storage: Arc<StorageManager>,
395 schema_info: Arc<UniSchema>,
396 params: HashMap<String, Value>,
397 output_schema: SchemaRef,
398 max_iterations: usize,
399 timeout: Duration,
400 max_derived_bytes: usize,
401 deterministic_best_by: bool,
402 strict_probability_domain: bool,
403 probability_epsilon: f64,
404 exact_probability: bool,
405 max_bdd_variables: usize,
406 top_k_proofs: usize,
407 semiring_kind: SemiringKind,
408 classifier_registry: Arc<ClassifierRegistry>,
409 classifier_cache: Option<Arc<ModelInvocationCache>>,
410 classifier_provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
411 ) -> Self {
412 let properties = compute_plan_properties(Arc::clone(&output_schema));
413 Self {
414 strata,
415 commands,
416 derived_scan_registry,
417 plugin_registry,
418 graph_ctx,
419 session_ctx,
420 storage,
421 schema_info,
422 params,
423 output_schema,
424 properties,
425 metrics: ExecutionPlanMetricsSet::new(),
426 max_iterations,
427 timeout,
428 max_derived_bytes,
429 deterministic_best_by,
430 strict_probability_domain,
431 probability_epsilon,
432 exact_probability,
433 max_bdd_variables,
434 semiring_kind,
435 classifier_registry,
436 classifier_cache,
437 classifier_provenance_store,
438 derived_store_slot: Arc::new(StdRwLock::new(None)),
439 approximate_slot: Arc::new(StdRwLock::new(HashMap::new())),
440 derivation_tracker: Arc::new(StdRwLock::new(None)),
441 iteration_counts_slot: Arc::new(StdRwLock::new(HashMap::new())),
442 peak_memory_slot: Arc::new(StdRwLock::new(0)),
443 warnings_slot: Arc::new(StdRwLock::new(Vec::new())),
444 command_results_slot: Arc::new(StdRwLock::new(Vec::new())),
445 top_k_proofs,
446 timeout_flag: Arc::new(std::sync::atomic::AtomicU8::new(interruption::NONE)),
447 incomplete_slot: Arc::new(StdRwLock::new(None)),
448 }
449 }
450
451 pub fn derived_store_slot(&self) -> Arc<StdRwLock<Option<DerivedStore>>> {
456 Arc::clone(&self.derived_store_slot)
457 }
458
459 pub fn set_derivation_tracker(&self, tracker: Arc<ProvenanceStore>) {
464 if let Ok(mut guard) = self.derivation_tracker.write() {
465 *guard = Some(tracker);
466 }
467 }
468
469 pub fn iteration_counts_slot(&self) -> Arc<StdRwLock<HashMap<String, usize>>> {
474 Arc::clone(&self.iteration_counts_slot)
475 }
476
477 pub fn peak_memory_slot(&self) -> Arc<StdRwLock<usize>> {
482 Arc::clone(&self.peak_memory_slot)
483 }
484
485 pub fn warnings_slot(&self) -> Arc<StdRwLock<Vec<RuntimeWarning>>> {
490 Arc::clone(&self.warnings_slot)
491 }
492
493 pub fn approximate_slot(&self) -> Arc<StdRwLock<HashMap<String, Vec<String>>>> {
498 Arc::clone(&self.approximate_slot)
499 }
500
501 pub fn command_results_slot(&self) -> Arc<StdRwLock<Vec<(usize, CommandResult)>>> {
506 Arc::clone(&self.command_results_slot)
507 }
508
509 pub fn timeout_flag(&self) -> Arc<std::sync::atomic::AtomicU8> {
515 Arc::clone(&self.timeout_flag)
516 }
517
518 pub fn incomplete_slot(&self) -> Arc<StdRwLock<Option<uni_common::LocyIncomplete>>> {
524 Arc::clone(&self.incomplete_slot)
525 }
526}
527
528impl DisplayAs for LocyProgramExec {
529 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
530 write!(
531 f,
532 "LocyProgramExec: strata={}, commands={}, max_iter={}, timeout={:?}",
533 self.strata.len(),
534 self.commands.len(),
535 self.max_iterations,
536 self.timeout,
537 )
538 }
539}
540
541impl ExecutionPlan for LocyProgramExec {
542 fn name(&self) -> &str {
543 "LocyProgramExec"
544 }
545
546 fn as_any(&self) -> &dyn Any {
547 self
548 }
549
550 fn schema(&self) -> SchemaRef {
551 Arc::clone(&self.output_schema)
552 }
553
554 fn properties(&self) -> &Arc<PlanProperties> {
555 &self.properties
556 }
557
558 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
559 vec![]
560 }
561
562 fn with_new_children(
563 self: Arc<Self>,
564 children: Vec<Arc<dyn ExecutionPlan>>,
565 ) -> DFResult<Arc<dyn ExecutionPlan>> {
566 if !children.is_empty() {
567 return Err(datafusion::error::DataFusionError::Plan(
568 "LocyProgramExec has no children".to_string(),
569 ));
570 }
571 Ok(self)
572 }
573
574 fn execute(
575 &self,
576 partition: usize,
577 _context: Arc<TaskContext>,
578 ) -> DFResult<SendableRecordBatchStream> {
579 let metrics = BaselineMetrics::new(&self.metrics, partition);
580
581 let strata = self.strata.clone();
582 let registry = Arc::clone(&self.derived_scan_registry);
583 let plugin_registry = Arc::clone(&self.plugin_registry);
584 let graph_ctx = Arc::clone(&self.graph_ctx);
585 let session_ctx = Arc::clone(&self.session_ctx);
586 let storage = Arc::clone(&self.storage);
587 let schema_info = Arc::clone(&self.schema_info);
588 let params = self.params.clone();
589 let output_schema = Arc::clone(&self.output_schema);
590 let max_iterations = self.max_iterations;
591 let timeout = self.timeout;
592 let max_derived_bytes = self.max_derived_bytes;
593 let deterministic_best_by = self.deterministic_best_by;
594 let strict_probability_domain = self.strict_probability_domain;
595 let probability_epsilon = self.probability_epsilon;
596 let exact_probability = self.exact_probability;
597 let max_bdd_variables = self.max_bdd_variables;
598 let derived_store_slot = Arc::clone(&self.derived_store_slot);
599 let approximate_slot = Arc::clone(&self.approximate_slot);
600 let iteration_counts_slot = Arc::clone(&self.iteration_counts_slot);
601 let peak_memory_slot = Arc::clone(&self.peak_memory_slot);
602 let derivation_tracker = self.derivation_tracker.read().ok().and_then(|g| g.clone());
603 let warnings_slot = Arc::clone(&self.warnings_slot);
604 let commands = self.commands.clone();
605 let command_results_slot = Arc::clone(&self.command_results_slot);
606 let top_k_proofs = self.top_k_proofs;
607 let timeout_flag = Arc::clone(&self.timeout_flag);
608 let incomplete_slot = Arc::clone(&self.incomplete_slot);
609 let semiring_kind = self.semiring_kind;
610 let classifier_registry = Arc::clone(&self.classifier_registry);
611 let classifier_cache = self.classifier_cache.as_ref().map(Arc::clone);
612 let classifier_provenance_store = self.classifier_provenance_store.as_ref().map(Arc::clone);
613
614 let fut = async move {
615 run_program(
616 strata,
617 commands,
618 registry,
619 plugin_registry,
620 graph_ctx,
621 session_ctx,
622 storage,
623 schema_info,
624 params,
625 output_schema,
626 max_iterations,
627 timeout,
628 max_derived_bytes,
629 deterministic_best_by,
630 strict_probability_domain,
631 probability_epsilon,
632 exact_probability,
633 max_bdd_variables,
634 derived_store_slot,
635 approximate_slot,
636 iteration_counts_slot,
637 peak_memory_slot,
638 derivation_tracker,
639 warnings_slot,
640 command_results_slot,
641 top_k_proofs,
642 timeout_flag,
643 incomplete_slot,
644 semiring_kind,
645 classifier_registry,
646 classifier_cache,
647 classifier_provenance_store,
648 )
649 .await
650 };
651
652 Ok(Box::pin(ProgramStream {
653 state: ProgramStreamState::Running(Box::pin(fut)),
654 schema: Arc::clone(&self.output_schema),
655 metrics,
656 }))
657 }
658
659 fn metrics(&self) -> Option<MetricsSet> {
660 Some(self.metrics.clone_inner())
661 }
662}
663
664enum ProgramStreamState {
669 Running(Pin<Box<dyn std::future::Future<Output = DFResult<Vec<RecordBatch>>> + Send>>),
670 Emitting(Vec<RecordBatch>, usize),
671 Done,
672}
673
674struct ProgramStream {
675 state: ProgramStreamState,
676 schema: SchemaRef,
677 metrics: BaselineMetrics,
678}
679
680impl Stream for ProgramStream {
681 type Item = DFResult<RecordBatch>;
682
683 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
684 let this = self.get_mut();
685 let metrics = this.metrics.clone();
686 let _timer = metrics.elapsed_compute().timer();
687 loop {
688 match &mut this.state {
689 ProgramStreamState::Running(fut) => match fut.as_mut().poll(cx) {
690 Poll::Ready(Ok(batches)) => {
691 if batches.is_empty() {
692 this.state = ProgramStreamState::Done;
693 return Poll::Ready(None);
694 }
695 this.state = ProgramStreamState::Emitting(batches, 0);
696 }
697 Poll::Ready(Err(e)) => {
698 this.state = ProgramStreamState::Done;
699 return Poll::Ready(Some(Err(e)));
700 }
701 Poll::Pending => return Poll::Pending,
702 },
703 ProgramStreamState::Emitting(batches, idx) => {
704 if *idx >= batches.len() {
705 this.state = ProgramStreamState::Done;
706 return Poll::Ready(None);
707 }
708 let batch = batches[*idx].clone();
709 *idx += 1;
710 this.metrics.record_output(batch.num_rows());
711 return Poll::Ready(Some(Ok(batch)));
712 }
713 ProgramStreamState::Done => return Poll::Ready(None),
714 }
715 }
716 }
717}
718
719impl RecordBatchStream for ProgramStream {
720 fn schema(&self) -> SchemaRef {
721 Arc::clone(&self.schema)
722 }
723}
724
725async fn execute_cypher_inline(
731 query: &uni_cypher::ast::Query,
732 schema_info: &Arc<UniSchema>,
733 params: &HashMap<String, Value>,
734 graph_ctx: &Arc<GraphExecutionContext>,
735 session_ctx: &Arc<RwLock<datafusion::prelude::SessionContext>>,
736 storage: &Arc<StorageManager>,
737) -> DFResult<Vec<FactRow>> {
738 let planner = crate::query::planner::QueryPlanner::new(Arc::clone(schema_info));
739 let logical_plan = planner.plan(query.clone()).map_err(|e| {
740 datafusion::error::DataFusionError::Execution(format!("Cypher plan error: {e}"))
741 })?;
742 let batches = execute_subplan(
743 &logical_plan,
744 params,
745 &HashMap::new(),
746 graph_ctx,
747 session_ctx,
748 storage,
749 schema_info,
750 None, )
752 .await?;
753 Ok(super::locy_eval::record_batches_to_locy_rows(&batches))
754}
755
756#[expect(
761 clippy::too_many_arguments,
762 reason = "program evaluation requires full graph and session context"
763)]
764async fn run_program(
765 strata: Vec<LocyStratum>,
766 commands: Vec<LocyCommand>,
767 registry: Arc<DerivedScanRegistry>,
768 plugin_registry: Arc<PluginRegistry>,
769 graph_ctx: Arc<GraphExecutionContext>,
770 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
771 storage: Arc<StorageManager>,
772 schema_info: Arc<UniSchema>,
773 params: HashMap<String, Value>,
774 output_schema: SchemaRef,
775 max_iterations: usize,
776 timeout: Duration,
777 max_derived_bytes: usize,
778 deterministic_best_by: bool,
779 strict_probability_domain: bool,
780 probability_epsilon: f64,
781 exact_probability: bool,
782 max_bdd_variables: usize,
783 derived_store_slot: Arc<StdRwLock<Option<DerivedStore>>>,
784 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
785 iteration_counts_slot: Arc<StdRwLock<HashMap<String, usize>>>,
786 peak_memory_slot: Arc<StdRwLock<usize>>,
787 derivation_tracker: Option<Arc<ProvenanceStore>>,
788 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
789 command_results_slot: Arc<StdRwLock<Vec<(usize, CommandResult)>>>,
790 top_k_proofs: usize,
791 timeout_flag: Arc<std::sync::atomic::AtomicU8>,
792 incomplete_slot: Arc<StdRwLock<Option<uni_common::LocyIncomplete>>>,
793 semiring_kind: SemiringKind,
794 classifier_registry: Arc<ClassifierRegistry>,
795 classifier_cache: Option<Arc<ModelInvocationCache>>,
796 classifier_provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
797) -> DFResult<Vec<RecordBatch>> {
798 let start = Instant::now();
799 let mut derived_store = DerivedStore::new();
800
801 if semiring_kind == SemiringKind::MaxMinProb {
806 let mut warnings = warnings_slot.write().unwrap_or_else(|e| e.into_inner());
807 let mut already: std::collections::HashSet<String> = warnings
808 .iter()
809 .filter(|w| w.code == uni_locy::RuntimeWarningCode::FuzzyNotProbabilistic)
810 .map(|w| w.rule_name.clone())
811 .collect();
812 for stratum in &strata {
813 for rule in &stratum.rules {
814 let has_prob = rule.yield_schema.iter().any(|c| c.is_prob);
815 if has_prob && !already.contains(&rule.name) {
816 warnings.push(RuntimeWarning {
817 code: uni_locy::RuntimeWarningCode::FuzzyNotProbabilistic,
818 message: format!(
819 "rule '{}' carries a PROB column but is being evaluated under \
820 the MaxMinProb (fuzzy / Viterbi) semiring; outputs are fuzzy \
821 truth values, not probabilities",
822 rule.name
823 ),
824 rule_name: rule.name.clone(),
825 variable_count: None,
826 key_group: None,
827 });
828 already.insert(rule.name.clone());
829 }
830 }
831 }
832 }
833
834 let total_strata = strata.len();
838 let mut completed_strata = 0usize;
839 let mut partial_stratum: Option<usize> = None;
840 for (stratum_idx, stratum) in strata.iter().enumerate() {
841 write_cross_stratum_facts(®istry, &derived_store, stratum);
843
844 let remaining_timeout = timeout.saturating_sub(start.elapsed());
845 if remaining_timeout.is_zero() {
846 tracing::warn!("Locy program timeout exceeded during stratum evaluation");
847 interruption::set(&timeout_flag, interruption::TIMEOUT);
848 break;
849 }
850
851 if stratum.is_recursive {
852 let fixpoint_rules = convert_to_fixpoint_plans(
854 &stratum.rules,
855 ®istry,
856 &plugin_registry,
857 deterministic_best_by,
858 )?;
859 let fixpoint_schema = build_fixpoint_output_schema(&stratum.rules);
860
861 let exec = FixpointExec::new_with_semiring_classifiers_and_cache(
862 fixpoint_rules,
863 max_iterations,
864 remaining_timeout,
865 Arc::clone(&graph_ctx),
866 Arc::clone(&session_ctx),
867 Arc::clone(&storage),
868 Arc::clone(&schema_info),
869 params.clone(),
870 Arc::clone(®istry),
871 fixpoint_schema,
872 max_derived_bytes,
873 derivation_tracker.clone(),
874 Arc::clone(&iteration_counts_slot),
875 strict_probability_domain,
876 probability_epsilon,
877 exact_probability,
878 max_bdd_variables,
879 Arc::clone(&warnings_slot),
880 Arc::clone(&approximate_slot),
881 top_k_proofs,
882 Arc::clone(&timeout_flag),
883 semiring_kind,
884 Arc::clone(&classifier_registry),
885 classifier_cache.as_ref().map(Arc::clone),
886 classifier_provenance_store.as_ref().map(Arc::clone),
887 );
888
889 let task_ctx = session_ctx.read().task_ctx();
890 let exec_arc: Arc<dyn ExecutionPlan> = Arc::new(exec);
891 let batches = collect_all_partitions(&exec_arc, task_ctx).await?;
892
893 for rule in &stratum.rules {
904 if rule.yield_schema.is_empty() {
906 continue;
907 }
908 let rule_entries = registry.entries_for_rule(&rule.name);
910 for entry in rule_entries {
911 if !entry.is_self_ref {
912 let all_facts: Vec<RecordBatch> = batches
916 .iter()
917 .filter(|b| {
918 let rule_schema = yield_columns_to_arrow_schema(&rule.yield_schema);
920 b.schema().fields().len() == rule_schema.fields().len()
921 })
922 .cloned()
923 .collect();
924 let mut guard = entry.data.write();
925 *guard = if all_facts.is_empty() {
926 vec![RecordBatch::new_empty(Arc::clone(&entry.schema))]
927 } else {
928 all_facts
929 };
930 }
931 }
932 derived_store.insert(rule.name.clone(), batches.clone());
933 }
934 } else {
935 let fixpoint_rules = convert_to_fixpoint_plans(
937 &stratum.rules,
938 ®istry,
939 &plugin_registry,
940 deterministic_best_by,
941 )?;
942 let task_ctx = session_ctx.read().task_ctx();
943
944 for (rule, fp_rule) in stratum.rules.iter().zip(fixpoint_rules.iter()) {
945 if rule.yield_schema.is_empty() {
950 continue;
951 }
952
953 if let Ok(mut counts) = iteration_counts_slot.write() {
958 counts.insert(rule.name.clone(), 1);
959 }
960
961 let mut tagged_clause_facts: Vec<(usize, Vec<RecordBatch>)> = Vec::new();
963 for (clause_idx, (clause, fp_clause)) in
964 rule.clauses.iter().zip(fp_rule.clauses.iter()).enumerate()
965 {
966 let mut batches = execute_subplan(
970 &clause.body,
971 ¶ms,
972 &HashMap::new(),
973 &graph_ctx,
974 &session_ctx,
975 &storage,
976 &schema_info,
977 None, )
979 .await?;
980
981 for binding in &fp_clause.is_ref_bindings {
983 if binding.negated
984 && !binding.anti_join_cols.is_empty()
985 && let Some(entry) = registry.get(binding.derived_scan_index)
986 {
987 let neg_facts = entry.data.read().clone();
988 if !neg_facts.is_empty() {
989 if binding.target_has_prob && fp_rule.prob_column_name.is_some() {
990 let complement_col =
991 format!("__prob_complement_{}", binding.rule_name);
992 if let Some(prob_col) = &binding.target_prob_col {
993 batches =
994 super::locy_fixpoint::apply_prob_complement_composite(
995 batches,
996 &neg_facts,
997 &binding.anti_join_cols,
998 prob_col,
999 &complement_col,
1000 )?;
1001 } else {
1002 batches = super::locy_fixpoint::apply_anti_join_composite(
1004 batches,
1005 &neg_facts,
1006 &binding.anti_join_cols,
1007 )?;
1008 }
1009 } else {
1010 batches = super::locy_fixpoint::apply_anti_join_composite(
1011 batches,
1012 &neg_facts,
1013 &binding.anti_join_cols,
1014 )?;
1015 }
1016 }
1017 }
1018 }
1019
1020 let complement_cols: Vec<String> = if !batches.is_empty() {
1022 batches[0]
1023 .schema()
1024 .fields()
1025 .iter()
1026 .filter(|f| f.name().starts_with("__prob_complement_"))
1027 .map(|f| f.name().clone())
1028 .collect()
1029 } else {
1030 vec![]
1031 };
1032 if !complement_cols.is_empty() {
1033 batches = super::locy_fixpoint::multiply_prob_factors(
1034 batches,
1035 fp_rule.prob_column_name.as_deref(),
1036 &complement_cols,
1037 )?;
1038 }
1039
1040 tagged_clause_facts.push((clause_idx, batches));
1041 }
1042
1043 let shared_info = if semiring_kind == SemiringKind::MaxMinProb {
1056 None
1057 } else if let Some(ref tracker) = derivation_tracker {
1058 super::locy_fixpoint::record_and_detect_lineage_nonrecursive(
1059 fp_rule,
1060 &tagged_clause_facts,
1061 tracker,
1062 &warnings_slot,
1063 ®istry,
1064 top_k_proofs,
1065 super::locy_fixpoint::ClassifierRefs {
1066 registry: &classifier_registry,
1067 cache: classifier_cache.as_ref(),
1068 provenance_store: classifier_provenance_store.as_ref(),
1069 },
1070 semiring_kind,
1071 )
1072 .await
1073 } else {
1074 None
1075 };
1076
1077 let mut all_clause_facts: Vec<RecordBatch> = tagged_clause_facts
1079 .into_iter()
1080 .flat_map(|(_, batches)| batches)
1081 .collect();
1082
1083 if exact_probability
1085 && let Some(ref info) = shared_info
1086 && let Some(ref tracker) = derivation_tracker
1087 {
1088 all_clause_facts = super::locy_fixpoint::apply_exact_wmc(
1089 all_clause_facts,
1090 fp_rule,
1091 info,
1092 tracker,
1093 max_bdd_variables,
1094 &warnings_slot,
1095 &approximate_slot,
1096 )?;
1097 }
1098
1099 let facts = super::locy_fixpoint::apply_post_fixpoint_chain(
1101 all_clause_facts,
1102 fp_rule,
1103 &task_ctx,
1104 strict_probability_domain,
1105 probability_epsilon,
1106 semiring_kind,
1107 derivation_tracker.as_ref().map(Arc::clone),
1108 top_k_proofs,
1109 Some(Arc::clone(®istry)),
1110 )
1111 .await?;
1112
1113 write_facts_to_registry(®istry, &rule.name, &facts);
1115 derived_store.insert(rule.name.clone(), facts);
1116 }
1117 }
1118
1119 if interruption::reason(&timeout_flag).is_some() {
1123 partial_stratum = Some(stratum_idx);
1124 break;
1125 }
1126 completed_strata += 1;
1127 }
1128
1129 if let Some(reason) = interruption::reason(&timeout_flag) {
1133 let skipped_start = match partial_stratum {
1134 Some(i) => i + 1,
1135 None => completed_strata,
1136 };
1137 let incomplete_rules: Vec<String> = partial_stratum
1138 .map(|i| strata[i].rules.iter().map(|r| r.name.clone()).collect())
1139 .unwrap_or_default();
1140 let skipped_rules: Vec<String> = strata[skipped_start..]
1141 .iter()
1142 .flat_map(|s| s.rules.iter().map(|r| r.name.clone()))
1143 .collect();
1144 let mut complement_rules_affected = Vec::new();
1145 for idx in partial_stratum
1146 .into_iter()
1147 .chain(skipped_start..total_strata)
1148 {
1149 for rule in &strata[idx].rules {
1150 if rule
1151 .clauses
1152 .iter()
1153 .any(|c| c.is_refs.iter().any(|r| r.negated))
1154 {
1155 complement_rules_affected.push(rule.name.clone());
1156 }
1157 }
1158 }
1159 if let Ok(mut slot) = incomplete_slot.write() {
1160 *slot = Some(uni_common::LocyIncomplete {
1161 reason,
1162 elapsed_ms: u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
1163 limit_ms: u64::try_from(timeout.as_millis()).unwrap_or(u64::MAX),
1164 max_iterations,
1165 completed_strata,
1166 total_strata,
1167 incomplete_rules,
1168 skipped_rules,
1169 complement_rules_affected,
1170 });
1171 }
1172 }
1173
1174 let peak_bytes: usize = derived_store
1176 .relations
1177 .values()
1178 .flat_map(|batches| batches.iter())
1179 .map(|b| {
1180 b.columns()
1181 .iter()
1182 .map(|col| col.get_buffer_memory_size())
1183 .sum::<usize>()
1184 })
1185 .sum();
1186 *peak_memory_slot.write().unwrap() = peak_bytes;
1187
1188 let first_derive_idx = commands
1198 .iter()
1199 .position(|c| matches!(c, LocyCommand::Derive { .. }));
1200 let mut inline_results: Vec<(usize, CommandResult)> = Vec::new();
1201 for (cmd_idx, cmd) in commands.iter().enumerate() {
1202 match cmd {
1203 LocyCommand::Cypher { query } => {
1204 if first_derive_idx.is_some_and(|di| cmd_idx > di) {
1207 continue;
1208 }
1209 let rows = execute_cypher_inline(
1210 query,
1211 &schema_info,
1212 ¶ms,
1213 &graph_ctx,
1214 &session_ctx,
1215 &storage,
1216 )
1217 .await?;
1218 inline_results.push((cmd_idx, CommandResult::Cypher(rows)));
1219 }
1220 LocyCommand::Validate { validate } => {
1221 let rule_key_cols: Vec<String> = strata
1225 .iter()
1226 .flat_map(|s| s.rules.iter())
1227 .find(|r| r.name == validate.rule_name)
1228 .map(|r| {
1229 r.yield_schema
1230 .iter()
1231 .filter(|c| c.is_key)
1232 .map(|c| c.name.clone())
1233 .collect()
1234 })
1235 .unwrap_or_default();
1236 let query =
1237 super::locy_validate::validate_collection_query(validate, &rule_key_cols);
1238 let target_rows = execute_cypher_inline(
1239 &query,
1240 &schema_info,
1241 ¶ms,
1242 &graph_ctx,
1243 &session_ctx,
1244 &storage,
1245 )
1246 .await?;
1247 let rule_facts: Vec<uni_locy::FactRow> = derived_store
1248 .get(&validate.rule_name)
1249 .map(|batches| super::locy_eval::record_batches_to_locy_rows(batches))
1250 .unwrap_or_default();
1251 let result = super::locy_validate::run_validate(
1252 validate,
1253 &rule_key_cols,
1254 &rule_facts,
1255 target_rows,
1256 )
1257 .map_err(|e| {
1258 datafusion::error::DataFusionError::Execution(format!("VALIDATE error: {e}"))
1259 })?;
1260 inline_results.push((cmd_idx, CommandResult::Validate(result)));
1261 }
1262 LocyCommand::Calibrate {
1263 calibrate,
1264 model_inputs,
1265 } => {
1266 let model_snapshot = uni_locy::CompiledModel {
1279 name: calibrate.model_name.clone(),
1280 inputs: model_inputs.clone(),
1281 features: vec![],
1282 path_context: None,
1283 output_type: uni_cypher::locy_ast::OutputType::Prob,
1284 output_name: String::new(),
1285 xervo_alias: String::new(),
1286 embedder_alias: None,
1287 calibration: None,
1288 version: None,
1289 annotations: Default::default(),
1290 };
1291 let query =
1292 super::locy_calibrate::calibrate_collection_query(calibrate, &model_snapshot);
1293 let rows = execute_cypher_inline(
1294 &query,
1295 &schema_info,
1296 ¶ms,
1297 &graph_ctx,
1298 &session_ctx,
1299 &storage,
1300 )
1301 .await?;
1302 let mut catalog = std::collections::HashMap::new();
1303 catalog.insert(calibrate.model_name.clone(), model_snapshot);
1304 let result = super::locy_calibrate::run_calibrate(
1305 calibrate,
1306 &catalog,
1307 &classifier_registry,
1308 rows,
1309 )
1310 .await
1311 .map_err(|e| {
1312 datafusion::error::DataFusionError::Execution(format!("CALIBRATE error: {e}"))
1313 })?;
1314 inline_results.push((cmd_idx, CommandResult::Calibrate(result)));
1315 }
1316 _ => {}
1317 }
1318 }
1319 *command_results_slot.write().unwrap() = inline_results;
1320
1321 let stats = vec![build_stats_batch(&derived_store, &strata, output_schema)];
1322 *derived_store_slot.write().unwrap() = Some(derived_store);
1323 Ok(stats)
1324}
1325
1326fn write_cross_stratum_facts(
1332 registry: &DerivedScanRegistry,
1333 derived_store: &DerivedStore,
1334 stratum: &LocyStratum,
1335) {
1336 for rule in &stratum.rules {
1338 for clause in &rule.clauses {
1339 for is_ref in &clause.is_refs {
1340 if let Some(facts) = derived_store.get(&is_ref.rule_name) {
1343 write_facts_to_registry(registry, &is_ref.rule_name, facts);
1344 }
1345 }
1346 }
1347 }
1348}
1349
1350fn write_facts_to_registry(registry: &DerivedScanRegistry, rule_name: &str, facts: &[RecordBatch]) {
1352 let entries = registry.entries_for_rule(rule_name);
1353 for entry in entries {
1354 if !entry.is_self_ref {
1355 let mut guard = entry.data.write();
1356 *guard = if facts.is_empty() || facts.iter().all(|b| b.num_rows() == 0) {
1357 vec![RecordBatch::new_empty(Arc::clone(&entry.schema))]
1358 } else {
1359 facts
1364 .iter()
1365 .filter(|b| b.num_rows() > 0)
1366 .map(|b| {
1367 RecordBatch::try_new(Arc::clone(&entry.schema), b.columns().to_vec())
1368 .unwrap_or_else(|_| b.clone())
1369 })
1370 .collect()
1371 };
1372 }
1373 }
1374}
1375
1376fn convert_to_fixpoint_plans(
1382 rules: &[LocyRulePlan],
1383 registry: &DerivedScanRegistry,
1384 plugin_registry: &PluginRegistry,
1385 deterministic_best_by: bool,
1386) -> DFResult<Vec<FixpointRulePlan>> {
1387 let stratum_rule_names: std::collections::HashSet<&str> =
1390 rules.iter().map(|r| r.name.as_str()).collect();
1391 rules
1392 .iter()
1393 .map(|rule| {
1394 let yield_schema = yield_columns_to_arrow_schema(&rule.yield_schema);
1395 let key_column_indices: Vec<usize> = rule
1396 .yield_schema
1397 .iter()
1398 .enumerate()
1399 .filter(|(_, yc)| yc.is_key)
1400 .map(|(i, _)| i)
1401 .collect();
1402
1403 let clauses: Vec<FixpointClausePlan> = rule
1404 .clauses
1405 .iter()
1406 .map(|clause| {
1407 let is_ref_bindings =
1408 convert_is_refs(&clause.is_refs, registry, &stratum_rule_names)?;
1409 Ok(FixpointClausePlan {
1410 body_logical: clause.body.clone(),
1411 is_ref_bindings,
1412 priority: clause.priority,
1413 along_bindings: clause.along_bindings.clone(),
1414 model_invocations: clause.model_invocations.clone(),
1415 })
1416 })
1417 .collect::<DFResult<Vec<_>>>()?;
1418
1419 let fold_bindings =
1420 convert_fold_bindings(&rule.fold_bindings, &rule.yield_schema, plugin_registry)?;
1421 let best_by_criteria =
1422 convert_best_by_criteria(&rule.best_by_criteria, &rule.yield_schema)?;
1423
1424 let has_priority = rule.priority.is_some();
1425
1426 let yield_schema = if has_priority {
1428 let mut fields: Vec<Arc<Field>> = yield_schema.fields().iter().cloned().collect();
1429 fields.push(Arc::new(Field::new("__priority", DataType::Int64, true)));
1430 ArrowSchema::new(fields)
1431 } else {
1432 yield_schema
1433 };
1434
1435 let prob_column_name = rule
1436 .yield_schema
1437 .iter()
1438 .find(|yc| yc.is_prob)
1439 .map(|yc| yc.name.clone());
1440
1441 let non_linear = rule.clauses.iter().any(|clause| {
1445 clause
1446 .is_refs
1447 .iter()
1448 .filter(|ir| !ir.negated && stratum_rule_names.contains(ir.rule_name.as_str()))
1449 .count()
1450 >= 2
1451 });
1452
1453 Ok(FixpointRulePlan {
1454 name: rule.name.clone(),
1455 clauses,
1456 yield_schema: Arc::new(yield_schema),
1457 key_column_indices,
1458 priority: rule.priority,
1459 has_fold: !rule.fold_bindings.is_empty(),
1460 fold_bindings,
1461 having: rule.having.clone(),
1462 has_best_by: !rule.best_by_criteria.is_empty(),
1463 best_by_criteria,
1464 has_priority,
1465 deterministic: deterministic_best_by,
1466 prob_column_name,
1467 non_linear,
1468 })
1469 })
1470 .collect()
1471}
1472
1473fn convert_is_refs(
1485 is_refs: &[LocyIsRef],
1486 registry: &DerivedScanRegistry,
1487 stratum_rule_names: &std::collections::HashSet<&str>,
1488) -> DFResult<Vec<IsRefBinding>> {
1489 is_refs
1490 .iter()
1491 .map(|is_ref| {
1492 let entries = registry.entries_for_rule(&is_ref.rule_name);
1493 let want_self_ref = stratum_rule_names.contains(is_ref.rule_name.as_str());
1498 let entry = entries
1499 .iter()
1500 .find(|e| e.is_self_ref == want_self_ref)
1501 .or_else(|| entries.first())
1502 .ok_or_else(|| {
1503 datafusion::error::DataFusionError::Plan(format!(
1504 "No derived scan entry found for IS-ref to '{}'",
1505 is_ref.rule_name
1506 ))
1507 })?;
1508
1509 let anti_join_cols = if is_ref.negated {
1514 let mut cols: Vec<(String, String)> = is_ref
1515 .subjects
1516 .iter()
1517 .enumerate()
1518 .filter_map(|(i, s)| {
1519 if let uni_cypher::ast::Expr::Variable(var) = s {
1520 let right_col = entry
1521 .schema
1522 .fields()
1523 .get(i)
1524 .map(|f| f.name().clone())
1525 .unwrap_or_else(|| var.clone());
1526 Some((var.clone(), right_col))
1529 } else {
1530 None
1531 }
1532 })
1533 .collect();
1534 if let Some(uni_cypher::ast::Expr::Variable(target_var)) = &is_ref.target {
1539 let target_idx = is_ref.subjects.len();
1540 if let Some(field) = entry.schema.fields().get(target_idx) {
1541 cols.push((target_var.clone(), field.name().clone()));
1542 }
1543 }
1544 cols
1545 } else {
1546 Vec::new()
1547 };
1548
1549 let provenance_join_cols: Vec<(String, String)> = is_ref
1553 .subjects
1554 .iter()
1555 .enumerate()
1556 .filter_map(|(i, s)| {
1557 if let uni_cypher::ast::Expr::Variable(var) = s {
1558 let right_col = entry
1559 .schema
1560 .fields()
1561 .get(i)
1562 .map(|f| f.name().clone())
1563 .unwrap_or_else(|| var.clone());
1564 Some((var.clone(), right_col))
1565 } else {
1566 None
1567 }
1568 })
1569 .collect();
1570
1571 Ok(IsRefBinding {
1572 derived_scan_index: entry.scan_index,
1573 rule_name: is_ref.rule_name.clone(),
1574 is_self_ref: entry.is_self_ref,
1575 negated: is_ref.negated,
1576 anti_join_cols,
1577 target_has_prob: is_ref.target_has_prob,
1578 target_prob_col: is_ref.target_prob_col.clone(),
1579 provenance_join_cols,
1580 })
1581 })
1582 .collect()
1583}
1584
1585fn convert_fold_bindings(
1593 fold_bindings: &[(String, String, Expr)],
1594 yield_schema: &[LocyYieldColumn],
1595 plugin_registry: &PluginRegistry,
1596) -> DFResult<Vec<FoldBinding>> {
1597 fold_bindings
1598 .iter()
1599 .map(|(name, yield_alias, expr)| {
1600 let (agg_name, _input_col_name) = parse_fold_aggregate(expr)?;
1601 let entry =
1602 resolve_locy_aggregate(plugin_registry, agg_name.as_str()).ok_or_else(|| {
1603 datafusion::error::DataFusionError::Plan(format!(
1604 "Unknown Locy aggregate '{agg_name}' — not registered in plugin registry"
1605 ))
1606 })?;
1607 let aggregate = Arc::clone(&entry.aggregate);
1608
1609 if agg_name.as_str() == "COUNTALL" {
1612 return Ok(FoldBinding {
1613 output_name: yield_alias.clone(),
1614 name: agg_name,
1615 aggregate,
1616 input_col_index: 0, input_col_name: None,
1618 });
1619 }
1620
1621 let input_col_index = yield_schema
1626 .iter()
1627 .position(|yc| yc.name == *name || yc.name == *yield_alias)
1628 .unwrap_or(0);
1629 Ok(FoldBinding {
1630 output_name: yield_alias.clone(),
1631 name: agg_name,
1632 aggregate,
1633 input_col_index,
1634 input_col_name: Some(name.clone()),
1635 })
1636 })
1637 .collect()
1638}
1639
1640fn parse_fold_aggregate(expr: &Expr) -> DFResult<(smol_str::SmolStr, String)> {
1646 match expr {
1647 Expr::FunctionCall { name, args, .. } => {
1648 let upper = name.to_uppercase();
1649 let is_count = matches!(upper.as_str(), "COUNT" | "MCOUNT");
1650
1651 if is_count && args.is_empty() {
1653 return Ok((smol_str::SmolStr::new_static("COUNTALL"), String::new()));
1654 }
1655
1656 let canonical = match upper.as_str() {
1657 "SUM" | "MSUM" => smol_str::SmolStr::new_static("SUM"),
1658 "MAX" | "MMAX" => smol_str::SmolStr::new_static("MAX"),
1659 "MIN" | "MMIN" => smol_str::SmolStr::new_static("MIN"),
1660 "COUNT" | "MCOUNT" => smol_str::SmolStr::new_static("COUNT"),
1661 "AVG" => smol_str::SmolStr::new_static("AVG"),
1662 "COLLECT" => smol_str::SmolStr::new_static("COLLECT"),
1663 "MNOR" => smol_str::SmolStr::new_static("MNOR"),
1664 "MPROD" => smol_str::SmolStr::new_static("MPROD"),
1665 _ => {
1666 return Err(datafusion::error::DataFusionError::Plan(format!(
1667 "Unknown FOLD aggregate function: {}",
1668 name
1669 )));
1670 }
1671 };
1672 let col_name = match args.first() {
1673 Some(Expr::Variable(v)) => v.clone(),
1674 Some(Expr::Property(_, prop)) => prop.clone(),
1675 Some(other) => other.to_string_repr(),
1676 None => {
1677 return Err(datafusion::error::DataFusionError::Plan(
1678 "FOLD aggregate function requires at least one argument".to_string(),
1679 ));
1680 }
1681 };
1682 Ok((canonical, col_name))
1683 }
1684 _ => Err(datafusion::error::DataFusionError::Plan(
1685 "FOLD binding must be a function call (e.g., SUM(x))".to_string(),
1686 )),
1687 }
1688}
1689
1690fn convert_best_by_criteria(
1697 criteria: &[(Expr, bool)],
1698 yield_schema: &[LocyYieldColumn],
1699) -> DFResult<Vec<SortCriterion>> {
1700 criteria
1701 .iter()
1702 .map(|(expr, ascending)| {
1703 let col_name = match expr {
1704 Expr::Property(_, prop) => prop.clone(),
1705 Expr::Variable(v) => v.clone(),
1706 _ => {
1707 return Err(datafusion::error::DataFusionError::Plan(
1708 "BEST BY criterion must be a variable or property reference".to_string(),
1709 ));
1710 }
1711 };
1712 let col_index = yield_schema
1714 .iter()
1715 .position(|yc| yc.name == col_name)
1716 .or_else(|| {
1717 let short_name = col_name.rsplit('.').next().unwrap_or(&col_name);
1718 yield_schema.iter().position(|yc| yc.name == short_name)
1719 })
1720 .ok_or_else(|| {
1721 datafusion::error::DataFusionError::Plan(format!(
1722 "BEST BY column '{}' not found in yield schema",
1723 col_name
1724 ))
1725 })?;
1726 Ok(SortCriterion {
1727 col_index,
1728 ascending: *ascending,
1729 nulls_first: false,
1730 })
1731 })
1732 .collect()
1733}
1734
1735fn yield_columns_to_arrow_schema(columns: &[LocyYieldColumn]) -> ArrowSchema {
1741 let fields: Vec<Arc<Field>> = columns
1742 .iter()
1743 .map(|yc| Arc::new(Field::new(&yc.name, yc.data_type.clone(), true)))
1744 .collect();
1745 ArrowSchema::new(fields)
1746}
1747
1748fn build_fixpoint_output_schema(rules: &[LocyRulePlan]) -> SchemaRef {
1750 if let Some(rule) = rules.first() {
1753 Arc::new(yield_columns_to_arrow_schema(&rule.yield_schema))
1754 } else {
1755 Arc::new(ArrowSchema::empty())
1756 }
1757}
1758
1759fn build_stats_batch(
1761 derived_store: &DerivedStore,
1762 _strata: &[LocyStratum],
1763 output_schema: SchemaRef,
1764) -> RecordBatch {
1765 let mut rule_names: Vec<String> = derived_store.rule_names().map(String::from).collect();
1767 rule_names.sort();
1768
1769 let name_col: arrow_array::StringArray = rule_names.iter().map(|s| Some(s.as_str())).collect();
1770 let count_col: arrow_array::Int64Array = rule_names
1771 .iter()
1772 .map(|name| Some(derived_store.fact_count(name) as i64))
1773 .collect();
1774
1775 let stats_schema = stats_schema();
1776 RecordBatch::try_new(stats_schema, vec![Arc::new(name_col), Arc::new(count_col)])
1777 .unwrap_or_else(|_| RecordBatch::new_empty(output_schema))
1778}
1779
1780pub fn stats_schema() -> SchemaRef {
1782 Arc::new(ArrowSchema::new(vec![
1783 Arc::new(Field::new("rule_name", DataType::Utf8, false)),
1784 Arc::new(Field::new("fact_count", DataType::Int64, false)),
1785 ]))
1786}
1787
1788#[cfg(test)]
1793mod tests {
1794 use super::*;
1795 use arrow_array::{Int64Array, LargeBinaryArray, StringArray};
1796
1797 #[test]
1798 fn test_derived_store_insert_and_get() {
1799 let mut store = DerivedStore::new();
1800 assert!(store.get("test").is_none());
1801
1802 let schema = Arc::new(ArrowSchema::new(vec![Arc::new(Field::new(
1803 "x",
1804 DataType::LargeBinary,
1805 true,
1806 ))]));
1807 let batch = RecordBatch::try_new(
1808 Arc::clone(&schema),
1809 vec![Arc::new(LargeBinaryArray::from(vec![
1810 Some(b"a" as &[u8]),
1811 Some(b"b"),
1812 ]))],
1813 )
1814 .unwrap();
1815
1816 store.insert("test".to_string(), vec![batch.clone()]);
1817
1818 let facts = store.get("test").unwrap();
1819 assert_eq!(facts.len(), 1);
1820 assert_eq!(facts[0].num_rows(), 2);
1821 }
1822
1823 #[test]
1824 fn test_derived_store_fact_count() {
1825 let mut store = DerivedStore::new();
1826 assert_eq!(store.fact_count("empty"), 0);
1827
1828 let schema = Arc::new(ArrowSchema::new(vec![Arc::new(Field::new(
1829 "x",
1830 DataType::LargeBinary,
1831 true,
1832 ))]));
1833 let batch1 = RecordBatch::try_new(
1834 Arc::clone(&schema),
1835 vec![Arc::new(LargeBinaryArray::from(vec![Some(b"a" as &[u8])]))],
1836 )
1837 .unwrap();
1838 let batch2 = RecordBatch::try_new(
1839 Arc::clone(&schema),
1840 vec![Arc::new(LargeBinaryArray::from(vec![
1841 Some(b"b" as &[u8]),
1842 Some(b"c"),
1843 ]))],
1844 )
1845 .unwrap();
1846
1847 store.insert("test".to_string(), vec![batch1, batch2]);
1848 assert_eq!(store.fact_count("test"), 3);
1849 }
1850
1851 #[test]
1852 fn test_stats_batch_schema() {
1853 let schema = stats_schema();
1854 assert_eq!(schema.fields().len(), 2);
1855 assert_eq!(schema.field(0).name(), "rule_name");
1856 assert_eq!(schema.field(1).name(), "fact_count");
1857 assert_eq!(schema.field(0).data_type(), &DataType::Utf8);
1858 assert_eq!(schema.field(1).data_type(), &DataType::Int64);
1859 }
1860
1861 #[test]
1862 fn test_stats_batch_content() {
1863 let mut store = DerivedStore::new();
1864 let schema = Arc::new(ArrowSchema::new(vec![Arc::new(Field::new(
1865 "x",
1866 DataType::LargeBinary,
1867 true,
1868 ))]));
1869 let batch = RecordBatch::try_new(
1870 Arc::clone(&schema),
1871 vec![Arc::new(LargeBinaryArray::from(vec![
1872 Some(b"a" as &[u8]),
1873 Some(b"b"),
1874 ]))],
1875 )
1876 .unwrap();
1877 store.insert("reach".to_string(), vec![batch]);
1878
1879 let output_schema = stats_schema();
1880 let stats = build_stats_batch(&store, &[], Arc::clone(&output_schema));
1881 assert_eq!(stats.num_rows(), 1);
1882
1883 let names = stats
1884 .column(0)
1885 .as_any()
1886 .downcast_ref::<StringArray>()
1887 .unwrap();
1888 assert_eq!(names.value(0), "reach");
1889
1890 let counts = stats
1891 .column(1)
1892 .as_any()
1893 .downcast_ref::<Int64Array>()
1894 .unwrap();
1895 assert_eq!(counts.value(0), 2);
1896 }
1897
1898 #[test]
1899 fn test_yield_columns_to_arrow_schema() {
1900 let columns = vec![
1901 LocyYieldColumn {
1902 name: "a".to_string(),
1903 is_key: true,
1904 is_prob: false,
1905 data_type: DataType::UInt64,
1906 },
1907 LocyYieldColumn {
1908 name: "b".to_string(),
1909 is_key: false,
1910 is_prob: false,
1911 data_type: DataType::LargeUtf8,
1912 },
1913 LocyYieldColumn {
1914 name: "c".to_string(),
1915 is_key: true,
1916 is_prob: false,
1917 data_type: DataType::Float64,
1918 },
1919 ];
1920
1921 let schema = yield_columns_to_arrow_schema(&columns);
1922 assert_eq!(schema.fields().len(), 3);
1923 assert_eq!(schema.field(0).name(), "a");
1924 assert_eq!(schema.field(1).name(), "b");
1925 assert_eq!(schema.field(2).name(), "c");
1926 assert_eq!(schema.field(0).data_type(), &DataType::UInt64);
1928 assert_eq!(schema.field(1).data_type(), &DataType::LargeUtf8);
1929 assert_eq!(schema.field(2).data_type(), &DataType::Float64);
1930 for field in schema.fields() {
1931 assert!(field.is_nullable());
1932 }
1933 }
1934
1935 #[test]
1936 fn test_key_column_indices() {
1937 let columns = [
1938 LocyYieldColumn {
1939 name: "a".to_string(),
1940 is_key: true,
1941 is_prob: false,
1942 data_type: DataType::LargeBinary,
1943 },
1944 LocyYieldColumn {
1945 name: "b".to_string(),
1946 is_key: false,
1947 is_prob: false,
1948 data_type: DataType::LargeBinary,
1949 },
1950 LocyYieldColumn {
1951 name: "c".to_string(),
1952 is_key: true,
1953 is_prob: false,
1954 data_type: DataType::LargeBinary,
1955 },
1956 ];
1957
1958 let key_indices: Vec<usize> = columns
1959 .iter()
1960 .enumerate()
1961 .filter(|(_, yc)| yc.is_key)
1962 .map(|(i, _)| i)
1963 .collect();
1964 assert_eq!(key_indices, vec![0, 2]);
1965 }
1966
1967 #[test]
1968 fn test_parse_fold_aggregate_sum() {
1969 let expr = Expr::FunctionCall {
1970 name: "SUM".to_string(),
1971 args: vec![Expr::Variable("cost".to_string())],
1972 distinct: false,
1973 window_spec: None,
1974 };
1975 let (kind, col) = parse_fold_aggregate(&expr).unwrap();
1976 assert_eq!(kind.as_str(), "SUM");
1977 assert_eq!(col, "cost");
1978 }
1979
1980 #[test]
1981 fn test_parse_fold_aggregate_monotonic() {
1982 let expr = Expr::FunctionCall {
1983 name: "MMAX".to_string(),
1984 args: vec![Expr::Variable("score".to_string())],
1985 distinct: false,
1986 window_spec: None,
1987 };
1988 let (kind, col) = parse_fold_aggregate(&expr).unwrap();
1989 assert_eq!(kind.as_str(), "MAX");
1990 assert_eq!(col, "score");
1991 }
1992
1993 #[test]
1994 fn test_parse_fold_aggregate_unknown() {
1995 let expr = Expr::FunctionCall {
1996 name: "UNKNOWN_AGG".to_string(),
1997 args: vec![Expr::Variable("x".to_string())],
1998 distinct: false,
1999 window_spec: None,
2000 };
2001 assert!(parse_fold_aggregate(&expr).is_err());
2002 }
2003
2004 #[test]
2005 fn test_no_commands_returns_stats() {
2006 let store = DerivedStore::new();
2007 let output_schema = stats_schema();
2008 let stats = build_stats_batch(&store, &[], Arc::clone(&output_schema));
2009 assert_eq!(stats.num_rows(), 0);
2011 }
2012}