Skip to main content

uni_query/query/df_graph/
locy_program.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Top-level Locy program executor.
5//!
6//! `LocyProgramExec` orchestrates the full evaluation of a Locy program:
7//! it evaluates strata in dependency order, runs fixpoint for recursive strata,
8//! applies post-fixpoint operators (FOLD, PRIORITY, BEST BY), and then
9//! executes the program's commands (goal queries, DERIVE, ASSUME, etc.).
10
11use crate::query::df_graph::GraphExecutionContext;
12use crate::query::df_graph::common::{
13    collect_all_partitions, compute_plan_properties, execute_subplan,
14};
15use crate::query::df_graph::locy_best_by::SortCriterion;
16use crate::query::df_graph::locy_explain::ProvenanceStore;
17use crate::query::df_graph::locy_fixpoint::{
18    DerivedScanRegistry, FixpointClausePlan, FixpointExec, FixpointRulePlan, IsRefBinding,
19};
20use crate::query::df_graph::locy_fold::{FoldBinding, resolve_locy_aggregate};
21use crate::query::planner_locy_types::{
22    LocyCommand, LocyIsRef, LocyRulePlan, LocyStratum, LocyYieldColumn,
23};
24use arrow_array::RecordBatch;
25use arrow_schema::{DataType, Field, Schema as ArrowSchema, SchemaRef};
26use datafusion::common::Result as DFResult;
27use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
28use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
29use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
30use futures::Stream;
31use parking_lot::RwLock;
32use std::any::Any;
33use std::collections::HashMap;
34use std::fmt;
35use std::pin::Pin;
36use std::sync::Arc;
37use std::sync::RwLock as StdRwLock;
38use std::task::{Context, Poll};
39use std::time::{Duration, Instant};
40use uni_common::Value;
41use uni_common::core::schema::Schema as UniSchema;
42use uni_cypher::ast::Expr;
43use uni_locy::{
44    ClassifierRegistry, CommandResult, FactRow, ModelInvocationCache, RuntimeWarning, SemiringKind,
45};
46use uni_plugin::PluginRegistry;
47use uni_store::storage::manager::StorageManager;
48
49// ---------------------------------------------------------------------------
50// DerivedStore — cross-stratum fact sharing
51// ---------------------------------------------------------------------------
52
53/// Simple store for derived relation facts across strata.
54///
55/// Each rule's converged facts are stored here after its stratum completes,
56/// making them available for later strata that depend on them.
57pub struct DerivedStore {
58    relations: HashMap<String, Vec<RecordBatch>>,
59}
60
61impl Default for DerivedStore {
62    fn default() -> Self {
63        Self::new()
64    }
65}
66
67impl DerivedStore {
68    pub fn new() -> Self {
69        Self {
70            relations: HashMap::new(),
71        }
72    }
73
74    pub fn insert(&mut self, rule_name: String, facts: Vec<RecordBatch>) {
75        self.relations.insert(rule_name, facts);
76    }
77
78    pub fn get(&self, rule_name: &str) -> Option<&Vec<RecordBatch>> {
79        self.relations.get(rule_name)
80    }
81
82    pub fn fact_count(&self, rule_name: &str) -> usize {
83        self.relations
84            .get(rule_name)
85            .map(|batches| batches.iter().map(|b| b.num_rows()).sum())
86            .unwrap_or(0)
87    }
88
89    pub fn rule_names(&self) -> impl Iterator<Item = &str> {
90        self.relations.keys().map(|s| s.as_str())
91    }
92}
93
94// ---------------------------------------------------------------------------
95// LocyProgramExec — DataFusion ExecutionPlan
96// ---------------------------------------------------------------------------
97
98/// DataFusion `ExecutionPlan` that runs an entire Locy program.
99///
100/// Evaluates strata in dependency order, using `FixpointExec` for recursive
101/// strata and direct subplan execution for non-recursive ones. After all
102/// strata converge, dispatches commands.
103pub struct LocyProgramExec {
104    strata: Vec<LocyStratum>,
105    commands: Vec<LocyCommand>,
106    derived_scan_registry: Arc<DerivedScanRegistry>,
107    plugin_registry: Arc<PluginRegistry>,
108    graph_ctx: Arc<GraphExecutionContext>,
109    session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
110    storage: Arc<StorageManager>,
111    schema_info: Arc<UniSchema>,
112    params: HashMap<String, Value>,
113    output_schema: SchemaRef,
114    properties: Arc<PlanProperties>,
115    metrics: ExecutionPlanMetricsSet,
116    max_iterations: usize,
117    timeout: Duration,
118    max_derived_bytes: usize,
119    deterministic_best_by: bool,
120    strict_probability_domain: bool,
121    probability_epsilon: f64,
122    exact_probability: bool,
123    max_bdd_variables: usize,
124    /// Active probability semiring (rollout D-7). Defaults to `AddMultProb`
125    /// — the Phase 1/2 byte-identical behavior.
126    semiring_kind: SemiringKind,
127    /// Phase B Slice 3: runtime registry of `NeuralClassifier` impls
128    /// keyed by model name. Held by `Arc` so executor clones share the
129    /// same map.
130    classifier_registry: Arc<ClassifierRegistry>,
131    /// Phase B follow-up: optional memoization cache for classifier
132    /// outputs. `None` → no caching.
133    classifier_cache: Option<Arc<ModelInvocationCache>>,
134    /// Phase C B1-B3 follow-up: per-query side-channel store for
135    /// (raw, calibrated, confidence_band) records. Threaded to
136    /// `FixpointExec` so EXPLAIN can read from it.
137    classifier_provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
138    /// Shared slot for extracting the DerivedStore after execution completes.
139    derived_store_slot: Arc<StdRwLock<Option<DerivedStore>>>,
140    /// Shared slot for groups where BDD fell back to independence mode.
141    approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
142    /// Optional provenance tracker injected after construction (via `set_derivation_tracker`).
143    derivation_tracker: Arc<StdRwLock<Option<Arc<ProvenanceStore>>>>,
144    /// Shared slot written with per-rule iteration counts after fixpoint convergence.
145    iteration_counts_slot: Arc<StdRwLock<HashMap<String, usize>>>,
146    /// Shared slot written with peak memory bytes after fixpoint completes.
147    peak_memory_slot: Arc<StdRwLock<usize>>,
148    /// Shared slot for runtime warnings collected during evaluation.
149    warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
150    /// Shared slot for inline command results (QUERY, Cypher) executed inside `run_program()`.
151    command_results_slot: Arc<StdRwLock<Vec<(usize, CommandResult)>>>,
152    /// Top-k proof filtering: 0 = unlimited (default), >0 = retain at most k proofs per fact.
153    top_k_proofs: usize,
154    /// Shared interruption signal (see [`interruption`]): `interruption::NONE`
155    /// while running, non-zero once the stratum loop or fixpoint is cut short.
156    /// Decoded after execution to populate `incomplete_slot`.
157    timeout_flag: Arc<std::sync::atomic::AtomicU8>,
158    /// Shared slot populated when evaluation stops before completing. Holds the
159    /// stop reason plus the skipped / unsound-complement rule lists; read after
160    /// execution to populate `LocyResult.incomplete`. `None` for a complete run.
161    incomplete_slot: Arc<StdRwLock<Option<uni_common::LocyIncomplete>>>,
162}
163
164/// Encoding for the shared interruption signal threaded through the stratum
165/// loop and the recursive fixpoint as an `Arc<AtomicU8>`.
166///
167/// A single atomic byte records *why* evaluation stopped so the two layers can
168/// agree on a reason without a second channel. `NONE` means "running or
169/// completed normally".
170pub(crate) mod interruption {
171    use std::sync::atomic::{AtomicU8, Ordering};
172
173    use uni_common::LocyIncompleteReason;
174
175    /// No interruption: evaluation is running or completed normally.
176    pub(crate) const NONE: u8 = 0;
177    /// The wall-clock `timeout` budget was exhausted.
178    pub(crate) const TIMEOUT: u8 = 1;
179    /// A recursive stratum hit `max_iterations` without converging.
180    pub(crate) const ITERATION_LIMIT: u8 = 2;
181
182    /// Decodes the current interruption reason, if any.
183    pub(crate) fn reason(flag: &AtomicU8) -> Option<LocyIncompleteReason> {
184        match flag.load(Ordering::Relaxed) {
185            TIMEOUT => Some(LocyIncompleteReason::Timeout),
186            ITERATION_LIMIT => Some(LocyIncompleteReason::IterationLimit),
187            _ => None,
188        }
189    }
190
191    /// Records an interruption reason. First reason wins: a later, lower-priority
192    /// signal (non-convergence) never overwrites an earlier wall-clock timeout,
193    /// preserving the original precedence.
194    pub(crate) fn set(flag: &AtomicU8, code: u8) {
195        let _ = flag.compare_exchange(NONE, code, Ordering::Relaxed, Ordering::Relaxed);
196    }
197}
198
199impl fmt::Debug for LocyProgramExec {
200    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201        f.debug_struct("LocyProgramExec")
202            .field("strata_count", &self.strata.len())
203            .field("commands_count", &self.commands.len())
204            .field("max_iterations", &self.max_iterations)
205            .field("timeout", &self.timeout)
206            .field("output_schema", &self.output_schema)
207            .field("max_derived_bytes", &self.max_derived_bytes)
208            .finish_non_exhaustive()
209    }
210}
211
212impl LocyProgramExec {
213    #[expect(
214        clippy::too_many_arguments,
215        reason = "execution plan node requires full graph and session context"
216    )]
217    #[deprecated(
218        note = "use `new_with_semiring_classifiers_and_cache` (or the lighter \
219                `new_with_semiring_and_classifiers` / `new_with_semiring`) — \
220                this legacy ctor defaults the semiring to AddMultProb and \
221                ships no classifier registry. To be removed after C0 Stage 2."
222    )]
223    pub fn new(
224        strata: Vec<LocyStratum>,
225        commands: Vec<LocyCommand>,
226        derived_scan_registry: Arc<DerivedScanRegistry>,
227        plugin_registry: Arc<PluginRegistry>,
228        graph_ctx: Arc<GraphExecutionContext>,
229        session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
230        storage: Arc<StorageManager>,
231        schema_info: Arc<UniSchema>,
232        params: HashMap<String, Value>,
233        output_schema: SchemaRef,
234        max_iterations: usize,
235        timeout: Duration,
236        max_derived_bytes: usize,
237        deterministic_best_by: bool,
238        strict_probability_domain: bool,
239        probability_epsilon: f64,
240        exact_probability: bool,
241        max_bdd_variables: usize,
242        top_k_proofs: usize,
243    ) -> Self {
244        Self::new_with_semiring_and_classifiers(
245            strata,
246            commands,
247            derived_scan_registry,
248            plugin_registry,
249            graph_ctx,
250            session_ctx,
251            storage,
252            schema_info,
253            params,
254            output_schema,
255            max_iterations,
256            timeout,
257            max_derived_bytes,
258            deterministic_best_by,
259            strict_probability_domain,
260            probability_epsilon,
261            exact_probability,
262            max_bdd_variables,
263            top_k_proofs,
264            SemiringKind::AddMultProb,
265            Arc::new(ClassifierRegistry::new()),
266        )
267    }
268
269    /// Constructor accepting an explicit semiring. Empty classifier
270    /// registry; for the full Slice 3 variant call
271    /// [`Self::new_with_semiring_and_classifiers`].
272    #[expect(
273        clippy::too_many_arguments,
274        reason = "execution plan node requires full graph and session context"
275    )]
276    pub fn new_with_semiring(
277        strata: Vec<LocyStratum>,
278        commands: Vec<LocyCommand>,
279        derived_scan_registry: Arc<DerivedScanRegistry>,
280        plugin_registry: Arc<PluginRegistry>,
281        graph_ctx: Arc<GraphExecutionContext>,
282        session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
283        storage: Arc<StorageManager>,
284        schema_info: Arc<UniSchema>,
285        params: HashMap<String, Value>,
286        output_schema: SchemaRef,
287        max_iterations: usize,
288        timeout: Duration,
289        max_derived_bytes: usize,
290        deterministic_best_by: bool,
291        strict_probability_domain: bool,
292        probability_epsilon: f64,
293        exact_probability: bool,
294        max_bdd_variables: usize,
295        top_k_proofs: usize,
296        semiring_kind: SemiringKind,
297    ) -> Self {
298        Self::new_with_semiring_and_classifiers(
299            strata,
300            commands,
301            derived_scan_registry,
302            plugin_registry,
303            graph_ctx,
304            session_ctx,
305            storage,
306            schema_info,
307            params,
308            output_schema,
309            max_iterations,
310            timeout,
311            max_derived_bytes,
312            deterministic_best_by,
313            strict_probability_domain,
314            probability_epsilon,
315            exact_probability,
316            max_bdd_variables,
317            top_k_proofs,
318            semiring_kind,
319            Arc::new(ClassifierRegistry::new()),
320        )
321    }
322
323    /// Phase B Slice 3 entry: accepts both the semiring kind and the
324    /// runtime classifier registry.
325    #[expect(
326        clippy::too_many_arguments,
327        reason = "execution plan node requires full graph and session context"
328    )]
329    pub fn new_with_semiring_and_classifiers(
330        strata: Vec<LocyStratum>,
331        commands: Vec<LocyCommand>,
332        derived_scan_registry: Arc<DerivedScanRegistry>,
333        plugin_registry: Arc<PluginRegistry>,
334        graph_ctx: Arc<GraphExecutionContext>,
335        session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
336        storage: Arc<StorageManager>,
337        schema_info: Arc<UniSchema>,
338        params: HashMap<String, Value>,
339        output_schema: SchemaRef,
340        max_iterations: usize,
341        timeout: Duration,
342        max_derived_bytes: usize,
343        deterministic_best_by: bool,
344        strict_probability_domain: bool,
345        probability_epsilon: f64,
346        exact_probability: bool,
347        max_bdd_variables: usize,
348        top_k_proofs: usize,
349        semiring_kind: SemiringKind,
350        classifier_registry: Arc<ClassifierRegistry>,
351    ) -> Self {
352        Self::new_with_semiring_classifiers_and_cache(
353            strata,
354            commands,
355            derived_scan_registry,
356            plugin_registry,
357            graph_ctx,
358            session_ctx,
359            storage,
360            schema_info,
361            params,
362            output_schema,
363            max_iterations,
364            timeout,
365            max_derived_bytes,
366            deterministic_best_by,
367            strict_probability_domain,
368            probability_epsilon,
369            exact_probability,
370            max_bdd_variables,
371            top_k_proofs,
372            semiring_kind,
373            classifier_registry,
374            None,
375            None,
376        )
377    }
378
379    /// Phase B follow-up: full constructor accepting the optional
380    /// memoization cache. Existing callers default to `None` (no
381    /// cache); `impl_locy.rs` threads `LocyConfig.classifier_cache`
382    /// here.
383    #[expect(
384        clippy::too_many_arguments,
385        reason = "execution plan node requires full graph and session context"
386    )]
387    pub fn new_with_semiring_classifiers_and_cache(
388        strata: Vec<LocyStratum>,
389        commands: Vec<LocyCommand>,
390        derived_scan_registry: Arc<DerivedScanRegistry>,
391        plugin_registry: Arc<PluginRegistry>,
392        graph_ctx: Arc<GraphExecutionContext>,
393        session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
394        storage: Arc<StorageManager>,
395        schema_info: Arc<UniSchema>,
396        params: HashMap<String, Value>,
397        output_schema: SchemaRef,
398        max_iterations: usize,
399        timeout: Duration,
400        max_derived_bytes: usize,
401        deterministic_best_by: bool,
402        strict_probability_domain: bool,
403        probability_epsilon: f64,
404        exact_probability: bool,
405        max_bdd_variables: usize,
406        top_k_proofs: usize,
407        semiring_kind: SemiringKind,
408        classifier_registry: Arc<ClassifierRegistry>,
409        classifier_cache: Option<Arc<ModelInvocationCache>>,
410        classifier_provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
411    ) -> Self {
412        let properties = compute_plan_properties(Arc::clone(&output_schema));
413        Self {
414            strata,
415            commands,
416            derived_scan_registry,
417            plugin_registry,
418            graph_ctx,
419            session_ctx,
420            storage,
421            schema_info,
422            params,
423            output_schema,
424            properties,
425            metrics: ExecutionPlanMetricsSet::new(),
426            max_iterations,
427            timeout,
428            max_derived_bytes,
429            deterministic_best_by,
430            strict_probability_domain,
431            probability_epsilon,
432            exact_probability,
433            max_bdd_variables,
434            semiring_kind,
435            classifier_registry,
436            classifier_cache,
437            classifier_provenance_store,
438            derived_store_slot: Arc::new(StdRwLock::new(None)),
439            approximate_slot: Arc::new(StdRwLock::new(HashMap::new())),
440            derivation_tracker: Arc::new(StdRwLock::new(None)),
441            iteration_counts_slot: Arc::new(StdRwLock::new(HashMap::new())),
442            peak_memory_slot: Arc::new(StdRwLock::new(0)),
443            warnings_slot: Arc::new(StdRwLock::new(Vec::new())),
444            command_results_slot: Arc::new(StdRwLock::new(Vec::new())),
445            top_k_proofs,
446            timeout_flag: Arc::new(std::sync::atomic::AtomicU8::new(interruption::NONE)),
447            incomplete_slot: Arc::new(StdRwLock::new(None)),
448        }
449    }
450
451    /// Returns a shared handle to the derived store slot.
452    ///
453    /// After execution completes, the slot contains the `DerivedStore` with all
454    /// converged facts. Read it with `slot.read().unwrap()`.
455    pub fn derived_store_slot(&self) -> Arc<StdRwLock<Option<DerivedStore>>> {
456        Arc::clone(&self.derived_store_slot)
457    }
458
459    /// Inject a `ProvenanceStore` to record provenance during fixpoint iteration.
460    ///
461    /// Must be called before `execute()` is invoked (i.e., before DataFusion runs
462    /// the physical plan). Uses interior mutability so it works through `&self`.
463    pub fn set_derivation_tracker(&self, tracker: Arc<ProvenanceStore>) {
464        if let Ok(mut guard) = self.derivation_tracker.write() {
465            *guard = Some(tracker);
466        }
467    }
468
469    /// Returns the shared iteration counts slot.
470    ///
471    /// After execution, the slot contains per-rule iteration counts from the
472    /// most recent fixpoint convergence. Sum the values for `total_iterations`.
473    pub fn iteration_counts_slot(&self) -> Arc<StdRwLock<HashMap<String, usize>>> {
474        Arc::clone(&self.iteration_counts_slot)
475    }
476
477    /// Returns the shared peak memory slot.
478    ///
479    /// After execution, the slot contains the peak byte count of derived facts
480    /// across all strata. Read it with `slot.read().unwrap()`.
481    pub fn peak_memory_slot(&self) -> Arc<StdRwLock<usize>> {
482        Arc::clone(&self.peak_memory_slot)
483    }
484
485    /// Returns the shared runtime warnings slot.
486    ///
487    /// After execution, the slot contains warnings collected during fixpoint
488    /// iteration (e.g. shared probabilistic dependencies).
489    pub fn warnings_slot(&self) -> Arc<StdRwLock<Vec<RuntimeWarning>>> {
490        Arc::clone(&self.warnings_slot)
491    }
492
493    /// Returns the shared approximate groups slot.
494    ///
495    /// After execution, the slot contains rule→key group descriptions for
496    /// groups where BDD computation fell back to independence mode.
497    pub fn approximate_slot(&self) -> Arc<StdRwLock<HashMap<String, Vec<String>>>> {
498        Arc::clone(&self.approximate_slot)
499    }
500
501    /// Returns the shared command results slot.
502    ///
503    /// After execution, the slot contains `(command_index, CommandResult)` pairs
504    /// for commands that were executed inline by `run_program()` (QUERY, Cypher).
505    pub fn command_results_slot(&self) -> Arc<StdRwLock<Vec<(usize, CommandResult)>>> {
506        Arc::clone(&self.command_results_slot)
507    }
508
509    /// Returns the shared interruption signal.
510    ///
511    /// After execution, a non-zero value means the evaluation was cut short
512    /// (timeout or iteration limit) and the derived store holds partial results.
513    /// Prefer [`LocyProgramExec::incomplete_slot`] for the decoded diagnostics.
514    pub fn timeout_flag(&self) -> Arc<std::sync::atomic::AtomicU8> {
515        Arc::clone(&self.timeout_flag)
516    }
517
518    /// Returns the shared incomplete-evaluation diagnostics slot.
519    ///
520    /// After execution, `Some(detail)` means evaluation stopped before
521    /// completing; `detail` names the skipped / unsound-complement rules and the
522    /// stop reason. `None` for a complete run.
523    pub fn incomplete_slot(&self) -> Arc<StdRwLock<Option<uni_common::LocyIncomplete>>> {
524        Arc::clone(&self.incomplete_slot)
525    }
526}
527
528impl DisplayAs for LocyProgramExec {
529    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
530        write!(
531            f,
532            "LocyProgramExec: strata={}, commands={}, max_iter={}, timeout={:?}",
533            self.strata.len(),
534            self.commands.len(),
535            self.max_iterations,
536            self.timeout,
537        )
538    }
539}
540
541impl ExecutionPlan for LocyProgramExec {
542    fn name(&self) -> &str {
543        "LocyProgramExec"
544    }
545
546    fn as_any(&self) -> &dyn Any {
547        self
548    }
549
550    fn schema(&self) -> SchemaRef {
551        Arc::clone(&self.output_schema)
552    }
553
554    fn properties(&self) -> &Arc<PlanProperties> {
555        &self.properties
556    }
557
558    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
559        vec![]
560    }
561
562    fn with_new_children(
563        self: Arc<Self>,
564        children: Vec<Arc<dyn ExecutionPlan>>,
565    ) -> DFResult<Arc<dyn ExecutionPlan>> {
566        if !children.is_empty() {
567            return Err(datafusion::error::DataFusionError::Plan(
568                "LocyProgramExec has no children".to_string(),
569            ));
570        }
571        Ok(self)
572    }
573
574    fn execute(
575        &self,
576        partition: usize,
577        _context: Arc<TaskContext>,
578    ) -> DFResult<SendableRecordBatchStream> {
579        let metrics = BaselineMetrics::new(&self.metrics, partition);
580
581        let strata = self.strata.clone();
582        let registry = Arc::clone(&self.derived_scan_registry);
583        let plugin_registry = Arc::clone(&self.plugin_registry);
584        let graph_ctx = Arc::clone(&self.graph_ctx);
585        let session_ctx = Arc::clone(&self.session_ctx);
586        let storage = Arc::clone(&self.storage);
587        let schema_info = Arc::clone(&self.schema_info);
588        let params = self.params.clone();
589        let output_schema = Arc::clone(&self.output_schema);
590        let max_iterations = self.max_iterations;
591        let timeout = self.timeout;
592        let max_derived_bytes = self.max_derived_bytes;
593        let deterministic_best_by = self.deterministic_best_by;
594        let strict_probability_domain = self.strict_probability_domain;
595        let probability_epsilon = self.probability_epsilon;
596        let exact_probability = self.exact_probability;
597        let max_bdd_variables = self.max_bdd_variables;
598        let derived_store_slot = Arc::clone(&self.derived_store_slot);
599        let approximate_slot = Arc::clone(&self.approximate_slot);
600        let iteration_counts_slot = Arc::clone(&self.iteration_counts_slot);
601        let peak_memory_slot = Arc::clone(&self.peak_memory_slot);
602        let derivation_tracker = self.derivation_tracker.read().ok().and_then(|g| g.clone());
603        let warnings_slot = Arc::clone(&self.warnings_slot);
604        let commands = self.commands.clone();
605        let command_results_slot = Arc::clone(&self.command_results_slot);
606        let top_k_proofs = self.top_k_proofs;
607        let timeout_flag = Arc::clone(&self.timeout_flag);
608        let incomplete_slot = Arc::clone(&self.incomplete_slot);
609        let semiring_kind = self.semiring_kind;
610        let classifier_registry = Arc::clone(&self.classifier_registry);
611        let classifier_cache = self.classifier_cache.as_ref().map(Arc::clone);
612        let classifier_provenance_store = self.classifier_provenance_store.as_ref().map(Arc::clone);
613
614        let fut = async move {
615            run_program(
616                strata,
617                commands,
618                registry,
619                plugin_registry,
620                graph_ctx,
621                session_ctx,
622                storage,
623                schema_info,
624                params,
625                output_schema,
626                max_iterations,
627                timeout,
628                max_derived_bytes,
629                deterministic_best_by,
630                strict_probability_domain,
631                probability_epsilon,
632                exact_probability,
633                max_bdd_variables,
634                derived_store_slot,
635                approximate_slot,
636                iteration_counts_slot,
637                peak_memory_slot,
638                derivation_tracker,
639                warnings_slot,
640                command_results_slot,
641                top_k_proofs,
642                timeout_flag,
643                incomplete_slot,
644                semiring_kind,
645                classifier_registry,
646                classifier_cache,
647                classifier_provenance_store,
648            )
649            .await
650        };
651
652        Ok(Box::pin(ProgramStream {
653            state: ProgramStreamState::Running(Box::pin(fut)),
654            schema: Arc::clone(&self.output_schema),
655            metrics,
656        }))
657    }
658
659    fn metrics(&self) -> Option<MetricsSet> {
660        Some(self.metrics.clone_inner())
661    }
662}
663
664// ---------------------------------------------------------------------------
665// ProgramStream — async state machine for streaming results
666// ---------------------------------------------------------------------------
667
668enum ProgramStreamState {
669    Running(Pin<Box<dyn std::future::Future<Output = DFResult<Vec<RecordBatch>>> + Send>>),
670    Emitting(Vec<RecordBatch>, usize),
671    Done,
672}
673
674struct ProgramStream {
675    state: ProgramStreamState,
676    schema: SchemaRef,
677    metrics: BaselineMetrics,
678}
679
680impl Stream for ProgramStream {
681    type Item = DFResult<RecordBatch>;
682
683    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
684        let this = self.get_mut();
685        let metrics = this.metrics.clone();
686        let _timer = metrics.elapsed_compute().timer();
687        loop {
688            match &mut this.state {
689                ProgramStreamState::Running(fut) => match fut.as_mut().poll(cx) {
690                    Poll::Ready(Ok(batches)) => {
691                        if batches.is_empty() {
692                            this.state = ProgramStreamState::Done;
693                            return Poll::Ready(None);
694                        }
695                        this.state = ProgramStreamState::Emitting(batches, 0);
696                    }
697                    Poll::Ready(Err(e)) => {
698                        this.state = ProgramStreamState::Done;
699                        return Poll::Ready(Some(Err(e)));
700                    }
701                    Poll::Pending => return Poll::Pending,
702                },
703                ProgramStreamState::Emitting(batches, idx) => {
704                    if *idx >= batches.len() {
705                        this.state = ProgramStreamState::Done;
706                        return Poll::Ready(None);
707                    }
708                    let batch = batches[*idx].clone();
709                    *idx += 1;
710                    this.metrics.record_output(batch.num_rows());
711                    return Poll::Ready(Some(Ok(batch)));
712                }
713                ProgramStreamState::Done => return Poll::Ready(None),
714            }
715        }
716    }
717}
718
719impl RecordBatchStream for ProgramStream {
720    fn schema(&self) -> SchemaRef {
721        Arc::clone(&self.schema)
722    }
723}
724
725// ---------------------------------------------------------------------------
726// Inline command execution helpers
727// ---------------------------------------------------------------------------
728
729/// Execute Cypher passthrough via execute_subplan.
730async fn execute_cypher_inline(
731    query: &uni_cypher::ast::Query,
732    schema_info: &Arc<UniSchema>,
733    params: &HashMap<String, Value>,
734    graph_ctx: &Arc<GraphExecutionContext>,
735    session_ctx: &Arc<RwLock<datafusion::prelude::SessionContext>>,
736    storage: &Arc<StorageManager>,
737) -> DFResult<Vec<FactRow>> {
738    let planner = crate::query::planner::QueryPlanner::new(Arc::clone(schema_info));
739    let logical_plan = planner.plan(query.clone()).map_err(|e| {
740        datafusion::error::DataFusionError::Execution(format!("Cypher plan error: {e}"))
741    })?;
742    let batches = execute_subplan(
743        &logical_plan,
744        params,
745        &HashMap::new(),
746        graph_ctx,
747        session_ctx,
748        storage,
749        schema_info,
750        None, // Locy paths are read-only (queries + fact extraction)
751    )
752    .await?;
753    Ok(super::locy_eval::record_batches_to_locy_rows(&batches))
754}
755
756// ---------------------------------------------------------------------------
757// run_program — core evaluation algorithm
758// ---------------------------------------------------------------------------
759
760#[expect(
761    clippy::too_many_arguments,
762    reason = "program evaluation requires full graph and session context"
763)]
764async fn run_program(
765    strata: Vec<LocyStratum>,
766    commands: Vec<LocyCommand>,
767    registry: Arc<DerivedScanRegistry>,
768    plugin_registry: Arc<PluginRegistry>,
769    graph_ctx: Arc<GraphExecutionContext>,
770    session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
771    storage: Arc<StorageManager>,
772    schema_info: Arc<UniSchema>,
773    params: HashMap<String, Value>,
774    output_schema: SchemaRef,
775    max_iterations: usize,
776    timeout: Duration,
777    max_derived_bytes: usize,
778    deterministic_best_by: bool,
779    strict_probability_domain: bool,
780    probability_epsilon: f64,
781    exact_probability: bool,
782    max_bdd_variables: usize,
783    derived_store_slot: Arc<StdRwLock<Option<DerivedStore>>>,
784    approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
785    iteration_counts_slot: Arc<StdRwLock<HashMap<String, usize>>>,
786    peak_memory_slot: Arc<StdRwLock<usize>>,
787    derivation_tracker: Option<Arc<ProvenanceStore>>,
788    warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
789    command_results_slot: Arc<StdRwLock<Vec<(usize, CommandResult)>>>,
790    top_k_proofs: usize,
791    timeout_flag: Arc<std::sync::atomic::AtomicU8>,
792    incomplete_slot: Arc<StdRwLock<Option<uni_common::LocyIncomplete>>>,
793    semiring_kind: SemiringKind,
794    classifier_registry: Arc<ClassifierRegistry>,
795    classifier_cache: Option<Arc<ModelInvocationCache>>,
796    classifier_provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
797) -> DFResult<Vec<RecordBatch>> {
798    let start = Instant::now();
799    let mut derived_store = DerivedStore::new();
800
801    // IMPORTANT: per rollout D-9 the FuzzyNotProbabilistic warning is
802    // unsuppressible. Emit one warning per PROB-bearing rule at program
803    // start under MaxMinProb. The recursive path in `run_fixpoint_loop`
804    // dedups against this set.
805    if semiring_kind == SemiringKind::MaxMinProb {
806        let mut warnings = warnings_slot.write().unwrap_or_else(|e| e.into_inner());
807        let mut already: std::collections::HashSet<String> = warnings
808            .iter()
809            .filter(|w| w.code == uni_locy::RuntimeWarningCode::FuzzyNotProbabilistic)
810            .map(|w| w.rule_name.clone())
811            .collect();
812        for stratum in &strata {
813            for rule in &stratum.rules {
814                let has_prob = rule.yield_schema.iter().any(|c| c.is_prob);
815                if has_prob && !already.contains(&rule.name) {
816                    warnings.push(RuntimeWarning {
817                        code: uni_locy::RuntimeWarningCode::FuzzyNotProbabilistic,
818                        message: format!(
819                            "rule '{}' carries a PROB column but is being evaluated under \
820                             the MaxMinProb (fuzzy / Viterbi) semiring; outputs are fuzzy \
821                             truth values, not probabilities",
822                            rule.name
823                        ),
824                        rule_name: rule.name.clone(),
825                        variable_count: None,
826                        key_group: None,
827                    });
828                    already.insert(rule.name.clone());
829                }
830            }
831        }
832    }
833
834    // Evaluate each stratum in topological order, tracking how far we get so an
835    // interruption can distinguish rules left incomplete (partial fixpoint) from
836    // rules never reached (skipped) — neither is "genuinely empty".
837    let total_strata = strata.len();
838    let mut completed_strata = 0usize;
839    let mut partial_stratum: Option<usize> = None;
840    for (stratum_idx, stratum) in strata.iter().enumerate() {
841        // Write cross-stratum facts into registry handles for strata we depend on
842        write_cross_stratum_facts(&registry, &derived_store, stratum);
843
844        let remaining_timeout = timeout.saturating_sub(start.elapsed());
845        if remaining_timeout.is_zero() {
846            tracing::warn!("Locy program timeout exceeded during stratum evaluation");
847            interruption::set(&timeout_flag, interruption::TIMEOUT);
848            break;
849        }
850
851        if stratum.is_recursive {
852            // Convert LocyRulePlan → FixpointRulePlan and run fixpoint
853            let fixpoint_rules = convert_to_fixpoint_plans(
854                &stratum.rules,
855                &registry,
856                &plugin_registry,
857                deterministic_best_by,
858            )?;
859            let fixpoint_schema = build_fixpoint_output_schema(&stratum.rules);
860
861            let exec = FixpointExec::new_with_semiring_classifiers_and_cache(
862                fixpoint_rules,
863                max_iterations,
864                remaining_timeout,
865                Arc::clone(&graph_ctx),
866                Arc::clone(&session_ctx),
867                Arc::clone(&storage),
868                Arc::clone(&schema_info),
869                params.clone(),
870                Arc::clone(&registry),
871                fixpoint_schema,
872                max_derived_bytes,
873                derivation_tracker.clone(),
874                Arc::clone(&iteration_counts_slot),
875                strict_probability_domain,
876                probability_epsilon,
877                exact_probability,
878                max_bdd_variables,
879                Arc::clone(&warnings_slot),
880                Arc::clone(&approximate_slot),
881                top_k_proofs,
882                Arc::clone(&timeout_flag),
883                semiring_kind,
884                Arc::clone(&classifier_registry),
885                classifier_cache.as_ref().map(Arc::clone),
886                classifier_provenance_store.as_ref().map(Arc::clone),
887            );
888
889            let task_ctx = session_ctx.read().task_ctx();
890            let exec_arc: Arc<dyn ExecutionPlan> = Arc::new(exec);
891            let batches = collect_all_partitions(&exec_arc, task_ctx).await?;
892
893            // FixpointExec concatenates all rules' output; store per-rule.
894            // For now, store all output under each rule name (since FixpointExec
895            // handles per-rule state internally, the output is already correct).
896            // NOTE(deferred): Per-rule fact demultiplexing is not yet implemented.
897            // FixpointExec concatenates all rules' output into a single batch stream.
898            // Proper demux requires FixpointExec to tag output batches with rule identity
899            // (e.g. an extra column or side-channel), which is a non-trivial change to
900            // run_fixpoint_loop. The current schema-field-count heuristic (filter below)
901            // works because recursive stratum rules share compatible schemas.
902            // Revisit when cross-stratum consumption of individual recursive rules is needed.
903            for rule in &stratum.rules {
904                // Skip DERIVE-only rules (empty yield_schema).
905                if rule.yield_schema.is_empty() {
906                    continue;
907                }
908                // Write converged facts into registry handles for cross-stratum consumers
909                let rule_entries = registry.entries_for_rule(&rule.name);
910                for entry in rule_entries {
911                    if !entry.is_self_ref {
912                        // Cross-stratum handles get the full fixpoint output
913                        // In practice, FixpointExec already wrote self-ref handles;
914                        // we need to write non-self-ref handles for later strata.
915                        let all_facts: Vec<RecordBatch> = batches
916                            .iter()
917                            .filter(|b| {
918                                // If schemas match, this batch belongs to this rule
919                                let rule_schema = yield_columns_to_arrow_schema(&rule.yield_schema);
920                                b.schema().fields().len() == rule_schema.fields().len()
921                            })
922                            .cloned()
923                            .collect();
924                        let mut guard = entry.data.write();
925                        *guard = if all_facts.is_empty() {
926                            vec![RecordBatch::new_empty(Arc::clone(&entry.schema))]
927                        } else {
928                            all_facts
929                        };
930                    }
931                }
932                derived_store.insert(rule.name.clone(), batches.clone());
933            }
934        } else {
935            // Non-recursive: single-pass evaluation
936            let fixpoint_rules = convert_to_fixpoint_plans(
937                &stratum.rules,
938                &registry,
939                &plugin_registry,
940                deterministic_best_by,
941            )?;
942            let task_ctx = session_ctx.read().task_ctx();
943
944            for (rule, fp_rule) in stratum.rules.iter().zip(fixpoint_rules.iter()) {
945                // DERIVE-only rules have empty yield_schema (the compiler's
946                // infer_yield_schema only matches RuleOutput::Yield). Skip them
947                // in the fixpoint loop — DERIVE materialization is handled by
948                // the DERIVE command dispatch, not by the fixpoint.
949                if rule.yield_schema.is_empty() {
950                    continue;
951                }
952
953                // Record the single evaluation pass for this non-recursive rule.
954                // The recursive branch writes per-rule fixpoint counts to this slot;
955                // a non-recursive rule is evaluated exactly once, so without this a
956                // purely non-recursive program would report `total_iterations == 0`.
957                if let Ok(mut counts) = iteration_counts_slot.write() {
958                    counts.insert(rule.name.clone(), 1);
959                }
960
961                // Process each clause independently (per-clause IS NOT).
962                let mut tagged_clause_facts: Vec<(usize, Vec<RecordBatch>)> = Vec::new();
963                for (clause_idx, (clause, fp_clause)) in
964                    rule.clauses.iter().zip(fp_rule.clauses.iter()).enumerate()
965                {
966                    // Phase B A4 follow-up: the planner inserts
967                    // `LocyModelInvoke` between body and `LocyProject`
968                    // when the clause has neural invocations.
969                    let mut batches = execute_subplan(
970                        &clause.body,
971                        &params,
972                        &HashMap::new(),
973                        &graph_ctx,
974                        &session_ctx,
975                        &storage,
976                        &schema_info,
977                        None, // Locy clause body is read-only
978                    )
979                    .await?;
980
981                    // Apply negated IS-ref semantics per-clause.
982                    for binding in &fp_clause.is_ref_bindings {
983                        if binding.negated
984                            && !binding.anti_join_cols.is_empty()
985                            && let Some(entry) = registry.get(binding.derived_scan_index)
986                        {
987                            let neg_facts = entry.data.read().clone();
988                            if !neg_facts.is_empty() {
989                                if binding.target_has_prob && fp_rule.prob_column_name.is_some() {
990                                    let complement_col =
991                                        format!("__prob_complement_{}", binding.rule_name);
992                                    if let Some(prob_col) = &binding.target_prob_col {
993                                        batches =
994                                            super::locy_fixpoint::apply_prob_complement_composite(
995                                                batches,
996                                                &neg_facts,
997                                                &binding.anti_join_cols,
998                                                prob_col,
999                                                &complement_col,
1000                                            )?;
1001                                    } else {
1002                                        // target_has_prob but no prob_col: fall back to anti-join.
1003                                        batches = super::locy_fixpoint::apply_anti_join_composite(
1004                                            batches,
1005                                            &neg_facts,
1006                                            &binding.anti_join_cols,
1007                                        )?;
1008                                    }
1009                                } else {
1010                                    batches = super::locy_fixpoint::apply_anti_join_composite(
1011                                        batches,
1012                                        &neg_facts,
1013                                        &binding.anti_join_cols,
1014                                    )?;
1015                                }
1016                            }
1017                        }
1018                    }
1019
1020                    // Multiply complement columns into PROB per-clause.
1021                    let complement_cols: Vec<String> = if !batches.is_empty() {
1022                        batches[0]
1023                            .schema()
1024                            .fields()
1025                            .iter()
1026                            .filter(|f| f.name().starts_with("__prob_complement_"))
1027                            .map(|f| f.name().clone())
1028                            .collect()
1029                    } else {
1030                        vec![]
1031                    };
1032                    if !complement_cols.is_empty() {
1033                        batches = super::locy_fixpoint::multiply_prob_factors(
1034                            batches,
1035                            fp_rule.prob_column_name.as_deref(),
1036                            &complement_cols,
1037                        )?;
1038                    }
1039
1040                    tagged_clause_facts.push((clause_idx, batches));
1041                }
1042
1043                // Record provenance and detect shared proofs for non-recursive rules.
1044                //
1045                // TODO(C0-stage2): swap `record_and_detect_lineage_nonrecursive`
1046                // for `TopKTag` DNF inspection when
1047                // `semiring_kind == TopKProofs { k }`. Library-layer
1048                // tag math landed in
1049                // `crates/uni-locy/src/top_k_proofs.rs` (Phase C C0
1050                // Stage 1); Stage 2 wires per-row tags through the
1051                // runtime so dependencies are visible here.
1052                //
1053                // Under MaxMinProb, `plus = max` is idempotent so shared
1054                // proofs don't double-count — skip the (misleading) warning.
1055                let shared_info = if semiring_kind == SemiringKind::MaxMinProb {
1056                    None
1057                } else if let Some(ref tracker) = derivation_tracker {
1058                    super::locy_fixpoint::record_and_detect_lineage_nonrecursive(
1059                        fp_rule,
1060                        &tagged_clause_facts,
1061                        tracker,
1062                        &warnings_slot,
1063                        &registry,
1064                        top_k_proofs,
1065                        super::locy_fixpoint::ClassifierRefs {
1066                            registry: &classifier_registry,
1067                            cache: classifier_cache.as_ref(),
1068                            provenance_store: classifier_provenance_store.as_ref(),
1069                        },
1070                        semiring_kind,
1071                    )
1072                    .await
1073                } else {
1074                    None
1075                };
1076
1077                // Flatten tagged facts for post-fixpoint chain.
1078                let mut all_clause_facts: Vec<RecordBatch> = tagged_clause_facts
1079                    .into_iter()
1080                    .flat_map(|(_, batches)| batches)
1081                    .collect();
1082
1083                // Apply BDD for shared groups if exact_probability is enabled.
1084                if exact_probability
1085                    && let Some(ref info) = shared_info
1086                    && let Some(ref tracker) = derivation_tracker
1087                {
1088                    all_clause_facts = super::locy_fixpoint::apply_exact_wmc(
1089                        all_clause_facts,
1090                        fp_rule,
1091                        info,
1092                        tracker,
1093                        max_bdd_variables,
1094                        &warnings_slot,
1095                        &approximate_slot,
1096                    )?;
1097                }
1098
1099                // Apply post-fixpoint operators (PRIORITY, FOLD, BEST BY) on union.
1100                let facts = super::locy_fixpoint::apply_post_fixpoint_chain(
1101                    all_clause_facts,
1102                    fp_rule,
1103                    &task_ctx,
1104                    strict_probability_domain,
1105                    probability_epsilon,
1106                    semiring_kind,
1107                    derivation_tracker.as_ref().map(Arc::clone),
1108                    top_k_proofs,
1109                    Some(Arc::clone(&registry)),
1110                )
1111                .await?;
1112
1113                // Write facts into registry handles for later strata
1114                write_facts_to_registry(&registry, &rule.name, &facts);
1115                derived_store.insert(rule.name.clone(), facts);
1116            }
1117        }
1118
1119        // The recursive fixpoint can set the interruption flag mid-stratum (the
1120        // non-recursive branch cannot). Stop here either way so later strata are
1121        // recorded as skipped rather than passed off as empty.
1122        if interruption::reason(&timeout_flag).is_some() {
1123            partial_stratum = Some(stratum_idx);
1124            break;
1125        }
1126        completed_strata += 1;
1127    }
1128
1129    // If evaluation was cut short, record which rules were left incomplete vs.
1130    // never reached, flagging any complement (`IS NOT`) rules among them as
1131    // unsound. Read by impl_locy to choose Err(LocyIncomplete) vs. Ok(partial).
1132    if let Some(reason) = interruption::reason(&timeout_flag) {
1133        let skipped_start = match partial_stratum {
1134            Some(i) => i + 1,
1135            None => completed_strata,
1136        };
1137        let incomplete_rules: Vec<String> = partial_stratum
1138            .map(|i| strata[i].rules.iter().map(|r| r.name.clone()).collect())
1139            .unwrap_or_default();
1140        let skipped_rules: Vec<String> = strata[skipped_start..]
1141            .iter()
1142            .flat_map(|s| s.rules.iter().map(|r| r.name.clone()))
1143            .collect();
1144        let mut complement_rules_affected = Vec::new();
1145        for idx in partial_stratum
1146            .into_iter()
1147            .chain(skipped_start..total_strata)
1148        {
1149            for rule in &strata[idx].rules {
1150                if rule
1151                    .clauses
1152                    .iter()
1153                    .any(|c| c.is_refs.iter().any(|r| r.negated))
1154                {
1155                    complement_rules_affected.push(rule.name.clone());
1156                }
1157            }
1158        }
1159        if let Ok(mut slot) = incomplete_slot.write() {
1160            *slot = Some(uni_common::LocyIncomplete {
1161                reason,
1162                elapsed_ms: u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
1163                limit_ms: u64::try_from(timeout.as_millis()).unwrap_or(u64::MAX),
1164                max_iterations,
1165                completed_strata,
1166                total_strata,
1167                incomplete_rules,
1168                skipped_rules,
1169                complement_rules_affected,
1170            });
1171        }
1172    }
1173
1174    // Compute peak memory from derived store byte sizes
1175    let peak_bytes: usize = derived_store
1176        .relations
1177        .values()
1178        .flat_map(|batches| batches.iter())
1179        .map(|b| {
1180            b.columns()
1181                .iter()
1182                .map(|col| col.get_buffer_memory_size())
1183                .sum::<usize>()
1184        })
1185        .sum();
1186    *peak_memory_slot.write().unwrap() = peak_bytes;
1187
1188    // Execute inline Cypher commands via execute_subplan.
1189    // QUERY is deferred to the orchestrator: the DerivedStore uses inferred types
1190    // (e.g. Float64 for property-derived columns) which don't preserve the actual
1191    // property values. The orchestrator's SLG path re-derives with correct types.
1192    // DERIVE/ASSUME/EXPLAIN/ABDUCE are also deferred (need L0 fork/restore, tree output, etc.).
1193    //
1194    // Cypher commands that appear AFTER a DERIVE command are also deferred:
1195    // they need the ephemeral L0 overlay populated by DERIVE to see derived
1196    // edges, which is only available in the orchestrator's dispatch loop.
1197    let first_derive_idx = commands
1198        .iter()
1199        .position(|c| matches!(c, LocyCommand::Derive { .. }));
1200    let mut inline_results: Vec<(usize, CommandResult)> = Vec::new();
1201    for (cmd_idx, cmd) in commands.iter().enumerate() {
1202        match cmd {
1203            LocyCommand::Cypher { query } => {
1204                // Defer Cypher commands that follow a DERIVE to the dispatch loop
1205                // so they can read from the ephemeral L0 overlay.
1206                if first_derive_idx.is_some_and(|di| cmd_idx > di) {
1207                    continue;
1208                }
1209                let rows = execute_cypher_inline(
1210                    query,
1211                    &schema_info,
1212                    &params,
1213                    &graph_ctx,
1214                    &session_ctx,
1215                    &storage,
1216                )
1217                .await?;
1218                inline_results.push((cmd_idx, CommandResult::Cypher(rows)));
1219            }
1220            LocyCommand::Validate { validate } => {
1221                // Phase C C3: collect ground-truth pairs via a
1222                // MATCH+TARGET query, join with the rule's derived
1223                // facts on KEY columns, compute metrics.
1224                let rule_key_cols: Vec<String> = strata
1225                    .iter()
1226                    .flat_map(|s| s.rules.iter())
1227                    .find(|r| r.name == validate.rule_name)
1228                    .map(|r| {
1229                        r.yield_schema
1230                            .iter()
1231                            .filter(|c| c.is_key)
1232                            .map(|c| c.name.clone())
1233                            .collect()
1234                    })
1235                    .unwrap_or_default();
1236                let query =
1237                    super::locy_validate::validate_collection_query(validate, &rule_key_cols);
1238                let target_rows = execute_cypher_inline(
1239                    &query,
1240                    &schema_info,
1241                    &params,
1242                    &graph_ctx,
1243                    &session_ctx,
1244                    &storage,
1245                )
1246                .await?;
1247                let rule_facts: Vec<uni_locy::FactRow> = derived_store
1248                    .get(&validate.rule_name)
1249                    .map(|batches| super::locy_eval::record_batches_to_locy_rows(batches))
1250                    .unwrap_or_default();
1251                let result = super::locy_validate::run_validate(
1252                    validate,
1253                    &rule_key_cols,
1254                    &rule_facts,
1255                    target_rows,
1256                )
1257                .map_err(|e| {
1258                    datafusion::error::DataFusionError::Execution(format!("VALIDATE error: {e}"))
1259                })?;
1260                inline_results.push((cmd_idx, CommandResult::Validate(result)));
1261            }
1262            LocyCommand::Calibrate {
1263                calibrate,
1264                model_inputs,
1265            } => {
1266                // Phase C C2: dispatch a CALIBRATE command. Build a
1267                // Cypher MATCH+RETURN query that projects the model's
1268                // input variables + the TARGET expression, execute
1269                // it, then drive `run_calibrate` over the collected
1270                // rows. The fitted calibrator + holdout metrics
1271                // surface as `CommandResult::Calibrate(...)`.
1272                //
1273                // Synthesize a CompiledModel snapshot from the carried
1274                // model_inputs so we can build the collection query
1275                // without lugging the full catalog through this call
1276                // site. Other fields the runtime doesn't read are
1277                // filled with defaults.
1278                let model_snapshot = uni_locy::CompiledModel {
1279                    name: calibrate.model_name.clone(),
1280                    inputs: model_inputs.clone(),
1281                    features: vec![],
1282                    path_context: None,
1283                    output_type: uni_cypher::locy_ast::OutputType::Prob,
1284                    output_name: String::new(),
1285                    xervo_alias: String::new(),
1286                    embedder_alias: None,
1287                    calibration: None,
1288                    version: None,
1289                    annotations: Default::default(),
1290                };
1291                let query =
1292                    super::locy_calibrate::calibrate_collection_query(calibrate, &model_snapshot);
1293                let rows = execute_cypher_inline(
1294                    &query,
1295                    &schema_info,
1296                    &params,
1297                    &graph_ctx,
1298                    &session_ctx,
1299                    &storage,
1300                )
1301                .await?;
1302                let mut catalog = std::collections::HashMap::new();
1303                catalog.insert(calibrate.model_name.clone(), model_snapshot);
1304                let result = super::locy_calibrate::run_calibrate(
1305                    calibrate,
1306                    &catalog,
1307                    &classifier_registry,
1308                    rows,
1309                )
1310                .await
1311                .map_err(|e| {
1312                    datafusion::error::DataFusionError::Execution(format!("CALIBRATE error: {e}"))
1313                })?;
1314                inline_results.push((cmd_idx, CommandResult::Calibrate(result)));
1315            }
1316            _ => {}
1317        }
1318    }
1319    *command_results_slot.write().unwrap() = inline_results;
1320
1321    let stats = vec![build_stats_batch(&derived_store, &strata, output_schema)];
1322    *derived_store_slot.write().unwrap() = Some(derived_store);
1323    Ok(stats)
1324}
1325
1326// ---------------------------------------------------------------------------
1327// Cross-stratum fact injection
1328// ---------------------------------------------------------------------------
1329
1330/// Write already-evaluated facts into registry handles for cross-stratum IS-refs.
1331fn write_cross_stratum_facts(
1332    registry: &DerivedScanRegistry,
1333    derived_store: &DerivedStore,
1334    stratum: &LocyStratum,
1335) {
1336    // For each rule in this stratum, find IS-refs to rules in other strata
1337    for rule in &stratum.rules {
1338        for clause in &rule.clauses {
1339            for is_ref in &clause.is_refs {
1340                // If this IS-ref points to a rule already in the derived store
1341                // (i.e., from a previous stratum), write its facts into the registry
1342                if let Some(facts) = derived_store.get(&is_ref.rule_name) {
1343                    write_facts_to_registry(registry, &is_ref.rule_name, facts);
1344                }
1345            }
1346        }
1347    }
1348}
1349
1350/// Write facts into non-self-ref registry handles for a given rule.
1351fn write_facts_to_registry(registry: &DerivedScanRegistry, rule_name: &str, facts: &[RecordBatch]) {
1352    let entries = registry.entries_for_rule(rule_name);
1353    for entry in entries {
1354        if !entry.is_self_ref {
1355            let mut guard = entry.data.write();
1356            *guard = if facts.is_empty() || facts.iter().all(|b| b.num_rows() == 0) {
1357                vec![RecordBatch::new_empty(Arc::clone(&entry.schema))]
1358            } else {
1359                // Try to re-wrap batches with the entry's schema for column name
1360                // alignment. If the types don't match (e.g. inferred Float64 vs
1361                // actual Utf8 from schema mode), fall back to the batch's own
1362                // schema to avoid silent data loss.
1363                facts
1364                    .iter()
1365                    .filter(|b| b.num_rows() > 0)
1366                    .map(|b| {
1367                        RecordBatch::try_new(Arc::clone(&entry.schema), b.columns().to_vec())
1368                            .unwrap_or_else(|_| b.clone())
1369                    })
1370                    .collect()
1371            };
1372        }
1373    }
1374}
1375
1376// ---------------------------------------------------------------------------
1377// LocyRulePlan → FixpointRulePlan conversion
1378// ---------------------------------------------------------------------------
1379
1380/// Convert logical `LocyRulePlan` types to physical `FixpointRulePlan` types.
1381fn convert_to_fixpoint_plans(
1382    rules: &[LocyRulePlan],
1383    registry: &DerivedScanRegistry,
1384    plugin_registry: &PluginRegistry,
1385    deterministic_best_by: bool,
1386) -> DFResult<Vec<FixpointRulePlan>> {
1387    // `rules` is one stratum's rule set, so membership here means
1388    // "same stratum" — the recursion-detection set for `non_linear`.
1389    let stratum_rule_names: std::collections::HashSet<&str> =
1390        rules.iter().map(|r| r.name.as_str()).collect();
1391    rules
1392        .iter()
1393        .map(|rule| {
1394            let yield_schema = yield_columns_to_arrow_schema(&rule.yield_schema);
1395            let key_column_indices: Vec<usize> = rule
1396                .yield_schema
1397                .iter()
1398                .enumerate()
1399                .filter(|(_, yc)| yc.is_key)
1400                .map(|(i, _)| i)
1401                .collect();
1402
1403            let clauses: Vec<FixpointClausePlan> = rule
1404                .clauses
1405                .iter()
1406                .map(|clause| {
1407                    let is_ref_bindings =
1408                        convert_is_refs(&clause.is_refs, registry, &stratum_rule_names)?;
1409                    Ok(FixpointClausePlan {
1410                        body_logical: clause.body.clone(),
1411                        is_ref_bindings,
1412                        priority: clause.priority,
1413                        along_bindings: clause.along_bindings.clone(),
1414                        model_invocations: clause.model_invocations.clone(),
1415                    })
1416                })
1417                .collect::<DFResult<Vec<_>>>()?;
1418
1419            let fold_bindings =
1420                convert_fold_bindings(&rule.fold_bindings, &rule.yield_schema, plugin_registry)?;
1421            let best_by_criteria =
1422                convert_best_by_criteria(&rule.best_by_criteria, &rule.yield_schema)?;
1423
1424            let has_priority = rule.priority.is_some();
1425
1426            // Add __priority column to yield schema if PRIORITY is used
1427            let yield_schema = if has_priority {
1428                let mut fields: Vec<Arc<Field>> = yield_schema.fields().iter().cloned().collect();
1429                fields.push(Arc::new(Field::new("__priority", DataType::Int64, true)));
1430                ArrowSchema::new(fields)
1431            } else {
1432                yield_schema
1433            };
1434
1435            let prob_column_name = rule
1436                .yield_schema
1437                .iter()
1438                .find(|yc| yc.is_prob)
1439                .map(|yc| yc.name.clone());
1440
1441            // Non-linear recursion: any clause joining ≥2 positive
1442            // same-stratum IS-refs needs full facts on its self-ref scans
1443            // (see `FixpointRulePlan::non_linear`).
1444            let non_linear = rule.clauses.iter().any(|clause| {
1445                clause
1446                    .is_refs
1447                    .iter()
1448                    .filter(|ir| !ir.negated && stratum_rule_names.contains(ir.rule_name.as_str()))
1449                    .count()
1450                    >= 2
1451            });
1452
1453            Ok(FixpointRulePlan {
1454                name: rule.name.clone(),
1455                clauses,
1456                yield_schema: Arc::new(yield_schema),
1457                key_column_indices,
1458                priority: rule.priority,
1459                has_fold: !rule.fold_bindings.is_empty(),
1460                fold_bindings,
1461                having: rule.having.clone(),
1462                has_best_by: !rule.best_by_criteria.is_empty(),
1463                best_by_criteria,
1464                has_priority,
1465                deterministic: deterministic_best_by,
1466                prob_column_name,
1467                non_linear,
1468            })
1469        })
1470        .collect()
1471}
1472
1473/// Convert `LocyIsRef` to `IsRefBinding` by looking up scan indices in the registry.
1474///
1475/// `stratum_rule_names` is the set of rule names in the stratum being converted.
1476/// A reference is self-referential exactly when its target is in that set — the
1477/// same rule the planner used to mint the handle (see `get_or_create_derived_scan_handle`).
1478/// Selecting the entry whose `is_self_ref` matches that decision is essential for
1479/// negation: a recursive rule has BOTH a self-ref handle (carrying the final,
1480/// usually-empty semi-naive delta) and a non-self-ref handle (carrying the
1481/// converged facts). An `IS NOT <recursive rule>` reference is cross-stratum
1482/// (`is_self_ref == false`), so it must anti-join against the converged facts —
1483/// not the delta, which would silently under-filter.
1484fn convert_is_refs(
1485    is_refs: &[LocyIsRef],
1486    registry: &DerivedScanRegistry,
1487    stratum_rule_names: &std::collections::HashSet<&str>,
1488) -> DFResult<Vec<IsRefBinding>> {
1489    is_refs
1490        .iter()
1491        .map(|is_ref| {
1492            let entries = registry.entries_for_rule(&is_ref.rule_name);
1493            // Select the handle matching the planner's self-ref decision for this
1494            // reference: same-stratum targets use the delta (self-ref) handle for
1495            // semi-naive evaluation; cross-stratum targets (including every IS NOT
1496            // against a lower-stratum recursive rule) use the converged-facts handle.
1497            let want_self_ref = stratum_rule_names.contains(is_ref.rule_name.as_str());
1498            let entry = entries
1499                .iter()
1500                .find(|e| e.is_self_ref == want_self_ref)
1501                .or_else(|| entries.first())
1502                .ok_or_else(|| {
1503                    datafusion::error::DataFusionError::Plan(format!(
1504                        "No derived scan entry found for IS-ref to '{}'",
1505                        is_ref.rule_name
1506                    ))
1507                })?;
1508
1509            // For negated IS-refs, compute (left_body_col, right_derived_col) pairs for
1510            // anti-join filtering. Subject vars are assumed to be node variables, so
1511            // the body column is `{var}._vid` (UInt64). The derived column name is taken
1512            // positionally from the registry entry's schema (KEY columns come first).
1513            let anti_join_cols = if is_ref.negated {
1514                let mut cols: Vec<(String, String)> = is_ref
1515                    .subjects
1516                    .iter()
1517                    .enumerate()
1518                    .filter_map(|(i, s)| {
1519                        if let uni_cypher::ast::Expr::Variable(var) = s {
1520                            let right_col = entry
1521                                .schema
1522                                .fields()
1523                                .get(i)
1524                                .map(|f| f.name().clone())
1525                                .unwrap_or_else(|| var.clone());
1526                            // After LocyProject the subject column is renamed to the yield
1527                            // column name (just `var`, not `var._vid`). Use bare var as left.
1528                            Some((var.clone(), right_col))
1529                        } else {
1530                            None
1531                        }
1532                    })
1533                    .collect();
1534                // Include target variable in anti-join for composite-key IS NOT.
1535                // Without this, `d IS NOT known TO dis` only checks d, not (d, dis),
1536                // filtering ALL pairs where the drug has ANY indication regardless
1537                // of disease.
1538                if let Some(uni_cypher::ast::Expr::Variable(target_var)) = &is_ref.target {
1539                    let target_idx = is_ref.subjects.len();
1540                    if let Some(field) = entry.schema.fields().get(target_idx) {
1541                        cols.push((target_var.clone(), field.name().clone()));
1542                    }
1543                }
1544                cols
1545            } else {
1546                Vec::new()
1547            };
1548
1549            // Provenance join cols: for ALL IS-refs (not just negated), compute
1550            // (body_col, derived_col) pairs so shared-proof detection can trace
1551            // which source facts contributed to each derived row.
1552            let provenance_join_cols: Vec<(String, String)> = is_ref
1553                .subjects
1554                .iter()
1555                .enumerate()
1556                .filter_map(|(i, s)| {
1557                    if let uni_cypher::ast::Expr::Variable(var) = s {
1558                        let right_col = entry
1559                            .schema
1560                            .fields()
1561                            .get(i)
1562                            .map(|f| f.name().clone())
1563                            .unwrap_or_else(|| var.clone());
1564                        Some((var.clone(), right_col))
1565                    } else {
1566                        None
1567                    }
1568                })
1569                .collect();
1570
1571            Ok(IsRefBinding {
1572                derived_scan_index: entry.scan_index,
1573                rule_name: is_ref.rule_name.clone(),
1574                is_self_ref: entry.is_self_ref,
1575                negated: is_ref.negated,
1576                anti_join_cols,
1577                target_has_prob: is_ref.target_has_prob,
1578                target_prob_col: is_ref.target_prob_col.clone(),
1579                provenance_join_cols,
1580            })
1581        })
1582        .collect()
1583}
1584
1585/// Convert fold binding expressions to physical `FoldBinding`.
1586///
1587/// The input column is looked up by the fold binding's output name (e.g., "total")
1588/// in the yield schema, since the LocyProject aliases the aggregate input expression
1589/// to the fold output name. The aggregate name is resolved against
1590/// `plugin_registry` to obtain the [`uni_plugin::traits::locy::LocyAggregate`]
1591/// trait object at plan time.
1592fn convert_fold_bindings(
1593    fold_bindings: &[(String, String, Expr)],
1594    yield_schema: &[LocyYieldColumn],
1595    plugin_registry: &PluginRegistry,
1596) -> DFResult<Vec<FoldBinding>> {
1597    fold_bindings
1598        .iter()
1599        .map(|(name, yield_alias, expr)| {
1600            let (agg_name, _input_col_name) = parse_fold_aggregate(expr)?;
1601            let entry =
1602                resolve_locy_aggregate(plugin_registry, agg_name.as_str()).ok_or_else(|| {
1603                    datafusion::error::DataFusionError::Plan(format!(
1604                        "Unknown Locy aggregate '{agg_name}' — not registered in plugin registry"
1605                    ))
1606                })?;
1607            let aggregate = Arc::clone(&entry.aggregate);
1608
1609            // CountAll has no input column — LocyProject skips the output column
1610            // entirely, so there is nothing to look up.
1611            if agg_name.as_str() == "COUNTALL" {
1612                return Ok(FoldBinding {
1613                    output_name: yield_alias.clone(),
1614                    name: agg_name,
1615                    aggregate,
1616                    input_col_index: 0, // unused for CountAll
1617                    input_col_name: None,
1618                });
1619            }
1620
1621            // The LocyProject projects the aggregate input expression AS the fold
1622            // output name, so the input column index matches the yield schema position.
1623            // Also store the column name for name-based resolution at execution time
1624            // (more robust when schema reconciliation changes column ordering).
1625            let input_col_index = yield_schema
1626                .iter()
1627                .position(|yc| yc.name == *name || yc.name == *yield_alias)
1628                .unwrap_or(0);
1629            Ok(FoldBinding {
1630                output_name: yield_alias.clone(),
1631                name: agg_name,
1632                aggregate,
1633                input_col_index,
1634                input_col_name: Some(name.clone()),
1635            })
1636        })
1637        .collect()
1638}
1639
1640/// Parse a fold aggregate expression into (canonical_name, input_column_name).
1641///
1642/// Normalizes grammar aliases to canonical names: `MSUM`→`SUM`, `MMAX`→`MAX`,
1643/// `MMIN`→`MIN`, `MCOUNT`→`COUNT`. The zero-arg `COUNT()`/`MCOUNT()` form
1644/// returns the `COUNTALL` sentinel. `MNOR`/`MPROD` are already canonical.
1645fn parse_fold_aggregate(expr: &Expr) -> DFResult<(smol_str::SmolStr, String)> {
1646    match expr {
1647        Expr::FunctionCall { name, args, .. } => {
1648            let upper = name.to_uppercase();
1649            let is_count = matches!(upper.as_str(), "COUNT" | "MCOUNT");
1650
1651            // COUNT/MCOUNT with zero args → CountAll (like SQL COUNT(*))
1652            if is_count && args.is_empty() {
1653                return Ok((smol_str::SmolStr::new_static("COUNTALL"), String::new()));
1654            }
1655
1656            let canonical = match upper.as_str() {
1657                "SUM" | "MSUM" => smol_str::SmolStr::new_static("SUM"),
1658                "MAX" | "MMAX" => smol_str::SmolStr::new_static("MAX"),
1659                "MIN" | "MMIN" => smol_str::SmolStr::new_static("MIN"),
1660                "COUNT" | "MCOUNT" => smol_str::SmolStr::new_static("COUNT"),
1661                "AVG" => smol_str::SmolStr::new_static("AVG"),
1662                "COLLECT" => smol_str::SmolStr::new_static("COLLECT"),
1663                "MNOR" => smol_str::SmolStr::new_static("MNOR"),
1664                "MPROD" => smol_str::SmolStr::new_static("MPROD"),
1665                _ => {
1666                    return Err(datafusion::error::DataFusionError::Plan(format!(
1667                        "Unknown FOLD aggregate function: {}",
1668                        name
1669                    )));
1670                }
1671            };
1672            let col_name = match args.first() {
1673                Some(Expr::Variable(v)) => v.clone(),
1674                Some(Expr::Property(_, prop)) => prop.clone(),
1675                Some(other) => other.to_string_repr(),
1676                None => {
1677                    return Err(datafusion::error::DataFusionError::Plan(
1678                        "FOLD aggregate function requires at least one argument".to_string(),
1679                    ));
1680                }
1681            };
1682            Ok((canonical, col_name))
1683        }
1684        _ => Err(datafusion::error::DataFusionError::Plan(
1685            "FOLD binding must be a function call (e.g., SUM(x))".to_string(),
1686        )),
1687    }
1688}
1689
1690/// Convert best-by criteria expressions to physical `SortCriterion`.
1691///
1692/// Resolves the criteria column by trying:
1693/// 1. Property name (e.g., `e.cost` → "cost")
1694/// 2. Variable name (e.g., `cost`)
1695/// 3. Full expression string (e.g., "e.cost" as a variable name)
1696fn convert_best_by_criteria(
1697    criteria: &[(Expr, bool)],
1698    yield_schema: &[LocyYieldColumn],
1699) -> DFResult<Vec<SortCriterion>> {
1700    criteria
1701        .iter()
1702        .map(|(expr, ascending)| {
1703            let col_name = match expr {
1704                Expr::Property(_, prop) => prop.clone(),
1705                Expr::Variable(v) => v.clone(),
1706                _ => {
1707                    return Err(datafusion::error::DataFusionError::Plan(
1708                        "BEST BY criterion must be a variable or property reference".to_string(),
1709                    ));
1710                }
1711            };
1712            // Try exact match first, then try just the last component after '.'
1713            let col_index = yield_schema
1714                .iter()
1715                .position(|yc| yc.name == col_name)
1716                .or_else(|| {
1717                    let short_name = col_name.rsplit('.').next().unwrap_or(&col_name);
1718                    yield_schema.iter().position(|yc| yc.name == short_name)
1719                })
1720                .ok_or_else(|| {
1721                    datafusion::error::DataFusionError::Plan(format!(
1722                        "BEST BY column '{}' not found in yield schema",
1723                        col_name
1724                    ))
1725                })?;
1726            Ok(SortCriterion {
1727                col_index,
1728                ascending: *ascending,
1729                nulls_first: false,
1730            })
1731        })
1732        .collect()
1733}
1734
1735// ---------------------------------------------------------------------------
1736// Schema helpers
1737// ---------------------------------------------------------------------------
1738
1739/// Convert `LocyYieldColumn` slice to Arrow schema using inferred types.
1740fn yield_columns_to_arrow_schema(columns: &[LocyYieldColumn]) -> ArrowSchema {
1741    let fields: Vec<Arc<Field>> = columns
1742        .iter()
1743        .map(|yc| Arc::new(Field::new(&yc.name, yc.data_type.clone(), true)))
1744        .collect();
1745    ArrowSchema::new(fields)
1746}
1747
1748/// Build a combined output schema for fixpoint (union of all rules' schemas).
1749fn build_fixpoint_output_schema(rules: &[LocyRulePlan]) -> SchemaRef {
1750    // FixpointExec concatenates all rules' output, using the first rule's schema
1751    // as the output schema (all rules in a recursive stratum share compatible schemas).
1752    if let Some(rule) = rules.first() {
1753        Arc::new(yield_columns_to_arrow_schema(&rule.yield_schema))
1754    } else {
1755        Arc::new(ArrowSchema::empty())
1756    }
1757}
1758
1759/// Build a stats RecordBatch summarizing derived relation counts.
1760fn build_stats_batch(
1761    derived_store: &DerivedStore,
1762    _strata: &[LocyStratum],
1763    output_schema: SchemaRef,
1764) -> RecordBatch {
1765    // Build a simple stats batch with rule_name and fact_count columns
1766    let mut rule_names: Vec<String> = derived_store.rule_names().map(String::from).collect();
1767    rule_names.sort();
1768
1769    let name_col: arrow_array::StringArray = rule_names.iter().map(|s| Some(s.as_str())).collect();
1770    let count_col: arrow_array::Int64Array = rule_names
1771        .iter()
1772        .map(|name| Some(derived_store.fact_count(name) as i64))
1773        .collect();
1774
1775    let stats_schema = stats_schema();
1776    RecordBatch::try_new(stats_schema, vec![Arc::new(name_col), Arc::new(count_col)])
1777        .unwrap_or_else(|_| RecordBatch::new_empty(output_schema))
1778}
1779
1780/// Schema for the stats batch returned when no commands are present.
1781pub fn stats_schema() -> SchemaRef {
1782    Arc::new(ArrowSchema::new(vec![
1783        Arc::new(Field::new("rule_name", DataType::Utf8, false)),
1784        Arc::new(Field::new("fact_count", DataType::Int64, false)),
1785    ]))
1786}
1787
1788// ---------------------------------------------------------------------------
1789// Unit tests
1790// ---------------------------------------------------------------------------
1791
1792#[cfg(test)]
1793mod tests {
1794    use super::*;
1795    use arrow_array::{Int64Array, LargeBinaryArray, StringArray};
1796
1797    #[test]
1798    fn test_derived_store_insert_and_get() {
1799        let mut store = DerivedStore::new();
1800        assert!(store.get("test").is_none());
1801
1802        let schema = Arc::new(ArrowSchema::new(vec![Arc::new(Field::new(
1803            "x",
1804            DataType::LargeBinary,
1805            true,
1806        ))]));
1807        let batch = RecordBatch::try_new(
1808            Arc::clone(&schema),
1809            vec![Arc::new(LargeBinaryArray::from(vec![
1810                Some(b"a" as &[u8]),
1811                Some(b"b"),
1812            ]))],
1813        )
1814        .unwrap();
1815
1816        store.insert("test".to_string(), vec![batch.clone()]);
1817
1818        let facts = store.get("test").unwrap();
1819        assert_eq!(facts.len(), 1);
1820        assert_eq!(facts[0].num_rows(), 2);
1821    }
1822
1823    #[test]
1824    fn test_derived_store_fact_count() {
1825        let mut store = DerivedStore::new();
1826        assert_eq!(store.fact_count("empty"), 0);
1827
1828        let schema = Arc::new(ArrowSchema::new(vec![Arc::new(Field::new(
1829            "x",
1830            DataType::LargeBinary,
1831            true,
1832        ))]));
1833        let batch1 = RecordBatch::try_new(
1834            Arc::clone(&schema),
1835            vec![Arc::new(LargeBinaryArray::from(vec![Some(b"a" as &[u8])]))],
1836        )
1837        .unwrap();
1838        let batch2 = RecordBatch::try_new(
1839            Arc::clone(&schema),
1840            vec![Arc::new(LargeBinaryArray::from(vec![
1841                Some(b"b" as &[u8]),
1842                Some(b"c"),
1843            ]))],
1844        )
1845        .unwrap();
1846
1847        store.insert("test".to_string(), vec![batch1, batch2]);
1848        assert_eq!(store.fact_count("test"), 3);
1849    }
1850
1851    #[test]
1852    fn test_stats_batch_schema() {
1853        let schema = stats_schema();
1854        assert_eq!(schema.fields().len(), 2);
1855        assert_eq!(schema.field(0).name(), "rule_name");
1856        assert_eq!(schema.field(1).name(), "fact_count");
1857        assert_eq!(schema.field(0).data_type(), &DataType::Utf8);
1858        assert_eq!(schema.field(1).data_type(), &DataType::Int64);
1859    }
1860
1861    #[test]
1862    fn test_stats_batch_content() {
1863        let mut store = DerivedStore::new();
1864        let schema = Arc::new(ArrowSchema::new(vec![Arc::new(Field::new(
1865            "x",
1866            DataType::LargeBinary,
1867            true,
1868        ))]));
1869        let batch = RecordBatch::try_new(
1870            Arc::clone(&schema),
1871            vec![Arc::new(LargeBinaryArray::from(vec![
1872                Some(b"a" as &[u8]),
1873                Some(b"b"),
1874            ]))],
1875        )
1876        .unwrap();
1877        store.insert("reach".to_string(), vec![batch]);
1878
1879        let output_schema = stats_schema();
1880        let stats = build_stats_batch(&store, &[], Arc::clone(&output_schema));
1881        assert_eq!(stats.num_rows(), 1);
1882
1883        let names = stats
1884            .column(0)
1885            .as_any()
1886            .downcast_ref::<StringArray>()
1887            .unwrap();
1888        assert_eq!(names.value(0), "reach");
1889
1890        let counts = stats
1891            .column(1)
1892            .as_any()
1893            .downcast_ref::<Int64Array>()
1894            .unwrap();
1895        assert_eq!(counts.value(0), 2);
1896    }
1897
1898    #[test]
1899    fn test_yield_columns_to_arrow_schema() {
1900        let columns = vec![
1901            LocyYieldColumn {
1902                name: "a".to_string(),
1903                is_key: true,
1904                is_prob: false,
1905                data_type: DataType::UInt64,
1906            },
1907            LocyYieldColumn {
1908                name: "b".to_string(),
1909                is_key: false,
1910                is_prob: false,
1911                data_type: DataType::LargeUtf8,
1912            },
1913            LocyYieldColumn {
1914                name: "c".to_string(),
1915                is_key: true,
1916                is_prob: false,
1917                data_type: DataType::Float64,
1918            },
1919        ];
1920
1921        let schema = yield_columns_to_arrow_schema(&columns);
1922        assert_eq!(schema.fields().len(), 3);
1923        assert_eq!(schema.field(0).name(), "a");
1924        assert_eq!(schema.field(1).name(), "b");
1925        assert_eq!(schema.field(2).name(), "c");
1926        // Fields use inferred types
1927        assert_eq!(schema.field(0).data_type(), &DataType::UInt64);
1928        assert_eq!(schema.field(1).data_type(), &DataType::LargeUtf8);
1929        assert_eq!(schema.field(2).data_type(), &DataType::Float64);
1930        for field in schema.fields() {
1931            assert!(field.is_nullable());
1932        }
1933    }
1934
1935    #[test]
1936    fn test_key_column_indices() {
1937        let columns = [
1938            LocyYieldColumn {
1939                name: "a".to_string(),
1940                is_key: true,
1941                is_prob: false,
1942                data_type: DataType::LargeBinary,
1943            },
1944            LocyYieldColumn {
1945                name: "b".to_string(),
1946                is_key: false,
1947                is_prob: false,
1948                data_type: DataType::LargeBinary,
1949            },
1950            LocyYieldColumn {
1951                name: "c".to_string(),
1952                is_key: true,
1953                is_prob: false,
1954                data_type: DataType::LargeBinary,
1955            },
1956        ];
1957
1958        let key_indices: Vec<usize> = columns
1959            .iter()
1960            .enumerate()
1961            .filter(|(_, yc)| yc.is_key)
1962            .map(|(i, _)| i)
1963            .collect();
1964        assert_eq!(key_indices, vec![0, 2]);
1965    }
1966
1967    #[test]
1968    fn test_parse_fold_aggregate_sum() {
1969        let expr = Expr::FunctionCall {
1970            name: "SUM".to_string(),
1971            args: vec![Expr::Variable("cost".to_string())],
1972            distinct: false,
1973            window_spec: None,
1974        };
1975        let (kind, col) = parse_fold_aggregate(&expr).unwrap();
1976        assert_eq!(kind.as_str(), "SUM");
1977        assert_eq!(col, "cost");
1978    }
1979
1980    #[test]
1981    fn test_parse_fold_aggregate_monotonic() {
1982        let expr = Expr::FunctionCall {
1983            name: "MMAX".to_string(),
1984            args: vec![Expr::Variable("score".to_string())],
1985            distinct: false,
1986            window_spec: None,
1987        };
1988        let (kind, col) = parse_fold_aggregate(&expr).unwrap();
1989        assert_eq!(kind.as_str(), "MAX");
1990        assert_eq!(col, "score");
1991    }
1992
1993    #[test]
1994    fn test_parse_fold_aggregate_unknown() {
1995        let expr = Expr::FunctionCall {
1996            name: "UNKNOWN_AGG".to_string(),
1997            args: vec![Expr::Variable("x".to_string())],
1998            distinct: false,
1999            window_spec: None,
2000        };
2001        assert!(parse_fold_aggregate(&expr).is_err());
2002    }
2003
2004    #[test]
2005    fn test_no_commands_returns_stats() {
2006        let store = DerivedStore::new();
2007        let output_schema = stats_schema();
2008        let stats = build_stats_batch(&store, &[], Arc::clone(&output_schema));
2009        // Empty store → 0 rows
2010        assert_eq!(stats.num_rows(), 0);
2011    }
2012}