Skip to main content

uni_query/query/df_graph/
locy_program.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Top-level Locy program executor.
5//!
6//! `LocyProgramExec` orchestrates the full evaluation of a Locy program:
7//! it evaluates strata in dependency order, runs fixpoint for recursive strata,
8//! applies post-fixpoint operators (FOLD, PRIORITY, BEST BY), and then
9//! executes the program's commands (goal queries, DERIVE, ASSUME, etc.).
10
11use crate::query::df_graph::GraphExecutionContext;
12use crate::query::df_graph::common::{
13    collect_all_partitions, compute_plan_properties, execute_subplan,
14};
15use crate::query::df_graph::locy_best_by::SortCriterion;
16use crate::query::df_graph::locy_explain::DerivationTracker;
17use crate::query::df_graph::locy_fixpoint::{
18    DerivedScanRegistry, FixpointClausePlan, FixpointExec, FixpointRulePlan, IsRefBinding,
19};
20use crate::query::df_graph::locy_fold::{FoldAggKind, FoldBinding};
21use crate::query::planner_locy_types::{
22    LocyCommand, LocyIsRef, LocyRulePlan, LocyStratum, LocyYieldColumn,
23};
24use arrow_array::RecordBatch;
25use arrow_schema::{DataType, Field, Schema as ArrowSchema, SchemaRef};
26use datafusion::common::Result as DFResult;
27use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
28use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
29use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
30use futures::Stream;
31use parking_lot::RwLock;
32use std::any::Any;
33use std::collections::HashMap;
34use std::fmt;
35use std::pin::Pin;
36use std::sync::Arc;
37use std::sync::RwLock as StdRwLock;
38use std::task::{Context, Poll};
39use std::time::{Duration, Instant};
40use uni_common::Value;
41use uni_common::core::schema::Schema as UniSchema;
42use uni_cypher::ast::Expr;
43use uni_store::storage::manager::StorageManager;
44
45// ---------------------------------------------------------------------------
46// DerivedStore — cross-stratum fact sharing
47// ---------------------------------------------------------------------------
48
49/// Simple store for derived relation facts across strata.
50///
51/// Each rule's converged facts are stored here after its stratum completes,
52/// making them available for later strata that depend on them.
53pub struct DerivedStore {
54    relations: HashMap<String, Vec<RecordBatch>>,
55}
56
57impl Default for DerivedStore {
58    fn default() -> Self {
59        Self::new()
60    }
61}
62
63impl DerivedStore {
64    pub fn new() -> Self {
65        Self {
66            relations: HashMap::new(),
67        }
68    }
69
70    pub fn insert(&mut self, rule_name: String, facts: Vec<RecordBatch>) {
71        self.relations.insert(rule_name, facts);
72    }
73
74    pub fn get(&self, rule_name: &str) -> Option<&Vec<RecordBatch>> {
75        self.relations.get(rule_name)
76    }
77
78    pub fn fact_count(&self, rule_name: &str) -> usize {
79        self.relations
80            .get(rule_name)
81            .map(|batches| batches.iter().map(|b| b.num_rows()).sum())
82            .unwrap_or(0)
83    }
84
85    pub fn rule_names(&self) -> impl Iterator<Item = &str> {
86        self.relations.keys().map(|s| s.as_str())
87    }
88}
89
90// ---------------------------------------------------------------------------
91// LocyProgramExec — DataFusion ExecutionPlan
92// ---------------------------------------------------------------------------
93
94/// DataFusion `ExecutionPlan` that runs an entire Locy program.
95///
96/// Evaluates strata in dependency order, using `FixpointExec` for recursive
97/// strata and direct subplan execution for non-recursive ones. After all
98/// strata converge, dispatches commands.
99pub struct LocyProgramExec {
100    strata: Vec<LocyStratum>,
101    commands: Vec<LocyCommand>,
102    derived_scan_registry: Arc<DerivedScanRegistry>,
103    graph_ctx: Arc<GraphExecutionContext>,
104    session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
105    storage: Arc<StorageManager>,
106    schema_info: Arc<UniSchema>,
107    params: HashMap<String, Value>,
108    output_schema: SchemaRef,
109    properties: PlanProperties,
110    metrics: ExecutionPlanMetricsSet,
111    max_iterations: usize,
112    timeout: Duration,
113    max_derived_bytes: usize,
114    deterministic_best_by: bool,
115    /// Shared slot for extracting the DerivedStore after execution completes.
116    derived_store_slot: Arc<StdRwLock<Option<DerivedStore>>>,
117    /// Optional provenance tracker injected after construction (via `set_derivation_tracker`).
118    derivation_tracker: Arc<StdRwLock<Option<Arc<DerivationTracker>>>>,
119    /// Shared slot written with per-rule iteration counts after fixpoint convergence.
120    iteration_counts_slot: Arc<StdRwLock<HashMap<String, usize>>>,
121    /// Shared slot written with peak memory bytes after fixpoint completes.
122    peak_memory_slot: Arc<StdRwLock<usize>>,
123}
124
125impl fmt::Debug for LocyProgramExec {
126    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127        f.debug_struct("LocyProgramExec")
128            .field("strata_count", &self.strata.len())
129            .field("commands_count", &self.commands.len())
130            .field("max_iterations", &self.max_iterations)
131            .field("timeout", &self.timeout)
132            .field("output_schema", &self.output_schema)
133            .field("max_derived_bytes", &self.max_derived_bytes)
134            .finish_non_exhaustive()
135    }
136}
137
138impl LocyProgramExec {
139    #[allow(clippy::too_many_arguments)]
140    pub fn new(
141        strata: Vec<LocyStratum>,
142        commands: Vec<LocyCommand>,
143        derived_scan_registry: Arc<DerivedScanRegistry>,
144        graph_ctx: Arc<GraphExecutionContext>,
145        session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
146        storage: Arc<StorageManager>,
147        schema_info: Arc<UniSchema>,
148        params: HashMap<String, Value>,
149        output_schema: SchemaRef,
150        max_iterations: usize,
151        timeout: Duration,
152        max_derived_bytes: usize,
153        deterministic_best_by: bool,
154    ) -> Self {
155        let properties = compute_plan_properties(Arc::clone(&output_schema));
156        Self {
157            strata,
158            commands,
159            derived_scan_registry,
160            graph_ctx,
161            session_ctx,
162            storage,
163            schema_info,
164            params,
165            output_schema,
166            properties,
167            metrics: ExecutionPlanMetricsSet::new(),
168            max_iterations,
169            timeout,
170            max_derived_bytes,
171            deterministic_best_by,
172            derived_store_slot: Arc::new(StdRwLock::new(None)),
173            derivation_tracker: Arc::new(StdRwLock::new(None)),
174            iteration_counts_slot: Arc::new(StdRwLock::new(HashMap::new())),
175            peak_memory_slot: Arc::new(StdRwLock::new(0)),
176        }
177    }
178
179    /// Returns a shared handle to the derived store slot.
180    ///
181    /// After execution completes, the slot contains the `DerivedStore` with all
182    /// converged facts. Read it with `slot.read().unwrap()`.
183    pub fn derived_store_slot(&self) -> Arc<StdRwLock<Option<DerivedStore>>> {
184        Arc::clone(&self.derived_store_slot)
185    }
186
187    /// Inject a `DerivationTracker` to record provenance during fixpoint iteration.
188    ///
189    /// Must be called before `execute()` is invoked (i.e., before DataFusion runs
190    /// the physical plan). Uses interior mutability so it works through `&self`.
191    pub fn set_derivation_tracker(&self, tracker: Arc<DerivationTracker>) {
192        if let Ok(mut guard) = self.derivation_tracker.write() {
193            *guard = Some(tracker);
194        }
195    }
196
197    /// Returns the shared iteration counts slot.
198    ///
199    /// After execution, the slot contains per-rule iteration counts from the
200    /// most recent fixpoint convergence. Sum the values for `total_iterations`.
201    pub fn iteration_counts_slot(&self) -> Arc<StdRwLock<HashMap<String, usize>>> {
202        Arc::clone(&self.iteration_counts_slot)
203    }
204
205    /// Returns the shared peak memory slot.
206    ///
207    /// After execution, the slot contains the peak byte count of derived facts
208    /// across all strata. Read it with `slot.read().unwrap()`.
209    pub fn peak_memory_slot(&self) -> Arc<StdRwLock<usize>> {
210        Arc::clone(&self.peak_memory_slot)
211    }
212}
213
214impl DisplayAs for LocyProgramExec {
215    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
216        write!(
217            f,
218            "LocyProgramExec: strata={}, commands={}, max_iter={}, timeout={:?}",
219            self.strata.len(),
220            self.commands.len(),
221            self.max_iterations,
222            self.timeout,
223        )
224    }
225}
226
227impl ExecutionPlan for LocyProgramExec {
228    fn name(&self) -> &str {
229        "LocyProgramExec"
230    }
231
232    fn as_any(&self) -> &dyn Any {
233        self
234    }
235
236    fn schema(&self) -> SchemaRef {
237        Arc::clone(&self.output_schema)
238    }
239
240    fn properties(&self) -> &PlanProperties {
241        &self.properties
242    }
243
244    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
245        vec![]
246    }
247
248    fn with_new_children(
249        self: Arc<Self>,
250        children: Vec<Arc<dyn ExecutionPlan>>,
251    ) -> DFResult<Arc<dyn ExecutionPlan>> {
252        if !children.is_empty() {
253            return Err(datafusion::error::DataFusionError::Plan(
254                "LocyProgramExec has no children".to_string(),
255            ));
256        }
257        Ok(self)
258    }
259
260    fn execute(
261        &self,
262        partition: usize,
263        _context: Arc<TaskContext>,
264    ) -> DFResult<SendableRecordBatchStream> {
265        let metrics = BaselineMetrics::new(&self.metrics, partition);
266
267        let strata = self.strata.clone();
268        let registry = Arc::clone(&self.derived_scan_registry);
269        let graph_ctx = Arc::clone(&self.graph_ctx);
270        let session_ctx = Arc::clone(&self.session_ctx);
271        let storage = Arc::clone(&self.storage);
272        let schema_info = Arc::clone(&self.schema_info);
273        let params = self.params.clone();
274        let output_schema = Arc::clone(&self.output_schema);
275        let max_iterations = self.max_iterations;
276        let timeout = self.timeout;
277        let max_derived_bytes = self.max_derived_bytes;
278        let deterministic_best_by = self.deterministic_best_by;
279        let derived_store_slot = Arc::clone(&self.derived_store_slot);
280        let iteration_counts_slot = Arc::clone(&self.iteration_counts_slot);
281        let peak_memory_slot = Arc::clone(&self.peak_memory_slot);
282        let derivation_tracker = self.derivation_tracker.read().ok().and_then(|g| g.clone());
283
284        let fut = async move {
285            run_program(
286                strata,
287                registry,
288                graph_ctx,
289                session_ctx,
290                storage,
291                schema_info,
292                params,
293                output_schema,
294                max_iterations,
295                timeout,
296                max_derived_bytes,
297                deterministic_best_by,
298                derived_store_slot,
299                iteration_counts_slot,
300                peak_memory_slot,
301                derivation_tracker,
302            )
303            .await
304        };
305
306        Ok(Box::pin(ProgramStream {
307            state: ProgramStreamState::Running(Box::pin(fut)),
308            schema: Arc::clone(&self.output_schema),
309            metrics,
310        }))
311    }
312
313    fn metrics(&self) -> Option<MetricsSet> {
314        Some(self.metrics.clone_inner())
315    }
316}
317
318// ---------------------------------------------------------------------------
319// ProgramStream — async state machine for streaming results
320// ---------------------------------------------------------------------------
321
322enum ProgramStreamState {
323    Running(Pin<Box<dyn std::future::Future<Output = DFResult<Vec<RecordBatch>>> + Send>>),
324    Emitting(Vec<RecordBatch>, usize),
325    Done,
326}
327
328struct ProgramStream {
329    state: ProgramStreamState,
330    schema: SchemaRef,
331    metrics: BaselineMetrics,
332}
333
334impl Stream for ProgramStream {
335    type Item = DFResult<RecordBatch>;
336
337    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
338        let this = self.get_mut();
339        loop {
340            match &mut this.state {
341                ProgramStreamState::Running(fut) => match fut.as_mut().poll(cx) {
342                    Poll::Ready(Ok(batches)) => {
343                        if batches.is_empty() {
344                            this.state = ProgramStreamState::Done;
345                            return Poll::Ready(None);
346                        }
347                        this.state = ProgramStreamState::Emitting(batches, 0);
348                    }
349                    Poll::Ready(Err(e)) => {
350                        this.state = ProgramStreamState::Done;
351                        return Poll::Ready(Some(Err(e)));
352                    }
353                    Poll::Pending => return Poll::Pending,
354                },
355                ProgramStreamState::Emitting(batches, idx) => {
356                    if *idx >= batches.len() {
357                        this.state = ProgramStreamState::Done;
358                        return Poll::Ready(None);
359                    }
360                    let batch = batches[*idx].clone();
361                    *idx += 1;
362                    this.metrics.record_output(batch.num_rows());
363                    return Poll::Ready(Some(Ok(batch)));
364                }
365                ProgramStreamState::Done => return Poll::Ready(None),
366            }
367        }
368    }
369}
370
371impl RecordBatchStream for ProgramStream {
372    fn schema(&self) -> SchemaRef {
373        Arc::clone(&self.schema)
374    }
375}
376
377// ---------------------------------------------------------------------------
378// run_program — core evaluation algorithm
379// ---------------------------------------------------------------------------
380
381#[allow(clippy::too_many_arguments)]
382async fn run_program(
383    strata: Vec<LocyStratum>,
384    registry: Arc<DerivedScanRegistry>,
385    graph_ctx: Arc<GraphExecutionContext>,
386    session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
387    storage: Arc<StorageManager>,
388    schema_info: Arc<UniSchema>,
389    params: HashMap<String, Value>,
390    output_schema: SchemaRef,
391    max_iterations: usize,
392    timeout: Duration,
393    max_derived_bytes: usize,
394    deterministic_best_by: bool,
395    derived_store_slot: Arc<StdRwLock<Option<DerivedStore>>>,
396    iteration_counts_slot: Arc<StdRwLock<HashMap<String, usize>>>,
397    peak_memory_slot: Arc<StdRwLock<usize>>,
398    derivation_tracker: Option<Arc<DerivationTracker>>,
399) -> DFResult<Vec<RecordBatch>> {
400    let start = Instant::now();
401    let mut derived_store = DerivedStore::new();
402
403    // Evaluate each stratum in topological order
404    for stratum in &strata {
405        // Write cross-stratum facts into registry handles for strata we depend on
406        write_cross_stratum_facts(&registry, &derived_store, stratum);
407
408        let remaining_timeout = timeout.saturating_sub(start.elapsed());
409        if remaining_timeout.is_zero() {
410            return Err(datafusion::error::DataFusionError::Execution(
411                "Locy program timeout exceeded during stratum evaluation".to_string(),
412            ));
413        }
414
415        if stratum.is_recursive {
416            // Convert LocyRulePlan → FixpointRulePlan and run fixpoint
417            let fixpoint_rules =
418                convert_to_fixpoint_plans(&stratum.rules, &registry, deterministic_best_by)?;
419            let fixpoint_schema = build_fixpoint_output_schema(&stratum.rules);
420
421            let exec = FixpointExec::new(
422                fixpoint_rules,
423                max_iterations,
424                remaining_timeout,
425                Arc::clone(&graph_ctx),
426                Arc::clone(&session_ctx),
427                Arc::clone(&storage),
428                Arc::clone(&schema_info),
429                params.clone(),
430                Arc::clone(&registry),
431                fixpoint_schema,
432                max_derived_bytes,
433                derivation_tracker.clone(),
434                Arc::clone(&iteration_counts_slot),
435            );
436
437            let task_ctx = session_ctx.read().task_ctx();
438            let exec_arc: Arc<dyn ExecutionPlan> = Arc::new(exec);
439            let batches = collect_all_partitions(&exec_arc, task_ctx).await?;
440
441            // FixpointExec concatenates all rules' output; store per-rule.
442            // For now, store all output under each rule name (since FixpointExec
443            // handles per-rule state internally, the output is already correct).
444            // TODO: parse output back into per-rule facts when needed for
445            // cross-stratum consumption of individual rules from recursive strata.
446            for rule in &stratum.rules {
447                // Write converged facts into registry handles for cross-stratum consumers
448                let rule_entries = registry.entries_for_rule(&rule.name);
449                for entry in rule_entries {
450                    if !entry.is_self_ref {
451                        // Cross-stratum handles get the full fixpoint output
452                        // In practice, FixpointExec already wrote self-ref handles;
453                        // we need to write non-self-ref handles for later strata.
454                        let all_facts: Vec<RecordBatch> = batches
455                            .iter()
456                            .filter(|b| {
457                                // If schemas match, this batch belongs to this rule
458                                let rule_schema = yield_columns_to_arrow_schema(&rule.yield_schema);
459                                b.schema().fields().len() == rule_schema.fields().len()
460                            })
461                            .cloned()
462                            .collect();
463                        let mut guard = entry.data.write();
464                        *guard = if all_facts.is_empty() {
465                            vec![RecordBatch::new_empty(Arc::clone(&entry.schema))]
466                        } else {
467                            all_facts
468                        };
469                    }
470                }
471                derived_store.insert(rule.name.clone(), batches.clone());
472            }
473        } else {
474            // Non-recursive: single-pass evaluation
475            let fixpoint_rules =
476                convert_to_fixpoint_plans(&stratum.rules, &registry, deterministic_best_by)?;
477            let task_ctx = session_ctx.read().task_ctx();
478
479            for (rule, fp_rule) in stratum.rules.iter().zip(fixpoint_rules.iter()) {
480                let mut facts = evaluate_non_recursive_rule(
481                    rule,
482                    &params,
483                    &graph_ctx,
484                    &session_ctx,
485                    &storage,
486                    &schema_info,
487                )
488                .await?;
489
490                // Apply anti-joins for negated IS-refs (IS NOT semantics).
491                // For non-recursive rules, the negated rule is always in a lower stratum,
492                // so its facts are already in the registry from write_cross_stratum_facts.
493                for clause in &fp_rule.clauses {
494                    for binding in &clause.is_ref_bindings {
495                        if binding.negated
496                            && !binding.anti_join_cols.is_empty()
497                            && let Some(entry) = registry.get(binding.derived_scan_index)
498                        {
499                            let neg_facts = entry.data.read().clone();
500                            if !neg_facts.is_empty() {
501                                for (left_col, right_col) in &binding.anti_join_cols {
502                                    facts = super::locy_fixpoint::apply_anti_join(
503                                        facts, &neg_facts, left_col, right_col,
504                                    )?;
505                                }
506                            }
507                        }
508                    }
509                }
510
511                // Apply post-fixpoint operators (PRIORITY, FOLD, BEST BY)
512                let facts =
513                    super::locy_fixpoint::apply_post_fixpoint_chain(facts, fp_rule, &task_ctx)
514                        .await?;
515
516                // Write facts into registry handles for later strata
517                write_facts_to_registry(&registry, &rule.name, &facts);
518                derived_store.insert(rule.name.clone(), facts);
519            }
520        }
521    }
522
523    // Compute peak memory from derived store byte sizes
524    let peak_bytes: usize = derived_store
525        .relations
526        .values()
527        .flat_map(|batches| batches.iter())
528        .map(|b| {
529            b.columns()
530                .iter()
531                .map(|col| col.get_buffer_memory_size())
532                .sum::<usize>()
533        })
534        .sum();
535    *peak_memory_slot.write().unwrap() = peak_bytes;
536
537    // Commands are dispatched by the caller (e.g., evaluate_native) via the
538    // orchestrator after DataFusion strata evaluation, so run_program only handles
539    // strata evaluation and stores converged facts.
540    let stats = vec![build_stats_batch(&derived_store, &strata, output_schema)];
541    *derived_store_slot.write().unwrap() = Some(derived_store);
542    Ok(stats)
543}
544
545// ---------------------------------------------------------------------------
546// Non-recursive stratum evaluation
547// ---------------------------------------------------------------------------
548
549async fn evaluate_non_recursive_rule(
550    rule: &LocyRulePlan,
551    params: &HashMap<String, Value>,
552    graph_ctx: &Arc<GraphExecutionContext>,
553    session_ctx: &Arc<RwLock<datafusion::prelude::SessionContext>>,
554    storage: &Arc<StorageManager>,
555    schema_info: &Arc<UniSchema>,
556) -> DFResult<Vec<RecordBatch>> {
557    let mut all_batches = Vec::new();
558
559    for clause in &rule.clauses {
560        let batches = execute_subplan(
561            &clause.body,
562            params,
563            &HashMap::new(),
564            graph_ctx,
565            session_ctx,
566            storage,
567            schema_info,
568        )
569        .await?;
570        all_batches.extend(batches);
571    }
572
573    Ok(all_batches)
574}
575
576// ---------------------------------------------------------------------------
577// Cross-stratum fact injection
578// ---------------------------------------------------------------------------
579
580/// Write already-evaluated facts into registry handles for cross-stratum IS-refs.
581fn write_cross_stratum_facts(
582    registry: &DerivedScanRegistry,
583    derived_store: &DerivedStore,
584    stratum: &LocyStratum,
585) {
586    // For each rule in this stratum, find IS-refs to rules in other strata
587    for rule in &stratum.rules {
588        for clause in &rule.clauses {
589            for is_ref in &clause.is_refs {
590                // If this IS-ref points to a rule already in the derived store
591                // (i.e., from a previous stratum), write its facts into the registry
592                if let Some(facts) = derived_store.get(&is_ref.rule_name) {
593                    write_facts_to_registry(registry, &is_ref.rule_name, facts);
594                }
595            }
596        }
597    }
598}
599
600/// Write facts into non-self-ref registry handles for a given rule.
601fn write_facts_to_registry(registry: &DerivedScanRegistry, rule_name: &str, facts: &[RecordBatch]) {
602    let entries = registry.entries_for_rule(rule_name);
603    for entry in entries {
604        if !entry.is_self_ref {
605            let mut guard = entry.data.write();
606            *guard = if facts.is_empty() || facts.iter().all(|b| b.num_rows() == 0) {
607                vec![RecordBatch::new_empty(Arc::clone(&entry.schema))]
608            } else {
609                // Re-wrap batches with the entry's schema to ensure column names and
610                // types match exactly. The column data is preserved; only the schema
611                // metadata (field names) is replaced.
612                facts
613                    .iter()
614                    .filter_map(|b| {
615                        RecordBatch::try_new(Arc::clone(&entry.schema), b.columns().to_vec()).ok()
616                    })
617                    .collect()
618            };
619        }
620    }
621}
622
623// ---------------------------------------------------------------------------
624// LocyRulePlan → FixpointRulePlan conversion
625// ---------------------------------------------------------------------------
626
627/// Convert logical `LocyRulePlan` types to physical `FixpointRulePlan` types.
628fn convert_to_fixpoint_plans(
629    rules: &[LocyRulePlan],
630    registry: &DerivedScanRegistry,
631    deterministic_best_by: bool,
632) -> DFResult<Vec<FixpointRulePlan>> {
633    rules
634        .iter()
635        .map(|rule| {
636            let yield_schema = yield_columns_to_arrow_schema(&rule.yield_schema);
637            let key_column_indices: Vec<usize> = rule
638                .yield_schema
639                .iter()
640                .enumerate()
641                .filter(|(_, yc)| yc.is_key)
642                .map(|(i, _)| i)
643                .collect();
644
645            let clauses: Vec<FixpointClausePlan> = rule
646                .clauses
647                .iter()
648                .map(|clause| {
649                    let is_ref_bindings = convert_is_refs(&clause.is_refs, registry)?;
650                    Ok(FixpointClausePlan {
651                        body_logical: clause.body.clone(),
652                        is_ref_bindings,
653                        priority: clause.priority,
654                    })
655                })
656                .collect::<DFResult<Vec<_>>>()?;
657
658            let fold_bindings = convert_fold_bindings(&rule.fold_bindings, &rule.yield_schema)?;
659            let best_by_criteria =
660                convert_best_by_criteria(&rule.best_by_criteria, &rule.yield_schema)?;
661
662            let has_priority = rule.priority.is_some();
663
664            // Add __priority column to yield schema if PRIORITY is used
665            let yield_schema = if has_priority {
666                let mut fields: Vec<Arc<Field>> = yield_schema.fields().iter().cloned().collect();
667                fields.push(Arc::new(Field::new("__priority", DataType::Int64, true)));
668                ArrowSchema::new(fields)
669            } else {
670                yield_schema
671            };
672
673            Ok(FixpointRulePlan {
674                name: rule.name.clone(),
675                clauses,
676                yield_schema: Arc::new(yield_schema),
677                key_column_indices,
678                priority: rule.priority,
679                has_fold: !rule.fold_bindings.is_empty(),
680                fold_bindings,
681                has_best_by: !rule.best_by_criteria.is_empty(),
682                best_by_criteria,
683                has_priority,
684                deterministic: deterministic_best_by,
685            })
686        })
687        .collect()
688}
689
690/// Convert `LocyIsRef` to `IsRefBinding` by looking up scan indices in the registry.
691fn convert_is_refs(
692    is_refs: &[LocyIsRef],
693    registry: &DerivedScanRegistry,
694) -> DFResult<Vec<IsRefBinding>> {
695    is_refs
696        .iter()
697        .map(|is_ref| {
698            let entries = registry.entries_for_rule(&is_ref.rule_name);
699            // Find the matching entry (prefer self-ref for same-stratum rules)
700            let entry = entries
701                .iter()
702                .find(|e| e.is_self_ref)
703                .or_else(|| entries.first())
704                .ok_or_else(|| {
705                    datafusion::error::DataFusionError::Plan(format!(
706                        "No derived scan entry found for IS-ref to '{}'",
707                        is_ref.rule_name
708                    ))
709                })?;
710
711            // For negated IS-refs, compute (left_body_col, right_derived_col) pairs for
712            // anti-join filtering. Subject vars are assumed to be node variables, so
713            // the body column is `{var}._vid` (UInt64). The derived column name is taken
714            // positionally from the registry entry's schema (KEY columns come first).
715            let anti_join_cols = if is_ref.negated {
716                is_ref
717                    .subjects
718                    .iter()
719                    .enumerate()
720                    .filter_map(|(i, s)| {
721                        if let uni_cypher::ast::Expr::Variable(var) = s {
722                            let right_col = entry
723                                .schema
724                                .fields()
725                                .get(i)
726                                .map(|f| f.name().clone())
727                                .unwrap_or_else(|| var.clone());
728                            // After LocyProject the subject column is renamed to the yield
729                            // column name (just `var`, not `var._vid`). Use bare var as left.
730                            Some((var.clone(), right_col))
731                        } else {
732                            None
733                        }
734                    })
735                    .collect()
736            } else {
737                Vec::new()
738            };
739
740            Ok(IsRefBinding {
741                derived_scan_index: entry.scan_index,
742                rule_name: is_ref.rule_name.clone(),
743                is_self_ref: entry.is_self_ref,
744                negated: is_ref.negated,
745                anti_join_cols,
746            })
747        })
748        .collect()
749}
750
751/// Convert fold binding expressions to physical `FoldBinding`.
752///
753/// The input column is looked up by the fold binding's output name (e.g., "total")
754/// in the yield schema, since the LocyProject aliases the aggregate input expression
755/// to the fold output name.
756fn convert_fold_bindings(
757    fold_bindings: &[(String, Expr)],
758    yield_schema: &[LocyYieldColumn],
759) -> DFResult<Vec<FoldBinding>> {
760    fold_bindings
761        .iter()
762        .map(|(name, expr)| {
763            let (kind, _input_col_name) = parse_fold_aggregate(expr)?;
764            // The LocyProject projects the aggregate input expression AS the fold
765            // output name, so the input column index matches the yield schema position.
766            let input_col_index = yield_schema
767                .iter()
768                .position(|yc| yc.name == *name)
769                .ok_or_else(|| {
770                    datafusion::error::DataFusionError::Plan(format!(
771                        "FOLD column '{}' not found in yield schema",
772                        name
773                    ))
774                })?;
775            Ok(FoldBinding {
776                output_name: name.clone(),
777                kind,
778                input_col_index,
779            })
780        })
781        .collect()
782}
783
784/// Parse a fold aggregate expression into (kind, input_column_name).
785fn parse_fold_aggregate(expr: &Expr) -> DFResult<(FoldAggKind, String)> {
786    match expr {
787        Expr::FunctionCall { name, args, .. } => {
788            let kind = match name.to_uppercase().as_str() {
789                "SUM" | "MSUM" => FoldAggKind::Sum,
790                "MAX" | "MMAX" => FoldAggKind::Max,
791                "MIN" | "MMIN" => FoldAggKind::Min,
792                "COUNT" | "MCOUNT" => FoldAggKind::Count,
793                "AVG" => FoldAggKind::Avg,
794                "COLLECT" => FoldAggKind::Collect,
795                _ => {
796                    return Err(datafusion::error::DataFusionError::Plan(format!(
797                        "Unknown FOLD aggregate function: {}",
798                        name
799                    )));
800                }
801            };
802            let col_name = match args.first() {
803                Some(Expr::Variable(v)) => v.clone(),
804                Some(Expr::Property(_, prop)) => prop.clone(),
805                _ => {
806                    return Err(datafusion::error::DataFusionError::Plan(
807                        "FOLD aggregate argument must be a variable or property reference"
808                            .to_string(),
809                    ));
810                }
811            };
812            Ok((kind, col_name))
813        }
814        _ => Err(datafusion::error::DataFusionError::Plan(
815            "FOLD binding must be a function call (e.g., SUM(x))".to_string(),
816        )),
817    }
818}
819
820/// Convert best-by criteria expressions to physical `SortCriterion`.
821///
822/// Resolves the criteria column by trying:
823/// 1. Property name (e.g., `e.cost` → "cost")
824/// 2. Variable name (e.g., `cost`)
825/// 3. Full expression string (e.g., "e.cost" as a variable name)
826fn convert_best_by_criteria(
827    criteria: &[(Expr, bool)],
828    yield_schema: &[LocyYieldColumn],
829) -> DFResult<Vec<SortCriterion>> {
830    criteria
831        .iter()
832        .map(|(expr, ascending)| {
833            let col_name = match expr {
834                Expr::Property(_, prop) => prop.clone(),
835                Expr::Variable(v) => v.clone(),
836                _ => {
837                    return Err(datafusion::error::DataFusionError::Plan(
838                        "BEST BY criterion must be a variable or property reference".to_string(),
839                    ));
840                }
841            };
842            // Try exact match first, then try just the last component after '.'
843            let col_index = yield_schema
844                .iter()
845                .position(|yc| yc.name == col_name)
846                .or_else(|| {
847                    let short_name = col_name.rsplit('.').next().unwrap_or(&col_name);
848                    yield_schema.iter().position(|yc| yc.name == short_name)
849                })
850                .ok_or_else(|| {
851                    datafusion::error::DataFusionError::Plan(format!(
852                        "BEST BY column '{}' not found in yield schema",
853                        col_name
854                    ))
855                })?;
856            Ok(SortCriterion {
857                col_index,
858                ascending: *ascending,
859                nulls_first: false,
860            })
861        })
862        .collect()
863}
864
865// ---------------------------------------------------------------------------
866// Schema helpers
867// ---------------------------------------------------------------------------
868
869/// Convert `LocyYieldColumn` slice to Arrow schema using inferred types.
870fn yield_columns_to_arrow_schema(columns: &[LocyYieldColumn]) -> ArrowSchema {
871    let fields: Vec<Arc<Field>> = columns
872        .iter()
873        .map(|yc| Arc::new(Field::new(&yc.name, yc.data_type.clone(), true)))
874        .collect();
875    ArrowSchema::new(fields)
876}
877
878/// Build a combined output schema for fixpoint (union of all rules' schemas).
879fn build_fixpoint_output_schema(rules: &[LocyRulePlan]) -> SchemaRef {
880    // FixpointExec concatenates all rules' output, using the first rule's schema
881    // as the output schema (all rules in a recursive stratum share compatible schemas).
882    if let Some(rule) = rules.first() {
883        Arc::new(yield_columns_to_arrow_schema(&rule.yield_schema))
884    } else {
885        Arc::new(ArrowSchema::empty())
886    }
887}
888
889/// Build a stats RecordBatch summarizing derived relation counts.
890fn build_stats_batch(
891    derived_store: &DerivedStore,
892    _strata: &[LocyStratum],
893    output_schema: SchemaRef,
894) -> RecordBatch {
895    // Build a simple stats batch with rule_name and fact_count columns
896    let mut rule_names: Vec<String> = derived_store.rule_names().map(String::from).collect();
897    rule_names.sort();
898
899    let name_col: arrow_array::StringArray = rule_names.iter().map(|s| Some(s.as_str())).collect();
900    let count_col: arrow_array::Int64Array = rule_names
901        .iter()
902        .map(|name| Some(derived_store.fact_count(name) as i64))
903        .collect();
904
905    let stats_schema = stats_schema();
906    RecordBatch::try_new(stats_schema, vec![Arc::new(name_col), Arc::new(count_col)])
907        .unwrap_or_else(|_| RecordBatch::new_empty(output_schema))
908}
909
910/// Schema for the stats batch returned when no commands are present.
911pub fn stats_schema() -> SchemaRef {
912    Arc::new(ArrowSchema::new(vec![
913        Arc::new(Field::new("rule_name", DataType::Utf8, false)),
914        Arc::new(Field::new("fact_count", DataType::Int64, false)),
915    ]))
916}
917
918// ---------------------------------------------------------------------------
919// Unit tests
920// ---------------------------------------------------------------------------
921
922#[cfg(test)]
923mod tests {
924    use super::*;
925    use arrow_array::{Int64Array, LargeBinaryArray, StringArray};
926
927    #[test]
928    fn test_derived_store_insert_and_get() {
929        let mut store = DerivedStore::new();
930        assert!(store.get("test").is_none());
931
932        let schema = Arc::new(ArrowSchema::new(vec![Arc::new(Field::new(
933            "x",
934            DataType::LargeBinary,
935            true,
936        ))]));
937        let batch = RecordBatch::try_new(
938            Arc::clone(&schema),
939            vec![Arc::new(LargeBinaryArray::from(vec![
940                Some(b"a" as &[u8]),
941                Some(b"b"),
942            ]))],
943        )
944        .unwrap();
945
946        store.insert("test".to_string(), vec![batch.clone()]);
947
948        let facts = store.get("test").unwrap();
949        assert_eq!(facts.len(), 1);
950        assert_eq!(facts[0].num_rows(), 2);
951    }
952
953    #[test]
954    fn test_derived_store_fact_count() {
955        let mut store = DerivedStore::new();
956        assert_eq!(store.fact_count("empty"), 0);
957
958        let schema = Arc::new(ArrowSchema::new(vec![Arc::new(Field::new(
959            "x",
960            DataType::LargeBinary,
961            true,
962        ))]));
963        let batch1 = RecordBatch::try_new(
964            Arc::clone(&schema),
965            vec![Arc::new(LargeBinaryArray::from(vec![Some(b"a" as &[u8])]))],
966        )
967        .unwrap();
968        let batch2 = RecordBatch::try_new(
969            Arc::clone(&schema),
970            vec![Arc::new(LargeBinaryArray::from(vec![
971                Some(b"b" as &[u8]),
972                Some(b"c"),
973            ]))],
974        )
975        .unwrap();
976
977        store.insert("test".to_string(), vec![batch1, batch2]);
978        assert_eq!(store.fact_count("test"), 3);
979    }
980
981    #[test]
982    fn test_stats_batch_schema() {
983        let schema = stats_schema();
984        assert_eq!(schema.fields().len(), 2);
985        assert_eq!(schema.field(0).name(), "rule_name");
986        assert_eq!(schema.field(1).name(), "fact_count");
987        assert_eq!(schema.field(0).data_type(), &DataType::Utf8);
988        assert_eq!(schema.field(1).data_type(), &DataType::Int64);
989    }
990
991    #[test]
992    fn test_stats_batch_content() {
993        let mut store = DerivedStore::new();
994        let schema = Arc::new(ArrowSchema::new(vec![Arc::new(Field::new(
995            "x",
996            DataType::LargeBinary,
997            true,
998        ))]));
999        let batch = RecordBatch::try_new(
1000            Arc::clone(&schema),
1001            vec![Arc::new(LargeBinaryArray::from(vec![
1002                Some(b"a" as &[u8]),
1003                Some(b"b"),
1004            ]))],
1005        )
1006        .unwrap();
1007        store.insert("reach".to_string(), vec![batch]);
1008
1009        let output_schema = stats_schema();
1010        let stats = build_stats_batch(&store, &[], Arc::clone(&output_schema));
1011        assert_eq!(stats.num_rows(), 1);
1012
1013        let names = stats
1014            .column(0)
1015            .as_any()
1016            .downcast_ref::<StringArray>()
1017            .unwrap();
1018        assert_eq!(names.value(0), "reach");
1019
1020        let counts = stats
1021            .column(1)
1022            .as_any()
1023            .downcast_ref::<Int64Array>()
1024            .unwrap();
1025        assert_eq!(counts.value(0), 2);
1026    }
1027
1028    #[test]
1029    fn test_yield_columns_to_arrow_schema() {
1030        let columns = vec![
1031            LocyYieldColumn {
1032                name: "a".to_string(),
1033                is_key: true,
1034                data_type: DataType::UInt64,
1035            },
1036            LocyYieldColumn {
1037                name: "b".to_string(),
1038                is_key: false,
1039                data_type: DataType::LargeUtf8,
1040            },
1041            LocyYieldColumn {
1042                name: "c".to_string(),
1043                is_key: true,
1044                data_type: DataType::Float64,
1045            },
1046        ];
1047
1048        let schema = yield_columns_to_arrow_schema(&columns);
1049        assert_eq!(schema.fields().len(), 3);
1050        assert_eq!(schema.field(0).name(), "a");
1051        assert_eq!(schema.field(1).name(), "b");
1052        assert_eq!(schema.field(2).name(), "c");
1053        // Fields use inferred types
1054        assert_eq!(schema.field(0).data_type(), &DataType::UInt64);
1055        assert_eq!(schema.field(1).data_type(), &DataType::LargeUtf8);
1056        assert_eq!(schema.field(2).data_type(), &DataType::Float64);
1057        for field in schema.fields() {
1058            assert!(field.is_nullable());
1059        }
1060    }
1061
1062    #[test]
1063    fn test_key_column_indices() {
1064        let columns = [
1065            LocyYieldColumn {
1066                name: "a".to_string(),
1067                is_key: true,
1068                data_type: DataType::LargeBinary,
1069            },
1070            LocyYieldColumn {
1071                name: "b".to_string(),
1072                is_key: false,
1073                data_type: DataType::LargeBinary,
1074            },
1075            LocyYieldColumn {
1076                name: "c".to_string(),
1077                is_key: true,
1078                data_type: DataType::LargeBinary,
1079            },
1080        ];
1081
1082        let key_indices: Vec<usize> = columns
1083            .iter()
1084            .enumerate()
1085            .filter(|(_, yc)| yc.is_key)
1086            .map(|(i, _)| i)
1087            .collect();
1088        assert_eq!(key_indices, vec![0, 2]);
1089    }
1090
1091    #[test]
1092    fn test_parse_fold_aggregate_sum() {
1093        let expr = Expr::FunctionCall {
1094            name: "SUM".to_string(),
1095            args: vec![Expr::Variable("cost".to_string())],
1096            distinct: false,
1097            window_spec: None,
1098        };
1099        let (kind, col) = parse_fold_aggregate(&expr).unwrap();
1100        assert!(matches!(kind, FoldAggKind::Sum));
1101        assert_eq!(col, "cost");
1102    }
1103
1104    #[test]
1105    fn test_parse_fold_aggregate_monotonic() {
1106        let expr = Expr::FunctionCall {
1107            name: "MMAX".to_string(),
1108            args: vec![Expr::Variable("score".to_string())],
1109            distinct: false,
1110            window_spec: None,
1111        };
1112        let (kind, col) = parse_fold_aggregate(&expr).unwrap();
1113        assert!(matches!(kind, FoldAggKind::Max));
1114        assert_eq!(col, "score");
1115    }
1116
1117    #[test]
1118    fn test_parse_fold_aggregate_unknown() {
1119        let expr = Expr::FunctionCall {
1120            name: "UNKNOWN_AGG".to_string(),
1121            args: vec![Expr::Variable("x".to_string())],
1122            distinct: false,
1123            window_spec: None,
1124        };
1125        assert!(parse_fold_aggregate(&expr).is_err());
1126    }
1127
1128    #[test]
1129    fn test_no_commands_returns_stats() {
1130        let store = DerivedStore::new();
1131        let output_schema = stats_schema();
1132        let stats = build_stats_batch(&store, &[], Arc::clone(&output_schema));
1133        // Empty store → 0 rows
1134        assert_eq!(stats.num_rows(), 0);
1135    }
1136}