Skip to main content

uni_query/query/df_graph/
locy_program.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Top-level Locy program executor.
5//!
6//! `LocyProgramExec` orchestrates the full evaluation of a Locy program:
7//! it evaluates strata in dependency order, runs fixpoint for recursive strata,
8//! applies post-fixpoint operators (FOLD, PRIORITY, BEST BY), and then
9//! executes the program's commands (goal queries, DERIVE, ASSUME, etc.).
10
11use crate::query::df_graph::GraphExecutionContext;
12use crate::query::df_graph::common::{
13    collect_all_partitions, compute_plan_properties, execute_subplan,
14};
15use crate::query::df_graph::locy_best_by::SortCriterion;
16use crate::query::df_graph::locy_explain::ProvenanceStore;
17use crate::query::df_graph::locy_fixpoint::{
18    DerivedScanRegistry, FixpointClausePlan, FixpointExec, FixpointRulePlan, IsRefBinding,
19};
20use crate::query::df_graph::locy_fold::{FoldAggKind, FoldBinding};
21use crate::query::planner_locy_types::{
22    LocyCommand, LocyIsRef, LocyRulePlan, LocyStratum, LocyYieldColumn,
23};
24use arrow_array::RecordBatch;
25use arrow_schema::{DataType, Field, Schema as ArrowSchema, SchemaRef};
26use datafusion::common::Result as DFResult;
27use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
28use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
29use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
30use futures::Stream;
31use parking_lot::RwLock;
32use std::any::Any;
33use std::collections::HashMap;
34use std::fmt;
35use std::pin::Pin;
36use std::sync::Arc;
37use std::sync::RwLock as StdRwLock;
38use std::task::{Context, Poll};
39use std::time::{Duration, Instant};
40use uni_common::Value;
41use uni_common::core::schema::Schema as UniSchema;
42use uni_cypher::ast::Expr;
43use uni_cypher::locy_ast::GoalQuery;
44use uni_locy::{CommandResult, FactRow, RuntimeWarning};
45use uni_store::storage::manager::StorageManager;
46
47// ---------------------------------------------------------------------------
48// DerivedStore — cross-stratum fact sharing
49// ---------------------------------------------------------------------------
50
51/// Simple store for derived relation facts across strata.
52///
53/// Each rule's converged facts are stored here after its stratum completes,
54/// making them available for later strata that depend on them.
55pub struct DerivedStore {
56    relations: HashMap<String, Vec<RecordBatch>>,
57}
58
59impl Default for DerivedStore {
60    fn default() -> Self {
61        Self::new()
62    }
63}
64
65impl DerivedStore {
66    pub fn new() -> Self {
67        Self {
68            relations: HashMap::new(),
69        }
70    }
71
72    pub fn insert(&mut self, rule_name: String, facts: Vec<RecordBatch>) {
73        self.relations.insert(rule_name, facts);
74    }
75
76    pub fn get(&self, rule_name: &str) -> Option<&Vec<RecordBatch>> {
77        self.relations.get(rule_name)
78    }
79
80    pub fn fact_count(&self, rule_name: &str) -> usize {
81        self.relations
82            .get(rule_name)
83            .map(|batches| batches.iter().map(|b| b.num_rows()).sum())
84            .unwrap_or(0)
85    }
86
87    pub fn rule_names(&self) -> impl Iterator<Item = &str> {
88        self.relations.keys().map(|s| s.as_str())
89    }
90}
91
92// ---------------------------------------------------------------------------
93// LocyProgramExec — DataFusion ExecutionPlan
94// ---------------------------------------------------------------------------
95
96/// DataFusion `ExecutionPlan` that runs an entire Locy program.
97///
98/// Evaluates strata in dependency order, using `FixpointExec` for recursive
99/// strata and direct subplan execution for non-recursive ones. After all
100/// strata converge, dispatches commands.
101pub struct LocyProgramExec {
102    strata: Vec<LocyStratum>,
103    commands: Vec<LocyCommand>,
104    derived_scan_registry: Arc<DerivedScanRegistry>,
105    graph_ctx: Arc<GraphExecutionContext>,
106    session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
107    storage: Arc<StorageManager>,
108    schema_info: Arc<UniSchema>,
109    params: HashMap<String, Value>,
110    output_schema: SchemaRef,
111    properties: PlanProperties,
112    metrics: ExecutionPlanMetricsSet,
113    max_iterations: usize,
114    timeout: Duration,
115    max_derived_bytes: usize,
116    deterministic_best_by: bool,
117    strict_probability_domain: bool,
118    probability_epsilon: f64,
119    exact_probability: bool,
120    max_bdd_variables: usize,
121    /// Shared slot for extracting the DerivedStore after execution completes.
122    derived_store_slot: Arc<StdRwLock<Option<DerivedStore>>>,
123    /// Shared slot for groups where BDD fell back to independence mode.
124    approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
125    /// Optional provenance tracker injected after construction (via `set_derivation_tracker`).
126    derivation_tracker: Arc<StdRwLock<Option<Arc<ProvenanceStore>>>>,
127    /// Shared slot written with per-rule iteration counts after fixpoint convergence.
128    iteration_counts_slot: Arc<StdRwLock<HashMap<String, usize>>>,
129    /// Shared slot written with peak memory bytes after fixpoint completes.
130    peak_memory_slot: Arc<StdRwLock<usize>>,
131    /// Shared slot for runtime warnings collected during evaluation.
132    warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
133    /// Shared slot for inline command results (QUERY, Cypher) executed inside `run_program()`.
134    command_results_slot: Arc<StdRwLock<Vec<(usize, CommandResult)>>>,
135    /// Top-k proof filtering: 0 = unlimited (default), >0 = retain at most k proofs per fact.
136    top_k_proofs: usize,
137    /// Shared flag set to true when the evaluation is cut short by a timeout.
138    /// Checked after execution to populate `LocyResult.timed_out`.
139    timeout_flag: Arc<std::sync::atomic::AtomicBool>,
140}
141
142impl fmt::Debug for LocyProgramExec {
143    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
144        f.debug_struct("LocyProgramExec")
145            .field("strata_count", &self.strata.len())
146            .field("commands_count", &self.commands.len())
147            .field("max_iterations", &self.max_iterations)
148            .field("timeout", &self.timeout)
149            .field("output_schema", &self.output_schema)
150            .field("max_derived_bytes", &self.max_derived_bytes)
151            .finish_non_exhaustive()
152    }
153}
154
155impl LocyProgramExec {
156    #[expect(
157        clippy::too_many_arguments,
158        reason = "execution plan node requires full graph and session context"
159    )]
160    pub fn new(
161        strata: Vec<LocyStratum>,
162        commands: Vec<LocyCommand>,
163        derived_scan_registry: Arc<DerivedScanRegistry>,
164        graph_ctx: Arc<GraphExecutionContext>,
165        session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
166        storage: Arc<StorageManager>,
167        schema_info: Arc<UniSchema>,
168        params: HashMap<String, Value>,
169        output_schema: SchemaRef,
170        max_iterations: usize,
171        timeout: Duration,
172        max_derived_bytes: usize,
173        deterministic_best_by: bool,
174        strict_probability_domain: bool,
175        probability_epsilon: f64,
176        exact_probability: bool,
177        max_bdd_variables: usize,
178        top_k_proofs: usize,
179    ) -> Self {
180        let properties = compute_plan_properties(Arc::clone(&output_schema));
181        Self {
182            strata,
183            commands,
184            derived_scan_registry,
185            graph_ctx,
186            session_ctx,
187            storage,
188            schema_info,
189            params,
190            output_schema,
191            properties,
192            metrics: ExecutionPlanMetricsSet::new(),
193            max_iterations,
194            timeout,
195            max_derived_bytes,
196            deterministic_best_by,
197            strict_probability_domain,
198            probability_epsilon,
199            exact_probability,
200            max_bdd_variables,
201            derived_store_slot: Arc::new(StdRwLock::new(None)),
202            approximate_slot: Arc::new(StdRwLock::new(HashMap::new())),
203            derivation_tracker: Arc::new(StdRwLock::new(None)),
204            iteration_counts_slot: Arc::new(StdRwLock::new(HashMap::new())),
205            peak_memory_slot: Arc::new(StdRwLock::new(0)),
206            warnings_slot: Arc::new(StdRwLock::new(Vec::new())),
207            command_results_slot: Arc::new(StdRwLock::new(Vec::new())),
208            top_k_proofs,
209            timeout_flag: Arc::new(std::sync::atomic::AtomicBool::new(false)),
210        }
211    }
212
213    /// Returns a shared handle to the derived store slot.
214    ///
215    /// After execution completes, the slot contains the `DerivedStore` with all
216    /// converged facts. Read it with `slot.read().unwrap()`.
217    pub fn derived_store_slot(&self) -> Arc<StdRwLock<Option<DerivedStore>>> {
218        Arc::clone(&self.derived_store_slot)
219    }
220
221    /// Inject a `ProvenanceStore` to record provenance during fixpoint iteration.
222    ///
223    /// Must be called before `execute()` is invoked (i.e., before DataFusion runs
224    /// the physical plan). Uses interior mutability so it works through `&self`.
225    pub fn set_derivation_tracker(&self, tracker: Arc<ProvenanceStore>) {
226        if let Ok(mut guard) = self.derivation_tracker.write() {
227            *guard = Some(tracker);
228        }
229    }
230
231    /// Returns the shared iteration counts slot.
232    ///
233    /// After execution, the slot contains per-rule iteration counts from the
234    /// most recent fixpoint convergence. Sum the values for `total_iterations`.
235    pub fn iteration_counts_slot(&self) -> Arc<StdRwLock<HashMap<String, usize>>> {
236        Arc::clone(&self.iteration_counts_slot)
237    }
238
239    /// Returns the shared peak memory slot.
240    ///
241    /// After execution, the slot contains the peak byte count of derived facts
242    /// across all strata. Read it with `slot.read().unwrap()`.
243    pub fn peak_memory_slot(&self) -> Arc<StdRwLock<usize>> {
244        Arc::clone(&self.peak_memory_slot)
245    }
246
247    /// Returns the shared runtime warnings slot.
248    ///
249    /// After execution, the slot contains warnings collected during fixpoint
250    /// iteration (e.g. shared probabilistic dependencies).
251    pub fn warnings_slot(&self) -> Arc<StdRwLock<Vec<RuntimeWarning>>> {
252        Arc::clone(&self.warnings_slot)
253    }
254
255    /// Returns the shared approximate groups slot.
256    ///
257    /// After execution, the slot contains rule→key group descriptions for
258    /// groups where BDD computation fell back to independence mode.
259    pub fn approximate_slot(&self) -> Arc<StdRwLock<HashMap<String, Vec<String>>>> {
260        Arc::clone(&self.approximate_slot)
261    }
262
263    /// Returns the shared command results slot.
264    ///
265    /// After execution, the slot contains `(command_index, CommandResult)` pairs
266    /// for commands that were executed inline by `run_program()` (QUERY, Cypher).
267    pub fn command_results_slot(&self) -> Arc<StdRwLock<Vec<(usize, CommandResult)>>> {
268        Arc::clone(&self.command_results_slot)
269    }
270
271    /// Returns the shared timeout flag.
272    ///
273    /// After execution, if this flag is true, the evaluation was cut short by
274    /// a timeout and the derived store contains partial results.
275    pub fn timeout_flag(&self) -> Arc<std::sync::atomic::AtomicBool> {
276        Arc::clone(&self.timeout_flag)
277    }
278}
279
280impl DisplayAs for LocyProgramExec {
281    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
282        write!(
283            f,
284            "LocyProgramExec: strata={}, commands={}, max_iter={}, timeout={:?}",
285            self.strata.len(),
286            self.commands.len(),
287            self.max_iterations,
288            self.timeout,
289        )
290    }
291}
292
293impl ExecutionPlan for LocyProgramExec {
294    fn name(&self) -> &str {
295        "LocyProgramExec"
296    }
297
298    fn as_any(&self) -> &dyn Any {
299        self
300    }
301
302    fn schema(&self) -> SchemaRef {
303        Arc::clone(&self.output_schema)
304    }
305
306    fn properties(&self) -> &PlanProperties {
307        &self.properties
308    }
309
310    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
311        vec![]
312    }
313
314    fn with_new_children(
315        self: Arc<Self>,
316        children: Vec<Arc<dyn ExecutionPlan>>,
317    ) -> DFResult<Arc<dyn ExecutionPlan>> {
318        if !children.is_empty() {
319            return Err(datafusion::error::DataFusionError::Plan(
320                "LocyProgramExec has no children".to_string(),
321            ));
322        }
323        Ok(self)
324    }
325
326    fn execute(
327        &self,
328        partition: usize,
329        _context: Arc<TaskContext>,
330    ) -> DFResult<SendableRecordBatchStream> {
331        let metrics = BaselineMetrics::new(&self.metrics, partition);
332
333        let strata = self.strata.clone();
334        let registry = Arc::clone(&self.derived_scan_registry);
335        let graph_ctx = Arc::clone(&self.graph_ctx);
336        let session_ctx = Arc::clone(&self.session_ctx);
337        let storage = Arc::clone(&self.storage);
338        let schema_info = Arc::clone(&self.schema_info);
339        let params = self.params.clone();
340        let output_schema = Arc::clone(&self.output_schema);
341        let max_iterations = self.max_iterations;
342        let timeout = self.timeout;
343        let max_derived_bytes = self.max_derived_bytes;
344        let deterministic_best_by = self.deterministic_best_by;
345        let strict_probability_domain = self.strict_probability_domain;
346        let probability_epsilon = self.probability_epsilon;
347        let exact_probability = self.exact_probability;
348        let max_bdd_variables = self.max_bdd_variables;
349        let derived_store_slot = Arc::clone(&self.derived_store_slot);
350        let approximate_slot = Arc::clone(&self.approximate_slot);
351        let iteration_counts_slot = Arc::clone(&self.iteration_counts_slot);
352        let peak_memory_slot = Arc::clone(&self.peak_memory_slot);
353        let derivation_tracker = self.derivation_tracker.read().ok().and_then(|g| g.clone());
354        let warnings_slot = Arc::clone(&self.warnings_slot);
355        let commands = self.commands.clone();
356        let command_results_slot = Arc::clone(&self.command_results_slot);
357        let top_k_proofs = self.top_k_proofs;
358        let timeout_flag = Arc::clone(&self.timeout_flag);
359
360        let fut = async move {
361            run_program(
362                strata,
363                commands,
364                registry,
365                graph_ctx,
366                session_ctx,
367                storage,
368                schema_info,
369                params,
370                output_schema,
371                max_iterations,
372                timeout,
373                max_derived_bytes,
374                deterministic_best_by,
375                strict_probability_domain,
376                probability_epsilon,
377                exact_probability,
378                max_bdd_variables,
379                derived_store_slot,
380                approximate_slot,
381                iteration_counts_slot,
382                peak_memory_slot,
383                derivation_tracker,
384                warnings_slot,
385                command_results_slot,
386                top_k_proofs,
387                timeout_flag,
388            )
389            .await
390        };
391
392        Ok(Box::pin(ProgramStream {
393            state: ProgramStreamState::Running(Box::pin(fut)),
394            schema: Arc::clone(&self.output_schema),
395            metrics,
396        }))
397    }
398
399    fn metrics(&self) -> Option<MetricsSet> {
400        Some(self.metrics.clone_inner())
401    }
402}
403
404// ---------------------------------------------------------------------------
405// ProgramStream — async state machine for streaming results
406// ---------------------------------------------------------------------------
407
408enum ProgramStreamState {
409    Running(Pin<Box<dyn std::future::Future<Output = DFResult<Vec<RecordBatch>>> + Send>>),
410    Emitting(Vec<RecordBatch>, usize),
411    Done,
412}
413
414struct ProgramStream {
415    state: ProgramStreamState,
416    schema: SchemaRef,
417    metrics: BaselineMetrics,
418}
419
420impl Stream for ProgramStream {
421    type Item = DFResult<RecordBatch>;
422
423    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
424        let this = self.get_mut();
425        loop {
426            match &mut this.state {
427                ProgramStreamState::Running(fut) => match fut.as_mut().poll(cx) {
428                    Poll::Ready(Ok(batches)) => {
429                        if batches.is_empty() {
430                            this.state = ProgramStreamState::Done;
431                            return Poll::Ready(None);
432                        }
433                        this.state = ProgramStreamState::Emitting(batches, 0);
434                    }
435                    Poll::Ready(Err(e)) => {
436                        this.state = ProgramStreamState::Done;
437                        return Poll::Ready(Some(Err(e)));
438                    }
439                    Poll::Pending => return Poll::Pending,
440                },
441                ProgramStreamState::Emitting(batches, idx) => {
442                    if *idx >= batches.len() {
443                        this.state = ProgramStreamState::Done;
444                        return Poll::Ready(None);
445                    }
446                    let batch = batches[*idx].clone();
447                    *idx += 1;
448                    this.metrics.record_output(batch.num_rows());
449                    return Poll::Ready(Some(Ok(batch)));
450                }
451                ProgramStreamState::Done => return Poll::Ready(None),
452            }
453        }
454    }
455}
456
457impl RecordBatchStream for ProgramStream {
458    fn schema(&self) -> SchemaRef {
459        Arc::clone(&self.schema)
460    }
461}
462
463// ---------------------------------------------------------------------------
464// Inline command execution helpers
465// ---------------------------------------------------------------------------
466
467/// Execute QUERY by scanning converged DerivedStore directly (no SLG).
468///
469/// Strata have already converged all facts, so we can scan the DerivedStore
470/// without re-derivation. WHERE filtering and RETURN projection are applied
471/// in-memory.
472///
473/// NOTE: Currently unused because the DerivedStore uses inferred Arrow types
474/// (Float64 for all property-derived columns), so string property values are
475/// not preserved. Once `infer_expr_type` is improved to use actual schema types,
476/// this function can be re-enabled for QUERYs whose WHERE/RETURN only reference
477/// reliably-typed columns.
478#[allow(dead_code)]
479fn execute_query_inline(
480    query: &GoalQuery,
481    derived_store: &DerivedStore,
482    params: &HashMap<String, Value>,
483) -> DFResult<Vec<FactRow>> {
484    let rule_name = query.rule_name.to_string();
485    let batches = derived_store.get(&rule_name).cloned().unwrap_or_default();
486    let rows = super::locy_eval::record_batches_to_locy_rows(&batches);
487
488    // Apply WHERE filter
489    let filtered = if let Some(ref where_expr) = query.where_expr {
490        rows.into_iter()
491            .filter(|row| {
492                let merged = super::locy_query::merge_params(row, params);
493                super::locy_eval::eval_expr(where_expr, &merged)
494                    .map(|v| v.as_bool().unwrap_or(false))
495                    .unwrap_or(false)
496            })
497            .collect()
498    } else {
499        rows
500    };
501
502    // Apply RETURN (project, distinct, order, skip, limit)
503    super::locy_query::apply_return_clause(filtered, &query.return_clause, params)
504        .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
505}
506
507/// Execute Cypher passthrough via execute_subplan.
508async fn execute_cypher_inline(
509    query: &uni_cypher::ast::Query,
510    schema_info: &Arc<UniSchema>,
511    params: &HashMap<String, Value>,
512    graph_ctx: &Arc<GraphExecutionContext>,
513    session_ctx: &Arc<RwLock<datafusion::prelude::SessionContext>>,
514    storage: &Arc<StorageManager>,
515) -> DFResult<Vec<FactRow>> {
516    let planner = crate::query::planner::QueryPlanner::new(Arc::clone(schema_info));
517    let logical_plan = planner.plan(query.clone()).map_err(|e| {
518        datafusion::error::DataFusionError::Execution(format!("Cypher plan error: {e}"))
519    })?;
520    let batches = execute_subplan(
521        &logical_plan,
522        params,
523        &HashMap::new(),
524        graph_ctx,
525        session_ctx,
526        storage,
527        schema_info,
528    )
529    .await?;
530    Ok(super::locy_eval::record_batches_to_locy_rows(&batches))
531}
532
533/// Returns true if the WHERE or RETURN expressions contain property access
534/// (e.g. `a.name`) that requires full node objects not available in DerivedStore.
535#[allow(dead_code)]
536fn needs_node_enrichment(query: &GoalQuery) -> bool {
537    let where_has_property = query
538        .where_expr
539        .as_ref()
540        .is_some_and(expr_has_property_access);
541    let return_has_property = query.return_clause.as_ref().is_some_and(|rc| {
542        rc.items.iter().any(|item| match item {
543            uni_cypher::ast::ReturnItem::Expr { expr, .. } => expr_has_property_access(expr),
544            uni_cypher::ast::ReturnItem::All => false,
545        })
546    });
547    where_has_property || return_has_property
548}
549
550/// Recursively check whether an expression contains property access (`a.name`).
551#[allow(dead_code)]
552fn expr_has_property_access(expr: &Expr) -> bool {
553    match expr {
554        Expr::Property(..) => true,
555        Expr::BinaryOp { left, right, .. } => {
556            expr_has_property_access(left) || expr_has_property_access(right)
557        }
558        Expr::UnaryOp { expr, .. } => expr_has_property_access(expr),
559        Expr::FunctionCall { args, .. } => args.iter().any(expr_has_property_access),
560        Expr::List(items) => items.iter().any(expr_has_property_access),
561        Expr::Map(entries) => entries.iter().any(|(_, e)| expr_has_property_access(e)),
562        Expr::Case {
563            expr: case_expr,
564            when_then,
565            else_expr,
566        } => {
567            case_expr
568                .as_ref()
569                .is_some_and(|e| expr_has_property_access(e))
570                || when_then
571                    .iter()
572                    .any(|(w, t)| expr_has_property_access(w) || expr_has_property_access(t))
573                || else_expr
574                    .as_ref()
575                    .is_some_and(|e| expr_has_property_access(e))
576        }
577        Expr::IsNull(e) | Expr::IsNotNull(e) | Expr::IsUnique(e) => expr_has_property_access(e),
578        Expr::In { expr, list } => expr_has_property_access(expr) || expr_has_property_access(list),
579        Expr::ArrayIndex { array, index } => {
580            expr_has_property_access(array) || expr_has_property_access(index)
581        }
582        Expr::ArraySlice { array, start, end } => {
583            expr_has_property_access(array)
584                || start.as_ref().is_some_and(|e| expr_has_property_access(e))
585                || end.as_ref().is_some_and(|e| expr_has_property_access(e))
586        }
587        Expr::Quantifier {
588            list, predicate, ..
589        } => expr_has_property_access(list) || expr_has_property_access(predicate),
590        Expr::Reduce {
591            init, list, expr, ..
592        } => {
593            expr_has_property_access(init)
594                || expr_has_property_access(list)
595                || expr_has_property_access(expr)
596        }
597        Expr::ListComprehension {
598            list,
599            where_clause,
600            map_expr,
601            ..
602        } => {
603            expr_has_property_access(list)
604                || where_clause
605                    .as_ref()
606                    .is_some_and(|e| expr_has_property_access(e))
607                || expr_has_property_access(map_expr)
608        }
609        Expr::PatternComprehension {
610            where_clause,
611            map_expr,
612            ..
613        } => {
614            where_clause
615                .as_ref()
616                .is_some_and(|e| expr_has_property_access(e))
617                || expr_has_property_access(map_expr)
618        }
619        Expr::ValidAt {
620            entity, timestamp, ..
621        } => expr_has_property_access(entity) || expr_has_property_access(timestamp),
622        Expr::MapProjection { base, .. } => expr_has_property_access(base),
623        Expr::LabelCheck { expr, .. } => expr_has_property_access(expr),
624        // Leaf nodes (Literal, Parameter, Variable, Wildcard) and subqueries
625        // (Exists, CountSubquery, CollectSubquery) — no property access.
626        _ => false,
627    }
628}
629
630// ---------------------------------------------------------------------------
631// run_program — core evaluation algorithm
632// ---------------------------------------------------------------------------
633
634#[expect(
635    clippy::too_many_arguments,
636    reason = "program evaluation requires full graph and session context"
637)]
638async fn run_program(
639    strata: Vec<LocyStratum>,
640    commands: Vec<LocyCommand>,
641    registry: Arc<DerivedScanRegistry>,
642    graph_ctx: Arc<GraphExecutionContext>,
643    session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
644    storage: Arc<StorageManager>,
645    schema_info: Arc<UniSchema>,
646    params: HashMap<String, Value>,
647    output_schema: SchemaRef,
648    max_iterations: usize,
649    timeout: Duration,
650    max_derived_bytes: usize,
651    deterministic_best_by: bool,
652    strict_probability_domain: bool,
653    probability_epsilon: f64,
654    exact_probability: bool,
655    max_bdd_variables: usize,
656    derived_store_slot: Arc<StdRwLock<Option<DerivedStore>>>,
657    approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
658    iteration_counts_slot: Arc<StdRwLock<HashMap<String, usize>>>,
659    peak_memory_slot: Arc<StdRwLock<usize>>,
660    derivation_tracker: Option<Arc<ProvenanceStore>>,
661    warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
662    command_results_slot: Arc<StdRwLock<Vec<(usize, CommandResult)>>>,
663    top_k_proofs: usize,
664    timeout_flag: Arc<std::sync::atomic::AtomicBool>,
665) -> DFResult<Vec<RecordBatch>> {
666    let start = Instant::now();
667    let mut derived_store = DerivedStore::new();
668
669    // Evaluate each stratum in topological order
670    for stratum in &strata {
671        // Write cross-stratum facts into registry handles for strata we depend on
672        write_cross_stratum_facts(&registry, &derived_store, stratum);
673
674        let remaining_timeout = timeout.saturating_sub(start.elapsed());
675        if remaining_timeout.is_zero() {
676            tracing::warn!("Locy program timeout exceeded during stratum evaluation");
677            timeout_flag.store(true, std::sync::atomic::Ordering::Relaxed);
678            break;
679        }
680
681        if stratum.is_recursive {
682            // Convert LocyRulePlan → FixpointRulePlan and run fixpoint
683            let fixpoint_rules =
684                convert_to_fixpoint_plans(&stratum.rules, &registry, deterministic_best_by)?;
685            let fixpoint_schema = build_fixpoint_output_schema(&stratum.rules);
686
687            let exec = FixpointExec::new(
688                fixpoint_rules,
689                max_iterations,
690                remaining_timeout,
691                Arc::clone(&graph_ctx),
692                Arc::clone(&session_ctx),
693                Arc::clone(&storage),
694                Arc::clone(&schema_info),
695                params.clone(),
696                Arc::clone(&registry),
697                fixpoint_schema,
698                max_derived_bytes,
699                derivation_tracker.clone(),
700                Arc::clone(&iteration_counts_slot),
701                strict_probability_domain,
702                probability_epsilon,
703                exact_probability,
704                max_bdd_variables,
705                Arc::clone(&warnings_slot),
706                Arc::clone(&approximate_slot),
707                top_k_proofs,
708                Arc::clone(&timeout_flag),
709            );
710
711            let task_ctx = session_ctx.read().task_ctx();
712            let exec_arc: Arc<dyn ExecutionPlan> = Arc::new(exec);
713            let batches = collect_all_partitions(&exec_arc, task_ctx).await?;
714
715            // FixpointExec concatenates all rules' output; store per-rule.
716            // For now, store all output under each rule name (since FixpointExec
717            // handles per-rule state internally, the output is already correct).
718            // NOTE(deferred): Per-rule fact demultiplexing is not yet implemented.
719            // FixpointExec concatenates all rules' output into a single batch stream.
720            // Proper demux requires FixpointExec to tag output batches with rule identity
721            // (e.g. an extra column or side-channel), which is a non-trivial change to
722            // run_fixpoint_loop. The current schema-field-count heuristic (filter below)
723            // works because recursive stratum rules share compatible schemas.
724            // Revisit when cross-stratum consumption of individual recursive rules is needed.
725            for rule in &stratum.rules {
726                // Skip DERIVE-only rules (empty yield_schema).
727                if rule.yield_schema.is_empty() {
728                    continue;
729                }
730                // Write converged facts into registry handles for cross-stratum consumers
731                let rule_entries = registry.entries_for_rule(&rule.name);
732                for entry in rule_entries {
733                    if !entry.is_self_ref {
734                        // Cross-stratum handles get the full fixpoint output
735                        // In practice, FixpointExec already wrote self-ref handles;
736                        // we need to write non-self-ref handles for later strata.
737                        let all_facts: Vec<RecordBatch> = batches
738                            .iter()
739                            .filter(|b| {
740                                // If schemas match, this batch belongs to this rule
741                                let rule_schema = yield_columns_to_arrow_schema(&rule.yield_schema);
742                                b.schema().fields().len() == rule_schema.fields().len()
743                            })
744                            .cloned()
745                            .collect();
746                        let mut guard = entry.data.write();
747                        *guard = if all_facts.is_empty() {
748                            vec![RecordBatch::new_empty(Arc::clone(&entry.schema))]
749                        } else {
750                            all_facts
751                        };
752                    }
753                }
754                derived_store.insert(rule.name.clone(), batches.clone());
755            }
756        } else {
757            // Non-recursive: single-pass evaluation
758            let fixpoint_rules =
759                convert_to_fixpoint_plans(&stratum.rules, &registry, deterministic_best_by)?;
760            let task_ctx = session_ctx.read().task_ctx();
761
762            for (rule, fp_rule) in stratum.rules.iter().zip(fixpoint_rules.iter()) {
763                // DERIVE-only rules have empty yield_schema (the compiler's
764                // infer_yield_schema only matches RuleOutput::Yield). Skip them
765                // in the fixpoint loop — DERIVE materialization is handled by
766                // the DERIVE command dispatch, not by the fixpoint.
767                if rule.yield_schema.is_empty() {
768                    continue;
769                }
770
771                // Process each clause independently (per-clause IS NOT).
772                let mut tagged_clause_facts: Vec<(usize, Vec<RecordBatch>)> = Vec::new();
773                for (clause_idx, (clause, fp_clause)) in
774                    rule.clauses.iter().zip(fp_rule.clauses.iter()).enumerate()
775                {
776                    let mut batches = execute_subplan(
777                        &clause.body,
778                        &params,
779                        &HashMap::new(),
780                        &graph_ctx,
781                        &session_ctx,
782                        &storage,
783                        &schema_info,
784                    )
785                    .await?;
786
787                    // Apply negated IS-ref semantics per-clause.
788                    for binding in &fp_clause.is_ref_bindings {
789                        if binding.negated
790                            && !binding.anti_join_cols.is_empty()
791                            && let Some(entry) = registry.get(binding.derived_scan_index)
792                        {
793                            let neg_facts = entry.data.read().clone();
794                            if !neg_facts.is_empty() {
795                                if binding.target_has_prob && fp_rule.prob_column_name.is_some() {
796                                    let complement_col =
797                                        format!("__prob_complement_{}", binding.rule_name);
798                                    if let Some(prob_col) = &binding.target_prob_col {
799                                        batches =
800                                            super::locy_fixpoint::apply_prob_complement_composite(
801                                                batches,
802                                                &neg_facts,
803                                                &binding.anti_join_cols,
804                                                prob_col,
805                                                &complement_col,
806                                            )?;
807                                    } else {
808                                        // target_has_prob but no prob_col: fall back to anti-join.
809                                        batches = super::locy_fixpoint::apply_anti_join_composite(
810                                            batches,
811                                            &neg_facts,
812                                            &binding.anti_join_cols,
813                                        )?;
814                                    }
815                                } else {
816                                    batches = super::locy_fixpoint::apply_anti_join_composite(
817                                        batches,
818                                        &neg_facts,
819                                        &binding.anti_join_cols,
820                                    )?;
821                                }
822                            }
823                        }
824                    }
825
826                    // Multiply complement columns into PROB per-clause.
827                    let complement_cols: Vec<String> = if !batches.is_empty() {
828                        batches[0]
829                            .schema()
830                            .fields()
831                            .iter()
832                            .filter(|f| f.name().starts_with("__prob_complement_"))
833                            .map(|f| f.name().clone())
834                            .collect()
835                    } else {
836                        vec![]
837                    };
838                    if !complement_cols.is_empty() {
839                        batches = super::locy_fixpoint::multiply_prob_factors(
840                            batches,
841                            fp_rule.prob_column_name.as_deref(),
842                            &complement_cols,
843                        )?;
844                    }
845
846                    tagged_clause_facts.push((clause_idx, batches));
847                }
848
849                // Record provenance and detect shared proofs for non-recursive rules.
850                let shared_info = if let Some(ref tracker) = derivation_tracker {
851                    super::locy_fixpoint::record_and_detect_lineage_nonrecursive(
852                        fp_rule,
853                        &tagged_clause_facts,
854                        tracker,
855                        &warnings_slot,
856                        &registry,
857                        top_k_proofs,
858                    )
859                } else {
860                    None
861                };
862
863                // Flatten tagged facts for post-fixpoint chain.
864                let mut all_clause_facts: Vec<RecordBatch> = tagged_clause_facts
865                    .into_iter()
866                    .flat_map(|(_, batches)| batches)
867                    .collect();
868
869                // Apply BDD for shared groups if exact_probability is enabled.
870                if exact_probability
871                    && let Some(ref info) = shared_info
872                    && let Some(ref tracker) = derivation_tracker
873                {
874                    all_clause_facts = super::locy_fixpoint::apply_exact_wmc(
875                        all_clause_facts,
876                        fp_rule,
877                        info,
878                        tracker,
879                        max_bdd_variables,
880                        &warnings_slot,
881                        &approximate_slot,
882                    )?;
883                }
884
885                // Apply post-fixpoint operators (PRIORITY, FOLD, BEST BY) on union.
886                let facts = super::locy_fixpoint::apply_post_fixpoint_chain(
887                    all_clause_facts,
888                    fp_rule,
889                    &task_ctx,
890                    strict_probability_domain,
891                    probability_epsilon,
892                )
893                .await?;
894
895                // Write facts into registry handles for later strata
896                write_facts_to_registry(&registry, &rule.name, &facts);
897                derived_store.insert(rule.name.clone(), facts);
898            }
899        }
900    }
901
902    // Compute peak memory from derived store byte sizes
903    let peak_bytes: usize = derived_store
904        .relations
905        .values()
906        .flat_map(|batches| batches.iter())
907        .map(|b| {
908            b.columns()
909                .iter()
910                .map(|col| col.get_buffer_memory_size())
911                .sum::<usize>()
912        })
913        .sum();
914    *peak_memory_slot.write().unwrap() = peak_bytes;
915
916    // Execute inline Cypher commands via execute_subplan.
917    // QUERY is deferred to the orchestrator: the DerivedStore uses inferred types
918    // (e.g. Float64 for property-derived columns) which don't preserve the actual
919    // property values. The orchestrator's SLG path re-derives with correct types.
920    // DERIVE/ASSUME/EXPLAIN/ABDUCE are also deferred (need L0 fork/restore, tree output, etc.).
921    //
922    // Cypher commands that appear AFTER a DERIVE command are also deferred:
923    // they need the ephemeral L0 overlay populated by DERIVE to see derived
924    // edges, which is only available in the orchestrator's dispatch loop.
925    let first_derive_idx = commands
926        .iter()
927        .position(|c| matches!(c, LocyCommand::Derive { .. }));
928    let mut inline_results: Vec<(usize, CommandResult)> = Vec::new();
929    for (cmd_idx, cmd) in commands.iter().enumerate() {
930        if let LocyCommand::Cypher { query } = cmd {
931            // Defer Cypher commands that follow a DERIVE to the dispatch loop
932            // so they can read from the ephemeral L0 overlay.
933            if first_derive_idx.is_some_and(|di| cmd_idx > di) {
934                continue;
935            }
936            let rows = execute_cypher_inline(
937                query,
938                &schema_info,
939                &params,
940                &graph_ctx,
941                &session_ctx,
942                &storage,
943            )
944            .await?;
945            inline_results.push((cmd_idx, CommandResult::Cypher(rows)));
946        }
947    }
948    *command_results_slot.write().unwrap() = inline_results;
949
950    let stats = vec![build_stats_batch(&derived_store, &strata, output_schema)];
951    *derived_store_slot.write().unwrap() = Some(derived_store);
952    Ok(stats)
953}
954
955// ---------------------------------------------------------------------------
956// Cross-stratum fact injection
957// ---------------------------------------------------------------------------
958
959/// Write already-evaluated facts into registry handles for cross-stratum IS-refs.
960fn write_cross_stratum_facts(
961    registry: &DerivedScanRegistry,
962    derived_store: &DerivedStore,
963    stratum: &LocyStratum,
964) {
965    // For each rule in this stratum, find IS-refs to rules in other strata
966    for rule in &stratum.rules {
967        for clause in &rule.clauses {
968            for is_ref in &clause.is_refs {
969                // If this IS-ref points to a rule already in the derived store
970                // (i.e., from a previous stratum), write its facts into the registry
971                if let Some(facts) = derived_store.get(&is_ref.rule_name) {
972                    write_facts_to_registry(registry, &is_ref.rule_name, facts);
973                }
974            }
975        }
976    }
977}
978
979/// Write facts into non-self-ref registry handles for a given rule.
980fn write_facts_to_registry(registry: &DerivedScanRegistry, rule_name: &str, facts: &[RecordBatch]) {
981    let entries = registry.entries_for_rule(rule_name);
982    for entry in entries {
983        if !entry.is_self_ref {
984            let mut guard = entry.data.write();
985            *guard = if facts.is_empty() || facts.iter().all(|b| b.num_rows() == 0) {
986                vec![RecordBatch::new_empty(Arc::clone(&entry.schema))]
987            } else {
988                // Try to re-wrap batches with the entry's schema for column name
989                // alignment. If the types don't match (e.g. inferred Float64 vs
990                // actual Utf8 from schema mode), fall back to the batch's own
991                // schema to avoid silent data loss.
992                facts
993                    .iter()
994                    .filter(|b| b.num_rows() > 0)
995                    .map(|b| {
996                        RecordBatch::try_new(Arc::clone(&entry.schema), b.columns().to_vec())
997                            .unwrap_or_else(|_| b.clone())
998                    })
999                    .collect()
1000            };
1001        }
1002    }
1003}
1004
1005// ---------------------------------------------------------------------------
1006// LocyRulePlan → FixpointRulePlan conversion
1007// ---------------------------------------------------------------------------
1008
1009/// Convert logical `LocyRulePlan` types to physical `FixpointRulePlan` types.
1010fn convert_to_fixpoint_plans(
1011    rules: &[LocyRulePlan],
1012    registry: &DerivedScanRegistry,
1013    deterministic_best_by: bool,
1014) -> DFResult<Vec<FixpointRulePlan>> {
1015    rules
1016        .iter()
1017        .map(|rule| {
1018            let yield_schema = yield_columns_to_arrow_schema(&rule.yield_schema);
1019            let key_column_indices: Vec<usize> = rule
1020                .yield_schema
1021                .iter()
1022                .enumerate()
1023                .filter(|(_, yc)| yc.is_key)
1024                .map(|(i, _)| i)
1025                .collect();
1026
1027            let clauses: Vec<FixpointClausePlan> = rule
1028                .clauses
1029                .iter()
1030                .map(|clause| {
1031                    let is_ref_bindings = convert_is_refs(&clause.is_refs, registry)?;
1032                    Ok(FixpointClausePlan {
1033                        body_logical: clause.body.clone(),
1034                        is_ref_bindings,
1035                        priority: clause.priority,
1036                        along_bindings: clause.along_bindings.clone(),
1037                    })
1038                })
1039                .collect::<DFResult<Vec<_>>>()?;
1040
1041            let fold_bindings = convert_fold_bindings(&rule.fold_bindings, &rule.yield_schema)?;
1042            let best_by_criteria =
1043                convert_best_by_criteria(&rule.best_by_criteria, &rule.yield_schema)?;
1044
1045            let has_priority = rule.priority.is_some();
1046
1047            // Add __priority column to yield schema if PRIORITY is used
1048            let yield_schema = if has_priority {
1049                let mut fields: Vec<Arc<Field>> = yield_schema.fields().iter().cloned().collect();
1050                fields.push(Arc::new(Field::new("__priority", DataType::Int64, true)));
1051                ArrowSchema::new(fields)
1052            } else {
1053                yield_schema
1054            };
1055
1056            let prob_column_name = rule
1057                .yield_schema
1058                .iter()
1059                .find(|yc| yc.is_prob)
1060                .map(|yc| yc.name.clone());
1061
1062            Ok(FixpointRulePlan {
1063                name: rule.name.clone(),
1064                clauses,
1065                yield_schema: Arc::new(yield_schema),
1066                key_column_indices,
1067                priority: rule.priority,
1068                has_fold: !rule.fold_bindings.is_empty(),
1069                fold_bindings,
1070                having: rule.having.clone(),
1071                has_best_by: !rule.best_by_criteria.is_empty(),
1072                best_by_criteria,
1073                has_priority,
1074                deterministic: deterministic_best_by,
1075                prob_column_name,
1076            })
1077        })
1078        .collect()
1079}
1080
1081/// Convert `LocyIsRef` to `IsRefBinding` by looking up scan indices in the registry.
1082fn convert_is_refs(
1083    is_refs: &[LocyIsRef],
1084    registry: &DerivedScanRegistry,
1085) -> DFResult<Vec<IsRefBinding>> {
1086    is_refs
1087        .iter()
1088        .map(|is_ref| {
1089            let entries = registry.entries_for_rule(&is_ref.rule_name);
1090            // Find the matching entry (prefer self-ref for same-stratum rules)
1091            let entry = entries
1092                .iter()
1093                .find(|e| e.is_self_ref)
1094                .or_else(|| entries.first())
1095                .ok_or_else(|| {
1096                    datafusion::error::DataFusionError::Plan(format!(
1097                        "No derived scan entry found for IS-ref to '{}'",
1098                        is_ref.rule_name
1099                    ))
1100                })?;
1101
1102            // For negated IS-refs, compute (left_body_col, right_derived_col) pairs for
1103            // anti-join filtering. Subject vars are assumed to be node variables, so
1104            // the body column is `{var}._vid` (UInt64). The derived column name is taken
1105            // positionally from the registry entry's schema (KEY columns come first).
1106            let anti_join_cols = if is_ref.negated {
1107                let mut cols: Vec<(String, String)> = is_ref
1108                    .subjects
1109                    .iter()
1110                    .enumerate()
1111                    .filter_map(|(i, s)| {
1112                        if let uni_cypher::ast::Expr::Variable(var) = s {
1113                            let right_col = entry
1114                                .schema
1115                                .fields()
1116                                .get(i)
1117                                .map(|f| f.name().clone())
1118                                .unwrap_or_else(|| var.clone());
1119                            // After LocyProject the subject column is renamed to the yield
1120                            // column name (just `var`, not `var._vid`). Use bare var as left.
1121                            Some((var.clone(), right_col))
1122                        } else {
1123                            None
1124                        }
1125                    })
1126                    .collect();
1127                // Include target variable in anti-join for composite-key IS NOT.
1128                // Without this, `d IS NOT known TO dis` only checks d, not (d, dis),
1129                // filtering ALL pairs where the drug has ANY indication regardless
1130                // of disease.
1131                if let Some(uni_cypher::ast::Expr::Variable(target_var)) = &is_ref.target {
1132                    let target_idx = is_ref.subjects.len();
1133                    if let Some(field) = entry.schema.fields().get(target_idx) {
1134                        cols.push((target_var.clone(), field.name().clone()));
1135                    }
1136                }
1137                cols
1138            } else {
1139                Vec::new()
1140            };
1141
1142            // Provenance join cols: for ALL IS-refs (not just negated), compute
1143            // (body_col, derived_col) pairs so shared-proof detection can trace
1144            // which source facts contributed to each derived row.
1145            let provenance_join_cols: Vec<(String, String)> = is_ref
1146                .subjects
1147                .iter()
1148                .enumerate()
1149                .filter_map(|(i, s)| {
1150                    if let uni_cypher::ast::Expr::Variable(var) = s {
1151                        let right_col = entry
1152                            .schema
1153                            .fields()
1154                            .get(i)
1155                            .map(|f| f.name().clone())
1156                            .unwrap_or_else(|| var.clone());
1157                        Some((var.clone(), right_col))
1158                    } else {
1159                        None
1160                    }
1161                })
1162                .collect();
1163
1164            Ok(IsRefBinding {
1165                derived_scan_index: entry.scan_index,
1166                rule_name: is_ref.rule_name.clone(),
1167                is_self_ref: entry.is_self_ref,
1168                negated: is_ref.negated,
1169                anti_join_cols,
1170                target_has_prob: is_ref.target_has_prob,
1171                target_prob_col: is_ref.target_prob_col.clone(),
1172                provenance_join_cols,
1173            })
1174        })
1175        .collect()
1176}
1177
1178/// Convert fold binding expressions to physical `FoldBinding`.
1179///
1180/// The input column is looked up by the fold binding's output name (e.g., "total")
1181/// in the yield schema, since the LocyProject aliases the aggregate input expression
1182/// to the fold output name.
1183fn convert_fold_bindings(
1184    fold_bindings: &[(String, String, Expr)],
1185    yield_schema: &[LocyYieldColumn],
1186) -> DFResult<Vec<FoldBinding>> {
1187    fold_bindings
1188        .iter()
1189        .map(|(name, yield_alias, expr)| {
1190            let (kind, _input_col_name) = parse_fold_aggregate(expr)?;
1191
1192            // CountAll has no input column — LocyProject skips the output column
1193            // entirely, so there is nothing to look up.
1194            if kind == FoldAggKind::CountAll {
1195                return Ok(FoldBinding {
1196                    output_name: yield_alias.clone(),
1197                    kind,
1198                    input_col_index: 0, // unused for CountAll
1199                    input_col_name: None,
1200                });
1201            }
1202
1203            // The LocyProject projects the aggregate input expression AS the fold
1204            // output name, so the input column index matches the yield schema position.
1205            // Also store the column name for name-based resolution at execution time
1206            // (more robust when schema reconciliation changes column ordering).
1207            let input_col_index = yield_schema
1208                .iter()
1209                .position(|yc| yc.name == *name || yc.name == *yield_alias)
1210                .unwrap_or(0);
1211            Ok(FoldBinding {
1212                output_name: yield_alias.clone(),
1213                kind,
1214                input_col_index,
1215                input_col_name: Some(name.clone()),
1216            })
1217        })
1218        .collect()
1219}
1220
1221/// Parse a fold aggregate expression into (kind, input_column_name).
1222fn parse_fold_aggregate(expr: &Expr) -> DFResult<(FoldAggKind, String)> {
1223    match expr {
1224        Expr::FunctionCall { name, args, .. } => {
1225            let upper = name.to_uppercase();
1226            let is_count = matches!(upper.as_str(), "COUNT" | "MCOUNT");
1227
1228            // COUNT/MCOUNT with zero args → CountAll (like SQL COUNT(*))
1229            if is_count && args.is_empty() {
1230                return Ok((FoldAggKind::CountAll, String::new()));
1231            }
1232
1233            let kind = match upper.as_str() {
1234                "SUM" | "MSUM" => FoldAggKind::Sum,
1235                "MAX" | "MMAX" => FoldAggKind::Max,
1236                "MIN" | "MMIN" => FoldAggKind::Min,
1237                "COUNT" | "MCOUNT" => FoldAggKind::Count,
1238                "AVG" => FoldAggKind::Avg,
1239                "COLLECT" => FoldAggKind::Collect,
1240                "MNOR" => FoldAggKind::Nor,
1241                "MPROD" => FoldAggKind::Prod,
1242                _ => {
1243                    return Err(datafusion::error::DataFusionError::Plan(format!(
1244                        "Unknown FOLD aggregate function: {}",
1245                        name
1246                    )));
1247                }
1248            };
1249            let col_name = match args.first() {
1250                Some(Expr::Variable(v)) => v.clone(),
1251                Some(Expr::Property(_, prop)) => prop.clone(),
1252                Some(other) => other.to_string_repr(),
1253                None => {
1254                    return Err(datafusion::error::DataFusionError::Plan(
1255                        "FOLD aggregate function requires at least one argument".to_string(),
1256                    ));
1257                }
1258            };
1259            Ok((kind, col_name))
1260        }
1261        _ => Err(datafusion::error::DataFusionError::Plan(
1262            "FOLD binding must be a function call (e.g., SUM(x))".to_string(),
1263        )),
1264    }
1265}
1266
1267/// Convert best-by criteria expressions to physical `SortCriterion`.
1268///
1269/// Resolves the criteria column by trying:
1270/// 1. Property name (e.g., `e.cost` → "cost")
1271/// 2. Variable name (e.g., `cost`)
1272/// 3. Full expression string (e.g., "e.cost" as a variable name)
1273fn convert_best_by_criteria(
1274    criteria: &[(Expr, bool)],
1275    yield_schema: &[LocyYieldColumn],
1276) -> DFResult<Vec<SortCriterion>> {
1277    criteria
1278        .iter()
1279        .map(|(expr, ascending)| {
1280            let col_name = match expr {
1281                Expr::Property(_, prop) => prop.clone(),
1282                Expr::Variable(v) => v.clone(),
1283                _ => {
1284                    return Err(datafusion::error::DataFusionError::Plan(
1285                        "BEST BY criterion must be a variable or property reference".to_string(),
1286                    ));
1287                }
1288            };
1289            // Try exact match first, then try just the last component after '.'
1290            let col_index = yield_schema
1291                .iter()
1292                .position(|yc| yc.name == col_name)
1293                .or_else(|| {
1294                    let short_name = col_name.rsplit('.').next().unwrap_or(&col_name);
1295                    yield_schema.iter().position(|yc| yc.name == short_name)
1296                })
1297                .ok_or_else(|| {
1298                    datafusion::error::DataFusionError::Plan(format!(
1299                        "BEST BY column '{}' not found in yield schema",
1300                        col_name
1301                    ))
1302                })?;
1303            Ok(SortCriterion {
1304                col_index,
1305                ascending: *ascending,
1306                nulls_first: false,
1307            })
1308        })
1309        .collect()
1310}
1311
1312// ---------------------------------------------------------------------------
1313// Schema helpers
1314// ---------------------------------------------------------------------------
1315
1316/// Convert `LocyYieldColumn` slice to Arrow schema using inferred types.
1317fn yield_columns_to_arrow_schema(columns: &[LocyYieldColumn]) -> ArrowSchema {
1318    let fields: Vec<Arc<Field>> = columns
1319        .iter()
1320        .map(|yc| Arc::new(Field::new(&yc.name, yc.data_type.clone(), true)))
1321        .collect();
1322    ArrowSchema::new(fields)
1323}
1324
1325/// Build a combined output schema for fixpoint (union of all rules' schemas).
1326fn build_fixpoint_output_schema(rules: &[LocyRulePlan]) -> SchemaRef {
1327    // FixpointExec concatenates all rules' output, using the first rule's schema
1328    // as the output schema (all rules in a recursive stratum share compatible schemas).
1329    if let Some(rule) = rules.first() {
1330        Arc::new(yield_columns_to_arrow_schema(&rule.yield_schema))
1331    } else {
1332        Arc::new(ArrowSchema::empty())
1333    }
1334}
1335
1336/// Build a stats RecordBatch summarizing derived relation counts.
1337fn build_stats_batch(
1338    derived_store: &DerivedStore,
1339    _strata: &[LocyStratum],
1340    output_schema: SchemaRef,
1341) -> RecordBatch {
1342    // Build a simple stats batch with rule_name and fact_count columns
1343    let mut rule_names: Vec<String> = derived_store.rule_names().map(String::from).collect();
1344    rule_names.sort();
1345
1346    let name_col: arrow_array::StringArray = rule_names.iter().map(|s| Some(s.as_str())).collect();
1347    let count_col: arrow_array::Int64Array = rule_names
1348        .iter()
1349        .map(|name| Some(derived_store.fact_count(name) as i64))
1350        .collect();
1351
1352    let stats_schema = stats_schema();
1353    RecordBatch::try_new(stats_schema, vec![Arc::new(name_col), Arc::new(count_col)])
1354        .unwrap_or_else(|_| RecordBatch::new_empty(output_schema))
1355}
1356
1357/// Schema for the stats batch returned when no commands are present.
1358pub fn stats_schema() -> SchemaRef {
1359    Arc::new(ArrowSchema::new(vec![
1360        Arc::new(Field::new("rule_name", DataType::Utf8, false)),
1361        Arc::new(Field::new("fact_count", DataType::Int64, false)),
1362    ]))
1363}
1364
1365// ---------------------------------------------------------------------------
1366// Unit tests
1367// ---------------------------------------------------------------------------
1368
1369#[cfg(test)]
1370mod tests {
1371    use super::*;
1372    use arrow_array::{Int64Array, LargeBinaryArray, StringArray};
1373
1374    #[test]
1375    fn test_derived_store_insert_and_get() {
1376        let mut store = DerivedStore::new();
1377        assert!(store.get("test").is_none());
1378
1379        let schema = Arc::new(ArrowSchema::new(vec![Arc::new(Field::new(
1380            "x",
1381            DataType::LargeBinary,
1382            true,
1383        ))]));
1384        let batch = RecordBatch::try_new(
1385            Arc::clone(&schema),
1386            vec![Arc::new(LargeBinaryArray::from(vec![
1387                Some(b"a" as &[u8]),
1388                Some(b"b"),
1389            ]))],
1390        )
1391        .unwrap();
1392
1393        store.insert("test".to_string(), vec![batch.clone()]);
1394
1395        let facts = store.get("test").unwrap();
1396        assert_eq!(facts.len(), 1);
1397        assert_eq!(facts[0].num_rows(), 2);
1398    }
1399
1400    #[test]
1401    fn test_derived_store_fact_count() {
1402        let mut store = DerivedStore::new();
1403        assert_eq!(store.fact_count("empty"), 0);
1404
1405        let schema = Arc::new(ArrowSchema::new(vec![Arc::new(Field::new(
1406            "x",
1407            DataType::LargeBinary,
1408            true,
1409        ))]));
1410        let batch1 = RecordBatch::try_new(
1411            Arc::clone(&schema),
1412            vec![Arc::new(LargeBinaryArray::from(vec![Some(b"a" as &[u8])]))],
1413        )
1414        .unwrap();
1415        let batch2 = RecordBatch::try_new(
1416            Arc::clone(&schema),
1417            vec![Arc::new(LargeBinaryArray::from(vec![
1418                Some(b"b" as &[u8]),
1419                Some(b"c"),
1420            ]))],
1421        )
1422        .unwrap();
1423
1424        store.insert("test".to_string(), vec![batch1, batch2]);
1425        assert_eq!(store.fact_count("test"), 3);
1426    }
1427
1428    #[test]
1429    fn test_stats_batch_schema() {
1430        let schema = stats_schema();
1431        assert_eq!(schema.fields().len(), 2);
1432        assert_eq!(schema.field(0).name(), "rule_name");
1433        assert_eq!(schema.field(1).name(), "fact_count");
1434        assert_eq!(schema.field(0).data_type(), &DataType::Utf8);
1435        assert_eq!(schema.field(1).data_type(), &DataType::Int64);
1436    }
1437
1438    #[test]
1439    fn test_stats_batch_content() {
1440        let mut store = DerivedStore::new();
1441        let schema = Arc::new(ArrowSchema::new(vec![Arc::new(Field::new(
1442            "x",
1443            DataType::LargeBinary,
1444            true,
1445        ))]));
1446        let batch = RecordBatch::try_new(
1447            Arc::clone(&schema),
1448            vec![Arc::new(LargeBinaryArray::from(vec![
1449                Some(b"a" as &[u8]),
1450                Some(b"b"),
1451            ]))],
1452        )
1453        .unwrap();
1454        store.insert("reach".to_string(), vec![batch]);
1455
1456        let output_schema = stats_schema();
1457        let stats = build_stats_batch(&store, &[], Arc::clone(&output_schema));
1458        assert_eq!(stats.num_rows(), 1);
1459
1460        let names = stats
1461            .column(0)
1462            .as_any()
1463            .downcast_ref::<StringArray>()
1464            .unwrap();
1465        assert_eq!(names.value(0), "reach");
1466
1467        let counts = stats
1468            .column(1)
1469            .as_any()
1470            .downcast_ref::<Int64Array>()
1471            .unwrap();
1472        assert_eq!(counts.value(0), 2);
1473    }
1474
1475    #[test]
1476    fn test_yield_columns_to_arrow_schema() {
1477        let columns = vec![
1478            LocyYieldColumn {
1479                name: "a".to_string(),
1480                is_key: true,
1481                is_prob: false,
1482                data_type: DataType::UInt64,
1483            },
1484            LocyYieldColumn {
1485                name: "b".to_string(),
1486                is_key: false,
1487                is_prob: false,
1488                data_type: DataType::LargeUtf8,
1489            },
1490            LocyYieldColumn {
1491                name: "c".to_string(),
1492                is_key: true,
1493                is_prob: false,
1494                data_type: DataType::Float64,
1495            },
1496        ];
1497
1498        let schema = yield_columns_to_arrow_schema(&columns);
1499        assert_eq!(schema.fields().len(), 3);
1500        assert_eq!(schema.field(0).name(), "a");
1501        assert_eq!(schema.field(1).name(), "b");
1502        assert_eq!(schema.field(2).name(), "c");
1503        // Fields use inferred types
1504        assert_eq!(schema.field(0).data_type(), &DataType::UInt64);
1505        assert_eq!(schema.field(1).data_type(), &DataType::LargeUtf8);
1506        assert_eq!(schema.field(2).data_type(), &DataType::Float64);
1507        for field in schema.fields() {
1508            assert!(field.is_nullable());
1509        }
1510    }
1511
1512    #[test]
1513    fn test_key_column_indices() {
1514        let columns = [
1515            LocyYieldColumn {
1516                name: "a".to_string(),
1517                is_key: true,
1518                is_prob: false,
1519                data_type: DataType::LargeBinary,
1520            },
1521            LocyYieldColumn {
1522                name: "b".to_string(),
1523                is_key: false,
1524                is_prob: false,
1525                data_type: DataType::LargeBinary,
1526            },
1527            LocyYieldColumn {
1528                name: "c".to_string(),
1529                is_key: true,
1530                is_prob: false,
1531                data_type: DataType::LargeBinary,
1532            },
1533        ];
1534
1535        let key_indices: Vec<usize> = columns
1536            .iter()
1537            .enumerate()
1538            .filter(|(_, yc)| yc.is_key)
1539            .map(|(i, _)| i)
1540            .collect();
1541        assert_eq!(key_indices, vec![0, 2]);
1542    }
1543
1544    #[test]
1545    fn test_parse_fold_aggregate_sum() {
1546        let expr = Expr::FunctionCall {
1547            name: "SUM".to_string(),
1548            args: vec![Expr::Variable("cost".to_string())],
1549            distinct: false,
1550            window_spec: None,
1551        };
1552        let (kind, col) = parse_fold_aggregate(&expr).unwrap();
1553        assert!(matches!(kind, FoldAggKind::Sum));
1554        assert_eq!(col, "cost");
1555    }
1556
1557    #[test]
1558    fn test_parse_fold_aggregate_monotonic() {
1559        let expr = Expr::FunctionCall {
1560            name: "MMAX".to_string(),
1561            args: vec![Expr::Variable("score".to_string())],
1562            distinct: false,
1563            window_spec: None,
1564        };
1565        let (kind, col) = parse_fold_aggregate(&expr).unwrap();
1566        assert!(matches!(kind, FoldAggKind::Max));
1567        assert_eq!(col, "score");
1568    }
1569
1570    #[test]
1571    fn test_parse_fold_aggregate_unknown() {
1572        let expr = Expr::FunctionCall {
1573            name: "UNKNOWN_AGG".to_string(),
1574            args: vec![Expr::Variable("x".to_string())],
1575            distinct: false,
1576            window_spec: None,
1577        };
1578        assert!(parse_fold_aggregate(&expr).is_err());
1579    }
1580
1581    #[test]
1582    fn test_no_commands_returns_stats() {
1583        let store = DerivedStore::new();
1584        let output_schema = stats_schema();
1585        let stats = build_stats_batch(&store, &[], Arc::clone(&output_schema));
1586        // Empty store → 0 rows
1587        assert_eq!(stats.num_rows(), 0);
1588    }
1589}