Skip to main content

uni_query/query/df_graph/
locy_fixpoint.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Fixpoint iteration operator for recursive Locy strata.
5//!
6//! `FixpointExec` drives semi-naive evaluation: it repeatedly evaluates the rules
7//! in a recursive stratum, feeding back deltas until no new facts are produced.
8
9use crate::query::df_graph::GraphExecutionContext;
10use crate::query::df_graph::common::{
11    ScalarKey, arrow_err, collect_all_partitions, compute_plan_properties, execute_subplan,
12    execute_subplan_collecting, extract_scalar_key,
13};
14use crate::query::df_graph::locy_best_by::{BestByExec, SortCriterion};
15use crate::query::df_graph::locy_errors::LocyRuntimeError;
16use crate::query::df_graph::locy_explain::{
17    ProofTerm, ProvenanceAnnotation, ProvenanceStore, compute_proof_probability,
18};
19use crate::query::df_graph::locy_fold::{FoldBinding, FoldExec};
20use crate::query::df_graph::locy_priority::PriorityExec;
21use crate::query::df_graph::locy_profile::LocyProfileCollector;
22use crate::query::df_graph::locy_program::interruption;
23use crate::query::executor::core::OperatorStats;
24use crate::query::planner::LogicalPlan;
25use arrow_array::RecordBatch;
26use arrow_row::{RowConverter, SortField};
27use arrow_schema::SchemaRef;
28use datafusion::common::JoinType;
29use datafusion::common::Result as DFResult;
30use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
31use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode};
32use datafusion::physical_plan::memory::MemoryStream;
33use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
34use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
35use futures::Stream;
36use parking_lot::RwLock;
37use std::any::Any;
38use std::collections::{HashMap, HashSet};
39use std::fmt;
40use std::pin::Pin;
41use std::sync::{Arc, RwLock as StdRwLock};
42use std::task::{Context, Poll};
43use std::time::{Duration, Instant};
44use uni_common::Value;
45use uni_common::core::schema::Schema as UniSchema;
46use uni_cypher::ast::Expr;
47use uni_locy::{
48    ClassifierRegistry, ModelInvocation, ModelInvocationCache, RuntimeWarning, RuntimeWarningCode,
49    SemiringKind,
50};
51use uni_store::storage::manager::StorageManager;
52
53// ---------------------------------------------------------------------------
54// DerivedScanRegistry — injection point for IS-ref data into subplans
55// ---------------------------------------------------------------------------
56
57/// A single entry in the derived scan registry.
58///
59/// Each entry corresponds to one `LocyDerivedScan` node in the logical plan tree.
60/// The `data` handle is shared with the logical plan node so that writing data here
61/// makes it visible when the subplan is re-planned and executed.
62#[derive(Debug)]
63pub struct DerivedScanEntry {
64    /// Index matching the `scan_index` in `LocyDerivedScan`.
65    pub scan_index: usize,
66    /// Name of the rule this scan reads from.
67    pub rule_name: String,
68    /// Whether this is a self-referential scan (rule references itself).
69    pub is_self_ref: bool,
70    /// Shared data handle — write batches here to inject into subplans.
71    pub data: Arc<RwLock<Vec<RecordBatch>>>,
72    /// Schema of the derived relation.
73    pub schema: SchemaRef,
74}
75
76/// Registry of derived scan handles for fixpoint iteration.
77///
78/// During fixpoint, each clause body may reference derived relations via
79/// `LocyDerivedScan` nodes. The registry maps scan indices to shared data
80/// handles so the fixpoint loop can inject delta/full facts before each
81/// iteration.
82#[derive(Debug, Default)]
83pub struct DerivedScanRegistry {
84    entries: Vec<DerivedScanEntry>,
85}
86
87impl DerivedScanRegistry {
88    /// Create a new empty registry.
89    pub fn new() -> Self {
90        Self::default()
91    }
92
93    /// Add an entry to the registry.
94    pub fn add(&mut self, entry: DerivedScanEntry) {
95        self.entries.push(entry);
96    }
97
98    /// Get an entry by scan index.
99    pub fn get(&self, scan_index: usize) -> Option<&DerivedScanEntry> {
100        self.entries.iter().find(|e| e.scan_index == scan_index)
101    }
102
103    /// Write data into a scan entry's shared handle.
104    pub fn write_data(&self, scan_index: usize, batches: Vec<RecordBatch>) {
105        if let Some(entry) = self.get(scan_index) {
106            let mut guard = entry.data.write();
107            *guard = batches;
108        }
109    }
110
111    /// Get all entries for a given rule name.
112    pub fn entries_for_rule(&self, rule_name: &str) -> Vec<&DerivedScanEntry> {
113        self.entries
114            .iter()
115            .filter(|e| e.rule_name == rule_name)
116            .collect()
117    }
118}
119
120// ---------------------------------------------------------------------------
121// MonotonicAggState — tracking monotonic aggregates across iterations
122// ---------------------------------------------------------------------------
123
124/// Monotonic aggregate binding: maps a fold name to its aggregate
125/// trait object and input column.
126///
127/// Dispatches purely through [`uni_plugin::traits::locy::LocyAggregate`]
128/// (`update_step` / `initial_accum_f64` / `is_probability_aggregate` /
129/// `is_noisy_or`).
130#[derive(Debug, Clone)]
131pub struct MonotonicFoldBinding {
132    pub fold_name: String,
133    pub aggregate: std::sync::Arc<dyn uni_plugin::traits::locy::LocyAggregate>,
134    pub input_col_index: usize,
135    /// Column name for name-based resolution (more robust than positional index).
136    pub input_col_name: Option<String>,
137}
138
139/// Tracks monotonic aggregate accumulators across fixpoint iterations.
140///
141/// After each iteration, accumulators are updated and compared to their previous
142/// snapshot. The fixpoint has converged (w.r.t. aggregates) when all accumulators
143/// are stable (no change between iterations).
144#[derive(Debug)]
145pub struct MonotonicAggState {
146    /// Current accumulator values keyed by (group_key, fold_name).
147    accumulators: HashMap<(Vec<ScalarKey>, String), f64>,
148    /// Snapshot from the previous iteration for stability check.
149    prev_snapshot: HashMap<(Vec<ScalarKey>, String), f64>,
150    /// Bindings describing which aggregates to track.
151    bindings: Vec<MonotonicFoldBinding>,
152}
153
154impl MonotonicAggState {
155    /// Create a new monotonic aggregate state.
156    pub fn new(bindings: Vec<MonotonicFoldBinding>) -> Self {
157        Self {
158            accumulators: HashMap::new(),
159            prev_snapshot: HashMap::new(),
160            bindings,
161        }
162    }
163
164    /// Update accumulators with new delta batches.
165    ///
166    /// Returns `true` if any accumulator value changed. When `strict` is
167    /// `true`, MNOR/MPROD inputs outside `[0, 1]` produce an error
168    /// instead of being clamped.
169    ///
170    /// `semiring_kind` selects the probability semiring for probability
171    /// aggregates: `AddMultProb` (default, Phase 1/2 noisy-OR/product) or
172    /// `MaxMinProb` (Viterbi/fuzzy — opt-in, callers emit
173    /// `FuzzyNotProbabilistic`).
174    ///
175    /// Dispatch goes through each binding's `Arc<dyn LocyAggregate>` trait
176    /// object via [`uni_plugin::traits::locy::LocyAggregate::update_step`].
177    /// The trait object's `initial_accum_f64()` seeds the per-group
178    /// accumulator. Under `MaxMinProb`, probability aggregates (MNOR /
179    /// MPROD) bypass `update_step` and fold via the `MaxMinProb` semiring's
180    /// `plus` (max) / `times` (min) instead — preserving the opt-in
181    /// Viterbi/fuzzy semantics that `update_step`'s built-in noisy-OR /
182    /// product path does not implement.
183    ///
184    /// Aggregates whose `update_step` returns `Err(CODE_UNKNOWN_FUNCTION)`
185    /// (default impl — no row-level fast path; e.g., `AVG`, `COLLECT`)
186    /// are skipped silently here — those run through the batch-shape
187    /// [`uni_plugin::traits::locy::LocyAggState::ingest`] path in
188    /// `apply_post_fixpoint_chain` instead.
189    pub fn update(
190        &mut self,
191        key_indices: &[usize],
192        delta_batches: &[RecordBatch],
193        strict: bool,
194        semiring_kind: SemiringKind,
195    ) -> DFResult<bool> {
196        let mut changed = false;
197        for batch in delta_batches {
198            for row_idx in 0..batch.num_rows() {
199                let group_key = extract_scalar_key(batch, key_indices, row_idx);
200                for binding in &self.bindings {
201                    let idx = binding
202                        .input_col_name
203                        .as_ref()
204                        .and_then(|name| batch.schema().index_of(name).ok())
205                        .unwrap_or(binding.input_col_index);
206                    if idx >= batch.num_columns() {
207                        continue;
208                    }
209                    let col = batch.column(idx);
210                    let val = extract_f64(col.as_ref(), row_idx);
211                    if let Some(val) = val {
212                        let map_key = (group_key.clone(), binding.fold_name.clone());
213                        let initial = binding.aggregate.initial_accum_f64().unwrap_or(0.0);
214                        let entry = self.accumulators.entry(map_key).or_insert(initial);
215                        let old = *entry;
216                        // Under `MaxMinProb`, probability aggregates (MNOR /
217                        // MPROD) fold via the Viterbi/fuzzy semiring (max /
218                        // min) rather than the trait object's built-in
219                        // noisy-OR / product `update_step`. The inline domain
220                        // checks below preserve the exact strict-mode error
221                        // and clamp-warning literals. `is_noisy_or()`
222                        // distinguishes MNOR (disjunction → max) from MPROD
223                        // (conjunction → min). All other aggregates — and the
224                        // default `AddMultProb` semiring — dispatch through
225                        // the trait object's `update_step`.
226                        if matches!(semiring_kind, SemiringKind::MaxMinProb)
227                            && binding.aggregate.is_probability_aggregate()
228                        {
229                            use uni_locy::LocySemiring;
230                            let sr = uni_locy::MaxMinProb;
231                            let is_nor = binding.aggregate.is_noisy_or();
232                            let label = if is_nor { "MNOR" } else { "MPROD" };
233                            if strict && !(0.0..=1.0).contains(&val) {
234                                return Err(datafusion::error::DataFusionError::Execution(
235                                    format!(
236                                        "strict_probability_domain: {label} input {val} is outside [0, 1]"
237                                    ),
238                                ));
239                            }
240                            if !strict && !(0.0..=1.0).contains(&val) {
241                                tracing::warn!(
242                                    "{label} input {val} outside [0,1], clamped to {}",
243                                    val.clamp(0.0, 1.0)
244                                );
245                            }
246                            let p = val.clamp(0.0, 1.0);
247                            // MaxMinProb: MNOR -> max (plus), MPROD -> min (times).
248                            *entry = if is_nor {
249                                sr.plus(entry, &p)
250                            } else {
251                                sr.times(entry, &p)
252                            };
253                            if (*entry - old).abs() > f64::EPSILON {
254                                changed = true;
255                            }
256                            continue;
257                        }
258                        match binding.aggregate.update_step(*entry, val, strict) {
259                            Ok(new_val) => {
260                                *entry = new_val;
261                                if (*entry - old).abs() > f64::EPSILON {
262                                    changed = true;
263                                }
264                            }
265                            Err(e) if e.code == uni_plugin::FnError::CODE_UNKNOWN_FUNCTION => {
266                                // Aggregate has no row-level fast path (AVG,
267                                // COLLECT). Those run through the
268                                // batch-shape `ingest` path elsewhere; skip.
269                            }
270                            Err(e) => {
271                                // Strict-mode probability-domain violation,
272                                // or another aggregate-specific failure.
273                                return Err(datafusion::error::DataFusionError::Execution(
274                                    e.message,
275                                ));
276                            }
277                        }
278                    }
279                }
280            }
281        }
282        Ok(changed)
283    }
284
285    /// Take a snapshot of current accumulators for stability comparison.
286    pub fn snapshot(&mut self) {
287        self.prev_snapshot = self.accumulators.clone();
288    }
289
290    /// Check if accumulators are stable (no change since last snapshot).
291    pub fn is_stable(&self) -> bool {
292        if self.accumulators.len() != self.prev_snapshot.len() {
293            return false;
294        }
295        for (key, val) in &self.accumulators {
296            match self.prev_snapshot.get(key) {
297                Some(prev) if (*val - *prev).abs() <= f64::EPSILON => {}
298                _ => return false,
299            }
300        }
301        true
302    }
303
304    /// Test-only accessor for accumulator values.
305    #[cfg(test)]
306    pub(crate) fn get_accumulator(&self, key: &(Vec<ScalarKey>, String)) -> Option<f64> {
307        self.accumulators.get(key).copied()
308    }
309}
310
311/// Extract f64 value from an Arrow column at a given row index.
312fn extract_f64(col: &dyn arrow_array::Array, row_idx: usize) -> Option<f64> {
313    if col.is_null(row_idx) {
314        return None;
315    }
316    if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Float64Array>() {
317        Some(arr.value(row_idx))
318    } else {
319        col.as_any()
320            .downcast_ref::<arrow_array::Int64Array>()
321            .map(|arr| arr.value(row_idx) as f64)
322    }
323}
324
325// ---------------------------------------------------------------------------
326// RowDedupState — Arrow RowConverter-based persistent dedup set
327// ---------------------------------------------------------------------------
328
329/// Arrow-native row deduplication using [`RowConverter`].
330///
331/// Unlike the legacy `HashSet<Vec<ScalarKey>>` approach, this struct maintains a
332/// persistent `seen` set across iterations so per-iteration cost is O(M) where M
333/// is the number of candidate rows — the full facts table is never re-scanned.
334struct RowDedupState {
335    converter: RowConverter,
336    seen: HashSet<Box<[u8]>>,
337}
338
339impl RowDedupState {
340    /// Try to build a `RowDedupState` for the given schema.
341    ///
342    /// Returns `None` if any column type is not supported by `RowConverter`
343    /// (triggers legacy fallback).
344    fn try_new(schema: &SchemaRef) -> Option<Self> {
345        let fields: Vec<SortField> = schema
346            .fields()
347            .iter()
348            .map(|f| SortField::new(f.data_type().clone()))
349            .collect();
350        match RowConverter::new(fields) {
351            Ok(converter) => Some(Self {
352                converter,
353                seen: HashSet::new(),
354            }),
355            Err(e) => {
356                tracing::warn!(
357                    "RowDedupState: RowConverter unsupported for schema, falling back to legacy dedup: {}",
358                    e
359                );
360                None
361            }
362        }
363    }
364
365    /// Populate the seen set from existing fact batches.
366    ///
367    /// Used after BEST BY in-loop pruning replaces the fact set, so that delta
368    /// computation in subsequent iterations correctly recognizes surviving facts.
369    fn ingest_existing(&mut self, facts: &[RecordBatch], _schema: &SchemaRef) {
370        self.seen.clear();
371        for batch in facts {
372            if batch.num_rows() == 0 {
373                continue;
374            }
375            let arrays: Vec<_> = batch.columns().to_vec();
376            if let Ok(rows) = self.converter.convert_columns(&arrays) {
377                for row_idx in 0..batch.num_rows() {
378                    let row_bytes: Box<[u8]> = rows.row(row_idx).data().into();
379                    self.seen.insert(row_bytes);
380                }
381            }
382        }
383    }
384
385    /// Filter `candidates` to only rows not yet seen, updating the persistent set.
386    ///
387    /// Both cross-iteration dedup (rows already accepted in prior iterations) and
388    /// within-batch dedup (duplicate rows in a single candidate batch) are handled
389    /// in a single pass.
390    fn compute_delta(
391        &mut self,
392        candidates: &[RecordBatch],
393        schema: &SchemaRef,
394    ) -> DFResult<Vec<RecordBatch>> {
395        let mut delta_batches = Vec::new();
396        for batch in candidates {
397            if batch.num_rows() == 0 {
398                continue;
399            }
400
401            // Vectorized encoding of all rows in this batch.
402            let arrays: Vec<_> = batch.columns().to_vec();
403            let rows = self.converter.convert_columns(&arrays).map_err(arrow_err)?;
404
405            // One pass: check+insert into persistent seen set.
406            let mut keep = Vec::with_capacity(batch.num_rows());
407            for row_idx in 0..batch.num_rows() {
408                let row_bytes: Box<[u8]> = rows.row(row_idx).data().into();
409                keep.push(self.seen.insert(row_bytes));
410            }
411
412            let keep_mask = arrow_array::BooleanArray::from(keep);
413            let new_cols = batch
414                .columns()
415                .iter()
416                .map(|col| {
417                    arrow::compute::filter(col.as_ref(), &keep_mask).map_err(|e| {
418                        datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
419                    })
420                })
421                .collect::<DFResult<Vec<_>>>()?;
422
423            if new_cols.first().is_some_and(|c| !c.is_empty()) {
424                let filtered = RecordBatch::try_new(Arc::clone(schema), new_cols).map_err(|e| {
425                    datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
426                })?;
427                delta_batches.push(filtered);
428            }
429        }
430        Ok(delta_batches)
431    }
432}
433
434// ---------------------------------------------------------------------------
435// FixpointState — per-rule delta tracking during fixpoint iteration
436// ---------------------------------------------------------------------------
437
438/// Per-rule state for fixpoint iteration.
439///
440/// Tracks accumulated facts and the delta (new facts from the latest iteration).
441/// Deduplication uses Arrow [`RowConverter`] with a persistent seen set (O(M) per
442/// iteration) when supported, with a legacy `HashSet<Vec<ScalarKey>>` fallback.
443pub struct FixpointState {
444    rule_name: String,
445    facts: Vec<RecordBatch>,
446    delta: Vec<RecordBatch>,
447    schema: SchemaRef,
448    key_column_indices: Vec<usize>,
449    /// KEY column names for recomputing indices after schema reconciliation.
450    key_column_names: Vec<String>,
451    /// All column indices for full-row dedup (legacy path only).
452    all_column_indices: Vec<usize>,
453    /// Running total of facts bytes for memory limit tracking.
454    facts_bytes: usize,
455    /// Maximum bytes allowed for this derived relation.
456    max_derived_bytes: usize,
457    /// Optional monotonic aggregate tracking.
458    monotonic_agg: Option<MonotonicAggState>,
459    /// Arrow RowConverter-based dedup state; `None` triggers legacy fallback.
460    row_dedup: Option<RowDedupState>,
461    /// Whether strict probability domain checks are enabled.
462    strict_probability_domain: bool,
463    /// Active probability semiring for this rule's MNOR/MPROD math.
464    semiring_kind: SemiringKind,
465}
466
467impl FixpointState {
468    /// Create a new fixpoint state for a rule. Existing tests call this
469    /// with the Phase 1/2 default; the fixpoint planner uses
470    /// [`FixpointState::new_with_semiring`] to thread the configured
471    /// semiring through.
472    pub fn new(
473        rule_name: String,
474        schema: SchemaRef,
475        key_column_indices: Vec<usize>,
476        max_derived_bytes: usize,
477        monotonic_agg: Option<MonotonicAggState>,
478        strict_probability_domain: bool,
479    ) -> Self {
480        Self::new_with_semiring(
481            rule_name,
482            schema,
483            key_column_indices,
484            max_derived_bytes,
485            monotonic_agg,
486            strict_probability_domain,
487            SemiringKind::AddMultProb,
488        )
489    }
490
491    pub fn new_with_semiring(
492        rule_name: String,
493        schema: SchemaRef,
494        key_column_indices: Vec<usize>,
495        max_derived_bytes: usize,
496        monotonic_agg: Option<MonotonicAggState>,
497        strict_probability_domain: bool,
498        semiring_kind: SemiringKind,
499    ) -> Self {
500        let num_cols = schema.fields().len();
501        let row_dedup = RowDedupState::try_new(&schema);
502        let key_column_names: Vec<String> = key_column_indices
503            .iter()
504            .filter_map(|&i| schema.fields().get(i).map(|f| f.name().clone()))
505            .collect();
506        Self {
507            rule_name,
508            facts: Vec::new(),
509            delta: Vec::new(),
510            schema,
511            key_column_indices,
512            key_column_names,
513            all_column_indices: (0..num_cols).collect(),
514            facts_bytes: 0,
515            max_derived_bytes,
516            monotonic_agg,
517            row_dedup,
518            strict_probability_domain,
519            semiring_kind,
520        }
521    }
522
523    /// Reconcile the pre-computed schema with the actual physical plan output.
524    ///
525    /// `infer_expr_type` may guess wrong (e.g. `Property → Float64` for a
526    /// string column).  When the first real batch arrives with a different
527    /// schema, update ours so that `RowDedupState` / `RecordBatch::try_new`
528    /// use the correct types.
529    fn reconcile_schema(&mut self, actual_schema: &SchemaRef) {
530        if self.schema.fields() != actual_schema.fields() {
531            tracing::debug!(
532                rule = %self.rule_name,
533                "Reconciling fixpoint schema from physical plan output",
534            );
535            self.schema = Arc::clone(actual_schema);
536            self.row_dedup = RowDedupState::try_new(&self.schema);
537            // Recompute key_column_indices from stored KEY column names.
538            // Without this, FoldExec groups by wrong columns when the
539            // physical plan reorders columns vs the pre-inferred schema.
540            let new_indices: Vec<usize> = self
541                .key_column_names
542                .iter()
543                .filter_map(|name| actual_schema.index_of(name).ok())
544                .collect();
545            if new_indices.len() == self.key_column_names.len() {
546                self.key_column_indices = new_indices;
547            }
548            // else: not all KEY columns found in new schema — keep original indices
549            let num_cols = actual_schema.fields().len();
550            self.all_column_indices = (0..num_cols).collect();
551        }
552    }
553
554    /// Merge candidate rows into facts, computing delta (truly new rows).
555    ///
556    /// Returns `true` if any new facts were added.
557    pub async fn merge_delta(
558        &mut self,
559        candidates: Vec<RecordBatch>,
560        task_ctx: Option<Arc<TaskContext>>,
561    ) -> DFResult<bool> {
562        if candidates.is_empty() || candidates.iter().all(|b| b.num_rows() == 0) {
563            self.delta.clear();
564            return Ok(false);
565        }
566
567        // Reconcile schema from the first non-empty candidate batch.
568        // The physical plan's output types are authoritative over the
569        // planner's inferred types.
570        if let Some(first) = candidates.iter().find(|b| b.num_rows() > 0) {
571            self.reconcile_schema(&first.schema());
572        }
573
574        // Round floats for stable dedup
575        let candidates = round_float_columns(&candidates);
576
577        // Compute delta: rows in candidates not already in facts
578        let delta = self.compute_delta(&candidates, task_ctx.as_ref()).await?;
579
580        if delta.is_empty() || delta.iter().all(|b| b.num_rows() == 0) {
581            self.delta.clear();
582            // Update monotonic aggs even with empty delta (for stability check)
583            if let Some(ref mut agg) = self.monotonic_agg {
584                agg.snapshot();
585            }
586            return Ok(false);
587        }
588
589        // Check memory limit
590        let delta_bytes: usize = delta.iter().map(batch_byte_size).sum();
591        if self.facts_bytes + delta_bytes > self.max_derived_bytes {
592            return Err(datafusion::error::DataFusionError::Execution(
593                LocyRuntimeError::MemoryLimitExceeded {
594                    rule: self.rule_name.clone(),
595                    bytes: self.facts_bytes + delta_bytes,
596                    limit: self.max_derived_bytes,
597                }
598                .to_string(),
599            ));
600        }
601
602        // Update monotonic aggs
603        if let Some(ref mut agg) = self.monotonic_agg {
604            agg.snapshot();
605            agg.update(
606                &self.key_column_indices,
607                &delta,
608                self.strict_probability_domain,
609                self.semiring_kind,
610            )?;
611        }
612
613        // Append delta to facts
614        self.facts_bytes += delta_bytes;
615        self.facts.extend(delta.iter().cloned());
616        self.delta = delta;
617
618        Ok(true)
619    }
620
621    /// Dispatch to vectorized LeftAntiJoin, Arrow RowConverter dedup, or legacy ScalarKey dedup.
622    ///
623    /// Priority order:
624    /// 1. `arrow_left_anti_dedup` when `total_existing >= DEDUP_ANTI_JOIN_THRESHOLD` and task_ctx available.
625    /// 2. `RowDedupState` (persistent HashSet, O(M) per iteration) when schema is supported.
626    /// 3. `compute_delta_legacy` (rebuilds from facts, fallback for unsupported column types).
627    async fn compute_delta(
628        &mut self,
629        candidates: &[RecordBatch],
630        task_ctx: Option<&Arc<TaskContext>>,
631    ) -> DFResult<Vec<RecordBatch>> {
632        let total_existing: usize = self.facts.iter().map(|b| b.num_rows()).sum();
633        if total_existing >= DEDUP_ANTI_JOIN_THRESHOLD
634            && let Some(ctx) = task_ctx
635        {
636            return arrow_left_anti_dedup(candidates.to_vec(), &self.facts, &self.schema, ctx)
637                .await;
638        }
639        if let Some(ref mut rd) = self.row_dedup {
640            rd.compute_delta(candidates, &self.schema)
641        } else {
642            self.compute_delta_legacy(candidates)
643        }
644    }
645
646    /// Legacy dedup: rebuild a `HashSet<Vec<ScalarKey>>` from all facts each call.
647    ///
648    /// Used as fallback when `RowConverter` does not support the schema's column types.
649    fn compute_delta_legacy(&self, candidates: &[RecordBatch]) -> DFResult<Vec<RecordBatch>> {
650        // Build set of existing fact row keys (ALL columns)
651        let mut existing: HashSet<Vec<ScalarKey>> = HashSet::new();
652        for batch in &self.facts {
653            for row_idx in 0..batch.num_rows() {
654                let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
655                existing.insert(key);
656            }
657        }
658
659        let mut delta_batches = Vec::new();
660        for batch in candidates {
661            if batch.num_rows() == 0 {
662                continue;
663            }
664            // Filter to only new rows
665            let mut keep = Vec::with_capacity(batch.num_rows());
666            for row_idx in 0..batch.num_rows() {
667                let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
668                keep.push(!existing.contains(&key));
669            }
670
671            // Also dedup within the candidate batch itself
672            for (row_idx, kept) in keep.iter_mut().enumerate() {
673                if *kept {
674                    let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
675                    if !existing.insert(key) {
676                        *kept = false;
677                    }
678                }
679            }
680
681            let keep_mask = arrow_array::BooleanArray::from(keep);
682            let new_rows = batch
683                .columns()
684                .iter()
685                .map(|col| {
686                    arrow::compute::filter(col.as_ref(), &keep_mask).map_err(|e| {
687                        datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
688                    })
689                })
690                .collect::<DFResult<Vec<_>>>()?;
691
692            if new_rows.first().is_some_and(|c| !c.is_empty()) {
693                let filtered =
694                    RecordBatch::try_new(Arc::clone(&self.schema), new_rows).map_err(|e| {
695                        datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
696                    })?;
697                delta_batches.push(filtered);
698            }
699        }
700
701        Ok(delta_batches)
702    }
703
704    /// Check if this rule has converged (no new facts and aggs stable).
705    pub fn is_converged(&self) -> bool {
706        let delta_empty = self.delta.is_empty() || self.delta.iter().all(|b| b.num_rows() == 0);
707        let agg_stable = self.monotonic_agg.as_ref().is_none_or(|a| a.is_stable());
708        delta_empty && agg_stable
709    }
710
711    /// Get all accumulated facts.
712    pub fn all_facts(&self) -> &[RecordBatch] {
713        &self.facts
714    }
715
716    /// Get the delta from the latest iteration.
717    pub fn all_delta(&self) -> &[RecordBatch] {
718        &self.delta
719    }
720
721    /// Consume self and return facts.
722    pub fn into_facts(self) -> Vec<RecordBatch> {
723        self.facts
724    }
725
726    /// Merge candidates using BEST BY semantics.
727    ///
728    /// Combines existing facts with new candidates, keeping only the best row
729    /// per KEY group according to `sort_criteria`. Returns `true` if the
730    /// best-per-KEY fact set actually changed (a genuinely better value was
731    /// found or a new KEY appeared).
732    ///
733    /// This replaces `merge_delta` for rules with BEST BY, enabling convergence
734    /// on cyclic graphs where dominated ALONG values would otherwise produce an
735    /// unbounded stream of "new" full-row facts.
736    pub fn merge_best_by(
737        &mut self,
738        candidates: Vec<RecordBatch>,
739        sort_criteria: &[SortCriterion],
740    ) -> DFResult<bool> {
741        if candidates.is_empty() || candidates.iter().all(|b| b.num_rows() == 0) {
742            self.delta.clear();
743            return Ok(false);
744        }
745
746        // Reconcile schema from the first non-empty candidate batch.
747        if let Some(first) = candidates.iter().find(|b| b.num_rows() > 0) {
748            self.reconcile_schema(&first.schema());
749        }
750
751        // Round floats for stable dedup.
752        let candidates = round_float_columns(&candidates);
753
754        // Snapshot existing best-per-KEY facts for change detection.
755        let old_best: HashMap<Vec<ScalarKey>, Vec<ScalarKey>> =
756            self.build_key_criteria_map(sort_criteria);
757
758        // Concat existing facts + new candidates.
759        let mut all_batches = self.facts.clone();
760        all_batches.extend(candidates);
761        let all_batches: Vec<_> = all_batches
762            .into_iter()
763            .filter(|b| b.num_rows() > 0)
764            .collect();
765        if all_batches.is_empty() {
766            self.delta.clear();
767            return Ok(false);
768        }
769
770        let combined = arrow::compute::concat_batches(&self.schema, &all_batches)
771            .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
772
773        if combined.num_rows() == 0 {
774            self.delta.clear();
775            return Ok(false);
776        }
777
778        // Sort by KEY ASC then criteria, so the best row per KEY group comes
779        // first.
780        let mut sort_columns = Vec::new();
781        for &ki in &self.key_column_indices {
782            if ki >= combined.num_columns() {
783                continue;
784            }
785            sort_columns.push(arrow::compute::SortColumn {
786                values: Arc::clone(combined.column(ki)),
787                options: Some(arrow::compute::SortOptions {
788                    descending: false,
789                    nulls_first: false,
790                }),
791            });
792        }
793        for criterion in sort_criteria {
794            if criterion.col_index >= combined.num_columns() {
795                continue;
796            }
797            sort_columns.push(arrow::compute::SortColumn {
798                values: Arc::clone(combined.column(criterion.col_index)),
799                options: Some(arrow::compute::SortOptions {
800                    descending: !criterion.ascending,
801                    nulls_first: criterion.nulls_first,
802                }),
803            });
804        }
805
806        let sorted_indices =
807            arrow::compute::lexsort_to_indices(&sort_columns, None).map_err(arrow_err)?;
808        let sorted_columns: Vec<_> = combined
809            .columns()
810            .iter()
811            .map(|col| arrow::compute::take(col.as_ref(), &sorted_indices, None))
812            .collect::<Result<Vec<_>, _>>()
813            .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
814        let sorted = RecordBatch::try_new(Arc::clone(&self.schema), sorted_columns)
815            .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
816
817        // Dedup: keep first (best) row per KEY group.
818        let mut keep_indices: Vec<u32> = Vec::new();
819        let mut prev_key: Option<Vec<ScalarKey>> = None;
820        for row_idx in 0..sorted.num_rows() {
821            let key = extract_scalar_key(&sorted, &self.key_column_indices, row_idx);
822            let is_new_group = match &prev_key {
823                None => true,
824                Some(prev) => *prev != key,
825            };
826            if is_new_group {
827                keep_indices.push(row_idx as u32);
828                prev_key = Some(key);
829            }
830        }
831
832        let keep_array = arrow_array::UInt32Array::from(keep_indices);
833        let output_columns: Vec<_> = sorted
834            .columns()
835            .iter()
836            .map(|col| arrow::compute::take(col.as_ref(), &keep_array, None))
837            .collect::<Result<Vec<_>, _>>()
838            .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
839        let pruned = RecordBatch::try_new(Arc::clone(&self.schema), output_columns)
840            .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
841
842        // Detect whether the best-per-KEY set actually changed.
843        let new_best: HashMap<Vec<ScalarKey>, Vec<ScalarKey>> = {
844            let mut map = HashMap::new();
845            for row_idx in 0..pruned.num_rows() {
846                let key = extract_scalar_key(&pruned, &self.key_column_indices, row_idx);
847                let criteria: Vec<ScalarKey> = sort_criteria
848                    .iter()
849                    .flat_map(|c| extract_scalar_key(&pruned, &[c.col_index], row_idx))
850                    .collect();
851                map.insert(key, criteria);
852            }
853            map
854        };
855        let changed = old_best != new_best;
856
857        tracing::debug!(
858            rule = %self.rule_name,
859            old_keys = old_best.len(),
860            new_keys = new_best.len(),
861            changed = changed,
862            "BEST BY merge"
863        );
864
865        // Replace facts with the pruned set.
866        self.facts_bytes = batch_byte_size(&pruned);
867        self.facts = vec![pruned];
868        if changed {
869            // Delta is conceptually the new/improved facts, but since we
870            // replaced the entire set, just mark delta non-empty.
871            self.delta = self.facts.clone();
872        } else {
873            self.delta.clear();
874        }
875
876        // Rebuild row dedup from pruned facts for consistency.
877        self.row_dedup = RowDedupState::try_new(&self.schema);
878        if let Some(ref mut rd) = self.row_dedup {
879            rd.ingest_existing(&self.facts, &self.schema);
880        }
881
882        Ok(changed)
883    }
884
885    /// Build a map from KEY column values to sort criteria values.
886    fn build_key_criteria_map(
887        &self,
888        sort_criteria: &[SortCriterion],
889    ) -> HashMap<Vec<ScalarKey>, Vec<ScalarKey>> {
890        let mut map = HashMap::new();
891        for batch in &self.facts {
892            for row_idx in 0..batch.num_rows() {
893                let key = extract_scalar_key(batch, &self.key_column_indices, row_idx);
894                let criteria: Vec<ScalarKey> = sort_criteria
895                    .iter()
896                    .flat_map(|c| extract_scalar_key(batch, &[c.col_index], row_idx))
897                    .collect();
898                map.insert(key, criteria);
899            }
900        }
901        map
902    }
903}
904
905/// Estimate byte size of a RecordBatch.
906fn batch_byte_size(batch: &RecordBatch) -> usize {
907    batch
908        .columns()
909        .iter()
910        .map(|col| col.get_buffer_memory_size())
911        .sum()
912}
913
914// ---------------------------------------------------------------------------
915// Float rounding for stable dedup
916// ---------------------------------------------------------------------------
917
918/// Round all Float64 columns to 12 decimal places for stable dedup.
919fn round_float_columns(batches: &[RecordBatch]) -> Vec<RecordBatch> {
920    batches
921        .iter()
922        .map(|batch| {
923            let schema = batch.schema();
924            let has_float = schema
925                .fields()
926                .iter()
927                .any(|f| *f.data_type() == arrow_schema::DataType::Float64);
928            if !has_float {
929                return batch.clone();
930            }
931
932            let columns: Vec<arrow_array::ArrayRef> = batch
933                .columns()
934                .iter()
935                .enumerate()
936                .map(|(i, col)| {
937                    if *schema.field(i).data_type() == arrow_schema::DataType::Float64 {
938                        let arr = col
939                            .as_any()
940                            .downcast_ref::<arrow_array::Float64Array>()
941                            .unwrap();
942                        let rounded: arrow_array::Float64Array = arr
943                            .iter()
944                            .map(|v| v.map(|f| (f * 1e12).round() / 1e12))
945                            .collect();
946                        Arc::new(rounded) as arrow_array::ArrayRef
947                    } else {
948                        Arc::clone(col)
949                    }
950                })
951                .collect();
952
953            RecordBatch::try_new(schema, columns).unwrap_or_else(|_| batch.clone())
954        })
955        .collect()
956}
957
958// ---------------------------------------------------------------------------
959// LeftAntiJoin delta deduplication
960// ---------------------------------------------------------------------------
961
962/// Row threshold above which the vectorized Arrow LeftAntiJoin dedup path is used.
963///
964/// Below this threshold the persistent `RowDedupState` HashSet is O(M) and
965/// avoids rebuilding the existing-row set; above it DataFusion's vectorized
966/// HashJoinExec is more cache-efficient.
967const DEDUP_ANTI_JOIN_THRESHOLD: usize = 300;
968
969/// Deduplicate `candidates` against `existing` using DataFusion's HashJoinExec.
970///
971/// Returns rows in `candidates` that do not appear in `existing` (LeftAnti semantics).
972/// `null_equals_null = true` so NULLs are treated as equal for dedup purposes.
973/// Dedup `batches` by all columns (set semantics), keeping the first occurrence.
974///
975/// `arrow_left_anti_dedup` removes candidate rows that match the existing fact
976/// set, but a single semi-naive iteration can emit the same row many times — e.g.
977/// a transitive-closure rule derives the same `(a, b)` pair via every intermediate
978/// `mid` on a path. A `LeftAnti` join does not remove these *within-candidate*
979/// duplicates, so they would leak into the fact set. The `RowDedupState` and legacy
980/// paths both dedup within the candidate batch ([`RowDedupState::compute_delta`],
981/// [`FixpointState::compute_delta_legacy`]); this keeps the `arrow_left_anti_dedup`
982/// path identical so dedup behavior does not change across `DEDUP_ANTI_JOIN_THRESHOLD`.
983fn dedup_batches_all_columns(
984    batches: Vec<RecordBatch>,
985    schema: &SchemaRef,
986) -> DFResult<Vec<RecordBatch>> {
987    let fields: Vec<SortField> = schema
988        .fields()
989        .iter()
990        .map(|f| SortField::new(f.data_type().clone()))
991        .collect();
992    // Unsupported column types: leave as-is. This path is only reached for facts
993    // sets >= DEDUP_ANTI_JOIN_THRESHOLD; for those types the <threshold path uses
994    // `compute_delta_legacy`, which dedups via `ScalarKey` instead.
995    let Ok(converter) = RowConverter::new(fields) else {
996        return Ok(batches);
997    };
998    let mut seen: HashSet<Box<[u8]>> = HashSet::new();
999    let mut out = Vec::with_capacity(batches.len());
1000    for batch in batches {
1001        if batch.num_rows() == 0 {
1002            continue;
1003        }
1004        let rows = converter
1005            .convert_columns(batch.columns())
1006            .map_err(arrow_err)?;
1007        let mut keep = Vec::with_capacity(batch.num_rows());
1008        for row_idx in 0..batch.num_rows() {
1009            let row_bytes: Box<[u8]> = rows.row(row_idx).data().into();
1010            keep.push(seen.insert(row_bytes));
1011        }
1012        let keep_mask = arrow_array::BooleanArray::from(keep);
1013        let cols = batch
1014            .columns()
1015            .iter()
1016            .map(|c| arrow::compute::filter(c.as_ref(), &keep_mask).map_err(arrow_err))
1017            .collect::<DFResult<Vec<_>>>()?;
1018        if cols.first().is_some_and(|c| !c.is_empty()) {
1019            out.push(RecordBatch::try_new(Arc::clone(schema), cols).map_err(arrow_err)?);
1020        }
1021    }
1022    Ok(out)
1023}
1024
1025async fn arrow_left_anti_dedup(
1026    candidates: Vec<RecordBatch>,
1027    existing: &[RecordBatch],
1028    schema: &SchemaRef,
1029    task_ctx: &Arc<TaskContext>,
1030) -> DFResult<Vec<RecordBatch>> {
1031    if existing.is_empty() || existing.iter().all(|b| b.num_rows() == 0) {
1032        // No existing facts to anti-join against, but still dedup the candidates
1033        // among themselves (a single iteration may emit duplicate rows).
1034        return dedup_batches_all_columns(candidates, schema);
1035    }
1036
1037    let left: Arc<dyn ExecutionPlan> = Arc::new(InMemoryExec::new(candidates, Arc::clone(schema)));
1038    let right: Arc<dyn ExecutionPlan> =
1039        Arc::new(InMemoryExec::new(existing.to_vec(), Arc::clone(schema)));
1040
1041    let on: Vec<(
1042        Arc<dyn datafusion::physical_plan::PhysicalExpr>,
1043        Arc<dyn datafusion::physical_plan::PhysicalExpr>,
1044    )> = schema
1045        .fields()
1046        .iter()
1047        .enumerate()
1048        .map(|(i, field)| {
1049            let l: Arc<dyn datafusion::physical_plan::PhysicalExpr> = Arc::new(
1050                datafusion::physical_plan::expressions::Column::new(field.name(), i),
1051            );
1052            let r: Arc<dyn datafusion::physical_plan::PhysicalExpr> = Arc::new(
1053                datafusion::physical_plan::expressions::Column::new(field.name(), i),
1054            );
1055            (l, r)
1056        })
1057        .collect();
1058
1059    if on.is_empty() {
1060        return Ok(vec![]);
1061    }
1062
1063    let join = HashJoinExec::try_new(
1064        left,
1065        right,
1066        on,
1067        None,
1068        &JoinType::LeftAnti,
1069        None,
1070        PartitionMode::CollectLeft,
1071        datafusion::common::NullEquality::NullEqualsNull,
1072        // null_aware = false: this is a set-difference dedup (NOT EXISTS), not a
1073        // SQL NOT-IN. `NullEqualsNull` already makes NULL keys dedup against NULL
1074        // keys; enabling null-aware semantics would wrongly annihilate all rows
1075        // whenever the existing-fact side contains a NULL key.
1076        false,
1077    )?;
1078
1079    let join_arc: Arc<dyn ExecutionPlan> = Arc::new(join);
1080    // LeftAnti removes candidates that match `existing`, but not duplicate rows
1081    // within the candidate set — dedup those to match the other delta strategies.
1082    let anti = collect_all_partitions(&join_arc, task_ctx.clone()).await?;
1083    dedup_batches_all_columns(anti, schema)
1084}
1085
1086// ---------------------------------------------------------------------------
1087// Plan types for fixpoint rules
1088// ---------------------------------------------------------------------------
1089
1090/// IS-ref binding: a reference from a clause body to a derived relation.
1091#[derive(Debug, Clone)]
1092pub struct IsRefBinding {
1093    /// Index into the DerivedScanRegistry.
1094    pub derived_scan_index: usize,
1095    /// Name of the rule being referenced.
1096    pub rule_name: String,
1097    /// Whether this is a self-reference (rule references itself).
1098    pub is_self_ref: bool,
1099    /// Whether this is a negated reference (NOT IS).
1100    pub negated: bool,
1101    /// For negated IS-refs: `(left_body_col, right_derived_col)` pairs for anti-join filtering.
1102    ///
1103    /// `left_body_col` is the VID column in the clause body (e.g., `"n._vid"`);
1104    /// `right_derived_col` is the corresponding KEY column in the negated rule's facts (e.g., `"n"`).
1105    /// Empty for non-negated IS-refs.
1106    pub anti_join_cols: Vec<(String, String)>,
1107    /// Whether the target rule has a PROB column.
1108    pub target_has_prob: bool,
1109    /// Name of the PROB column in the target rule, if any.
1110    pub target_prob_col: Option<String>,
1111    /// `(body_col, derived_col)` pairs for provenance tracking.
1112    ///
1113    /// Used by shared-proof detection to find which source facts a derived row
1114    /// consumed. Populated for all IS-refs (not just negated ones).
1115    pub provenance_join_cols: Vec<(String, String)>,
1116}
1117
1118/// A single clause (body) within a fixpoint rule.
1119#[derive(Debug)]
1120pub struct FixpointClausePlan {
1121    /// The logical plan for the clause body.
1122    pub body_logical: LogicalPlan,
1123    /// IS-ref bindings used by this clause.
1124    pub is_ref_bindings: Vec<IsRefBinding>,
1125    /// Priority value for this clause (if PRIORITY semantics apply).
1126    pub priority: Option<i64>,
1127    /// ALONG binding variable names propagated from the planner.
1128    pub along_bindings: Vec<String>,
1129    /// Phase B Slice 3: neural-model invocations lifted out of YIELD
1130    /// items by the compiler. Each entry is evaluated per row after the
1131    /// clause body produces batches and before IS-ref handling.
1132    pub model_invocations: Vec<ModelInvocation>,
1133}
1134
1135/// Physical plan for a single rule in a fixpoint stratum.
1136#[derive(Debug)]
1137pub struct FixpointRulePlan {
1138    /// Rule name.
1139    pub name: String,
1140    /// Clause bodies (each evaluates to candidate rows).
1141    pub clauses: Vec<FixpointClausePlan>,
1142    /// Output schema for this rule's derived relation.
1143    pub yield_schema: SchemaRef,
1144    /// Indices of KEY columns within yield_schema.
1145    pub key_column_indices: Vec<usize>,
1146    /// Priority value (if PRIORITY semantics apply).
1147    pub priority: Option<i64>,
1148    /// Whether this rule has FOLD semantics.
1149    pub has_fold: bool,
1150    /// FOLD bindings for post-fixpoint aggregation.
1151    pub fold_bindings: Vec<FoldBinding>,
1152    /// Post-FOLD filter expressions (HAVING semantics).
1153    pub having: Vec<Expr>,
1154    /// Whether this rule has BEST BY semantics.
1155    pub has_best_by: bool,
1156    /// BEST BY sort criteria for post-fixpoint selection.
1157    pub best_by_criteria: Vec<SortCriterion>,
1158    /// Whether this rule has PRIORITY semantics.
1159    pub has_priority: bool,
1160    /// Whether BEST BY should apply a deterministic secondary sort for
1161    /// tie-breaking. When false, tied rows are selected non-deterministically
1162    /// (faster but not repeatable across runs).
1163    pub deterministic: bool,
1164    /// Name of the PROB column in this rule's yield schema, if any.
1165    pub prob_column_name: Option<String>,
1166    /// True when any clause of this rule has ≥2 positive same-stratum
1167    /// IS-refs (non-linear recursion, e.g. `tc(a,b) :- tc(a,m), tc(m,b)`).
1168    /// Such rules get FULL facts (naive evaluation) instead of the latest
1169    /// delta on their self-ref scans: a delta-only join computes Δ×Δ and
1170    /// misses the Δ×F_old combinations, silently under-deriving. Covers
1171    /// both `p :- p, p` and `p :- p, q` (q in the same SCC) shapes.
1172    pub non_linear: bool,
1173}
1174
1175// ---------------------------------------------------------------------------
1176// run_fixpoint_loop — the core semi-naive iteration algorithm
1177// ---------------------------------------------------------------------------
1178
1179/// Run the semi-naive fixpoint iteration loop.
1180///
1181/// Evaluates all rules in a stratum repeatedly, feeding deltas back through
1182/// derived scan handles until convergence or limits are reached.
1183#[expect(clippy::too_many_arguments, reason = "Fixpoint loop needs all context")]
1184async fn run_fixpoint_loop(
1185    rules: Vec<FixpointRulePlan>,
1186    max_iterations: usize,
1187    timeout: Duration,
1188    graph_ctx: Arc<GraphExecutionContext>,
1189    session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
1190    storage: Arc<StorageManager>,
1191    schema_info: Arc<UniSchema>,
1192    params: HashMap<String, Value>,
1193    registry: Arc<DerivedScanRegistry>,
1194    output_schema: SchemaRef,
1195    max_derived_bytes: usize,
1196    derivation_tracker: Option<Arc<ProvenanceStore>>,
1197    iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
1198    strict_probability_domain: bool,
1199    probability_epsilon: f64,
1200    exact_probability: bool,
1201    max_bdd_variables: usize,
1202    warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
1203    approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
1204    top_k_proofs: usize,
1205    timeout_flag: Arc<std::sync::atomic::AtomicU8>,
1206    semiring_kind: SemiringKind,
1207    classifier_registry: Arc<ClassifierRegistry>,
1208    classifier_cache: Option<Arc<ModelInvocationCache>>,
1209    classifier_provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
1210    profile_collector: Option<Arc<LocyProfileCollector>>,
1211) -> DFResult<Vec<RecordBatch>> {
1212    let start = Instant::now();
1213    let task_ctx = session_ctx.read().task_ctx();
1214
1215    // IMPORTANT: per rollout D-9 the FuzzyNotProbabilistic warning emitted
1216    // below is unsuppressible — do not gate on any suppression mechanism.
1217    // Fuzzy truth values are not probabilities; silent conflation is the
1218    // dominant pitfall in neuro-symbolic systems (LTN, NTP).
1219    if semiring_kind == SemiringKind::MaxMinProb {
1220        let mut warnings = warnings_slot.write().unwrap_or_else(|e| e.into_inner());
1221        let mut already_warned: HashSet<String> = warnings
1222            .iter()
1223            .filter(|w| w.code == RuntimeWarningCode::FuzzyNotProbabilistic)
1224            .map(|w| w.rule_name.clone())
1225            .collect();
1226        for rule in &rules {
1227            if rule.prob_column_name.is_some() && !already_warned.contains(&rule.name) {
1228                warnings.push(RuntimeWarning {
1229                    code: RuntimeWarningCode::FuzzyNotProbabilistic,
1230                    message: format!(
1231                        "rule '{}' carries a PROB column but is being evaluated under \
1232                         the MaxMinProb (fuzzy / Viterbi) semiring; outputs are fuzzy \
1233                         truth values, not probabilities",
1234                        rule.name
1235                    ),
1236                    rule_name: rule.name.clone(),
1237                    variable_count: None,
1238                    key_group: None,
1239                });
1240                already_warned.insert(rule.name.clone());
1241            }
1242        }
1243    }
1244
1245    // Initialize per-rule state
1246    let mut states: Vec<FixpointState> = rules
1247        .iter()
1248        .map(|rule| {
1249            let monotonic_agg = if !rule.fold_bindings.is_empty() {
1250                let bindings: Vec<MonotonicFoldBinding> = rule
1251                    .fold_bindings
1252                    .iter()
1253                    .map(|fb| MonotonicFoldBinding {
1254                        fold_name: fb.output_name.clone(),
1255                        aggregate: std::sync::Arc::clone(&fb.aggregate),
1256                        input_col_index: fb.input_col_index,
1257                        input_col_name: fb.input_col_name.clone(),
1258                    })
1259                    .collect();
1260                Some(MonotonicAggState::new(bindings))
1261            } else {
1262                None
1263            };
1264            FixpointState::new_with_semiring(
1265                rule.name.clone(),
1266                Arc::clone(&rule.yield_schema),
1267                rule.key_column_indices.clone(),
1268                max_derived_bytes,
1269                monotonic_agg,
1270                strict_probability_domain,
1271                semiring_kind,
1272            )
1273        })
1274        .collect();
1275
1276    // Main iteration loop
1277    let mut converged = false;
1278    let mut total_iters = 0usize;
1279    for iteration in 0..max_iterations {
1280        total_iters = iteration + 1;
1281        tracing::debug!("fixpoint iteration {}", iteration);
1282        let mut any_changed = false;
1283
1284        for rule_idx in 0..rules.len() {
1285            let rule = &rules[rule_idx];
1286
1287            // Update derived scan handles for this rule's clauses
1288            update_derived_scan_handles(&registry, &states, rule_idx, &rules);
1289
1290            // Evaluate clause bodies, tracking per-clause candidates for provenance.
1291            let mut all_candidates = Vec::new();
1292            let mut clause_candidates: Vec<Vec<RecordBatch>> = Vec::new();
1293            // Profiling (profile() path only): time this rule's evaluation in
1294            // this iteration and accumulate the clause-body operator trees.
1295            let rule_start = Instant::now();
1296            let mut iter_ops: Vec<OperatorStats> = Vec::new();
1297            for clause in &rule.clauses {
1298                // Phase B A4 follow-up: the planner inserts
1299                // `LogicalPlan::LocyModelInvoke` between the body and
1300                // `LocyProject` when this clause has neural-model
1301                // invocations, so `execute_subplan` runs the invocation
1302                // inline as part of the body plan tree.
1303                let mut batches = if profile_collector.is_some() {
1304                    let (b, ops) = execute_subplan_collecting(
1305                        &clause.body_logical,
1306                        &params,
1307                        &HashMap::new(),
1308                        &graph_ctx,
1309                        &session_ctx,
1310                        &storage,
1311                        &schema_info,
1312                        None, // Locy fixpoint clause body is read-only
1313                    )
1314                    .await?;
1315                    iter_ops.extend(ops);
1316                    b
1317                } else {
1318                    execute_subplan(
1319                        &clause.body_logical,
1320                        &params,
1321                        &HashMap::new(),
1322                        &graph_ctx,
1323                        &session_ctx,
1324                        &storage,
1325                        &schema_info,
1326                        None, // Locy fixpoint clause body is read-only
1327                    )
1328                    .await?
1329                };
1330                // Apply negated IS-ref semantics: probabilistic complement or anti-join.
1331                for binding in &clause.is_ref_bindings {
1332                    if binding.negated
1333                        && !binding.anti_join_cols.is_empty()
1334                        && let Some(entry) = registry.get(binding.derived_scan_index)
1335                    {
1336                        let neg_facts = entry.data.read().clone();
1337                        if !neg_facts.is_empty() {
1338                            if binding.target_has_prob && rule.prob_column_name.is_some() {
1339                                // Probabilistic complement: add 1-p column instead of filtering.
1340                                let complement_col =
1341                                    format!("__prob_complement_{}", binding.rule_name);
1342                                if let Some(prob_col) = &binding.target_prob_col {
1343                                    batches = apply_prob_complement_composite(
1344                                        batches,
1345                                        &neg_facts,
1346                                        &binding.anti_join_cols,
1347                                        prob_col,
1348                                        &complement_col,
1349                                    )?;
1350                                } else {
1351                                    // target_has_prob but no prob_col: fall back to anti-join.
1352                                    batches = apply_anti_join_composite(
1353                                        batches,
1354                                        &neg_facts,
1355                                        &binding.anti_join_cols,
1356                                    )?;
1357                                }
1358                            } else {
1359                                // Boolean exclusion: anti-join (existing behavior)
1360                                batches = apply_anti_join_composite(
1361                                    batches,
1362                                    &neg_facts,
1363                                    &binding.anti_join_cols,
1364                                )?;
1365                            }
1366                        }
1367                    }
1368                }
1369                // Multiply complement columns into the PROB column (if any) and clean up
1370                let complement_cols: Vec<String> = if !batches.is_empty() {
1371                    batches[0]
1372                        .schema()
1373                        .fields()
1374                        .iter()
1375                        .filter(|f| f.name().starts_with("__prob_complement_"))
1376                        .map(|f| f.name().clone())
1377                        .collect()
1378                } else {
1379                    vec![]
1380                };
1381                if !complement_cols.is_empty() {
1382                    batches = multiply_prob_factors(
1383                        batches,
1384                        rule.prob_column_name.as_deref(),
1385                        &complement_cols,
1386                    )?;
1387                }
1388
1389                clause_candidates.push(batches.clone());
1390                all_candidates.extend(batches);
1391            }
1392
1393            // Merge candidates into facts.
1394            // For BEST BY rules, use a specialized merge that keeps only the
1395            // best row per KEY group, enabling convergence on cyclic graphs.
1396            let changed = if rule.has_best_by && !rule.best_by_criteria.is_empty() {
1397                states[rule_idx].merge_best_by(all_candidates, &rule.best_by_criteria)?
1398            } else {
1399                states[rule_idx]
1400                    .merge_delta(all_candidates, Some(Arc::clone(&task_ctx)))
1401                    .await?
1402            };
1403            if changed {
1404                any_changed = true;
1405                // Record provenance for newly derived facts when tracker is present.
1406                if let Some(ref tracker) = derivation_tracker {
1407                    record_provenance(
1408                        ProvenanceCtx {
1409                            tracker,
1410                            registry: &registry,
1411                            warnings_slot: &warnings_slot,
1412                        },
1413                        rule,
1414                        &states[rule_idx],
1415                        &clause_candidates,
1416                        iteration,
1417                        top_k_proofs,
1418                        ClassifierRefs {
1419                            registry: &classifier_registry,
1420                            cache: classifier_cache.as_ref(),
1421                            provenance_store: classifier_provenance_store.as_ref(),
1422                        },
1423                    )
1424                    .await;
1425                }
1426            }
1427
1428            // Profiling: record this rule's per-iteration row (delta = net-new
1429            // facts merged this pass, plus the captured clause-body operators).
1430            if let Some(ref collector) = profile_collector {
1431                let delta_facts: usize = states[rule_idx]
1432                    .all_delta()
1433                    .iter()
1434                    .map(|b| b.num_rows())
1435                    .sum();
1436                collector.record(
1437                    &rule.name,
1438                    iteration,
1439                    delta_facts,
1440                    rule_start.elapsed().as_secs_f64() * 1000.0,
1441                    iter_ops,
1442                );
1443            }
1444        }
1445
1446        // Check convergence
1447        if !any_changed && states.iter().all(|s| s.is_converged()) {
1448            tracing::debug!("fixpoint converged after {} iterations", iteration + 1);
1449            converged = true;
1450            break;
1451        }
1452
1453        // Check timeout
1454        if start.elapsed() > timeout {
1455            tracing::warn!(
1456                "fixpoint timeout after {} iterations; returning partial results",
1457                iteration + 1,
1458            );
1459            interruption::set(&timeout_flag, interruption::TIMEOUT);
1460            break;
1461        }
1462    }
1463
1464    // Write per-rule iteration counts to the shared slot.
1465    if let Ok(mut counts) = iteration_counts.write() {
1466        for rule in &rules {
1467            counts.insert(rule.name.clone(), total_iters);
1468        }
1469    }
1470
1471    // Profiling: record each rule's final converged fact count.
1472    if let Some(ref collector) = profile_collector {
1473        for (idx, rule) in rules.iter().enumerate() {
1474            let facts: usize = states[idx].all_facts().iter().map(|b| b.num_rows()).sum();
1475            collector.set_final_facts(&rule.name, facts);
1476        }
1477    }
1478
1479    // If we exhausted all iterations without converging, record the iteration
1480    // limit (distinct from a wall-clock timeout) and proceed with partial
1481    // results rather than discarding all work. `set` is first-wins, so a
1482    // wall-clock timeout recorded above is not overwritten here.
1483    if !converged && interruption::reason(&timeout_flag).is_none() {
1484        tracing::warn!(
1485            "fixpoint did not converge after {max_iterations} iterations; returning partial results",
1486        );
1487        interruption::set(&timeout_flag, interruption::ITERATION_LIMIT);
1488    }
1489
1490    // Post-fixpoint processing per rule and collect output
1491    let task_ctx = session_ctx.read().task_ctx();
1492    let mut all_output = Vec::new();
1493
1494    for (rule_idx, state) in states.into_iter().enumerate() {
1495        let rule = &rules[rule_idx];
1496        let mut facts = state.into_facts();
1497        if facts.is_empty() {
1498            continue;
1499        }
1500
1501        // Detect shared proofs before FOLD collapses groups.
1502        //
1503        // TODO(C0-stage2): swap `detect_shared_lineage` for `TopKTag`
1504        // DNF inspection when `semiring_kind == TopKProofs { k }`.
1505        // The library-layer tag math has landed in
1506        // `crates/uni-locy/src/top_k_proofs.rs` (Phase C C0 Stage 1);
1507        // Stage 2 plumbs `TopKTag` through `MonotonicAggState` /
1508        // `FoldExec` so per-row dependency DNFs are available here.
1509        // Until Stage 2, this scalar `ProvenanceStore` path runs for
1510        // every semiring including `TopKProofs` (per rollout D-4
1511        // "graceful migration").
1512        //
1513        // Phase-3 shared-proof detection is meaningful only under
1514        // `AddMultProb` (and `BddExact`, which is the AddMultProb math
1515        // plus a WMC post-correction). Under `MaxMinProb`, `plus = max`
1516        // is idempotent — shared proofs don't double-count — so the
1517        // warning is moot and we skip the work.
1518        let shared_info = if semiring_kind == SemiringKind::MaxMinProb {
1519            None
1520        } else if let Some(ref tracker) = derivation_tracker {
1521            detect_shared_lineage(rule, &facts, tracker, &warnings_slot, semiring_kind)
1522        } else {
1523            None
1524        };
1525
1526        // Apply BDD for shared groups if exact_probability is enabled.
1527        if exact_probability
1528            && let Some(ref info) = shared_info
1529            && let Some(ref tracker) = derivation_tracker
1530        {
1531            facts = apply_exact_wmc(
1532                facts,
1533                rule,
1534                info,
1535                tracker,
1536                max_bdd_variables,
1537                &warnings_slot,
1538                &approximate_slot,
1539            )?;
1540        }
1541
1542        let processed = apply_post_fixpoint_chain(
1543            facts,
1544            rule,
1545            &task_ctx,
1546            strict_probability_domain,
1547            probability_epsilon,
1548            semiring_kind,
1549            derivation_tracker.as_ref().map(Arc::clone),
1550            top_k_proofs,
1551            Some(Arc::clone(&registry)),
1552        )
1553        .await?;
1554        all_output.extend(processed);
1555    }
1556
1557    // If no output, return empty batch with output schema
1558    if all_output.is_empty() {
1559        all_output.push(RecordBatch::new_empty(output_schema));
1560    }
1561
1562    Ok(all_output)
1563}
1564
1565// ---------------------------------------------------------------------------
1566// Provenance recording helpers
1567// ---------------------------------------------------------------------------
1568
1569/// Record provenance for all newly derived facts (rows in the current delta).
1570///
1571/// Called after `merge_delta` returns `true`. Attributes each new fact to the
1572/// clause most likely to have produced it, using first-derivation-wins semantics.
1573/// Borrowed bundle of classifier-side runtime state used by
1574/// provenance / EXPLAIN-reconstruction code paths. Keeps function
1575/// signatures under the too-many-arguments threshold.
1576pub(crate) struct ClassifierRefs<'a> {
1577    pub registry: &'a Arc<ClassifierRegistry>,
1578    pub cache: Option<&'a Arc<uni_locy::ModelInvocationCache>>,
1579    /// Phase C B1-B3 follow-up: when `Some`, EXPLAIN's neural_calls
1580    /// collection consults the side-channel provenance store first
1581    /// (populated by `apply_model_invocations`). This is the only way
1582    /// to surface NeuralProvenance for Python-registered classifiers,
1583    /// whose model_invocations may be rewritten away by the planner
1584    /// and so wouldn't trigger the re-invocation fallback.
1585    pub provenance_store: Option<&'a Arc<uni_locy::NeuralProvenanceStore>>,
1586}
1587
1588/// Borrowed bundle of provenance-recording state: the in-flight
1589/// tracker, the derived-scan registry (used to resolve IS-ref inputs),
1590/// and the shared warnings slot. Bundled to keep
1591/// `record_provenance` / `record_and_detect_lineage_nonrecursive`
1592/// under the too-many-arguments threshold.
1593pub(crate) struct ProvenanceCtx<'a> {
1594    pub tracker: &'a Arc<ProvenanceStore>,
1595    pub registry: &'a Arc<DerivedScanRegistry>,
1596    pub warnings_slot: &'a Arc<StdRwLock<Vec<RuntimeWarning>>>,
1597}
1598
1599async fn record_provenance(
1600    prov: ProvenanceCtx<'_>,
1601    rule: &FixpointRulePlan,
1602    state: &FixpointState,
1603    clause_candidates: &[Vec<RecordBatch>],
1604    iteration: usize,
1605    top_k_proofs: usize,
1606    classifiers: ClassifierRefs<'_>,
1607) {
1608    let tracker = prov.tracker;
1609    let registry = prov.registry;
1610    let warnings_slot = prov.warnings_slot;
1611    let classifier_registry = classifiers.registry;
1612    let classifier_cache = classifiers.cache;
1613    let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
1614
1615    // Pre-compute base fact probabilities for top-k mode.
1616    let base_probs = if top_k_proofs > 0 {
1617        tracker.base_fact_probs()
1618    } else {
1619        HashMap::new()
1620    };
1621
1622    let mut topk_acc = TopKProofAccumulator::new();
1623
1624    for delta_batch in state.all_delta() {
1625        for row_idx in 0..delta_batch.num_rows() {
1626            let row_hash = format!(
1627                "{:?}",
1628                extract_scalar_key(delta_batch, &all_indices, row_idx)
1629            )
1630            .into_bytes();
1631            let fact_row = batch_row_to_value_map(delta_batch, row_idx);
1632            let clause_index =
1633                find_clause_for_row(delta_batch, row_idx, &all_indices, clause_candidates);
1634
1635            let support = collect_is_ref_inputs(rule, clause_index, delta_batch, row_idx, registry);
1636
1637            let proof_probability = if top_k_proofs > 0 {
1638                compute_proof_probability(&support, &base_probs)
1639            } else {
1640                None
1641            };
1642
1643            let entry = ProvenanceAnnotation {
1644                rule_name: rule.name.clone(),
1645                clause_index,
1646                support,
1647                along_values: {
1648                    let along_names: Vec<String> = rule
1649                        .clauses
1650                        .get(clause_index)
1651                        .map(|c| c.along_bindings.clone())
1652                        .unwrap_or_default();
1653                    along_names
1654                        .iter()
1655                        .filter_map(|name| fact_row.get(name).map(|v| (name.clone(), v.clone())))
1656                        .collect()
1657                },
1658                iteration,
1659                fact_row: fact_row.clone(),
1660                proof_probability,
1661                neural_calls: collect_neural_calls_for_row(
1662                    rule,
1663                    clause_index,
1664                    &fact_row,
1665                    classifier_registry,
1666                    classifier_cache,
1667                    classifiers.provenance_store,
1668                )
1669                .await,
1670            };
1671            if top_k_proofs > 0 {
1672                topk_acc.accumulate(&entry, &row_hash);
1673                tracker.record_top_k(row_hash, entry, top_k_proofs);
1674            } else {
1675                tracker.record(row_hash, entry);
1676            }
1677        }
1678    }
1679
1680    topk_acc.emit_warning_if_any(rule, top_k_proofs, warnings_slot);
1681}
1682
1683/// Phase C C0 Stage 2: collects per-row `Proof` tags during the
1684/// fixpoint row walk, then surfaces `TopKPruningCrossedDependency`
1685/// when post-walk top-K merging would drop a proof whose base RVs
1686/// overlap a retained one. The shared `BaseRv` interner is what
1687/// makes the overlap detectable — proofs grounded in the same
1688/// `base_fact_id` get the same `BaseRv`.
1689struct TopKProofAccumulator {
1690    per_fact: HashMap<Vec<u8>, Vec<uni_locy::Proof>>,
1691    base_rv_interner: HashMap<Vec<u8>, uni_locy::BaseRv>,
1692    next_rv: u32,
1693}
1694
1695impl TopKProofAccumulator {
1696    fn new() -> Self {
1697        Self {
1698            per_fact: HashMap::new(),
1699            base_rv_interner: HashMap::new(),
1700            next_rv: 0,
1701        }
1702    }
1703
1704    fn accumulate(&mut self, entry: &ProvenanceAnnotation, row_hash: &[u8]) {
1705        let mut base_rvs = uni_locy::BaseRvSet::empty();
1706        for term in &entry.support {
1707            let rv = *self
1708                .base_rv_interner
1709                .entry(term.base_fact_id.clone())
1710                .or_insert_with(|| {
1711                    let r = uni_locy::BaseRv(self.next_rv);
1712                    self.next_rv += 1;
1713                    r
1714                });
1715            base_rvs.insert(rv);
1716        }
1717        self.per_fact
1718            .entry(row_hash.to_vec())
1719            .or_default()
1720            .push(uni_locy::Proof {
1721                weight: entry.proof_probability.unwrap_or(0.0),
1722                base_rvs,
1723                neural_calls: Vec::new(),
1724            });
1725    }
1726
1727    fn emit_warning_if_any(
1728        &self,
1729        rule: &FixpointRulePlan,
1730        top_k_proofs: usize,
1731        warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
1732    ) {
1733        if top_k_proofs == 0 || self.per_fact.is_empty() {
1734            return;
1735        }
1736        let crossed_facts = self
1737            .per_fact
1738            .values()
1739            .filter(|proofs| {
1740                let (_kept, notice) =
1741                    uni_locy::merge_top_k_runtime(Vec::new(), (*proofs).clone(), top_k_proofs);
1742                notice == uni_locy::PruneNotice::CrossedDependency
1743            })
1744            .count();
1745        if crossed_facts == 0 {
1746            return;
1747        }
1748        let Ok(mut w) = warnings_slot.write() else {
1749            return;
1750        };
1751        let already = w.iter().any(|rw| {
1752            matches!(
1753                rw.code,
1754                uni_locy::types::RuntimeWarningCode::TopKPruningCrossedDependency
1755            ) && rw.rule_name == rule.name
1756        });
1757        if already {
1758            return;
1759        }
1760        w.push(RuntimeWarning {
1761            code: uni_locy::types::RuntimeWarningCode::TopKPruningCrossedDependency,
1762            rule_name: rule.name.clone(),
1763            message: format!(
1764                "rule '{}': top-K proof pruning (k={}) discarded {} fact(s) \
1765                 whose dependencies overlap retained proofs. The retained \
1766                 top-{} under-counts the true joint probability for those \
1767                 facts (Scallop, Huang et al. 2021). Increase k to recover.",
1768                rule.name, top_k_proofs, crossed_facts, top_k_proofs
1769            ),
1770            variable_count: None,
1771            key_group: None,
1772        });
1773    }
1774}
1775
1776/// Collect IS-ref input facts for a derived row using provenance join columns.
1777///
1778/// For each non-negated IS-ref binding in the clause, extracts body-side key
1779/// values from the delta row and finds matching source rows in the registry.
1780/// Returns a `ProofTerm` for each match (with the source fact hash).
1781/// Phase C B1–B3: build [`uni_locy::NeuralProvenance`] entries for
1782/// the model invocations on this clause by reading each
1783/// invocation's output column from the post-LocyModelInvoke row.
1784/// `raw_probability` is the classifier's direct output;
1785/// `calibrated_probability` and `confidence_band` come from the
1786/// active Calibrator (when the classifier wraps one).
1787///
1788/// Phase C B1-B3 follow-up rewrite: re-evaluate the classifier
1789/// per fact using the ORIGINAL pre-rewrite feature expressions
1790/// (stored as `invocation.original_feature_exprs`). This works
1791/// for invocations in YIELD, ALONG, and FOLD positions uniformly
1792/// — the original args carry the input bindings regardless of
1793/// where the synthetic `__model_<n>` column ends up in the plan
1794/// tree. Memoization via `ModelInvocationCache` (already threaded)
1795/// absorbs repeat costs; EXPLAIN typically operates on small
1796/// derivation trees so the per-fact classifier call is bounded.
1797async fn collect_neural_calls_for_row(
1798    rule: &FixpointRulePlan,
1799    clause_index: usize,
1800    fact_row: &uni_locy::FactRow,
1801    classifier_registry: &Arc<ClassifierRegistry>,
1802    classifier_cache: Option<&Arc<uni_locy::ModelInvocationCache>>,
1803    provenance_store: Option<&Arc<uni_locy::NeuralProvenanceStore>>,
1804) -> Vec<uni_locy::NeuralProvenance> {
1805    let Some(clause) = rule.clauses.get(clause_index) else {
1806        return Vec::new();
1807    };
1808    if clause.model_invocations.is_empty() {
1809        return Vec::new();
1810    }
1811    let mut out = Vec::with_capacity(clause.model_invocations.len());
1812    for invocation in &clause.model_invocations {
1813        // Build ClassifyInput from the REWRITTEN feature expressions
1814        // (referencing the synthetic `__feat_*` hidden columns that the
1815        // planner lifts model-call args into). This matches the writer's
1816        // path in `apply_model_invocations`, which iterates
1817        // `invocation.feature_exprs` to compute the same `input_hash`
1818        // that gets stored. Using the pre-rewrite `original_feature_exprs`
1819        // here would compute a different hash for YIELD-position
1820        // invocations (where the pre-rewrite expr references properties
1821        // not materialised into fact_row), causing the store lookup
1822        // below to miss and `neural_calls` to come back empty.
1823        let mut features = std::collections::HashMap::new();
1824        for (binding_name, feat_expr) in invocation
1825            .feature_names
1826            .iter()
1827            .zip(invocation.feature_exprs.iter())
1828        {
1829            features.insert(
1830                binding_name.clone(),
1831                eval_feature_expr_against_fact_row(feat_expr, fact_row),
1832            );
1833        }
1834        let input = uni_locy::ClassifyInput { features };
1835        let input_hash = input.stable_hash();
1836
1837        // Store-first read path. `apply_model_invocations` writes a
1838        // NeuralProvenanceRecord per (model, input_hash) into the
1839        // side-channel store during fixpoint. If we find a record
1840        // there, surface it directly — this is the only path that
1841        // populates calibrated_probability + confidence_band for
1842        // Python-registered classifiers.
1843        if let Some(store) = provenance_store
1844            && let Some(record) = store.get(&invocation.model_name, input_hash)
1845        {
1846            out.push(uni_locy::NeuralProvenance {
1847                model_name: invocation.model_name.clone(),
1848                raw_probability: record.raw_probability,
1849                calibrated_probability: record.calibrated_probability,
1850                confidence_band: record.confidence_band,
1851            });
1852            continue;
1853        }
1854
1855        // Fallback: re-invoke the classifier. Only reached when the
1856        // store wasn't populated for this (model, input) — e.g. older
1857        // sessions where the store wasn't threaded, or when EXPLAIN
1858        // runs against a row that fixpoint never touched.
1859        let Some(classifier) = classifier_registry.get(&invocation.model_name) else {
1860            continue;
1861        };
1862        let raw = if let Some(v) =
1863            classifier_cache.and_then(|c| c.get(&invocation.model_name, input_hash))
1864        {
1865            v
1866        } else {
1867            match classifier.classify(std::slice::from_ref(&input)).await {
1868                Ok(probs) => {
1869                    let v = probs.first().copied().unwrap_or(0.0);
1870                    if let Some(c) = classifier_cache {
1871                        c.insert(&invocation.model_name, input_hash, v);
1872                    }
1873                    v
1874                }
1875                Err(_) => continue,
1876            }
1877        };
1878        let calibrator = classifier.get_calibrator();
1879        let calibrated_probability = calibrator.as_ref().map(|_| raw);
1880        let confidence_band = calibrator.as_ref().and_then(|c| c.confidence_band(raw));
1881        out.push(uni_locy::NeuralProvenance {
1882            model_name: invocation.model_name.clone(),
1883            raw_probability: raw,
1884            calibrated_probability,
1885            confidence_band,
1886        });
1887    }
1888    out
1889}
1890
1891/// Phase C B1-B3 follow-up: evaluate a model's pre-rewrite feature
1892/// expression against a fact_row, producing a `FeatureValue` for
1893/// classifier input reconstruction. Mirrors the compile-time
1894/// acceptance set in `validate_features`:
1895/// - `Variable(name)` → fact_row[name], coerced.
1896/// - `Property(Variable(v), prop)` → fact_row["v.prop"] (the
1897///   materialized property column).
1898fn eval_feature_expr_against_fact_row(
1899    expr: &uni_cypher::ast::Expr,
1900    fact_row: &uni_locy::FactRow,
1901) -> uni_locy::FeatureValue {
1902    use uni_cypher::ast::Expr;
1903    use uni_locy::FeatureValue;
1904    let value_to_feature = |v: Option<&uni_common::Value>| -> FeatureValue {
1905        match v {
1906            Some(uni_common::Value::Float(f)) => FeatureValue::Float(*f),
1907            Some(uni_common::Value::Int(i)) => FeatureValue::Int(*i),
1908            Some(uni_common::Value::Bool(b)) => FeatureValue::Bool(*b),
1909            Some(uni_common::Value::String(s)) => FeatureValue::String(s.clone()),
1910            Some(uni_common::Value::Node(n)) => {
1911                // Encode node by vid for `scorer(s)` style.
1912                FeatureValue::Int(n.vid.as_u64() as i64)
1913            }
1914            _ => FeatureValue::Null,
1915        }
1916    };
1917    // Phase D D1: resolve a sub-expression to its raw `uni_common::Value`
1918    // for the `similar_to` UDF input. Falls back to the node's property
1919    // when the materialized column key isn't directly present.
1920    let resolve_value = |sub: &Expr| -> uni_common::Value {
1921        match sub {
1922            Expr::Variable(name) => fact_row
1923                .get(name)
1924                .cloned()
1925                .unwrap_or(uni_common::Value::Null),
1926            Expr::Property(boxed, prop) if matches!(boxed.as_ref(), Expr::Variable(_)) => {
1927                let Expr::Variable(v) = boxed.as_ref() else {
1928                    unreachable!()
1929                };
1930                let key = format!("{}.{}", v, prop);
1931                if let Some(val) = fact_row.get(&key) {
1932                    return val.clone();
1933                }
1934                if let Some(uni_common::Value::Node(n)) = fact_row.get(v) {
1935                    return n
1936                        .properties
1937                        .get(prop)
1938                        .cloned()
1939                        .unwrap_or(uni_common::Value::Null);
1940                }
1941                uni_common::Value::Null
1942            }
1943            Expr::Literal(lit) => lit.to_value(),
1944            Expr::List(items) => {
1945                let mut out = Vec::with_capacity(items.len());
1946                for it in items {
1947                    out.push(match it {
1948                        Expr::Literal(lit) => lit.to_value(),
1949                        _ => uni_common::Value::Null,
1950                    });
1951                }
1952                uni_common::Value::List(out)
1953            }
1954            _ => uni_common::Value::Null,
1955        }
1956    };
1957
1958    match expr {
1959        Expr::Variable(name) => value_to_feature(fact_row.get(name)),
1960        Expr::Property(boxed, prop) => {
1961            if let Expr::Variable(v) = boxed.as_ref() {
1962                // Try the materialized property column first.
1963                let key = format!("{}.{}", v, prop);
1964                if let Some(val) = fact_row.get(&key) {
1965                    return value_to_feature(Some(val));
1966                }
1967                // Fallback: try the synthetic hidden column that the
1968                // planner injects for property-access feature args
1969                // (`__feat_<var>_<prop>`). The writer side
1970                // (`apply_model_invocations`) already uses this
1971                // fallback (see `resolve_src` in the same file), so
1972                // mirroring it here keeps reader/writer input_hash
1973                // symmetric — without it, the YIELD-position case
1974                // (where `fact_row[v]` is a vid Int, not a Node)
1975                // returns Null and the store-lookup misses.
1976                let hidden_key = format!("__feat_{}_{}", v, prop);
1977                if let Some(val) = fact_row.get(&hidden_key) {
1978                    return value_to_feature(Some(val));
1979                }
1980                // Final fallback: read property directly from the
1981                // node value (works when fact_row carries the Node
1982                // rather than a vid Int).
1983                if let Some(uni_common::Value::Node(n)) = fact_row.get(v) {
1984                    return value_to_feature(n.properties.get(prop));
1985                }
1986            }
1987            FeatureValue::Null
1988        }
1989        Expr::FunctionCall { name, args, .. } if name == "similar_to" && args.len() == 2 => {
1990            let lv = resolve_value(&args[0]);
1991            let rv = resolve_value(&args[1]);
1992            match crate::query::similar_to::eval_similar_to_pure(&lv, &rv) {
1993                Ok(uni_common::Value::Float(f)) => FeatureValue::Float(f),
1994                _ => FeatureValue::Null,
1995            }
1996        }
1997        // `semantic_match` requires the Xervo embedder at this scope, which
1998        // is not threaded into the EXPLAIN re-evaluation path. Surface as
1999        // Null so neural-provenance still renders for the rest of the row.
2000        //
2001        // Phase D D1 graph-structural FunctionCalls (`degree_centrality`,
2002        // `pagerank_score`, `closeness_centrality`, `avg_neighbor`,
2003        // `max_neighbor`, `sum_neighbor`) require the `GraphAlgoHandle`
2004        // (algorithm registry + storage + PropertyManager) and an async
2005        // re-precompute pass — none of which are reachable from this
2006        // synchronous fact-row evaluator. Mode B re-evaluation surfaces
2007        // them as Null; the authoritative hot-path values are recorded
2008        // in `NeuralProvenanceStore` per fact (the EXPLAIN renderer
2009        // consults the store first when configured, falling back to
2010        // Mode B re-evaluation only as a backup).
2011        Expr::FunctionCall { name, .. }
2012            if matches!(
2013                name.as_str(),
2014                "degree_centrality"
2015                    | "pagerank_score"
2016                    | "closeness_centrality"
2017                    | "betweenness_centrality"
2018                    | "eigenvector_centrality"
2019                    | "harmonic_centrality"
2020                    | "katz_centrality"
2021                    | "avg_neighbor"
2022                    | "max_neighbor"
2023                    | "sum_neighbor"
2024            ) =>
2025        {
2026            FeatureValue::Null
2027        }
2028        _ => FeatureValue::Null,
2029    }
2030}
2031
2032fn collect_is_ref_inputs(
2033    rule: &FixpointRulePlan,
2034    clause_index: usize,
2035    delta_batch: &RecordBatch,
2036    row_idx: usize,
2037    registry: &Arc<DerivedScanRegistry>,
2038) -> Vec<ProofTerm> {
2039    let clause = match rule.clauses.get(clause_index) {
2040        Some(c) => c,
2041        None => return vec![],
2042    };
2043
2044    let mut inputs = Vec::new();
2045    let delta_schema = delta_batch.schema();
2046
2047    for binding in &clause.is_ref_bindings {
2048        if binding.negated {
2049            continue;
2050        }
2051        if binding.provenance_join_cols.is_empty() {
2052            continue;
2053        }
2054
2055        // Extract body-side values from the delta row for each provenance join col.
2056        let body_values: Vec<(String, ScalarKey)> = binding
2057            .provenance_join_cols
2058            .iter()
2059            .filter_map(|(body_col, _derived_col)| {
2060                let col_idx = delta_schema
2061                    .fields()
2062                    .iter()
2063                    .position(|f| f.name() == body_col)?;
2064                let key = extract_scalar_key(delta_batch, &[col_idx], row_idx);
2065                Some((body_col.clone(), key.into_iter().next()?))
2066            })
2067            .collect();
2068
2069        if body_values.len() != binding.provenance_join_cols.len() {
2070            continue;
2071        }
2072
2073        // Read current data from the registry entry for this IS-ref's rule.
2074        let entry = match registry.get(binding.derived_scan_index) {
2075            Some(e) => e,
2076            None => continue,
2077        };
2078        let source_batches = entry.data.read();
2079        let source_schema = &entry.schema;
2080
2081        // Find matching source rows and hash them.
2082        for src_batch in source_batches.iter() {
2083            let all_src_indices: Vec<usize> = (0..src_batch.num_columns()).collect();
2084            for src_row in 0..src_batch.num_rows() {
2085                let matches = binding.provenance_join_cols.iter().enumerate().all(
2086                    |(i, (_body_col, derived_col))| {
2087                        let src_col_idx = source_schema
2088                            .fields()
2089                            .iter()
2090                            .position(|f| f.name() == derived_col);
2091                        match src_col_idx {
2092                            Some(idx) => {
2093                                let src_key = extract_scalar_key(src_batch, &[idx], src_row);
2094                                src_key.first() == Some(&body_values[i].1)
2095                            }
2096                            None => false,
2097                        }
2098                    },
2099                );
2100                if matches {
2101                    let fact_hash = format!(
2102                        "{:?}",
2103                        extract_scalar_key(src_batch, &all_src_indices, src_row)
2104                    )
2105                    .into_bytes();
2106                    inputs.push(ProofTerm {
2107                        source_rule: binding.rule_name.clone(),
2108                        base_fact_id: fact_hash,
2109                    });
2110                }
2111            }
2112        }
2113    }
2114
2115    inputs
2116}
2117
2118/// Phase D D-C0: per-body-row variant of [`collect_is_ref_inputs`] used
2119/// to pre-populate `FoldExec`'s `body_support_map` for TopKProofs MNOR.
2120///
2121/// At FOLD time, the rule's own facts haven't been recorded in the
2122/// `ProvenanceStore` yet (`record_provenance` runs after fact
2123/// materialization, and is keyed by post-YIELD hashes anyway), so the
2124/// support set for each pre-fold body row must be reconstructed
2125/// directly from the rule's IS-ref bindings + the source rules'
2126/// registry data.
2127///
2128/// We don't know which clause produced each body row at this point —
2129/// the iteration-local `clause_candidates` are gone — so we iterate
2130/// **every** clause's `is_ref_bindings`. The `provenance_join_cols`
2131/// schema check inside `collect_is_ref_inputs` already skips bindings
2132/// whose body columns aren't in the row's schema, so cross-clause
2133/// contamination is bounded (a binding only matches if its body cols
2134/// are present and the values join). For the single-clause TopKProofs
2135/// scenarios in TCK this is exact; for multi-clause TopKProofs rules
2136/// it is a conservative over-approximation that may inflate base-RV
2137/// counts (treated as the same RV under interning) — acceptable
2138/// because the DNF math collapses duplicates by inclusion-exclusion.
2139fn collect_is_ref_inputs_for_body_row(
2140    rule: &FixpointRulePlan,
2141    delta_batch: &RecordBatch,
2142    row_idx: usize,
2143    registry: &Arc<DerivedScanRegistry>,
2144) -> Vec<ProofTerm> {
2145    let mut combined: Vec<ProofTerm> = Vec::new();
2146    for clause_index in 0..rule.clauses.len() {
2147        let part = collect_is_ref_inputs(rule, clause_index, delta_batch, row_idx, registry);
2148        combined.extend(part);
2149    }
2150    combined
2151}
2152
2153// ---------------------------------------------------------------------------
2154// Shared-lineage detection
2155// ---------------------------------------------------------------------------
2156
2157/// Detect KEY groups in a rule's pre-fold facts where recursive derivation
2158/// may violate the independence assumption of MNOR/MPROD.
2159///
2160/// Uses a two-tier strategy:
2161/// 1. **Precise**: If the `ProvenanceStore` has populated `support` for facts
2162///    in the group, we recursively compute lineage (Cui & Widom 2000) and
2163///    check for pairwise overlap. A shared base fact proves a dependency.
2164/// 2. **Structural fallback**: When lineage tracking is unavailable (e.g., the
2165///    IS-ref subject variables were projected away), we check whether any fact
2166///    in a multi-row group was derived by a clause that has IS-ref bindings.
2167///    Recursive derivation through shared relations is a strong signal that
2168///    proof paths may share intermediate nodes.
2169///
2170/// Per-row data collected during shared-lineage detection.
2171#[expect(
2172    dead_code,
2173    reason = "Fields accessed via SharedLineageInfo in detect_shared_lineage"
2174)]
2175pub(crate) struct SharedGroupRow {
2176    pub fact_hash: Vec<u8>,
2177    pub lineage: HashSet<Vec<u8>>,
2178}
2179
2180/// Information about groups with shared proofs, returned by `detect_shared_lineage`.
2181pub(crate) struct SharedLineageInfo {
2182    /// KEY group → rows with their base fact sets.
2183    pub shared_groups: HashMap<Vec<ScalarKey>, Vec<SharedGroupRow>>,
2184}
2185
2186/// Build a byte key that uniquely identifies a row across all columns.
2187pub(crate) fn fact_hash_key(batch: &RecordBatch, all_indices: &[usize], row_idx: usize) -> Vec<u8> {
2188    format!("{:?}", extract_scalar_key(batch, all_indices, row_idx)).into_bytes()
2189}
2190
2191/// Emits at most one `SharedProbabilisticDependency` warning per rule.
2192/// Returns `Some(SharedLineageInfo)` if any group has shared proofs.
2193fn detect_shared_lineage(
2194    rule: &FixpointRulePlan,
2195    pre_fold_facts: &[RecordBatch],
2196    tracker: &Arc<ProvenanceStore>,
2197    warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
2198    semiring_kind: SemiringKind,
2199) -> Option<SharedLineageInfo> {
2200    use uni_locy::{RuntimeWarning, RuntimeWarningCode};
2201
2202    // Only check rules with probability-domain fold bindings. M3:
2203    // dispatches via the `LocyAggregate` trait so user-authored
2204    // probability aggregates participate automatically — selected by the
2205    // trait's `is_probability_aggregate()` flag, not by hardcoded name.
2206    let has_prob_fold = rule
2207        .fold_bindings
2208        .iter()
2209        .any(|fb| fb.aggregate.is_probability_aggregate());
2210    if !has_prob_fold {
2211        return None;
2212    }
2213
2214    // Group facts by KEY columns.
2215    let key_indices = &rule.key_column_indices;
2216    let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
2217
2218    let mut groups: HashMap<Vec<ScalarKey>, Vec<Vec<u8>>> = HashMap::new();
2219    for batch in pre_fold_facts {
2220        for row_idx in 0..batch.num_rows() {
2221            let key = extract_scalar_key(batch, key_indices, row_idx);
2222            let fact_hash = fact_hash_key(batch, &all_indices, row_idx);
2223            groups.entry(key).or_default().push(fact_hash);
2224        }
2225    }
2226
2227    let mut shared_groups: HashMap<Vec<ScalarKey>, Vec<SharedGroupRow>> = HashMap::new();
2228    let mut any_shared = false;
2229
2230    // Check each group with ≥2 rows.
2231    for (key, fact_hashes) in &groups {
2232        if fact_hashes.len() < 2 {
2233            continue;
2234        }
2235
2236        // Tier 1: precise base-fact overlap detection via tracker inputs.
2237        let mut has_inputs = false;
2238        let mut per_row_bases: Vec<HashSet<Vec<u8>>> = Vec::new();
2239        for fh in fact_hashes {
2240            let bases = compute_lineage(fh, tracker, &mut HashSet::new());
2241            if let Some(entry) = tracker.lookup(fh)
2242                && !entry.support.is_empty()
2243            {
2244                has_inputs = true;
2245            }
2246            per_row_bases.push(bases);
2247        }
2248
2249        let shared_found = if has_inputs {
2250            // At least some facts have tracked inputs — do precise comparison.
2251            let mut found = false;
2252            'outer: for i in 0..per_row_bases.len() {
2253                for j in (i + 1)..per_row_bases.len() {
2254                    if !per_row_bases[i].is_disjoint(&per_row_bases[j]) {
2255                        found = true;
2256                        break 'outer;
2257                    }
2258                }
2259            }
2260            found
2261        } else {
2262            // Tier 2: structural fallback — check if any fact in the group was
2263            // derived by a clause with IS-ref bindings (recursive derivation).
2264            fact_hashes.iter().any(|fh| {
2265                tracker.lookup(fh).is_some_and(|entry| {
2266                    rule.clauses
2267                        .get(entry.clause_index)
2268                        .is_some_and(|clause| clause.is_ref_bindings.iter().any(|b| !b.negated))
2269                })
2270            })
2271        };
2272
2273        if shared_found {
2274            any_shared = true;
2275            // Collect the group rows with their base facts for BDD use.
2276            let rows: Vec<SharedGroupRow> = fact_hashes
2277                .iter()
2278                .zip(per_row_bases)
2279                .map(|(fh, bases)| SharedGroupRow {
2280                    fact_hash: fh.clone(),
2281                    lineage: bases,
2282                })
2283                .collect();
2284            shared_groups.insert(key.clone(), rows);
2285        }
2286    }
2287
2288    // Phase 5: Cross-group correlation warning.
2289    // Check if any IS-ref input fact appears in multiple KEY groups.
2290    // This is independent of within-group sharing: even rules whose KEY groups
2291    // each have only one post-fold row can exhibit cross-group correlation when
2292    // different groups consume the same IS-ref base fact.
2293    {
2294        let mut input_to_groups: HashMap<Vec<u8>, HashSet<Vec<ScalarKey>>> = HashMap::new();
2295        for (key, fact_hashes) in &groups {
2296            for fh in fact_hashes {
2297                if let Some(entry) = tracker.lookup(fh) {
2298                    for input in &entry.support {
2299                        input_to_groups
2300                            .entry(input.base_fact_id.clone())
2301                            .or_default()
2302                            .insert(key.clone());
2303                    }
2304                }
2305            }
2306        }
2307        let has_cross_group = input_to_groups.values().any(|g| g.len() > 1);
2308        if has_cross_group && let Ok(mut warnings) = warnings_slot.write() {
2309            let already_warned = warnings.iter().any(|w| {
2310                w.code == RuntimeWarningCode::CrossGroupCorrelationNotExact
2311                    && w.rule_name == rule.name
2312            });
2313            if !already_warned {
2314                // Phase D F3: pick one canonical example of a shared
2315                // input fact and the KEY groups it bridges, so users
2316                // can correlate the warning with EXPLAIN output.
2317                let example =
2318                    input_to_groups
2319                        .iter()
2320                        .find(|(_, g)| g.len() > 1)
2321                        .map(|(input, groups)| {
2322                            let short = input
2323                                .iter()
2324                                .take(8)
2325                                .map(|b| format!("{:02x}", b))
2326                                .collect::<String>();
2327                            let mut group_strs: Vec<String> =
2328                                groups.iter().map(|k| format!("{:?}", k)).collect();
2329                            group_strs.sort();
2330                            format!(
2331                                "input {} shared by groups [{}]",
2332                                short,
2333                                group_strs.join(", ")
2334                            )
2335                        });
2336                // Phase D F3 case 1 BDD-time deepening: count distinct
2337                // base facts (= BDD variables) that cross groups, so the
2338                // warning carries structured metadata mirroring
2339                // `BddLimitExceeded`. Users can correlate
2340                // `variable_count` with EXPLAIN's BDD output.
2341                let shared_variable_count =
2342                    input_to_groups.values().filter(|g| g.len() > 1).count();
2343                warnings.push(RuntimeWarning {
2344                    code: RuntimeWarningCode::CrossGroupCorrelationNotExact,
2345                    message: format!(
2346                        "Rule '{}': {} IS-ref base fact(s) are shared across different \
2347                         KEY groups. BDD corrects per-group probabilities but cannot \
2348                         account for cross-group correlations.",
2349                        rule.name, shared_variable_count
2350                    ),
2351                    rule_name: rule.name.clone(),
2352                    variable_count: Some(shared_variable_count),
2353                    key_group: example,
2354                });
2355            }
2356        }
2357    }
2358
2359    if any_shared {
2360        // Phase D D-C0b: under `SemiringKind::TopKProofs`, the FOLD-time
2361        // DNF inclusion-exclusion math (shipped in D-C0) auto-corrects
2362        // for within-group base-fact sharing — the "Results may
2363        // overestimate" premise of `SharedProbabilisticDependency`
2364        // is no longer true. Suppress the warning under TopK; users
2365        // who chose TopKProofs explicitly opted into the
2366        // correctness-preserving path. Cross-group correlation
2367        // (`CrossGroupCorrelationNotExact`) still fires above because
2368        // D-C0 doesn't span KEY-group boundaries.
2369        let suppress_under_topk = matches!(semiring_kind, SemiringKind::TopKProofs { .. });
2370        if !suppress_under_topk && let Ok(mut warnings) = warnings_slot.write() {
2371            let already_warned = warnings.iter().any(|w| {
2372                w.code == RuntimeWarningCode::SharedProbabilisticDependency
2373                    && w.rule_name == rule.name
2374            });
2375            if !already_warned {
2376                warnings.push(RuntimeWarning {
2377                    code: RuntimeWarningCode::SharedProbabilisticDependency,
2378                    message: format!(
2379                        "Rule '{}' aggregates with MNOR/MPROD but some proof paths \
2380                         share intermediate facts, violating the independence assumption. \
2381                         Results may overestimate probability.",
2382                        rule.name
2383                    ),
2384                    rule_name: rule.name.clone(),
2385                    variable_count: None,
2386                    key_group: None,
2387                });
2388            }
2389        }
2390        Some(SharedLineageInfo { shared_groups })
2391    } else {
2392        None
2393    }
2394}
2395
2396/// Record provenance and detect shared proofs for non-recursive strata.
2397///
2398/// Non-recursive rules are evaluated in a single pass (no fixpoint loop), so
2399/// the regular `record_provenance` + `detect_shared_lineage` path is never hit.
2400/// This function bridges that gap by recording a `ProvenanceAnnotation` for every
2401/// fact produced by each clause and then running the same two-tier detection
2402/// logic used by the recursive path.
2403#[allow(
2404    clippy::too_many_arguments,
2405    reason = "context bundle would be over-engineering for one call site"
2406)]
2407pub(crate) async fn record_and_detect_lineage_nonrecursive(
2408    rule: &FixpointRulePlan,
2409    tagged_clause_facts: &[(usize, Vec<RecordBatch>)],
2410    tracker: &Arc<ProvenanceStore>,
2411    warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
2412    registry: &Arc<DerivedScanRegistry>,
2413    top_k_proofs: usize,
2414    classifiers: ClassifierRefs<'_>,
2415    semiring_kind: SemiringKind,
2416) -> Option<SharedLineageInfo> {
2417    let classifier_registry = classifiers.registry;
2418    let classifier_cache = classifiers.cache;
2419    let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
2420
2421    // Pre-compute base fact probabilities for top-k mode.
2422    let base_probs = if top_k_proofs > 0 {
2423        tracker.base_fact_probs()
2424    } else {
2425        HashMap::new()
2426    };
2427
2428    let mut topk_acc = TopKProofAccumulator::new();
2429
2430    // Record provenance for each clause's facts.
2431    for (clause_index, batches) in tagged_clause_facts {
2432        for batch in batches {
2433            for row_idx in 0..batch.num_rows() {
2434                let row_hash = fact_hash_key(batch, &all_indices, row_idx);
2435                let fact_row = batch_row_to_value_map(batch, row_idx);
2436
2437                let support = collect_is_ref_inputs(rule, *clause_index, batch, row_idx, registry);
2438
2439                let proof_probability = if top_k_proofs > 0 {
2440                    compute_proof_probability(&support, &base_probs)
2441                } else {
2442                    None
2443                };
2444
2445                let entry = ProvenanceAnnotation {
2446                    rule_name: rule.name.clone(),
2447                    clause_index: *clause_index,
2448                    support,
2449                    along_values: {
2450                        let along_names: Vec<String> = rule
2451                            .clauses
2452                            .get(*clause_index)
2453                            .map(|c| c.along_bindings.clone())
2454                            .unwrap_or_default();
2455                        along_names
2456                            .iter()
2457                            .filter_map(|name| {
2458                                fact_row.get(name).map(|v| (name.clone(), v.clone()))
2459                            })
2460                            .collect()
2461                    },
2462                    iteration: 0,
2463                    fact_row: fact_row.clone(),
2464                    proof_probability,
2465                    neural_calls: collect_neural_calls_for_row(
2466                        rule,
2467                        *clause_index,
2468                        &fact_row,
2469                        classifier_registry,
2470                        classifier_cache,
2471                        classifiers.provenance_store,
2472                    )
2473                    .await,
2474                };
2475                if top_k_proofs > 0 {
2476                    topk_acc.accumulate(&entry, &row_hash);
2477                    tracker.record_top_k(row_hash, entry, top_k_proofs);
2478                } else {
2479                    tracker.record(row_hash, entry);
2480                }
2481            }
2482        }
2483    }
2484
2485    topk_acc.emit_warning_if_any(rule, top_k_proofs, warnings_slot);
2486
2487    // Flatten all clause facts and run detection.
2488    let all_facts: Vec<RecordBatch> = tagged_clause_facts
2489        .iter()
2490        .flat_map(|(_, batches)| batches.iter().cloned())
2491        .collect();
2492    detect_shared_lineage(rule, &all_facts, tracker, warnings_slot, semiring_kind)
2493}
2494
2495/// Apply exact weighted model counting (WMC) for shared-lineage groups.
2496///
2497/// Replaces multiple rows in groups with shared lineage with a single
2498/// representative row whose PROB column carries the BDD-computed exact
2499/// probability (Sang et al. 2005). For groups that exceed
2500/// `max_bdd_variables`, rows are left unchanged and a `BddLimitExceeded`
2501/// warning is emitted.
2502pub(crate) fn apply_exact_wmc(
2503    pre_fold_facts: Vec<RecordBatch>,
2504    rule: &FixpointRulePlan,
2505    shared_info: &SharedLineageInfo,
2506    tracker: &Arc<ProvenanceStore>,
2507    max_bdd_variables: usize,
2508    warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
2509    approximate_slot: &Arc<StdRwLock<HashMap<String, Vec<String>>>>,
2510) -> DFResult<Vec<RecordBatch>> {
2511    use crate::query::df_graph::locy_bdd::{SemiringOp, weighted_model_count};
2512    use uni_locy::{RuntimeWarning, RuntimeWarningCode};
2513
2514    // Find the probability-domain fold binding to know which column
2515    // to overwrite. M3: dispatch through the `LocyAggregate` trait so
2516    // user-authored probability aggregates participate.
2517    let prob_fold = rule
2518        .fold_bindings
2519        .iter()
2520        .find(|fb| fb.aggregate.is_probability_aggregate());
2521    let prob_fold = match prob_fold {
2522        Some(f) => f,
2523        None => return Ok(pre_fold_facts),
2524    };
2525    let semiring_op = if prob_fold.aggregate.is_noisy_or() {
2526        SemiringOp::Disjunction
2527    } else {
2528        SemiringOp::Conjunction
2529    };
2530    let prob_col_idx = prob_fold.input_col_index;
2531    let prob_col_name = rule.yield_schema.field(prob_col_idx).name().clone();
2532
2533    let key_indices = &rule.key_column_indices;
2534    let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
2535
2536    // Build a set of shared group keys for quick lookup.
2537    let shared_keys: HashSet<Vec<ScalarKey>> = shared_info.shared_groups.keys().cloned().collect();
2538
2539    // Phase 1: Collect all rows for each shared KEY group across all batches.
2540    // Store (batch_index, row_index) pairs for each group.
2541    struct GroupAccum {
2542        base_facts: Vec<HashSet<Vec<u8>>>,
2543        base_probs: HashMap<Vec<u8>, f64>,
2544        /// First occurrence: (batch_index, row_index) — used as representative.
2545        representative: (usize, usize),
2546        row_locations: Vec<(usize, usize)>,
2547    }
2548
2549    let mut group_accums: HashMap<Vec<ScalarKey>, GroupAccum> = HashMap::new();
2550    let mut non_shared_rows: Vec<(usize, usize)> = Vec::new(); // (batch_idx, row_idx)
2551
2552    for (batch_idx, batch) in pre_fold_facts.iter().enumerate() {
2553        for row_idx in 0..batch.num_rows() {
2554            let key = extract_scalar_key(batch, key_indices, row_idx);
2555            if shared_keys.contains(&key) {
2556                let fact_hash = fact_hash_key(batch, &all_indices, row_idx);
2557                let bases = compute_lineage(&fact_hash, tracker, &mut HashSet::new());
2558
2559                let accum = group_accums.entry(key).or_insert_with(|| GroupAccum {
2560                    base_facts: Vec::new(),
2561                    base_probs: HashMap::new(),
2562                    representative: (batch_idx, row_idx),
2563                    row_locations: Vec::new(),
2564                });
2565
2566                // Look up probabilities for base facts.
2567                for bf in &bases {
2568                    if !accum.base_probs.contains_key(bf)
2569                        && let Some(entry) = tracker.lookup(bf)
2570                        && let Some(val) = entry.fact_row.get(&prob_col_name)
2571                        && let Some(p) = value_to_f64(val)
2572                    {
2573                        accum.base_probs.insert(bf.clone(), p);
2574                    }
2575                }
2576
2577                accum.base_facts.push(bases);
2578                accum.row_locations.push((batch_idx, row_idx));
2579            } else {
2580                non_shared_rows.push((batch_idx, row_idx));
2581            }
2582        }
2583    }
2584
2585    // Phase 2: Compute BDD for each shared group (across all batches).
2586    // Track which (batch_idx, row_idx) pairs to keep vs drop.
2587    let mut keep_rows: HashSet<(usize, usize)> = HashSet::new();
2588    // Map of (batch_idx, row_idx) → overridden PROB value (for BDD-succeeded groups).
2589    let mut overrides: HashMap<(usize, usize), f64> = HashMap::new();
2590
2591    // All non-shared rows are kept.
2592    for &loc in &non_shared_rows {
2593        keep_rows.insert(loc);
2594    }
2595
2596    for (key, accum) in &group_accums {
2597        let bdd_result = weighted_model_count(
2598            &accum.base_facts,
2599            &accum.base_probs,
2600            semiring_op,
2601            max_bdd_variables,
2602        );
2603
2604        if bdd_result.approximated {
2605            // Emit BddLimitExceeded warning (one per key group).
2606            if let Ok(mut warnings) = warnings_slot.write() {
2607                let key_desc = format!("{key:?}");
2608                let already_warned = warnings.iter().any(|w| {
2609                    w.code == RuntimeWarningCode::BddLimitExceeded
2610                        && w.rule_name == rule.name
2611                        && w.key_group.as_deref() == Some(&key_desc)
2612                });
2613                if !already_warned {
2614                    warnings.push(RuntimeWarning {
2615                        code: RuntimeWarningCode::BddLimitExceeded,
2616                        message: format!(
2617                            "Rule '{}': BDD variable limit exceeded ({} > {}). \
2618                             Falling back to independence-mode result.",
2619                            rule.name, bdd_result.variable_count, max_bdd_variables
2620                        ),
2621                        rule_name: rule.name.clone(),
2622                        variable_count: Some(bdd_result.variable_count),
2623                        key_group: Some(key_desc),
2624                    });
2625                }
2626            }
2627            if let Ok(mut approx) = approximate_slot.write() {
2628                let key_desc = format!("{key:?}");
2629                approx.entry(rule.name.clone()).or_default().push(key_desc);
2630            }
2631            // Keep all rows unchanged.
2632            for &loc in &accum.row_locations {
2633                keep_rows.insert(loc);
2634            }
2635        } else {
2636            // BDD succeeded: keep one representative row with overridden PROB.
2637            keep_rows.insert(accum.representative);
2638            overrides.insert(accum.representative, bdd_result.probability);
2639        }
2640    }
2641
2642    // Phase 3: Build output batches by filtering kept rows per batch.
2643    let mut result_batches = Vec::new();
2644    for (batch_idx, batch) in pre_fold_facts.iter().enumerate() {
2645        let kept_indices: Vec<usize> = (0..batch.num_rows())
2646            .filter(|&row_idx| keep_rows.contains(&(batch_idx, row_idx)))
2647            .collect();
2648
2649        if kept_indices.is_empty() {
2650            continue;
2651        }
2652
2653        let indices = arrow::array::UInt32Array::from(
2654            kept_indices.iter().map(|&i| i as u32).collect::<Vec<_>>(),
2655        );
2656        let mut columns: Vec<arrow::array::ArrayRef> = batch
2657            .columns()
2658            .iter()
2659            .map(|col| arrow::compute::take(col, &indices, None))
2660            .collect::<Result<Vec<_>, _>>()
2661            .map_err(arrow_err)?;
2662
2663        // Check if any kept rows have PROB overrides.
2664        let override_map: Vec<Option<f64>> = kept_indices
2665            .iter()
2666            .map(|&row_idx| overrides.get(&(batch_idx, row_idx)).copied())
2667            .collect();
2668
2669        if override_map.iter().any(|o| o.is_some()) && prob_col_idx < columns.len() {
2670            // Rebuild the PROB column with overrides.
2671            let existing_prob = columns[prob_col_idx]
2672                .as_any()
2673                .downcast_ref::<arrow::array::Float64Array>();
2674            let new_values: Vec<f64> = override_map
2675                .iter()
2676                .enumerate()
2677                .map(|(i, ov)| match ov {
2678                    Some(p) => *p,
2679                    None => existing_prob.map(|arr| arr.value(i)).unwrap_or(0.0),
2680                })
2681                .collect();
2682            columns[prob_col_idx] = Arc::new(arrow::array::Float64Array::from(new_values));
2683        }
2684
2685        let result_batch = RecordBatch::try_new(batch.schema(), columns).map_err(arrow_err)?;
2686        result_batches.push(result_batch);
2687    }
2688
2689    Ok(result_batches)
2690}
2691
2692/// Extract an f64 from a `Value`, supporting Float and Int.
2693fn value_to_f64(val: &uni_common::Value) -> Option<f64> {
2694    match val {
2695        uni_common::Value::Float(f) => Some(*f),
2696        uni_common::Value::Int(i) => Some(*i as f64),
2697        _ => None,
2698    }
2699}
2700
2701/// Compute the lineage of a derived fact (Cui & Widom 2000).
2702///
2703/// Recursively traverses the provenance store to collect the set of base-level
2704/// fact hashes that contribute to this derivation. A base fact is one with no
2705/// IS-ref support (a graph-level fact). Intermediate facts are expanded
2706/// transitively through the store.
2707fn compute_lineage(
2708    fact_hash: &[u8],
2709    tracker: &Arc<ProvenanceStore>,
2710    visited: &mut HashSet<Vec<u8>>,
2711) -> HashSet<Vec<u8>> {
2712    if !visited.insert(fact_hash.to_vec()) {
2713        return HashSet::new(); // Cycle guard.
2714    }
2715
2716    match tracker.lookup(fact_hash) {
2717        Some(entry) if !entry.support.is_empty() => {
2718            let mut bases = HashSet::new();
2719            for input in &entry.support {
2720                let child_bases = compute_lineage(&input.base_fact_id, tracker, visited);
2721                bases.extend(child_bases);
2722            }
2723            bases
2724        }
2725        _ => {
2726            // Base fact (no tracker entry or no inputs).
2727            let mut set = HashSet::new();
2728            set.insert(fact_hash.to_vec());
2729            set
2730        }
2731    }
2732}
2733
2734/// Determine which clause produced a given row by checking each clause's candidates.
2735///
2736/// Returns the index of the first clause whose candidates contain a matching row.
2737/// Falls back to 0 if no match is found.
2738fn find_clause_for_row(
2739    delta_batch: &RecordBatch,
2740    row_idx: usize,
2741    all_indices: &[usize],
2742    clause_candidates: &[Vec<RecordBatch>],
2743) -> usize {
2744    let target_key = extract_scalar_key(delta_batch, all_indices, row_idx);
2745    for (clause_idx, batches) in clause_candidates.iter().enumerate() {
2746        for batch in batches {
2747            if batch.num_columns() != all_indices.len() {
2748                continue;
2749            }
2750            for r in 0..batch.num_rows() {
2751                if extract_scalar_key(batch, all_indices, r) == target_key {
2752                    return clause_idx;
2753                }
2754            }
2755        }
2756    }
2757    0
2758}
2759
2760/// Convert a single row from a `RecordBatch` at `row_idx` into a `HashMap<String, Value>`.
2761fn batch_row_to_value_map(
2762    batch: &RecordBatch,
2763    row_idx: usize,
2764) -> std::collections::HashMap<String, Value> {
2765    use uni_store::storage::arrow_convert::arrow_to_value;
2766
2767    let schema = batch.schema();
2768    schema
2769        .fields()
2770        .iter()
2771        .enumerate()
2772        .map(|(col_idx, field)| {
2773            let col = batch.column(col_idx);
2774            let val = arrow_to_value(col.as_ref(), row_idx, None);
2775            (field.name().clone(), val)
2776        })
2777        .collect()
2778}
2779
2780/// Filter `batches` to exclude rows where `left_col` VID appears in `neg_facts[right_col]`.
2781///
2782/// Implements anti-join semantics for negated IS-refs (`n IS NOT rule`): keeps only
2783/// rows whose subject VID is NOT present in the negated rule's fully-converged facts.
2784pub fn apply_anti_join(
2785    batches: Vec<RecordBatch>,
2786    neg_facts: &[RecordBatch],
2787    left_col: &str,
2788    right_col: &str,
2789) -> datafusion::error::Result<Vec<RecordBatch>> {
2790    use arrow::compute::filter_record_batch;
2791    use arrow_array::{Array as _, BooleanArray, UInt64Array};
2792
2793    // Collect right-side VIDs from the negated rule's derived facts.
2794    let mut banned: std::collections::HashSet<u64> = std::collections::HashSet::new();
2795    for batch in neg_facts {
2796        let Ok(idx) = batch.schema().index_of(right_col) else {
2797            continue;
2798        };
2799        let arr = batch.column(idx);
2800        let Some(vids) = arr.as_any().downcast_ref::<UInt64Array>() else {
2801            continue;
2802        };
2803        for i in 0..vids.len() {
2804            if !vids.is_null(i) {
2805                banned.insert(vids.value(i));
2806            }
2807        }
2808    }
2809
2810    if banned.is_empty() {
2811        return Ok(batches);
2812    }
2813
2814    // Filter body batches: keep rows where left_col NOT IN banned.
2815    let mut result = Vec::new();
2816    for batch in batches {
2817        let Ok(idx) = batch.schema().index_of(left_col) else {
2818            result.push(batch);
2819            continue;
2820        };
2821        let arr = batch.column(idx);
2822        let Some(vids) = arr.as_any().downcast_ref::<UInt64Array>() else {
2823            result.push(batch);
2824            continue;
2825        };
2826        let keep: Vec<bool> = (0..vids.len())
2827            .map(|i| vids.is_null(i) || !banned.contains(&vids.value(i)))
2828            .collect();
2829        let keep_arr = BooleanArray::from(keep);
2830        let filtered = filter_record_batch(&batch, &keep_arr).map_err(arrow_err)?;
2831        if filtered.num_rows() > 0 {
2832            result.push(filtered);
2833        }
2834    }
2835    Ok(result)
2836}
2837
2838// ─── Phase B Slice 3: neural-model invocation pass ───────────────────────
2839//
2840// `apply_model_invocations` runs every `ModelInvocation` lifted from a
2841// clause's YIELD items against the body's output batches. For each
2842// invocation it:
2843//
2844//   1. Resolves each `feature_expr` to a column in the batch — Slice 3
2845//      supports plain `Expr::Variable("name")` references; richer
2846//      expressions (property access, nested calls) are deferred.
2847//   2. Builds one `ClassifyInput` per row keyed by the model's input
2848//      binding names.
2849//   3. Issues a single batched `NeuralClassifier::classify` call.
2850//   4. Appends the resulting `Float64Array` as a new column matching
2851//      `invocation.output_column`.
2852//
2853// Errors:
2854//   * `UnknownNeuralModel`: the model name isn't in the registry.
2855//   * Mismatched feature-expr / column: returned as a DataFusion
2856//     Execution error.
2857
2858#[allow(clippy::too_many_arguments)]
2859pub(crate) async fn apply_model_invocations(
2860    batches: Vec<RecordBatch>,
2861    invocations: &[uni_locy::ModelInvocation],
2862    registry: &Arc<ClassifierRegistry>,
2863    cache: Option<&Arc<uni_locy::ModelInvocationCache>>,
2864    provenance_store: Option<&Arc<uni_locy::NeuralProvenanceStore>>,
2865    path_context_handles: &HashMap<
2866        String,
2867        crate::query::df_graph::locy_model_invoke::PathContextHandle,
2868    >,
2869    xervo_runtime: &crate::query::df_graph::locy_model_invoke::XervoRuntimeHandle,
2870    graph_algo: &crate::query::df_graph::locy_model_invoke::GraphAlgoHandle,
2871) -> DFResult<Vec<RecordBatch>> {
2872    use uni_locy::ClassifyInput;
2873    if batches.is_empty() || invocations.is_empty() {
2874        return Ok(batches);
2875    }
2876    // Phase D D2: pre-embed all unique `semantic_match` query literals
2877    // once per call. Resolvers below lower each `semantic_match(prop,
2878    // 'text')` into a `SimilarTo { left: prop_col, right: Const(Vector) }`.
2879    let semantic_match_embeddings =
2880        pre_embed_semantic_match_queries(invocations, xervo_runtime).await?;
2881    // Phase D D1 graph-structural: pre-compute topology scores and
2882    // neighbor-property maps for every distinct (fn_name, args) tuple
2883    // appearing in any FEATURE FunctionCall. One pass per call; reused
2884    // across every row of every batch.
2885    let graph_feature_maps = precompute_graph_feature_maps(invocations, graph_algo).await?;
2886    let neighbor_feature_maps =
2887        precompute_neighbor_feature_maps(invocations, &batches, graph_algo).await?;
2888    let mut out_batches = Vec::with_capacity(batches.len());
2889    for batch in batches {
2890        let mut current = batch;
2891        for invocation in invocations {
2892            let classifier = registry.get(&invocation.model_name).ok_or_else(|| {
2893                datafusion::error::DataFusionError::Execution(format!(
2894                    "neural classifier '{}' not registered; \
2895                         add it to LocyConfig::classifier_registry",
2896                    invocation.model_name
2897                ))
2898            })?;
2899
2900            // Resolve each feature_expr to a per-row evaluator.
2901            // Supported shapes (validated at compile time by
2902            // `extract_model_invocations` / `validate_features`):
2903            //   * `Expr::Variable("name")` — direct column reference.
2904            //   * `Expr::Property(Variable(v), prop)` — looked up by the
2905            //     conventional `"v.prop"` column name materialized by
2906            //     the planner's `translate_property_access` pipeline.
2907            //   * `Expr::FunctionCall { name: "similar_to"|"semantic_match", ... }`
2908            //     — Phase D D1/D2 retrieval-backed feature; both args
2909            //     resolved to columns; UDF evaluated per row against
2910            //     the row's `uni_common::Value` payloads.
2911            let resolvers = build_feature_resolvers(
2912                &current,
2913                invocation,
2914                path_context_handles,
2915                &semantic_match_embeddings,
2916                &graph_feature_maps,
2917                &neighbor_feature_maps,
2918            )?;
2919
2920            // Build one ClassifyInput per row.
2921            let n_rows = current.num_rows();
2922            let mut inputs: Vec<ClassifyInput> = Vec::with_capacity(n_rows);
2923            let mut input_hashes: Vec<u64> = Vec::with_capacity(n_rows);
2924            for row_idx in 0..n_rows {
2925                let mut features = std::collections::HashMap::new();
2926                for resolver in &resolvers {
2927                    let value = resolver.eval_row(&current, row_idx)?;
2928                    features.insert(resolver.binding_name.clone(), value);
2929                }
2930                let input = ClassifyInput { features };
2931                input_hashes.push(input.stable_hash());
2932                inputs.push(input);
2933            }
2934
2935            // Slice 2: memoization. Split inputs into cache hits and
2936            // misses; only call `classify` on the misses, then weave
2937            // the cached values back in by original row index.
2938            let mut probs: Vec<f64> = vec![0.0; n_rows];
2939            let mut miss_inputs: Vec<ClassifyInput> = Vec::new();
2940            let mut miss_row_indices: Vec<usize> = Vec::new();
2941            if let Some(c) = cache {
2942                for (row_idx, h) in input_hashes.iter().enumerate() {
2943                    match c.get(&invocation.model_name, *h) {
2944                        Some(v) => probs[row_idx] = v,
2945                        None => {
2946                            miss_row_indices.push(row_idx);
2947                            miss_inputs.push(inputs[row_idx].clone());
2948                        }
2949                    }
2950                }
2951            } else {
2952                miss_row_indices = (0..n_rows).collect();
2953                miss_inputs = inputs.clone();
2954            }
2955
2956            // Phase C C-RawCalibratedSeparation: when a calibrator is
2957            // present, route through `raw_and_calibrated` so the
2958            // provenance store records both the base classifier's raw
2959            // output AND the post-calibrator value. Bare classifiers
2960            // (no calibrator) keep using `classify`. The downstream
2961            // `probs[row]` is always the *calibrated* value when
2962            // available — that's what the rule's PROB output column
2963            // and the memoization cache carry.
2964            let calibrator = classifier.get_calibrator();
2965            let (miss_raws, miss_calibrated) = if miss_inputs.is_empty() {
2966                (Vec::new(), Vec::new())
2967            } else if calibrator.is_some() {
2968                let pairs = classifier
2969                    .raw_and_calibrated(&miss_inputs)
2970                    .await
2971                    .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
2972                if pairs.len() != miss_inputs.len() {
2973                    return Err(datafusion::error::DataFusionError::Execution(format!(
2974                        "classifier '{}' raw_and_calibrated returned {} outputs for {} inputs",
2975                        invocation.model_name,
2976                        pairs.len(),
2977                        miss_inputs.len()
2978                    )));
2979                }
2980                let raws: Vec<f64> = pairs.iter().map(|(r, _)| *r).collect();
2981                let cals: Vec<f64> = pairs.iter().map(|(r, c)| c.unwrap_or(*r)).collect();
2982                (raws, cals)
2983            } else {
2984                let r = classifier
2985                    .classify(&miss_inputs)
2986                    .await
2987                    .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
2988                if r.len() != miss_inputs.len() {
2989                    return Err(datafusion::error::DataFusionError::Execution(format!(
2990                        "classifier '{}' returned {} outputs for {} inputs",
2991                        invocation.model_name,
2992                        r.len(),
2993                        miss_inputs.len()
2994                    )));
2995                }
2996                // No calibrator → raw == final.
2997                (r.clone(), r)
2998            };
2999            // The memoization cache stores the *final* (calibrated when
3000            // available) value, matching what downstream rules consume.
3001            // Track per-miss raws alongside so the provenance store
3002            // sees both. For cache hits, we don't have raws — the
3003            // provenance record for that row gets None for `raw` and
3004            // the cached value as `calibrated` (the only thing we
3005            // remembered). Future slice can extend the cache to carry
3006            // both; current behavior matches pre-fix correctness for
3007            // EXPLAIN of cache hits.
3008            let mut row_raw: Vec<Option<f64>> = vec![None; n_rows];
3009            for (i, &row_idx) in miss_row_indices.iter().enumerate() {
3010                probs[row_idx] = miss_calibrated[i];
3011                row_raw[row_idx] = Some(miss_raws[i]);
3012                if let Some(c) = cache {
3013                    c.insert(
3014                        &invocation.model_name,
3015                        input_hashes[row_idx],
3016                        miss_calibrated[i],
3017                    );
3018                }
3019            }
3020
3021            // Phase C B1-B3 follow-up: when a provenance store is
3022            // configured, record (raw, calibrated, confidence_band)
3023            // per row. With C-RawCalibratedSeparation (above),
3024            // `row_raw[i]` carries the *pre-calibrator* value when we
3025            // computed it on this call; `probs[i]` is the
3026            // post-calibrator value. `confidence_band` comes from the
3027            // active Calibrator's `confidence_band(p)`.
3028            if let Some(store) = provenance_store {
3029                for row_idx in 0..n_rows {
3030                    let calibrated_value = probs[row_idx];
3031                    let (raw_value, calibrated) = match (row_raw[row_idx], &calibrator) {
3032                        (Some(raw), Some(_)) => (raw, Some(calibrated_value)),
3033                        (Some(raw), None) => (raw, None),
3034                        // Cache hit: we only have the calibrated value.
3035                        // Report it as raw with `calibrated == raw` to
3036                        // preserve telemetry shape; document this in
3037                        // the field doc.
3038                        (None, _) => (
3039                            calibrated_value,
3040                            calibrator.as_ref().map(|_| calibrated_value),
3041                        ),
3042                    };
3043                    let band = calibrator
3044                        .as_ref()
3045                        .and_then(|c| c.confidence_band(calibrated_value));
3046                    store.record(
3047                        &invocation.model_name,
3048                        input_hashes[row_idx],
3049                        uni_locy::NeuralProvenanceRecord {
3050                            raw_probability: raw_value,
3051                            calibrated_probability: calibrated,
3052                            confidence_band: band,
3053                            // Phase 12 EXPLAIN follow-up: stash the
3054                            // per-binding `FeatureValue` map that fed
3055                            // the classifier so Mode B re-evaluation
3056                            // can surface graph-structural feature
3057                            // values without re-precomputing topology
3058                            // or neighbor maps (the hot-path data is
3059                            // authoritative).
3060                            feature_inputs: inputs[row_idx].features.clone(),
3061                        },
3062                    );
3063                }
3064            }
3065
3066            // Overwrite the placeholder column put there by the
3067            // compile-time YIELD rewrite. If the column doesn't yet
3068            // exist (defensive — shouldn't happen for well-formed
3069            // plans), fall back to appending.
3070            let out_col: Arc<dyn arrow_array::Array> =
3071                Arc::new(arrow_array::Float64Array::from(probs));
3072            let schema = current.schema();
3073            let target_idx = schema.index_of(&invocation.output_column).ok();
3074            let mut columns: Vec<Arc<dyn arrow_array::Array>> = current.columns().to_vec();
3075            let mut fields: Vec<Arc<arrow_schema::Field>> =
3076                schema.fields().iter().cloned().collect();
3077            match target_idx {
3078                Some(idx) => {
3079                    columns[idx] = out_col;
3080                    // Force the field's data type to Float64 in case
3081                    // the placeholder was inferred at a wider type.
3082                    fields[idx] = Arc::new(arrow_schema::Field::new(
3083                        &invocation.output_column,
3084                        arrow_schema::DataType::Float64,
3085                        true,
3086                    ));
3087                }
3088                None => {
3089                    columns.push(out_col);
3090                    fields.push(Arc::new(arrow_schema::Field::new(
3091                        &invocation.output_column,
3092                        arrow_schema::DataType::Float64,
3093                        true,
3094                    )));
3095                }
3096            }
3097            let new_schema = Arc::new(arrow_schema::Schema::new(fields));
3098            current = RecordBatch::try_new(new_schema, columns).map_err(arrow_err)?;
3099        }
3100        out_batches.push(current);
3101    }
3102    Ok(out_batches)
3103}
3104
3105/// Extract a [`uni_locy::FeatureValue`] from a column at a given row.
3106/// Conservative cast set matching the property-graph value types Locy
3107/// currently exposes; unsupported types fall back to `Null`.
3108/// Per-feature evaluator built once from a clause's `feature_exprs`
3109/// and reused for every row in the batch. Supports plain column
3110/// reads (`Direct`) and the Phase D D1/D2 retrieval-backed UDFs
3111/// (`SimilarTo`, `SemanticMatch`).
3112struct FeatureResolver {
3113    binding_name: String,
3114    kind: FeatureResolverKind,
3115}
3116
3117enum FeatureResolverKind {
3118    Direct(usize),
3119    SimilarTo {
3120        left: FeatureValueSrc,
3121        right: FeatureValueSrc,
3122    },
3123    /// Phase D D3 runtime: pull `column` from the source rule's
3124    /// derived facts via a pre-built `vid → FeatureValue` lookup. The
3125    /// `subject_col` is the index of `<subject_var>._vid` in the body
3126    /// batch; the lookup runs once per row.
3127    PathContext {
3128        subject_col: usize,
3129        vid_to_value: Arc<HashMap<u64, uni_locy::FeatureValue>>,
3130    },
3131    /// Phase D D1 graph-structural: look up the subject's pre-computed
3132    /// topology score (degree/pagerank/closeness). `subject_col` indexes
3133    /// the row's `<var>._vid`; `vid_to_score` is the whole-graph
3134    /// procedure output built once per `apply_model_invocations` call.
3135    GraphAlgoScore {
3136        subject_col: usize,
3137        vid_to_score: Arc<HashMap<u64, f64>>,
3138    },
3139    /// Phase D D1 graph-structural: aggregate a numeric property over
3140    /// each subject's one-hop outgoing neighborhood along a named edge
3141    /// type. `vid_to_values` maps subject vid → the list of numeric
3142    /// neighbor property values collected at precompute time; the
3143    /// per-row resolver applies the configured `op` (avg/max/sum).
3144    NeighborAggregate {
3145        subject_col: usize,
3146        op: NeighborAgg,
3147        vid_to_values: Arc<HashMap<u64, Vec<f64>>>,
3148    },
3149}
3150
3151#[derive(Debug, Clone, Copy)]
3152enum NeighborAgg {
3153    Avg,
3154    Max,
3155    Sum,
3156}
3157
3158/// Direction for one-hop neighborhood traversal. Mirrors
3159/// `uni_store::storage::direction::Direction` but is independent so
3160/// the typecheck / planner layer doesn't depend on uni-store.
3161#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3162enum NeighborDirection {
3163    Outgoing,
3164    Incoming,
3165    Both,
3166}
3167
3168impl NeighborDirection {
3169    fn store_directions(self) -> &'static [uni_store::storage::direction::Direction] {
3170        use uni_store::storage::direction::Direction;
3171        match self {
3172            NeighborDirection::Outgoing => &[Direction::Outgoing],
3173            NeighborDirection::Incoming => &[Direction::Incoming],
3174            NeighborDirection::Both => &[Direction::Outgoing, Direction::Incoming],
3175        }
3176    }
3177}
3178
3179impl NeighborAgg {
3180    fn from_fn_name(name: &str) -> Option<Self> {
3181        match name {
3182            "avg_neighbor" => Some(NeighborAgg::Avg),
3183            "max_neighbor" => Some(NeighborAgg::Max),
3184            "sum_neighbor" => Some(NeighborAgg::Sum),
3185            _ => None,
3186        }
3187    }
3188
3189    fn apply(self, values: &[f64]) -> Option<f64> {
3190        if values.is_empty() {
3191            return None;
3192        }
3193        match self {
3194            NeighborAgg::Avg => Some(values.iter().sum::<f64>() / values.len() as f64),
3195            NeighborAgg::Max => values.iter().copied().reduce(f64::max),
3196            NeighborAgg::Sum => Some(values.iter().sum()),
3197        }
3198    }
3199}
3200
3201/// One side of a `similar_to` feature: either a column index in the
3202/// per-row batch or a constant value lifted from a literal expression.
3203enum FeatureValueSrc {
3204    Col(usize),
3205    Const(uni_common::Value),
3206}
3207
3208impl FeatureValueSrc {
3209    fn resolve(&self, batch: &RecordBatch, row_idx: usize) -> uni_common::Value {
3210        match self {
3211            FeatureValueSrc::Col(idx) => extract_common_value(batch.column(*idx).as_ref(), row_idx),
3212            FeatureValueSrc::Const(v) => v.clone(),
3213        }
3214    }
3215}
3216
3217impl FeatureResolver {
3218    fn eval_row(&self, batch: &RecordBatch, row_idx: usize) -> DFResult<uni_locy::FeatureValue> {
3219        match &self.kind {
3220            FeatureResolverKind::Direct(idx) => {
3221                Ok(extract_feature_value(batch.column(*idx).as_ref(), row_idx))
3222            }
3223            FeatureResolverKind::SimilarTo { left, right } => {
3224                let lv = left.resolve(batch, row_idx);
3225                let rv = right.resolve(batch, row_idx);
3226                match crate::query::similar_to::eval_similar_to_pure(&lv, &rv) {
3227                    Ok(uni_common::Value::Float(f)) => Ok(uni_locy::FeatureValue::Float(f)),
3228                    Ok(_) => Ok(uni_locy::FeatureValue::Null),
3229                    Err(e) => Err(datafusion::error::DataFusionError::Execution(format!(
3230                        "similar_to UDF failed: {e}"
3231                    ))),
3232                }
3233            }
3234            FeatureResolverKind::PathContext {
3235                subject_col,
3236                vid_to_value,
3237            } => {
3238                let col = batch.column(*subject_col);
3239                if col.is_null(row_idx) {
3240                    return Ok(uni_locy::FeatureValue::Null);
3241                }
3242                if let Some(arr) = col.as_any().downcast_ref::<arrow_array::UInt64Array>() {
3243                    let vid = arr.value(row_idx);
3244                    Ok(vid_to_value
3245                        .get(&vid)
3246                        .cloned()
3247                        .unwrap_or(uni_locy::FeatureValue::Null))
3248                } else if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Int64Array>() {
3249                    let vid = arr.value(row_idx) as u64;
3250                    Ok(vid_to_value
3251                        .get(&vid)
3252                        .cloned()
3253                        .unwrap_or(uni_locy::FeatureValue::Null))
3254                } else {
3255                    Ok(uni_locy::FeatureValue::Null)
3256                }
3257            }
3258            FeatureResolverKind::GraphAlgoScore {
3259                subject_col,
3260                vid_to_score,
3261            } => {
3262                let col = batch.column(*subject_col);
3263                if col.is_null(row_idx) {
3264                    return Ok(uni_locy::FeatureValue::Null);
3265                }
3266                let vid_opt: Option<u64> = if let Some(arr) =
3267                    col.as_any().downcast_ref::<arrow_array::UInt64Array>()
3268                {
3269                    Some(arr.value(row_idx))
3270                } else if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Int64Array>() {
3271                    Some(arr.value(row_idx) as u64)
3272                } else {
3273                    // Fallback: subject column carries a Node-encoded
3274                    // `uni_common::Value` (LargeBinary via codec). Decode
3275                    // and pull the VID. This is the common case for
3276                    // bare-variable subjects where no `_vid` hidden
3277                    // column was materialized.
3278                    match extract_common_value(col.as_ref(), row_idx) {
3279                        uni_common::Value::Node(n) => Some(n.vid.as_u64()),
3280                        uni_common::Value::Int(i) => Some(i as u64),
3281                        _ => None,
3282                    }
3283                };
3284                Ok(vid_opt
3285                    .and_then(|v| vid_to_score.get(&v).copied())
3286                    .map(uni_locy::FeatureValue::Float)
3287                    .unwrap_or(uni_locy::FeatureValue::Null))
3288            }
3289            FeatureResolverKind::NeighborAggregate {
3290                subject_col,
3291                op,
3292                vid_to_values,
3293            } => {
3294                let vid_opt = extract_vid_from_column(batch.column(*subject_col).as_ref(), row_idx);
3295                Ok(vid_opt
3296                    .and_then(|v| vid_to_values.get(&v))
3297                    .and_then(|values| op.apply(values))
3298                    .map(uni_locy::FeatureValue::Float)
3299                    .unwrap_or(uni_locy::FeatureValue::Null))
3300            }
3301        }
3302    }
3303}
3304
3305/// Extract a node VID from a per-row batch column. Handles the three
3306/// common shapes: `_vid` UInt64 columns, Int64 columns, and Node-encoded
3307/// LargeBinary columns (the standard `uni_common::Value::Node` codec
3308/// representation for a bare variable column).
3309fn extract_vid_from_column(col: &dyn arrow_array::Array, row_idx: usize) -> Option<u64> {
3310    if col.is_null(row_idx) {
3311        return None;
3312    }
3313    if let Some(arr) = col.as_any().downcast_ref::<arrow_array::UInt64Array>() {
3314        return Some(arr.value(row_idx));
3315    }
3316    if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Int64Array>() {
3317        return Some(arr.value(row_idx) as u64);
3318    }
3319    match extract_common_value(col, row_idx) {
3320        uni_common::Value::Node(n) => Some(n.vid.as_u64()),
3321        uni_common::Value::Int(i) => Some(i as u64),
3322        _ => None,
3323    }
3324}
3325
3326#[allow(clippy::too_many_arguments)]
3327fn build_feature_resolvers(
3328    batch: &RecordBatch,
3329    invocation: &uni_locy::ModelInvocation,
3330    path_context_handles: &HashMap<
3331        String,
3332        crate::query::df_graph::locy_model_invoke::PathContextHandle,
3333    >,
3334    semantic_match_embeddings: &HashMap<String, Vec<f32>>,
3335    graph_feature_maps: &HashMap<String, Arc<HashMap<u64, f64>>>,
3336    neighbor_feature_maps: &NeighborFeatureMaps,
3337) -> DFResult<Vec<FeatureResolver>> {
3338    use uni_cypher::ast::Expr;
3339    let schema = batch.schema();
3340    let lookup_col = |name_or_property: String| -> DFResult<usize> {
3341        schema.index_of(&name_or_property).map_err(|_| {
3342            datafusion::error::DataFusionError::Execution(format!(
3343                "feature column '{name_or_property}' not found in clause body output schema"
3344            ))
3345        })
3346    };
3347    // Resolve a feature sub-expression to a per-row value source. Variables
3348    // and property accesses map to batch columns; list/scalar literals
3349    // become inline constants — required so `similar_to(s.embedding, [1,0,0])`
3350    // works without a hidden column for the literal vector.
3351    let resolve_src = |expr: &Expr| -> DFResult<FeatureValueSrc> {
3352        match expr {
3353            Expr::Variable(name) => {
3354                let col = if schema.index_of(name).is_ok() {
3355                    name.clone()
3356                } else {
3357                    let vid_name = format!("{}._vid", name);
3358                    if schema.index_of(&vid_name).is_ok() {
3359                        vid_name
3360                    } else {
3361                        name.clone()
3362                    }
3363                };
3364                Ok(FeatureValueSrc::Col(lookup_col(col)?))
3365            }
3366            Expr::Property(boxed, prop) if matches!(boxed.as_ref(), Expr::Variable(_)) => {
3367                let Expr::Variable(v) = boxed.as_ref() else {
3368                    unreachable!()
3369                };
3370                let direct = format!("{}.{}", v, prop);
3371                let col = if schema.index_of(&direct).is_ok() {
3372                    direct
3373                } else {
3374                    format!("__feat_{}_{}", v, prop)
3375                };
3376                Ok(FeatureValueSrc::Col(lookup_col(col)?))
3377            }
3378            Expr::Literal(lit) => Ok(FeatureValueSrc::Const(lit.to_value())),
3379            Expr::List(items) => {
3380                let mut out = Vec::with_capacity(items.len());
3381                for it in items {
3382                    out.push(match it {
3383                        Expr::Literal(lit) => lit.to_value(),
3384                        _ => uni_common::Value::Null,
3385                    });
3386                }
3387                Ok(FeatureValueSrc::Const(uni_common::Value::List(out)))
3388            }
3389            other => Err(datafusion::error::DataFusionError::Execution(format!(
3390                "unsupported feature sub-expression: {other:?}"
3391            ))),
3392        }
3393    };
3394
3395    // Phase D D3 runtime: when the model declares a path-context
3396    // feature, build a `vid → FeatureValue` lookup once from the
3397    // source rule's converged facts and wrap it in an Arc so the
3398    // per-row resolver does a single hash lookup. The model's
3399    // `INPUT` bindings are unused under this form for MVP — the
3400    // resolver's binding name is the column name (matches how the
3401    // mock-classifier feature-driver pattern in TCK consumes it).
3402    if let Some(pc) = &invocation.path_context {
3403        let handle = path_context_handles.get(&pc.source_rule).ok_or_else(|| {
3404            datafusion::error::DataFusionError::Execution(format!(
3405                "model '{}' path_context references rule '{}' but no DerivedScanHandle \
3406                 was registered; this should never happen — the build_clause path \
3407                 mints a handle for every distinct source_rule in the invocation set",
3408                invocation.model_name, pc.source_rule
3409            ))
3410        })?;
3411        let subject_col = schema
3412            .index_of(&format!("{}._vid", pc.subject_var))
3413            .or_else(|_| schema.index_of(&pc.subject_var))
3414            .map_err(|_| {
3415                datafusion::error::DataFusionError::Execution(format!(
3416                    "model '{}' path_context: subject column '{}' (or '{0}._vid') not \
3417                     in body batch schema",
3418                    invocation.model_name, pc.subject_var
3419                ))
3420            })?;
3421        let vid_to_value =
3422            build_path_context_lookup(handle, &pc.subject_var, &pc.column, &invocation.model_name)?;
3423        return Ok(vec![FeatureResolver {
3424            binding_name: pc.column.clone(),
3425            kind: FeatureResolverKind::PathContext {
3426                subject_col,
3427                vid_to_value: Arc::new(vid_to_value),
3428            },
3429        }]);
3430    }
3431
3432    let mut out = Vec::with_capacity(invocation.feature_exprs.len());
3433    for (i, fexpr) in invocation.feature_exprs.iter().enumerate() {
3434        let binding_name = invocation.feature_names[i].clone();
3435        let kind = match fexpr {
3436            Expr::FunctionCall { name, args, .. } if name == "similar_to" => {
3437                if args.len() != 2 {
3438                    return Err(datafusion::error::DataFusionError::Execution(format!(
3439                        "similar_to expects 2 args, got {}",
3440                        args.len()
3441                    )));
3442                }
3443                FeatureResolverKind::SimilarTo {
3444                    left: resolve_src(&args[0])?,
3445                    right: resolve_src(&args[1])?,
3446                }
3447            }
3448            Expr::FunctionCall { name, args, .. } if name == "semantic_match" => {
3449                // Phase D D2: lower `semantic_match(prop, 'text')` to a
3450                // `SimilarTo` resolver with the pre-embedded query
3451                // vector as the right side. The literal text was embedded
3452                // once via the Xervo runtime in `pre_embed_semantic_match_queries`.
3453                if args.len() != 2 {
3454                    return Err(datafusion::error::DataFusionError::Execution(format!(
3455                        "semantic_match expects 2 args, got {}",
3456                        args.len()
3457                    )));
3458                }
3459                let text = match &args[1] {
3460                    Expr::Literal(uni_cypher::ast::CypherLiteral::String(s)) => s.clone(),
3461                    other => {
3462                        return Err(datafusion::error::DataFusionError::Execution(format!(
3463                            "semantic_match: 2nd arg must be a string literal, got {other:?}"
3464                        )));
3465                    }
3466                };
3467                let embedded = semantic_match_embeddings.get(&text).ok_or_else(|| {
3468                    datafusion::error::DataFusionError::Execution(format!(
3469                        "semantic_match: query text '{text}' was not pre-embedded. \
3470                         This is a bug — `apply_model_invocations` should have \
3471                         embedded all unique semantic_match texts up front. Most \
3472                         likely the Xervo runtime is not configured (configure \
3473                         via `LocyConfig::xervo_runtime` or its equivalent)."
3474                    ))
3475                })?;
3476                let right_vec: Vec<f32> = embedded.clone();
3477                FeatureResolverKind::SimilarTo {
3478                    left: resolve_src(&args[0])?,
3479                    right: FeatureValueSrc::Const(uni_common::Value::Vector(right_vec)),
3480                }
3481            }
3482            Expr::FunctionCall { name, args, .. }
3483                if matches!(
3484                    name.as_str(),
3485                    "degree_centrality"
3486                        | "pagerank_score"
3487                        | "closeness_centrality"
3488                        | "betweenness_centrality"
3489                        | "eigenvector_centrality"
3490                        | "harmonic_centrality"
3491                        | "katz_centrality"
3492                ) =>
3493            {
3494                if args.len() != 1 {
3495                    return Err(datafusion::error::DataFusionError::Execution(format!(
3496                        "{name} expects 1 arg, got {}",
3497                        args.len()
3498                    )));
3499                }
3500                let Expr::Variable(v) = &args[0] else {
3501                    return Err(datafusion::error::DataFusionError::Execution(format!(
3502                        "{name}(...) argument must be a node variable, got {:?}",
3503                        args[0]
3504                    )));
3505                };
3506                let subject_col = {
3507                    let direct = schema.index_of(v).ok();
3508                    let vid_name = format!("{}._vid", v);
3509                    let vid_col = schema.index_of(&vid_name).ok();
3510                    vid_col.or(direct).ok_or_else(|| {
3511                        datafusion::error::DataFusionError::Execution(format!(
3512                            "{name}: subject column '{v}' (or '{v}._vid') not in body batch schema"
3513                        ))
3514                    })?
3515                };
3516                let vid_to_score = graph_feature_maps.get(name).cloned().ok_or_else(|| {
3517                    datafusion::error::DataFusionError::Execution(format!(
3518                        "{name}: pre-computed score map missing. This is a bug — \
3519                         `apply_model_invocations` should have called \
3520                         `precompute_graph_feature_maps` for every graph-structural \
3521                         feature before building resolvers. Most likely the graph \
3522                         algorithm registry is not configured."
3523                    ))
3524                })?;
3525                FeatureResolverKind::GraphAlgoScore {
3526                    subject_col,
3527                    vid_to_score,
3528                }
3529            }
3530            Expr::FunctionCall { name, args, .. }
3531                if matches!(
3532                    name.as_str(),
3533                    "avg_neighbor" | "max_neighbor" | "sum_neighbor"
3534                ) =>
3535            {
3536                if args.len() != 3 && args.len() != 4 {
3537                    return Err(datafusion::error::DataFusionError::Execution(format!(
3538                        "{name} expects 3 or 4 args, got {}",
3539                        args.len()
3540                    )));
3541                }
3542                let Expr::Variable(v) = &args[0] else {
3543                    return Err(datafusion::error::DataFusionError::Execution(format!(
3544                        "{name}(...) first argument must be a node variable, got {:?}",
3545                        args[0]
3546                    )));
3547                };
3548                let rel_type = match &args[1] {
3549                    Expr::Literal(uni_cypher::ast::CypherLiteral::String(s)) => s.clone(),
3550                    other => {
3551                        return Err(datafusion::error::DataFusionError::Execution(format!(
3552                            "{name}: 2nd arg must be a string literal (rel-type), got {other:?}"
3553                        )));
3554                    }
3555                };
3556                let prop_name = match &args[2] {
3557                    Expr::Literal(uni_cypher::ast::CypherLiteral::String(s)) => s.clone(),
3558                    other => {
3559                        return Err(datafusion::error::DataFusionError::Execution(format!(
3560                            "{name}: 3rd arg must be a string literal (property), got {other:?}"
3561                        )));
3562                    }
3563                };
3564                let direction_arg = match args.get(3) {
3565                    None => NeighborDirection::Outgoing,
3566                    Some(Expr::Literal(uni_cypher::ast::CypherLiteral::String(d))) => {
3567                        match d.to_uppercase().as_str() {
3568                            "OUTGOING" => NeighborDirection::Outgoing,
3569                            "INCOMING" => NeighborDirection::Incoming,
3570                            "BOTH" => NeighborDirection::Both,
3571                            other => {
3572                                return Err(datafusion::error::DataFusionError::Execution(
3573                                    format!(
3574                                        "{name}: direction must be OUTGOING|INCOMING|BOTH, got '{other}'"
3575                                    ),
3576                                ));
3577                            }
3578                        }
3579                    }
3580                    Some(other) => {
3581                        return Err(datafusion::error::DataFusionError::Execution(format!(
3582                            "{name}: 4th arg must be a string literal (direction), got {other:?}"
3583                        )));
3584                    }
3585                };
3586                let subject_col = {
3587                    let direct = schema.index_of(v).ok();
3588                    let vid_name = format!("{}._vid", v);
3589                    let vid_col = schema.index_of(&vid_name).ok();
3590                    vid_col.or(direct).ok_or_else(|| {
3591                        datafusion::error::DataFusionError::Execution(format!(
3592                            "{name}: subject column '{v}' (or '{v}._vid') not in body batch schema"
3593                        ))
3594                    })?
3595                };
3596                let vid_to_values = neighbor_feature_maps
3597                    .get(&(rel_type.clone(), prop_name.clone(), direction_arg))
3598                    .cloned()
3599                    .ok_or_else(|| {
3600                        datafusion::error::DataFusionError::Execution(format!(
3601                            "{name}: pre-computed neighbor map missing for ({rel_type}, {prop_name}, {direction_arg:?}). \
3602                             This is a bug — `apply_model_invocations` should have called \
3603                             `precompute_neighbor_feature_maps` for every neighbor-aggregator \
3604                             feature before building resolvers."
3605                        ))
3606                    })?;
3607                let op = NeighborAgg::from_fn_name(name).unwrap();
3608                FeatureResolverKind::NeighborAggregate {
3609                    subject_col,
3610                    op,
3611                    vid_to_values,
3612                }
3613            }
3614            other => match resolve_src(other)? {
3615                FeatureValueSrc::Col(idx) => FeatureResolverKind::Direct(idx),
3616                FeatureValueSrc::Const(_) => {
3617                    return Err(datafusion::error::DataFusionError::Execution(format!(
3618                        "model '{}' feature must reference a variable or property — got a literal",
3619                        invocation.model_name
3620                    )));
3621                }
3622            },
3623        };
3624        out.push(FeatureResolver { binding_name, kind });
3625    }
3626    Ok(out)
3627}
3628
3629/// Phase D D2: scan invocations' feature expressions for
3630/// `semantic_match(prop, 'text')` calls and embed each distinct
3631/// query string once via the Xervo runtime. Returns a
3632/// `text → Vec<f32>` map consumed at resolver-build time. Errors
3633/// cleanly when `semantic_match` is used without a configured
3634/// Xervo runtime.
3635async fn pre_embed_semantic_match_queries(
3636    invocations: &[uni_locy::ModelInvocation],
3637    xervo_runtime: &crate::query::df_graph::locy_model_invoke::XervoRuntimeHandle,
3638) -> DFResult<HashMap<String, Vec<f32>>> {
3639    use uni_cypher::ast::{CypherLiteral, Expr};
3640    // Collect (text, embedder_alias) pairs. The alias is per-model
3641    // (Phase D D2 follow-up): each invocation's `embedder_alias` overrides
3642    // the runtime-wide `"default"`. Two invocations sharing the same
3643    // (text, alias) reuse one embed call; same text under different
3644    // aliases is embedded twice. The cache key remains plain `text` —
3645    // a model's resolver fetches embeddings via the text it knows, and
3646    // mixed-alias re-embed of identical text under a different alias is
3647    // a rare-enough edge case that the "last writer wins" cache shape
3648    // is acceptable (documented).
3649    let mut needed: Vec<(String, String)> = Vec::new();
3650    for inv in invocations {
3651        let alias = inv
3652            .embedder_alias
3653            .clone()
3654            .unwrap_or_else(|| "default".to_string());
3655        for fexpr in &inv.feature_exprs {
3656            if let Expr::FunctionCall { name, args, .. } = fexpr
3657                && name == "semantic_match"
3658                && args.len() == 2
3659                && let Expr::Literal(CypherLiteral::String(s)) = &args[1]
3660            {
3661                let tuple = (s.clone(), alias.clone());
3662                if !needed.contains(&tuple) {
3663                    needed.push(tuple);
3664                }
3665            }
3666        }
3667    }
3668    if needed.is_empty() {
3669        return Ok(HashMap::new());
3670    }
3671    let runtime = xervo_runtime.as_ref().ok_or_else(|| {
3672        datafusion::error::DataFusionError::Execution(
3673            "semantic_match: Uni-Xervo runtime not configured. Either provide \
3674             one via `LocyConfig::xervo_runtime` (or its equivalent setup \
3675             path) or pre-compute the query embedding and pass it via \
3676             `similar_to(prop, <literal_vector>)`."
3677                .to_string(),
3678        )
3679    })?;
3680    // Group needed (text, alias) by alias so each embedder is consulted
3681    // exactly once.
3682    let mut by_alias: HashMap<String, Vec<String>> = HashMap::new();
3683    for (text, alias) in &needed {
3684        by_alias
3685            .entry(alias.clone())
3686            .or_default()
3687            .push(text.clone());
3688    }
3689    let mut out: HashMap<String, Vec<f32>> = HashMap::new();
3690    for (alias, texts) in by_alias {
3691        let embedder = runtime.embedding(&alias).await.map_err(|e| {
3692            datafusion::error::DataFusionError::Execution(format!(
3693                "semantic_match: failed to obtain embedder for alias '{alias}': {e}. \
3694                 Register an embedder under that alias in your Uni-Xervo runtime, or \
3695                 pre-compute the query embedding and pass via similar_to."
3696            ))
3697        })?;
3698        let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
3699        let embeddings = embedder.embed(text_refs).await.map_err(|e| {
3700            datafusion::error::DataFusionError::Execution(format!(
3701                "semantic_match: embedder '{alias}' call failed: {e}"
3702            ))
3703        })?;
3704        if embeddings.len() != texts.len() {
3705            return Err(datafusion::error::DataFusionError::Execution(format!(
3706                "semantic_match: embedder '{alias}' returned {} vectors for {} queries",
3707                embeddings.len(),
3708                texts.len()
3709            )));
3710        }
3711        for (text, vec) in texts.into_iter().zip(embeddings) {
3712            out.insert(text, vec);
3713        }
3714    }
3715    Ok(out)
3716}
3717
3718/// Phase D D1 graph-structural: scan invocations' feature expressions
3719/// for `degree_centrality(n)` / `pagerank_score(n)` / `closeness_centrality(n)`
3720/// calls and invoke the corresponding `uni.algo.*` procedure on the
3721/// configured `AlgorithmRegistry` once per distinct call. Returns a
3722/// `fn_name → Arc<HashMap<vid, score>>` map consumed at resolver-build
3723/// time. Errors cleanly when a graph-structural FEATURE is used
3724/// without a configured registry or storage handle.
3725///
3726/// Pre-computation is `O(graph)` per call. Across fixpoint iterations
3727/// the graph state can change, so the cache lives for the lifetime of
3728/// one `apply_model_invocations` call only — same lifetime as the
3729/// D2 query-embedding cache (`pre_embed_semantic_match_queries`).
3730async fn precompute_graph_feature_maps(
3731    invocations: &[uni_locy::ModelInvocation],
3732    graph_algo: &crate::query::df_graph::locy_model_invoke::GraphAlgoHandle,
3733) -> DFResult<HashMap<String, Arc<HashMap<u64, f64>>>> {
3734    use futures::StreamExt;
3735    use uni_algo::algo::procedures::AlgoContext;
3736    use uni_cypher::ast::Expr;
3737
3738    // Map our user-facing FEATURE function names to the canonical
3739    // `uni.algo.*` procedure names registered in `AlgorithmRegistry`.
3740    fn procedure_for(fn_name: &str) -> Option<&'static str> {
3741        match fn_name {
3742            "degree_centrality" => Some("uni.algo.degreeCentrality"),
3743            "pagerank_score" => Some("uni.algo.pageRank"),
3744            "closeness_centrality" => Some("uni.algo.closeness"),
3745            "betweenness_centrality" => Some("uni.algo.betweenness"),
3746            "eigenvector_centrality" => Some("uni.algo.eigenvectorCentrality"),
3747            "harmonic_centrality" => Some("uni.algo.harmonicCentrality"),
3748            "katz_centrality" => Some("uni.algo.katzCentrality"),
3749            _ => None,
3750        }
3751    }
3752
3753    // Collect the set of distinct topology-FEATURE names referenced
3754    // across all invocations. Args are always a single Variable, so
3755    // the precomputation key is just the function name.
3756    let mut needed: Vec<String> = Vec::new();
3757    for inv in invocations {
3758        for fexpr in &inv.feature_exprs {
3759            if let Expr::FunctionCall { name, .. } = fexpr
3760                && procedure_for(name).is_some()
3761                && !needed.contains(name)
3762            {
3763                needed.push(name.clone());
3764            }
3765        }
3766    }
3767    if needed.is_empty() {
3768        return Ok(HashMap::new());
3769    }
3770
3771    let registry = graph_algo.registry.as_ref().ok_or_else(|| {
3772        datafusion::error::DataFusionError::Execution(
3773            "graph-structural FEATURE invoked but no `AlgorithmRegistry` is \
3774             configured. Configure one on `GraphExecutionContext::with_algo_registry`."
3775                .to_string(),
3776        )
3777    })?;
3778    let storage = graph_algo.storage.as_ref().ok_or_else(|| {
3779        datafusion::error::DataFusionError::Execution(
3780            "graph-structural FEATURE invoked but no storage handle was \
3781             threaded into the FEATURE runtime. This is a bug in df_planner."
3782                .to_string(),
3783        )
3784    })?;
3785
3786    let mut out: HashMap<String, Arc<HashMap<u64, f64>>> = HashMap::new();
3787    for fn_name in needed {
3788        let proc_name = procedure_for(&fn_name).unwrap();
3789        let procedure = registry.get(proc_name).ok_or_else(|| {
3790            datafusion::error::DataFusionError::Execution(format!(
3791                "graph-structural FEATURE '{fn_name}' resolves to procedure \
3792                 '{proc_name}' which is not in the algorithm registry"
3793            ))
3794        })?;
3795        // Topology procedures take (nodeLabels[], relationshipTypes[],
3796        // [direction], [...]) — pass empty arrays for nodeLabels and
3797        // relationshipTypes to mean "all". The procedure fills the
3798        // remaining optional args from its signature defaults.
3799        let args: Vec<serde_json::Value> = vec![
3800            serde_json::Value::Array(Vec::new()),
3801            serde_json::Value::Array(Vec::new()),
3802        ];
3803        let algo_ctx = AlgoContext::new(
3804            storage.clone(),
3805            graph_algo.l0_manager.as_ref().map(Arc::clone),
3806        );
3807        // The AlgoProcedure trait routes direct (nodeLabels, edgeTypes)
3808        // args through the V2 projection entry point: build a projection
3809        // from the direct args, then execute against it.
3810        //
3811        // Fill optional algorithm-specific args (e.g. degree_centrality's
3812        // `direction`, eigenvector/katz `weightProperty`) with their schema
3813        // defaults for the projection build: `build_projection_from_direct_args`
3814        // feeds the specific args (`args[2..]`) to the adapter's
3815        // `customize_projection`, which indexes them positionally — the two
3816        // empty placeholder arrays alone would leave that slice empty and
3817        // panic. `validate_args` fills missing optionals WITHOUT type-checking
3818        // the defaults (some are `Null`-typed sentinels), so this never errors
3819        // for the placeholder shape.
3820        //
3821        // We pass the ORIGINAL `args` (not the filled ones) to
3822        // `execute_with_projection`, which re-runs `validate_args` internally:
3823        // re-feeding already-filled args would make those defaults look
3824        // "provided" and trip the type-check (`weightProperty: Null` vs
3825        // `String`). This mirrors the (now-removed) legacy
3826        // `AlgoProcedure::execute`, which validated once then built + ran.
3827        let filled_args = procedure
3828            .signature()
3829            .validate_args(args.clone())
3830            .map_err(|e| {
3831                datafusion::error::DataFusionError::Execution(format!(
3832                    "graph-structural FEATURE '{fn_name}': argument validation failed: {e}"
3833                ))
3834            })?;
3835        let projection = uni_algo::algo::procedure_template::build_projection_from_direct_args(
3836            procedure.as_ref(),
3837            &algo_ctx,
3838            &filled_args,
3839        )
3840        .await
3841        .map_err(|e| {
3842            datafusion::error::DataFusionError::Execution(format!(
3843                "graph-structural FEATURE '{fn_name}': projection build failed: {e}"
3844            ))
3845        })?;
3846        let mut stream = procedure.execute_with_projection(algo_ctx, args, projection);
3847        let mut score_map: HashMap<u64, f64> = HashMap::new();
3848        let sig = procedure.signature();
3849        let node_idx = sig
3850            .yields
3851            .iter()
3852            .position(|(n, _)| *n == "nodeId")
3853            .ok_or_else(|| {
3854                datafusion::error::DataFusionError::Execution(format!(
3855                    "procedure '{proc_name}' yield schema missing 'nodeId'"
3856                ))
3857            })?;
3858        // Most `uni.algo.*` centrality procedures yield `score`; the
3859        // `harmonicCentrality` family yields `centrality` instead. Accept
3860        // either to keep this dispatch independent of procedure-internal
3861        // naming choices.
3862        let score_idx = sig
3863            .yields
3864            .iter()
3865            .position(|(n, _)| *n == "score" || *n == "centrality")
3866            .ok_or_else(|| {
3867                datafusion::error::DataFusionError::Execution(format!(
3868                    "procedure '{proc_name}' yield schema missing a numeric score column \
3869                     (expected 'score' or 'centrality')"
3870                ))
3871            })?;
3872        while let Some(row_res) = stream.next().await {
3873            let row = row_res.map_err(|e| {
3874                datafusion::error::DataFusionError::Execution(format!(
3875                    "graph-structural FEATURE '{fn_name}': procedure '{proc_name}' failed: {e}"
3876                ))
3877            })?;
3878            let vid_v = row.values.get(node_idx);
3879            let score_v = row.values.get(score_idx);
3880            let (Some(vid_v), Some(score_v)) = (vid_v, score_v) else {
3881                continue;
3882            };
3883            let vid = vid_v.as_u64().or_else(|| vid_v.as_i64().map(|i| i as u64));
3884            let score = score_v
3885                .as_f64()
3886                .or_else(|| score_v.as_i64().map(|i| i as f64));
3887            if let (Some(vid), Some(score)) = (vid, score) {
3888                score_map.insert(vid, score);
3889            }
3890        }
3891        out.insert(fn_name, Arc::new(score_map));
3892    }
3893    Ok(out)
3894}
3895
3896/// Phase D D1 graph-structural: one-hop neighborhood aggregator
3897/// precompute. Scans invocations' feature expressions for
3898/// `avg_neighbor` / `max_neighbor` / `sum_neighbor` FunctionCalls,
3899/// collects the distinct `(rel_type, prop_name)` pairs they need,
3900/// resolves each rel-type to a schema edge-type id, warms the
3901/// outgoing-adjacency CSR, and for every subject vid present in the
3902/// body batches walks the one-hop neighborhood and fetches the
3903/// requested property from each neighbor via `PropertyManager`.
3904/// Non-numeric neighbor property values are filtered out via
3905/// `Value::as_f64`.
3906///
3907/// Returns `Arc<HashMap<u64, Vec<f64>>>` keyed by `(rel_type, prop_name)`.
3908/// The resolver's runtime cost per row is then a single hash lookup
3909/// plus an `avg`/`max`/`sum` over the cached `Vec<f64>`.
3910///
3911/// Scope: **subject-set-only** — we only collect for vids that appear
3912/// in the body batches' subject columns (avoids pre-walking the entire
3913/// graph). Subjects with no outgoing edges of the named type land in
3914/// the map with an empty `Vec` so the resolver's `Null` semantics
3915/// remain crisp (empty → `Null` → classifier interprets per its
3916/// feature contract).
3917/// Per-`(rel_type, prop_name, direction)` cache of neighbor property
3918/// values keyed by subject vid, produced by
3919/// `precompute_neighbor_feature_maps` and consumed by
3920/// `FeatureResolverKind::NeighborAggregate` resolvers.
3921type NeighborFeatureMaps =
3922    HashMap<(String, String, NeighborDirection), Arc<HashMap<u64, Vec<f64>>>>;
3923
3924async fn precompute_neighbor_feature_maps(
3925    invocations: &[uni_locy::ModelInvocation],
3926    batches: &[RecordBatch],
3927    graph_algo: &crate::query::df_graph::locy_model_invoke::GraphAlgoHandle,
3928) -> DFResult<NeighborFeatureMaps> {
3929    use uni_cypher::ast::{CypherLiteral, Expr};
3930
3931    // Collect distinct (subject_var, rel_type, prop_name, direction)
3932    // tuples needed across all invocations. The subject_var tells us
3933    // which body batch column to scan for subject vids; the direction
3934    // is optional in the AST (defaults to OUTGOING).
3935    let parse_direction = |arg: Option<&Expr>| -> Option<NeighborDirection> {
3936        match arg {
3937            None => Some(NeighborDirection::Outgoing),
3938            Some(Expr::Literal(CypherLiteral::String(d))) => match d.to_uppercase().as_str() {
3939                "OUTGOING" => Some(NeighborDirection::Outgoing),
3940                "INCOMING" => Some(NeighborDirection::Incoming),
3941                "BOTH" => Some(NeighborDirection::Both),
3942                _ => None,
3943            },
3944            _ => None,
3945        }
3946    };
3947    let mut needed: Vec<(String, String, String, NeighborDirection)> = Vec::new();
3948    for inv in invocations {
3949        for fexpr in &inv.feature_exprs {
3950            if let Expr::FunctionCall { name, args, .. } = fexpr
3951                && NeighborAgg::from_fn_name(name).is_some()
3952                && (args.len() == 3 || args.len() == 4)
3953                && let Expr::Variable(v) = &args[0]
3954                && let Expr::Literal(CypherLiteral::String(rel)) = &args[1]
3955                && let Expr::Literal(CypherLiteral::String(prop)) = &args[2]
3956                && let Some(direction) = parse_direction(args.get(3))
3957            {
3958                let tuple = (v.clone(), rel.clone(), prop.clone(), direction);
3959                if !needed.contains(&tuple) {
3960                    needed.push(tuple);
3961                }
3962            }
3963        }
3964    }
3965    if needed.is_empty() {
3966        return Ok(HashMap::new());
3967    }
3968
3969    let storage = graph_algo.storage.as_ref().ok_or_else(|| {
3970        datafusion::error::DataFusionError::Execution(
3971            "neighbor-aggregator FEATURE invoked but no storage handle was \
3972             threaded into the FEATURE runtime. This is a bug in df_planner."
3973                .to_string(),
3974        )
3975    })?;
3976    let property_manager = graph_algo.property_manager.as_ref().ok_or_else(|| {
3977        datafusion::error::DataFusionError::Execution(
3978            "neighbor-aggregator FEATURE invoked but no PropertyManager was \
3979             threaded into the FEATURE runtime. This is a bug in df_planner."
3980                .to_string(),
3981        )
3982    })?;
3983    // Build a QueryContext snapshot so L0-resident vertex properties
3984    // are visible to `get_vertex_prop_with_ctx`. Without a ctx, L0
3985    // property data is silently invisible (returns Null), which is
3986    // why the topology trio's `AlgoContext` consumes L0 via
3987    // `L0Manager` whereas property reads need this separate path.
3988    let query_ctx = graph_algo.l0_buffers.as_ref().map(|bufs| {
3989        uni_store::runtime::context::QueryContext::new_with_pending(
3990            bufs.current.clone(),
3991            bufs.transaction.clone(),
3992            bufs.pending_flush.clone(),
3993        )
3994    });
3995
3996    // Group needed tuples by (rel_type, prop_name, direction) — one
3997    // precomputed map per key, regardless of which subject_var binding
3998    // points at it (the subject vids are unioned).
3999    let mut by_key: HashMap<(String, String, NeighborDirection), Vec<String>> = HashMap::new();
4000    for (subject_var, rel, prop, direction) in needed {
4001        by_key
4002            .entry((rel, prop, direction))
4003            .or_default()
4004            .push(subject_var);
4005    }
4006
4007    let mut out: NeighborFeatureMaps = HashMap::new();
4008    for ((rel_type, prop_name, direction), subject_vars) in by_key {
4009        // Resolve edge_type_id from schema.
4010        let schema = storage.schema_manager().schema();
4011        let Some(edge_meta) = schema.edge_types.get(&rel_type) else {
4012            // Unregistered rel-type → empty map. The resolver surfaces
4013            // Null at row time, consistent with the no-neighbor case.
4014            out.insert((rel_type, prop_name, direction), Arc::new(HashMap::new()));
4015            continue;
4016        };
4017        let edge_type_id = edge_meta.id;
4018
4019        // Warm adjacency for every direction we'll traverse. Mirrors
4020        // the pattern in projection.rs / procedure_template.rs.
4021        let edge_ver = storage.get_edge_version_by_id(edge_type_id);
4022        for dir in direction.store_directions() {
4023            storage
4024                .warm_adjacency(edge_type_id, *dir, edge_ver)
4025                .await
4026                .map_err(|e| {
4027                    datafusion::error::DataFusionError::Execution(format!(
4028                        "neighbor-aggregator warm_adjacency for '{rel_type}' / {dir:?} failed: {e}"
4029                    ))
4030                })?;
4031        }
4032
4033        // Collect distinct subject vids from body batches across every
4034        // subject_var binding that this (rel, prop) pair uses.
4035        let mut subject_vids: std::collections::HashSet<u64> = std::collections::HashSet::new();
4036        for subject_var in &subject_vars {
4037            for batch in batches {
4038                let schema = batch.schema();
4039                let col_idx = schema
4040                    .index_of(&format!("{}._vid", subject_var))
4041                    .ok()
4042                    .or_else(|| schema.index_of(subject_var).ok());
4043                let Some(col_idx) = col_idx else { continue };
4044                let col = batch.column(col_idx);
4045                for row in 0..batch.num_rows() {
4046                    if let Some(v) = extract_vid_from_column(col.as_ref(), row) {
4047                        subject_vids.insert(v);
4048                    }
4049                }
4050            }
4051        }
4052
4053        // For each subject, walk edges in the configured direction(s),
4054        // fetch neighbor property, coerce to f64, accumulate. Subjects
4055        // with no numeric neighbors retain an empty Vec (→ Null at
4056        // row time).
4057        let mut vid_to_values: HashMap<u64, Vec<f64>> = HashMap::new();
4058        let adj = storage.adjacency_manager();
4059        for subject_vid in subject_vids {
4060            let mut neighbors: Vec<(uni_common::core::id::Vid, uni_common::core::id::Eid)> =
4061                Vec::new();
4062            for dir in direction.store_directions() {
4063                neighbors.extend(adj.get_neighbors(
4064                    uni_common::core::id::Vid::from(subject_vid),
4065                    edge_type_id,
4066                    *dir,
4067                ));
4068            }
4069            let mut values: Vec<f64> = Vec::with_capacity(neighbors.len());
4070            for (neighbor_vid, _eid) in neighbors {
4071                let val = property_manager
4072                    .get_vertex_prop_with_ctx(neighbor_vid, &prop_name, query_ctx.as_ref())
4073                    .await
4074                    .map_err(|e| {
4075                        datafusion::error::DataFusionError::Execution(format!(
4076                            "neighbor-aggregator: failed to read property \
4077                             '{prop_name}' on neighbor vid {neighbor_vid:?}: {e}"
4078                        ))
4079                    })?;
4080                if let Some(f) = val.as_f64()
4081                    && !f.is_nan()
4082                {
4083                    values.push(f);
4084                }
4085            }
4086            vid_to_values.insert(subject_vid, values);
4087        }
4088        out.insert((rel_type, prop_name, direction), Arc::new(vid_to_values));
4089    }
4090    Ok(out)
4091}
4092
4093/// Phase D D3: walk the source rule's converged batches and build
4094/// a `vid → FeatureValue` lookup for the named column. The subject
4095/// column in the derived rule's schema holds VIDs (UInt64) for node
4096/// variables; the value column type follows the rule's yield-schema
4097/// inference (typically Float64 / Int64 / Bool / String).
4098fn build_path_context_lookup(
4099    handle: &crate::query::df_graph::locy_model_invoke::PathContextHandle,
4100    _subject_var: &str,
4101    column: &str,
4102    model_name: &str,
4103) -> DFResult<HashMap<u64, uni_locy::FeatureValue>> {
4104    // The source rule's KEY column is its first yield column by
4105    // convention (`infer_yield_schema` orders KEYs first). The model's
4106    // local `subject_var` is just a binding alias — the actual join
4107    // matches the body row's VID against this canonical column.
4108    if handle.schema.fields().is_empty() {
4109        return Err(datafusion::error::DataFusionError::Execution(format!(
4110            "model '{model_name}' path_context: source rule has empty yield schema"
4111        )));
4112    }
4113    let subj_idx = 0_usize;
4114    let col_idx = handle.schema.index_of(column).map_err(|_| {
4115        datafusion::error::DataFusionError::Execution(format!(
4116            "model '{model_name}' path_context: column '{column}' not in \
4117             source rule's yield schema (have: {:?})",
4118            handle
4119                .schema
4120                .fields()
4121                .iter()
4122                .map(|f| f.name().clone())
4123                .collect::<Vec<_>>()
4124        ))
4125    })?;
4126    let batches = handle.data.read();
4127    let mut out: HashMap<u64, uni_locy::FeatureValue> = HashMap::new();
4128    for batch in batches.iter() {
4129        let subj_col = batch.column(subj_idx);
4130        let value_col = batch.column(col_idx);
4131        for row in 0..batch.num_rows() {
4132            if subj_col.is_null(row) {
4133                continue;
4134            }
4135            let vid = if let Some(a) = subj_col.as_any().downcast_ref::<arrow_array::UInt64Array>()
4136            {
4137                a.value(row)
4138            } else if let Some(a) = subj_col.as_any().downcast_ref::<arrow_array::Int64Array>() {
4139                a.value(row) as u64
4140            } else {
4141                continue;
4142            };
4143            let v = extract_feature_value(value_col.as_ref(), row);
4144            // Last write wins on duplicates; derived rules typically have
4145            // unique KEY values, so this is a defensive guard.
4146            out.insert(vid, v);
4147        }
4148    }
4149    Ok(out)
4150}
4151
4152/// Extract a `uni_common::Value` from one row of an Arrow column.
4153/// Used by the Phase D `similar_to` feature resolver, which needs
4154/// the raw `Value` (especially `Value::Vector(Vec<f32>)`) to feed
4155/// `eval_similar_to_pure`.
4156fn extract_common_value(col: &dyn arrow_array::Array, row_idx: usize) -> uni_common::Value {
4157    use arrow_array::{BooleanArray, Float64Array, Int64Array, LargeStringArray, StringArray};
4158    if col.is_null(row_idx) {
4159        return uni_common::Value::Null;
4160    }
4161    if let Some(a) = col.as_any().downcast_ref::<Float64Array>() {
4162        return uni_common::Value::Float(a.value(row_idx));
4163    }
4164    if let Some(a) = col.as_any().downcast_ref::<Int64Array>() {
4165        return uni_common::Value::Int(a.value(row_idx));
4166    }
4167    if let Some(a) = col.as_any().downcast_ref::<BooleanArray>() {
4168        return uni_common::Value::Bool(a.value(row_idx));
4169    }
4170    if let Some(a) = col.as_any().downcast_ref::<StringArray>() {
4171        return uni_common::Value::String(a.value(row_idx).to_string());
4172    }
4173    if let Some(a) = col.as_any().downcast_ref::<LargeStringArray>() {
4174        return uni_common::Value::String(a.value(row_idx).to_string());
4175    }
4176    if let Some(b) = col.as_any().downcast_ref::<arrow_array::LargeBinaryArray>() {
4177        let bytes = b.value(row_idx);
4178        if bytes.is_empty() {
4179            return uni_common::Value::Null;
4180        }
4181        return uni_common::cypher_value_codec::decode(bytes).unwrap_or(uni_common::Value::Null);
4182    }
4183    uni_common::Value::Null
4184}
4185
4186fn extract_feature_value(col: &dyn arrow_array::Array, row_idx: usize) -> uni_locy::FeatureValue {
4187    use arrow_array::{BooleanArray, Float64Array, Int64Array, LargeStringArray, StringArray};
4188    if col.is_null(row_idx) {
4189        return uni_locy::FeatureValue::Null;
4190    }
4191    if let Some(a) = col.as_any().downcast_ref::<Float64Array>() {
4192        return uni_locy::FeatureValue::Float(a.value(row_idx));
4193    }
4194    if let Some(a) = col.as_any().downcast_ref::<Int64Array>() {
4195        return uni_locy::FeatureValue::Int(a.value(row_idx));
4196    }
4197    if let Some(a) = col.as_any().downcast_ref::<BooleanArray>() {
4198        return uni_locy::FeatureValue::Bool(a.value(row_idx));
4199    }
4200    if let Some(a) = col.as_any().downcast_ref::<StringArray>() {
4201        return uni_locy::FeatureValue::String(a.value(row_idx).to_string());
4202    }
4203    if let Some(a) = col.as_any().downcast_ref::<LargeStringArray>() {
4204        return uni_locy::FeatureValue::String(a.value(row_idx).to_string());
4205    }
4206    // Schema-less property storage: values arrive as LargeBinary
4207    // MessagePack-encoded `CypherValue`. Decode via the standard codec
4208    // and project the result to the matching `FeatureValue` variant.
4209    if let Some(b) = col.as_any().downcast_ref::<arrow_array::LargeBinaryArray>() {
4210        let bytes = b.value(row_idx);
4211        if bytes.is_empty() {
4212            return uni_locy::FeatureValue::Null;
4213        }
4214        let v = uni_common::cypher_value_codec::decode(bytes).unwrap_or(uni_common::Value::Null);
4215        return match v {
4216            uni_common::Value::Float(f) => uni_locy::FeatureValue::Float(f),
4217            uni_common::Value::Int(i) => uni_locy::FeatureValue::Int(i),
4218            uni_common::Value::Bool(b) => uni_locy::FeatureValue::Bool(b),
4219            uni_common::Value::String(s) => uni_locy::FeatureValue::String(s),
4220            uni_common::Value::Null => uni_locy::FeatureValue::Null,
4221            _ => uni_locy::FeatureValue::Null,
4222        };
4223    }
4224    uni_locy::FeatureValue::Null
4225}
4226
4227/// Probabilistic complement for negated IS-refs targeting PROB rules.
4228///
4229/// Instead of filtering out matching VIDs (anti-join), this adds a complement
4230/// column `__prob_complement_{rule_name}` with value `1 - p` for each matching
4231/// VID, and `1.0` for VIDs not present in the negated rule's facts. Implements
4232/// `IS NOT risk` on a PROB rule: the probability that the entity is NOT risky.
4233pub fn apply_prob_complement(
4234    batches: Vec<RecordBatch>,
4235    neg_facts: &[RecordBatch],
4236    left_col: &str,
4237    right_col: &str,
4238    prob_col: &str,
4239    complement_col_name: &str,
4240) -> datafusion::error::Result<Vec<RecordBatch>> {
4241    use arrow_array::{Array as _, Float64Array, UInt64Array};
4242
4243    // Build VID → probability lookup from negative facts
4244    let mut prob_map: std::collections::HashMap<u64, f64> = std::collections::HashMap::new();
4245    for batch in neg_facts {
4246        let Ok(vid_idx) = batch.schema().index_of(right_col) else {
4247            continue;
4248        };
4249        let Ok(prob_idx) = batch.schema().index_of(prob_col) else {
4250            continue;
4251        };
4252        let Some(vids) = batch.column(vid_idx).as_any().downcast_ref::<UInt64Array>() else {
4253            continue;
4254        };
4255        let prob_arr = batch.column(prob_idx);
4256        let probs = prob_arr.as_any().downcast_ref::<Float64Array>();
4257        for i in 0..vids.len() {
4258            if !vids.is_null(i) {
4259                let p = probs
4260                    .and_then(|arr| {
4261                        if arr.is_null(i) {
4262                            None
4263                        } else {
4264                            Some(arr.value(i))
4265                        }
4266                    })
4267                    .unwrap_or(0.0);
4268                // If multiple facts for same VID, use noisy-OR combination:
4269                // combined = 1 - (1 - existing) * (1 - new)
4270                prob_map
4271                    .entry(vids.value(i))
4272                    .and_modify(|existing| {
4273                        *existing = 1.0 - (1.0 - *existing) * (1.0 - p);
4274                    })
4275                    .or_insert(p);
4276            }
4277        }
4278    }
4279
4280    // Add complement column to each batch
4281    let mut result = Vec::new();
4282    for batch in batches {
4283        let Ok(idx) = batch.schema().index_of(left_col) else {
4284            result.push(batch);
4285            continue;
4286        };
4287        let Some(vids) = batch.column(idx).as_any().downcast_ref::<UInt64Array>() else {
4288            result.push(batch);
4289            continue;
4290        };
4291
4292        // Compute complement values: 1 - p for matched VIDs, 1.0 for absent
4293        let complements: Vec<f64> = (0..vids.len())
4294            .map(|i| {
4295                if vids.is_null(i) {
4296                    1.0
4297                } else {
4298                    let p = prob_map.get(&vids.value(i)).copied().unwrap_or(0.0);
4299                    1.0 - p
4300                }
4301            })
4302            .collect();
4303
4304        let complement_arr = Float64Array::from(complements);
4305
4306        // Add the complement column to the batch
4307        let mut columns: Vec<arrow_array::ArrayRef> = batch.columns().to_vec();
4308        columns.push(std::sync::Arc::new(complement_arr));
4309
4310        let mut fields: Vec<std::sync::Arc<arrow_schema::Field>> =
4311            batch.schema().fields().iter().cloned().collect();
4312        fields.push(std::sync::Arc::new(arrow_schema::Field::new(
4313            complement_col_name,
4314            arrow_schema::DataType::Float64,
4315            true,
4316        )));
4317
4318        let new_schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
4319        let new_batch = RecordBatch::try_new(new_schema, columns).map_err(arrow_err)?;
4320        result.push(new_batch);
4321    }
4322    Ok(result)
4323}
4324
4325/// Probabilistic complement for composite (multi-column) join keys.
4326///
4327/// Builds a composite key from all `join_cols` right-side columns in
4328/// `neg_facts`, maps each composite key to a probability via noisy-OR
4329/// combination, then adds a single `complement_col_name` column with
4330/// `1 - p` for matched keys and `1.0` for absent keys.
4331pub fn apply_prob_complement_composite(
4332    batches: Vec<RecordBatch>,
4333    neg_facts: &[RecordBatch],
4334    join_cols: &[(String, String)],
4335    prob_col: &str,
4336    complement_col_name: &str,
4337) -> datafusion::error::Result<Vec<RecordBatch>> {
4338    use arrow_array::{Array as _, Float64Array, UInt64Array};
4339
4340    // Build composite-key → probability lookup from negative facts.
4341    let mut prob_map: HashMap<Vec<u64>, f64> = HashMap::new();
4342    for batch in neg_facts {
4343        let right_indices: Vec<usize> = join_cols
4344            .iter()
4345            .filter_map(|(_, rc)| batch.schema().index_of(rc).ok())
4346            .collect();
4347        if right_indices.len() != join_cols.len() {
4348            continue;
4349        }
4350        let Ok(prob_idx) = batch.schema().index_of(prob_col) else {
4351            continue;
4352        };
4353        let prob_arr = batch.column(prob_idx);
4354        let probs = prob_arr.as_any().downcast_ref::<Float64Array>();
4355        for row in 0..batch.num_rows() {
4356            let mut key = Vec::with_capacity(right_indices.len());
4357            let mut valid = true;
4358            for &ci in &right_indices {
4359                let col = batch.column(ci);
4360                if let Some(vids) = col.as_any().downcast_ref::<UInt64Array>() {
4361                    if vids.is_null(row) {
4362                        valid = false;
4363                        break;
4364                    }
4365                    key.push(vids.value(row));
4366                } else {
4367                    valid = false;
4368                    break;
4369                }
4370            }
4371            if !valid {
4372                continue;
4373            }
4374            let p = probs
4375                .and_then(|arr| {
4376                    if arr.is_null(row) {
4377                        None
4378                    } else {
4379                        Some(arr.value(row))
4380                    }
4381                })
4382                .unwrap_or(0.0);
4383            // Noisy-OR combination for duplicate composite keys.
4384            prob_map
4385                .entry(key)
4386                .and_modify(|existing| {
4387                    *existing = 1.0 - (1.0 - *existing) * (1.0 - p);
4388                })
4389                .or_insert(p);
4390        }
4391    }
4392
4393    // Add complement column to each batch.
4394    let mut result = Vec::new();
4395    for batch in batches {
4396        let left_indices: Vec<usize> = join_cols
4397            .iter()
4398            .filter_map(|(lc, _)| batch.schema().index_of(lc).ok())
4399            .collect();
4400        if left_indices.len() != join_cols.len() {
4401            result.push(batch);
4402            continue;
4403        }
4404        let all_u64 = left_indices.iter().all(|&ci| {
4405            batch
4406                .column(ci)
4407                .as_any()
4408                .downcast_ref::<UInt64Array>()
4409                .is_some()
4410        });
4411        if !all_u64 {
4412            result.push(batch);
4413            continue;
4414        }
4415
4416        let complements: Vec<f64> = (0..batch.num_rows())
4417            .map(|row| {
4418                let mut key = Vec::with_capacity(left_indices.len());
4419                for &ci in &left_indices {
4420                    let vids = batch
4421                        .column(ci)
4422                        .as_any()
4423                        .downcast_ref::<UInt64Array>()
4424                        .unwrap();
4425                    if vids.is_null(row) {
4426                        return 1.0;
4427                    }
4428                    key.push(vids.value(row));
4429                }
4430                let p = prob_map.get(&key).copied().unwrap_or(0.0);
4431                1.0 - p
4432            })
4433            .collect();
4434
4435        let complement_arr = Float64Array::from(complements);
4436        let mut columns: Vec<arrow_array::ArrayRef> = batch.columns().to_vec();
4437        columns.push(Arc::new(complement_arr));
4438
4439        let mut fields: Vec<Arc<arrow_schema::Field>> =
4440            batch.schema().fields().iter().cloned().collect();
4441        fields.push(Arc::new(arrow_schema::Field::new(
4442            complement_col_name,
4443            arrow_schema::DataType::Float64,
4444            true,
4445        )));
4446
4447        let new_schema = Arc::new(arrow_schema::Schema::new(fields));
4448        let new_batch = RecordBatch::try_new(new_schema, columns).map_err(arrow_err)?;
4449        result.push(new_batch);
4450    }
4451    Ok(result)
4452}
4453
4454/// Boolean anti-join for composite (multi-column) join keys.
4455///
4456/// Builds a `HashSet<Vec<u64>>` from `neg_facts` using all right-side
4457/// columns in `join_cols`, then filters `batches` to keep only rows
4458/// whose composite left-side key is NOT in the set.
4459pub fn apply_anti_join_composite(
4460    batches: Vec<RecordBatch>,
4461    neg_facts: &[RecordBatch],
4462    join_cols: &[(String, String)],
4463) -> datafusion::error::Result<Vec<RecordBatch>> {
4464    use arrow::compute::filter_record_batch;
4465    use arrow_array::{Array as _, BooleanArray, UInt64Array};
4466
4467    // Collect composite keys from the negated rule's derived facts.
4468    let mut banned: HashSet<Vec<u64>> = HashSet::new();
4469    for batch in neg_facts {
4470        let right_indices: Vec<usize> = join_cols
4471            .iter()
4472            .filter_map(|(_, rc)| batch.schema().index_of(rc).ok())
4473            .collect();
4474        if right_indices.len() != join_cols.len() {
4475            continue;
4476        }
4477        for row in 0..batch.num_rows() {
4478            let mut key = Vec::with_capacity(right_indices.len());
4479            let mut valid = true;
4480            for &ci in &right_indices {
4481                let col = batch.column(ci);
4482                if let Some(vids) = col.as_any().downcast_ref::<UInt64Array>() {
4483                    if vids.is_null(row) {
4484                        valid = false;
4485                        break;
4486                    }
4487                    key.push(vids.value(row));
4488                } else {
4489                    valid = false;
4490                    break;
4491                }
4492            }
4493            if valid {
4494                banned.insert(key);
4495            }
4496        }
4497    }
4498
4499    if banned.is_empty() {
4500        return Ok(batches);
4501    }
4502
4503    // Filter body batches: keep rows where composite left key NOT IN banned.
4504    let mut result = Vec::new();
4505    for batch in batches {
4506        let left_indices: Vec<usize> = join_cols
4507            .iter()
4508            .filter_map(|(lc, _)| batch.schema().index_of(lc).ok())
4509            .collect();
4510        if left_indices.len() != join_cols.len() {
4511            result.push(batch);
4512            continue;
4513        }
4514        let all_u64 = left_indices.iter().all(|&ci| {
4515            batch
4516                .column(ci)
4517                .as_any()
4518                .downcast_ref::<UInt64Array>()
4519                .is_some()
4520        });
4521        if !all_u64 {
4522            result.push(batch);
4523            continue;
4524        }
4525
4526        let keep: Vec<bool> = (0..batch.num_rows())
4527            .map(|row| {
4528                let mut key = Vec::with_capacity(left_indices.len());
4529                for &ci in &left_indices {
4530                    let vids = batch
4531                        .column(ci)
4532                        .as_any()
4533                        .downcast_ref::<UInt64Array>()
4534                        .unwrap();
4535                    if vids.is_null(row) {
4536                        return true; // null keys are never banned
4537                    }
4538                    key.push(vids.value(row));
4539                }
4540                !banned.contains(&key)
4541            })
4542            .collect();
4543        let keep_arr = BooleanArray::from(keep);
4544        let filtered = filter_record_batch(&batch, &keep_arr).map_err(arrow_err)?;
4545        if filtered.num_rows() > 0 {
4546            result.push(filtered);
4547        }
4548    }
4549    Ok(result)
4550}
4551
4552/// Multiply `__prob_complement_*` columns into the rule's PROB column and clean up.
4553///
4554/// After IS NOT probabilistic complement semantics have added `__prob_complement_*`
4555/// columns to clause results, this function:
4556/// 1. Computes the product of all complement factor columns
4557/// 2. Multiplies the product into the existing PROB column (if any)
4558/// 3. Removes the internal `__prob_complement_*` columns from the output
4559///
4560/// If the rule has no PROB column, complement columns are simply removed
4561/// (the complement information is discarded and IS NOT acts as a keep-all).
4562pub fn multiply_prob_factors(
4563    batches: Vec<RecordBatch>,
4564    prob_col: Option<&str>,
4565    complement_cols: &[String],
4566) -> datafusion::error::Result<Vec<RecordBatch>> {
4567    use arrow_array::{Array as _, Float64Array};
4568
4569    let mut result = Vec::with_capacity(batches.len());
4570
4571    for batch in batches {
4572        if batch.num_rows() == 0 {
4573            // Remove complement columns from empty batches
4574            let keep: Vec<usize> = batch
4575                .schema()
4576                .fields()
4577                .iter()
4578                .enumerate()
4579                .filter(|(_, f)| !complement_cols.contains(f.name()))
4580                .map(|(i, _)| i)
4581                .collect();
4582            let fields: Vec<_> = keep
4583                .iter()
4584                .map(|&i| batch.schema().field(i).clone())
4585                .collect();
4586            let cols: Vec<_> = keep.iter().map(|&i| batch.column(i).clone()).collect();
4587            let schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
4588            result.push(
4589                RecordBatch::try_new(schema, cols).map_err(|e| {
4590                    datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
4591                })?,
4592            );
4593            continue;
4594        }
4595
4596        let num_rows = batch.num_rows();
4597
4598        // 1. Compute product of all complement factors
4599        let mut combined = vec![1.0f64; num_rows];
4600        for col_name in complement_cols {
4601            if let Ok(idx) = batch.schema().index_of(col_name) {
4602                let arr = batch
4603                    .column(idx)
4604                    .as_any()
4605                    .downcast_ref::<Float64Array>()
4606                    .ok_or_else(|| {
4607                        datafusion::error::DataFusionError::Internal(format!(
4608                            "Expected Float64 for complement column {col_name}"
4609                        ))
4610                    })?;
4611                for (i, val) in combined.iter_mut().enumerate().take(num_rows) {
4612                    if !arr.is_null(i) {
4613                        *val *= arr.value(i);
4614                    }
4615                }
4616            }
4617        }
4618
4619        // 2. If there's a PROB column, multiply combined into it
4620        let final_prob: Vec<f64> = if let Some(prob_name) = prob_col {
4621            if let Ok(idx) = batch.schema().index_of(prob_name) {
4622                let arr = batch
4623                    .column(idx)
4624                    .as_any()
4625                    .downcast_ref::<Float64Array>()
4626                    .ok_or_else(|| {
4627                        datafusion::error::DataFusionError::Internal(format!(
4628                            "Expected Float64 for PROB column {prob_name}"
4629                        ))
4630                    })?;
4631                (0..num_rows)
4632                    .map(|i| {
4633                        if arr.is_null(i) {
4634                            combined[i]
4635                        } else {
4636                            arr.value(i) * combined[i]
4637                        }
4638                    })
4639                    .collect()
4640            } else {
4641                combined
4642            }
4643        } else {
4644            combined
4645        };
4646
4647        let new_prob_array: arrow_array::ArrayRef =
4648            std::sync::Arc::new(Float64Array::from(final_prob));
4649
4650        // 3. Build output: replace PROB column, remove complement columns
4651        let mut fields = Vec::new();
4652        let mut columns = Vec::new();
4653
4654        for (idx, field) in batch.schema().fields().iter().enumerate() {
4655            if complement_cols.contains(field.name()) {
4656                continue;
4657            }
4658            if prob_col.is_some_and(|p| field.name() == p) {
4659                fields.push(field.clone());
4660                columns.push(new_prob_array.clone());
4661            } else {
4662                fields.push(field.clone());
4663                columns.push(batch.column(idx).clone());
4664            }
4665        }
4666
4667        let schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
4668        result.push(RecordBatch::try_new(schema, columns).map_err(arrow_err)?);
4669    }
4670
4671    Ok(result)
4672}
4673
4674/// Update derived scan handles before evaluating a rule's clause bodies.
4675///
4676/// For self-references: inject delta (semi-naive optimization).
4677/// For cross-references: inject full facts.
4678fn update_derived_scan_handles(
4679    registry: &DerivedScanRegistry,
4680    states: &[FixpointState],
4681    current_rule_idx: usize,
4682    rules: &[FixpointRulePlan],
4683) {
4684    let current_rule_name = &rules[current_rule_idx].name;
4685
4686    for entry in &registry.entries {
4687        // Find the state for this entry's rule
4688        let source_state_idx = rules.iter().position(|r| r.name == entry.rule_name);
4689        let Some(source_idx) = source_state_idx else {
4690            continue;
4691        };
4692
4693        let is_self = entry.rule_name == *current_rule_name;
4694        let data = if is_self && !rules[current_rule_idx].non_linear {
4695            // Self-ref in a linear rule: inject delta for semi-naive
4696            states[source_idx].all_delta().to_vec()
4697        } else {
4698            // Cross-ref, or self-ref of a non-linear rule (≥2 same-stratum
4699            // refs in one clause — Δ×Δ would miss Δ×F_old): inject full facts
4700            states[source_idx].all_facts().to_vec()
4701        };
4702
4703        // If empty, write an empty batch so the scan returns zero rows
4704        let data = if data.is_empty() || data.iter().all(|b| b.num_rows() == 0) {
4705            vec![RecordBatch::new_empty(Arc::clone(&entry.schema))]
4706        } else {
4707            data
4708        };
4709
4710        let mut guard = entry.data.write();
4711        *guard = data;
4712    }
4713}
4714
4715// ---------------------------------------------------------------------------
4716// DerivedScanExec — physical plan that reads from shared data at execution time
4717// ---------------------------------------------------------------------------
4718
4719/// Physical plan for `LocyDerivedScan` that reads from a shared `Arc<RwLock>` at
4720/// execution time (not at plan creation time).
4721///
4722/// This is critical for fixpoint iteration: the data handle is updated between
4723/// iterations, and each re-execution of the subplan must read the latest data.
4724pub struct DerivedScanExec {
4725    data: Arc<RwLock<Vec<RecordBatch>>>,
4726    schema: SchemaRef,
4727    properties: Arc<PlanProperties>,
4728}
4729
4730impl DerivedScanExec {
4731    pub fn new(data: Arc<RwLock<Vec<RecordBatch>>>, schema: SchemaRef) -> Self {
4732        let properties = compute_plan_properties(Arc::clone(&schema));
4733        Self {
4734            data,
4735            schema,
4736            properties,
4737        }
4738    }
4739}
4740
4741impl fmt::Debug for DerivedScanExec {
4742    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4743        f.debug_struct("DerivedScanExec")
4744            .field("schema", &self.schema)
4745            .finish()
4746    }
4747}
4748
4749impl DisplayAs for DerivedScanExec {
4750    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4751        write!(f, "DerivedScanExec")
4752    }
4753}
4754
4755impl ExecutionPlan for DerivedScanExec {
4756    fn name(&self) -> &str {
4757        "DerivedScanExec"
4758    }
4759    fn as_any(&self) -> &dyn Any {
4760        self
4761    }
4762    fn schema(&self) -> SchemaRef {
4763        Arc::clone(&self.schema)
4764    }
4765    fn properties(&self) -> &Arc<PlanProperties> {
4766        &self.properties
4767    }
4768    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
4769        vec![]
4770    }
4771    fn with_new_children(
4772        self: Arc<Self>,
4773        _children: Vec<Arc<dyn ExecutionPlan>>,
4774    ) -> DFResult<Arc<dyn ExecutionPlan>> {
4775        Ok(self)
4776    }
4777    fn execute(
4778        &self,
4779        _partition: usize,
4780        _context: Arc<TaskContext>,
4781    ) -> DFResult<SendableRecordBatchStream> {
4782        let batches = {
4783            let guard = self.data.read();
4784            if guard.is_empty() {
4785                vec![RecordBatch::new_empty(Arc::clone(&self.schema))]
4786            } else {
4787                // Re-stamp every batch with this exec's schema. The shared
4788                // data Arc always holds batches with the rule's original
4789                // yield-schema names, but this scan may carry per-occurrence
4790                // aliased column names (multi-IS-ref clauses). Zero-copy:
4791                // only the schema pointer changes, never the columns.
4792                guard
4793                    .iter()
4794                    .map(|b| {
4795                        RecordBatch::try_new(Arc::clone(&self.schema), b.columns().to_vec())
4796                            .map_err(|e| {
4797                                datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
4798                            })
4799                    })
4800                    .collect::<DFResult<Vec<_>>>()?
4801            }
4802        };
4803        Ok(Box::pin(MemoryStream::try_new(
4804            batches,
4805            Arc::clone(&self.schema),
4806            None,
4807        )?))
4808    }
4809}
4810
4811// ---------------------------------------------------------------------------
4812// InMemoryExec — wrapper to feed Vec<RecordBatch> into operator chains
4813// ---------------------------------------------------------------------------
4814
4815/// Simple in-memory execution plan that serves pre-computed batches.
4816///
4817/// Used internally to feed fixpoint results into post-fixpoint operator chains
4818/// (FOLD, BEST BY). Not exported — only used within this module.
4819struct InMemoryExec {
4820    batches: Vec<RecordBatch>,
4821    schema: SchemaRef,
4822    properties: Arc<PlanProperties>,
4823}
4824
4825impl InMemoryExec {
4826    fn new(batches: Vec<RecordBatch>, schema: SchemaRef) -> Self {
4827        let properties = compute_plan_properties(Arc::clone(&schema));
4828        Self {
4829            batches,
4830            schema,
4831            properties,
4832        }
4833    }
4834}
4835
4836impl fmt::Debug for InMemoryExec {
4837    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4838        f.debug_struct("InMemoryExec")
4839            .field("num_batches", &self.batches.len())
4840            .field("schema", &self.schema)
4841            .finish()
4842    }
4843}
4844
4845impl DisplayAs for InMemoryExec {
4846    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4847        write!(f, "InMemoryExec: batches={}", self.batches.len())
4848    }
4849}
4850
4851impl ExecutionPlan for InMemoryExec {
4852    fn name(&self) -> &str {
4853        "InMemoryExec"
4854    }
4855    fn as_any(&self) -> &dyn Any {
4856        self
4857    }
4858    fn schema(&self) -> SchemaRef {
4859        Arc::clone(&self.schema)
4860    }
4861    fn properties(&self) -> &Arc<PlanProperties> {
4862        &self.properties
4863    }
4864    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
4865        vec![]
4866    }
4867    fn with_new_children(
4868        self: Arc<Self>,
4869        _children: Vec<Arc<dyn ExecutionPlan>>,
4870    ) -> DFResult<Arc<dyn ExecutionPlan>> {
4871        Ok(self)
4872    }
4873    fn execute(
4874        &self,
4875        _partition: usize,
4876        _context: Arc<TaskContext>,
4877    ) -> DFResult<SendableRecordBatchStream> {
4878        Ok(Box::pin(MemoryStream::try_new(
4879            self.batches.clone(),
4880            Arc::clone(&self.schema),
4881            None,
4882        )?))
4883    }
4884}
4885
4886// ---------------------------------------------------------------------------
4887// Post-fixpoint chain — FOLD and BEST BY on converged facts
4888// ---------------------------------------------------------------------------
4889
4890/// Apply post-FOLD WHERE (HAVING) filter to aggregated batches.
4891///
4892/// Converts each Cypher HAVING expression to a DataFusion physical expression
4893/// via `cypher_expr_to_df` → type coercion → `create_physical_expr`, evaluates
4894/// against the FOLD output, and keeps only rows where all conditions hold.
4895fn apply_having_filter(
4896    batches: Vec<RecordBatch>,
4897    having_exprs: &[Expr],
4898    schema: &SchemaRef,
4899    task_ctx: &Arc<TaskContext>,
4900) -> DFResult<Vec<RecordBatch>> {
4901    use arrow::compute::{and, filter_record_batch};
4902    use arrow_array::BooleanArray;
4903    use datafusion::common::DFSchema;
4904    use datafusion::logical_expr::LogicalPlanBuilder;
4905    use datafusion::logical_expr::execution_props::ExecutionProps;
4906    use datafusion::optimizer::AnalyzerRule;
4907    use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
4908    use datafusion::physical_expr::create_physical_expr;
4909
4910    if batches.is_empty() {
4911        return Ok(batches);
4912    }
4913
4914    // Build DFSchema from the FOLD output Arrow schema.
4915    let df_schema = DFSchema::try_from(schema.as_ref().clone()).map_err(|e| {
4916        datafusion::common::DataFusionError::Internal(format!("HAVING schema conversion: {e}"))
4917    })?;
4918
4919    // Use the active TaskContext's config rather than allocating a fresh
4920    // `SessionContext` per HAVING evaluation (~130 µs/call). HAVING uses only
4921    // built-in DataFusion arithmetic — no Cypher UDFs — so a default
4922    // `ExecutionProps` is sufficient (it's documented as cheap to construct).
4923    let config = (**task_ctx.session_config().options()).clone();
4924    let props = ExecutionProps::new();
4925
4926    // Cypher Expr → DataFusion DfExpr → type-coerced DfExpr → PhysicalExpr.
4927    //
4928    // Type coercion is needed because FOLD aggregates produce Float64 (SUM,
4929    // AVG) or Int64 (COUNT), and literal comparisons like `total >= 100`
4930    // may mix Float64 columns with Int64 literals.
4931    let physical_exprs: Vec<Arc<dyn datafusion::physical_expr::PhysicalExpr>> = having_exprs
4932        .iter()
4933        .map(|expr| {
4934            let df_expr = crate::query::df_expr::cypher_expr_to_df(expr, None).map_err(|e| {
4935                datafusion::common::DataFusionError::Internal(format!(
4936                    "HAVING expression conversion: {e}"
4937                ))
4938            })?;
4939
4940            // Run DataFusion's type coercion by wrapping in a Filter plan,
4941            // applying the TypeCoercion analyzer rule, then extracting the
4942            // coerced predicate.
4943            let empty = datafusion::logical_expr::LogicalPlan::EmptyRelation(
4944                datafusion::logical_expr::EmptyRelation {
4945                    produce_one_row: false,
4946                    schema: Arc::new(df_schema.clone()),
4947                },
4948            );
4949            let filter_plan = LogicalPlanBuilder::from(empty)
4950                .filter(df_expr.clone())?
4951                .build()?;
4952            let coerced_expr = match TypeCoercion::new().analyze(filter_plan, &config) {
4953                Ok(datafusion::logical_expr::LogicalPlan::Filter(f)) => f.predicate,
4954                _ => df_expr,
4955            };
4956
4957            create_physical_expr(&coerced_expr, &df_schema, &props)
4958        })
4959        .collect::<DFResult<Vec<_>>>()?;
4960
4961    let mut result = Vec::new();
4962    for batch in batches {
4963        // Evaluate each condition and AND the boolean masks.
4964        let mut mask: Option<BooleanArray> = None;
4965        for phys_expr in &physical_exprs {
4966            let value = phys_expr.evaluate(&batch)?;
4967            let arr = value.into_array(batch.num_rows())?;
4968            let bool_arr = arr.as_any().downcast_ref::<BooleanArray>().ok_or_else(|| {
4969                datafusion::common::DataFusionError::Internal(
4970                    "HAVING condition must evaluate to boolean".into(),
4971                )
4972            })?;
4973            mask = Some(match mask {
4974                None => bool_arr.clone(),
4975                Some(prev) => and(&prev, bool_arr).map_err(arrow_err)?,
4976            });
4977        }
4978        if let Some(ref m) = mask {
4979            let filtered = filter_record_batch(&batch, m).map_err(arrow_err)?;
4980            if filtered.num_rows() > 0 {
4981                result.push(filtered);
4982            }
4983        } else {
4984            result.push(batch);
4985        }
4986    }
4987    Ok(result)
4988}
4989
4990/// Apply post-fixpoint operators (FOLD, HAVING, BEST BY, PRIORITY) to converged facts.
4991#[allow(
4992    clippy::too_many_arguments,
4993    reason = "context bundle would be over-engineering for one call site"
4994)]
4995pub(crate) async fn apply_post_fixpoint_chain(
4996    facts: Vec<RecordBatch>,
4997    rule: &FixpointRulePlan,
4998    task_ctx: &Arc<TaskContext>,
4999    strict_probability_domain: bool,
5000    probability_epsilon: f64,
5001    semiring_kind: SemiringKind,
5002    provenance_tracker: Option<Arc<ProvenanceStore>>,
5003    top_k_proofs_k: usize,
5004    registry: Option<Arc<DerivedScanRegistry>>,
5005) -> DFResult<Vec<RecordBatch>> {
5006    if !rule.has_fold && !rule.has_best_by && !rule.has_priority && rule.having.is_empty() {
5007        return Ok(facts);
5008    }
5009
5010    // Wrap facts in InMemoryExec.
5011    // Prefer the actual batch schema (from physical execution) over the
5012    // pre-computed yield_schema, which may have wrong inferred types
5013    // (e.g. Float64 for a string property).
5014    let schema = facts
5015        .iter()
5016        .find(|b| b.num_rows() > 0)
5017        .map(|b| b.schema())
5018        .unwrap_or_else(|| Arc::clone(&rule.yield_schema));
5019
5020    // Phase D D-C0: pre-compute body-row → IS-ref support map for
5021    // TopKProofs MNOR's DNF inclusion-exclusion math. Must be built
5022    // here because `facts` is moved into `InMemoryExec` on the next
5023    // line. The map is keyed by a full-column row hash — only
5024    // meaningful when no downstream plan node strips/adds columns
5025    // between this batch view and the FoldExec input. PRIORITY drops
5026    // the `__priority` column, which would change row hashes; until
5027    // we plumb the map past PRIORITY, skip map construction for
5028    // PRIORITY rules (the failing TCK test doesn't use PRIORITY).
5029    // Read the active K from `semiring_kind` rather than the separate
5030    // `top_k_proofs_k` parameter — the latter is not always threaded
5031    // from the LocyProgram config (the semiring's `k` is the source of
5032    // truth).
5033    let topk_k: Option<usize> = match semiring_kind {
5034        SemiringKind::TopKProofs { k } if k > 0 => Some(k as usize),
5035        _ => None,
5036    };
5037    let body_support_map: Option<Arc<HashMap<Vec<u8>, Vec<ProofTerm>>>> = if topk_k.is_some()
5038        && !rule.has_priority
5039        && let Some(registry) = registry.as_ref()
5040    {
5041        let mut map: HashMap<Vec<u8>, Vec<ProofTerm>> = HashMap::new();
5042        for batch in &facts {
5043            let all_indices: Vec<usize> = (0..batch.num_columns()).collect();
5044            for row_idx in 0..batch.num_rows() {
5045                let support = collect_is_ref_inputs_for_body_row(rule, batch, row_idx, registry);
5046                if support.is_empty() {
5047                    continue;
5048                }
5049                let hash = fact_hash_key(batch, &all_indices, row_idx);
5050                map.insert(hash, support);
5051            }
5052        }
5053        if map.is_empty() {
5054            None
5055        } else {
5056            Some(Arc::new(map))
5057        }
5058    } else {
5059        None
5060    };
5061
5062    let input: Arc<dyn ExecutionPlan> = Arc::new(InMemoryExec::new(facts, schema.clone()));
5063
5064    // Reconcile key indices: rule's indices are yield-schema positions but
5065    // the actual batch may have different column ordering after schema
5066    // reconciliation during fixpoint iteration (same pattern as
5067    // FixpointState::reconcile_schema).
5068    let key_column_indices: Vec<usize> = rule
5069        .key_column_indices
5070        .iter()
5071        .filter_map(|&i| {
5072            let name = rule.yield_schema.field(i).name();
5073            schema.index_of(name).ok()
5074        })
5075        .collect();
5076
5077    // Apply PRIORITY first — keeps only rows with max __priority per KEY group,
5078    // then strips the __priority column from output.
5079    // Must run before FOLD so that the __priority column is still present.
5080    let current: Arc<dyn ExecutionPlan> = if rule.has_priority {
5081        let priority_schema = input.schema();
5082        let priority_idx = priority_schema.index_of("__priority").map_err(|_| {
5083            datafusion::common::DataFusionError::Internal(
5084                "PRIORITY rule missing __priority column".to_string(),
5085            )
5086        })?;
5087        Arc::new(PriorityExec::new(
5088            input,
5089            key_column_indices.clone(),
5090            priority_idx,
5091        ))
5092    } else {
5093        input
5094    };
5095
5096    // Apply FOLD
5097    let current: Arc<dyn ExecutionPlan> = if rule.has_fold && !rule.fold_bindings.is_empty() {
5098        Arc::new(FoldExec::new_with_topk(
5099            current,
5100            key_column_indices.clone(),
5101            rule.fold_bindings.clone(),
5102            strict_probability_domain,
5103            probability_epsilon,
5104            semiring_kind,
5105            provenance_tracker.clone(),
5106            topk_k.unwrap_or(top_k_proofs_k),
5107            body_support_map.clone(),
5108        ))
5109    } else {
5110        current
5111    };
5112
5113    // Apply HAVING (post-FOLD WHERE filter)
5114    let current: Arc<dyn ExecutionPlan> = if !rule.having.is_empty() {
5115        let batches = collect_all_partitions(&current, Arc::clone(task_ctx)).await?;
5116        let filtered = apply_having_filter(batches, &rule.having, &current.schema(), task_ctx)?;
5117        if filtered.is_empty() {
5118            return Ok(filtered);
5119        }
5120        Arc::new(InMemoryExec::new(filtered, Arc::clone(&current.schema())))
5121    } else {
5122        current
5123    };
5124
5125    // Apply BEST BY
5126    let current: Arc<dyn ExecutionPlan> = if rule.has_best_by && !rule.best_by_criteria.is_empty() {
5127        Arc::new(BestByExec::new(
5128            current,
5129            key_column_indices.clone(),
5130            rule.best_by_criteria.clone(),
5131            rule.deterministic,
5132        ))
5133    } else {
5134        current
5135    };
5136
5137    collect_all_partitions(&current, Arc::clone(task_ctx)).await
5138}
5139
5140// ---------------------------------------------------------------------------
5141// FixpointExec — DataFusion ExecutionPlan
5142// ---------------------------------------------------------------------------
5143
5144/// DataFusion `ExecutionPlan` that drives semi-naive fixpoint iteration.
5145///
5146/// Has no physical children: clause bodies are re-planned from logical plans
5147/// on each iteration (same pattern as `RecursiveCTEExec` and `GraphApplyExec`).
5148pub struct FixpointExec {
5149    rules: Vec<FixpointRulePlan>,
5150    max_iterations: usize,
5151    timeout: Duration,
5152    graph_ctx: Arc<GraphExecutionContext>,
5153    session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
5154    storage: Arc<StorageManager>,
5155    schema_info: Arc<UniSchema>,
5156    params: HashMap<String, Value>,
5157    derived_scan_registry: Arc<DerivedScanRegistry>,
5158    output_schema: SchemaRef,
5159    properties: Arc<PlanProperties>,
5160    metrics: ExecutionPlanMetricsSet,
5161    max_derived_bytes: usize,
5162    /// Optional provenance tracker populated during fixpoint iteration.
5163    derivation_tracker: Option<Arc<ProvenanceStore>>,
5164    /// Shared slot written with per-rule iteration counts after convergence.
5165    iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
5166    strict_probability_domain: bool,
5167    probability_epsilon: f64,
5168    exact_probability: bool,
5169    max_bdd_variables: usize,
5170    /// Shared slot for runtime warnings collected during fixpoint iteration.
5171    warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
5172    /// Shared slot for groups where BDD fell back to independence mode.
5173    approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
5174    /// When > 0, retain at most this many proofs per fact (top-k provenance).
5175    top_k_proofs: usize,
5176    /// Shared flag: set to true on timeout to signal partial results.
5177    timeout_flag: Arc<std::sync::atomic::AtomicU8>,
5178    /// Active probability semiring (rollout D-7).
5179    semiring_kind: SemiringKind,
5180    /// Phase B Slice 3 registry of neural classifiers, keyed by the
5181    /// model name from `CREATE MODEL`. Held by `Arc` so executor clones
5182    /// share the same underlying map.
5183    classifier_registry: Arc<ClassifierRegistry>,
5184    /// Phase B follow-up: optional per-evaluation memoization cache
5185    /// for classifier outputs keyed by `(model_name, feature_hash)`.
5186    /// `None` → no caching; `Some` → cache shared across fixpoint
5187    /// iterations (and optionally across the entire query / multiple
5188    /// queries, when the caller threads it via `LocyConfig`).
5189    classifier_cache: Option<Arc<ModelInvocationCache>>,
5190    /// Phase C B1-B3 follow-up: per-query side-channel store
5191    /// for (raw, calibrated, confidence_band) records. Read by
5192    /// EXPLAIN; not used by the fixpoint inner loop directly
5193    /// (LocyModelInvokeExec writes; this struct just carries
5194    /// the Arc to keep the type wiring consistent across the
5195    /// LocyProgramExec/FixpointExec boundary).
5196    #[allow(
5197        dead_code,
5198        reason = "boundary plumbing; read by EXPLAIN via LocyModelInvokeExec"
5199    )]
5200    classifier_provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
5201    /// Optional per-stratum profile collector. `Some` only on the `profile()`
5202    /// path; when set, each fixpoint iteration records per-rule timing, delta
5203    /// facts, and the clause-body operator tree into it. `None` → zero overhead.
5204    profile_collector: Option<Arc<LocyProfileCollector>>,
5205}
5206
5207impl fmt::Debug for FixpointExec {
5208    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
5209        f.debug_struct("FixpointExec")
5210            .field("rules_count", &self.rules.len())
5211            .field("max_iterations", &self.max_iterations)
5212            .field("timeout", &self.timeout)
5213            .field("output_schema", &self.output_schema)
5214            .field("max_derived_bytes", &self.max_derived_bytes)
5215            .finish_non_exhaustive()
5216    }
5217}
5218
5219impl FixpointExec {
5220    /// Create a new `FixpointExec`.
5221    #[expect(
5222        clippy::too_many_arguments,
5223        reason = "FixpointExec configuration needs all context"
5224    )]
5225    #[deprecated(
5226        note = "use `new_with_semiring_classifiers_and_cache` (or the lighter \
5227                `new_with_semiring_and_classifiers` / `new_with_semiring`) — \
5228                this legacy ctor defaults the semiring to AddMultProb and \
5229                ships no classifier registry, which the Phase B+ runtime needs \
5230                explicitly. To be removed after C0 Stage 2."
5231    )]
5232    pub fn new(
5233        rules: Vec<FixpointRulePlan>,
5234        max_iterations: usize,
5235        timeout: Duration,
5236        graph_ctx: Arc<GraphExecutionContext>,
5237        session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
5238        storage: Arc<StorageManager>,
5239        schema_info: Arc<UniSchema>,
5240        params: HashMap<String, Value>,
5241        derived_scan_registry: Arc<DerivedScanRegistry>,
5242        output_schema: SchemaRef,
5243        max_derived_bytes: usize,
5244        derivation_tracker: Option<Arc<ProvenanceStore>>,
5245        iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
5246        strict_probability_domain: bool,
5247        probability_epsilon: f64,
5248        exact_probability: bool,
5249        max_bdd_variables: usize,
5250        warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
5251        approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
5252        top_k_proofs: usize,
5253        timeout_flag: Arc<std::sync::atomic::AtomicU8>,
5254    ) -> Self {
5255        Self::new_with_semiring_and_classifiers(
5256            rules,
5257            max_iterations,
5258            timeout,
5259            graph_ctx,
5260            session_ctx,
5261            storage,
5262            schema_info,
5263            params,
5264            derived_scan_registry,
5265            output_schema,
5266            max_derived_bytes,
5267            derivation_tracker,
5268            iteration_counts,
5269            strict_probability_domain,
5270            probability_epsilon,
5271            exact_probability,
5272            max_bdd_variables,
5273            warnings_slot,
5274            approximate_slot,
5275            top_k_proofs,
5276            timeout_flag,
5277            SemiringKind::AddMultProb,
5278            Arc::new(ClassifierRegistry::new()),
5279        )
5280    }
5281
5282    /// Variant accepting an explicit [`SemiringKind`]. Empty classifier
5283    /// registry; for the full variant call
5284    /// [`FixpointExec::new_with_semiring_and_classifiers`].
5285    #[expect(
5286        clippy::too_many_arguments,
5287        reason = "FixpointExec configuration needs all context"
5288    )]
5289    pub fn new_with_semiring(
5290        rules: Vec<FixpointRulePlan>,
5291        max_iterations: usize,
5292        timeout: Duration,
5293        graph_ctx: Arc<GraphExecutionContext>,
5294        session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
5295        storage: Arc<StorageManager>,
5296        schema_info: Arc<UniSchema>,
5297        params: HashMap<String, Value>,
5298        derived_scan_registry: Arc<DerivedScanRegistry>,
5299        output_schema: SchemaRef,
5300        max_derived_bytes: usize,
5301        derivation_tracker: Option<Arc<ProvenanceStore>>,
5302        iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
5303        strict_probability_domain: bool,
5304        probability_epsilon: f64,
5305        exact_probability: bool,
5306        max_bdd_variables: usize,
5307        warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
5308        approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
5309        top_k_proofs: usize,
5310        timeout_flag: Arc<std::sync::atomic::AtomicU8>,
5311        semiring_kind: SemiringKind,
5312    ) -> Self {
5313        Self::new_with_semiring_and_classifiers(
5314            rules,
5315            max_iterations,
5316            timeout,
5317            graph_ctx,
5318            session_ctx,
5319            storage,
5320            schema_info,
5321            params,
5322            derived_scan_registry,
5323            output_schema,
5324            max_derived_bytes,
5325            derivation_tracker,
5326            iteration_counts,
5327            strict_probability_domain,
5328            probability_epsilon,
5329            exact_probability,
5330            max_bdd_variables,
5331            warnings_slot,
5332            approximate_slot,
5333            top_k_proofs,
5334            timeout_flag,
5335            semiring_kind,
5336            Arc::new(ClassifierRegistry::new()),
5337        )
5338    }
5339
5340    /// Phase B Slice 3 entry: accepts both the semiring kind and the
5341    /// runtime classifier registry. The planner uses this when the
5342    /// program contains `CREATE MODEL` declarations.
5343    #[expect(
5344        clippy::too_many_arguments,
5345        reason = "FixpointExec configuration needs all context"
5346    )]
5347    pub fn new_with_semiring_and_classifiers(
5348        rules: Vec<FixpointRulePlan>,
5349        max_iterations: usize,
5350        timeout: Duration,
5351        graph_ctx: Arc<GraphExecutionContext>,
5352        session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
5353        storage: Arc<StorageManager>,
5354        schema_info: Arc<UniSchema>,
5355        params: HashMap<String, Value>,
5356        derived_scan_registry: Arc<DerivedScanRegistry>,
5357        output_schema: SchemaRef,
5358        max_derived_bytes: usize,
5359        derivation_tracker: Option<Arc<ProvenanceStore>>,
5360        iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
5361        strict_probability_domain: bool,
5362        probability_epsilon: f64,
5363        exact_probability: bool,
5364        max_bdd_variables: usize,
5365        warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
5366        approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
5367        top_k_proofs: usize,
5368        timeout_flag: Arc<std::sync::atomic::AtomicU8>,
5369        semiring_kind: SemiringKind,
5370        classifier_registry: Arc<ClassifierRegistry>,
5371    ) -> Self {
5372        Self::new_with_semiring_classifiers_and_cache(
5373            rules,
5374            max_iterations,
5375            timeout,
5376            graph_ctx,
5377            session_ctx,
5378            storage,
5379            schema_info,
5380            params,
5381            derived_scan_registry,
5382            output_schema,
5383            max_derived_bytes,
5384            derivation_tracker,
5385            iteration_counts,
5386            strict_probability_domain,
5387            probability_epsilon,
5388            exact_probability,
5389            max_bdd_variables,
5390            warnings_slot,
5391            approximate_slot,
5392            top_k_proofs,
5393            timeout_flag,
5394            semiring_kind,
5395            classifier_registry,
5396            None,
5397            None,
5398        )
5399    }
5400
5401    /// Phase B follow-up: full constructor accepting an optional
5402    /// memoization cache. Existing callers default to `None` (no cache);
5403    /// the impl_locy.rs entry passes the user's `config.classifier_cache`.
5404    #[expect(
5405        clippy::too_many_arguments,
5406        reason = "FixpointExec configuration needs all context"
5407    )]
5408    pub fn new_with_semiring_classifiers_and_cache(
5409        rules: Vec<FixpointRulePlan>,
5410        max_iterations: usize,
5411        timeout: Duration,
5412        graph_ctx: Arc<GraphExecutionContext>,
5413        session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
5414        storage: Arc<StorageManager>,
5415        schema_info: Arc<UniSchema>,
5416        params: HashMap<String, Value>,
5417        derived_scan_registry: Arc<DerivedScanRegistry>,
5418        output_schema: SchemaRef,
5419        max_derived_bytes: usize,
5420        derivation_tracker: Option<Arc<ProvenanceStore>>,
5421        iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
5422        strict_probability_domain: bool,
5423        probability_epsilon: f64,
5424        exact_probability: bool,
5425        max_bdd_variables: usize,
5426        warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
5427        approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
5428        top_k_proofs: usize,
5429        timeout_flag: Arc<std::sync::atomic::AtomicU8>,
5430        semiring_kind: SemiringKind,
5431        classifier_registry: Arc<ClassifierRegistry>,
5432        classifier_cache: Option<Arc<ModelInvocationCache>>,
5433        classifier_provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
5434    ) -> Self {
5435        let properties = compute_plan_properties(Arc::clone(&output_schema));
5436        Self {
5437            rules,
5438            max_iterations,
5439            timeout,
5440            graph_ctx,
5441            session_ctx,
5442            storage,
5443            schema_info,
5444            params,
5445            derived_scan_registry,
5446            output_schema,
5447            properties,
5448            metrics: ExecutionPlanMetricsSet::new(),
5449            max_derived_bytes,
5450            derivation_tracker,
5451            iteration_counts,
5452            strict_probability_domain,
5453            probability_epsilon,
5454            exact_probability,
5455            max_bdd_variables,
5456            warnings_slot,
5457            approximate_slot,
5458            top_k_proofs,
5459            timeout_flag,
5460            semiring_kind,
5461            classifier_registry,
5462            classifier_cache,
5463            classifier_provenance_store,
5464            profile_collector: None,
5465        }
5466    }
5467
5468    /// Returns the shared iteration counts slot for post-execution inspection.
5469    pub fn iteration_counts(&self) -> Arc<StdRwLock<HashMap<String, usize>>> {
5470        Arc::clone(&self.iteration_counts)
5471    }
5472
5473    /// Attach a profile collector so this stratum's fixpoint records per-rule,
5474    /// per-iteration timing, delta facts, and clause-body operator metrics.
5475    ///
5476    /// Mirrors `set_derivation_tracker`: call before wrapping the exec in an
5477    /// `Arc` and executing. Only the Locy `profile()` path sets this.
5478    pub fn set_profile_collector(&mut self, collector: Arc<LocyProfileCollector>) {
5479        self.profile_collector = Some(collector);
5480    }
5481}
5482
5483impl DisplayAs for FixpointExec {
5484    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
5485        write!(
5486            f,
5487            "FixpointExec: rules=[{}], max_iter={}, timeout={:?}",
5488            self.rules
5489                .iter()
5490                .map(|r| r.name.as_str())
5491                .collect::<Vec<_>>()
5492                .join(", "),
5493            self.max_iterations,
5494            self.timeout,
5495        )
5496    }
5497}
5498
5499impl ExecutionPlan for FixpointExec {
5500    fn name(&self) -> &str {
5501        "FixpointExec"
5502    }
5503
5504    fn as_any(&self) -> &dyn Any {
5505        self
5506    }
5507
5508    fn schema(&self) -> SchemaRef {
5509        Arc::clone(&self.output_schema)
5510    }
5511
5512    fn properties(&self) -> &Arc<PlanProperties> {
5513        &self.properties
5514    }
5515
5516    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
5517        // No physical children — clause bodies are re-planned each iteration
5518        vec![]
5519    }
5520
5521    fn with_new_children(
5522        self: Arc<Self>,
5523        children: Vec<Arc<dyn ExecutionPlan>>,
5524    ) -> DFResult<Arc<dyn ExecutionPlan>> {
5525        if !children.is_empty() {
5526            return Err(datafusion::error::DataFusionError::Plan(
5527                "FixpointExec has no children".to_string(),
5528            ));
5529        }
5530        Ok(self)
5531    }
5532
5533    fn execute(
5534        &self,
5535        partition: usize,
5536        _context: Arc<TaskContext>,
5537    ) -> DFResult<SendableRecordBatchStream> {
5538        let metrics = BaselineMetrics::new(&self.metrics, partition);
5539
5540        // Clone all fields for the async closure
5541        let rules = self
5542            .rules
5543            .iter()
5544            .map(|r| {
5545                // We need to clone the FixpointRulePlan, but it contains LogicalPlan
5546                // which doesn't implement Clone traditionally. However, our LogicalPlan
5547                // does implement Clone since it's an enum.
5548                FixpointRulePlan {
5549                    name: r.name.clone(),
5550                    clauses: r
5551                        .clauses
5552                        .iter()
5553                        .map(|c| FixpointClausePlan {
5554                            body_logical: c.body_logical.clone(),
5555                            is_ref_bindings: c.is_ref_bindings.clone(),
5556                            priority: c.priority,
5557                            along_bindings: c.along_bindings.clone(),
5558                            model_invocations: c.model_invocations.clone(),
5559                        })
5560                        .collect(),
5561                    yield_schema: Arc::clone(&r.yield_schema),
5562                    key_column_indices: r.key_column_indices.clone(),
5563                    priority: r.priority,
5564                    has_fold: r.has_fold,
5565                    fold_bindings: r.fold_bindings.clone(),
5566                    having: r.having.clone(),
5567                    has_best_by: r.has_best_by,
5568                    best_by_criteria: r.best_by_criteria.clone(),
5569                    has_priority: r.has_priority,
5570                    deterministic: r.deterministic,
5571                    prob_column_name: r.prob_column_name.clone(),
5572                    non_linear: r.non_linear,
5573                }
5574            })
5575            .collect();
5576
5577        let max_iterations = self.max_iterations;
5578        let timeout = self.timeout;
5579        let graph_ctx = Arc::clone(&self.graph_ctx);
5580        let session_ctx = Arc::clone(&self.session_ctx);
5581        let storage = Arc::clone(&self.storage);
5582        let schema_info = Arc::clone(&self.schema_info);
5583        let params = self.params.clone();
5584        let registry = Arc::clone(&self.derived_scan_registry);
5585        let output_schema = Arc::clone(&self.output_schema);
5586        let max_derived_bytes = self.max_derived_bytes;
5587        let derivation_tracker = self.derivation_tracker.clone();
5588        let iteration_counts = Arc::clone(&self.iteration_counts);
5589        let strict_probability_domain = self.strict_probability_domain;
5590        let probability_epsilon = self.probability_epsilon;
5591        let exact_probability = self.exact_probability;
5592        let max_bdd_variables = self.max_bdd_variables;
5593        let warnings_slot = Arc::clone(&self.warnings_slot);
5594        let approximate_slot = Arc::clone(&self.approximate_slot);
5595        let top_k_proofs = self.top_k_proofs;
5596        let timeout_flag = Arc::clone(&self.timeout_flag);
5597        let semiring_kind = self.semiring_kind;
5598        let classifier_registry = Arc::clone(&self.classifier_registry);
5599        let classifier_cache = self.classifier_cache.as_ref().map(Arc::clone);
5600        let classifier_provenance_store = self.classifier_provenance_store.as_ref().map(Arc::clone);
5601        let profile_collector = self.profile_collector.as_ref().map(Arc::clone);
5602
5603        let fut = async move {
5604            run_fixpoint_loop(
5605                rules,
5606                max_iterations,
5607                timeout,
5608                graph_ctx,
5609                session_ctx,
5610                storage,
5611                schema_info,
5612                params,
5613                registry,
5614                output_schema,
5615                max_derived_bytes,
5616                derivation_tracker,
5617                iteration_counts,
5618                strict_probability_domain,
5619                probability_epsilon,
5620                exact_probability,
5621                max_bdd_variables,
5622                warnings_slot,
5623                approximate_slot,
5624                top_k_proofs,
5625                timeout_flag,
5626                semiring_kind,
5627                classifier_registry,
5628                classifier_cache,
5629                classifier_provenance_store,
5630                profile_collector,
5631            )
5632            .await
5633        };
5634
5635        Ok(Box::pin(FixpointStream {
5636            state: FixpointStreamState::Running(Box::pin(fut)),
5637            schema: Arc::clone(&self.output_schema),
5638            metrics,
5639        }))
5640    }
5641
5642    fn metrics(&self) -> Option<MetricsSet> {
5643        Some(self.metrics.clone_inner())
5644    }
5645}
5646
5647// ---------------------------------------------------------------------------
5648// FixpointStream — async state machine for streaming results
5649// ---------------------------------------------------------------------------
5650
5651enum FixpointStreamState {
5652    /// Fixpoint loop is running.
5653    Running(Pin<Box<dyn std::future::Future<Output = DFResult<Vec<RecordBatch>>> + Send>>),
5654    /// Emitting accumulated result batches one at a time.
5655    Emitting(Vec<RecordBatch>, usize),
5656    /// All batches emitted.
5657    Done,
5658}
5659
5660struct FixpointStream {
5661    state: FixpointStreamState,
5662    schema: SchemaRef,
5663    metrics: BaselineMetrics,
5664}
5665
5666impl Stream for FixpointStream {
5667    type Item = DFResult<RecordBatch>;
5668
5669    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
5670        let this = self.get_mut();
5671        let metrics = this.metrics.clone();
5672        let _timer = metrics.elapsed_compute().timer();
5673        loop {
5674            match &mut this.state {
5675                FixpointStreamState::Running(fut) => match fut.as_mut().poll(cx) {
5676                    Poll::Ready(Ok(batches)) => {
5677                        if batches.is_empty() {
5678                            this.state = FixpointStreamState::Done;
5679                            return Poll::Ready(None);
5680                        }
5681                        this.state = FixpointStreamState::Emitting(batches, 0);
5682                        // Loop to emit first batch
5683                    }
5684                    Poll::Ready(Err(e)) => {
5685                        this.state = FixpointStreamState::Done;
5686                        return Poll::Ready(Some(Err(e)));
5687                    }
5688                    Poll::Pending => return Poll::Pending,
5689                },
5690                FixpointStreamState::Emitting(batches, idx) => {
5691                    if *idx >= batches.len() {
5692                        this.state = FixpointStreamState::Done;
5693                        return Poll::Ready(None);
5694                    }
5695                    let batch = batches[*idx].clone();
5696                    *idx += 1;
5697                    this.metrics.record_output(batch.num_rows());
5698                    return Poll::Ready(Some(Ok(batch)));
5699                }
5700                FixpointStreamState::Done => return Poll::Ready(None),
5701            }
5702        }
5703    }
5704}
5705
5706impl RecordBatchStream for FixpointStream {
5707    fn schema(&self) -> SchemaRef {
5708        Arc::clone(&self.schema)
5709    }
5710}
5711
5712// ---------------------------------------------------------------------------
5713// Unit tests
5714// ---------------------------------------------------------------------------
5715
5716#[cfg(test)]
5717mod tests {
5718    use super::*;
5719    use arrow_array::{Float64Array, Int64Array, StringArray};
5720    use arrow_schema::{DataType, Field, Schema};
5721
5722    fn test_schema() -> SchemaRef {
5723        Arc::new(Schema::new(vec![
5724            Field::new("name", DataType::Utf8, true),
5725            Field::new("value", DataType::Int64, true),
5726        ]))
5727    }
5728
5729    fn make_batch(names: &[&str], values: &[i64]) -> RecordBatch {
5730        RecordBatch::try_new(
5731            test_schema(),
5732            vec![
5733                Arc::new(StringArray::from(
5734                    names.iter().map(|s| Some(*s)).collect::<Vec<_>>(),
5735                )),
5736                Arc::new(Int64Array::from(values.to_vec())),
5737            ],
5738        )
5739        .unwrap()
5740    }
5741
5742    // --- FixpointState dedup tests ---
5743
5744    #[tokio::test]
5745    async fn test_fixpoint_state_empty_facts_adds_all() {
5746        let schema = test_schema();
5747        let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
5748
5749        let batch = make_batch(&["a", "b", "c"], &[1, 2, 3]);
5750        let changed = state.merge_delta(vec![batch], None).await.unwrap();
5751
5752        assert!(changed);
5753        assert_eq!(state.all_facts().len(), 1);
5754        assert_eq!(state.all_facts()[0].num_rows(), 3);
5755        assert_eq!(state.all_delta().len(), 1);
5756        assert_eq!(state.all_delta()[0].num_rows(), 3);
5757    }
5758
5759    #[tokio::test]
5760    async fn test_fixpoint_state_exact_duplicates_excluded() {
5761        let schema = test_schema();
5762        let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
5763
5764        let batch1 = make_batch(&["a", "b"], &[1, 2]);
5765        state.merge_delta(vec![batch1], None).await.unwrap();
5766
5767        // Same rows again
5768        let batch2 = make_batch(&["a", "b"], &[1, 2]);
5769        let changed = state.merge_delta(vec![batch2], None).await.unwrap();
5770        assert!(!changed);
5771        assert!(
5772            state.all_delta().is_empty() || state.all_delta().iter().all(|b| b.num_rows() == 0)
5773        );
5774    }
5775
5776    #[tokio::test]
5777    async fn test_fixpoint_state_partial_overlap() {
5778        let schema = test_schema();
5779        let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
5780
5781        let batch1 = make_batch(&["a", "b"], &[1, 2]);
5782        state.merge_delta(vec![batch1], None).await.unwrap();
5783
5784        // "a":1 is duplicate, "c":3 is new
5785        let batch2 = make_batch(&["a", "c"], &[1, 3]);
5786        let changed = state.merge_delta(vec![batch2], None).await.unwrap();
5787        assert!(changed);
5788
5789        // Delta should have only "c":3
5790        let delta_rows: usize = state.all_delta().iter().map(|b| b.num_rows()).sum();
5791        assert_eq!(delta_rows, 1);
5792
5793        // Total facts: a:1, b:2, c:3
5794        let total_rows: usize = state.all_facts().iter().map(|b| b.num_rows()).sum();
5795        assert_eq!(total_rows, 3);
5796    }
5797
5798    #[tokio::test]
5799    async fn test_fixpoint_state_convergence() {
5800        let schema = test_schema();
5801        let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
5802
5803        let batch = make_batch(&["a"], &[1]);
5804        state.merge_delta(vec![batch], None).await.unwrap();
5805
5806        // Empty candidates → converged
5807        let changed = state.merge_delta(vec![], None).await.unwrap();
5808        assert!(!changed);
5809        assert!(state.is_converged());
5810    }
5811
5812    // --- RowDedupState tests ---
5813
5814    #[test]
5815    fn test_row_dedup_persistent_across_calls() {
5816        // RowDedupState should remember rows from the first call so the second
5817        // call does not re-accept them (O(M) per iteration, no facts re-scan).
5818        let schema = test_schema();
5819        let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
5820
5821        let batch1 = make_batch(&["a", "b"], &[1, 2]);
5822        let delta1 = rd.compute_delta(&[batch1], &schema).unwrap();
5823        // First call: both rows are new.
5824        let rows1: usize = delta1.iter().map(|b| b.num_rows()).sum();
5825        assert_eq!(rows1, 2);
5826
5827        // Second call with same rows: seen set already has them → empty delta.
5828        let batch2 = make_batch(&["a", "b"], &[1, 2]);
5829        let delta2 = rd.compute_delta(&[batch2], &schema).unwrap();
5830        let rows2: usize = delta2.iter().map(|b| b.num_rows()).sum();
5831        assert_eq!(rows2, 0);
5832
5833        // Third call with one old + one new: only the new row is returned.
5834        let batch3 = make_batch(&["a", "c"], &[1, 3]);
5835        let delta3 = rd.compute_delta(&[batch3], &schema).unwrap();
5836        let rows3: usize = delta3.iter().map(|b| b.num_rows()).sum();
5837        assert_eq!(rows3, 1);
5838    }
5839
5840    #[test]
5841    fn test_row_dedup_null_handling() {
5842        use arrow_array::StringArray;
5843        use arrow_schema::{DataType, Field, Schema};
5844
5845        let schema: SchemaRef = Arc::new(Schema::new(vec![
5846            Field::new("a", DataType::Utf8, true),
5847            Field::new("b", DataType::Int64, true),
5848        ]));
5849        let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
5850
5851        // Two rows: (NULL, 1) and (NULL, 1) — same NULLs → duplicate.
5852        let batch_nulls = RecordBatch::try_new(
5853            Arc::clone(&schema),
5854            vec![
5855                Arc::new(StringArray::from(vec![None::<&str>, None::<&str>])),
5856                Arc::new(arrow_array::Int64Array::from(vec![1i64, 1i64])),
5857            ],
5858        )
5859        .unwrap();
5860        let delta = rd.compute_delta(&[batch_nulls], &schema).unwrap();
5861        let rows: usize = delta.iter().map(|b| b.num_rows()).sum();
5862        assert_eq!(rows, 1, "two identical NULL rows should be deduped to one");
5863
5864        // (NULL, 2) — NULL in same col but different non-null col → distinct.
5865        let batch_diff = RecordBatch::try_new(
5866            Arc::clone(&schema),
5867            vec![
5868                Arc::new(StringArray::from(vec![None::<&str>])),
5869                Arc::new(arrow_array::Int64Array::from(vec![2i64])),
5870            ],
5871        )
5872        .unwrap();
5873        let delta2 = rd.compute_delta(&[batch_diff], &schema).unwrap();
5874        let rows2: usize = delta2.iter().map(|b| b.num_rows()).sum();
5875        assert_eq!(rows2, 1, "(NULL, 2) is distinct from (NULL, 1)");
5876    }
5877
5878    #[test]
5879    fn test_row_dedup_within_candidate_dedup() {
5880        // Duplicate rows within a single candidate batch should be collapsed to one.
5881        let schema = test_schema();
5882        let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
5883
5884        // Batch with three rows: a:1, a:1, b:2 — "a:1" appears twice.
5885        let batch = make_batch(&["a", "a", "b"], &[1, 1, 2]);
5886        let delta = rd.compute_delta(&[batch], &schema).unwrap();
5887        let rows: usize = delta.iter().map(|b| b.num_rows()).sum();
5888        assert_eq!(rows, 2, "within-batch dup should be collapsed: a:1, b:2");
5889    }
5890
5891    // --- Float rounding tests ---
5892
5893    #[test]
5894    fn test_round_float_columns_near_duplicates() {
5895        let schema = Arc::new(Schema::new(vec![
5896            Field::new("name", DataType::Utf8, true),
5897            Field::new("dist", DataType::Float64, true),
5898        ]));
5899        let batch = RecordBatch::try_new(
5900            Arc::clone(&schema),
5901            vec![
5902                Arc::new(StringArray::from(vec![Some("a"), Some("a")])),
5903                Arc::new(Float64Array::from(vec![1.0000000000001, 1.0000000000002])),
5904            ],
5905        )
5906        .unwrap();
5907
5908        let rounded = round_float_columns(&[batch]);
5909        assert_eq!(rounded.len(), 1);
5910        let col = rounded[0]
5911            .column(1)
5912            .as_any()
5913            .downcast_ref::<Float64Array>()
5914            .unwrap();
5915        // Both should round to same value
5916        assert_eq!(col.value(0), col.value(1));
5917    }
5918
5919    // --- DerivedScanRegistry tests ---
5920
5921    #[test]
5922    fn test_registry_write_read_round_trip() {
5923        let schema = test_schema();
5924        let data = Arc::new(RwLock::new(Vec::new()));
5925        let mut reg = DerivedScanRegistry::new();
5926        reg.add(DerivedScanEntry {
5927            scan_index: 0,
5928            rule_name: "reachable".into(),
5929            is_self_ref: true,
5930            data: Arc::clone(&data),
5931            schema: Arc::clone(&schema),
5932        });
5933
5934        let batch = make_batch(&["x"], &[42]);
5935        reg.write_data(0, vec![batch.clone()]);
5936
5937        let entry = reg.get(0).unwrap();
5938        let guard = entry.data.read();
5939        assert_eq!(guard.len(), 1);
5940        assert_eq!(guard[0].num_rows(), 1);
5941    }
5942
5943    #[test]
5944    fn test_registry_entries_for_rule() {
5945        let schema = test_schema();
5946        let mut reg = DerivedScanRegistry::new();
5947        reg.add(DerivedScanEntry {
5948            scan_index: 0,
5949            rule_name: "r1".into(),
5950            is_self_ref: true,
5951            data: Arc::new(RwLock::new(Vec::new())),
5952            schema: Arc::clone(&schema),
5953        });
5954        reg.add(DerivedScanEntry {
5955            scan_index: 1,
5956            rule_name: "r2".into(),
5957            is_self_ref: false,
5958            data: Arc::new(RwLock::new(Vec::new())),
5959            schema: Arc::clone(&schema),
5960        });
5961        reg.add(DerivedScanEntry {
5962            scan_index: 2,
5963            rule_name: "r1".into(),
5964            is_self_ref: false,
5965            data: Arc::new(RwLock::new(Vec::new())),
5966            schema: Arc::clone(&schema),
5967        });
5968
5969        assert_eq!(reg.entries_for_rule("r1").len(), 2);
5970        assert_eq!(reg.entries_for_rule("r2").len(), 1);
5971        assert_eq!(reg.entries_for_rule("r3").len(), 0);
5972    }
5973
5974    // --- MonotonicAggState tests ---
5975
5976    #[test]
5977    fn test_monotonic_agg_update_and_stability() {
5978        let bindings = vec![MonotonicFoldBinding {
5979            fold_name: "total".into(),
5980            aggregate: std::sync::Arc::new(uni_plugin_builtin::locy_aggregates::SumAgg),
5981            input_col_index: 1,
5982            input_col_name: None,
5983        }];
5984        let mut agg = MonotonicAggState::new(bindings);
5985
5986        // First update
5987        let batch = make_batch(&["a"], &[10]);
5988        agg.snapshot();
5989        let changed = agg
5990            .update(&[0], &[batch], false, SemiringKind::AddMultProb)
5991            .unwrap();
5992        assert!(changed);
5993        assert!(!agg.is_stable()); // changed since snapshot
5994
5995        // Snapshot and check stability with no new data
5996        agg.snapshot();
5997        let changed = agg
5998            .update(&[0], &[], false, SemiringKind::AddMultProb)
5999            .unwrap();
6000        assert!(!changed);
6001        assert!(agg.is_stable());
6002    }
6003
6004    // --- Memory limit test ---
6005
6006    #[tokio::test]
6007    async fn test_memory_limit_exceeded() {
6008        let schema = test_schema();
6009        // Set a tiny limit
6010        let mut state = FixpointState::new("test".into(), schema, vec![0], 1, None, false);
6011
6012        let batch = make_batch(&["a", "b", "c"], &[1, 2, 3]);
6013        let result = state.merge_delta(vec![batch], None).await;
6014        assert!(result.is_err());
6015        let err = result.unwrap_err().to_string();
6016        assert!(err.contains("memory limit"), "Error was: {}", err);
6017    }
6018
6019    // --- FixpointStream lifecycle test ---
6020
6021    #[tokio::test]
6022    async fn test_fixpoint_stream_emitting() {
6023        use futures::StreamExt;
6024
6025        let schema = test_schema();
6026        let batch1 = make_batch(&["a"], &[1]);
6027        let batch2 = make_batch(&["b"], &[2]);
6028
6029        let metrics = ExecutionPlanMetricsSet::new();
6030        let baseline = BaselineMetrics::new(&metrics, 0);
6031
6032        let mut stream = FixpointStream {
6033            state: FixpointStreamState::Emitting(vec![batch1, batch2], 0),
6034            schema,
6035            metrics: baseline,
6036        };
6037
6038        let stream = Pin::new(&mut stream);
6039        let batches: Vec<RecordBatch> = stream.filter_map(|r| async { r.ok() }).collect().await;
6040
6041        assert_eq!(batches.len(), 2);
6042        assert_eq!(batches[0].num_rows(), 1);
6043        assert_eq!(batches[1].num_rows(), 1);
6044    }
6045
6046    // ── MonotonicAggState MNOR/MPROD tests ──────────────────────────────
6047
6048    fn make_f64_batch(names: &[&str], values: &[f64]) -> RecordBatch {
6049        let schema = Arc::new(Schema::new(vec![
6050            Field::new("name", DataType::Utf8, true),
6051            Field::new("value", DataType::Float64, true),
6052        ]));
6053        RecordBatch::try_new(
6054            schema,
6055            vec![
6056                Arc::new(StringArray::from(
6057                    names.iter().map(|s| Some(*s)).collect::<Vec<_>>(),
6058                )),
6059                Arc::new(Float64Array::from(values.to_vec())),
6060            ],
6061        )
6062        .unwrap()
6063    }
6064
6065    fn make_nor_binding() -> Vec<MonotonicFoldBinding> {
6066        vec![MonotonicFoldBinding {
6067            fold_name: "prob".into(),
6068            aggregate: std::sync::Arc::new(uni_plugin_builtin::locy_aggregates::MnorAgg),
6069            input_col_index: 1,
6070            input_col_name: None,
6071        }]
6072    }
6073
6074    fn make_prod_binding() -> Vec<MonotonicFoldBinding> {
6075        vec![MonotonicFoldBinding {
6076            fold_name: "prob".into(),
6077            aggregate: std::sync::Arc::new(uni_plugin_builtin::locy_aggregates::MprodAgg),
6078            input_col_index: 1,
6079            input_col_name: None,
6080        }]
6081    }
6082
6083    fn acc_key(name: &str) -> (Vec<ScalarKey>, String) {
6084        (vec![ScalarKey::Utf8(name.to_string())], "prob".to_string())
6085    }
6086
6087    #[test]
6088    fn test_monotonic_nor_first_update() {
6089        let mut agg = MonotonicAggState::new(make_nor_binding());
6090        let batch = make_f64_batch(&["a"], &[0.3]);
6091        let changed = agg
6092            .update(&[0], &[batch], false, SemiringKind::AddMultProb)
6093            .unwrap();
6094        assert!(changed);
6095        let val = agg.get_accumulator(&acc_key("a")).unwrap();
6096        assert!((val - 0.3).abs() < 1e-10, "expected 0.3, got {}", val);
6097    }
6098
6099    #[test]
6100    fn test_monotonic_nor_two_updates() {
6101        // Incremental NOR: acc = 1-(1-0.3)(1-0.5) = 0.65
6102        let mut agg = MonotonicAggState::new(make_nor_binding());
6103        let batch1 = make_f64_batch(&["a"], &[0.3]);
6104        agg.update(&[0], &[batch1], false, SemiringKind::AddMultProb)
6105            .unwrap();
6106        let batch2 = make_f64_batch(&["a"], &[0.5]);
6107        agg.update(&[0], &[batch2], false, SemiringKind::AddMultProb)
6108            .unwrap();
6109        let val = agg.get_accumulator(&acc_key("a")).unwrap();
6110        assert!((val - 0.65).abs() < 1e-10, "expected 0.65, got {}", val);
6111    }
6112
6113    #[test]
6114    fn test_monotonic_prod_first_update() {
6115        let mut agg = MonotonicAggState::new(make_prod_binding());
6116        let batch = make_f64_batch(&["a"], &[0.6]);
6117        let changed = agg
6118            .update(&[0], &[batch], false, SemiringKind::AddMultProb)
6119            .unwrap();
6120        assert!(changed);
6121        let val = agg.get_accumulator(&acc_key("a")).unwrap();
6122        assert!((val - 0.6).abs() < 1e-10, "expected 0.6, got {}", val);
6123    }
6124
6125    #[test]
6126    fn test_monotonic_prod_two_updates() {
6127        // Incremental PROD: acc = 0.6 * 0.8 = 0.48
6128        let mut agg = MonotonicAggState::new(make_prod_binding());
6129        let batch1 = make_f64_batch(&["a"], &[0.6]);
6130        agg.update(&[0], &[batch1], false, SemiringKind::AddMultProb)
6131            .unwrap();
6132        let batch2 = make_f64_batch(&["a"], &[0.8]);
6133        agg.update(&[0], &[batch2], false, SemiringKind::AddMultProb)
6134            .unwrap();
6135        let val = agg.get_accumulator(&acc_key("a")).unwrap();
6136        assert!((val - 0.48).abs() < 1e-10, "expected 0.48, got {}", val);
6137    }
6138
6139    #[test]
6140    fn test_monotonic_nor_stability() {
6141        let mut agg = MonotonicAggState::new(make_nor_binding());
6142        let batch = make_f64_batch(&["a"], &[0.3]);
6143        agg.update(&[0], &[batch], false, SemiringKind::AddMultProb)
6144            .unwrap();
6145        agg.snapshot();
6146        let changed = agg
6147            .update(&[0], &[], false, SemiringKind::AddMultProb)
6148            .unwrap();
6149        assert!(!changed);
6150        assert!(agg.is_stable());
6151    }
6152
6153    #[test]
6154    fn test_monotonic_prod_stability() {
6155        let mut agg = MonotonicAggState::new(make_prod_binding());
6156        let batch = make_f64_batch(&["a"], &[0.6]);
6157        agg.update(&[0], &[batch], false, SemiringKind::AddMultProb)
6158            .unwrap();
6159        agg.snapshot();
6160        let changed = agg
6161            .update(&[0], &[], false, SemiringKind::AddMultProb)
6162            .unwrap();
6163        assert!(!changed);
6164        assert!(agg.is_stable());
6165    }
6166
6167    #[test]
6168    fn test_monotonic_nor_multi_group() {
6169        // (a,0.3),(b,0.5) then (a,0.5),(b,0.2) → a=0.65, b=0.6
6170        let mut agg = MonotonicAggState::new(make_nor_binding());
6171        let batch1 = make_f64_batch(&["a", "b"], &[0.3, 0.5]);
6172        agg.update(&[0], &[batch1], false, SemiringKind::AddMultProb)
6173            .unwrap();
6174        let batch2 = make_f64_batch(&["a", "b"], &[0.5, 0.2]);
6175        agg.update(&[0], &[batch2], false, SemiringKind::AddMultProb)
6176            .unwrap();
6177
6178        let val_a = agg.get_accumulator(&acc_key("a")).unwrap();
6179        let val_b = agg.get_accumulator(&acc_key("b")).unwrap();
6180        assert!(
6181            (val_a - 0.65).abs() < 1e-10,
6182            "expected a=0.65, got {}",
6183            val_a
6184        );
6185        assert!((val_b - 0.6).abs() < 1e-10, "expected b=0.6, got {}", val_b);
6186    }
6187
6188    #[test]
6189    fn test_monotonic_prod_zero_absorbing() {
6190        // Zero absorbs: once 0.0, all further updates stay 0.0
6191        let mut agg = MonotonicAggState::new(make_prod_binding());
6192        let batch1 = make_f64_batch(&["a"], &[0.5]);
6193        agg.update(&[0], &[batch1], false, SemiringKind::AddMultProb)
6194            .unwrap();
6195        let batch2 = make_f64_batch(&["a"], &[0.0]);
6196        agg.update(&[0], &[batch2], false, SemiringKind::AddMultProb)
6197            .unwrap();
6198
6199        let val = agg.get_accumulator(&acc_key("a")).unwrap();
6200        assert!((val - 0.0).abs() < 1e-10, "expected 0.0, got {}", val);
6201
6202        // Further updates don't change the absorbing zero
6203        agg.snapshot();
6204        let batch3 = make_f64_batch(&["a"], &[0.5]);
6205        let changed = agg
6206            .update(&[0], &[batch3], false, SemiringKind::AddMultProb)
6207            .unwrap();
6208        assert!(!changed);
6209        assert!(agg.is_stable());
6210    }
6211
6212    #[test]
6213    fn test_monotonic_nor_clamping() {
6214        // 1.5 clamped to 1.0: acc = 1-(1-0)(1-1) = 1.0
6215        let mut agg = MonotonicAggState::new(make_nor_binding());
6216        let batch = make_f64_batch(&["a"], &[1.5]);
6217        agg.update(&[0], &[batch], false, SemiringKind::AddMultProb)
6218            .unwrap();
6219        let val = agg.get_accumulator(&acc_key("a")).unwrap();
6220        assert!((val - 1.0).abs() < 1e-10, "expected 1.0, got {}", val);
6221    }
6222
6223    #[test]
6224    fn test_monotonic_nor_absorbing() {
6225        // p=1.0 absorbs: 0.3 then 1.0 → 1.0
6226        let mut agg = MonotonicAggState::new(make_nor_binding());
6227        let batch1 = make_f64_batch(&["a"], &[0.3]);
6228        agg.update(&[0], &[batch1], false, SemiringKind::AddMultProb)
6229            .unwrap();
6230        let batch2 = make_f64_batch(&["a"], &[1.0]);
6231        agg.update(&[0], &[batch2], false, SemiringKind::AddMultProb)
6232            .unwrap();
6233        let val = agg.get_accumulator(&acc_key("a")).unwrap();
6234        assert!((val - 1.0).abs() < 1e-10, "expected 1.0, got {}", val);
6235    }
6236
6237    // ── MonotonicAggState strict mode tests (Phase 5) ─────────────────────
6238
6239    #[test]
6240    fn test_monotonic_agg_strict_nor_rejects() {
6241        let mut agg = MonotonicAggState::new(make_nor_binding());
6242        let batch = make_f64_batch(&["a"], &[1.5]);
6243        let result = agg.update(&[0], &[batch], true, SemiringKind::AddMultProb);
6244        assert!(result.is_err());
6245        let err = result.unwrap_err().to_string();
6246        assert!(
6247            err.contains("strict_probability_domain"),
6248            "Expected strict error, got: {}",
6249            err
6250        );
6251    }
6252
6253    #[test]
6254    fn test_monotonic_agg_strict_prod_rejects() {
6255        let mut agg = MonotonicAggState::new(make_prod_binding());
6256        let batch = make_f64_batch(&["a"], &[2.0]);
6257        let result = agg.update(&[0], &[batch], true, SemiringKind::AddMultProb);
6258        assert!(result.is_err());
6259        let err = result.unwrap_err().to_string();
6260        assert!(
6261            err.contains("strict_probability_domain"),
6262            "Expected strict error, got: {}",
6263            err
6264        );
6265    }
6266
6267    #[test]
6268    fn test_monotonic_agg_strict_accepts_valid() {
6269        let mut agg = MonotonicAggState::new(make_nor_binding());
6270        let batch = make_f64_batch(&["a"], &[0.5]);
6271        let result = agg.update(&[0], &[batch], true, SemiringKind::AddMultProb);
6272        assert!(result.is_ok());
6273        let val = agg.get_accumulator(&acc_key("a")).unwrap();
6274        assert!((val - 0.5).abs() < 1e-10, "expected 0.5, got {}", val);
6275    }
6276
6277    // ── Complement function unit tests (Phase 4) ──────────────────────────
6278
6279    fn make_vid_prob_batch(vids: &[u64], probs: &[f64]) -> RecordBatch {
6280        use arrow_array::UInt64Array;
6281        let schema = Arc::new(Schema::new(vec![
6282            Field::new("vid", DataType::UInt64, true),
6283            Field::new("prob", DataType::Float64, true),
6284        ]));
6285        RecordBatch::try_new(
6286            schema,
6287            vec![
6288                Arc::new(UInt64Array::from(vids.to_vec())),
6289                Arc::new(Float64Array::from(probs.to_vec())),
6290            ],
6291        )
6292        .unwrap()
6293    }
6294
6295    #[test]
6296    fn test_prob_complement_basic() {
6297        // neg has VID=1 with prob=0.7 → complement=0.3; VID=2 absent → complement=1.0
6298        let body = make_vid_prob_batch(&[1, 2], &[0.9, 0.8]);
6299        let neg = make_vid_prob_batch(&[1], &[0.7]);
6300        let join_cols = vec![("vid".to_string(), "vid".to_string())];
6301        let result = apply_prob_complement_composite(
6302            vec![body],
6303            &[neg],
6304            &join_cols,
6305            "prob",
6306            "__complement_0",
6307        )
6308        .unwrap();
6309        assert_eq!(result.len(), 1);
6310        let batch = &result[0];
6311        let complement = batch
6312            .column_by_name("__complement_0")
6313            .unwrap()
6314            .as_any()
6315            .downcast_ref::<Float64Array>()
6316            .unwrap();
6317        // VID=1: complement = 1 - 0.7 = 0.3
6318        assert!(
6319            (complement.value(0) - 0.3).abs() < 1e-10,
6320            "expected 0.3, got {}",
6321            complement.value(0)
6322        );
6323        // VID=2: absent from neg → complement = 1.0
6324        assert!(
6325            (complement.value(1) - 1.0).abs() < 1e-10,
6326            "expected 1.0, got {}",
6327            complement.value(1)
6328        );
6329    }
6330
6331    #[test]
6332    fn test_prob_complement_noisy_or_duplicates() {
6333        // neg has VID=1 twice with prob=0.3 and prob=0.5
6334        // Combined via noisy-OR: 1-(1-0.3)(1-0.5) = 0.65
6335        // Complement = 1 - 0.65 = 0.35
6336        let body = make_vid_prob_batch(&[1], &[0.9]);
6337        let neg = make_vid_prob_batch(&[1, 1], &[0.3, 0.5]);
6338        let join_cols = vec![("vid".to_string(), "vid".to_string())];
6339        let result = apply_prob_complement_composite(
6340            vec![body],
6341            &[neg],
6342            &join_cols,
6343            "prob",
6344            "__complement_0",
6345        )
6346        .unwrap();
6347        let batch = &result[0];
6348        let complement = batch
6349            .column_by_name("__complement_0")
6350            .unwrap()
6351            .as_any()
6352            .downcast_ref::<Float64Array>()
6353            .unwrap();
6354        assert!(
6355            (complement.value(0) - 0.35).abs() < 1e-10,
6356            "expected 0.35, got {}",
6357            complement.value(0)
6358        );
6359    }
6360
6361    #[test]
6362    fn test_prob_complement_empty_neg() {
6363        // Empty neg_facts → body passes through with complement=1.0
6364        let body = make_vid_prob_batch(&[1, 2], &[0.5, 0.6]);
6365        let join_cols = vec![("vid".to_string(), "vid".to_string())];
6366        let result =
6367            apply_prob_complement_composite(vec![body], &[], &join_cols, "prob", "__complement_0")
6368                .unwrap();
6369        let batch = &result[0];
6370        let complement = batch
6371            .column_by_name("__complement_0")
6372            .unwrap()
6373            .as_any()
6374            .downcast_ref::<Float64Array>()
6375            .unwrap();
6376        for i in 0..2 {
6377            assert!(
6378                (complement.value(i) - 1.0).abs() < 1e-10,
6379                "row {}: expected 1.0, got {}",
6380                i,
6381                complement.value(i)
6382            );
6383        }
6384    }
6385
6386    #[test]
6387    fn test_anti_join_basic() {
6388        // body [1,2,3], neg [2] → result [1,3]
6389        use arrow_array::UInt64Array;
6390        let body = make_vid_prob_batch(&[1, 2, 3], &[0.5, 0.6, 0.7]);
6391        let neg = make_vid_prob_batch(&[2], &[0.0]);
6392        let join_cols = vec![("vid".to_string(), "vid".to_string())];
6393        let result = apply_anti_join_composite(vec![body], &[neg], &join_cols).unwrap();
6394        assert_eq!(result.len(), 1);
6395        let batch = &result[0];
6396        assert_eq!(batch.num_rows(), 2);
6397        let vids = batch
6398            .column_by_name("vid")
6399            .unwrap()
6400            .as_any()
6401            .downcast_ref::<UInt64Array>()
6402            .unwrap();
6403        assert_eq!(vids.value(0), 1);
6404        assert_eq!(vids.value(1), 3);
6405    }
6406
6407    #[test]
6408    fn test_anti_join_empty_neg() {
6409        // Empty neg → all rows kept
6410        let body = make_vid_prob_batch(&[1, 2, 3], &[0.5, 0.6, 0.7]);
6411        let join_cols = vec![("vid".to_string(), "vid".to_string())];
6412        let result = apply_anti_join_composite(vec![body], &[], &join_cols).unwrap();
6413        assert_eq!(result.len(), 1);
6414        assert_eq!(result[0].num_rows(), 3);
6415    }
6416
6417    #[test]
6418    fn test_anti_join_all_excluded() {
6419        // neg covers all body rows → empty result
6420        let body = make_vid_prob_batch(&[1, 2], &[0.5, 0.6]);
6421        let neg = make_vid_prob_batch(&[1, 2], &[0.0, 0.0]);
6422        let join_cols = vec![("vid".to_string(), "vid".to_string())];
6423        let result = apply_anti_join_composite(vec![body], &[neg], &join_cols).unwrap();
6424        let total: usize = result.iter().map(|b| b.num_rows()).sum();
6425        assert_eq!(total, 0);
6426    }
6427
6428    #[test]
6429    fn test_multiply_prob_single_complement() {
6430        // prob=0.8, complement=0.5 → output prob=0.4; complement col removed
6431        let body = make_vid_prob_batch(&[1], &[0.8]);
6432        // Add a complement column
6433        let complement_arr = Float64Array::from(vec![0.5]);
6434        let mut cols: Vec<arrow_array::ArrayRef> = body.columns().to_vec();
6435        cols.push(Arc::new(complement_arr));
6436        let mut fields: Vec<Arc<Field>> = body.schema().fields().iter().cloned().collect();
6437        fields.push(Arc::new(Field::new(
6438            "__complement_0",
6439            DataType::Float64,
6440            true,
6441        )));
6442        let schema = Arc::new(Schema::new(fields));
6443        let batch = RecordBatch::try_new(schema, cols).unwrap();
6444
6445        let result =
6446            multiply_prob_factors(vec![batch], Some("prob"), &["__complement_0".to_string()])
6447                .unwrap();
6448        assert_eq!(result.len(), 1);
6449        let out = &result[0];
6450        // Complement column should be removed
6451        assert!(out.column_by_name("__complement_0").is_none());
6452        let prob = out
6453            .column_by_name("prob")
6454            .unwrap()
6455            .as_any()
6456            .downcast_ref::<Float64Array>()
6457            .unwrap();
6458        assert!(
6459            (prob.value(0) - 0.4).abs() < 1e-10,
6460            "expected 0.4, got {}",
6461            prob.value(0)
6462        );
6463    }
6464
6465    #[test]
6466    fn test_multiply_prob_multiple_complements() {
6467        // prob=0.8, c1=0.5, c2=0.6 → 0.8×0.5×0.6=0.24
6468        let body = make_vid_prob_batch(&[1], &[0.8]);
6469        let c1 = Float64Array::from(vec![0.5]);
6470        let c2 = Float64Array::from(vec![0.6]);
6471        let mut cols: Vec<arrow_array::ArrayRef> = body.columns().to_vec();
6472        cols.push(Arc::new(c1));
6473        cols.push(Arc::new(c2));
6474        let mut fields: Vec<Arc<Field>> = body.schema().fields().iter().cloned().collect();
6475        fields.push(Arc::new(Field::new("__c1", DataType::Float64, true)));
6476        fields.push(Arc::new(Field::new("__c2", DataType::Float64, true)));
6477        let schema = Arc::new(Schema::new(fields));
6478        let batch = RecordBatch::try_new(schema, cols).unwrap();
6479
6480        let result = multiply_prob_factors(
6481            vec![batch],
6482            Some("prob"),
6483            &["__c1".to_string(), "__c2".to_string()],
6484        )
6485        .unwrap();
6486        let out = &result[0];
6487        assert!(out.column_by_name("__c1").is_none());
6488        assert!(out.column_by_name("__c2").is_none());
6489        let prob = out
6490            .column_by_name("prob")
6491            .unwrap()
6492            .as_any()
6493            .downcast_ref::<Float64Array>()
6494            .unwrap();
6495        assert!(
6496            (prob.value(0) - 0.24).abs() < 1e-10,
6497            "expected 0.24, got {}",
6498            prob.value(0)
6499        );
6500    }
6501
6502    #[test]
6503    fn test_multiply_prob_no_prob_column() {
6504        // No prob column → combined complements become the output
6505        use arrow_array::UInt64Array;
6506        let schema = Arc::new(Schema::new(vec![
6507            Field::new("vid", DataType::UInt64, true),
6508            Field::new("__c1", DataType::Float64, true),
6509        ]));
6510        let batch = RecordBatch::try_new(
6511            schema,
6512            vec![
6513                Arc::new(UInt64Array::from(vec![1u64])),
6514                Arc::new(Float64Array::from(vec![0.7])),
6515            ],
6516        )
6517        .unwrap();
6518
6519        let result = multiply_prob_factors(vec![batch], None, &["__c1".to_string()]).unwrap();
6520        let out = &result[0];
6521        // __c1 should be removed since it's a complement column
6522        assert!(out.column_by_name("__c1").is_none());
6523        // Only vid column remains
6524        assert_eq!(out.num_columns(), 1);
6525    }
6526}