1use crate::query::df_graph::GraphExecutionContext;
12use crate::query::df_graph::common::{
13 collect_all_partitions, compute_plan_properties, execute_subplan,
14};
15use crate::query::df_graph::locy_best_by::SortCriterion;
16use crate::query::df_graph::locy_explain::ProvenanceStore;
17use crate::query::df_graph::locy_fixpoint::{
18 DerivedScanRegistry, FixpointClausePlan, FixpointExec, FixpointRulePlan, IsRefBinding,
19};
20use crate::query::df_graph::locy_fold::{FoldBinding, resolve_locy_aggregate};
21use crate::query::planner_locy_types::{
22 LocyCommand, LocyIsRef, LocyRulePlan, LocyStratum, LocyYieldColumn,
23};
24use arrow_array::RecordBatch;
25use arrow_schema::{DataType, Field, Schema as ArrowSchema, SchemaRef};
26use datafusion::common::Result as DFResult;
27use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
28use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
29use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
30use futures::Stream;
31use parking_lot::RwLock;
32use std::any::Any;
33use std::collections::HashMap;
34use std::fmt;
35use std::pin::Pin;
36use std::sync::Arc;
37use std::sync::RwLock as StdRwLock;
38use std::task::{Context, Poll};
39use std::time::{Duration, Instant};
40use uni_common::Value;
41use uni_common::core::schema::Schema as UniSchema;
42use uni_cypher::ast::Expr;
43use uni_locy::{
44 ClassifierRegistry, CommandResult, FactRow, ModelInvocationCache, RuntimeWarning, SemiringKind,
45};
46use uni_plugin::PluginRegistry;
47use uni_store::storage::manager::StorageManager;
48
49pub struct DerivedStore {
58 relations: HashMap<String, Vec<RecordBatch>>,
59}
60
61impl Default for DerivedStore {
62 fn default() -> Self {
63 Self::new()
64 }
65}
66
67impl DerivedStore {
68 pub fn new() -> Self {
69 Self {
70 relations: HashMap::new(),
71 }
72 }
73
74 pub fn insert(&mut self, rule_name: String, facts: Vec<RecordBatch>) {
75 self.relations.insert(rule_name, facts);
76 }
77
78 pub fn get(&self, rule_name: &str) -> Option<&Vec<RecordBatch>> {
79 self.relations.get(rule_name)
80 }
81
82 pub fn fact_count(&self, rule_name: &str) -> usize {
83 self.relations
84 .get(rule_name)
85 .map(|batches| batches.iter().map(|b| b.num_rows()).sum())
86 .unwrap_or(0)
87 }
88
89 pub fn rule_names(&self) -> impl Iterator<Item = &str> {
90 self.relations.keys().map(|s| s.as_str())
91 }
92}
93
94pub struct LocyProgramExec {
104 strata: Vec<LocyStratum>,
105 commands: Vec<LocyCommand>,
106 derived_scan_registry: Arc<DerivedScanRegistry>,
107 plugin_registry: Arc<PluginRegistry>,
108 graph_ctx: Arc<GraphExecutionContext>,
109 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
110 storage: Arc<StorageManager>,
111 schema_info: Arc<UniSchema>,
112 params: HashMap<String, Value>,
113 output_schema: SchemaRef,
114 properties: Arc<PlanProperties>,
115 metrics: ExecutionPlanMetricsSet,
116 max_iterations: usize,
117 timeout: Duration,
118 max_derived_bytes: usize,
119 deterministic_best_by: bool,
120 strict_probability_domain: bool,
121 probability_epsilon: f64,
122 exact_probability: bool,
123 max_bdd_variables: usize,
124 semiring_kind: SemiringKind,
127 classifier_registry: Arc<ClassifierRegistry>,
131 classifier_cache: Option<Arc<ModelInvocationCache>>,
134 classifier_provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
138 derived_store_slot: Arc<StdRwLock<Option<DerivedStore>>>,
140 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
142 derivation_tracker: Arc<StdRwLock<Option<Arc<ProvenanceStore>>>>,
144 iteration_counts_slot: Arc<StdRwLock<HashMap<String, usize>>>,
146 peak_memory_slot: Arc<StdRwLock<usize>>,
148 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
150 command_results_slot: Arc<StdRwLock<Vec<(usize, CommandResult)>>>,
152 top_k_proofs: usize,
154 timeout_flag: Arc<std::sync::atomic::AtomicU8>,
158 incomplete_slot: Arc<StdRwLock<Option<uni_common::LocyIncomplete>>>,
162}
163
164pub(crate) mod interruption {
171 use std::sync::atomic::{AtomicU8, Ordering};
172
173 use uni_common::LocyIncompleteReason;
174
175 pub(crate) const NONE: u8 = 0;
177 pub(crate) const TIMEOUT: u8 = 1;
179 pub(crate) const ITERATION_LIMIT: u8 = 2;
181
182 pub(crate) fn reason(flag: &AtomicU8) -> Option<LocyIncompleteReason> {
184 match flag.load(Ordering::Relaxed) {
185 TIMEOUT => Some(LocyIncompleteReason::Timeout),
186 ITERATION_LIMIT => Some(LocyIncompleteReason::IterationLimit),
187 _ => None,
188 }
189 }
190
191 pub(crate) fn set(flag: &AtomicU8, code: u8) {
195 let _ = flag.compare_exchange(NONE, code, Ordering::Relaxed, Ordering::Relaxed);
196 }
197}
198
199impl fmt::Debug for LocyProgramExec {
200 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201 f.debug_struct("LocyProgramExec")
202 .field("strata_count", &self.strata.len())
203 .field("commands_count", &self.commands.len())
204 .field("max_iterations", &self.max_iterations)
205 .field("timeout", &self.timeout)
206 .field("output_schema", &self.output_schema)
207 .field("max_derived_bytes", &self.max_derived_bytes)
208 .finish_non_exhaustive()
209 }
210}
211
212impl LocyProgramExec {
213 #[expect(
214 clippy::too_many_arguments,
215 reason = "execution plan node requires full graph and session context"
216 )]
217 #[deprecated(
218 note = "use `new_with_semiring_classifiers_and_cache` (or the lighter \
219 `new_with_semiring_and_classifiers` / `new_with_semiring`) — \
220 this legacy ctor defaults the semiring to AddMultProb and \
221 ships no classifier registry. To be removed after C0 Stage 2."
222 )]
223 pub fn new(
224 strata: Vec<LocyStratum>,
225 commands: Vec<LocyCommand>,
226 derived_scan_registry: Arc<DerivedScanRegistry>,
227 plugin_registry: Arc<PluginRegistry>,
228 graph_ctx: Arc<GraphExecutionContext>,
229 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
230 storage: Arc<StorageManager>,
231 schema_info: Arc<UniSchema>,
232 params: HashMap<String, Value>,
233 output_schema: SchemaRef,
234 max_iterations: usize,
235 timeout: Duration,
236 max_derived_bytes: usize,
237 deterministic_best_by: bool,
238 strict_probability_domain: bool,
239 probability_epsilon: f64,
240 exact_probability: bool,
241 max_bdd_variables: usize,
242 top_k_proofs: usize,
243 ) -> Self {
244 Self::new_with_semiring_and_classifiers(
245 strata,
246 commands,
247 derived_scan_registry,
248 plugin_registry,
249 graph_ctx,
250 session_ctx,
251 storage,
252 schema_info,
253 params,
254 output_schema,
255 max_iterations,
256 timeout,
257 max_derived_bytes,
258 deterministic_best_by,
259 strict_probability_domain,
260 probability_epsilon,
261 exact_probability,
262 max_bdd_variables,
263 top_k_proofs,
264 SemiringKind::AddMultProb,
265 Arc::new(ClassifierRegistry::new()),
266 )
267 }
268
269 #[expect(
273 clippy::too_many_arguments,
274 reason = "execution plan node requires full graph and session context"
275 )]
276 pub fn new_with_semiring(
277 strata: Vec<LocyStratum>,
278 commands: Vec<LocyCommand>,
279 derived_scan_registry: Arc<DerivedScanRegistry>,
280 plugin_registry: Arc<PluginRegistry>,
281 graph_ctx: Arc<GraphExecutionContext>,
282 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
283 storage: Arc<StorageManager>,
284 schema_info: Arc<UniSchema>,
285 params: HashMap<String, Value>,
286 output_schema: SchemaRef,
287 max_iterations: usize,
288 timeout: Duration,
289 max_derived_bytes: usize,
290 deterministic_best_by: bool,
291 strict_probability_domain: bool,
292 probability_epsilon: f64,
293 exact_probability: bool,
294 max_bdd_variables: usize,
295 top_k_proofs: usize,
296 semiring_kind: SemiringKind,
297 ) -> Self {
298 Self::new_with_semiring_and_classifiers(
299 strata,
300 commands,
301 derived_scan_registry,
302 plugin_registry,
303 graph_ctx,
304 session_ctx,
305 storage,
306 schema_info,
307 params,
308 output_schema,
309 max_iterations,
310 timeout,
311 max_derived_bytes,
312 deterministic_best_by,
313 strict_probability_domain,
314 probability_epsilon,
315 exact_probability,
316 max_bdd_variables,
317 top_k_proofs,
318 semiring_kind,
319 Arc::new(ClassifierRegistry::new()),
320 )
321 }
322
323 #[expect(
326 clippy::too_many_arguments,
327 reason = "execution plan node requires full graph and session context"
328 )]
329 pub fn new_with_semiring_and_classifiers(
330 strata: Vec<LocyStratum>,
331 commands: Vec<LocyCommand>,
332 derived_scan_registry: Arc<DerivedScanRegistry>,
333 plugin_registry: Arc<PluginRegistry>,
334 graph_ctx: Arc<GraphExecutionContext>,
335 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
336 storage: Arc<StorageManager>,
337 schema_info: Arc<UniSchema>,
338 params: HashMap<String, Value>,
339 output_schema: SchemaRef,
340 max_iterations: usize,
341 timeout: Duration,
342 max_derived_bytes: usize,
343 deterministic_best_by: bool,
344 strict_probability_domain: bool,
345 probability_epsilon: f64,
346 exact_probability: bool,
347 max_bdd_variables: usize,
348 top_k_proofs: usize,
349 semiring_kind: SemiringKind,
350 classifier_registry: Arc<ClassifierRegistry>,
351 ) -> Self {
352 Self::new_with_semiring_classifiers_and_cache(
353 strata,
354 commands,
355 derived_scan_registry,
356 plugin_registry,
357 graph_ctx,
358 session_ctx,
359 storage,
360 schema_info,
361 params,
362 output_schema,
363 max_iterations,
364 timeout,
365 max_derived_bytes,
366 deterministic_best_by,
367 strict_probability_domain,
368 probability_epsilon,
369 exact_probability,
370 max_bdd_variables,
371 top_k_proofs,
372 semiring_kind,
373 classifier_registry,
374 None,
375 None,
376 )
377 }
378
379 #[expect(
384 clippy::too_many_arguments,
385 reason = "execution plan node requires full graph and session context"
386 )]
387 pub fn new_with_semiring_classifiers_and_cache(
388 strata: Vec<LocyStratum>,
389 commands: Vec<LocyCommand>,
390 derived_scan_registry: Arc<DerivedScanRegistry>,
391 plugin_registry: Arc<PluginRegistry>,
392 graph_ctx: Arc<GraphExecutionContext>,
393 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
394 storage: Arc<StorageManager>,
395 schema_info: Arc<UniSchema>,
396 params: HashMap<String, Value>,
397 output_schema: SchemaRef,
398 max_iterations: usize,
399 timeout: Duration,
400 max_derived_bytes: usize,
401 deterministic_best_by: bool,
402 strict_probability_domain: bool,
403 probability_epsilon: f64,
404 exact_probability: bool,
405 max_bdd_variables: usize,
406 top_k_proofs: usize,
407 semiring_kind: SemiringKind,
408 classifier_registry: Arc<ClassifierRegistry>,
409 classifier_cache: Option<Arc<ModelInvocationCache>>,
410 classifier_provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
411 ) -> Self {
412 let properties = compute_plan_properties(Arc::clone(&output_schema));
413 Self {
414 strata,
415 commands,
416 derived_scan_registry,
417 plugin_registry,
418 graph_ctx,
419 session_ctx,
420 storage,
421 schema_info,
422 params,
423 output_schema,
424 properties,
425 metrics: ExecutionPlanMetricsSet::new(),
426 max_iterations,
427 timeout,
428 max_derived_bytes,
429 deterministic_best_by,
430 strict_probability_domain,
431 probability_epsilon,
432 exact_probability,
433 max_bdd_variables,
434 semiring_kind,
435 classifier_registry,
436 classifier_cache,
437 classifier_provenance_store,
438 derived_store_slot: Arc::new(StdRwLock::new(None)),
439 approximate_slot: Arc::new(StdRwLock::new(HashMap::new())),
440 derivation_tracker: Arc::new(StdRwLock::new(None)),
441 iteration_counts_slot: Arc::new(StdRwLock::new(HashMap::new())),
442 peak_memory_slot: Arc::new(StdRwLock::new(0)),
443 warnings_slot: Arc::new(StdRwLock::new(Vec::new())),
444 command_results_slot: Arc::new(StdRwLock::new(Vec::new())),
445 top_k_proofs,
446 timeout_flag: Arc::new(std::sync::atomic::AtomicU8::new(interruption::NONE)),
447 incomplete_slot: Arc::new(StdRwLock::new(None)),
448 }
449 }
450
451 pub fn derived_store_slot(&self) -> Arc<StdRwLock<Option<DerivedStore>>> {
456 Arc::clone(&self.derived_store_slot)
457 }
458
459 pub fn set_derivation_tracker(&self, tracker: Arc<ProvenanceStore>) {
464 if let Ok(mut guard) = self.derivation_tracker.write() {
465 *guard = Some(tracker);
466 }
467 }
468
469 pub fn iteration_counts_slot(&self) -> Arc<StdRwLock<HashMap<String, usize>>> {
474 Arc::clone(&self.iteration_counts_slot)
475 }
476
477 pub fn peak_memory_slot(&self) -> Arc<StdRwLock<usize>> {
482 Arc::clone(&self.peak_memory_slot)
483 }
484
485 pub fn warnings_slot(&self) -> Arc<StdRwLock<Vec<RuntimeWarning>>> {
490 Arc::clone(&self.warnings_slot)
491 }
492
493 pub fn approximate_slot(&self) -> Arc<StdRwLock<HashMap<String, Vec<String>>>> {
498 Arc::clone(&self.approximate_slot)
499 }
500
501 pub fn command_results_slot(&self) -> Arc<StdRwLock<Vec<(usize, CommandResult)>>> {
506 Arc::clone(&self.command_results_slot)
507 }
508
509 pub fn timeout_flag(&self) -> Arc<std::sync::atomic::AtomicU8> {
515 Arc::clone(&self.timeout_flag)
516 }
517
518 pub fn incomplete_slot(&self) -> Arc<StdRwLock<Option<uni_common::LocyIncomplete>>> {
524 Arc::clone(&self.incomplete_slot)
525 }
526}
527
528impl DisplayAs for LocyProgramExec {
529 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
530 write!(
531 f,
532 "LocyProgramExec: strata={}, commands={}, max_iter={}, timeout={:?}",
533 self.strata.len(),
534 self.commands.len(),
535 self.max_iterations,
536 self.timeout,
537 )
538 }
539}
540
541impl ExecutionPlan for LocyProgramExec {
542 fn name(&self) -> &str {
543 "LocyProgramExec"
544 }
545
546 fn as_any(&self) -> &dyn Any {
547 self
548 }
549
550 fn schema(&self) -> SchemaRef {
551 Arc::clone(&self.output_schema)
552 }
553
554 fn properties(&self) -> &Arc<PlanProperties> {
555 &self.properties
556 }
557
558 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
559 vec![]
560 }
561
562 fn with_new_children(
563 self: Arc<Self>,
564 children: Vec<Arc<dyn ExecutionPlan>>,
565 ) -> DFResult<Arc<dyn ExecutionPlan>> {
566 if !children.is_empty() {
567 return Err(datafusion::error::DataFusionError::Plan(
568 "LocyProgramExec has no children".to_string(),
569 ));
570 }
571 Ok(self)
572 }
573
574 fn execute(
575 &self,
576 partition: usize,
577 _context: Arc<TaskContext>,
578 ) -> DFResult<SendableRecordBatchStream> {
579 let metrics = BaselineMetrics::new(&self.metrics, partition);
580
581 let strata = self.strata.clone();
582 let registry = Arc::clone(&self.derived_scan_registry);
583 let plugin_registry = Arc::clone(&self.plugin_registry);
584 let graph_ctx = Arc::clone(&self.graph_ctx);
585 let session_ctx = Arc::clone(&self.session_ctx);
586 let storage = Arc::clone(&self.storage);
587 let schema_info = Arc::clone(&self.schema_info);
588 let params = self.params.clone();
589 let output_schema = Arc::clone(&self.output_schema);
590 let max_iterations = self.max_iterations;
591 let timeout = self.timeout;
592 let max_derived_bytes = self.max_derived_bytes;
593 let deterministic_best_by = self.deterministic_best_by;
594 let strict_probability_domain = self.strict_probability_domain;
595 let probability_epsilon = self.probability_epsilon;
596 let exact_probability = self.exact_probability;
597 let max_bdd_variables = self.max_bdd_variables;
598 let derived_store_slot = Arc::clone(&self.derived_store_slot);
599 let approximate_slot = Arc::clone(&self.approximate_slot);
600 let iteration_counts_slot = Arc::clone(&self.iteration_counts_slot);
601 let peak_memory_slot = Arc::clone(&self.peak_memory_slot);
602 let derivation_tracker = self.derivation_tracker.read().ok().and_then(|g| g.clone());
603 let warnings_slot = Arc::clone(&self.warnings_slot);
604 let commands = self.commands.clone();
605 let command_results_slot = Arc::clone(&self.command_results_slot);
606 let top_k_proofs = self.top_k_proofs;
607 let timeout_flag = Arc::clone(&self.timeout_flag);
608 let incomplete_slot = Arc::clone(&self.incomplete_slot);
609 let semiring_kind = self.semiring_kind;
610 let classifier_registry = Arc::clone(&self.classifier_registry);
611 let classifier_cache = self.classifier_cache.as_ref().map(Arc::clone);
612 let classifier_provenance_store = self.classifier_provenance_store.as_ref().map(Arc::clone);
613
614 let fut = async move {
615 run_program(
616 strata,
617 commands,
618 registry,
619 plugin_registry,
620 graph_ctx,
621 session_ctx,
622 storage,
623 schema_info,
624 params,
625 output_schema,
626 max_iterations,
627 timeout,
628 max_derived_bytes,
629 deterministic_best_by,
630 strict_probability_domain,
631 probability_epsilon,
632 exact_probability,
633 max_bdd_variables,
634 derived_store_slot,
635 approximate_slot,
636 iteration_counts_slot,
637 peak_memory_slot,
638 derivation_tracker,
639 warnings_slot,
640 command_results_slot,
641 top_k_proofs,
642 timeout_flag,
643 incomplete_slot,
644 semiring_kind,
645 classifier_registry,
646 classifier_cache,
647 classifier_provenance_store,
648 )
649 .await
650 };
651
652 Ok(Box::pin(ProgramStream {
653 state: ProgramStreamState::Running(Box::pin(fut)),
654 schema: Arc::clone(&self.output_schema),
655 metrics,
656 }))
657 }
658
659 fn metrics(&self) -> Option<MetricsSet> {
660 Some(self.metrics.clone_inner())
661 }
662}
663
664enum ProgramStreamState {
669 Running(Pin<Box<dyn std::future::Future<Output = DFResult<Vec<RecordBatch>>> + Send>>),
670 Emitting(Vec<RecordBatch>, usize),
671 Done,
672}
673
674struct ProgramStream {
675 state: ProgramStreamState,
676 schema: SchemaRef,
677 metrics: BaselineMetrics,
678}
679
680impl Stream for ProgramStream {
681 type Item = DFResult<RecordBatch>;
682
683 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
684 let this = self.get_mut();
685 let metrics = this.metrics.clone();
686 let _timer = metrics.elapsed_compute().timer();
687 loop {
688 match &mut this.state {
689 ProgramStreamState::Running(fut) => match fut.as_mut().poll(cx) {
690 Poll::Ready(Ok(batches)) => {
691 if batches.is_empty() {
692 this.state = ProgramStreamState::Done;
693 return Poll::Ready(None);
694 }
695 this.state = ProgramStreamState::Emitting(batches, 0);
696 }
697 Poll::Ready(Err(e)) => {
698 this.state = ProgramStreamState::Done;
699 return Poll::Ready(Some(Err(e)));
700 }
701 Poll::Pending => return Poll::Pending,
702 },
703 ProgramStreamState::Emitting(batches, idx) => {
704 if *idx >= batches.len() {
705 this.state = ProgramStreamState::Done;
706 return Poll::Ready(None);
707 }
708 let batch = batches[*idx].clone();
709 *idx += 1;
710 this.metrics.record_output(batch.num_rows());
711 return Poll::Ready(Some(Ok(batch)));
712 }
713 ProgramStreamState::Done => return Poll::Ready(None),
714 }
715 }
716 }
717}
718
719impl RecordBatchStream for ProgramStream {
720 fn schema(&self) -> SchemaRef {
721 Arc::clone(&self.schema)
722 }
723}
724
725async fn execute_cypher_inline(
731 query: &uni_cypher::ast::Query,
732 schema_info: &Arc<UniSchema>,
733 params: &HashMap<String, Value>,
734 graph_ctx: &Arc<GraphExecutionContext>,
735 session_ctx: &Arc<RwLock<datafusion::prelude::SessionContext>>,
736 storage: &Arc<StorageManager>,
737) -> DFResult<Vec<FactRow>> {
738 let planner = crate::query::planner::QueryPlanner::new(Arc::clone(schema_info));
739 let logical_plan = planner.plan(query.clone()).map_err(|e| {
740 datafusion::error::DataFusionError::Execution(format!("Cypher plan error: {e}"))
741 })?;
742 let batches = execute_subplan(
743 &logical_plan,
744 params,
745 &HashMap::new(),
746 graph_ctx,
747 session_ctx,
748 storage,
749 schema_info,
750 None, )
752 .await?;
753 Ok(super::locy_eval::record_batches_to_locy_rows(&batches))
754}
755
756#[expect(
761 clippy::too_many_arguments,
762 reason = "program evaluation requires full graph and session context"
763)]
764async fn run_program(
765 strata: Vec<LocyStratum>,
766 commands: Vec<LocyCommand>,
767 registry: Arc<DerivedScanRegistry>,
768 plugin_registry: Arc<PluginRegistry>,
769 graph_ctx: Arc<GraphExecutionContext>,
770 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
771 storage: Arc<StorageManager>,
772 schema_info: Arc<UniSchema>,
773 params: HashMap<String, Value>,
774 output_schema: SchemaRef,
775 max_iterations: usize,
776 timeout: Duration,
777 max_derived_bytes: usize,
778 deterministic_best_by: bool,
779 strict_probability_domain: bool,
780 probability_epsilon: f64,
781 exact_probability: bool,
782 max_bdd_variables: usize,
783 derived_store_slot: Arc<StdRwLock<Option<DerivedStore>>>,
784 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
785 iteration_counts_slot: Arc<StdRwLock<HashMap<String, usize>>>,
786 peak_memory_slot: Arc<StdRwLock<usize>>,
787 derivation_tracker: Option<Arc<ProvenanceStore>>,
788 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
789 command_results_slot: Arc<StdRwLock<Vec<(usize, CommandResult)>>>,
790 top_k_proofs: usize,
791 timeout_flag: Arc<std::sync::atomic::AtomicU8>,
792 incomplete_slot: Arc<StdRwLock<Option<uni_common::LocyIncomplete>>>,
793 semiring_kind: SemiringKind,
794 classifier_registry: Arc<ClassifierRegistry>,
795 classifier_cache: Option<Arc<ModelInvocationCache>>,
796 classifier_provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
797) -> DFResult<Vec<RecordBatch>> {
798 let start = Instant::now();
799 let mut derived_store = DerivedStore::new();
800
801 if semiring_kind == SemiringKind::MaxMinProb {
806 let mut warnings = warnings_slot.write().unwrap_or_else(|e| e.into_inner());
807 let mut already: std::collections::HashSet<String> = warnings
808 .iter()
809 .filter(|w| w.code == uni_locy::RuntimeWarningCode::FuzzyNotProbabilistic)
810 .map(|w| w.rule_name.clone())
811 .collect();
812 for stratum in &strata {
813 for rule in &stratum.rules {
814 let has_prob = rule.yield_schema.iter().any(|c| c.is_prob);
815 if has_prob && !already.contains(&rule.name) {
816 warnings.push(RuntimeWarning {
817 code: uni_locy::RuntimeWarningCode::FuzzyNotProbabilistic,
818 message: format!(
819 "rule '{}' carries a PROB column but is being evaluated under \
820 the MaxMinProb (fuzzy / Viterbi) semiring; outputs are fuzzy \
821 truth values, not probabilities",
822 rule.name
823 ),
824 rule_name: rule.name.clone(),
825 variable_count: None,
826 key_group: None,
827 });
828 already.insert(rule.name.clone());
829 }
830 }
831 }
832 }
833
834 let total_strata = strata.len();
838 let mut completed_strata = 0usize;
839 let mut partial_stratum: Option<usize> = None;
840 for (stratum_idx, stratum) in strata.iter().enumerate() {
841 write_cross_stratum_facts(®istry, &derived_store, stratum);
843
844 let remaining_timeout = timeout.saturating_sub(start.elapsed());
845 if remaining_timeout.is_zero() {
846 tracing::warn!("Locy program timeout exceeded during stratum evaluation");
847 interruption::set(&timeout_flag, interruption::TIMEOUT);
848 break;
849 }
850
851 if stratum.is_recursive {
852 let fixpoint_rules = convert_to_fixpoint_plans(
854 &stratum.rules,
855 ®istry,
856 &plugin_registry,
857 deterministic_best_by,
858 )?;
859 let fixpoint_schema = build_fixpoint_output_schema(&stratum.rules);
860
861 let exec = FixpointExec::new_with_semiring_classifiers_and_cache(
862 fixpoint_rules,
863 max_iterations,
864 remaining_timeout,
865 Arc::clone(&graph_ctx),
866 Arc::clone(&session_ctx),
867 Arc::clone(&storage),
868 Arc::clone(&schema_info),
869 params.clone(),
870 Arc::clone(®istry),
871 fixpoint_schema,
872 max_derived_bytes,
873 derivation_tracker.clone(),
874 Arc::clone(&iteration_counts_slot),
875 strict_probability_domain,
876 probability_epsilon,
877 exact_probability,
878 max_bdd_variables,
879 Arc::clone(&warnings_slot),
880 Arc::clone(&approximate_slot),
881 top_k_proofs,
882 Arc::clone(&timeout_flag),
883 semiring_kind,
884 Arc::clone(&classifier_registry),
885 classifier_cache.as_ref().map(Arc::clone),
886 classifier_provenance_store.as_ref().map(Arc::clone),
887 );
888
889 let task_ctx = session_ctx.read().task_ctx();
890 let exec_arc: Arc<dyn ExecutionPlan> = Arc::new(exec);
891 let batches = collect_all_partitions(&exec_arc, task_ctx).await?;
892
893 for rule in &stratum.rules {
904 if rule.yield_schema.is_empty() {
906 continue;
907 }
908 let rule_entries = registry.entries_for_rule(&rule.name);
910 for entry in rule_entries {
911 if !entry.is_self_ref {
912 let all_facts: Vec<RecordBatch> = batches
916 .iter()
917 .filter(|b| {
918 let rule_schema = yield_columns_to_arrow_schema(&rule.yield_schema);
920 b.schema().fields().len() == rule_schema.fields().len()
921 })
922 .cloned()
923 .collect();
924 let mut guard = entry.data.write();
925 *guard = if all_facts.is_empty() {
926 vec![RecordBatch::new_empty(Arc::clone(&entry.schema))]
927 } else {
928 all_facts
929 };
930 }
931 }
932 derived_store.insert(rule.name.clone(), batches.clone());
933 }
934 } else {
935 let fixpoint_rules = convert_to_fixpoint_plans(
937 &stratum.rules,
938 ®istry,
939 &plugin_registry,
940 deterministic_best_by,
941 )?;
942 let task_ctx = session_ctx.read().task_ctx();
943
944 for (rule, fp_rule) in stratum.rules.iter().zip(fixpoint_rules.iter()) {
945 if rule.yield_schema.is_empty() {
950 continue;
951 }
952
953 let mut tagged_clause_facts: Vec<(usize, Vec<RecordBatch>)> = Vec::new();
955 for (clause_idx, (clause, fp_clause)) in
956 rule.clauses.iter().zip(fp_rule.clauses.iter()).enumerate()
957 {
958 let mut batches = execute_subplan(
962 &clause.body,
963 ¶ms,
964 &HashMap::new(),
965 &graph_ctx,
966 &session_ctx,
967 &storage,
968 &schema_info,
969 None, )
971 .await?;
972
973 for binding in &fp_clause.is_ref_bindings {
975 if binding.negated
976 && !binding.anti_join_cols.is_empty()
977 && let Some(entry) = registry.get(binding.derived_scan_index)
978 {
979 let neg_facts = entry.data.read().clone();
980 if !neg_facts.is_empty() {
981 if binding.target_has_prob && fp_rule.prob_column_name.is_some() {
982 let complement_col =
983 format!("__prob_complement_{}", binding.rule_name);
984 if let Some(prob_col) = &binding.target_prob_col {
985 batches =
986 super::locy_fixpoint::apply_prob_complement_composite(
987 batches,
988 &neg_facts,
989 &binding.anti_join_cols,
990 prob_col,
991 &complement_col,
992 )?;
993 } else {
994 batches = super::locy_fixpoint::apply_anti_join_composite(
996 batches,
997 &neg_facts,
998 &binding.anti_join_cols,
999 )?;
1000 }
1001 } else {
1002 batches = super::locy_fixpoint::apply_anti_join_composite(
1003 batches,
1004 &neg_facts,
1005 &binding.anti_join_cols,
1006 )?;
1007 }
1008 }
1009 }
1010 }
1011
1012 let complement_cols: Vec<String> = if !batches.is_empty() {
1014 batches[0]
1015 .schema()
1016 .fields()
1017 .iter()
1018 .filter(|f| f.name().starts_with("__prob_complement_"))
1019 .map(|f| f.name().clone())
1020 .collect()
1021 } else {
1022 vec![]
1023 };
1024 if !complement_cols.is_empty() {
1025 batches = super::locy_fixpoint::multiply_prob_factors(
1026 batches,
1027 fp_rule.prob_column_name.as_deref(),
1028 &complement_cols,
1029 )?;
1030 }
1031
1032 tagged_clause_facts.push((clause_idx, batches));
1033 }
1034
1035 let shared_info = if semiring_kind == SemiringKind::MaxMinProb {
1048 None
1049 } else if let Some(ref tracker) = derivation_tracker {
1050 super::locy_fixpoint::record_and_detect_lineage_nonrecursive(
1051 fp_rule,
1052 &tagged_clause_facts,
1053 tracker,
1054 &warnings_slot,
1055 ®istry,
1056 top_k_proofs,
1057 super::locy_fixpoint::ClassifierRefs {
1058 registry: &classifier_registry,
1059 cache: classifier_cache.as_ref(),
1060 provenance_store: classifier_provenance_store.as_ref(),
1061 },
1062 semiring_kind,
1063 )
1064 .await
1065 } else {
1066 None
1067 };
1068
1069 let mut all_clause_facts: Vec<RecordBatch> = tagged_clause_facts
1071 .into_iter()
1072 .flat_map(|(_, batches)| batches)
1073 .collect();
1074
1075 if exact_probability
1077 && let Some(ref info) = shared_info
1078 && let Some(ref tracker) = derivation_tracker
1079 {
1080 all_clause_facts = super::locy_fixpoint::apply_exact_wmc(
1081 all_clause_facts,
1082 fp_rule,
1083 info,
1084 tracker,
1085 max_bdd_variables,
1086 &warnings_slot,
1087 &approximate_slot,
1088 )?;
1089 }
1090
1091 let facts = super::locy_fixpoint::apply_post_fixpoint_chain(
1093 all_clause_facts,
1094 fp_rule,
1095 &task_ctx,
1096 strict_probability_domain,
1097 probability_epsilon,
1098 semiring_kind,
1099 derivation_tracker.as_ref().map(Arc::clone),
1100 top_k_proofs,
1101 Some(Arc::clone(®istry)),
1102 )
1103 .await?;
1104
1105 write_facts_to_registry(®istry, &rule.name, &facts);
1107 derived_store.insert(rule.name.clone(), facts);
1108 }
1109 }
1110
1111 if interruption::reason(&timeout_flag).is_some() {
1115 partial_stratum = Some(stratum_idx);
1116 break;
1117 }
1118 completed_strata += 1;
1119 }
1120
1121 if let Some(reason) = interruption::reason(&timeout_flag) {
1125 let skipped_start = match partial_stratum {
1126 Some(i) => i + 1,
1127 None => completed_strata,
1128 };
1129 let incomplete_rules: Vec<String> = partial_stratum
1130 .map(|i| strata[i].rules.iter().map(|r| r.name.clone()).collect())
1131 .unwrap_or_default();
1132 let skipped_rules: Vec<String> = strata[skipped_start..]
1133 .iter()
1134 .flat_map(|s| s.rules.iter().map(|r| r.name.clone()))
1135 .collect();
1136 let mut complement_rules_affected = Vec::new();
1137 for idx in partial_stratum
1138 .into_iter()
1139 .chain(skipped_start..total_strata)
1140 {
1141 for rule in &strata[idx].rules {
1142 if rule
1143 .clauses
1144 .iter()
1145 .any(|c| c.is_refs.iter().any(|r| r.negated))
1146 {
1147 complement_rules_affected.push(rule.name.clone());
1148 }
1149 }
1150 }
1151 if let Ok(mut slot) = incomplete_slot.write() {
1152 *slot = Some(uni_common::LocyIncomplete {
1153 reason,
1154 elapsed_ms: u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
1155 limit_ms: u64::try_from(timeout.as_millis()).unwrap_or(u64::MAX),
1156 max_iterations,
1157 completed_strata,
1158 total_strata,
1159 incomplete_rules,
1160 skipped_rules,
1161 complement_rules_affected,
1162 });
1163 }
1164 }
1165
1166 let peak_bytes: usize = derived_store
1168 .relations
1169 .values()
1170 .flat_map(|batches| batches.iter())
1171 .map(|b| {
1172 b.columns()
1173 .iter()
1174 .map(|col| col.get_buffer_memory_size())
1175 .sum::<usize>()
1176 })
1177 .sum();
1178 *peak_memory_slot.write().unwrap() = peak_bytes;
1179
1180 let first_derive_idx = commands
1190 .iter()
1191 .position(|c| matches!(c, LocyCommand::Derive { .. }));
1192 let mut inline_results: Vec<(usize, CommandResult)> = Vec::new();
1193 for (cmd_idx, cmd) in commands.iter().enumerate() {
1194 match cmd {
1195 LocyCommand::Cypher { query } => {
1196 if first_derive_idx.is_some_and(|di| cmd_idx > di) {
1199 continue;
1200 }
1201 let rows = execute_cypher_inline(
1202 query,
1203 &schema_info,
1204 ¶ms,
1205 &graph_ctx,
1206 &session_ctx,
1207 &storage,
1208 )
1209 .await?;
1210 inline_results.push((cmd_idx, CommandResult::Cypher(rows)));
1211 }
1212 LocyCommand::Validate { validate } => {
1213 let rule_key_cols: Vec<String> = strata
1217 .iter()
1218 .flat_map(|s| s.rules.iter())
1219 .find(|r| r.name == validate.rule_name)
1220 .map(|r| {
1221 r.yield_schema
1222 .iter()
1223 .filter(|c| c.is_key)
1224 .map(|c| c.name.clone())
1225 .collect()
1226 })
1227 .unwrap_or_default();
1228 let query =
1229 super::locy_validate::validate_collection_query(validate, &rule_key_cols);
1230 let target_rows = execute_cypher_inline(
1231 &query,
1232 &schema_info,
1233 ¶ms,
1234 &graph_ctx,
1235 &session_ctx,
1236 &storage,
1237 )
1238 .await?;
1239 let rule_facts: Vec<uni_locy::FactRow> = derived_store
1240 .get(&validate.rule_name)
1241 .map(|batches| super::locy_eval::record_batches_to_locy_rows(batches))
1242 .unwrap_or_default();
1243 let result = super::locy_validate::run_validate(
1244 validate,
1245 &rule_key_cols,
1246 &rule_facts,
1247 target_rows,
1248 )
1249 .map_err(|e| {
1250 datafusion::error::DataFusionError::Execution(format!("VALIDATE error: {e}"))
1251 })?;
1252 inline_results.push((cmd_idx, CommandResult::Validate(result)));
1253 }
1254 LocyCommand::Calibrate {
1255 calibrate,
1256 model_inputs,
1257 } => {
1258 let model_snapshot = uni_locy::CompiledModel {
1271 name: calibrate.model_name.clone(),
1272 inputs: model_inputs.clone(),
1273 features: vec![],
1274 path_context: None,
1275 output_type: uni_cypher::locy_ast::OutputType::Prob,
1276 output_name: String::new(),
1277 xervo_alias: String::new(),
1278 embedder_alias: None,
1279 calibration: None,
1280 version: None,
1281 annotations: Default::default(),
1282 };
1283 let query =
1284 super::locy_calibrate::calibrate_collection_query(calibrate, &model_snapshot);
1285 let rows = execute_cypher_inline(
1286 &query,
1287 &schema_info,
1288 ¶ms,
1289 &graph_ctx,
1290 &session_ctx,
1291 &storage,
1292 )
1293 .await?;
1294 let mut catalog = std::collections::HashMap::new();
1295 catalog.insert(calibrate.model_name.clone(), model_snapshot);
1296 let result = super::locy_calibrate::run_calibrate(
1297 calibrate,
1298 &catalog,
1299 &classifier_registry,
1300 rows,
1301 )
1302 .await
1303 .map_err(|e| {
1304 datafusion::error::DataFusionError::Execution(format!("CALIBRATE error: {e}"))
1305 })?;
1306 inline_results.push((cmd_idx, CommandResult::Calibrate(result)));
1307 }
1308 _ => {}
1309 }
1310 }
1311 *command_results_slot.write().unwrap() = inline_results;
1312
1313 let stats = vec![build_stats_batch(&derived_store, &strata, output_schema)];
1314 *derived_store_slot.write().unwrap() = Some(derived_store);
1315 Ok(stats)
1316}
1317
1318fn write_cross_stratum_facts(
1324 registry: &DerivedScanRegistry,
1325 derived_store: &DerivedStore,
1326 stratum: &LocyStratum,
1327) {
1328 for rule in &stratum.rules {
1330 for clause in &rule.clauses {
1331 for is_ref in &clause.is_refs {
1332 if let Some(facts) = derived_store.get(&is_ref.rule_name) {
1335 write_facts_to_registry(registry, &is_ref.rule_name, facts);
1336 }
1337 }
1338 }
1339 }
1340}
1341
1342fn write_facts_to_registry(registry: &DerivedScanRegistry, rule_name: &str, facts: &[RecordBatch]) {
1344 let entries = registry.entries_for_rule(rule_name);
1345 for entry in entries {
1346 if !entry.is_self_ref {
1347 let mut guard = entry.data.write();
1348 *guard = if facts.is_empty() || facts.iter().all(|b| b.num_rows() == 0) {
1349 vec![RecordBatch::new_empty(Arc::clone(&entry.schema))]
1350 } else {
1351 facts
1356 .iter()
1357 .filter(|b| b.num_rows() > 0)
1358 .map(|b| {
1359 RecordBatch::try_new(Arc::clone(&entry.schema), b.columns().to_vec())
1360 .unwrap_or_else(|_| b.clone())
1361 })
1362 .collect()
1363 };
1364 }
1365 }
1366}
1367
1368fn convert_to_fixpoint_plans(
1374 rules: &[LocyRulePlan],
1375 registry: &DerivedScanRegistry,
1376 plugin_registry: &PluginRegistry,
1377 deterministic_best_by: bool,
1378) -> DFResult<Vec<FixpointRulePlan>> {
1379 rules
1380 .iter()
1381 .map(|rule| {
1382 let yield_schema = yield_columns_to_arrow_schema(&rule.yield_schema);
1383 let key_column_indices: Vec<usize> = rule
1384 .yield_schema
1385 .iter()
1386 .enumerate()
1387 .filter(|(_, yc)| yc.is_key)
1388 .map(|(i, _)| i)
1389 .collect();
1390
1391 let clauses: Vec<FixpointClausePlan> = rule
1392 .clauses
1393 .iter()
1394 .map(|clause| {
1395 let is_ref_bindings = convert_is_refs(&clause.is_refs, registry)?;
1396 Ok(FixpointClausePlan {
1397 body_logical: clause.body.clone(),
1398 is_ref_bindings,
1399 priority: clause.priority,
1400 along_bindings: clause.along_bindings.clone(),
1401 model_invocations: clause.model_invocations.clone(),
1402 })
1403 })
1404 .collect::<DFResult<Vec<_>>>()?;
1405
1406 let fold_bindings =
1407 convert_fold_bindings(&rule.fold_bindings, &rule.yield_schema, plugin_registry)?;
1408 let best_by_criteria =
1409 convert_best_by_criteria(&rule.best_by_criteria, &rule.yield_schema)?;
1410
1411 let has_priority = rule.priority.is_some();
1412
1413 let yield_schema = if has_priority {
1415 let mut fields: Vec<Arc<Field>> = yield_schema.fields().iter().cloned().collect();
1416 fields.push(Arc::new(Field::new("__priority", DataType::Int64, true)));
1417 ArrowSchema::new(fields)
1418 } else {
1419 yield_schema
1420 };
1421
1422 let prob_column_name = rule
1423 .yield_schema
1424 .iter()
1425 .find(|yc| yc.is_prob)
1426 .map(|yc| yc.name.clone());
1427
1428 Ok(FixpointRulePlan {
1429 name: rule.name.clone(),
1430 clauses,
1431 yield_schema: Arc::new(yield_schema),
1432 key_column_indices,
1433 priority: rule.priority,
1434 has_fold: !rule.fold_bindings.is_empty(),
1435 fold_bindings,
1436 having: rule.having.clone(),
1437 has_best_by: !rule.best_by_criteria.is_empty(),
1438 best_by_criteria,
1439 has_priority,
1440 deterministic: deterministic_best_by,
1441 prob_column_name,
1442 })
1443 })
1444 .collect()
1445}
1446
1447fn convert_is_refs(
1449 is_refs: &[LocyIsRef],
1450 registry: &DerivedScanRegistry,
1451) -> DFResult<Vec<IsRefBinding>> {
1452 is_refs
1453 .iter()
1454 .map(|is_ref| {
1455 let entries = registry.entries_for_rule(&is_ref.rule_name);
1456 let entry = entries
1458 .iter()
1459 .find(|e| e.is_self_ref)
1460 .or_else(|| entries.first())
1461 .ok_or_else(|| {
1462 datafusion::error::DataFusionError::Plan(format!(
1463 "No derived scan entry found for IS-ref to '{}'",
1464 is_ref.rule_name
1465 ))
1466 })?;
1467
1468 let anti_join_cols = if is_ref.negated {
1473 let mut cols: Vec<(String, String)> = is_ref
1474 .subjects
1475 .iter()
1476 .enumerate()
1477 .filter_map(|(i, s)| {
1478 if let uni_cypher::ast::Expr::Variable(var) = s {
1479 let right_col = entry
1480 .schema
1481 .fields()
1482 .get(i)
1483 .map(|f| f.name().clone())
1484 .unwrap_or_else(|| var.clone());
1485 Some((var.clone(), right_col))
1488 } else {
1489 None
1490 }
1491 })
1492 .collect();
1493 if let Some(uni_cypher::ast::Expr::Variable(target_var)) = &is_ref.target {
1498 let target_idx = is_ref.subjects.len();
1499 if let Some(field) = entry.schema.fields().get(target_idx) {
1500 cols.push((target_var.clone(), field.name().clone()));
1501 }
1502 }
1503 cols
1504 } else {
1505 Vec::new()
1506 };
1507
1508 let provenance_join_cols: Vec<(String, String)> = is_ref
1512 .subjects
1513 .iter()
1514 .enumerate()
1515 .filter_map(|(i, s)| {
1516 if let uni_cypher::ast::Expr::Variable(var) = s {
1517 let right_col = entry
1518 .schema
1519 .fields()
1520 .get(i)
1521 .map(|f| f.name().clone())
1522 .unwrap_or_else(|| var.clone());
1523 Some((var.clone(), right_col))
1524 } else {
1525 None
1526 }
1527 })
1528 .collect();
1529
1530 Ok(IsRefBinding {
1531 derived_scan_index: entry.scan_index,
1532 rule_name: is_ref.rule_name.clone(),
1533 is_self_ref: entry.is_self_ref,
1534 negated: is_ref.negated,
1535 anti_join_cols,
1536 target_has_prob: is_ref.target_has_prob,
1537 target_prob_col: is_ref.target_prob_col.clone(),
1538 provenance_join_cols,
1539 })
1540 })
1541 .collect()
1542}
1543
1544fn convert_fold_bindings(
1552 fold_bindings: &[(String, String, Expr)],
1553 yield_schema: &[LocyYieldColumn],
1554 plugin_registry: &PluginRegistry,
1555) -> DFResult<Vec<FoldBinding>> {
1556 fold_bindings
1557 .iter()
1558 .map(|(name, yield_alias, expr)| {
1559 let (agg_name, _input_col_name) = parse_fold_aggregate(expr)?;
1560 let entry =
1561 resolve_locy_aggregate(plugin_registry, agg_name.as_str()).ok_or_else(|| {
1562 datafusion::error::DataFusionError::Plan(format!(
1563 "Unknown Locy aggregate '{agg_name}' — not registered in plugin registry"
1564 ))
1565 })?;
1566 let aggregate = Arc::clone(&entry.aggregate);
1567
1568 if agg_name.as_str() == "COUNTALL" {
1571 return Ok(FoldBinding {
1572 output_name: yield_alias.clone(),
1573 name: agg_name,
1574 aggregate,
1575 input_col_index: 0, input_col_name: None,
1577 });
1578 }
1579
1580 let input_col_index = yield_schema
1585 .iter()
1586 .position(|yc| yc.name == *name || yc.name == *yield_alias)
1587 .unwrap_or(0);
1588 Ok(FoldBinding {
1589 output_name: yield_alias.clone(),
1590 name: agg_name,
1591 aggregate,
1592 input_col_index,
1593 input_col_name: Some(name.clone()),
1594 })
1595 })
1596 .collect()
1597}
1598
1599fn parse_fold_aggregate(expr: &Expr) -> DFResult<(smol_str::SmolStr, String)> {
1605 match expr {
1606 Expr::FunctionCall { name, args, .. } => {
1607 let upper = name.to_uppercase();
1608 let is_count = matches!(upper.as_str(), "COUNT" | "MCOUNT");
1609
1610 if is_count && args.is_empty() {
1612 return Ok((smol_str::SmolStr::new_static("COUNTALL"), String::new()));
1613 }
1614
1615 let canonical = match upper.as_str() {
1616 "SUM" | "MSUM" => smol_str::SmolStr::new_static("SUM"),
1617 "MAX" | "MMAX" => smol_str::SmolStr::new_static("MAX"),
1618 "MIN" | "MMIN" => smol_str::SmolStr::new_static("MIN"),
1619 "COUNT" | "MCOUNT" => smol_str::SmolStr::new_static("COUNT"),
1620 "AVG" => smol_str::SmolStr::new_static("AVG"),
1621 "COLLECT" => smol_str::SmolStr::new_static("COLLECT"),
1622 "MNOR" => smol_str::SmolStr::new_static("MNOR"),
1623 "MPROD" => smol_str::SmolStr::new_static("MPROD"),
1624 _ => {
1625 return Err(datafusion::error::DataFusionError::Plan(format!(
1626 "Unknown FOLD aggregate function: {}",
1627 name
1628 )));
1629 }
1630 };
1631 let col_name = match args.first() {
1632 Some(Expr::Variable(v)) => v.clone(),
1633 Some(Expr::Property(_, prop)) => prop.clone(),
1634 Some(other) => other.to_string_repr(),
1635 None => {
1636 return Err(datafusion::error::DataFusionError::Plan(
1637 "FOLD aggregate function requires at least one argument".to_string(),
1638 ));
1639 }
1640 };
1641 Ok((canonical, col_name))
1642 }
1643 _ => Err(datafusion::error::DataFusionError::Plan(
1644 "FOLD binding must be a function call (e.g., SUM(x))".to_string(),
1645 )),
1646 }
1647}
1648
1649fn convert_best_by_criteria(
1656 criteria: &[(Expr, bool)],
1657 yield_schema: &[LocyYieldColumn],
1658) -> DFResult<Vec<SortCriterion>> {
1659 criteria
1660 .iter()
1661 .map(|(expr, ascending)| {
1662 let col_name = match expr {
1663 Expr::Property(_, prop) => prop.clone(),
1664 Expr::Variable(v) => v.clone(),
1665 _ => {
1666 return Err(datafusion::error::DataFusionError::Plan(
1667 "BEST BY criterion must be a variable or property reference".to_string(),
1668 ));
1669 }
1670 };
1671 let col_index = yield_schema
1673 .iter()
1674 .position(|yc| yc.name == col_name)
1675 .or_else(|| {
1676 let short_name = col_name.rsplit('.').next().unwrap_or(&col_name);
1677 yield_schema.iter().position(|yc| yc.name == short_name)
1678 })
1679 .ok_or_else(|| {
1680 datafusion::error::DataFusionError::Plan(format!(
1681 "BEST BY column '{}' not found in yield schema",
1682 col_name
1683 ))
1684 })?;
1685 Ok(SortCriterion {
1686 col_index,
1687 ascending: *ascending,
1688 nulls_first: false,
1689 })
1690 })
1691 .collect()
1692}
1693
1694fn yield_columns_to_arrow_schema(columns: &[LocyYieldColumn]) -> ArrowSchema {
1700 let fields: Vec<Arc<Field>> = columns
1701 .iter()
1702 .map(|yc| Arc::new(Field::new(&yc.name, yc.data_type.clone(), true)))
1703 .collect();
1704 ArrowSchema::new(fields)
1705}
1706
1707fn build_fixpoint_output_schema(rules: &[LocyRulePlan]) -> SchemaRef {
1709 if let Some(rule) = rules.first() {
1712 Arc::new(yield_columns_to_arrow_schema(&rule.yield_schema))
1713 } else {
1714 Arc::new(ArrowSchema::empty())
1715 }
1716}
1717
1718fn build_stats_batch(
1720 derived_store: &DerivedStore,
1721 _strata: &[LocyStratum],
1722 output_schema: SchemaRef,
1723) -> RecordBatch {
1724 let mut rule_names: Vec<String> = derived_store.rule_names().map(String::from).collect();
1726 rule_names.sort();
1727
1728 let name_col: arrow_array::StringArray = rule_names.iter().map(|s| Some(s.as_str())).collect();
1729 let count_col: arrow_array::Int64Array = rule_names
1730 .iter()
1731 .map(|name| Some(derived_store.fact_count(name) as i64))
1732 .collect();
1733
1734 let stats_schema = stats_schema();
1735 RecordBatch::try_new(stats_schema, vec![Arc::new(name_col), Arc::new(count_col)])
1736 .unwrap_or_else(|_| RecordBatch::new_empty(output_schema))
1737}
1738
1739pub fn stats_schema() -> SchemaRef {
1741 Arc::new(ArrowSchema::new(vec![
1742 Arc::new(Field::new("rule_name", DataType::Utf8, false)),
1743 Arc::new(Field::new("fact_count", DataType::Int64, false)),
1744 ]))
1745}
1746
1747#[cfg(test)]
1752mod tests {
1753 use super::*;
1754 use arrow_array::{Int64Array, LargeBinaryArray, StringArray};
1755
1756 #[test]
1757 fn test_derived_store_insert_and_get() {
1758 let mut store = DerivedStore::new();
1759 assert!(store.get("test").is_none());
1760
1761 let schema = Arc::new(ArrowSchema::new(vec![Arc::new(Field::new(
1762 "x",
1763 DataType::LargeBinary,
1764 true,
1765 ))]));
1766 let batch = RecordBatch::try_new(
1767 Arc::clone(&schema),
1768 vec![Arc::new(LargeBinaryArray::from(vec![
1769 Some(b"a" as &[u8]),
1770 Some(b"b"),
1771 ]))],
1772 )
1773 .unwrap();
1774
1775 store.insert("test".to_string(), vec![batch.clone()]);
1776
1777 let facts = store.get("test").unwrap();
1778 assert_eq!(facts.len(), 1);
1779 assert_eq!(facts[0].num_rows(), 2);
1780 }
1781
1782 #[test]
1783 fn test_derived_store_fact_count() {
1784 let mut store = DerivedStore::new();
1785 assert_eq!(store.fact_count("empty"), 0);
1786
1787 let schema = Arc::new(ArrowSchema::new(vec![Arc::new(Field::new(
1788 "x",
1789 DataType::LargeBinary,
1790 true,
1791 ))]));
1792 let batch1 = RecordBatch::try_new(
1793 Arc::clone(&schema),
1794 vec![Arc::new(LargeBinaryArray::from(vec![Some(b"a" as &[u8])]))],
1795 )
1796 .unwrap();
1797 let batch2 = RecordBatch::try_new(
1798 Arc::clone(&schema),
1799 vec![Arc::new(LargeBinaryArray::from(vec![
1800 Some(b"b" as &[u8]),
1801 Some(b"c"),
1802 ]))],
1803 )
1804 .unwrap();
1805
1806 store.insert("test".to_string(), vec![batch1, batch2]);
1807 assert_eq!(store.fact_count("test"), 3);
1808 }
1809
1810 #[test]
1811 fn test_stats_batch_schema() {
1812 let schema = stats_schema();
1813 assert_eq!(schema.fields().len(), 2);
1814 assert_eq!(schema.field(0).name(), "rule_name");
1815 assert_eq!(schema.field(1).name(), "fact_count");
1816 assert_eq!(schema.field(0).data_type(), &DataType::Utf8);
1817 assert_eq!(schema.field(1).data_type(), &DataType::Int64);
1818 }
1819
1820 #[test]
1821 fn test_stats_batch_content() {
1822 let mut store = DerivedStore::new();
1823 let schema = Arc::new(ArrowSchema::new(vec![Arc::new(Field::new(
1824 "x",
1825 DataType::LargeBinary,
1826 true,
1827 ))]));
1828 let batch = RecordBatch::try_new(
1829 Arc::clone(&schema),
1830 vec![Arc::new(LargeBinaryArray::from(vec![
1831 Some(b"a" as &[u8]),
1832 Some(b"b"),
1833 ]))],
1834 )
1835 .unwrap();
1836 store.insert("reach".to_string(), vec![batch]);
1837
1838 let output_schema = stats_schema();
1839 let stats = build_stats_batch(&store, &[], Arc::clone(&output_schema));
1840 assert_eq!(stats.num_rows(), 1);
1841
1842 let names = stats
1843 .column(0)
1844 .as_any()
1845 .downcast_ref::<StringArray>()
1846 .unwrap();
1847 assert_eq!(names.value(0), "reach");
1848
1849 let counts = stats
1850 .column(1)
1851 .as_any()
1852 .downcast_ref::<Int64Array>()
1853 .unwrap();
1854 assert_eq!(counts.value(0), 2);
1855 }
1856
1857 #[test]
1858 fn test_yield_columns_to_arrow_schema() {
1859 let columns = vec![
1860 LocyYieldColumn {
1861 name: "a".to_string(),
1862 is_key: true,
1863 is_prob: false,
1864 data_type: DataType::UInt64,
1865 },
1866 LocyYieldColumn {
1867 name: "b".to_string(),
1868 is_key: false,
1869 is_prob: false,
1870 data_type: DataType::LargeUtf8,
1871 },
1872 LocyYieldColumn {
1873 name: "c".to_string(),
1874 is_key: true,
1875 is_prob: false,
1876 data_type: DataType::Float64,
1877 },
1878 ];
1879
1880 let schema = yield_columns_to_arrow_schema(&columns);
1881 assert_eq!(schema.fields().len(), 3);
1882 assert_eq!(schema.field(0).name(), "a");
1883 assert_eq!(schema.field(1).name(), "b");
1884 assert_eq!(schema.field(2).name(), "c");
1885 assert_eq!(schema.field(0).data_type(), &DataType::UInt64);
1887 assert_eq!(schema.field(1).data_type(), &DataType::LargeUtf8);
1888 assert_eq!(schema.field(2).data_type(), &DataType::Float64);
1889 for field in schema.fields() {
1890 assert!(field.is_nullable());
1891 }
1892 }
1893
1894 #[test]
1895 fn test_key_column_indices() {
1896 let columns = [
1897 LocyYieldColumn {
1898 name: "a".to_string(),
1899 is_key: true,
1900 is_prob: false,
1901 data_type: DataType::LargeBinary,
1902 },
1903 LocyYieldColumn {
1904 name: "b".to_string(),
1905 is_key: false,
1906 is_prob: false,
1907 data_type: DataType::LargeBinary,
1908 },
1909 LocyYieldColumn {
1910 name: "c".to_string(),
1911 is_key: true,
1912 is_prob: false,
1913 data_type: DataType::LargeBinary,
1914 },
1915 ];
1916
1917 let key_indices: Vec<usize> = columns
1918 .iter()
1919 .enumerate()
1920 .filter(|(_, yc)| yc.is_key)
1921 .map(|(i, _)| i)
1922 .collect();
1923 assert_eq!(key_indices, vec![0, 2]);
1924 }
1925
1926 #[test]
1927 fn test_parse_fold_aggregate_sum() {
1928 let expr = Expr::FunctionCall {
1929 name: "SUM".to_string(),
1930 args: vec![Expr::Variable("cost".to_string())],
1931 distinct: false,
1932 window_spec: None,
1933 };
1934 let (kind, col) = parse_fold_aggregate(&expr).unwrap();
1935 assert_eq!(kind.as_str(), "SUM");
1936 assert_eq!(col, "cost");
1937 }
1938
1939 #[test]
1940 fn test_parse_fold_aggregate_monotonic() {
1941 let expr = Expr::FunctionCall {
1942 name: "MMAX".to_string(),
1943 args: vec![Expr::Variable("score".to_string())],
1944 distinct: false,
1945 window_spec: None,
1946 };
1947 let (kind, col) = parse_fold_aggregate(&expr).unwrap();
1948 assert_eq!(kind.as_str(), "MAX");
1949 assert_eq!(col, "score");
1950 }
1951
1952 #[test]
1953 fn test_parse_fold_aggregate_unknown() {
1954 let expr = Expr::FunctionCall {
1955 name: "UNKNOWN_AGG".to_string(),
1956 args: vec![Expr::Variable("x".to_string())],
1957 distinct: false,
1958 window_spec: None,
1959 };
1960 assert!(parse_fold_aggregate(&expr).is_err());
1961 }
1962
1963 #[test]
1964 fn test_no_commands_returns_stats() {
1965 let store = DerivedStore::new();
1966 let output_schema = stats_schema();
1967 let stats = build_stats_batch(&store, &[], Arc::clone(&output_schema));
1968 assert_eq!(stats.num_rows(), 0);
1970 }
1971}