1use crate::query::df_graph::GraphExecutionContext;
12use crate::query::df_graph::common::{
13 collect_all_partitions, compute_plan_properties, execute_subplan,
14};
15use crate::query::df_graph::locy_best_by::SortCriterion;
16use crate::query::df_graph::locy_explain::ProvenanceStore;
17use crate::query::df_graph::locy_fixpoint::{
18 DerivedScanRegistry, FixpointClausePlan, FixpointExec, FixpointRulePlan, IsRefBinding,
19};
20use crate::query::df_graph::locy_fold::{FoldBinding, resolve_locy_aggregate};
21use crate::query::planner_locy_types::{
22 LocyCommand, LocyIsRef, LocyRulePlan, LocyStratum, LocyYieldColumn,
23};
24use arrow_array::RecordBatch;
25use arrow_schema::{DataType, Field, Schema as ArrowSchema, SchemaRef};
26use datafusion::common::Result as DFResult;
27use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
28use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
29use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
30use futures::Stream;
31use parking_lot::RwLock;
32use std::any::Any;
33use std::collections::HashMap;
34use std::fmt;
35use std::pin::Pin;
36use std::sync::Arc;
37use std::sync::RwLock as StdRwLock;
38use std::task::{Context, Poll};
39use std::time::{Duration, Instant};
40use uni_common::Value;
41use uni_common::core::schema::Schema as UniSchema;
42use uni_cypher::ast::Expr;
43use uni_locy::{
44 ClassifierRegistry, CommandResult, FactRow, ModelInvocationCache, RuntimeWarning, SemiringKind,
45};
46use uni_plugin::PluginRegistry;
47use uni_store::storage::manager::StorageManager;
48
49pub struct DerivedStore {
58 relations: HashMap<String, Vec<RecordBatch>>,
59}
60
61impl Default for DerivedStore {
62 fn default() -> Self {
63 Self::new()
64 }
65}
66
67impl DerivedStore {
68 pub fn new() -> Self {
69 Self {
70 relations: HashMap::new(),
71 }
72 }
73
74 pub fn insert(&mut self, rule_name: String, facts: Vec<RecordBatch>) {
75 self.relations.insert(rule_name, facts);
76 }
77
78 pub fn get(&self, rule_name: &str) -> Option<&Vec<RecordBatch>> {
79 self.relations.get(rule_name)
80 }
81
82 pub fn fact_count(&self, rule_name: &str) -> usize {
83 self.relations
84 .get(rule_name)
85 .map(|batches| batches.iter().map(|b| b.num_rows()).sum())
86 .unwrap_or(0)
87 }
88
89 pub fn rule_names(&self) -> impl Iterator<Item = &str> {
90 self.relations.keys().map(|s| s.as_str())
91 }
92}
93
94pub struct LocyProgramExec {
104 strata: Vec<LocyStratum>,
105 commands: Vec<LocyCommand>,
106 derived_scan_registry: Arc<DerivedScanRegistry>,
107 plugin_registry: Arc<PluginRegistry>,
108 graph_ctx: Arc<GraphExecutionContext>,
109 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
110 storage: Arc<StorageManager>,
111 schema_info: Arc<UniSchema>,
112 params: HashMap<String, Value>,
113 output_schema: SchemaRef,
114 properties: Arc<PlanProperties>,
115 metrics: ExecutionPlanMetricsSet,
116 max_iterations: usize,
117 timeout: Duration,
118 max_derived_bytes: usize,
119 deterministic_best_by: bool,
120 strict_probability_domain: bool,
121 probability_epsilon: f64,
122 exact_probability: bool,
123 max_bdd_variables: usize,
124 semiring_kind: SemiringKind,
127 classifier_registry: Arc<ClassifierRegistry>,
131 classifier_cache: Option<Arc<ModelInvocationCache>>,
134 classifier_provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
138 derived_store_slot: Arc<StdRwLock<Option<DerivedStore>>>,
140 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
142 derivation_tracker: Arc<StdRwLock<Option<Arc<ProvenanceStore>>>>,
144 iteration_counts_slot: Arc<StdRwLock<HashMap<String, usize>>>,
146 peak_memory_slot: Arc<StdRwLock<usize>>,
148 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
150 command_results_slot: Arc<StdRwLock<Vec<(usize, CommandResult)>>>,
152 top_k_proofs: usize,
154 timeout_flag: Arc<std::sync::atomic::AtomicU8>,
158 incomplete_slot: Arc<StdRwLock<Option<uni_common::LocyIncomplete>>>,
162}
163
164pub(crate) mod interruption {
171 use std::sync::atomic::{AtomicU8, Ordering};
172
173 use uni_common::LocyIncompleteReason;
174
175 pub(crate) const NONE: u8 = 0;
177 pub(crate) const TIMEOUT: u8 = 1;
179 pub(crate) const ITERATION_LIMIT: u8 = 2;
181
182 pub(crate) fn reason(flag: &AtomicU8) -> Option<LocyIncompleteReason> {
184 match flag.load(Ordering::Relaxed) {
185 TIMEOUT => Some(LocyIncompleteReason::Timeout),
186 ITERATION_LIMIT => Some(LocyIncompleteReason::IterationLimit),
187 _ => None,
188 }
189 }
190
191 pub(crate) fn set(flag: &AtomicU8, code: u8) {
195 let _ = flag.compare_exchange(NONE, code, Ordering::Relaxed, Ordering::Relaxed);
196 }
197}
198
199impl fmt::Debug for LocyProgramExec {
200 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201 f.debug_struct("LocyProgramExec")
202 .field("strata_count", &self.strata.len())
203 .field("commands_count", &self.commands.len())
204 .field("max_iterations", &self.max_iterations)
205 .field("timeout", &self.timeout)
206 .field("output_schema", &self.output_schema)
207 .field("max_derived_bytes", &self.max_derived_bytes)
208 .finish_non_exhaustive()
209 }
210}
211
212impl LocyProgramExec {
213 #[expect(
214 clippy::too_many_arguments,
215 reason = "execution plan node requires full graph and session context"
216 )]
217 #[deprecated(
218 note = "use `new_with_semiring_classifiers_and_cache` (or the lighter \
219 `new_with_semiring_and_classifiers` / `new_with_semiring`) — \
220 this legacy ctor defaults the semiring to AddMultProb and \
221 ships no classifier registry. To be removed after C0 Stage 2."
222 )]
223 pub fn new(
224 strata: Vec<LocyStratum>,
225 commands: Vec<LocyCommand>,
226 derived_scan_registry: Arc<DerivedScanRegistry>,
227 plugin_registry: Arc<PluginRegistry>,
228 graph_ctx: Arc<GraphExecutionContext>,
229 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
230 storage: Arc<StorageManager>,
231 schema_info: Arc<UniSchema>,
232 params: HashMap<String, Value>,
233 output_schema: SchemaRef,
234 max_iterations: usize,
235 timeout: Duration,
236 max_derived_bytes: usize,
237 deterministic_best_by: bool,
238 strict_probability_domain: bool,
239 probability_epsilon: f64,
240 exact_probability: bool,
241 max_bdd_variables: usize,
242 top_k_proofs: usize,
243 ) -> Self {
244 Self::new_with_semiring_and_classifiers(
245 strata,
246 commands,
247 derived_scan_registry,
248 plugin_registry,
249 graph_ctx,
250 session_ctx,
251 storage,
252 schema_info,
253 params,
254 output_schema,
255 max_iterations,
256 timeout,
257 max_derived_bytes,
258 deterministic_best_by,
259 strict_probability_domain,
260 probability_epsilon,
261 exact_probability,
262 max_bdd_variables,
263 top_k_proofs,
264 SemiringKind::AddMultProb,
265 Arc::new(ClassifierRegistry::new()),
266 )
267 }
268
269 #[expect(
273 clippy::too_many_arguments,
274 reason = "execution plan node requires full graph and session context"
275 )]
276 pub fn new_with_semiring(
277 strata: Vec<LocyStratum>,
278 commands: Vec<LocyCommand>,
279 derived_scan_registry: Arc<DerivedScanRegistry>,
280 plugin_registry: Arc<PluginRegistry>,
281 graph_ctx: Arc<GraphExecutionContext>,
282 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
283 storage: Arc<StorageManager>,
284 schema_info: Arc<UniSchema>,
285 params: HashMap<String, Value>,
286 output_schema: SchemaRef,
287 max_iterations: usize,
288 timeout: Duration,
289 max_derived_bytes: usize,
290 deterministic_best_by: bool,
291 strict_probability_domain: bool,
292 probability_epsilon: f64,
293 exact_probability: bool,
294 max_bdd_variables: usize,
295 top_k_proofs: usize,
296 semiring_kind: SemiringKind,
297 ) -> Self {
298 Self::new_with_semiring_and_classifiers(
299 strata,
300 commands,
301 derived_scan_registry,
302 plugin_registry,
303 graph_ctx,
304 session_ctx,
305 storage,
306 schema_info,
307 params,
308 output_schema,
309 max_iterations,
310 timeout,
311 max_derived_bytes,
312 deterministic_best_by,
313 strict_probability_domain,
314 probability_epsilon,
315 exact_probability,
316 max_bdd_variables,
317 top_k_proofs,
318 semiring_kind,
319 Arc::new(ClassifierRegistry::new()),
320 )
321 }
322
323 #[expect(
326 clippy::too_many_arguments,
327 reason = "execution plan node requires full graph and session context"
328 )]
329 pub fn new_with_semiring_and_classifiers(
330 strata: Vec<LocyStratum>,
331 commands: Vec<LocyCommand>,
332 derived_scan_registry: Arc<DerivedScanRegistry>,
333 plugin_registry: Arc<PluginRegistry>,
334 graph_ctx: Arc<GraphExecutionContext>,
335 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
336 storage: Arc<StorageManager>,
337 schema_info: Arc<UniSchema>,
338 params: HashMap<String, Value>,
339 output_schema: SchemaRef,
340 max_iterations: usize,
341 timeout: Duration,
342 max_derived_bytes: usize,
343 deterministic_best_by: bool,
344 strict_probability_domain: bool,
345 probability_epsilon: f64,
346 exact_probability: bool,
347 max_bdd_variables: usize,
348 top_k_proofs: usize,
349 semiring_kind: SemiringKind,
350 classifier_registry: Arc<ClassifierRegistry>,
351 ) -> Self {
352 Self::new_with_semiring_classifiers_and_cache(
353 strata,
354 commands,
355 derived_scan_registry,
356 plugin_registry,
357 graph_ctx,
358 session_ctx,
359 storage,
360 schema_info,
361 params,
362 output_schema,
363 max_iterations,
364 timeout,
365 max_derived_bytes,
366 deterministic_best_by,
367 strict_probability_domain,
368 probability_epsilon,
369 exact_probability,
370 max_bdd_variables,
371 top_k_proofs,
372 semiring_kind,
373 classifier_registry,
374 None,
375 None,
376 )
377 }
378
379 #[expect(
384 clippy::too_many_arguments,
385 reason = "execution plan node requires full graph and session context"
386 )]
387 pub fn new_with_semiring_classifiers_and_cache(
388 strata: Vec<LocyStratum>,
389 commands: Vec<LocyCommand>,
390 derived_scan_registry: Arc<DerivedScanRegistry>,
391 plugin_registry: Arc<PluginRegistry>,
392 graph_ctx: Arc<GraphExecutionContext>,
393 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
394 storage: Arc<StorageManager>,
395 schema_info: Arc<UniSchema>,
396 params: HashMap<String, Value>,
397 output_schema: SchemaRef,
398 max_iterations: usize,
399 timeout: Duration,
400 max_derived_bytes: usize,
401 deterministic_best_by: bool,
402 strict_probability_domain: bool,
403 probability_epsilon: f64,
404 exact_probability: bool,
405 max_bdd_variables: usize,
406 top_k_proofs: usize,
407 semiring_kind: SemiringKind,
408 classifier_registry: Arc<ClassifierRegistry>,
409 classifier_cache: Option<Arc<ModelInvocationCache>>,
410 classifier_provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
411 ) -> Self {
412 let properties = compute_plan_properties(Arc::clone(&output_schema));
413 Self {
414 strata,
415 commands,
416 derived_scan_registry,
417 plugin_registry,
418 graph_ctx,
419 session_ctx,
420 storage,
421 schema_info,
422 params,
423 output_schema,
424 properties,
425 metrics: ExecutionPlanMetricsSet::new(),
426 max_iterations,
427 timeout,
428 max_derived_bytes,
429 deterministic_best_by,
430 strict_probability_domain,
431 probability_epsilon,
432 exact_probability,
433 max_bdd_variables,
434 semiring_kind,
435 classifier_registry,
436 classifier_cache,
437 classifier_provenance_store,
438 derived_store_slot: Arc::new(StdRwLock::new(None)),
439 approximate_slot: Arc::new(StdRwLock::new(HashMap::new())),
440 derivation_tracker: Arc::new(StdRwLock::new(None)),
441 iteration_counts_slot: Arc::new(StdRwLock::new(HashMap::new())),
442 peak_memory_slot: Arc::new(StdRwLock::new(0)),
443 warnings_slot: Arc::new(StdRwLock::new(Vec::new())),
444 command_results_slot: Arc::new(StdRwLock::new(Vec::new())),
445 top_k_proofs,
446 timeout_flag: Arc::new(std::sync::atomic::AtomicU8::new(interruption::NONE)),
447 incomplete_slot: Arc::new(StdRwLock::new(None)),
448 }
449 }
450
451 pub fn derived_store_slot(&self) -> Arc<StdRwLock<Option<DerivedStore>>> {
456 Arc::clone(&self.derived_store_slot)
457 }
458
459 pub fn set_derivation_tracker(&self, tracker: Arc<ProvenanceStore>) {
464 if let Ok(mut guard) = self.derivation_tracker.write() {
465 *guard = Some(tracker);
466 }
467 }
468
469 pub fn iteration_counts_slot(&self) -> Arc<StdRwLock<HashMap<String, usize>>> {
474 Arc::clone(&self.iteration_counts_slot)
475 }
476
477 pub fn peak_memory_slot(&self) -> Arc<StdRwLock<usize>> {
482 Arc::clone(&self.peak_memory_slot)
483 }
484
485 pub fn warnings_slot(&self) -> Arc<StdRwLock<Vec<RuntimeWarning>>> {
490 Arc::clone(&self.warnings_slot)
491 }
492
493 pub fn approximate_slot(&self) -> Arc<StdRwLock<HashMap<String, Vec<String>>>> {
498 Arc::clone(&self.approximate_slot)
499 }
500
501 pub fn command_results_slot(&self) -> Arc<StdRwLock<Vec<(usize, CommandResult)>>> {
506 Arc::clone(&self.command_results_slot)
507 }
508
509 pub fn timeout_flag(&self) -> Arc<std::sync::atomic::AtomicU8> {
515 Arc::clone(&self.timeout_flag)
516 }
517
518 pub fn incomplete_slot(&self) -> Arc<StdRwLock<Option<uni_common::LocyIncomplete>>> {
524 Arc::clone(&self.incomplete_slot)
525 }
526}
527
528impl DisplayAs for LocyProgramExec {
529 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
530 write!(
531 f,
532 "LocyProgramExec: strata={}, commands={}, max_iter={}, timeout={:?}",
533 self.strata.len(),
534 self.commands.len(),
535 self.max_iterations,
536 self.timeout,
537 )
538 }
539}
540
541impl ExecutionPlan for LocyProgramExec {
542 fn name(&self) -> &str {
543 "LocyProgramExec"
544 }
545
546 fn as_any(&self) -> &dyn Any {
547 self
548 }
549
550 fn schema(&self) -> SchemaRef {
551 Arc::clone(&self.output_schema)
552 }
553
554 fn properties(&self) -> &Arc<PlanProperties> {
555 &self.properties
556 }
557
558 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
559 vec![]
560 }
561
562 fn with_new_children(
563 self: Arc<Self>,
564 children: Vec<Arc<dyn ExecutionPlan>>,
565 ) -> DFResult<Arc<dyn ExecutionPlan>> {
566 if !children.is_empty() {
567 return Err(datafusion::error::DataFusionError::Plan(
568 "LocyProgramExec has no children".to_string(),
569 ));
570 }
571 Ok(self)
572 }
573
574 fn execute(
575 &self,
576 partition: usize,
577 _context: Arc<TaskContext>,
578 ) -> DFResult<SendableRecordBatchStream> {
579 let metrics = BaselineMetrics::new(&self.metrics, partition);
580
581 let strata = self.strata.clone();
582 let registry = Arc::clone(&self.derived_scan_registry);
583 let plugin_registry = Arc::clone(&self.plugin_registry);
584 let graph_ctx = Arc::clone(&self.graph_ctx);
585 let session_ctx = Arc::clone(&self.session_ctx);
586 let storage = Arc::clone(&self.storage);
587 let schema_info = Arc::clone(&self.schema_info);
588 let params = self.params.clone();
589 let output_schema = Arc::clone(&self.output_schema);
590 let max_iterations = self.max_iterations;
591 let timeout = self.timeout;
592 let max_derived_bytes = self.max_derived_bytes;
593 let deterministic_best_by = self.deterministic_best_by;
594 let strict_probability_domain = self.strict_probability_domain;
595 let probability_epsilon = self.probability_epsilon;
596 let exact_probability = self.exact_probability;
597 let max_bdd_variables = self.max_bdd_variables;
598 let derived_store_slot = Arc::clone(&self.derived_store_slot);
599 let approximate_slot = Arc::clone(&self.approximate_slot);
600 let iteration_counts_slot = Arc::clone(&self.iteration_counts_slot);
601 let peak_memory_slot = Arc::clone(&self.peak_memory_slot);
602 let derivation_tracker = self.derivation_tracker.read().ok().and_then(|g| g.clone());
603 let warnings_slot = Arc::clone(&self.warnings_slot);
604 let commands = self.commands.clone();
605 let command_results_slot = Arc::clone(&self.command_results_slot);
606 let top_k_proofs = self.top_k_proofs;
607 let timeout_flag = Arc::clone(&self.timeout_flag);
608 let incomplete_slot = Arc::clone(&self.incomplete_slot);
609 let semiring_kind = self.semiring_kind;
610 let classifier_registry = Arc::clone(&self.classifier_registry);
611 let classifier_cache = self.classifier_cache.as_ref().map(Arc::clone);
612 let classifier_provenance_store = self.classifier_provenance_store.as_ref().map(Arc::clone);
613
614 let fut = async move {
615 run_program(
616 strata,
617 commands,
618 registry,
619 plugin_registry,
620 graph_ctx,
621 session_ctx,
622 storage,
623 schema_info,
624 params,
625 output_schema,
626 max_iterations,
627 timeout,
628 max_derived_bytes,
629 deterministic_best_by,
630 strict_probability_domain,
631 probability_epsilon,
632 exact_probability,
633 max_bdd_variables,
634 derived_store_slot,
635 approximate_slot,
636 iteration_counts_slot,
637 peak_memory_slot,
638 derivation_tracker,
639 warnings_slot,
640 command_results_slot,
641 top_k_proofs,
642 timeout_flag,
643 incomplete_slot,
644 semiring_kind,
645 classifier_registry,
646 classifier_cache,
647 classifier_provenance_store,
648 )
649 .await
650 };
651
652 Ok(Box::pin(ProgramStream {
653 state: ProgramStreamState::Running(Box::pin(fut)),
654 schema: Arc::clone(&self.output_schema),
655 metrics,
656 }))
657 }
658
659 fn metrics(&self) -> Option<MetricsSet> {
660 Some(self.metrics.clone_inner())
661 }
662}
663
664enum ProgramStreamState {
669 Running(Pin<Box<dyn std::future::Future<Output = DFResult<Vec<RecordBatch>>> + Send>>),
670 Emitting(Vec<RecordBatch>, usize),
671 Done,
672}
673
674struct ProgramStream {
675 state: ProgramStreamState,
676 schema: SchemaRef,
677 metrics: BaselineMetrics,
678}
679
680impl Stream for ProgramStream {
681 type Item = DFResult<RecordBatch>;
682
683 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
684 let this = self.get_mut();
685 let metrics = this.metrics.clone();
686 let _timer = metrics.elapsed_compute().timer();
687 loop {
688 match &mut this.state {
689 ProgramStreamState::Running(fut) => match fut.as_mut().poll(cx) {
690 Poll::Ready(Ok(batches)) => {
691 if batches.is_empty() {
692 this.state = ProgramStreamState::Done;
693 return Poll::Ready(None);
694 }
695 this.state = ProgramStreamState::Emitting(batches, 0);
696 }
697 Poll::Ready(Err(e)) => {
698 this.state = ProgramStreamState::Done;
699 return Poll::Ready(Some(Err(e)));
700 }
701 Poll::Pending => return Poll::Pending,
702 },
703 ProgramStreamState::Emitting(batches, idx) => {
704 if *idx >= batches.len() {
705 this.state = ProgramStreamState::Done;
706 return Poll::Ready(None);
707 }
708 let batch = batches[*idx].clone();
709 *idx += 1;
710 this.metrics.record_output(batch.num_rows());
711 return Poll::Ready(Some(Ok(batch)));
712 }
713 ProgramStreamState::Done => return Poll::Ready(None),
714 }
715 }
716 }
717}
718
719impl RecordBatchStream for ProgramStream {
720 fn schema(&self) -> SchemaRef {
721 Arc::clone(&self.schema)
722 }
723}
724
725async fn execute_cypher_inline(
731 query: &uni_cypher::ast::Query,
732 schema_info: &Arc<UniSchema>,
733 params: &HashMap<String, Value>,
734 graph_ctx: &Arc<GraphExecutionContext>,
735 session_ctx: &Arc<RwLock<datafusion::prelude::SessionContext>>,
736 storage: &Arc<StorageManager>,
737) -> DFResult<Vec<FactRow>> {
738 let planner = crate::query::planner::QueryPlanner::new(Arc::clone(schema_info));
739 let logical_plan = planner.plan(query.clone()).map_err(|e| {
740 datafusion::error::DataFusionError::Execution(format!("Cypher plan error: {e}"))
741 })?;
742 let batches = execute_subplan(
743 &logical_plan,
744 params,
745 &HashMap::new(),
746 graph_ctx,
747 session_ctx,
748 storage,
749 schema_info,
750 None, )
752 .await?;
753 Ok(super::locy_eval::record_batches_to_locy_rows(&batches))
754}
755
756#[expect(
761 clippy::too_many_arguments,
762 reason = "program evaluation requires full graph and session context"
763)]
764async fn run_program(
765 strata: Vec<LocyStratum>,
766 commands: Vec<LocyCommand>,
767 registry: Arc<DerivedScanRegistry>,
768 plugin_registry: Arc<PluginRegistry>,
769 graph_ctx: Arc<GraphExecutionContext>,
770 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
771 storage: Arc<StorageManager>,
772 schema_info: Arc<UniSchema>,
773 params: HashMap<String, Value>,
774 output_schema: SchemaRef,
775 max_iterations: usize,
776 timeout: Duration,
777 max_derived_bytes: usize,
778 deterministic_best_by: bool,
779 strict_probability_domain: bool,
780 probability_epsilon: f64,
781 exact_probability: bool,
782 max_bdd_variables: usize,
783 derived_store_slot: Arc<StdRwLock<Option<DerivedStore>>>,
784 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
785 iteration_counts_slot: Arc<StdRwLock<HashMap<String, usize>>>,
786 peak_memory_slot: Arc<StdRwLock<usize>>,
787 derivation_tracker: Option<Arc<ProvenanceStore>>,
788 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
789 command_results_slot: Arc<StdRwLock<Vec<(usize, CommandResult)>>>,
790 top_k_proofs: usize,
791 timeout_flag: Arc<std::sync::atomic::AtomicU8>,
792 incomplete_slot: Arc<StdRwLock<Option<uni_common::LocyIncomplete>>>,
793 semiring_kind: SemiringKind,
794 classifier_registry: Arc<ClassifierRegistry>,
795 classifier_cache: Option<Arc<ModelInvocationCache>>,
796 classifier_provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
797) -> DFResult<Vec<RecordBatch>> {
798 let start = Instant::now();
799 let mut derived_store = DerivedStore::new();
800
801 if semiring_kind == SemiringKind::MaxMinProb {
806 let mut warnings = warnings_slot.write().unwrap_or_else(|e| e.into_inner());
807 let mut already: std::collections::HashSet<String> = warnings
808 .iter()
809 .filter(|w| w.code == uni_locy::RuntimeWarningCode::FuzzyNotProbabilistic)
810 .map(|w| w.rule_name.clone())
811 .collect();
812 for stratum in &strata {
813 for rule in &stratum.rules {
814 let has_prob = rule.yield_schema.iter().any(|c| c.is_prob);
815 if has_prob && !already.contains(&rule.name) {
816 warnings.push(RuntimeWarning {
817 code: uni_locy::RuntimeWarningCode::FuzzyNotProbabilistic,
818 message: format!(
819 "rule '{}' carries a PROB column but is being evaluated under \
820 the MaxMinProb (fuzzy / Viterbi) semiring; outputs are fuzzy \
821 truth values, not probabilities",
822 rule.name
823 ),
824 rule_name: rule.name.clone(),
825 variable_count: None,
826 key_group: None,
827 });
828 already.insert(rule.name.clone());
829 }
830 }
831 }
832 }
833
834 let total_strata = strata.len();
838 let mut completed_strata = 0usize;
839 let mut partial_stratum: Option<usize> = None;
840 for (stratum_idx, stratum) in strata.iter().enumerate() {
841 write_cross_stratum_facts(®istry, &derived_store, stratum);
843
844 let remaining_timeout = timeout.saturating_sub(start.elapsed());
845 if remaining_timeout.is_zero() {
846 tracing::warn!("Locy program timeout exceeded during stratum evaluation");
847 interruption::set(&timeout_flag, interruption::TIMEOUT);
848 break;
849 }
850
851 if stratum.is_recursive {
852 let fixpoint_rules = convert_to_fixpoint_plans(
854 &stratum.rules,
855 ®istry,
856 &plugin_registry,
857 deterministic_best_by,
858 )?;
859 let fixpoint_schema = build_fixpoint_output_schema(&stratum.rules);
860
861 let exec = FixpointExec::new_with_semiring_classifiers_and_cache(
862 fixpoint_rules,
863 max_iterations,
864 remaining_timeout,
865 Arc::clone(&graph_ctx),
866 Arc::clone(&session_ctx),
867 Arc::clone(&storage),
868 Arc::clone(&schema_info),
869 params.clone(),
870 Arc::clone(®istry),
871 fixpoint_schema,
872 max_derived_bytes,
873 derivation_tracker.clone(),
874 Arc::clone(&iteration_counts_slot),
875 strict_probability_domain,
876 probability_epsilon,
877 exact_probability,
878 max_bdd_variables,
879 Arc::clone(&warnings_slot),
880 Arc::clone(&approximate_slot),
881 top_k_proofs,
882 Arc::clone(&timeout_flag),
883 semiring_kind,
884 Arc::clone(&classifier_registry),
885 classifier_cache.as_ref().map(Arc::clone),
886 classifier_provenance_store.as_ref().map(Arc::clone),
887 );
888
889 let task_ctx = session_ctx.read().task_ctx();
890 let exec_arc: Arc<dyn ExecutionPlan> = Arc::new(exec);
891 let batches = collect_all_partitions(&exec_arc, task_ctx).await?;
892
893 for rule in &stratum.rules {
904 if rule.yield_schema.is_empty() {
906 continue;
907 }
908 let rule_entries = registry.entries_for_rule(&rule.name);
910 for entry in rule_entries {
911 if !entry.is_self_ref {
912 let all_facts: Vec<RecordBatch> = batches
916 .iter()
917 .filter(|b| {
918 let rule_schema = yield_columns_to_arrow_schema(&rule.yield_schema);
920 b.schema().fields().len() == rule_schema.fields().len()
921 })
922 .cloned()
923 .collect();
924 let mut guard = entry.data.write();
925 *guard = if all_facts.is_empty() {
926 vec![RecordBatch::new_empty(Arc::clone(&entry.schema))]
927 } else {
928 all_facts
929 };
930 }
931 }
932 derived_store.insert(rule.name.clone(), batches.clone());
933 }
934 } else {
935 let fixpoint_rules = convert_to_fixpoint_plans(
937 &stratum.rules,
938 ®istry,
939 &plugin_registry,
940 deterministic_best_by,
941 )?;
942 let task_ctx = session_ctx.read().task_ctx();
943
944 for (rule, fp_rule) in stratum.rules.iter().zip(fixpoint_rules.iter()) {
945 if rule.yield_schema.is_empty() {
950 continue;
951 }
952
953 let mut tagged_clause_facts: Vec<(usize, Vec<RecordBatch>)> = Vec::new();
955 for (clause_idx, (clause, fp_clause)) in
956 rule.clauses.iter().zip(fp_rule.clauses.iter()).enumerate()
957 {
958 let mut batches = execute_subplan(
962 &clause.body,
963 ¶ms,
964 &HashMap::new(),
965 &graph_ctx,
966 &session_ctx,
967 &storage,
968 &schema_info,
969 None, )
971 .await?;
972
973 for binding in &fp_clause.is_ref_bindings {
975 if binding.negated
976 && !binding.anti_join_cols.is_empty()
977 && let Some(entry) = registry.get(binding.derived_scan_index)
978 {
979 let neg_facts = entry.data.read().clone();
980 if !neg_facts.is_empty() {
981 if binding.target_has_prob && fp_rule.prob_column_name.is_some() {
982 let complement_col =
983 format!("__prob_complement_{}", binding.rule_name);
984 if let Some(prob_col) = &binding.target_prob_col {
985 batches =
986 super::locy_fixpoint::apply_prob_complement_composite(
987 batches,
988 &neg_facts,
989 &binding.anti_join_cols,
990 prob_col,
991 &complement_col,
992 )?;
993 } else {
994 batches = super::locy_fixpoint::apply_anti_join_composite(
996 batches,
997 &neg_facts,
998 &binding.anti_join_cols,
999 )?;
1000 }
1001 } else {
1002 batches = super::locy_fixpoint::apply_anti_join_composite(
1003 batches,
1004 &neg_facts,
1005 &binding.anti_join_cols,
1006 )?;
1007 }
1008 }
1009 }
1010 }
1011
1012 let complement_cols: Vec<String> = if !batches.is_empty() {
1014 batches[0]
1015 .schema()
1016 .fields()
1017 .iter()
1018 .filter(|f| f.name().starts_with("__prob_complement_"))
1019 .map(|f| f.name().clone())
1020 .collect()
1021 } else {
1022 vec![]
1023 };
1024 if !complement_cols.is_empty() {
1025 batches = super::locy_fixpoint::multiply_prob_factors(
1026 batches,
1027 fp_rule.prob_column_name.as_deref(),
1028 &complement_cols,
1029 )?;
1030 }
1031
1032 tagged_clause_facts.push((clause_idx, batches));
1033 }
1034
1035 let shared_info = if semiring_kind == SemiringKind::MaxMinProb {
1048 None
1049 } else if let Some(ref tracker) = derivation_tracker {
1050 super::locy_fixpoint::record_and_detect_lineage_nonrecursive(
1051 fp_rule,
1052 &tagged_clause_facts,
1053 tracker,
1054 &warnings_slot,
1055 ®istry,
1056 top_k_proofs,
1057 super::locy_fixpoint::ClassifierRefs {
1058 registry: &classifier_registry,
1059 cache: classifier_cache.as_ref(),
1060 provenance_store: classifier_provenance_store.as_ref(),
1061 },
1062 semiring_kind,
1063 )
1064 .await
1065 } else {
1066 None
1067 };
1068
1069 let mut all_clause_facts: Vec<RecordBatch> = tagged_clause_facts
1071 .into_iter()
1072 .flat_map(|(_, batches)| batches)
1073 .collect();
1074
1075 if exact_probability
1077 && let Some(ref info) = shared_info
1078 && let Some(ref tracker) = derivation_tracker
1079 {
1080 all_clause_facts = super::locy_fixpoint::apply_exact_wmc(
1081 all_clause_facts,
1082 fp_rule,
1083 info,
1084 tracker,
1085 max_bdd_variables,
1086 &warnings_slot,
1087 &approximate_slot,
1088 )?;
1089 }
1090
1091 let facts = super::locy_fixpoint::apply_post_fixpoint_chain(
1093 all_clause_facts,
1094 fp_rule,
1095 &task_ctx,
1096 strict_probability_domain,
1097 probability_epsilon,
1098 semiring_kind,
1099 derivation_tracker.as_ref().map(Arc::clone),
1100 top_k_proofs,
1101 Some(Arc::clone(®istry)),
1102 )
1103 .await?;
1104
1105 write_facts_to_registry(®istry, &rule.name, &facts);
1107 derived_store.insert(rule.name.clone(), facts);
1108 }
1109 }
1110
1111 if interruption::reason(&timeout_flag).is_some() {
1115 partial_stratum = Some(stratum_idx);
1116 break;
1117 }
1118 completed_strata += 1;
1119 }
1120
1121 if let Some(reason) = interruption::reason(&timeout_flag) {
1125 let skipped_start = match partial_stratum {
1126 Some(i) => i + 1,
1127 None => completed_strata,
1128 };
1129 let incomplete_rules: Vec<String> = partial_stratum
1130 .map(|i| strata[i].rules.iter().map(|r| r.name.clone()).collect())
1131 .unwrap_or_default();
1132 let skipped_rules: Vec<String> = strata[skipped_start..]
1133 .iter()
1134 .flat_map(|s| s.rules.iter().map(|r| r.name.clone()))
1135 .collect();
1136 let mut complement_rules_affected = Vec::new();
1137 for idx in partial_stratum
1138 .into_iter()
1139 .chain(skipped_start..total_strata)
1140 {
1141 for rule in &strata[idx].rules {
1142 if rule
1143 .clauses
1144 .iter()
1145 .any(|c| c.is_refs.iter().any(|r| r.negated))
1146 {
1147 complement_rules_affected.push(rule.name.clone());
1148 }
1149 }
1150 }
1151 if let Ok(mut slot) = incomplete_slot.write() {
1152 *slot = Some(uni_common::LocyIncomplete {
1153 reason,
1154 elapsed_ms: u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
1155 limit_ms: u64::try_from(timeout.as_millis()).unwrap_or(u64::MAX),
1156 max_iterations,
1157 completed_strata,
1158 total_strata,
1159 incomplete_rules,
1160 skipped_rules,
1161 complement_rules_affected,
1162 });
1163 }
1164 }
1165
1166 let peak_bytes: usize = derived_store
1168 .relations
1169 .values()
1170 .flat_map(|batches| batches.iter())
1171 .map(|b| {
1172 b.columns()
1173 .iter()
1174 .map(|col| col.get_buffer_memory_size())
1175 .sum::<usize>()
1176 })
1177 .sum();
1178 *peak_memory_slot.write().unwrap() = peak_bytes;
1179
1180 let first_derive_idx = commands
1190 .iter()
1191 .position(|c| matches!(c, LocyCommand::Derive { .. }));
1192 let mut inline_results: Vec<(usize, CommandResult)> = Vec::new();
1193 for (cmd_idx, cmd) in commands.iter().enumerate() {
1194 match cmd {
1195 LocyCommand::Cypher { query } => {
1196 if first_derive_idx.is_some_and(|di| cmd_idx > di) {
1199 continue;
1200 }
1201 let rows = execute_cypher_inline(
1202 query,
1203 &schema_info,
1204 ¶ms,
1205 &graph_ctx,
1206 &session_ctx,
1207 &storage,
1208 )
1209 .await?;
1210 inline_results.push((cmd_idx, CommandResult::Cypher(rows)));
1211 }
1212 LocyCommand::Validate { validate } => {
1213 let rule_key_cols: Vec<String> = strata
1217 .iter()
1218 .flat_map(|s| s.rules.iter())
1219 .find(|r| r.name == validate.rule_name)
1220 .map(|r| {
1221 r.yield_schema
1222 .iter()
1223 .filter(|c| c.is_key)
1224 .map(|c| c.name.clone())
1225 .collect()
1226 })
1227 .unwrap_or_default();
1228 let query =
1229 super::locy_validate::validate_collection_query(validate, &rule_key_cols);
1230 let target_rows = execute_cypher_inline(
1231 &query,
1232 &schema_info,
1233 ¶ms,
1234 &graph_ctx,
1235 &session_ctx,
1236 &storage,
1237 )
1238 .await?;
1239 let rule_facts: Vec<uni_locy::FactRow> = derived_store
1240 .get(&validate.rule_name)
1241 .map(|batches| super::locy_eval::record_batches_to_locy_rows(batches))
1242 .unwrap_or_default();
1243 let result = super::locy_validate::run_validate(
1244 validate,
1245 &rule_key_cols,
1246 &rule_facts,
1247 target_rows,
1248 )
1249 .map_err(|e| {
1250 datafusion::error::DataFusionError::Execution(format!("VALIDATE error: {e}"))
1251 })?;
1252 inline_results.push((cmd_idx, CommandResult::Validate(result)));
1253 }
1254 LocyCommand::Calibrate {
1255 calibrate,
1256 model_inputs,
1257 } => {
1258 let model_snapshot = uni_locy::CompiledModel {
1271 name: calibrate.model_name.clone(),
1272 inputs: model_inputs.clone(),
1273 features: vec![],
1274 path_context: None,
1275 output_type: uni_cypher::locy_ast::OutputType::Prob,
1276 output_name: String::new(),
1277 xervo_alias: String::new(),
1278 embedder_alias: None,
1279 calibration: None,
1280 version: None,
1281 annotations: Default::default(),
1282 };
1283 let query =
1284 super::locy_calibrate::calibrate_collection_query(calibrate, &model_snapshot);
1285 let rows = execute_cypher_inline(
1286 &query,
1287 &schema_info,
1288 ¶ms,
1289 &graph_ctx,
1290 &session_ctx,
1291 &storage,
1292 )
1293 .await?;
1294 let mut catalog = std::collections::HashMap::new();
1295 catalog.insert(calibrate.model_name.clone(), model_snapshot);
1296 let result = super::locy_calibrate::run_calibrate(
1297 calibrate,
1298 &catalog,
1299 &classifier_registry,
1300 rows,
1301 )
1302 .await
1303 .map_err(|e| {
1304 datafusion::error::DataFusionError::Execution(format!("CALIBRATE error: {e}"))
1305 })?;
1306 inline_results.push((cmd_idx, CommandResult::Calibrate(result)));
1307 }
1308 _ => {}
1309 }
1310 }
1311 *command_results_slot.write().unwrap() = inline_results;
1312
1313 let stats = vec![build_stats_batch(&derived_store, &strata, output_schema)];
1314 *derived_store_slot.write().unwrap() = Some(derived_store);
1315 Ok(stats)
1316}
1317
1318fn write_cross_stratum_facts(
1324 registry: &DerivedScanRegistry,
1325 derived_store: &DerivedStore,
1326 stratum: &LocyStratum,
1327) {
1328 for rule in &stratum.rules {
1330 for clause in &rule.clauses {
1331 for is_ref in &clause.is_refs {
1332 if let Some(facts) = derived_store.get(&is_ref.rule_name) {
1335 write_facts_to_registry(registry, &is_ref.rule_name, facts);
1336 }
1337 }
1338 }
1339 }
1340}
1341
1342fn write_facts_to_registry(registry: &DerivedScanRegistry, rule_name: &str, facts: &[RecordBatch]) {
1344 let entries = registry.entries_for_rule(rule_name);
1345 for entry in entries {
1346 if !entry.is_self_ref {
1347 let mut guard = entry.data.write();
1348 *guard = if facts.is_empty() || facts.iter().all(|b| b.num_rows() == 0) {
1349 vec![RecordBatch::new_empty(Arc::clone(&entry.schema))]
1350 } else {
1351 facts
1356 .iter()
1357 .filter(|b| b.num_rows() > 0)
1358 .map(|b| {
1359 RecordBatch::try_new(Arc::clone(&entry.schema), b.columns().to_vec())
1360 .unwrap_or_else(|_| b.clone())
1361 })
1362 .collect()
1363 };
1364 }
1365 }
1366}
1367
1368fn convert_to_fixpoint_plans(
1374 rules: &[LocyRulePlan],
1375 registry: &DerivedScanRegistry,
1376 plugin_registry: &PluginRegistry,
1377 deterministic_best_by: bool,
1378) -> DFResult<Vec<FixpointRulePlan>> {
1379 let stratum_rule_names: std::collections::HashSet<&str> =
1382 rules.iter().map(|r| r.name.as_str()).collect();
1383 rules
1384 .iter()
1385 .map(|rule| {
1386 let yield_schema = yield_columns_to_arrow_schema(&rule.yield_schema);
1387 let key_column_indices: Vec<usize> = rule
1388 .yield_schema
1389 .iter()
1390 .enumerate()
1391 .filter(|(_, yc)| yc.is_key)
1392 .map(|(i, _)| i)
1393 .collect();
1394
1395 let clauses: Vec<FixpointClausePlan> = rule
1396 .clauses
1397 .iter()
1398 .map(|clause| {
1399 let is_ref_bindings =
1400 convert_is_refs(&clause.is_refs, registry, &stratum_rule_names)?;
1401 Ok(FixpointClausePlan {
1402 body_logical: clause.body.clone(),
1403 is_ref_bindings,
1404 priority: clause.priority,
1405 along_bindings: clause.along_bindings.clone(),
1406 model_invocations: clause.model_invocations.clone(),
1407 })
1408 })
1409 .collect::<DFResult<Vec<_>>>()?;
1410
1411 let fold_bindings =
1412 convert_fold_bindings(&rule.fold_bindings, &rule.yield_schema, plugin_registry)?;
1413 let best_by_criteria =
1414 convert_best_by_criteria(&rule.best_by_criteria, &rule.yield_schema)?;
1415
1416 let has_priority = rule.priority.is_some();
1417
1418 let yield_schema = if has_priority {
1420 let mut fields: Vec<Arc<Field>> = yield_schema.fields().iter().cloned().collect();
1421 fields.push(Arc::new(Field::new("__priority", DataType::Int64, true)));
1422 ArrowSchema::new(fields)
1423 } else {
1424 yield_schema
1425 };
1426
1427 let prob_column_name = rule
1428 .yield_schema
1429 .iter()
1430 .find(|yc| yc.is_prob)
1431 .map(|yc| yc.name.clone());
1432
1433 let non_linear = rule.clauses.iter().any(|clause| {
1437 clause
1438 .is_refs
1439 .iter()
1440 .filter(|ir| !ir.negated && stratum_rule_names.contains(ir.rule_name.as_str()))
1441 .count()
1442 >= 2
1443 });
1444
1445 Ok(FixpointRulePlan {
1446 name: rule.name.clone(),
1447 clauses,
1448 yield_schema: Arc::new(yield_schema),
1449 key_column_indices,
1450 priority: rule.priority,
1451 has_fold: !rule.fold_bindings.is_empty(),
1452 fold_bindings,
1453 having: rule.having.clone(),
1454 has_best_by: !rule.best_by_criteria.is_empty(),
1455 best_by_criteria,
1456 has_priority,
1457 deterministic: deterministic_best_by,
1458 prob_column_name,
1459 non_linear,
1460 })
1461 })
1462 .collect()
1463}
1464
1465fn convert_is_refs(
1477 is_refs: &[LocyIsRef],
1478 registry: &DerivedScanRegistry,
1479 stratum_rule_names: &std::collections::HashSet<&str>,
1480) -> DFResult<Vec<IsRefBinding>> {
1481 is_refs
1482 .iter()
1483 .map(|is_ref| {
1484 let entries = registry.entries_for_rule(&is_ref.rule_name);
1485 let want_self_ref = stratum_rule_names.contains(is_ref.rule_name.as_str());
1490 let entry = entries
1491 .iter()
1492 .find(|e| e.is_self_ref == want_self_ref)
1493 .or_else(|| entries.first())
1494 .ok_or_else(|| {
1495 datafusion::error::DataFusionError::Plan(format!(
1496 "No derived scan entry found for IS-ref to '{}'",
1497 is_ref.rule_name
1498 ))
1499 })?;
1500
1501 let anti_join_cols = if is_ref.negated {
1506 let mut cols: Vec<(String, String)> = is_ref
1507 .subjects
1508 .iter()
1509 .enumerate()
1510 .filter_map(|(i, s)| {
1511 if let uni_cypher::ast::Expr::Variable(var) = s {
1512 let right_col = entry
1513 .schema
1514 .fields()
1515 .get(i)
1516 .map(|f| f.name().clone())
1517 .unwrap_or_else(|| var.clone());
1518 Some((var.clone(), right_col))
1521 } else {
1522 None
1523 }
1524 })
1525 .collect();
1526 if let Some(uni_cypher::ast::Expr::Variable(target_var)) = &is_ref.target {
1531 let target_idx = is_ref.subjects.len();
1532 if let Some(field) = entry.schema.fields().get(target_idx) {
1533 cols.push((target_var.clone(), field.name().clone()));
1534 }
1535 }
1536 cols
1537 } else {
1538 Vec::new()
1539 };
1540
1541 let provenance_join_cols: Vec<(String, String)> = is_ref
1545 .subjects
1546 .iter()
1547 .enumerate()
1548 .filter_map(|(i, s)| {
1549 if let uni_cypher::ast::Expr::Variable(var) = s {
1550 let right_col = entry
1551 .schema
1552 .fields()
1553 .get(i)
1554 .map(|f| f.name().clone())
1555 .unwrap_or_else(|| var.clone());
1556 Some((var.clone(), right_col))
1557 } else {
1558 None
1559 }
1560 })
1561 .collect();
1562
1563 Ok(IsRefBinding {
1564 derived_scan_index: entry.scan_index,
1565 rule_name: is_ref.rule_name.clone(),
1566 is_self_ref: entry.is_self_ref,
1567 negated: is_ref.negated,
1568 anti_join_cols,
1569 target_has_prob: is_ref.target_has_prob,
1570 target_prob_col: is_ref.target_prob_col.clone(),
1571 provenance_join_cols,
1572 })
1573 })
1574 .collect()
1575}
1576
1577fn convert_fold_bindings(
1585 fold_bindings: &[(String, String, Expr)],
1586 yield_schema: &[LocyYieldColumn],
1587 plugin_registry: &PluginRegistry,
1588) -> DFResult<Vec<FoldBinding>> {
1589 fold_bindings
1590 .iter()
1591 .map(|(name, yield_alias, expr)| {
1592 let (agg_name, _input_col_name) = parse_fold_aggregate(expr)?;
1593 let entry =
1594 resolve_locy_aggregate(plugin_registry, agg_name.as_str()).ok_or_else(|| {
1595 datafusion::error::DataFusionError::Plan(format!(
1596 "Unknown Locy aggregate '{agg_name}' — not registered in plugin registry"
1597 ))
1598 })?;
1599 let aggregate = Arc::clone(&entry.aggregate);
1600
1601 if agg_name.as_str() == "COUNTALL" {
1604 return Ok(FoldBinding {
1605 output_name: yield_alias.clone(),
1606 name: agg_name,
1607 aggregate,
1608 input_col_index: 0, input_col_name: None,
1610 });
1611 }
1612
1613 let input_col_index = yield_schema
1618 .iter()
1619 .position(|yc| yc.name == *name || yc.name == *yield_alias)
1620 .unwrap_or(0);
1621 Ok(FoldBinding {
1622 output_name: yield_alias.clone(),
1623 name: agg_name,
1624 aggregate,
1625 input_col_index,
1626 input_col_name: Some(name.clone()),
1627 })
1628 })
1629 .collect()
1630}
1631
1632fn parse_fold_aggregate(expr: &Expr) -> DFResult<(smol_str::SmolStr, String)> {
1638 match expr {
1639 Expr::FunctionCall { name, args, .. } => {
1640 let upper = name.to_uppercase();
1641 let is_count = matches!(upper.as_str(), "COUNT" | "MCOUNT");
1642
1643 if is_count && args.is_empty() {
1645 return Ok((smol_str::SmolStr::new_static("COUNTALL"), String::new()));
1646 }
1647
1648 let canonical = match upper.as_str() {
1649 "SUM" | "MSUM" => smol_str::SmolStr::new_static("SUM"),
1650 "MAX" | "MMAX" => smol_str::SmolStr::new_static("MAX"),
1651 "MIN" | "MMIN" => smol_str::SmolStr::new_static("MIN"),
1652 "COUNT" | "MCOUNT" => smol_str::SmolStr::new_static("COUNT"),
1653 "AVG" => smol_str::SmolStr::new_static("AVG"),
1654 "COLLECT" => smol_str::SmolStr::new_static("COLLECT"),
1655 "MNOR" => smol_str::SmolStr::new_static("MNOR"),
1656 "MPROD" => smol_str::SmolStr::new_static("MPROD"),
1657 _ => {
1658 return Err(datafusion::error::DataFusionError::Plan(format!(
1659 "Unknown FOLD aggregate function: {}",
1660 name
1661 )));
1662 }
1663 };
1664 let col_name = match args.first() {
1665 Some(Expr::Variable(v)) => v.clone(),
1666 Some(Expr::Property(_, prop)) => prop.clone(),
1667 Some(other) => other.to_string_repr(),
1668 None => {
1669 return Err(datafusion::error::DataFusionError::Plan(
1670 "FOLD aggregate function requires at least one argument".to_string(),
1671 ));
1672 }
1673 };
1674 Ok((canonical, col_name))
1675 }
1676 _ => Err(datafusion::error::DataFusionError::Plan(
1677 "FOLD binding must be a function call (e.g., SUM(x))".to_string(),
1678 )),
1679 }
1680}
1681
1682fn convert_best_by_criteria(
1689 criteria: &[(Expr, bool)],
1690 yield_schema: &[LocyYieldColumn],
1691) -> DFResult<Vec<SortCriterion>> {
1692 criteria
1693 .iter()
1694 .map(|(expr, ascending)| {
1695 let col_name = match expr {
1696 Expr::Property(_, prop) => prop.clone(),
1697 Expr::Variable(v) => v.clone(),
1698 _ => {
1699 return Err(datafusion::error::DataFusionError::Plan(
1700 "BEST BY criterion must be a variable or property reference".to_string(),
1701 ));
1702 }
1703 };
1704 let col_index = yield_schema
1706 .iter()
1707 .position(|yc| yc.name == col_name)
1708 .or_else(|| {
1709 let short_name = col_name.rsplit('.').next().unwrap_or(&col_name);
1710 yield_schema.iter().position(|yc| yc.name == short_name)
1711 })
1712 .ok_or_else(|| {
1713 datafusion::error::DataFusionError::Plan(format!(
1714 "BEST BY column '{}' not found in yield schema",
1715 col_name
1716 ))
1717 })?;
1718 Ok(SortCriterion {
1719 col_index,
1720 ascending: *ascending,
1721 nulls_first: false,
1722 })
1723 })
1724 .collect()
1725}
1726
1727fn yield_columns_to_arrow_schema(columns: &[LocyYieldColumn]) -> ArrowSchema {
1733 let fields: Vec<Arc<Field>> = columns
1734 .iter()
1735 .map(|yc| Arc::new(Field::new(&yc.name, yc.data_type.clone(), true)))
1736 .collect();
1737 ArrowSchema::new(fields)
1738}
1739
1740fn build_fixpoint_output_schema(rules: &[LocyRulePlan]) -> SchemaRef {
1742 if let Some(rule) = rules.first() {
1745 Arc::new(yield_columns_to_arrow_schema(&rule.yield_schema))
1746 } else {
1747 Arc::new(ArrowSchema::empty())
1748 }
1749}
1750
1751fn build_stats_batch(
1753 derived_store: &DerivedStore,
1754 _strata: &[LocyStratum],
1755 output_schema: SchemaRef,
1756) -> RecordBatch {
1757 let mut rule_names: Vec<String> = derived_store.rule_names().map(String::from).collect();
1759 rule_names.sort();
1760
1761 let name_col: arrow_array::StringArray = rule_names.iter().map(|s| Some(s.as_str())).collect();
1762 let count_col: arrow_array::Int64Array = rule_names
1763 .iter()
1764 .map(|name| Some(derived_store.fact_count(name) as i64))
1765 .collect();
1766
1767 let stats_schema = stats_schema();
1768 RecordBatch::try_new(stats_schema, vec![Arc::new(name_col), Arc::new(count_col)])
1769 .unwrap_or_else(|_| RecordBatch::new_empty(output_schema))
1770}
1771
1772pub fn stats_schema() -> SchemaRef {
1774 Arc::new(ArrowSchema::new(vec![
1775 Arc::new(Field::new("rule_name", DataType::Utf8, false)),
1776 Arc::new(Field::new("fact_count", DataType::Int64, false)),
1777 ]))
1778}
1779
1780#[cfg(test)]
1785mod tests {
1786 use super::*;
1787 use arrow_array::{Int64Array, LargeBinaryArray, StringArray};
1788
1789 #[test]
1790 fn test_derived_store_insert_and_get() {
1791 let mut store = DerivedStore::new();
1792 assert!(store.get("test").is_none());
1793
1794 let schema = Arc::new(ArrowSchema::new(vec![Arc::new(Field::new(
1795 "x",
1796 DataType::LargeBinary,
1797 true,
1798 ))]));
1799 let batch = RecordBatch::try_new(
1800 Arc::clone(&schema),
1801 vec![Arc::new(LargeBinaryArray::from(vec![
1802 Some(b"a" as &[u8]),
1803 Some(b"b"),
1804 ]))],
1805 )
1806 .unwrap();
1807
1808 store.insert("test".to_string(), vec![batch.clone()]);
1809
1810 let facts = store.get("test").unwrap();
1811 assert_eq!(facts.len(), 1);
1812 assert_eq!(facts[0].num_rows(), 2);
1813 }
1814
1815 #[test]
1816 fn test_derived_store_fact_count() {
1817 let mut store = DerivedStore::new();
1818 assert_eq!(store.fact_count("empty"), 0);
1819
1820 let schema = Arc::new(ArrowSchema::new(vec![Arc::new(Field::new(
1821 "x",
1822 DataType::LargeBinary,
1823 true,
1824 ))]));
1825 let batch1 = RecordBatch::try_new(
1826 Arc::clone(&schema),
1827 vec![Arc::new(LargeBinaryArray::from(vec![Some(b"a" as &[u8])]))],
1828 )
1829 .unwrap();
1830 let batch2 = RecordBatch::try_new(
1831 Arc::clone(&schema),
1832 vec![Arc::new(LargeBinaryArray::from(vec![
1833 Some(b"b" as &[u8]),
1834 Some(b"c"),
1835 ]))],
1836 )
1837 .unwrap();
1838
1839 store.insert("test".to_string(), vec![batch1, batch2]);
1840 assert_eq!(store.fact_count("test"), 3);
1841 }
1842
1843 #[test]
1844 fn test_stats_batch_schema() {
1845 let schema = stats_schema();
1846 assert_eq!(schema.fields().len(), 2);
1847 assert_eq!(schema.field(0).name(), "rule_name");
1848 assert_eq!(schema.field(1).name(), "fact_count");
1849 assert_eq!(schema.field(0).data_type(), &DataType::Utf8);
1850 assert_eq!(schema.field(1).data_type(), &DataType::Int64);
1851 }
1852
1853 #[test]
1854 fn test_stats_batch_content() {
1855 let mut store = DerivedStore::new();
1856 let schema = Arc::new(ArrowSchema::new(vec![Arc::new(Field::new(
1857 "x",
1858 DataType::LargeBinary,
1859 true,
1860 ))]));
1861 let batch = RecordBatch::try_new(
1862 Arc::clone(&schema),
1863 vec![Arc::new(LargeBinaryArray::from(vec![
1864 Some(b"a" as &[u8]),
1865 Some(b"b"),
1866 ]))],
1867 )
1868 .unwrap();
1869 store.insert("reach".to_string(), vec![batch]);
1870
1871 let output_schema = stats_schema();
1872 let stats = build_stats_batch(&store, &[], Arc::clone(&output_schema));
1873 assert_eq!(stats.num_rows(), 1);
1874
1875 let names = stats
1876 .column(0)
1877 .as_any()
1878 .downcast_ref::<StringArray>()
1879 .unwrap();
1880 assert_eq!(names.value(0), "reach");
1881
1882 let counts = stats
1883 .column(1)
1884 .as_any()
1885 .downcast_ref::<Int64Array>()
1886 .unwrap();
1887 assert_eq!(counts.value(0), 2);
1888 }
1889
1890 #[test]
1891 fn test_yield_columns_to_arrow_schema() {
1892 let columns = vec![
1893 LocyYieldColumn {
1894 name: "a".to_string(),
1895 is_key: true,
1896 is_prob: false,
1897 data_type: DataType::UInt64,
1898 },
1899 LocyYieldColumn {
1900 name: "b".to_string(),
1901 is_key: false,
1902 is_prob: false,
1903 data_type: DataType::LargeUtf8,
1904 },
1905 LocyYieldColumn {
1906 name: "c".to_string(),
1907 is_key: true,
1908 is_prob: false,
1909 data_type: DataType::Float64,
1910 },
1911 ];
1912
1913 let schema = yield_columns_to_arrow_schema(&columns);
1914 assert_eq!(schema.fields().len(), 3);
1915 assert_eq!(schema.field(0).name(), "a");
1916 assert_eq!(schema.field(1).name(), "b");
1917 assert_eq!(schema.field(2).name(), "c");
1918 assert_eq!(schema.field(0).data_type(), &DataType::UInt64);
1920 assert_eq!(schema.field(1).data_type(), &DataType::LargeUtf8);
1921 assert_eq!(schema.field(2).data_type(), &DataType::Float64);
1922 for field in schema.fields() {
1923 assert!(field.is_nullable());
1924 }
1925 }
1926
1927 #[test]
1928 fn test_key_column_indices() {
1929 let columns = [
1930 LocyYieldColumn {
1931 name: "a".to_string(),
1932 is_key: true,
1933 is_prob: false,
1934 data_type: DataType::LargeBinary,
1935 },
1936 LocyYieldColumn {
1937 name: "b".to_string(),
1938 is_key: false,
1939 is_prob: false,
1940 data_type: DataType::LargeBinary,
1941 },
1942 LocyYieldColumn {
1943 name: "c".to_string(),
1944 is_key: true,
1945 is_prob: false,
1946 data_type: DataType::LargeBinary,
1947 },
1948 ];
1949
1950 let key_indices: Vec<usize> = columns
1951 .iter()
1952 .enumerate()
1953 .filter(|(_, yc)| yc.is_key)
1954 .map(|(i, _)| i)
1955 .collect();
1956 assert_eq!(key_indices, vec![0, 2]);
1957 }
1958
1959 #[test]
1960 fn test_parse_fold_aggregate_sum() {
1961 let expr = Expr::FunctionCall {
1962 name: "SUM".to_string(),
1963 args: vec![Expr::Variable("cost".to_string())],
1964 distinct: false,
1965 window_spec: None,
1966 };
1967 let (kind, col) = parse_fold_aggregate(&expr).unwrap();
1968 assert_eq!(kind.as_str(), "SUM");
1969 assert_eq!(col, "cost");
1970 }
1971
1972 #[test]
1973 fn test_parse_fold_aggregate_monotonic() {
1974 let expr = Expr::FunctionCall {
1975 name: "MMAX".to_string(),
1976 args: vec![Expr::Variable("score".to_string())],
1977 distinct: false,
1978 window_spec: None,
1979 };
1980 let (kind, col) = parse_fold_aggregate(&expr).unwrap();
1981 assert_eq!(kind.as_str(), "MAX");
1982 assert_eq!(col, "score");
1983 }
1984
1985 #[test]
1986 fn test_parse_fold_aggregate_unknown() {
1987 let expr = Expr::FunctionCall {
1988 name: "UNKNOWN_AGG".to_string(),
1989 args: vec![Expr::Variable("x".to_string())],
1990 distinct: false,
1991 window_spec: None,
1992 };
1993 assert!(parse_fold_aggregate(&expr).is_err());
1994 }
1995
1996 #[test]
1997 fn test_no_commands_returns_stats() {
1998 let store = DerivedStore::new();
1999 let output_schema = stats_schema();
2000 let stats = build_stats_batch(&store, &[], Arc::clone(&output_schema));
2001 assert_eq!(stats.num_rows(), 0);
2003 }
2004}