Skip to main content

uni_query/query/df_graph/
locy_fixpoint.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Fixpoint iteration operator for recursive Locy strata.
5//!
6//! `FixpointExec` drives semi-naive evaluation: it repeatedly evaluates the rules
7//! in a recursive stratum, feeding back deltas until no new facts are produced.
8
9use crate::query::df_graph::GraphExecutionContext;
10use crate::query::df_graph::common::{
11    ScalarKey, arrow_err, collect_all_partitions, compute_plan_properties, execute_subplan,
12    extract_scalar_key,
13};
14use crate::query::df_graph::locy_best_by::{BestByExec, SortCriterion};
15use crate::query::df_graph::locy_errors::LocyRuntimeError;
16use crate::query::df_graph::locy_explain::{
17    ProofTerm, ProvenanceAnnotation, ProvenanceStore, compute_proof_probability,
18};
19use crate::query::df_graph::locy_fold::{FoldBinding, FoldExec};
20use crate::query::df_graph::locy_priority::PriorityExec;
21use crate::query::df_graph::locy_program::interruption;
22use crate::query::planner::LogicalPlan;
23use arrow_array::RecordBatch;
24use arrow_row::{RowConverter, SortField};
25use arrow_schema::SchemaRef;
26use datafusion::common::JoinType;
27use datafusion::common::Result as DFResult;
28use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
29use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode};
30use datafusion::physical_plan::memory::MemoryStream;
31use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
32use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
33use futures::Stream;
34use parking_lot::RwLock;
35use std::any::Any;
36use std::collections::{HashMap, HashSet};
37use std::fmt;
38use std::pin::Pin;
39use std::sync::{Arc, RwLock as StdRwLock};
40use std::task::{Context, Poll};
41use std::time::{Duration, Instant};
42use uni_common::Value;
43use uni_common::core::schema::Schema as UniSchema;
44use uni_cypher::ast::Expr;
45use uni_locy::{
46    ClassifierRegistry, ModelInvocation, ModelInvocationCache, RuntimeWarning, RuntimeWarningCode,
47    SemiringKind,
48};
49use uni_store::storage::manager::StorageManager;
50
51// ---------------------------------------------------------------------------
52// DerivedScanRegistry — injection point for IS-ref data into subplans
53// ---------------------------------------------------------------------------
54
55/// A single entry in the derived scan registry.
56///
57/// Each entry corresponds to one `LocyDerivedScan` node in the logical plan tree.
58/// The `data` handle is shared with the logical plan node so that writing data here
59/// makes it visible when the subplan is re-planned and executed.
60#[derive(Debug)]
61pub struct DerivedScanEntry {
62    /// Index matching the `scan_index` in `LocyDerivedScan`.
63    pub scan_index: usize,
64    /// Name of the rule this scan reads from.
65    pub rule_name: String,
66    /// Whether this is a self-referential scan (rule references itself).
67    pub is_self_ref: bool,
68    /// Shared data handle — write batches here to inject into subplans.
69    pub data: Arc<RwLock<Vec<RecordBatch>>>,
70    /// Schema of the derived relation.
71    pub schema: SchemaRef,
72}
73
74/// Registry of derived scan handles for fixpoint iteration.
75///
76/// During fixpoint, each clause body may reference derived relations via
77/// `LocyDerivedScan` nodes. The registry maps scan indices to shared data
78/// handles so the fixpoint loop can inject delta/full facts before each
79/// iteration.
80#[derive(Debug, Default)]
81pub struct DerivedScanRegistry {
82    entries: Vec<DerivedScanEntry>,
83}
84
85impl DerivedScanRegistry {
86    /// Create a new empty registry.
87    pub fn new() -> Self {
88        Self::default()
89    }
90
91    /// Add an entry to the registry.
92    pub fn add(&mut self, entry: DerivedScanEntry) {
93        self.entries.push(entry);
94    }
95
96    /// Get an entry by scan index.
97    pub fn get(&self, scan_index: usize) -> Option<&DerivedScanEntry> {
98        self.entries.iter().find(|e| e.scan_index == scan_index)
99    }
100
101    /// Write data into a scan entry's shared handle.
102    pub fn write_data(&self, scan_index: usize, batches: Vec<RecordBatch>) {
103        if let Some(entry) = self.get(scan_index) {
104            let mut guard = entry.data.write();
105            *guard = batches;
106        }
107    }
108
109    /// Get all entries for a given rule name.
110    pub fn entries_for_rule(&self, rule_name: &str) -> Vec<&DerivedScanEntry> {
111        self.entries
112            .iter()
113            .filter(|e| e.rule_name == rule_name)
114            .collect()
115    }
116}
117
118// ---------------------------------------------------------------------------
119// MonotonicAggState — tracking monotonic aggregates across iterations
120// ---------------------------------------------------------------------------
121
122/// Monotonic aggregate binding: maps a fold name to its aggregate
123/// trait object and input column.
124///
125/// Dispatches purely through [`uni_plugin::traits::locy::LocyAggregate`]
126/// (`update_step` / `initial_accum_f64` / `is_probability_aggregate` /
127/// `is_noisy_or`).
128#[derive(Debug, Clone)]
129pub struct MonotonicFoldBinding {
130    pub fold_name: String,
131    pub aggregate: std::sync::Arc<dyn uni_plugin::traits::locy::LocyAggregate>,
132    pub input_col_index: usize,
133    /// Column name for name-based resolution (more robust than positional index).
134    pub input_col_name: Option<String>,
135}
136
137/// Tracks monotonic aggregate accumulators across fixpoint iterations.
138///
139/// After each iteration, accumulators are updated and compared to their previous
140/// snapshot. The fixpoint has converged (w.r.t. aggregates) when all accumulators
141/// are stable (no change between iterations).
142#[derive(Debug)]
143pub struct MonotonicAggState {
144    /// Current accumulator values keyed by (group_key, fold_name).
145    accumulators: HashMap<(Vec<ScalarKey>, String), f64>,
146    /// Snapshot from the previous iteration for stability check.
147    prev_snapshot: HashMap<(Vec<ScalarKey>, String), f64>,
148    /// Bindings describing which aggregates to track.
149    bindings: Vec<MonotonicFoldBinding>,
150}
151
152impl MonotonicAggState {
153    /// Create a new monotonic aggregate state.
154    pub fn new(bindings: Vec<MonotonicFoldBinding>) -> Self {
155        Self {
156            accumulators: HashMap::new(),
157            prev_snapshot: HashMap::new(),
158            bindings,
159        }
160    }
161
162    /// Update accumulators with new delta batches.
163    ///
164    /// Returns `true` if any accumulator value changed. When `strict` is
165    /// `true`, MNOR/MPROD inputs outside `[0, 1]` produce an error
166    /// instead of being clamped.
167    ///
168    /// `semiring_kind` selects the probability semiring for probability
169    /// aggregates: `AddMultProb` (default, Phase 1/2 noisy-OR/product) or
170    /// `MaxMinProb` (Viterbi/fuzzy — opt-in, callers emit
171    /// `FuzzyNotProbabilistic`).
172    ///
173    /// Dispatch goes through each binding's `Arc<dyn LocyAggregate>` trait
174    /// object via [`uni_plugin::traits::locy::LocyAggregate::update_step`].
175    /// The trait object's `initial_accum_f64()` seeds the per-group
176    /// accumulator. Under `MaxMinProb`, probability aggregates (MNOR /
177    /// MPROD) bypass `update_step` and fold via the `MaxMinProb` semiring's
178    /// `plus` (max) / `times` (min) instead — preserving the opt-in
179    /// Viterbi/fuzzy semantics that `update_step`'s built-in noisy-OR /
180    /// product path does not implement.
181    ///
182    /// Aggregates whose `update_step` returns `Err(CODE_UNKNOWN_FUNCTION)`
183    /// (default impl — no row-level fast path; e.g., `AVG`, `COLLECT`)
184    /// are skipped silently here — those run through the batch-shape
185    /// [`uni_plugin::traits::locy::LocyAggState::ingest`] path in
186    /// `apply_post_fixpoint_chain` instead.
187    pub fn update(
188        &mut self,
189        key_indices: &[usize],
190        delta_batches: &[RecordBatch],
191        strict: bool,
192        semiring_kind: SemiringKind,
193    ) -> DFResult<bool> {
194        let mut changed = false;
195        for batch in delta_batches {
196            for row_idx in 0..batch.num_rows() {
197                let group_key = extract_scalar_key(batch, key_indices, row_idx);
198                for binding in &self.bindings {
199                    let idx = binding
200                        .input_col_name
201                        .as_ref()
202                        .and_then(|name| batch.schema().index_of(name).ok())
203                        .unwrap_or(binding.input_col_index);
204                    if idx >= batch.num_columns() {
205                        continue;
206                    }
207                    let col = batch.column(idx);
208                    let val = extract_f64(col.as_ref(), row_idx);
209                    if let Some(val) = val {
210                        let map_key = (group_key.clone(), binding.fold_name.clone());
211                        let initial = binding.aggregate.initial_accum_f64().unwrap_or(0.0);
212                        let entry = self.accumulators.entry(map_key).or_insert(initial);
213                        let old = *entry;
214                        // Under `MaxMinProb`, probability aggregates (MNOR /
215                        // MPROD) fold via the Viterbi/fuzzy semiring (max /
216                        // min) rather than the trait object's built-in
217                        // noisy-OR / product `update_step`. The inline domain
218                        // checks below preserve the exact strict-mode error
219                        // and clamp-warning literals. `is_noisy_or()`
220                        // distinguishes MNOR (disjunction → max) from MPROD
221                        // (conjunction → min). All other aggregates — and the
222                        // default `AddMultProb` semiring — dispatch through
223                        // the trait object's `update_step`.
224                        if matches!(semiring_kind, SemiringKind::MaxMinProb)
225                            && binding.aggregate.is_probability_aggregate()
226                        {
227                            use uni_locy::LocySemiring;
228                            let sr = uni_locy::MaxMinProb;
229                            let is_nor = binding.aggregate.is_noisy_or();
230                            let label = if is_nor { "MNOR" } else { "MPROD" };
231                            if strict && !(0.0..=1.0).contains(&val) {
232                                return Err(datafusion::error::DataFusionError::Execution(
233                                    format!(
234                                        "strict_probability_domain: {label} input {val} is outside [0, 1]"
235                                    ),
236                                ));
237                            }
238                            if !strict && !(0.0..=1.0).contains(&val) {
239                                tracing::warn!(
240                                    "{label} input {val} outside [0,1], clamped to {}",
241                                    val.clamp(0.0, 1.0)
242                                );
243                            }
244                            let p = val.clamp(0.0, 1.0);
245                            // MaxMinProb: MNOR -> max (plus), MPROD -> min (times).
246                            *entry = if is_nor {
247                                sr.plus(entry, &p)
248                            } else {
249                                sr.times(entry, &p)
250                            };
251                            if (*entry - old).abs() > f64::EPSILON {
252                                changed = true;
253                            }
254                            continue;
255                        }
256                        match binding.aggregate.update_step(*entry, val, strict) {
257                            Ok(new_val) => {
258                                *entry = new_val;
259                                if (*entry - old).abs() > f64::EPSILON {
260                                    changed = true;
261                                }
262                            }
263                            Err(e) if e.code == uni_plugin::FnError::CODE_UNKNOWN_FUNCTION => {
264                                // Aggregate has no row-level fast path (AVG,
265                                // COLLECT). Those run through the
266                                // batch-shape `ingest` path elsewhere; skip.
267                            }
268                            Err(e) => {
269                                // Strict-mode probability-domain violation,
270                                // or another aggregate-specific failure.
271                                return Err(datafusion::error::DataFusionError::Execution(
272                                    e.message,
273                                ));
274                            }
275                        }
276                    }
277                }
278            }
279        }
280        Ok(changed)
281    }
282
283    /// Take a snapshot of current accumulators for stability comparison.
284    pub fn snapshot(&mut self) {
285        self.prev_snapshot = self.accumulators.clone();
286    }
287
288    /// Check if accumulators are stable (no change since last snapshot).
289    pub fn is_stable(&self) -> bool {
290        if self.accumulators.len() != self.prev_snapshot.len() {
291            return false;
292        }
293        for (key, val) in &self.accumulators {
294            match self.prev_snapshot.get(key) {
295                Some(prev) if (*val - *prev).abs() <= f64::EPSILON => {}
296                _ => return false,
297            }
298        }
299        true
300    }
301
302    /// Test-only accessor for accumulator values.
303    #[cfg(test)]
304    pub(crate) fn get_accumulator(&self, key: &(Vec<ScalarKey>, String)) -> Option<f64> {
305        self.accumulators.get(key).copied()
306    }
307}
308
309/// Extract f64 value from an Arrow column at a given row index.
310fn extract_f64(col: &dyn arrow_array::Array, row_idx: usize) -> Option<f64> {
311    if col.is_null(row_idx) {
312        return None;
313    }
314    if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Float64Array>() {
315        Some(arr.value(row_idx))
316    } else {
317        col.as_any()
318            .downcast_ref::<arrow_array::Int64Array>()
319            .map(|arr| arr.value(row_idx) as f64)
320    }
321}
322
323// ---------------------------------------------------------------------------
324// RowDedupState — Arrow RowConverter-based persistent dedup set
325// ---------------------------------------------------------------------------
326
327/// Arrow-native row deduplication using [`RowConverter`].
328///
329/// Unlike the legacy `HashSet<Vec<ScalarKey>>` approach, this struct maintains a
330/// persistent `seen` set across iterations so per-iteration cost is O(M) where M
331/// is the number of candidate rows — the full facts table is never re-scanned.
332struct RowDedupState {
333    converter: RowConverter,
334    seen: HashSet<Box<[u8]>>,
335}
336
337impl RowDedupState {
338    /// Try to build a `RowDedupState` for the given schema.
339    ///
340    /// Returns `None` if any column type is not supported by `RowConverter`
341    /// (triggers legacy fallback).
342    fn try_new(schema: &SchemaRef) -> Option<Self> {
343        let fields: Vec<SortField> = schema
344            .fields()
345            .iter()
346            .map(|f| SortField::new(f.data_type().clone()))
347            .collect();
348        match RowConverter::new(fields) {
349            Ok(converter) => Some(Self {
350                converter,
351                seen: HashSet::new(),
352            }),
353            Err(e) => {
354                tracing::warn!(
355                    "RowDedupState: RowConverter unsupported for schema, falling back to legacy dedup: {}",
356                    e
357                );
358                None
359            }
360        }
361    }
362
363    /// Populate the seen set from existing fact batches.
364    ///
365    /// Used after BEST BY in-loop pruning replaces the fact set, so that delta
366    /// computation in subsequent iterations correctly recognizes surviving facts.
367    fn ingest_existing(&mut self, facts: &[RecordBatch], _schema: &SchemaRef) {
368        self.seen.clear();
369        for batch in facts {
370            if batch.num_rows() == 0 {
371                continue;
372            }
373            let arrays: Vec<_> = batch.columns().to_vec();
374            if let Ok(rows) = self.converter.convert_columns(&arrays) {
375                for row_idx in 0..batch.num_rows() {
376                    let row_bytes: Box<[u8]> = rows.row(row_idx).data().into();
377                    self.seen.insert(row_bytes);
378                }
379            }
380        }
381    }
382
383    /// Filter `candidates` to only rows not yet seen, updating the persistent set.
384    ///
385    /// Both cross-iteration dedup (rows already accepted in prior iterations) and
386    /// within-batch dedup (duplicate rows in a single candidate batch) are handled
387    /// in a single pass.
388    fn compute_delta(
389        &mut self,
390        candidates: &[RecordBatch],
391        schema: &SchemaRef,
392    ) -> DFResult<Vec<RecordBatch>> {
393        let mut delta_batches = Vec::new();
394        for batch in candidates {
395            if batch.num_rows() == 0 {
396                continue;
397            }
398
399            // Vectorized encoding of all rows in this batch.
400            let arrays: Vec<_> = batch.columns().to_vec();
401            let rows = self.converter.convert_columns(&arrays).map_err(arrow_err)?;
402
403            // One pass: check+insert into persistent seen set.
404            let mut keep = Vec::with_capacity(batch.num_rows());
405            for row_idx in 0..batch.num_rows() {
406                let row_bytes: Box<[u8]> = rows.row(row_idx).data().into();
407                keep.push(self.seen.insert(row_bytes));
408            }
409
410            let keep_mask = arrow_array::BooleanArray::from(keep);
411            let new_cols = batch
412                .columns()
413                .iter()
414                .map(|col| {
415                    arrow::compute::filter(col.as_ref(), &keep_mask).map_err(|e| {
416                        datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
417                    })
418                })
419                .collect::<DFResult<Vec<_>>>()?;
420
421            if new_cols.first().is_some_and(|c| !c.is_empty()) {
422                let filtered = RecordBatch::try_new(Arc::clone(schema), new_cols).map_err(|e| {
423                    datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
424                })?;
425                delta_batches.push(filtered);
426            }
427        }
428        Ok(delta_batches)
429    }
430}
431
432// ---------------------------------------------------------------------------
433// FixpointState — per-rule delta tracking during fixpoint iteration
434// ---------------------------------------------------------------------------
435
436/// Per-rule state for fixpoint iteration.
437///
438/// Tracks accumulated facts and the delta (new facts from the latest iteration).
439/// Deduplication uses Arrow [`RowConverter`] with a persistent seen set (O(M) per
440/// iteration) when supported, with a legacy `HashSet<Vec<ScalarKey>>` fallback.
441pub struct FixpointState {
442    rule_name: String,
443    facts: Vec<RecordBatch>,
444    delta: Vec<RecordBatch>,
445    schema: SchemaRef,
446    key_column_indices: Vec<usize>,
447    /// KEY column names for recomputing indices after schema reconciliation.
448    key_column_names: Vec<String>,
449    /// All column indices for full-row dedup (legacy path only).
450    all_column_indices: Vec<usize>,
451    /// Running total of facts bytes for memory limit tracking.
452    facts_bytes: usize,
453    /// Maximum bytes allowed for this derived relation.
454    max_derived_bytes: usize,
455    /// Optional monotonic aggregate tracking.
456    monotonic_agg: Option<MonotonicAggState>,
457    /// Arrow RowConverter-based dedup state; `None` triggers legacy fallback.
458    row_dedup: Option<RowDedupState>,
459    /// Whether strict probability domain checks are enabled.
460    strict_probability_domain: bool,
461    /// Active probability semiring for this rule's MNOR/MPROD math.
462    semiring_kind: SemiringKind,
463}
464
465impl FixpointState {
466    /// Create a new fixpoint state for a rule. Existing tests call this
467    /// with the Phase 1/2 default; the fixpoint planner uses
468    /// [`FixpointState::new_with_semiring`] to thread the configured
469    /// semiring through.
470    pub fn new(
471        rule_name: String,
472        schema: SchemaRef,
473        key_column_indices: Vec<usize>,
474        max_derived_bytes: usize,
475        monotonic_agg: Option<MonotonicAggState>,
476        strict_probability_domain: bool,
477    ) -> Self {
478        Self::new_with_semiring(
479            rule_name,
480            schema,
481            key_column_indices,
482            max_derived_bytes,
483            monotonic_agg,
484            strict_probability_domain,
485            SemiringKind::AddMultProb,
486        )
487    }
488
489    pub fn new_with_semiring(
490        rule_name: String,
491        schema: SchemaRef,
492        key_column_indices: Vec<usize>,
493        max_derived_bytes: usize,
494        monotonic_agg: Option<MonotonicAggState>,
495        strict_probability_domain: bool,
496        semiring_kind: SemiringKind,
497    ) -> Self {
498        let num_cols = schema.fields().len();
499        let row_dedup = RowDedupState::try_new(&schema);
500        let key_column_names: Vec<String> = key_column_indices
501            .iter()
502            .filter_map(|&i| schema.fields().get(i).map(|f| f.name().clone()))
503            .collect();
504        Self {
505            rule_name,
506            facts: Vec::new(),
507            delta: Vec::new(),
508            schema,
509            key_column_indices,
510            key_column_names,
511            all_column_indices: (0..num_cols).collect(),
512            facts_bytes: 0,
513            max_derived_bytes,
514            monotonic_agg,
515            row_dedup,
516            strict_probability_domain,
517            semiring_kind,
518        }
519    }
520
521    /// Reconcile the pre-computed schema with the actual physical plan output.
522    ///
523    /// `infer_expr_type` may guess wrong (e.g. `Property → Float64` for a
524    /// string column).  When the first real batch arrives with a different
525    /// schema, update ours so that `RowDedupState` / `RecordBatch::try_new`
526    /// use the correct types.
527    fn reconcile_schema(&mut self, actual_schema: &SchemaRef) {
528        if self.schema.fields() != actual_schema.fields() {
529            tracing::debug!(
530                rule = %self.rule_name,
531                "Reconciling fixpoint schema from physical plan output",
532            );
533            self.schema = Arc::clone(actual_schema);
534            self.row_dedup = RowDedupState::try_new(&self.schema);
535            // Recompute key_column_indices from stored KEY column names.
536            // Without this, FoldExec groups by wrong columns when the
537            // physical plan reorders columns vs the pre-inferred schema.
538            let new_indices: Vec<usize> = self
539                .key_column_names
540                .iter()
541                .filter_map(|name| actual_schema.index_of(name).ok())
542                .collect();
543            if new_indices.len() == self.key_column_names.len() {
544                self.key_column_indices = new_indices;
545            }
546            // else: not all KEY columns found in new schema — keep original indices
547            let num_cols = actual_schema.fields().len();
548            self.all_column_indices = (0..num_cols).collect();
549        }
550    }
551
552    /// Merge candidate rows into facts, computing delta (truly new rows).
553    ///
554    /// Returns `true` if any new facts were added.
555    pub async fn merge_delta(
556        &mut self,
557        candidates: Vec<RecordBatch>,
558        task_ctx: Option<Arc<TaskContext>>,
559    ) -> DFResult<bool> {
560        if candidates.is_empty() || candidates.iter().all(|b| b.num_rows() == 0) {
561            self.delta.clear();
562            return Ok(false);
563        }
564
565        // Reconcile schema from the first non-empty candidate batch.
566        // The physical plan's output types are authoritative over the
567        // planner's inferred types.
568        if let Some(first) = candidates.iter().find(|b| b.num_rows() > 0) {
569            self.reconcile_schema(&first.schema());
570        }
571
572        // Round floats for stable dedup
573        let candidates = round_float_columns(&candidates);
574
575        // Compute delta: rows in candidates not already in facts
576        let delta = self.compute_delta(&candidates, task_ctx.as_ref()).await?;
577
578        if delta.is_empty() || delta.iter().all(|b| b.num_rows() == 0) {
579            self.delta.clear();
580            // Update monotonic aggs even with empty delta (for stability check)
581            if let Some(ref mut agg) = self.monotonic_agg {
582                agg.snapshot();
583            }
584            return Ok(false);
585        }
586
587        // Check memory limit
588        let delta_bytes: usize = delta.iter().map(batch_byte_size).sum();
589        if self.facts_bytes + delta_bytes > self.max_derived_bytes {
590            return Err(datafusion::error::DataFusionError::Execution(
591                LocyRuntimeError::MemoryLimitExceeded {
592                    rule: self.rule_name.clone(),
593                    bytes: self.facts_bytes + delta_bytes,
594                    limit: self.max_derived_bytes,
595                }
596                .to_string(),
597            ));
598        }
599
600        // Update monotonic aggs
601        if let Some(ref mut agg) = self.monotonic_agg {
602            agg.snapshot();
603            agg.update(
604                &self.key_column_indices,
605                &delta,
606                self.strict_probability_domain,
607                self.semiring_kind,
608            )?;
609        }
610
611        // Append delta to facts
612        self.facts_bytes += delta_bytes;
613        self.facts.extend(delta.iter().cloned());
614        self.delta = delta;
615
616        Ok(true)
617    }
618
619    /// Dispatch to vectorized LeftAntiJoin, Arrow RowConverter dedup, or legacy ScalarKey dedup.
620    ///
621    /// Priority order:
622    /// 1. `arrow_left_anti_dedup` when `total_existing >= DEDUP_ANTI_JOIN_THRESHOLD` and task_ctx available.
623    /// 2. `RowDedupState` (persistent HashSet, O(M) per iteration) when schema is supported.
624    /// 3. `compute_delta_legacy` (rebuilds from facts, fallback for unsupported column types).
625    async fn compute_delta(
626        &mut self,
627        candidates: &[RecordBatch],
628        task_ctx: Option<&Arc<TaskContext>>,
629    ) -> DFResult<Vec<RecordBatch>> {
630        let total_existing: usize = self.facts.iter().map(|b| b.num_rows()).sum();
631        if total_existing >= DEDUP_ANTI_JOIN_THRESHOLD
632            && let Some(ctx) = task_ctx
633        {
634            return arrow_left_anti_dedup(candidates.to_vec(), &self.facts, &self.schema, ctx)
635                .await;
636        }
637        if let Some(ref mut rd) = self.row_dedup {
638            rd.compute_delta(candidates, &self.schema)
639        } else {
640            self.compute_delta_legacy(candidates)
641        }
642    }
643
644    /// Legacy dedup: rebuild a `HashSet<Vec<ScalarKey>>` from all facts each call.
645    ///
646    /// Used as fallback when `RowConverter` does not support the schema's column types.
647    fn compute_delta_legacy(&self, candidates: &[RecordBatch]) -> DFResult<Vec<RecordBatch>> {
648        // Build set of existing fact row keys (ALL columns)
649        let mut existing: HashSet<Vec<ScalarKey>> = HashSet::new();
650        for batch in &self.facts {
651            for row_idx in 0..batch.num_rows() {
652                let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
653                existing.insert(key);
654            }
655        }
656
657        let mut delta_batches = Vec::new();
658        for batch in candidates {
659            if batch.num_rows() == 0 {
660                continue;
661            }
662            // Filter to only new rows
663            let mut keep = Vec::with_capacity(batch.num_rows());
664            for row_idx in 0..batch.num_rows() {
665                let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
666                keep.push(!existing.contains(&key));
667            }
668
669            // Also dedup within the candidate batch itself
670            for (row_idx, kept) in keep.iter_mut().enumerate() {
671                if *kept {
672                    let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
673                    if !existing.insert(key) {
674                        *kept = false;
675                    }
676                }
677            }
678
679            let keep_mask = arrow_array::BooleanArray::from(keep);
680            let new_rows = batch
681                .columns()
682                .iter()
683                .map(|col| {
684                    arrow::compute::filter(col.as_ref(), &keep_mask).map_err(|e| {
685                        datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
686                    })
687                })
688                .collect::<DFResult<Vec<_>>>()?;
689
690            if new_rows.first().is_some_and(|c| !c.is_empty()) {
691                let filtered =
692                    RecordBatch::try_new(Arc::clone(&self.schema), new_rows).map_err(|e| {
693                        datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
694                    })?;
695                delta_batches.push(filtered);
696            }
697        }
698
699        Ok(delta_batches)
700    }
701
702    /// Check if this rule has converged (no new facts and aggs stable).
703    pub fn is_converged(&self) -> bool {
704        let delta_empty = self.delta.is_empty() || self.delta.iter().all(|b| b.num_rows() == 0);
705        let agg_stable = self.monotonic_agg.as_ref().is_none_or(|a| a.is_stable());
706        delta_empty && agg_stable
707    }
708
709    /// Get all accumulated facts.
710    pub fn all_facts(&self) -> &[RecordBatch] {
711        &self.facts
712    }
713
714    /// Get the delta from the latest iteration.
715    pub fn all_delta(&self) -> &[RecordBatch] {
716        &self.delta
717    }
718
719    /// Consume self and return facts.
720    pub fn into_facts(self) -> Vec<RecordBatch> {
721        self.facts
722    }
723
724    /// Merge candidates using BEST BY semantics.
725    ///
726    /// Combines existing facts with new candidates, keeping only the best row
727    /// per KEY group according to `sort_criteria`. Returns `true` if the
728    /// best-per-KEY fact set actually changed (a genuinely better value was
729    /// found or a new KEY appeared).
730    ///
731    /// This replaces `merge_delta` for rules with BEST BY, enabling convergence
732    /// on cyclic graphs where dominated ALONG values would otherwise produce an
733    /// unbounded stream of "new" full-row facts.
734    pub fn merge_best_by(
735        &mut self,
736        candidates: Vec<RecordBatch>,
737        sort_criteria: &[SortCriterion],
738    ) -> DFResult<bool> {
739        if candidates.is_empty() || candidates.iter().all(|b| b.num_rows() == 0) {
740            self.delta.clear();
741            return Ok(false);
742        }
743
744        // Reconcile schema from the first non-empty candidate batch.
745        if let Some(first) = candidates.iter().find(|b| b.num_rows() > 0) {
746            self.reconcile_schema(&first.schema());
747        }
748
749        // Round floats for stable dedup.
750        let candidates = round_float_columns(&candidates);
751
752        // Snapshot existing best-per-KEY facts for change detection.
753        let old_best: HashMap<Vec<ScalarKey>, Vec<ScalarKey>> =
754            self.build_key_criteria_map(sort_criteria);
755
756        // Concat existing facts + new candidates.
757        let mut all_batches = self.facts.clone();
758        all_batches.extend(candidates);
759        let all_batches: Vec<_> = all_batches
760            .into_iter()
761            .filter(|b| b.num_rows() > 0)
762            .collect();
763        if all_batches.is_empty() {
764            self.delta.clear();
765            return Ok(false);
766        }
767
768        let combined = arrow::compute::concat_batches(&self.schema, &all_batches)
769            .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
770
771        if combined.num_rows() == 0 {
772            self.delta.clear();
773            return Ok(false);
774        }
775
776        // Sort by KEY ASC then criteria, so the best row per KEY group comes
777        // first.
778        let mut sort_columns = Vec::new();
779        for &ki in &self.key_column_indices {
780            if ki >= combined.num_columns() {
781                continue;
782            }
783            sort_columns.push(arrow::compute::SortColumn {
784                values: Arc::clone(combined.column(ki)),
785                options: Some(arrow::compute::SortOptions {
786                    descending: false,
787                    nulls_first: false,
788                }),
789            });
790        }
791        for criterion in sort_criteria {
792            if criterion.col_index >= combined.num_columns() {
793                continue;
794            }
795            sort_columns.push(arrow::compute::SortColumn {
796                values: Arc::clone(combined.column(criterion.col_index)),
797                options: Some(arrow::compute::SortOptions {
798                    descending: !criterion.ascending,
799                    nulls_first: criterion.nulls_first,
800                }),
801            });
802        }
803
804        let sorted_indices =
805            arrow::compute::lexsort_to_indices(&sort_columns, None).map_err(arrow_err)?;
806        let sorted_columns: Vec<_> = combined
807            .columns()
808            .iter()
809            .map(|col| arrow::compute::take(col.as_ref(), &sorted_indices, None))
810            .collect::<Result<Vec<_>, _>>()
811            .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
812        let sorted = RecordBatch::try_new(Arc::clone(&self.schema), sorted_columns)
813            .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
814
815        // Dedup: keep first (best) row per KEY group.
816        let mut keep_indices: Vec<u32> = Vec::new();
817        let mut prev_key: Option<Vec<ScalarKey>> = None;
818        for row_idx in 0..sorted.num_rows() {
819            let key = extract_scalar_key(&sorted, &self.key_column_indices, row_idx);
820            let is_new_group = match &prev_key {
821                None => true,
822                Some(prev) => *prev != key,
823            };
824            if is_new_group {
825                keep_indices.push(row_idx as u32);
826                prev_key = Some(key);
827            }
828        }
829
830        let keep_array = arrow_array::UInt32Array::from(keep_indices);
831        let output_columns: Vec<_> = sorted
832            .columns()
833            .iter()
834            .map(|col| arrow::compute::take(col.as_ref(), &keep_array, None))
835            .collect::<Result<Vec<_>, _>>()
836            .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
837        let pruned = RecordBatch::try_new(Arc::clone(&self.schema), output_columns)
838            .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
839
840        // Detect whether the best-per-KEY set actually changed.
841        let new_best: HashMap<Vec<ScalarKey>, Vec<ScalarKey>> = {
842            let mut map = HashMap::new();
843            for row_idx in 0..pruned.num_rows() {
844                let key = extract_scalar_key(&pruned, &self.key_column_indices, row_idx);
845                let criteria: Vec<ScalarKey> = sort_criteria
846                    .iter()
847                    .flat_map(|c| extract_scalar_key(&pruned, &[c.col_index], row_idx))
848                    .collect();
849                map.insert(key, criteria);
850            }
851            map
852        };
853        let changed = old_best != new_best;
854
855        tracing::debug!(
856            rule = %self.rule_name,
857            old_keys = old_best.len(),
858            new_keys = new_best.len(),
859            changed = changed,
860            "BEST BY merge"
861        );
862
863        // Replace facts with the pruned set.
864        self.facts_bytes = batch_byte_size(&pruned);
865        self.facts = vec![pruned];
866        if changed {
867            // Delta is conceptually the new/improved facts, but since we
868            // replaced the entire set, just mark delta non-empty.
869            self.delta = self.facts.clone();
870        } else {
871            self.delta.clear();
872        }
873
874        // Rebuild row dedup from pruned facts for consistency.
875        self.row_dedup = RowDedupState::try_new(&self.schema);
876        if let Some(ref mut rd) = self.row_dedup {
877            rd.ingest_existing(&self.facts, &self.schema);
878        }
879
880        Ok(changed)
881    }
882
883    /// Build a map from KEY column values to sort criteria values.
884    fn build_key_criteria_map(
885        &self,
886        sort_criteria: &[SortCriterion],
887    ) -> HashMap<Vec<ScalarKey>, Vec<ScalarKey>> {
888        let mut map = HashMap::new();
889        for batch in &self.facts {
890            for row_idx in 0..batch.num_rows() {
891                let key = extract_scalar_key(batch, &self.key_column_indices, row_idx);
892                let criteria: Vec<ScalarKey> = sort_criteria
893                    .iter()
894                    .flat_map(|c| extract_scalar_key(batch, &[c.col_index], row_idx))
895                    .collect();
896                map.insert(key, criteria);
897            }
898        }
899        map
900    }
901}
902
903/// Estimate byte size of a RecordBatch.
904fn batch_byte_size(batch: &RecordBatch) -> usize {
905    batch
906        .columns()
907        .iter()
908        .map(|col| col.get_buffer_memory_size())
909        .sum()
910}
911
912// ---------------------------------------------------------------------------
913// Float rounding for stable dedup
914// ---------------------------------------------------------------------------
915
916/// Round all Float64 columns to 12 decimal places for stable dedup.
917fn round_float_columns(batches: &[RecordBatch]) -> Vec<RecordBatch> {
918    batches
919        .iter()
920        .map(|batch| {
921            let schema = batch.schema();
922            let has_float = schema
923                .fields()
924                .iter()
925                .any(|f| *f.data_type() == arrow_schema::DataType::Float64);
926            if !has_float {
927                return batch.clone();
928            }
929
930            let columns: Vec<arrow_array::ArrayRef> = batch
931                .columns()
932                .iter()
933                .enumerate()
934                .map(|(i, col)| {
935                    if *schema.field(i).data_type() == arrow_schema::DataType::Float64 {
936                        let arr = col
937                            .as_any()
938                            .downcast_ref::<arrow_array::Float64Array>()
939                            .unwrap();
940                        let rounded: arrow_array::Float64Array = arr
941                            .iter()
942                            .map(|v| v.map(|f| (f * 1e12).round() / 1e12))
943                            .collect();
944                        Arc::new(rounded) as arrow_array::ArrayRef
945                    } else {
946                        Arc::clone(col)
947                    }
948                })
949                .collect();
950
951            RecordBatch::try_new(schema, columns).unwrap_or_else(|_| batch.clone())
952        })
953        .collect()
954}
955
956// ---------------------------------------------------------------------------
957// LeftAntiJoin delta deduplication
958// ---------------------------------------------------------------------------
959
960/// Row threshold above which the vectorized Arrow LeftAntiJoin dedup path is used.
961///
962/// Below this threshold the persistent `RowDedupState` HashSet is O(M) and
963/// avoids rebuilding the existing-row set; above it DataFusion's vectorized
964/// HashJoinExec is more cache-efficient.
965const DEDUP_ANTI_JOIN_THRESHOLD: usize = 300;
966
967/// Deduplicate `candidates` against `existing` using DataFusion's HashJoinExec.
968///
969/// Returns rows in `candidates` that do not appear in `existing` (LeftAnti semantics).
970/// `null_equals_null = true` so NULLs are treated as equal for dedup purposes.
971/// Dedup `batches` by all columns (set semantics), keeping the first occurrence.
972///
973/// `arrow_left_anti_dedup` removes candidate rows that match the existing fact
974/// set, but a single semi-naive iteration can emit the same row many times — e.g.
975/// a transitive-closure rule derives the same `(a, b)` pair via every intermediate
976/// `mid` on a path. A `LeftAnti` join does not remove these *within-candidate*
977/// duplicates, so they would leak into the fact set. The `RowDedupState` and legacy
978/// paths both dedup within the candidate batch ([`RowDedupState::compute_delta`],
979/// [`FixpointState::compute_delta_legacy`]); this keeps the `arrow_left_anti_dedup`
980/// path identical so dedup behavior does not change across `DEDUP_ANTI_JOIN_THRESHOLD`.
981fn dedup_batches_all_columns(
982    batches: Vec<RecordBatch>,
983    schema: &SchemaRef,
984) -> DFResult<Vec<RecordBatch>> {
985    let fields: Vec<SortField> = schema
986        .fields()
987        .iter()
988        .map(|f| SortField::new(f.data_type().clone()))
989        .collect();
990    // Unsupported column types: leave as-is. This path is only reached for facts
991    // sets >= DEDUP_ANTI_JOIN_THRESHOLD; for those types the <threshold path uses
992    // `compute_delta_legacy`, which dedups via `ScalarKey` instead.
993    let Ok(converter) = RowConverter::new(fields) else {
994        return Ok(batches);
995    };
996    let mut seen: HashSet<Box<[u8]>> = HashSet::new();
997    let mut out = Vec::with_capacity(batches.len());
998    for batch in batches {
999        if batch.num_rows() == 0 {
1000            continue;
1001        }
1002        let rows = converter
1003            .convert_columns(batch.columns())
1004            .map_err(arrow_err)?;
1005        let mut keep = Vec::with_capacity(batch.num_rows());
1006        for row_idx in 0..batch.num_rows() {
1007            let row_bytes: Box<[u8]> = rows.row(row_idx).data().into();
1008            keep.push(seen.insert(row_bytes));
1009        }
1010        let keep_mask = arrow_array::BooleanArray::from(keep);
1011        let cols = batch
1012            .columns()
1013            .iter()
1014            .map(|c| arrow::compute::filter(c.as_ref(), &keep_mask).map_err(arrow_err))
1015            .collect::<DFResult<Vec<_>>>()?;
1016        if cols.first().is_some_and(|c| !c.is_empty()) {
1017            out.push(RecordBatch::try_new(Arc::clone(schema), cols).map_err(arrow_err)?);
1018        }
1019    }
1020    Ok(out)
1021}
1022
1023async fn arrow_left_anti_dedup(
1024    candidates: Vec<RecordBatch>,
1025    existing: &[RecordBatch],
1026    schema: &SchemaRef,
1027    task_ctx: &Arc<TaskContext>,
1028) -> DFResult<Vec<RecordBatch>> {
1029    if existing.is_empty() || existing.iter().all(|b| b.num_rows() == 0) {
1030        // No existing facts to anti-join against, but still dedup the candidates
1031        // among themselves (a single iteration may emit duplicate rows).
1032        return dedup_batches_all_columns(candidates, schema);
1033    }
1034
1035    let left: Arc<dyn ExecutionPlan> = Arc::new(InMemoryExec::new(candidates, Arc::clone(schema)));
1036    let right: Arc<dyn ExecutionPlan> =
1037        Arc::new(InMemoryExec::new(existing.to_vec(), Arc::clone(schema)));
1038
1039    let on: Vec<(
1040        Arc<dyn datafusion::physical_plan::PhysicalExpr>,
1041        Arc<dyn datafusion::physical_plan::PhysicalExpr>,
1042    )> = schema
1043        .fields()
1044        .iter()
1045        .enumerate()
1046        .map(|(i, field)| {
1047            let l: Arc<dyn datafusion::physical_plan::PhysicalExpr> = Arc::new(
1048                datafusion::physical_plan::expressions::Column::new(field.name(), i),
1049            );
1050            let r: Arc<dyn datafusion::physical_plan::PhysicalExpr> = Arc::new(
1051                datafusion::physical_plan::expressions::Column::new(field.name(), i),
1052            );
1053            (l, r)
1054        })
1055        .collect();
1056
1057    if on.is_empty() {
1058        return Ok(vec![]);
1059    }
1060
1061    let join = HashJoinExec::try_new(
1062        left,
1063        right,
1064        on,
1065        None,
1066        &JoinType::LeftAnti,
1067        None,
1068        PartitionMode::CollectLeft,
1069        datafusion::common::NullEquality::NullEqualsNull,
1070        // null_aware = false: this is a set-difference dedup (NOT EXISTS), not a
1071        // SQL NOT-IN. `NullEqualsNull` already makes NULL keys dedup against NULL
1072        // keys; enabling null-aware semantics would wrongly annihilate all rows
1073        // whenever the existing-fact side contains a NULL key.
1074        false,
1075    )?;
1076
1077    let join_arc: Arc<dyn ExecutionPlan> = Arc::new(join);
1078    // LeftAnti removes candidates that match `existing`, but not duplicate rows
1079    // within the candidate set — dedup those to match the other delta strategies.
1080    let anti = collect_all_partitions(&join_arc, task_ctx.clone()).await?;
1081    dedup_batches_all_columns(anti, schema)
1082}
1083
1084// ---------------------------------------------------------------------------
1085// Plan types for fixpoint rules
1086// ---------------------------------------------------------------------------
1087
1088/// IS-ref binding: a reference from a clause body to a derived relation.
1089#[derive(Debug, Clone)]
1090pub struct IsRefBinding {
1091    /// Index into the DerivedScanRegistry.
1092    pub derived_scan_index: usize,
1093    /// Name of the rule being referenced.
1094    pub rule_name: String,
1095    /// Whether this is a self-reference (rule references itself).
1096    pub is_self_ref: bool,
1097    /// Whether this is a negated reference (NOT IS).
1098    pub negated: bool,
1099    /// For negated IS-refs: `(left_body_col, right_derived_col)` pairs for anti-join filtering.
1100    ///
1101    /// `left_body_col` is the VID column in the clause body (e.g., `"n._vid"`);
1102    /// `right_derived_col` is the corresponding KEY column in the negated rule's facts (e.g., `"n"`).
1103    /// Empty for non-negated IS-refs.
1104    pub anti_join_cols: Vec<(String, String)>,
1105    /// Whether the target rule has a PROB column.
1106    pub target_has_prob: bool,
1107    /// Name of the PROB column in the target rule, if any.
1108    pub target_prob_col: Option<String>,
1109    /// `(body_col, derived_col)` pairs for provenance tracking.
1110    ///
1111    /// Used by shared-proof detection to find which source facts a derived row
1112    /// consumed. Populated for all IS-refs (not just negated ones).
1113    pub provenance_join_cols: Vec<(String, String)>,
1114}
1115
1116/// A single clause (body) within a fixpoint rule.
1117#[derive(Debug)]
1118pub struct FixpointClausePlan {
1119    /// The logical plan for the clause body.
1120    pub body_logical: LogicalPlan,
1121    /// IS-ref bindings used by this clause.
1122    pub is_ref_bindings: Vec<IsRefBinding>,
1123    /// Priority value for this clause (if PRIORITY semantics apply).
1124    pub priority: Option<i64>,
1125    /// ALONG binding variable names propagated from the planner.
1126    pub along_bindings: Vec<String>,
1127    /// Phase B Slice 3: neural-model invocations lifted out of YIELD
1128    /// items by the compiler. Each entry is evaluated per row after the
1129    /// clause body produces batches and before IS-ref handling.
1130    pub model_invocations: Vec<ModelInvocation>,
1131}
1132
1133/// Physical plan for a single rule in a fixpoint stratum.
1134#[derive(Debug)]
1135pub struct FixpointRulePlan {
1136    /// Rule name.
1137    pub name: String,
1138    /// Clause bodies (each evaluates to candidate rows).
1139    pub clauses: Vec<FixpointClausePlan>,
1140    /// Output schema for this rule's derived relation.
1141    pub yield_schema: SchemaRef,
1142    /// Indices of KEY columns within yield_schema.
1143    pub key_column_indices: Vec<usize>,
1144    /// Priority value (if PRIORITY semantics apply).
1145    pub priority: Option<i64>,
1146    /// Whether this rule has FOLD semantics.
1147    pub has_fold: bool,
1148    /// FOLD bindings for post-fixpoint aggregation.
1149    pub fold_bindings: Vec<FoldBinding>,
1150    /// Post-FOLD filter expressions (HAVING semantics).
1151    pub having: Vec<Expr>,
1152    /// Whether this rule has BEST BY semantics.
1153    pub has_best_by: bool,
1154    /// BEST BY sort criteria for post-fixpoint selection.
1155    pub best_by_criteria: Vec<SortCriterion>,
1156    /// Whether this rule has PRIORITY semantics.
1157    pub has_priority: bool,
1158    /// Whether BEST BY should apply a deterministic secondary sort for
1159    /// tie-breaking. When false, tied rows are selected non-deterministically
1160    /// (faster but not repeatable across runs).
1161    pub deterministic: bool,
1162    /// Name of the PROB column in this rule's yield schema, if any.
1163    pub prob_column_name: Option<String>,
1164    /// True when any clause of this rule has ≥2 positive same-stratum
1165    /// IS-refs (non-linear recursion, e.g. `tc(a,b) :- tc(a,m), tc(m,b)`).
1166    /// Such rules get FULL facts (naive evaluation) instead of the latest
1167    /// delta on their self-ref scans: a delta-only join computes Δ×Δ and
1168    /// misses the Δ×F_old combinations, silently under-deriving. Covers
1169    /// both `p :- p, p` and `p :- p, q` (q in the same SCC) shapes.
1170    pub non_linear: bool,
1171}
1172
1173// ---------------------------------------------------------------------------
1174// run_fixpoint_loop — the core semi-naive iteration algorithm
1175// ---------------------------------------------------------------------------
1176
1177/// Run the semi-naive fixpoint iteration loop.
1178///
1179/// Evaluates all rules in a stratum repeatedly, feeding deltas back through
1180/// derived scan handles until convergence or limits are reached.
1181#[expect(clippy::too_many_arguments, reason = "Fixpoint loop needs all context")]
1182async fn run_fixpoint_loop(
1183    rules: Vec<FixpointRulePlan>,
1184    max_iterations: usize,
1185    timeout: Duration,
1186    graph_ctx: Arc<GraphExecutionContext>,
1187    session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
1188    storage: Arc<StorageManager>,
1189    schema_info: Arc<UniSchema>,
1190    params: HashMap<String, Value>,
1191    registry: Arc<DerivedScanRegistry>,
1192    output_schema: SchemaRef,
1193    max_derived_bytes: usize,
1194    derivation_tracker: Option<Arc<ProvenanceStore>>,
1195    iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
1196    strict_probability_domain: bool,
1197    probability_epsilon: f64,
1198    exact_probability: bool,
1199    max_bdd_variables: usize,
1200    warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
1201    approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
1202    top_k_proofs: usize,
1203    timeout_flag: Arc<std::sync::atomic::AtomicU8>,
1204    semiring_kind: SemiringKind,
1205    classifier_registry: Arc<ClassifierRegistry>,
1206    classifier_cache: Option<Arc<ModelInvocationCache>>,
1207    classifier_provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
1208) -> DFResult<Vec<RecordBatch>> {
1209    let start = Instant::now();
1210    let task_ctx = session_ctx.read().task_ctx();
1211
1212    // IMPORTANT: per rollout D-9 the FuzzyNotProbabilistic warning emitted
1213    // below is unsuppressible — do not gate on any suppression mechanism.
1214    // Fuzzy truth values are not probabilities; silent conflation is the
1215    // dominant pitfall in neuro-symbolic systems (LTN, NTP).
1216    if semiring_kind == SemiringKind::MaxMinProb {
1217        let mut warnings = warnings_slot.write().unwrap_or_else(|e| e.into_inner());
1218        let mut already_warned: HashSet<String> = warnings
1219            .iter()
1220            .filter(|w| w.code == RuntimeWarningCode::FuzzyNotProbabilistic)
1221            .map(|w| w.rule_name.clone())
1222            .collect();
1223        for rule in &rules {
1224            if rule.prob_column_name.is_some() && !already_warned.contains(&rule.name) {
1225                warnings.push(RuntimeWarning {
1226                    code: RuntimeWarningCode::FuzzyNotProbabilistic,
1227                    message: format!(
1228                        "rule '{}' carries a PROB column but is being evaluated under \
1229                         the MaxMinProb (fuzzy / Viterbi) semiring; outputs are fuzzy \
1230                         truth values, not probabilities",
1231                        rule.name
1232                    ),
1233                    rule_name: rule.name.clone(),
1234                    variable_count: None,
1235                    key_group: None,
1236                });
1237                already_warned.insert(rule.name.clone());
1238            }
1239        }
1240    }
1241
1242    // Initialize per-rule state
1243    let mut states: Vec<FixpointState> = rules
1244        .iter()
1245        .map(|rule| {
1246            let monotonic_agg = if !rule.fold_bindings.is_empty() {
1247                let bindings: Vec<MonotonicFoldBinding> = rule
1248                    .fold_bindings
1249                    .iter()
1250                    .map(|fb| MonotonicFoldBinding {
1251                        fold_name: fb.output_name.clone(),
1252                        aggregate: std::sync::Arc::clone(&fb.aggregate),
1253                        input_col_index: fb.input_col_index,
1254                        input_col_name: fb.input_col_name.clone(),
1255                    })
1256                    .collect();
1257                Some(MonotonicAggState::new(bindings))
1258            } else {
1259                None
1260            };
1261            FixpointState::new_with_semiring(
1262                rule.name.clone(),
1263                Arc::clone(&rule.yield_schema),
1264                rule.key_column_indices.clone(),
1265                max_derived_bytes,
1266                monotonic_agg,
1267                strict_probability_domain,
1268                semiring_kind,
1269            )
1270        })
1271        .collect();
1272
1273    // Main iteration loop
1274    let mut converged = false;
1275    let mut total_iters = 0usize;
1276    for iteration in 0..max_iterations {
1277        total_iters = iteration + 1;
1278        tracing::debug!("fixpoint iteration {}", iteration);
1279        let mut any_changed = false;
1280
1281        for rule_idx in 0..rules.len() {
1282            let rule = &rules[rule_idx];
1283
1284            // Update derived scan handles for this rule's clauses
1285            update_derived_scan_handles(&registry, &states, rule_idx, &rules);
1286
1287            // Evaluate clause bodies, tracking per-clause candidates for provenance.
1288            let mut all_candidates = Vec::new();
1289            let mut clause_candidates: Vec<Vec<RecordBatch>> = Vec::new();
1290            for clause in &rule.clauses {
1291                // Phase B A4 follow-up: the planner inserts
1292                // `LogicalPlan::LocyModelInvoke` between the body and
1293                // `LocyProject` when this clause has neural-model
1294                // invocations, so `execute_subplan` runs the invocation
1295                // inline as part of the body plan tree.
1296                let mut batches = execute_subplan(
1297                    &clause.body_logical,
1298                    &params,
1299                    &HashMap::new(),
1300                    &graph_ctx,
1301                    &session_ctx,
1302                    &storage,
1303                    &schema_info,
1304                    None, // Locy fixpoint clause body is read-only
1305                )
1306                .await?;
1307                // Apply negated IS-ref semantics: probabilistic complement or anti-join.
1308                for binding in &clause.is_ref_bindings {
1309                    if binding.negated
1310                        && !binding.anti_join_cols.is_empty()
1311                        && let Some(entry) = registry.get(binding.derived_scan_index)
1312                    {
1313                        let neg_facts = entry.data.read().clone();
1314                        if !neg_facts.is_empty() {
1315                            if binding.target_has_prob && rule.prob_column_name.is_some() {
1316                                // Probabilistic complement: add 1-p column instead of filtering.
1317                                let complement_col =
1318                                    format!("__prob_complement_{}", binding.rule_name);
1319                                if let Some(prob_col) = &binding.target_prob_col {
1320                                    batches = apply_prob_complement_composite(
1321                                        batches,
1322                                        &neg_facts,
1323                                        &binding.anti_join_cols,
1324                                        prob_col,
1325                                        &complement_col,
1326                                    )?;
1327                                } else {
1328                                    // target_has_prob but no prob_col: fall back to anti-join.
1329                                    batches = apply_anti_join_composite(
1330                                        batches,
1331                                        &neg_facts,
1332                                        &binding.anti_join_cols,
1333                                    )?;
1334                                }
1335                            } else {
1336                                // Boolean exclusion: anti-join (existing behavior)
1337                                batches = apply_anti_join_composite(
1338                                    batches,
1339                                    &neg_facts,
1340                                    &binding.anti_join_cols,
1341                                )?;
1342                            }
1343                        }
1344                    }
1345                }
1346                // Multiply complement columns into the PROB column (if any) and clean up
1347                let complement_cols: Vec<String> = if !batches.is_empty() {
1348                    batches[0]
1349                        .schema()
1350                        .fields()
1351                        .iter()
1352                        .filter(|f| f.name().starts_with("__prob_complement_"))
1353                        .map(|f| f.name().clone())
1354                        .collect()
1355                } else {
1356                    vec![]
1357                };
1358                if !complement_cols.is_empty() {
1359                    batches = multiply_prob_factors(
1360                        batches,
1361                        rule.prob_column_name.as_deref(),
1362                        &complement_cols,
1363                    )?;
1364                }
1365
1366                clause_candidates.push(batches.clone());
1367                all_candidates.extend(batches);
1368            }
1369
1370            // Merge candidates into facts.
1371            // For BEST BY rules, use a specialized merge that keeps only the
1372            // best row per KEY group, enabling convergence on cyclic graphs.
1373            let changed = if rule.has_best_by && !rule.best_by_criteria.is_empty() {
1374                states[rule_idx].merge_best_by(all_candidates, &rule.best_by_criteria)?
1375            } else {
1376                states[rule_idx]
1377                    .merge_delta(all_candidates, Some(Arc::clone(&task_ctx)))
1378                    .await?
1379            };
1380            if changed {
1381                any_changed = true;
1382                // Record provenance for newly derived facts when tracker is present.
1383                if let Some(ref tracker) = derivation_tracker {
1384                    record_provenance(
1385                        ProvenanceCtx {
1386                            tracker,
1387                            registry: &registry,
1388                            warnings_slot: &warnings_slot,
1389                        },
1390                        rule,
1391                        &states[rule_idx],
1392                        &clause_candidates,
1393                        iteration,
1394                        top_k_proofs,
1395                        ClassifierRefs {
1396                            registry: &classifier_registry,
1397                            cache: classifier_cache.as_ref(),
1398                            provenance_store: classifier_provenance_store.as_ref(),
1399                        },
1400                    )
1401                    .await;
1402                }
1403            }
1404        }
1405
1406        // Check convergence
1407        if !any_changed && states.iter().all(|s| s.is_converged()) {
1408            tracing::debug!("fixpoint converged after {} iterations", iteration + 1);
1409            converged = true;
1410            break;
1411        }
1412
1413        // Check timeout
1414        if start.elapsed() > timeout {
1415            tracing::warn!(
1416                "fixpoint timeout after {} iterations; returning partial results",
1417                iteration + 1,
1418            );
1419            interruption::set(&timeout_flag, interruption::TIMEOUT);
1420            break;
1421        }
1422    }
1423
1424    // Write per-rule iteration counts to the shared slot.
1425    if let Ok(mut counts) = iteration_counts.write() {
1426        for rule in &rules {
1427            counts.insert(rule.name.clone(), total_iters);
1428        }
1429    }
1430
1431    // If we exhausted all iterations without converging, record the iteration
1432    // limit (distinct from a wall-clock timeout) and proceed with partial
1433    // results rather than discarding all work. `set` is first-wins, so a
1434    // wall-clock timeout recorded above is not overwritten here.
1435    if !converged && interruption::reason(&timeout_flag).is_none() {
1436        tracing::warn!(
1437            "fixpoint did not converge after {max_iterations} iterations; returning partial results",
1438        );
1439        interruption::set(&timeout_flag, interruption::ITERATION_LIMIT);
1440    }
1441
1442    // Post-fixpoint processing per rule and collect output
1443    let task_ctx = session_ctx.read().task_ctx();
1444    let mut all_output = Vec::new();
1445
1446    for (rule_idx, state) in states.into_iter().enumerate() {
1447        let rule = &rules[rule_idx];
1448        let mut facts = state.into_facts();
1449        if facts.is_empty() {
1450            continue;
1451        }
1452
1453        // Detect shared proofs before FOLD collapses groups.
1454        //
1455        // TODO(C0-stage2): swap `detect_shared_lineage` for `TopKTag`
1456        // DNF inspection when `semiring_kind == TopKProofs { k }`.
1457        // The library-layer tag math has landed in
1458        // `crates/uni-locy/src/top_k_proofs.rs` (Phase C C0 Stage 1);
1459        // Stage 2 plumbs `TopKTag` through `MonotonicAggState` /
1460        // `FoldExec` so per-row dependency DNFs are available here.
1461        // Until Stage 2, this scalar `ProvenanceStore` path runs for
1462        // every semiring including `TopKProofs` (per rollout D-4
1463        // "graceful migration").
1464        //
1465        // Phase-3 shared-proof detection is meaningful only under
1466        // `AddMultProb` (and `BddExact`, which is the AddMultProb math
1467        // plus a WMC post-correction). Under `MaxMinProb`, `plus = max`
1468        // is idempotent — shared proofs don't double-count — so the
1469        // warning is moot and we skip the work.
1470        let shared_info = if semiring_kind == SemiringKind::MaxMinProb {
1471            None
1472        } else if let Some(ref tracker) = derivation_tracker {
1473            detect_shared_lineage(rule, &facts, tracker, &warnings_slot, semiring_kind)
1474        } else {
1475            None
1476        };
1477
1478        // Apply BDD for shared groups if exact_probability is enabled.
1479        if exact_probability
1480            && let Some(ref info) = shared_info
1481            && let Some(ref tracker) = derivation_tracker
1482        {
1483            facts = apply_exact_wmc(
1484                facts,
1485                rule,
1486                info,
1487                tracker,
1488                max_bdd_variables,
1489                &warnings_slot,
1490                &approximate_slot,
1491            )?;
1492        }
1493
1494        let processed = apply_post_fixpoint_chain(
1495            facts,
1496            rule,
1497            &task_ctx,
1498            strict_probability_domain,
1499            probability_epsilon,
1500            semiring_kind,
1501            derivation_tracker.as_ref().map(Arc::clone),
1502            top_k_proofs,
1503            Some(Arc::clone(&registry)),
1504        )
1505        .await?;
1506        all_output.extend(processed);
1507    }
1508
1509    // If no output, return empty batch with output schema
1510    if all_output.is_empty() {
1511        all_output.push(RecordBatch::new_empty(output_schema));
1512    }
1513
1514    Ok(all_output)
1515}
1516
1517// ---------------------------------------------------------------------------
1518// Provenance recording helpers
1519// ---------------------------------------------------------------------------
1520
1521/// Record provenance for all newly derived facts (rows in the current delta).
1522///
1523/// Called after `merge_delta` returns `true`. Attributes each new fact to the
1524/// clause most likely to have produced it, using first-derivation-wins semantics.
1525/// Borrowed bundle of classifier-side runtime state used by
1526/// provenance / EXPLAIN-reconstruction code paths. Keeps function
1527/// signatures under the too-many-arguments threshold.
1528pub(crate) struct ClassifierRefs<'a> {
1529    pub registry: &'a Arc<ClassifierRegistry>,
1530    pub cache: Option<&'a Arc<uni_locy::ModelInvocationCache>>,
1531    /// Phase C B1-B3 follow-up: when `Some`, EXPLAIN's neural_calls
1532    /// collection consults the side-channel provenance store first
1533    /// (populated by `apply_model_invocations`). This is the only way
1534    /// to surface NeuralProvenance for Python-registered classifiers,
1535    /// whose model_invocations may be rewritten away by the planner
1536    /// and so wouldn't trigger the re-invocation fallback.
1537    pub provenance_store: Option<&'a Arc<uni_locy::NeuralProvenanceStore>>,
1538}
1539
1540/// Borrowed bundle of provenance-recording state: the in-flight
1541/// tracker, the derived-scan registry (used to resolve IS-ref inputs),
1542/// and the shared warnings slot. Bundled to keep
1543/// `record_provenance` / `record_and_detect_lineage_nonrecursive`
1544/// under the too-many-arguments threshold.
1545pub(crate) struct ProvenanceCtx<'a> {
1546    pub tracker: &'a Arc<ProvenanceStore>,
1547    pub registry: &'a Arc<DerivedScanRegistry>,
1548    pub warnings_slot: &'a Arc<StdRwLock<Vec<RuntimeWarning>>>,
1549}
1550
1551async fn record_provenance(
1552    prov: ProvenanceCtx<'_>,
1553    rule: &FixpointRulePlan,
1554    state: &FixpointState,
1555    clause_candidates: &[Vec<RecordBatch>],
1556    iteration: usize,
1557    top_k_proofs: usize,
1558    classifiers: ClassifierRefs<'_>,
1559) {
1560    let tracker = prov.tracker;
1561    let registry = prov.registry;
1562    let warnings_slot = prov.warnings_slot;
1563    let classifier_registry = classifiers.registry;
1564    let classifier_cache = classifiers.cache;
1565    let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
1566
1567    // Pre-compute base fact probabilities for top-k mode.
1568    let base_probs = if top_k_proofs > 0 {
1569        tracker.base_fact_probs()
1570    } else {
1571        HashMap::new()
1572    };
1573
1574    let mut topk_acc = TopKProofAccumulator::new();
1575
1576    for delta_batch in state.all_delta() {
1577        for row_idx in 0..delta_batch.num_rows() {
1578            let row_hash = format!(
1579                "{:?}",
1580                extract_scalar_key(delta_batch, &all_indices, row_idx)
1581            )
1582            .into_bytes();
1583            let fact_row = batch_row_to_value_map(delta_batch, row_idx);
1584            let clause_index =
1585                find_clause_for_row(delta_batch, row_idx, &all_indices, clause_candidates);
1586
1587            let support = collect_is_ref_inputs(rule, clause_index, delta_batch, row_idx, registry);
1588
1589            let proof_probability = if top_k_proofs > 0 {
1590                compute_proof_probability(&support, &base_probs)
1591            } else {
1592                None
1593            };
1594
1595            let entry = ProvenanceAnnotation {
1596                rule_name: rule.name.clone(),
1597                clause_index,
1598                support,
1599                along_values: {
1600                    let along_names: Vec<String> = rule
1601                        .clauses
1602                        .get(clause_index)
1603                        .map(|c| c.along_bindings.clone())
1604                        .unwrap_or_default();
1605                    along_names
1606                        .iter()
1607                        .filter_map(|name| fact_row.get(name).map(|v| (name.clone(), v.clone())))
1608                        .collect()
1609                },
1610                iteration,
1611                fact_row: fact_row.clone(),
1612                proof_probability,
1613                neural_calls: collect_neural_calls_for_row(
1614                    rule,
1615                    clause_index,
1616                    &fact_row,
1617                    classifier_registry,
1618                    classifier_cache,
1619                    classifiers.provenance_store,
1620                )
1621                .await,
1622            };
1623            if top_k_proofs > 0 {
1624                topk_acc.accumulate(&entry, &row_hash);
1625                tracker.record_top_k(row_hash, entry, top_k_proofs);
1626            } else {
1627                tracker.record(row_hash, entry);
1628            }
1629        }
1630    }
1631
1632    topk_acc.emit_warning_if_any(rule, top_k_proofs, warnings_slot);
1633}
1634
1635/// Phase C C0 Stage 2: collects per-row `Proof` tags during the
1636/// fixpoint row walk, then surfaces `TopKPruningCrossedDependency`
1637/// when post-walk top-K merging would drop a proof whose base RVs
1638/// overlap a retained one. The shared `BaseRv` interner is what
1639/// makes the overlap detectable — proofs grounded in the same
1640/// `base_fact_id` get the same `BaseRv`.
1641struct TopKProofAccumulator {
1642    per_fact: HashMap<Vec<u8>, Vec<uni_locy::Proof>>,
1643    base_rv_interner: HashMap<Vec<u8>, uni_locy::BaseRv>,
1644    next_rv: u32,
1645}
1646
1647impl TopKProofAccumulator {
1648    fn new() -> Self {
1649        Self {
1650            per_fact: HashMap::new(),
1651            base_rv_interner: HashMap::new(),
1652            next_rv: 0,
1653        }
1654    }
1655
1656    fn accumulate(&mut self, entry: &ProvenanceAnnotation, row_hash: &[u8]) {
1657        let mut base_rvs = uni_locy::BaseRvSet::empty();
1658        for term in &entry.support {
1659            let rv = *self
1660                .base_rv_interner
1661                .entry(term.base_fact_id.clone())
1662                .or_insert_with(|| {
1663                    let r = uni_locy::BaseRv(self.next_rv);
1664                    self.next_rv += 1;
1665                    r
1666                });
1667            base_rvs.insert(rv);
1668        }
1669        self.per_fact
1670            .entry(row_hash.to_vec())
1671            .or_default()
1672            .push(uni_locy::Proof {
1673                weight: entry.proof_probability.unwrap_or(0.0),
1674                base_rvs,
1675                neural_calls: Vec::new(),
1676            });
1677    }
1678
1679    fn emit_warning_if_any(
1680        &self,
1681        rule: &FixpointRulePlan,
1682        top_k_proofs: usize,
1683        warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
1684    ) {
1685        if top_k_proofs == 0 || self.per_fact.is_empty() {
1686            return;
1687        }
1688        let crossed_facts = self
1689            .per_fact
1690            .values()
1691            .filter(|proofs| {
1692                let (_kept, notice) =
1693                    uni_locy::merge_top_k_runtime(Vec::new(), (*proofs).clone(), top_k_proofs);
1694                notice == uni_locy::PruneNotice::CrossedDependency
1695            })
1696            .count();
1697        if crossed_facts == 0 {
1698            return;
1699        }
1700        let Ok(mut w) = warnings_slot.write() else {
1701            return;
1702        };
1703        let already = w.iter().any(|rw| {
1704            matches!(
1705                rw.code,
1706                uni_locy::types::RuntimeWarningCode::TopKPruningCrossedDependency
1707            ) && rw.rule_name == rule.name
1708        });
1709        if already {
1710            return;
1711        }
1712        w.push(RuntimeWarning {
1713            code: uni_locy::types::RuntimeWarningCode::TopKPruningCrossedDependency,
1714            rule_name: rule.name.clone(),
1715            message: format!(
1716                "rule '{}': top-K proof pruning (k={}) discarded {} fact(s) \
1717                 whose dependencies overlap retained proofs. The retained \
1718                 top-{} under-counts the true joint probability for those \
1719                 facts (Scallop, Huang et al. 2021). Increase k to recover.",
1720                rule.name, top_k_proofs, crossed_facts, top_k_proofs
1721            ),
1722            variable_count: None,
1723            key_group: None,
1724        });
1725    }
1726}
1727
1728/// Collect IS-ref input facts for a derived row using provenance join columns.
1729///
1730/// For each non-negated IS-ref binding in the clause, extracts body-side key
1731/// values from the delta row and finds matching source rows in the registry.
1732/// Returns a `ProofTerm` for each match (with the source fact hash).
1733/// Phase C B1–B3: build [`uni_locy::NeuralProvenance`] entries for
1734/// the model invocations on this clause by reading each
1735/// invocation's output column from the post-LocyModelInvoke row.
1736/// `raw_probability` is the classifier's direct output;
1737/// `calibrated_probability` and `confidence_band` come from the
1738/// active Calibrator (when the classifier wraps one).
1739///
1740/// Phase C B1-B3 follow-up rewrite: re-evaluate the classifier
1741/// per fact using the ORIGINAL pre-rewrite feature expressions
1742/// (stored as `invocation.original_feature_exprs`). This works
1743/// for invocations in YIELD, ALONG, and FOLD positions uniformly
1744/// — the original args carry the input bindings regardless of
1745/// where the synthetic `__model_<n>` column ends up in the plan
1746/// tree. Memoization via `ModelInvocationCache` (already threaded)
1747/// absorbs repeat costs; EXPLAIN typically operates on small
1748/// derivation trees so the per-fact classifier call is bounded.
1749async fn collect_neural_calls_for_row(
1750    rule: &FixpointRulePlan,
1751    clause_index: usize,
1752    fact_row: &uni_locy::FactRow,
1753    classifier_registry: &Arc<ClassifierRegistry>,
1754    classifier_cache: Option<&Arc<uni_locy::ModelInvocationCache>>,
1755    provenance_store: Option<&Arc<uni_locy::NeuralProvenanceStore>>,
1756) -> Vec<uni_locy::NeuralProvenance> {
1757    let Some(clause) = rule.clauses.get(clause_index) else {
1758        return Vec::new();
1759    };
1760    if clause.model_invocations.is_empty() {
1761        return Vec::new();
1762    }
1763    let mut out = Vec::with_capacity(clause.model_invocations.len());
1764    for invocation in &clause.model_invocations {
1765        // Build ClassifyInput from the REWRITTEN feature expressions
1766        // (referencing the synthetic `__feat_*` hidden columns that the
1767        // planner lifts model-call args into). This matches the writer's
1768        // path in `apply_model_invocations`, which iterates
1769        // `invocation.feature_exprs` to compute the same `input_hash`
1770        // that gets stored. Using the pre-rewrite `original_feature_exprs`
1771        // here would compute a different hash for YIELD-position
1772        // invocations (where the pre-rewrite expr references properties
1773        // not materialised into fact_row), causing the store lookup
1774        // below to miss and `neural_calls` to come back empty.
1775        let mut features = std::collections::HashMap::new();
1776        for (binding_name, feat_expr) in invocation
1777            .feature_names
1778            .iter()
1779            .zip(invocation.feature_exprs.iter())
1780        {
1781            features.insert(
1782                binding_name.clone(),
1783                eval_feature_expr_against_fact_row(feat_expr, fact_row),
1784            );
1785        }
1786        let input = uni_locy::ClassifyInput { features };
1787        let input_hash = input.stable_hash();
1788
1789        // Store-first read path. `apply_model_invocations` writes a
1790        // NeuralProvenanceRecord per (model, input_hash) into the
1791        // side-channel store during fixpoint. If we find a record
1792        // there, surface it directly — this is the only path that
1793        // populates calibrated_probability + confidence_band for
1794        // Python-registered classifiers.
1795        if let Some(store) = provenance_store
1796            && let Some(record) = store.get(&invocation.model_name, input_hash)
1797        {
1798            out.push(uni_locy::NeuralProvenance {
1799                model_name: invocation.model_name.clone(),
1800                raw_probability: record.raw_probability,
1801                calibrated_probability: record.calibrated_probability,
1802                confidence_band: record.confidence_band,
1803            });
1804            continue;
1805        }
1806
1807        // Fallback: re-invoke the classifier. Only reached when the
1808        // store wasn't populated for this (model, input) — e.g. older
1809        // sessions where the store wasn't threaded, or when EXPLAIN
1810        // runs against a row that fixpoint never touched.
1811        let Some(classifier) = classifier_registry.get(&invocation.model_name) else {
1812            continue;
1813        };
1814        let raw = if let Some(v) =
1815            classifier_cache.and_then(|c| c.get(&invocation.model_name, input_hash))
1816        {
1817            v
1818        } else {
1819            match classifier.classify(std::slice::from_ref(&input)).await {
1820                Ok(probs) => {
1821                    let v = probs.first().copied().unwrap_or(0.0);
1822                    if let Some(c) = classifier_cache {
1823                        c.insert(&invocation.model_name, input_hash, v);
1824                    }
1825                    v
1826                }
1827                Err(_) => continue,
1828            }
1829        };
1830        let calibrator = classifier.get_calibrator();
1831        let calibrated_probability = calibrator.as_ref().map(|_| raw);
1832        let confidence_band = calibrator.as_ref().and_then(|c| c.confidence_band(raw));
1833        out.push(uni_locy::NeuralProvenance {
1834            model_name: invocation.model_name.clone(),
1835            raw_probability: raw,
1836            calibrated_probability,
1837            confidence_band,
1838        });
1839    }
1840    out
1841}
1842
1843/// Phase C B1-B3 follow-up: evaluate a model's pre-rewrite feature
1844/// expression against a fact_row, producing a `FeatureValue` for
1845/// classifier input reconstruction. Mirrors the compile-time
1846/// acceptance set in `validate_features`:
1847/// - `Variable(name)` → fact_row[name], coerced.
1848/// - `Property(Variable(v), prop)` → fact_row["v.prop"] (the
1849///   materialized property column).
1850fn eval_feature_expr_against_fact_row(
1851    expr: &uni_cypher::ast::Expr,
1852    fact_row: &uni_locy::FactRow,
1853) -> uni_locy::FeatureValue {
1854    use uni_cypher::ast::Expr;
1855    use uni_locy::FeatureValue;
1856    let value_to_feature = |v: Option<&uni_common::Value>| -> FeatureValue {
1857        match v {
1858            Some(uni_common::Value::Float(f)) => FeatureValue::Float(*f),
1859            Some(uni_common::Value::Int(i)) => FeatureValue::Int(*i),
1860            Some(uni_common::Value::Bool(b)) => FeatureValue::Bool(*b),
1861            Some(uni_common::Value::String(s)) => FeatureValue::String(s.clone()),
1862            Some(uni_common::Value::Node(n)) => {
1863                // Encode node by vid for `scorer(s)` style.
1864                FeatureValue::Int(n.vid.as_u64() as i64)
1865            }
1866            _ => FeatureValue::Null,
1867        }
1868    };
1869    // Phase D D1: resolve a sub-expression to its raw `uni_common::Value`
1870    // for the `similar_to` UDF input. Falls back to the node's property
1871    // when the materialized column key isn't directly present.
1872    let resolve_value = |sub: &Expr| -> uni_common::Value {
1873        match sub {
1874            Expr::Variable(name) => fact_row
1875                .get(name)
1876                .cloned()
1877                .unwrap_or(uni_common::Value::Null),
1878            Expr::Property(boxed, prop) if matches!(boxed.as_ref(), Expr::Variable(_)) => {
1879                let Expr::Variable(v) = boxed.as_ref() else {
1880                    unreachable!()
1881                };
1882                let key = format!("{}.{}", v, prop);
1883                if let Some(val) = fact_row.get(&key) {
1884                    return val.clone();
1885                }
1886                if let Some(uni_common::Value::Node(n)) = fact_row.get(v) {
1887                    return n
1888                        .properties
1889                        .get(prop)
1890                        .cloned()
1891                        .unwrap_or(uni_common::Value::Null);
1892                }
1893                uni_common::Value::Null
1894            }
1895            Expr::Literal(lit) => lit.to_value(),
1896            Expr::List(items) => {
1897                let mut out = Vec::with_capacity(items.len());
1898                for it in items {
1899                    out.push(match it {
1900                        Expr::Literal(lit) => lit.to_value(),
1901                        _ => uni_common::Value::Null,
1902                    });
1903                }
1904                uni_common::Value::List(out)
1905            }
1906            _ => uni_common::Value::Null,
1907        }
1908    };
1909
1910    match expr {
1911        Expr::Variable(name) => value_to_feature(fact_row.get(name)),
1912        Expr::Property(boxed, prop) => {
1913            if let Expr::Variable(v) = boxed.as_ref() {
1914                // Try the materialized property column first.
1915                let key = format!("{}.{}", v, prop);
1916                if let Some(val) = fact_row.get(&key) {
1917                    return value_to_feature(Some(val));
1918                }
1919                // Fallback: try the synthetic hidden column that the
1920                // planner injects for property-access feature args
1921                // (`__feat_<var>_<prop>`). The writer side
1922                // (`apply_model_invocations`) already uses this
1923                // fallback (see `resolve_src` in the same file), so
1924                // mirroring it here keeps reader/writer input_hash
1925                // symmetric — without it, the YIELD-position case
1926                // (where `fact_row[v]` is a vid Int, not a Node)
1927                // returns Null and the store-lookup misses.
1928                let hidden_key = format!("__feat_{}_{}", v, prop);
1929                if let Some(val) = fact_row.get(&hidden_key) {
1930                    return value_to_feature(Some(val));
1931                }
1932                // Final fallback: read property directly from the
1933                // node value (works when fact_row carries the Node
1934                // rather than a vid Int).
1935                if let Some(uni_common::Value::Node(n)) = fact_row.get(v) {
1936                    return value_to_feature(n.properties.get(prop));
1937                }
1938            }
1939            FeatureValue::Null
1940        }
1941        Expr::FunctionCall { name, args, .. } if name == "similar_to" && args.len() == 2 => {
1942            let lv = resolve_value(&args[0]);
1943            let rv = resolve_value(&args[1]);
1944            match crate::query::similar_to::eval_similar_to_pure(&lv, &rv) {
1945                Ok(uni_common::Value::Float(f)) => FeatureValue::Float(f),
1946                _ => FeatureValue::Null,
1947            }
1948        }
1949        // `semantic_match` requires the Xervo embedder at this scope, which
1950        // is not threaded into the EXPLAIN re-evaluation path. Surface as
1951        // Null so neural-provenance still renders for the rest of the row.
1952        //
1953        // Phase D D1 graph-structural FunctionCalls (`degree_centrality`,
1954        // `pagerank_score`, `closeness_centrality`, `avg_neighbor`,
1955        // `max_neighbor`, `sum_neighbor`) require the `GraphAlgoHandle`
1956        // (algorithm registry + storage + PropertyManager) and an async
1957        // re-precompute pass — none of which are reachable from this
1958        // synchronous fact-row evaluator. Mode B re-evaluation surfaces
1959        // them as Null; the authoritative hot-path values are recorded
1960        // in `NeuralProvenanceStore` per fact (the EXPLAIN renderer
1961        // consults the store first when configured, falling back to
1962        // Mode B re-evaluation only as a backup).
1963        Expr::FunctionCall { name, .. }
1964            if matches!(
1965                name.as_str(),
1966                "degree_centrality"
1967                    | "pagerank_score"
1968                    | "closeness_centrality"
1969                    | "betweenness_centrality"
1970                    | "eigenvector_centrality"
1971                    | "harmonic_centrality"
1972                    | "katz_centrality"
1973                    | "avg_neighbor"
1974                    | "max_neighbor"
1975                    | "sum_neighbor"
1976            ) =>
1977        {
1978            FeatureValue::Null
1979        }
1980        _ => FeatureValue::Null,
1981    }
1982}
1983
1984fn collect_is_ref_inputs(
1985    rule: &FixpointRulePlan,
1986    clause_index: usize,
1987    delta_batch: &RecordBatch,
1988    row_idx: usize,
1989    registry: &Arc<DerivedScanRegistry>,
1990) -> Vec<ProofTerm> {
1991    let clause = match rule.clauses.get(clause_index) {
1992        Some(c) => c,
1993        None => return vec![],
1994    };
1995
1996    let mut inputs = Vec::new();
1997    let delta_schema = delta_batch.schema();
1998
1999    for binding in &clause.is_ref_bindings {
2000        if binding.negated {
2001            continue;
2002        }
2003        if binding.provenance_join_cols.is_empty() {
2004            continue;
2005        }
2006
2007        // Extract body-side values from the delta row for each provenance join col.
2008        let body_values: Vec<(String, ScalarKey)> = binding
2009            .provenance_join_cols
2010            .iter()
2011            .filter_map(|(body_col, _derived_col)| {
2012                let col_idx = delta_schema
2013                    .fields()
2014                    .iter()
2015                    .position(|f| f.name() == body_col)?;
2016                let key = extract_scalar_key(delta_batch, &[col_idx], row_idx);
2017                Some((body_col.clone(), key.into_iter().next()?))
2018            })
2019            .collect();
2020
2021        if body_values.len() != binding.provenance_join_cols.len() {
2022            continue;
2023        }
2024
2025        // Read current data from the registry entry for this IS-ref's rule.
2026        let entry = match registry.get(binding.derived_scan_index) {
2027            Some(e) => e,
2028            None => continue,
2029        };
2030        let source_batches = entry.data.read();
2031        let source_schema = &entry.schema;
2032
2033        // Find matching source rows and hash them.
2034        for src_batch in source_batches.iter() {
2035            let all_src_indices: Vec<usize> = (0..src_batch.num_columns()).collect();
2036            for src_row in 0..src_batch.num_rows() {
2037                let matches = binding.provenance_join_cols.iter().enumerate().all(
2038                    |(i, (_body_col, derived_col))| {
2039                        let src_col_idx = source_schema
2040                            .fields()
2041                            .iter()
2042                            .position(|f| f.name() == derived_col);
2043                        match src_col_idx {
2044                            Some(idx) => {
2045                                let src_key = extract_scalar_key(src_batch, &[idx], src_row);
2046                                src_key.first() == Some(&body_values[i].1)
2047                            }
2048                            None => false,
2049                        }
2050                    },
2051                );
2052                if matches {
2053                    let fact_hash = format!(
2054                        "{:?}",
2055                        extract_scalar_key(src_batch, &all_src_indices, src_row)
2056                    )
2057                    .into_bytes();
2058                    inputs.push(ProofTerm {
2059                        source_rule: binding.rule_name.clone(),
2060                        base_fact_id: fact_hash,
2061                    });
2062                }
2063            }
2064        }
2065    }
2066
2067    inputs
2068}
2069
2070/// Phase D D-C0: per-body-row variant of [`collect_is_ref_inputs`] used
2071/// to pre-populate `FoldExec`'s `body_support_map` for TopKProofs MNOR.
2072///
2073/// At FOLD time, the rule's own facts haven't been recorded in the
2074/// `ProvenanceStore` yet (`record_provenance` runs after fact
2075/// materialization, and is keyed by post-YIELD hashes anyway), so the
2076/// support set for each pre-fold body row must be reconstructed
2077/// directly from the rule's IS-ref bindings + the source rules'
2078/// registry data.
2079///
2080/// We don't know which clause produced each body row at this point —
2081/// the iteration-local `clause_candidates` are gone — so we iterate
2082/// **every** clause's `is_ref_bindings`. The `provenance_join_cols`
2083/// schema check inside `collect_is_ref_inputs` already skips bindings
2084/// whose body columns aren't in the row's schema, so cross-clause
2085/// contamination is bounded (a binding only matches if its body cols
2086/// are present and the values join). For the single-clause TopKProofs
2087/// scenarios in TCK this is exact; for multi-clause TopKProofs rules
2088/// it is a conservative over-approximation that may inflate base-RV
2089/// counts (treated as the same RV under interning) — acceptable
2090/// because the DNF math collapses duplicates by inclusion-exclusion.
2091fn collect_is_ref_inputs_for_body_row(
2092    rule: &FixpointRulePlan,
2093    delta_batch: &RecordBatch,
2094    row_idx: usize,
2095    registry: &Arc<DerivedScanRegistry>,
2096) -> Vec<ProofTerm> {
2097    let mut combined: Vec<ProofTerm> = Vec::new();
2098    for clause_index in 0..rule.clauses.len() {
2099        let part = collect_is_ref_inputs(rule, clause_index, delta_batch, row_idx, registry);
2100        combined.extend(part);
2101    }
2102    combined
2103}
2104
2105// ---------------------------------------------------------------------------
2106// Shared-lineage detection
2107// ---------------------------------------------------------------------------
2108
2109/// Detect KEY groups in a rule's pre-fold facts where recursive derivation
2110/// may violate the independence assumption of MNOR/MPROD.
2111///
2112/// Uses a two-tier strategy:
2113/// 1. **Precise**: If the `ProvenanceStore` has populated `support` for facts
2114///    in the group, we recursively compute lineage (Cui & Widom 2000) and
2115///    check for pairwise overlap. A shared base fact proves a dependency.
2116/// 2. **Structural fallback**: When lineage tracking is unavailable (e.g., the
2117///    IS-ref subject variables were projected away), we check whether any fact
2118///    in a multi-row group was derived by a clause that has IS-ref bindings.
2119///    Recursive derivation through shared relations is a strong signal that
2120///    proof paths may share intermediate nodes.
2121///
2122/// Per-row data collected during shared-lineage detection.
2123#[expect(
2124    dead_code,
2125    reason = "Fields accessed via SharedLineageInfo in detect_shared_lineage"
2126)]
2127pub(crate) struct SharedGroupRow {
2128    pub fact_hash: Vec<u8>,
2129    pub lineage: HashSet<Vec<u8>>,
2130}
2131
2132/// Information about groups with shared proofs, returned by `detect_shared_lineage`.
2133pub(crate) struct SharedLineageInfo {
2134    /// KEY group → rows with their base fact sets.
2135    pub shared_groups: HashMap<Vec<ScalarKey>, Vec<SharedGroupRow>>,
2136}
2137
2138/// Build a byte key that uniquely identifies a row across all columns.
2139pub(crate) fn fact_hash_key(batch: &RecordBatch, all_indices: &[usize], row_idx: usize) -> Vec<u8> {
2140    format!("{:?}", extract_scalar_key(batch, all_indices, row_idx)).into_bytes()
2141}
2142
2143/// Emits at most one `SharedProbabilisticDependency` warning per rule.
2144/// Returns `Some(SharedLineageInfo)` if any group has shared proofs.
2145fn detect_shared_lineage(
2146    rule: &FixpointRulePlan,
2147    pre_fold_facts: &[RecordBatch],
2148    tracker: &Arc<ProvenanceStore>,
2149    warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
2150    semiring_kind: SemiringKind,
2151) -> Option<SharedLineageInfo> {
2152    use uni_locy::{RuntimeWarning, RuntimeWarningCode};
2153
2154    // Only check rules with probability-domain fold bindings. M3:
2155    // dispatches via the `LocyAggregate` trait so user-authored
2156    // probability aggregates participate automatically — selected by the
2157    // trait's `is_probability_aggregate()` flag, not by hardcoded name.
2158    let has_prob_fold = rule
2159        .fold_bindings
2160        .iter()
2161        .any(|fb| fb.aggregate.is_probability_aggregate());
2162    if !has_prob_fold {
2163        return None;
2164    }
2165
2166    // Group facts by KEY columns.
2167    let key_indices = &rule.key_column_indices;
2168    let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
2169
2170    let mut groups: HashMap<Vec<ScalarKey>, Vec<Vec<u8>>> = HashMap::new();
2171    for batch in pre_fold_facts {
2172        for row_idx in 0..batch.num_rows() {
2173            let key = extract_scalar_key(batch, key_indices, row_idx);
2174            let fact_hash = fact_hash_key(batch, &all_indices, row_idx);
2175            groups.entry(key).or_default().push(fact_hash);
2176        }
2177    }
2178
2179    let mut shared_groups: HashMap<Vec<ScalarKey>, Vec<SharedGroupRow>> = HashMap::new();
2180    let mut any_shared = false;
2181
2182    // Check each group with ≥2 rows.
2183    for (key, fact_hashes) in &groups {
2184        if fact_hashes.len() < 2 {
2185            continue;
2186        }
2187
2188        // Tier 1: precise base-fact overlap detection via tracker inputs.
2189        let mut has_inputs = false;
2190        let mut per_row_bases: Vec<HashSet<Vec<u8>>> = Vec::new();
2191        for fh in fact_hashes {
2192            let bases = compute_lineage(fh, tracker, &mut HashSet::new());
2193            if let Some(entry) = tracker.lookup(fh)
2194                && !entry.support.is_empty()
2195            {
2196                has_inputs = true;
2197            }
2198            per_row_bases.push(bases);
2199        }
2200
2201        let shared_found = if has_inputs {
2202            // At least some facts have tracked inputs — do precise comparison.
2203            let mut found = false;
2204            'outer: for i in 0..per_row_bases.len() {
2205                for j in (i + 1)..per_row_bases.len() {
2206                    if !per_row_bases[i].is_disjoint(&per_row_bases[j]) {
2207                        found = true;
2208                        break 'outer;
2209                    }
2210                }
2211            }
2212            found
2213        } else {
2214            // Tier 2: structural fallback — check if any fact in the group was
2215            // derived by a clause with IS-ref bindings (recursive derivation).
2216            fact_hashes.iter().any(|fh| {
2217                tracker.lookup(fh).is_some_and(|entry| {
2218                    rule.clauses
2219                        .get(entry.clause_index)
2220                        .is_some_and(|clause| clause.is_ref_bindings.iter().any(|b| !b.negated))
2221                })
2222            })
2223        };
2224
2225        if shared_found {
2226            any_shared = true;
2227            // Collect the group rows with their base facts for BDD use.
2228            let rows: Vec<SharedGroupRow> = fact_hashes
2229                .iter()
2230                .zip(per_row_bases)
2231                .map(|(fh, bases)| SharedGroupRow {
2232                    fact_hash: fh.clone(),
2233                    lineage: bases,
2234                })
2235                .collect();
2236            shared_groups.insert(key.clone(), rows);
2237        }
2238    }
2239
2240    // Phase 5: Cross-group correlation warning.
2241    // Check if any IS-ref input fact appears in multiple KEY groups.
2242    // This is independent of within-group sharing: even rules whose KEY groups
2243    // each have only one post-fold row can exhibit cross-group correlation when
2244    // different groups consume the same IS-ref base fact.
2245    {
2246        let mut input_to_groups: HashMap<Vec<u8>, HashSet<Vec<ScalarKey>>> = HashMap::new();
2247        for (key, fact_hashes) in &groups {
2248            for fh in fact_hashes {
2249                if let Some(entry) = tracker.lookup(fh) {
2250                    for input in &entry.support {
2251                        input_to_groups
2252                            .entry(input.base_fact_id.clone())
2253                            .or_default()
2254                            .insert(key.clone());
2255                    }
2256                }
2257            }
2258        }
2259        let has_cross_group = input_to_groups.values().any(|g| g.len() > 1);
2260        if has_cross_group && let Ok(mut warnings) = warnings_slot.write() {
2261            let already_warned = warnings.iter().any(|w| {
2262                w.code == RuntimeWarningCode::CrossGroupCorrelationNotExact
2263                    && w.rule_name == rule.name
2264            });
2265            if !already_warned {
2266                // Phase D F3: pick one canonical example of a shared
2267                // input fact and the KEY groups it bridges, so users
2268                // can correlate the warning with EXPLAIN output.
2269                let example =
2270                    input_to_groups
2271                        .iter()
2272                        .find(|(_, g)| g.len() > 1)
2273                        .map(|(input, groups)| {
2274                            let short = input
2275                                .iter()
2276                                .take(8)
2277                                .map(|b| format!("{:02x}", b))
2278                                .collect::<String>();
2279                            let mut group_strs: Vec<String> =
2280                                groups.iter().map(|k| format!("{:?}", k)).collect();
2281                            group_strs.sort();
2282                            format!(
2283                                "input {} shared by groups [{}]",
2284                                short,
2285                                group_strs.join(", ")
2286                            )
2287                        });
2288                // Phase D F3 case 1 BDD-time deepening: count distinct
2289                // base facts (= BDD variables) that cross groups, so the
2290                // warning carries structured metadata mirroring
2291                // `BddLimitExceeded`. Users can correlate
2292                // `variable_count` with EXPLAIN's BDD output.
2293                let shared_variable_count =
2294                    input_to_groups.values().filter(|g| g.len() > 1).count();
2295                warnings.push(RuntimeWarning {
2296                    code: RuntimeWarningCode::CrossGroupCorrelationNotExact,
2297                    message: format!(
2298                        "Rule '{}': {} IS-ref base fact(s) are shared across different \
2299                         KEY groups. BDD corrects per-group probabilities but cannot \
2300                         account for cross-group correlations.",
2301                        rule.name, shared_variable_count
2302                    ),
2303                    rule_name: rule.name.clone(),
2304                    variable_count: Some(shared_variable_count),
2305                    key_group: example,
2306                });
2307            }
2308        }
2309    }
2310
2311    if any_shared {
2312        // Phase D D-C0b: under `SemiringKind::TopKProofs`, the FOLD-time
2313        // DNF inclusion-exclusion math (shipped in D-C0) auto-corrects
2314        // for within-group base-fact sharing — the "Results may
2315        // overestimate" premise of `SharedProbabilisticDependency`
2316        // is no longer true. Suppress the warning under TopK; users
2317        // who chose TopKProofs explicitly opted into the
2318        // correctness-preserving path. Cross-group correlation
2319        // (`CrossGroupCorrelationNotExact`) still fires above because
2320        // D-C0 doesn't span KEY-group boundaries.
2321        let suppress_under_topk = matches!(semiring_kind, SemiringKind::TopKProofs { .. });
2322        if !suppress_under_topk && let Ok(mut warnings) = warnings_slot.write() {
2323            let already_warned = warnings.iter().any(|w| {
2324                w.code == RuntimeWarningCode::SharedProbabilisticDependency
2325                    && w.rule_name == rule.name
2326            });
2327            if !already_warned {
2328                warnings.push(RuntimeWarning {
2329                    code: RuntimeWarningCode::SharedProbabilisticDependency,
2330                    message: format!(
2331                        "Rule '{}' aggregates with MNOR/MPROD but some proof paths \
2332                         share intermediate facts, violating the independence assumption. \
2333                         Results may overestimate probability.",
2334                        rule.name
2335                    ),
2336                    rule_name: rule.name.clone(),
2337                    variable_count: None,
2338                    key_group: None,
2339                });
2340            }
2341        }
2342        Some(SharedLineageInfo { shared_groups })
2343    } else {
2344        None
2345    }
2346}
2347
2348/// Record provenance and detect shared proofs for non-recursive strata.
2349///
2350/// Non-recursive rules are evaluated in a single pass (no fixpoint loop), so
2351/// the regular `record_provenance` + `detect_shared_lineage` path is never hit.
2352/// This function bridges that gap by recording a `ProvenanceAnnotation` for every
2353/// fact produced by each clause and then running the same two-tier detection
2354/// logic used by the recursive path.
2355#[allow(
2356    clippy::too_many_arguments,
2357    reason = "context bundle would be over-engineering for one call site"
2358)]
2359pub(crate) async fn record_and_detect_lineage_nonrecursive(
2360    rule: &FixpointRulePlan,
2361    tagged_clause_facts: &[(usize, Vec<RecordBatch>)],
2362    tracker: &Arc<ProvenanceStore>,
2363    warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
2364    registry: &Arc<DerivedScanRegistry>,
2365    top_k_proofs: usize,
2366    classifiers: ClassifierRefs<'_>,
2367    semiring_kind: SemiringKind,
2368) -> Option<SharedLineageInfo> {
2369    let classifier_registry = classifiers.registry;
2370    let classifier_cache = classifiers.cache;
2371    let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
2372
2373    // Pre-compute base fact probabilities for top-k mode.
2374    let base_probs = if top_k_proofs > 0 {
2375        tracker.base_fact_probs()
2376    } else {
2377        HashMap::new()
2378    };
2379
2380    let mut topk_acc = TopKProofAccumulator::new();
2381
2382    // Record provenance for each clause's facts.
2383    for (clause_index, batches) in tagged_clause_facts {
2384        for batch in batches {
2385            for row_idx in 0..batch.num_rows() {
2386                let row_hash = fact_hash_key(batch, &all_indices, row_idx);
2387                let fact_row = batch_row_to_value_map(batch, row_idx);
2388
2389                let support = collect_is_ref_inputs(rule, *clause_index, batch, row_idx, registry);
2390
2391                let proof_probability = if top_k_proofs > 0 {
2392                    compute_proof_probability(&support, &base_probs)
2393                } else {
2394                    None
2395                };
2396
2397                let entry = ProvenanceAnnotation {
2398                    rule_name: rule.name.clone(),
2399                    clause_index: *clause_index,
2400                    support,
2401                    along_values: {
2402                        let along_names: Vec<String> = rule
2403                            .clauses
2404                            .get(*clause_index)
2405                            .map(|c| c.along_bindings.clone())
2406                            .unwrap_or_default();
2407                        along_names
2408                            .iter()
2409                            .filter_map(|name| {
2410                                fact_row.get(name).map(|v| (name.clone(), v.clone()))
2411                            })
2412                            .collect()
2413                    },
2414                    iteration: 0,
2415                    fact_row: fact_row.clone(),
2416                    proof_probability,
2417                    neural_calls: collect_neural_calls_for_row(
2418                        rule,
2419                        *clause_index,
2420                        &fact_row,
2421                        classifier_registry,
2422                        classifier_cache,
2423                        classifiers.provenance_store,
2424                    )
2425                    .await,
2426                };
2427                if top_k_proofs > 0 {
2428                    topk_acc.accumulate(&entry, &row_hash);
2429                    tracker.record_top_k(row_hash, entry, top_k_proofs);
2430                } else {
2431                    tracker.record(row_hash, entry);
2432                }
2433            }
2434        }
2435    }
2436
2437    topk_acc.emit_warning_if_any(rule, top_k_proofs, warnings_slot);
2438
2439    // Flatten all clause facts and run detection.
2440    let all_facts: Vec<RecordBatch> = tagged_clause_facts
2441        .iter()
2442        .flat_map(|(_, batches)| batches.iter().cloned())
2443        .collect();
2444    detect_shared_lineage(rule, &all_facts, tracker, warnings_slot, semiring_kind)
2445}
2446
2447/// Apply exact weighted model counting (WMC) for shared-lineage groups.
2448///
2449/// Replaces multiple rows in groups with shared lineage with a single
2450/// representative row whose PROB column carries the BDD-computed exact
2451/// probability (Sang et al. 2005). For groups that exceed
2452/// `max_bdd_variables`, rows are left unchanged and a `BddLimitExceeded`
2453/// warning is emitted.
2454pub(crate) fn apply_exact_wmc(
2455    pre_fold_facts: Vec<RecordBatch>,
2456    rule: &FixpointRulePlan,
2457    shared_info: &SharedLineageInfo,
2458    tracker: &Arc<ProvenanceStore>,
2459    max_bdd_variables: usize,
2460    warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
2461    approximate_slot: &Arc<StdRwLock<HashMap<String, Vec<String>>>>,
2462) -> DFResult<Vec<RecordBatch>> {
2463    use crate::query::df_graph::locy_bdd::{SemiringOp, weighted_model_count};
2464    use uni_locy::{RuntimeWarning, RuntimeWarningCode};
2465
2466    // Find the probability-domain fold binding to know which column
2467    // to overwrite. M3: dispatch through the `LocyAggregate` trait so
2468    // user-authored probability aggregates participate.
2469    let prob_fold = rule
2470        .fold_bindings
2471        .iter()
2472        .find(|fb| fb.aggregate.is_probability_aggregate());
2473    let prob_fold = match prob_fold {
2474        Some(f) => f,
2475        None => return Ok(pre_fold_facts),
2476    };
2477    let semiring_op = if prob_fold.aggregate.is_noisy_or() {
2478        SemiringOp::Disjunction
2479    } else {
2480        SemiringOp::Conjunction
2481    };
2482    let prob_col_idx = prob_fold.input_col_index;
2483    let prob_col_name = rule.yield_schema.field(prob_col_idx).name().clone();
2484
2485    let key_indices = &rule.key_column_indices;
2486    let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
2487
2488    // Build a set of shared group keys for quick lookup.
2489    let shared_keys: HashSet<Vec<ScalarKey>> = shared_info.shared_groups.keys().cloned().collect();
2490
2491    // Phase 1: Collect all rows for each shared KEY group across all batches.
2492    // Store (batch_index, row_index) pairs for each group.
2493    struct GroupAccum {
2494        base_facts: Vec<HashSet<Vec<u8>>>,
2495        base_probs: HashMap<Vec<u8>, f64>,
2496        /// First occurrence: (batch_index, row_index) — used as representative.
2497        representative: (usize, usize),
2498        row_locations: Vec<(usize, usize)>,
2499    }
2500
2501    let mut group_accums: HashMap<Vec<ScalarKey>, GroupAccum> = HashMap::new();
2502    let mut non_shared_rows: Vec<(usize, usize)> = Vec::new(); // (batch_idx, row_idx)
2503
2504    for (batch_idx, batch) in pre_fold_facts.iter().enumerate() {
2505        for row_idx in 0..batch.num_rows() {
2506            let key = extract_scalar_key(batch, key_indices, row_idx);
2507            if shared_keys.contains(&key) {
2508                let fact_hash = fact_hash_key(batch, &all_indices, row_idx);
2509                let bases = compute_lineage(&fact_hash, tracker, &mut HashSet::new());
2510
2511                let accum = group_accums.entry(key).or_insert_with(|| GroupAccum {
2512                    base_facts: Vec::new(),
2513                    base_probs: HashMap::new(),
2514                    representative: (batch_idx, row_idx),
2515                    row_locations: Vec::new(),
2516                });
2517
2518                // Look up probabilities for base facts.
2519                for bf in &bases {
2520                    if !accum.base_probs.contains_key(bf)
2521                        && let Some(entry) = tracker.lookup(bf)
2522                        && let Some(val) = entry.fact_row.get(&prob_col_name)
2523                        && let Some(p) = value_to_f64(val)
2524                    {
2525                        accum.base_probs.insert(bf.clone(), p);
2526                    }
2527                }
2528
2529                accum.base_facts.push(bases);
2530                accum.row_locations.push((batch_idx, row_idx));
2531            } else {
2532                non_shared_rows.push((batch_idx, row_idx));
2533            }
2534        }
2535    }
2536
2537    // Phase 2: Compute BDD for each shared group (across all batches).
2538    // Track which (batch_idx, row_idx) pairs to keep vs drop.
2539    let mut keep_rows: HashSet<(usize, usize)> = HashSet::new();
2540    // Map of (batch_idx, row_idx) → overridden PROB value (for BDD-succeeded groups).
2541    let mut overrides: HashMap<(usize, usize), f64> = HashMap::new();
2542
2543    // All non-shared rows are kept.
2544    for &loc in &non_shared_rows {
2545        keep_rows.insert(loc);
2546    }
2547
2548    for (key, accum) in &group_accums {
2549        let bdd_result = weighted_model_count(
2550            &accum.base_facts,
2551            &accum.base_probs,
2552            semiring_op,
2553            max_bdd_variables,
2554        );
2555
2556        if bdd_result.approximated {
2557            // Emit BddLimitExceeded warning (one per key group).
2558            if let Ok(mut warnings) = warnings_slot.write() {
2559                let key_desc = format!("{key:?}");
2560                let already_warned = warnings.iter().any(|w| {
2561                    w.code == RuntimeWarningCode::BddLimitExceeded
2562                        && w.rule_name == rule.name
2563                        && w.key_group.as_deref() == Some(&key_desc)
2564                });
2565                if !already_warned {
2566                    warnings.push(RuntimeWarning {
2567                        code: RuntimeWarningCode::BddLimitExceeded,
2568                        message: format!(
2569                            "Rule '{}': BDD variable limit exceeded ({} > {}). \
2570                             Falling back to independence-mode result.",
2571                            rule.name, bdd_result.variable_count, max_bdd_variables
2572                        ),
2573                        rule_name: rule.name.clone(),
2574                        variable_count: Some(bdd_result.variable_count),
2575                        key_group: Some(key_desc),
2576                    });
2577                }
2578            }
2579            if let Ok(mut approx) = approximate_slot.write() {
2580                let key_desc = format!("{key:?}");
2581                approx.entry(rule.name.clone()).or_default().push(key_desc);
2582            }
2583            // Keep all rows unchanged.
2584            for &loc in &accum.row_locations {
2585                keep_rows.insert(loc);
2586            }
2587        } else {
2588            // BDD succeeded: keep one representative row with overridden PROB.
2589            keep_rows.insert(accum.representative);
2590            overrides.insert(accum.representative, bdd_result.probability);
2591        }
2592    }
2593
2594    // Phase 3: Build output batches by filtering kept rows per batch.
2595    let mut result_batches = Vec::new();
2596    for (batch_idx, batch) in pre_fold_facts.iter().enumerate() {
2597        let kept_indices: Vec<usize> = (0..batch.num_rows())
2598            .filter(|&row_idx| keep_rows.contains(&(batch_idx, row_idx)))
2599            .collect();
2600
2601        if kept_indices.is_empty() {
2602            continue;
2603        }
2604
2605        let indices = arrow::array::UInt32Array::from(
2606            kept_indices.iter().map(|&i| i as u32).collect::<Vec<_>>(),
2607        );
2608        let mut columns: Vec<arrow::array::ArrayRef> = batch
2609            .columns()
2610            .iter()
2611            .map(|col| arrow::compute::take(col, &indices, None))
2612            .collect::<Result<Vec<_>, _>>()
2613            .map_err(arrow_err)?;
2614
2615        // Check if any kept rows have PROB overrides.
2616        let override_map: Vec<Option<f64>> = kept_indices
2617            .iter()
2618            .map(|&row_idx| overrides.get(&(batch_idx, row_idx)).copied())
2619            .collect();
2620
2621        if override_map.iter().any(|o| o.is_some()) && prob_col_idx < columns.len() {
2622            // Rebuild the PROB column with overrides.
2623            let existing_prob = columns[prob_col_idx]
2624                .as_any()
2625                .downcast_ref::<arrow::array::Float64Array>();
2626            let new_values: Vec<f64> = override_map
2627                .iter()
2628                .enumerate()
2629                .map(|(i, ov)| match ov {
2630                    Some(p) => *p,
2631                    None => existing_prob.map(|arr| arr.value(i)).unwrap_or(0.0),
2632                })
2633                .collect();
2634            columns[prob_col_idx] = Arc::new(arrow::array::Float64Array::from(new_values));
2635        }
2636
2637        let result_batch = RecordBatch::try_new(batch.schema(), columns).map_err(arrow_err)?;
2638        result_batches.push(result_batch);
2639    }
2640
2641    Ok(result_batches)
2642}
2643
2644/// Extract an f64 from a `Value`, supporting Float and Int.
2645fn value_to_f64(val: &uni_common::Value) -> Option<f64> {
2646    match val {
2647        uni_common::Value::Float(f) => Some(*f),
2648        uni_common::Value::Int(i) => Some(*i as f64),
2649        _ => None,
2650    }
2651}
2652
2653/// Compute the lineage of a derived fact (Cui & Widom 2000).
2654///
2655/// Recursively traverses the provenance store to collect the set of base-level
2656/// fact hashes that contribute to this derivation. A base fact is one with no
2657/// IS-ref support (a graph-level fact). Intermediate facts are expanded
2658/// transitively through the store.
2659fn compute_lineage(
2660    fact_hash: &[u8],
2661    tracker: &Arc<ProvenanceStore>,
2662    visited: &mut HashSet<Vec<u8>>,
2663) -> HashSet<Vec<u8>> {
2664    if !visited.insert(fact_hash.to_vec()) {
2665        return HashSet::new(); // Cycle guard.
2666    }
2667
2668    match tracker.lookup(fact_hash) {
2669        Some(entry) if !entry.support.is_empty() => {
2670            let mut bases = HashSet::new();
2671            for input in &entry.support {
2672                let child_bases = compute_lineage(&input.base_fact_id, tracker, visited);
2673                bases.extend(child_bases);
2674            }
2675            bases
2676        }
2677        _ => {
2678            // Base fact (no tracker entry or no inputs).
2679            let mut set = HashSet::new();
2680            set.insert(fact_hash.to_vec());
2681            set
2682        }
2683    }
2684}
2685
2686/// Determine which clause produced a given row by checking each clause's candidates.
2687///
2688/// Returns the index of the first clause whose candidates contain a matching row.
2689/// Falls back to 0 if no match is found.
2690fn find_clause_for_row(
2691    delta_batch: &RecordBatch,
2692    row_idx: usize,
2693    all_indices: &[usize],
2694    clause_candidates: &[Vec<RecordBatch>],
2695) -> usize {
2696    let target_key = extract_scalar_key(delta_batch, all_indices, row_idx);
2697    for (clause_idx, batches) in clause_candidates.iter().enumerate() {
2698        for batch in batches {
2699            if batch.num_columns() != all_indices.len() {
2700                continue;
2701            }
2702            for r in 0..batch.num_rows() {
2703                if extract_scalar_key(batch, all_indices, r) == target_key {
2704                    return clause_idx;
2705                }
2706            }
2707        }
2708    }
2709    0
2710}
2711
2712/// Convert a single row from a `RecordBatch` at `row_idx` into a `HashMap<String, Value>`.
2713fn batch_row_to_value_map(
2714    batch: &RecordBatch,
2715    row_idx: usize,
2716) -> std::collections::HashMap<String, Value> {
2717    use uni_store::storage::arrow_convert::arrow_to_value;
2718
2719    let schema = batch.schema();
2720    schema
2721        .fields()
2722        .iter()
2723        .enumerate()
2724        .map(|(col_idx, field)| {
2725            let col = batch.column(col_idx);
2726            let val = arrow_to_value(col.as_ref(), row_idx, None);
2727            (field.name().clone(), val)
2728        })
2729        .collect()
2730}
2731
2732/// Filter `batches` to exclude rows where `left_col` VID appears in `neg_facts[right_col]`.
2733///
2734/// Implements anti-join semantics for negated IS-refs (`n IS NOT rule`): keeps only
2735/// rows whose subject VID is NOT present in the negated rule's fully-converged facts.
2736pub fn apply_anti_join(
2737    batches: Vec<RecordBatch>,
2738    neg_facts: &[RecordBatch],
2739    left_col: &str,
2740    right_col: &str,
2741) -> datafusion::error::Result<Vec<RecordBatch>> {
2742    use arrow::compute::filter_record_batch;
2743    use arrow_array::{Array as _, BooleanArray, UInt64Array};
2744
2745    // Collect right-side VIDs from the negated rule's derived facts.
2746    let mut banned: std::collections::HashSet<u64> = std::collections::HashSet::new();
2747    for batch in neg_facts {
2748        let Ok(idx) = batch.schema().index_of(right_col) else {
2749            continue;
2750        };
2751        let arr = batch.column(idx);
2752        let Some(vids) = arr.as_any().downcast_ref::<UInt64Array>() else {
2753            continue;
2754        };
2755        for i in 0..vids.len() {
2756            if !vids.is_null(i) {
2757                banned.insert(vids.value(i));
2758            }
2759        }
2760    }
2761
2762    if banned.is_empty() {
2763        return Ok(batches);
2764    }
2765
2766    // Filter body batches: keep rows where left_col NOT IN banned.
2767    let mut result = Vec::new();
2768    for batch in batches {
2769        let Ok(idx) = batch.schema().index_of(left_col) else {
2770            result.push(batch);
2771            continue;
2772        };
2773        let arr = batch.column(idx);
2774        let Some(vids) = arr.as_any().downcast_ref::<UInt64Array>() else {
2775            result.push(batch);
2776            continue;
2777        };
2778        let keep: Vec<bool> = (0..vids.len())
2779            .map(|i| vids.is_null(i) || !banned.contains(&vids.value(i)))
2780            .collect();
2781        let keep_arr = BooleanArray::from(keep);
2782        let filtered = filter_record_batch(&batch, &keep_arr).map_err(arrow_err)?;
2783        if filtered.num_rows() > 0 {
2784            result.push(filtered);
2785        }
2786    }
2787    Ok(result)
2788}
2789
2790// ─── Phase B Slice 3: neural-model invocation pass ───────────────────────
2791//
2792// `apply_model_invocations` runs every `ModelInvocation` lifted from a
2793// clause's YIELD items against the body's output batches. For each
2794// invocation it:
2795//
2796//   1. Resolves each `feature_expr` to a column in the batch — Slice 3
2797//      supports plain `Expr::Variable("name")` references; richer
2798//      expressions (property access, nested calls) are deferred.
2799//   2. Builds one `ClassifyInput` per row keyed by the model's input
2800//      binding names.
2801//   3. Issues a single batched `NeuralClassifier::classify` call.
2802//   4. Appends the resulting `Float64Array` as a new column matching
2803//      `invocation.output_column`.
2804//
2805// Errors:
2806//   * `UnknownNeuralModel`: the model name isn't in the registry.
2807//   * Mismatched feature-expr / column: returned as a DataFusion
2808//     Execution error.
2809
2810#[allow(clippy::too_many_arguments)]
2811pub(crate) async fn apply_model_invocations(
2812    batches: Vec<RecordBatch>,
2813    invocations: &[uni_locy::ModelInvocation],
2814    registry: &Arc<ClassifierRegistry>,
2815    cache: Option<&Arc<uni_locy::ModelInvocationCache>>,
2816    provenance_store: Option<&Arc<uni_locy::NeuralProvenanceStore>>,
2817    path_context_handles: &HashMap<
2818        String,
2819        crate::query::df_graph::locy_model_invoke::PathContextHandle,
2820    >,
2821    xervo_runtime: &crate::query::df_graph::locy_model_invoke::XervoRuntimeHandle,
2822    graph_algo: &crate::query::df_graph::locy_model_invoke::GraphAlgoHandle,
2823) -> DFResult<Vec<RecordBatch>> {
2824    use uni_locy::ClassifyInput;
2825    if batches.is_empty() || invocations.is_empty() {
2826        return Ok(batches);
2827    }
2828    // Phase D D2: pre-embed all unique `semantic_match` query literals
2829    // once per call. Resolvers below lower each `semantic_match(prop,
2830    // 'text')` into a `SimilarTo { left: prop_col, right: Const(Vector) }`.
2831    let semantic_match_embeddings =
2832        pre_embed_semantic_match_queries(invocations, xervo_runtime).await?;
2833    // Phase D D1 graph-structural: pre-compute topology scores and
2834    // neighbor-property maps for every distinct (fn_name, args) tuple
2835    // appearing in any FEATURE FunctionCall. One pass per call; reused
2836    // across every row of every batch.
2837    let graph_feature_maps = precompute_graph_feature_maps(invocations, graph_algo).await?;
2838    let neighbor_feature_maps =
2839        precompute_neighbor_feature_maps(invocations, &batches, graph_algo).await?;
2840    let mut out_batches = Vec::with_capacity(batches.len());
2841    for batch in batches {
2842        let mut current = batch;
2843        for invocation in invocations {
2844            let classifier = registry.get(&invocation.model_name).ok_or_else(|| {
2845                datafusion::error::DataFusionError::Execution(format!(
2846                    "neural classifier '{}' not registered; \
2847                         add it to LocyConfig::classifier_registry",
2848                    invocation.model_name
2849                ))
2850            })?;
2851
2852            // Resolve each feature_expr to a per-row evaluator.
2853            // Supported shapes (validated at compile time by
2854            // `extract_model_invocations` / `validate_features`):
2855            //   * `Expr::Variable("name")` — direct column reference.
2856            //   * `Expr::Property(Variable(v), prop)` — looked up by the
2857            //     conventional `"v.prop"` column name materialized by
2858            //     the planner's `translate_property_access` pipeline.
2859            //   * `Expr::FunctionCall { name: "similar_to"|"semantic_match", ... }`
2860            //     — Phase D D1/D2 retrieval-backed feature; both args
2861            //     resolved to columns; UDF evaluated per row against
2862            //     the row's `uni_common::Value` payloads.
2863            let resolvers = build_feature_resolvers(
2864                &current,
2865                invocation,
2866                path_context_handles,
2867                &semantic_match_embeddings,
2868                &graph_feature_maps,
2869                &neighbor_feature_maps,
2870            )?;
2871
2872            // Build one ClassifyInput per row.
2873            let n_rows = current.num_rows();
2874            let mut inputs: Vec<ClassifyInput> = Vec::with_capacity(n_rows);
2875            let mut input_hashes: Vec<u64> = Vec::with_capacity(n_rows);
2876            for row_idx in 0..n_rows {
2877                let mut features = std::collections::HashMap::new();
2878                for resolver in &resolvers {
2879                    let value = resolver.eval_row(&current, row_idx)?;
2880                    features.insert(resolver.binding_name.clone(), value);
2881                }
2882                let input = ClassifyInput { features };
2883                input_hashes.push(input.stable_hash());
2884                inputs.push(input);
2885            }
2886
2887            // Slice 2: memoization. Split inputs into cache hits and
2888            // misses; only call `classify` on the misses, then weave
2889            // the cached values back in by original row index.
2890            let mut probs: Vec<f64> = vec![0.0; n_rows];
2891            let mut miss_inputs: Vec<ClassifyInput> = Vec::new();
2892            let mut miss_row_indices: Vec<usize> = Vec::new();
2893            if let Some(c) = cache {
2894                for (row_idx, h) in input_hashes.iter().enumerate() {
2895                    match c.get(&invocation.model_name, *h) {
2896                        Some(v) => probs[row_idx] = v,
2897                        None => {
2898                            miss_row_indices.push(row_idx);
2899                            miss_inputs.push(inputs[row_idx].clone());
2900                        }
2901                    }
2902                }
2903            } else {
2904                miss_row_indices = (0..n_rows).collect();
2905                miss_inputs = inputs.clone();
2906            }
2907
2908            // Phase C C-RawCalibratedSeparation: when a calibrator is
2909            // present, route through `raw_and_calibrated` so the
2910            // provenance store records both the base classifier's raw
2911            // output AND the post-calibrator value. Bare classifiers
2912            // (no calibrator) keep using `classify`. The downstream
2913            // `probs[row]` is always the *calibrated* value when
2914            // available — that's what the rule's PROB output column
2915            // and the memoization cache carry.
2916            let calibrator = classifier.get_calibrator();
2917            let (miss_raws, miss_calibrated) = if miss_inputs.is_empty() {
2918                (Vec::new(), Vec::new())
2919            } else if calibrator.is_some() {
2920                let pairs = classifier
2921                    .raw_and_calibrated(&miss_inputs)
2922                    .await
2923                    .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
2924                if pairs.len() != miss_inputs.len() {
2925                    return Err(datafusion::error::DataFusionError::Execution(format!(
2926                        "classifier '{}' raw_and_calibrated returned {} outputs for {} inputs",
2927                        invocation.model_name,
2928                        pairs.len(),
2929                        miss_inputs.len()
2930                    )));
2931                }
2932                let raws: Vec<f64> = pairs.iter().map(|(r, _)| *r).collect();
2933                let cals: Vec<f64> = pairs.iter().map(|(r, c)| c.unwrap_or(*r)).collect();
2934                (raws, cals)
2935            } else {
2936                let r = classifier
2937                    .classify(&miss_inputs)
2938                    .await
2939                    .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
2940                if r.len() != miss_inputs.len() {
2941                    return Err(datafusion::error::DataFusionError::Execution(format!(
2942                        "classifier '{}' returned {} outputs for {} inputs",
2943                        invocation.model_name,
2944                        r.len(),
2945                        miss_inputs.len()
2946                    )));
2947                }
2948                // No calibrator → raw == final.
2949                (r.clone(), r)
2950            };
2951            // The memoization cache stores the *final* (calibrated when
2952            // available) value, matching what downstream rules consume.
2953            // Track per-miss raws alongside so the provenance store
2954            // sees both. For cache hits, we don't have raws — the
2955            // provenance record for that row gets None for `raw` and
2956            // the cached value as `calibrated` (the only thing we
2957            // remembered). Future slice can extend the cache to carry
2958            // both; current behavior matches pre-fix correctness for
2959            // EXPLAIN of cache hits.
2960            let mut row_raw: Vec<Option<f64>> = vec![None; n_rows];
2961            for (i, &row_idx) in miss_row_indices.iter().enumerate() {
2962                probs[row_idx] = miss_calibrated[i];
2963                row_raw[row_idx] = Some(miss_raws[i]);
2964                if let Some(c) = cache {
2965                    c.insert(
2966                        &invocation.model_name,
2967                        input_hashes[row_idx],
2968                        miss_calibrated[i],
2969                    );
2970                }
2971            }
2972
2973            // Phase C B1-B3 follow-up: when a provenance store is
2974            // configured, record (raw, calibrated, confidence_band)
2975            // per row. With C-RawCalibratedSeparation (above),
2976            // `row_raw[i]` carries the *pre-calibrator* value when we
2977            // computed it on this call; `probs[i]` is the
2978            // post-calibrator value. `confidence_band` comes from the
2979            // active Calibrator's `confidence_band(p)`.
2980            if let Some(store) = provenance_store {
2981                for row_idx in 0..n_rows {
2982                    let calibrated_value = probs[row_idx];
2983                    let (raw_value, calibrated) = match (row_raw[row_idx], &calibrator) {
2984                        (Some(raw), Some(_)) => (raw, Some(calibrated_value)),
2985                        (Some(raw), None) => (raw, None),
2986                        // Cache hit: we only have the calibrated value.
2987                        // Report it as raw with `calibrated == raw` to
2988                        // preserve telemetry shape; document this in
2989                        // the field doc.
2990                        (None, _) => (
2991                            calibrated_value,
2992                            calibrator.as_ref().map(|_| calibrated_value),
2993                        ),
2994                    };
2995                    let band = calibrator
2996                        .as_ref()
2997                        .and_then(|c| c.confidence_band(calibrated_value));
2998                    store.record(
2999                        &invocation.model_name,
3000                        input_hashes[row_idx],
3001                        uni_locy::NeuralProvenanceRecord {
3002                            raw_probability: raw_value,
3003                            calibrated_probability: calibrated,
3004                            confidence_band: band,
3005                            // Phase 12 EXPLAIN follow-up: stash the
3006                            // per-binding `FeatureValue` map that fed
3007                            // the classifier so Mode B re-evaluation
3008                            // can surface graph-structural feature
3009                            // values without re-precomputing topology
3010                            // or neighbor maps (the hot-path data is
3011                            // authoritative).
3012                            feature_inputs: inputs[row_idx].features.clone(),
3013                        },
3014                    );
3015                }
3016            }
3017
3018            // Overwrite the placeholder column put there by the
3019            // compile-time YIELD rewrite. If the column doesn't yet
3020            // exist (defensive — shouldn't happen for well-formed
3021            // plans), fall back to appending.
3022            let out_col: Arc<dyn arrow_array::Array> =
3023                Arc::new(arrow_array::Float64Array::from(probs));
3024            let schema = current.schema();
3025            let target_idx = schema.index_of(&invocation.output_column).ok();
3026            let mut columns: Vec<Arc<dyn arrow_array::Array>> = current.columns().to_vec();
3027            let mut fields: Vec<Arc<arrow_schema::Field>> =
3028                schema.fields().iter().cloned().collect();
3029            match target_idx {
3030                Some(idx) => {
3031                    columns[idx] = out_col;
3032                    // Force the field's data type to Float64 in case
3033                    // the placeholder was inferred at a wider type.
3034                    fields[idx] = Arc::new(arrow_schema::Field::new(
3035                        &invocation.output_column,
3036                        arrow_schema::DataType::Float64,
3037                        true,
3038                    ));
3039                }
3040                None => {
3041                    columns.push(out_col);
3042                    fields.push(Arc::new(arrow_schema::Field::new(
3043                        &invocation.output_column,
3044                        arrow_schema::DataType::Float64,
3045                        true,
3046                    )));
3047                }
3048            }
3049            let new_schema = Arc::new(arrow_schema::Schema::new(fields));
3050            current = RecordBatch::try_new(new_schema, columns).map_err(arrow_err)?;
3051        }
3052        out_batches.push(current);
3053    }
3054    Ok(out_batches)
3055}
3056
3057/// Extract a [`uni_locy::FeatureValue`] from a column at a given row.
3058/// Conservative cast set matching the property-graph value types Locy
3059/// currently exposes; unsupported types fall back to `Null`.
3060/// Per-feature evaluator built once from a clause's `feature_exprs`
3061/// and reused for every row in the batch. Supports plain column
3062/// reads (`Direct`) and the Phase D D1/D2 retrieval-backed UDFs
3063/// (`SimilarTo`, `SemanticMatch`).
3064struct FeatureResolver {
3065    binding_name: String,
3066    kind: FeatureResolverKind,
3067}
3068
3069enum FeatureResolverKind {
3070    Direct(usize),
3071    SimilarTo {
3072        left: FeatureValueSrc,
3073        right: FeatureValueSrc,
3074    },
3075    /// Phase D D3 runtime: pull `column` from the source rule's
3076    /// derived facts via a pre-built `vid → FeatureValue` lookup. The
3077    /// `subject_col` is the index of `<subject_var>._vid` in the body
3078    /// batch; the lookup runs once per row.
3079    PathContext {
3080        subject_col: usize,
3081        vid_to_value: Arc<HashMap<u64, uni_locy::FeatureValue>>,
3082    },
3083    /// Phase D D1 graph-structural: look up the subject's pre-computed
3084    /// topology score (degree/pagerank/closeness). `subject_col` indexes
3085    /// the row's `<var>._vid`; `vid_to_score` is the whole-graph
3086    /// procedure output built once per `apply_model_invocations` call.
3087    GraphAlgoScore {
3088        subject_col: usize,
3089        vid_to_score: Arc<HashMap<u64, f64>>,
3090    },
3091    /// Phase D D1 graph-structural: aggregate a numeric property over
3092    /// each subject's one-hop outgoing neighborhood along a named edge
3093    /// type. `vid_to_values` maps subject vid → the list of numeric
3094    /// neighbor property values collected at precompute time; the
3095    /// per-row resolver applies the configured `op` (avg/max/sum).
3096    NeighborAggregate {
3097        subject_col: usize,
3098        op: NeighborAgg,
3099        vid_to_values: Arc<HashMap<u64, Vec<f64>>>,
3100    },
3101}
3102
3103#[derive(Debug, Clone, Copy)]
3104enum NeighborAgg {
3105    Avg,
3106    Max,
3107    Sum,
3108}
3109
3110/// Direction for one-hop neighborhood traversal. Mirrors
3111/// `uni_store::storage::direction::Direction` but is independent so
3112/// the typecheck / planner layer doesn't depend on uni-store.
3113#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3114enum NeighborDirection {
3115    Outgoing,
3116    Incoming,
3117    Both,
3118}
3119
3120impl NeighborDirection {
3121    fn store_directions(self) -> &'static [uni_store::storage::direction::Direction] {
3122        use uni_store::storage::direction::Direction;
3123        match self {
3124            NeighborDirection::Outgoing => &[Direction::Outgoing],
3125            NeighborDirection::Incoming => &[Direction::Incoming],
3126            NeighborDirection::Both => &[Direction::Outgoing, Direction::Incoming],
3127        }
3128    }
3129}
3130
3131impl NeighborAgg {
3132    fn from_fn_name(name: &str) -> Option<Self> {
3133        match name {
3134            "avg_neighbor" => Some(NeighborAgg::Avg),
3135            "max_neighbor" => Some(NeighborAgg::Max),
3136            "sum_neighbor" => Some(NeighborAgg::Sum),
3137            _ => None,
3138        }
3139    }
3140
3141    fn apply(self, values: &[f64]) -> Option<f64> {
3142        if values.is_empty() {
3143            return None;
3144        }
3145        match self {
3146            NeighborAgg::Avg => Some(values.iter().sum::<f64>() / values.len() as f64),
3147            NeighborAgg::Max => values.iter().copied().reduce(f64::max),
3148            NeighborAgg::Sum => Some(values.iter().sum()),
3149        }
3150    }
3151}
3152
3153/// One side of a `similar_to` feature: either a column index in the
3154/// per-row batch or a constant value lifted from a literal expression.
3155enum FeatureValueSrc {
3156    Col(usize),
3157    Const(uni_common::Value),
3158}
3159
3160impl FeatureValueSrc {
3161    fn resolve(&self, batch: &RecordBatch, row_idx: usize) -> uni_common::Value {
3162        match self {
3163            FeatureValueSrc::Col(idx) => extract_common_value(batch.column(*idx).as_ref(), row_idx),
3164            FeatureValueSrc::Const(v) => v.clone(),
3165        }
3166    }
3167}
3168
3169impl FeatureResolver {
3170    fn eval_row(&self, batch: &RecordBatch, row_idx: usize) -> DFResult<uni_locy::FeatureValue> {
3171        match &self.kind {
3172            FeatureResolverKind::Direct(idx) => {
3173                Ok(extract_feature_value(batch.column(*idx).as_ref(), row_idx))
3174            }
3175            FeatureResolverKind::SimilarTo { left, right } => {
3176                let lv = left.resolve(batch, row_idx);
3177                let rv = right.resolve(batch, row_idx);
3178                match crate::query::similar_to::eval_similar_to_pure(&lv, &rv) {
3179                    Ok(uni_common::Value::Float(f)) => Ok(uni_locy::FeatureValue::Float(f)),
3180                    Ok(_) => Ok(uni_locy::FeatureValue::Null),
3181                    Err(e) => Err(datafusion::error::DataFusionError::Execution(format!(
3182                        "similar_to UDF failed: {e}"
3183                    ))),
3184                }
3185            }
3186            FeatureResolverKind::PathContext {
3187                subject_col,
3188                vid_to_value,
3189            } => {
3190                let col = batch.column(*subject_col);
3191                if col.is_null(row_idx) {
3192                    return Ok(uni_locy::FeatureValue::Null);
3193                }
3194                if let Some(arr) = col.as_any().downcast_ref::<arrow_array::UInt64Array>() {
3195                    let vid = arr.value(row_idx);
3196                    Ok(vid_to_value
3197                        .get(&vid)
3198                        .cloned()
3199                        .unwrap_or(uni_locy::FeatureValue::Null))
3200                } else if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Int64Array>() {
3201                    let vid = arr.value(row_idx) as u64;
3202                    Ok(vid_to_value
3203                        .get(&vid)
3204                        .cloned()
3205                        .unwrap_or(uni_locy::FeatureValue::Null))
3206                } else {
3207                    Ok(uni_locy::FeatureValue::Null)
3208                }
3209            }
3210            FeatureResolverKind::GraphAlgoScore {
3211                subject_col,
3212                vid_to_score,
3213            } => {
3214                let col = batch.column(*subject_col);
3215                if col.is_null(row_idx) {
3216                    return Ok(uni_locy::FeatureValue::Null);
3217                }
3218                let vid_opt: Option<u64> = if let Some(arr) =
3219                    col.as_any().downcast_ref::<arrow_array::UInt64Array>()
3220                {
3221                    Some(arr.value(row_idx))
3222                } else if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Int64Array>() {
3223                    Some(arr.value(row_idx) as u64)
3224                } else {
3225                    // Fallback: subject column carries a Node-encoded
3226                    // `uni_common::Value` (LargeBinary via codec). Decode
3227                    // and pull the VID. This is the common case for
3228                    // bare-variable subjects where no `_vid` hidden
3229                    // column was materialized.
3230                    match extract_common_value(col.as_ref(), row_idx) {
3231                        uni_common::Value::Node(n) => Some(n.vid.as_u64()),
3232                        uni_common::Value::Int(i) => Some(i as u64),
3233                        _ => None,
3234                    }
3235                };
3236                Ok(vid_opt
3237                    .and_then(|v| vid_to_score.get(&v).copied())
3238                    .map(uni_locy::FeatureValue::Float)
3239                    .unwrap_or(uni_locy::FeatureValue::Null))
3240            }
3241            FeatureResolverKind::NeighborAggregate {
3242                subject_col,
3243                op,
3244                vid_to_values,
3245            } => {
3246                let vid_opt = extract_vid_from_column(batch.column(*subject_col).as_ref(), row_idx);
3247                Ok(vid_opt
3248                    .and_then(|v| vid_to_values.get(&v))
3249                    .and_then(|values| op.apply(values))
3250                    .map(uni_locy::FeatureValue::Float)
3251                    .unwrap_or(uni_locy::FeatureValue::Null))
3252            }
3253        }
3254    }
3255}
3256
3257/// Extract a node VID from a per-row batch column. Handles the three
3258/// common shapes: `_vid` UInt64 columns, Int64 columns, and Node-encoded
3259/// LargeBinary columns (the standard `uni_common::Value::Node` codec
3260/// representation for a bare variable column).
3261fn extract_vid_from_column(col: &dyn arrow_array::Array, row_idx: usize) -> Option<u64> {
3262    if col.is_null(row_idx) {
3263        return None;
3264    }
3265    if let Some(arr) = col.as_any().downcast_ref::<arrow_array::UInt64Array>() {
3266        return Some(arr.value(row_idx));
3267    }
3268    if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Int64Array>() {
3269        return Some(arr.value(row_idx) as u64);
3270    }
3271    match extract_common_value(col, row_idx) {
3272        uni_common::Value::Node(n) => Some(n.vid.as_u64()),
3273        uni_common::Value::Int(i) => Some(i as u64),
3274        _ => None,
3275    }
3276}
3277
3278#[allow(clippy::too_many_arguments)]
3279fn build_feature_resolvers(
3280    batch: &RecordBatch,
3281    invocation: &uni_locy::ModelInvocation,
3282    path_context_handles: &HashMap<
3283        String,
3284        crate::query::df_graph::locy_model_invoke::PathContextHandle,
3285    >,
3286    semantic_match_embeddings: &HashMap<String, Vec<f32>>,
3287    graph_feature_maps: &HashMap<String, Arc<HashMap<u64, f64>>>,
3288    neighbor_feature_maps: &NeighborFeatureMaps,
3289) -> DFResult<Vec<FeatureResolver>> {
3290    use uni_cypher::ast::Expr;
3291    let schema = batch.schema();
3292    let lookup_col = |name_or_property: String| -> DFResult<usize> {
3293        schema.index_of(&name_or_property).map_err(|_| {
3294            datafusion::error::DataFusionError::Execution(format!(
3295                "feature column '{name_or_property}' not found in clause body output schema"
3296            ))
3297        })
3298    };
3299    // Resolve a feature sub-expression to a per-row value source. Variables
3300    // and property accesses map to batch columns; list/scalar literals
3301    // become inline constants — required so `similar_to(s.embedding, [1,0,0])`
3302    // works without a hidden column for the literal vector.
3303    let resolve_src = |expr: &Expr| -> DFResult<FeatureValueSrc> {
3304        match expr {
3305            Expr::Variable(name) => {
3306                let col = if schema.index_of(name).is_ok() {
3307                    name.clone()
3308                } else {
3309                    let vid_name = format!("{}._vid", name);
3310                    if schema.index_of(&vid_name).is_ok() {
3311                        vid_name
3312                    } else {
3313                        name.clone()
3314                    }
3315                };
3316                Ok(FeatureValueSrc::Col(lookup_col(col)?))
3317            }
3318            Expr::Property(boxed, prop) if matches!(boxed.as_ref(), Expr::Variable(_)) => {
3319                let Expr::Variable(v) = boxed.as_ref() else {
3320                    unreachable!()
3321                };
3322                let direct = format!("{}.{}", v, prop);
3323                let col = if schema.index_of(&direct).is_ok() {
3324                    direct
3325                } else {
3326                    format!("__feat_{}_{}", v, prop)
3327                };
3328                Ok(FeatureValueSrc::Col(lookup_col(col)?))
3329            }
3330            Expr::Literal(lit) => Ok(FeatureValueSrc::Const(lit.to_value())),
3331            Expr::List(items) => {
3332                let mut out = Vec::with_capacity(items.len());
3333                for it in items {
3334                    out.push(match it {
3335                        Expr::Literal(lit) => lit.to_value(),
3336                        _ => uni_common::Value::Null,
3337                    });
3338                }
3339                Ok(FeatureValueSrc::Const(uni_common::Value::List(out)))
3340            }
3341            other => Err(datafusion::error::DataFusionError::Execution(format!(
3342                "unsupported feature sub-expression: {other:?}"
3343            ))),
3344        }
3345    };
3346
3347    // Phase D D3 runtime: when the model declares a path-context
3348    // feature, build a `vid → FeatureValue` lookup once from the
3349    // source rule's converged facts and wrap it in an Arc so the
3350    // per-row resolver does a single hash lookup. The model's
3351    // `INPUT` bindings are unused under this form for MVP — the
3352    // resolver's binding name is the column name (matches how the
3353    // mock-classifier feature-driver pattern in TCK consumes it).
3354    if let Some(pc) = &invocation.path_context {
3355        let handle = path_context_handles.get(&pc.source_rule).ok_or_else(|| {
3356            datafusion::error::DataFusionError::Execution(format!(
3357                "model '{}' path_context references rule '{}' but no DerivedScanHandle \
3358                 was registered; this should never happen — the build_clause path \
3359                 mints a handle for every distinct source_rule in the invocation set",
3360                invocation.model_name, pc.source_rule
3361            ))
3362        })?;
3363        let subject_col = schema
3364            .index_of(&format!("{}._vid", pc.subject_var))
3365            .or_else(|_| schema.index_of(&pc.subject_var))
3366            .map_err(|_| {
3367                datafusion::error::DataFusionError::Execution(format!(
3368                    "model '{}' path_context: subject column '{}' (or '{0}._vid') not \
3369                     in body batch schema",
3370                    invocation.model_name, pc.subject_var
3371                ))
3372            })?;
3373        let vid_to_value =
3374            build_path_context_lookup(handle, &pc.subject_var, &pc.column, &invocation.model_name)?;
3375        return Ok(vec![FeatureResolver {
3376            binding_name: pc.column.clone(),
3377            kind: FeatureResolverKind::PathContext {
3378                subject_col,
3379                vid_to_value: Arc::new(vid_to_value),
3380            },
3381        }]);
3382    }
3383
3384    let mut out = Vec::with_capacity(invocation.feature_exprs.len());
3385    for (i, fexpr) in invocation.feature_exprs.iter().enumerate() {
3386        let binding_name = invocation.feature_names[i].clone();
3387        let kind = match fexpr {
3388            Expr::FunctionCall { name, args, .. } if name == "similar_to" => {
3389                if args.len() != 2 {
3390                    return Err(datafusion::error::DataFusionError::Execution(format!(
3391                        "similar_to expects 2 args, got {}",
3392                        args.len()
3393                    )));
3394                }
3395                FeatureResolverKind::SimilarTo {
3396                    left: resolve_src(&args[0])?,
3397                    right: resolve_src(&args[1])?,
3398                }
3399            }
3400            Expr::FunctionCall { name, args, .. } if name == "semantic_match" => {
3401                // Phase D D2: lower `semantic_match(prop, 'text')` to a
3402                // `SimilarTo` resolver with the pre-embedded query
3403                // vector as the right side. The literal text was embedded
3404                // once via the Xervo runtime in `pre_embed_semantic_match_queries`.
3405                if args.len() != 2 {
3406                    return Err(datafusion::error::DataFusionError::Execution(format!(
3407                        "semantic_match expects 2 args, got {}",
3408                        args.len()
3409                    )));
3410                }
3411                let text = match &args[1] {
3412                    Expr::Literal(uni_cypher::ast::CypherLiteral::String(s)) => s.clone(),
3413                    other => {
3414                        return Err(datafusion::error::DataFusionError::Execution(format!(
3415                            "semantic_match: 2nd arg must be a string literal, got {other:?}"
3416                        )));
3417                    }
3418                };
3419                let embedded = semantic_match_embeddings.get(&text).ok_or_else(|| {
3420                    datafusion::error::DataFusionError::Execution(format!(
3421                        "semantic_match: query text '{text}' was not pre-embedded. \
3422                         This is a bug — `apply_model_invocations` should have \
3423                         embedded all unique semantic_match texts up front. Most \
3424                         likely the Xervo runtime is not configured (configure \
3425                         via `LocyConfig::xervo_runtime` or its equivalent)."
3426                    ))
3427                })?;
3428                let right_vec: Vec<f32> = embedded.clone();
3429                FeatureResolverKind::SimilarTo {
3430                    left: resolve_src(&args[0])?,
3431                    right: FeatureValueSrc::Const(uni_common::Value::Vector(right_vec)),
3432                }
3433            }
3434            Expr::FunctionCall { name, args, .. }
3435                if matches!(
3436                    name.as_str(),
3437                    "degree_centrality"
3438                        | "pagerank_score"
3439                        | "closeness_centrality"
3440                        | "betweenness_centrality"
3441                        | "eigenvector_centrality"
3442                        | "harmonic_centrality"
3443                        | "katz_centrality"
3444                ) =>
3445            {
3446                if args.len() != 1 {
3447                    return Err(datafusion::error::DataFusionError::Execution(format!(
3448                        "{name} expects 1 arg, got {}",
3449                        args.len()
3450                    )));
3451                }
3452                let Expr::Variable(v) = &args[0] else {
3453                    return Err(datafusion::error::DataFusionError::Execution(format!(
3454                        "{name}(...) argument must be a node variable, got {:?}",
3455                        args[0]
3456                    )));
3457                };
3458                let subject_col = {
3459                    let direct = schema.index_of(v).ok();
3460                    let vid_name = format!("{}._vid", v);
3461                    let vid_col = schema.index_of(&vid_name).ok();
3462                    vid_col.or(direct).ok_or_else(|| {
3463                        datafusion::error::DataFusionError::Execution(format!(
3464                            "{name}: subject column '{v}' (or '{v}._vid') not in body batch schema"
3465                        ))
3466                    })?
3467                };
3468                let vid_to_score = graph_feature_maps.get(name).cloned().ok_or_else(|| {
3469                    datafusion::error::DataFusionError::Execution(format!(
3470                        "{name}: pre-computed score map missing. This is a bug — \
3471                         `apply_model_invocations` should have called \
3472                         `precompute_graph_feature_maps` for every graph-structural \
3473                         feature before building resolvers. Most likely the graph \
3474                         algorithm registry is not configured."
3475                    ))
3476                })?;
3477                FeatureResolverKind::GraphAlgoScore {
3478                    subject_col,
3479                    vid_to_score,
3480                }
3481            }
3482            Expr::FunctionCall { name, args, .. }
3483                if matches!(
3484                    name.as_str(),
3485                    "avg_neighbor" | "max_neighbor" | "sum_neighbor"
3486                ) =>
3487            {
3488                if args.len() != 3 && args.len() != 4 {
3489                    return Err(datafusion::error::DataFusionError::Execution(format!(
3490                        "{name} expects 3 or 4 args, got {}",
3491                        args.len()
3492                    )));
3493                }
3494                let Expr::Variable(v) = &args[0] else {
3495                    return Err(datafusion::error::DataFusionError::Execution(format!(
3496                        "{name}(...) first argument must be a node variable, got {:?}",
3497                        args[0]
3498                    )));
3499                };
3500                let rel_type = match &args[1] {
3501                    Expr::Literal(uni_cypher::ast::CypherLiteral::String(s)) => s.clone(),
3502                    other => {
3503                        return Err(datafusion::error::DataFusionError::Execution(format!(
3504                            "{name}: 2nd arg must be a string literal (rel-type), got {other:?}"
3505                        )));
3506                    }
3507                };
3508                let prop_name = match &args[2] {
3509                    Expr::Literal(uni_cypher::ast::CypherLiteral::String(s)) => s.clone(),
3510                    other => {
3511                        return Err(datafusion::error::DataFusionError::Execution(format!(
3512                            "{name}: 3rd arg must be a string literal (property), got {other:?}"
3513                        )));
3514                    }
3515                };
3516                let direction_arg = match args.get(3) {
3517                    None => NeighborDirection::Outgoing,
3518                    Some(Expr::Literal(uni_cypher::ast::CypherLiteral::String(d))) => {
3519                        match d.to_uppercase().as_str() {
3520                            "OUTGOING" => NeighborDirection::Outgoing,
3521                            "INCOMING" => NeighborDirection::Incoming,
3522                            "BOTH" => NeighborDirection::Both,
3523                            other => {
3524                                return Err(datafusion::error::DataFusionError::Execution(
3525                                    format!(
3526                                        "{name}: direction must be OUTGOING|INCOMING|BOTH, got '{other}'"
3527                                    ),
3528                                ));
3529                            }
3530                        }
3531                    }
3532                    Some(other) => {
3533                        return Err(datafusion::error::DataFusionError::Execution(format!(
3534                            "{name}: 4th arg must be a string literal (direction), got {other:?}"
3535                        )));
3536                    }
3537                };
3538                let subject_col = {
3539                    let direct = schema.index_of(v).ok();
3540                    let vid_name = format!("{}._vid", v);
3541                    let vid_col = schema.index_of(&vid_name).ok();
3542                    vid_col.or(direct).ok_or_else(|| {
3543                        datafusion::error::DataFusionError::Execution(format!(
3544                            "{name}: subject column '{v}' (or '{v}._vid') not in body batch schema"
3545                        ))
3546                    })?
3547                };
3548                let vid_to_values = neighbor_feature_maps
3549                    .get(&(rel_type.clone(), prop_name.clone(), direction_arg))
3550                    .cloned()
3551                    .ok_or_else(|| {
3552                        datafusion::error::DataFusionError::Execution(format!(
3553                            "{name}: pre-computed neighbor map missing for ({rel_type}, {prop_name}, {direction_arg:?}). \
3554                             This is a bug — `apply_model_invocations` should have called \
3555                             `precompute_neighbor_feature_maps` for every neighbor-aggregator \
3556                             feature before building resolvers."
3557                        ))
3558                    })?;
3559                let op = NeighborAgg::from_fn_name(name).unwrap();
3560                FeatureResolverKind::NeighborAggregate {
3561                    subject_col,
3562                    op,
3563                    vid_to_values,
3564                }
3565            }
3566            other => match resolve_src(other)? {
3567                FeatureValueSrc::Col(idx) => FeatureResolverKind::Direct(idx),
3568                FeatureValueSrc::Const(_) => {
3569                    return Err(datafusion::error::DataFusionError::Execution(format!(
3570                        "model '{}' feature must reference a variable or property — got a literal",
3571                        invocation.model_name
3572                    )));
3573                }
3574            },
3575        };
3576        out.push(FeatureResolver { binding_name, kind });
3577    }
3578    Ok(out)
3579}
3580
3581/// Phase D D2: scan invocations' feature expressions for
3582/// `semantic_match(prop, 'text')` calls and embed each distinct
3583/// query string once via the Xervo runtime. Returns a
3584/// `text → Vec<f32>` map consumed at resolver-build time. Errors
3585/// cleanly when `semantic_match` is used without a configured
3586/// Xervo runtime.
3587async fn pre_embed_semantic_match_queries(
3588    invocations: &[uni_locy::ModelInvocation],
3589    xervo_runtime: &crate::query::df_graph::locy_model_invoke::XervoRuntimeHandle,
3590) -> DFResult<HashMap<String, Vec<f32>>> {
3591    use uni_cypher::ast::{CypherLiteral, Expr};
3592    // Collect (text, embedder_alias) pairs. The alias is per-model
3593    // (Phase D D2 follow-up): each invocation's `embedder_alias` overrides
3594    // the runtime-wide `"default"`. Two invocations sharing the same
3595    // (text, alias) reuse one embed call; same text under different
3596    // aliases is embedded twice. The cache key remains plain `text` —
3597    // a model's resolver fetches embeddings via the text it knows, and
3598    // mixed-alias re-embed of identical text under a different alias is
3599    // a rare-enough edge case that the "last writer wins" cache shape
3600    // is acceptable (documented).
3601    let mut needed: Vec<(String, String)> = Vec::new();
3602    for inv in invocations {
3603        let alias = inv
3604            .embedder_alias
3605            .clone()
3606            .unwrap_or_else(|| "default".to_string());
3607        for fexpr in &inv.feature_exprs {
3608            if let Expr::FunctionCall { name, args, .. } = fexpr
3609                && name == "semantic_match"
3610                && args.len() == 2
3611                && let Expr::Literal(CypherLiteral::String(s)) = &args[1]
3612            {
3613                let tuple = (s.clone(), alias.clone());
3614                if !needed.contains(&tuple) {
3615                    needed.push(tuple);
3616                }
3617            }
3618        }
3619    }
3620    if needed.is_empty() {
3621        return Ok(HashMap::new());
3622    }
3623    let runtime = xervo_runtime.as_ref().ok_or_else(|| {
3624        datafusion::error::DataFusionError::Execution(
3625            "semantic_match: Uni-Xervo runtime not configured. Either provide \
3626             one via `LocyConfig::xervo_runtime` (or its equivalent setup \
3627             path) or pre-compute the query embedding and pass it via \
3628             `similar_to(prop, <literal_vector>)`."
3629                .to_string(),
3630        )
3631    })?;
3632    // Group needed (text, alias) by alias so each embedder is consulted
3633    // exactly once.
3634    let mut by_alias: HashMap<String, Vec<String>> = HashMap::new();
3635    for (text, alias) in &needed {
3636        by_alias
3637            .entry(alias.clone())
3638            .or_default()
3639            .push(text.clone());
3640    }
3641    let mut out: HashMap<String, Vec<f32>> = HashMap::new();
3642    for (alias, texts) in by_alias {
3643        let embedder = runtime.embedding(&alias).await.map_err(|e| {
3644            datafusion::error::DataFusionError::Execution(format!(
3645                "semantic_match: failed to obtain embedder for alias '{alias}': {e}. \
3646                 Register an embedder under that alias in your Uni-Xervo runtime, or \
3647                 pre-compute the query embedding and pass via similar_to."
3648            ))
3649        })?;
3650        let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
3651        let embeddings = embedder.embed(text_refs).await.map_err(|e| {
3652            datafusion::error::DataFusionError::Execution(format!(
3653                "semantic_match: embedder '{alias}' call failed: {e}"
3654            ))
3655        })?;
3656        if embeddings.len() != texts.len() {
3657            return Err(datafusion::error::DataFusionError::Execution(format!(
3658                "semantic_match: embedder '{alias}' returned {} vectors for {} queries",
3659                embeddings.len(),
3660                texts.len()
3661            )));
3662        }
3663        for (text, vec) in texts.into_iter().zip(embeddings) {
3664            out.insert(text, vec);
3665        }
3666    }
3667    Ok(out)
3668}
3669
3670/// Phase D D1 graph-structural: scan invocations' feature expressions
3671/// for `degree_centrality(n)` / `pagerank_score(n)` / `closeness_centrality(n)`
3672/// calls and invoke the corresponding `uni.algo.*` procedure on the
3673/// configured `AlgorithmRegistry` once per distinct call. Returns a
3674/// `fn_name → Arc<HashMap<vid, score>>` map consumed at resolver-build
3675/// time. Errors cleanly when a graph-structural FEATURE is used
3676/// without a configured registry or storage handle.
3677///
3678/// Pre-computation is `O(graph)` per call. Across fixpoint iterations
3679/// the graph state can change, so the cache lives for the lifetime of
3680/// one `apply_model_invocations` call only — same lifetime as the
3681/// D2 query-embedding cache (`pre_embed_semantic_match_queries`).
3682async fn precompute_graph_feature_maps(
3683    invocations: &[uni_locy::ModelInvocation],
3684    graph_algo: &crate::query::df_graph::locy_model_invoke::GraphAlgoHandle,
3685) -> DFResult<HashMap<String, Arc<HashMap<u64, f64>>>> {
3686    use futures::StreamExt;
3687    use uni_algo::algo::procedures::AlgoContext;
3688    use uni_cypher::ast::Expr;
3689
3690    // Map our user-facing FEATURE function names to the canonical
3691    // `uni.algo.*` procedure names registered in `AlgorithmRegistry`.
3692    fn procedure_for(fn_name: &str) -> Option<&'static str> {
3693        match fn_name {
3694            "degree_centrality" => Some("uni.algo.degreeCentrality"),
3695            "pagerank_score" => Some("uni.algo.pageRank"),
3696            "closeness_centrality" => Some("uni.algo.closeness"),
3697            "betweenness_centrality" => Some("uni.algo.betweenness"),
3698            "eigenvector_centrality" => Some("uni.algo.eigenvectorCentrality"),
3699            "harmonic_centrality" => Some("uni.algo.harmonicCentrality"),
3700            "katz_centrality" => Some("uni.algo.katzCentrality"),
3701            _ => None,
3702        }
3703    }
3704
3705    // Collect the set of distinct topology-FEATURE names referenced
3706    // across all invocations. Args are always a single Variable, so
3707    // the precomputation key is just the function name.
3708    let mut needed: Vec<String> = Vec::new();
3709    for inv in invocations {
3710        for fexpr in &inv.feature_exprs {
3711            if let Expr::FunctionCall { name, .. } = fexpr
3712                && procedure_for(name).is_some()
3713                && !needed.contains(name)
3714            {
3715                needed.push(name.clone());
3716            }
3717        }
3718    }
3719    if needed.is_empty() {
3720        return Ok(HashMap::new());
3721    }
3722
3723    let registry = graph_algo.registry.as_ref().ok_or_else(|| {
3724        datafusion::error::DataFusionError::Execution(
3725            "graph-structural FEATURE invoked but no `AlgorithmRegistry` is \
3726             configured. Configure one on `GraphExecutionContext::with_algo_registry`."
3727                .to_string(),
3728        )
3729    })?;
3730    let storage = graph_algo.storage.as_ref().ok_or_else(|| {
3731        datafusion::error::DataFusionError::Execution(
3732            "graph-structural FEATURE invoked but no storage handle was \
3733             threaded into the FEATURE runtime. This is a bug in df_planner."
3734                .to_string(),
3735        )
3736    })?;
3737
3738    let mut out: HashMap<String, Arc<HashMap<u64, f64>>> = HashMap::new();
3739    for fn_name in needed {
3740        let proc_name = procedure_for(&fn_name).unwrap();
3741        let procedure = registry.get(proc_name).ok_or_else(|| {
3742            datafusion::error::DataFusionError::Execution(format!(
3743                "graph-structural FEATURE '{fn_name}' resolves to procedure \
3744                 '{proc_name}' which is not in the algorithm registry"
3745            ))
3746        })?;
3747        // Topology procedures take (nodeLabels[], relationshipTypes[],
3748        // [direction], [...]) — pass empty arrays for nodeLabels and
3749        // relationshipTypes to mean "all". The procedure fills the
3750        // remaining optional args from its signature defaults.
3751        let args: Vec<serde_json::Value> = vec![
3752            serde_json::Value::Array(Vec::new()),
3753            serde_json::Value::Array(Vec::new()),
3754        ];
3755        let algo_ctx = AlgoContext::new(
3756            storage.clone(),
3757            graph_algo.l0_manager.as_ref().map(Arc::clone),
3758        );
3759        // The AlgoProcedure trait routes direct (nodeLabels, edgeTypes)
3760        // args through the V2 projection entry point: build a projection
3761        // from the direct args, then execute against it.
3762        //
3763        // Fill optional algorithm-specific args (e.g. degree_centrality's
3764        // `direction`, eigenvector/katz `weightProperty`) with their schema
3765        // defaults for the projection build: `build_projection_from_direct_args`
3766        // feeds the specific args (`args[2..]`) to the adapter's
3767        // `customize_projection`, which indexes them positionally — the two
3768        // empty placeholder arrays alone would leave that slice empty and
3769        // panic. `validate_args` fills missing optionals WITHOUT type-checking
3770        // the defaults (some are `Null`-typed sentinels), so this never errors
3771        // for the placeholder shape.
3772        //
3773        // We pass the ORIGINAL `args` (not the filled ones) to
3774        // `execute_with_projection`, which re-runs `validate_args` internally:
3775        // re-feeding already-filled args would make those defaults look
3776        // "provided" and trip the type-check (`weightProperty: Null` vs
3777        // `String`). This mirrors the (now-removed) legacy
3778        // `AlgoProcedure::execute`, which validated once then built + ran.
3779        let filled_args = procedure
3780            .signature()
3781            .validate_args(args.clone())
3782            .map_err(|e| {
3783                datafusion::error::DataFusionError::Execution(format!(
3784                    "graph-structural FEATURE '{fn_name}': argument validation failed: {e}"
3785                ))
3786            })?;
3787        let projection = uni_algo::algo::procedure_template::build_projection_from_direct_args(
3788            procedure.as_ref(),
3789            &algo_ctx,
3790            &filled_args,
3791        )
3792        .await
3793        .map_err(|e| {
3794            datafusion::error::DataFusionError::Execution(format!(
3795                "graph-structural FEATURE '{fn_name}': projection build failed: {e}"
3796            ))
3797        })?;
3798        let mut stream = procedure.execute_with_projection(algo_ctx, args, projection);
3799        let mut score_map: HashMap<u64, f64> = HashMap::new();
3800        let sig = procedure.signature();
3801        let node_idx = sig
3802            .yields
3803            .iter()
3804            .position(|(n, _)| *n == "nodeId")
3805            .ok_or_else(|| {
3806                datafusion::error::DataFusionError::Execution(format!(
3807                    "procedure '{proc_name}' yield schema missing 'nodeId'"
3808                ))
3809            })?;
3810        // Most `uni.algo.*` centrality procedures yield `score`; the
3811        // `harmonicCentrality` family yields `centrality` instead. Accept
3812        // either to keep this dispatch independent of procedure-internal
3813        // naming choices.
3814        let score_idx = sig
3815            .yields
3816            .iter()
3817            .position(|(n, _)| *n == "score" || *n == "centrality")
3818            .ok_or_else(|| {
3819                datafusion::error::DataFusionError::Execution(format!(
3820                    "procedure '{proc_name}' yield schema missing a numeric score column \
3821                     (expected 'score' or 'centrality')"
3822                ))
3823            })?;
3824        while let Some(row_res) = stream.next().await {
3825            let row = row_res.map_err(|e| {
3826                datafusion::error::DataFusionError::Execution(format!(
3827                    "graph-structural FEATURE '{fn_name}': procedure '{proc_name}' failed: {e}"
3828                ))
3829            })?;
3830            let vid_v = row.values.get(node_idx);
3831            let score_v = row.values.get(score_idx);
3832            let (Some(vid_v), Some(score_v)) = (vid_v, score_v) else {
3833                continue;
3834            };
3835            let vid = vid_v.as_u64().or_else(|| vid_v.as_i64().map(|i| i as u64));
3836            let score = score_v
3837                .as_f64()
3838                .or_else(|| score_v.as_i64().map(|i| i as f64));
3839            if let (Some(vid), Some(score)) = (vid, score) {
3840                score_map.insert(vid, score);
3841            }
3842        }
3843        out.insert(fn_name, Arc::new(score_map));
3844    }
3845    Ok(out)
3846}
3847
3848/// Phase D D1 graph-structural: one-hop neighborhood aggregator
3849/// precompute. Scans invocations' feature expressions for
3850/// `avg_neighbor` / `max_neighbor` / `sum_neighbor` FunctionCalls,
3851/// collects the distinct `(rel_type, prop_name)` pairs they need,
3852/// resolves each rel-type to a schema edge-type id, warms the
3853/// outgoing-adjacency CSR, and for every subject vid present in the
3854/// body batches walks the one-hop neighborhood and fetches the
3855/// requested property from each neighbor via `PropertyManager`.
3856/// Non-numeric neighbor property values are filtered out via
3857/// `Value::as_f64`.
3858///
3859/// Returns `Arc<HashMap<u64, Vec<f64>>>` keyed by `(rel_type, prop_name)`.
3860/// The resolver's runtime cost per row is then a single hash lookup
3861/// plus an `avg`/`max`/`sum` over the cached `Vec<f64>`.
3862///
3863/// Scope: **subject-set-only** — we only collect for vids that appear
3864/// in the body batches' subject columns (avoids pre-walking the entire
3865/// graph). Subjects with no outgoing edges of the named type land in
3866/// the map with an empty `Vec` so the resolver's `Null` semantics
3867/// remain crisp (empty → `Null` → classifier interprets per its
3868/// feature contract).
3869/// Per-`(rel_type, prop_name, direction)` cache of neighbor property
3870/// values keyed by subject vid, produced by
3871/// `precompute_neighbor_feature_maps` and consumed by
3872/// `FeatureResolverKind::NeighborAggregate` resolvers.
3873type NeighborFeatureMaps =
3874    HashMap<(String, String, NeighborDirection), Arc<HashMap<u64, Vec<f64>>>>;
3875
3876async fn precompute_neighbor_feature_maps(
3877    invocations: &[uni_locy::ModelInvocation],
3878    batches: &[RecordBatch],
3879    graph_algo: &crate::query::df_graph::locy_model_invoke::GraphAlgoHandle,
3880) -> DFResult<NeighborFeatureMaps> {
3881    use uni_cypher::ast::{CypherLiteral, Expr};
3882
3883    // Collect distinct (subject_var, rel_type, prop_name, direction)
3884    // tuples needed across all invocations. The subject_var tells us
3885    // which body batch column to scan for subject vids; the direction
3886    // is optional in the AST (defaults to OUTGOING).
3887    let parse_direction = |arg: Option<&Expr>| -> Option<NeighborDirection> {
3888        match arg {
3889            None => Some(NeighborDirection::Outgoing),
3890            Some(Expr::Literal(CypherLiteral::String(d))) => match d.to_uppercase().as_str() {
3891                "OUTGOING" => Some(NeighborDirection::Outgoing),
3892                "INCOMING" => Some(NeighborDirection::Incoming),
3893                "BOTH" => Some(NeighborDirection::Both),
3894                _ => None,
3895            },
3896            _ => None,
3897        }
3898    };
3899    let mut needed: Vec<(String, String, String, NeighborDirection)> = Vec::new();
3900    for inv in invocations {
3901        for fexpr in &inv.feature_exprs {
3902            if let Expr::FunctionCall { name, args, .. } = fexpr
3903                && NeighborAgg::from_fn_name(name).is_some()
3904                && (args.len() == 3 || args.len() == 4)
3905                && let Expr::Variable(v) = &args[0]
3906                && let Expr::Literal(CypherLiteral::String(rel)) = &args[1]
3907                && let Expr::Literal(CypherLiteral::String(prop)) = &args[2]
3908                && let Some(direction) = parse_direction(args.get(3))
3909            {
3910                let tuple = (v.clone(), rel.clone(), prop.clone(), direction);
3911                if !needed.contains(&tuple) {
3912                    needed.push(tuple);
3913                }
3914            }
3915        }
3916    }
3917    if needed.is_empty() {
3918        return Ok(HashMap::new());
3919    }
3920
3921    let storage = graph_algo.storage.as_ref().ok_or_else(|| {
3922        datafusion::error::DataFusionError::Execution(
3923            "neighbor-aggregator FEATURE invoked but no storage handle was \
3924             threaded into the FEATURE runtime. This is a bug in df_planner."
3925                .to_string(),
3926        )
3927    })?;
3928    let property_manager = graph_algo.property_manager.as_ref().ok_or_else(|| {
3929        datafusion::error::DataFusionError::Execution(
3930            "neighbor-aggregator FEATURE invoked but no PropertyManager was \
3931             threaded into the FEATURE runtime. This is a bug in df_planner."
3932                .to_string(),
3933        )
3934    })?;
3935    // Build a QueryContext snapshot so L0-resident vertex properties
3936    // are visible to `get_vertex_prop_with_ctx`. Without a ctx, L0
3937    // property data is silently invisible (returns Null), which is
3938    // why the topology trio's `AlgoContext` consumes L0 via
3939    // `L0Manager` whereas property reads need this separate path.
3940    let query_ctx = graph_algo.l0_buffers.as_ref().map(|bufs| {
3941        uni_store::runtime::context::QueryContext::new_with_pending(
3942            bufs.current.clone(),
3943            bufs.transaction.clone(),
3944            bufs.pending_flush.clone(),
3945        )
3946    });
3947
3948    // Group needed tuples by (rel_type, prop_name, direction) — one
3949    // precomputed map per key, regardless of which subject_var binding
3950    // points at it (the subject vids are unioned).
3951    let mut by_key: HashMap<(String, String, NeighborDirection), Vec<String>> = HashMap::new();
3952    for (subject_var, rel, prop, direction) in needed {
3953        by_key
3954            .entry((rel, prop, direction))
3955            .or_default()
3956            .push(subject_var);
3957    }
3958
3959    let mut out: NeighborFeatureMaps = HashMap::new();
3960    for ((rel_type, prop_name, direction), subject_vars) in by_key {
3961        // Resolve edge_type_id from schema.
3962        let schema = storage.schema_manager().schema();
3963        let Some(edge_meta) = schema.edge_types.get(&rel_type) else {
3964            // Unregistered rel-type → empty map. The resolver surfaces
3965            // Null at row time, consistent with the no-neighbor case.
3966            out.insert((rel_type, prop_name, direction), Arc::new(HashMap::new()));
3967            continue;
3968        };
3969        let edge_type_id = edge_meta.id;
3970
3971        // Warm adjacency for every direction we'll traverse. Mirrors
3972        // the pattern in projection.rs / procedure_template.rs.
3973        let edge_ver = storage.get_edge_version_by_id(edge_type_id);
3974        for dir in direction.store_directions() {
3975            storage
3976                .warm_adjacency(edge_type_id, *dir, edge_ver)
3977                .await
3978                .map_err(|e| {
3979                    datafusion::error::DataFusionError::Execution(format!(
3980                        "neighbor-aggregator warm_adjacency for '{rel_type}' / {dir:?} failed: {e}"
3981                    ))
3982                })?;
3983        }
3984
3985        // Collect distinct subject vids from body batches across every
3986        // subject_var binding that this (rel, prop) pair uses.
3987        let mut subject_vids: std::collections::HashSet<u64> = std::collections::HashSet::new();
3988        for subject_var in &subject_vars {
3989            for batch in batches {
3990                let schema = batch.schema();
3991                let col_idx = schema
3992                    .index_of(&format!("{}._vid", subject_var))
3993                    .ok()
3994                    .or_else(|| schema.index_of(subject_var).ok());
3995                let Some(col_idx) = col_idx else { continue };
3996                let col = batch.column(col_idx);
3997                for row in 0..batch.num_rows() {
3998                    if let Some(v) = extract_vid_from_column(col.as_ref(), row) {
3999                        subject_vids.insert(v);
4000                    }
4001                }
4002            }
4003        }
4004
4005        // For each subject, walk edges in the configured direction(s),
4006        // fetch neighbor property, coerce to f64, accumulate. Subjects
4007        // with no numeric neighbors retain an empty Vec (→ Null at
4008        // row time).
4009        let mut vid_to_values: HashMap<u64, Vec<f64>> = HashMap::new();
4010        let adj = storage.adjacency_manager();
4011        for subject_vid in subject_vids {
4012            let mut neighbors: Vec<(uni_common::core::id::Vid, uni_common::core::id::Eid)> =
4013                Vec::new();
4014            for dir in direction.store_directions() {
4015                neighbors.extend(adj.get_neighbors(
4016                    uni_common::core::id::Vid::from(subject_vid),
4017                    edge_type_id,
4018                    *dir,
4019                ));
4020            }
4021            let mut values: Vec<f64> = Vec::with_capacity(neighbors.len());
4022            for (neighbor_vid, _eid) in neighbors {
4023                let val = property_manager
4024                    .get_vertex_prop_with_ctx(neighbor_vid, &prop_name, query_ctx.as_ref())
4025                    .await
4026                    .map_err(|e| {
4027                        datafusion::error::DataFusionError::Execution(format!(
4028                            "neighbor-aggregator: failed to read property \
4029                             '{prop_name}' on neighbor vid {neighbor_vid:?}: {e}"
4030                        ))
4031                    })?;
4032                if let Some(f) = val.as_f64()
4033                    && !f.is_nan()
4034                {
4035                    values.push(f);
4036                }
4037            }
4038            vid_to_values.insert(subject_vid, values);
4039        }
4040        out.insert((rel_type, prop_name, direction), Arc::new(vid_to_values));
4041    }
4042    Ok(out)
4043}
4044
4045/// Phase D D3: walk the source rule's converged batches and build
4046/// a `vid → FeatureValue` lookup for the named column. The subject
4047/// column in the derived rule's schema holds VIDs (UInt64) for node
4048/// variables; the value column type follows the rule's yield-schema
4049/// inference (typically Float64 / Int64 / Bool / String).
4050fn build_path_context_lookup(
4051    handle: &crate::query::df_graph::locy_model_invoke::PathContextHandle,
4052    _subject_var: &str,
4053    column: &str,
4054    model_name: &str,
4055) -> DFResult<HashMap<u64, uni_locy::FeatureValue>> {
4056    // The source rule's KEY column is its first yield column by
4057    // convention (`infer_yield_schema` orders KEYs first). The model's
4058    // local `subject_var` is just a binding alias — the actual join
4059    // matches the body row's VID against this canonical column.
4060    if handle.schema.fields().is_empty() {
4061        return Err(datafusion::error::DataFusionError::Execution(format!(
4062            "model '{model_name}' path_context: source rule has empty yield schema"
4063        )));
4064    }
4065    let subj_idx = 0_usize;
4066    let col_idx = handle.schema.index_of(column).map_err(|_| {
4067        datafusion::error::DataFusionError::Execution(format!(
4068            "model '{model_name}' path_context: column '{column}' not in \
4069             source rule's yield schema (have: {:?})",
4070            handle
4071                .schema
4072                .fields()
4073                .iter()
4074                .map(|f| f.name().clone())
4075                .collect::<Vec<_>>()
4076        ))
4077    })?;
4078    let batches = handle.data.read();
4079    let mut out: HashMap<u64, uni_locy::FeatureValue> = HashMap::new();
4080    for batch in batches.iter() {
4081        let subj_col = batch.column(subj_idx);
4082        let value_col = batch.column(col_idx);
4083        for row in 0..batch.num_rows() {
4084            if subj_col.is_null(row) {
4085                continue;
4086            }
4087            let vid = if let Some(a) = subj_col.as_any().downcast_ref::<arrow_array::UInt64Array>()
4088            {
4089                a.value(row)
4090            } else if let Some(a) = subj_col.as_any().downcast_ref::<arrow_array::Int64Array>() {
4091                a.value(row) as u64
4092            } else {
4093                continue;
4094            };
4095            let v = extract_feature_value(value_col.as_ref(), row);
4096            // Last write wins on duplicates; derived rules typically have
4097            // unique KEY values, so this is a defensive guard.
4098            out.insert(vid, v);
4099        }
4100    }
4101    Ok(out)
4102}
4103
4104/// Extract a `uni_common::Value` from one row of an Arrow column.
4105/// Used by the Phase D `similar_to` feature resolver, which needs
4106/// the raw `Value` (especially `Value::Vector(Vec<f32>)`) to feed
4107/// `eval_similar_to_pure`.
4108fn extract_common_value(col: &dyn arrow_array::Array, row_idx: usize) -> uni_common::Value {
4109    use arrow_array::{BooleanArray, Float64Array, Int64Array, LargeStringArray, StringArray};
4110    if col.is_null(row_idx) {
4111        return uni_common::Value::Null;
4112    }
4113    if let Some(a) = col.as_any().downcast_ref::<Float64Array>() {
4114        return uni_common::Value::Float(a.value(row_idx));
4115    }
4116    if let Some(a) = col.as_any().downcast_ref::<Int64Array>() {
4117        return uni_common::Value::Int(a.value(row_idx));
4118    }
4119    if let Some(a) = col.as_any().downcast_ref::<BooleanArray>() {
4120        return uni_common::Value::Bool(a.value(row_idx));
4121    }
4122    if let Some(a) = col.as_any().downcast_ref::<StringArray>() {
4123        return uni_common::Value::String(a.value(row_idx).to_string());
4124    }
4125    if let Some(a) = col.as_any().downcast_ref::<LargeStringArray>() {
4126        return uni_common::Value::String(a.value(row_idx).to_string());
4127    }
4128    if let Some(b) = col.as_any().downcast_ref::<arrow_array::LargeBinaryArray>() {
4129        let bytes = b.value(row_idx);
4130        if bytes.is_empty() {
4131            return uni_common::Value::Null;
4132        }
4133        return uni_common::cypher_value_codec::decode(bytes).unwrap_or(uni_common::Value::Null);
4134    }
4135    uni_common::Value::Null
4136}
4137
4138fn extract_feature_value(col: &dyn arrow_array::Array, row_idx: usize) -> uni_locy::FeatureValue {
4139    use arrow_array::{BooleanArray, Float64Array, Int64Array, LargeStringArray, StringArray};
4140    if col.is_null(row_idx) {
4141        return uni_locy::FeatureValue::Null;
4142    }
4143    if let Some(a) = col.as_any().downcast_ref::<Float64Array>() {
4144        return uni_locy::FeatureValue::Float(a.value(row_idx));
4145    }
4146    if let Some(a) = col.as_any().downcast_ref::<Int64Array>() {
4147        return uni_locy::FeatureValue::Int(a.value(row_idx));
4148    }
4149    if let Some(a) = col.as_any().downcast_ref::<BooleanArray>() {
4150        return uni_locy::FeatureValue::Bool(a.value(row_idx));
4151    }
4152    if let Some(a) = col.as_any().downcast_ref::<StringArray>() {
4153        return uni_locy::FeatureValue::String(a.value(row_idx).to_string());
4154    }
4155    if let Some(a) = col.as_any().downcast_ref::<LargeStringArray>() {
4156        return uni_locy::FeatureValue::String(a.value(row_idx).to_string());
4157    }
4158    // Schema-less property storage: values arrive as LargeBinary
4159    // MessagePack-encoded `CypherValue`. Decode via the standard codec
4160    // and project the result to the matching `FeatureValue` variant.
4161    if let Some(b) = col.as_any().downcast_ref::<arrow_array::LargeBinaryArray>() {
4162        let bytes = b.value(row_idx);
4163        if bytes.is_empty() {
4164            return uni_locy::FeatureValue::Null;
4165        }
4166        let v = uni_common::cypher_value_codec::decode(bytes).unwrap_or(uni_common::Value::Null);
4167        return match v {
4168            uni_common::Value::Float(f) => uni_locy::FeatureValue::Float(f),
4169            uni_common::Value::Int(i) => uni_locy::FeatureValue::Int(i),
4170            uni_common::Value::Bool(b) => uni_locy::FeatureValue::Bool(b),
4171            uni_common::Value::String(s) => uni_locy::FeatureValue::String(s),
4172            uni_common::Value::Null => uni_locy::FeatureValue::Null,
4173            _ => uni_locy::FeatureValue::Null,
4174        };
4175    }
4176    uni_locy::FeatureValue::Null
4177}
4178
4179/// Probabilistic complement for negated IS-refs targeting PROB rules.
4180///
4181/// Instead of filtering out matching VIDs (anti-join), this adds a complement
4182/// column `__prob_complement_{rule_name}` with value `1 - p` for each matching
4183/// VID, and `1.0` for VIDs not present in the negated rule's facts. Implements
4184/// `IS NOT risk` on a PROB rule: the probability that the entity is NOT risky.
4185pub fn apply_prob_complement(
4186    batches: Vec<RecordBatch>,
4187    neg_facts: &[RecordBatch],
4188    left_col: &str,
4189    right_col: &str,
4190    prob_col: &str,
4191    complement_col_name: &str,
4192) -> datafusion::error::Result<Vec<RecordBatch>> {
4193    use arrow_array::{Array as _, Float64Array, UInt64Array};
4194
4195    // Build VID → probability lookup from negative facts
4196    let mut prob_map: std::collections::HashMap<u64, f64> = std::collections::HashMap::new();
4197    for batch in neg_facts {
4198        let Ok(vid_idx) = batch.schema().index_of(right_col) else {
4199            continue;
4200        };
4201        let Ok(prob_idx) = batch.schema().index_of(prob_col) else {
4202            continue;
4203        };
4204        let Some(vids) = batch.column(vid_idx).as_any().downcast_ref::<UInt64Array>() else {
4205            continue;
4206        };
4207        let prob_arr = batch.column(prob_idx);
4208        let probs = prob_arr.as_any().downcast_ref::<Float64Array>();
4209        for i in 0..vids.len() {
4210            if !vids.is_null(i) {
4211                let p = probs
4212                    .and_then(|arr| {
4213                        if arr.is_null(i) {
4214                            None
4215                        } else {
4216                            Some(arr.value(i))
4217                        }
4218                    })
4219                    .unwrap_or(0.0);
4220                // If multiple facts for same VID, use noisy-OR combination:
4221                // combined = 1 - (1 - existing) * (1 - new)
4222                prob_map
4223                    .entry(vids.value(i))
4224                    .and_modify(|existing| {
4225                        *existing = 1.0 - (1.0 - *existing) * (1.0 - p);
4226                    })
4227                    .or_insert(p);
4228            }
4229        }
4230    }
4231
4232    // Add complement column to each batch
4233    let mut result = Vec::new();
4234    for batch in batches {
4235        let Ok(idx) = batch.schema().index_of(left_col) else {
4236            result.push(batch);
4237            continue;
4238        };
4239        let Some(vids) = batch.column(idx).as_any().downcast_ref::<UInt64Array>() else {
4240            result.push(batch);
4241            continue;
4242        };
4243
4244        // Compute complement values: 1 - p for matched VIDs, 1.0 for absent
4245        let complements: Vec<f64> = (0..vids.len())
4246            .map(|i| {
4247                if vids.is_null(i) {
4248                    1.0
4249                } else {
4250                    let p = prob_map.get(&vids.value(i)).copied().unwrap_or(0.0);
4251                    1.0 - p
4252                }
4253            })
4254            .collect();
4255
4256        let complement_arr = Float64Array::from(complements);
4257
4258        // Add the complement column to the batch
4259        let mut columns: Vec<arrow_array::ArrayRef> = batch.columns().to_vec();
4260        columns.push(std::sync::Arc::new(complement_arr));
4261
4262        let mut fields: Vec<std::sync::Arc<arrow_schema::Field>> =
4263            batch.schema().fields().iter().cloned().collect();
4264        fields.push(std::sync::Arc::new(arrow_schema::Field::new(
4265            complement_col_name,
4266            arrow_schema::DataType::Float64,
4267            true,
4268        )));
4269
4270        let new_schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
4271        let new_batch = RecordBatch::try_new(new_schema, columns).map_err(arrow_err)?;
4272        result.push(new_batch);
4273    }
4274    Ok(result)
4275}
4276
4277/// Probabilistic complement for composite (multi-column) join keys.
4278///
4279/// Builds a composite key from all `join_cols` right-side columns in
4280/// `neg_facts`, maps each composite key to a probability via noisy-OR
4281/// combination, then adds a single `complement_col_name` column with
4282/// `1 - p` for matched keys and `1.0` for absent keys.
4283pub fn apply_prob_complement_composite(
4284    batches: Vec<RecordBatch>,
4285    neg_facts: &[RecordBatch],
4286    join_cols: &[(String, String)],
4287    prob_col: &str,
4288    complement_col_name: &str,
4289) -> datafusion::error::Result<Vec<RecordBatch>> {
4290    use arrow_array::{Array as _, Float64Array, UInt64Array};
4291
4292    // Build composite-key → probability lookup from negative facts.
4293    let mut prob_map: HashMap<Vec<u64>, f64> = HashMap::new();
4294    for batch in neg_facts {
4295        let right_indices: Vec<usize> = join_cols
4296            .iter()
4297            .filter_map(|(_, rc)| batch.schema().index_of(rc).ok())
4298            .collect();
4299        if right_indices.len() != join_cols.len() {
4300            continue;
4301        }
4302        let Ok(prob_idx) = batch.schema().index_of(prob_col) else {
4303            continue;
4304        };
4305        let prob_arr = batch.column(prob_idx);
4306        let probs = prob_arr.as_any().downcast_ref::<Float64Array>();
4307        for row in 0..batch.num_rows() {
4308            let mut key = Vec::with_capacity(right_indices.len());
4309            let mut valid = true;
4310            for &ci in &right_indices {
4311                let col = batch.column(ci);
4312                if let Some(vids) = col.as_any().downcast_ref::<UInt64Array>() {
4313                    if vids.is_null(row) {
4314                        valid = false;
4315                        break;
4316                    }
4317                    key.push(vids.value(row));
4318                } else {
4319                    valid = false;
4320                    break;
4321                }
4322            }
4323            if !valid {
4324                continue;
4325            }
4326            let p = probs
4327                .and_then(|arr| {
4328                    if arr.is_null(row) {
4329                        None
4330                    } else {
4331                        Some(arr.value(row))
4332                    }
4333                })
4334                .unwrap_or(0.0);
4335            // Noisy-OR combination for duplicate composite keys.
4336            prob_map
4337                .entry(key)
4338                .and_modify(|existing| {
4339                    *existing = 1.0 - (1.0 - *existing) * (1.0 - p);
4340                })
4341                .or_insert(p);
4342        }
4343    }
4344
4345    // Add complement column to each batch.
4346    let mut result = Vec::new();
4347    for batch in batches {
4348        let left_indices: Vec<usize> = join_cols
4349            .iter()
4350            .filter_map(|(lc, _)| batch.schema().index_of(lc).ok())
4351            .collect();
4352        if left_indices.len() != join_cols.len() {
4353            result.push(batch);
4354            continue;
4355        }
4356        let all_u64 = left_indices.iter().all(|&ci| {
4357            batch
4358                .column(ci)
4359                .as_any()
4360                .downcast_ref::<UInt64Array>()
4361                .is_some()
4362        });
4363        if !all_u64 {
4364            result.push(batch);
4365            continue;
4366        }
4367
4368        let complements: Vec<f64> = (0..batch.num_rows())
4369            .map(|row| {
4370                let mut key = Vec::with_capacity(left_indices.len());
4371                for &ci in &left_indices {
4372                    let vids = batch
4373                        .column(ci)
4374                        .as_any()
4375                        .downcast_ref::<UInt64Array>()
4376                        .unwrap();
4377                    if vids.is_null(row) {
4378                        return 1.0;
4379                    }
4380                    key.push(vids.value(row));
4381                }
4382                let p = prob_map.get(&key).copied().unwrap_or(0.0);
4383                1.0 - p
4384            })
4385            .collect();
4386
4387        let complement_arr = Float64Array::from(complements);
4388        let mut columns: Vec<arrow_array::ArrayRef> = batch.columns().to_vec();
4389        columns.push(Arc::new(complement_arr));
4390
4391        let mut fields: Vec<Arc<arrow_schema::Field>> =
4392            batch.schema().fields().iter().cloned().collect();
4393        fields.push(Arc::new(arrow_schema::Field::new(
4394            complement_col_name,
4395            arrow_schema::DataType::Float64,
4396            true,
4397        )));
4398
4399        let new_schema = Arc::new(arrow_schema::Schema::new(fields));
4400        let new_batch = RecordBatch::try_new(new_schema, columns).map_err(arrow_err)?;
4401        result.push(new_batch);
4402    }
4403    Ok(result)
4404}
4405
4406/// Boolean anti-join for composite (multi-column) join keys.
4407///
4408/// Builds a `HashSet<Vec<u64>>` from `neg_facts` using all right-side
4409/// columns in `join_cols`, then filters `batches` to keep only rows
4410/// whose composite left-side key is NOT in the set.
4411pub fn apply_anti_join_composite(
4412    batches: Vec<RecordBatch>,
4413    neg_facts: &[RecordBatch],
4414    join_cols: &[(String, String)],
4415) -> datafusion::error::Result<Vec<RecordBatch>> {
4416    use arrow::compute::filter_record_batch;
4417    use arrow_array::{Array as _, BooleanArray, UInt64Array};
4418
4419    // Collect composite keys from the negated rule's derived facts.
4420    let mut banned: HashSet<Vec<u64>> = HashSet::new();
4421    for batch in neg_facts {
4422        let right_indices: Vec<usize> = join_cols
4423            .iter()
4424            .filter_map(|(_, rc)| batch.schema().index_of(rc).ok())
4425            .collect();
4426        if right_indices.len() != join_cols.len() {
4427            continue;
4428        }
4429        for row in 0..batch.num_rows() {
4430            let mut key = Vec::with_capacity(right_indices.len());
4431            let mut valid = true;
4432            for &ci in &right_indices {
4433                let col = batch.column(ci);
4434                if let Some(vids) = col.as_any().downcast_ref::<UInt64Array>() {
4435                    if vids.is_null(row) {
4436                        valid = false;
4437                        break;
4438                    }
4439                    key.push(vids.value(row));
4440                } else {
4441                    valid = false;
4442                    break;
4443                }
4444            }
4445            if valid {
4446                banned.insert(key);
4447            }
4448        }
4449    }
4450
4451    if banned.is_empty() {
4452        return Ok(batches);
4453    }
4454
4455    // Filter body batches: keep rows where composite left key NOT IN banned.
4456    let mut result = Vec::new();
4457    for batch in batches {
4458        let left_indices: Vec<usize> = join_cols
4459            .iter()
4460            .filter_map(|(lc, _)| batch.schema().index_of(lc).ok())
4461            .collect();
4462        if left_indices.len() != join_cols.len() {
4463            result.push(batch);
4464            continue;
4465        }
4466        let all_u64 = left_indices.iter().all(|&ci| {
4467            batch
4468                .column(ci)
4469                .as_any()
4470                .downcast_ref::<UInt64Array>()
4471                .is_some()
4472        });
4473        if !all_u64 {
4474            result.push(batch);
4475            continue;
4476        }
4477
4478        let keep: Vec<bool> = (0..batch.num_rows())
4479            .map(|row| {
4480                let mut key = Vec::with_capacity(left_indices.len());
4481                for &ci in &left_indices {
4482                    let vids = batch
4483                        .column(ci)
4484                        .as_any()
4485                        .downcast_ref::<UInt64Array>()
4486                        .unwrap();
4487                    if vids.is_null(row) {
4488                        return true; // null keys are never banned
4489                    }
4490                    key.push(vids.value(row));
4491                }
4492                !banned.contains(&key)
4493            })
4494            .collect();
4495        let keep_arr = BooleanArray::from(keep);
4496        let filtered = filter_record_batch(&batch, &keep_arr).map_err(arrow_err)?;
4497        if filtered.num_rows() > 0 {
4498            result.push(filtered);
4499        }
4500    }
4501    Ok(result)
4502}
4503
4504/// Multiply `__prob_complement_*` columns into the rule's PROB column and clean up.
4505///
4506/// After IS NOT probabilistic complement semantics have added `__prob_complement_*`
4507/// columns to clause results, this function:
4508/// 1. Computes the product of all complement factor columns
4509/// 2. Multiplies the product into the existing PROB column (if any)
4510/// 3. Removes the internal `__prob_complement_*` columns from the output
4511///
4512/// If the rule has no PROB column, complement columns are simply removed
4513/// (the complement information is discarded and IS NOT acts as a keep-all).
4514pub fn multiply_prob_factors(
4515    batches: Vec<RecordBatch>,
4516    prob_col: Option<&str>,
4517    complement_cols: &[String],
4518) -> datafusion::error::Result<Vec<RecordBatch>> {
4519    use arrow_array::{Array as _, Float64Array};
4520
4521    let mut result = Vec::with_capacity(batches.len());
4522
4523    for batch in batches {
4524        if batch.num_rows() == 0 {
4525            // Remove complement columns from empty batches
4526            let keep: Vec<usize> = batch
4527                .schema()
4528                .fields()
4529                .iter()
4530                .enumerate()
4531                .filter(|(_, f)| !complement_cols.contains(f.name()))
4532                .map(|(i, _)| i)
4533                .collect();
4534            let fields: Vec<_> = keep
4535                .iter()
4536                .map(|&i| batch.schema().field(i).clone())
4537                .collect();
4538            let cols: Vec<_> = keep.iter().map(|&i| batch.column(i).clone()).collect();
4539            let schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
4540            result.push(
4541                RecordBatch::try_new(schema, cols).map_err(|e| {
4542                    datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
4543                })?,
4544            );
4545            continue;
4546        }
4547
4548        let num_rows = batch.num_rows();
4549
4550        // 1. Compute product of all complement factors
4551        let mut combined = vec![1.0f64; num_rows];
4552        for col_name in complement_cols {
4553            if let Ok(idx) = batch.schema().index_of(col_name) {
4554                let arr = batch
4555                    .column(idx)
4556                    .as_any()
4557                    .downcast_ref::<Float64Array>()
4558                    .ok_or_else(|| {
4559                        datafusion::error::DataFusionError::Internal(format!(
4560                            "Expected Float64 for complement column {col_name}"
4561                        ))
4562                    })?;
4563                for (i, val) in combined.iter_mut().enumerate().take(num_rows) {
4564                    if !arr.is_null(i) {
4565                        *val *= arr.value(i);
4566                    }
4567                }
4568            }
4569        }
4570
4571        // 2. If there's a PROB column, multiply combined into it
4572        let final_prob: Vec<f64> = if let Some(prob_name) = prob_col {
4573            if let Ok(idx) = batch.schema().index_of(prob_name) {
4574                let arr = batch
4575                    .column(idx)
4576                    .as_any()
4577                    .downcast_ref::<Float64Array>()
4578                    .ok_or_else(|| {
4579                        datafusion::error::DataFusionError::Internal(format!(
4580                            "Expected Float64 for PROB column {prob_name}"
4581                        ))
4582                    })?;
4583                (0..num_rows)
4584                    .map(|i| {
4585                        if arr.is_null(i) {
4586                            combined[i]
4587                        } else {
4588                            arr.value(i) * combined[i]
4589                        }
4590                    })
4591                    .collect()
4592            } else {
4593                combined
4594            }
4595        } else {
4596            combined
4597        };
4598
4599        let new_prob_array: arrow_array::ArrayRef =
4600            std::sync::Arc::new(Float64Array::from(final_prob));
4601
4602        // 3. Build output: replace PROB column, remove complement columns
4603        let mut fields = Vec::new();
4604        let mut columns = Vec::new();
4605
4606        for (idx, field) in batch.schema().fields().iter().enumerate() {
4607            if complement_cols.contains(field.name()) {
4608                continue;
4609            }
4610            if prob_col.is_some_and(|p| field.name() == p) {
4611                fields.push(field.clone());
4612                columns.push(new_prob_array.clone());
4613            } else {
4614                fields.push(field.clone());
4615                columns.push(batch.column(idx).clone());
4616            }
4617        }
4618
4619        let schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
4620        result.push(RecordBatch::try_new(schema, columns).map_err(arrow_err)?);
4621    }
4622
4623    Ok(result)
4624}
4625
4626/// Update derived scan handles before evaluating a rule's clause bodies.
4627///
4628/// For self-references: inject delta (semi-naive optimization).
4629/// For cross-references: inject full facts.
4630fn update_derived_scan_handles(
4631    registry: &DerivedScanRegistry,
4632    states: &[FixpointState],
4633    current_rule_idx: usize,
4634    rules: &[FixpointRulePlan],
4635) {
4636    let current_rule_name = &rules[current_rule_idx].name;
4637
4638    for entry in &registry.entries {
4639        // Find the state for this entry's rule
4640        let source_state_idx = rules.iter().position(|r| r.name == entry.rule_name);
4641        let Some(source_idx) = source_state_idx else {
4642            continue;
4643        };
4644
4645        let is_self = entry.rule_name == *current_rule_name;
4646        let data = if is_self && !rules[current_rule_idx].non_linear {
4647            // Self-ref in a linear rule: inject delta for semi-naive
4648            states[source_idx].all_delta().to_vec()
4649        } else {
4650            // Cross-ref, or self-ref of a non-linear rule (≥2 same-stratum
4651            // refs in one clause — Δ×Δ would miss Δ×F_old): inject full facts
4652            states[source_idx].all_facts().to_vec()
4653        };
4654
4655        // If empty, write an empty batch so the scan returns zero rows
4656        let data = if data.is_empty() || data.iter().all(|b| b.num_rows() == 0) {
4657            vec![RecordBatch::new_empty(Arc::clone(&entry.schema))]
4658        } else {
4659            data
4660        };
4661
4662        let mut guard = entry.data.write();
4663        *guard = data;
4664    }
4665}
4666
4667// ---------------------------------------------------------------------------
4668// DerivedScanExec — physical plan that reads from shared data at execution time
4669// ---------------------------------------------------------------------------
4670
4671/// Physical plan for `LocyDerivedScan` that reads from a shared `Arc<RwLock>` at
4672/// execution time (not at plan creation time).
4673///
4674/// This is critical for fixpoint iteration: the data handle is updated between
4675/// iterations, and each re-execution of the subplan must read the latest data.
4676pub struct DerivedScanExec {
4677    data: Arc<RwLock<Vec<RecordBatch>>>,
4678    schema: SchemaRef,
4679    properties: Arc<PlanProperties>,
4680}
4681
4682impl DerivedScanExec {
4683    pub fn new(data: Arc<RwLock<Vec<RecordBatch>>>, schema: SchemaRef) -> Self {
4684        let properties = compute_plan_properties(Arc::clone(&schema));
4685        Self {
4686            data,
4687            schema,
4688            properties,
4689        }
4690    }
4691}
4692
4693impl fmt::Debug for DerivedScanExec {
4694    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4695        f.debug_struct("DerivedScanExec")
4696            .field("schema", &self.schema)
4697            .finish()
4698    }
4699}
4700
4701impl DisplayAs for DerivedScanExec {
4702    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4703        write!(f, "DerivedScanExec")
4704    }
4705}
4706
4707impl ExecutionPlan for DerivedScanExec {
4708    fn name(&self) -> &str {
4709        "DerivedScanExec"
4710    }
4711    fn as_any(&self) -> &dyn Any {
4712        self
4713    }
4714    fn schema(&self) -> SchemaRef {
4715        Arc::clone(&self.schema)
4716    }
4717    fn properties(&self) -> &Arc<PlanProperties> {
4718        &self.properties
4719    }
4720    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
4721        vec![]
4722    }
4723    fn with_new_children(
4724        self: Arc<Self>,
4725        _children: Vec<Arc<dyn ExecutionPlan>>,
4726    ) -> DFResult<Arc<dyn ExecutionPlan>> {
4727        Ok(self)
4728    }
4729    fn execute(
4730        &self,
4731        _partition: usize,
4732        _context: Arc<TaskContext>,
4733    ) -> DFResult<SendableRecordBatchStream> {
4734        let batches = {
4735            let guard = self.data.read();
4736            if guard.is_empty() {
4737                vec![RecordBatch::new_empty(Arc::clone(&self.schema))]
4738            } else {
4739                // Re-stamp every batch with this exec's schema. The shared
4740                // data Arc always holds batches with the rule's original
4741                // yield-schema names, but this scan may carry per-occurrence
4742                // aliased column names (multi-IS-ref clauses). Zero-copy:
4743                // only the schema pointer changes, never the columns.
4744                guard
4745                    .iter()
4746                    .map(|b| {
4747                        RecordBatch::try_new(Arc::clone(&self.schema), b.columns().to_vec())
4748                            .map_err(|e| {
4749                                datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
4750                            })
4751                    })
4752                    .collect::<DFResult<Vec<_>>>()?
4753            }
4754        };
4755        Ok(Box::pin(MemoryStream::try_new(
4756            batches,
4757            Arc::clone(&self.schema),
4758            None,
4759        )?))
4760    }
4761}
4762
4763// ---------------------------------------------------------------------------
4764// InMemoryExec — wrapper to feed Vec<RecordBatch> into operator chains
4765// ---------------------------------------------------------------------------
4766
4767/// Simple in-memory execution plan that serves pre-computed batches.
4768///
4769/// Used internally to feed fixpoint results into post-fixpoint operator chains
4770/// (FOLD, BEST BY). Not exported — only used within this module.
4771struct InMemoryExec {
4772    batches: Vec<RecordBatch>,
4773    schema: SchemaRef,
4774    properties: Arc<PlanProperties>,
4775}
4776
4777impl InMemoryExec {
4778    fn new(batches: Vec<RecordBatch>, schema: SchemaRef) -> Self {
4779        let properties = compute_plan_properties(Arc::clone(&schema));
4780        Self {
4781            batches,
4782            schema,
4783            properties,
4784        }
4785    }
4786}
4787
4788impl fmt::Debug for InMemoryExec {
4789    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4790        f.debug_struct("InMemoryExec")
4791            .field("num_batches", &self.batches.len())
4792            .field("schema", &self.schema)
4793            .finish()
4794    }
4795}
4796
4797impl DisplayAs for InMemoryExec {
4798    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4799        write!(f, "InMemoryExec: batches={}", self.batches.len())
4800    }
4801}
4802
4803impl ExecutionPlan for InMemoryExec {
4804    fn name(&self) -> &str {
4805        "InMemoryExec"
4806    }
4807    fn as_any(&self) -> &dyn Any {
4808        self
4809    }
4810    fn schema(&self) -> SchemaRef {
4811        Arc::clone(&self.schema)
4812    }
4813    fn properties(&self) -> &Arc<PlanProperties> {
4814        &self.properties
4815    }
4816    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
4817        vec![]
4818    }
4819    fn with_new_children(
4820        self: Arc<Self>,
4821        _children: Vec<Arc<dyn ExecutionPlan>>,
4822    ) -> DFResult<Arc<dyn ExecutionPlan>> {
4823        Ok(self)
4824    }
4825    fn execute(
4826        &self,
4827        _partition: usize,
4828        _context: Arc<TaskContext>,
4829    ) -> DFResult<SendableRecordBatchStream> {
4830        Ok(Box::pin(MemoryStream::try_new(
4831            self.batches.clone(),
4832            Arc::clone(&self.schema),
4833            None,
4834        )?))
4835    }
4836}
4837
4838// ---------------------------------------------------------------------------
4839// Post-fixpoint chain — FOLD and BEST BY on converged facts
4840// ---------------------------------------------------------------------------
4841
4842/// Apply post-FOLD WHERE (HAVING) filter to aggregated batches.
4843///
4844/// Converts each Cypher HAVING expression to a DataFusion physical expression
4845/// via `cypher_expr_to_df` → type coercion → `create_physical_expr`, evaluates
4846/// against the FOLD output, and keeps only rows where all conditions hold.
4847fn apply_having_filter(
4848    batches: Vec<RecordBatch>,
4849    having_exprs: &[Expr],
4850    schema: &SchemaRef,
4851    task_ctx: &Arc<TaskContext>,
4852) -> DFResult<Vec<RecordBatch>> {
4853    use arrow::compute::{and, filter_record_batch};
4854    use arrow_array::BooleanArray;
4855    use datafusion::common::DFSchema;
4856    use datafusion::logical_expr::LogicalPlanBuilder;
4857    use datafusion::logical_expr::execution_props::ExecutionProps;
4858    use datafusion::optimizer::AnalyzerRule;
4859    use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
4860    use datafusion::physical_expr::create_physical_expr;
4861
4862    if batches.is_empty() {
4863        return Ok(batches);
4864    }
4865
4866    // Build DFSchema from the FOLD output Arrow schema.
4867    let df_schema = DFSchema::try_from(schema.as_ref().clone()).map_err(|e| {
4868        datafusion::common::DataFusionError::Internal(format!("HAVING schema conversion: {e}"))
4869    })?;
4870
4871    // Use the active TaskContext's config rather than allocating a fresh
4872    // `SessionContext` per HAVING evaluation (~130 µs/call). HAVING uses only
4873    // built-in DataFusion arithmetic — no Cypher UDFs — so a default
4874    // `ExecutionProps` is sufficient (it's documented as cheap to construct).
4875    let config = (**task_ctx.session_config().options()).clone();
4876    let props = ExecutionProps::new();
4877
4878    // Cypher Expr → DataFusion DfExpr → type-coerced DfExpr → PhysicalExpr.
4879    //
4880    // Type coercion is needed because FOLD aggregates produce Float64 (SUM,
4881    // AVG) or Int64 (COUNT), and literal comparisons like `total >= 100`
4882    // may mix Float64 columns with Int64 literals.
4883    let physical_exprs: Vec<Arc<dyn datafusion::physical_expr::PhysicalExpr>> = having_exprs
4884        .iter()
4885        .map(|expr| {
4886            let df_expr = crate::query::df_expr::cypher_expr_to_df(expr, None).map_err(|e| {
4887                datafusion::common::DataFusionError::Internal(format!(
4888                    "HAVING expression conversion: {e}"
4889                ))
4890            })?;
4891
4892            // Run DataFusion's type coercion by wrapping in a Filter plan,
4893            // applying the TypeCoercion analyzer rule, then extracting the
4894            // coerced predicate.
4895            let empty = datafusion::logical_expr::LogicalPlan::EmptyRelation(
4896                datafusion::logical_expr::EmptyRelation {
4897                    produce_one_row: false,
4898                    schema: Arc::new(df_schema.clone()),
4899                },
4900            );
4901            let filter_plan = LogicalPlanBuilder::from(empty)
4902                .filter(df_expr.clone())?
4903                .build()?;
4904            let coerced_expr = match TypeCoercion::new().analyze(filter_plan, &config) {
4905                Ok(datafusion::logical_expr::LogicalPlan::Filter(f)) => f.predicate,
4906                _ => df_expr,
4907            };
4908
4909            create_physical_expr(&coerced_expr, &df_schema, &props)
4910        })
4911        .collect::<DFResult<Vec<_>>>()?;
4912
4913    let mut result = Vec::new();
4914    for batch in batches {
4915        // Evaluate each condition and AND the boolean masks.
4916        let mut mask: Option<BooleanArray> = None;
4917        for phys_expr in &physical_exprs {
4918            let value = phys_expr.evaluate(&batch)?;
4919            let arr = value.into_array(batch.num_rows())?;
4920            let bool_arr = arr.as_any().downcast_ref::<BooleanArray>().ok_or_else(|| {
4921                datafusion::common::DataFusionError::Internal(
4922                    "HAVING condition must evaluate to boolean".into(),
4923                )
4924            })?;
4925            mask = Some(match mask {
4926                None => bool_arr.clone(),
4927                Some(prev) => and(&prev, bool_arr).map_err(arrow_err)?,
4928            });
4929        }
4930        if let Some(ref m) = mask {
4931            let filtered = filter_record_batch(&batch, m).map_err(arrow_err)?;
4932            if filtered.num_rows() > 0 {
4933                result.push(filtered);
4934            }
4935        } else {
4936            result.push(batch);
4937        }
4938    }
4939    Ok(result)
4940}
4941
4942/// Apply post-fixpoint operators (FOLD, HAVING, BEST BY, PRIORITY) to converged facts.
4943#[allow(
4944    clippy::too_many_arguments,
4945    reason = "context bundle would be over-engineering for one call site"
4946)]
4947pub(crate) async fn apply_post_fixpoint_chain(
4948    facts: Vec<RecordBatch>,
4949    rule: &FixpointRulePlan,
4950    task_ctx: &Arc<TaskContext>,
4951    strict_probability_domain: bool,
4952    probability_epsilon: f64,
4953    semiring_kind: SemiringKind,
4954    provenance_tracker: Option<Arc<ProvenanceStore>>,
4955    top_k_proofs_k: usize,
4956    registry: Option<Arc<DerivedScanRegistry>>,
4957) -> DFResult<Vec<RecordBatch>> {
4958    if !rule.has_fold && !rule.has_best_by && !rule.has_priority && rule.having.is_empty() {
4959        return Ok(facts);
4960    }
4961
4962    // Wrap facts in InMemoryExec.
4963    // Prefer the actual batch schema (from physical execution) over the
4964    // pre-computed yield_schema, which may have wrong inferred types
4965    // (e.g. Float64 for a string property).
4966    let schema = facts
4967        .iter()
4968        .find(|b| b.num_rows() > 0)
4969        .map(|b| b.schema())
4970        .unwrap_or_else(|| Arc::clone(&rule.yield_schema));
4971
4972    // Phase D D-C0: pre-compute body-row → IS-ref support map for
4973    // TopKProofs MNOR's DNF inclusion-exclusion math. Must be built
4974    // here because `facts` is moved into `InMemoryExec` on the next
4975    // line. The map is keyed by a full-column row hash — only
4976    // meaningful when no downstream plan node strips/adds columns
4977    // between this batch view and the FoldExec input. PRIORITY drops
4978    // the `__priority` column, which would change row hashes; until
4979    // we plumb the map past PRIORITY, skip map construction for
4980    // PRIORITY rules (the failing TCK test doesn't use PRIORITY).
4981    // Read the active K from `semiring_kind` rather than the separate
4982    // `top_k_proofs_k` parameter — the latter is not always threaded
4983    // from the LocyProgram config (the semiring's `k` is the source of
4984    // truth).
4985    let topk_k: Option<usize> = match semiring_kind {
4986        SemiringKind::TopKProofs { k } if k > 0 => Some(k as usize),
4987        _ => None,
4988    };
4989    let body_support_map: Option<Arc<HashMap<Vec<u8>, Vec<ProofTerm>>>> = if topk_k.is_some()
4990        && !rule.has_priority
4991        && let Some(registry) = registry.as_ref()
4992    {
4993        let mut map: HashMap<Vec<u8>, Vec<ProofTerm>> = HashMap::new();
4994        for batch in &facts {
4995            let all_indices: Vec<usize> = (0..batch.num_columns()).collect();
4996            for row_idx in 0..batch.num_rows() {
4997                let support = collect_is_ref_inputs_for_body_row(rule, batch, row_idx, registry);
4998                if support.is_empty() {
4999                    continue;
5000                }
5001                let hash = fact_hash_key(batch, &all_indices, row_idx);
5002                map.insert(hash, support);
5003            }
5004        }
5005        if map.is_empty() {
5006            None
5007        } else {
5008            Some(Arc::new(map))
5009        }
5010    } else {
5011        None
5012    };
5013
5014    let input: Arc<dyn ExecutionPlan> = Arc::new(InMemoryExec::new(facts, schema.clone()));
5015
5016    // Reconcile key indices: rule's indices are yield-schema positions but
5017    // the actual batch may have different column ordering after schema
5018    // reconciliation during fixpoint iteration (same pattern as
5019    // FixpointState::reconcile_schema).
5020    let key_column_indices: Vec<usize> = rule
5021        .key_column_indices
5022        .iter()
5023        .filter_map(|&i| {
5024            let name = rule.yield_schema.field(i).name();
5025            schema.index_of(name).ok()
5026        })
5027        .collect();
5028
5029    // Apply PRIORITY first — keeps only rows with max __priority per KEY group,
5030    // then strips the __priority column from output.
5031    // Must run before FOLD so that the __priority column is still present.
5032    let current: Arc<dyn ExecutionPlan> = if rule.has_priority {
5033        let priority_schema = input.schema();
5034        let priority_idx = priority_schema.index_of("__priority").map_err(|_| {
5035            datafusion::common::DataFusionError::Internal(
5036                "PRIORITY rule missing __priority column".to_string(),
5037            )
5038        })?;
5039        Arc::new(PriorityExec::new(
5040            input,
5041            key_column_indices.clone(),
5042            priority_idx,
5043        ))
5044    } else {
5045        input
5046    };
5047
5048    // Apply FOLD
5049    let current: Arc<dyn ExecutionPlan> = if rule.has_fold && !rule.fold_bindings.is_empty() {
5050        Arc::new(FoldExec::new_with_topk(
5051            current,
5052            key_column_indices.clone(),
5053            rule.fold_bindings.clone(),
5054            strict_probability_domain,
5055            probability_epsilon,
5056            semiring_kind,
5057            provenance_tracker.clone(),
5058            topk_k.unwrap_or(top_k_proofs_k),
5059            body_support_map.clone(),
5060        ))
5061    } else {
5062        current
5063    };
5064
5065    // Apply HAVING (post-FOLD WHERE filter)
5066    let current: Arc<dyn ExecutionPlan> = if !rule.having.is_empty() {
5067        let batches = collect_all_partitions(&current, Arc::clone(task_ctx)).await?;
5068        let filtered = apply_having_filter(batches, &rule.having, &current.schema(), task_ctx)?;
5069        if filtered.is_empty() {
5070            return Ok(filtered);
5071        }
5072        Arc::new(InMemoryExec::new(filtered, Arc::clone(&current.schema())))
5073    } else {
5074        current
5075    };
5076
5077    // Apply BEST BY
5078    let current: Arc<dyn ExecutionPlan> = if rule.has_best_by && !rule.best_by_criteria.is_empty() {
5079        Arc::new(BestByExec::new(
5080            current,
5081            key_column_indices.clone(),
5082            rule.best_by_criteria.clone(),
5083            rule.deterministic,
5084        ))
5085    } else {
5086        current
5087    };
5088
5089    collect_all_partitions(&current, Arc::clone(task_ctx)).await
5090}
5091
5092// ---------------------------------------------------------------------------
5093// FixpointExec — DataFusion ExecutionPlan
5094// ---------------------------------------------------------------------------
5095
5096/// DataFusion `ExecutionPlan` that drives semi-naive fixpoint iteration.
5097///
5098/// Has no physical children: clause bodies are re-planned from logical plans
5099/// on each iteration (same pattern as `RecursiveCTEExec` and `GraphApplyExec`).
5100pub struct FixpointExec {
5101    rules: Vec<FixpointRulePlan>,
5102    max_iterations: usize,
5103    timeout: Duration,
5104    graph_ctx: Arc<GraphExecutionContext>,
5105    session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
5106    storage: Arc<StorageManager>,
5107    schema_info: Arc<UniSchema>,
5108    params: HashMap<String, Value>,
5109    derived_scan_registry: Arc<DerivedScanRegistry>,
5110    output_schema: SchemaRef,
5111    properties: Arc<PlanProperties>,
5112    metrics: ExecutionPlanMetricsSet,
5113    max_derived_bytes: usize,
5114    /// Optional provenance tracker populated during fixpoint iteration.
5115    derivation_tracker: Option<Arc<ProvenanceStore>>,
5116    /// Shared slot written with per-rule iteration counts after convergence.
5117    iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
5118    strict_probability_domain: bool,
5119    probability_epsilon: f64,
5120    exact_probability: bool,
5121    max_bdd_variables: usize,
5122    /// Shared slot for runtime warnings collected during fixpoint iteration.
5123    warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
5124    /// Shared slot for groups where BDD fell back to independence mode.
5125    approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
5126    /// When > 0, retain at most this many proofs per fact (top-k provenance).
5127    top_k_proofs: usize,
5128    /// Shared flag: set to true on timeout to signal partial results.
5129    timeout_flag: Arc<std::sync::atomic::AtomicU8>,
5130    /// Active probability semiring (rollout D-7).
5131    semiring_kind: SemiringKind,
5132    /// Phase B Slice 3 registry of neural classifiers, keyed by the
5133    /// model name from `CREATE MODEL`. Held by `Arc` so executor clones
5134    /// share the same underlying map.
5135    classifier_registry: Arc<ClassifierRegistry>,
5136    /// Phase B follow-up: optional per-evaluation memoization cache
5137    /// for classifier outputs keyed by `(model_name, feature_hash)`.
5138    /// `None` → no caching; `Some` → cache shared across fixpoint
5139    /// iterations (and optionally across the entire query / multiple
5140    /// queries, when the caller threads it via `LocyConfig`).
5141    classifier_cache: Option<Arc<ModelInvocationCache>>,
5142    /// Phase C B1-B3 follow-up: per-query side-channel store
5143    /// for (raw, calibrated, confidence_band) records. Read by
5144    /// EXPLAIN; not used by the fixpoint inner loop directly
5145    /// (LocyModelInvokeExec writes; this struct just carries
5146    /// the Arc to keep the type wiring consistent across the
5147    /// LocyProgramExec/FixpointExec boundary).
5148    #[allow(
5149        dead_code,
5150        reason = "boundary plumbing; read by EXPLAIN via LocyModelInvokeExec"
5151    )]
5152    classifier_provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
5153}
5154
5155impl fmt::Debug for FixpointExec {
5156    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
5157        f.debug_struct("FixpointExec")
5158            .field("rules_count", &self.rules.len())
5159            .field("max_iterations", &self.max_iterations)
5160            .field("timeout", &self.timeout)
5161            .field("output_schema", &self.output_schema)
5162            .field("max_derived_bytes", &self.max_derived_bytes)
5163            .finish_non_exhaustive()
5164    }
5165}
5166
5167impl FixpointExec {
5168    /// Create a new `FixpointExec`.
5169    #[expect(
5170        clippy::too_many_arguments,
5171        reason = "FixpointExec configuration needs all context"
5172    )]
5173    #[deprecated(
5174        note = "use `new_with_semiring_classifiers_and_cache` (or the lighter \
5175                `new_with_semiring_and_classifiers` / `new_with_semiring`) — \
5176                this legacy ctor defaults the semiring to AddMultProb and \
5177                ships no classifier registry, which the Phase B+ runtime needs \
5178                explicitly. To be removed after C0 Stage 2."
5179    )]
5180    pub fn new(
5181        rules: Vec<FixpointRulePlan>,
5182        max_iterations: usize,
5183        timeout: Duration,
5184        graph_ctx: Arc<GraphExecutionContext>,
5185        session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
5186        storage: Arc<StorageManager>,
5187        schema_info: Arc<UniSchema>,
5188        params: HashMap<String, Value>,
5189        derived_scan_registry: Arc<DerivedScanRegistry>,
5190        output_schema: SchemaRef,
5191        max_derived_bytes: usize,
5192        derivation_tracker: Option<Arc<ProvenanceStore>>,
5193        iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
5194        strict_probability_domain: bool,
5195        probability_epsilon: f64,
5196        exact_probability: bool,
5197        max_bdd_variables: usize,
5198        warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
5199        approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
5200        top_k_proofs: usize,
5201        timeout_flag: Arc<std::sync::atomic::AtomicU8>,
5202    ) -> Self {
5203        Self::new_with_semiring_and_classifiers(
5204            rules,
5205            max_iterations,
5206            timeout,
5207            graph_ctx,
5208            session_ctx,
5209            storage,
5210            schema_info,
5211            params,
5212            derived_scan_registry,
5213            output_schema,
5214            max_derived_bytes,
5215            derivation_tracker,
5216            iteration_counts,
5217            strict_probability_domain,
5218            probability_epsilon,
5219            exact_probability,
5220            max_bdd_variables,
5221            warnings_slot,
5222            approximate_slot,
5223            top_k_proofs,
5224            timeout_flag,
5225            SemiringKind::AddMultProb,
5226            Arc::new(ClassifierRegistry::new()),
5227        )
5228    }
5229
5230    /// Variant accepting an explicit [`SemiringKind`]. Empty classifier
5231    /// registry; for the full variant call
5232    /// [`FixpointExec::new_with_semiring_and_classifiers`].
5233    #[expect(
5234        clippy::too_many_arguments,
5235        reason = "FixpointExec configuration needs all context"
5236    )]
5237    pub fn new_with_semiring(
5238        rules: Vec<FixpointRulePlan>,
5239        max_iterations: usize,
5240        timeout: Duration,
5241        graph_ctx: Arc<GraphExecutionContext>,
5242        session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
5243        storage: Arc<StorageManager>,
5244        schema_info: Arc<UniSchema>,
5245        params: HashMap<String, Value>,
5246        derived_scan_registry: Arc<DerivedScanRegistry>,
5247        output_schema: SchemaRef,
5248        max_derived_bytes: usize,
5249        derivation_tracker: Option<Arc<ProvenanceStore>>,
5250        iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
5251        strict_probability_domain: bool,
5252        probability_epsilon: f64,
5253        exact_probability: bool,
5254        max_bdd_variables: usize,
5255        warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
5256        approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
5257        top_k_proofs: usize,
5258        timeout_flag: Arc<std::sync::atomic::AtomicU8>,
5259        semiring_kind: SemiringKind,
5260    ) -> Self {
5261        Self::new_with_semiring_and_classifiers(
5262            rules,
5263            max_iterations,
5264            timeout,
5265            graph_ctx,
5266            session_ctx,
5267            storage,
5268            schema_info,
5269            params,
5270            derived_scan_registry,
5271            output_schema,
5272            max_derived_bytes,
5273            derivation_tracker,
5274            iteration_counts,
5275            strict_probability_domain,
5276            probability_epsilon,
5277            exact_probability,
5278            max_bdd_variables,
5279            warnings_slot,
5280            approximate_slot,
5281            top_k_proofs,
5282            timeout_flag,
5283            semiring_kind,
5284            Arc::new(ClassifierRegistry::new()),
5285        )
5286    }
5287
5288    /// Phase B Slice 3 entry: accepts both the semiring kind and the
5289    /// runtime classifier registry. The planner uses this when the
5290    /// program contains `CREATE MODEL` declarations.
5291    #[expect(
5292        clippy::too_many_arguments,
5293        reason = "FixpointExec configuration needs all context"
5294    )]
5295    pub fn new_with_semiring_and_classifiers(
5296        rules: Vec<FixpointRulePlan>,
5297        max_iterations: usize,
5298        timeout: Duration,
5299        graph_ctx: Arc<GraphExecutionContext>,
5300        session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
5301        storage: Arc<StorageManager>,
5302        schema_info: Arc<UniSchema>,
5303        params: HashMap<String, Value>,
5304        derived_scan_registry: Arc<DerivedScanRegistry>,
5305        output_schema: SchemaRef,
5306        max_derived_bytes: usize,
5307        derivation_tracker: Option<Arc<ProvenanceStore>>,
5308        iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
5309        strict_probability_domain: bool,
5310        probability_epsilon: f64,
5311        exact_probability: bool,
5312        max_bdd_variables: usize,
5313        warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
5314        approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
5315        top_k_proofs: usize,
5316        timeout_flag: Arc<std::sync::atomic::AtomicU8>,
5317        semiring_kind: SemiringKind,
5318        classifier_registry: Arc<ClassifierRegistry>,
5319    ) -> Self {
5320        Self::new_with_semiring_classifiers_and_cache(
5321            rules,
5322            max_iterations,
5323            timeout,
5324            graph_ctx,
5325            session_ctx,
5326            storage,
5327            schema_info,
5328            params,
5329            derived_scan_registry,
5330            output_schema,
5331            max_derived_bytes,
5332            derivation_tracker,
5333            iteration_counts,
5334            strict_probability_domain,
5335            probability_epsilon,
5336            exact_probability,
5337            max_bdd_variables,
5338            warnings_slot,
5339            approximate_slot,
5340            top_k_proofs,
5341            timeout_flag,
5342            semiring_kind,
5343            classifier_registry,
5344            None,
5345            None,
5346        )
5347    }
5348
5349    /// Phase B follow-up: full constructor accepting an optional
5350    /// memoization cache. Existing callers default to `None` (no cache);
5351    /// the impl_locy.rs entry passes the user's `config.classifier_cache`.
5352    #[expect(
5353        clippy::too_many_arguments,
5354        reason = "FixpointExec configuration needs all context"
5355    )]
5356    pub fn new_with_semiring_classifiers_and_cache(
5357        rules: Vec<FixpointRulePlan>,
5358        max_iterations: usize,
5359        timeout: Duration,
5360        graph_ctx: Arc<GraphExecutionContext>,
5361        session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
5362        storage: Arc<StorageManager>,
5363        schema_info: Arc<UniSchema>,
5364        params: HashMap<String, Value>,
5365        derived_scan_registry: Arc<DerivedScanRegistry>,
5366        output_schema: SchemaRef,
5367        max_derived_bytes: usize,
5368        derivation_tracker: Option<Arc<ProvenanceStore>>,
5369        iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
5370        strict_probability_domain: bool,
5371        probability_epsilon: f64,
5372        exact_probability: bool,
5373        max_bdd_variables: usize,
5374        warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
5375        approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
5376        top_k_proofs: usize,
5377        timeout_flag: Arc<std::sync::atomic::AtomicU8>,
5378        semiring_kind: SemiringKind,
5379        classifier_registry: Arc<ClassifierRegistry>,
5380        classifier_cache: Option<Arc<ModelInvocationCache>>,
5381        classifier_provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
5382    ) -> Self {
5383        let properties = compute_plan_properties(Arc::clone(&output_schema));
5384        Self {
5385            rules,
5386            max_iterations,
5387            timeout,
5388            graph_ctx,
5389            session_ctx,
5390            storage,
5391            schema_info,
5392            params,
5393            derived_scan_registry,
5394            output_schema,
5395            properties,
5396            metrics: ExecutionPlanMetricsSet::new(),
5397            max_derived_bytes,
5398            derivation_tracker,
5399            iteration_counts,
5400            strict_probability_domain,
5401            probability_epsilon,
5402            exact_probability,
5403            max_bdd_variables,
5404            warnings_slot,
5405            approximate_slot,
5406            top_k_proofs,
5407            timeout_flag,
5408            semiring_kind,
5409            classifier_registry,
5410            classifier_cache,
5411            classifier_provenance_store,
5412        }
5413    }
5414
5415    /// Returns the shared iteration counts slot for post-execution inspection.
5416    pub fn iteration_counts(&self) -> Arc<StdRwLock<HashMap<String, usize>>> {
5417        Arc::clone(&self.iteration_counts)
5418    }
5419}
5420
5421impl DisplayAs for FixpointExec {
5422    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
5423        write!(
5424            f,
5425            "FixpointExec: rules=[{}], max_iter={}, timeout={:?}",
5426            self.rules
5427                .iter()
5428                .map(|r| r.name.as_str())
5429                .collect::<Vec<_>>()
5430                .join(", "),
5431            self.max_iterations,
5432            self.timeout,
5433        )
5434    }
5435}
5436
5437impl ExecutionPlan for FixpointExec {
5438    fn name(&self) -> &str {
5439        "FixpointExec"
5440    }
5441
5442    fn as_any(&self) -> &dyn Any {
5443        self
5444    }
5445
5446    fn schema(&self) -> SchemaRef {
5447        Arc::clone(&self.output_schema)
5448    }
5449
5450    fn properties(&self) -> &Arc<PlanProperties> {
5451        &self.properties
5452    }
5453
5454    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
5455        // No physical children — clause bodies are re-planned each iteration
5456        vec![]
5457    }
5458
5459    fn with_new_children(
5460        self: Arc<Self>,
5461        children: Vec<Arc<dyn ExecutionPlan>>,
5462    ) -> DFResult<Arc<dyn ExecutionPlan>> {
5463        if !children.is_empty() {
5464            return Err(datafusion::error::DataFusionError::Plan(
5465                "FixpointExec has no children".to_string(),
5466            ));
5467        }
5468        Ok(self)
5469    }
5470
5471    fn execute(
5472        &self,
5473        partition: usize,
5474        _context: Arc<TaskContext>,
5475    ) -> DFResult<SendableRecordBatchStream> {
5476        let metrics = BaselineMetrics::new(&self.metrics, partition);
5477
5478        // Clone all fields for the async closure
5479        let rules = self
5480            .rules
5481            .iter()
5482            .map(|r| {
5483                // We need to clone the FixpointRulePlan, but it contains LogicalPlan
5484                // which doesn't implement Clone traditionally. However, our LogicalPlan
5485                // does implement Clone since it's an enum.
5486                FixpointRulePlan {
5487                    name: r.name.clone(),
5488                    clauses: r
5489                        .clauses
5490                        .iter()
5491                        .map(|c| FixpointClausePlan {
5492                            body_logical: c.body_logical.clone(),
5493                            is_ref_bindings: c.is_ref_bindings.clone(),
5494                            priority: c.priority,
5495                            along_bindings: c.along_bindings.clone(),
5496                            model_invocations: c.model_invocations.clone(),
5497                        })
5498                        .collect(),
5499                    yield_schema: Arc::clone(&r.yield_schema),
5500                    key_column_indices: r.key_column_indices.clone(),
5501                    priority: r.priority,
5502                    has_fold: r.has_fold,
5503                    fold_bindings: r.fold_bindings.clone(),
5504                    having: r.having.clone(),
5505                    has_best_by: r.has_best_by,
5506                    best_by_criteria: r.best_by_criteria.clone(),
5507                    has_priority: r.has_priority,
5508                    deterministic: r.deterministic,
5509                    prob_column_name: r.prob_column_name.clone(),
5510                    non_linear: r.non_linear,
5511                }
5512            })
5513            .collect();
5514
5515        let max_iterations = self.max_iterations;
5516        let timeout = self.timeout;
5517        let graph_ctx = Arc::clone(&self.graph_ctx);
5518        let session_ctx = Arc::clone(&self.session_ctx);
5519        let storage = Arc::clone(&self.storage);
5520        let schema_info = Arc::clone(&self.schema_info);
5521        let params = self.params.clone();
5522        let registry = Arc::clone(&self.derived_scan_registry);
5523        let output_schema = Arc::clone(&self.output_schema);
5524        let max_derived_bytes = self.max_derived_bytes;
5525        let derivation_tracker = self.derivation_tracker.clone();
5526        let iteration_counts = Arc::clone(&self.iteration_counts);
5527        let strict_probability_domain = self.strict_probability_domain;
5528        let probability_epsilon = self.probability_epsilon;
5529        let exact_probability = self.exact_probability;
5530        let max_bdd_variables = self.max_bdd_variables;
5531        let warnings_slot = Arc::clone(&self.warnings_slot);
5532        let approximate_slot = Arc::clone(&self.approximate_slot);
5533        let top_k_proofs = self.top_k_proofs;
5534        let timeout_flag = Arc::clone(&self.timeout_flag);
5535        let semiring_kind = self.semiring_kind;
5536        let classifier_registry = Arc::clone(&self.classifier_registry);
5537        let classifier_cache = self.classifier_cache.as_ref().map(Arc::clone);
5538        let classifier_provenance_store = self.classifier_provenance_store.as_ref().map(Arc::clone);
5539
5540        let fut = async move {
5541            run_fixpoint_loop(
5542                rules,
5543                max_iterations,
5544                timeout,
5545                graph_ctx,
5546                session_ctx,
5547                storage,
5548                schema_info,
5549                params,
5550                registry,
5551                output_schema,
5552                max_derived_bytes,
5553                derivation_tracker,
5554                iteration_counts,
5555                strict_probability_domain,
5556                probability_epsilon,
5557                exact_probability,
5558                max_bdd_variables,
5559                warnings_slot,
5560                approximate_slot,
5561                top_k_proofs,
5562                timeout_flag,
5563                semiring_kind,
5564                classifier_registry,
5565                classifier_cache,
5566                classifier_provenance_store,
5567            )
5568            .await
5569        };
5570
5571        Ok(Box::pin(FixpointStream {
5572            state: FixpointStreamState::Running(Box::pin(fut)),
5573            schema: Arc::clone(&self.output_schema),
5574            metrics,
5575        }))
5576    }
5577
5578    fn metrics(&self) -> Option<MetricsSet> {
5579        Some(self.metrics.clone_inner())
5580    }
5581}
5582
5583// ---------------------------------------------------------------------------
5584// FixpointStream — async state machine for streaming results
5585// ---------------------------------------------------------------------------
5586
5587enum FixpointStreamState {
5588    /// Fixpoint loop is running.
5589    Running(Pin<Box<dyn std::future::Future<Output = DFResult<Vec<RecordBatch>>> + Send>>),
5590    /// Emitting accumulated result batches one at a time.
5591    Emitting(Vec<RecordBatch>, usize),
5592    /// All batches emitted.
5593    Done,
5594}
5595
5596struct FixpointStream {
5597    state: FixpointStreamState,
5598    schema: SchemaRef,
5599    metrics: BaselineMetrics,
5600}
5601
5602impl Stream for FixpointStream {
5603    type Item = DFResult<RecordBatch>;
5604
5605    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
5606        let this = self.get_mut();
5607        let metrics = this.metrics.clone();
5608        let _timer = metrics.elapsed_compute().timer();
5609        loop {
5610            match &mut this.state {
5611                FixpointStreamState::Running(fut) => match fut.as_mut().poll(cx) {
5612                    Poll::Ready(Ok(batches)) => {
5613                        if batches.is_empty() {
5614                            this.state = FixpointStreamState::Done;
5615                            return Poll::Ready(None);
5616                        }
5617                        this.state = FixpointStreamState::Emitting(batches, 0);
5618                        // Loop to emit first batch
5619                    }
5620                    Poll::Ready(Err(e)) => {
5621                        this.state = FixpointStreamState::Done;
5622                        return Poll::Ready(Some(Err(e)));
5623                    }
5624                    Poll::Pending => return Poll::Pending,
5625                },
5626                FixpointStreamState::Emitting(batches, idx) => {
5627                    if *idx >= batches.len() {
5628                        this.state = FixpointStreamState::Done;
5629                        return Poll::Ready(None);
5630                    }
5631                    let batch = batches[*idx].clone();
5632                    *idx += 1;
5633                    this.metrics.record_output(batch.num_rows());
5634                    return Poll::Ready(Some(Ok(batch)));
5635                }
5636                FixpointStreamState::Done => return Poll::Ready(None),
5637            }
5638        }
5639    }
5640}
5641
5642impl RecordBatchStream for FixpointStream {
5643    fn schema(&self) -> SchemaRef {
5644        Arc::clone(&self.schema)
5645    }
5646}
5647
5648// ---------------------------------------------------------------------------
5649// Unit tests
5650// ---------------------------------------------------------------------------
5651
5652#[cfg(test)]
5653mod tests {
5654    use super::*;
5655    use arrow_array::{Float64Array, Int64Array, StringArray};
5656    use arrow_schema::{DataType, Field, Schema};
5657
5658    fn test_schema() -> SchemaRef {
5659        Arc::new(Schema::new(vec![
5660            Field::new("name", DataType::Utf8, true),
5661            Field::new("value", DataType::Int64, true),
5662        ]))
5663    }
5664
5665    fn make_batch(names: &[&str], values: &[i64]) -> RecordBatch {
5666        RecordBatch::try_new(
5667            test_schema(),
5668            vec![
5669                Arc::new(StringArray::from(
5670                    names.iter().map(|s| Some(*s)).collect::<Vec<_>>(),
5671                )),
5672                Arc::new(Int64Array::from(values.to_vec())),
5673            ],
5674        )
5675        .unwrap()
5676    }
5677
5678    // --- FixpointState dedup tests ---
5679
5680    #[tokio::test]
5681    async fn test_fixpoint_state_empty_facts_adds_all() {
5682        let schema = test_schema();
5683        let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
5684
5685        let batch = make_batch(&["a", "b", "c"], &[1, 2, 3]);
5686        let changed = state.merge_delta(vec![batch], None).await.unwrap();
5687
5688        assert!(changed);
5689        assert_eq!(state.all_facts().len(), 1);
5690        assert_eq!(state.all_facts()[0].num_rows(), 3);
5691        assert_eq!(state.all_delta().len(), 1);
5692        assert_eq!(state.all_delta()[0].num_rows(), 3);
5693    }
5694
5695    #[tokio::test]
5696    async fn test_fixpoint_state_exact_duplicates_excluded() {
5697        let schema = test_schema();
5698        let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
5699
5700        let batch1 = make_batch(&["a", "b"], &[1, 2]);
5701        state.merge_delta(vec![batch1], None).await.unwrap();
5702
5703        // Same rows again
5704        let batch2 = make_batch(&["a", "b"], &[1, 2]);
5705        let changed = state.merge_delta(vec![batch2], None).await.unwrap();
5706        assert!(!changed);
5707        assert!(
5708            state.all_delta().is_empty() || state.all_delta().iter().all(|b| b.num_rows() == 0)
5709        );
5710    }
5711
5712    #[tokio::test]
5713    async fn test_fixpoint_state_partial_overlap() {
5714        let schema = test_schema();
5715        let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
5716
5717        let batch1 = make_batch(&["a", "b"], &[1, 2]);
5718        state.merge_delta(vec![batch1], None).await.unwrap();
5719
5720        // "a":1 is duplicate, "c":3 is new
5721        let batch2 = make_batch(&["a", "c"], &[1, 3]);
5722        let changed = state.merge_delta(vec![batch2], None).await.unwrap();
5723        assert!(changed);
5724
5725        // Delta should have only "c":3
5726        let delta_rows: usize = state.all_delta().iter().map(|b| b.num_rows()).sum();
5727        assert_eq!(delta_rows, 1);
5728
5729        // Total facts: a:1, b:2, c:3
5730        let total_rows: usize = state.all_facts().iter().map(|b| b.num_rows()).sum();
5731        assert_eq!(total_rows, 3);
5732    }
5733
5734    #[tokio::test]
5735    async fn test_fixpoint_state_convergence() {
5736        let schema = test_schema();
5737        let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
5738
5739        let batch = make_batch(&["a"], &[1]);
5740        state.merge_delta(vec![batch], None).await.unwrap();
5741
5742        // Empty candidates → converged
5743        let changed = state.merge_delta(vec![], None).await.unwrap();
5744        assert!(!changed);
5745        assert!(state.is_converged());
5746    }
5747
5748    // --- RowDedupState tests ---
5749
5750    #[test]
5751    fn test_row_dedup_persistent_across_calls() {
5752        // RowDedupState should remember rows from the first call so the second
5753        // call does not re-accept them (O(M) per iteration, no facts re-scan).
5754        let schema = test_schema();
5755        let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
5756
5757        let batch1 = make_batch(&["a", "b"], &[1, 2]);
5758        let delta1 = rd.compute_delta(&[batch1], &schema).unwrap();
5759        // First call: both rows are new.
5760        let rows1: usize = delta1.iter().map(|b| b.num_rows()).sum();
5761        assert_eq!(rows1, 2);
5762
5763        // Second call with same rows: seen set already has them → empty delta.
5764        let batch2 = make_batch(&["a", "b"], &[1, 2]);
5765        let delta2 = rd.compute_delta(&[batch2], &schema).unwrap();
5766        let rows2: usize = delta2.iter().map(|b| b.num_rows()).sum();
5767        assert_eq!(rows2, 0);
5768
5769        // Third call with one old + one new: only the new row is returned.
5770        let batch3 = make_batch(&["a", "c"], &[1, 3]);
5771        let delta3 = rd.compute_delta(&[batch3], &schema).unwrap();
5772        let rows3: usize = delta3.iter().map(|b| b.num_rows()).sum();
5773        assert_eq!(rows3, 1);
5774    }
5775
5776    #[test]
5777    fn test_row_dedup_null_handling() {
5778        use arrow_array::StringArray;
5779        use arrow_schema::{DataType, Field, Schema};
5780
5781        let schema: SchemaRef = Arc::new(Schema::new(vec![
5782            Field::new("a", DataType::Utf8, true),
5783            Field::new("b", DataType::Int64, true),
5784        ]));
5785        let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
5786
5787        // Two rows: (NULL, 1) and (NULL, 1) — same NULLs → duplicate.
5788        let batch_nulls = RecordBatch::try_new(
5789            Arc::clone(&schema),
5790            vec![
5791                Arc::new(StringArray::from(vec![None::<&str>, None::<&str>])),
5792                Arc::new(arrow_array::Int64Array::from(vec![1i64, 1i64])),
5793            ],
5794        )
5795        .unwrap();
5796        let delta = rd.compute_delta(&[batch_nulls], &schema).unwrap();
5797        let rows: usize = delta.iter().map(|b| b.num_rows()).sum();
5798        assert_eq!(rows, 1, "two identical NULL rows should be deduped to one");
5799
5800        // (NULL, 2) — NULL in same col but different non-null col → distinct.
5801        let batch_diff = RecordBatch::try_new(
5802            Arc::clone(&schema),
5803            vec![
5804                Arc::new(StringArray::from(vec![None::<&str>])),
5805                Arc::new(arrow_array::Int64Array::from(vec![2i64])),
5806            ],
5807        )
5808        .unwrap();
5809        let delta2 = rd.compute_delta(&[batch_diff], &schema).unwrap();
5810        let rows2: usize = delta2.iter().map(|b| b.num_rows()).sum();
5811        assert_eq!(rows2, 1, "(NULL, 2) is distinct from (NULL, 1)");
5812    }
5813
5814    #[test]
5815    fn test_row_dedup_within_candidate_dedup() {
5816        // Duplicate rows within a single candidate batch should be collapsed to one.
5817        let schema = test_schema();
5818        let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
5819
5820        // Batch with three rows: a:1, a:1, b:2 — "a:1" appears twice.
5821        let batch = make_batch(&["a", "a", "b"], &[1, 1, 2]);
5822        let delta = rd.compute_delta(&[batch], &schema).unwrap();
5823        let rows: usize = delta.iter().map(|b| b.num_rows()).sum();
5824        assert_eq!(rows, 2, "within-batch dup should be collapsed: a:1, b:2");
5825    }
5826
5827    // --- Float rounding tests ---
5828
5829    #[test]
5830    fn test_round_float_columns_near_duplicates() {
5831        let schema = Arc::new(Schema::new(vec![
5832            Field::new("name", DataType::Utf8, true),
5833            Field::new("dist", DataType::Float64, true),
5834        ]));
5835        let batch = RecordBatch::try_new(
5836            Arc::clone(&schema),
5837            vec![
5838                Arc::new(StringArray::from(vec![Some("a"), Some("a")])),
5839                Arc::new(Float64Array::from(vec![1.0000000000001, 1.0000000000002])),
5840            ],
5841        )
5842        .unwrap();
5843
5844        let rounded = round_float_columns(&[batch]);
5845        assert_eq!(rounded.len(), 1);
5846        let col = rounded[0]
5847            .column(1)
5848            .as_any()
5849            .downcast_ref::<Float64Array>()
5850            .unwrap();
5851        // Both should round to same value
5852        assert_eq!(col.value(0), col.value(1));
5853    }
5854
5855    // --- DerivedScanRegistry tests ---
5856
5857    #[test]
5858    fn test_registry_write_read_round_trip() {
5859        let schema = test_schema();
5860        let data = Arc::new(RwLock::new(Vec::new()));
5861        let mut reg = DerivedScanRegistry::new();
5862        reg.add(DerivedScanEntry {
5863            scan_index: 0,
5864            rule_name: "reachable".into(),
5865            is_self_ref: true,
5866            data: Arc::clone(&data),
5867            schema: Arc::clone(&schema),
5868        });
5869
5870        let batch = make_batch(&["x"], &[42]);
5871        reg.write_data(0, vec![batch.clone()]);
5872
5873        let entry = reg.get(0).unwrap();
5874        let guard = entry.data.read();
5875        assert_eq!(guard.len(), 1);
5876        assert_eq!(guard[0].num_rows(), 1);
5877    }
5878
5879    #[test]
5880    fn test_registry_entries_for_rule() {
5881        let schema = test_schema();
5882        let mut reg = DerivedScanRegistry::new();
5883        reg.add(DerivedScanEntry {
5884            scan_index: 0,
5885            rule_name: "r1".into(),
5886            is_self_ref: true,
5887            data: Arc::new(RwLock::new(Vec::new())),
5888            schema: Arc::clone(&schema),
5889        });
5890        reg.add(DerivedScanEntry {
5891            scan_index: 1,
5892            rule_name: "r2".into(),
5893            is_self_ref: false,
5894            data: Arc::new(RwLock::new(Vec::new())),
5895            schema: Arc::clone(&schema),
5896        });
5897        reg.add(DerivedScanEntry {
5898            scan_index: 2,
5899            rule_name: "r1".into(),
5900            is_self_ref: false,
5901            data: Arc::new(RwLock::new(Vec::new())),
5902            schema: Arc::clone(&schema),
5903        });
5904
5905        assert_eq!(reg.entries_for_rule("r1").len(), 2);
5906        assert_eq!(reg.entries_for_rule("r2").len(), 1);
5907        assert_eq!(reg.entries_for_rule("r3").len(), 0);
5908    }
5909
5910    // --- MonotonicAggState tests ---
5911
5912    #[test]
5913    fn test_monotonic_agg_update_and_stability() {
5914        let bindings = vec![MonotonicFoldBinding {
5915            fold_name: "total".into(),
5916            aggregate: std::sync::Arc::new(uni_plugin_builtin::locy_aggregates::SumAgg),
5917            input_col_index: 1,
5918            input_col_name: None,
5919        }];
5920        let mut agg = MonotonicAggState::new(bindings);
5921
5922        // First update
5923        let batch = make_batch(&["a"], &[10]);
5924        agg.snapshot();
5925        let changed = agg
5926            .update(&[0], &[batch], false, SemiringKind::AddMultProb)
5927            .unwrap();
5928        assert!(changed);
5929        assert!(!agg.is_stable()); // changed since snapshot
5930
5931        // Snapshot and check stability with no new data
5932        agg.snapshot();
5933        let changed = agg
5934            .update(&[0], &[], false, SemiringKind::AddMultProb)
5935            .unwrap();
5936        assert!(!changed);
5937        assert!(agg.is_stable());
5938    }
5939
5940    // --- Memory limit test ---
5941
5942    #[tokio::test]
5943    async fn test_memory_limit_exceeded() {
5944        let schema = test_schema();
5945        // Set a tiny limit
5946        let mut state = FixpointState::new("test".into(), schema, vec![0], 1, None, false);
5947
5948        let batch = make_batch(&["a", "b", "c"], &[1, 2, 3]);
5949        let result = state.merge_delta(vec![batch], None).await;
5950        assert!(result.is_err());
5951        let err = result.unwrap_err().to_string();
5952        assert!(err.contains("memory limit"), "Error was: {}", err);
5953    }
5954
5955    // --- FixpointStream lifecycle test ---
5956
5957    #[tokio::test]
5958    async fn test_fixpoint_stream_emitting() {
5959        use futures::StreamExt;
5960
5961        let schema = test_schema();
5962        let batch1 = make_batch(&["a"], &[1]);
5963        let batch2 = make_batch(&["b"], &[2]);
5964
5965        let metrics = ExecutionPlanMetricsSet::new();
5966        let baseline = BaselineMetrics::new(&metrics, 0);
5967
5968        let mut stream = FixpointStream {
5969            state: FixpointStreamState::Emitting(vec![batch1, batch2], 0),
5970            schema,
5971            metrics: baseline,
5972        };
5973
5974        let stream = Pin::new(&mut stream);
5975        let batches: Vec<RecordBatch> = stream.filter_map(|r| async { r.ok() }).collect().await;
5976
5977        assert_eq!(batches.len(), 2);
5978        assert_eq!(batches[0].num_rows(), 1);
5979        assert_eq!(batches[1].num_rows(), 1);
5980    }
5981
5982    // ── MonotonicAggState MNOR/MPROD tests ──────────────────────────────
5983
5984    fn make_f64_batch(names: &[&str], values: &[f64]) -> RecordBatch {
5985        let schema = Arc::new(Schema::new(vec![
5986            Field::new("name", DataType::Utf8, true),
5987            Field::new("value", DataType::Float64, true),
5988        ]));
5989        RecordBatch::try_new(
5990            schema,
5991            vec![
5992                Arc::new(StringArray::from(
5993                    names.iter().map(|s| Some(*s)).collect::<Vec<_>>(),
5994                )),
5995                Arc::new(Float64Array::from(values.to_vec())),
5996            ],
5997        )
5998        .unwrap()
5999    }
6000
6001    fn make_nor_binding() -> Vec<MonotonicFoldBinding> {
6002        vec![MonotonicFoldBinding {
6003            fold_name: "prob".into(),
6004            aggregate: std::sync::Arc::new(uni_plugin_builtin::locy_aggregates::MnorAgg),
6005            input_col_index: 1,
6006            input_col_name: None,
6007        }]
6008    }
6009
6010    fn make_prod_binding() -> Vec<MonotonicFoldBinding> {
6011        vec![MonotonicFoldBinding {
6012            fold_name: "prob".into(),
6013            aggregate: std::sync::Arc::new(uni_plugin_builtin::locy_aggregates::MprodAgg),
6014            input_col_index: 1,
6015            input_col_name: None,
6016        }]
6017    }
6018
6019    fn acc_key(name: &str) -> (Vec<ScalarKey>, String) {
6020        (vec![ScalarKey::Utf8(name.to_string())], "prob".to_string())
6021    }
6022
6023    #[test]
6024    fn test_monotonic_nor_first_update() {
6025        let mut agg = MonotonicAggState::new(make_nor_binding());
6026        let batch = make_f64_batch(&["a"], &[0.3]);
6027        let changed = agg
6028            .update(&[0], &[batch], false, SemiringKind::AddMultProb)
6029            .unwrap();
6030        assert!(changed);
6031        let val = agg.get_accumulator(&acc_key("a")).unwrap();
6032        assert!((val - 0.3).abs() < 1e-10, "expected 0.3, got {}", val);
6033    }
6034
6035    #[test]
6036    fn test_monotonic_nor_two_updates() {
6037        // Incremental NOR: acc = 1-(1-0.3)(1-0.5) = 0.65
6038        let mut agg = MonotonicAggState::new(make_nor_binding());
6039        let batch1 = make_f64_batch(&["a"], &[0.3]);
6040        agg.update(&[0], &[batch1], false, SemiringKind::AddMultProb)
6041            .unwrap();
6042        let batch2 = make_f64_batch(&["a"], &[0.5]);
6043        agg.update(&[0], &[batch2], false, SemiringKind::AddMultProb)
6044            .unwrap();
6045        let val = agg.get_accumulator(&acc_key("a")).unwrap();
6046        assert!((val - 0.65).abs() < 1e-10, "expected 0.65, got {}", val);
6047    }
6048
6049    #[test]
6050    fn test_monotonic_prod_first_update() {
6051        let mut agg = MonotonicAggState::new(make_prod_binding());
6052        let batch = make_f64_batch(&["a"], &[0.6]);
6053        let changed = agg
6054            .update(&[0], &[batch], false, SemiringKind::AddMultProb)
6055            .unwrap();
6056        assert!(changed);
6057        let val = agg.get_accumulator(&acc_key("a")).unwrap();
6058        assert!((val - 0.6).abs() < 1e-10, "expected 0.6, got {}", val);
6059    }
6060
6061    #[test]
6062    fn test_monotonic_prod_two_updates() {
6063        // Incremental PROD: acc = 0.6 * 0.8 = 0.48
6064        let mut agg = MonotonicAggState::new(make_prod_binding());
6065        let batch1 = make_f64_batch(&["a"], &[0.6]);
6066        agg.update(&[0], &[batch1], false, SemiringKind::AddMultProb)
6067            .unwrap();
6068        let batch2 = make_f64_batch(&["a"], &[0.8]);
6069        agg.update(&[0], &[batch2], false, SemiringKind::AddMultProb)
6070            .unwrap();
6071        let val = agg.get_accumulator(&acc_key("a")).unwrap();
6072        assert!((val - 0.48).abs() < 1e-10, "expected 0.48, got {}", val);
6073    }
6074
6075    #[test]
6076    fn test_monotonic_nor_stability() {
6077        let mut agg = MonotonicAggState::new(make_nor_binding());
6078        let batch = make_f64_batch(&["a"], &[0.3]);
6079        agg.update(&[0], &[batch], false, SemiringKind::AddMultProb)
6080            .unwrap();
6081        agg.snapshot();
6082        let changed = agg
6083            .update(&[0], &[], false, SemiringKind::AddMultProb)
6084            .unwrap();
6085        assert!(!changed);
6086        assert!(agg.is_stable());
6087    }
6088
6089    #[test]
6090    fn test_monotonic_prod_stability() {
6091        let mut agg = MonotonicAggState::new(make_prod_binding());
6092        let batch = make_f64_batch(&["a"], &[0.6]);
6093        agg.update(&[0], &[batch], false, SemiringKind::AddMultProb)
6094            .unwrap();
6095        agg.snapshot();
6096        let changed = agg
6097            .update(&[0], &[], false, SemiringKind::AddMultProb)
6098            .unwrap();
6099        assert!(!changed);
6100        assert!(agg.is_stable());
6101    }
6102
6103    #[test]
6104    fn test_monotonic_nor_multi_group() {
6105        // (a,0.3),(b,0.5) then (a,0.5),(b,0.2) → a=0.65, b=0.6
6106        let mut agg = MonotonicAggState::new(make_nor_binding());
6107        let batch1 = make_f64_batch(&["a", "b"], &[0.3, 0.5]);
6108        agg.update(&[0], &[batch1], false, SemiringKind::AddMultProb)
6109            .unwrap();
6110        let batch2 = make_f64_batch(&["a", "b"], &[0.5, 0.2]);
6111        agg.update(&[0], &[batch2], false, SemiringKind::AddMultProb)
6112            .unwrap();
6113
6114        let val_a = agg.get_accumulator(&acc_key("a")).unwrap();
6115        let val_b = agg.get_accumulator(&acc_key("b")).unwrap();
6116        assert!(
6117            (val_a - 0.65).abs() < 1e-10,
6118            "expected a=0.65, got {}",
6119            val_a
6120        );
6121        assert!((val_b - 0.6).abs() < 1e-10, "expected b=0.6, got {}", val_b);
6122    }
6123
6124    #[test]
6125    fn test_monotonic_prod_zero_absorbing() {
6126        // Zero absorbs: once 0.0, all further updates stay 0.0
6127        let mut agg = MonotonicAggState::new(make_prod_binding());
6128        let batch1 = make_f64_batch(&["a"], &[0.5]);
6129        agg.update(&[0], &[batch1], false, SemiringKind::AddMultProb)
6130            .unwrap();
6131        let batch2 = make_f64_batch(&["a"], &[0.0]);
6132        agg.update(&[0], &[batch2], false, SemiringKind::AddMultProb)
6133            .unwrap();
6134
6135        let val = agg.get_accumulator(&acc_key("a")).unwrap();
6136        assert!((val - 0.0).abs() < 1e-10, "expected 0.0, got {}", val);
6137
6138        // Further updates don't change the absorbing zero
6139        agg.snapshot();
6140        let batch3 = make_f64_batch(&["a"], &[0.5]);
6141        let changed = agg
6142            .update(&[0], &[batch3], false, SemiringKind::AddMultProb)
6143            .unwrap();
6144        assert!(!changed);
6145        assert!(agg.is_stable());
6146    }
6147
6148    #[test]
6149    fn test_monotonic_nor_clamping() {
6150        // 1.5 clamped to 1.0: acc = 1-(1-0)(1-1) = 1.0
6151        let mut agg = MonotonicAggState::new(make_nor_binding());
6152        let batch = make_f64_batch(&["a"], &[1.5]);
6153        agg.update(&[0], &[batch], false, SemiringKind::AddMultProb)
6154            .unwrap();
6155        let val = agg.get_accumulator(&acc_key("a")).unwrap();
6156        assert!((val - 1.0).abs() < 1e-10, "expected 1.0, got {}", val);
6157    }
6158
6159    #[test]
6160    fn test_monotonic_nor_absorbing() {
6161        // p=1.0 absorbs: 0.3 then 1.0 → 1.0
6162        let mut agg = MonotonicAggState::new(make_nor_binding());
6163        let batch1 = make_f64_batch(&["a"], &[0.3]);
6164        agg.update(&[0], &[batch1], false, SemiringKind::AddMultProb)
6165            .unwrap();
6166        let batch2 = make_f64_batch(&["a"], &[1.0]);
6167        agg.update(&[0], &[batch2], false, SemiringKind::AddMultProb)
6168            .unwrap();
6169        let val = agg.get_accumulator(&acc_key("a")).unwrap();
6170        assert!((val - 1.0).abs() < 1e-10, "expected 1.0, got {}", val);
6171    }
6172
6173    // ── MonotonicAggState strict mode tests (Phase 5) ─────────────────────
6174
6175    #[test]
6176    fn test_monotonic_agg_strict_nor_rejects() {
6177        let mut agg = MonotonicAggState::new(make_nor_binding());
6178        let batch = make_f64_batch(&["a"], &[1.5]);
6179        let result = agg.update(&[0], &[batch], true, SemiringKind::AddMultProb);
6180        assert!(result.is_err());
6181        let err = result.unwrap_err().to_string();
6182        assert!(
6183            err.contains("strict_probability_domain"),
6184            "Expected strict error, got: {}",
6185            err
6186        );
6187    }
6188
6189    #[test]
6190    fn test_monotonic_agg_strict_prod_rejects() {
6191        let mut agg = MonotonicAggState::new(make_prod_binding());
6192        let batch = make_f64_batch(&["a"], &[2.0]);
6193        let result = agg.update(&[0], &[batch], true, SemiringKind::AddMultProb);
6194        assert!(result.is_err());
6195        let err = result.unwrap_err().to_string();
6196        assert!(
6197            err.contains("strict_probability_domain"),
6198            "Expected strict error, got: {}",
6199            err
6200        );
6201    }
6202
6203    #[test]
6204    fn test_monotonic_agg_strict_accepts_valid() {
6205        let mut agg = MonotonicAggState::new(make_nor_binding());
6206        let batch = make_f64_batch(&["a"], &[0.5]);
6207        let result = agg.update(&[0], &[batch], true, SemiringKind::AddMultProb);
6208        assert!(result.is_ok());
6209        let val = agg.get_accumulator(&acc_key("a")).unwrap();
6210        assert!((val - 0.5).abs() < 1e-10, "expected 0.5, got {}", val);
6211    }
6212
6213    // ── Complement function unit tests (Phase 4) ──────────────────────────
6214
6215    fn make_vid_prob_batch(vids: &[u64], probs: &[f64]) -> RecordBatch {
6216        use arrow_array::UInt64Array;
6217        let schema = Arc::new(Schema::new(vec![
6218            Field::new("vid", DataType::UInt64, true),
6219            Field::new("prob", DataType::Float64, true),
6220        ]));
6221        RecordBatch::try_new(
6222            schema,
6223            vec![
6224                Arc::new(UInt64Array::from(vids.to_vec())),
6225                Arc::new(Float64Array::from(probs.to_vec())),
6226            ],
6227        )
6228        .unwrap()
6229    }
6230
6231    #[test]
6232    fn test_prob_complement_basic() {
6233        // neg has VID=1 with prob=0.7 → complement=0.3; VID=2 absent → complement=1.0
6234        let body = make_vid_prob_batch(&[1, 2], &[0.9, 0.8]);
6235        let neg = make_vid_prob_batch(&[1], &[0.7]);
6236        let join_cols = vec![("vid".to_string(), "vid".to_string())];
6237        let result = apply_prob_complement_composite(
6238            vec![body],
6239            &[neg],
6240            &join_cols,
6241            "prob",
6242            "__complement_0",
6243        )
6244        .unwrap();
6245        assert_eq!(result.len(), 1);
6246        let batch = &result[0];
6247        let complement = batch
6248            .column_by_name("__complement_0")
6249            .unwrap()
6250            .as_any()
6251            .downcast_ref::<Float64Array>()
6252            .unwrap();
6253        // VID=1: complement = 1 - 0.7 = 0.3
6254        assert!(
6255            (complement.value(0) - 0.3).abs() < 1e-10,
6256            "expected 0.3, got {}",
6257            complement.value(0)
6258        );
6259        // VID=2: absent from neg → complement = 1.0
6260        assert!(
6261            (complement.value(1) - 1.0).abs() < 1e-10,
6262            "expected 1.0, got {}",
6263            complement.value(1)
6264        );
6265    }
6266
6267    #[test]
6268    fn test_prob_complement_noisy_or_duplicates() {
6269        // neg has VID=1 twice with prob=0.3 and prob=0.5
6270        // Combined via noisy-OR: 1-(1-0.3)(1-0.5) = 0.65
6271        // Complement = 1 - 0.65 = 0.35
6272        let body = make_vid_prob_batch(&[1], &[0.9]);
6273        let neg = make_vid_prob_batch(&[1, 1], &[0.3, 0.5]);
6274        let join_cols = vec![("vid".to_string(), "vid".to_string())];
6275        let result = apply_prob_complement_composite(
6276            vec![body],
6277            &[neg],
6278            &join_cols,
6279            "prob",
6280            "__complement_0",
6281        )
6282        .unwrap();
6283        let batch = &result[0];
6284        let complement = batch
6285            .column_by_name("__complement_0")
6286            .unwrap()
6287            .as_any()
6288            .downcast_ref::<Float64Array>()
6289            .unwrap();
6290        assert!(
6291            (complement.value(0) - 0.35).abs() < 1e-10,
6292            "expected 0.35, got {}",
6293            complement.value(0)
6294        );
6295    }
6296
6297    #[test]
6298    fn test_prob_complement_empty_neg() {
6299        // Empty neg_facts → body passes through with complement=1.0
6300        let body = make_vid_prob_batch(&[1, 2], &[0.5, 0.6]);
6301        let join_cols = vec![("vid".to_string(), "vid".to_string())];
6302        let result =
6303            apply_prob_complement_composite(vec![body], &[], &join_cols, "prob", "__complement_0")
6304                .unwrap();
6305        let batch = &result[0];
6306        let complement = batch
6307            .column_by_name("__complement_0")
6308            .unwrap()
6309            .as_any()
6310            .downcast_ref::<Float64Array>()
6311            .unwrap();
6312        for i in 0..2 {
6313            assert!(
6314                (complement.value(i) - 1.0).abs() < 1e-10,
6315                "row {}: expected 1.0, got {}",
6316                i,
6317                complement.value(i)
6318            );
6319        }
6320    }
6321
6322    #[test]
6323    fn test_anti_join_basic() {
6324        // body [1,2,3], neg [2] → result [1,3]
6325        use arrow_array::UInt64Array;
6326        let body = make_vid_prob_batch(&[1, 2, 3], &[0.5, 0.6, 0.7]);
6327        let neg = make_vid_prob_batch(&[2], &[0.0]);
6328        let join_cols = vec![("vid".to_string(), "vid".to_string())];
6329        let result = apply_anti_join_composite(vec![body], &[neg], &join_cols).unwrap();
6330        assert_eq!(result.len(), 1);
6331        let batch = &result[0];
6332        assert_eq!(batch.num_rows(), 2);
6333        let vids = batch
6334            .column_by_name("vid")
6335            .unwrap()
6336            .as_any()
6337            .downcast_ref::<UInt64Array>()
6338            .unwrap();
6339        assert_eq!(vids.value(0), 1);
6340        assert_eq!(vids.value(1), 3);
6341    }
6342
6343    #[test]
6344    fn test_anti_join_empty_neg() {
6345        // Empty neg → all rows kept
6346        let body = make_vid_prob_batch(&[1, 2, 3], &[0.5, 0.6, 0.7]);
6347        let join_cols = vec![("vid".to_string(), "vid".to_string())];
6348        let result = apply_anti_join_composite(vec![body], &[], &join_cols).unwrap();
6349        assert_eq!(result.len(), 1);
6350        assert_eq!(result[0].num_rows(), 3);
6351    }
6352
6353    #[test]
6354    fn test_anti_join_all_excluded() {
6355        // neg covers all body rows → empty result
6356        let body = make_vid_prob_batch(&[1, 2], &[0.5, 0.6]);
6357        let neg = make_vid_prob_batch(&[1, 2], &[0.0, 0.0]);
6358        let join_cols = vec![("vid".to_string(), "vid".to_string())];
6359        let result = apply_anti_join_composite(vec![body], &[neg], &join_cols).unwrap();
6360        let total: usize = result.iter().map(|b| b.num_rows()).sum();
6361        assert_eq!(total, 0);
6362    }
6363
6364    #[test]
6365    fn test_multiply_prob_single_complement() {
6366        // prob=0.8, complement=0.5 → output prob=0.4; complement col removed
6367        let body = make_vid_prob_batch(&[1], &[0.8]);
6368        // Add a complement column
6369        let complement_arr = Float64Array::from(vec![0.5]);
6370        let mut cols: Vec<arrow_array::ArrayRef> = body.columns().to_vec();
6371        cols.push(Arc::new(complement_arr));
6372        let mut fields: Vec<Arc<Field>> = body.schema().fields().iter().cloned().collect();
6373        fields.push(Arc::new(Field::new(
6374            "__complement_0",
6375            DataType::Float64,
6376            true,
6377        )));
6378        let schema = Arc::new(Schema::new(fields));
6379        let batch = RecordBatch::try_new(schema, cols).unwrap();
6380
6381        let result =
6382            multiply_prob_factors(vec![batch], Some("prob"), &["__complement_0".to_string()])
6383                .unwrap();
6384        assert_eq!(result.len(), 1);
6385        let out = &result[0];
6386        // Complement column should be removed
6387        assert!(out.column_by_name("__complement_0").is_none());
6388        let prob = out
6389            .column_by_name("prob")
6390            .unwrap()
6391            .as_any()
6392            .downcast_ref::<Float64Array>()
6393            .unwrap();
6394        assert!(
6395            (prob.value(0) - 0.4).abs() < 1e-10,
6396            "expected 0.4, got {}",
6397            prob.value(0)
6398        );
6399    }
6400
6401    #[test]
6402    fn test_multiply_prob_multiple_complements() {
6403        // prob=0.8, c1=0.5, c2=0.6 → 0.8×0.5×0.6=0.24
6404        let body = make_vid_prob_batch(&[1], &[0.8]);
6405        let c1 = Float64Array::from(vec![0.5]);
6406        let c2 = Float64Array::from(vec![0.6]);
6407        let mut cols: Vec<arrow_array::ArrayRef> = body.columns().to_vec();
6408        cols.push(Arc::new(c1));
6409        cols.push(Arc::new(c2));
6410        let mut fields: Vec<Arc<Field>> = body.schema().fields().iter().cloned().collect();
6411        fields.push(Arc::new(Field::new("__c1", DataType::Float64, true)));
6412        fields.push(Arc::new(Field::new("__c2", DataType::Float64, true)));
6413        let schema = Arc::new(Schema::new(fields));
6414        let batch = RecordBatch::try_new(schema, cols).unwrap();
6415
6416        let result = multiply_prob_factors(
6417            vec![batch],
6418            Some("prob"),
6419            &["__c1".to_string(), "__c2".to_string()],
6420        )
6421        .unwrap();
6422        let out = &result[0];
6423        assert!(out.column_by_name("__c1").is_none());
6424        assert!(out.column_by_name("__c2").is_none());
6425        let prob = out
6426            .column_by_name("prob")
6427            .unwrap()
6428            .as_any()
6429            .downcast_ref::<Float64Array>()
6430            .unwrap();
6431        assert!(
6432            (prob.value(0) - 0.24).abs() < 1e-10,
6433            "expected 0.24, got {}",
6434            prob.value(0)
6435        );
6436    }
6437
6438    #[test]
6439    fn test_multiply_prob_no_prob_column() {
6440        // No prob column → combined complements become the output
6441        use arrow_array::UInt64Array;
6442        let schema = Arc::new(Schema::new(vec![
6443            Field::new("vid", DataType::UInt64, true),
6444            Field::new("__c1", DataType::Float64, true),
6445        ]));
6446        let batch = RecordBatch::try_new(
6447            schema,
6448            vec![
6449                Arc::new(UInt64Array::from(vec![1u64])),
6450                Arc::new(Float64Array::from(vec![0.7])),
6451            ],
6452        )
6453        .unwrap();
6454
6455        let result = multiply_prob_factors(vec![batch], None, &["__c1".to_string()]).unwrap();
6456        let out = &result[0];
6457        // __c1 should be removed since it's a complement column
6458        assert!(out.column_by_name("__c1").is_none());
6459        // Only vid column remains
6460        assert_eq!(out.num_columns(), 1);
6461    }
6462}