Skip to main content

uni_query/query/df_graph/
locy_fixpoint.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Fixpoint iteration operator for recursive Locy strata.
5//!
6//! `FixpointExec` drives semi-naive evaluation: it repeatedly evaluates the rules
7//! in a recursive stratum, feeding back deltas until no new facts are produced.
8
9use crate::query::df_graph::GraphExecutionContext;
10use crate::query::df_graph::common::{
11    ScalarKey, arrow_err, collect_all_partitions, compute_plan_properties, execute_subplan,
12    extract_scalar_key,
13};
14use crate::query::df_graph::locy_best_by::{BestByExec, SortCriterion};
15use crate::query::df_graph::locy_errors::LocyRuntimeError;
16use crate::query::df_graph::locy_explain::{
17    ProofTerm, ProvenanceAnnotation, ProvenanceStore, compute_proof_probability,
18};
19use crate::query::df_graph::locy_fold::{FoldBinding, FoldExec};
20use crate::query::df_graph::locy_priority::PriorityExec;
21use crate::query::df_graph::locy_program::interruption;
22use crate::query::planner::LogicalPlan;
23use arrow_array::RecordBatch;
24use arrow_row::{RowConverter, SortField};
25use arrow_schema::SchemaRef;
26use datafusion::common::JoinType;
27use datafusion::common::Result as DFResult;
28use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
29use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode};
30use datafusion::physical_plan::memory::MemoryStream;
31use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
32use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
33use futures::Stream;
34use parking_lot::RwLock;
35use std::any::Any;
36use std::collections::{HashMap, HashSet};
37use std::fmt;
38use std::pin::Pin;
39use std::sync::{Arc, RwLock as StdRwLock};
40use std::task::{Context, Poll};
41use std::time::{Duration, Instant};
42use uni_common::Value;
43use uni_common::core::schema::Schema as UniSchema;
44use uni_cypher::ast::Expr;
45use uni_locy::{
46    ClassifierRegistry, ModelInvocation, ModelInvocationCache, RuntimeWarning, RuntimeWarningCode,
47    SemiringKind,
48};
49use uni_store::storage::manager::StorageManager;
50
51// ---------------------------------------------------------------------------
52// DerivedScanRegistry — injection point for IS-ref data into subplans
53// ---------------------------------------------------------------------------
54
55/// A single entry in the derived scan registry.
56///
57/// Each entry corresponds to one `LocyDerivedScan` node in the logical plan tree.
58/// The `data` handle is shared with the logical plan node so that writing data here
59/// makes it visible when the subplan is re-planned and executed.
60#[derive(Debug)]
61pub struct DerivedScanEntry {
62    /// Index matching the `scan_index` in `LocyDerivedScan`.
63    pub scan_index: usize,
64    /// Name of the rule this scan reads from.
65    pub rule_name: String,
66    /// Whether this is a self-referential scan (rule references itself).
67    pub is_self_ref: bool,
68    /// Shared data handle — write batches here to inject into subplans.
69    pub data: Arc<RwLock<Vec<RecordBatch>>>,
70    /// Schema of the derived relation.
71    pub schema: SchemaRef,
72}
73
74/// Registry of derived scan handles for fixpoint iteration.
75///
76/// During fixpoint, each clause body may reference derived relations via
77/// `LocyDerivedScan` nodes. The registry maps scan indices to shared data
78/// handles so the fixpoint loop can inject delta/full facts before each
79/// iteration.
80#[derive(Debug, Default)]
81pub struct DerivedScanRegistry {
82    entries: Vec<DerivedScanEntry>,
83}
84
85impl DerivedScanRegistry {
86    /// Create a new empty registry.
87    pub fn new() -> Self {
88        Self::default()
89    }
90
91    /// Add an entry to the registry.
92    pub fn add(&mut self, entry: DerivedScanEntry) {
93        self.entries.push(entry);
94    }
95
96    /// Get an entry by scan index.
97    pub fn get(&self, scan_index: usize) -> Option<&DerivedScanEntry> {
98        self.entries.iter().find(|e| e.scan_index == scan_index)
99    }
100
101    /// Write data into a scan entry's shared handle.
102    pub fn write_data(&self, scan_index: usize, batches: Vec<RecordBatch>) {
103        if let Some(entry) = self.get(scan_index) {
104            let mut guard = entry.data.write();
105            *guard = batches;
106        }
107    }
108
109    /// Get all entries for a given rule name.
110    pub fn entries_for_rule(&self, rule_name: &str) -> Vec<&DerivedScanEntry> {
111        self.entries
112            .iter()
113            .filter(|e| e.rule_name == rule_name)
114            .collect()
115    }
116}
117
118// ---------------------------------------------------------------------------
119// MonotonicAggState — tracking monotonic aggregates across iterations
120// ---------------------------------------------------------------------------
121
122/// Monotonic aggregate binding: maps a fold name to its aggregate
123/// trait object and input column.
124///
125/// Dispatches purely through [`uni_plugin::traits::locy::LocyAggregate`]
126/// (`update_step` / `initial_accum_f64` / `is_probability_aggregate` /
127/// `is_noisy_or`).
128#[derive(Debug, Clone)]
129pub struct MonotonicFoldBinding {
130    pub fold_name: String,
131    pub aggregate: std::sync::Arc<dyn uni_plugin::traits::locy::LocyAggregate>,
132    pub input_col_index: usize,
133    /// Column name for name-based resolution (more robust than positional index).
134    pub input_col_name: Option<String>,
135}
136
137/// Tracks monotonic aggregate accumulators across fixpoint iterations.
138///
139/// After each iteration, accumulators are updated and compared to their previous
140/// snapshot. The fixpoint has converged (w.r.t. aggregates) when all accumulators
141/// are stable (no change between iterations).
142#[derive(Debug)]
143pub struct MonotonicAggState {
144    /// Current accumulator values keyed by (group_key, fold_name).
145    accumulators: HashMap<(Vec<ScalarKey>, String), f64>,
146    /// Snapshot from the previous iteration for stability check.
147    prev_snapshot: HashMap<(Vec<ScalarKey>, String), f64>,
148    /// Bindings describing which aggregates to track.
149    bindings: Vec<MonotonicFoldBinding>,
150}
151
152impl MonotonicAggState {
153    /// Create a new monotonic aggregate state.
154    pub fn new(bindings: Vec<MonotonicFoldBinding>) -> Self {
155        Self {
156            accumulators: HashMap::new(),
157            prev_snapshot: HashMap::new(),
158            bindings,
159        }
160    }
161
162    /// Update accumulators with new delta batches.
163    ///
164    /// Returns `true` if any accumulator value changed. When `strict` is
165    /// `true`, MNOR/MPROD inputs outside `[0, 1]` produce an error
166    /// instead of being clamped.
167    ///
168    /// `semiring_kind` selects the probability semiring for probability
169    /// aggregates: `AddMultProb` (default, Phase 1/2 noisy-OR/product) or
170    /// `MaxMinProb` (Viterbi/fuzzy — opt-in, callers emit
171    /// `FuzzyNotProbabilistic`).
172    ///
173    /// Dispatch goes through each binding's `Arc<dyn LocyAggregate>` trait
174    /// object via [`uni_plugin::traits::locy::LocyAggregate::update_step`].
175    /// The trait object's `initial_accum_f64()` seeds the per-group
176    /// accumulator. Under `MaxMinProb`, probability aggregates (MNOR /
177    /// MPROD) bypass `update_step` and fold via the `MaxMinProb` semiring's
178    /// `plus` (max) / `times` (min) instead — preserving the opt-in
179    /// Viterbi/fuzzy semantics that `update_step`'s built-in noisy-OR /
180    /// product path does not implement.
181    ///
182    /// Aggregates whose `update_step` returns `Err(CODE_UNKNOWN_FUNCTION)`
183    /// (default impl — no row-level fast path; e.g., `AVG`, `COLLECT`)
184    /// are skipped silently here — those run through the batch-shape
185    /// [`uni_plugin::traits::locy::LocyAggState::ingest`] path in
186    /// `apply_post_fixpoint_chain` instead.
187    pub fn update(
188        &mut self,
189        key_indices: &[usize],
190        delta_batches: &[RecordBatch],
191        strict: bool,
192        semiring_kind: SemiringKind,
193    ) -> DFResult<bool> {
194        let mut changed = false;
195        for batch in delta_batches {
196            for row_idx in 0..batch.num_rows() {
197                let group_key = extract_scalar_key(batch, key_indices, row_idx);
198                for binding in &self.bindings {
199                    let idx = binding
200                        .input_col_name
201                        .as_ref()
202                        .and_then(|name| batch.schema().index_of(name).ok())
203                        .unwrap_or(binding.input_col_index);
204                    if idx >= batch.num_columns() {
205                        continue;
206                    }
207                    let col = batch.column(idx);
208                    let val = extract_f64(col.as_ref(), row_idx);
209                    if let Some(val) = val {
210                        let map_key = (group_key.clone(), binding.fold_name.clone());
211                        let initial = binding.aggregate.initial_accum_f64().unwrap_or(0.0);
212                        let entry = self.accumulators.entry(map_key).or_insert(initial);
213                        let old = *entry;
214                        // Under `MaxMinProb`, probability aggregates (MNOR /
215                        // MPROD) fold via the Viterbi/fuzzy semiring (max /
216                        // min) rather than the trait object's built-in
217                        // noisy-OR / product `update_step`. The inline domain
218                        // checks below preserve the exact strict-mode error
219                        // and clamp-warning literals. `is_noisy_or()`
220                        // distinguishes MNOR (disjunction → max) from MPROD
221                        // (conjunction → min). All other aggregates — and the
222                        // default `AddMultProb` semiring — dispatch through
223                        // the trait object's `update_step`.
224                        if matches!(semiring_kind, SemiringKind::MaxMinProb)
225                            && binding.aggregate.is_probability_aggregate()
226                        {
227                            use uni_locy::LocySemiring;
228                            let sr = uni_locy::MaxMinProb;
229                            let is_nor = binding.aggregate.is_noisy_or();
230                            let label = if is_nor { "MNOR" } else { "MPROD" };
231                            if strict && !(0.0..=1.0).contains(&val) {
232                                return Err(datafusion::error::DataFusionError::Execution(
233                                    format!(
234                                        "strict_probability_domain: {label} input {val} is outside [0, 1]"
235                                    ),
236                                ));
237                            }
238                            if !strict && !(0.0..=1.0).contains(&val) {
239                                tracing::warn!(
240                                    "{label} input {val} outside [0,1], clamped to {}",
241                                    val.clamp(0.0, 1.0)
242                                );
243                            }
244                            let p = val.clamp(0.0, 1.0);
245                            // MaxMinProb: MNOR -> max (plus), MPROD -> min (times).
246                            *entry = if is_nor {
247                                sr.plus(entry, &p)
248                            } else {
249                                sr.times(entry, &p)
250                            };
251                            if (*entry - old).abs() > f64::EPSILON {
252                                changed = true;
253                            }
254                            continue;
255                        }
256                        match binding.aggregate.update_step(*entry, val, strict) {
257                            Ok(new_val) => {
258                                *entry = new_val;
259                                if (*entry - old).abs() > f64::EPSILON {
260                                    changed = true;
261                                }
262                            }
263                            Err(e) if e.code == uni_plugin::FnError::CODE_UNKNOWN_FUNCTION => {
264                                // Aggregate has no row-level fast path (AVG,
265                                // COLLECT). Those run through the
266                                // batch-shape `ingest` path elsewhere; skip.
267                            }
268                            Err(e) => {
269                                // Strict-mode probability-domain violation,
270                                // or another aggregate-specific failure.
271                                return Err(datafusion::error::DataFusionError::Execution(
272                                    e.message,
273                                ));
274                            }
275                        }
276                    }
277                }
278            }
279        }
280        Ok(changed)
281    }
282
283    /// Take a snapshot of current accumulators for stability comparison.
284    pub fn snapshot(&mut self) {
285        self.prev_snapshot = self.accumulators.clone();
286    }
287
288    /// Check if accumulators are stable (no change since last snapshot).
289    pub fn is_stable(&self) -> bool {
290        if self.accumulators.len() != self.prev_snapshot.len() {
291            return false;
292        }
293        for (key, val) in &self.accumulators {
294            match self.prev_snapshot.get(key) {
295                Some(prev) if (*val - *prev).abs() <= f64::EPSILON => {}
296                _ => return false,
297            }
298        }
299        true
300    }
301
302    /// Test-only accessor for accumulator values.
303    #[cfg(test)]
304    pub(crate) fn get_accumulator(&self, key: &(Vec<ScalarKey>, String)) -> Option<f64> {
305        self.accumulators.get(key).copied()
306    }
307}
308
309/// Extract f64 value from an Arrow column at a given row index.
310fn extract_f64(col: &dyn arrow_array::Array, row_idx: usize) -> Option<f64> {
311    if col.is_null(row_idx) {
312        return None;
313    }
314    if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Float64Array>() {
315        Some(arr.value(row_idx))
316    } else {
317        col.as_any()
318            .downcast_ref::<arrow_array::Int64Array>()
319            .map(|arr| arr.value(row_idx) as f64)
320    }
321}
322
323// ---------------------------------------------------------------------------
324// RowDedupState — Arrow RowConverter-based persistent dedup set
325// ---------------------------------------------------------------------------
326
327/// Arrow-native row deduplication using [`RowConverter`].
328///
329/// Unlike the legacy `HashSet<Vec<ScalarKey>>` approach, this struct maintains a
330/// persistent `seen` set across iterations so per-iteration cost is O(M) where M
331/// is the number of candidate rows — the full facts table is never re-scanned.
332struct RowDedupState {
333    converter: RowConverter,
334    seen: HashSet<Box<[u8]>>,
335}
336
337impl RowDedupState {
338    /// Try to build a `RowDedupState` for the given schema.
339    ///
340    /// Returns `None` if any column type is not supported by `RowConverter`
341    /// (triggers legacy fallback).
342    fn try_new(schema: &SchemaRef) -> Option<Self> {
343        let fields: Vec<SortField> = schema
344            .fields()
345            .iter()
346            .map(|f| SortField::new(f.data_type().clone()))
347            .collect();
348        match RowConverter::new(fields) {
349            Ok(converter) => Some(Self {
350                converter,
351                seen: HashSet::new(),
352            }),
353            Err(e) => {
354                tracing::warn!(
355                    "RowDedupState: RowConverter unsupported for schema, falling back to legacy dedup: {}",
356                    e
357                );
358                None
359            }
360        }
361    }
362
363    /// Populate the seen set from existing fact batches.
364    ///
365    /// Used after BEST BY in-loop pruning replaces the fact set, so that delta
366    /// computation in subsequent iterations correctly recognizes surviving facts.
367    fn ingest_existing(&mut self, facts: &[RecordBatch], _schema: &SchemaRef) {
368        self.seen.clear();
369        for batch in facts {
370            if batch.num_rows() == 0 {
371                continue;
372            }
373            let arrays: Vec<_> = batch.columns().to_vec();
374            if let Ok(rows) = self.converter.convert_columns(&arrays) {
375                for row_idx in 0..batch.num_rows() {
376                    let row_bytes: Box<[u8]> = rows.row(row_idx).data().into();
377                    self.seen.insert(row_bytes);
378                }
379            }
380        }
381    }
382
383    /// Filter `candidates` to only rows not yet seen, updating the persistent set.
384    ///
385    /// Both cross-iteration dedup (rows already accepted in prior iterations) and
386    /// within-batch dedup (duplicate rows in a single candidate batch) are handled
387    /// in a single pass.
388    fn compute_delta(
389        &mut self,
390        candidates: &[RecordBatch],
391        schema: &SchemaRef,
392    ) -> DFResult<Vec<RecordBatch>> {
393        let mut delta_batches = Vec::new();
394        for batch in candidates {
395            if batch.num_rows() == 0 {
396                continue;
397            }
398
399            // Vectorized encoding of all rows in this batch.
400            let arrays: Vec<_> = batch.columns().to_vec();
401            let rows = self.converter.convert_columns(&arrays).map_err(arrow_err)?;
402
403            // One pass: check+insert into persistent seen set.
404            let mut keep = Vec::with_capacity(batch.num_rows());
405            for row_idx in 0..batch.num_rows() {
406                let row_bytes: Box<[u8]> = rows.row(row_idx).data().into();
407                keep.push(self.seen.insert(row_bytes));
408            }
409
410            let keep_mask = arrow_array::BooleanArray::from(keep);
411            let new_cols = batch
412                .columns()
413                .iter()
414                .map(|col| {
415                    arrow::compute::filter(col.as_ref(), &keep_mask).map_err(|e| {
416                        datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
417                    })
418                })
419                .collect::<DFResult<Vec<_>>>()?;
420
421            if new_cols.first().is_some_and(|c| !c.is_empty()) {
422                let filtered = RecordBatch::try_new(Arc::clone(schema), new_cols).map_err(|e| {
423                    datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
424                })?;
425                delta_batches.push(filtered);
426            }
427        }
428        Ok(delta_batches)
429    }
430}
431
432// ---------------------------------------------------------------------------
433// FixpointState — per-rule delta tracking during fixpoint iteration
434// ---------------------------------------------------------------------------
435
436/// Per-rule state for fixpoint iteration.
437///
438/// Tracks accumulated facts and the delta (new facts from the latest iteration).
439/// Deduplication uses Arrow [`RowConverter`] with a persistent seen set (O(M) per
440/// iteration) when supported, with a legacy `HashSet<Vec<ScalarKey>>` fallback.
441pub struct FixpointState {
442    rule_name: String,
443    facts: Vec<RecordBatch>,
444    delta: Vec<RecordBatch>,
445    schema: SchemaRef,
446    key_column_indices: Vec<usize>,
447    /// KEY column names for recomputing indices after schema reconciliation.
448    key_column_names: Vec<String>,
449    /// All column indices for full-row dedup (legacy path only).
450    all_column_indices: Vec<usize>,
451    /// Running total of facts bytes for memory limit tracking.
452    facts_bytes: usize,
453    /// Maximum bytes allowed for this derived relation.
454    max_derived_bytes: usize,
455    /// Optional monotonic aggregate tracking.
456    monotonic_agg: Option<MonotonicAggState>,
457    /// Arrow RowConverter-based dedup state; `None` triggers legacy fallback.
458    row_dedup: Option<RowDedupState>,
459    /// Whether strict probability domain checks are enabled.
460    strict_probability_domain: bool,
461    /// Active probability semiring for this rule's MNOR/MPROD math.
462    semiring_kind: SemiringKind,
463}
464
465impl FixpointState {
466    /// Create a new fixpoint state for a rule. Existing tests call this
467    /// with the Phase 1/2 default; the fixpoint planner uses
468    /// [`FixpointState::new_with_semiring`] to thread the configured
469    /// semiring through.
470    pub fn new(
471        rule_name: String,
472        schema: SchemaRef,
473        key_column_indices: Vec<usize>,
474        max_derived_bytes: usize,
475        monotonic_agg: Option<MonotonicAggState>,
476        strict_probability_domain: bool,
477    ) -> Self {
478        Self::new_with_semiring(
479            rule_name,
480            schema,
481            key_column_indices,
482            max_derived_bytes,
483            monotonic_agg,
484            strict_probability_domain,
485            SemiringKind::AddMultProb,
486        )
487    }
488
489    pub fn new_with_semiring(
490        rule_name: String,
491        schema: SchemaRef,
492        key_column_indices: Vec<usize>,
493        max_derived_bytes: usize,
494        monotonic_agg: Option<MonotonicAggState>,
495        strict_probability_domain: bool,
496        semiring_kind: SemiringKind,
497    ) -> Self {
498        let num_cols = schema.fields().len();
499        let row_dedup = RowDedupState::try_new(&schema);
500        let key_column_names: Vec<String> = key_column_indices
501            .iter()
502            .filter_map(|&i| schema.fields().get(i).map(|f| f.name().clone()))
503            .collect();
504        Self {
505            rule_name,
506            facts: Vec::new(),
507            delta: Vec::new(),
508            schema,
509            key_column_indices,
510            key_column_names,
511            all_column_indices: (0..num_cols).collect(),
512            facts_bytes: 0,
513            max_derived_bytes,
514            monotonic_agg,
515            row_dedup,
516            strict_probability_domain,
517            semiring_kind,
518        }
519    }
520
521    /// Reconcile the pre-computed schema with the actual physical plan output.
522    ///
523    /// `infer_expr_type` may guess wrong (e.g. `Property → Float64` for a
524    /// string column).  When the first real batch arrives with a different
525    /// schema, update ours so that `RowDedupState` / `RecordBatch::try_new`
526    /// use the correct types.
527    fn reconcile_schema(&mut self, actual_schema: &SchemaRef) {
528        if self.schema.fields() != actual_schema.fields() {
529            tracing::debug!(
530                rule = %self.rule_name,
531                "Reconciling fixpoint schema from physical plan output",
532            );
533            self.schema = Arc::clone(actual_schema);
534            self.row_dedup = RowDedupState::try_new(&self.schema);
535            // Recompute key_column_indices from stored KEY column names.
536            // Without this, FoldExec groups by wrong columns when the
537            // physical plan reorders columns vs the pre-inferred schema.
538            let new_indices: Vec<usize> = self
539                .key_column_names
540                .iter()
541                .filter_map(|name| actual_schema.index_of(name).ok())
542                .collect();
543            if new_indices.len() == self.key_column_names.len() {
544                self.key_column_indices = new_indices;
545            }
546            // else: not all KEY columns found in new schema — keep original indices
547            let num_cols = actual_schema.fields().len();
548            self.all_column_indices = (0..num_cols).collect();
549        }
550    }
551
552    /// Merge candidate rows into facts, computing delta (truly new rows).
553    ///
554    /// Returns `true` if any new facts were added.
555    pub async fn merge_delta(
556        &mut self,
557        candidates: Vec<RecordBatch>,
558        task_ctx: Option<Arc<TaskContext>>,
559    ) -> DFResult<bool> {
560        if candidates.is_empty() || candidates.iter().all(|b| b.num_rows() == 0) {
561            self.delta.clear();
562            return Ok(false);
563        }
564
565        // Reconcile schema from the first non-empty candidate batch.
566        // The physical plan's output types are authoritative over the
567        // planner's inferred types.
568        if let Some(first) = candidates.iter().find(|b| b.num_rows() > 0) {
569            self.reconcile_schema(&first.schema());
570        }
571
572        // Round floats for stable dedup
573        let candidates = round_float_columns(&candidates);
574
575        // Compute delta: rows in candidates not already in facts
576        let delta = self.compute_delta(&candidates, task_ctx.as_ref()).await?;
577
578        if delta.is_empty() || delta.iter().all(|b| b.num_rows() == 0) {
579            self.delta.clear();
580            // Update monotonic aggs even with empty delta (for stability check)
581            if let Some(ref mut agg) = self.monotonic_agg {
582                agg.snapshot();
583            }
584            return Ok(false);
585        }
586
587        // Check memory limit
588        let delta_bytes: usize = delta.iter().map(batch_byte_size).sum();
589        if self.facts_bytes + delta_bytes > self.max_derived_bytes {
590            return Err(datafusion::error::DataFusionError::Execution(
591                LocyRuntimeError::MemoryLimitExceeded {
592                    rule: self.rule_name.clone(),
593                    bytes: self.facts_bytes + delta_bytes,
594                    limit: self.max_derived_bytes,
595                }
596                .to_string(),
597            ));
598        }
599
600        // Update monotonic aggs
601        if let Some(ref mut agg) = self.monotonic_agg {
602            agg.snapshot();
603            agg.update(
604                &self.key_column_indices,
605                &delta,
606                self.strict_probability_domain,
607                self.semiring_kind,
608            )?;
609        }
610
611        // Append delta to facts
612        self.facts_bytes += delta_bytes;
613        self.facts.extend(delta.iter().cloned());
614        self.delta = delta;
615
616        Ok(true)
617    }
618
619    /// Dispatch to vectorized LeftAntiJoin, Arrow RowConverter dedup, or legacy ScalarKey dedup.
620    ///
621    /// Priority order:
622    /// 1. `arrow_left_anti_dedup` when `total_existing >= DEDUP_ANTI_JOIN_THRESHOLD` and task_ctx available.
623    /// 2. `RowDedupState` (persistent HashSet, O(M) per iteration) when schema is supported.
624    /// 3. `compute_delta_legacy` (rebuilds from facts, fallback for unsupported column types).
625    async fn compute_delta(
626        &mut self,
627        candidates: &[RecordBatch],
628        task_ctx: Option<&Arc<TaskContext>>,
629    ) -> DFResult<Vec<RecordBatch>> {
630        let total_existing: usize = self.facts.iter().map(|b| b.num_rows()).sum();
631        if total_existing >= DEDUP_ANTI_JOIN_THRESHOLD
632            && let Some(ctx) = task_ctx
633        {
634            return arrow_left_anti_dedup(candidates.to_vec(), &self.facts, &self.schema, ctx)
635                .await;
636        }
637        if let Some(ref mut rd) = self.row_dedup {
638            rd.compute_delta(candidates, &self.schema)
639        } else {
640            self.compute_delta_legacy(candidates)
641        }
642    }
643
644    /// Legacy dedup: rebuild a `HashSet<Vec<ScalarKey>>` from all facts each call.
645    ///
646    /// Used as fallback when `RowConverter` does not support the schema's column types.
647    fn compute_delta_legacy(&self, candidates: &[RecordBatch]) -> DFResult<Vec<RecordBatch>> {
648        // Build set of existing fact row keys (ALL columns)
649        let mut existing: HashSet<Vec<ScalarKey>> = HashSet::new();
650        for batch in &self.facts {
651            for row_idx in 0..batch.num_rows() {
652                let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
653                existing.insert(key);
654            }
655        }
656
657        let mut delta_batches = Vec::new();
658        for batch in candidates {
659            if batch.num_rows() == 0 {
660                continue;
661            }
662            // Filter to only new rows
663            let mut keep = Vec::with_capacity(batch.num_rows());
664            for row_idx in 0..batch.num_rows() {
665                let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
666                keep.push(!existing.contains(&key));
667            }
668
669            // Also dedup within the candidate batch itself
670            for (row_idx, kept) in keep.iter_mut().enumerate() {
671                if *kept {
672                    let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
673                    if !existing.insert(key) {
674                        *kept = false;
675                    }
676                }
677            }
678
679            let keep_mask = arrow_array::BooleanArray::from(keep);
680            let new_rows = batch
681                .columns()
682                .iter()
683                .map(|col| {
684                    arrow::compute::filter(col.as_ref(), &keep_mask).map_err(|e| {
685                        datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
686                    })
687                })
688                .collect::<DFResult<Vec<_>>>()?;
689
690            if new_rows.first().is_some_and(|c| !c.is_empty()) {
691                let filtered =
692                    RecordBatch::try_new(Arc::clone(&self.schema), new_rows).map_err(|e| {
693                        datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
694                    })?;
695                delta_batches.push(filtered);
696            }
697        }
698
699        Ok(delta_batches)
700    }
701
702    /// Check if this rule has converged (no new facts and aggs stable).
703    pub fn is_converged(&self) -> bool {
704        let delta_empty = self.delta.is_empty() || self.delta.iter().all(|b| b.num_rows() == 0);
705        let agg_stable = self.monotonic_agg.as_ref().is_none_or(|a| a.is_stable());
706        delta_empty && agg_stable
707    }
708
709    /// Get all accumulated facts.
710    pub fn all_facts(&self) -> &[RecordBatch] {
711        &self.facts
712    }
713
714    /// Get the delta from the latest iteration.
715    pub fn all_delta(&self) -> &[RecordBatch] {
716        &self.delta
717    }
718
719    /// Consume self and return facts.
720    pub fn into_facts(self) -> Vec<RecordBatch> {
721        self.facts
722    }
723
724    /// Merge candidates using BEST BY semantics.
725    ///
726    /// Combines existing facts with new candidates, keeping only the best row
727    /// per KEY group according to `sort_criteria`. Returns `true` if the
728    /// best-per-KEY fact set actually changed (a genuinely better value was
729    /// found or a new KEY appeared).
730    ///
731    /// This replaces `merge_delta` for rules with BEST BY, enabling convergence
732    /// on cyclic graphs where dominated ALONG values would otherwise produce an
733    /// unbounded stream of "new" full-row facts.
734    pub fn merge_best_by(
735        &mut self,
736        candidates: Vec<RecordBatch>,
737        sort_criteria: &[SortCriterion],
738    ) -> DFResult<bool> {
739        if candidates.is_empty() || candidates.iter().all(|b| b.num_rows() == 0) {
740            self.delta.clear();
741            return Ok(false);
742        }
743
744        // Reconcile schema from the first non-empty candidate batch.
745        if let Some(first) = candidates.iter().find(|b| b.num_rows() > 0) {
746            self.reconcile_schema(&first.schema());
747        }
748
749        // Round floats for stable dedup.
750        let candidates = round_float_columns(&candidates);
751
752        // Snapshot existing best-per-KEY facts for change detection.
753        let old_best: HashMap<Vec<ScalarKey>, Vec<ScalarKey>> =
754            self.build_key_criteria_map(sort_criteria);
755
756        // Concat existing facts + new candidates.
757        let mut all_batches = self.facts.clone();
758        all_batches.extend(candidates);
759        let all_batches: Vec<_> = all_batches
760            .into_iter()
761            .filter(|b| b.num_rows() > 0)
762            .collect();
763        if all_batches.is_empty() {
764            self.delta.clear();
765            return Ok(false);
766        }
767
768        let combined = arrow::compute::concat_batches(&self.schema, &all_batches)
769            .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
770
771        if combined.num_rows() == 0 {
772            self.delta.clear();
773            return Ok(false);
774        }
775
776        // Sort by KEY ASC then criteria, so the best row per KEY group comes
777        // first.
778        let mut sort_columns = Vec::new();
779        for &ki in &self.key_column_indices {
780            if ki >= combined.num_columns() {
781                continue;
782            }
783            sort_columns.push(arrow::compute::SortColumn {
784                values: Arc::clone(combined.column(ki)),
785                options: Some(arrow::compute::SortOptions {
786                    descending: false,
787                    nulls_first: false,
788                }),
789            });
790        }
791        for criterion in sort_criteria {
792            if criterion.col_index >= combined.num_columns() {
793                continue;
794            }
795            sort_columns.push(arrow::compute::SortColumn {
796                values: Arc::clone(combined.column(criterion.col_index)),
797                options: Some(arrow::compute::SortOptions {
798                    descending: !criterion.ascending,
799                    nulls_first: criterion.nulls_first,
800                }),
801            });
802        }
803
804        let sorted_indices =
805            arrow::compute::lexsort_to_indices(&sort_columns, None).map_err(arrow_err)?;
806        let sorted_columns: Vec<_> = combined
807            .columns()
808            .iter()
809            .map(|col| arrow::compute::take(col.as_ref(), &sorted_indices, None))
810            .collect::<Result<Vec<_>, _>>()
811            .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
812        let sorted = RecordBatch::try_new(Arc::clone(&self.schema), sorted_columns)
813            .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
814
815        // Dedup: keep first (best) row per KEY group.
816        let mut keep_indices: Vec<u32> = Vec::new();
817        let mut prev_key: Option<Vec<ScalarKey>> = None;
818        for row_idx in 0..sorted.num_rows() {
819            let key = extract_scalar_key(&sorted, &self.key_column_indices, row_idx);
820            let is_new_group = match &prev_key {
821                None => true,
822                Some(prev) => *prev != key,
823            };
824            if is_new_group {
825                keep_indices.push(row_idx as u32);
826                prev_key = Some(key);
827            }
828        }
829
830        let keep_array = arrow_array::UInt32Array::from(keep_indices);
831        let output_columns: Vec<_> = sorted
832            .columns()
833            .iter()
834            .map(|col| arrow::compute::take(col.as_ref(), &keep_array, None))
835            .collect::<Result<Vec<_>, _>>()
836            .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
837        let pruned = RecordBatch::try_new(Arc::clone(&self.schema), output_columns)
838            .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
839
840        // Detect whether the best-per-KEY set actually changed.
841        let new_best: HashMap<Vec<ScalarKey>, Vec<ScalarKey>> = {
842            let mut map = HashMap::new();
843            for row_idx in 0..pruned.num_rows() {
844                let key = extract_scalar_key(&pruned, &self.key_column_indices, row_idx);
845                let criteria: Vec<ScalarKey> = sort_criteria
846                    .iter()
847                    .flat_map(|c| extract_scalar_key(&pruned, &[c.col_index], row_idx))
848                    .collect();
849                map.insert(key, criteria);
850            }
851            map
852        };
853        let changed = old_best != new_best;
854
855        tracing::debug!(
856            rule = %self.rule_name,
857            old_keys = old_best.len(),
858            new_keys = new_best.len(),
859            changed = changed,
860            "BEST BY merge"
861        );
862
863        // Replace facts with the pruned set.
864        self.facts_bytes = batch_byte_size(&pruned);
865        self.facts = vec![pruned];
866        if changed {
867            // Delta is conceptually the new/improved facts, but since we
868            // replaced the entire set, just mark delta non-empty.
869            self.delta = self.facts.clone();
870        } else {
871            self.delta.clear();
872        }
873
874        // Rebuild row dedup from pruned facts for consistency.
875        self.row_dedup = RowDedupState::try_new(&self.schema);
876        if let Some(ref mut rd) = self.row_dedup {
877            rd.ingest_existing(&self.facts, &self.schema);
878        }
879
880        Ok(changed)
881    }
882
883    /// Build a map from KEY column values to sort criteria values.
884    fn build_key_criteria_map(
885        &self,
886        sort_criteria: &[SortCriterion],
887    ) -> HashMap<Vec<ScalarKey>, Vec<ScalarKey>> {
888        let mut map = HashMap::new();
889        for batch in &self.facts {
890            for row_idx in 0..batch.num_rows() {
891                let key = extract_scalar_key(batch, &self.key_column_indices, row_idx);
892                let criteria: Vec<ScalarKey> = sort_criteria
893                    .iter()
894                    .flat_map(|c| extract_scalar_key(batch, &[c.col_index], row_idx))
895                    .collect();
896                map.insert(key, criteria);
897            }
898        }
899        map
900    }
901}
902
903/// Estimate byte size of a RecordBatch.
904fn batch_byte_size(batch: &RecordBatch) -> usize {
905    batch
906        .columns()
907        .iter()
908        .map(|col| col.get_buffer_memory_size())
909        .sum()
910}
911
912// ---------------------------------------------------------------------------
913// Float rounding for stable dedup
914// ---------------------------------------------------------------------------
915
916/// Round all Float64 columns to 12 decimal places for stable dedup.
917fn round_float_columns(batches: &[RecordBatch]) -> Vec<RecordBatch> {
918    batches
919        .iter()
920        .map(|batch| {
921            let schema = batch.schema();
922            let has_float = schema
923                .fields()
924                .iter()
925                .any(|f| *f.data_type() == arrow_schema::DataType::Float64);
926            if !has_float {
927                return batch.clone();
928            }
929
930            let columns: Vec<arrow_array::ArrayRef> = batch
931                .columns()
932                .iter()
933                .enumerate()
934                .map(|(i, col)| {
935                    if *schema.field(i).data_type() == arrow_schema::DataType::Float64 {
936                        let arr = col
937                            .as_any()
938                            .downcast_ref::<arrow_array::Float64Array>()
939                            .unwrap();
940                        let rounded: arrow_array::Float64Array = arr
941                            .iter()
942                            .map(|v| v.map(|f| (f * 1e12).round() / 1e12))
943                            .collect();
944                        Arc::new(rounded) as arrow_array::ArrayRef
945                    } else {
946                        Arc::clone(col)
947                    }
948                })
949                .collect();
950
951            RecordBatch::try_new(schema, columns).unwrap_or_else(|_| batch.clone())
952        })
953        .collect()
954}
955
956// ---------------------------------------------------------------------------
957// LeftAntiJoin delta deduplication
958// ---------------------------------------------------------------------------
959
960/// Row threshold above which the vectorized Arrow LeftAntiJoin dedup path is used.
961///
962/// Below this threshold the persistent `RowDedupState` HashSet is O(M) and
963/// avoids rebuilding the existing-row set; above it DataFusion's vectorized
964/// HashJoinExec is more cache-efficient.
965const DEDUP_ANTI_JOIN_THRESHOLD: usize = 300;
966
967/// Deduplicate `candidates` against `existing` using DataFusion's HashJoinExec.
968///
969/// Returns rows in `candidates` that do not appear in `existing` (LeftAnti semantics).
970/// `null_equals_null = true` so NULLs are treated as equal for dedup purposes.
971/// Dedup `batches` by all columns (set semantics), keeping the first occurrence.
972///
973/// `arrow_left_anti_dedup` removes candidate rows that match the existing fact
974/// set, but a single semi-naive iteration can emit the same row many times — e.g.
975/// a transitive-closure rule derives the same `(a, b)` pair via every intermediate
976/// `mid` on a path. A `LeftAnti` join does not remove these *within-candidate*
977/// duplicates, so they would leak into the fact set. The `RowDedupState` and legacy
978/// paths both dedup within the candidate batch ([`RowDedupState::compute_delta`],
979/// [`FixpointState::compute_delta_legacy`]); this keeps the `arrow_left_anti_dedup`
980/// path identical so dedup behavior does not change across `DEDUP_ANTI_JOIN_THRESHOLD`.
981fn dedup_batches_all_columns(
982    batches: Vec<RecordBatch>,
983    schema: &SchemaRef,
984) -> DFResult<Vec<RecordBatch>> {
985    let fields: Vec<SortField> = schema
986        .fields()
987        .iter()
988        .map(|f| SortField::new(f.data_type().clone()))
989        .collect();
990    // Unsupported column types: leave as-is. This path is only reached for facts
991    // sets >= DEDUP_ANTI_JOIN_THRESHOLD; for those types the <threshold path uses
992    // `compute_delta_legacy`, which dedups via `ScalarKey` instead.
993    let Ok(converter) = RowConverter::new(fields) else {
994        return Ok(batches);
995    };
996    let mut seen: HashSet<Box<[u8]>> = HashSet::new();
997    let mut out = Vec::with_capacity(batches.len());
998    for batch in batches {
999        if batch.num_rows() == 0 {
1000            continue;
1001        }
1002        let rows = converter
1003            .convert_columns(batch.columns())
1004            .map_err(arrow_err)?;
1005        let mut keep = Vec::with_capacity(batch.num_rows());
1006        for row_idx in 0..batch.num_rows() {
1007            let row_bytes: Box<[u8]> = rows.row(row_idx).data().into();
1008            keep.push(seen.insert(row_bytes));
1009        }
1010        let keep_mask = arrow_array::BooleanArray::from(keep);
1011        let cols = batch
1012            .columns()
1013            .iter()
1014            .map(|c| arrow::compute::filter(c.as_ref(), &keep_mask).map_err(arrow_err))
1015            .collect::<DFResult<Vec<_>>>()?;
1016        if cols.first().is_some_and(|c| !c.is_empty()) {
1017            out.push(RecordBatch::try_new(Arc::clone(schema), cols).map_err(arrow_err)?);
1018        }
1019    }
1020    Ok(out)
1021}
1022
1023async fn arrow_left_anti_dedup(
1024    candidates: Vec<RecordBatch>,
1025    existing: &[RecordBatch],
1026    schema: &SchemaRef,
1027    task_ctx: &Arc<TaskContext>,
1028) -> DFResult<Vec<RecordBatch>> {
1029    if existing.is_empty() || existing.iter().all(|b| b.num_rows() == 0) {
1030        // No existing facts to anti-join against, but still dedup the candidates
1031        // among themselves (a single iteration may emit duplicate rows).
1032        return dedup_batches_all_columns(candidates, schema);
1033    }
1034
1035    let left: Arc<dyn ExecutionPlan> = Arc::new(InMemoryExec::new(candidates, Arc::clone(schema)));
1036    let right: Arc<dyn ExecutionPlan> =
1037        Arc::new(InMemoryExec::new(existing.to_vec(), Arc::clone(schema)));
1038
1039    let on: Vec<(
1040        Arc<dyn datafusion::physical_plan::PhysicalExpr>,
1041        Arc<dyn datafusion::physical_plan::PhysicalExpr>,
1042    )> = schema
1043        .fields()
1044        .iter()
1045        .enumerate()
1046        .map(|(i, field)| {
1047            let l: Arc<dyn datafusion::physical_plan::PhysicalExpr> = Arc::new(
1048                datafusion::physical_plan::expressions::Column::new(field.name(), i),
1049            );
1050            let r: Arc<dyn datafusion::physical_plan::PhysicalExpr> = Arc::new(
1051                datafusion::physical_plan::expressions::Column::new(field.name(), i),
1052            );
1053            (l, r)
1054        })
1055        .collect();
1056
1057    if on.is_empty() {
1058        return Ok(vec![]);
1059    }
1060
1061    let join = HashJoinExec::try_new(
1062        left,
1063        right,
1064        on,
1065        None,
1066        &JoinType::LeftAnti,
1067        None,
1068        PartitionMode::CollectLeft,
1069        datafusion::common::NullEquality::NullEqualsNull,
1070        // null_aware = false: this is a set-difference dedup (NOT EXISTS), not a
1071        // SQL NOT-IN. `NullEqualsNull` already makes NULL keys dedup against NULL
1072        // keys; enabling null-aware semantics would wrongly annihilate all rows
1073        // whenever the existing-fact side contains a NULL key.
1074        false,
1075    )?;
1076
1077    let join_arc: Arc<dyn ExecutionPlan> = Arc::new(join);
1078    // LeftAnti removes candidates that match `existing`, but not duplicate rows
1079    // within the candidate set — dedup those to match the other delta strategies.
1080    let anti = collect_all_partitions(&join_arc, task_ctx.clone()).await?;
1081    dedup_batches_all_columns(anti, schema)
1082}
1083
1084// ---------------------------------------------------------------------------
1085// Plan types for fixpoint rules
1086// ---------------------------------------------------------------------------
1087
1088/// IS-ref binding: a reference from a clause body to a derived relation.
1089#[derive(Debug, Clone)]
1090pub struct IsRefBinding {
1091    /// Index into the DerivedScanRegistry.
1092    pub derived_scan_index: usize,
1093    /// Name of the rule being referenced.
1094    pub rule_name: String,
1095    /// Whether this is a self-reference (rule references itself).
1096    pub is_self_ref: bool,
1097    /// Whether this is a negated reference (NOT IS).
1098    pub negated: bool,
1099    /// For negated IS-refs: `(left_body_col, right_derived_col)` pairs for anti-join filtering.
1100    ///
1101    /// `left_body_col` is the VID column in the clause body (e.g., `"n._vid"`);
1102    /// `right_derived_col` is the corresponding KEY column in the negated rule's facts (e.g., `"n"`).
1103    /// Empty for non-negated IS-refs.
1104    pub anti_join_cols: Vec<(String, String)>,
1105    /// Whether the target rule has a PROB column.
1106    pub target_has_prob: bool,
1107    /// Name of the PROB column in the target rule, if any.
1108    pub target_prob_col: Option<String>,
1109    /// `(body_col, derived_col)` pairs for provenance tracking.
1110    ///
1111    /// Used by shared-proof detection to find which source facts a derived row
1112    /// consumed. Populated for all IS-refs (not just negated ones).
1113    pub provenance_join_cols: Vec<(String, String)>,
1114}
1115
1116/// A single clause (body) within a fixpoint rule.
1117#[derive(Debug)]
1118pub struct FixpointClausePlan {
1119    /// The logical plan for the clause body.
1120    pub body_logical: LogicalPlan,
1121    /// IS-ref bindings used by this clause.
1122    pub is_ref_bindings: Vec<IsRefBinding>,
1123    /// Priority value for this clause (if PRIORITY semantics apply).
1124    pub priority: Option<i64>,
1125    /// ALONG binding variable names propagated from the planner.
1126    pub along_bindings: Vec<String>,
1127    /// Phase B Slice 3: neural-model invocations lifted out of YIELD
1128    /// items by the compiler. Each entry is evaluated per row after the
1129    /// clause body produces batches and before IS-ref handling.
1130    pub model_invocations: Vec<ModelInvocation>,
1131}
1132
1133/// Physical plan for a single rule in a fixpoint stratum.
1134#[derive(Debug)]
1135pub struct FixpointRulePlan {
1136    /// Rule name.
1137    pub name: String,
1138    /// Clause bodies (each evaluates to candidate rows).
1139    pub clauses: Vec<FixpointClausePlan>,
1140    /// Output schema for this rule's derived relation.
1141    pub yield_schema: SchemaRef,
1142    /// Indices of KEY columns within yield_schema.
1143    pub key_column_indices: Vec<usize>,
1144    /// Priority value (if PRIORITY semantics apply).
1145    pub priority: Option<i64>,
1146    /// Whether this rule has FOLD semantics.
1147    pub has_fold: bool,
1148    /// FOLD bindings for post-fixpoint aggregation.
1149    pub fold_bindings: Vec<FoldBinding>,
1150    /// Post-FOLD filter expressions (HAVING semantics).
1151    pub having: Vec<Expr>,
1152    /// Whether this rule has BEST BY semantics.
1153    pub has_best_by: bool,
1154    /// BEST BY sort criteria for post-fixpoint selection.
1155    pub best_by_criteria: Vec<SortCriterion>,
1156    /// Whether this rule has PRIORITY semantics.
1157    pub has_priority: bool,
1158    /// Whether BEST BY should apply a deterministic secondary sort for
1159    /// tie-breaking. When false, tied rows are selected non-deterministically
1160    /// (faster but not repeatable across runs).
1161    pub deterministic: bool,
1162    /// Name of the PROB column in this rule's yield schema, if any.
1163    pub prob_column_name: Option<String>,
1164}
1165
1166// ---------------------------------------------------------------------------
1167// run_fixpoint_loop — the core semi-naive iteration algorithm
1168// ---------------------------------------------------------------------------
1169
1170/// Run the semi-naive fixpoint iteration loop.
1171///
1172/// Evaluates all rules in a stratum repeatedly, feeding deltas back through
1173/// derived scan handles until convergence or limits are reached.
1174#[expect(clippy::too_many_arguments, reason = "Fixpoint loop needs all context")]
1175async fn run_fixpoint_loop(
1176    rules: Vec<FixpointRulePlan>,
1177    max_iterations: usize,
1178    timeout: Duration,
1179    graph_ctx: Arc<GraphExecutionContext>,
1180    session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
1181    storage: Arc<StorageManager>,
1182    schema_info: Arc<UniSchema>,
1183    params: HashMap<String, Value>,
1184    registry: Arc<DerivedScanRegistry>,
1185    output_schema: SchemaRef,
1186    max_derived_bytes: usize,
1187    derivation_tracker: Option<Arc<ProvenanceStore>>,
1188    iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
1189    strict_probability_domain: bool,
1190    probability_epsilon: f64,
1191    exact_probability: bool,
1192    max_bdd_variables: usize,
1193    warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
1194    approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
1195    top_k_proofs: usize,
1196    timeout_flag: Arc<std::sync::atomic::AtomicU8>,
1197    semiring_kind: SemiringKind,
1198    classifier_registry: Arc<ClassifierRegistry>,
1199    classifier_cache: Option<Arc<ModelInvocationCache>>,
1200    classifier_provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
1201) -> DFResult<Vec<RecordBatch>> {
1202    let start = Instant::now();
1203    let task_ctx = session_ctx.read().task_ctx();
1204
1205    // IMPORTANT: per rollout D-9 the FuzzyNotProbabilistic warning emitted
1206    // below is unsuppressible — do not gate on any suppression mechanism.
1207    // Fuzzy truth values are not probabilities; silent conflation is the
1208    // dominant pitfall in neuro-symbolic systems (LTN, NTP).
1209    if semiring_kind == SemiringKind::MaxMinProb {
1210        let mut warnings = warnings_slot.write().unwrap_or_else(|e| e.into_inner());
1211        let mut already_warned: HashSet<String> = warnings
1212            .iter()
1213            .filter(|w| w.code == RuntimeWarningCode::FuzzyNotProbabilistic)
1214            .map(|w| w.rule_name.clone())
1215            .collect();
1216        for rule in &rules {
1217            if rule.prob_column_name.is_some() && !already_warned.contains(&rule.name) {
1218                warnings.push(RuntimeWarning {
1219                    code: RuntimeWarningCode::FuzzyNotProbabilistic,
1220                    message: format!(
1221                        "rule '{}' carries a PROB column but is being evaluated under \
1222                         the MaxMinProb (fuzzy / Viterbi) semiring; outputs are fuzzy \
1223                         truth values, not probabilities",
1224                        rule.name
1225                    ),
1226                    rule_name: rule.name.clone(),
1227                    variable_count: None,
1228                    key_group: None,
1229                });
1230                already_warned.insert(rule.name.clone());
1231            }
1232        }
1233    }
1234
1235    // Initialize per-rule state
1236    let mut states: Vec<FixpointState> = rules
1237        .iter()
1238        .map(|rule| {
1239            let monotonic_agg = if !rule.fold_bindings.is_empty() {
1240                let bindings: Vec<MonotonicFoldBinding> = rule
1241                    .fold_bindings
1242                    .iter()
1243                    .map(|fb| MonotonicFoldBinding {
1244                        fold_name: fb.output_name.clone(),
1245                        aggregate: std::sync::Arc::clone(&fb.aggregate),
1246                        input_col_index: fb.input_col_index,
1247                        input_col_name: fb.input_col_name.clone(),
1248                    })
1249                    .collect();
1250                Some(MonotonicAggState::new(bindings))
1251            } else {
1252                None
1253            };
1254            FixpointState::new_with_semiring(
1255                rule.name.clone(),
1256                Arc::clone(&rule.yield_schema),
1257                rule.key_column_indices.clone(),
1258                max_derived_bytes,
1259                monotonic_agg,
1260                strict_probability_domain,
1261                semiring_kind,
1262            )
1263        })
1264        .collect();
1265
1266    // Main iteration loop
1267    let mut converged = false;
1268    let mut total_iters = 0usize;
1269    for iteration in 0..max_iterations {
1270        total_iters = iteration + 1;
1271        tracing::debug!("fixpoint iteration {}", iteration);
1272        let mut any_changed = false;
1273
1274        for rule_idx in 0..rules.len() {
1275            let rule = &rules[rule_idx];
1276
1277            // Update derived scan handles for this rule's clauses
1278            update_derived_scan_handles(&registry, &states, rule_idx, &rules);
1279
1280            // Evaluate clause bodies, tracking per-clause candidates for provenance.
1281            let mut all_candidates = Vec::new();
1282            let mut clause_candidates: Vec<Vec<RecordBatch>> = Vec::new();
1283            for clause in &rule.clauses {
1284                // Phase B A4 follow-up: the planner inserts
1285                // `LogicalPlan::LocyModelInvoke` between the body and
1286                // `LocyProject` when this clause has neural-model
1287                // invocations, so `execute_subplan` runs the invocation
1288                // inline as part of the body plan tree.
1289                let mut batches = execute_subplan(
1290                    &clause.body_logical,
1291                    &params,
1292                    &HashMap::new(),
1293                    &graph_ctx,
1294                    &session_ctx,
1295                    &storage,
1296                    &schema_info,
1297                    None, // Locy fixpoint clause body is read-only
1298                )
1299                .await?;
1300                // Apply negated IS-ref semantics: probabilistic complement or anti-join.
1301                for binding in &clause.is_ref_bindings {
1302                    if binding.negated
1303                        && !binding.anti_join_cols.is_empty()
1304                        && let Some(entry) = registry.get(binding.derived_scan_index)
1305                    {
1306                        let neg_facts = entry.data.read().clone();
1307                        if !neg_facts.is_empty() {
1308                            if binding.target_has_prob && rule.prob_column_name.is_some() {
1309                                // Probabilistic complement: add 1-p column instead of filtering.
1310                                let complement_col =
1311                                    format!("__prob_complement_{}", binding.rule_name);
1312                                if let Some(prob_col) = &binding.target_prob_col {
1313                                    batches = apply_prob_complement_composite(
1314                                        batches,
1315                                        &neg_facts,
1316                                        &binding.anti_join_cols,
1317                                        prob_col,
1318                                        &complement_col,
1319                                    )?;
1320                                } else {
1321                                    // target_has_prob but no prob_col: fall back to anti-join.
1322                                    batches = apply_anti_join_composite(
1323                                        batches,
1324                                        &neg_facts,
1325                                        &binding.anti_join_cols,
1326                                    )?;
1327                                }
1328                            } else {
1329                                // Boolean exclusion: anti-join (existing behavior)
1330                                batches = apply_anti_join_composite(
1331                                    batches,
1332                                    &neg_facts,
1333                                    &binding.anti_join_cols,
1334                                )?;
1335                            }
1336                        }
1337                    }
1338                }
1339                // Multiply complement columns into the PROB column (if any) and clean up
1340                let complement_cols: Vec<String> = if !batches.is_empty() {
1341                    batches[0]
1342                        .schema()
1343                        .fields()
1344                        .iter()
1345                        .filter(|f| f.name().starts_with("__prob_complement_"))
1346                        .map(|f| f.name().clone())
1347                        .collect()
1348                } else {
1349                    vec![]
1350                };
1351                if !complement_cols.is_empty() {
1352                    batches = multiply_prob_factors(
1353                        batches,
1354                        rule.prob_column_name.as_deref(),
1355                        &complement_cols,
1356                    )?;
1357                }
1358
1359                clause_candidates.push(batches.clone());
1360                all_candidates.extend(batches);
1361            }
1362
1363            // Merge candidates into facts.
1364            // For BEST BY rules, use a specialized merge that keeps only the
1365            // best row per KEY group, enabling convergence on cyclic graphs.
1366            let changed = if rule.has_best_by && !rule.best_by_criteria.is_empty() {
1367                states[rule_idx].merge_best_by(all_candidates, &rule.best_by_criteria)?
1368            } else {
1369                states[rule_idx]
1370                    .merge_delta(all_candidates, Some(Arc::clone(&task_ctx)))
1371                    .await?
1372            };
1373            if changed {
1374                any_changed = true;
1375                // Record provenance for newly derived facts when tracker is present.
1376                if let Some(ref tracker) = derivation_tracker {
1377                    record_provenance(
1378                        ProvenanceCtx {
1379                            tracker,
1380                            registry: &registry,
1381                            warnings_slot: &warnings_slot,
1382                        },
1383                        rule,
1384                        &states[rule_idx],
1385                        &clause_candidates,
1386                        iteration,
1387                        top_k_proofs,
1388                        ClassifierRefs {
1389                            registry: &classifier_registry,
1390                            cache: classifier_cache.as_ref(),
1391                            provenance_store: classifier_provenance_store.as_ref(),
1392                        },
1393                    )
1394                    .await;
1395                }
1396            }
1397        }
1398
1399        // Check convergence
1400        if !any_changed && states.iter().all(|s| s.is_converged()) {
1401            tracing::debug!("fixpoint converged after {} iterations", iteration + 1);
1402            converged = true;
1403            break;
1404        }
1405
1406        // Check timeout
1407        if start.elapsed() > timeout {
1408            tracing::warn!(
1409                "fixpoint timeout after {} iterations; returning partial results",
1410                iteration + 1,
1411            );
1412            interruption::set(&timeout_flag, interruption::TIMEOUT);
1413            break;
1414        }
1415    }
1416
1417    // Write per-rule iteration counts to the shared slot.
1418    if let Ok(mut counts) = iteration_counts.write() {
1419        for rule in &rules {
1420            counts.insert(rule.name.clone(), total_iters);
1421        }
1422    }
1423
1424    // If we exhausted all iterations without converging, record the iteration
1425    // limit (distinct from a wall-clock timeout) and proceed with partial
1426    // results rather than discarding all work. `set` is first-wins, so a
1427    // wall-clock timeout recorded above is not overwritten here.
1428    if !converged && interruption::reason(&timeout_flag).is_none() {
1429        tracing::warn!(
1430            "fixpoint did not converge after {max_iterations} iterations; returning partial results",
1431        );
1432        interruption::set(&timeout_flag, interruption::ITERATION_LIMIT);
1433    }
1434
1435    // Post-fixpoint processing per rule and collect output
1436    let task_ctx = session_ctx.read().task_ctx();
1437    let mut all_output = Vec::new();
1438
1439    for (rule_idx, state) in states.into_iter().enumerate() {
1440        let rule = &rules[rule_idx];
1441        let mut facts = state.into_facts();
1442        if facts.is_empty() {
1443            continue;
1444        }
1445
1446        // Detect shared proofs before FOLD collapses groups.
1447        //
1448        // TODO(C0-stage2): swap `detect_shared_lineage` for `TopKTag`
1449        // DNF inspection when `semiring_kind == TopKProofs { k }`.
1450        // The library-layer tag math has landed in
1451        // `crates/uni-locy/src/top_k_proofs.rs` (Phase C C0 Stage 1);
1452        // Stage 2 plumbs `TopKTag` through `MonotonicAggState` /
1453        // `FoldExec` so per-row dependency DNFs are available here.
1454        // Until Stage 2, this scalar `ProvenanceStore` path runs for
1455        // every semiring including `TopKProofs` (per rollout D-4
1456        // "graceful migration").
1457        //
1458        // Phase-3 shared-proof detection is meaningful only under
1459        // `AddMultProb` (and `BddExact`, which is the AddMultProb math
1460        // plus a WMC post-correction). Under `MaxMinProb`, `plus = max`
1461        // is idempotent — shared proofs don't double-count — so the
1462        // warning is moot and we skip the work.
1463        let shared_info = if semiring_kind == SemiringKind::MaxMinProb {
1464            None
1465        } else if let Some(ref tracker) = derivation_tracker {
1466            detect_shared_lineage(rule, &facts, tracker, &warnings_slot, semiring_kind)
1467        } else {
1468            None
1469        };
1470
1471        // Apply BDD for shared groups if exact_probability is enabled.
1472        if exact_probability
1473            && let Some(ref info) = shared_info
1474            && let Some(ref tracker) = derivation_tracker
1475        {
1476            facts = apply_exact_wmc(
1477                facts,
1478                rule,
1479                info,
1480                tracker,
1481                max_bdd_variables,
1482                &warnings_slot,
1483                &approximate_slot,
1484            )?;
1485        }
1486
1487        let processed = apply_post_fixpoint_chain(
1488            facts,
1489            rule,
1490            &task_ctx,
1491            strict_probability_domain,
1492            probability_epsilon,
1493            semiring_kind,
1494            derivation_tracker.as_ref().map(Arc::clone),
1495            top_k_proofs,
1496            Some(Arc::clone(&registry)),
1497        )
1498        .await?;
1499        all_output.extend(processed);
1500    }
1501
1502    // If no output, return empty batch with output schema
1503    if all_output.is_empty() {
1504        all_output.push(RecordBatch::new_empty(output_schema));
1505    }
1506
1507    Ok(all_output)
1508}
1509
1510// ---------------------------------------------------------------------------
1511// Provenance recording helpers
1512// ---------------------------------------------------------------------------
1513
1514/// Record provenance for all newly derived facts (rows in the current delta).
1515///
1516/// Called after `merge_delta` returns `true`. Attributes each new fact to the
1517/// clause most likely to have produced it, using first-derivation-wins semantics.
1518/// Borrowed bundle of classifier-side runtime state used by
1519/// provenance / EXPLAIN-reconstruction code paths. Keeps function
1520/// signatures under the too-many-arguments threshold.
1521pub(crate) struct ClassifierRefs<'a> {
1522    pub registry: &'a Arc<ClassifierRegistry>,
1523    pub cache: Option<&'a Arc<uni_locy::ModelInvocationCache>>,
1524    /// Phase C B1-B3 follow-up: when `Some`, EXPLAIN's neural_calls
1525    /// collection consults the side-channel provenance store first
1526    /// (populated by `apply_model_invocations`). This is the only way
1527    /// to surface NeuralProvenance for Python-registered classifiers,
1528    /// whose model_invocations may be rewritten away by the planner
1529    /// and so wouldn't trigger the re-invocation fallback.
1530    pub provenance_store: Option<&'a Arc<uni_locy::NeuralProvenanceStore>>,
1531}
1532
1533/// Borrowed bundle of provenance-recording state: the in-flight
1534/// tracker, the derived-scan registry (used to resolve IS-ref inputs),
1535/// and the shared warnings slot. Bundled to keep
1536/// `record_provenance` / `record_and_detect_lineage_nonrecursive`
1537/// under the too-many-arguments threshold.
1538pub(crate) struct ProvenanceCtx<'a> {
1539    pub tracker: &'a Arc<ProvenanceStore>,
1540    pub registry: &'a Arc<DerivedScanRegistry>,
1541    pub warnings_slot: &'a Arc<StdRwLock<Vec<RuntimeWarning>>>,
1542}
1543
1544async fn record_provenance(
1545    prov: ProvenanceCtx<'_>,
1546    rule: &FixpointRulePlan,
1547    state: &FixpointState,
1548    clause_candidates: &[Vec<RecordBatch>],
1549    iteration: usize,
1550    top_k_proofs: usize,
1551    classifiers: ClassifierRefs<'_>,
1552) {
1553    let tracker = prov.tracker;
1554    let registry = prov.registry;
1555    let warnings_slot = prov.warnings_slot;
1556    let classifier_registry = classifiers.registry;
1557    let classifier_cache = classifiers.cache;
1558    let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
1559
1560    // Pre-compute base fact probabilities for top-k mode.
1561    let base_probs = if top_k_proofs > 0 {
1562        tracker.base_fact_probs()
1563    } else {
1564        HashMap::new()
1565    };
1566
1567    let mut topk_acc = TopKProofAccumulator::new();
1568
1569    for delta_batch in state.all_delta() {
1570        for row_idx in 0..delta_batch.num_rows() {
1571            let row_hash = format!(
1572                "{:?}",
1573                extract_scalar_key(delta_batch, &all_indices, row_idx)
1574            )
1575            .into_bytes();
1576            let fact_row = batch_row_to_value_map(delta_batch, row_idx);
1577            let clause_index =
1578                find_clause_for_row(delta_batch, row_idx, &all_indices, clause_candidates);
1579
1580            let support = collect_is_ref_inputs(rule, clause_index, delta_batch, row_idx, registry);
1581
1582            let proof_probability = if top_k_proofs > 0 {
1583                compute_proof_probability(&support, &base_probs)
1584            } else {
1585                None
1586            };
1587
1588            let entry = ProvenanceAnnotation {
1589                rule_name: rule.name.clone(),
1590                clause_index,
1591                support,
1592                along_values: {
1593                    let along_names: Vec<String> = rule
1594                        .clauses
1595                        .get(clause_index)
1596                        .map(|c| c.along_bindings.clone())
1597                        .unwrap_or_default();
1598                    along_names
1599                        .iter()
1600                        .filter_map(|name| fact_row.get(name).map(|v| (name.clone(), v.clone())))
1601                        .collect()
1602                },
1603                iteration,
1604                fact_row: fact_row.clone(),
1605                proof_probability,
1606                neural_calls: collect_neural_calls_for_row(
1607                    rule,
1608                    clause_index,
1609                    &fact_row,
1610                    classifier_registry,
1611                    classifier_cache,
1612                    classifiers.provenance_store,
1613                )
1614                .await,
1615            };
1616            if top_k_proofs > 0 {
1617                topk_acc.accumulate(&entry, &row_hash);
1618                tracker.record_top_k(row_hash, entry, top_k_proofs);
1619            } else {
1620                tracker.record(row_hash, entry);
1621            }
1622        }
1623    }
1624
1625    topk_acc.emit_warning_if_any(rule, top_k_proofs, warnings_slot);
1626}
1627
1628/// Phase C C0 Stage 2: collects per-row `Proof` tags during the
1629/// fixpoint row walk, then surfaces `TopKPruningCrossedDependency`
1630/// when post-walk top-K merging would drop a proof whose base RVs
1631/// overlap a retained one. The shared `BaseRv` interner is what
1632/// makes the overlap detectable — proofs grounded in the same
1633/// `base_fact_id` get the same `BaseRv`.
1634struct TopKProofAccumulator {
1635    per_fact: HashMap<Vec<u8>, Vec<uni_locy::Proof>>,
1636    base_rv_interner: HashMap<Vec<u8>, uni_locy::BaseRv>,
1637    next_rv: u32,
1638}
1639
1640impl TopKProofAccumulator {
1641    fn new() -> Self {
1642        Self {
1643            per_fact: HashMap::new(),
1644            base_rv_interner: HashMap::new(),
1645            next_rv: 0,
1646        }
1647    }
1648
1649    fn accumulate(&mut self, entry: &ProvenanceAnnotation, row_hash: &[u8]) {
1650        let mut base_rvs = uni_locy::BaseRvSet::empty();
1651        for term in &entry.support {
1652            let rv = *self
1653                .base_rv_interner
1654                .entry(term.base_fact_id.clone())
1655                .or_insert_with(|| {
1656                    let r = uni_locy::BaseRv(self.next_rv);
1657                    self.next_rv += 1;
1658                    r
1659                });
1660            base_rvs.insert(rv);
1661        }
1662        self.per_fact
1663            .entry(row_hash.to_vec())
1664            .or_default()
1665            .push(uni_locy::Proof {
1666                weight: entry.proof_probability.unwrap_or(0.0),
1667                base_rvs,
1668                neural_calls: Vec::new(),
1669            });
1670    }
1671
1672    fn emit_warning_if_any(
1673        &self,
1674        rule: &FixpointRulePlan,
1675        top_k_proofs: usize,
1676        warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
1677    ) {
1678        if top_k_proofs == 0 || self.per_fact.is_empty() {
1679            return;
1680        }
1681        let crossed_facts = self
1682            .per_fact
1683            .values()
1684            .filter(|proofs| {
1685                let (_kept, notice) =
1686                    uni_locy::merge_top_k_runtime(Vec::new(), (*proofs).clone(), top_k_proofs);
1687                notice == uni_locy::PruneNotice::CrossedDependency
1688            })
1689            .count();
1690        if crossed_facts == 0 {
1691            return;
1692        }
1693        let Ok(mut w) = warnings_slot.write() else {
1694            return;
1695        };
1696        let already = w.iter().any(|rw| {
1697            matches!(
1698                rw.code,
1699                uni_locy::types::RuntimeWarningCode::TopKPruningCrossedDependency
1700            ) && rw.rule_name == rule.name
1701        });
1702        if already {
1703            return;
1704        }
1705        w.push(RuntimeWarning {
1706            code: uni_locy::types::RuntimeWarningCode::TopKPruningCrossedDependency,
1707            rule_name: rule.name.clone(),
1708            message: format!(
1709                "rule '{}': top-K proof pruning (k={}) discarded {} fact(s) \
1710                 whose dependencies overlap retained proofs. The retained \
1711                 top-{} under-counts the true joint probability for those \
1712                 facts (Scallop, Huang et al. 2021). Increase k to recover.",
1713                rule.name, top_k_proofs, crossed_facts, top_k_proofs
1714            ),
1715            variable_count: None,
1716            key_group: None,
1717        });
1718    }
1719}
1720
1721/// Collect IS-ref input facts for a derived row using provenance join columns.
1722///
1723/// For each non-negated IS-ref binding in the clause, extracts body-side key
1724/// values from the delta row and finds matching source rows in the registry.
1725/// Returns a `ProofTerm` for each match (with the source fact hash).
1726/// Phase C B1–B3: build [`uni_locy::NeuralProvenance`] entries for
1727/// the model invocations on this clause by reading each
1728/// invocation's output column from the post-LocyModelInvoke row.
1729/// `raw_probability` is the classifier's direct output;
1730/// `calibrated_probability` and `confidence_band` come from the
1731/// active Calibrator (when the classifier wraps one).
1732///
1733/// Phase C B1-B3 follow-up rewrite: re-evaluate the classifier
1734/// per fact using the ORIGINAL pre-rewrite feature expressions
1735/// (stored as `invocation.original_feature_exprs`). This works
1736/// for invocations in YIELD, ALONG, and FOLD positions uniformly
1737/// — the original args carry the input bindings regardless of
1738/// where the synthetic `__model_<n>` column ends up in the plan
1739/// tree. Memoization via `ModelInvocationCache` (already threaded)
1740/// absorbs repeat costs; EXPLAIN typically operates on small
1741/// derivation trees so the per-fact classifier call is bounded.
1742async fn collect_neural_calls_for_row(
1743    rule: &FixpointRulePlan,
1744    clause_index: usize,
1745    fact_row: &uni_locy::FactRow,
1746    classifier_registry: &Arc<ClassifierRegistry>,
1747    classifier_cache: Option<&Arc<uni_locy::ModelInvocationCache>>,
1748    provenance_store: Option<&Arc<uni_locy::NeuralProvenanceStore>>,
1749) -> Vec<uni_locy::NeuralProvenance> {
1750    let Some(clause) = rule.clauses.get(clause_index) else {
1751        return Vec::new();
1752    };
1753    if clause.model_invocations.is_empty() {
1754        return Vec::new();
1755    }
1756    let mut out = Vec::with_capacity(clause.model_invocations.len());
1757    for invocation in &clause.model_invocations {
1758        // Build ClassifyInput from the REWRITTEN feature expressions
1759        // (referencing the synthetic `__feat_*` hidden columns that the
1760        // planner lifts model-call args into). This matches the writer's
1761        // path in `apply_model_invocations`, which iterates
1762        // `invocation.feature_exprs` to compute the same `input_hash`
1763        // that gets stored. Using the pre-rewrite `original_feature_exprs`
1764        // here would compute a different hash for YIELD-position
1765        // invocations (where the pre-rewrite expr references properties
1766        // not materialised into fact_row), causing the store lookup
1767        // below to miss and `neural_calls` to come back empty.
1768        let mut features = std::collections::HashMap::new();
1769        for (binding_name, feat_expr) in invocation
1770            .feature_names
1771            .iter()
1772            .zip(invocation.feature_exprs.iter())
1773        {
1774            features.insert(
1775                binding_name.clone(),
1776                eval_feature_expr_against_fact_row(feat_expr, fact_row),
1777            );
1778        }
1779        let input = uni_locy::ClassifyInput { features };
1780        let input_hash = input.stable_hash();
1781
1782        // Store-first read path. `apply_model_invocations` writes a
1783        // NeuralProvenanceRecord per (model, input_hash) into the
1784        // side-channel store during fixpoint. If we find a record
1785        // there, surface it directly — this is the only path that
1786        // populates calibrated_probability + confidence_band for
1787        // Python-registered classifiers.
1788        if let Some(store) = provenance_store
1789            && let Some(record) = store.get(&invocation.model_name, input_hash)
1790        {
1791            out.push(uni_locy::NeuralProvenance {
1792                model_name: invocation.model_name.clone(),
1793                raw_probability: record.raw_probability,
1794                calibrated_probability: record.calibrated_probability,
1795                confidence_band: record.confidence_band,
1796            });
1797            continue;
1798        }
1799
1800        // Fallback: re-invoke the classifier. Only reached when the
1801        // store wasn't populated for this (model, input) — e.g. older
1802        // sessions where the store wasn't threaded, or when EXPLAIN
1803        // runs against a row that fixpoint never touched.
1804        let Some(classifier) = classifier_registry.get(&invocation.model_name) else {
1805            continue;
1806        };
1807        let raw = if let Some(v) =
1808            classifier_cache.and_then(|c| c.get(&invocation.model_name, input_hash))
1809        {
1810            v
1811        } else {
1812            match classifier.classify(std::slice::from_ref(&input)).await {
1813                Ok(probs) => {
1814                    let v = probs.first().copied().unwrap_or(0.0);
1815                    if let Some(c) = classifier_cache {
1816                        c.insert(&invocation.model_name, input_hash, v);
1817                    }
1818                    v
1819                }
1820                Err(_) => continue,
1821            }
1822        };
1823        let calibrator = classifier.get_calibrator();
1824        let calibrated_probability = calibrator.as_ref().map(|_| raw);
1825        let confidence_band = calibrator.as_ref().and_then(|c| c.confidence_band(raw));
1826        out.push(uni_locy::NeuralProvenance {
1827            model_name: invocation.model_name.clone(),
1828            raw_probability: raw,
1829            calibrated_probability,
1830            confidence_band,
1831        });
1832    }
1833    out
1834}
1835
1836/// Phase C B1-B3 follow-up: evaluate a model's pre-rewrite feature
1837/// expression against a fact_row, producing a `FeatureValue` for
1838/// classifier input reconstruction. Mirrors the compile-time
1839/// acceptance set in `validate_features`:
1840/// - `Variable(name)` → fact_row[name], coerced.
1841/// - `Property(Variable(v), prop)` → fact_row["v.prop"] (the
1842///   materialized property column).
1843fn eval_feature_expr_against_fact_row(
1844    expr: &uni_cypher::ast::Expr,
1845    fact_row: &uni_locy::FactRow,
1846) -> uni_locy::FeatureValue {
1847    use uni_cypher::ast::Expr;
1848    use uni_locy::FeatureValue;
1849    let value_to_feature = |v: Option<&uni_common::Value>| -> FeatureValue {
1850        match v {
1851            Some(uni_common::Value::Float(f)) => FeatureValue::Float(*f),
1852            Some(uni_common::Value::Int(i)) => FeatureValue::Int(*i),
1853            Some(uni_common::Value::Bool(b)) => FeatureValue::Bool(*b),
1854            Some(uni_common::Value::String(s)) => FeatureValue::String(s.clone()),
1855            Some(uni_common::Value::Node(n)) => {
1856                // Encode node by vid for `scorer(s)` style.
1857                FeatureValue::Int(n.vid.as_u64() as i64)
1858            }
1859            _ => FeatureValue::Null,
1860        }
1861    };
1862    // Phase D D1: resolve a sub-expression to its raw `uni_common::Value`
1863    // for the `similar_to` UDF input. Falls back to the node's property
1864    // when the materialized column key isn't directly present.
1865    let resolve_value = |sub: &Expr| -> uni_common::Value {
1866        match sub {
1867            Expr::Variable(name) => fact_row
1868                .get(name)
1869                .cloned()
1870                .unwrap_or(uni_common::Value::Null),
1871            Expr::Property(boxed, prop) if matches!(boxed.as_ref(), Expr::Variable(_)) => {
1872                let Expr::Variable(v) = boxed.as_ref() else {
1873                    unreachable!()
1874                };
1875                let key = format!("{}.{}", v, prop);
1876                if let Some(val) = fact_row.get(&key) {
1877                    return val.clone();
1878                }
1879                if let Some(uni_common::Value::Node(n)) = fact_row.get(v) {
1880                    return n
1881                        .properties
1882                        .get(prop)
1883                        .cloned()
1884                        .unwrap_or(uni_common::Value::Null);
1885                }
1886                uni_common::Value::Null
1887            }
1888            Expr::Literal(lit) => lit.to_value(),
1889            Expr::List(items) => {
1890                let mut out = Vec::with_capacity(items.len());
1891                for it in items {
1892                    out.push(match it {
1893                        Expr::Literal(lit) => lit.to_value(),
1894                        _ => uni_common::Value::Null,
1895                    });
1896                }
1897                uni_common::Value::List(out)
1898            }
1899            _ => uni_common::Value::Null,
1900        }
1901    };
1902
1903    match expr {
1904        Expr::Variable(name) => value_to_feature(fact_row.get(name)),
1905        Expr::Property(boxed, prop) => {
1906            if let Expr::Variable(v) = boxed.as_ref() {
1907                // Try the materialized property column first.
1908                let key = format!("{}.{}", v, prop);
1909                if let Some(val) = fact_row.get(&key) {
1910                    return value_to_feature(Some(val));
1911                }
1912                // Fallback: try the synthetic hidden column that the
1913                // planner injects for property-access feature args
1914                // (`__feat_<var>_<prop>`). The writer side
1915                // (`apply_model_invocations`) already uses this
1916                // fallback (see `resolve_src` in the same file), so
1917                // mirroring it here keeps reader/writer input_hash
1918                // symmetric — without it, the YIELD-position case
1919                // (where `fact_row[v]` is a vid Int, not a Node)
1920                // returns Null and the store-lookup misses.
1921                let hidden_key = format!("__feat_{}_{}", v, prop);
1922                if let Some(val) = fact_row.get(&hidden_key) {
1923                    return value_to_feature(Some(val));
1924                }
1925                // Final fallback: read property directly from the
1926                // node value (works when fact_row carries the Node
1927                // rather than a vid Int).
1928                if let Some(uni_common::Value::Node(n)) = fact_row.get(v) {
1929                    return value_to_feature(n.properties.get(prop));
1930                }
1931            }
1932            FeatureValue::Null
1933        }
1934        Expr::FunctionCall { name, args, .. } if name == "similar_to" && args.len() == 2 => {
1935            let lv = resolve_value(&args[0]);
1936            let rv = resolve_value(&args[1]);
1937            match crate::query::similar_to::eval_similar_to_pure(&lv, &rv) {
1938                Ok(uni_common::Value::Float(f)) => FeatureValue::Float(f),
1939                _ => FeatureValue::Null,
1940            }
1941        }
1942        // `semantic_match` requires the Xervo embedder at this scope, which
1943        // is not threaded into the EXPLAIN re-evaluation path. Surface as
1944        // Null so neural-provenance still renders for the rest of the row.
1945        //
1946        // Phase D D1 graph-structural FunctionCalls (`degree_centrality`,
1947        // `pagerank_score`, `closeness_centrality`, `avg_neighbor`,
1948        // `max_neighbor`, `sum_neighbor`) require the `GraphAlgoHandle`
1949        // (algorithm registry + storage + PropertyManager) and an async
1950        // re-precompute pass — none of which are reachable from this
1951        // synchronous fact-row evaluator. Mode B re-evaluation surfaces
1952        // them as Null; the authoritative hot-path values are recorded
1953        // in `NeuralProvenanceStore` per fact (the EXPLAIN renderer
1954        // consults the store first when configured, falling back to
1955        // Mode B re-evaluation only as a backup).
1956        Expr::FunctionCall { name, .. }
1957            if matches!(
1958                name.as_str(),
1959                "degree_centrality"
1960                    | "pagerank_score"
1961                    | "closeness_centrality"
1962                    | "betweenness_centrality"
1963                    | "eigenvector_centrality"
1964                    | "harmonic_centrality"
1965                    | "katz_centrality"
1966                    | "avg_neighbor"
1967                    | "max_neighbor"
1968                    | "sum_neighbor"
1969            ) =>
1970        {
1971            FeatureValue::Null
1972        }
1973        _ => FeatureValue::Null,
1974    }
1975}
1976
1977fn collect_is_ref_inputs(
1978    rule: &FixpointRulePlan,
1979    clause_index: usize,
1980    delta_batch: &RecordBatch,
1981    row_idx: usize,
1982    registry: &Arc<DerivedScanRegistry>,
1983) -> Vec<ProofTerm> {
1984    let clause = match rule.clauses.get(clause_index) {
1985        Some(c) => c,
1986        None => return vec![],
1987    };
1988
1989    let mut inputs = Vec::new();
1990    let delta_schema = delta_batch.schema();
1991
1992    for binding in &clause.is_ref_bindings {
1993        if binding.negated {
1994            continue;
1995        }
1996        if binding.provenance_join_cols.is_empty() {
1997            continue;
1998        }
1999
2000        // Extract body-side values from the delta row for each provenance join col.
2001        let body_values: Vec<(String, ScalarKey)> = binding
2002            .provenance_join_cols
2003            .iter()
2004            .filter_map(|(body_col, _derived_col)| {
2005                let col_idx = delta_schema
2006                    .fields()
2007                    .iter()
2008                    .position(|f| f.name() == body_col)?;
2009                let key = extract_scalar_key(delta_batch, &[col_idx], row_idx);
2010                Some((body_col.clone(), key.into_iter().next()?))
2011            })
2012            .collect();
2013
2014        if body_values.len() != binding.provenance_join_cols.len() {
2015            continue;
2016        }
2017
2018        // Read current data from the registry entry for this IS-ref's rule.
2019        let entry = match registry.get(binding.derived_scan_index) {
2020            Some(e) => e,
2021            None => continue,
2022        };
2023        let source_batches = entry.data.read();
2024        let source_schema = &entry.schema;
2025
2026        // Find matching source rows and hash them.
2027        for src_batch in source_batches.iter() {
2028            let all_src_indices: Vec<usize> = (0..src_batch.num_columns()).collect();
2029            for src_row in 0..src_batch.num_rows() {
2030                let matches = binding.provenance_join_cols.iter().enumerate().all(
2031                    |(i, (_body_col, derived_col))| {
2032                        let src_col_idx = source_schema
2033                            .fields()
2034                            .iter()
2035                            .position(|f| f.name() == derived_col);
2036                        match src_col_idx {
2037                            Some(idx) => {
2038                                let src_key = extract_scalar_key(src_batch, &[idx], src_row);
2039                                src_key.first() == Some(&body_values[i].1)
2040                            }
2041                            None => false,
2042                        }
2043                    },
2044                );
2045                if matches {
2046                    let fact_hash = format!(
2047                        "{:?}",
2048                        extract_scalar_key(src_batch, &all_src_indices, src_row)
2049                    )
2050                    .into_bytes();
2051                    inputs.push(ProofTerm {
2052                        source_rule: binding.rule_name.clone(),
2053                        base_fact_id: fact_hash,
2054                    });
2055                }
2056            }
2057        }
2058    }
2059
2060    inputs
2061}
2062
2063/// Phase D D-C0: per-body-row variant of [`collect_is_ref_inputs`] used
2064/// to pre-populate `FoldExec`'s `body_support_map` for TopKProofs MNOR.
2065///
2066/// At FOLD time, the rule's own facts haven't been recorded in the
2067/// `ProvenanceStore` yet (`record_provenance` runs after fact
2068/// materialization, and is keyed by post-YIELD hashes anyway), so the
2069/// support set for each pre-fold body row must be reconstructed
2070/// directly from the rule's IS-ref bindings + the source rules'
2071/// registry data.
2072///
2073/// We don't know which clause produced each body row at this point —
2074/// the iteration-local `clause_candidates` are gone — so we iterate
2075/// **every** clause's `is_ref_bindings`. The `provenance_join_cols`
2076/// schema check inside `collect_is_ref_inputs` already skips bindings
2077/// whose body columns aren't in the row's schema, so cross-clause
2078/// contamination is bounded (a binding only matches if its body cols
2079/// are present and the values join). For the single-clause TopKProofs
2080/// scenarios in TCK this is exact; for multi-clause TopKProofs rules
2081/// it is a conservative over-approximation that may inflate base-RV
2082/// counts (treated as the same RV under interning) — acceptable
2083/// because the DNF math collapses duplicates by inclusion-exclusion.
2084fn collect_is_ref_inputs_for_body_row(
2085    rule: &FixpointRulePlan,
2086    delta_batch: &RecordBatch,
2087    row_idx: usize,
2088    registry: &Arc<DerivedScanRegistry>,
2089) -> Vec<ProofTerm> {
2090    let mut combined: Vec<ProofTerm> = Vec::new();
2091    for clause_index in 0..rule.clauses.len() {
2092        let part = collect_is_ref_inputs(rule, clause_index, delta_batch, row_idx, registry);
2093        combined.extend(part);
2094    }
2095    combined
2096}
2097
2098// ---------------------------------------------------------------------------
2099// Shared-lineage detection
2100// ---------------------------------------------------------------------------
2101
2102/// Detect KEY groups in a rule's pre-fold facts where recursive derivation
2103/// may violate the independence assumption of MNOR/MPROD.
2104///
2105/// Uses a two-tier strategy:
2106/// 1. **Precise**: If the `ProvenanceStore` has populated `support` for facts
2107///    in the group, we recursively compute lineage (Cui & Widom 2000) and
2108///    check for pairwise overlap. A shared base fact proves a dependency.
2109/// 2. **Structural fallback**: When lineage tracking is unavailable (e.g., the
2110///    IS-ref subject variables were projected away), we check whether any fact
2111///    in a multi-row group was derived by a clause that has IS-ref bindings.
2112///    Recursive derivation through shared relations is a strong signal that
2113///    proof paths may share intermediate nodes.
2114///
2115/// Per-row data collected during shared-lineage detection.
2116#[expect(
2117    dead_code,
2118    reason = "Fields accessed via SharedLineageInfo in detect_shared_lineage"
2119)]
2120pub(crate) struct SharedGroupRow {
2121    pub fact_hash: Vec<u8>,
2122    pub lineage: HashSet<Vec<u8>>,
2123}
2124
2125/// Information about groups with shared proofs, returned by `detect_shared_lineage`.
2126pub(crate) struct SharedLineageInfo {
2127    /// KEY group → rows with their base fact sets.
2128    pub shared_groups: HashMap<Vec<ScalarKey>, Vec<SharedGroupRow>>,
2129}
2130
2131/// Build a byte key that uniquely identifies a row across all columns.
2132pub(crate) fn fact_hash_key(batch: &RecordBatch, all_indices: &[usize], row_idx: usize) -> Vec<u8> {
2133    format!("{:?}", extract_scalar_key(batch, all_indices, row_idx)).into_bytes()
2134}
2135
2136/// Emits at most one `SharedProbabilisticDependency` warning per rule.
2137/// Returns `Some(SharedLineageInfo)` if any group has shared proofs.
2138fn detect_shared_lineage(
2139    rule: &FixpointRulePlan,
2140    pre_fold_facts: &[RecordBatch],
2141    tracker: &Arc<ProvenanceStore>,
2142    warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
2143    semiring_kind: SemiringKind,
2144) -> Option<SharedLineageInfo> {
2145    use uni_locy::{RuntimeWarning, RuntimeWarningCode};
2146
2147    // Only check rules with probability-domain fold bindings. M3:
2148    // dispatches via the `LocyAggregate` trait so user-authored
2149    // probability aggregates participate automatically — selected by the
2150    // trait's `is_probability_aggregate()` flag, not by hardcoded name.
2151    let has_prob_fold = rule
2152        .fold_bindings
2153        .iter()
2154        .any(|fb| fb.aggregate.is_probability_aggregate());
2155    if !has_prob_fold {
2156        return None;
2157    }
2158
2159    // Group facts by KEY columns.
2160    let key_indices = &rule.key_column_indices;
2161    let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
2162
2163    let mut groups: HashMap<Vec<ScalarKey>, Vec<Vec<u8>>> = HashMap::new();
2164    for batch in pre_fold_facts {
2165        for row_idx in 0..batch.num_rows() {
2166            let key = extract_scalar_key(batch, key_indices, row_idx);
2167            let fact_hash = fact_hash_key(batch, &all_indices, row_idx);
2168            groups.entry(key).or_default().push(fact_hash);
2169        }
2170    }
2171
2172    let mut shared_groups: HashMap<Vec<ScalarKey>, Vec<SharedGroupRow>> = HashMap::new();
2173    let mut any_shared = false;
2174
2175    // Check each group with ≥2 rows.
2176    for (key, fact_hashes) in &groups {
2177        if fact_hashes.len() < 2 {
2178            continue;
2179        }
2180
2181        // Tier 1: precise base-fact overlap detection via tracker inputs.
2182        let mut has_inputs = false;
2183        let mut per_row_bases: Vec<HashSet<Vec<u8>>> = Vec::new();
2184        for fh in fact_hashes {
2185            let bases = compute_lineage(fh, tracker, &mut HashSet::new());
2186            if let Some(entry) = tracker.lookup(fh)
2187                && !entry.support.is_empty()
2188            {
2189                has_inputs = true;
2190            }
2191            per_row_bases.push(bases);
2192        }
2193
2194        let shared_found = if has_inputs {
2195            // At least some facts have tracked inputs — do precise comparison.
2196            let mut found = false;
2197            'outer: for i in 0..per_row_bases.len() {
2198                for j in (i + 1)..per_row_bases.len() {
2199                    if !per_row_bases[i].is_disjoint(&per_row_bases[j]) {
2200                        found = true;
2201                        break 'outer;
2202                    }
2203                }
2204            }
2205            found
2206        } else {
2207            // Tier 2: structural fallback — check if any fact in the group was
2208            // derived by a clause with IS-ref bindings (recursive derivation).
2209            fact_hashes.iter().any(|fh| {
2210                tracker.lookup(fh).is_some_and(|entry| {
2211                    rule.clauses
2212                        .get(entry.clause_index)
2213                        .is_some_and(|clause| clause.is_ref_bindings.iter().any(|b| !b.negated))
2214                })
2215            })
2216        };
2217
2218        if shared_found {
2219            any_shared = true;
2220            // Collect the group rows with their base facts for BDD use.
2221            let rows: Vec<SharedGroupRow> = fact_hashes
2222                .iter()
2223                .zip(per_row_bases)
2224                .map(|(fh, bases)| SharedGroupRow {
2225                    fact_hash: fh.clone(),
2226                    lineage: bases,
2227                })
2228                .collect();
2229            shared_groups.insert(key.clone(), rows);
2230        }
2231    }
2232
2233    // Phase 5: Cross-group correlation warning.
2234    // Check if any IS-ref input fact appears in multiple KEY groups.
2235    // This is independent of within-group sharing: even rules whose KEY groups
2236    // each have only one post-fold row can exhibit cross-group correlation when
2237    // different groups consume the same IS-ref base fact.
2238    {
2239        let mut input_to_groups: HashMap<Vec<u8>, HashSet<Vec<ScalarKey>>> = HashMap::new();
2240        for (key, fact_hashes) in &groups {
2241            for fh in fact_hashes {
2242                if let Some(entry) = tracker.lookup(fh) {
2243                    for input in &entry.support {
2244                        input_to_groups
2245                            .entry(input.base_fact_id.clone())
2246                            .or_default()
2247                            .insert(key.clone());
2248                    }
2249                }
2250            }
2251        }
2252        let has_cross_group = input_to_groups.values().any(|g| g.len() > 1);
2253        if has_cross_group && let Ok(mut warnings) = warnings_slot.write() {
2254            let already_warned = warnings.iter().any(|w| {
2255                w.code == RuntimeWarningCode::CrossGroupCorrelationNotExact
2256                    && w.rule_name == rule.name
2257            });
2258            if !already_warned {
2259                // Phase D F3: pick one canonical example of a shared
2260                // input fact and the KEY groups it bridges, so users
2261                // can correlate the warning with EXPLAIN output.
2262                let example =
2263                    input_to_groups
2264                        .iter()
2265                        .find(|(_, g)| g.len() > 1)
2266                        .map(|(input, groups)| {
2267                            let short = input
2268                                .iter()
2269                                .take(8)
2270                                .map(|b| format!("{:02x}", b))
2271                                .collect::<String>();
2272                            let mut group_strs: Vec<String> =
2273                                groups.iter().map(|k| format!("{:?}", k)).collect();
2274                            group_strs.sort();
2275                            format!(
2276                                "input {} shared by groups [{}]",
2277                                short,
2278                                group_strs.join(", ")
2279                            )
2280                        });
2281                // Phase D F3 case 1 BDD-time deepening: count distinct
2282                // base facts (= BDD variables) that cross groups, so the
2283                // warning carries structured metadata mirroring
2284                // `BddLimitExceeded`. Users can correlate
2285                // `variable_count` with EXPLAIN's BDD output.
2286                let shared_variable_count =
2287                    input_to_groups.values().filter(|g| g.len() > 1).count();
2288                warnings.push(RuntimeWarning {
2289                    code: RuntimeWarningCode::CrossGroupCorrelationNotExact,
2290                    message: format!(
2291                        "Rule '{}': {} IS-ref base fact(s) are shared across different \
2292                         KEY groups. BDD corrects per-group probabilities but cannot \
2293                         account for cross-group correlations.",
2294                        rule.name, shared_variable_count
2295                    ),
2296                    rule_name: rule.name.clone(),
2297                    variable_count: Some(shared_variable_count),
2298                    key_group: example,
2299                });
2300            }
2301        }
2302    }
2303
2304    if any_shared {
2305        // Phase D D-C0b: under `SemiringKind::TopKProofs`, the FOLD-time
2306        // DNF inclusion-exclusion math (shipped in D-C0) auto-corrects
2307        // for within-group base-fact sharing — the "Results may
2308        // overestimate" premise of `SharedProbabilisticDependency`
2309        // is no longer true. Suppress the warning under TopK; users
2310        // who chose TopKProofs explicitly opted into the
2311        // correctness-preserving path. Cross-group correlation
2312        // (`CrossGroupCorrelationNotExact`) still fires above because
2313        // D-C0 doesn't span KEY-group boundaries.
2314        let suppress_under_topk = matches!(semiring_kind, SemiringKind::TopKProofs { .. });
2315        if !suppress_under_topk && let Ok(mut warnings) = warnings_slot.write() {
2316            let already_warned = warnings.iter().any(|w| {
2317                w.code == RuntimeWarningCode::SharedProbabilisticDependency
2318                    && w.rule_name == rule.name
2319            });
2320            if !already_warned {
2321                warnings.push(RuntimeWarning {
2322                    code: RuntimeWarningCode::SharedProbabilisticDependency,
2323                    message: format!(
2324                        "Rule '{}' aggregates with MNOR/MPROD but some proof paths \
2325                         share intermediate facts, violating the independence assumption. \
2326                         Results may overestimate probability.",
2327                        rule.name
2328                    ),
2329                    rule_name: rule.name.clone(),
2330                    variable_count: None,
2331                    key_group: None,
2332                });
2333            }
2334        }
2335        Some(SharedLineageInfo { shared_groups })
2336    } else {
2337        None
2338    }
2339}
2340
2341/// Record provenance and detect shared proofs for non-recursive strata.
2342///
2343/// Non-recursive rules are evaluated in a single pass (no fixpoint loop), so
2344/// the regular `record_provenance` + `detect_shared_lineage` path is never hit.
2345/// This function bridges that gap by recording a `ProvenanceAnnotation` for every
2346/// fact produced by each clause and then running the same two-tier detection
2347/// logic used by the recursive path.
2348#[allow(
2349    clippy::too_many_arguments,
2350    reason = "context bundle would be over-engineering for one call site"
2351)]
2352pub(crate) async fn record_and_detect_lineage_nonrecursive(
2353    rule: &FixpointRulePlan,
2354    tagged_clause_facts: &[(usize, Vec<RecordBatch>)],
2355    tracker: &Arc<ProvenanceStore>,
2356    warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
2357    registry: &Arc<DerivedScanRegistry>,
2358    top_k_proofs: usize,
2359    classifiers: ClassifierRefs<'_>,
2360    semiring_kind: SemiringKind,
2361) -> Option<SharedLineageInfo> {
2362    let classifier_registry = classifiers.registry;
2363    let classifier_cache = classifiers.cache;
2364    let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
2365
2366    // Pre-compute base fact probabilities for top-k mode.
2367    let base_probs = if top_k_proofs > 0 {
2368        tracker.base_fact_probs()
2369    } else {
2370        HashMap::new()
2371    };
2372
2373    let mut topk_acc = TopKProofAccumulator::new();
2374
2375    // Record provenance for each clause's facts.
2376    for (clause_index, batches) in tagged_clause_facts {
2377        for batch in batches {
2378            for row_idx in 0..batch.num_rows() {
2379                let row_hash = fact_hash_key(batch, &all_indices, row_idx);
2380                let fact_row = batch_row_to_value_map(batch, row_idx);
2381
2382                let support = collect_is_ref_inputs(rule, *clause_index, batch, row_idx, registry);
2383
2384                let proof_probability = if top_k_proofs > 0 {
2385                    compute_proof_probability(&support, &base_probs)
2386                } else {
2387                    None
2388                };
2389
2390                let entry = ProvenanceAnnotation {
2391                    rule_name: rule.name.clone(),
2392                    clause_index: *clause_index,
2393                    support,
2394                    along_values: {
2395                        let along_names: Vec<String> = rule
2396                            .clauses
2397                            .get(*clause_index)
2398                            .map(|c| c.along_bindings.clone())
2399                            .unwrap_or_default();
2400                        along_names
2401                            .iter()
2402                            .filter_map(|name| {
2403                                fact_row.get(name).map(|v| (name.clone(), v.clone()))
2404                            })
2405                            .collect()
2406                    },
2407                    iteration: 0,
2408                    fact_row: fact_row.clone(),
2409                    proof_probability,
2410                    neural_calls: collect_neural_calls_for_row(
2411                        rule,
2412                        *clause_index,
2413                        &fact_row,
2414                        classifier_registry,
2415                        classifier_cache,
2416                        classifiers.provenance_store,
2417                    )
2418                    .await,
2419                };
2420                if top_k_proofs > 0 {
2421                    topk_acc.accumulate(&entry, &row_hash);
2422                    tracker.record_top_k(row_hash, entry, top_k_proofs);
2423                } else {
2424                    tracker.record(row_hash, entry);
2425                }
2426            }
2427        }
2428    }
2429
2430    topk_acc.emit_warning_if_any(rule, top_k_proofs, warnings_slot);
2431
2432    // Flatten all clause facts and run detection.
2433    let all_facts: Vec<RecordBatch> = tagged_clause_facts
2434        .iter()
2435        .flat_map(|(_, batches)| batches.iter().cloned())
2436        .collect();
2437    detect_shared_lineage(rule, &all_facts, tracker, warnings_slot, semiring_kind)
2438}
2439
2440/// Apply exact weighted model counting (WMC) for shared-lineage groups.
2441///
2442/// Replaces multiple rows in groups with shared lineage with a single
2443/// representative row whose PROB column carries the BDD-computed exact
2444/// probability (Sang et al. 2005). For groups that exceed
2445/// `max_bdd_variables`, rows are left unchanged and a `BddLimitExceeded`
2446/// warning is emitted.
2447pub(crate) fn apply_exact_wmc(
2448    pre_fold_facts: Vec<RecordBatch>,
2449    rule: &FixpointRulePlan,
2450    shared_info: &SharedLineageInfo,
2451    tracker: &Arc<ProvenanceStore>,
2452    max_bdd_variables: usize,
2453    warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
2454    approximate_slot: &Arc<StdRwLock<HashMap<String, Vec<String>>>>,
2455) -> DFResult<Vec<RecordBatch>> {
2456    use crate::query::df_graph::locy_bdd::{SemiringOp, weighted_model_count};
2457    use uni_locy::{RuntimeWarning, RuntimeWarningCode};
2458
2459    // Find the probability-domain fold binding to know which column
2460    // to overwrite. M3: dispatch through the `LocyAggregate` trait so
2461    // user-authored probability aggregates participate.
2462    let prob_fold = rule
2463        .fold_bindings
2464        .iter()
2465        .find(|fb| fb.aggregate.is_probability_aggregate());
2466    let prob_fold = match prob_fold {
2467        Some(f) => f,
2468        None => return Ok(pre_fold_facts),
2469    };
2470    let semiring_op = if prob_fold.aggregate.is_noisy_or() {
2471        SemiringOp::Disjunction
2472    } else {
2473        SemiringOp::Conjunction
2474    };
2475    let prob_col_idx = prob_fold.input_col_index;
2476    let prob_col_name = rule.yield_schema.field(prob_col_idx).name().clone();
2477
2478    let key_indices = &rule.key_column_indices;
2479    let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
2480
2481    // Build a set of shared group keys for quick lookup.
2482    let shared_keys: HashSet<Vec<ScalarKey>> = shared_info.shared_groups.keys().cloned().collect();
2483
2484    // Phase 1: Collect all rows for each shared KEY group across all batches.
2485    // Store (batch_index, row_index) pairs for each group.
2486    struct GroupAccum {
2487        base_facts: Vec<HashSet<Vec<u8>>>,
2488        base_probs: HashMap<Vec<u8>, f64>,
2489        /// First occurrence: (batch_index, row_index) — used as representative.
2490        representative: (usize, usize),
2491        row_locations: Vec<(usize, usize)>,
2492    }
2493
2494    let mut group_accums: HashMap<Vec<ScalarKey>, GroupAccum> = HashMap::new();
2495    let mut non_shared_rows: Vec<(usize, usize)> = Vec::new(); // (batch_idx, row_idx)
2496
2497    for (batch_idx, batch) in pre_fold_facts.iter().enumerate() {
2498        for row_idx in 0..batch.num_rows() {
2499            let key = extract_scalar_key(batch, key_indices, row_idx);
2500            if shared_keys.contains(&key) {
2501                let fact_hash = fact_hash_key(batch, &all_indices, row_idx);
2502                let bases = compute_lineage(&fact_hash, tracker, &mut HashSet::new());
2503
2504                let accum = group_accums.entry(key).or_insert_with(|| GroupAccum {
2505                    base_facts: Vec::new(),
2506                    base_probs: HashMap::new(),
2507                    representative: (batch_idx, row_idx),
2508                    row_locations: Vec::new(),
2509                });
2510
2511                // Look up probabilities for base facts.
2512                for bf in &bases {
2513                    if !accum.base_probs.contains_key(bf)
2514                        && let Some(entry) = tracker.lookup(bf)
2515                        && let Some(val) = entry.fact_row.get(&prob_col_name)
2516                        && let Some(p) = value_to_f64(val)
2517                    {
2518                        accum.base_probs.insert(bf.clone(), p);
2519                    }
2520                }
2521
2522                accum.base_facts.push(bases);
2523                accum.row_locations.push((batch_idx, row_idx));
2524            } else {
2525                non_shared_rows.push((batch_idx, row_idx));
2526            }
2527        }
2528    }
2529
2530    // Phase 2: Compute BDD for each shared group (across all batches).
2531    // Track which (batch_idx, row_idx) pairs to keep vs drop.
2532    let mut keep_rows: HashSet<(usize, usize)> = HashSet::new();
2533    // Map of (batch_idx, row_idx) → overridden PROB value (for BDD-succeeded groups).
2534    let mut overrides: HashMap<(usize, usize), f64> = HashMap::new();
2535
2536    // All non-shared rows are kept.
2537    for &loc in &non_shared_rows {
2538        keep_rows.insert(loc);
2539    }
2540
2541    for (key, accum) in &group_accums {
2542        let bdd_result = weighted_model_count(
2543            &accum.base_facts,
2544            &accum.base_probs,
2545            semiring_op,
2546            max_bdd_variables,
2547        );
2548
2549        if bdd_result.approximated {
2550            // Emit BddLimitExceeded warning (one per key group).
2551            if let Ok(mut warnings) = warnings_slot.write() {
2552                let key_desc = format!("{key:?}");
2553                let already_warned = warnings.iter().any(|w| {
2554                    w.code == RuntimeWarningCode::BddLimitExceeded
2555                        && w.rule_name == rule.name
2556                        && w.key_group.as_deref() == Some(&key_desc)
2557                });
2558                if !already_warned {
2559                    warnings.push(RuntimeWarning {
2560                        code: RuntimeWarningCode::BddLimitExceeded,
2561                        message: format!(
2562                            "Rule '{}': BDD variable limit exceeded ({} > {}). \
2563                             Falling back to independence-mode result.",
2564                            rule.name, bdd_result.variable_count, max_bdd_variables
2565                        ),
2566                        rule_name: rule.name.clone(),
2567                        variable_count: Some(bdd_result.variable_count),
2568                        key_group: Some(key_desc),
2569                    });
2570                }
2571            }
2572            if let Ok(mut approx) = approximate_slot.write() {
2573                let key_desc = format!("{key:?}");
2574                approx.entry(rule.name.clone()).or_default().push(key_desc);
2575            }
2576            // Keep all rows unchanged.
2577            for &loc in &accum.row_locations {
2578                keep_rows.insert(loc);
2579            }
2580        } else {
2581            // BDD succeeded: keep one representative row with overridden PROB.
2582            keep_rows.insert(accum.representative);
2583            overrides.insert(accum.representative, bdd_result.probability);
2584        }
2585    }
2586
2587    // Phase 3: Build output batches by filtering kept rows per batch.
2588    let mut result_batches = Vec::new();
2589    for (batch_idx, batch) in pre_fold_facts.iter().enumerate() {
2590        let kept_indices: Vec<usize> = (0..batch.num_rows())
2591            .filter(|&row_idx| keep_rows.contains(&(batch_idx, row_idx)))
2592            .collect();
2593
2594        if kept_indices.is_empty() {
2595            continue;
2596        }
2597
2598        let indices = arrow::array::UInt32Array::from(
2599            kept_indices.iter().map(|&i| i as u32).collect::<Vec<_>>(),
2600        );
2601        let mut columns: Vec<arrow::array::ArrayRef> = batch
2602            .columns()
2603            .iter()
2604            .map(|col| arrow::compute::take(col, &indices, None))
2605            .collect::<Result<Vec<_>, _>>()
2606            .map_err(arrow_err)?;
2607
2608        // Check if any kept rows have PROB overrides.
2609        let override_map: Vec<Option<f64>> = kept_indices
2610            .iter()
2611            .map(|&row_idx| overrides.get(&(batch_idx, row_idx)).copied())
2612            .collect();
2613
2614        if override_map.iter().any(|o| o.is_some()) && prob_col_idx < columns.len() {
2615            // Rebuild the PROB column with overrides.
2616            let existing_prob = columns[prob_col_idx]
2617                .as_any()
2618                .downcast_ref::<arrow::array::Float64Array>();
2619            let new_values: Vec<f64> = override_map
2620                .iter()
2621                .enumerate()
2622                .map(|(i, ov)| match ov {
2623                    Some(p) => *p,
2624                    None => existing_prob.map(|arr| arr.value(i)).unwrap_or(0.0),
2625                })
2626                .collect();
2627            columns[prob_col_idx] = Arc::new(arrow::array::Float64Array::from(new_values));
2628        }
2629
2630        let result_batch = RecordBatch::try_new(batch.schema(), columns).map_err(arrow_err)?;
2631        result_batches.push(result_batch);
2632    }
2633
2634    Ok(result_batches)
2635}
2636
2637/// Extract an f64 from a `Value`, supporting Float and Int.
2638fn value_to_f64(val: &uni_common::Value) -> Option<f64> {
2639    match val {
2640        uni_common::Value::Float(f) => Some(*f),
2641        uni_common::Value::Int(i) => Some(*i as f64),
2642        _ => None,
2643    }
2644}
2645
2646/// Compute the lineage of a derived fact (Cui & Widom 2000).
2647///
2648/// Recursively traverses the provenance store to collect the set of base-level
2649/// fact hashes that contribute to this derivation. A base fact is one with no
2650/// IS-ref support (a graph-level fact). Intermediate facts are expanded
2651/// transitively through the store.
2652fn compute_lineage(
2653    fact_hash: &[u8],
2654    tracker: &Arc<ProvenanceStore>,
2655    visited: &mut HashSet<Vec<u8>>,
2656) -> HashSet<Vec<u8>> {
2657    if !visited.insert(fact_hash.to_vec()) {
2658        return HashSet::new(); // Cycle guard.
2659    }
2660
2661    match tracker.lookup(fact_hash) {
2662        Some(entry) if !entry.support.is_empty() => {
2663            let mut bases = HashSet::new();
2664            for input in &entry.support {
2665                let child_bases = compute_lineage(&input.base_fact_id, tracker, visited);
2666                bases.extend(child_bases);
2667            }
2668            bases
2669        }
2670        _ => {
2671            // Base fact (no tracker entry or no inputs).
2672            let mut set = HashSet::new();
2673            set.insert(fact_hash.to_vec());
2674            set
2675        }
2676    }
2677}
2678
2679/// Determine which clause produced a given row by checking each clause's candidates.
2680///
2681/// Returns the index of the first clause whose candidates contain a matching row.
2682/// Falls back to 0 if no match is found.
2683fn find_clause_for_row(
2684    delta_batch: &RecordBatch,
2685    row_idx: usize,
2686    all_indices: &[usize],
2687    clause_candidates: &[Vec<RecordBatch>],
2688) -> usize {
2689    let target_key = extract_scalar_key(delta_batch, all_indices, row_idx);
2690    for (clause_idx, batches) in clause_candidates.iter().enumerate() {
2691        for batch in batches {
2692            if batch.num_columns() != all_indices.len() {
2693                continue;
2694            }
2695            for r in 0..batch.num_rows() {
2696                if extract_scalar_key(batch, all_indices, r) == target_key {
2697                    return clause_idx;
2698                }
2699            }
2700        }
2701    }
2702    0
2703}
2704
2705/// Convert a single row from a `RecordBatch` at `row_idx` into a `HashMap<String, Value>`.
2706fn batch_row_to_value_map(
2707    batch: &RecordBatch,
2708    row_idx: usize,
2709) -> std::collections::HashMap<String, Value> {
2710    use uni_store::storage::arrow_convert::arrow_to_value;
2711
2712    let schema = batch.schema();
2713    schema
2714        .fields()
2715        .iter()
2716        .enumerate()
2717        .map(|(col_idx, field)| {
2718            let col = batch.column(col_idx);
2719            let val = arrow_to_value(col.as_ref(), row_idx, None);
2720            (field.name().clone(), val)
2721        })
2722        .collect()
2723}
2724
2725/// Filter `batches` to exclude rows where `left_col` VID appears in `neg_facts[right_col]`.
2726///
2727/// Implements anti-join semantics for negated IS-refs (`n IS NOT rule`): keeps only
2728/// rows whose subject VID is NOT present in the negated rule's fully-converged facts.
2729pub fn apply_anti_join(
2730    batches: Vec<RecordBatch>,
2731    neg_facts: &[RecordBatch],
2732    left_col: &str,
2733    right_col: &str,
2734) -> datafusion::error::Result<Vec<RecordBatch>> {
2735    use arrow::compute::filter_record_batch;
2736    use arrow_array::{Array as _, BooleanArray, UInt64Array};
2737
2738    // Collect right-side VIDs from the negated rule's derived facts.
2739    let mut banned: std::collections::HashSet<u64> = std::collections::HashSet::new();
2740    for batch in neg_facts {
2741        let Ok(idx) = batch.schema().index_of(right_col) else {
2742            continue;
2743        };
2744        let arr = batch.column(idx);
2745        let Some(vids) = arr.as_any().downcast_ref::<UInt64Array>() else {
2746            continue;
2747        };
2748        for i in 0..vids.len() {
2749            if !vids.is_null(i) {
2750                banned.insert(vids.value(i));
2751            }
2752        }
2753    }
2754
2755    if banned.is_empty() {
2756        return Ok(batches);
2757    }
2758
2759    // Filter body batches: keep rows where left_col NOT IN banned.
2760    let mut result = Vec::new();
2761    for batch in batches {
2762        let Ok(idx) = batch.schema().index_of(left_col) else {
2763            result.push(batch);
2764            continue;
2765        };
2766        let arr = batch.column(idx);
2767        let Some(vids) = arr.as_any().downcast_ref::<UInt64Array>() else {
2768            result.push(batch);
2769            continue;
2770        };
2771        let keep: Vec<bool> = (0..vids.len())
2772            .map(|i| vids.is_null(i) || !banned.contains(&vids.value(i)))
2773            .collect();
2774        let keep_arr = BooleanArray::from(keep);
2775        let filtered = filter_record_batch(&batch, &keep_arr).map_err(arrow_err)?;
2776        if filtered.num_rows() > 0 {
2777            result.push(filtered);
2778        }
2779    }
2780    Ok(result)
2781}
2782
2783// ─── Phase B Slice 3: neural-model invocation pass ───────────────────────
2784//
2785// `apply_model_invocations` runs every `ModelInvocation` lifted from a
2786// clause's YIELD items against the body's output batches. For each
2787// invocation it:
2788//
2789//   1. Resolves each `feature_expr` to a column in the batch — Slice 3
2790//      supports plain `Expr::Variable("name")` references; richer
2791//      expressions (property access, nested calls) are deferred.
2792//   2. Builds one `ClassifyInput` per row keyed by the model's input
2793//      binding names.
2794//   3. Issues a single batched `NeuralClassifier::classify` call.
2795//   4. Appends the resulting `Float64Array` as a new column matching
2796//      `invocation.output_column`.
2797//
2798// Errors:
2799//   * `UnknownNeuralModel`: the model name isn't in the registry.
2800//   * Mismatched feature-expr / column: returned as a DataFusion
2801//     Execution error.
2802
2803#[allow(clippy::too_many_arguments)]
2804pub(crate) async fn apply_model_invocations(
2805    batches: Vec<RecordBatch>,
2806    invocations: &[uni_locy::ModelInvocation],
2807    registry: &Arc<ClassifierRegistry>,
2808    cache: Option<&Arc<uni_locy::ModelInvocationCache>>,
2809    provenance_store: Option<&Arc<uni_locy::NeuralProvenanceStore>>,
2810    path_context_handles: &HashMap<
2811        String,
2812        crate::query::df_graph::locy_model_invoke::PathContextHandle,
2813    >,
2814    xervo_runtime: &crate::query::df_graph::locy_model_invoke::XervoRuntimeHandle,
2815    graph_algo: &crate::query::df_graph::locy_model_invoke::GraphAlgoHandle,
2816) -> DFResult<Vec<RecordBatch>> {
2817    use uni_locy::ClassifyInput;
2818    if batches.is_empty() || invocations.is_empty() {
2819        return Ok(batches);
2820    }
2821    // Phase D D2: pre-embed all unique `semantic_match` query literals
2822    // once per call. Resolvers below lower each `semantic_match(prop,
2823    // 'text')` into a `SimilarTo { left: prop_col, right: Const(Vector) }`.
2824    let semantic_match_embeddings =
2825        pre_embed_semantic_match_queries(invocations, xervo_runtime).await?;
2826    // Phase D D1 graph-structural: pre-compute topology scores and
2827    // neighbor-property maps for every distinct (fn_name, args) tuple
2828    // appearing in any FEATURE FunctionCall. One pass per call; reused
2829    // across every row of every batch.
2830    let graph_feature_maps = precompute_graph_feature_maps(invocations, graph_algo).await?;
2831    let neighbor_feature_maps =
2832        precompute_neighbor_feature_maps(invocations, &batches, graph_algo).await?;
2833    let mut out_batches = Vec::with_capacity(batches.len());
2834    for batch in batches {
2835        let mut current = batch;
2836        for invocation in invocations {
2837            let classifier = registry.get(&invocation.model_name).ok_or_else(|| {
2838                datafusion::error::DataFusionError::Execution(format!(
2839                    "neural classifier '{}' not registered; \
2840                         add it to LocyConfig::classifier_registry",
2841                    invocation.model_name
2842                ))
2843            })?;
2844
2845            // Resolve each feature_expr to a per-row evaluator.
2846            // Supported shapes (validated at compile time by
2847            // `extract_model_invocations` / `validate_features`):
2848            //   * `Expr::Variable("name")` — direct column reference.
2849            //   * `Expr::Property(Variable(v), prop)` — looked up by the
2850            //     conventional `"v.prop"` column name materialized by
2851            //     the planner's `translate_property_access` pipeline.
2852            //   * `Expr::FunctionCall { name: "similar_to"|"semantic_match", ... }`
2853            //     — Phase D D1/D2 retrieval-backed feature; both args
2854            //     resolved to columns; UDF evaluated per row against
2855            //     the row's `uni_common::Value` payloads.
2856            let resolvers = build_feature_resolvers(
2857                &current,
2858                invocation,
2859                path_context_handles,
2860                &semantic_match_embeddings,
2861                &graph_feature_maps,
2862                &neighbor_feature_maps,
2863            )?;
2864
2865            // Build one ClassifyInput per row.
2866            let n_rows = current.num_rows();
2867            let mut inputs: Vec<ClassifyInput> = Vec::with_capacity(n_rows);
2868            let mut input_hashes: Vec<u64> = Vec::with_capacity(n_rows);
2869            for row_idx in 0..n_rows {
2870                let mut features = std::collections::HashMap::new();
2871                for resolver in &resolvers {
2872                    let value = resolver.eval_row(&current, row_idx)?;
2873                    features.insert(resolver.binding_name.clone(), value);
2874                }
2875                let input = ClassifyInput { features };
2876                input_hashes.push(input.stable_hash());
2877                inputs.push(input);
2878            }
2879
2880            // Slice 2: memoization. Split inputs into cache hits and
2881            // misses; only call `classify` on the misses, then weave
2882            // the cached values back in by original row index.
2883            let mut probs: Vec<f64> = vec![0.0; n_rows];
2884            let mut miss_inputs: Vec<ClassifyInput> = Vec::new();
2885            let mut miss_row_indices: Vec<usize> = Vec::new();
2886            if let Some(c) = cache {
2887                for (row_idx, h) in input_hashes.iter().enumerate() {
2888                    match c.get(&invocation.model_name, *h) {
2889                        Some(v) => probs[row_idx] = v,
2890                        None => {
2891                            miss_row_indices.push(row_idx);
2892                            miss_inputs.push(inputs[row_idx].clone());
2893                        }
2894                    }
2895                }
2896            } else {
2897                miss_row_indices = (0..n_rows).collect();
2898                miss_inputs = inputs.clone();
2899            }
2900
2901            // Phase C C-RawCalibratedSeparation: when a calibrator is
2902            // present, route through `raw_and_calibrated` so the
2903            // provenance store records both the base classifier's raw
2904            // output AND the post-calibrator value. Bare classifiers
2905            // (no calibrator) keep using `classify`. The downstream
2906            // `probs[row]` is always the *calibrated* value when
2907            // available — that's what the rule's PROB output column
2908            // and the memoization cache carry.
2909            let calibrator = classifier.get_calibrator();
2910            let (miss_raws, miss_calibrated) = if miss_inputs.is_empty() {
2911                (Vec::new(), Vec::new())
2912            } else if calibrator.is_some() {
2913                let pairs = classifier
2914                    .raw_and_calibrated(&miss_inputs)
2915                    .await
2916                    .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
2917                if pairs.len() != miss_inputs.len() {
2918                    return Err(datafusion::error::DataFusionError::Execution(format!(
2919                        "classifier '{}' raw_and_calibrated returned {} outputs for {} inputs",
2920                        invocation.model_name,
2921                        pairs.len(),
2922                        miss_inputs.len()
2923                    )));
2924                }
2925                let raws: Vec<f64> = pairs.iter().map(|(r, _)| *r).collect();
2926                let cals: Vec<f64> = pairs.iter().map(|(r, c)| c.unwrap_or(*r)).collect();
2927                (raws, cals)
2928            } else {
2929                let r = classifier
2930                    .classify(&miss_inputs)
2931                    .await
2932                    .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
2933                if r.len() != miss_inputs.len() {
2934                    return Err(datafusion::error::DataFusionError::Execution(format!(
2935                        "classifier '{}' returned {} outputs for {} inputs",
2936                        invocation.model_name,
2937                        r.len(),
2938                        miss_inputs.len()
2939                    )));
2940                }
2941                // No calibrator → raw == final.
2942                (r.clone(), r)
2943            };
2944            // The memoization cache stores the *final* (calibrated when
2945            // available) value, matching what downstream rules consume.
2946            // Track per-miss raws alongside so the provenance store
2947            // sees both. For cache hits, we don't have raws — the
2948            // provenance record for that row gets None for `raw` and
2949            // the cached value as `calibrated` (the only thing we
2950            // remembered). Future slice can extend the cache to carry
2951            // both; current behavior matches pre-fix correctness for
2952            // EXPLAIN of cache hits.
2953            let mut row_raw: Vec<Option<f64>> = vec![None; n_rows];
2954            for (i, &row_idx) in miss_row_indices.iter().enumerate() {
2955                probs[row_idx] = miss_calibrated[i];
2956                row_raw[row_idx] = Some(miss_raws[i]);
2957                if let Some(c) = cache {
2958                    c.insert(
2959                        &invocation.model_name,
2960                        input_hashes[row_idx],
2961                        miss_calibrated[i],
2962                    );
2963                }
2964            }
2965
2966            // Phase C B1-B3 follow-up: when a provenance store is
2967            // configured, record (raw, calibrated, confidence_band)
2968            // per row. With C-RawCalibratedSeparation (above),
2969            // `row_raw[i]` carries the *pre-calibrator* value when we
2970            // computed it on this call; `probs[i]` is the
2971            // post-calibrator value. `confidence_band` comes from the
2972            // active Calibrator's `confidence_band(p)`.
2973            if let Some(store) = provenance_store {
2974                for row_idx in 0..n_rows {
2975                    let calibrated_value = probs[row_idx];
2976                    let (raw_value, calibrated) = match (row_raw[row_idx], &calibrator) {
2977                        (Some(raw), Some(_)) => (raw, Some(calibrated_value)),
2978                        (Some(raw), None) => (raw, None),
2979                        // Cache hit: we only have the calibrated value.
2980                        // Report it as raw with `calibrated == raw` to
2981                        // preserve telemetry shape; document this in
2982                        // the field doc.
2983                        (None, _) => (
2984                            calibrated_value,
2985                            calibrator.as_ref().map(|_| calibrated_value),
2986                        ),
2987                    };
2988                    let band = calibrator
2989                        .as_ref()
2990                        .and_then(|c| c.confidence_band(calibrated_value));
2991                    store.record(
2992                        &invocation.model_name,
2993                        input_hashes[row_idx],
2994                        uni_locy::NeuralProvenanceRecord {
2995                            raw_probability: raw_value,
2996                            calibrated_probability: calibrated,
2997                            confidence_band: band,
2998                            // Phase 12 EXPLAIN follow-up: stash the
2999                            // per-binding `FeatureValue` map that fed
3000                            // the classifier so Mode B re-evaluation
3001                            // can surface graph-structural feature
3002                            // values without re-precomputing topology
3003                            // or neighbor maps (the hot-path data is
3004                            // authoritative).
3005                            feature_inputs: inputs[row_idx].features.clone(),
3006                        },
3007                    );
3008                }
3009            }
3010
3011            // Overwrite the placeholder column put there by the
3012            // compile-time YIELD rewrite. If the column doesn't yet
3013            // exist (defensive — shouldn't happen for well-formed
3014            // plans), fall back to appending.
3015            let out_col: Arc<dyn arrow_array::Array> =
3016                Arc::new(arrow_array::Float64Array::from(probs));
3017            let schema = current.schema();
3018            let target_idx = schema.index_of(&invocation.output_column).ok();
3019            let mut columns: Vec<Arc<dyn arrow_array::Array>> = current.columns().to_vec();
3020            let mut fields: Vec<Arc<arrow_schema::Field>> =
3021                schema.fields().iter().cloned().collect();
3022            match target_idx {
3023                Some(idx) => {
3024                    columns[idx] = out_col;
3025                    // Force the field's data type to Float64 in case
3026                    // the placeholder was inferred at a wider type.
3027                    fields[idx] = Arc::new(arrow_schema::Field::new(
3028                        &invocation.output_column,
3029                        arrow_schema::DataType::Float64,
3030                        true,
3031                    ));
3032                }
3033                None => {
3034                    columns.push(out_col);
3035                    fields.push(Arc::new(arrow_schema::Field::new(
3036                        &invocation.output_column,
3037                        arrow_schema::DataType::Float64,
3038                        true,
3039                    )));
3040                }
3041            }
3042            let new_schema = Arc::new(arrow_schema::Schema::new(fields));
3043            current = RecordBatch::try_new(new_schema, columns).map_err(arrow_err)?;
3044        }
3045        out_batches.push(current);
3046    }
3047    Ok(out_batches)
3048}
3049
3050/// Extract a [`uni_locy::FeatureValue`] from a column at a given row.
3051/// Conservative cast set matching the property-graph value types Locy
3052/// currently exposes; unsupported types fall back to `Null`.
3053/// Per-feature evaluator built once from a clause's `feature_exprs`
3054/// and reused for every row in the batch. Supports plain column
3055/// reads (`Direct`) and the Phase D D1/D2 retrieval-backed UDFs
3056/// (`SimilarTo`, `SemanticMatch`).
3057struct FeatureResolver {
3058    binding_name: String,
3059    kind: FeatureResolverKind,
3060}
3061
3062enum FeatureResolverKind {
3063    Direct(usize),
3064    SimilarTo {
3065        left: FeatureValueSrc,
3066        right: FeatureValueSrc,
3067    },
3068    /// Phase D D3 runtime: pull `column` from the source rule's
3069    /// derived facts via a pre-built `vid → FeatureValue` lookup. The
3070    /// `subject_col` is the index of `<subject_var>._vid` in the body
3071    /// batch; the lookup runs once per row.
3072    PathContext {
3073        subject_col: usize,
3074        vid_to_value: Arc<HashMap<u64, uni_locy::FeatureValue>>,
3075    },
3076    /// Phase D D1 graph-structural: look up the subject's pre-computed
3077    /// topology score (degree/pagerank/closeness). `subject_col` indexes
3078    /// the row's `<var>._vid`; `vid_to_score` is the whole-graph
3079    /// procedure output built once per `apply_model_invocations` call.
3080    GraphAlgoScore {
3081        subject_col: usize,
3082        vid_to_score: Arc<HashMap<u64, f64>>,
3083    },
3084    /// Phase D D1 graph-structural: aggregate a numeric property over
3085    /// each subject's one-hop outgoing neighborhood along a named edge
3086    /// type. `vid_to_values` maps subject vid → the list of numeric
3087    /// neighbor property values collected at precompute time; the
3088    /// per-row resolver applies the configured `op` (avg/max/sum).
3089    NeighborAggregate {
3090        subject_col: usize,
3091        op: NeighborAgg,
3092        vid_to_values: Arc<HashMap<u64, Vec<f64>>>,
3093    },
3094}
3095
3096#[derive(Debug, Clone, Copy)]
3097enum NeighborAgg {
3098    Avg,
3099    Max,
3100    Sum,
3101}
3102
3103/// Direction for one-hop neighborhood traversal. Mirrors
3104/// `uni_store::storage::direction::Direction` but is independent so
3105/// the typecheck / planner layer doesn't depend on uni-store.
3106#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3107enum NeighborDirection {
3108    Outgoing,
3109    Incoming,
3110    Both,
3111}
3112
3113impl NeighborDirection {
3114    fn store_directions(self) -> &'static [uni_store::storage::direction::Direction] {
3115        use uni_store::storage::direction::Direction;
3116        match self {
3117            NeighborDirection::Outgoing => &[Direction::Outgoing],
3118            NeighborDirection::Incoming => &[Direction::Incoming],
3119            NeighborDirection::Both => &[Direction::Outgoing, Direction::Incoming],
3120        }
3121    }
3122}
3123
3124impl NeighborAgg {
3125    fn from_fn_name(name: &str) -> Option<Self> {
3126        match name {
3127            "avg_neighbor" => Some(NeighborAgg::Avg),
3128            "max_neighbor" => Some(NeighborAgg::Max),
3129            "sum_neighbor" => Some(NeighborAgg::Sum),
3130            _ => None,
3131        }
3132    }
3133
3134    fn apply(self, values: &[f64]) -> Option<f64> {
3135        if values.is_empty() {
3136            return None;
3137        }
3138        match self {
3139            NeighborAgg::Avg => Some(values.iter().sum::<f64>() / values.len() as f64),
3140            NeighborAgg::Max => values.iter().copied().reduce(f64::max),
3141            NeighborAgg::Sum => Some(values.iter().sum()),
3142        }
3143    }
3144}
3145
3146/// One side of a `similar_to` feature: either a column index in the
3147/// per-row batch or a constant value lifted from a literal expression.
3148enum FeatureValueSrc {
3149    Col(usize),
3150    Const(uni_common::Value),
3151}
3152
3153impl FeatureValueSrc {
3154    fn resolve(&self, batch: &RecordBatch, row_idx: usize) -> uni_common::Value {
3155        match self {
3156            FeatureValueSrc::Col(idx) => extract_common_value(batch.column(*idx).as_ref(), row_idx),
3157            FeatureValueSrc::Const(v) => v.clone(),
3158        }
3159    }
3160}
3161
3162impl FeatureResolver {
3163    fn eval_row(&self, batch: &RecordBatch, row_idx: usize) -> DFResult<uni_locy::FeatureValue> {
3164        match &self.kind {
3165            FeatureResolverKind::Direct(idx) => {
3166                Ok(extract_feature_value(batch.column(*idx).as_ref(), row_idx))
3167            }
3168            FeatureResolverKind::SimilarTo { left, right } => {
3169                let lv = left.resolve(batch, row_idx);
3170                let rv = right.resolve(batch, row_idx);
3171                match crate::query::similar_to::eval_similar_to_pure(&lv, &rv) {
3172                    Ok(uni_common::Value::Float(f)) => Ok(uni_locy::FeatureValue::Float(f)),
3173                    Ok(_) => Ok(uni_locy::FeatureValue::Null),
3174                    Err(e) => Err(datafusion::error::DataFusionError::Execution(format!(
3175                        "similar_to UDF failed: {e}"
3176                    ))),
3177                }
3178            }
3179            FeatureResolverKind::PathContext {
3180                subject_col,
3181                vid_to_value,
3182            } => {
3183                let col = batch.column(*subject_col);
3184                if col.is_null(row_idx) {
3185                    return Ok(uni_locy::FeatureValue::Null);
3186                }
3187                if let Some(arr) = col.as_any().downcast_ref::<arrow_array::UInt64Array>() {
3188                    let vid = arr.value(row_idx);
3189                    Ok(vid_to_value
3190                        .get(&vid)
3191                        .cloned()
3192                        .unwrap_or(uni_locy::FeatureValue::Null))
3193                } else if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Int64Array>() {
3194                    let vid = arr.value(row_idx) as u64;
3195                    Ok(vid_to_value
3196                        .get(&vid)
3197                        .cloned()
3198                        .unwrap_or(uni_locy::FeatureValue::Null))
3199                } else {
3200                    Ok(uni_locy::FeatureValue::Null)
3201                }
3202            }
3203            FeatureResolverKind::GraphAlgoScore {
3204                subject_col,
3205                vid_to_score,
3206            } => {
3207                let col = batch.column(*subject_col);
3208                if col.is_null(row_idx) {
3209                    return Ok(uni_locy::FeatureValue::Null);
3210                }
3211                let vid_opt: Option<u64> = if let Some(arr) =
3212                    col.as_any().downcast_ref::<arrow_array::UInt64Array>()
3213                {
3214                    Some(arr.value(row_idx))
3215                } else if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Int64Array>() {
3216                    Some(arr.value(row_idx) as u64)
3217                } else {
3218                    // Fallback: subject column carries a Node-encoded
3219                    // `uni_common::Value` (LargeBinary via codec). Decode
3220                    // and pull the VID. This is the common case for
3221                    // bare-variable subjects where no `_vid` hidden
3222                    // column was materialized.
3223                    match extract_common_value(col.as_ref(), row_idx) {
3224                        uni_common::Value::Node(n) => Some(n.vid.as_u64()),
3225                        uni_common::Value::Int(i) => Some(i as u64),
3226                        _ => None,
3227                    }
3228                };
3229                Ok(vid_opt
3230                    .and_then(|v| vid_to_score.get(&v).copied())
3231                    .map(uni_locy::FeatureValue::Float)
3232                    .unwrap_or(uni_locy::FeatureValue::Null))
3233            }
3234            FeatureResolverKind::NeighborAggregate {
3235                subject_col,
3236                op,
3237                vid_to_values,
3238            } => {
3239                let vid_opt = extract_vid_from_column(batch.column(*subject_col).as_ref(), row_idx);
3240                Ok(vid_opt
3241                    .and_then(|v| vid_to_values.get(&v))
3242                    .and_then(|values| op.apply(values))
3243                    .map(uni_locy::FeatureValue::Float)
3244                    .unwrap_or(uni_locy::FeatureValue::Null))
3245            }
3246        }
3247    }
3248}
3249
3250/// Extract a node VID from a per-row batch column. Handles the three
3251/// common shapes: `_vid` UInt64 columns, Int64 columns, and Node-encoded
3252/// LargeBinary columns (the standard `uni_common::Value::Node` codec
3253/// representation for a bare variable column).
3254fn extract_vid_from_column(col: &dyn arrow_array::Array, row_idx: usize) -> Option<u64> {
3255    if col.is_null(row_idx) {
3256        return None;
3257    }
3258    if let Some(arr) = col.as_any().downcast_ref::<arrow_array::UInt64Array>() {
3259        return Some(arr.value(row_idx));
3260    }
3261    if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Int64Array>() {
3262        return Some(arr.value(row_idx) as u64);
3263    }
3264    match extract_common_value(col, row_idx) {
3265        uni_common::Value::Node(n) => Some(n.vid.as_u64()),
3266        uni_common::Value::Int(i) => Some(i as u64),
3267        _ => None,
3268    }
3269}
3270
3271#[allow(clippy::too_many_arguments)]
3272fn build_feature_resolvers(
3273    batch: &RecordBatch,
3274    invocation: &uni_locy::ModelInvocation,
3275    path_context_handles: &HashMap<
3276        String,
3277        crate::query::df_graph::locy_model_invoke::PathContextHandle,
3278    >,
3279    semantic_match_embeddings: &HashMap<String, Vec<f32>>,
3280    graph_feature_maps: &HashMap<String, Arc<HashMap<u64, f64>>>,
3281    neighbor_feature_maps: &NeighborFeatureMaps,
3282) -> DFResult<Vec<FeatureResolver>> {
3283    use uni_cypher::ast::Expr;
3284    let schema = batch.schema();
3285    let lookup_col = |name_or_property: String| -> DFResult<usize> {
3286        schema.index_of(&name_or_property).map_err(|_| {
3287            datafusion::error::DataFusionError::Execution(format!(
3288                "feature column '{name_or_property}' not found in clause body output schema"
3289            ))
3290        })
3291    };
3292    // Resolve a feature sub-expression to a per-row value source. Variables
3293    // and property accesses map to batch columns; list/scalar literals
3294    // become inline constants — required so `similar_to(s.embedding, [1,0,0])`
3295    // works without a hidden column for the literal vector.
3296    let resolve_src = |expr: &Expr| -> DFResult<FeatureValueSrc> {
3297        match expr {
3298            Expr::Variable(name) => {
3299                let col = if schema.index_of(name).is_ok() {
3300                    name.clone()
3301                } else {
3302                    let vid_name = format!("{}._vid", name);
3303                    if schema.index_of(&vid_name).is_ok() {
3304                        vid_name
3305                    } else {
3306                        name.clone()
3307                    }
3308                };
3309                Ok(FeatureValueSrc::Col(lookup_col(col)?))
3310            }
3311            Expr::Property(boxed, prop) if matches!(boxed.as_ref(), Expr::Variable(_)) => {
3312                let Expr::Variable(v) = boxed.as_ref() else {
3313                    unreachable!()
3314                };
3315                let direct = format!("{}.{}", v, prop);
3316                let col = if schema.index_of(&direct).is_ok() {
3317                    direct
3318                } else {
3319                    format!("__feat_{}_{}", v, prop)
3320                };
3321                Ok(FeatureValueSrc::Col(lookup_col(col)?))
3322            }
3323            Expr::Literal(lit) => Ok(FeatureValueSrc::Const(lit.to_value())),
3324            Expr::List(items) => {
3325                let mut out = Vec::with_capacity(items.len());
3326                for it in items {
3327                    out.push(match it {
3328                        Expr::Literal(lit) => lit.to_value(),
3329                        _ => uni_common::Value::Null,
3330                    });
3331                }
3332                Ok(FeatureValueSrc::Const(uni_common::Value::List(out)))
3333            }
3334            other => Err(datafusion::error::DataFusionError::Execution(format!(
3335                "unsupported feature sub-expression: {other:?}"
3336            ))),
3337        }
3338    };
3339
3340    // Phase D D3 runtime: when the model declares a path-context
3341    // feature, build a `vid → FeatureValue` lookup once from the
3342    // source rule's converged facts and wrap it in an Arc so the
3343    // per-row resolver does a single hash lookup. The model's
3344    // `INPUT` bindings are unused under this form for MVP — the
3345    // resolver's binding name is the column name (matches how the
3346    // mock-classifier feature-driver pattern in TCK consumes it).
3347    if let Some(pc) = &invocation.path_context {
3348        let handle = path_context_handles.get(&pc.source_rule).ok_or_else(|| {
3349            datafusion::error::DataFusionError::Execution(format!(
3350                "model '{}' path_context references rule '{}' but no DerivedScanHandle \
3351                 was registered; this should never happen — the build_clause path \
3352                 mints a handle for every distinct source_rule in the invocation set",
3353                invocation.model_name, pc.source_rule
3354            ))
3355        })?;
3356        let subject_col = schema
3357            .index_of(&format!("{}._vid", pc.subject_var))
3358            .or_else(|_| schema.index_of(&pc.subject_var))
3359            .map_err(|_| {
3360                datafusion::error::DataFusionError::Execution(format!(
3361                    "model '{}' path_context: subject column '{}' (or '{0}._vid') not \
3362                     in body batch schema",
3363                    invocation.model_name, pc.subject_var
3364                ))
3365            })?;
3366        let vid_to_value =
3367            build_path_context_lookup(handle, &pc.subject_var, &pc.column, &invocation.model_name)?;
3368        return Ok(vec![FeatureResolver {
3369            binding_name: pc.column.clone(),
3370            kind: FeatureResolverKind::PathContext {
3371                subject_col,
3372                vid_to_value: Arc::new(vid_to_value),
3373            },
3374        }]);
3375    }
3376
3377    let mut out = Vec::with_capacity(invocation.feature_exprs.len());
3378    for (i, fexpr) in invocation.feature_exprs.iter().enumerate() {
3379        let binding_name = invocation.feature_names[i].clone();
3380        let kind = match fexpr {
3381            Expr::FunctionCall { name, args, .. } if name == "similar_to" => {
3382                if args.len() != 2 {
3383                    return Err(datafusion::error::DataFusionError::Execution(format!(
3384                        "similar_to expects 2 args, got {}",
3385                        args.len()
3386                    )));
3387                }
3388                FeatureResolverKind::SimilarTo {
3389                    left: resolve_src(&args[0])?,
3390                    right: resolve_src(&args[1])?,
3391                }
3392            }
3393            Expr::FunctionCall { name, args, .. } if name == "semantic_match" => {
3394                // Phase D D2: lower `semantic_match(prop, 'text')` to a
3395                // `SimilarTo` resolver with the pre-embedded query
3396                // vector as the right side. The literal text was embedded
3397                // once via the Xervo runtime in `pre_embed_semantic_match_queries`.
3398                if args.len() != 2 {
3399                    return Err(datafusion::error::DataFusionError::Execution(format!(
3400                        "semantic_match expects 2 args, got {}",
3401                        args.len()
3402                    )));
3403                }
3404                let text = match &args[1] {
3405                    Expr::Literal(uni_cypher::ast::CypherLiteral::String(s)) => s.clone(),
3406                    other => {
3407                        return Err(datafusion::error::DataFusionError::Execution(format!(
3408                            "semantic_match: 2nd arg must be a string literal, got {other:?}"
3409                        )));
3410                    }
3411                };
3412                let embedded = semantic_match_embeddings.get(&text).ok_or_else(|| {
3413                    datafusion::error::DataFusionError::Execution(format!(
3414                        "semantic_match: query text '{text}' was not pre-embedded. \
3415                         This is a bug — `apply_model_invocations` should have \
3416                         embedded all unique semantic_match texts up front. Most \
3417                         likely the Xervo runtime is not configured (configure \
3418                         via `LocyConfig::xervo_runtime` or its equivalent)."
3419                    ))
3420                })?;
3421                let right_vec: Vec<f32> = embedded.clone();
3422                FeatureResolverKind::SimilarTo {
3423                    left: resolve_src(&args[0])?,
3424                    right: FeatureValueSrc::Const(uni_common::Value::Vector(right_vec)),
3425                }
3426            }
3427            Expr::FunctionCall { name, args, .. }
3428                if matches!(
3429                    name.as_str(),
3430                    "degree_centrality"
3431                        | "pagerank_score"
3432                        | "closeness_centrality"
3433                        | "betweenness_centrality"
3434                        | "eigenvector_centrality"
3435                        | "harmonic_centrality"
3436                        | "katz_centrality"
3437                ) =>
3438            {
3439                if args.len() != 1 {
3440                    return Err(datafusion::error::DataFusionError::Execution(format!(
3441                        "{name} expects 1 arg, got {}",
3442                        args.len()
3443                    )));
3444                }
3445                let Expr::Variable(v) = &args[0] else {
3446                    return Err(datafusion::error::DataFusionError::Execution(format!(
3447                        "{name}(...) argument must be a node variable, got {:?}",
3448                        args[0]
3449                    )));
3450                };
3451                let subject_col = {
3452                    let direct = schema.index_of(v).ok();
3453                    let vid_name = format!("{}._vid", v);
3454                    let vid_col = schema.index_of(&vid_name).ok();
3455                    vid_col.or(direct).ok_or_else(|| {
3456                        datafusion::error::DataFusionError::Execution(format!(
3457                            "{name}: subject column '{v}' (or '{v}._vid') not in body batch schema"
3458                        ))
3459                    })?
3460                };
3461                let vid_to_score = graph_feature_maps.get(name).cloned().ok_or_else(|| {
3462                    datafusion::error::DataFusionError::Execution(format!(
3463                        "{name}: pre-computed score map missing. This is a bug — \
3464                         `apply_model_invocations` should have called \
3465                         `precompute_graph_feature_maps` for every graph-structural \
3466                         feature before building resolvers. Most likely the graph \
3467                         algorithm registry is not configured."
3468                    ))
3469                })?;
3470                FeatureResolverKind::GraphAlgoScore {
3471                    subject_col,
3472                    vid_to_score,
3473                }
3474            }
3475            Expr::FunctionCall { name, args, .. }
3476                if matches!(
3477                    name.as_str(),
3478                    "avg_neighbor" | "max_neighbor" | "sum_neighbor"
3479                ) =>
3480            {
3481                if args.len() != 3 && args.len() != 4 {
3482                    return Err(datafusion::error::DataFusionError::Execution(format!(
3483                        "{name} expects 3 or 4 args, got {}",
3484                        args.len()
3485                    )));
3486                }
3487                let Expr::Variable(v) = &args[0] else {
3488                    return Err(datafusion::error::DataFusionError::Execution(format!(
3489                        "{name}(...) first argument must be a node variable, got {:?}",
3490                        args[0]
3491                    )));
3492                };
3493                let rel_type = match &args[1] {
3494                    Expr::Literal(uni_cypher::ast::CypherLiteral::String(s)) => s.clone(),
3495                    other => {
3496                        return Err(datafusion::error::DataFusionError::Execution(format!(
3497                            "{name}: 2nd arg must be a string literal (rel-type), got {other:?}"
3498                        )));
3499                    }
3500                };
3501                let prop_name = match &args[2] {
3502                    Expr::Literal(uni_cypher::ast::CypherLiteral::String(s)) => s.clone(),
3503                    other => {
3504                        return Err(datafusion::error::DataFusionError::Execution(format!(
3505                            "{name}: 3rd arg must be a string literal (property), got {other:?}"
3506                        )));
3507                    }
3508                };
3509                let direction_arg = match args.get(3) {
3510                    None => NeighborDirection::Outgoing,
3511                    Some(Expr::Literal(uni_cypher::ast::CypherLiteral::String(d))) => {
3512                        match d.to_uppercase().as_str() {
3513                            "OUTGOING" => NeighborDirection::Outgoing,
3514                            "INCOMING" => NeighborDirection::Incoming,
3515                            "BOTH" => NeighborDirection::Both,
3516                            other => {
3517                                return Err(datafusion::error::DataFusionError::Execution(
3518                                    format!(
3519                                        "{name}: direction must be OUTGOING|INCOMING|BOTH, got '{other}'"
3520                                    ),
3521                                ));
3522                            }
3523                        }
3524                    }
3525                    Some(other) => {
3526                        return Err(datafusion::error::DataFusionError::Execution(format!(
3527                            "{name}: 4th arg must be a string literal (direction), got {other:?}"
3528                        )));
3529                    }
3530                };
3531                let subject_col = {
3532                    let direct = schema.index_of(v).ok();
3533                    let vid_name = format!("{}._vid", v);
3534                    let vid_col = schema.index_of(&vid_name).ok();
3535                    vid_col.or(direct).ok_or_else(|| {
3536                        datafusion::error::DataFusionError::Execution(format!(
3537                            "{name}: subject column '{v}' (or '{v}._vid') not in body batch schema"
3538                        ))
3539                    })?
3540                };
3541                let vid_to_values = neighbor_feature_maps
3542                    .get(&(rel_type.clone(), prop_name.clone(), direction_arg))
3543                    .cloned()
3544                    .ok_or_else(|| {
3545                        datafusion::error::DataFusionError::Execution(format!(
3546                            "{name}: pre-computed neighbor map missing for ({rel_type}, {prop_name}, {direction_arg:?}). \
3547                             This is a bug — `apply_model_invocations` should have called \
3548                             `precompute_neighbor_feature_maps` for every neighbor-aggregator \
3549                             feature before building resolvers."
3550                        ))
3551                    })?;
3552                let op = NeighborAgg::from_fn_name(name).unwrap();
3553                FeatureResolverKind::NeighborAggregate {
3554                    subject_col,
3555                    op,
3556                    vid_to_values,
3557                }
3558            }
3559            other => match resolve_src(other)? {
3560                FeatureValueSrc::Col(idx) => FeatureResolverKind::Direct(idx),
3561                FeatureValueSrc::Const(_) => {
3562                    return Err(datafusion::error::DataFusionError::Execution(format!(
3563                        "model '{}' feature must reference a variable or property — got a literal",
3564                        invocation.model_name
3565                    )));
3566                }
3567            },
3568        };
3569        out.push(FeatureResolver { binding_name, kind });
3570    }
3571    Ok(out)
3572}
3573
3574/// Phase D D2: scan invocations' feature expressions for
3575/// `semantic_match(prop, 'text')` calls and embed each distinct
3576/// query string once via the Xervo runtime. Returns a
3577/// `text → Vec<f32>` map consumed at resolver-build time. Errors
3578/// cleanly when `semantic_match` is used without a configured
3579/// Xervo runtime.
3580async fn pre_embed_semantic_match_queries(
3581    invocations: &[uni_locy::ModelInvocation],
3582    xervo_runtime: &crate::query::df_graph::locy_model_invoke::XervoRuntimeHandle,
3583) -> DFResult<HashMap<String, Vec<f32>>> {
3584    use uni_cypher::ast::{CypherLiteral, Expr};
3585    // Collect (text, embedder_alias) pairs. The alias is per-model
3586    // (Phase D D2 follow-up): each invocation's `embedder_alias` overrides
3587    // the runtime-wide `"default"`. Two invocations sharing the same
3588    // (text, alias) reuse one embed call; same text under different
3589    // aliases is embedded twice. The cache key remains plain `text` —
3590    // a model's resolver fetches embeddings via the text it knows, and
3591    // mixed-alias re-embed of identical text under a different alias is
3592    // a rare-enough edge case that the "last writer wins" cache shape
3593    // is acceptable (documented).
3594    let mut needed: Vec<(String, String)> = Vec::new();
3595    for inv in invocations {
3596        let alias = inv
3597            .embedder_alias
3598            .clone()
3599            .unwrap_or_else(|| "default".to_string());
3600        for fexpr in &inv.feature_exprs {
3601            if let Expr::FunctionCall { name, args, .. } = fexpr
3602                && name == "semantic_match"
3603                && args.len() == 2
3604                && let Expr::Literal(CypherLiteral::String(s)) = &args[1]
3605            {
3606                let tuple = (s.clone(), alias.clone());
3607                if !needed.contains(&tuple) {
3608                    needed.push(tuple);
3609                }
3610            }
3611        }
3612    }
3613    if needed.is_empty() {
3614        return Ok(HashMap::new());
3615    }
3616    let runtime = xervo_runtime.as_ref().ok_or_else(|| {
3617        datafusion::error::DataFusionError::Execution(
3618            "semantic_match: Uni-Xervo runtime not configured. Either provide \
3619             one via `LocyConfig::xervo_runtime` (or its equivalent setup \
3620             path) or pre-compute the query embedding and pass it via \
3621             `similar_to(prop, <literal_vector>)`."
3622                .to_string(),
3623        )
3624    })?;
3625    // Group needed (text, alias) by alias so each embedder is consulted
3626    // exactly once.
3627    let mut by_alias: HashMap<String, Vec<String>> = HashMap::new();
3628    for (text, alias) in &needed {
3629        by_alias
3630            .entry(alias.clone())
3631            .or_default()
3632            .push(text.clone());
3633    }
3634    let mut out: HashMap<String, Vec<f32>> = HashMap::new();
3635    for (alias, texts) in by_alias {
3636        let embedder = runtime.embedding(&alias).await.map_err(|e| {
3637            datafusion::error::DataFusionError::Execution(format!(
3638                "semantic_match: failed to obtain embedder for alias '{alias}': {e}. \
3639                 Register an embedder under that alias in your Uni-Xervo runtime, or \
3640                 pre-compute the query embedding and pass via similar_to."
3641            ))
3642        })?;
3643        let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
3644        let embeddings = embedder.embed(text_refs).await.map_err(|e| {
3645            datafusion::error::DataFusionError::Execution(format!(
3646                "semantic_match: embedder '{alias}' call failed: {e}"
3647            ))
3648        })?;
3649        if embeddings.len() != texts.len() {
3650            return Err(datafusion::error::DataFusionError::Execution(format!(
3651                "semantic_match: embedder '{alias}' returned {} vectors for {} queries",
3652                embeddings.len(),
3653                texts.len()
3654            )));
3655        }
3656        for (text, vec) in texts.into_iter().zip(embeddings) {
3657            out.insert(text, vec);
3658        }
3659    }
3660    Ok(out)
3661}
3662
3663/// Phase D D1 graph-structural: scan invocations' feature expressions
3664/// for `degree_centrality(n)` / `pagerank_score(n)` / `closeness_centrality(n)`
3665/// calls and invoke the corresponding `uni.algo.*` procedure on the
3666/// configured `AlgorithmRegistry` once per distinct call. Returns a
3667/// `fn_name → Arc<HashMap<vid, score>>` map consumed at resolver-build
3668/// time. Errors cleanly when a graph-structural FEATURE is used
3669/// without a configured registry or storage handle.
3670///
3671/// Pre-computation is `O(graph)` per call. Across fixpoint iterations
3672/// the graph state can change, so the cache lives for the lifetime of
3673/// one `apply_model_invocations` call only — same lifetime as the
3674/// D2 query-embedding cache (`pre_embed_semantic_match_queries`).
3675async fn precompute_graph_feature_maps(
3676    invocations: &[uni_locy::ModelInvocation],
3677    graph_algo: &crate::query::df_graph::locy_model_invoke::GraphAlgoHandle,
3678) -> DFResult<HashMap<String, Arc<HashMap<u64, f64>>>> {
3679    use futures::StreamExt;
3680    use uni_algo::algo::procedures::AlgoContext;
3681    use uni_cypher::ast::Expr;
3682
3683    // Map our user-facing FEATURE function names to the canonical
3684    // `uni.algo.*` procedure names registered in `AlgorithmRegistry`.
3685    fn procedure_for(fn_name: &str) -> Option<&'static str> {
3686        match fn_name {
3687            "degree_centrality" => Some("uni.algo.degreeCentrality"),
3688            "pagerank_score" => Some("uni.algo.pageRank"),
3689            "closeness_centrality" => Some("uni.algo.closeness"),
3690            "betweenness_centrality" => Some("uni.algo.betweenness"),
3691            "eigenvector_centrality" => Some("uni.algo.eigenvectorCentrality"),
3692            "harmonic_centrality" => Some("uni.algo.harmonicCentrality"),
3693            "katz_centrality" => Some("uni.algo.katzCentrality"),
3694            _ => None,
3695        }
3696    }
3697
3698    // Collect the set of distinct topology-FEATURE names referenced
3699    // across all invocations. Args are always a single Variable, so
3700    // the precomputation key is just the function name.
3701    let mut needed: Vec<String> = Vec::new();
3702    for inv in invocations {
3703        for fexpr in &inv.feature_exprs {
3704            if let Expr::FunctionCall { name, .. } = fexpr
3705                && procedure_for(name).is_some()
3706                && !needed.contains(name)
3707            {
3708                needed.push(name.clone());
3709            }
3710        }
3711    }
3712    if needed.is_empty() {
3713        return Ok(HashMap::new());
3714    }
3715
3716    let registry = graph_algo.registry.as_ref().ok_or_else(|| {
3717        datafusion::error::DataFusionError::Execution(
3718            "graph-structural FEATURE invoked but no `AlgorithmRegistry` is \
3719             configured. Configure one on `GraphExecutionContext::with_algo_registry`."
3720                .to_string(),
3721        )
3722    })?;
3723    let storage = graph_algo.storage.as_ref().ok_or_else(|| {
3724        datafusion::error::DataFusionError::Execution(
3725            "graph-structural FEATURE invoked but no storage handle was \
3726             threaded into the FEATURE runtime. This is a bug in df_planner."
3727                .to_string(),
3728        )
3729    })?;
3730
3731    let mut out: HashMap<String, Arc<HashMap<u64, f64>>> = HashMap::new();
3732    for fn_name in needed {
3733        let proc_name = procedure_for(&fn_name).unwrap();
3734        let procedure = registry.get(proc_name).ok_or_else(|| {
3735            datafusion::error::DataFusionError::Execution(format!(
3736                "graph-structural FEATURE '{fn_name}' resolves to procedure \
3737                 '{proc_name}' which is not in the algorithm registry"
3738            ))
3739        })?;
3740        // Topology procedures take (nodeLabels[], relationshipTypes[],
3741        // [direction], [...]) — pass empty arrays for nodeLabels and
3742        // relationshipTypes to mean "all". The procedure fills the
3743        // remaining optional args from its signature defaults.
3744        let args: Vec<serde_json::Value> = vec![
3745            serde_json::Value::Array(Vec::new()),
3746            serde_json::Value::Array(Vec::new()),
3747        ];
3748        let algo_ctx = AlgoContext::new(
3749            storage.clone(),
3750            graph_algo.l0_manager.as_ref().map(Arc::clone),
3751        );
3752        // The AlgoProcedure trait routes direct (nodeLabels, edgeTypes)
3753        // args through the V2 projection entry point: build a projection
3754        // from the direct args, then execute against it.
3755        //
3756        // Fill optional algorithm-specific args (e.g. degree_centrality's
3757        // `direction`, eigenvector/katz `weightProperty`) with their schema
3758        // defaults for the projection build: `build_projection_from_direct_args`
3759        // feeds the specific args (`args[2..]`) to the adapter's
3760        // `customize_projection`, which indexes them positionally — the two
3761        // empty placeholder arrays alone would leave that slice empty and
3762        // panic. `validate_args` fills missing optionals WITHOUT type-checking
3763        // the defaults (some are `Null`-typed sentinels), so this never errors
3764        // for the placeholder shape.
3765        //
3766        // We pass the ORIGINAL `args` (not the filled ones) to
3767        // `execute_with_projection`, which re-runs `validate_args` internally:
3768        // re-feeding already-filled args would make those defaults look
3769        // "provided" and trip the type-check (`weightProperty: Null` vs
3770        // `String`). This mirrors the (now-removed) legacy
3771        // `AlgoProcedure::execute`, which validated once then built + ran.
3772        let filled_args = procedure
3773            .signature()
3774            .validate_args(args.clone())
3775            .map_err(|e| {
3776                datafusion::error::DataFusionError::Execution(format!(
3777                    "graph-structural FEATURE '{fn_name}': argument validation failed: {e}"
3778                ))
3779            })?;
3780        let projection = uni_algo::algo::procedure_template::build_projection_from_direct_args(
3781            procedure.as_ref(),
3782            &algo_ctx,
3783            &filled_args,
3784        )
3785        .await
3786        .map_err(|e| {
3787            datafusion::error::DataFusionError::Execution(format!(
3788                "graph-structural FEATURE '{fn_name}': projection build failed: {e}"
3789            ))
3790        })?;
3791        let mut stream = procedure.execute_with_projection(algo_ctx, args, projection);
3792        let mut score_map: HashMap<u64, f64> = HashMap::new();
3793        let sig = procedure.signature();
3794        let node_idx = sig
3795            .yields
3796            .iter()
3797            .position(|(n, _)| *n == "nodeId")
3798            .ok_or_else(|| {
3799                datafusion::error::DataFusionError::Execution(format!(
3800                    "procedure '{proc_name}' yield schema missing 'nodeId'"
3801                ))
3802            })?;
3803        // Most `uni.algo.*` centrality procedures yield `score`; the
3804        // `harmonicCentrality` family yields `centrality` instead. Accept
3805        // either to keep this dispatch independent of procedure-internal
3806        // naming choices.
3807        let score_idx = sig
3808            .yields
3809            .iter()
3810            .position(|(n, _)| *n == "score" || *n == "centrality")
3811            .ok_or_else(|| {
3812                datafusion::error::DataFusionError::Execution(format!(
3813                    "procedure '{proc_name}' yield schema missing a numeric score column \
3814                     (expected 'score' or 'centrality')"
3815                ))
3816            })?;
3817        while let Some(row_res) = stream.next().await {
3818            let row = row_res.map_err(|e| {
3819                datafusion::error::DataFusionError::Execution(format!(
3820                    "graph-structural FEATURE '{fn_name}': procedure '{proc_name}' failed: {e}"
3821                ))
3822            })?;
3823            let vid_v = row.values.get(node_idx);
3824            let score_v = row.values.get(score_idx);
3825            let (Some(vid_v), Some(score_v)) = (vid_v, score_v) else {
3826                continue;
3827            };
3828            let vid = vid_v.as_u64().or_else(|| vid_v.as_i64().map(|i| i as u64));
3829            let score = score_v
3830                .as_f64()
3831                .or_else(|| score_v.as_i64().map(|i| i as f64));
3832            if let (Some(vid), Some(score)) = (vid, score) {
3833                score_map.insert(vid, score);
3834            }
3835        }
3836        out.insert(fn_name, Arc::new(score_map));
3837    }
3838    Ok(out)
3839}
3840
3841/// Phase D D1 graph-structural: one-hop neighborhood aggregator
3842/// precompute. Scans invocations' feature expressions for
3843/// `avg_neighbor` / `max_neighbor` / `sum_neighbor` FunctionCalls,
3844/// collects the distinct `(rel_type, prop_name)` pairs they need,
3845/// resolves each rel-type to a schema edge-type id, warms the
3846/// outgoing-adjacency CSR, and for every subject vid present in the
3847/// body batches walks the one-hop neighborhood and fetches the
3848/// requested property from each neighbor via `PropertyManager`.
3849/// Non-numeric neighbor property values are filtered out via
3850/// `Value::as_f64`.
3851///
3852/// Returns `Arc<HashMap<u64, Vec<f64>>>` keyed by `(rel_type, prop_name)`.
3853/// The resolver's runtime cost per row is then a single hash lookup
3854/// plus an `avg`/`max`/`sum` over the cached `Vec<f64>`.
3855///
3856/// Scope: **subject-set-only** — we only collect for vids that appear
3857/// in the body batches' subject columns (avoids pre-walking the entire
3858/// graph). Subjects with no outgoing edges of the named type land in
3859/// the map with an empty `Vec` so the resolver's `Null` semantics
3860/// remain crisp (empty → `Null` → classifier interprets per its
3861/// feature contract).
3862/// Per-`(rel_type, prop_name, direction)` cache of neighbor property
3863/// values keyed by subject vid, produced by
3864/// `precompute_neighbor_feature_maps` and consumed by
3865/// `FeatureResolverKind::NeighborAggregate` resolvers.
3866type NeighborFeatureMaps =
3867    HashMap<(String, String, NeighborDirection), Arc<HashMap<u64, Vec<f64>>>>;
3868
3869async fn precompute_neighbor_feature_maps(
3870    invocations: &[uni_locy::ModelInvocation],
3871    batches: &[RecordBatch],
3872    graph_algo: &crate::query::df_graph::locy_model_invoke::GraphAlgoHandle,
3873) -> DFResult<NeighborFeatureMaps> {
3874    use uni_cypher::ast::{CypherLiteral, Expr};
3875
3876    // Collect distinct (subject_var, rel_type, prop_name, direction)
3877    // tuples needed across all invocations. The subject_var tells us
3878    // which body batch column to scan for subject vids; the direction
3879    // is optional in the AST (defaults to OUTGOING).
3880    let parse_direction = |arg: Option<&Expr>| -> Option<NeighborDirection> {
3881        match arg {
3882            None => Some(NeighborDirection::Outgoing),
3883            Some(Expr::Literal(CypherLiteral::String(d))) => match d.to_uppercase().as_str() {
3884                "OUTGOING" => Some(NeighborDirection::Outgoing),
3885                "INCOMING" => Some(NeighborDirection::Incoming),
3886                "BOTH" => Some(NeighborDirection::Both),
3887                _ => None,
3888            },
3889            _ => None,
3890        }
3891    };
3892    let mut needed: Vec<(String, String, String, NeighborDirection)> = Vec::new();
3893    for inv in invocations {
3894        for fexpr in &inv.feature_exprs {
3895            if let Expr::FunctionCall { name, args, .. } = fexpr
3896                && NeighborAgg::from_fn_name(name).is_some()
3897                && (args.len() == 3 || args.len() == 4)
3898                && let Expr::Variable(v) = &args[0]
3899                && let Expr::Literal(CypherLiteral::String(rel)) = &args[1]
3900                && let Expr::Literal(CypherLiteral::String(prop)) = &args[2]
3901                && let Some(direction) = parse_direction(args.get(3))
3902            {
3903                let tuple = (v.clone(), rel.clone(), prop.clone(), direction);
3904                if !needed.contains(&tuple) {
3905                    needed.push(tuple);
3906                }
3907            }
3908        }
3909    }
3910    if needed.is_empty() {
3911        return Ok(HashMap::new());
3912    }
3913
3914    let storage = graph_algo.storage.as_ref().ok_or_else(|| {
3915        datafusion::error::DataFusionError::Execution(
3916            "neighbor-aggregator FEATURE invoked but no storage handle was \
3917             threaded into the FEATURE runtime. This is a bug in df_planner."
3918                .to_string(),
3919        )
3920    })?;
3921    let property_manager = graph_algo.property_manager.as_ref().ok_or_else(|| {
3922        datafusion::error::DataFusionError::Execution(
3923            "neighbor-aggregator FEATURE invoked but no PropertyManager was \
3924             threaded into the FEATURE runtime. This is a bug in df_planner."
3925                .to_string(),
3926        )
3927    })?;
3928    // Build a QueryContext snapshot so L0-resident vertex properties
3929    // are visible to `get_vertex_prop_with_ctx`. Without a ctx, L0
3930    // property data is silently invisible (returns Null), which is
3931    // why the topology trio's `AlgoContext` consumes L0 via
3932    // `L0Manager` whereas property reads need this separate path.
3933    let query_ctx = graph_algo.l0_buffers.as_ref().map(|bufs| {
3934        uni_store::runtime::context::QueryContext::new_with_pending(
3935            bufs.current.clone(),
3936            bufs.transaction.clone(),
3937            bufs.pending_flush.clone(),
3938        )
3939    });
3940
3941    // Group needed tuples by (rel_type, prop_name, direction) — one
3942    // precomputed map per key, regardless of which subject_var binding
3943    // points at it (the subject vids are unioned).
3944    let mut by_key: HashMap<(String, String, NeighborDirection), Vec<String>> = HashMap::new();
3945    for (subject_var, rel, prop, direction) in needed {
3946        by_key
3947            .entry((rel, prop, direction))
3948            .or_default()
3949            .push(subject_var);
3950    }
3951
3952    let mut out: NeighborFeatureMaps = HashMap::new();
3953    for ((rel_type, prop_name, direction), subject_vars) in by_key {
3954        // Resolve edge_type_id from schema.
3955        let schema = storage.schema_manager().schema();
3956        let Some(edge_meta) = schema.edge_types.get(&rel_type) else {
3957            // Unregistered rel-type → empty map. The resolver surfaces
3958            // Null at row time, consistent with the no-neighbor case.
3959            out.insert((rel_type, prop_name, direction), Arc::new(HashMap::new()));
3960            continue;
3961        };
3962        let edge_type_id = edge_meta.id;
3963
3964        // Warm adjacency for every direction we'll traverse. Mirrors
3965        // the pattern in projection.rs / procedure_template.rs.
3966        let edge_ver = storage.get_edge_version_by_id(edge_type_id);
3967        for dir in direction.store_directions() {
3968            storage
3969                .warm_adjacency(edge_type_id, *dir, edge_ver)
3970                .await
3971                .map_err(|e| {
3972                    datafusion::error::DataFusionError::Execution(format!(
3973                        "neighbor-aggregator warm_adjacency for '{rel_type}' / {dir:?} failed: {e}"
3974                    ))
3975                })?;
3976        }
3977
3978        // Collect distinct subject vids from body batches across every
3979        // subject_var binding that this (rel, prop) pair uses.
3980        let mut subject_vids: std::collections::HashSet<u64> = std::collections::HashSet::new();
3981        for subject_var in &subject_vars {
3982            for batch in batches {
3983                let schema = batch.schema();
3984                let col_idx = schema
3985                    .index_of(&format!("{}._vid", subject_var))
3986                    .ok()
3987                    .or_else(|| schema.index_of(subject_var).ok());
3988                let Some(col_idx) = col_idx else { continue };
3989                let col = batch.column(col_idx);
3990                for row in 0..batch.num_rows() {
3991                    if let Some(v) = extract_vid_from_column(col.as_ref(), row) {
3992                        subject_vids.insert(v);
3993                    }
3994                }
3995            }
3996        }
3997
3998        // For each subject, walk edges in the configured direction(s),
3999        // fetch neighbor property, coerce to f64, accumulate. Subjects
4000        // with no numeric neighbors retain an empty Vec (→ Null at
4001        // row time).
4002        let mut vid_to_values: HashMap<u64, Vec<f64>> = HashMap::new();
4003        let adj = storage.adjacency_manager();
4004        for subject_vid in subject_vids {
4005            let mut neighbors: Vec<(uni_common::core::id::Vid, uni_common::core::id::Eid)> =
4006                Vec::new();
4007            for dir in direction.store_directions() {
4008                neighbors.extend(adj.get_neighbors(
4009                    uni_common::core::id::Vid::from(subject_vid),
4010                    edge_type_id,
4011                    *dir,
4012                ));
4013            }
4014            let mut values: Vec<f64> = Vec::with_capacity(neighbors.len());
4015            for (neighbor_vid, _eid) in neighbors {
4016                let val = property_manager
4017                    .get_vertex_prop_with_ctx(neighbor_vid, &prop_name, query_ctx.as_ref())
4018                    .await
4019                    .map_err(|e| {
4020                        datafusion::error::DataFusionError::Execution(format!(
4021                            "neighbor-aggregator: failed to read property \
4022                             '{prop_name}' on neighbor vid {neighbor_vid:?}: {e}"
4023                        ))
4024                    })?;
4025                if let Some(f) = val.as_f64()
4026                    && !f.is_nan()
4027                {
4028                    values.push(f);
4029                }
4030            }
4031            vid_to_values.insert(subject_vid, values);
4032        }
4033        out.insert((rel_type, prop_name, direction), Arc::new(vid_to_values));
4034    }
4035    Ok(out)
4036}
4037
4038/// Phase D D3: walk the source rule's converged batches and build
4039/// a `vid → FeatureValue` lookup for the named column. The subject
4040/// column in the derived rule's schema holds VIDs (UInt64) for node
4041/// variables; the value column type follows the rule's yield-schema
4042/// inference (typically Float64 / Int64 / Bool / String).
4043fn build_path_context_lookup(
4044    handle: &crate::query::df_graph::locy_model_invoke::PathContextHandle,
4045    _subject_var: &str,
4046    column: &str,
4047    model_name: &str,
4048) -> DFResult<HashMap<u64, uni_locy::FeatureValue>> {
4049    // The source rule's KEY column is its first yield column by
4050    // convention (`infer_yield_schema` orders KEYs first). The model's
4051    // local `subject_var` is just a binding alias — the actual join
4052    // matches the body row's VID against this canonical column.
4053    if handle.schema.fields().is_empty() {
4054        return Err(datafusion::error::DataFusionError::Execution(format!(
4055            "model '{model_name}' path_context: source rule has empty yield schema"
4056        )));
4057    }
4058    let subj_idx = 0_usize;
4059    let col_idx = handle.schema.index_of(column).map_err(|_| {
4060        datafusion::error::DataFusionError::Execution(format!(
4061            "model '{model_name}' path_context: column '{column}' not in \
4062             source rule's yield schema (have: {:?})",
4063            handle
4064                .schema
4065                .fields()
4066                .iter()
4067                .map(|f| f.name().clone())
4068                .collect::<Vec<_>>()
4069        ))
4070    })?;
4071    let batches = handle.data.read();
4072    let mut out: HashMap<u64, uni_locy::FeatureValue> = HashMap::new();
4073    for batch in batches.iter() {
4074        let subj_col = batch.column(subj_idx);
4075        let value_col = batch.column(col_idx);
4076        for row in 0..batch.num_rows() {
4077            if subj_col.is_null(row) {
4078                continue;
4079            }
4080            let vid = if let Some(a) = subj_col.as_any().downcast_ref::<arrow_array::UInt64Array>()
4081            {
4082                a.value(row)
4083            } else if let Some(a) = subj_col.as_any().downcast_ref::<arrow_array::Int64Array>() {
4084                a.value(row) as u64
4085            } else {
4086                continue;
4087            };
4088            let v = extract_feature_value(value_col.as_ref(), row);
4089            // Last write wins on duplicates; derived rules typically have
4090            // unique KEY values, so this is a defensive guard.
4091            out.insert(vid, v);
4092        }
4093    }
4094    Ok(out)
4095}
4096
4097/// Extract a `uni_common::Value` from one row of an Arrow column.
4098/// Used by the Phase D `similar_to` feature resolver, which needs
4099/// the raw `Value` (especially `Value::Vector(Vec<f32>)`) to feed
4100/// `eval_similar_to_pure`.
4101fn extract_common_value(col: &dyn arrow_array::Array, row_idx: usize) -> uni_common::Value {
4102    use arrow_array::{BooleanArray, Float64Array, Int64Array, LargeStringArray, StringArray};
4103    if col.is_null(row_idx) {
4104        return uni_common::Value::Null;
4105    }
4106    if let Some(a) = col.as_any().downcast_ref::<Float64Array>() {
4107        return uni_common::Value::Float(a.value(row_idx));
4108    }
4109    if let Some(a) = col.as_any().downcast_ref::<Int64Array>() {
4110        return uni_common::Value::Int(a.value(row_idx));
4111    }
4112    if let Some(a) = col.as_any().downcast_ref::<BooleanArray>() {
4113        return uni_common::Value::Bool(a.value(row_idx));
4114    }
4115    if let Some(a) = col.as_any().downcast_ref::<StringArray>() {
4116        return uni_common::Value::String(a.value(row_idx).to_string());
4117    }
4118    if let Some(a) = col.as_any().downcast_ref::<LargeStringArray>() {
4119        return uni_common::Value::String(a.value(row_idx).to_string());
4120    }
4121    if let Some(b) = col.as_any().downcast_ref::<arrow_array::LargeBinaryArray>() {
4122        let bytes = b.value(row_idx);
4123        if bytes.is_empty() {
4124            return uni_common::Value::Null;
4125        }
4126        return uni_common::cypher_value_codec::decode(bytes).unwrap_or(uni_common::Value::Null);
4127    }
4128    uni_common::Value::Null
4129}
4130
4131fn extract_feature_value(col: &dyn arrow_array::Array, row_idx: usize) -> uni_locy::FeatureValue {
4132    use arrow_array::{BooleanArray, Float64Array, Int64Array, LargeStringArray, StringArray};
4133    if col.is_null(row_idx) {
4134        return uni_locy::FeatureValue::Null;
4135    }
4136    if let Some(a) = col.as_any().downcast_ref::<Float64Array>() {
4137        return uni_locy::FeatureValue::Float(a.value(row_idx));
4138    }
4139    if let Some(a) = col.as_any().downcast_ref::<Int64Array>() {
4140        return uni_locy::FeatureValue::Int(a.value(row_idx));
4141    }
4142    if let Some(a) = col.as_any().downcast_ref::<BooleanArray>() {
4143        return uni_locy::FeatureValue::Bool(a.value(row_idx));
4144    }
4145    if let Some(a) = col.as_any().downcast_ref::<StringArray>() {
4146        return uni_locy::FeatureValue::String(a.value(row_idx).to_string());
4147    }
4148    if let Some(a) = col.as_any().downcast_ref::<LargeStringArray>() {
4149        return uni_locy::FeatureValue::String(a.value(row_idx).to_string());
4150    }
4151    // Schema-less property storage: values arrive as LargeBinary
4152    // MessagePack-encoded `CypherValue`. Decode via the standard codec
4153    // and project the result to the matching `FeatureValue` variant.
4154    if let Some(b) = col.as_any().downcast_ref::<arrow_array::LargeBinaryArray>() {
4155        let bytes = b.value(row_idx);
4156        if bytes.is_empty() {
4157            return uni_locy::FeatureValue::Null;
4158        }
4159        let v = uni_common::cypher_value_codec::decode(bytes).unwrap_or(uni_common::Value::Null);
4160        return match v {
4161            uni_common::Value::Float(f) => uni_locy::FeatureValue::Float(f),
4162            uni_common::Value::Int(i) => uni_locy::FeatureValue::Int(i),
4163            uni_common::Value::Bool(b) => uni_locy::FeatureValue::Bool(b),
4164            uni_common::Value::String(s) => uni_locy::FeatureValue::String(s),
4165            uni_common::Value::Null => uni_locy::FeatureValue::Null,
4166            _ => uni_locy::FeatureValue::Null,
4167        };
4168    }
4169    uni_locy::FeatureValue::Null
4170}
4171
4172/// Probabilistic complement for negated IS-refs targeting PROB rules.
4173///
4174/// Instead of filtering out matching VIDs (anti-join), this adds a complement
4175/// column `__prob_complement_{rule_name}` with value `1 - p` for each matching
4176/// VID, and `1.0` for VIDs not present in the negated rule's facts. Implements
4177/// `IS NOT risk` on a PROB rule: the probability that the entity is NOT risky.
4178pub fn apply_prob_complement(
4179    batches: Vec<RecordBatch>,
4180    neg_facts: &[RecordBatch],
4181    left_col: &str,
4182    right_col: &str,
4183    prob_col: &str,
4184    complement_col_name: &str,
4185) -> datafusion::error::Result<Vec<RecordBatch>> {
4186    use arrow_array::{Array as _, Float64Array, UInt64Array};
4187
4188    // Build VID → probability lookup from negative facts
4189    let mut prob_map: std::collections::HashMap<u64, f64> = std::collections::HashMap::new();
4190    for batch in neg_facts {
4191        let Ok(vid_idx) = batch.schema().index_of(right_col) else {
4192            continue;
4193        };
4194        let Ok(prob_idx) = batch.schema().index_of(prob_col) else {
4195            continue;
4196        };
4197        let Some(vids) = batch.column(vid_idx).as_any().downcast_ref::<UInt64Array>() else {
4198            continue;
4199        };
4200        let prob_arr = batch.column(prob_idx);
4201        let probs = prob_arr.as_any().downcast_ref::<Float64Array>();
4202        for i in 0..vids.len() {
4203            if !vids.is_null(i) {
4204                let p = probs
4205                    .and_then(|arr| {
4206                        if arr.is_null(i) {
4207                            None
4208                        } else {
4209                            Some(arr.value(i))
4210                        }
4211                    })
4212                    .unwrap_or(0.0);
4213                // If multiple facts for same VID, use noisy-OR combination:
4214                // combined = 1 - (1 - existing) * (1 - new)
4215                prob_map
4216                    .entry(vids.value(i))
4217                    .and_modify(|existing| {
4218                        *existing = 1.0 - (1.0 - *existing) * (1.0 - p);
4219                    })
4220                    .or_insert(p);
4221            }
4222        }
4223    }
4224
4225    // Add complement column to each batch
4226    let mut result = Vec::new();
4227    for batch in batches {
4228        let Ok(idx) = batch.schema().index_of(left_col) else {
4229            result.push(batch);
4230            continue;
4231        };
4232        let Some(vids) = batch.column(idx).as_any().downcast_ref::<UInt64Array>() else {
4233            result.push(batch);
4234            continue;
4235        };
4236
4237        // Compute complement values: 1 - p for matched VIDs, 1.0 for absent
4238        let complements: Vec<f64> = (0..vids.len())
4239            .map(|i| {
4240                if vids.is_null(i) {
4241                    1.0
4242                } else {
4243                    let p = prob_map.get(&vids.value(i)).copied().unwrap_or(0.0);
4244                    1.0 - p
4245                }
4246            })
4247            .collect();
4248
4249        let complement_arr = Float64Array::from(complements);
4250
4251        // Add the complement column to the batch
4252        let mut columns: Vec<arrow_array::ArrayRef> = batch.columns().to_vec();
4253        columns.push(std::sync::Arc::new(complement_arr));
4254
4255        let mut fields: Vec<std::sync::Arc<arrow_schema::Field>> =
4256            batch.schema().fields().iter().cloned().collect();
4257        fields.push(std::sync::Arc::new(arrow_schema::Field::new(
4258            complement_col_name,
4259            arrow_schema::DataType::Float64,
4260            true,
4261        )));
4262
4263        let new_schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
4264        let new_batch = RecordBatch::try_new(new_schema, columns).map_err(arrow_err)?;
4265        result.push(new_batch);
4266    }
4267    Ok(result)
4268}
4269
4270/// Probabilistic complement for composite (multi-column) join keys.
4271///
4272/// Builds a composite key from all `join_cols` right-side columns in
4273/// `neg_facts`, maps each composite key to a probability via noisy-OR
4274/// combination, then adds a single `complement_col_name` column with
4275/// `1 - p` for matched keys and `1.0` for absent keys.
4276pub fn apply_prob_complement_composite(
4277    batches: Vec<RecordBatch>,
4278    neg_facts: &[RecordBatch],
4279    join_cols: &[(String, String)],
4280    prob_col: &str,
4281    complement_col_name: &str,
4282) -> datafusion::error::Result<Vec<RecordBatch>> {
4283    use arrow_array::{Array as _, Float64Array, UInt64Array};
4284
4285    // Build composite-key → probability lookup from negative facts.
4286    let mut prob_map: HashMap<Vec<u64>, f64> = HashMap::new();
4287    for batch in neg_facts {
4288        let right_indices: Vec<usize> = join_cols
4289            .iter()
4290            .filter_map(|(_, rc)| batch.schema().index_of(rc).ok())
4291            .collect();
4292        if right_indices.len() != join_cols.len() {
4293            continue;
4294        }
4295        let Ok(prob_idx) = batch.schema().index_of(prob_col) else {
4296            continue;
4297        };
4298        let prob_arr = batch.column(prob_idx);
4299        let probs = prob_arr.as_any().downcast_ref::<Float64Array>();
4300        for row in 0..batch.num_rows() {
4301            let mut key = Vec::with_capacity(right_indices.len());
4302            let mut valid = true;
4303            for &ci in &right_indices {
4304                let col = batch.column(ci);
4305                if let Some(vids) = col.as_any().downcast_ref::<UInt64Array>() {
4306                    if vids.is_null(row) {
4307                        valid = false;
4308                        break;
4309                    }
4310                    key.push(vids.value(row));
4311                } else {
4312                    valid = false;
4313                    break;
4314                }
4315            }
4316            if !valid {
4317                continue;
4318            }
4319            let p = probs
4320                .and_then(|arr| {
4321                    if arr.is_null(row) {
4322                        None
4323                    } else {
4324                        Some(arr.value(row))
4325                    }
4326                })
4327                .unwrap_or(0.0);
4328            // Noisy-OR combination for duplicate composite keys.
4329            prob_map
4330                .entry(key)
4331                .and_modify(|existing| {
4332                    *existing = 1.0 - (1.0 - *existing) * (1.0 - p);
4333                })
4334                .or_insert(p);
4335        }
4336    }
4337
4338    // Add complement column to each batch.
4339    let mut result = Vec::new();
4340    for batch in batches {
4341        let left_indices: Vec<usize> = join_cols
4342            .iter()
4343            .filter_map(|(lc, _)| batch.schema().index_of(lc).ok())
4344            .collect();
4345        if left_indices.len() != join_cols.len() {
4346            result.push(batch);
4347            continue;
4348        }
4349        let all_u64 = left_indices.iter().all(|&ci| {
4350            batch
4351                .column(ci)
4352                .as_any()
4353                .downcast_ref::<UInt64Array>()
4354                .is_some()
4355        });
4356        if !all_u64 {
4357            result.push(batch);
4358            continue;
4359        }
4360
4361        let complements: Vec<f64> = (0..batch.num_rows())
4362            .map(|row| {
4363                let mut key = Vec::with_capacity(left_indices.len());
4364                for &ci in &left_indices {
4365                    let vids = batch
4366                        .column(ci)
4367                        .as_any()
4368                        .downcast_ref::<UInt64Array>()
4369                        .unwrap();
4370                    if vids.is_null(row) {
4371                        return 1.0;
4372                    }
4373                    key.push(vids.value(row));
4374                }
4375                let p = prob_map.get(&key).copied().unwrap_or(0.0);
4376                1.0 - p
4377            })
4378            .collect();
4379
4380        let complement_arr = Float64Array::from(complements);
4381        let mut columns: Vec<arrow_array::ArrayRef> = batch.columns().to_vec();
4382        columns.push(Arc::new(complement_arr));
4383
4384        let mut fields: Vec<Arc<arrow_schema::Field>> =
4385            batch.schema().fields().iter().cloned().collect();
4386        fields.push(Arc::new(arrow_schema::Field::new(
4387            complement_col_name,
4388            arrow_schema::DataType::Float64,
4389            true,
4390        )));
4391
4392        let new_schema = Arc::new(arrow_schema::Schema::new(fields));
4393        let new_batch = RecordBatch::try_new(new_schema, columns).map_err(arrow_err)?;
4394        result.push(new_batch);
4395    }
4396    Ok(result)
4397}
4398
4399/// Boolean anti-join for composite (multi-column) join keys.
4400///
4401/// Builds a `HashSet<Vec<u64>>` from `neg_facts` using all right-side
4402/// columns in `join_cols`, then filters `batches` to keep only rows
4403/// whose composite left-side key is NOT in the set.
4404pub fn apply_anti_join_composite(
4405    batches: Vec<RecordBatch>,
4406    neg_facts: &[RecordBatch],
4407    join_cols: &[(String, String)],
4408) -> datafusion::error::Result<Vec<RecordBatch>> {
4409    use arrow::compute::filter_record_batch;
4410    use arrow_array::{Array as _, BooleanArray, UInt64Array};
4411
4412    // Collect composite keys from the negated rule's derived facts.
4413    let mut banned: HashSet<Vec<u64>> = HashSet::new();
4414    for batch in neg_facts {
4415        let right_indices: Vec<usize> = join_cols
4416            .iter()
4417            .filter_map(|(_, rc)| batch.schema().index_of(rc).ok())
4418            .collect();
4419        if right_indices.len() != join_cols.len() {
4420            continue;
4421        }
4422        for row in 0..batch.num_rows() {
4423            let mut key = Vec::with_capacity(right_indices.len());
4424            let mut valid = true;
4425            for &ci in &right_indices {
4426                let col = batch.column(ci);
4427                if let Some(vids) = col.as_any().downcast_ref::<UInt64Array>() {
4428                    if vids.is_null(row) {
4429                        valid = false;
4430                        break;
4431                    }
4432                    key.push(vids.value(row));
4433                } else {
4434                    valid = false;
4435                    break;
4436                }
4437            }
4438            if valid {
4439                banned.insert(key);
4440            }
4441        }
4442    }
4443
4444    if banned.is_empty() {
4445        return Ok(batches);
4446    }
4447
4448    // Filter body batches: keep rows where composite left key NOT IN banned.
4449    let mut result = Vec::new();
4450    for batch in batches {
4451        let left_indices: Vec<usize> = join_cols
4452            .iter()
4453            .filter_map(|(lc, _)| batch.schema().index_of(lc).ok())
4454            .collect();
4455        if left_indices.len() != join_cols.len() {
4456            result.push(batch);
4457            continue;
4458        }
4459        let all_u64 = left_indices.iter().all(|&ci| {
4460            batch
4461                .column(ci)
4462                .as_any()
4463                .downcast_ref::<UInt64Array>()
4464                .is_some()
4465        });
4466        if !all_u64 {
4467            result.push(batch);
4468            continue;
4469        }
4470
4471        let keep: Vec<bool> = (0..batch.num_rows())
4472            .map(|row| {
4473                let mut key = Vec::with_capacity(left_indices.len());
4474                for &ci in &left_indices {
4475                    let vids = batch
4476                        .column(ci)
4477                        .as_any()
4478                        .downcast_ref::<UInt64Array>()
4479                        .unwrap();
4480                    if vids.is_null(row) {
4481                        return true; // null keys are never banned
4482                    }
4483                    key.push(vids.value(row));
4484                }
4485                !banned.contains(&key)
4486            })
4487            .collect();
4488        let keep_arr = BooleanArray::from(keep);
4489        let filtered = filter_record_batch(&batch, &keep_arr).map_err(arrow_err)?;
4490        if filtered.num_rows() > 0 {
4491            result.push(filtered);
4492        }
4493    }
4494    Ok(result)
4495}
4496
4497/// Multiply `__prob_complement_*` columns into the rule's PROB column and clean up.
4498///
4499/// After IS NOT probabilistic complement semantics have added `__prob_complement_*`
4500/// columns to clause results, this function:
4501/// 1. Computes the product of all complement factor columns
4502/// 2. Multiplies the product into the existing PROB column (if any)
4503/// 3. Removes the internal `__prob_complement_*` columns from the output
4504///
4505/// If the rule has no PROB column, complement columns are simply removed
4506/// (the complement information is discarded and IS NOT acts as a keep-all).
4507pub fn multiply_prob_factors(
4508    batches: Vec<RecordBatch>,
4509    prob_col: Option<&str>,
4510    complement_cols: &[String],
4511) -> datafusion::error::Result<Vec<RecordBatch>> {
4512    use arrow_array::{Array as _, Float64Array};
4513
4514    let mut result = Vec::with_capacity(batches.len());
4515
4516    for batch in batches {
4517        if batch.num_rows() == 0 {
4518            // Remove complement columns from empty batches
4519            let keep: Vec<usize> = batch
4520                .schema()
4521                .fields()
4522                .iter()
4523                .enumerate()
4524                .filter(|(_, f)| !complement_cols.contains(f.name()))
4525                .map(|(i, _)| i)
4526                .collect();
4527            let fields: Vec<_> = keep
4528                .iter()
4529                .map(|&i| batch.schema().field(i).clone())
4530                .collect();
4531            let cols: Vec<_> = keep.iter().map(|&i| batch.column(i).clone()).collect();
4532            let schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
4533            result.push(
4534                RecordBatch::try_new(schema, cols).map_err(|e| {
4535                    datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
4536                })?,
4537            );
4538            continue;
4539        }
4540
4541        let num_rows = batch.num_rows();
4542
4543        // 1. Compute product of all complement factors
4544        let mut combined = vec![1.0f64; num_rows];
4545        for col_name in complement_cols {
4546            if let Ok(idx) = batch.schema().index_of(col_name) {
4547                let arr = batch
4548                    .column(idx)
4549                    .as_any()
4550                    .downcast_ref::<Float64Array>()
4551                    .ok_or_else(|| {
4552                        datafusion::error::DataFusionError::Internal(format!(
4553                            "Expected Float64 for complement column {col_name}"
4554                        ))
4555                    })?;
4556                for (i, val) in combined.iter_mut().enumerate().take(num_rows) {
4557                    if !arr.is_null(i) {
4558                        *val *= arr.value(i);
4559                    }
4560                }
4561            }
4562        }
4563
4564        // 2. If there's a PROB column, multiply combined into it
4565        let final_prob: Vec<f64> = if let Some(prob_name) = prob_col {
4566            if let Ok(idx) = batch.schema().index_of(prob_name) {
4567                let arr = batch
4568                    .column(idx)
4569                    .as_any()
4570                    .downcast_ref::<Float64Array>()
4571                    .ok_or_else(|| {
4572                        datafusion::error::DataFusionError::Internal(format!(
4573                            "Expected Float64 for PROB column {prob_name}"
4574                        ))
4575                    })?;
4576                (0..num_rows)
4577                    .map(|i| {
4578                        if arr.is_null(i) {
4579                            combined[i]
4580                        } else {
4581                            arr.value(i) * combined[i]
4582                        }
4583                    })
4584                    .collect()
4585            } else {
4586                combined
4587            }
4588        } else {
4589            combined
4590        };
4591
4592        let new_prob_array: arrow_array::ArrayRef =
4593            std::sync::Arc::new(Float64Array::from(final_prob));
4594
4595        // 3. Build output: replace PROB column, remove complement columns
4596        let mut fields = Vec::new();
4597        let mut columns = Vec::new();
4598
4599        for (idx, field) in batch.schema().fields().iter().enumerate() {
4600            if complement_cols.contains(field.name()) {
4601                continue;
4602            }
4603            if prob_col.is_some_and(|p| field.name() == p) {
4604                fields.push(field.clone());
4605                columns.push(new_prob_array.clone());
4606            } else {
4607                fields.push(field.clone());
4608                columns.push(batch.column(idx).clone());
4609            }
4610        }
4611
4612        let schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
4613        result.push(RecordBatch::try_new(schema, columns).map_err(arrow_err)?);
4614    }
4615
4616    Ok(result)
4617}
4618
4619/// Update derived scan handles before evaluating a rule's clause bodies.
4620///
4621/// For self-references: inject delta (semi-naive optimization).
4622/// For cross-references: inject full facts.
4623fn update_derived_scan_handles(
4624    registry: &DerivedScanRegistry,
4625    states: &[FixpointState],
4626    current_rule_idx: usize,
4627    rules: &[FixpointRulePlan],
4628) {
4629    let current_rule_name = &rules[current_rule_idx].name;
4630
4631    for entry in &registry.entries {
4632        // Find the state for this entry's rule
4633        let source_state_idx = rules.iter().position(|r| r.name == entry.rule_name);
4634        let Some(source_idx) = source_state_idx else {
4635            continue;
4636        };
4637
4638        let is_self = entry.rule_name == *current_rule_name;
4639        let data = if is_self {
4640            // Self-ref: inject delta for semi-naive
4641            states[source_idx].all_delta().to_vec()
4642        } else {
4643            // Cross-ref: inject full facts
4644            states[source_idx].all_facts().to_vec()
4645        };
4646
4647        // If empty, write an empty batch so the scan returns zero rows
4648        let data = if data.is_empty() || data.iter().all(|b| b.num_rows() == 0) {
4649            vec![RecordBatch::new_empty(Arc::clone(&entry.schema))]
4650        } else {
4651            data
4652        };
4653
4654        let mut guard = entry.data.write();
4655        *guard = data;
4656    }
4657}
4658
4659// ---------------------------------------------------------------------------
4660// DerivedScanExec — physical plan that reads from shared data at execution time
4661// ---------------------------------------------------------------------------
4662
4663/// Physical plan for `LocyDerivedScan` that reads from a shared `Arc<RwLock>` at
4664/// execution time (not at plan creation time).
4665///
4666/// This is critical for fixpoint iteration: the data handle is updated between
4667/// iterations, and each re-execution of the subplan must read the latest data.
4668pub struct DerivedScanExec {
4669    data: Arc<RwLock<Vec<RecordBatch>>>,
4670    schema: SchemaRef,
4671    properties: Arc<PlanProperties>,
4672}
4673
4674impl DerivedScanExec {
4675    pub fn new(data: Arc<RwLock<Vec<RecordBatch>>>, schema: SchemaRef) -> Self {
4676        let properties = compute_plan_properties(Arc::clone(&schema));
4677        Self {
4678            data,
4679            schema,
4680            properties,
4681        }
4682    }
4683}
4684
4685impl fmt::Debug for DerivedScanExec {
4686    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4687        f.debug_struct("DerivedScanExec")
4688            .field("schema", &self.schema)
4689            .finish()
4690    }
4691}
4692
4693impl DisplayAs for DerivedScanExec {
4694    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4695        write!(f, "DerivedScanExec")
4696    }
4697}
4698
4699impl ExecutionPlan for DerivedScanExec {
4700    fn name(&self) -> &str {
4701        "DerivedScanExec"
4702    }
4703    fn as_any(&self) -> &dyn Any {
4704        self
4705    }
4706    fn schema(&self) -> SchemaRef {
4707        Arc::clone(&self.schema)
4708    }
4709    fn properties(&self) -> &Arc<PlanProperties> {
4710        &self.properties
4711    }
4712    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
4713        vec![]
4714    }
4715    fn with_new_children(
4716        self: Arc<Self>,
4717        _children: Vec<Arc<dyn ExecutionPlan>>,
4718    ) -> DFResult<Arc<dyn ExecutionPlan>> {
4719        Ok(self)
4720    }
4721    fn execute(
4722        &self,
4723        _partition: usize,
4724        _context: Arc<TaskContext>,
4725    ) -> DFResult<SendableRecordBatchStream> {
4726        let batches = {
4727            let guard = self.data.read();
4728            if guard.is_empty() {
4729                vec![RecordBatch::new_empty(Arc::clone(&self.schema))]
4730            } else {
4731                guard.clone()
4732            }
4733        };
4734        Ok(Box::pin(MemoryStream::try_new(
4735            batches,
4736            Arc::clone(&self.schema),
4737            None,
4738        )?))
4739    }
4740}
4741
4742// ---------------------------------------------------------------------------
4743// InMemoryExec — wrapper to feed Vec<RecordBatch> into operator chains
4744// ---------------------------------------------------------------------------
4745
4746/// Simple in-memory execution plan that serves pre-computed batches.
4747///
4748/// Used internally to feed fixpoint results into post-fixpoint operator chains
4749/// (FOLD, BEST BY). Not exported — only used within this module.
4750struct InMemoryExec {
4751    batches: Vec<RecordBatch>,
4752    schema: SchemaRef,
4753    properties: Arc<PlanProperties>,
4754}
4755
4756impl InMemoryExec {
4757    fn new(batches: Vec<RecordBatch>, schema: SchemaRef) -> Self {
4758        let properties = compute_plan_properties(Arc::clone(&schema));
4759        Self {
4760            batches,
4761            schema,
4762            properties,
4763        }
4764    }
4765}
4766
4767impl fmt::Debug for InMemoryExec {
4768    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4769        f.debug_struct("InMemoryExec")
4770            .field("num_batches", &self.batches.len())
4771            .field("schema", &self.schema)
4772            .finish()
4773    }
4774}
4775
4776impl DisplayAs for InMemoryExec {
4777    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4778        write!(f, "InMemoryExec: batches={}", self.batches.len())
4779    }
4780}
4781
4782impl ExecutionPlan for InMemoryExec {
4783    fn name(&self) -> &str {
4784        "InMemoryExec"
4785    }
4786    fn as_any(&self) -> &dyn Any {
4787        self
4788    }
4789    fn schema(&self) -> SchemaRef {
4790        Arc::clone(&self.schema)
4791    }
4792    fn properties(&self) -> &Arc<PlanProperties> {
4793        &self.properties
4794    }
4795    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
4796        vec![]
4797    }
4798    fn with_new_children(
4799        self: Arc<Self>,
4800        _children: Vec<Arc<dyn ExecutionPlan>>,
4801    ) -> DFResult<Arc<dyn ExecutionPlan>> {
4802        Ok(self)
4803    }
4804    fn execute(
4805        &self,
4806        _partition: usize,
4807        _context: Arc<TaskContext>,
4808    ) -> DFResult<SendableRecordBatchStream> {
4809        Ok(Box::pin(MemoryStream::try_new(
4810            self.batches.clone(),
4811            Arc::clone(&self.schema),
4812            None,
4813        )?))
4814    }
4815}
4816
4817// ---------------------------------------------------------------------------
4818// Post-fixpoint chain — FOLD and BEST BY on converged facts
4819// ---------------------------------------------------------------------------
4820
4821/// Apply post-FOLD WHERE (HAVING) filter to aggregated batches.
4822///
4823/// Converts each Cypher HAVING expression to a DataFusion physical expression
4824/// via `cypher_expr_to_df` → type coercion → `create_physical_expr`, evaluates
4825/// against the FOLD output, and keeps only rows where all conditions hold.
4826fn apply_having_filter(
4827    batches: Vec<RecordBatch>,
4828    having_exprs: &[Expr],
4829    schema: &SchemaRef,
4830    task_ctx: &Arc<TaskContext>,
4831) -> DFResult<Vec<RecordBatch>> {
4832    use arrow::compute::{and, filter_record_batch};
4833    use arrow_array::BooleanArray;
4834    use datafusion::common::DFSchema;
4835    use datafusion::logical_expr::LogicalPlanBuilder;
4836    use datafusion::logical_expr::execution_props::ExecutionProps;
4837    use datafusion::optimizer::AnalyzerRule;
4838    use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
4839    use datafusion::physical_expr::create_physical_expr;
4840
4841    if batches.is_empty() {
4842        return Ok(batches);
4843    }
4844
4845    // Build DFSchema from the FOLD output Arrow schema.
4846    let df_schema = DFSchema::try_from(schema.as_ref().clone()).map_err(|e| {
4847        datafusion::common::DataFusionError::Internal(format!("HAVING schema conversion: {e}"))
4848    })?;
4849
4850    // Use the active TaskContext's config rather than allocating a fresh
4851    // `SessionContext` per HAVING evaluation (~130 µs/call). HAVING uses only
4852    // built-in DataFusion arithmetic — no Cypher UDFs — so a default
4853    // `ExecutionProps` is sufficient (it's documented as cheap to construct).
4854    let config = (**task_ctx.session_config().options()).clone();
4855    let props = ExecutionProps::new();
4856
4857    // Cypher Expr → DataFusion DfExpr → type-coerced DfExpr → PhysicalExpr.
4858    //
4859    // Type coercion is needed because FOLD aggregates produce Float64 (SUM,
4860    // AVG) or Int64 (COUNT), and literal comparisons like `total >= 100`
4861    // may mix Float64 columns with Int64 literals.
4862    let physical_exprs: Vec<Arc<dyn datafusion::physical_expr::PhysicalExpr>> = having_exprs
4863        .iter()
4864        .map(|expr| {
4865            let df_expr = crate::query::df_expr::cypher_expr_to_df(expr, None).map_err(|e| {
4866                datafusion::common::DataFusionError::Internal(format!(
4867                    "HAVING expression conversion: {e}"
4868                ))
4869            })?;
4870
4871            // Run DataFusion's type coercion by wrapping in a Filter plan,
4872            // applying the TypeCoercion analyzer rule, then extracting the
4873            // coerced predicate.
4874            let empty = datafusion::logical_expr::LogicalPlan::EmptyRelation(
4875                datafusion::logical_expr::EmptyRelation {
4876                    produce_one_row: false,
4877                    schema: Arc::new(df_schema.clone()),
4878                },
4879            );
4880            let filter_plan = LogicalPlanBuilder::from(empty)
4881                .filter(df_expr.clone())?
4882                .build()?;
4883            let coerced_expr = match TypeCoercion::new().analyze(filter_plan, &config) {
4884                Ok(datafusion::logical_expr::LogicalPlan::Filter(f)) => f.predicate,
4885                _ => df_expr,
4886            };
4887
4888            create_physical_expr(&coerced_expr, &df_schema, &props)
4889        })
4890        .collect::<DFResult<Vec<_>>>()?;
4891
4892    let mut result = Vec::new();
4893    for batch in batches {
4894        // Evaluate each condition and AND the boolean masks.
4895        let mut mask: Option<BooleanArray> = None;
4896        for phys_expr in &physical_exprs {
4897            let value = phys_expr.evaluate(&batch)?;
4898            let arr = value.into_array(batch.num_rows())?;
4899            let bool_arr = arr.as_any().downcast_ref::<BooleanArray>().ok_or_else(|| {
4900                datafusion::common::DataFusionError::Internal(
4901                    "HAVING condition must evaluate to boolean".into(),
4902                )
4903            })?;
4904            mask = Some(match mask {
4905                None => bool_arr.clone(),
4906                Some(prev) => and(&prev, bool_arr).map_err(arrow_err)?,
4907            });
4908        }
4909        if let Some(ref m) = mask {
4910            let filtered = filter_record_batch(&batch, m).map_err(arrow_err)?;
4911            if filtered.num_rows() > 0 {
4912                result.push(filtered);
4913            }
4914        } else {
4915            result.push(batch);
4916        }
4917    }
4918    Ok(result)
4919}
4920
4921/// Apply post-fixpoint operators (FOLD, HAVING, BEST BY, PRIORITY) to converged facts.
4922#[allow(
4923    clippy::too_many_arguments,
4924    reason = "context bundle would be over-engineering for one call site"
4925)]
4926pub(crate) async fn apply_post_fixpoint_chain(
4927    facts: Vec<RecordBatch>,
4928    rule: &FixpointRulePlan,
4929    task_ctx: &Arc<TaskContext>,
4930    strict_probability_domain: bool,
4931    probability_epsilon: f64,
4932    semiring_kind: SemiringKind,
4933    provenance_tracker: Option<Arc<ProvenanceStore>>,
4934    top_k_proofs_k: usize,
4935    registry: Option<Arc<DerivedScanRegistry>>,
4936) -> DFResult<Vec<RecordBatch>> {
4937    if !rule.has_fold && !rule.has_best_by && !rule.has_priority && rule.having.is_empty() {
4938        return Ok(facts);
4939    }
4940
4941    // Wrap facts in InMemoryExec.
4942    // Prefer the actual batch schema (from physical execution) over the
4943    // pre-computed yield_schema, which may have wrong inferred types
4944    // (e.g. Float64 for a string property).
4945    let schema = facts
4946        .iter()
4947        .find(|b| b.num_rows() > 0)
4948        .map(|b| b.schema())
4949        .unwrap_or_else(|| Arc::clone(&rule.yield_schema));
4950
4951    // Phase D D-C0: pre-compute body-row → IS-ref support map for
4952    // TopKProofs MNOR's DNF inclusion-exclusion math. Must be built
4953    // here because `facts` is moved into `InMemoryExec` on the next
4954    // line. The map is keyed by a full-column row hash — only
4955    // meaningful when no downstream plan node strips/adds columns
4956    // between this batch view and the FoldExec input. PRIORITY drops
4957    // the `__priority` column, which would change row hashes; until
4958    // we plumb the map past PRIORITY, skip map construction for
4959    // PRIORITY rules (the failing TCK test doesn't use PRIORITY).
4960    // Read the active K from `semiring_kind` rather than the separate
4961    // `top_k_proofs_k` parameter — the latter is not always threaded
4962    // from the LocyProgram config (the semiring's `k` is the source of
4963    // truth).
4964    let topk_k: Option<usize> = match semiring_kind {
4965        SemiringKind::TopKProofs { k } if k > 0 => Some(k as usize),
4966        _ => None,
4967    };
4968    let body_support_map: Option<Arc<HashMap<Vec<u8>, Vec<ProofTerm>>>> = if topk_k.is_some()
4969        && !rule.has_priority
4970        && let Some(registry) = registry.as_ref()
4971    {
4972        let mut map: HashMap<Vec<u8>, Vec<ProofTerm>> = HashMap::new();
4973        for batch in &facts {
4974            let all_indices: Vec<usize> = (0..batch.num_columns()).collect();
4975            for row_idx in 0..batch.num_rows() {
4976                let support = collect_is_ref_inputs_for_body_row(rule, batch, row_idx, registry);
4977                if support.is_empty() {
4978                    continue;
4979                }
4980                let hash = fact_hash_key(batch, &all_indices, row_idx);
4981                map.insert(hash, support);
4982            }
4983        }
4984        if map.is_empty() {
4985            None
4986        } else {
4987            Some(Arc::new(map))
4988        }
4989    } else {
4990        None
4991    };
4992
4993    let input: Arc<dyn ExecutionPlan> = Arc::new(InMemoryExec::new(facts, schema.clone()));
4994
4995    // Reconcile key indices: rule's indices are yield-schema positions but
4996    // the actual batch may have different column ordering after schema
4997    // reconciliation during fixpoint iteration (same pattern as
4998    // FixpointState::reconcile_schema).
4999    let key_column_indices: Vec<usize> = rule
5000        .key_column_indices
5001        .iter()
5002        .filter_map(|&i| {
5003            let name = rule.yield_schema.field(i).name();
5004            schema.index_of(name).ok()
5005        })
5006        .collect();
5007
5008    // Apply PRIORITY first — keeps only rows with max __priority per KEY group,
5009    // then strips the __priority column from output.
5010    // Must run before FOLD so that the __priority column is still present.
5011    let current: Arc<dyn ExecutionPlan> = if rule.has_priority {
5012        let priority_schema = input.schema();
5013        let priority_idx = priority_schema.index_of("__priority").map_err(|_| {
5014            datafusion::common::DataFusionError::Internal(
5015                "PRIORITY rule missing __priority column".to_string(),
5016            )
5017        })?;
5018        Arc::new(PriorityExec::new(
5019            input,
5020            key_column_indices.clone(),
5021            priority_idx,
5022        ))
5023    } else {
5024        input
5025    };
5026
5027    // Apply FOLD
5028    let current: Arc<dyn ExecutionPlan> = if rule.has_fold && !rule.fold_bindings.is_empty() {
5029        Arc::new(FoldExec::new_with_topk(
5030            current,
5031            key_column_indices.clone(),
5032            rule.fold_bindings.clone(),
5033            strict_probability_domain,
5034            probability_epsilon,
5035            semiring_kind,
5036            provenance_tracker.clone(),
5037            topk_k.unwrap_or(top_k_proofs_k),
5038            body_support_map.clone(),
5039        ))
5040    } else {
5041        current
5042    };
5043
5044    // Apply HAVING (post-FOLD WHERE filter)
5045    let current: Arc<dyn ExecutionPlan> = if !rule.having.is_empty() {
5046        let batches = collect_all_partitions(&current, Arc::clone(task_ctx)).await?;
5047        let filtered = apply_having_filter(batches, &rule.having, &current.schema(), task_ctx)?;
5048        if filtered.is_empty() {
5049            return Ok(filtered);
5050        }
5051        Arc::new(InMemoryExec::new(filtered, Arc::clone(&current.schema())))
5052    } else {
5053        current
5054    };
5055
5056    // Apply BEST BY
5057    let current: Arc<dyn ExecutionPlan> = if rule.has_best_by && !rule.best_by_criteria.is_empty() {
5058        Arc::new(BestByExec::new(
5059            current,
5060            key_column_indices.clone(),
5061            rule.best_by_criteria.clone(),
5062            rule.deterministic,
5063        ))
5064    } else {
5065        current
5066    };
5067
5068    collect_all_partitions(&current, Arc::clone(task_ctx)).await
5069}
5070
5071// ---------------------------------------------------------------------------
5072// FixpointExec — DataFusion ExecutionPlan
5073// ---------------------------------------------------------------------------
5074
5075/// DataFusion `ExecutionPlan` that drives semi-naive fixpoint iteration.
5076///
5077/// Has no physical children: clause bodies are re-planned from logical plans
5078/// on each iteration (same pattern as `RecursiveCTEExec` and `GraphApplyExec`).
5079pub struct FixpointExec {
5080    rules: Vec<FixpointRulePlan>,
5081    max_iterations: usize,
5082    timeout: Duration,
5083    graph_ctx: Arc<GraphExecutionContext>,
5084    session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
5085    storage: Arc<StorageManager>,
5086    schema_info: Arc<UniSchema>,
5087    params: HashMap<String, Value>,
5088    derived_scan_registry: Arc<DerivedScanRegistry>,
5089    output_schema: SchemaRef,
5090    properties: Arc<PlanProperties>,
5091    metrics: ExecutionPlanMetricsSet,
5092    max_derived_bytes: usize,
5093    /// Optional provenance tracker populated during fixpoint iteration.
5094    derivation_tracker: Option<Arc<ProvenanceStore>>,
5095    /// Shared slot written with per-rule iteration counts after convergence.
5096    iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
5097    strict_probability_domain: bool,
5098    probability_epsilon: f64,
5099    exact_probability: bool,
5100    max_bdd_variables: usize,
5101    /// Shared slot for runtime warnings collected during fixpoint iteration.
5102    warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
5103    /// Shared slot for groups where BDD fell back to independence mode.
5104    approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
5105    /// When > 0, retain at most this many proofs per fact (top-k provenance).
5106    top_k_proofs: usize,
5107    /// Shared flag: set to true on timeout to signal partial results.
5108    timeout_flag: Arc<std::sync::atomic::AtomicU8>,
5109    /// Active probability semiring (rollout D-7).
5110    semiring_kind: SemiringKind,
5111    /// Phase B Slice 3 registry of neural classifiers, keyed by the
5112    /// model name from `CREATE MODEL`. Held by `Arc` so executor clones
5113    /// share the same underlying map.
5114    classifier_registry: Arc<ClassifierRegistry>,
5115    /// Phase B follow-up: optional per-evaluation memoization cache
5116    /// for classifier outputs keyed by `(model_name, feature_hash)`.
5117    /// `None` → no caching; `Some` → cache shared across fixpoint
5118    /// iterations (and optionally across the entire query / multiple
5119    /// queries, when the caller threads it via `LocyConfig`).
5120    classifier_cache: Option<Arc<ModelInvocationCache>>,
5121    /// Phase C B1-B3 follow-up: per-query side-channel store
5122    /// for (raw, calibrated, confidence_band) records. Read by
5123    /// EXPLAIN; not used by the fixpoint inner loop directly
5124    /// (LocyModelInvokeExec writes; this struct just carries
5125    /// the Arc to keep the type wiring consistent across the
5126    /// LocyProgramExec/FixpointExec boundary).
5127    #[allow(
5128        dead_code,
5129        reason = "boundary plumbing; read by EXPLAIN via LocyModelInvokeExec"
5130    )]
5131    classifier_provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
5132}
5133
5134impl fmt::Debug for FixpointExec {
5135    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
5136        f.debug_struct("FixpointExec")
5137            .field("rules_count", &self.rules.len())
5138            .field("max_iterations", &self.max_iterations)
5139            .field("timeout", &self.timeout)
5140            .field("output_schema", &self.output_schema)
5141            .field("max_derived_bytes", &self.max_derived_bytes)
5142            .finish_non_exhaustive()
5143    }
5144}
5145
5146impl FixpointExec {
5147    /// Create a new `FixpointExec`.
5148    #[expect(
5149        clippy::too_many_arguments,
5150        reason = "FixpointExec configuration needs all context"
5151    )]
5152    #[deprecated(
5153        note = "use `new_with_semiring_classifiers_and_cache` (or the lighter \
5154                `new_with_semiring_and_classifiers` / `new_with_semiring`) — \
5155                this legacy ctor defaults the semiring to AddMultProb and \
5156                ships no classifier registry, which the Phase B+ runtime needs \
5157                explicitly. To be removed after C0 Stage 2."
5158    )]
5159    pub fn new(
5160        rules: Vec<FixpointRulePlan>,
5161        max_iterations: usize,
5162        timeout: Duration,
5163        graph_ctx: Arc<GraphExecutionContext>,
5164        session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
5165        storage: Arc<StorageManager>,
5166        schema_info: Arc<UniSchema>,
5167        params: HashMap<String, Value>,
5168        derived_scan_registry: Arc<DerivedScanRegistry>,
5169        output_schema: SchemaRef,
5170        max_derived_bytes: usize,
5171        derivation_tracker: Option<Arc<ProvenanceStore>>,
5172        iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
5173        strict_probability_domain: bool,
5174        probability_epsilon: f64,
5175        exact_probability: bool,
5176        max_bdd_variables: usize,
5177        warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
5178        approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
5179        top_k_proofs: usize,
5180        timeout_flag: Arc<std::sync::atomic::AtomicU8>,
5181    ) -> Self {
5182        Self::new_with_semiring_and_classifiers(
5183            rules,
5184            max_iterations,
5185            timeout,
5186            graph_ctx,
5187            session_ctx,
5188            storage,
5189            schema_info,
5190            params,
5191            derived_scan_registry,
5192            output_schema,
5193            max_derived_bytes,
5194            derivation_tracker,
5195            iteration_counts,
5196            strict_probability_domain,
5197            probability_epsilon,
5198            exact_probability,
5199            max_bdd_variables,
5200            warnings_slot,
5201            approximate_slot,
5202            top_k_proofs,
5203            timeout_flag,
5204            SemiringKind::AddMultProb,
5205            Arc::new(ClassifierRegistry::new()),
5206        )
5207    }
5208
5209    /// Variant accepting an explicit [`SemiringKind`]. Empty classifier
5210    /// registry; for the full variant call
5211    /// [`FixpointExec::new_with_semiring_and_classifiers`].
5212    #[expect(
5213        clippy::too_many_arguments,
5214        reason = "FixpointExec configuration needs all context"
5215    )]
5216    pub fn new_with_semiring(
5217        rules: Vec<FixpointRulePlan>,
5218        max_iterations: usize,
5219        timeout: Duration,
5220        graph_ctx: Arc<GraphExecutionContext>,
5221        session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
5222        storage: Arc<StorageManager>,
5223        schema_info: Arc<UniSchema>,
5224        params: HashMap<String, Value>,
5225        derived_scan_registry: Arc<DerivedScanRegistry>,
5226        output_schema: SchemaRef,
5227        max_derived_bytes: usize,
5228        derivation_tracker: Option<Arc<ProvenanceStore>>,
5229        iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
5230        strict_probability_domain: bool,
5231        probability_epsilon: f64,
5232        exact_probability: bool,
5233        max_bdd_variables: usize,
5234        warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
5235        approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
5236        top_k_proofs: usize,
5237        timeout_flag: Arc<std::sync::atomic::AtomicU8>,
5238        semiring_kind: SemiringKind,
5239    ) -> Self {
5240        Self::new_with_semiring_and_classifiers(
5241            rules,
5242            max_iterations,
5243            timeout,
5244            graph_ctx,
5245            session_ctx,
5246            storage,
5247            schema_info,
5248            params,
5249            derived_scan_registry,
5250            output_schema,
5251            max_derived_bytes,
5252            derivation_tracker,
5253            iteration_counts,
5254            strict_probability_domain,
5255            probability_epsilon,
5256            exact_probability,
5257            max_bdd_variables,
5258            warnings_slot,
5259            approximate_slot,
5260            top_k_proofs,
5261            timeout_flag,
5262            semiring_kind,
5263            Arc::new(ClassifierRegistry::new()),
5264        )
5265    }
5266
5267    /// Phase B Slice 3 entry: accepts both the semiring kind and the
5268    /// runtime classifier registry. The planner uses this when the
5269    /// program contains `CREATE MODEL` declarations.
5270    #[expect(
5271        clippy::too_many_arguments,
5272        reason = "FixpointExec configuration needs all context"
5273    )]
5274    pub fn new_with_semiring_and_classifiers(
5275        rules: Vec<FixpointRulePlan>,
5276        max_iterations: usize,
5277        timeout: Duration,
5278        graph_ctx: Arc<GraphExecutionContext>,
5279        session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
5280        storage: Arc<StorageManager>,
5281        schema_info: Arc<UniSchema>,
5282        params: HashMap<String, Value>,
5283        derived_scan_registry: Arc<DerivedScanRegistry>,
5284        output_schema: SchemaRef,
5285        max_derived_bytes: usize,
5286        derivation_tracker: Option<Arc<ProvenanceStore>>,
5287        iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
5288        strict_probability_domain: bool,
5289        probability_epsilon: f64,
5290        exact_probability: bool,
5291        max_bdd_variables: usize,
5292        warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
5293        approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
5294        top_k_proofs: usize,
5295        timeout_flag: Arc<std::sync::atomic::AtomicU8>,
5296        semiring_kind: SemiringKind,
5297        classifier_registry: Arc<ClassifierRegistry>,
5298    ) -> Self {
5299        Self::new_with_semiring_classifiers_and_cache(
5300            rules,
5301            max_iterations,
5302            timeout,
5303            graph_ctx,
5304            session_ctx,
5305            storage,
5306            schema_info,
5307            params,
5308            derived_scan_registry,
5309            output_schema,
5310            max_derived_bytes,
5311            derivation_tracker,
5312            iteration_counts,
5313            strict_probability_domain,
5314            probability_epsilon,
5315            exact_probability,
5316            max_bdd_variables,
5317            warnings_slot,
5318            approximate_slot,
5319            top_k_proofs,
5320            timeout_flag,
5321            semiring_kind,
5322            classifier_registry,
5323            None,
5324            None,
5325        )
5326    }
5327
5328    /// Phase B follow-up: full constructor accepting an optional
5329    /// memoization cache. Existing callers default to `None` (no cache);
5330    /// the impl_locy.rs entry passes the user's `config.classifier_cache`.
5331    #[expect(
5332        clippy::too_many_arguments,
5333        reason = "FixpointExec configuration needs all context"
5334    )]
5335    pub fn new_with_semiring_classifiers_and_cache(
5336        rules: Vec<FixpointRulePlan>,
5337        max_iterations: usize,
5338        timeout: Duration,
5339        graph_ctx: Arc<GraphExecutionContext>,
5340        session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
5341        storage: Arc<StorageManager>,
5342        schema_info: Arc<UniSchema>,
5343        params: HashMap<String, Value>,
5344        derived_scan_registry: Arc<DerivedScanRegistry>,
5345        output_schema: SchemaRef,
5346        max_derived_bytes: usize,
5347        derivation_tracker: Option<Arc<ProvenanceStore>>,
5348        iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
5349        strict_probability_domain: bool,
5350        probability_epsilon: f64,
5351        exact_probability: bool,
5352        max_bdd_variables: usize,
5353        warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
5354        approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
5355        top_k_proofs: usize,
5356        timeout_flag: Arc<std::sync::atomic::AtomicU8>,
5357        semiring_kind: SemiringKind,
5358        classifier_registry: Arc<ClassifierRegistry>,
5359        classifier_cache: Option<Arc<ModelInvocationCache>>,
5360        classifier_provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
5361    ) -> Self {
5362        let properties = compute_plan_properties(Arc::clone(&output_schema));
5363        Self {
5364            rules,
5365            max_iterations,
5366            timeout,
5367            graph_ctx,
5368            session_ctx,
5369            storage,
5370            schema_info,
5371            params,
5372            derived_scan_registry,
5373            output_schema,
5374            properties,
5375            metrics: ExecutionPlanMetricsSet::new(),
5376            max_derived_bytes,
5377            derivation_tracker,
5378            iteration_counts,
5379            strict_probability_domain,
5380            probability_epsilon,
5381            exact_probability,
5382            max_bdd_variables,
5383            warnings_slot,
5384            approximate_slot,
5385            top_k_proofs,
5386            timeout_flag,
5387            semiring_kind,
5388            classifier_registry,
5389            classifier_cache,
5390            classifier_provenance_store,
5391        }
5392    }
5393
5394    /// Returns the shared iteration counts slot for post-execution inspection.
5395    pub fn iteration_counts(&self) -> Arc<StdRwLock<HashMap<String, usize>>> {
5396        Arc::clone(&self.iteration_counts)
5397    }
5398}
5399
5400impl DisplayAs for FixpointExec {
5401    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
5402        write!(
5403            f,
5404            "FixpointExec: rules=[{}], max_iter={}, timeout={:?}",
5405            self.rules
5406                .iter()
5407                .map(|r| r.name.as_str())
5408                .collect::<Vec<_>>()
5409                .join(", "),
5410            self.max_iterations,
5411            self.timeout,
5412        )
5413    }
5414}
5415
5416impl ExecutionPlan for FixpointExec {
5417    fn name(&self) -> &str {
5418        "FixpointExec"
5419    }
5420
5421    fn as_any(&self) -> &dyn Any {
5422        self
5423    }
5424
5425    fn schema(&self) -> SchemaRef {
5426        Arc::clone(&self.output_schema)
5427    }
5428
5429    fn properties(&self) -> &Arc<PlanProperties> {
5430        &self.properties
5431    }
5432
5433    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
5434        // No physical children — clause bodies are re-planned each iteration
5435        vec![]
5436    }
5437
5438    fn with_new_children(
5439        self: Arc<Self>,
5440        children: Vec<Arc<dyn ExecutionPlan>>,
5441    ) -> DFResult<Arc<dyn ExecutionPlan>> {
5442        if !children.is_empty() {
5443            return Err(datafusion::error::DataFusionError::Plan(
5444                "FixpointExec has no children".to_string(),
5445            ));
5446        }
5447        Ok(self)
5448    }
5449
5450    fn execute(
5451        &self,
5452        partition: usize,
5453        _context: Arc<TaskContext>,
5454    ) -> DFResult<SendableRecordBatchStream> {
5455        let metrics = BaselineMetrics::new(&self.metrics, partition);
5456
5457        // Clone all fields for the async closure
5458        let rules = self
5459            .rules
5460            .iter()
5461            .map(|r| {
5462                // We need to clone the FixpointRulePlan, but it contains LogicalPlan
5463                // which doesn't implement Clone traditionally. However, our LogicalPlan
5464                // does implement Clone since it's an enum.
5465                FixpointRulePlan {
5466                    name: r.name.clone(),
5467                    clauses: r
5468                        .clauses
5469                        .iter()
5470                        .map(|c| FixpointClausePlan {
5471                            body_logical: c.body_logical.clone(),
5472                            is_ref_bindings: c.is_ref_bindings.clone(),
5473                            priority: c.priority,
5474                            along_bindings: c.along_bindings.clone(),
5475                            model_invocations: c.model_invocations.clone(),
5476                        })
5477                        .collect(),
5478                    yield_schema: Arc::clone(&r.yield_schema),
5479                    key_column_indices: r.key_column_indices.clone(),
5480                    priority: r.priority,
5481                    has_fold: r.has_fold,
5482                    fold_bindings: r.fold_bindings.clone(),
5483                    having: r.having.clone(),
5484                    has_best_by: r.has_best_by,
5485                    best_by_criteria: r.best_by_criteria.clone(),
5486                    has_priority: r.has_priority,
5487                    deterministic: r.deterministic,
5488                    prob_column_name: r.prob_column_name.clone(),
5489                }
5490            })
5491            .collect();
5492
5493        let max_iterations = self.max_iterations;
5494        let timeout = self.timeout;
5495        let graph_ctx = Arc::clone(&self.graph_ctx);
5496        let session_ctx = Arc::clone(&self.session_ctx);
5497        let storage = Arc::clone(&self.storage);
5498        let schema_info = Arc::clone(&self.schema_info);
5499        let params = self.params.clone();
5500        let registry = Arc::clone(&self.derived_scan_registry);
5501        let output_schema = Arc::clone(&self.output_schema);
5502        let max_derived_bytes = self.max_derived_bytes;
5503        let derivation_tracker = self.derivation_tracker.clone();
5504        let iteration_counts = Arc::clone(&self.iteration_counts);
5505        let strict_probability_domain = self.strict_probability_domain;
5506        let probability_epsilon = self.probability_epsilon;
5507        let exact_probability = self.exact_probability;
5508        let max_bdd_variables = self.max_bdd_variables;
5509        let warnings_slot = Arc::clone(&self.warnings_slot);
5510        let approximate_slot = Arc::clone(&self.approximate_slot);
5511        let top_k_proofs = self.top_k_proofs;
5512        let timeout_flag = Arc::clone(&self.timeout_flag);
5513        let semiring_kind = self.semiring_kind;
5514        let classifier_registry = Arc::clone(&self.classifier_registry);
5515        let classifier_cache = self.classifier_cache.as_ref().map(Arc::clone);
5516        let classifier_provenance_store = self.classifier_provenance_store.as_ref().map(Arc::clone);
5517
5518        let fut = async move {
5519            run_fixpoint_loop(
5520                rules,
5521                max_iterations,
5522                timeout,
5523                graph_ctx,
5524                session_ctx,
5525                storage,
5526                schema_info,
5527                params,
5528                registry,
5529                output_schema,
5530                max_derived_bytes,
5531                derivation_tracker,
5532                iteration_counts,
5533                strict_probability_domain,
5534                probability_epsilon,
5535                exact_probability,
5536                max_bdd_variables,
5537                warnings_slot,
5538                approximate_slot,
5539                top_k_proofs,
5540                timeout_flag,
5541                semiring_kind,
5542                classifier_registry,
5543                classifier_cache,
5544                classifier_provenance_store,
5545            )
5546            .await
5547        };
5548
5549        Ok(Box::pin(FixpointStream {
5550            state: FixpointStreamState::Running(Box::pin(fut)),
5551            schema: Arc::clone(&self.output_schema),
5552            metrics,
5553        }))
5554    }
5555
5556    fn metrics(&self) -> Option<MetricsSet> {
5557        Some(self.metrics.clone_inner())
5558    }
5559}
5560
5561// ---------------------------------------------------------------------------
5562// FixpointStream — async state machine for streaming results
5563// ---------------------------------------------------------------------------
5564
5565enum FixpointStreamState {
5566    /// Fixpoint loop is running.
5567    Running(Pin<Box<dyn std::future::Future<Output = DFResult<Vec<RecordBatch>>> + Send>>),
5568    /// Emitting accumulated result batches one at a time.
5569    Emitting(Vec<RecordBatch>, usize),
5570    /// All batches emitted.
5571    Done,
5572}
5573
5574struct FixpointStream {
5575    state: FixpointStreamState,
5576    schema: SchemaRef,
5577    metrics: BaselineMetrics,
5578}
5579
5580impl Stream for FixpointStream {
5581    type Item = DFResult<RecordBatch>;
5582
5583    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
5584        let this = self.get_mut();
5585        let metrics = this.metrics.clone();
5586        let _timer = metrics.elapsed_compute().timer();
5587        loop {
5588            match &mut this.state {
5589                FixpointStreamState::Running(fut) => match fut.as_mut().poll(cx) {
5590                    Poll::Ready(Ok(batches)) => {
5591                        if batches.is_empty() {
5592                            this.state = FixpointStreamState::Done;
5593                            return Poll::Ready(None);
5594                        }
5595                        this.state = FixpointStreamState::Emitting(batches, 0);
5596                        // Loop to emit first batch
5597                    }
5598                    Poll::Ready(Err(e)) => {
5599                        this.state = FixpointStreamState::Done;
5600                        return Poll::Ready(Some(Err(e)));
5601                    }
5602                    Poll::Pending => return Poll::Pending,
5603                },
5604                FixpointStreamState::Emitting(batches, idx) => {
5605                    if *idx >= batches.len() {
5606                        this.state = FixpointStreamState::Done;
5607                        return Poll::Ready(None);
5608                    }
5609                    let batch = batches[*idx].clone();
5610                    *idx += 1;
5611                    this.metrics.record_output(batch.num_rows());
5612                    return Poll::Ready(Some(Ok(batch)));
5613                }
5614                FixpointStreamState::Done => return Poll::Ready(None),
5615            }
5616        }
5617    }
5618}
5619
5620impl RecordBatchStream for FixpointStream {
5621    fn schema(&self) -> SchemaRef {
5622        Arc::clone(&self.schema)
5623    }
5624}
5625
5626// ---------------------------------------------------------------------------
5627// Unit tests
5628// ---------------------------------------------------------------------------
5629
5630#[cfg(test)]
5631mod tests {
5632    use super::*;
5633    use arrow_array::{Float64Array, Int64Array, StringArray};
5634    use arrow_schema::{DataType, Field, Schema};
5635
5636    fn test_schema() -> SchemaRef {
5637        Arc::new(Schema::new(vec![
5638            Field::new("name", DataType::Utf8, true),
5639            Field::new("value", DataType::Int64, true),
5640        ]))
5641    }
5642
5643    fn make_batch(names: &[&str], values: &[i64]) -> RecordBatch {
5644        RecordBatch::try_new(
5645            test_schema(),
5646            vec![
5647                Arc::new(StringArray::from(
5648                    names.iter().map(|s| Some(*s)).collect::<Vec<_>>(),
5649                )),
5650                Arc::new(Int64Array::from(values.to_vec())),
5651            ],
5652        )
5653        .unwrap()
5654    }
5655
5656    // --- FixpointState dedup tests ---
5657
5658    #[tokio::test]
5659    async fn test_fixpoint_state_empty_facts_adds_all() {
5660        let schema = test_schema();
5661        let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
5662
5663        let batch = make_batch(&["a", "b", "c"], &[1, 2, 3]);
5664        let changed = state.merge_delta(vec![batch], None).await.unwrap();
5665
5666        assert!(changed);
5667        assert_eq!(state.all_facts().len(), 1);
5668        assert_eq!(state.all_facts()[0].num_rows(), 3);
5669        assert_eq!(state.all_delta().len(), 1);
5670        assert_eq!(state.all_delta()[0].num_rows(), 3);
5671    }
5672
5673    #[tokio::test]
5674    async fn test_fixpoint_state_exact_duplicates_excluded() {
5675        let schema = test_schema();
5676        let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
5677
5678        let batch1 = make_batch(&["a", "b"], &[1, 2]);
5679        state.merge_delta(vec![batch1], None).await.unwrap();
5680
5681        // Same rows again
5682        let batch2 = make_batch(&["a", "b"], &[1, 2]);
5683        let changed = state.merge_delta(vec![batch2], None).await.unwrap();
5684        assert!(!changed);
5685        assert!(
5686            state.all_delta().is_empty() || state.all_delta().iter().all(|b| b.num_rows() == 0)
5687        );
5688    }
5689
5690    #[tokio::test]
5691    async fn test_fixpoint_state_partial_overlap() {
5692        let schema = test_schema();
5693        let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
5694
5695        let batch1 = make_batch(&["a", "b"], &[1, 2]);
5696        state.merge_delta(vec![batch1], None).await.unwrap();
5697
5698        // "a":1 is duplicate, "c":3 is new
5699        let batch2 = make_batch(&["a", "c"], &[1, 3]);
5700        let changed = state.merge_delta(vec![batch2], None).await.unwrap();
5701        assert!(changed);
5702
5703        // Delta should have only "c":3
5704        let delta_rows: usize = state.all_delta().iter().map(|b| b.num_rows()).sum();
5705        assert_eq!(delta_rows, 1);
5706
5707        // Total facts: a:1, b:2, c:3
5708        let total_rows: usize = state.all_facts().iter().map(|b| b.num_rows()).sum();
5709        assert_eq!(total_rows, 3);
5710    }
5711
5712    #[tokio::test]
5713    async fn test_fixpoint_state_convergence() {
5714        let schema = test_schema();
5715        let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
5716
5717        let batch = make_batch(&["a"], &[1]);
5718        state.merge_delta(vec![batch], None).await.unwrap();
5719
5720        // Empty candidates → converged
5721        let changed = state.merge_delta(vec![], None).await.unwrap();
5722        assert!(!changed);
5723        assert!(state.is_converged());
5724    }
5725
5726    // --- RowDedupState tests ---
5727
5728    #[test]
5729    fn test_row_dedup_persistent_across_calls() {
5730        // RowDedupState should remember rows from the first call so the second
5731        // call does not re-accept them (O(M) per iteration, no facts re-scan).
5732        let schema = test_schema();
5733        let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
5734
5735        let batch1 = make_batch(&["a", "b"], &[1, 2]);
5736        let delta1 = rd.compute_delta(&[batch1], &schema).unwrap();
5737        // First call: both rows are new.
5738        let rows1: usize = delta1.iter().map(|b| b.num_rows()).sum();
5739        assert_eq!(rows1, 2);
5740
5741        // Second call with same rows: seen set already has them → empty delta.
5742        let batch2 = make_batch(&["a", "b"], &[1, 2]);
5743        let delta2 = rd.compute_delta(&[batch2], &schema).unwrap();
5744        let rows2: usize = delta2.iter().map(|b| b.num_rows()).sum();
5745        assert_eq!(rows2, 0);
5746
5747        // Third call with one old + one new: only the new row is returned.
5748        let batch3 = make_batch(&["a", "c"], &[1, 3]);
5749        let delta3 = rd.compute_delta(&[batch3], &schema).unwrap();
5750        let rows3: usize = delta3.iter().map(|b| b.num_rows()).sum();
5751        assert_eq!(rows3, 1);
5752    }
5753
5754    #[test]
5755    fn test_row_dedup_null_handling() {
5756        use arrow_array::StringArray;
5757        use arrow_schema::{DataType, Field, Schema};
5758
5759        let schema: SchemaRef = Arc::new(Schema::new(vec![
5760            Field::new("a", DataType::Utf8, true),
5761            Field::new("b", DataType::Int64, true),
5762        ]));
5763        let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
5764
5765        // Two rows: (NULL, 1) and (NULL, 1) — same NULLs → duplicate.
5766        let batch_nulls = RecordBatch::try_new(
5767            Arc::clone(&schema),
5768            vec![
5769                Arc::new(StringArray::from(vec![None::<&str>, None::<&str>])),
5770                Arc::new(arrow_array::Int64Array::from(vec![1i64, 1i64])),
5771            ],
5772        )
5773        .unwrap();
5774        let delta = rd.compute_delta(&[batch_nulls], &schema).unwrap();
5775        let rows: usize = delta.iter().map(|b| b.num_rows()).sum();
5776        assert_eq!(rows, 1, "two identical NULL rows should be deduped to one");
5777
5778        // (NULL, 2) — NULL in same col but different non-null col → distinct.
5779        let batch_diff = RecordBatch::try_new(
5780            Arc::clone(&schema),
5781            vec![
5782                Arc::new(StringArray::from(vec![None::<&str>])),
5783                Arc::new(arrow_array::Int64Array::from(vec![2i64])),
5784            ],
5785        )
5786        .unwrap();
5787        let delta2 = rd.compute_delta(&[batch_diff], &schema).unwrap();
5788        let rows2: usize = delta2.iter().map(|b| b.num_rows()).sum();
5789        assert_eq!(rows2, 1, "(NULL, 2) is distinct from (NULL, 1)");
5790    }
5791
5792    #[test]
5793    fn test_row_dedup_within_candidate_dedup() {
5794        // Duplicate rows within a single candidate batch should be collapsed to one.
5795        let schema = test_schema();
5796        let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
5797
5798        // Batch with three rows: a:1, a:1, b:2 — "a:1" appears twice.
5799        let batch = make_batch(&["a", "a", "b"], &[1, 1, 2]);
5800        let delta = rd.compute_delta(&[batch], &schema).unwrap();
5801        let rows: usize = delta.iter().map(|b| b.num_rows()).sum();
5802        assert_eq!(rows, 2, "within-batch dup should be collapsed: a:1, b:2");
5803    }
5804
5805    // --- Float rounding tests ---
5806
5807    #[test]
5808    fn test_round_float_columns_near_duplicates() {
5809        let schema = Arc::new(Schema::new(vec![
5810            Field::new("name", DataType::Utf8, true),
5811            Field::new("dist", DataType::Float64, true),
5812        ]));
5813        let batch = RecordBatch::try_new(
5814            Arc::clone(&schema),
5815            vec![
5816                Arc::new(StringArray::from(vec![Some("a"), Some("a")])),
5817                Arc::new(Float64Array::from(vec![1.0000000000001, 1.0000000000002])),
5818            ],
5819        )
5820        .unwrap();
5821
5822        let rounded = round_float_columns(&[batch]);
5823        assert_eq!(rounded.len(), 1);
5824        let col = rounded[0]
5825            .column(1)
5826            .as_any()
5827            .downcast_ref::<Float64Array>()
5828            .unwrap();
5829        // Both should round to same value
5830        assert_eq!(col.value(0), col.value(1));
5831    }
5832
5833    // --- DerivedScanRegistry tests ---
5834
5835    #[test]
5836    fn test_registry_write_read_round_trip() {
5837        let schema = test_schema();
5838        let data = Arc::new(RwLock::new(Vec::new()));
5839        let mut reg = DerivedScanRegistry::new();
5840        reg.add(DerivedScanEntry {
5841            scan_index: 0,
5842            rule_name: "reachable".into(),
5843            is_self_ref: true,
5844            data: Arc::clone(&data),
5845            schema: Arc::clone(&schema),
5846        });
5847
5848        let batch = make_batch(&["x"], &[42]);
5849        reg.write_data(0, vec![batch.clone()]);
5850
5851        let entry = reg.get(0).unwrap();
5852        let guard = entry.data.read();
5853        assert_eq!(guard.len(), 1);
5854        assert_eq!(guard[0].num_rows(), 1);
5855    }
5856
5857    #[test]
5858    fn test_registry_entries_for_rule() {
5859        let schema = test_schema();
5860        let mut reg = DerivedScanRegistry::new();
5861        reg.add(DerivedScanEntry {
5862            scan_index: 0,
5863            rule_name: "r1".into(),
5864            is_self_ref: true,
5865            data: Arc::new(RwLock::new(Vec::new())),
5866            schema: Arc::clone(&schema),
5867        });
5868        reg.add(DerivedScanEntry {
5869            scan_index: 1,
5870            rule_name: "r2".into(),
5871            is_self_ref: false,
5872            data: Arc::new(RwLock::new(Vec::new())),
5873            schema: Arc::clone(&schema),
5874        });
5875        reg.add(DerivedScanEntry {
5876            scan_index: 2,
5877            rule_name: "r1".into(),
5878            is_self_ref: false,
5879            data: Arc::new(RwLock::new(Vec::new())),
5880            schema: Arc::clone(&schema),
5881        });
5882
5883        assert_eq!(reg.entries_for_rule("r1").len(), 2);
5884        assert_eq!(reg.entries_for_rule("r2").len(), 1);
5885        assert_eq!(reg.entries_for_rule("r3").len(), 0);
5886    }
5887
5888    // --- MonotonicAggState tests ---
5889
5890    #[test]
5891    fn test_monotonic_agg_update_and_stability() {
5892        let bindings = vec![MonotonicFoldBinding {
5893            fold_name: "total".into(),
5894            aggregate: std::sync::Arc::new(uni_plugin_builtin::locy_aggregates::SumAgg),
5895            input_col_index: 1,
5896            input_col_name: None,
5897        }];
5898        let mut agg = MonotonicAggState::new(bindings);
5899
5900        // First update
5901        let batch = make_batch(&["a"], &[10]);
5902        agg.snapshot();
5903        let changed = agg
5904            .update(&[0], &[batch], false, SemiringKind::AddMultProb)
5905            .unwrap();
5906        assert!(changed);
5907        assert!(!agg.is_stable()); // changed since snapshot
5908
5909        // Snapshot and check stability with no new data
5910        agg.snapshot();
5911        let changed = agg
5912            .update(&[0], &[], false, SemiringKind::AddMultProb)
5913            .unwrap();
5914        assert!(!changed);
5915        assert!(agg.is_stable());
5916    }
5917
5918    // --- Memory limit test ---
5919
5920    #[tokio::test]
5921    async fn test_memory_limit_exceeded() {
5922        let schema = test_schema();
5923        // Set a tiny limit
5924        let mut state = FixpointState::new("test".into(), schema, vec![0], 1, None, false);
5925
5926        let batch = make_batch(&["a", "b", "c"], &[1, 2, 3]);
5927        let result = state.merge_delta(vec![batch], None).await;
5928        assert!(result.is_err());
5929        let err = result.unwrap_err().to_string();
5930        assert!(err.contains("memory limit"), "Error was: {}", err);
5931    }
5932
5933    // --- FixpointStream lifecycle test ---
5934
5935    #[tokio::test]
5936    async fn test_fixpoint_stream_emitting() {
5937        use futures::StreamExt;
5938
5939        let schema = test_schema();
5940        let batch1 = make_batch(&["a"], &[1]);
5941        let batch2 = make_batch(&["b"], &[2]);
5942
5943        let metrics = ExecutionPlanMetricsSet::new();
5944        let baseline = BaselineMetrics::new(&metrics, 0);
5945
5946        let mut stream = FixpointStream {
5947            state: FixpointStreamState::Emitting(vec![batch1, batch2], 0),
5948            schema,
5949            metrics: baseline,
5950        };
5951
5952        let stream = Pin::new(&mut stream);
5953        let batches: Vec<RecordBatch> = stream.filter_map(|r| async { r.ok() }).collect().await;
5954
5955        assert_eq!(batches.len(), 2);
5956        assert_eq!(batches[0].num_rows(), 1);
5957        assert_eq!(batches[1].num_rows(), 1);
5958    }
5959
5960    // ── MonotonicAggState MNOR/MPROD tests ──────────────────────────────
5961
5962    fn make_f64_batch(names: &[&str], values: &[f64]) -> RecordBatch {
5963        let schema = Arc::new(Schema::new(vec![
5964            Field::new("name", DataType::Utf8, true),
5965            Field::new("value", DataType::Float64, true),
5966        ]));
5967        RecordBatch::try_new(
5968            schema,
5969            vec![
5970                Arc::new(StringArray::from(
5971                    names.iter().map(|s| Some(*s)).collect::<Vec<_>>(),
5972                )),
5973                Arc::new(Float64Array::from(values.to_vec())),
5974            ],
5975        )
5976        .unwrap()
5977    }
5978
5979    fn make_nor_binding() -> Vec<MonotonicFoldBinding> {
5980        vec![MonotonicFoldBinding {
5981            fold_name: "prob".into(),
5982            aggregate: std::sync::Arc::new(uni_plugin_builtin::locy_aggregates::MnorAgg),
5983            input_col_index: 1,
5984            input_col_name: None,
5985        }]
5986    }
5987
5988    fn make_prod_binding() -> Vec<MonotonicFoldBinding> {
5989        vec![MonotonicFoldBinding {
5990            fold_name: "prob".into(),
5991            aggregate: std::sync::Arc::new(uni_plugin_builtin::locy_aggregates::MprodAgg),
5992            input_col_index: 1,
5993            input_col_name: None,
5994        }]
5995    }
5996
5997    fn acc_key(name: &str) -> (Vec<ScalarKey>, String) {
5998        (vec![ScalarKey::Utf8(name.to_string())], "prob".to_string())
5999    }
6000
6001    #[test]
6002    fn test_monotonic_nor_first_update() {
6003        let mut agg = MonotonicAggState::new(make_nor_binding());
6004        let batch = make_f64_batch(&["a"], &[0.3]);
6005        let changed = agg
6006            .update(&[0], &[batch], false, SemiringKind::AddMultProb)
6007            .unwrap();
6008        assert!(changed);
6009        let val = agg.get_accumulator(&acc_key("a")).unwrap();
6010        assert!((val - 0.3).abs() < 1e-10, "expected 0.3, got {}", val);
6011    }
6012
6013    #[test]
6014    fn test_monotonic_nor_two_updates() {
6015        // Incremental NOR: acc = 1-(1-0.3)(1-0.5) = 0.65
6016        let mut agg = MonotonicAggState::new(make_nor_binding());
6017        let batch1 = make_f64_batch(&["a"], &[0.3]);
6018        agg.update(&[0], &[batch1], false, SemiringKind::AddMultProb)
6019            .unwrap();
6020        let batch2 = make_f64_batch(&["a"], &[0.5]);
6021        agg.update(&[0], &[batch2], false, SemiringKind::AddMultProb)
6022            .unwrap();
6023        let val = agg.get_accumulator(&acc_key("a")).unwrap();
6024        assert!((val - 0.65).abs() < 1e-10, "expected 0.65, got {}", val);
6025    }
6026
6027    #[test]
6028    fn test_monotonic_prod_first_update() {
6029        let mut agg = MonotonicAggState::new(make_prod_binding());
6030        let batch = make_f64_batch(&["a"], &[0.6]);
6031        let changed = agg
6032            .update(&[0], &[batch], false, SemiringKind::AddMultProb)
6033            .unwrap();
6034        assert!(changed);
6035        let val = agg.get_accumulator(&acc_key("a")).unwrap();
6036        assert!((val - 0.6).abs() < 1e-10, "expected 0.6, got {}", val);
6037    }
6038
6039    #[test]
6040    fn test_monotonic_prod_two_updates() {
6041        // Incremental PROD: acc = 0.6 * 0.8 = 0.48
6042        let mut agg = MonotonicAggState::new(make_prod_binding());
6043        let batch1 = make_f64_batch(&["a"], &[0.6]);
6044        agg.update(&[0], &[batch1], false, SemiringKind::AddMultProb)
6045            .unwrap();
6046        let batch2 = make_f64_batch(&["a"], &[0.8]);
6047        agg.update(&[0], &[batch2], false, SemiringKind::AddMultProb)
6048            .unwrap();
6049        let val = agg.get_accumulator(&acc_key("a")).unwrap();
6050        assert!((val - 0.48).abs() < 1e-10, "expected 0.48, got {}", val);
6051    }
6052
6053    #[test]
6054    fn test_monotonic_nor_stability() {
6055        let mut agg = MonotonicAggState::new(make_nor_binding());
6056        let batch = make_f64_batch(&["a"], &[0.3]);
6057        agg.update(&[0], &[batch], false, SemiringKind::AddMultProb)
6058            .unwrap();
6059        agg.snapshot();
6060        let changed = agg
6061            .update(&[0], &[], false, SemiringKind::AddMultProb)
6062            .unwrap();
6063        assert!(!changed);
6064        assert!(agg.is_stable());
6065    }
6066
6067    #[test]
6068    fn test_monotonic_prod_stability() {
6069        let mut agg = MonotonicAggState::new(make_prod_binding());
6070        let batch = make_f64_batch(&["a"], &[0.6]);
6071        agg.update(&[0], &[batch], false, SemiringKind::AddMultProb)
6072            .unwrap();
6073        agg.snapshot();
6074        let changed = agg
6075            .update(&[0], &[], false, SemiringKind::AddMultProb)
6076            .unwrap();
6077        assert!(!changed);
6078        assert!(agg.is_stable());
6079    }
6080
6081    #[test]
6082    fn test_monotonic_nor_multi_group() {
6083        // (a,0.3),(b,0.5) then (a,0.5),(b,0.2) → a=0.65, b=0.6
6084        let mut agg = MonotonicAggState::new(make_nor_binding());
6085        let batch1 = make_f64_batch(&["a", "b"], &[0.3, 0.5]);
6086        agg.update(&[0], &[batch1], false, SemiringKind::AddMultProb)
6087            .unwrap();
6088        let batch2 = make_f64_batch(&["a", "b"], &[0.5, 0.2]);
6089        agg.update(&[0], &[batch2], false, SemiringKind::AddMultProb)
6090            .unwrap();
6091
6092        let val_a = agg.get_accumulator(&acc_key("a")).unwrap();
6093        let val_b = agg.get_accumulator(&acc_key("b")).unwrap();
6094        assert!(
6095            (val_a - 0.65).abs() < 1e-10,
6096            "expected a=0.65, got {}",
6097            val_a
6098        );
6099        assert!((val_b - 0.6).abs() < 1e-10, "expected b=0.6, got {}", val_b);
6100    }
6101
6102    #[test]
6103    fn test_monotonic_prod_zero_absorbing() {
6104        // Zero absorbs: once 0.0, all further updates stay 0.0
6105        let mut agg = MonotonicAggState::new(make_prod_binding());
6106        let batch1 = make_f64_batch(&["a"], &[0.5]);
6107        agg.update(&[0], &[batch1], false, SemiringKind::AddMultProb)
6108            .unwrap();
6109        let batch2 = make_f64_batch(&["a"], &[0.0]);
6110        agg.update(&[0], &[batch2], false, SemiringKind::AddMultProb)
6111            .unwrap();
6112
6113        let val = agg.get_accumulator(&acc_key("a")).unwrap();
6114        assert!((val - 0.0).abs() < 1e-10, "expected 0.0, got {}", val);
6115
6116        // Further updates don't change the absorbing zero
6117        agg.snapshot();
6118        let batch3 = make_f64_batch(&["a"], &[0.5]);
6119        let changed = agg
6120            .update(&[0], &[batch3], false, SemiringKind::AddMultProb)
6121            .unwrap();
6122        assert!(!changed);
6123        assert!(agg.is_stable());
6124    }
6125
6126    #[test]
6127    fn test_monotonic_nor_clamping() {
6128        // 1.5 clamped to 1.0: acc = 1-(1-0)(1-1) = 1.0
6129        let mut agg = MonotonicAggState::new(make_nor_binding());
6130        let batch = make_f64_batch(&["a"], &[1.5]);
6131        agg.update(&[0], &[batch], false, SemiringKind::AddMultProb)
6132            .unwrap();
6133        let val = agg.get_accumulator(&acc_key("a")).unwrap();
6134        assert!((val - 1.0).abs() < 1e-10, "expected 1.0, got {}", val);
6135    }
6136
6137    #[test]
6138    fn test_monotonic_nor_absorbing() {
6139        // p=1.0 absorbs: 0.3 then 1.0 → 1.0
6140        let mut agg = MonotonicAggState::new(make_nor_binding());
6141        let batch1 = make_f64_batch(&["a"], &[0.3]);
6142        agg.update(&[0], &[batch1], false, SemiringKind::AddMultProb)
6143            .unwrap();
6144        let batch2 = make_f64_batch(&["a"], &[1.0]);
6145        agg.update(&[0], &[batch2], false, SemiringKind::AddMultProb)
6146            .unwrap();
6147        let val = agg.get_accumulator(&acc_key("a")).unwrap();
6148        assert!((val - 1.0).abs() < 1e-10, "expected 1.0, got {}", val);
6149    }
6150
6151    // ── MonotonicAggState strict mode tests (Phase 5) ─────────────────────
6152
6153    #[test]
6154    fn test_monotonic_agg_strict_nor_rejects() {
6155        let mut agg = MonotonicAggState::new(make_nor_binding());
6156        let batch = make_f64_batch(&["a"], &[1.5]);
6157        let result = agg.update(&[0], &[batch], true, SemiringKind::AddMultProb);
6158        assert!(result.is_err());
6159        let err = result.unwrap_err().to_string();
6160        assert!(
6161            err.contains("strict_probability_domain"),
6162            "Expected strict error, got: {}",
6163            err
6164        );
6165    }
6166
6167    #[test]
6168    fn test_monotonic_agg_strict_prod_rejects() {
6169        let mut agg = MonotonicAggState::new(make_prod_binding());
6170        let batch = make_f64_batch(&["a"], &[2.0]);
6171        let result = agg.update(&[0], &[batch], true, SemiringKind::AddMultProb);
6172        assert!(result.is_err());
6173        let err = result.unwrap_err().to_string();
6174        assert!(
6175            err.contains("strict_probability_domain"),
6176            "Expected strict error, got: {}",
6177            err
6178        );
6179    }
6180
6181    #[test]
6182    fn test_monotonic_agg_strict_accepts_valid() {
6183        let mut agg = MonotonicAggState::new(make_nor_binding());
6184        let batch = make_f64_batch(&["a"], &[0.5]);
6185        let result = agg.update(&[0], &[batch], true, SemiringKind::AddMultProb);
6186        assert!(result.is_ok());
6187        let val = agg.get_accumulator(&acc_key("a")).unwrap();
6188        assert!((val - 0.5).abs() < 1e-10, "expected 0.5, got {}", val);
6189    }
6190
6191    // ── Complement function unit tests (Phase 4) ──────────────────────────
6192
6193    fn make_vid_prob_batch(vids: &[u64], probs: &[f64]) -> RecordBatch {
6194        use arrow_array::UInt64Array;
6195        let schema = Arc::new(Schema::new(vec![
6196            Field::new("vid", DataType::UInt64, true),
6197            Field::new("prob", DataType::Float64, true),
6198        ]));
6199        RecordBatch::try_new(
6200            schema,
6201            vec![
6202                Arc::new(UInt64Array::from(vids.to_vec())),
6203                Arc::new(Float64Array::from(probs.to_vec())),
6204            ],
6205        )
6206        .unwrap()
6207    }
6208
6209    #[test]
6210    fn test_prob_complement_basic() {
6211        // neg has VID=1 with prob=0.7 → complement=0.3; VID=2 absent → complement=1.0
6212        let body = make_vid_prob_batch(&[1, 2], &[0.9, 0.8]);
6213        let neg = make_vid_prob_batch(&[1], &[0.7]);
6214        let join_cols = vec![("vid".to_string(), "vid".to_string())];
6215        let result = apply_prob_complement_composite(
6216            vec![body],
6217            &[neg],
6218            &join_cols,
6219            "prob",
6220            "__complement_0",
6221        )
6222        .unwrap();
6223        assert_eq!(result.len(), 1);
6224        let batch = &result[0];
6225        let complement = batch
6226            .column_by_name("__complement_0")
6227            .unwrap()
6228            .as_any()
6229            .downcast_ref::<Float64Array>()
6230            .unwrap();
6231        // VID=1: complement = 1 - 0.7 = 0.3
6232        assert!(
6233            (complement.value(0) - 0.3).abs() < 1e-10,
6234            "expected 0.3, got {}",
6235            complement.value(0)
6236        );
6237        // VID=2: absent from neg → complement = 1.0
6238        assert!(
6239            (complement.value(1) - 1.0).abs() < 1e-10,
6240            "expected 1.0, got {}",
6241            complement.value(1)
6242        );
6243    }
6244
6245    #[test]
6246    fn test_prob_complement_noisy_or_duplicates() {
6247        // neg has VID=1 twice with prob=0.3 and prob=0.5
6248        // Combined via noisy-OR: 1-(1-0.3)(1-0.5) = 0.65
6249        // Complement = 1 - 0.65 = 0.35
6250        let body = make_vid_prob_batch(&[1], &[0.9]);
6251        let neg = make_vid_prob_batch(&[1, 1], &[0.3, 0.5]);
6252        let join_cols = vec![("vid".to_string(), "vid".to_string())];
6253        let result = apply_prob_complement_composite(
6254            vec![body],
6255            &[neg],
6256            &join_cols,
6257            "prob",
6258            "__complement_0",
6259        )
6260        .unwrap();
6261        let batch = &result[0];
6262        let complement = batch
6263            .column_by_name("__complement_0")
6264            .unwrap()
6265            .as_any()
6266            .downcast_ref::<Float64Array>()
6267            .unwrap();
6268        assert!(
6269            (complement.value(0) - 0.35).abs() < 1e-10,
6270            "expected 0.35, got {}",
6271            complement.value(0)
6272        );
6273    }
6274
6275    #[test]
6276    fn test_prob_complement_empty_neg() {
6277        // Empty neg_facts → body passes through with complement=1.0
6278        let body = make_vid_prob_batch(&[1, 2], &[0.5, 0.6]);
6279        let join_cols = vec![("vid".to_string(), "vid".to_string())];
6280        let result =
6281            apply_prob_complement_composite(vec![body], &[], &join_cols, "prob", "__complement_0")
6282                .unwrap();
6283        let batch = &result[0];
6284        let complement = batch
6285            .column_by_name("__complement_0")
6286            .unwrap()
6287            .as_any()
6288            .downcast_ref::<Float64Array>()
6289            .unwrap();
6290        for i in 0..2 {
6291            assert!(
6292                (complement.value(i) - 1.0).abs() < 1e-10,
6293                "row {}: expected 1.0, got {}",
6294                i,
6295                complement.value(i)
6296            );
6297        }
6298    }
6299
6300    #[test]
6301    fn test_anti_join_basic() {
6302        // body [1,2,3], neg [2] → result [1,3]
6303        use arrow_array::UInt64Array;
6304        let body = make_vid_prob_batch(&[1, 2, 3], &[0.5, 0.6, 0.7]);
6305        let neg = make_vid_prob_batch(&[2], &[0.0]);
6306        let join_cols = vec![("vid".to_string(), "vid".to_string())];
6307        let result = apply_anti_join_composite(vec![body], &[neg], &join_cols).unwrap();
6308        assert_eq!(result.len(), 1);
6309        let batch = &result[0];
6310        assert_eq!(batch.num_rows(), 2);
6311        let vids = batch
6312            .column_by_name("vid")
6313            .unwrap()
6314            .as_any()
6315            .downcast_ref::<UInt64Array>()
6316            .unwrap();
6317        assert_eq!(vids.value(0), 1);
6318        assert_eq!(vids.value(1), 3);
6319    }
6320
6321    #[test]
6322    fn test_anti_join_empty_neg() {
6323        // Empty neg → all rows kept
6324        let body = make_vid_prob_batch(&[1, 2, 3], &[0.5, 0.6, 0.7]);
6325        let join_cols = vec![("vid".to_string(), "vid".to_string())];
6326        let result = apply_anti_join_composite(vec![body], &[], &join_cols).unwrap();
6327        assert_eq!(result.len(), 1);
6328        assert_eq!(result[0].num_rows(), 3);
6329    }
6330
6331    #[test]
6332    fn test_anti_join_all_excluded() {
6333        // neg covers all body rows → empty result
6334        let body = make_vid_prob_batch(&[1, 2], &[0.5, 0.6]);
6335        let neg = make_vid_prob_batch(&[1, 2], &[0.0, 0.0]);
6336        let join_cols = vec![("vid".to_string(), "vid".to_string())];
6337        let result = apply_anti_join_composite(vec![body], &[neg], &join_cols).unwrap();
6338        let total: usize = result.iter().map(|b| b.num_rows()).sum();
6339        assert_eq!(total, 0);
6340    }
6341
6342    #[test]
6343    fn test_multiply_prob_single_complement() {
6344        // prob=0.8, complement=0.5 → output prob=0.4; complement col removed
6345        let body = make_vid_prob_batch(&[1], &[0.8]);
6346        // Add a complement column
6347        let complement_arr = Float64Array::from(vec![0.5]);
6348        let mut cols: Vec<arrow_array::ArrayRef> = body.columns().to_vec();
6349        cols.push(Arc::new(complement_arr));
6350        let mut fields: Vec<Arc<Field>> = body.schema().fields().iter().cloned().collect();
6351        fields.push(Arc::new(Field::new(
6352            "__complement_0",
6353            DataType::Float64,
6354            true,
6355        )));
6356        let schema = Arc::new(Schema::new(fields));
6357        let batch = RecordBatch::try_new(schema, cols).unwrap();
6358
6359        let result =
6360            multiply_prob_factors(vec![batch], Some("prob"), &["__complement_0".to_string()])
6361                .unwrap();
6362        assert_eq!(result.len(), 1);
6363        let out = &result[0];
6364        // Complement column should be removed
6365        assert!(out.column_by_name("__complement_0").is_none());
6366        let prob = out
6367            .column_by_name("prob")
6368            .unwrap()
6369            .as_any()
6370            .downcast_ref::<Float64Array>()
6371            .unwrap();
6372        assert!(
6373            (prob.value(0) - 0.4).abs() < 1e-10,
6374            "expected 0.4, got {}",
6375            prob.value(0)
6376        );
6377    }
6378
6379    #[test]
6380    fn test_multiply_prob_multiple_complements() {
6381        // prob=0.8, c1=0.5, c2=0.6 → 0.8×0.5×0.6=0.24
6382        let body = make_vid_prob_batch(&[1], &[0.8]);
6383        let c1 = Float64Array::from(vec![0.5]);
6384        let c2 = Float64Array::from(vec![0.6]);
6385        let mut cols: Vec<arrow_array::ArrayRef> = body.columns().to_vec();
6386        cols.push(Arc::new(c1));
6387        cols.push(Arc::new(c2));
6388        let mut fields: Vec<Arc<Field>> = body.schema().fields().iter().cloned().collect();
6389        fields.push(Arc::new(Field::new("__c1", DataType::Float64, true)));
6390        fields.push(Arc::new(Field::new("__c2", DataType::Float64, true)));
6391        let schema = Arc::new(Schema::new(fields));
6392        let batch = RecordBatch::try_new(schema, cols).unwrap();
6393
6394        let result = multiply_prob_factors(
6395            vec![batch],
6396            Some("prob"),
6397            &["__c1".to_string(), "__c2".to_string()],
6398        )
6399        .unwrap();
6400        let out = &result[0];
6401        assert!(out.column_by_name("__c1").is_none());
6402        assert!(out.column_by_name("__c2").is_none());
6403        let prob = out
6404            .column_by_name("prob")
6405            .unwrap()
6406            .as_any()
6407            .downcast_ref::<Float64Array>()
6408            .unwrap();
6409        assert!(
6410            (prob.value(0) - 0.24).abs() < 1e-10,
6411            "expected 0.24, got {}",
6412            prob.value(0)
6413        );
6414    }
6415
6416    #[test]
6417    fn test_multiply_prob_no_prob_column() {
6418        // No prob column → combined complements become the output
6419        use arrow_array::UInt64Array;
6420        let schema = Arc::new(Schema::new(vec![
6421            Field::new("vid", DataType::UInt64, true),
6422            Field::new("__c1", DataType::Float64, true),
6423        ]));
6424        let batch = RecordBatch::try_new(
6425            schema,
6426            vec![
6427                Arc::new(UInt64Array::from(vec![1u64])),
6428                Arc::new(Float64Array::from(vec![0.7])),
6429            ],
6430        )
6431        .unwrap();
6432
6433        let result = multiply_prob_factors(vec![batch], None, &["__c1".to_string()]).unwrap();
6434        let out = &result[0];
6435        // __c1 should be removed since it's a complement column
6436        assert!(out.column_by_name("__c1").is_none());
6437        // Only vid column remains
6438        assert_eq!(out.num_columns(), 1);
6439    }
6440}