Skip to main content

uni_query/query/df_graph/
locy_program.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Top-level Locy program executor.
5//!
6//! `LocyProgramExec` orchestrates the full evaluation of a Locy program:
7//! it evaluates strata in dependency order, runs fixpoint for recursive strata,
8//! applies post-fixpoint operators (FOLD, PRIORITY, BEST BY), and then
9//! executes the program's commands (goal queries, DERIVE, ASSUME, etc.).
10
11use crate::query::df_graph::GraphExecutionContext;
12use crate::query::df_graph::common::{
13    collect_all_partitions, compute_plan_properties, execute_subplan,
14};
15use crate::query::df_graph::locy_best_by::SortCriterion;
16use crate::query::df_graph::locy_explain::ProvenanceStore;
17use crate::query::df_graph::locy_fixpoint::{
18    DerivedScanRegistry, FixpointClausePlan, FixpointExec, FixpointRulePlan, IsRefBinding,
19};
20use crate::query::df_graph::locy_fold::{FoldAggKind, FoldBinding};
21use crate::query::planner_locy_types::{
22    LocyCommand, LocyIsRef, LocyRulePlan, LocyStratum, LocyYieldColumn,
23};
24use arrow_array::RecordBatch;
25use arrow_schema::{DataType, Field, Schema as ArrowSchema, SchemaRef};
26use datafusion::common::Result as DFResult;
27use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
28use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
29use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
30use futures::Stream;
31use parking_lot::RwLock;
32use std::any::Any;
33use std::collections::HashMap;
34use std::fmt;
35use std::pin::Pin;
36use std::sync::Arc;
37use std::sync::RwLock as StdRwLock;
38use std::task::{Context, Poll};
39use std::time::{Duration, Instant};
40use uni_common::Value;
41use uni_common::core::schema::Schema as UniSchema;
42use uni_cypher::ast::Expr;
43use uni_cypher::locy_ast::GoalQuery;
44use uni_locy::{CommandResult, FactRow, RuntimeWarning};
45use uni_store::storage::manager::StorageManager;
46
47// ---------------------------------------------------------------------------
48// DerivedStore — cross-stratum fact sharing
49// ---------------------------------------------------------------------------
50
51/// Simple store for derived relation facts across strata.
52///
53/// Each rule's converged facts are stored here after its stratum completes,
54/// making them available for later strata that depend on them.
55pub struct DerivedStore {
56    relations: HashMap<String, Vec<RecordBatch>>,
57}
58
59impl Default for DerivedStore {
60    fn default() -> Self {
61        Self::new()
62    }
63}
64
65impl DerivedStore {
66    pub fn new() -> Self {
67        Self {
68            relations: HashMap::new(),
69        }
70    }
71
72    pub fn insert(&mut self, rule_name: String, facts: Vec<RecordBatch>) {
73        self.relations.insert(rule_name, facts);
74    }
75
76    pub fn get(&self, rule_name: &str) -> Option<&Vec<RecordBatch>> {
77        self.relations.get(rule_name)
78    }
79
80    pub fn fact_count(&self, rule_name: &str) -> usize {
81        self.relations
82            .get(rule_name)
83            .map(|batches| batches.iter().map(|b| b.num_rows()).sum())
84            .unwrap_or(0)
85    }
86
87    pub fn rule_names(&self) -> impl Iterator<Item = &str> {
88        self.relations.keys().map(|s| s.as_str())
89    }
90}
91
92// ---------------------------------------------------------------------------
93// LocyProgramExec — DataFusion ExecutionPlan
94// ---------------------------------------------------------------------------
95
96/// DataFusion `ExecutionPlan` that runs an entire Locy program.
97///
98/// Evaluates strata in dependency order, using `FixpointExec` for recursive
99/// strata and direct subplan execution for non-recursive ones. After all
100/// strata converge, dispatches commands.
101pub struct LocyProgramExec {
102    strata: Vec<LocyStratum>,
103    commands: Vec<LocyCommand>,
104    derived_scan_registry: Arc<DerivedScanRegistry>,
105    graph_ctx: Arc<GraphExecutionContext>,
106    session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
107    storage: Arc<StorageManager>,
108    schema_info: Arc<UniSchema>,
109    params: HashMap<String, Value>,
110    output_schema: SchemaRef,
111    properties: PlanProperties,
112    metrics: ExecutionPlanMetricsSet,
113    max_iterations: usize,
114    timeout: Duration,
115    max_derived_bytes: usize,
116    deterministic_best_by: bool,
117    strict_probability_domain: bool,
118    probability_epsilon: f64,
119    exact_probability: bool,
120    max_bdd_variables: usize,
121    /// Shared slot for extracting the DerivedStore after execution completes.
122    derived_store_slot: Arc<StdRwLock<Option<DerivedStore>>>,
123    /// Shared slot for groups where BDD fell back to independence mode.
124    approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
125    /// Optional provenance tracker injected after construction (via `set_derivation_tracker`).
126    derivation_tracker: Arc<StdRwLock<Option<Arc<ProvenanceStore>>>>,
127    /// Shared slot written with per-rule iteration counts after fixpoint convergence.
128    iteration_counts_slot: Arc<StdRwLock<HashMap<String, usize>>>,
129    /// Shared slot written with peak memory bytes after fixpoint completes.
130    peak_memory_slot: Arc<StdRwLock<usize>>,
131    /// Shared slot for runtime warnings collected during evaluation.
132    warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
133    /// Shared slot for inline command results (QUERY, Cypher) executed inside `run_program()`.
134    command_results_slot: Arc<StdRwLock<Vec<(usize, CommandResult)>>>,
135    /// Top-k proof filtering: 0 = unlimited (default), >0 = retain at most k proofs per fact.
136    top_k_proofs: usize,
137}
138
139impl fmt::Debug for LocyProgramExec {
140    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
141        f.debug_struct("LocyProgramExec")
142            .field("strata_count", &self.strata.len())
143            .field("commands_count", &self.commands.len())
144            .field("max_iterations", &self.max_iterations)
145            .field("timeout", &self.timeout)
146            .field("output_schema", &self.output_schema)
147            .field("max_derived_bytes", &self.max_derived_bytes)
148            .finish_non_exhaustive()
149    }
150}
151
152impl LocyProgramExec {
153    #[expect(
154        clippy::too_many_arguments,
155        reason = "execution plan node requires full graph and session context"
156    )]
157    pub fn new(
158        strata: Vec<LocyStratum>,
159        commands: Vec<LocyCommand>,
160        derived_scan_registry: Arc<DerivedScanRegistry>,
161        graph_ctx: Arc<GraphExecutionContext>,
162        session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
163        storage: Arc<StorageManager>,
164        schema_info: Arc<UniSchema>,
165        params: HashMap<String, Value>,
166        output_schema: SchemaRef,
167        max_iterations: usize,
168        timeout: Duration,
169        max_derived_bytes: usize,
170        deterministic_best_by: bool,
171        strict_probability_domain: bool,
172        probability_epsilon: f64,
173        exact_probability: bool,
174        max_bdd_variables: usize,
175        top_k_proofs: usize,
176    ) -> Self {
177        let properties = compute_plan_properties(Arc::clone(&output_schema));
178        Self {
179            strata,
180            commands,
181            derived_scan_registry,
182            graph_ctx,
183            session_ctx,
184            storage,
185            schema_info,
186            params,
187            output_schema,
188            properties,
189            metrics: ExecutionPlanMetricsSet::new(),
190            max_iterations,
191            timeout,
192            max_derived_bytes,
193            deterministic_best_by,
194            strict_probability_domain,
195            probability_epsilon,
196            exact_probability,
197            max_bdd_variables,
198            derived_store_slot: Arc::new(StdRwLock::new(None)),
199            approximate_slot: Arc::new(StdRwLock::new(HashMap::new())),
200            derivation_tracker: Arc::new(StdRwLock::new(None)),
201            iteration_counts_slot: Arc::new(StdRwLock::new(HashMap::new())),
202            peak_memory_slot: Arc::new(StdRwLock::new(0)),
203            warnings_slot: Arc::new(StdRwLock::new(Vec::new())),
204            command_results_slot: Arc::new(StdRwLock::new(Vec::new())),
205            top_k_proofs,
206        }
207    }
208
209    /// Returns a shared handle to the derived store slot.
210    ///
211    /// After execution completes, the slot contains the `DerivedStore` with all
212    /// converged facts. Read it with `slot.read().unwrap()`.
213    pub fn derived_store_slot(&self) -> Arc<StdRwLock<Option<DerivedStore>>> {
214        Arc::clone(&self.derived_store_slot)
215    }
216
217    /// Inject a `ProvenanceStore` to record provenance during fixpoint iteration.
218    ///
219    /// Must be called before `execute()` is invoked (i.e., before DataFusion runs
220    /// the physical plan). Uses interior mutability so it works through `&self`.
221    pub fn set_derivation_tracker(&self, tracker: Arc<ProvenanceStore>) {
222        if let Ok(mut guard) = self.derivation_tracker.write() {
223            *guard = Some(tracker);
224        }
225    }
226
227    /// Returns the shared iteration counts slot.
228    ///
229    /// After execution, the slot contains per-rule iteration counts from the
230    /// most recent fixpoint convergence. Sum the values for `total_iterations`.
231    pub fn iteration_counts_slot(&self) -> Arc<StdRwLock<HashMap<String, usize>>> {
232        Arc::clone(&self.iteration_counts_slot)
233    }
234
235    /// Returns the shared peak memory slot.
236    ///
237    /// After execution, the slot contains the peak byte count of derived facts
238    /// across all strata. Read it with `slot.read().unwrap()`.
239    pub fn peak_memory_slot(&self) -> Arc<StdRwLock<usize>> {
240        Arc::clone(&self.peak_memory_slot)
241    }
242
243    /// Returns the shared runtime warnings slot.
244    ///
245    /// After execution, the slot contains warnings collected during fixpoint
246    /// iteration (e.g. shared probabilistic dependencies).
247    pub fn warnings_slot(&self) -> Arc<StdRwLock<Vec<RuntimeWarning>>> {
248        Arc::clone(&self.warnings_slot)
249    }
250
251    /// Returns the shared approximate groups slot.
252    ///
253    /// After execution, the slot contains rule→key group descriptions for
254    /// groups where BDD computation fell back to independence mode.
255    pub fn approximate_slot(&self) -> Arc<StdRwLock<HashMap<String, Vec<String>>>> {
256        Arc::clone(&self.approximate_slot)
257    }
258
259    /// Returns the shared command results slot.
260    ///
261    /// After execution, the slot contains `(command_index, CommandResult)` pairs
262    /// for commands that were executed inline by `run_program()` (QUERY, Cypher).
263    pub fn command_results_slot(&self) -> Arc<StdRwLock<Vec<(usize, CommandResult)>>> {
264        Arc::clone(&self.command_results_slot)
265    }
266}
267
268impl DisplayAs for LocyProgramExec {
269    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
270        write!(
271            f,
272            "LocyProgramExec: strata={}, commands={}, max_iter={}, timeout={:?}",
273            self.strata.len(),
274            self.commands.len(),
275            self.max_iterations,
276            self.timeout,
277        )
278    }
279}
280
281impl ExecutionPlan for LocyProgramExec {
282    fn name(&self) -> &str {
283        "LocyProgramExec"
284    }
285
286    fn as_any(&self) -> &dyn Any {
287        self
288    }
289
290    fn schema(&self) -> SchemaRef {
291        Arc::clone(&self.output_schema)
292    }
293
294    fn properties(&self) -> &PlanProperties {
295        &self.properties
296    }
297
298    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
299        vec![]
300    }
301
302    fn with_new_children(
303        self: Arc<Self>,
304        children: Vec<Arc<dyn ExecutionPlan>>,
305    ) -> DFResult<Arc<dyn ExecutionPlan>> {
306        if !children.is_empty() {
307            return Err(datafusion::error::DataFusionError::Plan(
308                "LocyProgramExec has no children".to_string(),
309            ));
310        }
311        Ok(self)
312    }
313
314    fn execute(
315        &self,
316        partition: usize,
317        _context: Arc<TaskContext>,
318    ) -> DFResult<SendableRecordBatchStream> {
319        let metrics = BaselineMetrics::new(&self.metrics, partition);
320
321        let strata = self.strata.clone();
322        let registry = Arc::clone(&self.derived_scan_registry);
323        let graph_ctx = Arc::clone(&self.graph_ctx);
324        let session_ctx = Arc::clone(&self.session_ctx);
325        let storage = Arc::clone(&self.storage);
326        let schema_info = Arc::clone(&self.schema_info);
327        let params = self.params.clone();
328        let output_schema = Arc::clone(&self.output_schema);
329        let max_iterations = self.max_iterations;
330        let timeout = self.timeout;
331        let max_derived_bytes = self.max_derived_bytes;
332        let deterministic_best_by = self.deterministic_best_by;
333        let strict_probability_domain = self.strict_probability_domain;
334        let probability_epsilon = self.probability_epsilon;
335        let exact_probability = self.exact_probability;
336        let max_bdd_variables = self.max_bdd_variables;
337        let derived_store_slot = Arc::clone(&self.derived_store_slot);
338        let approximate_slot = Arc::clone(&self.approximate_slot);
339        let iteration_counts_slot = Arc::clone(&self.iteration_counts_slot);
340        let peak_memory_slot = Arc::clone(&self.peak_memory_slot);
341        let derivation_tracker = self.derivation_tracker.read().ok().and_then(|g| g.clone());
342        let warnings_slot = Arc::clone(&self.warnings_slot);
343        let commands = self.commands.clone();
344        let command_results_slot = Arc::clone(&self.command_results_slot);
345        let top_k_proofs = self.top_k_proofs;
346
347        let fut = async move {
348            run_program(
349                strata,
350                commands,
351                registry,
352                graph_ctx,
353                session_ctx,
354                storage,
355                schema_info,
356                params,
357                output_schema,
358                max_iterations,
359                timeout,
360                max_derived_bytes,
361                deterministic_best_by,
362                strict_probability_domain,
363                probability_epsilon,
364                exact_probability,
365                max_bdd_variables,
366                derived_store_slot,
367                approximate_slot,
368                iteration_counts_slot,
369                peak_memory_slot,
370                derivation_tracker,
371                warnings_slot,
372                command_results_slot,
373                top_k_proofs,
374            )
375            .await
376        };
377
378        Ok(Box::pin(ProgramStream {
379            state: ProgramStreamState::Running(Box::pin(fut)),
380            schema: Arc::clone(&self.output_schema),
381            metrics,
382        }))
383    }
384
385    fn metrics(&self) -> Option<MetricsSet> {
386        Some(self.metrics.clone_inner())
387    }
388}
389
390// ---------------------------------------------------------------------------
391// ProgramStream — async state machine for streaming results
392// ---------------------------------------------------------------------------
393
394enum ProgramStreamState {
395    Running(Pin<Box<dyn std::future::Future<Output = DFResult<Vec<RecordBatch>>> + Send>>),
396    Emitting(Vec<RecordBatch>, usize),
397    Done,
398}
399
400struct ProgramStream {
401    state: ProgramStreamState,
402    schema: SchemaRef,
403    metrics: BaselineMetrics,
404}
405
406impl Stream for ProgramStream {
407    type Item = DFResult<RecordBatch>;
408
409    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
410        let this = self.get_mut();
411        loop {
412            match &mut this.state {
413                ProgramStreamState::Running(fut) => match fut.as_mut().poll(cx) {
414                    Poll::Ready(Ok(batches)) => {
415                        if batches.is_empty() {
416                            this.state = ProgramStreamState::Done;
417                            return Poll::Ready(None);
418                        }
419                        this.state = ProgramStreamState::Emitting(batches, 0);
420                    }
421                    Poll::Ready(Err(e)) => {
422                        this.state = ProgramStreamState::Done;
423                        return Poll::Ready(Some(Err(e)));
424                    }
425                    Poll::Pending => return Poll::Pending,
426                },
427                ProgramStreamState::Emitting(batches, idx) => {
428                    if *idx >= batches.len() {
429                        this.state = ProgramStreamState::Done;
430                        return Poll::Ready(None);
431                    }
432                    let batch = batches[*idx].clone();
433                    *idx += 1;
434                    this.metrics.record_output(batch.num_rows());
435                    return Poll::Ready(Some(Ok(batch)));
436                }
437                ProgramStreamState::Done => return Poll::Ready(None),
438            }
439        }
440    }
441}
442
443impl RecordBatchStream for ProgramStream {
444    fn schema(&self) -> SchemaRef {
445        Arc::clone(&self.schema)
446    }
447}
448
449// ---------------------------------------------------------------------------
450// Inline command execution helpers
451// ---------------------------------------------------------------------------
452
453/// Execute QUERY by scanning converged DerivedStore directly (no SLG).
454///
455/// Strata have already converged all facts, so we can scan the DerivedStore
456/// without re-derivation. WHERE filtering and RETURN projection are applied
457/// in-memory.
458///
459/// NOTE: Currently unused because the DerivedStore uses inferred Arrow types
460/// (Float64 for all property-derived columns), so string property values are
461/// not preserved. Once `infer_expr_type` is improved to use actual schema types,
462/// this function can be re-enabled for QUERYs whose WHERE/RETURN only reference
463/// reliably-typed columns.
464#[allow(dead_code)]
465fn execute_query_inline(
466    query: &GoalQuery,
467    derived_store: &DerivedStore,
468    params: &HashMap<String, Value>,
469) -> DFResult<Vec<FactRow>> {
470    let rule_name = query.rule_name.to_string();
471    let batches = derived_store.get(&rule_name).cloned().unwrap_or_default();
472    let rows = super::locy_eval::record_batches_to_locy_rows(&batches);
473
474    // Apply WHERE filter
475    let filtered = if let Some(ref where_expr) = query.where_expr {
476        rows.into_iter()
477            .filter(|row| {
478                let merged = super::locy_query::merge_params(row, params);
479                super::locy_eval::eval_expr(where_expr, &merged)
480                    .map(|v| v.as_bool().unwrap_or(false))
481                    .unwrap_or(false)
482            })
483            .collect()
484    } else {
485        rows
486    };
487
488    // Apply RETURN (project, distinct, order, skip, limit)
489    super::locy_query::apply_return_clause(filtered, &query.return_clause, params)
490        .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
491}
492
493/// Execute Cypher passthrough via execute_subplan.
494async fn execute_cypher_inline(
495    query: &uni_cypher::ast::Query,
496    schema_info: &Arc<UniSchema>,
497    params: &HashMap<String, Value>,
498    graph_ctx: &Arc<GraphExecutionContext>,
499    session_ctx: &Arc<RwLock<datafusion::prelude::SessionContext>>,
500    storage: &Arc<StorageManager>,
501) -> DFResult<Vec<FactRow>> {
502    let planner = crate::query::planner::QueryPlanner::new(Arc::clone(schema_info));
503    let logical_plan = planner.plan(query.clone()).map_err(|e| {
504        datafusion::error::DataFusionError::Execution(format!("Cypher plan error: {e}"))
505    })?;
506    let batches = execute_subplan(
507        &logical_plan,
508        params,
509        &HashMap::new(),
510        graph_ctx,
511        session_ctx,
512        storage,
513        schema_info,
514    )
515    .await?;
516    Ok(super::locy_eval::record_batches_to_locy_rows(&batches))
517}
518
519/// Returns true if the WHERE or RETURN expressions contain property access
520/// (e.g. `a.name`) that requires full node objects not available in DerivedStore.
521#[allow(dead_code)]
522fn needs_node_enrichment(query: &GoalQuery) -> bool {
523    let where_has_property = query
524        .where_expr
525        .as_ref()
526        .is_some_and(expr_has_property_access);
527    let return_has_property = query.return_clause.as_ref().is_some_and(|rc| {
528        rc.items.iter().any(|item| match item {
529            uni_cypher::ast::ReturnItem::Expr { expr, .. } => expr_has_property_access(expr),
530            uni_cypher::ast::ReturnItem::All => false,
531        })
532    });
533    where_has_property || return_has_property
534}
535
536/// Recursively check whether an expression contains property access (`a.name`).
537#[allow(dead_code)]
538fn expr_has_property_access(expr: &Expr) -> bool {
539    match expr {
540        Expr::Property(..) => true,
541        Expr::BinaryOp { left, right, .. } => {
542            expr_has_property_access(left) || expr_has_property_access(right)
543        }
544        Expr::UnaryOp { expr, .. } => expr_has_property_access(expr),
545        Expr::FunctionCall { args, .. } => args.iter().any(expr_has_property_access),
546        Expr::List(items) => items.iter().any(expr_has_property_access),
547        Expr::Map(entries) => entries.iter().any(|(_, e)| expr_has_property_access(e)),
548        Expr::Case {
549            expr: case_expr,
550            when_then,
551            else_expr,
552        } => {
553            case_expr
554                .as_ref()
555                .is_some_and(|e| expr_has_property_access(e))
556                || when_then
557                    .iter()
558                    .any(|(w, t)| expr_has_property_access(w) || expr_has_property_access(t))
559                || else_expr
560                    .as_ref()
561                    .is_some_and(|e| expr_has_property_access(e))
562        }
563        Expr::IsNull(e) | Expr::IsNotNull(e) | Expr::IsUnique(e) => expr_has_property_access(e),
564        Expr::In { expr, list } => expr_has_property_access(expr) || expr_has_property_access(list),
565        Expr::ArrayIndex { array, index } => {
566            expr_has_property_access(array) || expr_has_property_access(index)
567        }
568        Expr::ArraySlice { array, start, end } => {
569            expr_has_property_access(array)
570                || start.as_ref().is_some_and(|e| expr_has_property_access(e))
571                || end.as_ref().is_some_and(|e| expr_has_property_access(e))
572        }
573        Expr::Quantifier {
574            list, predicate, ..
575        } => expr_has_property_access(list) || expr_has_property_access(predicate),
576        Expr::Reduce {
577            init, list, expr, ..
578        } => {
579            expr_has_property_access(init)
580                || expr_has_property_access(list)
581                || expr_has_property_access(expr)
582        }
583        Expr::ListComprehension {
584            list,
585            where_clause,
586            map_expr,
587            ..
588        } => {
589            expr_has_property_access(list)
590                || where_clause
591                    .as_ref()
592                    .is_some_and(|e| expr_has_property_access(e))
593                || expr_has_property_access(map_expr)
594        }
595        Expr::PatternComprehension {
596            where_clause,
597            map_expr,
598            ..
599        } => {
600            where_clause
601                .as_ref()
602                .is_some_and(|e| expr_has_property_access(e))
603                || expr_has_property_access(map_expr)
604        }
605        Expr::ValidAt {
606            entity, timestamp, ..
607        } => expr_has_property_access(entity) || expr_has_property_access(timestamp),
608        Expr::MapProjection { base, .. } => expr_has_property_access(base),
609        Expr::LabelCheck { expr, .. } => expr_has_property_access(expr),
610        // Leaf nodes (Literal, Parameter, Variable, Wildcard) and subqueries
611        // (Exists, CountSubquery, CollectSubquery) — no property access.
612        _ => false,
613    }
614}
615
616// ---------------------------------------------------------------------------
617// run_program — core evaluation algorithm
618// ---------------------------------------------------------------------------
619
620#[expect(
621    clippy::too_many_arguments,
622    reason = "program evaluation requires full graph and session context"
623)]
624async fn run_program(
625    strata: Vec<LocyStratum>,
626    commands: Vec<LocyCommand>,
627    registry: Arc<DerivedScanRegistry>,
628    graph_ctx: Arc<GraphExecutionContext>,
629    session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
630    storage: Arc<StorageManager>,
631    schema_info: Arc<UniSchema>,
632    params: HashMap<String, Value>,
633    output_schema: SchemaRef,
634    max_iterations: usize,
635    timeout: Duration,
636    max_derived_bytes: usize,
637    deterministic_best_by: bool,
638    strict_probability_domain: bool,
639    probability_epsilon: f64,
640    exact_probability: bool,
641    max_bdd_variables: usize,
642    derived_store_slot: Arc<StdRwLock<Option<DerivedStore>>>,
643    approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
644    iteration_counts_slot: Arc<StdRwLock<HashMap<String, usize>>>,
645    peak_memory_slot: Arc<StdRwLock<usize>>,
646    derivation_tracker: Option<Arc<ProvenanceStore>>,
647    warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
648    command_results_slot: Arc<StdRwLock<Vec<(usize, CommandResult)>>>,
649    top_k_proofs: usize,
650) -> DFResult<Vec<RecordBatch>> {
651    let start = Instant::now();
652    let mut derived_store = DerivedStore::new();
653
654    // Evaluate each stratum in topological order
655    for stratum in &strata {
656        // Write cross-stratum facts into registry handles for strata we depend on
657        write_cross_stratum_facts(&registry, &derived_store, stratum);
658
659        let remaining_timeout = timeout.saturating_sub(start.elapsed());
660        if remaining_timeout.is_zero() {
661            return Err(datafusion::error::DataFusionError::Execution(
662                "Locy program timeout exceeded during stratum evaluation".to_string(),
663            ));
664        }
665
666        if stratum.is_recursive {
667            // Convert LocyRulePlan → FixpointRulePlan and run fixpoint
668            let fixpoint_rules =
669                convert_to_fixpoint_plans(&stratum.rules, &registry, deterministic_best_by)?;
670            let fixpoint_schema = build_fixpoint_output_schema(&stratum.rules);
671
672            let exec = FixpointExec::new(
673                fixpoint_rules,
674                max_iterations,
675                remaining_timeout,
676                Arc::clone(&graph_ctx),
677                Arc::clone(&session_ctx),
678                Arc::clone(&storage),
679                Arc::clone(&schema_info),
680                params.clone(),
681                Arc::clone(&registry),
682                fixpoint_schema,
683                max_derived_bytes,
684                derivation_tracker.clone(),
685                Arc::clone(&iteration_counts_slot),
686                strict_probability_domain,
687                probability_epsilon,
688                exact_probability,
689                max_bdd_variables,
690                Arc::clone(&warnings_slot),
691                Arc::clone(&approximate_slot),
692                top_k_proofs,
693            );
694
695            let task_ctx = session_ctx.read().task_ctx();
696            let exec_arc: Arc<dyn ExecutionPlan> = Arc::new(exec);
697            let batches = collect_all_partitions(&exec_arc, task_ctx).await?;
698
699            // FixpointExec concatenates all rules' output; store per-rule.
700            // For now, store all output under each rule name (since FixpointExec
701            // handles per-rule state internally, the output is already correct).
702            // NOTE(deferred): Per-rule fact demultiplexing is not yet implemented.
703            // FixpointExec concatenates all rules' output into a single batch stream.
704            // Proper demux requires FixpointExec to tag output batches with rule identity
705            // (e.g. an extra column or side-channel), which is a non-trivial change to
706            // run_fixpoint_loop. The current schema-field-count heuristic (filter below)
707            // works because recursive stratum rules share compatible schemas.
708            // Revisit when cross-stratum consumption of individual recursive rules is needed.
709            for rule in &stratum.rules {
710                // Skip DERIVE-only rules (empty yield_schema).
711                if rule.yield_schema.is_empty() {
712                    continue;
713                }
714                // Write converged facts into registry handles for cross-stratum consumers
715                let rule_entries = registry.entries_for_rule(&rule.name);
716                for entry in rule_entries {
717                    if !entry.is_self_ref {
718                        // Cross-stratum handles get the full fixpoint output
719                        // In practice, FixpointExec already wrote self-ref handles;
720                        // we need to write non-self-ref handles for later strata.
721                        let all_facts: Vec<RecordBatch> = batches
722                            .iter()
723                            .filter(|b| {
724                                // If schemas match, this batch belongs to this rule
725                                let rule_schema = yield_columns_to_arrow_schema(&rule.yield_schema);
726                                b.schema().fields().len() == rule_schema.fields().len()
727                            })
728                            .cloned()
729                            .collect();
730                        let mut guard = entry.data.write();
731                        *guard = if all_facts.is_empty() {
732                            vec![RecordBatch::new_empty(Arc::clone(&entry.schema))]
733                        } else {
734                            all_facts
735                        };
736                    }
737                }
738                derived_store.insert(rule.name.clone(), batches.clone());
739            }
740        } else {
741            // Non-recursive: single-pass evaluation
742            let fixpoint_rules =
743                convert_to_fixpoint_plans(&stratum.rules, &registry, deterministic_best_by)?;
744            let task_ctx = session_ctx.read().task_ctx();
745
746            for (rule, fp_rule) in stratum.rules.iter().zip(fixpoint_rules.iter()) {
747                // DERIVE-only rules have empty yield_schema (the compiler's
748                // infer_yield_schema only matches RuleOutput::Yield). Skip them
749                // in the fixpoint loop — DERIVE materialization is handled by
750                // the DERIVE command dispatch, not by the fixpoint.
751                if rule.yield_schema.is_empty() {
752                    continue;
753                }
754
755                // Process each clause independently (per-clause IS NOT).
756                let mut tagged_clause_facts: Vec<(usize, Vec<RecordBatch>)> = Vec::new();
757                for (clause_idx, (clause, fp_clause)) in
758                    rule.clauses.iter().zip(fp_rule.clauses.iter()).enumerate()
759                {
760                    let mut batches = execute_subplan(
761                        &clause.body,
762                        &params,
763                        &HashMap::new(),
764                        &graph_ctx,
765                        &session_ctx,
766                        &storage,
767                        &schema_info,
768                    )
769                    .await?;
770
771                    // Apply negated IS-ref semantics per-clause.
772                    for binding in &fp_clause.is_ref_bindings {
773                        if binding.negated
774                            && !binding.anti_join_cols.is_empty()
775                            && let Some(entry) = registry.get(binding.derived_scan_index)
776                        {
777                            let neg_facts = entry.data.read().clone();
778                            if !neg_facts.is_empty() {
779                                if binding.target_has_prob && fp_rule.prob_column_name.is_some() {
780                                    let complement_col =
781                                        format!("__prob_complement_{}", binding.rule_name);
782                                    if let Some(prob_col) = &binding.target_prob_col {
783                                        batches =
784                                            super::locy_fixpoint::apply_prob_complement_composite(
785                                                batches,
786                                                &neg_facts,
787                                                &binding.anti_join_cols,
788                                                prob_col,
789                                                &complement_col,
790                                            )?;
791                                    } else {
792                                        // target_has_prob but no prob_col: fall back to anti-join.
793                                        batches = super::locy_fixpoint::apply_anti_join_composite(
794                                            batches,
795                                            &neg_facts,
796                                            &binding.anti_join_cols,
797                                        )?;
798                                    }
799                                } else {
800                                    batches = super::locy_fixpoint::apply_anti_join_composite(
801                                        batches,
802                                        &neg_facts,
803                                        &binding.anti_join_cols,
804                                    )?;
805                                }
806                            }
807                        }
808                    }
809
810                    // Multiply complement columns into PROB per-clause.
811                    let complement_cols: Vec<String> = if !batches.is_empty() {
812                        batches[0]
813                            .schema()
814                            .fields()
815                            .iter()
816                            .filter(|f| f.name().starts_with("__prob_complement_"))
817                            .map(|f| f.name().clone())
818                            .collect()
819                    } else {
820                        vec![]
821                    };
822                    if !complement_cols.is_empty() {
823                        batches = super::locy_fixpoint::multiply_prob_factors(
824                            batches,
825                            fp_rule.prob_column_name.as_deref(),
826                            &complement_cols,
827                        )?;
828                    }
829
830                    tagged_clause_facts.push((clause_idx, batches));
831                }
832
833                // Record provenance and detect shared proofs for non-recursive rules.
834                let shared_info = if let Some(ref tracker) = derivation_tracker {
835                    super::locy_fixpoint::record_and_detect_lineage_nonrecursive(
836                        fp_rule,
837                        &tagged_clause_facts,
838                        tracker,
839                        &warnings_slot,
840                        &registry,
841                        top_k_proofs,
842                    )
843                } else {
844                    None
845                };
846
847                // Flatten tagged facts for post-fixpoint chain.
848                let mut all_clause_facts: Vec<RecordBatch> = tagged_clause_facts
849                    .into_iter()
850                    .flat_map(|(_, batches)| batches)
851                    .collect();
852
853                // Apply BDD for shared groups if exact_probability is enabled.
854                if exact_probability
855                    && let Some(ref info) = shared_info
856                    && let Some(ref tracker) = derivation_tracker
857                {
858                    all_clause_facts = super::locy_fixpoint::apply_exact_wmc(
859                        all_clause_facts,
860                        fp_rule,
861                        info,
862                        tracker,
863                        max_bdd_variables,
864                        &warnings_slot,
865                        &approximate_slot,
866                    )?;
867                }
868
869                // Apply post-fixpoint operators (PRIORITY, FOLD, BEST BY) on union.
870                let facts = super::locy_fixpoint::apply_post_fixpoint_chain(
871                    all_clause_facts,
872                    fp_rule,
873                    &task_ctx,
874                    strict_probability_domain,
875                    probability_epsilon,
876                )
877                .await?;
878
879                // Write facts into registry handles for later strata
880                write_facts_to_registry(&registry, &rule.name, &facts);
881                derived_store.insert(rule.name.clone(), facts);
882            }
883        }
884    }
885
886    // Compute peak memory from derived store byte sizes
887    let peak_bytes: usize = derived_store
888        .relations
889        .values()
890        .flat_map(|batches| batches.iter())
891        .map(|b| {
892            b.columns()
893                .iter()
894                .map(|col| col.get_buffer_memory_size())
895                .sum::<usize>()
896        })
897        .sum();
898    *peak_memory_slot.write().unwrap() = peak_bytes;
899
900    // Execute inline Cypher commands via execute_subplan.
901    // QUERY is deferred to the orchestrator: the DerivedStore uses inferred types
902    // (e.g. Float64 for property-derived columns) which don't preserve the actual
903    // property values. The orchestrator's SLG path re-derives with correct types.
904    // DERIVE/ASSUME/EXPLAIN/ABDUCE are also deferred (need L0 fork/restore, tree output, etc.).
905    //
906    // Cypher commands that appear AFTER a DERIVE command are also deferred:
907    // they need the ephemeral L0 overlay populated by DERIVE to see derived
908    // edges, which is only available in the orchestrator's dispatch loop.
909    let first_derive_idx = commands
910        .iter()
911        .position(|c| matches!(c, LocyCommand::Derive { .. }));
912    let mut inline_results: Vec<(usize, CommandResult)> = Vec::new();
913    for (cmd_idx, cmd) in commands.iter().enumerate() {
914        if let LocyCommand::Cypher { query } = cmd {
915            // Defer Cypher commands that follow a DERIVE to the dispatch loop
916            // so they can read from the ephemeral L0 overlay.
917            if first_derive_idx.is_some_and(|di| cmd_idx > di) {
918                continue;
919            }
920            let rows = execute_cypher_inline(
921                query,
922                &schema_info,
923                &params,
924                &graph_ctx,
925                &session_ctx,
926                &storage,
927            )
928            .await?;
929            inline_results.push((cmd_idx, CommandResult::Cypher(rows)));
930        }
931    }
932    *command_results_slot.write().unwrap() = inline_results;
933
934    let stats = vec![build_stats_batch(&derived_store, &strata, output_schema)];
935    *derived_store_slot.write().unwrap() = Some(derived_store);
936    Ok(stats)
937}
938
939// ---------------------------------------------------------------------------
940// Cross-stratum fact injection
941// ---------------------------------------------------------------------------
942
943/// Write already-evaluated facts into registry handles for cross-stratum IS-refs.
944fn write_cross_stratum_facts(
945    registry: &DerivedScanRegistry,
946    derived_store: &DerivedStore,
947    stratum: &LocyStratum,
948) {
949    // For each rule in this stratum, find IS-refs to rules in other strata
950    for rule in &stratum.rules {
951        for clause in &rule.clauses {
952            for is_ref in &clause.is_refs {
953                // If this IS-ref points to a rule already in the derived store
954                // (i.e., from a previous stratum), write its facts into the registry
955                if let Some(facts) = derived_store.get(&is_ref.rule_name) {
956                    write_facts_to_registry(registry, &is_ref.rule_name, facts);
957                }
958            }
959        }
960    }
961}
962
963/// Write facts into non-self-ref registry handles for a given rule.
964fn write_facts_to_registry(registry: &DerivedScanRegistry, rule_name: &str, facts: &[RecordBatch]) {
965    let entries = registry.entries_for_rule(rule_name);
966    for entry in entries {
967        if !entry.is_self_ref {
968            let mut guard = entry.data.write();
969            *guard = if facts.is_empty() || facts.iter().all(|b| b.num_rows() == 0) {
970                vec![RecordBatch::new_empty(Arc::clone(&entry.schema))]
971            } else {
972                // Try to re-wrap batches with the entry's schema for column name
973                // alignment. If the types don't match (e.g. inferred Float64 vs
974                // actual Utf8 from schema mode), fall back to the batch's own
975                // schema to avoid silent data loss.
976                facts
977                    .iter()
978                    .filter(|b| b.num_rows() > 0)
979                    .map(|b| {
980                        RecordBatch::try_new(Arc::clone(&entry.schema), b.columns().to_vec())
981                            .unwrap_or_else(|_| b.clone())
982                    })
983                    .collect()
984            };
985        }
986    }
987}
988
989// ---------------------------------------------------------------------------
990// LocyRulePlan → FixpointRulePlan conversion
991// ---------------------------------------------------------------------------
992
993/// Convert logical `LocyRulePlan` types to physical `FixpointRulePlan` types.
994fn convert_to_fixpoint_plans(
995    rules: &[LocyRulePlan],
996    registry: &DerivedScanRegistry,
997    deterministic_best_by: bool,
998) -> DFResult<Vec<FixpointRulePlan>> {
999    rules
1000        .iter()
1001        .map(|rule| {
1002            let yield_schema = yield_columns_to_arrow_schema(&rule.yield_schema);
1003            let key_column_indices: Vec<usize> = rule
1004                .yield_schema
1005                .iter()
1006                .enumerate()
1007                .filter(|(_, yc)| yc.is_key)
1008                .map(|(i, _)| i)
1009                .collect();
1010
1011            let clauses: Vec<FixpointClausePlan> = rule
1012                .clauses
1013                .iter()
1014                .map(|clause| {
1015                    let is_ref_bindings = convert_is_refs(&clause.is_refs, registry)?;
1016                    Ok(FixpointClausePlan {
1017                        body_logical: clause.body.clone(),
1018                        is_ref_bindings,
1019                        priority: clause.priority,
1020                        along_bindings: clause.along_bindings.clone(),
1021                    })
1022                })
1023                .collect::<DFResult<Vec<_>>>()?;
1024
1025            let fold_bindings = convert_fold_bindings(&rule.fold_bindings, &rule.yield_schema)?;
1026            let best_by_criteria =
1027                convert_best_by_criteria(&rule.best_by_criteria, &rule.yield_schema)?;
1028
1029            let has_priority = rule.priority.is_some();
1030
1031            // Add __priority column to yield schema if PRIORITY is used
1032            let yield_schema = if has_priority {
1033                let mut fields: Vec<Arc<Field>> = yield_schema.fields().iter().cloned().collect();
1034                fields.push(Arc::new(Field::new("__priority", DataType::Int64, true)));
1035                ArrowSchema::new(fields)
1036            } else {
1037                yield_schema
1038            };
1039
1040            let prob_column_name = rule
1041                .yield_schema
1042                .iter()
1043                .find(|yc| yc.is_prob)
1044                .map(|yc| yc.name.clone());
1045
1046            Ok(FixpointRulePlan {
1047                name: rule.name.clone(),
1048                clauses,
1049                yield_schema: Arc::new(yield_schema),
1050                key_column_indices,
1051                priority: rule.priority,
1052                has_fold: !rule.fold_bindings.is_empty(),
1053                fold_bindings,
1054                having: rule.having.clone(),
1055                has_best_by: !rule.best_by_criteria.is_empty(),
1056                best_by_criteria,
1057                has_priority,
1058                deterministic: deterministic_best_by,
1059                prob_column_name,
1060            })
1061        })
1062        .collect()
1063}
1064
1065/// Convert `LocyIsRef` to `IsRefBinding` by looking up scan indices in the registry.
1066fn convert_is_refs(
1067    is_refs: &[LocyIsRef],
1068    registry: &DerivedScanRegistry,
1069) -> DFResult<Vec<IsRefBinding>> {
1070    is_refs
1071        .iter()
1072        .map(|is_ref| {
1073            let entries = registry.entries_for_rule(&is_ref.rule_name);
1074            // Find the matching entry (prefer self-ref for same-stratum rules)
1075            let entry = entries
1076                .iter()
1077                .find(|e| e.is_self_ref)
1078                .or_else(|| entries.first())
1079                .ok_or_else(|| {
1080                    datafusion::error::DataFusionError::Plan(format!(
1081                        "No derived scan entry found for IS-ref to '{}'",
1082                        is_ref.rule_name
1083                    ))
1084                })?;
1085
1086            // For negated IS-refs, compute (left_body_col, right_derived_col) pairs for
1087            // anti-join filtering. Subject vars are assumed to be node variables, so
1088            // the body column is `{var}._vid` (UInt64). The derived column name is taken
1089            // positionally from the registry entry's schema (KEY columns come first).
1090            let anti_join_cols = if is_ref.negated {
1091                let mut cols: Vec<(String, String)> = is_ref
1092                    .subjects
1093                    .iter()
1094                    .enumerate()
1095                    .filter_map(|(i, s)| {
1096                        if let uni_cypher::ast::Expr::Variable(var) = s {
1097                            let right_col = entry
1098                                .schema
1099                                .fields()
1100                                .get(i)
1101                                .map(|f| f.name().clone())
1102                                .unwrap_or_else(|| var.clone());
1103                            // After LocyProject the subject column is renamed to the yield
1104                            // column name (just `var`, not `var._vid`). Use bare var as left.
1105                            Some((var.clone(), right_col))
1106                        } else {
1107                            None
1108                        }
1109                    })
1110                    .collect();
1111                // Include target variable in anti-join for composite-key IS NOT.
1112                // Without this, `d IS NOT known TO dis` only checks d, not (d, dis),
1113                // filtering ALL pairs where the drug has ANY indication regardless
1114                // of disease.
1115                if let Some(uni_cypher::ast::Expr::Variable(target_var)) = &is_ref.target {
1116                    let target_idx = is_ref.subjects.len();
1117                    if let Some(field) = entry.schema.fields().get(target_idx) {
1118                        cols.push((target_var.clone(), field.name().clone()));
1119                    }
1120                }
1121                cols
1122            } else {
1123                Vec::new()
1124            };
1125
1126            // Provenance join cols: for ALL IS-refs (not just negated), compute
1127            // (body_col, derived_col) pairs so shared-proof detection can trace
1128            // which source facts contributed to each derived row.
1129            let provenance_join_cols: Vec<(String, String)> = is_ref
1130                .subjects
1131                .iter()
1132                .enumerate()
1133                .filter_map(|(i, s)| {
1134                    if let uni_cypher::ast::Expr::Variable(var) = s {
1135                        let right_col = entry
1136                            .schema
1137                            .fields()
1138                            .get(i)
1139                            .map(|f| f.name().clone())
1140                            .unwrap_or_else(|| var.clone());
1141                        Some((var.clone(), right_col))
1142                    } else {
1143                        None
1144                    }
1145                })
1146                .collect();
1147
1148            Ok(IsRefBinding {
1149                derived_scan_index: entry.scan_index,
1150                rule_name: is_ref.rule_name.clone(),
1151                is_self_ref: entry.is_self_ref,
1152                negated: is_ref.negated,
1153                anti_join_cols,
1154                target_has_prob: is_ref.target_has_prob,
1155                target_prob_col: is_ref.target_prob_col.clone(),
1156                provenance_join_cols,
1157            })
1158        })
1159        .collect()
1160}
1161
1162/// Convert fold binding expressions to physical `FoldBinding`.
1163///
1164/// The input column is looked up by the fold binding's output name (e.g., "total")
1165/// in the yield schema, since the LocyProject aliases the aggregate input expression
1166/// to the fold output name.
1167fn convert_fold_bindings(
1168    fold_bindings: &[(String, Expr)],
1169    yield_schema: &[LocyYieldColumn],
1170) -> DFResult<Vec<FoldBinding>> {
1171    fold_bindings
1172        .iter()
1173        .map(|(name, expr)| {
1174            let (kind, _input_col_name) = parse_fold_aggregate(expr)?;
1175
1176            // CountAll has no input column — LocyProject skips the output column
1177            // entirely, so there is nothing to look up.
1178            if kind == FoldAggKind::CountAll {
1179                return Ok(FoldBinding {
1180                    output_name: name.clone(),
1181                    kind,
1182                    input_col_index: 0, // unused for CountAll
1183                    input_col_name: None,
1184                });
1185            }
1186
1187            // The LocyProject projects the aggregate input expression AS the fold
1188            // output name, so the input column index matches the yield schema position.
1189            // Also store the column name for name-based resolution at execution time
1190            // (more robust when schema reconciliation changes column ordering).
1191            let input_col_index = yield_schema
1192                .iter()
1193                .position(|yc| yc.name == *name)
1194                .unwrap_or(0);
1195            Ok(FoldBinding {
1196                output_name: name.clone(),
1197                kind,
1198                input_col_index,
1199                input_col_name: Some(name.clone()),
1200            })
1201        })
1202        .collect()
1203}
1204
1205/// Parse a fold aggregate expression into (kind, input_column_name).
1206fn parse_fold_aggregate(expr: &Expr) -> DFResult<(FoldAggKind, String)> {
1207    match expr {
1208        Expr::FunctionCall { name, args, .. } => {
1209            let upper = name.to_uppercase();
1210            let is_count = matches!(upper.as_str(), "COUNT" | "MCOUNT");
1211
1212            // COUNT/MCOUNT with zero args → CountAll (like SQL COUNT(*))
1213            if is_count && args.is_empty() {
1214                return Ok((FoldAggKind::CountAll, String::new()));
1215            }
1216
1217            let kind = match upper.as_str() {
1218                "SUM" | "MSUM" => FoldAggKind::Sum,
1219                "MAX" | "MMAX" => FoldAggKind::Max,
1220                "MIN" | "MMIN" => FoldAggKind::Min,
1221                "COUNT" | "MCOUNT" => FoldAggKind::Count,
1222                "AVG" => FoldAggKind::Avg,
1223                "COLLECT" => FoldAggKind::Collect,
1224                "MNOR" => FoldAggKind::Nor,
1225                "MPROD" => FoldAggKind::Prod,
1226                _ => {
1227                    return Err(datafusion::error::DataFusionError::Plan(format!(
1228                        "Unknown FOLD aggregate function: {}",
1229                        name
1230                    )));
1231                }
1232            };
1233            let col_name = match args.first() {
1234                Some(Expr::Variable(v)) => v.clone(),
1235                Some(Expr::Property(_, prop)) => prop.clone(),
1236                Some(other) => other.to_string_repr(),
1237                None => {
1238                    return Err(datafusion::error::DataFusionError::Plan(
1239                        "FOLD aggregate function requires at least one argument".to_string(),
1240                    ));
1241                }
1242            };
1243            Ok((kind, col_name))
1244        }
1245        _ => Err(datafusion::error::DataFusionError::Plan(
1246            "FOLD binding must be a function call (e.g., SUM(x))".to_string(),
1247        )),
1248    }
1249}
1250
1251/// Convert best-by criteria expressions to physical `SortCriterion`.
1252///
1253/// Resolves the criteria column by trying:
1254/// 1. Property name (e.g., `e.cost` → "cost")
1255/// 2. Variable name (e.g., `cost`)
1256/// 3. Full expression string (e.g., "e.cost" as a variable name)
1257fn convert_best_by_criteria(
1258    criteria: &[(Expr, bool)],
1259    yield_schema: &[LocyYieldColumn],
1260) -> DFResult<Vec<SortCriterion>> {
1261    criteria
1262        .iter()
1263        .map(|(expr, ascending)| {
1264            let col_name = match expr {
1265                Expr::Property(_, prop) => prop.clone(),
1266                Expr::Variable(v) => v.clone(),
1267                _ => {
1268                    return Err(datafusion::error::DataFusionError::Plan(
1269                        "BEST BY criterion must be a variable or property reference".to_string(),
1270                    ));
1271                }
1272            };
1273            // Try exact match first, then try just the last component after '.'
1274            let col_index = yield_schema
1275                .iter()
1276                .position(|yc| yc.name == col_name)
1277                .or_else(|| {
1278                    let short_name = col_name.rsplit('.').next().unwrap_or(&col_name);
1279                    yield_schema.iter().position(|yc| yc.name == short_name)
1280                })
1281                .ok_or_else(|| {
1282                    datafusion::error::DataFusionError::Plan(format!(
1283                        "BEST BY column '{}' not found in yield schema",
1284                        col_name
1285                    ))
1286                })?;
1287            Ok(SortCriterion {
1288                col_index,
1289                ascending: *ascending,
1290                nulls_first: false,
1291            })
1292        })
1293        .collect()
1294}
1295
1296// ---------------------------------------------------------------------------
1297// Schema helpers
1298// ---------------------------------------------------------------------------
1299
1300/// Convert `LocyYieldColumn` slice to Arrow schema using inferred types.
1301fn yield_columns_to_arrow_schema(columns: &[LocyYieldColumn]) -> ArrowSchema {
1302    let fields: Vec<Arc<Field>> = columns
1303        .iter()
1304        .map(|yc| Arc::new(Field::new(&yc.name, yc.data_type.clone(), true)))
1305        .collect();
1306    ArrowSchema::new(fields)
1307}
1308
1309/// Build a combined output schema for fixpoint (union of all rules' schemas).
1310fn build_fixpoint_output_schema(rules: &[LocyRulePlan]) -> SchemaRef {
1311    // FixpointExec concatenates all rules' output, using the first rule's schema
1312    // as the output schema (all rules in a recursive stratum share compatible schemas).
1313    if let Some(rule) = rules.first() {
1314        Arc::new(yield_columns_to_arrow_schema(&rule.yield_schema))
1315    } else {
1316        Arc::new(ArrowSchema::empty())
1317    }
1318}
1319
1320/// Build a stats RecordBatch summarizing derived relation counts.
1321fn build_stats_batch(
1322    derived_store: &DerivedStore,
1323    _strata: &[LocyStratum],
1324    output_schema: SchemaRef,
1325) -> RecordBatch {
1326    // Build a simple stats batch with rule_name and fact_count columns
1327    let mut rule_names: Vec<String> = derived_store.rule_names().map(String::from).collect();
1328    rule_names.sort();
1329
1330    let name_col: arrow_array::StringArray = rule_names.iter().map(|s| Some(s.as_str())).collect();
1331    let count_col: arrow_array::Int64Array = rule_names
1332        .iter()
1333        .map(|name| Some(derived_store.fact_count(name) as i64))
1334        .collect();
1335
1336    let stats_schema = stats_schema();
1337    RecordBatch::try_new(stats_schema, vec![Arc::new(name_col), Arc::new(count_col)])
1338        .unwrap_or_else(|_| RecordBatch::new_empty(output_schema))
1339}
1340
1341/// Schema for the stats batch returned when no commands are present.
1342pub fn stats_schema() -> SchemaRef {
1343    Arc::new(ArrowSchema::new(vec![
1344        Arc::new(Field::new("rule_name", DataType::Utf8, false)),
1345        Arc::new(Field::new("fact_count", DataType::Int64, false)),
1346    ]))
1347}
1348
1349// ---------------------------------------------------------------------------
1350// Unit tests
1351// ---------------------------------------------------------------------------
1352
1353#[cfg(test)]
1354mod tests {
1355    use super::*;
1356    use arrow_array::{Int64Array, LargeBinaryArray, StringArray};
1357
1358    #[test]
1359    fn test_derived_store_insert_and_get() {
1360        let mut store = DerivedStore::new();
1361        assert!(store.get("test").is_none());
1362
1363        let schema = Arc::new(ArrowSchema::new(vec![Arc::new(Field::new(
1364            "x",
1365            DataType::LargeBinary,
1366            true,
1367        ))]));
1368        let batch = RecordBatch::try_new(
1369            Arc::clone(&schema),
1370            vec![Arc::new(LargeBinaryArray::from(vec![
1371                Some(b"a" as &[u8]),
1372                Some(b"b"),
1373            ]))],
1374        )
1375        .unwrap();
1376
1377        store.insert("test".to_string(), vec![batch.clone()]);
1378
1379        let facts = store.get("test").unwrap();
1380        assert_eq!(facts.len(), 1);
1381        assert_eq!(facts[0].num_rows(), 2);
1382    }
1383
1384    #[test]
1385    fn test_derived_store_fact_count() {
1386        let mut store = DerivedStore::new();
1387        assert_eq!(store.fact_count("empty"), 0);
1388
1389        let schema = Arc::new(ArrowSchema::new(vec![Arc::new(Field::new(
1390            "x",
1391            DataType::LargeBinary,
1392            true,
1393        ))]));
1394        let batch1 = RecordBatch::try_new(
1395            Arc::clone(&schema),
1396            vec![Arc::new(LargeBinaryArray::from(vec![Some(b"a" as &[u8])]))],
1397        )
1398        .unwrap();
1399        let batch2 = RecordBatch::try_new(
1400            Arc::clone(&schema),
1401            vec![Arc::new(LargeBinaryArray::from(vec![
1402                Some(b"b" as &[u8]),
1403                Some(b"c"),
1404            ]))],
1405        )
1406        .unwrap();
1407
1408        store.insert("test".to_string(), vec![batch1, batch2]);
1409        assert_eq!(store.fact_count("test"), 3);
1410    }
1411
1412    #[test]
1413    fn test_stats_batch_schema() {
1414        let schema = stats_schema();
1415        assert_eq!(schema.fields().len(), 2);
1416        assert_eq!(schema.field(0).name(), "rule_name");
1417        assert_eq!(schema.field(1).name(), "fact_count");
1418        assert_eq!(schema.field(0).data_type(), &DataType::Utf8);
1419        assert_eq!(schema.field(1).data_type(), &DataType::Int64);
1420    }
1421
1422    #[test]
1423    fn test_stats_batch_content() {
1424        let mut store = DerivedStore::new();
1425        let schema = Arc::new(ArrowSchema::new(vec![Arc::new(Field::new(
1426            "x",
1427            DataType::LargeBinary,
1428            true,
1429        ))]));
1430        let batch = RecordBatch::try_new(
1431            Arc::clone(&schema),
1432            vec![Arc::new(LargeBinaryArray::from(vec![
1433                Some(b"a" as &[u8]),
1434                Some(b"b"),
1435            ]))],
1436        )
1437        .unwrap();
1438        store.insert("reach".to_string(), vec![batch]);
1439
1440        let output_schema = stats_schema();
1441        let stats = build_stats_batch(&store, &[], Arc::clone(&output_schema));
1442        assert_eq!(stats.num_rows(), 1);
1443
1444        let names = stats
1445            .column(0)
1446            .as_any()
1447            .downcast_ref::<StringArray>()
1448            .unwrap();
1449        assert_eq!(names.value(0), "reach");
1450
1451        let counts = stats
1452            .column(1)
1453            .as_any()
1454            .downcast_ref::<Int64Array>()
1455            .unwrap();
1456        assert_eq!(counts.value(0), 2);
1457    }
1458
1459    #[test]
1460    fn test_yield_columns_to_arrow_schema() {
1461        let columns = vec![
1462            LocyYieldColumn {
1463                name: "a".to_string(),
1464                is_key: true,
1465                is_prob: false,
1466                data_type: DataType::UInt64,
1467            },
1468            LocyYieldColumn {
1469                name: "b".to_string(),
1470                is_key: false,
1471                is_prob: false,
1472                data_type: DataType::LargeUtf8,
1473            },
1474            LocyYieldColumn {
1475                name: "c".to_string(),
1476                is_key: true,
1477                is_prob: false,
1478                data_type: DataType::Float64,
1479            },
1480        ];
1481
1482        let schema = yield_columns_to_arrow_schema(&columns);
1483        assert_eq!(schema.fields().len(), 3);
1484        assert_eq!(schema.field(0).name(), "a");
1485        assert_eq!(schema.field(1).name(), "b");
1486        assert_eq!(schema.field(2).name(), "c");
1487        // Fields use inferred types
1488        assert_eq!(schema.field(0).data_type(), &DataType::UInt64);
1489        assert_eq!(schema.field(1).data_type(), &DataType::LargeUtf8);
1490        assert_eq!(schema.field(2).data_type(), &DataType::Float64);
1491        for field in schema.fields() {
1492            assert!(field.is_nullable());
1493        }
1494    }
1495
1496    #[test]
1497    fn test_key_column_indices() {
1498        let columns = [
1499            LocyYieldColumn {
1500                name: "a".to_string(),
1501                is_key: true,
1502                is_prob: false,
1503                data_type: DataType::LargeBinary,
1504            },
1505            LocyYieldColumn {
1506                name: "b".to_string(),
1507                is_key: false,
1508                is_prob: false,
1509                data_type: DataType::LargeBinary,
1510            },
1511            LocyYieldColumn {
1512                name: "c".to_string(),
1513                is_key: true,
1514                is_prob: false,
1515                data_type: DataType::LargeBinary,
1516            },
1517        ];
1518
1519        let key_indices: Vec<usize> = columns
1520            .iter()
1521            .enumerate()
1522            .filter(|(_, yc)| yc.is_key)
1523            .map(|(i, _)| i)
1524            .collect();
1525        assert_eq!(key_indices, vec![0, 2]);
1526    }
1527
1528    #[test]
1529    fn test_parse_fold_aggregate_sum() {
1530        let expr = Expr::FunctionCall {
1531            name: "SUM".to_string(),
1532            args: vec![Expr::Variable("cost".to_string())],
1533            distinct: false,
1534            window_spec: None,
1535        };
1536        let (kind, col) = parse_fold_aggregate(&expr).unwrap();
1537        assert!(matches!(kind, FoldAggKind::Sum));
1538        assert_eq!(col, "cost");
1539    }
1540
1541    #[test]
1542    fn test_parse_fold_aggregate_monotonic() {
1543        let expr = Expr::FunctionCall {
1544            name: "MMAX".to_string(),
1545            args: vec![Expr::Variable("score".to_string())],
1546            distinct: false,
1547            window_spec: None,
1548        };
1549        let (kind, col) = parse_fold_aggregate(&expr).unwrap();
1550        assert!(matches!(kind, FoldAggKind::Max));
1551        assert_eq!(col, "score");
1552    }
1553
1554    #[test]
1555    fn test_parse_fold_aggregate_unknown() {
1556        let expr = Expr::FunctionCall {
1557            name: "UNKNOWN_AGG".to_string(),
1558            args: vec![Expr::Variable("x".to_string())],
1559            distinct: false,
1560            window_spec: None,
1561        };
1562        assert!(parse_fold_aggregate(&expr).is_err());
1563    }
1564
1565    #[test]
1566    fn test_no_commands_returns_stats() {
1567        let store = DerivedStore::new();
1568        let output_schema = stats_schema();
1569        let stats = build_stats_batch(&store, &[], Arc::clone(&output_schema));
1570        // Empty store → 0 rows
1571        assert_eq!(stats.num_rows(), 0);
1572    }
1573}