Skip to main content

uni_query/query/df_graph/
locy_fixpoint.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Fixpoint iteration operator for recursive Locy strata.
5//!
6//! `FixpointExec` drives semi-naive evaluation: it repeatedly evaluates the rules
7//! in a recursive stratum, feeding back deltas until no new facts are produced.
8
9use crate::query::df_graph::GraphExecutionContext;
10use crate::query::df_graph::common::{
11    ScalarKey, arrow_err, collect_all_partitions, compute_plan_properties, execute_subplan,
12    execute_subplan_collecting, extract_scalar_key,
13};
14use crate::query::df_graph::locy_best_by::{BestByExec, SortCriterion};
15use crate::query::df_graph::locy_errors::LocyRuntimeError;
16use crate::query::df_graph::locy_explain::{
17    ProofTerm, ProvenanceAnnotation, ProvenanceStore, compute_proof_probability,
18};
19use crate::query::df_graph::locy_fold::{FoldBinding, FoldExec};
20use crate::query::df_graph::locy_priority::PriorityExec;
21use crate::query::df_graph::locy_profile::LocyProfileCollector;
22use crate::query::df_graph::locy_program::interruption;
23use crate::query::executor::core::OperatorStats;
24use crate::query::planner::LogicalPlan;
25use arrow_array::RecordBatch;
26use arrow_row::{RowConverter, SortField};
27use arrow_schema::SchemaRef;
28use datafusion::common::JoinType;
29use datafusion::common::Result as DFResult;
30use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
31use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode};
32use datafusion::physical_plan::memory::MemoryStream;
33use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
34use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
35use futures::Stream;
36use parking_lot::RwLock;
37use std::any::Any;
38use std::collections::{HashMap, HashSet};
39use std::fmt;
40use std::pin::Pin;
41use std::sync::{Arc, RwLock as StdRwLock};
42use std::task::{Context, Poll};
43use std::time::{Duration, Instant};
44use uni_common::Value;
45use uni_common::core::schema::Schema as UniSchema;
46use uni_cypher::ast::Expr;
47use uni_locy::{
48    ClassifierRegistry, ModelInvocation, ModelInvocationCache, RuntimeWarning, RuntimeWarningCode,
49    SemiringKind,
50};
51use uni_store::storage::manager::StorageManager;
52
53// ---------------------------------------------------------------------------
54// DerivedScanRegistry — injection point for IS-ref data into subplans
55// ---------------------------------------------------------------------------
56
57/// A single entry in the derived scan registry.
58///
59/// Each entry corresponds to one `LocyDerivedScan` node in the logical plan tree.
60/// The `data` handle is shared with the logical plan node so that writing data here
61/// makes it visible when the subplan is re-planned and executed.
62#[derive(Debug)]
63pub struct DerivedScanEntry {
64    /// Index matching the `scan_index` in `LocyDerivedScan`.
65    pub scan_index: usize,
66    /// Name of the rule this scan reads from.
67    pub rule_name: String,
68    /// Whether this is a self-referential scan (rule references itself).
69    pub is_self_ref: bool,
70    /// Shared data handle — write batches here to inject into subplans.
71    pub data: Arc<RwLock<Vec<RecordBatch>>>,
72    /// Schema of the derived relation.
73    pub schema: SchemaRef,
74}
75
76/// Registry of derived scan handles for fixpoint iteration.
77///
78/// During fixpoint, each clause body may reference derived relations via
79/// `LocyDerivedScan` nodes. The registry maps scan indices to shared data
80/// handles so the fixpoint loop can inject delta/full facts before each
81/// iteration.
82#[derive(Debug, Default)]
83pub struct DerivedScanRegistry {
84    entries: Vec<DerivedScanEntry>,
85}
86
87impl DerivedScanRegistry {
88    /// Create a new empty registry.
89    pub fn new() -> Self {
90        Self::default()
91    }
92
93    /// Add an entry to the registry.
94    pub fn add(&mut self, entry: DerivedScanEntry) {
95        self.entries.push(entry);
96    }
97
98    /// Get an entry by scan index.
99    pub fn get(&self, scan_index: usize) -> Option<&DerivedScanEntry> {
100        self.entries.iter().find(|e| e.scan_index == scan_index)
101    }
102
103    /// Write data into a scan entry's shared handle.
104    pub fn write_data(&self, scan_index: usize, batches: Vec<RecordBatch>) {
105        if let Some(entry) = self.get(scan_index) {
106            let mut guard = entry.data.write();
107            *guard = batches;
108        }
109    }
110
111    /// Get all entries for a given rule name.
112    pub fn entries_for_rule(&self, rule_name: &str) -> Vec<&DerivedScanEntry> {
113        self.entries
114            .iter()
115            .filter(|e| e.rule_name == rule_name)
116            .collect()
117    }
118}
119
120// ---------------------------------------------------------------------------
121// MonotonicAggState — tracking monotonic aggregates across iterations
122// ---------------------------------------------------------------------------
123
124/// Monotonic aggregate binding: maps a fold name to its aggregate
125/// trait object and input column.
126///
127/// Dispatches purely through [`uni_plugin::traits::locy::LocyAggregate`]
128/// (`update_step` / `initial_accum_f64` / `is_probability_aggregate` /
129/// `is_noisy_or`).
130#[derive(Debug, Clone)]
131pub struct MonotonicFoldBinding {
132    pub fold_name: String,
133    pub aggregate: std::sync::Arc<dyn uni_plugin::traits::locy::LocyAggregate>,
134    pub input_col_index: usize,
135    /// Column name for name-based resolution (more robust than positional index).
136    pub input_col_name: Option<String>,
137}
138
139/// Tracks monotonic aggregate accumulators across fixpoint iterations.
140///
141/// After each iteration, accumulators are updated and compared to their previous
142/// snapshot. The fixpoint has converged (w.r.t. aggregates) when all accumulators
143/// are stable (no change between iterations).
144#[derive(Debug)]
145pub struct MonotonicAggState {
146    /// Current accumulator values keyed by (group_key, fold_name).
147    accumulators: HashMap<(Vec<ScalarKey>, String), f64>,
148    /// Snapshot from the previous iteration for stability check.
149    prev_snapshot: HashMap<(Vec<ScalarKey>, String), f64>,
150    /// Bindings describing which aggregates to track.
151    bindings: Vec<MonotonicFoldBinding>,
152}
153
154impl MonotonicAggState {
155    /// Create a new monotonic aggregate state.
156    pub fn new(bindings: Vec<MonotonicFoldBinding>) -> Self {
157        Self {
158            accumulators: HashMap::new(),
159            prev_snapshot: HashMap::new(),
160            bindings,
161        }
162    }
163
164    /// Update accumulators with new delta batches.
165    ///
166    /// Returns `true` if any accumulator value changed. When `strict` is
167    /// `true`, MNOR/MPROD inputs outside `[0, 1]` produce an error
168    /// instead of being clamped.
169    ///
170    /// `semiring_kind` selects the probability semiring for probability
171    /// aggregates: `AddMultProb` (default, Phase 1/2 noisy-OR/product) or
172    /// `MaxMinProb` (Viterbi/fuzzy — opt-in, callers emit
173    /// `FuzzyNotProbabilistic`).
174    ///
175    /// Dispatch goes through each binding's `Arc<dyn LocyAggregate>` trait
176    /// object via [`uni_plugin::traits::locy::LocyAggregate::update_step`].
177    /// The trait object's `initial_accum_f64()` seeds the per-group
178    /// accumulator. Under `MaxMinProb`, probability aggregates (MNOR /
179    /// MPROD) bypass `update_step` and fold via the `MaxMinProb` semiring's
180    /// `plus` (max) / `times` (min) instead — preserving the opt-in
181    /// Viterbi/fuzzy semantics that `update_step`'s built-in noisy-OR /
182    /// product path does not implement.
183    ///
184    /// Aggregates whose `update_step` returns `Err(CODE_UNKNOWN_FUNCTION)`
185    /// (default impl — no row-level fast path; e.g., `AVG`, `COLLECT`)
186    /// are skipped silently here — those run through the batch-shape
187    /// [`uni_plugin::traits::locy::LocyAggState::ingest`] path in
188    /// `apply_post_fixpoint_chain` instead.
189    pub fn update(
190        &mut self,
191        key_indices: &[usize],
192        delta_batches: &[RecordBatch],
193        strict: bool,
194        semiring_kind: SemiringKind,
195    ) -> DFResult<bool> {
196        let mut changed = false;
197        for batch in delta_batches {
198            for row_idx in 0..batch.num_rows() {
199                let group_key = extract_scalar_key(batch, key_indices, row_idx);
200                for binding in &self.bindings {
201                    let idx = binding
202                        .input_col_name
203                        .as_ref()
204                        .and_then(|name| batch.schema().index_of(name).ok())
205                        .unwrap_or(binding.input_col_index);
206                    if idx >= batch.num_columns() {
207                        continue;
208                    }
209                    let col = batch.column(idx);
210                    let val = extract_f64(col.as_ref(), row_idx);
211                    if let Some(val) = val {
212                        let map_key = (group_key.clone(), binding.fold_name.clone());
213                        let initial = binding.aggregate.initial_accum_f64().unwrap_or(0.0);
214                        let entry = self.accumulators.entry(map_key).or_insert(initial);
215                        let old = *entry;
216                        // Under `MaxMinProb`, probability aggregates (MNOR /
217                        // MPROD) fold via the Viterbi/fuzzy semiring (max /
218                        // min) rather than the trait object's built-in
219                        // noisy-OR / product `update_step`. The inline domain
220                        // checks below preserve the exact strict-mode error
221                        // and clamp-warning literals. `is_noisy_or()`
222                        // distinguishes MNOR (disjunction → max) from MPROD
223                        // (conjunction → min). All other aggregates — and the
224                        // default `AddMultProb` semiring — dispatch through
225                        // the trait object's `update_step`.
226                        if matches!(semiring_kind, SemiringKind::MaxMinProb)
227                            && binding.aggregate.is_probability_aggregate()
228                        {
229                            use uni_locy::LocySemiring;
230                            let sr = uni_locy::MaxMinProb;
231                            let is_nor = binding.aggregate.is_noisy_or();
232                            let label = if is_nor { "MNOR" } else { "MPROD" };
233                            if strict && !(0.0..=1.0).contains(&val) {
234                                return Err(datafusion::error::DataFusionError::Execution(
235                                    format!(
236                                        "strict_probability_domain: {label} input {val} is outside [0, 1]"
237                                    ),
238                                ));
239                            }
240                            if !strict && !(0.0..=1.0).contains(&val) {
241                                tracing::warn!(
242                                    "{label} input {val} outside [0,1], clamped to {}",
243                                    val.clamp(0.0, 1.0)
244                                );
245                            }
246                            let p = val.clamp(0.0, 1.0);
247                            // MaxMinProb: MNOR -> max (plus), MPROD -> min (times).
248                            *entry = if is_nor {
249                                sr.plus(entry, &p)
250                            } else {
251                                sr.times(entry, &p)
252                            };
253                            if (*entry - old).abs() > f64::EPSILON {
254                                changed = true;
255                            }
256                            continue;
257                        }
258                        match binding.aggregate.update_step(*entry, val, strict) {
259                            Ok(new_val) => {
260                                *entry = new_val;
261                                if (*entry - old).abs() > f64::EPSILON {
262                                    changed = true;
263                                }
264                            }
265                            Err(e) if e.code == uni_plugin::FnError::CODE_UNKNOWN_FUNCTION => {
266                                // Aggregate has no row-level fast path (AVG,
267                                // COLLECT). Those run through the
268                                // batch-shape `ingest` path elsewhere; skip.
269                            }
270                            Err(e) => {
271                                // Strict-mode probability-domain violation,
272                                // or another aggregate-specific failure.
273                                return Err(datafusion::error::DataFusionError::Execution(
274                                    e.message,
275                                ));
276                            }
277                        }
278                    }
279                }
280            }
281        }
282        Ok(changed)
283    }
284
285    /// Take a snapshot of current accumulators for stability comparison.
286    pub fn snapshot(&mut self) {
287        self.prev_snapshot = self.accumulators.clone();
288    }
289
290    /// Check if accumulators are stable (no change since last snapshot).
291    pub fn is_stable(&self) -> bool {
292        if self.accumulators.len() != self.prev_snapshot.len() {
293            return false;
294        }
295        for (key, val) in &self.accumulators {
296            match self.prev_snapshot.get(key) {
297                Some(prev) if (*val - *prev).abs() <= f64::EPSILON => {}
298                _ => return false,
299            }
300        }
301        true
302    }
303
304    /// Test-only accessor for accumulator values.
305    #[cfg(test)]
306    pub(crate) fn get_accumulator(&self, key: &(Vec<ScalarKey>, String)) -> Option<f64> {
307        self.accumulators.get(key).copied()
308    }
309}
310
311/// Extract f64 value from an Arrow column at a given row index.
312fn extract_f64(col: &dyn arrow_array::Array, row_idx: usize) -> Option<f64> {
313    if col.is_null(row_idx) {
314        return None;
315    }
316    if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Float64Array>() {
317        Some(arr.value(row_idx))
318    } else {
319        col.as_any()
320            .downcast_ref::<arrow_array::Int64Array>()
321            .map(|arr| arr.value(row_idx) as f64)
322    }
323}
324
325// ---------------------------------------------------------------------------
326// RowDedupState — Arrow RowConverter-based persistent dedup set
327// ---------------------------------------------------------------------------
328
329/// Arrow-native row deduplication using [`RowConverter`].
330///
331/// Unlike the legacy `HashSet<Vec<ScalarKey>>` approach, this struct maintains a
332/// persistent `seen` set across iterations so per-iteration cost is O(M) where M
333/// is the number of candidate rows — the full facts table is never re-scanned.
334struct RowDedupState {
335    converter: RowConverter,
336    seen: HashSet<Box<[u8]>>,
337}
338
339impl RowDedupState {
340    /// Try to build a `RowDedupState` for the given schema.
341    ///
342    /// Returns `None` if any column type is not supported by `RowConverter`
343    /// (triggers legacy fallback).
344    fn try_new(schema: &SchemaRef) -> Option<Self> {
345        let fields: Vec<SortField> = schema
346            .fields()
347            .iter()
348            .map(|f| SortField::new(f.data_type().clone()))
349            .collect();
350        match RowConverter::new(fields) {
351            Ok(converter) => Some(Self {
352                converter,
353                seen: HashSet::new(),
354            }),
355            Err(e) => {
356                tracing::warn!(
357                    "RowDedupState: RowConverter unsupported for schema, falling back to legacy dedup: {}",
358                    e
359                );
360                None
361            }
362        }
363    }
364
365    /// Populate the seen set from existing fact batches.
366    ///
367    /// Used after BEST BY in-loop pruning replaces the fact set, so that delta
368    /// computation in subsequent iterations correctly recognizes surviving facts.
369    fn ingest_existing(&mut self, facts: &[RecordBatch], _schema: &SchemaRef) {
370        self.seen.clear();
371        for batch in facts {
372            if batch.num_rows() == 0 {
373                continue;
374            }
375            let arrays: Vec<_> = batch.columns().to_vec();
376            if let Ok(rows) = self.converter.convert_columns(&arrays) {
377                for row_idx in 0..batch.num_rows() {
378                    let row_bytes: Box<[u8]> = rows.row(row_idx).data().into();
379                    self.seen.insert(row_bytes);
380                }
381            }
382        }
383    }
384
385    /// Filter `candidates` to only rows not yet seen, updating the persistent set.
386    ///
387    /// Both cross-iteration dedup (rows already accepted in prior iterations) and
388    /// within-batch dedup (duplicate rows in a single candidate batch) are handled
389    /// in a single pass.
390    fn compute_delta(
391        &mut self,
392        candidates: &[RecordBatch],
393        schema: &SchemaRef,
394    ) -> DFResult<Vec<RecordBatch>> {
395        let mut delta_batches = Vec::new();
396        for batch in candidates {
397            if batch.num_rows() == 0 {
398                continue;
399            }
400
401            // Vectorized encoding of all rows in this batch.
402            let arrays: Vec<_> = batch.columns().to_vec();
403            let rows = self.converter.convert_columns(&arrays).map_err(arrow_err)?;
404
405            // One pass: check+insert into persistent seen set.
406            let mut keep = Vec::with_capacity(batch.num_rows());
407            for row_idx in 0..batch.num_rows() {
408                let row_bytes: Box<[u8]> = rows.row(row_idx).data().into();
409                keep.push(self.seen.insert(row_bytes));
410            }
411
412            let keep_mask = arrow_array::BooleanArray::from(keep);
413            let new_cols = batch
414                .columns()
415                .iter()
416                .map(|col| {
417                    arrow::compute::filter(col.as_ref(), &keep_mask).map_err(|e| {
418                        datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
419                    })
420                })
421                .collect::<DFResult<Vec<_>>>()?;
422
423            if new_cols.first().is_some_and(|c| !c.is_empty()) {
424                let filtered = RecordBatch::try_new(Arc::clone(schema), new_cols).map_err(|e| {
425                    datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
426                })?;
427                delta_batches.push(filtered);
428            }
429        }
430        Ok(delta_batches)
431    }
432}
433
434// ---------------------------------------------------------------------------
435// FixpointState — per-rule delta tracking during fixpoint iteration
436// ---------------------------------------------------------------------------
437
438/// Per-rule state for fixpoint iteration.
439///
440/// Tracks accumulated facts and the delta (new facts from the latest iteration).
441/// Deduplication uses Arrow [`RowConverter`] with a persistent seen set (O(M) per
442/// iteration) when supported, with a legacy `HashSet<Vec<ScalarKey>>` fallback.
443pub struct FixpointState {
444    rule_name: String,
445    facts: Vec<RecordBatch>,
446    delta: Vec<RecordBatch>,
447    schema: SchemaRef,
448    key_column_indices: Vec<usize>,
449    /// KEY column names for recomputing indices after schema reconciliation.
450    key_column_names: Vec<String>,
451    /// All column indices for full-row dedup (legacy path only).
452    all_column_indices: Vec<usize>,
453    /// Running total of facts bytes for memory limit tracking.
454    facts_bytes: usize,
455    /// Maximum bytes allowed for this derived relation.
456    max_derived_bytes: usize,
457    /// Optional monotonic aggregate tracking.
458    monotonic_agg: Option<MonotonicAggState>,
459    /// Arrow RowConverter-based dedup state; `None` triggers legacy fallback.
460    row_dedup: Option<RowDedupState>,
461    /// Whether strict probability domain checks are enabled.
462    strict_probability_domain: bool,
463    /// Active probability semiring for this rule's MNOR/MPROD math.
464    semiring_kind: SemiringKind,
465}
466
467impl FixpointState {
468    /// Create a new fixpoint state for a rule. Existing tests call this
469    /// with the Phase 1/2 default; the fixpoint planner uses
470    /// [`FixpointState::new_with_semiring`] to thread the configured
471    /// semiring through.
472    pub fn new(
473        rule_name: String,
474        schema: SchemaRef,
475        key_column_indices: Vec<usize>,
476        max_derived_bytes: usize,
477        monotonic_agg: Option<MonotonicAggState>,
478        strict_probability_domain: bool,
479    ) -> Self {
480        Self::new_with_semiring(
481            rule_name,
482            schema,
483            key_column_indices,
484            max_derived_bytes,
485            monotonic_agg,
486            strict_probability_domain,
487            SemiringKind::AddMultProb,
488        )
489    }
490
491    pub fn new_with_semiring(
492        rule_name: String,
493        schema: SchemaRef,
494        key_column_indices: Vec<usize>,
495        max_derived_bytes: usize,
496        monotonic_agg: Option<MonotonicAggState>,
497        strict_probability_domain: bool,
498        semiring_kind: SemiringKind,
499    ) -> Self {
500        let num_cols = schema.fields().len();
501        let row_dedup = RowDedupState::try_new(&schema);
502        let key_column_names: Vec<String> = key_column_indices
503            .iter()
504            .filter_map(|&i| schema.fields().get(i).map(|f| f.name().clone()))
505            .collect();
506        Self {
507            rule_name,
508            facts: Vec::new(),
509            delta: Vec::new(),
510            schema,
511            key_column_indices,
512            key_column_names,
513            all_column_indices: (0..num_cols).collect(),
514            facts_bytes: 0,
515            max_derived_bytes,
516            monotonic_agg,
517            row_dedup,
518            strict_probability_domain,
519            semiring_kind,
520        }
521    }
522
523    /// Reconcile the pre-computed schema with the actual physical plan output.
524    ///
525    /// `infer_expr_type` may guess wrong (e.g. `Property → Float64` for a
526    /// string column).  When the first real batch arrives with a different
527    /// schema, update ours so that `RowDedupState` / `RecordBatch::try_new`
528    /// use the correct types.
529    fn reconcile_schema(&mut self, actual_schema: &SchemaRef) {
530        if self.schema.fields() != actual_schema.fields() {
531            tracing::debug!(
532                rule = %self.rule_name,
533                "Reconciling fixpoint schema from physical plan output",
534            );
535            self.schema = Arc::clone(actual_schema);
536            self.row_dedup = RowDedupState::try_new(&self.schema);
537            // Recompute key_column_indices from stored KEY column names.
538            // Without this, FoldExec groups by wrong columns when the
539            // physical plan reorders columns vs the pre-inferred schema.
540            let new_indices: Vec<usize> = self
541                .key_column_names
542                .iter()
543                .filter_map(|name| actual_schema.index_of(name).ok())
544                .collect();
545            if new_indices.len() == self.key_column_names.len() {
546                self.key_column_indices = new_indices;
547            }
548            // else: not all KEY columns found in new schema — keep original indices
549            let num_cols = actual_schema.fields().len();
550            self.all_column_indices = (0..num_cols).collect();
551        }
552    }
553
554    /// Merge candidate rows into facts, computing delta (truly new rows).
555    ///
556    /// Returns `true` if any new facts were added.
557    pub async fn merge_delta(
558        &mut self,
559        candidates: Vec<RecordBatch>,
560        task_ctx: Option<Arc<TaskContext>>,
561    ) -> DFResult<bool> {
562        if candidates.is_empty() || candidates.iter().all(|b| b.num_rows() == 0) {
563            self.delta.clear();
564            return Ok(false);
565        }
566
567        // Reconcile schema from the first non-empty candidate batch.
568        // The physical plan's output types are authoritative over the
569        // planner's inferred types.
570        if let Some(first) = candidates.iter().find(|b| b.num_rows() > 0) {
571            self.reconcile_schema(&first.schema());
572        }
573
574        // Round floats for stable dedup
575        let candidates = round_float_columns(&candidates);
576
577        // Compute delta: rows in candidates not already in facts
578        let delta = self.compute_delta(&candidates, task_ctx.as_ref()).await?;
579
580        if delta.is_empty() || delta.iter().all(|b| b.num_rows() == 0) {
581            self.delta.clear();
582            // Update monotonic aggs even with empty delta (for stability check)
583            if let Some(ref mut agg) = self.monotonic_agg {
584                agg.snapshot();
585            }
586            return Ok(false);
587        }
588
589        // Check memory limit
590        let delta_bytes: usize = delta.iter().map(batch_byte_size).sum();
591        if self.facts_bytes + delta_bytes > self.max_derived_bytes {
592            return Err(datafusion::error::DataFusionError::Execution(
593                LocyRuntimeError::MemoryLimitExceeded {
594                    rule: self.rule_name.clone(),
595                    bytes: self.facts_bytes + delta_bytes,
596                    limit: self.max_derived_bytes,
597                }
598                .to_string(),
599            ));
600        }
601
602        // Update monotonic aggs
603        if let Some(ref mut agg) = self.monotonic_agg {
604            agg.snapshot();
605            agg.update(
606                &self.key_column_indices,
607                &delta,
608                self.strict_probability_domain,
609                self.semiring_kind,
610            )?;
611        }
612
613        // Append delta to facts
614        self.facts_bytes += delta_bytes;
615        self.facts.extend(delta.iter().cloned());
616        self.delta = delta;
617
618        Ok(true)
619    }
620
621    /// Dispatch to vectorized LeftAntiJoin, Arrow RowConverter dedup, or legacy ScalarKey dedup.
622    ///
623    /// Priority order:
624    /// 1. `arrow_left_anti_dedup` when `total_existing >= DEDUP_ANTI_JOIN_THRESHOLD` and task_ctx available.
625    /// 2. `RowDedupState` (persistent HashSet, O(M) per iteration) when schema is supported.
626    /// 3. `compute_delta_legacy` (rebuilds from facts, fallback for unsupported column types).
627    async fn compute_delta(
628        &mut self,
629        candidates: &[RecordBatch],
630        task_ctx: Option<&Arc<TaskContext>>,
631    ) -> DFResult<Vec<RecordBatch>> {
632        let total_existing: usize = self.facts.iter().map(|b| b.num_rows()).sum();
633        if total_existing >= DEDUP_ANTI_JOIN_THRESHOLD
634            && let Some(ctx) = task_ctx
635        {
636            return arrow_left_anti_dedup(candidates.to_vec(), &self.facts, &self.schema, ctx)
637                .await;
638        }
639        if let Some(ref mut rd) = self.row_dedup {
640            rd.compute_delta(candidates, &self.schema)
641        } else {
642            self.compute_delta_legacy(candidates)
643        }
644    }
645
646    /// Legacy dedup: rebuild a `HashSet<Vec<ScalarKey>>` from all facts each call.
647    ///
648    /// Used as fallback when `RowConverter` does not support the schema's column types.
649    fn compute_delta_legacy(&self, candidates: &[RecordBatch]) -> DFResult<Vec<RecordBatch>> {
650        // Build set of existing fact row keys (ALL columns)
651        let mut existing: HashSet<Vec<ScalarKey>> = HashSet::new();
652        for batch in &self.facts {
653            for row_idx in 0..batch.num_rows() {
654                let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
655                existing.insert(key);
656            }
657        }
658
659        let mut delta_batches = Vec::new();
660        for batch in candidates {
661            if batch.num_rows() == 0 {
662                continue;
663            }
664            // Filter to only new rows
665            let mut keep = Vec::with_capacity(batch.num_rows());
666            for row_idx in 0..batch.num_rows() {
667                let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
668                keep.push(!existing.contains(&key));
669            }
670
671            // Also dedup within the candidate batch itself
672            for (row_idx, kept) in keep.iter_mut().enumerate() {
673                if *kept {
674                    let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
675                    if !existing.insert(key) {
676                        *kept = false;
677                    }
678                }
679            }
680
681            let keep_mask = arrow_array::BooleanArray::from(keep);
682            let new_rows = batch
683                .columns()
684                .iter()
685                .map(|col| {
686                    arrow::compute::filter(col.as_ref(), &keep_mask).map_err(|e| {
687                        datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
688                    })
689                })
690                .collect::<DFResult<Vec<_>>>()?;
691
692            if new_rows.first().is_some_and(|c| !c.is_empty()) {
693                let filtered =
694                    RecordBatch::try_new(Arc::clone(&self.schema), new_rows).map_err(|e| {
695                        datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
696                    })?;
697                delta_batches.push(filtered);
698            }
699        }
700
701        Ok(delta_batches)
702    }
703
704    /// Check if this rule has converged (no new facts and aggs stable).
705    pub fn is_converged(&self) -> bool {
706        let delta_empty = self.delta.is_empty() || self.delta.iter().all(|b| b.num_rows() == 0);
707        let agg_stable = self.monotonic_agg.as_ref().is_none_or(|a| a.is_stable());
708        delta_empty && agg_stable
709    }
710
711    /// Get all accumulated facts.
712    pub fn all_facts(&self) -> &[RecordBatch] {
713        &self.facts
714    }
715
716    /// Get the delta from the latest iteration.
717    pub fn all_delta(&self) -> &[RecordBatch] {
718        &self.delta
719    }
720
721    /// Consume self and return facts.
722    pub fn into_facts(self) -> Vec<RecordBatch> {
723        self.facts
724    }
725
726    /// Merge candidates using BEST BY semantics.
727    ///
728    /// Combines existing facts with new candidates, keeping only the best row
729    /// per KEY group according to `sort_criteria`. Returns `true` if the
730    /// best-per-KEY fact set actually changed (a genuinely better value was
731    /// found or a new KEY appeared).
732    ///
733    /// This replaces `merge_delta` for rules with BEST BY, enabling convergence
734    /// on cyclic graphs where dominated ALONG values would otherwise produce an
735    /// unbounded stream of "new" full-row facts.
736    pub fn merge_best_by(
737        &mut self,
738        candidates: Vec<RecordBatch>,
739        sort_criteria: &[SortCriterion],
740    ) -> DFResult<bool> {
741        if candidates.is_empty() || candidates.iter().all(|b| b.num_rows() == 0) {
742            self.delta.clear();
743            return Ok(false);
744        }
745
746        // Reconcile schema from the first non-empty candidate batch.
747        if let Some(first) = candidates.iter().find(|b| b.num_rows() > 0) {
748            self.reconcile_schema(&first.schema());
749        }
750
751        // Round floats for stable dedup.
752        let candidates = round_float_columns(&candidates);
753
754        // Snapshot existing best-per-KEY facts for change detection.
755        let old_best: HashMap<Vec<ScalarKey>, Vec<ScalarKey>> =
756            self.build_key_criteria_map(sort_criteria);
757
758        // Concat existing facts + new candidates.
759        let mut all_batches = self.facts.clone();
760        all_batches.extend(candidates);
761        let all_batches: Vec<_> = all_batches
762            .into_iter()
763            .filter(|b| b.num_rows() > 0)
764            .collect();
765        if all_batches.is_empty() {
766            self.delta.clear();
767            return Ok(false);
768        }
769
770        let combined = arrow::compute::concat_batches(&self.schema, &all_batches)
771            .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
772
773        if combined.num_rows() == 0 {
774            self.delta.clear();
775            return Ok(false);
776        }
777
778        // Sort by KEY ASC then criteria, so the best row per KEY group comes
779        // first.
780        let mut sort_columns = Vec::new();
781        for &ki in &self.key_column_indices {
782            if ki >= combined.num_columns() {
783                continue;
784            }
785            sort_columns.push(arrow::compute::SortColumn {
786                values: Arc::clone(combined.column(ki)),
787                options: Some(arrow::compute::SortOptions {
788                    descending: false,
789                    nulls_first: false,
790                }),
791            });
792        }
793        for criterion in sort_criteria {
794            if criterion.col_index >= combined.num_columns() {
795                continue;
796            }
797            sort_columns.push(arrow::compute::SortColumn {
798                values: Arc::clone(combined.column(criterion.col_index)),
799                options: Some(arrow::compute::SortOptions {
800                    descending: !criterion.ascending,
801                    nulls_first: criterion.nulls_first,
802                }),
803            });
804        }
805
806        let sorted_indices =
807            arrow::compute::lexsort_to_indices(&sort_columns, None).map_err(arrow_err)?;
808        let sorted_columns: Vec<_> = combined
809            .columns()
810            .iter()
811            .map(|col| arrow::compute::take(col.as_ref(), &sorted_indices, None))
812            .collect::<Result<Vec<_>, _>>()
813            .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
814        let sorted = RecordBatch::try_new(Arc::clone(&self.schema), sorted_columns)
815            .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
816
817        // Dedup: keep first (best) row per KEY group.
818        let mut keep_indices: Vec<u32> = Vec::new();
819        let mut prev_key: Option<Vec<ScalarKey>> = None;
820        for row_idx in 0..sorted.num_rows() {
821            let key = extract_scalar_key(&sorted, &self.key_column_indices, row_idx);
822            let is_new_group = match &prev_key {
823                None => true,
824                Some(prev) => *prev != key,
825            };
826            if is_new_group {
827                keep_indices.push(row_idx as u32);
828                prev_key = Some(key);
829            }
830        }
831
832        let keep_array = arrow_array::UInt32Array::from(keep_indices);
833        let output_columns: Vec<_> = sorted
834            .columns()
835            .iter()
836            .map(|col| arrow::compute::take(col.as_ref(), &keep_array, None))
837            .collect::<Result<Vec<_>, _>>()
838            .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
839        let pruned = RecordBatch::try_new(Arc::clone(&self.schema), output_columns)
840            .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
841
842        // Detect whether the best-per-KEY set actually changed.
843        let new_best: HashMap<Vec<ScalarKey>, Vec<ScalarKey>> = {
844            let mut map = HashMap::new();
845            for row_idx in 0..pruned.num_rows() {
846                let key = extract_scalar_key(&pruned, &self.key_column_indices, row_idx);
847                let criteria: Vec<ScalarKey> = sort_criteria
848                    .iter()
849                    .flat_map(|c| extract_scalar_key(&pruned, &[c.col_index], row_idx))
850                    .collect();
851                map.insert(key, criteria);
852            }
853            map
854        };
855        let changed = old_best != new_best;
856
857        tracing::debug!(
858            rule = %self.rule_name,
859            old_keys = old_best.len(),
860            new_keys = new_best.len(),
861            changed = changed,
862            "BEST BY merge"
863        );
864
865        // Replace facts with the pruned set.
866        self.facts_bytes = batch_byte_size(&pruned);
867        self.facts = vec![pruned];
868        if changed {
869            // Delta is conceptually the new/improved facts, but since we
870            // replaced the entire set, just mark delta non-empty.
871            self.delta = self.facts.clone();
872        } else {
873            self.delta.clear();
874        }
875
876        // Rebuild row dedup from pruned facts for consistency.
877        self.row_dedup = RowDedupState::try_new(&self.schema);
878        if let Some(ref mut rd) = self.row_dedup {
879            rd.ingest_existing(&self.facts, &self.schema);
880        }
881
882        Ok(changed)
883    }
884
885    /// Build a map from KEY column values to sort criteria values.
886    fn build_key_criteria_map(
887        &self,
888        sort_criteria: &[SortCriterion],
889    ) -> HashMap<Vec<ScalarKey>, Vec<ScalarKey>> {
890        let mut map = HashMap::new();
891        for batch in &self.facts {
892            for row_idx in 0..batch.num_rows() {
893                let key = extract_scalar_key(batch, &self.key_column_indices, row_idx);
894                let criteria: Vec<ScalarKey> = sort_criteria
895                    .iter()
896                    .flat_map(|c| extract_scalar_key(batch, &[c.col_index], row_idx))
897                    .collect();
898                map.insert(key, criteria);
899            }
900        }
901        map
902    }
903}
904
905/// Estimate byte size of a RecordBatch.
906fn batch_byte_size(batch: &RecordBatch) -> usize {
907    batch
908        .columns()
909        .iter()
910        .map(|col| col.get_buffer_memory_size())
911        .sum()
912}
913
914// ---------------------------------------------------------------------------
915// Float rounding for stable dedup
916// ---------------------------------------------------------------------------
917
918/// Round all Float64 columns to 12 decimal places for stable dedup.
919fn round_float_columns(batches: &[RecordBatch]) -> Vec<RecordBatch> {
920    batches
921        .iter()
922        .map(|batch| {
923            let schema = batch.schema();
924            let has_float = schema
925                .fields()
926                .iter()
927                .any(|f| *f.data_type() == arrow_schema::DataType::Float64);
928            if !has_float {
929                return batch.clone();
930            }
931
932            let columns: Vec<arrow_array::ArrayRef> = batch
933                .columns()
934                .iter()
935                .enumerate()
936                .map(|(i, col)| {
937                    if *schema.field(i).data_type() == arrow_schema::DataType::Float64 {
938                        let arr = col
939                            .as_any()
940                            .downcast_ref::<arrow_array::Float64Array>()
941                            .unwrap();
942                        let rounded: arrow_array::Float64Array = arr
943                            .iter()
944                            .map(|v| v.map(|f| (f * 1e12).round() / 1e12))
945                            .collect();
946                        Arc::new(rounded) as arrow_array::ArrayRef
947                    } else {
948                        Arc::clone(col)
949                    }
950                })
951                .collect();
952
953            RecordBatch::try_new(schema, columns).unwrap_or_else(|_| batch.clone())
954        })
955        .collect()
956}
957
958// ---------------------------------------------------------------------------
959// LeftAntiJoin delta deduplication
960// ---------------------------------------------------------------------------
961
962/// Row threshold above which the vectorized Arrow LeftAntiJoin dedup path is used.
963///
964/// Below this threshold the persistent `RowDedupState` HashSet is O(M) and
965/// avoids rebuilding the existing-row set; above it DataFusion's vectorized
966/// HashJoinExec is more cache-efficient.
967const DEDUP_ANTI_JOIN_THRESHOLD: usize = 300;
968
969/// Deduplicate `candidates` against `existing` using DataFusion's HashJoinExec.
970///
971/// Returns rows in `candidates` that do not appear in `existing` (LeftAnti semantics).
972/// `null_equals_null = true` so NULLs are treated as equal for dedup purposes.
973/// Dedup `batches` by all columns (set semantics), keeping the first occurrence.
974///
975/// `arrow_left_anti_dedup` removes candidate rows that match the existing fact
976/// set, but a single semi-naive iteration can emit the same row many times — e.g.
977/// a transitive-closure rule derives the same `(a, b)` pair via every intermediate
978/// `mid` on a path. A `LeftAnti` join does not remove these *within-candidate*
979/// duplicates, so they would leak into the fact set. The `RowDedupState` and legacy
980/// paths both dedup within the candidate batch ([`RowDedupState::compute_delta`],
981/// [`FixpointState::compute_delta_legacy`]); this keeps the `arrow_left_anti_dedup`
982/// path identical so dedup behavior does not change across `DEDUP_ANTI_JOIN_THRESHOLD`.
983fn dedup_batches_all_columns(
984    batches: Vec<RecordBatch>,
985    schema: &SchemaRef,
986) -> DFResult<Vec<RecordBatch>> {
987    let fields: Vec<SortField> = schema
988        .fields()
989        .iter()
990        .map(|f| SortField::new(f.data_type().clone()))
991        .collect();
992    // Unsupported column types: leave as-is. This path is only reached for facts
993    // sets >= DEDUP_ANTI_JOIN_THRESHOLD; for those types the <threshold path uses
994    // `compute_delta_legacy`, which dedups via `ScalarKey` instead.
995    let Ok(converter) = RowConverter::new(fields) else {
996        return Ok(batches);
997    };
998    let mut seen: HashSet<Box<[u8]>> = HashSet::new();
999    let mut out = Vec::with_capacity(batches.len());
1000    for batch in batches {
1001        if batch.num_rows() == 0 {
1002            continue;
1003        }
1004        let rows = converter
1005            .convert_columns(batch.columns())
1006            .map_err(arrow_err)?;
1007        let mut keep = Vec::with_capacity(batch.num_rows());
1008        for row_idx in 0..batch.num_rows() {
1009            let row_bytes: Box<[u8]> = rows.row(row_idx).data().into();
1010            keep.push(seen.insert(row_bytes));
1011        }
1012        let keep_mask = arrow_array::BooleanArray::from(keep);
1013        let cols = batch
1014            .columns()
1015            .iter()
1016            .map(|c| arrow::compute::filter(c.as_ref(), &keep_mask).map_err(arrow_err))
1017            .collect::<DFResult<Vec<_>>>()?;
1018        if cols.first().is_some_and(|c| !c.is_empty()) {
1019            out.push(RecordBatch::try_new(Arc::clone(schema), cols).map_err(arrow_err)?);
1020        }
1021    }
1022    Ok(out)
1023}
1024
1025async fn arrow_left_anti_dedup(
1026    candidates: Vec<RecordBatch>,
1027    existing: &[RecordBatch],
1028    schema: &SchemaRef,
1029    task_ctx: &Arc<TaskContext>,
1030) -> DFResult<Vec<RecordBatch>> {
1031    if existing.is_empty() || existing.iter().all(|b| b.num_rows() == 0) {
1032        // No existing facts to anti-join against, but still dedup the candidates
1033        // among themselves (a single iteration may emit duplicate rows).
1034        return dedup_batches_all_columns(candidates, schema);
1035    }
1036
1037    let left: Arc<dyn ExecutionPlan> = Arc::new(InMemoryExec::new(candidates, Arc::clone(schema)));
1038    let right: Arc<dyn ExecutionPlan> =
1039        Arc::new(InMemoryExec::new(existing.to_vec(), Arc::clone(schema)));
1040
1041    let on: Vec<(
1042        Arc<dyn datafusion::physical_plan::PhysicalExpr>,
1043        Arc<dyn datafusion::physical_plan::PhysicalExpr>,
1044    )> = schema
1045        .fields()
1046        .iter()
1047        .enumerate()
1048        .map(|(i, field)| {
1049            let l: Arc<dyn datafusion::physical_plan::PhysicalExpr> = Arc::new(
1050                datafusion::physical_plan::expressions::Column::new(field.name(), i),
1051            );
1052            let r: Arc<dyn datafusion::physical_plan::PhysicalExpr> = Arc::new(
1053                datafusion::physical_plan::expressions::Column::new(field.name(), i),
1054            );
1055            (l, r)
1056        })
1057        .collect();
1058
1059    if on.is_empty() {
1060        return Ok(vec![]);
1061    }
1062
1063    let join = HashJoinExec::try_new(
1064        left,
1065        right,
1066        on,
1067        None,
1068        &JoinType::LeftAnti,
1069        None,
1070        PartitionMode::CollectLeft,
1071        datafusion::common::NullEquality::NullEqualsNull,
1072        // null_aware = false: this is a set-difference dedup (NOT EXISTS), not a
1073        // SQL NOT-IN. `NullEqualsNull` already makes NULL keys dedup against NULL
1074        // keys; enabling null-aware semantics would wrongly annihilate all rows
1075        // whenever the existing-fact side contains a NULL key.
1076        false,
1077    )?;
1078
1079    let join_arc: Arc<dyn ExecutionPlan> = Arc::new(join);
1080    // LeftAnti removes candidates that match `existing`, but not duplicate rows
1081    // within the candidate set — dedup those to match the other delta strategies.
1082    let anti = collect_all_partitions(&join_arc, task_ctx.clone()).await?;
1083    dedup_batches_all_columns(anti, schema)
1084}
1085
1086// ---------------------------------------------------------------------------
1087// Plan types for fixpoint rules
1088// ---------------------------------------------------------------------------
1089
1090/// IS-ref binding: a reference from a clause body to a derived relation.
1091#[derive(Debug, Clone)]
1092pub struct IsRefBinding {
1093    /// Index into the DerivedScanRegistry.
1094    pub derived_scan_index: usize,
1095    /// Name of the rule being referenced.
1096    pub rule_name: String,
1097    /// Whether this is a self-reference (rule references itself).
1098    pub is_self_ref: bool,
1099    /// Whether this is a negated reference (NOT IS).
1100    pub negated: bool,
1101    /// For negated IS-refs: `(left_body_col, right_derived_col)` pairs for anti-join filtering.
1102    ///
1103    /// `left_body_col` is the VID column in the clause body (e.g., `"n._vid"`);
1104    /// `right_derived_col` is the corresponding KEY column in the negated rule's facts (e.g., `"n"`).
1105    /// Empty for non-negated IS-refs.
1106    pub anti_join_cols: Vec<(String, String)>,
1107    /// Whether the target rule has a PROB column.
1108    pub target_has_prob: bool,
1109    /// Name of the PROB column in the target rule, if any.
1110    pub target_prob_col: Option<String>,
1111    /// `(body_col, derived_col)` pairs for provenance tracking.
1112    ///
1113    /// Used by shared-proof detection to find which source facts a derived row
1114    /// consumed. Populated for all IS-refs (not just negated ones).
1115    pub provenance_join_cols: Vec<(String, String)>,
1116}
1117
1118/// A single clause (body) within a fixpoint rule.
1119#[derive(Debug)]
1120pub struct FixpointClausePlan {
1121    /// The logical plan for the clause body.
1122    pub body_logical: LogicalPlan,
1123    /// IS-ref bindings used by this clause.
1124    pub is_ref_bindings: Vec<IsRefBinding>,
1125    /// Priority value for this clause (if PRIORITY semantics apply).
1126    pub priority: Option<i64>,
1127    /// ALONG binding variable names propagated from the planner.
1128    pub along_bindings: Vec<String>,
1129    /// Phase B Slice 3: neural-model invocations lifted out of YIELD
1130    /// items by the compiler. Each entry is evaluated per row after the
1131    /// clause body produces batches and before IS-ref handling.
1132    pub model_invocations: Vec<ModelInvocation>,
1133}
1134
1135/// Physical plan for a single rule in a fixpoint stratum.
1136#[derive(Debug)]
1137pub struct FixpointRulePlan {
1138    /// Rule name.
1139    pub name: String,
1140    /// Clause bodies (each evaluates to candidate rows).
1141    pub clauses: Vec<FixpointClausePlan>,
1142    /// Output schema for this rule's derived relation.
1143    pub yield_schema: SchemaRef,
1144    /// Indices of KEY columns within yield_schema.
1145    pub key_column_indices: Vec<usize>,
1146    /// Priority value (if PRIORITY semantics apply).
1147    pub priority: Option<i64>,
1148    /// Whether this rule has FOLD semantics.
1149    pub has_fold: bool,
1150    /// FOLD bindings for post-fixpoint aggregation.
1151    pub fold_bindings: Vec<FoldBinding>,
1152    /// Post-FOLD filter expressions (HAVING semantics).
1153    pub having: Vec<Expr>,
1154    /// Whether this rule has BEST BY semantics.
1155    pub has_best_by: bool,
1156    /// BEST BY sort criteria for post-fixpoint selection.
1157    pub best_by_criteria: Vec<SortCriterion>,
1158    /// Whether this rule has PRIORITY semantics.
1159    pub has_priority: bool,
1160    /// Whether BEST BY should apply a deterministic secondary sort for
1161    /// tie-breaking. When false, tied rows are selected non-deterministically
1162    /// (faster but not repeatable across runs).
1163    pub deterministic: bool,
1164    /// Name of the PROB column in this rule's yield schema, if any.
1165    pub prob_column_name: Option<String>,
1166    /// True when any clause of this rule has ≥2 positive same-stratum
1167    /// IS-refs (non-linear recursion, e.g. `tc(a,b) :- tc(a,m), tc(m,b)`).
1168    /// Such rules get FULL facts (naive evaluation) instead of the latest
1169    /// delta on their self-ref scans: a delta-only join computes Δ×Δ and
1170    /// misses the Δ×F_old combinations, silently under-deriving. Covers
1171    /// both `p :- p, p` and `p :- p, q` (q in the same SCC) shapes.
1172    pub non_linear: bool,
1173}
1174
1175// ---------------------------------------------------------------------------
1176// run_fixpoint_loop — the core semi-naive iteration algorithm
1177// ---------------------------------------------------------------------------
1178
1179/// Run the semi-naive fixpoint iteration loop.
1180///
1181/// Evaluates all rules in a stratum repeatedly, feeding deltas back through
1182/// derived scan handles until convergence or limits are reached.
1183#[expect(clippy::too_many_arguments, reason = "Fixpoint loop needs all context")]
1184async fn run_fixpoint_loop(
1185    rules: Vec<FixpointRulePlan>,
1186    max_iterations: usize,
1187    timeout: Duration,
1188    graph_ctx: Arc<GraphExecutionContext>,
1189    session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
1190    storage: Arc<StorageManager>,
1191    schema_info: Arc<UniSchema>,
1192    params: HashMap<String, Value>,
1193    registry: Arc<DerivedScanRegistry>,
1194    output_schema: SchemaRef,
1195    max_derived_bytes: usize,
1196    derivation_tracker: Option<Arc<ProvenanceStore>>,
1197    iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
1198    strict_probability_domain: bool,
1199    probability_epsilon: f64,
1200    exact_probability: bool,
1201    max_bdd_variables: usize,
1202    warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
1203    approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
1204    top_k_proofs: usize,
1205    timeout_flag: Arc<std::sync::atomic::AtomicU8>,
1206    semiring_kind: SemiringKind,
1207    classifier_registry: Arc<ClassifierRegistry>,
1208    classifier_cache: Option<Arc<ModelInvocationCache>>,
1209    classifier_provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
1210    profile_collector: Option<Arc<LocyProfileCollector>>,
1211) -> DFResult<Vec<RecordBatch>> {
1212    let start = Instant::now();
1213    let task_ctx = session_ctx.read().task_ctx();
1214
1215    // IMPORTANT: per rollout D-9 the FuzzyNotProbabilistic warning emitted
1216    // below is unsuppressible — do not gate on any suppression mechanism.
1217    // Fuzzy truth values are not probabilities; silent conflation is the
1218    // dominant pitfall in neuro-symbolic systems (LTN, NTP).
1219    if semiring_kind == SemiringKind::MaxMinProb {
1220        let mut warnings = warnings_slot.write().unwrap_or_else(|e| e.into_inner());
1221        let mut already_warned: HashSet<String> = warnings
1222            .iter()
1223            .filter(|w| w.code == RuntimeWarningCode::FuzzyNotProbabilistic)
1224            .map(|w| w.rule_name.clone())
1225            .collect();
1226        for rule in &rules {
1227            if rule.prob_column_name.is_some() && !already_warned.contains(&rule.name) {
1228                warnings.push(RuntimeWarning {
1229                    code: RuntimeWarningCode::FuzzyNotProbabilistic,
1230                    message: format!(
1231                        "rule '{}' carries a PROB column but is being evaluated under \
1232                         the MaxMinProb (fuzzy / Viterbi) semiring; outputs are fuzzy \
1233                         truth values, not probabilities",
1234                        rule.name
1235                    ),
1236                    rule_name: rule.name.clone(),
1237                    variable_count: None,
1238                    key_group: None,
1239                });
1240                already_warned.insert(rule.name.clone());
1241            }
1242        }
1243    }
1244
1245    // Initialize per-rule state
1246    let mut states: Vec<FixpointState> = rules
1247        .iter()
1248        .map(|rule| {
1249            let monotonic_agg = if !rule.fold_bindings.is_empty() {
1250                let bindings: Vec<MonotonicFoldBinding> = rule
1251                    .fold_bindings
1252                    .iter()
1253                    .map(|fb| MonotonicFoldBinding {
1254                        fold_name: fb.output_name.clone(),
1255                        aggregate: std::sync::Arc::clone(&fb.aggregate),
1256                        input_col_index: fb.input_col_index,
1257                        input_col_name: fb.input_col_name.clone(),
1258                    })
1259                    .collect();
1260                Some(MonotonicAggState::new(bindings))
1261            } else {
1262                None
1263            };
1264            FixpointState::new_with_semiring(
1265                rule.name.clone(),
1266                Arc::clone(&rule.yield_schema),
1267                rule.key_column_indices.clone(),
1268                max_derived_bytes,
1269                monotonic_agg,
1270                strict_probability_domain,
1271                semiring_kind,
1272            )
1273        })
1274        .collect();
1275
1276    // Main iteration loop
1277    let mut converged = false;
1278    let mut total_iters = 0usize;
1279    for iteration in 0..max_iterations {
1280        total_iters = iteration + 1;
1281        tracing::debug!("fixpoint iteration {}", iteration);
1282        let mut any_changed = false;
1283
1284        for rule_idx in 0..rules.len() {
1285            let rule = &rules[rule_idx];
1286
1287            // Update derived scan handles for this rule's clauses
1288            update_derived_scan_handles(&registry, &states, rule_idx, &rules);
1289
1290            // Evaluate clause bodies, tracking per-clause candidates for provenance.
1291            let mut all_candidates = Vec::new();
1292            let mut clause_candidates: Vec<Vec<RecordBatch>> = Vec::new();
1293            // Profiling (profile() path only): time this rule's evaluation in
1294            // this iteration and accumulate the clause-body operator trees.
1295            let rule_start = Instant::now();
1296            let mut iter_ops: Vec<OperatorStats> = Vec::new();
1297            for clause in &rule.clauses {
1298                // Phase B A4 follow-up: the planner inserts
1299                // `LogicalPlan::LocyModelInvoke` between the body and
1300                // `LocyProject` when this clause has neural-model
1301                // invocations, so `execute_subplan` runs the invocation
1302                // inline as part of the body plan tree.
1303                let mut batches = if profile_collector.is_some() {
1304                    let (b, ops) = execute_subplan_collecting(
1305                        &clause.body_logical,
1306                        &params,
1307                        &HashMap::new(),
1308                        &graph_ctx,
1309                        &session_ctx,
1310                        &storage,
1311                        &schema_info,
1312                        None, // Locy fixpoint clause body is read-only
1313                    )
1314                    .await?;
1315                    iter_ops.extend(ops);
1316                    b
1317                } else {
1318                    execute_subplan(
1319                        &clause.body_logical,
1320                        &params,
1321                        &HashMap::new(),
1322                        &graph_ctx,
1323                        &session_ctx,
1324                        &storage,
1325                        &schema_info,
1326                        None, // Locy fixpoint clause body is read-only
1327                    )
1328                    .await?
1329                };
1330                // Apply negated IS-ref semantics: probabilistic complement or anti-join.
1331                for binding in &clause.is_ref_bindings {
1332                    if binding.negated
1333                        && !binding.anti_join_cols.is_empty()
1334                        && let Some(entry) = registry.get(binding.derived_scan_index)
1335                    {
1336                        let neg_facts = entry.data.read().clone();
1337                        if !neg_facts.is_empty() {
1338                            if binding.target_has_prob && rule.prob_column_name.is_some() {
1339                                // Probabilistic complement: add 1-p column instead of filtering.
1340                                let complement_col =
1341                                    format!("__prob_complement_{}", binding.rule_name);
1342                                if let Some(prob_col) = &binding.target_prob_col {
1343                                    batches = apply_prob_complement_composite(
1344                                        batches,
1345                                        &neg_facts,
1346                                        &binding.anti_join_cols,
1347                                        prob_col,
1348                                        &complement_col,
1349                                    )?;
1350                                } else {
1351                                    // target_has_prob but no prob_col: fall back to anti-join.
1352                                    batches = apply_anti_join_composite(
1353                                        batches,
1354                                        &neg_facts,
1355                                        &binding.anti_join_cols,
1356                                    )?;
1357                                }
1358                            } else {
1359                                // Boolean exclusion: anti-join (existing behavior)
1360                                batches = apply_anti_join_composite(
1361                                    batches,
1362                                    &neg_facts,
1363                                    &binding.anti_join_cols,
1364                                )?;
1365                            }
1366                        }
1367                    }
1368                }
1369                // Multiply complement columns into the PROB column (if any) and clean up
1370                let complement_cols: Vec<String> = if !batches.is_empty() {
1371                    batches[0]
1372                        .schema()
1373                        .fields()
1374                        .iter()
1375                        .filter(|f| f.name().starts_with("__prob_complement_"))
1376                        .map(|f| f.name().clone())
1377                        .collect()
1378                } else {
1379                    vec![]
1380                };
1381                if !complement_cols.is_empty() {
1382                    batches = multiply_prob_factors(
1383                        batches,
1384                        rule.prob_column_name.as_deref(),
1385                        &complement_cols,
1386                    )?;
1387                }
1388
1389                clause_candidates.push(batches.clone());
1390                all_candidates.extend(batches);
1391            }
1392
1393            // Merge candidates into facts.
1394            // For BEST BY rules, use a specialized merge that keeps only the
1395            // best row per KEY group, enabling convergence on cyclic graphs.
1396            let changed = if rule.has_best_by && !rule.best_by_criteria.is_empty() {
1397                states[rule_idx].merge_best_by(all_candidates, &rule.best_by_criteria)?
1398            } else {
1399                states[rule_idx]
1400                    .merge_delta(all_candidates, Some(Arc::clone(&task_ctx)))
1401                    .await?
1402            };
1403            if changed {
1404                any_changed = true;
1405                // Record provenance for newly derived facts when tracker is present.
1406                if let Some(ref tracker) = derivation_tracker {
1407                    record_provenance(
1408                        ProvenanceCtx {
1409                            tracker,
1410                            registry: &registry,
1411                            warnings_slot: &warnings_slot,
1412                        },
1413                        rule,
1414                        &states[rule_idx],
1415                        &clause_candidates,
1416                        iteration,
1417                        top_k_proofs,
1418                        ClassifierRefs {
1419                            registry: &classifier_registry,
1420                            cache: classifier_cache.as_ref(),
1421                            provenance_store: classifier_provenance_store.as_ref(),
1422                        },
1423                    )
1424                    .await;
1425                }
1426            }
1427
1428            // Profiling: record this rule's per-iteration row (delta = net-new
1429            // facts merged this pass, plus the captured clause-body operators).
1430            if let Some(ref collector) = profile_collector {
1431                let delta_facts: usize = states[rule_idx]
1432                    .all_delta()
1433                    .iter()
1434                    .map(|b| b.num_rows())
1435                    .sum();
1436                collector.record(
1437                    &rule.name,
1438                    iteration,
1439                    delta_facts,
1440                    rule_start.elapsed().as_secs_f64() * 1000.0,
1441                    iter_ops,
1442                );
1443            }
1444        }
1445
1446        // Check convergence
1447        if !any_changed && states.iter().all(|s| s.is_converged()) {
1448            tracing::debug!("fixpoint converged after {} iterations", iteration + 1);
1449            converged = true;
1450            break;
1451        }
1452
1453        // Check timeout
1454        if start.elapsed() > timeout {
1455            tracing::warn!(
1456                "fixpoint timeout after {} iterations; returning partial results",
1457                iteration + 1,
1458            );
1459            interruption::set(&timeout_flag, interruption::TIMEOUT);
1460            break;
1461        }
1462    }
1463
1464    // Write per-rule iteration counts to the shared slot.
1465    if let Ok(mut counts) = iteration_counts.write() {
1466        for rule in &rules {
1467            counts.insert(rule.name.clone(), total_iters);
1468        }
1469    }
1470
1471    // Profiling: record each rule's final converged fact count.
1472    if let Some(ref collector) = profile_collector {
1473        for (idx, rule) in rules.iter().enumerate() {
1474            let facts: usize = states[idx].all_facts().iter().map(|b| b.num_rows()).sum();
1475            collector.set_final_facts(&rule.name, facts);
1476        }
1477    }
1478
1479    // If we exhausted all iterations without converging, record the iteration
1480    // limit (distinct from a wall-clock timeout) and proceed with partial
1481    // results rather than discarding all work. `set` is first-wins, so a
1482    // wall-clock timeout recorded above is not overwritten here.
1483    if !converged && interruption::reason(&timeout_flag).is_none() {
1484        tracing::warn!(
1485            "fixpoint did not converge after {max_iterations} iterations; returning partial results",
1486        );
1487        interruption::set(&timeout_flag, interruption::ITERATION_LIMIT);
1488    }
1489
1490    // Post-fixpoint processing per rule and collect output
1491    let task_ctx = session_ctx.read().task_ctx();
1492    let mut all_output = Vec::new();
1493
1494    for (rule_idx, state) in states.into_iter().enumerate() {
1495        let rule = &rules[rule_idx];
1496        let mut facts = state.into_facts();
1497        if facts.is_empty() {
1498            continue;
1499        }
1500
1501        // Detect shared proofs before FOLD collapses groups.
1502        //
1503        // TODO(C0-stage2): swap `detect_shared_lineage` for `TopKTag`
1504        // DNF inspection when `semiring_kind == TopKProofs { k }`.
1505        // The library-layer tag math has landed in
1506        // `crates/uni-locy/src/top_k_proofs.rs` (Phase C C0 Stage 1);
1507        // Stage 2 plumbs `TopKTag` through `MonotonicAggState` /
1508        // `FoldExec` so per-row dependency DNFs are available here.
1509        // Until Stage 2, this scalar `ProvenanceStore` path runs for
1510        // every semiring including `TopKProofs` (per rollout D-4
1511        // "graceful migration").
1512        //
1513        // Phase-3 shared-proof detection is meaningful only under
1514        // `AddMultProb` (and `BddExact`, which is the AddMultProb math
1515        // plus a WMC post-correction). Under `MaxMinProb`, `plus = max`
1516        // is idempotent — shared proofs don't double-count — so the
1517        // warning is moot and we skip the work.
1518        let shared_info = if semiring_kind == SemiringKind::MaxMinProb {
1519            None
1520        } else if let Some(ref tracker) = derivation_tracker {
1521            detect_shared_lineage(rule, &facts, tracker, &warnings_slot, semiring_kind)
1522        } else {
1523            None
1524        };
1525
1526        // Apply BDD for shared groups if exact_probability is enabled.
1527        if exact_probability
1528            && let Some(ref info) = shared_info
1529            && let Some(ref tracker) = derivation_tracker
1530        {
1531            facts = apply_exact_wmc(
1532                facts,
1533                rule,
1534                info,
1535                tracker,
1536                max_bdd_variables,
1537                &warnings_slot,
1538                &approximate_slot,
1539            )?;
1540        }
1541
1542        let processed = apply_post_fixpoint_chain(
1543            facts,
1544            rule,
1545            &task_ctx,
1546            strict_probability_domain,
1547            probability_epsilon,
1548            semiring_kind,
1549            derivation_tracker.as_ref().map(Arc::clone),
1550            top_k_proofs,
1551            Some(Arc::clone(&registry)),
1552        )
1553        .await?;
1554        all_output.extend(processed);
1555    }
1556
1557    // If no output, return empty batch with output schema
1558    if all_output.is_empty() {
1559        all_output.push(RecordBatch::new_empty(output_schema));
1560    }
1561
1562    Ok(all_output)
1563}
1564
1565// ---------------------------------------------------------------------------
1566// Provenance recording helpers
1567// ---------------------------------------------------------------------------
1568
1569/// Record provenance for all newly derived facts (rows in the current delta).
1570///
1571/// Called after `merge_delta` returns `true`. Attributes each new fact to the
1572/// clause most likely to have produced it, using first-derivation-wins semantics.
1573/// Borrowed bundle of classifier-side runtime state used by
1574/// provenance / EXPLAIN-reconstruction code paths. Keeps function
1575/// signatures under the too-many-arguments threshold.
1576pub(crate) struct ClassifierRefs<'a> {
1577    pub registry: &'a Arc<ClassifierRegistry>,
1578    pub cache: Option<&'a Arc<uni_locy::ModelInvocationCache>>,
1579    /// Phase C B1-B3 follow-up: when `Some`, EXPLAIN's neural_calls
1580    /// collection consults the side-channel provenance store first
1581    /// (populated by `apply_model_invocations`). This is the only way
1582    /// to surface NeuralProvenance for Python-registered classifiers,
1583    /// whose model_invocations may be rewritten away by the planner
1584    /// and so wouldn't trigger the re-invocation fallback.
1585    pub provenance_store: Option<&'a Arc<uni_locy::NeuralProvenanceStore>>,
1586}
1587
1588/// Borrowed bundle of provenance-recording state: the in-flight
1589/// tracker, the derived-scan registry (used to resolve IS-ref inputs),
1590/// and the shared warnings slot. Bundled to keep
1591/// `record_provenance` / `record_and_detect_lineage_nonrecursive`
1592/// under the too-many-arguments threshold.
1593pub(crate) struct ProvenanceCtx<'a> {
1594    pub tracker: &'a Arc<ProvenanceStore>,
1595    pub registry: &'a Arc<DerivedScanRegistry>,
1596    pub warnings_slot: &'a Arc<StdRwLock<Vec<RuntimeWarning>>>,
1597}
1598
1599async fn record_provenance(
1600    prov: ProvenanceCtx<'_>,
1601    rule: &FixpointRulePlan,
1602    state: &FixpointState,
1603    clause_candidates: &[Vec<RecordBatch>],
1604    iteration: usize,
1605    top_k_proofs: usize,
1606    classifiers: ClassifierRefs<'_>,
1607) {
1608    let tracker = prov.tracker;
1609    let registry = prov.registry;
1610    let warnings_slot = prov.warnings_slot;
1611    let classifier_registry = classifiers.registry;
1612    let classifier_cache = classifiers.cache;
1613    let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
1614
1615    // Pre-compute base fact probabilities for top-k mode.
1616    let base_probs = if top_k_proofs > 0 {
1617        tracker.base_fact_probs()
1618    } else {
1619        HashMap::new()
1620    };
1621
1622    let mut topk_acc = TopKProofAccumulator::new();
1623
1624    for delta_batch in state.all_delta() {
1625        for row_idx in 0..delta_batch.num_rows() {
1626            let row_hash = format!(
1627                "{:?}",
1628                extract_scalar_key(delta_batch, &all_indices, row_idx)
1629            )
1630            .into_bytes();
1631            let fact_row = batch_row_to_value_map(delta_batch, row_idx);
1632            let clause_index =
1633                find_clause_for_row(delta_batch, row_idx, &all_indices, clause_candidates);
1634
1635            let support = collect_is_ref_inputs(rule, clause_index, delta_batch, row_idx, registry);
1636
1637            let proof_probability = if top_k_proofs > 0 {
1638                compute_proof_probability(&support, &base_probs)
1639            } else {
1640                None
1641            };
1642
1643            let entry = ProvenanceAnnotation {
1644                rule_name: rule.name.clone(),
1645                clause_index,
1646                support,
1647                along_values: {
1648                    let along_names: Vec<String> = rule
1649                        .clauses
1650                        .get(clause_index)
1651                        .map(|c| c.along_bindings.clone())
1652                        .unwrap_or_default();
1653                    along_names
1654                        .iter()
1655                        .filter_map(|name| fact_row.get(name).map(|v| (name.clone(), v.clone())))
1656                        .collect()
1657                },
1658                iteration,
1659                fact_row: fact_row.clone(),
1660                proof_probability,
1661                neural_calls: collect_neural_calls_for_row(
1662                    rule,
1663                    clause_index,
1664                    &fact_row,
1665                    classifier_registry,
1666                    classifier_cache,
1667                    classifiers.provenance_store,
1668                )
1669                .await,
1670            };
1671            if top_k_proofs > 0 {
1672                topk_acc.accumulate(&entry, &row_hash);
1673                tracker.record_top_k(row_hash, entry, top_k_proofs);
1674            } else {
1675                tracker.record(row_hash, entry);
1676            }
1677        }
1678    }
1679
1680    topk_acc.emit_warning_if_any(rule, top_k_proofs, warnings_slot);
1681}
1682
1683/// Phase C C0 Stage 2: collects per-row `Proof` tags during the
1684/// fixpoint row walk, then surfaces `TopKPruningCrossedDependency`
1685/// when post-walk top-K merging would drop a proof whose base RVs
1686/// overlap a retained one. The shared `BaseRv` interner is what
1687/// makes the overlap detectable — proofs grounded in the same
1688/// `base_fact_id` get the same `BaseRv`.
1689struct TopKProofAccumulator {
1690    per_fact: HashMap<Vec<u8>, Vec<uni_locy::Proof>>,
1691    base_rv_interner: HashMap<Vec<u8>, uni_locy::BaseRv>,
1692    next_rv: u32,
1693}
1694
1695impl TopKProofAccumulator {
1696    fn new() -> Self {
1697        Self {
1698            per_fact: HashMap::new(),
1699            base_rv_interner: HashMap::new(),
1700            next_rv: 0,
1701        }
1702    }
1703
1704    fn accumulate(&mut self, entry: &ProvenanceAnnotation, row_hash: &[u8]) {
1705        let mut base_rvs = uni_locy::BaseRvSet::empty();
1706        for term in &entry.support {
1707            let rv = *self
1708                .base_rv_interner
1709                .entry(term.base_fact_id.clone())
1710                .or_insert_with(|| {
1711                    let r = uni_locy::BaseRv(self.next_rv);
1712                    self.next_rv += 1;
1713                    r
1714                });
1715            base_rvs.insert(rv);
1716        }
1717        self.per_fact
1718            .entry(row_hash.to_vec())
1719            .or_default()
1720            .push(uni_locy::Proof {
1721                weight: entry.proof_probability.unwrap_or(0.0),
1722                base_rvs,
1723                neural_calls: Vec::new(),
1724            });
1725    }
1726
1727    fn emit_warning_if_any(
1728        &self,
1729        rule: &FixpointRulePlan,
1730        top_k_proofs: usize,
1731        warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
1732    ) {
1733        if top_k_proofs == 0 || self.per_fact.is_empty() {
1734            return;
1735        }
1736        let crossed_facts = self
1737            .per_fact
1738            .values()
1739            .filter(|proofs| {
1740                let (_kept, notice) =
1741                    uni_locy::merge_top_k_runtime(Vec::new(), (*proofs).clone(), top_k_proofs);
1742                notice == uni_locy::PruneNotice::CrossedDependency
1743            })
1744            .count();
1745        if crossed_facts == 0 {
1746            return;
1747        }
1748        let Ok(mut w) = warnings_slot.write() else {
1749            return;
1750        };
1751        let already = w.iter().any(|rw| {
1752            matches!(
1753                rw.code,
1754                uni_locy::types::RuntimeWarningCode::TopKPruningCrossedDependency
1755            ) && rw.rule_name == rule.name
1756        });
1757        if already {
1758            return;
1759        }
1760        w.push(RuntimeWarning {
1761            code: uni_locy::types::RuntimeWarningCode::TopKPruningCrossedDependency,
1762            rule_name: rule.name.clone(),
1763            message: format!(
1764                "rule '{}': top-K proof pruning (k={}) discarded {} fact(s) \
1765                 whose dependencies overlap retained proofs. The retained \
1766                 top-{} under-counts the true joint probability for those \
1767                 facts (Scallop, Huang et al. 2021). Increase k to recover.",
1768                rule.name, top_k_proofs, crossed_facts, top_k_proofs
1769            ),
1770            variable_count: None,
1771            key_group: None,
1772        });
1773    }
1774}
1775
1776/// Collect IS-ref input facts for a derived row using provenance join columns.
1777///
1778/// For each non-negated IS-ref binding in the clause, extracts body-side key
1779/// values from the delta row and finds matching source rows in the registry.
1780/// Returns a `ProofTerm` for each match (with the source fact hash).
1781/// Phase C B1–B3: build [`uni_locy::NeuralProvenance`] entries for
1782/// the model invocations on this clause by reading each
1783/// invocation's output column from the post-LocyModelInvoke row.
1784/// `raw_probability` is the classifier's direct output;
1785/// `calibrated_probability` and `confidence_band` come from the
1786/// active Calibrator (when the classifier wraps one).
1787///
1788/// Phase C B1-B3 follow-up rewrite: re-evaluate the classifier
1789/// per fact using the ORIGINAL pre-rewrite feature expressions
1790/// (stored as `invocation.original_feature_exprs`). This works
1791/// for invocations in YIELD, ALONG, and FOLD positions uniformly
1792/// — the original args carry the input bindings regardless of
1793/// where the synthetic `__model_<n>` column ends up in the plan
1794/// tree. Memoization via `ModelInvocationCache` (already threaded)
1795/// absorbs repeat costs; EXPLAIN typically operates on small
1796/// derivation trees so the per-fact classifier call is bounded.
1797async fn collect_neural_calls_for_row(
1798    rule: &FixpointRulePlan,
1799    clause_index: usize,
1800    fact_row: &uni_locy::FactRow,
1801    classifier_registry: &Arc<ClassifierRegistry>,
1802    classifier_cache: Option<&Arc<uni_locy::ModelInvocationCache>>,
1803    provenance_store: Option<&Arc<uni_locy::NeuralProvenanceStore>>,
1804) -> Vec<uni_locy::NeuralProvenance> {
1805    let Some(clause) = rule.clauses.get(clause_index) else {
1806        return Vec::new();
1807    };
1808    if clause.model_invocations.is_empty() {
1809        return Vec::new();
1810    }
1811    let mut out = Vec::with_capacity(clause.model_invocations.len());
1812    for invocation in &clause.model_invocations {
1813        // Build ClassifyInput from the REWRITTEN feature expressions
1814        // (referencing the synthetic `__feat_*` hidden columns that the
1815        // planner lifts model-call args into). This matches the writer's
1816        // path in `apply_model_invocations`, which iterates
1817        // `invocation.feature_exprs` to compute the same `input_hash`
1818        // that gets stored. Using the pre-rewrite `original_feature_exprs`
1819        // here would compute a different hash for YIELD-position
1820        // invocations (where the pre-rewrite expr references properties
1821        // not materialised into fact_row), causing the store lookup
1822        // below to miss and `neural_calls` to come back empty.
1823        let mut features = std::collections::HashMap::new();
1824        for (binding_name, feat_expr) in invocation
1825            .feature_names
1826            .iter()
1827            .zip(invocation.feature_exprs.iter())
1828        {
1829            features.insert(
1830                binding_name.clone(),
1831                eval_feature_expr_against_fact_row(feat_expr, fact_row),
1832            );
1833        }
1834        let input = uni_locy::ClassifyInput { features };
1835        let input_hash = input.stable_hash();
1836
1837        // Store-first read path. `apply_model_invocations` writes a
1838        // NeuralProvenanceRecord per (model, input_hash) into the
1839        // side-channel store during fixpoint. If we find a record
1840        // there, surface it directly — this is the only path that
1841        // populates calibrated_probability + confidence_band for
1842        // Python-registered classifiers.
1843        if let Some(store) = provenance_store
1844            && let Some(record) = store.get(&invocation.model_name, input_hash)
1845        {
1846            out.push(uni_locy::NeuralProvenance {
1847                model_name: invocation.model_name.clone(),
1848                raw_probability: record.raw_probability,
1849                calibrated_probability: record.calibrated_probability,
1850                confidence_band: record.confidence_band,
1851            });
1852            continue;
1853        }
1854
1855        // Fallback: re-invoke the classifier. Only reached when the
1856        // store wasn't populated for this (model, input) — e.g. older
1857        // sessions where the store wasn't threaded, or when EXPLAIN
1858        // runs against a row that fixpoint never touched.
1859        let Some(classifier) = classifier_registry.get(&invocation.model_name) else {
1860            continue;
1861        };
1862        let raw = if let Some(v) =
1863            classifier_cache.and_then(|c| c.get(&invocation.model_name, input_hash))
1864        {
1865            v
1866        } else {
1867            match classifier.classify(std::slice::from_ref(&input)).await {
1868                Ok(probs) => {
1869                    let v = probs.first().copied().unwrap_or(0.0);
1870                    if let Some(c) = classifier_cache {
1871                        c.insert(&invocation.model_name, input_hash, v);
1872                    }
1873                    v
1874                }
1875                Err(_) => continue,
1876            }
1877        };
1878        let calibrator = classifier.get_calibrator();
1879        let calibrated_probability = calibrator.as_ref().map(|_| raw);
1880        let confidence_band = calibrator.as_ref().and_then(|c| c.confidence_band(raw));
1881        out.push(uni_locy::NeuralProvenance {
1882            model_name: invocation.model_name.clone(),
1883            raw_probability: raw,
1884            calibrated_probability,
1885            confidence_band,
1886        });
1887    }
1888    out
1889}
1890
1891/// Phase C B1-B3 follow-up: evaluate a model's pre-rewrite feature
1892/// expression against a fact_row, producing a `FeatureValue` for
1893/// classifier input reconstruction. Mirrors the compile-time
1894/// acceptance set in `validate_features`:
1895/// - `Variable(name)` → fact_row[name], coerced.
1896/// - `Property(Variable(v), prop)` → fact_row["v.prop"] (the
1897///   materialized property column).
1898fn eval_feature_expr_against_fact_row(
1899    expr: &uni_cypher::ast::Expr,
1900    fact_row: &uni_locy::FactRow,
1901) -> uni_locy::FeatureValue {
1902    use uni_cypher::ast::Expr;
1903    use uni_locy::FeatureValue;
1904    let value_to_feature = |v: Option<&uni_common::Value>| -> FeatureValue {
1905        match v {
1906            Some(uni_common::Value::Float(f)) => FeatureValue::Float(*f),
1907            Some(uni_common::Value::Int(i)) => FeatureValue::Int(*i),
1908            Some(uni_common::Value::Bool(b)) => FeatureValue::Bool(*b),
1909            Some(uni_common::Value::String(s)) => FeatureValue::String(s.clone()),
1910            Some(uni_common::Value::Node(n)) => {
1911                // Encode node by vid for `scorer(s)` style.
1912                FeatureValue::Int(n.vid.as_u64() as i64)
1913            }
1914            _ => FeatureValue::Null,
1915        }
1916    };
1917    // Phase D D1: resolve a sub-expression to its raw `uni_common::Value`
1918    // for the `similar_to` UDF input. Falls back to the node's property
1919    // when the materialized column key isn't directly present.
1920    let resolve_value = |sub: &Expr| -> uni_common::Value {
1921        match sub {
1922            Expr::Variable(name) => fact_row
1923                .get(name)
1924                .cloned()
1925                .unwrap_or(uni_common::Value::Null),
1926            Expr::Property(boxed, prop) if matches!(boxed.as_ref(), Expr::Variable(_)) => {
1927                let Expr::Variable(v) = boxed.as_ref() else {
1928                    unreachable!()
1929                };
1930                let key = format!("{}.{}", v, prop);
1931                if let Some(val) = fact_row.get(&key) {
1932                    return val.clone();
1933                }
1934                if let Some(uni_common::Value::Node(n)) = fact_row.get(v) {
1935                    return n
1936                        .properties
1937                        .get(prop)
1938                        .cloned()
1939                        .unwrap_or(uni_common::Value::Null);
1940                }
1941                uni_common::Value::Null
1942            }
1943            Expr::Literal(lit) => lit.to_value(),
1944            Expr::List(items) => {
1945                let mut out = Vec::with_capacity(items.len());
1946                for it in items {
1947                    out.push(match it {
1948                        Expr::Literal(lit) => lit.to_value(),
1949                        _ => uni_common::Value::Null,
1950                    });
1951                }
1952                uni_common::Value::List(out)
1953            }
1954            _ => uni_common::Value::Null,
1955        }
1956    };
1957
1958    match expr {
1959        Expr::Variable(name) => value_to_feature(fact_row.get(name)),
1960        Expr::Property(boxed, prop) => {
1961            if let Expr::Variable(v) = boxed.as_ref() {
1962                // Try the materialized property column first.
1963                let key = format!("{}.{}", v, prop);
1964                if let Some(val) = fact_row.get(&key) {
1965                    return value_to_feature(Some(val));
1966                }
1967                // Fallback: try the synthetic hidden column that the
1968                // planner injects for property-access feature args
1969                // (`__feat_<var>_<prop>`). The writer side
1970                // (`apply_model_invocations`) already uses this
1971                // fallback (see `resolve_src` in the same file), so
1972                // mirroring it here keeps reader/writer input_hash
1973                // symmetric — without it, the YIELD-position case
1974                // (where `fact_row[v]` is a vid Int, not a Node)
1975                // returns Null and the store-lookup misses.
1976                let hidden_key = format!("__feat_{}_{}", v, prop);
1977                if let Some(val) = fact_row.get(&hidden_key) {
1978                    return value_to_feature(Some(val));
1979                }
1980                // Final fallback: read property directly from the
1981                // node value (works when fact_row carries the Node
1982                // rather than a vid Int).
1983                if let Some(uni_common::Value::Node(n)) = fact_row.get(v) {
1984                    return value_to_feature(n.properties.get(prop));
1985                }
1986            }
1987            FeatureValue::Null
1988        }
1989        Expr::FunctionCall { name, args, .. } if name == "similar_to" && args.len() == 2 => {
1990            let lv = resolve_value(&args[0]);
1991            let rv = resolve_value(&args[1]);
1992            match crate::query::similar_to::eval_similar_to_pure(&lv, &rv) {
1993                Ok(uni_common::Value::Float(f)) => FeatureValue::Float(f),
1994                _ => FeatureValue::Null,
1995            }
1996        }
1997        // `semantic_match` requires the Xervo embedder at this scope, which
1998        // is not threaded into the EXPLAIN re-evaluation path. Surface as
1999        // Null so neural-provenance still renders for the rest of the row.
2000        //
2001        // Phase D D1 graph-structural FunctionCalls (`degree_centrality`,
2002        // `pagerank_score`, `closeness_centrality`, `avg_neighbor`,
2003        // `max_neighbor`, `sum_neighbor`) require the `GraphAlgoHandle`
2004        // (algorithm registry + storage + PropertyManager) and an async
2005        // re-precompute pass — none of which are reachable from this
2006        // synchronous fact-row evaluator. Mode B re-evaluation surfaces
2007        // them as Null; the authoritative hot-path values are recorded
2008        // in `NeuralProvenanceStore` per fact (the EXPLAIN renderer
2009        // consults the store first when configured, falling back to
2010        // Mode B re-evaluation only as a backup).
2011        Expr::FunctionCall { name, .. }
2012            if matches!(
2013                name.as_str(),
2014                "degree_centrality"
2015                    | "pagerank_score"
2016                    | "closeness_centrality"
2017                    | "betweenness_centrality"
2018                    | "eigenvector_centrality"
2019                    | "harmonic_centrality"
2020                    | "katz_centrality"
2021                    | "avg_neighbor"
2022                    | "max_neighbor"
2023                    | "sum_neighbor"
2024            ) =>
2025        {
2026            FeatureValue::Null
2027        }
2028        _ => FeatureValue::Null,
2029    }
2030}
2031
2032fn collect_is_ref_inputs(
2033    rule: &FixpointRulePlan,
2034    clause_index: usize,
2035    delta_batch: &RecordBatch,
2036    row_idx: usize,
2037    registry: &Arc<DerivedScanRegistry>,
2038) -> Vec<ProofTerm> {
2039    let clause = match rule.clauses.get(clause_index) {
2040        Some(c) => c,
2041        None => return vec![],
2042    };
2043
2044    let mut inputs = Vec::new();
2045    let delta_schema = delta_batch.schema();
2046
2047    for binding in &clause.is_ref_bindings {
2048        if binding.negated {
2049            continue;
2050        }
2051        if binding.provenance_join_cols.is_empty() {
2052            continue;
2053        }
2054
2055        // Extract body-side values from the delta row for each provenance join col.
2056        let body_values: Vec<(String, ScalarKey)> = binding
2057            .provenance_join_cols
2058            .iter()
2059            .filter_map(|(body_col, _derived_col)| {
2060                let col_idx = delta_schema
2061                    .fields()
2062                    .iter()
2063                    .position(|f| f.name() == body_col)?;
2064                let key = extract_scalar_key(delta_batch, &[col_idx], row_idx);
2065                Some((body_col.clone(), key.into_iter().next()?))
2066            })
2067            .collect();
2068
2069        if body_values.len() != binding.provenance_join_cols.len() {
2070            continue;
2071        }
2072
2073        // Read current data from the registry entry for this IS-ref's rule.
2074        let entry = match registry.get(binding.derived_scan_index) {
2075            Some(e) => e,
2076            None => continue,
2077        };
2078        let source_batches = entry.data.read();
2079        let source_schema = &entry.schema;
2080
2081        // Find matching source rows and hash them.
2082        for src_batch in source_batches.iter() {
2083            let all_src_indices: Vec<usize> = (0..src_batch.num_columns()).collect();
2084            for src_row in 0..src_batch.num_rows() {
2085                let matches = binding.provenance_join_cols.iter().enumerate().all(
2086                    |(i, (_body_col, derived_col))| {
2087                        let src_col_idx = source_schema
2088                            .fields()
2089                            .iter()
2090                            .position(|f| f.name() == derived_col);
2091                        match src_col_idx {
2092                            Some(idx) => {
2093                                let src_key = extract_scalar_key(src_batch, &[idx], src_row);
2094                                src_key.first() == Some(&body_values[i].1)
2095                            }
2096                            None => false,
2097                        }
2098                    },
2099                );
2100                if matches {
2101                    let fact_hash = format!(
2102                        "{:?}",
2103                        extract_scalar_key(src_batch, &all_src_indices, src_row)
2104                    )
2105                    .into_bytes();
2106                    inputs.push(ProofTerm {
2107                        source_rule: binding.rule_name.clone(),
2108                        base_fact_id: fact_hash,
2109                    });
2110                }
2111            }
2112        }
2113    }
2114
2115    inputs
2116}
2117
2118/// Phase D D-C0: per-body-row variant of [`collect_is_ref_inputs`] used
2119/// to pre-populate `FoldExec`'s `body_support_map` for TopKProofs MNOR.
2120///
2121/// At FOLD time, the rule's own facts haven't been recorded in the
2122/// `ProvenanceStore` yet (`record_provenance` runs after fact
2123/// materialization, and is keyed by post-YIELD hashes anyway), so the
2124/// support set for each pre-fold body row must be reconstructed
2125/// directly from the rule's IS-ref bindings + the source rules'
2126/// registry data.
2127///
2128/// We don't know which clause produced each body row at this point —
2129/// the iteration-local `clause_candidates` are gone — so we iterate
2130/// **every** clause's `is_ref_bindings`. The `provenance_join_cols`
2131/// schema check inside `collect_is_ref_inputs` already skips bindings
2132/// whose body columns aren't in the row's schema, so cross-clause
2133/// contamination is bounded (a binding only matches if its body cols
2134/// are present and the values join). For the single-clause TopKProofs
2135/// scenarios in TCK this is exact; for multi-clause TopKProofs rules
2136/// it is a conservative over-approximation that may inflate base-RV
2137/// counts (treated as the same RV under interning) — acceptable
2138/// because the DNF math collapses duplicates by inclusion-exclusion.
2139fn collect_is_ref_inputs_for_body_row(
2140    rule: &FixpointRulePlan,
2141    delta_batch: &RecordBatch,
2142    row_idx: usize,
2143    registry: &Arc<DerivedScanRegistry>,
2144) -> Vec<ProofTerm> {
2145    let mut combined: Vec<ProofTerm> = Vec::new();
2146    for clause_index in 0..rule.clauses.len() {
2147        let part = collect_is_ref_inputs(rule, clause_index, delta_batch, row_idx, registry);
2148        combined.extend(part);
2149    }
2150    combined
2151}
2152
2153// ---------------------------------------------------------------------------
2154// Shared-lineage detection
2155// ---------------------------------------------------------------------------
2156
2157/// Detect KEY groups in a rule's pre-fold facts where recursive derivation
2158/// may violate the independence assumption of MNOR/MPROD.
2159///
2160/// Uses a two-tier strategy:
2161/// 1. **Precise**: If the `ProvenanceStore` has populated `support` for facts
2162///    in the group, we recursively compute lineage (Cui & Widom 2000) and
2163///    check for pairwise overlap. A shared base fact proves a dependency.
2164/// 2. **Structural fallback**: When lineage tracking is unavailable (e.g., the
2165///    IS-ref subject variables were projected away), we check whether any fact
2166///    in a multi-row group was derived by a clause that has IS-ref bindings.
2167///    Recursive derivation through shared relations is a strong signal that
2168///    proof paths may share intermediate nodes.
2169///
2170/// Per-row data collected during shared-lineage detection.
2171#[expect(
2172    dead_code,
2173    reason = "Fields accessed via SharedLineageInfo in detect_shared_lineage"
2174)]
2175pub(crate) struct SharedGroupRow {
2176    pub fact_hash: Vec<u8>,
2177    pub lineage: HashSet<Vec<u8>>,
2178}
2179
2180/// Information about groups with shared proofs, returned by `detect_shared_lineage`.
2181pub(crate) struct SharedLineageInfo {
2182    /// KEY group → rows with their base fact sets.
2183    pub shared_groups: HashMap<Vec<ScalarKey>, Vec<SharedGroupRow>>,
2184}
2185
2186/// Build a byte key that uniquely identifies a row across all columns.
2187pub(crate) fn fact_hash_key(batch: &RecordBatch, all_indices: &[usize], row_idx: usize) -> Vec<u8> {
2188    format!("{:?}", extract_scalar_key(batch, all_indices, row_idx)).into_bytes()
2189}
2190
2191/// Emits at most one `SharedProbabilisticDependency` warning per rule.
2192/// Returns `Some(SharedLineageInfo)` if any group has shared proofs.
2193fn detect_shared_lineage(
2194    rule: &FixpointRulePlan,
2195    pre_fold_facts: &[RecordBatch],
2196    tracker: &Arc<ProvenanceStore>,
2197    warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
2198    semiring_kind: SemiringKind,
2199) -> Option<SharedLineageInfo> {
2200    use uni_locy::{RuntimeWarning, RuntimeWarningCode};
2201
2202    // Only check rules with probability-domain fold bindings. M3:
2203    // dispatches via the `LocyAggregate` trait so user-authored
2204    // probability aggregates participate automatically — selected by the
2205    // trait's `is_probability_aggregate()` flag, not by hardcoded name.
2206    let has_prob_fold = rule
2207        .fold_bindings
2208        .iter()
2209        .any(|fb| fb.aggregate.is_probability_aggregate());
2210    if !has_prob_fold {
2211        return None;
2212    }
2213
2214    // Group facts by KEY columns.
2215    let key_indices = &rule.key_column_indices;
2216    let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
2217
2218    let mut groups: HashMap<Vec<ScalarKey>, Vec<Vec<u8>>> = HashMap::new();
2219    for batch in pre_fold_facts {
2220        for row_idx in 0..batch.num_rows() {
2221            let key = extract_scalar_key(batch, key_indices, row_idx);
2222            let fact_hash = fact_hash_key(batch, &all_indices, row_idx);
2223            groups.entry(key).or_default().push(fact_hash);
2224        }
2225    }
2226
2227    let mut shared_groups: HashMap<Vec<ScalarKey>, Vec<SharedGroupRow>> = HashMap::new();
2228    let mut any_shared = false;
2229
2230    // Check each group with ≥2 rows.
2231    for (key, fact_hashes) in &groups {
2232        if fact_hashes.len() < 2 {
2233            continue;
2234        }
2235
2236        // Tier 1: precise base-fact overlap detection via tracker inputs.
2237        let mut has_inputs = false;
2238        let mut per_row_bases: Vec<HashSet<Vec<u8>>> = Vec::new();
2239        for fh in fact_hashes {
2240            let bases = compute_lineage(fh, tracker, &mut HashSet::new());
2241            if let Some(entry) = tracker.lookup(fh)
2242                && !entry.support.is_empty()
2243            {
2244                has_inputs = true;
2245            }
2246            per_row_bases.push(bases);
2247        }
2248
2249        let shared_found = if has_inputs {
2250            // At least some facts have tracked inputs — do precise comparison.
2251            let mut found = false;
2252            'outer: for i in 0..per_row_bases.len() {
2253                for j in (i + 1)..per_row_bases.len() {
2254                    if !per_row_bases[i].is_disjoint(&per_row_bases[j]) {
2255                        found = true;
2256                        break 'outer;
2257                    }
2258                }
2259            }
2260            found
2261        } else {
2262            // Tier 2: structural fallback — check if any fact in the group was
2263            // derived by a clause with IS-ref bindings (recursive derivation).
2264            fact_hashes.iter().any(|fh| {
2265                tracker.lookup(fh).is_some_and(|entry| {
2266                    rule.clauses
2267                        .get(entry.clause_index)
2268                        .is_some_and(|clause| clause.is_ref_bindings.iter().any(|b| !b.negated))
2269                })
2270            })
2271        };
2272
2273        if shared_found {
2274            any_shared = true;
2275            // Collect the group rows with their base facts for BDD use.
2276            let rows: Vec<SharedGroupRow> = fact_hashes
2277                .iter()
2278                .zip(per_row_bases)
2279                .map(|(fh, bases)| SharedGroupRow {
2280                    fact_hash: fh.clone(),
2281                    lineage: bases,
2282                })
2283                .collect();
2284            shared_groups.insert(key.clone(), rows);
2285        }
2286    }
2287
2288    // Phase 5: Cross-group correlation warning.
2289    // Check if any IS-ref input fact appears in multiple KEY groups.
2290    // This is independent of within-group sharing: even rules whose KEY groups
2291    // each have only one post-fold row can exhibit cross-group correlation when
2292    // different groups consume the same IS-ref base fact.
2293    {
2294        let mut input_to_groups: HashMap<Vec<u8>, HashSet<Vec<ScalarKey>>> = HashMap::new();
2295        for (key, fact_hashes) in &groups {
2296            for fh in fact_hashes {
2297                if let Some(entry) = tracker.lookup(fh) {
2298                    for input in &entry.support {
2299                        input_to_groups
2300                            .entry(input.base_fact_id.clone())
2301                            .or_default()
2302                            .insert(key.clone());
2303                    }
2304                }
2305            }
2306        }
2307        let has_cross_group = input_to_groups.values().any(|g| g.len() > 1);
2308        if has_cross_group && let Ok(mut warnings) = warnings_slot.write() {
2309            let already_warned = warnings.iter().any(|w| {
2310                w.code == RuntimeWarningCode::CrossGroupCorrelationNotExact
2311                    && w.rule_name == rule.name
2312            });
2313            if !already_warned {
2314                // Phase D F3: pick one canonical example of a shared
2315                // input fact and the KEY groups it bridges, so users
2316                // can correlate the warning with EXPLAIN output.
2317                let example =
2318                    input_to_groups
2319                        .iter()
2320                        .find(|(_, g)| g.len() > 1)
2321                        .map(|(input, groups)| {
2322                            let short = input
2323                                .iter()
2324                                .take(8)
2325                                .map(|b| format!("{:02x}", b))
2326                                .collect::<String>();
2327                            let mut group_strs: Vec<String> =
2328                                groups.iter().map(|k| format!("{:?}", k)).collect();
2329                            group_strs.sort();
2330                            format!(
2331                                "input {} shared by groups [{}]",
2332                                short,
2333                                group_strs.join(", ")
2334                            )
2335                        });
2336                // Phase D F3 case 1 BDD-time deepening: count distinct
2337                // base facts (= BDD variables) that cross groups, so the
2338                // warning carries structured metadata mirroring
2339                // `BddLimitExceeded`. Users can correlate
2340                // `variable_count` with EXPLAIN's BDD output.
2341                let shared_variable_count =
2342                    input_to_groups.values().filter(|g| g.len() > 1).count();
2343                warnings.push(RuntimeWarning {
2344                    code: RuntimeWarningCode::CrossGroupCorrelationNotExact,
2345                    message: format!(
2346                        "Rule '{}': {} IS-ref base fact(s) are shared across different \
2347                         KEY groups. BDD corrects per-group probabilities but cannot \
2348                         account for cross-group correlations.",
2349                        rule.name, shared_variable_count
2350                    ),
2351                    rule_name: rule.name.clone(),
2352                    variable_count: Some(shared_variable_count),
2353                    key_group: example,
2354                });
2355            }
2356        }
2357    }
2358
2359    if any_shared {
2360        // Phase D D-C0b: under `SemiringKind::TopKProofs`, the FOLD-time
2361        // DNF inclusion-exclusion math (shipped in D-C0) auto-corrects
2362        // for within-group base-fact sharing — the "Results may
2363        // overestimate" premise of `SharedProbabilisticDependency`
2364        // is no longer true. Suppress the warning under TopK; users
2365        // who chose TopKProofs explicitly opted into the
2366        // correctness-preserving path. Cross-group correlation
2367        // (`CrossGroupCorrelationNotExact`) still fires above because
2368        // D-C0 doesn't span KEY-group boundaries.
2369        let suppress_under_topk = matches!(semiring_kind, SemiringKind::TopKProofs { .. });
2370        if !suppress_under_topk && let Ok(mut warnings) = warnings_slot.write() {
2371            let already_warned = warnings.iter().any(|w| {
2372                w.code == RuntimeWarningCode::SharedProbabilisticDependency
2373                    && w.rule_name == rule.name
2374            });
2375            if !already_warned {
2376                warnings.push(RuntimeWarning {
2377                    code: RuntimeWarningCode::SharedProbabilisticDependency,
2378                    message: format!(
2379                        "Rule '{}' aggregates with MNOR/MPROD but some proof paths \
2380                         share intermediate facts, violating the independence assumption. \
2381                         Results may overestimate probability.",
2382                        rule.name
2383                    ),
2384                    rule_name: rule.name.clone(),
2385                    variable_count: None,
2386                    key_group: None,
2387                });
2388            }
2389        }
2390        Some(SharedLineageInfo { shared_groups })
2391    } else {
2392        None
2393    }
2394}
2395
2396/// Record provenance and detect shared proofs for non-recursive strata.
2397///
2398/// Non-recursive rules are evaluated in a single pass (no fixpoint loop), so
2399/// the regular `record_provenance` + `detect_shared_lineage` path is never hit.
2400/// This function bridges that gap by recording a `ProvenanceAnnotation` for every
2401/// fact produced by each clause and then running the same two-tier detection
2402/// logic used by the recursive path.
2403#[allow(
2404    clippy::too_many_arguments,
2405    reason = "context bundle would be over-engineering for one call site"
2406)]
2407pub(crate) async fn record_and_detect_lineage_nonrecursive(
2408    rule: &FixpointRulePlan,
2409    tagged_clause_facts: &[(usize, Vec<RecordBatch>)],
2410    tracker: &Arc<ProvenanceStore>,
2411    warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
2412    registry: &Arc<DerivedScanRegistry>,
2413    top_k_proofs: usize,
2414    classifiers: ClassifierRefs<'_>,
2415    semiring_kind: SemiringKind,
2416) -> Option<SharedLineageInfo> {
2417    let classifier_registry = classifiers.registry;
2418    let classifier_cache = classifiers.cache;
2419    let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
2420
2421    // Pre-compute base fact probabilities for top-k mode.
2422    let base_probs = if top_k_proofs > 0 {
2423        tracker.base_fact_probs()
2424    } else {
2425        HashMap::new()
2426    };
2427
2428    let mut topk_acc = TopKProofAccumulator::new();
2429
2430    // Record provenance for each clause's facts.
2431    for (clause_index, batches) in tagged_clause_facts {
2432        for batch in batches {
2433            for row_idx in 0..batch.num_rows() {
2434                let row_hash = fact_hash_key(batch, &all_indices, row_idx);
2435                let fact_row = batch_row_to_value_map(batch, row_idx);
2436
2437                let support = collect_is_ref_inputs(rule, *clause_index, batch, row_idx, registry);
2438
2439                let proof_probability = if top_k_proofs > 0 {
2440                    compute_proof_probability(&support, &base_probs)
2441                } else {
2442                    None
2443                };
2444
2445                let entry = ProvenanceAnnotation {
2446                    rule_name: rule.name.clone(),
2447                    clause_index: *clause_index,
2448                    support,
2449                    along_values: {
2450                        let along_names: Vec<String> = rule
2451                            .clauses
2452                            .get(*clause_index)
2453                            .map(|c| c.along_bindings.clone())
2454                            .unwrap_or_default();
2455                        along_names
2456                            .iter()
2457                            .filter_map(|name| {
2458                                fact_row.get(name).map(|v| (name.clone(), v.clone()))
2459                            })
2460                            .collect()
2461                    },
2462                    iteration: 0,
2463                    fact_row: fact_row.clone(),
2464                    proof_probability,
2465                    neural_calls: collect_neural_calls_for_row(
2466                        rule,
2467                        *clause_index,
2468                        &fact_row,
2469                        classifier_registry,
2470                        classifier_cache,
2471                        classifiers.provenance_store,
2472                    )
2473                    .await,
2474                };
2475                if top_k_proofs > 0 {
2476                    topk_acc.accumulate(&entry, &row_hash);
2477                    tracker.record_top_k(row_hash, entry, top_k_proofs);
2478                } else {
2479                    tracker.record(row_hash, entry);
2480                }
2481            }
2482        }
2483    }
2484
2485    topk_acc.emit_warning_if_any(rule, top_k_proofs, warnings_slot);
2486
2487    // Flatten all clause facts and run detection.
2488    let all_facts: Vec<RecordBatch> = tagged_clause_facts
2489        .iter()
2490        .flat_map(|(_, batches)| batches.iter().cloned())
2491        .collect();
2492    detect_shared_lineage(rule, &all_facts, tracker, warnings_slot, semiring_kind)
2493}
2494
2495/// Apply exact weighted model counting (WMC) for shared-lineage groups.
2496///
2497/// Replaces multiple rows in groups with shared lineage with a single
2498/// representative row whose PROB column carries the BDD-computed exact
2499/// probability (Sang et al. 2005). For groups that exceed
2500/// `max_bdd_variables`, rows are left unchanged and a `BddLimitExceeded`
2501/// warning is emitted.
2502pub(crate) fn apply_exact_wmc(
2503    pre_fold_facts: Vec<RecordBatch>,
2504    rule: &FixpointRulePlan,
2505    shared_info: &SharedLineageInfo,
2506    tracker: &Arc<ProvenanceStore>,
2507    max_bdd_variables: usize,
2508    warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
2509    approximate_slot: &Arc<StdRwLock<HashMap<String, Vec<String>>>>,
2510) -> DFResult<Vec<RecordBatch>> {
2511    use crate::query::df_graph::locy_bdd::{SemiringOp, weighted_model_count};
2512    use uni_locy::{RuntimeWarning, RuntimeWarningCode};
2513
2514    // Find the probability-domain fold binding to know which column
2515    // to overwrite. M3: dispatch through the `LocyAggregate` trait so
2516    // user-authored probability aggregates participate.
2517    let prob_fold = rule
2518        .fold_bindings
2519        .iter()
2520        .find(|fb| fb.aggregate.is_probability_aggregate());
2521    let prob_fold = match prob_fold {
2522        Some(f) => f,
2523        None => return Ok(pre_fold_facts),
2524    };
2525    let semiring_op = if prob_fold.aggregate.is_noisy_or() {
2526        SemiringOp::Disjunction
2527    } else {
2528        SemiringOp::Conjunction
2529    };
2530    let prob_col_idx = prob_fold.input_col_index;
2531    let prob_col_name = rule.yield_schema.field(prob_col_idx).name().clone();
2532
2533    let key_indices = &rule.key_column_indices;
2534    let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
2535
2536    // Build a set of shared group keys for quick lookup.
2537    let shared_keys: HashSet<Vec<ScalarKey>> = shared_info.shared_groups.keys().cloned().collect();
2538
2539    // Phase 1: Collect all rows for each shared KEY group across all batches.
2540    // Store (batch_index, row_index) pairs for each group.
2541    struct GroupAccum {
2542        base_facts: Vec<HashSet<Vec<u8>>>,
2543        base_probs: HashMap<Vec<u8>, f64>,
2544        /// First occurrence: (batch_index, row_index) — used as representative.
2545        representative: (usize, usize),
2546        row_locations: Vec<(usize, usize)>,
2547    }
2548
2549    let mut group_accums: HashMap<Vec<ScalarKey>, GroupAccum> = HashMap::new();
2550    let mut non_shared_rows: Vec<(usize, usize)> = Vec::new(); // (batch_idx, row_idx)
2551
2552    for (batch_idx, batch) in pre_fold_facts.iter().enumerate() {
2553        for row_idx in 0..batch.num_rows() {
2554            let key = extract_scalar_key(batch, key_indices, row_idx);
2555            if shared_keys.contains(&key) {
2556                let fact_hash = fact_hash_key(batch, &all_indices, row_idx);
2557                let bases = compute_lineage(&fact_hash, tracker, &mut HashSet::new());
2558
2559                let accum = group_accums.entry(key).or_insert_with(|| GroupAccum {
2560                    base_facts: Vec::new(),
2561                    base_probs: HashMap::new(),
2562                    representative: (batch_idx, row_idx),
2563                    row_locations: Vec::new(),
2564                });
2565
2566                // Look up probabilities for base facts.
2567                for bf in &bases {
2568                    if !accum.base_probs.contains_key(bf)
2569                        && let Some(entry) = tracker.lookup(bf)
2570                        && let Some(val) = entry.fact_row.get(&prob_col_name)
2571                        && let Some(p) = value_to_f64(val)
2572                    {
2573                        accum.base_probs.insert(bf.clone(), p);
2574                    }
2575                }
2576
2577                accum.base_facts.push(bases);
2578                accum.row_locations.push((batch_idx, row_idx));
2579            } else {
2580                non_shared_rows.push((batch_idx, row_idx));
2581            }
2582        }
2583    }
2584
2585    // Phase 2: Compute BDD for each shared group (across all batches).
2586    // Track which (batch_idx, row_idx) pairs to keep vs drop.
2587    let mut keep_rows: HashSet<(usize, usize)> = HashSet::new();
2588    // Map of (batch_idx, row_idx) → overridden PROB value (for BDD-succeeded groups).
2589    let mut overrides: HashMap<(usize, usize), f64> = HashMap::new();
2590
2591    // All non-shared rows are kept.
2592    for &loc in &non_shared_rows {
2593        keep_rows.insert(loc);
2594    }
2595
2596    for (key, accum) in &group_accums {
2597        let bdd_result = weighted_model_count(
2598            &accum.base_facts,
2599            &accum.base_probs,
2600            semiring_op,
2601            max_bdd_variables,
2602        );
2603
2604        if bdd_result.approximated {
2605            // Emit BddLimitExceeded warning (one per key group).
2606            if let Ok(mut warnings) = warnings_slot.write() {
2607                let key_desc = format!("{key:?}");
2608                let already_warned = warnings.iter().any(|w| {
2609                    w.code == RuntimeWarningCode::BddLimitExceeded
2610                        && w.rule_name == rule.name
2611                        && w.key_group.as_deref() == Some(&key_desc)
2612                });
2613                if !already_warned {
2614                    warnings.push(RuntimeWarning {
2615                        code: RuntimeWarningCode::BddLimitExceeded,
2616                        message: format!(
2617                            "Rule '{}': BDD variable limit exceeded ({} > {}). \
2618                             Falling back to independence-mode result.",
2619                            rule.name, bdd_result.variable_count, max_bdd_variables
2620                        ),
2621                        rule_name: rule.name.clone(),
2622                        variable_count: Some(bdd_result.variable_count),
2623                        key_group: Some(key_desc),
2624                    });
2625                }
2626            }
2627            if let Ok(mut approx) = approximate_slot.write() {
2628                let key_desc = format!("{key:?}");
2629                approx.entry(rule.name.clone()).or_default().push(key_desc);
2630            }
2631            // Keep all rows unchanged.
2632            for &loc in &accum.row_locations {
2633                keep_rows.insert(loc);
2634            }
2635        } else {
2636            // BDD succeeded: keep one representative row with overridden PROB.
2637            keep_rows.insert(accum.representative);
2638            overrides.insert(accum.representative, bdd_result.probability);
2639        }
2640    }
2641
2642    // Phase 3: Build output batches by filtering kept rows per batch.
2643    let mut result_batches = Vec::new();
2644    for (batch_idx, batch) in pre_fold_facts.iter().enumerate() {
2645        let kept_indices: Vec<usize> = (0..batch.num_rows())
2646            .filter(|&row_idx| keep_rows.contains(&(batch_idx, row_idx)))
2647            .collect();
2648
2649        if kept_indices.is_empty() {
2650            continue;
2651        }
2652
2653        let indices = arrow::array::UInt32Array::from(
2654            kept_indices.iter().map(|&i| i as u32).collect::<Vec<_>>(),
2655        );
2656        let mut columns: Vec<arrow::array::ArrayRef> = batch
2657            .columns()
2658            .iter()
2659            .map(|col| arrow::compute::take(col, &indices, None))
2660            .collect::<Result<Vec<_>, _>>()
2661            .map_err(arrow_err)?;
2662
2663        // Check if any kept rows have PROB overrides.
2664        let override_map: Vec<Option<f64>> = kept_indices
2665            .iter()
2666            .map(|&row_idx| overrides.get(&(batch_idx, row_idx)).copied())
2667            .collect();
2668
2669        if override_map.iter().any(|o| o.is_some()) && prob_col_idx < columns.len() {
2670            // Rebuild the PROB column with overrides.
2671            let existing_prob = columns[prob_col_idx]
2672                .as_any()
2673                .downcast_ref::<arrow::array::Float64Array>();
2674            let new_values: Vec<f64> = override_map
2675                .iter()
2676                .enumerate()
2677                .map(|(i, ov)| match ov {
2678                    Some(p) => *p,
2679                    None => existing_prob.map(|arr| arr.value(i)).unwrap_or(0.0),
2680                })
2681                .collect();
2682            columns[prob_col_idx] = Arc::new(arrow::array::Float64Array::from(new_values));
2683        }
2684
2685        let result_batch = RecordBatch::try_new(batch.schema(), columns).map_err(arrow_err)?;
2686        result_batches.push(result_batch);
2687    }
2688
2689    Ok(result_batches)
2690}
2691
2692/// Extract an f64 from a `Value`, supporting Float and Int.
2693fn value_to_f64(val: &uni_common::Value) -> Option<f64> {
2694    match val {
2695        uni_common::Value::Float(f) => Some(*f),
2696        uni_common::Value::Int(i) => Some(*i as f64),
2697        _ => None,
2698    }
2699}
2700
2701/// Compute the lineage of a derived fact (Cui & Widom 2000).
2702///
2703/// Recursively traverses the provenance store to collect the set of base-level
2704/// fact hashes that contribute to this derivation. A base fact is one with no
2705/// IS-ref support (a graph-level fact). Intermediate facts are expanded
2706/// transitively through the store.
2707fn compute_lineage(
2708    fact_hash: &[u8],
2709    tracker: &Arc<ProvenanceStore>,
2710    visited: &mut HashSet<Vec<u8>>,
2711) -> HashSet<Vec<u8>> {
2712    if !visited.insert(fact_hash.to_vec()) {
2713        return HashSet::new(); // Cycle guard.
2714    }
2715
2716    match tracker.lookup(fact_hash) {
2717        Some(entry) if !entry.support.is_empty() => {
2718            let mut bases = HashSet::new();
2719            for input in &entry.support {
2720                let child_bases = compute_lineage(&input.base_fact_id, tracker, visited);
2721                bases.extend(child_bases);
2722            }
2723            bases
2724        }
2725        _ => {
2726            // Base fact (no tracker entry or no inputs).
2727            let mut set = HashSet::new();
2728            set.insert(fact_hash.to_vec());
2729            set
2730        }
2731    }
2732}
2733
2734/// Determine which clause produced a given row by checking each clause's candidates.
2735///
2736/// Returns the index of the first clause whose candidates contain a matching row.
2737/// Falls back to 0 if no match is found.
2738fn find_clause_for_row(
2739    delta_batch: &RecordBatch,
2740    row_idx: usize,
2741    all_indices: &[usize],
2742    clause_candidates: &[Vec<RecordBatch>],
2743) -> usize {
2744    let target_key = extract_scalar_key(delta_batch, all_indices, row_idx);
2745    for (clause_idx, batches) in clause_candidates.iter().enumerate() {
2746        for batch in batches {
2747            if batch.num_columns() != all_indices.len() {
2748                continue;
2749            }
2750            for r in 0..batch.num_rows() {
2751                if extract_scalar_key(batch, all_indices, r) == target_key {
2752                    return clause_idx;
2753                }
2754            }
2755        }
2756    }
2757    0
2758}
2759
2760/// Convert a single row from a `RecordBatch` at `row_idx` into a `HashMap<String, Value>`.
2761fn batch_row_to_value_map(
2762    batch: &RecordBatch,
2763    row_idx: usize,
2764) -> std::collections::HashMap<String, Value> {
2765    use uni_store::storage::arrow_convert::arrow_to_value;
2766
2767    let schema = batch.schema();
2768    schema
2769        .fields()
2770        .iter()
2771        .enumerate()
2772        .map(|(col_idx, field)| {
2773            let col = batch.column(col_idx);
2774            let val = arrow_to_value(col.as_ref(), row_idx, None);
2775            (field.name().clone(), val)
2776        })
2777        .collect()
2778}
2779
2780/// Filter `batches` to exclude rows where `left_col` VID appears in `neg_facts[right_col]`.
2781///
2782/// Implements anti-join semantics for negated IS-refs (`n IS NOT rule`): keeps only
2783/// rows whose subject VID is NOT present in the negated rule's fully-converged facts.
2784pub fn apply_anti_join(
2785    batches: Vec<RecordBatch>,
2786    neg_facts: &[RecordBatch],
2787    left_col: &str,
2788    right_col: &str,
2789) -> datafusion::error::Result<Vec<RecordBatch>> {
2790    use arrow::compute::filter_record_batch;
2791    use arrow_array::{Array as _, BooleanArray, UInt64Array};
2792
2793    // Collect right-side VIDs from the negated rule's derived facts.
2794    let mut banned: std::collections::HashSet<u64> = std::collections::HashSet::new();
2795    for batch in neg_facts {
2796        let Ok(idx) = batch.schema().index_of(right_col) else {
2797            continue;
2798        };
2799        let arr = batch.column(idx);
2800        let Some(vids) = arr.as_any().downcast_ref::<UInt64Array>() else {
2801            continue;
2802        };
2803        for i in 0..vids.len() {
2804            if !vids.is_null(i) {
2805                banned.insert(vids.value(i));
2806            }
2807        }
2808    }
2809
2810    if banned.is_empty() {
2811        return Ok(batches);
2812    }
2813
2814    // Filter body batches: keep rows where left_col NOT IN banned.
2815    let mut result = Vec::new();
2816    for batch in batches {
2817        let Ok(idx) = batch.schema().index_of(left_col) else {
2818            result.push(batch);
2819            continue;
2820        };
2821        let arr = batch.column(idx);
2822        let Some(vids) = arr.as_any().downcast_ref::<UInt64Array>() else {
2823            result.push(batch);
2824            continue;
2825        };
2826        let keep: Vec<bool> = (0..vids.len())
2827            .map(|i| vids.is_null(i) || !banned.contains(&vids.value(i)))
2828            .collect();
2829        let keep_arr = BooleanArray::from(keep);
2830        let filtered = filter_record_batch(&batch, &keep_arr).map_err(arrow_err)?;
2831        if filtered.num_rows() > 0 {
2832            result.push(filtered);
2833        }
2834    }
2835    Ok(result)
2836}
2837
2838// ─── Phase B Slice 3: neural-model invocation pass ───────────────────────
2839//
2840// `apply_model_invocations` runs every `ModelInvocation` lifted from a
2841// clause's YIELD items against the body's output batches. For each
2842// invocation it:
2843//
2844//   1. Resolves each `feature_expr` to a column in the batch — Slice 3
2845//      supports plain `Expr::Variable("name")` references; richer
2846//      expressions (property access, nested calls) are deferred.
2847//   2. Builds one `ClassifyInput` per row keyed by the model's input
2848//      binding names.
2849//   3. Issues a single batched `NeuralClassifier::classify` call.
2850//   4. Appends the resulting `Float64Array` as a new column matching
2851//      `invocation.output_column`.
2852//
2853// Errors:
2854//   * `UnknownNeuralModel`: the model name isn't in the registry.
2855//   * Mismatched feature-expr / column: returned as a DataFusion
2856//     Execution error.
2857
2858#[allow(clippy::too_many_arguments)]
2859pub(crate) async fn apply_model_invocations(
2860    batches: Vec<RecordBatch>,
2861    invocations: &[uni_locy::ModelInvocation],
2862    registry: &Arc<ClassifierRegistry>,
2863    cache: Option<&Arc<uni_locy::ModelInvocationCache>>,
2864    provenance_store: Option<&Arc<uni_locy::NeuralProvenanceStore>>,
2865    path_context_handles: &HashMap<
2866        String,
2867        crate::query::df_graph::locy_model_invoke::PathContextHandle,
2868    >,
2869    xervo_runtime: &crate::query::df_graph::locy_model_invoke::XervoRuntimeHandle,
2870    graph_algo: &crate::query::df_graph::locy_model_invoke::GraphAlgoHandle,
2871) -> DFResult<Vec<RecordBatch>> {
2872    use uni_locy::ClassifyInput;
2873    if batches.is_empty() || invocations.is_empty() {
2874        return Ok(batches);
2875    }
2876    // Phase D D2: pre-embed all unique `semantic_match` query literals
2877    // once per call. Resolvers below lower each `semantic_match(prop,
2878    // 'text')` into a `SimilarTo { left: prop_col, right: Const(Vector) }`.
2879    let semantic_match_embeddings =
2880        pre_embed_semantic_match_queries(invocations, xervo_runtime).await?;
2881    // Phase D D1 graph-structural: pre-compute topology scores and
2882    // neighbor-property maps for every distinct (fn_name, args) tuple
2883    // appearing in any FEATURE FunctionCall. One pass per call; reused
2884    // across every row of every batch.
2885    let graph_feature_maps = precompute_graph_feature_maps(invocations, graph_algo).await?;
2886    let neighbor_feature_maps =
2887        precompute_neighbor_feature_maps(invocations, &batches, graph_algo).await?;
2888    let mut out_batches = Vec::with_capacity(batches.len());
2889    for batch in batches {
2890        let mut current = batch;
2891        for invocation in invocations {
2892            let classifier = registry.get(&invocation.model_name).ok_or_else(|| {
2893                datafusion::error::DataFusionError::Execution(format!(
2894                    "neural classifier '{}' not registered; \
2895                         add it to LocyConfig::classifier_registry",
2896                    invocation.model_name
2897                ))
2898            })?;
2899
2900            // Resolve each feature_expr to a per-row evaluator.
2901            // Supported shapes (validated at compile time by
2902            // `extract_model_invocations` / `validate_features`):
2903            //   * `Expr::Variable("name")` — direct column reference.
2904            //   * `Expr::Property(Variable(v), prop)` — looked up by the
2905            //     conventional `"v.prop"` column name materialized by
2906            //     the planner's `translate_property_access` pipeline.
2907            //   * `Expr::FunctionCall { name: "similar_to"|"semantic_match", ... }`
2908            //     — Phase D D1/D2 retrieval-backed feature; both args
2909            //     resolved to columns; UDF evaluated per row against
2910            //     the row's `uni_common::Value` payloads.
2911            let resolvers = build_feature_resolvers(
2912                &current,
2913                invocation,
2914                path_context_handles,
2915                &semantic_match_embeddings,
2916                &graph_feature_maps,
2917                &neighbor_feature_maps,
2918            )?;
2919
2920            // Build one ClassifyInput per row.
2921            let n_rows = current.num_rows();
2922            let mut inputs: Vec<ClassifyInput> = Vec::with_capacity(n_rows);
2923            let mut input_hashes: Vec<u64> = Vec::with_capacity(n_rows);
2924            for row_idx in 0..n_rows {
2925                let mut features = std::collections::HashMap::new();
2926                for resolver in &resolvers {
2927                    let value = resolver.eval_row(&current, row_idx)?;
2928                    features.insert(resolver.binding_name.clone(), value);
2929                }
2930                let input = ClassifyInput { features };
2931                input_hashes.push(input.stable_hash());
2932                inputs.push(input);
2933            }
2934
2935            // Slice 2: memoization. Split inputs into cache hits and
2936            // misses; only call `classify` on the misses, then weave
2937            // the cached values back in by original row index.
2938            let mut probs: Vec<f64> = vec![0.0; n_rows];
2939            let mut miss_inputs: Vec<ClassifyInput> = Vec::new();
2940            let mut miss_row_indices: Vec<usize> = Vec::new();
2941            if let Some(c) = cache {
2942                for (row_idx, h) in input_hashes.iter().enumerate() {
2943                    match c.get(&invocation.model_name, *h) {
2944                        Some(v) => probs[row_idx] = v,
2945                        None => {
2946                            miss_row_indices.push(row_idx);
2947                            miss_inputs.push(inputs[row_idx].clone());
2948                        }
2949                    }
2950                }
2951            } else {
2952                miss_row_indices = (0..n_rows).collect();
2953                miss_inputs = inputs.clone();
2954            }
2955
2956            // Phase C C-RawCalibratedSeparation: when a calibrator is
2957            // present, route through `raw_and_calibrated` so the
2958            // provenance store records both the base classifier's raw
2959            // output AND the post-calibrator value. Bare classifiers
2960            // (no calibrator) keep using `classify`. The downstream
2961            // `probs[row]` is always the *calibrated* value when
2962            // available — that's what the rule's PROB output column
2963            // and the memoization cache carry.
2964            let calibrator = classifier.get_calibrator();
2965            let (miss_raws, miss_calibrated) = if miss_inputs.is_empty() {
2966                (Vec::new(), Vec::new())
2967            } else if calibrator.is_some() {
2968                let pairs = classifier
2969                    .raw_and_calibrated(&miss_inputs)
2970                    .await
2971                    .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
2972                if pairs.len() != miss_inputs.len() {
2973                    return Err(datafusion::error::DataFusionError::Execution(format!(
2974                        "classifier '{}' raw_and_calibrated returned {} outputs for {} inputs",
2975                        invocation.model_name,
2976                        pairs.len(),
2977                        miss_inputs.len()
2978                    )));
2979                }
2980                let raws: Vec<f64> = pairs.iter().map(|(r, _)| *r).collect();
2981                let cals: Vec<f64> = pairs.iter().map(|(r, c)| c.unwrap_or(*r)).collect();
2982                (raws, cals)
2983            } else {
2984                let r = classifier
2985                    .classify(&miss_inputs)
2986                    .await
2987                    .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
2988                if r.len() != miss_inputs.len() {
2989                    return Err(datafusion::error::DataFusionError::Execution(format!(
2990                        "classifier '{}' returned {} outputs for {} inputs",
2991                        invocation.model_name,
2992                        r.len(),
2993                        miss_inputs.len()
2994                    )));
2995                }
2996                // No calibrator → raw == final.
2997                (r.clone(), r)
2998            };
2999            // The memoization cache stores the *final* (calibrated when
3000            // available) value, matching what downstream rules consume.
3001            // Track per-miss raws alongside so the provenance store
3002            // sees both. For cache hits, we don't have raws — the
3003            // provenance record for that row gets None for `raw` and
3004            // the cached value as `calibrated` (the only thing we
3005            // remembered). Future slice can extend the cache to carry
3006            // both; current behavior matches pre-fix correctness for
3007            // EXPLAIN of cache hits.
3008            let mut row_raw: Vec<Option<f64>> = vec![None; n_rows];
3009            for (i, &row_idx) in miss_row_indices.iter().enumerate() {
3010                probs[row_idx] = miss_calibrated[i];
3011                row_raw[row_idx] = Some(miss_raws[i]);
3012                if let Some(c) = cache {
3013                    c.insert(
3014                        &invocation.model_name,
3015                        input_hashes[row_idx],
3016                        miss_calibrated[i],
3017                    );
3018                }
3019            }
3020
3021            // Phase C B1-B3 follow-up: when a provenance store is
3022            // configured, record (raw, calibrated, confidence_band)
3023            // per row. With C-RawCalibratedSeparation (above),
3024            // `row_raw[i]` carries the *pre-calibrator* value when we
3025            // computed it on this call; `probs[i]` is the
3026            // post-calibrator value. `confidence_band` comes from the
3027            // active Calibrator's `confidence_band(p)`.
3028            if let Some(store) = provenance_store {
3029                for row_idx in 0..n_rows {
3030                    let calibrated_value = probs[row_idx];
3031                    let (raw_value, calibrated) = match (row_raw[row_idx], &calibrator) {
3032                        (Some(raw), Some(_)) => (raw, Some(calibrated_value)),
3033                        (Some(raw), None) => (raw, None),
3034                        // Cache hit: we only have the calibrated value.
3035                        // Report it as raw with `calibrated == raw` to
3036                        // preserve telemetry shape; document this in
3037                        // the field doc.
3038                        (None, _) => (
3039                            calibrated_value,
3040                            calibrator.as_ref().map(|_| calibrated_value),
3041                        ),
3042                    };
3043                    let band = calibrator
3044                        .as_ref()
3045                        .and_then(|c| c.confidence_band(calibrated_value));
3046                    store.record(
3047                        &invocation.model_name,
3048                        input_hashes[row_idx],
3049                        uni_locy::NeuralProvenanceRecord {
3050                            raw_probability: raw_value,
3051                            calibrated_probability: calibrated,
3052                            confidence_band: band,
3053                            // Phase 12 EXPLAIN follow-up: stash the
3054                            // per-binding `FeatureValue` map that fed
3055                            // the classifier so Mode B re-evaluation
3056                            // can surface graph-structural feature
3057                            // values without re-precomputing topology
3058                            // or neighbor maps (the hot-path data is
3059                            // authoritative).
3060                            feature_inputs: inputs[row_idx].features.clone(),
3061                        },
3062                    );
3063                }
3064            }
3065
3066            // Overwrite the placeholder column put there by the
3067            // compile-time YIELD rewrite. If the column doesn't yet
3068            // exist (defensive — shouldn't happen for well-formed
3069            // plans), fall back to appending.
3070            let out_col: Arc<dyn arrow_array::Array> =
3071                Arc::new(arrow_array::Float64Array::from(probs));
3072            let schema = current.schema();
3073            let target_idx = schema.index_of(&invocation.output_column).ok();
3074            let mut columns: Vec<Arc<dyn arrow_array::Array>> = current.columns().to_vec();
3075            let mut fields: Vec<Arc<arrow_schema::Field>> =
3076                schema.fields().iter().cloned().collect();
3077            match target_idx {
3078                Some(idx) => {
3079                    columns[idx] = out_col;
3080                    // Force the field's data type to Float64 in case
3081                    // the placeholder was inferred at a wider type.
3082                    fields[idx] = Arc::new(arrow_schema::Field::new(
3083                        &invocation.output_column,
3084                        arrow_schema::DataType::Float64,
3085                        true,
3086                    ));
3087                }
3088                None => {
3089                    columns.push(out_col);
3090                    fields.push(Arc::new(arrow_schema::Field::new(
3091                        &invocation.output_column,
3092                        arrow_schema::DataType::Float64,
3093                        true,
3094                    )));
3095                }
3096            }
3097            let new_schema = Arc::new(arrow_schema::Schema::new(fields));
3098            current = RecordBatch::try_new(new_schema, columns).map_err(arrow_err)?;
3099        }
3100        out_batches.push(current);
3101    }
3102    Ok(out_batches)
3103}
3104
3105/// Extract a [`uni_locy::FeatureValue`] from a column at a given row.
3106/// Conservative cast set matching the property-graph value types Locy
3107/// currently exposes; unsupported types fall back to `Null`.
3108/// Per-feature evaluator built once from a clause's `feature_exprs`
3109/// and reused for every row in the batch. Supports plain column
3110/// reads (`Direct`) and the Phase D D1/D2 retrieval-backed UDFs
3111/// (`SimilarTo`, `SemanticMatch`).
3112struct FeatureResolver {
3113    binding_name: String,
3114    kind: FeatureResolverKind,
3115}
3116
3117enum FeatureResolverKind {
3118    Direct(usize),
3119    SimilarTo {
3120        left: FeatureValueSrc,
3121        right: FeatureValueSrc,
3122    },
3123    /// Phase D D3 runtime: pull `column` from the source rule's
3124    /// derived facts via a pre-built `vid → FeatureValue` lookup. The
3125    /// `subject_col` is the index of `<subject_var>._vid` in the body
3126    /// batch; the lookup runs once per row.
3127    PathContext {
3128        subject_col: usize,
3129        vid_to_value: Arc<HashMap<u64, uni_locy::FeatureValue>>,
3130    },
3131    /// Phase D D1 graph-structural: look up the subject's pre-computed
3132    /// topology score (degree/pagerank/closeness). `subject_col` indexes
3133    /// the row's `<var>._vid`; `vid_to_score` is the whole-graph
3134    /// procedure output built once per `apply_model_invocations` call.
3135    GraphAlgoScore {
3136        subject_col: usize,
3137        vid_to_score: Arc<HashMap<u64, f64>>,
3138    },
3139    /// Phase D D1 graph-structural: aggregate a numeric property over
3140    /// each subject's one-hop outgoing neighborhood along a named edge
3141    /// type. `vid_to_values` maps subject vid → the list of numeric
3142    /// neighbor property values collected at precompute time; the
3143    /// per-row resolver applies the configured `op` (avg/max/sum).
3144    NeighborAggregate {
3145        subject_col: usize,
3146        op: NeighborAgg,
3147        vid_to_values: Arc<HashMap<u64, Vec<f64>>>,
3148    },
3149}
3150
3151#[derive(Debug, Clone, Copy)]
3152enum NeighborAgg {
3153    Avg,
3154    Max,
3155    Sum,
3156}
3157
3158/// Direction for one-hop neighborhood traversal. Mirrors
3159/// `uni_store::storage::direction::Direction` but is independent so
3160/// the typecheck / planner layer doesn't depend on uni-store.
3161#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3162enum NeighborDirection {
3163    Outgoing,
3164    Incoming,
3165    Both,
3166}
3167
3168impl NeighborDirection {
3169    fn store_directions(self) -> &'static [uni_store::storage::direction::Direction] {
3170        use uni_store::storage::direction::Direction;
3171        match self {
3172            NeighborDirection::Outgoing => &[Direction::Outgoing],
3173            NeighborDirection::Incoming => &[Direction::Incoming],
3174            NeighborDirection::Both => &[Direction::Outgoing, Direction::Incoming],
3175        }
3176    }
3177}
3178
3179impl NeighborAgg {
3180    fn from_fn_name(name: &str) -> Option<Self> {
3181        match name {
3182            "avg_neighbor" => Some(NeighborAgg::Avg),
3183            "max_neighbor" => Some(NeighborAgg::Max),
3184            "sum_neighbor" => Some(NeighborAgg::Sum),
3185            _ => None,
3186        }
3187    }
3188
3189    fn apply(self, values: &[f64]) -> Option<f64> {
3190        if values.is_empty() {
3191            return None;
3192        }
3193        match self {
3194            NeighborAgg::Avg => Some(values.iter().sum::<f64>() / values.len() as f64),
3195            NeighborAgg::Max => values.iter().copied().reduce(f64::max),
3196            NeighborAgg::Sum => Some(values.iter().sum()),
3197        }
3198    }
3199}
3200
3201/// One side of a `similar_to` feature: either a column index in the
3202/// per-row batch or a constant value lifted from a literal expression.
3203enum FeatureValueSrc {
3204    Col(usize),
3205    Const(uni_common::Value),
3206}
3207
3208impl FeatureValueSrc {
3209    fn resolve(&self, batch: &RecordBatch, row_idx: usize) -> uni_common::Value {
3210        match self {
3211            FeatureValueSrc::Col(idx) => extract_common_value(batch.column(*idx).as_ref(), row_idx),
3212            FeatureValueSrc::Const(v) => v.clone(),
3213        }
3214    }
3215}
3216
3217impl FeatureResolver {
3218    fn eval_row(&self, batch: &RecordBatch, row_idx: usize) -> DFResult<uni_locy::FeatureValue> {
3219        match &self.kind {
3220            FeatureResolverKind::Direct(idx) => {
3221                Ok(extract_feature_value(batch.column(*idx).as_ref(), row_idx))
3222            }
3223            FeatureResolverKind::SimilarTo { left, right } => {
3224                let lv = left.resolve(batch, row_idx);
3225                let rv = right.resolve(batch, row_idx);
3226                match crate::query::similar_to::eval_similar_to_pure(&lv, &rv) {
3227                    Ok(uni_common::Value::Float(f)) => Ok(uni_locy::FeatureValue::Float(f)),
3228                    Ok(_) => Ok(uni_locy::FeatureValue::Null),
3229                    Err(e) => Err(datafusion::error::DataFusionError::Execution(format!(
3230                        "similar_to UDF failed: {e}"
3231                    ))),
3232                }
3233            }
3234            FeatureResolverKind::PathContext {
3235                subject_col,
3236                vid_to_value,
3237            } => {
3238                let col = batch.column(*subject_col);
3239                if col.is_null(row_idx) {
3240                    return Ok(uni_locy::FeatureValue::Null);
3241                }
3242                if let Some(arr) = col.as_any().downcast_ref::<arrow_array::UInt64Array>() {
3243                    let vid = arr.value(row_idx);
3244                    Ok(vid_to_value
3245                        .get(&vid)
3246                        .cloned()
3247                        .unwrap_or(uni_locy::FeatureValue::Null))
3248                } else if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Int64Array>() {
3249                    let vid = arr.value(row_idx) as u64;
3250                    Ok(vid_to_value
3251                        .get(&vid)
3252                        .cloned()
3253                        .unwrap_or(uni_locy::FeatureValue::Null))
3254                } else {
3255                    Ok(uni_locy::FeatureValue::Null)
3256                }
3257            }
3258            FeatureResolverKind::GraphAlgoScore {
3259                subject_col,
3260                vid_to_score,
3261            } => {
3262                let col = batch.column(*subject_col);
3263                if col.is_null(row_idx) {
3264                    return Ok(uni_locy::FeatureValue::Null);
3265                }
3266                let vid_opt: Option<u64> = if let Some(arr) =
3267                    col.as_any().downcast_ref::<arrow_array::UInt64Array>()
3268                {
3269                    Some(arr.value(row_idx))
3270                } else if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Int64Array>() {
3271                    Some(arr.value(row_idx) as u64)
3272                } else {
3273                    // Fallback: subject column carries a Node-encoded
3274                    // `uni_common::Value` (LargeBinary via codec). Decode
3275                    // and pull the VID. This is the common case for
3276                    // bare-variable subjects where no `_vid` hidden
3277                    // column was materialized.
3278                    match extract_common_value(col.as_ref(), row_idx) {
3279                        uni_common::Value::Node(n) => Some(n.vid.as_u64()),
3280                        uni_common::Value::Int(i) => Some(i as u64),
3281                        _ => None,
3282                    }
3283                };
3284                Ok(vid_opt
3285                    .and_then(|v| vid_to_score.get(&v).copied())
3286                    .map(uni_locy::FeatureValue::Float)
3287                    .unwrap_or(uni_locy::FeatureValue::Null))
3288            }
3289            FeatureResolverKind::NeighborAggregate {
3290                subject_col,
3291                op,
3292                vid_to_values,
3293            } => {
3294                let vid_opt = extract_vid_from_column(batch.column(*subject_col).as_ref(), row_idx);
3295                Ok(vid_opt
3296                    .and_then(|v| vid_to_values.get(&v))
3297                    .and_then(|values| op.apply(values))
3298                    .map(uni_locy::FeatureValue::Float)
3299                    .unwrap_or(uni_locy::FeatureValue::Null))
3300            }
3301        }
3302    }
3303}
3304
3305/// Extract a node VID from a per-row batch column. Handles the three
3306/// common shapes: `_vid` UInt64 columns, Int64 columns, and Node-encoded
3307/// LargeBinary columns (the standard `uni_common::Value::Node` codec
3308/// representation for a bare variable column).
3309fn extract_vid_from_column(col: &dyn arrow_array::Array, row_idx: usize) -> Option<u64> {
3310    if col.is_null(row_idx) {
3311        return None;
3312    }
3313    if let Some(arr) = col.as_any().downcast_ref::<arrow_array::UInt64Array>() {
3314        return Some(arr.value(row_idx));
3315    }
3316    if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Int64Array>() {
3317        return Some(arr.value(row_idx) as u64);
3318    }
3319    match extract_common_value(col, row_idx) {
3320        uni_common::Value::Node(n) => Some(n.vid.as_u64()),
3321        uni_common::Value::Int(i) => Some(i as u64),
3322        _ => None,
3323    }
3324}
3325
3326#[allow(clippy::too_many_arguments)]
3327fn build_feature_resolvers(
3328    batch: &RecordBatch,
3329    invocation: &uni_locy::ModelInvocation,
3330    path_context_handles: &HashMap<
3331        String,
3332        crate::query::df_graph::locy_model_invoke::PathContextHandle,
3333    >,
3334    semantic_match_embeddings: &HashMap<String, Vec<f32>>,
3335    graph_feature_maps: &HashMap<String, Arc<HashMap<u64, f64>>>,
3336    neighbor_feature_maps: &NeighborFeatureMaps,
3337) -> DFResult<Vec<FeatureResolver>> {
3338    use uni_cypher::ast::Expr;
3339    let schema = batch.schema();
3340    let lookup_col = |name_or_property: String| -> DFResult<usize> {
3341        schema.index_of(&name_or_property).map_err(|_| {
3342            datafusion::error::DataFusionError::Execution(format!(
3343                "feature column '{name_or_property}' not found in clause body output schema"
3344            ))
3345        })
3346    };
3347    // Resolve a feature sub-expression to a per-row value source. Variables
3348    // and property accesses map to batch columns; list/scalar literals
3349    // become inline constants — required so `similar_to(s.embedding, [1,0,0])`
3350    // works without a hidden column for the literal vector.
3351    let resolve_src = |expr: &Expr| -> DFResult<FeatureValueSrc> {
3352        match expr {
3353            Expr::Variable(name) => {
3354                let col = if schema.index_of(name).is_ok() {
3355                    name.clone()
3356                } else {
3357                    let vid_name = format!("{}._vid", name);
3358                    if schema.index_of(&vid_name).is_ok() {
3359                        vid_name
3360                    } else {
3361                        name.clone()
3362                    }
3363                };
3364                Ok(FeatureValueSrc::Col(lookup_col(col)?))
3365            }
3366            Expr::Property(boxed, prop) if matches!(boxed.as_ref(), Expr::Variable(_)) => {
3367                let Expr::Variable(v) = boxed.as_ref() else {
3368                    unreachable!()
3369                };
3370                let direct = format!("{}.{}", v, prop);
3371                let col = if schema.index_of(&direct).is_ok() {
3372                    direct
3373                } else {
3374                    format!("__feat_{}_{}", v, prop)
3375                };
3376                Ok(FeatureValueSrc::Col(lookup_col(col)?))
3377            }
3378            Expr::Literal(lit) => Ok(FeatureValueSrc::Const(lit.to_value())),
3379            Expr::List(items) => {
3380                let mut out = Vec::with_capacity(items.len());
3381                for it in items {
3382                    out.push(match it {
3383                        Expr::Literal(lit) => lit.to_value(),
3384                        _ => uni_common::Value::Null,
3385                    });
3386                }
3387                Ok(FeatureValueSrc::Const(uni_common::Value::List(out)))
3388            }
3389            other => Err(datafusion::error::DataFusionError::Execution(format!(
3390                "unsupported feature sub-expression: {other:?}"
3391            ))),
3392        }
3393    };
3394
3395    // Phase D D3 runtime: when the model declares a path-context
3396    // feature, build a `vid → FeatureValue` lookup once from the
3397    // source rule's converged facts and wrap it in an Arc so the
3398    // per-row resolver does a single hash lookup. The model's
3399    // `INPUT` bindings are unused under this form for MVP — the
3400    // resolver's binding name is the column name (matches how the
3401    // mock-classifier feature-driver pattern in TCK consumes it).
3402    if let Some(pc) = &invocation.path_context {
3403        let handle = path_context_handles.get(&pc.source_rule).ok_or_else(|| {
3404            datafusion::error::DataFusionError::Execution(format!(
3405                "model '{}' path_context references rule '{}' but no DerivedScanHandle \
3406                 was registered; this should never happen — the build_clause path \
3407                 mints a handle for every distinct source_rule in the invocation set",
3408                invocation.model_name, pc.source_rule
3409            ))
3410        })?;
3411        let subject_col = schema
3412            .index_of(&format!("{}._vid", pc.subject_var))
3413            .or_else(|_| schema.index_of(&pc.subject_var))
3414            .map_err(|_| {
3415                datafusion::error::DataFusionError::Execution(format!(
3416                    "model '{}' path_context: subject column '{}' (or '{0}._vid') not \
3417                     in body batch schema",
3418                    invocation.model_name, pc.subject_var
3419                ))
3420            })?;
3421        let vid_to_value =
3422            build_path_context_lookup(handle, &pc.subject_var, &pc.column, &invocation.model_name)?;
3423        return Ok(vec![FeatureResolver {
3424            binding_name: pc.column.clone(),
3425            kind: FeatureResolverKind::PathContext {
3426                subject_col,
3427                vid_to_value: Arc::new(vid_to_value),
3428            },
3429        }]);
3430    }
3431
3432    let mut out = Vec::with_capacity(invocation.feature_exprs.len());
3433    for (i, fexpr) in invocation.feature_exprs.iter().enumerate() {
3434        let binding_name = invocation.feature_names[i].clone();
3435        let kind = match fexpr {
3436            Expr::FunctionCall { name, args, .. } if name == "similar_to" => {
3437                if args.len() != 2 {
3438                    return Err(datafusion::error::DataFusionError::Execution(format!(
3439                        "similar_to expects 2 args, got {}",
3440                        args.len()
3441                    )));
3442                }
3443                FeatureResolverKind::SimilarTo {
3444                    left: resolve_src(&args[0])?,
3445                    right: resolve_src(&args[1])?,
3446                }
3447            }
3448            Expr::FunctionCall { name, args, .. } if name == "semantic_match" => {
3449                // Phase D D2: lower `semantic_match(prop, 'text')` to a
3450                // `SimilarTo` resolver with the pre-embedded query
3451                // vector as the right side. The literal text was embedded
3452                // once via the Xervo runtime in `pre_embed_semantic_match_queries`.
3453                if args.len() != 2 {
3454                    return Err(datafusion::error::DataFusionError::Execution(format!(
3455                        "semantic_match expects 2 args, got {}",
3456                        args.len()
3457                    )));
3458                }
3459                let text = match &args[1] {
3460                    Expr::Literal(uni_cypher::ast::CypherLiteral::String(s)) => s.clone(),
3461                    other => {
3462                        return Err(datafusion::error::DataFusionError::Execution(format!(
3463                            "semantic_match: 2nd arg must be a string literal, got {other:?}"
3464                        )));
3465                    }
3466                };
3467                let embedded = semantic_match_embeddings.get(&text).ok_or_else(|| {
3468                    datafusion::error::DataFusionError::Execution(format!(
3469                        "semantic_match: query text '{text}' was not pre-embedded. \
3470                         This is a bug — `apply_model_invocations` should have \
3471                         embedded all unique semantic_match texts up front. Most \
3472                         likely the Xervo runtime is not configured (configure \
3473                         via `LocyConfig::xervo_runtime` or its equivalent)."
3474                    ))
3475                })?;
3476                let right_vec: Vec<f32> = embedded.clone();
3477                FeatureResolverKind::SimilarTo {
3478                    left: resolve_src(&args[0])?,
3479                    right: FeatureValueSrc::Const(uni_common::Value::Vector(right_vec)),
3480                }
3481            }
3482            Expr::FunctionCall { name, args, .. }
3483                if matches!(
3484                    name.as_str(),
3485                    "degree_centrality"
3486                        | "pagerank_score"
3487                        | "closeness_centrality"
3488                        | "betweenness_centrality"
3489                        | "eigenvector_centrality"
3490                        | "harmonic_centrality"
3491                        | "katz_centrality"
3492                ) =>
3493            {
3494                if args.len() != 1 {
3495                    return Err(datafusion::error::DataFusionError::Execution(format!(
3496                        "{name} expects 1 arg, got {}",
3497                        args.len()
3498                    )));
3499                }
3500                let Expr::Variable(v) = &args[0] else {
3501                    return Err(datafusion::error::DataFusionError::Execution(format!(
3502                        "{name}(...) argument must be a node variable, got {:?}",
3503                        args[0]
3504                    )));
3505                };
3506                let subject_col = {
3507                    let direct = schema.index_of(v).ok();
3508                    let vid_name = format!("{}._vid", v);
3509                    let vid_col = schema.index_of(&vid_name).ok();
3510                    vid_col.or(direct).ok_or_else(|| {
3511                        datafusion::error::DataFusionError::Execution(format!(
3512                            "{name}: subject column '{v}' (or '{v}._vid') not in body batch schema"
3513                        ))
3514                    })?
3515                };
3516                let vid_to_score = graph_feature_maps.get(name).cloned().ok_or_else(|| {
3517                    datafusion::error::DataFusionError::Execution(format!(
3518                        "{name}: pre-computed score map missing. This is a bug — \
3519                         `apply_model_invocations` should have called \
3520                         `precompute_graph_feature_maps` for every graph-structural \
3521                         feature before building resolvers. Most likely the graph \
3522                         algorithm registry is not configured."
3523                    ))
3524                })?;
3525                FeatureResolverKind::GraphAlgoScore {
3526                    subject_col,
3527                    vid_to_score,
3528                }
3529            }
3530            Expr::FunctionCall { name, args, .. }
3531                if matches!(
3532                    name.as_str(),
3533                    "avg_neighbor" | "max_neighbor" | "sum_neighbor"
3534                ) =>
3535            {
3536                if args.len() != 3 && args.len() != 4 {
3537                    return Err(datafusion::error::DataFusionError::Execution(format!(
3538                        "{name} expects 3 or 4 args, got {}",
3539                        args.len()
3540                    )));
3541                }
3542                let Expr::Variable(v) = &args[0] else {
3543                    return Err(datafusion::error::DataFusionError::Execution(format!(
3544                        "{name}(...) first argument must be a node variable, got {:?}",
3545                        args[0]
3546                    )));
3547                };
3548                let rel_type = match &args[1] {
3549                    Expr::Literal(uni_cypher::ast::CypherLiteral::String(s)) => s.clone(),
3550                    other => {
3551                        return Err(datafusion::error::DataFusionError::Execution(format!(
3552                            "{name}: 2nd arg must be a string literal (rel-type), got {other:?}"
3553                        )));
3554                    }
3555                };
3556                let prop_name = match &args[2] {
3557                    Expr::Literal(uni_cypher::ast::CypherLiteral::String(s)) => s.clone(),
3558                    other => {
3559                        return Err(datafusion::error::DataFusionError::Execution(format!(
3560                            "{name}: 3rd arg must be a string literal (property), got {other:?}"
3561                        )));
3562                    }
3563                };
3564                let direction_arg = match args.get(3) {
3565                    None => NeighborDirection::Outgoing,
3566                    Some(Expr::Literal(uni_cypher::ast::CypherLiteral::String(d))) => {
3567                        match d.to_uppercase().as_str() {
3568                            "OUTGOING" => NeighborDirection::Outgoing,
3569                            "INCOMING" => NeighborDirection::Incoming,
3570                            "BOTH" => NeighborDirection::Both,
3571                            other => {
3572                                return Err(datafusion::error::DataFusionError::Execution(
3573                                    format!(
3574                                        "{name}: direction must be OUTGOING|INCOMING|BOTH, got '{other}'"
3575                                    ),
3576                                ));
3577                            }
3578                        }
3579                    }
3580                    Some(other) => {
3581                        return Err(datafusion::error::DataFusionError::Execution(format!(
3582                            "{name}: 4th arg must be a string literal (direction), got {other:?}"
3583                        )));
3584                    }
3585                };
3586                let subject_col = {
3587                    let direct = schema.index_of(v).ok();
3588                    let vid_name = format!("{}._vid", v);
3589                    let vid_col = schema.index_of(&vid_name).ok();
3590                    vid_col.or(direct).ok_or_else(|| {
3591                        datafusion::error::DataFusionError::Execution(format!(
3592                            "{name}: subject column '{v}' (or '{v}._vid') not in body batch schema"
3593                        ))
3594                    })?
3595                };
3596                let vid_to_values = neighbor_feature_maps
3597                    .get(&(rel_type.clone(), prop_name.clone(), direction_arg))
3598                    .cloned()
3599                    .ok_or_else(|| {
3600                        datafusion::error::DataFusionError::Execution(format!(
3601                            "{name}: pre-computed neighbor map missing for ({rel_type}, {prop_name}, {direction_arg:?}). \
3602                             This is a bug — `apply_model_invocations` should have called \
3603                             `precompute_neighbor_feature_maps` for every neighbor-aggregator \
3604                             feature before building resolvers."
3605                        ))
3606                    })?;
3607                let op = NeighborAgg::from_fn_name(name).unwrap();
3608                FeatureResolverKind::NeighborAggregate {
3609                    subject_col,
3610                    op,
3611                    vid_to_values,
3612                }
3613            }
3614            other => match resolve_src(other)? {
3615                FeatureValueSrc::Col(idx) => FeatureResolverKind::Direct(idx),
3616                FeatureValueSrc::Const(_) => {
3617                    return Err(datafusion::error::DataFusionError::Execution(format!(
3618                        "model '{}' feature must reference a variable or property — got a literal",
3619                        invocation.model_name
3620                    )));
3621                }
3622            },
3623        };
3624        out.push(FeatureResolver { binding_name, kind });
3625    }
3626    Ok(out)
3627}
3628
3629/// Phase D D2: scan invocations' feature expressions for
3630/// `semantic_match(prop, 'text')` calls and embed each distinct
3631/// query string once via the Xervo runtime. Returns a
3632/// `text → Vec<f32>` map consumed at resolver-build time. Errors
3633/// cleanly when `semantic_match` is used without a configured
3634/// Xervo runtime.
3635async fn pre_embed_semantic_match_queries(
3636    invocations: &[uni_locy::ModelInvocation],
3637    xervo_runtime: &crate::query::df_graph::locy_model_invoke::XervoRuntimeHandle,
3638) -> DFResult<HashMap<String, Vec<f32>>> {
3639    use uni_cypher::ast::{CypherLiteral, Expr};
3640    // Collect (text, embedder_alias) pairs. The alias is per-model
3641    // (Phase D D2 follow-up): each invocation's `embedder_alias` overrides
3642    // the runtime-wide `"default"`. Two invocations sharing the same
3643    // (text, alias) reuse one embed call; same text under different
3644    // aliases is embedded twice. The cache key remains plain `text` —
3645    // a model's resolver fetches embeddings via the text it knows, and
3646    // mixed-alias re-embed of identical text under a different alias is
3647    // a rare-enough edge case that the "last writer wins" cache shape
3648    // is acceptable (documented).
3649    let mut needed: Vec<(String, String)> = Vec::new();
3650    for inv in invocations {
3651        let alias = inv
3652            .embedder_alias
3653            .clone()
3654            .unwrap_or_else(|| "default".to_string());
3655        for fexpr in &inv.feature_exprs {
3656            if let Expr::FunctionCall { name, args, .. } = fexpr
3657                && name == "semantic_match"
3658                && args.len() == 2
3659                && let Expr::Literal(CypherLiteral::String(s)) = &args[1]
3660            {
3661                let tuple = (s.clone(), alias.clone());
3662                if !needed.contains(&tuple) {
3663                    needed.push(tuple);
3664                }
3665            }
3666        }
3667    }
3668    if needed.is_empty() {
3669        return Ok(HashMap::new());
3670    }
3671    let runtime = xervo_runtime.as_ref().ok_or_else(|| {
3672        datafusion::error::DataFusionError::Execution(
3673            "semantic_match: Uni-Xervo runtime not configured. Either provide \
3674             one via `LocyConfig::xervo_runtime` (or its equivalent setup \
3675             path) or pre-compute the query embedding and pass it via \
3676             `similar_to(prop, <literal_vector>)`."
3677                .to_string(),
3678        )
3679    })?;
3680    // Group needed (text, alias) by alias so each embedder is consulted
3681    // exactly once.
3682    let mut by_alias: HashMap<String, Vec<String>> = HashMap::new();
3683    for (text, alias) in &needed {
3684        by_alias
3685            .entry(alias.clone())
3686            .or_default()
3687            .push(text.clone());
3688    }
3689    let mut out: HashMap<String, Vec<f32>> = HashMap::new();
3690    for (alias, texts) in by_alias {
3691        let embedder = runtime.embedding(&alias).await.map_err(|e| {
3692            datafusion::error::DataFusionError::Execution(format!(
3693                "semantic_match: failed to obtain embedder for alias '{alias}': {e}. \
3694                 Register an embedder under that alias in your Uni-Xervo runtime, or \
3695                 pre-compute the query embedding and pass via similar_to."
3696            ))
3697        })?;
3698        let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
3699        let embeddings = embedder
3700            .embed(&text_refs)
3701            .await
3702            .map_err(|e| {
3703                datafusion::error::DataFusionError::Execution(format!(
3704                    "semantic_match: embedder '{alias}' call failed: {e}"
3705                ))
3706            })?
3707            .vectors;
3708        if embeddings.len() != texts.len() {
3709            return Err(datafusion::error::DataFusionError::Execution(format!(
3710                "semantic_match: embedder '{alias}' returned {} vectors for {} queries",
3711                embeddings.len(),
3712                texts.len()
3713            )));
3714        }
3715        for (text, vec) in texts.into_iter().zip(embeddings) {
3716            out.insert(text, vec);
3717        }
3718    }
3719    Ok(out)
3720}
3721
3722/// Phase D D1 graph-structural: scan invocations' feature expressions
3723/// for `degree_centrality(n)` / `pagerank_score(n)` / `closeness_centrality(n)`
3724/// calls and invoke the corresponding `uni.algo.*` procedure on the
3725/// configured `AlgorithmRegistry` once per distinct call. Returns a
3726/// `fn_name → Arc<HashMap<vid, score>>` map consumed at resolver-build
3727/// time. Errors cleanly when a graph-structural FEATURE is used
3728/// without a configured registry or storage handle.
3729///
3730/// Pre-computation is `O(graph)` per call. Across fixpoint iterations
3731/// the graph state can change, so the cache lives for the lifetime of
3732/// one `apply_model_invocations` call only — same lifetime as the
3733/// D2 query-embedding cache (`pre_embed_semantic_match_queries`).
3734async fn precompute_graph_feature_maps(
3735    invocations: &[uni_locy::ModelInvocation],
3736    graph_algo: &crate::query::df_graph::locy_model_invoke::GraphAlgoHandle,
3737) -> DFResult<HashMap<String, Arc<HashMap<u64, f64>>>> {
3738    use futures::StreamExt;
3739    use uni_algo::algo::procedures::AlgoContext;
3740    use uni_cypher::ast::Expr;
3741
3742    // Map our user-facing FEATURE function names to the canonical
3743    // `uni.algo.*` procedure names registered in `AlgorithmRegistry`.
3744    fn procedure_for(fn_name: &str) -> Option<&'static str> {
3745        match fn_name {
3746            "degree_centrality" => Some("uni.algo.degreeCentrality"),
3747            "pagerank_score" => Some("uni.algo.pageRank"),
3748            "closeness_centrality" => Some("uni.algo.closeness"),
3749            "betweenness_centrality" => Some("uni.algo.betweenness"),
3750            "eigenvector_centrality" => Some("uni.algo.eigenvectorCentrality"),
3751            "harmonic_centrality" => Some("uni.algo.harmonicCentrality"),
3752            "katz_centrality" => Some("uni.algo.katzCentrality"),
3753            _ => None,
3754        }
3755    }
3756
3757    // Collect the set of distinct topology-FEATURE names referenced
3758    // across all invocations. Args are always a single Variable, so
3759    // the precomputation key is just the function name.
3760    let mut needed: Vec<String> = Vec::new();
3761    for inv in invocations {
3762        for fexpr in &inv.feature_exprs {
3763            if let Expr::FunctionCall { name, .. } = fexpr
3764                && procedure_for(name).is_some()
3765                && !needed.contains(name)
3766            {
3767                needed.push(name.clone());
3768            }
3769        }
3770    }
3771    if needed.is_empty() {
3772        return Ok(HashMap::new());
3773    }
3774
3775    let registry = graph_algo.registry.as_ref().ok_or_else(|| {
3776        datafusion::error::DataFusionError::Execution(
3777            "graph-structural FEATURE invoked but no `AlgorithmRegistry` is \
3778             configured. Configure one on `GraphExecutionContext::with_algo_registry`."
3779                .to_string(),
3780        )
3781    })?;
3782    let storage = graph_algo.storage.as_ref().ok_or_else(|| {
3783        datafusion::error::DataFusionError::Execution(
3784            "graph-structural FEATURE invoked but no storage handle was \
3785             threaded into the FEATURE runtime. This is a bug in df_planner."
3786                .to_string(),
3787        )
3788    })?;
3789
3790    let mut out: HashMap<String, Arc<HashMap<u64, f64>>> = HashMap::new();
3791    for fn_name in needed {
3792        let proc_name = procedure_for(&fn_name).unwrap();
3793        let procedure = registry.get(proc_name).ok_or_else(|| {
3794            datafusion::error::DataFusionError::Execution(format!(
3795                "graph-structural FEATURE '{fn_name}' resolves to procedure \
3796                 '{proc_name}' which is not in the algorithm registry"
3797            ))
3798        })?;
3799        // Topology procedures take (nodeLabels[], relationshipTypes[],
3800        // [direction], [...]) — pass empty arrays for nodeLabels and
3801        // relationshipTypes to mean "all". The procedure fills the
3802        // remaining optional args from its signature defaults.
3803        let args: Vec<serde_json::Value> = vec![
3804            serde_json::Value::Array(Vec::new()),
3805            serde_json::Value::Array(Vec::new()),
3806        ];
3807        let algo_ctx = AlgoContext::new(
3808            storage.clone(),
3809            graph_algo.l0_manager.as_ref().map(Arc::clone),
3810        );
3811        // The AlgoProcedure trait routes direct (nodeLabels, edgeTypes)
3812        // args through the V2 projection entry point: build a projection
3813        // from the direct args, then execute against it.
3814        //
3815        // Fill optional algorithm-specific args (e.g. degree_centrality's
3816        // `direction`, eigenvector/katz `weightProperty`) with their schema
3817        // defaults for the projection build: `build_projection_from_direct_args`
3818        // feeds the specific args (`args[2..]`) to the adapter's
3819        // `customize_projection`, which indexes them positionally — the two
3820        // empty placeholder arrays alone would leave that slice empty and
3821        // panic. `validate_args` fills missing optionals WITHOUT type-checking
3822        // the defaults (some are `Null`-typed sentinels), so this never errors
3823        // for the placeholder shape.
3824        //
3825        // We pass the ORIGINAL `args` (not the filled ones) to
3826        // `execute_with_projection`, which re-runs `validate_args` internally:
3827        // re-feeding already-filled args would make those defaults look
3828        // "provided" and trip the type-check (`weightProperty: Null` vs
3829        // `String`). This mirrors the (now-removed) legacy
3830        // `AlgoProcedure::execute`, which validated once then built + ran.
3831        let filled_args = procedure
3832            .signature()
3833            .validate_args(args.clone())
3834            .map_err(|e| {
3835                datafusion::error::DataFusionError::Execution(format!(
3836                    "graph-structural FEATURE '{fn_name}': argument validation failed: {e}"
3837                ))
3838            })?;
3839        let projection = uni_algo::algo::procedure_template::build_projection_from_direct_args(
3840            procedure.as_ref(),
3841            &algo_ctx,
3842            &filled_args,
3843        )
3844        .await
3845        .map_err(|e| {
3846            datafusion::error::DataFusionError::Execution(format!(
3847                "graph-structural FEATURE '{fn_name}': projection build failed: {e}"
3848            ))
3849        })?;
3850        let mut stream = procedure.execute_with_projection(algo_ctx, args, projection);
3851        let mut score_map: HashMap<u64, f64> = HashMap::new();
3852        let sig = procedure.signature();
3853        let node_idx = sig
3854            .yields
3855            .iter()
3856            .position(|(n, _)| *n == "nodeId")
3857            .ok_or_else(|| {
3858                datafusion::error::DataFusionError::Execution(format!(
3859                    "procedure '{proc_name}' yield schema missing 'nodeId'"
3860                ))
3861            })?;
3862        // Most `uni.algo.*` centrality procedures yield `score`; the
3863        // `harmonicCentrality` family yields `centrality` instead. Accept
3864        // either to keep this dispatch independent of procedure-internal
3865        // naming choices.
3866        let score_idx = sig
3867            .yields
3868            .iter()
3869            .position(|(n, _)| *n == "score" || *n == "centrality")
3870            .ok_or_else(|| {
3871                datafusion::error::DataFusionError::Execution(format!(
3872                    "procedure '{proc_name}' yield schema missing a numeric score column \
3873                     (expected 'score' or 'centrality')"
3874                ))
3875            })?;
3876        while let Some(row_res) = stream.next().await {
3877            let row = row_res.map_err(|e| {
3878                datafusion::error::DataFusionError::Execution(format!(
3879                    "graph-structural FEATURE '{fn_name}': procedure '{proc_name}' failed: {e}"
3880                ))
3881            })?;
3882            let vid_v = row.values.get(node_idx);
3883            let score_v = row.values.get(score_idx);
3884            let (Some(vid_v), Some(score_v)) = (vid_v, score_v) else {
3885                continue;
3886            };
3887            let vid = vid_v.as_u64().or_else(|| vid_v.as_i64().map(|i| i as u64));
3888            let score = score_v
3889                .as_f64()
3890                .or_else(|| score_v.as_i64().map(|i| i as f64));
3891            if let (Some(vid), Some(score)) = (vid, score) {
3892                score_map.insert(vid, score);
3893            }
3894        }
3895        out.insert(fn_name, Arc::new(score_map));
3896    }
3897    Ok(out)
3898}
3899
3900/// Phase D D1 graph-structural: one-hop neighborhood aggregator
3901/// precompute. Scans invocations' feature expressions for
3902/// `avg_neighbor` / `max_neighbor` / `sum_neighbor` FunctionCalls,
3903/// collects the distinct `(rel_type, prop_name)` pairs they need,
3904/// resolves each rel-type to a schema edge-type id, warms the
3905/// outgoing-adjacency CSR, and for every subject vid present in the
3906/// body batches walks the one-hop neighborhood and fetches the
3907/// requested property from each neighbor via `PropertyManager`.
3908/// Non-numeric neighbor property values are filtered out via
3909/// `Value::as_f64`.
3910///
3911/// Returns `Arc<HashMap<u64, Vec<f64>>>` keyed by `(rel_type, prop_name)`.
3912/// The resolver's runtime cost per row is then a single hash lookup
3913/// plus an `avg`/`max`/`sum` over the cached `Vec<f64>`.
3914///
3915/// Scope: **subject-set-only** — we only collect for vids that appear
3916/// in the body batches' subject columns (avoids pre-walking the entire
3917/// graph). Subjects with no outgoing edges of the named type land in
3918/// the map with an empty `Vec` so the resolver's `Null` semantics
3919/// remain crisp (empty → `Null` → classifier interprets per its
3920/// feature contract).
3921/// Per-`(rel_type, prop_name, direction)` cache of neighbor property
3922/// values keyed by subject vid, produced by
3923/// `precompute_neighbor_feature_maps` and consumed by
3924/// `FeatureResolverKind::NeighborAggregate` resolvers.
3925type NeighborFeatureMaps =
3926    HashMap<(String, String, NeighborDirection), Arc<HashMap<u64, Vec<f64>>>>;
3927
3928async fn precompute_neighbor_feature_maps(
3929    invocations: &[uni_locy::ModelInvocation],
3930    batches: &[RecordBatch],
3931    graph_algo: &crate::query::df_graph::locy_model_invoke::GraphAlgoHandle,
3932) -> DFResult<NeighborFeatureMaps> {
3933    use uni_cypher::ast::{CypherLiteral, Expr};
3934
3935    // Collect distinct (subject_var, rel_type, prop_name, direction)
3936    // tuples needed across all invocations. The subject_var tells us
3937    // which body batch column to scan for subject vids; the direction
3938    // is optional in the AST (defaults to OUTGOING).
3939    let parse_direction = |arg: Option<&Expr>| -> Option<NeighborDirection> {
3940        match arg {
3941            None => Some(NeighborDirection::Outgoing),
3942            Some(Expr::Literal(CypherLiteral::String(d))) => match d.to_uppercase().as_str() {
3943                "OUTGOING" => Some(NeighborDirection::Outgoing),
3944                "INCOMING" => Some(NeighborDirection::Incoming),
3945                "BOTH" => Some(NeighborDirection::Both),
3946                _ => None,
3947            },
3948            _ => None,
3949        }
3950    };
3951    let mut needed: Vec<(String, String, String, NeighborDirection)> = Vec::new();
3952    for inv in invocations {
3953        for fexpr in &inv.feature_exprs {
3954            if let Expr::FunctionCall { name, args, .. } = fexpr
3955                && NeighborAgg::from_fn_name(name).is_some()
3956                && (args.len() == 3 || args.len() == 4)
3957                && let Expr::Variable(v) = &args[0]
3958                && let Expr::Literal(CypherLiteral::String(rel)) = &args[1]
3959                && let Expr::Literal(CypherLiteral::String(prop)) = &args[2]
3960                && let Some(direction) = parse_direction(args.get(3))
3961            {
3962                let tuple = (v.clone(), rel.clone(), prop.clone(), direction);
3963                if !needed.contains(&tuple) {
3964                    needed.push(tuple);
3965                }
3966            }
3967        }
3968    }
3969    if needed.is_empty() {
3970        return Ok(HashMap::new());
3971    }
3972
3973    let storage = graph_algo.storage.as_ref().ok_or_else(|| {
3974        datafusion::error::DataFusionError::Execution(
3975            "neighbor-aggregator FEATURE invoked but no storage handle was \
3976             threaded into the FEATURE runtime. This is a bug in df_planner."
3977                .to_string(),
3978        )
3979    })?;
3980    let property_manager = graph_algo.property_manager.as_ref().ok_or_else(|| {
3981        datafusion::error::DataFusionError::Execution(
3982            "neighbor-aggregator FEATURE invoked but no PropertyManager was \
3983             threaded into the FEATURE runtime. This is a bug in df_planner."
3984                .to_string(),
3985        )
3986    })?;
3987    // Build a QueryContext snapshot so L0-resident vertex properties
3988    // are visible to `get_vertex_prop_with_ctx`. Without a ctx, L0
3989    // property data is silently invisible (returns Null), which is
3990    // why the topology trio's `AlgoContext` consumes L0 via
3991    // `L0Manager` whereas property reads need this separate path.
3992    let query_ctx = graph_algo.l0_buffers.as_ref().map(|bufs| {
3993        uni_store::runtime::context::QueryContext::new_with_pending(
3994            bufs.current.clone(),
3995            bufs.transaction.clone(),
3996            bufs.pending_flush.clone(),
3997        )
3998    });
3999
4000    // Group needed tuples by (rel_type, prop_name, direction) — one
4001    // precomputed map per key, regardless of which subject_var binding
4002    // points at it (the subject vids are unioned).
4003    let mut by_key: HashMap<(String, String, NeighborDirection), Vec<String>> = HashMap::new();
4004    for (subject_var, rel, prop, direction) in needed {
4005        by_key
4006            .entry((rel, prop, direction))
4007            .or_default()
4008            .push(subject_var);
4009    }
4010
4011    let mut out: NeighborFeatureMaps = HashMap::new();
4012    for ((rel_type, prop_name, direction), subject_vars) in by_key {
4013        // Resolve edge_type_id from schema.
4014        let schema = storage.schema_manager().schema();
4015        let Some(edge_meta) = schema.edge_types.get(&rel_type) else {
4016            // Unregistered rel-type → empty map. The resolver surfaces
4017            // Null at row time, consistent with the no-neighbor case.
4018            out.insert((rel_type, prop_name, direction), Arc::new(HashMap::new()));
4019            continue;
4020        };
4021        let edge_type_id = edge_meta.id;
4022
4023        // Warm adjacency for every direction we'll traverse. Mirrors
4024        // the pattern in projection.rs / procedure_template.rs.
4025        let edge_ver = storage.get_edge_version_by_id(edge_type_id);
4026        for dir in direction.store_directions() {
4027            storage
4028                .warm_adjacency(edge_type_id, *dir, edge_ver)
4029                .await
4030                .map_err(|e| {
4031                    datafusion::error::DataFusionError::Execution(format!(
4032                        "neighbor-aggregator warm_adjacency for '{rel_type}' / {dir:?} failed: {e}"
4033                    ))
4034                })?;
4035        }
4036
4037        // Collect distinct subject vids from body batches across every
4038        // subject_var binding that this (rel, prop) pair uses.
4039        let mut subject_vids: std::collections::HashSet<u64> = std::collections::HashSet::new();
4040        for subject_var in &subject_vars {
4041            for batch in batches {
4042                let schema = batch.schema();
4043                let col_idx = schema
4044                    .index_of(&format!("{}._vid", subject_var))
4045                    .ok()
4046                    .or_else(|| schema.index_of(subject_var).ok());
4047                let Some(col_idx) = col_idx else { continue };
4048                let col = batch.column(col_idx);
4049                for row in 0..batch.num_rows() {
4050                    if let Some(v) = extract_vid_from_column(col.as_ref(), row) {
4051                        subject_vids.insert(v);
4052                    }
4053                }
4054            }
4055        }
4056
4057        // For each subject, walk edges in the configured direction(s),
4058        // fetch neighbor property, coerce to f64, accumulate. Subjects
4059        // with no numeric neighbors retain an empty Vec (→ Null at
4060        // row time).
4061        let mut vid_to_values: HashMap<u64, Vec<f64>> = HashMap::new();
4062        let adj = storage.adjacency_manager();
4063        for subject_vid in subject_vids {
4064            let mut neighbors: Vec<(uni_common::core::id::Vid, uni_common::core::id::Eid)> =
4065                Vec::new();
4066            for dir in direction.store_directions() {
4067                neighbors.extend(adj.get_neighbors(
4068                    uni_common::core::id::Vid::from(subject_vid),
4069                    edge_type_id,
4070                    *dir,
4071                ));
4072            }
4073            let mut values: Vec<f64> = Vec::with_capacity(neighbors.len());
4074            for (neighbor_vid, _eid) in neighbors {
4075                let val = property_manager
4076                    .get_vertex_prop_with_ctx(neighbor_vid, &prop_name, query_ctx.as_ref())
4077                    .await
4078                    .map_err(|e| {
4079                        datafusion::error::DataFusionError::Execution(format!(
4080                            "neighbor-aggregator: failed to read property \
4081                             '{prop_name}' on neighbor vid {neighbor_vid:?}: {e}"
4082                        ))
4083                    })?;
4084                if let Some(f) = val.as_f64()
4085                    && !f.is_nan()
4086                {
4087                    values.push(f);
4088                }
4089            }
4090            vid_to_values.insert(subject_vid, values);
4091        }
4092        out.insert((rel_type, prop_name, direction), Arc::new(vid_to_values));
4093    }
4094    Ok(out)
4095}
4096
4097/// Phase D D3: walk the source rule's converged batches and build
4098/// a `vid → FeatureValue` lookup for the named column. The subject
4099/// column in the derived rule's schema holds VIDs (UInt64) for node
4100/// variables; the value column type follows the rule's yield-schema
4101/// inference (typically Float64 / Int64 / Bool / String).
4102fn build_path_context_lookup(
4103    handle: &crate::query::df_graph::locy_model_invoke::PathContextHandle,
4104    _subject_var: &str,
4105    column: &str,
4106    model_name: &str,
4107) -> DFResult<HashMap<u64, uni_locy::FeatureValue>> {
4108    // The source rule's KEY column is its first yield column by
4109    // convention (`infer_yield_schema` orders KEYs first). The model's
4110    // local `subject_var` is just a binding alias — the actual join
4111    // matches the body row's VID against this canonical column.
4112    if handle.schema.fields().is_empty() {
4113        return Err(datafusion::error::DataFusionError::Execution(format!(
4114            "model '{model_name}' path_context: source rule has empty yield schema"
4115        )));
4116    }
4117    let subj_idx = 0_usize;
4118    let col_idx = handle.schema.index_of(column).map_err(|_| {
4119        datafusion::error::DataFusionError::Execution(format!(
4120            "model '{model_name}' path_context: column '{column}' not in \
4121             source rule's yield schema (have: {:?})",
4122            handle
4123                .schema
4124                .fields()
4125                .iter()
4126                .map(|f| f.name().clone())
4127                .collect::<Vec<_>>()
4128        ))
4129    })?;
4130    let batches = handle.data.read();
4131    let mut out: HashMap<u64, uni_locy::FeatureValue> = HashMap::new();
4132    for batch in batches.iter() {
4133        let subj_col = batch.column(subj_idx);
4134        let value_col = batch.column(col_idx);
4135        for row in 0..batch.num_rows() {
4136            if subj_col.is_null(row) {
4137                continue;
4138            }
4139            let vid = if let Some(a) = subj_col.as_any().downcast_ref::<arrow_array::UInt64Array>()
4140            {
4141                a.value(row)
4142            } else if let Some(a) = subj_col.as_any().downcast_ref::<arrow_array::Int64Array>() {
4143                a.value(row) as u64
4144            } else {
4145                continue;
4146            };
4147            let v = extract_feature_value(value_col.as_ref(), row);
4148            // Last write wins on duplicates; derived rules typically have
4149            // unique KEY values, so this is a defensive guard.
4150            out.insert(vid, v);
4151        }
4152    }
4153    Ok(out)
4154}
4155
4156/// Extract a `uni_common::Value` from one row of an Arrow column.
4157/// Used by the Phase D `similar_to` feature resolver, which needs
4158/// the raw `Value` (especially `Value::Vector(Vec<f32>)`) to feed
4159/// `eval_similar_to_pure`.
4160fn extract_common_value(col: &dyn arrow_array::Array, row_idx: usize) -> uni_common::Value {
4161    use arrow_array::{BooleanArray, Float64Array, Int64Array, LargeStringArray, StringArray};
4162    if col.is_null(row_idx) {
4163        return uni_common::Value::Null;
4164    }
4165    if let Some(a) = col.as_any().downcast_ref::<Float64Array>() {
4166        return uni_common::Value::Float(a.value(row_idx));
4167    }
4168    if let Some(a) = col.as_any().downcast_ref::<Int64Array>() {
4169        return uni_common::Value::Int(a.value(row_idx));
4170    }
4171    if let Some(a) = col.as_any().downcast_ref::<BooleanArray>() {
4172        return uni_common::Value::Bool(a.value(row_idx));
4173    }
4174    if let Some(a) = col.as_any().downcast_ref::<StringArray>() {
4175        return uni_common::Value::String(a.value(row_idx).to_string());
4176    }
4177    if let Some(a) = col.as_any().downcast_ref::<LargeStringArray>() {
4178        return uni_common::Value::String(a.value(row_idx).to_string());
4179    }
4180    if let Some(b) = col.as_any().downcast_ref::<arrow_array::LargeBinaryArray>() {
4181        let bytes = b.value(row_idx);
4182        if bytes.is_empty() {
4183            return uni_common::Value::Null;
4184        }
4185        return uni_common::cypher_value_codec::decode(bytes).unwrap_or(uni_common::Value::Null);
4186    }
4187    uni_common::Value::Null
4188}
4189
4190fn extract_feature_value(col: &dyn arrow_array::Array, row_idx: usize) -> uni_locy::FeatureValue {
4191    use arrow_array::{BooleanArray, Float64Array, Int64Array, LargeStringArray, StringArray};
4192    if col.is_null(row_idx) {
4193        return uni_locy::FeatureValue::Null;
4194    }
4195    if let Some(a) = col.as_any().downcast_ref::<Float64Array>() {
4196        return uni_locy::FeatureValue::Float(a.value(row_idx));
4197    }
4198    if let Some(a) = col.as_any().downcast_ref::<Int64Array>() {
4199        return uni_locy::FeatureValue::Int(a.value(row_idx));
4200    }
4201    if let Some(a) = col.as_any().downcast_ref::<BooleanArray>() {
4202        return uni_locy::FeatureValue::Bool(a.value(row_idx));
4203    }
4204    if let Some(a) = col.as_any().downcast_ref::<StringArray>() {
4205        return uni_locy::FeatureValue::String(a.value(row_idx).to_string());
4206    }
4207    if let Some(a) = col.as_any().downcast_ref::<LargeStringArray>() {
4208        return uni_locy::FeatureValue::String(a.value(row_idx).to_string());
4209    }
4210    // Schema-less property storage: values arrive as LargeBinary
4211    // MessagePack-encoded `CypherValue`. Decode via the standard codec
4212    // and project the result to the matching `FeatureValue` variant.
4213    if let Some(b) = col.as_any().downcast_ref::<arrow_array::LargeBinaryArray>() {
4214        let bytes = b.value(row_idx);
4215        if bytes.is_empty() {
4216            return uni_locy::FeatureValue::Null;
4217        }
4218        let v = uni_common::cypher_value_codec::decode(bytes).unwrap_or(uni_common::Value::Null);
4219        return match v {
4220            uni_common::Value::Float(f) => uni_locy::FeatureValue::Float(f),
4221            uni_common::Value::Int(i) => uni_locy::FeatureValue::Int(i),
4222            uni_common::Value::Bool(b) => uni_locy::FeatureValue::Bool(b),
4223            uni_common::Value::String(s) => uni_locy::FeatureValue::String(s),
4224            uni_common::Value::Null => uni_locy::FeatureValue::Null,
4225            _ => uni_locy::FeatureValue::Null,
4226        };
4227    }
4228    uni_locy::FeatureValue::Null
4229}
4230
4231/// Probabilistic complement for negated IS-refs targeting PROB rules.
4232///
4233/// Instead of filtering out matching VIDs (anti-join), this adds a complement
4234/// column `__prob_complement_{rule_name}` with value `1 - p` for each matching
4235/// VID, and `1.0` for VIDs not present in the negated rule's facts. Implements
4236/// `IS NOT risk` on a PROB rule: the probability that the entity is NOT risky.
4237pub fn apply_prob_complement(
4238    batches: Vec<RecordBatch>,
4239    neg_facts: &[RecordBatch],
4240    left_col: &str,
4241    right_col: &str,
4242    prob_col: &str,
4243    complement_col_name: &str,
4244) -> datafusion::error::Result<Vec<RecordBatch>> {
4245    use arrow_array::{Array as _, Float64Array, UInt64Array};
4246
4247    // Build VID → probability lookup from negative facts
4248    let mut prob_map: std::collections::HashMap<u64, f64> = std::collections::HashMap::new();
4249    for batch in neg_facts {
4250        let Ok(vid_idx) = batch.schema().index_of(right_col) else {
4251            continue;
4252        };
4253        let Ok(prob_idx) = batch.schema().index_of(prob_col) else {
4254            continue;
4255        };
4256        let Some(vids) = batch.column(vid_idx).as_any().downcast_ref::<UInt64Array>() else {
4257            continue;
4258        };
4259        let prob_arr = batch.column(prob_idx);
4260        let probs = prob_arr.as_any().downcast_ref::<Float64Array>();
4261        for i in 0..vids.len() {
4262            if !vids.is_null(i) {
4263                let p = probs
4264                    .and_then(|arr| {
4265                        if arr.is_null(i) {
4266                            None
4267                        } else {
4268                            Some(arr.value(i))
4269                        }
4270                    })
4271                    .unwrap_or(0.0);
4272                // If multiple facts for same VID, use noisy-OR combination:
4273                // combined = 1 - (1 - existing) * (1 - new)
4274                prob_map
4275                    .entry(vids.value(i))
4276                    .and_modify(|existing| {
4277                        *existing = 1.0 - (1.0 - *existing) * (1.0 - p);
4278                    })
4279                    .or_insert(p);
4280            }
4281        }
4282    }
4283
4284    // Add complement column to each batch
4285    let mut result = Vec::new();
4286    for batch in batches {
4287        let Ok(idx) = batch.schema().index_of(left_col) else {
4288            result.push(batch);
4289            continue;
4290        };
4291        let Some(vids) = batch.column(idx).as_any().downcast_ref::<UInt64Array>() else {
4292            result.push(batch);
4293            continue;
4294        };
4295
4296        // Compute complement values: 1 - p for matched VIDs, 1.0 for absent
4297        let complements: Vec<f64> = (0..vids.len())
4298            .map(|i| {
4299                if vids.is_null(i) {
4300                    1.0
4301                } else {
4302                    let p = prob_map.get(&vids.value(i)).copied().unwrap_or(0.0);
4303                    1.0 - p
4304                }
4305            })
4306            .collect();
4307
4308        let complement_arr = Float64Array::from(complements);
4309
4310        // Add the complement column to the batch
4311        let mut columns: Vec<arrow_array::ArrayRef> = batch.columns().to_vec();
4312        columns.push(std::sync::Arc::new(complement_arr));
4313
4314        let mut fields: Vec<std::sync::Arc<arrow_schema::Field>> =
4315            batch.schema().fields().iter().cloned().collect();
4316        fields.push(std::sync::Arc::new(arrow_schema::Field::new(
4317            complement_col_name,
4318            arrow_schema::DataType::Float64,
4319            true,
4320        )));
4321
4322        let new_schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
4323        let new_batch = RecordBatch::try_new(new_schema, columns).map_err(arrow_err)?;
4324        result.push(new_batch);
4325    }
4326    Ok(result)
4327}
4328
4329/// Probabilistic complement for composite (multi-column) join keys.
4330///
4331/// Builds a composite key from all `join_cols` right-side columns in
4332/// `neg_facts`, maps each composite key to a probability via noisy-OR
4333/// combination, then adds a single `complement_col_name` column with
4334/// `1 - p` for matched keys and `1.0` for absent keys.
4335pub fn apply_prob_complement_composite(
4336    batches: Vec<RecordBatch>,
4337    neg_facts: &[RecordBatch],
4338    join_cols: &[(String, String)],
4339    prob_col: &str,
4340    complement_col_name: &str,
4341) -> datafusion::error::Result<Vec<RecordBatch>> {
4342    use arrow_array::{Array as _, Float64Array, UInt64Array};
4343
4344    // Build composite-key → probability lookup from negative facts.
4345    let mut prob_map: HashMap<Vec<u64>, f64> = HashMap::new();
4346    for batch in neg_facts {
4347        let right_indices: Vec<usize> = join_cols
4348            .iter()
4349            .filter_map(|(_, rc)| batch.schema().index_of(rc).ok())
4350            .collect();
4351        if right_indices.len() != join_cols.len() {
4352            continue;
4353        }
4354        let Ok(prob_idx) = batch.schema().index_of(prob_col) else {
4355            continue;
4356        };
4357        let prob_arr = batch.column(prob_idx);
4358        let probs = prob_arr.as_any().downcast_ref::<Float64Array>();
4359        for row in 0..batch.num_rows() {
4360            let mut key = Vec::with_capacity(right_indices.len());
4361            let mut valid = true;
4362            for &ci in &right_indices {
4363                let col = batch.column(ci);
4364                if let Some(vids) = col.as_any().downcast_ref::<UInt64Array>() {
4365                    if vids.is_null(row) {
4366                        valid = false;
4367                        break;
4368                    }
4369                    key.push(vids.value(row));
4370                } else {
4371                    valid = false;
4372                    break;
4373                }
4374            }
4375            if !valid {
4376                continue;
4377            }
4378            let p = probs
4379                .and_then(|arr| {
4380                    if arr.is_null(row) {
4381                        None
4382                    } else {
4383                        Some(arr.value(row))
4384                    }
4385                })
4386                .unwrap_or(0.0);
4387            // Noisy-OR combination for duplicate composite keys.
4388            prob_map
4389                .entry(key)
4390                .and_modify(|existing| {
4391                    *existing = 1.0 - (1.0 - *existing) * (1.0 - p);
4392                })
4393                .or_insert(p);
4394        }
4395    }
4396
4397    // Add complement column to each batch.
4398    let mut result = Vec::new();
4399    for batch in batches {
4400        let left_indices: Vec<usize> = join_cols
4401            .iter()
4402            .filter_map(|(lc, _)| batch.schema().index_of(lc).ok())
4403            .collect();
4404        if left_indices.len() != join_cols.len() {
4405            result.push(batch);
4406            continue;
4407        }
4408        let all_u64 = left_indices.iter().all(|&ci| {
4409            batch
4410                .column(ci)
4411                .as_any()
4412                .downcast_ref::<UInt64Array>()
4413                .is_some()
4414        });
4415        if !all_u64 {
4416            result.push(batch);
4417            continue;
4418        }
4419
4420        let complements: Vec<f64> = (0..batch.num_rows())
4421            .map(|row| {
4422                let mut key = Vec::with_capacity(left_indices.len());
4423                for &ci in &left_indices {
4424                    let vids = batch
4425                        .column(ci)
4426                        .as_any()
4427                        .downcast_ref::<UInt64Array>()
4428                        .unwrap();
4429                    if vids.is_null(row) {
4430                        return 1.0;
4431                    }
4432                    key.push(vids.value(row));
4433                }
4434                let p = prob_map.get(&key).copied().unwrap_or(0.0);
4435                1.0 - p
4436            })
4437            .collect();
4438
4439        let complement_arr = Float64Array::from(complements);
4440        let mut columns: Vec<arrow_array::ArrayRef> = batch.columns().to_vec();
4441        columns.push(Arc::new(complement_arr));
4442
4443        let mut fields: Vec<Arc<arrow_schema::Field>> =
4444            batch.schema().fields().iter().cloned().collect();
4445        fields.push(Arc::new(arrow_schema::Field::new(
4446            complement_col_name,
4447            arrow_schema::DataType::Float64,
4448            true,
4449        )));
4450
4451        let new_schema = Arc::new(arrow_schema::Schema::new(fields));
4452        let new_batch = RecordBatch::try_new(new_schema, columns).map_err(arrow_err)?;
4453        result.push(new_batch);
4454    }
4455    Ok(result)
4456}
4457
4458/// Boolean anti-join for composite (multi-column) join keys.
4459///
4460/// Builds a `HashSet<Vec<u64>>` from `neg_facts` using all right-side
4461/// columns in `join_cols`, then filters `batches` to keep only rows
4462/// whose composite left-side key is NOT in the set.
4463pub fn apply_anti_join_composite(
4464    batches: Vec<RecordBatch>,
4465    neg_facts: &[RecordBatch],
4466    join_cols: &[(String, String)],
4467) -> datafusion::error::Result<Vec<RecordBatch>> {
4468    use arrow::compute::filter_record_batch;
4469    use arrow_array::{Array as _, BooleanArray, UInt64Array};
4470
4471    // Collect composite keys from the negated rule's derived facts.
4472    let mut banned: HashSet<Vec<u64>> = HashSet::new();
4473    for batch in neg_facts {
4474        let right_indices: Vec<usize> = join_cols
4475            .iter()
4476            .filter_map(|(_, rc)| batch.schema().index_of(rc).ok())
4477            .collect();
4478        if right_indices.len() != join_cols.len() {
4479            continue;
4480        }
4481        for row in 0..batch.num_rows() {
4482            let mut key = Vec::with_capacity(right_indices.len());
4483            let mut valid = true;
4484            for &ci in &right_indices {
4485                let col = batch.column(ci);
4486                if let Some(vids) = col.as_any().downcast_ref::<UInt64Array>() {
4487                    if vids.is_null(row) {
4488                        valid = false;
4489                        break;
4490                    }
4491                    key.push(vids.value(row));
4492                } else {
4493                    valid = false;
4494                    break;
4495                }
4496            }
4497            if valid {
4498                banned.insert(key);
4499            }
4500        }
4501    }
4502
4503    if banned.is_empty() {
4504        return Ok(batches);
4505    }
4506
4507    // Filter body batches: keep rows where composite left key NOT IN banned.
4508    let mut result = Vec::new();
4509    for batch in batches {
4510        let left_indices: Vec<usize> = join_cols
4511            .iter()
4512            .filter_map(|(lc, _)| batch.schema().index_of(lc).ok())
4513            .collect();
4514        if left_indices.len() != join_cols.len() {
4515            result.push(batch);
4516            continue;
4517        }
4518        let all_u64 = left_indices.iter().all(|&ci| {
4519            batch
4520                .column(ci)
4521                .as_any()
4522                .downcast_ref::<UInt64Array>()
4523                .is_some()
4524        });
4525        if !all_u64 {
4526            result.push(batch);
4527            continue;
4528        }
4529
4530        let keep: Vec<bool> = (0..batch.num_rows())
4531            .map(|row| {
4532                let mut key = Vec::with_capacity(left_indices.len());
4533                for &ci in &left_indices {
4534                    let vids = batch
4535                        .column(ci)
4536                        .as_any()
4537                        .downcast_ref::<UInt64Array>()
4538                        .unwrap();
4539                    if vids.is_null(row) {
4540                        return true; // null keys are never banned
4541                    }
4542                    key.push(vids.value(row));
4543                }
4544                !banned.contains(&key)
4545            })
4546            .collect();
4547        let keep_arr = BooleanArray::from(keep);
4548        let filtered = filter_record_batch(&batch, &keep_arr).map_err(arrow_err)?;
4549        if filtered.num_rows() > 0 {
4550            result.push(filtered);
4551        }
4552    }
4553    Ok(result)
4554}
4555
4556/// Multiply `__prob_complement_*` columns into the rule's PROB column and clean up.
4557///
4558/// After IS NOT probabilistic complement semantics have added `__prob_complement_*`
4559/// columns to clause results, this function:
4560/// 1. Computes the product of all complement factor columns
4561/// 2. Multiplies the product into the existing PROB column (if any)
4562/// 3. Removes the internal `__prob_complement_*` columns from the output
4563///
4564/// If the rule has no PROB column, complement columns are simply removed
4565/// (the complement information is discarded and IS NOT acts as a keep-all).
4566pub fn multiply_prob_factors(
4567    batches: Vec<RecordBatch>,
4568    prob_col: Option<&str>,
4569    complement_cols: &[String],
4570) -> datafusion::error::Result<Vec<RecordBatch>> {
4571    use arrow_array::{Array as _, Float64Array};
4572
4573    let mut result = Vec::with_capacity(batches.len());
4574
4575    for batch in batches {
4576        if batch.num_rows() == 0 {
4577            // Remove complement columns from empty batches
4578            let keep: Vec<usize> = batch
4579                .schema()
4580                .fields()
4581                .iter()
4582                .enumerate()
4583                .filter(|(_, f)| !complement_cols.contains(f.name()))
4584                .map(|(i, _)| i)
4585                .collect();
4586            let fields: Vec<_> = keep
4587                .iter()
4588                .map(|&i| batch.schema().field(i).clone())
4589                .collect();
4590            let cols: Vec<_> = keep.iter().map(|&i| batch.column(i).clone()).collect();
4591            let schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
4592            result.push(
4593                RecordBatch::try_new(schema, cols).map_err(|e| {
4594                    datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
4595                })?,
4596            );
4597            continue;
4598        }
4599
4600        let num_rows = batch.num_rows();
4601
4602        // 1. Compute product of all complement factors
4603        let mut combined = vec![1.0f64; num_rows];
4604        for col_name in complement_cols {
4605            if let Ok(idx) = batch.schema().index_of(col_name) {
4606                let arr = batch
4607                    .column(idx)
4608                    .as_any()
4609                    .downcast_ref::<Float64Array>()
4610                    .ok_or_else(|| {
4611                        datafusion::error::DataFusionError::Internal(format!(
4612                            "Expected Float64 for complement column {col_name}"
4613                        ))
4614                    })?;
4615                for (i, val) in combined.iter_mut().enumerate().take(num_rows) {
4616                    if !arr.is_null(i) {
4617                        *val *= arr.value(i);
4618                    }
4619                }
4620            }
4621        }
4622
4623        // 2. If there's a PROB column, multiply combined into it
4624        let final_prob: Vec<f64> = if let Some(prob_name) = prob_col {
4625            if let Ok(idx) = batch.schema().index_of(prob_name) {
4626                let arr = batch
4627                    .column(idx)
4628                    .as_any()
4629                    .downcast_ref::<Float64Array>()
4630                    .ok_or_else(|| {
4631                        datafusion::error::DataFusionError::Internal(format!(
4632                            "Expected Float64 for PROB column {prob_name}"
4633                        ))
4634                    })?;
4635                (0..num_rows)
4636                    .map(|i| {
4637                        if arr.is_null(i) {
4638                            combined[i]
4639                        } else {
4640                            arr.value(i) * combined[i]
4641                        }
4642                    })
4643                    .collect()
4644            } else {
4645                combined
4646            }
4647        } else {
4648            combined
4649        };
4650
4651        let new_prob_array: arrow_array::ArrayRef =
4652            std::sync::Arc::new(Float64Array::from(final_prob));
4653
4654        // 3. Build output: replace PROB column, remove complement columns
4655        let mut fields = Vec::new();
4656        let mut columns = Vec::new();
4657
4658        for (idx, field) in batch.schema().fields().iter().enumerate() {
4659            if complement_cols.contains(field.name()) {
4660                continue;
4661            }
4662            if prob_col.is_some_and(|p| field.name() == p) {
4663                fields.push(field.clone());
4664                columns.push(new_prob_array.clone());
4665            } else {
4666                fields.push(field.clone());
4667                columns.push(batch.column(idx).clone());
4668            }
4669        }
4670
4671        let schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
4672        result.push(RecordBatch::try_new(schema, columns).map_err(arrow_err)?);
4673    }
4674
4675    Ok(result)
4676}
4677
4678/// Update derived scan handles before evaluating a rule's clause bodies.
4679///
4680/// For self-references: inject delta (semi-naive optimization).
4681/// For cross-references: inject full facts.
4682fn update_derived_scan_handles(
4683    registry: &DerivedScanRegistry,
4684    states: &[FixpointState],
4685    current_rule_idx: usize,
4686    rules: &[FixpointRulePlan],
4687) {
4688    let current_rule_name = &rules[current_rule_idx].name;
4689
4690    for entry in &registry.entries {
4691        // Find the state for this entry's rule
4692        let source_state_idx = rules.iter().position(|r| r.name == entry.rule_name);
4693        let Some(source_idx) = source_state_idx else {
4694            continue;
4695        };
4696
4697        let is_self = entry.rule_name == *current_rule_name;
4698        let data = if is_self && !rules[current_rule_idx].non_linear {
4699            // Self-ref in a linear rule: inject delta for semi-naive
4700            states[source_idx].all_delta().to_vec()
4701        } else {
4702            // Cross-ref, or self-ref of a non-linear rule (≥2 same-stratum
4703            // refs in one clause — Δ×Δ would miss Δ×F_old): inject full facts
4704            states[source_idx].all_facts().to_vec()
4705        };
4706
4707        // If empty, write an empty batch so the scan returns zero rows
4708        let data = if data.is_empty() || data.iter().all(|b| b.num_rows() == 0) {
4709            vec![RecordBatch::new_empty(Arc::clone(&entry.schema))]
4710        } else {
4711            data
4712        };
4713
4714        let mut guard = entry.data.write();
4715        *guard = data;
4716    }
4717}
4718
4719// ---------------------------------------------------------------------------
4720// DerivedScanExec — physical plan that reads from shared data at execution time
4721// ---------------------------------------------------------------------------
4722
4723/// Physical plan for `LocyDerivedScan` that reads from a shared `Arc<RwLock>` at
4724/// execution time (not at plan creation time).
4725///
4726/// This is critical for fixpoint iteration: the data handle is updated between
4727/// iterations, and each re-execution of the subplan must read the latest data.
4728pub struct DerivedScanExec {
4729    data: Arc<RwLock<Vec<RecordBatch>>>,
4730    schema: SchemaRef,
4731    properties: Arc<PlanProperties>,
4732}
4733
4734impl DerivedScanExec {
4735    pub fn new(data: Arc<RwLock<Vec<RecordBatch>>>, schema: SchemaRef) -> Self {
4736        let properties = compute_plan_properties(Arc::clone(&schema));
4737        Self {
4738            data,
4739            schema,
4740            properties,
4741        }
4742    }
4743}
4744
4745impl fmt::Debug for DerivedScanExec {
4746    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4747        f.debug_struct("DerivedScanExec")
4748            .field("schema", &self.schema)
4749            .finish()
4750    }
4751}
4752
4753impl DisplayAs for DerivedScanExec {
4754    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4755        write!(f, "DerivedScanExec")
4756    }
4757}
4758
4759impl ExecutionPlan for DerivedScanExec {
4760    fn name(&self) -> &str {
4761        "DerivedScanExec"
4762    }
4763    fn as_any(&self) -> &dyn Any {
4764        self
4765    }
4766    fn schema(&self) -> SchemaRef {
4767        Arc::clone(&self.schema)
4768    }
4769    fn properties(&self) -> &Arc<PlanProperties> {
4770        &self.properties
4771    }
4772    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
4773        vec![]
4774    }
4775    fn with_new_children(
4776        self: Arc<Self>,
4777        _children: Vec<Arc<dyn ExecutionPlan>>,
4778    ) -> DFResult<Arc<dyn ExecutionPlan>> {
4779        Ok(self)
4780    }
4781    fn execute(
4782        &self,
4783        _partition: usize,
4784        _context: Arc<TaskContext>,
4785    ) -> DFResult<SendableRecordBatchStream> {
4786        let batches = {
4787            let guard = self.data.read();
4788            if guard.is_empty() {
4789                vec![RecordBatch::new_empty(Arc::clone(&self.schema))]
4790            } else {
4791                // Re-stamp every batch with this exec's schema. The shared
4792                // data Arc always holds batches with the rule's original
4793                // yield-schema names, but this scan may carry per-occurrence
4794                // aliased column names (multi-IS-ref clauses). Zero-copy:
4795                // only the schema pointer changes, never the columns.
4796                guard
4797                    .iter()
4798                    .map(|b| {
4799                        RecordBatch::try_new(Arc::clone(&self.schema), b.columns().to_vec())
4800                            .map_err(|e| {
4801                                datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
4802                            })
4803                    })
4804                    .collect::<DFResult<Vec<_>>>()?
4805            }
4806        };
4807        Ok(Box::pin(MemoryStream::try_new(
4808            batches,
4809            Arc::clone(&self.schema),
4810            None,
4811        )?))
4812    }
4813}
4814
4815// ---------------------------------------------------------------------------
4816// InMemoryExec — wrapper to feed Vec<RecordBatch> into operator chains
4817// ---------------------------------------------------------------------------
4818
4819/// Simple in-memory execution plan that serves pre-computed batches.
4820///
4821/// Used internally to feed fixpoint results into post-fixpoint operator chains
4822/// (FOLD, BEST BY). Not exported — only used within this module.
4823struct InMemoryExec {
4824    batches: Vec<RecordBatch>,
4825    schema: SchemaRef,
4826    properties: Arc<PlanProperties>,
4827}
4828
4829impl InMemoryExec {
4830    fn new(batches: Vec<RecordBatch>, schema: SchemaRef) -> Self {
4831        let properties = compute_plan_properties(Arc::clone(&schema));
4832        Self {
4833            batches,
4834            schema,
4835            properties,
4836        }
4837    }
4838}
4839
4840impl fmt::Debug for InMemoryExec {
4841    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4842        f.debug_struct("InMemoryExec")
4843            .field("num_batches", &self.batches.len())
4844            .field("schema", &self.schema)
4845            .finish()
4846    }
4847}
4848
4849impl DisplayAs for InMemoryExec {
4850    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4851        write!(f, "InMemoryExec: batches={}", self.batches.len())
4852    }
4853}
4854
4855impl ExecutionPlan for InMemoryExec {
4856    fn name(&self) -> &str {
4857        "InMemoryExec"
4858    }
4859    fn as_any(&self) -> &dyn Any {
4860        self
4861    }
4862    fn schema(&self) -> SchemaRef {
4863        Arc::clone(&self.schema)
4864    }
4865    fn properties(&self) -> &Arc<PlanProperties> {
4866        &self.properties
4867    }
4868    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
4869        vec![]
4870    }
4871    fn with_new_children(
4872        self: Arc<Self>,
4873        _children: Vec<Arc<dyn ExecutionPlan>>,
4874    ) -> DFResult<Arc<dyn ExecutionPlan>> {
4875        Ok(self)
4876    }
4877    fn execute(
4878        &self,
4879        _partition: usize,
4880        _context: Arc<TaskContext>,
4881    ) -> DFResult<SendableRecordBatchStream> {
4882        Ok(Box::pin(MemoryStream::try_new(
4883            self.batches.clone(),
4884            Arc::clone(&self.schema),
4885            None,
4886        )?))
4887    }
4888}
4889
4890// ---------------------------------------------------------------------------
4891// Post-fixpoint chain — FOLD and BEST BY on converged facts
4892// ---------------------------------------------------------------------------
4893
4894/// Apply post-FOLD WHERE (HAVING) filter to aggregated batches.
4895///
4896/// Converts each Cypher HAVING expression to a DataFusion physical expression
4897/// via `cypher_expr_to_df` → type coercion → `create_physical_expr`, evaluates
4898/// against the FOLD output, and keeps only rows where all conditions hold.
4899fn apply_having_filter(
4900    batches: Vec<RecordBatch>,
4901    having_exprs: &[Expr],
4902    schema: &SchemaRef,
4903    task_ctx: &Arc<TaskContext>,
4904) -> DFResult<Vec<RecordBatch>> {
4905    use arrow::compute::{and, filter_record_batch};
4906    use arrow_array::BooleanArray;
4907    use datafusion::common::DFSchema;
4908    use datafusion::logical_expr::LogicalPlanBuilder;
4909    use datafusion::logical_expr::execution_props::ExecutionProps;
4910    use datafusion::optimizer::AnalyzerRule;
4911    use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
4912    use datafusion::physical_expr::create_physical_expr;
4913
4914    if batches.is_empty() {
4915        return Ok(batches);
4916    }
4917
4918    // Build DFSchema from the FOLD output Arrow schema.
4919    let df_schema = DFSchema::try_from(schema.as_ref().clone()).map_err(|e| {
4920        datafusion::common::DataFusionError::Internal(format!("HAVING schema conversion: {e}"))
4921    })?;
4922
4923    // Use the active TaskContext's config rather than allocating a fresh
4924    // `SessionContext` per HAVING evaluation (~130 µs/call). HAVING uses only
4925    // built-in DataFusion arithmetic — no Cypher UDFs — so a default
4926    // `ExecutionProps` is sufficient (it's documented as cheap to construct).
4927    let config = (**task_ctx.session_config().options()).clone();
4928    let props = ExecutionProps::new();
4929
4930    // Cypher Expr → DataFusion DfExpr → type-coerced DfExpr → PhysicalExpr.
4931    //
4932    // Type coercion is needed because FOLD aggregates produce Float64 (SUM,
4933    // AVG) or Int64 (COUNT), and literal comparisons like `total >= 100`
4934    // may mix Float64 columns with Int64 literals.
4935    let physical_exprs: Vec<Arc<dyn datafusion::physical_expr::PhysicalExpr>> = having_exprs
4936        .iter()
4937        .map(|expr| {
4938            let df_expr = crate::query::df_expr::cypher_expr_to_df(expr, None).map_err(|e| {
4939                datafusion::common::DataFusionError::Internal(format!(
4940                    "HAVING expression conversion: {e}"
4941                ))
4942            })?;
4943
4944            // Run DataFusion's type coercion by wrapping in a Filter plan,
4945            // applying the TypeCoercion analyzer rule, then extracting the
4946            // coerced predicate.
4947            let empty = datafusion::logical_expr::LogicalPlan::EmptyRelation(
4948                datafusion::logical_expr::EmptyRelation {
4949                    produce_one_row: false,
4950                    schema: Arc::new(df_schema.clone()),
4951                },
4952            );
4953            let filter_plan = LogicalPlanBuilder::from(empty)
4954                .filter(df_expr.clone())?
4955                .build()?;
4956            let coerced_expr = match TypeCoercion::new().analyze(filter_plan, &config) {
4957                Ok(datafusion::logical_expr::LogicalPlan::Filter(f)) => f.predicate,
4958                _ => df_expr,
4959            };
4960
4961            create_physical_expr(&coerced_expr, &df_schema, &props)
4962        })
4963        .collect::<DFResult<Vec<_>>>()?;
4964
4965    let mut result = Vec::new();
4966    for batch in batches {
4967        // Evaluate each condition and AND the boolean masks.
4968        let mut mask: Option<BooleanArray> = None;
4969        for phys_expr in &physical_exprs {
4970            let value = phys_expr.evaluate(&batch)?;
4971            let arr = value.into_array(batch.num_rows())?;
4972            let bool_arr = arr.as_any().downcast_ref::<BooleanArray>().ok_or_else(|| {
4973                datafusion::common::DataFusionError::Internal(
4974                    "HAVING condition must evaluate to boolean".into(),
4975                )
4976            })?;
4977            mask = Some(match mask {
4978                None => bool_arr.clone(),
4979                Some(prev) => and(&prev, bool_arr).map_err(arrow_err)?,
4980            });
4981        }
4982        if let Some(ref m) = mask {
4983            let filtered = filter_record_batch(&batch, m).map_err(arrow_err)?;
4984            if filtered.num_rows() > 0 {
4985                result.push(filtered);
4986            }
4987        } else {
4988            result.push(batch);
4989        }
4990    }
4991    Ok(result)
4992}
4993
4994/// Apply post-fixpoint operators (FOLD, HAVING, BEST BY, PRIORITY) to converged facts.
4995#[allow(
4996    clippy::too_many_arguments,
4997    reason = "context bundle would be over-engineering for one call site"
4998)]
4999pub(crate) async fn apply_post_fixpoint_chain(
5000    facts: Vec<RecordBatch>,
5001    rule: &FixpointRulePlan,
5002    task_ctx: &Arc<TaskContext>,
5003    strict_probability_domain: bool,
5004    probability_epsilon: f64,
5005    semiring_kind: SemiringKind,
5006    provenance_tracker: Option<Arc<ProvenanceStore>>,
5007    top_k_proofs_k: usize,
5008    registry: Option<Arc<DerivedScanRegistry>>,
5009) -> DFResult<Vec<RecordBatch>> {
5010    if !rule.has_fold && !rule.has_best_by && !rule.has_priority && rule.having.is_empty() {
5011        return Ok(facts);
5012    }
5013
5014    // Wrap facts in InMemoryExec.
5015    // Prefer the actual batch schema (from physical execution) over the
5016    // pre-computed yield_schema, which may have wrong inferred types
5017    // (e.g. Float64 for a string property).
5018    let schema = facts
5019        .iter()
5020        .find(|b| b.num_rows() > 0)
5021        .map(|b| b.schema())
5022        .unwrap_or_else(|| Arc::clone(&rule.yield_schema));
5023
5024    // Phase D D-C0: pre-compute body-row → IS-ref support map for
5025    // TopKProofs MNOR's DNF inclusion-exclusion math. Must be built
5026    // here because `facts` is moved into `InMemoryExec` on the next
5027    // line. The map is keyed by a full-column row hash — only
5028    // meaningful when no downstream plan node strips/adds columns
5029    // between this batch view and the FoldExec input. PRIORITY drops
5030    // the `__priority` column, which would change row hashes; until
5031    // we plumb the map past PRIORITY, skip map construction for
5032    // PRIORITY rules (the failing TCK test doesn't use PRIORITY).
5033    // Read the active K from `semiring_kind` rather than the separate
5034    // `top_k_proofs_k` parameter — the latter is not always threaded
5035    // from the LocyProgram config (the semiring's `k` is the source of
5036    // truth).
5037    let topk_k: Option<usize> = match semiring_kind {
5038        SemiringKind::TopKProofs { k } if k > 0 => Some(k as usize),
5039        _ => None,
5040    };
5041    let body_support_map: Option<Arc<HashMap<Vec<u8>, Vec<ProofTerm>>>> = if topk_k.is_some()
5042        && !rule.has_priority
5043        && let Some(registry) = registry.as_ref()
5044    {
5045        let mut map: HashMap<Vec<u8>, Vec<ProofTerm>> = HashMap::new();
5046        for batch in &facts {
5047            let all_indices: Vec<usize> = (0..batch.num_columns()).collect();
5048            for row_idx in 0..batch.num_rows() {
5049                let support = collect_is_ref_inputs_for_body_row(rule, batch, row_idx, registry);
5050                if support.is_empty() {
5051                    continue;
5052                }
5053                let hash = fact_hash_key(batch, &all_indices, row_idx);
5054                map.insert(hash, support);
5055            }
5056        }
5057        if map.is_empty() {
5058            None
5059        } else {
5060            Some(Arc::new(map))
5061        }
5062    } else {
5063        None
5064    };
5065
5066    let input: Arc<dyn ExecutionPlan> = Arc::new(InMemoryExec::new(facts, schema.clone()));
5067
5068    // Reconcile key indices: rule's indices are yield-schema positions but
5069    // the actual batch may have different column ordering after schema
5070    // reconciliation during fixpoint iteration (same pattern as
5071    // FixpointState::reconcile_schema).
5072    let key_column_indices: Vec<usize> = rule
5073        .key_column_indices
5074        .iter()
5075        .filter_map(|&i| {
5076            let name = rule.yield_schema.field(i).name();
5077            schema.index_of(name).ok()
5078        })
5079        .collect();
5080
5081    // Apply PRIORITY first — keeps only rows with max __priority per KEY group,
5082    // then strips the __priority column from output.
5083    // Must run before FOLD so that the __priority column is still present.
5084    let current: Arc<dyn ExecutionPlan> = if rule.has_priority {
5085        let priority_schema = input.schema();
5086        let priority_idx = priority_schema.index_of("__priority").map_err(|_| {
5087            datafusion::common::DataFusionError::Internal(
5088                "PRIORITY rule missing __priority column".to_string(),
5089            )
5090        })?;
5091        Arc::new(PriorityExec::new(
5092            input,
5093            key_column_indices.clone(),
5094            priority_idx,
5095        ))
5096    } else {
5097        input
5098    };
5099
5100    // Apply FOLD
5101    let current: Arc<dyn ExecutionPlan> = if rule.has_fold && !rule.fold_bindings.is_empty() {
5102        Arc::new(FoldExec::new_with_topk(
5103            current,
5104            key_column_indices.clone(),
5105            rule.fold_bindings.clone(),
5106            strict_probability_domain,
5107            probability_epsilon,
5108            semiring_kind,
5109            provenance_tracker.clone(),
5110            topk_k.unwrap_or(top_k_proofs_k),
5111            body_support_map.clone(),
5112        ))
5113    } else {
5114        current
5115    };
5116
5117    // Apply HAVING (post-FOLD WHERE filter)
5118    let current: Arc<dyn ExecutionPlan> = if !rule.having.is_empty() {
5119        let batches = collect_all_partitions(&current, Arc::clone(task_ctx)).await?;
5120        let filtered = apply_having_filter(batches, &rule.having, &current.schema(), task_ctx)?;
5121        if filtered.is_empty() {
5122            return Ok(filtered);
5123        }
5124        Arc::new(InMemoryExec::new(filtered, Arc::clone(&current.schema())))
5125    } else {
5126        current
5127    };
5128
5129    // Apply BEST BY
5130    let current: Arc<dyn ExecutionPlan> = if rule.has_best_by && !rule.best_by_criteria.is_empty() {
5131        Arc::new(BestByExec::new(
5132            current,
5133            key_column_indices.clone(),
5134            rule.best_by_criteria.clone(),
5135            rule.deterministic,
5136        ))
5137    } else {
5138        current
5139    };
5140
5141    collect_all_partitions(&current, Arc::clone(task_ctx)).await
5142}
5143
5144// ---------------------------------------------------------------------------
5145// FixpointExec — DataFusion ExecutionPlan
5146// ---------------------------------------------------------------------------
5147
5148/// DataFusion `ExecutionPlan` that drives semi-naive fixpoint iteration.
5149///
5150/// Has no physical children: clause bodies are re-planned from logical plans
5151/// on each iteration (same pattern as `RecursiveCTEExec` and `GraphApplyExec`).
5152pub struct FixpointExec {
5153    rules: Vec<FixpointRulePlan>,
5154    max_iterations: usize,
5155    timeout: Duration,
5156    graph_ctx: Arc<GraphExecutionContext>,
5157    session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
5158    storage: Arc<StorageManager>,
5159    schema_info: Arc<UniSchema>,
5160    params: HashMap<String, Value>,
5161    derived_scan_registry: Arc<DerivedScanRegistry>,
5162    output_schema: SchemaRef,
5163    properties: Arc<PlanProperties>,
5164    metrics: ExecutionPlanMetricsSet,
5165    max_derived_bytes: usize,
5166    /// Optional provenance tracker populated during fixpoint iteration.
5167    derivation_tracker: Option<Arc<ProvenanceStore>>,
5168    /// Shared slot written with per-rule iteration counts after convergence.
5169    iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
5170    strict_probability_domain: bool,
5171    probability_epsilon: f64,
5172    exact_probability: bool,
5173    max_bdd_variables: usize,
5174    /// Shared slot for runtime warnings collected during fixpoint iteration.
5175    warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
5176    /// Shared slot for groups where BDD fell back to independence mode.
5177    approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
5178    /// When > 0, retain at most this many proofs per fact (top-k provenance).
5179    top_k_proofs: usize,
5180    /// Shared flag: set to true on timeout to signal partial results.
5181    timeout_flag: Arc<std::sync::atomic::AtomicU8>,
5182    /// Active probability semiring (rollout D-7).
5183    semiring_kind: SemiringKind,
5184    /// Phase B Slice 3 registry of neural classifiers, keyed by the
5185    /// model name from `CREATE MODEL`. Held by `Arc` so executor clones
5186    /// share the same underlying map.
5187    classifier_registry: Arc<ClassifierRegistry>,
5188    /// Phase B follow-up: optional per-evaluation memoization cache
5189    /// for classifier outputs keyed by `(model_name, feature_hash)`.
5190    /// `None` → no caching; `Some` → cache shared across fixpoint
5191    /// iterations (and optionally across the entire query / multiple
5192    /// queries, when the caller threads it via `LocyConfig`).
5193    classifier_cache: Option<Arc<ModelInvocationCache>>,
5194    /// Phase C B1-B3 follow-up: per-query side-channel store
5195    /// for (raw, calibrated, confidence_band) records. Read by
5196    /// EXPLAIN; not used by the fixpoint inner loop directly
5197    /// (LocyModelInvokeExec writes; this struct just carries
5198    /// the Arc to keep the type wiring consistent across the
5199    /// LocyProgramExec/FixpointExec boundary).
5200    #[allow(
5201        dead_code,
5202        reason = "boundary plumbing; read by EXPLAIN via LocyModelInvokeExec"
5203    )]
5204    classifier_provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
5205    /// Optional per-stratum profile collector. `Some` only on the `profile()`
5206    /// path; when set, each fixpoint iteration records per-rule timing, delta
5207    /// facts, and the clause-body operator tree into it. `None` → zero overhead.
5208    profile_collector: Option<Arc<LocyProfileCollector>>,
5209}
5210
5211impl fmt::Debug for FixpointExec {
5212    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
5213        f.debug_struct("FixpointExec")
5214            .field("rules_count", &self.rules.len())
5215            .field("max_iterations", &self.max_iterations)
5216            .field("timeout", &self.timeout)
5217            .field("output_schema", &self.output_schema)
5218            .field("max_derived_bytes", &self.max_derived_bytes)
5219            .finish_non_exhaustive()
5220    }
5221}
5222
5223impl FixpointExec {
5224    /// Create a new `FixpointExec`.
5225    #[expect(
5226        clippy::too_many_arguments,
5227        reason = "FixpointExec configuration needs all context"
5228    )]
5229    #[deprecated(
5230        note = "use `new_with_semiring_classifiers_and_cache` (or the lighter \
5231                `new_with_semiring_and_classifiers` / `new_with_semiring`) — \
5232                this legacy ctor defaults the semiring to AddMultProb and \
5233                ships no classifier registry, which the Phase B+ runtime needs \
5234                explicitly. To be removed after C0 Stage 2."
5235    )]
5236    pub fn new(
5237        rules: Vec<FixpointRulePlan>,
5238        max_iterations: usize,
5239        timeout: Duration,
5240        graph_ctx: Arc<GraphExecutionContext>,
5241        session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
5242        storage: Arc<StorageManager>,
5243        schema_info: Arc<UniSchema>,
5244        params: HashMap<String, Value>,
5245        derived_scan_registry: Arc<DerivedScanRegistry>,
5246        output_schema: SchemaRef,
5247        max_derived_bytes: usize,
5248        derivation_tracker: Option<Arc<ProvenanceStore>>,
5249        iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
5250        strict_probability_domain: bool,
5251        probability_epsilon: f64,
5252        exact_probability: bool,
5253        max_bdd_variables: usize,
5254        warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
5255        approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
5256        top_k_proofs: usize,
5257        timeout_flag: Arc<std::sync::atomic::AtomicU8>,
5258    ) -> Self {
5259        Self::new_with_semiring_and_classifiers(
5260            rules,
5261            max_iterations,
5262            timeout,
5263            graph_ctx,
5264            session_ctx,
5265            storage,
5266            schema_info,
5267            params,
5268            derived_scan_registry,
5269            output_schema,
5270            max_derived_bytes,
5271            derivation_tracker,
5272            iteration_counts,
5273            strict_probability_domain,
5274            probability_epsilon,
5275            exact_probability,
5276            max_bdd_variables,
5277            warnings_slot,
5278            approximate_slot,
5279            top_k_proofs,
5280            timeout_flag,
5281            SemiringKind::AddMultProb,
5282            Arc::new(ClassifierRegistry::new()),
5283        )
5284    }
5285
5286    /// Variant accepting an explicit [`SemiringKind`]. Empty classifier
5287    /// registry; for the full variant call
5288    /// [`FixpointExec::new_with_semiring_and_classifiers`].
5289    #[expect(
5290        clippy::too_many_arguments,
5291        reason = "FixpointExec configuration needs all context"
5292    )]
5293    pub fn new_with_semiring(
5294        rules: Vec<FixpointRulePlan>,
5295        max_iterations: usize,
5296        timeout: Duration,
5297        graph_ctx: Arc<GraphExecutionContext>,
5298        session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
5299        storage: Arc<StorageManager>,
5300        schema_info: Arc<UniSchema>,
5301        params: HashMap<String, Value>,
5302        derived_scan_registry: Arc<DerivedScanRegistry>,
5303        output_schema: SchemaRef,
5304        max_derived_bytes: usize,
5305        derivation_tracker: Option<Arc<ProvenanceStore>>,
5306        iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
5307        strict_probability_domain: bool,
5308        probability_epsilon: f64,
5309        exact_probability: bool,
5310        max_bdd_variables: usize,
5311        warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
5312        approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
5313        top_k_proofs: usize,
5314        timeout_flag: Arc<std::sync::atomic::AtomicU8>,
5315        semiring_kind: SemiringKind,
5316    ) -> Self {
5317        Self::new_with_semiring_and_classifiers(
5318            rules,
5319            max_iterations,
5320            timeout,
5321            graph_ctx,
5322            session_ctx,
5323            storage,
5324            schema_info,
5325            params,
5326            derived_scan_registry,
5327            output_schema,
5328            max_derived_bytes,
5329            derivation_tracker,
5330            iteration_counts,
5331            strict_probability_domain,
5332            probability_epsilon,
5333            exact_probability,
5334            max_bdd_variables,
5335            warnings_slot,
5336            approximate_slot,
5337            top_k_proofs,
5338            timeout_flag,
5339            semiring_kind,
5340            Arc::new(ClassifierRegistry::new()),
5341        )
5342    }
5343
5344    /// Phase B Slice 3 entry: accepts both the semiring kind and the
5345    /// runtime classifier registry. The planner uses this when the
5346    /// program contains `CREATE MODEL` declarations.
5347    #[expect(
5348        clippy::too_many_arguments,
5349        reason = "FixpointExec configuration needs all context"
5350    )]
5351    pub fn new_with_semiring_and_classifiers(
5352        rules: Vec<FixpointRulePlan>,
5353        max_iterations: usize,
5354        timeout: Duration,
5355        graph_ctx: Arc<GraphExecutionContext>,
5356        session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
5357        storage: Arc<StorageManager>,
5358        schema_info: Arc<UniSchema>,
5359        params: HashMap<String, Value>,
5360        derived_scan_registry: Arc<DerivedScanRegistry>,
5361        output_schema: SchemaRef,
5362        max_derived_bytes: usize,
5363        derivation_tracker: Option<Arc<ProvenanceStore>>,
5364        iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
5365        strict_probability_domain: bool,
5366        probability_epsilon: f64,
5367        exact_probability: bool,
5368        max_bdd_variables: usize,
5369        warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
5370        approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
5371        top_k_proofs: usize,
5372        timeout_flag: Arc<std::sync::atomic::AtomicU8>,
5373        semiring_kind: SemiringKind,
5374        classifier_registry: Arc<ClassifierRegistry>,
5375    ) -> Self {
5376        Self::new_with_semiring_classifiers_and_cache(
5377            rules,
5378            max_iterations,
5379            timeout,
5380            graph_ctx,
5381            session_ctx,
5382            storage,
5383            schema_info,
5384            params,
5385            derived_scan_registry,
5386            output_schema,
5387            max_derived_bytes,
5388            derivation_tracker,
5389            iteration_counts,
5390            strict_probability_domain,
5391            probability_epsilon,
5392            exact_probability,
5393            max_bdd_variables,
5394            warnings_slot,
5395            approximate_slot,
5396            top_k_proofs,
5397            timeout_flag,
5398            semiring_kind,
5399            classifier_registry,
5400            None,
5401            None,
5402        )
5403    }
5404
5405    /// Phase B follow-up: full constructor accepting an optional
5406    /// memoization cache. Existing callers default to `None` (no cache);
5407    /// the impl_locy.rs entry passes the user's `config.classifier_cache`.
5408    #[expect(
5409        clippy::too_many_arguments,
5410        reason = "FixpointExec configuration needs all context"
5411    )]
5412    pub fn new_with_semiring_classifiers_and_cache(
5413        rules: Vec<FixpointRulePlan>,
5414        max_iterations: usize,
5415        timeout: Duration,
5416        graph_ctx: Arc<GraphExecutionContext>,
5417        session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
5418        storage: Arc<StorageManager>,
5419        schema_info: Arc<UniSchema>,
5420        params: HashMap<String, Value>,
5421        derived_scan_registry: Arc<DerivedScanRegistry>,
5422        output_schema: SchemaRef,
5423        max_derived_bytes: usize,
5424        derivation_tracker: Option<Arc<ProvenanceStore>>,
5425        iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
5426        strict_probability_domain: bool,
5427        probability_epsilon: f64,
5428        exact_probability: bool,
5429        max_bdd_variables: usize,
5430        warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
5431        approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
5432        top_k_proofs: usize,
5433        timeout_flag: Arc<std::sync::atomic::AtomicU8>,
5434        semiring_kind: SemiringKind,
5435        classifier_registry: Arc<ClassifierRegistry>,
5436        classifier_cache: Option<Arc<ModelInvocationCache>>,
5437        classifier_provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
5438    ) -> Self {
5439        let properties = compute_plan_properties(Arc::clone(&output_schema));
5440        Self {
5441            rules,
5442            max_iterations,
5443            timeout,
5444            graph_ctx,
5445            session_ctx,
5446            storage,
5447            schema_info,
5448            params,
5449            derived_scan_registry,
5450            output_schema,
5451            properties,
5452            metrics: ExecutionPlanMetricsSet::new(),
5453            max_derived_bytes,
5454            derivation_tracker,
5455            iteration_counts,
5456            strict_probability_domain,
5457            probability_epsilon,
5458            exact_probability,
5459            max_bdd_variables,
5460            warnings_slot,
5461            approximate_slot,
5462            top_k_proofs,
5463            timeout_flag,
5464            semiring_kind,
5465            classifier_registry,
5466            classifier_cache,
5467            classifier_provenance_store,
5468            profile_collector: None,
5469        }
5470    }
5471
5472    /// Returns the shared iteration counts slot for post-execution inspection.
5473    pub fn iteration_counts(&self) -> Arc<StdRwLock<HashMap<String, usize>>> {
5474        Arc::clone(&self.iteration_counts)
5475    }
5476
5477    /// Attach a profile collector so this stratum's fixpoint records per-rule,
5478    /// per-iteration timing, delta facts, and clause-body operator metrics.
5479    ///
5480    /// Mirrors `set_derivation_tracker`: call before wrapping the exec in an
5481    /// `Arc` and executing. Only the Locy `profile()` path sets this.
5482    pub fn set_profile_collector(&mut self, collector: Arc<LocyProfileCollector>) {
5483        self.profile_collector = Some(collector);
5484    }
5485}
5486
5487impl DisplayAs for FixpointExec {
5488    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
5489        write!(
5490            f,
5491            "FixpointExec: rules=[{}], max_iter={}, timeout={:?}",
5492            self.rules
5493                .iter()
5494                .map(|r| r.name.as_str())
5495                .collect::<Vec<_>>()
5496                .join(", "),
5497            self.max_iterations,
5498            self.timeout,
5499        )
5500    }
5501}
5502
5503impl ExecutionPlan for FixpointExec {
5504    fn name(&self) -> &str {
5505        "FixpointExec"
5506    }
5507
5508    fn as_any(&self) -> &dyn Any {
5509        self
5510    }
5511
5512    fn schema(&self) -> SchemaRef {
5513        Arc::clone(&self.output_schema)
5514    }
5515
5516    fn properties(&self) -> &Arc<PlanProperties> {
5517        &self.properties
5518    }
5519
5520    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
5521        // No physical children — clause bodies are re-planned each iteration
5522        vec![]
5523    }
5524
5525    fn with_new_children(
5526        self: Arc<Self>,
5527        children: Vec<Arc<dyn ExecutionPlan>>,
5528    ) -> DFResult<Arc<dyn ExecutionPlan>> {
5529        if !children.is_empty() {
5530            return Err(datafusion::error::DataFusionError::Plan(
5531                "FixpointExec has no children".to_string(),
5532            ));
5533        }
5534        Ok(self)
5535    }
5536
5537    fn execute(
5538        &self,
5539        partition: usize,
5540        _context: Arc<TaskContext>,
5541    ) -> DFResult<SendableRecordBatchStream> {
5542        let metrics = BaselineMetrics::new(&self.metrics, partition);
5543
5544        // Clone all fields for the async closure
5545        let rules = self
5546            .rules
5547            .iter()
5548            .map(|r| {
5549                // We need to clone the FixpointRulePlan, but it contains LogicalPlan
5550                // which doesn't implement Clone traditionally. However, our LogicalPlan
5551                // does implement Clone since it's an enum.
5552                FixpointRulePlan {
5553                    name: r.name.clone(),
5554                    clauses: r
5555                        .clauses
5556                        .iter()
5557                        .map(|c| FixpointClausePlan {
5558                            body_logical: c.body_logical.clone(),
5559                            is_ref_bindings: c.is_ref_bindings.clone(),
5560                            priority: c.priority,
5561                            along_bindings: c.along_bindings.clone(),
5562                            model_invocations: c.model_invocations.clone(),
5563                        })
5564                        .collect(),
5565                    yield_schema: Arc::clone(&r.yield_schema),
5566                    key_column_indices: r.key_column_indices.clone(),
5567                    priority: r.priority,
5568                    has_fold: r.has_fold,
5569                    fold_bindings: r.fold_bindings.clone(),
5570                    having: r.having.clone(),
5571                    has_best_by: r.has_best_by,
5572                    best_by_criteria: r.best_by_criteria.clone(),
5573                    has_priority: r.has_priority,
5574                    deterministic: r.deterministic,
5575                    prob_column_name: r.prob_column_name.clone(),
5576                    non_linear: r.non_linear,
5577                }
5578            })
5579            .collect();
5580
5581        let max_iterations = self.max_iterations;
5582        let timeout = self.timeout;
5583        let graph_ctx = Arc::clone(&self.graph_ctx);
5584        let session_ctx = Arc::clone(&self.session_ctx);
5585        let storage = Arc::clone(&self.storage);
5586        let schema_info = Arc::clone(&self.schema_info);
5587        let params = self.params.clone();
5588        let registry = Arc::clone(&self.derived_scan_registry);
5589        let output_schema = Arc::clone(&self.output_schema);
5590        let max_derived_bytes = self.max_derived_bytes;
5591        let derivation_tracker = self.derivation_tracker.clone();
5592        let iteration_counts = Arc::clone(&self.iteration_counts);
5593        let strict_probability_domain = self.strict_probability_domain;
5594        let probability_epsilon = self.probability_epsilon;
5595        let exact_probability = self.exact_probability;
5596        let max_bdd_variables = self.max_bdd_variables;
5597        let warnings_slot = Arc::clone(&self.warnings_slot);
5598        let approximate_slot = Arc::clone(&self.approximate_slot);
5599        let top_k_proofs = self.top_k_proofs;
5600        let timeout_flag = Arc::clone(&self.timeout_flag);
5601        let semiring_kind = self.semiring_kind;
5602        let classifier_registry = Arc::clone(&self.classifier_registry);
5603        let classifier_cache = self.classifier_cache.as_ref().map(Arc::clone);
5604        let classifier_provenance_store = self.classifier_provenance_store.as_ref().map(Arc::clone);
5605        let profile_collector = self.profile_collector.as_ref().map(Arc::clone);
5606
5607        let fut = async move {
5608            run_fixpoint_loop(
5609                rules,
5610                max_iterations,
5611                timeout,
5612                graph_ctx,
5613                session_ctx,
5614                storage,
5615                schema_info,
5616                params,
5617                registry,
5618                output_schema,
5619                max_derived_bytes,
5620                derivation_tracker,
5621                iteration_counts,
5622                strict_probability_domain,
5623                probability_epsilon,
5624                exact_probability,
5625                max_bdd_variables,
5626                warnings_slot,
5627                approximate_slot,
5628                top_k_proofs,
5629                timeout_flag,
5630                semiring_kind,
5631                classifier_registry,
5632                classifier_cache,
5633                classifier_provenance_store,
5634                profile_collector,
5635            )
5636            .await
5637        };
5638
5639        Ok(Box::pin(FixpointStream {
5640            state: FixpointStreamState::Running(Box::pin(fut)),
5641            schema: Arc::clone(&self.output_schema),
5642            metrics,
5643        }))
5644    }
5645
5646    fn metrics(&self) -> Option<MetricsSet> {
5647        Some(self.metrics.clone_inner())
5648    }
5649}
5650
5651// ---------------------------------------------------------------------------
5652// FixpointStream — async state machine for streaming results
5653// ---------------------------------------------------------------------------
5654
5655enum FixpointStreamState {
5656    /// Fixpoint loop is running.
5657    Running(Pin<Box<dyn std::future::Future<Output = DFResult<Vec<RecordBatch>>> + Send>>),
5658    /// Emitting accumulated result batches one at a time.
5659    Emitting(Vec<RecordBatch>, usize),
5660    /// All batches emitted.
5661    Done,
5662}
5663
5664struct FixpointStream {
5665    state: FixpointStreamState,
5666    schema: SchemaRef,
5667    metrics: BaselineMetrics,
5668}
5669
5670impl Stream for FixpointStream {
5671    type Item = DFResult<RecordBatch>;
5672
5673    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
5674        let this = self.get_mut();
5675        let metrics = this.metrics.clone();
5676        let _timer = metrics.elapsed_compute().timer();
5677        loop {
5678            match &mut this.state {
5679                FixpointStreamState::Running(fut) => match fut.as_mut().poll(cx) {
5680                    Poll::Ready(Ok(batches)) => {
5681                        if batches.is_empty() {
5682                            this.state = FixpointStreamState::Done;
5683                            return Poll::Ready(None);
5684                        }
5685                        this.state = FixpointStreamState::Emitting(batches, 0);
5686                        // Loop to emit first batch
5687                    }
5688                    Poll::Ready(Err(e)) => {
5689                        this.state = FixpointStreamState::Done;
5690                        return Poll::Ready(Some(Err(e)));
5691                    }
5692                    Poll::Pending => return Poll::Pending,
5693                },
5694                FixpointStreamState::Emitting(batches, idx) => {
5695                    if *idx >= batches.len() {
5696                        this.state = FixpointStreamState::Done;
5697                        return Poll::Ready(None);
5698                    }
5699                    let batch = batches[*idx].clone();
5700                    *idx += 1;
5701                    this.metrics.record_output(batch.num_rows());
5702                    return Poll::Ready(Some(Ok(batch)));
5703                }
5704                FixpointStreamState::Done => return Poll::Ready(None),
5705            }
5706        }
5707    }
5708}
5709
5710impl RecordBatchStream for FixpointStream {
5711    fn schema(&self) -> SchemaRef {
5712        Arc::clone(&self.schema)
5713    }
5714}
5715
5716// ---------------------------------------------------------------------------
5717// Unit tests
5718// ---------------------------------------------------------------------------
5719
5720#[cfg(test)]
5721mod tests {
5722    use super::*;
5723    use arrow_array::{Float64Array, Int64Array, StringArray};
5724    use arrow_schema::{DataType, Field, Schema};
5725
5726    fn test_schema() -> SchemaRef {
5727        Arc::new(Schema::new(vec![
5728            Field::new("name", DataType::Utf8, true),
5729            Field::new("value", DataType::Int64, true),
5730        ]))
5731    }
5732
5733    fn make_batch(names: &[&str], values: &[i64]) -> RecordBatch {
5734        RecordBatch::try_new(
5735            test_schema(),
5736            vec![
5737                Arc::new(StringArray::from(
5738                    names.iter().map(|s| Some(*s)).collect::<Vec<_>>(),
5739                )),
5740                Arc::new(Int64Array::from(values.to_vec())),
5741            ],
5742        )
5743        .unwrap()
5744    }
5745
5746    // --- FixpointState dedup tests ---
5747
5748    #[tokio::test]
5749    async fn test_fixpoint_state_empty_facts_adds_all() {
5750        let schema = test_schema();
5751        let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
5752
5753        let batch = make_batch(&["a", "b", "c"], &[1, 2, 3]);
5754        let changed = state.merge_delta(vec![batch], None).await.unwrap();
5755
5756        assert!(changed);
5757        assert_eq!(state.all_facts().len(), 1);
5758        assert_eq!(state.all_facts()[0].num_rows(), 3);
5759        assert_eq!(state.all_delta().len(), 1);
5760        assert_eq!(state.all_delta()[0].num_rows(), 3);
5761    }
5762
5763    #[tokio::test]
5764    async fn test_fixpoint_state_exact_duplicates_excluded() {
5765        let schema = test_schema();
5766        let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
5767
5768        let batch1 = make_batch(&["a", "b"], &[1, 2]);
5769        state.merge_delta(vec![batch1], None).await.unwrap();
5770
5771        // Same rows again
5772        let batch2 = make_batch(&["a", "b"], &[1, 2]);
5773        let changed = state.merge_delta(vec![batch2], None).await.unwrap();
5774        assert!(!changed);
5775        assert!(
5776            state.all_delta().is_empty() || state.all_delta().iter().all(|b| b.num_rows() == 0)
5777        );
5778    }
5779
5780    #[tokio::test]
5781    async fn test_fixpoint_state_partial_overlap() {
5782        let schema = test_schema();
5783        let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
5784
5785        let batch1 = make_batch(&["a", "b"], &[1, 2]);
5786        state.merge_delta(vec![batch1], None).await.unwrap();
5787
5788        // "a":1 is duplicate, "c":3 is new
5789        let batch2 = make_batch(&["a", "c"], &[1, 3]);
5790        let changed = state.merge_delta(vec![batch2], None).await.unwrap();
5791        assert!(changed);
5792
5793        // Delta should have only "c":3
5794        let delta_rows: usize = state.all_delta().iter().map(|b| b.num_rows()).sum();
5795        assert_eq!(delta_rows, 1);
5796
5797        // Total facts: a:1, b:2, c:3
5798        let total_rows: usize = state.all_facts().iter().map(|b| b.num_rows()).sum();
5799        assert_eq!(total_rows, 3);
5800    }
5801
5802    #[tokio::test]
5803    async fn test_fixpoint_state_convergence() {
5804        let schema = test_schema();
5805        let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
5806
5807        let batch = make_batch(&["a"], &[1]);
5808        state.merge_delta(vec![batch], None).await.unwrap();
5809
5810        // Empty candidates → converged
5811        let changed = state.merge_delta(vec![], None).await.unwrap();
5812        assert!(!changed);
5813        assert!(state.is_converged());
5814    }
5815
5816    // --- RowDedupState tests ---
5817
5818    #[test]
5819    fn test_row_dedup_persistent_across_calls() {
5820        // RowDedupState should remember rows from the first call so the second
5821        // call does not re-accept them (O(M) per iteration, no facts re-scan).
5822        let schema = test_schema();
5823        let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
5824
5825        let batch1 = make_batch(&["a", "b"], &[1, 2]);
5826        let delta1 = rd.compute_delta(&[batch1], &schema).unwrap();
5827        // First call: both rows are new.
5828        let rows1: usize = delta1.iter().map(|b| b.num_rows()).sum();
5829        assert_eq!(rows1, 2);
5830
5831        // Second call with same rows: seen set already has them → empty delta.
5832        let batch2 = make_batch(&["a", "b"], &[1, 2]);
5833        let delta2 = rd.compute_delta(&[batch2], &schema).unwrap();
5834        let rows2: usize = delta2.iter().map(|b| b.num_rows()).sum();
5835        assert_eq!(rows2, 0);
5836
5837        // Third call with one old + one new: only the new row is returned.
5838        let batch3 = make_batch(&["a", "c"], &[1, 3]);
5839        let delta3 = rd.compute_delta(&[batch3], &schema).unwrap();
5840        let rows3: usize = delta3.iter().map(|b| b.num_rows()).sum();
5841        assert_eq!(rows3, 1);
5842    }
5843
5844    #[test]
5845    fn test_row_dedup_null_handling() {
5846        use arrow_array::StringArray;
5847        use arrow_schema::{DataType, Field, Schema};
5848
5849        let schema: SchemaRef = Arc::new(Schema::new(vec![
5850            Field::new("a", DataType::Utf8, true),
5851            Field::new("b", DataType::Int64, true),
5852        ]));
5853        let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
5854
5855        // Two rows: (NULL, 1) and (NULL, 1) — same NULLs → duplicate.
5856        let batch_nulls = RecordBatch::try_new(
5857            Arc::clone(&schema),
5858            vec![
5859                Arc::new(StringArray::from(vec![None::<&str>, None::<&str>])),
5860                Arc::new(arrow_array::Int64Array::from(vec![1i64, 1i64])),
5861            ],
5862        )
5863        .unwrap();
5864        let delta = rd.compute_delta(&[batch_nulls], &schema).unwrap();
5865        let rows: usize = delta.iter().map(|b| b.num_rows()).sum();
5866        assert_eq!(rows, 1, "two identical NULL rows should be deduped to one");
5867
5868        // (NULL, 2) — NULL in same col but different non-null col → distinct.
5869        let batch_diff = RecordBatch::try_new(
5870            Arc::clone(&schema),
5871            vec![
5872                Arc::new(StringArray::from(vec![None::<&str>])),
5873                Arc::new(arrow_array::Int64Array::from(vec![2i64])),
5874            ],
5875        )
5876        .unwrap();
5877        let delta2 = rd.compute_delta(&[batch_diff], &schema).unwrap();
5878        let rows2: usize = delta2.iter().map(|b| b.num_rows()).sum();
5879        assert_eq!(rows2, 1, "(NULL, 2) is distinct from (NULL, 1)");
5880    }
5881
5882    #[test]
5883    fn test_row_dedup_within_candidate_dedup() {
5884        // Duplicate rows within a single candidate batch should be collapsed to one.
5885        let schema = test_schema();
5886        let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
5887
5888        // Batch with three rows: a:1, a:1, b:2 — "a:1" appears twice.
5889        let batch = make_batch(&["a", "a", "b"], &[1, 1, 2]);
5890        let delta = rd.compute_delta(&[batch], &schema).unwrap();
5891        let rows: usize = delta.iter().map(|b| b.num_rows()).sum();
5892        assert_eq!(rows, 2, "within-batch dup should be collapsed: a:1, b:2");
5893    }
5894
5895    // --- Float rounding tests ---
5896
5897    #[test]
5898    fn test_round_float_columns_near_duplicates() {
5899        let schema = Arc::new(Schema::new(vec![
5900            Field::new("name", DataType::Utf8, true),
5901            Field::new("dist", DataType::Float64, true),
5902        ]));
5903        let batch = RecordBatch::try_new(
5904            Arc::clone(&schema),
5905            vec![
5906                Arc::new(StringArray::from(vec![Some("a"), Some("a")])),
5907                Arc::new(Float64Array::from(vec![1.0000000000001, 1.0000000000002])),
5908            ],
5909        )
5910        .unwrap();
5911
5912        let rounded = round_float_columns(&[batch]);
5913        assert_eq!(rounded.len(), 1);
5914        let col = rounded[0]
5915            .column(1)
5916            .as_any()
5917            .downcast_ref::<Float64Array>()
5918            .unwrap();
5919        // Both should round to same value
5920        assert_eq!(col.value(0), col.value(1));
5921    }
5922
5923    // --- DerivedScanRegistry tests ---
5924
5925    #[test]
5926    fn test_registry_write_read_round_trip() {
5927        let schema = test_schema();
5928        let data = Arc::new(RwLock::new(Vec::new()));
5929        let mut reg = DerivedScanRegistry::new();
5930        reg.add(DerivedScanEntry {
5931            scan_index: 0,
5932            rule_name: "reachable".into(),
5933            is_self_ref: true,
5934            data: Arc::clone(&data),
5935            schema: Arc::clone(&schema),
5936        });
5937
5938        let batch = make_batch(&["x"], &[42]);
5939        reg.write_data(0, vec![batch.clone()]);
5940
5941        let entry = reg.get(0).unwrap();
5942        let guard = entry.data.read();
5943        assert_eq!(guard.len(), 1);
5944        assert_eq!(guard[0].num_rows(), 1);
5945    }
5946
5947    #[test]
5948    fn test_registry_entries_for_rule() {
5949        let schema = test_schema();
5950        let mut reg = DerivedScanRegistry::new();
5951        reg.add(DerivedScanEntry {
5952            scan_index: 0,
5953            rule_name: "r1".into(),
5954            is_self_ref: true,
5955            data: Arc::new(RwLock::new(Vec::new())),
5956            schema: Arc::clone(&schema),
5957        });
5958        reg.add(DerivedScanEntry {
5959            scan_index: 1,
5960            rule_name: "r2".into(),
5961            is_self_ref: false,
5962            data: Arc::new(RwLock::new(Vec::new())),
5963            schema: Arc::clone(&schema),
5964        });
5965        reg.add(DerivedScanEntry {
5966            scan_index: 2,
5967            rule_name: "r1".into(),
5968            is_self_ref: false,
5969            data: Arc::new(RwLock::new(Vec::new())),
5970            schema: Arc::clone(&schema),
5971        });
5972
5973        assert_eq!(reg.entries_for_rule("r1").len(), 2);
5974        assert_eq!(reg.entries_for_rule("r2").len(), 1);
5975        assert_eq!(reg.entries_for_rule("r3").len(), 0);
5976    }
5977
5978    // --- MonotonicAggState tests ---
5979
5980    #[test]
5981    fn test_monotonic_agg_update_and_stability() {
5982        let bindings = vec![MonotonicFoldBinding {
5983            fold_name: "total".into(),
5984            aggregate: std::sync::Arc::new(uni_plugin_builtin::locy_aggregates::SumAgg),
5985            input_col_index: 1,
5986            input_col_name: None,
5987        }];
5988        let mut agg = MonotonicAggState::new(bindings);
5989
5990        // First update
5991        let batch = make_batch(&["a"], &[10]);
5992        agg.snapshot();
5993        let changed = agg
5994            .update(&[0], &[batch], false, SemiringKind::AddMultProb)
5995            .unwrap();
5996        assert!(changed);
5997        assert!(!agg.is_stable()); // changed since snapshot
5998
5999        // Snapshot and check stability with no new data
6000        agg.snapshot();
6001        let changed = agg
6002            .update(&[0], &[], false, SemiringKind::AddMultProb)
6003            .unwrap();
6004        assert!(!changed);
6005        assert!(agg.is_stable());
6006    }
6007
6008    // --- Memory limit test ---
6009
6010    #[tokio::test]
6011    async fn test_memory_limit_exceeded() {
6012        let schema = test_schema();
6013        // Set a tiny limit
6014        let mut state = FixpointState::new("test".into(), schema, vec![0], 1, None, false);
6015
6016        let batch = make_batch(&["a", "b", "c"], &[1, 2, 3]);
6017        let result = state.merge_delta(vec![batch], None).await;
6018        assert!(result.is_err());
6019        let err = result.unwrap_err().to_string();
6020        assert!(err.contains("memory limit"), "Error was: {}", err);
6021    }
6022
6023    // --- FixpointStream lifecycle test ---
6024
6025    #[tokio::test]
6026    async fn test_fixpoint_stream_emitting() {
6027        use futures::StreamExt;
6028
6029        let schema = test_schema();
6030        let batch1 = make_batch(&["a"], &[1]);
6031        let batch2 = make_batch(&["b"], &[2]);
6032
6033        let metrics = ExecutionPlanMetricsSet::new();
6034        let baseline = BaselineMetrics::new(&metrics, 0);
6035
6036        let mut stream = FixpointStream {
6037            state: FixpointStreamState::Emitting(vec![batch1, batch2], 0),
6038            schema,
6039            metrics: baseline,
6040        };
6041
6042        let stream = Pin::new(&mut stream);
6043        let batches: Vec<RecordBatch> = stream.filter_map(|r| async { r.ok() }).collect().await;
6044
6045        assert_eq!(batches.len(), 2);
6046        assert_eq!(batches[0].num_rows(), 1);
6047        assert_eq!(batches[1].num_rows(), 1);
6048    }
6049
6050    // ── MonotonicAggState MNOR/MPROD tests ──────────────────────────────
6051
6052    fn make_f64_batch(names: &[&str], values: &[f64]) -> RecordBatch {
6053        let schema = Arc::new(Schema::new(vec![
6054            Field::new("name", DataType::Utf8, true),
6055            Field::new("value", DataType::Float64, true),
6056        ]));
6057        RecordBatch::try_new(
6058            schema,
6059            vec![
6060                Arc::new(StringArray::from(
6061                    names.iter().map(|s| Some(*s)).collect::<Vec<_>>(),
6062                )),
6063                Arc::new(Float64Array::from(values.to_vec())),
6064            ],
6065        )
6066        .unwrap()
6067    }
6068
6069    fn make_nor_binding() -> Vec<MonotonicFoldBinding> {
6070        vec![MonotonicFoldBinding {
6071            fold_name: "prob".into(),
6072            aggregate: std::sync::Arc::new(uni_plugin_builtin::locy_aggregates::MnorAgg),
6073            input_col_index: 1,
6074            input_col_name: None,
6075        }]
6076    }
6077
6078    fn make_prod_binding() -> Vec<MonotonicFoldBinding> {
6079        vec![MonotonicFoldBinding {
6080            fold_name: "prob".into(),
6081            aggregate: std::sync::Arc::new(uni_plugin_builtin::locy_aggregates::MprodAgg),
6082            input_col_index: 1,
6083            input_col_name: None,
6084        }]
6085    }
6086
6087    fn acc_key(name: &str) -> (Vec<ScalarKey>, String) {
6088        (vec![ScalarKey::Utf8(name.to_string())], "prob".to_string())
6089    }
6090
6091    #[test]
6092    fn test_monotonic_nor_first_update() {
6093        let mut agg = MonotonicAggState::new(make_nor_binding());
6094        let batch = make_f64_batch(&["a"], &[0.3]);
6095        let changed = agg
6096            .update(&[0], &[batch], false, SemiringKind::AddMultProb)
6097            .unwrap();
6098        assert!(changed);
6099        let val = agg.get_accumulator(&acc_key("a")).unwrap();
6100        assert!((val - 0.3).abs() < 1e-10, "expected 0.3, got {}", val);
6101    }
6102
6103    #[test]
6104    fn test_monotonic_nor_two_updates() {
6105        // Incremental NOR: acc = 1-(1-0.3)(1-0.5) = 0.65
6106        let mut agg = MonotonicAggState::new(make_nor_binding());
6107        let batch1 = make_f64_batch(&["a"], &[0.3]);
6108        agg.update(&[0], &[batch1], false, SemiringKind::AddMultProb)
6109            .unwrap();
6110        let batch2 = make_f64_batch(&["a"], &[0.5]);
6111        agg.update(&[0], &[batch2], false, SemiringKind::AddMultProb)
6112            .unwrap();
6113        let val = agg.get_accumulator(&acc_key("a")).unwrap();
6114        assert!((val - 0.65).abs() < 1e-10, "expected 0.65, got {}", val);
6115    }
6116
6117    #[test]
6118    fn test_monotonic_prod_first_update() {
6119        let mut agg = MonotonicAggState::new(make_prod_binding());
6120        let batch = make_f64_batch(&["a"], &[0.6]);
6121        let changed = agg
6122            .update(&[0], &[batch], false, SemiringKind::AddMultProb)
6123            .unwrap();
6124        assert!(changed);
6125        let val = agg.get_accumulator(&acc_key("a")).unwrap();
6126        assert!((val - 0.6).abs() < 1e-10, "expected 0.6, got {}", val);
6127    }
6128
6129    #[test]
6130    fn test_monotonic_prod_two_updates() {
6131        // Incremental PROD: acc = 0.6 * 0.8 = 0.48
6132        let mut agg = MonotonicAggState::new(make_prod_binding());
6133        let batch1 = make_f64_batch(&["a"], &[0.6]);
6134        agg.update(&[0], &[batch1], false, SemiringKind::AddMultProb)
6135            .unwrap();
6136        let batch2 = make_f64_batch(&["a"], &[0.8]);
6137        agg.update(&[0], &[batch2], false, SemiringKind::AddMultProb)
6138            .unwrap();
6139        let val = agg.get_accumulator(&acc_key("a")).unwrap();
6140        assert!((val - 0.48).abs() < 1e-10, "expected 0.48, got {}", val);
6141    }
6142
6143    #[test]
6144    fn test_monotonic_nor_stability() {
6145        let mut agg = MonotonicAggState::new(make_nor_binding());
6146        let batch = make_f64_batch(&["a"], &[0.3]);
6147        agg.update(&[0], &[batch], false, SemiringKind::AddMultProb)
6148            .unwrap();
6149        agg.snapshot();
6150        let changed = agg
6151            .update(&[0], &[], false, SemiringKind::AddMultProb)
6152            .unwrap();
6153        assert!(!changed);
6154        assert!(agg.is_stable());
6155    }
6156
6157    #[test]
6158    fn test_monotonic_prod_stability() {
6159        let mut agg = MonotonicAggState::new(make_prod_binding());
6160        let batch = make_f64_batch(&["a"], &[0.6]);
6161        agg.update(&[0], &[batch], false, SemiringKind::AddMultProb)
6162            .unwrap();
6163        agg.snapshot();
6164        let changed = agg
6165            .update(&[0], &[], false, SemiringKind::AddMultProb)
6166            .unwrap();
6167        assert!(!changed);
6168        assert!(agg.is_stable());
6169    }
6170
6171    #[test]
6172    fn test_monotonic_nor_multi_group() {
6173        // (a,0.3),(b,0.5) then (a,0.5),(b,0.2) → a=0.65, b=0.6
6174        let mut agg = MonotonicAggState::new(make_nor_binding());
6175        let batch1 = make_f64_batch(&["a", "b"], &[0.3, 0.5]);
6176        agg.update(&[0], &[batch1], false, SemiringKind::AddMultProb)
6177            .unwrap();
6178        let batch2 = make_f64_batch(&["a", "b"], &[0.5, 0.2]);
6179        agg.update(&[0], &[batch2], false, SemiringKind::AddMultProb)
6180            .unwrap();
6181
6182        let val_a = agg.get_accumulator(&acc_key("a")).unwrap();
6183        let val_b = agg.get_accumulator(&acc_key("b")).unwrap();
6184        assert!(
6185            (val_a - 0.65).abs() < 1e-10,
6186            "expected a=0.65, got {}",
6187            val_a
6188        );
6189        assert!((val_b - 0.6).abs() < 1e-10, "expected b=0.6, got {}", val_b);
6190    }
6191
6192    #[test]
6193    fn test_monotonic_prod_zero_absorbing() {
6194        // Zero absorbs: once 0.0, all further updates stay 0.0
6195        let mut agg = MonotonicAggState::new(make_prod_binding());
6196        let batch1 = make_f64_batch(&["a"], &[0.5]);
6197        agg.update(&[0], &[batch1], false, SemiringKind::AddMultProb)
6198            .unwrap();
6199        let batch2 = make_f64_batch(&["a"], &[0.0]);
6200        agg.update(&[0], &[batch2], false, SemiringKind::AddMultProb)
6201            .unwrap();
6202
6203        let val = agg.get_accumulator(&acc_key("a")).unwrap();
6204        assert!((val - 0.0).abs() < 1e-10, "expected 0.0, got {}", val);
6205
6206        // Further updates don't change the absorbing zero
6207        agg.snapshot();
6208        let batch3 = make_f64_batch(&["a"], &[0.5]);
6209        let changed = agg
6210            .update(&[0], &[batch3], false, SemiringKind::AddMultProb)
6211            .unwrap();
6212        assert!(!changed);
6213        assert!(agg.is_stable());
6214    }
6215
6216    #[test]
6217    fn test_monotonic_nor_clamping() {
6218        // 1.5 clamped to 1.0: acc = 1-(1-0)(1-1) = 1.0
6219        let mut agg = MonotonicAggState::new(make_nor_binding());
6220        let batch = make_f64_batch(&["a"], &[1.5]);
6221        agg.update(&[0], &[batch], false, SemiringKind::AddMultProb)
6222            .unwrap();
6223        let val = agg.get_accumulator(&acc_key("a")).unwrap();
6224        assert!((val - 1.0).abs() < 1e-10, "expected 1.0, got {}", val);
6225    }
6226
6227    #[test]
6228    fn test_monotonic_nor_absorbing() {
6229        // p=1.0 absorbs: 0.3 then 1.0 → 1.0
6230        let mut agg = MonotonicAggState::new(make_nor_binding());
6231        let batch1 = make_f64_batch(&["a"], &[0.3]);
6232        agg.update(&[0], &[batch1], false, SemiringKind::AddMultProb)
6233            .unwrap();
6234        let batch2 = make_f64_batch(&["a"], &[1.0]);
6235        agg.update(&[0], &[batch2], false, SemiringKind::AddMultProb)
6236            .unwrap();
6237        let val = agg.get_accumulator(&acc_key("a")).unwrap();
6238        assert!((val - 1.0).abs() < 1e-10, "expected 1.0, got {}", val);
6239    }
6240
6241    // ── MonotonicAggState strict mode tests (Phase 5) ─────────────────────
6242
6243    #[test]
6244    fn test_monotonic_agg_strict_nor_rejects() {
6245        let mut agg = MonotonicAggState::new(make_nor_binding());
6246        let batch = make_f64_batch(&["a"], &[1.5]);
6247        let result = agg.update(&[0], &[batch], true, SemiringKind::AddMultProb);
6248        assert!(result.is_err());
6249        let err = result.unwrap_err().to_string();
6250        assert!(
6251            err.contains("strict_probability_domain"),
6252            "Expected strict error, got: {}",
6253            err
6254        );
6255    }
6256
6257    #[test]
6258    fn test_monotonic_agg_strict_prod_rejects() {
6259        let mut agg = MonotonicAggState::new(make_prod_binding());
6260        let batch = make_f64_batch(&["a"], &[2.0]);
6261        let result = agg.update(&[0], &[batch], true, SemiringKind::AddMultProb);
6262        assert!(result.is_err());
6263        let err = result.unwrap_err().to_string();
6264        assert!(
6265            err.contains("strict_probability_domain"),
6266            "Expected strict error, got: {}",
6267            err
6268        );
6269    }
6270
6271    #[test]
6272    fn test_monotonic_agg_strict_accepts_valid() {
6273        let mut agg = MonotonicAggState::new(make_nor_binding());
6274        let batch = make_f64_batch(&["a"], &[0.5]);
6275        let result = agg.update(&[0], &[batch], true, SemiringKind::AddMultProb);
6276        assert!(result.is_ok());
6277        let val = agg.get_accumulator(&acc_key("a")).unwrap();
6278        assert!((val - 0.5).abs() < 1e-10, "expected 0.5, got {}", val);
6279    }
6280
6281    // ── Complement function unit tests (Phase 4) ──────────────────────────
6282
6283    fn make_vid_prob_batch(vids: &[u64], probs: &[f64]) -> RecordBatch {
6284        use arrow_array::UInt64Array;
6285        let schema = Arc::new(Schema::new(vec![
6286            Field::new("vid", DataType::UInt64, true),
6287            Field::new("prob", DataType::Float64, true),
6288        ]));
6289        RecordBatch::try_new(
6290            schema,
6291            vec![
6292                Arc::new(UInt64Array::from(vids.to_vec())),
6293                Arc::new(Float64Array::from(probs.to_vec())),
6294            ],
6295        )
6296        .unwrap()
6297    }
6298
6299    #[test]
6300    fn test_prob_complement_basic() {
6301        // neg has VID=1 with prob=0.7 → complement=0.3; VID=2 absent → complement=1.0
6302        let body = make_vid_prob_batch(&[1, 2], &[0.9, 0.8]);
6303        let neg = make_vid_prob_batch(&[1], &[0.7]);
6304        let join_cols = vec![("vid".to_string(), "vid".to_string())];
6305        let result = apply_prob_complement_composite(
6306            vec![body],
6307            &[neg],
6308            &join_cols,
6309            "prob",
6310            "__complement_0",
6311        )
6312        .unwrap();
6313        assert_eq!(result.len(), 1);
6314        let batch = &result[0];
6315        let complement = batch
6316            .column_by_name("__complement_0")
6317            .unwrap()
6318            .as_any()
6319            .downcast_ref::<Float64Array>()
6320            .unwrap();
6321        // VID=1: complement = 1 - 0.7 = 0.3
6322        assert!(
6323            (complement.value(0) - 0.3).abs() < 1e-10,
6324            "expected 0.3, got {}",
6325            complement.value(0)
6326        );
6327        // VID=2: absent from neg → complement = 1.0
6328        assert!(
6329            (complement.value(1) - 1.0).abs() < 1e-10,
6330            "expected 1.0, got {}",
6331            complement.value(1)
6332        );
6333    }
6334
6335    #[test]
6336    fn test_prob_complement_noisy_or_duplicates() {
6337        // neg has VID=1 twice with prob=0.3 and prob=0.5
6338        // Combined via noisy-OR: 1-(1-0.3)(1-0.5) = 0.65
6339        // Complement = 1 - 0.65 = 0.35
6340        let body = make_vid_prob_batch(&[1], &[0.9]);
6341        let neg = make_vid_prob_batch(&[1, 1], &[0.3, 0.5]);
6342        let join_cols = vec![("vid".to_string(), "vid".to_string())];
6343        let result = apply_prob_complement_composite(
6344            vec![body],
6345            &[neg],
6346            &join_cols,
6347            "prob",
6348            "__complement_0",
6349        )
6350        .unwrap();
6351        let batch = &result[0];
6352        let complement = batch
6353            .column_by_name("__complement_0")
6354            .unwrap()
6355            .as_any()
6356            .downcast_ref::<Float64Array>()
6357            .unwrap();
6358        assert!(
6359            (complement.value(0) - 0.35).abs() < 1e-10,
6360            "expected 0.35, got {}",
6361            complement.value(0)
6362        );
6363    }
6364
6365    #[test]
6366    fn test_prob_complement_empty_neg() {
6367        // Empty neg_facts → body passes through with complement=1.0
6368        let body = make_vid_prob_batch(&[1, 2], &[0.5, 0.6]);
6369        let join_cols = vec![("vid".to_string(), "vid".to_string())];
6370        let result =
6371            apply_prob_complement_composite(vec![body], &[], &join_cols, "prob", "__complement_0")
6372                .unwrap();
6373        let batch = &result[0];
6374        let complement = batch
6375            .column_by_name("__complement_0")
6376            .unwrap()
6377            .as_any()
6378            .downcast_ref::<Float64Array>()
6379            .unwrap();
6380        for i in 0..2 {
6381            assert!(
6382                (complement.value(i) - 1.0).abs() < 1e-10,
6383                "row {}: expected 1.0, got {}",
6384                i,
6385                complement.value(i)
6386            );
6387        }
6388    }
6389
6390    #[test]
6391    fn test_anti_join_basic() {
6392        // body [1,2,3], neg [2] → result [1,3]
6393        use arrow_array::UInt64Array;
6394        let body = make_vid_prob_batch(&[1, 2, 3], &[0.5, 0.6, 0.7]);
6395        let neg = make_vid_prob_batch(&[2], &[0.0]);
6396        let join_cols = vec![("vid".to_string(), "vid".to_string())];
6397        let result = apply_anti_join_composite(vec![body], &[neg], &join_cols).unwrap();
6398        assert_eq!(result.len(), 1);
6399        let batch = &result[0];
6400        assert_eq!(batch.num_rows(), 2);
6401        let vids = batch
6402            .column_by_name("vid")
6403            .unwrap()
6404            .as_any()
6405            .downcast_ref::<UInt64Array>()
6406            .unwrap();
6407        assert_eq!(vids.value(0), 1);
6408        assert_eq!(vids.value(1), 3);
6409    }
6410
6411    #[test]
6412    fn test_anti_join_empty_neg() {
6413        // Empty neg → all rows kept
6414        let body = make_vid_prob_batch(&[1, 2, 3], &[0.5, 0.6, 0.7]);
6415        let join_cols = vec![("vid".to_string(), "vid".to_string())];
6416        let result = apply_anti_join_composite(vec![body], &[], &join_cols).unwrap();
6417        assert_eq!(result.len(), 1);
6418        assert_eq!(result[0].num_rows(), 3);
6419    }
6420
6421    #[test]
6422    fn test_anti_join_all_excluded() {
6423        // neg covers all body rows → empty result
6424        let body = make_vid_prob_batch(&[1, 2], &[0.5, 0.6]);
6425        let neg = make_vid_prob_batch(&[1, 2], &[0.0, 0.0]);
6426        let join_cols = vec![("vid".to_string(), "vid".to_string())];
6427        let result = apply_anti_join_composite(vec![body], &[neg], &join_cols).unwrap();
6428        let total: usize = result.iter().map(|b| b.num_rows()).sum();
6429        assert_eq!(total, 0);
6430    }
6431
6432    #[test]
6433    fn test_multiply_prob_single_complement() {
6434        // prob=0.8, complement=0.5 → output prob=0.4; complement col removed
6435        let body = make_vid_prob_batch(&[1], &[0.8]);
6436        // Add a complement column
6437        let complement_arr = Float64Array::from(vec![0.5]);
6438        let mut cols: Vec<arrow_array::ArrayRef> = body.columns().to_vec();
6439        cols.push(Arc::new(complement_arr));
6440        let mut fields: Vec<Arc<Field>> = body.schema().fields().iter().cloned().collect();
6441        fields.push(Arc::new(Field::new(
6442            "__complement_0",
6443            DataType::Float64,
6444            true,
6445        )));
6446        let schema = Arc::new(Schema::new(fields));
6447        let batch = RecordBatch::try_new(schema, cols).unwrap();
6448
6449        let result =
6450            multiply_prob_factors(vec![batch], Some("prob"), &["__complement_0".to_string()])
6451                .unwrap();
6452        assert_eq!(result.len(), 1);
6453        let out = &result[0];
6454        // Complement column should be removed
6455        assert!(out.column_by_name("__complement_0").is_none());
6456        let prob = out
6457            .column_by_name("prob")
6458            .unwrap()
6459            .as_any()
6460            .downcast_ref::<Float64Array>()
6461            .unwrap();
6462        assert!(
6463            (prob.value(0) - 0.4).abs() < 1e-10,
6464            "expected 0.4, got {}",
6465            prob.value(0)
6466        );
6467    }
6468
6469    #[test]
6470    fn test_multiply_prob_multiple_complements() {
6471        // prob=0.8, c1=0.5, c2=0.6 → 0.8×0.5×0.6=0.24
6472        let body = make_vid_prob_batch(&[1], &[0.8]);
6473        let c1 = Float64Array::from(vec![0.5]);
6474        let c2 = Float64Array::from(vec![0.6]);
6475        let mut cols: Vec<arrow_array::ArrayRef> = body.columns().to_vec();
6476        cols.push(Arc::new(c1));
6477        cols.push(Arc::new(c2));
6478        let mut fields: Vec<Arc<Field>> = body.schema().fields().iter().cloned().collect();
6479        fields.push(Arc::new(Field::new("__c1", DataType::Float64, true)));
6480        fields.push(Arc::new(Field::new("__c2", DataType::Float64, true)));
6481        let schema = Arc::new(Schema::new(fields));
6482        let batch = RecordBatch::try_new(schema, cols).unwrap();
6483
6484        let result = multiply_prob_factors(
6485            vec![batch],
6486            Some("prob"),
6487            &["__c1".to_string(), "__c2".to_string()],
6488        )
6489        .unwrap();
6490        let out = &result[0];
6491        assert!(out.column_by_name("__c1").is_none());
6492        assert!(out.column_by_name("__c2").is_none());
6493        let prob = out
6494            .column_by_name("prob")
6495            .unwrap()
6496            .as_any()
6497            .downcast_ref::<Float64Array>()
6498            .unwrap();
6499        assert!(
6500            (prob.value(0) - 0.24).abs() < 1e-10,
6501            "expected 0.24, got {}",
6502            prob.value(0)
6503        );
6504    }
6505
6506    #[test]
6507    fn test_multiply_prob_no_prob_column() {
6508        // No prob column → combined complements become the output
6509        use arrow_array::UInt64Array;
6510        let schema = Arc::new(Schema::new(vec![
6511            Field::new("vid", DataType::UInt64, true),
6512            Field::new("__c1", DataType::Float64, true),
6513        ]));
6514        let batch = RecordBatch::try_new(
6515            schema,
6516            vec![
6517                Arc::new(UInt64Array::from(vec![1u64])),
6518                Arc::new(Float64Array::from(vec![0.7])),
6519            ],
6520        )
6521        .unwrap();
6522
6523        let result = multiply_prob_factors(vec![batch], None, &["__c1".to_string()]).unwrap();
6524        let out = &result[0];
6525        // __c1 should be removed since it's a complement column
6526        assert!(out.column_by_name("__c1").is_none());
6527        // Only vid column remains
6528        assert_eq!(out.num_columns(), 1);
6529    }
6530}